1 //
2 //  CPUTopKV2.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2018/08/28.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "backend/cpu/CPUTopKV2.hpp"
10 #include "backend/cpu/CPUBackend.hpp"
11 #include "core/Macro.h"
12 #include "core/Concurrency.h"
13 #include "backend/cpu/compute/CommonOptFunction.h"
14 
15 namespace MNN {
16 
17 template <typename T>
18 class TopContainer {
19 public:
20     TopContainer() = delete;
TopContainer(int32_t k,int32_t rowSize)21     TopContainer(int32_t k, int32_t rowSize) : mK(k) {
22         mContainer.reserve(std::min(k, rowSize) + 1);
23     }
24 
startCollecting(const T * values)25     void startCollecting(const T* values) {
26         mValues = values;
27         mContainer.clear();
28     }
push(int32_t a)29     void push(int32_t a) {
30         auto comparator = [this](int32_t a, int32_t b) { return compareFunc(a, b); };
31         if (mContainer.size() <= mK) {
32             mContainer.push_back(a);
33             if (mContainer.size() == mK + 1) {
34                 std::make_heap(mContainer.begin(), mContainer.end(), comparator);
35                 std::pop_heap(mContainer.begin(), mContainer.end(), comparator);
36             }
37         } else if (comparator(a, mContainer.front())) {
38             mContainer.back() = a;
39             std::push_heap(mContainer.begin(), mContainer.end(), comparator);
40             std::pop_heap(mContainer.begin(), mContainer.end(), comparator);
41         }
42     }
43 
sortedResult()44     const std::vector<int32_t>& sortedResult() {
45         auto comparator = [this](int32_t a, int32_t b) { return compareFunc(a, b); };
46         if (mContainer.size() <= mK) {
47             std::sort(mContainer.begin(), mContainer.end(), comparator);
48         } else {
49             std::sort_heap(mContainer.begin(), mContainer.end() - 1, comparator);
50             mContainer.resize(mK);
51         }
52         return mContainer;
53     }
54 
55 private:
56     int32_t mK;
57     std::vector<int32_t> mContainer;
58     const T* mValues = nullptr;
59 
compareFunc(int32_t a,int32_t b) const60     bool compareFunc(int32_t a, int32_t b) const {
61         if (mValues[b] < mValues[a]) {
62             return true;
63         } else if (mValues[b] > mValues[a]) {
64             return false;
65         } else {
66             return a < b;
67         }
68     }
69 };
70 
71 template <typename T>
findTopK(int32_t rowSize,int32_t numRows,const T * data,int32_t k,int32_t * outputIndexes,T * outputValues)72 void findTopK(int32_t rowSize, int32_t numRows, const T* data, int32_t k, int32_t* outputIndexes, T* outputValues) {
73     TopContainer<T> topc(k, rowSize);
74     for (int row = 0; row < numRows; row++) {
75         const T* valuesRow = data + row * rowSize;
76         topc.startCollecting(valuesRow);
77         for (int c = 0; c < rowSize; c++) {
78             topc.push(c);
79         }
80 
81         int32_t* indexesRow = outputIndexes + row * k;
82         T* ouputRow         = outputValues + row * k;
83 
84         const auto& topK = topc.sortedResult();
85         std::copy(topK.begin(), topK.end(), indexesRow);
86         std::transform(topK.begin(), topK.end(), ouputRow, [valuesRow](const int32_t loc) { return valuesRow[loc]; });
87     }
88 }
89 
CPUTopKV2(Backend * b)90 CPUTopKV2::CPUTopKV2(Backend* b) : MNN::Execution(b) {
91     // nothing to do
92 }
93 
onExecute(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs)94 ErrorCode CPUTopKV2::onExecute(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs) {
95     const int k        = inputs[1]->host<int32_t>()[0];
96     auto inputTensor   = inputs[0];
97     auto outputData    = outputs[0];
98     auto outputIndices = outputs[1];
99 
100     const int inputDimension = inputTensor->buffer().dimensions;
101 
102     const int rowSize = inputTensor->buffer().dim[inputDimension - 1].extent;
103     const int rowC4Blocks = rowSize / 4;
104     const int rowRemain = rowSize % 4;
105     const int rowC4ElementSize = rowC4Blocks * 4;
106     MNN_ASSERT(k <= rowSize);
107     const int numRows = inputTensor->elementSize() / rowSize;
108 
109     if (k == 1) {
110         if (halide_type_float == inputTensor->getType().code) {
111             float* inputData   = inputTensor->host<float>();
112             float* topkData    = outputData->host<float>();
113             int32_t* indicesData = outputIndices->host<int32_t>();
114 
115             MNN_CONCURRENCY_BEGIN(i, numRows) {
116                 float* inputRowData = inputData + i * rowSize;
117                 float* rowTopkData = topkData + i * k;
118                 int32_t* rowTopkIndexData = indicesData + i * k;
119                 MNNVectorTop1Float(inputRowData, rowTopkData, rowTopkIndexData, rowC4Blocks);
120                 for (int j = 0; j < rowRemain; j++) {
121                     int index = rowC4ElementSize + j;
122                     float value = inputRowData[index];
123                     if (value > rowTopkData[0]) {
124                         rowTopkData[0] = value;
125                         rowTopkIndexData[0] = index;
126                     }
127                 }
128             }
129             MNN_CONCURRENCY_END();
130         } else if (halide_type_int == inputTensor->getType().code && 32 == inputTensor->getType().bits) {
131             int32_t* inputData   = inputTensor->host<int32_t>();
132             int32_t* topkData    = outputData->host<int32_t>();
133             int32_t* indicesData = outputIndices->host<int32_t>();
134             MNN_CONCURRENCY_BEGIN(i, numRows) {
135                 int32_t* inputRowData = inputData + i * rowSize;
136                 int32_t* rowTopkData = topkData + i * k;
137                 int32_t* rowTopkIndexData = indicesData + i * k;
138                 MNNVectorTop1Int32(inputRowData, rowTopkData, rowTopkIndexData, rowC4Blocks);
139                 for (int j = 0; j < rowRemain; j++) {
140                     int index = rowC4ElementSize + j;
141                     int32_t value = inputRowData[index];
142                     if (value > rowTopkData[0]) {
143                         rowTopkData[0] = value;
144                         rowTopkIndexData[0] = index;
145                     }
146                 }
147             }
148             MNN_CONCURRENCY_END();
149         } else {
150             MNN_PRINT("TopKV2 data type not supported\n");
151             MNN_ASSERT(false);
152         }
153 
154         return NO_ERROR;
155     }
156 
157     if (halide_type_float == inputTensor->getType().code) {
158         auto inputData   = inputTensor->host<float>();
159         auto topkData    = outputData->host<float>();
160         int* indicesData = outputIndices->host<int32_t>();
161         findTopK<float>(rowSize, numRows, inputData, k, indicesData, topkData);
162     } else if(halide_type_int == inputTensor->getType().code && 32 == inputTensor->getType().bits) {
163         auto inputData   = inputTensor->host<int32_t>();
164         auto topkData    = outputData->host<int32_t>();
165         int* indicesData = outputIndices->host<int32_t>();
166         findTopK<int32_t>(rowSize, numRows, inputData, k, indicesData, topkData);
167     } else {
168         MNN_PRINT("TODO\n");
169         MNN_ASSERT(false);
170     }
171     return NO_ERROR;
172 }
173 
174 class CPUTopKV2Creator : public CPUBackend::Creator {
175 public:
onCreate(const std::vector<Tensor * > & inputs,const std::vector<Tensor * > & outputs,const MNN::Op * op,Backend * backend) const176     virtual Execution* onCreate(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs,
177                                 const MNN::Op* op, Backend* backend) const override {
178         return new CPUTopKV2(backend);
179     }
180 };
181 
182 REGISTER_CPU_OP_CREATOR(CPUTopKV2Creator, OpType_TopKV2);
183 
184 } // namespace MNN
185