# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions build around the Differentiable Trajectory Reweighting (DiffTRe)
algorithm [#Thaler2021]_. The DiffTRe algorithm builds
on umbrella sampling to efficiently compute gradients of ensemble observables.
Chemtrain implements umbrella sampling approaches in the module
:mod:`chemtrain.trajectory.reweighting`.
References:
.. [#Thaler2021] Thaler, S.; Zavadlav, J. Learning Neural Network
Potentials from Experimental Data via Differentiable Trajectory
Reweighting. Nat Commun **2021**, 12 (1), 6884.
https://doi.org/10.1038/s41467-021-27241-4.
.. [#Shell2008] Shell, M. S. The Relative Entropy Is Fundamental to
Multiscale and Inverse Thermodynamic Problems. J. Chem. Phys. 2008,
129 (14), 144108. https://doi.org/10.1063/1.2992060.
"""
import functools
from typing import Dict, Any, Callable, Tuple
import numpy as onp
import jax
from jax import jit, numpy as jnp, lax
from jax.typing import ArrayLike
from jax_md_mod import custom_quantity
from chemtrain.learn import max_likelihood
from chemtrain.typing import TargetDict, EnergyFnTemplate, ComputeFn
from chemtrain.ensemble import reweighting, evaluation, sampling
from chemtrain import util
from chemtrain.typing import TrajFn
[docs]
def init_default_loss_fn(observables: Dict[str, TrajFn],
loss_fns: Dict[str, Callable]
):
"""Initializes the MSE loss function for DiffTRe.
The default loss for the Differentiable Trajectory Reweighting (DiffTRe)
algorithm [#Thaler2021]_ is the mean squared error (MSE) between an observable
:math:`\\mathcal A` and an (experimental) reference :math:`\hat a`
.. math::
\\mathcal L(\\theta) = \\gamma \\left(
\\hat a - \\mathcal A(\\mathbf w_N, \\mathbf r_N) \\right)^2.
Some observables are simple ensemble averages of instantaneous quantities
:math:`a`
.. math::
\\mathcal A(\\mathbf w, \\mathbf r^n) =
\\sum_{n=1}^N w^{(n)} a\\left(\mathbf r^{(n)}\\right).
However, some quantities, e.g., the heat capacity :math:`c_V`,
relate to multiple ensemble averages or even multiple quantities.
Therefore, each target specifies a ``traj_fn`` with access to all
instantaneous quantities.
For a list of implemented ensemble observables, refer to the module
:mod:`chemtrain.quantity.observables`.
Args:
observables: Dictionary containing functions to compute ensemble
observables from snapshots.
loss_fns: Dictionary containing loss functions for the individual
targets.
Returns:
Returns a DiffTRe loss_fn. The loss function accepts a dictionary of
snapshots for each sample, the weights for each sample, a dict
definition properties of the statepoint and the targets of the training.
"""
def loss_fn(quantity_trajs, weights, state_dict, targets):
predictions = {
key: obs_fn(quantity_trajs, weights=weights, **state_dict)
for key, obs_fn in observables.items()
}
# MSE loss for the remaining targets
loss = 0.
for target_key, target in targets.items():
loss_fn = loss_fns.get(target_key, max_likelihood.mse_loss)
gamma = target.get('gamma', 1.0)
loss += gamma * loss_fn(predictions[target_key], target['target'])
return loss, predictions
return loss_fn
[docs]
def init_difftre_gradient_and_propagation(
reweight_fns: Tuple[Callable, Callable, Callable],
loss_fn,
quantities: Dict[str, ComputeFn],
energy_fn_template: EnergyFnTemplate,
wrapped: bool = True,
batched: bool = False,
):
"""Initializes the function to compute the DiffTRe loss and its gradients.
The Differentiable Trajectory Reweighing (DiffTRe) algorithm [#Thaler2021]_
computes gradients of ensemble averages via a perturbation approach,
initialized in
:func:`chemtrain.trajectory.reweighting.init_pot_reweight_propagation_fns`.
Args:
reweight_fns: Functions to perform and evaluate umbrella-sampling
simulations.
loss_fn: DiffTRe compatible loss function, e.g.,
initialized via :func:`init_default_loss_fn`.
quantities: Dictionary specifying how to compute instantaneous
quantities from the simulator states.
energy_fn_template: Template to initialize the energy function that
is required to compute the weights.
wrapped: Return separate weight, propagation, and gradient functions
batched: Computes loss for multiple ensembles.
Returns:
Returns a function to propagate the current trajectory state,
compute the loss and gradient, and predict observations.
"""
weights_fn, propagate_fn, safe_propagate = reweight_fns
quantities['energy'] = custom_quantity.energy_wrapper(
energy_fn_template)
reweighting.checkpoint_quantities(quantities)
def _difftre_loss(params, traj_state, state_dict, targets):
"""Computes the loss using the DiffTRe formalism and
additionally returns predictions of the current model.
"""
weights, _, entropy, free_energy = weights_fn(
params, traj_state, entropy_and_free_energy=True)
quantity_trajs = sampling.quantity_traj(
traj_state, quantities, params)
quantity_trajs.update(entropy=entropy, free_energy=free_energy)
loss, predictions = loss_fn(
quantity_trajs, weights, state_dict, targets)
# Always save free energy and entropy even if they are not part of
# the loss.
predictions.update(entropy=entropy, free_energy=free_energy)
return loss, predictions
# TODO: Maybe separate the functions like for force matching
def difftre_weights_fn(params, traj_state, reduction="min"):
partial_weights = functools.partial(weights_fn, params)
if not batched:
return partial_weights(traj_state, reduction=reduction)
return jax.vmap(partial_weights)(traj_state)
def difftre_loss_fn(params, traj_state, state_dict, targets):
partial_loss = functools.partial(_difftre_loss, params)
if not batched:
return partial_loss(traj_state, state_dict, targets)
batched_loss, batched_predictions = jax.vmap(partial_loss)(
traj_state, state_dict, targets)
return jnp.mean(batched_loss), batched_predictions
def difftre_propagation(params, traj_state, state_dict):
"""The main DiffTRe function that recomputes trajectories
when needed and computes gradients of the loss wrt. energy function
parameters for a single state point.
"""
partial_propagation = functools.partial(
propagate_fn, params, recompute=True)
if not batched:
return partial_propagation(traj_state, **state_dict)
return jax.vmap(partial_propagation)(traj_state, **state_dict)
if not wrapped:
return difftre_loss_fn, difftre_propagation, difftre_weights_fn
assert not batched, "Batched computation requires 'wrapped=False'."
loss_grad_fn = jax.value_and_grad(difftre_loss_fn, has_aux=True, argnums=0)
# TODO: There is more opportunity to make this general.
# We could extend the propagation and gradient function to take
# additional args besides the first two, e.g., batch state and
# output additional args besides the two mandatory traj state and
# loss grad
@safe_propagate
@jit
def difftre_grad_and_propagation(params, traj_state, state_dict, targets):
"""The main DiffTRe function that recomputes trajectories
when needed and computes gradients of the loss wrt. energy function
parameters for a single state point.
"""
traj_state = propagate_fn(params, traj_state, **state_dict)
(loss_val, predictions), loss_grad = loss_grad_fn(
params, traj_state, state_dict, targets)
return traj_state, loss_val, loss_grad, predictions
return difftre_grad_and_propagation
[docs]
def init_rel_entropy_loss_fn(energy_fn_template, compute_weights, kbt, vmap_batch_size=10):
"""Initializes a function to computes the relative entropy loss.
The relative entropy between a fine-grained (FG) reference system
with :math:`U^\\mathrm{FG}(\\mathbf r)` coarse-grained (CG) reference system
with :math:`U_\\theta^\\mathrm{CG}(\\mathbf R)` is [#Shell2008]_
.. math::
S_\\text{rel} = \\beta\\langle
U_\\theta^\\mathrm{CG}(\\mathcal M(\\mathbf r))
- U^\\mathrm{FG}(\\mathbf r) \\rangle_\\text{FG}
-\\beta(A^\\mathrm{CG}_\\theta - A^\\mathrm{FG}) + S_\\text{map}.
This relative entropy depends on the free energies of the models and
a mapping entropy.
However, using free-energy perturbation approaches, one can create
a replacement loss functions that has the same gradients
.. math::
\\mathcal L(\\theta) = \\beta \\langle
U_\\theta^\\mathrm{CG}(\\mathcal M(\\mathbf r))\\rangle_\\text{FG}
-\\beta A^\\mathrm{CG}_\\theta.
Args:
energy_fn_template: Energy function template
compute_weights: compute_weights function as initialized from
init_pot_reweight_propagation_fns.
kbt: Temperature of the statepoint.
vmap_batch_size: Batch size for computing the potential energies on
the reference positions.
Returns:
A function that returns the relative entropy loss, i.e., the
contributions to the relative entropy that depend on the parameters
of the model.
"""
ref_quantities = {
"ref_energy": custom_quantity.energy_wrapper(energy_fn_template)
}
reweighting.checkpoint_quantities(ref_quantities)
def loss_fn(params, traj_state, reference_batch):
# Compute the free energy difference with respect to the initial state
*_, free_energy = compute_weights(
params, traj_state, entropy_and_free_energy=True)
free_energy += traj_state.free_energy_diff
# Compute the potential predictions on the reference data
ref_states = evaluation.SimpleState(position=reference_batch['R'])
nbrs = traj_state.sim_state.nbrs
if nbrs.reference_position.ndim > 2:
nbrs = util.tree_get_single(traj_state.sim_state.nbrs)
ref_energies = evaluation.quantity_map(
ref_states, ref_quantities, nbrs, {},
{"kT": kbt}, params, vmap_batch_size,
)["ref_energy"]
return (jnp.mean(ref_energies) - free_energy) / kbt
return loss_fn
[docs]
def init_rel_entropy_gradient_and_propagation(reference_dataloader,
reweight_fns,
energy_fn_template,
kbt,
vmap_batch_size=10):
"""Initialize function to compute the relative entropy gradients.
This implementation of the Relative Entropy Minimization algorithm
[#Shell2008]_ computes the gradients of the free energy similar to the
Differentiable Trajectory Reweighting (DiffTRe) algorithm [#Thaler2021]_
via a perturbation approach, initialized in
:func:`chemtrain.trajectory.reweighting.init_pot_reweight_propagation_fns`.
The computation of the gradient is batched to increase computational
efficiency.
Args:
reference_dataloader: Dataloader containing the mapped atomistic
reference positions.
reweight_fns: Functions to perform and evaluate umbrella-sampling
simulations to estimate the free energy gradients.
Initialized via
:func:`chemtrain.trajectory.reweighting.init_pot_reweight_propagation_fns`.
energy_fn_template: Template to initialize the energy function that
is required to compute the weights.
kbt: Temperature of the statepoint
vmap_batch_size: Batch for computing the potential energies on the
reference positions.
Returns:
Returns the gradient and propagation function for the relative entropy
minimization algorithm.
"""
weights_fn, propagate_fn, safe_propagate = reweight_fns
rel_entropy_loss = init_rel_entropy_loss_fn(
energy_fn_template, weights_fn, kbt, vmap_batch_size)
value_and_grad = jax.value_and_grad(rel_entropy_loss, argnums=0)
@safe_propagate
@jax.jit
def safe_propagation_and_grad(params, traj_state, reference_batch):
"""Propagates the trajectory, if necessary, and computes the
gradient via the relative entropy formalism.
"""
traj_state = propagate_fn(params, traj_state)
loss, grad = value_and_grad(params, traj_state, reference_batch)
return traj_state, loss, grad
def propagation_and_grad(params, traj_state, batch_state):
new_batch_state, reference_batch = reference_dataloader(batch_state)
outs = safe_propagation_and_grad(params, traj_state, reference_batch)
return *outs, new_batch_state
return propagation_and_grad
[docs]
def init_step_size_adaption(weight_fn: Callable,
allowed_reduction: ArrayLike = 0.5,
interior_points: int = 10,
step_size_scale: float = 1e-7
) -> Callable:
"""Initializes a line search to tune the step size in each iteration.
This method interpolates linearly between the old parameters
:math:`\\theta^{(i)}` and the paremeters :math:`\\tilde\\theta`
proposed by the optimizer to find the optimal update
.. math ::
\\theta^{(i + 1)} = (1 - \\alpha) \\theta^{(i)} + \\alpha\\tilde\\theta
that reduces the effective sample size to a predefined constant
.. math ::
N_\\text{eff}(\\theta^{(i+1)}) = r\cdot N_\\text{eff}(\\theta^{(i)}).
This method uses a vectorized bisection algorithm with fixed number of
iterations. At each iteration, the algorithm computes the effective
sample size for a predefined number of interior points and updates the
search interval boundaries to include the two closest points bisecting
the residual.
The number of required iterations computes from the number of interior
points :math:`n_i` and the desired accuracy :math:`a` via
.. math ::
N = \\left\\lceil -\\log(a) / \\log(n_i + 1)\\right\\rceil.
Args:
weight_fn: Function computing a tuple (weights, N_eff) from the
parameter states.
allowed_reduction: Target reduction of the effective sample size
interior_points: Number of interiour points
step_size_scale: Accuracy of the found optimal interpolation
coefficient
Returns:
Returns the interpolation coefficient :math:`\\alpha`.
"""
# TODO: Makes this more general to work on an arbitrary measure,
# not only the effective sample size.
# Then, it is possible to move out the step size adaption
# into a more general trainer.
iterations = int(onp.ceil(-onp.log(step_size_scale) / onp.log(interior_points + 1)))
print(f"[Step size] Use {iterations} iterations for {interior_points} interior points.")
@functools.partial(jax.vmap, in_axes=(0, None, None, None, None, None))
def _residual(alpha, params, N_eff, batch_grad, proposal, traj_state):
# Find the biggest reduction among the statepoints
new_params = jax.tree_util.tree_map(
lambda old, new: old * (1 - alpha) + new * alpha,
params, proposal
)
_, N_eff_new = weight_fn(new_params, traj_state)
reduction = jnp.log(N_eff_new) - jnp.log(N_eff)
# Allow a reduction of the current effective sample size
# The minimum reduction must still be larger than the allowed reduction
# i.e. the residual of the final alpha must be greater than zero
return reduction - jnp.log(allowed_reduction)
def _step(state, _, params=None, N_effs=None, batch_grad=None, proposal=None, traj_states=None):
a, b, res_a, res_b = state
# Do not re-evaluate the residual for the already computed interval
# boundaries
c = jnp.reshape(jnp.linspace(a, b, interior_points + 2)[1:-1], (-1,))
res_c = _residual(c, params, N_effs, batch_grad, proposal, traj_states)
# debug.print("[Step Size] Residuals are {res}", res=res_c)
# Add bondary points to the possible candidates
c = jnp.concatenate((jnp.asarray([a, b]), c))
res_c = jnp.concatenate((jnp.asarray([res_a, res_b]), res_c))
# Find the smallest point bigger than zero and the biggest point
# smaller than zero
all_positive = jnp.where(res_c < 0, jnp.max(res_c), res_c)
all_negative = jnp.where(res_c > 0, jnp.min(res_c), res_c)
a_idx = jnp.argmin(all_positive)
b_idx = jnp.argmax(all_negative)
a, res_a = c[a_idx], res_c[a_idx]
b, res_b = c[b_idx], res_c[b_idx]
# debug.print("[Step Size] Search interval [{a}, {b}] with residual in [{res_a}, {res_b}]", a=a, b=b, res_a=res_a, res_b=res_b)
return (a, b, res_a, res_b), None
@jit
def _adaptive_step_size(params, batch_grad, proposal, traj_state):
_, N_eff = weight_fn(params, traj_state)
a, b = 1.0e-5, 1.0
res_a, res_b = _residual(
jnp.asarray([a, b]),
params, N_eff, batch_grad, proposal, traj_state)
# Check that minimum step size is sufficiently small, else just keep
# the minimum step size
b = jnp.where(res_a <= 0, a, b)
# Check whether full step does not reduce the effective step size
# below the threshold. If this is the case do the full step
a = jnp.where(jnp.logical_and(res_a > 0, res_b > 0), b, a)
# In the other case, do the bisection with the unchanged initial
# values of a and b
_step_fn = functools.partial(
_step, N_effs=N_eff, batch_grad=batch_grad, proposal=proposal,
traj_states=traj_state, params=params)
(a, b, res_a, _), _ = lax.scan(
_step_fn,
(a, b, res_a, res_b), onp.arange(iterations)
)
return a, res_a
return _adaptive_step_size