Source code for chemtrain.ensemble.evaluation

# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Callable, NamedTuple

import numpy as onp
from jax import tree_util, numpy as jnp, vmap, lax, Array
from jax._src.basearray import ArrayLike

from jax_md_mod import custom_partition
from jax_md import simulate
from jax_md.partition import NeighborList

from chemtrain import util
from chemtrain.typing import QuantityDict

from typing import Protocol

"""Functions to compute snapshots quantities for many samples."""

[docs] class State(Protocol): """State of a molecular system. All states must at least prescribe the particle positions. Other attributes, such as velocities, forces, etc., might be necessary for some instantaneous quantities. Attributes: position: Particle positions """ position: ArrayLike
[docs] class SimpleState(NamedTuple): """Simplest state of a molecular system. Args: position: Particle positions """ position: ArrayLike
[docs] def quantity_map(states: State, quantities: QuantityDict, nbrs: NeighborList = None, state_kwargs: Dict[str, Array] = None, constant_state_kwargs: Dict[str, Array] = None, energy_params: Any = None, batch_size: int = 1, feature_extract_fns: Dict[str, Callable] = None): """Computes quantities of interest for all states in a trajectory. Arbitrary quantity functions can be provided via the quantities-dict. The quantities dict provides the function to compute the quantity on a single snapshot. The resulting quantity trajectory will be saved in a dict under the same key as the input quantity function. Example usage: .. code-block:: python def custom_compute_fn(state, neighbor=None, feature=None, **kwargs): ... return quantity_snapshot quantities = { 'energy': custom_quantity.energy_wrapper(energy_template_fn), 'custom_quantity': custom_compute_fn } # Results will be available to all snapshot compute functions feature_extract_fns = { 'feature': custom_feature_compute_fn } quantity_trajs = quantity_map( trajectory, quantities, reference_nbrs, dynamic_kwargs, energy_params, feature_extract_fns=feature_extract_fns ) custom_quantity = quantity_trajs['custom_quantity'] Args: states: System states, concatenated along the first dimensions of the arrays. quantities: The quantity dict containing for each target quantity the snapshot compute function nbrs: Reference neighbor list to compute new neighbor list state_kwargs: Kwargs to supply reference ``'kT'`` and/or ``'pressure'`` to the energy function or the quantity functions. constant_state_kwargs: Kwargs to supply information to the energy function that is constant over all states. energy_params: Energy params for energy_fn_template to initialize the current energy_fn batch_size: Number of batches for vmap feature_extract_fns: Callables to compute features accessible to all snapshot compute functions. Returns: A dict of quantity trajectories saved under the same key as the input quantity function. """ return quantity_multimap( states, quantities=quantities, nbrs=nbrs, state_kwargs=state_kwargs, constant_state_kwargs=constant_state_kwargs, energy_params=energy_params, batch_size=batch_size, feature_extract_fns=feature_extract_fns)
[docs] def quantity_multimap(*states: State, quantities: QuantityDict, nbrs: NeighborList = None, state_kwargs: Dict[str, Array] = None, constant_state_kwargs: Dict[str, Array] = None, energy_params: Any = None, batch_size: int = 1, feature_extract_fns: Dict[str, Callable] = None): """Computes quantities of interest for all states in a trajectory. This function extends :func:`quantity_map` to quantities with respect to multiple reference states. Therefore, the quantity function signature changes to .. code-block:: python def quantity_fn(*states, neighbor=None, energy_params=None, **kwargs): ... The keywords arguments, i.e. the neighbor list, are with respect to the first state of `*states`. Args: states: System states, concatenated along the first dimensions of the arrays. quantities: The quantity dict containing for each target quantity the snapshot compute function nbrs: Reference neighbor list to compute new neighbor list state_kwargs: Kwargs to supply reference ``'kT'`` and/or ``'pressure'`` to the energy function or the quantity functions. constant_state_kwargs: Kwargs to supply information to the energy function that is constant over all states. energy_params: Energy params for energy_fn_template to initialize the current energy_fn batch_size: Number of batches for vmap feature_extract_fns: Callables to compute features accessible to all snapshot compute functions. Returns: A dict of quantity trajectories saved under the same key as the input quantity function. """ nbrs_update_fn = nbrs.update_fn # Check that all states have the same format if state_kwargs is None: state_kwargs = {} if constant_state_kwargs is None: constant_state_kwargs = {} assert len(states) > 0, 'Need at least one trajectory.' ref_leaves, ref_struct = tree_util.tree_flatten(states[0]) for traj in states: assert ref_struct == tree_util.tree_structure(traj), ( "All trajectory states must have the same tree structure." ) assert onp.all([ jnp.shape(l) == jnp.shape(r) for r, l in zip(ref_leaves, tree_util.tree_leaves(traj)) ]), "All trajectory state leaves must be of identical shape." # Extract additional features, making them accessible to all snapshot # compute functions. For example, when predicting molecular properties # using a neural network. if feature_extract_fns is None: feature_extract_fns = {} else: assert len(states) == 1, ( "Feature extraction functions are only supported for single " "trajectory." ) @vmap def single_state_quantities(single_snapshot): states, kwargs = single_snapshot kwargs.update(energy_params=energy_params) kwargs.update(constant_state_kwargs) # Add a masked neighbor list if masked and neighbor list are provided if util.is_npt_ensemble(states): box = simulate.npt_box(states[0]) kwargs['box'] = box if nbrs is not None: new_nbrs = nbrs_update_fn(states[0].position, nbrs, **kwargs) mask = kwargs.get( "mask", jnp.ones(new_nbrs.reference_position.shape[0])) kwargs["neighbor"] = custom_partition.mask_neighbor_list( new_nbrs, mask) # Extract additional features to all snapshot computation functions, # e.g., the neighbor list graph. Next features can be beased on # previously computed features. for key in feature_extract_fns.keys(): kwargs[key] = feature_extract_fns[key](states[0], **kwargs) if len(states) == 1: computed_quantities = { quantity_fn_key: quantities[quantity_fn_key](states[0], **kwargs) for quantity_fn_key in quantities } else: computed_quantities = { quantity_fn_key: quantities[quantity_fn_key](*states, **kwargs) for quantity_fn_key in quantities } return computed_quantities # If the batch size is larger than the number of samples, reduce the # batch size to compute all samples in one batch. # If the batch size does not divide the number of samples, we compute # the remainder in a separate call to avoid padding. ipt = (states, state_kwargs) if states[0].position.shape[0] < batch_size: batch_size = states[0].position.shape[0] remainder = states[0].position.shape[0] % batch_size if remainder > 0: rmd = tree_util.tree_map(lambda x: x[-remainder:], ipt) ipt = tree_util.tree_map(lambda x: x[:-remainder], ipt) batched_samples = util.tree_vmap_split(ipt, batch_size) batched_quantity_trajs = lax.map(single_state_quantities, batched_samples) quantity_trajs = util.tree_combine(batched_quantity_trajs) if remainder > 0: rmd_trajs = single_state_quantities(rmd) quantity_trajs = tree_util.tree_map( lambda a, b: jnp.concatenate([a, b], axis=0), quantity_trajs, rmd_trajs ) return quantity_trajs