1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  *
22  * \file mac_count.cc
23  * \brief Pass to roughly count the number of MACs (Multiply-Accumulate)
24  * operations of a model. Only MACs in CONV and Dense ops are counted.
25  * This pass is valid after the type infer pass is called,
26  * otherwise the count is 0.
27  */
28 
29 #include <tvm/relay/op.h>
30 #include <tvm/relay/attrs/nn.h>
31 #include <tvm/relay/expr_functor.h>
32 #include <tvm/relay/analysis.h>
33 #include <tvm/data_layout.h>
34 #include "pattern_util.h"
35 
36 namespace tvm {
37 namespace relay {
38 
39 namespace mac_count {
40 
GetCartesianProd(Array<IndexExpr> arr)41 inline int64_t GetCartesianProd(Array<IndexExpr> arr) {
42   int64_t ret = 1;
43   for (size_t i = 0; i < arr.size(); i++) {
44     const auto* intImm = arr[i].as<IntImm>();
45     ret *= static_cast<int64_t>(intImm->value);
46   }
47   return ret;
48 }
49 
50 /*
51  * \brief Preparation function for MAC count.
52  * \param call_node The call node.
53  * \return The number of MACs.
54  */
55 using FMacCount = runtime::TypedPackedFunc<
56   int64_t(const Call& call_node)>;
57 
58 //----------------------------------------------
59 // Per operator defs for MAC count
60 //----------------------------------------------
61 
ConvMacCount(const Call & call_node)62 int64_t ConvMacCount(const Call& call_node) {
63   if (!call_node->checked_type_.defined()) {
64     LOG(WARNING) << "The infer type pass should be called before the mac count pass";
65     return 0;
66   }
67   Array<Expr> args = call_node->args;
68   CHECK_EQ(args.size(), 2)
69     << "The number of input arguments of a CONV 2D node should be 2.";
70   const auto* conv_2d_attr = call_node->attrs.as<Conv2DAttrs>();
71   const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
72   Array<IndexExpr> data_shape = data_type->shape;
73   std::string data_layout = conv_2d_attr->data_layout;
74   int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
75   int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
76   CHECK_NE(C_ind, -1)
77     << "There is no input channel dimension.";
78   int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
79   if (c_ind != -1)
80     input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
81   Array<IndexExpr> kernel_size = conv_2d_attr->kernel_size;
82   CHECK_EQ(kernel_size.size(), 2)
83     << "The dimension of the kernel in Conv 2D should be 2.";
84   const auto* expr = call_node->checked_type().as<TensorTypeNode>();
85   Array<IndexExpr> output_tensor = expr->shape;
86   CHECK(output_tensor.size() == 4 || output_tensor.size() == 5)
87     << "The dimension of the output tensor in Conv 2D should be 4 or 5.";
88   int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size);
89   CHECK_EQ(input_channel % conv_2d_attr->groups, 0)
90     << "The number of input channels is not divisble by groups.";
91   count *= input_channel/conv_2d_attr->groups;
92   return count;
93 }
94 
Conv2dTransposeMacCount(const Call & call_node)95 int64_t Conv2dTransposeMacCount(const Call& call_node) {
96   if (!call_node->checked_type_.defined()) {
97     LOG(WARNING) << "The infer type pass should be called before the mac count pass";
98     return 0;
99   }
100   Array<Expr> args = call_node->args;
101   CHECK_EQ(args.size(), 2)
102     << "The number of input arguments of a CONV 2D Transpose node should be 2.";
103   const auto* conv_2d_transpose_attr = call_node->attrs.as<Conv2DTransposeAttrs>();
104   const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
105   Array<IndexExpr> data_shape = data_type->shape;
106   std::string data_layout = conv_2d_transpose_attr->data_layout;
107   int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
108   int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
109   CHECK_NE(C_ind, -1)
110     << "There is no input channel dimension.";
111   int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
112   if (c_ind != -1)
113     input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
114   Array<IndexExpr> kernel_size = conv_2d_transpose_attr->kernel_size;
115   CHECK_EQ(kernel_size.size(), 2)
116     << "The dimension of the kernel in Conv 2D Transpose should be 2.";
117   const auto* expr = call_node->checked_type().as<TensorTypeNode>();
118   Array<IndexExpr> output_tensor = expr->shape;
119   CHECK(output_tensor.size() == 4 || output_tensor.size() == 5)
120     << "The dimension of the output tensor in Conv 2D Transpose should be 4 or 5.";
121   int64_t count = GetCartesianProd(output_tensor) * GetCartesianProd(kernel_size);
122   CHECK_EQ(input_channel % conv_2d_transpose_attr->groups, 0)
123     << "The number of input channels is not divisble by groups.";
124   count *= input_channel/conv_2d_transpose_attr->groups;
125   return count;
126 }
127 
DenseMacCount(const Call & call_node)128 int64_t DenseMacCount(const Call& call_node) {
129   if (!call_node->checked_type_.defined()) {
130     LOG(WARNING) << "The infer type pass should be called before the mac count pass";
131     return 0;
132   }
133   Array<Expr> args = call_node->args;
134   CHECK_EQ(args.size(), 2)
135       << "The number of input arguments of a Dense node should be 2.";
136   const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
137   const auto* weight_type = args[1]->checked_type().as<TensorTypeNode>();
138   Array<IndexExpr> data_shape = data_type->shape;
139   Array<IndexExpr> weight_shape = weight_type->shape;
140   CHECK(data_shape.size() == 2 && weight_shape.size() == 2)
141     << "The dimension of an input tensor to Dense node should be 2.";
142   int64_t d1 = static_cast<int64_t>(data_shape[0].as<IntImm>()->value);
143   int64_t d2 = static_cast<int64_t>(data_shape[1].as<IntImm>()->value);
144   int64_t d3 = static_cast<int64_t>(weight_shape[0].as<IntImm>()->value);
145   int64_t d4 = static_cast<int64_t>(weight_shape[1].as<IntImm>()->value);
146   CHECK_EQ(d2, d4)
147     << "The dimensions of input arguments do not match.";
148   int64_t count = d1 * d2 * d3;
149   return count;
150 }
151 
BatchMatmulMacCount(const Call & call_node)152 int64_t BatchMatmulMacCount(const Call& call_node) {
153   if (!call_node->checked_type_.defined()) {
154     LOG(WARNING) << "The infer type pass should be called before the mac count pass";
155     return 0;
156   }
157   Array<Expr> args = call_node->args;
158   CHECK_EQ(args.size(), 2);
159   Array<IndexExpr> x_shape = args[0]->checked_type().as<TensorTypeNode>()->shape;
160   Array<IndexExpr> y_shape = args[1]->checked_type().as<TensorTypeNode>()->shape;
161   int64_t batch = x_shape[0].as<IntImm>()->value;
162   int64_t m = x_shape[1].as<IntImm>()->value;
163   int64_t k = x_shape[2].as<IntImm>()->value;
164   int64_t n = y_shape[1].as<IntImm>()->value;
165   return batch * m * k * n;
166 }
167 
168 RELAY_REGISTER_OP("nn.conv2d")
169 .set_attr<FMacCount>("FMacCount", ConvMacCount);
170 
171 RELAY_REGISTER_OP("nn.conv2d_transpose")
172 .set_attr<FMacCount>("FMacCount", Conv2dTransposeMacCount);
173 
174 RELAY_REGISTER_OP("nn.dense")
175 .set_attr<FMacCount>("FMacCount", DenseMacCount);
176 
177 RELAY_REGISTER_OP("nn.batch_matmul")
178 .set_attr<FMacCount>("FMacCount", BatchMatmulMacCount);
179 
180 class MacCounter : private ExprVisitor {
181  public:
MacCounter()182   MacCounter() {
183     count_ = 0;
184   }
GetTotalMacNumber(const Expr & expr)185   static int64_t GetTotalMacNumber(const Expr& expr) {
186     LOG(INFO) << "This pass only counts MACs in direct conv2d, "
187               << "conv2d_transpose, dense, and batch_matmul ops";
188     MacCounter counter;
189     counter(expr);
190     return counter.count_;
191   }
192 
193  private:
VisitExpr_(const CallNode * call_node)194   void VisitExpr_(const CallNode* call_node) final {
195     static const auto& fprep =
196         Op::GetAttr<FMacCount>("FMacCount");
197     auto f = fprep.get(call_node->op, nullptr);
198     if (f != nullptr) count_ += f(GetRef<Call>(call_node));
199     ExprVisitor::VisitExpr_(call_node);
200   }
201 
202   int64_t count_;
203 };
204 
GetTotalMacNumber(const Expr & expr)205 int64_t GetTotalMacNumber(const Expr& expr) {
206   return MacCounter::GetTotalMacNumber(expr);
207 }
208 
209 TVM_REGISTER_API("relay._analysis.GetTotalMacNumber")
210 .set_body_typed(GetTotalMacNumber);
211 
212 }  // namespace mac_count
213 }  // namespace relay
214 }  // namespace tvm
215