Source code for chemtrain.ensemble.reweighting

# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module provides implementations of thermodynamic perturbation theory.

Thermodynamic perturbation theory enables the transfer of information between
perturbed ensembles, e.g., free energy differences or ensemble averages.

A description and example of using free energy perturbation approaches can be
found here: :doc:`/algorithms/relative_entropy`.

Likewise, an example to use the reweighting approach for ensemble averages is
provided here: :doc:`/algorithms/difftre`.

"""
import functools
import time
import warnings
from functools import partial

import numpy as onp

import jax_sgmc.util
from jax import (checkpoint, lax, random, tree_util, vmap, numpy as jnp, jit,
                 debug)

from jax_md_mod import custom_quantity
import jax_md.util
from jax_md import util as jax_md_util, simulate

from chemtrain import util
from chemtrain.ensemble import sampling
from chemtrain.quantity import constants, observables

from typing import Dict, Any, Union, Callable, Tuple, Protocol

try:
    from jax.typing import ArrayLike
except:
    ArrayLike = Any
from jax_md.partition import NeighborFn
from chemtrain.typing import EnergyFnTemplate
from chemtrain.typing import ComputeFn


def dynamic_statepoint(default_kwargs: dict, **kwargs) -> dict:
    """Overwrites default statepoint with dynamically defined values."""
    statepoint = default_kwargs.copy()
    statepoint.update(kwargs)

    return statepoint


[docs] def checkpoint_quantities(compute_fns: dict[str, ComputeFn]) -> None: """Applies checkpoint to all compute_fns to save memory on backward pass. Args: compute_fns: Dictionary of functions to compute instantaneous quantities from simulator states. """ for quantity_key in compute_fns: compute_fns[quantity_key] = checkpoint(compute_fns[quantity_key])
def _estimate_effective_samples(weights): """Returns the effective sample size after reweighting to judge reweighting quality. """ # mask to avoid NaN from log(0) if a few weights are 0. weights = jnp.where(weights > 1.e-10, weights, 1.e-10) exponent = -jnp.sum(weights * jnp.log(weights)) return jnp.exp(exponent) def _build_weights(exponents): """Returns weights and the effective sample size from exponents of the reweighting formulas in a numerically stable way. """ # The reweighting scheme is a softmax, where the exponent above # represents the logits. To improve numerical stability and # guard against overflow it is good practice to subtract the # max of the exponent using the identity softmax(x + c) = # softmax(x). With all values in the exponent <=0, this # rules out overflow and the 0 value guarantees a denominator >=1. exponents -= jnp.max(exponents) prob_ratios = jnp.exp(exponents) weights = prob_ratios / jax_md_util.high_precision_sum(prob_ratios) n_eff = _estimate_effective_samples(weights) return weights, n_eff
[docs] def reweight_trajectory(traj, **targets): """Computes weights to reweight a trajectory from one thermodynamic state point to another. This function allows re-using an existing trajectory to compute observables at slightly perturbed thermodynamic state points. The reference trajectory can be generated at a constant state point or at different state points, e.g. via non-equlinibrium MD. Both NVT and NPT trajectories are supported, however reweighting currently only allows reweighting into the same ensemble. For NVT, the trajectory can be reweighted to a different temperature. For NPT, can be reweighted to different kbT and/or pressure. We assume quantities not included in 'targets' to be constant over the trajectory, however this is not ensured by the code. Implemented are cases 1. - 4. of the reference [#plumed]_. Args: traj: Reference trajectory to be reweighted targets: Kwargs containing the targets under 'kT' and/or 'pressure'. If a keyword is not provided, the qunatity is assumed to be and remain constant. Returns: A tuple (weights, n_eff). Weights can be used to compute reweighted observables and n_eff judges the expected statistical error from reweighting. References: .. [#plumed] `<https://www.plumed.org/doc-v2.6/user-doc/html/_r_e_w_e_i_g_h_t__t_e_m_p__p_r_e_s_s.html>`_ # pylint: disable=line-too-long """ npt_ensemble = util.is_npt_ensemble(traj.sim_state[0]) if not npt_ensemble: assert 'kT' in targets, 'For NVT, a "kT" target needs to be provided.' # Note: if temperature / pressure are supposed to remain constant and are # hence not provided in the targets, we set them to the respective reference # values. Hence, their contribution to reweighting cancels. This should # even be at no additional cost under jit as XLA should easily detect the # zero contribution. Same applies to combinations in the NPT ensemble. target_kbt = targets.get('kT', traj.dynamic_kwargs['kT']) target_beta = 1. / target_kbt reference_betas = 1. / traj.dynamic_kwargs['kT'] # temperature reweighting if 'energy' not in traj.aux: raise ValueError('For reweighting, energies need to be provided ' 'alongside the trajectory. Add energy to auxilary ' 'outputs in trajectory generator.') exponents = -(target_beta - reference_betas) * traj.aux['energy'] if npt_ensemble: # correct for P * V assert 'kbT' in targets or 'pressure' in targets, ('At least one target' ' needs to be given ' 'for reweighting.') target_press = targets.get('pressure', traj.dynamic_kwargs['pressure']) target_beta_p = target_beta * target_press ref_beta_p = reference_betas * traj.dynamic_kwargs['pressure'] volumes = observables.volumes(traj) # For constant p, reduces to -V * P_ref * (beta_target - beta_ref) # For constant T, reduces to -V * beta_ref * (p_target - p_ref) exponents -= volumes * (target_beta_p - ref_beta_p) return _build_weights(exponents)
[docs] def init_reference_trajectory_reweight_fns(energy_fn_template: EnergyFnTemplate, neighbor_fn: NeighborFn, target_quantities: Dict[str, Any], ref_kbt: ArrayLike, ref_pressure: ArrayLike = None, compute_fns: Dict[str, Callable] = None, energy_batch_size: int = 10, dynamic_dropout: bool = False, reference_energy_fn_template: EnergyFnTemplate = None, pressure_correction: bool = False ) -> [Callable, Callable]: """Initializes reweighting based on a reference trajectory. Instead of recomputing a trajectory, this function uses a precomputed trajectory and quantities to estimate quantities for a different potential model via the reweighing procedure. .. code-block :: python # Initialize the reference reweighting method init_fn, compute_fn = init_reference_trajectory_reweight_fn( energy_fn_template, neighbour_fn, target_quantities, ref_kbt) # Provide the reference trajectory and reference quantities state = init_fn(reference_trajectory, reference_quantities) # Compute the new quantities results = compute_fn(state, new_energy_parameters) Args: energy_fn_template: Perturbed potential model neighbor_fn: Neighbour list function target_quantities: Quantities to estimate via reweighting ref_kbt: Reference microscopic temperature ref_pressure: Reference pressure compute_fns: Functions to recompute quantities used in reweighting. The energy is automatically recomputed. It is possible to add quantities that are not contained in the reference quantities and required to recompute quantities that depdend directly on the potential, i.e. the pressure. energy_batch_size: Number of configurations to compute in parallel dynamic_dropout: Issues a new dropout key for every state of the trajectory. reference_energy_fn_template: Energy function to re-compute the energies of the potential used to generate the trajectory. pressure_correction: Include the pressure in the Boltzmann factor for the NPT ensemble. Returns: Returns a dictionary containing the quantities obtained via reweighting, the original quantities and the weights as well as the effective sample size during the reweighting procedure. """ # The energy of the new potential model is necessary for reweighting traj_energy_fn = custom_quantity.energy_wrapper(energy_fn_template) if compute_fns is None: compute_fns = {} compute_fns['energy'] = traj_energy_fn beta = 1. / ref_kbt checkpoint_quantities(compute_fns) def init_reference_reweighting(reference_trajectory: sampling.TrajectoryState, reference_quantities: Dict[str, Any], extra_acpacity: int = 0, reference_params: Any = None, ) -> Any: """Inits reweighting based on a reference trajectory.""" # Take the last frame of the reference trajectory as the simulation # state and initialize the neighbour list to create a state as expected # by the traj_util.quantity_traj function first_frame = util.tree_get_single(reference_trajectory) nbrs_state = util.neighbor_allocate( neighbor_fn, first_frame, extra_capacity=extra_acpacity) n_samples = reference_quantities['energy'].size thermostat = jnp.ones(n_samples) * ref_kbt if ref_pressure is not None: barostat = jnp.ones(n_samples) * ref_pressure else: barostat = None initial_state = sampling.TrajectoryState( sim_state=sampling.SimulatorState( sim_state=first_frame, nbrs=nbrs_state ), trajectory=reference_trajectory, overflow=False, thermostat_kbt=thermostat, barostat_press=barostat, aux=reference_quantities, energy_params=reference_params ) # TODO: Rework for the new trajectory quantities # Without the potential, some reference quantities cannot be # calculated. Therefore, we just require all reference quantities # upfront. Otherwise, it would be necessary to also provide the # reference potential model. We check that all required reference # quantities were provided or can be recomputed by a compute_fn. # available_quantities = list(compute_fns.keys()) # available_quantities += list(reference_quantities.keys()) # missing = [] # for target_key in target_quantities.keys(): # if target_key not in available_quantities: # missing.append(target_key) # assert len(missing) == 0, f'Missing quantities, {missing}' \ # f'must be provided as reference quantitiy ' \ # f'or as compute_fn.' return initial_state def compute_reweighted_quantities(reference_state: sampling.TrajectoryState, energy_params: Any, dropout_key: ArrayLike = None ) -> Dict[str, Union[Any, ArrayLike]]: """Computes weights for the reweighting approach.""" npt_ensemble = util.is_npt_ensemble(reference_state.sim_state[0]) dropout_keys = None if dropout_key is not None: n_samples = reference_state.trajectory.position.shape[0] if dynamic_dropout: dropout_keys = random.split(dropout_key, n_samples) else: dropout_keys = jnp.tile(dropout_key, (n_samples, 1)) # TODO: Recompute also the reference quantities if given a reference energy # template if reference_energy_fn_template is not None: print(f"Recompute reference quantities") reference_compute_fns = {} reference_compute_fns["energy"] = custom_quantity.energy_wrapper( reference_energy_fn_template) if npt_ensemble: reference_compute_fns["conf_pressure"] = custom_quantity.init_pressure( reference_energy_fn_template, include_kinetic=False ) ref_quantities = sampling.quantity_traj( reference_state, reference_compute_fns, reference_state.energy_params, energy_batch_size, dropout_keys ) else: ref_quantities = reference_state.aux if npt_ensemble: compute_fns['conf_pressure'] = custom_quantity.init_pressure( energy_fn_template, include_kinetic=False ) # reweighting properties (U and pressure) under perturbed potential quantities = sampling.quantity_traj( reference_state, compute_fns, energy_params, energy_batch_size, dropout_keys) # Note: Difference in pot. Energy is difference in total energy # as kinetic energy is the same and cancels. exponent = quantities['energy'] - ref_quantities['energy'] # In the npt ensemble, the pressure depends on the particle positions # and thus contribute to the boltzmann factor similar to the potential. # The volume instead depends only on the particle positions and is # thus equivalent for both states if npt_ensemble and pressure_correction: print(f"Consider the pressure in the boltzmann factor.") exponent = quantities['volume'] * ( quantities['conf_pressure'] - ref_quantities['conf_pressure'] ) exponent *= -beta weights, n_eff = _build_weights(exponent) # Compute raw quantities and update re-computed reference quantities. for key, quant in reference_state.aux.items(): if key not in quantities.keys(): quantities[key] = quant reweighted_quantities = { target_key: target['traj_fn'](quantities, weights=weights) for target_key, target in target_quantities.items() } result = { 'weights': weights, 'exponent': exponent, 'n_eff': n_eff, 'reference_quantities': ref_quantities, 'unweighted_quantities': quantities, 'reweighted_quantities': reweighted_quantities } return result return init_reference_reweighting, compute_reweighted_quantities
[docs] class ComputeWeightsFn(Protocol):
[docs] def __call__(self, params: Any, traj_state: sampling.TrajectoryState, entropy_and_free_energy: bool = False )-> Tuple[ArrayLike, ArrayLike] | Any: """Computes weights for the reweighting approach. Args: params: Energy parameters to obtain the perturbed potential. traj_state: Trajectory of a sufficiently close potential. The auxiliary quantities must contain the energy of the reference trajectory (key: ``'energy'``). **kwargs: Additional arguments to be passed to the function. Returns: Returns the weight for each sample and the effective sample size. If ``entropy_and_free_energy=True`` set in kwargs, additionally returns the free energy and entropy difference to the reference potential. """
[docs] class PropagateFn(Protocol):
[docs] def __call__(self, params: Any, traj_state: sampling.TrajectoryState, *args, **kwargs ) -> sampling.TrajectoryState | Tuple[sampling.TrajectoryState, ...]: """Samples from a new reference ensemble if the ESS is insufficient. This function computes the ESS and decides whether to update the reference potential to the current potential parameters. Additionally, the function checks whether overflow occurred and increases the neighbor list if necessary. Args: params: Energy parameters for the perturbed target potential. traj_state: Trajectory from the most recent reference potential. Returns: Returns a trajectory state with adequate ESS and neighbor list, which can be the previous trajectory state. If obtained via the ``safe_propagate`` decorator, can return additional results besides the propagated trajectory state. """
ReweightingFns = Union[ Tuple[Callable, ComputeWeightsFn, PropagateFn], Tuple[Callable, ComputeWeightsFn, Callable, Callable[..., PropagateFn]] ]
[docs] def init_pot_reweight_propagation_fns(energy_fn_template: EnergyFnTemplate, simulator_template: Callable, neighbor_fn: NeighborFn, timings: sampling.TimingClass, state_kwargs: Dict[str, ArrayLike], reweight_ratio: float = 0.9, npt_ensemble: bool = False, energy_batch_size: int = 1, entropy_approximation: bool = False, max_iter_bar: int = 25, safe_propagation: bool = True, resample_simstates: bool = False, ) -> ReweightingFns: """ Initializes all functions necessary for trajectory reweighting for a single state point. The initialized functions include a function that computes weights for a given trajectory and a function that propagates the trajectory forward if the statistical error does not allow a re-use of the trajectory. The third and (optionally) forth function depends on the value of ``safe_propagation``. If set to True, only three functions are returned. Additionally, the propagation function checks the neighbor list for overflow and the trajectory for NaNs. However, in this case, the propagation function is not jit-able. Instead, for ``safe_propagation=False``, the fourth function can be used as decorator to extend a non-jitable outer function. Args: energy_fn_template: Energy function template to initialize a new energy function with the current parameters. simulator_template: Template to create new simulators with different energy functions. neighbor_fn: Function to re-compute a neighbor list on reference positions. timings: Timings of the simulation. state_kwargs: Dictionary defining the statepoint, e.g., containing the reference temperature ``'kT'``. reweight_ratio: Minimal fractional ESS to re-use a trajectory. npt_ensemble: Whether to reweight in an NPT ensemble. energy_batch_size: Batch size for the vectorized energy computation. entropy_approximation: Approximation of the entropy difference between reference and target potential with similar gradient. max_iter_bar: Maximum number of iterations for the BAR procedure. safe_propagation: Ensure that generated trajectories did not encounter any neighbor list overflow. resample_simstates: Re-samples the sim states from the trajectory for a new simulation. Example: Here, we increase the number of retires for overflown neighbor lists and additionally return additional arguments besides the trajectory state. .. code:: python @partial(safe_propagate, multiple_argents=True, max_retry=10) def outer_fn(params, traj_state): traj_state = propagate(params, traj_state) weights = compute_weights(params, traj_state) loss, predictions = some_loss_fn(traj_state, weights) return traj_state, loss, predictions Returns: Returns a tuple of function to apply the reweighting formalism. The first function generates a reference trajectory, starting from a reference simulator state. The second function computes the weights given a reference trajectory state. The third function propagates the trajectory state, re-computing a trajectory from the current energy parameters if the ESS drops below a certain threshold. The fourth function is only returned if ``safe_propagation=False``. """ traj_energy_fn = custom_quantity.energy_wrapper(energy_fn_template) reweighting_quantities = {'energy': traj_energy_fn} bennett_free_energy = init_bar( energy_fn_template, state_kwargs['kT'], energy_batch_size, max_iter_bar) if npt_ensemble: # pressure currently only used to print pressure of generated trajectory # such that user can ensure correct statepoint of reference trajectory pressure_fn = custom_quantity.init_pressure(energy_fn_template) reweighting_quantities['pressure'] = pressure_fn trajectory_generator = sampling.trajectory_generator_init( simulator_template, energy_fn_template, timings, reweighting_quantities, vmap_batch=energy_batch_size, vmap_sim_batch=energy_batch_size ) trajectory_generator = jit(trajectory_generator) checkpoint_quantities(reweighting_quantities) def resample_new_simstate(params, traj_state): ref_sim_state = traj_state.sim_state.sim_state ref_trajectory = traj_state.trajectory n_chains = ref_sim_state.position.shape[0] # Position shape ends with [n_samples, n_particles, 3] weights, *_ = compute_weights(params, traj_state) num_samples = weights.size # Choose new initial positions from the reweighted # distribution of all samples. key, split1, split2 = random.split(traj_state.key, 3) new_position_idx = random.choice( split1, jnp.arange(num_samples), shape=(n_chains,), replace=False, p=weights ) new_sim_state = ref_sim_state.set( position=ref_trajectory.position[new_position_idx, ...] ) # Since velocities are independent of the potential, we can # redraw the velocities from the maxwell boltzmann distribution new_sim_state = vmap( partial(simulate.initialize_momenta, kT=state_kwargs["kT"]) )(new_sim_state, random.split(split2, n_chains)) new_traj_state = traj_state.replace( sim_state=traj_state.sim_state.replace( sim_state=new_sim_state ), key=key ) return new_traj_state def compute_weights(params, traj_state, entropy_and_free_energy = False): """Computes weights for the reweighting approach.""" # reweighting properties (U and pressure) under perturbed potential reweight_properties = sampling.quantity_traj( traj_state, reweighting_quantities, params, energy_batch_size) beta = 1. / traj_state.static_kwargs['kT'] assert (jnp.isscalar(beta) or beta.shape == ()), ( "Reweighting requires a constant temperature." ) # Note: Difference in pot. Energy is difference in total energy # as kinetic energy is the same and cancels exponent = reweight_properties['energy'] - traj_state.aux['energy'] exponent *= -beta # debug.print("Energy differences are between {} and {}", dU.min(), dU.max()) weights, n_eff = _build_weights(exponent) if not entropy_and_free_energy: return weights, n_eff else: # Compute the free energy difference and entropy difference to the # potential model that generated the trajectory max_exp = jnp.max(exponent) ratio_sum = jax_md_util.high_precision_sum( jnp.exp(exponent - max_exp)) log_n = jnp.log(exponent.size) free_energy_diff = jnp.log(ratio_sum) + max_exp - log_n free_energy_diff *= -1. / beta if entropy_approximation: raise NotImplementedError("Approximation not implemented.") # This is the thermodynamic entropy definition eng_diff = jnp.sum(reweight_properties['energy'].T * weights) eng_diff -= jnp.mean(traj_state.aux['energy']) entropy = eng_diff - free_energy_diff # Add the differences with respect to the initial potential entropy += traj_state.entropy_diff free_energy_diff += traj_state.free_energy_diff return weights, n_eff, entropy, free_energy_diff def trajectory_identity_mapping(inputs): """Re-uses trajectory if no recomputation needed.""" traj_state = inputs[1] return traj_state def recompute_trajectory(inputs): """Recomputes the reference trajectory, starting from the last state of the previous trajectory to save equilibration time. """ params, traj_state, kwargs = inputs # give kT here as additional input to be handed through to energy_fn # for kbt-dependent potentials if resample_simstates: traj_state = resample_new_simstate(params, traj_state) updated_traj = trajectory_generator( params, traj_state.sim_state, **kwargs) updated_traj = updated_traj.replace(key=traj_state.key) # Apply the BAR procedure to obtain the free energy difference between # the old and new trajectory # TODO: Update BAR method to use the correct statepoint. # E.g., change from U to exp, where exp is the generalized # exponent of the ensemble dfe, ds = bennett_free_energy( traj_state, updated_traj, **traj_state.dynamic_kwargs, **traj_state.static_kwargs ) updated_traj = updated_traj.replace( entropy_diff=traj_state.entropy_diff + ds, free_energy_diff=traj_state.free_energy_diff + dfe ) return updated_traj @jit def propagation_fn(params, traj_state, recompute=False, **kwargs): """Checks if a trajectory can be re-used. If not, a new trajectory is generated ensuring trajectories are always valid. Takes params and the traj_state as input and returns a trajectory valid for reweighting as well as an error code indicating if the neighborlist buffer overflowed during trajectory generation. """ kwargs = dynamic_statepoint(state_kwargs, **kwargs) _, n_eff = compute_weights(params, traj_state) n_snapshots = traj_state.aux['energy'].size recompute |= n_eff < reweight_ratio * n_snapshots debug.print(f"[Propagate] Effective sample size: {{}} " f"({reweight_ratio * n_snapshots}) " f"-> Recompute is {{}}", n_eff, recompute) propagated_state = lax.cond(recompute, recompute_trajectory, trajectory_identity_mapping, (params, traj_state, kwargs)) return propagated_state def safe_propagate(fun, multiple_arguments=True, max_retry=3): """Re-executes the wrapped function until errors are resolved.""" def wrapped(params, traj_state, *args, **kwargs): recompute = kwargs.pop("recompute", False) if jnp.any(jnp.isnan(traj_state.sim_state.sim_state.position)): raise RuntimeError( 'Last state is NaN. Currently, there is no recovering from ' 'this. Restart from the last non-overflown state might ' 'help, but comes at the cost that the reference state is ' 'likely not representative.') for reset_counter in range(max_retry): # When only propagating then only the trajectory is returned if multiple_arguments: traj_state, *returns = fun( params, traj_state, *args, **kwargs) else: traj_state = fun(params, traj_state) returns = None if recompute: print(f"[Safe Propagate] Forced recomputation.") traj_state = recompute_trajectory( (params, traj_state)) if traj_state.overflow: print(f"[Safe Propagate] Overflow detected, recompute " f"trajectory with increased neighbor list size.") last_state = traj_state.sim_state.sim_state if last_state.position.ndim > 2: single_enlarged_nbrs = util.neighbor_allocate( neighbor_fn, util.tree_get_single(last_state)) enlarged_nbrs = vmap(util.neighbor_update, (None, 0))( single_enlarged_nbrs, last_state) else: enlarged_nbrs = util.neighbor_allocate( neighbor_fn, last_state) reset_traj_state = traj_state.replace( sim_state=sampling.SimulatorState( sim_state=last_state, nbrs=enlarged_nbrs ) ) traj_state = recompute_trajectory( (params, reset_traj_state)) reset_counter += 1 else: if multiple_arguments: return traj_state, *returns else: return traj_state raise RuntimeError('Multiple neighbor list re-computations did ' 'not yield a trajectory without overflow. ' 'Consider increasing the neighbor list ' 'capacity multiplier.') return wrapped def propagate(params, traj_state: sampling.TrajectoryState, **kwargs) -> sampling.TrajectoryState: """Wrapper around jitted propagation function that ensures that if neighbor list buffer overflowed, the trajectory is recomputed and the neighbor list size is increased until valid trajectory was obtained. Due to the recomputation of the neighbor list, this function cannot be jit. """ new_traj_state = propagation_fn(params, traj_state, **kwargs) return new_traj_state @functools.partial(jit, static_argnames=("num_runs")) def init_first_traj(key, params, reference_state, num_runs=2, **kwargs): """Initializes initial trajectory to start optimization from. We dump the initial trajectory for equilibration, as initial equilibration usually takes much longer than equilibration time of each trajectory. If this is still not sufficient, the simulation should equilibrate over the course of subsequent updates. """ if resample_simstates: assert reference_state.sim_state.position.ndim > 2, ( f"Please initialize multiple chains to resample new initial " f"chain states." ) kwargs = dynamic_statepoint(state_kwargs, **kwargs) def _run_fn(sim_state, _): traj = trajectory_generator(params, sim_state, **kwargs) return traj.sim_state, traj # Run multiple times _, init_traj = lax.scan(_run_fn, reference_state, jnp.arange(num_runs)) # Select the last trajectory init_traj = tree_util.tree_map(lambda x: x[-1], init_traj) # Use the initial trajectory as a reference for entropy and free energy init_traj = init_traj.replace( entropy_diff=0.0, free_energy_diff=0.0, energy_params=params, key=key ) return init_traj if safe_propagation: safe_propagation_fn = safe_propagate(propagate, multiple_arguments=False) return init_first_traj, compute_weights, safe_propagation_fn else: warnings.warn( 'Propagation function is not safe by default. ' 'Do not forget to use the wrapper around the compute function to ' 'ensure that the neighborlist does not overflow.') return init_first_traj, compute_weights, propagate, safe_propagate
[docs] def init_bar(energy_fn_template: EnergyFnTemplate, kT: ArrayLike, energy_batch_size: int = 10, iter_bisection: int = 25 ) -> Callable[[sampling.TrajectoryState, sampling.TrajectoryState], Tuple[ArrayLike, ArrayLike]]: """Initializes the free energy and entropy computation via the BAR approach. The algorithm [#wyczalkowski2010]_ uses the BAR method to derive the free energy difference between two trajectories and additionally derives the entropy difference via the thermodynamic relation :math:`TdS = dU - dF`. This implementation relies on the bisection method to solve the implicit equation .. math :: \\Delta F:\ \\left\\langle\\frac{1}{1 + \\exp(\\beta\\Delta U - \\beta\\Delta F)}\\right\\rangle_0 - \\left\\langle\\frac{1}{1 + \\exp(-\\beta\\Delta U + \\beta\\Delta F)}\\right\\rangle_1 = 0. Args: energy_fn_template: Function that returns a potential model when called with a set of energy parameters. kT: Reference temperature. energy_batch_size: Batch size of the vectorized potential energy computation. iter_bisection: Iterations of the bisection method. Returns: Returns the new_traj with updated free energy and entropy difference. These differences are updated by the differences between the old and the new trajectory, such that these values resemble the differences to the first trajectory that has been generated. References: .. [#wyczalkowski2010] New Estimators for Calculating Solvation Entropy and Enthalpy and Comparative Assessments of Their Accuracy and Precision. Matthew A. Wyczalkowski, Andreas Vitalis, and Rohit V. Pappu in The Journal of Physical Chemistry B 2010 114 (24), 8166-8180, DOI: 10.1021/jp103050u """ traj_energy_fn = custom_quantity.energy_wrapper(energy_fn_template) reweighting_quantities = {'energy': traj_energy_fn} # Helper functions to calculate the free energy and entropy difference via # the iterative bar approach def _vmap_potential_energy_differences(old_traj, new_traj): """Performs the reweighting procedure vectorized.""" # print(f"Old trajectory has shapes {tree_util.tree_map(jnp.shape, old_traj)}") # print(f"New trajectory has shapes {tree_util.tree_map(jnp.shape, new_traj)}") @jax_sgmc.util.list_vmap def _inner(pair): traj_state, params = pair reweighting_properties = sampling.quantity_traj( traj_state, reweighting_quantities, params, energy_batch_size) return reweighting_properties # The BAR method requires the energy difference between the potential # for both the perturbed and unperturbed trajectories. We thus have to # compute the potential energy of the new potential model on the old # trajectory and vice versa. return _inner( (old_traj, new_traj.energy_params), (new_traj, old_traj.energy_params)) def _fr_free_energy(dV_p, dV_0, df, beta): """Returns the forward and reverse estimators for the BAR equation.""" exponent_p = beta * (dV_p - df) exponent_0 = beta * (-dV_0 + df) gf = 1.0 / (1 + jnp.exp(exponent_p)) gr = 1.0 / (1 + jnp.exp(exponent_0)) return gf, gr def _bar_residual(df, dV_p, dV_0, beta): """Squared residual of the implicit BAR equation. """ gf, gr = _fr_free_energy(dV_p, dV_0, df, beta) # debug.print("[Solve] df = {df} with {gf} and {gr}", df=df, gf=jnp.mean(gf), gr=jnp.mean(gr)) sum_gf = jax_md.util.high_precision_sum(gf) sum_gr = jax_md.util.high_precision_sum(gr) return sum_gf - sum_gr def _entropy_equation(df, V_0, V_p, rV_p, rV_0, beta): """Computes the entropy difference from both trajectories.""" dV_p = V_p - rV_0 dV_0 = rV_p - V_0 # Forward and reverse estimators once for the perturbed ensemble average # and once for the reference ensemble average gf_p, gr_p = _fr_free_energy(dV_p, dV_p, df, beta) gf_0, gr_0 = _fr_free_energy(dV_0, dV_0, df, beta) alpha_0 = jnp.mean(gf_0 * V_0) - jnp.mean(gf_0) * jnp.mean(V_0) alpha_0 += jnp.mean(gf_0 * gr_0 * dV_0) alpha_p = jnp.mean(gr_p * V_p) - jnp.mean(gr_p) * jnp.mean(V_p) alpha_p -= jnp.mean(gf_p * gr_p * dV_p) du = alpha_0 - alpha_p du /= jnp.mean(gf_0 * gr_0) + jnp.mean(gf_p * gr_p) + 1.e-30 return constants.kb * beta * (du - df) def _init_bisection(dV_p, dV_0, beta): # Helper function to find valid initial points by extending the search # interval if necessary def _expand_interval(state, _): # Expand the interval if both initial points have a residual with equal # sign df_p, df_0 = state loss_p = _bar_residual(df_p, dV_p, dV_0, beta) loss_0 = _bar_residual(df_0, dV_p, dV_0, beta) # Extend the search interval by a factor of four if solution is # not contained in the interval. Ensure that the interval is # increased even if the proposals are equal up to a constant extend = 1.5 * jnp.abs(df_p - df_0) + 0.5e-4 * jnp.abs(df_p + df_0) + 1e-8 extend *= jnp.sign(df_p - df_0) extend_interval = jnp.where( jnp.sign(loss_p) == jnp.sign(loss_0), extend, 0.0) df_p += extend_interval df_0 -= extend_interval # debug.print("[BAR INIT] Residuals are {r_p} and {r_0} -> New search interval: [{fp}, {fo}]", r_p=loss_p, r_0=loss_0, fp=df_p, fo=df_0) return (df_p, df_0), None # Initialize the guesses exponent_p = -beta * dV_p exponent_0 = -beta * dV_0 exp_p = jnp.exp(exponent_p - jnp.max(exponent_p)) exp_0 = jnp.exp(exponent_0 - jnp.max(exponent_0)) df_p = jnp.log(jax_md.util.high_precision_sum(exp_p)) df_0 = jnp.log(jax_md.util.high_precision_sum(exp_0)) df_p += jnp.max(exponent_p) - jnp.log(exponent_p.size) df_0 += jnp.max(exponent_0) - jnp.log(exponent_0.size) df_p *= -1 / beta df_0 *= -1 / beta # debug.print("[BAR INIT] Initialize a = {df_a} and b = {df_b}", df_a = df_p, df_b = df_0) (df_p, df_0), _ = lax.scan(_expand_interval, (df_p, df_0), onp.arange(10)) res_a = _bar_residual(df_p, dV_p, dV_0, beta) res_b = _bar_residual(df_0, dV_p, dV_0, beta) def _bisection_step(state, _): df_a, df_b = state df_c = 0.5 * (df_a + df_b) loss_a = _bar_residual(df_a, dV_p, dV_0, beta) loss_c = _bar_residual(df_c, dV_p, dV_0, beta) # debug.print("[BAR : {idx}] Residual {residual} in [{a}, {b}] for df = {df} in [{fa}, {fb}]", idx=idx, residual=loss_c, df=df_c, a=loss_a, b=loss_b, fa=df_a, fb=df_b) # Keep the point A or B that is on the other side of the zero line # from point C, # i.e. check if the loss for A and C have the same sign. # If this is the case, search in [C, B], otherwise in [A, C]. new_a = jnp.where(jnp.sign(loss_a) == jnp.sign(loss_c), df_c, df_a) new_b = jnp.where(jnp.sign(loss_a) == jnp.sign(loss_c), df_b, df_c) # Check if c is already the solution (up to the available precision) new_a = jnp.where(loss_c == 0.0, df_c, new_a) new_b = jnp.where(loss_c == 0.0, df_c, new_b) return (new_a, new_b), loss_c return (df_p, df_0), _bisection_step def bennett_free_energy(old_traj: sampling.TrajectoryState, new_traj: sampling.TrajectoryState, **kwargs): """Compute the free energy and entropy difference. The algorithm from [#wyczalkowski2010] uses the BAR method to derive the free energy difference between two trajectories and additionally derives the entropy difference via the thermodynamic relation $TdS = dU - dF$. This implementation relies on the bisection method to solve the implicit equation. Args: old_traj: trajectory generated with the unperturbed potential new_traj: trajectory generated with the perturbed potential Returns: Returns the new_traj with updated free energy and entropy difference. These differences are updated by the differences between the old and the new trajectory, such that these values resemble the differences to the first trajectory that has been generated. .. [#wyczalkowski2010] New Estimators for Calculating Solvation Entropy and Enthalpy and Comparative Assessments of Their Accuracy and Precision. Matthew A. Wyczalkowski, Andreas Vitalis, and Rohit V. Pappu in The Journal of Physical Chemistry B 2010 114 (24), 8166-8180, DOI: 10.1021/jp103050u """ _kT = kwargs.get('kT', kT) beta = 1. / _kT assert (jnp.isscalar(beta) or beta.shape == ()), ( "Reweighting requires a constant temperature." ) # Calculate the differences in potential energy for both trajectories rew_0, rew_p = _vmap_potential_energy_differences(old_traj, new_traj) # Get the potential energy from both trajectory V_p = new_traj.aux['energy'] V_0 = old_traj.aux['energy'] rV_p = rew_0['energy'] rV_0 = rew_p['energy'] # dV_p is the energy difference between the perturbed and unperturbed # potential based on the perturbed ensemble, dV_0 is the same # difference but on the unperturbed ensemble. dV_p = V_p - rV_0 dV_0 = rV_p - V_0 init_f, update_f = _init_bisection(dV_p, dV_0, beta) (dfe, df2), _ = lax.scan(update_f, init_f, onp.arange(iter_bisection)) res_a = _bar_residual(dfe, dV_p, dV_0, beta) res_b = _bar_residual(df2, dV_p, dV_0, beta) ds = _entropy_equation(dfe, V_0, V_p, rV_p, rV_0, beta) return dfe, ds return bennett_free_energy