model.layers#

Jax / Haiku implementation of layers to build the DimeNet++ architecture.

The DimeNet++ Building Blocks take components of chemtrain.potential.sparse_graph.SparseDirectionalGraph as input. Please refer to this class for input descriptions.

Initializers#

OrthogonalVarianceScalingInit([scale])

Initializer scaling variance of uniform orthogonal matrix distribution.

RBFFrequencyInitializer()

Initializer of the frequencies of the RadialBesselLayer.

DimeNet++ Layers#

Basis Layers#

SmoothingEnvelope([p, name])

Smoothing envelope function for radial edge embeddings.

RadialBesselLayer(cutoff[, num_radial, ...])

Radial Bessel Function (RBF) representation of pairwise distances.

SphericalBesselLayer(r_cutoff, ...[, ...])

Spherical Bessel Function (SBF) representation of angular triplets.

DimeNet++ Building Blocks#

ResidualLayer(layer_size[, activation, ...])

Residual Layer: 2 activated Linear layers and a skip connection.

EmbeddingBlock(embed_size, n_species[, ...])

Embeddimg block of DimeNet.

OutputBlock(embed_size[, out_embed_size, ...])

DimeNet++ Output block.

InteractionBlock(embed_size, ...[, ...])

DimeNet++ Interaction block.

Utility Functions#

high_precision_segment_sum(data, segment_ids)

Implements the jax.ops.segment_sum, but casts input to float64 before summation and casts back to a target output type afterwards (float32 by default).