Source code for chemtrain.deploy.graphs

# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Graphs for exporting potential and force models."""

import abc
import functools
import typing
from itertools import product

import numpy as onp

import jax
from jax import export, numpy as jnp, lax

import jax_md_mod
from jax_md import partition, dataclasses, smap, space

from typing import NamedTuple, Tuple

from . import util
from ._protobuf import model_pb2 as model_proto


# Does not have to be typed
ListStatistics = typing.Dict


@dataclasses.dataclass
class NeighborList(metaclass=abc.ABCMeta):
    """Abstract class for neighbor list graphs."""

    @staticmethod
    @abc.abstractmethod
    def set_properties(proto: model_proto.Model):
        """Assigns the graph type to the protobuf message."""
        pass

    @staticmethod
    @util.define_symbols("")
    @abc.abstractmethod
    def create_symbolic_input_format(*args, **kwargs):
        """Creates a symbolic representation of the graph.

        Args:
            max_atoms: The maximum number of atoms, including ghost atoms and
                padding atoms.
            scope: The scope to add more symbolic variables.

        The variables should begin with "graph_".

        Returns:
            Returns a symbolic representation of the graph.

        """

    @staticmethod
    def create_from_args(displacement_fn,
                         r_cutoff,
                         num_mpl,
                         position,
                         species,
                         ghost_mask,
                         valid_mask,
                         newton,
                         *args,
                         half=True
                         ) -> Tuple["NeighborList", "ListStatistics"]:
        """Creates the neighbor list from inputs to the exported function."""


[docs] @dataclasses.dataclass class SimpleSparseNeighborList(NeighborList): """Simple neighbor list representation using precomputed neighbor list. This neighbor list is a sparse representation of a graph. It does not infer the neighbor list from the positions but acts as an interface between the precomputed neighbor list, e.g., from LAMMPS, and the exported model. Nevertheless, this class increases the efficiency of the exported model while reducing necessary data transfer by pruning the neighbor list. Therefore, the class filters out all edges that are longer than the specified model cutoff distance. Moreover, it prunes all edges between ghost atoms (atoms not in the local domain) that are not relevant for a correct force computation. Attributes: senders: The sender indices of the edges. receivers: The receiver indices of the edges. max_edges: The maximum number of relevant edges in the neighbor list. """ senders: jax.Array receivers: jax.Array max_edges: jax.Array
[docs] @staticmethod def set_properties(proto: model_proto.Model): proto.neighbor_list.type = model_proto.Model.NeighborListType.SIMPLE_SPARSE proto.neighbor_list.half_list = True
@staticmethod @util.define_symbols( "max_buffers, max_edges", ["max_edges <= 2 * max_buffers"] ) def create_symbolic_input_format(max_buffers, max_edges, **kwargs): senders = jax.ShapeDtypeStruct((max_buffers,), jnp.int32) receivers = jax.ShapeDtypeStruct((max_buffers,), jnp.int32) buffer = jax.ShapeDtypeStruct((max_edges,), jnp.bool_) return senders, receivers, buffer
[docs] @staticmethod def create_from_args(r_cutoff, nbr_order, position, species, ghost_mask, valid_mask, newton, *args) -> Tuple["SimpleSparseNeighborList", "NeighborListStatistics"]: # Make edges undirected by adding their counterpart invalid_idx = species.size # If newton is true, the transferred neighbor list is a full list. # Therefore, we need to set half of the edges to invalid to avoid # double counting. senders, receivers, m = args max_edges = m.size # Remove all edges that are longer than the cutoff distance dists = jnp.linalg.norm(position[senders] - position[receivers], axis=-1) invalid = dists > r_cutoff vs = jnp.where(invalid, invalid_idx, senders) vr = jnp.where(invalid, invalid_idx, receivers) # Prune all irrelevant edges. In the newton setting, the provided # neighbor list is a full list. graph = SimpleSparseNeighborList(vs, vr, m) graph, max_neighbors = lax.cond( newton, functools.partial(prune_neighbor_list, max_edges=max_edges, nbr_order=nbr_order[0], half_list=False), functools.partial(prune_neighbor_list, max_edges=max_edges, nbr_order=nbr_order[1], half_list=True), graph, ghost_mask ) statistics = NeighborListStatistics( max_neighbors=max_neighbors, overlong=jnp.sum(~invalid) ) return graph, statistics
def to_neighborlist(self): idx = jnp.stack([self.senders, self.receivers], axis=0) nbrs = partition.NeighborList( idx, None, None, None, None, partition.Sparse, None, None, None) return nbrs
[docs] @dataclasses.dataclass class SimpleDenseNeighborList(NeighborList): """Simple dense neighbor list representation using precomputed neighbor list. This neighbor list is a semi-sparse representation of a graph. It does not infer the neighbor list from the positions but acts as an interface between the precomputed neighbor list, e.g., from LAMMPS, and the exported model. This class increases the efficiency of the exported model while reducing necessary data transfer by pruning the neighbor list. Therefore, the class filters out all edges that are longer than the specified model cutoff distance. Moreover, it prunes all edges between ghost atoms (atoms not in the local domain) that are not relevant for a correct force computation. Attributes: senders: The sender indices of the edges. receivers: The receiver indices of the edges. max_edges: The maximum number of relevant edges in the neighbor list. """ nbrs: jax.Array max_edges: jax.Array max_triplets: jax.Array
[docs] @staticmethod def set_properties(proto: model_proto.Model): proto.neighbor_list.type = proto.NeighborListType.SIMPLE_DENSE proto.neighbor_list.half_list = False
@staticmethod @util.define_symbols( "max_nbrs, max_edges, max_triplets", [ "max_nbrs <= n_atoms", "max_edges <= n_atoms * max_nbrs", "max_triplets <= max_edges * max_nbrs" ] ) def create_symbolic_input_format(max_nbrs, max_edges, max_triplets, **kwargs): nbrs = jax.ShapeDtypeStruct((kwargs["n_atoms"], max_nbrs), jnp.int32) max_edges = jax.ShapeDtypeStruct((max_edges,), jnp.bool_) max_triplets = jax.ShapeDtypeStruct((max_triplets,), jnp.bool_) return nbrs, max_edges, max_triplets
[docs] @staticmethod def create_from_args(r_cutoff, nbr_order, position, species, ghost_mask, valid_mask, newton, *args) -> Tuple["SimpleDenseNeighborList", "NeighborListStatistics"]: # Make edges undirected by adding their counterpart invalid_idx = species.size # If newton is true, the transferred neighbor list is a full list. # Therefore, we need to set half of the edges to invalid to avoid # double counting. nbrs, max_edges, max_triplets = args # Remove all edges that are longer than the cutoff distance dists = jax.vmap( jax.vmap( lambda i, j: jnp.linalg.norm(position[i] - position[j]), in_axes=(None, 0) ), in_axes=(0, 0) )(jnp.arange(nbrs.shape[0]), nbrs) invalid = dists > r_cutoff nbrs = jnp.where(invalid, invalid_idx, nbrs) # Prune all irrelevant edges. In the newton setting, the provided # neighbor list is a full list. graph = SimpleDenseNeighborList(nbrs, max_edges, max_triplets) graph, (max_edges, max_triplets) = lax.cond( newton, functools.partial(prune_neighbor_list_dense, nbr_order=nbr_order[0]), functools.partial(prune_neighbor_list_dense, nbr_order=nbr_order[1]), graph, ghost_mask ) statistics = NeighborListStatistics( max_neighbors=max_edges, overlong=max_triplets ) return graph, statistics
def to_neighborlist(self): nbrs = partition.NeighborList( self.nbrs, None, None, None, None, partition.Dense, None, None, None) return nbrs
class DeviceSparseNeighborListArgs(NamedTuple): update: jax.Array | jax.ShapeDtypeStruct xcells: jax.Array | jax.ShapeDtypeStruct ycells: jax.Array | jax.ShapeDtypeStruct zcells: jax.Array | jax.ShapeDtypeStruct capacity: jax.Array | jax.ShapeDtypeStruct # ref_pos: jax.Array | jax.ShapeDtypeStruct # cutoff: jax.Array | jax.ShapeDtypeStruct # skin: jax.Array | jax.ShapeDtypeStruct senders: jax.Array | jax.ShapeDtypeStruct receivers: jax.Array | jax.ShapeDtypeStruct @dataclasses.dataclass class DeviceSparseNeighborList(NeighborList): """Creates the neighbor list graph on the device using a cell list. Warning: This implementation is experimental and work in progress. """ @staticmethod def set_properties(proto: model_proto.Model): proto.neighbor_list.type = proto.NeighborListType.DEVICE_SPARSE @staticmethod @util.define_symbols( "max_neighbors, nx, ny, nz, c", ["c <= n_atoms", "27*c^2*nx*ny*nz >= max_neighbors"] ) def create_symbolic_input_format(max_neighbors, nx, ny, nz, c, *, n_atoms, **kwargs): # Currently, JAX can only infer dimensions from array shapes but not the # input update = jax.ShapeDtypeStruct((1,), jnp.bool) xcells = jax.ShapeDtypeStruct((nx,), jnp.bool) ycells = jax.ShapeDtypeStruct((ny,), jnp.bool) zcells = jax.ShapeDtypeStruct((nz,), jnp.bool) capacity = jax.ShapeDtypeStruct((c,), jnp.bool) # We pass reference positions from the previous build to skip the # neighbor list construction if smaller than the input # ref_pos = jax.ShapeDtypeStruct((n_atoms, 3), jnp.float32) # Increase cutoff by this value to reuse neighbor list when particle # move less than half this distance # skin = jax.ShapeDtypeStruct(tuple(), jnp.float32) # cutoff = skin senders = jax.ShapeDtypeStruct((max_neighbors,), jnp.int32) # receivers = jax.ShapeDtypeStruct((max_neighbors,), jnp.int32) return ( update, xcells, ycells, zcells, capacity, senders, senders ) @staticmethod def create_from_args(r_cutoff, num_mpl, positions, species, ghost_mask, valid_mask, *args): nargs = DeviceSparseNeighborListArgs(*args) buffer = jnp.zeros( ( nargs.xcells.size, nargs.ycells.size, nargs.zcells.size, nargs.capacity.size ), dtype=jnp.int32 ) # TODO: Skip the recomputation for now # recompute = jnp.max( # jnp.sum((positions - nargs.ref_pos) ** 2.0, axis=-1) # ) < (nargs.skin / 2) ** 2 update_fn = functools.partial( compute_neighbor_list, positions, buffer, nargs.senders, cutoff=r_cutoff + 2.0, mask=valid_mask # Hard-coded skin size ) def reuse_fn(): # Return the statistics from the previous build statistics = NeighborListStatistics( min_cell_capacity=nargs.capacity.size, cell_too_small=0, max_neighbors=nargs.senders.size) return (nargs.senders, nargs.receivers), statistics graph, statistics = lax.cond(nargs.update.squeeze(), update_fn, reuse_fn) return SimpleSparseNeighborList(*graph), (*statistics.tuple, *graph) class DeviceListStatistics(typing.TypedDict, total=True): """Statistics for the :class:`DeviceSparseNeighborList`.""" min_cell_capacity: typing.Required[int] cell_too_small: typing.Required[int] max_neighbors: typing.Required[int]
[docs] class NeighborListStatistics(typing.TypedDict, total=True): """Statistics for the :class:`SimpleSparseNeighborList`.""" max_neighbors: typing.Required[int] overlong: typing.Optional[int]
[docs] @jax.jit def compute_cell_list(position, id_buffer, cutoff, mask=None, eps=1e-3): """Assigns particle IDs into a 3D grid. This implementation follows the JAX, M.D. implementation, but aims to support building a cell list by only using shape information from the input arguments. Args: position: The position of the atom. id_buffer: Determines the dimensions of the grid and the cell capacities. Shape (nx, ny, nz, c) correponds to the numbers of cells in x,y,z dimensions and the maximum capacity per cell c. cutoff: Cutoff to check the dimensions of the cells. If the cell dimensions are smaller than the cutoff, increases the box size to enlarge the cells. Has the downside that cells will get fuller than usual, but will still yield correct neighbor list results. mask: Specifies whether particles should be ignored (mask = 0) eps: Tolerance increasing the box and cells to avoid wrong classification Returns: Returns a tuple with updated particle ids per grid and a dataclass containing statistics of the build. """ assert mask is not None, "Requires mask argument!" if mask is None: mask = jnp.ones(position.shape[0], dtype=bool) *cell_counts, capacity = id_buffer.shape # Shift the positions to be in the range [0, box]. First, we shift # the masked particles positions to not have an influence on the range. # Then we shift the positions to be positive. mean_position = jnp.mean(mask[:, jnp.newaxis] * position, axis=0, keepdims=True) position = jnp.where(mask[:, jnp.newaxis], position, mean_position) position -= jnp.min(position, axis=0, keepdims=True) # TODO: How big should the tolerance be? box = jnp.diag(jnp.max(position, axis=0) + 0.5 * cutoff) # Generally, the minimum cell dimension must be larger than the cutoff, # such that all potential neighbors are contained in the neighboring cells. # Potential workaround: Increase box dimension such that smallest cell size # is as large as the cutoff. Will work if cell capacity is big enough cell_sizes = jnp.diag(box) / jnp.asarray(cell_counts) cell_too_small = jnp.sum((cell_sizes < cutoff) * 2 ** jnp.arange(3)) cell_too_small = jnp.sum(1 - mask) # Scale the box dimensions such that all cell sizes are larger than the cutoff cell_sizes *= 1 + (cell_sizes < cutoff) * ((cutoff - cell_sizes) / cell_sizes) # Get the cell ids for each particle in every dimension (n, x_id, y_id, z_id) # and transfrom into flat ids. Assign invalid particles an invalid # cell id such that they are not member to any of the cells nx, ny, nz = cell_counts max_cell_ids = 1 for n_in_dim in cell_counts: max_cell_ids *= n_in_dim cell_ids = jnp.int32(jnp.floor(position / cell_sizes[jnp.newaxis, :])) cell_ids = jnp.sum(cell_ids * jnp.asarray([[nz * ny, nz, 1]]), axis=-1) cell_ids = jnp.where(mask, cell_ids, max_cell_ids) # We can now count how often a particle appears in each cell cell_occupancy = jax.ops.segment_sum(jnp.int32(mask), cell_ids, cell_ids.size + 1) min_cell_capacity = jnp.max(cell_occupancy) # We sort the particles along their cell id to obtain, e.g. # the cell id array (0, 0, 0, 1, 1, 2, 3, ...). If the capacity is # sufficiently large, each segment should be no longer than the capacity. # We now create a second array that with repeating numbers 0 ... capacity, # such that within segment each number appears at most once. sort_idx = jnp.argsort(cell_ids) particle_ids = jnp.arange(position.shape[0]) unique_id_per_segment = jnp.mod(lax.iota(jnp.int32, position.shape[0]), capacity) new_id_buffer = jnp.full((max_cell_ids + 1, capacity), position.shape[0]) new_id_buffer = new_id_buffer.at[cell_ids[sort_idx], unique_id_per_segment].set(particle_ids[sort_idx]) new_id_buffer = new_id_buffer[:-1, :].reshape(id_buffer.shape) statistics = DeviceListStatistics(min_cell_capacity, cell_too_small, 0) return new_id_buffer, statistics
@jax.jit def compute_neighbor_list(position, id_buffer, senders, cutoff, mask=None, eps=1e-3): """Computes a sparse neighbor list using a cell list. Args: position: The positions of the atoms. id_buffer: Determines the dimensions of the grid and the cell capacity. senders: Determines the maximum number of edges. cutoff: Includes neighbor up to this distance. mask: Specifies whether particles should be ignored (mask = 0) eps: Tolerance increasing the box and cells to avoid wrong classification. Returns: Returns a tuple with sender-receiver pairs and statistics of the neighbor list construction. """ assert mask is not None, "Requires mask argument!" if mask is None: mask = jnp.ones(position.shape[0], dtype=bool) invalid_idx = position.shape[0] # Compute the offsets of all neighboring cells offset_in_dim = jnp.arange(3) - 1 xn, yn, zn = jnp.meshgrid(offset_in_dim, offset_in_dim, offset_in_dim, indexing='ij') nx, ny, nz, capacity = id_buffer.shape total_edges = 27 * (nx * ny * nz) * (capacity ** 2) id_buffer, statistics = compute_cell_list( position, id_buffer, cutoff, mask=mask, eps=eps) # Build the neighbor list for all cells @functools.partial(jax.vmap, in_axes=(0, None, None)) @functools.partial(jax.vmap, in_axes=(None, 0, None)) @functools.partial(jax.vmap, in_axes=(None, None, 0)) def cell_candidate_fn(cx, cy, cz): # Get the ids of all neighboring cells. For at least # three cells, this should not count edges double all_cx = jnp.mod(cx + xn, nx).ravel() all_cy = jnp.mod(cy + yn, ny).ravel() all_cz = jnp.mod(cz + zn, nz).ravel() # These are the indices of all particles that could be neighbors. # Senders are only local atoms such that no directed edges will be # coundted double receiver_idxs = id_buffer[all_cx, all_cy, all_cz, :] sender_idxs = id_buffer[cx, cy, cz, :] # Transform to sparse list cell_senders, cell_receivers = jnp.meshgrid( sender_idxs, receiver_idxs.ravel(), indexing='ij') cell_senders = cell_senders.ravel() cell_receivers = cell_receivers.ravel() sender_pos = position[cell_senders, :] receiver_pos = position[cell_receivers, :] # Compute all the distances (senders, receivers) dist_sq = jnp.sum((receiver_pos - sender_pos) ** 2, axis=-1) cut_sq = jnp.square(cutoff) # Select valid neighbors within cutoff that are not self cell_mask = dist_sq < cut_sq # Remove edges from or to invalid receivers cell_mask = jnp.logical_and(cell_mask, mask[cell_senders]) cell_mask = jnp.logical_and(cell_mask, mask[cell_receivers]) # Remove edges to self cell_mask = jnp.logical_and(cell_mask, cell_senders != cell_receivers) # Apply invalid indices form senders to receivers and vice versa cell_mask = jnp.logical_and(cell_mask, cell_senders < invalid_idx) cell_mask = jnp.logical_and(cell_mask, cell_receivers < invalid_idx) # Apply mask to neighbor list cell_senders = jnp.where(cell_mask, cell_senders, invalid_idx) cell_receivers = jnp.where(cell_mask, cell_receivers, invalid_idx) print( f"Senders: {cell_senders.shape}, Receivers: {cell_receivers.shape}") return cell_senders, cell_receivers new_senders, new_receivers = cell_candidate_fn( jnp.arange(nx), jnp.arange(ny), jnp.arange(nz) ) new_senders, new_receivers = new_senders.ravel(), new_receivers.ravel() max_neighbors = senders.size valid_neighbors = jnp.sum(new_receivers < invalid_idx) _, prune_idx = lax.top_k(-new_receivers, max_neighbors) valid_pruned_neighbors = jnp.sum(new_receivers[prune_idx] < invalid_idx) statistics = statistics.set( max_neighbors=valid_neighbors, cell_too_small=valid_pruned_neighbors) return (new_senders[prune_idx], new_receivers[prune_idx]), statistics
[docs] def prune_neighbor_list(list, local, max_edges, nbr_order: int, half_list: bool = False): """Prunes the neighbor list by removing edges irrelevant to local atoms. For simplicity, a neighbor list might be built for all atoms within a, e.g., rectangular domain. However, this list can contain atoms that are not relevant for the force computation of local atoms. Therefore, this function prunes the neighbor list by removing all edges that are not relevant for the local atoms. For example, given a simple lennard-jones potential, the neighbor list should only contain atoms that are first-order neighbors to any local atoms. Args: list: Sparse neighbor list to prune. local: Mask specifying the local atoms. max_edges: Maximum number of edges in the pruned list. nbr_order: Maximum order of neighbors required for the force computation. half_list: If True, the neighbor list is a half list. This means that an edge from i to j implies an edge from j to i. Returns: Returns the pruned neighbor list and the number of valid edges. """ if half_list: # Make a full list from the half list senders = jnp.concat([list.senders, list.receivers], axis=0) receivers = jnp.concat([list.receivers, list.senders], axis=0) else: # Fill up the list with invalid indices. Required to ensure consistency # with half list setting invalid_fill = jnp.full_like(list.senders, local.size) senders = jnp.concat([list.senders, invalid_fill], axis=0) receivers = jnp.concat([list.receivers, invalid_fill], axis=0) list = list.set(senders=senders, receivers=receivers) def _update(reachable, _): # Send reachable messages to neighbors. May should act like a logical # any reachable |= jax.ops.segment_max( reachable[list.senders], list.receivers, reachable.size) # jax.debug.print("Update {} with {} -> {}", list.senders, reachable[list.senders], list.receivers) # jax.debug.print("After update: {}", reachable) return reachable, _ # Non-newton case: # Relevant sender atoms are all atoms that are reachable via two times # the message passing interactions from a local atom of the domain. # Additional edges within the cutoff are required to correctly encode # the environment. We need the correct energy even for some ghost atoms # to compute forces without communication between domains. reachable, _ = lax.scan(_update, local, jnp.arange(nbr_order)) mask = reachable[list.senders] & reachable[list.receivers] mask &= (list.senders < local.size) & (list.receivers < local.size) senders = jnp.where(mask, list.senders, local.size) receivers = jnp.where(mask, list.receivers, local.size) n_valid = jnp.sum(mask) # Reduce the size of the neighbor list mask, select = lax.top_k(mask, k=max_edges) senders = senders[select] receivers = receivers[select] return SimpleSparseNeighborList(senders, receivers, mask), n_valid
[docs] def prune_neighbor_list_dense(list, local, nbr_order: int): """Prunes a dense neighbor list. Args: list: Sparse neighbor list to prune. local: Mask specifying the local atoms. nbr_order: Maximum order of neighbors required for the force computation. Returns: Returns the pruned neighbor list, the number of valid edges, and the number of triplets from the valid edges. """ def _update(reachable, _): # Send reachable messages to neighbors. Any connection to a reachable # node makes the node itself reachable print(f"Shape of reachable: {reachable[list.nbrs].shape}") reachable = jnp.any( reachable[list.nbrs] & list.nbrs < local.size, axis=1, keepdims=False) print(f"Shape of reachable (later): {reachable.shape}") print(f"Shape due to {(reachable[list.nbrs] & list.nbrs < reachable.size).shape}") return reachable, _ reachable, _ = lax.scan(_update, local, jnp.arange(nbr_order)) # Every node unreachable does not send out edges (row will be zero). # Every node unreachable does not receive edges (check indices). nbrs = jnp.where(reachable[:, None], list.nbrs, local.size) nbrs = jnp.where(reachable[nbrs], nbrs, local.size) nbrs_per_atom = jnp.sum(nbrs < reachable.size, axis=1) max_edges = jnp.sum(nbrs_per_atom) max_triplets = jnp.sum(nbrs_per_atom * (nbrs_per_atom - 1)) return list.set(nbrs=nbrs), (max_edges, max_triplets)