1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 *
22 * \file tvm/relay/pass/pattern_util.h
23 * \brief Header of internal operator functions
24 * These can be used for writing passes.
25 */
26 #ifndef TVM_RELAY_PASS_PATTERN_UTIL_H_
27 #define TVM_RELAY_PASS_PATTERN_UTIL_H_
28
29 #include <builtin_fp16.h>
30 #include <tvm/data_layout.h>
31 #include <tvm/relay/op.h>
32 #include <tvm/relay/expr.h>
33 #include <tvm/relay/analysis.h>
34 #include <tvm/relay/attrs/nn.h>
35 #include <tvm/relay/attrs/transform.h>
36 #include <tvm/relay/attrs/reduce.h>
37 #include <string>
38 #include <utility>
39
40
41 namespace tvm {
42 namespace relay {
43
44 /*!
45 * \brief Dispatch DataType to the C++ data type
46 * during runtime.
47 */
48 #define TVM_DTYPE_DISPATCH(type, DType, ...) \
49 if (type == Float(64)) { \
50 typedef double DType; \
51 {__VA_ARGS__} \
52 } else if (type == Float(32)) { \
53 typedef float DType; \
54 {__VA_ARGS__} \
55 } else if (type == Float(16)) { \
56 typedef uint16_t DType; \
57 {__VA_ARGS__} \
58 } else if (type == Int(64)) { \
59 typedef int64_t DType; \
60 {__VA_ARGS__} \
61 } else if (type == Int(32)) { \
62 typedef int32_t DType; \
63 {__VA_ARGS__} \
64 } else if (type == Int(16)) { \
65 typedef int16_t DType; \
66 {__VA_ARGS__} \
67 } else if (type == Int(8)) { \
68 typedef int8_t DType; \
69 {__VA_ARGS__} \
70 } else if (type == UInt(64)) { \
71 typedef uint64_t DType; \
72 {__VA_ARGS__} \
73 } else if (type == UInt(32)) { \
74 typedef uint32_t DType; \
75 {__VA_ARGS__} \
76 } else if (type == UInt(16)) { \
77 typedef uint16_t DType; \
78 {__VA_ARGS__} \
79 } else if (type == UInt(8)) { \
80 typedef uint8_t DType; \
81 {__VA_ARGS__} \
82 } else { \
83 LOG(FATAL) << "unknown data type " << type; \
84 }
85
86 /*!
87 * \brief Try to match lhs and rhs via broadcasting rule, such that:
88 *
89 * rhs matches the dimension of lhs specified by lhs_axes
90 * rhs's value equals 1 on rest of dimensions.
91 *
92 * \param tlhs The type of left operand (data)
93 * \param trhs The type right operand (bias)
94 * \param lhs_axes The axes on lhs to match.
95 * \param rhs_value A squeezed version of rhs which only contains matched dimension.
96 * \return Whether match is successful.
97 */
98 inline bool MatchBroadcastToLeftAxes(const TensorTypeNode* tlhs,
99 const TensorTypeNode* trhs,
100 const Array<Integer>& lhs_axes,
101 Expr* rhs_value = nullptr) {
102 if (tlhs->shape.size() < trhs->shape.size()) return false;
103 AttrsEqual equal;
104 size_t base = tlhs->shape.size() - trhs->shape.size();
105 size_t j = 0;
106
107 NodePtr<SqueezeAttrs> squeeze_attrs;
108 if (rhs_value != nullptr) {
109 squeeze_attrs = make_node<SqueezeAttrs>();
110 }
111
112 for (size_t i = 0; i < tlhs->shape.size(); ++i) {
113 if (j < lhs_axes.size() && i == static_cast<size_t>(lhs_axes[j]->value)) {
114 if (i < base || !equal(tlhs->shape[i], trhs->shape[i - base])) {
115 return false;
116 }
117 ++j;
118 } else if (i >= base) {
119 if (!is_const_int(trhs->shape[i - base], 1)) {
120 return false;
121 }
122 if (rhs_value != nullptr) {
123 squeeze_attrs->axis.push_back(static_cast<int>(i - base));
124 }
125 }
126 }
127 if (rhs_value != nullptr && squeeze_attrs->axis.size() != 0) {
128 static const Op& squeeze_op = Op::Get("squeeze");
129 *rhs_value = CallNode::make(squeeze_op, {rhs_value[0]}, Attrs(squeeze_attrs), {});
130 }
131 return true;
132 }
133
134 /*!
135 * \brief Expand 1D Tensor to match axis.
136 *
137 * The result bias can be used to add or multiply to
138 * the target Tensor on the specified axis via broadcasting rule.
139 *
140 * \param bias The bias.
141 * \param target_ndim Target dimension.
142 * \param axes The axis on the output we want to match on.
143 */
ExpandBiasToMatchAxis(Expr bias,int target_ndim,const Array<Integer> & axes)144 inline Expr ExpandBiasToMatchAxis(Expr bias,
145 int target_ndim,
146 const Array<Integer>& axes) {
147 static const Op& expand_dims = Op::Get("expand_dims");
148 for (size_t i = axes.size(); i != 0; --i) {
149 if (i == axes.size()) {
150 int64_t num_pad_axis = target_ndim - axes[i - 1]->value - 1;
151 if (num_pad_axis > 0) {
152 auto attrs = make_node<ExpandDimsAttrs>();
153 attrs->axis = i;
154 attrs->num_newaxis = static_cast<int>(num_pad_axis);
155 bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {});
156 }
157 } else {
158 int64_t diff = axes[i]->value - axes[i - 1]->value;
159 CHECK_GE(diff, 0L);
160 if (diff > 0) {
161 auto attrs = make_node<ExpandDimsAttrs>();
162 attrs->axis = i;
163 attrs->num_newaxis = static_cast<int>(diff);
164 bias = CallNode::make(expand_dims, {bias}, Attrs(attrs), {});
165 }
166 }
167 }
168 return bias;
169 }
170
171 /*!
172 * \brief Check if the call is depthwise conv2d.
173 *
174 * \param call The conv2d call.
175 * \param param The conv2d attributes.
176 * \return Whether it is depthwise_conv2d.
177 */
IsDepthwiseConv2D(const Call & call,const Conv2DAttrs * param,const Layout & kernel_layout)178 inline bool IsDepthwiseConv2D(const Call& call,
179 const Conv2DAttrs* param,
180 const Layout& kernel_layout) {
181 static const Layout kOIHW("OIHW");
182 const auto bilayout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
183 auto wshape = bilayout.ForwardShape(call->args[1]->type_as<TensorTypeNode>()->shape);
184 return is_const_int(wshape[0], param->groups) &&
185 is_const_int(wshape[1], 1);
186 }
187
188 /*!
189 * \brief Get super-dimension of output channels of conv2d
190 * \param call The conv2d call.
191 * \return Super-dimension size of output channels of conv2d.
192 */
GetConv2DSuperChannelsDim(const CallNode * call)193 inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
194 auto param = call->attrs.as<Conv2DAttrs>();
195 auto tweight = call->args[1]->type_as<TensorTypeNode>();
196 auto index = param->kernel_layout.find('O');
197 CHECK_NE(index, std::string::npos);
198 auto channels = as_const_int(tweight->shape[index]);
199 return *channels;
200 }
201
202 /*!
203 * \brief Create a Constant with a scalar
204 *
205 * \param dtype The data type.
206 * \param value The value of the scalar.
207 * \return A Constant.
208 */
209 template<typename T>
MakeConstantScalar(DataType dtype,T value)210 inline Constant MakeConstantScalar(DataType dtype, T value) {
211 runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0});
212 TVM_DTYPE_DISPATCH(dtype, DType, {
213 if (dtype == Float(16)) {
214 // convert to float16
215 // storage is uint16_t
216 *static_cast<DType*>(arr->data) =
217 __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value));
218 } else {
219 *static_cast<DType*>(arr->data) = value;
220 }
221 })
222 return ConstantNode::make(arr);
223 }
224
225 /*!
226 * \brief Check if two expressions are equal scalars.
227 * \param a The expression to be checked.
228 * \param b The expression to be checked
229 * \return Whether two expressions are equal scalars.
230 */
IsEqualScalar(const Expr & a,const Expr & b)231 inline bool IsEqualScalar(const Expr& a, const Expr& b) {
232 const auto* constant_a = a.as<ConstantNode>();
233 const auto* constant_b = b.as<ConstantNode>();
234 if (!constant_a || !constant_b || !constant_a->is_scalar() || !constant_b->is_scalar()) {
235 return false;
236 }
237 return AlphaEqual(a, b);
238 }
239
GetField(Expr t,size_t i)240 inline Expr GetField(Expr t, size_t i) {
241 return TupleGetItemNode::make(t, i);
242 }
243
Pair(Expr l,Expr r)244 inline Expr Pair(Expr l, Expr r) {
245 return TupleNode::make({l, r});
246 }
247
Exp(Expr e)248 inline Expr Exp(Expr e) {
249 static const Op& op = Op::Get("exp");
250 return CallNode::make(op, {e});
251 }
252
Log(Expr e)253 inline Expr Log(Expr e) {
254 static const Op& op = Op::Get("log");
255 return CallNode::make(op, {e});
256 }
257 /*!
258 * \brief Get an immediate scalar from a Constant expr.
259 *
260 * \param expr The Constant expr.
261 * \return A scalar with type T.
262 */
263 template <typename T>
GetScalarFromConstant(Expr expr)264 T GetScalarFromConstant(Expr expr) {
265 const auto* n = expr.as<ConstantNode>();
266 CHECK(n->is_scalar());
267 return static_cast<T*>(n->data->data)[0];
268 }
269
Cast(Expr x,DataType dtype)270 inline Expr Cast(Expr x, DataType dtype) {
271 static const Op& op = Op::Get("cast");
272 auto attrs = make_node<CastAttrs>();
273 attrs->dtype = dtype;
274 return CallNode::make(op, {x}, Attrs(attrs), {});
275 }
276
Negative(Expr x)277 inline Expr Negative(Expr x) {
278 static const Op& op = Op::Get("negative");
279 return CallNode::make(op, {x}, Attrs(), {});
280 }
281
282
Sqrt(Expr x)283 inline Expr Sqrt(Expr x) {
284 static const Op& op = Op::Get("sqrt");
285 return CallNode::make(op, {x}, Attrs(), {});
286 }
287
288
Relu(Expr x)289 inline Expr Relu(Expr x) {
290 static const Op& op = Op::Get("nn.relu");
291 return CallNode::make(op, {x}, Attrs(), {});
292 }
293
294
Round(Expr x)295 inline Expr Round(Expr x) {
296 static const Op& op = Op::Get("round");
297 return CallNode::make(op, {x}, Attrs(), {});
298 }
299
300
Clip(Expr x,double a_min,double a_max)301 inline Expr Clip(Expr x, double a_min, double a_max) {
302 static const Op& op = Op::Get("clip");
303 auto attrs = make_node<ClipAttrs>();
304 attrs->a_min = a_min;
305 attrs->a_max = a_max;
306 return CallNode::make(op, {x}, Attrs(attrs), {});
307 }
308
309
Add(Expr lhs,Expr rhs)310 inline Expr Add(Expr lhs, Expr rhs) {
311 static const Op& op = Op::Get("add");
312 return CallNode::make(op, {lhs, rhs}, Attrs(), {});
313 }
314
315
Subtract(Expr lhs,Expr rhs)316 inline Expr Subtract(Expr lhs, Expr rhs) {
317 static const Op& op = Op::Get("subtract");
318 return CallNode::make(op, {lhs, rhs}, Attrs(), {});
319 }
320
321
Multiply(Expr lhs,Expr rhs)322 inline Expr Multiply(Expr lhs, Expr rhs) {
323 static const Op& op = Op::Get("multiply");
324 return CallNode::make(op, {lhs, rhs}, Attrs(), {});
325 }
326
327
Divide(Expr lhs,Expr rhs)328 inline Expr Divide(Expr lhs, Expr rhs) {
329 static const Op& op = Op::Get("divide");
330 return CallNode::make(op, {lhs, rhs}, Attrs(), {});
331 }
332
ZerosLike(Expr e)333 inline Expr ZerosLike(Expr e) {
334 static const Op& op = Op::Get("zeros_like");
335 return CallNode::make(op, {e});
336 }
337
Zeros(Array<IndexExpr> shape,DataType dtype)338 inline Expr Zeros(Array<IndexExpr> shape, DataType dtype) {
339 auto attrs = make_node<InitOpAttrs>();
340 attrs->shape = std::move(shape);
341 attrs->dtype = std::move(dtype);
342 static const Op& op = Op::Get("zeros");
343 return CallNode::make(op, {}, Attrs(attrs), {});
344 }
345
OnesLike(Expr e)346 inline Expr OnesLike(Expr e) {
347 static const Op& op = Op::Get("ones_like");
348 return CallNode::make(op, {e});
349 }
350
CollapseSumLike(Expr e)351 inline Expr CollapseSumLike(Expr e) {
352 static const Op& op = Op::Get("collapse_sum_like");
353 return CallNode::make(op, {e});
354 }
355
Power(Expr lhs,Expr rhs)356 inline Expr Power(Expr lhs, Expr rhs) {
357 static const Op& op = Op::Get("power");
358 return CallNode::make(op, {lhs, rhs}, Attrs(), {});
359 }
360
361
RightShift(Expr x,Expr nbit)362 inline Expr RightShift(Expr x, Expr nbit) {
363 static const Op& op = Op::Get("right_shift");
364 return CallNode::make(op, {x, nbit}, Attrs(), {});
365 }
366
367
LeftShift(Expr x,Expr nbit)368 inline Expr LeftShift(Expr x, Expr nbit) {
369 static const Op& op = Op::Get("left_shift");
370 return CallNode::make(op, {x, nbit}, Attrs(), {});
371 }
372
373
ReshapeLike(Expr lhs,Expr rhs)374 inline Expr ReshapeLike(Expr lhs, Expr rhs) {
375 static const Op& op = Op::Get("reshape_like");
376 return CallNode::make(op, {lhs, rhs}, Attrs(), {});
377 }
378
379
Copy(Expr data)380 inline Expr Copy(Expr data) {
381 static const Op& op = Op::Get("copy");
382 return CallNode::make(op, {data}, Attrs(), {});
383 }
384
385
Mean(Expr data,Array<Integer> axis,bool keepdims,bool exclude)386 inline Expr Mean(Expr data, Array<Integer> axis, bool keepdims, bool exclude) {
387 auto attrs = make_node<ReduceAttrs>();
388 attrs->axis = std::move(axis);
389 attrs->keepdims = keepdims;
390 attrs->exclude = exclude;
391 static const Op& op = Op::Get("mean");
392 return CallNode::make(op, {data}, Attrs(attrs), {});
393 }
394
Variance(Expr data,Expr mean,Array<Integer> axis,bool keepdims,bool exclude)395 inline Expr Variance(Expr data, Expr mean, Array<Integer> axis, bool keepdims, bool exclude) {
396 auto attrs = make_node<ReduceAttrs>();
397 attrs->axis = std::move(axis);
398 attrs->keepdims = keepdims;
399 attrs->exclude = exclude;
400 static const Op& op = Op::Get("variance");
401 return CallNode::make(op, {data, mean}, Attrs(attrs), {});
402 }
403
404
Where(const Expr & condition,const Expr & x,const Expr & y)405 static inline Expr Where(const Expr& condition, const Expr& x, const Expr& y) {
406 static const Op& op = Op::Get("where");
407 return CallNode::make(op, {condition, x, y});
408 }
409
GreaterEqual(const Expr & lhs,const Expr & rhs)410 static inline Expr GreaterEqual(const Expr& lhs, const Expr& rhs) {
411 static const Op& op = Op::Get("greater_equal");
412 return CallNode::make(op, {lhs, rhs}, Attrs(), {});
413 }
414
Full(Expr fill_value,Array<IndexExpr> shape,DataType dtype)415 static inline Expr Full(Expr fill_value,
416 Array<IndexExpr> shape,
417 DataType dtype) {
418 auto attrs = make_node<InitOpAttrs>();
419 attrs->shape = std::move(shape);
420 attrs->dtype = std::move(dtype);
421 static const Op& op = Op::Get("full");
422 return CallNode::make(op, {fill_value}, Attrs(attrs), {});
423 }
424
Conv2D(Expr data,Expr weight,Array<IndexExpr> strides,Array<IndexExpr> padding,Array<IndexExpr> dilation,int groups,IndexExpr channels,Array<IndexExpr> kernel_size,std::string data_layout,std::string kernel_layout,std::string out_layout,DataType out_dtype)425 static inline Expr Conv2D(Expr data, Expr weight, Array<IndexExpr> strides,
426 Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups,
427 IndexExpr channels, Array<IndexExpr> kernel_size, std::string data_layout,
428 std::string kernel_layout, std::string out_layout, DataType out_dtype) {
429 auto attrs = make_node<Conv2DAttrs>();
430 attrs->strides = std::move(strides);
431 attrs->padding = std::move(padding);
432 attrs->dilation = std::move(dilation);
433 attrs->groups = groups;
434 attrs->channels = std::move(channels);
435 attrs->kernel_size = std::move(kernel_size);
436 attrs->data_layout = std::move(data_layout);
437 attrs->kernel_layout = std::move(kernel_layout);
438 attrs->out_layout = std::move(out_layout);
439 attrs->out_dtype = std::move(out_dtype);
440 static const Op& op = Op::Get("nn.conv2d");
441 return CallNode::make(op, {data, weight}, Attrs(attrs), {});
442 }
443
Dense(Expr data,Expr weight,IndexExpr units,DataType out_dtype)444 static inline Expr Dense(Expr data,
445 Expr weight,
446 IndexExpr units,
447 DataType out_dtype) {
448 auto attrs = make_node<DenseAttrs>();
449 attrs->units = units;
450 attrs->out_dtype = out_dtype;
451 static const Op& op = Op::Get("nn.dense");
452 return CallNode::make(op, {data, weight}, Attrs(attrs), {});
453 }
454
Sum(Expr data,Array<Integer> axis,bool keepdims,bool exclude)455 static inline Expr Sum(Expr data, Array<Integer> axis, bool keepdims, bool exclude) {
456 auto attrs = make_node<ReduceAttrs>();
457 attrs->axis = std::move(axis);
458 attrs->keepdims = keepdims;
459 attrs->exclude = exclude;
460 static const Op& op = Op::Get("sum");
461 return CallNode::make(op, {data}, Attrs(attrs), {});
462 }
463
Reshape(Expr data,Array<Integer> newshape)464 static inline Expr Reshape(Expr data, Array<Integer> newshape) {
465 auto attrs = make_node<ReshapeAttrs>();
466 attrs->newshape = std::move(newshape);
467 attrs->reverse = false;
468 static const Op& op = Op::Get("reshape");
469 return CallNode::make(op, {data}, Attrs(attrs), {});
470 }
471
AvgPool2D(Expr data,Array<IndexExpr> pool_size,Array<IndexExpr> strides,Array<IndexExpr> padding,std::string layout,bool ceil_mode,bool count_include_pad)472 static inline Expr AvgPool2D(Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
473 Array<IndexExpr> padding, std::string layout, bool ceil_mode,
474 bool count_include_pad) {
475 auto attrs = make_node<AvgPool2DAttrs>();
476 attrs->pool_size = std::move(pool_size);
477 attrs->strides = std::move(strides);
478 attrs->padding = std::move(padding);
479 attrs->layout = std::move(layout);
480 attrs->ceil_mode = ceil_mode;
481 attrs->count_include_pad = count_include_pad;
482 static const Op& op = Op::Get("nn.avg_pool2d");
483 return CallNode::make(op, {data}, Attrs(attrs), {});
484 }
485
Pad(Expr data,Array<Array<IndexExpr>> pad_width,double pad_value,std::string pad_mode)486 static inline Expr Pad(Expr data, Array<Array<IndexExpr>> pad_width, double pad_value,
487 std::string pad_mode) {
488 auto attrs = make_node<PadAttrs>();
489 attrs->pad_value = pad_value;
490 attrs->pad_width = std::move(pad_width);
491 attrs->pad_mode = std::move(pad_mode);
492 static const Op& op = Op::Get("nn.pad");
493 return CallNode::make(op, {data}, Attrs(attrs), {});
494 }
495
Tile(Expr data,Array<Integer> reps)496 static inline Expr Tile(Expr data, Array<Integer> reps) {
497 auto attrs = make_node<TileAttrs>();
498 attrs->reps = reps;
499 static const Op& op = Op::Get("tile");
500 return CallNode::make(op, {data}, Attrs(attrs), {});
501 }
502
503 Expr MakeConcatenate(Expr data, int axis);
504
505 Expr MakeRepeat(Expr data, int repeats, int axis);
506
507 Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides);
508
509 Expr MakeStack(Expr data, int axis);
510
511 Expr MakeSplit(Expr data, NodeRef indices_or_sections, int axis);
512
513 Expr MakeSqueeze(Expr data, Array<Integer> axis);
514
515 Expr MakeExpandDims(Expr data, int axis, int num_newaxis);
516
517 Expr MakeLayoutTransform(Expr data, std::string src_layout, std::string dst_layout);
518
519 Expr StopFusion(Expr data);
520
521 Expr CastHint(Expr data, DataType dtype);
522
523 } // namespace relay
524 } // namespace tvm
525 #endif // TVM_RELAY_PASS_PATTERN_UTIL_H_
526