Source code for jax_md_mod.model.dimenet_basis_util._src

"""The spherical Bessel Function utility file from the original DimeNet
implementation (https://github.com/klicperajo/dimenet).
Licenced under Hippocratic License 2.1 (see LICENSE.md).
Adapted to work with newer version of scipy and numpy.
"""

import numpy as np
from scipy.optimize import brentq
from scipy import special as sp
import sympy as sym


[docs] def Jn(r, n): """Numerical spherical bessel functions of order n.""" return np.sqrt(np.pi/(2*r)) * sp.jv(n+0.5, r)
[docs] def Jn_zeros(n, k): """Compute the first k zeros of the spherical bessel functions up to order n (excluded). """ zerosj = np.zeros((n, k), dtype="float32") zerosj[0] = np.arange(1, k + 1) * np.pi points = np.arange(1, k + n) * np.pi racines = np.zeros(k + n - 1, dtype="float32") for i in range(1, n): for j in range(k + n - 1 - i): foo = brentq(Jn, points[j], points[j + 1], (i,)) racines[j] = foo points = racines zerosj[i][:k] = racines[:k] return zerosj
[docs] def spherical_bessel_formulas(n): """Computes the sympy formulas for the spherical bessel functions up to order n (excluded). """ x = sym.symbols('x') f = [sym.sin(x)/x] a = sym.sin(x)/x for i in range(1, n): b = sym.diff(a, x)/x f += [sym.simplify(b*(-x)**i)] a = sym.simplify(b) return f
[docs] def bessel_basis(n, k): """Compute the sympy formulas for the normalized and rescaled spherical bessel functions up to order n (excluded) and maximum frequency k (excluded). """ zeros = Jn_zeros(n, k) normalizer = [] for order in range(n): normalizer_tmp = [] for i in range(k): normalizer_tmp += [0.5*Jn(zeros[order, i], order+1)**2] normalizer_tmp = 1/np.array(normalizer_tmp)**0.5 normalizer += [normalizer_tmp] f = spherical_bessel_formulas(n) x = sym.symbols('x') bess_basis = [] for order in range(n): bess_basis_tmp = [] for i in range(k): bess_basis_tmp += [sym.simplify(normalizer[order] [i]*f[order].subs(x, zeros[order, i]*x))] bess_basis += [bess_basis_tmp] return bess_basis
[docs] def sph_harm_prefactor(l: int, m: int): """Computes the constant pre-factor for the spherical harmonic of degree l and order m. Args: l: :math:`l>=0` m: :math:`-l<=m<=l` """ return ((2*l+1) * sp.factorial(l-abs(m)) / (4*np.pi*sp.factorial(l+abs(m))))**0.5
[docs] def associated_legendre_polynomials(l, zero_m_only=True): """Computes sympy formulas of the associated legendre polynomials up to order l (excluded). """ z = sym.symbols('z') P_l_m = [[0]*(j+1) for j in range(l)] P_l_m[0][0] = 1 if l > 0: P_l_m[1][0] = z for j in range(2, l): P_l_m[j][0] = sym.simplify( ((2*j-1)*z*P_l_m[j-1][0] - (j-1)*P_l_m[j-2][0])/j) if not zero_m_only: for i in range(1, l): P_l_m[i][i] = sym.simplify((1-2*i)*P_l_m[i-1][i-1]) if i + 1 < l: P_l_m[i+1][i] = sym.simplify((2*i+1)*z*P_l_m[i][i]) for j in range(i + 2, l): P_l_m[j][i] = sym.simplify( ((2*j-1) * z * P_l_m[j-1][i] - (i+j-1) * P_l_m[j-2][i]) / (j - i)) return P_l_m
[docs] def real_sph_harm(l, zero_m_only=True, spherical_coordinates=True): """Computes formula strings of the the real part of the spherical harmonics up to order l (excluded). Variables are either cartesian coordinates x,y,z on the unit sphere or spherical coordinates phi and theta. """ if not zero_m_only: S_m = [0] C_m = [1] for i in range(1, l): x = sym.symbols('x') y = sym.symbols('y') S_m += [x*S_m[i-1] + y*C_m[i-1]] C_m += [x*C_m[i-1] - y*S_m[i-1]] P_l_m = associated_legendre_polynomials(l, zero_m_only) if spherical_coordinates: theta = sym.symbols('theta') z = sym.symbols('z') for i in range(len(P_l_m)): for j in range(len(P_l_m[i])): if type(P_l_m[i][j]) != int: P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta)) if not zero_m_only: phi = sym.symbols('phi') for i in range(len(S_m)): S_m[i] = S_m[i].subs(x, sym.sin( theta)*sym.cos(phi)).subs(y, sym.sin(theta)*sym.sin(phi)) for i in range(len(C_m)): C_m[i] = C_m[i].subs(x, sym.sin( theta)*sym.cos(phi)).subs(y, sym.sin(theta)*sym.sin(phi)) Y_func_l_m = [['0']*(2*j + 1) for j in range(l)] for i in range(l): Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0]) if not zero_m_only: for i in range(1, l): for j in range(1, i + 1): Y_func_l_m[i][j] = sym.simplify( 2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j]) for i in range(1, l): for j in range(1, i + 1): Y_func_l_m[i][-j] = sym.simplify( 2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j]) return Y_func_l_m