1"""
2===========
3Basic Units
4===========
5
6"""
7
8from distutils.version import LooseVersion
9import math
10
11import numpy as np
12
13import matplotlib.units as units
14import matplotlib.ticker as ticker
15
16
17class ProxyDelegate:
18    def __init__(self, fn_name, proxy_type):
19        self.proxy_type = proxy_type
20        self.fn_name = fn_name
21
22    def __get__(self, obj, objtype=None):
23        return self.proxy_type(self.fn_name, obj)
24
25
26class TaggedValueMeta(type):
27    def __init__(self, name, bases, dict):
28        for fn_name in self._proxies:
29            if not hasattr(self, fn_name):
30                setattr(self, fn_name,
31                        ProxyDelegate(fn_name, self._proxies[fn_name]))
32
33
34class PassThroughProxy:
35    def __init__(self, fn_name, obj):
36        self.fn_name = fn_name
37        self.target = obj.proxy_target
38
39    def __call__(self, *args):
40        fn = getattr(self.target, self.fn_name)
41        ret = fn(*args)
42        return ret
43
44
45class ConvertArgsProxy(PassThroughProxy):
46    def __init__(self, fn_name, obj):
47        super().__init__(fn_name, obj)
48        self.unit = obj.unit
49
50    def __call__(self, *args):
51        converted_args = []
52        for a in args:
53            try:
54                converted_args.append(a.convert_to(self.unit))
55            except AttributeError:
56                converted_args.append(TaggedValue(a, self.unit))
57        converted_args = tuple([c.get_value() for c in converted_args])
58        return super().__call__(*converted_args)
59
60
61class ConvertReturnProxy(PassThroughProxy):
62    def __init__(self, fn_name, obj):
63        super().__init__(fn_name, obj)
64        self.unit = obj.unit
65
66    def __call__(self, *args):
67        ret = super().__call__(*args)
68        return (NotImplemented if ret is NotImplemented
69                else TaggedValue(ret, self.unit))
70
71
72class ConvertAllProxy(PassThroughProxy):
73    def __init__(self, fn_name, obj):
74        super().__init__(fn_name, obj)
75        self.unit = obj.unit
76
77    def __call__(self, *args):
78        converted_args = []
79        arg_units = [self.unit]
80        for a in args:
81            if hasattr(a, 'get_unit') and not hasattr(a, 'convert_to'):
82                # if this arg has a unit type but no conversion ability,
83                # this operation is prohibited
84                return NotImplemented
85
86            if hasattr(a, 'convert_to'):
87                try:
88                    a = a.convert_to(self.unit)
89                except Exception:
90                    pass
91                arg_units.append(a.get_unit())
92                converted_args.append(a.get_value())
93            else:
94                converted_args.append(a)
95                if hasattr(a, 'get_unit'):
96                    arg_units.append(a.get_unit())
97                else:
98                    arg_units.append(None)
99        converted_args = tuple(converted_args)
100        ret = super().__call__(*converted_args)
101        if ret is NotImplemented:
102            return NotImplemented
103        ret_unit = unit_resolver(self.fn_name, arg_units)
104        if ret_unit is NotImplemented:
105            return NotImplemented
106        return TaggedValue(ret, ret_unit)
107
108
109class TaggedValue(metaclass=TaggedValueMeta):
110
111    _proxies = {'__add__': ConvertAllProxy,
112                '__sub__': ConvertAllProxy,
113                '__mul__': ConvertAllProxy,
114                '__rmul__': ConvertAllProxy,
115                '__cmp__': ConvertAllProxy,
116                '__lt__': ConvertAllProxy,
117                '__gt__': ConvertAllProxy,
118                '__len__': PassThroughProxy}
119
120    def __new__(cls, value, unit):
121        # generate a new subclass for value
122        value_class = type(value)
123        try:
124            subcls = type(f'TaggedValue_of_{value_class.__name__}',
125                          (cls, value_class), {})
126            return object.__new__(subcls)
127        except TypeError:
128            return object.__new__(cls)
129
130    def __init__(self, value, unit):
131        self.value = value
132        self.unit = unit
133        self.proxy_target = self.value
134
135    def __getattribute__(self, name):
136        if name.startswith('__'):
137            return object.__getattribute__(self, name)
138        variable = object.__getattribute__(self, 'value')
139        if hasattr(variable, name) and name not in self.__class__.__dict__:
140            return getattr(variable, name)
141        return object.__getattribute__(self, name)
142
143    def __array__(self, dtype=object):
144        return np.asarray(self.value).astype(dtype)
145
146    def __array_wrap__(self, array, context):
147        return TaggedValue(array, self.unit)
148
149    def __repr__(self):
150        return 'TaggedValue({!r}, {!r})'.format(self.value, self.unit)
151
152    def __str__(self):
153        return str(self.value) + ' in ' + str(self.unit)
154
155    def __len__(self):
156        return len(self.value)
157
158    if LooseVersion(np.__version__) >= '1.20':
159        def __getitem__(self, key):
160            return TaggedValue(self.value[key], self.unit)
161
162    def __iter__(self):
163        # Return a generator expression rather than use `yield`, so that
164        # TypeError is raised by iter(self) if appropriate when checking for
165        # iterability.
166        return (TaggedValue(inner, self.unit) for inner in self.value)
167
168    def get_compressed_copy(self, mask):
169        new_value = np.ma.masked_array(self.value, mask=mask).compressed()
170        return TaggedValue(new_value, self.unit)
171
172    def convert_to(self, unit):
173        if unit == self.unit or not unit:
174            return self
175        try:
176            new_value = self.unit.convert_value_to(self.value, unit)
177        except AttributeError:
178            new_value = self
179        return TaggedValue(new_value, unit)
180
181    def get_value(self):
182        return self.value
183
184    def get_unit(self):
185        return self.unit
186
187
188class BasicUnit:
189    def __init__(self, name, fullname=None):
190        self.name = name
191        if fullname is None:
192            fullname = name
193        self.fullname = fullname
194        self.conversions = dict()
195
196    def __repr__(self):
197        return f'BasicUnit({self.name})'
198
199    def __str__(self):
200        return self.fullname
201
202    def __call__(self, value):
203        return TaggedValue(value, self)
204
205    def __mul__(self, rhs):
206        value = rhs
207        unit = self
208        if hasattr(rhs, 'get_unit'):
209            value = rhs.get_value()
210            unit = rhs.get_unit()
211            unit = unit_resolver('__mul__', (self, unit))
212        if unit is NotImplemented:
213            return NotImplemented
214        return TaggedValue(value, unit)
215
216    def __rmul__(self, lhs):
217        return self*lhs
218
219    def __array_wrap__(self, array, context):
220        return TaggedValue(array, self)
221
222    def __array__(self, t=None, context=None):
223        ret = np.array(1)
224        if t is not None:
225            return ret.astype(t)
226        else:
227            return ret
228
229    def add_conversion_factor(self, unit, factor):
230        def convert(x):
231            return x*factor
232        self.conversions[unit] = convert
233
234    def add_conversion_fn(self, unit, fn):
235        self.conversions[unit] = fn
236
237    def get_conversion_fn(self, unit):
238        return self.conversions[unit]
239
240    def convert_value_to(self, value, unit):
241        conversion_fn = self.conversions[unit]
242        ret = conversion_fn(value)
243        return ret
244
245    def get_unit(self):
246        return self
247
248
249class UnitResolver:
250    def addition_rule(self, units):
251        for unit_1, unit_2 in zip(units[:-1], units[1:]):
252            if unit_1 != unit_2:
253                return NotImplemented
254        return units[0]
255
256    def multiplication_rule(self, units):
257        non_null = [u for u in units if u]
258        if len(non_null) > 1:
259            return NotImplemented
260        return non_null[0]
261
262    op_dict = {
263        '__mul__': multiplication_rule,
264        '__rmul__': multiplication_rule,
265        '__add__': addition_rule,
266        '__radd__': addition_rule,
267        '__sub__': addition_rule,
268        '__rsub__': addition_rule}
269
270    def __call__(self, operation, units):
271        if operation not in self.op_dict:
272            return NotImplemented
273
274        return self.op_dict[operation](self, units)
275
276
277unit_resolver = UnitResolver()
278
279cm = BasicUnit('cm', 'centimeters')
280inch = BasicUnit('inch', 'inches')
281inch.add_conversion_factor(cm, 2.54)
282cm.add_conversion_factor(inch, 1/2.54)
283
284radians = BasicUnit('rad', 'radians')
285degrees = BasicUnit('deg', 'degrees')
286radians.add_conversion_factor(degrees, 180.0/np.pi)
287degrees.add_conversion_factor(radians, np.pi/180.0)
288
289secs = BasicUnit('s', 'seconds')
290hertz = BasicUnit('Hz', 'Hertz')
291minutes = BasicUnit('min', 'minutes')
292
293secs.add_conversion_fn(hertz, lambda x: 1./x)
294secs.add_conversion_factor(minutes, 1/60.0)
295
296
297# radians formatting
298def rad_fn(x, pos=None):
299    if x >= 0:
300        n = int((x / np.pi) * 2.0 + 0.25)
301    else:
302        n = int((x / np.pi) * 2.0 - 0.25)
303
304    if n == 0:
305        return '0'
306    elif n == 1:
307        return r'$\pi/2$'
308    elif n == 2:
309        return r'$\pi$'
310    elif n == -1:
311        return r'$-\pi/2$'
312    elif n == -2:
313        return r'$-\pi$'
314    elif n % 2 == 0:
315        return fr'${n//2}\pi$'
316    else:
317        return fr'${n}\pi/2$'
318
319
320class BasicUnitConverter(units.ConversionInterface):
321    @staticmethod
322    def axisinfo(unit, axis):
323        """Return AxisInfo instance for x and unit."""
324
325        if unit == radians:
326            return units.AxisInfo(
327                majloc=ticker.MultipleLocator(base=np.pi/2),
328                majfmt=ticker.FuncFormatter(rad_fn),
329                label=unit.fullname,
330            )
331        elif unit == degrees:
332            return units.AxisInfo(
333                majloc=ticker.AutoLocator(),
334                majfmt=ticker.FormatStrFormatter(r'$%i^\circ$'),
335                label=unit.fullname,
336            )
337        elif unit is not None:
338            if hasattr(unit, 'fullname'):
339                return units.AxisInfo(label=unit.fullname)
340            elif hasattr(unit, 'unit'):
341                return units.AxisInfo(label=unit.unit.fullname)
342        return None
343
344    @staticmethod
345    def convert(val, unit, axis):
346        if units.ConversionInterface.is_numlike(val):
347            return val
348        if np.iterable(val):
349            if isinstance(val, np.ma.MaskedArray):
350                val = val.astype(float).filled(np.nan)
351            out = np.empty(len(val))
352            for i, thisval in enumerate(val):
353                if np.ma.is_masked(thisval):
354                    out[i] = np.nan
355                else:
356                    try:
357                        out[i] = thisval.convert_to(unit).get_value()
358                    except AttributeError:
359                        out[i] = thisval
360            return out
361        if np.ma.is_masked(val):
362            return np.nan
363        else:
364            return val.convert_to(unit).get_value()
365
366    @staticmethod
367    def default_units(x, axis):
368        """Return the default unit for x or None."""
369        if np.iterable(x):
370            for thisx in x:
371                return thisx.unit
372        return x.unit
373
374
375def cos(x):
376    if np.iterable(x):
377        return [math.cos(val.convert_to(radians).get_value()) for val in x]
378    else:
379        return math.cos(x.convert_to(radians).get_value())
380
381
382units.registry[BasicUnit] = units.registry[TaggedValue] = BasicUnitConverter()
383