quantity.property_prediction#

This module contains molecular properties, computed from features of a neural network used for potential energy prediction.

Molecular Property Predictors#

This wrapper function is at the core of the molecular property prediction module by transforming models for (atom-wise) potential energy prediction to molecular property predictors for atom and molecule level properties.

molecular_property_predictor(model, n_per_atom=0)[source]#

Wraps models that predict per-atom quantities to predict both global and per-atom quantities.

Parameters:
  • model – Initialized model predicting per-atom quantities, e.g. DimeNetPP.

  • n_per_atom – Number of per-atom quantities to predict. Remaining predictions are assumed to be global.

Returns:

A tuple (global_properties, per_atom_properties) being a (n_globals,) and a (n_particles, n_per_atom) array of predictions.

Properties#

chemtrain provides the following property predictors:

partial_charge_prediction([model, ...])

Initializes a prediction of partial charges.

potential_energy_prediction([model, ...])

Initializes a prediction of the potential energy.

Protocols#

class PropertyPredictor(*args, **kwargs)[source]#
__call__(params, graph, **kwargs)[source]#

Predicts molecular properties from a molecular graph.

The form of the graph depends on the underlying model. E.g, for property predictions with jax_md_mod.model.neural_networks.DimeNetPP, the graph is of the form jax_md_mod.model.sparse_graph.SparseDirectionalGraph.

Parameters:
  • graph (Any) – Molecular graph containing neighborhood information.

  • **kwargs – Additional arguments to the potential model, extracting features from the molecular graph.

Return type:

Tuple[Array, Array]

Returns:

A tuple of global and per-atom properties.

class SinglePropertyPredictor(*args, **kwargs)[source]#
static __call__(self, params, graph: Any, **kwargs)[source]#
static __call__(self, features, **kwargs)

Predicts a molecular property from a molecular graph.

The form of the graph depends on the underlying model. E.g, for property predictions with jax_md_mod.model.neural_networks.DimeNetPP, the graph is of the form jax_md_mod.model.sparse_graph.SparseDirectionalGraph.

Parameters:
  • graph – Molecular graph containing neighborhood information. Required, if features should be computed from a model within the predictor.

  • features – Features derived from the molecular graph. Required if no model is provided to compute the features.

  • **kwargs – Additional arguments to the potential model, extracting features from the molecular graph.

Returns:

Returns a single per-atom or molecular property.

Examples#

We provide an example to transform DimeNet++ to a partial charge predictor, which enforces charge neutrality of its predictions. For a real-world application of this partial charge predictor in an active learning context, see this code of Thaler et al. (2024).

Utilities#

apply_model([model])

Initializes a molecular property predictor.

Snapshot Quantities#

Based on the predicted properties, we can compute other physical snapshot quantities. For example, with the predicted partial charges, we can compute the dipole moment of a system.

The following function transforms a property predictor on a graph into a snapshot compute function:

snapshot_quantity(property_predictor, graph_from_neighbor_list=None, features_key='features')[source]#

Transforms a single property predictor to a snapshot compute function.

Parameters:
  • property_predictor (SinglePropertyPredictor) – Function to predict a property from a molecular graph or from pre-extracted features.

  • graph_from_neighbor_list (Callable) – Function to build a molecular graph from a neighbor list. Only necessary, when the features should be extracted within the property predictor. Otherwise, pre-extracted global and per-atom features are required as kwargs.

  • features_key – Key to the pre-computed features if no model is provided.

Return type:

ComputeFn

Returns:

Returns a snapshot compute function for the corresponding property.

To extract features only once for all derived snapshot quantities, the following function can be used in combination with chemtrain.trajectory.traj_util.quantity_map():

init_feature_pre_computation(model, graph_from_neighbor_list=None)[source]#

Initializes a function to compute global and per-atom features.

Parameters:
  • model (PropertyPredictor) – Model to compute the features from a molecular graph.

  • graph_from_neighbor_list (Callable) – Function to construct a molecular graph from the neighbor list.

Return type:

ComputeFn

Returns:

Returns a function to compute the features from a molecular graph.

chemtrain provides the following snapshot quantities:

init_dipole_moment(displacement_fn[, model, ...])

Computes the dipole moment from partial charges of a molecule.