learn.probabilistic.dropout_uq_predictions

learn.probabilistic.dropout_uq_predictions#

dropout_uq_predictions(batched_model, meta_params, val_loader, init_rng_key=Array([0, 0], dtype=uint32), n_dropout_samples=8, batch_size=1, batch_cache=10, include_without_dropout=True)[source]#

Returns forward UQ predictions for a trained model on a validation dataset.

Parameters:
  • batched_model – A model with signature model(params, batch), which was trained using dropout.

  • meta_params – Final trained meta_params

  • val_loader – Validation data loader

  • init_rng_key – Initial PRNGKey to use for sampling dropout configurations

  • n_dropout_samples – Number of predictions with different dropout configurations for each data observation.

  • batch_size – Number of input observations to vectorize. n_dropout_samples are already vmapped over.

  • batch_cache – Number of input observations cached in GPU memory.

  • include_without_dropout – Whether to also output prediction with Dropout disabled.

Returns:

A tuple (uncertainties, no_dropout_predictions) containing for each data observation n_dropout_samples dropout predictions as well as the mean prediction with dropout disabled. If include_without_dropout is False, only uncertainties are returned.