Source code for jax_md_mod.custom_partition

# Copyright 2019 Google LLC
# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Custom functions to analyze the neighborlist graph."""
import importlib
import warnings
from typing import Union

import functools

import jax
import jax.numpy as jnp
from jax import Array, export

import numpy as onp

from jax_md import partition, space
from jax_md_mod.model import sparse_graph


def mask_dense(idx, mask=None):
    # Mask out edges to self
    self_mask = (idx == jnp.arange(idx.shape[0])[:, jnp.newaxis])

    # Only mask edges to self
    if mask is None:
        return jnp.where(self_mask, idx.shape[0], idx)

    # Mask out all senders
    sender_mask = jnp.logical_or(
        jnp.logical_not(mask)[:, jnp.newaxis], self_mask
    )

    # Mask out all receivers
    total_mask = jnp.logical_or(
        sender_mask, jnp.logical_not(mask[idx])
    )

    return jnp.where(total_mask, idx.shape[0], idx)


[docs] def mask_neighbor_list(nbrs: partition.NeighborList, mask: Array = None) -> partition.NeighborList: """Masks the neighbor list indices. Args: nbrs: Dense or sparse neighbor list. mask: Boolean array masking valid particles (True). Edges from and to invalid particles (False) are removed from the neighbor list. Returns: Returns a neighbor list without edges to invalid particles. """ def mask_sparse(idx, mask): # Mask out all invalid edges senders, receivers = idx edge_mask = jnp.logical_or( jnp.logical_not(mask[senders]), jnp.logical_not(mask[receivers]) ) return jnp.where(edge_mask[jnp.newaxis, :], nbrs.reference_position.shape[0], idx) if nbrs.format == partition.NeighborListFormat.Dense: new_idx = mask_dense(nbrs.idx, mask) else: new_idx = mask_sparse(nbrs.idx, mask) new_position = jnp.where(mask[:, jnp.newaxis], nbrs.reference_position, 0.0) return nbrs.set(idx=new_idx, reference_position=new_position)
[docs] def masked_neighbor_list(displacement_or_metric, r_cutoff: float, dr_threshold: Union[float, None] = None, capacity_multiplier: float = 1.25, format = partition.NeighborListFormat.Dense, ) -> partition.NeighborFn: """Returns a function that builds a list neighbors for collections of points. Adapts the JAX, M.D. neighbor list :func:`jax_md.partition.neighbor_list` to allow for masking and to enforce rebuilding of the neighbor list. Args: displacement_or_metric: A function `d(R_a, R_b)` that computes the displacement between pairs of points. r_cutoff: A scalar specifying the neighborhood radius. dr_threshold: A scalar specifying the maximum distance particles can move before rebuilding the neighbor list. If specified to None, the neighbor list will always be rebuilt. capacity_multiplier: A floating point scalar specifying the fractional increase in maximum neighborhood occupancy we allocate compared with the maximum in the example positions. format: The format of the neighbor list; see the :meth:`NeighborListFormat` enum for details about the different choices for formats. Defaults to `Dense`. Returns: A NeighborListFns object that contains a method to allocate a new neighbor list and a method to update an existing neighbor list. """ always_recompute = dr_threshold is None if always_recompute: dr_threshold = 0.0 partition.is_format_valid(format) r_cutoff = jax.lax.stop_gradient(r_cutoff) dr_threshold = jax.lax.stop_gradient(dr_threshold) cutoff = r_cutoff + dr_threshold cutoff_sq = cutoff ** 2 threshold_sq = (dr_threshold / partition.f32(2)) ** 2 metric_sq = partition._displacement_or_metric_to_metric_sq(displacement_or_metric) @functools.partial(jax.jit, static_argnums=0) def candidate_fn(positionShape) -> Array: candidates = jnp.arange(positionShape[0]) return jnp.broadcast_to(candidates[None, :], (positionShape[0], positionShape[0])) @jax.jit def mask_self_fn(idx: Array) -> Array: self_mask = idx == jnp.reshape(jnp.arange(idx.shape[0], dtype=partition.i32), (idx.shape[0], 1)) return jnp.where(self_mask, idx.shape[0], idx) @jax.jit def prune_neighbor_list_dense(position: Array, idx: Array, **kwargs ) -> Array: d = functools.partial(metric_sq, **kwargs) d = space.map_neighbor(d) N = position.shape[0] neigh_position = position[idx] dR = d(position, neigh_position) mask = (dR < cutoff_sq) & (idx < N) out_idx = N * jnp.ones(idx.shape, partition.i32) cumsum = jnp.cumsum(mask, axis=1) index = jnp.where(mask, cumsum - 1, idx.shape[1] - 1) p_index = jnp.arange(idx.shape[0])[:, None] out_idx = out_idx.at[p_index, index].set(idx) max_occupancy = jnp.max(cumsum[:, -1]) return out_idx, max_occupancy @jax.jit def prune_neighbor_list_sparse(position: Array, idx: Array, **kwargs ) -> Array: d = functools.partial(metric_sq, **kwargs) d = space.map_bond(d) N = position.shape[0] sender_idx = jnp.broadcast_to(jnp.arange(N)[:, None], idx.shape) sender_idx = jnp.reshape(sender_idx, (-1,)) receiver_idx = jnp.reshape(idx, (-1,)) dR = d(position[sender_idx], position[receiver_idx]) mask = (dR < cutoff_sq) & (receiver_idx < N) if format is partition.NeighborListFormat.OrderedSparse: mask = mask & (receiver_idx < sender_idx) out_idx = N * jnp.ones(receiver_idx.shape, partition.i32) cumsum = jnp.cumsum(mask) index = jnp.where(mask, cumsum - 1, len(receiver_idx) - 1) receiver_idx = out_idx.at[index].set(receiver_idx) sender_idx = out_idx.at[index].set(sender_idx) max_occupancy = cumsum[-1] return jnp.stack((receiver_idx, sender_idx)), max_occupancy def neighbor_list_fn(position: Array, neighbors = None, extra_capacity: int = 0, **kwargs) -> partition.NeighborList: N = position.shape[0] mask = kwargs.get("mask", jnp.ones(N, dtype=jnp.bool_)) position = jnp.where(mask[:, jnp.newaxis], position, jnp.inf) def neighbor_fn(position_and_error, max_occupancy=None): position, err = position_and_error idx = candidate_fn(position.shape) idx = mask_dense(idx, mask=mask) if partition.is_sparse(format): idx, occupancy = prune_neighbor_list_sparse(position, idx, **kwargs) else: idx, occupancy = prune_neighbor_list_dense(position, idx, **kwargs) if max_occupancy is None: _extra_capacity = (extra_capacity if not partition.is_sparse(format) else N * extra_capacity) max_occupancy = int(occupancy * capacity_multiplier + _extra_capacity) if max_occupancy > idx.shape[-1]: max_occupancy = idx.shape[-1] if not partition.is_sparse(format): capacity_limit = N - 1 elif format is partition.NeighborListFormat.Sparse: capacity_limit = N * (N - 1) else: capacity_limit = N * (N - 1) // 2 if max_occupancy > capacity_limit: max_occupancy = capacity_limit idx = idx[:, :max_occupancy] update_fn = (neighbor_list_fn if neighbors is None else neighbors.update_fn) return partition.NeighborList( idx, position, err.update(partition.PEC.NEIGHBOR_LIST_OVERFLOW, occupancy > max_occupancy), None, max_occupancy, format, None, None, update_fn) # pytype: disable=wrong-arg-count nbrs = neighbors if nbrs is None: return neighbor_fn((position, partition.PartitionError(jnp.zeros((), jnp.uint8)))) neighbor_fn = functools.partial(neighbor_fn, max_occupancy=nbrs.max_occupancy) d = functools.partial(metric_sq, **kwargs) d = jax.vmap(d) if always_recompute: print(f"Always recompute the neighbor list!") return neighbor_fn((position, nbrs.error)) else: return jax.lax.cond( jnp.logical_or( jnp.any(d(position, nbrs.reference_position) > threshold_sq), jnp.any(jnp.logical_xor(position == jnp.inf, nbrs.reference_position == jnp.inf)) ), (position, nbrs.error), neighbor_fn, nbrs, lambda x: x) def allocate_fn(position: Array, extra_capacity: int = 0, **kwargs ): return neighbor_list_fn(position, extra_capacity=extra_capacity, **kwargs) def update_fn(position: Array, neighbors, **kwargs ): return neighbor_list_fn(position, neighbors, **kwargs) return partition.NeighborListFns(allocate_fn, update_fn) # pytype: disable=wrong-arg-count
[docs] def exclude_from_neighbor_list(neighbor: partition.NeighborList, exclude_idx, exclude_mask) -> partition.NeighborList: """Function excluding edges from the neighbor list. Args: neighbor: Neighbor list exclude_idx: Indices of the edges that should be excluded, if contained. exclude_mask: Boolean array whether specified edges should be excluded. Example: >>> from pathlib import Path >>> root = Path.cwd().parent >>> import mdtraj >>> from jax_md import space >>> from jax import numpy as jnp >>> from jax_md_mod.custom_partition import masked_neighbor_list >>> pdb = mdtraj.load(root / "examples/data/ethane.pdb") >>> r_init = jnp.asarray(pdb.xyz[0], dtype=jnp.float32) >>> box = jnp.array(1.0) >>> displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=True) >>> neighbor_fns = masked_neighbor_list( ... displacement_fn, r_cutoff=1.0, dr_threshold=0.05, ... ) We can now exclude, e.g., the first C atom from the neighbor list >>> mask = jnp.array([0, 1, 1, 1, 1, 1, 1, 1]) >>> nbrs_init = neighbor_fns.allocate(r_init, mask=mask) >>> >>> print(nbrs_init.idx) [[8 8 8 8 8 8 8] [2 3 4 5 6 7 8] [1 3 4 5 6 7 8] [1 2 4 5 6 7 8] [1 2 3 5 6 7 8] [1 2 3 4 6 7 8] [1 2 3 4 5 7 8] [1 2 3 4 5 6 8]] Whenever the neighbor list must be recomputed (dR threshold), a new mask is applied >>> mask = jnp.array([1, 0, 1, 1, 1, 1, 1, 1]) >>> print(neighbor_fns.update(r_init, nbrs_init, mask=mask).idx) [[2 3 4 5 6 7 8] [8 8 8 8 8 8 8] [0 3 4 5 6 7 8] [0 2 4 5 6 7 8] [0 2 3 5 6 7 8] [0 2 3 4 6 7 8] [0 2 3 4 5 7 8] [0 2 3 4 5 6 8]] Returns: Returns a neighbor list of the same format with excluded edges. """ @functools.partial(jax.vmap, in_axes=(0, 0, None, None, None, None)) def _exclude(particle_nbrs, idx, invalid, mask, ref_i, ref_j): # Check for all neighbors whether they are already part of the bond # list. exclude = jax.vmap( # Map over all pairs i, j that are part of the neighbor list lambda i: jnp.any( # Check whether these pairs appear in the bond list jnp.logical_and( mask, jnp.logical_or( jnp.logical_and(i == ref_i, idx == ref_j), jnp.logical_and(i == ref_j, idx == ref_i) ) ) ) )(particle_nbrs) return jnp.where(exclude, invalid, particle_nbrs) @functools.partial(jax.vmap, in_axes=(1, None, None, None, None), out_axes=1) def _exclude_sparse(nbr_idx, invalid, ref_i, ref_j, mask): # Map through the entries in the neighbor list and exclude them if # they are part of the bond list exclude = jnp.logical_or( jnp.logical_and(nbr_idx[0] == ref_i, nbr_idx[1] == ref_j), jnp.logical_and(nbr_idx[0] == ref_j, nbr_idx[1] == ref_i) ) # Mask out invalid bonds or angles exclude = jnp.any(jnp.logical_and(exclude, mask)) return jnp.where(exclude, invalid, nbr_idx) # Call the respective function depending on the format of the neighbor list invalid_idx = neighbor.idx.shape[0] if neighbor.format == partition.NeighborListFormat.Dense: new_idx = _exclude( neighbor.idx, jnp.arange(invalid_idx), invalid_idx, exclude_mask, exclude_idx[:, 0], exclude_idx[:, 1] ) else: new_idx = _exclude_sparse( neighbor.idx, invalid_idx, exclude_idx[:, 0], exclude_idx[:, 1], exclude_mask ) return neighbor.set(idx=new_idx)
[docs] def get_triplet_indices(neighbor: partition.NeighborList): """Returns indices for all triplets of the neighbor list.""" @functools.partial(jax.vmap, in_axes=(None, 0), out_axes=-1) def _get_triplets(idx, j): max_nbrs = idx.shape[1] # Return all bonds idx to j to_j = idx[j, :] # All permutations (remove diagonal entries) diagonal_mask = onp.mod(onp.arange(max_nbrs ** 2), max_nbrs + 1) != 0 ik_to_j = jnp.stack(jnp.meshgrid(to_j, to_j, indexing='ij'), axis=0) ik_to_j = ik_to_j.reshape((2, -1))[:, diagonal_mask] # Add the reference to the center atom ij = ik_to_j.at[1, :].set(j) jk = ik_to_j.at[0, :].set(j) return ij, jk if neighbor.format == neighbor.format.Dense: invalid_idx = neighbor.idx.shape[0] ij, jk = _get_triplets(neighbor.idx, jnp.arange(neighbor.idx.shape[0])) ij = ij.reshape((2, -1)).swapaxes(0, 1) jk = jk.reshape((2, -1)).swapaxes(0, 1) # Mask out all invalid triplets mask = jax.vmap(jnp.logical_and)( jnp.all(ij != invalid_idx, axis=-1), jnp.all(jk != invalid_idx, axis=-1), ) # Sort (simpler to later prune the triplet array) order = jnp.argsort(-1.0 * mask) ij = ij[order, :] jk = jnp.flip(jk, axis=-1)[order, :] mask = mask[order] return ij, jk, mask else: raise NotImplementedError( f"Neighbor list format {neighbor.format} not yet supported." )
def check_connectivity(neighbor: partition.NeighborList, mask=None): """Check the connectivity of the neighbor list. Args: neighbor: Neighbor list Returns: Returns True if a connection between any nodes exists. """ if mask is None: mask = jnp.ones(neighbor.reference_position.shape[0], dtype=bool) def _update_connectivity(state): reachable, idx = state if neighbor.format == partition.NeighborListFormat.Dense: pass elif neighbor.format == partition.NeighborListFormat.Sparse: senders, receivers = neighbor.idx # Propagate reachable state from senders to receivers reachable = jax.ops.segment_sum( jnp.int_(reachable[senders]), receivers, reachable.size) reachable = jnp.logical_and(reachable > 0, mask) else: raise NotImplementedError( f"Neighbor list format {neighbor.format} not yet supported." ) return reachable, idx + 1 def _search(state): reachable, idx = state # We stop when one of the following conditions is met: # 1. Iterations equal to the number of actual particles. Worst case # scenario when graph is line # 2. All valid particles are reachable return jnp.logical_and(idx < jnp.sum(mask), jnp.sum(reachable) < jnp.sum(mask)) # Find one non-masked particle and start the search from there first_nonzero = jnp.argmax(mask) reachable = jnp.logical_and(mask, jnp.arange(mask.size) == first_nonzero) reachable, _ = jax.lax.while_loop( _search, _update_connectivity, (reachable, 0) ) return jnp.sum(reachable) >= jnp.sum(mask) def find_clusters(neighbor: partition.NeighborList, mask=None): """Discovers separate subgraphs in the neighbor list. Args: neighbor: Neighbor list mask: Mask indicating whether particles are real or padded Returns: Returns a vector with unique cluster-ids to which a particle belongs to and the number of discovered separate subgraphs. """ if mask is None: mask = jnp.ones(neighbor.reference_position.shape[0], dtype=bool) def _update_connectivity(clusters, _): # Particles propagate their cluster information if neighbor.format == partition.NeighborListFormat.Dense: pass elif neighbor.format == partition.NeighborListFormat.Sparse: senders, receivers = neighbor.idx # Propagate cluster state from senders to receivers clusters = jax.ops.segment_min( jnp.int_(clusters[senders]), receivers, clusters.size) else: raise NotImplementedError( f"Neighbor list format {neighbor.format} not yet supported." ) return clusters, jnp.sum(jnp.diff(jnp.sort(clusters) * mask) > 0) + 1 # Each valid particle gets its own cluster in the beginning clusters = jnp.where(mask, jnp.arange(mask.size), mask.size) clusters -= jnp.min(clusters) # Start the cluster counter with 0 _, nclusters = jax.lax.scan(_update_connectivity, clusters, jnp.arange(clusters.size)) return clusters, nclusters[-1]
[docs] def to_networkx(neighbor: partition.NeighborList): nx = importlib.import_module('networkx') graph = nx.Graph() if neighbor.format == partition.NeighborListFormat.Dense: num_particles, max_neighbors = neighbor.idx.shape for i in range(num_particles): for j in range(max_neighbors): if neighbor.idx[i, j] == num_particles: continue graph.add_edge(int(i), int(neighbor.idx[i, j])) else: for i, j in neighbor.idx.T: if i == neighbor.reference_position.shape[0]: continue if j == neighbor.reference_position.shape[0]: continue graph.add_edge(int(i), int(j)) return graph
def test_graph_statistics(displacement_fn: space.DisplacementFn, position: Array, neighbor: partition.NeighborList, r_cutoff: Union[float, Array], max_edge_multiplier: float = 1.5, ): """Computes neighbor list statistics for test conformation. Args: displacement_fn: Displacement function position: Particle position for a representative configuration neighbor: Neighbor list for a representative configuration r_cutoff: Cutoff radius of edges. max_edge_multiplier: Multiplier to increase maximum number of edges based on edges found in representative configuration. Returns: Returns a tuple of average number of neighbors and maximum number of edges. """ # Checking only necessary if neighbor list is dense if neighbor.format == partition.Dense: print('Capping edges and triplets. Beware of overflow, which is ' 'currently not being detected.') testgraph, _ = sparse_graph.sparse_graph_from_neighborlist( displacement_fn, position, neighbor, r_cutoff) max_edges = jnp.int32(jnp.ceil(testgraph.n_edges * max_edge_multiplier)) # cap maximum edges and angles to avoid overflow from multiplier n_particles, n_neighbors = neighbor.idx.shape max_edges = min(max_edges, n_particles * n_neighbors) print(f"Estimated max. {max_edges} edges.") avg_num_neighbors = testgraph.n_edges / n_particles else: n_particles = neighbor.reference_position.shape[0] max_edges = neighbor.idx.shape[0] avg_num_neighbors = onp.sum(neighbor.idx[0] < n_particles) avg_num_neighbors /= n_particles return avg_num_neighbors, max_edges def readout_vectors(displacement_fn: space.DisplacementFn, r_cutoff: Union[float, Array], position: Array, neighbor: partition.NeighborList, species: Array = None, mask: Array = None, max_edges = None, edges_per_particle: float = None, **kwargs ): """Computes neighbor list statistics for test conformation. Args: displacement_fn: Displacement function r_cutoff: Cutoff radius of edges. position: Particle position for a representative configuration neighbor: Neighbor list for a representative configuration species: Species of atoms. mask: Mask indicating whether particles are real or padded. max_edges: Maximum number of edges to consider. edges_per_particle: Limit the number of edges to a value proportional to the number of particles. kwargs: Keyword arguments passed to the displacement function. Returns: Returns the extracted vectors, senders, and receivers. """ dyn_displacement = functools.partial(displacement_fn, **kwargs) if edges_per_particle is not None: if max_edges is not None: raise ValueError( "Either edges_per_particle or max_edges can be specified, not both." ) # Restrict to two digits after the decimal point. This step is required # because JAX symbolic dimensions support only integer operations. factor = int(edges_per_particle * 1000) gcd = onp.gcd(factor, 1000) max_edges = (factor // gcd) * position.shape[0] // (1000 // gcd) print(f"Limit the maximum number of edges to {max_edges} " f"({factor // gcd} / {1000 // gcd} edges per particle).") if species is None: species = jnp.ones(position.shape[0], dtype=jnp.int32) if mask is None: mask = jnp.ones(position.shape[0], dtype=jnp.bool_) if max_edges is not None and export.is_symbolic_dim(position.shape[0]): if not export.is_symbolic_dim(max_edges): raise TypeError( "max_edges must be symbolic if used in export." ) if neighbor.format == partition.Dense: graph, _ = sparse_graph.sparse_graph_from_neighborlist( dyn_displacement, position, neighbor, r_cutoff, species, max_edges=max_edges, species_mask=mask ) senders = graph.idx_i receivers = graph.idx_j else: assert neighbor.idx.shape == ( 2, neighbor.idx.shape[1]), "Neighbor list has wrong shape." senders, receivers = neighbor.idx # Set invalid edges to the cutoff to avoid numerical issues vectors = jax.vmap(dyn_displacement)(position[senders], position[receivers]) vectors = jnp.where( jnp.logical_and( jnp.logical_and(senders < position.shape[0], mask[senders]), jnp.logical_and(receivers < position.shape[0], mask[senders]) )[:, jnp.newaxis], vectors, r_cutoff) if max_edges is not None: # Sort vectors by length and remove up to max_edges edges lengths = jnp.linalg.norm(vectors, axis=-1) sort_idx = jnp.argsort(lengths) vectors = vectors[sort_idx][:max_edges] senders = senders[sort_idx][:max_edges] receivers = receivers[sort_idx][:max_edges] return vectors, senders, receivers