1 //===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of tanh op.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Math/IR/Math.h"
14 #include "mlir/Dialect/Math/Transforms/Passes.h"
15 #include "mlir/Dialect/StandardOps/IR/Ops.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 using namespace mlir;
19 
20 /// Expands tanh op into
21 ///   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
22 ///   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0
convertTanhOp(math::TanhOp op,PatternRewriter & rewriter)23 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
24   auto floatType = op.operand().getType();
25   Location loc = op.getLoc();
26   auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
27   auto floatTwo = rewriter.getFloatAttr(floatType, 2.0);
28   Value one = rewriter.create<ConstantOp>(loc, floatOne);
29   Value two = rewriter.create<ConstantOp>(loc, floatTwo);
30   Value doubledX = rewriter.create<MulFOp>(loc, op.operand(), two);
31 
32   // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
33   Value negDoubledX = rewriter.create<NegFOp>(loc, doubledX);
34   Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
35   Value dividend = rewriter.create<SubFOp>(loc, one, exp2x);
36   Value divisor = rewriter.create<AddFOp>(loc, one, exp2x);
37   Value positiveRes = rewriter.create<DivFOp>(loc, dividend, divisor);
38 
39   // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
40   exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
41   dividend = rewriter.create<SubFOp>(loc, exp2x, one);
42   divisor = rewriter.create<AddFOp>(loc, exp2x, one);
43   Value negativeRes = rewriter.create<DivFOp>(loc, dividend, divisor);
44 
45   // tanh(x) = x >= 0 ? positiveRes : negativeRes
46   auto floatZero = rewriter.getFloatAttr(floatType, 0.0);
47   Value zero = rewriter.create<ConstantOp>(loc, floatZero);
48   Value cmpRes =
49       rewriter.create<CmpFOp>(loc, CmpFPredicate::OGE, op.operand(), zero);
50   rewriter.replaceOpWithNewOp<SelectOp>(op, cmpRes, positiveRes, negativeRes);
51   return success();
52 }
53 
populateExpandTanhPattern(RewritePatternSet & patterns)54 void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
55   patterns.add(convertTanhOp);
56 }
57