learn.probabilistic.init_dropout_uq_fwd#
- init_dropout_uq_fwd(batched_model, meta_params, n_dropout_samples=8)[source]#
Initializes a function that predicts a distribution of predictions for different dropout configurations, e.g. for uncertainty quantification.
- Parameters:
batched_model – A model with signature model(params, batch), which was trained using dropout.
meta_params – Final trained meta_params
n_dropout_samples – Number of predictions to run
- Returns:
The function predict_distribution(key, model_input) predicts n_dropout_samples predictions for different dropout configurations to be used e.g. for uncertainty quantification.