trainers.DifftreParallel#
- class DifftreParallel(key, init_params, optimizer, energy_fn_template, simulator_template, neighbor_fn, timings, state_kwargs, quantities, targets, observables, initial_trajstates=None, reference_states=None, num_runs_init=1, reweight_ratio=0.9, allowed_reduction=0.95, step_size_scale=0.0001, interior_points=100, sim_batch_size=1, full_checkpoint=False, target_loss_fns=None, loss_fn=None, vmap_batch=10, bucket_recompute=True, resample_simstates=False, convergence_criterion='window_median', checkpoint_path='Checkpoints', log_dir=None)[source]#
Trainer class for parametrizing potentials via the DiffTRe method.
This method performs simulations and updates for multiple statepoints in parallel using vmap.
- Parameters:
init_params (
Any) – Initial energy parametersoptimizer (
GradientTransformationExtraArgs) – Optimizer from optaxenergy_fn_template (
EnergyFnTemplate) – Function that takes energy parameters and initializes a new energy function.simulator_template (
Callable) – Function that takes an energy function and returns a simulator function.neighbor_fn (
Callable[[Array,Optional[NeighborList,None],Optional[int,None]],NeighborList]) – Neighbor function. Must be ofjax_md_mod.custom_partition.masked_neighbor_list()if the statepoints have a different number of atoms.timings (
TimingClass) – Instance of TimingClass containing information about the trajectory length and which states to retainstate_kwargs (
Dict[str,Union[Array,ndarray,bool,number,bool,int,float,complex]]) – Properties defining the thermodynamic state. Must at least contain the temperature ‘kT’. For a non-exhaustive list, seechemtrain.ensemble.templates.StatePoint.quantities (
Dict[str,Dict]) – Dict containing for each observable specified by the key a corresponding function to compute it for each snapshot usingensemble.sampling.quantity_traj().targets (
Dict[str,Any]) – Dict containing the same keys as quantities and containing another dict providing ‘gamma’ and ‘target’ for each observable.observables (
Dict[str,TrajFn]) – Optional dictionary providing the observable functions for the targets.initial_trajstates – Initial trajectory states of the statepoints. It is usually simpler to let the trainer generate the initial trajectory states by providing reference_states.
reference_states – Initial simulator states from which DiffTRe can compute the initial trajectory states.
num_runs_init (
int) – Number of runs to perform for the initial trajectory states. This number can be increased to ensure a better equilibration when starting from less favourable initial states.reweight_ratio (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Ratio of reference samples required for n_eff to surpass to allow re-use of previous reference trajectory state. If trajectories should not be re-used, a value > 1 can be specified.allowed_reduction (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Allowed reduction of the effective sample size through a parameter update.step_size_scale (
float) – Initial step size scale for the step size adaption.interior_points (
int) – Number of interior points to use for the step size adaption.sim_batch_size (
int) – Number of state-points to be processed as a single batch. Gradients will be averaged over the batch before stepping the optimizer.full_checkpoint (
bool) – If True, the whole trainer state is saved, otherwise only important parameters are stored as a dictionary.target_loss_fns (
Dict[str,Callable]) – Dictionary of loss functions to use for each target.loss_fn – Custom loss function to use for the training.
vmap_batch (
int) – Number of samples to process simultaneously when computing instantaneous quantities for a trajectory.bucket_recompute (
bool) – Groups together statepoints that need a recomputation.resample_simstates (
bool) – Resample the sim states from all trajectories instead of simulating independent chains.convergence_criterion (
str) – Either ‘max_loss’ or ‘ave_loss’. If ‘max_loss’, stops if the maximum loss across all batches in the epoch is smaller than convergence_thresh. ‘ave_loss’ evaluates the average loss across the batch. For a single state point, both are equivalent. A criterion based on the rolling standatd deviation ‘std’ might be implemented in the future.checkpoint_path (
PathLike) – Name of folders to store checkpoints in.log_dir (
PathLike) – Path to the log file where to store training progress.
- Variables:
batch_losses – List of losses for each batch in each epoch.
epoch_losses – List of losses for each epoch.
step_size_history – List of step sizes for each batched update.
gradient_norm_history – List of gradient norms for each batched update.
batch_gradient_norms – List of gradient norms for each batch.
predictions – Dictionary containing the predictions for each statepoint at each epoch.
early_stop – Instance of EarlyStopping to check for convergence.
Methods
__init__(key, init_params, optimizer, ...[, ...])A reference energy_fn_template can be provided, but is not mandatory due to the dependence of the template on the box via the displacement function.
add_task(trigger, fn_or_method)Adds a tasks to perform regularly during training.
checkpoint(name, object)Marks attribute to be saved in a partial checkpoint.
load_energy_params(file_path)Loads energy parameters.
Transforms the trainer states to JAX arrays.
predict(batch)Predict for a batch of statepoints.
Prints the tasks performed by the trainer.
restore(checkpoint)Restores the trainer from a checkpoint.
save_energy_params(file_path[, save_format, ...])Saves energy parameters.
save_trainer(save_path[, format])Saves whole trainer, e.g. for production after training.
train(max_epochs[, thresh, checkpoint_freq])Trains for a maximum number of epochs, checkpoints after a specified number of epochs and ends training if a convergence criterion is met.
Attributes
- best_params#
Returns the best parameters according to the early stopping criterion.
- energy_fn#
Returns the energy function for the current parameters.
- n_statepoints#
- params#
Current energy parameters.