1 //
2 // CPUBinary.cpp
3 // MNN
4 //
5 // Created by MNN on 2018/08/02.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "CPUBinary.hpp"
10 #include "CPUBackend.hpp"
11 #include "compute/CommonOptFunction.h"
12 #include "compute/ConvOpt.h"
13 #include "core/Macro.h"
14 #include "core/Concurrency.h"
15 #include "core/OpCommonUtils.hpp"
16 #include "BinaryUtils.hpp"
17 #include "math/Vec.hpp"
18 using Vec4 = MNN::Math::Vec<float, 4>;
19
20 namespace MNN {
21
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)22 ErrorCode CPUBinary::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
23 const int input0DataCount = inputs[0]->elementSize();
24 const int input1DataCount = inputs[1]->elementSize();
25 if (input1DataCount == input0DataCount) {
26 mNeedBroadcastIndex = -1;
27 mTotalSize = input1DataCount;
28 } else if (input0DataCount == 1) {
29 mNeedBroadcastIndex = 0;
30 mTotalSize = input1DataCount;
31 } else {
32 mNeedBroadcastIndex = 1;
33 mTotalSize = input0DataCount;
34 }
35 MNN_ASSERT(mTotalSize == outputs[0]->elementSize());
36 return NO_ERROR;
37 }
38
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)39 ErrorCode CPUBinary::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
40 const int input0DataCount = ((CPUBackend*)backend())->getTensorSize(inputs[0]);
41 const int input1DataCount = ((CPUBackend*)backend())->getTensorSize(inputs[1]);
42 if (input1DataCount == input0DataCount) {
43 mNeedBroadcastIndex = -1;
44 mTotalSize = input1DataCount;
45 } else if (input0DataCount == 1) {
46 mNeedBroadcastIndex = 0;
47 mTotalSize = input1DataCount;
48 } else {
49 mNeedBroadcastIndex = 1;
50 mTotalSize = input0DataCount;
51 }
52 auto input = inputs[0];
53 auto input1 = inputs[1];
54 auto output = outputs[0];
55
56 auto schedule = ((CPUBackend*)backend())->multiThreadDivide(mTotalSize);
57 auto input0Ptr = input->host<uint8_t>();
58 auto input1Ptr = input1->host<uint8_t>();
59 auto outputPtr = output->host<uint8_t>();
60 int inpBytes = input->getType().bytes();
61 int outBytes = output->getType().bytes();
62 if (halide_type_float == input->getType().code) {
63 inpBytes = static_cast<CPUBackend*>(backend())->functions()->bytes;
64 }
65 if (halide_type_float == output->getType().code) {
66 outBytes = static_cast<CPUBackend*>(backend())->functions()->bytes;
67 }
68 auto precision = static_cast<CPUBackend*>(backend())->precisionMode();
69 MNN_CONCURRENCY_BEGIN(tId, schedule.second) {
70 int start = schedule.first * (int)tId;
71 int realSize = schedule.first;
72 if (tId == schedule.second -1 ) {
73 realSize = mTotalSize - start;
74 }
75 if (realSize > 0) {
76 auto inp0 = input0Ptr + start * inpBytes;
77 auto inp1 = input1Ptr + start * inpBytes;
78 if (mNeedBroadcastIndex == 0) {
79 inp0 = input0Ptr;
80 } else if (mNeedBroadcastIndex == 1) {
81 inp1 = input1Ptr;
82 }
83 auto out = outputPtr + start * outBytes;
84 mProc(out, inp0, inp1, realSize, mNeedBroadcastIndex);
85 }
86 }
87 MNN_CONCURRENCY_END();
88 return NO_ERROR;
89 }
90
selectForFloat(int type)91 MNNBinaryExecute CPUBinary::selectForFloat(int type) {
92 auto vecFunction = selectVector<Vec4, 4>(type);
93 if (nullptr != vecFunction) {
94 return vecFunction;
95 }
96 switch (type) {
97 case BinaryOpOperation_REALDIV:
98 return execute<float, float, BinaryRealDiv<float, float, float>>;
99 case BinaryOpOperation_FLOORDIV:
100 return execute<float, float, BinaryFloorDiv<float, float, float>>;
101 case BinaryOpOperation_FLOORMOD:
102 return execute<float, float, BinaryFloorMod<float, float, float>>;
103 case BinaryOpOperation_POW:
104 return execute<float, float, BinaryPow<float, float, float>>;
105 case BinaryOpOperation_ATAN2:
106 return execute<float, float, BinaryAtan2<float, float, float>>;
107 case BinaryOpOperation_MOD:
108 return execute<float, float, BinaryMod<float, float, float>>;
109 case BinaryOpOperation_GREATER:
110 return execute<float, int32_t, BinaryGreater<float, float, int32_t>>;
111 case BinaryOpOperation_LESS:
112 return execute<float, int32_t, BinaryLess<float, float, int32_t>>;
113 case BinaryOpOperation_LESS_EQUAL:
114 return execute<float, int32_t, BinaryLessEqual<float, float, int32_t>>;
115 case BinaryOpOperation_GREATER_EQUAL:
116 return execute<float, int32_t, BinaryGreaterEqual<float, float, int32_t>>;
117 case BinaryOpOperation_EQUAL:
118 return execute<float, int32_t, BinaryEqual<float, float, int32_t>>;
119 case BinaryOpOperation_NOTEQUAL:
120 return execute<float, int32_t, BinaryNotEqual<float, float, int32_t>>;
121 default:
122 MNN_ASSERT(false);
123 break;
124 }
125 return nullptr;
126 }
127
selectForInt(int type)128 static MNNBinaryExecute selectForInt(int type) {
129 switch (type) {
130 case BinaryOpOperation_MUL:
131 return execute<int32_t, int32_t, BinaryMul<int32_t, int32_t, int32_t>>;
132 case BinaryOpOperation_ADD:
133 return execute<int32_t, int32_t, BinaryAdd<int32_t, int32_t, int32_t>>;
134 case BinaryOpOperation_SUB:
135 return execute<int32_t, int32_t, BinarySub<int32_t, int32_t, int32_t>>;
136 case BinaryOpOperation_REALDIV:
137 return execute<int32_t, int32_t, BinaryRealDiv<int32_t, int32_t, int32_t>>;
138 case BinaryOpOperation_MINIMUM:
139 return execute<int32_t, int32_t, BinaryMin<int32_t, int32_t, int32_t>>;
140 break;
141 case BinaryOpOperation_MAXIMUM:
142 return execute<int32_t, int32_t, BinaryMax<int32_t, int32_t, int32_t>>;
143 break;
144 case BinaryOpOperation_GREATER:
145 return execute<int32_t, int32_t, BinaryGreater<int32_t, int32_t, int32_t>>;
146 break;
147 case BinaryOpOperation_LESS:
148 return execute<int32_t, int32_t, BinaryLess<int32_t, int32_t, int32_t>>;
149 break;
150 case BinaryOpOperation_LESS_EQUAL:
151 return execute<int32_t, int32_t, BinaryLessEqual<int32_t, int32_t, int32_t>>;
152 break;
153 case BinaryOpOperation_GREATER_EQUAL:
154 return execute<int32_t, int32_t, BinaryGreaterEqual<int32_t, int32_t, int32_t>>;
155 break;
156 case BinaryOpOperation_EQUAL:
157 return execute<int32_t, int32_t, BinaryEqual<int32_t, int32_t, int32_t>>;
158 break;
159 case BinaryOpOperation_FLOORDIV:
160 return execute<int32_t, int32_t, BinaryFloorDiv<int32_t, int32_t, int32_t>>;
161 break;
162 case BinaryOpOperation_FLOORMOD:
163 return execute<int32_t, int32_t, BinaryFloorMod<int32_t, int32_t, int32_t>>;
164 break;
165 case BinaryOpOperation_SquaredDifference:
166 return execute<int32_t, int32_t, BinarySquaredDifference<int32_t, int32_t, int32_t>>;
167 break;
168 case BinaryOpOperation_LOGICALOR:
169 return execute<int32_t, int32_t, BinaryLogicalOr<int32_t, int32_t, int32_t>>;
170 break;
171 case BinaryOpOperation_NOTEQUAL:
172 return execute<int32_t, int32_t, BinaryNotEqual<int32_t, int32_t, int32_t>>;
173 break;
174 case BinaryOpOperation_MOD:
175 return execute<int32_t, int32_t, BinaryMod<int32_t, int32_t, int32_t>>;
176 break;
177 default:
178 MNN_ASSERT(false);
179 break;
180 }
181 return nullptr;
182 }
183
184 class CPUBinaryCreator : public CPUBackend::Creator {
185 public:
onCreate(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const MNN::Op * op,Backend * backend) const186 virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
187 const MNN::Op* op, Backend* backend) const override {
188 int32_t type = op->main_as_BinaryOp()->opType();
189 auto dataType = inputs[0]->getType();
190 auto core = static_cast<CPUBackend*>(backend)->functions();
191 if (dataType.bits == 32) {
192 if (dataType.code == halide_type_int) {
193 auto func = selectForInt(type);
194 if (nullptr == func) {
195 return nullptr;
196 }
197 return new CPUBinary(backend, func);
198 } else if (dataType.code == halide_type_float) {
199 auto func = core->MNNSelectBinaryFunctionForFloat(type);
200 if (nullptr == func) {
201 return nullptr;
202 }
203 return new CPUBinary(backend, func);
204 }
205 }
206 MNN_ERROR("CpuBinary: unsupported data type (bits: %d, code: %d)\n",
207 dataType.bits, dataType.code);
208 return nullptr;
209 }
210 };
211
212 REGISTER_CPU_OP_CREATOR(CPUBinaryCreator, OpType_BinaryOp);
213
214 } // namespace MNN
215