trainers.SGMCForceMatching#
- class SGMCForceMatching(sgmc_solver, init_samples, val_dataloader=None, energy_fn_template=None)[source]#
Trainer for stochastic gradient Markov-chain Monte Carlo training based on force-matching.
- init_samples: A list, possibly of size 1, of sets of initial MCMC samples,
where each spawns a dedicated MCMC chain,
Methods
__init__(sgmc_solver, init_samples[, ...])A reference energy_fn_template can be provided, but is not mandatory due to the dependence of the template on the box via the displacement function.
checkpoint(name, object)Marks attribute to be saved in a partial checkpoint.
load_energy_params(file_path)Loads energy parameters.
Move all attributes that are expected to be on device to device to avoid TracerExceptions after loading trainers from disk, i.e. loading numpy rather than device arrays.
restore(checkpoint)Restores the trainer from a checkpoint.
save_energy_params(file_path[, save_format, ...])Saves energy parameters.
save_trainer(save_path)Save the trainer to a file.
train(iterations)Training of any trainer should start by calling train.
Attributes
- energy_fn#
Returns the energy function for the current parameters.
- list_of_params#
A list of the sampled parameters.
- params#
Get the sampled parameters from all chains.