1from decimal import Decimal
2
3import numpy as np
4import pydantic
5import pytest
6
7import qcelemental as qcel
8from qcelemental.testing import compare_recursive, compare_values
9
10
11@pytest.fixture
12def dataset():
13    datums = {
14        "decimal": qcel.Datum(
15            "a label", "mdyn/angstrom", Decimal("4.4"), comment="force constant", doi="10.1000/182", numeric=False
16        ),
17        "ndarray": qcel.Datum("an array", "cm^-1", np.arange(4, dtype=float) * 4 / 3, comment="freqs"),
18        "float": qcel.Datum("a float", "kg", 4.4, doi="10.1000/182"),
19        "string": qcel.Datum("ze lbl", "ze unit", "ze data", numeric=False),
20        "lststr": qcel.Datum("ze lbl", "ze unit", ["V", "R", None], numeric=False),
21    }
22
23    return datums
24
25
26def test_creation(dataset):
27    datum1 = dataset["decimal"]
28
29    assert datum1.label == "a label"
30    assert datum1.units == "mdyn/angstrom"
31    assert datum1.data == Decimal("4.4")
32    assert datum1.numeric is True  # checking that numeric got properly reset from input
33
34
35def test_creation_nonnum(dataset):
36    datum1 = dataset["string"]
37
38    assert datum1.label == "ze lbl"
39    assert datum1.units == "ze unit"
40    assert datum1.data == "ze data"
41    assert datum1.numeric is False
42
43
44def test_creation_error():
45    with pytest.raises(pydantic.ValidationError):
46        qcel.Datum("ze lbl", "ze unit", "ze data")
47
48    # assert 'Datum data should be float' in str(e)
49
50
51@pytest.mark.parametrize(
52    "inp,expected",
53    [
54        (("decimal", None), 4.4),
55        (("decimal", "N/m"), 440),
56        (("decimal", "hartree/bohr/bohr"), 0.282614141011 if qcel.constants.name == "CODATA2014" else 0.28261413658),
57        (("ndarray", "1/m"), np.arange(4, dtype=float) * 400 / 3),
58    ],
59)
60def test_units(dataset, inp, expected):
61    assert compare_values(expected, dataset[inp[0]].to_units(inp[1]), atol=1.0e-9)
62
63
64def test_printing(dataset):
65    datum1 = dataset["decimal"]
66    str1 = """----------------------------------------
67             Datum a label
68                 Pytest
69----------------------------------------
70Data:     4.4
71Units:    [mdyn/angstrom]
72doi:      10.1000/182
73Comment:  force constant
74Glossary:
75----------------------------------------"""
76
77    # Handle some odd spaces in the assert
78    str2 = datum1.__str__(label="Pytest")
79    assert all(x == y for x, y in zip(str1.split(), str2.split()))
80
81
82def test_mass_printing_blank():
83    pvnone = """
84  Variable Map:
85  ----------------------------------------------------------------------------
86  (none)"""
87
88    assert pvnone == qcel.datum.print_variables({})
89
90
91def test_mass_printing(dataset):
92    str1 = """
93  Variable Map:
94  ----------------------------------------------------------------------------
95  "decimal" =>                    4.4 [mdyn/angstrom]
96  "float"   =>         4.400000000000 [kg]
97  "lststr"  =>       ['V', 'R', None] [ze unit]
98  "ndarray" =>                        [cm^-1]
99        [0.         1.33333333 2.66666667 4.        ]
100  "string"  =>                ze data [ze unit]
101"""
102
103    assert str1 == qcel.datum.print_variables(dataset)
104
105
106def test_to_dict(dataset):
107    listans = [i * 4 / 3 for i in range(4)]
108    ans = {"label": "an array", "units": "cm^-1", "data": listans, "comment": "freqs", "numeric": True}
109
110    dicary = dataset["ndarray"].dict()
111    assert compare_recursive(ans, dicary, 9)
112
113
114def test_complex_scalar():
115    datum1 = qcel.Datum("complex scalar", "", complex(1, 2))
116    ans = {"label": "complex scalar", "units": "", "data": complex(1, 2), "numeric": True}
117
118    assert datum1.label == "complex scalar"
119    assert datum1.units == ""
120    assert datum1.data.real == 1
121    assert datum1.data.imag == 2
122
123    dicary = datum1.dict()
124    assert compare_recursive(ans, dicary, 9)
125
126
127def test_complex_array():
128    datum1 = qcel.Datum("complex array", "", np.arange(3, dtype=np.complex_) + 1j)
129    ans = {
130        "label": "complex array",
131        "units": "",
132        "data": [complex(0, 1), complex(1, 1), complex(2, 1)],
133        "numeric": True,
134    }
135
136    dicary = datum1.dict()
137    assert compare_recursive(ans, dicary, 9)
138
139
140def test_qc_units():
141    au2D = 2.541746451895025916414946904
142    au2Q = au2D * 0.52917721067
143
144    onedebye = qcel.Datum("CC dipole", "e a0", np.array([0, 0, 1 / au2D]))
145    onebuckingham = qcel.Datum("CC quadrupole", "e a0^2", np.array([0, 0, 1 / au2Q, 0, 0, 0, 0, 0, 0]).reshape((3, 3)))
146
147    assert compare_values(np.array([0, 0, 1.0]), onedebye.to_units("D"))
148    assert compare_values(np.array([[0, 0, 1.0], [0, 0, 0], [0, 0, 0]]), onebuckingham.to_units("D Å"))
149