# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for io: Loading data to and from Jax M.D."""
import jax.numpy as jnp
import mdtraj
from mdtraj import utils as mdtraj_utils
import numpy as onp
from typing import Union, List, Any
from os import PathLike
try:
from jax.typing import ArrayLike
except:
ArrayLike = Any
[docs]
def load_box(filename, frame=-1, top=None):
"""Loads initial configuration using the file loader from MDTraj.
Args:
filename: String providing the location of the file to load.
frame: Frame of the trajectory to read data from.
top: Topology file, necessary if the file does not contain sufficient
information to generate the system topology.
Returns:
Tuple of jnp arrays of box, coordinates, mass, and species.
"""
if top is not None:
top = mdtraj.load_topology(top)
traj = mdtraj.load(filename, top=top)
coordinates = traj.xyz[frame]
lengths = traj.unitcell_lengths[frame]
angles = traj.unitcell_angles[frame]
if angles is not None:
vectors = mdtraj_utils.unitcell.lengths_and_angles_to_box_vectors(
*lengths, *angles)
box = jnp.stack(vectors, axis=-1)
elif lengths is not None:
box = jnp.asarray(lengths)
species = onp.zeros(coordinates.shape[0])
masses = onp.zeros_like(species)
for atom in traj.topology.atoms:
species[atom.index] = atom.element.number
masses[atom.index] = atom.element.mass
# _, bonds = traj.topology.to_dataframe()
return (box, jnp.array(coordinates), jnp.array(masses),
jnp.array(species, dtype=jnp.int32))
[docs]
def save_gro(positions: ArrayLike,
box: ArrayLike,
velocities: ArrayLike = None,
group: Union[str, List] = "SOL",
species: Union[str, List] = "CG",
fractional: bool = False,
time: ArrayLike = 0.0,
filename: PathLike = None
) -> str:
"""Writes a box of particles to the gro file format [#gromacsgro]_.
Args:
positions: ``(N, 3)`` array of particle positions.
box: Scalar or array describing the box. Currently, only supports
diagonal 2D boxes.
velocities: ``(N, 3)`` array of velocities. If not given, the velocities
are set to zero.
species: Either a string for a single species or a list with one
species type per atom.
group: Either a string if all atoms belong to the same group or a list
with one group per atom.
fractional: Whether particle coordinates are fractional. Transforms
fractional coordinates.
time: Time of the frame.
filename: Save the generated string to a file.
Returns:
Returns the content of the gro file as string.
References:
.. [#gromacsgro] `<https://manual.gromacs.org/archive/5.0.3/online/gro.html>`_
"""
lines = [f"Generated by chemtrain; t={time}\n"]
assert positions.ndim == 2, "Can only save a single frame"
# Set all velocities to zero if not given
if velocities is None:
velocities = onp.zeros_like(positions)
else:
assert velocities.shape == positions.shape, (
"Velocities must have the same shape as positions.")
# Save the atom count
lines.append("%5d\n" % positions.shape[0])
if not isinstance(species, list):
species = [species] * positions.shape[0]
if not isinstance(group, list):
group = [group] * positions.shape[0]
atoms = zip(positions, velocities, species, group)
# Save the atom positions and velocities
for idx, (r, v, sp, gp) in enumerate(atoms):
# Transform fractional coordinates
if fractional:
if onp.isscalar(box) or box.ndim == 1:
r = r * box
if box.ndim == 2:
r = r @ box
values = tuple([1, gp, f"{sp}{idx}", idx] + r.tolist() + v.tolist())
lines.append("%5d%-5s%5s%5d%8.3f%8.3f%8.3f%8.4f%8.4f%8.4f\n" % values)
if onp.isscalar(box):
box = onp.repeat(box, 3)
# Save the box (only the diag for now)
lines.append("%8.3f%8.3f%8.3f\n" % tuple(box))
if filename is not None:
with open(filename, "w") as f:
f.writelines(lines)
return "".join(lines)
[docs]
def save_traj(times: ArrayLike,
positions: ArrayLike,
box: ArrayLike,
velocities: ArrayLike = None,
dynamic_box: bool = False,
group: Union[str, List] = "SOL",
species: Union[str, List] = "CG",
fractional: bool = False,
filename: PathLike = None
) -> None:
"""Writes a trajectory to the gro file format using :func:`save_gro`.
Args:
times: ``(T,)`` array of corresponding times.
positions: ``(T, N, 3)`` array of particle positions.
box: Scalar or array describing the box. If ``dynamic_box=True``, a
time dependent box size ``(T, ...)`` can be given.
Currently, does not support off-diagonal box entries.
velocities: ``(N, 3)`` array of velocities. If not given, the velocities
are set to zero.
dynamic_box: Set to ``True`` if the box is time-dependent.
species: Either a string for a single species or a list with one
species type per atom.
group: Either a string if all atoms belong to the same group or a list
with one group per atom.
fractional: Whether particle coordinates are fractional. Transforms
fractional coordinates.
filename: Save the generated string to a file.
Notes:
It is inefficient to save trajectories in ASCII format. Consider
reformatting the saved trajectory to a more efficient format, e.g.
via:
.. code-block :: bash
gmx trjconv -f trajectory.gro -o <new-trajectory>.<trr/xtc>
"""
assert times.ndim == 1, (
f"Times must be a one-dimensional vector.")
assert times.shape[0] == positions.shape[0], (
"Requires a time for each configuration.")
if velocities is not None:
assert velocities.shape == positions.shape, (
"The velocities must have the same shape as the positions.")
else:
velocities = [None] * times.size
if dynamic_box:
assert box.shape[0] == times.shape[0], (
"Requires a box for each configuration if dynamic_box is set.")
else:
box = [box] * times.size
all_frames = "".join([
save_gro(*f, group=group, species=species, fractional=fractional, time=t)
for *f, t in zip(positions, box, velocities, times)
])
with open(filename, "w") as f:
f.write(all_frames)