learn.probabilistic#
This module contains methods for training and evaluation of uncertainty-aware neural network potentials trained bottom-up via energy / force matching.
Bayesian UQ#
Defines the Bayesian UQ problem by building prior and likelihood functions as well as combining them to prepare the problem in order to initialize an SGMCMC force matching trainer.
|
Uniform improper prior function. |
|
Initializes a prior distribution that acts on all parameters independently. |
|
Returns the likelihood function for Bayesian potential optimization based on a force-matching formulation. |
|
Initializes the log-posterior function. |
|
Initializes a compatible set of prior, likelihood, initial MCMC samples as well as train and validation loaders for learning probabilistic potentials via force-matching. |
Dropout Monte Carlo#
Utility functions to perform forward UQ of a model trained via Dropout.
|
Initializes a function that predicts a distribution of predictions for different dropout configurations, e.g. for uncertainty quantification. |
|
Returns forward UQ predictions for a trained model on a validation dataset. |
Propagate Uncertainty#
Propagate the uncertainty of probabilistic potentials to obtain UQ of MD observables.
- uq_trajectories(param_sets, init_state, trajectory_generator, vmap_simulations=1, kt_schedule=None, n_dropout=16)[source]#
Compute multiple trajectories in parallel for evaluation of parameter uncertainty.
- Parameters:
param_sets – Energy_params stacked along axis 0. For Dropout only a single parameter set.
init_state – Either a single sim_state (compatible with the trajectory_generator) or sim_states stacked along axis 0 to start each energy_param set from a different sim_state.
trajectory_generator – Trajectory generator as initialized from trajectory_generator_init.
vmap_simulations – Number of simulations to run vectorized.
kt_schedule – kbT schedule for simulations. If None, uses encoded temperature in simulator_template.
n_dropout – Number of Dropout samples to evaluate.
- Returns:
A trajectory state that contains all generated trajectories stacked along axis 0.
UQ Postprocessing#
Utility functions to assess statistics of Markov chains and compute error metrics.
|
|
|
Evaluates the mean absolute error of a list of parameter sets generated via sampling-based methods, based on validation data and a force-matching likelihood. |
|
Evaluates the root mean squared error of a set of parameters generated via sampling-based methods, based on test data and a force-matching objective. |