1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16from collections import namedtuple
17import functools
18import operator as op
19from typing import Optional, Sequence
20
21import numpy as np
22
23from ._src.util import safe_map, safe_zip, unzip2, subvals, taggedtuple
24from .lib import xla_bridge as xb
25from .lib import xla_client as xc
26
27from ._src import traceback_util
28traceback_util.register_exclusion(__file__)
29
30xops = xc.ops
31
32map = safe_map
33zip = safe_zip
34
35
36### lazy sublanguage
37
38# There are two components to a LazyExpr: an input and a reindexing
39# specification. The input represents a base array to which the reindexing
40# specification is applied.
41#
42# An input can represent an array constructor (Iota, Eye, etc.) or it can be an
43# ArrayVar which encodes that the base array is some exogenous array value (from
44# an environment with only a single value in it). These LazyExprs are attached
45# to DeviceArrays, so when the input part of the expression is ArrayVar that
46# basically means the associated device buffer represents the input, while if
47# the input is an array constructor then the associated device_buffer field of
48# the DeviceArray should be set to a DeviceConstant sentinel value. For the
49# array constructor expressions:
50#   * Iota builds a 1D sequence [0, 1, ..., N-1],
51#   * Eye builds a 2D array with ones on a (possibly offset) diagonal and zeros
52#     elsewhere (like numpy.eye),
53#   * Tri builds a triangular matrix with ones on and below a diagonal and zeros
54#     elsewhere (like numpy.tri), and
55#   * Delta builds a Kronecker delta array with ones along its multidimensional
56#     main diagonal and zeros elsewhere (for use in tensor contractions).
57#
58# The reindexing specification encodes the shape of the final result and a list
59# of dimensions, which are integers or Nones. The integer entries take on values
60# 0, 1, ..., R-1 where R is the rank of the input array, and encode where the
61# axes of the input array are to be mapped in the final output. When an entry is
62# None that indicates that the corresponding axis of the result is a broadcasted
63# one.
64#
65# Here are some examples of lazy expressions and the arrays they represent:
66#
67# LazyExpr(input=Iota(dtype=dtype('float32'), size=3),
68#          shape=(3, 4), dims=(0, None))
69# DeviceArray([[0., 0., 0., 0.],
70#              [1., 1., 1., 1.],
71#              [2., 2., 2., 2.]], dtype=float32)
72#
73# LazyExpr(input=Iota(dtype=dtype('float32'), size=3),
74#          shape=(4, 3), dims=(None, 0))
75# DeviceArray([[0., 1., 2.],
76#              [0., 1., 2.],
77#              [0., 1., 2.],
78#              [0., 1., 2.]], dtype=float32)
79#
80# For performance, some functions on lazy expressions accept None as an input to
81# stand for the identity lazy expression.
82#
83# We use the `taggedtuple` class constructor, rather than standard namedtuples,
84# because two namedtuple instances of different types but equal elements hash to
85# the same value, e.g.
86#   A = namedtuple('A', ['x', 'y'])
87#   B = namedtuple('B', ['x', 'y'])
88#   hash(A(1, 2)) == hash(B(1, 2))   # True
89# but we want hashes to be sensitive to the type tag (while still being fast).
90
91# pytype: disable=wrong-arg-count
92LazyExpr = namedtuple('LazyExpr', ['input', 'shape', 'dims'])
93ArrayVar = taggedtuple('ArrayVar', [])
94Iota = taggedtuple('Iota', ['dtype', 'size'])           # like np.arange(N)
95Eye = taggedtuple('Eye', ['dtype', 'shape', 'offset'])  # like np.eye
96Tri = taggedtuple('Tri', ['dtype', 'shape', 'offset'])  # like np.tri
97Delta = taggedtuple('Delta', ['dtype', 'shape'])  # kronecker delta arrays
98# pytype: enable=wrong-arg-count
99
100def array(shape):
101  return LazyExpr(ArrayVar(), shape, tuple(range(len(shape))))
102
103def iota(dtype, size):
104  return LazyExpr(Iota(dtype, size), (size,), (0,))
105
106def eye(dtype, shape, offset):
107  assert len(shape) == 2
108  return LazyExpr(Eye(dtype, shape, offset), shape, (0, 1))
109
110def tri(dtype, shape, offset):
111  assert len(shape) == 2
112  return LazyExpr(Tri(dtype, shape, offset), shape, (0, 1))
113
114def delta(dtype, shape):
115  return LazyExpr(Delta(dtype, shape), shape, tuple(range(len(shape))))
116
117def broadcast(lexpr, shape, broadcast_dimensions):
118  new_dims = [None] * len(shape)
119  for i, d in enumerate(broadcast_dimensions):
120    new_dims[d] = lexpr.dims[i]
121  return LazyExpr(lexpr.input, shape, tuple(new_dims))
122
123def transpose(lexpr: LazyExpr, perm: Sequence[int]):
124  new_shape = tuple(lexpr.shape[i] for i in perm)
125  new_dims = tuple(lexpr.dims[i] for i in perm)
126  return LazyExpr(lexpr.input, new_shape, new_dims)
127
128def is_constant(lexpr: Optional[LazyExpr]):
129  return lexpr is not None and type(lexpr.input) is not ArrayVar
130
131def is_trivial(lexpr: Optional[LazyExpr]) -> bool:
132  return lexpr is None or (type(lexpr.input) is ArrayVar and
133          lexpr.dims == tuple(range(len(lexpr.shape))))
134
135
136def eval_lexpr(lexpr, x):
137  """Evaluate a lazy expression using NumPy.
138  Args:
139    lexpr: the LazyExpr to evaluate.
140    x: ndarray or None, representing the value of ArrayVar if present.
141  Returns:
142    An ndarray representing the value of the lazy expression.
143  """
144  if lexpr is None or is_trivial(lexpr):
145    return x
146
147  input_, shape, dims = lexpr
148
149  # first create a starting ndarray from input_
150  t = type(input_)
151  if t is ArrayVar:
152    assert x is not None and type(x) is np.ndarray
153  elif t is Iota:
154    assert x is None
155    x = np.arange(input_.size, dtype=input_.dtype)
156  elif t is Eye:
157    assert x is None
158    N, M = input_.shape
159    x = np.eye(N, M, dtype=input_.dtype, k=input_.offset)
160  elif t is Tri:
161    assert x is None
162    N, M = input_.shape
163    x = np.tri(N, M, dtype=input_.dtype, k=input_.offset)
164  elif t is Delta:
165    ones = [1] * len(input_.shape)
166    iotas = [np.arange(d).reshape(subvals(ones, [(i, -1)]))
167             for i, d in enumerate(input_.shape)]
168    eyes = [i1 == i2 for i1, i2 in zip(iotas[:-1], iotas[1:])]
169    x = np.asarray(functools.reduce(op.and_, eyes), input_.dtype)
170  else:
171    assert False
172
173  # then apply the reindexing operation
174  perm = [d for d in dims if d is not None]
175  if perm != list(range(len(perm))):
176    x = np.transpose(x, perm)
177  if shape != x.shape:
178    in_shape = [1 if d is None else s for d, s in zip(dims, shape)]
179    x = np.broadcast_to(np.reshape(x, in_shape), shape)
180
181  return x
182
183
184def stage_lexpr(c, lexpr: Optional[LazyExpr], x):
185  """Stage a lazy expression into an XLA computation.
186  Args:
187    c: XLA ComputationBuilder into which to stage the expression.
188    lexpr: a LazyExpr to evaluate (or None for the identity expression).
189    x: XlaOp or None, representing the value of ArrayVar if present.
190  Returns:
191    An XlaOp representing the value of the lazy expression.
192  """
193  if lexpr is None or is_trivial(lexpr):
194    return x
195
196  input_, shape, dims = lexpr
197
198  # first create a starting XlaOp from input_
199  t = type(input_)
200  if t is ArrayVar:
201    assert x is not None
202  elif t is Iota:
203    assert x is None
204    x = xops.Iota(c, xb.dtype_to_etype(input_.dtype), input_.size)
205  elif t is Eye:
206    assert x is None
207    N, M = input_.shape
208    xla_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, (N, M))
209    bool_eye = xops.Eq(
210      xops.Add(xops.Iota(c, xla_shape, 0),
211               xb.constant(c, np.array(input_.offset, np.int32))),
212      xops.Iota(c, xla_shape, 1))
213    x = xops.ConvertElementType(bool_eye, xb.dtype_to_etype(input_.dtype))
214  elif t is Tri:
215    assert x is None
216    N, M = input_.shape
217    xla_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, (N, M))
218    bool_tri = xops.Ge(
219      xops.Add(xops.Iota(c, xla_shape, 0),
220               xb.constant(c, np.array(input_.offset, np.int32))),
221      xops.Iota(c, xla_shape, 1))
222    x = xops.ConvertElementType(bool_tri, xb.dtype_to_etype(input_.dtype))
223  elif t is Delta:
224    etype = xb.dtype_to_etype(input_.dtype)
225    iotas = [xops.Iota(c, xc.Shape.array_shape(xc.PrimitiveType.U32, input_.shape), i)
226             for i in range(len(input_.shape))]
227    eyes = [xops.Eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])]
228    x = xops.ConvertElementType(functools.reduce(xops.And, eyes), etype)
229  else:
230    assert False
231
232  # then apply the operations encoded in reindex
233  bcast_dims, perm = unzip2((i, d) for i, d in enumerate(dims) if d is not None)
234  if tuple(perm) != tuple(range(len(perm))):
235    x = xops.Transpose(x, perm)
236  if shape != c.get_shape(x).dimensions():
237    x = xops.BroadcastInDim(x, shape, bcast_dims)
238
239  return x
240