trainers.base.MLETrainerTemplate#

class MLETrainerTemplate(optimizer, init_state, checkpoint_path, full_checkpoint=True, log_file=None, reference_energy_fn_template=None)[source]#

Abstract class implementing common properties and methods of single point estimate Trainers using optax optimizers.

Parameters:
  • optimizer – Optax optimizer

  • init_state (TrainerState) – Initial state of optimizer and model

  • checkpoint_path (PathLike) – Path to folder where checkpoints are saved

  • full_checkpoint (bool) – Whether to save the full trainer with pickle or only a subset of attributes.

  • log_file (PathLike) – Write loggs of Trainer to the file specified by path.

  • reference_energy_fn_template (EnergyFnTemplate) – Function returning a concrete energy function for the current parameters

The MLE trainer performs a sequence of task before and after each training, epoch and batch update. It is possible to add custom tasks to the trainer via MLETrainerTemplate.add_task().

Variables:
  • update_times – Computation time of each update

  • gradient_norm_history – Norms of the gradient for each update

Methods

__init__(optimizer, init_state, checkpoint_path)

A reference energy_fn_template can be provided, but is not mandatory due to the dependence of the template on the box via the displacement function.

add_task(trigger, fn_or_method)

Adds a tasks to perform regularly during training.

checkpoint(name, object)

Marks attribute to be saved in a partial checkpoint.

load_energy_params(file_path)

Loads energy parameters.

move_to_device()

Converts all arrays of the trainer state to JAX arrays.

print_training_tasks()

Prints the tasks performed by the trainer.

restore(checkpoint)

Restores the trainer from a checkpoint.

save_energy_params(file_path[, save_format, ...])

Saves energy parameters.

save_trainer(save_path[, format])

Saves whole trainer, e.g. for production after training.

train(max_epochs[, thresh, checkpoint_freq])

Trains for a maximum number of epochs, checkpoints after a specified number of epochs and ends training if a convergence criterion is met.

Attributes

energy_fn#

Returns the energy function for the current parameters.

params#

Short-cut for parameters. Depends on specific trainer.