data.data_loaders#

Numpy Data Loaders#

init_dataloaders(dataset, train_ratio=0.7, val_ratio=0.1, shuffle=False)[source]#

Splits dataset and initializes dataloaders.

If the validation or test ratios are 0, returns None for the respective dataloaders.

Parameters:
  • dataset – Dictionary containing the whole dataset. The NumpyDataLoader returns batches with the same kwargs as provided in dataset.

  • train_ratio – Fraction of dataset to use for training.

  • val_ratio – Fraction of dataset to use for validation.

  • shuffle – Whether to shuffle data before splitting into train-val-test.

Returns:

Returns a tuple (train_loader, val_loader, test_loader) of NumpyDataLoaders.

init_batch_functions(data_loader, mb_size, cache_size=1)[source]#

Initializes reference data access outside jit-compiled functions.

Randomly draw batches from a given dataset on the host or the device. If rng_seed=<seed> is passed to the init_fn, a jax.random.PRNGKey, will be added to the batch.

Parameters:
  • data_loader (HostDataLoader) – Reads data from storage.

  • cache_size (int) – Number of batches in the cache. A larger number is faster, but requires more memory.

  • mb_size (int) – Size of the data batch.

Return type:

Tuple[Any, GetBatchFunction, Callable[[], None]]

Returns:

Returns a tuple of functions to initialize a new reference data state, get a minibatch from the reference data state and release the data loader after the last computation.