trainers.base.MLETrainerTemplate#
- class MLETrainerTemplate(optimizer, init_state, checkpoint_path, full_checkpoint=True, log_file=None, reference_energy_fn_template=None)[source]#
Abstract class implementing common properties and methods of single point estimate Trainers using optax optimizers.
- Parameters:
optimizer – Optax optimizer
init_state (
TrainerState) – Initial state of optimizer and modelcheckpoint_path (
PathLike) – Path to folder where checkpoints are savedfull_checkpoint (
bool) – Whether to save the full trainer with pickle or only a subset of attributes.log_file (
PathLike) – Write loggs of Trainer to the file specified by path.reference_energy_fn_template (
EnergyFnTemplate) – Function returning a concrete energy function for the current parameters
The MLE trainer performs a sequence of task before and after each training, epoch and batch update. It is possible to add custom tasks to the trainer via
MLETrainerTemplate.add_task().- Variables:
update_times – Computation time of each update
gradient_norm_history – Norms of the gradient for each update
Methods
__init__(optimizer, init_state, checkpoint_path)A reference energy_fn_template can be provided, but is not mandatory due to the dependence of the template on the box via the displacement function.
add_task(trigger, fn_or_method)Adds a tasks to perform regularly during training.
checkpoint(name, object)Marks attribute to be saved in a partial checkpoint.
load_energy_params(file_path)Loads energy parameters.
Converts all arrays of the trainer state to JAX arrays.
Prints the tasks performed by the trainer.
restore(checkpoint)Restores the trainer from a checkpoint.
save_energy_params(file_path[, save_format, ...])Saves energy parameters.
save_trainer(save_path[, format])Saves whole trainer, e.g. for production after training.
train(max_epochs[, thresh, checkpoint_freq])Trains for a maximum number of epochs, checkpoints after a specified number of epochs and ends training if a convergence criterion is met.
Attributes
- energy_fn#
Returns the energy function for the current parameters.
- params#
Short-cut for parameters. Depends on specific trainer.