1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "pass_ncnn.h"
16 
17 namespace pnnx {
18 
19 namespace ncnn {
20 
21 class Tensor_reshape : public GraphRewriterPass
22 {
23 public:
match_pattern_graph() const24     const char* match_pattern_graph() const
25     {
26         return R"PNNXIR(7767517
27 3 2
28 pnnx.Input              input       0 1 input
29 Tensor.reshape          op_0        1 1 input out shape=%shape
30 pnnx.Output             output      1 0 out
31 )PNNXIR";
32     }
33 
type_str() const34     const char* type_str() const
35     {
36         return "Reshape";
37     }
38 
name_str() const39     const char* name_str() const
40     {
41         return "reshape";
42     }
43 
write(Operator * op,const std::map<std::string,Parameter> & captured_params) const44     void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
45     {
46         const std::vector<int>& shape = captured_params.at("shape").ai;
47 
48         const int batch_index = op->inputs[0]->params["__batch_index"].i;
49 
50         if (batch_index != 0)
51         {
52             fprintf(stderr, "reshape tensor with batch index %d is not supported yet!\n", batch_index);
53         }
54 
55         // drop shape batch index
56         std::vector<int> new_shape;
57         for (int i = 0; i < (int)shape.size(); i++)
58         {
59             if (i == batch_index && shape[i] == 1)
60                 continue;
61 
62             new_shape.push_back(shape[i]);
63         }
64 
65         const int shape_rank = (int)new_shape.size();
66 
67         if (shape_rank > 5)
68         {
69             fprintf(stderr, "reshape to %d-rank tensor is not supported yet!\n", shape_rank);
70             return;
71         }
72 
73         if (shape_rank == 1)
74         {
75             op->params["0"] = new_shape[0];
76         }
77         if (shape_rank == 2)
78         {
79             op->params["0"] = new_shape[1];
80             op->params["1"] = new_shape[0];
81         }
82         if (shape_rank == 3)
83         {
84             op->params["0"] = new_shape[2];
85             op->params["1"] = new_shape[1];
86             op->params["2"] = new_shape[0];
87         }
88         if (shape_rank == 4)
89         {
90             op->params["0"] = new_shape[3];
91             op->params["1"] = new_shape[2];
92             op->params["11"] = new_shape[1];
93             op->params["2"] = new_shape[0];
94         }
95     }
96 };
97 
98 REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(Tensor_reshape, 20)
99 
100 } // namespace ncnn
101 
102 } // namespace pnnx
103