1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 *
22 * \file mac_count.cc
23 * \brief Pass to roughly count the number of MACs (Multiply-Accumulate)
24 * operations of a model. Only MACs in CONV and Dense ops are counted.
25 * This pass is valid after the type infer pass is called,
26 * otherwise the count is 0.
27 */
28
29 #include <tvm/relay/op.h>
30 #include <tvm/relay/attrs/nn.h>
31 #include <tvm/relay/expr_functor.h>
32 #include <tvm/relay/analysis.h>
33 #include <tvm/data_layout.h>
34 #include "pattern_util.h"
35
36 namespace tvm {
37 namespace relay {
38
39 namespace mac_count {
40
GetCartesianProd(Array<IndexExpr> arr)41 inline int64_t GetCartesianProd(Array<IndexExpr> arr) {
42 int64_t ret = 1;
43 for (size_t i = 0; i < arr.size(); i++) {
44 const auto* intImm = arr[i].as<IntImm>();
45 ret *= static_cast<int64_t>(intImm->value);
46 }
47 return ret;
48 }
49
50 /*
51 * \brief Preparation function for MAC count.
52 * \param call_node The call node.
53 * \return The number of MACs.
54 */
55 using FMacCount = runtime::TypedPackedFunc<
56 int64_t(const Call& call_node)>;
57
58 //----------------------------------------------
59 // Per operator defs for MAC count
60 //----------------------------------------------
61
ConvMacCount(const Call & call_node)62 int64_t ConvMacCount(const Call& call_node) {
63 if (!call_node->checked_type_.defined()) {
64 LOG(WARNING) << "The infer type pass should be called before the mac count pass";
65 return 0;
66 }
67 Array<Expr> args = call_node->args;
68 CHECK_EQ(args.size(), 2)
69 << "The number of input arguments of a CONV 2D node should be 2.";
70 const auto* conv_2d_attr = call_node->attrs.as<Conv2DAttrs>();
71 const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
72 Array<IndexExpr> data_shape = data_type->shape;
73 std::string data_layout = conv_2d_attr->data_layout;
74 int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
75 int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
76 CHECK_NE(C_ind, -1)
77 << "There is no input channel dimension.";
78 int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
79 if (c_ind != -1)
80 input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
81 Array<IndexExpr> kernel_size = conv_2d_attr->kernel_size;
82 CHECK_EQ(kernel_size.size(), 2)
83 << "The dimension of the kernel in Conv 2D should be 2.";
84 const auto* expr = call_node->checked_type().as<TensorTypeNode>();
85 Array<IndexExpr> output_tensor = expr->shape;
86 CHECK(output_tensor.size() == 4 || output_tensor.size() == 5)
87 << "The dimension of the output tensor in Conv 2D should be 4 or 5.";
88 int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size);
89 CHECK_EQ(input_channel % conv_2d_attr->groups, 0)
90 << "The number of input channels is not divisble by groups.";
91 count *= input_channel/conv_2d_attr->groups;
92 return count;
93 }
94
Conv2dTransposeMacCount(const Call & call_node)95 int64_t Conv2dTransposeMacCount(const Call& call_node) {
96 if (!call_node->checked_type_.defined()) {
97 LOG(WARNING) << "The infer type pass should be called before the mac count pass";
98 return 0;
99 }
100 Array<Expr> args = call_node->args;
101 CHECK_EQ(args.size(), 2)
102 << "The number of input arguments of a CONV 2D Transpose node should be 2.";
103 const auto* conv_2d_transpose_attr = call_node->attrs.as<Conv2DTransposeAttrs>();
104 const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
105 Array<IndexExpr> data_shape = data_type->shape;
106 std::string data_layout = conv_2d_transpose_attr->data_layout;
107 int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
108 int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
109 CHECK_NE(C_ind, -1)
110 << "There is no input channel dimension.";
111 int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
112 if (c_ind != -1)
113 input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
114 Array<IndexExpr> kernel_size = conv_2d_transpose_attr->kernel_size;
115 CHECK_EQ(kernel_size.size(), 2)
116 << "The dimension of the kernel in Conv 2D Transpose should be 2.";
117 const auto* expr = call_node->checked_type().as<TensorTypeNode>();
118 Array<IndexExpr> output_tensor = expr->shape;
119 CHECK(output_tensor.size() == 4 || output_tensor.size() == 5)
120 << "The dimension of the output tensor in Conv 2D Transpose should be 4 or 5.";
121 int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size);
122 CHECK_EQ(input_channel % conv_2d_transpose_attr->groups, 0)
123 << "The number of input channels is not divisble by groups.";
124 count *= input_channel/conv_2d_transpose_attr->groups;
125 return count;
126 }
127
DenseMacCount(const Call & call_node)128 int64_t DenseMacCount(const Call& call_node) {
129 if (!call_node->checked_type_.defined()) {
130 LOG(WARNING) << "The infer type pass should be called before the mac count pass";
131 return 0;
132 }
133 Array<Expr> args = call_node->args;
134 CHECK_EQ(args.size(), 2)
135 << "The number of input arguments of a Dense node should be 2.";
136 const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
137 const auto* weight_type = args[1]->checked_type().as<TensorTypeNode>();
138 Array<IndexExpr> data_shape = data_type->shape;
139 Array<IndexExpr> weight_shape = weight_type->shape;
140 CHECK(data_shape.size() == 2 && weight_shape.size() == 2)
141 << "The dimension of an input tensor to Dense node should be 2.";
142 int64_t d1 = static_cast<int64_t>(data_shape[0].as<IntImm>()->value);
143 int64_t d2 = static_cast<int64_t>(data_shape[1].as<IntImm>()->value);
144 int64_t d3 = static_cast<int64_t>(weight_shape[0].as<IntImm>()->value);
145 int64_t d4 = static_cast<int64_t>(weight_shape[1].as<IntImm>()->value);
146 CHECK_EQ(d2, d4)
147 << "The dimensions of input arguments do not match.";
148 int64_t count = d1 * d2 * d3;
149 return count;
150 }
151
BatchMatmulMacCount(const Call & call_node)152 int64_t BatchMatmulMacCount(const Call& call_node) {
153 if (!call_node->checked_type_.defined()) {
154 LOG(WARNING) << "The infer type pass should be called before the mac count pass";
155 return 0;
156 }
157 Array<Expr> args = call_node->args;
158 CHECK_EQ(args.size(), 2);
159 Array<IndexExpr> x_shape = args[0]->checked_type().as<TensorTypeNode>()->shape;
160 Array<IndexExpr> y_shape = args[1]->checked_type().as<TensorTypeNode>()->shape;
161 int64_t batch = x_shape[0].as<IntImm>()->value;
162 int64_t m = x_shape[1].as<IntImm>()->value;
163 int64_t k = x_shape[2].as<IntImm>()->value;
164 int64_t n = y_shape[1].as<IntImm>()->value;
165 return batch * m * k * n;
166 }
167
168 RELAY_REGISTER_OP("nn.conv2d")
169 .set_attr<FMacCount>("FMacCount", ConvMacCount);
170
171 RELAY_REGISTER_OP("nn.conv2d_transpose")
172 .set_attr<FMacCount>("FMacCount", Conv2dTransposeMacCount);
173
174 RELAY_REGISTER_OP("nn.dense")
175 .set_attr<FMacCount>("FMacCount", DenseMacCount);
176
177 RELAY_REGISTER_OP("nn.batch_matmul")
178 .set_attr<FMacCount>("FMacCount", BatchMatmulMacCount);
179
180 class MacCounter : private ExprVisitor {
181 public:
MacCounter()182 MacCounter() {
183 count_ = 0;
184 }
GetTotalMacNumber(const Expr & expr)185 static int64_t GetTotalMacNumber(const Expr& expr) {
186 LOG(INFO) << "This pass only counts MACs in direct conv2d, "
187 << "conv2d_transpose, dense, and batch_matmul ops";
188 MacCounter counter;
189 counter(expr);
190 return counter.count_;
191 }
192
193 private:
VisitExpr_(const CallNode * call_node)194 void VisitExpr_(const CallNode* call_node) final {
195 static const auto& fprep =
196 Op::GetAttr<FMacCount>("FMacCount");
197 auto f = fprep.get(call_node->op, nullptr);
198 if (f != nullptr) count_ += f(GetRef<Call>(call_node));
199 ExprVisitor::VisitExpr_(call_node);
200 }
201
202 int64_t count_;
203 };
204
GetTotalMacNumber(const Expr & expr)205 int64_t GetTotalMacNumber(const Expr& expr) {
206 return MacCounter::GetTotalMacNumber(expr);
207 }
208
209 TVM_REGISTER_API("relay._analysis.GetTotalMacNumber")
210 .set_body_typed(GetTotalMacNumber);
211
212 } // namespace mac_count
213 } // namespace relay
214 } // namespace tvm
215