learn.difftre#

Functions build around the Differentiable Trajectory Reweighting (DiffTRe) algorithm [1]. The DiffTRe algorithm builds on umbrella sampling to efficiently compute gradients of ensemble observables.

Chemtrain implements umbrella sampling approaches in the module chemtrain.trajectory.reweighting.

References

Loss Functions#

init_default_loss_fn(observables, loss_fns)[source]#

Initializes the MSE loss function for DiffTRe.

The default loss for the Differentiable Trajectory Reweighting (DiffTRe) algorithm [1] is the mean squared error (MSE) between an observable \(\mathcal A\) and an (experimental) reference \(\hat a\)

\[\mathcal L(\theta) = \gamma \left( \hat a - \mathcal A(\mathbf w_N, \mathbf r_N) \right)^2.\]

Some observables are simple ensemble averages of instantaneous quantities \(a\)

\[\mathcal A(\mathbf w, \mathbf r^n) = \sum_{n=1}^N w^{(n)} a\left(\mathbf r^{(n)}\right).\]

However, some quantities, e.g., the heat capacity \(c_V\), relate to multiple ensemble averages or even multiple quantities. Therefore, each target specifies a traj_fn with access to all instantaneous quantities.

For a list of implemented ensemble observables, refer to the module chemtrain.quantity.observables.

Parameters:
  • observables (Dict[str, TrajFn]) – Dictionary containing functions to compute ensemble observables from snapshots.

  • loss_fns (Dict[str, Callable]) – Dictionary containing loss functions for the individual targets.

Returns:

Returns a DiffTRe loss_fn. The loss function accepts a dictionary of snapshots for each sample, the weights for each sample, a dict definition properties of the statepoint and the targets of the training.

init_rel_entropy_loss_fn(energy_fn_template, compute_weights, kbt, vmap_batch_size=10)[source]#

Initializes a function to computes the relative entropy loss.

The relative entropy between a fine-grained (FG) reference system with \(U^\mathrm{FG}(\mathbf r)\) coarse-grained (CG) reference system with \(U_\theta^\mathrm{CG}(\mathbf R)\) is [2]

\[S_\text{rel} = \beta\langle U_\theta^\mathrm{CG}(\mathcal M(\mathbf r)) - U^\mathrm{FG}(\mathbf r) \rangle_\text{FG} -\beta(A^\mathrm{CG}_\theta - A^\mathrm{FG}) + S_\text{map}.\]

This relative entropy depends on the free energies of the models and a mapping entropy.

However, using free-energy perturbation approaches, one can create a replacement loss functions that has the same gradients

\[\mathcal L(\theta) = \beta \langle U_\theta^\mathrm{CG}(\mathcal M(\mathbf r))\rangle_\text{FG} -\beta A^\mathrm{CG}_\theta.\]
Parameters:
  • energy_fn_template – Energy function template

  • compute_weights – compute_weights function as initialized from init_pot_reweight_propagation_fns.

  • kbt – Temperature of the statepoint.

  • vmap_batch_size – Batch size for computing the potential energies on the reference positions.

Returns:

A function that returns the relative entropy loss, i.e., the contributions to the relative entropy that depend on the parameters of the model.

Gradient Computation Routines#

init_difftre_gradient_and_propagation(reweight_fns, loss_fn, quantities, energy_fn_template, wrapped=True, batched=False)[source]#

Initializes the function to compute the DiffTRe loss and its gradients.

The Differentiable Trajectory Reweighing (DiffTRe) algorithm [1] computes gradients of ensemble averages via a perturbation approach, initialized in chemtrain.trajectory.reweighting.init_pot_reweight_propagation_fns().

Parameters:
  • reweight_fns (Tuple[Callable, Callable, Callable]) – Functions to perform and evaluate umbrella-sampling simulations.

  • loss_fn – DiffTRe compatible loss function, e.g., initialized via init_default_loss_fn().

  • quantities (Dict[str, ComputeFn]) – Dictionary specifying how to compute instantaneous quantities from the simulator states.

  • energy_fn_template (EnergyFnTemplate) – Template to initialize the energy function that is required to compute the weights.

  • wrapped (bool) – Return separate weight, propagation, and gradient functions

  • batched (bool) – Computes loss for multiple ensembles.

Returns:

Returns a function to propagate the current trajectory state, compute the loss and gradient, and predict observations.

init_rel_entropy_gradient_and_propagation(reference_dataloader, reweight_fns, energy_fn_template, kbt, vmap_batch_size=10)[source]#

Initialize function to compute the relative entropy gradients.

This implementation of the Relative Entropy Minimization algorithm [2] computes the gradients of the free energy similar to the Differentiable Trajectory Reweighting (DiffTRe) algorithm [1] via a perturbation approach, initialized in chemtrain.trajectory.reweighting.init_pot_reweight_propagation_fns().

The computation of the gradient is batched to increase computational efficiency.

Parameters:
  • reference_dataloader – Dataloader containing the mapped atomistic reference positions.

  • reweight_fns – Functions to perform and evaluate umbrella-sampling simulations to estimate the free energy gradients. Initialized via chemtrain.trajectory.reweighting.init_pot_reweight_propagation_fns().

  • energy_fn_template – Template to initialize the energy function that is required to compute the weights.

  • kbt – Temperature of the statepoint

  • vmap_batch_size – Batch for computing the potential energies on the reference positions.

Returns:

Returns the gradient and propagation function for the relative entropy minimization algorithm.