learn.max_likelihood#

A collection of functions to facilitate learning maximum likelihood / single point estimate models.

Loss Functions#

These functions are masked implementations of common loss functions.

mse_loss(predictions, targets, mask=None, weights=None)[source]#

Computes mean squared error loss for given predictions and targets.

Parameters:
  • predictions – Array of predictions

  • targets – Array of respective targets. Needs to have same shape as predictions.

  • mask – Mask contribution of some array elements. Needs to have same shape as predictions. Default None applies no mask.

Returns:

Mean squared error loss value.

mae_loss(predictions, targets, mask=None, weights=None)[source]#

Computes the mean absolute error for given predictions and targets.

Parameters:
  • predictions – Array of predictions

  • targets – Array of respective targets. Needs to have same shape as predictions.

  • mask – Mask contribution of some array elements. Needs to have same shape as predictions. Default None applies no mask.

Returns:

Mean absolute error value.

Dataset Predictions#

Algorithms such as force matching requires evaluation of the loss function on many samples instead of a single snapshot. Therefore, chemtrain provides functions to efficiently parallelize these evaluations, using vectorization and parallelization.

pmap_update_fn(batched_model, loss_fn, optimizer, penalty_fn=None)[source]#

Initializes a pmapped function for updating parameters.

Usage:
params, opt_state, loss, grad = update_fn(params, opt_state, batch)

Loss and grad are only a single instance, no n_device replica. Params and opt_state need to be N_devices times duplicated along axis 0. Batch is reshaped by this function.

Parameters:
  • batched_model – A model with signature model(params, batch), which predicts a batch of outputs used in loss function.

  • loss_fn – Loss function(predictions, targets) returning the scalar loss value for a batch.

  • optimizer – Optax optimizer

  • penalty_fn – A penalty function based on the model parameters.

Returns:

A function that computes the gradient and updates the parameters via the optimizer.

shmap_update_fn(batched_model, loss_fn, optimizer, penalty_fn=None)[source]#

Initializes a shmapped function for updating parameters.

Usage:
params, opt_state, loss, grad = update_fn(params, opt_state, batch)
Parameters:
  • batched_model – A model with signature model(params, batch), which predicts a batch of outputs used in loss function.

  • loss_fn – Loss function(predictions, targets) returning the scalar loss value for a batch.

  • optimizer – Optax optimizer

  • penalty_fn – A penalty function based on the model parameters.

Returns:

A function that computes the gradient and updates the parameters via the optimizer.

shmap_loss_fn(batched_model, loss_fn, penalty_fn=None)[source]#

Initializes a shmapped function for computing a loss.

Usage:
loss, per_target_losses = loss_fn(params, batch, per_target=True)
Parameters:
  • batched_model – A model with signature model(params, batch), which predicts a batch of outputs used in loss function.

  • loss_fn – Loss function(predictions, targets) returning the scalar loss value for a batch.

  • penalty_fn – A penalty function based on the model parameters.

Returns:

A function that computes the total loss and per-target loss contributions.

init_val_predictions(batched_model, val_loader, batch_size=1, batch_cache=10)[source]#

Model predictions for whole validation/test dataset.

Usage:
predictions, data_state = mapped_model_fn(params, data_state)

Params needs to be N_devices times duplicated along axis 0.

Parameters:
  • batched_model – A model with signature model(params, batch), which predicts a batch of outputs used in loss function.

  • val_loader – Validation or test set NumpyDataLoader.

  • batch_size – Total batch size that is processed in parallel

  • batch_cache – Number of batches to cache.

Returns:

Tuple (predictions, data_state). predictions contains model predictions for the whole validation dataset and data_state is used to start the data loading in the next evaluation.

init_val_loss_fn(model, loss_fn, val_loader, val_targets_keys=None, batch_size=1, batch_cache=100)[source]#

Initializes a pmapped loss function that computes the validation loss.

Usage:
val_loss, data_state = batched_loss_fn(params, data_state)

Params needs to be N_devices times duplicated along axis 0.

Parameters:
  • model – A model with signature model(params, batch), which predicts outputs used in loss function.

  • loss_fn – Loss function(predictions, targets) returning the scalar loss value for a batch.

  • val_loader – NumpyDataLoader for validation set.

  • val_targets_keys – Dict containing targets of whole val

  • batch_size – Total batch size that is processed in parallel.

  • batch_cache – Number of batches to cache on GPU to reduce host-device communication.

Returns:

A pmapped function that returns the average validation loss.