1 /**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "glow/Backends/Interpreter/Interpreter.h"
18
19 #include "glow/Base/TensorSerialization.h"
20 #include "glow/IR/Instrs.h"
21 #include "glow/Quantization/Base/Base.h"
22 #include "glow/Quantization/Base/Profile.h"
23
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/raw_ostream.h"
27
28 #include <chrono>
29 #include <cmath>
30 #include <math.h>
31
32 #ifdef WIN32
33 #include <corecrt_math_defines.h>
34 #endif
35
36 using namespace glow;
37
38 #define dispatchImpl(functionName, elemTy, ...) \
39 switch (elemTy) { \
40 case ElemKind::FloatTy: \
41 functionName<float>(__VA_ARGS__); \
42 break; \
43 case ElemKind::Float16Ty: \
44 functionName<float16_t>(__VA_ARGS__); \
45 break; \
46 case ElemKind::BFloat16Ty: \
47 functionName<bfloat16_t>(__VA_ARGS__); \
48 break; \
49 case ElemKind::Int8QTy: \
50 functionName<int8_t>(__VA_ARGS__); \
51 break; \
52 case ElemKind::Int16QTy: \
53 functionName<int16_t>(__VA_ARGS__); \
54 break; \
55 case ElemKind::Int32QTy: \
56 functionName<int32_t>(__VA_ARGS__); \
57 break; \
58 case ElemKind::Int32ITy: \
59 functionName<int32_t>(__VA_ARGS__); \
60 break; \
61 case ElemKind::Int64ITy: \
62 functionName<int64_t>(__VA_ARGS__); \
63 break; \
64 case ElemKind::BoolTy: \
65 functionName<bool>(__VA_ARGS__); \
66 break; \
67 default: \
68 llvm_unreachable("Type is not supported"); \
69 }
70
71 #define dispatchFloatingPointImpl(functionName, elemTy, ...) \
72 switch (elemTy) { \
73 case ElemKind::FloatTy: \
74 functionName<float>(__VA_ARGS__); \
75 break; \
76 case ElemKind::Float16Ty: \
77 functionName<float16_t>(__VA_ARGS__); \
78 break; \
79 case ElemKind::BFloat16Ty: \
80 functionName<bfloat16_t>(__VA_ARGS__); \
81 break; \
82 default: \
83 llvm_unreachable("Type is not supported"); \
84 }
85
86 #define dispatchFloatingPointAndIndexImpl(functionName, elemTy, elemTyIndex, \
87 ...) \
88 switch (elemTy) { \
89 case ElemKind::FloatTy: \
90 if (elemTyIndex == ElemKind::Int64ITy) { \
91 functionName<float, int64_t>(__VA_ARGS__); \
92 } else if (elemTyIndex == ElemKind::Int32ITy) { \
93 functionName<float, int32_t>(__VA_ARGS__); \
94 } \
95 break; \
96 case ElemKind::Float16Ty: \
97 if (elemTyIndex == ElemKind::Int64ITy) { \
98 functionName<float16, int64_t>(__VA_ARGS__); \
99 } else if (elemTyIndex == ElemKind::Int32ITy) { \
100 functionName<float16, int32_t>(__VA_ARGS__); \
101 } \
102 break; \
103 case ElemKind::BFloat16Ty: \
104 if (elemTyIndex == ElemKind::Int64ITy) { \
105 functionName<bfloat16, int64_t>(__VA_ARGS__); \
106 } else if (elemTyIndex == ElemKind::Int32ITy) { \
107 functionName<bfloat16, int32_t>(__VA_ARGS__); \
108 } \
109 break; \
110 default: \
111 llvm_unreachable("Type is not supported"); \
112 }
113
114 #define dispatchIndexTypeImpl(functionName, elemTy, ...) \
115 switch (elemTy) { \
116 case ElemKind::Int32ITy: \
117 functionName<int32_t>(__VA_ARGS__); \
118 break; \
119 case ElemKind::Int64ITy: \
120 functionName<int64_t>(__VA_ARGS__); \
121 break; \
122 default: \
123 llvm_unreachable("Type is not supported"); \
124 }
125
126 #define dispatchArithmeticImpl(functionName, elemTy, ...) \
127 switch (elemTy) { \
128 case ElemKind::FloatTy: \
129 functionName<float>(__VA_ARGS__); \
130 break; \
131 case ElemKind::Float16Ty: \
132 functionName<float16_t>(__VA_ARGS__); \
133 break; \
134 case ElemKind::BFloat16Ty: \
135 functionName<bfloat16_t>(__VA_ARGS__); \
136 break; \
137 case ElemKind::Int32ITy: \
138 functionName<int32_t>(__VA_ARGS__); \
139 break; \
140 case ElemKind::Int64ITy: \
141 functionName<int64_t>(__VA_ARGS__); \
142 break; \
143 default: \
144 llvm_unreachable("Type is not supported"); \
145 }
146
147 #define dispatchQuantizedImpl(functionName, elemTy, ...) \
148 switch (elemTy) { \
149 case ElemKind::Int8QTy: \
150 functionName<int8_t>(__VA_ARGS__); \
151 break; \
152 case ElemKind::Int16QTy: \
153 functionName<int16_t>(__VA_ARGS__); \
154 break; \
155 case ElemKind::Int32QTy: \
156 functionName<int32_t>(__VA_ARGS__); \
157 break; \
158 default: \
159 llvm_unreachable("Type is not supported"); \
160 }
161
162 #define dispatchQuantizedWithAccumulationImpl(functionName, elemTy, ...) \
163 switch (elemTy) { \
164 case ElemKind::Int8QTy: \
165 functionName<int8_t, int32_t>(__VA_ARGS__); \
166 break; \
167 case ElemKind::Int16QTy: \
168 functionName<int16_t, int64_t>(__VA_ARGS__); \
169 break; \
170 default: \
171 llvm_unreachable("Type is not supported"); \
172 }
173
174 #define dispatchQuantizedWithAccumulationAndBiasImpl(functionName, elemTy, \
175 biasElemType, ...) \
176 if (elemTy == ElemKind::Int8QTy && biasElemType == ElemKind::Int8QTy) { \
177 functionName<int8_t, int32_t, int8_t>(__VA_ARGS__); \
178 } else if (elemTy == ElemKind::Int8QTy && \
179 biasElemType == ElemKind::Int32QTy) { \
180 functionName<int8_t, int32_t, int32_t>(__VA_ARGS__); \
181 } else if (elemTy == ElemKind::Int16QTy && \
182 biasElemType == ElemKind::Int16QTy) { \
183 functionName<int16_t, int64_t, int16_t>(__VA_ARGS__); \
184 } else if (elemTy == ElemKind::Int16QTy && \
185 biasElemType == ElemKind::Int32QTy) { \
186 functionName<int16_t, int64_t, int32_t>(__VA_ARGS__); \
187 } else { \
188 llvm_unreachable("Type is not supported"); \
189 }
190
191 #define staticAssertFloatingPointType(ElemTy) \
192 static_assert( \
193 std::is_floating_point<ElemTy>::value || \
194 std::is_same<float16_t, \
195 typename std::remove_cv<ElemTy>::type>::value || \
196 std::is_same<bfloat16_t, \
197 typename std::remove_cv<ElemTy>::type>::value, \
198 "This implementation is for floating-point values only")
199
200 #define staticAssertArithmeticType(ElemTy) \
201 static_assert( \
202 std::is_arithmetic<ElemTy>::value || \
203 std::is_same<float16_t, \
204 typename std::remove_cv<ElemTy>::type>::value || \
205 std::is_same<bfloat16_t, \
206 typename std::remove_cv<ElemTy>::type>::value, \
207 "This implementation is for arithmetic values only")
208
209 //===----------------------------------------------------------------------===//
210 // Convolution
211 //===----------------------------------------------------------------------===//
212
213 /// This is the floating point implementation of Convolution.
214 template <typename ElemTy>
fwdConvolutionInstFloatImpl(Value * inV,Value * outV,Value * filterV,Value * biasV,llvm::ArrayRef<unsigned_t> kernelSizes,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,size_t group,size_t dilation)215 void BoundInterpreterFunction::fwdConvolutionInstFloatImpl(
216 Value *inV, Value *outV, Value *filterV, Value *biasV,
217 llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides,
218 llvm::ArrayRef<unsigned_t> pads, size_t group, size_t dilation) {
219 staticAssertFloatingPointType(ElemTy);
220
221 auto inW = getWeightHandle<ElemTy>(inV);
222 auto outW = getWeightHandle<ElemTy>(outV);
223 auto filterW = getWeightHandle<ElemTy>(filterV);
224 auto biasW = getWeightHandle<ElemTy>(biasV);
225
226 ShapeNHWC odim(outW.dims());
227 ShapeNHWC idim(inW.dims());
228 ShapeHW kdim(kernelSizes);
229 ShapeHW sdim(strides);
230
231 assert(idim.c % group == 0 && "Input channels must be divisible by group.");
232 assert(odim.c % group == 0 && "Output channels must be divisible by group.");
233 dim_t inCperG = idim.c / group;
234 dim_t outCperG = odim.c / group;
235
236 PaddingTLBR pdim(pads);
237
238 // For each input in the batch:
239 for (dim_t n = 0; n < idim.n; n++) {
240
241 // For each group of input channels:
242 for (dim_t g = 0; g < group; g++) {
243
244 // For each output channel in the group:
245 for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
246
247 // For each convolution 'jump' in the input tensor:
248 ssize_t x = -ssize_t(pdim.top);
249 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
250 ssize_t y = -ssize_t(pdim.left);
251 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
252
253 // For each element in the convolution-filter:
254 float sum = 0;
255 for (dim_t fx = 0; fx < kdim.height; fx++) {
256 for (dim_t fy = 0; fy < kdim.width; fy++) {
257 sdim_t ox = x + fx * dilation;
258 sdim_t oy = y + fy * dilation;
259
260 // Ignore index access below zero (this is due to padding).
261 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
262 oy >= ssize_t(idim.w)) {
263 continue;
264 }
265 for (dim_t fd = 0; fd < inCperG; fd++) {
266 sum += float(
267 filterW.at({d, fx, fy, fd}) *
268 inW.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd}));
269 }
270 }
271 }
272
273 sum += float(biasW.at({d}));
274 outW.at({n, ax, ay, d}) = ElemTy(sum);
275 } // W
276 } // H
277 } // C
278 } // G
279 } // N
280 }
281
282 /// This is the quantized implementation of Convolution.
283 template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
fwdConvolutionInstQuantizedImpl(Value * inV,Value * outV,Value * filterV,Value * biasV,llvm::ArrayRef<unsigned_t> kernelSizes,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,size_t group,size_t dilation)284 void BoundInterpreterFunction::fwdConvolutionInstQuantizedImpl(
285 Value *inV, Value *outV, Value *filterV, Value *biasV,
286 llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides,
287 llvm::ArrayRef<unsigned_t> pads, size_t group, size_t dilation) {
288 auto inW = getWeightHandle<ElemTy>(inV);
289 auto outW = getWeightHandle<ElemTy>(outV);
290 auto filterW = getWeightHandle<ElemTy>(filterV);
291 auto biasW = getWeightHandle<BiasElemTy>(biasV);
292
293 ShapeNHWC odim(outW.dims());
294 ShapeNHWC idim(inW.dims());
295 ShapeHW kdim(kernelSizes);
296 ShapeHW sdim(strides);
297
298 assert(idim.c % group == 0 && "Input channels must be divisible by group.");
299 assert(odim.c % group == 0 && "Output channels must be divisible by group.");
300 dim_t inCperG = idim.c / group;
301 dim_t outCperG = odim.c / group;
302
303 PaddingTLBR pdim(pads);
304 auto outTy = outV->getType();
305 auto inTy = inV->getType();
306 auto filterTy = filterV->getType();
307 auto biasTy = biasV->getType();
308
309 int32_t outOffset = outTy->getOffset();
310 int32_t inOffset = inTy->getOffset();
311 int32_t filterOffset = filterTy->getOffset();
312 int32_t biasOffset = biasTy->getOffset();
313
314 float outScale = outTy->getScale();
315 float inScale = inTy->getScale();
316 float filterScale = filterTy->getScale();
317 float biasScale = biasTy->getScale();
318
319 // Calculate the scale of the values that come out of the matrix
320 // multiplication part of the calculation.
321 float matMulScale = inScale * filterScale;
322
323 // For each input in the batch:
324 for (dim_t n = 0; n < idim.n; n++) {
325 // For each group of input channels:
326 for (dim_t g = 0; g < group; g++) {
327
328 // For each output channel in the group:
329 for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
330
331 // For each convolution 'jump' in the input tensor:
332 ssize_t x = -ssize_t(pdim.top);
333 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
334 ssize_t y = -ssize_t(pdim.left);
335 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
336
337 // For each element in the convolution-filter:
338 AccumulatorTy sum = 0;
339 for (dim_t fx = 0; fx < kdim.height; fx++) {
340 for (dim_t fy = 0; fy < kdim.width; fy++) {
341 sdim_t ox = x + fx * dilation;
342 sdim_t oy = y + fy * dilation;
343
344 // Ignore index access below zero (this is due to padding).
345 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
346 oy >= sdim_t(idim.w)) {
347 continue;
348 }
349 for (dim_t fd = 0; fd < inCperG; fd++) {
350
351 AccumulatorTy F = filterW.at({d, fx, fy, fd});
352 AccumulatorTy I =
353 inW.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd});
354 // We represent the element multiplication with offset as
355 // (value - offset).
356 sum += (F - filterOffset) * (I - inOffset);
357 }
358 }
359 }
360
361 // Scale the bias to match the scale of the matrix multiplication.
362 AccumulatorTy B = std::round(float(biasW.at({d}) - biasOffset) *
363 (biasScale / matMulScale));
364
365 // Add the bias.
366 sum += B;
367
368 // Scale the result back to the expected destination scale.
369 outW.at({n, ax, ay, d}) = quantization::clip<AccumulatorTy, ElemTy>(
370 std::round(float(sum) * (matMulScale / outScale) + outOffset));
371 } // W
372 } // H
373 } // C
374 } // G
375 } // N
376 }
377
378 /// This is the floating point implementation of ConvTranspose.
379 template <typename ElemTy>
fwdConvTransposeInstFloatImpl(Value * inV,Value * outV,Value * filterV,Value * biasV,llvm::ArrayRef<unsigned_t> kernelSizes,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,size_t group,size_t dilation)380 void BoundInterpreterFunction::fwdConvTransposeInstFloatImpl(
381 Value *inV, Value *outV, Value *filterV, Value *biasV,
382 llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides,
383 llvm::ArrayRef<unsigned_t> pads, size_t group, size_t dilation) {
384 staticAssertFloatingPointType(ElemTy);
385
386 auto inW = getWeightHandle<ElemTy>(inV);
387 auto outW = getWeightHandle<ElemTy>(outV);
388 auto filterW = getWeightHandle<ElemTy>(filterV);
389 auto biasW = getWeightHandle<ElemTy>(biasV);
390
391 ShapeNHWC odim(outW.dims());
392 ShapeNHWC idim(inW.dims());
393 ShapeHW kdim(kernelSizes);
394 ShapeHW sdim(strides);
395
396 assert(idim.c % group == 0 && "Input channels must be divisible by group.");
397 assert(odim.c % group == 0 && "Output channels must be divisible by group.");
398 assert(group == 1 && "Group must be 1.");
399
400 dim_t inCperG = idim.c / group;
401 dim_t outCperG = odim.c / group;
402
403 PaddingTLBR pdim(pads);
404
405 // For each input in the batch:
406 for (dim_t n = 0; n < idim.n; n++) {
407
408 // Initialize bias (TODO take out to a separate function when quant is in).
409 for (dim_t ax = 0; ax < odim.h; ax++) {
410 for (dim_t ay = 0; ay < odim.w; ay++) {
411 for (dim_t d = 0; d < odim.c; d++) {
412 outW.at({n, ax, ay, d}) = static_cast<ElemTy>(biasW.at({d}));
413 }
414 }
415 }
416
417 // For each group of input channels:
418 for (dim_t g = 0; g < group; g++) {
419
420 // For each input channel in the group:
421 for (dim_t d = g * inCperG; d < (g + 1) * inCperG; d++) {
422
423 // For each transposed convolution 'jump' in the input tensor:
424 ssize_t x = -ssize_t(pdim.top);
425 for (dim_t bx = 0; bx < idim.h; bx++, x += sdim.height) {
426 ssize_t y = -ssize_t(pdim.left);
427 for (dim_t by = 0; by < idim.w; by++, y += sdim.width) {
428
429 // For each element in the each transposed convolution filter:
430 ElemTy input = inW.at({n, bx, by, d});
431
432 for (dim_t kx = 0; kx < kdim.height; kx++) {
433 for (dim_t ky = 0; ky < kdim.width; ky++) {
434 ssize_t ax = x + kx * dilation;
435 ssize_t ay = y + ky * dilation;
436
437 // Ignore index access below zero (this is due to padding).
438 if (ax < 0 || ay < 0 || ax >= ssize_t(odim.h) ||
439 ay >= ssize_t(odim.w)) {
440 continue;
441 }
442 for (dim_t c = 0; c < outCperG; c++) {
443 outW.at({n, (dim_t)ax, (dim_t)ay, g * outCperG + c}) +=
444 filterW.at({c, kx, ky, d}) * input;
445 }
446 }
447 }
448 } // W
449 } // H
450 } // C
451 } // G
452 } // N
453 }
454
fwdConvTransposeInst(const ConvTransposeInst * I)455 void BoundInterpreterFunction::fwdConvTransposeInst(
456 const ConvTransposeInst *I) {
457 auto kernelSizes = I->getKernels();
458 auto pads = I->getPads();
459 auto strides = I->getStrides();
460 size_t group = I->getGroup();
461
462 if (I->getSrc()->getType()->isQuantizedType()) {
463 llvm_unreachable("Quantized ConvTranspose not supported");
464 return;
465 }
466
467 dispatchFloatingPointImpl(
468 fwdConvTransposeInstFloatImpl, I->getSrc()->getElementType(), I->getSrc(),
469 I->getDest(), I->getFilter(), I->getBias(), kernelSizes, strides, pads,
470 group, I->getDilation());
471 }
472
fwdConvolutionInst(const ConvolutionInst * I)473 void BoundInterpreterFunction::fwdConvolutionInst(const ConvolutionInst *I) {
474 auto kernelSizes = I->getKernels();
475 auto pads = I->getPads();
476 auto strides = I->getStrides();
477 size_t group = I->getGroup();
478
479 if (I->getSrc()->getType()->isQuantizedType()) {
480 dispatchQuantizedWithAccumulationAndBiasImpl(
481 fwdConvolutionInstQuantizedImpl, I->getSrc()->getElementType(),
482 I->getBias()->getElementType(), I->getSrc(), I->getDest(),
483 I->getFilter(), I->getBias(), kernelSizes, strides, pads, group,
484 I->getDilation());
485 return;
486 }
487
488 dispatchFloatingPointImpl(
489 fwdConvolutionInstFloatImpl, I->getSrc()->getElementType(), I->getSrc(),
490 I->getDest(), I->getFilter(), I->getBias(), kernelSizes, strides, pads,
491 group, I->getDilation());
492 }
493
fwdConvolutionGradInst(const ConvolutionGradInst * I)494 void BoundInterpreterFunction::fwdConvolutionGradInst(
495 const ConvolutionGradInst *I) {
496 auto inW = getWeightHandle(I->getSrc());
497 auto inG = getWeightHandle(I->getSrcGrad());
498 auto outG = getWeightHandle(I->getDestGrad());
499
500 auto filterW = getWeightHandle(I->getFilter());
501 auto filterG = getWeightHandle(I->getFilterGrad());
502 auto biasG = getWeightHandle(I->getBiasGrad());
503
504 size_t group = I->getGroup();
505 size_t dilation = I->getDilation();
506
507 inG.clear();
508 filterG.clear();
509 biasG.clear();
510
511 ShapeNHWC odim(outG.dims());
512 ShapeNHWC idim(inW.dims());
513 ShapeHW kdim(I->getKernels());
514 ShapeHW sdim(I->getStrides());
515 PaddingTLBR pdim(I->getPads());
516
517 assert(idim.c % group == 0 && "Input channels must be divisible by group.");
518 assert(odim.c % group == 0 && "Output channels must be divisible by group.");
519 dim_t inCperG = idim.c / group;
520 dim_t outCperG = odim.c / group;
521
522 // For each input in the batch:
523 for (dim_t n = 0; n < odim.n; n++) {
524
525 // For each group of input channels:
526 for (dim_t g = 0; g < group; g++) {
527
528 // Compute the gradient. For each layer in the output tensor:
529 for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
530
531 // For each convolution 'jump' in the input tensor:
532 sdim_t x = -sdim_t(pdim.top);
533 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
534 sdim_t y = -sdim_t(pdim.left);
535 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
536
537 float chainGrad = outG.at({n, ax, ay, d});
538
539 // For each element in the convolution-filter:
540 for (dim_t fx = 0; fx < kdim.height; fx++) {
541 for (dim_t fy = 0; fy < kdim.width; fy++) {
542 sdim_t ox = x + fx * dilation;
543 sdim_t oy = y + fy * dilation;
544
545 // Ignore index access below zero (this is due to padding).
546 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
547 oy >= sdim_t(idim.w)) {
548 continue;
549 }
550
551 for (dim_t fd = 0; fd < inCperG; fd++) {
552 filterG.at({d, fx, fy, fd}) +=
553 inW.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd}) *
554 chainGrad;
555 inG.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd}) +=
556 filterW.at({d, fx, fy, fd}) * chainGrad;
557 }
558 }
559 }
560
561 biasG.at({d}) += chainGrad;
562 } // W
563 } // H
564 } // C
565 } // G
566 } // N
567 }
568
569 /// This is the floating point implementation of Convolution3D.
570 template <typename ElemTy>
fwdConvolution3DInstFloatImpl(Value * inV,Value * outV,Value * filterV,Value * biasV,llvm::ArrayRef<unsigned_t> kernelSizes,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,size_t group)571 void BoundInterpreterFunction::fwdConvolution3DInstFloatImpl(
572 Value *inV, Value *outV, Value *filterV, Value *biasV,
573 llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides,
574 llvm::ArrayRef<unsigned_t> pads, size_t group) {
575 staticAssertFloatingPointType(ElemTy);
576
577 auto inW = getWeightHandle<ElemTy>(inV);
578 auto outW = getWeightHandle<ElemTy>(outV);
579 auto filterW = getWeightHandle<ElemTy>(filterV);
580 auto biasW = getWeightHandle<ElemTy>(biasV);
581
582 ShapeNTHWC odim(outW.dims());
583 ShapeNTHWC idim(inW.dims());
584 ShapeTHW kdim(kernelSizes);
585 ShapeTHW sdim(strides);
586
587 assert(idim.c % group == 0 && "Input channels must be divisible by group.");
588 assert(odim.c % group == 0 && "Output channels must be divisible by group.");
589 dim_t inCperG = idim.c / group;
590 dim_t outCperG = odim.c / group;
591
592 PaddingNFTBLR pdim(pads);
593
594 // For each input in the batch:
595 for (dim_t n = 0; n < idim.n; n++) {
596
597 // For each group of input channels:
598 for (dim_t ig = 0; ig < group; ig++) {
599
600 // For each output channel in the group:
601 for (dim_t og = ig * outCperG; og < (ig + 1) * outCperG; og++) {
602
603 ssize_t t = -ssize_t(pdim.near);
604 for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
605 // For each convolution 'jump' in the input tensor:
606 ssize_t x = -ssize_t(pdim.top);
607 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
608 ssize_t y = -ssize_t(pdim.left);
609 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
610 // For each element in the 3D convolution-filter:
611 float sum = 0;
612 for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
613 for (dim_t fx = 0; fx < kdim.height; fx++) {
614 for (dim_t fy = 0; fy < kdim.width; fy++) {
615 sdim_t ot = t + ft;
616 sdim_t ox = x + fx;
617 sdim_t oy = y + fy;
618
619 // Ignore index access below zero (this is due to padding).
620 if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
621 ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) {
622 continue;
623 }
624 for (dim_t fg = 0; fg < inCperG; fg++) {
625 sum += float(filterW.at({og, ft, fx, fy, fg}) *
626 inW.at({n, (dim_t)ot, (dim_t)ox, (dim_t)oy,
627 ig * inCperG + fg}));
628 }
629 }
630 }
631 }
632
633 sum += float(biasW.at({og}));
634 outW.at({n, at, ax, ay, og}) = ElemTy(sum);
635 } // D
636 } // W
637 } // H
638 } // C
639 } // G
640 } // N
641 }
642
643 /// This is the quantized implementation of Convolution3D.
644 template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
fwdConvolution3DInstQuantizedImpl(Value * inV,Value * outV,Value * filterV,Value * biasV,llvm::ArrayRef<unsigned_t> kernelSizes,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,size_t group)645 void BoundInterpreterFunction::fwdConvolution3DInstQuantizedImpl(
646 Value *inV, Value *outV, Value *filterV, Value *biasV,
647 llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides,
648 llvm::ArrayRef<unsigned_t> pads, size_t group) {
649 auto inW = getWeightHandle<ElemTy>(inV);
650 auto outW = getWeightHandle<ElemTy>(outV);
651 auto filterW = getWeightHandle<ElemTy>(filterV);
652 auto biasW = getWeightHandle<BiasElemTy>(biasV);
653
654 ShapeNTHWC odim(outW.dims());
655 ShapeNTHWC idim(inW.dims());
656 ShapeTHW kdim(kernelSizes);
657 ShapeTHW sdim(strides);
658
659 assert(idim.c % group == 0 && "Input channels must be divisible by group.");
660 assert(odim.c % group == 0 && "Output channels must be divisible by group.");
661 dim_t inCperG = idim.c / group;
662 dim_t outCperG = odim.c / group;
663
664 PaddingNFTBLR pdim(pads);
665
666 auto outTy = outV->getType();
667 auto inTy = inV->getType();
668 auto filterTy = filterV->getType();
669 auto biasTy = biasV->getType();
670
671 int32_t outOffset = outTy->getOffset();
672 int32_t inOffset = inTy->getOffset();
673 int32_t filterOffset = filterTy->getOffset();
674 int32_t biasOffset = biasTy->getOffset();
675
676 float outScale = outTy->getScale();
677 float inScale = inTy->getScale();
678 float filterScale = filterTy->getScale();
679 float biasScale = biasTy->getScale();
680
681 // Calculate the scale of the values that come out of the matrix
682 // multiplication part of the calculation.
683 float matMulScale = inScale * filterScale;
684
685 // For each input in the batch:
686 for (dim_t n = 0; n < idim.n; n++) {
687
688 // For each group of input channels:
689 for (dim_t ig = 0; ig < group; ig++) {
690
691 // For each output channel in the group:
692 for (dim_t og = ig * outCperG; og < (ig + 1) * outCperG; og++) {
693
694 // For each convolution 'jump' in the input tensor:
695 ssize_t t = -ssize_t(pdim.near);
696 for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
697 ssize_t x = -ssize_t(pdim.top);
698 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
699 ssize_t y = -ssize_t(pdim.left);
700 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
701
702 // For each element in the convolution-filter:
703 AccumulatorTy sum = 0;
704 for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
705 for (dim_t fx = 0; fx < kdim.height; fx++) {
706 for (dim_t fy = 0; fy < kdim.width; fy++) {
707 ssize_t ot = t + ft;
708 ssize_t ox = x + fx;
709 ssize_t oy = y + fy;
710
711 // Ignore index access below zero (this is due to padding).
712 if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
713 ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) {
714 continue;
715 }
716 for (dim_t fg = 0; fg < inCperG; fg++) {
717
718 AccumulatorTy F = filterW.at({og, ft, fx, fy, fg});
719 AccumulatorTy I = inW.at({n, (dim_t)ot, (dim_t)ox,
720 (dim_t)oy, ig * inCperG + fg});
721 // We represent the element multiplication with offset as
722 // (value - offset).
723 sum += (F - filterOffset) * (I - inOffset);
724 }
725 }
726 }
727 }
728
729 // Scale the bias to match the scale of the matrix multiplication.
730 AccumulatorTy B = std::round(float(biasW.at({og}) - biasOffset) *
731 (biasScale / matMulScale));
732
733 // Add the bias:
734 sum += B;
735
736 // Scale the result back to the expected destination scale.
737 outW.at({n, at, ax, ay, og}) =
738 quantization::clip<AccumulatorTy, ElemTy>(std::round(
739 float(sum) * (matMulScale / outScale) + outOffset));
740 } // D
741 } // W
742 } // H
743 } // C
744 } // G
745 } // N
746 }
747
fwdConvolution3DInst(const Convolution3DInst * I)748 void BoundInterpreterFunction::fwdConvolution3DInst(
749 const Convolution3DInst *I) {
750 auto kernelSizes = I->getKernels();
751 auto pads = I->getPads();
752 auto strides = I->getStrides();
753 size_t group = I->getGroup();
754
755 if (I->getSrc()->getType()->isQuantizedType()) {
756 dispatchQuantizedWithAccumulationAndBiasImpl(
757 fwdConvolution3DInstQuantizedImpl, I->getSrc()->getElementType(),
758 I->getBias()->getElementType(), I->getSrc(), I->getDest(),
759 I->getFilter(), I->getBias(), kernelSizes, strides, pads, group);
760 return;
761 }
762
763 dispatchFloatingPointImpl(fwdConvolution3DInstFloatImpl,
764 I->getSrc()->getElementType(), I->getSrc(),
765 I->getDest(), I->getFilter(), I->getBias(),
766 kernelSizes, strides, pads, group);
767 }
768
fwdConvolution3DGradInst(const Convolution3DGradInst * I)769 void BoundInterpreterFunction::fwdConvolution3DGradInst(
770 const Convolution3DGradInst *I) {
771 (void)I;
772 // TODO
773 llvm_unreachable("not yet implemented");
774 }
775
776 //===----------------------------------------------------------------------===//
777 // Channelwise quantized Convolution
778 //===----------------------------------------------------------------------===//
779 template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
fwdChannelwiseQuantizedConv2DInstImpl(const ChannelwiseQuantizedConvolutionInst * I)780 void BoundInterpreterFunction::fwdChannelwiseQuantizedConv2DInstImpl(
781 const ChannelwiseQuantizedConvolutionInst *I) {
782 auto inW = getWeightHandle<ElemTy>(I->getSrc());
783 auto outW = getWeightHandle<ElemTy>(I->getDest());
784 auto filterW = getWeightHandle<ElemTy>(I->getFilter());
785 auto biasW = getWeightHandle<BiasElemTy>(I->getBias());
786 auto filterScales = getWeightHandle<float>(I->getFilterScales());
787 auto filterOffsets = getWeightHandle<int32_t>(I->getFilterOffsets());
788 auto biasScales = getWeightHandle<float>(I->getBiasScales());
789 auto biasOffsets = getWeightHandle<int32_t>(I->getBiasOffsets());
790
791 llvm::ArrayRef<unsigned_t> kernelSizes = I->getKernels();
792 llvm::ArrayRef<unsigned_t> pads = I->getPads();
793 llvm::ArrayRef<unsigned_t> strides = I->getStrides();
794 dim_t group = I->getGroup();
795 dim_t dilation = I->getDilation();
796
797 ShapeNHWC odim(outW.dims());
798 ShapeNHWC idim(inW.dims());
799 ShapeHW kdim(kernelSizes);
800 ShapeHW sdim(strides);
801
802 assert(idim.c % group == 0 && "Input channels must be divisible by group.");
803 assert(odim.c % group == 0 && "Output channels must be divisible by group.");
804 dim_t inCperG = idim.c / group;
805 dim_t outCperG = odim.c / group;
806
807 PaddingTLBR pdim(pads);
808
809 auto &inTy = inW.getType();
810 auto &outTy = outW.getType();
811
812 float inScale = inTy.getScale();
813 float outScale = outTy.getScale();
814
815 int32_t inOffset = inTy.getOffset();
816 int32_t outOffset = outTy.getOffset();
817
818 // For each input in the batch:
819 for (dim_t n = 0; n < idim.n; n++) {
820 // For each group of input channels:
821 for (dim_t g = 0; g < group; g++) {
822 // For each output channel in the group:
823 for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
824
825 // Get channel wise quantization params.
826 int32_t filterOffset = filterOffsets.at(d);
827 float filterScale = filterScales.at(d);
828 int32_t biasOffset = biasOffsets.at(d);
829 float biasScale = biasScales.at(d);
830 float matMulScale = inScale * filterScale;
831
832 // For each convolution 'jump' in the input tensor:
833 sdim_t x = -sdim_t(pdim.top);
834 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
835 sdim_t y = -sdim_t(pdim.left);
836 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
837
838 // For each element in the convolution-filter:
839 AccumulatorTy sum = 0;
840 for (dim_t fx = 0; fx < kdim.height; fx++) {
841 for (dim_t fy = 0; fy < kdim.width; fy++) {
842 sdim_t ox = x + fx * dilation;
843 sdim_t oy = y + fy * dilation;
844
845 // Ignore index access below zero (this is due to padding).
846 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
847 oy >= sdim_t(idim.w)) {
848 continue;
849 }
850
851 // Accumulate along the filter depth.
852 for (dim_t fd = 0; fd < inCperG; fd++) {
853 AccumulatorTy F = filterW.at({d, fx, fy, fd});
854 AccumulatorTy I =
855 inW.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd});
856 // We represent the element multiplication with offset as
857 // (value - offset).
858 sum += (F - filterOffset) * (I - inOffset);
859 }
860 }
861 }
862
863 // Scale the bias to match the scale of the matrix multiplication.
864 sum += std::round(float(biasW.at({d}) - biasOffset) *
865 (biasScale / matMulScale));
866
867 // Scale the result back to the expected destination scale.
868 outW.at({n, ax, ay, d}) = quantization::clip<AccumulatorTy, ElemTy>(
869 std::round(float(sum) * (matMulScale / outScale) + outOffset));
870 } // W
871 } // H
872 } // C
873 } // G
874 } // N
875 }
876
877 template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
fwdChannelwiseQuantizedConv3DInstImpl(const ChannelwiseQuantizedConvolutionInst * I)878 void BoundInterpreterFunction::fwdChannelwiseQuantizedConv3DInstImpl(
879 const ChannelwiseQuantizedConvolutionInst *I) {
880 auto inW = getWeightHandle<ElemTy>(I->getSrc());
881 auto outW = getWeightHandle<ElemTy>(I->getDest());
882 auto filterW = getWeightHandle<ElemTy>(I->getFilter());
883 auto biasW = getWeightHandle<BiasElemTy>(I->getBias());
884 auto filterScales = getWeightHandle<float>(I->getFilterScales());
885 auto filterOffsets = getWeightHandle<int32_t>(I->getFilterOffsets());
886 auto biasScales = getWeightHandle<float>(I->getBiasScales());
887 auto biasOffsets = getWeightHandle<int32_t>(I->getBiasOffsets());
888
889 llvm::ArrayRef<unsigned_t> kernelSizes = I->getKernels();
890 llvm::ArrayRef<unsigned_t> pads = I->getPads();
891 llvm::ArrayRef<unsigned_t> strides = I->getStrides();
892 dim_t group = I->getGroup();
893
894 ShapeNTHWC odim(outW.dims());
895 ShapeNTHWC idim(inW.dims());
896 ShapeTHW kdim(kernelSizes);
897 ShapeTHW sdim(strides);
898
899 assert(idim.c % group == 0 && "Input channels must be divisible by group.");
900 assert(odim.c % group == 0 && "Output channels must be divisible by group.");
901 dim_t inCperG = idim.c / group;
902 dim_t outCperG = odim.c / group;
903
904 PaddingNFTBLR pdim(pads);
905
906 auto &inTy = inW.getType();
907 auto &outTy = outW.getType();
908
909 float inScale = inTy.getScale();
910 float outScale = outTy.getScale();
911
912 int32_t inOffset = inTy.getOffset();
913 int32_t outOffset = outTy.getOffset();
914
915 // For each input in the batch:
916 for (dim_t n = 0; n < idim.n; n++) {
917 // For each group of input channels:
918 for (dim_t g = 0; g < group; g++) {
919 // For each output channel in the group:
920 for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
921
922 // Get channel wise quantization params.
923 int32_t filterOffset = filterOffsets.at(d);
924 float filterScale = filterScales.at(d);
925 int32_t biasOffset = biasOffsets.at(d);
926 float biasScale = biasScales.at(d);
927 float matMulScale = inScale * filterScale;
928
929 // For each convolution 'jump' in the input tensor:
930 sdim_t t = -sdim_t(pdim.near);
931 for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
932 sdim_t x = -sdim_t(pdim.top);
933 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
934 sdim_t y = -sdim_t(pdim.left);
935 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
936
937 // For each element in the convolution-filter:
938 AccumulatorTy sum = 0;
939 for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
940 for (dim_t fx = 0; fx < kdim.height; fx++) {
941 for (dim_t fy = 0; fy < kdim.width; fy++) {
942 sdim_t ot = t + ft;
943 sdim_t ox = x + fx;
944 sdim_t oy = y + fy;
945
946 // Ignore index access below zero (this is due to
947 // padding).
948 if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
949 ox >= ssize_t(idim.h) || oy >= sdim_t(idim.w)) {
950 continue;
951 }
952
953 // Accumulate along the filter depth.
954 for (dim_t fd = 0; fd < inCperG; fd++) {
955
956 AccumulatorTy F = filterW.at({d, ft, fx, fy, fd});
957 AccumulatorTy I = inW.at({n, (dim_t)ot, (dim_t)ox,
958 (dim_t)oy, g * inCperG + fd});
959 // We represent the element multiplication with offset
960 // as (value - offset).
961 sum += (F - filterOffset) * (I - inOffset);
962 }
963 }
964 }
965 }
966
967 // Scale the bias to match the scale of the matrix multiplication.
968 sum += std::round(float(biasW.at({d}) - biasOffset) *
969 (biasScale / matMulScale));
970
971 // Scale the result back to the expected destination scale.
972 outW.at({n, at, ax, ay, d}) =
973 quantization::clip<AccumulatorTy, ElemTy>(std::round(
974 float(sum) * (matMulScale / outScale) + outOffset));
975 } // W
976 } // H
977 } // T
978 } // C
979 } // G
980 } // N
981 }
982
fwdChannelwiseQuantizedConvolutionInst(const ChannelwiseQuantizedConvolutionInst * I)983 void BoundInterpreterFunction::fwdChannelwiseQuantizedConvolutionInst(
984 const ChannelwiseQuantizedConvolutionInst *I) {
985 bool isConv3D = (I->getSrc()->dims().size() == 5);
986 if (isConv3D) {
987 dispatchQuantizedWithAccumulationAndBiasImpl(
988 fwdChannelwiseQuantizedConv3DInstImpl, I->getSrc()->getElementType(),
989 I->getBias()->getElementType(), I);
990 } else {
991 dispatchQuantizedWithAccumulationAndBiasImpl(
992 fwdChannelwiseQuantizedConv2DInstImpl, I->getSrc()->getElementType(),
993 I->getBias()->getElementType(), I);
994 }
995 }
996
997 //===----------------------------------------------------------------------===//
998 // Pooling
999 //===----------------------------------------------------------------------===//
1000 template <class T>
fwdMaxPool(Tensor * inW,Tensor * outW,Tensor * argmaxW,llvm::ArrayRef<unsigned_t> kernelSizes,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads)1001 static void fwdMaxPool(Tensor *inW, Tensor *outW, Tensor *argmaxW,
1002 llvm::ArrayRef<unsigned_t> kernelSizes,
1003 llvm::ArrayRef<unsigned_t> strides,
1004 llvm::ArrayRef<unsigned_t> pads) {
1005 ShapeNHWC odim(outW->dims());
1006 ShapeNHWC idim(inW->dims());
1007 Handle<T> inHandle = inW->getHandle<T>();
1008 Handle<T> outHandle = outW->getHandle<T>();
1009 PaddingTLBR pdim(pads);
1010 ShapeHW kdim(kernelSizes);
1011 ShapeHW sdim(strides);
1012
1013 llvm::Optional<Handle<int64_t>> argmaxH;
1014 if (argmaxW) {
1015 argmaxH = argmaxW->getHandle<int64_t>();
1016 }
1017 // For each input in the batch:
1018 for (dim_t n = 0; n < odim.n; n++) {
1019
1020 // For each layer in the output tensor:
1021 for (dim_t z = 0; z < idim.c; z++) {
1022 // For each convolution 'jump' in the input tensor:
1023 sdim_t x = -sdim_t(pdim.top);
1024 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1025 sdim_t y = -sdim_t(pdim.left);
1026 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1027
1028 bool first = true;
1029 T max_value = 0;
1030 dim_t argmaxNHWC = 0;
1031
1032 for (dim_t fx = 0; fx < kdim.height; fx++) {
1033 for (dim_t fy = 0; fy < kdim.width; fy++) {
1034 sdim_t ox = x + fx;
1035 sdim_t oy = y + fy;
1036
1037 // Ignore index access below zero (this is due to padding).
1038 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
1039 oy >= ssize_t(idim.w)) {
1040 continue;
1041 }
1042
1043 T val = inHandle.at({n, (dim_t)ox, (dim_t)oy, z});
1044 if (first || (val >= max_value)) {
1045 first = false;
1046 max_value = val;
1047 if (argmaxW) {
1048 argmaxNHWC = &inHandle.at({n, (dim_t)ox, (dim_t)oy, z}) -
1049 &inHandle.raw(0);
1050 }
1051 }
1052 }
1053 }
1054
1055 outHandle.at({n, ax, ay, z}) = max_value;
1056
1057 if (argmaxW) {
1058 (*argmaxH).at({n, ax, ay, z}) = argmaxNHWC;
1059 }
1060 } // W
1061 } // H
1062 } // C
1063 } // N
1064 }
1065
fwdMaxPoolInst(const MaxPoolInst * I)1066 void BoundInterpreterFunction::fwdMaxPoolInst(const MaxPoolInst *I) {
1067 auto inW = getTensor(I->getSrc());
1068 auto outW = getTensor(I->getDest());
1069
1070 if (inW->getType().isQuantizedType()) {
1071 dispatchQuantizedImpl(fwdMaxPool, inW->getType().getElementType(), inW,
1072 outW, nullptr, I->getKernels(), I->getStrides(),
1073 I->getPads());
1074 return;
1075 }
1076
1077 dispatchFloatingPointImpl(fwdMaxPool, inW->getType().getElementType(), inW,
1078 outW, nullptr, I->getKernels(), I->getStrides(),
1079 I->getPads());
1080 }
1081
fwdMaxPoolWithArgmaxInst(const MaxPoolWithArgmaxInst * I)1082 void BoundInterpreterFunction::fwdMaxPoolWithArgmaxInst(
1083 const MaxPoolWithArgmaxInst *I) {
1084 auto inW = getTensor(I->getSrc());
1085 auto outW = getTensor(I->getDest());
1086 auto argmaxW = getTensor(I->getArgmax());
1087
1088 if (inW->getType().isQuantizedType()) {
1089 dispatchQuantizedImpl(fwdMaxPool, inW->getType().getElementType(), inW,
1090 outW, argmaxW, I->getKernels(), I->getStrides(),
1091 I->getPads());
1092 return;
1093 }
1094 dispatchFloatingPointImpl(fwdMaxPool, inW->getType().getElementType(), inW,
1095 outW, argmaxW, I->getKernels(), I->getStrides(),
1096 I->getPads());
1097 }
1098
1099 template <typename ElemTy>
fwdAvgPoolInstFloatImpl(const AvgPoolInst * I)1100 void BoundInterpreterFunction::fwdAvgPoolInstFloatImpl(const AvgPoolInst *I) {
1101 staticAssertFloatingPointType(ElemTy);
1102
1103 ShapeNHWC odim(I->getDest()->dims());
1104 ShapeNHWC idim(I->getSrc()->dims());
1105
1106 PaddingTLBR pdim(I->getPads());
1107 ShapeHW kdim(I->getKernels());
1108 ShapeHW sdim(I->getStrides());
1109 // Implement the avg pooling operation as defined here:
1110 // https://arxiv.org/abs/1312.4400
1111 float filterArea = kdim.height * kdim.width;
1112
1113 auto inW = getWeightHandle<ElemTy>(I->getSrc());
1114 auto outW = getWeightHandle<ElemTy>(I->getDest());
1115
1116 // For each input in the batch:
1117 for (dim_t n = 0; n < odim.n; n++) {
1118 // For each layer in the output tensor:
1119 for (dim_t z = 0; z < idim.c; z++) {
1120 // For each convolution 'jump' in the input tensor:
1121 ssize_t x = -ssize_t(pdim.top);
1122 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1123 ssize_t y = -ssize_t(pdim.left);
1124 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1125 float sum = 0;
1126
1127 for (dim_t fx = 0; fx < kdim.height; fx++) {
1128 for (dim_t fy = 0; fy < kdim.width; fy++) {
1129 sdim_t ox = x + fx;
1130 sdim_t oy = y + fy;
1131
1132 // Ignore index access below zero (this is due to padding).
1133 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
1134 oy >= ssize_t(idim.w)) {
1135 continue;
1136 }
1137
1138 sum += float(inW.at({n, (dim_t)ox, (dim_t)oy, z}));
1139 }
1140 }
1141 outW.at({n, ax, ay, z}) = ElemTy(sum / filterArea);
1142 } // W
1143 } // H
1144 } // C
1145 } // N
1146 }
1147
fwdAvgPoolInstI8Impl(const AvgPoolInst * I)1148 void BoundInterpreterFunction::fwdAvgPoolInstI8Impl(const AvgPoolInst *I) {
1149 ShapeNHWC odim(I->getDest()->dims());
1150 ShapeNHWC idim(I->getSrc()->dims());
1151
1152 PaddingTLBR pdim(I->getPads());
1153 ShapeHW kdim(I->getKernels());
1154 ShapeHW sdim(I->getStrides());
1155 // Implement the avg pooling operation as defined here:
1156 // https://arxiv.org/abs/1312.4400
1157 float filterArea = kdim.height * kdim.width;
1158
1159 auto inW = getWeightHandle<int8_t>(I->getSrc());
1160 auto outW = getWeightHandle<int8_t>(I->getDest());
1161 TensorQuantizationParams inQP{I->getSrc()->getType()->getScale(),
1162 I->getSrc()->getType()->getOffset()};
1163 TensorQuantizationParams outQP{I->getDest()->getType()->getScale(),
1164 I->getDest()->getType()->getOffset()};
1165
1166 // For each input in the batch:
1167 for (dim_t n = 0; n < odim.n; n++) {
1168 // For each layer in the output tensor:
1169 for (dim_t z = 0; z < idim.c; z++) {
1170 // For each convolution 'jump' in the input tensor:
1171 ssize_t x = -ssize_t(pdim.top);
1172 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1173 ssize_t y = -ssize_t(pdim.left);
1174 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1175 int32_t sum = 0;
1176
1177 for (dim_t fx = 0; fx < kdim.height; fx++) {
1178 for (dim_t fy = 0; fy < kdim.width; fy++) {
1179 sdim_t ox = x + fx;
1180 sdim_t oy = y + fy;
1181
1182 // Ignore index access below zero (this is due to padding).
1183 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
1184 oy >= ssize_t(idim.w)) {
1185 continue;
1186 }
1187
1188 sum += inW.at({n, (dim_t)ox, (dim_t)oy, z}) - inQP.offset;
1189 }
1190 }
1191 // Instead of dividing by filterArea, just change scale.
1192 outW.at({n, ax, ay, z}) = quantization::clip<int32_t, int8_t>(
1193 std::round(float(sum) * (inQP.scale / outQP.scale / filterArea) +
1194 outQP.offset));
1195 } // W
1196 } // H
1197 } // C
1198 } // N
1199 }
1200
1201 template <typename ElemTy>
fwdAvgPool3DInstFloatImpl(const AvgPoolInst * I)1202 void BoundInterpreterFunction::fwdAvgPool3DInstFloatImpl(const AvgPoolInst *I) {
1203 staticAssertFloatingPointType(ElemTy);
1204
1205 ShapeNTHWC odim(I->getDest()->dims());
1206 ShapeNTHWC idim(I->getSrc()->dims());
1207
1208 PaddingNFTBLR pdim(I->getPads());
1209 ShapeTHW kdim(I->getKernels());
1210 ShapeTHW sdim(I->getStrides());
1211 // Implement the avg pooling operation as defined here:
1212 // https://arxiv.org/abs/1312.4400
1213 float filterArea = kdim.temporal_frames * kdim.height * kdim.width;
1214
1215 auto inW = getWeightHandle<ElemTy>(I->getSrc());
1216 auto outW = getWeightHandle<ElemTy>(I->getDest());
1217
1218 // For each input in the batch:
1219 for (dim_t n = 0; n < odim.n; n++) {
1220 // For each layer in the output tensor:
1221 for (dim_t z = 0; z < idim.c; z++) {
1222 // For each convolution 'jump' in the input tensor:
1223 ssize_t t = -ssize_t(pdim.near);
1224 for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
1225 ssize_t x = -ssize_t(pdim.top);
1226 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1227 ssize_t y = -ssize_t(pdim.left);
1228 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1229 float sum = 0;
1230
1231 for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
1232 for (dim_t fx = 0; fx < kdim.height; fx++) {
1233 for (dim_t fy = 0; fy < kdim.width; fy++) {
1234 sdim_t ot = t + ft;
1235 sdim_t ox = x + fx;
1236 sdim_t oy = y + fy;
1237
1238 // Ignore index access below zero (this is due to padding).
1239 if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
1240 ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) {
1241 continue;
1242 }
1243
1244 sum += float(inW.at({n, (dim_t)ot, (dim_t)ox, (dim_t)oy, z}));
1245 }
1246 }
1247 outW.at({n, at, ax, ay, z}) = ElemTy(sum / filterArea);
1248 }
1249 } // W
1250 } // H
1251 } // T
1252 } // C
1253 } // N
1254 }
1255
fwdAvgPool3DInstI8Impl(const AvgPoolInst * I)1256 void BoundInterpreterFunction::fwdAvgPool3DInstI8Impl(const AvgPoolInst *I) {
1257 ShapeNTHWC odim(I->getDest()->dims());
1258 ShapeNTHWC idim(I->getSrc()->dims());
1259
1260 PaddingNFTBLR pdim(I->getPads());
1261 ShapeTHW kdim(I->getKernels());
1262 ShapeTHW sdim(I->getStrides());
1263 // Implement the avg pooling operation as defined here:
1264 // https://arxiv.org/abs/1312.4400
1265 float filterArea = kdim.temporal_frames * kdim.height * kdim.width;
1266
1267 auto inW = getWeightHandle<int8_t>(I->getSrc());
1268 auto outW = getWeightHandle<int8_t>(I->getDest());
1269 TensorQuantizationParams inQP{I->getSrc()->getType()->getScale(),
1270 I->getSrc()->getType()->getOffset()};
1271 TensorQuantizationParams outQP{I->getDest()->getType()->getScale(),
1272 I->getDest()->getType()->getOffset()};
1273
1274 // For each input in the batch:
1275 for (dim_t n = 0; n < odim.n; n++) {
1276 // For each layer in the output tensor:
1277 for (dim_t z = 0; z < idim.c; z++) {
1278 // For each convolution 'jump' in the input tensor:
1279 ssize_t t = -ssize_t(pdim.near);
1280 for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
1281 ssize_t x = -ssize_t(pdim.top);
1282 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1283 ssize_t y = -ssize_t(pdim.left);
1284 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1285 int32_t sum = 0;
1286
1287 for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
1288 for (dim_t fx = 0; fx < kdim.height; fx++) {
1289 for (dim_t fy = 0; fy < kdim.width; fy++) {
1290 sdim_t ot = t + ft;
1291 sdim_t ox = x + fx;
1292 sdim_t oy = y + fy;
1293
1294 // Ignore index access below zero (this is due to padding).
1295 if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
1296 ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) {
1297 continue;
1298 }
1299
1300 sum += inW.at({n, (dim_t)ot, (dim_t)ox, (dim_t)oy, z}) -
1301 inQP.offset;
1302 }
1303 }
1304 }
1305 // Instead of dividing by filterArea, just change scale.
1306 outW.at({n, at, ax, ay, z}) =
1307 quantization::clip<int32_t, int8_t>(std::round(
1308 float(sum) * (inQP.scale / outQP.scale / filterArea) +
1309 outQP.offset));
1310 } // W
1311 } // H
1312 } // T
1313 } // C
1314 } // N
1315 }
1316
fwdAvgPoolInst(const AvgPoolInst * I)1317 void BoundInterpreterFunction::fwdAvgPoolInst(const AvgPoolInst *I) {
1318 bool isConv3D = is3DData(ConvolutionLayout(I->getLayout()));
1319 bool isQuantized = I->getSrc()->getType()->isQuantizedType();
1320
1321 if (isConv3D) {
1322 if (isQuantized) {
1323 fwdAvgPool3DInstI8Impl(I);
1324 } else {
1325 dispatchFloatingPointImpl(fwdAvgPool3DInstFloatImpl,
1326 I->getSrc()->getElementType(), I);
1327 }
1328 } else {
1329 if (isQuantized) {
1330 fwdAvgPoolInstI8Impl(I);
1331 } else {
1332 dispatchFloatingPointImpl(fwdAvgPoolInstFloatImpl,
1333 I->getSrc()->getElementType(), I);
1334 }
1335 }
1336 }
1337
1338 template <typename ElemTy>
fwdAdaptiveAvgPoolInstFloatImpl(const AdaptiveAvgPoolInst * I)1339 void BoundInterpreterFunction::fwdAdaptiveAvgPoolInstFloatImpl(
1340 const AdaptiveAvgPoolInst *I) {
1341 staticAssertFloatingPointType(ElemTy);
1342
1343 ShapeNHWC odim(I->getDest()->dims());
1344 ShapeNHWC idim(I->getSrc()->dims());
1345
1346 auto inW = getWeightHandle<ElemTy>(I->getSrc());
1347 auto outW = getWeightHandle<ElemTy>(I->getDest());
1348
1349 // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
1350 #define START_IND(a, b, c) (size_t) std::floor((float)((a) * (c)) / (b))
1351 #define END_IND(a, b, c) (size_t) std::ceil((float)(((a) + 1) * (c)) / (b))
1352
1353 // For each input in the batch:
1354 for (dim_t n = 0; n < odim.n; n++) {
1355 // For each layer in the output tensor:
1356 for (dim_t z = 0; z < idim.c; z++) {
1357 // For each value in the output tensor:
1358 for (dim_t ax = 0; ax < odim.h; ax++) {
1359
1360 dim_t x = START_IND(ax, odim.h, idim.h);
1361 dim_t kH = END_IND(ax, odim.h, idim.h) - x;
1362
1363 for (dim_t ay = 0; ay < odim.w; ay++) {
1364
1365 dim_t y = START_IND(ay, odim.w, idim.w);
1366 dim_t kW = END_IND(ay, odim.w, idim.w) - y;
1367
1368 float sum = 0;
1369 for (dim_t fx = 0; fx < kH; fx++) {
1370 for (dim_t fy = 0; fy < kW; fy++) {
1371 dim_t ox = x + fx;
1372 dim_t oy = y + fy;
1373
1374 sum += float(inW.at({n, ox, oy, z}));
1375 }
1376 }
1377 outW.at({n, ax, ay, z}) = ElemTy(sum / kW / kH);
1378 } // W
1379 } // H
1380 } // C
1381 } // N
1382 #undef START_IND
1383 #undef END_IND
1384 }
1385
fwdAdaptiveAvgPoolInstI8Impl(const AdaptiveAvgPoolInst * I)1386 void BoundInterpreterFunction::fwdAdaptiveAvgPoolInstI8Impl(
1387 const AdaptiveAvgPoolInst *I) {
1388 ShapeNHWC odim(I->getDest()->dims());
1389 ShapeNHWC idim(I->getSrc()->dims());
1390
1391 auto inW = getWeightHandle<int8_t>(I->getSrc());
1392 auto outW = getWeightHandle<int8_t>(I->getDest());
1393
1394 TensorQuantizationParams inQP{I->getSrc()->getType()->getScale(),
1395 I->getSrc()->getType()->getOffset()};
1396 TensorQuantizationParams outQP{I->getDest()->getType()->getScale(),
1397 I->getDest()->getType()->getOffset()};
1398
1399 // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
1400 #define START_IND(a, b, c) (size_t) std::floor((float)((a) * (c)) / (b))
1401 #define END_IND(a, b, c) (size_t) std::ceil((float)(((a) + 1) * (c)) / (b))
1402
1403 // For each input in the batch:
1404 for (dim_t n = 0; n < odim.n; n++) {
1405 // For each layer in the output tensor:
1406 for (dim_t z = 0; z < idim.c; z++) {
1407 // For each value in the output tensor:
1408 for (dim_t ax = 0; ax < odim.h; ax++) {
1409
1410 dim_t x = START_IND(ax, odim.h, idim.h);
1411 dim_t kH = END_IND(ax, odim.h, idim.h) - x;
1412
1413 for (dim_t ay = 0; ay < odim.w; ay++) {
1414
1415 dim_t y = START_IND(ay, odim.w, idim.w);
1416 dim_t kW = END_IND(ay, odim.w, idim.w) - y;
1417
1418 int32_t sum = 0;
1419 for (dim_t fx = 0; fx < kH; fx++) {
1420 for (dim_t fy = 0; fy < kW; fy++) {
1421 dim_t ox = x + fx;
1422 dim_t oy = y + fy;
1423
1424 sum += inW.at({n, ox, oy, z}) - inQP.offset;
1425 }
1426 }
1427
1428 outW.at({n, ax, ay, z}) = quantization::clip<int32_t, int8_t>(
1429 std::round(float(sum) * (inQP.scale / outQP.scale / kW / kH) +
1430 outQP.offset));
1431 } // W
1432 } // H
1433 } // C
1434 } // N
1435 #undef START_IND
1436 #undef END_IND
1437 }
1438
fwdAdaptiveAvgPoolInst(const AdaptiveAvgPoolInst * I)1439 void BoundInterpreterFunction::fwdAdaptiveAvgPoolInst(
1440 const AdaptiveAvgPoolInst *I) {
1441 if (I->getSrc()->getType()->isQuantizedType()) {
1442 fwdAdaptiveAvgPoolInstI8Impl(I);
1443 return;
1444 }
1445
1446 dispatchFloatingPointImpl(fwdAdaptiveAvgPoolInstFloatImpl,
1447 I->getSrc()->getElementType(), I);
1448 }
1449
fwdAdaptiveAvgPoolGradInst(const AdaptiveAvgPoolGradInst * I)1450 void BoundInterpreterFunction::fwdAdaptiveAvgPoolGradInst(
1451 const AdaptiveAvgPoolGradInst *I) {
1452 auto inG = getWeightHandle(I->getSrcGrad());
1453 auto outW = getWeightHandle(I->getDest());
1454 auto outG = getWeightHandle(I->getDestGrad());
1455
1456 inG.clear();
1457
1458 ShapeNHWC odim(outW.dims());
1459 ShapeNHWC idim(inG.dims());
1460
1461 const float gradCoefficient = 1. / (odim.h * odim.w);
1462
1463 #define START_IND(a, b, c) (size_t) std::floor((float)((a) * (c)) / (b))
1464 #define END_IND(a, b, c) (size_t) std::ceil((float)(((a) + 1) * (c)) / (b))
1465
1466 // https://software.intel.com/en-us/daal-programming-guide-2d-average-pooling-backward-layer
1467 // For each input in the batch:
1468 for (dim_t n = 0; n < odim.n; n++) {
1469 // For each layer in the output tensor:
1470 for (dim_t z = 0; z < idim.c; z++) {
1471 // For each value in the output tensor:
1472 for (dim_t ax = 0; ax < odim.h; ax++) {
1473
1474 dim_t x = START_IND(ax, odim.h, idim.h);
1475 dim_t kH = END_IND(ax, odim.h, idim.h) - x;
1476
1477 for (dim_t ay = 0; ay < odim.w; ay++) {
1478
1479 dim_t y = START_IND(ay, odim.w, idim.w);
1480 dim_t kW = END_IND(ay, odim.w, idim.w) - y;
1481
1482 const float chainGrad = outG.at({n, ax, ay, z}) * gradCoefficient;
1483
1484 for (dim_t fx = 0; fx < kH; fx++) {
1485 for (dim_t fy = 0; fy < kW; fy++) {
1486 dim_t ox = x + fx;
1487 dim_t oy = y + fy;
1488
1489 inG.at({n, ox, oy, z}) += chainGrad;
1490 }
1491 }
1492 } // W
1493 } // H
1494 } // C
1495 } // N
1496 #undef START_IND
1497 #undef END_IND
1498 }
1499
fwdMaxPoolWithArgmaxGradInst(const MaxPoolWithArgmaxGradInst * I)1500 void BoundInterpreterFunction::fwdMaxPoolWithArgmaxGradInst(
1501 const MaxPoolWithArgmaxGradInst *I) {
1502 auto inG = getWeightHandle(I->getSrcGrad());
1503 auto outW = getWeightHandle(I->getDest());
1504 auto outG = getWeightHandle(I->getDestGrad());
1505
1506 inG.clear();
1507
1508 ShapeNHWC idim(inG.dims());
1509 ShapeNHWC odim(outW.dims());
1510
1511 auto argmax = getWeightHandle<int64_t>(I->getArgmax());
1512
1513 // For each input in the batch:
1514 for (dim_t n = 0; n < odim.n; n++) {
1515
1516 // Compute the gradient. For each layer in the output tensor:
1517 for (dim_t z = 0; z < odim.c; z++) {
1518
1519 // For each convolution 'jump' in the input tensor:
1520 for (dim_t ax = 0; ax < odim.h; ax++) {
1521 for (dim_t ay = 0; ay < odim.w; ay++) {
1522 // Reuse precomputed linear index of max element from argmax.
1523 float chainGrad = outG.at({n, ax, ay, z});
1524 inG.raw(argmax.at({n, ax, ay, z})) += chainGrad;
1525 } // W
1526 } // H
1527 } // C
1528 } // N
1529 }
1530
fwdAvgPool2DGradInst(const AvgPoolGradInst * I)1531 void BoundInterpreterFunction::fwdAvgPool2DGradInst(const AvgPoolGradInst *I) {
1532 auto inG = getWeightHandle(I->getSrcGrad());
1533 auto outW = getWeightHandle(I->getDest());
1534 auto outG = getWeightHandle(I->getDestGrad());
1535
1536 ShapeNHWC odim(outW.dims());
1537 ShapeNHWC idim(inG.dims());
1538
1539 PaddingTLBR pdim(I->getPads());
1540 ShapeHW kdim(I->getKernels());
1541 ShapeHW sdim(I->getStrides());
1542
1543 inG.clear();
1544
1545 float filterArea = kdim.height * kdim.width;
1546
1547 // For each input in the batch:
1548 for (dim_t n = 0; n < odim.n; n++) {
1549
1550 // For each layer in the output tensor:
1551 for (dim_t z = 0; z < odim.c; z++) {
1552 // For each convolution 'jump' in the input tensor:
1553 ssize_t x = -ssize_t(pdim.top);
1554 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1555 ssize_t y = -ssize_t(pdim.left);
1556 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1557
1558 float dy = outG.at({n, ax, ay, z}) / filterArea;
1559
1560 for (dim_t fx = 0; fx < kdim.height; fx++) {
1561 for (dim_t fy = 0; fy < kdim.width; fy++) {
1562 ssize_t ox = x + fx;
1563 ssize_t oy = y + fy;
1564
1565 // Ignore index access below zero (this is due to padding).
1566 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
1567 oy >= ssize_t(idim.w)) {
1568 continue;
1569 }
1570 inG.at({n, (dim_t)ox, (dim_t)oy, z}) += dy;
1571 }
1572 }
1573 } // W
1574 } // H
1575 } // C
1576 } // N
1577 }
1578
fwdAvgPool3DGradInst(const AvgPoolGradInst * I)1579 void BoundInterpreterFunction::fwdAvgPool3DGradInst(const AvgPoolGradInst *I) {
1580 auto inG = getWeightHandle(I->getSrcGrad());
1581 auto outW = getWeightHandle(I->getDest());
1582 auto outG = getWeightHandle(I->getDestGrad());
1583
1584 ShapeNTHWC odim(outW.dims());
1585 ShapeNTHWC idim(inG.dims());
1586
1587 PaddingNFTBLR pdim(I->getPads());
1588 ShapeTHW kdim(I->getKernels());
1589 ShapeTHW sdim(I->getStrides());
1590
1591 inG.clear();
1592
1593 float filterArea = kdim.temporal_frames * kdim.height * kdim.width;
1594
1595 // For each input in the batch:
1596 for (dim_t n = 0; n < odim.n; n++) {
1597
1598 // For each layer in the output tensor:
1599 for (dim_t z = 0; z < odim.c; z++) {
1600 // For each convolution 'jump' in the input tensor:
1601 ssize_t t = -ssize_t(pdim.near);
1602 for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
1603 ssize_t x = -ssize_t(pdim.top);
1604 for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1605 ssize_t y = -ssize_t(pdim.left);
1606 for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1607
1608 float dy = outG.at({n, at, ax, ay, z}) / filterArea;
1609
1610 for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
1611 for (dim_t fx = 0; fx < kdim.height; fx++) {
1612 for (dim_t fy = 0; fy < kdim.width; fy++) {
1613 ssize_t ot = t + ft;
1614 ssize_t ox = x + fx;
1615 ssize_t oy = y + fy;
1616
1617 // Ignore index access below zero (this is due to padding).
1618 if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
1619 ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) {
1620 continue;
1621 }
1622 inG.at({n, (dim_t)ot, (dim_t)ox, (dim_t)oy, z}) += dy;
1623 }
1624 }
1625 }
1626 } // W
1627 } // H
1628 } // T
1629 } // C
1630 } // N
1631 }
1632
fwdAvgPoolGradInst(const AvgPoolGradInst * I)1633 void BoundInterpreterFunction::fwdAvgPoolGradInst(const AvgPoolGradInst *I) {
1634 bool isConv3D = is3DData(ConvolutionLayout(I->getLayout()));
1635
1636 if (isConv3D) {
1637 fwdAvgPool3DGradInst(I);
1638 } else {
1639 fwdAvgPool2DGradInst(I);
1640 }
1641 }
1642
1643 //===----------------------------------------------------------------------===//
1644 // Activation functions
1645 //===----------------------------------------------------------------------===//
1646
fwdReluInst(const ReluInst *)1647 void BoundInterpreterFunction::fwdReluInst(const ReluInst *) {
1648 DCHECK(!"Found ReluInst but Relu is lowered on Interpreter");
1649 }
1650
fwdClipInst(const ClipInst *)1651 void BoundInterpreterFunction::fwdClipInst(const ClipInst *) {
1652 DCHECK(!"Found ClipInst but Clip is lowered on Interpreter");
1653 }
1654
1655 template <typename ElemTy>
fwdSigmoidInstFloatImpl(const SigmoidInst * I)1656 void BoundInterpreterFunction::fwdSigmoidInstFloatImpl(const SigmoidInst *I) {
1657 staticAssertFloatingPointType(ElemTy);
1658
1659 auto inW = getWeightHandle<ElemTy>(I->getSrc());
1660 auto outW = getWeightHandle<ElemTy>(I->getDest());
1661
1662 for (dim_t i = 0, e = outW.size(); i < e; i++) {
1663 float val = inW.raw(i);
1664 outW.raw(i) = ElemTy(1 / (1 + std::exp(-val)));
1665 }
1666 }
1667
fwdSigmoidInst(const SigmoidInst * I)1668 void BoundInterpreterFunction::fwdSigmoidInst(const SigmoidInst *I) {
1669 dispatchFloatingPointImpl(fwdSigmoidInstFloatImpl,
1670 I->getSrc()->getElementType(), I);
1671 }
1672
1673 template <typename ElemTy>
fwdTanhInstFloatImpl(const TanhInst * I)1674 void BoundInterpreterFunction::fwdTanhInstFloatImpl(const TanhInst *I) {
1675 staticAssertFloatingPointType(ElemTy);
1676
1677 auto inW = getWeightHandle<ElemTy>(I->getSrc());
1678 auto outW = getWeightHandle<ElemTy>(I->getDest());
1679
1680 for (dim_t i = 0, e = inW.size(); i < e; i++) {
1681 float val = inW.raw(i);
1682 outW.raw(i) = ElemTy(std::tanh(val));
1683 }
1684 }
1685
fwdTanhInst(const TanhInst * I)1686 void BoundInterpreterFunction::fwdTanhInst(const TanhInst *I) {
1687 dispatchFloatingPointImpl(fwdTanhInstFloatImpl, I->getSrc()->getElementType(),
1688 I);
1689 }
1690
1691 //===----------------------------------------------------------------------===//
1692 // Loss Functions (Softmax/regression/...)
1693 //===----------------------------------------------------------------------===//
1694
1695 template <typename ElemTy>
fwdSoftMaxInstImpl(const SoftMaxInst * I)1696 void BoundInterpreterFunction::fwdSoftMaxInstImpl(const SoftMaxInst *I) {
1697 staticAssertFloatingPointType(ElemTy);
1698
1699 auto inW = getWeightHandle<ElemTy>(I->getSrc());
1700 auto outW = getWeightHandle<ElemTy>(I->getDest());
1701 auto idim = inW.dims();
1702
1703 for (dim_t n = 0; n < idim[0]; n++) {
1704 // Find Max.
1705 float max = float(inW.at({n, 0}));
1706 for (dim_t i = 1; i < idim[1]; i++) {
1707 max = std::max(max, float(inW.at({n, i})));
1708 }
1709
1710 // Compute exp.
1711 float sum = 0;
1712 for (dim_t i = 0; i < idim[1]; i++) {
1713 float e = std::exp(float(inW.at({n, i})) - max);
1714 sum += e;
1715 outW.at({n, i}) = ElemTy(e);
1716 }
1717
1718 // Normalize the output.
1719 for (dim_t i = 0; i < idim[1]; i++) {
1720 outW.at({n, i}) = ElemTy(float(outW.at({n, i})) / sum);
1721 }
1722 } // N
1723 }
1724
fwdSoftMaxInst(const SoftMaxInst * I)1725 void BoundInterpreterFunction::fwdSoftMaxInst(const SoftMaxInst *I) {
1726 dispatchFloatingPointImpl(fwdSoftMaxInstImpl, I->getSrc()->getElementType(),
1727 I);
1728 }
1729
fwdSoftMaxGradInst(const SoftMaxGradInst * I)1730 void BoundInterpreterFunction::fwdSoftMaxGradInst(const SoftMaxGradInst *I) {
1731 auto inG = getWeightHandle(I->getSrcGrad());
1732 auto idim = inG.dims();
1733 auto outW = getWeightHandle(I->getOrigDest());
1734 auto selectedH = getWeightHandle<int64_t>(I->getSelected());
1735
1736 inG.clear();
1737
1738 // http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/
1739 // https://stats.stackexchange.com/questions/79454/softmax-layer-in-a-neural-network
1740 for (dim_t n = 0; n < idim[0]; n++) {
1741 for (dim_t i = 0; i < idim[1]; i++) {
1742 float delta = (selectedH.at({n, 0}) == (int64_t)i);
1743 inG.at({n, i}) = outW.at({n, i}) - delta;
1744 }
1745 }
1746 }
1747
1748 template <typename ElemTy>
fwdCrossEntropyLossInstFloatImpl(const CrossEntropyLossInst * I)1749 void BoundInterpreterFunction::fwdCrossEntropyLossInstFloatImpl(
1750 const CrossEntropyLossInst *I) {
1751 staticAssertFloatingPointType(ElemTy);
1752
1753 auto P = getWeightHandle<ElemTy>(I->getP());
1754 auto labels = getWeightHandle<int64_t>(I->getLabels());
1755 auto CE = getWeightHandle<ElemTy>(I->getCE());
1756 auto dims = P.dims();
1757 CE.clear();
1758 for (dim_t n = 0; n < dims[0]; ++n) {
1759 assert(labels.raw(n) >= 0 && "Cannot use negative index.");
1760 dim_t y = labels.raw(n);
1761 float p_n = P.at({n, y});
1762 CE.at({0}) -= log(p_n);
1763 }
1764 }
1765
fwdCrossEntropyLossInst(const CrossEntropyLossInst * I)1766 void BoundInterpreterFunction::fwdCrossEntropyLossInst(
1767 const CrossEntropyLossInst *I) {
1768 dispatchFloatingPointImpl(fwdCrossEntropyLossInstFloatImpl,
1769 I->getP()->getElementType(), I);
1770 }
1771
fwdCrossEntropyLossGradInst(const CrossEntropyLossGradInst * I)1772 void BoundInterpreterFunction::fwdCrossEntropyLossGradInst(
1773 const CrossEntropyLossGradInst *I) {
1774 auto P = getWeightHandle(I->getP());
1775 auto Labels = getWeightHandle<int64_t>(I->getLabels());
1776 auto PGrad = getWeightHandle(I->getPgrad());
1777 auto dims = PGrad.dims();
1778 PGrad.clear();
1779 for (dim_t n = 0; n < dims[0]; ++n) {
1780 assert(Labels.raw(n) >= 0 && "Cannot use negative index.");
1781 dim_t y = Labels.raw(n);
1782 PGrad.at({n, y}) = -1 / P.at({n, y}); // * CEGrad.at({0})
1783 }
1784 }
1785
1786 //===----------------------------------------------------------------------===//
1787 // Tensor shape (copy/transpose/concat/...)
1788 //===----------------------------------------------------------------------===//
1789
fwdCopyInst(const CopyInst * I)1790 void BoundInterpreterFunction::fwdCopyInst(const CopyInst *I) {
1791 auto inT = getTensor(I->getSrc());
1792 auto outT = getTensor(I->getDest());
1793 outT->copyRawFrom(inT);
1794 }
1795
fwdTransposeInst(const TransposeInst * I)1796 void BoundInterpreterFunction::fwdTransposeInst(const TransposeInst *I) {
1797 auto inT = getTensor(I->getSrc());
1798 (void)inT;
1799 auto outT = getTensor(I->getDest());
1800
1801 assert(outT->size() == inT->size() && "Invalid tensor dimensions");
1802
1803 if (I->getSrc()->getType()->isQuantizedType()) {
1804 inT->transpose(outT, I->getShuffle());
1805 } else {
1806 inT->transpose(outT, I->getShuffle());
1807 }
1808 }
1809
fwdTensorViewInst(const TensorViewInst * I)1810 void BoundInterpreterFunction::fwdTensorViewInst(const TensorViewInst *I) {
1811 getOrCreateUnownedTensor(I, I->getSrc(), I->getOffsets());
1812 }
1813
fwdSplatInst(const glow::SplatInst * I)1814 void BoundInterpreterFunction::fwdSplatInst(const glow::SplatInst *I) {
1815 auto *T = getTensor(I->getDest());
1816 ElemKind k = T->getElementType();
1817
1818 if (k == ElemKind::Int32ITy) {
1819 return T->getHandle<int32_t>().clear(I->getValue());
1820 }
1821
1822 if (k == ElemKind::Int64ITy) {
1823 return T->getHandle<int64_t>().clear(I->getValue());
1824 }
1825
1826 if (k == ElemKind::Int32ITy) {
1827 return T->getHandle<int32_t>().clear(I->getValue());
1828 }
1829
1830 if (k == ElemKind::FloatTy) {
1831 return T->getHandle<float>().clear(I->getValue());
1832 }
1833
1834 if (k == ElemKind::Float16Ty) {
1835 return T->getHandle<float16_t>().clear(I->getValue());
1836 }
1837
1838 if (k == ElemKind::BFloat16Ty) {
1839 return T->getHandle<bfloat16_t>().clear(I->getValue());
1840 }
1841
1842 if (k == ElemKind::BoolTy) {
1843 return T->getHandle<bool>().clear(static_cast<bool>(I->getValue()));
1844 }
1845
1846 if (k == ElemKind::Int8QTy) {
1847 // Quantize the requested floating point splat value into the correct
1848 // integer representation.
1849 auto destTy = I->getDest()->getType();
1850 TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()};
1851 float val = I->getValue();
1852 return T->getHandle<int8_t>().clear(quantization::quantize(val, destQ));
1853 }
1854
1855 if (k == ElemKind::BoolTy) {
1856 return T->getHandle<bool>().clear(static_cast<bool>(I->getValue()));
1857 }
1858
1859 llvm_unreachable("Unsupported tensor type");
1860 }
1861
fwdTouchInst(const glow::TouchInst *)1862 void BoundInterpreterFunction::fwdTouchInst(const glow::TouchInst *) {
1863 // Do nothing for a TouchInst
1864 }
1865
fwdInsertTensorInst(const glow::InsertTensorInst * I)1866 void BoundInterpreterFunction::fwdInsertTensorInst(
1867 const glow::InsertTensorInst *I) {
1868 Tensor *outT = getTensor(I->getDest());
1869 Tensor *inT = getTensor(I->getSrc());
1870 ElemKind k = outT->getElementType();
1871 #define TYPED_INSERT(TY, TYPEKIND) \
1872 if (k == TYPEKIND) { \
1873 auto OH = outT->getHandle<TY>(); \
1874 auto IH = inT->getHandle<TY>(); \
1875 return OH.insertTensors(IH, I->getOffsets(), I->getCount(), I->getAxis()); \
1876 }
1877
1878 TYPED_INSERT(int64_t, ElemKind::Int64ITy);
1879 TYPED_INSERT(int32_t, ElemKind::Int32ITy);
1880 TYPED_INSERT(float, ElemKind::FloatTy);
1881 TYPED_INSERT(float16_t, ElemKind::Float16Ty);
1882 TYPED_INSERT(bfloat16_t, ElemKind::BFloat16Ty);
1883 TYPED_INSERT(int8_t, ElemKind::Int8QTy);
1884 TYPED_INSERT(bool, ElemKind::BoolTy);
1885 #undef TYPED_INSERT
1886
1887 llvm_unreachable("Unsupported tensor type");
1888 }
1889
fwdExtractTensorInst(const glow::ExtractTensorInst * I)1890 void BoundInterpreterFunction::fwdExtractTensorInst(
1891 const glow::ExtractTensorInst *I) {
1892 Tensor *outT = getTensor(I->getDest());
1893 Tensor *inT = getTensor(I->getSrc());
1894 ElemKind k = outT->getElementType();
1895 #define TYPED_INSERT(TY, TYPEKIND) \
1896 if (k == TYPEKIND) { \
1897 auto OH = outT->getHandle<TY>(); \
1898 auto IH = inT->getHandle<TY>(); \
1899 return IH.extractTensors(OH, I->getOffsets()); \
1900 }
1901
1902 TYPED_INSERT(int64_t, ElemKind::Int64ITy);
1903 TYPED_INSERT(float, ElemKind::FloatTy);
1904 TYPED_INSERT(float16_t, ElemKind::Float16Ty);
1905 TYPED_INSERT(bfloat16_t, ElemKind::BFloat16Ty);
1906 TYPED_INSERT(int8_t, ElemKind::Int8QTy);
1907 TYPED_INSERT(int32_t, ElemKind::Int32QTy);
1908 TYPED_INSERT(int32_t, ElemKind::Int32ITy);
1909 #undef TYPED_INSERT
1910
1911 llvm_unreachable("Unsupported tensor type");
1912 }
1913
1914 template <typename ElemTy>
fwdGatherInstImpl(const glow::GatherInst * I)1915 void BoundInterpreterFunction::fwdGatherInstImpl(const glow::GatherInst *I) {
1916 Tensor *dataT = getTensor(I->getData());
1917 auto &dataTy = dataT->getType();
1918 Tensor *indicesT = getTensor(I->getIndices());
1919 Tensor *outT = getTensor(I->getDest());
1920 unsigned_t batchDims = I->getBatchDims();
1921
1922 size_t out_p = 0;
1923 dim_t elementSize = dataTy.getElementSize();
1924 // The size of the sample in the batch.
1925 dim_t dataSampleSize = dataTy.getSliceSize(batchDims) * elementSize;
1926 // The size of the slices that we gather.
1927 dim_t dataSliceSize = dataTy.getSliceSize(batchDims + 1) * elementSize;
1928
1929 // Calculate the size of each sample in the batch.
1930 dim_t numSamples = (dataT->size() * elementSize) / dataSampleSize;
1931
1932 // Calculate number of samples in the batch.
1933 dim_t batchSize = dataTy.dims()[batchDims];
1934 (void)batchSize;
1935
1936 // For each sample in the batch:
1937 for (dim_t sample = 0; sample < numSamples; sample++) {
1938 dim_t sampleStart = sample * dataSampleSize;
1939
1940 // For each slice (small fragment) that we copy from the source memory:
1941 for (dim_t i = 0, end = indicesT->size(); i < end; i++) {
1942 dim_t slice = indicesT->getHandle<ElemTy>().raw(i);
1943 assert(slice < batchSize && "Invalid index seen during Gather operation");
1944 std::copy(
1945 &dataT->getUnsafePtr()[sampleStart + dataSliceSize * slice],
1946 &dataT->getUnsafePtr()[sampleStart + dataSliceSize * (slice + 1)],
1947 &outT->getUnsafePtr()[out_p]);
1948 out_p += dataSliceSize;
1949 }
1950 }
1951 }
1952
fwdGatherInst(const glow::GatherInst * I)1953 void BoundInterpreterFunction::fwdGatherInst(const glow::GatherInst *I) {
1954 switch (I->getIndices()->getElementType()) {
1955 case ElemKind::Int64ITy:
1956 fwdGatherInstImpl<int64_t>(I);
1957 break;
1958 case ElemKind::Int32ITy:
1959 fwdGatherInstImpl<int32_t>(I);
1960 break;
1961 default:
1962 llvm_unreachable("Unsupported type for indices input of Gather.");
1963 }
1964 }
1965
1966 template <typename ElemTy>
fwdGatherRangesInstImpl(const glow::GatherRangesInst * I)1967 void BoundInterpreterFunction::fwdGatherRangesInstImpl(
1968 const glow::GatherRangesInst *I) {
1969 Tensor *dataT = getTensor(I->getData());
1970 auto &dataTy = dataT->getType();
1971 Tensor *rangesT = getTensor(I->getRanges());
1972 auto &rangesTy = rangesT->getType();
1973 Tensor *outT = getTensor(I->getOutput());
1974 Tensor *lengthsT = getTensor(I->getLengths());
1975
1976 // Offset into the output tensor that keeps track of where to start
1977 // copying data.
1978 size_t outP = 0;
1979
1980 unsigned dataElementSize = dataTy.getElementSize();
1981 dim_t numExamples = rangesTy.dims()[0];
1982 dim_t exampleSize = rangesTy.dims()[1];
1983
1984 // Keep track of the total number of elements gathered across all
1985 // examples for a sanity check later.
1986 dim_t grandTotalLen = 0;
1987
1988 // For each example in ranges:
1989 for (dim_t example = 0; example < numExamples; ++example) {
1990 // Keep a running total of the lengths of all ranges in this example
1991 // to record into lengthsT once the entire example is processed.
1992 ElemTy totalLen = 0;
1993
1994 // For each range in the example:
1995 for (dim_t range = 0; range < exampleSize; ++range) {
1996 // Get the start index and range length.
1997 ElemTy startIdx = rangesT->getHandle<ElemTy>().at({example, range, 0});
1998 ElemTy len = rangesT->getHandle<ElemTy>().at({example, range, 1});
1999
2000 // Add the length of this current range to the example length counter.
2001 totalLen += len;
2002
2003 // Compute the start and end offsets.
2004 dim_t startOffset = startIdx * dataElementSize;
2005 dim_t endOffset = startOffset + (len * dataElementSize);
2006
2007 // Sanity checks on the offsets.
2008 assert(startOffset < dataT->getSizeInBytes());
2009 assert(endOffset <= dataT->getSizeInBytes());
2010 assert(endOffset >= startOffset);
2011 assert(outP < outT->getSizeInBytes());
2012 assert((outP + (len * dataElementSize)) <= outT->getSizeInBytes());
2013
2014 // Copy the specified data to outT.
2015 std::copy(&dataT->getUnsafePtr()[startOffset],
2016 &dataT->getUnsafePtr()[endOffset], &outT->getUnsafePtr()[outP]);
2017
2018 // Advance the offset into outT.
2019 outP += len * dataElementSize;
2020 }
2021
2022 // Record the total number of elements gathered for the example in
2023 // lengthsT.
2024 lengthsT->getHandle<ElemTy>().at({example}) = totalLen;
2025
2026 // Add the total length of the entire example to the grand total.
2027 grandTotalLen += static_cast<size_t>(totalLen);
2028 }
2029
2030 // Make sure that number of elements written to outT is equal to the
2031 // total of all elements in lengthsT.
2032 assert(grandTotalLen == (outP / dataElementSize));
2033 }
2034
fwdGatherRangesInst(const glow::GatherRangesInst * I)2035 void BoundInterpreterFunction::fwdGatherRangesInst(
2036 const glow::GatherRangesInst *I) {
2037 switch (I->getRanges()->getElementType()) {
2038 case ElemKind::Int64ITy:
2039 fwdGatherRangesInstImpl<int64_t>(I);
2040 break;
2041 case ElemKind::Int32ITy:
2042 fwdGatherRangesInstImpl<int32_t>(I);
2043 break;
2044 default:
2045 llvm_unreachable("Unsupported type for ranges input of GatherRanges.");
2046 }
2047 }
2048
2049 template <typename ElemTy>
fwdScatterDataInstCopyImpl(const glow::ScatterDataInst * I)2050 void BoundInterpreterFunction::fwdScatterDataInstCopyImpl(
2051 const glow::ScatterDataInst *I) {
2052 Tensor *dataT = getTensor(I->getData());
2053 Tensor *indicesT = getTensor(I->getIndices());
2054 Tensor *slicesT = getTensor(I->getSlices());
2055
2056 assert(indicesT->dims().size() == 2 &&
2057 "Index should be stored in 2D tensor!");
2058 const dim_t dataSliceSize = slicesT->size() / slicesT->dims()[0] *
2059 slicesT->getType().getElementSize();
2060
2061 auto IH = indicesT->getHandle<int64_t>();
2062 // For each index, copy from the slice at that index into the location in data
2063 // given the offset from the indices tensor.
2064 for (dim_t i = 0, end = indicesT->dims()[0]; i < end; i++) {
2065 dim_t destDataIdx = 0;
2066 for (dim_t j = 0, e = indicesT->dims()[1]; j < e; j++) {
2067 destDataIdx *= dataT->dims()[j];
2068 destDataIdx += IH.at({i, j});
2069 }
2070 std::copy(&slicesT->getUnsafePtr()[i * dataSliceSize],
2071 &slicesT->getUnsafePtr()[(i + 1) * dataSliceSize],
2072 &dataT->getUnsafePtr()[dataSliceSize * destDataIdx]);
2073 }
2074 }
2075
2076 template <typename ElemTy>
fwdScatterDataInstAddFloatImpl(const glow::ScatterDataInst * I)2077 void BoundInterpreterFunction::fwdScatterDataInstAddFloatImpl(
2078 const glow::ScatterDataInst *I) {
2079 Tensor *dataT = getTensor(I->getData());
2080 Tensor *indicesT = getTensor(I->getIndices());
2081 Tensor *slicesT = getTensor(I->getSlices());
2082
2083 assert(!dataT->getType().isQuantizedType() && "Should be float type!");
2084 assert(!slicesT->getType().isQuantizedType() && "Should be float type!");
2085
2086 const size_t numSlices = slicesT->size() / slicesT->dims()[0];
2087
2088 auto IH = indicesT->getHandle<int64_t>();
2089 // For each index, copy from the slice at that index into the location in data
2090 // given the offset from the indices tensor.
2091 assert(indicesT->dims().size() == 2 &&
2092 "Multi-dimensional index should be stored in 2D tensor!");
2093 auto D = dataT->getHandle<ElemTy>(), S = slicesT->getHandle<ElemTy>();
2094 for (dim_t i = 0, end = indicesT->dims()[0]; i < end; i++) {
2095 size_t destDataIdx = 0;
2096 for (dim_t j = 0, e = indicesT->dims()[1]; j < e; j++) {
2097 destDataIdx *= dataT->dims()[j];
2098 destDataIdx += IH.at({i, j});
2099 }
2100 for (dim_t j = 0; j < numSlices; j++) {
2101 D.raw(destDataIdx * numSlices + j) += S.raw(i * numSlices + j);
2102 }
2103 }
2104 }
2105
2106 template <typename ElemTy>
fwdScatterDataInstAddQuantizedImpl(const glow::ScatterDataInst * I)2107 void BoundInterpreterFunction::fwdScatterDataInstAddQuantizedImpl(
2108 const glow::ScatterDataInst *I) {
2109 Tensor *dataT = getTensor(I->getData());
2110 Tensor *indicesT = getTensor(I->getIndices());
2111 Tensor *slicesT = getTensor(I->getSlices());
2112
2113 assert(dataT->getType().isQuantizedType() && "Should be quantized type!");
2114 assert(slicesT->getType().isQuantizedType() && "Should be quantized type!");
2115
2116 const dim_t numSlices = slicesT->size() / slicesT->dims()[0];
2117
2118 TensorQuantizationParams dataQ{dataT->getType().getScale(),
2119 dataT->getType().getOffset()};
2120 TensorQuantizationParams sliceQ{slicesT->getType().getScale(),
2121 slicesT->getType().getOffset()};
2122
2123 auto IH = indicesT->getHandle<int64_t>();
2124 // For each index, copy from the slice at that index into the location in data
2125 // given the offset from the indices tensor.
2126 assert(indicesT->dims().size() == 2 &&
2127 "Multi-dimensional index should be stored in 2D tensor!");
2128 auto D = dataT->getHandle<ElemTy>(), S = slicesT->getHandle<ElemTy>();
2129 for (dim_t i = 0, end = indicesT->dims()[0]; i < end; i++) {
2130 dim_t destDataIdx = 0;
2131 for (dim_t j = 0, e = indicesT->dims()[1]; j < e; j++) {
2132 destDataIdx *= dataT->dims()[j];
2133 destDataIdx += IH.at({i, j});
2134 }
2135 for (dim_t j = 0; j < numSlices; j++) {
2136 float lhs =
2137 quantization::dequantize(D.raw(destDataIdx * numSlices + j), dataQ);
2138 float rhs = quantization::dequantize(S.raw(i * numSlices + j), sliceQ);
2139 ElemTy result = quantization::quantize(lhs + rhs, dataQ);
2140 D.raw(destDataIdx * numSlices + j) = result;
2141 }
2142 }
2143 }
2144
fwdScatterDataInst(const glow::ScatterDataInst * I)2145 void BoundInterpreterFunction::fwdScatterDataInst(
2146 const glow::ScatterDataInst *I) {
2147 if (I->getCumulative()) {
2148 switch (I->getData()->getElementType()) {
2149 case ElemKind::FloatTy:
2150 fwdScatterDataInstAddFloatImpl<float>(I);
2151 break;
2152 case ElemKind::Int8QTy:
2153 fwdScatterDataInstAddQuantizedImpl<int8_t>(I);
2154 break;
2155 default:
2156 llvm_unreachable("Unsupported type for ScatterData.");
2157 }
2158 } else {
2159 switch (I->getData()->getElementType()) {
2160 case ElemKind::FloatTy:
2161 fwdScatterDataInstCopyImpl<float>(I);
2162 break;
2163 case ElemKind::Int8QTy:
2164 fwdScatterDataInstCopyImpl<int8_t>(I);
2165 break;
2166 default:
2167 llvm_unreachable("Unsupported type for ScatterData.");
2168 }
2169 }
2170 }
2171
2172 template <typename ElemTy>
fwdBatchOneHotImpl(const glow::BatchOneHotInst * I)2173 void BoundInterpreterFunction::fwdBatchOneHotImpl(
2174 const glow::BatchOneHotInst *I) {
2175 auto dataH = getWeightHandle<ElemTy>(I->getData());
2176 auto lengthsH = getWeightHandle<int32_t>(I->getLengths());
2177 auto valuesH = getWeightHandle<ElemTy>(I->getValues());
2178 auto destH = getWeightHandle<ElemTy>(I->getDest());
2179
2180 auto batchSize = dataH.dims()[0];
2181 auto featureCnt = dataH.dims()[1];
2182
2183 for (dim_t batchId = 0; batchId < batchSize; batchId++) {
2184 size_t offset = 0;
2185 for (dim_t featureId = 0; featureId < featureCnt; featureId++) {
2186 auto curValue = dataH.at({batchId, featureId});
2187 auto curLength = lengthsH.at({featureId});
2188 for (dim_t i = offset, e = offset + curLength; i != e; i++) {
2189 destH.at({batchId, i}) = curValue == valuesH.at({i});
2190 }
2191 offset += curLength;
2192 }
2193 assert(offset == destH.dims()[1] &&
2194 "Sum of Lengths must be equal to size of Values");
2195 }
2196 }
2197
fwdBatchOneHotInst(const glow::BatchOneHotInst * I)2198 void BoundInterpreterFunction::fwdBatchOneHotInst(
2199 const glow::BatchOneHotInst *I) {
2200 switch (I->getData()->getElementType()) {
2201 case ElemKind::Int64ITy:
2202 fwdBatchOneHotImpl<int64_t>(I);
2203 break;
2204 case ElemKind::Int32ITy:
2205 fwdBatchOneHotImpl<int32_t>(I);
2206 break;
2207 case ElemKind::Int8QTy:
2208 fwdBatchOneHotImpl<int8_t>(I);
2209 break;
2210 default:
2211 dispatchFloatingPointImpl(fwdBatchOneHotImpl,
2212 I->getData()->getElementType(), I);
2213 }
2214 }
2215
2216 template <typename ElemTy>
fwdSpaceToDepthInstImpl(const glow::SpaceToDepthInst * I)2217 void BoundInterpreterFunction::fwdSpaceToDepthInstImpl(
2218 const glow::SpaceToDepthInst *I) {
2219 auto *inT = getTensor(I->getSrc());
2220 auto *outT = getTensor(I->getDest());
2221
2222 auto inH = inT->getHandle<ElemTy>();
2223 auto outH = outT->getHandle<ElemTy>();
2224
2225 unsigned blockSize = I->getBlockSize();
2226
2227 dim_t inDepth = inT->dims()[3];
2228
2229 dim_t outBatch = outT->dims()[0];
2230 dim_t outHeight = outT->dims()[1];
2231 dim_t outWidth = outT->dims()[2];
2232 dim_t outDepth = outT->dims()[3];
2233
2234 for (dim_t ob = 0; ob < outBatch; ++ob) {
2235 for (dim_t oh = 0; oh < outHeight; ++oh) {
2236 for (dim_t ow = 0; ow < outWidth; ++ow) {
2237 for (dim_t oc = 0; oc < outDepth; ++oc) {
2238 // Gets the block layer we are on
2239 dim_t blockDepthLayer = oc / inDepth;
2240 // every multiple of block size we reset to 0 offset
2241 dim_t iw = ow * blockSize + blockDepthLayer % blockSize;
2242 // every multiple of blockSize we start height traversal + 1
2243 dim_t ih = oh * blockSize + blockDepthLayer / blockSize;
2244 // at every multiple of inDepth index in to input depths resets to 0
2245 dim_t ic = oc % inDepth;
2246
2247 outH.at({ob, oh, ow, oc}) = inH.at({ob, ih, iw, ic});
2248 }
2249 }
2250 }
2251 }
2252 }
2253
fwdSpaceToDepthInst(const glow::SpaceToDepthInst * I)2254 void BoundInterpreterFunction::fwdSpaceToDepthInst(
2255 const glow::SpaceToDepthInst *I) {
2256 switch (I->getSrc()->getElementType()) {
2257 case ElemKind::FloatTy:
2258 fwdSpaceToDepthInstImpl<float>(I);
2259 break;
2260 case ElemKind::Int8QTy:
2261 fwdSpaceToDepthInstImpl<int8_t>(I);
2262 break;
2263 default:
2264 llvm_unreachable("Type is not supported");
2265 break;
2266 }
2267 }
2268
2269 template <typename ElemTy>
fwdResizeNearestInstImpl(const ResizeNearestInst * I)2270 void BoundInterpreterFunction::fwdResizeNearestInstImpl(
2271 const ResizeNearestInst *I) {
2272 auto inW = getWeightHandle<ElemTy>(I->getSrc());
2273 auto scale = I->getScale();
2274 auto outW = getWeightHandle<ElemTy>(I->getDest());
2275
2276 ShapeNHWC odim(outW.dims());
2277 ShapeNHWC idim(inW.dims());
2278
2279 for (dim_t ob = 0; ob < odim.n; ++ob) {
2280 auto ib = std::min(dim_t(ob / scale[0]), idim.n - 1);
2281 for (dim_t oh = 0; oh < odim.h; ++oh) {
2282 auto ih = std::min(dim_t(oh / scale[1]), idim.h - 1);
2283 for (dim_t ow = 0; ow < odim.w; ++ow) {
2284 auto iw = std::min(dim_t(ow / scale[2]), idim.w - 1);
2285 for (dim_t oc = 0; oc < odim.c; ++oc) {
2286 auto ic = std::min(dim_t(oc / scale[3]), idim.c - 1);
2287 outW.at({ob, oh, ow, oc}) = inW.at({ib, ih, iw, ic});
2288 }
2289 }
2290 }
2291 }
2292 }
2293
fwdResizeNearestInst(const ResizeNearestInst * I)2294 void BoundInterpreterFunction::fwdResizeNearestInst(
2295 const ResizeNearestInst *I) {
2296 if (getTensor(I->getSrc())->getType().isQuantizedType()) {
2297 dispatchQuantizedImpl(fwdResizeNearestInstImpl,
2298 I->getSrc()->getElementType(), I);
2299 return;
2300 }
2301
2302 dispatchImpl(fwdResizeNearestInstImpl, I->getSrc()->getElementType(), I);
2303 }
2304
2305 template <typename ElemTy>
fwdResizeBilinearInstImpl(const ResizeBilinearInst * I)2306 void BoundInterpreterFunction::fwdResizeBilinearInstImpl(
2307 const ResizeBilinearInst *I) {
2308 auto inW = getWeightHandle<ElemTy>(I->getSrc());
2309 auto scale = I->getScale();
2310 auto outW = getWeightHandle<ElemTy>(I->getDest());
2311
2312 ShapeNHWC odim(outW.dims());
2313 ShapeNHWC idim(inW.dims());
2314
2315 CHECK_EQ(scale[0], 1.0) << "Scaling batch not supported.";
2316 CHECK_EQ(scale[3], 1.0) << "Scaling channel not supported.";
2317
2318 for (dim_t ob = 0; ob < odim.n; ++ob) {
2319 for (dim_t oh = 0; oh < odim.h; ++oh) {
2320 for (dim_t ow = 0; ow < odim.w; ++ow) {
2321
2322 float ihf = oh / scale[1];
2323 float iwf = ow / scale[2];
2324 dim_t ih = dim_t(ihf);
2325 dim_t iw = dim_t(iwf);
2326
2327 auto ih0 = std::min(ih, idim.h - 1);
2328 auto ih1 = std::min(ih + 1, idim.h - 1);
2329 auto iw0 = std::min(iw, idim.w - 1);
2330 auto iw1 = std::min(iw + 1, idim.w - 1);
2331
2332 for (dim_t oc = 0; oc < odim.c; ++oc) {
2333 auto v00 = inW.at({ob, ih0, iw0, oc});
2334 auto v01 = inW.at({ob, ih0, iw1, oc});
2335 auto v10 = inW.at({ob, ih1, iw0, oc});
2336 auto v11 = inW.at({ob, ih1, iw1, oc});
2337
2338 auto hd = (float)v00 + (float)(v10 - v00) * (ihf - ih);
2339 auto hw = (float)v01 + (float)(v11 - v01) * (ihf - ih);
2340 float result = hd + (hw - hd) * (iwf - iw);
2341 outW.at({ob, oh, ow, oc}) = result;
2342 }
2343 }
2344 }
2345 }
2346 }
2347
fwdResizeBilinearInst(const ResizeBilinearInst * I)2348 void BoundInterpreterFunction::fwdResizeBilinearInst(
2349 const ResizeBilinearInst *I) {
2350 if (getTensor(I->getSrc())->getType().isQuantizedType()) {
2351 dispatchQuantizedImpl(fwdResizeBilinearInstImpl,
2352 I->getSrc()->getElementType(), I);
2353 return;
2354 }
2355
2356 dispatchImpl(fwdResizeBilinearInstImpl, I->getSrc()->getElementType(), I);
2357 }
2358
2359 //===----------------------------------------------------------------------===//
2360 // Local Response Normalization
2361 //===----------------------------------------------------------------------===//
2362
2363 template <typename ElemTy>
fwdLocalResponseNormalizationInstFloatImpl(const glow::LocalResponseNormalizationInst * I)2364 void BoundInterpreterFunction::fwdLocalResponseNormalizationInstFloatImpl(
2365 const glow::LocalResponseNormalizationInst *I) {
2366 staticAssertFloatingPointType(ElemTy);
2367
2368 auto inW = getWeightHandle<ElemTy>(I->getSrc());
2369 auto outW = getWeightHandle<ElemTy>(I->getDest());
2370 auto scaleCache = getWeightHandle<ElemTy>(I->getScale());
2371
2372 ShapeNHWC odim(outW.dims());
2373 ShapeNHWC idim(inW.dims());
2374
2375 (void)odim;
2376
2377 // LRN node does not change the shape of the input.
2378 assert(odim == idim && "Output of LRN node must be same shape as input");
2379
2380 // LRN node normalizes across channels, so the input must have a minimum
2381 // depth of 1.
2382 assert(idim.c > 0 && "Input of LRN node must have a minimum depth of 1");
2383
2384 auto halfWindowSize = (size_t)I->getHalfWindowSize();
2385 auto k = I->getK();
2386 auto beta = I->getBeta();
2387 auto windowSize = 2 * halfWindowSize + 1;
2388 auto normedAlpha = I->getAlpha() / windowSize;
2389
2390 // For every input in the batch:
2391 for (dim_t n = 0; n < idim.n; n++) {
2392
2393 // For every row:
2394 for (dim_t h = 0; h < idim.h; h++) {
2395
2396 // For every column:
2397 for (dim_t w = 0; w < idim.w; w++) {
2398
2399 // For every channel:
2400 for (dim_t c = 0; c < idim.c; c++) {
2401 float squareSum = 0.0;
2402 for (dim_t i = (c >= halfWindowSize ? c - halfWindowSize : 0);
2403 i <= std::min<dim_t>(c + halfWindowSize, (size_t)idim.c - 1);
2404 i++) {
2405 float val = inW.at({n, h, w, i});
2406 squareSum += val * val;
2407 }
2408
2409 auto scale = k + normedAlpha * squareSum;
2410
2411 // This will be used to accelerate the backward pass.
2412 scaleCache.at({n, h, w, c}) = ElemTy(scale);
2413
2414 auto normFactor = std::pow(scale, -beta);
2415 outW.at({n, h, w, c}) =
2416 ElemTy(float(inW.at({n, h, w, c})) * normFactor);
2417 }
2418 }
2419 }
2420 }
2421 }
2422
fwdLocalResponseNormalizationInst(const LocalResponseNormalizationInst * I)2423 void BoundInterpreterFunction::fwdLocalResponseNormalizationInst(
2424 const LocalResponseNormalizationInst *I) {
2425 dispatchFloatingPointImpl(fwdLocalResponseNormalizationInstFloatImpl,
2426 I->getSrc()->getElementType(), I);
2427 }
2428
fwdLocalResponseNormalizationGradInst(const glow::LocalResponseNormalizationGradInst * I)2429 void BoundInterpreterFunction::fwdLocalResponseNormalizationGradInst(
2430 const glow::LocalResponseNormalizationGradInst *I) {
2431 auto inW = getWeightHandle(I->getSrc());
2432 auto inG = getWeightHandle(I->getSrcGrad());
2433 auto outW = getWeightHandle(I->getDest());
2434 auto outG = getWeightHandle(I->getDestGrad());
2435 auto scaleCache = getWeightHandle(I->getScale());
2436
2437 ShapeNHWC odim(outW.dims());
2438
2439 auto halfWindowSize = I->getHalfWindowSize();
2440 auto beta = I->getBeta();
2441 auto windowSize = 2 * halfWindowSize + 1;
2442 auto normedAlpha = I->getAlpha() / windowSize;
2443
2444 // For every input in the batch:
2445 for (dim_t n = 0; n < odim.n; n++) {
2446
2447 // For every row:
2448 for (dim_t h = 0; h < odim.h; h++) {
2449
2450 // For every column:
2451 for (dim_t w = 0; w < odim.w; w++) {
2452
2453 float sum = 0.0;
2454
2455 // Compute sum for first channel.
2456 for (dim_t c = 0; c <= halfWindowSize && c < odim.c; c++) {
2457 auto outw = outW.at({n, h, w, c});
2458 auto scale = scaleCache.at({n, h, w, c});
2459 auto outg = outG.at({n, h, w, c});
2460 sum += (outg * (outw / scale));
2461 }
2462
2463 // For every channel:
2464 for (dim_t c = 0; c < odim.c; c++) {
2465 auto outg = outG.at({n, h, w, c});
2466 auto scale = scaleCache.at({n, h, w, c});
2467 auto inw = inW.at({n, h, w, c});
2468
2469 inG.at({n, h, w, c}) = outg * std::pow(scale, -beta) -
2470 2 * normedAlpha * beta * inw * sum;
2471
2472 // Modify sum for next channel.
2473 auto subIndex = c - halfWindowSize;
2474 auto addIndex = c + halfWindowSize + 1;
2475
2476 if (c >= halfWindowSize) {
2477 auto outw = outW.at({n, h, w, subIndex});
2478 auto scale = scaleCache.at({n, h, w, subIndex});
2479 auto outg = outG.at({n, h, w, subIndex});
2480
2481 // Subtract "rear" end of this window.
2482 sum -= (outg * (outw / scale));
2483 }
2484
2485 if (addIndex < odim.c) {
2486 auto outw = outW.at({n, h, w, addIndex});
2487 auto scale = scaleCache.at({n, h, w, addIndex});
2488 auto outg = outG.at({n, h, w, addIndex});
2489
2490 // Add "front" end of next window.
2491 sum += (outg * (outw / scale));
2492 }
2493 }
2494 }
2495 }
2496 }
2497 }
2498
2499 //===----------------------------------------------------------------------===//
2500 // Arithmetic operations
2501 //===----------------------------------------------------------------------===//
fwdElementAddInstI8Impl(const ElementAddInst * I)2502 void BoundInterpreterFunction::fwdElementAddInstI8Impl(
2503 const ElementAddInst *I) {
2504 assert(getTensor(I->getLHS())->getType().isQuantizedType() &&
2505 "Wrong function");
2506 auto lhsTy = I->getLHS()->getType();
2507 auto rhsTy = I->getRHS()->getType();
2508 auto destTy = I->getDest()->getType();
2509
2510 float lhsScale = lhsTy->getScale();
2511 float rhsScale = rhsTy->getScale();
2512 float destScale = destTy->getScale();
2513
2514 int32_t lhsOffset = lhsTy->getOffset();
2515 int32_t rhsOffset = rhsTy->getOffset();
2516 int32_t destOffset = destTy->getOffset();
2517
2518 auto outW = getWeightHandle<int8_t>(I->getDest());
2519 auto lhsW = getWeightHandle<int8_t>(I->getLHS());
2520 auto rhsW = getWeightHandle<int8_t>(I->getRHS());
2521 for (dim_t i = 0, e = outW.size(); i < e; i++) {
2522 int32_t L = lhsW.raw(i);
2523 int32_t R = rhsW.raw(i);
2524
2525 // We increase the size of the integer up to 16 bits to prevent overflow.
2526 const float largeScale = float(1) / (1 << 15);
2527 // Scale both sides from 8-bit to 16-bits.
2528 int32_t L32 = std::round(float(L - lhsOffset) * (lhsScale / largeScale));
2529 int32_t R32 = std::round(float(R - rhsOffset) * (rhsScale / largeScale));
2530 int32_t sum32 = L32 + R32;
2531 sum32 = std::round(float(sum32) * (largeScale / destScale) + destOffset);
2532 outW.raw(i) = quantization::clip<int32_t, int8_t>(sum32);
2533 }
2534 }
2535
2536 template <typename ElemTy>
fwdElementAddInstArithmeticImpl(const ElementAddInst * I)2537 void BoundInterpreterFunction::fwdElementAddInstArithmeticImpl(
2538 const ElementAddInst *I) {
2539 staticAssertArithmeticType(ElemTy);
2540
2541 auto outW = getWeightHandle<ElemTy>(I->getDest());
2542 auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
2543 auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
2544 for (size_t i = 0, e = outW.size(); i < e; i++) {
2545 outW.raw(i) = lhsW.raw(i) + rhsW.raw(i);
2546 }
2547 }
2548
fwdElementAddInst(const ElementAddInst * I)2549 void BoundInterpreterFunction::fwdElementAddInst(const ElementAddInst *I) {
2550 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
2551 fwdElementAddInstI8Impl(I);
2552 return;
2553 }
2554
2555 dispatchArithmeticImpl(fwdElementAddInstArithmeticImpl,
2556 I->getLHS()->getType()->getElementType(), I);
2557 }
2558
2559 template <typename ElemTy>
fwdElementSubInstArithmeticImpl(const ElementSubInst * I)2560 void BoundInterpreterFunction::fwdElementSubInstArithmeticImpl(
2561 const ElementSubInst *I) {
2562 staticAssertArithmeticType(ElemTy);
2563
2564 auto outW = getWeightHandle<ElemTy>(I->getDest());
2565 auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
2566 auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
2567 for (size_t i = 0, e = outW.size(); i < e; i++) {
2568 outW.raw(i) = lhsW.raw(i) - rhsW.raw(i);
2569 }
2570 }
2571
fwdElementSubInst(const ElementSubInst * I)2572 void BoundInterpreterFunction::fwdElementSubInst(const ElementSubInst *I) {
2573 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
2574 auto destTy = I->getDest()->getType();
2575 auto lhsTy = I->getLHS()->getType();
2576 auto rhsTy = I->getRHS()->getType();
2577
2578 float destScale = destTy->getScale();
2579 float lhsScale = lhsTy->getScale();
2580 float rhsScale = rhsTy->getScale();
2581
2582 int32_t destOffset = destTy->getOffset();
2583 int32_t lhsOffset = lhsTy->getOffset();
2584 int32_t rhsOffset = rhsTy->getOffset();
2585
2586 auto outW = getWeightHandle<int8_t>(I->getDest());
2587 auto lhsW = getWeightHandle<int8_t>(I->getLHS());
2588 auto rhsW = getWeightHandle<int8_t>(I->getRHS());
2589 for (size_t i = 0, e = outW.size(); i < e; i++) {
2590 // s_d * (i_d - o_d) = s_l * (i_l - o_l) - s_r * (i_r - o_r)
2591 // => i_d = (s_l / s_d) * (i_l - o_l) - (s_r / s_d) * (i_r - o_r) + o_d
2592 float l = (lhsScale / destScale) * float(lhsW.raw(i) - lhsOffset);
2593 float r = (rhsScale / destScale) * float(rhsW.raw(i) - rhsOffset);
2594 int32_t q = std::round(l - r + destOffset);
2595 outW.raw(i) = quantization::clip<int32_t, int8_t>(q);
2596 }
2597 return;
2598 }
2599
2600 dispatchArithmeticImpl(fwdElementSubInstArithmeticImpl,
2601 I->getDest()->getElementType(), I);
2602 }
2603
2604 template <typename ElemTy>
fwdElementMulInstArithmeticImpl(const ElementMulInst * I)2605 void BoundInterpreterFunction::fwdElementMulInstArithmeticImpl(
2606 const ElementMulInst *I) {
2607 staticAssertArithmeticType(ElemTy);
2608
2609 auto outW = getWeightHandle<ElemTy>(I->getDest());
2610 auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
2611 auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
2612 for (size_t i = 0, e = outW.size(); i < e; i++) {
2613 outW.raw(i) = lhsW.raw(i) * rhsW.raw(i);
2614 }
2615 }
2616
fwdElementMulInst(const ElementMulInst * I)2617 void BoundInterpreterFunction::fwdElementMulInst(const ElementMulInst *I) {
2618 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
2619 auto lhsTy = I->getLHS()->getType();
2620 auto rhsTy = I->getRHS()->getType();
2621 auto destTy = I->getDest()->getType();
2622
2623 TensorQuantizationParams lhsQ{lhsTy->getScale(), lhsTy->getOffset()};
2624 TensorQuantizationParams rhsQ{rhsTy->getScale(), rhsTy->getOffset()};
2625 TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()};
2626
2627 auto outW = getWeightHandle<int8_t>(I->getDest());
2628 auto lhsW = getWeightHandle<int8_t>(I->getLHS());
2629 auto rhsW = getWeightHandle<int8_t>(I->getRHS());
2630 float scale = lhsQ.scale * rhsQ.scale / destQ.scale;
2631 for (size_t i = 0, e = outW.size(); i < e; i++) {
2632 int32_t mul = (lhsW.raw(i) - lhsQ.offset) * (rhsW.raw(i) - rhsQ.offset);
2633 outW.raw(i) = quantization::clip<int32_t, int8_t>(
2634 std::round(mul * scale) + destQ.offset);
2635 }
2636 return;
2637 }
2638
2639 dispatchArithmeticImpl(fwdElementMulInstArithmeticImpl,
2640 I->getDest()->getElementType(), I);
2641 }
2642
fwdElementDivInst(const ElementDivInst * I)2643 void BoundInterpreterFunction::fwdElementDivInst(const ElementDivInst *I) {
2644 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
2645 auto destTy = I->getDest()->getType();
2646 auto lhsTy = I->getLHS()->getType();
2647 auto rhsTy = I->getRHS()->getType();
2648
2649 float destScale = destTy->getScale();
2650 float lhsScale = lhsTy->getScale();
2651 float rhsScale = rhsTy->getScale();
2652
2653 int32_t destOffset = destTy->getOffset();
2654 int32_t lhsOffset = lhsTy->getOffset();
2655 int32_t rhsOffset = rhsTy->getOffset();
2656
2657 auto outW = getWeightHandle<int8_t>(I->getDest());
2658 auto lhsW = getWeightHandle<int8_t>(I->getLHS());
2659 auto rhsW = getWeightHandle<int8_t>(I->getRHS());
2660 for (size_t i = 0, e = outW.size(); i < e; i++) {
2661 // s_d * (i_d - o_d) = (s_l * (i_l - o_l)) / (s_r * (i_r - o_r))
2662 // => i_d = (s_l * (i_l - o_l)) / (s_d * s_r * (i_r - o_r)) + o_d
2663 float l = lhsScale * float(lhsW.raw(i) - lhsOffset);
2664 float r = rhsScale * destScale * float(rhsW.raw(i) - rhsOffset);
2665 int32_t q = std::round(l / r + destOffset);
2666 outW.raw(i) = quantization::clip<int32_t, int8_t>(q);
2667 }
2668 return;
2669 }
2670
2671 #define DIV_LOOP(TYPE_) \
2672 auto outW = getWeightHandle<TYPE_>(I->getDest()); \
2673 auto lhsW = getWeightHandle<TYPE_>(I->getLHS()); \
2674 auto rhsW = getWeightHandle<TYPE_>(I->getRHS()); \
2675 for (size_t i = 0, e = outW.size(); i < e; i++) { \
2676 outW.raw(i) = lhsW.raw(i) / rhsW.raw(i); \
2677 }
2678
2679 auto *T = getTensor(I->getDest());
2680 switch (T->getElementType()) {
2681 case ElemKind::Int64ITy: {
2682 DIV_LOOP(int64_t);
2683 return;
2684 }
2685 case ElemKind::FloatTy: {
2686 DIV_LOOP(float);
2687 return;
2688 }
2689 case ElemKind::Float16Ty: {
2690 DIV_LOOP(float16_t);
2691 return;
2692 }
2693 case ElemKind::BFloat16Ty: {
2694 DIV_LOOP(bfloat16_t);
2695 return;
2696 }
2697 default:
2698 llvm_unreachable("Unsupported type for Div.");
2699 }
2700 }
2701
fwdElementMaxInstI8Impl(const ElementMaxInst * I)2702 void BoundInterpreterFunction::fwdElementMaxInstI8Impl(
2703 const ElementMaxInst *I) {
2704 assert(getTensor(I->getLHS())->getType().isQuantizedType() &&
2705 "Wrong function");
2706 auto lhsTy = I->getLHS()->getType();
2707 auto rhsTy = I->getRHS()->getType();
2708 auto destTy = I->getDest()->getType();
2709
2710 TensorQuantizationParams lhsQ{lhsTy->getScale(), lhsTy->getOffset()};
2711 TensorQuantizationParams rhsQ{rhsTy->getScale(), rhsTy->getOffset()};
2712 TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()};
2713
2714 auto outW = getWeightHandle<int8_t>(I->getDest());
2715 auto lhsW = getWeightHandle<int8_t>(I->getLHS());
2716 auto rhsW = getWeightHandle<int8_t>(I->getRHS());
2717 for (size_t i = 0, e = outW.size(); i < e; i++) {
2718 // Convert both sides to the destination scale and perform a regular
2719 // comparison.
2720 int8_t L = quantization::quantize(
2721 quantization::dequantize(lhsW.raw(i), lhsQ), destQ);
2722 int8_t R = quantization::quantize(
2723 quantization::dequantize(rhsW.raw(i), rhsQ), destQ);
2724 outW.raw(i) = std::max(L, R);
2725 }
2726 }
2727
2728 template <typename ElemTy>
fwdElementMaxInstArithmeticImpl(const ElementMaxInst * I)2729 void BoundInterpreterFunction::fwdElementMaxInstArithmeticImpl(
2730 const ElementMaxInst *I) {
2731 staticAssertArithmeticType(ElemTy);
2732
2733 auto outW = getWeightHandle<ElemTy>(I->getDest());
2734 auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
2735 auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
2736 for (size_t i = 0, e = outW.size(); i < e; i++) {
2737 outW.raw(i) = std::max(lhsW.raw(i), rhsW.raw(i));
2738 }
2739 }
2740
fwdElementMaxInst(const ElementMaxInst * I)2741 void BoundInterpreterFunction::fwdElementMaxInst(const ElementMaxInst *I) {
2742 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
2743 fwdElementMaxInstI8Impl(I);
2744 return;
2745 }
2746
2747 dispatchArithmeticImpl(fwdElementMaxInstArithmeticImpl,
2748 I->getLHS()->getType()->getElementType(), I);
2749 }
2750
2751 template <typename ElemTy>
fwdElementMinInstArithmeticImpl(const ElementMinInst * I)2752 void BoundInterpreterFunction::fwdElementMinInstArithmeticImpl(
2753 const ElementMinInst *I) {
2754 staticAssertArithmeticType(ElemTy);
2755
2756 auto outW = getWeightHandle<ElemTy>(I->getDest());
2757 auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
2758 auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
2759 for (size_t i = 0, e = outW.size(); i < e; i++) {
2760 outW.raw(i) = std::min(lhsW.raw(i), rhsW.raw(i));
2761 }
2762 }
2763
fwdElementMinInst(const ElementMinInst * I)2764 void BoundInterpreterFunction::fwdElementMinInst(const ElementMinInst *I) {
2765 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
2766 auto lhsTy = I->getLHS()->getType();
2767 auto rhsTy = I->getRHS()->getType();
2768 auto destTy = I->getDest()->getType();
2769
2770 TensorQuantizationParams lhsQ{lhsTy->getScale(), lhsTy->getOffset()};
2771 TensorQuantizationParams rhsQ{rhsTy->getScale(), rhsTy->getOffset()};
2772 TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()};
2773
2774 auto outW = getWeightHandle<int8_t>(I->getDest());
2775 auto lhsW = getWeightHandle<int8_t>(I->getLHS());
2776 auto rhsW = getWeightHandle<int8_t>(I->getRHS());
2777 for (size_t i = 0, e = outW.size(); i < e; i++) {
2778 // Convert both sides to the destination scale and perform a regular
2779 // comparison.
2780 int8_t L = quantization::quantize(
2781 quantization::dequantize(lhsW.raw(i), lhsQ), destQ);
2782 int8_t R = quantization::quantize(
2783 quantization::dequantize(rhsW.raw(i), rhsQ), destQ);
2784 outW.raw(i) = std::min(L, R);
2785 }
2786 return;
2787 }
2788
2789 dispatchArithmeticImpl(fwdElementMinInstArithmeticImpl,
2790 I->getDest()->getElementType(), I);
2791 }
2792
2793 //===----------------------------------------------------------------------===//
2794 // Logical operations
2795 //===----------------------------------------------------------------------===//
fwdElementNotInst(const ElementNotInst * I)2796 void BoundInterpreterFunction::fwdElementNotInst(const ElementNotInst *I) {
2797 auto inpW = getWeightHandle<bool>(I->getSrc());
2798 auto outW = getWeightHandle<bool>(I->getDest());
2799 for (size_t i = 0, e = outW.size(); i < e; ++i) {
2800 outW.raw(i) = (!inpW.raw(i));
2801 }
2802 }
2803
fwdElementAndInst(const ElementAndInst * I)2804 void BoundInterpreterFunction::fwdElementAndInst(const ElementAndInst *I) {
2805 auto lhsW = getWeightHandle<bool>(I->getLHS());
2806 auto rhsW = getWeightHandle<bool>(I->getRHS());
2807 auto outW = getWeightHandle<bool>(I->getDest());
2808 for (size_t i = 0, e = outW.size(); i < e; ++i) {
2809 outW.raw(i) = (lhsW.raw(i) && rhsW.raw(i));
2810 }
2811 }
2812
fwdElementOrInst(const ElementOrInst * I)2813 void BoundInterpreterFunction::fwdElementOrInst(const ElementOrInst *I) {
2814 auto lhsW = getWeightHandle<bool>(I->getLHS());
2815 auto rhsW = getWeightHandle<bool>(I->getRHS());
2816 auto outW = getWeightHandle<bool>(I->getDest());
2817 for (size_t i = 0, e = outW.size(); i < e; ++i) {
2818 outW.raw(i) = (lhsW.raw(i) || rhsW.raw(i));
2819 }
2820 }
2821
fwdElementXorInst(const ElementXorInst * I)2822 void BoundInterpreterFunction::fwdElementXorInst(const ElementXorInst *I) {
2823 auto lhsW = getWeightHandle<bool>(I->getLHS());
2824 auto rhsW = getWeightHandle<bool>(I->getRHS());
2825 auto outW = getWeightHandle<bool>(I->getDest());
2826 for (size_t i = 0, e = outW.size(); i < e; ++i) {
2827 outW.raw(i) = (lhsW.raw(i) ^ rhsW.raw(i));
2828 }
2829 }
2830
2831 //===----------------------------------------------------------------------===//
2832 // Unary arithmetic operations
2833 //===----------------------------------------------------------------------===//
2834 template <typename ElemTy, typename InstKind>
fwdUnaryArithmeticImpl(const InstKind * I,std::function<float (float)> func)2835 void BoundInterpreterFunction::fwdUnaryArithmeticImpl(
2836 const InstKind *I, std::function<float(float)> func) {
2837 Value *inpV = I->getSrc();
2838 Value *outV = I->getDest();
2839 auto inpTy = inpV->getType();
2840 auto outTy = outV->getType();
2841 auto inpH = getWeightHandle<ElemTy>(inpV);
2842 auto outH = getWeightHandle<ElemTy>(outV);
2843
2844 if (inpTy->isQuantizedType()) {
2845 float inpScale = inpTy->getScale();
2846 int32_t inpOffset = inpTy->getOffset();
2847 float outScale = outTy->getScale();
2848 int32_t outOffset = outTy->getOffset();
2849 for (size_t i = 0, e = outH.size(); i < e; ++i) {
2850 float inpVal =
2851 quantization::dequantize<ElemTy>(inpH.raw(i), {inpScale, inpOffset});
2852 float outVal = func(inpVal);
2853 outH.raw(i) =
2854 quantization::quantize<ElemTy>(outVal, {outScale, outOffset});
2855 }
2856 } else {
2857 for (size_t i = 0, e = outH.size(); i < e; ++i) {
2858 float inpVal = static_cast<float>(inpH.raw(i));
2859 float outVal = func(inpVal);
2860 outH.raw(i) = static_cast<ElemTy>(outVal);
2861 }
2862 }
2863 }
2864
fwdElementAbsInst(const ElementAbsInst * I)2865 void BoundInterpreterFunction::fwdElementAbsInst(const ElementAbsInst *I) {
2866 auto func = [](float x) -> float { return std::abs(x); };
2867 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2868 }
2869
fwdElementNegInst(const ElementNegInst * I)2870 void BoundInterpreterFunction::fwdElementNegInst(const ElementNegInst *I) {
2871 auto func = [](float x) -> float { return -x; };
2872 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2873 }
2874
fwdElementFloorInst(const ElementFloorInst * I)2875 void BoundInterpreterFunction::fwdElementFloorInst(const ElementFloorInst *I) {
2876 auto func = [](float x) -> float { return std::floor(x); };
2877 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2878 }
2879
fwdElementCeilInst(const ElementCeilInst * I)2880 void BoundInterpreterFunction::fwdElementCeilInst(const ElementCeilInst *I) {
2881 auto func = [](float x) -> float { return std::ceil(x); };
2882 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2883 }
2884
fwdElementRoundInst(const ElementRoundInst * I)2885 void BoundInterpreterFunction::fwdElementRoundInst(const ElementRoundInst *I) {
2886 // Rounding mode required by ONNX, Numpy, TensorFlow is round to even which
2887 // rounds to nearest even integer those values with fractional part 0.5.
2888 auto func = [](float x) -> float { return std::nearbyintf(x); };
2889 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2890 }
2891
fwdElementSqrtInst(const ElementSqrtInst * I)2892 void BoundInterpreterFunction::fwdElementSqrtInst(const ElementSqrtInst *I) {
2893 auto func = [](float x) -> float { return std::sqrt(x); };
2894 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2895 }
2896
fwdElementRsqrtInst(const ElementRsqrtInst * I)2897 void BoundInterpreterFunction::fwdElementRsqrtInst(const ElementRsqrtInst *I) {
2898 auto func = [](float x) -> float { return 1 / std::sqrt(x); };
2899 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2900 }
2901
fwdElementReciprocalInst(const ElementReciprocalInst * I)2902 void BoundInterpreterFunction::fwdElementReciprocalInst(
2903 const ElementReciprocalInst *I) {
2904 auto func = [](float x) -> float { return 1 / x; };
2905 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2906 }
2907
fwdElementSinInst(const ElementSinInst * I)2908 void BoundInterpreterFunction::fwdElementSinInst(const ElementSinInst *I) {
2909 auto func = [](float x) -> float { return std::sin(x); };
2910 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2911 }
2912
fwdElementCosInst(const ElementCosInst * I)2913 void BoundInterpreterFunction::fwdElementCosInst(const ElementCosInst *I) {
2914 auto func = [](float x) -> float { return std::cos(x); };
2915 dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2916 }
2917
2918 //===----------------------------------------------------------------------===//
2919 // Compare operations
2920 //===----------------------------------------------------------------------===//
2921 template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
2922 typename CmpTy, typename InstCmpKind>
fwdElementCmpHelperImpl(const InstCmpKind * I,std::function<bool (CmpTy LHS,CmpTy RHS)> cmpHelper)2923 void BoundInterpreterFunction::fwdElementCmpHelperImpl(
2924 const InstCmpKind *I, std::function<bool(CmpTy LHS, CmpTy RHS)> cmpHelper) {
2925 Value *lhsV = I->getLHS();
2926 Value *rhsV = I->getRHS();
2927 Value *outV = I->getDest();
2928
2929 auto lhsH = getWeightHandle<ElemTy>(lhsV);
2930 auto rhsH = getWeightHandle<ElemTy>(rhsV);
2931 auto oH = getWeightHandle<bool>(outV);
2932
2933 ElemScaleTy lhsScale = 1.0f;
2934 ElemScaleTy rhsScale = 1.0f;
2935 ElemOffsetTy lhsOffset = 0;
2936 ElemOffsetTy rhsOffset = 0;
2937
2938 auto lhsTy = lhsV->getType();
2939 auto rhsTy = rhsV->getType();
2940
2941 if (lhsV->getType()->isQuantizedType()) {
2942 lhsScale = lhsTy->getScale();
2943 rhsScale = rhsTy->getScale();
2944
2945 lhsOffset = lhsTy->getOffset();
2946 rhsOffset = rhsTy->getOffset();
2947 }
2948
2949 // For each layer in the batch:
2950 for (size_t i = 0, e = oH.size(); i < e; i++) {
2951 oH.raw(i) = cmpHelper(lhsScale * (lhsH.raw(i) - lhsOffset),
2952 rhsScale * (rhsH.raw(i) - rhsOffset));
2953 }
2954 }
2955
2956 template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
2957 typename CmpTy>
fwdElementCmpLTEInstImpl(const ElementCmpLTEInst * I)2958 void BoundInterpreterFunction::fwdElementCmpLTEInstImpl(
2959 const ElementCmpLTEInst *I) {
2960 auto cmpHelper = [](CmpTy LHS, CmpTy RHS) -> bool { return LHS <= RHS; };
2961 fwdElementCmpHelperImpl<ElemTy, ElemOffsetTy, ElemScaleTy, CmpTy,
2962 ElementCmpLTEInst>(I, cmpHelper);
2963 }
2964
fwdElementCmpLTEInst(const ElementCmpLTEInst * I)2965 void BoundInterpreterFunction::fwdElementCmpLTEInst(
2966 const ElementCmpLTEInst *I) {
2967 auto *T = getTensor(I->getLHS());
2968
2969 if (T->getType().isQuantizedType()) {
2970 fwdElementCmpLTEInstImpl<int8_t, int32_t, float, int32_t>(I);
2971 return;
2972 }
2973
2974 switch (T->getElementType()) {
2975 case ElemKind::FloatTy:
2976 fwdElementCmpLTEInstImpl<float, float, float>(I);
2977 break;
2978 case ElemKind::Float16Ty:
2979 fwdElementCmpLTEInstImpl<float16_t, float16_t, float16_t>(I);
2980 break;
2981 case ElemKind::BFloat16Ty:
2982 fwdElementCmpLTEInstImpl<bfloat16_t, bfloat16_t, bfloat16_t>(I);
2983 break;
2984 case ElemKind::Int32ITy:
2985 fwdElementCmpLTEInstImpl<int32_t, int32_t, float>(I);
2986 break;
2987 case ElemKind::Int64ITy:
2988 fwdElementCmpLTEInstImpl<int64_t, int64_t, float>(I);
2989 break;
2990 default:
2991 llvm_unreachable("Type is not supported");
2992 }
2993 }
2994
2995 template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
2996 typename CmpTy>
fwdElementCmpEQInstImpl(const ElementCmpEQInst * I)2997 void BoundInterpreterFunction::fwdElementCmpEQInstImpl(
2998 const ElementCmpEQInst *I) {
2999 auto cmpHelper = [](CmpTy LHS, CmpTy RHS) -> bool { return LHS == RHS; };
3000 fwdElementCmpHelperImpl<ElemTy, ElemOffsetTy, ElemScaleTy, CmpTy,
3001 ElementCmpEQInst>(I, cmpHelper);
3002 }
3003
fwdElementCmpEQInst(const ElementCmpEQInst * I)3004 void BoundInterpreterFunction::fwdElementCmpEQInst(const ElementCmpEQInst *I) {
3005 auto *T = getTensor(I->getLHS());
3006
3007 if (T->getType().isQuantizedType()) {
3008 fwdElementCmpEQInstImpl<int8_t, int32_t, float, int32_t>(I);
3009 return;
3010 }
3011
3012 switch (T->getElementType()) {
3013 case ElemKind::FloatTy:
3014 fwdElementCmpEQInstImpl<float, float, float>(I);
3015 break;
3016 case ElemKind::Float16Ty:
3017 fwdElementCmpEQInstImpl<float16_t, float16_t, float16_t>(I);
3018 break;
3019 case ElemKind::BFloat16Ty:
3020 fwdElementCmpEQInstImpl<bfloat16_t, bfloat16_t, bfloat16_t>(I);
3021 break;
3022 case ElemKind::Int32ITy:
3023 fwdElementCmpEQInstImpl<int32_t, int32_t, float>(I);
3024 break;
3025 case ElemKind::Int64ITy:
3026 fwdElementCmpEQInstImpl<int64_t, int64_t, float>(I);
3027 break;
3028 default:
3029 llvm_unreachable("Type is not supported");
3030 }
3031 }
3032
3033 template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
3034 typename CmpTy>
fwdElementCmpNEQInstImpl(const ElementCmpNEQInst * I)3035 void BoundInterpreterFunction::fwdElementCmpNEQInstImpl(
3036 const ElementCmpNEQInst *I) {
3037 auto cmpHelper = [](CmpTy LHS, CmpTy RHS) -> bool { return !(LHS == RHS); };
3038 fwdElementCmpHelperImpl<ElemTy, ElemOffsetTy, ElemScaleTy, CmpTy,
3039 ElementCmpNEQInst>(I, cmpHelper);
3040 }
3041
fwdElementCmpNEQInst(const ElementCmpNEQInst * I)3042 void BoundInterpreterFunction::fwdElementCmpNEQInst(
3043 const ElementCmpNEQInst *I) {
3044 auto *T = getTensor(I->getLHS());
3045
3046 if (T->getType().isQuantizedType()) {
3047 fwdElementCmpNEQInstImpl<int8_t, int32_t, float, int32_t>(I);
3048 return;
3049 }
3050
3051 switch (T->getElementType()) {
3052 case ElemKind::FloatTy:
3053 fwdElementCmpNEQInstImpl<float, float, float>(I);
3054 break;
3055 case ElemKind::Float16Ty:
3056 fwdElementCmpNEQInstImpl<float16_t, float16_t, float16_t>(I);
3057 break;
3058 case ElemKind::BFloat16Ty:
3059 fwdElementCmpNEQInstImpl<bfloat16_t, bfloat16_t, bfloat16_t>(I);
3060 break;
3061 case ElemKind::Int32ITy:
3062 fwdElementCmpNEQInstImpl<int32_t, int32_t, float>(I);
3063 break;
3064 case ElemKind::Int64ITy:
3065 fwdElementCmpNEQInstImpl<int64_t, int64_t, float>(I);
3066 break;
3067 default:
3068 llvm_unreachable("Type is not supported");
3069 }
3070 }
3071
3072 template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
3073 typename CmpTy>
fwdElementCmpLTInstImpl(const ElementCmpLTInst * I)3074 void BoundInterpreterFunction::fwdElementCmpLTInstImpl(
3075 const ElementCmpLTInst *I) {
3076 auto cmpHelper = [](CmpTy LHS, CmpTy RHS) -> bool { return LHS < RHS; };
3077 fwdElementCmpHelperImpl<ElemTy, ElemOffsetTy, ElemScaleTy, CmpTy,
3078 ElementCmpLTInst>(I, cmpHelper);
3079 }
3080
fwdElementCmpLTInst(ElementCmpLTInst const * I)3081 void BoundInterpreterFunction::fwdElementCmpLTInst(ElementCmpLTInst const *I) {
3082 auto *T = getTensor(I->getLHS());
3083 if (T->getType().isQuantizedType()) {
3084 fwdElementCmpLTInstImpl<int8_t, int32_t, float, int32_t>(I);
3085 return;
3086 }
3087
3088 switch (T->getElementType()) {
3089 case ElemKind::FloatTy:
3090 fwdElementCmpLTInstImpl<float, float, float>(I);
3091 break;
3092 case ElemKind::Float16Ty:
3093 fwdElementCmpLTInstImpl<float16_t, float16_t, float16_t>(I);
3094 break;
3095 case ElemKind::BFloat16Ty:
3096 fwdElementCmpLTInstImpl<bfloat16_t, bfloat16_t, bfloat16_t>(I);
3097 break;
3098 case ElemKind::Int32ITy:
3099 fwdElementCmpLTInstImpl<int32_t, int32_t, float>(I);
3100 break;
3101 case ElemKind::Int64ITy:
3102 fwdElementCmpLTInstImpl<int64_t, int64_t, float>(I);
3103 break;
3104 default:
3105 llvm_unreachable("Type is not supported");
3106 }
3107 }
3108
3109 template <typename ElemTy>
fwdElementPowInstFloatImpl(const ElementPowInst * I)3110 void BoundInterpreterFunction::fwdElementPowInstFloatImpl(
3111 const ElementPowInst *I) {
3112 staticAssertFloatingPointType(ElemTy);
3113
3114 auto baseW = getWeightHandle<ElemTy>(I->getLHS());
3115 auto expW = getWeightHandle<ElemTy>(I->getRHS());
3116 auto outW = getWeightHandle<ElemTy>(I->getDest());
3117 for (size_t i = 0, e = outW.size(); i < e; i++) {
3118 outW.raw(i) = ElemTy(pow(float(baseW.raw(i)), float(expW.raw(i))));
3119 }
3120 }
3121
fwdElementPowInst(const glow::ElementPowInst * I)3122 void BoundInterpreterFunction::fwdElementPowInst(
3123 const glow::ElementPowInst *I) {
3124 dispatchFloatingPointImpl(fwdElementPowInstFloatImpl,
3125 I->getLHS()->getElementType(), I);
3126 }
3127
3128 template <typename ElemTy>
fwdElementIsNaNInstFloatImpl(const ElementIsNaNInst * I)3129 void BoundInterpreterFunction::fwdElementIsNaNInstFloatImpl(
3130 const ElementIsNaNInst *I) {
3131 staticAssertFloatingPointType(ElemTy);
3132
3133 auto inW = getWeightHandle<ElemTy>(I->getSrc());
3134 auto outW = getWeightHandle<bool>(I->getDest());
3135 for (size_t i = 0, e = inW.size(); i < e; i++) {
3136 float val = inW.raw(i);
3137 outW.raw(i) = std::isnan(val);
3138 }
3139 }
3140
fwdElementIsNaNInst(const glow::ElementIsNaNInst * I)3141 void BoundInterpreterFunction::fwdElementIsNaNInst(
3142 const glow::ElementIsNaNInst *I) {
3143 dispatchFloatingPointImpl(fwdElementIsNaNInstFloatImpl,
3144 I->getSrc()->getElementType(), I);
3145 }
3146
3147 template <typename ElemTy>
fwdElementLogInstFloatImpl(const ElementLogInst * I)3148 void BoundInterpreterFunction::fwdElementLogInstFloatImpl(
3149 const ElementLogInst *I) {
3150 staticAssertFloatingPointType(ElemTy);
3151
3152 auto inW = getWeightHandle<ElemTy>(I->getSrc());
3153 auto outW = getWeightHandle<ElemTy>(I->getDest());
3154 for (size_t i = 0, e = inW.size(); i < e; i++) {
3155 float val = inW.raw(i);
3156 outW.raw(i) = ElemTy(log(val));
3157 }
3158 }
3159
fwdElementLogInst(const ElementLogInst * I)3160 void BoundInterpreterFunction::fwdElementLogInst(const ElementLogInst *I) {
3161 dispatchFloatingPointImpl(fwdElementLogInstFloatImpl,
3162 I->getSrc()->getElementType(), I);
3163 }
3164
3165 template <typename ElemTy>
fwdElementExpInstFloatImpl(const ElementExpInst * I)3166 void BoundInterpreterFunction::fwdElementExpInstFloatImpl(
3167 const ElementExpInst *I) {
3168 staticAssertFloatingPointType(ElemTy);
3169
3170 auto inW = getWeightHandle<ElemTy>(I->getSrc());
3171 auto outW = getWeightHandle<ElemTy>(I->getDest());
3172 for (size_t i = 0, e = inW.size(); i < e; i++) {
3173 float val = inW.raw(i);
3174 outW.raw(i) = ElemTy(exp(val));
3175 }
3176 }
3177
fwdElementExpInst(const ElementExpInst * I)3178 void BoundInterpreterFunction::fwdElementExpInst(const ElementExpInst *I) {
3179 dispatchFloatingPointImpl(fwdElementExpInstFloatImpl,
3180 I->getSrc()->getElementType(), I);
3181 }
3182
3183 template <typename ElemTy>
fwdElementSelectInstFloatImpl(const glow::ElementSelectInst * I)3184 void BoundInterpreterFunction::fwdElementSelectInstFloatImpl(
3185 const glow::ElementSelectInst *I) {
3186 staticAssertFloatingPointType(ElemTy);
3187
3188 auto outW = getWeightHandle<ElemTy>(I->getDest());
3189 auto condW = getWeightHandle<bool>(I->getCond());
3190 auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
3191 auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
3192 for (size_t i = 0, e = outW.size(); i < e; i++) {
3193 outW.raw(i) = condW.raw(i) ? lhsW.raw(i) : rhsW.raw(i);
3194 }
3195 }
3196
fwdElementSelectInst(const glow::ElementSelectInst * I)3197 void BoundInterpreterFunction::fwdElementSelectInst(
3198 const glow::ElementSelectInst *I) {
3199 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
3200 auto destTy = I->getDest()->getType();
3201 auto lhsTy = I->getLHS()->getType();
3202 auto rhsTy = I->getRHS()->getType();
3203
3204 float destScale = destTy->getScale();
3205 float lhsScale = lhsTy->getScale();
3206 float rhsScale = rhsTy->getScale();
3207
3208 int32_t destOffset = destTy->getOffset();
3209 int32_t lhsOffset = lhsTy->getOffset();
3210 int32_t rhsOffset = rhsTy->getOffset();
3211
3212 auto outW = getWeightHandle<int8_t>(I->getDest());
3213 auto condW = getWeightHandle<bool>(I->getCond());
3214 auto lhsW = getWeightHandle<int8_t>(I->getLHS());
3215 auto rhsW = getWeightHandle<int8_t>(I->getRHS());
3216 for (size_t i = 0, e = outW.size(); i < e; i++) {
3217 float val = condW.raw(i) ? lhsScale * (lhsW.raw(i) - lhsOffset)
3218 : rhsScale * (rhsW.raw(i) - rhsOffset);
3219 int32_t q = std::round(val / destScale + destOffset);
3220 outW.raw(i) = quantization::clip<int32_t, int8_t>(q);
3221 }
3222 return;
3223 }
3224
3225 dispatchFloatingPointImpl(fwdElementSelectInstFloatImpl,
3226 I->getLHS()->getElementType(), I);
3227 }
3228
3229 template <typename ElemTy>
fwdModuloInstImpl(glow::ModuloInst const * I)3230 void BoundInterpreterFunction::fwdModuloInstImpl(glow::ModuloInst const *I) {
3231 auto srcH = getTensor(I->getSrc())->getHandle<ElemTy>();
3232 auto destH = getTensor(I->getDest())->getHandle<ElemTy>();
3233
3234 auto divisor = I->getDivisor();
3235 auto signFollowDivisor = I->getSignFollowDivisor();
3236
3237 for (size_t i = 0, e = srcH.size(); i < e; i++) {
3238 auto res = srcH.raw(i) % divisor;
3239 if (signFollowDivisor && res < 0) {
3240 res += divisor;
3241 }
3242 destH.raw(i) = res;
3243 }
3244 }
3245
fwdModuloInst(glow::ModuloInst const * I)3246 void BoundInterpreterFunction::fwdModuloInst(glow::ModuloInst const *I) {
3247 dispatchIndexTypeImpl(fwdModuloInstImpl, I->getSrc()->getElementType(), I);
3248 }
3249
3250 //===----------------------------------------------------------------------===//
3251 // Mat Mul
3252 //===----------------------------------------------------------------------===//
3253 template <typename ElemTy, typename AccumulatorTy>
fwdMatMulInstQuantizedImpl(const glow::MatMulInst * I)3254 void BoundInterpreterFunction::fwdMatMulInstQuantizedImpl(
3255 const glow::MatMulInst *I) {
3256 assert(getTensor(I->getLHS())->getType().isQuantizedType());
3257 auto lhs = getWeightHandle<ElemTy>(I->getLHS());
3258 auto rhs = getWeightHandle<ElemTy>(I->getRHS());
3259
3260 auto dest = getWeightHandle<ElemTy>(I->getDest());
3261
3262 auto destDim = dest.dims();
3263 auto lhsDim = lhs.dims();
3264
3265 auto destTy = I->getDest()->getType();
3266 auto lhsTy = I->getLHS()->getType();
3267 auto rhsTy = I->getRHS()->getType();
3268
3269 dest.clear(0);
3270
3271 // For matrix multiplication, if the offset is equal to zero the scale
3272 // is defined as the formula (L.scale * R.scale / D.scale).
3273 // In here we assume that the offset for all buffers is zero.
3274 float scale = lhsTy->getScale() * rhsTy->getScale() / destTy->getScale();
3275 int32_t lhsOffset = lhsTy->getOffset();
3276 int32_t rhsOffset = rhsTy->getOffset();
3277 int32_t destOffset = destTy->getOffset();
3278
3279 // For each (x,y) in the destination matrix:
3280 for (dim_t x = 0; x < destDim[0]; x++) {
3281 for (dim_t y = 0; y < destDim[1]; y++) {
3282
3283 // Perform DOT on the row an column.
3284 AccumulatorTy sum = 0;
3285 for (dim_t i = 0; i < lhsDim[1]; i++) {
3286 AccumulatorTy L = lhs.at({x, i});
3287 AccumulatorTy R = rhs.at({i, y});
3288 // We represent the element multiplication with offset as
3289 // (value - offset).
3290 sum += (L - lhsOffset) * (R - rhsOffset);
3291 }
3292
3293 dest.at({x, y}) = quantization::clip<AccumulatorTy, ElemTy>(
3294 std::round(scale * sum + destOffset));
3295 }
3296 }
3297 }
3298
3299 template <typename ElemTy>
fwdMatMulInstFloatImpl(const MatMulInst * I)3300 void BoundInterpreterFunction::fwdMatMulInstFloatImpl(const MatMulInst *I) {
3301 staticAssertFloatingPointType(ElemTy);
3302
3303 auto lhs = getWeightHandle<ElemTy>(I->getLHS());
3304 auto rhs = getWeightHandle<ElemTy>(I->getRHS());
3305 auto dest = getWeightHandle<ElemTy>(I->getDest());
3306
3307 auto destDim = dest.dims();
3308 auto lhsDim = lhs.dims();
3309
3310 dest.clear(0);
3311
3312 // For each (x,y) in the destination matrix:
3313 for (dim_t x = 0; x < destDim[0]; x++) {
3314 for (dim_t y = 0; y < destDim[1]; y++) {
3315
3316 // Perform DOT on the row an column.
3317 float sum = 0;
3318 for (dim_t i = 0; i < lhsDim[1]; i++) {
3319 sum += float(lhs.at({x, i}) * rhs.at({i, y}));
3320 }
3321 dest.at({x, y}) = ElemTy(sum);
3322 }
3323 }
3324 }
3325
fwdMatMulInst(const glow::MatMulInst * I)3326 void BoundInterpreterFunction::fwdMatMulInst(const glow::MatMulInst *I) {
3327 if (getTensor(I->getLHS())->getType().isQuantizedType()) {
3328 dispatchQuantizedWithAccumulationImpl(fwdMatMulInstQuantizedImpl,
3329 I->getLHS()->getElementType(), I);
3330 return;
3331 }
3332
3333 dispatchFloatingPointImpl(fwdMatMulInstFloatImpl,
3334 I->getLHS()->getElementType(), I);
3335 }
3336
fwdBatchMatMulInst(const glow::BatchMatMulInst * I)3337 void BoundInterpreterFunction::fwdBatchMatMulInst(
3338 const glow::BatchMatMulInst *I) {
3339 DCHECK(!"Found BatchMatMulInst but BatchMatMul is lowered on Interpreter");
3340 }
3341
fwdReluGradInst(const glow::ReluGradInst * I)3342 void BoundInterpreterFunction::fwdReluGradInst(const glow::ReluGradInst *I) {
3343 DCHECK(!"Found ReluGradInst but ReluGrad is lowered on Interpreter");
3344 }
3345
3346 //===----------------------------------------------------------------------===//
3347 // FC
3348 //===----------------------------------------------------------------------===//
3349 template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
fwdFullyConnectedInstQuantizedImpl(const glow::FullyConnectedInst * I)3350 void BoundInterpreterFunction::fwdFullyConnectedInstQuantizedImpl(
3351 const glow::FullyConnectedInst *I) {
3352 assert(getTensor(I->getSrc())->getType().isQuantizedType());
3353
3354 auto inW = getWeightHandle<ElemTy>(I->getSrc());
3355 auto weightsW = getWeightHandle<ElemTy>(I->getWeights());
3356 auto biasW = getWeightHandle<BiasElemTy>(I->getBias());
3357 auto outW = getWeightHandle<ElemTy>(I->getDest());
3358
3359 auto inTy = inW.getType();
3360 auto weightsTy = weightsW.getType();
3361 auto biasTy = biasW.getType();
3362 auto outTy = outW.getType();
3363
3364 int32_t inOffset = inTy.getOffset();
3365 int32_t weightsOffset = weightsTy.getOffset();
3366 int32_t biasOffset = biasTy.getOffset();
3367 int32_t outOffset = outTy.getOffset();
3368
3369 float outScale = outTy.getScale();
3370 float weightsScale = weightsTy.getScale();
3371 float biasScale = biasTy.getScale();
3372 float inScale = inTy.getScale();
3373
3374 ShapeHW idim(inW.dims());
3375 ShapeHW odim(outW.dims());
3376
3377 // Calculate the scale of the values that come out of the matrix
3378 // multiplication part of the calculation.
3379 float matMulScale = weightsScale * inScale;
3380
3381 outW.clear(0);
3382
3383 for (dim_t i = 0; i < idim.height; i++) {
3384 for (dim_t j = 0; j < odim.width; j++) {
3385 AccumulatorTy sum = 0;
3386 for (dim_t k = 0; k < idim.width; k++) {
3387 AccumulatorTy W = weightsW.at({k, j});
3388 AccumulatorTy A = inW.at({i, k});
3389 sum += (W - weightsOffset) * (A - inOffset);
3390 }
3391
3392 // Scale the bias to match the scale of the matrix multiplication.
3393 AccumulatorTy B = std::round(float(biasW.at({j}) - biasOffset) *
3394 (biasScale / matMulScale));
3395
3396 // Add the bias.
3397 sum += B;
3398
3399 // Scale the result back to the expected destination scale.
3400 outW.at({i, j}) = quantization::clip<AccumulatorTy, ElemTy>(
3401 std::round(float(sum) * (matMulScale / outScale)) + outOffset);
3402 }
3403 }
3404 }
3405
3406 template <typename ElemTy>
fwdFullyConnectedInstFloatImpl(const FullyConnectedInst * I)3407 void BoundInterpreterFunction::fwdFullyConnectedInstFloatImpl(
3408 const FullyConnectedInst *I) {
3409 staticAssertFloatingPointType(ElemTy);
3410
3411 auto inW = getWeightHandle<ElemTy>(I->getSrc());
3412 auto weightsW = getWeightHandle<ElemTy>(I->getWeights());
3413 auto biasW = getWeightHandle<ElemTy>(I->getBias());
3414 auto outW = getWeightHandle<ElemTy>(I->getDest());
3415
3416 ShapeHW idim(inW.dims());
3417 ShapeHW odim(outW.dims());
3418
3419 outW.clear(0);
3420
3421 for (dim_t i = 0; i < idim.height; i++) {
3422 for (dim_t j = 0; j < odim.width; j++) {
3423 float sum = 0;
3424 for (dim_t k = 0; k < idim.width; k++) {
3425 sum += float(inW.at({i, k})) * float(weightsW.at({k, j}));
3426 }
3427
3428 outW.at({i, j}) = sum + float(biasW.at({j}));
3429 }
3430 }
3431 }
3432
fwdFullyConnectedInst(const glow::FullyConnectedInst * I)3433 void BoundInterpreterFunction::fwdFullyConnectedInst(
3434 const glow::FullyConnectedInst *I) {
3435
3436 if (getTensor(I->getSrc())->getType().isQuantizedType()) {
3437 dispatchQuantizedWithAccumulationAndBiasImpl(
3438 fwdFullyConnectedInstQuantizedImpl, I->getSrc()->getElementType(),
3439 I->getBias()->getElementType(), I);
3440 return;
3441 } else {
3442 dispatchFloatingPointImpl(fwdFullyConnectedInstFloatImpl,
3443 I->getSrc()->getElementType(), I);
3444 }
3445 }
3446
3447 //===----------------------------------------------------------------------===//
3448 // Row-wise quantized FC
3449 //===----------------------------------------------------------------------===//
3450 template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
fwdRowwiseQuantizedFullyConnectedInstImpl(Value * inV,Value * outV,Value * weightsV,Value * biasV,Value * scalesV,Value * offsetsV)3451 void BoundInterpreterFunction::fwdRowwiseQuantizedFullyConnectedInstImpl(
3452 Value *inV, Value *outV, Value *weightsV, Value *biasV, Value *scalesV,
3453 Value *offsetsV) {
3454 auto inW = getWeightHandle<ElemTy>(inV);
3455 auto outW = getWeightHandle<ElemTy>(outV);
3456 auto weightsW = getWeightHandle<ElemTy>(weightsV);
3457 auto biasW = getWeightHandle<BiasElemTy>(biasV);
3458 auto scalesW = getWeightHandle<float>(scalesV);
3459 auto offsetsW = getWeightHandle<int32_t>(offsetsV);
3460 ShapeHW idim(inW.dims());
3461 ShapeHW odim(outW.dims());
3462 auto inTy = inW.getType();
3463 auto biasTy = biasW.getType();
3464 auto outTy = outW.getType();
3465 int32_t outOffset = outTy.getOffset();
3466 int32_t inOffset = inTy.getOffset();
3467 int32_t biasOffset = biasTy.getOffset();
3468 float outScale = outTy.getScale();
3469 float inScale = inTy.getScale();
3470 float biasScale = biasTy.getScale();
3471
3472 for (dim_t i = 0; i < idim.height; i++) {
3473 for (dim_t j = 0; j < odim.width; j++) {
3474 float matMulScale = scalesW.raw(j) * inScale;
3475 AccumulatorTy sum = 0;
3476 for (dim_t k = 0; k < idim.width; k++) {
3477 AccumulatorTy W = weightsW.at({j, k});
3478 AccumulatorTy A = inW.at({i, k});
3479 sum += (W - offsetsW.raw(j)) * (A - inOffset);
3480 }
3481
3482 // Scale the bias to match the scale of the matrix multiplication.
3483 AccumulatorTy B = std::round(float(biasW.at({j}) - biasOffset) *
3484 (biasScale / matMulScale));
3485
3486 // Add the bias.
3487 sum += B;
3488
3489 // Scale the result back to the expected destination scale.
3490 outW.at({i, j}) = quantization::clip<AccumulatorTy, ElemTy>(
3491 std::round(float(sum) * (matMulScale / outScale) + outOffset));
3492 }
3493 }
3494 }
3495
fwdRowwiseQuantizedFullyConnectedInst(const RowwiseQuantizedFullyConnectedInst * I)3496 void BoundInterpreterFunction::fwdRowwiseQuantizedFullyConnectedInst(
3497 const RowwiseQuantizedFullyConnectedInst *I) {
3498 dispatchQuantizedWithAccumulationAndBiasImpl(
3499 fwdRowwiseQuantizedFullyConnectedInstImpl, I->getSrc()->getElementType(),
3500 I->getBias()->getElementType(), I->getSrc(), I->getDest(),
3501 I->getWeights(), I->getBias(), I->getScales(), I->getOffsets());
3502 }
3503
3504 //===----------------------------------------------------------------------===//
3505 // Batched operations
3506 //===----------------------------------------------------------------------===//
3507 template <typename ElemTy, typename AccumulatorTy, typename SliceElemTy>
fwdBatchedAdd(Tensor * batch,Tensor * slice,Tensor * dest)3508 static void fwdBatchedAdd(Tensor *batch, Tensor *slice, Tensor *dest) {
3509 auto batchH = batch->getHandle<ElemTy>();
3510 auto sliceH = slice->getHandle<SliceElemTy>();
3511 auto destH = dest->getHandle<ElemTy>();
3512
3513 auto batchTy = batch->getType();
3514 auto sliceTy = slice->getType();
3515 auto destTy = dest->getType();
3516
3517 float sliceScale = sliceTy.getScale();
3518 float batchScale = batchTy.getScale();
3519 float destScale = destTy.getScale();
3520
3521 int32_t sliceOffset = sliceTy.getOffset();
3522 int32_t batchOffset = batchTy.getOffset();
3523 int32_t destOffset = destTy.getOffset();
3524
3525 auto bdim = flattenCdr(batchH.dims());
3526 assert(sliceH.size() == bdim.second && "Invalid slice size");
3527 assert(batchH.dims().drop_front() == sliceH.dims() && "Invalid batch size");
3528
3529 // For each layer in the batch:
3530 for (dim_t n = 0; n < bdim.first; n++) {
3531 size_t base = batchH.getElementPtr({n});
3532
3533 // For each element in the slice.
3534 for (dim_t i = 0; i < bdim.second; i++) {
3535 AccumulatorTy batchVal = batchH.raw(base + i);
3536 AccumulatorTy sliceVal = sliceH.raw(i);
3537 // We increase the size of the integer up to 16 bits for more accurate
3538 // arithmetic.
3539 const float largeScale = float(1) / (1 << 15);
3540 // Scale both sides from 8-bit to 16-bits.
3541 AccumulatorTy B =
3542 std::round(float(batchVal - batchOffset) * (batchScale / largeScale));
3543 AccumulatorTy S =
3544 std::round(float(sliceVal - sliceOffset) * (sliceScale / largeScale));
3545 AccumulatorTy R = B + S;
3546 destH.raw(base + i) = quantization::clip<AccumulatorTy, ElemTy>(
3547 std::round(float(R) * (largeScale / destScale) + destOffset));
3548 }
3549 }
3550 }
3551
3552 template <typename ElemTy>
fwdBatchedAddInstFloatImpl(const glow::BatchedAddInst * I)3553 void BoundInterpreterFunction::fwdBatchedAddInstFloatImpl(
3554 const glow::BatchedAddInst *I) {
3555 staticAssertFloatingPointType(ElemTy);
3556
3557 auto batch = getWeightHandle<ElemTy>(I->getBatch());
3558 auto slice = getWeightHandle<ElemTy>(I->getSlice());
3559 auto dest = getWeightHandle<ElemTy>(I->getDest());
3560
3561 auto bdim = flattenCdr(batch.dims());
3562 assert(slice.size() == bdim.second && "Invalid slice size");
3563 assert(batch.dims().drop_front() == slice.dims() && "Invalid batch size");
3564
3565 // For each layer in the batch:
3566 for (dim_t n = 0; n < bdim.first; n++) {
3567 size_t base = batch.getElementPtr({n});
3568
3569 // For each element in the slice.
3570 for (dim_t i = 0; i < bdim.second; i++) {
3571 dest.raw(base + i) = batch.raw(base + i) + slice.raw(i);
3572 }
3573 }
3574 }
3575
fwdBatchedAddInst(const glow::BatchedAddInst * I)3576 void BoundInterpreterFunction::fwdBatchedAddInst(
3577 const glow::BatchedAddInst *I) {
3578 if (getTensor(I->getBatch())->getType().isQuantizedType()) {
3579 dispatchQuantizedWithAccumulationAndBiasImpl(
3580 fwdBatchedAdd, I->getBatch()->getElementType(),
3581 I->getSlice()->getElementType(), getTensor(I->getBatch()),
3582 getTensor(I->getSlice()), getTensor(I->getDest()));
3583 return;
3584 }
3585 dispatchFloatingPointImpl(fwdBatchedAddInstFloatImpl,
3586 I->getBatch()->getElementType(), I);
3587 }
3588
3589 template <typename ElemTy>
fwdBatchedReduceAddInstFloatImpl(Value * batch,Value * dest,unsigned_t axis,const ShapeVector & eBatchDims,const ShapeVector & eDestDims)3590 void BoundInterpreterFunction::fwdBatchedReduceAddInstFloatImpl(
3591 Value *batch, Value *dest, unsigned_t axis, const ShapeVector &eBatchDims,
3592 const ShapeVector &eDestDims) {
3593 staticAssertFloatingPointType(ElemTy);
3594
3595 // Get unowned handles of the batch and dest with these new expanded dims.
3596 auto eBatch = getTensor(batch)->getUnowned(eBatchDims);
3597 auto eDest = getTensor(dest)->getUnowned(eDestDims);
3598 auto eBatchH = eBatch.getHandle<ElemTy>();
3599 auto eDestH = eDest.getHandle<ElemTy>();
3600 eDestH.clear();
3601
3602 // We can use this loop for all shapes. Use the same indices for both the
3603 // batch and dest, except for setting the axis index in the dest to 0.
3604 for (dim_t x = 0; x < eBatchDims[0]; x++) {
3605 for (dim_t y = 0; y < eBatchDims[1]; y++) {
3606 for (dim_t z = 0; z < eBatchDims[2]; z++) {
3607 for (dim_t w = 0; w < eBatchDims[3]; w++) {
3608 for (dim_t q = 0; q < eBatchDims[4]; q++) {
3609 for (dim_t r = 0; r < eBatchDims[5]; r++) {
3610 dim_t destIndices[] = {x, y, z, w, q, r};
3611 destIndices[axis] = 0;
3612 eDestH.at(destIndices) =
3613 eDestH.at(destIndices) + eBatchH.at({x, y, z, w, q, r});
3614 }
3615 }
3616 }
3617 }
3618 }
3619 }
3620 }
3621
fwdBatchedReduceAddInst(const glow::BatchedReduceAddInst * I)3622 void BoundInterpreterFunction::fwdBatchedReduceAddInst(
3623 const glow::BatchedReduceAddInst *I) {
3624 static_assert(max_tensor_dimensions == 6,
3625 "Loops below assume max_tensor_dimensions = 6.");
3626
3627 auto *batch = I->getBatch();
3628 auto *dest = I->getDest();
3629 const auto axis = I->getAxis();
3630
3631 // Initialize both expanded batch and dest dims to the expanded batch
3632 // dims. This allows us below to iterate over the tensor regardless of its
3633 // shape using max_tensor_dimensions loops below.
3634 ShapeVector eBatchDims = expandDimsToMax(batch->dims());
3635 ShapeVector eDestDims = eBatchDims;
3636
3637 // Set the destination axis dimension (the one we are reducing) to 1.
3638 eDestDims[axis] = 1;
3639
3640 if (getTensor(batch)->getType().isQuantizedType()) {
3641 auto destTy = dest->getType();
3642 auto batchTy = batch->getType();
3643
3644 float destScale = destTy->getScale();
3645 float batchScale = batchTy->getScale();
3646
3647 int32_t destOffset = destTy->getOffset();
3648 int32_t batchOffset = batchTy->getOffset();
3649
3650 // Get unowned handles of the batch and dest with these new expanded dims.
3651 auto eBatch = getTensor(batch)->getUnowned(eBatchDims);
3652 auto eDest = getTensor(dest)->getUnowned(eDestDims);
3653 auto eBatchH = eBatch.getHandle<int8_t>();
3654 auto eDestH = eDest.getHandle<int8_t>();
3655 eDestH.clear();
3656
3657 // For quantization, we must accumulate in the inner-most loop into a local
3658 // float and then clip the result back into the dest tensor. Here are the
3659 // max_tensor_dimensions cases for this, to ensure the axis is used as the
3660 // inner-most loop.
3661 switch (axis) {
3662 #define LOOP_AXIS_CASE(_D0, _D1, _D2, _D3, _D4, _D5_AXIS) \
3663 case _D5_AXIS: \
3664 for (dim_t i##_D0 = 0; i##_D0 < eBatchDims[_D0]; i##_D0++) \
3665 for (dim_t i##_D1 = 0; i##_D1 < eBatchDims[_D1]; i##_D1++) \
3666 for (dim_t i##_D2 = 0; i##_D2 < eBatchDims[_D2]; i##_D2++) \
3667 for (dim_t i##_D3 = 0; i##_D3 < eBatchDims[_D3]; i##_D3++) \
3668 for (dim_t i##_D4 = 0; i##_D4 < eBatchDims[_D4]; i##_D4++) { \
3669 float sum = 0.0; \
3670 for (dim_t i##_D5_AXIS = 0; i##_D5_AXIS < eBatchDims[_D5_AXIS]; \
3671 i##_D5_AXIS++) { \
3672 sum += eBatchH.at({i0, i1, i2, i3, i4, i5}) - batchOffset; \
3673 } \
3674 dim_t i##_D5_AXIS = 0; \
3675 int32_t res = \
3676 std::round(sum * batchScale / destScale) + destOffset; \
3677 eDestH.at({i0, i1, i2, i3, i4, i5}) = \
3678 quantization::clip<int32_t, int8_t>(res); \
3679 } \
3680 return;
3681
3682 // Each loop order, with the inner-most dimension/index equal to the axis.
3683 LOOP_AXIS_CASE(1, 2, 3, 4, 5, 0);
3684 LOOP_AXIS_CASE(0, 2, 3, 4, 5, 1);
3685 LOOP_AXIS_CASE(0, 1, 3, 4, 5, 2);
3686 LOOP_AXIS_CASE(0, 1, 2, 4, 5, 3);
3687 LOOP_AXIS_CASE(0, 1, 2, 3, 5, 4);
3688 LOOP_AXIS_CASE(0, 1, 2, 3, 4, 5);
3689 #undef LOOP_AXIS_CASE
3690 default:
3691 llvm_unreachable("Axis should be less than max_tensor_dimensions.");
3692 }
3693 }
3694 dispatchFloatingPointImpl(fwdBatchedReduceAddInstFloatImpl,
3695 batch->getElementType(), batch, dest, axis,
3696 eBatchDims, eDestDims);
3697 }
3698
3699 template <typename ElemTy>
fwdBatchedReduceMinInstImpl(Value * batch,Value * dest,const ShapeVector & eBatchDims,const ShapeVector & eDestDims,ElemTy max)3700 void BoundInterpreterFunction::fwdBatchedReduceMinInstImpl(
3701 Value *batch, Value *dest, const ShapeVector &eBatchDims,
3702 const ShapeVector &eDestDims, ElemTy max) {
3703 static_assert(max_tensor_dimensions == 6,
3704 "Loops below assume max_tensor_dimensions = 6.");
3705 // Get unowned handles of the batch and dest with these new expanded dims.
3706 auto eBatch = getTensor(batch)->getUnowned(eBatchDims);
3707 auto eDest = getTensor(dest)->getUnowned(eDestDims);
3708 auto eBatchH = eBatch.getHandle<ElemTy>();
3709 auto eDestH = eDest.getHandle<ElemTy>();
3710 eDestH.clear(max);
3711
3712 unsigned int axes[max_tensor_dimensions];
3713 for (dim_t i = 0; i < max_tensor_dimensions; i++) {
3714 axes[i] = (eDestDims[i] > 1);
3715 }
3716
3717 // We can use this loop for all shapes. Use the same indices for both the
3718 // batch and dest, except for setting the axis index in the dest to 0.
3719 for (dim_t x = 0, dx = 0; x < eBatchDims[0]; x++, dx += axes[0]) {
3720 for (dim_t y = 0, dy = 0; y < eBatchDims[1]; y++, dy += axes[1]) {
3721 for (dim_t z = 0, dz = 0; z < eBatchDims[2]; z++, dz += axes[2]) {
3722 for (dim_t w = 0, dw = 0; w < eBatchDims[3]; w++, dw += axes[3]) {
3723 for (dim_t q = 0, dq = 0; q < eBatchDims[4]; q++, dq += axes[4]) {
3724 for (dim_t r = 0, dr = 0; r < eBatchDims[5]; r++, dr += axes[5]) {
3725 dim_t destIndices[] = {dx, dy, dz, dw, dq, dr};
3726 dim_t srcIndices[] = {x, y, z, w, q, r};
3727 eDestH.at(destIndices) =
3728 eDestH.at(destIndices) < eBatchH.at(srcIndices)
3729 ? eDestH.at(destIndices)
3730 : eBatchH.at(srcIndices);
3731 }
3732 }
3733 }
3734 }
3735 }
3736 }
3737 }
3738
fwdBatchedReduceMinInst(const glow::BatchedReduceMinInst * I)3739 void BoundInterpreterFunction::fwdBatchedReduceMinInst(
3740 const glow::BatchedReduceMinInst *I) {
3741
3742 auto *batch = I->getBatch();
3743 auto *dest = I->getDest();
3744 const auto axes = I->getAxes();
3745
3746 // Initialize both expanded batch and dest dims to the expanded batch
3747 // dims. This allows us below to iterate over the tensor regardless of its
3748 // shape using max_tensor_dimensions loops below.
3749 ShapeVector eBatchDims = expandDimsToMax(batch->dims());
3750 ShapeVector eDestDims = eBatchDims;
3751 // Set the destination axes dimensions (the one we are reducing) to 1.
3752 for (dim_t i = 0; i < axes.size(); i++) {
3753 eDestDims[axes[i]] = 1;
3754 }
3755
3756 dispatchArithmeticImpl(fwdBatchedReduceMinInstImpl, batch->getElementType(),
3757 batch, dest, eBatchDims, eDestDims,
3758 std::numeric_limits<int32_t>::max());
3759 }
3760
3761 template <typename ElemTy>
fwdCumSumInstImpl(Value * input,Value * dest,bool exclusive,bool reverse)3762 void BoundInterpreterFunction::fwdCumSumInstImpl(Value *input, Value *dest,
3763 bool exclusive, bool reverse) {
3764 auto *eInput = getTensor(input);
3765 auto *eDest = getTensor(dest);
3766 auto eInputH = eInput->getHandle<ElemTy>();
3767 auto eDestH = eDest->getHandle<ElemTy>();
3768 eDestH.clear();
3769
3770 ElemTy accum = 0;
3771
3772 sdim_t s = 0;
3773 sdim_t n = eDestH.size();
3774 sdim_t dir = 1;
3775
3776 if (reverse) {
3777 s = n - 1;
3778 n = -1;
3779 dir = -1;
3780 }
3781
3782 for (sdim_t i = s; i != n; i += dir) {
3783 if (!exclusive) {
3784 accum += eInputH.at(i);
3785 }
3786 eDestH.at(i) = accum;
3787 if (exclusive) {
3788 accum += eInputH.at(i);
3789 }
3790 }
3791 }
3792
fwdCumSumInst(glow::CumSumInst const * I)3793 void BoundInterpreterFunction::fwdCumSumInst(glow::CumSumInst const *I) {
3794 dispatchArithmeticImpl(fwdCumSumInstImpl, I->getInput()->getElementType(),
3795 I->getInput(), I->getDest(), I->getExclusive(),
3796 I->getReverse());
3797 }
3798
3799 template <typename ElemTy>
fwdLengthsSumInstFloatImpl(const LengthsSumInst * I)3800 void BoundInterpreterFunction::fwdLengthsSumInstFloatImpl(
3801 const LengthsSumInst *I) {
3802 staticAssertFloatingPointType(ElemTy);
3803
3804 auto out = getTensor(I->getDest());
3805 auto data = getTensor(I->getData());
3806 auto lengths = getTensor(I->getLengths());
3807
3808 out->zero();
3809
3810 auto LH = lengths->getHandle<int32_t>();
3811
3812 size_t segments = lengths->dims()[0];
3813 size_t sliceSize = data->size() / data->dims()[0];
3814
3815 auto DH = data->getHandle<ElemTy>();
3816 auto OH = out->getHandle<ElemTy>();
3817
3818 size_t offsetIn = 0;
3819 size_t offsetOut = 0;
3820 for (dim_t i = 0; i < segments; i++) {
3821 for (int32_t j = 0, e = LH.raw(i); j < e; j++) {
3822 for (dim_t k = 0; k < sliceSize; k++) {
3823 OH.raw(offsetOut + k) += DH.raw(offsetIn + k);
3824 }
3825 offsetIn += sliceSize;
3826 }
3827 offsetOut += sliceSize;
3828 }
3829
3830 assert(offsetIn == data->size() && "All values in Data should be consumed");
3831 assert(offsetOut == out->size() && "All values in Dest should be written to");
3832 }
3833
fwdLengthsSumInst(const LengthsSumInst * I)3834 void BoundInterpreterFunction::fwdLengthsSumInst(const LengthsSumInst *I) {
3835 dispatchFloatingPointImpl(fwdLengthsSumInstFloatImpl,
3836 I->getData()->getElementType(), I)
3837 }
3838
3839 template <typename TI>
fwdSparseLengthsSumInstI8Impl(const SparseLengthsSumInst * I)3840 void BoundInterpreterFunction::fwdSparseLengthsSumInstI8Impl(
3841 const SparseLengthsSumInst *I) {
3842
3843 auto out = getTensor(I->getDest());
3844 auto data = getTensor(I->getData());
3845 auto indices = getTensor(I->getIndices());
3846 auto lengths = getTensor(I->getLengths());
3847
3848 out->zero();
3849
3850 auto IH = indices->getHandle<TI>();
3851 auto LH = lengths->getHandle<int32_t>();
3852
3853 size_t segments = lengths->dims()[0];
3854 size_t totalLength = 0;
3855 for (size_t i = 0; i < segments; i++) {
3856 totalLength += LH.raw(i);
3857 }
3858 assert(totalLength <= indices->dims()[0] &&
3859 "sum(Lengths) must be equal to len(Indices)");
3860
3861 size_t lineSize = data->size() / data->dims()[0];
3862
3863 auto DH = data->getHandle<int8_t>();
3864 auto OH = out->getHandle<int8_t>();
3865
3866 auto TQP = [](Tensor *T) {
3867 return TensorQuantizationParams{T->getType().getScale(),
3868 T->getType().getOffset()};
3869 };
3870
3871 size_t curIdx = 0;
3872 for (size_t i = 0; i < segments; i++) {
3873 std::vector<float> accum(lineSize, 0.0f);
3874 for (int32_t j = 0; j < LH.raw(i); j++) {
3875 size_t offsetIn = IH.raw(curIdx) * lineSize;
3876 for (size_t k = 0; k < lineSize; k++) {
3877 accum[k] += quantization::dequantize(DH.raw(offsetIn++), TQP(data));
3878 }
3879 curIdx++;
3880 }
3881 size_t offsetOut = i * lineSize;
3882 for (size_t k = 0; k < lineSize; k++) {
3883 OH.raw(offsetOut++) = quantization::quantize(accum[k], TQP(out));
3884 }
3885 }
3886 }
3887
3888 template <typename ElemTy, typename TI>
fwdSparseLengthsSumInstFloatImpl(const SparseLengthsSumInst * I)3889 void BoundInterpreterFunction::fwdSparseLengthsSumInstFloatImpl(
3890 const SparseLengthsSumInst *I) {
3891 staticAssertFloatingPointType(ElemTy);
3892
3893 auto out = getTensor(I->getDest());
3894 auto data = getTensor(I->getData());
3895 auto indices = getTensor(I->getIndices());
3896 auto lengths = getTensor(I->getLengths());
3897
3898 out->zero();
3899
3900 auto IH = indices->getHandle<TI>();
3901 auto LH = lengths->getHandle<int32_t>();
3902
3903 size_t segments = lengths->dims()[0];
3904 size_t totalLength = 0;
3905 for (size_t i = 0; i < segments; i++) {
3906 totalLength += LH.raw(i);
3907 }
3908 assert(totalLength <= indices->dims()[0] &&
3909 "sum(Lengths) must be equal to len(Indices)");
3910
3911 size_t lineSize = data->size() / data->dims()[0];
3912
3913 auto DH = data->getHandle<ElemTy>();
3914 auto OH = out->getHandle<ElemTy>();
3915
3916 size_t curIdx = 0;
3917 for (size_t i = 0; i < segments; i++) {
3918 for (size_t j = 0, e = LH.raw(i); j < e; j++) {
3919 size_t offsetIn = IH.raw(curIdx++) * lineSize;
3920 size_t offsetOut = i * lineSize;
3921 for (size_t k = 0; k < lineSize; k++)
3922 OH.raw(offsetOut++) += DH.raw(offsetIn++);
3923 }
3924 }
3925 }
3926
fwdSparseLengthsSumInst(const SparseLengthsSumInst * I)3927 void BoundInterpreterFunction::fwdSparseLengthsSumInst(
3928 const SparseLengthsSumInst *I) {
3929 if (I->getDest()->getType()->isQuantizedType()) {
3930 dispatchIndexTypeImpl(fwdSparseLengthsSumInstI8Impl,
3931 I->getIndices()->getElementType(), I);
3932 return;
3933 }
3934 dispatchFloatingPointAndIndexImpl(fwdSparseLengthsSumInstFloatImpl,
3935 I->getData()->getElementType(),
3936 I->getIndices()->getElementType(), I);
3937 }
3938
3939 template <typename ElemTy, typename TI>
fwdSparseLengthsWeightedSumInstFloatImpl(const SparseLengthsWeightedSumInst * I)3940 void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInstFloatImpl(
3941 const SparseLengthsWeightedSumInst *I) {
3942 staticAssertFloatingPointType(ElemTy);
3943
3944 auto out = getTensor(I->getDest());
3945 auto data = getTensor(I->getData());
3946 auto weights = getTensor(I->getWeights());
3947 auto indices = getTensor(I->getIndices());
3948 auto lengths = getTensor(I->getLengths());
3949
3950 out->zero();
3951
3952 auto IH = indices->getHandle<TI>();
3953 auto LH = lengths->getHandle<int32_t>();
3954
3955 size_t segments = lengths->dims()[0];
3956 size_t totalLength = 0;
3957 for (dim_t i = 0; i < segments; i++) {
3958 totalLength += LH.raw(i);
3959 }
3960 assert(totalLength <= indices->dims()[0] &&
3961 "sum(Lengths) must be equal to len(Indices)");
3962
3963 dim_t lineSize = data->size() / data->dims()[0];
3964
3965 auto DH = data->getHandle<ElemTy>();
3966 auto WH = weights->getHandle<ElemTy>();
3967 auto OH = out->getHandle<ElemTy>();
3968
3969 dim_t curIdx = 0;
3970 for (dim_t i = 0; i < segments; i++) {
3971 for (dim_t j = 0, e = LH.raw(i); j < e; j++) {
3972 ElemTy weight = WH.raw(curIdx);
3973 size_t offsetIn = IH.raw(curIdx++) * lineSize;
3974 size_t offsetOut = i * lineSize;
3975 for (dim_t k = 0; k < lineSize; k++)
3976 OH.raw(offsetOut++) += DH.raw(offsetIn++) * weight;
3977 }
3978 }
3979 }
3980
3981 template <typename TI>
fwdSparseLengthsWeightedSumInstI8Impl(const SparseLengthsWeightedSumInst * I)3982 void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInstI8Impl(
3983 const SparseLengthsWeightedSumInst *I) {
3984
3985 auto out = getTensor(I->getDest());
3986 auto data = getTensor(I->getData());
3987 auto weights = getTensor(I->getWeights());
3988 auto indices = getTensor(I->getIndices());
3989 auto lengths = getTensor(I->getLengths());
3990
3991 out->zero();
3992
3993 auto IH = indices->getHandle<TI>();
3994 auto LH = lengths->getHandle<int32_t>();
3995
3996 dim_t segments = lengths->dims()[0];
3997 dim_t totalLength = 0;
3998 for (dim_t i = 0; i < segments; i++) {
3999 totalLength += LH.raw(i);
4000 }
4001 assert(totalLength <= indices->dims()[0] &&
4002 "sum(Lengths) must be equal to len(Indices)");
4003
4004 dim_t lineSize = data->size() / data->dims()[0];
4005
4006 auto DH = data->getHandle<int8_t>();
4007 auto WH = weights->getHandle<int8_t>();
4008 auto OH = out->getHandle<int8_t>();
4009
4010 auto TQP = [](Tensor *T) {
4011 return TensorQuantizationParams{T->getType().getScale(),
4012 T->getType().getOffset()};
4013 };
4014 using namespace quantization;
4015
4016 dim_t curIdx = 0;
4017 for (dim_t i = 0; i < segments; i++) {
4018 std::vector<float> accum(lineSize, 0.0f);
4019 for (int32_t j = 0; j < LH.raw(i); j++) {
4020 float weight = dequantize(WH.raw(curIdx), TQP(weights));
4021 size_t offsetIn = IH.raw(curIdx) * lineSize;
4022 for (dim_t k = 0; k < lineSize; k++) {
4023 accum[k] += weight * dequantize(DH.raw(offsetIn++), TQP(data));
4024 }
4025 curIdx++;
4026 }
4027 dim_t offsetOut = i * lineSize;
4028 for (dim_t k = 0; k < lineSize; k++) {
4029 OH.raw(offsetOut++) = quantize(accum[k], TQP(out));
4030 }
4031 }
4032 }
4033
fwdSparseLengthsSumGradInst(const SparseLengthsSumGradInst *)4034 void BoundInterpreterFunction::fwdSparseLengthsSumGradInst(
4035 const SparseLengthsSumGradInst * /*I*/) {
4036 DCHECK(!"Found SparseLengthsSumGradInst but SparseLengthsSum is lowered on "
4037 "Interpreter");
4038 }
4039
fwdSparseLengthsWeightedSumInst(const SparseLengthsWeightedSumInst * I)4040 void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInst(
4041 const SparseLengthsWeightedSumInst *I) {
4042 if (I->getDest()->getType()->isQuantizedType()) {
4043 dispatchIndexTypeImpl(fwdSparseLengthsWeightedSumInstI8Impl,
4044 I->getIndices()->getElementType(), I);
4045 return;
4046 }
4047 dispatchFloatingPointAndIndexImpl(fwdSparseLengthsWeightedSumInstFloatImpl,
4048 I->getData()->getElementType(),
4049 I->getIndices()->getElementType(), I);
4050 }
4051
fwdSparseLengthsWeightedSumGradInst(const SparseLengthsWeightedSumGradInst * I)4052 void BoundInterpreterFunction::fwdSparseLengthsWeightedSumGradInst(
4053 const SparseLengthsWeightedSumGradInst *I) {
4054 assert(I->getDataGrad()->getType()->getElementType() == ElemKind::FloatTy &&
4055 "Input type must be float");
4056
4057 auto destGrad = getTensor(I->getDestGrad());
4058 auto data = getTensor(I->getData());
4059 auto dataGrad = getTensor(I->getDataGrad());
4060 auto weightsGrad = getTensor(I->getWeightsGrad());
4061 auto weights = getTensor(I->getWeights());
4062 auto indices = getTensor(I->getIndices());
4063 auto lengths = getTensor(I->getLengths());
4064
4065 // The data gradients not touched by this operation should
4066 // be 0, so set the entire buffer to 0 to start with.
4067 dataGrad->zero();
4068
4069 auto LH = lengths->getHandle<int32_t>();
4070 auto IH = indices->getHandle<int64_t>();
4071
4072 size_t segments = lengths->dims()[0];
4073 size_t totalLength = 0;
4074 for (size_t i = 0; i < segments; ++i) {
4075 totalLength += LH.raw(i);
4076 }
4077 assert(totalLength == indices->dims()[0] &&
4078 "sum(Lengths) must be equal to len(Indices)");
4079
4080 size_t lineSize = dataGrad->size() / dataGrad->dims()[0];
4081
4082 auto IGH = destGrad->getHandle();
4083 auto WH = weights->getHandle();
4084 auto WGH = weightsGrad->getHandle();
4085 auto DH = data->getHandle();
4086 auto OGH = dataGrad->getHandle();
4087
4088 // For each index in each segment:
4089 // 1) accumulate into the corresponding data gradient the product of the
4090 // gradient of the result it was added to and the weight that it was
4091 // multiplied by during the SparseLengthsWeightedSum operation.
4092 //
4093 // 2) accumulate into each weight gradient the reduced sum of the
4094 // elementwise product of the result slice that the corresponding weight
4095 // produced and the input slice that the weight was multiplied with.
4096 for (size_t i = 0, curIdx = 0; i < segments; ++i) {
4097 size_t destOffset = i * lineSize;
4098 for (size_t j = 0, e = LH.raw(i); j < e; ++j, ++curIdx) {
4099 float weightGrad = 0.0f;
4100 float weight = WH.raw(curIdx);
4101 size_t dataOffset = IH.raw(curIdx) * lineSize;
4102
4103 for (size_t k = 0; k < lineSize; ++k) {
4104 OGH.raw(dataOffset + k) += IGH.raw(destOffset + k) * weight;
4105 weightGrad += IGH.raw(destOffset + k) * DH.raw(dataOffset + k);
4106 }
4107
4108 WGH.raw(curIdx) = weightGrad;
4109 }
4110 }
4111 }
4112
4113 template <typename ElemTy>
fwdEmbeddingBagInstFloatImpl(const EmbeddingBagInst * I)4114 void BoundInterpreterFunction::fwdEmbeddingBagInstFloatImpl(
4115 const EmbeddingBagInst *I) {
4116 staticAssertFloatingPointType(ElemTy);
4117
4118 auto out = getTensor(I->getDest());
4119 auto data = getTensor(I->getData());
4120 auto weights = getTensor(I->getWeights());
4121 auto indices = getTensor(I->getIndices());
4122 auto offsets = getTensor(I->getOffsets());
4123 bool hasEndOffset = I->getHasEndOffset();
4124
4125 out->zero();
4126
4127 auto IH = indices->getHandle<int64_t>();
4128 auto OFFH = offsets->getHandle<int64_t>();
4129
4130 // If an end offset is present to mark the end of the last segment then this
4131 // must be subtracted to get the correct number of segments
4132 size_t segments = hasEndOffset ? offsets->dims()[0] - 1 : offsets->dims()[0];
4133 size_t numIndices = indices->dims()[0];
4134
4135 size_t lineSize = data->size() / data->dims()[0];
4136
4137 auto DH = data->getHandle<ElemTy>();
4138 auto WH = weights->getHandle<ElemTy>();
4139 auto OH = out->getHandle<ElemTy>();
4140
4141 dim_t curIdx = 0;
4142 for (dim_t i = 0; i < segments; i++) {
4143 dim_t start = OFFH.raw(i);
4144 dim_t end;
4145 if (!hasEndOffset) {
4146 // Note that in this case we have to use numIndices to find the end of
4147 // the last segment. This is an issue though because it relies on knowing
4148 // the total length of the indices tensor which may not be possible.
4149 // Future implementations of this operator should always give an end
4150 // offset so eventually this case should be removed.
4151 end = i == segments - 1 ? numIndices : OFFH.raw(i + 1);
4152 } else {
4153 end = OFFH.raw(i + 1);
4154 }
4155 if (start == end) {
4156 continue;
4157 } else if (start > end) {
4158 break;
4159 }
4160 for (dim_t j = start; j < end; j++) {
4161 ElemTy weight = WH.raw(curIdx);
4162 dim_t offsetIn = IH.raw(curIdx++) * lineSize;
4163 dim_t offsetOut = i * lineSize;
4164 for (dim_t k = 0; k < lineSize; k++) {
4165 OH.raw(offsetOut++) += DH.raw(offsetIn++) * weight;
4166 }
4167 }
4168 }
4169 }
4170
fwdEmbeddingBagInst(const EmbeddingBagInst * I)4171 void BoundInterpreterFunction::fwdEmbeddingBagInst(const EmbeddingBagInst *I) {
4172 dispatchFloatingPointImpl(fwdEmbeddingBagInstFloatImpl,
4173 I->getData()->getElementType(), I);
4174 }
4175
4176 template <typename T, typename AccumT, typename TI>
fwdRowwiseQuantizedSparseLengthsWeightedSumImpl(const RowwiseQuantizedSparseLengthsWeightedSumInst * I)4177 void BoundInterpreterFunction::fwdRowwiseQuantizedSparseLengthsWeightedSumImpl(
4178 const RowwiseQuantizedSparseLengthsWeightedSumInst *I) {
4179 auto *out = getTensor(I->getDest());
4180 auto *data = getTensor(I->getData());
4181 auto *dataScales = getTensor(I->getScales());
4182 auto *dataOffsets = getTensor(I->getOffsets());
4183 auto *weights = getTensor(I->getWeights());
4184 auto *indices = getTensor(I->getIndices());
4185 auto *lengths = getTensor(I->getLengths());
4186
4187 out->zero();
4188
4189 auto IH = indices->getHandle<TI>();
4190 auto LH = lengths->getHandle<int32_t>();
4191
4192 dim_t segments = lengths->dims()[0];
4193 dim_t totalLength = 0;
4194 for (dim_t i = 0; i < segments; i++) {
4195 totalLength += LH.raw(i);
4196 }
4197 assert(totalLength <= indices->dims()[0] &&
4198 "sum(Lengths) must be equal to len(Indices)");
4199
4200 dim_t lineSize = data->size() / data->dims()[0];
4201
4202 auto DH = data->getHandle<uint8_t>();
4203 auto DSH = dataScales->getHandle<T>();
4204 auto DOH = dataOffsets->getHandle<T>();
4205 auto WH = weights->getHandle<T>();
4206 auto OH = out->getHandle<T>();
4207
4208 dim_t curIdx = 0;
4209 for (dim_t i = 0; i < segments; i++) {
4210 std::vector<AccumT> accum(lineSize, 0.0f);
4211 for (dim_t j = 0, e = LH.raw(i); j < e; j++) {
4212 const float weight = static_cast<float>(WH.raw(curIdx));
4213 const dim_t rowIdx = IH.raw(curIdx++);
4214 const float scale = static_cast<float>(DSH.at({rowIdx}));
4215 const float offset = static_cast<float>(DOH.at({rowIdx}));
4216 size_t offsetIn = rowIdx * lineSize;
4217 for (dim_t k = 0; k < lineSize; k++) {
4218 float d = quantization::dequantizeWithFloatOffset(DH.raw(offsetIn++),
4219 scale, offset);
4220 accum[k] += d * weight;
4221 }
4222 }
4223 // Accumulation in FP32 complete, now copy back to output with cast to T.
4224 size_t offsetOut = i * lineSize;
4225 for (size_t k = 0; k < lineSize; k++) {
4226 OH.raw(offsetOut++) = static_cast<T>(accum[k]);
4227 }
4228 }
4229 }
4230
fwdRowwiseQuantizedSparseLengthsWeightedSumInst(const RowwiseQuantizedSparseLengthsWeightedSumInst * I)4231 void BoundInterpreterFunction::fwdRowwiseQuantizedSparseLengthsWeightedSumInst(
4232 const RowwiseQuantizedSparseLengthsWeightedSumInst *I) {
4233 const auto ity = I->getIndices()->getElementType();
4234 switch (I->getDest()->getElementType()) {
4235 case ElemKind::FloatTy:
4236 if (ity == ElemKind::Int32ITy) {
4237 fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float, float, int32_t>(I);
4238 } else if (ity == ElemKind::Int64ITy) {
4239 fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float, float, int64_t>(I);
4240 } else {
4241 llvm_unreachable("Index type is not supported");
4242 }
4243 break;
4244 case ElemKind::Float16Ty:
4245 if (I->getUseFP16Accumulation()) {
4246 if (ity == ElemKind::Int32ITy) {
4247 fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float16_t,
4248 int32_t>(I);
4249 } else if (ity == ElemKind::Int64ITy) {
4250 fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float16_t,
4251 int64_t>(I);
4252 } else {
4253 llvm_unreachable("Index type is not supported");
4254 }
4255 } else {
4256 if (ity == ElemKind::Int32ITy) {
4257 fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float,
4258 int32_t>(I);
4259 } else if (ity == ElemKind::Int64ITy) {
4260 fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float,
4261 int64_t>(I);
4262 } else {
4263 llvm_unreachable("Index type is not supported");
4264 }
4265 }
4266 break;
4267 default:
4268 llvm_unreachable("Type is not supported");
4269 }
4270 }
4271
4272 template <typename T, typename AccumT, typename TI>
4273 void BoundInterpreterFunction::
fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl(const FusedRowwiseQuantizedSparseLengthsWeightedSumInst * I)4274 fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl(
4275 const FusedRowwiseQuantizedSparseLengthsWeightedSumInst *I) {
4276 auto *out = getTensor(I->getDest());
4277 auto *data = getTensor(I->getData());
4278 auto *weights = getTensor(I->getWeights());
4279 auto *indices = getTensor(I->getIndices());
4280 auto *lengths = getTensor(I->getLengths());
4281
4282 out->zero();
4283
4284 auto IH = indices->getHandle<TI>();
4285 auto LH = lengths->getHandle<int32_t>();
4286
4287 size_t segments = lengths->dims()[0];
4288 size_t totalLength = 0;
4289 for (size_t i = 0; i < segments; i++) {
4290 totalLength += LH.raw(i);
4291 }
4292 assert(totalLength <= indices->dims()[0] &&
4293 "sum(Lengths) must be equal to len(Indices)");
4294
4295 const bool using4BitQuantization =
4296 data->getType().getElementType() == ElemKind::UInt4FusedFP16QTy;
4297
4298 const size_t outLineSize = out->size() / out->dims()[0];
4299
4300 auto DH = data->getHandle<uint8_t>();
4301 auto WH = weights->getHandle<T>();
4302 auto OH = out->getHandle<T>();
4303
4304 dim_t curIdx = 0;
4305 for (dim_t i = 0; i < segments; i++) {
4306 std::vector<AccumT> accum(outLineSize, 0.0f);
4307 for (dim_t j = 0, e = LH.raw(i); j < e; j++) {
4308 const float weight = static_cast<float>(WH.raw(curIdx));
4309 const dim_t rowIdx = IH.raw(curIdx++);
4310 // Data type for the Scale and Offset for fused types need not follow
4311 // the type for the output Tensor passed in T.
4312 float scale, offset;
4313 switch (
4314 getScaleOffsetElemKindFromFused(data->getType().getElementType())) {
4315 case ElemKind::FloatTy:
4316 std::tie(scale, offset) = DH.getFusedScaleOffsetFromRow<float>(rowIdx);
4317 break;
4318 case ElemKind::Float16Ty:
4319 std::tie(scale, offset) =
4320 DH.getFusedScaleOffsetFromRow<float16_t>(rowIdx);
4321 break;
4322 default:
4323 llvm_unreachable("Type is not supported");
4324 break;
4325 }
4326
4327 for (dim_t k = 0; k < outLineSize; k++) {
4328 float d = 0.0f;
4329 if (!using4BitQuantization) {
4330 d = quantization::dequantizeWithFloatOffset(
4331 DH.at({rowIdx, k}), static_cast<float>(scale),
4332 static_cast<float>(offset));
4333 } else {
4334 const bool isMSB = (k % 2 == 1);
4335 d = quantization::dequantize4BitWithFloatOffset(
4336 DH.at({rowIdx, k / 2}), static_cast<float>(scale),
4337 static_cast<float>(offset), isMSB);
4338 }
4339 accum[k] += d * weight;
4340 }
4341 }
4342 // Accumulation in FP32 complete, now copy back to output with cast to T.
4343 dim_t offsetOut = i * outLineSize;
4344 for (dim_t k = 0; k < outLineSize; k++) {
4345 OH.raw(offsetOut++) = static_cast<T>(accum[k]);
4346 }
4347 }
4348 }
4349
4350 void BoundInterpreterFunction::
fwdFusedRowwiseQuantizedSparseLengthsWeightedSumInst(const FusedRowwiseQuantizedSparseLengthsWeightedSumInst * I)4351 fwdFusedRowwiseQuantizedSparseLengthsWeightedSumInst(
4352 const FusedRowwiseQuantizedSparseLengthsWeightedSumInst *I) {
4353 const auto ity = I->getIndices()->getElementType();
4354 switch (I->getDest()->getElementType()) {
4355 case ElemKind::FloatTy:
4356 if (ity == ElemKind::Int32ITy) {
4357 fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<float, float,
4358 int32_t>(I);
4359 } else if (ity == ElemKind::Int64ITy) {
4360 fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<float, float,
4361 int64_t>(I);
4362 } else {
4363 llvm_unreachable("Index type is not supported");
4364 }
4365 break;
4366 case ElemKind::Float16Ty:
4367 if (I->getUseFP16Accumulation()) {
4368 if (ity == ElemKind::Int32ITy) {
4369 fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<
4370 float16_t, float16_t, int32_t>(I);
4371 } else if (ity == ElemKind::Int64ITy) {
4372 fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<
4373 float16_t, float16_t, int64_t>(I);
4374 } else {
4375 llvm_unreachable("Index type is not supported");
4376 }
4377 } else {
4378 if (ity == ElemKind::Int32ITy) {
4379 fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float,
4380 int32_t>(I);
4381 } else if (ity == ElemKind::Int64ITy) {
4382 fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float,
4383 int64_t>(I);
4384 } else {
4385 llvm_unreachable("Index type is not supported");
4386 }
4387 }
4388 break;
4389 default:
4390 llvm_unreachable("Type is not supported");
4391 }
4392 }
4393
4394 template <typename T, typename AccumT>
fwdEmbeddingBagByteRowwiseOffsetsImpl(const EmbeddingBagByteRowwiseOffsetsInst * I)4395 void BoundInterpreterFunction::fwdEmbeddingBagByteRowwiseOffsetsImpl(
4396 const EmbeddingBagByteRowwiseOffsetsInst *I) {
4397 auto *out = getTensor(I->getDest());
4398 auto *data = getTensor(I->getData());
4399 auto *weights = getTensor(I->getWeights());
4400 auto *indices = getTensor(I->getIndices());
4401 auto *offsets = getTensor(I->getOffsets());
4402 bool hasEndOffset = I->getHasEndOffset();
4403
4404 out->zero();
4405
4406 auto IH = indices->getHandle<int64_t>();
4407 auto OFFH = offsets->getHandle<int64_t>();
4408
4409 // If an end offset is present to mark the end of the last segment then this
4410 // must be subtracted to get the correct number of segments
4411 size_t segments = hasEndOffset ? offsets->dims()[0] - 1 : offsets->dims()[0];
4412 dim_t numIndices = indices->dims()[0];
4413
4414 const bool using4BitQuantization =
4415 data->getType().getElementType() == ElemKind::UInt4FusedFP16QTy;
4416
4417 const size_t outLineSize = out->size() / out->dims()[0];
4418
4419 auto DH = data->getHandle<uint8_t>();
4420 auto WH = weights->getHandle<T>();
4421 auto OH = out->getHandle<T>();
4422
4423 for (dim_t i = 0; i < segments; i++) {
4424 std::vector<AccumT> accum(outLineSize, 0.0f);
4425 size_t start = OFFH.raw(i);
4426 dim_t end;
4427 if (!hasEndOffset) {
4428 // Note that in this case we have to use numIndices to find the end of
4429 // the last segment. This is an issue though because it relies on knowing
4430 // the total length of the indices tensor which may not be possible.
4431 // Future implementations of this operator should always give an end
4432 // offset so eventually this case should be removed.
4433 end = i == segments - 1 ? numIndices : OFFH.raw(i + 1);
4434 } else {
4435 end = OFFH.raw(i + 1);
4436 }
4437 if (start == end) {
4438 continue;
4439 } else if (start > end) {
4440 break;
4441 }
4442
4443 for (dim_t j = start; j < end; j++) {
4444 const float weight = static_cast<float>(WH.raw(j));
4445 const dim_t rowIdx = IH.raw(j);
4446 T scale, offset;
4447 std::tie(scale, offset) = DH.getFusedScaleOffsetFromRow<T>(rowIdx);
4448 for (dim_t k = 0; k < outLineSize; k++) {
4449 float d = 0.0f;
4450 if (!using4BitQuantization) {
4451 d = quantization::dequantizeWithFloatOffset(
4452 DH.at({rowIdx, k}), static_cast<float>(scale),
4453 static_cast<float>(offset));
4454 } else {
4455 const bool isMSB = (k % 2 == 1);
4456 d = quantization::dequantize4BitWithFloatOffset(
4457 DH.at({rowIdx, k / 2}), static_cast<float>(scale),
4458 static_cast<float>(offset), isMSB);
4459 }
4460 accum[k] += d * weight;
4461 }
4462 }
4463 // Accumulation in FP32 complete, now copy back to output with cast to T.
4464 dim_t offsetOut = i * outLineSize;
4465 for (dim_t k = 0; k < outLineSize; k++) {
4466 OH.raw(offsetOut++) = static_cast<T>(accum[k]);
4467 }
4468 }
4469 }
4470
fwdEmbeddingBagByteRowwiseOffsetsInst(const EmbeddingBagByteRowwiseOffsetsInst * I)4471 void BoundInterpreterFunction::fwdEmbeddingBagByteRowwiseOffsetsInst(
4472 const EmbeddingBagByteRowwiseOffsetsInst *I) {
4473 switch (I->getDest()->getElementType()) {
4474 case ElemKind::FloatTy:
4475 fwdEmbeddingBagByteRowwiseOffsetsImpl<float, float>(I);
4476 break;
4477 case ElemKind::Float16Ty:
4478 if (I->getUseFP16Accumulation()) {
4479 fwdEmbeddingBagByteRowwiseOffsetsImpl<float16_t, float16_t>(I);
4480 } else {
4481 fwdEmbeddingBagByteRowwiseOffsetsImpl<float16_t, float>(I);
4482 }
4483 break;
4484 default:
4485 llvm_unreachable("Type is not supported");
4486 }
4487 }
4488
fwdLengthsToRangesInst(const LengthsToRangesInst * I)4489 void BoundInterpreterFunction::fwdLengthsToRangesInst(
4490 const LengthsToRangesInst *I) {
4491 auto ranges = getTensor(I->getDest())->getHandle<int32_t>();
4492 auto lengths = getTensor(I->getLengths())->getHandle<int32_t>();
4493 int32_t offset = 0;
4494 for (dim_t i = 0; i < lengths.dims()[0]; i++) {
4495 auto length = lengths.at({i});
4496 ranges.at({i, 0}) = offset;
4497 ranges.at({i, 1}) = length;
4498 offset += length;
4499 }
4500 }
4501
fwdLengthsRangeFillInst(const LengthsRangeFillInst * I)4502 void BoundInterpreterFunction::fwdLengthsRangeFillInst(
4503 const LengthsRangeFillInst *I) {
4504 auto lengthsH = getTensor(I->getLengths())->getHandle<int32_t>();
4505 auto resultH = getTensor(I->getDest())->getHandle<int32_t>();
4506 dim_t curIdx = 0;
4507 for (dim_t i = 0, e = lengthsH.dims()[0]; i < e; i++) {
4508 for (int32_t j = 0, f = lengthsH.at({i}); j < f; j++) {
4509 resultH.at({curIdx++}) = j;
4510 }
4511 }
4512 }
4513
4514 template <typename ElemTy>
fwdSparseToDenseInstImpl(const SparseToDenseInst * I)4515 void BoundInterpreterFunction::fwdSparseToDenseInstImpl(
4516 const SparseToDenseInst *I) {
4517
4518 auto out = getTensor(I->getDest());
4519 auto indices = getTensor(I->getIndices());
4520 auto values = getTensor(I->getValues());
4521
4522 out->zero();
4523
4524 auto IH = indices->getHandle<int64_t>();
4525
4526 size_t numIndices = indices->dims()[0];
4527 size_t numOutDims = out->dims().size();
4528
4529 // Convert sparse representation to dense representation by taking
4530 // slices of output and values and accumulating the value slice into
4531 // the output slice.
4532
4533 // Dimensions and offsets for the output and values slices. sliceDims
4534 // will always be {1, [rest of output dimensions]} since the first dimension
4535 // is the index in this operation. sliceOffsets will be {indices[j], 0, ...}
4536 // for the output slice and {j, 0, ...} for the values slice so that the
4537 // slice at index j gets mapped to index indices[j] in the dense
4538 // representation.
4539 ShapeVector sliceDims(out->dims().begin(), out->dims().end());
4540 ShapeVector sliceOffsets(numOutDims, 0);
4541 sliceDims[0] = 1;
4542
4543 for (dim_t j = 0; j < numIndices; ++j) {
4544 // Create values slice with offsets {j, 0, ...}.
4545 sliceOffsets[0] = j;
4546 auto VS = values->getUnowned(sliceDims, sliceOffsets);
4547 auto VSH = VS.getHandle<ElemTy>();
4548
4549 // Create output slice with offsets {indices[j], 0, ...}.
4550 sliceOffsets[0] = IH.at({j});
4551 auto OS = out->getUnowned(sliceDims, sliceOffsets);
4552 auto OSH = OS.getHandle<ElemTy>();
4553
4554 // Accumulate values slice into output slice.
4555 size_t outputSliceSize = OS.size();
4556 for (size_t k = 0; k < outputSliceSize; ++k) {
4557 OSH.raw(k) += VSH.raw(k);
4558 }
4559 }
4560 }
4561
fwdSparseToDenseInst(const SparseToDenseInst * I)4562 void BoundInterpreterFunction::fwdSparseToDenseInst(
4563 const SparseToDenseInst *I) {
4564 dispatchArithmeticImpl(fwdSparseToDenseInstImpl,
4565 I->getDest()->getElementType(), I);
4566 }
4567
fwdSparseToDenseMaskInst(const SparseToDenseMaskInst * I)4568 void BoundInterpreterFunction::fwdSparseToDenseMaskInst(
4569 const SparseToDenseMaskInst *I) {
4570 auto out = getTensor(I->getDest());
4571 auto values = getTensor(I->getValues());
4572 auto defaultValue = getTensor(I->getDefaultValue());
4573
4574 auto indicesH = getTensor(I->getIndices())->getHandle<int64_t>();
4575 auto lengthsH = getTensor(I->getLengths())->getHandle<int32_t>();
4576
4577 const std::vector<dim_t> &mask = I->getMask();
4578 size_t maskSize = mask.size();
4579 // Create a reverse map from ID to its position in the mask.
4580 std::unordered_map<int64_t, size_t> reverseMap;
4581 for (size_t i = 0; i < maskSize; i++) {
4582 assert(reverseMap.find(mask[i]) == reverseMap.end() &&
4583 "duplicate IDs in the mask");
4584 reverseMap[mask[i]] = i;
4585 }
4586
4587 auto valueSize = defaultValue->getSizeInBytes();
4588
4589 // First un-processed index-value pair.
4590 size_t posIn = 0;
4591 // Beginning of output block for first unprocessed batch.
4592 size_t byteOffsetOut = 0;
4593 // Lengths can be scalar, which means that all pairs belong to one batch.
4594 size_t numBatches = lengthsH.dims().empty() ? 1 : lengthsH.dims()[0];
4595 for (size_t batch = 0; batch < numBatches; batch++) {
4596 // Fill everything with maskSize copies of defaultValue.
4597 for (size_t i = 0; i < maskSize; i++) {
4598 std::copy(defaultValue->getUnsafePtr(),
4599 &defaultValue->getUnsafePtr()[valueSize],
4600 &out->getUnsafePtr()[byteOffsetOut + valueSize * i]);
4601 }
4602 // Go through input pairs and find matches.
4603 for (size_t i = 0, batchLen = lengthsH.raw(batch); i < batchLen;
4604 i++, posIn++) {
4605 int64_t idx = indicesH.raw(posIn);
4606 auto it = reverseMap.find(idx);
4607 // Skip if ID is not present in the mask.
4608 if (it == reverseMap.end())
4609 continue;
4610 size_t to = it->second;
4611
4612 std::copy(&values->getUnsafePtr()[posIn * valueSize],
4613 &values->getUnsafePtr()[(posIn + 1) * valueSize],
4614 &out->getUnsafePtr()[byteOffsetOut + valueSize * to]);
4615 }
4616
4617 byteOffsetOut += maskSize * valueSize;
4618 }
4619
4620 assert(posIn == indicesH.dims()[0] &&
4621 "Sum of Lengths must be equal to size of indices.");
4622 }
4623
4624 //===----------------------------------------------------------------------===//
4625 // Instructions used by RNN
4626 //===----------------------------------------------------------------------===//
4627 template <typename T, typename TI>
fwdTopK(Tensor * outW,Tensor * indW,Tensor * inW,size_t k)4628 static void fwdTopK(Tensor *outW, Tensor *indW, Tensor *inW, size_t k) {
4629 auto values = outW->getHandle<T>();
4630 auto indices = indW->getHandle<TI>();
4631 auto in = inW->getHandle<T>();
4632 size_t n = in.dims().back();
4633
4634 size_t in_p = 0, out_p = 0;
4635 size_t tensor_end = in.size();
4636 using pairType = std::pair<float, size_t>;
4637 std::vector<pairType> buf(n);
4638
4639 while (in_p < tensor_end) {
4640 for (size_t i = 0; i < n; i++) {
4641 buf[i].first = in.raw(in_p++);
4642 buf[i].second = i;
4643 }
4644 // NOTE: it's possible to do N + KlogK, while this version is NlogN
4645 std::sort(buf.begin(), buf.end(), [](const pairType &a, const pairType &b) {
4646 if (a.first != b.first)
4647 return a.first > b.first;
4648 return a.second < b.second;
4649 });
4650 for (size_t i = 0; i < k; i++) {
4651 values.raw(out_p) = buf[i].first;
4652 indices.raw(out_p) = buf[i].second;
4653 out_p++;
4654 }
4655 }
4656 }
4657
4658 template <typename inpType, typename outType>
fwdArgMax(Tensor * inpT,Tensor * outT,size_t axis)4659 static void fwdArgMax(Tensor *inpT, Tensor *outT, size_t axis) {
4660
4661 // Get input/output handles with dimensions expanded to maximum.
4662 ShapeVector inpDims = expandDimsToMax(inpT->dims());
4663 ShapeVector outDims = inpDims;
4664 outDims[axis] = 1;
4665 auto eInpT = inpT->getUnowned(inpDims);
4666 auto eOutT = outT->getUnowned(outDims);
4667 auto inpH = eInpT.getHandle<inpType>();
4668 auto outH = eOutT.getHandle<outType>();
4669
4670 static_assert(max_tensor_dimensions == 6,
4671 "Loops below assume max_tensor_dimensions = 6.");
4672
4673 for (dim_t idx0 = 0; idx0 < outDims[0]; idx0++) {
4674 for (dim_t idx1 = 0; idx1 < outDims[1]; idx1++) {
4675 for (dim_t idx2 = 0; idx2 < outDims[2]; idx2++) {
4676 for (dim_t idx3 = 0; idx3 < outDims[3]; idx3++) {
4677 for (dim_t idx4 = 0; idx4 < outDims[4]; idx4++) {
4678 for (dim_t idx5 = 0; idx5 < outDims[5]; idx5++) {
4679
4680 // Initialize maximum value/index.
4681 inpType maxVal = std::numeric_limits<inpType>::lowest();
4682 outType maxIdx = 0;
4683
4684 // Iterate input axis dimension.
4685 for (dim_t axisIdx = 0; axisIdx < inpDims[axis]; axisIdx++) {
4686 std::vector<dim_t> inpIdx = {idx0, idx1, idx2,
4687 idx3, idx4, idx5};
4688 inpIdx[axis] = axisIdx;
4689 inpType inpVal = inpH.at(inpIdx);
4690 if (inpVal > maxVal) {
4691 maxVal = inpVal;
4692 maxIdx = axisIdx;
4693 }
4694 }
4695
4696 // Store maximum index.
4697 outH.at({idx0, idx1, idx2, idx3, idx4, idx5}) = maxIdx;
4698 }
4699 }
4700 }
4701 }
4702 }
4703 }
4704 }
4705
4706 template <typename inpType, typename outType>
fwdArgMin(Tensor * inpT,Tensor * outT,size_t axis)4707 static void fwdArgMin(Tensor *inpT, Tensor *outT, size_t axis) {
4708
4709 // Get input/output handles with dimensions expanded to maximum.
4710 ShapeVector inpDims = expandDimsToMax(inpT->dims());
4711 ShapeVector outDims = inpDims;
4712 outDims[axis] = 1;
4713 auto eInpT = inpT->getUnowned(inpDims);
4714 auto eOutT = outT->getUnowned(outDims);
4715 auto inpH = eInpT.getHandle<inpType>();
4716 auto outH = eOutT.getHandle<outType>();
4717
4718 static_assert(max_tensor_dimensions == 6,
4719 "Loops below assume max_tensor_dimensions = 6.");
4720
4721 for (dim_t idx0 = 0; idx0 < outDims[0]; idx0++) {
4722 for (dim_t idx1 = 0; idx1 < outDims[1]; idx1++) {
4723 for (dim_t idx2 = 0; idx2 < outDims[2]; idx2++) {
4724 for (dim_t idx3 = 0; idx3 < outDims[3]; idx3++) {
4725 for (dim_t idx4 = 0; idx4 < outDims[4]; idx4++) {
4726 for (dim_t idx5 = 0; idx5 < outDims[5]; idx5++) {
4727
4728 // Initialize minimum value/index.
4729 inpType minVal = std::numeric_limits<inpType>::max();
4730 outType minIdx = 0;
4731
4732 // Iterate input axis dimension.
4733 for (dim_t axisIdx = 0; axisIdx < inpDims[axis]; axisIdx++) {
4734 std::vector<dim_t> inpIdx = {idx0, idx1, idx2,
4735 idx3, idx4, idx5};
4736 inpIdx[axis] = axisIdx;
4737 inpType inpVal = inpH.at(inpIdx);
4738 if (inpVal < minVal) {
4739 minVal = inpVal;
4740 minIdx = axisIdx;
4741 }
4742 }
4743
4744 // Store minimum index.
4745 outH.at({idx0, idx1, idx2, idx3, idx4, idx5}) = minIdx;
4746 }
4747 }
4748 }
4749 }
4750 }
4751 }
4752 }
4753
4754 //===----------------------------------------------------------------------===//
4755 // Sorting operators
4756 //===----------------------------------------------------------------------===//
4757
fwdTopKInst(const TopKInst * I)4758 void BoundInterpreterFunction::fwdTopKInst(const TopKInst *I) {
4759 auto outW = getTensor(I->getValues());
4760 auto indW = getTensor(I->getIndices());
4761 auto inW = getTensor(I->getInput());
4762 size_t k = I->getK();
4763
4764 if (inW->getType().isQuantizedType()) {
4765 if (indW->getElementType() == ElemKind::Int64ITy) {
4766
4767 fwdTopK<int8_t, int64_t>(outW, indW, inW, k);
4768 } else if (indW->getElementType() == ElemKind::Int32ITy) {
4769 fwdTopK<int8_t, int32_t>(outW, indW, inW, k);
4770 }
4771 return;
4772 }
4773
4774 dispatchFloatingPointAndIndexImpl(fwdTopK, inW->getElementType(),
4775 indW->getElementType(), outW, indW, inW, k);
4776 }
4777
4778 #define DISPATCH_ARG_MIN_MAX(functionName, elemTy, elemTyIndex, ...) \
4779 switch (elemTy) { \
4780 case ElemKind::FloatTy: \
4781 if (elemTyIndex == ElemKind::Int64ITy) { \
4782 functionName<float, int64_t>(__VA_ARGS__); \
4783 } else if (elemTyIndex == ElemKind::Int32ITy) { \
4784 functionName<float, int32_t>(__VA_ARGS__); \
4785 } \
4786 break; \
4787 case ElemKind::Int8QTy: \
4788 if (elemTyIndex == ElemKind::Int64ITy) { \
4789 functionName<int8_t, int64_t>(__VA_ARGS__); \
4790 } else if (elemTyIndex == ElemKind::Int32ITy) { \
4791 functionName<int8_t, int32_t>(__VA_ARGS__); \
4792 } \
4793 break; \
4794 default: \
4795 llvm_unreachable("Type is not supported"); \
4796 }
4797
fwdArgMaxInst(const ArgMaxInst * I)4798 void BoundInterpreterFunction::fwdArgMaxInst(const ArgMaxInst *I) {
4799 auto inpT = getTensor(I->getSrc());
4800 auto outT = getTensor(I->getDest());
4801 size_t axis = I->getAxis();
4802 auto inpElemType = inpT->getElementType();
4803 auto outElemType = outT->getElementType();
4804 DISPATCH_ARG_MIN_MAX(fwdArgMax, inpElemType, outElemType, inpT, outT, axis);
4805 }
4806
fwdArgMinInst(const ArgMinInst * I)4807 void BoundInterpreterFunction::fwdArgMinInst(const ArgMinInst *I) {
4808 auto inpT = getTensor(I->getSrc());
4809 auto outT = getTensor(I->getDest());
4810 size_t axis = I->getAxis();
4811 auto inpElemType = inpT->getElementType();
4812 auto outElemType = outT->getElementType();
4813 DISPATCH_ARG_MIN_MAX(fwdArgMin, inpElemType, outElemType, inpT, outT, axis);
4814 }
4815 #undef DISPATCH_ARG_MIN_MAX
4816
4817 //===----------------------------------------------------------------------===//
4818 // Tensor allocation operations
4819 //===----------------------------------------------------------------------===//
4820
fwdAllocActivationInst(const AllocActivationInst * I)4821 void BoundInterpreterFunction::fwdAllocActivationInst(
4822 const AllocActivationInst *I) {
4823 getOrCreateTensor(I);
4824 }
4825
fwdDeallocActivationInst(const DeallocActivationInst * I)4826 void BoundInterpreterFunction::fwdDeallocActivationInst(
4827 const DeallocActivationInst *I) {
4828 deleteTensor(I->getSrc());
4829 }
4830
4831 //===----------------------------------------------------------------------===//
4832 // Debug instructions
4833 //===----------------------------------------------------------------------===//
4834 /// Prints a value of the instruction's operand.
4835 /// In most cases it will be the name of the variable and the value of the
4836 /// tensor.
fwdDebugPrintInst(const DebugPrintInst * I)4837 void BoundInterpreterFunction::fwdDebugPrintInst(const DebugPrintInst *I) {
4838 auto *V = I->getSrc();
4839 auto *T = getTensor(V);
4840 std::string format = I->getFormat();
4841 std::string filename = I->getFileName();
4842
4843 if (format == "console") {
4844 // Dump tensor in console.
4845 llvm::outs() << I->getName() << ": ";
4846 V->dump();
4847 llvm::outs() << "\n";
4848 dumpImpl(T);
4849 llvm::outs() << "\n";
4850 } else if (format == "bin") {
4851 TensorSerializationOptions opts;
4852 opts.withType = true;
4853 glow::dumpTensorToBinaryFile(*T, filename, opts);
4854 } else if (format == "txt") {
4855 TensorSerializationOptions opts;
4856 opts.withType = true;
4857 glow::dumpTensorToTextFile(*T, filename, opts);
4858 } else if (format == "rawbin") {
4859 TensorSerializationOptions opts;
4860 opts.withType = false;
4861 glow::dumpTensorToBinaryFile(*T, filename, opts);
4862 } else if (format == "rawtxt") {
4863 TensorSerializationOptions opts;
4864 opts.withType = false;
4865 glow::dumpTensorToTextFile(*T, filename, opts);
4866 } else {
4867 llvm_unreachable("DebugPrint format not supported!");
4868 }
4869 }
4870
fwdTraceEventInst(const TraceEventInst * I)4871 void BoundInterpreterFunction::fwdTraceEventInst(const TraceEventInst *I) {
4872 auto T = getTensor(I->getData());
4873 auto IH = T->getHandle<int64_t>();
4874 size_t index = I->getIndex();
4875 IH.raw(index) = std::chrono::duration_cast<std::chrono::microseconds>(
4876 std::chrono::steady_clock::now().time_since_epoch())
4877 .count();
4878 }
4879
4880 //===----------------------------------------------------------------------===//
4881 // Instructions used by Quantization
4882 //===----------------------------------------------------------------------===//
fwdQuantizationProfileInst(const glow::QuantizationProfileInst * I)4883 void BoundInterpreterFunction::fwdQuantizationProfileInst(
4884 const glow::QuantizationProfileInst *I) {
4885 auto inputTensor = getWeightHandle(I->getInputTensor());
4886 auto currentHistogram = getWeightHandle(I->getHistogram());
4887 auto computationInfo = getWeightHandle(I->getComputationInfo());
4888
4889 float &min = computationInfo.raw(0);
4890 float &max = computationInfo.raw(1);
4891
4892 // Update current histogram, min and max based on the inputTensor data.
4893 quantization::generateTensorHistogram(inputTensor, currentHistogram, min,
4894 max);
4895 }
4896
4897 /// Quantize floating point tensor. Scale and Offset are based on return type
4898 /// of the instruction \p I.
fwdQuantizeInst(const glow::QuantizeInst * I)4899 void BoundInterpreterFunction::fwdQuantizeInst(const glow::QuantizeInst *I) {
4900 auto *srcTensor = getTensor(I->getSrc());
4901 auto *destTensor = getTensor(I->getDest());
4902 auto destTy = destTensor->getType();
4903 Tensor qTensor = quantization::quantizeTensor(
4904 *srcTensor, {destTy.getScale(), destTy.getOffset()},
4905 destTy.getElementType());
4906 destTensor->assign(&qTensor);
4907 }
4908
4909 /// Dequantize integer tensor. Scale and Offset are based
4910 /// on the source tensor type.
fwdDequantizeInst(const glow::DequantizeInst * I)4911 void BoundInterpreterFunction::fwdDequantizeInst(
4912 const glow::DequantizeInst *I) {
4913 auto *srcTensor = getTensor(I->getSrc());
4914 auto *destTensor = getTensor(I->getDest());
4915 auto destTy = destTensor->getType();
4916 Tensor fTensor =
4917 quantization::dequantizeTensor(*srcTensor, destTy.getElementType());
4918 destTensor->assign(&fTensor);
4919 }
4920
4921 template <class eTy>
fwdRescaleQuantizedInstImpl(Value * src,Value * dest,TensorQuantizationParams & srcQ,TensorQuantizationParams & destQ)4922 void BoundInterpreterFunction::fwdRescaleQuantizedInstImpl(
4923 Value *src, Value *dest, TensorQuantizationParams &srcQ,
4924 TensorQuantizationParams &destQ) {
4925
4926 auto srcH = getWeightHandle<eTy>(src);
4927 auto destH = getWeightHandle<eTy>(dest);
4928
4929 for (size_t i = 0, e = destH.size(); i < e; ++i) {
4930 float val = quantization::dequantize(srcH.raw(i), srcQ);
4931 destH.raw(i) = quantization::quantize(val, destQ);
4932 }
4933 }
4934
fwdRescaleQuantizedInst(const glow::RescaleQuantizedInst * I)4935 void BoundInterpreterFunction::fwdRescaleQuantizedInst(
4936 const glow::RescaleQuantizedInst *I) {
4937 auto src = I->getSrc();
4938 auto dest = I->getDest();
4939 auto srcTy = src->getType();
4940 auto destTy = dest->getType();
4941
4942 TensorQuantizationParams srcQ{srcTy->getScale(), srcTy->getOffset()};
4943 TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()};
4944
4945 dispatchQuantizedImpl(fwdRescaleQuantizedInstImpl, destTy->getElementType(),
4946 src, dest, srcQ, destQ);
4947 }
4948
fwdIntLookupTableInst(const IntLookupTableInst * I)4949 void BoundInterpreterFunction::fwdIntLookupTableInst(
4950 const IntLookupTableInst *I) {
4951 auto srcH = getWeightHandle<int8_t>(I->getSrc());
4952 auto destH = getWeightHandle<int8_t>(I->getDest());
4953 auto mappingH = getWeightHandle<int8_t>(I->getMapping());
4954
4955 for (size_t i = 0, e = destH.size(); i < e; i++) {
4956 destH.raw(i) = mappingH.raw((int)srcH.raw(i) + 128);
4957 }
4958 }
4959
fwdConvertToInst(const glow::ConvertToInst * I)4960 void BoundInterpreterFunction::fwdConvertToInst(const glow::ConvertToInst *I) {
4961 Tensor *source = getTensor(I->getInput());
4962 Tensor *dest = getTensor(I->getResult());
4963 auto srcElType = source->getType().getElementType();
4964 auto destElType = dest->getType().getElementType();
4965 if (srcElType == destElType) {
4966 // This is a noop conversion.
4967 dest->copyRawFrom(source);
4968 return;
4969 }
4970
4971 #define CONVERT(T_FROM, T_TO, DTY_FROM, DTY_TO) \
4972 if (srcElType == DTY_FROM && destElType == DTY_TO) { \
4973 dest->copyWithCast<T_TO, T_FROM>(source); \
4974 return; \
4975 }
4976 CONVERT(float, float16_t, ElemKind::FloatTy, ElemKind::Float16Ty)
4977 CONVERT(float, bfloat16_t, ElemKind::FloatTy, ElemKind::BFloat16Ty)
4978 CONVERT(float, bool, ElemKind::FloatTy, ElemKind::BoolTy)
4979 CONVERT(float, int32_t, ElemKind::FloatTy, ElemKind::Int32ITy)
4980 CONVERT(float, int64_t, ElemKind::FloatTy, ElemKind::Int64ITy)
4981 CONVERT(float16_t, float, ElemKind::Float16Ty, ElemKind::FloatTy)
4982 CONVERT(float16_t, bfloat16_t, ElemKind::Float16Ty, ElemKind::BFloat16Ty)
4983 CONVERT(float16_t, int32_t, ElemKind::Float16Ty, ElemKind::Int32ITy)
4984 CONVERT(float16_t, int64_t, ElemKind::Float16Ty, ElemKind::Int64ITy)
4985 CONVERT(bfloat16_t, float, ElemKind::BFloat16Ty, ElemKind::FloatTy)
4986 CONVERT(bfloat16_t, float16_t, ElemKind::BFloat16Ty, ElemKind::Float16Ty)
4987 CONVERT(bfloat16_t, int32_t, ElemKind::BFloat16Ty, ElemKind::Int32ITy)
4988 CONVERT(bfloat16_t, int64_t, ElemKind::BFloat16Ty, ElemKind::Int64ITy)
4989 CONVERT(bool, float, ElemKind::BoolTy, ElemKind::FloatTy)
4990 CONVERT(bool, bfloat16_t, ElemKind::BoolTy, ElemKind::BFloat16Ty)
4991 CONVERT(int32_t, float, ElemKind::Int32ITy, ElemKind::FloatTy)
4992 CONVERT(int32_t, float16_t, ElemKind::Int32ITy, ElemKind::Float16Ty)
4993 CONVERT(int32_t, bfloat16_t, ElemKind::Int32ITy, ElemKind::BFloat16Ty)
4994 CONVERT(int32_t, int64_t, ElemKind::Int32ITy, ElemKind::Int64ITy)
4995 CONVERT(int64_t, float, ElemKind::Int64ITy, ElemKind::FloatTy)
4996 CONVERT(int64_t, float16_t, ElemKind::Int64ITy, ElemKind::Float16Ty)
4997 CONVERT(int64_t, bfloat16_t, ElemKind::Int64ITy, ElemKind::BFloat16Ty)
4998 CONVERT(int64_t, int32_t, ElemKind::Int64ITy, ElemKind::Int32ITy)
4999 #undef CONVERT
5000
5001 if (srcElType == ElemKind::UInt8FusedQTy &&
5002 destElType == ElemKind::UInt8FusedFP16QTy) {
5003 dest->convertToType(ElemKind::UInt8FusedFP16QTy);
5004 return;
5005 }
5006 llvm_unreachable("Type not supported");
5007 }
5008
5009 template <typename ElemTy>
fwdBatchedPairwiseDotProductInstImpl(const BatchedPairwiseDotProductInst * I)5010 void BoundInterpreterFunction::fwdBatchedPairwiseDotProductInstImpl(
5011 const BatchedPairwiseDotProductInst *I) {
5012 auto destT = getTensor(I->getDest());
5013 auto destH = destT->getHandle<ElemTy>();
5014
5015 dim_t batchCount = destT->getType().dims()[0];
5016
5017 // Gather all batched vector operands into an array so that they can be
5018 // indexed easily.
5019 std::vector<Value *> srcs;
5020 for (unsigned i = 1, e = I->getNumOperands(); i < e; ++i) {
5021 auto op = I->getOperand(i);
5022 srcs.emplace_back(op.first);
5023 }
5024
5025 // pairIdx is the total number of pairs (i, j) that have been processed.
5026 unsigned pairIdx = 0;
5027
5028 // For each src operand:
5029 for (unsigned i = 1, e = I->getNumInputs(); i < e; ++i) {
5030 auto vAH = getTensor(srcs[i])->getHandle<ElemTy>();
5031 dim_t vectorSize = getTensor(srcs[i])->getType().dims()[1];
5032
5033 // Compute the dot product of src[i] with every other vector with a smaller
5034 // index.
5035 for (unsigned j = 0; j < i; ++j) {
5036 auto vBH = getTensor(srcs[j])->getHandle<ElemTy>();
5037
5038 // Process all batches for a given pair (i, j).
5039 for (dim_t b = 0; b < batchCount; ++b) {
5040 ElemTy accum = 0;
5041
5042 for (dim_t k = 0; k < vectorSize; ++k) {
5043 accum += vAH.at({b, k}) * vBH.at({b, k});
5044 }
5045
5046 destH.at({b, pairIdx}) = accum;
5047 }
5048
5049 ++pairIdx;
5050 }
5051 }
5052 }
5053
fwdBatchedPairwiseDotProductInst(const BatchedPairwiseDotProductInst * I)5054 void BoundInterpreterFunction::fwdBatchedPairwiseDotProductInst(
5055 const BatchedPairwiseDotProductInst *I) {
5056 dispatchImpl(fwdBatchedPairwiseDotProductInstImpl,
5057 I->getDest()->getElementType(), I);
5058 }
5059
5060 template <typename ElemTy>
fwdBatchedPairwiseDotProductGradInstImpl(const BatchedPairwiseDotProductGradInst * I)5061 void BoundInterpreterFunction::fwdBatchedPairwiseDotProductGradInstImpl(
5062 const BatchedPairwiseDotProductGradInst *I) {
5063 auto destGradT = getTensor(I->getDestGrad());
5064 auto destGradH = destGradT->getHandle<ElemTy>();
5065
5066 dim_t batchCount = destGradT->getType().dims()[0];
5067
5068 // Gather all batched vector operands into arrays so that they can be
5069 // indexed easily. Operands 1 -> numInputs are gradients of inputs, and
5070 // operands numInputs + 1 -> numOperands - 1 are the corresponding original
5071 // inputs.
5072 std::vector<Value *> srcs, srcGrads;
5073 for (unsigned i = 0, e = I->getNumInputs(); i < e; ++i) {
5074 auto gradOp = I->getOperand(i + 1);
5075 auto inputOp = I->getOperand(i + 1 + e);
5076
5077 srcGrads.emplace_back(gradOp.first);
5078 srcs.emplace_back(inputOp.first);
5079 }
5080
5081 // Zero initialize all srcGrad tensors.
5082 for (auto &s : srcGrads) {
5083 getTensor(s)->zero();
5084 }
5085
5086 // pairIdx is the total number of pairs (i, j) that have been processed.
5087 unsigned pairIdx = 0;
5088
5089 // For each srcGrad operand:
5090 for (unsigned i = 0, e = I->getNumInputs(); i < e; ++i) {
5091 auto dvAH = getTensor(srcGrads[i])->getHandle<ElemTy>();
5092 dim_t vectorSize = getTensor(srcs[i])->getType().dims()[1];
5093
5094 // Accmulate into it the product of the gradient of all dot products that
5095 // src[i] contributed to and the corresponding vectors that src[i] was
5096 // dotted with.
5097 for (unsigned j = i + 1; j < e; ++j) {
5098 auto vBH = getTensor(srcs[j])->getHandle<ElemTy>();
5099
5100 // Process all batches for a given pair (i, j).
5101 for (dim_t b = 0; b < batchCount; ++b) {
5102 ElemTy grad = destGradH.at({b, pairIdx});
5103
5104 for (dim_t k = 0; k < vectorSize; ++k) {
5105 dvAH.at({b, k}) += grad * vBH.at({b, k});
5106 }
5107 }
5108
5109 ++pairIdx;
5110 }
5111 }
5112 }
5113
fwdBatchedPairwiseDotProductGradInst(const BatchedPairwiseDotProductGradInst * I)5114 void BoundInterpreterFunction::fwdBatchedPairwiseDotProductGradInst(
5115 const BatchedPairwiseDotProductGradInst *I) {
5116 dispatchImpl(fwdBatchedPairwiseDotProductGradInstImpl,
5117 I->getDestGrad()->getElementType(), I);
5118 }
5119
5120 template <typename ElemTy>
fwdFlipInstImpl(const FlipInst * I)5121 void BoundInterpreterFunction::fwdFlipInstImpl(const FlipInst *I) {
5122
5123 static_assert(max_tensor_dimensions == 6,
5124 "Loops below assume max_tensor_dimensions = 6.");
5125
5126 auto *src = I->getSrc();
5127 auto *dest = I->getDest();
5128
5129 // Get unowned handles of src and dest with dims expanded to maximum.
5130 ShapeVector eDims = expandDimsToMax(src->dims());
5131 auto eSrc = getTensor(src)->getUnowned(eDims);
5132 auto eDest = getTensor(dest)->getUnowned(eDims);
5133 auto srcH = eSrc.getHandle<ElemTy>();
5134 auto destH = eDest.getHandle<ElemTy>();
5135
5136 #define LOOP_AXIS_CASE(_D0, _D1, _D2, _D3, _D4, _D5) \
5137 for (dim_t idx0 = 0; idx0 < eDims[0]; idx0++) \
5138 for (dim_t idx1 = 0; idx1 < eDims[1]; idx1++) \
5139 for (dim_t idx2 = 0; idx2 < eDims[2]; idx2++) \
5140 for (dim_t idx3 = 0; idx3 < eDims[3]; idx3++) \
5141 for (dim_t idx4 = 0; idx4 < eDims[4]; idx4++) \
5142 for (dim_t idx5 = 0; idx5 < eDims[5]; idx5++) { \
5143 destH.at({_D0, _D1, _D2, _D3, _D4, _D5}) = \
5144 srcH.at({idx0, idx1, idx2, idx3, idx4, idx5}); \
5145 } \
5146 return;
5147
5148 switch (I->getAxis()) {
5149 case 0:
5150 LOOP_AXIS_CASE(eDims[0] - 1 - idx0, idx1, idx2, idx3, idx4, idx5);
5151 case 1:
5152 LOOP_AXIS_CASE(idx0, eDims[1] - 1 - idx1, idx2, idx3, idx4, idx5);
5153 case 2:
5154 LOOP_AXIS_CASE(idx0, idx1, eDims[2] - 1 - idx2, idx3, idx4, idx5);
5155 case 3:
5156 LOOP_AXIS_CASE(idx0, idx1, idx2, eDims[3] - 1 - idx3, idx4, idx5);
5157 case 4:
5158 LOOP_AXIS_CASE(idx0, idx1, idx2, idx3, eDims[4] - 1 - idx4, idx5);
5159 case 5:
5160 LOOP_AXIS_CASE(idx0, idx1, idx2, idx3, idx4, eDims[5] - 1 - idx5);
5161 default:
5162 llvm_unreachable("Axis should be less than max_tensor_dimensions.");
5163 }
5164 }
5165
fwdFlipInst(const FlipInst * I)5166 void BoundInterpreterFunction::fwdFlipInst(const FlipInst *I) {
5167 dispatchImpl(fwdFlipInstImpl, I->getSrc()->getElementType(), I);
5168 }
5169
5170 //===----------------------------------------------------------------------===//
5171 // Instructions used by ObjectDetection
5172 //===----------------------------------------------------------------------===//
maxMin(float lhs,float rhs,float & min,float & max)5173 static void maxMin(float lhs, float rhs, float &min, float &max) {
5174 if (lhs >= rhs) {
5175 min = rhs;
5176 max = lhs;
5177 } else {
5178 min = lhs;
5179 max = rhs;
5180 }
5181 }
5182
5183 using ClassBox = std::pair<float, dim_t>;
5184
5185 struct Box {
5186 float classValue{0.0f};
5187 dim_t batchIndex{0};
5188 dim_t classIndex{0};
5189 dim_t boxIndex{0};
5190 };
5191
5192 template <typename ElemTy>
doIOU(Handle<ElemTy> & boxes,dim_t batchIndex,dim_t selectedBoxIndex,dim_t candidateBoxIndex,int centerPointBox,float iouThreshold,bool isV4)5193 static bool doIOU(Handle<ElemTy> &boxes, dim_t batchIndex,
5194 dim_t selectedBoxIndex, dim_t candidateBoxIndex,
5195 int centerPointBox, float iouThreshold, bool isV4) {
5196 float sx[] = {0.0f, 0.0f, 0.0f, 0.0f};
5197 float cx[] = {0.0f, 0.0f, 0.0f, 0.0f};
5198
5199 if (isV4) {
5200 for (dim_t i = 0; i < 4; ++i) {
5201 sx[i] = boxes.at({selectedBoxIndex, i});
5202 cx[i] = boxes.at({candidateBoxIndex, i});
5203 }
5204 } else {
5205 for (dim_t i = 0; i < 4; ++i) {
5206 sx[i] = boxes.at({batchIndex, selectedBoxIndex, i});
5207 cx[i] = boxes.at({batchIndex, candidateBoxIndex, i});
5208 }
5209 }
5210
5211 float xSMin = 0.0f;
5212 float ySMin = 0.0f;
5213 float xSMax = 0.0f;
5214 float ySMax = 0.0f;
5215
5216 float xCMin = 0.0f;
5217 float yCMin = 0.0f;
5218 float xCMax = 0.0f;
5219 float yCMax = 0.0f;
5220
5221 // Standardizing coordinates so that (xmin, ymin) is upper left corner of a
5222 // box and (xmax, ymax) is lower right corner of the box.
5223 if (!centerPointBox) {
5224 // 0 means coordinates for diagonal ends of a box.
5225 // Coordinates can either be absolute or normalized.
5226 maxMin(sx[0], sx[2], xSMin, xSMax);
5227 maxMin(sx[1], sx[3], ySMin, ySMax);
5228
5229 maxMin(cx[0], cx[2], xCMin, xCMax);
5230 maxMin(cx[1], cx[3], yCMin, yCMax);
5231 } else {
5232 float halfWidthS = sx[2] / 2.0f;
5233 float halfHeightS = sx[3] / 2.0f;
5234 float halfWidthC = cx[2] / 2.0f;
5235 float halfHeightC = cx[3] / 2.0f;
5236
5237 xSMin = sx[0] - halfWidthS;
5238 ySMin = sx[1] - halfHeightS;
5239 xSMax = sx[0] + halfWidthS;
5240 ySMax = sx[1] + halfHeightS;
5241
5242 xCMin = cx[0] - halfWidthC;
5243 yCMin = cx[1] - halfHeightC;
5244 xCMax = cx[0] + halfWidthC;
5245 yCMax = cx[1] + halfHeightC;
5246 }
5247
5248 // finding upper left and lower right corner of a box formed by intersection.
5249 float xMin = std::max(xSMin, xCMin);
5250 float yMin = std::max(ySMin, yCMin);
5251 float xMax = std::min(xSMax, xCMax);
5252 float yMax = std::min(ySMax, yCMax);
5253
5254 float intersectionArea =
5255 std::max(0.0f, xMax - xMin) * std::max(0.0f, yMax - yMin);
5256
5257 if (intersectionArea == 0.0f) {
5258 return false;
5259 }
5260
5261 float sArea = (xSMax - xSMin) * (ySMax - ySMin);
5262 float cArea = (xCMax - xCMin) * (yCMax - yCMin);
5263 float unionArea = sArea + cArea - intersectionArea;
5264
5265 return intersectionArea > iouThreshold * unionArea;
5266 }
5267
5268 template <typename T>
fwdNonMaxSuppressionInstImpl(glow::NonMaxSuppressionInst const * I)5269 void BoundInterpreterFunction::fwdNonMaxSuppressionInstImpl(
5270 glow::NonMaxSuppressionInst const *I) {
5271
5272 auto boxes = I->getBoxes();
5273 auto scores = I->getScores();
5274 auto indices = I->getIndices();
5275 auto numDetected = I->getNumberOfSelectedIndices();
5276 float iouThreshold = I->getIouThreshold();
5277 dim_t maxBoxesPerClass = I->getMaxOutputBoxesPerClass();
5278 float scoreThreshold = I->getScoreThreshold();
5279 unsigned centerPointBox = I->getCenterPointBox();
5280 bool isV4 = I->getIsTFVersion4();
5281
5282 auto boxesH = getTensor(boxes)->getHandle<float>();
5283 auto scoresH = getTensor(scores)->getHandle<float>();
5284 auto indicesH = getTensor(indices)->getHandle<T>();
5285 auto numDetectedH = getTensor(numDetected)->getHandle<T>();
5286
5287 int boxesBoxDim = boxes->dims().size() - 2;
5288
5289 dim_t numBatches = 1;
5290 dim_t numClasses = 1;
5291 dim_t numBoxes = boxes->dims()[boxesBoxDim];
5292
5293 size_t maxOutputPerBatch = 0;
5294
5295 if (!isV4) {
5296 int boxesBatchDim = boxes->dims().size() - 3;
5297
5298 int scoresBatchDim = scores->dims().size() - 3;
5299 int scoresBoxDim = scores->dims().size() - 1;
5300 int scoresClassDim = scores->dims().size() - 2;
5301 assert(scores->dims()[scoresBoxDim] == boxes->dims()[boxesBoxDim] &&
5302 "Mismatch between number of scores and number of boxes.");
5303 assert(scores->dims()[scoresBatchDim] == boxes->dims()[boxesBatchDim] &&
5304 "Mismatch in batch dimension.");
5305 (void)boxesBatchDim;
5306 (void)scoresBoxDim;
5307 numBatches = scores->dims()[scoresBatchDim];
5308 numClasses = scores->dims()[scoresClassDim];
5309 numBoxes = boxes->dims()[boxesBoxDim];
5310 maxOutputPerBatch =
5311 indices->dims()[indices->dims().size() - 2] / numBatches;
5312 } else {
5313 maxOutputPerBatch =
5314 indices->dims()[indices->dims().size() - 1] / numBatches;
5315 }
5316
5317 auto cmpFunc = [](const ClassBox &a, const ClassBox &b) {
5318 return a.first < b.first;
5319 };
5320
5321 std::vector<ClassBox> selectedIndices(numBoxes);
5322 dim_t outPutBoxIndex = 0;
5323
5324 for (dim_t batchIndex = 0; batchIndex < numBatches; ++batchIndex) {
5325 Box minBox{scoresH.raw(batchIndex * numClasses * numBoxes), batchIndex, 0,
5326 0};
5327 int32_t detectedPerBatch = 0;
5328 for (dim_t classIndex = 0; classIndex < numClasses; ++classIndex) {
5329 selectedIndices.clear();
5330 size_t detectedPerClass = 0;
5331 std::priority_queue<ClassBox, std::vector<ClassBox>, decltype(cmpFunc)>
5332 queue(cmpFunc);
5333
5334 for (size_t boxIndex = 0; boxIndex < numBoxes; ++boxIndex) {
5335 float classValue = scoresH.raw(
5336 (batchIndex * numClasses + classIndex) * numBoxes + boxIndex);
5337 if (classValue > scoreThreshold) {
5338 queue.emplace(classValue, boxIndex);
5339 }
5340 }
5341
5342 float tScore = minBox.classValue;
5343 while (!queue.empty()) {
5344 auto priorBox = queue.top();
5345 queue.pop();
5346
5347 bool selected = true;
5348 for (auto &sBox : selectedIndices) {
5349 if (doIOU(boxesH, batchIndex, sBox.second, priorBox.second,
5350 centerPointBox, iouThreshold, isV4)) {
5351 selected = false;
5352 break;
5353 }
5354 }
5355
5356 if (selected) {
5357 selectedIndices.emplace_back(priorBox);
5358 if (isV4) {
5359 indicesH.at({outPutBoxIndex}) = priorBox.second;
5360 tScore = scoresH.at({priorBox.second});
5361 } else {
5362 indicesH.at({outPutBoxIndex, 0}) = batchIndex;
5363 indicesH.at({outPutBoxIndex, 1}) = classIndex;
5364 indicesH.at({outPutBoxIndex, 2}) = priorBox.second;
5365 tScore = scoresH.at({batchIndex, classIndex, priorBox.second});
5366 }
5367
5368 ++outPutBoxIndex;
5369 ++detectedPerClass;
5370 ++detectedPerBatch;
5371 }
5372 if (maxBoxesPerClass == detectedPerClass) {
5373 break;
5374 }
5375 }
5376
5377 if (tScore < minBox.classValue) {
5378 minBox.classValue = tScore;
5379 minBox.classIndex = classIndex;
5380 if (isV4) {
5381 minBox.boxIndex = indicesH.at({outPutBoxIndex - 1});
5382 } else {
5383 minBox.boxIndex = indicesH.at({outPutBoxIndex - 1, 2});
5384 }
5385 }
5386 }
5387
5388 for (size_t i = detectedPerBatch; i < maxOutputPerBatch; ++i) {
5389 if (isV4) {
5390 indicesH.at({outPutBoxIndex}) = minBox.boxIndex;
5391 } else {
5392 indicesH.at({outPutBoxIndex, 0}) = minBox.batchIndex;
5393 indicesH.at({outPutBoxIndex, 1}) = minBox.classIndex;
5394 indicesH.at({outPutBoxIndex, 2}) = minBox.boxIndex;
5395 }
5396
5397 ++outPutBoxIndex;
5398 }
5399 // For ONNX NMS it's not used, for TF Batch Dimension is 1.
5400 for (dim_t i = 0; i < maxBoxesPerClass; ++i) {
5401 numDetectedH.at({batchIndex * maxBoxesPerClass + i}) = detectedPerBatch;
5402 }
5403 }
5404 }
5405
fwdNonMaxSuppressionInst(glow::NonMaxSuppressionInst const * I)5406 void BoundInterpreterFunction::fwdNonMaxSuppressionInst(
5407 glow::NonMaxSuppressionInst const *I) {
5408 switch (I->getBoxes()->getElementType()) {
5409 case ElemKind::FloatTy:
5410 if (I->getIndices()->getElementType() == ElemKind::Int32ITy) {
5411 fwdNonMaxSuppressionInstImpl<int32_t>(I);
5412 } else if (I->getIndices()->getElementType() == ElemKind::Int64ITy) {
5413 fwdNonMaxSuppressionInstImpl<int64_t>(I);
5414 } else {
5415 llvm_unreachable("Output type is not supported.");
5416 }
5417 break;
5418 default:
5419 llvm_unreachable("Type is not supported.");
5420 break;
5421 }
5422 }
5423
fwdAudioSpectrogramInstFloatImpl(glow::AudioSpectrogramInst const * I)5424 void BoundInterpreterFunction::fwdAudioSpectrogramInstFloatImpl(
5425 glow::AudioSpectrogramInst const *I) {
5426
5427 auto spectrogram = I->getSpectrogram();
5428 auto input = I->getInput();
5429 auto window = I->getWindow();
5430 int64_t windowSize = I->getWindowSize();
5431 int64_t windowStride = I->getWindowStride();
5432
5433 auto spectrogramH = getTensor(spectrogram)->getHandle<float>();
5434 auto inputH = getTensor(input)->getHandle<float>();
5435 auto windowH = getTensor(window)->getHandle<float>();
5436
5437 // Compute window count.
5438 int64_t inputLength = input->size();
5439 int64_t windowCount =
5440 std::floor((inputLength - windowSize) / windowStride) + 1;
5441
5442 // Compute FFT length (next power of 2) and spectrogram length.
5443 dim_t fftLen = 1 << (dim_t)std::ceil(std::log2((double)windowSize));
5444 dim_t specLen = fftLen / 2 + 1;
5445
5446 // Allocate temporary buffers.
5447 auto winOut = std::make_unique<float[]>(windowSize);
5448 auto fftRealOut = std::make_unique<float[]>(specLen);
5449 auto fftImagOut = std::make_unique<float[]>(specLen);
5450
5451 // Compute the spectrogram.
5452 for (dim_t winIdx = 0; winIdx < windowCount; winIdx++) {
5453
5454 // Windowing.
5455 for (dim_t n = 0; n < windowSize; n++) {
5456 winOut[n] = inputH.raw(winIdx * windowStride + n) * windowH.raw(n);
5457 }
5458
5459 // Compute spectrum (perform FFT).
5460 for (int k = 0; k < specLen; k++) {
5461 fftRealOut[k] = 0;
5462 fftImagOut[k] = 0;
5463 for (int n = 0; n < windowSize; n++) {
5464 fftRealOut[k] +=
5465 winOut[n] * cos(2.0 * M_PI * (double)(n * k) / (double)(fftLen));
5466 fftImagOut[k] -=
5467 winOut[n] * sin(2.0 * M_PI * (double)(n * k) / (double)(fftLen));
5468 }
5469 }
5470
5471 // Compute spectrum magnitude/power.
5472 if (I->getMagnitudeSquared()) {
5473 for (dim_t k = 0; k < specLen; k++) {
5474 spectrogramH.at({winIdx, k}) =
5475 fftRealOut[k] * fftRealOut[k] + fftImagOut[k] * fftImagOut[k];
5476 }
5477 } else {
5478 for (dim_t k = 0; k < specLen; k++) {
5479 spectrogramH.at({winIdx, k}) =
5480 sqrt(fftRealOut[k] * fftRealOut[k] + fftImagOut[k] * fftImagOut[k]);
5481 }
5482 }
5483 }
5484 }
5485
fwdAudioSpectrogramInst(glow::AudioSpectrogramInst const * I)5486 void BoundInterpreterFunction::fwdAudioSpectrogramInst(
5487 glow::AudioSpectrogramInst const *I) {
5488 auto inputTy = I->getInput()->getElementType();
5489 auto spectrogramTy = I->getSpectrogram()->getElementType();
5490 if ((inputTy == ElemKind::FloatTy) && (spectrogramTy == ElemKind::FloatTy)) {
5491 fwdAudioSpectrogramInstFloatImpl(I);
5492 } else {
5493 llvm_unreachable("Type is not supported.");
5494 }
5495 }
5496
fwdMFCCInstFloatImpl(glow::MFCCInst const * I)5497 void BoundInterpreterFunction::fwdMFCCInstFloatImpl(glow::MFCCInst const *I) {
5498
5499 auto coefficients = I->getCoefficients();
5500 auto spectrogram = I->getSpectrogram();
5501 auto melWeights = I->getMelWeights();
5502 auto melRanges = I->getMelRanges();
5503 auto dctMat = I->getDctMat();
5504 int64_t filterBankCount = I->getFilterBankCount();
5505 int64_t numCoefficients = I->getNumCoefficients();
5506
5507 auto coefficientsH = getTensor(coefficients)->getHandle<float>();
5508 auto spectrogramH = getTensor(spectrogram)->getHandle<float>();
5509 auto melWeightsH = getTensor(melWeights)->getHandle<float>();
5510 auto melRangesH = getTensor(melRanges)->getHandle<int32_t>();
5511 auto dctMatH = getTensor(dctMat)->getHandle<float>();
5512
5513 // Perform MFCC for all the windows.
5514 auto winNum = spectrogram->dims()[0];
5515 auto melBuff = std::make_unique<float[]>(filterBankCount);
5516 for (dim_t winIdx = 0; winIdx < winNum; winIdx++) {
5517
5518 // Apply Mel filter bank mapping. We use sqrt for the spectrogram since we
5519 // assume the spectrogram is a power value and not a magnitude.
5520 dim_t melBinCoeffIdx = 0;
5521 for (dim_t melIdx = 0; melIdx < filterBankCount; melIdx++) {
5522 int32_t freqIdxStart = melRangesH.raw(2 * melIdx + 0);
5523 int32_t freqIdxStop = melRangesH.raw(2 * melIdx + 1);
5524 float melPwr = 0.0f;
5525 for (dim_t freqIdx = freqIdxStart; freqIdx <= freqIdxStop; freqIdx++) {
5526 melPwr += std::sqrt(spectrogramH.at({winIdx, freqIdx})) *
5527 melWeightsH.raw(melBinCoeffIdx++);
5528 }
5529 melBuff[melIdx] = melPwr;
5530 }
5531
5532 // Take logarithm in-place (avoid log(0)).
5533 for (dim_t melIdx = 0; melIdx < filterBankCount; melIdx++) {
5534 float melPwr = melBuff[melIdx];
5535 melBuff[melIdx] = (melPwr == 0.0)
5536 ? logf(std::numeric_limits<float>::min())
5537 : logf(melPwr);
5538 }
5539
5540 // Compute DCT transform.
5541 for (dim_t k = 0; k < numCoefficients; k++) {
5542 float dctOut = 0.0f;
5543 for (dim_t n = 0; n < filterBankCount; n++) {
5544 dctOut += dctMatH.at({k, n}) * melBuff[n];
5545 }
5546 coefficientsH.at({winIdx, k}) = dctOut;
5547 }
5548 }
5549 }
5550
fwdMFCCInst(glow::MFCCInst const * I)5551 void BoundInterpreterFunction::fwdMFCCInst(glow::MFCCInst const *I) {
5552 auto spectrogramTy = I->getSpectrogram()->getElementType();
5553 auto coefficientsTy = I->getCoefficients()->getElementType();
5554 if ((spectrogramTy == ElemKind::FloatTy) &&
5555 (coefficientsTy == ElemKind::FloatTy)) {
5556 fwdMFCCInstFloatImpl(I);
5557 } else {
5558 llvm_unreachable("Type is not supported.");
5559 }
5560 }
5561