1# -*- coding: utf-8 -*-
2
3from collections import defaultdict
4
5try:
6    import numpy as np
7except ImportError:
8    np = None
9import pytest
10
11from ..util.testing import requires
12from ..units import (
13    amount,
14    allclose,
15    concatenate,
16    concentration,
17    fold_constants,
18    energy,
19    get_derived_unit,
20    is_unitless,
21    linspace,
22    logspace_from_lin,
23    SI_base_registry,
24    unitless_in_registry,
25    format_string,
26    get_physical_dimensionality,
27    to_unitless,
28    length,
29    magnitude,
30    mass,
31    time,
32    default_unit_in_registry,
33    Backend,
34    latex_of_unit,
35    unit_of,
36    unit_registry_to_human_readable,
37    units_library,
38    volume,
39    simplified,
40    uniform,
41    unit_registry_from_human_readable,
42    _sum,
43    UncertainQuantity,
44    compare_equality,
45    default_units as u,
46    patched_numpy as pnp,
47    default_constants as dc,
48)
49
50
51def test_dimensionality():
52    assert mass + 2 * length - 2 * time == energy
53    assert amount - 3 * length == concentration
54    assert 3 * length == volume
55
56
57@requires(units_library)
58def test_default_units():
59    u.metre
60    u.second
61    u.hour
62    u.decimetre
63    u.mole
64    u.kilogram
65    u.ampere
66    u.kelvin
67    u.candela
68    u.molar
69    u.per100eV
70    u.joule
71    u.gray
72    u.eV
73    u.MeV
74    u.metre
75    u.decimetre
76    u.centimetre
77    u.micrometre
78    u.nanometre
79    u.gram
80    u.molar
81    u.hour
82    u.perMolar_perSecond
83    u.per100eV
84    u.umol
85    u.umol_per_J
86
87
88@requires(units_library)
89def test_allclose():
90    assert allclose(42, 42)
91    assert allclose(42 * u.meter, 0.042 * u.km)
92    assert not allclose(42, 43)
93    assert not allclose(42, 42 * u.meter)
94    assert not allclose(42, 43 * u.meter)
95    assert not allclose(42 * u.meter, 42)
96
97    a = np.linspace(2, 3) * u.second
98    b = np.linspace(2 / 3600.0, 3 / 3600.0) * u.hour
99    assert allclose(a, b)
100    assert allclose(
101        [3600 * u.second, 2 * u.metre / u.hour],
102        [1 * u.hour, 2 / 3600 * u.metre / u.second],
103    )
104    c1 = [[3000, 4000], [3000, 4000]] * u.mol / u.metre ** 3
105    c2 = [[3000, 4000], [436.2, 5281.89]] * u.mol / u.metre ** 3
106    assert not allclose(c1, c2)
107    assert allclose(0 * u.second, 0 * u.second)
108
109    # Possibly allow comparison with scalars in future (broadcasting):
110    # assert allclose(2, [2, 2])
111    # assert allclose([2, 2], 2)
112
113    # assert not allclose(2, [2, 3])
114    # assert not allclose([2, 3], 2)
115
116    # assert allclose(2*u.second, [2, 2]*u.second)
117    # assert allclose([2, 2]*u.second, 2*u.second)
118
119    # assert not allclose(2*u.second, [2, 3]*u.second)
120    # assert not allclose([2, 3]*u.second, 2*u.second)
121
122
123@requires(units_library)
124def test_is_unitless():
125    assert not is_unitless(1 * u.second)
126    assert is_unitless(1)
127    assert is_unitless({"a": 1, "b": 2.0})
128    assert not is_unitless({"a": 2, "b": 5.0 * u.second, "c": 3})
129    assert is_unitless(7 * u.molar / u.mole * u.dm3)
130    assert is_unitless([2, 3, 4])
131    assert not is_unitless([2 * u.m, 3 * u.m])
132    assert not is_unitless([3, 4 * u.m])
133    assert is_unitless(u.dimensionless)  # this was causing RecursionError
134
135
136@requires(units_library)
137def test_unit_of():
138    assert compare_equality(unit_of(0.1 * u.metre / u.second), u.metre / u.second)
139    assert not compare_equality(
140        unit_of(0.1 * u.metre / u.second), u.kilometre / u.second
141    )
142    assert compare_equality(unit_of(7), 1)
143    assert unit_of(u.gray).dimensionality == u.gray.dimensionality
144    ref = (u.joule / u.kg).simplified.dimensionality
145    assert unit_of(u.gray, simplified=True).dimensionality == ref
146
147    assert compare_equality(unit_of(dict(foo=3 * u.molar, bar=2 * u.molar)), u.molar)
148    assert not compare_equality(
149        unit_of(dict(foo=3 * u.molar, bar=2 * u.molar)), u.second
150    )
151    with pytest.raises(Exception):
152        unit_of(dict(foo=3 * u.molar, bar=2 * u.second))
153    assert not compare_equality(
154        unit_of(dict(foo=3 * u.molar, bar=2 * u.molar)), u.mol / u.metre ** 3
155    )
156
157
158@requires(units_library)
159def test_to_unitless():
160    dm = u.decimetre
161    vals = [1.0 * dm, 2.0 * dm]
162    result = to_unitless(vals, u.metre)
163    assert result[0] == 0.1
164    assert result[1] == 0.2
165    with pytest.raises(ValueError):
166        to_unitless([42, 43], u.metre)
167
168    with pytest.raises(ValueError):
169        to_unitless(np.array([42, 43]), u.metre)
170
171    vals = [1.0, 2.0] * dm
172    result = to_unitless(vals, u.metre)
173    assert result[0] == 0.1
174    assert result[1] == 0.2
175
176    length_unit = 1000 * u.metre
177    result = to_unitless(1.0 * u.metre, length_unit)
178    assert abs(result - 1e-3) < 1e-12
179
180    amount_unit = 1e-9  # nano
181    assert abs(to_unitless(1.0, amount_unit) - 1e9) < 1e-6
182    assert (
183        abs(
184            to_unitless(3 / (u.second * u.molar), u.metre ** 3 / u.mole / u.second)
185            - 3e-3
186        )
187        < 1e-12
188    )
189    assert abs(to_unitless(2 * u.dm3, u.cm3) - 2000) < 1e-12
190    assert abs(to_unitless(2 * u.m3, u.dm3) - 2000) < 1e-12
191    assert (float(to_unitless(UncertainQuantity(2, u.dm3, 0.3), u.cm3)) - 2000) < 1e-12
192
193    g1 = UncertainQuantity(4.46, u.per100eV, 0)
194    g_unit = get_derived_unit(SI_base_registry, "radiolytic_yield")
195    assert abs(to_unitless(g1, g_unit) - 4.46 * 1.036e-7) < 1e-9
196    g2 = UncertainQuantity(-4.46, u.per100eV, 0)
197    assert abs(to_unitless(-g2, g_unit) - 4.46 * 1.036e-7) < 1e-9
198
199    vals = np.array([1.0 * dm, 2.0 * dm], dtype=object)
200    result = to_unitless(vals, u.metre)
201    assert result[0] == 0.1
202    assert result[1] == 0.2
203
204    one_billionth_molar_in_nanomolar = to_unitless(1e-9 * u.molar, u.nanomolar)
205    assert one_billionth_molar_in_nanomolar == 1
206
207
208@requires(units_library)
209def test_UncertainQuantity():
210    a = UncertainQuantity([1, 2], u.m, [0.1, 0.2])
211    assert a[1] == [2.0] * u.m
212    assert (-a)[0] == [-1.0] * u.m
213    assert (-a).uncertainty[0] == [0.1] * u.m
214    assert (-a)[0] == (a * -1)[0]
215    assert (-a).uncertainty[0] == (a * -1).uncertainty[0]
216    assert allclose(a, [1, 2] * u.m)
217
218
219@requires(units_library, "sympy")
220def test_to_unitless__sympy():
221    import sympy as sp
222
223    assert sp.cos(to_unitless(sp.pi)) == -1
224    with pytest.raises(AttributeError):
225        to_unitless(sp.pi, u.second)
226
227
228@requires(units_library)
229def test_linspace():
230    ls = linspace(2 * u.second, 3 * u.second)
231    assert abs(to_unitless(ls[0], u.hour) - 2 / 3600.0) < 1e-15
232
233
234@requires(units_library)
235def test_logspace_from_lin():
236    ls = logspace_from_lin(2 * u.second, 3 * u.second)
237    assert abs(to_unitless(ls[0], u.hour) - 2 / 3600.0) < 1e-15
238    assert abs(to_unitless(ls[-1], u.hour) - 3 / 3600.0) < 1e-15
239
240
241@requires(units_library)
242def test_get_derived_unit():
243    registry = SI_base_registry.copy()
244    registry["length"] = 1e-1 * registry["length"]
245    conc_unit = get_derived_unit(registry, "concentration")
246    dm = u.decimetre
247    assert abs(conc_unit - 1 * u.mole / (dm ** 3)) < 1e-12 * u.mole / (dm ** 3)
248
249    registry = defaultdict(lambda: 1)
250    registry["amount"] = 1e-9  # nano
251    assert (
252        abs(to_unitless(1.0, get_derived_unit(registry, "concentration")) - 1e9) < 1e-6
253    )
254
255
256@requires(units_library)
257def test_unit_registry_to_human_readable():
258    # Not as much human readable as JSON serializable...
259    d = defaultdict(lambda: 1)
260    assert unit_registry_to_human_readable(d) == dict(
261        (x, (1, 1)) for x in SI_base_registry.keys()
262    )
263
264    ur = {
265        "length": 1e3 * u.metre,
266        "mass": 1e-2 * u.kilogram,
267        "time": 1e4 * u.second,
268        "current": 1e-1 * u.ampere,
269        "temperature": 1e1 * u.kelvin,
270        "luminous_intensity": 1e-3 * u.candela,
271        "amount": 1e4 * u.mole,
272    }
273    assert unit_registry_to_human_readable(ur) == {
274        "length": (1e3, "m"),
275        "mass": (1e-2, "kg"),
276        "time": (1e4, "s"),
277        "current": (1e-1, "A"),
278        "temperature": (1e1, "K"),
279        "luminous_intensity": (1e-3, "cd"),
280        "amount": (1e4, "mol"),
281    }
282    assert unit_registry_to_human_readable(ur) != {
283        "length": (1e2, "m"),
284        "mass": (1e-2, "kg"),
285        "time": (1e4, "s"),
286        "current": (1e-1, "A"),
287        "temperature": (1e1, "K"),
288        "luminous_intensity": (1e-3, "cd"),
289        "amount": (1e4, "mol"),
290    }
291
292
293@requires(units_library)
294def test_unit_registry_from_human_readable():
295    hr = unit_registry_to_human_readable(defaultdict(lambda: 1))
296    assert hr == dict((x, (1, 1)) for x in SI_base_registry.keys())
297    ur = unit_registry_from_human_readable(hr)
298    assert ur == dict((x, 1) for x in SI_base_registry.keys())
299
300    hr = unit_registry_to_human_readable(SI_base_registry)
301    assert hr == {
302        "length": (1.0, "m"),
303        "mass": (1.0, "kg"),
304        "time": (1.0, "s"),
305        "current": (1.0, "A"),
306        "temperature": (1.0, "K"),
307        "luminous_intensity": (1.0, "cd"),
308        "amount": (1.0, "mol"),
309    }
310    ur = unit_registry_from_human_readable(hr)
311    assert ur == SI_base_registry
312
313    ur = unit_registry_from_human_readable(
314        {
315            "length": (1.0, "m"),
316            "mass": (1.0, "kg"),
317            "time": (1.0, "s"),
318            "current": (1.0, "A"),
319            "temperature": (1.0, "K"),
320            "luminous_intensity": (1.0, "cd"),
321            "amount": (1.0, "mol"),
322        }
323    )
324    assert ur == {
325        "length": u.metre,
326        "mass": u.kilogram,
327        "time": u.second,
328        "current": u.ampere,
329        "temperature": u.kelvin,
330        "luminous_intensity": u.candela,
331        "amount": u.mole,
332    }
333
334    ur = unit_registry_from_human_readable(
335        {
336            "length": (1e3, "m"),
337            "mass": (1e-2, "kg"),
338            "time": (1e4, "s"),
339            "current": (1e-1, "A"),
340            "temperature": (1e1, "K"),
341            "luminous_intensity": (1e-3, "cd"),
342            "amount": (1e4, "mol"),
343        }
344    )
345    assert ur == {
346        "length": 1e3 * u.metre,
347        "mass": 1e-2 * u.kilogram,
348        "time": 1e4 * u.second,
349        "current": 1e-1 * u.ampere,
350        "temperature": 1e1 * u.kelvin,
351        "luminous_intensity": 1e-3 * u.candela,
352        "amount": 1e4 * u.mole,
353    }
354
355    assert ur != {
356        "length": 1e2 * u.metre,
357        "mass": 1e-3 * u.kilogram,
358        "time": 1e2 * u.second,
359        "current": 1e-2 * u.ampere,
360        "temperature": 1e0 * u.kelvin,
361        "luminous_intensity": 1e-2 * u.candela,
362        "amount": 1e3 * u.mole,
363    }
364
365
366@requires(units_library)
367def test_unitless_in_registry():
368    mag = magnitude(unitless_in_registry(3 * u.per100eV, SI_base_registry))
369    ref = 3 * 1.0364268834527753e-07
370    assert abs(mag - ref) < 1e-14
371    ul = unitless_in_registry([3 * u.per100eV, 5 * u.mol / u.J], SI_base_registry)
372    assert allclose(ul, [ref, 5], rtol=1e-6)
373
374
375@requires(units_library)
376def test_compare_equality():
377    assert compare_equality(3 * u.m, 3 * u.m)
378    assert compare_equality(3 * u.m, 3e-3 * u.km)
379    assert compare_equality(3e3 * u.mm, 3 * u.m)
380    assert not compare_equality(3 * u.m, 2 * u.m)
381    assert not compare_equality(3 * u.m, 3 * u.s)
382    assert not compare_equality(3 * u.m, 3 * u.m ** 2)
383    assert not compare_equality(3 * u.m, np.array(3))
384    assert not compare_equality(np.array(3), 3 * u.m)
385    assert compare_equality([3, None], [3, None])
386    assert not compare_equality([3, None, 3], [3, None, None])
387    assert not compare_equality([None, None, 3], [None, None, 2])
388    assert compare_equality([3 * u.m, None], [3, None])
389    assert not compare_equality([3 * u.m, None], [3 * u.km, None])
390
391
392@requires(units_library)
393def test_get_physical_dimensionality():
394    assert get_physical_dimensionality(3 * u.mole) == {"amount": 1}
395    assert get_physical_dimensionality([3 * u.mole]) == {"amount": 1}
396    assert get_physical_dimensionality(42) == {}
397
398
399@requires(units_library)
400def test_default_unit_in_registry():
401    mol_per_m3 = default_unit_in_registry(3 * u.molar, SI_base_registry)
402    assert magnitude(mol_per_m3) == 1
403    assert mol_per_m3 == u.mole / u.metre ** 3
404
405    assert default_unit_in_registry(3, SI_base_registry) == 1
406    assert default_unit_in_registry(3.0, SI_base_registry) == 1
407
408
409@requires(units_library)
410def test__sum():
411    # sum() does not work here...
412    assert (_sum([0.1 * u.metre, 1 * u.decimetre]) - 2 * u.decimetre) / u.metre == 0
413
414
415@requires(units_library)
416def test_Backend():
417    b = Backend()
418    with pytest.raises(ValueError):
419        b.exp(-3 * u.metre)
420    assert abs(b.exp(1234 * u.metre / u.kilometre) - b.exp(1.234)) < 1e-14
421
422
423@requires(units_library, "numpy")
424def test_Backend__numpy():
425    import numpy as np
426
427    b = Backend(np)
428    b.sum([1000 * u.metre / u.kilometre, 1], axis=0) == 2.0
429
430    with pytest.raises(AttributeError):
431        b.Piecewise
432
433
434@requires("sympy")
435def test_Backend__sympy():
436    b = Backend("sympy")
437    b.sin(b.pi) == 0
438
439    with pytest.raises(AttributeError):
440        b.min
441
442
443@requires(units_library)
444def test_format_string():
445    assert format_string(3 * u.gram / u.metre ** 2) == ("3", "g/m**2")
446    assert format_string(3 * u.gram / u.metre ** 2, tex=True) == (
447        "3",
448        r"\mathrm{\frac{g}{m^{2}}}",
449    )
450
451
452@requires(units_library)
453def test_joule_html():
454    joule_htm = "kg&sdot;m<sup>2</sup>/s<sup>2</sup>"
455    joule = u.J.dimensionality.simplified
456    assert joule.html == joule_htm
457
458
459@requires(units_library)
460def test_latex_of_unit():
461    assert latex_of_unit(u.gram / u.metre ** 2) == r"\mathrm{\frac{g}{m^{2}}}"
462
463
464@requires(units_library)
465def test_concatenate():
466    a = [1, 2] * u.metre
467    b = [2, 3] * u.mm
468    ref = [1, 2, 2e-3, 3e-3] * u.metre
469    assert allclose(concatenate((a, b)), ref)
470
471
472@requires(units_library)
473def test_pow0():
474    a = [1, 2] * u.metre
475    b = a ** 0
476    assert allclose(b, [1, 1])
477
478    c = a ** 2
479    assert allclose(c, [1, 4] * u.m ** 2)
480
481
482@requires(units_library)
483def test_patched_numpy():
484    # see https://github.com/python-quantities/python-quantities/issues/152
485    assert allclose(pnp.exp(3 * u.joule / (2 * u.cal)), 1.43119335, rtol=1e-5)
486    for arg in ([1, 2], [[1], [2]], [1], 2):
487        assert np.all(pnp.exp(arg) == np.exp(arg))
488
489
490@requires(units_library)
491def test_tile():
492    a = [2 * u.m, 3 * u.km]
493    assert allclose(pnp.tile(a, 2), [2 * u.m, 3000 * u.m, 2e-3 * u.km, 3 * u.km])
494
495
496@requires(units_library)
497def test_simplified():
498    assert allclose(
499        simplified(dc.molar_gas_constant), 8.314 * u.J / u.mol / u.K, rtol=2e-3
500    )
501    assert simplified(2.0) == 2.0
502
503
504@requires(units_library)
505def test_polyfit_polyval():
506    p1 = pnp.polyfit([0, 1, 2], [0, 1, 4], 2)
507    assert allclose(p1, [1, 0, 0], atol=1e-14)
508    assert allclose(pnp.polyval(p1, 3), 9)
509    assert allclose(pnp.polyval(p1, [4, 5]), [16, 25])
510
511    p2 = pnp.polyfit([0, 1, 2] * u.s, [0, 1, 4] * u.m, 2)
512    for _p, _r, _a in zip(
513        p2,
514        [1 * u.m / u.s ** 2, 0 * u.m / u.s, 0 * u.m],
515        [0 * u.m / u.s ** 2, 1e-15 * u.m / u.s, 1e-15 * u.m],
516    ):
517        assert allclose(_p, _r, atol=_a)
518    assert allclose(pnp.polyval(p2, 3 * u.s), 9 * u.m)
519    assert allclose(pnp.polyval(p2, [4, 5] * u.s), [16, 25] * u.m)
520
521
522@requires(units_library)
523def test_uniform():
524    base = [3 * u.km, 200 * u.m]
525    refs = [np.array([3000, 200]), np.array([3, 0.2])]
526
527    def _check(case, ref):
528        assert np.any(np.all(magnitude(uniform(case)) == ref, axis=1))
529
530    _check(base, refs)
531    _check(tuple(base), refs)
532    keys = "foo bar".split()
533    assert magnitude(uniform(dict(zip(keys, base)))) in [
534        dict(zip(keys, r)) for r in refs
535    ]
536
537
538@requires(units_library)
539def test_fold_constants():
540    assert abs(fold_constants(dc.pi) - np.pi) < 1e-15
541
542
543@requires("numpy")
544def test_to_unitless___0D_array_with_object():
545    from ..util._expr import Constant
546
547    # b = Backend('sympy')
548    # pi = np.array(b.pi)
549    pi = np.array(Constant(np.pi))
550    one_thousand = to_unitless(pi * u.metre, u.millimeter)
551    assert get_physical_dimensionality(one_thousand) == {}
552    assert abs(magnitude(one_thousand) - np.arctan(1) * 4e3) < 1e-12
553