trainers.PropertyPrediction#
- class PropertyPrediction(error_fn, prediction_model, init_params, optimizer, graph_dataset, targets, batch_per_device=1, batch_cache=10, train_ratio=0.7, val_ratio=0.1, test_error_fn=None, shuffle=False, convergence_criterion='window_median', checkpoint_folder='Checkpoints')[source]#
Trainer for direct prediction of molecular properties.
Methods
__init__(error_fn, prediction_model, ...[, ...])A reference energy_fn_template can be provided, but is not mandatory due to the dependence of the template on the box via the displacement function.
add_task(trigger, fn_or_method)Adds a tasks to perform regularly during training.
checkpoint(name, object)Marks attribute to be saved in a partial checkpoint.
evaluate([stage, loss_fn, params])Computes the loss on the whole dataset.
evaluate_testset_error([best_params])limit_batches_per_epoch([max_batches])Limits the number of batches per epoch.
load_energy_params(file_path)Loads energy parameters.
Transforms all arrays of the trainer state to JAX arrays.
predict(single_observation)Prediction for a single input graph using the current param state.
Prints the tasks performed by the trainer.
Resets early stopping convergence monitoring.
restore(checkpoint)Restores the trainer from a checkpoint.
save_energy_params(file_path[, save_format, ...])Saves energy parameters.
save_trainer(save_path[, format])Saves whole trainer, e.g. for production after training.
set_batches_per_epoch([stage, max_batches])Limits the number of updates within an epoch.
set_dataset(dataset[, stage, shuffle, ...])Sets the dataset for a single stage, e.g., training.
set_datasets(dataset[, train_ratio, ...])Sets the datasets for training, testing and validation.
set_loader(data_loader[, stage, ...])Sets a data loader for a specific stage, e.g., training.
train(max_epochs[, thresh, checkpoint_freq])Trains for a maximum number of epochs, checkpoints after a specified number of epochs and ends training if a convergence criterion is met.
update_with_samples(**sample)A single params update step, where a batch is taken from the training set and samples of the batch are substituted by the provided samples.
Attributes
- best_inference_params#
Returns best model params irrespective whether dropout is used.
- best_inference_params_replicated#
Returns the best inference params replicated on every device.
- best_params#
Returns the best parameters based on the validation loss.
If training was performed with early stopping, return the best parameters to this criterion instead.
- energy_fn#
Returns the energy function for the current parameters.
- params#
Current energy parameters.