1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file intrin_rule_llvm.cc
22  */
23 #ifdef TVM_LLVM_VERSION
24 
25 #include "intrin_rule_llvm.h"
26 
27 #include <tvm/tir/op.h>
28 
29 namespace tvm {
30 namespace codegen {
31 namespace llvm {
32 
33 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.prefetch")
34     .set_body(DispatchLLVMIntrin<::llvm::Intrinsic::prefetch, 4>);
35 
36 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp")
37     .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);
38 
39 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp2")
40     .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>);
41 
42 // TODO(tvm-team): migrate the legalization transformations as a separate
43 //                 set of rules in TIR that can be shared across backends.
44 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.exp10")
__anon358d77900102(const TVMArgs& targs, TVMRetValue* rv) 45     .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
46       using tir::make_const;
47       using tir::make_zero;
48       PrimExpr e = targs[0];
49       const tir::CallNode* call = e.as<tir::CallNode>();
50       CHECK(call != nullptr);
51       const PrimExpr& x = call->args[0];
52       PrimExpr ln10 = make_const(x.dtype(), 2.302585093);
53       PrimExpr ret = exp(x * ln10);
54       *rv = ret;
55     });
56 
57 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fma")
58     .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>);
59 
60 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log")
61     .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);
62 
63 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log2")
64     .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log2, 1>);
65 
66 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.log10")
67     .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log10, 1>);
68 
69 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sqrt")
70     .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);
71 
72 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.floor")
73     .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);
74 
75 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.ceil")
76     .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);
77 
78 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.trunc")
79     .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);
80 
81 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.fabs")
82     .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);
83 
84 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.round")
85     .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
86 
87 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.nearbyint")
88     .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
89 
90 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tanh")
__anon358d77900202(const TVMArgs& targs, TVMRetValue* rv) 91     .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
92       using tir::make_const;
93       using tir::make_zero;
94       PrimExpr e = targs[0];
95       const tir::CallNode* call = e.as<tir::CallNode>();
96       CHECK(call != nullptr);
97       const PrimExpr& x = call->args[0];
98       PrimExpr one = make_const(x.dtype(), 1);
99       PrimExpr two = make_const(x.dtype(), 2);
100       PrimExpr neg_two = make_const(x.dtype(), -2);
101 
102       PrimExpr exp_neg2x = exp(neg_two * x);
103       PrimExpr exp_pos2x = exp(two * x);
104 
105       PrimExpr tanh_pos = (one - exp_neg2x) / (one + exp_neg2x);
106       PrimExpr tanh_neg = (exp_pos2x - one) / (exp_pos2x + one);
107       *rv = tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg);
108     });
109 
110 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow")
111     .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>);
112 
113 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount")
114     .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);
115 
__anon358d77900302(const TVMArgs& targs, TVMRetValue* rv) 116 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan").set_body([](const TVMArgs& targs, TVMRetValue* rv) {
117   PrimExpr e = targs[0];
118   const tir::CallNode* call = e.as<tir::CallNode>();
119   CHECK(call != nullptr);
120   const PrimExpr& x = call->args[0];
121   PrimExpr tan_x = sin(x) / cos(x);
122   *rv = tan_x;
123 });
124 
125 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos")
126     .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
127 
128 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cosh")
__anon358d77900402(const TVMArgs& targs, TVMRetValue* rv) 129     .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
130       using tir::make_const;
131       using tir::make_zero;
132       PrimExpr e = targs[0];
133       const tir::CallNode* call = e.as<tir::CallNode>();
134       CHECK(call != nullptr);
135       const PrimExpr& x = call->args[0];
136       PrimExpr two = make_const(x.dtype(), 2);
137       PrimExpr neg_one = make_const(x.dtype(), -1);
138       PrimExpr exp_negx = exp(neg_one * x);
139       PrimExpr exp_posx = exp(x);
140       PrimExpr ret = (exp_posx + exp_negx) / two;
141       *rv = ret;
142     });
143 
144 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sin")
145     .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>);
146 
147 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.sinh")
__anon358d77900502(const TVMArgs& targs, TVMRetValue* rv) 148     .set_body([](const TVMArgs& targs, TVMRetValue* rv) {
149       using tir::make_const;
150       using tir::make_zero;
151       PrimExpr e = targs[0];
152       const tir::CallNode* call = e.as<tir::CallNode>();
153       CHECK(call != nullptr);
154       const PrimExpr& x = call->args[0];
155       PrimExpr two = make_const(x.dtype(), 2);
156       PrimExpr neg_one = make_const(x.dtype(), -1);
157       PrimExpr exp_negx = exp(neg_one * x);
158       PrimExpr exp_posx = exp(x);
159       PrimExpr ret = (exp_posx - exp_negx) / two;
160       *rv = ret;
161     });
162 
163 }  // namespace llvm
164 }  // namespace codegen
165 }  // namespace tvm
166 
167 #endif  // LLVM_VERSION
168