1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file intrin_rule_llvm.cc
22 */
23 #ifdef TVM_LLVM_VERSION
24
25 #include "intrin_rule_llvm.h"
26
27 #include <tvm/tir/op.h>
28
29 namespace tvm {
30 namespace codegen {
31 namespace llvm {
32
33 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch")
34 .set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>);
35
36 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
37 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);
38
39 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp2")
40 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>);
41
42 // TODO(tvm-team): migrate the legalization transformations as a separate
43 // set of rules in TIR that can be shared across backends.
44 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10")
__anon358d77900102(const TVMArgs& targs, TVMRetValue* rv) 45 .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
46 using tir::make_const;
47 using tir::make_zero;
48 PrimExpr e = targs[0];
49 const tir::CallNode* call = e.as<tir::CallNode>();
50 CHECK(call != nullptr);
51 const PrimExpr& x = call->args[0];
52 PrimExpr ln10 = make_const(x.dtype(), 2.302585093);
53 PrimExpr ret = exp(x * ln10);
54 *rv = ret;
55 });
56
57 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma")
58 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>);
59
60 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
61 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);
62
63 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log2")
64 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>);
65
66 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log10")
67 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>);
68
69 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
70 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);
71
72 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor")
73 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);
74
75 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil")
76 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);
77
78 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.trunc")
79 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);
80
81 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fabs")
82 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);
83
84 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round")
85 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
86
87 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint")
88 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
89
90 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
__anon358d77900202(const TVMArgs& targs, TVMRetValue* rv) 91 .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
92 using tir::make_const;
93 using tir::make_zero;
94 PrimExpr e = targs[0];
95 const tir::CallNode* call = e.as<tir::CallNode>();
96 CHECK(call != nullptr);
97 const PrimExpr& x = call->args[0];
98 PrimExpr one = make_const(x.dtype(), 1);
99 PrimExpr two = make_const(x.dtype(), 2);
100 PrimExpr neg_two = make_const(x.dtype(), -2);
101
102 PrimExpr exp_neg2x = exp(neg_two * x);
103 PrimExpr exp_pos2x = exp(two * x);
104
105 PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
106 PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
107 *rv = tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
108 });
109
110 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow")
111 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);
112
113 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount")
114 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);
115
__anon358d77900302(const TVMArgs& targs, TVMRetValue* rv) 116 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan").set_body([](const TVMArgs& targs, TVMRetValue* rv) {
117 PrimExpr e = targs[0];
118 const tir::CallNode* call = e.as<tir::CallNode>();
119 CHECK(call != nullptr);
120 const PrimExpr& x = call->args[0];
121 PrimExpr tan_x = sin(x) / cos(x);
122 *rv = tan_x;
123 });
124
125 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos")
126 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
127
128 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cosh")
__anon358d77900402(const TVMArgs& targs, TVMRetValue* rv) 129 .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
130 using tir::make_const;
131 using tir::make_zero;
132 PrimExpr e = targs[0];
133 const tir::CallNode* call = e.as<tir::CallNode>();
134 CHECK(call != nullptr);
135 const PrimExpr& x = call->args[0];
136 PrimExpr two = make_const(x.dtype(), 2);
137 PrimExpr neg_one = make_const(x.dtype(), -1);
138 PrimExpr exp_negx = exp(neg_one * x);
139 PrimExpr exp_posx = exp(x);
140 PrimExpr ret = (exp_posx + exp_negx) / two;
141 *rv = ret;
142 });
143
144 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sin")
145 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
146
147 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh")
__anon358d77900502(const TVMArgs& targs, TVMRetValue* rv) 148 .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
149 using tir::make_const;
150 using tir::make_zero;
151 PrimExpr e = targs[0];
152 const tir::CallNode* call = e.as<tir::CallNode>();
153 CHECK(call != nullptr);
154 const PrimExpr& x = call->args[0];
155 PrimExpr two = make_const(x.dtype(), 2);
156 PrimExpr neg_one = make_const(x.dtype(), -1);
157 PrimExpr exp_negx = exp(neg_one * x);
158 PrimExpr exp_posx = exp(x);
159 PrimExpr ret = (exp_posx - exp_negx) / two;
160 *rv = ret;
161 });
162
163 } // namespace llvm
164 } // namespace codegen
165 } // namespace tvm
166
167 #endif // LLVM_VERSION
168