Source code for chemtrain.trainers.extensions

# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Simple extensions, e.g., to log trainer statistics to MLops frameworks."""
import importlib

from chemtrain.trainers import trainers, base

[docs] def wandb_log_difftre(run, trainer: trainers.Difftre, plot_fns=None): """Logs DiffTRe training statistics to Weights & Biases. Args: run: Active W&B run trainer: Trainer to log to W&B plot_fns: Dictionary with functions to plot the selected predictions for selected statepoints Example usage: After initiating the trainer and the run, add a W&B tracking task to the trainer via: .. code:: def plot_fn(some_prediction): fig = plt.figure() ... return fig plot_fns = { 0: {'some_prediction': plot_fn} } wandb_log_difftre(run, difftre_trainer, plot_fn) """ wandb = importlib.import_module("wandb") if plot_fns is None: plot_fns = {} def log_fn(trainer: trainers.Difftre, *args, **kwargs): plots = {} assert issubclass(type(trainer), trainers.Difftre), ( f"Supports only DiffTRe trainer." ) for statepoint_key, statepoint_fns in plot_fns.items(): recent_predictions = trainer.predictions[statepoint_key][trainer._epoch] plots[statepoint_key] = { pred_key: wandb.Image( plot_fn(recent_predictions[pred_key]) ) for pred_key, plot_fn in statepoint_fns.items() } run.log( data={ "Epoch loss": trainer.epoch_losses[-1], "Gradient norm": trainer.gradient_norm_history[-1], "Elapsed time": trainer.update_times[-1], "Predictions": plots }, commit=True ) trainer.add_task("post_epoch", log_fn)
def wandb_log_data_parallel(run, trainer: base.DataParallelTrainer): """Logs DataParallel training statistics to Weights & Biases. Args: run: Active W&B run trainer: Trainer to log to W&B """ wandb = importlib.import_module("wandb") def get_validation_loss(key): try: return trainer.val_target_losses[key][-1] except IndexError: return "N.A." def log_fn(trainer: base.DataParallelTrainer, *args, **kwargs): assert issubclass(type(trainer), base.DataParallelTrainer), ( f"Supports only DataParallalTrainer trainers." ) duration = trainer.update_times[trainer._epoch] statistics = { "training": trainer.train_batch_losses[-1], "validation": trainer.val_losses[-1], "gradient_norm": trainer.gradient_norm_history[-1], "duration": duration, "targets": { key: { "training": trainer.train_target_losses[key][-1], "validation": get_validation_loss(key), } for key in trainer.train_target_losses.keys() } } run.log(data=statistics, commit=True) trainer.add_task("post_epoch", log_fn)