# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file contains several Trainer classes as a quickstart for users."""
import functools
import os
import pickle
import time
import warnings
from os import PathLike
from typing import Any, Mapping, Dict, Callable
import jax.tree_util
import numpy as onp
from jax import numpy as jnp, tree_util, jit, random
from jax_sgmc.data import numpy_loader
from jax_md_mod import custom_quantity
from chemtrain import (util)
from chemtrain.learn import (
force_matching, max_likelihood, difftre, property_prediction
)
from chemtrain.quantity import property_prediction
from chemtrain.trainers import base as tt
from chemtrain.ensemble import sampling, reweighting
from chemtrain.data import data_loaders
try:
from jax.typing import ArrayLike
except:
ArrayLike = Any
from optax import GradientTransformationExtraArgs
from jax_md.partition import NeighborFn, NeighborList
from chemtrain.typing import EnergyFnTemplate, TrajFn
[docs]
class PropertyPrediction(tt.DataParallelTrainer):
"""Trainer for direct prediction of molecular properties."""
[docs]
def __init__(self, error_fn, prediction_model, init_params, optimizer,
graph_dataset, targets, batch_per_device=1, batch_cache=10,
train_ratio=0.7, val_ratio=0.1, test_error_fn=None,
shuffle=False, convergence_criterion="window_median",
checkpoint_folder="Checkpoints"):
# TODO documentation
# TODO build graph on-the-fly as memory moving might be bottleneck here
model = property_prediction.init_model(prediction_model)
checkpoint_path = "output/property_prediction/" + str(checkpoint_folder)
loss_fn = property_prediction.init_loss_fn(error_fn)
super().__init__(
loss_fn, model, init_params, optimizer, checkpoint_path,
batch_per_device, batch_cache,
convergence_criterion=convergence_criterion
)
dataset_dict, _ = property_prediction.build_dataset(targets, graph_dataset)
self.set_datasets(
dataset_dict, train_ratio=train_ratio, val_ratio=val_ratio,
shuffle=shuffle
)
self.test_error_fn = test_error_fn
[docs]
def predict(self, single_observation):
"""Prediction for a single input graph using the current param state."""
batched_observation = tree_util.tree_map(
functools.partial(jnp.expand_dims, axis=0), single_observation
)
batched_prediction = self.batched_model(
self.best_inference_params, batched_observation)
single_prediction = tree_util.tree_map(
functools.partial(jnp.squeeze, axis=0), batched_prediction
)
return single_prediction
[docs]
def evaluate_testset_error(self, best_params=True):
assert "testing" in self._batch_states.keys(), (
"No test set available. Check train and val ratios."
)
assert self._test_fn is not None, (
"`test_error_fn` is necessary during initialization."
)
params = (self.best_inference_params_replicated
if best_params else self.state.params)
error = self.evaluate("testing", self._test_fn, params=params)
print(f"Error on test set: {error}")
return error
[docs]
class ForceMatching(tt.DataParallelTrainer):
"""Parametrizes potential models via the Force Matching method.
The Force Matching method can be used to learn atomistic [#Ercolessi1994]_
and coarse-grained [#Noid2008]_ models from first-principle or atomistic
reference data.
Args:
init_params: Initial energy parameters.
energy_fn_template: Function that takes energy parameters and returns
an energy function.
nbrs_init: Initial neighbor list. The neighbor list must be large enough
to not overflow for any sample of the dataset.
optimizer: Optimizer from optax.
gammas: Coefficients for the individual targets in the weighted loss.
weights_keys: Dictionary to entries of the dataset that contain a
per-sample weight for the total loss.
additional_targets: Additional snapshot targets to train on. Forces
and energy are derived automatically from the energy_fn_template.
feature_extract_fns: Features to extract from the data, passed to
all snapshot functions as keyword arguments.
energy_fn_has_aux: Energy function has an auxiliary output. The
energy function will be called with argument ``mode="with_aux"``
and should return a tuple ``(pot, aux)``.
batch_per_device: Number of samples to process vectorized on every
device.
batch_cache: Number of batches to load into the device memories.
full_checkpoint: Save the whole trainer instead of only some statistics.
disable_shmap: Use ``pmap`` instead of ``shmap`` for parallelization.
penalty_fn: Penalty depending only on the parameters.
convergence_criterion: Check convergence via
:class:`base.EarlyStopping`.
checkpoint_path: Path to the folder to store checkpoints.
log_file: Path to file where to log training progress.
Warning:
Currently neighborlist overflow is not checked.
Make sure to build nbrs_init large enough.
References:
.. [#Ercolessi1994] Ercolessi, F.; Adams, J. B. Interatomic Potentials
from First-Principles Calculations: The Force-Matching Method.
Europhys. Lett. 1994, 26 (8), 583–588.
https://doi.org/10.1209/0295-5075/26/8/005.
.. [#Noid2008] Noid, W. G.; Chu, J.-W.; Ayton, G. S.; Krishna, V.;
Izvekov, S.; Voth, G. A.; Das, A.; Andersen, H. C. The Multiscale
Coarse-Graining Method. I. A Rigorous Bridge between Atomistic and
Coarse-Grained Models. J Chem Phys 2008, 128 (24), 244114.
https://doi.org/10.1063/1.2938860.
"""
[docs]
def __init__(self,
init_params,
optimizer,
energy_fn_template: EnergyFnTemplate,
nbrs_init: NeighborList,
gammas: Dict[str, float] = None,
error_fns: Dict[str, Callable] = None,
weights_keys: Dict[str, str] = None,
additional_targets: Dict[str, Dict] = None,
feature_extract_fns: Dict[str, Callable] = None,
energy_fn_has_aux: bool = False,
batch_per_device: int = 1,
batch_cache: int = 10,
full_checkpoint: bool = False,
disable_shmap: bool = False,
penalty_fn: Callable = None,
convergence_criterion: str = "window_median",
log_file: str = "force_matching.log",
checkpoint_path: PathLike = "checkpoints"):
# Add additional trainable targets
if gammas is None:
gammas = {}
# This feature extractor enables to evaluate the energy function
# only once for all computations involving the energy and forces.
feature_fns = {
"energy_and_force": custom_quantity.energy_force_wrapper(
energy_fn_template, has_aux=energy_fn_has_aux
)
}
# These are common quantities to train on. The energy function is not
# necessary, since forces and energy are pre-extracted
quantities = {
"F": custom_quantity.force_wrapper(None),
"U": custom_quantity.energy_wrapper(None)
}
if additional_targets is not None:
quantities.update(additional_targets)
if feature_extract_fns is not None:
feature_fns.update(feature_extract_fns)
model = force_matching.init_model(
nbrs_init, quantities, feature_extract_fns=feature_fns
)
loss_fn = force_matching.init_loss_fn(
error_fns=error_fns, gammas=gammas, weights_keys=weights_keys)
super().__init__(loss_fn, model, init_params, optimizer,
checkpoint_path, batch_per_device, batch_cache,
disable_shmap=disable_shmap, penalty_fn=penalty_fn,
convergence_criterion=convergence_criterion,
full_checkpoint=full_checkpoint,
log_file=log_file,
energy_fn_template=energy_fn_template)
self._nbrs_init = nbrs_init
[docs]
def evaluate_mae_testset(self):
"""Prints the Mean Absolute Error for every target on the test set."""
mae_loss_fn = force_matching.init_loss_fn(
max_likelihood.mae_loss, individual=True
)
_, maes = self.evaluate(
"testing", mae_loss_fn, params=self.best_inference_params
)
for key, mae_value in maes.items():
print(f"{key}: MAE = {mae_value:.4f}")
[docs]
class DifftreParallel(tt.MLETrainerTemplate):
"""Trainer class for parametrizing potentials via the DiffTRe method.
This method performs simulations and updates for multiple statepoints in
parallel using vmap.
Args:
init_params: Initial energy parameters
optimizer: Optimizer from optax
energy_fn_template: Function that takes energy parameters and
initializes a new energy function.
simulator_template: Function that takes an energy function and
returns a simulator function.
neighbor_fn: Neighbor function. Must be of
:func:`jax_md_mod.custom_partition.masked_neighbor_list` if the
statepoints have a different number of atoms.
timings: Instance of TimingClass containing information about the
trajectory length and which states to retain
state_kwargs: Properties defining the thermodynamic state. Must at least
contain the temperature 'kT'. For a non-exhaustive list, see
:class:`chemtrain.ensemble.templates.StatePoint`.
quantities: Dict containing for each observable specified by the key a
corresponding function to compute it for each snapshot using
:func:`ensemble.sampling.quantity_traj`.
targets: Dict containing the same keys as quantities and containing
another dict providing 'gamma' and 'target' for each observable.
observables: Optional dictionary providing the observable functions
for the targets.
initial_trajstates: Initial trajectory states of the statepoints.
It is usually simpler to let the trainer generate the initial
trajectory states by providing `reference_states`.
reference_states: Initial simulator states from which DiffTRe can
compute the initial trajectory states.
num_runs_init: Number of runs to perform for the initial trajectory
states. This number can be increased to ensure a better
equilibration when starting from less favourable initial states.
reweight_ratio: Ratio of reference samples required for n_eff to
surpass to allow re-use of previous reference trajectory state.
If trajectories should not be re-used, a value > 1 can be
specified.
allowed_reduction: Allowed reduction of the effective sample size
through a parameter update.
step_size_scale: Initial step size scale for the step size adaption.
interior_points: Number of interior points to use for the step size
adaption.
sim_batch_size: Number of state-points to be processed as a single
batch. Gradients will be averaged over the batch before stepping the
optimizer.
full_checkpoint: If True, the whole trainer state is saved, otherwise
only important parameters are stored as a dictionary.
target_loss_fns: Dictionary of loss functions to use for each target.
loss_fn: Custom loss function to use for the training.
vmap_batch: Number of samples to process simultaneously when computing
instantaneous quantities for a trajectory.
bucket_recompute: Groups together statepoints that need a recomputation.
resample_simstates: Resample the sim states from all trajectories
instead of simulating independent chains.
convergence_criterion: Either 'max_loss' or 'ave_loss'.
If 'max_loss', stops if the maximum loss across all batches in
the epoch is smaller than convergence_thresh. 'ave_loss'
evaluates the average loss across the batch. For a single state
point, both are equivalent. A criterion based on the rolling
standatd deviation 'std' might be implemented in the future.
checkpoint_path: Name of folders to store checkpoints in.
log_dir: Path to the log file where to store training progress.
Attributes:
batch_losses: List of losses for each batch in each epoch.
epoch_losses: List of losses for each epoch.
step_size_history: List of step sizes for each batched update.
gradient_norm_history: List of gradient norms for each batched update.
batch_gradient_norms: List of gradient norms for each batch.
predictions: Dictionary containing the predictions for each statepoint
at each epoch.
early_stop: Instance of EarlyStopping to check for convergence.
"""
[docs]
def __init__(self,
key: jax.Array,
init_params: Any,
optimizer: GradientTransformationExtraArgs,
energy_fn_template: EnergyFnTemplate,
simulator_template: Callable,
neighbor_fn: NeighborFn,
timings: sampling.TimingClass,
state_kwargs: Dict[str, ArrayLike],
quantities: Dict[str, Dict],
targets: Dict[str, Any],
observables: Dict[str, TrajFn],
initial_trajstates = None,
reference_states = None,
num_runs_init: int = 1,
reweight_ratio: ArrayLike = 0.9,
allowed_reduction: ArrayLike = 0.95,
step_size_scale: float = 1e-4,
interior_points: int = 100,
sim_batch_size: int = 1,
full_checkpoint: bool = False,
target_loss_fns: Dict[str, Callable] = None,
loss_fn=None,
vmap_batch: int = 10,
bucket_recompute: bool = True,
resample_simstates: bool = False,
convergence_criterion: str = "window_median",
checkpoint_path: os.PathLike = "Checkpoints",
log_dir: os.PathLike = None):
init_state = util.TrainerState(params=init_params,
opt_state=optimizer.init(init_params))
# Optional: Initialized by calling trainer.init_step_size_adaption
# after all statepoints to be considered have been set up.
self._recompute = False
gen_init_traj, *reweight_fns = reweighting.init_pot_reweight_propagation_fns(
energy_fn_template, simulator_template, neighbor_fn, timings,
state_kwargs, reweight_ratio, False,
vmap_batch, safe_propagation=False,
entropy_approximation=False,
resample_simstates=resample_simstates
)
self._bucket_recompute = bucket_recompute
# TODO: Parallelize over multiple devices
if target_loss_fns is None:
target_loss_fns = {}
if loss_fn is None:
loss_fn = difftre.init_default_loss_fn(observables, target_loss_fns)
batched_model, batched_propagation, batched_weights = difftre.init_difftre_gradient_and_propagation(
reweight_fns, loss_fn, quantities, energy_fn_template,
wrapped=False, batched=True
)
self.reweight_ratio = reweight_ratio
self.key = key
self.batch_size = sim_batch_size
self.statepoints = state_kwargs
self.model = jax.jit(jax.value_and_grad(batched_model, argnums=0, has_aux=True))
self.propagate = jax.jit(batched_propagation)
self.weights = jax.jit(batched_weights)
self.targets = targets
if initial_trajstates is not None:
self.traj_states = initial_trajstates
else:
n_statepoints = targets[list(targets.keys())[0]]["target"].shape[0]
self.key, split = random.split(key)
self.traj_states = util.batch_map(
lambda ops: gen_init_traj(
ops[0][0], init_params, ops[0][1],
num_runs=num_runs_init, **ops[1]
),
(
(random.split(split, n_statepoints), reference_states),
state_kwargs
), batch_size=sim_batch_size
)
if allowed_reduction is not None:
self._adaptive_step_size = difftre.init_step_size_adaption(
lambda *args: (None, jnp.min(batched_weights(*args)[1])),
allowed_reduction, step_size_scale=step_size_scale,
interior_points=interior_points
)
else:
self._adaptive_step_size = lambda *args: (1.0, None)
super().__init__(
init_state=init_state,
optimizer=optimizer,
checkpoint_path=checkpoint_path,
full_checkpoint=full_checkpoint,
log_file=log_dir
)
self.batch_losses = self.checkpoint("batch_losses", [])
self.batch_gradient_norms = self.checkpoint("batch_gradient_norms", [])
self.epoch_losses = self.checkpoint("epoch_losses", [])
self.step_size_history = self.checkpoint("step_size_history", [])
self.gradient_norm_history = self.checkpoint("gradient_norm_history", [])
self.predictions = self.checkpoint("predictions", {})
# Initial trajstates should be set by now
for key in range(self.n_statepoints):
self.predictions[key] = {}
self.early_stop = tt.EarlyStopping(
self.params, convergence_criterion)
@property
def params(self):
"""Current energy parameters."""
return self.state.params
@params.setter
def params(self, loaded_params):
"""Replaces the current energy parameters."""
self.state = self.state.replace(params=loaded_params)
@property
def n_statepoints(self):
return self.traj_states.trajectory.position.shape[0]
def _get_batch(self):
"""Returns the next batch of statepoints to be processed."""
self.key, key = random.split(self.key)
num_statepoints = self.traj_states.trajectory.position.shape[0]
mask = jnp.ones(num_statepoints)
for i in range(num_statepoints // self.batch_size):
key, split = random.split(key)
# If bucketing is no longer possible or disabled, return back
# a random batch of statepoints. Otherwise, return a full batch of
# statepoints that either need or don't need reweighting.
if not self._bucket_recompute or 2 * self.batch_size > jnp.sum(mask):
batches = random.choice(
split, num_statepoints, (self.batch_size,),
replace=False, p=mask
)
else:
# Compute the effective sample size for twice as many samples
candidates = random.choice(
split, num_statepoints, (2 * self.batch_size,),
replace=False, p=mask
)
trajstates = util.tree_take(self.traj_states, candidates,
on_cpu=False)
# Compute the effective sample sizes
_, n_eff = self.weights(self.params, trajstates)
min_n_eff = self.traj_states.trajectory.position.shape[
1] * self.reweight_ratio
recompute = n_eff < min_n_eff
# Select samples only from the largest class. At least one
# of the conditions should be fulfilled:
# a) At least BS samples must be recomputed
# b) At least BS samples do not need a recomputation
key, split = random.split(key)
if jnp.sum(recompute) > self.batch_size:
select = jnp.float32(recompute)
else:
select = jnp.float32(~recompute)
# Select from the class at random. The not selected samples
# should remain in the pool.
batches = random.choice(
split, candidates, (self.batch_size,),
replace=False, p=select
)
# Mark the samples as drawn
mask = mask.at[batches].set(0.0)
yield batches
def _update(self, batch):
"""Computes gradient averaged over the sim_batch by propagating
respective state points. Additionally saves predictions and loss
for postprocessing."""
# Select the relevant trajstates and targets
trajstates = util.tree_take(self.traj_states, batch, on_cpu=False)
targets = util.tree_take(self.targets, batch, on_cpu=False)
statepoints = util.tree_take(self.statepoints, batch, on_cpu=False)
# Compute the effective sample sizes and print
_, n_eff = self.weights(self.params, trajstates)
min_n_eff = self.traj_states.trajectory.position.shape[1] * self.reweight_ratio
## Determine if recompute is necessary #################################
print(f"[DifftreParallel] Effective sample sizes (limit: {min_n_eff})")
for b, eff in zip(batch, n_eff):
info = "-> recompute" if eff < min_n_eff else ""
print(f"\t[Statepoint {b}] Effective sample size: {eff:.2f} {info}")
if onp.any(n_eff < min_n_eff):
print(f"[DifftreParallel] Recomputing trajectories...")
start = time.time()
trajstates = self.propagate(self.params, trajstates, statepoints)
print(f"[DifftreParallel] Recomputed trajectories in {(time.time() - start) / 60.:.2f} min")
# Save the recomputed trajectories
self.traj_states = util.tree_put(self.traj_states, batch, trajstates, on_cpu=False)
## Compute the loss ####################################################
print(f"[DifftreParallel] Computing loss...")
start = time.time()
(loss, state_point_predictions), grad = self.model(
self.params, trajstates, statepoints, targets
)
batch_norm = util.tree_norm(grad)
self.batch_gradient_norms.append(onp.asarray(batch_norm))
print(f"[DifftreParallel] Computed loss {loss} in {(time.time() - start) / 60.:.2f} min")
## Optimize the step size ##############################################
proposal = self._optimizer_step(grad)
# Perform stepsize optimization
start = time.time()
alpha, residual = self._adaptive_step_size(self.params, grad, proposal, trajstates)
print(
f"[Step Size] Found optimal step size for {alpha} with residual "
f"{residual} in {(time.time() - start):.1f} s", flush=True)
self._step_optimizer(grad, alpha=alpha)
## Save the predictions for the respective batches #####################
print(f"[DifftreParallel] Predictions:")
for idx, b in enumerate(batch):
self.predictions[int(b)][self._epoch] = {
key: onp.asarray(val[idx])
for key, val in state_point_predictions.items()
}
# Print scalar predictions
print(f"\t[Statepoint {b}]")
for key, value in state_point_predictions.items():
if jnp.shape(value[idx]) == ():
target = ""
if key in targets:
target = f"(target: {targets[key]['target'][idx]})"
print(f"\t\t{key} = {value[idx]} {target}")
# Save the loss and gradient norm
self.batch_losses.append(onp.asarray(loss))
self.step_size_history.append(onp.asarray(alpha))
[docs]
def predict(self, batch):
"""Predict for a batch of statepoints."""
# Select the relevant trajstates and targets
trajstates = util.tree_take(self.traj_states, batch, on_cpu=False)
targets = util.tree_take(self.targets, batch, on_cpu=False)
statepoints = util.tree_take(self.statepoints, batch, on_cpu=False)
# Compute the effective sample sizes and print
_, n_eff = self.weights(self.params, trajstates)
min_n_eff = self.traj_states.trajectory.position.shape[
1] * self.reweight_ratio
## Determine if recompute is necessary #################################
print(
f"[DifftreParallel] Effective sample sizes (limit: {min_n_eff})")
for b, eff in zip(batch, n_eff):
info = "-> recompute" if eff < min_n_eff else ""
print(
f"\t[Statepoint {b}] Effective sample size: {eff:.2f} {info}")
if onp.any(n_eff < min_n_eff):
print(f"[DifftreParallel] Recomputing trajectories...")
start = time.time()
trajstates = self.propagate(self.params, trajstates,
statepoints)
print(
f"[DifftreParallel] Recomputed trajectories in {(time.time() - start) / 60.:.2f} min")
# Save the recomputed trajectories
self.traj_states = util.tree_put(self.traj_states, batch,
trajstates, on_cpu=False)
## Compute the loss ####################################################
print(f"[DifftreParallel] Start predictions...")
for idx, b in enumerate(batch):
print(f"\t[Statepoint {b}]")
for key, val in statepoints.items():
if not jnp.isscalar(val[idx]): continue
print(f"\t\t{key} = {val[idx]}")
(_, state_point_predictions), _ = self.model(
self.params, trajstates, statepoints, targets
)
return state_point_predictions
def _evaluate_convergence(self, *args, thresh=None, **kwargs):
# sim_batch_size = -1 means all statepoints are processed in one batch.
batches_per_epoch = self.n_statepoints // self.batch_size
last_losses = jnp.array(self.batch_losses[-batches_per_epoch:])
epoch_loss = jnp.mean(last_losses)
duration = self.update_times[self._epoch]
self.epoch_losses.append(epoch_loss)
self.gradient_norm_history.append(
onp.mean(self.batch_gradient_norms[-batches_per_epoch:])
)
print(
f"\n[DiffTRe] Epoch {self._epoch}"
f"\n\tEpoch loss = {epoch_loss:.5f}"
f"\n\tGradient norm: {self.gradient_norm_history[-1]}"
f"\n\tElapsed time = {duration:.3f} min")
self._converged = self.early_stop.early_stopping(
epoch_loss, thresh, self.params)
@property
def best_params(self):
"""Returns the best parameters according to the early stopping criterion."""
return self.early_stop.best_params
[docs]
def move_to_device(self):
"""Transforms the trainer states to JAX arrays."""
super().move_to_device()
self.early_stop.move_to_device()
[docs]
class Difftre(tt.PropagationBase):
"""Trainer class for parametrizing potentials via the DiffTRe method.
The Differentiable Trajectory Reweighting (DiffTRe) method [#Thaler2021]_
is a method to compute the gradients of ensemble averages without
differentiating through the simulation. Therefore, the method can
efficiently train potential models on macroscopic observables.
The trainer initialization only sets the initial trainer state
as well as checkpointing and save-functionality. For training,
target state points with respective simulations need to be added
via :func:`Difftre.add_statepoint`.
Args:
init_params: Initial energy parameters
optimizer: Optimizer from optax
reweight_ratio: Ratio of reference samples required for n_eff to
surpass to allow re-use of previous reference trajectory state.
If trajectories should not be re-used, a value > 1 can be
specified.
sim_batch_size: Number of state-points to be processed as a single
batch. Gradients will be averaged over the batch before stepping the
optimizer.
energy_fn_template: Function that takes energy parameters and
initializes a new energy function. Here, the energy_fn_template
is only a reference that will be saved alongside the trainer.
Each state point requires its own due to the dependence on the
box size via the displacement function, which can vary between
state points.
convergence_criterion: Either 'max_loss' or 'ave_loss'.
If 'max_loss', stops if the maximum loss across all batches in
the epoch is smaller than convergence_thresh. 'ave_loss'
evaluates the average loss across the batch. For a single state
point, both are equivalent. A criterion based on the rolling
standatd deviation 'std' might be implemented in the future.
checkpoint_folder: Name of folders to store ckeckpoints in.
Attributes:
weight_fn: Dictionary containing the reweighting functions for each
statepoint.
batch_losses: List of losses for each batch in each epoch.
epoch_losses: List of losses for each epoch.
step_size_history: List of step sizes for each batched update.
gradient_norm_history: List of gradient norms for each batched update.
predictions: Dictionary containing the predictions for each statepoint
at each epoch.
early_stop: Instance of EarlyStopping to check for convergence.
Examples:
.. code-block :: python
trainer = trainers.Difftre(init_params, optimizer)
# Add all statepoints
trainer.add_statepoint(energy_fn_template, simulator_template,
neighbor_fn, timings, statepoint_dict,
compute_fns, reference_state, targets)
...
# Optionally initialize the step size adaption
trainer.init_step_size_adaption(allowed_reduction=0.5)
trainer.train(num_updates)
References:
.. [#Thaler2021] Thaler, S.; Zavadlav, J. Learning Neural Network
Potentials from Experimental Data via Differentiable Trajectory
Reweighting. Nat Commun **2021**, 12 (1), 6884.
https://doi.org/10.1038/s41467-021-27241-4.
"""
[docs]
def __init__(self,
init_params: Any,
optimizer: GradientTransformationExtraArgs,
reweight_ratio: ArrayLike = 1.0,
adaptive_step_size_threshold: float = 1e-4,
sim_batch_size: int = 1,
energy_fn_template: EnergyFnTemplate = None,
full_checkpoint: bool = False,
convergence_criterion: str = "window_median",
checkpoint_path: os.PathLike = "Checkpoints",
log_dir: os.PathLike = None):
init_state = util.TrainerState(params=init_params,
opt_state=optimizer.init(init_params))
# Optional: Initialized by calling trainer.init_step_size_adaption
# after all statepoints to be considered have been set up.
self._recompute = False
self._adaptive_step_size_threshold = adaptive_step_size_threshold
self.state_dicts = {}
self.weight_fn = {}
self.targets = {}
super().__init__(
init_trainer_state=init_state, optimizer=optimizer,
checkpoint_path=checkpoint_path, reweight_ratio=reweight_ratio,
sim_batch_size=sim_batch_size, full_checkpoint=full_checkpoint,
energy_fn_template=energy_fn_template, log_dir=log_dir)
self.batch_losses = self.checkpoint("batch_losses", [])
self.epoch_losses = self.checkpoint("epoch_losses", [])
self.step_size_history = self.checkpoint("step_size_history", [])
self.predictions = self.checkpoint("predictions", {})
self.early_stop = tt.EarlyStopping(self.params,
convergence_criterion)
[docs]
def add_statepoint(self,
energy_fn_template: EnergyFnTemplate,
simulator_template: Callable,
neighbor_fn: NeighborFn,
timings: sampling.TimingClass,
state_kwargs: Dict[str, ArrayLike],
quantities: Dict[str, Dict],
reference_state,
targets: Dict[str, Any] = None,
observables: Dict[str, TrajFn] = None,
target_loss_fns: Dict[str, Callable] = None,
loss_fn = None,
vmap_batch: int = 10,
initialize_traj: bool = True,
set_key: str = None,
resample_simstates: bool = False,
allowed_reduction: ArrayLike = None,
adaption_kwargs: Dict = None
):
"""
Adds a state point to the pool of simulations with respective targets.
Each statepoints initializes a new gradient and propagation function via
:func:`chemtrain.learn.difftre.init_difftre_gradient_and_propagation`.
Args:
energy_fn_template: Function that takes energy parameters and
initializes a new energy function.
simulator_template: Function that takes an energy function and
returns a simulator function.
neighbor_fn: Neighbor function
timings: Instance of TimingClass containing information
about the trajectory length and which states to retain
state_kwargs: Properties defining the thermodynamic state. Must
at least contain the temperature 'kT'. For a non-exhaustive
list, see :class:`chemtrain.ensemble.templates.StatePoint`.
quantities: Dict containing for each observable specified by the
key a corresponding function to compute it for each snapshot
using :func:`ensemble.sampling.quantity_traj`.
reference_state: Tuple of initial simulation state and neighbor list
targets: Dict containing the same keys as quantities and containing
another dict providing 'gamma' and 'target' for each observable.
Targets are only necessary when using the 'independent_loss_fn'.
observables: Optional dictionary providing the observable functions
for the targets. This is only necessary when the observable
functions are not already contained in the targets dict.
target_loss_fns: Optional dictionary providing the loss functions
for the individual targets. This is only necessary when the
loss functions are not already contained in the targets dict
or should be different from the MSE loss.
loss_fn: Custom loss function taking the trajectory of quantities
and weights and returning the loss and predictions;
By default, initializes an independent MSE loss, which computes
reweighting averages from snapshot-based observables.
In many applications, the default loss function will be
sufficient. For a description, see
:func:`chemtrain.learn.difftre.init_default_loss_fn`.
vmap_batch: Batch size of vmapping of per-snapshot energy for weight
computation.
initialize_traj: True, if an initial trajectory should be generated.
Should only be set to False if a checkpoint is loaded before
starting any training.
set_key: Specify a key in order to restart from same statepoint.
By default, uses the index of the sequance statepoints are
added, i.e. self.trajectory_states[0] for the first added
statepoint.
Can be used for changing the timings of the simulation during
training.
resample_simstates: Resample the sim states from all trajectories
instead of simulating independent chains.
allowed_reduction: Allowed reduction of the effective sample size
for the given statepoint.
adaption_kwargs: Additional keyword arguments for the step size
line search. For a description, see
:func:`chemtrain.learn.difftre.init_step_size_adaption`.
"""
# init simulation, reweighting functions and initial trajectory
(key, *reweight_fns) = self._init_statepoint(
reference_state,
energy_fn_template,
simulator_template,
neighbor_fn,
timings,
state_kwargs,
set_key,
vmap_batch,
initialize_traj,
safe_propagation=False,
entropy_approximation=False,
resample_simstates=resample_simstates
)
# For backwards compatibility and ease of use for a single statepoint
if observables is None:
observables = {
key: target["traj_fn"] for key, target in targets.items()
}
if target_loss_fns is None:
target_loss_fns = {
key: target["loss_fn"] for key, target in targets.items()
if "loss_fn" in target
}
# Enables a greater flexibility by sorting out data from frunctions
targets = {
key: {k: v for k, v in target.items() if k in ["gamma", "target"]}
for key, target in targets.items() if target.get("target") is not None
}
# build loss function for current state point
if loss_fn is None:
loss_fn = difftre.init_default_loss_fn(observables, target_loss_fns)
else:
print("Using custom loss function. Ignoring 'target' dict.")
difftre_grad_and_propagation = difftre.init_difftre_gradient_and_propagation(
reweight_fns, loss_fn, quantities, energy_fn_template
)
self.grad_fns[key] = difftre_grad_and_propagation
self.predictions[key] = {} # init saving predictions for this point
self.weight_fn[key] = jax.jit(reweight_fns[0])
self.state_dicts[key] = state_kwargs
self.targets[key] = targets
if allowed_reduction is not None:
if adaption_kwargs is None:
adaption_kwargs = {}
self._adaptive_step_size[key] = difftre.init_step_size_adaption(
self.weight_fn[key], allowed_reduction, **adaption_kwargs
)
# Reset loss measures if new state point es added since loss values
# are not necessarily comparable
self.early_stop.reset_convergence_losses()
[docs]
def predict(self, *, key: int):
"""Get predictions for a specific statepoint.
This method predicts the target quantities for a specific
statepoint. If necessary, the statepoint performs a trajectory
regeneration.
Args:
key: The key of the statepoint to predict.
Returns:
Returns a dictionary containing the predicted observables
given the current parameter values.
"""
traj_state = self.trajectory_states[key]
try:
traj_state.overflow
except:
start = time.time()
traj_state = traj_state()
compute_time = (time.time() - start) / 60.
print(
f"Delayed initialization of trajectory state in {compute_time :.2f} min.")
grad_fn = self.grad_fns[key]
(new_traj_state, *_, state_point_predictions) = grad_fn(
self.params, traj_state, self.state_dicts[key], self.targets[key],
recompute=self._recompute
)
self.trajectory_states[key] = new_traj_state
return state_point_predictions
def _update(self, batch):
"""Computes gradient averaged over the sim_batch by propagating
respective state points. Additionally saves predictions and loss
for postprocessing."""
# TODO parallelization? Maybe lift batch requirement and only
# sync sporadically?
# https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices
# https://github.com/mpi4jax/mpi4jax
# TODO split gradient and loss computation from stepping optimizer for
# building hybrid trainers?
# TODO is there good way to reuse this function in BaseClass?
# Note: in principle, we could move all the use of instance attributes
# into difftre_grad_and_propagation, which would increase re-usability
# with relative_entropy. However, this would probably stop all
# parallelization efforts
losses = 0.0
grads = None
for sim_key in batch:
traj_state = self.trajectory_states[sim_key]
try:
traj_state.overflow
except:
start = time.time()
traj_state = traj_state()
compute_time = (time.time() - start) / 60.
print(f"Delayed initialization of trajectory state in {compute_time :.2f} min.")
grad_fn = self.grad_fns[sim_key]
(new_traj_state, loss_val, curr_grad,
state_point_predictions) = grad_fn(
self.params, traj_state,
self.state_dicts[sim_key], self.targets[sim_key],
recompute=self._recompute
)
self.trajectory_states[sim_key] = new_traj_state
self.predictions[sim_key][self._epoch] = tree_util.tree_map(
onp.asarray, state_point_predictions)
losses += loss_val
if grads is None:
grads = curr_grad
else:
grads = util.tree_sum(grads, curr_grad)
# Print scalar predictions and statepoint measurements
self._print_measured_statepoint(sim_key=sim_key)
last_predictions = self.predictions[sim_key][self._epoch]
for quantity, value in last_predictions.items():
if value.ndim == 0:
if quantity in self.targets[sim_key]:
target = f"({self.targets[sim_key][quantity]['target']})"
else:
target = ""
print(f"\tPredicted {quantity}: {value} {target}")
if jnp.isnan(loss_val):
warnings.warn(f"Loss of state point {sim_key} in epoch "
f"{self._epoch} is NaN. This was likely caused by"
f" divergence of the optimization or a bad model "
f"setup causing a NaN trajectory.")
self._diverged = True # ends training
break
self.batch_losses.append(onp.asarray(losses / len(batch)))
batch_grad = tree_util.tree_map(lambda x: x / len(batch), grads)
step_size = 1.0
recompute = False
proposal = self._optimizer_step(batch_grad)
for sim_key in batch:
if sim_key not in self._adaptive_step_size: continue
alpha, residual = self._adaptive_step_size[sim_key](
self.params, batch_grad, proposal, self.trajectory_states[sim_key]
)
recompute |= alpha < self._adaptive_step_size_threshold
print(f"[Step Size] Found optimal step size for {alpha} for statepoint {sim_key} with residual "
f"{residual}", flush=True)
if alpha < step_size:
step_size = alpha
# self._recompute = recompute
self._step_optimizer(batch_grad, alpha=step_size)
batch_norm = util.tree_norm(batch_grad)
self.gradient_norm_history.append(onp.asarray(batch_norm))
self.step_size_history.append(onp.asarray(step_size))
def _evaluate_convergence(self, *args, thresh=None, **kwargs):
# sim_batch_size = -1 means all statepoints are processed in one batch.
if self.sim_batch_size < 0:
batches_per_epoch = 1
else:
batches_per_epoch = self.n_statepoints // self.sim_batch_size
last_losses = jnp.array(self.batch_losses[-batches_per_epoch:])
epoch_loss = jnp.mean(last_losses)
duration = self.update_times[self._epoch]
self.epoch_losses.append(epoch_loss)
print(
f"\n[DiffTRe] Epoch {self._epoch}"
f"\n\tEpoch loss = {epoch_loss:.5f}"
f"\n\tGradient norm: {self.gradient_norm_history[-1]}"
f"\n\tElapsed time = {duration:.3f} min")
self._converged = self.early_stop.early_stopping(
epoch_loss, thresh, self.params)
@property
def best_params(self):
"""Returns the best parameters according to the early stopping criterion."""
return self.early_stop.best_params
[docs]
def move_to_device(self):
"""Transforms the trainer states to JAX arrays."""
super().move_to_device()
self.early_stop.move_to_device()
[docs]
class RelativeEntropy(tt.PropagationBase):
"""Trainer for relative entropy minimization.
The Relative Entropy Minimization procedure coarse-graines potential
models by minimizing the relative entropy between the atomistic reference
and coarse-grained target canonical distributions [#Shell2008]_
[#Thaler2022]_.
The relative entropy algorithm currently assume a NVT ensemble.
Args:
init_params: Initial energy parameters.
optimizer: Optimizer from optax.
reweight_ratio: Ratio of reference samples required for n_eff to
surpass to allow re-use of previous reference trajectory state.
If trajectories should not be re-used, a value > 1 can be specified.
sim_batch_size: Number of state-points to be processed as a single
batch. Gradients will be averaged over the batch before stepping the
optimizer.
energy_fn_template: Function that takes energy parameters and
initializes an new energy function. Here, the ``energy_fn_template``
is only a reference that will be saved alongside the trainer.
Each state point requires its own due to the dependence on the box
size via the displacement function, which can vary between state points.
convergence_criterion: Either ``'max_loss'`` or ``'ave_loss'``.
If ``'max_loss'``, stops if the gradient norm cross all batches in
the epoch is smaller than convergence_thresh.
``'ave_loss'`` evaluates the average gradient norm across the batch.
For a single state point, both are equivalent.
checkpoint_path: Path to the folder to store ckeckpoints in.
full_checkpoint: Save the whole trainer instead of only the inference
data.
Attributes:
data_states: Dictionary containing the dataloader states for each
state points.
delta_re: Dictionary containing the improvement of the relative entropy
with respect to the initial potential.
step_size_history: List of step size scales for each batched update.
gradient_norm_history: List of gradient norms for each batched update.
weight_fn: Dictionary containing the reweighting functions for each
statepoint.
early_stop: Instance of EarlyStopping to check for convergence.
References:
.. [#Shell2008] Shell, M. S. The Relative Entropy Is Fundamental to
Multiscale and Inverse Thermodynamic Problems. J. Chem. Phys. 2008,
129 (14), 144108. https://doi.org/10.1063/1.2992060.
.. [#Thaler2022] Thaler, S.; Stupp, M.; Zavadlav, J. Deep Coarse-Grained
Potentials via Relative Entropy Minimization. The Journal of Chemical
Physics 2022, 157 (24), 244103. https://doi.org/10.1063/5.0124538.
"""
[docs]
def __init__(self,
init_params,
optimizer,
reweight_ratio: float = 0.9,
sim_batch_size: int = 1,
energy_fn_template: EnergyFnTemplate = None,
convergence_criterion: str = "window_median",
checkpoint_path: os.PathLike = "Checkpoints",
full_checkpoint: bool = False):
init_trainer_state = util.TrainerState(
params=init_params, opt_state=optimizer.init(init_params))
super().__init__(init_trainer_state, optimizer, checkpoint_path,
reweight_ratio, sim_batch_size, energy_fn_template,
full_checkpoint)
# in addition to the standard trajectory state, we also need to keep
# track of dataloader states for reference snapshots
self.data_states = {}
self.delta_re = self.checkpoint("delta_re", {})
self.step_size_history = self.checkpoint("step_size_history", [])
self.gradient_norm_history = self.checkpoint("gradient_norm_history", [])
self.early_stop = tt.EarlyStopping(self.params, convergence_criterion)
def _set_dataset(self, key, reference_data, reference_batch_size,
batch_cache=1):
"""Set dataset and loader corresponding to current state point."""
reference_loader = numpy_loader.NumpyDataLoader(
R=reference_data, copy=False)
init_ref_batch, get_ref_batch, _ = data_loaders.init_batch_functions(
data_loader=reference_loader, mb_size=reference_batch_size,
cache_size=batch_cache
)
init_reference_batch_state = init_ref_batch(shuffle=True)
self.data_states[key] = init_reference_batch_state
return get_ref_batch
[docs]
def add_statepoint(self,
reference_data: ArrayLike,
energy_fn_template: EnergyFnTemplate,
simulator_template: Callable,
neighbor_fn: NeighborFn,
timings: sampling.TimingClass,
state_kwargs: Dict[str, ArrayLike],
reference_state,
reference_batch_size: int = None,
batch_cache: int = 1,
initialize_traj: bool = True,
set_key: str = None,
vmap_batch: int = 10,
resample_simstates: bool = False,
allowed_reduction: float = None,
adaption_kwargs: Dict = None):
"""
Adds a state point to the pool of simulations.
The gradient of the relative entropy is computed via the gradient
function initialized by
:func:`chemtrain.learn.difftre.init_rel_entropy_gradient_and_propagation`.
As each reference dataset / trajectory corresponds to a single
state point, we initialize the dataloader together with the
simulation.
Currently only supports NVT simulations.
Args:
reference_data: De-correlated reference trajectory
energy_fn_template: Function that takes energy parameters and
initializes an new energy function.
simulator_template: Function that takes an energy function and
returns a simulator function.
neighbor_fn: Neighbor function
timings: Instance of TimingClass containing information
about the trajectory length and which states to retain
state_kwargs: Properties defining the thermodynamic state. Must
at least contain the temperature 'kT'.
reference_state: Tuple of initial simulation state and neighbor list
reference_batch_size: Batch size of dataloader for reference
trajectory. If None, will use the same number of snapshots as
generated via the optimizer.
batch_cache: Number of reference batches to cache in order to
minimize host-device communication. Make sure the cached data
size does not exceed the full dataset size.
initialize_traj: True, if an initial trajectory should be generated.
Should only be set to False if a checkpoint is loaded before
starting any training.
set_key: Specify a key in order to restart from same statepoint.
By default, uses the index of the sequance statepoints are
added, i.e. ``self.trajectory_states[0]`` for the first added
statepoint. Can be used for changing the timings of the
simulation during training.
vmap_batch: Batch size of vmapping of per-snapshot energy and
gradient calculation.
allowed_reduction: Allowed reduction of the effective sample size
for the given statepoint.
adaption_kwargs: Additional keyword arguments for the step size
line search. For a description, see
:func:`chemtrain.learn.difftre.init_step_size_adaption`.
"""
if reference_batch_size is None:
print("No reference batch size provided. Using number of generated "
"CG snapshots by default.")
states_per_traj = jnp.size(timings.t_production_start)
if reference_state.sim_state.position.ndim > 2:
n_trajectories = reference_state.sim_state.position.shape[0]
reference_batch_size = n_trajectories * states_per_traj
else:
reference_batch_size = states_per_traj
(key, *reweight_fns) = self._init_statepoint(reference_state,
energy_fn_template,
simulator_template,
neighbor_fn,
timings,
state_kwargs,
set_key,
vmap_batch,
initialize_traj,
entropy_approximation=False,
resample_simstates=resample_simstates,
safe_propagation=False)
reference_dataloader = self._set_dataset(key,
reference_data,
reference_batch_size,
batch_cache)
propagation_and_grad = difftre.init_rel_entropy_gradient_and_propagation(
reference_dataloader, reweight_fns, energy_fn_template,
state_kwargs["kT"], vmap_batch
)
self.grad_fns[key] = propagation_and_grad
self.delta_re[key] = []
self.weight_fn[key] = jax.jit(reweight_fns[0])
if allowed_reduction is not None:
if adaption_kwargs is None:
adaption_kwargs = {}
self._adaptive_step_size[key] = difftre.init_step_size_adaption(
self.weights_fn[key], allowed_reduction, **adaption_kwargs
)
def _update(self, batch):
"""Updates the potential using the gradient from relative entropy."""
grads = []
for sim_key in batch:
grad_fn = self.grad_fns[sim_key]
self.trajectory_states[sim_key], delta_re, curr_grad, \
self.data_states[sim_key] = grad_fn(self.params,
self.trajectory_states[sim_key],
self.data_states[sim_key])
grads.append(curr_grad)
self.delta_re[sim_key].append(delta_re)
batch_grad = util.tree_mean(grads)
step_size = 1.0
proposal = self._optimizer_step(batch_grad)
for sim_key in batch:
if sim_key not in self._adaptive_step_size: continue
alpha, residual = self._adaptive_step_size[sim_key](
self.params, batch_grad, proposal,
self.trajectory_states[sim_key]
)
if alpha < step_size:
step_size = alpha
print(f"[Step Size] Found optimal step size {step_size} with residual "
f"{residual}", flush=True)
self._step_optimizer(batch_grad, alpha=step_size)
batch_norm = util.tree_norm(batch_grad)
self.gradient_norm_history.append(onp.asarray(batch_norm))
self.step_size_history.append(onp.asarray(step_size))
def _evaluate_convergence(self, *args, thresh=None, **kwargs):
curr_grad_norm = self.gradient_norm_history[-1]
# Mean loss from last simbatch
mean_delta_re = onp.mean(
[delta_re[-1] for delta_re in self.delta_re.values()]
)
duration = self.update_times[self._epoch]
print(
f"\n[RE] Epoch {self._epoch}"
f"\n\tMean Delta RE loss = {mean_delta_re:.5f}"
f"\n\tGradient norm: {curr_grad_norm}"
f"\n\tElapsed time = {duration:.3f} min")
self._print_measured_statepoint()
self._converged = self.early_stop.early_stopping(
curr_grad_norm, thresh, save_best_params=False)
[docs]
class SGMCForceMatching(tt.ProbabilisticFMTrainerTemplate):
"""Trainer for stochastic gradient Markov-chain Monte Carlo training
based on force-matching.
init_samples: A list, possibly of size 1, of sets of initial MCMC samples,
where each spawns a dedicated MCMC chain,
"""
[docs]
def __init__(self, sgmc_solver, init_samples, val_dataloader=None,
energy_fn_template=None):
# TODO: Where does alias.py get checkpoint_path info?
super().__init__(None, energy_fn_template)
self._params = [init_sample["params"] for init_sample in init_samples]
self.sgmcmc_run_fn = sgmc_solver
self.init_samples = init_samples
# TODO use val dataloader to compute posterior predictive p value or
# other convergence metric. In ProbabilisticFMTrainerTemplate??
# TODO also use test_set?
[docs]
def train(self, iterations):
"""Training of any trainer should start by calling train."""
self.results = self.sgmcmc_run_fn(*self.init_samples,
iterations=iterations)
@property
def params(self):
"""Get the sampled parameters from all chains."""
if len(self.results) == 1: # single chain
return self.results[0]["samples"]["variables"]["params"]
else:
params = []
for chain in self.results:
params.append(chain["samples"]["variables"]["params"])
stacked_params = util.tree_stack(params)
return util.tree_combine(stacked_params)
@params.setter
def params(self, loaded_params):
raise NotImplementedError("Setting params seems not meaningful in"
" the case of SG-MCMC samplers.")
@property
def list_of_params(self):
"""A list of the sampled parameters."""
return util.tree_unstack(self.params)
[docs]
def save_trainer(self, save_path):
"""Save the trainer to a file."""
raise NotImplementedError("Saving the trainer currently does not work"
" for SGMCMC.")
[docs]
class EnsembleOfModels(tt.ProbabilisticFMTrainerTemplate):
"""Train an ensemble of models by starting optimization from different
initial parameter sets, for use in uncertainty quantification applications.
Example:
.. code-block:: python
trainer_list = []
for i in range(4):
trainer_list.append(trainers.ForceMatching(...))
trainer_ensemble = trainers.EnsembleOfModels(trainer_list)
trainer_ensemble.train(*args, **kwargs)
trained_params = trainer_ensemble.list_of_params
"""
[docs]
def __init__(self, trainers, ref_energy_fn_template=None):
super().__init__(None, ref_energy_fn_template)
self.trainers = trainers
[docs]
def train(self, *args, **kwargs):
for i, trainer in enumerate(self.trainers):
print(f"---------Starting trainer {i}-----------")
trainer.train(*args, **kwargs)
print("Finished training all models.")
@property
def params(self):
return util.tree_stack(self.list_of_params)
@params.setter
def params(self, loaded_params):
for i, params in enumerate(loaded_params):
self.trainers[i].params = params
@property
def list_of_params(self):
params = []
for trainer in self.trainers:
if hasattr(trainer, "best_params"):
params.append(trainer.best_params)
else:
params.append(trainer.params)
return params
[docs]
class InterleaveTrainers(tt.TrainerInterface):
"""Interleaves updates to train models using multiple algorithms.
This special trainer allows to train models simultaneously with different
algorithms.
Example:
.. code-block::
# First initialize the base-trainers, e.g.
fm_trainer = trainers.ForceMatching(...)
difftre_trainer = trainers.Difftre(...)
difftre_trainer.add_statepoint(...)
# Now combine the trainers. The trainers are executed in the
# order in which they are added
trainer = trainers.InterleaveTrainers('checkpoint_folder',
energy_fn_template,
full_checkpoint=False)
# Force matching should run 10 epochs before difftre runs 2 epochs
trainer.add_trainer(fm_trainer, num_updates=10, name='Force Matching')
trainer.add_trainer(difftre_trainer, num_updates=2, name='DiffTRe')
trainer.train(100, checkpoint_frequency=10)
Args:
sequential: Start the next trainer directly with the optimized
parameters of the previous trainer. In the non-sequential case,
the trainers start their epoch on the same parameter set and
the final update is a weighted sum of both updates.
checkpoint_base_path: Location to store checkpoints of the trainers.
reference_energy_fn_template: Energy function template to optionally
return an energy function with current parameters.
full_checkpoint: Store the complete trainer or important properties
only.
"""
[docs]
def __init__(self,
sequential = True,
checkpoint_base_path = "checkpoints",
reference_energy_fn_template=None,
full_checkpoint=False):
super().__init__(checkpoint_base_path, reference_energy_fn_template,
full_checkpoint)
self.sequential = sequential
self._trainers = []
self._epoch = 0
[docs]
def add_trainer(self, trainer, num_updates: int = 1, name: str = "trainer",
weight: float = 1.0, **trainer_kwargs):
"""Adds a trainer to the combined training.
The trainers are executed in the order they are added to this instance.
It is possible to specify how many epochs each trainer should train
before the next trainer starts again.
Args:
trainer: Trainer to add to the chain.
num_updates: Consecutive updates of the trainer in one epoch of the
interleaved trainer.
name: Display name of the trainer.
weight: Weight for the interpolated update of the parameters.
trainer_kwargs: Additional arguments for the training method
of the trainer.
"""
self._trainers.append(
{"trainer": trainer, "num_updates": num_updates, "name": name,
"kwargs": trainer_kwargs, "weight": weight}
)
@property
def params(self):
return self._trainers[-1]["trainer"].params
@params.setter
def params(self, params):
for trainer in self._trainers:
trainer["trainer"].params = params
@property
def _all_params(self):
return [t["trainer"].params for t in self._trainers]
@property
def _all_weights(self):
return [t["weight"] for t in self._trainers]
def _init_interpolated_update(self):
weights = jnp.asarray(self._all_weights)
weights /= jnp.sum(weights)
@jit
def update(parameters):
# Scale the parameters
structure = tree_util.tree_structure(parameters[0])
leaves = [tree_util.tree_leaves(t) for t in parameters]
concat = [jnp.concatenate(l) for l in zip(*leaves)]
summed = [jnp.sum(weights * l, axis=0) for l in concat]
return tree_util.tree_unflatten(structure, summed)
return update
[docs]
def train(self, epochs, checkpoint_frequency=None):
"""Train model with combined algorithms.
Args:
epochs: Number of epochs, where one epoch can contain multiple
epochs for each added trainer.
checkpoint_frequency: Save a checkpoint in the given frequency.
"""
interpolated_update = self._init_interpolated_update()
self._converged = False
start_epoch = self._epoch
end_epoch = start_epoch + epochs
for e in range(start_epoch, end_epoch):
start = time.time()
for t, trainer in enumerate(self._trainers):
print(f"---------Starting trainer {trainer['name']} for {trainer['num_updates']} updates -----------")
trainer["trainer"].train(trainer["num_updates"], **trainer["kwargs"])
next = (t + 1) % len(self._trainers)
if self.sequential:
# Pass updated parameters to the next trainer
self._trainers[next]["trainer"].params = trainer["trainer"].params
if not self.sequential:
# Update the parameters of all trainers with a weighted sum of
# the individual parameters
self.params = interpolated_update(self.params)
duration = (time.time() - start) / 60.
self._epoch += 1
print(f"Finished epoch {e} for all trainers in {duration : .2f} minutes.")
self._dump_checkpoint_occasionally(frequency=checkpoint_frequency)
[docs]
def move_to_device(self):
for trainer in self._trainers:
trainer["trainer"].move_to_device()
[docs]
def save_trainer(self, save_path, format=".pkl"):
data = {}
for t, trainer in enumerate(self._trainers):
number = str(t + 1).rjust(3, "0")
key = "trainer_{0}_{1}".format(trainer["name"], number)
data[key] = trainer["trainer"].save_trainer(None, format="none")
if format == ".pkl":
with open(save_path, "wb") as pickle_file:
pickle.dump(data, pickle_file)
elif format == "none":
return data