jax_md_mod.model.layers.EmbeddingBlock#
- class EmbeddingBlock(embed_size, n_species, type_embed_size=None, activation=<PjitFunction of <function silu>>, init_kwargs=None, kbt_dependent=False, name='Embedding')[source]#
Embeddimg block of DimeNet.
Embeds edges by concattenating RBF embeddings with atom type embeddings of both connected atoms. If the network is defined to be kbT-dependent, adds a temperature embedding.
Methods
- __init__(embed_size, n_species, type_embed_size=None, activation=<PjitFunction of <function silu>>, init_kwargs=None, kbt_dependent=False, name='Embedding')[source]#
Initializes an Embedding block.
- Parameters:
embed_size – Size of the edge embedding.
n_species – Number of different atom species the network is supposed to process.
type_embed_size – Embedding size of atom type embedding. Default None results in embed_size / 2.
activation – Activation function
init_kwargs – Dict of initialization kwargs for Linear layers
kbt_dependent – Boolean, whether network prediction should depend on temperature.
name – Name of Embedding block
- __call__(rbf, species, idx_i, idx_j, dropout_dict=None, **kwargs)[source]#
Returns output of the Embedding block.
Returns parameters keyed by name for this module and submodules.
Returns state keyed by name for this module and submodules.