Source code for chemtrain.compose.utils


# Copyright 2026 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities to connect models to chemtrain."""

import numpy as onp

import jax
import jax.numpy as jnp
from jax import lax

from typing import Protocol, Any, Tuple


[docs] class ApplyFn(Protocol): """GNN apply function protocol.""" def __call__( self, params: Any, senders: jnp.ndarray, receivers: jnp.ndarray, edge_features: Tuple[jnp.ndarray], node_features: Tuple[jnp.ndarray], ) -> jnp.ndarray: ...
[docs] def batch_apply_fn(_apply_fn: ApplyFn) -> ApplyFn: """Write custom vmap rules for the apply function. Instead of batching over graphs, combines all graphs into a supergraph. This step avoids vmapping operations in the neural network. Args: apply_fn: Mapping from (params, vectors, senders, receivers, species, mask) to per-particle energies. Returns: A wrapped apply function with custom vmap rules that avoids vmapping in the neural network. """ def apply_fn(*args, **kwargs): res = _apply_fn(*args, **kwargs) return res def wrapped(params, senders, receivers, edge_features, node_features): # Check inputs. Node features are important to correctly rewrite invalid # indices in the supergraph. assert len(node_features) > 0, "At least one node feature array is required." return wrapped_apply_fn( params, senders, receivers, edge_features, node_features ) @jax.custom_vjp @jax.custom_batching.custom_vmap def wrapped_apply_fn(params, senders, receivers, edge_features, node_features): return apply_fn(params, senders, receivers, edge_features, node_features) def wrapped_fun_fwd(*args): y = wrapped_apply_fn(*args) return y, args @jax.custom_batching.custom_vmap def wrapped_fun_bwd(res, y_bar): _, vjp_fn = jax.vjp(apply_fn, *res) return vjp_fn(y_bar) @wrapped_fun_bwd.def_vmap def wrapped_fun_bwd_batch( axis_size, in_batched, res, y_bar): args_batched, y_bar_batched = in_batched # bparams, *_, bedge_features, bnode_features = args_batched if not y_bar_batched: y_bar = jnp.tile(y_bar[jnp.newaxis, :], (axis_size, 1)) # The batched function figures out batching of parameters and graphs _, vjp_fn = jax.vjp( lambda *args: wrapped_fun_batch( axis_size, args_batched, *args, compute_vjp=True )[0], *res ) grads = vjp_fn(y_bar) return grads, args_batched def wrapped_fun_batch( axis_size, in_batched, params, senders, receivers, edge_features, node_features, compute_vjp=False): bparams, *inputs_batched = in_batched # We must avoid the custom_vmap decorated function, as no vjp rule # is defined for it. if compute_vjp: fn = apply_fn else: fn = wrapped_apply_fn # Find out whether any of the parameters is batched batched_params = jax.tree.reduce( onp.logical_or, bparams, onp.bool(False) ) # If the parameters are batched, the super-graph strategy does not work. # Thus, we fall back to a sequential evaluation via lax.map. if batched_params: # Ensure that all params are in the batched shape args_tiled = jax.tree.map( lambda l, b: jnp.tile(l, (axis_size,) + (1,) * (l.ndim - 1)) if not b else l, (params, senders, receivers, edge_features, node_features), (bparams, *inputs_batched), ) energies = lax.map( lambda args: fn(*args), args_tiled ) return energies, True else: flattened_graph = flatten_graph( axis_size, inputs_batched, senders, receivers, edge_features, node_features ) energies_flat = fn(params, *flattened_graph) # Unflatten the results return energies_flat.reshape((axis_size, -1)), True wrapped_apply_fn.defvjp(wrapped_fun_fwd, wrapped_fun_bwd) wrapped_apply_fn.def_vmap(wrapped_fun_batch) return wrapped
[docs] def flatten_graph(axis_size, in_batched, senders, receivers, edge_features, node_features): """Flattens batched graphs into a supergraph.""" bsenders, breceivers, bedge_features, bnode_features = in_batched num_graphs = axis_size num_edges = ( senders.shape[1] if bsenders else senders.shape[0] ) natoms = ( node_features[0].shape[1] if bnode_features[0] else node_features[0].shape[0] ) if bsenders: assert breceivers, ( "If vectors are batched, senders and receivers must be batched." ) else: assert not breceivers, ( "If vectors are not batched, senders and receivers must not be batched." ) senders = jnp.tile( senders[None, :], (num_graphs, 1) + (1,) * (senders.ndim -1) ) receivers = jnp.tile( receivers[None, :], (num_graphs, 1) + (1,) * (receivers.ndim -1) ) # Relabel the senders and receiver indices. Offset the senders # by the number of atoms in previous graphs. senders = jnp.where( senders.ravel() < natoms, senders.ravel() + natoms * jnp.repeat(jnp.arange(num_graphs), num_edges), num_graphs * natoms ) receivers = receivers.reshape((-1, 2)) receivers = jnp.where( receivers.ravel() < natoms, receivers.ravel() + natoms * jnp.repeat(jnp.arange(num_graphs), num_edges), num_graphs * natoms ) edge_features_flat = () for b, feat in zip(bedge_features, edge_features): if b: edge_features_flat += (jnp.reshape(feat, (-1,) + feat.shape[2:]),) else: edge_features_flat += (jnp.tile( feat[None, :], (num_graphs, 1) + (1,) * (feat.ndim -1) ).ravel(),) node_features_flat = () for b, feat in zip(bnode_features, node_features): if b: node_features_flat += (jnp.reshape(feat, (-1,) + feat.shape[2:]),) else: node_features_flat += (jnp.tile( feat[None, :], (num_graphs, 1) + (1,) * (feat.ndim -1) ).ravel(),) return senders, receivers, edge_features_flat, node_features_flat