jax_md_mod.model.layers.InteractionBlock

jax_md_mod.model.layers.InteractionBlock#

class InteractionBlock(embed_size, num_res_before_skip, num_res_after_skip, activation=<PjitFunction of <function silu>>, init_kwargs=None, angle_int_embed_size=None, basis_int_embed_size=8, name='Interaction')[source]#

DimeNet++ Interaction block.

Performs directional message-passing based on RBF and SBF embeddings as well as messages from the previous message-passing iteration. Updated messages are used in the subsequent Output block.

Methods

__init__(embed_size, num_res_before_skip, num_res_after_skip, activation=<PjitFunction of <function silu>>, init_kwargs=None, angle_int_embed_size=None, basis_int_embed_size=8, name='Interaction')[source]#

Initializes an Interaction block.

Parameters:
  • embed_size – Size of the edge embedding.

  • num_res_before_skip – Number of Residual blocks before skip

  • num_res_after_skip – Number of Residual blocks after skip

  • activation – Activation function

  • init_kwargs – Dict of initialization kwargs for Linear layers

  • angle_int_embed_size – Embedding size of Linear layers for down-projected triplet interation

  • basis_int_embed_size – Embedding size of Linear layers for interation of RBS/ SBF basis

  • name – Name of Interaction block

__call__(m_input, rbf, sbf, reduce_to_ji, expand_to_kj, dropout_dict=None)[source]#

Returns messages after interaction via message-passing.

params_dict()

Returns parameters keyed by name for this module and submodules.

state_dict()

Returns state keyed by name for this module and submodules.