Source code for chemtrain.quantity.observables

# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Molecular dynamics observable functions acting on trajectories rather than
single snapshots.

Builds on the TrajectoryState object defined in traj_util.py.
"""
import functools

from jax import vmap, numpy as jnp, lax
from jax_md import simulate, quantity

from chemtrain.quantity import constants

from typing import Dict
from jax.typing import ArrayLike
from chemtrain.typing import TrajFn


[docs] def dynamic_statepoint(keys_with_defaults: Dict[str, ArrayLike] = None): """Initializes a decorator enabeling a dynamically provided statepoint definition. Args: keys_with_defaults: List of keys that can be provided dynamically with default values. Returns: Returns a decorator to filter the dynamically provided arguments. """ if keys_with_defaults is None: keys_with_defaults = {} def decorator(obs_fn): @functools.wraps(obs_fn) def wrapper(quantity_trajs, weights=None, **state_kwargs): if len(keys_with_defaults) == 0: return obs_fn(quantity_trajs, weights) state_dict = { key: state_kwargs.pop(key, default) for key, default in keys_with_defaults.items() } return obs_fn(quantity_trajs, state_dict, weights) return wrapper return decorator
[docs] def init_identity_fn(quantity_key) -> TrajFn: """Initialize a wrapper that returns the snapshots unchanged. Args: quantity_key: Key referencing the quantity in the computed snapshots. Returns: Returns the snapshots of the quantity. """ @dynamic_statepoint() def identity(quantity_trajs, weights=None): return quantity_trajs[quantity_key] return identity
[docs] def init_traj_mean_fn(quantity_key) -> TrajFn: """Initialize the computation of a (perturbed) ensemble average. This function builds the ``'traj_fn'`` of the DiffTRe ``target`` dict for the common case of target observables that simply consist of a trajectory-average of instantaneous quantities, such as RDF, ADF, pressure or density. .. math:: a = \\left\\langle w(\\mathbf r) a(\\mathbf r)\\right\\rangle_{\\tilde U},\\quad \\text{where} \\\\ w(\\mathbf r) = \\frac{e^{-\\beta (U(\\mathbf r) - \\tilde U(\\mathbf r))}}{\\left\\langle e^{-\\beta (U(\\mathbf r) - \\tilde U(\\mathbf r))}\\right\\rangle} Args: quantity_key: Key referencing the quantity in the computed snapshots. Returns: Returns a function to compute an ensemble average when given a dictionary of snapshots. Returns a perturbed ensemble average if weights are provided. """ @dynamic_statepoint() def traj_mean(quantity_trajs, weights=None): quantity_traj = quantity_trajs[quantity_key] return _traj_mean(quantity_traj, weights, linearized=False) return traj_mean
def _weighted_mean(quantity_traj, weights): if weights is not None: weights *= weights.size quantity_traj = (quantity_traj.T * weights).T return jnp.mean(quantity_traj, axis=0) def _linearized_mean(quantity_traj, weights): # Mask out zero weights mask = weights > 1e-30 weights = jnp.where(mask, weights, 1.0) n_nonzero = jnp.sum(mask) # Calculate the masked means quantity_mean = jnp.mean(quantity_traj, axis=0) correction = jnp.sum((quantity_traj.T * jnp.log(weights) * mask).T, axis=0) correction -= quantity_mean * jnp.sum(jnp.log(weights) * mask) return quantity_mean + correction / n_nonzero def _traj_mean(quantity_traj, weights=None, linearized=False, **kwargs): del kwargs if linearized: assert weights is not None, ( "The linearized trajectory average requires weights to be " "specified.") return _linearized_mean(quantity_traj, weights) else: return _weighted_mean(quantity_traj, weights)
[docs] def init_linear_traj_mean_fn(quantity_key) -> TrajFn: """Initializes the linear approximation of an observable average. This function approximates ensemble averages of instantaneous quantities for a perturbed ensemble based on a cumulant expansion [#imbalzano2021]_. .. [#imbalzano2021] Giulio Imbalzano, Yongbin Zhuang, Venkat Kapil, Kevin Rossi, Edgar A. Engel, Federico Grasselli, Michele Ceriotti; _Uncertainty estimation for molecular dynamics and sampling._ J. Chem. Phys. 21 February 2021; 154 (7): 074102. https://doi.org/10.1063/5.0036522 Args: quantity_key: Key referencing the quantity in the computed snapshots. Returns: Returns a function to approximate a perturbed ensemble average when given a dictionary of snapshots and weights. """ @dynamic_statepoint() def traj_mean(quantity_trajs, weights=None): quantity_traj = quantity_trajs[quantity_key] return _traj_mean(quantity_traj, weights, linearized=True) return traj_mean
[docs] def init_relative_entropy_traj_fn(kT: float, reference_key = 'ref_energy') -> TrajFn: r"""Initializes the computation of the relative entropy difference between the current canonical distribution and reference distribution. The relative entropy is given as .. math:: S_\text{rel} = -\int p(x)\log\frac{p(x)}{q(x)}dx. For two canonical distributions defined by the potentials U_p and U_q, this relative entropy computes as .. math:: S_\text{rel} = \beta\left\langle\left(U_p - U_q\right)\right\rangle_{U_p} - \log\left\langle e^{-\beta(U_q - U_p)}\right\rangle_{U_p}. Args: kT: Reference temperature. reference_key: Key corresponding to the reference energy quantity. Returns: Returns a function that computes the relative entropy given a dictionary of quantities along a trajectory. """ @dynamic_statepoint({'kT': kT}) def relative_entropy_traj_fn(quantity_trajs, state_dict, weights=None): beta = 1.0 / state_dict['kT'] energies = quantity_trajs['energy'] ref_energies = quantity_trajs[reference_key] if weights is None: weights = jnp.ones_like(energies) / energies.size energy_difference = jnp.sum(weights * (energies - ref_energies)) # Compute the free energy from the current to the reference potential # via the forward perturbation and reverse the sign exponents = -beta * (ref_energies - energies) max_exponent = jnp.max(exponents) exp = jnp.exp(exponents - max_exponent) # We use this method of free energy calculation only for the value but # not for the gradient. # The quality of the estimation via the perturbation formula depends # on the difference between the potentials, which is perhaps large. # Instead, the gradients of the free energy are quite easily computable # with respect to any other reference state. free_energy_difference = jnp.log(jnp.sum(weights * exp)) free_energy_difference += max_exponent free_energy_difference *= -1.0 / beta free_energy_difference = lax.stop_gradient(free_energy_difference) # By using stop_gradient, we can ensure the contribution to the gradient # without a contribution to the absolute value. free_energy_difference -= quantity_trajs['free_energy'] free_energy_difference += lax.stop_gradient( quantity_trajs['free_energy']) entropy_difference = energy_difference + free_energy_difference entropy_difference *= beta * constants.kb return entropy_difference return relative_entropy_traj_fn
def init_free_energy(kT, pressure=None, **kwargs): # TODO: Documentation assert pressure is None, ( "The free energy calculation is not yet implemented for the NPT ensemble." ) @dynamic_statepoint({'kT': kT}) def compute_fn(quantity_trajs, state_dict, weights=None): # Free energy already computed logZ = quantity_trajs['free_energy'] free_energy = -state_dict['kT'] * logZ return free_energy return compute_fn
[docs] def init_heat_capacity_nvt(kT, dof, **kwargs): """Returns the specific heat capacity of a system in the NPT ensemble via the fluctuation formula. The specific heat capacity is returned in units of the Boltzmann constant. Args: kT: Reference temperature dof: Number of degrees of freedom in the system kwargs: Additional arguments determining the calculation of trajectory averages. Returns: Returns a function to compute the (perturbed) specific heat capacity based on potential energy snapshots. References: `<https://journals-aps-org.eaccess.tum.edu/pre/pdf/10.1103/PhysRevE.99.012139>` """ @dynamic_statepoint({'kT': kT, 'dof': dof}) def specific_heat_capacity_fn(quantity_trajs, state_dict, weights=None): energy = quantity_trajs['energy'] if weights is None: weights = jnp.ones_like(energy) / energy.size fluctuation = _traj_mean(energy ** 2, weights, **kwargs) fluctuation -= _traj_mean(energy, weights, **kwargs) ** 2 c_v = (0.5 * state_dict['dof'] + fluctuation / state_dict['kT'] ** 2) return c_v return specific_heat_capacity_fn
[docs] def init_heat_capacity_npt(kT, dof, pressure, **kwargs): """Initializes the isobaric heat capacity of a system in the NPT ensemble. The heat capacity of the system depends on the fluctuation of the conformational enthalpy [#stroeker2021]_. Args: kT: Thermostat temperature dof: Number of degrees of freedom in the system kwargs: Extra arguments to the function computing the ensemble averages Returns: Returns a function to compute the (perturbed) specific heat capacity based on potential energy snapshots and volume snapshots. References: .. [#stroeker2021] P. Ströker, R. Hellmann, und K. Meier, „Systematic formulation of thermodynamic properties in the N p T ensemble“, Phys. Rev. E, Bd. 103, Nr. 2, S. 023305, Feb. 2021, doi: 10.1103/PhysRevE.103.023305. """ @dynamic_statepoint({'kT': kT, 'pressure': pressure, 'dof': dof}) def cp_fn(quantity_traj, state_dict, weights=None): volume_traj = quantity_traj['volume'] energy_traj = quantity_traj['energy'] enthalpy = energy_traj + volume_traj * state_dict['pressure'] fluctuation = _traj_mean(enthalpy ** 2, weights, **kwargs) fluctuation -= _traj_mean(enthalpy, weights, **kwargs) ** 2 cp = (0.5 * state_dict['dof'] + fluctuation / state_dict['kT'] ** 2) return cp return cp_fn
[docs] def init_born_stiffness_tensor(reference_box, dof, kT, elastic_constant_function=None): """Computes the elastic stiffness tensor via the stress fluctuation method in the NVT ensemble. Args: reference_box: Default box used to compute the volume dof: Default number of degrees of freedom kT: Default temperature elastic_constant_function: Function to compute the elastic constants from the stiffness tensor. If None, the stiffness tensor is returned. Returns: born_stiffness: Either the stiffness tenosr or the elastic constants if the elastic_constant_function is provided. """ def stiffness_tensor_fn(mean_born, mean_sigma, mean_sig_ij_sig_kl, state_dict): """Computes the stiffness tensor given ensemble averages of C^B_ijkl, sigma^B_ij and sigma^B_ij * sigma^B_kl. """ spatial_dim = mean_born.shape[0] volume = quantity.volume(spatial_dim, state_dict['box']) delta_ij = jnp.eye(spatial_dim) delta_ik_delta_jl = jnp.einsum('ik,jl->ijkl', delta_ij, delta_ij) delta_il_delta_jk = jnp.einsum('il,jk->ijkl', delta_ij, delta_ij) sigma_prod = jnp.einsum('ij,kl->ijkl', mean_sigma, mean_sigma) # Note: maybe use real kinetic energy of trajectory rather than target # kbt? kinetic_term = delta_ik_delta_jl + delta_il_delta_jk kinetic_term *= state_dict['dof'] * state_dict['kT'] kinetic_term /= spatial_dim * volume sigma_term = mean_sig_ij_sig_kl - sigma_prod sigma_term *= volume / state_dict['kT'] return mean_born - sigma_term + kinetic_term @dynamic_statepoint({'kT': kT, 'dof': dof, 'box': reference_box}) def born_stiffness(quantity_trajs, state_dict, weights=None): sigma = quantity_trajs['born_stress'] born_stress_prods = jnp.einsum('nij,nkl->nijkl', sigma, sigma) stress_product_born_mean = _weighted_mean(born_stress_prods, weights) born_stress_tensor_mean = _weighted_mean( quantity_trajs['born_stress'], weights) born_stiffness_mean = _weighted_mean( quantity_trajs['born_stiffness'], weights) stiffness_tensor = stiffness_tensor_fn( born_stiffness_mean, born_stress_tensor_mean, stress_product_born_mean, state_dict) if elastic_constant_function is None: return stiffness_tensor else: return elastic_constant_function(stiffness_tensor) return born_stiffness
[docs] def stiffness_tensor_components_cubic_crystal(stiffness_tensor): """Computes the 3 independent elastic stiffness components of a cubic crystal from the whole stiffness tensor. The number of independent components in a general stiffness tensor is 21 for isotropic pressure. For a cubic crystal, these 21 parameters only take 3 distinct values: c11, c12 and c44. We compute these values from averages using all 21 components for variance reduction purposes. Args: stiffness_tensor: The full (3, 3, 3, 3) elastic stiffness tensor Returns: A (3,) ndarray containing (c11, c12, c44) """ # TODO likely there exists a better formulation via Einstein notation c = stiffness_tensor c11 = (c[0, 0, 0, 0] + c[1, 1, 1, 1] + c[2, 2, 2, 2]) / 3. c12 = (c[0, 0, 1, 1] + c[1, 1, 0, 0] + c[0, 0, 2, 2] + c[2, 2, 0, 0] + c[1, 1, 2, 2] + c[2, 2, 1, 1]) / 6. c44 = (c[0, 1, 0, 1] + c[1, 0, 0, 1] + c[0, 1, 1, 0] + c[1, 0, 1, 0] + c[0, 2, 0, 2] + c[2, 0, 0, 2] + c[0, 2, 2, 0] + c[2, 0, 2, 0] + c[2, 1, 2, 1] + c[1, 2, 2, 1] + c[2, 1, 1, 2] + c[1, 2, 1, 2]) / 12. return jnp.array([c11, c12, c44])
[docs] def stiffness_tensor_components_hexagonal_crystal(stiffness_tensor): """ Computes 5 independent elastic stiff components of a hexagonal crystal from the whole stiffness tensor For a hexagonal structure the 5 components are c11, c33, c44, c12, c13 Args: stiffness_tensor: The full(3,3,3,3) elastic stiffness tensor Returns: A (5,) ndarray containing (c11, c33, c44, c12, c13) """ c = stiffness_tensor c11 = (c[0, 0, 0, 0] + c[1, 1, 1, 1]) / 2. c33 = c[2, 2, 2, 2] c44 = (c[0, 2, 0, 2] + c[2, 0, 0, 2] + c[0, 2, 2, 0] + c[2, 0, 2, 0] + c[2, 1, 2, 1] + c[1, 2, 2, 1] + c[2, 1, 1, 2] + c[1, 2, 1, 2]) / 8. c12 = (c[0, 0, 1, 1] + c[1, 1, 0, 0]) / 2. c13 = (c[0, 0, 2, 2] + c[2, 2, 0, 0] + c[1, 1, 2, 2] + c[2, 2, 1, 1]) / 4. return jnp.array([c11, c33, c44, c12, c13])
# TODO: Update the remaining compute functions (see reference above for npt # formulations) def volumes(traj_state): """Returns array of volumes for all boxes in a NPT trajectory. Args: traj_state: TrajectoryState containing the NPT trajectory """ dim = traj_state.sim_state.sim_state.position.shape[-1] boxes = vmap(simulate.npt_box)(traj_state.trajectory) return vmap(quantity.volume, (None, 0))(dim, boxes) def isothermal_compressibility_npt(volume_traj, kbt): """Returns isothermal compressibility of a system in the NPT ensemble via the fluctuation formula. Args: volume_traj: Trajectory of box volumes of a NPT simulation, e.g. computed via traj_quantity.volumes. kbt: Temperature * Boltzmann constant """ mean_volume = jnp.mean(volume_traj) kappa = (jnp.mean(volume_traj**2) - mean_volume**2) / (mean_volume * kbt) return kappa def thermal_expansion_coefficient_npt(volume_traj, energy_traj, temperature, kbt, pressure): """Returns the thermal expansion coefficient of a system in the NPT ensemble via the fluctuation formula. Args: volume_traj: Trajectory of box volumes of a NPT simulation, e.g. computed via traj_quantity.volumes. energy_traj: Trajectory of potential energies temperature: Thermostat temperature kbt: Temperature * Boltzmann constant pressure: Barostat pressure """ mean_volume = jnp.mean(volume_traj) fluctuation = (jnp.mean(volume_traj * energy_traj) - jnp.mean(energy_traj) * mean_volume + pressure * (jnp.mean(volume_traj**2) - mean_volume**2)) alpha = fluctuation / temperature / kbt / mean_volume return alpha