1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 
11 #include "mlir/IR/Operation.h"
12 #include "llvm/Support/Debug.h"
13 
14 namespace mlir {
15 namespace sparse_tensor {
16 
17 //===----------------------------------------------------------------------===//
18 // Constructors.
19 //===----------------------------------------------------------------------===//
20 
TensorExp(Kind k,unsigned x,unsigned y,Value v)21 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
22     : kind(k), val(v) {
23   switch (kind) {
24   case kTensor:
25     assert(x != -1u && y == -1u && !v);
26     tensor = x;
27     break;
28   case kInvariant:
29     assert(x == -1u && y == -1u && v);
30     break;
31   case kAbsF:
32   case kCeilF:
33   case kFloorF:
34   case kNegF:
35   case kNegI:
36     assert(x != -1u && y == -1u && !v);
37     children.e0 = x;
38     children.e1 = y;
39     break;
40   case kTruncF:
41   case kExtF:
42   case kCastFS:
43   case kCastFU:
44   case kCastSF:
45   case kCastUF:
46   case kCastS:
47   case kCastU:
48   case kTruncI:
49   case kBitCast:
50     assert(x != -1u && y == -1u && v);
51     children.e0 = x;
52     children.e1 = y;
53     break;
54   default:
55     assert(x != -1u && y != -1u && !v);
56     children.e0 = x;
57     children.e1 = y;
58     break;
59   }
60 }
61 
LatPoint(unsigned n,unsigned e,unsigned b)62 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
63     : bits(n, false), simple(), exp(e) {
64   bits.set(b);
65 }
66 
LatPoint(const llvm::BitVector & b,unsigned e)67 LatPoint::LatPoint(const llvm::BitVector &b, unsigned e)
68     : bits(b), simple(), exp(e) {}
69 
70 //===----------------------------------------------------------------------===//
71 // Lattice methods.
72 //===----------------------------------------------------------------------===//
73 
addExp(Kind k,unsigned e0,unsigned e1,Value v)74 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
75   unsigned e = tensorExps.size();
76   tensorExps.push_back(TensorExp(k, e0, e1, v));
77   return e;
78 }
79 
addLat(unsigned t,unsigned i,unsigned e)80 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
81   assert(t < numTensors && i < numLoops);
82   unsigned p = latPoints.size();
83   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
84   return p;
85 }
86 
addSet()87 unsigned Merger::addSet() {
88   unsigned s = latSets.size();
89   latSets.emplace_back(SmallVector<unsigned, 16>());
90   return s;
91 }
92 
conjLatPoint(Kind kind,unsigned p0,unsigned p1)93 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
94   unsigned p = latPoints.size();
95   llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
96   nb |= latPoints[p1].bits;
97   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
98   latPoints.push_back(LatPoint(nb, e));
99   return p;
100 }
101 
takeConj(Kind kind,unsigned s0,unsigned s1)102 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
103   unsigned s = addSet();
104   for (unsigned p0 : latSets[s0])
105     for (unsigned p1 : latSets[s1])
106       latSets[s].push_back(conjLatPoint(kind, p0, p1));
107   return s;
108 }
109 
takeDisj(Kind kind,unsigned s0,unsigned s1)110 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
111   unsigned s = takeConj(kind, s0, s1);
112   // Followed by all in s0.
113   for (unsigned p : latSets[s0])
114     latSets[s].push_back(p);
115   // Map binary 0-y to unary -y.
116   if (kind == kSubF)
117     s1 = mapSet(kNegF, s1);
118   else if (kind == kSubI)
119     s1 = mapSet(kNegI, s1);
120   // Followed by all in s1.
121   for (unsigned p : latSets[s1])
122     latSets[s].push_back(p);
123   return s;
124 }
125 
mapSet(Kind kind,unsigned s0,Value v)126 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v) {
127   assert(kAbsF <= kind && kind <= kBitCast);
128   unsigned s = addSet();
129   for (unsigned p : latSets[s0]) {
130     unsigned e = addExp(kind, latPoints[p].exp, v);
131     latPoints.push_back(LatPoint(latPoints[p].bits, e));
132     latSets[s].push_back(latPoints.size() - 1);
133   }
134   return s;
135 }
136 
optimizeSet(unsigned s0)137 unsigned Merger::optimizeSet(unsigned s0) {
138   unsigned s = addSet();
139   assert(latSets[s0].size() != 0);
140   unsigned p0 = latSets[s0][0];
141   for (unsigned p1 : latSets[s0]) {
142     bool add = true;
143     if (p0 != p1) {
144       // Is this a straightforward copy?
145       unsigned e = latPoints[p1].exp;
146       if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
147         continue;
148       // Conjunction already covered?
149       for (unsigned p2 : latSets[s]) {
150         assert(!latGT(p1, p2)); // Lj => Li would be bad
151         if (onlyDenseDiff(p2, p1)) {
152           add = false;
153           break;
154         }
155       }
156       assert(!add || latGT(p0, p1));
157     }
158     if (add)
159       latSets[s].push_back(p1);
160   }
161   for (unsigned p : latSets[s])
162     latPoints[p].simple = simplifyCond(s, p);
163   return s;
164 }
165 
simplifyCond(unsigned s0,unsigned p0)166 llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
167   // First determine if this lattice point is a *singleton*, i.e.,
168   // the last point in a lattice, no other is less than this one.
169   bool isSingleton = true;
170   for (unsigned p1 : latSets[s0]) {
171     if (p0 != p1 && latGT(p0, p1)) {
172       isSingleton = false;
173       break;
174     }
175   }
176   // Now apply the two basic rules.
177   llvm::BitVector simple = latPoints[p0].bits;
178   bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
179   for (unsigned b = 0, be = simple.size(); b < be; b++) {
180     if (simple[b] && !isDim(b, kSparse)) {
181       if (reset)
182         simple.reset(b);
183       reset = true;
184     }
185   }
186   return simple;
187 }
188 
latGT(unsigned i,unsigned j) const189 bool Merger::latGT(unsigned i, unsigned j) const {
190   const llvm::BitVector &bitsi = latPoints[i].bits;
191   const llvm::BitVector &bitsj = latPoints[j].bits;
192   assert(bitsi.size() == bitsj.size());
193   if (bitsi.count() > bitsj.count()) {
194     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
195       if (bitsj[b] && !bitsi[b])
196         return false;
197     return true;
198   }
199   return false;
200 }
201 
onlyDenseDiff(unsigned i,unsigned j)202 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
203   llvm::BitVector tmp = latPoints[j].bits;
204   tmp ^= latPoints[i].bits;
205   return !hasAnyDimOf(tmp, kSparse);
206 }
207 
hasAnyDimOf(const llvm::BitVector & bits,Dim d) const208 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
209   for (unsigned b = 0, be = bits.size(); b < be; b++)
210     if (bits[b] && isDim(b, d))
211       return true;
212   return false;
213 }
214 
isConjunction(unsigned t,unsigned e) const215 bool Merger::isConjunction(unsigned t, unsigned e) const {
216   switch (tensorExps[e].kind) {
217   case kTensor:
218     return tensorExps[e].tensor == t;
219   case kAbsF:
220   case kCeilF:
221   case kFloorF:
222   case kNegF:
223   case kNegI:
224   case kTruncF:
225   case kExtF:
226   case kCastFS:
227   case kCastFU:
228   case kCastSF:
229   case kCastUF:
230   case kCastS:
231   case kCastU:
232   case kTruncI:
233   case kBitCast:
234     return isConjunction(t, tensorExps[e].children.e0);
235   case kDivF: // note: x / c only
236   case kDivS:
237   case kDivU:
238     assert(!maybeZero(tensorExps[e].children.e1));
239     return isConjunction(t, tensorExps[e].children.e0);
240   case kShrS: // note: x >> inv only
241   case kShrU:
242   case kShlI:
243     assert(isInvariant(tensorExps[e].children.e1));
244     return isConjunction(t, tensorExps[e].children.e0);
245   case kMulF:
246   case kMulI:
247   case kAndI:
248     return isConjunction(t, tensorExps[e].children.e0) ||
249            isConjunction(t, tensorExps[e].children.e1);
250   default:
251     return false;
252   }
253 }
254 
255 #ifndef NDEBUG
256 
257 //===----------------------------------------------------------------------===//
258 // Print methods (for debugging).
259 //===----------------------------------------------------------------------===//
260 
kindToOpSymbol(Kind kind)261 static const char *kindToOpSymbol(Kind kind) {
262   switch (kind) {
263   case kTensor:
264     return "tensor";
265   case kInvariant:
266     return "invariant";
267   case kAbsF:
268     return "abs";
269   case kCeilF:
270     return "ceil";
271   case kFloorF:
272     return "floor";
273   case kNegF:
274     return "-";
275   case kNegI:
276     return "-";
277   case kTruncF:
278   case kExtF:
279   case kCastFS:
280   case kCastFU:
281   case kCastSF:
282   case kCastUF:
283   case kCastS:
284   case kCastU:
285   case kTruncI:
286   case kBitCast:
287     return "cast";
288   case kMulF:
289     return "*";
290   case kMulI:
291     return "*";
292   case kDivF:
293     return "/";
294   case kDivS:
295     return "/";
296   case kDivU:
297     return "/";
298   case kAddF:
299     return "+";
300   case kAddI:
301     return "+";
302   case kSubF:
303     return "-";
304   case kSubI:
305     return "-";
306   case kAndI:
307     return "&";
308   case kOrI:
309     return "|";
310   case kXorI:
311     return "^";
312   case kShrS:
313     return "a>>";
314   case kShrU:
315     return ">>";
316   case kShlI:
317     return "<<";
318   }
319   llvm_unreachable("unexpected kind for symbol");
320 }
321 
dumpExp(unsigned e) const322 void Merger::dumpExp(unsigned e) const {
323   switch (tensorExps[e].kind) {
324   case kTensor:
325     if (tensorExps[e].tensor == syntheticTensor)
326       llvm::dbgs() << "synthetic_";
327     else if (tensorExps[e].tensor == outTensor)
328       llvm::dbgs() << "output_";
329     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
330     break;
331   case kInvariant:
332     llvm::dbgs() << "invariant";
333     break;
334   case kAbsF:
335   case kCeilF:
336   case kFloorF:
337   case kNegF:
338   case kNegI:
339   case kTruncF:
340   case kExtF:
341   case kCastFS:
342   case kCastFU:
343   case kCastSF:
344   case kCastUF:
345   case kCastS:
346   case kCastU:
347   case kTruncI:
348   case kBitCast:
349     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
350     dumpExp(tensorExps[e].children.e0);
351     break;
352   default:
353     llvm::dbgs() << "(";
354     dumpExp(tensorExps[e].children.e0);
355     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
356     dumpExp(tensorExps[e].children.e1);
357     llvm::dbgs() << ")";
358   }
359 }
360 
dumpLat(unsigned p) const361 void Merger::dumpLat(unsigned p) const {
362   llvm::dbgs() << "lat(";
363   dumpBits(latPoints[p].bits);
364   llvm::dbgs() << " :";
365   dumpBits(latPoints[p].simple);
366   llvm::dbgs() << " : ";
367   dumpExp(latPoints[p].exp);
368   llvm::dbgs() << " )\n";
369 }
370 
dumpSet(unsigned s) const371 void Merger::dumpSet(unsigned s) const {
372   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
373   for (unsigned p : latSets[s]) {
374     llvm::dbgs() << "  ";
375     dumpLat(p);
376   }
377   llvm::dbgs() << "}\n";
378 }
379 
dumpBits(const llvm::BitVector & bits) const380 void Merger::dumpBits(const llvm::BitVector &bits) const {
381   for (unsigned b = 0, be = bits.size(); b < be; b++) {
382     if (bits[b]) {
383       unsigned t = tensor(b);
384       unsigned i = index(b);
385       llvm::dbgs() << " i_" << t << "_" << i << "_";
386       switch (dims[t][i]) {
387       case kSparse:
388         llvm::dbgs() << "S";
389         break;
390       case kDense:
391         llvm::dbgs() << "D";
392         break;
393       case kSingle:
394         llvm::dbgs() << "T";
395         break;
396       case kUndef:
397         llvm::dbgs() << "U";
398         break;
399       }
400     }
401   }
402 }
403 
404 #endif // NDEBUG
405 
406 //===----------------------------------------------------------------------===//
407 // Builder methods.
408 //===----------------------------------------------------------------------===//
409 
buildLattices(unsigned e,unsigned i)410 unsigned Merger::buildLattices(unsigned e, unsigned i) {
411   Kind kind = tensorExps[e].kind;
412   switch (kind) {
413   case kTensor:
414   case kInvariant: {
415     // Either the index is really used in the tensor expression, or it is
416     // set to the undefined index in that dimension. An invariant expression
417     // is set to a synthetic tensor with undefined indices only.
418     unsigned s = addSet();
419     unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor;
420     latSets[s].push_back(addLat(t, i, e));
421     return s;
422   }
423   case kAbsF:
424   case kCeilF:
425   case kFloorF:
426   case kNegF:
427   case kNegI:
428   case kTruncF:
429   case kExtF:
430   case kCastFS:
431   case kCastFU:
432   case kCastSF:
433   case kCastUF:
434   case kCastS:
435   case kCastU:
436   case kTruncI:
437   case kBitCast:
438     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
439     // lattice set of the operand through the operator into a new set.
440     //
441     //  -y|!y | y |
442     //  --+---+---+
443     //    | 0 |-y |
444     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
445                   tensorExps[e].val);
446   case kMulF:
447   case kMulI:
448   case kAndI:
449     // A multiplicative operation only needs to be performed
450     // for the conjunction of sparse iteration spaces.
451     //
452     //  x*y|!y | y |
453     //  ---+---+---+
454     //  !x | 0 | 0 |
455     //   x | 0 |x*y|
456     return takeConj(kind, // take binary conjunction
457                     buildLattices(tensorExps[e].children.e0, i),
458                     buildLattices(tensorExps[e].children.e1, i));
459   case kDivF:
460   case kDivS:
461   case kDivU:
462     // A division is tricky, since 0/0, 0/c, c/0 all have
463     // specific outcomes for floating-point and integers.
464     // Thus, we need to traverse the full iteration space.
465     //
466     //  x/y|!y | y |
467     //  ---+---+---+
468     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
469     //   x |x/0|x/y|  INT: x/0=exception for any x
470     //
471     // TODO: for now we "fixed" this by only accepting x/c cases
472     //       during expression building, so that the conjunction
473     //       rules applies (viz. x/c = x*(1/c) as far as lattice
474     //       construction is concerned).
475     assert(!maybeZero(tensorExps[e].children.e1));
476     return takeConj(kind, // take binary conjunction
477                     buildLattices(tensorExps[e].children.e0, i),
478                     buildLattices(tensorExps[e].children.e1, i));
479   case kAddF:
480   case kAddI:
481   case kSubF:
482   case kSubI:
483   case kOrI:
484   case kXorI:
485     // An additive operation needs to be performed
486     // for the disjunction of sparse iteration spaces.
487     //
488     //  x+y|!y | y |    x-y|!y | y |
489     //  ---+---+---+    ---+---+---+
490     //  !x | 0 | y |    !x | 0 |-y |
491     //   x | x |x+y|     x | x |x-y|
492     return takeDisj(kind, // take binary disjunction
493                     buildLattices(tensorExps[e].children.e0, i),
494                     buildLattices(tensorExps[e].children.e1, i));
495   case kShrS:
496   case kShrU:
497   case kShlI:
498     // A shift operation by an invariant amount (viz. tensor expressions
499     // can only occur at the left-hand-side of the operator) can be handled
500     // with the conjuction rule.
501     assert(isInvariant(tensorExps[e].children.e1));
502     return takeConj(kind, // take binary conjunction
503                     buildLattices(tensorExps[e].children.e0, i),
504                     buildLattices(tensorExps[e].children.e1, i));
505   }
506   llvm_unreachable("unexpected expression kind");
507 }
508 
buildTensorExpFromLinalg(linalg::GenericOp op)509 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
510   Operation *yield = op.region().front().getTerminator();
511   return buildTensorExp(op, yield->getOperand(0));
512 }
513 
514 /// Only returns false if we are certain this is a nonzero.
maybeZero(unsigned e) const515 bool Merger::maybeZero(unsigned e) const {
516   if (tensorExps[e].kind == kInvariant) {
517     if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>())
518       return c.getValue() == 0;
519     if (auto c = tensorExps[e].val.getDefiningOp<ConstantFloatOp>())
520       return c.getValue().isZero();
521   }
522   return true;
523 }
524 
isInvariant(unsigned e) const525 bool Merger::isInvariant(unsigned e) const {
526   return tensorExps[e].kind == kInvariant;
527 }
528 
inferType(unsigned e,Value src)529 Type Merger::inferType(unsigned e, Value src) {
530   // Obtain the destination type from the cast node.
531   Type dtp = tensorExps[e].val.getType();
532   // Inspect source type. For vector types, apply the same
533   // vectorization to the destination type.
534   if (auto vtp = src.getType().dyn_cast<VectorType>())
535     return VectorType::get(vtp.getNumElements(), dtp);
536   return dtp;
537 }
538 
buildTensorExp(linalg::GenericOp op,Value v)539 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
540   if (auto arg = v.dyn_cast<BlockArgument>()) {
541     unsigned argN = arg.getArgNumber();
542     // Any argument of the generic op that is not marked as a scalar
543     // argument is considered a tensor, indexed by the implicit loop
544     // bounds. This includes rank-0 tensor arguments.
545     if (arg.getOwner()->getParentOp() == op) {
546       OpOperand *t = op.getInputAndOutputOperands()[argN];
547       if (!op.isScalar(t))
548         return addExp(kTensor, argN);
549       v = t->get(); // get scalar value
550     }
551     // Any other argument (marked as scalar argument for the generic op
552     // or belonging to an enveloping op) is considered invariant.
553     return addExp(kInvariant, v);
554   }
555   // Something defined outside is invariant.
556   Operation *def = v.getDefiningOp();
557   if (def->getBlock() != &op.region().front())
558     return addExp(kInvariant, v);
559   // Construct unary operations if subexpression can be built.
560   if (def->getNumOperands() == 1) {
561     auto x = buildTensorExp(op, def->getOperand(0));
562     if (x.hasValue()) {
563       unsigned e = x.getValue();
564       if (isa<AbsFOp>(def))
565         return addExp(kAbsF, e);
566       if (isa<CeilFOp>(def))
567         return addExp(kCeilF, e);
568       if (isa<FloorFOp>(def))
569         return addExp(kFloorF, e);
570       if (isa<NegFOp>(def))
571         return addExp(kNegF, e); // no negi in std
572       if (isa<FPTruncOp>(def))
573         return addExp(kTruncF, e, v);
574       if (isa<FPExtOp>(def))
575         return addExp(kExtF, e, v);
576       if (isa<FPToSIOp>(def))
577         return addExp(kCastFS, e, v);
578       if (isa<FPToUIOp>(def))
579         return addExp(kCastFU, e, v);
580       if (isa<SIToFPOp>(def))
581         return addExp(kCastSF, e, v);
582       if (isa<UIToFPOp>(def))
583         return addExp(kCastUF, e, v);
584       if (isa<SignExtendIOp>(def))
585         return addExp(kCastS, e, v);
586       if (isa<ZeroExtendIOp>(def))
587         return addExp(kCastU, e, v);
588       if (isa<TruncateIOp>(def))
589         return addExp(kTruncI, e, v);
590       if (isa<BitcastOp>(def))
591         return addExp(kBitCast, e, v);
592     }
593   }
594   // Construct binary operations if subexpressions can be built.
595   // TODO: see buildLattices() for an explanation of rejecting
596   //       certain division and shift operations
597   if (def->getNumOperands() == 2) {
598     auto x = buildTensorExp(op, def->getOperand(0));
599     auto y = buildTensorExp(op, def->getOperand(1));
600     if (x.hasValue() && y.hasValue()) {
601       unsigned e0 = x.getValue();
602       unsigned e1 = y.getValue();
603       if (isa<MulFOp>(def))
604         return addExp(kMulF, e0, e1);
605       if (isa<MulIOp>(def))
606         return addExp(kMulI, e0, e1);
607       if (isa<DivFOp>(def) && !maybeZero(e1))
608         return addExp(kDivF, e0, e1);
609       if (isa<SignedDivIOp>(def) && !maybeZero(e1))
610         return addExp(kDivS, e0, e1);
611       if (isa<UnsignedDivIOp>(def) && !maybeZero(e1))
612         return addExp(kDivU, e0, e1);
613       if (isa<AddFOp>(def))
614         return addExp(kAddF, e0, e1);
615       if (isa<AddIOp>(def))
616         return addExp(kAddI, e0, e1);
617       if (isa<SubFOp>(def))
618         return addExp(kSubF, e0, e1);
619       if (isa<SubIOp>(def))
620         return addExp(kSubI, e0, e1);
621       if (isa<AndOp>(def))
622         return addExp(kAndI, e0, e1);
623       if (isa<OrOp>(def))
624         return addExp(kOrI, e0, e1);
625       if (isa<XOrOp>(def))
626         return addExp(kXorI, e0, e1);
627       if (isa<SignedShiftRightOp>(def) && isInvariant(e1))
628         return addExp(kShrS, e0, e1);
629       if (isa<UnsignedShiftRightOp>(def) && isInvariant(e1))
630         return addExp(kShrU, e0, e1);
631       if (isa<ShiftLeftOp>(def) && isInvariant(e1))
632         return addExp(kShlI, e0, e1);
633     }
634   }
635   // Cannot build.
636   return None;
637 }
638 
buildExp(PatternRewriter & rewriter,Location loc,unsigned e,Value v0,Value v1)639 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
640                        Value v0, Value v1) {
641   switch (tensorExps[e].kind) {
642   case kTensor:
643   case kInvariant:
644     llvm_unreachable("unexpected non-op");
645   // Unary ops.
646   case kAbsF:
647     return rewriter.create<AbsFOp>(loc, v0);
648   case kCeilF:
649     return rewriter.create<CeilFOp>(loc, v0);
650   case kFloorF:
651     return rewriter.create<FloorFOp>(loc, v0);
652   case kNegF:
653     return rewriter.create<NegFOp>(loc, v0);
654   case kNegI: // no negi in std
655     return rewriter.create<SubIOp>(
656         loc,
657         rewriter.create<ConstantOp>(loc, v0.getType(),
658                                     rewriter.getZeroAttr(v0.getType())),
659         v0);
660   case kTruncF:
661     return rewriter.create<FPTruncOp>(loc, v0, inferType(e, v0));
662   case kExtF:
663     return rewriter.create<FPExtOp>(loc, v0, inferType(e, v0));
664   case kCastFS:
665     return rewriter.create<FPToSIOp>(loc, v0, inferType(e, v0));
666   case kCastFU:
667     return rewriter.create<FPToUIOp>(loc, v0, inferType(e, v0));
668   case kCastSF:
669     return rewriter.create<SIToFPOp>(loc, v0, inferType(e, v0));
670   case kCastUF:
671     return rewriter.create<UIToFPOp>(loc, v0, inferType(e, v0));
672   case kCastS:
673     return rewriter.create<SignExtendIOp>(loc, v0, inferType(e, v0));
674   case kCastU:
675     return rewriter.create<ZeroExtendIOp>(loc, v0, inferType(e, v0));
676   case kTruncI:
677     return rewriter.create<TruncateIOp>(loc, v0, inferType(e, v0));
678   case kBitCast:
679     return rewriter.create<BitcastOp>(loc, v0, inferType(e, v0));
680   // Binary ops.
681   case kMulF:
682     return rewriter.create<MulFOp>(loc, v0, v1);
683   case kMulI:
684     return rewriter.create<MulIOp>(loc, v0, v1);
685   case kDivF:
686     return rewriter.create<DivFOp>(loc, v0, v1);
687   case kDivS:
688     return rewriter.create<SignedDivIOp>(loc, v0, v1);
689   case kDivU:
690     return rewriter.create<UnsignedDivIOp>(loc, v0, v1);
691   case kAddF:
692     return rewriter.create<AddFOp>(loc, v0, v1);
693   case kAddI:
694     return rewriter.create<AddIOp>(loc, v0, v1);
695   case kSubF:
696     return rewriter.create<SubFOp>(loc, v0, v1);
697   case kSubI:
698     return rewriter.create<SubIOp>(loc, v0, v1);
699   case kAndI:
700     return rewriter.create<AndOp>(loc, v0, v1);
701   case kOrI:
702     return rewriter.create<OrOp>(loc, v0, v1);
703   case kXorI:
704     return rewriter.create<XOrOp>(loc, v0, v1);
705   case kShrS:
706     return rewriter.create<SignedShiftRightOp>(loc, v0, v1);
707   case kShrU:
708     return rewriter.create<UnsignedShiftRightOp>(loc, v0, v1);
709   case kShlI:
710     return rewriter.create<ShiftLeftOp>(loc, v0, v1);
711   }
712   llvm_unreachable("unexpected expression kind in build");
713 }
714 
715 } // namespace sparse_tensor
716 } // namespace mlir
717