learn.probabilistic.dropout_uq_predictions#
- dropout_uq_predictions(batched_model, meta_params, val_loader, init_rng_key=Array([0, 0], dtype=uint32), n_dropout_samples=8, batch_size=1, batch_cache=10, include_without_dropout=True)[source]#
Returns forward UQ predictions for a trained model on a validation dataset.
- Parameters:
batched_model – A model with signature model(params, batch), which was trained using dropout.
meta_params – Final trained meta_params
val_loader – Validation data loader
init_rng_key – Initial PRNGKey to use for sampling dropout configurations
n_dropout_samples – Number of predictions with different dropout configurations for each data observation.
batch_size – Number of input observations to vectorize. n_dropout_samples are already vmapped over.
batch_cache – Number of input observations cached in GPU memory.
include_without_dropout – Whether to also output prediction with Dropout disabled.
- Returns:
A tuple (uncertainties, no_dropout_predictions) containing for each data observation n_dropout_samples dropout predictions as well as the mean prediction with dropout disabled. If include_without_dropout is False, only uncertainties are returned.