learn.force_matching#
Functions for direct learning of per-snapshot quantities.
Directly learnable quantities are, for example, energy, forces, or virial pressure.
Dataset#
Utility functions to process datasets of per-snapshot quantities. These utilities check the dataset for consistent keys, differentiating between inputs to the model (e.g., atomic positions) and targets (e.g., atomic forces).
- AtomisticDataset()[source]#
Atomistic data for force-matching.
- Parameters:
R – Particle positions
U – Potential energies
F – Forces
p – Pressures
kT – Temperatures
- build_dataset(position_data, energy_data=None, force_data=None, virial_data=None, kt_data=None, **extra_data)[source]#
Builds the force-matching dataset depending on available data.
Example
For force matching, the reference data constist of particle positions and target forces.
>>> from chemtrain.learn.force_matching import build_dataset >>> position_data = [...] >>> force_data = [...]
The dataset for training is can be created via:
>>> dataset = build_dataset( ... position_data=position_data, force_data=force_data) >>> print(dataset) {'R': [Ellipsis], 'F': [Ellipsis]}
- Parameters:
position_data (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Reference particle positionsenergy_data (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Reference potential energiesforce_data (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Reference forcesvirial_data (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Reference virialskt_data (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Reference temperatures
- Return type:
- Returns:
Returns the canonicalized dataset and a list of keys specifying the trainable targets.
Model#
The input to the learning problems is always an energy_fn_template function.
To match forces and/or virials, the computation of these quantities from the
energy function must first be initialized, e.g., using the following
functions.
- init_model(nbrs_init, quantities, state_from_input=None, feature_extract_fns=None)[source]#
Initialize prediction function for a single snapshot.
The prediction function computed the energy, force, and virial (if provided) based on a single conformation and returns the results in a canonical format.
Note
The prediction function does not check whether the neighbor list overflowed.
- Parameters:
nbrs_init (
NeighborList) – Initial neighbor list.quantities (
Dict[str,ComputeFn]) – Dictionary of snapshot functions, e.g., energy and forces.state_from_input (
Callable) – Function to build a system state from the input data. Not necessary, if the state is already a key in the observations.feature_extract_fns (
Dict[str,Callable]) – Additional quantities, computed before the snapshots and available to all snapshot compute functions.
- Returns:
Returns a function that computes snapshots given energy parameters and observations (inputs).
Loss#
These are functions to initialize advanced loss functions, e.g., combining the losses for force and energy predictions into a single loss value.
- init_loss_fn(error_fns=None, individual=True, gammas=None, weights_keys=None)[source]#
Initializes loss function for energy/force matching.
- Parameters:
error_fns (
Union[ErrorFn,dict[str,ErrorFn]]) – Functions quantifying the deviation of the model and the targets. By default, mean-squared error functions.individual (
bool) – Return the loss values for the individual targets, e.g., for testing purposes. If False, the loss function returns a scalar loss value from the individual loss contributions, weighted by thegamma_coefficients.gammas (
dict[str,float]) – Weights for the per-target losses in the total loss.weights_keys (
Dict[str,str]) – Dictionary specifying weight keys in the dataset for individual targets. The weights determine the per-sample contribution for the specific target.
- Returns:
Returns a function
loss_fn(predictions, targets), which returns a scalar loss value for a batch of predictions and targets.