Layers
Jax / Haiku implementation of layers to build the DimeNet++ architecture.
The DimeNet++ Building Blocks take components of
SparseDirectionalGraph
as input. Please refer
to this class for input descriptions.
Initializers
- class chemtrain.layers.OrthogonalVarianceScalingInit(scale=2.0)[source]
Initializer scaling variance of uniform orthogonal matrix distribution.
Generates a weight matrix with variance according to Glorot initialization. Based on a random (semi-)orthogonal matrix. Neural networks are expected to learn better when features are decorrelated e.g. stated by “Reducing overfitting in deep networks by decorrelating representations”.
The approach is adopted from the original DimeNet and the implementation is inspired by Haiku’s variance scaling initializer.
- scale
Variance scaling factor
DimeNet++ Layers
Basis Layers
- class chemtrain.layers.SmoothingEnvelope(*args, **kwargs)[source]
Smoothing envelope function for radial edge embeddings.
Smoothing the cut-off enables twice continuous differentiability of the model output, including the potential energy. The envelope function is 1 at 0 and has a root of multiplicity of 3 at 1 as defined in DimeNet. It is applied to scaled radial edge distances d_ij / cut_off [0, 1].
The implementation corresponds to the definition in the DimeNet paper. It is different from the original implementation of DimeNet / DimeNet++ that define incorrect spherical basis layers as a result (a known issue).
- class chemtrain.layers.RadialBesselLayer(*args, **kwargs)[source]
Radial Bessel Function (RBF) representation of pairwise distances.
- freq_init
RBFFrequencyInitializer
- class chemtrain.layers.SphericalBesselLayer(*args, **kwargs)[source]
Spherical Bessel Function (SBF) representation of angular triplets.
- __init__(r_cutoff, num_spherical, num_radial, envelope_p=6, name='BesselSpherical')[source]
Initializes the SBF layer.
- Parameters:
r_cutoff – Radial cut-off
num_spherical – Number of spherical Bessel embedding functions
num_radial – Number of radial Bessel embedding functions
envelope_p – Power of envelope polynomial
name – Name of SBF layer
DimeNet++ Building Blocks
- class chemtrain.layers.ResidualLayer(*args, **kwargs)[source]
Residual Layer: 2 activated Linear layers and a skip connection.
- __init__(layer_size, activation=<PjitFunction of <function silu>>, init_kwargs=None, name='ResLayer')[source]
Initializes the Residual layer.
- Parameters:
layer_size – Output size of the Linear layers
activation – Activation function
init_kwargs – Dict of initialization kwargs for Linear layers
name – Name of the Residual layer
- class chemtrain.layers.EmbeddingBlock(*args, **kwargs)[source]
Embeddimg block of DimeNet.
Embeds edges by concattenating RBF embeddings with atom type embeddings of both connected atoms. If the network is defined to be kbT-dependent, adds a temperature embedding.
- __init__(embed_size, n_species, type_embed_size=None, activation=<PjitFunction of <function silu>>, init_kwargs=None, kbt_dependent=False, name='Embedding')[source]
Initializes an Embedding block.
- Parameters:
embed_size – Size of the edge embedding.
n_species – Number of different atom species the network is supposed to process.
type_embed_size – Embedding size of atom type embedding. Default None results in embed_size / 2.
activation – Activation function
init_kwargs – Dict of initialization kwargs for Linear layers
kbt_dependent – Boolean, whether network prediction should depend on temperature.
name – Name of Embedding block
- class chemtrain.layers.OutputBlock(*args, **kwargs)[source]
DimeNet++ Output block.
Predicts per-atom quantities given RBF embeddings and messages.
- __init__(embed_size, out_embed_size=None, num_dense=3, num_targets=1, activation=<PjitFunction of <function silu>>, init_kwargs=None, name='Output')[source]
Initializes an Output block.
- Parameters:
embed_size – Size of the edge embedding.
out_embed_size – Output size of Linear layers after upsampling
num_dense – Number of dense layers
num_targets – Number of target quantities to be predicted
activation – Activation function
init_kwargs – Dict of initialization kwargs for Linear layers
name – Name of Output block
- class chemtrain.layers.InteractionBlock(*args, **kwargs)[source]
DimeNet++ Interaction block.
Performs directional message-passing based on RBF and SBF embeddings as well as messages from the previous message-passing iteration. Updated messages are used in the subsequent Output block.
- __init__(embed_size, num_res_before_skip, num_res_after_skip, activation=<PjitFunction of <function silu>>, init_kwargs=None, angle_int_embed_size=None, basis_int_embed_size=8, name='Interaction')[source]
Initializes an Interaction block.
- Parameters:
embed_size – Size of the edge embedding.
num_res_before_skip – Number of Residual blocks before skip
num_res_after_skip – Number of Residual blocks after skip
activation – Activation function
init_kwargs – Dict of initialization kwargs for Linear layers
angle_int_embed_size – Embedding size of Linear layers for down-projected triplet interation
basis_int_embed_size – Embedding size of Linear layers for interation of RBS/ SBF basis
name – Name of Interaction block
Utility Functions
|
Implements the jax.ops.segment_sum, but casts input to float64 before summation and casts back to a target output type afterwards (float32 by default). |