trainers.DifftreParallel.add_task

trainers.DifftreParallel.add_task#

DifftreParallel.add_task(trigger, fn_or_method)#

Adds a tasks to perform regularly during training.

Parameters:
  • trigger – The trigger at which the task is executed. Can be "pre/post_training/epoch/batch".

  • fn_or_method – The function or method to be executed.

Example

The following code adds a task printing a specific energy parameter after each epoch.

def print_parameter(trainer, *args, **kwargs):
   print(f"Parameter after epoch {trainer._epoch}: "
         f"{trainer.state.params["parameter"]}")

trainer.add_task("post_epoch", print_parameter)