trainers.base.DataParallelTrainer#

class DataParallelTrainer(loss_fn, model, init_params, optimizer, checkpoint_path, batch_per_device, batch_cache=1, full_checkpoint=True, penalty_fn=None, energy_fn_template=None, convergence_criterion='window_median', log_file=None, disable_shmap=False)[source]#

Trainer for parallelized MLE training based on a dataset.

This trainer implements methods for MLE training on a dataset, where parallelization can simply be accomplished by pmapping over batched data. As pmap requires constant batch dimensions, data with unequal number of atoms needs to be padded and to be compatible with this trainer.

Methods

__init__(loss_fn, model, init_params, ...[, ...])

A reference energy_fn_template can be provided, but is not mandatory due to the dependence of the template on the box via the displacement function.

add_task(trigger, fn_or_method)

Adds a tasks to perform regularly during training.

checkpoint(name, object)

Marks attribute to be saved in a partial checkpoint.

evaluate([stage, loss_fn, params])

Computes the loss on the whole dataset.

limit_batches_per_epoch([max_batches])

Limits the number of batches per epoch.

load_energy_params(file_path)

Loads energy parameters.

move_to_device()

Transforms all arrays of the trainer state to JAX arrays.

predict(dataset[, params, batch_size])

Computes predictions for a dataset.

print_training_tasks()

Prints the tasks performed by the trainer.

reset_convergence_losses()

Resets early stopping convergence monitoring.

restore(checkpoint)

Restores the trainer from a checkpoint.

save_energy_params(file_path[, save_format, ...])

Saves energy parameters.

save_trainer(save_path[, format])

Saves whole trainer, e.g. for production after training.

set_batches_per_epoch([stage, max_batches])

Limits the number of updates within an epoch.

set_dataset(dataset[, stage, shuffle, ...])

Sets the dataset for a single stage, e.g., training.

set_datasets(dataset[, train_ratio, ...])

Sets the datasets for training, testing and validation.

set_loader(data_loader[, stage, ...])

Sets a data loader for a specific stage, e.g., training.

train(max_epochs[, thresh, checkpoint_freq])

Trains for a maximum number of epochs, checkpoints after a specified number of epochs and ends training if a convergence criterion is met.

update_with_samples(**sample)

A single params update step, where a batch is taken from the training set and samples of the batch are substituted by the provided samples.

Attributes

best_inference_params#

Returns best model params irrespective whether dropout is used.

best_inference_params_replicated#

Returns the best inference params replicated on every device.

best_params#

Returns the best parameters based on the validation loss.

If training was performed with early stopping, return the best parameters to this criterion instead.

energy_fn#

Returns the energy function for the current parameters.

params#

Current energy parameters.