trainers.PropertyPrediction#

class PropertyPrediction(error_fn, prediction_model, init_params, optimizer, graph_dataset, targets, batch_per_device=1, batch_cache=10, train_ratio=0.7, val_ratio=0.1, test_error_fn=None, shuffle=False, convergence_criterion='window_median', checkpoint_folder='Checkpoints')[source]#

Trainer for direct prediction of molecular properties.

Methods

__init__(error_fn, prediction_model, ...[, ...])

A reference energy_fn_template can be provided, but is not mandatory due to the dependence of the template on the box via the displacement function.

add_task(trigger, fn_or_method)

Adds a tasks to perform regularly during training.

checkpoint(name, object)

Marks attribute to be saved in a partial checkpoint.

evaluate([stage, loss_fn, params])

Computes the loss on the whole dataset.

evaluate_testset_error([best_params])

limit_batches_per_epoch([max_batches])

Limits the number of batches per epoch.

load_energy_params(file_path)

Loads energy parameters.

move_to_device()

Transforms all arrays of the trainer state to JAX arrays.

predict(single_observation)

Prediction for a single input graph using the current param state.

print_training_tasks()

Prints the tasks performed by the trainer.

reset_convergence_losses()

Resets early stopping convergence monitoring.

restore(checkpoint)

Restores the trainer from a checkpoint.

save_energy_params(file_path[, save_format, ...])

Saves energy parameters.

save_trainer(save_path[, format])

Saves whole trainer, e.g. for production after training.

set_batches_per_epoch([stage, max_batches])

Limits the number of updates within an epoch.

set_dataset(dataset[, stage, shuffle, ...])

Sets the dataset for a single stage, e.g., training.

set_datasets(dataset[, train_ratio, ...])

Sets the datasets for training, testing and validation.

set_loader(data_loader[, stage, ...])

Sets a data loader for a specific stage, e.g., training.

train(max_epochs[, thresh, checkpoint_freq])

Trains for a maximum number of epochs, checkpoints after a specified number of epochs and ends training if a convergence criterion is met.

update_with_samples(**sample)

A single params update step, where a batch is taken from the training set and samples of the batch are substituted by the provided samples.

Attributes

best_inference_params#

Returns best model params irrespective whether dropout is used.

best_inference_params_replicated#

Returns the best inference params replicated on every device.

best_params#

Returns the best parameters based on the validation loss.

If training was performed with early stopping, return the best parameters to this criterion instead.

energy_fn#

Returns the energy function for the current parameters.

params#

Current energy parameters.