learn.max_likelihood#
A collection of functions to facilitate learning maximum likelihood / single point estimate models.
Loss Functions#
These functions are masked implementations of common loss functions.
- mse_loss(predictions, targets, mask=None, weights=None)[source]#
Computes mean squared error loss for given predictions and targets.
- Parameters:
predictions – Array of predictions
targets – Array of respective targets. Needs to have same shape as predictions.
mask – Mask contribution of some array elements. Needs to have same shape as predictions. Default None applies no mask.
- Returns:
Mean squared error loss value.
- mae_loss(predictions, targets, mask=None, weights=None)[source]#
Computes the mean absolute error for given predictions and targets.
- Parameters:
predictions – Array of predictions
targets – Array of respective targets. Needs to have same shape as predictions.
mask – Mask contribution of some array elements. Needs to have same shape as predictions. Default None applies no mask.
- Returns:
Mean absolute error value.
Dataset Predictions#
Algorithms such as force matching requires evaluation of the loss function on many samples instead of a single snapshot. Therefore, chemtrain provides functions to efficiently parallelize these evaluations, using vectorization and parallelization.
- pmap_update_fn(batched_model, loss_fn, optimizer, penalty_fn=None)[source]#
Initializes a pmapped function for updating parameters.
- Usage:
params, opt_state, loss, grad = update_fn(params, opt_state, batch)
Loss and grad are only a single instance, no n_device replica. Params and opt_state need to be N_devices times duplicated along axis 0. Batch is reshaped by this function.
- Parameters:
batched_model – A model with signature model(params, batch), which predicts a batch of outputs used in loss function.
loss_fn – Loss function(predictions, targets) returning the scalar loss value for a batch.
optimizer – Optax optimizer
penalty_fn – A penalty function based on the model parameters.
- Returns:
A function that computes the gradient and updates the parameters via the optimizer.
- shmap_update_fn(batched_model, loss_fn, optimizer, penalty_fn=None)[source]#
Initializes a shmapped function for updating parameters.
- Usage:
params, opt_state, loss, grad = update_fn(params, opt_state, batch)
- Parameters:
batched_model – A model with signature model(params, batch), which predicts a batch of outputs used in loss function.
loss_fn – Loss function(predictions, targets) returning the scalar loss value for a batch.
optimizer – Optax optimizer
penalty_fn – A penalty function based on the model parameters.
- Returns:
A function that computes the gradient and updates the parameters via the optimizer.
- shmap_loss_fn(batched_model, loss_fn, penalty_fn=None)[source]#
Initializes a shmapped function for computing a loss.
- Usage:
loss, per_target_losses = loss_fn(params, batch, per_target=True)
- Parameters:
batched_model – A model with signature model(params, batch), which predicts a batch of outputs used in loss function.
loss_fn – Loss function(predictions, targets) returning the scalar loss value for a batch.
penalty_fn – A penalty function based on the model parameters.
- Returns:
A function that computes the total loss and per-target loss contributions.
- init_val_predictions(batched_model, val_loader, batch_size=1, batch_cache=10)[source]#
Model predictions for whole validation/test dataset.
- Usage:
predictions, data_state = mapped_model_fn(params, data_state)
Params needs to be N_devices times duplicated along axis 0.
- Parameters:
batched_model – A model with signature model(params, batch), which predicts a batch of outputs used in loss function.
val_loader – Validation or test set NumpyDataLoader.
batch_size – Total batch size that is processed in parallel
batch_cache – Number of batches to cache.
- Returns:
Tuple (predictions, data_state). predictions contains model predictions for the whole validation dataset and data_state is used to start the data loading in the next evaluation.
- init_val_loss_fn(model, loss_fn, val_loader, val_targets_keys=None, batch_size=1, batch_cache=100)[source]#
Initializes a pmapped loss function that computes the validation loss.
- Usage:
val_loss, data_state = batched_loss_fn(params, data_state)
Params needs to be N_devices times duplicated along axis 0.
- Parameters:
model – A model with signature model(params, batch), which predicts outputs used in loss function.
loss_fn – Loss function(predictions, targets) returning the scalar loss value for a batch.
val_loader – NumpyDataLoader for validation set.
val_targets_keys – Dict containing targets of whole val
batch_size – Total batch size that is processed in parallel.
batch_cache – Number of batches to cache on GPU to reduce host-device communication.
- Returns:
A pmapped function that returns the average validation loss.