1 // ATTENTION: The code in this file is highly EXPERIMENTAL.
2 // Adventurous users should note that the APIs will probably change.
3 
4 #pragma once
5 
6 // Before:
7 //   Z = Conv(X, Y)
8 //   B = Z + A
9 // After:
10 //   B = Conv(X, Y, A)
11 //
12 // the pass can handle the following cases:
13 //   case 1: A is 1D tensor and A.dim[0] == Z.dim[1]
14 //   case 2: A is 1-element 1D tensor
15 
16 #include <numeric>
17 
18 #include "onnx/common/assertions.h"
19 #include "onnx/optimizer/pass.h"
20 
21 namespace ONNX_NAMESPACE {
22 namespace optimization {
23 
24 struct FuseAddBiasIntoConv final : public PredicateBasedPass {
FuseAddBiasIntoConvfinal25   explicit FuseAddBiasIntoConv()
26       : PredicateBasedPass(
27             PassType::Fuse,
28             PassEfficiency::Complete,
29             PassOptimizationType::Compute) {}
getPassNamefinal30   std::string getPassName() const override {
31     return "fuse_add_bias_into_conv";
32   }
patternMatchPredicatefinal33   bool patternMatchPredicate(Node* node) override {
34     return node->kind() == kAdd && node->inputs()[0]->node()->kind() == kConv &&
35         node->inputs()[0]->node()->inputs().size() == 2;
36   }
runTransformfinal37   bool runTransform(Node* n, Graph& graph, NodeDestroyType& destroy_current)
38       override {
39     // due to current broadcasting's constraint, Conv has to be the first
40     // operand
41     destroy_current = NodeDestroyType::DestroyZero;
42     auto orig_conv = n->inputs()[0];
43     auto orig_bias = n->inputs()[1];
44     // check if bias is Const or in graph's initializers
45     if (orig_bias->node()->kind() != kConstant &&
46         orig_bias->node()->kind() != kParam) {
47       return false;
48     }
49     // check if conv is only used by Add
50     if (orig_conv->uses().size() > 1) {
51       return false;
52     }
53     auto conv_shape = orig_conv->sizes();
54     auto bias_shape = orig_bias->sizes();
55     auto weight_shape = orig_conv->node()->inputs()[1]->sizes();
56     int64_t M = -1;
57     int64_t rank = -1;
58     // try to get feature M and rank from conv_shape
59     if (conv_shape.size() > 1 && conv_shape[1].is_int) {
60       M = conv_shape[1].dim;
61       rank = conv_shape.size();
62     }
63     // try to get feature M and rank from weight_shape
64     if (weight_shape.size() > 0 && weight_shape[0].is_int) {
65       ONNX_ASSERT(M == -1 || M == weight_shape[0].dim);
66       M = weight_shape[0].dim;
67       ONNX_ASSERT(
68           rank == -1 || rank == static_cast<int64_t>(weight_shape.size()));
69       rank = weight_shape.size();
70     }
71     int64_t num_el = 1;
72     for (int i = 0; i < static_cast<int64_t>(bias_shape.size()); ++i) {
73       if (bias_shape[i].is_int) {
74         num_el *= bias_shape[i].dim;
75       } else {
76         num_el = -1;
77         return false;
78       }
79     }
80     if (M == -1 || num_el == -1) {
81       // No enough information, bail out
82       return false;
83     }
84     if (rank < static_cast<int64_t>(bias_shape.size())) {
85       return false;
86     }
87     if (num_el == 1) {
88       if (orig_bias->node()->kind() != kParam &&
89           orig_conv->node()->isBefore(orig_bias->node())) {
90         orig_bias->node()->moveBefore(orig_conv->node());
91       }
92       Value* conv_3rd_input = orig_bias;
93       if (bias_shape.size() > 1) {
94         Node* squeeze = graph.create(kSqueeze, 1);
95         std::vector<int64_t> axes(bias_shape.size() - 1);
96         std::iota(axes.begin(), axes.end(), 0);
97         squeeze->is_(kaxes, std::move(axes));
98         squeeze->addInput(conv_3rd_input);
99         conv_3rd_input = squeeze->output();
100         squeeze->insertBefore(orig_conv->node());
101       }
102       if (M > 1) {
103         Node* constant = graph.create(kConstant, 1);
104         Tensor t;
105         t.sizes().push_back(static_cast<int64_t>(1));
106         t.int64s().push_back(M);
107         t.elem_type() = TensorProto_DataType_INT64;
108         Symbol sym = Symbol("value");
109         constant->t_(sym, t);
110         std::vector<Dimension> s = {1};
111         constant->output()->setSizes(s);
112         constant->output()->setElemType(TensorProto_DataType_INT64);
113         constant->insertBefore(orig_conv->node());
114         Node* tile = graph.create(kTile, 1);
115         tile->addInput(conv_3rd_input);
116         tile->addInput(constant->output());
117         conv_3rd_input = tile->output();
118         tile->insertBefore(orig_conv->node());
119       }
120       orig_conv->node()->addInput(conv_3rd_input);
121     } else if (rank > static_cast<int64_t>(bias_shape.size()) + 1) {
122       return false;
123     } else if (
124         num_el == M &&
125         bias_shape[1 + bias_shape.size() - static_cast<unsigned>(rank)].dim ==
126             M) {
127       ONNX_ASSERT(bias_shape.size() > 1);
128       if (orig_bias->node()->kind() != kParam &&
129           orig_conv->node()->isBefore(orig_bias->node())) {
130         orig_bias->node()->moveBefore(orig_conv->node());
131       }
132       Node* squeeze = graph.create(kSqueeze, 1);
133       std::vector<int64_t> axes(bias_shape.size());
134       std::iota(axes.begin(), axes.end(), static_cast<int64_t>(0));
135       axes.erase(
136           axes.begin() + (1 + bias_shape.size() - static_cast<unsigned>(rank)));
137       squeeze->is_(kaxes, std::move(axes));
138       squeeze->addInput(orig_bias);
139       squeeze->insertBefore(orig_conv->node());
140       orig_conv->node()->addInput(squeeze->output());
141     } else {
142       return false;
143     }
144     if (orig_conv->sizes().size() == 0 && n->output()->sizes().size() > 0) {
145       orig_conv->setSizes(n->output()->sizes());
146     }
147     if (n->output()->elemType() != TensorProto_DataType_UNDEFINED) {
148       orig_conv->setElemType(n->output()->elemType());
149     }
150     n->replaceAllUsesWith(orig_conv->node());
151     destroy_current = NodeDestroyType::DestroyOne;
152     return true;
153   }
154 };
155 
156 } // namespace optimization
157 } // namespace ONNX_NAMESPACE
158