Hide code cell content

from pathlib import Path
from urllib import request
import time
import os
import functools

if 'CUDA_VISIBLE_DEVICES' not in os.environ:
    os.environ['CUDA_VISIBLE_DEVICES'] = "0"

os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=False'
)

os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.95'

import zipfile

import numpy as onp

import jax
import optax
from jax import numpy as jnp, random, vmap, tree_util
from jax_md_mod import io, custom_quantity, custom_space, custom_energy
from jax_md import simulate, partition, space, util, energy, quantity as snapshot_quantity


from jax_md_mod.model import layers, neural_networks, prior

import mdtraj


import matplotlib.pyplot as plt

import haiku as hk
import chex
import copy
import contextlib

from chemtrain.data import preprocessing
from chemtrain.ensemble import sampling
from chemtrain import quantity, trainers, util as chem_util
2024-08-28 10:13:40.599139: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version (12.5.82). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
%config InlineBackend.figure_formats = ['svg']
%matplotlib inline

Bottom-Up and Top-Down Training of Atomistic Titanium#

Problem#

This example reproduces the results of the paper Accurate machine learning force fields via experimental and simulation data fusion [1].

This paper introduces fused bottom-up and top-down learning. We use titanium as a showcase, and train the bottom-up model on DFT forces and energies. The top-down model is trained on experimental elastic constants and pressures. First, we train a NN potential bottom-up using FM. Second, we train a fused bottom-up and top-down NN potential.

Loading Bottom-Up Reference Data#

def load_subset(data_dir, train_ratio=0.7, val_ratio=0.1):
    box = onp.load(data_dir / 'box.npy', allow_pickle=True)
    coord = onp.load(data_dir / 'coord.npy', allow_pickle=True)
    energy = onp.load(data_dir / 'energy.npy', allow_pickle=True)
    force = onp.load(data_dir / 'force.npy', allow_pickle=True)
    virial = onp.load(data_dir / 'virial.npy', allow_pickle=True)
    type = onp.load(data_dir / 'types.npy', allow_pickle=True)

    # We reshape the data to a standard format
    n_samples = box.shape[0]

    # Transpose box tensor to conform to JAX-MD format
    dataset = dict(
        box=onp.reshape(box, (n_samples, 3, 3)).swapaxes(1, 2),
        R=onp.reshape(coord, (n_samples, -1, 3)),
        U=onp.reshape(energy, (n_samples,)),
        type=onp.reshape(type, (n_samples,)),
        F=onp.reshape(force, (n_samples, -1, 3)),
        virial=onp.reshape(virial, (n_samples, 3, 3))
    )

    # Do not shuffle to use same splits as in the paper
    splits = preprocessing.train_val_test_split(
        dataset, train_ratio=train_ratio, val_ratio=val_ratio, shuffle=False)

    return splits


def get_train_val_test_set(dir_files):
    """This function takes in a list of directories that contain
        box, coord, energy, force, virial npy files and returns for each
        of these properties a train(70%), validation(10%), and test(20%) set in a new format.
        box = np.array(n_structures, np.array(3,3))
        coords = [n_structures, np.array(atoms_strucutre,3)]
        energy = np.array(n_structures)
        force = [n_structures, np.array(atoms_strucutre,3)]
        virial = np.array(n_structures, np.array(3,3))"""

    # Initialize arrays to store the data
    dataset = dict(
        training=dict(box=[], R=[], U=[], F=[], virial=[], type=[]),
        validation=dict(box=[], R=[], U=[], F=[], virial=[], type=[]),
        testing=dict(box=[], R=[], U=[], F=[], virial=[], type=[])
    )

    # Load the data from all provided files
    for i in range(len(dir_files)):
        train_split, val_split, test_split = load_subset(dir_files[i])

        for k in dataset['training'].keys():
            dataset['training'][k].append(train_split[k])
            dataset['validation'][k].append(val_split[k])
            dataset['testing'][k].append(test_split[k])

    # Concatenate to single arrays
    for split in dataset.keys():
        for quantity in dataset[split].keys():
            dataset[split][quantity] = onp.concatenate(dataset[split][quantity], axis=0)

    return dataset


def scale_dataset(dataset, scale_U=1.0, scale_R=1.0, fractional=True):
    """Scales a dataset of positions from real space to fractional coordinates.

    Args:
        traj: A (N_snapshots, N_particles, 3) array of particle positions
        boxes: A (N_snapshots, 1 or 2-dimensional jax_md box)

    Returns:
        A (N_snapshots, N_particles, 3) array of particle positions in
        fractional coordinates.
    """

    scale_F = scale_U / scale_R
    for split in dataset.keys():

        _, scale_fn = custom_space.init_fractional_coordinates(dataset[split]['box'][0])
        vmap_scale_fn = jax.vmap(lambda R, box: scale_fn(R, box=box), in_axes=(0, 0))

        if fractional:
            dataset[split]['R'] = vmap_scale_fn(dataset[split]['R'], dataset[split]['box'])
        else:
            dataset[split]['R'] = dataset[split]['R'] * scale_R

        dataset[split]['box'] *= scale_R
        dataset[split]['U'] *= scale_U
        dataset[split]['F'] *= scale_F
        
        # Scale virial by volume as in the chemntrain implementation and invert
        # the sign.
        volumes = jax.vmap(snapshot_quantity.volume, (None, 0))(3, dataset[split]['box'])
        dataset[split]['virial'] *= -scale_U / volumes[:, None, None]
            
        # Only bulk virials should contribute to the loss. Correct for
        # the reduction in the number of samples
        type = dataset[split]['type']
        virial_weights = (1 - type) / onp.mean(1 - type)
        
        assert onp.all(virial_weights >= 0), "Virial weights should be positive."
        assert onp.isclose(onp.mean(virial_weights), 1.0)
        
        dataset[split]['virial_weights'] = virial_weights
   
    return dataset

The following section downloads the reference data from the original paper[1].

Hide code cell content

# Load data from the link provided in the paper

url = "https://github.com/tummfm/Fused-EXP-DFT-MLP/raw/main/Dataset/Data_DFT_and_Exp.zip"

data_dir = Path("../_data")
data_dir.mkdir(exist_ok=True)

if not (data_dir / "TI_DFT_EXP").exists():
    request.urlretrieve(url, data_dir / "TI_DFT_EXP.zip")

with zipfile.ZipFile(data_dir / "TI_DFT_EXP.zip") as zip_f:
    zip_f.extractall(data_dir / "TI_DFT_EXP")

dft_path = data_dir / "TI_DFT_EXP" / "Data DFT and EXP" / "DFT_data" / "InitAndBulk_256atoms_curatedData.zip"
with zipfile.ZipFile(dft_path) as zip_f:
    zip_f.extractall(dft_path.parent / "InitAndBulk_256atoms_curatedData")
    
exp_path = data_dir / "TI_DFT_EXP" / "Data DFT and EXP" / "Exp_data" / "Exp_Boxes_AtomPositions" / "ExperimentalLattice_Boxes_AtomPositions.zip"
with zipfile.ZipFile(exp_path) as zip_f:
    zip_f.extractall(exp_path.parent / "ExperimentalLattice_Boxes_AtomPositions")

data_list = [dft_path.parent / "InitAndBulk_256atoms_curatedData" / "261022_AllInitAndBulk_256atoms_with_types_curatedData"]

predef_weights = onp.load(data_list[0] / 'types.npy')
predef_weights = jnp.array(predef_weights)

scale_energy = 96.4853722                  # [eV] ->   [kJ/mol]
scale_pos = 0.1                            # [Å] -> [nm]


dataset = get_train_val_test_set(data_list)
dataset = scale_dataset(dataset, scale_R=scale_pos, scale_U=scale_energy, fractional=True)

Model definition#

To learn the underlying potential energy surface, we employ the grap neural network architecture DimeNet++ [2].

print(f"Dataset format: {tree_util.tree_map(jnp.shape, dataset)}")

# Set up NN model
r_init = jnp.asarray(dataset['training']['R'][0])
species_init = jnp.ones_like(r_init[..., 0], dtype=int)
box_init = jnp.asarray(dataset['training']['box'][0])

n_species = 10
r_cut = 0.5

fractional = True
displacement_fn, shift_fn = space.periodic_general(
    box_init, fractional_coordinates=fractional)


neighbor_fn = partition.neighbor_list(
    displacement_fn, box_init, r_cut, disable_cell_list=True,
    fractional_coordinates=fractional, capacity_multiplier=2.5
)

nbrs_init = neighbor_fn.allocate(r_init, extra_capacity=3)
key = random.PRNGKey(21)
mlp_init = {
    'b_init': hk.initializers.Constant(0.),
    'w_init': layers.OrthogonalVarianceScalingInit(scale=1.)
}

max_edges, max_triplets = (12000, 470000)
init_fn, gnn_energy_fn = neural_networks.dimenetpp_neighborlist(
    displacement_fn, r_cut, n_species, embed_size=32, init_kwargs=mlp_init,
    max_edges=max_edges, max_triplets=max_triplets
)

# Load a pretrained model
init_params = onp.load("../_data/output/AT_TI_pretrained_bottom_up.pkl", allow_pickle=True)
init_params = tree_util.tree_map(jnp.asarray, init_params)

def energy_fn_template(energy_params):
    
    def energy_fn(pos, neighbor, **dynamic_kwargs):
        assert 'box' in dynamic_kwargs.keys(), 'box not in dynamic_kwargs'
        
        # We only have one type of particle
        species=jnp.ones(pos.shape[0], dtype=int)
        
        gnn_energy = gnn_energy_fn(
            energy_params, pos, neighbor, species=species, **dynamic_kwargs
        )
        
        return gnn_energy
    return energy_fn
Dataset format: {'testing': {'F': (1142, 256, 3), 'R': (1142, 256, 3), 'U': (1142,), 'box': (1142, 3, 3), 'type': (1142,), 'virial': (1142, 3, 3), 'virial_weights': (1142,)}, 'training': {'F': (3992, 256, 3), 'R': (3992, 256, 3), 'U': (3992,), 'box': (3992, 3, 3), 'type': (3992,), 'virial': (3992, 3, 3), 'virial_weights': (3992,)}, 'validation': {'F': (570, 256, 3), 'R': (570, 256, 3), 'U': (570,), 'box': (570, 3, 3), 'type': (570,), 'virial': (570, 3, 3), 'virial_weights': (570,)}}

Setting up Bottom-Up Training#

# First define new bottom-up trainer for combined training
initial_lr_fm_1 = 5e-7
num_epochs_fused = 100
fm_batch_size = 30
fm_batch_cache = 8
fm_batch_per_device = fm_batch_size // jax.device_count()

num_transition_steps_fm_1 = int(
    num_epochs_fused * dataset['training']['U'].size
) // fm_batch_per_device


lr_schedule_fm = optax.exponential_decay(initial_lr_fm_1, num_transition_steps_fm_1, 0.01)
optimizer_fm = optax.chain(
    optax.scale_by_adam(),
    optax.scale_by_learning_rate(lr_schedule_fm, flip_sign=True)
)

trainer_fm = trainers.ForceMatching(
    init_params, optimizer_fm, energy_fn_template, nbrs_init,
    batch_per_device=fm_batch_per_device,
    batch_cache=fm_batch_cache,
    gammas={
        'virial': 4e-6,'U': 1e-6, 'F': 1e-2
    },
    additional_targets={
        'virial': custom_quantity.init_virial_stress_tensor(
            energy_fn_template, reference_box=None, include_kinetic=False)
    },
    weights_keys={
        'virial': 'virial_weights'
    }
)

trainer_fm.set_dataset(dataset['training'], stage='training')
trainer_fm.set_dataset(dataset['validation'], stage='validation', include_all=True)
trainer_fm.set_dataset(dataset['testing'], stage='testing', include_all=True)
test_predictions = trainer_fm.predict(dataset['testing'], batch_size=fm_batch_size)
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(11, 5), layout="constrained")

fig.suptitle("Predictions on Testset")

mae = onp.mean(onp.abs(test_predictions['U'] - dataset['testing']['U'])) / scale_energy / 256
ax1.set_title(f"Energy (MAE: {mae*1000:.1f} meV/atom)")
ax1.plot(dataset['testing']['U'] / scale_energy / 256, test_predictions['U'] / scale_energy / 256, "*")
ax1.set_xlabel("Ref. U [eV/atom]")
ax1.set_ylabel("Pred. U [eV/atom]")

mae = onp.mean(onp.abs(test_predictions['F'] - dataset['testing']['F'])) / scale_energy * scale_pos
ax2.set_title(f"Force (MAE: {mae*1000:.1f} meV/A)")
ax2.plot(dataset['testing']['F'][::50].ravel() / scale_energy * scale_pos, test_predictions['F'][::50].ravel() / scale_energy * scale_pos, "*")
ax2.set_xlabel("Ref. F [eV/A]")
ax2.set_ylabel("Pred. F [eV/A]")

mae = onp.mean(onp.abs(test_predictions['virial'] - dataset['testing']['virial'])) / scale_energy * (scale_pos ** 3)
ax3.set_title(f"Virial (MAE: {mae*1000:.1f} meV/A^3)")
ax3.plot(dataset['testing']['virial'][dataset['testing']['type'] == 0].ravel() / scale_energy * (scale_pos ** 3), test_predictions['virial'][dataset['testing']['type'] == 0].ravel() / scale_energy * (scale_pos ** 3), "*")
ax3.set_xlabel("Ref. W [eV/A^3]")
ax3.set_ylabel("Pred. W [eV/A^3]")

fig.savefig("../_data/output/TI_fm_predictions.pdf", bbox_inches="tight")
../_images/33c93fc848611539f4fe944e4710113a8669f4219314fb058a0d8a9640ce5496.svg

Loading Top-Down Reference Data#

Load the top-down learning targets, i.e., experimental elastic constants at 323 and 923 K.

data = onp.loadtxt(data_dir / "TI_DFT_EXP/Data DFT and EXP/Exp_data/Exp_EC/EC_From_ExpLatticeConstants.txt", skiprows=1)

def preprocess_constants(raw_data):
    raw_data *= 100 # to GPa
    elastic_constants = onp.round(raw_data, decimals=1)
    elastic_constants *= 10 ** 3 / 1.66054 #  and from GPa to kJ/mol/nm^3
    
    elastic_constants = elastic_constants[(0, 3, 4, 1, 2),]
    return elastic_constants

# Preprocess the reference data
exp_data = {t: preprocess_constants(r) for t, r in zip(data[:, 0], data[:, 2:])}
# We select only a subset of the reference data
temps = onp.asarray([323, 923])
mass = 47.867  # mass of Ti atoms in u

base_path = data_dir / 'TI_DFT_EXP/Data DFT and EXP/Exp_data/Exp_Boxes_AtomPositions/ExperimentalLattice_Boxes_AtomPositions/ExperimentalLattice_Boxes_AtomPositions'

init_kwargs = {}
state_kwargs = {}
for t in temps:
    
    # Convert from A to nm
    box = jnp.array(onp.load(base_path / f'{t}K_expt_box.npy')) / 10.
    coords = jnp.array(onp.load(base_path / f'{t}K_expt_coordinates.npy')) / 10.
    
    assert coords.shape == (150,3), (
        f"Shape of coordinates at temperatre {t} is {coords.shape}."
    )
    
    init_box, scale_fn = custom_space.init_fractional_coordinates(box)
    
    # Convert temperature to units of kb
    state_kwargs[t] = {
        'kT': jnp.asarray(t * quantity.kb), 'box': init_box
    }
    init_kwargs[t] = {
        'r_init': scale_fn(coords)
    }
    
    assert init_kwargs[t]['r_init'].shape == (150,3)

Setting up Simulation#

We employ molecular dynamics with Langevin dynamics to simulate an NVT ensemble and sample approximately independent states for the top-down learning procedure.

# Setup simulation
gamma_sim = 100
dt = 0.002
total_time_difftre = 20.
t_equilib_difftre = 5.  # discard all states within the first 5 ps as equilibration
print_every_difftre = 0.05  # save states every 0.05 ps for use in averaging

timings = sampling.process_printouts(
    dt, total_time_difftre, t_equilib_difftre, print_every_difftre
)

# All systems have the sambe number of particles, so we can use the first
# system as reference
box_init_difftre = state_kwargs[temps[0]]['box']
r_init_difftre = init_kwargs[temps[0]]['r_init']

displacement_fn, shift_fn = space.periodic_general(box_init_difftre)
neighbor_fn_difftre = partition.neighbor_list(
    displacement_or_metric=displacement_fn, box=box_init_difftre, r_cutoff=r_cut, disable_cell_list=True,
    fractional_coordinates=fractional, capacity_multiplier=2.5
)

nbrs_init_difftre = neighbor_fn_difftre.allocate(r_init_difftre, extra_capacity=1)

init_ref_state, sim_template = sampling.initialize_simulator_template(
    simulate.nvt_langevin, shift_fn=shift_fn, nbrs=nbrs_init_difftre,
    init_with_PRNGKey=True, extra_simulator_kwargs={"dt": dt, "gamma": gamma_sim, "kT": None}
)

# We initialize a simulation state for all temperatures
reference_states = {}
for t in temps:
    key, split = random.split(key)
    
    init_sim_kwargs = {"mass": mass, "neighbor": nbrs_init_difftre}
    init_sim_kwargs.update(state_kwargs[t])
    init_nbrs_kwargs = {"extra_capacity": 1}
    init_nbrs_kwargs.update(state_kwargs[t])
    
    assert init_kwargs[t]['r_init'].shape[0] == nbrs_init_difftre.idx.shape[0], (
        f"Missmatch in neighborlist shape."
    )
    
    reference_state = init_ref_state(
        split, init_kwargs[t]['r_init'], energy_or_force_fn=energy_fn_template(
        init_params), init_sim_kwargs=init_sim_kwargs,
        init_nbrs_kwargs=init_nbrs_kwargs
    )
    
    reference_states[t] = reference_state

Setting up Top-Down Targets#

The experimental targets are the experimental elastic constants of titanium at $323$ and $939\ \text{K}$, and a pressure of $1\ \text{bar}$.

# Add targets and compute functions
gamma_elastic_contant = 1.e-10
gamma_pressure_scalar = 1.e-9

# We have to define different target values for 
# the statepoints, but can use the same observables
# and snapshot functions
targets = {
    t: {
        'elastic_constants': {
            'target': exp_data[t],
            'gamma': gamma_elastic_contant
        },
        'pressure': {
            'target': jnp.asarray(1. / 16.6054),
            'gamma': gamma_pressure_scalar
        }
    } for t in temps
}

observables = {
    'elastic_constants': quantity.observables.init_born_stiffness_tensor(
            reference_box=None, dof=3 * r_init_difftre.shape[0], kT=None,
            elastic_constant_function=quantity.observables.stiffness_tensor_components_hexagonal_crystal),
    'pressure': quantity.observables.init_traj_mean_fn('pressure')
}

compute_fns = {
    'born_stiffness': custom_quantity.init_stiffness_tensor_stress_fluctuation(
        energy_fn_template, reference_box=None),
    'born_stress': custom_quantity.init_sigma_born(
        energy_fn_template, reference_box=None),
    'pressure': custom_quantity.init_pressure(
        energy_fn_template, reference_box=None)
}
def difftre_log_fn(trainer, *args, **kwargs):
    """Logs the predicted elastic constants for the latest epoch"""
    for key, values in trainer.predictions.items():
        print(f"[Statepoint {key}] Elastic constants:")
        for idx, cst in enumerate(['C11', 'C33', 'C44', 'C12', 'C13']):
            pred = values[trainer._epoch]['elastic_constants'][idx] / 10 ** 3 * 1.66054
            print(f"\t{cst}: {pred:.2f} GPa", flush=True)

Setting up DiffTRe and Interleaved Trainer#

The DiffTRe trainer class denotes the top-down learning procedure. The InterleaveTrainers class is a wrapper around the bottom-up and the top-down trainer.

# Setup optimizer for difftre
check_freq_difftre = 10.
initial_lr_difftre = 0.0002
lr_schedule_difftre = optax.exponential_decay(initial_lr_difftre, num_epochs_fused, 0.1)
optimizer_difftre = optax.chain(
    optax.scale_by_adam(0.1, 0.4),
    optax.scale_by_learning_rate(lr_schedule_difftre, flip_sign=True)
)

trainer_difftre = trainers.Difftre(
    init_params, optimizer_difftre, energy_fn_template=energy_fn_template,
    checkpoint_path=Path('atomistic_titanium/trained_models/checkpoints'),
    sim_batch_size=-1)

# Add a separate statepoint for each temperature. We can re-use the
# snapshot and observable functions for all statepoints
for t in temps:
    trainer_difftre.add_statepoint(
        energy_fn_template, sim_template, neighbor_fn_difftre, timings,
        state_kwargs=state_kwargs[t], quantities=compute_fns,
        reference_state=reference_states[t],
        observables=observables, targets=targets[t])

# Log the predicted elastic constants after each epoch
trainer_difftre.add_task("post_epoch", difftre_log_fn)

trainer_fused = trainers.InterleaveTrainers(
    sequential=True, checkpoint_base_path='atomistic_titanium/trained_models/interleave_trainer',
    reference_energy_fn_template=energy_fn_template, full_checkpoint=False)

trainer_fused.add_trainer(trainer_fm, num_updates=1, name='Force and Energy Matching')
trainer_fused.add_trainer(trainer_difftre, num_updates=1, name='Difftre')

Hide code cell output

/home/paul/chemtrain-dev/chemtrain/ensemble/reweighting.py:777: UserWarning: Propagation function is not safe by default. Do not forget to use the wrapper around the compute function to ensure that the neighborlist does not overflow.
  warnings.warn(
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:195: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:195: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:195: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/reductions.py:213: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in sum is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/reductions.py:213: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in sum is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:195: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:195: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:195: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:195: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:195: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
Time for trajectory initialization 0: 1.0241082549095153 mins
/home/paul/chemtrain-dev/chemtrain/ensemble/reweighting.py:777: UserWarning: Propagation function is not safe by default. Do not forget to use the wrapper around the compute function to ensure that the neighborlist does not overflow.
  warnings.warn(
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:195: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:195: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:195: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:195: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:195: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:195: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:195: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:195: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
Time for trajectory initialization 1: 1.0049342354138693 mins
if os.environ.get("FUSED_TRAINING", "False").lower() == "true":
    with open("../_data/output/Fused_AT_Ti_training.log", "w") as f:
        with contextlib.redirect_stdout(f), contextlib.redirect_stderr(f):
            start = time.time()
            trainer_fused.train(num_epochs_fused, checkpoint_frequency=10)
            print(f"Total training time: {(time.time() - start) / 3600 : .1f} hours")
    trainer_fused.save_energy_params("../_data/output/Fused_AT_Ti_params.pkl", '.pkl', best=False)
    trainer_fused.save_trainer("../_data/output/Fused_AT_Ti_trainer.pkl", '.pkl')

trainer_fused = onp.load("../_data/output/Fused_AT_Ti_trainer.pkl", allow_pickle=True)
trainer_fused_params = tree_util.tree_map(
    jnp.asarray, onp.load("../_data/output/Fused_AT_Ti_params.pkl", allow_pickle=True)
)

with open("../_data/output/Fused_AT_Ti_training.log") as f:
    print(f.read())

Hide code cell output

---------Starting trainer Force and Energy Matching for 1 updates -----------
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
[Force] Found precomputed forces.
[Potential] Found precomputed forces.
[Virial] Found precomputed forces.
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
[Force] Found precomputed forces.
[Potential] Found precomputed forces.
[Virial] Found precomputed forces.
[Epoch 0]:
	Average train loss: 61.19277
	Average val loss: 73.03468322753906
	Gradient norm: 621970.8125
	Elapsed time = 3.114 min
	Per-target losses:
		F | train loss: 6108.517939746828 | val loss: 7291.6787109375
		U | train loss: 14591.104389391447 | val loss: 14301.1015625
		virial | train loss: 23250.070630066377 | val loss: 25897.12109375

---------Starting trainer Difftre for 1 updates -----------
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/reductions.py:221: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in sum is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric,
/home/paul/miniconda3/envs/chemtrain/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:166: UserWarning: Explicitly requested dtype float64 requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  return asarray(x, dtype=self.dtype)

[DiffTRe] Epoch 0
	Epoch loss = 0.00408
	Gradient norm: 1.7344677448272705
	Elapsed time = 10.034 min
[Statepoint 0]
	kT = 2.675 ref_kT = 2.686
	Predicted entropy: -0.016579577699303627
	Predicted free_energy: -38.480567932128906
	Predicted pressure: -834.55078125
[Statepoint 1]
	kT = 7.661 ref_kT = 7.674
	Predicted entropy: -0.010593105107545853
	Predicted free_energy: -27.34140396118164
	Predicted pressure: -816.9301147460938
[Statepoint 0] Elastic constants:
	C11: 154.08 GPa
	C33: 173.38 GPa
	C44: 35.42 GPa
	C12: 77.87 GPa
	C13: 69.42 GPa
[Statepoint 1] Elastic constants:
	C11: 124.34 GPa
	C33: 145.82 GPa
	C44: 27.52 GPa
	C12: 81.49 GPa
	C13: 67.84 GPa
Finished epoch 0 for all trainers in  13.52 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 1]:
	Average train loss: 127.02905
	Average val loss: 97.69428253173828
	Gradient norm: 105618040.0
	Elapsed time = 0.885 min
	Per-target losses:
		F | train loss: 11854.612088081532 | val loss: 9277.7587890625
		U | train loss: 2130390.2403665413 | val loss: 462353.03125
		virial | train loss: 1588135.6016212406 | val loss: 1113586.625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 1
	Epoch loss = 0.02460
	Gradient norm: 25.922208786010742
	Elapsed time = 3.040 min
[Statepoint 0]
	kT = 2.671 ref_kT = 2.686
	Predicted entropy: -0.27595677971839905
	Predicted free_energy: -858.64501953125
	Predicted pressure: 307.8705139160156
[Statepoint 1]
	kT = 7.683 ref_kT = 7.674
	Predicted entropy: -0.3169091045856476
	Predicted free_energy: -692.7993774414062
	Predicted pressure: 102.58610534667969
[Statepoint 0] Elastic constants:
	C11: 202.53 GPa
	C33: 223.46 GPa
	C44: 44.53 GPa
	C12: 104.88 GPa
	C13: 93.59 GPa
[Statepoint 1] Elastic constants:
	C11: 161.15 GPa
	C33: 188.71 GPa
	C44: 36.50 GPa
	C12: 104.89 GPa
	C13: 87.83 GPa
Finished epoch 1 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 2]:
	Average train loss: 92.51104
	Average val loss: 85.89185333251953
	Gradient norm: 53836580.0
	Elapsed time = 0.879 min
	Per-target losses:
		F | train loss: 8764.924657835996 | val loss: 8384.5224609375
		U | train loss: 3339023.897086466 | val loss: 1133660.25
		virial | train loss: 380692.9406132519 | val loss: 228242.03125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 2
	Epoch loss = 0.03065
	Gradient norm: 293.8670349121094
	Elapsed time = 3.008 min
[Statepoint 0]
	kT = 2.676 ref_kT = 2.686
	Predicted entropy: 0.21706552803516388
	Predicted free_energy: 882.3873291015625
	Predicted pressure: -29.523326873779297
[Statepoint 1]
	kT = 7.695 ref_kT = 7.674
	Predicted entropy: 0.2927267849445343
	Predicted free_energy: 676.7294921875
	Predicted pressure: 36.63535690307617
[Statepoint 0] Elastic constants:
	C11: 95.08 GPa
	C33: 157.70 GPa
	C44: 26.62 GPa
	C12: 125.11 GPa
	C13: 66.40 GPa
[Statepoint 1] Elastic constants:
	C11: 91.34 GPa
	C33: 135.90 GPa
	C44: 22.24 GPa
	C12: 109.21 GPa
	C13: 68.02 GPa
Finished epoch 2 for all trainers in  3.94 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 3]:
	Average train loss: 83.68477
	Average val loss: 85.26952362060547
	Gradient norm: 38489396.0
	Elapsed time = 0.882 min
	Per-target losses:
		F | train loss: 8219.547370623824 | val loss: 8447.7041015625
		U | train loss: 655649.969631109 | val loss: 183838.65625
		virial | train loss: 208412.88187265038 | val loss: 152160.890625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 3
	Epoch loss = 0.01602
	Gradient norm: 8.50614070892334
	Elapsed time = 3.048 min
[Statepoint 0]
	kT = 2.695 ref_kT = 2.686
	Predicted entropy: -0.3312065601348877
	Predicted free_energy: -699.5762329101562
	Predicted pressure: -859.5714721679688
[Statepoint 1]
	kT = 7.632 ref_kT = 7.674
	Predicted entropy: -0.06371491402387619
	Predicted free_energy: -500.88507080078125
	Predicted pressure: -639.8872680664062
[Statepoint 0] Elastic constants:
	C11: 208.75 GPa
	C33: 191.61 GPa
	C44: 48.33 GPa
	C12: 68.84 GPa
	C13: 75.79 GPa
[Statepoint 1] Elastic constants:
	C11: 158.44 GPa
	C33: 165.10 GPa
	C44: 38.16 GPa
	C12: 81.82 GPa
	C13: 73.55 GPa
Finished epoch 3 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 4]:
	Average train loss: 68.07038
	Average val loss: 77.6180419921875
	Gradient norm: 1912670.875
	Elapsed time = 0.881 min
	Per-target losses:
		F | train loss: 6726.067738927397 | val loss: 7701.01611328125
		U | train loss: 396989.29992951127 | val loss: 193573.71875
		virial | train loss: 103178.2598243656 | val loss: 103576.2890625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 4
	Epoch loss = 0.00585
	Gradient norm: 1.8047772645950317
	Elapsed time = 3.033 min
[Statepoint 0]
	kT = 2.694 ref_kT = 2.686
	Predicted entropy: -0.23218852281570435
	Predicted free_energy: -51.30207443237305
	Predicted pressure: -576.2041015625
[Statepoint 1]
	kT = 7.664 ref_kT = 7.674
	Predicted entropy: 0.05917840078473091
	Predicted free_energy: 63.035282135009766
	Predicted pressure: -583.1804809570312
[Statepoint 0] Elastic constants:
	C11: 180.84 GPa
	C33: 174.10 GPa
	C44: 43.42 GPa
	C12: 72.44 GPa
	C13: 66.75 GPa
[Statepoint 1] Elastic constants:
	C11: 144.24 GPa
	C33: 152.91 GPa
	C44: 33.85 GPa
	C12: 78.10 GPa
	C13: 66.45 GPa
Finished epoch 4 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 5]:
	Average train loss: 66.62759
	Average val loss: 76.0431900024414
	Gradient norm: 2443218.5
	Elapsed time = 0.879 min
	Per-target losses:
		F | train loss: 6508.945169319783 | val loss: 7504.4208984375
		U | train loss: 551689.5032894737 | val loss: 251759.453125
		virial | train loss: 246611.29469425516 | val loss: 186806.15625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 5
	Epoch loss = 0.00364
	Gradient norm: 3.6675589084625244
	Elapsed time = 3.032 min
[Statepoint 0]
	kT = 2.666 ref_kT = 2.686
	Predicted entropy: -0.14223842322826385
	Predicted free_energy: 149.51287841796875
	Predicted pressure: -245.0475616455078
[Statepoint 1]
	kT = 7.623 ref_kT = 7.674
	Predicted entropy: 0.13995888829231262
	Predicted free_energy: 190.74667358398438
	Predicted pressure: -161.8952178955078
[Statepoint 0] Elastic constants:
	C11: 162.34 GPa
	C33: 172.53 GPa
	C44: 38.18 GPa
	C12: 83.87 GPa
	C13: 68.02 GPa
[Statepoint 1] Elastic constants:
	C11: 106.14 GPa
	C33: 149.93 GPa
	C44: 30.05 GPa
	C12: 111.95 GPa
	C13: 67.78 GPa
Finished epoch 5 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 6]:
	Average train loss: 77.20200
	Average val loss: 80.23603820800781
	Gradient norm: 19579644.0
	Elapsed time = 0.884 min
	Per-target losses:
		F | train loss: 7547.265037593985 | val loss: 7899.97607421875
		U | train loss: 169750.3536624765 | val loss: 72568.4765625
		virial | train loss: 389901.2630991541 | val loss: 290926.0625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 6
	Epoch loss = 0.00665
	Gradient norm: 4.708769798278809
	Elapsed time = 3.046 min
[Statepoint 0]
	kT = 2.677 ref_kT = 2.686
	Predicted entropy: -0.2704162895679474
	Predicted free_energy: -346.7611389160156
	Predicted pressure: -183.85287475585938
[Statepoint 1]
	kT = 7.673 ref_kT = 7.674
	Predicted entropy: 0.008241275325417519
	Predicted free_energy: -215.4674530029297
	Predicted pressure: -224.67599487304688
[Statepoint 0] Elastic constants:
	C11: 188.77 GPa
	C33: 189.91 GPa
	C44: 44.32 GPa
	C12: 79.75 GPa
	C13: 76.05 GPa
[Statepoint 1] Elastic constants:
	C11: 151.10 GPa
	C33: 164.25 GPa
	C44: 35.31 GPa
	C12: 85.15 GPa
	C13: 74.27 GPa
Finished epoch 6 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 7]:
	Average train loss: 69.41270
	Average val loss: 77.03765106201172
	Gradient norm: 3331178.25
	Elapsed time = 0.881 min
	Per-target losses:
		F | train loss: 6680.386795847039 | val loss: 7578.48486328125
		U | train loss: 1711282.9478383458 | val loss: 582385.8125
		virial | train loss: 224386.83734727444 | val loss: 167605.046875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 7
	Epoch loss = 0.01757
	Gradient norm: 67.33468627929688
	Elapsed time = 3.019 min
[Statepoint 0]
	kT = 2.671 ref_kT = 2.686
	Predicted entropy: 0.03036479279398918
	Predicted free_energy: 486.78155517578125
	Predicted pressure: -148.44235229492188
[Statepoint 1]
	kT = 7.642 ref_kT = 7.674
	Predicted entropy: 0.2542579770088196
	Predicted free_energy: 444.1061096191406
	Predicted pressure: 162.9320068359375
[Statepoint 0] Elastic constants:
	C11: 112.43 GPa
	C33: 166.50 GPa
	C44: 33.84 GPa
	C12: 118.93 GPa
	C13: 67.00 GPa
[Statepoint 1] Elastic constants:
	C11: 96.71 GPa
	C33: 142.36 GPa
	C44: 26.59 GPa
	C12: 107.76 GPa
	C13: 66.38 GPa
Finished epoch 7 for all trainers in  3.96 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 8]:
	Average train loss: 96.03573
	Average val loss: 90.90355682373047
	Gradient norm: 23380672.0
	Elapsed time = 0.884 min
	Per-target losses:
		F | train loss: 9421.100633664239 | val loss: 8986.5390625
		U | train loss: 593758.8831355734 | val loss: 149849.734375
		virial | train loss: 307741.56251468515 | val loss: 222078.53125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 8
	Epoch loss = 0.01403
	Gradient norm: 13.81141185760498
	Elapsed time = 3.040 min
[Statepoint 0]
	kT = 2.667 ref_kT = 2.686
	Predicted entropy: -0.3896113634109497
	Predicted free_energy: -714.103759765625
	Predicted pressure: -632.54443359375
[Statepoint 1]
	kT = 7.667 ref_kT = 7.674
	Predicted entropy: -0.17299196124076843
	Predicted free_energy: -472.1962585449219
	Predicted pressure: -531.84912109375
[Statepoint 0] Elastic constants:
	C11: 201.18 GPa
	C33: 200.18 GPa
	C44: 50.94 GPa
	C12: 74.18 GPa
	C13: 75.69 GPa
[Statepoint 1] Elastic constants:
	C11: 157.37 GPa
	C33: 170.45 GPa
	C44: 40.44 GPa
	C12: 80.71 GPa
	C13: 72.62 GPa
Finished epoch 8 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 9]:
	Average train loss: 68.81170
	Average val loss: 77.39753723144531
	Gradient norm: 5337505.5
	Elapsed time = 0.886 min
	Per-target losses:
		F | train loss: 6681.771095218515 | val loss: 7631.06201171875
		U | train loss: 1356816.8411654136 | val loss: 517985.25
		virial | train loss: 159294.05738956766 | val loss: 142232.515625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 9
	Epoch loss = 0.00334
	Gradient norm: 0.7888472080230713
	Elapsed time = 3.037 min
[Statepoint 0]
	kT = 2.694 ref_kT = 2.686
	Predicted entropy: -0.07724928110837936
	Predicted free_energy: 244.31846618652344
	Predicted pressure: -35.30990219116211
[Statepoint 1]
	kT = 7.662 ref_kT = 7.674
	Predicted entropy: 0.14242160320281982
	Predicted free_energy: 292.0471496582031
	Predicted pressure: 61.364253997802734
[Statepoint 0] Elastic constants:
	C11: 164.60 GPa
	C33: 172.25 GPa
	C44: 39.72 GPa
	C12: 72.35 GPa
	C13: 66.71 GPa
[Statepoint 1] Elastic constants:
	C11: 127.59 GPa
	C33: 146.40 GPa
	C44: 31.18 GPa
	C12: 83.18 GPa
	C13: 65.97 GPa
Finished epoch 9 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 10]:
	Average train loss: 66.99371
	Average val loss: 76.66167449951172
	Gradient norm: 4592459.5
	Elapsed time = 0.881 min
	Per-target losses:
		F | train loss: 6568.639754904841 | val loss: 7567.62890625
		U | train loss: 241937.91312265038 | val loss: 155759.609375
		virial | train loss: 266343.74165883457 | val loss: 207405.375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 10
	Epoch loss = 0.00214
	Gradient norm: 2.3169281482696533
	Elapsed time = 3.019 min
[Statepoint 0]
	kT = 2.687 ref_kT = 2.686
	Predicted entropy: -0.06885744631290436
	Predicted free_energy: 18.79383087158203
	Predicted pressure: 153.2273406982422
[Statepoint 1]
	kT = 7.640 ref_kT = 7.674
	Predicted entropy: 0.1307915896177292
	Predicted free_energy: 64.47624206542969
	Predicted pressure: 28.258939743041992
[Statepoint 0] Elastic constants:
	C11: 154.56 GPa
	C33: 185.20 GPa
	C44: 37.04 GPa
	C12: 97.90 GPa
	C13: 75.84 GPa
[Statepoint 1] Elastic constants:
	C11: 114.64 GPa
	C33: 159.59 GPa
	C44: 30.43 GPa
	C12: 108.62 GPa
	C13: 73.30 GPa
Finished epoch 10 for all trainers in  3.96 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 11]:
	Average train loss: 77.68928
	Average val loss: 82.61542510986328
	Gradient norm: 14598580.0
	Elapsed time = 0.886 min
	Per-target losses:
		F | train loss: 7653.509082765508 | val loss: 8194.0380859375
		U | train loss: 451726.97071781015 | val loss: 101334.390625
		virial | train loss: 175614.96518885103 | val loss: 143427.40625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 11
	Epoch loss = 0.00347
	Gradient norm: 4.153076648712158
	Elapsed time = 3.032 min
[Statepoint 0]
	kT = 2.694 ref_kT = 2.686
	Predicted entropy: -0.22466683387756348
	Predicted free_energy: -579.7315063476562
	Predicted pressure: -589.1024169921875
[Statepoint 1]
	kT = 7.693 ref_kT = 7.674
	Predicted entropy: -0.05698510631918907
	Predicted free_energy: -407.595947265625
	Predicted pressure: -389.77838134765625
[Statepoint 0] Elastic constants:
	C11: 177.47 GPa
	C33: 187.22 GPa
	C44: 45.11 GPa
	C12: 85.34 GPa
	C13: 72.48 GPa
[Statepoint 1] Elastic constants:
	C11: 146.23 GPa
	C33: 163.21 GPa
	C44: 36.63 GPa
	C12: 85.77 GPa
	C13: 70.53 GPa
Finished epoch 11 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 12]:
	Average train loss: 77.46583
	Average val loss: 80.5436019897461
	Gradient norm: 13973154.0
	Elapsed time = 0.883 min
	Per-target losses:
		F | train loss: 7531.878528107378 | val loss: 7943.48046875
		U | train loss: 1324754.6038533836 | val loss: 452311.4375
		virial | train loss: 205572.57979910713 | val loss: 164121.84375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 12
	Epoch loss = 0.00736
	Gradient norm: 4.82000732421875
	Elapsed time = 3.023 min
[Statepoint 0]
	kT = 2.668 ref_kT = 2.686
	Predicted entropy: 0.1060604527592659
	Predicted free_energy: 423.52392578125
	Predicted pressure: -286.0054626464844
[Statepoint 1]
	kT = 7.704 ref_kT = 7.674
	Predicted entropy: 0.2785276472568512
	Predicted free_energy: 361.3257751464844
	Predicted pressure: -221.04336547851562
[Statepoint 0] Elastic constants:
	C11: 139.86 GPa
	C33: 160.77 GPa
	C44: 33.49 GPa
	C12: 84.58 GPa
	C13: 65.48 GPa
[Statepoint 1] Elastic constants:
	C11: 105.03 GPa
	C33: 140.14 GPa
	C44: 27.31 GPa
	C12: 98.33 GPa
	C13: 64.92 GPa
Finished epoch 12 for all trainers in  3.96 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 13]:
	Average train loss: 82.07360
	Average val loss: 83.55987548828125
	Gradient norm: 20259440.0
	Elapsed time = 0.884 min
	Per-target losses:
		F | train loss: 7952.679860050517 | val loss: 8171.85400390625
		U | train loss: 231863.4720982143 | val loss: 78868.671875
		virial | train loss: 578734.2640390038 | val loss: 440616.1875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 13
	Epoch loss = 0.00352
	Gradient norm: 1.4518791437149048
	Elapsed time = 3.044 min
[Statepoint 0]
	kT = 2.663 ref_kT = 2.686
	Predicted entropy: -0.11992435157299042
	Predicted free_energy: -406.4053955078125
	Predicted pressure: -19.296558380126953
[Statepoint 1]
	kT = 7.618 ref_kT = 7.674
	Predicted entropy: -0.045557599514722824
	Predicted free_energy: -298.8771667480469
	Predicted pressure: -164.86697387695312
[Statepoint 0] Elastic constants:
	C11: 175.41 GPa
	C33: 195.85 GPa
	C44: 41.84 GPa
	C12: 96.05 GPa
	C13: 79.47 GPa
[Statepoint 1] Elastic constants:
	C11: 131.74 GPa
	C33: 170.08 GPa
	C44: 34.38 GPa
	C12: 110.00 GPa
	C13: 77.38 GPa
Finished epoch 13 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 14]:
	Average train loss: 70.92974
	Average val loss: 79.16064453125
	Gradient norm: 6290501.5
	Elapsed time = 0.885 min
	Per-target losses:
		F | train loss: 6967.983376409775 | val loss: 7820.1796875
		U | train loss: 442274.05486372183 | val loss: 199527.3125
		virial | train loss: 201909.00676251174 | val loss: 189830.265625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 14
	Epoch loss = 0.00343
	Gradient norm: 0.5437045693397522
	Elapsed time = 3.039 min
[Statepoint 0]
	kT = 2.675 ref_kT = 2.686
	Predicted entropy: 0.031239604577422142
	Predicted free_energy: 45.92497634887695
	Predicted pressure: -126.584716796875
[Statepoint 1]
	kT = 7.678 ref_kT = 7.674
	Predicted entropy: 0.10408172756433487
	Predicted free_energy: 70.31317901611328
	Predicted pressure: 217.18565368652344
[Statepoint 0] Elastic constants:
	C11: 154.04 GPa
	C33: 167.90 GPa
	C44: 39.60 GPa
	C12: 84.48 GPa
	C13: 64.65 GPa
[Statepoint 1] Elastic constants:
	C11: 139.07 GPa
	C33: 148.47 GPa
	C44: 32.45 GPa
	C12: 77.89 GPa
	C13: 66.39 GPa
Finished epoch 14 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 15]:
	Average train loss: 70.97965
	Average val loss: 78.83422088623047
	Gradient norm: 6427225.0
	Elapsed time = 0.881 min
	Per-target losses:
		F | train loss: 6904.35174606438 | val loss: 7747.64111328125
		U | train loss: 681433.6548402256 | val loss: 284322.75
		virial | train loss: 313674.91235902254 | val loss: 268370.625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 15
	Epoch loss = 0.00491
	Gradient norm: 3.7854528427124023
	Elapsed time = 3.023 min
[Statepoint 0]
	kT = 2.667 ref_kT = 2.686
	Predicted entropy: 0.16164761781692505
	Predicted free_energy: 299.3247375488281
	Predicted pressure: 128.1702880859375
[Statepoint 1]
	kT = 7.667 ref_kT = 7.674
	Predicted entropy: 0.28600993752479553
	Predicted free_energy: 225.3889617919922
	Predicted pressure: 437.668701171875
[Statepoint 0] Elastic constants:
	C11: 139.38 GPa
	C33: 171.66 GPa
	C44: 33.57 GPa
	C12: 102.03 GPa
	C13: 71.84 GPa
[Statepoint 1] Elastic constants:
	C11: 111.32 GPa
	C33: 150.39 GPa
	C44: 27.17 GPa
	C12: 108.26 GPa
	C13: 73.48 GPa
Finished epoch 15 for all trainers in  3.96 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 16]:
	Average train loss: 81.82453
	Average val loss: 86.49129486083984
	Gradient norm: 13780426.0
	Elapsed time = 0.878 min
	Per-target losses:
		F | train loss: 8031.809144443139 | val loss: 8553.326171875
		U | train loss: 718170.6766917293 | val loss: 265873.625
		virial | train loss: 197066.3631050282 | val loss: 173042.171875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 16
	Epoch loss = 0.00732
	Gradient norm: 4.336783409118652
	Elapsed time = 3.036 min
[Statepoint 0]
	kT = 2.688 ref_kT = 2.686
	Predicted entropy: 0.022642644122242928
	Predicted free_energy: -692.7710571289062
	Predicted pressure: -656.1771240234375
[Statepoint 1]
	kT = 7.684 ref_kT = 7.674
	Predicted entropy: 0.01799602434039116
	Predicted free_energy: -593.6630249023438
	Predicted pressure: -529.0773315429688
[Statepoint 0] Elastic constants:
	C11: 187.52 GPa
	C33: 192.63 GPa
	C44: 45.00 GPa
	C12: 82.89 GPa
	C13: 80.47 GPa
[Statepoint 1] Elastic constants:
	C11: 152.34 GPa
	C33: 167.02 GPa
	C44: 36.23 GPa
	C12: 88.80 GPa
	C13: 79.29 GPa
Finished epoch 16 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 17]:
	Average train loss: 80.84922
	Average val loss: 83.60978698730469
	Gradient norm: 29097542.0
	Elapsed time = 0.879 min
	Per-target losses:
		F | train loss: 7841.624104205827 | val loss: 8207.259765625
		U | train loss: 1586079.7030075188 | val loss: 766295.5
		virial | train loss: 211725.6522850094 | val loss: 192723.640625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 17
	Epoch loss = 0.00664
	Gradient norm: 4.128739833831787
	Elapsed time = 3.012 min
[Statepoint 0]
	kT = 2.681 ref_kT = 2.686
	Predicted entropy: 0.26311010122299194
	Predicted free_energy: 508.33544921875
	Predicted pressure: -169.04873657226562
[Statepoint 1]
	kT = 7.701 ref_kT = 7.674
	Predicted entropy: 0.35088127851486206
	Predicted free_energy: 384.5974426269531
	Predicted pressure: 50.115177154541016
[Statepoint 0] Elastic constants:
	C11: 130.25 GPa
	C33: 162.25 GPa
	C44: 32.24 GPa
	C12: 96.36 GPa
	C13: 68.13 GPa
[Statepoint 1] Elastic constants:
	C11: 123.23 GPa
	C33: 144.12 GPa
	C44: 27.47 GPa
	C12: 88.76 GPa
	C13: 70.93 GPa
Finished epoch 17 for all trainers in  3.94 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 18]:
	Average train loss: 77.83877
	Average val loss: 84.26258087158203
	Gradient norm: 18229148.0
	Elapsed time = 0.882 min
	Per-target losses:
		F | train loss: 7596.016597891213 | val loss: 8272.1044921875
		U | train loss: 210112.22609257518 | val loss: 95710.5703125
		virial | train loss: 417122.8795817669 | val loss: 361455.625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 18
	Epoch loss = 0.00212
	Gradient norm: 1.4901490211486816
	Elapsed time = 3.031 min
[Statepoint 0]
	kT = 2.704 ref_kT = 2.686
	Predicted entropy: -0.04398870840668678
	Predicted free_energy: -452.77337646484375
	Predicted pressure: -81.3863296508789
[Statepoint 1]
	kT = 7.640 ref_kT = 7.674
	Predicted entropy: 0.0214199498295784
	Predicted free_energy: -394.13104248046875
	Predicted pressure: -175.10598754882812
[Statepoint 0] Elastic constants:
	C11: 171.37 GPa
	C33: 190.44 GPa
	C44: 42.60 GPa
	C12: 91.22 GPa
	C13: 78.35 GPa
[Statepoint 1] Elastic constants:
	C11: 139.92 GPa
	C33: 164.75 GPa
	C44: 34.98 GPa
	C12: 97.33 GPa
	C13: 78.40 GPa
Finished epoch 18 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 19]:
	Average train loss: 76.94017
	Average val loss: 83.20471954345703
	Gradient norm: 28211662.0
	Elapsed time = 0.886 min
	Per-target losses:
		F | train loss: 7534.540020706062 | val loss: 8195.2333984375
		U | train loss: 636527.2523496241 | val loss: 335417.25
		virial | train loss: 239559.9176456767 | val loss: 229242.03125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 19
	Epoch loss = 0.00275
	Gradient norm: 1.3023756742477417
	Elapsed time = 3.019 min
[Statepoint 0]
	kT = 2.681 ref_kT = 2.686
	Predicted entropy: 0.1933264434337616
	Predicted free_energy: 194.54000854492188
	Predicted pressure: -63.2870979309082
[Statepoint 1]
	kT = 7.638 ref_kT = 7.674
	Predicted entropy: 0.2631434202194214
	Predicted free_energy: 116.17140197753906
	Predicted pressure: -170.3132781982422
[Statepoint 0] Elastic constants:
	C11: 153.38 GPa
	C33: 166.12 GPa
	C44: 36.80 GPa
	C12: 80.39 GPa
	C13: 68.27 GPa
[Statepoint 1] Elastic constants:
	C11: 128.84 GPa
	C33: 146.73 GPa
	C44: 30.11 GPa
	C12: 86.95 GPa
	C13: 68.63 GPa
Finished epoch 19 for all trainers in  3.96 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 20]:
	Average train loss: 77.51810
	Average val loss: 82.79999542236328
	Gradient norm: 21868298.0
	Elapsed time = 0.880 min
	Per-target losses:
		F | train loss: 7448.975732789004 | val loss: 8040.68408203125
		U | train loss: 324561.2638040414 | val loss: 129737.28125
		virial | train loss: 675946.3893033365 | val loss: 565854.25

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 20
	Epoch loss = 0.00254
	Gradient norm: 1.179019808769226
	Elapsed time = 3.031 min
[Statepoint 0]
	kT = 2.678 ref_kT = 2.686
	Predicted entropy: 0.01051382813602686
	Predicted free_energy: -405.6603698730469
	Predicted pressure: 221.0079803466797
[Statepoint 1]
	kT = 7.664 ref_kT = 7.674
	Predicted entropy: 0.06217942386865616
	Predicted free_energy: -378.9704895019531
	Predicted pressure: 224.65982055664062
[Statepoint 0] Elastic constants:
	C11: 169.28 GPa
	C33: 191.55 GPa
	C44: 40.25 GPa
	C12: 95.49 GPa
	C13: 77.65 GPa
[Statepoint 1] Elastic constants:
	C11: 124.44 GPa
	C33: 165.29 GPa
	C44: 33.30 GPa
	C12: 113.28 GPa
	C13: 76.91 GPa
Finished epoch 20 for all trainers in  3.96 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 21]:
	Average train loss: 78.24550
	Average val loss: 84.64207458496094
	Gradient norm: 17645184.0
	Elapsed time = 0.886 min
	Per-target losses:
		F | train loss: 7753.705805039944 | val loss: 8399.9189453125
		U | train loss: 258657.4538298872 | val loss: 153968.75
		virial | train loss: 112445.33423402255 | val loss: 122229.3203125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 21
	Epoch loss = 0.00168
	Gradient norm: 0.7270141839981079
	Elapsed time = 3.021 min
[Statepoint 0]
	kT = 2.681 ref_kT = 2.686
	Predicted entropy: 0.13060489296913147
	Predicted free_energy: -33.3587646484375
	Predicted pressure: -351.4255065917969
[Statepoint 1]
	kT = 7.653 ref_kT = 7.674
	Predicted entropy: 0.18200360238552094
	Predicted free_energy: -56.37739562988281
	Predicted pressure: -23.77931785583496
[Statepoint 0] Elastic constants:
	C11: 157.02 GPa
	C33: 169.16 GPa
	C44: 39.97 GPa
	C12: 82.60 GPa
	C13: 67.73 GPa
[Statepoint 1] Elastic constants:
	C11: 126.90 GPa
	C33: 145.92 GPa
	C44: 32.46 GPa
	C12: 93.19 GPa
	C13: 70.51 GPa
Finished epoch 21 for all trainers in  3.96 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 22]:
	Average train loss: 78.08946
	Average val loss: 84.691162109375
	Gradient norm: 46822640.0
	Elapsed time = 0.872 min
	Per-target losses:
		F | train loss: 7556.749640213816 | val loss: 8275.783203125
		U | train loss: 630366.1629464285 | val loss: 258212.65625
		virial | train loss: 472898.6200364192 | val loss: 418778.59375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 22
	Epoch loss = 0.00472
	Gradient norm: 3.07832670211792
	Elapsed time = 3.029 min
[Statepoint 0]
	kT = 2.684 ref_kT = 2.686
	Predicted entropy: -0.026417499408125877
	Predicted free_energy: -682.5860595703125
	Predicted pressure: 23.9029598236084
[Statepoint 1]
	kT = 7.671 ref_kT = 7.674
	Predicted entropy: 0.015545133501291275
	Predicted free_energy: -618.9674072265625
	Predicted pressure: 224.52667236328125
[Statepoint 0] Elastic constants:
	C11: 180.02 GPa
	C33: 192.65 GPa
	C44: 43.85 GPa
	C12: 87.81 GPa
	C13: 76.07 GPa
[Statepoint 1] Elastic constants:
	C11: 150.98 GPa
	C33: 169.51 GPa
	C44: 36.76 GPa
	C12: 93.19 GPa
	C13: 75.63 GPa
Finished epoch 22 for all trainers in  3.96 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 23]:
	Average train loss: 79.09395
	Average val loss: 84.64041900634766
	Gradient norm: 45143596.0
	Elapsed time = 0.876 min
	Per-target losses:
		F | train loss: 7694.6592457706765 | val loss: 8315.0732421875
		U | train loss: 1162366.4934210526 | val loss: 647280.9375
		virial | train loss: 246247.0593280075 | val loss: 210603.671875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 23
	Epoch loss = 0.01951
	Gradient norm: 44.36445617675781
	Elapsed time = 3.010 min
[Statepoint 0]
	kT = 2.673 ref_kT = 2.686
	Predicted entropy: 0.3055586516857147
	Predicted free_energy: 528.93115234375
	Predicted pressure: -15.239706039428711
[Statepoint 1]
	kT = 7.689 ref_kT = 7.674
	Predicted entropy: 0.3733007311820984
	Predicted free_energy: 387.6637878417969
	Predicted pressure: 324.7287292480469
[Statepoint 0] Elastic constants:
	C11: 121.16 GPa
	C33: 161.31 GPa
	C44: 32.49 GPa
	C12: 107.03 GPa
	C13: 66.27 GPa
[Statepoint 1] Elastic constants:
	C11: 84.79 GPa
	C33: 141.20 GPa
	C44: 27.54 GPa
	C12: 125.78 GPa
	C13: 68.34 GPa
Finished epoch 23 for all trainers in  3.94 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 24]:
	Average train loss: 78.02711
	Average val loss: 85.64972686767578
	Gradient norm: 16940064.0
	Elapsed time = 0.884 min
	Per-target losses:
		F | train loss: 7709.487139479558 | val loss: 8492.7685546875
		U | train loss: 260281.92310855264 | val loss: 107535.171875
		virial | train loss: 167990.07976973685 | val loss: 153625.9375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 24
	Epoch loss = 0.00383
	Gradient norm: 1.3536980152130127
	Elapsed time = 3.037 min
[Statepoint 0]
	kT = 2.684 ref_kT = 2.686
	Predicted entropy: -0.11052805930376053
	Predicted free_energy: -600.1470947265625
	Predicted pressure: -492.74835205078125
[Statepoint 1]
	kT = 7.632 ref_kT = 7.674
	Predicted entropy: 0.03560691699385643
	Predicted free_energy: -503.41558837890625
	Predicted pressure: -57.24401092529297
[Statepoint 0] Elastic constants:
	C11: 166.99 GPa
	C33: 185.60 GPa
	C44: 45.36 GPa
	C12: 91.88 GPa
	C13: 70.09 GPa
[Statepoint 1] Elastic constants:
	C11: 154.31 GPa
	C33: 165.29 GPa
	C44: 38.24 GPa
	C12: 83.02 GPa
	C13: 71.58 GPa
Finished epoch 24 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 25]:
	Average train loss: 72.30075
	Average val loss: 82.4305648803711
	Gradient norm: 5843546.0
	Elapsed time = 0.873 min
	Per-target losses:
		F | train loss: 7168.961616688205 | val loss: 8180.841796875
		U | train loss: 71647.21597450657 | val loss: 65593.65625
		virial | train loss: 134871.98876585998 | val loss: 139138.796875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 25
	Epoch loss = 0.00070
	Gradient norm: 0.1743520051240921
	Elapsed time = 3.033 min
[Statepoint 0]
	kT = 2.672 ref_kT = 2.686
	Predicted entropy: -0.024502748623490334
	Predicted free_energy: -267.98419189453125
	Predicted pressure: -386.3186950683594
[Statepoint 1]
	kT = 7.663 ref_kT = 7.674
	Predicted entropy: 0.10517144203186035
	Predicted free_energy: -224.54754638671875
	Predicted pressure: -108.11512756347656
[Statepoint 0] Elastic constants:
	C11: 165.01 GPa
	C33: 176.98 GPa
	C44: 42.92 GPa
	C12: 84.34 GPa
	C13: 67.21 GPa
[Statepoint 1] Elastic constants:
	C11: 134.48 GPa
	C33: 157.13 GPa
	C44: 35.32 GPa
	C12: 92.63 GPa
	C13: 68.03 GPa
Finished epoch 25 for all trainers in  3.96 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 26]:
	Average train loss: 69.15789
	Average val loss: 79.51879119873047
	Gradient norm: 4747943.5
	Elapsed time = 0.880 min
	Per-target losses:
		F | train loss: 6795.073895676692 | val loss: 7839.07080078125
		U | train loss: 178773.13539708647 | val loss: 138098.015625
		virial | train loss: 257095.38237194548 | val loss: 247495.703125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 26
	Epoch loss = 0.00045
	Gradient norm: 0.13858193159103394
	Elapsed time = 3.021 min
[Statepoint 0]
	kT = 2.674 ref_kT = 2.686
	Predicted entropy: 0.01940671168267727
	Predicted free_energy: -89.80198669433594
	Predicted pressure: -119.84151458740234
[Statepoint 1]
	kT = 7.675 ref_kT = 7.674
	Predicted entropy: 0.13360626995563507
	Predicted free_energy: -78.8937759399414
	Predicted pressure: 64.66793060302734
[Statepoint 0] Elastic constants:
	C11: 164.73 GPa
	C33: 179.24 GPa
	C44: 40.67 GPa
	C12: 85.37 GPa
	C13: 69.25 GPa
[Statepoint 1] Elastic constants:
	C11: 133.24 GPa
	C33: 157.79 GPa
	C44: 34.12 GPa
	C12: 95.61 GPa
	C13: 70.33 GPa
Finished epoch 26 for all trainers in  3.96 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 27]:
	Average train loss: 67.83240
	Average val loss: 78.61652374267578
	Gradient norm: 2897858.5
	Elapsed time = 0.886 min
	Per-target losses:
		F | train loss: 6621.561585849389 | val loss: 7723.859375
		U | train loss: 448954.1420347744 | val loss: 301833.59375
		virial | train loss: 291956.896968985 | val loss: 269024.0625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 27
	Epoch loss = 0.00191
	Gradient norm: 1.1332544088363647
	Elapsed time = 3.022 min
[Statepoint 0]
	kT = 2.686 ref_kT = 2.686
	Predicted entropy: 0.08504495769739151
	Predicted free_energy: 146.43618774414062
	Predicted pressure: 59.812442779541016
[Statepoint 1]
	kT = 7.669 ref_kT = 7.674
	Predicted entropy: 0.20687313377857208
	Predicted free_energy: 114.40319061279297
	Predicted pressure: -38.198638916015625
[Statepoint 0] Elastic constants:
	C11: 144.90 GPa
	C33: 178.10 GPa
	C44: 37.79 GPa
	C12: 102.84 GPa
	C13: 70.70 GPa
[Statepoint 1] Elastic constants:
	C11: 134.39 GPa
	C33: 154.44 GPa
	C44: 31.62 GPa
	C12: 92.43 GPa
	C13: 71.66 GPa
Finished epoch 27 for all trainers in  3.96 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 28]:
	Average train loss: 78.50879
	Average val loss: 84.54242706298828
	Gradient norm: 23293162.0
	Elapsed time = 0.888 min
	Per-target losses:
		F | train loss: 7681.120836759868 | val loss: 8308.75
		U | train loss: 118743.71328712406 | val loss: 72414.6015625
		virial | train loss: 394708.6413666001 | val loss: 345626.71875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 28
	Epoch loss = 0.00422
	Gradient norm: 3.4356870651245117
	Elapsed time = 3.041 min
[Statepoint 0]
	kT = 2.692 ref_kT = 2.686
	Predicted entropy: -0.18617196381092072
	Predicted free_energy: -421.30401611328125
	Predicted pressure: -69.65828704833984
[Statepoint 1]
	kT = 7.637 ref_kT = 7.674
	Predicted entropy: 0.017005866393446922
	Predicted free_energy: -311.7225036621094
	Predicted pressure: -65.59970092773438
[Statepoint 0] Elastic constants:
	C11: 174.70 GPa
	C33: 194.52 GPa
	C44: 45.78 GPa
	C12: 94.38 GPa
	C13: 75.12 GPa
[Statepoint 1] Elastic constants:
	C11: 150.59 GPa
	C33: 168.49 GPa
	C44: 36.54 GPa
	C12: 88.47 GPa
	C13: 74.76 GPa
Finished epoch 28 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 29]:
	Average train loss: 70.78960
	Average val loss: 79.59300231933594
	Gradient norm: 10156608.0
	Elapsed time = 0.893 min
	Per-target losses:
		F | train loss: 6894.264398789944 | val loss: 7827.5634765625
		U | train loss: 1103125.6630639099 | val loss: 647928.0625
		virial | train loss: 185958.45958646617 | val loss: 167359.515625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 29
	Epoch loss = 0.00463
	Gradient norm: 6.085515975952148
	Elapsed time = 3.026 min
[Statepoint 0]
	kT = 2.689 ref_kT = 2.686
	Predicted entropy: 0.1378229558467865
	Predicted free_energy: 488.7605895996094
	Predicted pressure: -132.04981994628906
[Statepoint 1]
	kT = 7.623 ref_kT = 7.674
	Predicted entropy: 0.2889941334724426
	Predicted free_energy: 414.24029541015625
	Predicted pressure: 37.76528549194336
[Statepoint 0] Elastic constants:
	C11: 149.94 GPa
	C33: 164.45 GPa
	C44: 36.15 GPa
	C12: 81.49 GPa
	C13: 65.17 GPa
[Statepoint 1] Elastic constants:
	C11: 108.95 GPa
	C33: 142.94 GPa
	C44: 28.02 GPa
	C12: 100.94 GPa
	C13: 67.60 GPa
Finished epoch 29 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 30]:
	Average train loss: 77.93647
	Average val loss: 84.57925415039062
	Gradient norm: 30168874.0
	Elapsed time = 0.890 min
	Per-target losses:
		F | train loss: 7662.808391829182 | val loss: 8354.236328125
		U | train loss: 197681.20999765038 | val loss: 95421.5390625
		virial | train loss: 277675.6916889392 | val loss: 235368.40625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 30
	Epoch loss = 0.00122
	Gradient norm: 1.1407909393310547
	Elapsed time = 3.043 min
[Statepoint 0]
	kT = 2.672 ref_kT = 2.686
	Predicted entropy: -0.14048275351524353
	Predicted free_energy: -497.2057800292969
	Predicted pressure: -277.4764099121094
[Statepoint 1]
	kT = 7.670 ref_kT = 7.674
	Predicted entropy: -0.05929484963417053
	Predicted free_energy: -392.8315734863281
	Predicted pressure: -141.830322265625
[Statepoint 0] Elastic constants:
	C11: 172.10 GPa
	C33: 186.70 GPa
	C44: 44.67 GPa
	C12: 90.18 GPa
	C13: 71.53 GPa
[Statepoint 1] Elastic constants:
	C11: 136.77 GPa
	C33: 165.75 GPa
	C44: 36.68 GPa
	C12: 98.95 GPa
	C13: 71.86 GPa
Finished epoch 30 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 31]:
	Average train loss: 68.37350
	Average val loss: 78.86310577392578
	Gradient norm: 4250449.0
	Elapsed time = 0.915 min
	Per-target losses:
		F | train loss: 6736.717208059211 | val loss: 7793.0751953125
		U | train loss: 220118.66934915414 | val loss: 165844.265625
		virial | train loss: 196553.128612547 | val loss: 191628.046875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 31
	Epoch loss = 0.00205
	Gradient norm: 4.508066654205322
	Elapsed time = 3.025 min
[Statepoint 0]
	kT = 2.683 ref_kT = 2.686
	Predicted entropy: 0.040583036839962006
	Predicted free_energy: 45.124847412109375
	Predicted pressure: -84.75879669189453
[Statepoint 1]
	kT = 7.686 ref_kT = 7.674
	Predicted entropy: 0.12730777263641357
	Predicted free_energy: 39.678226470947266
	Predicted pressure: 138.30848693847656
[Statepoint 0] Elastic constants:
	C11: 146.96 GPa
	C33: 172.23 GPa
	C44: 39.60 GPa
	C12: 95.92 GPa
	C13: 65.94 GPa
[Statepoint 1] Elastic constants:
	C11: 115.93 GPa
	C33: 152.69 GPa
	C44: 32.26 GPa
	C12: 104.33 GPa
	C13: 67.53 GPa
Finished epoch 31 for all trainers in  4.00 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 32]:
	Average train loss: 89.63958
	Average val loss: 93.59725952148438
	Gradient norm: 57335416.0
	Elapsed time = 0.893 min
	Per-target losses:
		F | train loss: 8753.866247650376 | val loss: 9207.5615234375
		U | train loss: 745071.2132283835 | val loss: 370806.15625
		virial | train loss: 338961.46076127817 | val loss: 287710.0625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 32
	Epoch loss = 0.00366
	Gradient norm: 2.9698197841644287
	Elapsed time = 3.040 min
[Statepoint 0]
	kT = 2.672 ref_kT = 2.686
	Predicted entropy: -0.2112882286310196
	Predicted free_energy: -859.1232299804688
	Predicted pressure: -249.27500915527344
[Statepoint 1]
	kT = 7.665 ref_kT = 7.674
	Predicted entropy: -0.17357538640499115
	Predicted free_energy: -700.702392578125
	Predicted pressure: -32.1204948425293
[Statepoint 0] Elastic constants:
	C11: 177.92 GPa
	C33: 190.53 GPa
	C44: 48.48 GPa
	C12: 88.82 GPa
	C13: 69.95 GPa
[Statepoint 1] Elastic constants:
	C11: 149.23 GPa
	C33: 168.47 GPa
	C44: 40.27 GPa
	C12: 93.68 GPa
	C13: 70.90 GPa
Finished epoch 32 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 33]:
	Average train loss: 69.74205
	Average val loss: 79.92384338378906
	Gradient norm: 3615938.5
	Elapsed time = 0.894 min
	Per-target losses:
		F | train loss: 6861.964039738017 | val loss: 7893.4453125
		U | train loss: 296641.048637218 | val loss: 221670.375
		virial | train loss: 206442.43883634868 | val loss: 191930.171875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 33
	Epoch loss = 0.00399
	Gradient norm: 1.1707935333251953
	Elapsed time = 3.031 min
[Statepoint 0]
	kT = 2.697 ref_kT = 2.686
	Predicted entropy: -0.004841028247028589
	Predicted free_energy: 62.45187759399414
	Predicted pressure: -106.00189208984375
[Statepoint 1]
	kT = 7.586 ref_kT = 7.674
	Predicted entropy: 0.09315269440412521
	Predicted free_energy: 84.05068969726562
	Predicted pressure: 562.6220092773438
[Statepoint 0] Elastic constants:
	C11: 136.66 GPa
	C33: 170.20 GPa
	C44: 40.51 GPa
	C12: 105.04 GPa
	C13: 63.60 GPa
[Statepoint 1] Elastic constants:
	C11: 120.12 GPa
	C33: 150.58 GPa
	C44: 33.16 GPa
	C12: 96.73 GPa
	C13: 65.62 GPa
Finished epoch 33 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 34]:
	Average train loss: 71.01278
	Average val loss: 81.52849578857422
	Gradient norm: 3622465.75
	Elapsed time = 0.891 min
	Per-target losses:
		F | train loss: 7038.341062617481 | val loss: 8089.4345703125
		U | train loss: 122408.00258458647 | val loss: 102248.9296875
		virial | train loss: 126741.25920758929 | val loss: 132975.40625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 34
	Epoch loss = 0.00337
	Gradient norm: 1.697966456413269
	Elapsed time = 3.036 min
[Statepoint 0]
	kT = 2.689 ref_kT = 2.686
	Predicted entropy: -0.059666965156793594
	Predicted free_energy: -192.0963592529297
	Predicted pressure: -423.8460388183594
[Statepoint 1]
	kT = 7.710 ref_kT = 7.674
	Predicted entropy: 0.018612736836075783
	Predicted free_energy: -128.69406127929688
	Predicted pressure: 39.12260437011719
[Statepoint 0] Elastic constants:
	C11: 172.40 GPa
	C33: 178.75 GPa
	C44: 44.55 GPa
	C12: 82.27 GPa
	C13: 69.19 GPa
[Statepoint 1] Elastic constants:
	C11: 148.46 GPa
	C33: 158.55 GPa
	C44: 36.30 GPa
	C12: 81.58 GPa
	C13: 70.43 GPa
Finished epoch 34 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 35]:
	Average train loss: 74.59399
	Average val loss: 83.01007080078125
	Gradient norm: 25108548.0
	Elapsed time = 0.892 min
	Per-target losses:
		F | train loss: 7272.504743303572 | val loss: 8149.3505859375
		U | train loss: 985669.5935150376 | val loss: 659816.6875
		virial | train loss: 220818.50549224624 | val loss: 214186.609375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 35
	Epoch loss = 0.00214
	Gradient norm: 1.9704078435897827
	Elapsed time = 3.019 min
[Statepoint 0]
	kT = 2.672 ref_kT = 2.686
	Predicted entropy: 0.14301520586013794
	Predicted free_energy: 417.2220153808594
	Predicted pressure: 290.2907409667969
[Statepoint 1]
	kT = 7.639 ref_kT = 7.674
	Predicted entropy: 0.21487590670585632
	Predicted free_energy: 338.69622802734375
	Predicted pressure: 148.60269165039062
[Statepoint 0] Elastic constants:
	C11: 146.32 GPa
	C33: 171.93 GPa
	C44: 36.57 GPa
	C12: 96.96 GPa
	C13: 70.98 GPa
[Statepoint 1] Elastic constants:
	C11: 118.08 GPa
	C33: 154.06 GPa
	C44: 30.42 GPa
	C12: 103.80 GPa
	C13: 71.01 GPa
Finished epoch 35 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 36]:
	Average train loss: 71.40512
	Average val loss: 81.85944366455078
	Gradient norm: 13440403.0
	Elapsed time = 0.889 min
	Per-target losses:
		F | train loss: 7061.118017210996 | val loss: 8106.9765625
		U | train loss: 57591.12870065789 | val loss: 55491.33203125
		virial | train loss: 184087.0470512218 | val loss: 183545.046875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 36
	Epoch loss = 0.00153
	Gradient norm: 0.8104355335235596
	Elapsed time = 3.041 min
[Statepoint 0]
	kT = 2.684 ref_kT = 2.686
	Predicted entropy: -0.07545561343431473
	Predicted free_energy: -328.11090087890625
	Predicted pressure: 258.89251708984375
[Statepoint 1]
	kT = 7.628 ref_kT = 7.674
	Predicted entropy: 0.02206774428486824
	Predicted free_energy: -262.76849365234375
	Predicted pressure: -471.1501159667969
[Statepoint 0] Elastic constants:
	C11: 171.10 GPa
	C33: 188.34 GPa
	C44: 44.02 GPa
	C12: 92.50 GPa
	C13: 75.51 GPa
[Statepoint 1] Elastic constants:
	C11: 137.83 GPa
	C33: 164.50 GPa
	C44: 36.48 GPa
	C12: 100.35 GPa
	C13: 74.59 GPa
Finished epoch 36 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 37]:
	Average train loss: 76.35260
	Average val loss: 84.83145141601562
	Gradient norm: 20416652.0
	Elapsed time = 0.888 min
	Per-target losses:
		F | train loss: 7551.134416852678 | val loss: 8410.6533203125
		U | train loss: 492993.79487781954 | val loss: 347689.375
		virial | train loss: 87066.87128465695 | val loss: 94306.8515625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 37
	Epoch loss = 0.00293
	Gradient norm: 1.1185603141784668
	Elapsed time = 3.038 min
[Statepoint 0]
	kT = 2.677 ref_kT = 2.686
	Predicted entropy: 0.05193214863538742
	Predicted free_energy: 161.36505126953125
	Predicted pressure: -33.51380920410156
[Statepoint 1]
	kT = 7.662 ref_kT = 7.674
	Predicted entropy: 0.15468870103359222
	Predicted free_energy: 144.0112762451172
	Predicted pressure: -265.94384765625
[Statepoint 0] Elastic constants:
	C11: 159.42 GPa
	C33: 170.22 GPa
	C44: 40.48 GPa
	C12: 80.66 GPa
	C13: 66.74 GPa
[Statepoint 1] Elastic constants:
	C11: 139.41 GPa
	C33: 149.96 GPa
	C44: 32.99 GPa
	C12: 79.74 GPa
	C13: 67.25 GPa
Finished epoch 37 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 38]:
	Average train loss: 71.78320
	Average val loss: 81.54167938232422
	Gradient norm: 17595466.0
	Elapsed time = 0.890 min
	Per-target losses:
		F | train loss: 7028.045802984023 | val loss: 8015.97119140625
		U | train loss: 98821.88163768797 | val loss: 81812.640625
		virial | train loss: 350981.39351797465 | val loss: 325039.125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 38
	Epoch loss = 0.00092
	Gradient norm: 0.21880370378494263
	Elapsed time = 3.035 min
[Statepoint 0]
	kT = 2.700 ref_kT = 2.686
	Predicted entropy: -0.04459325224161148
	Predicted free_energy: -222.25527954101562
	Predicted pressure: 492.49114990234375
[Statepoint 1]
	kT = 7.625 ref_kT = 7.674
	Predicted entropy: 0.0401434451341629
	Predicted free_energy: -184.58343505859375
	Predicted pressure: 209.63821411132812
[Statepoint 0] Elastic constants:
	C11: 165.69 GPa
	C33: 185.03 GPa
	C44: 40.83 GPa
	C12: 93.52 GPa
	C13: 73.88 GPa
[Statepoint 1] Elastic constants:
	C11: 131.81 GPa
	C33: 164.85 GPa
	C44: 33.77 GPa
	C12: 105.00 GPa
	C13: 73.18 GPa
Finished epoch 38 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 39]:
	Average train loss: 67.63736
	Average val loss: 78.31224822998047
	Gradient norm: 4077252.25
	Elapsed time = 0.894 min
	Per-target losses:
		F | train loss: 6720.77716018562 | val loss: 7789.0283203125
		U | train loss: 43555.331399788534 | val loss: 38740.75390625
		virial | train loss: 96508.43388745301 | val loss: 95807.5

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 39
	Epoch loss = 0.00331
	Gradient norm: 2.473020076751709
	Elapsed time = 3.043 min
[Statepoint 0]
	kT = 2.683 ref_kT = 2.686
	Predicted entropy: 0.0009742730762809515
	Predicted free_energy: -143.73631286621094
	Predicted pressure: -297.34771728515625
[Statepoint 1]
	kT = 7.682 ref_kT = 7.674
	Predicted entropy: 0.09749108552932739
	Predicted free_energy: -119.00115966796875
	Predicted pressure: -470.8390197753906
[Statepoint 0] Elastic constants:
	C11: 140.96 GPa
	C33: 173.45 GPa
	C44: 39.59 GPa
	C12: 104.32 GPa
	C13: 66.73 GPa
[Statepoint 1] Elastic constants:
	C11: 136.97 GPa
	C33: 154.57 GPa
	C44: 32.75 GPa
	C12: 85.53 GPa
	C13: 68.02 GPa
Finished epoch 39 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 40]:
	Average train loss: 79.68681
	Average val loss: 88.2174072265625
	Gradient norm: 36762888.0
	Elapsed time = 0.888 min
	Per-target losses:
		F | train loss: 7828.874761366306 | val loss: 8713.0673828125
		U | train loss: 676894.2704417293 | val loss: 443515.28125
		virial | train loss: 180292.93497415414 | val loss: 160804.390625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 40
	Epoch loss = 0.00726
	Gradient norm: 2.4455132484436035
	Elapsed time = 3.051 min
[Statepoint 0]
	kT = 2.681 ref_kT = 2.686
	Predicted entropy: -0.22592344880104065
	Predicted free_energy: -896.699462890625
	Predicted pressure: -487.2895812988281
[Statepoint 1]
	kT = 7.659 ref_kT = 7.674
	Predicted entropy: -0.138531893491745
	Predicted free_energy: -733.9419555664062
	Predicted pressure: -736.8145141601562
[Statepoint 0] Elastic constants:
	C11: 186.99 GPa
	C33: 188.62 GPa
	C44: 48.61 GPa
	C12: 79.46 GPa
	C13: 70.11 GPa
[Statepoint 1] Elastic constants:
	C11: 154.59 GPa
	C33: 164.96 GPa
	C44: 39.92 GPa
	C12: 83.49 GPa
	C13: 71.04 GPa
Finished epoch 40 for all trainers in  4.00 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 41]:
	Average train loss: 69.15499
	Average val loss: 79.6964111328125
	Gradient norm: 3845056.5
	Elapsed time = 0.893 min
	Per-target losses:
		F | train loss: 6804.5993009868425 | val loss: 7867.830078125
		U | train loss: 209340.33076832705 | val loss: 161982.53125
		virial | train loss: 224913.5037300282 | val loss: 214030.9375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 41
	Epoch loss = 0.00203
	Gradient norm: 0.37273070216178894
	Elapsed time = 3.038 min
[Statepoint 0]
	kT = 2.666 ref_kT = 2.686
	Predicted entropy: -0.022378558292984962
	Predicted free_energy: -10.025096893310547
	Predicted pressure: 70.8148193359375
[Statepoint 1]
	kT = 7.632 ref_kT = 7.674
	Predicted entropy: 0.05925210937857628
	Predicted free_energy: 24.807865142822266
	Predicted pressure: -189.45558166503906
[Statepoint 0] Elastic constants:
	C11: 166.57 GPa
	C33: 174.17 GPa
	C44: 41.28 GPa
	C12: 79.37 GPa
	C13: 66.05 GPa
[Statepoint 1] Elastic constants:
	C11: 137.77 GPa
	C33: 153.12 GPa
	C44: 34.34 GPa
	C12: 85.02 GPa
	C13: 66.72 GPa
Finished epoch 41 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 42]:
	Average train loss: 74.78506
	Average val loss: 83.99213409423828
	Gradient norm: 12626138.0
	Elapsed time = 0.887 min
	Per-target losses:
		F | train loss: 7313.120869801457 | val loss: 8248.416015625
		U | train loss: 169677.92046522556 | val loss: 150493.453125
		virial | train loss: 371042.47971980734 | val loss: 339368.5

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 42
	Epoch loss = 0.00411
	Gradient norm: 2.0506999492645264
	Elapsed time = 3.029 min
[Statepoint 0]
	kT = 2.687 ref_kT = 2.686
	Predicted entropy: -0.014979971572756767
	Predicted free_energy: -51.51728439331055
	Predicted pressure: 401.7138366699219
[Statepoint 1]
	kT = 7.668 ref_kT = 7.674
	Predicted entropy: 0.06108151748776436
	Predicted free_energy: -28.785259246826172
	Predicted pressure: -23.39794158935547
[Statepoint 0] Elastic constants:
	C11: 147.40 GPa
	C33: 180.79 GPa
	C44: 39.16 GPa
	C12: 105.05 GPa
	C13: 70.76 GPa
[Statepoint 1] Elastic constants:
	C11: 111.08 GPa
	C33: 160.63 GPa
	C44: 32.37 GPa
	C12: 117.83 GPa
	C13: 70.96 GPa
Finished epoch 42 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 43]:
	Average train loss: 75.13374
	Average val loss: 84.69140625
	Gradient norm: 29920744.0
	Elapsed time = 0.891 min
	Per-target losses:
		F | train loss: 7414.2740469337405 | val loss: 8384.1044921875
		U | train loss: 196081.97538768797 | val loss: 124310.4609375
		virial | train loss: 198728.5886248825 | val loss: 181510.859375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 43
	Epoch loss = 0.00207
	Gradient norm: 0.7914893627166748
	Elapsed time = 3.034 min
[Statepoint 0]
	kT = 2.680 ref_kT = 2.686
	Predicted entropy: -0.12425820529460907
	Predicted free_energy: -575.556884765625
	Predicted pressure: -315.0969543457031
[Statepoint 1]
	kT = 7.680 ref_kT = 7.674
	Predicted entropy: -0.05705450847744942
	Predicted free_energy: -464.6083984375
	Predicted pressure: -262.1494445800781
[Statepoint 0] Elastic constants:
	C11: 167.80 GPa
	C33: 185.97 GPa
	C44: 45.39 GPa
	C12: 93.15 GPa
	C13: 71.16 GPa
[Statepoint 1] Elastic constants:
	C11: 146.49 GPa
	C33: 164.07 GPa
	C44: 37.46 GPa
	C12: 88.39 GPa
	C13: 70.96 GPa
Finished epoch 43 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 44]:
	Average train loss: 69.83724
	Average val loss: 80.47298431396484
	Gradient norm: 4482165.5
	Elapsed time = 0.892 min
	Per-target losses:
		F | train loss: 6895.942232289709 | val loss: 7964.32177734375
		U | train loss: 209268.54587640977 | val loss: 163397.46875
		virial | train loss: 167137.4823043938 | val loss: 166592.703125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 44
	Epoch loss = 0.00197
	Gradient norm: 0.2610093057155609
	Elapsed time = 3.023 min
[Statepoint 0]
	kT = 2.697 ref_kT = 2.686
	Predicted entropy: 0.008373860269784927
	Predicted free_energy: 15.191621780395508
	Predicted pressure: 234.8789825439453
[Statepoint 1]
	kT = 7.671 ref_kT = 7.674
	Predicted entropy: 0.08672326803207397
	Predicted free_energy: 36.808937072753906
	Predicted pressure: 94.04469299316406
[Statepoint 0] Elastic constants:
	C11: 158.67 GPa
	C33: 171.93 GPa
	C44: 40.88 GPa
	C12: 83.41 GPa
	C13: 65.03 GPa
[Statepoint 1] Elastic constants:
	C11: 137.41 GPa
	C33: 153.28 GPa
	C44: 33.73 GPa
	C12: 82.85 GPa
	C13: 67.23 GPa
Finished epoch 44 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 45]:
	Average train loss: 71.21906
	Average val loss: 81.7156982421875
	Gradient norm: 9385934.0
	Elapsed time = 0.891 min
	Per-target losses:
		F | train loss: 7032.782125602092 | val loss: 8086.265625
		U | train loss: 87292.82900610902 | val loss: 81738.5703125
		virial | train loss: 200986.42299107142 | val loss: 192826.359375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 45
	Epoch loss = 0.00352
	Gradient norm: 1.4237678050994873
	Elapsed time = 3.019 min
[Statepoint 0]
	kT = 2.698 ref_kT = 2.686
	Predicted entropy: -0.0021023431327193975
	Predicted free_energy: -108.30425262451172
	Predicted pressure: 98.91412353515625
[Statepoint 1]
	kT = 7.745 ref_kT = 7.674
	Predicted entropy: 0.08950632810592651
	Predicted free_energy: -81.52925109863281
	Predicted pressure: 730.02490234375
[Statepoint 0] Elastic constants:
	C11: 140.42 GPa
	C33: 181.46 GPa
	C44: 39.90 GPa
	C12: 112.28 GPa
	C13: 71.84 GPa
[Statepoint 1] Elastic constants:
	C11: 126.22 GPa
	C33: 157.81 GPa
	C44: 32.56 GPa
	C12: 102.40 GPa
	C13: 71.92 GPa
Finished epoch 45 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 46]:
	Average train loss: 71.81721
	Average val loss: 82.41265869140625
	Gradient norm: 1930471.25
	Elapsed time = 0.892 min
	Per-target losses:
		F | train loss: 7144.795399142387 | val loss: 8205.9755859375
		U | train loss: 117994.70089285714 | val loss: 106910.328125
		virial | train loss: 62815.30001762218 | val loss: 61497.95703125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 46
	Epoch loss = 0.00115
	Gradient norm: 0.16663388907909393
	Elapsed time = 3.019 min
[Statepoint 0]
	kT = 2.683 ref_kT = 2.686
	Predicted entropy: -0.041432540863752365
	Predicted free_energy: -114.57950592041016
	Predicted pressure: -715.2369995117188
[Statepoint 1]
	kT = 7.697 ref_kT = 7.674
	Predicted entropy: 0.06480327248573303
	Predicted free_energy: -51.79055404663086
	Predicted pressure: 20.056482315063477
[Statepoint 0] Elastic constants:
	C11: 168.06 GPa
	C33: 179.43 GPa
	C44: 43.84 GPa
	C12: 84.84 GPa
	C13: 71.02 GPa
[Statepoint 1] Elastic constants:
	C11: 136.23 GPa
	C33: 156.50 GPa
	C44: 35.37 GPa
	C12: 90.30 GPa
	C13: 70.52 GPa
Finished epoch 46 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 47]:
	Average train loss: 70.12567
	Average val loss: 81.19659423828125
	Gradient norm: 3159435.0
	Elapsed time = 0.891 min
	Per-target losses:
		F | train loss: 6929.353071399201 | val loss: 8042.86181640625
		U | train loss: 431909.65037593985 | val loss: 342428.125
		virial | train loss: 100057.0258311795 | val loss: 106387.09375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 47
	Epoch loss = 0.00044
	Gradient norm: 0.31901541352272034
	Elapsed time = 3.026 min
[Statepoint 0]
	kT = 2.698 ref_kT = 2.686
	Predicted entropy: 0.022694649174809456
	Predicted free_energy: 147.64112854003906
	Predicted pressure: -319.7126770019531
[Statepoint 1]
	kT = 7.681 ref_kT = 7.674
	Predicted entropy: 0.10351793467998505
	Predicted free_energy: 163.73524475097656
	Predicted pressure: 374.6155700683594
[Statepoint 0] Elastic constants:
	C11: 155.15 GPa
	C33: 177.63 GPa
	C44: 41.68 GPa
	C12: 93.87 GPa
	C13: 71.05 GPa
[Statepoint 1] Elastic constants:
	C11: 125.44 GPa
	C33: 156.18 GPa
	C44: 33.94 GPa
	C12: 98.83 GPa
	C13: 70.79 GPa
Finished epoch 47 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 48]:
	Average train loss: 77.06567
	Average val loss: 86.41337585449219
	Gradient norm: 25899402.0
	Elapsed time = 0.888 min
	Per-target losses:
		F | train loss: 7605.728589050752 | val loss: 8549.4287109375
		U | train loss: 96952.18300634398 | val loss: 70187.03125
		virial | train loss: 227858.82438028665 | val loss: 212226.40625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 48
	Epoch loss = 0.00551
	Gradient norm: 5.1159162521362305
	Elapsed time = 3.043 min
[Statepoint 0]
	kT = 2.668 ref_kT = 2.686
	Predicted entropy: -0.15597909688949585
	Predicted free_energy: -457.3305358886719
	Predicted pressure: -202.9479217529297
[Statepoint 1]
	kT = 7.669 ref_kT = 7.674
	Predicted entropy: -0.0346573106944561
	Predicted free_energy: -341.73388671875
	Predicted pressure: 764.9503173828125
[Statepoint 0] Elastic constants:
	C11: 178.07 GPa
	C33: 189.77 GPa
	C44: 47.00 GPa
	C12: 88.10 GPa
	C13: 74.18 GPa
[Statepoint 1] Elastic constants:
	C11: 154.41 GPa
	C33: 167.43 GPa
	C44: 38.42 GPa
	C12: 83.64 GPa
	C13: 73.82 GPa
Finished epoch 48 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 49]:
	Average train loss: 72.95385
	Average val loss: 83.23175811767578
	Gradient norm: 12883111.0
	Elapsed time = 0.893 min
	Per-target losses:
		F | train loss: 7193.344355762453 | val loss: 8235.8935546875
		U | train loss: 655735.4414943609 | val loss: 493484.96875
		virial | train loss: 91168.30518679511 | val loss: 94834.015625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 49
	Epoch loss = 0.00177
	Gradient norm: 2.1258771419525146
	Elapsed time = 3.026 min
[Statepoint 0]
	kT = 2.656 ref_kT = 2.686
	Predicted entropy: 0.0812186524271965
	Predicted free_energy: 337.1135559082031
	Predicted pressure: -383.8122863769531
[Statepoint 1]
	kT = 7.657 ref_kT = 7.674
	Predicted entropy: 0.1916687786579132
	Predicted free_energy: 312.02862548828125
	Predicted pressure: 86.07785034179688
[Statepoint 0] Elastic constants:
	C11: 149.90 GPa
	C33: 169.93 GPa
	C44: 40.02 GPa
	C12: 91.18 GPa
	C13: 68.64 GPa
[Statepoint 1] Elastic constants:
	C11: 117.55 GPa
	C33: 150.12 GPa
	C44: 31.67 GPa
	C12: 99.87 GPa
	C13: 68.74 GPa
Finished epoch 49 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 50]:
	Average train loss: 71.97142
	Average val loss: 82.94490051269531
	Gradient norm: 11676041.0
	Elapsed time = 0.895 min
	Per-target losses:
		F | train loss: 7143.430686090225 | val loss: 8240.0791015625
		U | train loss: 51698.01541940789 | val loss: 53606.1328125
		virial | train loss: 121353.34877232143 | val loss: 122623.796875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 50
	Epoch loss = 0.00086
	Gradient norm: 0.2720410227775574
	Elapsed time = 3.037 min
[Statepoint 0]
	kT = 2.689 ref_kT = 2.686
	Predicted entropy: -0.06757465749979019
	Predicted free_energy: -207.48643493652344
	Predicted pressure: -396.0620422363281
[Statepoint 1]
	kT = 7.686 ref_kT = 7.674
	Predicted entropy: 0.02925381250679493
	Predicted free_energy: -140.37911987304688
	Predicted pressure: -525.46826171875
[Statepoint 0] Elastic constants:
	C11: 168.32 GPa
	C33: 182.81 GPa
	C44: 44.41 GPa
	C12: 88.44 GPa
	C13: 72.52 GPa
[Statepoint 1] Elastic constants:
	C11: 135.96 GPa
	C33: 160.35 GPa
	C44: 35.70 GPa
	C12: 94.30 GPa
	C13: 71.80 GPa
Finished epoch 50 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 51]:
	Average train loss: 72.04077
	Average val loss: 82.66669464111328
	Gradient norm: 4688355.5
	Elapsed time = 0.895 min
	Per-target losses:
		F | train loss: 7057.24002144032 | val loss: 8125.1328125
		U | train loss: 297775.8171992481 | val loss: 257255.5
		virial | train loss: 292647.46477766684 | val loss: 289527.09375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 51
	Epoch loss = 0.00059
	Gradient norm: 0.11559153348207474
	Elapsed time = 3.032 min
[Statepoint 0]
	kT = 2.676 ref_kT = 2.686
	Predicted entropy: -0.005943298805505037
	Predicted free_energy: 62.06934356689453
	Predicted pressure: 414.5313720703125
[Statepoint 1]
	kT = 7.655 ref_kT = 7.674
	Predicted entropy: 0.08693156391382217
	Predicted free_energy: 80.190185546875
	Predicted pressure: -35.13886642456055
[Statepoint 0] Elastic constants:
	C11: 152.27 GPa
	C33: 178.61 GPa
	C44: 41.90 GPa
	C12: 98.20 GPa
	C13: 70.31 GPa
[Statepoint 1] Elastic constants:
	C11: 130.67 GPa
	C33: 157.16 GPa
	C44: 34.00 GPa
	C12: 96.82 GPa
	C13: 70.62 GPa
Finished epoch 51 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 52]:
	Average train loss: 71.23129
	Average val loss: 82.31294250488281
	Gradient norm: 11150191.0
	Elapsed time = 0.889 min
	Per-target losses:
		F | train loss: 7035.264762247415 | val loss: 8145.1513671875
		U | train loss: 63629.50108670113 | val loss: 56691.85546875
		virial | train loss: 203754.31796287594 | val loss: 201184.03125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 52
	Epoch loss = 0.00070
	Gradient norm: 0.14300042390823364
	Elapsed time = 3.039 min
[Statepoint 0]
	kT = 2.700 ref_kT = 2.686
	Predicted entropy: -0.07042833417654037
	Predicted free_energy: -248.39425659179688
	Predicted pressure: 212.0353240966797
[Statepoint 1]
	kT = 7.677 ref_kT = 7.674
	Predicted entropy: 0.0357457734644413
	Predicted free_energy: -185.7342987060547
	Predicted pressure: -192.32693481445312
[Statepoint 0] Elastic constants:
	C11: 157.69 GPa
	C33: 180.38 GPa
	C44: 44.20 GPa
	C12: 97.15 GPa
	C13: 69.95 GPa
[Statepoint 1] Elastic constants:
	C11: 138.60 GPa
	C33: 158.95 GPa
	C44: 36.10 GPa
	C12: 91.76 GPa
	C13: 71.08 GPa
Finished epoch 52 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 53]:
	Average train loss: 72.42543
	Average val loss: 82.73094940185547
	Gradient norm: 13430323.0
	Elapsed time = 0.893 min
	Per-target losses:
		F | train loss: 7169.4193675105735 | val loss: 8200.884765625
		U | train loss: 172491.88698308272 | val loss: 141502.21875
		virial | train loss: 139686.03419437265 | val loss: 145151.375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 53
	Epoch loss = 0.00119
	Gradient norm: 0.7304670810699463
	Elapsed time = 3.035 min
[Statepoint 0]
	kT = 2.706 ref_kT = 2.686
	Predicted entropy: 0.0076451292261481285
	Predicted free_energy: 51.96684265136719
	Predicted pressure: 122.1203842163086
[Statepoint 1]
	kT = 7.672 ref_kT = 7.674
	Predicted entropy: 0.11160419136285782
	Predicted free_energy: 64.24556732177734
	Predicted pressure: 288.8213195800781
[Statepoint 0] Elastic constants:
	C11: 151.61 GPa
	C33: 170.72 GPa
	C44: 41.89 GPa
	C12: 93.74 GPa
	C13: 66.05 GPa
[Statepoint 1] Elastic constants:
	C11: 131.37 GPa
	C33: 151.63 GPa
	C44: 33.37 GPa
	C12: 90.10 GPa
	C13: 67.54 GPa
Finished epoch 53 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 54]:
	Average train loss: 75.40509
	Average val loss: 85.59130859375
	Gradient norm: 16381124.0
	Elapsed time = 0.893 min
	Per-target losses:
		F | train loss: 7446.334171610667 | val loss: 8474.2734375
		U | train loss: 237641.41230028195 | val loss: 183753.703125
		virial | train loss: 176026.47648907424 | val loss: 166206.0625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 54
	Epoch loss = 0.00262
	Gradient norm: 1.4511604309082031
	Elapsed time = 3.038 min
[Statepoint 0]
	kT = 2.679 ref_kT = 2.686
	Predicted entropy: -0.11934937536716461
	Predicted free_energy: -594.274169921875
	Predicted pressure: -273.9833068847656
[Statepoint 1]
	kT = 7.636 ref_kT = 7.674
	Predicted entropy: -0.034907106310129166
	Predicted free_energy: -472.7720031738281
	Predicted pressure: 61.80819320678711
[Statepoint 0] Elastic constants:
	C11: 168.74 GPa
	C33: 187.59 GPa
	C44: 46.84 GPa
	C12: 98.14 GPa
	C13: 72.28 GPa
[Statepoint 1] Elastic constants:
	C11: 148.96 GPa
	C33: 165.09 GPa
	C44: 37.38 GPa
	C12: 89.58 GPa
	C13: 72.86 GPa
Finished epoch 54 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 55]:
	Average train loss: 72.47877
	Average val loss: 82.6650390625
	Gradient norm: 9030978.0
	Elapsed time = 0.892 min
	Per-target losses:
		F | train loss: 7202.53399612312 | val loss: 8220.875
		U | train loss: 135043.6448543233 | val loss: 108319.8359375
		virial | train loss: 79596.01774700423 | val loss: 86992.359375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 55
	Epoch loss = 0.00124
	Gradient norm: 1.6011042594909668
	Elapsed time = 3.017 min
[Statepoint 0]
	kT = 2.693 ref_kT = 2.686
	Predicted entropy: 0.02473788894712925
	Predicted free_energy: 51.47557067871094
	Predicted pressure: 78.53590393066406
[Statepoint 1]
	kT = 7.650 ref_kT = 7.674
	Predicted entropy: 0.11392994970083237
	Predicted free_energy: 67.00300598144531
	Predicted pressure: 13.938222885131836
[Statepoint 0] Elastic constants:
	C11: 148.43 GPa
	C33: 172.23 GPa
	C44: 41.65 GPa
	C12: 97.03 GPa
	C13: 67.28 GPa
[Statepoint 1] Elastic constants:
	C11: 129.57 GPa
	C33: 151.40 GPa
	C44: 33.37 GPa
	C12: 92.09 GPa
	C13: 68.45 GPa
Finished epoch 55 for all trainers in  3.96 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 56]:
	Average train loss: 72.89611
	Average val loss: 83.72871398925781
	Gradient norm: 6605907.5
	Elapsed time = 0.891 min
	Per-target losses:
		F | train loss: 7243.138477296758 | val loss: 8329.255859375
		U | train loss: 79368.81311677632 | val loss: 65536.9921875
		virial | train loss: 96339.29105968046 | val loss: 92654.5703125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 56
	Epoch loss = 0.00150
	Gradient norm: 0.5132156610488892
	Elapsed time = 3.031 min
[Statepoint 0]
	kT = 2.678 ref_kT = 2.686
	Predicted entropy: -0.13674423098564148
	Predicted free_energy: -446.0254821777344
	Predicted pressure: -518.5890502929688
[Statepoint 1]
	kT = 7.623 ref_kT = 7.674
	Predicted entropy: -0.01310158520936966
	Predicted free_energy: -332.9237060546875
	Predicted pressure: -599.2649536132812
[Statepoint 0] Elastic constants:
	C11: 173.87 GPa
	C33: 181.67 GPa
	C44: 47.44 GPa
	C12: 85.11 GPa
	C13: 69.25 GPa
[Statepoint 1] Elastic constants:
	C11: 136.81 GPa
	C33: 159.29 GPa
	C44: 37.39 GPa
	C12: 93.32 GPa
	C13: 69.32 GPa
Finished epoch 56 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 57]:
	Average train loss: 71.82325
	Average val loss: 82.70758819580078
	Gradient norm: 3629338.75
	Elapsed time = 0.888 min
	Per-target losses:
		F | train loss: 7091.099407454182 | val loss: 8180.955078125
		U | train loss: 180834.97562265038 | val loss: 158318.984375
		virial | train loss: 182856.2837978736 | val loss: 184929.453125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 57
	Epoch loss = 0.00061
	Gradient norm: 0.12012317031621933
	Elapsed time = 3.025 min
[Statepoint 0]
	kT = 2.683 ref_kT = 2.686
	Predicted entropy: -0.07600364089012146
	Predicted free_energy: -16.38161277770996
	Predicted pressure: 385.3461608886719
[Statepoint 1]
	kT = 7.675 ref_kT = 7.674
	Predicted entropy: 0.051718905568122864
	Predicted free_energy: 38.640350341796875
	Predicted pressure: 141.4992218017578
[Statepoint 0] Elastic constants:
	C11: 160.22 GPa
	C33: 179.00 GPa
	C44: 44.30 GPa
	C12: 92.99 GPa
	C13: 68.48 GPa
[Statepoint 1] Elastic constants:
	C11: 119.78 GPa
	C33: 156.32 GPa
	C44: 34.84 GPa
	C12: 104.92 GPa
	C13: 67.97 GPa
Finished epoch 57 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 58]:
	Average train loss: 74.04492
	Average val loss: 84.07747650146484
	Gradient norm: 12453702.0
	Elapsed time = 0.894 min
	Per-target losses:
		F | train loss: 7314.1461172462405 | val loss: 8321.5400390625
		U | train loss: 476473.7918233083 | val loss: 397378.28125
		virial | train loss: 106746.08778782895 | val loss: 116173.484375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 58
	Epoch loss = 0.00051
	Gradient norm: 0.04260638356208801
	Elapsed time = 3.039 min
[Statepoint 0]
	kT = 2.686 ref_kT = 2.686
	Predicted entropy: -0.03660472854971886
	Predicted free_energy: 195.94537353515625
	Predicted pressure: 221.04177856445312
[Statepoint 1]
	kT = 7.705 ref_kT = 7.674
	Predicted entropy: 0.09787855297327042
	Predicted free_energy: 229.499755859375
	Predicted pressure: 9.48281478881836
[Statepoint 0] Elastic constants:
	C11: 158.48 GPa
	C33: 176.70 GPa
	C44: 44.01 GPa
	C12: 92.63 GPa
	C13: 69.50 GPa
[Statepoint 1] Elastic constants:
	C11: 134.17 GPa
	C33: 153.47 GPa
	C44: 34.56 GPa
	C12: 90.27 GPa
	C13: 69.11 GPa
Finished epoch 58 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 59]:
	Average train loss: 71.95273
	Average val loss: 82.639892578125
	Gradient norm: 3990887.0
	Elapsed time = 0.891 min
	Per-target losses:
		F | train loss: 7112.009556361607 | val loss: 8183.9052734375
		U | train loss: 274996.8504464286 | val loss: 226498.5
		virial | train loss: 139410.6232524671 | val loss: 143583.859375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 59
	Epoch loss = 0.00037
	Gradient norm: 0.040191423147916794
	Elapsed time = 3.019 min
[Statepoint 0]
	kT = 2.671 ref_kT = 2.686
	Predicted entropy: -0.029091840609908104
	Predicted free_energy: 112.55270385742188
	Predicted pressure: -167.16065979003906
[Statepoint 1]
	kT = 7.629 ref_kT = 7.674
	Predicted entropy: 0.10842801630496979
	Predicted free_energy: 139.43936157226562
	Predicted pressure: 469.5727233886719
[Statepoint 0] Elastic constants:
	C11: 159.99 GPa
	C33: 182.46 GPa
	C44: 42.90 GPa
	C12: 96.27 GPa
	C13: 72.71 GPa
[Statepoint 1] Elastic constants:
	C11: 127.97 GPa
	C33: 157.91 GPa
	C44: 33.52 GPa
	C12: 100.40 GPa
	C13: 72.62 GPa
Finished epoch 59 for all trainers in  3.96 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 60]:
	Average train loss: 74.71305
	Average val loss: 84.6507568359375
	Gradient norm: 22586378.0
	Elapsed time = 0.876 min
	Per-target losses:
		F | train loss: 7435.5050627055925 | val loss: 8429.9580078125
		U | train loss: 109557.88128524437 | val loss: 89169.3984375
		virial | train loss: 62109.89064702773 | val loss: 65502.203125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 60
	Epoch loss = 0.00087
	Gradient norm: 0.2819909155368805
	Elapsed time = 3.026 min
[Statepoint 0]
	kT = 2.684 ref_kT = 2.686
	Predicted entropy: -0.017247943207621574
	Predicted free_energy: 5.238862037658691
	Predicted pressure: -149.50439453125
[Statepoint 1]
	kT = 7.690 ref_kT = 7.674
	Predicted entropy: 0.12039102613925934
	Predicted free_energy: 35.521644592285156
	Predicted pressure: 236.77032470703125
[Statepoint 0] Elastic constants:
	C11: 163.81 GPa
	C33: 175.06 GPa
	C44: 43.58 GPa
	C12: 85.91 GPa
	C13: 68.57 GPa
[Statepoint 1] Elastic constants:
	C11: 119.67 GPa
	C33: 152.19 GPa
	C44: 33.99 GPa
	C12: 103.29 GPa
	C13: 68.55 GPa
Finished epoch 60 for all trainers in  3.96 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 61]:
	Average train loss: 75.17565
	Average val loss: 85.74877166748047
	Gradient norm: 14561820.0
	Elapsed time = 0.877 min
	Per-target losses:
		F | train loss: 7451.656415207942 | val loss: 8515.376953125
		U | train loss: 136475.5391212406 | val loss: 119128.5234375
		virial | train loss: 130653.40625 | val loss: 118966.296875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 61
	Epoch loss = 0.00347
	Gradient norm: 1.5639370679855347
	Elapsed time = 3.043 min
[Statepoint 0]
	kT = 2.689 ref_kT = 2.686
	Predicted entropy: -0.17030100524425507
	Predicted free_energy: -536.6354370117188
	Predicted pressure: -271.0188903808594
[Statepoint 1]
	kT = 7.624 ref_kT = 7.674
	Predicted entropy: -0.014847511425614357
	Predicted free_energy: -413.8400573730469
	Predicted pressure: 146.34329223632812
[Statepoint 0] Elastic constants:
	C11: 178.07 GPa
	C33: 189.69 GPa
	C44: 48.19 GPa
	C12: 87.97 GPa
	C13: 73.01 GPa
[Statepoint 1] Elastic constants:
	C11: 147.89 GPa
	C33: 165.15 GPa
	C44: 37.85 GPa
	C12: 89.60 GPa
	C13: 72.69 GPa
Finished epoch 61 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 62]:
	Average train loss: 72.53689
	Average val loss: 83.7142105102539
	Gradient norm: 9852955.0
	Elapsed time = 0.893 min
	Per-target losses:
		F | train loss: 7212.971239132989 | val loss: 8332.74609375
		U | train loss: 145242.4263392857 | val loss: 124293.5390625
		virial | train loss: 65483.713169642855 | val loss: 65612.5625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 62
	Epoch loss = 0.00040
	Gradient norm: 0.18606428802013397
	Elapsed time = 3.033 min
[Statepoint 0]
	kT = 2.683 ref_kT = 2.686
	Predicted entropy: -0.012768913991749287
	Predicted free_energy: 39.98147201538086
	Predicted pressure: -105.68759155273438
[Statepoint 1]
	kT = 7.672 ref_kT = 7.674
	Predicted entropy: 0.13417136669158936
	Predicted free_energy: 66.8893051147461
	Predicted pressure: -119.26023864746094
[Statepoint 0] Elastic constants:
	C11: 160.62 GPa
	C33: 176.78 GPa
	C44: 42.40 GPa
	C12: 88.23 GPa
	C13: 68.68 GPa
[Statepoint 1] Elastic constants:
	C11: 122.04 GPa
	C33: 155.84 GPa
	C44: 33.39 GPa
	C12: 99.83 GPa
	C13: 69.04 GPa
Finished epoch 62 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 63]:
	Average train loss: 71.48444
	Average val loss: 82.3320541381836
	Gradient norm: 4876069.0
	Elapsed time = 0.887 min
	Per-target losses:
		F | train loss: 7086.442793996711 | val loss: 8171.57421875
		U | train loss: 64651.02117598684 | val loss: 62086.34765625
		virial | train loss: 138839.62196017388 | val loss: 138556.28125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 63
	Epoch loss = 0.00019
	Gradient norm: 0.0650305524468422
	Elapsed time = 3.032 min
[Statepoint 0]
	kT = 2.682 ref_kT = 2.686
	Predicted entropy: -0.08017852157354355
	Predicted free_energy: -158.16510009765625
	Predicted pressure: 148.99472045898438
[Statepoint 1]
	kT = 7.614 ref_kT = 7.674
	Predicted entropy: 0.06810703128576279
	Predicted free_energy: -100.81542205810547
	Predicted pressure: 206.45773315429688
[Statepoint 0] Elastic constants:
	C11: 163.24 GPa
	C33: 182.54 GPa
	C44: 43.69 GPa
	C12: 93.94 GPa
	C13: 70.84 GPa
[Statepoint 1] Elastic constants:
	C11: 130.01 GPa
	C33: 159.15 GPa
	C44: 35.06 GPa
	C12: 101.67 GPa
	C13: 71.24 GPa
Finished epoch 63 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 64]:
	Average train loss: 72.19003
	Average val loss: 82.82260131835938
	Gradient norm: 2668197.0
	Elapsed time = 0.873 min
	Per-target losses:
		F | train loss: 7187.508326480263 | val loss: 8250.9697265625
		U | train loss: 44782.16845335996 | val loss: 41844.6484375
		virial | train loss: 67541.68055759516 | val loss: 67765.375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 64
	Epoch loss = 0.00067
	Gradient norm: 0.2956697940826416
	Elapsed time = 3.041 min
[Statepoint 0]
	kT = 2.683 ref_kT = 2.686
	Predicted entropy: -0.048219189047813416
	Predicted free_energy: -129.5282745361328
	Predicted pressure: -300.42864990234375
[Statepoint 1]
	kT = 7.685 ref_kT = 7.674
	Predicted entropy: 0.09563940763473511
	Predicted free_energy: -80.05317687988281
	Predicted pressure: -420.6025695800781
[Statepoint 0] Elastic constants:
	C11: 155.03 GPa
	C33: 175.67 GPa
	C44: 43.64 GPa
	C12: 94.98 GPa
	C13: 67.38 GPa
[Statepoint 1] Elastic constants:
	C11: 131.97 GPa
	C33: 154.26 GPa
	C44: 34.43 GPa
	C12: 91.02 GPa
	C13: 68.62 GPa
Finished epoch 64 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 65]:
	Average train loss: 73.40979
	Average val loss: 84.04224395751953
	Gradient norm: 14268102.0
	Elapsed time = 0.876 min
	Per-target losses:
		F | train loss: 7256.548655574483 | val loss: 8321.1435546875
		U | train loss: 81173.84833176692 | val loss: 76295.171875
		virial | train loss: 190781.4850541882 | val loss: 188626.96875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 65
	Epoch loss = 0.00148
	Gradient norm: 0.4942280650138855
	Elapsed time = 3.036 min
[Statepoint 0]
	kT = 2.680 ref_kT = 2.686
	Predicted entropy: -0.1689794510602951
	Predicted free_energy: -276.3451843261719
	Predicted pressure: -90.0777587890625
[Statepoint 1]
	kT = 7.666 ref_kT = 7.674
	Predicted entropy: 0.0013710331404581666
	Predicted free_energy: -173.8125
	Predicted pressure: -451.1978759765625
[Statepoint 0] Elastic constants:
	C11: 170.93 GPa
	C33: 185.35 GPa
	C44: 45.90 GPa
	C12: 91.55 GPa
	C13: 71.20 GPa
[Statepoint 1] Elastic constants:
	C11: 142.62 GPa
	C33: 161.50 GPa
	C44: 36.28 GPa
	C12: 92.16 GPa
	C13: 71.18 GPa
Finished epoch 65 for all trainers in  3.96 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 66]:
	Average train loss: 73.17243
	Average val loss: 83.8746109008789
	Gradient norm: 7820934.0
	Elapsed time = 0.876 min
	Per-target losses:
		F | train loss: 7172.565870242011 | val loss: 8249.8271484375
		U | train loss: 433971.2770206767 | val loss: 389191.21875
		virial | train loss: 253199.77602208647 | val loss: 246787.71875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 66
	Epoch loss = 0.00065
	Gradient norm: 0.4937017858028412
	Elapsed time = 3.030 min
[Statepoint 0]
	kT = 2.671 ref_kT = 2.686
	Predicted entropy: -0.0525115542113781
	Predicted free_energy: 219.3649139404297
	Predicted pressure: 451.4703369140625
[Statepoint 1]
	kT = 7.632 ref_kT = 7.674
	Predicted entropy: 0.1027144193649292
	Predicted free_energy: 242.27490234375
	Predicted pressure: 357.97369384765625
[Statepoint 0] Elastic constants:
	C11: 153.45 GPa
	C33: 177.27 GPa
	C44: 41.60 GPa
	C12: 96.56 GPa
	C13: 68.01 GPa
[Statepoint 1] Elastic constants:
	C11: 124.71 GPa
	C33: 155.76 GPa
	C44: 32.86 GPa
	C12: 100.81 GPa
	C13: 69.07 GPa
Finished epoch 66 for all trainers in  3.96 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 67]:
	Average train loss: 75.09245
	Average val loss: 85.54718017578125
	Gradient norm: 21324334.0
	Elapsed time = 0.876 min
	Per-target losses:
		F | train loss: 7426.926904664004 | val loss: 8474.353515625
		U | train loss: 102029.57242716165 | val loss: 99405.328125
		virial | train loss: 180287.9272057096 | val loss: 176059.453125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 67
	Epoch loss = 0.00101
	Gradient norm: 0.06982012838125229
	Elapsed time = 3.036 min
[Statepoint 0]
	kT = 2.672 ref_kT = 2.686
	Predicted entropy: -0.16964191198349
	Predicted free_energy: -242.87083435058594
	Predicted pressure: -360.5431213378906
[Statepoint 1]
	kT = 7.665 ref_kT = 7.674
	Predicted entropy: -0.013443561270833015
	Predicted free_energy: -136.93344116210938
	Predicted pressure: -246.8839569091797
[Statepoint 0] Elastic constants:
	C11: 170.02 GPa
	C33: 184.91 GPa
	C44: 45.31 GPa
	C12: 89.80 GPa
	C13: 69.92 GPa
[Statepoint 1] Elastic constants:
	C11: 123.17 GPa
	C33: 160.62 GPa
	C44: 35.64 GPa
	C12: 107.42 GPa
	C13: 71.00 GPa
Finished epoch 67 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 68]:
	Average train loss: 75.99428
	Average val loss: 86.16388702392578
	Gradient norm: 29130118.0
	Elapsed time = 0.876 min
	Per-target losses:
		F | train loss: 7517.183769971804 | val loss: 8535.6171875
		U | train loss: 148259.8798167293 | val loss: 142314.578125
		virial | train loss: 168545.66984844924 | val loss: 166350.609375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 68
	Epoch loss = 0.00163
	Gradient norm: 0.6371281147003174
	Elapsed time = 3.038 min
[Statepoint 0]
	kT = 2.674 ref_kT = 2.686
	Predicted entropy: -0.18930278718471527
	Predicted free_energy: -187.01956176757812
	Predicted pressure: -275.9463806152344
[Statepoint 1]
	kT = 7.673 ref_kT = 7.674
	Predicted entropy: -0.04723980277776718
	Predicted free_energy: -68.88427734375
	Predicted pressure: -251.8214111328125
[Statepoint 0] Elastic constants:
	C11: 164.21 GPa
	C33: 181.44 GPa
	C44: 46.70 GPa
	C12: 90.79 GPa
	C13: 66.68 GPa
[Statepoint 1] Elastic constants:
	C11: 143.45 GPa
	C33: 160.78 GPa
	C44: 37.58 GPa
	C12: 85.18 GPa
	C13: 67.44 GPa
Finished epoch 68 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 69]:
	Average train loss: 73.87222
	Average val loss: 84.5991439819336
	Gradient norm: 9047969.0
	Elapsed time = 0.874 min
	Per-target losses:
		F | train loss: 7232.281044407895 | val loss: 8314.046875
		U | train loss: 731552.6076127819 | val loss: 655565.6875
		virial | train loss: 204463.73132048873 | val loss: 200777.0

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 69
	Epoch loss = 0.00095
	Gradient norm: 0.48108428716659546
	Elapsed time = 3.032 min
[Statepoint 0]
	kT = 2.662 ref_kT = 2.686
	Predicted entropy: -0.02934451401233673
	Predicted free_energy: 344.0246887207031
	Predicted pressure: 207.9635009765625
[Statepoint 1]
	kT = 7.646 ref_kT = 7.674
	Predicted entropy: 0.11792922019958496
	Predicted free_energy: 372.2294006347656
	Predicted pressure: -48.410953521728516
[Statepoint 0] Elastic constants:
	C11: 153.53 GPa
	C33: 172.60 GPa
	C44: 42.13 GPa
	C12: 90.28 GPa
	C13: 65.06 GPa
[Statepoint 1] Elastic constants:
	C11: 121.25 GPa
	C33: 151.95 GPa
	C44: 33.23 GPa
	C12: 99.53 GPa
	C13: 66.00 GPa
Finished epoch 69 for all trainers in  3.96 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 70]:
	Average train loss: 74.13351
	Average val loss: 84.95600128173828
	Gradient norm: 6280399.0
	Elapsed time = 0.897 min
	Per-target losses:
		F | train loss: 7285.5726511101975 | val loss: 8370.638671875
		U | train loss: 333968.14226973685 | val loss: 309180.6875
		virial | train loss: 235954.35643796992 | val loss: 235107.125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 70
	Epoch loss = 0.00065
	Gradient norm: 0.5244580507278442
	Elapsed time = 3.027 min
[Statepoint 0]
	kT = 2.678 ref_kT = 2.686
	Predicted entropy: -0.12809905409812927
	Predicted free_energy: 23.81038475036621
	Predicted pressure: 39.07788848876953
[Statepoint 1]
	kT = 7.662 ref_kT = 7.674
	Predicted entropy: 0.019674494862556458
	Predicted free_energy: 106.8452377319336
	Predicted pressure: 104.86451721191406
[Statepoint 0] Elastic constants:
	C11: 154.18 GPa
	C33: 181.84 GPa
	C44: 44.77 GPa
	C12: 103.24 GPa
	C13: 69.41 GPa
[Statepoint 1] Elastic constants:
	C11: 130.65 GPa
	C33: 159.52 GPa
	C44: 35.54 GPa
	C12: 99.70 GPa
	C13: 70.11 GPa
Finished epoch 70 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 71]:
	Average train loss: 75.03469
	Average val loss: 86.07313537597656
	Gradient norm: 22504726.0
	Elapsed time = 0.899 min
	Per-target losses:
		F | train loss: 7410.9143708881575 | val loss: 8516.5078125
		U | train loss: 179121.4291588346 | val loss: 171970.53125
		virial | train loss: 186606.99719513627 | val loss: 184022.15625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 71
	Epoch loss = 0.00267
	Gradient norm: 2.8214128017425537
	Elapsed time = 3.045 min
[Statepoint 0]
	kT = 2.680 ref_kT = 2.686
	Predicted entropy: -0.22200529277324677
	Predicted free_energy: -206.2833709716797
	Predicted pressure: -315.1525573730469
[Statepoint 1]
	kT = 7.669 ref_kT = 7.674
	Predicted entropy: -0.056675024330616
	Predicted free_energy: -70.46634674072266
	Predicted pressure: -270.1787414550781
[Statepoint 0] Elastic constants:
	C11: 176.71 GPa
	C33: 183.79 GPa
	C44: 48.55 GPa
	C12: 84.89 GPa
	C13: 69.05 GPa
[Statepoint 1] Elastic constants:
	C11: 144.86 GPa
	C33: 162.76 GPa
	C44: 38.41 GPa
	C12: 88.88 GPa
	C13: 69.49 GPa
Finished epoch 71 for all trainers in  4.00 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 72]:
	Average train loss: 73.20813
	Average val loss: 84.11803436279297
	Gradient norm: 6886694.0
	Elapsed time = 0.896 min
	Per-target losses:
		F | train loss: 7186.246574688675 | val loss: 8283.662109375
		U | train loss: 733593.7570488722 | val loss: 656597.5625
		virial | train loss: 153017.9691685268 | val loss: 156204.234375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 72
	Epoch loss = 0.00091
	Gradient norm: 0.20515236258506775
	Elapsed time = 3.032 min
[Statepoint 0]
	kT = 2.673 ref_kT = 2.686
	Predicted entropy: -0.08465757966041565
	Predicted free_energy: 302.8069152832031
	Predicted pressure: -153.33714294433594
[Statepoint 1]
	kT = 7.663 ref_kT = 7.674
	Predicted entropy: 0.08105793595314026
	Predicted free_energy: 354.43865966796875
	Predicted pressure: 371.85028076171875
[Statepoint 0] Elastic constants:
	C11: 152.86 GPa
	C33: 173.67 GPa
	C44: 44.14 GPa
	C12: 95.24 GPa
	C13: 66.43 GPa
[Statepoint 1] Elastic constants:
	C11: 132.08 GPa
	C33: 152.94 GPa
	C44: 34.75 GPa
	C12: 90.16 GPa
	C13: 67.62 GPa
Finished epoch 72 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 73]:
	Average train loss: 73.15648
	Average val loss: 84.047607421875
	Gradient norm: 4924720.0
	Elapsed time = 0.892 min
	Per-target losses:
		F | train loss: 7180.7003752055925 | val loss: 8273.7265625
		U | train loss: 586932.1785714285 | val loss: 538269.5
		virial | train loss: 190635.1660743656 | val loss: 193016.984375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 73
	Epoch loss = 0.00025
	Gradient norm: 0.045859795063734055
	Elapsed time = 3.034 min
[Statepoint 0]
	kT = 2.685 ref_kT = 2.686
	Predicted entropy: -0.11602111905813217
	Predicted free_energy: 206.68417358398438
	Predicted pressure: 220.6878204345703
[Statepoint 1]
	kT = 7.661 ref_kT = 7.674
	Predicted entropy: 0.05034495145082474
	Predicted free_energy: 273.9913330078125
	Predicted pressure: 222.6054229736328
[Statepoint 0] Elastic constants:
	C11: 166.39 GPa
	C33: 179.55 GPa
	C44: 44.57 GPa
	C12: 88.38 GPa
	C13: 69.62 GPa
[Statepoint 1] Elastic constants:
	C11: 129.80 GPa
	C33: 157.14 GPa
	C44: 35.00 GPa
	C12: 98.69 GPa
	C13: 70.65 GPa
Finished epoch 73 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 74]:
	Average train loss: 73.02216
	Average val loss: 83.95215606689453
	Gradient norm: 4259174.5
	Elapsed time = 0.892 min
	Per-target losses:
		F | train loss: 7211.864301867951 | val loss: 8308.8310546875
		U | train loss: 389298.21428571426 | val loss: 358705.75
		virial | train loss: 128555.05947485902 | val loss: 126283.6796875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 74
	Epoch loss = 0.00031
	Gradient norm: 0.018041962757706642
	Elapsed time = 3.026 min
[Statepoint 0]
	kT = 2.668 ref_kT = 2.686
	Predicted entropy: -0.09396011382341385
	Predicted free_energy: 159.19757080078125
	Predicted pressure: 261.11346435546875
[Statepoint 1]
	kT = 7.660 ref_kT = 7.674
	Predicted entropy: 0.07404359430074692
	Predicted free_energy: 219.39395141601562
	Predicted pressure: 149.67076110839844
[Statepoint 0] Elastic constants:
	C11: 161.80 GPa
	C33: 178.83 GPa
	C44: 43.61 GPa
	C12: 90.62 GPa
	C13: 69.79 GPa
[Statepoint 1] Elastic constants:
	C11: 122.92 GPa
	C33: 157.86 GPa
	C44: 34.04 GPa
	C12: 102.23 GPa
	C13: 70.21 GPa
Finished epoch 74 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 75]:
	Average train loss: 72.79921
	Average val loss: 84.01824951171875
	Gradient norm: 2945724.5
	Elapsed time = 0.896 min
	Per-target losses:
		F | train loss: 7227.669496005639 | val loss: 8351.6298828125
		U | train loss: 221549.29981203008 | val loss: 204999.671875
		virial | train loss: 75240.50866423872 | val loss: 74238.34375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 75
	Epoch loss = 0.00049
	Gradient norm: 0.16720302402973175
	Elapsed time = 3.029 min
[Statepoint 0]
	kT = 2.687 ref_kT = 2.686
	Predicted entropy: -0.09604295343160629
	Predicted free_energy: 47.030609130859375
	Predicted pressure: -60.1083869934082
[Statepoint 1]
	kT = 7.684 ref_kT = 7.674
	Predicted entropy: 0.06875848025083542
	Predicted free_energy: 120.35475158691406
	Predicted pressure: 582.0106201171875
[Statepoint 0] Elastic constants:
	C11: 162.92 GPa
	C33: 179.13 GPa
	C44: 44.01 GPa
	C12: 91.08 GPa
	C13: 70.17 GPa
[Statepoint 1] Elastic constants:
	C11: 135.28 GPa
	C33: 158.21 GPa
	C44: 34.57 GPa
	C12: 92.68 GPa
	C13: 70.57 GPa
Finished epoch 75 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 76]:
	Average train loss: 74.29967
	Average val loss: 85.19709014892578
	Gradient norm: 12590999.0
	Elapsed time = 0.894 min
	Per-target losses:
		F | train loss: 7376.961238545583 | val loss: 8470.9716796875
		U | train loss: 293459.35326597746 | val loss: 262076.828125
		virial | train loss: 59149.80381740484 | val loss: 56323.796875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 76
	Epoch loss = 0.00474
	Gradient norm: 4.593550682067871
	Elapsed time = 3.025 min
[Statepoint 0]
	kT = 2.695 ref_kT = 2.686
	Predicted entropy: -0.00456226198002696
	Predicted free_energy: 216.49452209472656
	Predicted pressure: -547.9622192382812
[Statepoint 1]
	kT = 7.653 ref_kT = 7.674
	Predicted entropy: 0.14685918390750885
	Predicted free_energy: 239.69581604003906
	Predicted pressure: 678.7381591796875
[Statepoint 0] Elastic constants:
	C11: 138.09 GPa
	C33: 173.69 GPa
	C44: 40.69 GPa
	C12: 108.88 GPa
	C13: 69.43 GPa
[Statepoint 1] Elastic constants:
	C11: 114.41 GPa
	C33: 152.52 GPa
	C44: 32.03 GPa
	C12: 108.08 GPa
	C13: 70.14 GPa
Finished epoch 76 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 77]:
	Average train loss: 72.75487
	Average val loss: 84.12726593017578
	Gradient norm: 3335456.5
	Elapsed time = 0.893 min
	Per-target losses:
		F | train loss: 7236.6652373120305 | val loss: 8377.591796875
		U | train loss: 56667.611519031954 | val loss: 55586.921875
		virial | train loss: 82888.93893180216 | val loss: 73940.7578125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 77
	Epoch loss = 0.00027
	Gradient norm: 0.06442465633153915
	Elapsed time = 3.021 min
[Statepoint 0]
	kT = 2.663 ref_kT = 2.686
	Predicted entropy: -0.10814802348613739
	Predicted free_energy: -164.53392028808594
	Predicted pressure: -369.9704895019531
[Statepoint 1]
	kT = 7.673 ref_kT = 7.674
	Predicted entropy: 0.04032430052757263
	Predicted free_energy: -68.70447540283203
	Predicted pressure: 90.98857116699219
[Statepoint 0] Elastic constants:
	C11: 163.93 GPa
	C33: 180.08 GPa
	C44: 45.14 GPa
	C12: 93.87 GPa
	C13: 71.95 GPa
[Statepoint 1] Elastic constants:
	C11: 134.07 GPa
	C33: 157.84 GPa
	C44: 35.47 GPa
	C12: 96.16 GPa
	C13: 70.90 GPa
Finished epoch 77 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 78]:
	Average train loss: 72.39660
	Average val loss: 83.54708099365234
	Gradient norm: 2593178.5
	Elapsed time = 0.892 min
	Per-target losses:
		F | train loss: 7202.165182242717 | val loss: 8319.6611328125
		U | train loss: 90012.1989544173 | val loss: 85996.734375
		virial | train loss: 71233.09037241542 | val loss: 66119.0

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 78
	Epoch loss = 0.00059
	Gradient norm: 0.13858751952648163
	Elapsed time = 3.024 min
[Statepoint 0]
	kT = 2.675 ref_kT = 2.686
	Predicted entropy: -0.10587091743946075
	Predicted free_energy: -89.15959930419922
	Predicted pressure: -768.1002807617188
[Statepoint 1]
	kT = 7.670 ref_kT = 7.674
	Predicted entropy: 0.04723018780350685
	Predicted free_energy: -1.1079225540161133
	Predicted pressure: -408.8241271972656
[Statepoint 0] Elastic constants:
	C11: 159.40 GPa
	C33: 178.79 GPa
	C44: 45.05 GPa
	C12: 95.58 GPa
	C13: 70.14 GPa
[Statepoint 1] Elastic constants:
	C11: 124.80 GPa
	C33: 156.18 GPa
	C44: 35.42 GPa
	C12: 102.54 GPa
	C13: 70.21 GPa
Finished epoch 78 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 79]:
	Average train loss: 72.37704
	Average val loss: 83.2680435180664
	Gradient norm: 3373856.25
	Elapsed time = 0.893 min
	Per-target losses:
		F | train loss: 7167.823855292529 | val loss: 8259.0419921875
		U | train loss: 248261.0467575188 | val loss: 233582.546875
		virial | train loss: 112634.65992128759 | val loss: 111011.6796875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 79
	Epoch loss = 0.00079
	Gradient norm: 0.1968255490064621
	Elapsed time = 3.036 min
[Statepoint 0]
	kT = 2.695 ref_kT = 2.686
	Predicted entropy: -0.1357387751340866
	Predicted free_energy: -1.110770344734192
	Predicted pressure: -415.59149169921875
[Statepoint 1]
	kT = 7.646 ref_kT = 7.674
	Predicted entropy: 0.01618216186761856
	Predicted free_energy: 91.73265075683594
	Predicted pressure: -170.6316375732422
[Statepoint 0] Elastic constants:
	C11: 171.73 GPa
	C33: 180.97 GPa
	C44: 46.01 GPa
	C12: 84.74 GPa
	C13: 70.14 GPa
[Statepoint 1] Elastic constants:
	C11: 130.81 GPa
	C33: 157.05 GPa
	C44: 36.24 GPa
	C12: 98.18 GPa
	C13: 70.17 GPa
Finished epoch 79 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 80]:
	Average train loss: 72.90354
	Average val loss: 83.72473907470703
	Gradient norm: 6674190.5
	Elapsed time = 0.896 min
	Per-target losses:
		F | train loss: 7153.221521822134 | val loss: 8238.611328125
		U | train loss: 582274.774906015 | val loss: 550265.625
		virial | train loss: 197263.27161654134 | val loss: 197090.25

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 80
	Epoch loss = 0.00085
	Gradient norm: 0.5102037787437439
	Elapsed time = 3.031 min
[Statepoint 0]
	kT = 2.668 ref_kT = 2.686
	Predicted entropy: -0.10516239702701569
	Predicted free_energy: 207.26300048828125
	Predicted pressure: -151.1438446044922
[Statepoint 1]
	kT = 7.703 ref_kT = 7.674
	Predicted entropy: 0.051744695752859116
	Predicted free_energy: 271.28729248046875
	Predicted pressure: 438.6955871582031
[Statepoint 0] Elastic constants:
	C11: 167.75 GPa
	C33: 179.64 GPa
	C44: 44.34 GPa
	C12: 85.54 GPa
	C13: 69.38 GPa
[Statepoint 1] Elastic constants:
	C11: 122.19 GPa
	C33: 159.08 GPa
	C44: 35.01 GPa
	C12: 105.66 GPa
	C13: 69.94 GPa
Finished epoch 80 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 81]:
	Average train loss: 75.74182
	Average val loss: 86.73983001708984
	Gradient norm: 31883042.0
	Elapsed time = 0.893 min
	Per-target losses:
		F | train loss: 7477.308028371711 | val loss: 8578.9072265625
		U | train loss: 165798.8200775376 | val loss: 159842.265625
		virial | train loss: 200735.60886836232 | val loss: 197730.21875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 81
	Epoch loss = 0.00114
	Gradient norm: 0.2999221384525299
	Elapsed time = 3.039 min
[Statepoint 0]
	kT = 2.695 ref_kT = 2.686
	Predicted entropy: -0.17894963920116425
	Predicted free_energy: -185.53085327148438
	Predicted pressure: -216.5338592529297
[Statepoint 1]
	kT = 7.672 ref_kT = 7.674
	Predicted entropy: -0.052381183952093124
	Predicted free_energy: -64.20832824707031
	Predicted pressure: -282.4738464355469
[Statepoint 0] Elastic constants:
	C11: 174.07 GPa
	C33: 183.73 GPa
	C44: 46.68 GPa
	C12: 85.71 GPa
	C13: 69.61 GPa
[Statepoint 1] Elastic constants:
	C11: 133.26 GPa
	C33: 163.28 GPa
	C44: 37.22 GPa
	C12: 100.23 GPa
	C13: 70.14 GPa
Finished epoch 81 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 82]:
	Average train loss: 74.52704
	Average val loss: 85.3362045288086
	Gradient norm: 11836681.0
	Elapsed time = 0.896 min
	Per-target losses:
		F | train loss: 7330.60858200188 | val loss: 8412.560546875
		U | train loss: 396546.2791353383 | val loss: 377993.625
		virial | train loss: 206102.02016271147 | val loss: 208151.625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 82
	Epoch loss = 0.00151
	Gradient norm: 0.30076801776885986
	Elapsed time = 3.035 min
[Statepoint 0]
	kT = 2.683 ref_kT = 2.686
	Predicted entropy: -0.12869028747081757
	Predicted free_energy: 85.91639709472656
	Predicted pressure: -9.47510814666748
[Statepoint 1]
	kT = 7.660 ref_kT = 7.674
	Predicted entropy: 0.009865883737802505
	Predicted free_energy: 166.9315643310547
	Predicted pressure: -127.63363647460938
[Statepoint 0] Elastic constants:
	C11: 165.89 GPa
	C33: 177.44 GPa
	C44: 44.51 GPa
	C12: 85.14 GPa
	C13: 66.91 GPa
[Statepoint 1] Elastic constants:
	C11: 141.08 GPa
	C33: 157.50 GPa
	C44: 35.67 GPa
	C12: 85.26 GPa
	C13: 68.27 GPa
Finished epoch 82 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 83]:
	Average train loss: 77.25640
	Average val loss: 87.63298797607422
	Gradient norm: 10679301.0
	Elapsed time = 0.894 min
	Per-target losses:
		F | train loss: 7546.679984874295 | val loss: 8589.046875
		U | train loss: 658569.1872650376 | val loss: 627391.25
		virial | train loss: 282758.4473977914 | val loss: 278780.09375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 83
	Epoch loss = 0.00116
	Gradient norm: 0.15093760192394257
	Elapsed time = 3.032 min
[Statepoint 0]
	kT = 2.676 ref_kT = 2.686
	Predicted entropy: -0.09112221747636795
	Predicted free_energy: 278.3468017578125
	Predicted pressure: 68.45828247070312
[Statepoint 1]
	kT = 7.685 ref_kT = 7.674
	Predicted entropy: 0.059206970036029816
	Predicted free_energy: 327.6415100097656
	Predicted pressure: -3.847748279571533
[Statepoint 0] Elastic constants:
	C11: 152.72 GPa
	C33: 176.07 GPa
	C44: 42.07 GPa
	C12: 96.46 GPa
	C13: 67.87 GPa
[Statepoint 1] Elastic constants:
	C11: 138.18 GPa
	C33: 156.79 GPa
	C44: 33.25 GPa
	C12: 87.47 GPa
	C13: 69.02 GPa
Finished epoch 83 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 84]:
	Average train loss: 79.84636
	Average val loss: 90.72235107421875
	Gradient norm: 28702706.0
	Elapsed time = 0.893 min
	Per-target losses:
		F | train loss: 7831.954509075423 | val loss: 8920.9130859375
		U | train loss: 279528.95876409777 | val loss: 271280.21875
		virial | train loss: 311820.9403782895 | val loss: 310486.15625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 84
	Epoch loss = 0.00237
	Gradient norm: 1.433176040649414
	Elapsed time = 3.047 min
[Statepoint 0]
	kT = 2.687 ref_kT = 2.686
	Predicted entropy: -0.1437886655330658
	Predicted free_energy: -24.663686752319336
	Predicted pressure: 112.7208480834961
[Statepoint 1]
	kT = 7.646 ref_kT = 7.674
	Predicted entropy: 0.004199611954391003
	Predicted free_energy: 59.170166015625
	Predicted pressure: -165.59109497070312
[Statepoint 0] Elastic constants:
	C11: 154.42 GPa
	C33: 181.81 GPa
	C44: 43.43 GPa
	C12: 102.33 GPa
	C13: 69.52 GPa
[Statepoint 1] Elastic constants:
	C11: 116.30 GPa
	C33: 163.10 GPa
	C44: 35.15 GPa
	C12: 115.80 GPa
	C13: 70.74 GPa
Finished epoch 84 for all trainers in  4.00 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 85]:
	Average train loss: 78.84057
	Average val loss: 89.35836791992188
	Gradient norm: 50105512.0
	Elapsed time = 0.896 min
	Per-target losses:
		F | train loss: 7770.58612106438 | val loss: 8823.28125
		U | train loss: 172193.72080592104 | val loss: 166861.0
		virial | train loss: 240628.43529722744 | val loss: 239673.328125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 85
	Epoch loss = 0.00046
	Gradient norm: 0.2833036184310913
	Elapsed time = 3.034 min
[Statepoint 0]
	kT = 2.664 ref_kT = 2.686
	Predicted entropy: -0.19430281221866608
	Predicted free_energy: -208.55487060546875
	Predicted pressure: -3.3037607669830322
[Statepoint 1]
	kT = 7.704 ref_kT = 7.674
	Predicted entropy: -0.04532508924603462
	Predicted free_energy: -89.1484146118164
	Predicted pressure: 230.4102325439453
[Statepoint 0] Elastic constants:
	C11: 165.77 GPa
	C33: 183.24 GPa
	C44: 46.12 GPa
	C12: 93.16 GPa
	C13: 69.34 GPa
[Statepoint 1] Elastic constants:
	C11: 133.56 GPa
	C33: 164.64 GPa
	C44: 37.36 GPa
	C12: 100.70 GPa
	C13: 70.46 GPa
Finished epoch 85 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 86]:
	Average train loss: 74.01465
	Average val loss: 85.01675415039062
	Gradient norm: 11877346.0
	Elapsed time = 0.896 min
	Per-target losses:
		F | train loss: 7308.140364338581 | val loss: 8408.9677734375
		U | train loss: 295965.86501409777 | val loss: 288225.96875
		virial | train loss: 159320.92854940085 | val loss: 159711.625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 86
	Epoch loss = 0.00036
	Gradient norm: 0.0725645050406456
	Elapsed time = 3.040 min
[Statepoint 0]
	kT = 2.670 ref_kT = 2.686
	Predicted entropy: -0.13949011266231537
	Predicted free_energy: 13.261005401611328
	Predicted pressure: -162.36624145507812
[Statepoint 1]
	kT = 7.705 ref_kT = 7.674
	Predicted entropy: 0.007654790300875902
	Predicted free_energy: 102.80685424804688
	Predicted pressure: -233.39588928222656
[Statepoint 0] Elastic constants:
	C11: 162.19 GPa
	C33: 176.75 GPa
	C44: 44.92 GPa
	C12: 88.50 GPa
	C13: 66.93 GPa
[Statepoint 1] Elastic constants:
	C11: 134.00 GPa
	C33: 158.73 GPa
	C44: 35.88 GPa
	C12: 93.42 GPa
	C13: 68.56 GPa
Finished epoch 86 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 87]:
	Average train loss: 74.45125
	Average val loss: 85.1156997680664
	Gradient norm: 11115500.0
	Elapsed time = 0.918 min
	Per-target losses:
		F | train loss: 7312.030567140508 | val loss: 8379.439453125
		U | train loss: 511150.617481203 | val loss: 487515.4375
		virial | train loss: 204949.31041470866 | val loss: 208449.21875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 87
	Epoch loss = 0.00057
	Gradient norm: 0.032795924693346024
	Elapsed time = 3.043 min
[Statepoint 0]
	kT = 2.668 ref_kT = 2.686
	Predicted entropy: -0.10888708382844925
	Predicted free_energy: 175.039794921875
	Predicted pressure: -31.520151138305664
[Statepoint 1]
	kT = 7.730 ref_kT = 7.674
	Predicted entropy: 0.03400111943483353
	Predicted free_energy: 242.10157775878906
	Predicted pressure: 10.914896011352539
[Statepoint 0] Elastic constants:
	C11: 165.82 GPa
	C33: 175.78 GPa
	C44: 43.73 GPa
	C12: 84.07 GPa
	C13: 67.43 GPa
[Statepoint 1] Elastic constants:
	C11: 124.64 GPa
	C33: 156.83 GPa
	C44: 34.94 GPa
	C12: 102.16 GPa
	C13: 68.85 GPa
Finished epoch 87 for all trainers in  4.02 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 88]:
	Average train loss: 73.75571
	Average val loss: 84.71620178222656
	Gradient norm: 5778546.5
	Elapsed time = 0.897 min
	Per-target losses:
		F | train loss: 7255.231728001645 | val loss: 8351.7880859375
		U | train loss: 346864.2624530075 | val loss: 336424.59375
		virial | train loss: 214131.37845101033 | val loss: 215474.984375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 88
	Epoch loss = 0.00018
	Gradient norm: 0.053789567202329636
	Elapsed time = 3.035 min
[Statepoint 0]
	kT = 2.674 ref_kT = 2.686
	Predicted entropy: -0.08358306437730789
	Predicted free_energy: 100.60903930664062
	Predicted pressure: -80.28091430664062
[Statepoint 1]
	kT = 7.654 ref_kT = 7.674
	Predicted entropy: 0.044711317867040634
	Predicted free_energy: 157.05711364746094
	Predicted pressure: -170.31082153320312
[Statepoint 0] Elastic constants:
	C11: 156.91 GPa
	C33: 178.00 GPa
	C44: 43.23 GPa
	C12: 94.28 GPa
	C13: 68.99 GPa
[Statepoint 1] Elastic constants:
	C11: 130.60 GPa
	C33: 157.32 GPa
	C44: 34.49 GPa
	C12: 95.78 GPa
	C13: 69.64 GPa
Finished epoch 88 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 89]:
	Average train loss: 77.23034
	Average val loss: 87.81330871582031
	Gradient norm: 22347566.0
	Elapsed time = 0.890 min
	Per-target losses:
		F | train loss: 7579.5892710291355 | val loss: 8634.5400390625
		U | train loss: 217173.21017387218 | val loss: 211322.984375
		virial | train loss: 304319.6146910244 | val loss: 314146.75

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 89
	Epoch loss = 0.00099
	Gradient norm: 0.3493426740169525
	Elapsed time = 3.039 min
[Statepoint 0]
	kT = 2.678 ref_kT = 2.686
	Predicted entropy: -0.14557115733623505
	Predicted free_energy: -90.76190948486328
	Predicted pressure: 56.06751251220703
[Statepoint 1]
	kT = 7.666 ref_kT = 7.674
	Predicted entropy: -0.010921563021838665
	Predicted free_energy: -6.204138278961182
	Predicted pressure: -23.493144989013672
[Statepoint 0] Elastic constants:
	C11: 156.40 GPa
	C33: 183.56 GPa
	C44: 44.77 GPa
	C12: 101.68 GPa
	C13: 70.62 GPa
[Statepoint 1] Elastic constants:
	C11: 138.92 GPa
	C33: 163.57 GPa
	C44: 36.11 GPa
	C12: 95.55 GPa
	C13: 72.00 GPa
Finished epoch 89 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 90]:
	Average train loss: 72.93469
	Average val loss: 83.87403869628906
	Gradient norm: 5376934.0
	Elapsed time = 0.893 min
	Per-target losses:
		F | train loss: 7154.072548314145 | val loss: 8249.4912109375
		U | train loss: 479848.05239661655 | val loss: 466937.21875
		virial | train loss: 228529.2899215813 | val loss: 228046.0

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 90
	Epoch loss = 0.00018
	Gradient norm: 0.04242158308625221
	Elapsed time = 3.033 min
[Statepoint 0]
	kT = 2.666 ref_kT = 2.686
	Predicted entropy: -0.0778241753578186
	Predicted free_energy: 194.56138610839844
	Predicted pressure: -129.0907440185547
[Statepoint 1]
	kT = 7.689 ref_kT = 7.674
	Predicted entropy: 0.052843332290649414
	Predicted free_energy: 240.67416381835938
	Predicted pressure: 93.09114837646484
[Statepoint 0] Elastic constants:
	C11: 159.07 GPa
	C33: 175.25 GPa
	C44: 43.56 GPa
	C12: 89.72 GPa
	C13: 67.86 GPa
[Statepoint 1] Elastic constants:
	C11: 129.54 GPa
	C33: 157.28 GPa
	C44: 34.26 GPa
	C12: 96.52 GPa
	C13: 69.46 GPa
Finished epoch 90 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 91]:
	Average train loss: 73.06650
	Average val loss: 83.9799575805664
	Gradient norm: 2441417.5
	Elapsed time = 0.898 min
	Per-target losses:
		F | train loss: 7167.6229257225095 | val loss: 8255.9169921875
		U | train loss: 406988.568843985 | val loss: 392764.4375
		virial | train loss: 245821.51077890038 | val loss: 257005.359375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 91
	Epoch loss = 0.00212
	Gradient norm: 0.18385955691337585
	Elapsed time = 3.035 min
[Statepoint 0]
	kT = 2.675 ref_kT = 2.686
	Predicted entropy: -0.09420732408761978
	Predicted free_energy: 128.71847534179688
	Predicted pressure: -13.993939399719238
[Statepoint 1]
	kT = 7.660 ref_kT = 7.674
	Predicted entropy: 0.03927311301231384
	Predicted free_energy: 183.79595947265625
	Predicted pressure: 97.61955261230469
[Statepoint 0] Elastic constants:
	C11: 147.01 GPa
	C33: 178.91 GPa
	C44: 43.88 GPa
	C12: 106.49 GPa
	C13: 69.77 GPa
[Statepoint 1] Elastic constants:
	C11: 139.39 GPa
	C33: 159.53 GPa
	C44: 35.05 GPa
	C12: 89.76 GPa
	C13: 70.95 GPa
Finished epoch 91 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 92]:
	Average train loss: 71.92309
	Average val loss: 82.92314147949219
	Gradient norm: 2736049.25
	Elapsed time = 0.896 min
	Per-target losses:
		F | train loss: 7058.677958323543 | val loss: 8158.63134765625
		U | train loss: 467337.74013157893 | val loss: 453013.75
		virial | train loss: 217244.5199130639 | val loss: 220954.0625

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 92
	Epoch loss = 0.00088
	Gradient norm: 0.32502666115760803
	Elapsed time = 3.033 min
[Statepoint 0]
	kT = 2.688 ref_kT = 2.686
	Predicted entropy: -0.1121392771601677
	Predicted free_energy: 134.38577270507812
	Predicted pressure: -77.34964752197266
[Statepoint 1]
	kT = 7.675 ref_kT = 7.674
	Predicted entropy: 0.031705714762210846
	Predicted free_energy: 201.73199462890625
	Predicted pressure: -41.820438385009766
[Statepoint 0] Elastic constants:
	C11: 166.83 GPa
	C33: 178.86 GPa
	C44: 45.42 GPa
	C12: 87.46 GPa
	C13: 69.54 GPa
[Statepoint 1] Elastic constants:
	C11: 139.03 GPa
	C33: 155.79 GPa
	C44: 35.19 GPa
	C12: 89.74 GPa
	C13: 70.12 GPa
Finished epoch 92 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 93]:
	Average train loss: 72.41411
	Average val loss: 83.18617248535156
	Gradient norm: 4805343.0
	Elapsed time = 0.897 min
	Per-target losses:
		F | train loss: 7070.683355116306 | val loss: 8145.03515625
		U | train loss: 732887.5559210526 | val loss: 705372.75
		virial | train loss: 243596.34515977444 | val loss: 257611.75

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 93
	Epoch loss = 0.00043
	Gradient norm: 0.11096604913473129
	Elapsed time = 3.028 min
[Statepoint 0]
	kT = 2.681 ref_kT = 2.686
	Predicted entropy: -0.05097278207540512
	Predicted free_energy: 341.3049621582031
	Predicted pressure: 115.47044372558594
[Statepoint 1]
	kT = 7.641 ref_kT = 7.674
	Predicted entropy: 0.08422061800956726
	Predicted free_energy: 368.19390869140625
	Predicted pressure: -237.4134979248047
[Statepoint 0] Elastic constants:
	C11: 155.92 GPa
	C33: 175.84 GPa
	C44: 42.91 GPa
	C12: 94.04 GPa
	C13: 69.70 GPa
[Statepoint 1] Elastic constants:
	C11: 131.66 GPa
	C33: 153.64 GPa
	C44: 33.32 GPa
	C12: 93.41 GPa
	C13: 70.80 GPa
Finished epoch 93 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 94]:
	Average train loss: 72.50190
	Average val loss: 83.87460327148438
	Gradient norm: 3935435.0
	Elapsed time = 0.893 min
	Per-target losses:
		F | train loss: 7083.273536624765 | val loss: 8217.875
		U | train loss: 420423.48554981203 | val loss: 409444.28125
		virial | train loss: 312185.7279429041 | val loss: 321602.96875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 94
	Epoch loss = 0.00138
	Gradient norm: 0.1410047858953476
	Elapsed time = 3.042 min
[Statepoint 0]
	kT = 2.672 ref_kT = 2.686
	Predicted entropy: -0.0968511551618576
	Predicted free_energy: 139.2972869873047
	Predicted pressure: 102.52320098876953
[Statepoint 1]
	kT = 7.638 ref_kT = 7.674
	Predicted entropy: 0.0444231741130352
	Predicted free_energy: 192.52581787109375
	Predicted pressure: -200.40737915039062
[Statepoint 0] Elastic constants:
	C11: 158.25 GPa
	C33: 181.84 GPa
	C44: 44.14 GPa
	C12: 97.67 GPa
	C13: 70.92 GPa
[Statepoint 1] Elastic constants:
	C11: 142.81 GPa
	C33: 160.92 GPa
	C44: 34.82 GPa
	C12: 87.30 GPa
	C13: 71.32 GPa
Finished epoch 94 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 95]:
	Average train loss: 72.70493
	Average val loss: 83.33228302001953
	Gradient norm: 6865742.5
	Elapsed time = 0.896 min
	Per-target losses:
		F | train loss: 7097.549055744831 | val loss: 8157.70947265625
		U | train loss: 596955.9487781955 | val loss: 580759.8125
		virial | train loss: 283120.5049401764 | val loss: 293605.8125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 95
	Epoch loss = 0.00124
	Gradient norm: 0.11003584414720535
	Elapsed time = 3.033 min
[Statepoint 0]
	kT = 2.665 ref_kT = 2.686
	Predicted entropy: -0.05278043821454048
	Predicted free_energy: 304.0909118652344
	Predicted pressure: 189.12974548339844
[Statepoint 1]
	kT = 7.637 ref_kT = 7.674
	Predicted entropy: 0.09528447687625885
	Predicted free_energy: 325.496337890625
	Predicted pressure: -167.97755432128906
[Statepoint 0] Elastic constants:
	C11: 150.64 GPa
	C33: 176.74 GPa
	C44: 42.85 GPa
	C12: 98.88 GPa
	C13: 68.84 GPa
[Statepoint 1] Elastic constants:
	C11: 135.72 GPa
	C33: 153.57 GPa
	C44: 33.11 GPa
	C12: 88.30 GPa
	C13: 69.25 GPa
Finished epoch 95 for all trainers in  3.98 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 96]:
	Average train loss: 72.45762
	Average val loss: 83.35900115966797
	Gradient norm: 5572332.5
	Elapsed time = 0.898 min
	Per-target losses:
		F | train loss: 7099.968757342575 | val loss: 8187.07275390625
		U | train loss: 584338.7091165413 | val loss: 567419.3125
		virial | train loss: 218399.667139039 | val loss: 230213.6875

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 96
	Epoch loss = 0.00110
	Gradient norm: 1.2947107553482056
	Elapsed time = 3.038 min
[Statepoint 0]
	kT = 2.688 ref_kT = 2.686
	Predicted entropy: -0.025582559406757355
	Predicted free_energy: 326.6856689453125
	Predicted pressure: 145.86131286621094
[Statepoint 1]
	kT = 7.719 ref_kT = 7.674
	Predicted entropy: 0.12657535076141357
	Predicted free_energy: 334.51922607421875
	Predicted pressure: -190.2701416015625
[Statepoint 0] Elastic constants:
	C11: 148.37 GPa
	C33: 175.79 GPa
	C44: 43.03 GPa
	C12: 100.38 GPa
	C13: 69.57 GPa
[Statepoint 1] Elastic constants:
	C11: 127.06 GPa
	C33: 152.34 GPa
	C44: 32.78 GPa
	C12: 95.96 GPa
	C13: 70.80 GPa
Finished epoch 96 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 97]:
	Average train loss: 71.01707
	Average val loss: 82.11150360107422
	Gradient norm: 3391181.0
	Elapsed time = 0.894 min
	Per-target losses:
		F | train loss: 6990.827957956414 | val loss: 8095.21728515625
		U | train loss: 243040.15730733084 | val loss: 237241.59375
		virial | train loss: 216437.89626409774 | val loss: 230522.09375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 97
	Epoch loss = 0.00064
	Gradient norm: 0.33813363313674927
	Elapsed time = 3.023 min
[Statepoint 0]
	kT = 2.682 ref_kT = 2.686
	Predicted entropy: -0.125667005777359
	Predicted free_energy: 34.1160888671875
	Predicted pressure: 111.8371353149414
[Statepoint 1]
	kT = 7.656 ref_kT = 7.674
	Predicted entropy: 0.034453678876161575
	Predicted free_energy: 92.56702423095703
	Predicted pressure: -102.26790618896484
[Statepoint 0] Elastic constants:
	C11: 170.42 GPa
	C33: 180.96 GPa
	C44: 45.87 GPa
	C12: 85.48 GPa
	C13: 71.25 GPa
[Statepoint 1] Elastic constants:
	C11: 133.29 GPa
	C33: 157.94 GPa
	C44: 35.18 GPa
	C12: 96.82 GPa
	C13: 71.20 GPa
Finished epoch 97 for all trainers in  3.97 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 98]:
	Average train loss: 71.56160
	Average val loss: 82.50232696533203
	Gradient norm: 4556645.0
	Elapsed time = 0.887 min
	Per-target losses:
		F | train loss: 7033.204967986372 | val loss: 8125.1240234375
		U | train loss: 376261.48132048873 | val loss: 366345.125
		virial | train loss: 213322.65915765977 | val loss: 221184.484375

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 98
	Epoch loss = 0.00064
	Gradient norm: 0.030225055292248726
	Elapsed time = 3.046 min
[Statepoint 0]
	kT = 2.686 ref_kT = 2.686
	Predicted entropy: -0.07892695814371109
	Predicted free_energy: 182.00662231445312
	Predicted pressure: -4.9603962898254395
[Statepoint 1]
	kT = 7.612 ref_kT = 7.674
	Predicted entropy: 0.08135756105184555
	Predicted free_energy: 214.29061889648438
	Predicted pressure: -290.5653076171875
[Statepoint 0] Elastic constants:
	C11: 156.95 GPa
	C33: 178.32 GPa
	C44: 44.68 GPa
	C12: 95.03 GPa
	C13: 70.19 GPa
[Statepoint 1] Elastic constants:
	C11: 136.98 GPa
	C33: 157.82 GPa
	C44: 34.23 GPa
	C12: 90.15 GPa
	C13: 71.01 GPa
Finished epoch 98 for all trainers in  3.99 minutes.
---------Starting trainer Force and Energy Matching for 1 updates -----------
[Epoch 99]:
	Average train loss: 72.14637
	Average val loss: 83.63543701171875
	Gradient norm: 5924631.0
	Elapsed time = 0.897 min
	Per-target losses:
		F | train loss: 7066.967266799812 | val loss: 8214.4052734375
		U | train loss: 498878.61936090223 | val loss: 487566.28125
		virial | train loss: 244455.63099154134 | val loss: 250955.078125

---------Starting trainer Difftre for 1 updates -----------

[DiffTRe] Epoch 99
	Epoch loss = 0.00197
	Gradient norm: 0.2017417550086975
	Elapsed time = 3.039 min
[Statepoint 0]
	kT = 2.679 ref_kT = 2.686
	Predicted entropy: -0.05444449186325073
	Predicted free_energy: 272.4189758300781
	Predicted pressure: 146.8052978515625
[Statepoint 1]
	kT = 7.653 ref_kT = 7.674
	Predicted entropy: 0.10703057795763016
	Predicted free_energy: 285.66607666015625
	Predicted pressure: -236.8275604248047
[Statepoint 0] Elastic constants:
	C11: 165.41 GPa
	C33: 177.43 GPa
	C44: 43.90 GPa
	C12: 84.69 GPa
	C13: 69.47 GPa
[Statepoint 1] Elastic constants:
	C11: 142.33 GPa
	C33: 155.05 GPa
	C44: 33.76 GPa
	C12: 82.58 GPa
	C13: 70.87 GPa
Finished epoch 99 for all trainers in  3.99 minutes.
Total training time:  6.8 hours

Postprocessing#

To evaluate the model, we plot the elastic constants of the bottom-up model, the fused trained model, and the experimental reference data.

gamma_sim = 100
dt = 0.001
timings = sampling.process_printouts(dt, 70., 10., 0.1)

# Also compute the temperature
compute_fns['kT'] = custom_quantity.temperature
observables['kT'] = quantity.observables.init_traj_mean_fn('kT')

traj_gen = sampling.trajectory_generator_init(
    sim_template, energy_fn_template, timings,
    compute_fns, vmap_batch=2
)

@jax.jit
@functools.partial(jax.vmap, in_axes=(0, None, None))
@functools.partial(jax.vmap, in_axes=(None, 0, 0))
def postprocess_fn(params, reference_states, dynamic_kwargs):
    
    # Simulate a trajectory
    trajectory = traj_gen(params, reference_states, **dynamic_kwargs)
    
    # Compute the observables from the snapshot fns
    predictions = {
        obs_key: obs_fn(trajectory.aux, **dynamic_kwargs)
        for obs_key, obs_fn in observables.items()
    }
    
    return predictions
    
# Stack the trees to use vmap
all_params = chem_util.tree_stack([
    init_params, trainer_fused_params
])
all_reference_states = chem_util.tree_stack(reference_states.values())
all_state_kwargs = chem_util.tree_stack(state_kwargs.values())
if os.environ.get("EVALUATION", "False").lower() == "true":
    
    all_results = postprocess_fn(all_params, all_reference_states, all_state_kwargs)
    
    # Unstack
    fm_results, fused_results = chem_util.tree_unstack(all_results)
    fm_results = {str(temp): val for temp, val in zip(temps, chem_util.tree_unstack(fm_results))}
    fused_results = {str(temp): val for temp, val in zip(temps, chem_util.tree_unstack(fused_results))}
    
    onp.savez(data_dir / "TI_fm_results.npz", **jax.device_get(fm_results))
    onp.savez(data_dir / "TI_fused_results.npz", **jax.device_get(fused_results))
fused_data = onp.load(data_dir / "TI_fused_results.npz", allow_pickle=True)
fm_data = onp.load(data_dir / "TI_fm_results.npz", allow_pickle=True)

{"FM": dict(fm_data), "Fused": dict(fused_data)}

Hide code cell content

def plot_elastic_constants(ax, idx, predictions, reference, ylabel="C", set_ylim=True):
    ref_values = onp.asarray([consts[idx] for consts in reference.values()])
    ax.plot(list(reference.keys()), ref_values / 10 ** 3 * 1.66054, "k--", linewidth=2.0)
    
    colors = ['#368274', '#C92D39']
    markers = ["o", "X"]
    
    for pred, col, mar in zip(predictions, colors, markers):
        temps = [float(temp) for temp in pred.keys()]
        preds = onp.asarray([vals.tolist()['elastic_constants'][idx] for vals in pred.values()])
        
        ax.plot(
            temps, preds / 10 ** 3 * 1.66054, mar, color="white", markersize=9, markeredgewidth=2.0,
            markeredgecolor=col)
    
    ax.set_ylabel(ylabel)
    
    if set_ylim:
        lower, upper = ax.get_ylim()
        lower = onp.floor(lower / 10) * 10
        upper = onp.ceil(upper / 10) * 10

        ax.set_ylim([lower, upper])
        ax.set_yticks(onp.arange(lower, upper + 1, 10))
                
    return ax
fig, axes = plt.subplots(2, 2, layout="constrained", sharex=True, figsize=(9, 5.5))

all_preds = [fm_data, fused_data]
labels = ["Reference", "FM", "FM + DiffTRe"]

ax1, ax2, ax3, ax4 = onp.ravel(axes)
plot_elastic_constants(ax1, 0, all_preds, exp_data, "$C_{11}$ [GPa]")

plot_elastic_constants(ax2, 1, all_preds, exp_data, "$C_{33}$ [GPa]")

plot_elastic_constants(ax3, 3, all_preds, exp_data, "$C_{12}/C_{13}$ [GPa]", set_ylim=False)
plot_elastic_constants(ax3, 4, all_preds, exp_data, "$C_{12}/C_{13}$ [GPa]")

plot_elastic_constants(ax4, 2, all_preds, exp_data, "$C_{44}$ [GPa]")
ax4.legend(labels, loc="upper center", ncol=3)


fig.supxlabel("Temperature [K]")
fig.suptitle("Elastic Constants of Titanium")

fig.savefig("../_data/output/TI_elastic_constants.pdf", bbox_inches="tight")
../_images/679c54dea853c5518c71437c0a05167b4591ab33d798a376fcd7ce7583dc27aa.svg

References#