Source code for chemtrain.deploy.exporter

# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Exporting potential models to MLIR."""

import abc
import functools

import jax
from jax import numpy as jnp, export, lax

from typing import Dict, NamedTuple, Any, List, Tuple, Callable, NoReturn

import jax_md_mod
from jax_md import util as md_util, space

from . import graphs, util
from ._protobuf import model_pb2 as model_proto


[docs] class Exporter(metaclass=abc.ABCMeta): """Exports a potential model to an MLIR module. To export a potential model, subclass this class, select an appropriate graph type and define the energy function: Usage: >>> import jax.numpy as jnp >>> from jax_md_mod import custom_energy >>> from jax_md import partition, space >>> from chemtrain.deploy import exporter, graphs >>> class LennardJonesExport(exporter.Exporter): ... ... graph_type = graphs.SimpleSparseNeighborList ... r_cutoff = 5.0 ... unit_style = "real" ... nbr_order = [1, 1] ... ... def energy_fn(self, pos, species, graph): ... ... neighbors = partition.NeighborList( ... jnp.stack((graph.senders, graph.receivers)), ... pos, None, None, graph.senders.size, partition.Sparse, ... None, None, None ... ) ... ... assert neighbors.idx.shape[0] == 2, "Wrong shape" ... ... displacement_fn, _ = space.free() ... apply_fn = custom_energy.customn_lennard_jones_neighbor_list( ... displacement_fn, None, None, ... sigma=jnp.asarray([3.165]), epsilon=jnp.asarray([1.0]), ... r_onset=4.0, r_cutoff=5.0, ... initialize_neighbor_list=False, ... per_particle=True # Important for export ... ) ... ... return apply_fn(pos, neighbors, species=species) ... >>> model = LennardJonesExport() >>> try: ... model.save("out.ptb") ... except AssertionError as e: ... # We need to call the export method first ... print(f"Error: {e}") Error: Model has not been exported yet. Please call `export()` first. >>> >>> model.export() Attributes: graph_type: Specifies the required neighborhood representation and how to generate it from the input data. See ref:`chemtrain.deploy.graphs`. nbr_order: List of two integers specifying the number of neighbors required for the newton and non-newton setting to correctly compute forces. r_cutoff: Cutoff radius for the potential. unit_style: Specifies the units in which the potential requires positions and returns energies. The force units depend solely on the length and energy units. has_aux: If True, the energy function returns additional quantities besides the potential energy as dictionary. """ # Use the default graph containing the full neighbor indices graph_type: graphs.NeighborList = graphs.SimpleSparseNeighborList # Order to which neighbors are required for a correct force computation # in the newton and the non-newton setting nbr_order: List[int] = [1, 1] r_cutoff: float unit_style: str = "real" has_aux: bool = False _symbols: List[str] = None _constraints: List[str] = None _init_fns: List[Callable] = None _proto: model_proto.Model = None
[docs] @abc.abstractmethod def energy_fn(self, position, species, graph): """Computes the energy for positions and a graph representation. Args: position: (N, dim) Array of particle positions, including ghost atoms that are not within the local domain. species: (N) Array of atoms species. graph: Graph representation of the neighborhood around atoms. Returns: Must return an energy contribution associated to each particle. """ pass
@staticmethod @util.define_symbols("n_atoms") def _define_position_shapes(n_atoms, **kwargs): shape_defs = ( jax.ShapeDtypeStruct((n_atoms, 3), jnp.float32), jax.ShapeDtypeStruct((n_atoms,), jnp.int32), jax.ShapeDtypeStruct((), jnp.int32), # n_local jax.ShapeDtypeStruct((), jnp.int32), # n_ghost jax.ShapeDtypeStruct((), jnp.bool_), # newton flag ) return shape_defs def _add_shapes(self, init_fn): init_fn(self._symbols, self._constraints, self._init_fns) def _create_shapes(self): all_symbols = ",".join(self._symbols) symbols = { key: symb for key, symb in zip( self._symbols, export.symbolic_shape(all_symbols, constraints=self._constraints), ) } shapes = [] for init_fn in self._init_fns: shapes.extend(init_fn(**symbols)) # Reset self._symbols, self._constraints, self._init_fns = [], [], [] return shapes def _energy_fn(self, position, species, n_local, n_ghost, newton, *graph_args): # Expects particles to be sorted by local, ghost, and padding atoms valid_mask = jnp.arange(position.shape[0]) < (n_local + n_ghost) local_mask = jnp.arange(position.shape[0]) < n_local graph, build_statistics = self.graph_type.create_from_args( self.r_cutoff, self.nbr_order, position, species, local_mask, valid_mask, newton, *graph_args) graph = lax.stop_gradient(graph) @functools.partial(jax.grad, has_aux=True) def force_and_aux(pos): out = self.energy_fn(pos, species, graph) if self.has_aux: per_atom_energies, aux = out else: per_atom_energies = out aux = {} assert per_atom_energies.shape == local_mask.shape, ( f"Per particle energies have shape {per_atom_energies.shape}, " f"but should have shape {local_mask.shape}." ) # Attention: Force is negative gradient of potential. # Depending on the newton flag, we either compute: # - the gradient of the _total potential_ w.r.t. the _local atoms_ # - the gradient of the _local potential_ w.r.t. _all atoms_ # The latter case equals newton=true and requires additional # communication to sum up the forces on the ghost atoms. total_energy = md_util.high_precision_sum( jnp.where(valid_mask, per_atom_energies, jnp.float32(0.0))) local_energy = md_util.high_precision_sum( jnp.where(local_mask, per_atom_energies, jnp.float32(0.0)) ) force_energy = jnp.where(newton, local_energy, total_energy) force_energy = jnp.negative(force_energy) # Differentiate w.r.t. the total potential in the box, but exclude # ghost atom contributions to the total potential return force_energy, (per_atom_energies, aux) force, (energy, aux) = force_and_aux(position) predictions = dict(U=energy, F=force, **aux) for key, value in predictions.items(): assert value.shape[0] == local_mask.size, ( f"Wrong shape for prediction {value}. All model outputs " f"must be per-atom quantities." ) return predictions, build_statistics
[docs] def export(self) -> None: """Exports the potential model to an MLIR module.""" # Create a new context for each export self._symbols: List[str] = [] self._constraints: List[str] = [] self._init_fns: List[Callable] = [] proto = model_proto.Model() proto.neighbor_list.cutoff = self.r_cutoff proto.unit_style = self.unit_style assert len(self.nbr_order) == 2, ( "The nbr_order must contain the order of required neighbors for " "the newton and non-newton setting." ) proto.neighbor_list.nbr_order.extend(self.nbr_order) self.graph_type.set_properties(proto) # Using the ghost mask in the last layer we can compute correct forces # by accounting for their contribution to the gradient but # mask them out when we compute the total potential to not count # them double. self._add_shapes(self._define_position_shapes) self._add_shapes(self.graph_type.create_symbolic_input_format) export_fn = jax.jit(self._energy_fn) shapes = self._create_shapes() exp: export.Exported = export.export(export_fn, platforms=["cuda"])(*shapes) # Reconstruct the output to save the returned statistics and... predictions, statistics = exp.out_tree.unflatten(exp.out_avals) proto.neighbor_list.statistics_keys.extend(statistics.keys()) proto.quantities.extend(predictions.keys()) proto.mlir_module = exp.mlir_module() self._proto = proto
def __str__(self): assert self._proto is not None, ( "Model has not been exported yet. Please call `export()` first." ) return str(self._proto)
[docs] def save(self, file: str) -> None: """Saves the exported protobuffer to a file. Args: file: Path to the file where the model should be saved. """ assert self._proto is not None, ( "Model has not been exported yet. Please call `export()` first." ) with open(file, "wb") as f: f.write(self._proto.SerializeToString())