# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains molecular properties, computed from features of a
neural network used for potential energy prediction.
"""
import functools
import typing
from functools import wraps
from typing import Tuple, Callable, Any
import jax
from jax import numpy as jnp, Array
from jax_md import partition
from chemtrain.typing import ComputeFn
def readout_wrapper(energy_fn_template, mode="energy"):
"""Utility function to read out additional predictions from energy function.
The energy function returned by the energy function template must have
a keyword argument to select the return mode, e.g.
>>> def energy_fn(pos, neighbor, mode=None, **kwargs):
... ...
... if mode is None:
... return pot
... if mode == "some_property":
... return some_property
The energy function can then be used to predict also other quantities:
>>> quantity_dict = {
... "some_property": readout_wrapper(energy_fn_template, mode="some_property")
... }
"""
def snapshot_fn(state, neighbor=None, energy_params=None, **kwargs):
assert energy_params is not None, "Energy parameters must be provided."
prediction = energy_fn_template(energy_params)(state.position, neighbor, mode=mode, **kwargs)
return prediction
return snapshot_fn
[docs]
def molecular_property_predictor(model, n_per_atom=0):
"""Wraps models that predict per-atom quantities to predict both global and
per-atom quantities.
Args:
model: Initialized model predicting per-atom quantities, e.g. DimeNetPP.
n_per_atom: Number of per-atom quantities to predict. Remaining
predictions are assumed to be global.
Returns:
A tuple (global_properties, per_atom_properties) being a (n_globals,)
and a (n_particles, n_per_atom) array of predictions.
"""
@wraps(model)
def property_wrapper(*args, **kwargs):
per_atom_quantities = model(*args, **kwargs)
n_predicted = per_atom_quantities.shape[1]
n_global = n_predicted - n_per_atom
per_atom_properties = per_atom_quantities[:, n_global:]
global_properties = jnp.sum(per_atom_quantities[:, :n_global], axis=0)
return global_properties, per_atom_properties
return property_wrapper
[docs]
class PropertyPredictor(typing.Protocol):
[docs]
def __call__(self, params, graph: Any, **kwargs) -> Tuple[Array, Array]:
"""Predicts molecular properties from a molecular graph.
The form of the graph depends on the underlying model. E.g, for
property predictions with
:class:`jax_md_mod.model.neural_networks.DimeNetPP`, the graph is of
the form :class:`jax_md_mod.model.sparse_graph.SparseDirectionalGraph`.
Args:
graph: Molecular graph containing neighborhood information.
**kwargs: Additional arguments to the potential model,
extracting features from the molecular graph.
Returns:
A tuple of global and per-atom properties.
"""
[docs]
class SinglePropertyPredictor(typing.Protocol):
@typing.overload
def __call__(self, params, graph: Any, **kwargs): ...
@typing.overload
def __call__(self, features, **kwargs): ...
[docs]
@staticmethod
def __call__(params=None, graph=None, features=None, **kwargs):
"""Predicts a molecular property from a molecular graph.
The form of the graph depends on the underlying model. E.g, for
property predictions with
:class:`jax_md_mod.model.neural_networks.DimeNetPP`, the graph is of
the form :class:`jax_md_mod.model.sparse_graph.SparseDirectionalGraph`.
Args:
graph: Molecular graph containing neighborhood information.
Required, if features should be computed from a model within
the predictor.
features: Features derived from the molecular graph.
Required if no model is provided to compute the features.
**kwargs: Additional arguments to the potential model,
extracting features from the molecular graph.
Returns:
Returns a single per-atom or molecular property.
"""
def apply_model(model: PropertyPredictor = None):
"""Initializes a molecular property predictor.
If no model is provided, the global and per-atom features must be
pre-computed. The property predictor can then be called as::
property = property_predictor(global_features, per_atom_features, **kwargs)
Otherwise, the features are computed via calling the model within the
property predictor. Then, the property predictor must be called with
a molecular graph as input::
property = property_predictor(graph, **kwargs)
Args:
model: Optional model to compute the molecular features within the
property predictor.
"""
def decorator(fn) -> SinglePropertyPredictor:
def predictor(params=None, graph=None, features=None, **kwargs):
if model is not None:
assert graph is not None and params is not None, (
"A graph is required to compute molecular features"
)
global_features, per_atom_features = model(params, graph, **kwargs)
features = {
"global_features": global_features,
"per_atom_features": per_atom_features
}
else:
assert features is not None, (
"If molecular features are not pre-computed, a model must "
"be provided to compute them."
)
return fn(features, **kwargs)
return predictor
return decorator
[docs]
def snapshot_quantity(property_predictor: SinglePropertyPredictor,
graph_from_neighbor_list: Callable = None,
features_key = "features") -> ComputeFn:
"""Transforms a single property predictor to a snapshot compute function.
Args:
property_predictor: Function to predict a property from a molecular
graph or from pre-extracted features.
graph_from_neighbor_list: Function to build a molecular graph from
a neighbor list. Only necessary, when the features should be
extracted within the property predictor. Otherwise, pre-extracted
global and per-atom features are required as kwargs.
features_key: Key to the pre-computed features if no model is provided.
Returns:
Returns a snapshot compute function for the corresponding property.
"""
def compute_fn(state,
neighbor=None,
energy_params=None,
**kwargs):
if graph_from_neighbor_list is None:
assert features_key in kwargs, (
f"Features {features_key} must be pre-computed if model is not "
f"provided."
)
return property_predictor(kwargs[features_key], **kwargs)
else:
assert neighbor is not None, (
"A neighbor list is required to build a molecular graph."
)
# Enables to remove processed arguments
mol_graph, kwargs = graph_from_neighbor_list(
state.position, neighbor, **kwargs)
return property_predictor(energy_params, mol_graph, **kwargs)
return compute_fn
[docs]
def init_feature_pre_computation(model: PropertyPredictor,
graph_from_neighbor_list: Callable = None
) -> ComputeFn:
"""Initializes a function to compute global and per-atom features.
Args:
model: Model to compute the features from a molecular graph.
graph_from_neighbor_list: Function to construct a molecular graph
from the neighbor list.
Returns:
Returns a function to compute the features from a molecular graph.
"""
def feature_computation_fn(state, neighbor=None, energy_params=None, **kwargs):
graph = graph_from_neighbor_list(state.position, neighbor, **kwargs)
global_features, per_atom_features = model(energy_params, graph, **kwargs)
return {
"global_features": global_features,
"per_atom_features": per_atom_features
}
return feature_computation_fn
[docs]
def potential_energy_prediction(model: PropertyPredictor = None,
feature_number: int = 0
) -> SinglePropertyPredictor:
"""Initializes a prediction of the potential energy.
This wrapper allows to use the same features for the prediction of the
potential energy for a simulation and for other molecular properties.
Args:
model: Particle property prediction model.
feature_number: Number of the global features to interpret as
potential energy.
Example::
init_property_predictor, property_predictor = neural_networks.dimenetpp_property_prediction(
r_cutoff = 1.0, n_targets = 2, n_species = 2, n_per_atom = 0)
# Initialize the prediction of potential energy (the first global property)
potential_energy_predictor = property_prediction.potential_energy_prediction(
model=property_predictor, feature_number=0
)
# Initialize a function to compute the potential energy for a simulator
# snapshot. The snapshot function first constructs a molecular graph
# from a provided neighbor list.
energy_snapshot_fn = property_prediction.snapshot_quantity(
potential_energy_predictor, graph_from_neighbor_list
)
# The snapshot function can be used as learnable model or as
# compute function for traj_util.quantity_traj
def energy_fn_template(energy_params):
def energy_fn(position, neighbor=None, **kwargs):
# Wrap positions in pseudo simulator state
state = force_matching.State(position)
return energy_snapshot_fn(position, neighbor, energy_params=energy_params)
return energy_fn
Returns:
Returns a function to predict the potential energy from a molecular graph.
"""
@apply_model(model)
def potential_energy(features, **kwargs) -> Array:
return features["global_features"][feature_number]
return potential_energy
[docs]
def partial_charge_prediction(model: PropertyPredictor = None,
feature_number: int = 1,
total_charge: Array = 0.0,
) -> SinglePropertyPredictor:
"""Initializes a prediction of partial charges.
For a usage with or without model, see the decorator
:func:`apply_model`.
Args:
model: Model extracting particle properties. If not provided, the
global and per-atom features must be pre-computed.
feature_number: Number of the per-atom features to base the prediction on.
total_charge: Total charge of the system. By default, the system should
be charge neutral.
Returns:
Returns a function to predict partial charges from a molecular graph.
"""
@apply_model(model)
def partial_charge(features, **kwargs) -> Array:
raw_partial_charges = features["per_atom_features"][:, feature_number]
# If masked particles are present, remove their partial charges
mask = kwargs.get("mask", jnp.ones_like(raw_partial_charges))
# Correct the charges to ensure charge_neutrality
charge_correction = jnp.sum(raw_partial_charges * mask)
charge_correction -= total_charge
charge_correction /= jnp.sum(mask)
partial_charges = (raw_partial_charges - charge_correction) * mask
return partial_charges
return partial_charge
[docs]
def init_dipole_moment(displacement_fn: Callable,
model: PropertyPredictor = None,
graph_from_neighbor_list: Callable = None,
partial_charge_feature: int = 1,
reference_position_fn: Callable = None,
features_key: str = "features"):
"""Computes the dipole moment from partial charges of a molecule.
Args:
displacement_fn: Function to compute displacement between reference
point and particle positions.
model: Model to predict the partial charges.
graph_from_neighbor_list: Function to create molecular graph from
neighbor list.
partial_charge_feature: Number of the per-atom feature corresponding
to the partial charge.
reference_position_fn: Function to compute the reference position for
the dipole moment. If None, the origin of the box is used.
features_key: Key to the pre-computed features if no model is provided.
Returns:
Returns a function to compute dipole moment snapshots.
"""
# Initialize the partial charge as a snapshot function
partial_charge_predictor = partial_charge_prediction(
model, partial_charge_feature)
partial_charge_snapshot_fn = snapshot_quantity(
partial_charge_predictor, graph_from_neighbor_list, features_key
)
def dipole_moment_snapshot(state, mask=None, **kwargs):
if mask is None:
mask = jnp.ones(state.position.shape[0])
# Use the origin as reference position
if reference_position_fn is None:
ref_position = jnp.zeros(state.position.shape[1])
else:
ref_position = reference_position_fn(state, **kwargs)
dynamic_displacement = functools.partial(displacement_fn, **kwargs)
# Compute the dipole moment with respect to a user-defined reference
# point.
partial_charges = partial_charge_snapshot_fn(state, mask=mask, **kwargs)
displacements = jax.vmap(
dynamic_displacement, (0, None)
)(state.position, ref_position)
moment = jnp.sum(
partial_charges[:, None] * displacements * mask[:, None],
axis=0
)
return moment
return dipole_moment_snapshot