jax_md_mod.model.layers.SmoothingEnvelope

jax_md_mod.model.layers.SmoothingEnvelope#

class SmoothingEnvelope(p=6, name='Envelope')[source]#

Smoothing envelope function for radial edge embeddings.

Smoothing the cut-off enables twice continuous differentiability of the model output, including the potential energy. The envelope function is 1 at 0 and has a root of multiplicity of 3 at 1 as defined in DimeNet. It is applied to scaled radial edge distances d_ij / cut_off [0, 1].

The implementation corresponds to the definition in the DimeNet paper. It is different from the original implementation of DimeNet / DimeNet++ that define incorrect spherical basis layers as a result (a known issue).

Methods

__init__(p=6, name='Envelope')[source]#

Initializes the SmoothingEnvelope layer.

Parameters:
  • p – Power of the smoothing polynomial

  • name – Name of the layer

__call__(distances)[source]#

Returns the envelope values.

params_dict()

Returns parameters keyed by name for this module and submodules.

state_dict()

Returns state keyed by name for this module and submodules.