compose.utils#
- batch_apply_fn(_apply_fn)[source]#
Write custom vmap rules for the apply function.
Instead of batching over graphs, combines all graphs into a supergraph. This step avoids vmapping operations in the neural network.
- Parameters:
apply_fn – Mapping from (params, vectors, senders, receivers, species, mask) to per-particle energies.
- Return type:
- Returns:
A wrapped apply function with custom vmap rules that avoids vmapping in the neural network.