jax_md_mod.model.layers.InteractionBlock#
- class InteractionBlock(embed_size, num_res_before_skip, num_res_after_skip, activation=<PjitFunction of <function silu>>, init_kwargs=None, angle_int_embed_size=None, basis_int_embed_size=8, name='Interaction')[source]#
DimeNet++ Interaction block.
Performs directional message-passing based on RBF and SBF embeddings as well as messages from the previous message-passing iteration. Updated messages are used in the subsequent Output block.
Methods
- __init__(embed_size, num_res_before_skip, num_res_after_skip, activation=<PjitFunction of <function silu>>, init_kwargs=None, angle_int_embed_size=None, basis_int_embed_size=8, name='Interaction')[source]#
Initializes an Interaction block.
- Parameters:
embed_size – Size of the edge embedding.
num_res_before_skip – Number of Residual blocks before skip
num_res_after_skip – Number of Residual blocks after skip
activation – Activation function
init_kwargs – Dict of initialization kwargs for Linear layers
angle_int_embed_size – Embedding size of Linear layers for down-projected triplet interation
basis_int_embed_size – Embedding size of Linear layers for interation of RBS/ SBF basis
name – Name of Interaction block
- __call__(m_input, rbf, sbf, reduce_to_ji, expand_to_kj, dropout_dict=None)[source]#
Returns messages after interaction via message-passing.
Returns parameters keyed by name for this module and submodules.
Returns state keyed by name for this module and submodules.