trainers.Difftre#
- class Difftre(init_params, optimizer, reweight_ratio=1.0, adaptive_step_size_threshold=0.0001, sim_batch_size=1, energy_fn_template=None, full_checkpoint=False, convergence_criterion='window_median', checkpoint_path='Checkpoints', log_dir=None)[source]#
Trainer class for parametrizing potentials via the DiffTRe method.
The Differentiable Trajectory Reweighting (DiffTRe) method [1] is a method to compute the gradients of ensemble averages without differentiating through the simulation. Therefore, the method can efficiently train potential models on macroscopic observables.
The trainer initialization only sets the initial trainer state as well as checkpointing and save-functionality. For training, target state points with respective simulations need to be added via
Difftre.add_statepoint().- Parameters:
init_params (
Any) – Initial energy parametersoptimizer (
GradientTransformationExtraArgs) – Optimizer from optaxreweight_ratio (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Ratio of reference samples required for n_eff to surpass to allow re-use of previous reference trajectory state. If trajectories should not be re-used, a value > 1 can be specified.sim_batch_size (
int) – Number of state-points to be processed as a single batch. Gradients will be averaged over the batch before stepping the optimizer.energy_fn_template (
EnergyFnTemplate) – Function that takes energy parameters and initializes a new energy function. Here, the energy_fn_template is only a reference that will be saved alongside the trainer. Each state point requires its own due to the dependence on the box size via the displacement function, which can vary between state points.convergence_criterion (
str) – Either ‘max_loss’ or ‘ave_loss’. If ‘max_loss’, stops if the maximum loss across all batches in the epoch is smaller than convergence_thresh. ‘ave_loss’ evaluates the average loss across the batch. For a single state point, both are equivalent. A criterion based on the rolling standatd deviation ‘std’ might be implemented in the future.checkpoint_folder – Name of folders to store ckeckpoints in.
- Variables:
weight_fn – Dictionary containing the reweighting functions for each statepoint.
batch_losses – List of losses for each batch in each epoch.
epoch_losses – List of losses for each epoch.
step_size_history – List of step sizes for each batched update.
gradient_norm_history – List of gradient norms for each batched update.
predictions – Dictionary containing the predictions for each statepoint at each epoch.
early_stop – Instance of EarlyStopping to check for convergence.
Examples
trainer = trainers.Difftre(init_params, optimizer) # Add all statepoints trainer.add_statepoint(energy_fn_template, simulator_template, neighbor_fn, timings, statepoint_dict, compute_fns, reference_state, targets) ... # Optionally initialize the step size adaption trainer.init_step_size_adaption(allowed_reduction=0.5) trainer.train(num_updates)
References
Methods
__init__(init_params, optimizer[, ...])A reference energy_fn_template can be provided, but is not mandatory due to the dependence of the template on the box via the displacement function.
add_statepoint(energy_fn_template, ...[, ...])Adds a state point to the pool of simulations with respective targets.
add_task(trigger, fn_or_method)Adds a tasks to perform regularly during training.
checkpoint(name, object)Marks attribute to be saved in a partial checkpoint.
get_sim_state(key)Gets the simulator state of a statepoint.
init_step_size_adaption([allowed_reduction, ...])Initializes a line search to tune the step size in each iteration.
load_energy_params(file_path)Loads energy parameters.
Transforms the trainer states to JAX arrays.
predict(*, key)Get predictions for a specific statepoint.
Prints the tasks performed by the trainer.
restore(checkpoint)Restores the trainer from a checkpoint.
save_energy_params(file_path[, save_format, ...])Saves energy parameters.
save_trainer(save_path[, format])Saves whole trainer, e.g. for production after training.
train(max_epochs[, thresh, checkpoint_freq])Trains for a maximum number of epochs, checkpoints after a specified number of epochs and ends training if a convergence criterion is met.
Attributes
- best_params#
Returns the best parameters according to the early stopping criterion.
- energy_fn#
Returns the energy function for the current parameters.
- params#
Current energy parameters.