Hide code cell content

import os
import functools
import contextlib
from pathlib import Path
from urllib import request
import time

import numpy as onp

import jax
import optax
from jax import numpy as jnp, random, tree_util

from jax_md_mod import io, custom_quantity, custom_space
from jax_md_mod.model import layers, neural_networks, prior
from jax_md import simulate, partition, space

import mdtraj

import matplotlib.pyplot as plt

import haiku as hk

from chemtrain.data import preprocessing
from chemtrain.ensemble import sampling
from chemtrain import quantity, trainers, util

out_dir = Path("../_data/output")
out_dir.mkdir(exist_ok=True)

base_path = Path(os.environ.get("DATA_PATH", "./data"))
2024-08-22 14:39:13.799605: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.6.20). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
%config InlineBackend.figure_formats = ['svg']
%matplotlib inline

Alanine Dipeptide in Implicit Water#

This example reproduces the results of the paper Deep Coarse-Grained Potentials via Relative Entropy Minimization [1].

This paper introduces the bottom-up coarse-graining approach of Relative Entropy Minimization (REM) to Neural Network (NN) potential models. In the example of alanine-dipeptide in implicit water, the paper compares REM to the conventional Force Matching (FM) scheme.

We outline the theoretical backgrounds of the REM and FM schemes in the toy examples to Relative Entropy Minimization and Force Matching. We refer to the original paper for a more detailed analysis of the connection between those approaches.

Problem#

Alanine dipeptide is a good test case for developing coarse-grained protein models. Torsional potentials derived from alanine dipeptide generalize well to larger amino acids [2]. Thus, this example compares two algorithms for learning Graph Neural Network (GNN) potentials based on representing the torsional preferences of alanine dipeptide.

../_images/alanine_heavy.png

Alanine Dipeptide#

The coarse-grained model of alanine dipeptide only preserves the heavy atoms $\mathrm C$, $\mathrm N$, and $\mathrm O$ but distinguishes between carbon with different environments. Therefore, the coarse-grained model contains the five species $\mathrm{CH_3}, \mathrm{CH}, \mathrm{C}, \mathrm{O}, \mathrm{N}$. Additionally, the GNN does not explicitly consider water and should thus implicitly include the interactions with the solvent.

# Set random key and thermodynamic statepoint (300 K)
key = random.PRNGKey(21)
kT = 300. * quantity.kb
state_kwargs = {'kT': kT}

Before defining the potential model, we set up the space in which it acts. Therefore, we load an initial conformation, define a corresponding periodic space, and construct a neighbor list with a cutoff radius of $0.5~\text{nm}$.

# To construct a NN model, we first need to create a neighborlist.
# The neighborlist is necessary to construct a graph containing the environment
# of each atom.
box, r_init, _, _ = io.load_box("../_data/alanine_heavy_2_7nm.gro")

n_species = 5
r_cut = 0.5

fractional = True

displacement_fn, shift_fn = space.periodic_general(
    box, fractional_coordinates=fractional)

if fractional:
    box_tensor, scale_fn = custom_space.init_fractional_coordinates(box)
    r_init = scale_fn(r_init)
else:
    box_tensor = box

neighbor_fn = partition.neighbor_list(
    displacement_fn, box_tensor, r_cut, disable_cell_list=True,
    fractional_coordinates=fractional, capacity_multiplier=1.5
)

nbrs_init = neighbor_fn.allocate(r_init, extra_capacity=1)

To improve the stability of the learned potential, we employ a $\Delta$–learning approach. As a prior, we use commonly used intra-molecular potentials and repulsive non-bonded terms

\[\begin{split}U^\text{prior}(\mathbf R) = \sum_{(i,j) \in \mathcal B} U_b(|\mathbf r_i - \mathbf r_j|) + \sum_{(i,j,k) \in \mathcal A}U_\alpha(\alpha(\mathbf r_i, \mathbf r_j, \mathbf r_k)) \\ + \sum_{(i,j,k,l) \in \mathcal D}U_\phi(\phi(\mathbf r_j - \mathbf r_i, \mathbf r_l - \mathbf r_k)) + \sum_{(i,j) \notin\mathcal B,\mathcal A, \mathcal D } U_r(|\mathbf r_i - \mathbf r_j|),\end{split}\]

where $\mathcal B, \mathcal A, \mathcal D$ are the indices of the atoms that form bonds, angles, and dihedral angles, respectively. Note also that atoms do not interact via the nonbonded term if involved in the same bond, angle, or dihedral angle. For the bonds and angles, we choose a harmonic potential $$ U_b(x) = \frac{k}{2}(x - x_0)^2, $$ where we derived the parameters $k$ and $x_0$ for each combination of species from the mean and variance of $x$ $$ x_0 = \langle x \rangle_\text{AT}, \quad k = \frac{1}{2\beta\langle(x - x_0)^2\rangle_\text{AT}}. $$ For the dihedral angles, we choose a cosine series up to third order $$ U_\phi(\phi) = \sum_{n=1}^3 k_{\phi, i} (1 + \cos(n\phi - \phi_{0, i})). $$ where we took the values for the force constants $k_{\phi, i}$ and phase shifts $\phi_{0, i}$ from the Amber03 force field [3]. The non-bonded repulsion has the form $$ U_r(x) = \varepsilon\left(\frac{x}{\sigma}\right)^{12}. $$ We computed the pairwise parameters $\varepsilon_{ij} = \sqrt{\varepsilon_{ii}\varepsilon_{jj}}$ and $\sigma_{ij} = \frac{\sigma_{ii} + \sigma_{jj}}{2}$ via the Lorentz-Berthelot combining rules [4] from the Amber03 parameters.

We now construct this prior potential in chemtrain. First, we select the corresponding potential terms, i.e., harmonic bond and angle terms, cosine dihedral angle terms, and repulsive nonbonded terms in the correct space.

prior_energy = prior.init_prior_potential(displacement_fn, nonbonded_type="repulsion")

Following, we load the potential parameters.

with open(base_path / "alanine_heavy.toml") as f:
    print(f.read())
    
force_field = prior.ForceField.load_ff(base_path / "alanine_heavy.toml")
[bonded]
bondtypes = """
#    i,    j,    b0,    kb
    C,  CH3,0.15172,271037.3
    C,    O,0.12325,476729.7
    C,    N,0.13359,416066.4
    C,   CA,0.15445,277737.0
   CA,  CH3,0.15370,268630.6
   CA,    N,0.14683,290439.9
  CH3,    N,0.14592,287462.0
"""
angletypes = """
#    i,    j,    k,    th0,    kth
  CH3,    C,    O,119.896,  0.319
  CH3,    C,    N,116.911,  0.305
    N,    C,    O,122.485,  0.338
    C,    N,   CA,124.654,  0.247
  CH3,   CA,    N,108.107,  0.260
    C,   CA,    N,113.044,  0.231
    C,   CA,  CH3,111.665,  0.230
   CA,    C,    O,120.259,  0.320
   CA,    C,    N,117.381,  0.301
    C,    N,  CH3,124.797,  0.245
"""
dihedraltypes = """
#    i,    j,    k,    l,    phase,    kd    pn
    C,    N,   CA,    C,   0.00,  4.251,   1
    C,    N,   CA,    C, 180.00,  1.444,   2
    C,    N,   CA,    C,   0.00,  0.945,   3
    N,   CA,    C,    N, 180.00,  2.861,   1
    N,   CA,    C,    N, 180.00,  6.082,   2
    N,   CA,    C,    N, 180.00,  1.931,   3
   CA,    N,    C,  CH3, 180.00, 10.460,   2
   CA,    N,    C,    O, 180.00, 10.460,   2
    C,    N,   CA,  CH3, 180.00,  1.480,   1
    C,    N,   CA,  CH3, 180.00,  3.697,   2
    C,    N,   CA,  CH3, 180.00,  0.950,   3
  CH3,   CA,    C,    N, 180.00,  3.257,   1
  CH3,   CA,    C,    N, 180.00,  0.275,   2
  CH3,   CA,    C,    N, 180.00,  0.234,   3
   CA,    C,    N,  CH3, 180.00, 10.460,   2
  CH3,    N,    C,    O, 180.00, 10.460,   2
"""

[nonbonded]
# No non-bonded interaction
atomtypes = """
# name,    species,    mass,        sigma,      epsilon
  CH3,    0, 15.035,3.39967e-01,3.59824e-01
    C,    1, 12.011,3.39967e-01,3.59824e-01
    O,    2, 15.999,2.95992e-01,8.78640e-01
    N,    3, 14.007,3.25000e-01,7.11280e-01
   CA,    4, 13.019,3.39967e-01,3.59824e-01
"""

Finally, we must define the index sets for the bonds $\mathcal B$, angles $\mathcal A$, and dihedral angles $\mathcal D$. Luckily, we do not have to gather the correct indices by hand. Instead, chemtrain allows automatically identifying these indices by traversing a molecular graph, e.g., constructed from the mdtraj package [5].

top = mdtraj.load_topology("../_data/alanine_heavy_2_7nm.gro")

_mapping = force_field.mapping(by_name=True)
def mapping(name="", residue="", **kwargs):
    if residue == "NME" and name =="C":
        return _mapping(name="CH3", **kwargs)
    if name == "CB":
        return _mapping(name="CH3", **kwargs)
    else:
        return _mapping(name=name, **kwargs)

topology = prior.Topology.from_mdtraj(top, mapping)

species = topology.get_atom_species()
masses, *_ = force_field.get_nonbonded_params(species)[0].T

Model definition#

With the prior potential defined, we can now construct the learnable difference $\Delta U_\theta$. We select the DimeNet++ graph neural network architecture [6], shipped with chemtrain.

# We need an example graph (neighbor list) to determine the number
# of nodes, edges, angles, etc.

mlp_init = {
    'b_init': hk.initializers.Constant(0.),
    'w_init': layers.OrthogonalVarianceScalingInit(scale=1.)
}


init_fn, gnn_energy_fn = neural_networks.dimenetpp_neighborlist(
    displacement_fn, r_cut, n_species, r_init, nbrs_init,
    embed_size=32, init_kwargs=mlp_init,
)


# Create the parametrizable energy function and an initial
# parametrization
key, split = random.split(key)
init_params = init_fn(
    split, r_init, neighbor=nbrs_init, species=species
)

def energy_fn_template(energy_params):
    
    prior_energy_fn = prior_energy(topology, force_field)
    
    def energy_fn(pos, neighbor, **dynamic_kwargs):
        gnn_energy = gnn_energy_fn(
            energy_params, pos, neighbor, species=species,
            **dynamic_kwargs
        )

        prior_energy = prior_energy_fn(pos, neighbor=neighbor)
        return gnn_energy + prior_energy
    return energy_fn

Reference Data#

We use the reference data from the original paper [1]. This reference data consists of $5\times10^{5}$ conformations with corresponding forces, subsampled every $200~\text{fs}$ from a $100~\text{ns}$ all-atomistic simulation of alanine-dipeptide in TIP3P water.

# Download if not present
position_url = "https://drive.usercontent.google.com/download?id=1yKVHiI8y7ZNzyduh8bosR6YKScLyFezU&export=download&confirm=t&uuid=cff71a05-a45a-446e-bf2f-17f62e84263f"
force_url = "https://drive.usercontent.google.com/download?id=1JhRQcZ3tE2w-mLqTGN0JHJQst5uZijJx&export=download&confirm=t&uuid=ad4f279f-18b7-4b65-a144-ab1115110549" 

forces_path = "../_data/forces_heavy_100ns.npy"
positions_path = "../_data/positions_heavy_100ns.npy"

if not Path(forces_path).exists():
    request.urlretrieve(force_url, forces_path)
if not Path(positions_path).exists():
    request.urlretrieve(position_url, positions_path)

force_dataset = preprocessing.get_dataset(forces_path)
position_dataset = preprocessing.get_dataset(positions_path)

if fractional:
    position_dataset = preprocessing.scale_dataset_fractional(
        position_dataset, box
    )

Simulation#

As outlined in the Relative Entropy example, estimating the gradients of the relative entropy requires samples from the current coarse-grained ensemble. Moreover, we want to compare the learned models in predicting the distributions of backbone dihedral angles. Thus, we need to simulate CG alanine dipeptide based on the neural network and prior potential. However, we do not set up a single long simulation. Instead, we set up $100$ parallel and shorter simulations to accelerate the sampling. These simulations start from conformations randomly selected from the reference data without replacement.

dt = 0.002
n_chains = 100
gamma = 100.

key, split = random.split(key)
selection = random.choice(
    split, jnp.arange(position_dataset.shape[0]), shape=(n_chains,), replace=False)
r_init = position_dataset[selection, ...]

init_ref_state, sim_template = sampling.initialize_simulator_template(
    simulate.nvt_langevin, shift_fn=shift_fn, nbrs=nbrs_init,
    init_with_PRNGKey=True, extra_simulator_kwargs={"kT": kT, "gamma": gamma, "dt": dt}
)

key, split = random.split(key)
reference_state = init_ref_state(
    split, r_init,
    energy_or_force_fn=energy_fn_template(init_params),
    init_sim_kwargs={"mass": masses, "neighbor": nbrs_init}
)

Relative Entropy Minimization#

With reference data, potential model, and simulation routine set up, we can now instantiate the REM algorithm. We only use a subset of $80~%$ of the data for training. We also set the reweighting ratio to issue a new trajectory after every update. Nevertheless, we tune the step size such that at least $25~%$ effective samples remain after each update.

re_initial_lr = 0.003
re_epochs = 300
re_used_dataset_size = 400000

t_sample = 0.5
total_time = 110.
t_eq = 10.

re_timings = sampling.process_printouts(
    time_step=dt, total_time=total_time,
    t_equilib=t_eq, print_every=t_sample
)

lr_schedule = optax.exponential_decay(re_initial_lr, re_epochs, 0.01)
optimizer = optax.chain(
    optax.scale_by_adam(0.1, 0.4),
    optax.scale_by_schedule(lr_schedule),
    optax.scale_by_learning_rate(1.0)
)

relative_entropy = trainers.RelativeEntropy(
    init_params, optimizer, reweight_ratio=1.1,
    energy_fn_template=energy_fn_template)

relative_entropy.add_statepoint(
    position_dataset[:re_used_dataset_size, ...],
    energy_fn_template, sim_template, neighbor_fn,
    re_timings, state_kwargs, reference_state,
    reference_batch_size=re_used_dataset_size,
    vmap_batch=n_chains, resample_simstates=True)

relative_entropy.init_step_size_adaption(0.25)

Hide code cell output

/home/paul/chemtrain_rerun_ad/chemtrain/ensemble/reweighting.py:777: UserWarning: Propagation function is not safe by default. Do not forget to use the wrapper around the compute function to ensure that the neighborlist does not overflow.
  warnings.warn(
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/reductions.py:221: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in sum is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/reductions.py:221: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in sum is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
Time for trajectory initialization 0: 2.66248121658961 mins
[Step size] Use 7 iterations for 10 interior points.
if os.environ.get("RM_TRAINING", "False").lower() == "true":
    # Save the training log
    with open("../_data/output/alanine_dipeptide_rm_training.log", "w") as f:
        with contextlib.redirect_stdout(f), contextlib.redirect_stderr(f):
            start = time.time()
            relative_entropy.train(300)
            print(f"Total training time: {(time.time() - start) / 3600 : .1f} hours")
    
    relative_entropy.save_energy_params("../_data/output/alanine_dipeptide_re_params.pkl", '.pkl')
    relative_entropy.save_trainer("../_data/output/alanine_dipeptide_re_trainer.pkl", '.pkl')

relative_entropy = onp.load("../_data/output/alanine_dipeptide_re_trainer.pkl", allow_pickle=True)
relative_entropy_params = tree_util.tree_map(
    jnp.asarray, onp.load("../_data/output/alanine_dipeptide_re_params.pkl", allow_pickle=True)
)

with open("../_data/output/alanine_dipeptide_rm_training.log") as f:
    print(f.read())

Hide code cell output

/home/paul/miniconda3/envs/chemtrain/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
[Step Size] Found optimal step size 1.0 with residual 0.5050363540649414

[RE] Epoch 0
	Mean Delta RE loss = 15.53915
	Gradient norm: 96.78986358642578
	Elapsed time = 4.870 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 0.47622090578079224 with residual 0.0

[RE] Epoch 1
	Mean Delta RE loss = 1.09101
	Gradient norm: 136.9994659423828
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 0.6281028985977173 with residual 0.003246307373046875

[RE] Epoch 2
	Mean Delta RE loss = 21.48284
	Gradient norm: 101.16975402832031
	Elapsed time = 3.050 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 0.6273515820503235 with residual 0.0023374557495117188

[RE] Epoch 3
	Mean Delta RE loss = -3.64929
	Gradient norm: 105.43058013916016
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.431 ref_kT = 2.494
[Step Size] Found optimal step size 0.7043964862823486 with residual 6.67572021484375e-06

[RE] Epoch 4
	Mean Delta RE loss = 16.55299
	Gradient norm: 78.60031127929688
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 0.5539656281471252 with residual 2.6702880859375e-05

[RE] Epoch 5
	Mean Delta RE loss = 2.45981
	Gradient norm: 165.91661071777344
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.446 ref_kT = 2.494
[Step Size] Found optimal step size 0.8504898548126221 with residual 0.0003910064697265625

[RE] Epoch 6
	Mean Delta RE loss = 7.85277
	Gradient norm: 88.90467834472656
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.434 ref_kT = 2.494
[Step Size] Found optimal step size 0.48502036929130554 with residual 2.86102294921875e-06

[RE] Epoch 7
	Mean Delta RE loss = 6.67087
	Gradient norm: 123.801025390625
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 0.6894725561141968 with residual 0.0

[RE] Epoch 8
	Mean Delta RE loss = -3.47516
	Gradient norm: 96.43707275390625
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.430 ref_kT = 2.494
[Step Size] Found optimal step size 0.7450771331787109 with residual 3.814697265625e-06

[RE] Epoch 9
	Mean Delta RE loss = 12.66153
	Gradient norm: 76.37648010253906
	Elapsed time = 3.073 min
[Statepoint 0]
	kT = 2.425 ref_kT = 2.494
[Step Size] Found optimal step size 0.8334217071533203 with residual 1.239776611328125e-05

[RE] Epoch 10
	Mean Delta RE loss = 2.25065
	Gradient norm: 97.9099349975586
	Elapsed time = 3.072 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 0.6281028985977173 with residual 0.00045490264892578125

[RE] Epoch 11
	Mean Delta RE loss = 7.31625
	Gradient norm: 133.9579620361328
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.432 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.42595577239990234

[RE] Epoch 12
	Mean Delta RE loss = -4.42134
	Gradient norm: 35.630035400390625
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.435 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.05788135528564453

[RE] Epoch 13
	Mean Delta RE loss = 2.27360
	Gradient norm: 28.912837982177734
	Elapsed time = 3.070 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 0.6424583792686462 with residual 0.0

[RE] Epoch 14
	Mean Delta RE loss = 7.81918
	Gradient norm: 67.10161590576172
	Elapsed time = 3.075 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 0.7190110683441162 with residual 0.007254600524902344

[RE] Epoch 15
	Mean Delta RE loss = 0.08986
	Gradient norm: 123.29833984375
	Elapsed time = 3.072 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 0.839114785194397 with residual 2.47955322265625e-05

[RE] Epoch 16
	Mean Delta RE loss = -3.79533
	Gradient norm: 76.82270812988281
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 0.6378015279769897 with residual 0.00025653839111328125

[RE] Epoch 17
	Mean Delta RE loss = 13.96824
	Gradient norm: 174.5135955810547
	Elapsed time = 3.073 min
[Statepoint 0]
	kT = 2.434 ref_kT = 2.494
[Step Size] Found optimal step size 0.8099193572998047 with residual 0.0020742416381835938

[RE] Epoch 18
	Mean Delta RE loss = -1.41317
	Gradient norm: 72.4830322265625
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.431 ref_kT = 2.494
[Step Size] Found optimal step size 0.4546304941177368 with residual 1.239776611328125e-05

[RE] Epoch 19
	Mean Delta RE loss = -6.33745
	Gradient norm: 299.31365966796875
	Elapsed time = 3.067 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.26747894287109375

[RE] Epoch 20
	Mean Delta RE loss = 7.38566
	Gradient norm: 130.53680419921875
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.3663339614868164

[RE] Epoch 21
	Mean Delta RE loss = -15.23911
	Gradient norm: 65.86870574951172
	Elapsed time = 3.069 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 0.6409062147140503 with residual 1.239776611328125e-05

[RE] Epoch 22
	Mean Delta RE loss = 17.11715
	Gradient norm: 92.77953338623047
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 0.5589826107025146 with residual 3.337860107421875e-05

[RE] Epoch 23
	Mean Delta RE loss = 13.40763
	Gradient norm: 250.1896209716797
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 0.8531784415245056 with residual 1.239776611328125e-05

[RE] Epoch 24
	Mean Delta RE loss = 9.53603
	Gradient norm: 77.25128936767578
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 0.6131675243377686 with residual 1.9073486328125e-06

[RE] Epoch 25
	Mean Delta RE loss = 11.22598
	Gradient norm: 259.5020446777344
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.432 ref_kT = 2.494
[Step Size] Found optimal step size 0.8540397882461548 with residual 9.5367431640625e-07

[RE] Epoch 26
	Mean Delta RE loss = 10.39308
	Gradient norm: 99.2017822265625
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 0.5352264046669006 with residual 0.0

[RE] Epoch 27
	Mean Delta RE loss = -4.08759
	Gradient norm: 184.17913818359375
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 0.9090918302536011 with residual 0.12293529510498047

[RE] Epoch 28
	Mean Delta RE loss = -1.68255
	Gradient norm: 52.259273529052734
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 0.7378681898117065 with residual 2.193450927734375e-05

[RE] Epoch 29
	Mean Delta RE loss = 0.76418
	Gradient norm: 298.23956298828125
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.4898519515991211

[RE] Epoch 30
	Mean Delta RE loss = 3.71359
	Gradient norm: 85.60546875
	Elapsed time = 3.069 min
[Statepoint 0]
	kT = 2.429 ref_kT = 2.494
[Step Size] Found optimal step size 0.861372709274292 with residual 1.9073486328125e-06

[RE] Epoch 31
	Mean Delta RE loss = 7.09233
	Gradient norm: 198.5961151123047
	Elapsed time = 3.072 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.6121788024902344

[RE] Epoch 32
	Mean Delta RE loss = -17.12446
	Gradient norm: 67.64187622070312
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 0.9090918898582458 with residual 0.04456806182861328

[RE] Epoch 33
	Mean Delta RE loss = -9.19825
	Gradient norm: 38.26338195800781
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 0.8045551776885986 with residual 1.239776611328125e-05

[RE] Epoch 34
	Mean Delta RE loss = 10.70455
	Gradient norm: 213.989013671875
	Elapsed time = 3.051 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.6927347183227539

[RE] Epoch 35
	Mean Delta RE loss = -1.73648
	Gradient norm: 33.243560791015625
	Elapsed time = 3.066 min
[Statepoint 0]
	kT = 2.444 ref_kT = 2.494
[Step Size] Found optimal step size 0.9090918302536011 with residual 0.07998371124267578

[RE] Epoch 36
	Mean Delta RE loss = 13.79533
	Gradient norm: 104.1267318725586
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.435 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.016901016235351562

[RE] Epoch 37
	Mean Delta RE loss = -6.25355
	Gradient norm: 111.04705047607422
	Elapsed time = 3.057 min
[Statepoint 0]
	kT = 2.435 ref_kT = 2.494
[Step Size] Found optimal step size 0.9090918302536011 with residual 0.033603668212890625

[RE] Epoch 38
	Mean Delta RE loss = -1.05056
	Gradient norm: 96.11393737792969
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.11266708374023438

[RE] Epoch 39
	Mean Delta RE loss = -0.73153
	Gradient norm: 164.17063903808594
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.3326883316040039

[RE] Epoch 40
	Mean Delta RE loss = 6.60476
	Gradient norm: 44.0087890625
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.48657703399658203

[RE] Epoch 41
	Mean Delta RE loss = 12.16525
	Gradient norm: 108.45708465576172
	Elapsed time = 3.055 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.6175136566162109

[RE] Epoch 42
	Mean Delta RE loss = -15.87684
	Gradient norm: 53.473575592041016
	Elapsed time = 3.070 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.23824787139892578

[RE] Epoch 43
	Mean Delta RE loss = 6.12766
	Gradient norm: 115.73217010498047
	Elapsed time = 3.055 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.3355598449707031

[RE] Epoch 44
	Mean Delta RE loss = -10.64977
	Gradient norm: 51.03236770629883
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.2854146957397461

[RE] Epoch 45
	Mean Delta RE loss = -4.77047
	Gradient norm: 89.69039154052734
	Elapsed time = 3.066 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.4124765396118164

[RE] Epoch 46
	Mean Delta RE loss = 0.05812
	Gradient norm: 50.834102630615234
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.428 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.4334859848022461

[RE] Epoch 47
	Mean Delta RE loss = -6.81593
	Gradient norm: 51.84660339355469
	Elapsed time = 3.073 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.6274814605712891

[RE] Epoch 48
	Mean Delta RE loss = 3.58504
	Gradient norm: 53.69062805175781
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.457611083984375

[RE] Epoch 49
	Mean Delta RE loss = -1.91072
	Gradient norm: 31.703033447265625
	Elapsed time = 3.072 min
[Statepoint 0]
	kT = 2.445 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.08647632598876953

[RE] Epoch 50
	Mean Delta RE loss = 0.24307
	Gradient norm: 141.77261352539062
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.435 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.8804912567138672

[RE] Epoch 51
	Mean Delta RE loss = -4.86190
	Gradient norm: 65.97657012939453
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.446 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.2674398422241211

[RE] Epoch 52
	Mean Delta RE loss = 2.63218
	Gradient norm: 112.94928741455078
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.7255525588989258

[RE] Epoch 53
	Mean Delta RE loss = -7.78312
	Gradient norm: 47.47684860229492
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.7090339660644531

[RE] Epoch 54
	Mean Delta RE loss = -11.33452
	Gradient norm: 38.9448356628418
	Elapsed time = 3.055 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.573089599609375

[RE] Epoch 55
	Mean Delta RE loss = 2.51961
	Gradient norm: 55.06791687011719
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.428 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.7267780303955078

[RE] Epoch 56
	Mean Delta RE loss = -2.48782
	Gradient norm: 53.38599395751953
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.8906764984130859

[RE] Epoch 57
	Mean Delta RE loss = -7.47405
	Gradient norm: 43.80025863647461
	Elapsed time = 3.076 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.79827880859375

[RE] Epoch 58
	Mean Delta RE loss = -2.67337
	Gradient norm: 24.660694122314453
	Elapsed time = 3.052 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.6769723892211914

[RE] Epoch 59
	Mean Delta RE loss = -8.23746
	Gradient norm: 40.17720413208008
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.5663900375366211

[RE] Epoch 60
	Mean Delta RE loss = 2.89655
	Gradient norm: 50.7210807800293
	Elapsed time = 3.066 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.6876688003540039

[RE] Epoch 61
	Mean Delta RE loss = -6.82373
	Gradient norm: 70.7453384399414
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.444 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.915095329284668

[RE] Epoch 62
	Mean Delta RE loss = 2.44837
	Gradient norm: 25.503345489501953
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.7188739776611328

[RE] Epoch 63
	Mean Delta RE loss = 4.58812
	Gradient norm: 85.23062133789062
	Elapsed time = 3.066 min
[Statepoint 0]
	kT = 2.444 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.8026313781738281

[RE] Epoch 64
	Mean Delta RE loss = -11.87895
	Gradient norm: 32.247100830078125
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.430 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.9392538070678711

[RE] Epoch 65
	Mean Delta RE loss = -11.43642
	Gradient norm: 32.92289733886719
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.444 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.6600704193115234

[RE] Epoch 66
	Mean Delta RE loss = -4.53551
	Gradient norm: 35.656639099121094
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.432 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.6591463088989258

[RE] Epoch 67
	Mean Delta RE loss = -12.34015
	Gradient norm: 75.13743591308594
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.8978986740112305

[RE] Epoch 68
	Mean Delta RE loss = -1.57562
	Gradient norm: 48.27482223510742
	Elapsed time = 3.072 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.9930248260498047

[RE] Epoch 69
	Mean Delta RE loss = -2.45973
	Gradient norm: 14.779623985290527
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.446 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.7990493774414062

[RE] Epoch 70
	Mean Delta RE loss = -7.78443
	Gradient norm: 40.037986755371094
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.935725212097168

[RE] Epoch 71
	Mean Delta RE loss = -0.29148
	Gradient norm: 35.628578186035156
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.445 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.042037010192871

[RE] Epoch 72
	Mean Delta RE loss = -2.55216
	Gradient norm: 16.427093505859375
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.7313337326049805

[RE] Epoch 73
	Mean Delta RE loss = -3.32064
	Gradient norm: 32.10789108276367
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.7728166580200195

[RE] Epoch 74
	Mean Delta RE loss = -1.69896
	Gradient norm: 94.91574096679688
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.1665067672729492

[RE] Epoch 75
	Mean Delta RE loss = -4.83087
	Gradient norm: 18.975688934326172
	Elapsed time = 3.055 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.027780532836914

[RE] Epoch 76
	Mean Delta RE loss = -2.47687
	Gradient norm: 13.485079765319824
	Elapsed time = 3.067 min
[Statepoint 0]
	kT = 2.447 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.9475574493408203

[RE] Epoch 77
	Mean Delta RE loss = -10.63978
	Gradient norm: 23.51221466064453
	Elapsed time = 3.068 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.9697885513305664

[RE] Epoch 78
	Mean Delta RE loss = -5.17230
	Gradient norm: 20.220495223999023
	Elapsed time = 3.070 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.796137809753418

[RE] Epoch 79
	Mean Delta RE loss = -6.97275
	Gradient norm: 56.07233428955078
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.0164403915405273

[RE] Epoch 80
	Mean Delta RE loss = -7.80621
	Gradient norm: 33.16873550415039
	Elapsed time = 3.054 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.1218576431274414

[RE] Epoch 81
	Mean Delta RE loss = -6.80875
	Gradient norm: 19.53645896911621
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.9825096130371094

[RE] Epoch 82
	Mean Delta RE loss = -7.17308
	Gradient norm: 19.23621940612793
	Elapsed time = 3.057 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.8963098526000977

[RE] Epoch 83
	Mean Delta RE loss = -11.45778
	Gradient norm: 41.62761688232422
	Elapsed time = 3.068 min
[Statepoint 0]
	kT = 2.446 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.0531206130981445

[RE] Epoch 84
	Mean Delta RE loss = -5.52613
	Gradient norm: 25.43317222595215
	Elapsed time = 3.066 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.0830745697021484

[RE] Epoch 85
	Mean Delta RE loss = -5.61816
	Gradient norm: 15.72153091430664
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.9488162994384766

[RE] Epoch 86
	Mean Delta RE loss = -10.62728
	Gradient norm: 37.572628021240234
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.448 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.0278472900390625

[RE] Epoch 87
	Mean Delta RE loss = -4.25052
	Gradient norm: 39.52288818359375
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 0.9733171463012695

[RE] Epoch 88
	Mean Delta RE loss = -6.24033
	Gradient norm: 31.289201736450195
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.446 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.063115119934082

[RE] Epoch 89
	Mean Delta RE loss = -8.59588
	Gradient norm: 26.60273551940918
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.431 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.1394128799438477

[RE] Epoch 90
	Mean Delta RE loss = -2.71474
	Gradient norm: 17.07054328918457
	Elapsed time = 3.068 min
[Statepoint 0]
	kT = 2.431 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.0851221084594727

[RE] Epoch 91
	Mean Delta RE loss = -8.62135
	Gradient norm: 15.888648986816406
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.1025466918945312

[RE] Epoch 92
	Mean Delta RE loss = -10.05921
	Gradient norm: 19.19649887084961
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.1487436294555664

[RE] Epoch 93
	Mean Delta RE loss = -3.02613
	Gradient norm: 18.183650970458984
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.157358169555664

[RE] Epoch 94
	Mean Delta RE loss = -16.53636
	Gradient norm: 14.75916862487793
	Elapsed time = 3.057 min
[Statepoint 0]
	kT = 2.444 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.1031970977783203

[RE] Epoch 95
	Mean Delta RE loss = -7.46530
	Gradient norm: 18.357086181640625
	Elapsed time = 3.067 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.0449161529541016

[RE] Epoch 96
	Mean Delta RE loss = -15.64567
	Gradient norm: 26.463340759277344
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.066427230834961

[RE] Epoch 97
	Mean Delta RE loss = -7.23178
	Gradient norm: 25.799190521240234
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.435 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.1202049255371094

[RE] Epoch 98
	Mean Delta RE loss = -9.62063
	Gradient norm: 24.015247344970703
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.445 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2235956192016602

[RE] Epoch 99
	Mean Delta RE loss = -11.24260
	Gradient norm: 13.153402328491211
	Elapsed time = 3.054 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.1425952911376953

[RE] Epoch 100
	Mean Delta RE loss = -8.83402
	Gradient norm: 18.418441772460938
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.108438491821289

[RE] Epoch 101
	Mean Delta RE loss = -11.28405
	Gradient norm: 24.47221565246582
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.1634435653686523

[RE] Epoch 102
	Mean Delta RE loss = -11.49883
	Gradient norm: 22.502002716064453
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.1685523986816406

[RE] Epoch 103
	Mean Delta RE loss = -5.68649
	Gradient norm: 18.725576400756836
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2153282165527344

[RE] Epoch 104
	Mean Delta RE loss = -11.69006
	Gradient norm: 11.638129234313965
	Elapsed time = 3.057 min
[Statepoint 0]
	kT = 2.446 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.1633386611938477

[RE] Epoch 105
	Mean Delta RE loss = -8.44737
	Gradient norm: 10.894610404968262
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.426 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.084050178527832

[RE] Epoch 106
	Mean Delta RE loss = -5.28741
	Gradient norm: 34.60201644897461
	Elapsed time = 3.066 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2235403060913086

[RE] Epoch 107
	Mean Delta RE loss = -9.32830
	Gradient norm: 14.529751777648926
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.434 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.211207389831543

[RE] Epoch 108
	Mean Delta RE loss = -11.01207
	Gradient norm: 10.092248916625977
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.446 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2004690170288086

[RE] Epoch 109
	Mean Delta RE loss = -8.15697
	Gradient norm: 13.046310424804688
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.1976251602172852

[RE] Epoch 110
	Mean Delta RE loss = -9.09556
	Gradient norm: 14.230477333068848
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.434 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2117996215820312

[RE] Epoch 111
	Mean Delta RE loss = -9.59851
	Gradient norm: 15.308122634887695
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.230630874633789

[RE] Epoch 112
	Mean Delta RE loss = -8.70173
	Gradient norm: 10.201995849609375
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.429 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2004737854003906

[RE] Epoch 113
	Mean Delta RE loss = -12.01889
	Gradient norm: 17.237102508544922
	Elapsed time = 3.057 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2618656158447266

[RE] Epoch 114
	Mean Delta RE loss = -10.74353
	Gradient norm: 5.978794574737549
	Elapsed time = 3.067 min
[Statepoint 0]
	kT = 2.449 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.206130027770996

[RE] Epoch 115
	Mean Delta RE loss = -11.29639
	Gradient norm: 13.167731285095215
	Elapsed time = 3.057 min
[Statepoint 0]
	kT = 2.445 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2181987762451172

[RE] Epoch 116
	Mean Delta RE loss = -12.73764
	Gradient norm: 15.388191223144531
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2417802810668945

[RE] Epoch 117
	Mean Delta RE loss = -11.33344
	Gradient norm: 13.308270454406738
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.249124526977539

[RE] Epoch 118
	Mean Delta RE loss = -10.94403
	Gradient norm: 12.970610618591309
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.446 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2337455749511719

[RE] Epoch 119
	Mean Delta RE loss = -14.07195
	Gradient norm: 13.982844352722168
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2831544876098633

[RE] Epoch 120
	Mean Delta RE loss = -11.69041
	Gradient norm: 6.284595966339111
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.451 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2248811721801758

[RE] Epoch 121
	Mean Delta RE loss = -13.44534
	Gradient norm: 9.83364486694336
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.240443229675293

[RE] Epoch 122
	Mean Delta RE loss = -13.15900
	Gradient norm: 13.235137939453125
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2880144119262695

[RE] Epoch 123
	Mean Delta RE loss = -10.65569
	Gradient norm: 7.317370414733887
	Elapsed time = 3.052 min
[Statepoint 0]
	kT = 2.434 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2332639694213867

[RE] Epoch 124
	Mean Delta RE loss = -16.56692
	Gradient norm: 14.440974235534668
	Elapsed time = 3.054 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2523813247680664

[RE] Epoch 125
	Mean Delta RE loss = -9.24221
	Gradient norm: 18.928131103515625
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.448 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3072795867919922

[RE] Epoch 126
	Mean Delta RE loss = -12.67651
	Gradient norm: 5.689342021942139
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2710142135620117

[RE] Epoch 127
	Mean Delta RE loss = -15.54162
	Gradient norm: 7.400149345397949
	Elapsed time = 3.052 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2272453308105469

[RE] Epoch 128
	Mean Delta RE loss = -13.40695
	Gradient norm: 18.344871520996094
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.435 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2813749313354492

[RE] Epoch 129
	Mean Delta RE loss = -14.56014
	Gradient norm: 12.533761978149414
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2994842529296875

[RE] Epoch 130
	Mean Delta RE loss = -13.11172
	Gradient norm: 6.96432638168335
	Elapsed time = 3.079 min
[Statepoint 0]
	kT = 2.432 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3035774230957031

[RE] Epoch 131
	Mean Delta RE loss = -12.90831
	Gradient norm: 5.745017051696777
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.444 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2682819366455078

[RE] Epoch 132
	Mean Delta RE loss = -15.57068
	Gradient norm: 13.34791088104248
	Elapsed time = 3.055 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2822542190551758

[RE] Epoch 133
	Mean Delta RE loss = -10.42317
	Gradient norm: 13.19522762298584
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.299311637878418

[RE] Epoch 134
	Mean Delta RE loss = -12.58752
	Gradient norm: 8.079959869384766
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3193445205688477

[RE] Epoch 135
	Mean Delta RE loss = -12.24893
	Gradient norm: 3.677492141723633
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.284529685974121

[RE] Epoch 136
	Mean Delta RE loss = -13.59531
	Gradient norm: 6.899271011352539
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3006486892700195

[RE] Epoch 137
	Mean Delta RE loss = -11.37333
	Gradient norm: 5.6504669189453125
	Elapsed time = 3.049 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3173418045043945

[RE] Epoch 138
	Mean Delta RE loss = -14.35630
	Gradient norm: 4.512545108795166
	Elapsed time = 3.053 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.293848991394043

[RE] Epoch 139
	Mean Delta RE loss = -13.45678
	Gradient norm: 7.4067535400390625
	Elapsed time = 3.066 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2825345993041992

[RE] Epoch 140
	Mean Delta RE loss = -14.48671
	Gradient norm: 8.129594802856445
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2653942108154297

[RE] Epoch 141
	Mean Delta RE loss = -14.86857
	Gradient norm: 20.03192138671875
	Elapsed time = 3.054 min
[Statepoint 0]
	kT = 2.428 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3422412872314453

[RE] Epoch 142
	Mean Delta RE loss = -13.06235
	Gradient norm: 5.95358419418335
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3358564376831055

[RE] Epoch 143
	Mean Delta RE loss = -14.27866
	Gradient norm: 3.150526523590088
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.430 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2985849380493164

[RE] Epoch 144
	Mean Delta RE loss = -13.59745
	Gradient norm: 8.030147552490234
	Elapsed time = 3.054 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3178749084472656

[RE] Epoch 145
	Mean Delta RE loss = -15.68148
	Gradient norm: 8.274618148803711
	Elapsed time = 3.071 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3094310760498047

[RE] Epoch 146
	Mean Delta RE loss = -13.71014
	Gradient norm: 8.106793403625488
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3409957885742188

[RE] Epoch 147
	Mean Delta RE loss = -14.36578
	Gradient norm: 3.832988739013672
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3013286590576172

[RE] Epoch 148
	Mean Delta RE loss = -14.57996
	Gradient norm: 7.073241233825684
	Elapsed time = 3.068 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.306929588317871

[RE] Epoch 149
	Mean Delta RE loss = -15.28099
	Gradient norm: 8.403550148010254
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.309636116027832

[RE] Epoch 150
	Mean Delta RE loss = -14.28960
	Gradient norm: 9.017448425292969
	Elapsed time = 3.057 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3185796737670898

[RE] Epoch 151
	Mean Delta RE loss = -13.06892
	Gradient norm: 8.319890975952148
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.447 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.364908218383789

[RE] Epoch 152
	Mean Delta RE loss = -15.55031
	Gradient norm: 1.773384690284729
	Elapsed time = 3.054 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3499641418457031

[RE] Epoch 153
	Mean Delta RE loss = -14.65472
	Gradient norm: 1.3933762311935425
	Elapsed time = 3.066 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3001222610473633

[RE] Epoch 154
	Mean Delta RE loss = -16.79922
	Gradient norm: 8.729778289794922
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3167343139648438

[RE] Epoch 155
	Mean Delta RE loss = -16.29221
	Gradient norm: 11.500153541564941
	Elapsed time = 3.054 min
[Statepoint 0]
	kT = 2.445 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3429069519042969

[RE] Epoch 156
	Mean Delta RE loss = -15.36625
	Gradient norm: 5.574641704559326
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3513031005859375

[RE] Epoch 157
	Mean Delta RE loss = -15.76413
	Gradient norm: 2.627825975418091
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3184099197387695

[RE] Epoch 158
	Mean Delta RE loss = -16.32934
	Gradient norm: 7.235553741455078
	Elapsed time = 3.057 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3438377380371094

[RE] Epoch 159
	Mean Delta RE loss = -14.96679
	Gradient norm: 4.246687889099121
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3552160263061523

[RE] Epoch 160
	Mean Delta RE loss = -15.48582
	Gradient norm: 2.5405290126800537
	Elapsed time = 3.052 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3351325988769531

[RE] Epoch 161
	Mean Delta RE loss = -15.86650
	Gradient norm: 5.132160663604736
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.444 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.340165138244629

[RE] Epoch 162
	Mean Delta RE loss = -17.16809
	Gradient norm: 3.496100425720215
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3436107635498047

[RE] Epoch 163
	Mean Delta RE loss = -16.17754
	Gradient norm: 4.020005702972412
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.434 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3275814056396484

[RE] Epoch 164
	Mean Delta RE loss = -17.20761
	Gradient norm: 7.757768154144287
	Elapsed time = 3.057 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3662996292114258

[RE] Epoch 165
	Mean Delta RE loss = -15.72056
	Gradient norm: 2.0566227436065674
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3373441696166992

[RE] Epoch 166
	Mean Delta RE loss = -16.74858
	Gradient norm: 4.561529159545898
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3400888442993164

[RE] Epoch 167
	Mean Delta RE loss = -17.50462
	Gradient norm: 6.793392181396484
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3585405349731445

[RE] Epoch 168
	Mean Delta RE loss = -16.41686
	Gradient norm: 2.575246810913086
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.446 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3513708114624023

[RE] Epoch 169
	Mean Delta RE loss = -18.04861
	Gradient norm: 3.0459911823272705
	Elapsed time = 3.049 min
[Statepoint 0]
	kT = 2.435 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3456945419311523

[RE] Epoch 170
	Mean Delta RE loss = -17.50878
	Gradient norm: 5.8458356857299805
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.358428955078125

[RE] Epoch 171
	Mean Delta RE loss = -17.77328
	Gradient norm: 2.994828462600708
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.435 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.364471435546875

[RE] Epoch 172
	Mean Delta RE loss = -16.73382
	Gradient norm: 1.8513425588607788
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.427 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3552560806274414

[RE] Epoch 173
	Mean Delta RE loss = -16.64303
	Gradient norm: 3.1560323238372803
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3469552993774414

[RE] Epoch 174
	Mean Delta RE loss = -16.25961
	Gradient norm: 4.5493855476379395
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.432 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3457307815551758

[RE] Epoch 175
	Mean Delta RE loss = -17.31482
	Gradient norm: 4.59910249710083
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3571901321411133

[RE] Epoch 176
	Mean Delta RE loss = -17.85289
	Gradient norm: 3.5356557369232178
	Elapsed time = 3.069 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3582515716552734

[RE] Epoch 177
	Mean Delta RE loss = -17.85704
	Gradient norm: 2.8813414573669434
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3598823547363281

[RE] Epoch 178
	Mean Delta RE loss = -16.79643
	Gradient norm: 3.2053639888763428
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3594560623168945

[RE] Epoch 179
	Mean Delta RE loss = -19.03428
	Gradient norm: 2.5991015434265137
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3706293106079102

[RE] Epoch 180
	Mean Delta RE loss = -17.81890
	Gradient norm: 1.0632922649383545
	Elapsed time = 3.055 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.359151840209961

[RE] Epoch 181
	Mean Delta RE loss = -17.11016
	Gradient norm: 2.4218761920928955
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.448 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3598880767822266

[RE] Epoch 182
	Mean Delta RE loss = -18.64445
	Gradient norm: 2.542487859725952
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.360128402709961

[RE] Epoch 183
	Mean Delta RE loss = -17.80390
	Gradient norm: 2.790783166885376
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.445 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3568305969238281

[RE] Epoch 184
	Mean Delta RE loss = -17.78707
	Gradient norm: 5.197903633117676
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3666810989379883

[RE] Epoch 185
	Mean Delta RE loss = -18.87475
	Gradient norm: 3.037252187728882
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.446 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.373042106628418

[RE] Epoch 186
	Mean Delta RE loss = -16.61665
	Gradient norm: 1.3816590309143066
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.360682487487793

[RE] Epoch 187
	Mean Delta RE loss = -19.04077
	Gradient norm: 2.9555537700653076
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.447 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3741493225097656

[RE] Epoch 188
	Mean Delta RE loss = -18.13439
	Gradient norm: 0.9765499234199524
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3671894073486328

[RE] Epoch 189
	Mean Delta RE loss = -18.17826
	Gradient norm: 1.481730341911316
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.446 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3590631484985352

[RE] Epoch 190
	Mean Delta RE loss = -17.38437
	Gradient norm: 2.910466432571411
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3619976043701172

[RE] Epoch 191
	Mean Delta RE loss = -16.96972
	Gradient norm: 3.4688901901245117
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.444 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.362229347229004

[RE] Epoch 192
	Mean Delta RE loss = -17.64009
	Gradient norm: 4.103027820587158
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3700456619262695

[RE] Epoch 193
	Mean Delta RE loss = -17.62687
	Gradient norm: 2.306272268295288
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3638944625854492

[RE] Epoch 194
	Mean Delta RE loss = -17.63587
	Gradient norm: 3.8597655296325684
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.431 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.368668556213379

[RE] Epoch 195
	Mean Delta RE loss = -17.41372
	Gradient norm: 2.4297127723693848
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3703765869140625

[RE] Epoch 196
	Mean Delta RE loss = -17.43218
	Gradient norm: 1.7903109788894653
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3720722198486328

[RE] Epoch 197
	Mean Delta RE loss = -18.69017
	Gradient norm: 1.6432973146438599
	Elapsed time = 3.068 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3737726211547852

[RE] Epoch 198
	Mean Delta RE loss = -17.37750
	Gradient norm: 1.012855052947998
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3719921112060547

[RE] Epoch 199
	Mean Delta RE loss = -18.25661
	Gradient norm: 1.7335808277130127
	Elapsed time = 3.068 min
[Statepoint 0]
	kT = 2.446 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3704967498779297

[RE] Epoch 200
	Mean Delta RE loss = -18.81231
	Gradient norm: 2.7832722663879395
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3717575073242188

[RE] Epoch 201
	Mean Delta RE loss = -18.75987
	Gradient norm: 3.146716356277466
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.379791259765625

[RE] Epoch 202
	Mean Delta RE loss = -18.87007
	Gradient norm: 0.4509009122848511
	Elapsed time = 3.068 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3695077896118164

[RE] Epoch 203
	Mean Delta RE loss = -17.70953
	Gradient norm: 3.6651222705841064
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3766002655029297

[RE] Epoch 204
	Mean Delta RE loss = -18.51177
	Gradient norm: 0.9601520299911499
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3739690780639648

[RE] Epoch 205
	Mean Delta RE loss = -18.77411
	Gradient norm: 2.244922637939453
	Elapsed time = 3.057 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3775825500488281

[RE] Epoch 206
	Mean Delta RE loss = -18.58600
	Gradient norm: 0.7444601058959961
	Elapsed time = 3.057 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3743934631347656

[RE] Epoch 207
	Mean Delta RE loss = -19.77736
	Gradient norm: 0.9672713875770569
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3729734420776367

[RE] Epoch 208
	Mean Delta RE loss = -19.02344
	Gradient norm: 1.5159153938293457
	Elapsed time = 3.057 min
[Statepoint 0]
	kT = 2.445 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3706073760986328

[RE] Epoch 209
	Mean Delta RE loss = -19.29267
	Gradient norm: 2.7292375564575195
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3718252182006836

[RE] Epoch 210
	Mean Delta RE loss = -19.47379
	Gradient norm: 4.1125874519348145
	Elapsed time = 3.052 min
[Statepoint 0]
	kT = 2.435 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3832206726074219

[RE] Epoch 211
	Mean Delta RE loss = -19.28185
	Gradient norm: 0.35569366812705994
	Elapsed time = 3.066 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3768682479858398

[RE] Epoch 212
	Mean Delta RE loss = -19.52115
	Gradient norm: 0.8547153472900391
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.435 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3746023178100586

[RE] Epoch 213
	Mean Delta RE loss = -19.18061
	Gradient norm: 1.59763503074646
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.377192497253418

[RE] Epoch 214
	Mean Delta RE loss = -19.06534
	Gradient norm: 1.365769624710083
	Elapsed time = 3.071 min
[Statepoint 0]
	kT = 2.434 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.375197410583496

[RE] Epoch 215
	Mean Delta RE loss = -20.22354
	Gradient norm: 1.591800570487976
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3750219345092773

[RE] Epoch 216
	Mean Delta RE loss = -20.33745
	Gradient norm: 2.5360400676727295
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.428 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3791894912719727

[RE] Epoch 217
	Mean Delta RE loss = -20.33985
	Gradient norm: 1.0544513463974
	Elapsed time = 3.066 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3777790069580078

[RE] Epoch 218
	Mean Delta RE loss = -19.04720
	Gradient norm: 1.7686244249343872
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3793220520019531

[RE] Epoch 219
	Mean Delta RE loss = -19.99051
	Gradient norm: 1.0310062170028687
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3802728652954102

[RE] Epoch 220
	Mean Delta RE loss = -18.30895
	Gradient norm: 0.8116942644119263
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.378011703491211

[RE] Epoch 221
	Mean Delta RE loss = -19.24472
	Gradient norm: 0.9649312496185303
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.435 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.376633644104004

[RE] Epoch 222
	Mean Delta RE loss = -20.27152
	Gradient norm: 2.4916460514068604
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.444 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3828868865966797

[RE] Epoch 223
	Mean Delta RE loss = -19.60775
	Gradient norm: 0.3950134813785553
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3773527145385742

[RE] Epoch 224
	Mean Delta RE loss = -19.38570
	Gradient norm: 2.0825283527374268
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.446 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.381577491760254

[RE] Epoch 225
	Mean Delta RE loss = -20.16046
	Gradient norm: 0.5681160688400269
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.379227638244629

[RE] Epoch 226
	Mean Delta RE loss = -20.01424
	Gradient norm: 0.823721706867218
	Elapsed time = 3.053 min
[Statepoint 0]
	kT = 2.435 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3818626403808594

[RE] Epoch 227
	Mean Delta RE loss = -20.37399
	Gradient norm: 0.8904100060462952
	Elapsed time = 3.057 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.379563331604004

[RE] Epoch 228
	Mean Delta RE loss = -20.07648
	Gradient norm: 1.5096255540847778
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3817243576049805

[RE] Epoch 229
	Mean Delta RE loss = -21.14618
	Gradient norm: 0.9427226781845093
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3807945251464844

[RE] Epoch 230
	Mean Delta RE loss = -20.61677
	Gradient norm: 1.0540525913238525
	Elapsed time = 3.067 min
[Statepoint 0]
	kT = 2.446 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3811731338500977

[RE] Epoch 231
	Mean Delta RE loss = -19.75109
	Gradient norm: 0.7723918557167053
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.445 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.380136489868164

[RE] Epoch 232
	Mean Delta RE loss = -20.35836
	Gradient norm: 0.6921436190605164
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.432 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3807144165039062

[RE] Epoch 233
	Mean Delta RE loss = -20.70519
	Gradient norm: 1.0420913696289062
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.431 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3803510665893555

[RE] Epoch 234
	Mean Delta RE loss = -20.40872
	Gradient norm: 1.5631952285766602
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.434 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.381667137145996

[RE] Epoch 235
	Mean Delta RE loss = -20.02763
	Gradient norm: 1.2444238662719727
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3813962936401367

[RE] Epoch 236
	Mean Delta RE loss = -20.70097
	Gradient norm: 0.7559975981712341
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3811883926391602

[RE] Epoch 237
	Mean Delta RE loss = -20.93683
	Gradient norm: 1.2606236934661865
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.445 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.383584976196289

[RE] Epoch 238
	Mean Delta RE loss = -20.99779
	Gradient norm: 0.5830962061882019
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.448 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3799123764038086

[RE] Epoch 239
	Mean Delta RE loss = -21.04637
	Gradient norm: 1.9518108367919922
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3811836242675781

[RE] Epoch 240
	Mean Delta RE loss = -20.57918
	Gradient norm: 1.3013523817062378
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.381204605102539

[RE] Epoch 241
	Mean Delta RE loss = -20.68107
	Gradient norm: 1.3157799243927002
	Elapsed time = 3.054 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.382817268371582

[RE] Epoch 242
	Mean Delta RE loss = -20.51052
	Gradient norm: 1.0846946239471436
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.432 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3822698593139648

[RE] Epoch 243
	Mean Delta RE loss = -20.98361
	Gradient norm: 1.60535728931427
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3828697204589844

[RE] Epoch 244
	Mean Delta RE loss = -20.62286
	Gradient norm: 0.890728235244751
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.383316993713379

[RE] Epoch 245
	Mean Delta RE loss = -21.32641
	Gradient norm: 0.6880960464477539
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3828134536743164

[RE] Epoch 246
	Mean Delta RE loss = -21.17336
	Gradient norm: 0.8725855350494385
	Elapsed time = 3.068 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.382314682006836

[RE] Epoch 247
	Mean Delta RE loss = -21.55794
	Gradient norm: 1.718909502029419
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.435 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3822288513183594

[RE] Epoch 248
	Mean Delta RE loss = -21.16269
	Gradient norm: 1.4922560453414917
	Elapsed time = 3.054 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.382063865661621

[RE] Epoch 249
	Mean Delta RE loss = -21.36727
	Gradient norm: 2.784839391708374
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3837671279907227

[RE] Epoch 250
	Mean Delta RE loss = -21.45856
	Gradient norm: 0.8262881636619568
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.435 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3829317092895508

[RE] Epoch 251
	Mean Delta RE loss = -21.34903
	Gradient norm: 1.9633859395980835
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3849353790283203

[RE] Epoch 252
	Mean Delta RE loss = -21.55167
	Gradient norm: 0.3001972734928131
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.445 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3839836120605469

[RE] Epoch 253
	Mean Delta RE loss = -21.67530
	Gradient norm: 0.6494846940040588
	Elapsed time = 3.067 min
[Statepoint 0]
	kT = 2.445 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.383493423461914

[RE] Epoch 254
	Mean Delta RE loss = -20.77075
	Gradient norm: 0.6158787608146667
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.444 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3835153579711914

[RE] Epoch 255
	Mean Delta RE loss = -20.63047
	Gradient norm: 0.8258280754089355
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.38336181640625

[RE] Epoch 256
	Mean Delta RE loss = -20.86545
	Gradient norm: 1.2712103128433228
	Elapsed time = 3.066 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3846197128295898

[RE] Epoch 257
	Mean Delta RE loss = -20.50692
	Gradient norm: 0.2751625180244446
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3839454650878906

[RE] Epoch 258
	Mean Delta RE loss = -20.55837
	Gradient norm: 0.7165603041648865
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3832759857177734

[RE] Epoch 259
	Mean Delta RE loss = -20.28492
	Gradient norm: 0.8787643909454346
	Elapsed time = 3.067 min
[Statepoint 0]
	kT = 2.447 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.383742332458496

[RE] Epoch 260
	Mean Delta RE loss = -20.67515
	Gradient norm: 0.4644695520401001
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3837223052978516

[RE] Epoch 261
	Mean Delta RE loss = -20.52976
	Gradient norm: 1.1340821981430054
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3835840225219727

[RE] Epoch 262
	Mean Delta RE loss = -20.87739
	Gradient norm: 1.3011189699172974
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3834638595581055

[RE] Epoch 263
	Mean Delta RE loss = -21.09548
	Gradient norm: 2.172177314758301
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.384577751159668

[RE] Epoch 264
	Mean Delta RE loss = -20.98687
	Gradient norm: 0.7959808707237244
	Elapsed time = 3.056 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.384622573852539

[RE] Epoch 265
	Mean Delta RE loss = -20.81060
	Gradient norm: 0.6294431090354919
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.383925437927246

[RE] Epoch 266
	Mean Delta RE loss = -21.14562
	Gradient norm: 0.799331784248352
	Elapsed time = 3.057 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3845205307006836

[RE] Epoch 267
	Mean Delta RE loss = -21.00014
	Gradient norm: 0.5184598565101624
	Elapsed time = 3.049 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.384185791015625

[RE] Epoch 268
	Mean Delta RE loss = -20.85096
	Gradient norm: 0.9197970628738403
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3848485946655273

[RE] Epoch 269
	Mean Delta RE loss = -21.43647
	Gradient norm: 0.48857057094573975
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.445 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.384206771850586

[RE] Epoch 270
	Mean Delta RE loss = -21.31525
	Gradient norm: 0.845851480960846
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3842144012451172

[RE] Epoch 271
	Mean Delta RE loss = -21.30834
	Gradient norm: 0.7804965972900391
	Elapsed time = 3.057 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3838081359863281

[RE] Epoch 272
	Mean Delta RE loss = -21.50130
	Gradient norm: 1.9039368629455566
	Elapsed time = 3.051 min
[Statepoint 0]
	kT = 2.446 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3845233917236328

[RE] Epoch 273
	Mean Delta RE loss = -21.40242
	Gradient norm: 1.6037694215774536
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3840951919555664

[RE] Epoch 274
	Mean Delta RE loss = -21.60691
	Gradient norm: 4.168537616729736
	Elapsed time = 3.050 min
[Statepoint 0]
	kT = 2.444 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.384657859802246

[RE] Epoch 275
	Mean Delta RE loss = -21.68963
	Gradient norm: 1.6718472242355347
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3855705261230469

[RE] Epoch 276
	Mean Delta RE loss = -21.57642
	Gradient norm: 0.43147537112236023
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.444 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.385148048400879

[RE] Epoch 277
	Mean Delta RE loss = -21.93787
	Gradient norm: 0.8026583790779114
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.434 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3844528198242188

[RE] Epoch 278
	Mean Delta RE loss = -21.61170
	Gradient norm: 2.393871784210205
	Elapsed time = 3.065 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.385289192199707

[RE] Epoch 279
	Mean Delta RE loss = -21.38023
	Gradient norm: 0.8842772841453552
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3845062255859375

[RE] Epoch 280
	Mean Delta RE loss = -21.27612
	Gradient norm: 2.326753616333008
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3849306106567383

[RE] Epoch 281
	Mean Delta RE loss = -21.30926
	Gradient norm: 1.68629789352417
	Elapsed time = 3.057 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3853425979614258

[RE] Epoch 282
	Mean Delta RE loss = -21.35375
	Gradient norm: 1.0329744815826416
	Elapsed time = 3.066 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3858747482299805

[RE] Epoch 283
	Mean Delta RE loss = -21.67064
	Gradient norm: 0.2095198631286621
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3851404190063477

[RE] Epoch 284
	Mean Delta RE loss = -21.75385
	Gradient norm: 0.9696405529975891
	Elapsed time = 3.055 min
[Statepoint 0]
	kT = 2.434 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3846197128295898

[RE] Epoch 285
	Mean Delta RE loss = -22.23992
	Gradient norm: 4.410651206970215
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.447 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3847742080688477

[RE] Epoch 286
	Mean Delta RE loss = -22.16305
	Gradient norm: 4.064218044281006
	Elapsed time = 3.062 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3852310180664062

[RE] Epoch 287
	Mean Delta RE loss = -22.03316
	Gradient norm: 2.4372479915618896
	Elapsed time = 3.067 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3851661682128906

[RE] Epoch 288
	Mean Delta RE loss = -21.74469
	Gradient norm: 2.439411163330078
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.444 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3856897354125977

[RE] Epoch 289
	Mean Delta RE loss = -21.77463
	Gradient norm: 0.6691542267799377
	Elapsed time = 3.060 min
[Statepoint 0]
	kT = 2.432 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3855218887329102

[RE] Epoch 290
	Mean Delta RE loss = -21.63434
	Gradient norm: 0.736581563949585
	Elapsed time = 3.064 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3853273391723633

[RE] Epoch 291
	Mean Delta RE loss = -21.49836
	Gradient norm: 1.1609324216842651
	Elapsed time = 3.067 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.384993553161621

[RE] Epoch 292
	Mean Delta RE loss = -21.48102
	Gradient norm: 1.83067786693573
	Elapsed time = 3.061 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3852176666259766

[RE] Epoch 293
	Mean Delta RE loss = -21.55817
	Gradient norm: 1.1055381298065186
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3855857849121094

[RE] Epoch 294
	Mean Delta RE loss = -21.41875
	Gradient norm: 0.48371773958206177
	Elapsed time = 3.066 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3855619430541992

[RE] Epoch 295
	Mean Delta RE loss = -21.47185
	Gradient norm: 0.507348895072937
	Elapsed time = 3.053 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3854360580444336

[RE] Epoch 296
	Mean Delta RE loss = -21.49526
	Gradient norm: 0.44295579195022583
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.431 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3855581283569336

[RE] Epoch 297
	Mean Delta RE loss = -21.84192
	Gradient norm: 0.7438615560531616
	Elapsed time = 3.063 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3852729797363281

[RE] Epoch 298
	Mean Delta RE loss = -21.96592
	Gradient norm: 1.0665724277496338
	Elapsed time = 3.058 min
[Statepoint 0]
	kT = 2.444 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3857364654541016

[RE] Epoch 299
	Mean Delta RE loss = -21.93011
	Gradient norm: 0.4649997055530548
	Elapsed time = 3.059 min
[Statepoint 0]
	kT = 2.445 ref_kT = 2.494
Total training time:  15.3 hours

Plotting the change of relative entropy and the gradient norm indicates convergence of the algorithm for $300$ epochs.

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4), layout="constrained")

ax1.plot(relative_entropy["delta_re"][0])
ax1.set_xticks(ticks=range(0, re_epochs + 1, 50))
ax1.set_xlabel("Epoch")
ax1.set_ylabel("RE Loss")

ax2.plot(relative_entropy["gradient_norm_history"])
ax2.set_xticks(ticks=range(0, re_epochs + 1, 50))
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Gradient Norm")
Text(0, 0.5, 'Gradient Norm')
../_images/245f6b071832ef410162712370cb6beed08df942606c4b3439952459d96310bd.svg

Force Matching#

As a reference, we train an instance of the potential model via FM. Similar to REM, we only consider a subset of the data for training. However, mainly to prevent overfitting, we need additional data for validation. Hence, the part of data we can use for training is smaller.

fm_epochs = 100
fm_used_dataset_size = 500000
fm_train_ratio = 0.7
fm_val_ratio = 0.1
fm_batch_per_device = 500 // len(jax.devices())
fm_batch_cache = 50
fm_initial_lr = 0.0003

lrd = int(fm_used_dataset_size / fm_batch_per_device / len(jax.devices()) * fm_epochs)
lr_schedule = optax.exponential_decay(fm_initial_lr, lrd, 0.01)
fm_optimizer = optax.chain(
    optax.scale_by_adam(),
    optax.scale_by_schedule(lr_schedule),
    optax.scale_by_learning_rate(1.0)
)

force_matching = trainers.ForceMatching(
    init_params, fm_optimizer, energy_fn_template, nbrs_init,
    batch_per_device=fm_batch_per_device,
    batch_cache=fm_batch_cache,
)

force_matching.set_datasets({
    'R': position_dataset[:fm_used_dataset_size, ...], 
    'F': force_dataset[:fm_used_dataset_size, ...]
}, train_ratio=fm_train_ratio 
)
if os.environ.get("FM_TRAINING", "False").lower() == "true":
    # Save the training log
    with open("../_data/output/alanine_dipeptide_fm_training.log", "w") as f:
        with contextlib.redirect_stdout(f), contextlib.redirect_stderr(f):
            
            print(f"Visible devices: {jax.devices()}")
            
            start = time.time()
            force_matching.train(fm_epochs)
            print(f"Total training time: {(time.time() - start) / 3600 : .1f} hours")
    
    force_matching.save_energy_params("../_data/output/alanine_dipeptide_fm_params.pkl", '.pkl', best=False)
    force_matching.save_trainer("../_data/output/alanine_dipeptide_fm_trainer.pkl", '.pkl')
    
force_matching = onp.load("../_data/output/alanine_dipeptide_fm_trainer.pkl", allow_pickle=True)
force_matching_params = tree_util.tree_map(
    jnp.asarray, onp.load("../_data/output/alanine_dipeptide_fm_params.pkl", allow_pickle=True)
)

with open("../_data/output/alanine_dipeptide_fm_training.log") as f:
    print(f.read())

Hide code cell output

Visible devices: [cuda(id=0)]
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
[Epoch 0]:
	Average train loss: 176929.70580
	Average val loss: 172895.71875
	Gradient norm: 9123601408.0
	Elapsed time = 3.230 min
	Per-target losses:
		F | train loss: 176929.7058035714 | val loss: 172895.71875

[Epoch 1]:
	Average train loss: 172596.42212
	Average val loss: 172559.46875
	Gradient norm: 11774687232.0
	Elapsed time = 0.680 min
	Per-target losses:
		F | train loss: 172596.42212053572 | val loss: 172559.46875

[Epoch 2]:
	Average train loss: 172276.97830
	Average val loss: 172431.96875
	Gradient norm: 20744620032.0
	Elapsed time = 0.692 min
	Per-target losses:
		F | train loss: 172276.97830357144 | val loss: 172431.96875

[Epoch 3]:
	Average train loss: 172154.51933
	Average val loss: 172473.640625
	Gradient norm: 6008239616.0
	Elapsed time = 0.663 min
	Per-target losses:
		F | train loss: 172154.51933035714 | val loss: 172473.640625

[Epoch 4]:
	Average train loss: 172082.00011
	Average val loss: 172211.15625
	Gradient norm: 10019186688.0
	Elapsed time = 0.658 min
	Per-target losses:
		F | train loss: 172082.00011160714 | val loss: 172211.15625

[Epoch 5]:
	Average train loss: 171997.42679
	Average val loss: 172236.0
	Gradient norm: 10897506304.0
	Elapsed time = 0.681 min
	Per-target losses:
		F | train loss: 171997.42678571428 | val loss: 172236.0

[Epoch 6]:
	Average train loss: 171949.75645
	Average val loss: 172006.609375
	Gradient norm: 8550966784.0
	Elapsed time = 0.704 min
	Per-target losses:
		F | train loss: 171949.75645089286 | val loss: 172006.609375

[Epoch 7]:
	Average train loss: 171910.50455
	Average val loss: 172095.1875
	Gradient norm: 4817648640.0
	Elapsed time = 0.690 min
	Per-target losses:
		F | train loss: 171910.50455357143 | val loss: 172095.1875

[Epoch 8]:
	Average train loss: 171866.08663
	Average val loss: 172141.65625
	Gradient norm: 15040643072.0
	Elapsed time = 0.666 min
	Per-target losses:
		F | train loss: 171866.0866294643 | val loss: 172141.65625

[Epoch 9]:
	Average train loss: 171849.00362
	Average val loss: 171915.90625
	Gradient norm: 2684147456.0
	Elapsed time = 0.664 min
	Per-target losses:
		F | train loss: 171849.00361607142 | val loss: 171915.90625

[Epoch 10]:
	Average train loss: 171815.13138
	Average val loss: 172070.078125
	Gradient norm: 27397423104.0
	Elapsed time = 0.696 min
	Per-target losses:
		F | train loss: 171815.13138392856 | val loss: 172070.078125

[Epoch 11]:
	Average train loss: 171779.52540
	Average val loss: 171947.890625
	Gradient norm: 1806733568.0
	Elapsed time = 0.705 min
	Per-target losses:
		F | train loss: 171779.5254017857 | val loss: 171947.890625

[Epoch 12]:
	Average train loss: 171770.32016
	Average val loss: 171976.6875
	Gradient norm: 3632666112.0
	Elapsed time = 0.686 min
	Per-target losses:
		F | train loss: 171770.32015625 | val loss: 171976.6875

[Epoch 13]:
	Average train loss: 171749.34475
	Average val loss: 171890.625
	Gradient norm: 7541457920.0
	Elapsed time = 0.668 min
	Per-target losses:
		F | train loss: 171749.34475446428 | val loss: 171890.625

[Epoch 14]:
	Average train loss: 171721.61147
	Average val loss: 171821.234375
	Gradient norm: 1706345728.0
	Elapsed time = 0.678 min
	Per-target losses:
		F | train loss: 171721.6114732143 | val loss: 171821.234375

[Epoch 15]:
	Average train loss: 171709.31304
	Average val loss: 171871.328125
	Gradient norm: 9120953344.0
	Elapsed time = 0.695 min
	Per-target losses:
		F | train loss: 171709.3130357143 | val loss: 171871.328125

[Epoch 16]:
	Average train loss: 171696.01545
	Average val loss: 171937.96875
	Gradient norm: 2631065088.0
	Elapsed time = 0.708 min
	Per-target losses:
		F | train loss: 171696.01544642856 | val loss: 171937.96875

[Epoch 17]:
	Average train loss: 171670.82723
	Average val loss: 171854.390625
	Gradient norm: 3674040576.0
	Elapsed time = 0.680 min
	Per-target losses:
		F | train loss: 171670.82723214285 | val loss: 171854.390625

[Epoch 18]:
	Average train loss: 171660.99710
	Average val loss: 171839.640625
	Gradient norm: 5002995200.0
	Elapsed time = 0.667 min
	Per-target losses:
		F | train loss: 171660.99709821428 | val loss: 171839.640625

[Epoch 19]:
	Average train loss: 171659.13237
	Average val loss: 171944.4375
	Gradient norm: 1226780416.0
	Elapsed time = 0.689 min
	Per-target losses:
		F | train loss: 171659.13236607143 | val loss: 171944.4375

[Epoch 20]:
	Average train loss: 171645.16540
	Average val loss: 171778.65625
	Gradient norm: 1932255360.0
	Elapsed time = 0.697 min
	Per-target losses:
		F | train loss: 171645.16540178572 | val loss: 171778.65625

[Epoch 21]:
	Average train loss: 171623.16170
	Average val loss: 171983.484375
	Gradient norm: 7034304000.0
	Elapsed time = 0.706 min
	Per-target losses:
		F | train loss: 171623.16169642858 | val loss: 171983.484375

[Epoch 22]:
	Average train loss: 171621.38083
	Average val loss: 171830.9375
	Gradient norm: 4508636672.0
	Elapsed time = 0.674 min
	Per-target losses:
		F | train loss: 171621.38082589285 | val loss: 171830.9375

[Epoch 23]:
	Average train loss: 171604.86754
	Average val loss: 171777.171875
	Gradient norm: 2017770112.0
	Elapsed time = 0.671 min
	Per-target losses:
		F | train loss: 171604.86754464285 | val loss: 171777.171875

[Epoch 24]:
	Average train loss: 171600.03873
	Average val loss: 171885.9375
	Gradient norm: 2574169344.0
	Elapsed time = 0.696 min
	Per-target losses:
		F | train loss: 171600.03872767856 | val loss: 171885.9375

[Epoch 25]:
	Average train loss: 171585.58248
	Average val loss: 171806.875
	Gradient norm: 2361884416.0
	Elapsed time = 0.712 min
	Per-target losses:
		F | train loss: 171585.58247767857 | val loss: 171806.875

[Epoch 26]:
	Average train loss: 171581.53357
	Average val loss: 171844.015625
	Gradient norm: 3771243008.0
	Elapsed time = 0.695 min
	Per-target losses:
		F | train loss: 171581.53357142856 | val loss: 171844.015625

[Epoch 27]:
	Average train loss: 171568.04761
	Average val loss: 171795.5625
	Gradient norm: 5091220992.0
	Elapsed time = 0.669 min
	Per-target losses:
		F | train loss: 171568.04761160715 | val loss: 171795.5625

[Epoch 28]:
	Average train loss: 171552.08833
	Average val loss: 171883.46875
	Gradient norm: 8399463424.0
	Elapsed time = 0.686 min
	Per-target losses:
		F | train loss: 171552.08832589287 | val loss: 171883.46875

[Epoch 29]:
	Average train loss: 171551.66145
	Average val loss: 171775.5
	Gradient norm: 1917335168.0
	Elapsed time = 0.695 min
	Per-target losses:
		F | train loss: 171551.66145089286 | val loss: 171775.5

[Epoch 30]:
	Average train loss: 171545.76705
	Average val loss: 171759.71875
	Gradient norm: 1608831616.0
	Elapsed time = 0.710 min
	Per-target losses:
		F | train loss: 171545.76705357144 | val loss: 171759.71875

[Epoch 31]:
	Average train loss: 171535.17933
	Average val loss: 171810.46875
	Gradient norm: 5032909824.0
	Elapsed time = 0.687 min
	Per-target losses:
		F | train loss: 171535.17933035715 | val loss: 171810.46875

[Epoch 32]:
	Average train loss: 171533.74353
	Average val loss: 171789.703125
	Gradient norm: 2831327744.0
	Elapsed time = 0.668 min
	Per-target losses:
		F | train loss: 171533.74352678572 | val loss: 171789.703125

[Epoch 33]:
	Average train loss: 171519.66949
	Average val loss: 171854.921875
	Gradient norm: 3584940032.0
	Elapsed time = 0.696 min
	Per-target losses:
		F | train loss: 171519.66948660713 | val loss: 171854.921875

[Epoch 34]:
	Average train loss: 171517.76920
	Average val loss: 171778.578125
	Gradient norm: 903441664.0
	Elapsed time = 0.697 min
	Per-target losses:
		F | train loss: 171517.76919642856 | val loss: 171778.578125

[Epoch 35]:
	Average train loss: 171512.97766
	Average val loss: 171751.21875
	Gradient norm: 1391817344.0
	Elapsed time = 0.712 min
	Per-target losses:
		F | train loss: 171512.97765625 | val loss: 171751.21875

[Epoch 36]:
	Average train loss: 171504.98132
	Average val loss: 171776.484375
	Gradient norm: 1910281728.0
	Elapsed time = 0.674 min
	Per-target losses:
		F | train loss: 171504.9813169643 | val loss: 171776.484375

[Epoch 37]:
	Average train loss: 171500.50136
	Average val loss: 171787.21875
	Gradient norm: 2352769792.0
	Elapsed time = 0.674 min
	Per-target losses:
		F | train loss: 171500.50136160714 | val loss: 171787.21875

[Epoch 38]:
	Average train loss: 171493.98058
	Average val loss: 171779.21875
	Gradient norm: 1320057984.0
	Elapsed time = 0.697 min
	Per-target losses:
		F | train loss: 171493.98058035714 | val loss: 171779.21875

[Epoch 39]:
	Average train loss: 171489.56167
	Average val loss: 171746.890625
	Gradient norm: 2118475264.0
	Elapsed time = 0.696 min
	Per-target losses:
		F | train loss: 171489.56167410716 | val loss: 171746.890625

[Epoch 40]:
	Average train loss: 171489.27815
	Average val loss: 171867.328125
	Gradient norm: 1183177856.0
	Elapsed time = 0.703 min
	Per-target losses:
		F | train loss: 171489.27814732143 | val loss: 171867.328125

[Epoch 41]:
	Average train loss: 171475.39185
	Average val loss: 171737.53125
	Gradient norm: 3577387520.0
	Elapsed time = 0.673 min
	Per-target losses:
		F | train loss: 171475.3918526786 | val loss: 171737.53125

[Epoch 42]:
	Average train loss: 171472.20924
	Average val loss: 171743.03125
	Gradient norm: 4128272384.0
	Elapsed time = 0.689 min
	Per-target losses:
		F | train loss: 171472.20924107142 | val loss: 171743.03125

[Epoch 43]:
	Average train loss: 171472.44147
	Average val loss: 171736.890625
	Gradient norm: 4214556160.0
	Elapsed time = 0.697 min
	Per-target losses:
		F | train loss: 171472.44147321428 | val loss: 171736.890625

[Epoch 44]:
	Average train loss: 171459.43170
	Average val loss: 171746.390625
	Gradient norm: 1125263744.0
	Elapsed time = 0.695 min
	Per-target losses:
		F | train loss: 171459.43169642857 | val loss: 171746.390625

[Epoch 45]:
	Average train loss: 171456.98261
	Average val loss: 171724.40625
	Gradient norm: 1239106304.0
	Elapsed time = 0.692 min
	Per-target losses:
		F | train loss: 171456.98261160715 | val loss: 171724.40625

[Epoch 46]:
	Average train loss: 171450.39911
	Average val loss: 171722.3125
	Gradient norm: 1474299904.0
	Elapsed time = 0.670 min
	Per-target losses:
		F | train loss: 171450.39910714285 | val loss: 171722.3125

[Epoch 47]:
	Average train loss: 171450.30375
	Average val loss: 171723.5
	Gradient norm: 581114176.0
	Elapsed time = 0.701 min
	Per-target losses:
		F | train loss: 171450.30375 | val loss: 171723.5

[Epoch 48]:
	Average train loss: 171445.57556
	Average val loss: 171750.328125
	Gradient norm: 1099431424.0
	Elapsed time = 0.700 min
	Per-target losses:
		F | train loss: 171445.57555803572 | val loss: 171750.328125

[Epoch 49]:
	Average train loss: 171442.09750
	Average val loss: 171704.609375
	Gradient norm: 2439143168.0
	Elapsed time = 0.717 min
	Per-target losses:
		F | train loss: 171442.0975 | val loss: 171704.609375

[Epoch 50]:
	Average train loss: 171441.89283
	Average val loss: 171731.046875
	Gradient norm: 1725326592.0
	Elapsed time = 0.672 min
	Per-target losses:
		F | train loss: 171441.89283482142 | val loss: 171731.046875

[Epoch 51]:
	Average train loss: 171430.63596
	Average val loss: 171784.375
	Gradient norm: 2093662080.0
	Elapsed time = 0.683 min
	Per-target losses:
		F | train loss: 171430.63595982143 | val loss: 171784.375

[Epoch 52]:
	Average train loss: 171433.62161
	Average val loss: 171714.84375
	Gradient norm: 1248705152.0
	Elapsed time = 0.698 min
	Per-target losses:
		F | train loss: 171433.62160714285 | val loss: 171714.84375

[Epoch 53]:
	Average train loss: 171427.84569
	Average val loss: 171737.71875
	Gradient norm: 2855405056.0
	Elapsed time = 0.707 min
	Per-target losses:
		F | train loss: 171427.8456919643 | val loss: 171737.71875

[Epoch 54]:
	Average train loss: 171420.17080
	Average val loss: 171714.203125
	Gradient norm: 2034437632.0
	Elapsed time = 0.701 min
	Per-target losses:
		F | train loss: 171420.17080357144 | val loss: 171714.203125

[Epoch 55]:
	Average train loss: 171417.33560
	Average val loss: 171710.609375
	Gradient norm: 784035264.0
	Elapsed time = 0.685 min
	Per-target losses:
		F | train loss: 171417.33560267856 | val loss: 171710.609375

[Epoch 56]:
	Average train loss: 171419.52964
	Average val loss: 171712.84375
	Gradient norm: 1742775552.0
	Elapsed time = 0.699 min
	Per-target losses:
		F | train loss: 171419.52964285715 | val loss: 171712.84375

[Epoch 57]:
	Average train loss: 171413.24033
	Average val loss: 171718.59375
	Gradient norm: 1332674688.0
	Elapsed time = 0.706 min
	Per-target losses:
		F | train loss: 171413.24033482143 | val loss: 171718.59375

[Epoch 58]:
	Average train loss: 171412.05859
	Average val loss: 171696.890625
	Gradient norm: 1831571712.0
	Elapsed time = 0.701 min
	Per-target losses:
		F | train loss: 171412.05859375 | val loss: 171696.890625

[Epoch 59]:
	Average train loss: 171406.59828
	Average val loss: 171721.0625
	Gradient norm: 1702457600.0
	Elapsed time = 0.693 min
	Per-target losses:
		F | train loss: 171406.59828125 | val loss: 171721.0625

[Epoch 60]:
	Average train loss: 171406.02123
	Average val loss: 171692.390625
	Gradient norm: 917907776.0
	Elapsed time = 0.682 min
	Per-target losses:
		F | train loss: 171406.02122767858 | val loss: 171692.390625

[Epoch 61]:
	Average train loss: 171401.17757
	Average val loss: 171712.875
	Gradient norm: 993114368.0
	Elapsed time = 0.706 min
	Per-target losses:
		F | train loss: 171401.1775669643 | val loss: 171712.875

[Epoch 62]:
	Average train loss: 171398.06547
	Average val loss: 171711.015625
	Gradient norm: 1238548224.0
	Elapsed time = 0.711 min
	Per-target losses:
		F | train loss: 171398.06546875 | val loss: 171711.015625

[Epoch 63]:
	Average train loss: 171392.70897
	Average val loss: 171695.015625
	Gradient norm: 1207603200.0
	Elapsed time = 0.698 min
	Per-target losses:
		F | train loss: 171392.7089732143 | val loss: 171695.015625

[Epoch 64]:
	Average train loss: 171392.10917
	Average val loss: 171719.34375
	Gradient norm: 1597482240.0
	Elapsed time = 0.685 min
	Per-target losses:
		F | train loss: 171392.10917410714 | val loss: 171719.34375

[Epoch 65]:
	Average train loss: 171388.53275
	Average val loss: 171716.5625
	Gradient norm: 3158761984.0
	Elapsed time = 0.689 min
	Per-target losses:
		F | train loss: 171388.53274553572 | val loss: 171716.5625

[Epoch 66]:
	Average train loss: 171386.20917
	Average val loss: 171748.21875
	Gradient norm: 2455853824.0
	Elapsed time = 0.703 min
	Per-target losses:
		F | train loss: 171386.20917410715 | val loss: 171748.21875

[Epoch 67]:
	Average train loss: 171384.47759
	Average val loss: 171691.296875
	Gradient norm: 2556560384.0
	Elapsed time = 0.705 min
	Per-target losses:
		F | train loss: 171384.47758928573 | val loss: 171691.296875

[Epoch 68]:
	Average train loss: 171381.41324
	Average val loss: 171720.453125
	Gradient norm: 1343042176.0
	Elapsed time = 0.689 min
	Per-target losses:
		F | train loss: 171381.41323660716 | val loss: 171720.453125

[Epoch 69]:
	Average train loss: 171376.76217
	Average val loss: 171699.796875
	Gradient norm: 722508608.0
	Elapsed time = 0.691 min
	Per-target losses:
		F | train loss: 171376.76216517857 | val loss: 171699.796875

[Epoch 70]:
	Average train loss: 171376.53127
	Average val loss: 171706.5
	Gradient norm: 2094104448.0
	Elapsed time = 0.696 min
	Per-target losses:
		F | train loss: 171376.53127232142 | val loss: 171706.5

[Epoch 71]:
	Average train loss: 171376.45504
	Average val loss: 171710.84375
	Gradient norm: 2235779584.0
	Elapsed time = 0.702 min
	Per-target losses:
		F | train loss: 171376.45504464285 | val loss: 171710.84375

[Epoch 72]:
	Average train loss: 171371.00333
	Average val loss: 171708.03125
	Gradient norm: 789579584.0
	Elapsed time = 0.707 min
	Per-target losses:
		F | train loss: 171371.00332589285 | val loss: 171708.03125

[Epoch 73]:
	Average train loss: 171371.73154
	Average val loss: 171703.90625
	Gradient norm: 2599988736.0
	Elapsed time = 0.680 min
	Per-target losses:
		F | train loss: 171371.73154017856 | val loss: 171703.90625

[Epoch 74]:
	Average train loss: 171368.44261
	Average val loss: 171695.96875
	Gradient norm: 2016055552.0
	Elapsed time = 0.694 min
	Per-target losses:
		F | train loss: 171368.44261160714 | val loss: 171695.96875

[Epoch 75]:
	Average train loss: 171364.46337
	Average val loss: 171700.015625
	Gradient norm: 1387525248.0
	Elapsed time = 0.706 min
	Per-target losses:
		F | train loss: 171364.46337053573 | val loss: 171700.015625

[Epoch 76]:
	Average train loss: 171362.98002
	Average val loss: 171681.78125
	Gradient norm: 520471392.0
	Elapsed time = 0.705 min
	Per-target losses:
		F | train loss: 171362.98002232143 | val loss: 171681.78125

[Epoch 77]:
	Average train loss: 171362.73337
	Average val loss: 171675.4375
	Gradient norm: 808506304.0
	Elapsed time = 0.696 min
	Per-target losses:
		F | train loss: 171362.73337053572 | val loss: 171675.4375

[Epoch 78]:
	Average train loss: 171357.42449
	Average val loss: 171682.46875
	Gradient norm: 1157684608.0
	Elapsed time = 0.679 min
	Per-target losses:
		F | train loss: 171357.42448660714 | val loss: 171682.46875

[Epoch 79]:
	Average train loss: 171356.19129
	Average val loss: 171712.125
	Gradient norm: 982082112.0
	Elapsed time = 0.707 min
	Per-target losses:
		F | train loss: 171356.19129464286 | val loss: 171712.125

[Epoch 80]:
	Average train loss: 171357.02752
	Average val loss: 171707.359375
	Gradient norm: 732747200.0
	Elapsed time = 0.703 min
	Per-target losses:
		F | train loss: 171357.02752232144 | val loss: 171707.359375

[Epoch 81]:
	Average train loss: 171353.26801
	Average val loss: 171684.8125
	Gradient norm: 1249264512.0
	Elapsed time = 0.700 min
	Per-target losses:
		F | train loss: 171353.26801339287 | val loss: 171684.8125

[Epoch 82]:
	Average train loss: 171350.82027
	Average val loss: 171681.640625
	Gradient norm: 1534267520.0
	Elapsed time = 0.688 min
	Per-target losses:
		F | train loss: 171350.82026785714 | val loss: 171681.640625

[Epoch 83]:
	Average train loss: 171350.81031
	Average val loss: 171680.8125
	Gradient norm: 1149604352.0
	Elapsed time = 0.689 min
	Per-target losses:
		F | train loss: 171350.8103125 | val loss: 171680.8125

[Epoch 84]:
	Average train loss: 171348.69100
	Average val loss: 171693.578125
	Gradient norm: 1036749952.0
	Elapsed time = 0.706 min
	Per-target losses:
		F | train loss: 171348.6910044643 | val loss: 171693.578125

[Epoch 85]:
	Average train loss: 171348.47580
	Average val loss: 171684.65625
	Gradient norm: 1325895040.0
	Elapsed time = 0.707 min
	Per-target losses:
		F | train loss: 171348.47580357143 | val loss: 171684.65625

[Epoch 86]:
	Average train loss: 171344.73935
	Average val loss: 171686.203125
	Gradient norm: 1524068352.0
	Elapsed time = 0.703 min
	Per-target losses:
		F | train loss: 171344.73935267856 | val loss: 171686.203125

[Epoch 87]:
	Average train loss: 171343.08781
	Average val loss: 171687.953125
	Gradient norm: 2213083904.0
	Elapsed time = 0.677 min
	Per-target losses:
		F | train loss: 171343.0878125 | val loss: 171687.953125

[Epoch 88]:
	Average train loss: 171342.23192
	Average val loss: 171689.421875
	Gradient norm: 2924069120.0
	Elapsed time = 0.705 min
	Per-target losses:
		F | train loss: 171342.23191964286 | val loss: 171689.421875

[Epoch 89]:
	Average train loss: 171340.33558
	Average val loss: 171685.125
	Gradient norm: 655298496.0
	Elapsed time = 0.707 min
	Per-target losses:
		F | train loss: 171340.33558035715 | val loss: 171685.125

[Epoch 90]:
	Average train loss: 171340.86862
	Average val loss: 171683.84375
	Gradient norm: 2268468480.0
	Elapsed time = 0.702 min
	Per-target losses:
		F | train loss: 171340.86861607144 | val loss: 171683.84375

[Epoch 91]:
	Average train loss: 171339.36000
	Average val loss: 171683.625
	Gradient norm: 876853952.0
	Elapsed time = 0.690 min
	Per-target losses:
		F | train loss: 171339.36 | val loss: 171683.625

[Epoch 92]:
	Average train loss: 171337.82900
	Average val loss: 171687.859375
	Gradient norm: 796945856.0
	Elapsed time = 0.676 min
	Per-target losses:
		F | train loss: 171337.8289955357 | val loss: 171687.859375

[Epoch 93]:
	Average train loss: 171336.64951
	Average val loss: 171691.265625
	Gradient norm: 1060686848.0
	Elapsed time = 0.712 min
	Per-target losses:
		F | train loss: 171336.64950892856 | val loss: 171691.265625

[Epoch 94]:
	Average train loss: 171336.63949
	Average val loss: 171694.0625
	Gradient norm: 2352739072.0
	Elapsed time = 0.706 min
	Per-target losses:
		F | train loss: 171336.63948660714 | val loss: 171694.0625

[Epoch 95]:
	Average train loss: 171332.43127
	Average val loss: 171685.46875
	Gradient norm: 1889593472.0
	Elapsed time = 0.704 min
	Per-target losses:
		F | train loss: 171332.43127232144 | val loss: 171685.46875

[Epoch 96]:
	Average train loss: 171332.11388
	Average val loss: 171684.859375
	Gradient norm: 1890586880.0
	Elapsed time = 0.679 min
	Per-target losses:
		F | train loss: 171332.11388392857 | val loss: 171684.859375

[Epoch 97]:
	Average train loss: 171329.96730
	Average val loss: 171693.203125
	Gradient norm: 2480389120.0
	Elapsed time = 0.681 min
	Per-target losses:
		F | train loss: 171329.96729910714 | val loss: 171693.203125

[Epoch 98]:
	Average train loss: 171330.64931
	Average val loss: 171689.8125
	Gradient norm: 5878473728.0
	Elapsed time = 0.713 min
	Per-target losses:
		F | train loss: 171330.6493080357 | val loss: 171689.8125

[Epoch 99]:
	Average train loss: 171329.47799
	Average val loss: 171688.484375
	Gradient norm: 501658976.0
	Elapsed time = 0.700 min
	Per-target losses:
		F | train loss: 171329.47799107144 | val loss: 171688.484375

Total training time:  1.3 hours

Plotting the training and validation loss indicates convergence after approximately $50$ epochs. Further training improves the validation loss only slightly, but does not lead to overfitting. Therefore, we save the final model with $100$ training epochs.

fig, ax = plt.subplots(1, 1, figsize=(5, 4), layout="constrained")
ax.plot(force_matching["train_losses"], label="Training")
ax.plot(force_matching["val_losses"], label="Validation")
ax.set_xticks(ticks=range(0, fm_epochs + 1, 25))
ax.set_xlabel("Epoch")
ax.set_ylabel("MSE Force Loss")
ax.legend()
<matplotlib.legend.Legend at 0x7f1d102b8490>
../_images/0b75c66a3c83eff63608c37c6661f0654ecd3397645ef140f8231a90e71214c5.svg

Force Matching and Relative Entropy Minimization#

Due to the expense of training with relative entropy minimization, we train a third model from an improved initial state. As improved initial state, we use the FM trained model with the lowest validation loss.

rm_post_epochs = 50
rm_post_initial_lr = 0.0005

lr_schedule = optax.exponential_decay(rm_post_initial_lr, rm_post_epochs, 0.01)
optimizer = optax.chain(
    optax.scale_by_adam(0.1, 0.4),
    optax.scale_by_schedule(lr_schedule),
    optax.scale_by_learning_rate(1.0)
)

rm_post_fm = trainers.RelativeEntropy(
    force_matching_params, optimizer, reweight_ratio=1.1,
    energy_fn_template=energy_fn_template)

rm_post_fm.add_statepoint(
    position_dataset[:re_used_dataset_size, ...],
    energy_fn_template, sim_template, neighbor_fn,
    re_timings, state_kwargs, reference_state,
    reference_batch_size=re_used_dataset_size,
    vmap_batch=n_chains, resample_simstates=True)

rm_post_fm.init_step_size_adaption(0.25)
/home/paul/chemtrain_rerun_ad/chemtrain/ensemble/reweighting.py:777: UserWarning: Propagation function is not safe by default. Do not forget to use the wrapper around the compute function to ensure that the neighborlist does not overflow.
  warnings.warn(
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
Time for trajectory initialization 0: 2.6623266140619912 mins
[Step size] Use 7 iterations for 10 interior points.
if os.environ.get("FM_RM_TRAINING", "False").lower() == "true":
    # Save the training log
    with open("../_data/output/alanine_dipeptide_fm+rm_training.log", "w") as f:
        with contextlib.redirect_stdout(f), contextlib.redirect_stderr(f):

            start = time.time()
            rm_post_fm.train(rm_post_epochs)
            print(f"Total training time: {(time.time() - start) / 3600 : .1f} hours")
    
    rm_post_fm.save_energy_params("../_data/output/alanine_dipeptide_fm+rm_params.pkl", '.pkl')
    rm_post_fm.save_trainer("../_data/output/alanine_dipeptide_fm+rm_trainer.pkl", '.pkl')
    
rm_post_fm = onp.load("../_data/output/alanine_dipeptide_fm+rm_trainer.pkl", allow_pickle=True)
rm_post_fm_params = tree_util.tree_map(
    jnp.asarray, onp.load("../_data/output/alanine_dipeptide_fm+rm_params.pkl", allow_pickle=True)
)

with open("../_data/output/alanine_dipeptide_fm+rm_training.log") as f:
    print(f.read())

Hide code cell output

/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
[Step Size] Found optimal step size 1.0 with residual 0.7550487518310547

[RE] Epoch 0
	Mean Delta RE loss = 36.00563
	Gradient norm: 124.37580108642578
	Elapsed time = 4.856 min
[Statepoint 0]
	kT = 2.435 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2072954177856445

[RE] Epoch 1
	Mean Delta RE loss = 29.18939
	Gradient norm: 28.948904037475586
	Elapsed time = 3.113 min
[Statepoint 0]
	kT = 2.450 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.1973085403442383

[RE] Epoch 2
	Mean Delta RE loss = 31.97235
	Gradient norm: 17.787336349487305
	Elapsed time = 3.107 min
[Statepoint 0]
	kT = 2.443 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.1749677658081055

[RE] Epoch 3
	Mean Delta RE loss = 32.02235
	Gradient norm: 33.27383041381836
	Elapsed time = 3.106 min
[Statepoint 0]
	kT = 2.435 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.222977638244629

[RE] Epoch 4
	Mean Delta RE loss = 30.61776
	Gradient norm: 29.2092227935791
	Elapsed time = 3.110 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2871942520141602

[RE] Epoch 5
	Mean Delta RE loss = 31.03158
	Gradient norm: 14.097911834716797
	Elapsed time = 3.107 min
[Statepoint 0]
	kT = 2.448 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2399444580078125

[RE] Epoch 6
	Mean Delta RE loss = 31.19930
	Gradient norm: 14.768195152282715
	Elapsed time = 3.114 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.2291173934936523

[RE] Epoch 7
	Mean Delta RE loss = 32.76921
	Gradient norm: 28.707584381103516
	Elapsed time = 3.110 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.286931037902832

[RE] Epoch 8
	Mean Delta RE loss = 30.51094
	Gradient norm: 22.228130340576172
	Elapsed time = 3.115 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3354434967041016

[RE] Epoch 9
	Mean Delta RE loss = 32.36247
	Gradient norm: 9.235512733459473
	Elapsed time = 3.109 min
[Statepoint 0]
	kT = 2.431 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3513851165771484

[RE] Epoch 10
	Mean Delta RE loss = 32.07355
	Gradient norm: 5.61235237121582
	Elapsed time = 3.120 min
[Statepoint 0]
	kT = 2.444 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3605690002441406

[RE] Epoch 11
	Mean Delta RE loss = 31.21972
	Gradient norm: 3.3407294750213623
	Elapsed time = 3.114 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3429327011108398

[RE] Epoch 12
	Mean Delta RE loss = 32.20423
	Gradient norm: 3.997459650039673
	Elapsed time = 3.107 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3401222229003906

[RE] Epoch 13
	Mean Delta RE loss = 32.43740
	Gradient norm: 7.247232437133789
	Elapsed time = 3.116 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3477020263671875

[RE] Epoch 14
	Mean Delta RE loss = 30.45119
	Gradient norm: 9.020105361938477
	Elapsed time = 3.114 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3727474212646484

[RE] Epoch 15
	Mean Delta RE loss = 32.33013
	Gradient norm: 1.9302358627319336
	Elapsed time = 3.108 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3629083633422852

[RE] Epoch 16
	Mean Delta RE loss = 32.00609
	Gradient norm: 4.440040588378906
	Elapsed time = 3.105 min
[Statepoint 0]
	kT = 2.447 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3645515441894531

[RE] Epoch 17
	Mean Delta RE loss = 31.72719
	Gradient norm: 5.785034656524658
	Elapsed time = 3.116 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3814029693603516

[RE] Epoch 18
	Mean Delta RE loss = 31.20028
	Gradient norm: 0.7455909848213196
	Elapsed time = 3.114 min
[Statepoint 0]
	kT = 2.434 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3774614334106445

[RE] Epoch 19
	Mean Delta RE loss = 30.88225
	Gradient norm: 1.4139515161514282
	Elapsed time = 3.110 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3779487609863281

[RE] Epoch 20
	Mean Delta RE loss = 31.14375
	Gradient norm: 1.2635220289230347
	Elapsed time = 3.108 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.373457908630371

[RE] Epoch 21
	Mean Delta RE loss = 31.90127
	Gradient norm: 4.567408561706543
	Elapsed time = 3.117 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.380157470703125

[RE] Epoch 22
	Mean Delta RE loss = 32.06521
	Gradient norm: 2.277247667312622
	Elapsed time = 3.114 min
[Statepoint 0]
	kT = 2.449 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3764352798461914

[RE] Epoch 23
	Mean Delta RE loss = 32.54579
	Gradient norm: 9.940178871154785
	Elapsed time = 3.106 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.383091926574707

[RE] Epoch 24
	Mean Delta RE loss = 31.94402
	Gradient norm: 1.2865967750549316
	Elapsed time = 3.113 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3852910995483398

[RE] Epoch 25
	Mean Delta RE loss = 31.36543
	Gradient norm: 0.4130084216594696
	Elapsed time = 3.116 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3854036331176758

[RE] Epoch 26
	Mean Delta RE loss = 31.31074
	Gradient norm: 0.33154529333114624
	Elapsed time = 3.105 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3821048736572266

[RE] Epoch 27
	Mean Delta RE loss = 31.12829
	Gradient norm: 2.488618850708008
	Elapsed time = 3.114 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3844594955444336

[RE] Epoch 28
	Mean Delta RE loss = 31.57828
	Gradient norm: 0.7623170018196106
	Elapsed time = 3.112 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3846654891967773

[RE] Epoch 29
	Mean Delta RE loss = 31.79181
	Gradient norm: 0.509951114654541
	Elapsed time = 3.108 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3838434219360352

[RE] Epoch 30
	Mean Delta RE loss = 32.07230
	Gradient norm: 2.2512776851654053
	Elapsed time = 3.111 min
[Statepoint 0]
	kT = 2.431 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.384939193725586

[RE] Epoch 31
	Mean Delta RE loss = 32.02459
	Gradient norm: 1.5240644216537476
	Elapsed time = 3.120 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.384537696838379

[RE] Epoch 32
	Mean Delta RE loss = 32.04109
	Gradient norm: 2.2417218685150146
	Elapsed time = 3.105 min
[Statepoint 0]
	kT = 2.448 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3852167129516602

[RE] Epoch 33
	Mean Delta RE loss = 32.19444
	Gradient norm: 2.984471321105957
	Elapsed time = 3.110 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3848495483398438

[RE] Epoch 34
	Mean Delta RE loss = 31.98781
	Gradient norm: 8.29981803894043
	Elapsed time = 3.116 min
[Statepoint 0]
	kT = 2.441 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3858633041381836

[RE] Epoch 35
	Mean Delta RE loss = 31.75151
	Gradient norm: 1.2233120203018188
	Elapsed time = 3.113 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3856868743896484

[RE] Epoch 36
	Mean Delta RE loss = 31.56382
	Gradient norm: 2.1066699028015137
	Elapsed time = 3.105 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3856697082519531

[RE] Epoch 37
	Mean Delta RE loss = 31.67507
	Gradient norm: 3.8059933185577393
	Elapsed time = 3.112 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3860588073730469

[RE] Epoch 38
	Mean Delta RE loss = 31.53700
	Gradient norm: 0.7111489176750183
	Elapsed time = 3.101 min
[Statepoint 0]
	kT = 2.447 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3856821060180664

[RE] Epoch 39
	Mean Delta RE loss = 31.39881
	Gradient norm: 4.373356342315674
	Elapsed time = 3.111 min
[Statepoint 0]
	kT = 2.438 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3858509063720703

[RE] Epoch 40
	Mean Delta RE loss = 31.57184
	Gradient norm: 2.320279836654663
	Elapsed time = 3.109 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3860750198364258

[RE] Epoch 41
	Mean Delta RE loss = 31.60853
	Gradient norm: 0.4560565948486328
	Elapsed time = 3.112 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3861942291259766

[RE] Epoch 42
	Mean Delta RE loss = 31.63672
	Gradient norm: 0.7079092860221863
	Elapsed time = 3.109 min
[Statepoint 0]
	kT = 2.440 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.386124610900879

[RE] Epoch 43
	Mean Delta RE loss = 31.65482
	Gradient norm: 0.44317826628685
	Elapsed time = 3.106 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.385946273803711

[RE] Epoch 44
	Mean Delta RE loss = 31.71917
	Gradient norm: 4.31718635559082
	Elapsed time = 3.106 min
[Statepoint 0]
	kT = 2.442 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3859977722167969

[RE] Epoch 45
	Mean Delta RE loss = 31.77694
	Gradient norm: 6.135014057159424
	Elapsed time = 3.115 min
[Statepoint 0]
	kT = 2.439 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3860893249511719

[RE] Epoch 46
	Mean Delta RE loss = 31.84545
	Gradient norm: 3.542604923248291
	Elapsed time = 3.111 min
[Statepoint 0]
	kT = 2.433 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3862218856811523

[RE] Epoch 47
	Mean Delta RE loss = 31.91901
	Gradient norm: 0.5241677761077881
	Elapsed time = 3.113 min
[Statepoint 0]
	kT = 2.437 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3861351013183594

[RE] Epoch 48
	Mean Delta RE loss = 31.86957
	Gradient norm: 1.8216795921325684
	Elapsed time = 3.111 min
[Statepoint 0]
	kT = 2.436 ref_kT = 2.494
[Step Size] Found optimal step size 1.0 with residual 1.3861684799194336

[RE] Epoch 49
	Mean Delta RE loss = 31.92449
	Gradient norm: 1.3624800443649292
	Elapsed time = 3.107 min
[Statepoint 0]
	kT = 2.445 ref_kT = 2.494
Total training time:  2.6 hours
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4), layout="constrained")

ax1.plot(rm_post_fm["delta_re"][0])
ax1.set_xticks(ticks=range(0, rm_post_epochs + 1, 10))
ax1.set_xlabel("Epoch")
ax1.set_ylabel("RE Loss")

ax2.plot(rm_post_fm["gradient_norm_history"])
ax2.set_xticks(ticks=range(0, rm_post_epochs + 1, 10))
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Gradient Norm")
Text(0, 0.5, 'Gradient Norm')
../_images/bbaf98fc71c2f8de06e7cd1b579097d6563a30062efefc66323b4c405a6615f6.svg

Evaluation#

With both trained models, we perform a series of simulations to evaluate the alignment of both models with the reference data. To reduce the total running time, we again run $100$ shorter simulations, which still correspond to a total time of $100~\text{ns}$.

eval_total_time = 2100.
eval_t_eq = 100.
eval_t_sample = .5

eval_timings = sampling.process_printouts(
    time_step=dt, total_time=eval_total_time,
    t_equilib=eval_t_eq, print_every=eval_t_sample
)

trajectory_generator = sampling.trajectory_generator_init(
    sim_template, energy_fn_template, eval_timings)
trajectory_generator = jax.vmap(
    functools.partial(trajectory_generator, **state_kwargs), (0, None)
)
if os.environ.get("EVALUATION", "False").lower() == "true":
 
    all_parameters = util.tree_stack((
        relative_entropy_params,
        force_matching_params,
        rm_post_fm_params
    ))
    
    t_start = time.time()
    all_traj_states = trajectory_generator(all_parameters, reference_state)
    
    assert not jnp.any(all_traj_states.overflow), (
        'Neighborlist overflow during trajectory generation. '
        'Increase capacity and re-run.'
    )
    
    print(f'Total runtime: {(time.time() - t_start) / 3600 :.1f} hours')
    
    (re_traj_state, fm_traj_state, fm_rm_traj_state) = util.tree_unstack(all_traj_states)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain_rerun_AD/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
Total runtime: 4.4 hours
def postprocess_fn(positions):
    # Compute the dihedral angles
    dihedral_idxs = jnp.array([[1, 3, 4, 6], [3, 4, 6, 8]])  # 0: phi    1: psi
    batched_dihedrals = jax.vmap(
        custom_quantity.dihedral_displacement, (0, None, None)
    )
    
    dihedral_angles = batched_dihedrals(positions, displacement_fn, dihedral_idxs)
    
    return dihedral_angles.T

if os.environ.get("EVALUATION", "False").lower() == "true":
    ref_phi, ref_psi = postprocess_fn(position_dataset)
    fm_phi, fm_psi = postprocess_fn(fm_traj_state.trajectory.position)
    re_phi, re_psi = postprocess_fn(re_traj_state.trajectory.position)
    fm_rm_phi, fm_rm_psi = postprocess_fn(fm_rm_traj_state.trajectory.position)
    
    onp.savez(
        "../_data/output/alanine_dipeptide_dihedral_angles.npz",
        ref_phi=ref_phi, ref_psi=ref_psi,
        fm_phi=fm_phi, fm_psi=fm_psi,
        re_phi=re_phi, re_psi=re_psi,
        fm_rm_phi=fm_rm_phi, fm_rm_psi=fm_rm_psi
    )
    
results = onp.load("../_data/output/alanine_dipeptide_dihedral_angles.npz")

ref_phi=results["ref_phi"]
ref_psi=results["ref_psi"]
fm_phi=results["fm_phi"]
fm_psi=results["fm_psi"]
re_phi=results["re_phi"]
re_psi=results["re_psi"]
fm_rm_phi=results["fm_rm_phi"]
fm_rm_psi=results["fm_rm_psi"]
def plot_1d_dihedral(ax, angles, labels, bins=60, degrees=True,
                     xlabel='$\phi$ in deg', ylabel=True):
    """Plot  1D histogram splines for a dihedral angle. """
    color = ['#368274', '#0C7CBA', '#C92D39', 'k']
    line = ['-', '-', '-', '--']
    
    n_models = len(angles)
    for i in range(n_models):
        if degrees:
            angles_conv = angles[i]
            hist_range = [-180, 180]
        else:
            angles_conv = onp.rad2deg(angles[i])
            hist_range = [-onp.pi, onp.pi]

        # Compute the histogram
        hist, x_bins = jnp.histogram(angles_conv, bins=bins, density=True, range=hist_range)
        width = x_bins[1] - x_bins[0]
        bin_center = x_bins + width / 2
        
        ax.plot(
            bin_center[:-1], hist, label=labels[i], color=color[i],
            linestyle=line[i], linewidth=2.0
        )

    ax.set_xlabel(xlabel)
    if ylabel:
        ax.set_ylabel('Density')
    
    return ax

def plot_histogram_free_energy(ax, phi, psi, kbt, degrees=True, ylabel=False, title=""):
    """Plot 2D free energy histogram for alanine from the dihedral angles."""
    cmap = plt.get_cmap('viridis')

    if degrees:
        phi = jnp.deg2rad(phi)
        psi = jnp.deg2rad(psi)

    h, x_edges, y_edges = jnp.histogram2d(phi, psi, bins=60, density=True)

    h = jnp.log(h) * -(kbt / 4.184)
    x, y = onp.meshgrid(x_edges, y_edges)

    cax = ax.pcolormesh(x, y, h.T, cmap=cmap, vmax=5.25)
    ax.set_xlabel('$\phi$ [rad]')
    if ylabel:
        ax.set_ylabel('$\psi$ [rad]')
    ax.set_title(title)
    
    return ax, cax
    

Plotting the dihedral angle distributions reveals that both FM and REM-trained models can identify the preferred torsional states of alanine-dipeptide. However, the REM model reproduces the relative preference between the states much better.

labels = ["FM", "RM", "FM + RM", "AA Reference"]

fig, (ax1, ax2) = plt.subplots(1, 2, layout="constrained", figsize=(9, 3), sharey=True)
ax1 = plot_1d_dihedral(ax1, [fm_phi, re_phi, fm_rm_phi, ref_phi], labels, xlabel="$\phi\ [deg]$")
ax2 = plot_1d_dihedral(ax2, [fm_psi, re_psi, fm_rm_psi, ref_psi], labels, xlabel="$\psi\ [deg]$", ylabel=False)
fig.legend(labels, ncols=len(labels), bbox_to_anchor=(0.5, 1.01), loc="lower center")

fig.savefig("../_data/output/alanine_dipeptide_1D_dihedral_angles.pdf", bbox_inches='tight')
../_images/37728cb36d38210def2f99d71d31d3a8b9ff17a5761679dfb84bea935ace0609.svg

Plotting the free energy surface of the backbone dihedral angles indicates similar results. Both models can identify the regions of low free energy. However, only the REM model correctly predicts the depth of these regions.

labels = ["AA Reference"]

fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, layout="constrained", figsize=(9, 3), sharey=True)
ax1, _ = plot_histogram_free_energy(ax1, ref_phi, ref_psi, kT, ylabel=True, title="AA Reference")
ax2, _ = plot_histogram_free_energy(ax2, fm_phi, fm_psi, kT, title="Force Matching")
ax3, _ = plot_histogram_free_energy(ax3, re_phi, re_psi, kT, title="Relative Entropy")
ax4, cax = plot_histogram_free_energy(ax4, fm_rm_phi, fm_rm_psi, kT, title="FM and RM")

cbar = fig.colorbar(cax)
cbar.set_label('Free Energy (kcal/mol)')

fig.savefig("../_data/output/alanine_dipeptide_free_energy_dihedral_angles.pdf", bbox_inches='tight')
../_images/453e7224947234cfff36ac6ee4ddb4bb048b6e08fef0cad00decbba116490367.svg

References#