1""" Test unit conversion functions used on input and output of functions.
2
3    fixme: We need significant work on scalars.
4"""
5
6# Standard Library imports
7from __future__ import absolute_import
8import unittest
9
10# Numeric library imports
11from numpy import array, all, allclose, ndarray
12
13# Enthought library imports
14from scimath.units.unit import InvalidConversion
15from scimath.units.length import feet, meters
16from scimath.units.time import second
17
18# Numerical modeling library imports
19from scimath.units.api import UnitArray, UnitScalar
20from scimath.units.unit_manipulation import \
21    convert_units, set_units, have_some_units, strip_units
22
23
24class ConvertUnitsTestCase(unittest.TestCase):
25    """ ConvertUnits should pretty much leave anything without units alone
26        and pass them through silently.  UnitArrays do get converted,
27        and so should scalars with units (although we haven't really dealt with
28        those).
29    """
30
31    ##########################################################################
32    # ConvertUnitsTestCase interface.
33    ##########################################################################
34
35    def test_single_float(self):
36        """ Does it pass a single value through correctly?
37        """
38        units = [None]
39        result = convert_units(units, 1.0)
40        self.assertEqual(1.0, result)
41
42    def test_two_float(self):
43        """ Does it pass a two values through correctly?
44        """
45        units = [None, None]
46        result = convert_units(units, 1.0, 2.0)
47        self.assertEqual([1.0, 2.0], result)
48
49    def test_mismatch_raises_error(self):
50        """ Is an exception raised if there aren't enough units specified?
51        """
52        self.assertRaises(ValueError, convert_units, ([None], 1.0, 2.0))
53
54    def test_one_array(self):
55        """ Does it pass a an array through correctly?
56        """
57        units = [None]
58        a = array((1, 2, 3))
59        result = convert_units(units, a)
60        self.assertTrue(all(a == result))
61
62    def test_two_arrays(self):
63        """ Does it pass a two arrays through correctly?
64        """
65        units = [None, None]
66        a = array((1, 2, 3))
67        b = array((3, 4, 5))
68        aa, bb = convert_units(units, a, b)
69        self.assertTrue(all(a == aa))
70        self.assertTrue(all(b == bb))
71
72    def test_convert_array_with_units(self):
73        """ Does it add units to an array correctly?
74
75            fixme: This may be exactly what we don't want to happen!
76        """
77        units = [feet]
78        a = array((1, 2, 3))
79        aa = convert_units(units, a)
80        self.assertTrue(all(a == aa))
81        self.assertTrue(isinstance(aa, ndarray))
82
83    def test_convert_unit_array(self):
84        """ Does it convert an array correctly?
85        """
86        units = [feet]
87        a = UnitArray((1, 2, 3), units=meters)
88        aa = convert_units(units, a)
89        self.assertTrue(allclose(a, aa.as_units(meters)))
90        # fixme: This actually may be something we don't want.  For speed,
91        #        if this were just a standard array, we would be better off.
92        self.assertEqual(aa.units, feet)
93
94    def test_convert_unit_scalar(self):
95        """ Does it convert a scalar correctly?
96        """
97        units = [feet]
98        a = UnitScalar(3., units=meters)
99        aa = convert_units(units, a)
100        self.assertTrue(allclose(a, aa.as_units(meters)))
101        self.assertEqual(aa.units, feet)
102
103    def test_incompatible_array_units_raise_exception(self):
104        """ Does a units mismatch raise an exception?
105
106            fixme: Do we want this configurable?
107        """
108        units = [second]
109        a = UnitArray((1, 2, 3), units=meters)
110        self.assertRaises(InvalidConversion, convert_units, units, a)
111
112    def test_incompatible_scalar_units_raise_exception(self):
113        """ Does a units mismatch raise an exception?
114
115            fixme: Do we want this configurable?
116        """
117        units = [second]
118        a = UnitScalar(3., units=meters)
119        self.assertRaises(InvalidConversion, convert_units, units, a)
120
121    def test_dont_convert_unit_array(self):
122        """ Does it return the same object if units are the same?
123
124            Note: This isn't required for accuracy, but it is a good
125                  optimization.
126        """
127        units = [feet]
128        a = UnitArray((1, 2, 3), units=feet)
129        aa = convert_units(units, a)
130        self.assertTrue(id(a), id(aa))
131
132    def test_dont_convert_unit_scalar(self):
133        """ Does it return the same object if units are the same?
134
135            Note: This isn't required for accuracy, but it is a good
136                  optimization.
137        """
138        units = [feet]
139        a = UnitScalar(3., units=feet)
140        aa = convert_units(units, a)
141        self.assertTrue(id(a), id(aa))
142
143    def test_convert_different_args(self):
144        """ Does it handle multiple different args correctly?
145        """
146        units = [feet, meters, None, feet]
147        a = UnitArray((1, 2, 3), units=meters)
148        b = array((2, 3, 4))
149        c = 1
150        d = UnitScalar(3., units=meters)
151        aa, bb, cc, dd = convert_units(units, a, b, c, d)
152        self.assertTrue(allclose(a, aa.as_units(meters)))
153        self.assertTrue(allclose(b, bb))
154        self.assertEqual(c, cc)
155        self.assertTrue(allclose(d, dd.as_units(meters)))
156
157
158class SetUnitsTestCase(unittest.TestCase):
159
160    ##########################################################################
161    # TestCase interface.
162    ##########################################################################
163
164    def setUp(self):
165        unittest.TestCase.setUp(self)
166
167    def tearDown(self):
168        unittest.TestCase.tearDown(self)
169
170    ##########################################################################
171    # SetUnitsTestCase interface.
172    ##########################################################################
173
174    def test_single_float(self):
175        """ Does it pass a single value through correctly?
176        """
177        units = [None]
178        result = set_units(units, 1.0)
179        self.assertEqual(1.0, result)
180
181    def test_mismatch_raises_error(self):
182        """ Is an exception raised if there aren't enough units specified?
183        """
184        self.assertRaises(ValueError, convert_units, [None], 1.0, 2.0)
185
186    def test_one_array(self):
187        """ Does it pass a an array through correctly?
188        """
189        units = [None]
190        a = array((1, 2, 3))
191        result = set_units(units, a)
192        self.assertTrue(all(a == result))
193
194    def test_set_scalar_with_units(self):
195        """ Does it add units to a scalar correctly?
196        """
197        units = [feet]
198        x = 3.0
199        xx = set_units(units, x)
200        self.assertEqual(float(xx), x)
201        self.assertEqual(xx.units, feet)
202
203    def test_set_array_with_units(self):
204        """ Does it add units to an array correctly?
205
206            fixme: This may be exactly what we don't want to happen!
207        """
208        units = [feet]
209        a = array((1, 2, 3))
210        aa = set_units(units, a)
211        self.assertTrue(all(a == aa))
212        self.assertEqual(aa.units, feet)
213
214    def test_set_zero_dim_array_with_units(self):
215        """ Does it add units to an array with shape () correctly?
216
217            fixme: This may be exactly what we don't want to happen!
218        """
219        units = [feet]
220        a = array(2)
221        aa = set_units(units, a)
222        self.assertTrue(all(a == aa))
223        self.assertEqual(aa.units, feet)
224        assert isinstance(aa, UnitScalar)
225
226    def test_set_unit_overwrite_unit_scalar(self):
227        """ Does it overwrite units on a UnitScalar correctly?
228        """
229        units = [feet]
230        x = UnitScalar(3., units=meters)
231        xx = set_units(units, x)
232        # FIXME:
233        #     Behaves very stangely (on my machine), somethimes it fails,
234        #     other times it works almost like a random generator.
235        #
236        # We found that set_units(units, x) has a sideffect on x which it
237        # should not have.
238        #
239        #self.assertEqual(x, xx)
240        # print x, x.units
241        self.assertEqual(xx.units, feet)
242
243    def test_set_unit_overwrite_unit_array(self):
244        """ Does it overwrite units on a UnitArray correctly?
245        """
246        units = [feet]
247        a = UnitArray((1, 2, 3), units=meters)
248        aa = set_units(units, a)
249        self.assertTrue(all(a == aa))
250        self.assertEqual(aa.units, feet)
251#
252#    def test_raises_exception(self):
253#        """ Does it return the same object if units are the same?
254#
255#            Note: This isn't required for accuracy, but it is a good
256#                  optimization.
257#        """
258#        units = [feet]
259#        a = UnitArray((1,2,3),units=feet)
260#        aa = convert_units(units, a)
261#        self.assertTrue(id(a),id(aa))
262#
263#    def test_convert_different_args(self):
264#        """ Does it handle multiple different args correctly?
265#        """
266#        units = [feet, meters, None]
267#        a = UnitArray((1,2,3),units=meters)
268#        b = array((2,3,4))
269#        c = 1
270#        aa, bb, cc = convert_units(units, a, b, c)
271#        self.assertTrue(allclose(a,aa.as_units(meters)))
272#        self.assertTrue(allclose(b,bb))
273#        self.assertEqual(c,cc)
274
275
276class HaveSomeUnitsTestCase(unittest.TestCase):
277    """ have_some_units should check its arguments for any
278    UnitArrays/UnitScalars.
279    """
280
281    ##########################################################################
282    # TestCase interface.
283    ##########################################################################
284
285    def setUp(self):
286        # Make some useful data.
287        self.unit_array = UnitArray((1, 2, 3), units=meters)
288        self.unit_scalar = UnitScalar(1, units=meters)
289        self.plain_array = array([1, 2, 3])
290        self.plain_scalar = 1
291        unittest.TestCase.setUp(self)
292
293    def test_finds_one(self):
294        self.assertTrue(have_some_units(self.unit_array))
295        self.assertTrue(have_some_units(self.unit_scalar))
296
297    def test_finds_multiple(self):
298        self.assertTrue(have_some_units(self.unit_array, self.unit_array))
299        self.assertTrue(have_some_units(self.unit_scalar, self.unit_scalar))
300
301    def test_finds_mixed_scalar_array(self):
302        self.assertTrue(have_some_units(self.unit_array, self.unit_scalar))
303
304    def test_does_not_find_plain(self):
305        self.assertFalse(have_some_units(self.plain_array))
306        self.assertFalse(have_some_units(self.plain_scalar))
307
308    def test_does_not_find_mixed_plain(self):
309        self.assertFalse(have_some_units(self.plain_array, self.plain_scalar))
310
311    def test_finds_any_unitted(self):
312        self.assertTrue(
313            have_some_units(
314                self.unit_array,
315                self.plain_array,
316                self.plain_scalar))
317        self.assertTrue(
318            have_some_units(
319                self.plain_array,
320                self.unit_array,
321                self.plain_scalar))
322        self.assertTrue(
323            have_some_units(
324                self.unit_scalar,
325                self.plain_array,
326                self.plain_scalar))
327        self.assertTrue(
328            have_some_units(
329                self.plain_array,
330                self.unit_scalar,
331                self.plain_scalar))
332
333
334class StripUnitsTestCase(unittest.TestCase):
335    """ strip_units should remove units from UnitArrays/UnitScalars.
336    """
337
338    def setUp(self):
339        # Make some useful data.
340        self.unit_array = UnitArray((1, 2, 3), units=meters)
341        self.unit_scalar = UnitScalar(1, units=meters)
342        self.plain_array = array([1, 2, 3])
343        self.plain_scalar = 1
344        unittest.TestCase.setUp(self)
345
346    def test_strip_units_one_arg(self):
347        self.assertFalse(isinstance(strip_units(self.unit_array),
348                                    (UnitArray, UnitScalar)))
349        self.assertFalse(isinstance(strip_units(self.unit_scalar),
350                                    (UnitArray, UnitScalar)))
351        self.assertFalse(isinstance(strip_units(self.plain_array),
352                                    (UnitArray, UnitScalar)))
353        self.assertFalse(isinstance(strip_units(self.plain_scalar),
354                                    (UnitArray, UnitScalar)))
355
356        # Check for stupidity when returning only one argument.
357        self.assertFalse(isinstance(strip_units(self.unit_scalar), tuple))
358        self.assertFalse(isinstance(strip_units(self.plain_scalar), tuple))
359
360    def test_strip_units_multi_arg(self):
361        outs = strip_units(self.unit_array, self.unit_scalar)
362        self.assertEquals(len(outs), 2)
363        for x in outs:
364            self.assertFalse(isinstance(x, (UnitArray, UnitScalar)))
365
366        outs = strip_units(self.plain_array, self.plain_scalar)
367        self.assertEquals(len(outs), 2)
368        for x in outs:
369            self.assertFalse(isinstance(x, (UnitArray, UnitScalar)))
370
371        outs = strip_units(self.unit_array, self.unit_scalar, self.plain_array,
372                           self.plain_scalar)
373        self.assertEquals(len(outs), 4)
374        for x in outs:
375            self.assertFalse(isinstance(x, (UnitArray, UnitScalar)))
376
377
378if __name__ == '__main__':
379    unittest.main()
380