1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10
11 #include "mlir/IR/Operation.h"
12 #include "llvm/Support/Debug.h"
13
14 namespace mlir {
15 namespace sparse_tensor {
16
17 //===----------------------------------------------------------------------===//
18 // Constructors.
19 //===----------------------------------------------------------------------===//
20
TensorExp(Kind k,unsigned x,unsigned y,Value v)21 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
22 : kind(k), val(v) {
23 switch (kind) {
24 case kTensor:
25 assert(x != -1u && y == -1u && !v);
26 tensor = x;
27 break;
28 case kInvariant:
29 assert(x == -1u && y == -1u && v);
30 break;
31 case kAbsF:
32 case kCeilF:
33 case kFloorF:
34 case kNegF:
35 case kNegI:
36 assert(x != -1u && y == -1u && !v);
37 children.e0 = x;
38 children.e1 = y;
39 break;
40 case kTruncF:
41 case kExtF:
42 case kCastFS:
43 case kCastFU:
44 case kCastSF:
45 case kCastUF:
46 case kCastS:
47 case kCastU:
48 case kTruncI:
49 case kBitCast:
50 assert(x != -1u && y == -1u && v);
51 children.e0 = x;
52 children.e1 = y;
53 break;
54 default:
55 assert(x != -1u && y != -1u && !v);
56 children.e0 = x;
57 children.e1 = y;
58 break;
59 }
60 }
61
LatPoint(unsigned n,unsigned e,unsigned b)62 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
63 : bits(n, false), simple(), exp(e) {
64 bits.set(b);
65 }
66
LatPoint(const llvm::BitVector & b,unsigned e)67 LatPoint::LatPoint(const llvm::BitVector &b, unsigned e)
68 : bits(b), simple(), exp(e) {}
69
70 //===----------------------------------------------------------------------===//
71 // Lattice methods.
72 //===----------------------------------------------------------------------===//
73
addExp(Kind k,unsigned e0,unsigned e1,Value v)74 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
75 unsigned e = tensorExps.size();
76 tensorExps.push_back(TensorExp(k, e0, e1, v));
77 return e;
78 }
79
addLat(unsigned t,unsigned i,unsigned e)80 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
81 assert(t < numTensors && i < numLoops);
82 unsigned p = latPoints.size();
83 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
84 return p;
85 }
86
addSet()87 unsigned Merger::addSet() {
88 unsigned s = latSets.size();
89 latSets.emplace_back(SmallVector<unsigned, 16>());
90 return s;
91 }
92
conjLatPoint(Kind kind,unsigned p0,unsigned p1)93 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
94 unsigned p = latPoints.size();
95 llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
96 nb |= latPoints[p1].bits;
97 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
98 latPoints.push_back(LatPoint(nb, e));
99 return p;
100 }
101
takeConj(Kind kind,unsigned s0,unsigned s1)102 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
103 unsigned s = addSet();
104 for (unsigned p0 : latSets[s0])
105 for (unsigned p1 : latSets[s1])
106 latSets[s].push_back(conjLatPoint(kind, p0, p1));
107 return s;
108 }
109
takeDisj(Kind kind,unsigned s0,unsigned s1)110 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
111 unsigned s = takeConj(kind, s0, s1);
112 // Followed by all in s0.
113 for (unsigned p : latSets[s0])
114 latSets[s].push_back(p);
115 // Map binary 0-y to unary -y.
116 if (kind == kSubF)
117 s1 = mapSet(kNegF, s1);
118 else if (kind == kSubI)
119 s1 = mapSet(kNegI, s1);
120 // Followed by all in s1.
121 for (unsigned p : latSets[s1])
122 latSets[s].push_back(p);
123 return s;
124 }
125
mapSet(Kind kind,unsigned s0,Value v)126 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v) {
127 assert(kAbsF <= kind && kind <= kBitCast);
128 unsigned s = addSet();
129 for (unsigned p : latSets[s0]) {
130 unsigned e = addExp(kind, latPoints[p].exp, v);
131 latPoints.push_back(LatPoint(latPoints[p].bits, e));
132 latSets[s].push_back(latPoints.size() - 1);
133 }
134 return s;
135 }
136
optimizeSet(unsigned s0)137 unsigned Merger::optimizeSet(unsigned s0) {
138 unsigned s = addSet();
139 assert(latSets[s0].size() != 0);
140 unsigned p0 = latSets[s0][0];
141 for (unsigned p1 : latSets[s0]) {
142 bool add = true;
143 if (p0 != p1) {
144 // Is this a straightforward copy?
145 unsigned e = latPoints[p1].exp;
146 if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
147 continue;
148 // Conjunction already covered?
149 for (unsigned p2 : latSets[s]) {
150 assert(!latGT(p1, p2)); // Lj => Li would be bad
151 if (onlyDenseDiff(p2, p1)) {
152 add = false;
153 break;
154 }
155 }
156 assert(!add || latGT(p0, p1));
157 }
158 if (add)
159 latSets[s].push_back(p1);
160 }
161 for (unsigned p : latSets[s])
162 latPoints[p].simple = simplifyCond(s, p);
163 return s;
164 }
165
simplifyCond(unsigned s0,unsigned p0)166 llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
167 // First determine if this lattice point is a *singleton*, i.e.,
168 // the last point in a lattice, no other is less than this one.
169 bool isSingleton = true;
170 for (unsigned p1 : latSets[s0]) {
171 if (p0 != p1 && latGT(p0, p1)) {
172 isSingleton = false;
173 break;
174 }
175 }
176 // Now apply the two basic rules.
177 llvm::BitVector simple = latPoints[p0].bits;
178 bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
179 for (unsigned b = 0, be = simple.size(); b < be; b++) {
180 if (simple[b] && !isDim(b, kSparse)) {
181 if (reset)
182 simple.reset(b);
183 reset = true;
184 }
185 }
186 return simple;
187 }
188
latGT(unsigned i,unsigned j) const189 bool Merger::latGT(unsigned i, unsigned j) const {
190 const llvm::BitVector &bitsi = latPoints[i].bits;
191 const llvm::BitVector &bitsj = latPoints[j].bits;
192 assert(bitsi.size() == bitsj.size());
193 if (bitsi.count() > bitsj.count()) {
194 for (unsigned b = 0, be = bitsj.size(); b < be; b++)
195 if (bitsj[b] && !bitsi[b])
196 return false;
197 return true;
198 }
199 return false;
200 }
201
onlyDenseDiff(unsigned i,unsigned j)202 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
203 llvm::BitVector tmp = latPoints[j].bits;
204 tmp ^= latPoints[i].bits;
205 return !hasAnyDimOf(tmp, kSparse);
206 }
207
hasAnyDimOf(const llvm::BitVector & bits,Dim d) const208 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
209 for (unsigned b = 0, be = bits.size(); b < be; b++)
210 if (bits[b] && isDim(b, d))
211 return true;
212 return false;
213 }
214
isConjunction(unsigned t,unsigned e) const215 bool Merger::isConjunction(unsigned t, unsigned e) const {
216 switch (tensorExps[e].kind) {
217 case kTensor:
218 return tensorExps[e].tensor == t;
219 case kAbsF:
220 case kCeilF:
221 case kFloorF:
222 case kNegF:
223 case kNegI:
224 case kTruncF:
225 case kExtF:
226 case kCastFS:
227 case kCastFU:
228 case kCastSF:
229 case kCastUF:
230 case kCastS:
231 case kCastU:
232 case kTruncI:
233 case kBitCast:
234 return isConjunction(t, tensorExps[e].children.e0);
235 case kDivF: // note: x / c only
236 case kDivS:
237 case kDivU:
238 assert(!maybeZero(tensorExps[e].children.e1));
239 return isConjunction(t, tensorExps[e].children.e0);
240 case kShrS: // note: x >> inv only
241 case kShrU:
242 case kShlI:
243 assert(isInvariant(tensorExps[e].children.e1));
244 return isConjunction(t, tensorExps[e].children.e0);
245 case kMulF:
246 case kMulI:
247 case kAndI:
248 return isConjunction(t, tensorExps[e].children.e0) ||
249 isConjunction(t, tensorExps[e].children.e1);
250 default:
251 return false;
252 }
253 }
254
255 #ifndef NDEBUG
256
257 //===----------------------------------------------------------------------===//
258 // Print methods (for debugging).
259 //===----------------------------------------------------------------------===//
260
kindToOpSymbol(Kind kind)261 static const char *kindToOpSymbol(Kind kind) {
262 switch (kind) {
263 case kTensor:
264 return "tensor";
265 case kInvariant:
266 return "invariant";
267 case kAbsF:
268 return "abs";
269 case kCeilF:
270 return "ceil";
271 case kFloorF:
272 return "floor";
273 case kNegF:
274 return "-";
275 case kNegI:
276 return "-";
277 case kTruncF:
278 case kExtF:
279 case kCastFS:
280 case kCastFU:
281 case kCastSF:
282 case kCastUF:
283 case kCastS:
284 case kCastU:
285 case kTruncI:
286 case kBitCast:
287 return "cast";
288 case kMulF:
289 return "*";
290 case kMulI:
291 return "*";
292 case kDivF:
293 return "/";
294 case kDivS:
295 return "/";
296 case kDivU:
297 return "/";
298 case kAddF:
299 return "+";
300 case kAddI:
301 return "+";
302 case kSubF:
303 return "-";
304 case kSubI:
305 return "-";
306 case kAndI:
307 return "&";
308 case kOrI:
309 return "|";
310 case kXorI:
311 return "^";
312 case kShrS:
313 return "a>>";
314 case kShrU:
315 return ">>";
316 case kShlI:
317 return "<<";
318 }
319 llvm_unreachable("unexpected kind for symbol");
320 }
321
dumpExp(unsigned e) const322 void Merger::dumpExp(unsigned e) const {
323 switch (tensorExps[e].kind) {
324 case kTensor:
325 if (tensorExps[e].tensor == syntheticTensor)
326 llvm::dbgs() << "synthetic_";
327 else if (tensorExps[e].tensor == outTensor)
328 llvm::dbgs() << "output_";
329 llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
330 break;
331 case kInvariant:
332 llvm::dbgs() << "invariant";
333 break;
334 case kAbsF:
335 case kCeilF:
336 case kFloorF:
337 case kNegF:
338 case kNegI:
339 case kTruncF:
340 case kExtF:
341 case kCastFS:
342 case kCastFU:
343 case kCastSF:
344 case kCastUF:
345 case kCastS:
346 case kCastU:
347 case kTruncI:
348 case kBitCast:
349 llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
350 dumpExp(tensorExps[e].children.e0);
351 break;
352 default:
353 llvm::dbgs() << "(";
354 dumpExp(tensorExps[e].children.e0);
355 llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
356 dumpExp(tensorExps[e].children.e1);
357 llvm::dbgs() << ")";
358 }
359 }
360
dumpLat(unsigned p) const361 void Merger::dumpLat(unsigned p) const {
362 llvm::dbgs() << "lat(";
363 dumpBits(latPoints[p].bits);
364 llvm::dbgs() << " :";
365 dumpBits(latPoints[p].simple);
366 llvm::dbgs() << " : ";
367 dumpExp(latPoints[p].exp);
368 llvm::dbgs() << " )\n";
369 }
370
dumpSet(unsigned s) const371 void Merger::dumpSet(unsigned s) const {
372 llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
373 for (unsigned p : latSets[s]) {
374 llvm::dbgs() << " ";
375 dumpLat(p);
376 }
377 llvm::dbgs() << "}\n";
378 }
379
dumpBits(const llvm::BitVector & bits) const380 void Merger::dumpBits(const llvm::BitVector &bits) const {
381 for (unsigned b = 0, be = bits.size(); b < be; b++) {
382 if (bits[b]) {
383 unsigned t = tensor(b);
384 unsigned i = index(b);
385 llvm::dbgs() << " i_" << t << "_" << i << "_";
386 switch (dims[t][i]) {
387 case kSparse:
388 llvm::dbgs() << "S";
389 break;
390 case kDense:
391 llvm::dbgs() << "D";
392 break;
393 case kSingle:
394 llvm::dbgs() << "T";
395 break;
396 case kUndef:
397 llvm::dbgs() << "U";
398 break;
399 }
400 }
401 }
402 }
403
404 #endif // NDEBUG
405
406 //===----------------------------------------------------------------------===//
407 // Builder methods.
408 //===----------------------------------------------------------------------===//
409
buildLattices(unsigned e,unsigned i)410 unsigned Merger::buildLattices(unsigned e, unsigned i) {
411 Kind kind = tensorExps[e].kind;
412 switch (kind) {
413 case kTensor:
414 case kInvariant: {
415 // Either the index is really used in the tensor expression, or it is
416 // set to the undefined index in that dimension. An invariant expression
417 // is set to a synthetic tensor with undefined indices only.
418 unsigned s = addSet();
419 unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor;
420 latSets[s].push_back(addLat(t, i, e));
421 return s;
422 }
423 case kAbsF:
424 case kCeilF:
425 case kFloorF:
426 case kNegF:
427 case kNegI:
428 case kTruncF:
429 case kExtF:
430 case kCastFS:
431 case kCastFU:
432 case kCastSF:
433 case kCastUF:
434 case kCastS:
435 case kCastU:
436 case kTruncI:
437 case kBitCast:
438 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
439 // lattice set of the operand through the operator into a new set.
440 //
441 // -y|!y | y |
442 // --+---+---+
443 // | 0 |-y |
444 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
445 tensorExps[e].val);
446 case kMulF:
447 case kMulI:
448 case kAndI:
449 // A multiplicative operation only needs to be performed
450 // for the conjunction of sparse iteration spaces.
451 //
452 // x*y|!y | y |
453 // ---+---+---+
454 // !x | 0 | 0 |
455 // x | 0 |x*y|
456 return takeConj(kind, // take binary conjunction
457 buildLattices(tensorExps[e].children.e0, i),
458 buildLattices(tensorExps[e].children.e1, i));
459 case kDivF:
460 case kDivS:
461 case kDivU:
462 // A division is tricky, since 0/0, 0/c, c/0 all have
463 // specific outcomes for floating-point and integers.
464 // Thus, we need to traverse the full iteration space.
465 //
466 // x/y|!y | y |
467 // ---+---+---+
468 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
469 // x |x/0|x/y| INT: x/0=exception for any x
470 //
471 // TODO: for now we "fixed" this by only accepting x/c cases
472 // during expression building, so that the conjunction
473 // rules applies (viz. x/c = x*(1/c) as far as lattice
474 // construction is concerned).
475 assert(!maybeZero(tensorExps[e].children.e1));
476 return takeConj(kind, // take binary conjunction
477 buildLattices(tensorExps[e].children.e0, i),
478 buildLattices(tensorExps[e].children.e1, i));
479 case kAddF:
480 case kAddI:
481 case kSubF:
482 case kSubI:
483 case kOrI:
484 case kXorI:
485 // An additive operation needs to be performed
486 // for the disjunction of sparse iteration spaces.
487 //
488 // x+y|!y | y | x-y|!y | y |
489 // ---+---+---+ ---+---+---+
490 // !x | 0 | y | !x | 0 |-y |
491 // x | x |x+y| x | x |x-y|
492 return takeDisj(kind, // take binary disjunction
493 buildLattices(tensorExps[e].children.e0, i),
494 buildLattices(tensorExps[e].children.e1, i));
495 case kShrS:
496 case kShrU:
497 case kShlI:
498 // A shift operation by an invariant amount (viz. tensor expressions
499 // can only occur at the left-hand-side of the operator) can be handled
500 // with the conjuction rule.
501 assert(isInvariant(tensorExps[e].children.e1));
502 return takeConj(kind, // take binary conjunction
503 buildLattices(tensorExps[e].children.e0, i),
504 buildLattices(tensorExps[e].children.e1, i));
505 }
506 llvm_unreachable("unexpected expression kind");
507 }
508
buildTensorExpFromLinalg(linalg::GenericOp op)509 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
510 Operation *yield = op.region().front().getTerminator();
511 return buildTensorExp(op, yield->getOperand(0));
512 }
513
514 /// Only returns false if we are certain this is a nonzero.
maybeZero(unsigned e) const515 bool Merger::maybeZero(unsigned e) const {
516 if (tensorExps[e].kind == kInvariant) {
517 if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>())
518 return c.getValue() == 0;
519 if (auto c = tensorExps[e].val.getDefiningOp<ConstantFloatOp>())
520 return c.getValue().isZero();
521 }
522 return true;
523 }
524
isInvariant(unsigned e) const525 bool Merger::isInvariant(unsigned e) const {
526 return tensorExps[e].kind == kInvariant;
527 }
528
inferType(unsigned e,Value src)529 Type Merger::inferType(unsigned e, Value src) {
530 // Obtain the destination type from the cast node.
531 Type dtp = tensorExps[e].val.getType();
532 // Inspect source type. For vector types, apply the same
533 // vectorization to the destination type.
534 if (auto vtp = src.getType().dyn_cast<VectorType>())
535 return VectorType::get(vtp.getNumElements(), dtp);
536 return dtp;
537 }
538
buildTensorExp(linalg::GenericOp op,Value v)539 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
540 if (auto arg = v.dyn_cast<BlockArgument>()) {
541 unsigned argN = arg.getArgNumber();
542 // Any argument of the generic op that is not marked as a scalar
543 // argument is considered a tensor, indexed by the implicit loop
544 // bounds. This includes rank-0 tensor arguments.
545 if (arg.getOwner()->getParentOp() == op) {
546 OpOperand *t = op.getInputAndOutputOperands()[argN];
547 if (!op.isScalar(t))
548 return addExp(kTensor, argN);
549 v = t->get(); // get scalar value
550 }
551 // Any other argument (marked as scalar argument for the generic op
552 // or belonging to an enveloping op) is considered invariant.
553 return addExp(kInvariant, v);
554 }
555 // Something defined outside is invariant.
556 Operation *def = v.getDefiningOp();
557 if (def->getBlock() != &op.region().front())
558 return addExp(kInvariant, v);
559 // Construct unary operations if subexpression can be built.
560 if (def->getNumOperands() == 1) {
561 auto x = buildTensorExp(op, def->getOperand(0));
562 if (x.hasValue()) {
563 unsigned e = x.getValue();
564 if (isa<AbsFOp>(def))
565 return addExp(kAbsF, e);
566 if (isa<CeilFOp>(def))
567 return addExp(kCeilF, e);
568 if (isa<FloorFOp>(def))
569 return addExp(kFloorF, e);
570 if (isa<NegFOp>(def))
571 return addExp(kNegF, e); // no negi in std
572 if (isa<FPTruncOp>(def))
573 return addExp(kTruncF, e, v);
574 if (isa<FPExtOp>(def))
575 return addExp(kExtF, e, v);
576 if (isa<FPToSIOp>(def))
577 return addExp(kCastFS, e, v);
578 if (isa<FPToUIOp>(def))
579 return addExp(kCastFU, e, v);
580 if (isa<SIToFPOp>(def))
581 return addExp(kCastSF, e, v);
582 if (isa<UIToFPOp>(def))
583 return addExp(kCastUF, e, v);
584 if (isa<SignExtendIOp>(def))
585 return addExp(kCastS, e, v);
586 if (isa<ZeroExtendIOp>(def))
587 return addExp(kCastU, e, v);
588 if (isa<TruncateIOp>(def))
589 return addExp(kTruncI, e, v);
590 if (isa<BitcastOp>(def))
591 return addExp(kBitCast, e, v);
592 }
593 }
594 // Construct binary operations if subexpressions can be built.
595 // TODO: see buildLattices() for an explanation of rejecting
596 // certain division and shift operations
597 if (def->getNumOperands() == 2) {
598 auto x = buildTensorExp(op, def->getOperand(0));
599 auto y = buildTensorExp(op, def->getOperand(1));
600 if (x.hasValue() && y.hasValue()) {
601 unsigned e0 = x.getValue();
602 unsigned e1 = y.getValue();
603 if (isa<MulFOp>(def))
604 return addExp(kMulF, e0, e1);
605 if (isa<MulIOp>(def))
606 return addExp(kMulI, e0, e1);
607 if (isa<DivFOp>(def) && !maybeZero(e1))
608 return addExp(kDivF, e0, e1);
609 if (isa<SignedDivIOp>(def) && !maybeZero(e1))
610 return addExp(kDivS, e0, e1);
611 if (isa<UnsignedDivIOp>(def) && !maybeZero(e1))
612 return addExp(kDivU, e0, e1);
613 if (isa<AddFOp>(def))
614 return addExp(kAddF, e0, e1);
615 if (isa<AddIOp>(def))
616 return addExp(kAddI, e0, e1);
617 if (isa<SubFOp>(def))
618 return addExp(kSubF, e0, e1);
619 if (isa<SubIOp>(def))
620 return addExp(kSubI, e0, e1);
621 if (isa<AndOp>(def))
622 return addExp(kAndI, e0, e1);
623 if (isa<OrOp>(def))
624 return addExp(kOrI, e0, e1);
625 if (isa<XOrOp>(def))
626 return addExp(kXorI, e0, e1);
627 if (isa<SignedShiftRightOp>(def) && isInvariant(e1))
628 return addExp(kShrS, e0, e1);
629 if (isa<UnsignedShiftRightOp>(def) && isInvariant(e1))
630 return addExp(kShrU, e0, e1);
631 if (isa<ShiftLeftOp>(def) && isInvariant(e1))
632 return addExp(kShlI, e0, e1);
633 }
634 }
635 // Cannot build.
636 return None;
637 }
638
buildExp(PatternRewriter & rewriter,Location loc,unsigned e,Value v0,Value v1)639 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
640 Value v0, Value v1) {
641 switch (tensorExps[e].kind) {
642 case kTensor:
643 case kInvariant:
644 llvm_unreachable("unexpected non-op");
645 // Unary ops.
646 case kAbsF:
647 return rewriter.create<AbsFOp>(loc, v0);
648 case kCeilF:
649 return rewriter.create<CeilFOp>(loc, v0);
650 case kFloorF:
651 return rewriter.create<FloorFOp>(loc, v0);
652 case kNegF:
653 return rewriter.create<NegFOp>(loc, v0);
654 case kNegI: // no negi in std
655 return rewriter.create<SubIOp>(
656 loc,
657 rewriter.create<ConstantOp>(loc, v0.getType(),
658 rewriter.getZeroAttr(v0.getType())),
659 v0);
660 case kTruncF:
661 return rewriter.create<FPTruncOp>(loc, v0, inferType(e, v0));
662 case kExtF:
663 return rewriter.create<FPExtOp>(loc, v0, inferType(e, v0));
664 case kCastFS:
665 return rewriter.create<FPToSIOp>(loc, v0, inferType(e, v0));
666 case kCastFU:
667 return rewriter.create<FPToUIOp>(loc, v0, inferType(e, v0));
668 case kCastSF:
669 return rewriter.create<SIToFPOp>(loc, v0, inferType(e, v0));
670 case kCastUF:
671 return rewriter.create<UIToFPOp>(loc, v0, inferType(e, v0));
672 case kCastS:
673 return rewriter.create<SignExtendIOp>(loc, v0, inferType(e, v0));
674 case kCastU:
675 return rewriter.create<ZeroExtendIOp>(loc, v0, inferType(e, v0));
676 case kTruncI:
677 return rewriter.create<TruncateIOp>(loc, v0, inferType(e, v0));
678 case kBitCast:
679 return rewriter.create<BitcastOp>(loc, v0, inferType(e, v0));
680 // Binary ops.
681 case kMulF:
682 return rewriter.create<MulFOp>(loc, v0, v1);
683 case kMulI:
684 return rewriter.create<MulIOp>(loc, v0, v1);
685 case kDivF:
686 return rewriter.create<DivFOp>(loc, v0, v1);
687 case kDivS:
688 return rewriter.create<SignedDivIOp>(loc, v0, v1);
689 case kDivU:
690 return rewriter.create<UnsignedDivIOp>(loc, v0, v1);
691 case kAddF:
692 return rewriter.create<AddFOp>(loc, v0, v1);
693 case kAddI:
694 return rewriter.create<AddIOp>(loc, v0, v1);
695 case kSubF:
696 return rewriter.create<SubFOp>(loc, v0, v1);
697 case kSubI:
698 return rewriter.create<SubIOp>(loc, v0, v1);
699 case kAndI:
700 return rewriter.create<AndOp>(loc, v0, v1);
701 case kOrI:
702 return rewriter.create<OrOp>(loc, v0, v1);
703 case kXorI:
704 return rewriter.create<XOrOp>(loc, v0, v1);
705 case kShrS:
706 return rewriter.create<SignedShiftRightOp>(loc, v0, v1);
707 case kShrU:
708 return rewriter.create<UnsignedShiftRightOp>(loc, v0, v1);
709 case kShlI:
710 return rewriter.create<ShiftLeftOp>(loc, v0, v1);
711 }
712 llvm_unreachable("unexpected expression kind in build");
713 }
714
715 } // namespace sparse_tensor
716 } // namespace mlir
717