1# Copyright 2019 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from contextlib import contextmanager 16from collections import Counter, namedtuple 17from functools import partial, reduce 18from itertools import chain, product 19import operator as op 20import string 21from typing import Callable, Dict, Optional, Sequence, Union, Tuple 22 23import numpy as np 24 25from .. import core, dtypes 26from ..tree_util import tree_unflatten 27from ..core import ShapedArray, Trace, Tracer 28from .._src.util import safe_map, safe_zip, unzip2, prod, wrap_name 29from .. import linear_util as lu 30 31map = safe_map 32zip = safe_zip 33 34masking_rules: Dict[core.Primitive, Callable] = {} 35 36def defvectorized(prim): 37 masking_rules[prim] = partial(vectorized_masking_rule, prim) 38 39def defnaryop(prim): 40 masking_rules[prim] = partial(naryop_masking_rule, prim) 41 42def vectorized_masking_rule(prim, padded_vals, logical_shapes, **params): 43 del logical_shapes # Unused. 44 padded_val, = padded_vals 45 return prim.bind(padded_val, **params) 46 47def naryop_masking_rule(prim, padded_vals, logical_shapes): 48 del logical_shapes # Unused. 49 return prim.bind(*padded_vals) 50 51ShapeEnvs = namedtuple("ShapeEnvs", ["logical", "padded"]) 52shape_envs = ShapeEnvs({}, {}) # TODO(mattjj): make this a stack for efficiency 53 54def is_tracing(): 55 return bool(shape_envs.padded) 56 57@contextmanager 58def extend_shape_envs(logical_env, padded_env): 59 global shape_envs 60 new_logical = dict(chain(shape_envs.logical.items(), logical_env.items())) 61 new_padded = dict(chain(shape_envs.padded.items(), padded_env.items())) 62 shape_envs, prev = ShapeEnvs(new_logical, new_padded), shape_envs 63 try: 64 yield 65 finally: 66 shape_envs = prev 67 68def shape_as_value(shape): 69 assert is_tracing() or not is_polymorphic(shape) 70 return eval_poly_shape(shape, shape_envs.logical) 71 72def padded_shape_as_value(shape): 73 assert is_tracing() or not is_polymorphic(shape) 74 return eval_poly_shape(shape, shape_envs.padded) 75 76def mask_fun(fun, logical_env, padded_env, in_vals, polymorphic_shapes): 77 env_keys, padded_env_vals = unzip2(sorted(padded_env.items())) 78 logical_env_vals = [logical_env[k] for k in env_keys] 79 # Make padded_env hashable 80 padded_env = (env_keys, padded_env_vals) 81 with core.new_main(MaskTrace) as main: 82 fun, out_shapes = mask_subtrace(fun, main, polymorphic_shapes, padded_env) 83 out_vals = fun.call_wrapped(*(logical_env_vals + in_vals)) 84 del main 85 return out_vals, out_shapes() 86 87@lu.transformation_with_aux 88def mask_subtrace(main, shapes, padded_env, *in_vals): 89 env_keys, _ = padded_env 90 logical_env_vals, in_vals = in_vals[:len(env_keys)], in_vals[len(env_keys):] 91 logical_env = dict(zip(env_keys, logical_env_vals)) 92 padded_env = dict(zip(*padded_env)) 93 trace = MaskTrace(main, core.cur_sublevel()) 94 in_tracers = [MaskTracer(trace, x, s).full_lower() 95 for x, s in zip(in_vals, shapes)] 96 with extend_shape_envs(logical_env, padded_env): 97 outs = yield in_tracers, {} 98 out_tracers = map(trace.full_raise, outs) 99 out_vals, out_shapes = unzip2((t.val, t.polymorphic_shape) for t in out_tracers) 100 yield out_vals, out_shapes 101 102def eval_poly_shape(shape, values_dict): 103 return tuple(eval_poly(dim, values_dict) for dim in shape) 104 105def eval_poly(poly, values_dict): 106 return poly.evaluate(values_dict) if type(poly) is Poly else poly 107 108def _ensure_poly(p: 'Size') -> 'Poly': 109 if isinstance(p, Poly): return p 110 return Poly({Mon(): p}) 111 112def _polys_to_ints(shape): 113 return tuple(int(d) if type(d) is Poly and d.is_constant else d 114 for d in shape) 115 116def is_polymorphic(shape: Sequence['Size']): 117 return any(map(lambda d: type(d) is Poly, shape)) 118 119class UndefinedPoly(Exception): 120 """Exception raised when an operation involving polynomials is not defined. 121 122 An operation `op` on polynomials `p1` and `p2` either raises this exception, 123 or produce a polynomial `res`, such that `op(Val(p1), Val(p2)) = Val(res)`, 124 for any `Val`, a non-negative integer valuation of the shape variables. 125 """ 126 pass 127 128class Poly(dict): 129 """Polynomial with integer coefficients for polymorphic shapes. 130 131 The shape variables are assumed to range over non-negative integers. 132 133 We overload integer operations, but we do that soundly, raising 134 :class:`UndefinedPoly` when the result is not representable as a polynomial. 135 136 The representation of a polynomial is as a dictionary mapping monomials to 137 integer coefficients. The special monomial `Mon()` is mapped to the 138 free integer coefficient of the polynomial. 139 """ 140 141 def __init__(self, coeffs: Dict['Mon', int]): 142 # Makes sure Polynomials are always in canonical form 143 coeffs = {mon: op.index(coeff) 144 for mon, coeff in coeffs.items() if coeff != 0} 145 coeffs = coeffs or {Mon(): 0} 146 super().__init__(coeffs) 147 148 def __hash__(self): 149 return hash(tuple(sorted(self.items()))) 150 151 def __add__(self, other: 'Size') -> 'Poly': 152 coeffs = self.copy() 153 for mon, coeff in _ensure_poly(other).items(): 154 coeffs[mon] = coeffs.get(mon, 0) + coeff 155 return Poly(coeffs) 156 157 def __sub__(self, other: 'Size') -> 'Poly': 158 return self + -other 159 160 def __neg__(self) -> 'Poly': 161 return Poly({mon: -coeff for mon, coeff in self.items()}) 162 163 def __mul__(self, other: 'Size') -> 'Poly': 164 other = _ensure_poly(other) 165 coeffs: Dict[Mon, int] = {} 166 for (mon1, coeff1), (mon2, coeff2) in product(self.items(), other.items()): 167 mon = mon1 * mon2 168 coeffs[mon] = coeffs.get(mon, 0) + coeff1 * coeff2 169 return Poly(coeffs) 170 171 def __rmul__(self, other: 'Size') -> 'Poly': 172 return self * other # multiplication commutes 173 174 def __radd__(self, other: 'Size') -> 'Poly': 175 return self + other # addition commutes 176 177 def __rsub__(self, other: 'Size') -> 'Poly': 178 return _ensure_poly(other) - self 179 180 def __floordiv__(self, divisor: 'Size') -> 'Poly': 181 q, _ = divmod(self, divisor) # type: ignore 182 return q 183 184 def __mod__(self, divisor: 'Size') -> int: 185 _, r = divmod(self, divisor) # type: ignore 186 return r 187 188 def __divmod__(self, divisor: 'Size') -> Tuple['Poly', int]: 189 """ 190 Floor division with remainder (divmod) generalized to polynomials. To allow 191 ensuring '0 <= remainder < divisor' for consistency with integer divmod, the 192 divisor must divide the dividend (up to a constant for constant divisors). 193 :return: Quotient resulting from polynomial division and integer remainder. 194 """ 195 divisor = _ensure_poly(divisor) 196 dmon, dcount = divisor._leading_term 197 dividend, quotient, remainder = self, _ensure_poly(0), _ensure_poly(0) 198 while not dividend.is_constant or dividend != 0: # invariant: dividend == divisor*quotient + remainder 199 mon, count = dividend._leading_term 200 qcount, rcount = divmod(count, dcount) 201 try: 202 qmon = mon // dmon 203 except UndefinedPoly: 204 raise UndefinedPoly(f"Stride {divisor} must divide size {self} " 205 "(up to a constant for constant divisors).") 206 r = Poly({mon: rcount}) 207 q = Poly({qmon: qcount}) 208 quotient += q 209 remainder += r 210 dividend -= q * divisor + r 211 return quotient, int(remainder) 212 213 def __rdivmod__(self, dividend: 'Size') -> Tuple['Poly', int]: 214 return divmod(_ensure_poly(dividend), self) # type: ignore 215 216 def __eq__(self, other): 217 lb, ub = (self - other).bounds() 218 if lb == ub == 0: 219 return True 220 if lb is not None and lb > 0: 221 return False 222 if ub is not None and ub < 0: 223 return False 224 raise UndefinedPoly(f"Polynomial comparison {self} == {other} is inconclusive") 225 226 def __ne__(self, other): 227 return not self == other 228 229 def __ge__(self, other: 'Size'): 230 lb, ub = (self - other).bounds() 231 if lb is not None and lb >= 0: 232 return True 233 if ub is not None and ub < 0: 234 return False 235 raise UndefinedPoly(f"Polynomial comparison {self} >= {other} is inconclusive") 236 237 def __le__(self, other: 'Size'): 238 return _ensure_poly(other) >= self 239 240 def __lt__(self, other: 'Size'): 241 return not (self >= other) 242 243 def __gt__(self, other: 'Size'): 244 return not (_ensure_poly(other) >= self) 245 246 def __str__(self): 247 return ' + '.join(f'{c} {mon}' if c != 1 or mon.degree == 0 else str(mon) 248 for mon, c in sorted(self.items(), reverse=True)).strip() 249 250 def __repr__(self): 251 return str(self) 252 253 def __int__(self): 254 if self.is_constant: 255 return op.index(next(iter(self.values()))) 256 else: 257 raise UndefinedPoly(f"Polynomial {self} is not constant") 258 259 def bounds(self) -> Tuple[Optional[int], Optional[int]]: 260 """Returns the lower and upper bounds, if defined.""" 261 lb = ub = self.get(Mon(), 0) 262 for mon, coeff in self.items(): 263 if mon.degree > 0: 264 if coeff > 0: 265 ub = None 266 else: 267 lb = None 268 return lb, ub 269 270 def evaluate(self, env): 271 prod = lambda xs: reduce(op.mul, xs) if xs else 1 272 terms = [mul(coeff, prod([pow(env[id], deg) for id, deg in mon.items()])) 273 for mon, coeff in self.items()] 274 return sum(terms) if len(terms) > 1 else terms[0] 275 276 @property 277 def is_constant(self): 278 return len(self) == 1 and next(iter(self)).degree == 0 279 280 @property 281 def _leading_term(self) -> Tuple['Mon', int]: 282 """Returns the highest degree term that comes first lexicographically.""" 283 return max(self.items()) 284 285Size = Union[int, Poly] 286 287def pow(x, deg): 288 try: 289 deg = int(deg) 290 except: 291 return x ** deg 292 else: 293 return 1 if deg == 0 else x if deg == 1 else x ** deg 294 295def mul(coeff, mon): 296 try: 297 coeff = int(coeff) 298 except: 299 return coeff * mon 300 else: 301 return 0 if coeff == 0 else mon if coeff == 1 else coeff * mon 302 303 304core._DIMENSION_TYPES.add(Poly) 305 306class Mon(dict): 307 # TODO: move this before Poly in the file 308 """Represents a multivariate monomial, such as n^3 * m. 309 310 The representation is a dictionary mapping var:exponent. The 311 exponent is >= 1. 312 """ 313 def __hash__(self): 314 return hash(frozenset(self.items())) 315 316 def __str__(self): 317 return ' '.join(f'{key}^{exponent}' if exponent != 1 else str(key) 318 for key, exponent in sorted(self.items())) 319 320 def __lt__(self, other: 'Mon'): 321 # TODO: do not override __lt__ for this 322 """ 323 Comparison to another monomial in graded reverse lexicographic order. 324 """ 325 self_key = -self.degree, tuple(sorted(self)) 326 other_key = -other.degree, tuple(sorted(other)) 327 return self_key > other_key 328 329 def __mul__(self, other: 'Mon') -> 'Mon': 330 """ 331 Returns the product with another monomial. Example: (n^2*m) * n == n^3 * m. 332 """ 333 return Mon(Counter(self) + Counter(other)) 334 335 @property 336 def degree(self): 337 return sum(self.values()) 338 339 def __floordiv__(self, divisor: 'Mon') -> 'Mon': 340 """ 341 Divides by another monomial. Raises a ValueError if impossible. 342 For example, (n^3 * m) // n == n^2*m, but n // m fails. 343 """ 344 d = Counter(self) 345 for key, exponent in divisor.items(): 346 diff = self.get(key, 0) - exponent 347 if diff < 0: raise UndefinedPoly(f"Cannot divide {self} by {divisor}.") 348 elif diff == 0: del d[key] 349 elif diff > 0: d[key] = diff 350 return Mon(d) 351 352class ShapeError(Exception): pass 353 354class ShapeSyntaxError(Exception): pass 355 356# To denote some shape expressions (for annotations) we use a small language. 357# 358# data ShapeSpec = ShapeSpec [Dim] 359# data Dim = Id PyObj 360# | Lit Int 361# | Mul Dim Dim 362# | Add Dim Dim 363# | MonomorphicDim 364# 365# We'll also make a simple concrete syntax for annotation. The grammar is 366# 367# shape_spec ::= '(' dims ')' 368# dims ::= dim ',' dims | '' 369# dim ::= str | int | dim '*' dim | dim '+' dim | '_' 370# 371# ShapeSpecs can have some monomorphic dims inside them, which must be replaced 372# with concrete shapes when known. 373 374class ShapeSpec(tuple): 375 def __str__(self): 376 return 'ShapeSpec({})'.format(', '.join(map(str, self))) 377 378def finalize_spec(polymorphic_shape, padded_shape): 379 # TODO: what if polymorphic_shape has a constant that does not match padded_shape? 380 return tuple(_parse_lit(d) if e is _monomorphic_dim else e 381 for e, d in zip(polymorphic_shape, padded_shape)) 382 383def parse_spec(spec=''): 384 if not spec: 385 return ShapeSpec(()) 386 if spec[0] == '(': 387 if spec[-1] != ')': raise ShapeSyntaxError(spec) 388 spec = spec[1:-1] 389 dims = map(_parse_dim, spec.replace(' ', '').strip(',').split(',')) 390 return ShapeSpec(dims) 391 392def _parse_dim(spec): 393 if '+' in spec: 394 return np.sum(map(_parse_dim, spec.split('+'))) 395 elif '*' in spec: 396 return prod(map(_parse_dim, spec.split('*'))) 397 elif spec.isdigit() or spec.startswith('-') and spec[1:].isdigit(): 398 return _parse_lit(spec) 399 elif spec[0] in _identifiers: 400 return _parse_id(spec) 401 elif spec == '_': 402 return _monomorphic_dim 403 else: 404 raise ShapeSyntaxError(spec) 405 406_identifiers = frozenset(string.ascii_lowercase) 407 408def _parse_id(name): return Poly({Mon({name: 1}): 1}) 409 410def _parse_lit(val_str): return int(val_str) 411 412class MonomorphicDim(object): 413 def __str__(self): return '_' 414 415_monomorphic_dim = MonomorphicDim() 416 417# Two convenient ways to provide shape annotations: 418# 1. '(m, n)' 419# 2. s_['m', 'n'] 420 421class S_(object): 422 def __getitem__(self, idx): 423 return parse_spec(('(' + ','.join(map(str, idx)) + ')') 424 if type(idx) is tuple else str(idx)) 425 426s_ = S_() 427 428def _shape_spec_consistent(spec, expr): 429 return all(a == b for a, b in zip(spec, expr) if a is not _monomorphic_dim) 430 431class MaskTracer(Tracer): 432 __slots__ = ["val", "polymorphic_shape"] 433 434 def __init__(self, trace, val, polymorphic_shape): 435 super().__init__(trace) 436 self.val = val 437 self.polymorphic_shape = polymorphic_shape 438 439 @property 440 def aval(self): 441 return ShapedArray(self.polymorphic_shape, self.dtype) 442 443 @property 444 def dtype(self): 445 return dtypes.dtype(self.val) 446 447 def is_pure(self): 448 return all(type(poly) is not Poly or poly.is_constant 449 for poly in self.polymorphic_shape) 450 451 def full_lower(self): 452 if self.is_pure(): 453 return core.full_lower(self.val) 454 else: 455 return self 456 457 458class MaskTrace(Trace): 459 def pure(self, val): 460 return MaskTracer(self, val, np.shape(val)) 461 462 def lift(self, val): 463 return MaskTracer(self, val, np.shape(val)) 464 465 def sublift(self, val): 466 return MaskTracer(self, val.val, val.polymorphic_shape) 467 468 def process_primitive(self, primitive, tracers, params): 469 masking_rule = masking_rules.get(primitive) 470 if masking_rule is None: 471 raise NotImplementedError( 472 f'Masking rule for {primitive} not implemented yet.') 473 out_aval = primitive.abstract_eval(*(t.aval for t in tracers), **params) 474 vals, polymorphic_shapes = unzip2((t.val, t.polymorphic_shape) for t in tracers) 475 logical_shapes = map(shape_as_value, polymorphic_shapes) 476 # TODO(mattjj): generalize mask rule signature 477 if primitive.name == 'reshape': params['polymorphic_shapes'] = polymorphic_shapes 478 out = masking_rule(vals, logical_shapes, **params) 479 if primitive.multiple_results: 480 out_shapes = map(_polys_to_ints, [o.shape for o in out_aval]) 481 return map(partial(MaskTracer, self), out, out_shapes) 482 else: 483 return MaskTracer(self, out, _polys_to_ints(out_aval.shape)) 484 485 def process_call(self, call_primitive, f, tracers, params): 486 assert call_primitive.multiple_results 487 params = dict(params, name=wrap_name(params.get('name', f.__name__), 'mask')) 488 vals, shapes = unzip2((t.val, t.polymorphic_shape) for t in tracers) 489 if not any(is_polymorphic(s) for s in shapes): 490 return call_primitive.bind(f, *vals, **params) 491 else: 492 logical_env, padded_env = shape_envs 493 env_keys, padded_env_vals = unzip2(sorted(padded_env.items())) 494 logical_env_vals = tuple(logical_env[k] for k in env_keys) 495 # Make padded_env hashable 496 padded_env = (env_keys, padded_env_vals) 497 f, shapes_out = mask_subtrace(f, self.main, shapes, padded_env) 498 if 'donated_invars' in params: 499 params = dict(params, donated_invars=((False,) * len(logical_env_vals) + 500 params['donated_invars'])) 501 vals_out = call_primitive.bind(f, *(logical_env_vals + vals), **params) 502 return [MaskTracer(self, v, s) for v, s in zip(vals_out, shapes_out())] 503 504 def post_process_call(self, call_primitive, out_tracers, params): 505 vals, shapes = unzip2((t.val, t.polymorphic_shape) for t in out_tracers) 506 main = self.main 507 def todo(vals): 508 trace = MaskTrace(main, core.cur_sublevel()) 509 return map(partial(MaskTracer, trace), vals, shapes) 510 return vals, todo 511 512class UniqueId: 513 def __init__(self, name): 514 self.name = name 515 516 def __repr__(self): 517 return self.name 518 519 def __lt__(self, other): 520 return self.name < other.name 521 522class UniqueIds(dict): 523 def __missing__(self, key): 524 unique_id = UniqueId(key) 525 self[key] = unique_id 526 return unique_id 527 528def remap_ids(names, shape_spec): 529 return ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()}) 530 : coeff for mon, coeff in poly.items()}) 531 if isinstance(poly, Poly) else 532 poly for poly in shape_spec) 533 534def bind_shapes(polymorphic_shapes, padded_shapes): 535 env = {} 536 for polymorphic_shape, padded_shape in zip(polymorphic_shapes, padded_shapes): 537 for poly, d in zip(polymorphic_shape, padded_shape): 538 if type(poly) is not Poly or poly.is_constant: 539 if int(poly) != d: raise ShapeError 540 else: 541 poly = poly.copy() 542 const_coeff = poly.pop(Mon({}), 0) 543 (mon, linear_coeff), = poly.items() 544 (id, index), = mon.items() 545 if index != 1: raise ShapeError 546 d, r = divmod(d - const_coeff, linear_coeff) 547 assert r == 0 548 if env.setdefault(id, d) != d: raise ShapeError 549 return env 550 551def check_shapes(specs, spec_tree, shapes, tree, message_prefix="Output"): 552 if spec_tree != tree or not all(map(_shape_spec_consistent, specs, shapes)): 553 specs = tree_unflatten(spec_tree, specs) 554 shapes = tree_unflatten(tree, shapes) 555 raise ShapeError(f"{message_prefix} shapes should be {specs} but are {shapes}.") 556