learn.difftre#
Functions build around the Differentiable Trajectory Reweighting (DiffTRe) algorithm [1]. The DiffTRe algorithm builds on umbrella sampling to efficiently compute gradients of ensemble observables.
Chemtrain implements umbrella sampling approaches in the module
chemtrain.trajectory.reweighting.
References
Loss Functions#
- init_default_loss_fn(observables, loss_fns)[source]#
Initializes the MSE loss function for DiffTRe.
The default loss for the Differentiable Trajectory Reweighting (DiffTRe) algorithm [1] is the mean squared error (MSE) between an observable \(\mathcal A\) and an (experimental) reference \(\hat a\)
\[\mathcal L(\theta) = \gamma \left( \hat a - \mathcal A(\mathbf w_N, \mathbf r_N) \right)^2.\]Some observables are simple ensemble averages of instantaneous quantities \(a\)
\[\mathcal A(\mathbf w, \mathbf r^n) = \sum_{n=1}^N w^{(n)} a\left(\mathbf r^{(n)}\right).\]However, some quantities, e.g., the heat capacity \(c_V\), relate to multiple ensemble averages or even multiple quantities. Therefore, each target specifies a
traj_fnwith access to all instantaneous quantities.For a list of implemented ensemble observables, refer to the module
chemtrain.quantity.observables.- Parameters:
- Returns:
Returns a DiffTRe loss_fn. The loss function accepts a dictionary of snapshots for each sample, the weights for each sample, a dict definition properties of the statepoint and the targets of the training.
- init_rel_entropy_loss_fn(energy_fn_template, compute_weights, kbt, vmap_batch_size=10)[source]#
Initializes a function to computes the relative entropy loss.
The relative entropy between a fine-grained (FG) reference system with \(U^\mathrm{FG}(\mathbf r)\) coarse-grained (CG) reference system with \(U_\theta^\mathrm{CG}(\mathbf R)\) is [2]
\[S_\text{rel} = \beta\langle U_\theta^\mathrm{CG}(\mathcal M(\mathbf r)) - U^\mathrm{FG}(\mathbf r) \rangle_\text{FG} -\beta(A^\mathrm{CG}_\theta - A^\mathrm{FG}) + S_\text{map}.\]This relative entropy depends on the free energies of the models and a mapping entropy.
However, using free-energy perturbation approaches, one can create a replacement loss functions that has the same gradients
\[\mathcal L(\theta) = \beta \langle U_\theta^\mathrm{CG}(\mathcal M(\mathbf r))\rangle_\text{FG} -\beta A^\mathrm{CG}_\theta.\]- Parameters:
energy_fn_template – Energy function template
compute_weights – compute_weights function as initialized from init_pot_reweight_propagation_fns.
kbt – Temperature of the statepoint.
vmap_batch_size – Batch size for computing the potential energies on the reference positions.
- Returns:
A function that returns the relative entropy loss, i.e., the contributions to the relative entropy that depend on the parameters of the model.
Gradient Computation Routines#
- init_difftre_gradient_and_propagation(reweight_fns, loss_fn, quantities, energy_fn_template, wrapped=True, batched=False)[source]#
Initializes the function to compute the DiffTRe loss and its gradients.
The Differentiable Trajectory Reweighing (DiffTRe) algorithm [1] computes gradients of ensemble averages via a perturbation approach, initialized in
chemtrain.trajectory.reweighting.init_pot_reweight_propagation_fns().- Parameters:
reweight_fns (
Tuple[Callable,Callable,Callable]) – Functions to perform and evaluate umbrella-sampling simulations.loss_fn – DiffTRe compatible loss function, e.g., initialized via
init_default_loss_fn().quantities (
Dict[str,ComputeFn]) – Dictionary specifying how to compute instantaneous quantities from the simulator states.energy_fn_template (
EnergyFnTemplate) – Template to initialize the energy function that is required to compute the weights.wrapped (
bool) – Return separate weight, propagation, and gradient functionsbatched (
bool) – Computes loss for multiple ensembles.
- Returns:
Returns a function to propagate the current trajectory state, compute the loss and gradient, and predict observations.
- init_rel_entropy_gradient_and_propagation(reference_dataloader, reweight_fns, energy_fn_template, kbt, vmap_batch_size=10)[source]#
Initialize function to compute the relative entropy gradients.
This implementation of the Relative Entropy Minimization algorithm [2] computes the gradients of the free energy similar to the Differentiable Trajectory Reweighting (DiffTRe) algorithm [1] via a perturbation approach, initialized in
chemtrain.trajectory.reweighting.init_pot_reweight_propagation_fns().The computation of the gradient is batched to increase computational efficiency.
- Parameters:
reference_dataloader – Dataloader containing the mapped atomistic reference positions.
reweight_fns – Functions to perform and evaluate umbrella-sampling simulations to estimate the free energy gradients. Initialized via
chemtrain.trajectory.reweighting.init_pot_reweight_propagation_fns().energy_fn_template – Template to initialize the energy function that is required to compute the weights.
kbt – Temperature of the statepoint
vmap_batch_size – Batch for computing the potential energies on the reference positions.
- Returns:
Returns the gradient and propagation function for the relative entropy minimization algorithm.
ESS Line-Search#
- init_step_size_adaption(weight_fn, allowed_reduction=0.5, interior_points=10, step_size_scale=1e-07)[source]#
Initializes a line search to tune the step size in each iteration.
This method interpolates linearly between the old parameters \(\theta^{(i)}\) and the paremeters \(\tilde\theta\) proposed by the optimizer to find the optimal update
\[\theta^{(i + 1)} = (1 - \alpha) \theta^{(i)} + \alpha\tilde\theta\]that reduces the effective sample size to a predefined constant
\[N_\text{eff}(\theta^{(i+1)}) = r\cdot N_\text{eff}(\theta^{(i)}).\]This method uses a vectorized bisection algorithm with fixed number of iterations. At each iteration, the algorithm computes the effective sample size for a predefined number of interior points and updates the search interval boundaries to include the two closest points bisecting the residual.
The number of required iterations computes from the number of interior points \(n_i\) and the desired accuracy \(a\) via
\[N = \left\lceil -\log(a) / \log(n_i + 1)\right\rceil.\]- Parameters:
weight_fn (
Callable) – Function computing a tuple (weights, N_eff) from the parameter states.allowed_reduction (
Union[Array,ndarray,bool,number,bool,int,float,complex]) – Target reduction of the effective sample sizeinterior_points (
int) – Number of interiour pointsstep_size_scale (
float) – Accuracy of the found optimal interpolation coefficient
- Return type:
- Returns:
Returns the interpolation coefficient \(\alpha\).