custom_partition#
Custom functions to analyze the neighborlist graph.
- masked_neighbor_list(displacement_or_metric, r_cutoff, dr_threshold=None, capacity_multiplier=1.25, format=NeighborListFormat.Dense)[source]#
Returns a function that builds a list neighbors for collections of points.
Adapts the JAX, M.D. neighbor list
jax_md.partition.neighbor_list()to allow for masking and to enforce rebuilding of the neighbor list.- Parameters:
displacement_or_metric – A function d(R_a, R_b) that computes the displacement between pairs of points.
r_cutoff (
float) – A scalar specifying the neighborhood radius.dr_threshold (
Optional[float,None]) – A scalar specifying the maximum distance particles can move before rebuilding the neighbor list. If specified to None, the neighbor list will always be rebuilt.capacity_multiplier (
float) – A floating point scalar specifying the fractional increase in maximum neighborhood occupancy we allocate compared with the maximum in the example positions.format – The format of the neighbor list; see the
NeighborListFormat()enum for details about the different choices for formats. Defaults to Dense.
- Return type:
Callable[[Array,Optional[NeighborList,None],Optional[int,None]],NeighborList]- Returns:
A NeighborListFns object that contains a method to allocate a new neighbor list and a method to update an existing neighbor list.
- mask_neighbor_list(nbrs, mask=None)[source]#
Masks the neighbor list indices.
- Parameters:
nbrs (
NeighborList) – Dense or sparse neighbor list.mask (
Array) – Boolean array masking valid particles (True). Edges from and to invalid particles (False) are removed from the neighbor list.
- Return type:
- Returns:
Returns a neighbor list without edges to invalid particles.
- exclude_from_neighbor_list(neighbor, exclude_idx, exclude_mask)[source]#
Function excluding edges from the neighbor list.
- Parameters:
neighbor (
NeighborList) – Neighbor listexclude_idx – Indices of the edges that should be excluded, if contained.
exclude_mask – Boolean array whether specified edges should be excluded.
Example
>>> from pathlib import Path >>> root = Path.cwd().parent
>>> import mdtraj >>> from jax_md import space >>> from jax import numpy as jnp >>> from jax_md_mod.custom_partition import masked_neighbor_list
>>> pdb = mdtraj.load(root / "examples/data/ethane.pdb") >>> r_init = jnp.asarray(pdb.xyz[0], dtype=jnp.float32) >>> box = jnp.array(1.0)
>>> displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=True) >>> neighbor_fns = masked_neighbor_list( ... displacement_fn, r_cutoff=1.0, dr_threshold=0.05, ... )
We can now exclude, e.g., the first C atom from the neighbor list
>>> mask = jnp.array([0, 1, 1, 1, 1, 1, 1, 1]) >>> nbrs_init = neighbor_fns.allocate(r_init, mask=mask) >>> >>> print(nbrs_init.idx) [[8 8 8 8 8 8 8] [2 3 4 5 6 7 8] [1 3 4 5 6 7 8] [1 2 4 5 6 7 8] [1 2 3 5 6 7 8] [1 2 3 4 6 7 8] [1 2 3 4 5 7 8] [1 2 3 4 5 6 8]]
Whenever the neighbor list must be recomputed (dR threshold), a new mask is applied
>>> mask = jnp.array([1, 0, 1, 1, 1, 1, 1, 1]) >>> print(neighbor_fns.update(r_init, nbrs_init, mask=mask).idx) [[2 3 4 5 6 7 8] [8 8 8 8 8 8 8] [0 3 4 5 6 7 8] [0 2 4 5 6 7 8] [0 2 3 5 6 7 8] [0 2 3 4 6 7 8] [0 2 3 4 5 7 8] [0 2 3 4 5 6 8]]
- Return type:
- Returns:
Returns a neighbor list of the same format with excluded edges.