jax_md_mod.model.layers.OrthogonalVarianceScalingInit

jax_md_mod.model.layers.OrthogonalVarianceScalingInit#

class OrthogonalVarianceScalingInit(scale=2.0)[source]#

Initializer scaling variance of uniform orthogonal matrix distribution.

Generates a weight matrix with variance according to Glorot initialization. Based on a random (semi-)orthogonal matrix. Neural networks are expected to learn better when features are decorrelated e.g. stated by “Reducing overfitting in deep networks by decorrelating representations”.

The approach is adopted from the original DimeNet and the implementation is inspired by Haiku’s variance scaling initializer.

Variables:

scale – Variance scaling factor

Methods

__init__([scale])

Constructs the OrthogonalVarianceScaling Initializer.