trainers.ForceMatching#
- class ForceMatching(init_params, optimizer, energy_fn_template, nbrs_init, gammas=None, error_fns=None, weights_keys=None, additional_targets=None, feature_extract_fns=None, energy_fn_has_aux=False, batch_per_device=1, batch_cache=10, full_checkpoint=False, disable_shmap=False, penalty_fn=None, convergence_criterion='window_median', log_file='force_matching.log', checkpoint_path='checkpoints')[source]#
Parametrizes potential models via the Force Matching method.
The Force Matching method can be used to learn atomistic [1] and coarse-grained [2] models from first-principle or atomistic reference data.
- Parameters:
init_params – Initial energy parameters.
energy_fn_template (
EnergyFnTemplate) – Function that takes energy parameters and returns an energy function.nbrs_init (
NeighborList) – Initial neighbor list. The neighbor list must be large enough to not overflow for any sample of the dataset.optimizer – Optimizer from optax.
gammas (
Dict[str,float]) – Coefficients for the individual targets in the weighted loss.weights_keys (
Dict[str,str]) – Dictionary to entries of the dataset that contain a per-sample weight for the total loss.additional_targets (
Dict[str,Dict]) – Additional snapshot targets to train on. Forces and energy are derived automatically from the energy_fn_template.feature_extract_fns (
Dict[str,Callable]) – Features to extract from the data, passed to all snapshot functions as keyword arguments.energy_fn_has_aux (
bool) – Energy function has an auxiliary output. The energy function will be called with argumentmode="with_aux"and should return a tuple(pot, aux).batch_per_device (
int) – Number of samples to process vectorized on every device.batch_cache (
int) – Number of batches to load into the device memories.full_checkpoint (
bool) – Save the whole trainer instead of only some statistics.disable_shmap (
bool) – Usepmapinstead ofshmapfor parallelization.penalty_fn (
Callable) – Penalty depending only on the parameters.convergence_criterion (
str) – Check convergence viabase.EarlyStopping.checkpoint_path (
PathLike) – Path to the folder to store checkpoints.log_file (
str) – Path to file where to log training progress.
Warning
Currently neighborlist overflow is not checked. Make sure to build nbrs_init large enough.
References
Methods
__init__(init_params, optimizer, ...[, ...])A reference energy_fn_template can be provided, but is not mandatory due to the dependence of the template on the box via the displacement function.
add_task(trigger, fn_or_method)Adds a tasks to perform regularly during training.
checkpoint(name, object)Marks attribute to be saved in a partial checkpoint.
evaluate([stage, loss_fn, params])Computes the loss on the whole dataset.
Prints the Mean Absolute Error for every target on the test set.
limit_batches_per_epoch([max_batches])Limits the number of batches per epoch.
load_energy_params(file_path)Loads energy parameters.
Transforms all arrays of the trainer state to JAX arrays.
predict(dataset[, params, batch_size])Computes predictions for a dataset.
Prints the tasks performed by the trainer.
Resets early stopping convergence monitoring.
restore(checkpoint)Restores the trainer from a checkpoint.
save_energy_params(file_path[, save_format, ...])Saves energy parameters.
save_trainer(save_path[, format])Saves whole trainer, e.g. for production after training.
set_batches_per_epoch([stage, max_batches])Limits the number of updates within an epoch.
set_dataset(dataset[, stage, shuffle, ...])Sets the dataset for a single stage, e.g., training.
set_datasets(dataset[, train_ratio, ...])Sets the datasets for training, testing and validation.
set_loader(data_loader[, stage, ...])Sets a data loader for a specific stage, e.g., training.
train(max_epochs[, thresh, checkpoint_freq])Trains for a maximum number of epochs, checkpoints after a specified number of epochs and ends training if a convergence criterion is met.
update_with_samples(**sample)A single params update step, where a batch is taken from the training set and samples of the batch are substituted by the provided samples.
Attributes
- best_inference_params#
Returns best model params irrespective whether dropout is used.
- best_inference_params_replicated#
Returns the best inference params replicated on every device.
- best_params#
Returns the best parameters based on the validation loss.
If training was performed with early stopping, return the best parameters to this criterion instead.
- energy_fn#
Returns the energy function for the current parameters.
- params#
Current energy parameters.