data.preprocessing#
Common operations to pre-process datasets for machine learning potentials.
Typical pre-processing steps for molecular reference data are subsampling, splitting into training, validation and testing sets, as well as scaling the positions into fractional coordinates.
Examples
An example of loading a small subset of a heavy-atom trajectory of alanine:
>>> from pathlib import Path
>>> root = Path.cwd().parent
>>> import jax.numpy as jnp
>>> from jax_md_mod import io
>>> from chemtrain.data.data_loaders import init_dataloaders
>>> from chemtrain.data.preprocessing import (
... get_dataset, scale_dataset_fractional, train_val_test_split )
We only get a subset of 10 conformations from the training data and scale the conformations to fractional coordinates:
>>> box = jnp.ones(3)
>>> position_data = get_dataset(
... root / "examples/data/positions_ethane.npy",
... retain=10)
>>> force_data = get_dataset(
... root / "examples/data/forces_ethane.npy",
... retain=10)
>>> position_data = scale_dataset_fractional(position_data, box)
We split the dataset into a training, validation and testing set:
>>> train, val, test = train_val_test_split(position_data, train_ratio=0.8, shuffle=False)
>>> bool(jnp.all(train[0, 4, :] == position_data[0, 4, :]))
True
Alternatively, we can directly instanciate jax_sgmc data-loaders based
on the split datasets by using:
>>> dataset = {“positions”: position_data, “forces”: force_data}
>>> train_loader, val_loader, test_loader = init_dataloaders(dataset)
>>> print(train_loader.static_information)
{‘observation_count’: 7}
Load Data#
- get_dataset(data_location_str, retain=0, subsampling=1, offset=0)[source]#
Loads dataset from a
"*.npy"array file.- Parameters:
data_location_str – String of
"*.npy"data locationretain – Number of samples to keep in the dataset. All by default.
subsampling – Only keep every n-th sample of the data.
offset – Select which part of data to be used. Last part by default.
- Returns:
Sub-sampled array of reference data.
Preprocess Data#
- scale_dataset_fractional(positions, reference_box=None, box=None)[source]#
Scales a dataset of positions from real space to fractional coordinates.
- Parameters:
positions – An array with shape
(N_snapshots, N_particles, 3)with particle positions.reference_box – A 1 or 2-dimensional
jax_mdbox. If not provided, the box is assumed to be dynamic.box – An array of 1 or 2-dimensional boxes, corresponding to the individual samples.
- Returns:
Returns an array with shape
(N_snapshots, N_particles, 3)with particle positions in fractional coordinates.
- map_dataset(position_dataset, displacement_fn, shift_fn, c_map, d_map=None, force_dataset=None)[source]#
Maps fine-scaled positions and forces to a coarser scale.
Uses the linear mapping from [Noid2008] to map fine-scaled positions and forces to coarse grained positions and forces via the relations:
\[ \begin{align}\begin{aligned}\mathbf R_I = \sum_{i \in \mathcal I_I} c_{Ii} \mathbf r_i,\quad \text{and}\\\mathbf{F}_I = \sum_{i \in \mathcal I_I} \frac{d_{Ii}}{c_{Ii}} \mathbf f_i.\end{aligned}\end{align} \]- Parameters:
position_dataset – Dataset of fine-scaled positions.
displacement_fn – Function to compute the displacement between two sets of coordinates. Necessary to handle boundary conditions.
shift_fn – Ensures that the produced coordinates remain in the box.
c_map – Matrix $c_{Ii}$ defining the linear mapping of positions.
d_map – Matrix $d_{Ii}$ defining the linear mapping of forces in combination with $c_{Ii}$.
force_dataset – Dataset of fine-scaled forces.
- Returns:
Returns the coarse-grained positions and, if provided, coarse-grained forces.
References
[Noid2008]W. G. Noid, Jhih-Wei Chu, Gary S. Ayton, Vinod Krishna, Sergei Izvekov, Gregory A. Voth, Avisek Das, Hans C. Andersen; The multiscale coarse-graining method. I. A rigorous bridge between atomistic and coarse-grained models. J. Chem. Phys. 28 June 2008; 128 (24): 244114. https://doi-org.eaccess.tum.edu/10.1063/1.2938860
Create Splits#
- train_val_test_split(dataset, train_ratio=0.7, val_ratio=0.1, shuffle=False, shuffle_seed=0)[source]#
Split data into disjoint subsets for training, validation and testing.
Splitting works on arbitrary pytrees, including chex.dataclasses, dictionaries, and single arrays.
The function splits the pytree leaves along their first dimension.
If a subset ratio ratios is
0, returnsNonefor the respective subset.- Parameters:
dataset – Dataset as pytree. Samples are assumed to be stacked along the first dimension of the pytree leaves.
train_ratio – Fraction of dataset to use for training.
val_ratio – Fraction of dataset to use for validation.
shuffle – If True, shuffles data before splitting into train-val-test. Shuffling copies the dataset.
shuffle_seed – PRNG Seed for data shuffling
- Returns:
Returns a tuple
(train_data, val_data, test_data), where each tuple element has the same pytree structure as the input pytree.
Create Neighborlists#
- allocate_neighborlist(dataset, displacement, box, r_cutoff, capacity_multiplier=1.0, disable_cell_list=True, fractional_coordinates=True, format=NeighborListFormat.Dense, pairwise_distances=True, box_key=None, mask_key=None, reps_key=None, batch_size=1000, init_kwargs=None, count_triplets=False, **static_kwargs)[source]#
Allocates an optimally sized neighbor list.
>>> import jax.numpy as jnp >>> import jax_md_mod >>> from jax_md import space
>>> # Example Dataset >>>
- Parameters:
dataset – A dictionary containing the dataset with key
"R"for positions.displacement (
Union[Callable[[Array,Array],Array],Callable[[Array,Array],float]]) – A function d(R_a, R_b) that computes the displacement between pairs of points.box (
Array) – Either a float specifying the size of the box, an array of shape [spatial_dim] specifying the box size for a cubic box in each spatial dimension, or a matrix of shape [spatial_dim, spatial_dim] that is upper triangular and specifies the lattice vectors of the box.r_cutoff (
float) – A scalar specifying the neighborhood radius.capacity_multiplier (
float) – A floating point scalar specifying the fractional increase in maximum neighborhood occupancy we allocate compared with the maximum in the example positions.disable_cell_list (
bool) – An optional boolean. If set to True then the neighbor list is constructed using only distances. This can be useful for debugging but should generally be left as False.fractional_coordinates (
bool) – An optional boolean. Specifies whether positions will be supplied in fractional coordinates in the unit cube, \([0, 1]^d\). If this is set to True then the box_size will be set to 1.0 and the cell size used in the cell list will be set to cutoff / box_size.format (
NeighborListFormat) – The format of the neighbor list; see theNeighborListFormat()enum for details about the different choices for formats. Defaults to Dense.pairwise_distances (
bool) – Computes pairwise distances between every particles for every sample.box_key (
str) – The key in the dataset dictionary that contains the box. If not provided, uses the box argument.mask_key (
str) – The key in the dataset dictionary that contains the mask. If not provided, all particles are considered valid.reps_key (
str) – The key in the dataset dictionary that contains the number of replicas a supercell. If set, the neighborlist will only contain edge senders from the first appearing replica.batch_size (
int) – Evaluate multiple samples in parallel.init_kwargs (
dict) – Keyword arguments passed to the neighbor list allocation, e.g., to specify a capacity multiplier.count_triplets (
bool) – An optional boolean. If set to True, the function will return the maximum number of triplets, similar to the maximum number of edges.**static_kwargs – kwargs that get threaded through the calculation of example positions.
- Return type:
Tuple[NeighborList,Tuple[int,int,float,Optional[int,None]]]- Returns:
Returns a neighbor list that fits the dataset.