Hide code cell content

import os
from pathlib import Path

import jax
import jax.numpy as jnp
from jax import tree_util, random

import jax_md_mod
from jax_md import space, energy, partition, simulate

import optax

import matplotlib.pyplot as plt

from chemtrain.data import preprocessing
from chemtrain.trainers import ForceMatching, RelativeEntropy
from chemtrain import ensemble, quantity

base_path = Path(os.environ.get("DATA_PATH", "./data"))

Relative Entropy Minimization#

Principle of Relative Entropy#

Relative entropy provides a fundamental link between models of different scales [1]. Measuring the loss of information induced by the coarse-graining [2], it is thus a desirable objective to minimize.

For a corase-grained model $p^\text{CG}_\theta(\mathbf R)$ on coarse-grained sites $\mathbf R$ connected to the sites of a fine-scale model $p^\text{AA}(\mathbf r)$ via a mapping $\mathbf R = M(\mathbf r)$, the relative entropy is [2]

$$ S_\text{rel} = S_\text{map} + \int p^\text{AA}(\mathbf r)\log \frac{p^\text{AA}(\mathbf r)}{p^\text{CG}(M(\mathbf r))}d\mathbf r. $$

For a canonical ensemble $p(\mathbf r) \propto e^{-\beta U(\mathbf r)}$ at temperature $T = \frac{1}{k_B \beta}$, the relative entropy further decomposes to

$$ S_\text{rel} = S_\text{map} + \beta \left\langle U_\theta^\text{CG}(M(\mathbf r)) - U^\text{AA}(\mathbf r)\right\rangle_\text{AA} - \beta(A_\theta^\text{CG} - A^\text{AA}). $$

The first part $S_\text{rel}$ measures the unavoidable loss of information due to the degeneracy of the mapping. This part is, however, independent of the fine-grained and coarse-grained distributions.

The second part is the expected difference between the predicted potential energies $U_\theta^\text{CG}(M(\mathbf r)) - U^\text{AA}(\mathbf r)$ in the fine-scaled ensemble. This part is simple to estimate. Analogous to force-matching, the estimation involves pre-computing an atomistic trajectory, followed by a batched gradient-based optimization.

The last part is the free energy difference between the fine-scaled and coarse-grained ensembles. Since the free energy normalizes a distribution

$$ A_\theta = -\frac{1}{\beta}\log \int e^{-\beta U_\theta}dx, $$ it is not a quantity directly predictable from individual samples of the potential energy model. However, several routines exist to estimate the difference of free energies $A_\theta^\text{CG} = \Delta A_\theta^\text{CG} + \tilde A^\text{CG}$ to a reference potential $\tilde U^\text{CG}$.

Thus, the exact computation of the relative entropy is infeasible. Nevertheless, we can collect all terms directly depending on $\theta$ in a new objective

$$ \mathcal L_\text{RE}(\theta) = \beta\left(\left\langle U_\theta^\text{CG}(M(R))\right\rangle_\text{AA} - \Delta A_\theta^\text{CG}\right). $$

This objective has precisely the same gradients as the relative entropy

$$ \frac{\partial}{\partial \theta} \mathcal L(\theta) = \frac{\partial}{\partial \theta}S_\text{rel}. $$

Unfortunately, the objective is no longer lower bound by $0$, reached by the relative entropy under perfect preservation of information. Nevertheless, chemtrain enables the estimation of all the contributions to the loss. Thus, chemtrain can compute the correct gradients via algorithmic differentiation and enable training via the Relative Entropy objective.

Load Data#

This example follows the Force Matching guide. Again, we use reference data from an all-atomistic simulation of ethane. We obtained this data in the example Prior Simulation.

Ethane
train_ratio = 0.5

box = jnp.asarray([1.0, 1.0, 1.0])
kT = 2.56

all_forces = preprocessing.get_dataset(base_path / "forces_ethane.npy")
all_positions = preprocessing.get_dataset(base_path / "positions_ethane.npy")

Compute Mapping#

The reference data contains only fine-grained forces $\mathbf f_i$ and positions $\mathbf r_i$. Thus, we must define a mapping $M$ that derives the positions of the coarse-grained sites $\mathcal I_I$ [3]

\[\mathbf R_I = \sum_{i \in \mathcal I_I} c_{Ii} \mathbf r_i.\]

We select the two carbon atoms $C_1$ and $C_2$ as locations of the coarse-grained sites $\mathcal I_1$ and $\mathcal I_2$ and neglect the hydrogen atoms.

# Heacy-atoms mapping
displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=True)

# Scale the position data into fractional coordinates
position_dataset = preprocessing.scale_dataset_fractional(all_positions, box)

masses = jnp.asarray([15.035, 1.011, 1.011, 1.011])

weights = jnp.asarray([
    [1, 0.0000, 0, 0, 0, 0.000, 0.000, 0.000],
    [0.0000, 1, 0.000, 0.000, 0.000, 0, 0, 0]
])

position_dataset = preprocessing.map_dataset(
    position_dataset, displacement_fn, shift_fn, weights, 
)

Setup Model#

As a coarse-grained potential model, we choose a simple spring bond

\[ U(\mathbf R) = \frac{1}{2} k_b (|\mathbf R_1 - \mathbf R_2| - b_0)^2.\]

To ensure that the model parameters remain positive during optimization, we transform them into a constraint space $\theta_1 = \log b_0,\ \theta_2= \log k_b$.

r_init = position_dataset[0, ...]

displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=True)
neighbor_fn = partition.neighbor_list(
    displacement_fn, box, 1.0, fractional_coordinates=True, disable_cell_list=True)

nbrs_init = neighbor_fn.allocate(r_init)

init_params = {
    "log_b0": jnp.log(0.11),
    "log_kb": jnp.log(1000.0)
}

def energy_fn_template(energy_params):
    harmonic_energy_fn = energy.simple_spring_bond(
        displacement_fn, bond=jnp.asarray([[0, 1]]),
        length=jnp.exp(energy_params["log_b0"]),
        epsilon=jnp.exp(energy_params["log_kb"]),
        alpha=2.0
    )
    
    return harmonic_energy_fn    

sample_idx = 0

print(f"Energy with initial params is {energy_fn_template(init_params)(position_dataset[sample_idx, ...], neighbor=nbrs_init)}")
Energy with initial params is 1.105982780456543
/home/docs/checkouts/readthedocs.org/user_builds/chemtrain/envs/latest/lib/python3.11/site-packages/jax/_src/numpy/reductions.py:230: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in sum is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/jax-ml/jax#current-gotchas for more.
  return _reduction(a, "sum", lax.add, 0, preproc=_cast_to_numeric,

Analytical Solution#

As our model relies only on the magnitude of the displacement between $C_1$ and $C_2$, we compute this distance and plot it.

disp = jax.vmap(displacement_fn)(position_dataset[:, 0, :], position_dataset[:, 1, :])
dist_CC = jnp.sqrt(jnp.sum(disp ** 2, axis=-1))

plt.figure()
plt.hist(dist_CC, bins=100)
plt.xlabel("Distance C_1 - C_2 [nm]")
plt.ylabel("Count")
Text(0, 0.5, 'Count')
../_images/959915d7baa2ae864f84b57558f5eec183fa7d376cdc2c7c8f8e8f9610d1c104.png

Indeed, the distance between the two carbon atoms is approximately Gaussian distributed. Hence, the choice of a harmonic potential model is reasonable.

Thus, we might estimate the parameters of the model by computing the mean and variance of the particle distance.

$$ b_0 = \mathbb E[|\mathbf R_1 - \mathbf R_2|], \quad k_b = \frac{1}{\beta \operatorname{Var}[|\mathbf R_1 - \mathbf R_2|]} $$

# Analytical solution
b0 = jnp.mean(dist_CC)
kb = kT / jnp.var(dist_CC)

print(f"Estimated potential parameters are {kb :.1f} kJ/mol/nm^2 and {b0 :.3f} nm")
Estimated potential parameters are 9598.2 kJ/mol/nm^2 and 0.156 nm

Setup Optimizer#

epochs = 100
initial_lr = 0.5
lr_decay = 0.1

lrd = int(position_dataset.shape[0] / epochs)
lr_schedule = optax.exponential_decay(initial_lr, lrd, lr_decay)
optimizer = optax.chain(
    optax.scale_by_adam(),
    optax.scale_by_schedule(lr_schedule),
    # Flips the sign of the update for gradient descend
    optax.scale_by_learning_rate(1.0),
)

Setup Simulator#

timings = ensemble.sampling.process_printouts(
    time_step=0.002, total_time=1e3, t_equilib=1e2,
    print_every=0.1, t_start=0.0
)

init_ref_state, sim_template = ensemble.sampling.initialize_simulator_template(
    simulate.nvt_langevin, shift_fn=shift_fn, nbrs=nbrs_init,
    init_with_PRNGKey=True, extra_simulator_kwargs={"kT": kT, "gamma": 1.0, "dt": 0.002}
)

cg_masses = masses[0]

reference_state = init_ref_state(
    random.PRNGKey(11), r_init,
    energy_or_force_fn=energy_fn_template(init_params),
    init_sim_kwargs={"mass": cg_masses, "neighbor": nbrs_init}
)

Setup Relative Entropy Minimization#

relative_entropy = RelativeEntropy(
    init_params=init_params, optimizer=optimizer,
    reweight_ratio=1.1, sim_batch_size=1,
    energy_fn_template=energy_fn_template,
)

subsampled_dataset = position_dataset[::100, ...]
print(f"Dataset has shape {subsampled_dataset.shape}")

relative_entropy.add_statepoint(
    position_dataset, energy_fn_template,
    sim_template, neighbor_fn, timings,
    {'kT': kT}, reference_state,  
)

relative_entropy.init_step_size_adaption(0.1)

Hide code cell output

/home/docs/checkouts/readthedocs.org/user_builds/chemtrain/envs/latest/lib/python3.11/site-packages/chemtrain/trainers/base.py:262: UserWarning: [RelativeEntropy] Attribute gradient_norm_history is marked for checkpoining twice.
  warnings.warn(f"[{self.__class__.__name__}] Attribute {duplicate_key} is marked for checkpoining twice.")
/home/docs/checkouts/readthedocs.org/user_builds/chemtrain/envs/latest/lib/python3.11/site-packages/chemtrain/ensemble/reweighting.py:798: UserWarning: Propagation function is not safe by default. Do not forget to use the wrapper around the compute function to ensure that the neighborlist does not overflow.
  warnings.warn(
Dataset has shape (90, 2, 3)
No reference batch size provided. Using number of generated CG snapshots by default.
[Propagation] Time for trajectory compilation 0: 0.024921059608459473 mins
[Propagation] Time for trajectory simulation 0: 9.761253992716472e-06 mins
[Step size] Use 7 iterations for 10 interior points.
relative_entropy.train(epochs)

Hide code cell output

[Propagate] Effective sample size: 8999.9970703125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 1.6779112815856934

[RE] Epoch 0
	Mean Delta RE loss = 0.46043
	Gradient norm: 0.1396877020597458
	Elapsed time = 0.072 min

[Statepoint 0]

	kT = 2.572 ref_kT = 2.560

[Propagate] Effective sample size: 4818.92724609375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 1.1871552467346191

[RE] Epoch 1
	Mean Delta RE loss = 1.50640
	Gradient norm: 24.065814971923828
	Elapsed time = 0.065 min

[Statepoint 0]

	kT = 2.576 ref_kT = 2.560

[Propagate] Effective sample size: 2949.969482421875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.0589613914489746

[RE] Epoch 2
	Mean Delta RE loss = -0.07650
	Gradient norm: 3.244947910308838
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.543 ref_kT = 2.560

[Propagate] Effective sample size: 7054.046875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.101876735687256

[RE] Epoch 3
	Mean Delta RE loss = 0.10758
	Gradient norm: 31.531980514526367
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.582 ref_kT = 2.560

[Propagate] Effective sample size: 7363.3701171875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 1.8570141792297363

[RE] Epoch 4
	Mean Delta RE loss = -0.16296
	Gradient norm: 8.15678882598877
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.607 ref_kT = 2.560

[Propagate] Effective sample size: 5764.1298828125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.1130032539367676

[RE] Epoch 5
	Mean Delta RE loss = 0.15429
	Gradient norm: 1.7629905939102173
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.560 ref_kT = 2.560

[Propagate] Effective sample size: 7445.74951171875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.297715663909912

[RE] Epoch 6
	Mean Delta RE loss = 0.74985
	Gradient norm: 19.793004989624023
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.562 ref_kT = 2.560

[Propagate] Effective sample size: 8956.287109375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.1877455711364746

[RE] Epoch 7
	Mean Delta RE loss = 0.63629
	Gradient norm: 14.619794845581055
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.552 ref_kT = 2.560

[Propagate] Effective sample size: 8023.58740234375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.1572585105895996

[RE] Epoch 8
	Mean Delta RE loss = 0.16571
	Gradient norm: 2.045010805130005
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.566 ref_kT = 2.560

[Propagate] Effective sample size: 7782.6630859375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.2352232933044434

[RE] Epoch 9
	Mean Delta RE loss = -0.14568
	Gradient norm: 1.4026190042495728
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.573 ref_kT = 2.560

[Propagate] Effective sample size: 8413.724609375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.300136089324951

[RE] Epoch 10
	Mean Delta RE loss = -0.19564
	Gradient norm: 9.348515510559082
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.446 ref_kT = 2.560

[Propagate] Effective sample size: 8977.982421875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.2740225791931152

[RE] Epoch 11
	Mean Delta RE loss = -0.21241
	Gradient norm: 11.756546020507812
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.623 ref_kT = 2.560

[Propagate] Effective sample size: 8746.5703125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.209519863128662

[RE] Epoch 12
	Mean Delta RE loss = -0.28019
	Gradient norm: 6.527193546295166
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.544 ref_kT = 2.560

[Propagate] Effective sample size: 8200.2109375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.1981425285339355

[RE] Epoch 13
	Mean Delta RE loss = -0.26899
	Gradient norm: 0.3365727365016937
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.505 ref_kT = 2.560

[Propagate] Effective sample size: 8107.443359375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.280087947845459

[RE] Epoch 14
	Mean Delta RE loss = -0.08939
	Gradient norm: 5.170022964477539
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.519 ref_kT = 2.560

[Propagate] Effective sample size: 8799.783203125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.286281108856201

[RE] Epoch 15
	Mean Delta RE loss = 0.04222
	Gradient norm: 15.959290504455566
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.539 ref_kT = 2.560

[Propagate] Effective sample size: 8854.458984375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.2135634422302246

[RE] Epoch 16
	Mean Delta RE loss = -0.14746
	Gradient norm: 9.21475887298584
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.564 ref_kT = 2.560

[Propagate] Effective sample size: 8233.4365234375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.2236247062683105

[RE] Epoch 17
	Mean Delta RE loss = -0.44032
	Gradient norm: 0.1573340892791748
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.626 ref_kT = 2.560

[Propagate] Effective sample size: 8316.693359375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.2903494834899902

[RE] Epoch 18
	Mean Delta RE loss = -0.56104
	Gradient norm: 6.832849502563477
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.615 ref_kT = 2.560

[Propagate] Effective sample size: 8890.5556640625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.281297206878662

[RE] Epoch 19
	Mean Delta RE loss = -0.58122
	Gradient norm: 15.568363189697266
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.479 ref_kT = 2.560

[Propagate] Effective sample size: 8810.439453125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.211243152618408

[RE] Epoch 20
	Mean Delta RE loss = -0.65149
	Gradient norm: 6.963631629943848
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.509 ref_kT = 2.560

[Propagate] Effective sample size: 8214.3544921875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.236067295074463

[RE] Epoch 21
	Mean Delta RE loss = -0.63619
	Gradient norm: 0.3030811846256256
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.513 ref_kT = 2.560

[Propagate] Effective sample size: 8420.8291015625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3010096549987793

[RE] Epoch 22
	Mean Delta RE loss = -0.49812
	Gradient norm: 14.629256248474121
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.577 ref_kT = 2.560

[Propagate] Effective sample size: 8985.837890625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.2304205894470215

[RE] Epoch 23
	Mean Delta RE loss = -0.54915
	Gradient norm: 15.004228591918945
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.622 ref_kT = 2.560

[Propagate] Effective sample size: 8373.4052734375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.2360873222351074

[RE] Epoch 24
	Mean Delta RE loss = -0.75234
	Gradient norm: 0.264541357755661
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.539 ref_kT = 2.560

[Propagate] Effective sample size: 8420.990234375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3005099296569824

[RE] Epoch 25
	Mean Delta RE loss = -0.80296
	Gradient norm: 9.566757202148438
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.467 ref_kT = 2.560

[Propagate] Effective sample size: 8981.33984375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.2715439796447754

[RE] Epoch 26
	Mean Delta RE loss = -0.80435
	Gradient norm: 13.25450325012207
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.597 ref_kT = 2.560

[Propagate] Effective sample size: 8724.9345703125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.2464728355407715

[RE] Epoch 27
	Mean Delta RE loss = -0.82305
	Gradient norm: 2.0361173152923584
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.495 ref_kT = 2.560

[Propagate] Effective sample size: 8508.91015625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.2904820442199707

[RE] Epoch 28
	Mean Delta RE loss = -0.74118
	Gradient norm: 3.619258403778076
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.486 ref_kT = 2.560

[Propagate] Effective sample size: 8891.7421875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.292391300201416

[RE] Epoch 29
	Mean Delta RE loss = -0.66395
	Gradient norm: 12.157517433166504
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.465 ref_kT = 2.560

[Propagate] Effective sample size: 8908.7353515625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.2628302574157715

[RE] Epoch 30
	Mean Delta RE loss = -0.73179
	Gradient norm: 4.115368843078613
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.549 ref_kT = 2.560

[Propagate] Effective sample size: 8649.23046875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.2841668128967285

[RE] Epoch 31
	Mean Delta RE loss = -0.80763
	Gradient norm: 0.4202124774456024
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.592 ref_kT = 2.560

[Propagate] Effective sample size: 8835.7578125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.30234956741333

[RE] Epoch 32
	Mean Delta RE loss = -0.80804
	Gradient norm: 6.4892497062683105
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.502 ref_kT = 2.560

[Propagate] Effective sample size: 8997.8857421875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.28920316696167

[RE] Epoch 33
	Mean Delta RE loss = -0.79965
	Gradient norm: 4.524392127990723
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.642 ref_kT = 2.560

[Propagate] Effective sample size: 8880.3701171875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.2836079597473145

[RE] Epoch 34
	Mean Delta RE loss = -0.77566
	Gradient norm: 0.448618620634079
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.626 ref_kT = 2.560

[Propagate] Effective sample size: 8830.8212890625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.2975687980651855

[RE] Epoch 35
	Mean Delta RE loss = -0.71951
	Gradient norm: 0.9271399974822998
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.583 ref_kT = 2.560

[Propagate] Effective sample size: 8954.98828125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3018040657043457

[RE] Epoch 36
	Mean Delta RE loss = -0.66973
	Gradient norm: 3.8008437156677246
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.569 ref_kT = 2.560

[Propagate] Effective sample size: 8992.978515625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.294050693511963

[RE] Epoch 37
	Mean Delta RE loss = -0.67809
	Gradient norm: 2.2658591270446777
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.494 ref_kT = 2.560

[Propagate] Effective sample size: 8923.53125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.293487071990967

[RE] Epoch 38
	Mean Delta RE loss = -0.71508
	Gradient norm: 0.0824422836303711
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.549 ref_kT = 2.560

[Propagate] Effective sample size: 8918.5029296875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3007235527038574

[RE] Epoch 39
	Mean Delta RE loss = -0.73770
	Gradient norm: 0.8186063170433044
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.510 ref_kT = 2.560

[Propagate] Effective sample size: 8983.2666015625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302211284637451

[RE] Epoch 40
	Mean Delta RE loss = -0.74247
	Gradient norm: 1.900080919265747
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.493 ref_kT = 2.560

[Propagate] Effective sample size: 8996.6416015625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.2984652519226074

[RE] Epoch 41
	Mean Delta RE loss = -0.73916
	Gradient norm: 1.287218689918518
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.536 ref_kT = 2.560

[Propagate] Effective sample size: 8963.0029296875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.297333240509033

[RE] Epoch 42
	Mean Delta RE loss = -0.72575
	Gradient norm: 0.11942722648382187
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.517 ref_kT = 2.560

[Propagate] Effective sample size: 8952.853515625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.300912380218506

[RE] Epoch 43
	Mean Delta RE loss = -0.70366
	Gradient norm: 0.30601540207862854
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.541 ref_kT = 2.560

[Propagate] Effective sample size: 8984.962890625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302485942840576

[RE] Epoch 44
	Mean Delta RE loss = -0.68898
	Gradient norm: 1.224875569343567
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.593 ref_kT = 2.560

[Propagate] Effective sample size: 8999.11328125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3006396293640137

[RE] Epoch 45
	Mean Delta RE loss = -0.69725
	Gradient norm: 0.8312013149261475
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.484 ref_kT = 2.560

[Propagate] Effective sample size: 8982.5048828125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.2995667457580566

[RE] Epoch 46
	Mean Delta RE loss = -0.72044
	Gradient norm: 0.16865994036197662
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.576 ref_kT = 2.560

[Propagate] Effective sample size: 8972.8896484375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3010964393615723

[RE] Epoch 47
	Mean Delta RE loss = -0.74343
	Gradient norm: 0.08705897629261017
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.589 ref_kT = 2.560

[Propagate] Effective sample size: 8986.6083984375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302527904510498

[RE] Epoch 48
	Mean Delta RE loss = -0.75747
	Gradient norm: 0.5409376621246338
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.511 ref_kT = 2.560

[Propagate] Effective sample size: 8999.4990234375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3019375801086426

[RE] Epoch 49
	Mean Delta RE loss = -0.76336
	Gradient norm: 0.6264514327049255
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.624 ref_kT = 2.560

[Propagate] Effective sample size: 8994.1708984375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3005709648132324

[RE] Epoch 50
	Mean Delta RE loss = -0.76436
	Gradient norm: 0.36257773637771606
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.575 ref_kT = 2.560

[Propagate] Effective sample size: 8981.9052734375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3010401725769043

[RE] Epoch 51
	Mean Delta RE loss = -0.75831
	Gradient norm: 0.0009316425421275198
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.547 ref_kT = 2.560

[Propagate] Effective sample size: 8986.1201171875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3023085594177246

[RE] Epoch 52
	Mean Delta RE loss = -0.74806
	Gradient norm: 0.27788594365119934
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.582 ref_kT = 2.560

[Propagate] Effective sample size: 8997.5166015625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3023486137390137

[RE] Epoch 53
	Mean Delta RE loss = -0.74472
	Gradient norm: 0.6351782083511353
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.624 ref_kT = 2.560

[Propagate] Effective sample size: 8997.876953125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.301757335662842

[RE] Epoch 54
	Mean Delta RE loss = -0.75437
	Gradient norm: 0.20222587883472443
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.475 ref_kT = 2.560

[Propagate] Effective sample size: 8992.5498046875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.301604747772217

[RE] Epoch 55
	Mean Delta RE loss = -0.76740
	Gradient norm: 0.034628815948963165
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.534 ref_kT = 2.560

[Propagate] Effective sample size: 8991.1865234375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302011013031006

[RE] Epoch 56
	Mean Delta RE loss = -0.77972
	Gradient norm: 0.017051266506314278
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.600 ref_kT = 2.560

[Propagate] Effective sample size: 8994.8486328125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302542209625244

[RE] Epoch 57
	Mean Delta RE loss = -0.78837
	Gradient norm: 0.22289767861366272
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.611 ref_kT = 2.560

[Propagate] Effective sample size: 8999.6103515625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302245616912842

[RE] Epoch 58
	Mean Delta RE loss = -0.79237
	Gradient norm: 0.4771481156349182
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.526 ref_kT = 2.560

[Propagate] Effective sample size: 8996.9501953125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3019328117370605

[RE] Epoch 59
	Mean Delta RE loss = -0.78890
	Gradient norm: 0.07408290356397629
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.569 ref_kT = 2.560

[Propagate] Effective sample size: 8994.1279296875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302039623260498

[RE] Epoch 60
	Mean Delta RE loss = -0.78131
	Gradient norm: 2.311469415872125e-06
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.604 ref_kT = 2.560

[Propagate] Effective sample size: 8995.09765625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302445888519287

[RE] Epoch 61
	Mean Delta RE loss = -0.77376
	Gradient norm: 0.07753054052591324
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.554 ref_kT = 2.560

[Propagate] Effective sample size: 8998.7607421875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3025832176208496

[RE] Epoch 62
	Mean Delta RE loss = -0.77069
	Gradient norm: 0.08378595858812332
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.510 ref_kT = 2.560

[Propagate] Effective sample size: 8999.98828125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3024115562438965

[RE] Epoch 63
	Mean Delta RE loss = -0.77056
	Gradient norm: 0.18285134434700012
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.547 ref_kT = 2.560

[Propagate] Effective sample size: 8998.4521484375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3023200035095215

[RE] Epoch 64
	Mean Delta RE loss = -0.77591
	Gradient norm: 0.03559394180774689
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.559 ref_kT = 2.560

[Propagate] Effective sample size: 8997.6201171875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3021512031555176

[RE] Epoch 65
	Mean Delta RE loss = -0.78161
	Gradient norm: 0.029385928064584732
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.679 ref_kT = 2.560

[Propagate] Effective sample size: 8996.1005859375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302490711212158

[RE] Epoch 66
	Mean Delta RE loss = -0.78861
	Gradient norm: 0.07013164460659027
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.584 ref_kT = 2.560

[Propagate] Effective sample size: 8999.15625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302572727203369

[RE] Epoch 67
	Mean Delta RE loss = -0.79133
	Gradient norm: 0.16814275085926056
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.611 ref_kT = 2.560

[Propagate] Effective sample size: 8999.8935546875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3024392127990723

[RE] Epoch 68
	Mean Delta RE loss = -0.79015
	Gradient norm: 0.09761488437652588
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.613 ref_kT = 2.560

[Propagate] Effective sample size: 8998.701171875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302323818206787

[RE] Epoch 69
	Mean Delta RE loss = -0.78625
	Gradient norm: 0.03631487861275673
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.577 ref_kT = 2.560

[Propagate] Effective sample size: 8997.6455078125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302438259124756

[RE] Epoch 70
	Mean Delta RE loss = -0.78082
	Gradient norm: 0.006373928859829903
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.539 ref_kT = 2.560

[Propagate] Effective sample size: 8998.6923828125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3025660514831543

[RE] Epoch 71
	Mean Delta RE loss = -0.77623
	Gradient norm: 0.05559137836098671
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.527 ref_kT = 2.560

[Propagate] Effective sample size: 8999.833984375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302567958831787

[RE] Epoch 72
	Mean Delta RE loss = -0.77413
	Gradient norm: 0.08891811966896057
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.556 ref_kT = 2.560

[Propagate] Effective sample size: 8999.8505859375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3025078773498535

[RE] Epoch 73
	Mean Delta RE loss = -0.77464
	Gradient norm: 0.039795149117708206
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.564 ref_kT = 2.560

[Propagate] Effective sample size: 8999.310546875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3024744987487793

[RE] Epoch 74
	Mean Delta RE loss = -0.77667
	Gradient norm: 0.011786462739109993
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.526 ref_kT = 2.560

[Propagate] Effective sample size: 8999.0185546875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302520275115967

[RE] Epoch 75
	Mean Delta RE loss = -0.77903
	Gradient norm: 0.002934834687039256
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.579 ref_kT = 2.560

[Propagate] Effective sample size: 8999.421875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3025736808776855

[RE] Epoch 76
	Mean Delta RE loss = -0.78070
	Gradient norm: 0.0233423113822937
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.537 ref_kT = 2.560

[Propagate] Effective sample size: 8999.9111328125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302579402923584

[RE] Epoch 77
	Mean Delta RE loss = -0.78077
	Gradient norm: 0.036697857081890106
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.642 ref_kT = 2.560

[Propagate] Effective sample size: 8999.962890625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3025593757629395

[RE] Epoch 78
	Mean Delta RE loss = -0.77942
	Gradient norm: 0.019469670951366425
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.575 ref_kT = 2.560

[Propagate] Effective sample size: 8999.7734375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3025240898132324

[RE] Epoch 79
	Mean Delta RE loss = -0.77665
	Gradient norm: 0.020259184762835503
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.507 ref_kT = 2.560

[Propagate] Effective sample size: 8999.46484375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302542209625244

[RE] Epoch 80
	Mean Delta RE loss = -0.77250
	Gradient norm: 0.0008714621653780341
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.519 ref_kT = 2.560

[Propagate] Effective sample size: 8999.619140625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302567958831787

[RE] Epoch 81
	Mean Delta RE loss = -0.76891
	Gradient norm: 0.017470210790634155
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.500 ref_kT = 2.560

[Propagate] Effective sample size: 8999.8505859375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3025736808776855

[RE] Epoch 82
	Mean Delta RE loss = -0.76556
	Gradient norm: 0.0056398892775177956
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.563 ref_kT = 2.560

[Propagate] Effective sample size: 8999.90234375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.30256986618042

[RE] Epoch 83
	Mean Delta RE loss = -0.76232
	Gradient norm: 0.03979157283902168
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.612 ref_kT = 2.560

[Propagate] Effective sample size: 8999.876953125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302565097808838

[RE] Epoch 84
	Mean Delta RE loss = -0.76062
	Gradient norm: 0.003190131625160575
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.547 ref_kT = 2.560

[Propagate] Effective sample size: 8999.8251953125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302534580230713

[RE] Epoch 85
	Mean Delta RE loss = -0.75960
	Gradient norm: 0.02947920188307762
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.629 ref_kT = 2.560

[Propagate] Effective sample size: 8999.55078125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302542209625244

[RE] Epoch 86
	Mean Delta RE loss = -0.76051
	Gradient norm: 0.0016001646872609854
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.504 ref_kT = 2.560

[Propagate] Effective sample size: 8999.619140625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3025641441345215

[RE] Epoch 87
	Mean Delta RE loss = -0.76206
	Gradient norm: 0.0023577141109853983
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.506 ref_kT = 2.560

[Propagate] Effective sample size: 8999.80859375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302577495574951

[RE] Epoch 88
	Mean Delta RE loss = -0.76288
	Gradient norm: 0.00481218658387661
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.635 ref_kT = 2.560

[Propagate] Effective sample size: 8999.9365234375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302579402923584

[RE] Epoch 89
	Mean Delta RE loss = -0.76307
	Gradient norm: 0.0581313893198967
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.588 ref_kT = 2.560

[Propagate] Effective sample size: 8999.9541015625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302560329437256

[RE] Epoch 90
	Mean Delta RE loss = -0.76180
	Gradient norm: 0.028342323377728462
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.574 ref_kT = 2.560

[Propagate] Effective sample size: 8999.7822265625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302550792694092

[RE] Epoch 91
	Mean Delta RE loss = -0.75976
	Gradient norm: 0.008895926177501678
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.528 ref_kT = 2.560

[Propagate] Effective sample size: 8999.6875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3025622367858887

[RE] Epoch 92
	Mean Delta RE loss = -0.75747
	Gradient norm: 0.000587449932936579
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.570 ref_kT = 2.560

[Propagate] Effective sample size: 8999.7998046875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3025641441345215

[RE] Epoch 93
	Mean Delta RE loss = -0.75547
	Gradient norm: 0.00055323401466012
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.547 ref_kT = 2.560

[Propagate] Effective sample size: 8999.81640625 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3025622367858887

[RE] Epoch 94
	Mean Delta RE loss = -0.75405
	Gradient norm: 0.0034412797540426254
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.528 ref_kT = 2.560

[Propagate] Effective sample size: 8999.7998046875 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3025832176208496

[RE] Epoch 95
	Mean Delta RE loss = -0.75261
	Gradient norm: 0.032640501856803894
	Elapsed time = 0.012 min

[Statepoint 0]

	kT = 2.603 ref_kT = 2.560

[Propagate] Effective sample size: 8999.98828125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302579402923584

[RE] Epoch 96
	Mean Delta RE loss = -0.75254
	Gradient norm: 0.037296783179044724
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.522 ref_kT = 2.560

[Propagate] Effective sample size: 8999.9453125 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.302567958831787

[RE] Epoch 97
	Mean Delta RE loss = -0.75400
	Gradient norm: 0.021971747279167175
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.534 ref_kT = 2.560

[Propagate] Effective sample size: 8999.8427734375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3025660514831543

[RE] Epoch 98
	Mean Delta RE loss = -0.75582
	Gradient norm: 0.00416084099560976
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.496 ref_kT = 2.560

[Propagate] Effective sample size: 8999.833984375 (9900.0) -> Recompute is True
[Step Size] Found optimal step size 1.0 with residual 2.3025660514831543

[RE] Epoch 99
	Mean Delta RE loss = -0.75727
	Gradient norm: 0.0008216393180191517
	Elapsed time = 0.011 min

[Statepoint 0]

	kT = 2.534 ref_kT = 2.560

Results#

plt.figure()
plt.plot(relative_entropy.delta_re[0])
plt.xticks(ticks=range(0, epochs + 1, 25))
plt.xlabel("Epoch")
plt.ylabel("Loss")

plt.figure()
plt.plot(relative_entropy.gradient_norm_history)
plt.xticks(ticks=range(0, epochs + 1, 25))
plt.xlabel("Epoch")
plt.ylabel("Gradient Norm")
Text(0, 0.5, 'Gradient Norm')
../_images/0ba329503ecdebf695e7c9e07782c55ad2a3096a52a99361951087313fd7fcc6.png ../_images/813412232c89f460928882a6b0a56cfa7849f48fc78996dd55eec28218290595.png

Finally, we compare the values obtained from a Gaussian fit to those obtained from relative entropy minimization.

pred_parameters = tree_util.tree_map(jnp.exp, relative_entropy.params)

b0_err = jnp.abs(b0 - pred_parameters["log_b0"])
kb_err = jnp.abs(kb - pred_parameters["log_kb"])

print(f"RE min. predicted {pred_parameters['log_b0']:.3f} nm and {pred_parameters['log_kb']:.1f} kJ/mol/nm^2")
print(f"Gaussian fit predicted {b0:.3f} nm and {kb:.1f} kJ/mol/nm^2")
print(f"Absolute error in b0 is {b0_err:.3f} nm and in kb is {kb_err:.1f} kJ/mol/nm^2")
RE min. predicted 0.152 nm and 9245.5 kJ/mol/nm^2
Gaussian fit predicted 0.156 nm and 9598.2 kJ/mol/nm^2
Absolute error in b0 is 0.004 nm and in kb is 352.7 kJ/mol/nm^2

Further Reading#

Examples#

Publications#

  1. Stephan Thaler, Maximilian Stupp, Julija Zavadlav; Deep coarse-grained potentials via relative entropy minimization. J. Chem. Phys. 28 December 2022; 157 (24): 244103. https://doi.org/10.1063/5.0124538

References#