1"""
2===========
3Basic Units
4===========
5
6"""
7import six
8
9import math
10
11import numpy as np
12
13import matplotlib.units as units
14import matplotlib.ticker as ticker
15from matplotlib.axes import Axes
16from matplotlib.cbook import iterable
17
18
19class ProxyDelegate(object):
20    def __init__(self, fn_name, proxy_type):
21        self.proxy_type = proxy_type
22        self.fn_name = fn_name
23
24    def __get__(self, obj, objtype=None):
25        return self.proxy_type(self.fn_name, obj)
26
27
28class TaggedValueMeta(type):
29    def __init__(cls, name, bases, dict):
30        for fn_name in cls._proxies:
31            try:
32                dummy = getattr(cls, fn_name)
33            except AttributeError:
34                setattr(cls, fn_name,
35                        ProxyDelegate(fn_name, cls._proxies[fn_name]))
36
37
38class PassThroughProxy(object):
39    def __init__(self, fn_name, obj):
40        self.fn_name = fn_name
41        self.target = obj.proxy_target
42
43    def __call__(self, *args):
44        fn = getattr(self.target, self.fn_name)
45        ret = fn(*args)
46        return ret
47
48
49class ConvertArgsProxy(PassThroughProxy):
50    def __init__(self, fn_name, obj):
51        PassThroughProxy.__init__(self, fn_name, obj)
52        self.unit = obj.unit
53
54    def __call__(self, *args):
55        converted_args = []
56        for a in args:
57            try:
58                converted_args.append(a.convert_to(self.unit))
59            except AttributeError:
60                converted_args.append(TaggedValue(a, self.unit))
61        converted_args = tuple([c.get_value() for c in converted_args])
62        return PassThroughProxy.__call__(self, *converted_args)
63
64
65class ConvertReturnProxy(PassThroughProxy):
66    def __init__(self, fn_name, obj):
67        PassThroughProxy.__init__(self, fn_name, obj)
68        self.unit = obj.unit
69
70    def __call__(self, *args):
71        ret = PassThroughProxy.__call__(self, *args)
72        return (NotImplemented if ret is NotImplemented
73                else TaggedValue(ret, self.unit))
74
75
76class ConvertAllProxy(PassThroughProxy):
77    def __init__(self, fn_name, obj):
78        PassThroughProxy.__init__(self, fn_name, obj)
79        self.unit = obj.unit
80
81    def __call__(self, *args):
82        converted_args = []
83        arg_units = [self.unit]
84        for a in args:
85            if hasattr(a, 'get_unit') and not hasattr(a, 'convert_to'):
86                # if this arg has a unit type but no conversion ability,
87                # this operation is prohibited
88                return NotImplemented
89
90            if hasattr(a, 'convert_to'):
91                try:
92                    a = a.convert_to(self.unit)
93                except:
94                    pass
95                arg_units.append(a.get_unit())
96                converted_args.append(a.get_value())
97            else:
98                converted_args.append(a)
99                if hasattr(a, 'get_unit'):
100                    arg_units.append(a.get_unit())
101                else:
102                    arg_units.append(None)
103        converted_args = tuple(converted_args)
104        ret = PassThroughProxy.__call__(self, *converted_args)
105        if ret is NotImplemented:
106            return NotImplemented
107        ret_unit = unit_resolver(self.fn_name, arg_units)
108        if ret_unit is NotImplemented:
109            return NotImplemented
110        return TaggedValue(ret, ret_unit)
111
112
113class TaggedValue(six.with_metaclass(TaggedValueMeta)):
114
115    _proxies = {'__add__': ConvertAllProxy,
116                '__sub__': ConvertAllProxy,
117                '__mul__': ConvertAllProxy,
118                '__rmul__': ConvertAllProxy,
119                '__cmp__': ConvertAllProxy,
120                '__lt__': ConvertAllProxy,
121                '__gt__': ConvertAllProxy,
122                '__len__': PassThroughProxy}
123
124    def __new__(cls, value, unit):
125        # generate a new subclass for value
126        value_class = type(value)
127        try:
128            subcls = type('TaggedValue_of_%s' % (value_class.__name__),
129                          tuple([cls, value_class]),
130                          {})
131            if subcls not in units.registry:
132                units.registry[subcls] = basicConverter
133            return object.__new__(subcls)
134        except TypeError:
135            if cls not in units.registry:
136                units.registry[cls] = basicConverter
137            return object.__new__(cls)
138
139    def __init__(self, value, unit):
140        self.value = value
141        self.unit = unit
142        self.proxy_target = self.value
143
144    def __getattribute__(self, name):
145        if name.startswith('__'):
146            return object.__getattribute__(self, name)
147        variable = object.__getattribute__(self, 'value')
148        if hasattr(variable, name) and name not in self.__class__.__dict__:
149            return getattr(variable, name)
150        return object.__getattribute__(self, name)
151
152    def __array__(self, dtype=object):
153        return np.asarray(self.value).astype(dtype)
154
155    def __array_wrap__(self, array, context):
156        return TaggedValue(array, self.unit)
157
158    def __repr__(self):
159        return 'TaggedValue(' + repr(self.value) + ', ' + repr(self.unit) + ')'
160
161    def __str__(self):
162        return str(self.value) + ' in ' + str(self.unit)
163
164    def __len__(self):
165        return len(self.value)
166
167    def __iter__(self):
168        # Return a generator expression rather than use `yield`, so that
169        # TypeError is raised by iter(self) if appropriate when checking for
170        # iterability.
171        return (TaggedValue(inner, self.unit) for inner in self.value)
172
173    def get_compressed_copy(self, mask):
174        new_value = np.ma.masked_array(self.value, mask=mask).compressed()
175        return TaggedValue(new_value, self.unit)
176
177    def convert_to(self, unit):
178        if unit == self.unit or not unit:
179            return self
180        new_value = self.unit.convert_value_to(self.value, unit)
181        return TaggedValue(new_value, unit)
182
183    def get_value(self):
184        return self.value
185
186    def get_unit(self):
187        return self.unit
188
189
190class BasicUnit(object):
191    def __init__(self, name, fullname=None):
192        self.name = name
193        if fullname is None:
194            fullname = name
195        self.fullname = fullname
196        self.conversions = dict()
197
198    def __repr__(self):
199        return 'BasicUnit(%s)' % self.name
200
201    def __str__(self):
202        return self.fullname
203
204    def __call__(self, value):
205        return TaggedValue(value, self)
206
207    def __mul__(self, rhs):
208        value = rhs
209        unit = self
210        if hasattr(rhs, 'get_unit'):
211            value = rhs.get_value()
212            unit = rhs.get_unit()
213            unit = unit_resolver('__mul__', (self, unit))
214        if unit is NotImplemented:
215            return NotImplemented
216        return TaggedValue(value, unit)
217
218    def __rmul__(self, lhs):
219        return self*lhs
220
221    def __array_wrap__(self, array, context):
222        return TaggedValue(array, self)
223
224    def __array__(self, t=None, context=None):
225        ret = np.array([1])
226        if t is not None:
227            return ret.astype(t)
228        else:
229            return ret
230
231    def add_conversion_factor(self, unit, factor):
232        def convert(x):
233            return x*factor
234        self.conversions[unit] = convert
235
236    def add_conversion_fn(self, unit, fn):
237        self.conversions[unit] = fn
238
239    def get_conversion_fn(self, unit):
240        return self.conversions[unit]
241
242    def convert_value_to(self, value, unit):
243        conversion_fn = self.conversions[unit]
244        ret = conversion_fn(value)
245        return ret
246
247    def get_unit(self):
248        return self
249
250
251class UnitResolver(object):
252    def addition_rule(self, units):
253        for unit_1, unit_2 in zip(units[:-1], units[1:]):
254            if (unit_1 != unit_2):
255                return NotImplemented
256        return units[0]
257
258    def multiplication_rule(self, units):
259        non_null = [u for u in units if u]
260        if (len(non_null) > 1):
261            return NotImplemented
262        return non_null[0]
263
264    op_dict = {
265        '__mul__': multiplication_rule,
266        '__rmul__': multiplication_rule,
267        '__add__': addition_rule,
268        '__radd__': addition_rule,
269        '__sub__': addition_rule,
270        '__rsub__': addition_rule}
271
272    def __call__(self, operation, units):
273        if (operation not in self.op_dict):
274            return NotImplemented
275
276        return self.op_dict[operation](self, units)
277
278
279unit_resolver = UnitResolver()
280
281cm = BasicUnit('cm', 'centimeters')
282inch = BasicUnit('inch', 'inches')
283inch.add_conversion_factor(cm, 2.54)
284cm.add_conversion_factor(inch, 1/2.54)
285
286radians = BasicUnit('rad', 'radians')
287degrees = BasicUnit('deg', 'degrees')
288radians.add_conversion_factor(degrees, 180.0/np.pi)
289degrees.add_conversion_factor(radians, np.pi/180.0)
290
291secs = BasicUnit('s', 'seconds')
292hertz = BasicUnit('Hz', 'Hertz')
293minutes = BasicUnit('min', 'minutes')
294
295secs.add_conversion_fn(hertz, lambda x: 1./x)
296secs.add_conversion_factor(minutes, 1/60.0)
297
298
299# radians formatting
300def rad_fn(x, pos=None):
301    if x >= 0:
302        n = int((x / np.pi) * 2.0 + 0.25)
303    else:
304        n = int((x / np.pi) * 2.0 - 0.25)
305
306    if n == 0:
307        return '0'
308    elif n == 1:
309        return r'$\pi/2$'
310    elif n == 2:
311        return r'$\pi$'
312    elif n == -1:
313        return r'$-\pi/2$'
314    elif n == -2:
315        return r'$-\pi$'
316    elif n % 2 == 0:
317        return r'$%s\pi$' % (n//2,)
318    else:
319        return r'$%s\pi/2$' % (n,)
320
321
322class BasicUnitConverter(units.ConversionInterface):
323    @staticmethod
324    def axisinfo(unit, axis):
325        'return AxisInfo instance for x and unit'
326
327        if unit == radians:
328            return units.AxisInfo(
329                majloc=ticker.MultipleLocator(base=np.pi/2),
330                majfmt=ticker.FuncFormatter(rad_fn),
331                label=unit.fullname,
332            )
333        elif unit == degrees:
334            return units.AxisInfo(
335                majloc=ticker.AutoLocator(),
336                majfmt=ticker.FormatStrFormatter(r'$%i^\circ$'),
337                label=unit.fullname,
338            )
339        elif unit is not None:
340            if hasattr(unit, 'fullname'):
341                return units.AxisInfo(label=unit.fullname)
342            elif hasattr(unit, 'unit'):
343                return units.AxisInfo(label=unit.unit.fullname)
344        return None
345
346    @staticmethod
347    def convert(val, unit, axis):
348        if units.ConversionInterface.is_numlike(val):
349            return val
350        if iterable(val):
351            return [thisval.convert_to(unit).get_value() for thisval in val]
352        else:
353            return val.convert_to(unit).get_value()
354
355    @staticmethod
356    def default_units(x, axis):
357        'return the default unit for x or None'
358        if iterable(x):
359            for thisx in x:
360                return thisx.unit
361        return x.unit
362
363
364def cos(x):
365    if iterable(x):
366        return [math.cos(val.convert_to(radians).get_value()) for val in x]
367    else:
368        return math.cos(x.convert_to(radians).get_value())
369
370
371basicConverter = BasicUnitConverter()
372units.registry[BasicUnit] = basicConverter
373units.registry[TaggedValue] = basicConverter
374