deploy.exporter#
Exporting potential models to MLIR.
- class Exporter[source]#
Exports a potential model to an MLIR module.
To export a potential model, subclass this class, select an appropriate graph type and define the energy function:
Usage:
>>> import jax.numpy as jnp >>> from jax_md_mod import custom_energy >>> from jax_md import partition, space >>> from chemtrain.deploy import exporter, graphs >>> class LennardJonesExport(exporter.Exporter): ... ... graph_type = graphs.SimpleSparseNeighborList ... r_cutoff = 5.0 ... unit_style = "real" ... nbr_order = [1, 1] ... ... def energy_fn(self, pos, species, graph): ... ... neighbors = partition.NeighborList( ... jnp.stack((graph.senders, graph.receivers)), ... pos, None, None, graph.senders.size, partition.Sparse, ... None, None, None ... ) ... ... assert neighbors.idx.shape[0] == 2, "Wrong shape" ... ... displacement_fn, _ = space.free() ... apply_fn = custom_energy.customn_lennard_jones_neighbor_list( ... displacement_fn, None, None, ... sigma=jnp.asarray([3.165]), epsilon=jnp.asarray([1.0]), ... r_onset=4.0, r_cutoff=5.0, ... initialize_neighbor_list=False, ... per_particle=True # Important for export ... ) ... ... return apply_fn(pos, neighbors, species=species) ... >>> model = LennardJonesExport() >>> try: ... model.save("out.ptb") ... except AssertionError as e: ... # We need to call the export method first ... print(f"Error: {e}") Error: Model has not been exported yet. Please call `export()` first. >>> >>> model.export()
- Variables:
graph_type (chemtrain.deploy.graphs.NeighborList) – Specifies the required neighborhood representation and how to generate it from the input data. See ref:chemtrain.deploy.graphs.
nbr_order (List[int]) – List of two integers specifying the number of neighbors required for the newton and non-newton setting to correctly compute forces.
r_cutoff (float) – Cutoff radius for the potential.
unit_style (str) – Specifies the units in which the potential requires positions and returns energies. The force units depend solely on the length and energy units.
has_aux (bool) – If True, the energy function returns additional quantities besides the potential energy as dictionary.
- abstractmethod energy_fn(position, species, graph)[source]#
Computes the energy for positions and a graph representation.
- Parameters:
position – (N, dim) Array of particle positions, including ghost atoms that are not within the local domain.
species –
Array of atoms species.
graph – Graph representation of the neighborhood around atoms.
- Returns:
Must return an energy contribution associated to each particle.
- graph_type#
alias of
SimpleSparseNeighborList