1 //
2 //  CPUBinary.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2018/08/02.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "CPUBinary.hpp"
10 #include "CPUBackend.hpp"
11 #include "compute/CommonOptFunction.h"
12 #include "compute/ConvOpt.h"
13 #include "core/Macro.h"
14 #include "core/Concurrency.h"
15 #include "core/OpCommonUtils.hpp"
16 #include "BinaryUtils.hpp"
17 #include "math/Vec.hpp"
18 using Vec4 = MNN::Math::Vec<float, 4>;
19 
20 namespace MNN {
21 
onResize(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)22 ErrorCode CPUBinary::onResize(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
23     const int input0DataCount = inputs[0]->elementSize();
24     const int input1DataCount = inputs[1]->elementSize();
25     if (input1DataCount == input0DataCount) {
26         mNeedBroadcastIndex = -1;
27         mTotalSize = input1DataCount;
28     } else if (input0DataCount == 1) {
29         mNeedBroadcastIndex = 0;
30         mTotalSize = input1DataCount;
31     } else {
32         mNeedBroadcastIndex = 1;
33         mTotalSize = input0DataCount;
34     }
35     MNN_ASSERT(mTotalSize == outputs[0]->elementSize());
36     return NO_ERROR;
37 }
38 
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)39 ErrorCode CPUBinary::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
40     const int input0DataCount = ((CPUBackend*)backend())->getTensorSize(inputs[0]);
41     const int input1DataCount = ((CPUBackend*)backend())->getTensorSize(inputs[1]);
42     if (input1DataCount == input0DataCount) {
43         mNeedBroadcastIndex = -1;
44         mTotalSize = input1DataCount;
45     } else if (input0DataCount == 1) {
46         mNeedBroadcastIndex = 0;
47         mTotalSize = input1DataCount;
48     } else {
49         mNeedBroadcastIndex = 1;
50         mTotalSize = input0DataCount;
51     }
52     auto input  = inputs[0];
53     auto input1 = inputs[1];
54     auto output = outputs[0];
55 
56     auto schedule = ((CPUBackend*)backend())->multiThreadDivide(mTotalSize);
57     auto input0Ptr = input->host<uint8_t>();
58     auto input1Ptr = input1->host<uint8_t>();
59     auto outputPtr = output->host<uint8_t>();
60     int inpBytes = input->getType().bytes();
61     int outBytes = output->getType().bytes();
62     if (halide_type_float == input->getType().code) {
63         inpBytes = static_cast<CPUBackend*>(backend())->functions()->bytes;
64     }
65     if (halide_type_float == output->getType().code) {
66         outBytes = static_cast<CPUBackend*>(backend())->functions()->bytes;
67     }
68     auto precision = static_cast<CPUBackend*>(backend())->precisionMode();
69     MNN_CONCURRENCY_BEGIN(tId, schedule.second) {
70         int start = schedule.first * (int)tId;
71         int realSize = schedule.first;
72         if (tId == schedule.second -1 ) {
73             realSize = mTotalSize - start;
74         }
75         if (realSize > 0) {
76             auto inp0 = input0Ptr + start * inpBytes;
77             auto inp1 = input1Ptr + start * inpBytes;
78             if (mNeedBroadcastIndex == 0) {
79                 inp0 = input0Ptr;
80             } else if (mNeedBroadcastIndex == 1) {
81                 inp1 = input1Ptr;
82             }
83             auto out = outputPtr + start * outBytes;
84             mProc(out, inp0, inp1, realSize, mNeedBroadcastIndex);
85         }
86     }
87     MNN_CONCURRENCY_END();
88     return NO_ERROR;
89 }
90 
selectForFloat(int type)91 MNNBinaryExecute CPUBinary::selectForFloat(int type) {
92     auto vecFunction = selectVector<Vec4, 4>(type);
93     if (nullptr != vecFunction) {
94         return vecFunction;
95     }
96     switch (type) {
97         case BinaryOpOperation_REALDIV:
98             return execute<float, float, BinaryRealDiv<float, float, float>>;
99         case BinaryOpOperation_FLOORDIV:
100             return execute<float, float, BinaryFloorDiv<float, float, float>>;
101         case BinaryOpOperation_FLOORMOD:
102             return execute<float, float, BinaryFloorMod<float, float, float>>;
103         case BinaryOpOperation_POW:
104             return execute<float, float, BinaryPow<float, float, float>>;
105         case BinaryOpOperation_ATAN2:
106             return execute<float, float, BinaryAtan2<float, float, float>>;
107         case BinaryOpOperation_MOD:
108             return execute<float, float, BinaryMod<float, float, float>>;
109         case BinaryOpOperation_GREATER:
110             return execute<float, int32_t, BinaryGreater<float, float, int32_t>>;
111         case BinaryOpOperation_LESS:
112             return execute<float, int32_t, BinaryLess<float, float, int32_t>>;
113         case BinaryOpOperation_LESS_EQUAL:
114             return execute<float, int32_t, BinaryLessEqual<float, float, int32_t>>;
115         case BinaryOpOperation_GREATER_EQUAL:
116             return execute<float, int32_t, BinaryGreaterEqual<float, float, int32_t>>;
117         case BinaryOpOperation_EQUAL:
118             return execute<float, int32_t, BinaryEqual<float, float, int32_t>>;
119         case BinaryOpOperation_NOTEQUAL:
120             return execute<float, int32_t, BinaryNotEqual<float, float, int32_t>>;
121         default:
122             MNN_ASSERT(false);
123             break;
124     }
125     return nullptr;
126 }
127 
selectForInt(int type)128 static MNNBinaryExecute selectForInt(int type) {
129     switch (type) {
130         case BinaryOpOperation_MUL:
131             return execute<int32_t, int32_t, BinaryMul<int32_t, int32_t, int32_t>>;
132         case BinaryOpOperation_ADD:
133             return execute<int32_t, int32_t, BinaryAdd<int32_t, int32_t, int32_t>>;
134         case BinaryOpOperation_SUB:
135             return execute<int32_t, int32_t, BinarySub<int32_t, int32_t, int32_t>>;
136         case BinaryOpOperation_REALDIV:
137             return execute<int32_t, int32_t, BinaryRealDiv<int32_t, int32_t, int32_t>>;
138         case BinaryOpOperation_MINIMUM:
139             return execute<int32_t, int32_t, BinaryMin<int32_t, int32_t, int32_t>>;
140             break;
141         case BinaryOpOperation_MAXIMUM:
142             return execute<int32_t, int32_t, BinaryMax<int32_t, int32_t, int32_t>>;
143             break;
144         case BinaryOpOperation_GREATER:
145             return execute<int32_t, int32_t, BinaryGreater<int32_t, int32_t, int32_t>>;
146             break;
147         case BinaryOpOperation_LESS:
148             return execute<int32_t, int32_t, BinaryLess<int32_t, int32_t, int32_t>>;
149             break;
150         case BinaryOpOperation_LESS_EQUAL:
151             return execute<int32_t, int32_t, BinaryLessEqual<int32_t, int32_t, int32_t>>;
152             break;
153         case BinaryOpOperation_GREATER_EQUAL:
154             return execute<int32_t, int32_t, BinaryGreaterEqual<int32_t, int32_t, int32_t>>;
155             break;
156         case BinaryOpOperation_EQUAL:
157             return execute<int32_t, int32_t, BinaryEqual<int32_t, int32_t, int32_t>>;
158             break;
159         case BinaryOpOperation_FLOORDIV:
160             return execute<int32_t, int32_t, BinaryFloorDiv<int32_t, int32_t, int32_t>>;
161             break;
162         case BinaryOpOperation_FLOORMOD:
163             return execute<int32_t, int32_t, BinaryFloorMod<int32_t, int32_t, int32_t>>;
164             break;
165         case BinaryOpOperation_SquaredDifference:
166             return execute<int32_t, int32_t, BinarySquaredDifference<int32_t, int32_t, int32_t>>;
167             break;
168         case BinaryOpOperation_LOGICALOR:
169             return execute<int32_t, int32_t, BinaryLogicalOr<int32_t, int32_t, int32_t>>;
170             break;
171         case BinaryOpOperation_NOTEQUAL:
172             return execute<int32_t, int32_t, BinaryNotEqual<int32_t, int32_t, int32_t>>;
173             break;
174         case BinaryOpOperation_MOD:
175             return execute<int32_t, int32_t, BinaryMod<int32_t, int32_t, int32_t>>;
176             break;
177         default:
178             MNN_ASSERT(false);
179             break;
180     }
181     return nullptr;
182 }
183 
184 class CPUBinaryCreator : public CPUBackend::Creator {
185 public:
onCreate(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const MNN::Op * op,Backend * backend) const186     virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
187                                 const MNN::Op* op, Backend* backend) const override {
188         int32_t type = op->main_as_BinaryOp()->opType();
189         auto dataType = inputs[0]->getType();
190         auto core = static_cast<CPUBackend*>(backend)->functions();
191         if (dataType.bits == 32) {
192             if (dataType.code == halide_type_int) {
193                 auto func = selectForInt(type);
194                 if (nullptr == func) {
195                     return nullptr;
196                 }
197                 return new CPUBinary(backend, func);
198             } else if (dataType.code == halide_type_float) {
199                 auto func = core->MNNSelectBinaryFunctionForFloat(type);
200                 if (nullptr == func) {
201                     return nullptr;
202                 }
203                 return new CPUBinary(backend, func);
204             }
205         }
206         MNN_ERROR("CpuBinary: unsupported data type (bits: %d, code: %d)\n",
207                   dataType.bits, dataType.code);
208         return nullptr;
209     }
210 };
211 
212 REGISTER_CPU_OP_CREATOR(CPUBinaryCreator, OpType_BinaryOp);
213 
214 } // namespace MNN
215