model.layers#
Jax / Haiku implementation of layers to build the DimeNet++ architecture.
The DimeNet++ Building Blocks take components of
chemtrain.potential.sparse_graph.SparseDirectionalGraph as input.
Please refer to this class for input descriptions.
Initializers#
|
Initializer scaling variance of uniform orthogonal matrix distribution. |
Initializer of the frequencies of the RadialBesselLayer. |
DimeNet++ Layers#
Basis Layers#
|
Smoothing envelope function for radial edge embeddings. |
|
Radial Bessel Function (RBF) representation of pairwise distances. |
|
Spherical Bessel Function (SBF) representation of angular triplets. |
DimeNet++ Building Blocks#
|
Residual Layer: 2 activated Linear layers and a skip connection. |
|
Embeddimg block of DimeNet. |
|
DimeNet++ Output block. |
|
DimeNet++ Interaction block. |
Utility Functions#
|
Implements the jax.ops.segment_sum, but casts input to float64 before summation and casts back to a target output type afterwards (float32 by default). |