Neural Networks

Neural network models for potential energy and molecular property prediction.

DimeNet++

The DimeNetPP directly takes a SparseDirectionalGraph as input and predicts per-atom quantities. As DimeNetPP is a haiku Module, it needs to be wrapped inside a hk.transform() before it can be applied.

We provide 2 interfaces to DimeNet++: The function dimenetpp_neighborlist() serves as a interface to Jax M.D. The resulting apply function can be directly used as a jax_md energy_fn, e.g. to run molecular dynamics simulations.

For direct prediction of global molecular properties, dimenetpp_property_prediction() can be used.

class chemtrain.neural_networks.DimeNetPP(r_cutoff, n_species, num_targets, kbt_dependent=False, embed_size=128, n_interaction_blocks=4, num_residual_before_skip=1, num_residual_after_skip=2, out_embed_size=None, type_embed_size=None, angle_int_embed_size=None, basis_int_embed_size=8, num_dense_out=3, num_rbf=6, num_sbf=7, activation=<PjitFunction of <function silu>>, envelope_p=6, init_kwargs=None, dropout_mode=None, name='DimeNetPP')[source]

DimeNet++ for molecular property prediction.

This model takes as input a sparse representation of a molecular graph - consisting of pairwise distances and angular triplets - and predicts per-atom properties. Global properties can be obtained by summing over per-atom predictions.

The default values correspond to the orinal values of DimeNet++.

This custom implementation follows the original DimeNet / DimeNet++ (https://arxiv.org/abs/2011.14115), while correcting for known issues (see https://github.com/klicperajo/dimenet).

__init__(r_cutoff, n_species, num_targets, kbt_dependent=False, embed_size=128, n_interaction_blocks=4, num_residual_before_skip=1, num_residual_after_skip=2, out_embed_size=None, type_embed_size=None, angle_int_embed_size=None, basis_int_embed_size=8, num_dense_out=3, num_rbf=6, num_sbf=7, activation=<PjitFunction of <function silu>>, envelope_p=6, init_kwargs=None, dropout_mode=None, name='DimeNetPP')[source]

Initializes the DimeNet++ model

The default values correspond to the orinal values of DimeNet++.

Parameters:
  • r_cutoff (float) – Radial cut-off distance of edges

  • n_species (int) – Number of different atom species the network is supposed to process.

  • num_targets (int) – Number of different atomic properties to predict

  • kbt_dependent (bool) – True, if DimeNet explicitly depends on temperature. In this case ‘kT’ needs to be provided as a kwarg during the model call to the energy_fn. Default False results in a model independent of temperature.

  • embed_size (int) – Size of message embeddings. Scale interaction and output embedding sizes accordingly, if not specified explicitly.

  • n_interaction_blocks (int) – Number of interaction blocks

  • num_residual_before_skip (int) – Number of residual blocks before the skip connection in the Interaction block.

  • num_residual_after_skip (int) – Number of residual blocks after the skip connection in the Interaction block.

  • out_embed_size (Optional[int]) – Embedding size of output block. If None is set to 2 * embed_size.

  • type_embed_size (Optional[int]) – Embedding size of atom type embeddings. If None is set to 0.5 * embed_size.

  • angle_int_embed_size (Optional[int]) – Embedding size of Linear layers for down-projected triplet interation. If None is 0.5 * embed_size.

  • basis_int_embed_size (int) – Embedding size of Linear layers for interation of RBS/ SBF basis in interaction block

  • num_dense_out (int) – Number of final Linear layers in output block

  • num_rbf (int) – Number of radial Bessel embedding functions

  • num_sbf (int) – Number of spherical Bessel embedding functions

  • activation (Callable) – Activation function

  • envelope_p (int) – Power of envelope polynomial

  • init_kwargs (Optional[Dict[str, Any]]) – Kwargs for initializaion of Linear layers

  • dropout_mode (Optional[Dict[str, Any]]) – A dict defining which fully connected layers to apply dropout and at which rate (see dropout.dimenetpp_setup). If None, no Dropout is applied.

  • name (str) – Name of DimeNet++ model

__call__(graph, **dyn_kwargs)[source]

Predicts per-atom quantities for a given molecular graph.

Parameters:
  • graph (SparseDirectionalGraph) – An instance of sparse_graph.SparseDirectionalGraph defining the molecular graph connectivity.

  • **dyn_kwargs – Kwargs supplied on-the-fly, such as ‘kT’ for temperature-dependent models or ‘dropout_key’ for Dropout.

Return type:

Array

Returns:

An (n_partciles, num_targets) array of predicted per-atom quantities

chemtrain.neural_networks.dimenetpp_neighborlist(displacement, r_cutoff, n_species=100, positions_test=None, neighbor_test=None, max_triplet_multiplier=1.25, max_edge_multiplier=1.25, **dimenetpp_kwargs)[source]

DimeNet++ energy function for Jax, M.D.

This function provides an interface for the DimeNet++ haiku model to be used as a jax_md energy_fn. Analogous to jax_md energy_fns, the initialized DimeNet++ energy_fn requires particle positions and a dense neighbor list as input - plus an array for species or other dynamic kwargs, if applicable.

From particle positions and neighbor list, the sparse graph representation with edges and angle triplets is computed. Due to the constant shape requirement of jit of the neighborlist in jax_md, the neighbor list contains many masked edges, i.e. pairwise interactions that only “fill” the neighbor list, but are set to 0 during computation. This translates to masked edges and triplets in the sparse graph representation.

For improved computational efficiency during jax_md simulations, the maximum number of edges and triplets can be estimated during model initialization. Edges and triplets beyond this maximum estimate can be capped to reduce computational and memory requirements. Capping is enabled by providing sample inputs (positions_test and neighbor_test) at initialization time. However, beware that currently, an overflow of max_edges and max_angles is not caught, as this requires passing an error code throgh jax_md simulators - analogous to the overflow detection in jax_md neighbor lists. If in doubt, increase the max edges/angles multipliers or disable capping.

Parameters:
  • displacement (Callable[[Array, Array], Array]) – Jax_md displacement function

  • r_cutoff (float) – Radial cut-off distance of DimeNetPP and the neighbor list

  • n_species (int) – Number of different atom species the network is supposed to process.

  • positions_test (Optional[Array]) – Sample positions to estimate max_edges / max_angles. Needs to be provided to enable capping.

  • neighbor_test (Optional[NeighborList]) – Sample neighborlist to estimate max_edges / max_angles. Needs to be provided to enable capping.

  • max_edge_multiplier (float) – Multiplier for initial estimate of maximum edges.

  • max_triplet_multiplier (float) – Multiplier for initial estimate of maximum triplets.

  • dimenetpp_kwargs – Kwargs to change the default structure of DimeNet++. For definition of the kwargs, see DimeNetPP.

Returns:

A init_fn that initializes the model parameters and an energy function that computes the energy for a particular state given model parameters. The energy function requires the same input as other energy functions with neighbor lists in jax_md.energy.

Return type:

A tuple of 2 functions

chemtrain.neural_networks.dimenetpp_property_prediction(r_cutoff, n_targets=1, n_species=100, **model_kwargs)[source]

Initializes a model that predicts global molecular properties.

Parameters:
  • r_cutoff (float) – Radial cut-off distance of DimeNetPP and the neighbor list.

  • n_targets (int) – Number of different molecular properties to predict.

  • n_species (int) – Number of different atom species the network is supposed to process.

  • **model_kwargs – Kwargs to change the default structure of DimeNet++.

Returns:

A init_fn that initializes the model parameters and an apply_function that predicts global molecular properties.

Return type:

A tuple of 2 functions

Pairwise NN

PairwiseNN implements a neural network, that parametrizes 2-body interactions. The function pair_interaction_nn() initializes a pairwise jax_md neighborlist energy_fn, as an alternative to classical tabulated potentials.

class chemtrain.neural_networks.PairwiseNN(r_cutoff, hidden_layers, init_kwargs=None, activation=<PjitFunction of <function silu>>, num_rbf=6, envelope_p=6, name='PairNN')[source]

A neural network predicting pairwise edge quantities

Can be used for energy prediction for pairwise interactions.

__init__(r_cutoff, hidden_layers, init_kwargs=None, activation=<PjitFunction of <function silu>>, num_rbf=6, envelope_p=6, name='PairNN')[source]

Initializes the current module with the given name.

Subclasses should call this constructor before creating other modules or variables such that those modules are named correctly.

Parameters:

name (str) – An optional string name for the class. Must be a valid Python identifier. If name is not provided then the class name for the current instance is converted to lower_snake_case and used instead.

__call__(distances, species=None, **kwargs)[source]

Call self as a function.

chemtrain.neural_networks.pair_interaction_nn(displacement, r_cutoff, hidden_layers, **pair_net_kwargs)[source]

An MLP acting on pairwise distances independently and summing the contributions.

Embeds pairwise distances via radial Bessel functions (RBF). The RBF is also used to enforce a differentiable cut-off.

Parameters:
  • displacement (Callable[[Array, Array], Array]) – Displacement function

  • r_cutoff (float) – Radial cut-off of pairwise interactions and neighbor list

  • hidden_layers – A list (or scalar in the case of a single hidden layer) of number of neurons for each hidden layer in the MLP

  • pair_net_kwargs – Kwargs to change the default structure of PairwiseNN. For definition of the kwargs, see PairwiseNN.

Returns:

A init_fn that initializes the model parameters and an energy function that computes the energy for a particular state given model parameters.

Return type:

A tuple of 2 functions