Source code for chemtrain.data.preprocessing

# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Common operations to pre-process datasets for machine learning potentials.

Typical pre-processing steps for molecular reference data are subsampling,
splitting into training, validation and testing sets, as well as scaling
the positions into fractional coordinates.

Examples:

    An example of loading a small subset of a heavy-atom trajectory of alanine:

    >>> from pathlib import Path
    >>> root = Path.cwd().parent

    >>> import jax.numpy as jnp
    >>> from jax_md_mod import io
    >>> from chemtrain.data.data_loaders import init_dataloaders
    >>> from chemtrain.data.preprocessing import (
    ...     get_dataset, scale_dataset_fractional, train_val_test_split )

    We only get a subset of 10 conformations from the training data and scale the
    conformations to fractional coordinates:

    >>> box = jnp.ones(3)
    >>> position_data = get_dataset(
    ...     root / "examples/data/positions_ethane.npy",
    ...     retain=10)
    >>> force_data = get_dataset(
    ...     root / "examples/data/forces_ethane.npy",
    ...     retain=10)
    >>> position_data = scale_dataset_fractional(position_data, box)

    We split the dataset into a training, validation and testing set:

    >>> train, val, test = train_val_test_split(position_data, train_ratio=0.8, shuffle=False)
    >>> bool(jnp.all(train[0, 4, :] == position_data[0, 4, :]))
    True

    Alternatively, we can directly instanciate ``jax_sgmc`` data-loaders based
    on the split datasets by using:
    >>> dataset = {"positions": position_data, "forces": force_data}
    >>> train_loader, val_loader, test_loader = init_dataloaders(dataset)
    >>> print(train_loader.static_information)
    {'observation_count': 7}

"""

import functools

import numpy as onp

import jax
from jax import numpy as jnp, lax, tree_util

from jax_md_mod import custom_space, custom_partition
from jax_md import partition, space
from chemtrain import util

from typing import Tuple, Optional



[docs] def get_dataset(data_location_str, retain=0, subsampling=1, offset=0): """Loads dataset from a ``"*.npy"`` array file. Args: data_location_str: String of ``"*.npy"`` data location retain: Number of samples to keep in the dataset. All by default. subsampling: Only keep every n-th sample of the data. offset: Select which part of data to be used. Last part by default. Returns: Sub-sampled array of reference data. """ loaded_data = onp.load(data_location_str) if offset == 0: assert retain <= loaded_data.shape[0], ( f"Cannot retain more than {loaded_data.shape[0]} samples, got " f"retain = {retain}." ) loaded_data = loaded_data[-retain::subsampling] else: loaded_data = loaded_data[-retain-offset:-offset:subsampling] assert retain + offset <= offset <= loaded_data.shape, ( f"Cannot retain more than {loaded_data.shape[0] - offset} given " f"an offset of {offset}. Got retain = {retain}." ) return loaded_data
[docs] def train_val_test_split(dataset, train_ratio=0.7, val_ratio=0.1, shuffle=False, shuffle_seed=0): """Split data into disjoint subsets for training, validation and testing. Splitting works on arbitrary pytrees, including chex.dataclasses, dictionaries, and single arrays. The function splits the pytree leaves along their first dimension. If a subset ratio ratios is ``0``, returns ``None`` for the respective subset. Args: dataset: Dataset as pytree. Samples are assumed to be stacked along the first dimension of the pytree leaves. train_ratio: Fraction of dataset to use for training. val_ratio: Fraction of dataset to use for validation. shuffle: If True, shuffles data before splitting into train-val-test. Shuffling copies the dataset. shuffle_seed: PRNG Seed for data shuffling Returns: Returns a tuple ``(train_data, val_data, test_data)``, where each tuple element has the same pytree structure as the input pytree. """ assert train_ratio + val_ratio <= 1., 'Distribution of data exceeds 100%.' leaves, _ = tree_util.tree_flatten(dataset) dataset_size = leaves[0].shape[0] train_size = int(dataset_size * train_ratio) val_size = int(dataset_size * val_ratio) if shuffle: dataset_idxs = onp.arange(dataset_size) numpy_rng = onp.random.default_rng(shuffle_seed) numpy_rng.shuffle(dataset_idxs) def retreive_datasubset(idxs): data_subset = util.tree_take(dataset, idxs, axis=0) subset_leaves, _ = tree_util.tree_flatten(data_subset) subset_size = subset_leaves[0].shape[0] if subset_size == 0: data_subset = None return data_subset train_data = retreive_datasubset(dataset_idxs[:train_size]) val_data = retreive_datasubset(dataset_idxs[train_size: val_size + train_size]) test_data = retreive_datasubset(dataset_idxs[val_size + train_size:]) else: def retreive_datasubset(start, end): data_subset = util.tree_get_slice(dataset, start, end, to_device=False) subset_leaves, _ = tree_util.tree_flatten(data_subset) subset_size = subset_leaves[0].shape[0] if subset_size == 0: data_subset = None return data_subset train_data = retreive_datasubset(0, train_size) val_data = retreive_datasubset(train_size, train_size + val_size) test_data = retreive_datasubset(train_size + val_size, None) return train_data, val_data, test_data
[docs] def scale_dataset_fractional(positions, reference_box=None, box=None): """Scales a dataset of positions from real space to fractional coordinates. Args: positions: An array with shape ``(N_snapshots, N_particles, 3)`` with particle positions. reference_box: A 1 or 2-dimensional ``jax_md`` box. If not provided, the box is assumed to be dynamic. box: An array of 1 or 2-dimensional boxes, corresponding to the individual samples. Returns: Returns an array with shape ``(N_snapshots, N_particles, 3)`` with particle positions in fractional coordinates. """ if reference_box is None: reference_box = jnp.eye(positions.shape[-1]) _, scale_fn = custom_space.init_fractional_coordinates(reference_box) if box is not None: return jax.vmap(lambda R, box: scale_fn(R, box=box))(positions, box) else: return jax.vmap(scale_fn)(positions)
def scale_dataset(dataset, scale_R, scale_U, scale_e, fractional=True): """Scales the dataset from Hartee to kJ/mol and Bohr to nm.""" box = 10 * (dataset["R"].max() - dataset["R"].min()) if fractional: dataset['R'] = dataset['R'] / box else: dataset['R'] = dataset['R'] * scale_R print(f"Scale dataset by {scale_R} for R and {scale_U} for U.") scale_F = scale_U / scale_R dataset['box'] = scale_R * onp.tile(box * onp.eye(3), (dataset['R'].shape[0], 1, 1)) dataset['U'] *= scale_U dataset['F'] *= scale_F dataset['charge'] *= scale_e dataset['dipole'] *= scale_e * scale_R return dataset
[docs] def map_dataset(position_dataset, displacement_fn, shift_fn, c_map, d_map=None, force_dataset=None): """ Maps fine-scaled positions and forces to a coarser scale. Uses the linear mapping from [Noid2008]_ to map fine-scaled positions and forces to coarse grained positions and forces via the relations: .. math:: \\mathbf R_I = \\sum_{i \\in \\mathcal I_I} c_{Ii} \\mathbf r_i,\\quad \\text{and} \\mathbf{F}_I = \\sum_{i \\in \\mathcal I_I} \\frac{d_{Ii}}{c_{Ii}} \\mathbf f_i. Args: position_dataset: Dataset of fine-scaled positions. displacement_fn: Function to compute the displacement between two sets of coordinates. Necessary to handle boundary conditions. shift_fn: Ensures that the produced coordinates remain in the box. c_map: Matrix $c_{Ii}$ defining the linear mapping of positions. d_map: Matrix $d_{Ii}$ defining the linear mapping of forces in combination with $c_{Ii}$. force_dataset: Dataset of fine-scaled forces. Returns: Returns the coarse-grained positions and, if provided, coarse-grained forces. References: .. [Noid2008] W. G. Noid, Jhih-Wei Chu, Gary S. Ayton, Vinod Krishna, Sergei Izvekov, Gregory A. Voth, Avisek Das, Hans C. Andersen; *The multiscale coarse-graining method. I. A rigorous bridge between atomistic and coarse-grained models*. J. Chem. Phys. 28 June 2008; 128 (24): 244114. https://doi-org.eaccess.tum.edu/10.1063/1.2938860 """ # Normalise mapping weights c_norm = c_map / jnp.sum(c_map, axis=1, keepdims=True) if d_map is not None: d_norm = d_map / jnp.sum(d_map, axis=1, keepdims=True) else: d_norm = None def _map_single(ipt, shift_fn, displacement_fn, c_norm, d_norm): pos, forc = ipt # Choose reference for each CG bead ref_idx = jnp.argmax(c_map, axis=1) ref_positions = pos[ref_idx, :] # Compute displacements for each reference position and map disp = jax.vmap( lambda r: jax.vmap(lambda p: displacement_fn(p, r))(pos) )(ref_positions) cg_disp = jnp.einsum('Ii,Iid->Id', c_map, disp) cg_positions = jax.vmap(shift_fn)(ref_positions, cg_disp) if (forc is not None) and (d_norm is not None): mask = (c_norm > 0.0) safe_c = jnp.where(mask, c_norm, 1.0) cg_forces = jnp.einsum('Ii, id->Id', mask * d_norm / safe_c, forc) else: cg_forces = None return cg_positions, cg_forces _map_single = functools.partial(_map_single, shift_fn=shift_fn, displacement_fn=displacement_fn, c_norm=c_norm, d_norm=d_norm) if force_dataset is None: # map positions only return lax.map(lambda pos: _map_single((pos, None))[0], position_dataset) else: return lax.map(_map_single, (position_dataset, force_dataset))
[docs] def allocate_neighborlist(dataset, displacement: space.DisplacementOrMetricFn, box: space.Box, r_cutoff: float, capacity_multiplier: float = 1.0, disable_cell_list: bool = True, fractional_coordinates: bool = True, format: partition.NeighborListFormat = partition.NeighborListFormat.Dense, pairwise_distances: bool = True, box_key: str = None, mask_key: str = None, reps_key: str = None, batch_size: int = 1000, init_kwargs: dict = None, count_triplets: bool = False, **static_kwargs ) -> Tuple[partition.NeighborList, Tuple[int, int, float, Optional[int]]]: """Allocates an optimally sized neighbor list. .. doctest:: >>> import jax.numpy as jnp >>> import jax_md_mod >>> from jax_md import space >>> # Example Dataset >>> Args: dataset: A dictionary containing the dataset with key ``"R"`` for positions. displacement: A function `d(R_a, R_b)` that computes the displacement between pairs of points. box: Either a float specifying the size of the box, an array of shape `[spatial_dim]` specifying the box size for a cubic box in each spatial dimension, or a matrix of shape `[spatial_dim, spatial_dim]` that is upper triangular and specifies the lattice vectors of the box. r_cutoff: A scalar specifying the neighborhood radius. capacity_multiplier: A floating point scalar specifying the fractional increase in maximum neighborhood occupancy we allocate compared with the maximum in the example positions. disable_cell_list: An optional boolean. If set to `True` then the neighbor list is constructed using only distances. This can be useful for debugging but should generally be left as `False`. fractional_coordinates: An optional boolean. Specifies whether positions will be supplied in fractional coordinates in the unit cube, :math:`[0, 1]^d`. If this is set to True then the `box_size` will be set to `1.0` and the cell size used in the cell list will be set to `cutoff / box_size`. format: The format of the neighbor list; see the :meth:`NeighborListFormat` enum for details about the different choices for formats. Defaults to `Dense`. pairwise_distances: Computes pairwise distances between every particles for every sample. box_key: The key in the dataset dictionary that contains the box. If not provided, uses the box argument. mask_key: The key in the dataset dictionary that contains the mask. If not provided, all particles are considered valid. reps_key: The key in the dataset dictionary that contains the number of replicas a supercell. If set, the neighborlist will only contain edge senders from the first appearing replica. batch_size: Evaluate multiple samples in parallel. init_kwargs: Keyword arguments passed to the neighbor list allocation, e.g., to specify a capacity multiplier. count_triplets: An optional boolean. If set to `True`, the function will return the maximum number of triplets, similar to the maximum number of edges. **static_kwargs: kwargs that get threaded through the calculation of example positions. Returns: Returns a neighbor list that fits the dataset. """ # We use the masked neighbor list to avoid interference of masked particles # and required neighbor list capacity. neighbor_fn = custom_partition.masked_neighbor_list( displacement, r_cutoff, dr_threshold=None, capacity_multiplier=capacity_multiplier, format=format, ) assert pairwise_distances, ( "Currently, this function only works when computing distances between " "all pairs of particles (``pairwise_distances=True``)." ) @jax.jit def find_max_neighbors_and_edges(dataset): def number_of_neighbors(input): position, box, mask, reps = input if box is None: metric = space.canonicalize_displacement_or_metric(displacement) else: metric = space.canonicalize_displacement_or_metric( functools.partial(displacement, box=box)) pair_distances = space.map_product(metric)(position, position) # Find neighbors, discarding self-interactions and masked particles. is_neighbor = pair_distances <= r_cutoff is_neighbor = jnp.logical_and( is_neighbor, ~jnp.eye(is_neighbor.shape[0], dtype=jnp.bool_)) # Invalid particles cannot receive or send edges. if mask is not None: is_neighbor = jnp.logical_and(is_neighbor, mask[jnp.newaxis, :]) is_neighbor = jnp.logical_and(is_neighbor, mask[:, jnp.newaxis]) # Remove all replicated receivers if reps is not None: max_local = jnp.sum(mask) // reps include = max_local < jnp.arange(is_neighbor.shape[0]) is_neighbor = jnp.where( include[:, jnp.newaxis], is_neighbor, False) # Sets the number of neighbors to 0 for masked particles neighbors = jnp.sum(is_neighbor, axis=1) if mask is not None: neighbors *= mask # Compute the number of triplets. # First, we evaluate whether the pair of nodes are connected by an # edge to the same node. ji, jk = jax.vmap( functools.partial(jnp.meshgrid, indexing="ij") )(is_neighbor, is_neighbor) extra_out = [] if count_triplets: # We mask out pairs of identical edges. is_triplet = jnp.logical_and(ji, jk) is_triplet = jnp.logical_and( is_triplet, ~jnp.eye( is_triplet.shape[0], dtype=jnp.bool_ )[jnp.newaxis, ...] ) extra_out += [jnp.sum(is_triplet)] avg_neighbors = jnp.mean(neighbors) if mask is not None: avg_neighbors /= jnp.mean(mask) max_neighbors = jnp.max(neighbors) max_edges = jnp.sum(neighbors) return max_neighbors, max_edges, avg_neighbors, *extra_out # We find the sample with the maximum number of neighbors or edges return util.batch_map( number_of_neighbors, ( dataset["R"], dataset.get(box_key), dataset.get(mask_key), dataset.get(reps_key) ), batch_size=batch_size ) n_neighbors, n_edges, avg_neighbors, *extra = find_max_neighbors_and_edges( dataset) print( f"The dataset has max. {jnp.max(n_neighbors)} neighbors per particle " f"and max. {jnp.max(n_edges)} edges in total.") if format == partition.Dense: # The maximum neighbors per particle determine the capacity of the # neighbor list. sample_idx = jnp.argmax(n_neighbors) elif format == partition.Sparse: # The maximum number of edges determine the capacity of the neighbor list. sample_idx = jnp.argmax(n_edges) else: raise ValueError(f"Unsupported neighbor list format: {format}") extra_out = [] if count_triplets: n_triplets, = extra extra_out += [jnp.max(n_triplets)] if init_kwargs is None: init_kwargs = {} if box_key is not None: init_kwargs['box'] = jnp.asarray(dataset[box_key][sample_idx]) if mask_key is not None: init_kwargs['mask'] = jnp.asarray(dataset[mask_key][sample_idx]) nbrs_init = neighbor_fn.allocate( jnp.asarray(dataset["R"][sample_idx]), **init_kwargs) return nbrs_init, (n_neighbors.max(), n_edges.max(), avg_neighbors.mean(), *extra_out)