1from decimal import Decimal 2 3import numpy as np 4import pydantic 5import pytest 6 7import qcelemental as qcel 8from qcelemental.testing import compare_recursive, compare_values 9 10 11@pytest.fixture 12def dataset(): 13 datums = { 14 "decimal": qcel.Datum( 15 "a label", "mdyn/angstrom", Decimal("4.4"), comment="force constant", doi="10.1000/182", numeric=False 16 ), 17 "ndarray": qcel.Datum("an array", "cm^-1", np.arange(4, dtype=float) * 4 / 3, comment="freqs"), 18 "float": qcel.Datum("a float", "kg", 4.4, doi="10.1000/182"), 19 "string": qcel.Datum("ze lbl", "ze unit", "ze data", numeric=False), 20 "lststr": qcel.Datum("ze lbl", "ze unit", ["V", "R", None], numeric=False), 21 } 22 23 return datums 24 25 26def test_creation(dataset): 27 datum1 = dataset["decimal"] 28 29 assert datum1.label == "a label" 30 assert datum1.units == "mdyn/angstrom" 31 assert datum1.data == Decimal("4.4") 32 assert datum1.numeric is True # checking that numeric got properly reset from input 33 34 35def test_creation_nonnum(dataset): 36 datum1 = dataset["string"] 37 38 assert datum1.label == "ze lbl" 39 assert datum1.units == "ze unit" 40 assert datum1.data == "ze data" 41 assert datum1.numeric is False 42 43 44def test_creation_error(): 45 with pytest.raises(pydantic.ValidationError): 46 qcel.Datum("ze lbl", "ze unit", "ze data") 47 48 # assert 'Datum data should be float' in str(e) 49 50 51@pytest.mark.parametrize( 52 "inp,expected", 53 [ 54 (("decimal", None), 4.4), 55 (("decimal", "N/m"), 440), 56 (("decimal", "hartree/bohr/bohr"), 0.282614141011 if qcel.constants.name == "CODATA2014" else 0.28261413658), 57 (("ndarray", "1/m"), np.arange(4, dtype=float) * 400 / 3), 58 ], 59) 60def test_units(dataset, inp, expected): 61 assert compare_values(expected, dataset[inp[0]].to_units(inp[1]), atol=1.0e-9) 62 63 64def test_printing(dataset): 65 datum1 = dataset["decimal"] 66 str1 = """---------------------------------------- 67 Datum a label 68 Pytest 69---------------------------------------- 70Data: 4.4 71Units: [mdyn/angstrom] 72doi: 10.1000/182 73Comment: force constant 74Glossary: 75----------------------------------------""" 76 77 # Handle some odd spaces in the assert 78 str2 = datum1.__str__(label="Pytest") 79 assert all(x == y for x, y in zip(str1.split(), str2.split())) 80 81 82def test_mass_printing_blank(): 83 pvnone = """ 84 Variable Map: 85 ---------------------------------------------------------------------------- 86 (none)""" 87 88 assert pvnone == qcel.datum.print_variables({}) 89 90 91def test_mass_printing(dataset): 92 str1 = """ 93 Variable Map: 94 ---------------------------------------------------------------------------- 95 "decimal" => 4.4 [mdyn/angstrom] 96 "float" => 4.400000000000 [kg] 97 "lststr" => ['V', 'R', None] [ze unit] 98 "ndarray" => [cm^-1] 99 [0. 1.33333333 2.66666667 4. ] 100 "string" => ze data [ze unit] 101""" 102 103 assert str1 == qcel.datum.print_variables(dataset) 104 105 106def test_to_dict(dataset): 107 listans = [i * 4 / 3 for i in range(4)] 108 ans = {"label": "an array", "units": "cm^-1", "data": listans, "comment": "freqs", "numeric": True} 109 110 dicary = dataset["ndarray"].dict() 111 assert compare_recursive(ans, dicary, 9) 112 113 114def test_complex_scalar(): 115 datum1 = qcel.Datum("complex scalar", "", complex(1, 2)) 116 ans = {"label": "complex scalar", "units": "", "data": complex(1, 2), "numeric": True} 117 118 assert datum1.label == "complex scalar" 119 assert datum1.units == "" 120 assert datum1.data.real == 1 121 assert datum1.data.imag == 2 122 123 dicary = datum1.dict() 124 assert compare_recursive(ans, dicary, 9) 125 126 127def test_complex_array(): 128 datum1 = qcel.Datum("complex array", "", np.arange(3, dtype=np.complex_) + 1j) 129 ans = { 130 "label": "complex array", 131 "units": "", 132 "data": [complex(0, 1), complex(1, 1), complex(2, 1)], 133 "numeric": True, 134 } 135 136 dicary = datum1.dict() 137 assert compare_recursive(ans, dicary, 9) 138 139 140def test_qc_units(): 141 au2D = 2.541746451895025916414946904 142 au2Q = au2D * 0.52917721067 143 144 onedebye = qcel.Datum("CC dipole", "e a0", np.array([0, 0, 1 / au2D])) 145 onebuckingham = qcel.Datum("CC quadrupole", "e a0^2", np.array([0, 0, 1 / au2Q, 0, 0, 0, 0, 0, 0]).reshape((3, 3))) 146 147 assert compare_values(np.array([0, 0, 1.0]), onedebye.to_units("D")) 148 assert compare_values(np.array([[0, 0, 1.0], [0, 0, 0], [0, 0, 0]]), onebuckingham.to_units("D Å")) 149