trainers.RelativeEntropy#
- class RelativeEntropy(init_params, optimizer, reweight_ratio=0.9, sim_batch_size=1, energy_fn_template=None, convergence_criterion='window_median', checkpoint_path='Checkpoints', full_checkpoint=False)[source]#
Trainer for relative entropy minimization.
The Relative Entropy Minimization procedure coarse-graines potential models by minimizing the relative entropy between the atomistic reference and coarse-grained target canonical distributions [1] [2].
The relative entropy algorithm currently assume a NVT ensemble.
- Parameters:
init_params – Initial energy parameters.
optimizer – Optimizer from optax.
reweight_ratio (
float) – Ratio of reference samples required for n_eff to surpass to allow re-use of previous reference trajectory state. If trajectories should not be re-used, a value > 1 can be specified.sim_batch_size (
int) – Number of state-points to be processed as a single batch. Gradients will be averaged over the batch before stepping the optimizer.energy_fn_template (
EnergyFnTemplate) – Function that takes energy parameters and initializes an new energy function. Here, theenergy_fn_templateis only a reference that will be saved alongside the trainer. Each state point requires its own due to the dependence on the box size via the displacement function, which can vary between state points.convergence_criterion (
str) – Either'max_loss'or'ave_loss'. If'max_loss', stops if the gradient norm cross all batches in the epoch is smaller than convergence_thresh.'ave_loss'evaluates the average gradient norm across the batch. For a single state point, both are equivalent.checkpoint_path (
PathLike) – Path to the folder to store ckeckpoints in.full_checkpoint (
bool) – Save the whole trainer instead of only the inference data.
- Variables:
data_states – Dictionary containing the dataloader states for each state points.
delta_re – Dictionary containing the improvement of the relative entropy with respect to the initial potential.
step_size_history – List of step size scales for each batched update.
gradient_norm_history – List of gradient norms for each batched update.
weight_fn – Dictionary containing the reweighting functions for each statepoint.
early_stop – Instance of EarlyStopping to check for convergence.
References
Methods
__init__(init_params, optimizer[, ...])A reference energy_fn_template can be provided, but is not mandatory due to the dependence of the template on the box via the displacement function.
add_statepoint(reference_data, ...[, ...])Adds a state point to the pool of simulations.
add_task(trigger, fn_or_method)Adds a tasks to perform regularly during training.
checkpoint(name, object)Marks attribute to be saved in a partial checkpoint.
get_sim_state(key)Gets the simulator state of a statepoint.
init_step_size_adaption([allowed_reduction, ...])Initializes a line search to tune the step size in each iteration.
load_energy_params(file_path)Loads energy parameters.
Converts all arrays of the trainer state to JAX arrays.
Prints the tasks performed by the trainer.
restore(checkpoint)Restores the trainer from a checkpoint.
save_energy_params(file_path[, save_format, ...])Saves energy parameters.
save_trainer(save_path[, format])Saves whole trainer, e.g. for production after training.
train(max_epochs[, thresh, checkpoint_freq])Trains for a maximum number of epochs, checkpoints after a specified number of epochs and ends training if a convergence criterion is met.
Attributes
- energy_fn#
Returns the energy function for the current parameters.
- params#
Current energy parameters.