quantity.property_prediction#
This module contains molecular properties, computed from features of a neural network used for potential energy prediction.
Molecular Property Predictors#
This wrapper function is at the core of the molecular property prediction module by transforming models for (atom-wise) potential energy prediction to molecular property predictors for atom and molecule level properties.
- molecular_property_predictor(model, n_per_atom=0)[source]#
Wraps models that predict per-atom quantities to predict both global and per-atom quantities.
- Parameters:
model – Initialized model predicting per-atom quantities, e.g. DimeNetPP.
n_per_atom – Number of per-atom quantities to predict. Remaining predictions are assumed to be global.
- Returns:
A tuple (global_properties, per_atom_properties) being a (n_globals,) and a (n_particles, n_per_atom) array of predictions.
Properties#
chemtrain provides the following property predictors:
|
Initializes a prediction of partial charges. |
|
Initializes a prediction of the potential energy. |
Protocols#
- class PropertyPredictor(*args, **kwargs)[source]#
- __call__(params, graph, **kwargs)[source]#
Predicts molecular properties from a molecular graph.
The form of the graph depends on the underlying model. E.g, for property predictions with
jax_md_mod.model.neural_networks.DimeNetPP, the graph is of the formjax_md_mod.model.sparse_graph.SparseDirectionalGraph.
- class SinglePropertyPredictor(*args, **kwargs)[source]#
- static __call__(self, params, graph: Any, **kwargs)[source]#
- static __call__(self, features, **kwargs)
Predicts a molecular property from a molecular graph.
The form of the graph depends on the underlying model. E.g, for property predictions with
jax_md_mod.model.neural_networks.DimeNetPP, the graph is of the formjax_md_mod.model.sparse_graph.SparseDirectionalGraph.- Parameters:
graph – Molecular graph containing neighborhood information. Required, if features should be computed from a model within the predictor.
features – Features derived from the molecular graph. Required if no model is provided to compute the features.
**kwargs – Additional arguments to the potential model, extracting features from the molecular graph.
- Returns:
Returns a single per-atom or molecular property.
Examples#
We provide an example to transform DimeNet++ to a partial charge predictor, which enforces charge neutrality of its predictions. For a real-world application of this partial charge predictor in an active learning context, see this code of Thaler et al. (2024).
Utilities#
|
Initializes a molecular property predictor. |
Snapshot Quantities#
Based on the predicted properties, we can compute other physical snapshot quantities. For example, with the predicted partial charges, we can compute the dipole moment of a system.
The following function transforms a property predictor on a graph into a snapshot compute function:
- snapshot_quantity(property_predictor, graph_from_neighbor_list=None, features_key='features')[source]#
Transforms a single property predictor to a snapshot compute function.
- Parameters:
property_predictor (
SinglePropertyPredictor) – Function to predict a property from a molecular graph or from pre-extracted features.graph_from_neighbor_list (
Callable) – Function to build a molecular graph from a neighbor list. Only necessary, when the features should be extracted within the property predictor. Otherwise, pre-extracted global and per-atom features are required as kwargs.features_key – Key to the pre-computed features if no model is provided.
- Return type:
- Returns:
Returns a snapshot compute function for the corresponding property.
To extract features only once for all derived snapshot quantities, the
following function can be used in combination with
chemtrain.trajectory.traj_util.quantity_map():
- init_feature_pre_computation(model, graph_from_neighbor_list=None)[source]#
Initializes a function to compute global and per-atom features.
- Parameters:
model (
PropertyPredictor) – Model to compute the features from a molecular graph.graph_from_neighbor_list (
Callable) – Function to construct a molecular graph from the neighbor list.
- Return type:
- Returns:
Returns a function to compute the features from a molecular graph.
chemtrain provides the following snapshot quantities:
|
Computes the dipole moment from partial charges of a molecule. |