1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GEMMLOWP_META_MULTI_THREAD_GEMM_H_
16 #define GEMMLOWP_META_MULTI_THREAD_GEMM_H_
17 
18 #include "../internal/common.h"
19 
20 #ifdef GEMMLOWP_NEON
21 
22 #include "legacy_multi_thread_common.h"
23 #include "legacy_multi_thread_gemv.h"
24 #include "legacy_operations_common.h"
25 #include "legacy_single_thread_gemm.h"
26 
27 namespace gemmlowp {
28 namespace meta {
29 namespace internal {
30 
31 const std::int32_t kMaxCacheFriendlySize = 256 * 1024;
32 
33 template <typename IN_TYPE, typename OUT_TYPE, typename F>
CacheFriendlyMatrixMatrix(std::uint8_t * scratch,const IN_TYPE * lhs,const IN_TYPE * rhs,std::int32_t m,std::int32_t n,std::int32_t k,OUT_TYPE * result,std::int32_t result_stride,const F & operation)34 void CacheFriendlyMatrixMatrix(std::uint8_t* scratch, const IN_TYPE* lhs,
35                                const IN_TYPE* rhs, std::int32_t m,
36                                std::int32_t n, std::int32_t k, OUT_TYPE* result,
37                                std::int32_t result_stride, const F& operation) {
38   const std::int32_t rhs_size = n * k * sizeof(IN_TYPE);
39   if (rhs_size > kMaxCacheFriendlySize) {
40     const std::int32_t optimal_n =
41         std::max(1, 4 * (kMaxCacheFriendlySize / (k * 4)));
42     const std::int32_t chunks_count_less_one = n / optimal_n - 1;
43     const std::int32_t chunk_size = optimal_n * k;
44     for (int i = 0; i < chunks_count_less_one; ++i) {
45       operation.ExecuteCacheFriendlyMatrixMatrix(
46           scratch, lhs, rhs + i * chunk_size, m, optimal_n, k,
47           result + i * optimal_n, result_stride);
48     }
49     const std::int32_t n_left = n - chunks_count_less_one * optimal_n;
50     operation.ExecuteCacheFriendlyMatrixMatrix(
51         scratch, lhs, rhs + chunks_count_less_one * chunk_size, m, n_left, k,
52         result + chunks_count_less_one * optimal_n, result_stride);
53   } else {
54     operation.ExecuteCacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k,
55                                                result, result_stride);
56   }
57 }
58 
59 class GemmQuantized8BitOperation : public Quantized8BitOperation {
60  public:
GemmQuantized8BitOperation(std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t sum_offset,std::int32_t multiplier,std::int32_t shift)61   GemmQuantized8BitOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
62                              std::int32_t sum_offset, std::int32_t multiplier,
63                              std::int32_t shift)
64       : Quantized8BitOperation(lhs_offset, rhs_offset, sum_offset, multiplier,
65                                shift) {}
66 
ExecuteMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::uint8_t * result,std::int32_t result_stride)67   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
68                            const std::uint8_t* rhs, std::int32_t m,
69                            std::int32_t n, std::int32_t k, std::uint8_t* result,
70                            std::int32_t result_stride) const {
71     CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride,
72                               *this);
73   }
74 
ExecuteCacheFriendlyMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::uint8_t * result,std::int32_t result_stride)75   void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch,
76                                         const std::uint8_t* lhs,
77                                         const std::uint8_t* rhs, std::int32_t m,
78                                         std::int32_t n, std::int32_t k,
79                                         std::uint8_t* result,
80                                         std::int32_t result_stride) const {
81     gemm_q8_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
82                     sum_offset, multiplier, shift, result, result_stride);
83   }
84 
ScratchPerThread(std::int32_t m,std::int32_t n,std::int32_t k)85   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
86                                        std::int32_t k) {
87     return 4 * kMaxCacheFriendlySize;
88   }
89 };
90 
91 class GemmFloatOperation : public FloatOperation {
92  public:
GemmFloatOperation(std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset)93   GemmFloatOperation(std::int32_t lhs_offset, std::int32_t rhs_offset,
94                      float result_offset)
95       : FloatOperation(lhs_offset, rhs_offset, result_offset) {}
96 
ExecuteMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,float * result,std::int32_t result_stride)97   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
98                            const std::uint8_t* rhs, std::int32_t m,
99                            std::int32_t n, std::int32_t k, float* result,
100                            std::int32_t result_stride) const {
101     CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride,
102                               *this);
103   }
104 
ExecuteCacheFriendlyMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,float * result,std::int32_t result_stride)105   void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch,
106                                         const std::uint8_t* lhs,
107                                         const std::uint8_t* rhs, std::int32_t m,
108                                         std::int32_t n, std::int32_t k,
109                                         float* result,
110                                         std::int32_t result_stride) const {
111     gemm_f_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset,
112                    result_offset, result, result_stride);
113   }
114 
ScratchPerThread(std::int32_t m,std::int32_t n,std::int32_t k)115   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
116                                        std::int32_t k) {
117     return 4 * kMaxCacheFriendlySize;
118   }
119 };
120 
121 class GemmInt32Operation : public Int32Operation {
122  public:
GemmInt32Operation(std::int32_t lhs_offset,std::int32_t rhs_offset)123   GemmInt32Operation(std::int32_t lhs_offset, std::int32_t rhs_offset)
124       : Int32Operation(lhs_offset, rhs_offset) {}
125 
ExecuteMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t * result,std::int32_t result_stride)126   void ExecuteMatrixMatrix(std::uint8_t* scratch, const std::uint8_t* lhs,
127                            const std::uint8_t* rhs, std::int32_t m,
128                            std::int32_t n, std::int32_t k, std::int32_t* result,
129                            std::int32_t result_stride) const {
130     CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, result_stride,
131                               *this);
132   }
133 
ExecuteCacheFriendlyMatrixMatrix(std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t * result,std::int32_t result_stride)134   void ExecuteCacheFriendlyMatrixMatrix(std::uint8_t* scratch,
135                                         const std::uint8_t* lhs,
136                                         const std::uint8_t* rhs, std::int32_t m,
137                                         std::int32_t n, std::int32_t k,
138                                         std::int32_t* result,
139                                         std::int32_t result_stride) const {
140     gemm_i32_strided(scratch, lhs, rhs, m, n, k, lhs_offset, rhs_offset, result,
141                      result_stride);
142   }
143 
ScratchPerThread(std::int32_t m,std::int32_t n,std::int32_t k)144   static std::int32_t ScratchPerThread(std::int32_t m, std::int32_t n,
145                                        std::int32_t k) {
146     return 4 * kMaxCacheFriendlySize;
147   }
148 };
149 
150 }  // namespace internal
151 
gemm_q8_scratch(std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t max_threads)152 std::int32_t gemm_q8_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
153                              std::int32_t max_threads) {
154   return internal::ResolveMaxThreads(max_threads) *
155          internal::GemmQuantized8BitOperation::ScratchPerThread(m, n, k);
156 }
157 
multi_thread_gemm_q8(gemmlowp::WorkersPool * pool,std::int32_t max_threads,std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t sum_offset,std::int32_t multiplier,std::int32_t shift,std::uint8_t * result)158 void multi_thread_gemm_q8(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
159                           std::uint8_t* scratch, const std::uint8_t* lhs,
160                           const std::uint8_t* rhs, std::int32_t m,
161                           std::int32_t n, std::int32_t k,
162                           std::int32_t lhs_offset, std::int32_t rhs_offset,
163                           std::int32_t sum_offset, std::int32_t multiplier,
164                           std::int32_t shift, std::uint8_t* result) {
165   if (m == 1) {
166     multi_thread_gemv_q8(pool, max_threads, scratch, lhs, rhs, n, k, lhs_offset,
167                          rhs_offset, sum_offset, multiplier, shift, result);
168     return;
169   } else if (n == 1) {
170     multi_thread_gemv_q8(pool, max_threads, scratch, rhs, lhs, m, k, rhs_offset,
171                          lhs_offset, sum_offset, multiplier, shift, result);
172     return;
173   }
174 
175   max_threads = internal::ResolveMaxThreads(max_threads);
176   internal::GemmQuantized8BitOperation operation(lhs_offset, rhs_offset,
177                                                  sum_offset, multiplier, shift);
178   if (max_threads == 1) {
179     internal::CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, n,
180                                         operation);
181   } else {
182     internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m,
183                                         n, k, result, n, operation);
184   }
185 }
186 
gemm_f_scratch(std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t max_threads)187 std::int32_t gemm_f_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
188                             std::int32_t max_threads) {
189   return internal::ResolveMaxThreads(max_threads) *
190          internal::GemmFloatOperation::ScratchPerThread(m, n, k);
191 }
192 
multi_thread_gemm_f(gemmlowp::WorkersPool * pool,std::int32_t max_threads,std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,float result_offset,float * result)193 void multi_thread_gemm_f(gemmlowp::WorkersPool* pool, std::int32_t max_threads,
194                          std::uint8_t* scratch, const std::uint8_t* lhs,
195                          const std::uint8_t* rhs, std::int32_t m,
196                          std::int32_t n, std::int32_t k,
197                          std::int32_t lhs_offset, std::int32_t rhs_offset,
198                          float result_offset, float* result) {
199   if (m == 1) {
200     multi_thread_gemv_f(pool, max_threads, scratch, lhs, rhs, n, k, lhs_offset,
201                         rhs_offset, result_offset, result);
202     return;
203   } else if (n == 1) {
204     multi_thread_gemv_f(pool, max_threads, scratch, rhs, lhs, m, k, rhs_offset,
205                         lhs_offset, result_offset, result);
206     return;
207   }
208 
209   max_threads = internal::ResolveMaxThreads(max_threads);
210   internal::GemmFloatOperation operation(lhs_offset, rhs_offset, result_offset);
211   if (max_threads == 1) {
212     internal::CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, n,
213                                         operation);
214   } else {
215     internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m,
216                                         n, k, result, n, operation);
217   }
218 }
219 
gemm_i32_scratch(std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t max_threads)220 std::int32_t gemm_i32_scratch(std::int32_t m, std::int32_t n, std::int32_t k,
221                               std::int32_t max_threads) {
222   return internal::ResolveMaxThreads(max_threads) *
223          internal::GemmInt32Operation::ScratchPerThread(m, n, k);
224 }
225 
multi_thread_gemm_i32(gemmlowp::WorkersPool * pool,std::int32_t max_threads,std::uint8_t * scratch,const std::uint8_t * lhs,const std::uint8_t * rhs,std::int32_t m,std::int32_t n,std::int32_t k,std::int32_t lhs_offset,std::int32_t rhs_offset,std::int32_t * result)226 void multi_thread_gemm_i32(gemmlowp::WorkersPool* pool,
227                            std::int32_t max_threads, std::uint8_t* scratch,
228                            const std::uint8_t* lhs, const std::uint8_t* rhs,
229                            std::int32_t m, std::int32_t n, std::int32_t k,
230                            std::int32_t lhs_offset, std::int32_t rhs_offset,
231                            std::int32_t* result) {
232   if (m == 1) {
233     multi_thread_gemv_i32(pool, max_threads, scratch, lhs, rhs, n, k,
234                           lhs_offset, rhs_offset, result);
235     return;
236   } else if (n == 1) {
237     multi_thread_gemv_i32(pool, max_threads, scratch, rhs, lhs, m, k,
238                           rhs_offset, lhs_offset, result);
239     return;
240   }
241 
242   max_threads = internal::ResolveMaxThreads(max_threads);
243   internal::GemmInt32Operation operation(lhs_offset, rhs_offset);
244   if (max_threads == 1) {
245     internal::CacheFriendlyMatrixMatrix(scratch, lhs, rhs, m, n, k, result, n,
246                                         operation);
247   } else {
248     internal::MultiThreadedMatrixMatrix(pool, max_threads, scratch, lhs, rhs, m,
249                                         n, k, result, n, operation);
250   }
251 }
252 
253 }  // namespace meta
254 }  // namespace gemmlowp
255 
256 #else
257 #warning "Meta gemm fast-path requires GEMMLOWP_NEON_(32|64)!"
258 #endif
259 
260 #endif  // GEMMLOWP_META_MULTI_THREAD_GEMM_H_
261