trainers.DifftreParallel#

class DifftreParallel(key, init_params, optimizer, energy_fn_template, simulator_template, neighbor_fn, timings, state_kwargs, quantities, targets, observables, initial_trajstates=None, reference_states=None, num_runs_init=1, reweight_ratio=0.9, allowed_reduction=0.95, step_size_scale=0.0001, interior_points=100, sim_batch_size=1, full_checkpoint=False, target_loss_fns=None, loss_fn=None, vmap_batch=10, bucket_recompute=True, resample_simstates=False, convergence_criterion='window_median', checkpoint_path='Checkpoints', log_dir=None)[source]#

Trainer class for parametrizing potentials via the DiffTRe method.

This method performs simulations and updates for multiple statepoints in parallel using vmap.

Parameters:
  • init_params (Any) – Initial energy parameters

  • optimizer (GradientTransformationExtraArgs) – Optimizer from optax

  • energy_fn_template (EnergyFnTemplate) – Function that takes energy parameters and initializes a new energy function.

  • simulator_template (Callable) – Function that takes an energy function and returns a simulator function.

  • neighbor_fn (Callable[[Array, Optional[NeighborList, None], Optional[int, None]], NeighborList]) – Neighbor function. Must be of jax_md_mod.custom_partition.masked_neighbor_list() if the statepoints have a different number of atoms.

  • timings (TimingClass) – Instance of TimingClass containing information about the trajectory length and which states to retain

  • state_kwargs (Dict[str, Union[Array, ndarray, bool, number, bool, int, float, complex]]) – Properties defining the thermodynamic state. Must at least contain the temperature ‘kT’. For a non-exhaustive list, see chemtrain.ensemble.templates.StatePoint.

  • quantities (Dict[str, Dict]) – Dict containing for each observable specified by the key a corresponding function to compute it for each snapshot using ensemble.sampling.quantity_traj().

  • targets (Dict[str, Any]) – Dict containing the same keys as quantities and containing another dict providing ‘gamma’ and ‘target’ for each observable.

  • observables (Dict[str, TrajFn]) – Optional dictionary providing the observable functions for the targets.

  • initial_trajstates – Initial trajectory states of the statepoints. It is usually simpler to let the trainer generate the initial trajectory states by providing reference_states.

  • reference_states – Initial simulator states from which DiffTRe can compute the initial trajectory states.

  • num_runs_init (int) – Number of runs to perform for the initial trajectory states. This number can be increased to ensure a better equilibration when starting from less favourable initial states.

  • reweight_ratio (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Ratio of reference samples required for n_eff to surpass to allow re-use of previous reference trajectory state. If trajectories should not be re-used, a value > 1 can be specified.

  • allowed_reduction (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Allowed reduction of the effective sample size through a parameter update.

  • step_size_scale (float) – Initial step size scale for the step size adaption.

  • interior_points (int) – Number of interior points to use for the step size adaption.

  • sim_batch_size (int) – Number of state-points to be processed as a single batch. Gradients will be averaged over the batch before stepping the optimizer.

  • full_checkpoint (bool) – If True, the whole trainer state is saved, otherwise only important parameters are stored as a dictionary.

  • target_loss_fns (Dict[str, Callable]) – Dictionary of loss functions to use for each target.

  • loss_fn – Custom loss function to use for the training.

  • vmap_batch (int) – Number of samples to process simultaneously when computing instantaneous quantities for a trajectory.

  • bucket_recompute (bool) – Groups together statepoints that need a recomputation.

  • resample_simstates (bool) – Resample the sim states from all trajectories instead of simulating independent chains.

  • convergence_criterion (str) – Either ‘max_loss’ or ‘ave_loss’. If ‘max_loss’, stops if the maximum loss across all batches in the epoch is smaller than convergence_thresh. ‘ave_loss’ evaluates the average loss across the batch. For a single state point, both are equivalent. A criterion based on the rolling standatd deviation ‘std’ might be implemented in the future.

  • checkpoint_path (PathLike) – Name of folders to store checkpoints in.

  • log_dir (PathLike) – Path to the log file where to store training progress.

Variables:
  • batch_losses – List of losses for each batch in each epoch.

  • epoch_losses – List of losses for each epoch.

  • step_size_history – List of step sizes for each batched update.

  • gradient_norm_history – List of gradient norms for each batched update.

  • batch_gradient_norms – List of gradient norms for each batch.

  • predictions – Dictionary containing the predictions for each statepoint at each epoch.

  • early_stop – Instance of EarlyStopping to check for convergence.

Methods

__init__(key, init_params, optimizer, ...[, ...])

A reference energy_fn_template can be provided, but is not mandatory due to the dependence of the template on the box via the displacement function.

add_task(trigger, fn_or_method)

Adds a tasks to perform regularly during training.

checkpoint(name, object)

Marks attribute to be saved in a partial checkpoint.

load_energy_params(file_path)

Loads energy parameters.

move_to_device()

Transforms the trainer states to JAX arrays.

predict(batch)

Predict for a batch of statepoints.

print_training_tasks()

Prints the tasks performed by the trainer.

restore(checkpoint)

Restores the trainer from a checkpoint.

save_energy_params(file_path[, save_format, ...])

Saves energy parameters.

save_trainer(save_path[, format])

Saves whole trainer, e.g. for production after training.

train(max_epochs[, thresh, checkpoint_freq])

Trains for a maximum number of epochs, checkpoints after a specified number of epochs and ends training if a convergence criterion is met.

Attributes

best_params#

Returns the best parameters according to the early stopping criterion.

energy_fn#

Returns the energy function for the current parameters.

n_statepoints#
params#

Current energy parameters.