deploy.exporter#

Exporting potential models to MLIR.

class Exporter[source]#

Exports a potential model to an MLIR module.

To export a potential model, subclass this class, select an appropriate graph type and define the energy function:

Usage:

>>> import jax.numpy as jnp
>>> from jax_md_mod import custom_energy
>>> from jax_md import partition, space
>>> from chemtrain.deploy import exporter, graphs
>>> class LennardJonesExport(exporter.Exporter):
...
...     graph_type = graphs.SimpleSparseNeighborList
...     r_cutoff = 5.0
...     unit_style = "real"
...     nbr_order = [1, 1]
...
...     def energy_fn(self, pos, species, graph):
...
...         neighbors = partition.NeighborList(
...             jnp.stack((graph.senders, graph.receivers)),
...             pos, None, None, graph.senders.size, partition.Sparse,
...             None, None, None
...         )
...
...         assert neighbors.idx.shape[0] == 2, "Wrong shape"
...
...         displacement_fn, _ = space.free()
...         apply_fn = custom_energy.customn_lennard_jones_neighbor_list(
...             displacement_fn, None, None,
...             sigma=jnp.asarray([3.165]), epsilon=jnp.asarray([1.0]),
...             r_onset=4.0, r_cutoff=5.0,
...             initialize_neighbor_list=False,
...             per_particle=True # Important for export
...         )
...
...         return apply_fn(pos, neighbors, species=species)
...
>>> model = LennardJonesExport()
>>> try:
...     model.save("out.ptb")
... except AssertionError as e:
...     # We need to call the export method first
...     print(f"Error: {e}")
Error: Model has not been exported yet. Please call `export()` first.
>>>
>>> model.export()
Variables:
  • graph_type (chemtrain.deploy.graphs.NeighborList) – Specifies the required neighborhood representation and how to generate it from the input data. See ref:chemtrain.deploy.graphs.

  • nbr_order (List[int]) – List of two integers specifying the number of neighbors required for the newton and non-newton setting to correctly compute forces.

  • r_cutoff (float) – Cutoff radius for the potential.

  • unit_style (str) – Specifies the units in which the potential requires positions and returns energies. The force units depend solely on the length and energy units.

  • has_aux (bool) – If True, the energy function returns additional quantities besides the potential energy as dictionary.

abstractmethod energy_fn(position, species, graph)[source]#

Computes the energy for positions and a graph representation.

Parameters:
  • position – (N, dim) Array of particle positions, including ghost atoms that are not within the local domain.

  • species

    1. Array of atoms species.

  • graph – Graph representation of the neighborhood around atoms.

Returns:

Must return an energy contribution associated to each particle.

export()[source]#

Exports the potential model to an MLIR module.

Return type:

None

graph_type#

alias of SimpleSparseNeighborList

save(file)[source]#

Saves the exported protobuffer to a file.

Parameters:

file (str) – Path to the file where the model should be saved.

Return type:

None