1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \brief Softmax op constructions
22  * \file nn/flatten.h
23  */
24 #ifndef TVM_TOPI_NN_FLATTEN_H_
25 #define TVM_TOPI_NN_FLATTEN_H_
26 
27 #include <tvm/te/operation.h>
28 #include <tvm/topi/detail/constant_utils.h>
29 #include <tvm/topi/tags.h>
30 
31 #include <string>
32 #include <vector>
33 
34 namespace tvm {
35 namespace topi {
36 namespace nn {
37 
38 using namespace tvm::te;
39 
40 /*!
41  * \brief Flattens the input tensor into a 2-D tensor by collapsing higher dimensions.
42  * This requires the input tensor to have constant sized dimensions.
43  *
44  * \param x The input tensor.
45  * \param name The name of the operation
46  * \param tag The tag to mark the operation
47  *
48  * \return A 2-D tensor.
49  */
50 inline Tensor flatten(const Tensor& x, std::string name = "tensor", std::string tag = kInjective) {
51   auto ishape = x->shape;
52   PrimExpr dim = 1;
53   for (size_t i = 1; i < ishape.size(); ++i) {
54     dim = dim * ishape[i];
55   }
56 
57   Array<PrimExpr> oshape({ishape[0], dim});
58 
59   std::vector<PrimExpr> extra_shape;
60   for (size_t i = 1; i < ishape.size(); ++i) {
61     extra_shape.push_back(ishape[i]);
62   }
63   std::reverse(extra_shape.begin(), extra_shape.end());
64 
65   return tvm::te::compute(
66       oshape,
67       [&](Var i, Var j) {
68         PrimExpr idx = j;
69         std::vector<PrimExpr> index;
70         for (auto s : extra_shape) {
71           index.push_back(indexmod(idx, s));
72           idx = indexdiv(idx, s);
73         }
74         index.push_back(i);
75         std::reverse(index.begin(), index.end());
76         return x(index);
77       },
78       name, tag);
79 }
80 
81 }  // namespace nn
82 }  // namespace topi
83 }  // namespace tvm
84 #endif  // TVM_TOPI_NN_FLATTEN_H_
85