1# Copyright 2019 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15 16from collections import namedtuple 17import functools 18import operator as op 19from typing import Optional, Sequence 20 21import numpy as np 22 23from ._src.util import safe_map, safe_zip, unzip2, subvals, taggedtuple 24from .lib import xla_bridge as xb 25from .lib import xla_client as xc 26 27from ._src import traceback_util 28traceback_util.register_exclusion(__file__) 29 30xops = xc.ops 31 32map = safe_map 33zip = safe_zip 34 35 36### lazy sublanguage 37 38# There are two components to a LazyExpr: an input and a reindexing 39# specification. The input represents a base array to which the reindexing 40# specification is applied. 41# 42# An input can represent an array constructor (Iota, Eye, etc.) or it can be an 43# ArrayVar which encodes that the base array is some exogenous array value (from 44# an environment with only a single value in it). These LazyExprs are attached 45# to DeviceArrays, so when the input part of the expression is ArrayVar that 46# basically means the associated device buffer represents the input, while if 47# the input is an array constructor then the associated device_buffer field of 48# the DeviceArray should be set to a DeviceConstant sentinel value. For the 49# array constructor expressions: 50# * Iota builds a 1D sequence [0, 1, ..., N-1], 51# * Eye builds a 2D array with ones on a (possibly offset) diagonal and zeros 52# elsewhere (like numpy.eye), 53# * Tri builds a triangular matrix with ones on and below a diagonal and zeros 54# elsewhere (like numpy.tri), and 55# * Delta builds a Kronecker delta array with ones along its multidimensional 56# main diagonal and zeros elsewhere (for use in tensor contractions). 57# 58# The reindexing specification encodes the shape of the final result and a list 59# of dimensions, which are integers or Nones. The integer entries take on values 60# 0, 1, ..., R-1 where R is the rank of the input array, and encode where the 61# axes of the input array are to be mapped in the final output. When an entry is 62# None that indicates that the corresponding axis of the result is a broadcasted 63# one. 64# 65# Here are some examples of lazy expressions and the arrays they represent: 66# 67# LazyExpr(input=Iota(dtype=dtype('float32'), size=3), 68# shape=(3, 4), dims=(0, None)) 69# DeviceArray([[0., 0., 0., 0.], 70# [1., 1., 1., 1.], 71# [2., 2., 2., 2.]], dtype=float32) 72# 73# LazyExpr(input=Iota(dtype=dtype('float32'), size=3), 74# shape=(4, 3), dims=(None, 0)) 75# DeviceArray([[0., 1., 2.], 76# [0., 1., 2.], 77# [0., 1., 2.], 78# [0., 1., 2.]], dtype=float32) 79# 80# For performance, some functions on lazy expressions accept None as an input to 81# stand for the identity lazy expression. 82# 83# We use the `taggedtuple` class constructor, rather than standard namedtuples, 84# because two namedtuple instances of different types but equal elements hash to 85# the same value, e.g. 86# A = namedtuple('A', ['x', 'y']) 87# B = namedtuple('B', ['x', 'y']) 88# hash(A(1, 2)) == hash(B(1, 2)) # True 89# but we want hashes to be sensitive to the type tag (while still being fast). 90 91# pytype: disable=wrong-arg-count 92LazyExpr = namedtuple('LazyExpr', ['input', 'shape', 'dims']) 93ArrayVar = taggedtuple('ArrayVar', []) 94Iota = taggedtuple('Iota', ['dtype', 'size']) # like np.arange(N) 95Eye = taggedtuple('Eye', ['dtype', 'shape', 'offset']) # like np.eye 96Tri = taggedtuple('Tri', ['dtype', 'shape', 'offset']) # like np.tri 97Delta = taggedtuple('Delta', ['dtype', 'shape']) # kronecker delta arrays 98# pytype: enable=wrong-arg-count 99 100def array(shape): 101 return LazyExpr(ArrayVar(), shape, tuple(range(len(shape)))) 102 103def iota(dtype, size): 104 return LazyExpr(Iota(dtype, size), (size,), (0,)) 105 106def eye(dtype, shape, offset): 107 assert len(shape) == 2 108 return LazyExpr(Eye(dtype, shape, offset), shape, (0, 1)) 109 110def tri(dtype, shape, offset): 111 assert len(shape) == 2 112 return LazyExpr(Tri(dtype, shape, offset), shape, (0, 1)) 113 114def delta(dtype, shape): 115 return LazyExpr(Delta(dtype, shape), shape, tuple(range(len(shape)))) 116 117def broadcast(lexpr, shape, broadcast_dimensions): 118 new_dims = [None] * len(shape) 119 for i, d in enumerate(broadcast_dimensions): 120 new_dims[d] = lexpr.dims[i] 121 return LazyExpr(lexpr.input, shape, tuple(new_dims)) 122 123def transpose(lexpr: LazyExpr, perm: Sequence[int]): 124 new_shape = tuple(lexpr.shape[i] for i in perm) 125 new_dims = tuple(lexpr.dims[i] for i in perm) 126 return LazyExpr(lexpr.input, new_shape, new_dims) 127 128def is_constant(lexpr: Optional[LazyExpr]): 129 return lexpr is not None and type(lexpr.input) is not ArrayVar 130 131def is_trivial(lexpr: Optional[LazyExpr]) -> bool: 132 return lexpr is None or (type(lexpr.input) is ArrayVar and 133 lexpr.dims == tuple(range(len(lexpr.shape)))) 134 135 136def eval_lexpr(lexpr, x): 137 """Evaluate a lazy expression using NumPy. 138 Args: 139 lexpr: the LazyExpr to evaluate. 140 x: ndarray or None, representing the value of ArrayVar if present. 141 Returns: 142 An ndarray representing the value of the lazy expression. 143 """ 144 if lexpr is None or is_trivial(lexpr): 145 return x 146 147 input_, shape, dims = lexpr 148 149 # first create a starting ndarray from input_ 150 t = type(input_) 151 if t is ArrayVar: 152 assert x is not None and type(x) is np.ndarray 153 elif t is Iota: 154 assert x is None 155 x = np.arange(input_.size, dtype=input_.dtype) 156 elif t is Eye: 157 assert x is None 158 N, M = input_.shape 159 x = np.eye(N, M, dtype=input_.dtype, k=input_.offset) 160 elif t is Tri: 161 assert x is None 162 N, M = input_.shape 163 x = np.tri(N, M, dtype=input_.dtype, k=input_.offset) 164 elif t is Delta: 165 ones = [1] * len(input_.shape) 166 iotas = [np.arange(d).reshape(subvals(ones, [(i, -1)])) 167 for i, d in enumerate(input_.shape)] 168 eyes = [i1 == i2 for i1, i2 in zip(iotas[:-1], iotas[1:])] 169 x = np.asarray(functools.reduce(op.and_, eyes), input_.dtype) 170 else: 171 assert False 172 173 # then apply the reindexing operation 174 perm = [d for d in dims if d is not None] 175 if perm != list(range(len(perm))): 176 x = np.transpose(x, perm) 177 if shape != x.shape: 178 in_shape = [1 if d is None else s for d, s in zip(dims, shape)] 179 x = np.broadcast_to(np.reshape(x, in_shape), shape) 180 181 return x 182 183 184def stage_lexpr(c, lexpr: Optional[LazyExpr], x): 185 """Stage a lazy expression into an XLA computation. 186 Args: 187 c: XLA ComputationBuilder into which to stage the expression. 188 lexpr: a LazyExpr to evaluate (or None for the identity expression). 189 x: XlaOp or None, representing the value of ArrayVar if present. 190 Returns: 191 An XlaOp representing the value of the lazy expression. 192 """ 193 if lexpr is None or is_trivial(lexpr): 194 return x 195 196 input_, shape, dims = lexpr 197 198 # first create a starting XlaOp from input_ 199 t = type(input_) 200 if t is ArrayVar: 201 assert x is not None 202 elif t is Iota: 203 assert x is None 204 x = xops.Iota(c, xb.dtype_to_etype(input_.dtype), input_.size) 205 elif t is Eye: 206 assert x is None 207 N, M = input_.shape 208 xla_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, (N, M)) 209 bool_eye = xops.Eq( 210 xops.Add(xops.Iota(c, xla_shape, 0), 211 xb.constant(c, np.array(input_.offset, np.int32))), 212 xops.Iota(c, xla_shape, 1)) 213 x = xops.ConvertElementType(bool_eye, xb.dtype_to_etype(input_.dtype)) 214 elif t is Tri: 215 assert x is None 216 N, M = input_.shape 217 xla_shape = xc.Shape.array_shape(xc.PrimitiveType.S32, (N, M)) 218 bool_tri = xops.Ge( 219 xops.Add(xops.Iota(c, xla_shape, 0), 220 xb.constant(c, np.array(input_.offset, np.int32))), 221 xops.Iota(c, xla_shape, 1)) 222 x = xops.ConvertElementType(bool_tri, xb.dtype_to_etype(input_.dtype)) 223 elif t is Delta: 224 etype = xb.dtype_to_etype(input_.dtype) 225 iotas = [xops.Iota(c, xc.Shape.array_shape(xc.PrimitiveType.U32, input_.shape), i) 226 for i in range(len(input_.shape))] 227 eyes = [xops.Eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])] 228 x = xops.ConvertElementType(functools.reduce(xops.And, eyes), etype) 229 else: 230 assert False 231 232 # then apply the operations encoded in reindex 233 bcast_dims, perm = unzip2((i, d) for i, d in enumerate(dims) if d is not None) 234 if tuple(perm) != tuple(range(len(perm))): 235 x = xops.Transpose(x, perm) 236 if shape != c.get_shape(x).dimensions(): 237 x = xops.BroadcastInDim(x, shape, bcast_dims) 238 239 return x 240