ensemble.reweighting#

This module provides implementations of thermodynamic perturbation theory.

Thermodynamic perturbation theory enables the transfer of information between perturbed ensembles, e.g., free energy differences or ensemble averages.

A description and example of using free energy perturbation approaches can be found here: Relative Entropy Minimization.

Likewise, an example to use the reweighting approach for ensemble averages is provided here: Differentiable Trajectory Reweighting (DiffTRe).

Routines#

Reweighting#

init_pot_reweight_propagation_fns(energy_fn_template, simulator_template, neighbor_fn, timings, state_kwargs, reweight_ratio=0.9, npt_ensemble=False, energy_batch_size=1, entropy_approximation=False, max_iter_bar=25, safe_propagation=True, resample_simstates=False)[source]#

Initializes all functions necessary for trajectory reweighting for a single state point.

The initialized functions include a function that computes weights for a given trajectory and a function that propagates the trajectory forward if the statistical error does not allow a re-use of the trajectory.

The third and (optionally) forth function depends on the value of safe_propagation. If set to True, only three functions are returned. Additionally, the propagation function checks the neighbor list for overflow and the trajectory for NaNs. However, in this case, the propagation function is not jit-able. Instead, for safe_propagation=False, the fourth function can be used as decorator to extend a non-jitable outer function.

Parameters:
  • energy_fn_template (EnergyFnTemplate) – Energy function template to initialize a new energy function with the current parameters.

  • simulator_template (Callable) – Template to create new simulators with different energy functions.

  • neighbor_fn (Callable[[Array, Optional[NeighborList, None], Optional[int, None]], NeighborList]) – Function to re-compute a neighbor list on reference positions.

  • timings (TimingClass) – Timings of the simulation.

  • state_kwargs (Dict[str, Union[Array, ndarray, bool, number, bool, int, float, complex]]) – Dictionary defining the statepoint, e.g., containing the reference temperature 'kT'.

  • reweight_ratio (float) – Minimal fractional ESS to re-use a trajectory.

  • npt_ensemble (bool) – Whether to reweight in an NPT ensemble.

  • energy_batch_size (int) – Batch size for the vectorized energy computation.

  • entropy_approximation (bool) – Approximation of the entropy difference between reference and target potential with similar gradient.

  • max_iter_bar (int) – Maximum number of iterations for the BAR procedure.

  • safe_propagation (bool) – Ensure that generated trajectories did not encounter any neighbor list overflow.

  • resample_simstates (bool) – Re-samples the sim states from the trajectory for a new simulation.

Example

Here, we increase the number of retires for overflown neighbor lists and additionally return additional arguments besides the trajectory state.

@partial(safe_propagate, multiple_argents=True, max_retry=10)
def outer_fn(params, traj_state):
    traj_state = propagate(params, traj_state)
    weights = compute_weights(params, traj_state)

    loss, predictions = some_loss_fn(traj_state, weights)

    return traj_state, loss, predictions
Return type:

Union[Tuple[Callable, ComputeWeightsFn, PropagateFn], Tuple[Callable, ComputeWeightsFn, Callable, Callable[…, PropagateFn]]]

Returns:

Returns a tuple of function to apply the reweighting formalism. The first function generates a reference trajectory, starting from a reference simulator state. The second function computes the weights given a reference trajectory state. The third function propagates the trajectory state, re-computing a trajectory from the current energy parameters if the ESS drops below a certain threshold. The fourth function is only returned if safe_propagation=False.

class ComputeWeightsFn(*args, **kwargs)[source]#
__call__(params, traj_state, entropy_and_free_energy=False)[source]#

Computes weights for the reweighting approach.

Parameters:
  • params (Any) – Energy parameters to obtain the perturbed potential.

  • traj_state (TrajectoryState) – Trajectory of a sufficiently close potential. The auxiliary quantities must contain the energy of the reference trajectory (key: 'energy').

  • **kwargs – Additional arguments to be passed to the function.

Return type:

Union[Tuple[Union[Array, ndarray, bool, number, bool, int, float, complex], Union[Array, ndarray, bool, number, bool, int, float, complex]], Any]

Returns:

Returns the weight for each sample and the effective sample size.

If entropy_and_free_energy=True set in kwargs, additionally returns the free energy and entropy difference to the reference potential.

class PropagateFn(*args, **kwargs)[source]#
__call__(params, traj_state, *args, **kwargs)[source]#

Samples from a new reference ensemble if the ESS is insufficient.

This function computes the ESS and decides whether to update the reference potential to the current potential parameters.

Additionally, the function checks whether overflow occurred and increases the neighbor list if necessary.

Parameters:
  • params (Any) – Energy parameters for the perturbed target potential.

  • traj_state (TrajectoryState) – Trajectory from the most recent reference potential.

Return type:

Union[TrajectoryState, Tuple[TrajectoryState, …]]

Returns:

Returns a trajectory state with adequate ESS and neighbor list, which can be the previous trajectory state.

If obtained via the safe_propagate decorator, can return additional results besides the propagated trajectory state.

init_reference_trajectory_reweight_fns(energy_fn_template, neighbor_fn, target_quantities, ref_kbt, ref_pressure=None, compute_fns=None, energy_batch_size=10, dynamic_dropout=False, reference_energy_fn_template=None, pressure_correction=False)[source]#

Initializes reweighting based on a reference trajectory.

Instead of recomputing a trajectory, this function uses a precomputed trajectory and quantities to estimate quantities for a different potential model via the reweighing procedure.

# Initialize the reference reweighting method
init_fn, compute_fn = init_reference_trajectory_reweight_fn(
    energy_fn_template, neighbour_fn, target_quantities, ref_kbt)

# Provide the reference trajectory and reference quantities
state = init_fn(reference_trajectory, reference_quantities)

# Compute the new quantities
results = compute_fn(state, new_energy_parameters)
Parameters:
  • energy_fn_template (EnergyFnTemplate) – Perturbed potential model

  • neighbor_fn (Callable[[Array, Optional[NeighborList, None], Optional[int, None]], NeighborList]) – Neighbour list function

  • target_quantities (Dict[str, Any]) – Quantities to estimate via reweighting

  • ref_kbt (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Reference microscopic temperature

  • ref_pressure (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Reference pressure

  • compute_fns (Dict[str, Callable]) – Functions to recompute quantities used in reweighting. The energy is automatically recomputed. It is possible to add quantities that are not contained in the reference quantities and required to recompute quantities that depdend directly on the potential, i.e. the pressure.

  • energy_batch_size (int) – Number of configurations to compute in parallel

  • dynamic_dropout (bool) – Issues a new dropout key for every state of the trajectory.

  • reference_energy_fn_template (EnergyFnTemplate) – Energy function to re-compute the energies of the potential used to generate the trajectory.

  • pressure_correction (bool) – Include the pressure in the Boltzmann factor for the NPT ensemble.

Return type:

[Callable, Callable]

Returns:

Returns a dictionary containing the quantities obtained via reweighting, the original quantities and the weights as well as the effective sample size during the reweighting procedure.

reweight_trajectory(traj, **targets)[source]#

Computes weights to reweight a trajectory from one thermodynamic state point to another.

This function allows re-using an existing trajectory to compute observables at slightly perturbed thermodynamic state points. The reference trajectory can be generated at a constant state point or at different state points, e.g. via non-equlinibrium MD. Both NVT and NPT trajectories are supported, however reweighting currently only allows reweighting into the same ensemble. For NVT, the trajectory can be reweighted to a different temperature. For NPT, can be reweighted to different kbT and/or pressure. We assume quantities not included in ‘targets’ to be constant over the trajectory, however this is not ensured by the code. Implemented are cases 1. - 4. of the reference [1].

Parameters:
  • traj – Reference trajectory to be reweighted

  • targets – Kwargs containing the targets under ‘kT’ and/or ‘pressure’. If a keyword is not provided, the qunatity is assumed to be and remain constant.

Returns:

A tuple (weights, n_eff). Weights can be used to compute reweighted observables and n_eff judges the expected statistical error from reweighting.

References

Entropy and Free Energy Calculation#

init_bar(energy_fn_template, kT, energy_batch_size=10, iter_bisection=25)[source]#

Initializes the free energy and entropy computation via the BAR approach.

The algorithm [2] uses the BAR method to derive the free energy difference between two trajectories and additionally derives the entropy difference via the thermodynamic relation \(TdS = dU - dF\).

This implementation relies on the bisection method to solve the implicit equation

\[\Delta F:\ \left\langle\frac{1}{1 + \exp(\beta\Delta U - \beta\Delta F)}\right\rangle_0 - \left\langle\frac{1}{1 + \exp(-\beta\Delta U + \beta\Delta F)}\right\rangle_1 = 0.\]
Parameters:
  • energy_fn_template (EnergyFnTemplate) – Function that returns a potential model when called with a set of energy parameters.

  • kT (Union[Array, ndarray, bool, number, bool, int, float, complex]) – Reference temperature.

  • energy_batch_size (int) – Batch size of the vectorized potential energy computation.

  • iter_bisection (int) – Iterations of the bisection method.

Return type:

Callable[[TrajectoryState, TrajectoryState], Tuple[Union[Array, ndarray, bool, number, bool, int, float, complex], Union[Array, ndarray, bool, number, bool, int, float, complex]]]

Returns:

Returns the new_traj with updated free energy and entropy difference. These differences are updated by the differences between the old and the new trajectory, such that these values resemble the differences to the first trajectory that has been generated.

References

Utilities#

checkpoint_quantities(compute_fns)[source]#

Applies checkpoint to all compute_fns to save memory on backward pass.

Parameters:

compute_fns (dict[str, ComputeFn]) – Dictionary of functions to compute instantaneous quantities from simulator states.

Return type:

None