trainers.ForceMatching#

class ForceMatching(init_params, optimizer, energy_fn_template, nbrs_init, gammas=None, error_fns=None, weights_keys=None, additional_targets=None, feature_extract_fns=None, energy_fn_has_aux=False, batch_per_device=1, batch_cache=10, full_checkpoint=False, disable_shmap=False, penalty_fn=None, convergence_criterion='window_median', log_file='force_matching.log', checkpoint_path='checkpoints')[source]#

Parametrizes potential models via the Force Matching method.

The Force Matching method can be used to learn atomistic [1] and coarse-grained [2] models from first-principle or atomistic reference data.

Parameters:
  • init_params – Initial energy parameters.

  • energy_fn_template (EnergyFnTemplate) – Function that takes energy parameters and returns an energy function.

  • nbrs_init (NeighborList) – Initial neighbor list. The neighbor list must be large enough to not overflow for any sample of the dataset.

  • optimizer – Optimizer from optax.

  • gammas (Dict[str, float]) – Coefficients for the individual targets in the weighted loss.

  • weights_keys (Dict[str, str]) – Dictionary to entries of the dataset that contain a per-sample weight for the total loss.

  • additional_targets (Dict[str, Dict]) – Additional snapshot targets to train on. Forces and energy are derived automatically from the energy_fn_template.

  • feature_extract_fns (Dict[str, Callable]) – Features to extract from the data, passed to all snapshot functions as keyword arguments.

  • energy_fn_has_aux (bool) – Energy function has an auxiliary output. The energy function will be called with argument mode="with_aux" and should return a tuple (pot, aux).

  • batch_per_device (int) – Number of samples to process vectorized on every device.

  • batch_cache (int) – Number of batches to load into the device memories.

  • full_checkpoint (bool) – Save the whole trainer instead of only some statistics.

  • disable_shmap (bool) – Use pmap instead of shmap for parallelization.

  • penalty_fn (Callable) – Penalty depending only on the parameters.

  • convergence_criterion (str) – Check convergence via base.EarlyStopping.

  • checkpoint_path (PathLike) – Path to the folder to store checkpoints.

  • log_file (str) – Path to file where to log training progress.

Warning

Currently neighborlist overflow is not checked. Make sure to build nbrs_init large enough.

References

Methods

__init__(init_params, optimizer, ...[, ...])

A reference energy_fn_template can be provided, but is not mandatory due to the dependence of the template on the box via the displacement function.

add_task(trigger, fn_or_method)

Adds a tasks to perform regularly during training.

checkpoint(name, object)

Marks attribute to be saved in a partial checkpoint.

evaluate([stage, loss_fn, params])

Computes the loss on the whole dataset.

evaluate_mae_testset()

Prints the Mean Absolute Error for every target on the test set.

limit_batches_per_epoch([max_batches])

Limits the number of batches per epoch.

load_energy_params(file_path)

Loads energy parameters.

move_to_device()

Transforms all arrays of the trainer state to JAX arrays.

predict(dataset[, params, batch_size])

Computes predictions for a dataset.

print_training_tasks()

Prints the tasks performed by the trainer.

reset_convergence_losses()

Resets early stopping convergence monitoring.

restore(checkpoint)

Restores the trainer from a checkpoint.

save_energy_params(file_path[, save_format, ...])

Saves energy parameters.

save_trainer(save_path[, format])

Saves whole trainer, e.g. for production after training.

set_batches_per_epoch([stage, max_batches])

Limits the number of updates within an epoch.

set_dataset(dataset[, stage, shuffle, ...])

Sets the dataset for a single stage, e.g., training.

set_datasets(dataset[, train_ratio, ...])

Sets the datasets for training, testing and validation.

set_loader(data_loader[, stage, ...])

Sets a data loader for a specific stage, e.g., training.

train(max_epochs[, thresh, checkpoint_freq])

Trains for a maximum number of epochs, checkpoints after a specified number of epochs and ends training if a convergence criterion is met.

update_with_samples(**sample)

A single params update step, where a batch is taken from the training set and samples of the batch are substituted by the provided samples.

Attributes

best_inference_params#

Returns best model params irrespective whether dropout is used.

best_inference_params_replicated#

Returns the best inference params replicated on every device.

best_params#

Returns the best parameters based on the validation loss.

If training was performed with early stopping, return the best parameters to this criterion instead.

energy_fn#

Returns the energy function for the current parameters.

params#

Current energy parameters.