1 // Tencent is pleased to support the open source community by making ncnn available. 2 // 3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 4 // 5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6 // in compliance with the License. You may obtain a copy of the License at 7 // 8 // https://opensource.org/licenses/BSD-3-Clause 9 // 10 // Unless required by applicable law or agreed to in writing, software distributed 11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the 13 // specific language governing permissions and limitations under the License. 14 15 #include "pass_ncnn.h" 16 17 namespace pnnx { 18 19 namespace ncnn { 20 21 class Tensor_reshape : public GraphRewriterPass 22 { 23 public: match_pattern_graph() const24 const char* match_pattern_graph() const 25 { 26 return R"PNNXIR(7767517 27 3 2 28 pnnx.Input input 0 1 input 29 Tensor.reshape op_0 1 1 input out shape=%shape 30 pnnx.Output output 1 0 out 31 )PNNXIR"; 32 } 33 type_str() const34 const char* type_str() const 35 { 36 return "Reshape"; 37 } 38 name_str() const39 const char* name_str() const 40 { 41 return "reshape"; 42 } 43 write(Operator * op,const std::map<std::string,Parameter> & captured_params) const44 void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const 45 { 46 const std::vector<int>& shape = captured_params.at("shape").ai; 47 48 const int batch_index = op->inputs[0]->params["__batch_index"].i; 49 50 if (batch_index != 0) 51 { 52 fprintf(stderr, "reshape tensor with batch index %d is not supported yet!\n", batch_index); 53 } 54 55 // drop shape batch index 56 std::vector<int> new_shape; 57 for (int i = 0; i < (int)shape.size(); i++) 58 { 59 if (i == batch_index && shape[i] == 1) 60 continue; 61 62 new_shape.push_back(shape[i]); 63 } 64 65 const int shape_rank = (int)new_shape.size(); 66 67 if (shape_rank > 5) 68 { 69 fprintf(stderr, "reshape to %d-rank tensor is not supported yet!\n", shape_rank); 70 return; 71 } 72 73 if (shape_rank == 1) 74 { 75 op->params["0"] = new_shape[0]; 76 } 77 if (shape_rank == 2) 78 { 79 op->params["0"] = new_shape[1]; 80 op->params["1"] = new_shape[0]; 81 } 82 if (shape_rank == 3) 83 { 84 op->params["0"] = new_shape[2]; 85 op->params["1"] = new_shape[1]; 86 op->params["2"] = new_shape[0]; 87 } 88 if (shape_rank == 4) 89 { 90 op->params["0"] = new_shape[3]; 91 op->params["1"] = new_shape[2]; 92 op->params["11"] = new_shape[1]; 93 op->params["2"] = new_shape[0]; 94 } 95 } 96 }; 97 98 REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(Tensor_reshape, 20) 99 100 } // namespace ncnn 101 102 } // namespace pnnx 103