1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  *
22  * \file tvm/relay/pass/pattern_util.h
23  * \brief Header of internal operator functions
24  *  These can be used for writing passes.
25  */
26 #ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_
27 #define TVM_RELAY_PASS_PATTERN_UTIL_H_
28 
29 #include <builtin_fp16.h>
30 #include <tvm/data_layout.h>
31 #include <tvm/relay/op.h>
32 #include <tvm/relay/expr.h>
33 #include <tvm/relay/analysis.h>
34 #include <tvm/relay/attrs/nn.h>
35 #include <tvm/relay/attrs/transform.h>
36 #include <tvm/relay/attrs/reduce.h>
37 #include <string>
38 #include <utility>
39 
40 
41 namespace tvm {
42 namespace relay {
43 
44 /*!
45  * \brief Dispatch DataType to the C++ data type
46  *  during runtime.
47  */
48 #define TVM_DTYPE_DISPATCH(type, DType, ...)            \
49   if (type == Float(64)) {                              \
50     typedef double DType;                               \
51     {__VA_ARGS__}                                       \
52   } else if (type == Float(32)) {                       \
53     typedef float DType;                                \
54     {__VA_ARGS__}                                       \
55   } else if (type == Float(16)) {                       \
56     typedef uint16_t DType;                             \
57     {__VA_ARGS__}                                       \
58   } else if (type == Int(64)) {                         \
59     typedef int64_t DType;                              \
60     {__VA_ARGS__}                                       \
61   } else if (type == Int(32)) {                         \
62     typedef int32_t DType;                              \
63     {__VA_ARGS__}                                       \
64   } else if (type == Int(16)) {                         \
65     typedef int16_t DType;                              \
66     {__VA_ARGS__}                                       \
67   } else if (type == Int(8)) {                          \
68     typedef int8_t DType;                               \
69     {__VA_ARGS__}                                       \
70   } else if (type == UInt(64)) {                        \
71     typedef uint64_t DType;                             \
72     {__VA_ARGS__}                                       \
73   } else if (type == UInt(32)) {                        \
74     typedef uint32_t DType;                             \
75     {__VA_ARGS__}                                       \
76   } else if (type == UInt(16)) {                        \
77     typedef uint16_t DType;                             \
78     {__VA_ARGS__}                                       \
79   } else if (type == UInt(8)) {                         \
80     typedef uint8_t DType;                              \
81     {__VA_ARGS__}                                       \
82   } else {                                              \
83     LOG(FATAL) << "unknown data type " << type;         \
84   }
85 
86 /*!
87  * \brief Try to match lhs and rhs via broadcasting rule, such that:
88  *
89  * rhs matches the dimension of lhs specified by lhs_axes
90  * rhs's value equals 1 on rest of dimensions.
91  *
92  * \param tlhs The type of left operand (data)
93  * \param trhs The type right operand (bias)
94  * \param lhs_axes The axes on lhs to match.
95  * \param rhs_value A squeezed version of rhs which only contains matched dimension.
96  * \return Whether match is successful.
97  */
98 inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs,
99                                      const TensorTypeNode* trhs,
100                                      const Array<Integer>& lhs_axes,
101                                      Expr* rhs_value = nullptr) {
102   if (tlhs->shape.size() < trhs->shape.size()) return false;
103   AttrsEqual equal;
104   size_t base = tlhs->shape.size() - trhs->shape.size();
105   size_t j = 0;
106 
107   NodePtr<SqueezeAttrs> squeeze_attrs;
108   if (rhs_value != nullptr) {
109     squeeze_attrs = make_node<SqueezeAttrs>();
110   }
111 
112   for (size_t i = 0; i < tlhs->shape.size(); ++i) {
113     if (j < lhs_axes.size() && i == static_cast<size_t>(lhs_axes[j]->value)) {
114       if (i < base || !equal(tlhs->shape[i], trhs->shape[i - base])) {
115         return false;
116       }
117       ++j;
118     } else if (i >= base) {
119       if (!is_const_int(trhs->shape[i - base], 1)) {
120         return false;
121       }
122       if (rhs_value != nullptr) {
123         squeeze_attrs->axis.push_back(static_cast<int>(i - base));
124       }
125     }
126   }
127   if (rhs_value != nullptr && squeeze_attrs->axis.size() != 0) {
128     static const Op& squeeze_op = Op::Get("squeeze");
129     *rhs_value = CallNode::make(squeeze_op, {rhs_value[0]}, Attrs(squeeze_attrs), {});
130   }
131   return true;
132 }
133 
134 /*!
135  * \brief Expand 1D Tensor to match axis.
136  *
137  * The result bias can be used to add or multiply to
138  * the target Tensor on the specified axis via broadcasting rule.
139  *
140  * \param bias The bias.
141  * \param target_ndim Target dimension.
142  * \param axes The axis on the output we want to match on.
143  */
ExpandBiasToMatchAxis(Expr bias,int target_ndim,const Array<Integer> & axes)144 inline Expr ExpandBiasToMatchAxis(Expr bias,
145                                   int target_ndim,
146                                   const Array<Integer>& axes) {
147   static const Op& expand_dims = Op::Get("expand_dims");
148   for (size_t i = axes.size(); i != 0; --i) {
149     if (i == axes.size()) {
150       int64_t num_pad_axis = target_ndim - axes[i - 1]->value - 1;
151       if (num_pad_axis > 0) {
152         auto attrs = make_node<ExpandDimsAttrs>();
153         attrs->axis = i;
154         attrs->num_newaxis = static_cast<int>(num_pad_axis);
155         bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {});
156       }
157     } else {
158       int64_t diff = axes[i]->value - axes[i - 1]->value;
159       CHECK_GE(diff, 0L);
160       if (diff > 0) {
161         auto attrs = make_node<ExpandDimsAttrs>();
162         attrs->axis = i;
163         attrs->num_newaxis = static_cast<int>(diff);
164         bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {});
165       }
166     }
167   }
168   return bias;
169 }
170 
171 /*!
172  * \brief Check if the call is depthwise conv2d.
173  *
174  * \param call The conv2d call.
175  * \param param The conv2d attributes.
176  * \return Whether it is depthwise_conv2d.
177  */
IsDepthwiseConv2D(const Call & call,const Conv2DAttrs * param,const Layout & kernel_layout)178 inline bool IsDepthwiseConv2D(const Call& call,
179                               const Conv2DAttrs* param,
180                               const Layout& kernel_layout) {
181   static const Layout kOIHW("OIHW");
182   const auto bilayout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
183   auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape);
184   return is_const_int(wshape[0], param->groups) &&
185       is_const_int(wshape[1], 1);
186 }
187 
188 /*!
189  * \brief Get super-dimension of output channels of conv2d
190  * \param call The conv2d call.
191  * \return Super-dimension size of output channels of conv2d.
192  */
GetConv2DSuperChannelsDim(const CallNode * call)193 inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
194     auto param = call->attrs.as<Conv2DAttrs>();
195     auto tweight = call->args[1]->type_as<TensorTypeNode>();
196     auto index = param->kernel_layout.find('O');
197     CHECK_NE(index, std::string::npos);
198     auto channels = as_const_int(tweight->shape[index]);
199     return *channels;
200 }
201 
202 /*!
203  * \brief Create a Constant with a scalar
204  *
205  * \param dtype The data type.
206  * \param value The value of the scalar.
207  * \return A Constant.
208  */
209 template<typename T>
MakeConstantScalar(DataType dtype,T value)210 inline Constant MakeConstantScalar(DataType dtype, T value) {
211   runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0});
212   TVM_DTYPE_DISPATCH(dtype, DType, {
213     if (dtype == Float(16)) {
214       // convert to float16
215       // storage is uint16_t
216       *static_cast<DType*>(arr->data) =
217         __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value));
218     } else {
219       *static_cast<DType*>(arr->data) = value;
220     }
221   })
222   return ConstantNode::make(arr);
223 }
224 
225 /*!
226  * \brief Check if two expressions are equal scalars.
227  * \param a The expression to be checked.
228  * \param b The expression to be checked
229  * \return Whether two expressions are equal scalars.
230  */
IsEqualScalar(const Expr & a,const Expr & b)231 inline bool IsEqualScalar(const Expr& a, const Expr& b) {
232   const auto* constant_a = a.as<ConstantNode>();
233   const auto* constant_b = b.as<ConstantNode>();
234   if (!constant_a || !constant_b || !constant_a->is_scalar() || !constant_b->is_scalar()) {
235     return false;
236   }
237   return AlphaEqual(a, b);
238 }
239 
GetField(Expr t,size_t i)240 inline Expr GetField(Expr t, size_t i) {
241   return TupleGetItemNode::make(t, i);
242 }
243 
Pair(Expr l,Expr r)244 inline Expr Pair(Expr l, Expr r) {
245   return TupleNode::make({l, r});
246 }
247 
Exp(Expr e)248 inline Expr Exp(Expr e) {
249   static const Op& op = Op::Get("exp");
250   return CallNode::make(op, {e});
251 }
252 
Log(Expr e)253 inline Expr Log(Expr e) {
254   static const Op& op = Op::Get("log");
255   return CallNode::make(op, {e});
256 }
257 /*!
258  * \brief Get an immediate scalar from a Constant expr.
259  *
260  * \param expr The Constant expr.
261  * \return A scalar with type T.
262  */
263 template <typename T>
GetScalarFromConstant(Expr expr)264 T GetScalarFromConstant(Expr expr) {
265   const auto* n = expr.as<ConstantNode>();
266   CHECK(n->is_scalar());
267   return static_cast<T*>(n->data->data)[0];
268 }
269 
Cast(Expr x,DataType dtype)270 inline Expr Cast(Expr x, DataType dtype) {
271   static const Op& op = Op::Get("cast");
272   auto attrs = make_node<CastAttrs>();
273   attrs->dtype = dtype;
274   return CallNode::make(op, {x}, Attrs(attrs), {});
275 }
276 
Negative(Expr x)277 inline Expr Negative(Expr x) {
278   static const Op& op = Op::Get("negative");
279   return CallNode::make(op, {x}, Attrs(), {});
280 }
281 
282 
Sqrt(Expr x)283 inline Expr Sqrt(Expr x) {
284   static const Op& op = Op::Get("sqrt");
285   return CallNode::make(op, {x}, Attrs(), {});
286 }
287 
288 
Relu(Expr x)289 inline Expr Relu(Expr x) {
290   static const Op& op = Op::Get("nn.relu");
291   return CallNode::make(op, {x}, Attrs(), {});
292 }
293 
294 
Round(Expr x)295 inline Expr Round(Expr x) {
296   static const Op& op = Op::Get("round");
297   return CallNode::make(op, {x}, Attrs(), {});
298 }
299 
300 
Clip(Expr x,double a_min,double a_max)301 inline Expr Clip(Expr x, double a_min, double a_max) {
302   static const Op& op = Op::Get("clip");
303   auto attrs = make_node<ClipAttrs>();
304   attrs->a_min = a_min;
305   attrs->a_max = a_max;
306   return CallNode::make(op, {x}, Attrs(attrs), {});
307 }
308 
309 
Add(Expr lhs,Expr rhs)310 inline Expr Add(Expr lhs, Expr rhs) {
311   static const Op& op = Op::Get("add");
312   return CallNode::make(op, {lhs, rhs}, Attrs(), {});
313 }
314 
315 
Subtract(Expr lhs,Expr rhs)316 inline Expr Subtract(Expr lhs, Expr rhs) {
317   static const Op& op = Op::Get("subtract");
318   return CallNode::make(op, {lhs, rhs}, Attrs(), {});
319 }
320 
321 
Multiply(Expr lhs,Expr rhs)322 inline Expr Multiply(Expr lhs, Expr rhs) {
323   static const Op& op = Op::Get("multiply");
324   return CallNode::make(op, {lhs, rhs}, Attrs(), {});
325 }
326 
327 
Divide(Expr lhs,Expr rhs)328 inline Expr Divide(Expr lhs, Expr rhs) {
329   static const Op& op = Op::Get("divide");
330   return CallNode::make(op, {lhs, rhs}, Attrs(), {});
331 }
332 
ZerosLike(Expr e)333 inline Expr ZerosLike(Expr e) {
334   static const Op& op = Op::Get("zeros_like");
335   return CallNode::make(op, {e});
336 }
337 
Zeros(Array<IndexExpr> shape,DataType dtype)338 inline Expr Zeros(Array<IndexExpr> shape, DataType dtype) {
339   auto attrs = make_node<InitOpAttrs>();
340   attrs->shape = std::move(shape);
341   attrs->dtype = std::move(dtype);
342   static const Op& op = Op::Get("zeros");
343   return CallNode::make(op, {}, Attrs(attrs), {});
344 }
345 
OnesLike(Expr e)346 inline Expr OnesLike(Expr e) {
347   static const Op& op = Op::Get("ones_like");
348   return CallNode::make(op, {e});
349 }
350 
CollapseSumLike(Expr e)351 inline Expr CollapseSumLike(Expr e) {
352   static const Op& op = Op::Get("collapse_sum_like");
353   return CallNode::make(op, {e});
354 }
355 
Power(Expr lhs,Expr rhs)356 inline Expr Power(Expr lhs, Expr rhs) {
357   static const Op& op = Op::Get("power");
358   return CallNode::make(op, {lhs, rhs}, Attrs(), {});
359 }
360 
361 
RightShift(Expr x,Expr nbit)362 inline Expr RightShift(Expr x, Expr nbit) {
363   static const Op& op = Op::Get("right_shift");
364   return CallNode::make(op, {x, nbit}, Attrs(), {});
365 }
366 
367 
LeftShift(Expr x,Expr nbit)368 inline Expr LeftShift(Expr x, Expr nbit) {
369   static const Op& op = Op::Get("left_shift");
370   return CallNode::make(op, {x, nbit}, Attrs(), {});
371 }
372 
373 
ReshapeLike(Expr lhs,Expr rhs)374 inline Expr ReshapeLike(Expr lhs, Expr rhs) {
375   static const Op& op = Op::Get("reshape_like");
376   return CallNode::make(op, {lhs, rhs}, Attrs(), {});
377 }
378 
379 
Copy(Expr data)380 inline Expr Copy(Expr data) {
381   static const Op& op = Op::Get("copy");
382   return CallNode::make(op, {data}, Attrs(), {});
383 }
384 
385 
Mean(Expr data,Array<Integer> axis,bool keepdims,bool exclude)386 inline Expr Mean(Expr data, Array<Integer> axis, bool keepdims, bool exclude) {
387   auto attrs = make_node<ReduceAttrs>();
388   attrs->axis = std::move(axis);
389   attrs->keepdims = keepdims;
390   attrs->exclude = exclude;
391   static const Op& op = Op::Get("mean");
392   return CallNode::make(op, {data}, Attrs(attrs), {});
393 }
394 
Variance(Expr data,Expr mean,Array<Integer> axis,bool keepdims,bool exclude)395 inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude) {
396   auto attrs = make_node<ReduceAttrs>();
397   attrs->axis = std::move(axis);
398   attrs->keepdims = keepdims;
399   attrs->exclude = exclude;
400   static const Op& op = Op::Get("variance");
401   return CallNode::make(op, {data, mean}, Attrs(attrs), {});
402 }
403 
404 
Where(const Expr & condition,const Expr & x,const Expr & y)405 static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) {
406   static const Op& op = Op::Get("where");
407   return CallNode::make(op, {condition, x, y});
408 }
409 
GreaterEqual(const Expr & lhs,const Expr & rhs)410 static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) {
411   static const Op& op = Op::Get("greater_equal");
412   return CallNode::make(op, {lhs, rhs}, Attrs(), {});
413 }
414 
Full(Expr fill_value,Array<IndexExpr> shape,DataType dtype)415 static inline Expr Full(Expr fill_value,
416                         Array<IndexExpr> shape,
417                         DataType dtype) {
418   auto attrs = make_node<InitOpAttrs>();
419   attrs->shape = std::move(shape);
420   attrs->dtype = std::move(dtype);
421   static const Op& op = Op::Get("full");
422   return CallNode::make(op, {fill_value}, Attrs(attrs), {});
423 }
424 
Conv2D(Expr data,Expr weight,Array<IndexExpr> strides,Array<IndexExpr> padding,Array<IndexExpr> dilation,int groups,IndexExpr channels,Array<IndexExpr> kernel_size,std::string data_layout,std::string kernel_layout,std::string out_layout,DataType out_dtype)425 static inline Expr Conv2D(Expr data, Expr weight, Array<IndexExpr> strides,
426                           Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups,
427                           IndexExpr channels, Array<IndexExpr> kernel_size, std::string data_layout,
428                           std::string kernel_layout, std::string out_layout, DataType out_dtype) {
429   auto attrs = make_node<Conv2DAttrs>();
430   attrs->strides = std::move(strides);
431   attrs->padding = std::move(padding);
432   attrs->dilation = std::move(dilation);
433   attrs->groups = groups;
434   attrs->channels = std::move(channels);
435   attrs->kernel_size = std::move(kernel_size);
436   attrs->data_layout = std::move(data_layout);
437   attrs->kernel_layout = std::move(kernel_layout);
438   attrs->out_layout = std::move(out_layout);
439   attrs->out_dtype = std::move(out_dtype);
440   static const Op& op = Op::Get("nn.conv2d");
441   return CallNode::make(op, {data, weight}, Attrs(attrs), {});
442 }
443 
Dense(Expr data,Expr weight,IndexExpr units,DataType out_dtype)444 static inline Expr Dense(Expr data,
445                          Expr weight,
446                          IndexExpr units,
447                          DataType out_dtype) {
448   auto attrs = make_node<DenseAttrs>();
449   attrs->units = units;
450   attrs->out_dtype = out_dtype;
451   static const Op& op = Op::Get("nn.dense");
452   return CallNode::make(op, {data, weight}, Attrs(attrs), {});
453 }
454 
Sum(Expr data,Array<Integer> axis,bool keepdims,bool exclude)455 static inline Expr Sum(Expr data, Array<Integer> axis, bool keepdims, bool exclude) {
456   auto attrs = make_node<ReduceAttrs>();
457   attrs->axis = std::move(axis);
458   attrs->keepdims = keepdims;
459   attrs->exclude = exclude;
460   static const Op& op = Op::Get("sum");
461   return CallNode::make(op, {data}, Attrs(attrs), {});
462 }
463 
Reshape(Expr data,Array<Integer> newshape)464 static inline Expr Reshape(Expr data, Array<Integer> newshape) {
465   auto attrs = make_node<ReshapeAttrs>();
466   attrs->newshape = std::move(newshape);
467   attrs->reverse = false;
468   static const Op& op = Op::Get("reshape");
469   return CallNode::make(op, {data}, Attrs(attrs), {});
470 }
471 
AvgPool2D(Expr data,Array<IndexExpr> pool_size,Array<IndexExpr> strides,Array<IndexExpr> padding,std::string layout,bool ceil_mode,bool count_include_pad)472 static inline Expr AvgPool2D(Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
473                              Array<IndexExpr> padding, std::string layout, bool ceil_mode,
474                              bool count_include_pad) {
475   auto attrs = make_node<AvgPool2DAttrs>();
476   attrs->pool_size = std::move(pool_size);
477   attrs->strides = std::move(strides);
478   attrs->padding = std::move(padding);
479   attrs->layout = std::move(layout);
480   attrs->ceil_mode = ceil_mode;
481   attrs->count_include_pad = count_include_pad;
482   static const Op& op = Op::Get("nn.avg_pool2d");
483   return CallNode::make(op, {data}, Attrs(attrs), {});
484 }
485 
Pad(Expr data,Array<Array<IndexExpr>> pad_width,double pad_value,std::string pad_mode)486 static inline Expr Pad(Expr data, Array<Array<IndexExpr>> pad_width, double pad_value,
487                        std::string pad_mode) {
488   auto attrs = make_node<PadAttrs>();
489   attrs->pad_value = pad_value;
490   attrs->pad_width = std::move(pad_width);
491   attrs->pad_mode = std::move(pad_mode);
492   static const Op& op = Op::Get("nn.pad");
493   return CallNode::make(op, {data}, Attrs(attrs), {});
494 }
495 
Tile(Expr data,Array<Integer> reps)496 static inline Expr Tile(Expr data, Array<Integer> reps) {
497   auto attrs = make_node<TileAttrs>();
498   attrs->reps = reps;
499   static const Op& op = Op::Get("tile");
500   return CallNode::make(op, {data}, Attrs(attrs), {});
501 }
502 
503 Expr MakeConcatenate(Expr data, int axis);
504 
505 Expr MakeRepeat(Expr data, int repeats, int axis);
506 
507 Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
508 
509 Expr MakeStack(Expr data, int axis);
510 
511 Expr MakeSplit(Expr data, NodeRef indices_or_sections, int axis);
512 
513 Expr MakeSqueeze(Expr data, Array<Integer> axis);
514 
515 Expr MakeExpandDims(Expr data, int axis, int num_newaxis);
516 
517 Expr MakeLayoutTransform(Expr data, std::string src_layout, std::string dst_layout);
518 
519 Expr StopFusion(Expr data);
520 
521 Expr CastHint(Expr data, DataType dtype);
522 
523 }  // namespace relay
524 }  // namespace tvm
525 #endif  // TVM_RELAY_PASS_PATTERN_UTIL_H_
526