1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from contextlib import contextmanager
16from collections import Counter, namedtuple
17from functools import partial, reduce
18from itertools import chain, product
19import operator as op
20import string
21from typing import Callable, Dict, Optional, Sequence, Union, Tuple
22
23import numpy as np
24
25from .. import core, dtypes
26from ..tree_util import tree_unflatten
27from ..core import ShapedArray, Trace, Tracer
28from .._src.util import safe_map, safe_zip, unzip2, prod, wrap_name
29from .. import linear_util as lu
30
31map = safe_map
32zip = safe_zip
33
34masking_rules: Dict[core.Primitive, Callable] = {}
35
36def defvectorized(prim):
37  masking_rules[prim] = partial(vectorized_masking_rule, prim)
38
39def defnaryop(prim):
40  masking_rules[prim] = partial(naryop_masking_rule, prim)
41
42def vectorized_masking_rule(prim, padded_vals, logical_shapes, **params):
43  del logical_shapes  # Unused.
44  padded_val, = padded_vals
45  return prim.bind(padded_val, **params)
46
47def naryop_masking_rule(prim, padded_vals, logical_shapes):
48  del logical_shapes  # Unused.
49  return prim.bind(*padded_vals)
50
51ShapeEnvs = namedtuple("ShapeEnvs", ["logical", "padded"])
52shape_envs = ShapeEnvs({}, {})  # TODO(mattjj): make this a stack for efficiency
53
54def is_tracing():
55  return bool(shape_envs.padded)
56
57@contextmanager
58def extend_shape_envs(logical_env, padded_env):
59  global shape_envs
60  new_logical = dict(chain(shape_envs.logical.items(), logical_env.items()))
61  new_padded = dict(chain(shape_envs.padded.items(), padded_env.items()))
62  shape_envs, prev = ShapeEnvs(new_logical, new_padded), shape_envs
63  try:
64    yield
65  finally:
66    shape_envs = prev
67
68def shape_as_value(shape):
69  assert is_tracing() or not is_polymorphic(shape)
70  return eval_poly_shape(shape, shape_envs.logical)
71
72def padded_shape_as_value(shape):
73  assert is_tracing() or not is_polymorphic(shape)
74  return eval_poly_shape(shape, shape_envs.padded)
75
76def mask_fun(fun, logical_env, padded_env, in_vals, polymorphic_shapes):
77  env_keys, padded_env_vals = unzip2(sorted(padded_env.items()))
78  logical_env_vals = [logical_env[k] for k in env_keys]
79  # Make padded_env hashable
80  padded_env = (env_keys, padded_env_vals)
81  with core.new_main(MaskTrace) as main:
82    fun, out_shapes = mask_subtrace(fun, main, polymorphic_shapes, padded_env)
83    out_vals = fun.call_wrapped(*(logical_env_vals + in_vals))
84    del main
85  return out_vals, out_shapes()
86
87@lu.transformation_with_aux
88def mask_subtrace(main, shapes, padded_env, *in_vals):
89  env_keys, _ = padded_env
90  logical_env_vals, in_vals = in_vals[:len(env_keys)], in_vals[len(env_keys):]
91  logical_env = dict(zip(env_keys, logical_env_vals))
92  padded_env = dict(zip(*padded_env))
93  trace = MaskTrace(main, core.cur_sublevel())
94  in_tracers = [MaskTracer(trace, x, s).full_lower()
95                for x, s in zip(in_vals, shapes)]
96  with extend_shape_envs(logical_env, padded_env):
97    outs = yield in_tracers, {}
98  out_tracers = map(trace.full_raise, outs)
99  out_vals, out_shapes = unzip2((t.val, t.polymorphic_shape) for t in out_tracers)
100  yield out_vals, out_shapes
101
102def eval_poly_shape(shape, values_dict):
103  return tuple(eval_poly(dim, values_dict) for dim in shape)
104
105def eval_poly(poly, values_dict):
106  return poly.evaluate(values_dict) if type(poly) is Poly else poly
107
108def _ensure_poly(p: 'Size') -> 'Poly':
109  if isinstance(p, Poly): return p
110  return Poly({Mon(): p})
111
112def _polys_to_ints(shape):
113  return tuple(int(d) if type(d) is Poly and d.is_constant else d
114               for d in shape)
115
116def is_polymorphic(shape: Sequence['Size']):
117  return any(map(lambda d: type(d) is Poly, shape))
118
119class UndefinedPoly(Exception):
120  """Exception raised when an operation involving polynomials is not defined.
121
122  An operation `op` on polynomials `p1` and `p2` either raises this exception,
123  or produce a polynomial `res`, such that `op(Val(p1), Val(p2)) = Val(res)`,
124  for any `Val`, a non-negative integer valuation of the shape variables.
125  """
126  pass
127
128class Poly(dict):
129  """Polynomial with integer coefficients for polymorphic shapes.
130
131  The shape variables are assumed to range over non-negative integers.
132
133  We overload integer operations, but we do that soundly, raising
134  :class:`UndefinedPoly` when the result is not representable as a polynomial.
135
136  The representation of a polynomial is as a dictionary mapping monomials to
137  integer coefficients. The special monomial `Mon()` is mapped to the
138  free integer coefficient of the polynomial.
139  """
140
141  def __init__(self, coeffs: Dict['Mon', int]):
142    # Makes sure Polynomials are always in canonical form
143    coeffs = {mon: op.index(coeff)
144              for mon, coeff in coeffs.items() if coeff != 0}
145    coeffs = coeffs or {Mon(): 0}
146    super().__init__(coeffs)
147
148  def __hash__(self):
149    return hash(tuple(sorted(self.items())))
150
151  def __add__(self, other: 'Size') -> 'Poly':
152    coeffs = self.copy()
153    for mon, coeff in _ensure_poly(other).items():
154      coeffs[mon] = coeffs.get(mon, 0) + coeff
155    return Poly(coeffs)
156
157  def __sub__(self, other: 'Size') -> 'Poly':
158    return self + -other
159
160  def __neg__(self) -> 'Poly':
161    return Poly({mon: -coeff for mon, coeff in self.items()})
162
163  def __mul__(self, other: 'Size') -> 'Poly':
164    other = _ensure_poly(other)
165    coeffs: Dict[Mon, int] = {}
166    for (mon1, coeff1), (mon2, coeff2) in product(self.items(), other.items()):
167      mon = mon1 * mon2
168      coeffs[mon] = coeffs.get(mon, 0) + coeff1 * coeff2
169    return Poly(coeffs)
170
171  def __rmul__(self, other: 'Size') -> 'Poly':
172    return self * other  # multiplication commutes
173
174  def __radd__(self, other: 'Size') -> 'Poly':
175    return self + other  # addition commutes
176
177  def __rsub__(self, other: 'Size') -> 'Poly':
178    return _ensure_poly(other) - self
179
180  def __floordiv__(self, divisor: 'Size') -> 'Poly':
181    q, _ = divmod(self, divisor)  # type: ignore
182    return q
183
184  def __mod__(self, divisor: 'Size') -> int:
185    _, r = divmod(self, divisor)  # type: ignore
186    return r
187
188  def __divmod__(self, divisor: 'Size') -> Tuple['Poly', int]:
189    """
190    Floor division with remainder (divmod) generalized to polynomials. To allow
191    ensuring '0 <= remainder < divisor' for consistency with integer divmod, the
192    divisor must divide the dividend (up to a constant for constant divisors).
193    :return: Quotient resulting from polynomial division and integer remainder.
194    """
195    divisor = _ensure_poly(divisor)
196    dmon, dcount = divisor._leading_term
197    dividend, quotient, remainder = self, _ensure_poly(0), _ensure_poly(0)
198    while not dividend.is_constant or dividend != 0:  # invariant: dividend == divisor*quotient + remainder
199      mon, count = dividend._leading_term
200      qcount, rcount = divmod(count, dcount)
201      try:
202        qmon = mon // dmon
203      except UndefinedPoly:
204        raise UndefinedPoly(f"Stride {divisor} must divide size {self} "
205                            "(up to a constant for constant divisors).")
206      r = Poly({mon: rcount})
207      q = Poly({qmon: qcount})
208      quotient += q
209      remainder += r
210      dividend -= q * divisor + r
211    return quotient, int(remainder)
212
213  def __rdivmod__(self, dividend: 'Size') -> Tuple['Poly', int]:
214    return divmod(_ensure_poly(dividend), self)  # type: ignore
215
216  def __eq__(self, other):
217    lb, ub = (self - other).bounds()
218    if lb == ub == 0:
219      return True
220    if lb is not None and lb > 0:
221      return False
222    if ub is not None and ub < 0:
223      return False
224    raise UndefinedPoly(f"Polynomial comparison {self} == {other} is inconclusive")
225
226  def __ne__(self, other):
227    return not self == other
228
229  def __ge__(self, other: 'Size'):
230    lb, ub = (self - other).bounds()
231    if lb is not None and lb >= 0:
232      return True
233    if ub is not None and ub < 0:
234      return False
235    raise UndefinedPoly(f"Polynomial comparison {self} >= {other} is inconclusive")
236
237  def __le__(self, other: 'Size'):
238    return _ensure_poly(other) >= self
239
240  def __lt__(self, other: 'Size'):
241    return not (self >= other)
242
243  def __gt__(self, other: 'Size'):
244    return not (_ensure_poly(other) >= self)
245
246  def __str__(self):
247    return ' + '.join(f'{c} {mon}' if c != 1 or mon.degree == 0 else str(mon)
248                      for mon, c in sorted(self.items(), reverse=True)).strip()
249
250  def __repr__(self):
251    return str(self)
252
253  def __int__(self):
254    if self.is_constant:
255      return op.index(next(iter(self.values())))
256    else:
257      raise UndefinedPoly(f"Polynomial {self} is not constant")
258
259  def bounds(self) -> Tuple[Optional[int], Optional[int]]:
260    """Returns the lower and upper bounds, if defined."""
261    lb = ub = self.get(Mon(), 0)
262    for mon, coeff in self.items():
263      if mon.degree > 0:
264        if coeff > 0:
265          ub = None
266        else:
267          lb = None
268    return lb, ub
269
270  def evaluate(self, env):
271    prod = lambda xs: reduce(op.mul, xs) if xs else 1
272    terms = [mul(coeff, prod([pow(env[id], deg) for id, deg in mon.items()]))
273             for mon, coeff in self.items()]
274    return sum(terms) if len(terms) > 1 else terms[0]
275
276  @property
277  def is_constant(self):
278    return len(self) == 1 and next(iter(self)).degree == 0
279
280  @property
281  def _leading_term(self) -> Tuple['Mon', int]:
282    """Returns the highest degree term that comes first lexicographically."""
283    return max(self.items())
284
285Size = Union[int, Poly]
286
287def pow(x, deg):
288  try:
289    deg = int(deg)
290  except:
291    return x ** deg
292  else:
293    return 1 if deg == 0 else x if deg == 1 else x ** deg
294
295def mul(coeff, mon):
296  try:
297    coeff = int(coeff)
298  except:
299    return coeff * mon
300  else:
301    return 0 if coeff == 0 else mon if coeff == 1 else coeff * mon
302
303
304core._DIMENSION_TYPES.add(Poly)
305
306class Mon(dict):
307  # TODO: move this before Poly in the file
308  """Represents a multivariate monomial, such as n^3 * m.
309
310  The representation is a dictionary mapping var:exponent. The
311  exponent is >= 1.
312  """
313  def __hash__(self):
314    return hash(frozenset(self.items()))
315
316  def __str__(self):
317    return ' '.join(f'{key}^{exponent}' if exponent != 1 else str(key)
318                    for key, exponent in sorted(self.items()))
319
320  def __lt__(self, other: 'Mon'):
321    # TODO: do not override __lt__ for this
322    """
323    Comparison to another monomial in graded reverse lexicographic order.
324    """
325    self_key = -self.degree, tuple(sorted(self))
326    other_key = -other.degree, tuple(sorted(other))
327    return self_key > other_key
328
329  def __mul__(self, other: 'Mon') -> 'Mon':
330    """
331    Returns the product with another monomial. Example: (n^2*m) * n == n^3 * m.
332    """
333    return Mon(Counter(self) + Counter(other))
334
335  @property
336  def degree(self):
337    return sum(self.values())
338
339  def __floordiv__(self, divisor: 'Mon') -> 'Mon':
340    """
341    Divides by another monomial. Raises a ValueError if impossible.
342    For example, (n^3 * m) // n == n^2*m, but n // m fails.
343    """
344    d = Counter(self)
345    for key, exponent in divisor.items():
346      diff = self.get(key, 0) - exponent
347      if diff < 0: raise UndefinedPoly(f"Cannot divide {self} by {divisor}.")
348      elif diff == 0: del d[key]
349      elif diff > 0: d[key] = diff
350    return Mon(d)
351
352class ShapeError(Exception): pass
353
354class ShapeSyntaxError(Exception): pass
355
356# To denote some shape expressions (for annotations) we use a small language.
357#
358#   data ShapeSpec = ShapeSpec [Dim]
359#   data Dim = Id PyObj
360#            | Lit Int
361#            | Mul Dim Dim
362#            | Add Dim Dim
363#            | MonomorphicDim
364#
365# We'll also make a simple concrete syntax for annotation. The grammar is
366#
367#   shape_spec ::= '(' dims ')'
368#   dims       ::= dim ',' dims | ''
369#   dim        ::= str | int | dim '*' dim | dim '+' dim | '_'
370#
371# ShapeSpecs can have some monomorphic dims inside them, which must be replaced
372# with concrete shapes when known.
373
374class ShapeSpec(tuple):
375  def __str__(self):
376    return 'ShapeSpec({})'.format(', '.join(map(str, self)))
377
378def finalize_spec(polymorphic_shape, padded_shape):
379  # TODO: what if polymorphic_shape has a constant that does not match padded_shape?
380  return tuple(_parse_lit(d) if e is _monomorphic_dim else e
381               for e, d in zip(polymorphic_shape, padded_shape))
382
383def parse_spec(spec=''):
384  if not spec:
385    return ShapeSpec(())
386  if spec[0] == '(':
387    if spec[-1] != ')': raise ShapeSyntaxError(spec)
388    spec = spec[1:-1]
389  dims = map(_parse_dim, spec.replace(' ', '').strip(',').split(','))
390  return ShapeSpec(dims)
391
392def _parse_dim(spec):
393  if '+' in spec:
394    return np.sum(map(_parse_dim, spec.split('+')))
395  elif '*' in spec:
396    return prod(map(_parse_dim, spec.split('*')))
397  elif spec.isdigit() or spec.startswith('-') and spec[1:].isdigit():
398    return _parse_lit(spec)
399  elif spec[0] in _identifiers:
400    return _parse_id(spec)
401  elif spec == '_':
402    return _monomorphic_dim
403  else:
404    raise ShapeSyntaxError(spec)
405
406_identifiers = frozenset(string.ascii_lowercase)
407
408def _parse_id(name): return Poly({Mon({name: 1}): 1})
409
410def _parse_lit(val_str): return int(val_str)
411
412class MonomorphicDim(object):
413  def __str__(self): return '_'
414
415_monomorphic_dim = MonomorphicDim()
416
417# Two convenient ways to provide shape annotations:
418#   1. '(m, n)'
419#   2. s_['m', 'n']
420
421class S_(object):
422  def __getitem__(self, idx):
423    return parse_spec(('(' + ','.join(map(str, idx)) + ')')
424                             if type(idx) is tuple else str(idx))
425
426s_ = S_()
427
428def _shape_spec_consistent(spec, expr):
429  return all(a == b for a, b in zip(spec, expr) if a is not _monomorphic_dim)
430
431class MaskTracer(Tracer):
432  __slots__ = ["val", "polymorphic_shape"]
433
434  def __init__(self, trace, val, polymorphic_shape):
435    super().__init__(trace)
436    self.val = val
437    self.polymorphic_shape = polymorphic_shape
438
439  @property
440  def aval(self):
441    return ShapedArray(self.polymorphic_shape, self.dtype)
442
443  @property
444  def dtype(self):
445    return dtypes.dtype(self.val)
446
447  def is_pure(self):
448    return all(type(poly) is not Poly or poly.is_constant
449               for poly in self.polymorphic_shape)
450
451  def full_lower(self):
452    if self.is_pure():
453      return core.full_lower(self.val)
454    else:
455      return self
456
457
458class MaskTrace(Trace):
459  def pure(self, val):
460    return MaskTracer(self, val, np.shape(val))
461
462  def lift(self, val):
463    return MaskTracer(self, val, np.shape(val))
464
465  def sublift(self, val):
466    return MaskTracer(self, val.val, val.polymorphic_shape)
467
468  def process_primitive(self, primitive, tracers, params):
469    masking_rule = masking_rules.get(primitive)
470    if masking_rule is None:
471      raise NotImplementedError(
472        f'Masking rule for {primitive} not implemented yet.')
473    out_aval = primitive.abstract_eval(*(t.aval for t in tracers), **params)
474    vals, polymorphic_shapes = unzip2((t.val, t.polymorphic_shape) for t in tracers)
475    logical_shapes = map(shape_as_value, polymorphic_shapes)
476    # TODO(mattjj): generalize mask rule signature
477    if primitive.name == 'reshape': params['polymorphic_shapes'] = polymorphic_shapes
478    out = masking_rule(vals, logical_shapes, **params)
479    if primitive.multiple_results:
480      out_shapes = map(_polys_to_ints, [o.shape for o in out_aval])
481      return map(partial(MaskTracer, self), out, out_shapes)
482    else:
483      return MaskTracer(self, out, _polys_to_ints(out_aval.shape))
484
485  def process_call(self, call_primitive, f, tracers, params):
486    assert call_primitive.multiple_results
487    params = dict(params, name=wrap_name(params.get('name', f.__name__), 'mask'))
488    vals, shapes = unzip2((t.val, t.polymorphic_shape) for t in tracers)
489    if not any(is_polymorphic(s) for s in shapes):
490      return call_primitive.bind(f, *vals, **params)
491    else:
492      logical_env, padded_env = shape_envs
493      env_keys, padded_env_vals = unzip2(sorted(padded_env.items()))
494      logical_env_vals = tuple(logical_env[k] for k in env_keys)
495      # Make padded_env hashable
496      padded_env = (env_keys, padded_env_vals)
497      f, shapes_out = mask_subtrace(f, self.main, shapes, padded_env)
498      if 'donated_invars' in params:
499        params = dict(params, donated_invars=((False,) * len(logical_env_vals) +
500                                              params['donated_invars']))
501      vals_out = call_primitive.bind(f, *(logical_env_vals + vals), **params)
502      return [MaskTracer(self, v, s) for v, s in zip(vals_out, shapes_out())]
503
504  def post_process_call(self, call_primitive, out_tracers, params):
505    vals, shapes = unzip2((t.val, t.polymorphic_shape) for t in out_tracers)
506    main = self.main
507    def todo(vals):
508      trace = MaskTrace(main, core.cur_sublevel())
509      return map(partial(MaskTracer, trace), vals, shapes)
510    return vals, todo
511
512class UniqueId:
513  def __init__(self, name):
514    self.name = name
515
516  def __repr__(self):
517    return self.name
518
519  def __lt__(self, other):
520    return self.name < other.name
521
522class UniqueIds(dict):
523  def __missing__(self, key):
524    unique_id = UniqueId(key)
525    self[key] = unique_id
526    return unique_id
527
528def remap_ids(names, shape_spec):
529  return ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()})
530                         : coeff for mon, coeff in poly.items()})
531                   if isinstance(poly, Poly) else
532                   poly for poly in shape_spec)
533
534def bind_shapes(polymorphic_shapes, padded_shapes):
535  env = {}
536  for polymorphic_shape, padded_shape in zip(polymorphic_shapes, padded_shapes):
537    for poly, d in zip(polymorphic_shape, padded_shape):
538      if type(poly) is not Poly or poly.is_constant:
539        if int(poly) != d: raise ShapeError
540      else:
541        poly = poly.copy()
542        const_coeff = poly.pop(Mon({}), 0)
543        (mon, linear_coeff), = poly.items()
544        (id, index), = mon.items()
545        if index != 1: raise ShapeError
546        d, r = divmod(d - const_coeff, linear_coeff)
547        assert r == 0
548        if env.setdefault(id, d) != d: raise ShapeError
549  return env
550
551def check_shapes(specs, spec_tree, shapes, tree, message_prefix="Output"):
552  if spec_tree != tree or not all(map(_shape_spec_consistent, specs, shapes)):
553    specs = tree_unflatten(spec_tree, specs)
554    shapes = tree_unflatten(tree, shapes)
555    raise ShapeError(f"{message_prefix} shapes should be {specs} but are {shapes}.")
556