trainers.InterleaveTrainers#
- class InterleaveTrainers(sequential=True, checkpoint_base_path='checkpoints', reference_energy_fn_template=None, full_checkpoint=False)[source]#
Interleaves updates to train models using multiple algorithms.
This special trainer allows to train models simultaneously with different algorithms.
Example
# First initialize the base-trainers, e.g. fm_trainer = trainers.ForceMatching(...) difftre_trainer = trainers.Difftre(...) difftre_trainer.add_statepoint(...) # Now combine the trainers. The trainers are executed in the # order in which they are added trainer = trainers.InterleaveTrainers('checkpoint_folder', energy_fn_template, full_checkpoint=False) # Force matching should run 10 epochs before difftre runs 2 epochs trainer.add_trainer(fm_trainer, num_updates=10, name='Force Matching') trainer.add_trainer(difftre_trainer, num_updates=2, name='DiffTRe') trainer.train(100, checkpoint_frequency=10)
- Parameters:
sequential – Start the next trainer directly with the optimized parameters of the previous trainer. In the non-sequential case, the trainers start their epoch on the same parameter set and the final update is a weighted sum of both updates.
checkpoint_base_path – Location to store checkpoints of the trainers.
reference_energy_fn_template – Energy function template to optionally return an energy function with current parameters.
full_checkpoint – Store the complete trainer or important properties only.
Methods
__init__([sequential, checkpoint_base_path, ...])A reference energy_fn_template can be provided, but is not mandatory due to the dependence of the template on the box via the displacement function.
add_trainer(trainer[, num_updates, name, weight])Adds a trainer to the combined training.
checkpoint(name, object)Marks attribute to be saved in a partial checkpoint.
load_energy_params(file_path)Loads energy parameters.
Move all attributes that are expected to be on device to device to avoid TracerExceptions after loading trainers from disk, i.e. loading numpy rather than device arrays.
restore(checkpoint)Restores the trainer from a checkpoint.
save_energy_params(file_path[, save_format, ...])Saves energy parameters.
save_trainer(save_path[, format])Saves whole trainer, e.g. for production after training.
train(epochs[, checkpoint_frequency])Train model with combined algorithms.
Attributes
- energy_fn#
Returns the energy function for the current parameters.
- params#