Source code for chemtrain.typing
# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Documents commonly used types in chemtrain."""
import typing
from typing import Callable, Any, Optional, Protocol, TypedDict, Dict, TypeAlias
from typing_extensions import NotRequired
try:
from jax.typing import ArrayLike
except:
ArrayLike = Any
import jax_md_mod
from jax_md.energy import NeighborList
# Energy Functions
[docs]
class EnergyFn(Protocol):
[docs]
def __call__(self,
position: ArrayLike,
neighbor: NeighborList=None,
**kwargs) -> ArrayLike:
"""Computes the energy for a given conformation.
Args:
position: Positions of the particles
neighbor: Updated neighborlist
**kwargs: Additional parameters to the energy function, e.g., the
thermostat temperature ``"kbt"``.
Returns:
Returns the potential energy of the system.
"""
[docs]
class EnergyFnTemplate(Protocol):
[docs]
def __call__(self, energy_params: Any) -> EnergyFn:
"""Initialies the energy function with parameters.
Args:
energy_params: Parameters for the energy function.
Returns:
Returns a concrete potential energy function.
"""
[docs]
class ErrorFn(Protocol):
[docs]
def __call__(self,
predictions: ArrayLike,
targets: ArrayLike,
mask: ArrayLike = None,
weights: ArrayLike = None) -> ArrayLike:
"""Computes the error of the predictions.
Args:
predictions: Predicted values with same shape as targets
targets: Target values
mask: Masks out invalid predictions along the first axis.
weights: Weights for the error calculation.
Returns:
Returns the masked error value.
"""
# Quantities
[docs]
class TrajFn(Protocol):
[docs]
def __call__(self, quantity_trajs: Dict[str, ArrayLike], weights: ArrayLike = None) -> ArrayLike: ...
[docs]
class SingleTarget(TypedDict):
traj_fn: TrajFn
loss_fn: NotRequired[Callable[[ArrayLike, ArrayLike], ArrayLike]]
target: NotRequired[ArrayLike]
gamma: NotRequired[ArrayLike]
[docs]
class QuantityComputeFunction(Protocol):
[docs]
def __call__(self, state: Any, **kwargs) -> ArrayLike: ...
QuantityDict: TypeAlias = Dict[str, QuantityComputeFunction]
TargetDict: TypeAlias = Dict[str, SingleTarget]
[docs]
class ComputeFn(Protocol):
@typing.overload
def __call__(self, state, neighbor: NeighborList = None, **kwargs) -> Any: ...
[docs]
def __call__(self, state, **kwargs) -> Any: ...