1"""
2A Python wrapper for the compiled GG functions.
3"""
4
5import ctypes
6import ctypes.util
7import os
8
9import numpy as np
10
11from . import docs_generator
12from . import utility
13
14# Attempt to load the compiled C code
15__lib_found = False
16__libgg_path = None
17cgg = None
18
19# First check the local folder
20try:
21    abs_path = os.path.dirname(os.path.abspath(__file__))
22    cgg = np.ctypeslib.load_library("gg", abs_path)
23    __libgg_path = os.path.join(abs_path, cgg._name)
24    __lib_found = True
25except OSError:
26    try:
27        cgg = np.ctypeslib.load_library("libgg", abs_path)
28        __libgg_path = os.path.join(abs_path, cgg._name)
29        __lib_found = True
30    except OSError:
31        cgg = None
32
33__order_enum = {
34    "spherical": {
35        "cca": 300,
36        "gaussian": 301,
37    },
38    "cartesian": {
39        "cca": 400,
40        "molden": 401,
41    }
42}
43
44
45def _build_collocation_ctype(nout, orbital=False):
46    """
47    Builds the ctypes signatures for the libgg C lib
48    """
49    ret = [
50        # L, npoints
51        ctypes.c_int,
52        ctypes.c_ulong,
53
54        # XYZ, stride
55        np.ctypeslib.ndpointer(dtype=np.double, ndim=1, flags=("C", "A")),
56        ctypes.c_ulong,
57
58        # Gaussian, nprim, coef, exp, center
59        ctypes.c_int,
60        np.ctypeslib.ndpointer(dtype=np.double, ndim=1, flags=("C", "A")),  # coef
61        np.ctypeslib.ndpointer(dtype=np.double, ndim=1, flags=("C", "A")),  # exp
62        np.ctypeslib.ndpointer(dtype=np.double, shape=(3, ), ndim=1, flags=("C", "A")),  # center
63
64        # Spherical
65        ctypes.c_int,
66    ]
67    if orbital:
68        ret.insert(1, ctypes.c_ulong)  # norbs
69        ret.insert(1, np.ctypeslib.ndpointer(dtype=np.double, ndim=2, flags=("C", "A")))  # orbs
70
71    # Pushback output
72    for n in range(nout):
73        ret.append(np.ctypeslib.ndpointer(dtype=np.double, ndim=2, flags=("W", "C", "A")))
74
75    return tuple(ret)
76
77
78# Bind the C object
79if cgg is not None:
80
81    # Helpers
82    cgg.gg_ncomponents.argtypes = (ctypes.c_int, ctypes.c_int)
83    cgg.gg_ncomponents.restype = ctypes.c_int
84
85    # Transposes
86    cgg.gg_naive_transpose.restype = None
87    cgg.gg_naive_transpose.argtypes = (ctypes.c_ulong, ctypes.c_ulong, np.ctypeslib.ndpointer(),
88                                       np.ctypeslib.ndpointer())
89
90    cgg.gg_fast_transpose.restype = None
91    cgg.gg_fast_transpose.argtypes = (ctypes.c_ulong, ctypes.c_ulong, np.ctypeslib.ndpointer(),
92                                      np.ctypeslib.ndpointer())
93
94    # Collocation generators
95    cgg.gg_orbitals.restype = None
96    cgg.gg_orbitals.argtypes = _build_collocation_ctype(1, orbital=True)
97
98    cgg.gg_collocation.restype = None
99    cgg.gg_collocation.argtypes = _build_collocation_ctype(1)
100
101    cgg.gg_collocation.restype = None
102    cgg.gg_collocation_deriv1.argtypes = _build_collocation_ctype(4)
103
104    cgg.gg_collocation.restype = None
105    cgg.gg_collocation_deriv2.argtypes = _build_collocation_ctype(10)
106
107    cgg.gg_collocation.restype = None
108    cgg.gg_collocation_deriv3.argtypes = _build_collocation_ctype(20)
109
110
111def c_compiled():
112    """
113    Checks if the c code has been compiled or not.
114    """
115    return __lib_found
116
117
118def _validate_c_import():
119    """
120    Throws an error if libgg is not found.
121    """
122    if c_compiled() is False:
123        raise ImportError("Compiled libgg not found. Please compile gau2grid before calling these\n"
124                          "  functions. Alternatively, use the NumPy-based collocation functions found at\n"
125                          "  gau2grid.ref.collocation or gau2grid.ref.collocation_basis. It should\n"
126                          "  be noted that these functions are dramatically slower (4-20x).\n")
127
128
129def cgg_path():
130    """
131    Returns the path to the found libgg.so/dylib/dll object.
132    """
133    _validate_c_import()
134    return __libgg_path
135
136
137def get_cgg_shared_object():
138    """
139    Returns the compiled C shared object.
140    """
141    _validate_c_import()
142
143    return cgg
144
145
146def max_L():
147    """
148    Return the maximum compiled angular momentum.
149    """
150
151    return cgg.gg_max_L()
152
153
154def ncomponents(L, spherical=True):
155    """
156    Computes the number of components for spherical and cartesian gaussians of a given L
157
158    Parameters
159    ----------
160    L : int
161        The angular momentum of the basis function
162    spherical : bool, optional
163        Spherical (True) or Cartesian (False) number of components
164
165    Returns
166    -------
167    int
168        The number of components in the gaussian.
169    """
170
171    return cgg.gg_ncomponents(L, spherical)
172
173
174def _wrapper_checks(L, xyz, spherical, spherical_order, cartesian_order):
175    if L > cgg.gg_max_L():
176        raise ValueError("LibGG was only compiled to AM=%d, requested AM=%d." % (cgg.max_L(), L))
177
178    # Check XYZ
179    if xyz.shape[0] != 3:
180        raise ValueError("XYZ array must be of shape (3, N), found %s" % str(xyz.shape))
181
182
183# Validate the input
184    try:
185        if spherical:
186            order_name = spherical_order
187            order_enum = __order_enum["spherical"][spherical_order]
188        else:
189            order_name = cartesian_order
190            order_enum = __order_enum["cartesian"][cartesian_order]
191    except KeyError:
192        raise KeyError("Order Spherical=%s:%s not understood." % (spherical, order_name))
193
194    return order_enum
195
196
197def collocation_basis(xyz, basis, grad=0, spherical=True, out=None, cartesian_order="cca", spherical_order="cca"):
198
199    return utility.wrap_basis_collocation(collocation,
200                                          xyz,
201                                          basis,
202                                          grad,
203                                          spherical=spherical,
204                                          out=out,
205                                          cartesian_order=cartesian_order,
206                                          spherical_order=spherical_order)
207
208
209# Write common docs
210collocation_basis.__doc__ = docs_generator.build_collocation_basis_docs(
211    "This function uses a optimized C library as a backend.")
212
213
214def orbital_basis(orbs, xyz, basis, spherical=True, out=None, cartesian_order="cca", spherical_order="cca"):
215
216    return utility.wrap_basis_orbital(orbital,
217                                      orbs,
218                                      xyz,
219                                      basis,
220                                      spherical=spherical,
221                                      out=out,
222                                      cartesian_order=cartesian_order,
223                                      spherical_order=spherical_order)
224
225
226orbital_basis.__doc__ = docs_generator.build_orbital_basis_docs(
227    "This function uses a optimized C library as a backend.")
228
229
230def collocation(xyz,
231                L,
232                coeffs,
233                exponents,
234                center,
235                grad=0,
236                spherical=True,
237                out=None,
238                cartesian_order="cca",
239                spherical_order="cca"):
240
241    # Validates we loaded correctly
242    _validate_c_import()
243
244    order_enum = _wrapper_checks(L, xyz, spherical, spherical_order, cartesian_order)
245
246    # Check gaussian
247    coeffs = np.asarray(coeffs, dtype=np.double)
248    exponents = np.asarray(exponents, dtype=np.double)
249    center = np.asarray(center, dtype=np.double)
250    if coeffs.shape[0] != exponents.shape[0]:
251        raise ValueError("Coefficients (N=%d) and exponents (N=%d) must have the same shape." %
252                         (coeffs.shape[0], exponents.shape[0]))
253
254    # Find the output shape
255    npoints = xyz.shape[1]
256    if spherical:
257        nvals = utility.nspherical(L)
258    else:
259        nvals = utility.ncartesian(L)
260
261    # Build the outputs
262    out = utility.validate_coll_output(grad, (nvals, npoints), out)
263
264    # Select the correct function
265    if grad == 0:
266        cgg.gg_collocation(L, npoints, xyz.ravel(), 1, coeffs.shape[0], coeffs, exponents, center, order_enum,
267                           out["PHI"])
268    elif grad == 1:
269        cgg.gg_collocation_deriv1(L, npoints, xyz.ravel(), 1, coeffs.shape[0], coeffs, exponents, center, order_enum,
270                                  out["PHI"], out["PHI_X"], out["PHI_Y"], out["PHI_Z"])
271    elif grad == 2:
272        cgg.gg_collocation_deriv2(L, npoints, xyz.ravel(), 1, coeffs.shape[0], coeffs, exponents, center, order_enum,
273                                  out["PHI"], out["PHI_X"], out["PHI_Y"], out["PHI_Z"], out["PHI_XX"], out["PHI_XY"],
274                                  out["PHI_XZ"], out["PHI_YY"], out["PHI_YZ"], out["PHI_ZZ"])
275    elif grad == 3:
276        cgg.gg_collocation_deriv3(L, npoints, xyz.ravel(), 1, coeffs.shape[0], coeffs, exponents, center, order_enum,
277                                  out["PHI"], out["PHI_X"], out["PHI_Y"], out["PHI_Z"], out["PHI_XX"], out["PHI_XY"],
278                                  out["PHI_XZ"], out["PHI_YY"], out["PHI_YZ"], out["PHI_ZZ"], out["PHI_XXX"],
279                                  out["PHI_XXY"], out["PHI_XXZ"], out["PHI_XYY"], out["PHI_XYZ"], out["PHI_XZZ"],
280                                  out["PHI_YYY"], out["PHI_YYZ"], out["PHI_YZZ"], out["PHI_ZZZ"])
281    else:
282        raise ValueError("Only up to grad=3 is supported.")
283
284    return out
285
286
287collocation.__doc__ = docs_generator.build_collocation_docs("This function uses a optimized C library as a backend.")
288
289
290def orbital(orbs,
291            xyz,
292            L,
293            coeffs,
294            exponents,
295            center,
296            spherical=True,
297            out=None,
298            cartesian_order="cca",
299            spherical_order="cca"):
300
301    # Validates we loaded correctly
302    _validate_c_import()
303
304    order_enum = _wrapper_checks(L, xyz, spherical, spherical_order, cartesian_order)
305
306    # Check gaussian
307    orbs = np.asarray(orbs, dtype=np.double)
308    coeffs = np.asarray(coeffs, dtype=np.double)
309    exponents = np.asarray(exponents, dtype=np.double)
310    center = np.asarray(center, dtype=np.double)
311    if coeffs.shape[0] != exponents.shape[0]:
312        raise ValueError("Coefficients (N=%d) and exponents (N=%d) must have the same shape." %
313                         (coeffs.shape[0], exponents.shape[0]))
314
315    # Find the output shape
316    npoints = xyz.shape[1]
317    if spherical:
318        nvals = utility.nspherical(L)
319    else:
320        nvals = utility.ncartesian(L)
321
322    if nvals != orbs.shape[1]:
323        raise ValueError("Orbital block, must be equal to the shell size.")
324
325    # Build the outputs
326    if out is not None:
327        out = {"PHI": out}
328    out = utility.validate_coll_output(0, (orbs.shape[0], npoints), out)["PHI"]
329
330    # Select the correct function
331    cgg.gg_orbitals(L, orbs, orbs.shape[0], npoints, xyz.ravel(), 1, coeffs.shape[0], coeffs, exponents, center,
332                    order_enum, out)
333
334    return out
335
336
337orbital.__doc__ = docs_generator.build_orbital_docs("This function uses a optimized C library as a backend.")
338