1#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2#  See https://llvm.org/LICENSE.txt for license information.
3#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4"""Model classes representing a tensor comprehension.
5
6These classes model the language more at an AST level as evaluated. Reasoning
7about it typically involves processing this form into config objects that
8represent actual op definitions (i.e. YAML).
9"""
10
11from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
12from enum import Enum
13
14from mlir import ir as _ir
15
16from .affine import *
17from .scalar_expr import *
18from .types import *
19from .yaml_helper import *
20
21# Type aliases.
22AffineDimList = Dict[str, _ir.AffineExpr]
23
24
25class TensorExpression:
26  """An expression that can appear on the RHS of a comprehension."""
27
28  def to_scalar_expression(self) -> ScalarExpression:
29    raise NotImplementedError()
30
31  def visit_tensor_exprs(self, callback):
32    """Visits all tensor expression reachable by the expression."""
33    callback(self)
34
35  def collect_dim_uses(self, uses: Set["DimDef"]):
36    """Collects all DimDefs reachable through this expression."""
37    results = set()
38
39    def visit_dim_def(dim_def):
40      if isinstance(dim_def, DimDef):
41        uses.add(dim_def)
42
43    def visit_affine_exprs(expr):
44      if isinstance(expr, TensorUse):
45        for ind in expr.indices:
46          ind.visit_affine_exprs(visit_dim_def)
47      if isinstance(expr, ReduceApply):
48        for ind in expr.reduce.reduce_dims:
49          ind.visit_affine_exprs(visit_dim_def)
50
51    self.visit_tensor_exprs(visit_affine_exprs)
52
53  def collect_tensor_uses(self, uses: Set["TensorUse"]):
54    """Collects all TensorUses reachable through this expression."""
55
56    def visit_tensor_use(expr):
57      if isinstance(expr, TensorUse):
58        uses.add(expr)
59
60    self.visit_tensor_exprs(visit_tensor_use)
61
62  def collect_indices(self, indices: Set["index"]):
63    """Collects all index accesses reachable through this expression."""
64
65    def visit_index(expr):
66      if isinstance(expr, index):
67        indices.add(expr)
68
69    self.visit_tensor_exprs(visit_index)
70
71  def collect_scalar_uses(self, uses: Set["ScalarDef"]):
72    """Collects all ScalarDefs reachable through this expression."""
73
74    def visit_scalar_def(expr):
75      if isinstance(expr, ScalarDef):
76        uses.add(expr)
77
78    self.visit_tensor_exprs(visit_scalar_def)
79
80  def __add__(self, rhs: "TensorExpression") -> "TensorExpression":
81    return PrimFn.add(self, rhs)
82
83  def __mul__(self, rhs) -> "TensorExpression":
84    return PrimFn.mul(self, rhs)
85
86  def __sub__(self, rhs) -> "TensorExpression":
87    return PrimFn.sub(self, rhs)
88
89  def __hash__(self):
90    return hash(id(self))
91
92
93class TensorUse(TensorExpression):
94  """A used tensor represented by its (tensor_name, indices).
95
96  Note that forming a comprehension via direct assignment is performed through
97  __setitem__ on the TensorDef level. However, performing a reduction with
98  compound ops (+=, *=, etc) is done by doing a:
99    TensorDef.__getitem__
100    TensorUse.__iadd__
101    TensorDef.__setitem__
102  """
103
104  def __init__(self, operand_def: "OperandDef",
105               indices: Sequence[AffineExprDef]):
106    self.operand_def = operand_def
107    self.indices = tuple(indices)
108
109  def to_scalar_expression(self) -> ScalarExpression:
110    return ScalarArg(self.tensor_name).expr()
111
112  @property
113  def tensor_name(self) -> str:
114    name = self.operand_def.name
115    assert name is not None, "TensorDef not attached"
116    return name
117
118  def __iadd__(self, rhs: TensorExpression) -> TensorExpression:
119    return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs)
120
121  def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]:
122    """For implicit reductions, computes default reduction dims.
123
124    Assumes that the rhs is the expression being reduced and self is being
125    reduced into. Any indices referenced on the rhs and not in self are
126    considered reduction dims and will be ordered as encountered on the rhs.
127    """
128    rhs_dims = set()
129    lhs_dims = set()
130    rhs.collect_dim_uses(rhs_dims)
131    self.collect_dim_uses(lhs_dims)
132    return rhs_dims - lhs_dims
133
134  def __repr__(self):
135    return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]"
136
137
138class OperandKind(Enum):
139  InputTensor = 0
140  Scalar = 1
141  OutputTensor = 2
142  Attribute = 3
143
144
145class OperandDef:
146  """Definition of an operand passed to an operation.
147
148  Keep the meta information of Tensor, Scalar, and Attribute operands and
149  provide the shared registration functionality.
150  """
151
152  def __init__(self,
153               kind: OperandKind,
154               type_var: TypeVar,
155               size_exprs: Optional[Sequence[AffineExprDef]] = None,
156               index_dims: Optional[Sequence[DimDef]] = None):
157    if not isinstance(type_var, TypeVar):
158      raise ValueError(
159          f"OperandDef requires a TypeVar but got {repr(type_var)}")
160    self.owner = None  # type: Optional["LinalgOpDef"]
161    self.type_var = type_var
162    self.size_exprs = size_exprs
163    self.index_dims = index_dims
164    self.kind = kind
165    self.name = None  # type: Optional[str]
166    self.registered_index = -1  # type: int
167
168  def attach(self, index: int, name: str, owner: "LinalgOpDef"):
169    if self.owner:
170      raise ValueError(f"OperandDef already registered with op: {self}")
171    self.registered_index = index
172    self.name = name
173    self.owner = owner
174
175  def __hash__(self):
176    return hash(id(self))
177
178  def __repr__(self):
179    return (f"{self.name}:OperandDef(kind={self.kind.name}, "
180            f"type={repr(self.type_var)}, size_exprs={self.size_exprs}), "
181            f"index_dims={self.index_dims})")
182
183
184class TensorDef:
185  """Tensor operand definition.
186
187  Tensor operands are indexed using the associated indexing_map when forwarded
188  to the body of the structured op. A unique name identifies the tensor operands
189  and an index determines their position in the operation's parameter list. A
190  tensor definition takes type, a shape, and an optional flag to mark output
191  tensors. Additionally, a tuple of index dimensions may be used to map the
192  tensor to the loop dimensions of the operation. This mapping is needed to
193  compute the indexing map of shape-only tensors that have no uses.
194  """
195
196  def __init__(self,
197               type_var: TypeVar,
198               *shape: AffineExprDef,
199               index_dims: Optional[Sequence[DimDef]] = None,
200               output: bool = False):
201    if index_dims and len(shape) != len(index_dims):
202      raise ValueError(f"Expected the shape rank {len(shape)} to match the "
203                       f"number of index_dims {len(index_dims)}")
204    if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
205      raise ValueError(f"TensorDef requires index dims of type DimDef but "
206                       f"got {index_dims}")
207    kind = OperandKind.OutputTensor if output else OperandKind.InputTensor
208    self.operand_def = OperandDef(
209        kind, type_var, size_exprs=shape, index_dims=index_dims)
210
211  def __getitem__(self, dims) -> TensorUse:
212    assert self.operand_def.owner, "TensorDef is not attached to an op"
213    state = AffineBuildState(
214        global_state=self.operand_def.owner._affine_state,
215        allow_new_symbols=False)
216    if not isinstance(dims, tuple):
217      dims = (dims,)  # Handle single subscript case.
218    # Special case: (None) is a 0d-scalar use.
219    if dims == (None,):
220      dims = ()
221
222    exprs = []
223    for expr_def in dims:
224      if not isinstance(expr_def, AffineExprDef):
225        raise KeyError(
226            "A TensorDef can only be subscripted by a tuple of affine dims")
227      exprs.append(expr_def)
228    return TensorUse(self.operand_def, exprs)
229
230  def __setitem__(self, dims, value):
231    """Creates a new 1:1 comprehension by binding this tensor to an expression.
232
233    Note that due to the way assignment works in Python, we have to capture
234    direct assignment as a setitem on the TensorDef.
235    """
236    if not isinstance(value, TensorExpression):
237      raise ValueError(f"Only TensorExpressions can be assigned to TensorDefs. "
238                       f"Got: {repr(value)}")
239    use = self[dims]
240    comp = Comprehension((use, value))
241    self.operand_def.owner.comprehensions.append(comp)
242
243
244class ScalarDef(TensorExpression):
245  """Scalar operand definition.
246
247  Scalar operands are forwarded to the body of the structured op as they are.
248  A unique name identifies the scalars and an index determines their position in
249  the operation's parameter list.
250  """
251
252  def __init__(self, type_var: TypeVar):
253    self.operand_def = OperandDef(OperandKind.Scalar, type_var)
254
255  @property
256  def scalar_name(self) -> str:
257    name = self.operand_def.name
258    assert name is not None, "ScalarDef not attached"
259    return name
260
261  def to_scalar_expression(self) -> ScalarExpression:
262    return ScalarArg(self.scalar_name).expr()
263
264
265class AttributeDef:
266  """Index Attribute definition.
267
268  Index attributes provide a way to define and set symbols that can be used in
269  indexing expressions. Every attribute specifies a tuple of symbols that at
270  compile-time are replaced by integer values.
271  """
272  yaml_tag = "!LinalgAttributeDef"
273
274  def __init__(self, *sizes: SymbolDef):
275    if any(not isinstance(size, SymbolDef) for size in sizes):
276      raise ValueError(f"AttributeDef requires sizes of type SymbolDef but got "
277                       f"{sizes}")
278    self.operand_def = OperandDef(OperandKind.Attribute, I64, size_exprs=sizes)
279
280
281class Comprehension:
282  """Represents a single comprehension."""
283
284  def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]):
285    self.definitions = list()  # List[TensorUse]
286    self.values = list()  # List[TensorExpression]
287
288    # Find the lhs to reduction rhs.
289    for assign, value in bindings:
290      if isinstance(value, ReduceApply):
291        if value.lhs:
292          raise ValueError(f"Reduction expression already assigns: {value}")
293        value.lhs = assign
294      self.definitions.append(assign)
295      self.values.append(value)
296
297  @property
298  def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]:
299    """Gets the reduction dims for the comprehension or None."""
300    result = set()
301    for use in self.values:
302      if isinstance(use, ReduceApply):
303        result.add(use.reduce.reduce_dims)
304      else:
305        result.add(tuple())
306    return result
307
308  def __repr__(self):
309    if len(self.definitions) > 1:
310      defs_repr = f"({', '.join(repr(d) for d in self.definitions)})"
311      values_repr = f"({', '.join(repr(v) for v in self.values)})"
312    else:
313      defs_repr = f"{repr(self.definitions[0])}"
314      values_repr = f"{repr(self.values[0])}"
315
316    return f"{defs_repr} = {values_repr}"
317
318
319class PrimFnType:
320  """Primitive operations."""
321
322  def __init__(self, prim_name: str):
323    self.prim_name = prim_name
324
325  def __call__(self, *args):
326    return PrimApply(self, args)
327
328  def reduce(self, *reduce_dims: DimDef):
329    """Shortcut to create a Reduce operation from this primitive."""
330    return ReduceFnType(self, *reduce_dims)
331
332  def __repr__(self):
333    return f"{self.prim_name}"
334
335
336class PrimFn:
337  add = PrimFnType("add")
338  exp = PrimFnType("exp")
339  log = PrimFnType("log")
340  mul = PrimFnType("mul")
341  max = PrimFnType("max")
342  min = PrimFnType("min")
343  sub = PrimFnType("sub")
344
345
346class ReduceFnType:
347  """A reduction operator that reduces into its LHS from its RHS."""
348
349  def __init__(self, operator: PrimFnType, *reduce_dims: DimDef):
350    """Initializes the ReduceFn with a primitive function and dims."""
351    if not isinstance(operator, PrimFnType):
352      raise ValueError(f"Reduce expected a Prim operator but got {operator}")
353    self.operator = operator
354    self.reduce_dims = tuple(reduce_dims)
355
356  def __call__(self, *args: TensorExpression):
357    return ReduceApply(self, args)
358
359  def __repr__(self):
360    return (f"reduce_{self.operator.prim_name}"
361            f"({', '.join(repr(d) for d in self.reduce_dims)})")
362
363
364class ReduceFn:
365  add = PrimFn.add.reduce
366  mul = PrimFn.mul.reduce
367  max = PrimFn.max.reduce
368  min = PrimFn.min.reduce
369
370
371class PrimApply(TensorExpression):
372  """Application of a primitive."""
373
374  def __init__(self, prim: PrimFnType, args: Sequence[TensorExpression]):
375    self.prim = prim
376    self.args = tuple(args)
377
378  def to_scalar_expression(self) -> ScalarExpression:
379    return ScalarApplyFn(self.prim.prim_name,
380                         *[arg.to_scalar_expression() for arg in self.args
381                          ]).expr()
382
383  def visit_tensor_exprs(self, callback):
384    super().visit_tensor_exprs(callback)
385    for arg in self.args:
386      arg.visit_tensor_exprs(callback)
387
388  def __repr__(self):
389    return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})"
390
391
392class const(TensorExpression):
393  """Returns the given constant floating point or integer value."""
394
395  def __init__(self, value: Any):
396    with _ir.Context():
397      if isinstance(value, float):
398        self.value = str(_ir.FloatAttr.get_f64(float(value)))
399      elif isinstance(value, int):
400        self.value = str(
401            _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)))
402      else:
403        raise ValueError(f"const requires int or float but got {type(value)}")
404
405  def to_scalar_expression(self) -> ScalarExpression:
406    return ScalarConst(self.value).expr()
407
408  def __repr__(self):
409    return f"const({self.type_var}, {self.value})"
410
411
412class index(TensorExpression):
413  """Returns the iteration index for a given dimension name.
414
415  Resolves the given dimension name to obtain its position in the iteration
416  domain of the operation.
417  """
418
419  def __init__(self, dim: DimDef):
420    self.dim_def = dim
421    self.dim = -1
422
423  def resolve_dimension_name(self, affine_state: AffineBuildState):
424    self.dim = affine_state.get_dim(self.dim_def.dimname)
425
426  def to_scalar_expression(self) -> ScalarExpression:
427    assert self.dim != -1, "Dimension name not resolved"
428    return ScalarIndex(self.dim).expr()
429
430  def __repr__(self):
431    return f"index({repr(self.dim)})"
432
433
434class cast(TensorExpression):
435  """Casts the element type to a type (typically symbolic TypeVar)."""
436
437  def __init__(self, to_type: TypeVar, operand: TensorExpression):
438    self.to_type = to_type
439    self.operand = operand
440
441  def to_scalar_expression(self) -> ScalarExpression:
442    return ScalarSymbolicCast(self.to_type,
443                              self.operand.to_scalar_expression()).expr()
444
445  def visit_tensor_exprs(self, callback):
446    super().visit_tensor_exprs(callback)
447    self.operand.visit_tensor_exprs(callback)
448
449  def __repr__(self):
450    return f"cast({self.to_type}, {repr(self.operand)})"
451
452
453class ReduceApply(TensorExpression):
454  """Application of a reduction.
455
456  This captures the lhs separately (initial value) separately from the rhs.
457  """
458
459  def __init__(self, reduce: ReduceFnType, args: Sequence[TensorExpression]):
460    self.reduce = reduce
461    self.lhs = None  # type: Optional[TensorUse]
462    self.args = tuple(args)
463
464  def to_scalar_expression(self) -> ScalarExpression:
465    if self.lhs is None:
466      raise ValueError(f"Cannot scalarize a ReduceApply that has not been "
467                       f"bound to its lhs: {self}")
468    full_args = [self.lhs.to_scalar_expression()
469                ] + [arg.to_scalar_expression() for arg in self.args]
470    return ScalarApplyFn(self.reduce.operator.prim_name, *full_args).expr()
471
472  def visit_tensor_exprs(self, callback):
473    for arg in self.args:
474      arg.visit_tensor_exprs(callback)
475
476  def __repr__(self):
477    return f"{repr(self.reduce)}({', '.join(repr(a) for a in self.args)})"
478
479
480class OpInterfaceDef:
481  """An interface that an op implements."""
482
483  def __init__(self, cpp_name: str):
484    self.cpp_name = cpp_name
485
486
487ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface")
488
489
490class OpMetadataDef(YAMLObject):
491  """Metadata about the op (generally not behavior impacting)."""
492  yaml_tag = "!LinalgOpMetadata"
493
494  def __init__(self, name: str, cpp_class_name: Optional[str],
495               doc: Optional[str]):
496    self.name = name
497    self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name
498    self.doc = doc
499    self.implements = []  # type: List[OpInterfaceDef]
500
501  def to_yaml_custom_dict(self):
502    d = dict(
503        name=self.name,
504        cpp_class_name=self.cpp_class_name,
505        doc=self.doc,
506    )
507    if self.implements:
508      d["implements"] = [intr.cpp_name for intr in self.implements]
509    return d
510
511
512class LinalgOpDef:
513  """Definition of a linalg op."""
514
515  def __init__(self,
516               name: str,
517               cpp_class_name: Optional[str] = None,
518               doc: Optional[str] = None):
519    self.metadata = OpMetadataDef(
520        name=name, cpp_class_name=cpp_class_name, doc=doc)
521    self.registered_operands = dict()  # type: Dict[str, OperandDef]
522    self.domain = list()  # type: List[DimDef]
523    self.comprehensions = list()  # type: List[Comprehension]
524    self._affine_state = AffineBuildState()
525
526  def add_operand(self, name: str, operand: OperandDef):
527    """Registers an operand."""
528    if name in self.registered_operands:
529      raise ValueError(f"The operand {name} is already registered "
530                       f"to {self.registered_operands['name']}")
531    # Ensure output tensors are registered after input tensors and scalars and
532    # attributes are registered after all other operand types.
533    registered_kinds = [
534        operand.kind.value for operand in self.registered_operands.values()
535    ]
536    if registered_kinds:
537      maximum = max(registered_kinds)
538      if maximum > operand.kind.value and maximum > OperandKind.Scalar.value:
539        raise ValueError(
540            f"The operand {name} of kind {operand.kind.name} is registered "
541            f"after an operand of kind {OperandKind(maximum).name}")
542    operand.attach(len(self.registered_operands), name, self)
543    self.registered_operands[name] = operand
544
545  def __repr__(self):
546    lines = [
547        f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"
548    ]
549    for name, operand in self.registered_operands.items():
550      lines.append(f"  {operand}")
551    if self.comprehensions:
552      lines[-1] += " {"
553      for comprehension in self.comprehensions:
554        lines.append(f"    {comprehension}")
555      lines.append("}")
556    return "\n".join(lines)
557