custom_partition#

Custom functions to analyze the neighborlist graph.

masked_neighbor_list(displacement_or_metric, r_cutoff, dr_threshold=None, capacity_multiplier=1.25, format=NeighborListFormat.Dense)[source]#

Returns a function that builds a list neighbors for collections of points.

Adapts the JAX, M.D. neighbor list jax_md.partition.neighbor_list() to allow for masking and to enforce rebuilding of the neighbor list.

Parameters:
  • displacement_or_metric – A function d(R_a, R_b) that computes the displacement between pairs of points.

  • r_cutoff (float) – A scalar specifying the neighborhood radius.

  • dr_threshold (Optional[float, None]) – A scalar specifying the maximum distance particles can move before rebuilding the neighbor list. If specified to None, the neighbor list will always be rebuilt.

  • capacity_multiplier (float) – A floating point scalar specifying the fractional increase in maximum neighborhood occupancy we allocate compared with the maximum in the example positions.

  • format – The format of the neighbor list; see the NeighborListFormat() enum for details about the different choices for formats. Defaults to Dense.

Return type:

Callable[[Array, Optional[NeighborList, None], Optional[int, None]], NeighborList]

Returns:

A NeighborListFns object that contains a method to allocate a new neighbor list and a method to update an existing neighbor list.

mask_neighbor_list(nbrs, mask=None)[source]#

Masks the neighbor list indices.

Parameters:
  • nbrs (NeighborList) – Dense or sparse neighbor list.

  • mask (Array) – Boolean array masking valid particles (True). Edges from and to invalid particles (False) are removed from the neighbor list.

Return type:

NeighborList

Returns:

Returns a neighbor list without edges to invalid particles.

exclude_from_neighbor_list(neighbor, exclude_idx, exclude_mask)[source]#

Function excluding edges from the neighbor list.

Parameters:
  • neighbor (NeighborList) – Neighbor list

  • exclude_idx – Indices of the edges that should be excluded, if contained.

  • exclude_mask – Boolean array whether specified edges should be excluded.

Example

>>> from pathlib import Path
>>> root = Path.cwd().parent
>>> import mdtraj
>>> from jax_md import space
>>> from jax import numpy as jnp
>>> from jax_md_mod.custom_partition import masked_neighbor_list
>>> pdb = mdtraj.load(root / "examples/data/ethane.pdb")
>>> r_init = jnp.asarray(pdb.xyz[0], dtype=jnp.float32)
>>> box = jnp.array(1.0)
>>> displacement_fn, shift_fn = space.periodic_general(box, fractional_coordinates=True)
>>> neighbor_fns = masked_neighbor_list(
...     displacement_fn, r_cutoff=1.0, dr_threshold=0.05,
... )

We can now exclude, e.g., the first C atom from the neighbor list

>>> mask = jnp.array([0, 1, 1, 1, 1, 1, 1, 1])
>>> nbrs_init = neighbor_fns.allocate(r_init, mask=mask)
>>>
>>> print(nbrs_init.idx)
[[8 8 8 8 8 8 8]
 [2 3 4 5 6 7 8]
 [1 3 4 5 6 7 8]
 [1 2 4 5 6 7 8]
 [1 2 3 5 6 7 8]
 [1 2 3 4 6 7 8]
 [1 2 3 4 5 7 8]
 [1 2 3 4 5 6 8]]

Whenever the neighbor list must be recomputed (dR threshold), a new mask is applied

>>> mask = jnp.array([1, 0, 1, 1, 1, 1, 1, 1])
>>> print(neighbor_fns.update(r_init, nbrs_init, mask=mask).idx)
[[2 3 4 5 6 7 8]
 [8 8 8 8 8 8 8]
 [0 3 4 5 6 7 8]
 [0 2 4 5 6 7 8]
 [0 2 3 5 6 7 8]
 [0 2 3 4 6 7 8]
 [0 2 3 4 5 7 8]
 [0 2 3 4 5 6 8]]
Return type:

NeighborList

Returns:

Returns a neighbor list of the same format with excluded edges.

get_triplet_indices(neighbor)[source]#

Returns indices for all triplets of the neighbor list.

to_networkx(neighbor)[source]#