# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract templates for trainers, defining common functionality and
requirements."""
import abc
import copy
import dataclasses
import functools
import logging
import pathlib
import sys
import time
import warnings
from abc import abstractmethod
from os import PathLike
import inspect
from typing import Callable, Dict, Any
import cloudpickle as pickle
import jax
import numpy as onp
from jax import (
numpy as jnp, random, device_count, jit, device_get,
tree_util
)
from jax.tree_util import tree_map
from jax_sgmc import data
from chemtrain import util
from chemtrain.data import data_loaders
from chemtrain.learn import max_likelihood, difftre
from jax_md_mod.model import dropout
from chemtrain.ensemble.reweighting import init_pot_reweight_propagation_fns
from chemtrain.ensemble import sampling
from chemtrain.typing import EnergyFnTemplate
from chemtrain.util import format_not_recognized_error
class CaptureStdout:
"""Capture stdout and writes to file.
This context manager writes messages to stdout and a file.
Args:
file: Path to file where to write the stdout.
"""
def __init__(self, file=None):
self.files = []
if file is not None:
self.files = [file]
def write(self, message):
for f in self.out:
f.write(message)
f.flush()
def flush(self):
for f in self.out:
f.flush()
def __enter__(self):
try:
self.files = [open(file_path, "w") for file_path in self.files]
self.out = (sys.stdout, *self.files)
sys.stdout = self
except Exception as e:
self.__exit__(None, None, None)
raise e
return self
def __exit__(self, exc_type, exc_value, traceback):
sys.stdout = self.out[0]
for f in self.files:
try:
f.close()
except Exception as e:
print(f"Error closing file: {e}")
@dataclasses.dataclass
class CheckpointAttr:
name: str
object: Any
class TrainerInterface(metaclass=abc.ABCMeta):
"""Abstract class defining the user interface of trainers as well as
checkpointing functionality.
"""
# TODO write protocol classes for better documentation of initialized
# functions
def __init__(self,
checkpoint_path,
reference_energy_fn_template=None,
full_checkpoint=True):
"""A reference energy_fn_template can be provided, but is not mandatory
due to the dependence of the template on the box via the displacement
function.
"""
checkpoint_path = pathlib.Path(checkpoint_path)
checkpoint_path.mkdir(exist_ok=True, parents=True)
self._statistics: Dict[str, str] = {}
self._full_checkpoint = full_checkpoint
self.checkpoint_path = checkpoint_path
self._epoch = self.checkpoint("epoch", 0)
self.reference_energy_fn_template = reference_energy_fn_template
@property
def energy_fn(self):
"""Returns the energy function for the current parameters."""
if self.reference_energy_fn_template is None:
raise ValueError("Cannot construct energy_fn as no reference "
"energy_fn_template was provided during "
"initialization.")
return self.reference_energy_fn_template(self.params)
def _dump_checkpoint_occasionally(self, *args, checkpoint_frequency=None, **kwargs):
"""Dumps a checkpoint during training, from which training can
be resumed.
"""
assert self.checkpoint_path is not None
if checkpoint_frequency is not None:
pathlib.Path(self.checkpoint_path).mkdir(parents=True,
exist_ok=True)
if self._epoch % checkpoint_frequency == 0: # checkpoint model
epoch = str(self._epoch).rjust(5, "0")
file_path = (
pathlib.Path(self.checkpoint_path) / f"epoch{epoch}.pkl")
self.save_trainer(file_path)
print(f"[{type(self).__name__}] Checkpoint created sucessfully at: {str(file_path)}")
def save_trainer(self, save_path, format=".pkl"):
"""Saves whole trainer, e.g. for production after training."""
if self._full_checkpoint:
data = self
else:
data = {
name: self.__getattribute__(key)
for key, name in self._statistics.items()
}
if format == ".pkl":
leaves, treedef = tree_util.tree_flatten(data)
leaves = [
onp.asarray(leaf) if isinstance(leaf, jnp.ndarray) else leaf
for leaf in leaves
]
with open(save_path, "wb") as pickle_file:
pickle.dump(tree_util.tree_unflatten(treedef, leaves), pickle_file)
elif format == "none":
return data
def save_energy_params(self, file_path, save_format=".hdf5", best=False):
"""Saves energy parameters.
Args:
file_path: Path to the file where to save the energy parameters.
Currently, only saving to pickle files (``"*.pkl"``) is
supported.
save_format: Format in which to save the energy parameters.
best: If True, tries to save the best parameters, e.g., on the
validation loss. If no criterion to determine the best params
was specified, saves the latest parameters instead.
"""
if best:
try:
params = self.best_params
except AttributeError:
warnings.warn(
f"Saving best params is not possible, saving the last "
f"paramters.")
params = self.params
else:
params = self.params
if save_format == ".hdf5":
raise NotImplementedError
elif save_format == ".pkl":
with open(file_path, "wb") as pickle_file:
pickle.dump(device_get(params), pickle_file)
else:
format_not_recognized_error(save_format)
def load_energy_params(self, file_path):
"""Loads energy parameters.
Args:
file_path: Path to the file containing the energy parameters.
Currently, only loading from pickle files (``"*.pkl"``) is
supported.
"""
if file_path.endswith(".hdf5"):
raise NotImplementedError
elif file_path.endswith(".pkl"):
with open(file_path, "rb") as pickle_file:
params = pickle.load(pickle_file)
else:
format_not_recognized_error(file_path[-4:])
self.params = tree_map(jnp.array, params) # move state on device
@property
@abc.abstractmethod
def params(self):
"""Short-cut for parameters. Depends on specific trainer."""
@params.setter
@abc.abstractmethod
def params(self, loaded_params):
raise NotImplementedError()
@abc.abstractmethod
def train(self, *args, **kwargs):
"""Training of any trainer should start by calling train."""
@abc.abstractmethod
def move_to_device(self):
"""Move all attributes that are expected to be on device to device to
avoid TracerExceptions after loading trainers from disk, i.e.
loading numpy rather than device arrays.
"""
def checkpoint(self, name, object):
"""Marks attribute to be saved in a partial checkpoint.
The marked attribute is saved to a checkpoint dictionary under
the specified name.
Args:
name: Name of the statistic in the saved dictionary
object: Object to initialize the attribute
Returns:
Returns the original object wrapped as a CheckpointNode.
"""
return CheckpointAttr(name, object)
def __setattr__(self, key, value):
if isinstance(value, CheckpointAttr):
# The wrapper class is only to identify attributes to be checkpointed.
# We now track these attributes in a dictionary and remove the wrapper,
# which is no longer needed.
if value.name in self._statistics.values():
for duplicate_key, duplicate_value in self._statistics.items():
if duplicate_key == key:
warnings.warn(f"[{self.__class__.__name__}] Attribute {duplicate_key} is marked for checkpoining twice.")
continue
if duplicate_value == value.name:
raise ValueError(
f"Duplicate checkpoint name found for attribute {key}. "
f"Name '{value.name}' is already used for attribute {duplicate_key}."
)
self._statistics[key] = value.name
object.__setattr__(self, key, value.object)
else:
object.__setattr__(self, key, value)
def restore(self, checkpoint):
"""Restores the trainer from a checkpoint.
Args:
checkpoint: Checkpoint to restore from. Can be a path to a file or
a dictionary containing the trainer state.
"""
with open(checkpoint, "rb") as f:
checkpoint = pickle.load(f)
# Restore all attributes that were marked as checkpointable
restored = []
for attr, key in self._statistics.items():
object.__setattr__(self, attr, checkpoint[key])
restored.append(attr)
# Finish this epoch
self._epoch += 1
# Write summary
print(f"[{self.__class__.__name__}] Attributes Restored:")
print(f", ".join(restored))
unchanged = [
attr for attr in self.__dict__.keys() if attr not in restored
]
print(f"[{self.__class__.__name__}] Attributes Unchanged:")
print(f", ".join(unchanged))
[docs]
class MLETrainerTemplate(TrainerInterface):
"""Abstract class implementing common properties and methods of single
point estimate Trainers using optax optimizers.
Args:
optimizer: Optax optimizer
init_state: Initial state of optimizer and model
checkpoint_path: Path to folder where checkpoints are saved
full_checkpoint: Whether to save the full trainer with pickle or only
a subset of attributes.
log_file: Write loggs of Trainer to the file specified by path.
reference_energy_fn_template: Function returning a concrete energy
function for the current parameters
The MLE trainer performs a sequence of task before and after each training,
epoch and batch update. It is possible to add custom tasks to the trainer
via :func:`MLETrainerTemplate.add_task`.
Attributes:
update_times: Computation time of each update
gradient_norm_history: Norms of the gradient for each update
"""
[docs]
def __init__(self,
optimizer,
init_state: util.TrainerState,
checkpoint_path: PathLike,
full_checkpoint: bool = True,
log_file: PathLike = None,
reference_energy_fn_template: EnergyFnTemplate = None):
super().__init__(
checkpoint_path, reference_energy_fn_template, full_checkpoint)
self.optimizer = optimizer
self.state = self.checkpoint("trainer_state", init_state)
self.update_times = self.checkpoint("update_times", [])
self.gradient_norm_history = self.checkpoint("gradient_norm_history", [])
self._converged = self.checkpoint("converged", False)
self._diverged = self.checkpoint("diverged", False)
self.log_file = log_file
self._tasks = {}
# Add standard tasks
self.add_task("pre_epoch", self._update_times_start)
self.add_task("post_epoch", self._update_times_end)
self.add_task("post_epoch", self._evaluate_convergence)
self.add_task("post_epoch", self._dump_checkpoint_occasionally)
# Dropout only if params indicate necessity
if dropout.dropout_is_used(self.params):
self.add_task("post_batch", self._update_dropout)
# Note: make sure not to use such a construct during training as an
# if-statement based on params forces the python part to wait for the
# completion of the batch, hence losing the advantage of asynchronous
# dispatch, which can become the bottleneck in high-throughput learning.
self.release_fns = {}
[docs]
def add_task(self, trigger, fn_or_method):
"""Adds a tasks to perform regularly during training.
Args:
trigger: The trigger at which the task is executed. Can be
``"pre/post_training/epoch/batch"``.
fn_or_method: The function or method to be executed.
Example:
The following code adds a task printing a specific energy parameter
after each epoch.
.. code-block:: python
def print_parameter(trainer, *args, **kwargs):
print(f"Parameter after epoch {trainer._epoch}: "
f"{trainer.state.params["parameter"]}")
trainer.add_task("post_epoch", print_parameter)
"""
valid_triggers = [
"pre_training",
"pre_epoch",
"pre_batch",
"post_batch",
"post_epoch",
"post_training"
]
assert trigger in valid_triggers, (
f"Provided trigger {trigger} is invalid, can only be "
f"{valid_triggers}."
)
if trigger not in self._tasks:
self._tasks[trigger] = []
self._tasks[trigger].append(fn_or_method)
return fn_or_method
def _execute_tasks(self, trigger, *args, **kwargs):
"""Executes a dynamical set of tasks."""
if trigger not in self._tasks.keys():
return
for fn_or_method in self._tasks[trigger]:
if inspect.ismethod(fn_or_method):
fn_or_method(*args, **kwargs)
else:
fn_or_method(self, *args, **kwargs)
[docs]
def print_training_tasks(self):
"""Prints the tasks performed by the trainer."""
print("Preparation:")
if "pre_training" in self._tasks:
for task in self._tasks["pre_training"]:
print(f" - {task}")
else:
print("<no preparation tasks>")
print("\nFor every EPOCH\n===============")
if "pre_epoch" in self._tasks:
for task in self._tasks["pre_epoch"]:
print(f" - {task}")
else:
print("<no pre-epoch tasks>")
print(f"\n\tFor every BATCH in EPOCH\n\t----------------")
if "pre_batch" in self._tasks:
for task in self._tasks["pre_batch"]:
print(f"\t - {task}")
else:
print("\t<no pre-batch tasks>")
print(f"\n\tUPDATE\n")
if "post_batch" in self._tasks:
for task in self._tasks["post_batch"]:
print(f" - {task}")
else:
print("\t<no post-batch tasks>")
print("")
if "post_epoch" in self._tasks:
for task in self._tasks["post_epoch"]:
print(f" - {task}")
else:
print("<no post_epoch tasks>")
print("\nPostprocessing:")
if "post_training" in self._tasks:
for task in self._tasks["post_training"]:
print(f" - {task}")
else:
print("<no postprocessing tasks>")
def _optimizer_step(self, curr_grad):
"""Wrapper around step_optimizer that is useful whenever the
update of the optimizer can be done outside jit-compiled functions.
Returns:
Returns the parameters after an update of the optimizer, but
without updating the internal states.
"""
new_params, _ = max_likelihood.step_optimizer(
self.params, self.state.opt_state, curr_grad, self.optimizer)
return new_params
def _step_optimizer(self, curr_grad, alpha=1.0):
"""Wrapper around step_optimizer that is useful whenever the
update of the optimizer can be done outside of jit-compiled functions.
"""
new_params, new_opt_state = max_likelihood.step_optimizer(
self.params, self.state.opt_state, curr_grad, self.optimizer)
# Do an optimized update
new_params = tree_map(
lambda old, new: old * (1 - alpha) + new * alpha,
self.params, new_params
)
self.state = self.state.replace(params=new_params,
opt_state=new_opt_state)
[docs]
def train(self, max_epochs, thresh=None, checkpoint_freq=None):
"""Trains for a maximum number of epochs, checkpoints after a
specified number of epochs and ends training if a convergence
criterion is met. This function can be called multiple times to extend
training.
This function only implements the training sceleton by splitting the
training into epochs and batches as well as providing checkpointing and
ending of training if the convergence criterion is met. The specifics
of dataloading, parameter updating and convergence criterion evaluation
needs to be implemented in ``_get_batch()``, ``_update()`` and
``_evaluate_convergence()``, respectively, depending on the exact trainer
details to be implemented.
Args:
max_epochs: Maximum number of epochs for which training is
continued. Training will end sooner if convergence criterion is
met.
thresh: Threshold of the early stopping convergence criterion. If
None, no early stopping is applied. Definition of thresh depends
on specific convergence criterion. See :class:`EarlyStopping`.
checkpoint_freq: Number of epochs after which a checkpoint is saved.
By default, do not save checkpoints.
"""
self._converged = False
start_epoch = self._epoch
end_epoch = start_epoch + max_epochs
with CaptureStdout(self.log_file):
self._execute_tasks("pre_training")
for _ in range(start_epoch, end_epoch):
try:
self._execute_tasks("pre_epoch")
for batch in self._get_batch():
self._execute_tasks("pre_batch", batch)
self._update(batch)
self._execute_tasks("post_batch", batch)
self._execute_tasks("post_epoch",
checkpoint_frequency=checkpoint_freq,
convergence_thresh=thresh)
self._epoch += 1
except RuntimeError as err:
# In case the simulation diverges, break the optimization
# and checkpoint the last state such that an analysis can
# be performed.
self._diverged = True
if self.checkpoint_path is not None:
path = (self.checkpoint_path / f"epoch{self._epoch - 1}_error_state.pkl")
self.save_trainer(save_path=path)
print(f"Training has been unsuccessful due to the following"
f" error: {err}")
break
if self._converged:
break
else:
if thresh is not None:
print("Maximum number of epochs reached without convergence.")
self._execute_tasks("post_training")
def _update_dropout(self, batch):
"""Updates params, while keeping track of Dropout."""
# TODO refactor this as this needs to wait for when
# params will again be available, slowing down re-loading
# of batches. We could set dropout key as kwarg and keep
# track of keys in this class. Also refactor dropout in
# DimeNet taking advantage of haiku RNG key management and
# built-in dropout in MLP
params = dropout.next_dropout_params(self.params)
self.params = params
def _update_times_start(self, *args, **kwargs):
self.update_times.append(time.time())
def _update_times_end(self, *args, **kwargs):
self.update_times[self._epoch] -= time.time()
self.update_times[self._epoch] /= -60.
@abc.abstractmethod
def _get_batch(self):
"""A generator that returns the next batch that will be provided to the
_update function. The length of the generator should correspond to the
number of batches per epoch.
"""
@abc.abstractmethod
def _update(self, batch):
"""Uses the current batch to updates self.state via the training scheme
implemented in the specific trainer. Can additionally save auxilary
optimization results, such as losses and observables, that can be
used by _evaluate_convergence and for post-processing.
"""
@abc.abstractmethod
def _evaluate_convergence(self, duration, thresh, *args, **kwargs):
"""Checks whether a convergence criterion has been met. Can also be
used to print callbacks, such as time per epoch and loss vales.
"""
[docs]
def move_to_device(self):
"""Converts all arrays of the trainer state to JAX arrays."""
self.state = tree_map(jnp.array, self.state) # move on device
def _release_data_references(self):
for release in self.release_fns.values():
release()
self.release_fns = {}
class PropagationBase(MLETrainerTemplate):
"""Trainer base class for shared functionality whenever (multiple)
simulations are run during training. Can be used as a template to
build other trainers. Currently used for DiffTRe and relative entropy.
We only save the latest generated trajectory for each state point.
While accumulating trajectories would enable more frequent reweighting,
this effect is likely minor as past trajectories become exponentially
less useful with changing potential. Additionally, saving long trajectories
for each statepoint would increase memory requirements over the course of
the optimization.
"""
def __init__(self, init_trainer_state, optimizer, checkpoint_path,
reweight_ratio=0.9, sim_batch_size=1, energy_fn_template=None,
full_checkpoint=True, key=None, log_dir=None,):
super().__init__(optimizer, init_trainer_state, checkpoint_path,
full_checkpoint, log_dir, energy_fn_template)
self.sim_batch_size = sim_batch_size
self.reweight_ratio = reweight_ratio
if key is None:
self.key = random.PRNGKey(0)
# store for each state point corresponding traj_state and grad_fn
# save in distinct dicts as grad_fns need to be deleted for checkpoint
self.grad_fns, self.statepoints = {}, {}
self.n_statepoints = 0
self.trajectory_states = self.checkpoint("trajectory_states", {})
self.shuffle_key = self.checkpoint("key", random.PRNGKey(0))
self.weight_fn = {}
self._adaptive_step_size = {}
def _init_statepoint(self, reference_state, energy_fn_template,
simulator_template, neighbor_fn, timings, state_kwargs,
set_key=None, energy_batch_size=10,
initialize_traj=True,
safe_propagation=True, entropy_approximation=False,
resample_simstates=False, num_init_runs=2):
"""Initializes the simulation and reweighting functions as well
as the initial trajectory for a statepoint."""
# TODO ref pressure only used in print and to have barostat values.
# Reevaluate this parameter of barostat values not used in reweighting
# TODO document ref_press accordingly
# TODO: Extend this function to allow for multiple statepoints to be
# added. Requires batch argument to be set here.
assert "kT" in state_kwargs, (
"Reweighting requires at least the temperature to be specified in "
"the state_kwargs. "
)
# Backwards compatibility
if isinstance(reference_state, tuple):
warnings.warn(
"Passing the reference state as tuple of simulator state and "
"neighbors is deprecated. "
"Use trajectory.traj_util.SimulatorState instead.",
DeprecationWarning
)
reference_state = sampling.SimulatorState(
sim_state=reference_state[0], nbrs=reference_state[1])
if set_key is not None:
key = set_key
if set_key not in self.statepoints.keys():
self.n_statepoints += 1
else:
key = self.n_statepoints
self.n_statepoints += 1
self.statepoints[key] = state_kwargs
npt_ensemble = util.is_npt_ensemble(reference_state.sim_state)
if npt_ensemble:
assert "pressure" in state_kwargs, (
"Reweighting in the NPT ensemble requires the pressure to be "
"defined in the state_kwargs."
)
gen_init_traj, *reweight_fns = init_pot_reweight_propagation_fns(
energy_fn_template, simulator_template, neighbor_fn, timings,
state_kwargs, self.reweight_ratio, npt_ensemble,
energy_batch_size, safe_propagation=safe_propagation,
entropy_approximation=entropy_approximation,
resample_simstates=resample_simstates
)
self.key, split = random.split(self.key)
# To get the correct timings, first compile before evaluation
start = time.time()
init_traj_fn = gen_init_traj.lower(split, self.params, reference_state,
num_runs=num_init_runs, **state_kwargs)
init_traj_fn = init_traj_fn.compile()
compile_time = (time.time() - start) / 60.
print(
f"[Propagation] Time for trajectory compilation {key}: "
f"{compile_time} mins"
)
if initialize_traj:
start = time.time()
init_traj = init_traj_fn(split, self.params, reference_state, **state_kwargs)
run_time = (time.time() - start) / 60. / num_init_runs
assert not init_traj.overflow, "[Propagation] Neighborlist buffer overflowed."
assert not onp.any(onp.isnan(init_traj.trajectory.position)), "[Propagation] Initial simulation produced NaNs."
print(
f"[Propagation] Time for trajectory simulation {key}: "
f"{run_time} mins"
)
self.trajectory_states[key] = init_traj
else:
self.trajectory_states[key] = functools.partial(init_traj_fn, split, self.params, reference_state, **state_kwargs)
print("Not initializing the initial trajectory is only valid if "
"a checkpoint is loaded. In this case, please be use to add "
"state points in the same sequence, otherwise loaded "
"trajectories will not match its respective simulations.")
return key, *reweight_fns
@abstractmethod
def add_statepoint(self, *args, **kwargs):
"""User interface to add additional state point to train model on."""
raise NotImplementedError()
@property
def params(self):
"""Current energy parameters."""
return self.state.params
@params.setter
def params(self, loaded_params):
"""Replaces the current energy parameters."""
self.state = self.state.replace(params=loaded_params)
def get_sim_state(self, key):
"""Gets the simulator state of a statepoint."""
return self.trajectory_states[key].sim_state
def _get_batch(self):
"""Helper function to re-shuffle simulations and split into batches."""
self.shuffle_key, used_key = random.split(self.shuffle_key, 2)
shuffled_indices = random.permutation(used_key, self.n_statepoints)
if self.sim_batch_size == 1:
batch_list = jnp.split(shuffled_indices, shuffled_indices.size)
elif self.sim_batch_size == -1:
batch_list = jnp.split(shuffled_indices, 1)
else:
raise NotImplementedError("Only batch_size = 1 or -1 implemented.")
return (batch.tolist() for batch in batch_list)
def _print_measured_statepoint(self, sim_key=None):
"""Print meausured kbT (and pressure for npt ensemble) for all
statepoints to ensure the simulation is indeed carried out at the
prescribed state point.
"""
if sim_key is None:
for sim_key in self.trajectory_states.keys():
self._print_measured_statepoint(sim_key)
else:
traj = self.trajectory_states[sim_key]
print(f"[Statepoint {sim_key}]")
statepoint = self.statepoints[sim_key]
measured_kbt = jnp.mean(traj.aux["kT"])
if "pressure" in statepoint: # NPT
measured_press = jnp.mean(traj.aux["pressure"])
press_print = (f"\n\tpress = {measured_press:.2f} ref_press = "
f"{statepoint['pressure']:.2f}")
else:
press_print = ""
print(f"\tkT = {measured_kbt:.3f} ref_kT = "
f"{statepoint['kT']:.3f}" + press_print)
def train(self, max_epochs, thresh=None, checkpoint_freq=None):
assert self.n_statepoints > 0, ("Add at least 1 state point via "
"'add_statepoint' to start training.")
super().train(max_epochs, thresh=thresh,
checkpoint_freq=checkpoint_freq)
@abstractmethod
def _update(self, batch):
"""Implementation of gradient computation, stepping of the optimizer
and logging of auxiliary results. Takes batch of simulation indices
as input.
"""
def init_step_size_adaption(self,
allowed_reduction: float = 0.5,
interior_points: int = 10,
step_size_scale: float = 1e-7
) -> None:
"""Initializes a line search to tune the step size in each iteration.
The line search optimizes step size to limit the decrease in the
effective sample size (ESS) via the algorithm
:func:`chemtrain.learn.difftre.init_step_size_adaption`.
Args:
allowed_reduction: Target reduction of the effective sample size
interior_points: Number of interiour points
step_size_scale: Accuracy of the found optimal interpolation
coefficient
Returns:
Returns the interpolation coefficient :math:`\\alpha`.
"""
for key, weight_fn in self.weight_fn.items():
self._adaptive_step_size[key] = difftre.init_step_size_adaption(
weight_fn, allowed_reduction, interior_points, step_size_scale
)
[docs]
class DataParallelTrainer(MLETrainerTemplate):
"""Trainer for parallelized MLE training based on a dataset.
This trainer implements methods for MLE training on a dataset, where
parallelization can simply be accomplished by pmapping over batched data.
As pmap requires constant batch dimensions, data with unequal number of
atoms needs to be padded and to be compatible with this trainer.
"""
_train_loader: data.DataLoader
_val_loader: data.DataLoader
_test_loader: data.DataLoader
[docs]
def __init__(self, loss_fn, model, init_params, optimizer, checkpoint_path,
batch_per_device: int, batch_cache: int = 1,
full_checkpoint=True, penalty_fn=None, energy_fn_template=None,
convergence_criterion="window_median", log_file=None,
disable_shmap: bool = False):
self._disable_shmap = disable_shmap
self.batched_model = model
if disable_shmap:
self._update_fn = max_likelihood.pmap_update_fn(
self.batched_model, loss_fn, optimizer, penalty_fn)
self._evaluate_fn = None
else:
# shmap performs better, but some replication rules are missing
self._update_fn = max_likelihood.shmap_update_fn(
self.batched_model, loss_fn, optimizer, penalty_fn)
self._evaluate_fn = max_likelihood.shmap_loss_fn(
self.batched_model, loss_fn, penalty_fn)
self._loss_fn = loss_fn
self.batch_cache = batch_cache
self.batch_size = batch_per_device * device_count()
if optimizer is None:
print(f"No optimizer specified")
opt_state = None
else:
opt_state = optimizer.init(init_params) # initialize optimizer state
init_state = util.TrainerState(params=init_params, opt_state=opt_state)
super().__init__(
optimizer=optimizer, init_state=init_state,
checkpoint_path=checkpoint_path,
full_checkpoint=full_checkpoint,
log_file=log_file,
reference_energy_fn_template=energy_fn_template)
self.train_batch_losses = self.checkpoint("train_batch_losses", [])
self.train_losses = self.checkpoint("train_losses", [])
self.val_losses = self.checkpoint("val_losses", [])
self.train_target_losses = self.checkpoint("train_target_losses", {})
self.val_target_losses = self.checkpoint("val_target_losses", {})
self._batch_states: Dict[str, Any] = {}
self._batches_per_epoch: Dict[str, int] = {}
self._get_batch_fns: Dict[str, Callable] = {}
self._early_stop = EarlyStopping(self.params, convergence_criterion)
[docs]
def reset_convergence_losses(self):
"""Resets early stopping convergence monitoring."""
self._early_stop.reset_convergence_losses()
[docs]
def limit_batches_per_epoch(self, max_batches: int = 1):
"""Limits the number of batches per epoch.
Args:
max_batches: Maximum number of batches to use within one epoch.
"""
assert self._batches_per_epoch["training"] >= max_batches, (
"The number of batches per epoch is smaller than the requested "
"maximum."
)
self._batches_per_epoch["training"] = max_batches
[docs]
def set_datasets(self, dataset, train_ratio=0.7, val_ratio=0.1, shuffle=False,
include_all=True):
"""Sets the datasets for training, testing and validation.
Args:
dataset: Dictionary containing input and target data as numpy
arrays.
train_ratio: Percentage of dataset to use for training.
val_ratio: Percentage of dataset to use for validation.
shuffle: Whether to shuffle data before splitting into
train-val-test.
include_all: Compute the loss for all samples of the splits by
padding the last batch and masking out double samples.
Not applied to the training split.
"""
# release all references before allocating new data to avoid memory leak
self._release_data_references()
# Initialize the data loaders
loaders = data_loaders.init_dataloaders(
dataset, train_ratio, val_ratio, shuffle=shuffle)
self.set_loader(loaders.train_loader, stage="training")
self.set_loader(loaders.val_loader, stage="validation", include_all=include_all)
self.set_loader(loaders.test_loader, stage="testing", include_all=include_all)
[docs]
def set_dataset(self, dataset, stage="testing", shuffle=False, include_all=False, **kwargs):
"""Sets the dataset for a single stage, e.g., training.
Args:
dataset: Dictionary containing input and target data as numpy
arrays.
stage: Stage for which to set the dataset. Can be ``"training"``,
``"validation"``, or ``"testing"``.
shuffle: Whether the data should be shuffled.
include_all: Compute the loss for all samples of the split by
padding the last batch and masking out double samples.
Not applied to the training split.
"""
# Will only return one data loader
loaders = data_loaders.init_dataloaders(
dataset, train_ratio=1.0, val_ratio=0.0, shuffle=shuffle
)
self.set_loader(loaders.train_loader, stage=stage, include_all=include_all, **kwargs)
[docs]
def set_loader(self,
data_loader,
stage="training",
include_all=False,
batch_size=None,
rng_seed=None,
**kwargs):
"""Sets a data loader for a specific stage, e.g., training.
If the dataset consists of numpy arrays, it is simpler to use
:func:`set_dataset` or :func:`set_datasets` to set the data loaders.
Args:
data_loader: The data loader to set.
stage: The stage for which to set the data loader. Can be
``"training"``, ``"validation"``, or ``"testing"``.
include_all: Compute the loss for all samples of the split by
padding the last batch and masking out double samples.
Not applied to the training split.
rng_seed: Seed to include random keys in the reference data.
The keys are refreshed whenever a new batch is drawn.
batch_size: Overwrites the default batch size.
"""
if stage in self.release_fns.keys():
self.release_fns[stage]()
observation_count = data_loader.static_information["observation_count"]
assert observation_count > 0
# Overwrite batch size for splits
if batch_size is not None:
print(f"Chose custom batch size {batch_size} for split {stage}")
if batch_size is None:
batch_size = self.batch_size
if batch_size > observation_count:
batch_size = observation_count
# Ensures that the batch size is divisible by the number of devices
batch_size -= onp.mod(batch_size, device_count())
if batch_size != self.batch_size:
logging.info(
f"Batch size for stage {stage} changed to {batch_size} "
f"from {self.batch_size}."
)
if include_all:
assert stage != "training", (f"Including all samples not supported "
f"for the training split.")
# Increase the number of observations to make them divisible by
# the batch size
observation_count += onp.mod(
batch_size - onp.mod(observation_count, batch_size), batch_size
)
if onp.mod(observation_count, batch_size) != 0:
warnings.warn(
f"Batch size {batch_size} does not divide the number of "
f"observations {observation_count}. "
f"Trainer will skip {observation_count % batch_size} samples "
f"for state {stage}"
)
# Initialize the access functions
batch_fns = data_loaders.init_batch_functions(
data_loader, mb_size=batch_size, cache_size=self.batch_cache,
)
init_train_state, get_train_batch, release = batch_fns
train_batch_state = init_train_state(
shuffle=True, in_epochs=include_all, rng_seed=rng_seed, **kwargs
)
self._get_batch_fns[stage] = get_train_batch
self._batch_states[stage] = train_batch_state
self._batches_per_epoch[stage] = observation_count // batch_size
self.release_fns[stage] = release
def _get_batch_stage(self, stage, information=False):
for _ in range(self._batches_per_epoch[stage]):
self._batch_states[stage], train_batch = self._get_batch_fns[stage](
self._batch_states[stage], information=information)
yield train_batch
[docs]
def set_batches_per_epoch(self, stage="training", max_batches: int = 1):
"""Limits the number of updates within an epoch.
Useful, e.g., when the validation loss should be computed more
regularly.
Args:
stage: Key to the stage that should be limited. Currently,
stage other than ``"training"`` is not supported to avoid
wrongful computations of the validation loss.
max_batches: Maximum number of batches, i.e., number of batched
optimizer updates.
"""
if stage != "training":
raise NotImplementedError("Only training stage implemented.")
self._batches_per_epoch[stage] = min([
self._batches_per_epoch[stage], max_batches
])
def _get_batch(self):
return self._get_batch_stage("training")
def _update(self, batch):
"""Function to iterate, optimizing parameters and saving
training and validation loss values.
"""
params, opt_state, train_loss, curr_grad, per_target_losses = self._update_fn(
self.state.params, self.state.opt_state, batch, per_target=True)
# Save the statistics
for key, val in per_target_losses.items():
if key not in self.train_target_losses.keys():
self.train_target_losses[key] = []
self.train_target_losses[key].append(onp.asarray(val))
self.state = self.state.replace(params=params, opt_state=opt_state)
self.train_batch_losses.append(onp.asarray(train_loss))
self.gradient_norm_history.append(util.tree_norm(curr_grad))
[docs]
def predict(self, dataset, params=None, batch_size=10):
"""Computes predictions for a dataset.
Args:
dataset: Dictionary containing input data as numpy arrays. Can be,
e.g., the whole testing split.
params: Parameters for the model. If None, uses the current
parameters.
batch_size: Batch size for predictions.
Returns:
Returns all predictions of the model for the provided inputs.
"""
# Set random to False to prevent shuffling of results by shuffling
# inputs
self.set_dataset(dataset, "predict", include_all=True, random=False)
if params is None:
params = self.params
if self._disable_shmap:
raise NotImplementedError("Pmapped predictions not implemented.")
else:
shmapped_model = max_likelihood.shmap_model(self.batched_model)
all_predictions = None
for batch_with_info in self._get_batch_stage("predict", information=True):
# Compute the total loss and the individual contributions per
# target
batch, batch_info = batch_with_info
# Only get valid samples by masking with numpy
predictions = shmapped_model(params, batch)
predictions = tree_util.tree_map(
lambda x: x[onp.asarray(batch_info.mask), ...],
jax.device_get(predictions)
)
if all_predictions is None:
all_predictions = predictions
else:
all_predictions = util.tree_map(
lambda *leaves: onp.concatenate(leaves, axis=0),
all_predictions, predictions
)
return all_predictions
[docs]
def evaluate(self, stage = "validation", loss_fn = None, params=None):
"""Computes the loss on the whole dataset.
Args:
stage: Stage for which to evaluate the loss. Can be ``"testing"``,
``"validation"``, or ``"training"``.
loss_fn: Loss function to evaluate. If None, evaluates the loss
function used for training.
params: Parameters for the model. If None, uses the current
parameters.
Returns:
Returns the total loss and the loss for each individual target.
"""
assert stage in self._batch_states, (
f"A dataloader (dataset) is required to evaluate the loss on "
f"stage {stage}."
)
if params is None:
params = self.params
# Option to define a new loss function, e.g., for MAE error
if loss_fn is None:
loss_fn = self._evaluate_fn
elif not self._disable_shmap:
loss_fn = max_likelihood.shmap_loss_fn(self.batched_model, loss_fn)
else:
raise NotImplementedError
total_loss, per_target_losses = 0.0, {}
total_samples, valid_samples = 0, 0
for batch_with_info in self._get_batch_stage(stage, information=True):
# Compute the total loss and the individual contributions per
# target
batch, batch_info = batch_with_info
# Compute a correction factor
valid_samples += onp.sum(batch_info.mask)
total_samples += batch_info.batch_size
val_loss, per_target_loss = loss_fn(
params, batch, mask=batch_info.mask, per_target=True)
total_loss += onp.asarray(val_loss)
for key, val in per_target_loss.items():
if key not in per_target_losses:
per_target_losses[key] = []
per_target_losses[key].append(onp.asarray(val))
# The correction factor accounts for including invalid (masked) samples
# in the mean of the split
scale_factor = total_samples / valid_samples
total_loss /= self._batches_per_epoch[stage]
total_loss *= scale_factor
per_target_losses = {
key: sum(val) / len(val) * scale_factor
for key, val in per_target_losses.items()
}
return total_loss, per_target_losses
def _evaluate_convergence(self, *args, thresh=None, **kwargs):
"""Prints progress, saves best obtained params and signals converged if
validation loss improvement over the last epoch is less than the thesh.
"""
batches_per_epoch = self._batches_per_epoch["training"]
mean_train_loss = sum(
self.train_batch_losses[-batches_per_epoch:]
) / batches_per_epoch
self.train_losses.append(mean_train_loss)
duration = self.update_times[self._epoch]
start = time.time()
if "validation" in self._batch_states:
val_loss, val_target_losses = self.evaluate("validation")
self.val_losses.append(onp.asarray(val_loss))
for key, val in val_target_losses.items():
if key not in self.val_target_losses:
self.val_target_losses[key] = []
self.val_target_losses[key].append(val)
self._converged = self._early_stop.early_stopping(
val_loss, thresh, self.params)
else:
val_loss = None
# Add the time for the validation
total_duration = (time.time() - start) / 60. + duration
log_str = (
f"[Epoch {self._epoch}]:\n"
f"\tAverage train loss: {mean_train_loss:.5f}\n"
f"\tAverage val loss: {val_loss}\n"
f"\tGradient norm: {self.gradient_norm_history[-1]}\n"
f"\tElapsed time = {total_duration:.3f} min (total), {duration:.3f} min (training)\n"
f"\tPer-target losses:\n"
)
for key in self.train_target_losses:
train_batches_per_epoch = self._batches_per_epoch["training"]
mean_train_loss = sum(
self.train_target_losses[key][-train_batches_per_epoch:]
) / train_batches_per_epoch
try:
target_val_loss = self.val_target_losses[key][-1]
except IndexError or KeyError:
target_val_loss = "N.A."
log_str += (
f"\t\t{key} | train loss: {mean_train_loss} | "
f"val loss: {target_val_loss}\n"
)
print(log_str)
[docs]
def update_with_samples(self, **sample):
"""A single params update step, where a batch is taken from the training
set and samples of the batch are substituted by the provided samples.
This function is useful in an active learning context to retrain
specifically on newly labeled datapoints. The number of provided samples
must not exceed the trainer batch size.
Args:
sample: Kwargs storing data samples to supply to
``self._build_dataset`` to build samples in the correct
pytree. Analogous usage as update_dataset, but the dataset
only consists of a few observations.
"""
n_samples = util.tree_multiplicity(sample)
assert n_samples <= self.batch_size, ("Number of provided samples must"
" not exceed trainer batch size.")
batch = next(self._get_batch())
updated_batch = util.tree_set(batch, sample, n_samples)
self._update_with_dropout(updated_batch)
@property
def params(self):
"""Current energy parameters."""
single_params = self.state.params
return single_params
@params.setter
def params(self, loaded_params):
self.state = self.state.replace(params=loaded_params)
@property
def best_params(self):
"""Returns the best parameters based on the validation loss.
If training was performed with early stopping, return the best
parameters to this criterion instead.
"""
# if no validation data given, _early_stop.best_params are simply
# init_params
if "validation" in self._batch_states.keys() is None:
return self.params
else:
return self._early_stop.best_params
@property
def best_inference_params(self):
"""Returns best model params irrespective whether dropout is used."""
if dropout.dropout_is_used(self.best_params):
# all nodes present during inference
params, _ = dropout.split_dropout_params(self.best_params)
else:
params = self.best_params
return params
@property
def best_inference_params_replicated(self):
"""Returns the best inference params replicated on every device."""
inference_params = self.best_inference_params
return util.tree_replicate(inference_params)
[docs]
def move_to_device(self):
"""Transforms all arrays of the trainer state to JAX arrays."""
super().move_to_device()
self._early_stop.move_to_device()
class ProbabilisticFMTrainerTemplate(TrainerInterface):
"""Trainer template for methods that result in multiple parameter sets for
Monte-Carlo-style uncertainty quantification, based on a force-matching
formulation.
"""
def __init__(self, checkpoint_path, energy_fn_template,
val_dataloader=None):
super().__init__(checkpoint_path, energy_fn_template)
self.results = []
# TODO use val_loader for some metrics that are interesting for MCMC
# and SG-MCMC
def move_to_device(self):
params = []
for param_set in self.params:
params.append(tree_map(jnp.array, param_set)) # move on device
self.params = params
@property
@abc.abstractmethod
def list_of_params(self):
""" Returns a list containing n single model parameter sets, where n
is the number of samples. This provides a more intuitive parameter
interface that self.params, which returns a large set of parameters,
where n is the leading axis of each leaf. Self.params is most useful,
if parameter sets are mapped via map or vmap in a postprocessing step.
"""
class MCMCForceMatchingTemplate(ProbabilisticFMTrainerTemplate):
"""Initializes log_posterior function to be used for MCMC with blackjax,
including batch-wise evaluation of the likelihood and re-materialization.
"""
def __init__(self, init_state, kernel, checkpoint_path, val_loader=None,
ref_energy_fn_template=None):
super().__init__(checkpoint_path, ref_energy_fn_template, val_loader)
self.kernel = jit(kernel)
self.state = init_state
def train(self, num_samples, checkpoint_freq=None, init_samples=None,
rng_key=random.PRNGKey(0)):
if init_samples is not None:
# TODO implement multiple chains
raise NotImplementedError
for i in range(num_samples):
start_time = time.time()
rng_key, consumed_key = random.split(rng_key)
self.state, info = self.kernel(consumed_key, self.state)
self.results.append(self.state)
print(f"Time for sample {i}: {(time.time() - start_time) / 60.}"
f" min.", info)
self._epoch += 1
self._dump_checkpoint_occasionally(frequency=checkpoint_freq)
@property
def list_of_params(self):
"""Returns a list of sampled parameters."""
return [state.position["params"] for state in self.results]
@property
def params(self):
"""Concatenates the sampled parameters along the first dimension."""
return util.tree_stack(self.list_of_params)
@params.setter
def params(self, loaded_params):
raise NotImplementedError("Setting params seems not meaningful for MCMC"
" samplers.")
[docs]
class EarlyStopping:
"""A class that saves the best parameter obtained so far based on the
validation loss and determines whether the optimization can be stopped based
on some stopping criterion.
The following criteria are implemented:
- ``"window_median"``: 2 windows are placed at the end of the loss
history. Stops when the median of the latter window of size "thresh"
exceeds the median of the prior window of the same size.
- ``"PQ"``: Stops when the PQ criterion exceeds thresh
- ``"max_loss"``: Stops when the loss decreased below the maximum allowed
loss specified via thresh.
Args:
criterion: Convergence criterion to employ
pq_window_size: Window size for PQ method
Attributes:
best_loss: Loss of the best performing parameters
best_params: Parameters with best performance
"""
[docs]
def __init__(self, params, criterion, pq_window_size=5):
self.criterion = criterion
# own loss history that can be reset on the fly if needed.
self._epoch_losses = []
self.best_loss = 1.e16
self.best_params = copy.copy(params) # move on device, if loaded
self.pq_window_size = pq_window_size
def _is_converged(self, thresh):
converged = False
if thresh is not None: # otherwise no early stopping used
if self.criterion == "window_median":
window_size = thresh
if len(self._epoch_losses) >= 2 * window_size:
prior_window = onp.array(
self._epoch_losses[-2 * window_size:-window_size])
latter_window = onp.array(self._epoch_losses[-window_size:])
converged = (onp.median(latter_window)
> onp.median(prior_window))
elif self.criterion == "PQ":
if len(self._epoch_losses) >= self.pq_window_size:
best_loss = min(self._epoch_losses)
loss_window = self._epoch_losses[-self.pq_window_size:]
gen_loss = 100. * (loss_window[-1] / best_loss - 1.)
window_average = sum(loss_window) / self.pq_window_size
window_min = min(loss_window)
progress = 1000. * (window_average / window_min - 1.)
pq = gen_loss / progress
converged = pq > thresh
elif self.criterion == "max_loss":
converged = self._epoch_losses[-1] < thresh
else:
raise ValueError(f"Convergence criterion {self.criterion} "
f"unknown. Select 'max_loss', 'ave_loss' or "
f"'std'.")
return converged
[docs]
def early_stopping(self, curr_epoch_loss, thresh, params=None,
save_best_params=True):
"""Estimates whether the convergence criterion was met and keeps track
of the best parameters obtained so far.
Args:
curr_epoch_loss: Validation loss of the most recent epoch
thresh: Convergence threshold. Specific definition depends on the
selected convergence criterion.
params: Optimization parameters to save in case of being best. Make
sure to supply non-device-replicated params, i.e.
``self.params.``
save_best_params: If best params are supposed to be tracked
Returns:
True if the convergence criterion was met, else False.
"""
self._epoch_losses.append(curr_epoch_loss)
if save_best_params:
assert params is not None, ("If best params are saved, they need to"
" be provided in early_stopping.")
improvement = self.best_loss - curr_epoch_loss
if improvement > 0.:
self.best_loss = curr_epoch_loss
self.best_params = copy.copy(params)
return self._is_converged(thresh)
[docs]
def reset_convergence_losses(self):
"""Resets loss history used for convergence estimation, e.g., to avoid
early stopping when loss increases due to on-the-fly changes in the
dataset or the loss function.
"""
self._epoch_losses = []
self.best_loss = 1.e16
self.best_params = None
[docs]
def move_to_device(self):
"""Moves best_params to device to use them after loading trainer."""
self.best_params = tree_map(jnp.array, self.best_params)