Source code for jax_md_mod.custom_energy

# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Custom definition of some potential energy functions."""
from functools import partial
from typing import Callable, Any

from jax import vmap
import jax.numpy as jnp
from jax_md import space, partition, util, energy, smap

from jax_md_mod import custom_interpolate, custom_quantity
from jax_md_mod.model import sparse_graph

# Types
f32 = util.f32
f64 = util.f64
Array = util.Array

PyTree = Any
Box = space.Box
DisplacementFn = space.DisplacementFn
DisplacementOrMetricFn = space.DisplacementOrMetricFn

NeighborFn = partition.NeighborFn
NeighborList = partition.NeighborList


[docs] def stillinger_weber_energy(dr, d_vect, mask=None, a=7.049556277, b=0.6022245584, p=4, lam=21.0, epsilon=2.16826, gamma=1.2, sigma=2.0951, cutoff=1.8*2.0951, three_body_strength=1.0): """Computes the stiling weber potential. The Stillinger-Weber (SW) potential [#Stillinger]_ is commonly used to model silicon and similar systems. This function uses the default SW parameters from the original paper. The SW potential was originally proposed to model diamond in the diamond crystal phase and the liquid phase, and is known to give unphysical amorphous configurations [#Holender]_ [#Barkema]_. For this reason, we provide a three_body_strength parameter. Changing this number to $1.5$ or $2.0$ has been known to produce more physical amorphous phase, preventing most atoms from having more than four nearest neighbors. Note that this function currently assumes nearest-image-convention. References: .. [#Stillinger] Stillinger, Frank H., and Thomas A. Weber. "Computer simulation of local order in condensed phases of silicon." Physical review B 31.8 (1985): 5262. .. [#Holender] Holender, J. M., and G. J. Morgan. "Generation of a large structure (105 atoms) of amorphous Si using molecular dynamics." Journal of Physics: Condensed Matter 3.38 (1991): 7241. .. [#Barkema] Barkema, G. T., and Normand Mousseau. "Event-based relaxation of continuous disordered systems." Physical review letters 77.21 (1996): 4358. Args: dr: A ndarray of pairwise distances between particles d_vect: An ndarray of pairwise displacements between particles a: A scalar that determines the scale of two-body term b: Factor for radial power term p: Power in radial interaction lam: A scalar that determines the scale of the three-body term epsilon: A scalar that sets the energy scale gamma: Exponential scale in three-body term sigma: A scalar that sets the length scale cutoff: Cut-off value defined as sigma * a three_body_strength: A scalar that determines the relative strength of the angular interaction mask: ndarray of size dr masking non-existing neighbors in neighborlist (if applicable) Returns: The Stilinger-Weber energy for a snapshot. """ # initialize if mask is None: n_particels = dr.shape[0] mask = jnp.ones([n_particels, n_particels]) angle_mask = jnp.ones([n_particels, n_particels, n_particels]) else: # for neighborlist input max_neighbors = mask.shape[-1] angle_mask1 = jnp.tile(jnp.expand_dims(mask, 1), [1, max_neighbors, 1]) angle_mask2 = jnp.tile(jnp.expand_dims(mask, -1), [1, 1, max_neighbors]) angle_mask = angle_mask1 * angle_mask2 sw_radial_interaction = partial(energy._sw_radial_interaction, sigma=sigma, p=p, b=b, cutoff=cutoff) sw_angle_interaction = partial(energy._sw_angle_interaction, gamma=gamma, sigma=sigma, cutoff=cutoff) sw_three_body_term = vmap(vmap(vmap(sw_angle_interaction, (0, None)), (None, 0)), 0) # compute SW energy radial_interactions = sw_radial_interaction(dr) * mask angular_interactions = sw_three_body_term(d_vect, d_vect) * angle_mask first_term = a * jnp.sum(radial_interactions) / 2.0 second_term = lam * jnp.sum(angular_interactions) / 2.0 return epsilon * (first_term + three_body_strength * second_term)
[docs] def stillinger_weber_pair(displacement, a=7.049556277, b=0.6022245584, p=4, lam=21.0, epsilon=2.16826, gamma=1.2, sigma=2.0951, cutoff=1.8*2.0951, three_body_strength=1.0): """Convenience wrapper to compute stilinger-weber energy over a system with variable parameters. """ def compute_fn(pos, **dynamic_kwargs): d = partial(displacement, **dynamic_kwargs) dvect = space.map_product(d)(pos, pos) # N x N x3 displacement matrix dr = space.distance(dvect) # N x N distances return stillinger_weber_energy(dr, dvect, None, a, b, p, lam, epsilon, gamma, sigma, cutoff, three_body_strength) return compute_fn
[docs] def stillinger_weber_neighborlist(displacement, box_size=None, a=7.049556277, b=0.6022245584, p=4, lam=21.0, epsilon=2.16826, gamma=1.2, sigma=2.0951, cutoff=1.8*2.0951, three_body_strength=1.0, dr_threshold=0.1, capacity_multiplier=1.25, initialize_neighbor_list=True): """Convenience wrapper to compute stilinger-weber energy using a neighbor list. """ def energy_fn(pos, neighbor, **dynamic_kwargs): d = partial(displacement, **dynamic_kwargs) n_particles = pos.shape[0] mask = neighbor.idx != n_particles r_neigh = pos[neighbor.idx] dvect = space.map_neighbor(d)(pos, r_neigh) dr = space.distance(dvect) return stillinger_weber_energy(dr, dvect, mask, a, b, p, lam, epsilon, gamma, sigma, cutoff, three_body_strength) if initialize_neighbor_list: assert box_size is not None neighbor_fn = partition.neighbor_list( displacement, box_size, cutoff, dr_threshold, capacity_multiplier=capacity_multiplier) return neighbor_fn, energy_fn return energy_fn
[docs] def harmonic_angle(displacement_or_metric: DisplacementOrMetricFn, angle_idxs: Array, eq_mean: Array = None, eq_variance: Array = None, kbt: [float, Array] = None, th0: Array = None, kth: Array = None, ): """Harmonic Angle interaction. The variance of the angle is used to determine the force constant. https://manual.gromacs.org/documentation/2019/reference-manual/functions/bonded-interactions.html Args: displacement_or_metric: Displacement function angle_idxs: Indices of particles (i, j, k) eq_mean: Equilibrium angle in degrees eq_variance: Angle Variance kbt: kbT Returns: Harmonic angle potential energy function. """ angle_mask = jnp.ones([angle_idxs.shape[0], 1]) if th0 is None: th0 = eq_mean if kth is None: kbt = jnp.array(kbt, dtype=f32) kth = kbt / eq_variance harmonic_fn = partial(energy.simple_spring, length=th0, epsilon=kth) def energy_fn(pos, **unused_kwargs): angles = sparse_graph.angle_triplets(pos, displacement_or_metric, angle_idxs, angle_mask) return jnp.sum(harmonic_fn(jnp.rad2deg(angles))) return energy_fn
[docs] def dihedral_energy(angle, phase_angle: Array, force_constant: Array, n: [int, Array]): """Energy of dihedral angles. https://manual.gromacs.org/documentation/2019/reference-manual/functions/bonded-interactions.html """ # Alternatively to varying the phase shift angle from 0 to + 180, we # allow choosing a negative force constant, which will have the same effect. cos_angle = jnp.cos(n * angle - phase_angle) energies = jnp.abs(force_constant) + force_constant * cos_angle return jnp.sum(energies)
[docs] def periodic_dihedral(displacement_or_metric: DisplacementOrMetricFn, dihedral_idxs: Array, phase_angle: Array, force_constant: Array, multiplicity: [float, Array]): """Peridoc dihedral angle interaction. https://manual.gromacs.org/documentation/2019/reference-manual/functions/bonded-interactions.html Args: displacement_or_metric: Displacement function dihedral_idxs: Indices of particles (i, j, k, l) building the dihedrals phase_angle: Dihedral phase angle in degrees. force_constant: Force constant multiplicity: Dihedral multiplicity Returns: Peridoc dihedral potential energy function. """ multiplicity = jnp.array(multiplicity, dtype=f32) phase_angle = jnp.deg2rad(phase_angle) def energy_fn(pos, **unused_kwargs): dihedral_angles = custom_quantity.dihedral_displacement( pos, displacement_or_metric, dihedral_idxs, degrees=False) per_angle_u = vmap(dihedral_energy)(dihedral_angles, phase_angle, force_constant, multiplicity) return jnp.sum(per_angle_u) return energy_fn
[docs] def truncated_lennard_jones(dr: Array, sigma: Array = 1., epsilon: Array = 1., exp: Array = 12., **unused_dynamic_kwargs) -> Array: """Lennard Jones interaction truncated and shifted at the minimum. Args: dr: An ndarray of pairwise distances between particles. sigma: Repulsion length scale epsilon: Interaction energy scale exp: Exponent specifying interaction stiffness Returns: Array of energies """ del unused_dynamic_kwargs dr = jnp.where(dr > 1.e-7, dr, 1.e7) # save masks dividing by 0 r_min = 2.0 ** (2. / exp) * sigma dr = jnp.where(dr > r_min, r_min, dr) # truncate at the minimum idr = (sigma / dr) pot_energy = 4 * epsilon * (idr ** exp - idr ** (exp / 2.)) pot_energy += epsilon # shift return pot_energy
[docs] def truncated_lennard_jones_neighborlist( displacement_or_metric: DisplacementOrMetricFn, box_size: Box = None, species: Array = None, sigma: Array = 1.0, epsilon: Array = 1.0, exp: [int, Array] = 12., dr_threshold: float = 0.2, per_particle: bool = False, capacity_multiplier: float = 1.25, initialize_neighbor_list: bool = True): """Convenience wrapper to compute generic repulsion energy over a system with neighborlist. Provides option not to initialize neighborlist. This is useful if energy function needs to be initialized within a jitted function. """ if isinstance(sigma, tuple): sigma = (sigma[0], jnp.array(sigma[1], f32)) if isinstance(epsilon, tuple): epsilon = (epsilon[0], jnp.array(epsilon[1], f32)) energy_fn = smap.pair_neighbor_list( truncated_lennard_jones, space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon, exp=exp, reduce_axis=(1,) if per_particle else None) if initialize_neighbor_list: assert box_size is not None neighbor_fn = partition.neighbor_list( displacement_or_metric, box_size, r_cutoff, dr_threshold, capacity_multiplier=capacity_multiplier ) return neighbor_fn, energy_fn return energy_fn
[docs] def generic_repulsion(dr: Array, sigma: Array = 1., epsilon: Array = 1., exp: Array = 12., **unused_dynamic_kwargs) -> Array: """ Repulsive interaction between soft sphere particles: U = epsilon * (sigma / r)**exp. Args: dr: An ndarray of pairwise distances between particles. sigma: Repulsion length scale epsilon: Interaction energy scale exp: Exponent specifying interaction stiffness Returns: Array of energies """ del unused_dynamic_kwargs dr = jnp.where(dr > 1.e-7, dr, 1.e7) # save masks dividing by 0 idr = (sigma / dr) pot_energy = epsilon * idr ** exp return pot_energy
[docs] def generic_repulsion_pair( displacement_or_metric: DisplacementOrMetricFn, species: Array = None, sigma: Array = 1.0, epsilon: Array = 1.0, exp: Array = 12., r_onset: Array = 2.0, r_cutoff: Array = 2.5, per_particle: bool = False): """Convenience wrapper to compute generic repulsion energy over a system.""" sigma = jnp.array(sigma, dtype=f32) epsilon = jnp.array(epsilon, dtype=f32) exp = jnp.array(exp, dtype=f32) r_onset = jnp.array(r_onset, dtype=f32) r_cutoff = jnp.array(r_cutoff, dtype=f32) return smap.pair( energy.multiplicative_isotropic_cutoff(generic_repulsion, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon, exp=exp, reduce_axis=(1,) if per_particle else None)
[docs] def generic_repulsion_neighborlist( displacement_or_metric: DisplacementOrMetricFn, box_size: Box = None, species: Array = None, sigma: Array = 1.0, epsilon: Array = 1.0, exp: [int, Array] = 12., r_onset: Array = 0.9, r_cutoff: Array = 1., dr_threshold: float = 0.2, per_particle: bool = False, capacity_multiplier: float = 1.25, initialize_neighbor_list: bool = True): """Convenience wrapper to compute generic repulsion energy over a system with neighborlist. Provides option not to initialize neighborlist. This is useful if energy function needs to be initialized within a jitted function. """ if isinstance(sigma, tuple): sigma = (sigma[0], jnp.array(sigma[1], f32)) if isinstance(epsilon, tuple): epsilon = (epsilon[0], jnp.array(epsilon[1], f32)) exp = jnp.array(exp, dtype=f32) r_onset = jnp.array(r_onset, dtype=f32) r_cutoff = jnp.array(r_cutoff, dtype=f32) energy_fn = smap.pair_neighbor_list( energy.multiplicative_isotropic_cutoff(generic_repulsion, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon, exp=exp, reduce_axis=(1,) if per_particle else None) if initialize_neighbor_list: assert box_size is not None neighbor_fn = partition.neighbor_list( displacement_or_metric, box_size, r_cutoff, dr_threshold, capacity_multiplier=capacity_multiplier ) return neighbor_fn, energy_fn return energy_fn
[docs] def generic_repulsion_nonbond(displacement_or_metric: DisplacementOrMetricFn, pair_idxs: Array, sigma: Array = 1., epsilon: Array = 1., exp: Array = 12.) -> Callable[[Array], Array]: """Convenience wrapper to compute repulsive part of Lennard Jones energy of particles via connection idxs. Args: displacement_or_metric: Displacement_fn pair_idxs: Set of pair indices (i, j) defining repulsion pairs sigma: sigma epsilon: epsilon exp: LJ exponent Returns: Pairwise nonbonded repulsion potential energy function. """ sigma = jnp.array(sigma, f32) epsilon = jnp.array(epsilon, f32) exp = jnp.array(exp, dtype=f32) return smap.bond( generic_repulsion, space.canonicalize_displacement_or_metric(displacement_or_metric), pair_idxs, ignore_unused_parameters=True, sigma=sigma, epsilon=epsilon, exp=exp)
[docs] def lorentz_berthelot(idxs, species, sigma_dict, epsilon_dict): """Applies the Lorentz-Berthelot rule to a indices and species array. The Lorentz-Berthelot rules [#wikipedia]_ calculate the $\\sigma$ and $\\epsilon$ epsilon values from a given dictonary. .. math:: \\sigma_{ij} = \\frac{\\sigma_{ii} + \\sigma_{jj}}{2} \\epsilon_{ij} = \\sqrt{\\epsilon_{ii} * \\epsilon_{jj}} References: .. [#wikipedia] https://en.wikipedia.org/wiki/Combining_rules """ pairs = species[idxs] u, inv = jnp.unique(pairs, return_inverse=True) sigma = jnp.array([sigma_dict[x] for x in u])[inv].reshape(pairs.shape) sigma = jnp.sum(sigma, axis=1) * 0.5 epsilon = jnp.array([epsilon_dict[x] for x in u])[inv].reshape(pairs.shape) epsilon = jnp.sqrt(jnp.prod(epsilon, axis=1)) return sigma, epsilon
[docs] def lennard_jones_nonbond(displacement_or_metric: DisplacementOrMetricFn, pair_idxs: Array, sigma: Array = 1., epsilon: Array = 1.) -> Callable[[Array], Array]: """Convenience wrapper to compute lennard jones energy of nonbonded particles. Args: displacement_or_metric: Displacement_fn pair_idxs: Set of pair indices (i, j) defining repulsion pairs sigma: sigma epsilon: epsilon Returns: Pairwise nonbonded repulsion potential energy function. """ sigma = jnp.array(sigma, f32) epsilon = jnp.array(epsilon, f32) return smap.bond( energy.lennard_jones, space.canonicalize_displacement_or_metric(displacement_or_metric), pair_idxs, ignore_unused_parameters=True, sigma=sigma, epsilon=epsilon)
[docs] def customn_lennard_jones_neighbor_list( displacement_or_metric: DisplacementOrMetricFn, box_size: Box, species: Array = None, sigma: Array = 1.0, epsilon: Array = 1.0, r_onset: float = 2.0, r_cutoff: float = 2.5, dr_threshold: float = 0.5, per_particle: bool = False, capacity_multiplier: float = 1.25, initialize_neighbor_list: bool = True, fractional: bool = True, disable_cell_list: bool = False): """Convenience wrapper to compute lennard-jones using a neighbor list. Different implementation of the cutoff to disentable with energy_params. Option not to initialize neighbor list to allow jitable building of energy function for varying sigma and epsilon.""" if isinstance(sigma, tuple): sigma = (sigma[0], jnp.array(sigma[1], f32)) if isinstance(epsilon, tuple): epsilon = (epsilon[0], jnp.array(epsilon[1], f32)) r_onset = jnp.array(r_onset, f32) r_cutoff = jnp.array(r_cutoff, f32) dr_threshold = jnp.array(dr_threshold, f32) energy_fn = smap.pair_neighbor_list( energy.multiplicative_isotropic_cutoff(energy.lennard_jones, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, sigma=sigma, epsilon=epsilon, reduce_axis=(1,) if per_particle else None) if initialize_neighbor_list: neighbor_fn = partition.neighbor_list( displacement_or_metric, box_size, r_cutoff, dr_threshold, capacity_multiplier=capacity_multiplier, fractional_coordinates=fractional, disable_cell_list=disable_cell_list) return neighbor_fn, energy_fn return energy_fn
[docs] def tabulated(dr: Array, spline: Callable[[Array], Array], **unused_kwargs ) -> Array: """ Tabulated radial potential between particles given a spline function. Args: dr: An ndarray of pairwise distances between particles spline: A function computing the spline values at a given pairwise distance. Returns: Array of energies """ return spline(dr)
[docs] def tabulated_pair(displacement_or_metric: DisplacementOrMetricFn, x_vals: Array, y_vals: Array, degree: int = 3, r_onset: Array = 0.9, r_cutoff: Array = 1., species: Array = None, per_particle: bool = False) -> Callable[[Array], Array]: """Convenience wrapper to compute tabulated energy over a system.""" x_vals = jnp.array(x_vals, f32) y_vals = jnp.array(y_vals, f32) r_onset = jnp.array(r_onset, f32) r_cutoff = jnp.array(r_cutoff, f32) spline = custom_interpolate.MonotonicInterpolate(x_vals, y_vals) tabulated_partial = partial(tabulated, spline=spline) return smap.pair( energy.multiplicative_isotropic_cutoff(tabulated_partial, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, reduce_axis=(1,) if per_particle else None)
[docs] def tabulated_neighbor_list(displacement_or_metric: DisplacementOrMetricFn, x_vals: Array, y_vals: Array, box_size: Box, degree: int = 3, r_onset: Array = 0.9, r_cutoff: Array = 1., dr_threshold: Array = 0.2, species: Array = None, capacity_multiplier: float = 1.25, initialize_neighbor_list: bool = True, per_particle: bool = False, fractional=True): """ Convenience wrapper to compute tabulated energy using a neighbor list. Provides option not to initialize neighborlist. This is useful if energy function needs to be initialized within a jitted function. """ x_vals = jnp.array(x_vals, f32) y_vals = jnp.array(y_vals, f32) box_size = jnp.array(box_size, f32) r_onset = jnp.array(r_onset, f32) r_cutoff = jnp.array(r_cutoff, f32) dr_threshold = jnp.array(dr_threshold, f32) # Note: cannot provide the spline parameters via kwargs because only # per-particle parameters are supported spline = custom_interpolate.MonotonicInterpolate(x_vals, y_vals) tabulated_partial = partial(tabulated, spline=spline) energy_fn = smap.pair_neighbor_list( energy.multiplicative_isotropic_cutoff(tabulated_partial, r_onset, r_cutoff), space.canonicalize_displacement_or_metric(displacement_or_metric), species=species, reduce_axis=(1,) if per_particle else None) if initialize_neighbor_list: neighbor_fn = partition.neighbor_list( displacement_or_metric, box_size, r_cutoff, dr_threshold, capacity_multiplier=capacity_multiplier, fractional_coordinates=fractional) return neighbor_fn, energy_fn return energy_fn