1 // Tencent is pleased to support the open source community by making ncnn available. 2 // 3 // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 4 // 5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6 // in compliance with the License. You may obtain a copy of the License at 7 // 8 // https://opensource.org/licenses/BSD-3-Clause 9 // 10 // Unless required by applicable law or agreed to in writing, software distributed 11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the 13 // specific language governing permissions and limitations under the License. 14 15 #include "pass_level1.h" 16 17 #include "../utils.h" 18 19 namespace pnnx { 20 21 class RNN : public FuseModulePass 22 { 23 public: match_type_str() const24 const char* match_type_str() const 25 { 26 return "__torch__.torch.nn.modules.rnn.RNN"; 27 } 28 type_str() const29 const char* type_str() const 30 { 31 return "nn.RNN"; 32 } 33 write(Operator * op,const std::shared_ptr<torch::jit::Graph> & graph,const torch::jit::Module & mod) const34 void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const 35 { 36 // mod.dump(true, true, true); 37 38 // graph->dump(); 39 40 const torch::jit::Node* rnn = find_node_by_kind(graph, "aten::rnn_tanh"); 41 const torch::jit::Node* rnn_relu = find_node_by_kind(graph, "aten::rnn_relu"); 42 43 if (rnn_relu) 44 { 45 rnn = rnn_relu; 46 } 47 48 const torch::jit::Node* return_tuple = find_node_by_kind(graph, "prim::TupleConstruct"); 49 if (return_tuple && return_tuple->inputs().size() == 2 && rnn->outputs().size() == 2 50 && return_tuple->inputs()[0] == rnn->outputs()[1] && return_tuple->inputs()[1] == rnn->outputs()[0]) 51 { 52 // mark the swapped output tuple 53 // we would restore the fine order in pass_level3/fuse_rnn_unpack 54 fprintf(stderr, "swapped detected !\n"); 55 op->params["pnnx_rnn_output_swapped"] = 1; 56 } 57 58 // for (auto aa : rnn->schema().arguments()) 59 // { 60 // fprintf(stderr, "arg %s\n", aa.name().c_str()); 61 // } 62 63 const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor(); 64 65 op->params["input_size"] = weight_ih_l0.size(1); 66 op->params["hidden_size"] = weight_ih_l0.size(0); 67 op->params["num_layers"] = rnn->namedInput("num_layers"); 68 op->params["nonlinearity"] = rnn_relu ? "relu" : "tanh"; 69 op->params["bias"] = rnn->namedInput("has_biases"); 70 op->params["batch_first"] = rnn->namedInput("batch_first"); 71 op->params["bidirectional"] = rnn->namedInput("bidirectional"); 72 73 const int num_layers = op->params["num_layers"].i; 74 const bool bias = op->params["bias"].b; 75 const bool bidirectional = op->params["bidirectional"].b; 76 77 for (int k = 0; k < num_layers; k++) 78 { 79 std::string weight_ih_lk_key = std::string("weight_ih_l") + std::to_string(k); 80 std::string weight_hh_lk_key = std::string("weight_hh_l") + std::to_string(k); 81 82 op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key).toTensor(); 83 op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key).toTensor(); 84 85 if (bias) 86 { 87 std::string bias_ih_lk_key = std::string("bias_ih_l") + std::to_string(k); 88 std::string bias_hh_lk_key = std::string("bias_hh_l") + std::to_string(k); 89 90 op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key).toTensor(); 91 op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor(); 92 } 93 94 if (bidirectional) 95 { 96 std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse"; 97 std::string weight_hh_lk_reverse_key = std::string("weight_hh_l") + std::to_string(k) + "_reverse"; 98 99 op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key).toTensor(); 100 op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key).toTensor(); 101 102 if (bias) 103 { 104 std::string bias_ih_lk_reverse_key = std::string("bias_ih_l") + std::to_string(k) + "_reverse"; 105 std::string bias_hh_lk_reverse_key = std::string("bias_hh_l") + std::to_string(k) + "_reverse"; 106 107 op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor(); 108 op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor(); 109 } 110 } 111 } 112 } 113 }; 114 115 REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(RNN) 116 117 } // namespace pnnx 118