jax_md_mod.model.layers.OutputBlock#
- class OutputBlock(embed_size, out_embed_size=None, num_dense=3, num_targets=1, activation=<PjitFunction of <function silu>>, init_kwargs=None, name='Output', outscale=False)[source]#
DimeNet++ Output block.
Predicts per-atom quantities given RBF embeddings and messages.
Methods
- __init__(embed_size, out_embed_size=None, num_dense=3, num_targets=1, activation=<PjitFunction of <function silu>>, init_kwargs=None, name='Output', outscale=False)[source]#
Initializes an Output block.
- Parameters:
embed_size – Size of the edge embedding.
out_embed_size – Output size of Linear layers after upsampling
num_dense – Number of dense layers
num_targets – Number of target quantities to be predicted
activation – Activation function
init_kwargs – Dict of initialization kwargs for Linear layers
name – Name of Output block
outscale – Scale the output to initially have zero energy
- __call__(messages, rbf, idx_i, n_particles, dropout_dict=None)[source]#
Returns predicted per-atom quantities.
Returns parameters keyed by name for this module and submodules.
Returns state keyed by name for this module and submodules.