learn.probabilistic.init_dropout_uq_fwd

learn.probabilistic.init_dropout_uq_fwd#

init_dropout_uq_fwd(batched_model, meta_params, n_dropout_samples=8)[source]#

Initializes a function that predicts a distribution of predictions for different dropout configurations, e.g. for uncertainty quantification.

Parameters:
  • batched_model – A model with signature model(params, batch), which was trained using dropout.

  • meta_params – Final trained meta_params

  • n_dropout_samples – Number of predictions to run

Returns:

The function predict_distribution(key, model_input) predicts n_dropout_samples predictions for different dropout configurations to be used e.g. for uncertainty quantification.