jax_md_mod.model.layers.ResidualLayer#
- class ResidualLayer(layer_size, activation=<PjitFunction of <function silu>>, init_kwargs=None, name='ResLayer')[source]#
Residual Layer: 2 activated Linear layers and a skip connection.
Methods
- __init__(layer_size, activation=<PjitFunction of <function silu>>, init_kwargs=None, name='ResLayer')[source]#
Initializes the Residual layer.
- Parameters:
layer_size – Output size of the Linear layers
activation – Activation function
init_kwargs – Dict of initialization kwargs for Linear layers
name – Name of the Residual layer
Returns parameters keyed by name for this module and submodules.
Returns state keyed by name for this module and submodules.