jax_md_mod.model.layers.EmbeddingBlock

jax_md_mod.model.layers.EmbeddingBlock#

class EmbeddingBlock(embed_size, n_species, type_embed_size=None, activation=<PjitFunction of <function silu>>, init_kwargs=None, kbt_dependent=False, name='Embedding')[source]#

Embeddimg block of DimeNet.

Embeds edges by concattenating RBF embeddings with atom type embeddings of both connected atoms. If the network is defined to be kbT-dependent, adds a temperature embedding.

Methods

__init__(embed_size, n_species, type_embed_size=None, activation=<PjitFunction of <function silu>>, init_kwargs=None, kbt_dependent=False, name='Embedding')[source]#

Initializes an Embedding block.

Parameters:
  • embed_size – Size of the edge embedding.

  • n_species – Number of different atom species the network is supposed to process.

  • type_embed_size – Embedding size of atom type embedding. Default None results in embed_size / 2.

  • activation – Activation function

  • init_kwargs – Dict of initialization kwargs for Linear layers

  • kbt_dependent – Boolean, whether network prediction should depend on temperature.

  • name – Name of Embedding block

__call__(rbf, species, idx_i, idx_j, dropout_dict=None, **kwargs)[source]#

Returns output of the Embedding block.

params_dict()

Returns parameters keyed by name for this module and submodules.

state_dict()

Returns state keyed by name for this module and submodules.