# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for direct learning of per-snapshot quantities.
Directly learnable quantities are, for example, energy, forces, or virial
pressure.
"""
import collections
from typing import Callable, TypedDict, Tuple, List, Dict, DefaultDict, Union
from typing_extensions import Required
from jax import vmap, value_and_grad, numpy as jnp
from jax_sgmc.data.numpy_loader import NumpyDataLoader
from chemtrain.learn import max_likelihood
from chemtrain.ensemble import evaluation
from jax_md_mod import custom_quantity
from chemtrain.typing import EnergyFnTemplate, ArrayLike, NeighborList, ErrorFn, \
ComputeFn
# Note:
# Computing the neighbor list in each snapshot is not efficient for DimeNet++,
# which constructs a sparse graph representation afterward. However, other
# models such as the tabulated potential are inefficient if used without
# neighbor list as many cut-off interactions are otherwise computed.
# For the sake of a simpler implementation, the slight inefficiency
# in the case of DimeNet++ is accepted for now.
[docs]
class AtomisticDataset(TypedDict, total=False):
"""Atomistic data for force-matching.
Args:
R: Particle positions
U: Potential energies
F: Forces
p: Pressures
kT: Temperatures
"""
R: Required[ArrayLike]
U: ArrayLike
F: ArrayLike
p: ArrayLike
kT: ArrayLike
[docs]
def build_dataset(position_data: ArrayLike,
energy_data: ArrayLike = None,
force_data: ArrayLike = None,
virial_data: ArrayLike = None,
kt_data: ArrayLike = None,
**extra_data) -> Tuple[AtomisticDataset]:
"""Builds the force-matching dataset depending on available data.
Example:
For force matching, the reference data constist of particle positions
and target forces.
>>> from chemtrain.learn.force_matching import build_dataset
>>> position_data = [...]
>>> force_data = [...]
The dataset for training is can be created via:
>>> dataset = build_dataset(
... position_data=position_data, force_data=force_data)
>>> print(dataset)
{'R': [Ellipsis], 'F': [Ellipsis]}
Args:
position_data: Reference particle positions
energy_data: Reference potential energies
force_data: Reference forces
virial_data: Reference virials
kt_data: Reference temperatures
Returns:
Returns the canonicalized dataset and a list of keys specifying the
trainable targets.
"""
dataset = {'R': position_data}
if energy_data is not None:
dataset['U'] = energy_data
if force_data is not None:
dataset['F'] = force_data
if virial_data is not None:
dataset['p'] = virial_data
if kt_data is not None:
dataset['kT'] = kt_data
dataset.update(extra_data)
return dataset
def _split_targets_inputs(observation, quantities):
dynamic_kwargs, targets = {}, {}
for key in observation.keys():
if key in quantities:
targets[key] = observation[key]
else:
dynamic_kwargs[key] = observation[key]
assert set(quantities.keys()) == set(targets.keys()), (
'All trainig targets must be present in the observation data.'
)
return dynamic_kwargs, targets
def state_from_positions(input_dict: Dict[str, ArrayLike]):
"""Extracts the state of the system from the particle positions.
Args:
input_dict: Dictionary containing particle positions under 'R'.
Returns:
State of the system.
"""
state = evaluation.SimpleState(input_dict.pop('R'))
return state, input_dict
# TODO: Initialize predictions for all kinds of quantities
[docs]
def init_model(nbrs_init: NeighborList,
quantities: Dict[str, ComputeFn],
state_from_input: Callable = None,
feature_extract_fns: Dict[str, Callable] = None):
"""Initialize prediction function for a single snapshot.
The prediction function computed the energy, force, and virial (if provided)
based on a single conformation and returns the results in a canonical
format.
Note:
The prediction function does not check whether the neighbor list
overflowed.
Args:
nbrs_init: Initial neighbor list.
quantities: Dictionary of snapshot functions, e.g., energy and forces.
state_from_input: Function to build a system state from the input data.
Not necessary, if the state is already a key in the observations.
feature_extract_fns: Additional quantities, computed before the
snapshots and available to all snapshot compute functions.
Returns:
Returns a function that computes snapshots given energy parameters and
observations (inputs).
"""
if feature_extract_fns is None:
feature_extract_fns = {}
if state_from_input is None:
state_from_input = state_from_positions
def fm_model(energy_params, observations):
# Remove default arguments if not provided in dataset
if 'F' not in observations.keys():
quantities.pop('F', None)
if 'U' not in observations.keys():
quantities.pop('U', None)
dynamic_kwargs, _ = _split_targets_inputs(observations, quantities)
# Provides the possibility to add a more detailed state of the
# system, i.e., with velocities, box, etc.
if 'state' in dynamic_kwargs:
states = dynamic_kwargs.pop('state')
else:
states, dynamic_kwargs = vmap(state_from_input)(dynamic_kwargs)
batch_size = states.position.shape[0]
predictions = evaluation.quantity_map(
states, quantities, nbrs_init, dynamic_kwargs, {},
energy_params, batch_size, feature_extract_fns
)
return predictions
return fm_model
[docs]
def init_loss_fn(error_fns: Union[ErrorFn, dict[str, ErrorFn]] = None,
individual: bool = True,
gammas: dict[str, float] = None,
weights_keys: Dict[str, str] = None):
"""Initializes loss function for energy/force matching.
Args:
error_fns: Functions quantifying the deviation of the model and the
targets. By default, mean-squared error functions.
individual: Return the loss values for the individual targets, e.g., for
testing purposes. If False, the loss function returns a scalar loss
value from the individual loss contributions, weighted by the
``gamma_`` coefficients.
gammas: Weights for the per-target losses in the total loss.
weights_keys: Dictionary specifying weight keys in the dataset for
individual targets. The weights determine the per-sample
contribution for the specific target.
Returns:
Returns a function ``loss_fn(predictions, targets)``, which returns a
scalar loss value for a batch of predictions and targets.
"""
if gammas is None:
gammas = {}
if weights_keys is None:
weights_keys = {}
# Preserve old behaviour if error_fns is a single function.
if isinstance(error_fns, collections.abc.Callable):
_error_fns = collections.defaultdict(lambda: error_fns)
else:
_error_fns = collections.defaultdict(lambda: max_likelihood.mse_loss)
if isinstance(error_fns, dict):
_error_fns.update(error_fns)
def loss_fn(predictions, targets):
errors = {}
loss_val = 0.
# Always present.
if 'U' in targets.keys():
weights = targets.get(weights_keys.get('U'))
errors['U'] = _error_fns['U'](predictions['U'], targets['U'], weights=weights)
loss_val += gammas.get('U', 1.0) * errors['U']
if 'F' in targets.keys():
weights = targets.get(weights_keys.get('F'))
errors['F'] = _error_fns['F'](predictions['F'], targets['F'], weights=weights)
loss_val += gammas.get('F', 1.0) * errors['F']
for key, gamma in gammas.items():
if key in ['U', 'F']: continue
weights = None
if key in weights_keys.keys():
weights = targets[weights_keys[key]]
errors[key] = _error_fns[key](predictions[key], targets[key], weights=weights)
loss_val += gamma * errors[key]
if individual:
return loss_val, errors
else:
return loss_val
return loss_fn