Source code for chemtrain.learn.property_prediction
# Copyright 2023 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains methods to learn a molecular properties from features
of a neural network used for potential energy prediction.
"""
from jax import numpy as jnp, vmap
from jax_md_mod.model import sparse_graph
[docs]
def build_dataset(targets, graph_dataset):
"""Builds dataset in format that is used for data loading and throughout
property predictions.
Args:
targets: Dict containing all targets to be predicted. Can be retrieved
in error_fn under the respective key.
graph_dataset: Dataset of graphs, e.g. as obtained from
:func:`jax_md_mod.model.sparse_graph.convert_dataset_to_graphs`.
Returns:
A dictionary containing the combined dataset and a list of target keys
"""
target_keys = list(targets)
target_keys.append('species_mask')
return {**targets, **graph_dataset.to_dict()}, target_keys
[docs]
def init_model(prediction_model):
"""Initializes a model that returns predictions for a single observation."""
def mol_prediction_model(params, observation):
graph = sparse_graph.SparseDirectionalGraph.from_dict(observation)
predictions = prediction_model(params, graph)
return predictions
return mol_prediction_model
[docs]
def init_loss_fn(error_fn):
"""Returns a loss function to optimize model parameters.
Signature of error_fn::
error = error_fn(predictions, batch, mask)
where mask has the same shape as species to mask padded particles.
Args:
model: Molecular property prediction model (Haiku apply_fn).
error_fn: Error model quantifying the discrepancy between predictions
and respective targets.
"""
def loss_fn(predictions, batch):
mask = jnp.ones_like(predictions) * batch['species_mask']
return error_fn(predictions, batch, mask)
return loss_fn
[docs]
def per_species_results(species, per_atom_quantities, species_idxs):
"""Sorts per-atom results by species and returns a per-species mean.
Only real (non-masked) particles should be input.
Args:
species: An array storing for each atom the corresponding species.
per_atom_quantities: An array with the same shape as species, storing
per-atom quantities to be evaluated per-species.
species_idxs: A (species,) array storing species-types for evaluation,
e.g. jnp.unique(species).
Returns:
A (species,) array of per-species quantities.
"""
@vmap
def process_single_species(species_idx):
species_mask = (species == species_idx)
species_members = jnp.count_nonzero(species_mask)
screened_results = jnp.where(species_mask, per_atom_quantities, 0.)
mean_if_species_exists = jnp.sum(screened_results) / species_members
return jnp.where(species_members == 0, 0., mean_if_species_exists)
return process_single_species(species_idxs)
[docs]
def per_species_box_errors(dataset, per_atom_errors):
"""Computes for each snapshot in the provided graph dataset, the per-species
error.
Args:
dataset: Graph dataset containing the snapshots of interest.
per_atom_errors: Per-atom error for each atom in the dataset.
Has same shape as ``dataset['species']``.
Returns:
Mean per-species error for each snapshot in the dataset.
"""
mask = dataset['species_mask']
species = dataset['species']
real_species = species[mask]
unique_species = jnp.unique(real_species)
species_masked = jnp.where(mask, species, 1000) # species 1000 nonexistant
per_box_and_species_fn = vmap(per_species_results, in_axes=(0, 0, None))
per_box_species_errors = per_box_and_species_fn(
species_masked, per_atom_errors, unique_species)
distinct_per_box_species = jnp.sum(per_box_species_errors > 0., axis=1)
mean_per_box_species_errors = (jnp.sum(per_box_species_errors, axis=1)
/ distinct_per_box_species)
return mean_per_box_species_errors