1 // ATTENTION: The code in this file is highly EXPERIMENTAL. 2 // Adventurous users should note that the APIs will probably change. 3 4 #pragma once 5 6 // Before: 7 // Z = Conv(X, Y) 8 // B = Z + A 9 // After: 10 // B = Conv(X, Y, A) 11 // 12 // the pass can handle the following cases: 13 // case 1: A is 1D tensor and A.dim[0] == Z.dim[1] 14 // case 2: A is 1-element 1D tensor 15 16 #include <numeric> 17 18 #include "onnx/common/assertions.h" 19 #include "onnx/optimizer/pass.h" 20 21 namespace ONNX_NAMESPACE { 22 namespace optimization { 23 24 struct FuseAddBiasIntoConv final : public PredicateBasedPass { FuseAddBiasIntoConvfinal25 explicit FuseAddBiasIntoConv() 26 : PredicateBasedPass( 27 PassType::Fuse, 28 PassEfficiency::Complete, 29 PassOptimizationType::Compute) {} getPassNamefinal30 std::string getPassName() const override { 31 return "fuse_add_bias_into_conv"; 32 } patternMatchPredicatefinal33 bool patternMatchPredicate(Node* node) override { 34 return node->kind() == kAdd && node->inputs()[0]->node()->kind() == kConv && 35 node->inputs()[0]->node()->inputs().size() == 2; 36 } runTransformfinal37 bool runTransform(Node* n, Graph& graph, NodeDestroyType& destroy_current) 38 override { 39 // due to current broadcasting's constraint, Conv has to be the first 40 // operand 41 destroy_current = NodeDestroyType::DestroyZero; 42 auto orig_conv = n->inputs()[0]; 43 auto orig_bias = n->inputs()[1]; 44 // check if bias is Const or in graph's initializers 45 if (orig_bias->node()->kind() != kConstant && 46 orig_bias->node()->kind() != kParam) { 47 return false; 48 } 49 // check if conv is only used by Add 50 if (orig_conv->uses().size() > 1) { 51 return false; 52 } 53 auto conv_shape = orig_conv->sizes(); 54 auto bias_shape = orig_bias->sizes(); 55 auto weight_shape = orig_conv->node()->inputs()[1]->sizes(); 56 int64_t M = -1; 57 int64_t rank = -1; 58 // try to get feature M and rank from conv_shape 59 if (conv_shape.size() > 1 && conv_shape[1].is_int) { 60 M = conv_shape[1].dim; 61 rank = conv_shape.size(); 62 } 63 // try to get feature M and rank from weight_shape 64 if (weight_shape.size() > 0 && weight_shape[0].is_int) { 65 ONNX_ASSERT(M == -1 || M == weight_shape[0].dim); 66 M = weight_shape[0].dim; 67 ONNX_ASSERT( 68 rank == -1 || rank == static_cast<int64_t>(weight_shape.size())); 69 rank = weight_shape.size(); 70 } 71 int64_t num_el = 1; 72 for (int i = 0; i < static_cast<int64_t>(bias_shape.size()); ++i) { 73 if (bias_shape[i].is_int) { 74 num_el *= bias_shape[i].dim; 75 } else { 76 num_el = -1; 77 return false; 78 } 79 } 80 if (M == -1 || num_el == -1) { 81 // No enough information, bail out 82 return false; 83 } 84 if (rank < static_cast<int64_t>(bias_shape.size())) { 85 return false; 86 } 87 if (num_el == 1) { 88 if (orig_bias->node()->kind() != kParam && 89 orig_conv->node()->isBefore(orig_bias->node())) { 90 orig_bias->node()->moveBefore(orig_conv->node()); 91 } 92 Value* conv_3rd_input = orig_bias; 93 if (bias_shape.size() > 1) { 94 Node* squeeze = graph.create(kSqueeze, 1); 95 std::vector<int64_t> axes(bias_shape.size() - 1); 96 std::iota(axes.begin(), axes.end(), 0); 97 squeeze->is_(kaxes, std::move(axes)); 98 squeeze->addInput(conv_3rd_input); 99 conv_3rd_input = squeeze->output(); 100 squeeze->insertBefore(orig_conv->node()); 101 } 102 if (M > 1) { 103 Node* constant = graph.create(kConstant, 1); 104 Tensor t; 105 t.sizes().push_back(static_cast<int64_t>(1)); 106 t.int64s().push_back(M); 107 t.elem_type() = TensorProto_DataType_INT64; 108 Symbol sym = Symbol("value"); 109 constant->t_(sym, t); 110 std::vector<Dimension> s = {1}; 111 constant->output()->setSizes(s); 112 constant->output()->setElemType(TensorProto_DataType_INT64); 113 constant->insertBefore(orig_conv->node()); 114 Node* tile = graph.create(kTile, 1); 115 tile->addInput(conv_3rd_input); 116 tile->addInput(constant->output()); 117 conv_3rd_input = tile->output(); 118 tile->insertBefore(orig_conv->node()); 119 } 120 orig_conv->node()->addInput(conv_3rd_input); 121 } else if (rank > static_cast<int64_t>(bias_shape.size()) + 1) { 122 return false; 123 } else if ( 124 num_el == M && 125 bias_shape[1 + bias_shape.size() - static_cast<unsigned>(rank)].dim == 126 M) { 127 ONNX_ASSERT(bias_shape.size() > 1); 128 if (orig_bias->node()->kind() != kParam && 129 orig_conv->node()->isBefore(orig_bias->node())) { 130 orig_bias->node()->moveBefore(orig_conv->node()); 131 } 132 Node* squeeze = graph.create(kSqueeze, 1); 133 std::vector<int64_t> axes(bias_shape.size()); 134 std::iota(axes.begin(), axes.end(), static_cast<int64_t>(0)); 135 axes.erase( 136 axes.begin() + (1 + bias_shape.size() - static_cast<unsigned>(rank))); 137 squeeze->is_(kaxes, std::move(axes)); 138 squeeze->addInput(orig_bias); 139 squeeze->insertBefore(orig_conv->node()); 140 orig_conv->node()->addInput(squeeze->output()); 141 } else { 142 return false; 143 } 144 if (orig_conv->sizes().size() == 0 && n->output()->sizes().size() > 0) { 145 orig_conv->setSizes(n->output()->sizes()); 146 } 147 if (n->output()->elemType() != TensorProto_DataType_UNDEFINED) { 148 orig_conv->setElemType(n->output()->elemType()); 149 } 150 n->replaceAllUsesWith(orig_conv->node()); 151 destroy_current = NodeDestroyType::DestroyOne; 152 return true; 153 } 154 }; 155 156 } // namespace optimization 157 } // namespace ONNX_NAMESPACE 158