compose.utils

compose.utils#

class ApplyFn(*args, **kwargs)[source]#

GNN apply function protocol.

batch_apply_fn(_apply_fn)[source]#

Write custom vmap rules for the apply function.

Instead of batching over graphs, combines all graphs into a supergraph. This step avoids vmapping operations in the neural network.

Parameters:

apply_fn – Mapping from (params, vectors, senders, receivers, species, mask) to per-particle energies.

Return type:

ApplyFn

Returns:

A wrapped apply function with custom vmap rules that avoids vmapping in the neural network.

flatten_graph(axis_size, in_batched, senders, receivers, edge_features, node_features)[source]#

Flattens batched graphs into a supergraph.