# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A collection of functions evaluating quantiities of trajectories.
For easiest integration into chemtain, functions should be compatible with
traj_util.quantity_traj. Functions provided to quantity_traj need to take the
state and additional kwargs.
"""
from functools import partial
import jax
import numpy as onp
from jax import grad, vmap, lax, jacrev, jacfwd, numpy as jnp, jit
from jax.scipy.stats import norm
from jax_md import space, util, dataclasses, quantity, simulate, partition
from jax_md_mod.model import sparse_graph
from jax_md_mod import custom_partition
Array = util.Array
[docs]
def energy_wrapper(energy_fn_template, fixed_energy_params=None):
"""Wrapper around energy_fn to allow energy computation via
traj_util.quantity_traj.
Args:
energy_fn_template: Function creating an energy function when called
with energy parameters.
fixed_energy_params: Always use the energy function obtained when
using the fixed energy params. If not given, the function uses
dynamially specified parameters.
"""
def energy(state, neighbor, energy_params, energy_and_force=None, **kwargs):
if energy_and_force is not None:
print(f"[Potential] Found precomputed forces.")
return energy_and_force['energy']
if fixed_energy_params is None:
energy_fn = energy_fn_template(energy_params)
else:
energy_fn = energy_fn_template(fixed_energy_params)
return energy_fn(state.position, neighbor=neighbor, **kwargs)
return energy
def force_wrapper(energy_fn_template, fixed_energy_params=None):
"""Wrapper around energy_fn to allow force computation via
traj_util.quantity_traj.
Args:
energy_fn_template: Function creating an energy function when called
with energy parameters.
fixed_energy_params: Always use the energy function obtained when
using the fixed energy params. If not given, the function uses
dynamially specified parameters.
"""
def energy(state, neighbor, energy_params, energy_and_force=None, **kwargs):
if energy_and_force is not None:
print(f"[Force] Found precomputed forces.")
return energy_and_force['force']
if fixed_energy_params is None:
energy_fn = energy_fn_template(energy_params)
else:
energy_fn = energy_fn_template(fixed_energy_params)
force_fn = quantity.force(energy_fn)
return force_fn(state.position, neighbor=neighbor, **kwargs)
return energy
def energy_force_wrapper(energy_fn_template,
fixed_energy_params=None,
has_aux=False):
"""Wrapper around energy_fn to allow energy and force computation via
traj_util.quantity_traj.
Args:
energy_fn_template: Function creating an energy function when called
with energy parameters.
fixed_energy_params: Always use the energy function obtained when
using the fixed energy params. If not given, the function uses
dynamially specified parameters.
has_aux: Whether the energy function has an auxiliary output. In this
case, the energy function will be called with the argument
``mode="with_aux"`` and should return a tuple with ``(energy, aux)``.
"""
def energy_and_force_fn(state, neighbor, energy_params, **kwargs):
if fixed_energy_params is None:
energy_fn = energy_fn_template(energy_params)
else:
energy_fn = energy_fn_template(fixed_energy_params)
if has_aux:
kwargs['mode'] = 'with_aux'
box = kwargs.pop('box', None)
@partial(jax.value_and_grad, argnums=(0, 1), has_aux=has_aux)
def energy_and_grads_fn(R, _box):
if box is not None:
return energy_fn(R, neighbor=neighbor, box=_box, **kwargs)
else:
return energy_fn(R, neighbor=neighbor, **kwargs)
energy_or_aux, (neg_force, box_grads) = energy_and_grads_fn(state.position, box)
if has_aux:
energy, aux = energy_or_aux
else:
aux = None
energy = energy_or_aux
return {'energy': energy, 'force': -neg_force, 'box_grad': box_grads, 'aux': aux}
return energy_and_force_fn
def get_aux(aux_key=""):
"""Reads out auxiliary output from the energy function."""
def snapshot_fn(state, energy_and_force=None, **kwargs):
assert energy_and_force is not None, f"Need to provide aux for {aux_key}."
assert aux_key in energy_and_force['aux'].keys(), f"Need to provide aux for {aux_key}."
print(f"Read out {aux_key} from aux.")
return energy_and_force['aux'][aux_key]
return snapshot_fn
[docs]
def kinetic_energy_wrapper(state, **unused_kwargs):
"""Wrapper around kinetic_energy to allow kinetic energy computation via
traj_util.quantity_traj.
"""
return quantity.kinetic_energy(velocity=state.velocity, mass=state.mass)
[docs]
def total_energy_wrapper(energy_fn_template):
"""Wrapper around energy_fn to allow total energy computation via
traj_util.quantity_traj.
"""
def energy(state, neighbor, energy_params, **kwargs):
energy_fn = energy_fn_template(energy_params)
pot_energy = energy_fn(state.position, neighbor=neighbor, **kwargs)
kinetic_energy = quantity.kinetic_energy(state.velocity, state.mass)
return pot_energy + kinetic_energy
return energy
def connected(state, neighbor, mask=None, **kwargs):
"""Checks whether the system is connected."""
return custom_partition.check_connectivity(neighbor, mask=mask)
[docs]
def temperature(state, mask=None, **kwargs):
"""Temperature function that is consistent with quantity_traj interface."""
mass = state.mass
velocity = state.velocity
if mask is None:
mask = jnp.ones(velocity.shape[0], dtype=bool)
else:
print(f"Masked temperature computation")
velocity = jnp.where(mask[:, jnp.newaxis], velocity, 0.0)
mass = jnp.where(mask[:, jnp.newaxis], mass, 1.0)
dof = jnp.sum(mask) * jnp.shape(velocity)[-1]
return jnp.sum(velocity ** 2 * mass) / dof
def _dyn_box(reference_box, state, **kwargs):
"""Gets box dynamically from kwargs, if provided, otherwise defaults to
reference. Ensures that a box is provided and deletes from kwargs.
"""
# Get the dynamic box if simulation is in NPT ensemble
if hasattr(state, 'box_position'):
return simulate.npt_box(state), kwargs
elif hasattr(state, 'box'):
return state.box, kwargs
box = kwargs.pop('box', reference_box)
assert box is not None, ('If no reference box is given, needs to be '
'given as kwarg "box".')
return box, kwargs
def _dyn_kT(kT, **kwargs):
"""Gets kT dynamically from kwargs, if provided, otherwise defaults to
reference. Ensures that a kT is provided and deletes from kwargs.
"""
kT = kwargs.pop('kT', kT)
assert kT is not None, ('If no reference kT is given, needs to be '
'given as kwarg "kT".')
return kT, kwargs
[docs]
def volume(state, **kwargs):
"""Returns volume of a single snapshot."""
dim = state.position.shape[-1]
box, _ = _dyn_box(None, state, **kwargs)
volume = quantity.volume(dim, box)
return volume
def _canonicalized_masses(state):
if state.mass.ndim == 0:
masses = jnp.ones(state.position.shape[0]) * state.mass
else:
masses = state.mass
return masses
[docs]
def density(state, **unused_kwargs):
"""Returns density of a single snapshot of the NPT ensemble."""
masses = _canonicalized_masses(state)
total_mass = jnp.sum(masses)
volume = volume(state)
return total_mass / volume
# TODO distinct classes and discretization functions don't seem optimal
# --> possible refactor
[docs]
@dataclasses.dataclass
class RDFParams:
"""Hyperparameters to initialize the radial distribution function (RDF).
Attributes:
reference_rdf: The target rdf; initialize with None if no target available
rdf_bin_centers: The radial positions of the centers of the rdf bins
rdf_bin_boundaries: The radial positions of the edges of the rdf bins
sigma_RDF: Standard deviation of smoothing Gaussian
"""
reference: Array
rdf_bin_centers: Array
rdf_bin_boundaries: Array
sigma: Array
def rdf_discretization(rdf_cut, nbins=300, rdf_start=0.):
"""Computes dicretization parameters for initialization of RDF function.
Args:
rdf_cut: Cut-off length inside which pairs of particles are considered
nbins: Number of bins in radial direction
rdf_start: Minimal distance after which particle pairs are considered
Returns:
Arrays with radial positions of bin centers, bin edges and the standard
deviation of the Gaussian smoothing kernel.
"""
dx_bin = (rdf_cut - rdf_start) / float(nbins)
rdf_bin_centers = jnp.linspace(rdf_start + dx_bin / 2.,
rdf_cut - dx_bin / 2.,
nbins)
rdf_bin_boundaries = jnp.linspace(rdf_start, rdf_cut, nbins + 1)
sigma_rdf = jnp.array(dx_bin)
return rdf_bin_centers, rdf_bin_boundaries, sigma_rdf
[docs]
@dataclasses.dataclass
class ADFParams:
"""Hyperparameters to initialize a angular distribution function (ADF).
Attributes:
reference_adf: The target adf; initialize with None if no target available
adf_bin_centers: The positions of the centers of the adf bins over theta
sigma_ADF: Standard deviation of smoothing Gaussian
r_outer: Outer radius beyond which particle triplets are not considered
r_inner: Inner radius below which particle triplets are not considered
"""
reference: Array
adf_bin_centers: Array
sigma: Array
r_outer: Array
r_inner: Array
def adf_discretization(nbins=200):
"""Computes dicretization parameters for initialization of ADF function.
Args:
nbins: Number of bins discretizing theta
Returns:
Arrays containing bin centers in theta direction and the standard
deviation of the Gaussian smoothing kernel.
"""
dtheta_bin = jnp.pi / float(nbins)
adf_bin_centers = jnp.linspace(dtheta_bin / 2.,
jnp.pi - dtheta_bin / 2.,
nbins)
sigma_adf = util.f32(dtheta_bin)
return adf_bin_centers, sigma_adf
def dihedral_discretization(nbins=150):
dbin = 2 * jnp.pi / float(nbins)
bin_centers = dbin * (jnp.arange(nbins) + 0.5) - jnp.pi
sigma = util.f32(dbin)
return bin_centers, sigma
[docs]
@dataclasses.dataclass
class TCFParams:
"""Hyperparameters to initialize a triplet correlation function (TFC).
The triplet is defined via the sides x, y, z. Implementation according to
https://aip.scitation.org/doi/10.1063/1.4898755 and
https://aip.scitation.org/doi/10.1063/5.0048450.
Attributes:
reference_tcf: The target tcf; initialize with None if no target
available
sigma_TCF: Standard deviation of smoothing Gaussian
volume_TCF: Histogram volume element according to
https://journals.aps.org/pra/abstract/10.1103/PhysRevA.42.849
tcf_x_bin_centers: The radial positions of the centers of the tcf bins in
x direction
tcf_y_bin_centers: The radial positions of the centers of the tcf bins in
y direction
tcf_z_bin_centers: The radial positions of the centers of the tcf bins in
z direction
"""
reference: Array
sigma_tcf: Array
volume: Array
tcf_x_bin_centers: Array
tcf_y_bin_centers: Array
tcf_z_bin_centers: Array
def tcf_discretization(tcf_cut, nbins=30, tcf_start=0.1):
"""Computes dicretization parameters for initialization of TCF function.
Args:
tcf_cut: Cut-off length inside which pairs of particles are considered
nbins: Number of bins in all three radial direction
tcf_start: Minimal distance after which particle pairs are considered
Returns:
Tuple containing standard deviation of the Gaussian smoothing kernel,
histogram volume array and arrays with radial positions of bin centers
in x, y, z.
"""
dx_bin = (tcf_cut - tcf_start) / float(nbins)
tcf_bin_centers = jnp.linspace(tcf_start + dx_bin / 2.,
tcf_cut - dx_bin / 2.,
nbins)
tcf_x_binx_centers, tcf_y_bin_centers, tcf_z_bin_centers = jnp.meshgrid(
tcf_bin_centers, tcf_bin_centers, tcf_bin_centers, sparse=True)
sigma_tcf = jnp.array(dx_bin)
# volume for non-linear triplets (sigma / min(x,y,z)->0)
volume_tcf = (8 * jnp.pi**2 * tcf_x_binx_centers * tcf_y_bin_centers
* tcf_z_bin_centers * sigma_tcf**3)
return (sigma_tcf, volume_tcf, tcf_x_binx_centers, tcf_y_bin_centers,
tcf_z_bin_centers)
[docs]
@dataclasses.dataclass
class BondAngleParams:
reference: Array
sigma: Array
bonds: Array
bin_centers: Array
bin_boundaries: Array
[docs]
def init_bond_angle_distribution(displacement_fn, bond_angle_params: BondAngleParams, reference_box=None):
"""Initializes a function computing a dihedral distribution.
Args:
displacement_fn: Displacement to compute dihedral angles
bond_angle_params: Struct describing the dihedral angles and expected
format of the computed distribution
reference_box: Unused
Returns:
Returns a function that computes a distribution of dihedral angles
given a simulation state.
"""
_, sigma, bonds, bin_centers, bin_boundaries = dataclasses.astuple(bond_angle_params)
bin_size = jnp.diff(bin_boundaries)
def angle_fn(state, neighbor, **kwargs):
angles = angular_displacement(
state.position, displacement_fn, bonds, degrees=True)
# Gaussian ensures that discrete integral over distribution is 1
exp = jnp.exp(
util.f32(-0.5) * (angles[:, jnp.newaxis] - bin_centers)**2 / sigma)
gaussian_distances = exp * bin_size / jnp.sqrt(2 * jnp.pi * sigma**2)
per_bond = util.high_precision_sum(gaussian_distances,
axis=1) # sum nbrs
mean_angle_dist = util.high_precision_sum(per_bond,
axis=0) / bonds.shape[0]
return mean_angle_dist
return angle_fn
[docs]
@dataclasses.dataclass
class BondDihedralParams:
reference: Array
sigma: Array
bonds: Array
bin_centers: Array
bin_boundaries: Array
[docs]
def init_bond_dihedral_distribution(displacement_fn,
bond_dihedral_params: BondDihedralParams,
smoothing = 'gaussian'):
"""Initializes a function computing a dihedral distribution.
Args:
displacement_fn: Displacement to compute dihedral angles
bond_dihedral_params: Struct describing the dihedral angles and expected
format of the computed distribution
reference_box: Unused
Returns:
Returns a function that computes a distribution of dihedral angles
given a simulation state.
"""
_, sigma, bonds, bin_centers, bin_boundaries = dataclasses.astuple(bond_dihedral_params)
bin_size = jnp.diff(bin_boundaries)
def dihedral_fn(state, neighbor, **kwargs):
dihedrals = dihedral_displacement(
state.position, displacement_fn, bonds, degrees=False)
distances = dihedrals[None, :] - bin_centers[:, None]
if smoothing == 'gaussian':
# Smooth the bins
exponent = -0.5 * jnp.square(distances)
exponent /= sigma ** 2
kernel = jnp.exp(exponent)
elif smoothing == 'epanechnikov':
distances /= bin_size
kernel = 0.75 * (1 - distances ** 2)
kernel = jnp.where(kernel >= 0, kernel, 0.0)
else:
raise ValueError(f"Smoothing {smoothing} unknown.")
# Sum over all contributing bonds
dihedral_distribution = util.high_precision_sum(kernel, axis=1)
# Normalize the distribution
dihedral_distribution /= util.high_precision_sum(
dihedral_distribution * bin_size, axis=0)
return dihedral_distribution
return dihedral_fn
def _ideal_gas_density(particle_density, bin_boundaries):
"""Returns bin densities that would correspond to an ideal gas."""
r_small = bin_boundaries[:-1]
r_large = bin_boundaries[1:]
bin_volume = (4. / 3.) * jnp.pi * (r_large**3 - r_small**3)
bin_weights = bin_volume * particle_density
return bin_weights
[docs]
def init_rdf(displacement_fn,
rdf_params,
reference_box=None,
rdf_species=None):
"""Initializes a function that computes the radial distribution function
(RDF) for a single state.
Args:
displacement_fn: Displacement function
rdf_params: RDFParams defining the hyperparameters of the RDF
reference_box: Simulation box. Can be provided here for constant boxes
or on-the-fly as kwarg 'box', e.g. for NPT ensemble
rdf_species: Array of species pairs for which the RDF should be
computed. If not provided, compute the RDF for all particles
irrespectively of their species.
Returns:
A function taking a simulation state and returning the instantaneous RDF
"""
_, bin_centers, bin_boundaries, sigma = dataclasses.astuple(rdf_params)
distance_metric = space.canonicalize_displacement_or_metric(displacement_fn)
bin_size = jnp.diff(bin_boundaries)
def pair_corr_fun(position, box, species):
"""Returns instantaneous pair correlation function while ensuring
each particle pair contributes exactly 1.
"""
n_particles = position.shape[0]
metric = partial(distance_metric, box=box)
metric = space.map_product(metric)
dr = metric(position, position)
# neglect same particles i.e. distance = 0.
dr = jnp.where(dr > util.f32(1.e-7), dr, util.f32(1.e7))
# Gaussian ensures that discrete integral over distribution is 1
exponent = (dr[:, :, jnp.newaxis] - bin_centers) ** 2.0
exponent *= -0.5 / sigma ** 2
exp = jnp.exp(exponent)
gdist = exp * bin_size / jnp.sqrt(2 * jnp.pi * sigma ** 2)
if rdf_species is not None:
# Find the species of each particle
is_species_i = rdf_species[:, (0,)] == species[None, :]
is_species_j = rdf_species[:, (1,)] == species[None, :]
mask = jnp.logical_and(
is_species_i[:, None, :],
is_species_j[:, :, None]
)
masked_gdist = gdist[None, ...] * mask[:, ..., None]
# Axes of masked_gdisk are:
# [rdf_idx, particle_i, particle_j, bins]
# Sum over neighbors of corresponding species
mean_pair_corr = util.high_precision_sum(
masked_gdist, axis=(1, 2)
)
# TODO: Improve normalization for per-species RDF
mean_pair_corr /= jnp.sum(is_species_i, axis=1)[:, None]
mean_pair_corr *= n_particles / jnp.sum(is_species_j, axis=1, keepdims=True)
else:
mean_pair_corr = util.high_precision_sum(
gdist, axis=(0, 1)) # sum nbrs
mean_pair_corr /= n_particles
return mean_pair_corr
def rdf_compute_fun(state, species=None, **kwargs):
box, _ = _dyn_box(reference_box, state, **kwargs)
# Note: we cannot use neighbor list since RDF cutoff and
# neighbor list cut-off don't coincide in general
n_particles, spatial_dim = state.position.shape
total_vol = quantity.volume(spatial_dim, box)
mean_pair_corr = pair_corr_fun(state.position, box, species)
# RDF is defined to relate the particle densities to an ideal gas.
particle_density = n_particles / total_vol
normalization = _ideal_gas_density(particle_density, bin_boundaries)
rdf = mean_pair_corr / normalization
return rdf
return rdf_compute_fun
def _triplet_pairwise_displacements(position,
neighbor: partition.NeighborList,
displacement_fn,
species = None,
max_triplets: int = None,
return_mask: bool = False):
"""Computes the displacements r_ij and r_kj between triplets of particles.
For each triplet of particles :math:`(ijk)`, the function computes the
displacement vectors
.. math::
r_{kj} = R_k - R_j\\ \\text{and}\\ r_ij.
These vectors pointing from the central particle j to the side particles
i and k.
Returns:
Returns a tuple (r_kj, r_ij) that contains the displacement vectors for
all triplets. The displacement arrays have the shape
``r_kj.shape = (N_triplets, 3)``.
"""
# Compute the indices of the triplet edges
ij, kj, mask = custom_partition.get_triplet_indices(neighbor)
if max_triplets is not None:
ij = ij[:max_triplets, ...]
kj = kj[:max_triplets, ...]
mask = mask[:max_triplets]
# Compute the displacements
r_ij = vmap(displacement_fn)(position[ij[:, 0]], position[ij[:, 1]])
r_kj = vmap(displacement_fn)(position[kj[:, 0]], position[kj[:, 1]])
if return_mask:
return r_kj, r_ij, mask
else:
return r_kj, r_ij
def _triplet_species(neighbor,
species,
max_triplets = None,
return_mask: bool = False):
"""Compute the species of triplets."""
ij, kj, mask = custom_partition.get_triplet_indices(neighbor)
if max_triplets is not None:
ij = ij[:max_triplets]
kj = kj[:max_triplets]
mask = mask[:max_triplets]
si = species[ij[:, 0]]
sj = species[ij[:, 1]]
sk = species[kj[:, 0]]
if return_mask:
return si, sj, sk, mask
else:
return si, sj, sk
[docs]
def init_adf_nbrs(displacement_fn,
adf_params: ADFParams,
adf_species: Array = None,
smoothing_dr: float = 0.01,
r_init: Array = None,
nbrs_init: partition.NeighborList = None,
max_weight_multiplier: float = 2.):
"""Initializes a function to computes the angular distribution function (ADF).
Smoothens the histogram in radial direction via a Gaussian kernel
(compare RDF function).
In radial direction, triplets are weighted according to a
Gaussian cumulative distribution function, such that triplets with both
radii inside the cut-off band are weighted approximately 1 and the weights
of triplets towards the band edges are smoothly reduced to 0.
For computational speed-up and reduced memory needs, ``r_init`` and
``nbrs_init`` can be provided to estimate the maximum number of triplets.
Warning:
Currently, the user does not receive information whether overflow
occurred.
Note:
This function assumes that r_outer is smaller than the neighbor
list cut-off. If this is not the case, a function computing all pairwise
distances is necessary.
Args:
displacement_fn: Displacement function
adf_params: Hyperparameters of the ADF
smoothing_dr: Standard deviation of Gaussian smoothing in radial
direction
r_init: Initial positions to estimate maximum number of triplets
nbrs_init: Initial neighborlist to estimate maximum number of triplets
max_weight_multiplier: Multiplier to increase maximum number of triplets
Returns:
Returns a function that takes a simulation state with neighborlist and
computes the instantaneous adf.
"""
_, bin_centers, sigma_theta, r_outer, r_inner = dataclasses.astuple(
adf_params)
sigma_theta = util.f32(sigma_theta)
bin_centers = util.f32(bin_centers)
def cut_off_weights(r_kj, r_ij, mask):
"""Smoothly constraints triplets to a radial band such that both
distances are between r_inner and r_outer. The Gaussian cdf is used for
smoothing. The smoothing width can be controlled by the gaussian
standard deviation.
"""
dist_kj = space.distance(r_kj)
dist_ij = space.distance(r_ij)
# get d_small and d_large for each triplet
pair_dist = jnp.column_stack((dist_kj, dist_ij)).sort(axis=1)
# get inner boundary weight from r_small and outer weight from r_large
inner_weight = norm.cdf(pair_dist[:, 0], loc=r_inner,
scale=smoothing_dr**2)
outer_weight = 1 - norm.cdf(pair_dist[:, 1], loc=r_outer,
scale=smoothing_dr**2)
weights = outer_weight * inner_weight
return weights * mask
def weighted_adf(angles, weights):
"""Compute weighted ADF contribution of each triplet. For
differentiability, each triplet contribution is smoothed via a Gaussian.
"""
exp = jnp.exp(util.f32(-0.5) * (angles[:, jnp.newaxis] - bin_centers)**2
/ sigma_theta**2)
gaussians = exp / jnp.sqrt(2 * jnp.pi * sigma_theta**2)
gaussians *= weights[:, jnp.newaxis]
unnormed_adf = util.high_precision_sum(gaussians, axis=0)
adf = unnormed_adf / jnp.trapezoid(unnormed_adf, bin_centers)
return adf
# We use initial configuration to estimate the maximum number of triplets
# inside the cutoff radii.
#
if r_init is not None:
if nbrs_init is None:
raise ValueError(
'If we estimate the maximum number of triplets, the initial '
'neighbor list is a necessary input.'
)
r_ij, r_kj, mask = _triplet_pairwise_displacements(
r_init, nbrs_init, displacement_fn, return_mask=True)
weights = cut_off_weights(r_kj, r_ij, mask)
max_weights = min([
int(jnp.sum(weights > 1.e-6) * max_weight_multiplier),
mask.size
])
max_triplets = min([
int(jnp.sum(mask) * max_weight_multiplier),
mask.size
])
print(f"[ADF] Estimates {max_triplets} max. triplets in neighbor list "
f"and {max_weights} max. triplets in cutoff-shell.")
else:
max_weights = None
max_triplets = None
def adf_fn(state, neighbor, species=None, **kwargs):
"""Returns ADF for a single snapshot. Allows changing the box
on-the-fly via the 'box' kwarg.
"""
dyn_displacement = partial(displacement_fn, **kwargs) # box kwarg
r_kj, r_ij, mask = _triplet_pairwise_displacements(
state.position, neighbor, dyn_displacement,
max_triplets=max_triplets, return_mask=True
)
weights = cut_off_weights(r_kj, r_ij, mask)
if species is not None:
# Compute a second mask depending on the species of the triplets
si, sj, sk = _triplet_species(neighbor, species, max_triplets)
selection = si[None, :] == adf_species[:, (0,)]
selection = jnp.logical_and(
selection, sj[None, :] == adf_species[:, (1,)])
selection = jnp.logical_and(
selection, sk[None, :] == adf_species[:, (2,)])
else:
selection = None
if max_triplets is not None:
# Prune triplets by returning the most important weights
non_zero_weights = weights > 1.e-6
_, sorting_idxs = lax.top_k(weights, max_weights)
weights = weights[sorting_idxs]
r_ij = r_ij[sorting_idxs]
r_kj = r_kj[sorting_idxs]
mask = mask[sorting_idxs]
if selection is not None:
selection = selection[:, sorting_idxs]
# TODO check for overflow
if selection is not None:
weights = jnp.einsum('sn,n->sn', selection, weights)
_adf_fn = vmap(weighted_adf, in_axes=(None, 0))
else:
_adf_fn = weighted_adf
# ensure differentiability of tanh
r_ij_safe, r_kj_safe = sparse_graph.safe_angle_mask(r_ij, r_kj, mask)
angles = vmap(sparse_graph.angle)(r_ij_safe, r_kj_safe)
return _adf_fn(angles, weights)
return adf_fn
[docs]
def init_tcf_nbrs(displacement_fn,
tcf_params: TCFParams,
reference_box: Array = None,
nbrs_init: partition.NeighborList = None,
batch_size: int = 1000,
max_weight_multiplier: float = 1.2,
tcf_species: Array = None):
"""Initializes a function to compute the triplet correlation function (TCF).
This function assumes that the neighbor list cutoff matches the TCF cutoff.
Args:
displacement_fn: Displacement function
tcf_params: TCFParams defining the hyperparameters of the TCF
nbrs_init: Initial neighborlist to estimate maximum number of triplets
max_weight_multiplier: Multiplier for estimate of number of triplets
batch_size: Batch size for more efficient binning of triplets
reference_box: Simulation box. Can be provided here for constant boxes
or on-the-fly as kwarg ``'box'``, e.g., for NPT ensemble
Returns:
A function that takes a simulation state with neighborlist and returns
the instantaneous tcf.
"""
if tcf_species is not None:
raise NotImplementedError("Species-dependent TCF not yet implemented.")
(_, sigma, volume, x_bin_centers, y_bin_centers,
z_bin_centers) = dataclasses.astuple(tcf_params)
nbins = x_bin_centers.shape[1]
# We use the initial configuration to estimate the maximum number of
# non-zero weights to speed up the computation and improve the memory
# footprint
if nbrs_init is None:
raise NotImplementedError('nbrs_init currently needs to be provided.')
r_ij, r_kj, mask = _triplet_pairwise_displacements(
nbrs_init.reference_position, nbrs_init, displacement_fn,
return_mask=True
)
max_triplets = int(jnp.count_nonzero(mask > 1.e-6) * max_weight_multiplier)
# Increase the maximum number of triplets to enable simple batching
rem = jnp.remainder(max_triplets, batch_size)
max_triplets = max_triplets + (batch_size - rem)
def gaussian_3d_bins(exp, inputs):
triplet_distances, triplet_mask = inputs
batch_exp = jnp.exp(util.f32(-0.5) * (
(triplet_distances[:, 0, jnp.newaxis, jnp.newaxis, jnp.newaxis]
- x_bin_centers)**2 / sigma**2
+ (triplet_distances[:, 1, jnp.newaxis, jnp.newaxis,
jnp.newaxis] - y_bin_centers)**2 / sigma**2
+ (triplet_distances[:, 2, jnp.newaxis, jnp.newaxis,
jnp.newaxis] - z_bin_centers)**2 / sigma**2
))
batch_exp *= triplet_mask[:, jnp.newaxis, jnp.newaxis, jnp.newaxis]
batch_exp = jnp.sum(batch_exp, axis=0)
exp += batch_exp
return exp, 0
def triplet_corr_fun(r_kj, r_ij, triplet_mask):
"""Returns instantaneous triplet correlation function while ensuring
each particle pair contributes exactly 1.
"""
# Close the triplet triangle
r_ki = r_kj - r_ij
dist_kj = space.distance(r_kj)
dist_ij = space.distance(r_ij)
dist_ki = space.distance(r_ki)
histogram = jnp.zeros((nbins, nbins, nbins))
distances = jnp.stack((dist_kj, dist_ij, dist_ki), axis=1)
distances = jnp.reshape(distances, (-1, batch_size, 3))
triplet_mask = jnp.reshape(triplet_mask, (-1, batch_size))
# scan over per-batch computations for computational efficiency
histogram = lax.scan(gaussian_3d_bins, histogram,
(distances, triplet_mask))[0]
return histogram / volume / jnp.sqrt((2 * jnp.pi)**3)
def tcf_fn(state, neighbor, **kwargs):
"""Returns TCF for a single snapshot. Allows changing the box
on-the-fly via the 'box' kwarg.
"""
dyn_displacement = partial(displacement_fn, **kwargs) # box kwarg
r_kj, r_ij, triplet_mask = _triplet_pairwise_displacements(
state.position, neighbor, dyn_displacement,
max_triplets=max_triplets, return_mask=True)
box, _ = _dyn_box(reference_box, state, **kwargs)
n_particles, spatial_dim = state.position.shape
total_vol = quantity.volume(spatial_dim, box)
particle_density = n_particles / total_vol
tcf = triplet_corr_fun(r_kj, r_ij, triplet_mask)
return tcf / n_particles / particle_density ** 2
return tcf_fn
def _nearest_tetrahedral_nbrs(displacement_fn, position, nbrs):
"""Returns the displacement vectors r_ij of the 4 nearest neighbors."""
neighbor_displacement = space.map_neighbor(displacement_fn)
n_particles, _ = nbrs.idx.shape
neighbor_mask = nbrs.idx != n_particles
r_neigh = position[nbrs.idx]
# R_ij = R_i - R_j; i = central atom
displacements = neighbor_displacement(position, r_neigh)
distances = space.distance(displacements)
jnp.where(neighbor_mask, distances, 1.e7) # mask non-existing neighbors
_, nearest_idxs = lax.top_k(-1 * distances, 4) # 4 nearest neighbor indices
nearest_displ = jnp.take_along_axis(
displacements, jnp.expand_dims(nearest_idxs, -1), axis=1)
return nearest_displ
[docs]
def init_tetrahedral_order_parameter(displacement_fn):
"""Initializes a function that computes the tetrahedral order parameter q
for a single state.
Args:
displacement_fn: Displacement function
Returns:
A function that takes a simulation state with neighborlist and returns
the instantaneous q value.
"""
@partial(vmap, in_axes=(None, 0, 0))
def _masked_inner(nn_disp, j, k):
mask = k > j
r_ij = nn_disp[:, j]
r_ik = nn_disp[:, k]
psi_ijk = vmap(quantity.cosine_angle_between_two_vectors)(r_ij, r_ik)
summand = jnp.square(psi_ijk + (1. / 3.))
return mask * summand
def q_fn(state, neighbor, **kwargs):
dyn_displacement = partial(displacement_fn, **kwargs)
nearest_dispacements = _nearest_tetrahedral_nbrs(
dyn_displacement, state.position, neighbor)
all_j, all_k = jnp.meshgrid(jnp.arange(3), jnp.arange(4))
masked_angles = _masked_inner(
nearest_dispacements, all_j.ravel(), all_k.ravel())
summed_angles = jnp.sum(masked_angles, axis=0)
q = 1 - (3. / 8.) * jnp.mean(summed_angles)
return q
return q_fn
[docs]
def init_local_structure_index(displacement_fn,
r_cut: float = 0.37,
reference_box = None,
r_init = None,
max_pairs_multiplier: float = 3.0):
"""Initializes function to compute the local structure index (LSI).
The LSI measures the gap between the first and second solvation shell
[#dobouedijon2015]_.
Args:
displacement_fn: Function to compute the particle distances
r_cut: Cutoff of second solvation shell
reference_box: Reference box to compute particle distances. Necessary
if no dynamic box is provided.
r_init: Initial coordinates to estimate the number of particles in the
shell.
max_pairs_multiplier: Multiplies the estimated maximum number of
particles in a shell.
References:
.. [#dobouedijon2015] E. Duboué-Dijon und D. Laage,
„Characterization of the Local Structure in Liquid Water by Various
Order Parameters“, J. Phys. Chem. B, Bd. 119, Nr. 26, S. 8406–8418,
Juli 2015, doi: 10.1021/acs.jpcb.5b02936.
"""
# Estimate the number of pairs inside the first two solvation shells to
# speed up the computation
distance_metric = space.canonicalize_displacement_or_metric(
displacement_fn)
def _estimate_max_pairs():
num_atoms = r_init.shape[0]
if r_init is not None:
metric = space.map_product(distance_metric)
distances = metric(r_init, r_init)
mp = jnp.max(jnp.sum(distances < r_cut, axis=1))
mp = int(mp * max_pairs_multiplier)
else:
mp = num_atoms
print(f"[LSI] Consider {mp} number of pairs.")
return mp
max_pairs = _estimate_max_pairs()
@vmap
def _single_lsi(dist):
# Speed up the computation by only considering a subset of all
# particles.
# We seek to include the closest particles, so we have to select k
# maxima of the negative distance.
# Additionally, the calculation of the lsi requires sorting
# these closest particles.
_, selected = lax.top_k(-dist, max_pairs)
dr = lax.sort(dist[selected])[1:]
mask = (dr < util.f32(r_cut))[:-1]
# Mask out particles that are not close to any other particles
count = jnp.sum(mask)
mask /= jnp.where(count > 0, count, 1)
# Compute variance between the particle distance increments
delta = jnp.diff(dr)
lsi = jnp.sum(mask * jnp.square(delta - jnp.sum(mask * delta)))
# Set the LSI to zero for isolated particles
return (count > 0) * lsi
def lsi_fn(state, **kwargs):
box, _ = _dyn_box(reference_box, state, **kwargs)
# Incorporate the dynamic box and compute the distance between all
# pairs of the particles
dyn_metric = partial(distance_metric, box=box)
distance_fn = space.map_product(dyn_metric)
distances = distance_fn(state.position, state.position)
single_lsi = _single_lsi(distances)
return jnp.mean(single_lsi)
return lsi_fn
[docs]
def init_rmsd(reference_positions,
displacement_fn,
reference_box,
idx=None,
weights=None):
"""Initializes the root mean squared distance from a reference structure.
The RMSD is a common measure in the analysis of
macrostructures [#sargsyan2017]_.
The weighted RMSD between a current positions $p$ and reference positions
$q$ is defined as
.. math ::
\\mathrm{RMSD} = \\sqrt{\\frac{\\sum_{i=1}^n w_i || (Rp + t )- q||^2}{\\sum_{i=1}^n w_i}},
where $R$ and $t$ define a rigid body motion that minimizes the
RMSD [#hornung2017]_.
Args:
reference_positions: Reference positions including all atoms.
displacement_fn: Function to compute displacement between particles.
reference_box: Reference box of the reference structure.
idx: Indices selecting only the structure of interest, e.g. for
a protein in a solvent.
weights: Weight the rmsd, e.g., with masses of the particles.
References:
.. [#sargsyan2017] K. Sargsyan, C. Grauffel, und C. Lim,
„How Molecular Size Impacts RMSD Applications in Molecular Dynamics
Simulations“, J. Chem. Theory Comput., Bd. 13, Nr. 4, S. 1518–1524,
Apr. 2017, doi: 10.1021/acs.jctc.7b00028.
.. [#hornung2017] O. Sorkine-Hornung und M. Rabinovich,
„Least-Squares Rigid Motion Using SVD“.
https://igl.ethz.ch/projects/ARAP/svd_rot.pdf
"""
if idx is None:
idx = onp.arange(reference_positions.shape[0])
if weights is None:
weights = jnp.ones_like(idx)
weights /= jnp.sum(weights)
# The center of the positions does not matter as the structure is
# fit later by a rigid body motion
ref_q = reference_positions[idx[0]]
q = vmap(partial(displacement_fn, box=reference_box),
in_axes=(None, 0))(ref_q, reference_positions)
qbar = jnp.sum(weights[:, jnp.newaxis] * q, axis=0)
Y = q - qbar
def rmsd_fn(state, **kwargs):
box, _ = _dyn_box(reference_box, state, **kwargs)
dyn_displacement = partial(displacement_fn, box=box)
ref_p = state.position[idx[0]]
# Compute the displacements with respect to the first atoms to deal with
# different kinds of boundary conditions
p = vmap(dyn_displacement, in_axes=(None, 0))(ref_p, state.position)
pbar = jnp.sum(weights[:, jnp.newaxis] * p, axis=0)
X = p - pbar
# Compute the [d, d] covariance matrix for p.shape = (N, d) and perform
# a singular value decomposition to obtain the optimal rotation and
# translation that minimizes the weighted squared distance
cov = jnp.einsum('ji,j,jk->ik', X, weights, Y)
print(f"Covariance has shape {cov.shape}")
U, _, Vh = jnp.linalg.svd(cov, full_matrices=True, compute_uv=True)
print(f"Shapes are V: {Vh.shape}, U: {U.shape}")
det = jnp.linalg.det(jnp.dot(U, Vh.T).T)
sig = jnp.append(jnp.ones(p.shape[1] - 1), det)
rotation = jnp.einsum('ji,j,kj->ik', Vh, sig, U)
translation = qbar - jnp.dot(rotation, pbar)
# With the rigid body motion we can now compute the rmsd
p_opt = jnp.einsum('ij,nj->ni', rotation, p)
p_opt += translation[jnp.newaxis, :]
msd = jnp.sum(weights[:, jnp.newaxis] * jnp.square(p_opt -q))
rmsd = jnp.sqrt(msd)
return rmsd
return rmsd_fn
def init_rigid_body_alignment(displacement_fn, reference_position, weights=None, **kwargs):
"""Initializes a function that aligns a structure to a reference structure.
The aligned structure minimizes the (weighted) root mean squared distance
to the reference structure under rotations and translations, i.e.,
rigig body motions.
Args:
displacement_fn: Displacement function
reference_position: Reference positions including all atoms.
weights: Weight the rmsd, e.g., with masses of the particles.
**kwargs: Additional arguments for the displacement function.
Returns:
Returns a function to compute optimally aligned positions.
"""
ref_displacement_fn = partial(displacement_fn, **kwargs)
n_particles, dim = reference_position.shape
if weights is None:
weights = jnp.ones(n_particles)
def align_fn(position, **kwargs):
# Compute the centers of both point sets
q = vmap(ref_displacement_fn, in_axes=(0, None))(
reference_position, reference_position[0, :])
q_bar = jnp.mean(weights[:, jnp.newaxis] * q, axis=0) / jnp.mean(weights)
p = vmap(displacement_fn, in_axes=(0, None))(
position, position[0, :])
p_bar = jnp.mean(weights[:, jnp.newaxis] * p, axis=0) / jnp.mean(weights)
# Recenter the points
p -= p_bar[jnp.newaxis, :]
q -= q_bar[jnp.newaxis, :]
# Compute the [d, d] covariance matrix for p.shape = (N, d) and perform
# a singular value decomposition to obtain the optimal rotation and
# translation that minimizes the weighted squared distance
cov = jnp.einsum('ji,j,jk->ik', p, weights, q)
U, _, Vh = jnp.linalg.svd(cov, full_matrices=True, compute_uv=True)
det = jnp.linalg.det(jnp.dot(Vh.T, U.T))
sig = jnp.append(jnp.ones(p.shape[1] - 1), det)
rotation = jnp.einsum('ji,j,kj->ik', Vh, sig, U)
translation = q_bar - jnp.dot(rotation, p_bar)
# With the rigid body motion we can now compute the rmsd
p_opt = jnp.einsum('ij,nj->ni', rotation, p)
p_opt += translation[jnp.newaxis, :]
return p_opt
return align_fn
[docs]
def init_velocity_autocorrelation(num_lags):
"""Returns the velocity autocorrelation function (VACF).
Args:
num_lags: Number of time lags to compute VACF values for. The time lag
is implicitly defined by the dime difference between two adjacent
states in the trajectory.
Returns:
An array containing the value of the VACF for each considered time lag.
"""
# TODO this quadratic-scaling implementation of autocorrelation is not
# optimal. Long-term this should be using FFT if efficiency is critical
@partial(vmap, in_axes=(None, 0))
def _vel_correlation(vel, lag):
# Assume that array is of shape (Frames, Particles, Dimension).
# We roll around the frame axis to create a lag and average over all particles
lagged_vel = jnp.roll(vel, axis=0, shift=lag)
corr = jnp.mean(jnp.sum(vel * lagged_vel, axis=-1), axis=-1)
# Since the first elementwise products are now (v_(n-lag) * v_0), etc. we have to mask them out
mask = jnp.arange(vel.shape[0]) >= lag
avg_corr = jnp.sum(mask * corr) / jnp.sum(mask)
return avg_corr
@jit
def vac_fn(state, **kwargs):
del kwargs
# Due to a broadcasting error, it is necessary to compute the velocity from momentum
if state.mass.ndim == 1:
velocity = state.momentum / state.mass[:, None, None]
else:
velocity = state.velocity
lag_array = jnp.arange(num_lags)
vacf = lax.map(partial(_vel_correlation, velocity), lag_array)
return vacf
return vac_fn
[docs]
def self_diffusion_green_kubo(traj_state, time_step, t_cut):
"""Green-Kubo formulation to compute self-diffusion D via the velocity
autocorrelation function (VACF).
.. math::
D = \\frac{1}{dim} Int_0^{t_cut} VACF(\\tau) d\\tau
Args:
traj_state: TrajectoryState containing a finely resolved trajectory.
time_step: Time lag between 2 adjacent states in the trajetcory. The
simulation time step, in the usual case where every state is
retained.
t_cut: Cut-off time: Biggest time-difference to consider in the VACF.
Returns:
A tuple (D, VACF). Estimate of self-diffusion D and VACF that can be
used for additional post-processing / analysis.
"""
num_lags = int(t_cut / time_step)
vel_autocorr = velocity_autocorrelation(traj_state, num_lags)
dim = traj_state.trajectory.velocity.shape[-1]
diffusion = jnp.trapezoid(vel_autocorr, dx=time_step) / dim
return diffusion, vel_autocorr
[docs]
def init_bond_length(displacement_fn, bonds, average=False):
"""Initializes a function that computes bond lengths for given atom pairs.
Args:
displacement_fn: Displacement function
bonds: (n, 2) array defining IDs of bonded particles
average: If False, returns per-pair bond lengths.
If True, returns scalar average over all pairs
Returns:
A function that takes a simulation state and returns bond lengths
"""
metric = vmap(space.canonicalize_displacement_or_metric(displacement_fn))
def bond_length(state, **kwargs):
dyn_metric = partial(metric, **kwargs)
r1 = state.position[bonds[:, 0]]
r2 = state.position[bonds[:, 1]]
distances = dyn_metric(r1, r2)
if average:
return jnp.mean(distances)
else:
return distances
return bond_length
def _bond_length(bonds, positions, displacement_fn):
"""Computes bond lengths for given atom position vector and bonds."""
def pairwise_bond_length(bond, pos):
bond_displacement = displacement_fn(pos[bond[0]], pos[bond[1]])
bond_distance = space.distance(bond_displacement)
return bond_distance
batched_pair_boond_length = vmap(pairwise_bond_length, (0, None))
if positions.ndim == 3:
distances = vmap(batched_pair_boond_length, (None, 0))(bonds, positions)
return distances
elif positions.ndim == 2:
distances = vmap(pairwise_bond_length, (0, None))(bonds, positions)
return distances
else:
raise ValueError('Positions must be either of shape Ntimestep x Natoms'
' x spatial_dim or N_atoms x spatial_dim')
[docs]
def estimate_bond_constants(positions, bonds, displacement_fn):
"""Calculates the equlibrium harmonic bond constants from given positions.
Can be used to estimate the bond constants from an atomistic
simulation to be used as a coarse-grained prior.
Args:
positions: Position vector of size [Ntimestep x Natoms x spatial_dim]
or [N_atoms x spatial_dim]
bonds: (n_bonds, 2) array defining IDs of bonded particles
displacement_fn: Displacement function
Returns:
Tuple (eq_distances, eq_variances) of harmonic bond coefficients.
"""
distances = _bond_length(bonds, positions, displacement_fn)
eq_distances = jnp.mean(distances, axis=0)
eq_variances = jnp.var(distances, axis=0)
return eq_distances, eq_variances
[docs]
def angular_displacement(positions, displacement_fn, angle_idxs, degrees=True):
"""Computes the dihedral angle for all quadruple of atoms given in idxs.
Args:
positions: Positions of atoms in box
displacement_fn: Displacement function
dihedral_idxs: (n, 4) array defining IDs of quadruple particles
degrees: If False, returns angles in rads.
If True, returns angles in degrees.
Returns:
An array (n,) of the dihedral angles.
"""
p0 = positions[angle_idxs[:, 0]]
p1 = positions[angle_idxs[:, 1]]
p2 = positions[angle_idxs[:, 2]]
b0 = -1. * vmap(displacement_fn)(p1, p0)
b1 = vmap(displacement_fn)(p2, p1)
cos = vmap(quantity.cosine_angle_between_two_vectors)(b0, b1)
if degrees:
return jnp.degrees(jnp.arccos(cos))
else:
return jnp.arccos(cos)
[docs]
def dihedral_displacement(positions, displacement_fn, dihedral_idxs,
degrees=True):
"""Computes the dihedral angle for all quadruple of atoms given in idxs.
Args:
positions: Positions of atoms in box
displacement_fn: Displacement function
dihedral_idxs: (n, 4) array defining IDs of quadruple particles
degrees: If False, returns angles in rads.
If True, returns angles in degrees.
Returns:
An array (n,) of the dihedral angles.
"""
p0 = positions[dihedral_idxs[:, 0]]
p1 = positions[dihedral_idxs[:, 1]]
p2 = positions[dihedral_idxs[:, 2]]
p3 = positions[dihedral_idxs[:, 3]]
b0 = -1. * vmap(displacement_fn)(p1, p0)
b1 = vmap(displacement_fn)(p2, p1)
b2 = vmap(displacement_fn)(p3, p2)
# normalize b1 so that it does not influence magnitude of vector
# rejections that come next
b1 /= jnp.linalg.norm(b1, axis=1)[:, None]
# vector rejections
# v = projection of b0 onto plane perpendicular to b1
# = b0 minus component that aligns with b1
# w = projection of b2 onto plane perpendicular to b1
# = b2 minus component that aligns with b1
v = b0 - jnp.sum(b0 * b1, axis=1)[:, None] * b1
w = b2 - jnp.sum(b2 * b1, axis=1)[:, None] * b1
# angle between v and w in a plane is the torsion angle
# v and w may not be normalized but that's fine since tan is y/x
x = jnp.sum(v * w, axis=1)
cross = vmap(jnp.cross)(b1, v)
y = jnp.sum(cross * w, axis=1)
if degrees:
return jnp.degrees(jnp.arctan2(y, x))
else:
return jnp.arctan2(y, x)
[docs]
def kinetic_energy_tensor(state):
"""Computes the kinetic energy tensor of a single snapshot.
Args:
state: Jax_md simulation state
Returns:
Kinetic energy tensor
"""
average_velocity = jnp.mean(state.velocity, axis=0)
thermal_excitation_velocity = state.velocity - average_velocity
diadic_velocity_product = vmap(lambda v: jnp.outer(v, v))
velocity_tensors = diadic_velocity_product(
jnp.sqrt(state.mass) * thermal_excitation_velocity)
return util.high_precision_sum(velocity_tensors, axis=0)
[docs]
def virial_potential_part(energy_fn,
state,
nbrs,
reference_box,
energy_and_force=None,
fractional_coordinates=True,
**kwargs):
"""Interaction part of the virial pressure tensor for a single snaphot
based on the formulation of Chen at al. (2020). See
init_virial_stress_tensor. for details."""
position = state.position # in unit box if fractional coordinates used
box, kwargs = _dyn_box(reference_box, state, **kwargs)
if energy_and_force is None:
energy_fn_ = lambda pos, neighbor, box: energy_fn(
pos, neighbor=neighbor, box=box, **kwargs) # for grad
negative_forces, box_gradient = grad(energy_fn_, argnums=[0, 2])(
position, nbrs, box)
else:
print(f"[Virial] Found precomputed forces.")
box_gradient = energy_and_force['box_grad']
box_contribution = jnp.dot(box_gradient, box.T)
return box_contribution
[docs]
def init_virial_stress_tensor(energy_fn_template, reference_box=None,
include_kinetic=True, pressure_tensor=False):
"""Initializes a function that computes the virial stress tensor for a
single state.
This function is applicable to arbitrary many-body interactions without
explicit volume dependence, e.g., that only model interactions between
images with minimum distance. It does not respect the volume dependence
of, e.g., long-range electrostatic correction in periodic systems.
Chen et al. "TensorAlloy: An automatic atomistic neural network program
for alloys". Computer Physics Communications 250 (2020): 107057
Args:
energy_fn_template: A function that takes energy parameters as input
and returns an energy function
reference_box: The transformation T of general periodic boundary
conditions. If None, box_tensor needs to be provided as ``'box'``
during function call, e.g. for the NPT ensemble.
include_kinetic: Whether kinetic part of stress tensor should be added.
pressure_tensor: If False (default), returns the stress tensor. If True,
returns the pressure tensor, i.e. the negative stress tensor.
Returns:
A function that takes a simulation state with neighbor list,
energy_params and box (if applicable) and returns the instantaneous
virial stress tensor.
"""
if pressure_tensor:
pressure_sign = -1.
else:
pressure_sign = 1.
def virial_stress_tensor_neighborlist(state, neighbor, energy_params,
**kwargs):
# Note: this workaround with the energy_template was needed to keep
# the function jitable when changing energy_params on-the-fly
# TODO function to transform box to box-tensor
box, kwargs = _dyn_box(reference_box, state, **kwargs)
energy_fn = energy_fn_template(energy_params)
virial_tensor = virial_potential_part(
energy_fn, state, neighbor, box, **kwargs)
spatial_dim = state.position.shape[-1]
volume = quantity.volume(spatial_dim, box)
if include_kinetic:
kinetic_tensor = -1 * kinetic_energy_tensor(state)
return pressure_sign * (kinetic_tensor + virial_tensor) / volume
else:
return pressure_sign * virial_tensor / volume
return virial_stress_tensor_neighborlist
[docs]
def init_pressure(energy_fn_template, reference_box=None, include_kinetic=True):
"""Initializes a function that computes the pressure for a single state.
This function is applicable to arbitrary many-body interactions, even
under periodic boundary conditions. See `init_virial_stress_tensor`
for details.
Args:
energy_fn_template: A function that takes energy parameters as input
and returns an energy function
ref_box_tensor: The transformation T of general periodic boundary
conditions. If None, box_tensor needs to be provided as
'box' during function call, e.g. for NPT ensemble.
include_kinetic: Whether kinetic part of stress tensor should be added.
Returns:
A function that takes a simulation state with neighbor list,
energy_params and box (if applicable) and returns the instantaneous
pressure.
"""
# pressure is negative hydrostatic stress
stress_tensor_fn = init_virial_stress_tensor(
energy_fn_template, reference_box, include_kinetic=include_kinetic,
pressure_tensor=True
)
def pressure_neighborlist(state, neighbor, energy_params, **kwargs):
pressure_tensor = stress_tensor_fn(state, neighbor, energy_params,
**kwargs)
return jnp.trace(pressure_tensor) / 3.
return pressure_neighborlist
[docs]
def energy_under_strain(epsilon, energy_fn, box_tensor, state, neighbor,
**kwargs):
"""Potential energy of a state after applying linear strain epsilon."""
# Note: When computing the gradient, we deal with infinitesimally
# small strains. Linear strain theory is therefore valid and
# additionally tan(gamma) = gamma. These assumptions are used
# computing the box after applying the stain.
strained_box = jnp.dot(box_tensor, jnp.eye(box_tensor.shape[0]) + epsilon)
energy = energy_fn(state.position, neighbor=neighbor, box=strained_box,
**kwargs)
return energy
[docs]
def init_sigma_born(energy_fn_template, reference_box=None):
"""Initialiizes a function that computes the Born contribution to the
stress tensor.
sigma^B_ij = d U / d epsilon_ij
Can also be computed to compute the stress tensor at kbT = 0, when called
on the state of minimum energy. This function requires that `energy_fn`
takes a `box` keyword argument, usually alongside `periodic_general`
boundary conditions.
Args:
energy_fn_template: A function that takes energy parameters as input
and returns an energy function
ref_box_tensor: The transformation T of general periodic boundary
conditions. If None, box_tensor needs to be provided as
'box' during function call, e.g. for the NPT ensemble.
Returns:
A function that takes a simulation state with neighbor list,
energy_params and box (if applicable) and returns the instantaneous
Born contribution to the stress tensor.
"""
def sigma_born(state, neighbor, energy_params, **kwargs):
box, kwargs = _dyn_box(reference_box, state, **kwargs)
spatial_dim = state.position.shape[-1]
volume = quantity.volume(spatial_dim, box)
epsilon0 = jnp.zeros((spatial_dim, spatial_dim))
energy_fn = energy_fn_template(energy_params)
sigma_b = jacrev(energy_under_strain)(
epsilon0, energy_fn, box, state, neighbor, **kwargs)
return sigma_b / volume
return sigma_born
[docs]
def init_stiffness_tensor_stress_fluctuation(energy_fn_template, reference_box):
"""Initializes all functions necessary to compute the elastic stiffness
tensor via the stress fluctuation method in the NVT ensemble.
The provided functions compute all necessary instantaneous properties
necessary to compute the elastic stiffness tensor via the stress fluctuation
method. However, for compatibility with DiffTRe, (weighted) ensemble
averages need to be computed manually and given to the stiffness_tensor_fn
for final computation of the stiffness tensor. For an example usage see
the diamond notebook. The implementation follows the formulation derived by
Van Workum et al., "Isothermal stress and elasticity tensors for ions and
point dipoles using Ewald summations", PHYSICAL REVIEW E 71, 061102 (2005).
# TODO provide sample usage
Args:
energy_fn_template: A function that takes energy parameters as input
and returns an energy function
box_tensor: The transformation T of general periodic boundary
conditions. As the stress-fluctuation method is only
applicable to the NVT ensemble, the box_tensor needs to be
provided here as a constant, not on-the-fly.
kbt: Temperature in units of the Boltzmann constant
n_particles: Number of particles in the box
Returns: A tuple of 3 functions:
born_term_fn: A function computing the Born contribution to the
stiffness tensor for a single snapshot
sigma_born: A function computing the Born contribution to the stress
tensor for a single snapshot
sigma_tensor_prod: A function computing sigma^B_ij * sigma^B_kl given
a trajectory of sigma^B_ij
stiffness_tensor_fn: A function taking ensemble averages of C^B_ijkl,
sigma^B_ij and sigma^B_ij * sigma^B_kl and
returning the resulting stiffness tensor.
"""
# TODO this function simplifies a lot if split between per-snapshot
# and per-trajectory functions
# spatial_dim = reference_box.shape[-1]
# volume = quantity.volume(spatial_dim, reference_box)
# epsilon0 = jnp.zeros((spatial_dim, spatial_dim))
def born_term_fn(state, neighbor, energy_params, **kwargs):
"""Born contribution to the stiffness tensor:
C^B_ijkl = d^2 U / d epsilon_ij d epsilon_kl
"""
# check if box is passed in dynamic kwargs and use it if provided, else use reference box
box, kwargs = _dyn_box(reference_box, state, **kwargs)
spatial_dim = state.position.shape[-1]
volume = quantity.volume(spatial_dim, box)
epsilon0 = jnp.zeros((spatial_dim, spatial_dim))
energy_fn = energy_fn_template(energy_params)
born_stiffness_contribution = jax.hessian(energy_under_strain)(
epsilon0, energy_fn, box, state, neighbor, **kwargs
)
return born_stiffness_contribution / volume
return born_term_fn