1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 2# See https://llvm.org/LICENSE.txt for license information. 3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 4"""Model classes representing a tensor comprehension. 5 6These classes model the language more at an AST level as evaluated. Reasoning 7about it typically involves processing this form into config objects that 8represent actual op definitions (i.e. YAML). 9""" 10 11from typing import Any, Dict, List, Optional, Sequence, Set, Tuple 12from enum import Enum 13 14from mlir import ir as _ir 15 16from .affine import * 17from .scalar_expr import * 18from .types import * 19from .yaml_helper import * 20 21# Type aliases. 22AffineDimList = Dict[str, _ir.AffineExpr] 23 24 25class TensorExpression: 26 """An expression that can appear on the RHS of a comprehension.""" 27 28 def to_scalar_expression(self) -> ScalarExpression: 29 raise NotImplementedError() 30 31 def visit_tensor_exprs(self, callback): 32 """Visits all tensor expression reachable by the expression.""" 33 callback(self) 34 35 def collect_dim_uses(self, uses: Set["DimDef"]): 36 """Collects all DimDefs reachable through this expression.""" 37 results = set() 38 39 def visit_dim_def(dim_def): 40 if isinstance(dim_def, DimDef): 41 uses.add(dim_def) 42 43 def visit_affine_exprs(expr): 44 if isinstance(expr, TensorUse): 45 for ind in expr.indices: 46 ind.visit_affine_exprs(visit_dim_def) 47 if isinstance(expr, ReduceApply): 48 for ind in expr.reduce.reduce_dims: 49 ind.visit_affine_exprs(visit_dim_def) 50 51 self.visit_tensor_exprs(visit_affine_exprs) 52 53 def collect_tensor_uses(self, uses: Set["TensorUse"]): 54 """Collects all TensorUses reachable through this expression.""" 55 56 def visit_tensor_use(expr): 57 if isinstance(expr, TensorUse): 58 uses.add(expr) 59 60 self.visit_tensor_exprs(visit_tensor_use) 61 62 def collect_indices(self, indices: Set["index"]): 63 """Collects all index accesses reachable through this expression.""" 64 65 def visit_index(expr): 66 if isinstance(expr, index): 67 indices.add(expr) 68 69 self.visit_tensor_exprs(visit_index) 70 71 def collect_scalar_uses(self, uses: Set["ScalarDef"]): 72 """Collects all ScalarDefs reachable through this expression.""" 73 74 def visit_scalar_def(expr): 75 if isinstance(expr, ScalarDef): 76 uses.add(expr) 77 78 self.visit_tensor_exprs(visit_scalar_def) 79 80 def __add__(self, rhs: "TensorExpression") -> "TensorExpression": 81 return PrimFn.add(self, rhs) 82 83 def __mul__(self, rhs) -> "TensorExpression": 84 return PrimFn.mul(self, rhs) 85 86 def __sub__(self, rhs) -> "TensorExpression": 87 return PrimFn.sub(self, rhs) 88 89 def __hash__(self): 90 return hash(id(self)) 91 92 93class TensorUse(TensorExpression): 94 """A used tensor represented by its (tensor_name, indices). 95 96 Note that forming a comprehension via direct assignment is performed through 97 __setitem__ on the TensorDef level. However, performing a reduction with 98 compound ops (+=, *=, etc) is done by doing a: 99 TensorDef.__getitem__ 100 TensorUse.__iadd__ 101 TensorDef.__setitem__ 102 """ 103 104 def __init__(self, operand_def: "OperandDef", 105 indices: Sequence[AffineExprDef]): 106 self.operand_def = operand_def 107 self.indices = tuple(indices) 108 109 def to_scalar_expression(self) -> ScalarExpression: 110 return ScalarArg(self.tensor_name).expr() 111 112 @property 113 def tensor_name(self) -> str: 114 name = self.operand_def.name 115 assert name is not None, "TensorDef not attached" 116 return name 117 118 def __iadd__(self, rhs: TensorExpression) -> TensorExpression: 119 return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs) 120 121 def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: 122 """For implicit reductions, computes default reduction dims. 123 124 Assumes that the rhs is the expression being reduced and self is being 125 reduced into. Any indices referenced on the rhs and not in self are 126 considered reduction dims and will be ordered as encountered on the rhs. 127 """ 128 rhs_dims = set() 129 lhs_dims = set() 130 rhs.collect_dim_uses(rhs_dims) 131 self.collect_dim_uses(lhs_dims) 132 return rhs_dims - lhs_dims 133 134 def __repr__(self): 135 return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]" 136 137 138class OperandKind(Enum): 139 InputTensor = 0 140 Scalar = 1 141 OutputTensor = 2 142 Attribute = 3 143 144 145class OperandDef: 146 """Definition of an operand passed to an operation. 147 148 Keep the meta information of Tensor, Scalar, and Attribute operands and 149 provide the shared registration functionality. 150 """ 151 152 def __init__(self, 153 kind: OperandKind, 154 type_var: TypeVar, 155 size_exprs: Optional[Sequence[AffineExprDef]] = None, 156 index_dims: Optional[Sequence[DimDef]] = None): 157 if not isinstance(type_var, TypeVar): 158 raise ValueError( 159 f"OperandDef requires a TypeVar but got {repr(type_var)}") 160 self.owner = None # type: Optional["LinalgOpDef"] 161 self.type_var = type_var 162 self.size_exprs = size_exprs 163 self.index_dims = index_dims 164 self.kind = kind 165 self.name = None # type: Optional[str] 166 self.registered_index = -1 # type: int 167 168 def attach(self, index: int, name: str, owner: "LinalgOpDef"): 169 if self.owner: 170 raise ValueError(f"OperandDef already registered with op: {self}") 171 self.registered_index = index 172 self.name = name 173 self.owner = owner 174 175 def __hash__(self): 176 return hash(id(self)) 177 178 def __repr__(self): 179 return (f"{self.name}:OperandDef(kind={self.kind.name}, " 180 f"type={repr(self.type_var)}, size_exprs={self.size_exprs}), " 181 f"index_dims={self.index_dims})") 182 183 184class TensorDef: 185 """Tensor operand definition. 186 187 Tensor operands are indexed using the associated indexing_map when forwarded 188 to the body of the structured op. A unique name identifies the tensor operands 189 and an index determines their position in the operation's parameter list. A 190 tensor definition takes type, a shape, and an optional flag to mark output 191 tensors. Additionally, a tuple of index dimensions may be used to map the 192 tensor to the loop dimensions of the operation. This mapping is needed to 193 compute the indexing map of shape-only tensors that have no uses. 194 """ 195 196 def __init__(self, 197 type_var: TypeVar, 198 *shape: AffineExprDef, 199 index_dims: Optional[Sequence[DimDef]] = None, 200 output: bool = False): 201 if index_dims and len(shape) != len(index_dims): 202 raise ValueError(f"Expected the shape rank {len(shape)} to match the " 203 f"number of index_dims {len(index_dims)}") 204 if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims): 205 raise ValueError(f"TensorDef requires index dims of type DimDef but " 206 f"got {index_dims}") 207 kind = OperandKind.OutputTensor if output else OperandKind.InputTensor 208 self.operand_def = OperandDef( 209 kind, type_var, size_exprs=shape, index_dims=index_dims) 210 211 def __getitem__(self, dims) -> TensorUse: 212 assert self.operand_def.owner, "TensorDef is not attached to an op" 213 state = AffineBuildState( 214 global_state=self.operand_def.owner._affine_state, 215 allow_new_symbols=False) 216 if not isinstance(dims, tuple): 217 dims = (dims,) # Handle single subscript case. 218 # Special case: (None) is a 0d-scalar use. 219 if dims == (None,): 220 dims = () 221 222 exprs = [] 223 for expr_def in dims: 224 if not isinstance(expr_def, AffineExprDef): 225 raise KeyError( 226 "A TensorDef can only be subscripted by a tuple of affine dims") 227 exprs.append(expr_def) 228 return TensorUse(self.operand_def, exprs) 229 230 def __setitem__(self, dims, value): 231 """Creates a new 1:1 comprehension by binding this tensor to an expression. 232 233 Note that due to the way assignment works in Python, we have to capture 234 direct assignment as a setitem on the TensorDef. 235 """ 236 if not isinstance(value, TensorExpression): 237 raise ValueError(f"Only TensorExpressions can be assigned to TensorDefs. " 238 f"Got: {repr(value)}") 239 use = self[dims] 240 comp = Comprehension((use, value)) 241 self.operand_def.owner.comprehensions.append(comp) 242 243 244class ScalarDef(TensorExpression): 245 """Scalar operand definition. 246 247 Scalar operands are forwarded to the body of the structured op as they are. 248 A unique name identifies the scalars and an index determines their position in 249 the operation's parameter list. 250 """ 251 252 def __init__(self, type_var: TypeVar): 253 self.operand_def = OperandDef(OperandKind.Scalar, type_var) 254 255 @property 256 def scalar_name(self) -> str: 257 name = self.operand_def.name 258 assert name is not None, "ScalarDef not attached" 259 return name 260 261 def to_scalar_expression(self) -> ScalarExpression: 262 return ScalarArg(self.scalar_name).expr() 263 264 265class AttributeDef: 266 """Index Attribute definition. 267 268 Index attributes provide a way to define and set symbols that can be used in 269 indexing expressions. Every attribute specifies a tuple of symbols that at 270 compile-time are replaced by integer values. 271 """ 272 yaml_tag = "!LinalgAttributeDef" 273 274 def __init__(self, *sizes: SymbolDef): 275 if any(not isinstance(size, SymbolDef) for size in sizes): 276 raise ValueError(f"AttributeDef requires sizes of type SymbolDef but got " 277 f"{sizes}") 278 self.operand_def = OperandDef(OperandKind.Attribute, I64, size_exprs=sizes) 279 280 281class Comprehension: 282 """Represents a single comprehension.""" 283 284 def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]): 285 self.definitions = list() # List[TensorUse] 286 self.values = list() # List[TensorExpression] 287 288 # Find the lhs to reduction rhs. 289 for assign, value in bindings: 290 if isinstance(value, ReduceApply): 291 if value.lhs: 292 raise ValueError(f"Reduction expression already assigns: {value}") 293 value.lhs = assign 294 self.definitions.append(assign) 295 self.values.append(value) 296 297 @property 298 def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]: 299 """Gets the reduction dims for the comprehension or None.""" 300 result = set() 301 for use in self.values: 302 if isinstance(use, ReduceApply): 303 result.add(use.reduce.reduce_dims) 304 else: 305 result.add(tuple()) 306 return result 307 308 def __repr__(self): 309 if len(self.definitions) > 1: 310 defs_repr = f"({', '.join(repr(d) for d in self.definitions)})" 311 values_repr = f"({', '.join(repr(v) for v in self.values)})" 312 else: 313 defs_repr = f"{repr(self.definitions[0])}" 314 values_repr = f"{repr(self.values[0])}" 315 316 return f"{defs_repr} = {values_repr}" 317 318 319class PrimFnType: 320 """Primitive operations.""" 321 322 def __init__(self, prim_name: str): 323 self.prim_name = prim_name 324 325 def __call__(self, *args): 326 return PrimApply(self, args) 327 328 def reduce(self, *reduce_dims: DimDef): 329 """Shortcut to create a Reduce operation from this primitive.""" 330 return ReduceFnType(self, *reduce_dims) 331 332 def __repr__(self): 333 return f"{self.prim_name}" 334 335 336class PrimFn: 337 add = PrimFnType("add") 338 exp = PrimFnType("exp") 339 log = PrimFnType("log") 340 mul = PrimFnType("mul") 341 max = PrimFnType("max") 342 min = PrimFnType("min") 343 sub = PrimFnType("sub") 344 345 346class ReduceFnType: 347 """A reduction operator that reduces into its LHS from its RHS.""" 348 349 def __init__(self, operator: PrimFnType, *reduce_dims: DimDef): 350 """Initializes the ReduceFn with a primitive function and dims.""" 351 if not isinstance(operator, PrimFnType): 352 raise ValueError(f"Reduce expected a Prim operator but got {operator}") 353 self.operator = operator 354 self.reduce_dims = tuple(reduce_dims) 355 356 def __call__(self, *args: TensorExpression): 357 return ReduceApply(self, args) 358 359 def __repr__(self): 360 return (f"reduce_{self.operator.prim_name}" 361 f"({', '.join(repr(d) for d in self.reduce_dims)})") 362 363 364class ReduceFn: 365 add = PrimFn.add.reduce 366 mul = PrimFn.mul.reduce 367 max = PrimFn.max.reduce 368 min = PrimFn.min.reduce 369 370 371class PrimApply(TensorExpression): 372 """Application of a primitive.""" 373 374 def __init__(self, prim: PrimFnType, args: Sequence[TensorExpression]): 375 self.prim = prim 376 self.args = tuple(args) 377 378 def to_scalar_expression(self) -> ScalarExpression: 379 return ScalarApplyFn(self.prim.prim_name, 380 *[arg.to_scalar_expression() for arg in self.args 381 ]).expr() 382 383 def visit_tensor_exprs(self, callback): 384 super().visit_tensor_exprs(callback) 385 for arg in self.args: 386 arg.visit_tensor_exprs(callback) 387 388 def __repr__(self): 389 return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})" 390 391 392class const(TensorExpression): 393 """Returns the given constant floating point or integer value.""" 394 395 def __init__(self, value: Any): 396 with _ir.Context(): 397 if isinstance(value, float): 398 self.value = str(_ir.FloatAttr.get_f64(float(value))) 399 elif isinstance(value, int): 400 self.value = str( 401 _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value))) 402 else: 403 raise ValueError(f"const requires int or float but got {type(value)}") 404 405 def to_scalar_expression(self) -> ScalarExpression: 406 return ScalarConst(self.value).expr() 407 408 def __repr__(self): 409 return f"const({self.type_var}, {self.value})" 410 411 412class index(TensorExpression): 413 """Returns the iteration index for a given dimension name. 414 415 Resolves the given dimension name to obtain its position in the iteration 416 domain of the operation. 417 """ 418 419 def __init__(self, dim: DimDef): 420 self.dim_def = dim 421 self.dim = -1 422 423 def resolve_dimension_name(self, affine_state: AffineBuildState): 424 self.dim = affine_state.get_dim(self.dim_def.dimname) 425 426 def to_scalar_expression(self) -> ScalarExpression: 427 assert self.dim != -1, "Dimension name not resolved" 428 return ScalarIndex(self.dim).expr() 429 430 def __repr__(self): 431 return f"index({repr(self.dim)})" 432 433 434class cast(TensorExpression): 435 """Casts the element type to a type (typically symbolic TypeVar).""" 436 437 def __init__(self, to_type: TypeVar, operand: TensorExpression): 438 self.to_type = to_type 439 self.operand = operand 440 441 def to_scalar_expression(self) -> ScalarExpression: 442 return ScalarSymbolicCast(self.to_type, 443 self.operand.to_scalar_expression()).expr() 444 445 def visit_tensor_exprs(self, callback): 446 super().visit_tensor_exprs(callback) 447 self.operand.visit_tensor_exprs(callback) 448 449 def __repr__(self): 450 return f"cast({self.to_type}, {repr(self.operand)})" 451 452 453class ReduceApply(TensorExpression): 454 """Application of a reduction. 455 456 This captures the lhs separately (initial value) separately from the rhs. 457 """ 458 459 def __init__(self, reduce: ReduceFnType, args: Sequence[TensorExpression]): 460 self.reduce = reduce 461 self.lhs = None # type: Optional[TensorUse] 462 self.args = tuple(args) 463 464 def to_scalar_expression(self) -> ScalarExpression: 465 if self.lhs is None: 466 raise ValueError(f"Cannot scalarize a ReduceApply that has not been " 467 f"bound to its lhs: {self}") 468 full_args = [self.lhs.to_scalar_expression() 469 ] + [arg.to_scalar_expression() for arg in self.args] 470 return ScalarApplyFn(self.reduce.operator.prim_name, *full_args).expr() 471 472 def visit_tensor_exprs(self, callback): 473 for arg in self.args: 474 arg.visit_tensor_exprs(callback) 475 476 def __repr__(self): 477 return f"{repr(self.reduce)}({', '.join(repr(a) for a in self.args)})" 478 479 480class OpInterfaceDef: 481 """An interface that an op implements.""" 482 483 def __init__(self, cpp_name: str): 484 self.cpp_name = cpp_name 485 486 487ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface") 488 489 490class OpMetadataDef(YAMLObject): 491 """Metadata about the op (generally not behavior impacting).""" 492 yaml_tag = "!LinalgOpMetadata" 493 494 def __init__(self, name: str, cpp_class_name: Optional[str], 495 doc: Optional[str]): 496 self.name = name 497 self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name 498 self.doc = doc 499 self.implements = [] # type: List[OpInterfaceDef] 500 501 def to_yaml_custom_dict(self): 502 d = dict( 503 name=self.name, 504 cpp_class_name=self.cpp_class_name, 505 doc=self.doc, 506 ) 507 if self.implements: 508 d["implements"] = [intr.cpp_name for intr in self.implements] 509 return d 510 511 512class LinalgOpDef: 513 """Definition of a linalg op.""" 514 515 def __init__(self, 516 name: str, 517 cpp_class_name: Optional[str] = None, 518 doc: Optional[str] = None): 519 self.metadata = OpMetadataDef( 520 name=name, cpp_class_name=cpp_class_name, doc=doc) 521 self.registered_operands = dict() # type: Dict[str, OperandDef] 522 self.domain = list() # type: List[DimDef] 523 self.comprehensions = list() # type: List[Comprehension] 524 self._affine_state = AffineBuildState() 525 526 def add_operand(self, name: str, operand: OperandDef): 527 """Registers an operand.""" 528 if name in self.registered_operands: 529 raise ValueError(f"The operand {name} is already registered " 530 f"to {self.registered_operands['name']}") 531 # Ensure output tensors are registered after input tensors and scalars and 532 # attributes are registered after all other operand types. 533 registered_kinds = [ 534 operand.kind.value for operand in self.registered_operands.values() 535 ] 536 if registered_kinds: 537 maximum = max(registered_kinds) 538 if maximum > operand.kind.value and maximum > OperandKind.Scalar.value: 539 raise ValueError( 540 f"The operand {name} of kind {operand.kind.name} is registered " 541 f"after an operand of kind {OperandKind(maximum).name}") 542 operand.attach(len(self.registered_operands), name, self) 543 self.registered_operands[name] = operand 544 545 def __repr__(self): 546 lines = [ 547 f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name}," 548 ] 549 for name, operand in self.registered_operands.items(): 550 lines.append(f" {operand}") 551 if self.comprehensions: 552 lines[-1] += " {" 553 for comprehension in self.comprehensions: 554 lines.append(f" {comprehension}") 555 lines.append("}") 556 return "\n".join(lines) 557