Source code for chemtrain.quantity.targets.util

# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Create quantity targets for training. """
import functools

from jax.typing import ArrayLike

from jax_md.partition import NeighborList

from typing import TypedDict, Callable, Dict, Tuple, Protocol, List
from chemtrain.typing import ComputeFn, EnergyFnTemplate, TargetDict

InitReturn = Tuple[Dict[str, TargetDict], Dict[str, ComputeFn]]
ComputeFnInit = Callable[..., ComputeFn]
ComputeFnInitDict = Dict[str, Callable[..., ComputeFnInit]]
ComputeFnDict = Dict[str, ComputeFn]

class TargetInit(Protocol):

    required_args: List[str]
    optional_args: List[str]

    def __call__(self,
                 key,
                 target_dicts,
                 compute_fns,
                 system) -> Tuple[TargetDict, ComputeFnDict]:
        ...


[docs] class InitArguments(TypedDict, total=False): """Non-exhaustive dictionary of arguments to initialize the compute functions. Arguments: displacement_fn: Function to compute the displacement from particles. Initialized via :mod:`jax_md.space`. energy_fn_template: Template to build energy function from a set of parameters. r_init: Initial positions of the system, e.g., to infer the number of particles in the system. nbrs_init: Initial neighbor list of the system, e.g., to infer the number of close triplets in the system. reference_box: Reference box of the system. kT: Fixed temperature of the system. pressure: Fixed pressure of the system """ displacement_fn: Callable energy_fn_template: EnergyFnTemplate nbrs_init: NeighborList r_init: ArrayLike reference_box: ArrayLike kT: ArrayLike pressure: ArrayLike
[docs] class TargetBuilder: """Class to simplify the initialization of DiffTRe targets. For each added target, the target builder will check whether compute functions for all required snapshots are provided. Additionally, the target builder delays the initialization of the compute functions until all information about the system is available. Args: system: Dictionary containing properties of the system, e.g., the displacement between particles. strict: If True, the target builder will check whether all required system properties are provided for each target. Otherwise, the properties of the system can be provided when building the targets. Example: First, we add the RDF function to the target builder: .. code:: target_builder = TargetBuilder() target_builder["rdf"] = quantity.structure.init_radial_distribution_target( target, rdf_start=0.00, rdf_cut=1.0, nbins=50 ) Then, as soon as the box and displacement function of the system are known, the compute functions and observables can be initialized: .. code:: targets, compute_fns = target_builder.build({ 'displacement_fn': displacement_fn, 'reference_box': box }) """ def __init__(self, system: InitArguments = None, strict: bool = False): if system is None: system = {} self._strict = strict self._system = system self._targets = {} def __setitem__(self, key, value: TargetInit): assert key not in self._targets, ( f"Duplicate target {key}." ) if self._strict: for required_arg in value.required_args: assert required_arg in self._system.keys(), ( f"Required argument {required_arg} not provided." ) self._targets[key] = value def build(self, system: InitArguments = None, ) -> Tuple[TargetDict, ComputeFnDict]: """Initializes the targets and compute functions. Args: system: Dictionary containing properties of the system, e.g., the displacement between particles. Returns: Returns a dictionary of observables and a dictionary of functions to compute instantaneous properties from the simulator states. """ if system is not None: self._system.update(system) target_dicts = {} compute_fns = {} for key, init_fn in self._targets.items(): target_dicts, compute_fns = init_fn( key, target_dicts, compute_fns, self._system) return target_dicts, compute_fns
[docs] def split_target_dict(target_dict: TargetDict ) -> Tuple[Dict[str, Callable], Dict[str, Callable], Dict[str, Dict[str, ArrayLike]]]: """Splits the target dictionary into observable functions, loss functions, and target values. Args: target_dict: Dictionary of targets. Returns: Returns a tuple of observable functions, loss functions, and target values. """ observables = { key: target['traj_fn'] for key, target in target_dict.items() } target_loss_fns = { key: target['loss_fn'] for key, target in target_dict.items() if 'loss_fn' in target } targets = { key: {k: v for k, v in target.items() if k in ['gamma', 'target']} for key, target in target_dict.items() } return observables, target_loss_fns, targets
def target_quantity(required: list = None, optional: list = None): """Initializes a decorator defining required and optional arguments to initialize a target. """ if required is None: required = [] if optional is None: optional = [] def quantity_init_wrapper(init_fn): @functools.wraps(init_fn) def quantity_init(key: str, targets_dict: TargetDict, compute_fns: ComputeFnInitDict, init_args: InitArguments ) -> Tuple[TargetDict, ComputeFnDict]: # Assert that all required arguments are provided and collect all # optional arguments init_kwargs = { init_key: init_args[init_key] for init_key in required } optional_kwargs = { init_key: value for init_key in optional if (value := init_args.get(init_key)) is not None } init_kwargs.update(optional_kwargs) # Initialize the compute function and target function with the # provided arguments target_dict, compute_fn = init_fn( key, compute_fns, init_kwargs) # Add the initialized compute function and target dict if target_dict is not None: assert key not in target_dict.keys(), ( f"Duplicate target {key} provided.") targets_dict.update({key: target_dict}) if compute_fn is not None: assert key not in compute_fns.keys(), ( f"Duplicate quantitiy {key} initialized.") compute_fns.update({key: compute_fn}) return targets_dict, compute_fns # Enables checking whether system arguments are already provided quantity_init.required_args = required quantity_init.optional_args = optional return quantity_init return quantity_init_wrapper