1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \brief Softmax op constructions 22 * \file nn/flatten.h 23 */ 24 #ifndef TVM_TOPI_NN_FLATTEN_H_ 25 #define TVM_TOPI_NN_FLATTEN_H_ 26 27 #include <tvm/te/operation.h> 28 #include <tvm/topi/detail/constant_utils.h> 29 #include <tvm/topi/tags.h> 30 31 #include <string> 32 #include <vector> 33 34 namespace tvm { 35 namespace topi { 36 namespace nn { 37 38 using namespace tvm::te; 39 40 /*! 41 * \brief Flattens the input tensor into a 2-D tensor by collapsing higher dimensions. 42 * This requires the input tensor to have constant sized dimensions. 43 * 44 * \param x The input tensor. 45 * \param name The name of the operation 46 * \param tag The tag to mark the operation 47 * 48 * \return A 2-D tensor. 49 */ 50 inline Tensor flatten(const Tensor& x, std::string name = "tensor", std::string tag = kInjective) { 51 auto ishape = x->shape; 52 PrimExpr dim = 1; 53 for (size_t i = 1; i < ishape.size(); ++i) { 54 dim = dim * ishape[i]; 55 } 56 57 Array<PrimExpr> oshape({ishape[0], dim}); 58 59 std::vector<PrimExpr> extra_shape; 60 for (size_t i = 1; i < ishape.size(); ++i) { 61 extra_shape.push_back(ishape[i]); 62 } 63 std::reverse(extra_shape.begin(), extra_shape.end()); 64 65 return tvm::te::compute( 66 oshape, 67 [&](Var i, Var j) { 68 PrimExpr idx = j; 69 std::vector<PrimExpr> index; 70 for (auto s : extra_shape) { 71 index.push_back(indexmod(idx, s)); 72 idx = indexdiv(idx, s); 73 } 74 index.push_back(i); 75 std::reverse(index.begin(), index.end()); 76 return x(index); 77 }, 78 name, tag); 79 } 80 81 } // namespace nn 82 } // namespace topi 83 } // namespace tvm 84 #endif // TVM_TOPI_NN_FLATTEN_H_ 85