1""" 2=========== 3Basic Units 4=========== 5 6""" 7 8from distutils.version import LooseVersion 9import math 10 11import numpy as np 12 13import matplotlib.units as units 14import matplotlib.ticker as ticker 15 16 17class ProxyDelegate: 18 def __init__(self, fn_name, proxy_type): 19 self.proxy_type = proxy_type 20 self.fn_name = fn_name 21 22 def __get__(self, obj, objtype=None): 23 return self.proxy_type(self.fn_name, obj) 24 25 26class TaggedValueMeta(type): 27 def __init__(self, name, bases, dict): 28 for fn_name in self._proxies: 29 if not hasattr(self, fn_name): 30 setattr(self, fn_name, 31 ProxyDelegate(fn_name, self._proxies[fn_name])) 32 33 34class PassThroughProxy: 35 def __init__(self, fn_name, obj): 36 self.fn_name = fn_name 37 self.target = obj.proxy_target 38 39 def __call__(self, *args): 40 fn = getattr(self.target, self.fn_name) 41 ret = fn(*args) 42 return ret 43 44 45class ConvertArgsProxy(PassThroughProxy): 46 def __init__(self, fn_name, obj): 47 super().__init__(fn_name, obj) 48 self.unit = obj.unit 49 50 def __call__(self, *args): 51 converted_args = [] 52 for a in args: 53 try: 54 converted_args.append(a.convert_to(self.unit)) 55 except AttributeError: 56 converted_args.append(TaggedValue(a, self.unit)) 57 converted_args = tuple([c.get_value() for c in converted_args]) 58 return super().__call__(*converted_args) 59 60 61class ConvertReturnProxy(PassThroughProxy): 62 def __init__(self, fn_name, obj): 63 super().__init__(fn_name, obj) 64 self.unit = obj.unit 65 66 def __call__(self, *args): 67 ret = super().__call__(*args) 68 return (NotImplemented if ret is NotImplemented 69 else TaggedValue(ret, self.unit)) 70 71 72class ConvertAllProxy(PassThroughProxy): 73 def __init__(self, fn_name, obj): 74 super().__init__(fn_name, obj) 75 self.unit = obj.unit 76 77 def __call__(self, *args): 78 converted_args = [] 79 arg_units = [self.unit] 80 for a in args: 81 if hasattr(a, 'get_unit') and not hasattr(a, 'convert_to'): 82 # if this arg has a unit type but no conversion ability, 83 # this operation is prohibited 84 return NotImplemented 85 86 if hasattr(a, 'convert_to'): 87 try: 88 a = a.convert_to(self.unit) 89 except Exception: 90 pass 91 arg_units.append(a.get_unit()) 92 converted_args.append(a.get_value()) 93 else: 94 converted_args.append(a) 95 if hasattr(a, 'get_unit'): 96 arg_units.append(a.get_unit()) 97 else: 98 arg_units.append(None) 99 converted_args = tuple(converted_args) 100 ret = super().__call__(*converted_args) 101 if ret is NotImplemented: 102 return NotImplemented 103 ret_unit = unit_resolver(self.fn_name, arg_units) 104 if ret_unit is NotImplemented: 105 return NotImplemented 106 return TaggedValue(ret, ret_unit) 107 108 109class TaggedValue(metaclass=TaggedValueMeta): 110 111 _proxies = {'__add__': ConvertAllProxy, 112 '__sub__': ConvertAllProxy, 113 '__mul__': ConvertAllProxy, 114 '__rmul__': ConvertAllProxy, 115 '__cmp__': ConvertAllProxy, 116 '__lt__': ConvertAllProxy, 117 '__gt__': ConvertAllProxy, 118 '__len__': PassThroughProxy} 119 120 def __new__(cls, value, unit): 121 # generate a new subclass for value 122 value_class = type(value) 123 try: 124 subcls = type(f'TaggedValue_of_{value_class.__name__}', 125 (cls, value_class), {}) 126 return object.__new__(subcls) 127 except TypeError: 128 return object.__new__(cls) 129 130 def __init__(self, value, unit): 131 self.value = value 132 self.unit = unit 133 self.proxy_target = self.value 134 135 def __getattribute__(self, name): 136 if name.startswith('__'): 137 return object.__getattribute__(self, name) 138 variable = object.__getattribute__(self, 'value') 139 if hasattr(variable, name) and name not in self.__class__.__dict__: 140 return getattr(variable, name) 141 return object.__getattribute__(self, name) 142 143 def __array__(self, dtype=object): 144 return np.asarray(self.value).astype(dtype) 145 146 def __array_wrap__(self, array, context): 147 return TaggedValue(array, self.unit) 148 149 def __repr__(self): 150 return 'TaggedValue({!r}, {!r})'.format(self.value, self.unit) 151 152 def __str__(self): 153 return str(self.value) + ' in ' + str(self.unit) 154 155 def __len__(self): 156 return len(self.value) 157 158 if LooseVersion(np.__version__) >= '1.20': 159 def __getitem__(self, key): 160 return TaggedValue(self.value[key], self.unit) 161 162 def __iter__(self): 163 # Return a generator expression rather than use `yield`, so that 164 # TypeError is raised by iter(self) if appropriate when checking for 165 # iterability. 166 return (TaggedValue(inner, self.unit) for inner in self.value) 167 168 def get_compressed_copy(self, mask): 169 new_value = np.ma.masked_array(self.value, mask=mask).compressed() 170 return TaggedValue(new_value, self.unit) 171 172 def convert_to(self, unit): 173 if unit == self.unit or not unit: 174 return self 175 try: 176 new_value = self.unit.convert_value_to(self.value, unit) 177 except AttributeError: 178 new_value = self 179 return TaggedValue(new_value, unit) 180 181 def get_value(self): 182 return self.value 183 184 def get_unit(self): 185 return self.unit 186 187 188class BasicUnit: 189 def __init__(self, name, fullname=None): 190 self.name = name 191 if fullname is None: 192 fullname = name 193 self.fullname = fullname 194 self.conversions = dict() 195 196 def __repr__(self): 197 return f'BasicUnit({self.name})' 198 199 def __str__(self): 200 return self.fullname 201 202 def __call__(self, value): 203 return TaggedValue(value, self) 204 205 def __mul__(self, rhs): 206 value = rhs 207 unit = self 208 if hasattr(rhs, 'get_unit'): 209 value = rhs.get_value() 210 unit = rhs.get_unit() 211 unit = unit_resolver('__mul__', (self, unit)) 212 if unit is NotImplemented: 213 return NotImplemented 214 return TaggedValue(value, unit) 215 216 def __rmul__(self, lhs): 217 return self*lhs 218 219 def __array_wrap__(self, array, context): 220 return TaggedValue(array, self) 221 222 def __array__(self, t=None, context=None): 223 ret = np.array(1) 224 if t is not None: 225 return ret.astype(t) 226 else: 227 return ret 228 229 def add_conversion_factor(self, unit, factor): 230 def convert(x): 231 return x*factor 232 self.conversions[unit] = convert 233 234 def add_conversion_fn(self, unit, fn): 235 self.conversions[unit] = fn 236 237 def get_conversion_fn(self, unit): 238 return self.conversions[unit] 239 240 def convert_value_to(self, value, unit): 241 conversion_fn = self.conversions[unit] 242 ret = conversion_fn(value) 243 return ret 244 245 def get_unit(self): 246 return self 247 248 249class UnitResolver: 250 def addition_rule(self, units): 251 for unit_1, unit_2 in zip(units[:-1], units[1:]): 252 if unit_1 != unit_2: 253 return NotImplemented 254 return units[0] 255 256 def multiplication_rule(self, units): 257 non_null = [u for u in units if u] 258 if len(non_null) > 1: 259 return NotImplemented 260 return non_null[0] 261 262 op_dict = { 263 '__mul__': multiplication_rule, 264 '__rmul__': multiplication_rule, 265 '__add__': addition_rule, 266 '__radd__': addition_rule, 267 '__sub__': addition_rule, 268 '__rsub__': addition_rule} 269 270 def __call__(self, operation, units): 271 if operation not in self.op_dict: 272 return NotImplemented 273 274 return self.op_dict[operation](self, units) 275 276 277unit_resolver = UnitResolver() 278 279cm = BasicUnit('cm', 'centimeters') 280inch = BasicUnit('inch', 'inches') 281inch.add_conversion_factor(cm, 2.54) 282cm.add_conversion_factor(inch, 1/2.54) 283 284radians = BasicUnit('rad', 'radians') 285degrees = BasicUnit('deg', 'degrees') 286radians.add_conversion_factor(degrees, 180.0/np.pi) 287degrees.add_conversion_factor(radians, np.pi/180.0) 288 289secs = BasicUnit('s', 'seconds') 290hertz = BasicUnit('Hz', 'Hertz') 291minutes = BasicUnit('min', 'minutes') 292 293secs.add_conversion_fn(hertz, lambda x: 1./x) 294secs.add_conversion_factor(minutes, 1/60.0) 295 296 297# radians formatting 298def rad_fn(x, pos=None): 299 if x >= 0: 300 n = int((x / np.pi) * 2.0 + 0.25) 301 else: 302 n = int((x / np.pi) * 2.0 - 0.25) 303 304 if n == 0: 305 return '0' 306 elif n == 1: 307 return r'$\pi/2$' 308 elif n == 2: 309 return r'$\pi$' 310 elif n == -1: 311 return r'$-\pi/2$' 312 elif n == -2: 313 return r'$-\pi$' 314 elif n % 2 == 0: 315 return fr'${n//2}\pi$' 316 else: 317 return fr'${n}\pi/2$' 318 319 320class BasicUnitConverter(units.ConversionInterface): 321 @staticmethod 322 def axisinfo(unit, axis): 323 """Return AxisInfo instance for x and unit.""" 324 325 if unit == radians: 326 return units.AxisInfo( 327 majloc=ticker.MultipleLocator(base=np.pi/2), 328 majfmt=ticker.FuncFormatter(rad_fn), 329 label=unit.fullname, 330 ) 331 elif unit == degrees: 332 return units.AxisInfo( 333 majloc=ticker.AutoLocator(), 334 majfmt=ticker.FormatStrFormatter(r'$%i^\circ$'), 335 label=unit.fullname, 336 ) 337 elif unit is not None: 338 if hasattr(unit, 'fullname'): 339 return units.AxisInfo(label=unit.fullname) 340 elif hasattr(unit, 'unit'): 341 return units.AxisInfo(label=unit.unit.fullname) 342 return None 343 344 @staticmethod 345 def convert(val, unit, axis): 346 if units.ConversionInterface.is_numlike(val): 347 return val 348 if np.iterable(val): 349 if isinstance(val, np.ma.MaskedArray): 350 val = val.astype(float).filled(np.nan) 351 out = np.empty(len(val)) 352 for i, thisval in enumerate(val): 353 if np.ma.is_masked(thisval): 354 out[i] = np.nan 355 else: 356 try: 357 out[i] = thisval.convert_to(unit).get_value() 358 except AttributeError: 359 out[i] = thisval 360 return out 361 if np.ma.is_masked(val): 362 return np.nan 363 else: 364 return val.convert_to(unit).get_value() 365 366 @staticmethod 367 def default_units(x, axis): 368 """Return the default unit for x or None.""" 369 if np.iterable(x): 370 for thisx in x: 371 return thisx.unit 372 return x.unit 373 374 375def cos(x): 376 if np.iterable(x): 377 return [math.cos(val.convert_to(radians).get_value()) for val in x] 378 else: 379 return math.cos(x.convert_to(radians).get_value()) 380 381 382units.registry[BasicUnit] = units.registry[TaggedValue] = BasicUnitConverter() 383