# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import jax
from jax import numpy as jnp, random
from jax_sgmc.data import numpy_loader, core
from chemtrain.data.preprocessing import train_val_test_split
from chemtrain import util
from typing import NamedTuple
class DataLoaders(NamedTuple):
train_loader: core.DataLoader
val_loader: core.DataLoader
test_loader: core.DataLoader
[docs]
def init_dataloaders(dataset, train_ratio=0.7, val_ratio=0.1, shuffle=False):
"""Splits dataset and initializes dataloaders.
If the validation or test ratios are 0, returns None for the respective
dataloaders.
Args:
dataset: Dictionary containing the whole dataset. The NumpyDataLoader
returns batches with the same kwargs as provided in dataset.
train_ratio: Fraction of dataset to use for training.
val_ratio: Fraction of dataset to use for validation.
shuffle: Whether to shuffle data before splitting into train-val-test.
Returns:
Returns a tuple ``(train_loader, val_loader, test_loader)`` of
NumpyDataLoaders.
"""
def init_subloader(data_subset):
if data_subset is None:
loader = None
else:
loader = numpy_loader.NumpyDataLoader(**data_subset, copy=False)
return loader
train_set, val_set, test_set = train_val_test_split(
dataset, train_ratio, val_ratio, shuffle=shuffle)
train_loader = init_subloader(train_set)
val_loader = init_subloader(val_set)
test_loader = init_subloader(test_set)
return DataLoaders(train_loader, val_loader, test_loader)
[docs]
def init_batch_functions(data_loader: core.HostDataLoader,
mb_size: int,
cache_size: int = 1,
) -> core.RandomBatch:
"""Initializes reference data access outside jit-compiled functions.
Randomly draw batches from a given dataset on the host or the device.
If ``rng_seed=<seed>`` is passed to the ``init_fn``, a ``jax.random.PRNGKey``,
will be added to the batch.
Args:
data_loader: Reads data from storage.
cache_size: Number of batches in the cache. A larger number is
faster, but requires more memory.
mb_size: Size of the data batch.
Returns:
Returns a tuple of functions to initialize a new reference data state, get
a minibatch from the reference data state and release the data loader after
the last computation.
"""
hcb_format, mb_information = data_loader.batch_format(
cache_size, mb_size=mb_size)
mask_shape = (cache_size, mb_size)
def init_fn(random: bool = True, rng_seed=None, **kwargs) -> core.CacheState:
if random:
chain_id = data_loader.register_random_pipeline(
cache_size=cache_size, mb_size=mb_size, **kwargs
)
else:
print(f"Initialize full data pipeline")
chain_id = data_loader.register_ordered_pipeline(
cache_size=cache_size, mb_size=mb_size, **kwargs
)
initial_state, initial_mask = data_loader.get_batches(chain_id)
if initial_mask is None:
initial_mask = jnp.ones((cache_size, mb_size), dtype=jnp.bool_)
initial_internal_state = {}
if rng_seed is not None:
initial_internal_state['rng'] = jax.random.PRNGKey(rng_seed)
inital_cache_state = core.CacheState(
cached_batches=initial_state,
cached_batches_count=jnp.array(cache_size),
current_line=jnp.array(0),
chain_id=jnp.array(chain_id),
valid=initial_mask,
state=initial_internal_state,
)
return inital_cache_state
def _new_cache_fn(state: core.CacheState,
) -> core.CacheState:
new_data, masks = data_loader.get_batches(state.chain_id)
if masks is None:
# Assume all samples to be valid.
masks = jnp.ones(mask_shape, dtype=jnp.bool_)
new_state = core.CacheState(
cached_batches_count=state.cached_batches_count,
cached_batches=new_data,
current_line=jnp.array(0),
chain_id=state.chain_id,
valid=masks,
callback_uuid=state.callback_uuid,
state=state.state
)
return new_state
@jax.jit
def _split_batch(data_state: core.CacheState):
current_line = jnp.mod(
data_state.current_line, data_state.cached_batches_count)
# Read the current line from the cache and add the mask containing
# information about the validity of the individual samples
mini_batch = util.tree_get_single(data_state.cached_batches, current_line)
mask = data_state.valid[current_line, :]
# Add a random key if required
internal_state = data_state.state
if 'rng' in internal_state.keys():
key, split = random.split(internal_state['rng'])
mini_batch['rng'] = random.split(split, mb_information.batch_size)
internal_state['rng'] = key
current_line = current_line + 1
new_state = core.CacheState(
cached_batches=data_state.cached_batches,
cached_batches_count=data_state.cached_batches_count,
current_line=current_line,
chain_id=data_state.chain_id,
valid=data_state.valid,
state=internal_state
)
info = core.MiniBatchInformation(
observation_count = mb_information.observation_count,
batch_size = mb_information.batch_size,
mask = mask)
return new_state, mini_batch, info
def batch_fn(data_state: core.CacheState,
information: bool = False,
) -> core.Batch:
"""Draws a new random batch.
Args:
data_state: State with cached samples
information: Whether to return batch information
device_count: Number of parallel programs calling the batch function
Returns:
Returns the new data state and the next batch. Optionally an additional
struct containing information about the batch can be returned.
"""
# Refresh the cache if necessary, after all cached batches have been used.
if data_state.current_line == data_state.cached_batches_count:
data_state = _new_cache_fn(data_state)
new_state, mini_batch, info = _split_batch(data_state)
if information:
return new_state, (mini_batch, info)
else:
return new_state, mini_batch
def release():
pass
return init_fn, batch_fn, release