1 //
2 // CPUTopKV2.cpp
3 // MNN
4 //
5 // Created by MNN on 2018/08/28.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "backend/cpu/CPUTopKV2.hpp"
10 #include "backend/cpu/CPUBackend.hpp"
11 #include "core/Macro.h"
12 #include "core/Concurrency.h"
13 #include "backend/cpu/compute/CommonOptFunction.h"
14
15 namespace MNN {
16
17 template <typename T>
18 class TopContainer {
19 public:
20 TopContainer() = delete;
TopContainer(int32_t k,int32_t rowSize)21 TopContainer(int32_t k, int32_t rowSize) : mK(k) {
22 mContainer.reserve(std::min(k, rowSize) + 1);
23 }
24
startCollecting(const T * values)25 void startCollecting(const T* values) {
26 mValues = values;
27 mContainer.clear();
28 }
push(int32_t a)29 void push(int32_t a) {
30 auto comparator = [this](int32_t a, int32_t b) { return compareFunc(a, b); };
31 if (mContainer.size() <= mK) {
32 mContainer.push_back(a);
33 if (mContainer.size() == mK + 1) {
34 std::make_heap(mContainer.begin(), mContainer.end(), comparator);
35 std::pop_heap(mContainer.begin(), mContainer.end(), comparator);
36 }
37 } else if (comparator(a, mContainer.front())) {
38 mContainer.back() = a;
39 std::push_heap(mContainer.begin(), mContainer.end(), comparator);
40 std::pop_heap(mContainer.begin(), mContainer.end(), comparator);
41 }
42 }
43
sortedResult()44 const std::vector<int32_t>& sortedResult() {
45 auto comparator = [this](int32_t a, int32_t b) { return compareFunc(a, b); };
46 if (mContainer.size() <= mK) {
47 std::sort(mContainer.begin(), mContainer.end(), comparator);
48 } else {
49 std::sort_heap(mContainer.begin(), mContainer.end() - 1, comparator);
50 mContainer.resize(mK);
51 }
52 return mContainer;
53 }
54
55 private:
56 int32_t mK;
57 std::vector<int32_t> mContainer;
58 const T* mValues = nullptr;
59
compareFunc(int32_t a,int32_t b) const60 bool compareFunc(int32_t a, int32_t b) const {
61 if (mValues[b] < mValues[a]) {
62 return true;
63 } else if (mValues[b] > mValues[a]) {
64 return false;
65 } else {
66 return a < b;
67 }
68 }
69 };
70
71 template <typename T>
findTopK(int32_t rowSize,int32_t numRows,const T * data,int32_t k,int32_t * outputIndexes,T * outputValues)72 void findTopK(int32_t rowSize, int32_t numRows, const T* data, int32_t k, int32_t* outputIndexes, T* outputValues) {
73 TopContainer<T> topc(k, rowSize);
74 for (int row = 0; row < numRows; row++) {
75 const T* valuesRow = data + row * rowSize;
76 topc.startCollecting(valuesRow);
77 for (int c = 0; c < rowSize; c++) {
78 topc.push(c);
79 }
80
81 int32_t* indexesRow = outputIndexes + row * k;
82 T* ouputRow = outputValues + row * k;
83
84 const auto& topK = topc.sortedResult();
85 std::copy(topK.begin(), topK.end(), indexesRow);
86 std::transform(topK.begin(), topK.end(), ouputRow, [valuesRow](const int32_t loc) { return valuesRow[loc]; });
87 }
88 }
89
CPUTopKV2(Backend * b)90 CPUTopKV2::CPUTopKV2(Backend* b) : MNN::Execution(b) {
91 // nothing to do
92 }
93
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)94 ErrorCode CPUTopKV2::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
95 const int k = inputs[1]->host<int32_t>()[0];
96 auto inputTensor = inputs[0];
97 auto outputData = outputs[0];
98 auto outputIndices = outputs[1];
99
100 const int inputDimension = inputTensor->buffer().dimensions;
101
102 const int rowSize = inputTensor->buffer().dim[inputDimension - 1].extent;
103 const int rowC4Blocks = rowSize / 4;
104 const int rowRemain = rowSize % 4;
105 const int rowC4ElementSize = rowC4Blocks * 4;
106 MNN_ASSERT(k <= rowSize);
107 const int numRows = inputTensor->elementSize() / rowSize;
108
109 if (k == 1) {
110 if (halide_type_float == inputTensor->getType().code) {
111 float* inputData = inputTensor->host<float>();
112 float* topkData = outputData->host<float>();
113 int32_t* indicesData = outputIndices->host<int32_t>();
114
115 MNN_CONCURRENCY_BEGIN(i, numRows) {
116 float* inputRowData = inputData + i * rowSize;
117 float* rowTopkData = topkData + i * k;
118 int32_t* rowTopkIndexData = indicesData + i * k;
119 MNNVectorTop1Float(inputRowData, rowTopkData, rowTopkIndexData, rowC4Blocks);
120 for (int j = 0; j < rowRemain; j++) {
121 int index = rowC4ElementSize + j;
122 float value = inputRowData[index];
123 if (value > rowTopkData[0]) {
124 rowTopkData[0] = value;
125 rowTopkIndexData[0] = index;
126 }
127 }
128 }
129 MNN_CONCURRENCY_END();
130 } else if (halide_type_int == inputTensor->getType().code && 32 == inputTensor->getType().bits) {
131 int32_t* inputData = inputTensor->host<int32_t>();
132 int32_t* topkData = outputData->host<int32_t>();
133 int32_t* indicesData = outputIndices->host<int32_t>();
134 MNN_CONCURRENCY_BEGIN(i, numRows) {
135 int32_t* inputRowData = inputData + i * rowSize;
136 int32_t* rowTopkData = topkData + i * k;
137 int32_t* rowTopkIndexData = indicesData + i * k;
138 MNNVectorTop1Int32(inputRowData, rowTopkData, rowTopkIndexData, rowC4Blocks);
139 for (int j = 0; j < rowRemain; j++) {
140 int index = rowC4ElementSize + j;
141 int32_t value = inputRowData[index];
142 if (value > rowTopkData[0]) {
143 rowTopkData[0] = value;
144 rowTopkIndexData[0] = index;
145 }
146 }
147 }
148 MNN_CONCURRENCY_END();
149 } else {
150 MNN_PRINT("TopKV2 data type not supported\n");
151 MNN_ASSERT(false);
152 }
153
154 return NO_ERROR;
155 }
156
157 if (halide_type_float == inputTensor->getType().code) {
158 auto inputData = inputTensor->host<float>();
159 auto topkData = outputData->host<float>();
160 int* indicesData = outputIndices->host<int32_t>();
161 findTopK<float>(rowSize, numRows, inputData, k, indicesData, topkData);
162 } else if(halide_type_int == inputTensor->getType().code && 32 == inputTensor->getType().bits) {
163 auto inputData = inputTensor->host<int32_t>();
164 auto topkData = outputData->host<int32_t>();
165 int* indicesData = outputIndices->host<int32_t>();
166 findTopK<int32_t>(rowSize, numRows, inputData, k, indicesData, topkData);
167 } else {
168 MNN_PRINT("TODO\n");
169 MNN_ASSERT(false);
170 }
171 return NO_ERROR;
172 }
173
174 class CPUTopKV2Creator : public CPUBackend::Creator {
175 public:
onCreate(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const MNN::Op * op,Backend * backend) const176 virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
177 const MNN::Op* op, Backend* backend) const override {
178 return new CPUTopKV2(backend);
179 }
180 };
181
182 REGISTER_CPU_OP_CREATOR(CPUTopKV2Creator, OpType_TopKV2);
183
184 } // namespace MNN
185