learn.force_matching#

Functions for direct learning of per-snapshot quantities.

Directly learnable quantities are, for example, energy, forces, or virial pressure.

Dataset#

Utility functions to process datasets of per-snapshot quantities. These utilities check the dataset for consistent keys, differentiating between inputs to the model (e.g., atomic positions) and targets (e.g., atomic forces).

AtomisticDataset()[source]#

Atomistic data for force-matching.

Parameters:
  • R – Particle positions

  • U – Potential energies

  • F – Forces

  • p – Pressures

  • kT – Temperatures

build_dataset(position_data, energy_data=None, force_data=None, virial_data=None, kt_data=None, **extra_data)[source]#

Builds the force-matching dataset depending on available data.

Example

For force matching, the reference data constist of particle positions and target forces.

>>> from chemtrain.learn.force_matching import build_dataset
>>> position_data = [...]
>>> force_data = [...]

The dataset for training is can be created via:

>>> dataset = build_dataset(
...     position_data=position_data, force_data=force_data)
>>> print(dataset)
{'R': [Ellipsis], 'F': [Ellipsis]}
Parameters:
Return type:

Tuple[AtomisticDataset]

Returns:

Returns the canonicalized dataset and a list of keys specifying the trainable targets.

Model#

The input to the learning problems is always an energy_fn_template function. To match forces and/or virials, the computation of these quantities from the energy function must first be initialized, e.g., using the following functions.

init_model(nbrs_init, quantities, state_from_input=None, feature_extract_fns=None)[source]#

Initialize prediction function for a single snapshot.

The prediction function computed the energy, force, and virial (if provided) based on a single conformation and returns the results in a canonical format.

Note

The prediction function does not check whether the neighbor list overflowed.

Parameters:
  • nbrs_init (NeighborList) – Initial neighbor list.

  • quantities (Dict[str, ComputeFn]) – Dictionary of snapshot functions, e.g., energy and forces.

  • state_from_input (Callable) – Function to build a system state from the input data. Not necessary, if the state is already a key in the observations.

  • feature_extract_fns (Dict[str, Callable]) – Additional quantities, computed before the snapshots and available to all snapshot compute functions.

Returns:

Returns a function that computes snapshots given energy parameters and observations (inputs).

Loss#

These are functions to initialize advanced loss functions, e.g., combining the losses for force and energy predictions into a single loss value.

init_loss_fn(error_fns=None, individual=True, gammas=None, weights_keys=None)[source]#

Initializes loss function for energy/force matching.

Parameters:
  • error_fns (Union[ErrorFn, dict[str, ErrorFn]]) – Functions quantifying the deviation of the model and the targets. By default, mean-squared error functions.

  • individual (bool) – Return the loss values for the individual targets, e.g., for testing purposes. If False, the loss function returns a scalar loss value from the individual loss contributions, weighted by the gamma_ coefficients.

  • gammas (dict[str, float]) – Weights for the per-target losses in the total loss.

  • weights_keys (Dict[str, str]) – Dictionary specifying weight keys in the dataset for individual targets. The weights determine the per-sample contribution for the specific target.

Returns:

Returns a function loss_fn(predictions, targets), which returns a scalar loss value for a batch of predictions and targets.