1""" 2=========== 3Basic Units 4=========== 5 6""" 7import six 8 9import math 10 11import numpy as np 12 13import matplotlib.units as units 14import matplotlib.ticker as ticker 15from matplotlib.axes import Axes 16from matplotlib.cbook import iterable 17 18 19class ProxyDelegate(object): 20 def __init__(self, fn_name, proxy_type): 21 self.proxy_type = proxy_type 22 self.fn_name = fn_name 23 24 def __get__(self, obj, objtype=None): 25 return self.proxy_type(self.fn_name, obj) 26 27 28class TaggedValueMeta(type): 29 def __init__(cls, name, bases, dict): 30 for fn_name in cls._proxies: 31 try: 32 dummy = getattr(cls, fn_name) 33 except AttributeError: 34 setattr(cls, fn_name, 35 ProxyDelegate(fn_name, cls._proxies[fn_name])) 36 37 38class PassThroughProxy(object): 39 def __init__(self, fn_name, obj): 40 self.fn_name = fn_name 41 self.target = obj.proxy_target 42 43 def __call__(self, *args): 44 fn = getattr(self.target, self.fn_name) 45 ret = fn(*args) 46 return ret 47 48 49class ConvertArgsProxy(PassThroughProxy): 50 def __init__(self, fn_name, obj): 51 PassThroughProxy.__init__(self, fn_name, obj) 52 self.unit = obj.unit 53 54 def __call__(self, *args): 55 converted_args = [] 56 for a in args: 57 try: 58 converted_args.append(a.convert_to(self.unit)) 59 except AttributeError: 60 converted_args.append(TaggedValue(a, self.unit)) 61 converted_args = tuple([c.get_value() for c in converted_args]) 62 return PassThroughProxy.__call__(self, *converted_args) 63 64 65class ConvertReturnProxy(PassThroughProxy): 66 def __init__(self, fn_name, obj): 67 PassThroughProxy.__init__(self, fn_name, obj) 68 self.unit = obj.unit 69 70 def __call__(self, *args): 71 ret = PassThroughProxy.__call__(self, *args) 72 return (NotImplemented if ret is NotImplemented 73 else TaggedValue(ret, self.unit)) 74 75 76class ConvertAllProxy(PassThroughProxy): 77 def __init__(self, fn_name, obj): 78 PassThroughProxy.__init__(self, fn_name, obj) 79 self.unit = obj.unit 80 81 def __call__(self, *args): 82 converted_args = [] 83 arg_units = [self.unit] 84 for a in args: 85 if hasattr(a, 'get_unit') and not hasattr(a, 'convert_to'): 86 # if this arg has a unit type but no conversion ability, 87 # this operation is prohibited 88 return NotImplemented 89 90 if hasattr(a, 'convert_to'): 91 try: 92 a = a.convert_to(self.unit) 93 except: 94 pass 95 arg_units.append(a.get_unit()) 96 converted_args.append(a.get_value()) 97 else: 98 converted_args.append(a) 99 if hasattr(a, 'get_unit'): 100 arg_units.append(a.get_unit()) 101 else: 102 arg_units.append(None) 103 converted_args = tuple(converted_args) 104 ret = PassThroughProxy.__call__(self, *converted_args) 105 if ret is NotImplemented: 106 return NotImplemented 107 ret_unit = unit_resolver(self.fn_name, arg_units) 108 if ret_unit is NotImplemented: 109 return NotImplemented 110 return TaggedValue(ret, ret_unit) 111 112 113class TaggedValue(six.with_metaclass(TaggedValueMeta)): 114 115 _proxies = {'__add__': ConvertAllProxy, 116 '__sub__': ConvertAllProxy, 117 '__mul__': ConvertAllProxy, 118 '__rmul__': ConvertAllProxy, 119 '__cmp__': ConvertAllProxy, 120 '__lt__': ConvertAllProxy, 121 '__gt__': ConvertAllProxy, 122 '__len__': PassThroughProxy} 123 124 def __new__(cls, value, unit): 125 # generate a new subclass for value 126 value_class = type(value) 127 try: 128 subcls = type('TaggedValue_of_%s' % (value_class.__name__), 129 tuple([cls, value_class]), 130 {}) 131 if subcls not in units.registry: 132 units.registry[subcls] = basicConverter 133 return object.__new__(subcls) 134 except TypeError: 135 if cls not in units.registry: 136 units.registry[cls] = basicConverter 137 return object.__new__(cls) 138 139 def __init__(self, value, unit): 140 self.value = value 141 self.unit = unit 142 self.proxy_target = self.value 143 144 def __getattribute__(self, name): 145 if name.startswith('__'): 146 return object.__getattribute__(self, name) 147 variable = object.__getattribute__(self, 'value') 148 if hasattr(variable, name) and name not in self.__class__.__dict__: 149 return getattr(variable, name) 150 return object.__getattribute__(self, name) 151 152 def __array__(self, dtype=object): 153 return np.asarray(self.value).astype(dtype) 154 155 def __array_wrap__(self, array, context): 156 return TaggedValue(array, self.unit) 157 158 def __repr__(self): 159 return 'TaggedValue(' + repr(self.value) + ', ' + repr(self.unit) + ')' 160 161 def __str__(self): 162 return str(self.value) + ' in ' + str(self.unit) 163 164 def __len__(self): 165 return len(self.value) 166 167 def __iter__(self): 168 # Return a generator expression rather than use `yield`, so that 169 # TypeError is raised by iter(self) if appropriate when checking for 170 # iterability. 171 return (TaggedValue(inner, self.unit) for inner in self.value) 172 173 def get_compressed_copy(self, mask): 174 new_value = np.ma.masked_array(self.value, mask=mask).compressed() 175 return TaggedValue(new_value, self.unit) 176 177 def convert_to(self, unit): 178 if unit == self.unit or not unit: 179 return self 180 new_value = self.unit.convert_value_to(self.value, unit) 181 return TaggedValue(new_value, unit) 182 183 def get_value(self): 184 return self.value 185 186 def get_unit(self): 187 return self.unit 188 189 190class BasicUnit(object): 191 def __init__(self, name, fullname=None): 192 self.name = name 193 if fullname is None: 194 fullname = name 195 self.fullname = fullname 196 self.conversions = dict() 197 198 def __repr__(self): 199 return 'BasicUnit(%s)' % self.name 200 201 def __str__(self): 202 return self.fullname 203 204 def __call__(self, value): 205 return TaggedValue(value, self) 206 207 def __mul__(self, rhs): 208 value = rhs 209 unit = self 210 if hasattr(rhs, 'get_unit'): 211 value = rhs.get_value() 212 unit = rhs.get_unit() 213 unit = unit_resolver('__mul__', (self, unit)) 214 if unit is NotImplemented: 215 return NotImplemented 216 return TaggedValue(value, unit) 217 218 def __rmul__(self, lhs): 219 return self*lhs 220 221 def __array_wrap__(self, array, context): 222 return TaggedValue(array, self) 223 224 def __array__(self, t=None, context=None): 225 ret = np.array([1]) 226 if t is not None: 227 return ret.astype(t) 228 else: 229 return ret 230 231 def add_conversion_factor(self, unit, factor): 232 def convert(x): 233 return x*factor 234 self.conversions[unit] = convert 235 236 def add_conversion_fn(self, unit, fn): 237 self.conversions[unit] = fn 238 239 def get_conversion_fn(self, unit): 240 return self.conversions[unit] 241 242 def convert_value_to(self, value, unit): 243 conversion_fn = self.conversions[unit] 244 ret = conversion_fn(value) 245 return ret 246 247 def get_unit(self): 248 return self 249 250 251class UnitResolver(object): 252 def addition_rule(self, units): 253 for unit_1, unit_2 in zip(units[:-1], units[1:]): 254 if (unit_1 != unit_2): 255 return NotImplemented 256 return units[0] 257 258 def multiplication_rule(self, units): 259 non_null = [u for u in units if u] 260 if (len(non_null) > 1): 261 return NotImplemented 262 return non_null[0] 263 264 op_dict = { 265 '__mul__': multiplication_rule, 266 '__rmul__': multiplication_rule, 267 '__add__': addition_rule, 268 '__radd__': addition_rule, 269 '__sub__': addition_rule, 270 '__rsub__': addition_rule} 271 272 def __call__(self, operation, units): 273 if (operation not in self.op_dict): 274 return NotImplemented 275 276 return self.op_dict[operation](self, units) 277 278 279unit_resolver = UnitResolver() 280 281cm = BasicUnit('cm', 'centimeters') 282inch = BasicUnit('inch', 'inches') 283inch.add_conversion_factor(cm, 2.54) 284cm.add_conversion_factor(inch, 1/2.54) 285 286radians = BasicUnit('rad', 'radians') 287degrees = BasicUnit('deg', 'degrees') 288radians.add_conversion_factor(degrees, 180.0/np.pi) 289degrees.add_conversion_factor(radians, np.pi/180.0) 290 291secs = BasicUnit('s', 'seconds') 292hertz = BasicUnit('Hz', 'Hertz') 293minutes = BasicUnit('min', 'minutes') 294 295secs.add_conversion_fn(hertz, lambda x: 1./x) 296secs.add_conversion_factor(minutes, 1/60.0) 297 298 299# radians formatting 300def rad_fn(x, pos=None): 301 if x >= 0: 302 n = int((x / np.pi) * 2.0 + 0.25) 303 else: 304 n = int((x / np.pi) * 2.0 - 0.25) 305 306 if n == 0: 307 return '0' 308 elif n == 1: 309 return r'$\pi/2$' 310 elif n == 2: 311 return r'$\pi$' 312 elif n == -1: 313 return r'$-\pi/2$' 314 elif n == -2: 315 return r'$-\pi$' 316 elif n % 2 == 0: 317 return r'$%s\pi$' % (n//2,) 318 else: 319 return r'$%s\pi/2$' % (n,) 320 321 322class BasicUnitConverter(units.ConversionInterface): 323 @staticmethod 324 def axisinfo(unit, axis): 325 'return AxisInfo instance for x and unit' 326 327 if unit == radians: 328 return units.AxisInfo( 329 majloc=ticker.MultipleLocator(base=np.pi/2), 330 majfmt=ticker.FuncFormatter(rad_fn), 331 label=unit.fullname, 332 ) 333 elif unit == degrees: 334 return units.AxisInfo( 335 majloc=ticker.AutoLocator(), 336 majfmt=ticker.FormatStrFormatter(r'$%i^\circ$'), 337 label=unit.fullname, 338 ) 339 elif unit is not None: 340 if hasattr(unit, 'fullname'): 341 return units.AxisInfo(label=unit.fullname) 342 elif hasattr(unit, 'unit'): 343 return units.AxisInfo(label=unit.unit.fullname) 344 return None 345 346 @staticmethod 347 def convert(val, unit, axis): 348 if units.ConversionInterface.is_numlike(val): 349 return val 350 if iterable(val): 351 return [thisval.convert_to(unit).get_value() for thisval in val] 352 else: 353 return val.convert_to(unit).get_value() 354 355 @staticmethod 356 def default_units(x, axis): 357 'return the default unit for x or None' 358 if iterable(x): 359 for thisx in x: 360 return thisx.unit 361 return x.unit 362 363 364def cos(x): 365 if iterable(x): 366 return [math.cos(val.convert_to(radians).get_value()) for val in x] 367 else: 368 return math.cos(x.convert_to(radians).get_value()) 369 370 371basicConverter = BasicUnitConverter() 372units.registry[BasicUnit] = basicConverter 373units.registry[TaggedValue] = basicConverter 374