Source code for jax_md_mod.custom_space
# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Custom functions simplifying the handling of fractional coordinates."""
from typing import Union, Tuple, Callable
import jax
from jax_md import space, util
import jax.numpy as jnp
from jax import vmap
Box = Union[float, util.Array]
def _rectangular_boxtensor(box: Box) -> Box:
"""Transforms a 1-dimensional box to a 2D box tensor."""
spatial_dim = box.shape[0]
return jnp.eye(spatial_dim).at[jnp.diag_indices(spatial_dim)].set(box)
[docs]
def init_fractional_coordinates(box: Box) -> Tuple[Box, Callable]:
"""Returns a 2D box tensor and a scale function that projects positions
within a box in real space to the unit-hypercube as required by fractional
coordinates.
Args:
box: A 1 or 2-dimensional box
Returns:
A tuple (box, scale_fn) of a 2D box tensor and a scale_fn that scales
positions in real-space to the unit hypercube.
"""
if box.ndim != 2:
box = _rectangular_boxtensor(box)
def scale_fn(positions, **kwargs):
_box = kwargs.get('box', box)
if _box.ndim != 2:
_box = _rectangular_boxtensor(_box)
inv_box = jnp.linalg.inv(_box)
return jnp.dot(inv_box, positions.T).T
return box, scale_fn
def general_space(box: Box,
periodic: jax.Array,
wrapped: bool = True,
fractional: bool = False) -> space.Space:
"""TODO
"""
def displacement_fn(Ra: jax.Array, Rb: jax.Array, **kwargs):
_periodic = kwargs.get('periodic', periodic)
_box = kwargs.get('box', box)
if not fractional:
_inv_box = jnp.linalg.inv(_box)
Ra = space.transform(_inv_box, Ra)
Rb = space.transform(_inv_box, Rb)
dR = space.pairwise_displacement(Ra, Rb)
dR = jnp.where(_periodic, space.periodic_displacement(1.0, dR), dR)
dR = space.transform(_box, dR)
return dR
if wrapped:
def shift_fn(R: jax.Array, dR: jax.Array, **kwargs):
_periodic = kwargs.get('periodic', periodic)
_box = kwargs.get('box', box)
if not fractional:
_inv_box = jnp.linalg.inv(_box)
R = space.transform(_inv_box, R)
dR = space.transform(_inv_box, dR)
R = jnp.where(_periodic, space.periodic_shift(1.0, R, dR), R + dR)
if not fractional:
R = space.transform(_box, R)
return R
else:
def shift_fn(R: jax.Array, dR: jax.Array, **kwargs):
del kwargs
return R + dR
return displacement_fn, shift_fn