1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file src/relay/op/nn/convolution.h
22 * \brief Properties def of convlution operator for sharing.
23 */
24 #ifndef TVM_RELAY_OP_NN_POOLING_H_
25 #define TVM_RELAY_OP_NN_POOLING_H_
26
27 #include <tvm/relay/attrs/nn.h>
28 #include <tvm/relay/op.h>
29
30 #include <utility>
31
32 namespace tvm {
33 namespace relay {
34
35 template <typename T>
MakeMaxPool(Expr data,Array<IndexExpr> pool_size,Array<IndexExpr> strides,Array<IndexExpr> padding,String layout,bool ceil_mode,String op_name)36 inline Expr MakeMaxPool(Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
37 Array<IndexExpr> padding, String layout, bool ceil_mode, String op_name) {
38 auto attrs = make_object<T>();
39 attrs->pool_size = std::move(pool_size);
40 attrs->strides = std::move(strides);
41 attrs->padding = std::move(padding);
42 attrs->layout = std::move(layout);
43 attrs->ceil_mode = ceil_mode;
44 static const Op& op = Op::Get(op_name);
45 return Call(op, {data}, Attrs(attrs), {});
46 }
47
48 template <typename T>
MakeAvgPool(Expr data,Array<IndexExpr> pool_size,Array<IndexExpr> strides,Array<IndexExpr> padding,String layout,bool ceil_mode,bool count_include_pad,String op_name)49 inline Expr MakeAvgPool(Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides,
50 Array<IndexExpr> padding, String layout, bool ceil_mode,
51 bool count_include_pad, String op_name) {
52 auto attrs = make_object<T>();
53 attrs->pool_size = std::move(pool_size);
54 attrs->strides = std::move(strides);
55 attrs->padding = std::move(padding);
56 attrs->layout = std::move(layout);
57 attrs->ceil_mode = ceil_mode;
58 attrs->count_include_pad = count_include_pad;
59 static const Op& op = Op::Get(op_name);
60 return Call(op, {data}, Attrs(attrs), {});
61 }
62
63 } // namespace relay
64 } // namespace tvm
65 #endif // TVM_RELAY_OP_NN_POOLING_H_
66