Source code for jax_md_mod.io

# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Functions for io: Loading data to and from Jax M.D."""
import jax.numpy as jnp
import mdtraj
from mdtraj import utils as mdtraj_utils
import numpy as onp

from typing import Union, List, Any
from os import PathLike

try:
    from jax.typing import ArrayLike
except:
    ArrayLike = Any

[docs] def load_box(filename, frame=-1, top=None): """Loads initial configuration using the file loader from MDTraj. Args: filename: String providing the location of the file to load. frame: Frame of the trajectory to read data from. top: Topology file, necessary if the file does not contain sufficient information to generate the system topology. Returns: Tuple of jnp arrays of box, coordinates, mass, and species. """ if top is not None: top = mdtraj.load_topology(top) traj = mdtraj.load(filename, top=top) coordinates = traj.xyz[frame] lengths = traj.unitcell_lengths[frame] angles = traj.unitcell_angles[frame] if angles is not None: vectors = mdtraj_utils.unitcell.lengths_and_angles_to_box_vectors( *lengths, *angles) box = jnp.stack(vectors, axis=-1) elif lengths is not None: box = jnp.asarray(lengths) species = onp.zeros(coordinates.shape[0]) masses = onp.zeros_like(species) for atom in traj.topology.atoms: species[atom.index] = atom.element.number masses[atom.index] = atom.element.mass # _, bonds = traj.topology.to_dataframe() return (box, jnp.array(coordinates), jnp.array(masses), jnp.array(species, dtype=jnp.int32))
[docs] def save_gro(positions: ArrayLike, box: ArrayLike, velocities: ArrayLike = None, group: Union[str, List] = "SOL", species: Union[str, List] = "CG", fractional: bool = False, time: ArrayLike = 0.0, filename: PathLike = None ) -> str: """Writes a box of particles to the gro file format [#gromacsgro]_. Args: positions: ``(N, 3)`` array of particle positions. box: Scalar or array describing the box. Currently, only supports diagonal 2D boxes. velocities: ``(N, 3)`` array of velocities. If not given, the velocities are set to zero. species: Either a string for a single species or a list with one species type per atom. group: Either a string if all atoms belong to the same group or a list with one group per atom. fractional: Whether particle coordinates are fractional. Transforms fractional coordinates. time: Time of the frame. filename: Save the generated string to a file. Returns: Returns the content of the gro file as string. References: .. [#gromacsgro] `<https://manual.gromacs.org/archive/5.0.3/online/gro.html>`_ """ lines = [f"Generated by chemtrain; t={time}\n"] assert positions.ndim == 2, "Can only save a single frame" # Set all velocities to zero if not given if velocities is None: velocities = onp.zeros_like(positions) else: assert velocities.shape == positions.shape, ( "Velocities must have the same shape as positions.") # Save the atom count lines.append("%5d\n" % positions.shape[0]) if not isinstance(species, list): species = [species] * positions.shape[0] if not isinstance(group, list): group = [group] * positions.shape[0] atoms = zip(positions, velocities, species, group) # Save the atom positions and velocities for idx, (r, v, sp, gp) in enumerate(atoms): # Transform fractional coordinates if fractional: if onp.isscalar(box) or box.ndim == 1: r = r * box if box.ndim == 2: r = r @ box values = tuple([1, gp, f"{sp}{idx}", idx] + r.tolist() + v.tolist()) lines.append("%5d%-5s%5s%5d%8.3f%8.3f%8.3f%8.4f%8.4f%8.4f\n" % values) if onp.isscalar(box): box = onp.repeat(box, 3) # Save the box (only the diag for now) lines.append("%8.3f%8.3f%8.3f\n" % tuple(box)) if filename is not None: with open(filename, "w") as f: f.writelines(lines) return "".join(lines)
[docs] def save_traj(times: ArrayLike, positions: ArrayLike, box: ArrayLike, velocities: ArrayLike = None, dynamic_box: bool = False, group: Union[str, List] = "SOL", species: Union[str, List] = "CG", fractional: bool = False, filename: PathLike = None ) -> None: """Writes a trajectory to the gro file format using :func:`save_gro`. Args: times: ``(T,)`` array of corresponding times. positions: ``(T, N, 3)`` array of particle positions. box: Scalar or array describing the box. If ``dynamic_box=True``, a time dependent box size ``(T, ...)`` can be given. Currently, does not support off-diagonal box entries. velocities: ``(N, 3)`` array of velocities. If not given, the velocities are set to zero. dynamic_box: Set to ``True`` if the box is time-dependent. species: Either a string for a single species or a list with one species type per atom. group: Either a string if all atoms belong to the same group or a list with one group per atom. fractional: Whether particle coordinates are fractional. Transforms fractional coordinates. filename: Save the generated string to a file. Notes: It is inefficient to save trajectories in ASCII format. Consider reformatting the saved trajectory to a more efficient format, e.g. via: .. code-block :: bash gmx trjconv -f trajectory.gro -o <new-trajectory>.<trr/xtc> """ assert times.ndim == 1, ( f"Times must be a one-dimensional vector.") assert times.shape[0] == positions.shape[0], ( "Requires a time for each configuration.") if velocities is not None: assert velocities.shape == positions.shape, ( "The velocities must have the same shape as the positions.") else: velocities = [None] * times.size if dynamic_box: assert box.shape[0] == times.shape[0], ( "Requires a box for each configuration if dynamic_box is set.") else: box = [box] * times.size all_frames = "".join([ save_gro(*f, group=group, species=species, fractional=fractional, time=t) for *f, t in zip(positions, box, velocities, times) ]) with open(filename, "w") as f: f.write(all_frames)