1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "pass_level1.h"
16 
17 #include "../utils.h"
18 
19 namespace pnnx {
20 
21 class RNN : public FuseModulePass
22 {
23 public:
match_type_str() const24     const char* match_type_str() const
25     {
26         return "__torch__.torch.nn.modules.rnn.RNN";
27     }
28 
type_str() const29     const char* type_str() const
30     {
31         return "nn.RNN";
32     }
33 
write(Operator * op,const std::shared_ptr<torch::jit::Graph> & graph,const torch::jit::Module & mod) const34     void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const
35     {
36         //         mod.dump(true, true, true);
37 
38         //         graph->dump();
39 
40         const torch::jit::Node* rnn = find_node_by_kind(graph, "aten::rnn_tanh");
41         const torch::jit::Node* rnn_relu = find_node_by_kind(graph, "aten::rnn_relu");
42 
43         if (rnn_relu)
44         {
45             rnn = rnn_relu;
46         }
47 
48         const torch::jit::Node* return_tuple = find_node_by_kind(graph, "prim::TupleConstruct");
49         if (return_tuple && return_tuple->inputs().size() == 2 && rnn->outputs().size() == 2
50                 && return_tuple->inputs()[0] == rnn->outputs()[1] && return_tuple->inputs()[1] == rnn->outputs()[0])
51         {
52             // mark the swapped output tuple
53             // we would restore the fine order in pass_level3/fuse_rnn_unpack
54             fprintf(stderr, "swapped detected !\n");
55             op->params["pnnx_rnn_output_swapped"] = 1;
56         }
57 
58         //         for (auto aa : rnn->schema().arguments())
59         //         {
60         //             fprintf(stderr, "arg %s\n", aa.name().c_str());
61         //         }
62 
63         const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor();
64 
65         op->params["input_size"] = weight_ih_l0.size(1);
66         op->params["hidden_size"] = weight_ih_l0.size(0);
67         op->params["num_layers"] = rnn->namedInput("num_layers");
68         op->params["nonlinearity"] = rnn_relu ? "relu" : "tanh";
69         op->params["bias"] = rnn->namedInput("has_biases");
70         op->params["batch_first"] = rnn->namedInput("batch_first");
71         op->params["bidirectional"] = rnn->namedInput("bidirectional");
72 
73         const int num_layers = op->params["num_layers"].i;
74         const bool bias = op->params["bias"].b;
75         const bool bidirectional = op->params["bidirectional"].b;
76 
77         for (int k = 0; k < num_layers; k++)
78         {
79             std::string weight_ih_lk_key = std::string("weight_ih_l") + std::to_string(k);
80             std::string weight_hh_lk_key = std::string("weight_hh_l") + std::to_string(k);
81 
82             op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key).toTensor();
83             op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key).toTensor();
84 
85             if (bias)
86             {
87                 std::string bias_ih_lk_key = std::string("bias_ih_l") + std::to_string(k);
88                 std::string bias_hh_lk_key = std::string("bias_hh_l") + std::to_string(k);
89 
90                 op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key).toTensor();
91                 op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor();
92             }
93 
94             if (bidirectional)
95             {
96                 std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse";
97                 std::string weight_hh_lk_reverse_key = std::string("weight_hh_l") + std::to_string(k) + "_reverse";
98 
99                 op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key).toTensor();
100                 op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key).toTensor();
101 
102                 if (bias)
103                 {
104                     std::string bias_ih_lk_reverse_key = std::string("bias_ih_l") + std::to_string(k) + "_reverse";
105                     std::string bias_hh_lk_reverse_key = std::string("bias_hh_l") + std::to_string(k) + "_reverse";
106 
107                     op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor();
108                     op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor();
109                 }
110             }
111         }
112     }
113 };
114 
115 REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(RNN)
116 
117 } // namespace pnnx
118