jax_md_mod.model.layers.ResidualLayer

jax_md_mod.model.layers.ResidualLayer#

class ResidualLayer(layer_size, activation=<PjitFunction of <function silu>>, init_kwargs=None, name='ResLayer')[source]#

Residual Layer: 2 activated Linear layers and a skip connection.

Methods

__init__(layer_size, activation=<PjitFunction of <function silu>>, init_kwargs=None, name='ResLayer')[source]#

Initializes the Residual layer.

Parameters:
  • layer_size – Output size of the Linear layers

  • activation – Activation function

  • init_kwargs – Dict of initialization kwargs for Linear layers

  • name – Name of the Residual layer

__call__(inputs, dropout_dict=None)[source]#

Returns the ouput of the Residual layer.

params_dict()

Returns parameters keyed by name for this module and submodules.

state_dict()

Returns state keyed by name for this module and submodules.