jax_md_mod.model.layers.OutputBlock

jax_md_mod.model.layers.OutputBlock#

class OutputBlock(embed_size, out_embed_size=None, num_dense=3, num_targets=1, activation=<PjitFunction of <function silu>>, init_kwargs=None, name='Output', outscale=False)[source]#

DimeNet++ Output block.

Predicts per-atom quantities given RBF embeddings and messages.

Methods

__init__(embed_size, out_embed_size=None, num_dense=3, num_targets=1, activation=<PjitFunction of <function silu>>, init_kwargs=None, name='Output', outscale=False)[source]#

Initializes an Output block.

Parameters:
  • embed_size – Size of the edge embedding.

  • out_embed_size – Output size of Linear layers after upsampling

  • num_dense – Number of dense layers

  • num_targets – Number of target quantities to be predicted

  • activation – Activation function

  • init_kwargs – Dict of initialization kwargs for Linear layers

  • name – Name of Output block

  • outscale – Scale the output to initially have zero energy

__call__(messages, rbf, idx_i, n_particles, dropout_dict=None)[source]#

Returns predicted per-atom quantities.

params_dict()

Returns parameters keyed by name for this module and submodules.

state_dict()

Returns state keyed by name for this module and submodules.