trainers.DifftreParallel.add_task#
- DifftreParallel.add_task(trigger, fn_or_method)#
Adds a tasks to perform regularly during training.
- Parameters:
trigger – The trigger at which the task is executed. Can be
"pre/post_training/epoch/batch".fn_or_method – The function or method to be executed.
Example
The following code adds a task printing a specific energy parameter after each epoch.
def print_parameter(trainer, *args, **kwargs): print(f"Parameter after epoch {trainer._epoch}: " f"{trainer.state.params["parameter"]}") trainer.add_task("post_epoch", print_parameter)