1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "glow/Backends/Interpreter/Interpreter.h"
18 
19 #include "glow/Base/TensorSerialization.h"
20 #include "glow/IR/Instrs.h"
21 #include "glow/Quantization/Base/Base.h"
22 #include "glow/Quantization/Base/Profile.h"
23 
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/raw_ostream.h"
27 
28 #include <chrono>
29 #include <cmath>
30 #include <math.h>
31 
32 #ifdef WIN32
33 #include <corecrt_math_defines.h>
34 #endif
35 
36 using namespace glow;
37 
38 #define dispatchImpl(functionName, elemTy, ...)                                \
39   switch (elemTy) {                                                            \
40   case ElemKind::FloatTy:                                                      \
41     functionName<float>(__VA_ARGS__);                                          \
42     break;                                                                     \
43   case ElemKind::Float16Ty:                                                    \
44     functionName<float16_t>(__VA_ARGS__);                                      \
45     break;                                                                     \
46   case ElemKind::BFloat16Ty:                                                   \
47     functionName<bfloat16_t>(__VA_ARGS__);                                     \
48     break;                                                                     \
49   case ElemKind::Int8QTy:                                                      \
50     functionName<int8_t>(__VA_ARGS__);                                         \
51     break;                                                                     \
52   case ElemKind::Int16QTy:                                                     \
53     functionName<int16_t>(__VA_ARGS__);                                        \
54     break;                                                                     \
55   case ElemKind::Int32QTy:                                                     \
56     functionName<int32_t>(__VA_ARGS__);                                        \
57     break;                                                                     \
58   case ElemKind::Int32ITy:                                                     \
59     functionName<int32_t>(__VA_ARGS__);                                        \
60     break;                                                                     \
61   case ElemKind::Int64ITy:                                                     \
62     functionName<int64_t>(__VA_ARGS__);                                        \
63     break;                                                                     \
64   case ElemKind::BoolTy:                                                       \
65     functionName<bool>(__VA_ARGS__);                                           \
66     break;                                                                     \
67   default:                                                                     \
68     llvm_unreachable("Type is not supported");                                 \
69   }
70 
71 #define dispatchFloatingPointImpl(functionName, elemTy, ...)                   \
72   switch (elemTy) {                                                            \
73   case ElemKind::FloatTy:                                                      \
74     functionName<float>(__VA_ARGS__);                                          \
75     break;                                                                     \
76   case ElemKind::Float16Ty:                                                    \
77     functionName<float16_t>(__VA_ARGS__);                                      \
78     break;                                                                     \
79   case ElemKind::BFloat16Ty:                                                   \
80     functionName<bfloat16_t>(__VA_ARGS__);                                     \
81     break;                                                                     \
82   default:                                                                     \
83     llvm_unreachable("Type is not supported");                                 \
84   }
85 
86 #define dispatchFloatingPointAndIndexImpl(functionName, elemTy, elemTyIndex,   \
87                                           ...)                                 \
88   switch (elemTy) {                                                            \
89   case ElemKind::FloatTy:                                                      \
90     if (elemTyIndex == ElemKind::Int64ITy) {                                   \
91       functionName<float, int64_t>(__VA_ARGS__);                               \
92     } else if (elemTyIndex == ElemKind::Int32ITy) {                            \
93       functionName<float, int32_t>(__VA_ARGS__);                               \
94     }                                                                          \
95     break;                                                                     \
96   case ElemKind::Float16Ty:                                                    \
97     if (elemTyIndex == ElemKind::Int64ITy) {                                   \
98       functionName<float16, int64_t>(__VA_ARGS__);                             \
99     } else if (elemTyIndex == ElemKind::Int32ITy) {                            \
100       functionName<float16, int32_t>(__VA_ARGS__);                             \
101     }                                                                          \
102     break;                                                                     \
103   case ElemKind::BFloat16Ty:                                                   \
104     if (elemTyIndex == ElemKind::Int64ITy) {                                   \
105       functionName<bfloat16, int64_t>(__VA_ARGS__);                            \
106     } else if (elemTyIndex == ElemKind::Int32ITy) {                            \
107       functionName<bfloat16, int32_t>(__VA_ARGS__);                            \
108     }                                                                          \
109     break;                                                                     \
110   default:                                                                     \
111     llvm_unreachable("Type is not supported");                                 \
112   }
113 
114 #define dispatchIndexTypeImpl(functionName, elemTy, ...)                       \
115   switch (elemTy) {                                                            \
116   case ElemKind::Int32ITy:                                                     \
117     functionName<int32_t>(__VA_ARGS__);                                        \
118     break;                                                                     \
119   case ElemKind::Int64ITy:                                                     \
120     functionName<int64_t>(__VA_ARGS__);                                        \
121     break;                                                                     \
122   default:                                                                     \
123     llvm_unreachable("Type is not supported");                                 \
124   }
125 
126 #define dispatchArithmeticImpl(functionName, elemTy, ...)                      \
127   switch (elemTy) {                                                            \
128   case ElemKind::FloatTy:                                                      \
129     functionName<float>(__VA_ARGS__);                                          \
130     break;                                                                     \
131   case ElemKind::Float16Ty:                                                    \
132     functionName<float16_t>(__VA_ARGS__);                                      \
133     break;                                                                     \
134   case ElemKind::BFloat16Ty:                                                   \
135     functionName<bfloat16_t>(__VA_ARGS__);                                     \
136     break;                                                                     \
137   case ElemKind::Int32ITy:                                                     \
138     functionName<int32_t>(__VA_ARGS__);                                        \
139     break;                                                                     \
140   case ElemKind::Int64ITy:                                                     \
141     functionName<int64_t>(__VA_ARGS__);                                        \
142     break;                                                                     \
143   default:                                                                     \
144     llvm_unreachable("Type is not supported");                                 \
145   }
146 
147 #define dispatchQuantizedImpl(functionName, elemTy, ...)                       \
148   switch (elemTy) {                                                            \
149   case ElemKind::Int8QTy:                                                      \
150     functionName<int8_t>(__VA_ARGS__);                                         \
151     break;                                                                     \
152   case ElemKind::Int16QTy:                                                     \
153     functionName<int16_t>(__VA_ARGS__);                                        \
154     break;                                                                     \
155   case ElemKind::Int32QTy:                                                     \
156     functionName<int32_t>(__VA_ARGS__);                                        \
157     break;                                                                     \
158   default:                                                                     \
159     llvm_unreachable("Type is not supported");                                 \
160   }
161 
162 #define dispatchQuantizedWithAccumulationImpl(functionName, elemTy, ...)       \
163   switch (elemTy) {                                                            \
164   case ElemKind::Int8QTy:                                                      \
165     functionName<int8_t, int32_t>(__VA_ARGS__);                                \
166     break;                                                                     \
167   case ElemKind::Int16QTy:                                                     \
168     functionName<int16_t, int64_t>(__VA_ARGS__);                               \
169     break;                                                                     \
170   default:                                                                     \
171     llvm_unreachable("Type is not supported");                                 \
172   }
173 
174 #define dispatchQuantizedWithAccumulationAndBiasImpl(functionName, elemTy,     \
175                                                      biasElemType, ...)        \
176   if (elemTy == ElemKind::Int8QTy && biasElemType == ElemKind::Int8QTy) {      \
177     functionName<int8_t, int32_t, int8_t>(__VA_ARGS__);                        \
178   } else if (elemTy == ElemKind::Int8QTy &&                                    \
179              biasElemType == ElemKind::Int32QTy) {                             \
180     functionName<int8_t, int32_t, int32_t>(__VA_ARGS__);                       \
181   } else if (elemTy == ElemKind::Int16QTy &&                                   \
182              biasElemType == ElemKind::Int16QTy) {                             \
183     functionName<int16_t, int64_t, int16_t>(__VA_ARGS__);                      \
184   } else if (elemTy == ElemKind::Int16QTy &&                                   \
185              biasElemType == ElemKind::Int32QTy) {                             \
186     functionName<int16_t, int64_t, int32_t>(__VA_ARGS__);                      \
187   } else {                                                                     \
188     llvm_unreachable("Type is not supported");                                 \
189   }
190 
191 #define staticAssertFloatingPointType(ElemTy)                                  \
192   static_assert(                                                               \
193       std::is_floating_point<ElemTy>::value ||                                 \
194           std::is_same<float16_t,                                              \
195                        typename std::remove_cv<ElemTy>::type>::value ||        \
196           std::is_same<bfloat16_t,                                             \
197                        typename std::remove_cv<ElemTy>::type>::value,          \
198       "This implementation is for floating-point values only")
199 
200 #define staticAssertArithmeticType(ElemTy)                                     \
201   static_assert(                                                               \
202       std::is_arithmetic<ElemTy>::value ||                                     \
203           std::is_same<float16_t,                                              \
204                        typename std::remove_cv<ElemTy>::type>::value ||        \
205           std::is_same<bfloat16_t,                                             \
206                        typename std::remove_cv<ElemTy>::type>::value,          \
207       "This implementation is for arithmetic values only")
208 
209 //===----------------------------------------------------------------------===//
210 //                       Convolution
211 //===----------------------------------------------------------------------===//
212 
213 /// This is the floating point implementation of Convolution.
214 template <typename ElemTy>
fwdConvolutionInstFloatImpl(Value * inV,Value * outV,Value * filterV,Value * biasV,llvm::ArrayRef<unsigned_t> kernelSizes,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,size_t group,size_t dilation)215 void BoundInterpreterFunction::fwdConvolutionInstFloatImpl(
216     Value *inV, Value *outV, Value *filterV, Value *biasV,
217     llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides,
218     llvm::ArrayRef<unsigned_t> pads, size_t group, size_t dilation) {
219   staticAssertFloatingPointType(ElemTy);
220 
221   auto inW = getWeightHandle<ElemTy>(inV);
222   auto outW = getWeightHandle<ElemTy>(outV);
223   auto filterW = getWeightHandle<ElemTy>(filterV);
224   auto biasW = getWeightHandle<ElemTy>(biasV);
225 
226   ShapeNHWC odim(outW.dims());
227   ShapeNHWC idim(inW.dims());
228   ShapeHW kdim(kernelSizes);
229   ShapeHW sdim(strides);
230 
231   assert(idim.c % group == 0 && "Input channels must be divisible by group.");
232   assert(odim.c % group == 0 && "Output channels must be divisible by group.");
233   dim_t inCperG = idim.c / group;
234   dim_t outCperG = odim.c / group;
235 
236   PaddingTLBR pdim(pads);
237 
238   // For each input in the batch:
239   for (dim_t n = 0; n < idim.n; n++) {
240 
241     // For each group of input channels:
242     for (dim_t g = 0; g < group; g++) {
243 
244       // For each output channel in the group:
245       for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
246 
247         // For each convolution 'jump' in the input tensor:
248         ssize_t x = -ssize_t(pdim.top);
249         for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
250           ssize_t y = -ssize_t(pdim.left);
251           for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
252 
253             // For each element in the convolution-filter:
254             float sum = 0;
255             for (dim_t fx = 0; fx < kdim.height; fx++) {
256               for (dim_t fy = 0; fy < kdim.width; fy++) {
257                 sdim_t ox = x + fx * dilation;
258                 sdim_t oy = y + fy * dilation;
259 
260                 // Ignore index access below zero (this is due to padding).
261                 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
262                     oy >= ssize_t(idim.w)) {
263                   continue;
264                 }
265                 for (dim_t fd = 0; fd < inCperG; fd++) {
266                   sum += float(
267                       filterW.at({d, fx, fy, fd}) *
268                       inW.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd}));
269                 }
270               }
271             }
272 
273             sum += float(biasW.at({d}));
274             outW.at({n, ax, ay, d}) = ElemTy(sum);
275           } // W
276         }   // H
277       }     // C
278     }       // G
279   }         // N
280 }
281 
282 /// This is the quantized implementation of Convolution.
283 template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
fwdConvolutionInstQuantizedImpl(Value * inV,Value * outV,Value * filterV,Value * biasV,llvm::ArrayRef<unsigned_t> kernelSizes,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,size_t group,size_t dilation)284 void BoundInterpreterFunction::fwdConvolutionInstQuantizedImpl(
285     Value *inV, Value *outV, Value *filterV, Value *biasV,
286     llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides,
287     llvm::ArrayRef<unsigned_t> pads, size_t group, size_t dilation) {
288   auto inW = getWeightHandle<ElemTy>(inV);
289   auto outW = getWeightHandle<ElemTy>(outV);
290   auto filterW = getWeightHandle<ElemTy>(filterV);
291   auto biasW = getWeightHandle<BiasElemTy>(biasV);
292 
293   ShapeNHWC odim(outW.dims());
294   ShapeNHWC idim(inW.dims());
295   ShapeHW kdim(kernelSizes);
296   ShapeHW sdim(strides);
297 
298   assert(idim.c % group == 0 && "Input channels must be divisible by group.");
299   assert(odim.c % group == 0 && "Output channels must be divisible by group.");
300   dim_t inCperG = idim.c / group;
301   dim_t outCperG = odim.c / group;
302 
303   PaddingTLBR pdim(pads);
304   auto outTy = outV->getType();
305   auto inTy = inV->getType();
306   auto filterTy = filterV->getType();
307   auto biasTy = biasV->getType();
308 
309   int32_t outOffset = outTy->getOffset();
310   int32_t inOffset = inTy->getOffset();
311   int32_t filterOffset = filterTy->getOffset();
312   int32_t biasOffset = biasTy->getOffset();
313 
314   float outScale = outTy->getScale();
315   float inScale = inTy->getScale();
316   float filterScale = filterTy->getScale();
317   float biasScale = biasTy->getScale();
318 
319   // Calculate the scale of the values that come out of the matrix
320   // multiplication part of the calculation.
321   float matMulScale = inScale * filterScale;
322 
323   // For each input in the batch:
324   for (dim_t n = 0; n < idim.n; n++) {
325     // For each group of input channels:
326     for (dim_t g = 0; g < group; g++) {
327 
328       // For each output channel in the group:
329       for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
330 
331         // For each convolution 'jump' in the input tensor:
332         ssize_t x = -ssize_t(pdim.top);
333         for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
334           ssize_t y = -ssize_t(pdim.left);
335           for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
336 
337             // For each element in the convolution-filter:
338             AccumulatorTy sum = 0;
339             for (dim_t fx = 0; fx < kdim.height; fx++) {
340               for (dim_t fy = 0; fy < kdim.width; fy++) {
341                 sdim_t ox = x + fx * dilation;
342                 sdim_t oy = y + fy * dilation;
343 
344                 // Ignore index access below zero (this is due to padding).
345                 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
346                     oy >= sdim_t(idim.w)) {
347                   continue;
348                 }
349                 for (dim_t fd = 0; fd < inCperG; fd++) {
350 
351                   AccumulatorTy F = filterW.at({d, fx, fy, fd});
352                   AccumulatorTy I =
353                       inW.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd});
354                   // We represent the element multiplication with offset as
355                   // (value - offset).
356                   sum += (F - filterOffset) * (I - inOffset);
357                 }
358               }
359             }
360 
361             // Scale the bias to match the scale of the matrix multiplication.
362             AccumulatorTy B = std::round(float(biasW.at({d}) - biasOffset) *
363                                          (biasScale / matMulScale));
364 
365             // Add the bias.
366             sum += B;
367 
368             // Scale the result back to the expected destination scale.
369             outW.at({n, ax, ay, d}) = quantization::clip<AccumulatorTy, ElemTy>(
370                 std::round(float(sum) * (matMulScale / outScale) + outOffset));
371           } // W
372         }   // H
373       }     // C
374     }       // G
375   }         // N
376 }
377 
378 /// This is the floating point implementation of ConvTranspose.
379 template <typename ElemTy>
fwdConvTransposeInstFloatImpl(Value * inV,Value * outV,Value * filterV,Value * biasV,llvm::ArrayRef<unsigned_t> kernelSizes,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,size_t group,size_t dilation)380 void BoundInterpreterFunction::fwdConvTransposeInstFloatImpl(
381     Value *inV, Value *outV, Value *filterV, Value *biasV,
382     llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides,
383     llvm::ArrayRef<unsigned_t> pads, size_t group, size_t dilation) {
384   staticAssertFloatingPointType(ElemTy);
385 
386   auto inW = getWeightHandle<ElemTy>(inV);
387   auto outW = getWeightHandle<ElemTy>(outV);
388   auto filterW = getWeightHandle<ElemTy>(filterV);
389   auto biasW = getWeightHandle<ElemTy>(biasV);
390 
391   ShapeNHWC odim(outW.dims());
392   ShapeNHWC idim(inW.dims());
393   ShapeHW kdim(kernelSizes);
394   ShapeHW sdim(strides);
395 
396   assert(idim.c % group == 0 && "Input channels must be divisible by group.");
397   assert(odim.c % group == 0 && "Output channels must be divisible by group.");
398   assert(group == 1 && "Group must be 1.");
399 
400   dim_t inCperG = idim.c / group;
401   dim_t outCperG = odim.c / group;
402 
403   PaddingTLBR pdim(pads);
404 
405   // For each input in the batch:
406   for (dim_t n = 0; n < idim.n; n++) {
407 
408     // Initialize bias (TODO take out to a separate function when quant is in).
409     for (dim_t ax = 0; ax < odim.h; ax++) {
410       for (dim_t ay = 0; ay < odim.w; ay++) {
411         for (dim_t d = 0; d < odim.c; d++) {
412           outW.at({n, ax, ay, d}) = static_cast<ElemTy>(biasW.at({d}));
413         }
414       }
415     }
416 
417     // For each group of input channels:
418     for (dim_t g = 0; g < group; g++) {
419 
420       // For each input channel in the group:
421       for (dim_t d = g * inCperG; d < (g + 1) * inCperG; d++) {
422 
423         // For each transposed convolution 'jump' in the input tensor:
424         ssize_t x = -ssize_t(pdim.top);
425         for (dim_t bx = 0; bx < idim.h; bx++, x += sdim.height) {
426           ssize_t y = -ssize_t(pdim.left);
427           for (dim_t by = 0; by < idim.w; by++, y += sdim.width) {
428 
429             // For each element in the each transposed convolution filter:
430             ElemTy input = inW.at({n, bx, by, d});
431 
432             for (dim_t kx = 0; kx < kdim.height; kx++) {
433               for (dim_t ky = 0; ky < kdim.width; ky++) {
434                 ssize_t ax = x + kx * dilation;
435                 ssize_t ay = y + ky * dilation;
436 
437                 // Ignore index access below zero (this is due to padding).
438                 if (ax < 0 || ay < 0 || ax >= ssize_t(odim.h) ||
439                     ay >= ssize_t(odim.w)) {
440                   continue;
441                 }
442                 for (dim_t c = 0; c < outCperG; c++) {
443                   outW.at({n, (dim_t)ax, (dim_t)ay, g * outCperG + c}) +=
444                       filterW.at({c, kx, ky, d}) * input;
445                 }
446               }
447             }
448           } // W
449         }   // H
450       }     // C
451     }       // G
452   }         // N
453 }
454 
fwdConvTransposeInst(const ConvTransposeInst * I)455 void BoundInterpreterFunction::fwdConvTransposeInst(
456     const ConvTransposeInst *I) {
457   auto kernelSizes = I->getKernels();
458   auto pads = I->getPads();
459   auto strides = I->getStrides();
460   size_t group = I->getGroup();
461 
462   if (I->getSrc()->getType()->isQuantizedType()) {
463     llvm_unreachable("Quantized ConvTranspose not supported");
464     return;
465   }
466 
467   dispatchFloatingPointImpl(
468       fwdConvTransposeInstFloatImpl, I->getSrc()->getElementType(), I->getSrc(),
469       I->getDest(), I->getFilter(), I->getBias(), kernelSizes, strides, pads,
470       group, I->getDilation());
471 }
472 
fwdConvolutionInst(const ConvolutionInst * I)473 void BoundInterpreterFunction::fwdConvolutionInst(const ConvolutionInst *I) {
474   auto kernelSizes = I->getKernels();
475   auto pads = I->getPads();
476   auto strides = I->getStrides();
477   size_t group = I->getGroup();
478 
479   if (I->getSrc()->getType()->isQuantizedType()) {
480     dispatchQuantizedWithAccumulationAndBiasImpl(
481         fwdConvolutionInstQuantizedImpl, I->getSrc()->getElementType(),
482         I->getBias()->getElementType(), I->getSrc(), I->getDest(),
483         I->getFilter(), I->getBias(), kernelSizes, strides, pads, group,
484         I->getDilation());
485     return;
486   }
487 
488   dispatchFloatingPointImpl(
489       fwdConvolutionInstFloatImpl, I->getSrc()->getElementType(), I->getSrc(),
490       I->getDest(), I->getFilter(), I->getBias(), kernelSizes, strides, pads,
491       group, I->getDilation());
492 }
493 
fwdConvolutionGradInst(const ConvolutionGradInst * I)494 void BoundInterpreterFunction::fwdConvolutionGradInst(
495     const ConvolutionGradInst *I) {
496   auto inW = getWeightHandle(I->getSrc());
497   auto inG = getWeightHandle(I->getSrcGrad());
498   auto outG = getWeightHandle(I->getDestGrad());
499 
500   auto filterW = getWeightHandle(I->getFilter());
501   auto filterG = getWeightHandle(I->getFilterGrad());
502   auto biasG = getWeightHandle(I->getBiasGrad());
503 
504   size_t group = I->getGroup();
505   size_t dilation = I->getDilation();
506 
507   inG.clear();
508   filterG.clear();
509   biasG.clear();
510 
511   ShapeNHWC odim(outG.dims());
512   ShapeNHWC idim(inW.dims());
513   ShapeHW kdim(I->getKernels());
514   ShapeHW sdim(I->getStrides());
515   PaddingTLBR pdim(I->getPads());
516 
517   assert(idim.c % group == 0 && "Input channels must be divisible by group.");
518   assert(odim.c % group == 0 && "Output channels must be divisible by group.");
519   dim_t inCperG = idim.c / group;
520   dim_t outCperG = odim.c / group;
521 
522   // For each input in the batch:
523   for (dim_t n = 0; n < odim.n; n++) {
524 
525     // For each group of input channels:
526     for (dim_t g = 0; g < group; g++) {
527 
528       // Compute the gradient. For each layer in the output tensor:
529       for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
530 
531         // For each convolution 'jump' in the input tensor:
532         sdim_t x = -sdim_t(pdim.top);
533         for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
534           sdim_t y = -sdim_t(pdim.left);
535           for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
536 
537             float chainGrad = outG.at({n, ax, ay, d});
538 
539             // For each element in the convolution-filter:
540             for (dim_t fx = 0; fx < kdim.height; fx++) {
541               for (dim_t fy = 0; fy < kdim.width; fy++) {
542                 sdim_t ox = x + fx * dilation;
543                 sdim_t oy = y + fy * dilation;
544 
545                 // Ignore index access below zero (this is due to padding).
546                 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
547                     oy >= sdim_t(idim.w)) {
548                   continue;
549                 }
550 
551                 for (dim_t fd = 0; fd < inCperG; fd++) {
552                   filterG.at({d, fx, fy, fd}) +=
553                       inW.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd}) *
554                       chainGrad;
555                   inG.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd}) +=
556                       filterW.at({d, fx, fy, fd}) * chainGrad;
557                 }
558               }
559             }
560 
561             biasG.at({d}) += chainGrad;
562           } // W
563         }   // H
564       }     // C
565     }       // G
566   }         // N
567 }
568 
569 /// This is the floating point implementation of Convolution3D.
570 template <typename ElemTy>
fwdConvolution3DInstFloatImpl(Value * inV,Value * outV,Value * filterV,Value * biasV,llvm::ArrayRef<unsigned_t> kernelSizes,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,size_t group)571 void BoundInterpreterFunction::fwdConvolution3DInstFloatImpl(
572     Value *inV, Value *outV, Value *filterV, Value *biasV,
573     llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides,
574     llvm::ArrayRef<unsigned_t> pads, size_t group) {
575   staticAssertFloatingPointType(ElemTy);
576 
577   auto inW = getWeightHandle<ElemTy>(inV);
578   auto outW = getWeightHandle<ElemTy>(outV);
579   auto filterW = getWeightHandle<ElemTy>(filterV);
580   auto biasW = getWeightHandle<ElemTy>(biasV);
581 
582   ShapeNTHWC odim(outW.dims());
583   ShapeNTHWC idim(inW.dims());
584   ShapeTHW kdim(kernelSizes);
585   ShapeTHW sdim(strides);
586 
587   assert(idim.c % group == 0 && "Input channels must be divisible by group.");
588   assert(odim.c % group == 0 && "Output channels must be divisible by group.");
589   dim_t inCperG = idim.c / group;
590   dim_t outCperG = odim.c / group;
591 
592   PaddingNFTBLR pdim(pads);
593 
594   // For each input in the batch:
595   for (dim_t n = 0; n < idim.n; n++) {
596 
597     // For each group of input channels:
598     for (dim_t ig = 0; ig < group; ig++) {
599 
600       // For each output channel in the group:
601       for (dim_t og = ig * outCperG; og < (ig + 1) * outCperG; og++) {
602 
603         ssize_t t = -ssize_t(pdim.near);
604         for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
605           // For each convolution 'jump' in the input tensor:
606           ssize_t x = -ssize_t(pdim.top);
607           for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
608             ssize_t y = -ssize_t(pdim.left);
609             for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
610               // For each element in the 3D convolution-filter:
611               float sum = 0;
612               for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
613                 for (dim_t fx = 0; fx < kdim.height; fx++) {
614                   for (dim_t fy = 0; fy < kdim.width; fy++) {
615                     sdim_t ot = t + ft;
616                     sdim_t ox = x + fx;
617                     sdim_t oy = y + fy;
618 
619                     // Ignore index access below zero (this is due to padding).
620                     if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
621                         ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) {
622                       continue;
623                     }
624                     for (dim_t fg = 0; fg < inCperG; fg++) {
625                       sum += float(filterW.at({og, ft, fx, fy, fg}) *
626                                    inW.at({n, (dim_t)ot, (dim_t)ox, (dim_t)oy,
627                                            ig * inCperG + fg}));
628                     }
629                   }
630                 }
631               }
632 
633               sum += float(biasW.at({og}));
634               outW.at({n, at, ax, ay, og}) = ElemTy(sum);
635             } // D
636           }   // W
637         }     // H
638       }       // C
639     }         // G
640   }           // N
641 }
642 
643 /// This is the quantized implementation of Convolution3D.
644 template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
fwdConvolution3DInstQuantizedImpl(Value * inV,Value * outV,Value * filterV,Value * biasV,llvm::ArrayRef<unsigned_t> kernelSizes,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads,size_t group)645 void BoundInterpreterFunction::fwdConvolution3DInstQuantizedImpl(
646     Value *inV, Value *outV, Value *filterV, Value *biasV,
647     llvm::ArrayRef<unsigned_t> kernelSizes, llvm::ArrayRef<unsigned_t> strides,
648     llvm::ArrayRef<unsigned_t> pads, size_t group) {
649   auto inW = getWeightHandle<ElemTy>(inV);
650   auto outW = getWeightHandle<ElemTy>(outV);
651   auto filterW = getWeightHandle<ElemTy>(filterV);
652   auto biasW = getWeightHandle<BiasElemTy>(biasV);
653 
654   ShapeNTHWC odim(outW.dims());
655   ShapeNTHWC idim(inW.dims());
656   ShapeTHW kdim(kernelSizes);
657   ShapeTHW sdim(strides);
658 
659   assert(idim.c % group == 0 && "Input channels must be divisible by group.");
660   assert(odim.c % group == 0 && "Output channels must be divisible by group.");
661   dim_t inCperG = idim.c / group;
662   dim_t outCperG = odim.c / group;
663 
664   PaddingNFTBLR pdim(pads);
665 
666   auto outTy = outV->getType();
667   auto inTy = inV->getType();
668   auto filterTy = filterV->getType();
669   auto biasTy = biasV->getType();
670 
671   int32_t outOffset = outTy->getOffset();
672   int32_t inOffset = inTy->getOffset();
673   int32_t filterOffset = filterTy->getOffset();
674   int32_t biasOffset = biasTy->getOffset();
675 
676   float outScale = outTy->getScale();
677   float inScale = inTy->getScale();
678   float filterScale = filterTy->getScale();
679   float biasScale = biasTy->getScale();
680 
681   // Calculate the scale of the values that come out of the matrix
682   // multiplication part of the calculation.
683   float matMulScale = inScale * filterScale;
684 
685   // For each input in the batch:
686   for (dim_t n = 0; n < idim.n; n++) {
687 
688     // For each group of input channels:
689     for (dim_t ig = 0; ig < group; ig++) {
690 
691       // For each output channel in the group:
692       for (dim_t og = ig * outCperG; og < (ig + 1) * outCperG; og++) {
693 
694         // For each convolution 'jump' in the input tensor:
695         ssize_t t = -ssize_t(pdim.near);
696         for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
697           ssize_t x = -ssize_t(pdim.top);
698           for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
699             ssize_t y = -ssize_t(pdim.left);
700             for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
701 
702               // For each element in the convolution-filter:
703               AccumulatorTy sum = 0;
704               for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
705                 for (dim_t fx = 0; fx < kdim.height; fx++) {
706                   for (dim_t fy = 0; fy < kdim.width; fy++) {
707                     ssize_t ot = t + ft;
708                     ssize_t ox = x + fx;
709                     ssize_t oy = y + fy;
710 
711                     // Ignore index access below zero (this is due to padding).
712                     if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
713                         ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) {
714                       continue;
715                     }
716                     for (dim_t fg = 0; fg < inCperG; fg++) {
717 
718                       AccumulatorTy F = filterW.at({og, ft, fx, fy, fg});
719                       AccumulatorTy I = inW.at({n, (dim_t)ot, (dim_t)ox,
720                                                 (dim_t)oy, ig * inCperG + fg});
721                       // We represent the element multiplication with offset as
722                       // (value - offset).
723                       sum += (F - filterOffset) * (I - inOffset);
724                     }
725                   }
726                 }
727               }
728 
729               // Scale the bias to match the scale of the matrix multiplication.
730               AccumulatorTy B = std::round(float(biasW.at({og}) - biasOffset) *
731                                            (biasScale / matMulScale));
732 
733               // Add the bias:
734               sum += B;
735 
736               // Scale the result back to the expected destination scale.
737               outW.at({n, at, ax, ay, og}) =
738                   quantization::clip<AccumulatorTy, ElemTy>(std::round(
739                       float(sum) * (matMulScale / outScale) + outOffset));
740             } // D
741           }   // W
742         }     // H
743       }       // C
744     }         // G
745   }           // N
746 }
747 
fwdConvolution3DInst(const Convolution3DInst * I)748 void BoundInterpreterFunction::fwdConvolution3DInst(
749     const Convolution3DInst *I) {
750   auto kernelSizes = I->getKernels();
751   auto pads = I->getPads();
752   auto strides = I->getStrides();
753   size_t group = I->getGroup();
754 
755   if (I->getSrc()->getType()->isQuantizedType()) {
756     dispatchQuantizedWithAccumulationAndBiasImpl(
757         fwdConvolution3DInstQuantizedImpl, I->getSrc()->getElementType(),
758         I->getBias()->getElementType(), I->getSrc(), I->getDest(),
759         I->getFilter(), I->getBias(), kernelSizes, strides, pads, group);
760     return;
761   }
762 
763   dispatchFloatingPointImpl(fwdConvolution3DInstFloatImpl,
764                             I->getSrc()->getElementType(), I->getSrc(),
765                             I->getDest(), I->getFilter(), I->getBias(),
766                             kernelSizes, strides, pads, group);
767 }
768 
fwdConvolution3DGradInst(const Convolution3DGradInst * I)769 void BoundInterpreterFunction::fwdConvolution3DGradInst(
770     const Convolution3DGradInst *I) {
771   (void)I;
772   // TODO
773   llvm_unreachable("not yet implemented");
774 }
775 
776 //===----------------------------------------------------------------------===//
777 //                       Channelwise quantized Convolution
778 //===----------------------------------------------------------------------===//
779 template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
fwdChannelwiseQuantizedConv2DInstImpl(const ChannelwiseQuantizedConvolutionInst * I)780 void BoundInterpreterFunction::fwdChannelwiseQuantizedConv2DInstImpl(
781     const ChannelwiseQuantizedConvolutionInst *I) {
782   auto inW = getWeightHandle<ElemTy>(I->getSrc());
783   auto outW = getWeightHandle<ElemTy>(I->getDest());
784   auto filterW = getWeightHandle<ElemTy>(I->getFilter());
785   auto biasW = getWeightHandle<BiasElemTy>(I->getBias());
786   auto filterScales = getWeightHandle<float>(I->getFilterScales());
787   auto filterOffsets = getWeightHandle<int32_t>(I->getFilterOffsets());
788   auto biasScales = getWeightHandle<float>(I->getBiasScales());
789   auto biasOffsets = getWeightHandle<int32_t>(I->getBiasOffsets());
790 
791   llvm::ArrayRef<unsigned_t> kernelSizes = I->getKernels();
792   llvm::ArrayRef<unsigned_t> pads = I->getPads();
793   llvm::ArrayRef<unsigned_t> strides = I->getStrides();
794   dim_t group = I->getGroup();
795   dim_t dilation = I->getDilation();
796 
797   ShapeNHWC odim(outW.dims());
798   ShapeNHWC idim(inW.dims());
799   ShapeHW kdim(kernelSizes);
800   ShapeHW sdim(strides);
801 
802   assert(idim.c % group == 0 && "Input channels must be divisible by group.");
803   assert(odim.c % group == 0 && "Output channels must be divisible by group.");
804   dim_t inCperG = idim.c / group;
805   dim_t outCperG = odim.c / group;
806 
807   PaddingTLBR pdim(pads);
808 
809   auto &inTy = inW.getType();
810   auto &outTy = outW.getType();
811 
812   float inScale = inTy.getScale();
813   float outScale = outTy.getScale();
814 
815   int32_t inOffset = inTy.getOffset();
816   int32_t outOffset = outTy.getOffset();
817 
818   // For each input in the batch:
819   for (dim_t n = 0; n < idim.n; n++) {
820     // For each group of input channels:
821     for (dim_t g = 0; g < group; g++) {
822       // For each output channel in the group:
823       for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
824 
825         // Get channel wise quantization params.
826         int32_t filterOffset = filterOffsets.at(d);
827         float filterScale = filterScales.at(d);
828         int32_t biasOffset = biasOffsets.at(d);
829         float biasScale = biasScales.at(d);
830         float matMulScale = inScale * filterScale;
831 
832         // For each convolution 'jump' in the input tensor:
833         sdim_t x = -sdim_t(pdim.top);
834         for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
835           sdim_t y = -sdim_t(pdim.left);
836           for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
837 
838             // For each element in the convolution-filter:
839             AccumulatorTy sum = 0;
840             for (dim_t fx = 0; fx < kdim.height; fx++) {
841               for (dim_t fy = 0; fy < kdim.width; fy++) {
842                 sdim_t ox = x + fx * dilation;
843                 sdim_t oy = y + fy * dilation;
844 
845                 // Ignore index access below zero (this is due to padding).
846                 if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
847                     oy >= sdim_t(idim.w)) {
848                   continue;
849                 }
850 
851                 // Accumulate along the filter depth.
852                 for (dim_t fd = 0; fd < inCperG; fd++) {
853                   AccumulatorTy F = filterW.at({d, fx, fy, fd});
854                   AccumulatorTy I =
855                       inW.at({n, (dim_t)ox, (dim_t)oy, g * inCperG + fd});
856                   // We represent the element multiplication with offset as
857                   // (value - offset).
858                   sum += (F - filterOffset) * (I - inOffset);
859                 }
860               }
861             }
862 
863             // Scale the bias to match the scale of the matrix multiplication.
864             sum += std::round(float(biasW.at({d}) - biasOffset) *
865                               (biasScale / matMulScale));
866 
867             // Scale the result back to the expected destination scale.
868             outW.at({n, ax, ay, d}) = quantization::clip<AccumulatorTy, ElemTy>(
869                 std::round(float(sum) * (matMulScale / outScale) + outOffset));
870           } // W
871         }   // H
872       }     // C
873     }       // G
874   }         // N
875 }
876 
877 template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
fwdChannelwiseQuantizedConv3DInstImpl(const ChannelwiseQuantizedConvolutionInst * I)878 void BoundInterpreterFunction::fwdChannelwiseQuantizedConv3DInstImpl(
879     const ChannelwiseQuantizedConvolutionInst *I) {
880   auto inW = getWeightHandle<ElemTy>(I->getSrc());
881   auto outW = getWeightHandle<ElemTy>(I->getDest());
882   auto filterW = getWeightHandle<ElemTy>(I->getFilter());
883   auto biasW = getWeightHandle<BiasElemTy>(I->getBias());
884   auto filterScales = getWeightHandle<float>(I->getFilterScales());
885   auto filterOffsets = getWeightHandle<int32_t>(I->getFilterOffsets());
886   auto biasScales = getWeightHandle<float>(I->getBiasScales());
887   auto biasOffsets = getWeightHandle<int32_t>(I->getBiasOffsets());
888 
889   llvm::ArrayRef<unsigned_t> kernelSizes = I->getKernels();
890   llvm::ArrayRef<unsigned_t> pads = I->getPads();
891   llvm::ArrayRef<unsigned_t> strides = I->getStrides();
892   dim_t group = I->getGroup();
893 
894   ShapeNTHWC odim(outW.dims());
895   ShapeNTHWC idim(inW.dims());
896   ShapeTHW kdim(kernelSizes);
897   ShapeTHW sdim(strides);
898 
899   assert(idim.c % group == 0 && "Input channels must be divisible by group.");
900   assert(odim.c % group == 0 && "Output channels must be divisible by group.");
901   dim_t inCperG = idim.c / group;
902   dim_t outCperG = odim.c / group;
903 
904   PaddingNFTBLR pdim(pads);
905 
906   auto &inTy = inW.getType();
907   auto &outTy = outW.getType();
908 
909   float inScale = inTy.getScale();
910   float outScale = outTy.getScale();
911 
912   int32_t inOffset = inTy.getOffset();
913   int32_t outOffset = outTy.getOffset();
914 
915   // For each input in the batch:
916   for (dim_t n = 0; n < idim.n; n++) {
917     // For each group of input channels:
918     for (dim_t g = 0; g < group; g++) {
919       // For each output channel in the group:
920       for (dim_t d = g * outCperG; d < (g + 1) * outCperG; d++) {
921 
922         // Get channel wise quantization params.
923         int32_t filterOffset = filterOffsets.at(d);
924         float filterScale = filterScales.at(d);
925         int32_t biasOffset = biasOffsets.at(d);
926         float biasScale = biasScales.at(d);
927         float matMulScale = inScale * filterScale;
928 
929         // For each convolution 'jump' in the input tensor:
930         sdim_t t = -sdim_t(pdim.near);
931         for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
932           sdim_t x = -sdim_t(pdim.top);
933           for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
934             sdim_t y = -sdim_t(pdim.left);
935             for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
936 
937               // For each element in the convolution-filter:
938               AccumulatorTy sum = 0;
939               for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
940                 for (dim_t fx = 0; fx < kdim.height; fx++) {
941                   for (dim_t fy = 0; fy < kdim.width; fy++) {
942                     sdim_t ot = t + ft;
943                     sdim_t ox = x + fx;
944                     sdim_t oy = y + fy;
945 
946                     // Ignore index access below zero (this is due to
947                     // padding).
948                     if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
949                         ox >= ssize_t(idim.h) || oy >= sdim_t(idim.w)) {
950                       continue;
951                     }
952 
953                     // Accumulate along the filter depth.
954                     for (dim_t fd = 0; fd < inCperG; fd++) {
955 
956                       AccumulatorTy F = filterW.at({d, ft, fx, fy, fd});
957                       AccumulatorTy I = inW.at({n, (dim_t)ot, (dim_t)ox,
958                                                 (dim_t)oy, g * inCperG + fd});
959                       // We represent the element multiplication with offset
960                       // as (value - offset).
961                       sum += (F - filterOffset) * (I - inOffset);
962                     }
963                   }
964                 }
965               }
966 
967               // Scale the bias to match the scale of the matrix multiplication.
968               sum += std::round(float(biasW.at({d}) - biasOffset) *
969                                 (biasScale / matMulScale));
970 
971               // Scale the result back to the expected destination scale.
972               outW.at({n, at, ax, ay, d}) =
973                   quantization::clip<AccumulatorTy, ElemTy>(std::round(
974                       float(sum) * (matMulScale / outScale) + outOffset));
975             } // W
976           }   // H
977         }     // T
978       }       // C
979     }         // G
980   }           // N
981 }
982 
fwdChannelwiseQuantizedConvolutionInst(const ChannelwiseQuantizedConvolutionInst * I)983 void BoundInterpreterFunction::fwdChannelwiseQuantizedConvolutionInst(
984     const ChannelwiseQuantizedConvolutionInst *I) {
985   bool isConv3D = (I->getSrc()->dims().size() == 5);
986   if (isConv3D) {
987     dispatchQuantizedWithAccumulationAndBiasImpl(
988         fwdChannelwiseQuantizedConv3DInstImpl, I->getSrc()->getElementType(),
989         I->getBias()->getElementType(), I);
990   } else {
991     dispatchQuantizedWithAccumulationAndBiasImpl(
992         fwdChannelwiseQuantizedConv2DInstImpl, I->getSrc()->getElementType(),
993         I->getBias()->getElementType(), I);
994   }
995 }
996 
997 //===----------------------------------------------------------------------===//
998 //                       Pooling
999 //===----------------------------------------------------------------------===//
1000 template <class T>
fwdMaxPool(Tensor * inW,Tensor * outW,Tensor * argmaxW,llvm::ArrayRef<unsigned_t> kernelSizes,llvm::ArrayRef<unsigned_t> strides,llvm::ArrayRef<unsigned_t> pads)1001 static void fwdMaxPool(Tensor *inW, Tensor *outW, Tensor *argmaxW,
1002                        llvm::ArrayRef<unsigned_t> kernelSizes,
1003                        llvm::ArrayRef<unsigned_t> strides,
1004                        llvm::ArrayRef<unsigned_t> pads) {
1005   ShapeNHWC odim(outW->dims());
1006   ShapeNHWC idim(inW->dims());
1007   Handle<T> inHandle = inW->getHandle<T>();
1008   Handle<T> outHandle = outW->getHandle<T>();
1009   PaddingTLBR pdim(pads);
1010   ShapeHW kdim(kernelSizes);
1011   ShapeHW sdim(strides);
1012 
1013   llvm::Optional<Handle<int64_t>> argmaxH;
1014   if (argmaxW) {
1015     argmaxH = argmaxW->getHandle<int64_t>();
1016   }
1017   // For each input in the batch:
1018   for (dim_t n = 0; n < odim.n; n++) {
1019 
1020     // For each layer in the output tensor:
1021     for (dim_t z = 0; z < idim.c; z++) {
1022       // For each convolution 'jump' in the input tensor:
1023       sdim_t x = -sdim_t(pdim.top);
1024       for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1025         sdim_t y = -sdim_t(pdim.left);
1026         for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1027 
1028           bool first = true;
1029           T max_value = 0;
1030           dim_t argmaxNHWC = 0;
1031 
1032           for (dim_t fx = 0; fx < kdim.height; fx++) {
1033             for (dim_t fy = 0; fy < kdim.width; fy++) {
1034               sdim_t ox = x + fx;
1035               sdim_t oy = y + fy;
1036 
1037               // Ignore index access below zero (this is due to padding).
1038               if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
1039                   oy >= ssize_t(idim.w)) {
1040                 continue;
1041               }
1042 
1043               T val = inHandle.at({n, (dim_t)ox, (dim_t)oy, z});
1044               if (first || (val >= max_value)) {
1045                 first = false;
1046                 max_value = val;
1047                 if (argmaxW) {
1048                   argmaxNHWC = &inHandle.at({n, (dim_t)ox, (dim_t)oy, z}) -
1049                                &inHandle.raw(0);
1050                 }
1051               }
1052             }
1053           }
1054 
1055           outHandle.at({n, ax, ay, z}) = max_value;
1056 
1057           if (argmaxW) {
1058             (*argmaxH).at({n, ax, ay, z}) = argmaxNHWC;
1059           }
1060         } // W
1061       }   // H
1062     }     // C
1063   }       // N
1064 }
1065 
fwdMaxPoolInst(const MaxPoolInst * I)1066 void BoundInterpreterFunction::fwdMaxPoolInst(const MaxPoolInst *I) {
1067   auto inW = getTensor(I->getSrc());
1068   auto outW = getTensor(I->getDest());
1069 
1070   if (inW->getType().isQuantizedType()) {
1071     dispatchQuantizedImpl(fwdMaxPool, inW->getType().getElementType(), inW,
1072                           outW, nullptr, I->getKernels(), I->getStrides(),
1073                           I->getPads());
1074     return;
1075   }
1076 
1077   dispatchFloatingPointImpl(fwdMaxPool, inW->getType().getElementType(), inW,
1078                             outW, nullptr, I->getKernels(), I->getStrides(),
1079                             I->getPads());
1080 }
1081 
fwdMaxPoolWithArgmaxInst(const MaxPoolWithArgmaxInst * I)1082 void BoundInterpreterFunction::fwdMaxPoolWithArgmaxInst(
1083     const MaxPoolWithArgmaxInst *I) {
1084   auto inW = getTensor(I->getSrc());
1085   auto outW = getTensor(I->getDest());
1086   auto argmaxW = getTensor(I->getArgmax());
1087 
1088   if (inW->getType().isQuantizedType()) {
1089     dispatchQuantizedImpl(fwdMaxPool, inW->getType().getElementType(), inW,
1090                           outW, argmaxW, I->getKernels(), I->getStrides(),
1091                           I->getPads());
1092     return;
1093   }
1094   dispatchFloatingPointImpl(fwdMaxPool, inW->getType().getElementType(), inW,
1095                             outW, argmaxW, I->getKernels(), I->getStrides(),
1096                             I->getPads());
1097 }
1098 
1099 template <typename ElemTy>
fwdAvgPoolInstFloatImpl(const AvgPoolInst * I)1100 void BoundInterpreterFunction::fwdAvgPoolInstFloatImpl(const AvgPoolInst *I) {
1101   staticAssertFloatingPointType(ElemTy);
1102 
1103   ShapeNHWC odim(I->getDest()->dims());
1104   ShapeNHWC idim(I->getSrc()->dims());
1105 
1106   PaddingTLBR pdim(I->getPads());
1107   ShapeHW kdim(I->getKernels());
1108   ShapeHW sdim(I->getStrides());
1109   // Implement the avg pooling operation as defined here:
1110   // https://arxiv.org/abs/1312.4400
1111   float filterArea = kdim.height * kdim.width;
1112 
1113   auto inW = getWeightHandle<ElemTy>(I->getSrc());
1114   auto outW = getWeightHandle<ElemTy>(I->getDest());
1115 
1116   // For each input in the batch:
1117   for (dim_t n = 0; n < odim.n; n++) {
1118     // For each layer in the output tensor:
1119     for (dim_t z = 0; z < idim.c; z++) {
1120       // For each convolution 'jump' in the input tensor:
1121       ssize_t x = -ssize_t(pdim.top);
1122       for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1123         ssize_t y = -ssize_t(pdim.left);
1124         for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1125           float sum = 0;
1126 
1127           for (dim_t fx = 0; fx < kdim.height; fx++) {
1128             for (dim_t fy = 0; fy < kdim.width; fy++) {
1129               sdim_t ox = x + fx;
1130               sdim_t oy = y + fy;
1131 
1132               // Ignore index access below zero (this is due to padding).
1133               if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
1134                   oy >= ssize_t(idim.w)) {
1135                 continue;
1136               }
1137 
1138               sum += float(inW.at({n, (dim_t)ox, (dim_t)oy, z}));
1139             }
1140           }
1141           outW.at({n, ax, ay, z}) = ElemTy(sum / filterArea);
1142         } // W
1143       }   // H
1144     }     // C
1145   }       // N
1146 }
1147 
fwdAvgPoolInstI8Impl(const AvgPoolInst * I)1148 void BoundInterpreterFunction::fwdAvgPoolInstI8Impl(const AvgPoolInst *I) {
1149   ShapeNHWC odim(I->getDest()->dims());
1150   ShapeNHWC idim(I->getSrc()->dims());
1151 
1152   PaddingTLBR pdim(I->getPads());
1153   ShapeHW kdim(I->getKernels());
1154   ShapeHW sdim(I->getStrides());
1155   // Implement the avg pooling operation as defined here:
1156   // https://arxiv.org/abs/1312.4400
1157   float filterArea = kdim.height * kdim.width;
1158 
1159   auto inW = getWeightHandle<int8_t>(I->getSrc());
1160   auto outW = getWeightHandle<int8_t>(I->getDest());
1161   TensorQuantizationParams inQP{I->getSrc()->getType()->getScale(),
1162                                 I->getSrc()->getType()->getOffset()};
1163   TensorQuantizationParams outQP{I->getDest()->getType()->getScale(),
1164                                  I->getDest()->getType()->getOffset()};
1165 
1166   // For each input in the batch:
1167   for (dim_t n = 0; n < odim.n; n++) {
1168     // For each layer in the output tensor:
1169     for (dim_t z = 0; z < idim.c; z++) {
1170       // For each convolution 'jump' in the input tensor:
1171       ssize_t x = -ssize_t(pdim.top);
1172       for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1173         ssize_t y = -ssize_t(pdim.left);
1174         for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1175           int32_t sum = 0;
1176 
1177           for (dim_t fx = 0; fx < kdim.height; fx++) {
1178             for (dim_t fy = 0; fy < kdim.width; fy++) {
1179               sdim_t ox = x + fx;
1180               sdim_t oy = y + fy;
1181 
1182               // Ignore index access below zero (this is due to padding).
1183               if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
1184                   oy >= ssize_t(idim.w)) {
1185                 continue;
1186               }
1187 
1188               sum += inW.at({n, (dim_t)ox, (dim_t)oy, z}) - inQP.offset;
1189             }
1190           }
1191           // Instead of dividing by filterArea, just change scale.
1192           outW.at({n, ax, ay, z}) = quantization::clip<int32_t, int8_t>(
1193               std::round(float(sum) * (inQP.scale / outQP.scale / filterArea) +
1194                          outQP.offset));
1195         } // W
1196       }   // H
1197     }     // C
1198   }       // N
1199 }
1200 
1201 template <typename ElemTy>
fwdAvgPool3DInstFloatImpl(const AvgPoolInst * I)1202 void BoundInterpreterFunction::fwdAvgPool3DInstFloatImpl(const AvgPoolInst *I) {
1203   staticAssertFloatingPointType(ElemTy);
1204 
1205   ShapeNTHWC odim(I->getDest()->dims());
1206   ShapeNTHWC idim(I->getSrc()->dims());
1207 
1208   PaddingNFTBLR pdim(I->getPads());
1209   ShapeTHW kdim(I->getKernels());
1210   ShapeTHW sdim(I->getStrides());
1211   // Implement the avg pooling operation as defined here:
1212   // https://arxiv.org/abs/1312.4400
1213   float filterArea = kdim.temporal_frames * kdim.height * kdim.width;
1214 
1215   auto inW = getWeightHandle<ElemTy>(I->getSrc());
1216   auto outW = getWeightHandle<ElemTy>(I->getDest());
1217 
1218   // For each input in the batch:
1219   for (dim_t n = 0; n < odim.n; n++) {
1220     // For each layer in the output tensor:
1221     for (dim_t z = 0; z < idim.c; z++) {
1222       // For each convolution 'jump' in the input tensor:
1223       ssize_t t = -ssize_t(pdim.near);
1224       for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
1225         ssize_t x = -ssize_t(pdim.top);
1226         for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1227           ssize_t y = -ssize_t(pdim.left);
1228           for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1229             float sum = 0;
1230 
1231             for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
1232               for (dim_t fx = 0; fx < kdim.height; fx++) {
1233                 for (dim_t fy = 0; fy < kdim.width; fy++) {
1234                   sdim_t ot = t + ft;
1235                   sdim_t ox = x + fx;
1236                   sdim_t oy = y + fy;
1237 
1238                   // Ignore index access below zero (this is due to padding).
1239                   if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
1240                       ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) {
1241                     continue;
1242                   }
1243 
1244                   sum += float(inW.at({n, (dim_t)ot, (dim_t)ox, (dim_t)oy, z}));
1245                 }
1246               }
1247               outW.at({n, at, ax, ay, z}) = ElemTy(sum / filterArea);
1248             }
1249           } // W
1250         }   // H
1251       }     // T
1252     }       // C
1253   }         // N
1254 }
1255 
fwdAvgPool3DInstI8Impl(const AvgPoolInst * I)1256 void BoundInterpreterFunction::fwdAvgPool3DInstI8Impl(const AvgPoolInst *I) {
1257   ShapeNTHWC odim(I->getDest()->dims());
1258   ShapeNTHWC idim(I->getSrc()->dims());
1259 
1260   PaddingNFTBLR pdim(I->getPads());
1261   ShapeTHW kdim(I->getKernels());
1262   ShapeTHW sdim(I->getStrides());
1263   // Implement the avg pooling operation as defined here:
1264   // https://arxiv.org/abs/1312.4400
1265   float filterArea = kdim.temporal_frames * kdim.height * kdim.width;
1266 
1267   auto inW = getWeightHandle<int8_t>(I->getSrc());
1268   auto outW = getWeightHandle<int8_t>(I->getDest());
1269   TensorQuantizationParams inQP{I->getSrc()->getType()->getScale(),
1270                                 I->getSrc()->getType()->getOffset()};
1271   TensorQuantizationParams outQP{I->getDest()->getType()->getScale(),
1272                                  I->getDest()->getType()->getOffset()};
1273 
1274   // For each input in the batch:
1275   for (dim_t n = 0; n < odim.n; n++) {
1276     // For each layer in the output tensor:
1277     for (dim_t z = 0; z < idim.c; z++) {
1278       // For each convolution 'jump' in the input tensor:
1279       ssize_t t = -ssize_t(pdim.near);
1280       for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
1281         ssize_t x = -ssize_t(pdim.top);
1282         for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1283           ssize_t y = -ssize_t(pdim.left);
1284           for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1285             int32_t sum = 0;
1286 
1287             for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
1288               for (dim_t fx = 0; fx < kdim.height; fx++) {
1289                 for (dim_t fy = 0; fy < kdim.width; fy++) {
1290                   sdim_t ot = t + ft;
1291                   sdim_t ox = x + fx;
1292                   sdim_t oy = y + fy;
1293 
1294                   // Ignore index access below zero (this is due to padding).
1295                   if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
1296                       ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) {
1297                     continue;
1298                   }
1299 
1300                   sum += inW.at({n, (dim_t)ot, (dim_t)ox, (dim_t)oy, z}) -
1301                          inQP.offset;
1302                 }
1303               }
1304             }
1305             // Instead of dividing by filterArea, just change scale.
1306             outW.at({n, at, ax, ay, z}) =
1307                 quantization::clip<int32_t, int8_t>(std::round(
1308                     float(sum) * (inQP.scale / outQP.scale / filterArea) +
1309                     outQP.offset));
1310           } // W
1311         }   // H
1312       }     // T
1313     }       // C
1314   }         // N
1315 }
1316 
fwdAvgPoolInst(const AvgPoolInst * I)1317 void BoundInterpreterFunction::fwdAvgPoolInst(const AvgPoolInst *I) {
1318   bool isConv3D = is3DData(ConvolutionLayout(I->getLayout()));
1319   bool isQuantized = I->getSrc()->getType()->isQuantizedType();
1320 
1321   if (isConv3D) {
1322     if (isQuantized) {
1323       fwdAvgPool3DInstI8Impl(I);
1324     } else {
1325       dispatchFloatingPointImpl(fwdAvgPool3DInstFloatImpl,
1326                                 I->getSrc()->getElementType(), I);
1327     }
1328   } else {
1329     if (isQuantized) {
1330       fwdAvgPoolInstI8Impl(I);
1331     } else {
1332       dispatchFloatingPointImpl(fwdAvgPoolInstFloatImpl,
1333                                 I->getSrc()->getElementType(), I);
1334     }
1335   }
1336 }
1337 
1338 template <typename ElemTy>
fwdAdaptiveAvgPoolInstFloatImpl(const AdaptiveAvgPoolInst * I)1339 void BoundInterpreterFunction::fwdAdaptiveAvgPoolInstFloatImpl(
1340     const AdaptiveAvgPoolInst *I) {
1341   staticAssertFloatingPointType(ElemTy);
1342 
1343   ShapeNHWC odim(I->getDest()->dims());
1344   ShapeNHWC idim(I->getSrc()->dims());
1345 
1346   auto inW = getWeightHandle<ElemTy>(I->getSrc());
1347   auto outW = getWeightHandle<ElemTy>(I->getDest());
1348 
1349 // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
1350 #define START_IND(a, b, c) (size_t) std::floor((float)((a) * (c)) / (b))
1351 #define END_IND(a, b, c) (size_t) std::ceil((float)(((a) + 1) * (c)) / (b))
1352 
1353   // For each input in the batch:
1354   for (dim_t n = 0; n < odim.n; n++) {
1355     // For each layer in the output tensor:
1356     for (dim_t z = 0; z < idim.c; z++) {
1357       // For each value in the output tensor:
1358       for (dim_t ax = 0; ax < odim.h; ax++) {
1359 
1360         dim_t x = START_IND(ax, odim.h, idim.h);
1361         dim_t kH = END_IND(ax, odim.h, idim.h) - x;
1362 
1363         for (dim_t ay = 0; ay < odim.w; ay++) {
1364 
1365           dim_t y = START_IND(ay, odim.w, idim.w);
1366           dim_t kW = END_IND(ay, odim.w, idim.w) - y;
1367 
1368           float sum = 0;
1369           for (dim_t fx = 0; fx < kH; fx++) {
1370             for (dim_t fy = 0; fy < kW; fy++) {
1371               dim_t ox = x + fx;
1372               dim_t oy = y + fy;
1373 
1374               sum += float(inW.at({n, ox, oy, z}));
1375             }
1376           }
1377           outW.at({n, ax, ay, z}) = ElemTy(sum / kW / kH);
1378         } // W
1379       }   // H
1380     }     // C
1381   }       // N
1382 #undef START_IND
1383 #undef END_IND
1384 }
1385 
fwdAdaptiveAvgPoolInstI8Impl(const AdaptiveAvgPoolInst * I)1386 void BoundInterpreterFunction::fwdAdaptiveAvgPoolInstI8Impl(
1387     const AdaptiveAvgPoolInst *I) {
1388   ShapeNHWC odim(I->getDest()->dims());
1389   ShapeNHWC idim(I->getSrc()->dims());
1390 
1391   auto inW = getWeightHandle<int8_t>(I->getSrc());
1392   auto outW = getWeightHandle<int8_t>(I->getDest());
1393 
1394   TensorQuantizationParams inQP{I->getSrc()->getType()->getScale(),
1395                                 I->getSrc()->getType()->getOffset()};
1396   TensorQuantizationParams outQP{I->getDest()->getType()->getScale(),
1397                                  I->getDest()->getType()->getOffset()};
1398 
1399 // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
1400 #define START_IND(a, b, c) (size_t) std::floor((float)((a) * (c)) / (b))
1401 #define END_IND(a, b, c) (size_t) std::ceil((float)(((a) + 1) * (c)) / (b))
1402 
1403   // For each input in the batch:
1404   for (dim_t n = 0; n < odim.n; n++) {
1405     // For each layer in the output tensor:
1406     for (dim_t z = 0; z < idim.c; z++) {
1407       // For each value in the output tensor:
1408       for (dim_t ax = 0; ax < odim.h; ax++) {
1409 
1410         dim_t x = START_IND(ax, odim.h, idim.h);
1411         dim_t kH = END_IND(ax, odim.h, idim.h) - x;
1412 
1413         for (dim_t ay = 0; ay < odim.w; ay++) {
1414 
1415           dim_t y = START_IND(ay, odim.w, idim.w);
1416           dim_t kW = END_IND(ay, odim.w, idim.w) - y;
1417 
1418           int32_t sum = 0;
1419           for (dim_t fx = 0; fx < kH; fx++) {
1420             for (dim_t fy = 0; fy < kW; fy++) {
1421               dim_t ox = x + fx;
1422               dim_t oy = y + fy;
1423 
1424               sum += inW.at({n, ox, oy, z}) - inQP.offset;
1425             }
1426           }
1427 
1428           outW.at({n, ax, ay, z}) = quantization::clip<int32_t, int8_t>(
1429               std::round(float(sum) * (inQP.scale / outQP.scale / kW / kH) +
1430                          outQP.offset));
1431         } // W
1432       }   // H
1433     }     // C
1434   }       // N
1435 #undef START_IND
1436 #undef END_IND
1437 }
1438 
fwdAdaptiveAvgPoolInst(const AdaptiveAvgPoolInst * I)1439 void BoundInterpreterFunction::fwdAdaptiveAvgPoolInst(
1440     const AdaptiveAvgPoolInst *I) {
1441   if (I->getSrc()->getType()->isQuantizedType()) {
1442     fwdAdaptiveAvgPoolInstI8Impl(I);
1443     return;
1444   }
1445 
1446   dispatchFloatingPointImpl(fwdAdaptiveAvgPoolInstFloatImpl,
1447                             I->getSrc()->getElementType(), I);
1448 }
1449 
fwdAdaptiveAvgPoolGradInst(const AdaptiveAvgPoolGradInst * I)1450 void BoundInterpreterFunction::fwdAdaptiveAvgPoolGradInst(
1451     const AdaptiveAvgPoolGradInst *I) {
1452   auto inG = getWeightHandle(I->getSrcGrad());
1453   auto outW = getWeightHandle(I->getDest());
1454   auto outG = getWeightHandle(I->getDestGrad());
1455 
1456   inG.clear();
1457 
1458   ShapeNHWC odim(outW.dims());
1459   ShapeNHWC idim(inG.dims());
1460 
1461   const float gradCoefficient = 1. / (odim.h * odim.w);
1462 
1463 #define START_IND(a, b, c) (size_t) std::floor((float)((a) * (c)) / (b))
1464 #define END_IND(a, b, c) (size_t) std::ceil((float)(((a) + 1) * (c)) / (b))
1465 
1466   // https://software.intel.com/en-us/daal-programming-guide-2d-average-pooling-backward-layer
1467   // For each input in the batch:
1468   for (dim_t n = 0; n < odim.n; n++) {
1469     // For each layer in the output tensor:
1470     for (dim_t z = 0; z < idim.c; z++) {
1471       // For each value in the output tensor:
1472       for (dim_t ax = 0; ax < odim.h; ax++) {
1473 
1474         dim_t x = START_IND(ax, odim.h, idim.h);
1475         dim_t kH = END_IND(ax, odim.h, idim.h) - x;
1476 
1477         for (dim_t ay = 0; ay < odim.w; ay++) {
1478 
1479           dim_t y = START_IND(ay, odim.w, idim.w);
1480           dim_t kW = END_IND(ay, odim.w, idim.w) - y;
1481 
1482           const float chainGrad = outG.at({n, ax, ay, z}) * gradCoefficient;
1483 
1484           for (dim_t fx = 0; fx < kH; fx++) {
1485             for (dim_t fy = 0; fy < kW; fy++) {
1486               dim_t ox = x + fx;
1487               dim_t oy = y + fy;
1488 
1489               inG.at({n, ox, oy, z}) += chainGrad;
1490             }
1491           }
1492         } // W
1493       }   // H
1494     }     // C
1495   }       // N
1496 #undef START_IND
1497 #undef END_IND
1498 }
1499 
fwdMaxPoolWithArgmaxGradInst(const MaxPoolWithArgmaxGradInst * I)1500 void BoundInterpreterFunction::fwdMaxPoolWithArgmaxGradInst(
1501     const MaxPoolWithArgmaxGradInst *I) {
1502   auto inG = getWeightHandle(I->getSrcGrad());
1503   auto outW = getWeightHandle(I->getDest());
1504   auto outG = getWeightHandle(I->getDestGrad());
1505 
1506   inG.clear();
1507 
1508   ShapeNHWC idim(inG.dims());
1509   ShapeNHWC odim(outW.dims());
1510 
1511   auto argmax = getWeightHandle<int64_t>(I->getArgmax());
1512 
1513   // For each input in the batch:
1514   for (dim_t n = 0; n < odim.n; n++) {
1515 
1516     // Compute the gradient. For each layer in the output tensor:
1517     for (dim_t z = 0; z < odim.c; z++) {
1518 
1519       // For each convolution 'jump' in the input tensor:
1520       for (dim_t ax = 0; ax < odim.h; ax++) {
1521         for (dim_t ay = 0; ay < odim.w; ay++) {
1522           // Reuse precomputed linear index of max element from argmax.
1523           float chainGrad = outG.at({n, ax, ay, z});
1524           inG.raw(argmax.at({n, ax, ay, z})) += chainGrad;
1525         } // W
1526       }   // H
1527     }     // C
1528   }       // N
1529 }
1530 
fwdAvgPool2DGradInst(const AvgPoolGradInst * I)1531 void BoundInterpreterFunction::fwdAvgPool2DGradInst(const AvgPoolGradInst *I) {
1532   auto inG = getWeightHandle(I->getSrcGrad());
1533   auto outW = getWeightHandle(I->getDest());
1534   auto outG = getWeightHandle(I->getDestGrad());
1535 
1536   ShapeNHWC odim(outW.dims());
1537   ShapeNHWC idim(inG.dims());
1538 
1539   PaddingTLBR pdim(I->getPads());
1540   ShapeHW kdim(I->getKernels());
1541   ShapeHW sdim(I->getStrides());
1542 
1543   inG.clear();
1544 
1545   float filterArea = kdim.height * kdim.width;
1546 
1547   // For each input in the batch:
1548   for (dim_t n = 0; n < odim.n; n++) {
1549 
1550     // For each layer in the output tensor:
1551     for (dim_t z = 0; z < odim.c; z++) {
1552       // For each convolution 'jump' in the input tensor:
1553       ssize_t x = -ssize_t(pdim.top);
1554       for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1555         ssize_t y = -ssize_t(pdim.left);
1556         for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1557 
1558           float dy = outG.at({n, ax, ay, z}) / filterArea;
1559 
1560           for (dim_t fx = 0; fx < kdim.height; fx++) {
1561             for (dim_t fy = 0; fy < kdim.width; fy++) {
1562               ssize_t ox = x + fx;
1563               ssize_t oy = y + fy;
1564 
1565               // Ignore index access below zero (this is due to padding).
1566               if (ox < 0 || oy < 0 || ox >= ssize_t(idim.h) ||
1567                   oy >= ssize_t(idim.w)) {
1568                 continue;
1569               }
1570               inG.at({n, (dim_t)ox, (dim_t)oy, z}) += dy;
1571             }
1572           }
1573         } // W
1574       }   // H
1575     }     // C
1576   }       // N
1577 }
1578 
fwdAvgPool3DGradInst(const AvgPoolGradInst * I)1579 void BoundInterpreterFunction::fwdAvgPool3DGradInst(const AvgPoolGradInst *I) {
1580   auto inG = getWeightHandle(I->getSrcGrad());
1581   auto outW = getWeightHandle(I->getDest());
1582   auto outG = getWeightHandle(I->getDestGrad());
1583 
1584   ShapeNTHWC odim(outW.dims());
1585   ShapeNTHWC idim(inG.dims());
1586 
1587   PaddingNFTBLR pdim(I->getPads());
1588   ShapeTHW kdim(I->getKernels());
1589   ShapeTHW sdim(I->getStrides());
1590 
1591   inG.clear();
1592 
1593   float filterArea = kdim.temporal_frames * kdim.height * kdim.width;
1594 
1595   // For each input in the batch:
1596   for (dim_t n = 0; n < odim.n; n++) {
1597 
1598     // For each layer in the output tensor:
1599     for (dim_t z = 0; z < odim.c; z++) {
1600       // For each convolution 'jump' in the input tensor:
1601       ssize_t t = -ssize_t(pdim.near);
1602       for (dim_t at = 0; at < odim.t; t += sdim.temporal_frames, at++) {
1603         ssize_t x = -ssize_t(pdim.top);
1604         for (dim_t ax = 0; ax < odim.h; x += sdim.height, ax++) {
1605           ssize_t y = -ssize_t(pdim.left);
1606           for (dim_t ay = 0; ay < odim.w; y += sdim.width, ay++) {
1607 
1608             float dy = outG.at({n, at, ax, ay, z}) / filterArea;
1609 
1610             for (dim_t ft = 0; ft < kdim.temporal_frames; ft++) {
1611               for (dim_t fx = 0; fx < kdim.height; fx++) {
1612                 for (dim_t fy = 0; fy < kdim.width; fy++) {
1613                   ssize_t ot = t + ft;
1614                   ssize_t ox = x + fx;
1615                   ssize_t oy = y + fy;
1616 
1617                   // Ignore index access below zero (this is due to padding).
1618                   if (ot < 0 || ox < 0 || oy < 0 || ot >= ssize_t(idim.t) ||
1619                       ox >= ssize_t(idim.h) || oy >= ssize_t(idim.w)) {
1620                     continue;
1621                   }
1622                   inG.at({n, (dim_t)ot, (dim_t)ox, (dim_t)oy, z}) += dy;
1623                 }
1624               }
1625             }
1626           } // W
1627         }   // H
1628       }     // T
1629     }       // C
1630   }         // N
1631 }
1632 
fwdAvgPoolGradInst(const AvgPoolGradInst * I)1633 void BoundInterpreterFunction::fwdAvgPoolGradInst(const AvgPoolGradInst *I) {
1634   bool isConv3D = is3DData(ConvolutionLayout(I->getLayout()));
1635 
1636   if (isConv3D) {
1637     fwdAvgPool3DGradInst(I);
1638   } else {
1639     fwdAvgPool2DGradInst(I);
1640   }
1641 }
1642 
1643 //===----------------------------------------------------------------------===//
1644 //                       Activation functions
1645 //===----------------------------------------------------------------------===//
1646 
fwdReluInst(const ReluInst *)1647 void BoundInterpreterFunction::fwdReluInst(const ReluInst *) {
1648   DCHECK(!"Found ReluInst but Relu is lowered on Interpreter");
1649 }
1650 
fwdClipInst(const ClipInst *)1651 void BoundInterpreterFunction::fwdClipInst(const ClipInst *) {
1652   DCHECK(!"Found ClipInst but Clip is lowered on Interpreter");
1653 }
1654 
1655 template <typename ElemTy>
fwdSigmoidInstFloatImpl(const SigmoidInst * I)1656 void BoundInterpreterFunction::fwdSigmoidInstFloatImpl(const SigmoidInst *I) {
1657   staticAssertFloatingPointType(ElemTy);
1658 
1659   auto inW = getWeightHandle<ElemTy>(I->getSrc());
1660   auto outW = getWeightHandle<ElemTy>(I->getDest());
1661 
1662   for (dim_t i = 0, e = outW.size(); i < e; i++) {
1663     float val = inW.raw(i);
1664     outW.raw(i) = ElemTy(1 / (1 + std::exp(-val)));
1665   }
1666 }
1667 
fwdSigmoidInst(const SigmoidInst * I)1668 void BoundInterpreterFunction::fwdSigmoidInst(const SigmoidInst *I) {
1669   dispatchFloatingPointImpl(fwdSigmoidInstFloatImpl,
1670                             I->getSrc()->getElementType(), I);
1671 }
1672 
1673 template <typename ElemTy>
fwdTanhInstFloatImpl(const TanhInst * I)1674 void BoundInterpreterFunction::fwdTanhInstFloatImpl(const TanhInst *I) {
1675   staticAssertFloatingPointType(ElemTy);
1676 
1677   auto inW = getWeightHandle<ElemTy>(I->getSrc());
1678   auto outW = getWeightHandle<ElemTy>(I->getDest());
1679 
1680   for (dim_t i = 0, e = inW.size(); i < e; i++) {
1681     float val = inW.raw(i);
1682     outW.raw(i) = ElemTy(std::tanh(val));
1683   }
1684 }
1685 
fwdTanhInst(const TanhInst * I)1686 void BoundInterpreterFunction::fwdTanhInst(const TanhInst *I) {
1687   dispatchFloatingPointImpl(fwdTanhInstFloatImpl, I->getSrc()->getElementType(),
1688                             I);
1689 }
1690 
1691 //===----------------------------------------------------------------------===//
1692 //                        Loss Functions (Softmax/regression/...)
1693 //===----------------------------------------------------------------------===//
1694 
1695 template <typename ElemTy>
fwdSoftMaxInstImpl(const SoftMaxInst * I)1696 void BoundInterpreterFunction::fwdSoftMaxInstImpl(const SoftMaxInst *I) {
1697   staticAssertFloatingPointType(ElemTy);
1698 
1699   auto inW = getWeightHandle<ElemTy>(I->getSrc());
1700   auto outW = getWeightHandle<ElemTy>(I->getDest());
1701   auto idim = inW.dims();
1702 
1703   for (dim_t n = 0; n < idim[0]; n++) {
1704     // Find Max.
1705     float max = float(inW.at({n, 0}));
1706     for (dim_t i = 1; i < idim[1]; i++) {
1707       max = std::max(max, float(inW.at({n, i})));
1708     }
1709 
1710     // Compute exp.
1711     float sum = 0;
1712     for (dim_t i = 0; i < idim[1]; i++) {
1713       float e = std::exp(float(inW.at({n, i})) - max);
1714       sum += e;
1715       outW.at({n, i}) = ElemTy(e);
1716     }
1717 
1718     // Normalize the output.
1719     for (dim_t i = 0; i < idim[1]; i++) {
1720       outW.at({n, i}) = ElemTy(float(outW.at({n, i})) / sum);
1721     }
1722   } // N
1723 }
1724 
fwdSoftMaxInst(const SoftMaxInst * I)1725 void BoundInterpreterFunction::fwdSoftMaxInst(const SoftMaxInst *I) {
1726   dispatchFloatingPointImpl(fwdSoftMaxInstImpl, I->getSrc()->getElementType(),
1727                             I);
1728 }
1729 
fwdSoftMaxGradInst(const SoftMaxGradInst * I)1730 void BoundInterpreterFunction::fwdSoftMaxGradInst(const SoftMaxGradInst *I) {
1731   auto inG = getWeightHandle(I->getSrcGrad());
1732   auto idim = inG.dims();
1733   auto outW = getWeightHandle(I->getOrigDest());
1734   auto selectedH = getWeightHandle<int64_t>(I->getSelected());
1735 
1736   inG.clear();
1737 
1738   // http://eli.thegreenplace.net/2016/the-softmax-function-and-its-derivative/
1739   // https://stats.stackexchange.com/questions/79454/softmax-layer-in-a-neural-network
1740   for (dim_t n = 0; n < idim[0]; n++) {
1741     for (dim_t i = 0; i < idim[1]; i++) {
1742       float delta = (selectedH.at({n, 0}) == (int64_t)i);
1743       inG.at({n, i}) = outW.at({n, i}) - delta;
1744     }
1745   }
1746 }
1747 
1748 template <typename ElemTy>
fwdCrossEntropyLossInstFloatImpl(const CrossEntropyLossInst * I)1749 void BoundInterpreterFunction::fwdCrossEntropyLossInstFloatImpl(
1750     const CrossEntropyLossInst *I) {
1751   staticAssertFloatingPointType(ElemTy);
1752 
1753   auto P = getWeightHandle<ElemTy>(I->getP());
1754   auto labels = getWeightHandle<int64_t>(I->getLabels());
1755   auto CE = getWeightHandle<ElemTy>(I->getCE());
1756   auto dims = P.dims();
1757   CE.clear();
1758   for (dim_t n = 0; n < dims[0]; ++n) {
1759     assert(labels.raw(n) >= 0 && "Cannot use negative index.");
1760     dim_t y = labels.raw(n);
1761     float p_n = P.at({n, y});
1762     CE.at({0}) -= log(p_n);
1763   }
1764 }
1765 
fwdCrossEntropyLossInst(const CrossEntropyLossInst * I)1766 void BoundInterpreterFunction::fwdCrossEntropyLossInst(
1767     const CrossEntropyLossInst *I) {
1768   dispatchFloatingPointImpl(fwdCrossEntropyLossInstFloatImpl,
1769                             I->getP()->getElementType(), I);
1770 }
1771 
fwdCrossEntropyLossGradInst(const CrossEntropyLossGradInst * I)1772 void BoundInterpreterFunction::fwdCrossEntropyLossGradInst(
1773     const CrossEntropyLossGradInst *I) {
1774   auto P = getWeightHandle(I->getP());
1775   auto Labels = getWeightHandle<int64_t>(I->getLabels());
1776   auto PGrad = getWeightHandle(I->getPgrad());
1777   auto dims = PGrad.dims();
1778   PGrad.clear();
1779   for (dim_t n = 0; n < dims[0]; ++n) {
1780     assert(Labels.raw(n) >= 0 && "Cannot use negative index.");
1781     dim_t y = Labels.raw(n);
1782     PGrad.at({n, y}) = -1 / P.at({n, y}); // * CEGrad.at({0})
1783   }
1784 }
1785 
1786 //===----------------------------------------------------------------------===//
1787 //                       Tensor shape (copy/transpose/concat/...)
1788 //===----------------------------------------------------------------------===//
1789 
fwdCopyInst(const CopyInst * I)1790 void BoundInterpreterFunction::fwdCopyInst(const CopyInst *I) {
1791   auto inT = getTensor(I->getSrc());
1792   auto outT = getTensor(I->getDest());
1793   outT->copyRawFrom(inT);
1794 }
1795 
fwdTransposeInst(const TransposeInst * I)1796 void BoundInterpreterFunction::fwdTransposeInst(const TransposeInst *I) {
1797   auto inT = getTensor(I->getSrc());
1798   (void)inT;
1799   auto outT = getTensor(I->getDest());
1800 
1801   assert(outT->size() == inT->size() && "Invalid tensor dimensions");
1802 
1803   if (I->getSrc()->getType()->isQuantizedType()) {
1804     inT->transpose(outT, I->getShuffle());
1805   } else {
1806     inT->transpose(outT, I->getShuffle());
1807   }
1808 }
1809 
fwdTensorViewInst(const TensorViewInst * I)1810 void BoundInterpreterFunction::fwdTensorViewInst(const TensorViewInst *I) {
1811   getOrCreateUnownedTensor(I, I->getSrc(), I->getOffsets());
1812 }
1813 
fwdSplatInst(const glow::SplatInst * I)1814 void BoundInterpreterFunction::fwdSplatInst(const glow::SplatInst *I) {
1815   auto *T = getTensor(I->getDest());
1816   ElemKind k = T->getElementType();
1817 
1818   if (k == ElemKind::Int32ITy) {
1819     return T->getHandle<int32_t>().clear(I->getValue());
1820   }
1821 
1822   if (k == ElemKind::Int64ITy) {
1823     return T->getHandle<int64_t>().clear(I->getValue());
1824   }
1825 
1826   if (k == ElemKind::Int32ITy) {
1827     return T->getHandle<int32_t>().clear(I->getValue());
1828   }
1829 
1830   if (k == ElemKind::FloatTy) {
1831     return T->getHandle<float>().clear(I->getValue());
1832   }
1833 
1834   if (k == ElemKind::Float16Ty) {
1835     return T->getHandle<float16_t>().clear(I->getValue());
1836   }
1837 
1838   if (k == ElemKind::BFloat16Ty) {
1839     return T->getHandle<bfloat16_t>().clear(I->getValue());
1840   }
1841 
1842   if (k == ElemKind::BoolTy) {
1843     return T->getHandle<bool>().clear(static_cast<bool>(I->getValue()));
1844   }
1845 
1846   if (k == ElemKind::Int8QTy) {
1847     // Quantize the requested floating point splat value into the correct
1848     // integer representation.
1849     auto destTy = I->getDest()->getType();
1850     TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()};
1851     float val = I->getValue();
1852     return T->getHandle<int8_t>().clear(quantization::quantize(val, destQ));
1853   }
1854 
1855   if (k == ElemKind::BoolTy) {
1856     return T->getHandle<bool>().clear(static_cast<bool>(I->getValue()));
1857   }
1858 
1859   llvm_unreachable("Unsupported tensor type");
1860 }
1861 
fwdTouchInst(const glow::TouchInst *)1862 void BoundInterpreterFunction::fwdTouchInst(const glow::TouchInst *) {
1863   // Do nothing for a TouchInst
1864 }
1865 
fwdInsertTensorInst(const glow::InsertTensorInst * I)1866 void BoundInterpreterFunction::fwdInsertTensorInst(
1867     const glow::InsertTensorInst *I) {
1868   Tensor *outT = getTensor(I->getDest());
1869   Tensor *inT = getTensor(I->getSrc());
1870   ElemKind k = outT->getElementType();
1871 #define TYPED_INSERT(TY, TYPEKIND)                                             \
1872   if (k == TYPEKIND) {                                                         \
1873     auto OH = outT->getHandle<TY>();                                           \
1874     auto IH = inT->getHandle<TY>();                                            \
1875     return OH.insertTensors(IH, I->getOffsets(), I->getCount(), I->getAxis()); \
1876   }
1877 
1878   TYPED_INSERT(int64_t, ElemKind::Int64ITy);
1879   TYPED_INSERT(int32_t, ElemKind::Int32ITy);
1880   TYPED_INSERT(float, ElemKind::FloatTy);
1881   TYPED_INSERT(float16_t, ElemKind::Float16Ty);
1882   TYPED_INSERT(bfloat16_t, ElemKind::BFloat16Ty);
1883   TYPED_INSERT(int8_t, ElemKind::Int8QTy);
1884   TYPED_INSERT(bool, ElemKind::BoolTy);
1885 #undef TYPED_INSERT
1886 
1887   llvm_unreachable("Unsupported tensor type");
1888 }
1889 
fwdExtractTensorInst(const glow::ExtractTensorInst * I)1890 void BoundInterpreterFunction::fwdExtractTensorInst(
1891     const glow::ExtractTensorInst *I) {
1892   Tensor *outT = getTensor(I->getDest());
1893   Tensor *inT = getTensor(I->getSrc());
1894   ElemKind k = outT->getElementType();
1895 #define TYPED_INSERT(TY, TYPEKIND)                                             \
1896   if (k == TYPEKIND) {                                                         \
1897     auto OH = outT->getHandle<TY>();                                           \
1898     auto IH = inT->getHandle<TY>();                                            \
1899     return IH.extractTensors(OH, I->getOffsets());                             \
1900   }
1901 
1902   TYPED_INSERT(int64_t, ElemKind::Int64ITy);
1903   TYPED_INSERT(float, ElemKind::FloatTy);
1904   TYPED_INSERT(float16_t, ElemKind::Float16Ty);
1905   TYPED_INSERT(bfloat16_t, ElemKind::BFloat16Ty);
1906   TYPED_INSERT(int8_t, ElemKind::Int8QTy);
1907   TYPED_INSERT(int32_t, ElemKind::Int32QTy);
1908   TYPED_INSERT(int32_t, ElemKind::Int32ITy);
1909 #undef TYPED_INSERT
1910 
1911   llvm_unreachable("Unsupported tensor type");
1912 }
1913 
1914 template <typename ElemTy>
fwdGatherInstImpl(const glow::GatherInst * I)1915 void BoundInterpreterFunction::fwdGatherInstImpl(const glow::GatherInst *I) {
1916   Tensor *dataT = getTensor(I->getData());
1917   auto &dataTy = dataT->getType();
1918   Tensor *indicesT = getTensor(I->getIndices());
1919   Tensor *outT = getTensor(I->getDest());
1920   unsigned_t batchDims = I->getBatchDims();
1921 
1922   size_t out_p = 0;
1923   dim_t elementSize = dataTy.getElementSize();
1924   // The size of the sample in the batch.
1925   dim_t dataSampleSize = dataTy.getSliceSize(batchDims) * elementSize;
1926   // The size of the slices that we gather.
1927   dim_t dataSliceSize = dataTy.getSliceSize(batchDims + 1) * elementSize;
1928 
1929   // Calculate the size of each sample in the batch.
1930   dim_t numSamples = (dataT->size() * elementSize) / dataSampleSize;
1931 
1932   // Calculate number of samples in the batch.
1933   dim_t batchSize = dataTy.dims()[batchDims];
1934   (void)batchSize;
1935 
1936   // For each sample in the batch:
1937   for (dim_t sample = 0; sample < numSamples; sample++) {
1938     dim_t sampleStart = sample * dataSampleSize;
1939 
1940     // For each slice (small fragment) that we copy from the source memory:
1941     for (dim_t i = 0, end = indicesT->size(); i < end; i++) {
1942       dim_t slice = indicesT->getHandle<ElemTy>().raw(i);
1943       assert(slice < batchSize && "Invalid index seen during Gather operation");
1944       std::copy(
1945           &dataT->getUnsafePtr()[sampleStart + dataSliceSize * slice],
1946           &dataT->getUnsafePtr()[sampleStart + dataSliceSize * (slice + 1)],
1947           &outT->getUnsafePtr()[out_p]);
1948       out_p += dataSliceSize;
1949     }
1950   }
1951 }
1952 
fwdGatherInst(const glow::GatherInst * I)1953 void BoundInterpreterFunction::fwdGatherInst(const glow::GatherInst *I) {
1954   switch (I->getIndices()->getElementType()) {
1955   case ElemKind::Int64ITy:
1956     fwdGatherInstImpl<int64_t>(I);
1957     break;
1958   case ElemKind::Int32ITy:
1959     fwdGatherInstImpl<int32_t>(I);
1960     break;
1961   default:
1962     llvm_unreachable("Unsupported type for indices input of Gather.");
1963   }
1964 }
1965 
1966 template <typename ElemTy>
fwdGatherRangesInstImpl(const glow::GatherRangesInst * I)1967 void BoundInterpreterFunction::fwdGatherRangesInstImpl(
1968     const glow::GatherRangesInst *I) {
1969   Tensor *dataT = getTensor(I->getData());
1970   auto &dataTy = dataT->getType();
1971   Tensor *rangesT = getTensor(I->getRanges());
1972   auto &rangesTy = rangesT->getType();
1973   Tensor *outT = getTensor(I->getOutput());
1974   Tensor *lengthsT = getTensor(I->getLengths());
1975 
1976   // Offset into the output tensor that keeps track of where to start
1977   // copying data.
1978   size_t outP = 0;
1979 
1980   unsigned dataElementSize = dataTy.getElementSize();
1981   dim_t numExamples = rangesTy.dims()[0];
1982   dim_t exampleSize = rangesTy.dims()[1];
1983 
1984   // Keep track of the total number of elements gathered across all
1985   // examples for a sanity check later.
1986   dim_t grandTotalLen = 0;
1987 
1988   // For each example in ranges:
1989   for (dim_t example = 0; example < numExamples; ++example) {
1990     // Keep a running total of the lengths of all ranges in this example
1991     // to record into lengthsT once the entire example is processed.
1992     ElemTy totalLen = 0;
1993 
1994     // For each range in the example:
1995     for (dim_t range = 0; range < exampleSize; ++range) {
1996       // Get the start index and range length.
1997       ElemTy startIdx = rangesT->getHandle<ElemTy>().at({example, range, 0});
1998       ElemTy len = rangesT->getHandle<ElemTy>().at({example, range, 1});
1999 
2000       // Add the length of this current range to the example length counter.
2001       totalLen += len;
2002 
2003       // Compute the start and end offsets.
2004       dim_t startOffset = startIdx * dataElementSize;
2005       dim_t endOffset = startOffset + (len * dataElementSize);
2006 
2007       // Sanity checks on the offsets.
2008       assert(startOffset < dataT->getSizeInBytes());
2009       assert(endOffset <= dataT->getSizeInBytes());
2010       assert(endOffset >= startOffset);
2011       assert(outP < outT->getSizeInBytes());
2012       assert((outP + (len * dataElementSize)) <= outT->getSizeInBytes());
2013 
2014       // Copy the specified data to outT.
2015       std::copy(&dataT->getUnsafePtr()[startOffset],
2016                 &dataT->getUnsafePtr()[endOffset], &outT->getUnsafePtr()[outP]);
2017 
2018       // Advance the offset into outT.
2019       outP += len * dataElementSize;
2020     }
2021 
2022     // Record the total number of elements gathered for the example in
2023     // lengthsT.
2024     lengthsT->getHandle<ElemTy>().at({example}) = totalLen;
2025 
2026     // Add the total length of the entire example to the grand total.
2027     grandTotalLen += static_cast<size_t>(totalLen);
2028   }
2029 
2030   // Make sure that number of elements written to outT is equal to the
2031   // total of all elements in lengthsT.
2032   assert(grandTotalLen == (outP / dataElementSize));
2033 }
2034 
fwdGatherRangesInst(const glow::GatherRangesInst * I)2035 void BoundInterpreterFunction::fwdGatherRangesInst(
2036     const glow::GatherRangesInst *I) {
2037   switch (I->getRanges()->getElementType()) {
2038   case ElemKind::Int64ITy:
2039     fwdGatherRangesInstImpl<int64_t>(I);
2040     break;
2041   case ElemKind::Int32ITy:
2042     fwdGatherRangesInstImpl<int32_t>(I);
2043     break;
2044   default:
2045     llvm_unreachable("Unsupported type for ranges input of GatherRanges.");
2046   }
2047 }
2048 
2049 template <typename ElemTy>
fwdScatterDataInstCopyImpl(const glow::ScatterDataInst * I)2050 void BoundInterpreterFunction::fwdScatterDataInstCopyImpl(
2051     const glow::ScatterDataInst *I) {
2052   Tensor *dataT = getTensor(I->getData());
2053   Tensor *indicesT = getTensor(I->getIndices());
2054   Tensor *slicesT = getTensor(I->getSlices());
2055 
2056   assert(indicesT->dims().size() == 2 &&
2057          "Index should be stored in 2D tensor!");
2058   const dim_t dataSliceSize = slicesT->size() / slicesT->dims()[0] *
2059                               slicesT->getType().getElementSize();
2060 
2061   auto IH = indicesT->getHandle<int64_t>();
2062   // For each index, copy from the slice at that index into the location in data
2063   // given the offset from the indices tensor.
2064   for (dim_t i = 0, end = indicesT->dims()[0]; i < end; i++) {
2065     dim_t destDataIdx = 0;
2066     for (dim_t j = 0, e = indicesT->dims()[1]; j < e; j++) {
2067       destDataIdx *= dataT->dims()[j];
2068       destDataIdx += IH.at({i, j});
2069     }
2070     std::copy(&slicesT->getUnsafePtr()[i * dataSliceSize],
2071               &slicesT->getUnsafePtr()[(i + 1) * dataSliceSize],
2072               &dataT->getUnsafePtr()[dataSliceSize * destDataIdx]);
2073   }
2074 }
2075 
2076 template <typename ElemTy>
fwdScatterDataInstAddFloatImpl(const glow::ScatterDataInst * I)2077 void BoundInterpreterFunction::fwdScatterDataInstAddFloatImpl(
2078     const glow::ScatterDataInst *I) {
2079   Tensor *dataT = getTensor(I->getData());
2080   Tensor *indicesT = getTensor(I->getIndices());
2081   Tensor *slicesT = getTensor(I->getSlices());
2082 
2083   assert(!dataT->getType().isQuantizedType() && "Should be float type!");
2084   assert(!slicesT->getType().isQuantizedType() && "Should be float type!");
2085 
2086   const size_t numSlices = slicesT->size() / slicesT->dims()[0];
2087 
2088   auto IH = indicesT->getHandle<int64_t>();
2089   // For each index, copy from the slice at that index into the location in data
2090   // given the offset from the indices tensor.
2091   assert(indicesT->dims().size() == 2 &&
2092          "Multi-dimensional index should be stored in 2D tensor!");
2093   auto D = dataT->getHandle<ElemTy>(), S = slicesT->getHandle<ElemTy>();
2094   for (dim_t i = 0, end = indicesT->dims()[0]; i < end; i++) {
2095     size_t destDataIdx = 0;
2096     for (dim_t j = 0, e = indicesT->dims()[1]; j < e; j++) {
2097       destDataIdx *= dataT->dims()[j];
2098       destDataIdx += IH.at({i, j});
2099     }
2100     for (dim_t j = 0; j < numSlices; j++) {
2101       D.raw(destDataIdx * numSlices + j) += S.raw(i * numSlices + j);
2102     }
2103   }
2104 }
2105 
2106 template <typename ElemTy>
fwdScatterDataInstAddQuantizedImpl(const glow::ScatterDataInst * I)2107 void BoundInterpreterFunction::fwdScatterDataInstAddQuantizedImpl(
2108     const glow::ScatterDataInst *I) {
2109   Tensor *dataT = getTensor(I->getData());
2110   Tensor *indicesT = getTensor(I->getIndices());
2111   Tensor *slicesT = getTensor(I->getSlices());
2112 
2113   assert(dataT->getType().isQuantizedType() && "Should be quantized type!");
2114   assert(slicesT->getType().isQuantizedType() && "Should be quantized type!");
2115 
2116   const dim_t numSlices = slicesT->size() / slicesT->dims()[0];
2117 
2118   TensorQuantizationParams dataQ{dataT->getType().getScale(),
2119                                  dataT->getType().getOffset()};
2120   TensorQuantizationParams sliceQ{slicesT->getType().getScale(),
2121                                   slicesT->getType().getOffset()};
2122 
2123   auto IH = indicesT->getHandle<int64_t>();
2124   // For each index, copy from the slice at that index into the location in data
2125   // given the offset from the indices tensor.
2126   assert(indicesT->dims().size() == 2 &&
2127          "Multi-dimensional index should be stored in 2D tensor!");
2128   auto D = dataT->getHandle<ElemTy>(), S = slicesT->getHandle<ElemTy>();
2129   for (dim_t i = 0, end = indicesT->dims()[0]; i < end; i++) {
2130     dim_t destDataIdx = 0;
2131     for (dim_t j = 0, e = indicesT->dims()[1]; j < e; j++) {
2132       destDataIdx *= dataT->dims()[j];
2133       destDataIdx += IH.at({i, j});
2134     }
2135     for (dim_t j = 0; j < numSlices; j++) {
2136       float lhs =
2137           quantization::dequantize(D.raw(destDataIdx * numSlices + j), dataQ);
2138       float rhs = quantization::dequantize(S.raw(i * numSlices + j), sliceQ);
2139       ElemTy result = quantization::quantize(lhs + rhs, dataQ);
2140       D.raw(destDataIdx * numSlices + j) = result;
2141     }
2142   }
2143 }
2144 
fwdScatterDataInst(const glow::ScatterDataInst * I)2145 void BoundInterpreterFunction::fwdScatterDataInst(
2146     const glow::ScatterDataInst *I) {
2147   if (I->getCumulative()) {
2148     switch (I->getData()->getElementType()) {
2149     case ElemKind::FloatTy:
2150       fwdScatterDataInstAddFloatImpl<float>(I);
2151       break;
2152     case ElemKind::Int8QTy:
2153       fwdScatterDataInstAddQuantizedImpl<int8_t>(I);
2154       break;
2155     default:
2156       llvm_unreachable("Unsupported type for ScatterData.");
2157     }
2158   } else {
2159     switch (I->getData()->getElementType()) {
2160     case ElemKind::FloatTy:
2161       fwdScatterDataInstCopyImpl<float>(I);
2162       break;
2163     case ElemKind::Int8QTy:
2164       fwdScatterDataInstCopyImpl<int8_t>(I);
2165       break;
2166     default:
2167       llvm_unreachable("Unsupported type for ScatterData.");
2168     }
2169   }
2170 }
2171 
2172 template <typename ElemTy>
fwdBatchOneHotImpl(const glow::BatchOneHotInst * I)2173 void BoundInterpreterFunction::fwdBatchOneHotImpl(
2174     const glow::BatchOneHotInst *I) {
2175   auto dataH = getWeightHandle<ElemTy>(I->getData());
2176   auto lengthsH = getWeightHandle<int32_t>(I->getLengths());
2177   auto valuesH = getWeightHandle<ElemTy>(I->getValues());
2178   auto destH = getWeightHandle<ElemTy>(I->getDest());
2179 
2180   auto batchSize = dataH.dims()[0];
2181   auto featureCnt = dataH.dims()[1];
2182 
2183   for (dim_t batchId = 0; batchId < batchSize; batchId++) {
2184     size_t offset = 0;
2185     for (dim_t featureId = 0; featureId < featureCnt; featureId++) {
2186       auto curValue = dataH.at({batchId, featureId});
2187       auto curLength = lengthsH.at({featureId});
2188       for (dim_t i = offset, e = offset + curLength; i != e; i++) {
2189         destH.at({batchId, i}) = curValue == valuesH.at({i});
2190       }
2191       offset += curLength;
2192     }
2193     assert(offset == destH.dims()[1] &&
2194            "Sum of Lengths must be equal to size of Values");
2195   }
2196 }
2197 
fwdBatchOneHotInst(const glow::BatchOneHotInst * I)2198 void BoundInterpreterFunction::fwdBatchOneHotInst(
2199     const glow::BatchOneHotInst *I) {
2200   switch (I->getData()->getElementType()) {
2201   case ElemKind::Int64ITy:
2202     fwdBatchOneHotImpl<int64_t>(I);
2203     break;
2204   case ElemKind::Int32ITy:
2205     fwdBatchOneHotImpl<int32_t>(I);
2206     break;
2207   case ElemKind::Int8QTy:
2208     fwdBatchOneHotImpl<int8_t>(I);
2209     break;
2210   default:
2211     dispatchFloatingPointImpl(fwdBatchOneHotImpl,
2212                               I->getData()->getElementType(), I);
2213   }
2214 }
2215 
2216 template <typename ElemTy>
fwdSpaceToDepthInstImpl(const glow::SpaceToDepthInst * I)2217 void BoundInterpreterFunction::fwdSpaceToDepthInstImpl(
2218     const glow::SpaceToDepthInst *I) {
2219   auto *inT = getTensor(I->getSrc());
2220   auto *outT = getTensor(I->getDest());
2221 
2222   auto inH = inT->getHandle<ElemTy>();
2223   auto outH = outT->getHandle<ElemTy>();
2224 
2225   unsigned blockSize = I->getBlockSize();
2226 
2227   dim_t inDepth = inT->dims()[3];
2228 
2229   dim_t outBatch = outT->dims()[0];
2230   dim_t outHeight = outT->dims()[1];
2231   dim_t outWidth = outT->dims()[2];
2232   dim_t outDepth = outT->dims()[3];
2233 
2234   for (dim_t ob = 0; ob < outBatch; ++ob) {
2235     for (dim_t oh = 0; oh < outHeight; ++oh) {
2236       for (dim_t ow = 0; ow < outWidth; ++ow) {
2237         for (dim_t oc = 0; oc < outDepth; ++oc) {
2238           // Gets the block layer we are on
2239           dim_t blockDepthLayer = oc / inDepth;
2240           // every multiple of block size we reset to 0 offset
2241           dim_t iw = ow * blockSize + blockDepthLayer % blockSize;
2242           // every multiple of blockSize we start height traversal + 1
2243           dim_t ih = oh * blockSize + blockDepthLayer / blockSize;
2244           // at every multiple of inDepth index in to input depths resets to 0
2245           dim_t ic = oc % inDepth;
2246 
2247           outH.at({ob, oh, ow, oc}) = inH.at({ob, ih, iw, ic});
2248         }
2249       }
2250     }
2251   }
2252 }
2253 
fwdSpaceToDepthInst(const glow::SpaceToDepthInst * I)2254 void BoundInterpreterFunction::fwdSpaceToDepthInst(
2255     const glow::SpaceToDepthInst *I) {
2256   switch (I->getSrc()->getElementType()) {
2257   case ElemKind::FloatTy:
2258     fwdSpaceToDepthInstImpl<float>(I);
2259     break;
2260   case ElemKind::Int8QTy:
2261     fwdSpaceToDepthInstImpl<int8_t>(I);
2262     break;
2263   default:
2264     llvm_unreachable("Type is not supported");
2265     break;
2266   }
2267 }
2268 
2269 template <typename ElemTy>
fwdResizeNearestInstImpl(const ResizeNearestInst * I)2270 void BoundInterpreterFunction::fwdResizeNearestInstImpl(
2271     const ResizeNearestInst *I) {
2272   auto inW = getWeightHandle<ElemTy>(I->getSrc());
2273   auto scale = I->getScale();
2274   auto outW = getWeightHandle<ElemTy>(I->getDest());
2275 
2276   ShapeNHWC odim(outW.dims());
2277   ShapeNHWC idim(inW.dims());
2278 
2279   for (dim_t ob = 0; ob < odim.n; ++ob) {
2280     auto ib = std::min(dim_t(ob / scale[0]), idim.n - 1);
2281     for (dim_t oh = 0; oh < odim.h; ++oh) {
2282       auto ih = std::min(dim_t(oh / scale[1]), idim.h - 1);
2283       for (dim_t ow = 0; ow < odim.w; ++ow) {
2284         auto iw = std::min(dim_t(ow / scale[2]), idim.w - 1);
2285         for (dim_t oc = 0; oc < odim.c; ++oc) {
2286           auto ic = std::min(dim_t(oc / scale[3]), idim.c - 1);
2287           outW.at({ob, oh, ow, oc}) = inW.at({ib, ih, iw, ic});
2288         }
2289       }
2290     }
2291   }
2292 }
2293 
fwdResizeNearestInst(const ResizeNearestInst * I)2294 void BoundInterpreterFunction::fwdResizeNearestInst(
2295     const ResizeNearestInst *I) {
2296   if (getTensor(I->getSrc())->getType().isQuantizedType()) {
2297     dispatchQuantizedImpl(fwdResizeNearestInstImpl,
2298                           I->getSrc()->getElementType(), I);
2299     return;
2300   }
2301 
2302   dispatchImpl(fwdResizeNearestInstImpl, I->getSrc()->getElementType(), I);
2303 }
2304 
2305 template <typename ElemTy>
fwdResizeBilinearInstImpl(const ResizeBilinearInst * I)2306 void BoundInterpreterFunction::fwdResizeBilinearInstImpl(
2307     const ResizeBilinearInst *I) {
2308   auto inW = getWeightHandle<ElemTy>(I->getSrc());
2309   auto scale = I->getScale();
2310   auto outW = getWeightHandle<ElemTy>(I->getDest());
2311 
2312   ShapeNHWC odim(outW.dims());
2313   ShapeNHWC idim(inW.dims());
2314 
2315   CHECK_EQ(scale[0], 1.0) << "Scaling batch not supported.";
2316   CHECK_EQ(scale[3], 1.0) << "Scaling channel not supported.";
2317 
2318   for (dim_t ob = 0; ob < odim.n; ++ob) {
2319     for (dim_t oh = 0; oh < odim.h; ++oh) {
2320       for (dim_t ow = 0; ow < odim.w; ++ow) {
2321 
2322         float ihf = oh / scale[1];
2323         float iwf = ow / scale[2];
2324         dim_t ih = dim_t(ihf);
2325         dim_t iw = dim_t(iwf);
2326 
2327         auto ih0 = std::min(ih, idim.h - 1);
2328         auto ih1 = std::min(ih + 1, idim.h - 1);
2329         auto iw0 = std::min(iw, idim.w - 1);
2330         auto iw1 = std::min(iw + 1, idim.w - 1);
2331 
2332         for (dim_t oc = 0; oc < odim.c; ++oc) {
2333           auto v00 = inW.at({ob, ih0, iw0, oc});
2334           auto v01 = inW.at({ob, ih0, iw1, oc});
2335           auto v10 = inW.at({ob, ih1, iw0, oc});
2336           auto v11 = inW.at({ob, ih1, iw1, oc});
2337 
2338           auto hd = (float)v00 + (float)(v10 - v00) * (ihf - ih);
2339           auto hw = (float)v01 + (float)(v11 - v01) * (ihf - ih);
2340           float result = hd + (hw - hd) * (iwf - iw);
2341           outW.at({ob, oh, ow, oc}) = result;
2342         }
2343       }
2344     }
2345   }
2346 }
2347 
fwdResizeBilinearInst(const ResizeBilinearInst * I)2348 void BoundInterpreterFunction::fwdResizeBilinearInst(
2349     const ResizeBilinearInst *I) {
2350   if (getTensor(I->getSrc())->getType().isQuantizedType()) {
2351     dispatchQuantizedImpl(fwdResizeBilinearInstImpl,
2352                           I->getSrc()->getElementType(), I);
2353     return;
2354   }
2355 
2356   dispatchImpl(fwdResizeBilinearInstImpl, I->getSrc()->getElementType(), I);
2357 }
2358 
2359 //===----------------------------------------------------------------------===//
2360 //                      Local Response Normalization
2361 //===----------------------------------------------------------------------===//
2362 
2363 template <typename ElemTy>
fwdLocalResponseNormalizationInstFloatImpl(const glow::LocalResponseNormalizationInst * I)2364 void BoundInterpreterFunction::fwdLocalResponseNormalizationInstFloatImpl(
2365     const glow::LocalResponseNormalizationInst *I) {
2366   staticAssertFloatingPointType(ElemTy);
2367 
2368   auto inW = getWeightHandle<ElemTy>(I->getSrc());
2369   auto outW = getWeightHandle<ElemTy>(I->getDest());
2370   auto scaleCache = getWeightHandle<ElemTy>(I->getScale());
2371 
2372   ShapeNHWC odim(outW.dims());
2373   ShapeNHWC idim(inW.dims());
2374 
2375   (void)odim;
2376 
2377   // LRN node does not change the shape of the input.
2378   assert(odim == idim && "Output of LRN node must be same shape as input");
2379 
2380   // LRN node normalizes across channels, so the input must have a minimum
2381   // depth of 1.
2382   assert(idim.c > 0 && "Input of LRN node must have a minimum depth of 1");
2383 
2384   auto halfWindowSize = (size_t)I->getHalfWindowSize();
2385   auto k = I->getK();
2386   auto beta = I->getBeta();
2387   auto windowSize = 2 * halfWindowSize + 1;
2388   auto normedAlpha = I->getAlpha() / windowSize;
2389 
2390   // For every input in the batch:
2391   for (dim_t n = 0; n < idim.n; n++) {
2392 
2393     // For every row:
2394     for (dim_t h = 0; h < idim.h; h++) {
2395 
2396       // For every column:
2397       for (dim_t w = 0; w < idim.w; w++) {
2398 
2399         // For every channel:
2400         for (dim_t c = 0; c < idim.c; c++) {
2401           float squareSum = 0.0;
2402           for (dim_t i = (c >= halfWindowSize ? c - halfWindowSize : 0);
2403                i <= std::min<dim_t>(c + halfWindowSize, (size_t)idim.c - 1);
2404                i++) {
2405             float val = inW.at({n, h, w, i});
2406             squareSum += val * val;
2407           }
2408 
2409           auto scale = k + normedAlpha * squareSum;
2410 
2411           // This will be used to accelerate the backward pass.
2412           scaleCache.at({n, h, w, c}) = ElemTy(scale);
2413 
2414           auto normFactor = std::pow(scale, -beta);
2415           outW.at({n, h, w, c}) =
2416               ElemTy(float(inW.at({n, h, w, c})) * normFactor);
2417         }
2418       }
2419     }
2420   }
2421 }
2422 
fwdLocalResponseNormalizationInst(const LocalResponseNormalizationInst * I)2423 void BoundInterpreterFunction::fwdLocalResponseNormalizationInst(
2424     const LocalResponseNormalizationInst *I) {
2425   dispatchFloatingPointImpl(fwdLocalResponseNormalizationInstFloatImpl,
2426                             I->getSrc()->getElementType(), I);
2427 }
2428 
fwdLocalResponseNormalizationGradInst(const glow::LocalResponseNormalizationGradInst * I)2429 void BoundInterpreterFunction::fwdLocalResponseNormalizationGradInst(
2430     const glow::LocalResponseNormalizationGradInst *I) {
2431   auto inW = getWeightHandle(I->getSrc());
2432   auto inG = getWeightHandle(I->getSrcGrad());
2433   auto outW = getWeightHandle(I->getDest());
2434   auto outG = getWeightHandle(I->getDestGrad());
2435   auto scaleCache = getWeightHandle(I->getScale());
2436 
2437   ShapeNHWC odim(outW.dims());
2438 
2439   auto halfWindowSize = I->getHalfWindowSize();
2440   auto beta = I->getBeta();
2441   auto windowSize = 2 * halfWindowSize + 1;
2442   auto normedAlpha = I->getAlpha() / windowSize;
2443 
2444   // For every input in the batch:
2445   for (dim_t n = 0; n < odim.n; n++) {
2446 
2447     // For every row:
2448     for (dim_t h = 0; h < odim.h; h++) {
2449 
2450       // For every column:
2451       for (dim_t w = 0; w < odim.w; w++) {
2452 
2453         float sum = 0.0;
2454 
2455         // Compute sum for first channel.
2456         for (dim_t c = 0; c <= halfWindowSize && c < odim.c; c++) {
2457           auto outw = outW.at({n, h, w, c});
2458           auto scale = scaleCache.at({n, h, w, c});
2459           auto outg = outG.at({n, h, w, c});
2460           sum += (outg * (outw / scale));
2461         }
2462 
2463         // For every channel:
2464         for (dim_t c = 0; c < odim.c; c++) {
2465           auto outg = outG.at({n, h, w, c});
2466           auto scale = scaleCache.at({n, h, w, c});
2467           auto inw = inW.at({n, h, w, c});
2468 
2469           inG.at({n, h, w, c}) = outg * std::pow(scale, -beta) -
2470                                  2 * normedAlpha * beta * inw * sum;
2471 
2472           // Modify sum for next channel.
2473           auto subIndex = c - halfWindowSize;
2474           auto addIndex = c + halfWindowSize + 1;
2475 
2476           if (c >= halfWindowSize) {
2477             auto outw = outW.at({n, h, w, subIndex});
2478             auto scale = scaleCache.at({n, h, w, subIndex});
2479             auto outg = outG.at({n, h, w, subIndex});
2480 
2481             // Subtract "rear" end of this window.
2482             sum -= (outg * (outw / scale));
2483           }
2484 
2485           if (addIndex < odim.c) {
2486             auto outw = outW.at({n, h, w, addIndex});
2487             auto scale = scaleCache.at({n, h, w, addIndex});
2488             auto outg = outG.at({n, h, w, addIndex});
2489 
2490             // Add "front" end of next window.
2491             sum += (outg * (outw / scale));
2492           }
2493         }
2494       }
2495     }
2496   }
2497 }
2498 
2499 //===----------------------------------------------------------------------===//
2500 //                       Arithmetic operations
2501 //===----------------------------------------------------------------------===//
fwdElementAddInstI8Impl(const ElementAddInst * I)2502 void BoundInterpreterFunction::fwdElementAddInstI8Impl(
2503     const ElementAddInst *I) {
2504   assert(getTensor(I->getLHS())->getType().isQuantizedType() &&
2505          "Wrong function");
2506   auto lhsTy = I->getLHS()->getType();
2507   auto rhsTy = I->getRHS()->getType();
2508   auto destTy = I->getDest()->getType();
2509 
2510   float lhsScale = lhsTy->getScale();
2511   float rhsScale = rhsTy->getScale();
2512   float destScale = destTy->getScale();
2513 
2514   int32_t lhsOffset = lhsTy->getOffset();
2515   int32_t rhsOffset = rhsTy->getOffset();
2516   int32_t destOffset = destTy->getOffset();
2517 
2518   auto outW = getWeightHandle<int8_t>(I->getDest());
2519   auto lhsW = getWeightHandle<int8_t>(I->getLHS());
2520   auto rhsW = getWeightHandle<int8_t>(I->getRHS());
2521   for (dim_t i = 0, e = outW.size(); i < e; i++) {
2522     int32_t L = lhsW.raw(i);
2523     int32_t R = rhsW.raw(i);
2524 
2525     // We increase the size of the integer up to 16 bits to prevent overflow.
2526     const float largeScale = float(1) / (1 << 15);
2527     // Scale both sides from 8-bit to 16-bits.
2528     int32_t L32 = std::round(float(L - lhsOffset) * (lhsScale / largeScale));
2529     int32_t R32 = std::round(float(R - rhsOffset) * (rhsScale / largeScale));
2530     int32_t sum32 = L32 + R32;
2531     sum32 = std::round(float(sum32) * (largeScale / destScale) + destOffset);
2532     outW.raw(i) = quantization::clip<int32_t, int8_t>(sum32);
2533   }
2534 }
2535 
2536 template <typename ElemTy>
fwdElementAddInstArithmeticImpl(const ElementAddInst * I)2537 void BoundInterpreterFunction::fwdElementAddInstArithmeticImpl(
2538     const ElementAddInst *I) {
2539   staticAssertArithmeticType(ElemTy);
2540 
2541   auto outW = getWeightHandle<ElemTy>(I->getDest());
2542   auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
2543   auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
2544   for (size_t i = 0, e = outW.size(); i < e; i++) {
2545     outW.raw(i) = lhsW.raw(i) + rhsW.raw(i);
2546   }
2547 }
2548 
fwdElementAddInst(const ElementAddInst * I)2549 void BoundInterpreterFunction::fwdElementAddInst(const ElementAddInst *I) {
2550   if (getTensor(I->getLHS())->getType().isQuantizedType()) {
2551     fwdElementAddInstI8Impl(I);
2552     return;
2553   }
2554 
2555   dispatchArithmeticImpl(fwdElementAddInstArithmeticImpl,
2556                          I->getLHS()->getType()->getElementType(), I);
2557 }
2558 
2559 template <typename ElemTy>
fwdElementSubInstArithmeticImpl(const ElementSubInst * I)2560 void BoundInterpreterFunction::fwdElementSubInstArithmeticImpl(
2561     const ElementSubInst *I) {
2562   staticAssertArithmeticType(ElemTy);
2563 
2564   auto outW = getWeightHandle<ElemTy>(I->getDest());
2565   auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
2566   auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
2567   for (size_t i = 0, e = outW.size(); i < e; i++) {
2568     outW.raw(i) = lhsW.raw(i) - rhsW.raw(i);
2569   }
2570 }
2571 
fwdElementSubInst(const ElementSubInst * I)2572 void BoundInterpreterFunction::fwdElementSubInst(const ElementSubInst *I) {
2573   if (getTensor(I->getLHS())->getType().isQuantizedType()) {
2574     auto destTy = I->getDest()->getType();
2575     auto lhsTy = I->getLHS()->getType();
2576     auto rhsTy = I->getRHS()->getType();
2577 
2578     float destScale = destTy->getScale();
2579     float lhsScale = lhsTy->getScale();
2580     float rhsScale = rhsTy->getScale();
2581 
2582     int32_t destOffset = destTy->getOffset();
2583     int32_t lhsOffset = lhsTy->getOffset();
2584     int32_t rhsOffset = rhsTy->getOffset();
2585 
2586     auto outW = getWeightHandle<int8_t>(I->getDest());
2587     auto lhsW = getWeightHandle<int8_t>(I->getLHS());
2588     auto rhsW = getWeightHandle<int8_t>(I->getRHS());
2589     for (size_t i = 0, e = outW.size(); i < e; i++) {
2590       //    s_d * (i_d - o_d) = s_l * (i_l - o_l) - s_r * (i_r - o_r)
2591       // => i_d = (s_l / s_d) * (i_l - o_l) - (s_r / s_d) * (i_r - o_r) + o_d
2592       float l = (lhsScale / destScale) * float(lhsW.raw(i) - lhsOffset);
2593       float r = (rhsScale / destScale) * float(rhsW.raw(i) - rhsOffset);
2594       int32_t q = std::round(l - r + destOffset);
2595       outW.raw(i) = quantization::clip<int32_t, int8_t>(q);
2596     }
2597     return;
2598   }
2599 
2600   dispatchArithmeticImpl(fwdElementSubInstArithmeticImpl,
2601                          I->getDest()->getElementType(), I);
2602 }
2603 
2604 template <typename ElemTy>
fwdElementMulInstArithmeticImpl(const ElementMulInst * I)2605 void BoundInterpreterFunction::fwdElementMulInstArithmeticImpl(
2606     const ElementMulInst *I) {
2607   staticAssertArithmeticType(ElemTy);
2608 
2609   auto outW = getWeightHandle<ElemTy>(I->getDest());
2610   auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
2611   auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
2612   for (size_t i = 0, e = outW.size(); i < e; i++) {
2613     outW.raw(i) = lhsW.raw(i) * rhsW.raw(i);
2614   }
2615 }
2616 
fwdElementMulInst(const ElementMulInst * I)2617 void BoundInterpreterFunction::fwdElementMulInst(const ElementMulInst *I) {
2618   if (getTensor(I->getLHS())->getType().isQuantizedType()) {
2619     auto lhsTy = I->getLHS()->getType();
2620     auto rhsTy = I->getRHS()->getType();
2621     auto destTy = I->getDest()->getType();
2622 
2623     TensorQuantizationParams lhsQ{lhsTy->getScale(), lhsTy->getOffset()};
2624     TensorQuantizationParams rhsQ{rhsTy->getScale(), rhsTy->getOffset()};
2625     TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()};
2626 
2627     auto outW = getWeightHandle<int8_t>(I->getDest());
2628     auto lhsW = getWeightHandle<int8_t>(I->getLHS());
2629     auto rhsW = getWeightHandle<int8_t>(I->getRHS());
2630     float scale = lhsQ.scale * rhsQ.scale / destQ.scale;
2631     for (size_t i = 0, e = outW.size(); i < e; i++) {
2632       int32_t mul = (lhsW.raw(i) - lhsQ.offset) * (rhsW.raw(i) - rhsQ.offset);
2633       outW.raw(i) = quantization::clip<int32_t, int8_t>(
2634           std::round(mul * scale) + destQ.offset);
2635     }
2636     return;
2637   }
2638 
2639   dispatchArithmeticImpl(fwdElementMulInstArithmeticImpl,
2640                          I->getDest()->getElementType(), I);
2641 }
2642 
fwdElementDivInst(const ElementDivInst * I)2643 void BoundInterpreterFunction::fwdElementDivInst(const ElementDivInst *I) {
2644   if (getTensor(I->getLHS())->getType().isQuantizedType()) {
2645     auto destTy = I->getDest()->getType();
2646     auto lhsTy = I->getLHS()->getType();
2647     auto rhsTy = I->getRHS()->getType();
2648 
2649     float destScale = destTy->getScale();
2650     float lhsScale = lhsTy->getScale();
2651     float rhsScale = rhsTy->getScale();
2652 
2653     int32_t destOffset = destTy->getOffset();
2654     int32_t lhsOffset = lhsTy->getOffset();
2655     int32_t rhsOffset = rhsTy->getOffset();
2656 
2657     auto outW = getWeightHandle<int8_t>(I->getDest());
2658     auto lhsW = getWeightHandle<int8_t>(I->getLHS());
2659     auto rhsW = getWeightHandle<int8_t>(I->getRHS());
2660     for (size_t i = 0, e = outW.size(); i < e; i++) {
2661       //    s_d * (i_d - o_d) = (s_l * (i_l - o_l)) / (s_r * (i_r - o_r))
2662       // => i_d = (s_l * (i_l - o_l)) / (s_d * s_r * (i_r - o_r)) + o_d
2663       float l = lhsScale * float(lhsW.raw(i) - lhsOffset);
2664       float r = rhsScale * destScale * float(rhsW.raw(i) - rhsOffset);
2665       int32_t q = std::round(l / r + destOffset);
2666       outW.raw(i) = quantization::clip<int32_t, int8_t>(q);
2667     }
2668     return;
2669   }
2670 
2671 #define DIV_LOOP(TYPE_)                                                        \
2672   auto outW = getWeightHandle<TYPE_>(I->getDest());                            \
2673   auto lhsW = getWeightHandle<TYPE_>(I->getLHS());                             \
2674   auto rhsW = getWeightHandle<TYPE_>(I->getRHS());                             \
2675   for (size_t i = 0, e = outW.size(); i < e; i++) {                            \
2676     outW.raw(i) = lhsW.raw(i) / rhsW.raw(i);                                   \
2677   }
2678 
2679   auto *T = getTensor(I->getDest());
2680   switch (T->getElementType()) {
2681   case ElemKind::Int64ITy: {
2682     DIV_LOOP(int64_t);
2683     return;
2684   }
2685   case ElemKind::FloatTy: {
2686     DIV_LOOP(float);
2687     return;
2688   }
2689   case ElemKind::Float16Ty: {
2690     DIV_LOOP(float16_t);
2691     return;
2692   }
2693   case ElemKind::BFloat16Ty: {
2694     DIV_LOOP(bfloat16_t);
2695     return;
2696   }
2697   default:
2698     llvm_unreachable("Unsupported type for Div.");
2699   }
2700 }
2701 
fwdElementMaxInstI8Impl(const ElementMaxInst * I)2702 void BoundInterpreterFunction::fwdElementMaxInstI8Impl(
2703     const ElementMaxInst *I) {
2704   assert(getTensor(I->getLHS())->getType().isQuantizedType() &&
2705          "Wrong function");
2706   auto lhsTy = I->getLHS()->getType();
2707   auto rhsTy = I->getRHS()->getType();
2708   auto destTy = I->getDest()->getType();
2709 
2710   TensorQuantizationParams lhsQ{lhsTy->getScale(), lhsTy->getOffset()};
2711   TensorQuantizationParams rhsQ{rhsTy->getScale(), rhsTy->getOffset()};
2712   TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()};
2713 
2714   auto outW = getWeightHandle<int8_t>(I->getDest());
2715   auto lhsW = getWeightHandle<int8_t>(I->getLHS());
2716   auto rhsW = getWeightHandle<int8_t>(I->getRHS());
2717   for (size_t i = 0, e = outW.size(); i < e; i++) {
2718     // Convert both sides to the destination scale and perform a regular
2719     // comparison.
2720     int8_t L = quantization::quantize(
2721         quantization::dequantize(lhsW.raw(i), lhsQ), destQ);
2722     int8_t R = quantization::quantize(
2723         quantization::dequantize(rhsW.raw(i), rhsQ), destQ);
2724     outW.raw(i) = std::max(L, R);
2725   }
2726 }
2727 
2728 template <typename ElemTy>
fwdElementMaxInstArithmeticImpl(const ElementMaxInst * I)2729 void BoundInterpreterFunction::fwdElementMaxInstArithmeticImpl(
2730     const ElementMaxInst *I) {
2731   staticAssertArithmeticType(ElemTy);
2732 
2733   auto outW = getWeightHandle<ElemTy>(I->getDest());
2734   auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
2735   auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
2736   for (size_t i = 0, e = outW.size(); i < e; i++) {
2737     outW.raw(i) = std::max(lhsW.raw(i), rhsW.raw(i));
2738   }
2739 }
2740 
fwdElementMaxInst(const ElementMaxInst * I)2741 void BoundInterpreterFunction::fwdElementMaxInst(const ElementMaxInst *I) {
2742   if (getTensor(I->getLHS())->getType().isQuantizedType()) {
2743     fwdElementMaxInstI8Impl(I);
2744     return;
2745   }
2746 
2747   dispatchArithmeticImpl(fwdElementMaxInstArithmeticImpl,
2748                          I->getLHS()->getType()->getElementType(), I);
2749 }
2750 
2751 template <typename ElemTy>
fwdElementMinInstArithmeticImpl(const ElementMinInst * I)2752 void BoundInterpreterFunction::fwdElementMinInstArithmeticImpl(
2753     const ElementMinInst *I) {
2754   staticAssertArithmeticType(ElemTy);
2755 
2756   auto outW = getWeightHandle<ElemTy>(I->getDest());
2757   auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
2758   auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
2759   for (size_t i = 0, e = outW.size(); i < e; i++) {
2760     outW.raw(i) = std::min(lhsW.raw(i), rhsW.raw(i));
2761   }
2762 }
2763 
fwdElementMinInst(const ElementMinInst * I)2764 void BoundInterpreterFunction::fwdElementMinInst(const ElementMinInst *I) {
2765   if (getTensor(I->getLHS())->getType().isQuantizedType()) {
2766     auto lhsTy = I->getLHS()->getType();
2767     auto rhsTy = I->getRHS()->getType();
2768     auto destTy = I->getDest()->getType();
2769 
2770     TensorQuantizationParams lhsQ{lhsTy->getScale(), lhsTy->getOffset()};
2771     TensorQuantizationParams rhsQ{rhsTy->getScale(), rhsTy->getOffset()};
2772     TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()};
2773 
2774     auto outW = getWeightHandle<int8_t>(I->getDest());
2775     auto lhsW = getWeightHandle<int8_t>(I->getLHS());
2776     auto rhsW = getWeightHandle<int8_t>(I->getRHS());
2777     for (size_t i = 0, e = outW.size(); i < e; i++) {
2778       // Convert both sides to the destination scale and perform a regular
2779       // comparison.
2780       int8_t L = quantization::quantize(
2781           quantization::dequantize(lhsW.raw(i), lhsQ), destQ);
2782       int8_t R = quantization::quantize(
2783           quantization::dequantize(rhsW.raw(i), rhsQ), destQ);
2784       outW.raw(i) = std::min(L, R);
2785     }
2786     return;
2787   }
2788 
2789   dispatchArithmeticImpl(fwdElementMinInstArithmeticImpl,
2790                          I->getDest()->getElementType(), I);
2791 }
2792 
2793 //===----------------------------------------------------------------------===//
2794 //                              Logical operations
2795 //===----------------------------------------------------------------------===//
fwdElementNotInst(const ElementNotInst * I)2796 void BoundInterpreterFunction::fwdElementNotInst(const ElementNotInst *I) {
2797   auto inpW = getWeightHandle<bool>(I->getSrc());
2798   auto outW = getWeightHandle<bool>(I->getDest());
2799   for (size_t i = 0, e = outW.size(); i < e; ++i) {
2800     outW.raw(i) = (!inpW.raw(i));
2801   }
2802 }
2803 
fwdElementAndInst(const ElementAndInst * I)2804 void BoundInterpreterFunction::fwdElementAndInst(const ElementAndInst *I) {
2805   auto lhsW = getWeightHandle<bool>(I->getLHS());
2806   auto rhsW = getWeightHandle<bool>(I->getRHS());
2807   auto outW = getWeightHandle<bool>(I->getDest());
2808   for (size_t i = 0, e = outW.size(); i < e; ++i) {
2809     outW.raw(i) = (lhsW.raw(i) && rhsW.raw(i));
2810   }
2811 }
2812 
fwdElementOrInst(const ElementOrInst * I)2813 void BoundInterpreterFunction::fwdElementOrInst(const ElementOrInst *I) {
2814   auto lhsW = getWeightHandle<bool>(I->getLHS());
2815   auto rhsW = getWeightHandle<bool>(I->getRHS());
2816   auto outW = getWeightHandle<bool>(I->getDest());
2817   for (size_t i = 0, e = outW.size(); i < e; ++i) {
2818     outW.raw(i) = (lhsW.raw(i) || rhsW.raw(i));
2819   }
2820 }
2821 
fwdElementXorInst(const ElementXorInst * I)2822 void BoundInterpreterFunction::fwdElementXorInst(const ElementXorInst *I) {
2823   auto lhsW = getWeightHandle<bool>(I->getLHS());
2824   auto rhsW = getWeightHandle<bool>(I->getRHS());
2825   auto outW = getWeightHandle<bool>(I->getDest());
2826   for (size_t i = 0, e = outW.size(); i < e; ++i) {
2827     outW.raw(i) = (lhsW.raw(i) ^ rhsW.raw(i));
2828   }
2829 }
2830 
2831 //===----------------------------------------------------------------------===//
2832 //                         Unary arithmetic operations
2833 //===----------------------------------------------------------------------===//
2834 template <typename ElemTy, typename InstKind>
fwdUnaryArithmeticImpl(const InstKind * I,std::function<float (float)> func)2835 void BoundInterpreterFunction::fwdUnaryArithmeticImpl(
2836     const InstKind *I, std::function<float(float)> func) {
2837   Value *inpV = I->getSrc();
2838   Value *outV = I->getDest();
2839   auto inpTy = inpV->getType();
2840   auto outTy = outV->getType();
2841   auto inpH = getWeightHandle<ElemTy>(inpV);
2842   auto outH = getWeightHandle<ElemTy>(outV);
2843 
2844   if (inpTy->isQuantizedType()) {
2845     float inpScale = inpTy->getScale();
2846     int32_t inpOffset = inpTy->getOffset();
2847     float outScale = outTy->getScale();
2848     int32_t outOffset = outTy->getOffset();
2849     for (size_t i = 0, e = outH.size(); i < e; ++i) {
2850       float inpVal =
2851           quantization::dequantize<ElemTy>(inpH.raw(i), {inpScale, inpOffset});
2852       float outVal = func(inpVal);
2853       outH.raw(i) =
2854           quantization::quantize<ElemTy>(outVal, {outScale, outOffset});
2855     }
2856   } else {
2857     for (size_t i = 0, e = outH.size(); i < e; ++i) {
2858       float inpVal = static_cast<float>(inpH.raw(i));
2859       float outVal = func(inpVal);
2860       outH.raw(i) = static_cast<ElemTy>(outVal);
2861     }
2862   }
2863 }
2864 
fwdElementAbsInst(const ElementAbsInst * I)2865 void BoundInterpreterFunction::fwdElementAbsInst(const ElementAbsInst *I) {
2866   auto func = [](float x) -> float { return std::abs(x); };
2867   dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2868 }
2869 
fwdElementNegInst(const ElementNegInst * I)2870 void BoundInterpreterFunction::fwdElementNegInst(const ElementNegInst *I) {
2871   auto func = [](float x) -> float { return -x; };
2872   dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2873 }
2874 
fwdElementFloorInst(const ElementFloorInst * I)2875 void BoundInterpreterFunction::fwdElementFloorInst(const ElementFloorInst *I) {
2876   auto func = [](float x) -> float { return std::floor(x); };
2877   dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2878 }
2879 
fwdElementCeilInst(const ElementCeilInst * I)2880 void BoundInterpreterFunction::fwdElementCeilInst(const ElementCeilInst *I) {
2881   auto func = [](float x) -> float { return std::ceil(x); };
2882   dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2883 }
2884 
fwdElementRoundInst(const ElementRoundInst * I)2885 void BoundInterpreterFunction::fwdElementRoundInst(const ElementRoundInst *I) {
2886   // Rounding mode required by ONNX, Numpy, TensorFlow is round to even which
2887   // rounds to nearest even integer those values with fractional part 0.5.
2888   auto func = [](float x) -> float { return std::nearbyintf(x); };
2889   dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2890 }
2891 
fwdElementSqrtInst(const ElementSqrtInst * I)2892 void BoundInterpreterFunction::fwdElementSqrtInst(const ElementSqrtInst *I) {
2893   auto func = [](float x) -> float { return std::sqrt(x); };
2894   dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2895 }
2896 
fwdElementRsqrtInst(const ElementRsqrtInst * I)2897 void BoundInterpreterFunction::fwdElementRsqrtInst(const ElementRsqrtInst *I) {
2898   auto func = [](float x) -> float { return 1 / std::sqrt(x); };
2899   dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2900 }
2901 
fwdElementReciprocalInst(const ElementReciprocalInst * I)2902 void BoundInterpreterFunction::fwdElementReciprocalInst(
2903     const ElementReciprocalInst *I) {
2904   auto func = [](float x) -> float { return 1 / x; };
2905   dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2906 }
2907 
fwdElementSinInst(const ElementSinInst * I)2908 void BoundInterpreterFunction::fwdElementSinInst(const ElementSinInst *I) {
2909   auto func = [](float x) -> float { return std::sin(x); };
2910   dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2911 }
2912 
fwdElementCosInst(const ElementCosInst * I)2913 void BoundInterpreterFunction::fwdElementCosInst(const ElementCosInst *I) {
2914   auto func = [](float x) -> float { return std::cos(x); };
2915   dispatchImpl(fwdUnaryArithmeticImpl, I->getSrc()->getElementType(), I, func);
2916 }
2917 
2918 //===----------------------------------------------------------------------===//
2919 //                              Compare operations
2920 //===----------------------------------------------------------------------===//
2921 template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
2922           typename CmpTy, typename InstCmpKind>
fwdElementCmpHelperImpl(const InstCmpKind * I,std::function<bool (CmpTy LHS,CmpTy RHS)> cmpHelper)2923 void BoundInterpreterFunction::fwdElementCmpHelperImpl(
2924     const InstCmpKind *I, std::function<bool(CmpTy LHS, CmpTy RHS)> cmpHelper) {
2925   Value *lhsV = I->getLHS();
2926   Value *rhsV = I->getRHS();
2927   Value *outV = I->getDest();
2928 
2929   auto lhsH = getWeightHandle<ElemTy>(lhsV);
2930   auto rhsH = getWeightHandle<ElemTy>(rhsV);
2931   auto oH = getWeightHandle<bool>(outV);
2932 
2933   ElemScaleTy lhsScale = 1.0f;
2934   ElemScaleTy rhsScale = 1.0f;
2935   ElemOffsetTy lhsOffset = 0;
2936   ElemOffsetTy rhsOffset = 0;
2937 
2938   auto lhsTy = lhsV->getType();
2939   auto rhsTy = rhsV->getType();
2940 
2941   if (lhsV->getType()->isQuantizedType()) {
2942     lhsScale = lhsTy->getScale();
2943     rhsScale = rhsTy->getScale();
2944 
2945     lhsOffset = lhsTy->getOffset();
2946     rhsOffset = rhsTy->getOffset();
2947   }
2948 
2949   // For each layer in the batch:
2950   for (size_t i = 0, e = oH.size(); i < e; i++) {
2951     oH.raw(i) = cmpHelper(lhsScale * (lhsH.raw(i) - lhsOffset),
2952                           rhsScale * (rhsH.raw(i) - rhsOffset));
2953   }
2954 }
2955 
2956 template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
2957           typename CmpTy>
fwdElementCmpLTEInstImpl(const ElementCmpLTEInst * I)2958 void BoundInterpreterFunction::fwdElementCmpLTEInstImpl(
2959     const ElementCmpLTEInst *I) {
2960   auto cmpHelper = [](CmpTy LHS, CmpTy RHS) -> bool { return LHS <= RHS; };
2961   fwdElementCmpHelperImpl<ElemTy, ElemOffsetTy, ElemScaleTy, CmpTy,
2962                           ElementCmpLTEInst>(I, cmpHelper);
2963 }
2964 
fwdElementCmpLTEInst(const ElementCmpLTEInst * I)2965 void BoundInterpreterFunction::fwdElementCmpLTEInst(
2966     const ElementCmpLTEInst *I) {
2967   auto *T = getTensor(I->getLHS());
2968 
2969   if (T->getType().isQuantizedType()) {
2970     fwdElementCmpLTEInstImpl<int8_t, int32_t, float, int32_t>(I);
2971     return;
2972   }
2973 
2974   switch (T->getElementType()) {
2975   case ElemKind::FloatTy:
2976     fwdElementCmpLTEInstImpl<float, float, float>(I);
2977     break;
2978   case ElemKind::Float16Ty:
2979     fwdElementCmpLTEInstImpl<float16_t, float16_t, float16_t>(I);
2980     break;
2981   case ElemKind::BFloat16Ty:
2982     fwdElementCmpLTEInstImpl<bfloat16_t, bfloat16_t, bfloat16_t>(I);
2983     break;
2984   case ElemKind::Int32ITy:
2985     fwdElementCmpLTEInstImpl<int32_t, int32_t, float>(I);
2986     break;
2987   case ElemKind::Int64ITy:
2988     fwdElementCmpLTEInstImpl<int64_t, int64_t, float>(I);
2989     break;
2990   default:
2991     llvm_unreachable("Type is not supported");
2992   }
2993 }
2994 
2995 template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
2996           typename CmpTy>
fwdElementCmpEQInstImpl(const ElementCmpEQInst * I)2997 void BoundInterpreterFunction::fwdElementCmpEQInstImpl(
2998     const ElementCmpEQInst *I) {
2999   auto cmpHelper = [](CmpTy LHS, CmpTy RHS) -> bool { return LHS == RHS; };
3000   fwdElementCmpHelperImpl<ElemTy, ElemOffsetTy, ElemScaleTy, CmpTy,
3001                           ElementCmpEQInst>(I, cmpHelper);
3002 }
3003 
fwdElementCmpEQInst(const ElementCmpEQInst * I)3004 void BoundInterpreterFunction::fwdElementCmpEQInst(const ElementCmpEQInst *I) {
3005   auto *T = getTensor(I->getLHS());
3006 
3007   if (T->getType().isQuantizedType()) {
3008     fwdElementCmpEQInstImpl<int8_t, int32_t, float, int32_t>(I);
3009     return;
3010   }
3011 
3012   switch (T->getElementType()) {
3013   case ElemKind::FloatTy:
3014     fwdElementCmpEQInstImpl<float, float, float>(I);
3015     break;
3016   case ElemKind::Float16Ty:
3017     fwdElementCmpEQInstImpl<float16_t, float16_t, float16_t>(I);
3018     break;
3019   case ElemKind::BFloat16Ty:
3020     fwdElementCmpEQInstImpl<bfloat16_t, bfloat16_t, bfloat16_t>(I);
3021     break;
3022   case ElemKind::Int32ITy:
3023     fwdElementCmpEQInstImpl<int32_t, int32_t, float>(I);
3024     break;
3025   case ElemKind::Int64ITy:
3026     fwdElementCmpEQInstImpl<int64_t, int64_t, float>(I);
3027     break;
3028   default:
3029     llvm_unreachable("Type is not supported");
3030   }
3031 }
3032 
3033 template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
3034           typename CmpTy>
fwdElementCmpNEQInstImpl(const ElementCmpNEQInst * I)3035 void BoundInterpreterFunction::fwdElementCmpNEQInstImpl(
3036     const ElementCmpNEQInst *I) {
3037   auto cmpHelper = [](CmpTy LHS, CmpTy RHS) -> bool { return !(LHS == RHS); };
3038   fwdElementCmpHelperImpl<ElemTy, ElemOffsetTy, ElemScaleTy, CmpTy,
3039                           ElementCmpNEQInst>(I, cmpHelper);
3040 }
3041 
fwdElementCmpNEQInst(const ElementCmpNEQInst * I)3042 void BoundInterpreterFunction::fwdElementCmpNEQInst(
3043     const ElementCmpNEQInst *I) {
3044   auto *T = getTensor(I->getLHS());
3045 
3046   if (T->getType().isQuantizedType()) {
3047     fwdElementCmpNEQInstImpl<int8_t, int32_t, float, int32_t>(I);
3048     return;
3049   }
3050 
3051   switch (T->getElementType()) {
3052   case ElemKind::FloatTy:
3053     fwdElementCmpNEQInstImpl<float, float, float>(I);
3054     break;
3055   case ElemKind::Float16Ty:
3056     fwdElementCmpNEQInstImpl<float16_t, float16_t, float16_t>(I);
3057     break;
3058   case ElemKind::BFloat16Ty:
3059     fwdElementCmpNEQInstImpl<bfloat16_t, bfloat16_t, bfloat16_t>(I);
3060     break;
3061   case ElemKind::Int32ITy:
3062     fwdElementCmpNEQInstImpl<int32_t, int32_t, float>(I);
3063     break;
3064   case ElemKind::Int64ITy:
3065     fwdElementCmpNEQInstImpl<int64_t, int64_t, float>(I);
3066     break;
3067   default:
3068     llvm_unreachable("Type is not supported");
3069   }
3070 }
3071 
3072 template <typename ElemTy, typename ElemOffsetTy, typename ElemScaleTy,
3073           typename CmpTy>
fwdElementCmpLTInstImpl(const ElementCmpLTInst * I)3074 void BoundInterpreterFunction::fwdElementCmpLTInstImpl(
3075     const ElementCmpLTInst *I) {
3076   auto cmpHelper = [](CmpTy LHS, CmpTy RHS) -> bool { return LHS < RHS; };
3077   fwdElementCmpHelperImpl<ElemTy, ElemOffsetTy, ElemScaleTy, CmpTy,
3078                           ElementCmpLTInst>(I, cmpHelper);
3079 }
3080 
fwdElementCmpLTInst(ElementCmpLTInst const * I)3081 void BoundInterpreterFunction::fwdElementCmpLTInst(ElementCmpLTInst const *I) {
3082   auto *T = getTensor(I->getLHS());
3083   if (T->getType().isQuantizedType()) {
3084     fwdElementCmpLTInstImpl<int8_t, int32_t, float, int32_t>(I);
3085     return;
3086   }
3087 
3088   switch (T->getElementType()) {
3089   case ElemKind::FloatTy:
3090     fwdElementCmpLTInstImpl<float, float, float>(I);
3091     break;
3092   case ElemKind::Float16Ty:
3093     fwdElementCmpLTInstImpl<float16_t, float16_t, float16_t>(I);
3094     break;
3095   case ElemKind::BFloat16Ty:
3096     fwdElementCmpLTInstImpl<bfloat16_t, bfloat16_t, bfloat16_t>(I);
3097     break;
3098   case ElemKind::Int32ITy:
3099     fwdElementCmpLTInstImpl<int32_t, int32_t, float>(I);
3100     break;
3101   case ElemKind::Int64ITy:
3102     fwdElementCmpLTInstImpl<int64_t, int64_t, float>(I);
3103     break;
3104   default:
3105     llvm_unreachable("Type is not supported");
3106   }
3107 }
3108 
3109 template <typename ElemTy>
fwdElementPowInstFloatImpl(const ElementPowInst * I)3110 void BoundInterpreterFunction::fwdElementPowInstFloatImpl(
3111     const ElementPowInst *I) {
3112   staticAssertFloatingPointType(ElemTy);
3113 
3114   auto baseW = getWeightHandle<ElemTy>(I->getLHS());
3115   auto expW = getWeightHandle<ElemTy>(I->getRHS());
3116   auto outW = getWeightHandle<ElemTy>(I->getDest());
3117   for (size_t i = 0, e = outW.size(); i < e; i++) {
3118     outW.raw(i) = ElemTy(pow(float(baseW.raw(i)), float(expW.raw(i))));
3119   }
3120 }
3121 
fwdElementPowInst(const glow::ElementPowInst * I)3122 void BoundInterpreterFunction::fwdElementPowInst(
3123     const glow::ElementPowInst *I) {
3124   dispatchFloatingPointImpl(fwdElementPowInstFloatImpl,
3125                             I->getLHS()->getElementType(), I);
3126 }
3127 
3128 template <typename ElemTy>
fwdElementIsNaNInstFloatImpl(const ElementIsNaNInst * I)3129 void BoundInterpreterFunction::fwdElementIsNaNInstFloatImpl(
3130     const ElementIsNaNInst *I) {
3131   staticAssertFloatingPointType(ElemTy);
3132 
3133   auto inW = getWeightHandle<ElemTy>(I->getSrc());
3134   auto outW = getWeightHandle<bool>(I->getDest());
3135   for (size_t i = 0, e = inW.size(); i < e; i++) {
3136     float val = inW.raw(i);
3137     outW.raw(i) = std::isnan(val);
3138   }
3139 }
3140 
fwdElementIsNaNInst(const glow::ElementIsNaNInst * I)3141 void BoundInterpreterFunction::fwdElementIsNaNInst(
3142     const glow::ElementIsNaNInst *I) {
3143   dispatchFloatingPointImpl(fwdElementIsNaNInstFloatImpl,
3144                             I->getSrc()->getElementType(), I);
3145 }
3146 
3147 template <typename ElemTy>
fwdElementLogInstFloatImpl(const ElementLogInst * I)3148 void BoundInterpreterFunction::fwdElementLogInstFloatImpl(
3149     const ElementLogInst *I) {
3150   staticAssertFloatingPointType(ElemTy);
3151 
3152   auto inW = getWeightHandle<ElemTy>(I->getSrc());
3153   auto outW = getWeightHandle<ElemTy>(I->getDest());
3154   for (size_t i = 0, e = inW.size(); i < e; i++) {
3155     float val = inW.raw(i);
3156     outW.raw(i) = ElemTy(log(val));
3157   }
3158 }
3159 
fwdElementLogInst(const ElementLogInst * I)3160 void BoundInterpreterFunction::fwdElementLogInst(const ElementLogInst *I) {
3161   dispatchFloatingPointImpl(fwdElementLogInstFloatImpl,
3162                             I->getSrc()->getElementType(), I);
3163 }
3164 
3165 template <typename ElemTy>
fwdElementExpInstFloatImpl(const ElementExpInst * I)3166 void BoundInterpreterFunction::fwdElementExpInstFloatImpl(
3167     const ElementExpInst *I) {
3168   staticAssertFloatingPointType(ElemTy);
3169 
3170   auto inW = getWeightHandle<ElemTy>(I->getSrc());
3171   auto outW = getWeightHandle<ElemTy>(I->getDest());
3172   for (size_t i = 0, e = inW.size(); i < e; i++) {
3173     float val = inW.raw(i);
3174     outW.raw(i) = ElemTy(exp(val));
3175   }
3176 }
3177 
fwdElementExpInst(const ElementExpInst * I)3178 void BoundInterpreterFunction::fwdElementExpInst(const ElementExpInst *I) {
3179   dispatchFloatingPointImpl(fwdElementExpInstFloatImpl,
3180                             I->getSrc()->getElementType(), I);
3181 }
3182 
3183 template <typename ElemTy>
fwdElementSelectInstFloatImpl(const glow::ElementSelectInst * I)3184 void BoundInterpreterFunction::fwdElementSelectInstFloatImpl(
3185     const glow::ElementSelectInst *I) {
3186   staticAssertFloatingPointType(ElemTy);
3187 
3188   auto outW = getWeightHandle<ElemTy>(I->getDest());
3189   auto condW = getWeightHandle<bool>(I->getCond());
3190   auto lhsW = getWeightHandle<ElemTy>(I->getLHS());
3191   auto rhsW = getWeightHandle<ElemTy>(I->getRHS());
3192   for (size_t i = 0, e = outW.size(); i < e; i++) {
3193     outW.raw(i) = condW.raw(i) ? lhsW.raw(i) : rhsW.raw(i);
3194   }
3195 }
3196 
fwdElementSelectInst(const glow::ElementSelectInst * I)3197 void BoundInterpreterFunction::fwdElementSelectInst(
3198     const glow::ElementSelectInst *I) {
3199   if (getTensor(I->getLHS())->getType().isQuantizedType()) {
3200     auto destTy = I->getDest()->getType();
3201     auto lhsTy = I->getLHS()->getType();
3202     auto rhsTy = I->getRHS()->getType();
3203 
3204     float destScale = destTy->getScale();
3205     float lhsScale = lhsTy->getScale();
3206     float rhsScale = rhsTy->getScale();
3207 
3208     int32_t destOffset = destTy->getOffset();
3209     int32_t lhsOffset = lhsTy->getOffset();
3210     int32_t rhsOffset = rhsTy->getOffset();
3211 
3212     auto outW = getWeightHandle<int8_t>(I->getDest());
3213     auto condW = getWeightHandle<bool>(I->getCond());
3214     auto lhsW = getWeightHandle<int8_t>(I->getLHS());
3215     auto rhsW = getWeightHandle<int8_t>(I->getRHS());
3216     for (size_t i = 0, e = outW.size(); i < e; i++) {
3217       float val = condW.raw(i) ? lhsScale * (lhsW.raw(i) - lhsOffset)
3218                                : rhsScale * (rhsW.raw(i) - rhsOffset);
3219       int32_t q = std::round(val / destScale + destOffset);
3220       outW.raw(i) = quantization::clip<int32_t, int8_t>(q);
3221     }
3222     return;
3223   }
3224 
3225   dispatchFloatingPointImpl(fwdElementSelectInstFloatImpl,
3226                             I->getLHS()->getElementType(), I);
3227 }
3228 
3229 template <typename ElemTy>
fwdModuloInstImpl(glow::ModuloInst const * I)3230 void BoundInterpreterFunction::fwdModuloInstImpl(glow::ModuloInst const *I) {
3231   auto srcH = getTensor(I->getSrc())->getHandle<ElemTy>();
3232   auto destH = getTensor(I->getDest())->getHandle<ElemTy>();
3233 
3234   auto divisor = I->getDivisor();
3235   auto signFollowDivisor = I->getSignFollowDivisor();
3236 
3237   for (size_t i = 0, e = srcH.size(); i < e; i++) {
3238     auto res = srcH.raw(i) % divisor;
3239     if (signFollowDivisor && res < 0) {
3240       res += divisor;
3241     }
3242     destH.raw(i) = res;
3243   }
3244 }
3245 
fwdModuloInst(glow::ModuloInst const * I)3246 void BoundInterpreterFunction::fwdModuloInst(glow::ModuloInst const *I) {
3247   dispatchIndexTypeImpl(fwdModuloInstImpl, I->getSrc()->getElementType(), I);
3248 }
3249 
3250 //===----------------------------------------------------------------------===//
3251 //                       Mat Mul
3252 //===----------------------------------------------------------------------===//
3253 template <typename ElemTy, typename AccumulatorTy>
fwdMatMulInstQuantizedImpl(const glow::MatMulInst * I)3254 void BoundInterpreterFunction::fwdMatMulInstQuantizedImpl(
3255     const glow::MatMulInst *I) {
3256   assert(getTensor(I->getLHS())->getType().isQuantizedType());
3257   auto lhs = getWeightHandle<ElemTy>(I->getLHS());
3258   auto rhs = getWeightHandle<ElemTy>(I->getRHS());
3259 
3260   auto dest = getWeightHandle<ElemTy>(I->getDest());
3261 
3262   auto destDim = dest.dims();
3263   auto lhsDim = lhs.dims();
3264 
3265   auto destTy = I->getDest()->getType();
3266   auto lhsTy = I->getLHS()->getType();
3267   auto rhsTy = I->getRHS()->getType();
3268 
3269   dest.clear(0);
3270 
3271   // For matrix multiplication, if the offset is equal to zero the scale
3272   // is defined as the formula (L.scale * R.scale / D.scale).
3273   // In here we assume that the offset for all buffers is zero.
3274   float scale = lhsTy->getScale() * rhsTy->getScale() / destTy->getScale();
3275   int32_t lhsOffset = lhsTy->getOffset();
3276   int32_t rhsOffset = rhsTy->getOffset();
3277   int32_t destOffset = destTy->getOffset();
3278 
3279   // For each (x,y) in the destination matrix:
3280   for (dim_t x = 0; x < destDim[0]; x++) {
3281     for (dim_t y = 0; y < destDim[1]; y++) {
3282 
3283       // Perform DOT on the row an column.
3284       AccumulatorTy sum = 0;
3285       for (dim_t i = 0; i < lhsDim[1]; i++) {
3286         AccumulatorTy L = lhs.at({x, i});
3287         AccumulatorTy R = rhs.at({i, y});
3288         // We represent the element multiplication with offset as
3289         // (value - offset).
3290         sum += (L - lhsOffset) * (R - rhsOffset);
3291       }
3292 
3293       dest.at({x, y}) = quantization::clip<AccumulatorTy, ElemTy>(
3294           std::round(scale * sum + destOffset));
3295     }
3296   }
3297 }
3298 
3299 template <typename ElemTy>
fwdMatMulInstFloatImpl(const MatMulInst * I)3300 void BoundInterpreterFunction::fwdMatMulInstFloatImpl(const MatMulInst *I) {
3301   staticAssertFloatingPointType(ElemTy);
3302 
3303   auto lhs = getWeightHandle<ElemTy>(I->getLHS());
3304   auto rhs = getWeightHandle<ElemTy>(I->getRHS());
3305   auto dest = getWeightHandle<ElemTy>(I->getDest());
3306 
3307   auto destDim = dest.dims();
3308   auto lhsDim = lhs.dims();
3309 
3310   dest.clear(0);
3311 
3312   // For each (x,y) in the destination matrix:
3313   for (dim_t x = 0; x < destDim[0]; x++) {
3314     for (dim_t y = 0; y < destDim[1]; y++) {
3315 
3316       // Perform DOT on the row an column.
3317       float sum = 0;
3318       for (dim_t i = 0; i < lhsDim[1]; i++) {
3319         sum += float(lhs.at({x, i}) * rhs.at({i, y}));
3320       }
3321       dest.at({x, y}) = ElemTy(sum);
3322     }
3323   }
3324 }
3325 
fwdMatMulInst(const glow::MatMulInst * I)3326 void BoundInterpreterFunction::fwdMatMulInst(const glow::MatMulInst *I) {
3327   if (getTensor(I->getLHS())->getType().isQuantizedType()) {
3328     dispatchQuantizedWithAccumulationImpl(fwdMatMulInstQuantizedImpl,
3329                                           I->getLHS()->getElementType(), I);
3330     return;
3331   }
3332 
3333   dispatchFloatingPointImpl(fwdMatMulInstFloatImpl,
3334                             I->getLHS()->getElementType(), I);
3335 }
3336 
fwdBatchMatMulInst(const glow::BatchMatMulInst * I)3337 void BoundInterpreterFunction::fwdBatchMatMulInst(
3338     const glow::BatchMatMulInst *I) {
3339   DCHECK(!"Found BatchMatMulInst but BatchMatMul is lowered on Interpreter");
3340 }
3341 
fwdReluGradInst(const glow::ReluGradInst * I)3342 void BoundInterpreterFunction::fwdReluGradInst(const glow::ReluGradInst *I) {
3343   DCHECK(!"Found ReluGradInst but ReluGrad is lowered on Interpreter");
3344 }
3345 
3346 //===----------------------------------------------------------------------===//
3347 //                                 FC
3348 //===----------------------------------------------------------------------===//
3349 template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
fwdFullyConnectedInstQuantizedImpl(const glow::FullyConnectedInst * I)3350 void BoundInterpreterFunction::fwdFullyConnectedInstQuantizedImpl(
3351     const glow::FullyConnectedInst *I) {
3352   assert(getTensor(I->getSrc())->getType().isQuantizedType());
3353 
3354   auto inW = getWeightHandle<ElemTy>(I->getSrc());
3355   auto weightsW = getWeightHandle<ElemTy>(I->getWeights());
3356   auto biasW = getWeightHandle<BiasElemTy>(I->getBias());
3357   auto outW = getWeightHandle<ElemTy>(I->getDest());
3358 
3359   auto inTy = inW.getType();
3360   auto weightsTy = weightsW.getType();
3361   auto biasTy = biasW.getType();
3362   auto outTy = outW.getType();
3363 
3364   int32_t inOffset = inTy.getOffset();
3365   int32_t weightsOffset = weightsTy.getOffset();
3366   int32_t biasOffset = biasTy.getOffset();
3367   int32_t outOffset = outTy.getOffset();
3368 
3369   float outScale = outTy.getScale();
3370   float weightsScale = weightsTy.getScale();
3371   float biasScale = biasTy.getScale();
3372   float inScale = inTy.getScale();
3373 
3374   ShapeHW idim(inW.dims());
3375   ShapeHW odim(outW.dims());
3376 
3377   // Calculate the scale of the values that come out of the matrix
3378   // multiplication part of the calculation.
3379   float matMulScale = weightsScale * inScale;
3380 
3381   outW.clear(0);
3382 
3383   for (dim_t i = 0; i < idim.height; i++) {
3384     for (dim_t j = 0; j < odim.width; j++) {
3385       AccumulatorTy sum = 0;
3386       for (dim_t k = 0; k < idim.width; k++) {
3387         AccumulatorTy W = weightsW.at({k, j});
3388         AccumulatorTy A = inW.at({i, k});
3389         sum += (W - weightsOffset) * (A - inOffset);
3390       }
3391 
3392       // Scale the bias to match the scale of the matrix multiplication.
3393       AccumulatorTy B = std::round(float(biasW.at({j}) - biasOffset) *
3394                                    (biasScale / matMulScale));
3395 
3396       // Add the bias.
3397       sum += B;
3398 
3399       // Scale the result back to the expected destination scale.
3400       outW.at({i, j}) = quantization::clip<AccumulatorTy, ElemTy>(
3401           std::round(float(sum) * (matMulScale / outScale)) + outOffset);
3402     }
3403   }
3404 }
3405 
3406 template <typename ElemTy>
fwdFullyConnectedInstFloatImpl(const FullyConnectedInst * I)3407 void BoundInterpreterFunction::fwdFullyConnectedInstFloatImpl(
3408     const FullyConnectedInst *I) {
3409   staticAssertFloatingPointType(ElemTy);
3410 
3411   auto inW = getWeightHandle<ElemTy>(I->getSrc());
3412   auto weightsW = getWeightHandle<ElemTy>(I->getWeights());
3413   auto biasW = getWeightHandle<ElemTy>(I->getBias());
3414   auto outW = getWeightHandle<ElemTy>(I->getDest());
3415 
3416   ShapeHW idim(inW.dims());
3417   ShapeHW odim(outW.dims());
3418 
3419   outW.clear(0);
3420 
3421   for (dim_t i = 0; i < idim.height; i++) {
3422     for (dim_t j = 0; j < odim.width; j++) {
3423       float sum = 0;
3424       for (dim_t k = 0; k < idim.width; k++) {
3425         sum += float(inW.at({i, k})) * float(weightsW.at({k, j}));
3426       }
3427 
3428       outW.at({i, j}) = sum + float(biasW.at({j}));
3429     }
3430   }
3431 }
3432 
fwdFullyConnectedInst(const glow::FullyConnectedInst * I)3433 void BoundInterpreterFunction::fwdFullyConnectedInst(
3434     const glow::FullyConnectedInst *I) {
3435 
3436   if (getTensor(I->getSrc())->getType().isQuantizedType()) {
3437     dispatchQuantizedWithAccumulationAndBiasImpl(
3438         fwdFullyConnectedInstQuantizedImpl, I->getSrc()->getElementType(),
3439         I->getBias()->getElementType(), I);
3440     return;
3441   } else {
3442     dispatchFloatingPointImpl(fwdFullyConnectedInstFloatImpl,
3443                               I->getSrc()->getElementType(), I);
3444   }
3445 }
3446 
3447 //===----------------------------------------------------------------------===//
3448 //                       Row-wise quantized FC
3449 //===----------------------------------------------------------------------===//
3450 template <typename ElemTy, typename AccumulatorTy, typename BiasElemTy>
fwdRowwiseQuantizedFullyConnectedInstImpl(Value * inV,Value * outV,Value * weightsV,Value * biasV,Value * scalesV,Value * offsetsV)3451 void BoundInterpreterFunction::fwdRowwiseQuantizedFullyConnectedInstImpl(
3452     Value *inV, Value *outV, Value *weightsV, Value *biasV, Value *scalesV,
3453     Value *offsetsV) {
3454   auto inW = getWeightHandle<ElemTy>(inV);
3455   auto outW = getWeightHandle<ElemTy>(outV);
3456   auto weightsW = getWeightHandle<ElemTy>(weightsV);
3457   auto biasW = getWeightHandle<BiasElemTy>(biasV);
3458   auto scalesW = getWeightHandle<float>(scalesV);
3459   auto offsetsW = getWeightHandle<int32_t>(offsetsV);
3460   ShapeHW idim(inW.dims());
3461   ShapeHW odim(outW.dims());
3462   auto inTy = inW.getType();
3463   auto biasTy = biasW.getType();
3464   auto outTy = outW.getType();
3465   int32_t outOffset = outTy.getOffset();
3466   int32_t inOffset = inTy.getOffset();
3467   int32_t biasOffset = biasTy.getOffset();
3468   float outScale = outTy.getScale();
3469   float inScale = inTy.getScale();
3470   float biasScale = biasTy.getScale();
3471 
3472   for (dim_t i = 0; i < idim.height; i++) {
3473     for (dim_t j = 0; j < odim.width; j++) {
3474       float matMulScale = scalesW.raw(j) * inScale;
3475       AccumulatorTy sum = 0;
3476       for (dim_t k = 0; k < idim.width; k++) {
3477         AccumulatorTy W = weightsW.at({j, k});
3478         AccumulatorTy A = inW.at({i, k});
3479         sum += (W - offsetsW.raw(j)) * (A - inOffset);
3480       }
3481 
3482       // Scale the bias to match the scale of the matrix multiplication.
3483       AccumulatorTy B = std::round(float(biasW.at({j}) - biasOffset) *
3484                                    (biasScale / matMulScale));
3485 
3486       // Add the bias.
3487       sum += B;
3488 
3489       // Scale the result back to the expected destination scale.
3490       outW.at({i, j}) = quantization::clip<AccumulatorTy, ElemTy>(
3491           std::round(float(sum) * (matMulScale / outScale) + outOffset));
3492     }
3493   }
3494 }
3495 
fwdRowwiseQuantizedFullyConnectedInst(const RowwiseQuantizedFullyConnectedInst * I)3496 void BoundInterpreterFunction::fwdRowwiseQuantizedFullyConnectedInst(
3497     const RowwiseQuantizedFullyConnectedInst *I) {
3498   dispatchQuantizedWithAccumulationAndBiasImpl(
3499       fwdRowwiseQuantizedFullyConnectedInstImpl, I->getSrc()->getElementType(),
3500       I->getBias()->getElementType(), I->getSrc(), I->getDest(),
3501       I->getWeights(), I->getBias(), I->getScales(), I->getOffsets());
3502 }
3503 
3504 //===----------------------------------------------------------------------===//
3505 //                       Batched operations
3506 //===----------------------------------------------------------------------===//
3507 template <typename ElemTy, typename AccumulatorTy, typename SliceElemTy>
fwdBatchedAdd(Tensor * batch,Tensor * slice,Tensor * dest)3508 static void fwdBatchedAdd(Tensor *batch, Tensor *slice, Tensor *dest) {
3509   auto batchH = batch->getHandle<ElemTy>();
3510   auto sliceH = slice->getHandle<SliceElemTy>();
3511   auto destH = dest->getHandle<ElemTy>();
3512 
3513   auto batchTy = batch->getType();
3514   auto sliceTy = slice->getType();
3515   auto destTy = dest->getType();
3516 
3517   float sliceScale = sliceTy.getScale();
3518   float batchScale = batchTy.getScale();
3519   float destScale = destTy.getScale();
3520 
3521   int32_t sliceOffset = sliceTy.getOffset();
3522   int32_t batchOffset = batchTy.getOffset();
3523   int32_t destOffset = destTy.getOffset();
3524 
3525   auto bdim = flattenCdr(batchH.dims());
3526   assert(sliceH.size() == bdim.second && "Invalid slice size");
3527   assert(batchH.dims().drop_front() == sliceH.dims() && "Invalid batch size");
3528 
3529   // For each layer in the batch:
3530   for (dim_t n = 0; n < bdim.first; n++) {
3531     size_t base = batchH.getElementPtr({n});
3532 
3533     // For each element in the slice.
3534     for (dim_t i = 0; i < bdim.second; i++) {
3535       AccumulatorTy batchVal = batchH.raw(base + i);
3536       AccumulatorTy sliceVal = sliceH.raw(i);
3537       // We increase the size of the integer up to 16 bits for more accurate
3538       // arithmetic.
3539       const float largeScale = float(1) / (1 << 15);
3540       // Scale both sides from 8-bit to 16-bits.
3541       AccumulatorTy B =
3542           std::round(float(batchVal - batchOffset) * (batchScale / largeScale));
3543       AccumulatorTy S =
3544           std::round(float(sliceVal - sliceOffset) * (sliceScale / largeScale));
3545       AccumulatorTy R = B + S;
3546       destH.raw(base + i) = quantization::clip<AccumulatorTy, ElemTy>(
3547           std::round(float(R) * (largeScale / destScale) + destOffset));
3548     }
3549   }
3550 }
3551 
3552 template <typename ElemTy>
fwdBatchedAddInstFloatImpl(const glow::BatchedAddInst * I)3553 void BoundInterpreterFunction::fwdBatchedAddInstFloatImpl(
3554     const glow::BatchedAddInst *I) {
3555   staticAssertFloatingPointType(ElemTy);
3556 
3557   auto batch = getWeightHandle<ElemTy>(I->getBatch());
3558   auto slice = getWeightHandle<ElemTy>(I->getSlice());
3559   auto dest = getWeightHandle<ElemTy>(I->getDest());
3560 
3561   auto bdim = flattenCdr(batch.dims());
3562   assert(slice.size() == bdim.second && "Invalid slice size");
3563   assert(batch.dims().drop_front() == slice.dims() && "Invalid batch size");
3564 
3565   // For each layer in the batch:
3566   for (dim_t n = 0; n < bdim.first; n++) {
3567     size_t base = batch.getElementPtr({n});
3568 
3569     // For each element in the slice.
3570     for (dim_t i = 0; i < bdim.second; i++) {
3571       dest.raw(base + i) = batch.raw(base + i) + slice.raw(i);
3572     }
3573   }
3574 }
3575 
fwdBatchedAddInst(const glow::BatchedAddInst * I)3576 void BoundInterpreterFunction::fwdBatchedAddInst(
3577     const glow::BatchedAddInst *I) {
3578   if (getTensor(I->getBatch())->getType().isQuantizedType()) {
3579     dispatchQuantizedWithAccumulationAndBiasImpl(
3580         fwdBatchedAdd, I->getBatch()->getElementType(),
3581         I->getSlice()->getElementType(), getTensor(I->getBatch()),
3582         getTensor(I->getSlice()), getTensor(I->getDest()));
3583     return;
3584   }
3585   dispatchFloatingPointImpl(fwdBatchedAddInstFloatImpl,
3586                             I->getBatch()->getElementType(), I);
3587 }
3588 
3589 template <typename ElemTy>
fwdBatchedReduceAddInstFloatImpl(Value * batch,Value * dest,unsigned_t axis,const ShapeVector & eBatchDims,const ShapeVector & eDestDims)3590 void BoundInterpreterFunction::fwdBatchedReduceAddInstFloatImpl(
3591     Value *batch, Value *dest, unsigned_t axis, const ShapeVector &eBatchDims,
3592     const ShapeVector &eDestDims) {
3593   staticAssertFloatingPointType(ElemTy);
3594 
3595   // Get unowned handles of the batch and dest with these new expanded dims.
3596   auto eBatch = getTensor(batch)->getUnowned(eBatchDims);
3597   auto eDest = getTensor(dest)->getUnowned(eDestDims);
3598   auto eBatchH = eBatch.getHandle<ElemTy>();
3599   auto eDestH = eDest.getHandle<ElemTy>();
3600   eDestH.clear();
3601 
3602   // We can use this loop for all shapes. Use the same indices for both the
3603   // batch and dest, except for setting the axis index in the dest to 0.
3604   for (dim_t x = 0; x < eBatchDims[0]; x++) {
3605     for (dim_t y = 0; y < eBatchDims[1]; y++) {
3606       for (dim_t z = 0; z < eBatchDims[2]; z++) {
3607         for (dim_t w = 0; w < eBatchDims[3]; w++) {
3608           for (dim_t q = 0; q < eBatchDims[4]; q++) {
3609             for (dim_t r = 0; r < eBatchDims[5]; r++) {
3610               dim_t destIndices[] = {x, y, z, w, q, r};
3611               destIndices[axis] = 0;
3612               eDestH.at(destIndices) =
3613                   eDestH.at(destIndices) + eBatchH.at({x, y, z, w, q, r});
3614             }
3615           }
3616         }
3617       }
3618     }
3619   }
3620 }
3621 
fwdBatchedReduceAddInst(const glow::BatchedReduceAddInst * I)3622 void BoundInterpreterFunction::fwdBatchedReduceAddInst(
3623     const glow::BatchedReduceAddInst *I) {
3624   static_assert(max_tensor_dimensions == 6,
3625                 "Loops below assume max_tensor_dimensions = 6.");
3626 
3627   auto *batch = I->getBatch();
3628   auto *dest = I->getDest();
3629   const auto axis = I->getAxis();
3630 
3631   // Initialize both expanded batch and dest dims to the expanded batch
3632   // dims. This allows us below to iterate over the tensor regardless of its
3633   // shape using max_tensor_dimensions loops below.
3634   ShapeVector eBatchDims = expandDimsToMax(batch->dims());
3635   ShapeVector eDestDims = eBatchDims;
3636 
3637   // Set the destination axis dimension (the one we are reducing) to 1.
3638   eDestDims[axis] = 1;
3639 
3640   if (getTensor(batch)->getType().isQuantizedType()) {
3641     auto destTy = dest->getType();
3642     auto batchTy = batch->getType();
3643 
3644     float destScale = destTy->getScale();
3645     float batchScale = batchTy->getScale();
3646 
3647     int32_t destOffset = destTy->getOffset();
3648     int32_t batchOffset = batchTy->getOffset();
3649 
3650     // Get unowned handles of the batch and dest with these new expanded dims.
3651     auto eBatch = getTensor(batch)->getUnowned(eBatchDims);
3652     auto eDest = getTensor(dest)->getUnowned(eDestDims);
3653     auto eBatchH = eBatch.getHandle<int8_t>();
3654     auto eDestH = eDest.getHandle<int8_t>();
3655     eDestH.clear();
3656 
3657     // For quantization, we must accumulate in the inner-most loop into a local
3658     // float and then clip the result back into the dest tensor. Here are the
3659     // max_tensor_dimensions cases for this, to ensure the axis is used as the
3660     // inner-most loop.
3661     switch (axis) {
3662 #define LOOP_AXIS_CASE(_D0, _D1, _D2, _D3, _D4, _D5_AXIS)                      \
3663   case _D5_AXIS:                                                               \
3664     for (dim_t i##_D0 = 0; i##_D0 < eBatchDims[_D0]; i##_D0++)                 \
3665       for (dim_t i##_D1 = 0; i##_D1 < eBatchDims[_D1]; i##_D1++)               \
3666         for (dim_t i##_D2 = 0; i##_D2 < eBatchDims[_D2]; i##_D2++)             \
3667           for (dim_t i##_D3 = 0; i##_D3 < eBatchDims[_D3]; i##_D3++)           \
3668             for (dim_t i##_D4 = 0; i##_D4 < eBatchDims[_D4]; i##_D4++) {       \
3669               float sum = 0.0;                                                 \
3670               for (dim_t i##_D5_AXIS = 0; i##_D5_AXIS < eBatchDims[_D5_AXIS];  \
3671                    i##_D5_AXIS++) {                                            \
3672                 sum += eBatchH.at({i0, i1, i2, i3, i4, i5}) - batchOffset;     \
3673               }                                                                \
3674               dim_t i##_D5_AXIS = 0;                                           \
3675               int32_t res =                                                    \
3676                   std::round(sum * batchScale / destScale) + destOffset;       \
3677               eDestH.at({i0, i1, i2, i3, i4, i5}) =                            \
3678                   quantization::clip<int32_t, int8_t>(res);                    \
3679             }                                                                  \
3680     return;
3681 
3682       // Each loop order, with the inner-most dimension/index equal to the axis.
3683       LOOP_AXIS_CASE(1, 2, 3, 4, 5, 0);
3684       LOOP_AXIS_CASE(0, 2, 3, 4, 5, 1);
3685       LOOP_AXIS_CASE(0, 1, 3, 4, 5, 2);
3686       LOOP_AXIS_CASE(0, 1, 2, 4, 5, 3);
3687       LOOP_AXIS_CASE(0, 1, 2, 3, 5, 4);
3688       LOOP_AXIS_CASE(0, 1, 2, 3, 4, 5);
3689 #undef LOOP_AXIS_CASE
3690     default:
3691       llvm_unreachable("Axis should be less than max_tensor_dimensions.");
3692     }
3693   }
3694   dispatchFloatingPointImpl(fwdBatchedReduceAddInstFloatImpl,
3695                             batch->getElementType(), batch, dest, axis,
3696                             eBatchDims, eDestDims);
3697 }
3698 
3699 template <typename ElemTy>
fwdBatchedReduceMinInstImpl(Value * batch,Value * dest,const ShapeVector & eBatchDims,const ShapeVector & eDestDims,ElemTy max)3700 void BoundInterpreterFunction::fwdBatchedReduceMinInstImpl(
3701     Value *batch, Value *dest, const ShapeVector &eBatchDims,
3702     const ShapeVector &eDestDims, ElemTy max) {
3703   static_assert(max_tensor_dimensions == 6,
3704                 "Loops below assume max_tensor_dimensions = 6.");
3705   // Get unowned handles of the batch and dest with these new expanded dims.
3706   auto eBatch = getTensor(batch)->getUnowned(eBatchDims);
3707   auto eDest = getTensor(dest)->getUnowned(eDestDims);
3708   auto eBatchH = eBatch.getHandle<ElemTy>();
3709   auto eDestH = eDest.getHandle<ElemTy>();
3710   eDestH.clear(max);
3711 
3712   unsigned int axes[max_tensor_dimensions];
3713   for (dim_t i = 0; i < max_tensor_dimensions; i++) {
3714     axes[i] = (eDestDims[i] > 1);
3715   }
3716 
3717   // We can use this loop for all shapes. Use the same indices for both the
3718   // batch and dest, except for setting the axis index in the dest to 0.
3719   for (dim_t x = 0, dx = 0; x < eBatchDims[0]; x++, dx += axes[0]) {
3720     for (dim_t y = 0, dy = 0; y < eBatchDims[1]; y++, dy += axes[1]) {
3721       for (dim_t z = 0, dz = 0; z < eBatchDims[2]; z++, dz += axes[2]) {
3722         for (dim_t w = 0, dw = 0; w < eBatchDims[3]; w++, dw += axes[3]) {
3723           for (dim_t q = 0, dq = 0; q < eBatchDims[4]; q++, dq += axes[4]) {
3724             for (dim_t r = 0, dr = 0; r < eBatchDims[5]; r++, dr += axes[5]) {
3725               dim_t destIndices[] = {dx, dy, dz, dw, dq, dr};
3726               dim_t srcIndices[] = {x, y, z, w, q, r};
3727               eDestH.at(destIndices) =
3728                   eDestH.at(destIndices) < eBatchH.at(srcIndices)
3729                       ? eDestH.at(destIndices)
3730                       : eBatchH.at(srcIndices);
3731             }
3732           }
3733         }
3734       }
3735     }
3736   }
3737 }
3738 
fwdBatchedReduceMinInst(const glow::BatchedReduceMinInst * I)3739 void BoundInterpreterFunction::fwdBatchedReduceMinInst(
3740     const glow::BatchedReduceMinInst *I) {
3741 
3742   auto *batch = I->getBatch();
3743   auto *dest = I->getDest();
3744   const auto axes = I->getAxes();
3745 
3746   // Initialize both expanded batch and dest dims to the expanded batch
3747   // dims. This allows us below to iterate over the tensor regardless of its
3748   // shape using max_tensor_dimensions loops below.
3749   ShapeVector eBatchDims = expandDimsToMax(batch->dims());
3750   ShapeVector eDestDims = eBatchDims;
3751   // Set the destination axes dimensions (the one we are reducing) to 1.
3752   for (dim_t i = 0; i < axes.size(); i++) {
3753     eDestDims[axes[i]] = 1;
3754   }
3755 
3756   dispatchArithmeticImpl(fwdBatchedReduceMinInstImpl, batch->getElementType(),
3757                          batch, dest, eBatchDims, eDestDims,
3758                          std::numeric_limits<int32_t>::max());
3759 }
3760 
3761 template <typename ElemTy>
fwdCumSumInstImpl(Value * input,Value * dest,bool exclusive,bool reverse)3762 void BoundInterpreterFunction::fwdCumSumInstImpl(Value *input, Value *dest,
3763                                                  bool exclusive, bool reverse) {
3764   auto *eInput = getTensor(input);
3765   auto *eDest = getTensor(dest);
3766   auto eInputH = eInput->getHandle<ElemTy>();
3767   auto eDestH = eDest->getHandle<ElemTy>();
3768   eDestH.clear();
3769 
3770   ElemTy accum = 0;
3771 
3772   sdim_t s = 0;
3773   sdim_t n = eDestH.size();
3774   sdim_t dir = 1;
3775 
3776   if (reverse) {
3777     s = n - 1;
3778     n = -1;
3779     dir = -1;
3780   }
3781 
3782   for (sdim_t i = s; i != n; i += dir) {
3783     if (!exclusive) {
3784       accum += eInputH.at(i);
3785     }
3786     eDestH.at(i) = accum;
3787     if (exclusive) {
3788       accum += eInputH.at(i);
3789     }
3790   }
3791 }
3792 
fwdCumSumInst(glow::CumSumInst const * I)3793 void BoundInterpreterFunction::fwdCumSumInst(glow::CumSumInst const *I) {
3794   dispatchArithmeticImpl(fwdCumSumInstImpl, I->getInput()->getElementType(),
3795                          I->getInput(), I->getDest(), I->getExclusive(),
3796                          I->getReverse());
3797 }
3798 
3799 template <typename ElemTy>
fwdLengthsSumInstFloatImpl(const LengthsSumInst * I)3800 void BoundInterpreterFunction::fwdLengthsSumInstFloatImpl(
3801     const LengthsSumInst *I) {
3802   staticAssertFloatingPointType(ElemTy);
3803 
3804   auto out = getTensor(I->getDest());
3805   auto data = getTensor(I->getData());
3806   auto lengths = getTensor(I->getLengths());
3807 
3808   out->zero();
3809 
3810   auto LH = lengths->getHandle<int32_t>();
3811 
3812   size_t segments = lengths->dims()[0];
3813   size_t sliceSize = data->size() / data->dims()[0];
3814 
3815   auto DH = data->getHandle<ElemTy>();
3816   auto OH = out->getHandle<ElemTy>();
3817 
3818   size_t offsetIn = 0;
3819   size_t offsetOut = 0;
3820   for (dim_t i = 0; i < segments; i++) {
3821     for (int32_t j = 0, e = LH.raw(i); j < e; j++) {
3822       for (dim_t k = 0; k < sliceSize; k++) {
3823         OH.raw(offsetOut + k) += DH.raw(offsetIn + k);
3824       }
3825       offsetIn += sliceSize;
3826     }
3827     offsetOut += sliceSize;
3828   }
3829 
3830   assert(offsetIn == data->size() && "All values in Data should be consumed");
3831   assert(offsetOut == out->size() && "All values in Dest should be written to");
3832 }
3833 
fwdLengthsSumInst(const LengthsSumInst * I)3834 void BoundInterpreterFunction::fwdLengthsSumInst(const LengthsSumInst *I) {
3835   dispatchFloatingPointImpl(fwdLengthsSumInstFloatImpl,
3836                             I->getData()->getElementType(), I)
3837 }
3838 
3839 template <typename TI>
fwdSparseLengthsSumInstI8Impl(const SparseLengthsSumInst * I)3840 void BoundInterpreterFunction::fwdSparseLengthsSumInstI8Impl(
3841     const SparseLengthsSumInst *I) {
3842 
3843   auto out = getTensor(I->getDest());
3844   auto data = getTensor(I->getData());
3845   auto indices = getTensor(I->getIndices());
3846   auto lengths = getTensor(I->getLengths());
3847 
3848   out->zero();
3849 
3850   auto IH = indices->getHandle<TI>();
3851   auto LH = lengths->getHandle<int32_t>();
3852 
3853   size_t segments = lengths->dims()[0];
3854   size_t totalLength = 0;
3855   for (size_t i = 0; i < segments; i++) {
3856     totalLength += LH.raw(i);
3857   }
3858   assert(totalLength <= indices->dims()[0] &&
3859          "sum(Lengths) must be equal to len(Indices)");
3860 
3861   size_t lineSize = data->size() / data->dims()[0];
3862 
3863   auto DH = data->getHandle<int8_t>();
3864   auto OH = out->getHandle<int8_t>();
3865 
3866   auto TQP = [](Tensor *T) {
3867     return TensorQuantizationParams{T->getType().getScale(),
3868                                     T->getType().getOffset()};
3869   };
3870 
3871   size_t curIdx = 0;
3872   for (size_t i = 0; i < segments; i++) {
3873     std::vector<float> accum(lineSize, 0.0f);
3874     for (int32_t j = 0; j < LH.raw(i); j++) {
3875       size_t offsetIn = IH.raw(curIdx) * lineSize;
3876       for (size_t k = 0; k < lineSize; k++) {
3877         accum[k] += quantization::dequantize(DH.raw(offsetIn++), TQP(data));
3878       }
3879       curIdx++;
3880     }
3881     size_t offsetOut = i * lineSize;
3882     for (size_t k = 0; k < lineSize; k++) {
3883       OH.raw(offsetOut++) = quantization::quantize(accum[k], TQP(out));
3884     }
3885   }
3886 }
3887 
3888 template <typename ElemTy, typename TI>
fwdSparseLengthsSumInstFloatImpl(const SparseLengthsSumInst * I)3889 void BoundInterpreterFunction::fwdSparseLengthsSumInstFloatImpl(
3890     const SparseLengthsSumInst *I) {
3891   staticAssertFloatingPointType(ElemTy);
3892 
3893   auto out = getTensor(I->getDest());
3894   auto data = getTensor(I->getData());
3895   auto indices = getTensor(I->getIndices());
3896   auto lengths = getTensor(I->getLengths());
3897 
3898   out->zero();
3899 
3900   auto IH = indices->getHandle<TI>();
3901   auto LH = lengths->getHandle<int32_t>();
3902 
3903   size_t segments = lengths->dims()[0];
3904   size_t totalLength = 0;
3905   for (size_t i = 0; i < segments; i++) {
3906     totalLength += LH.raw(i);
3907   }
3908   assert(totalLength <= indices->dims()[0] &&
3909          "sum(Lengths) must be equal to len(Indices)");
3910 
3911   size_t lineSize = data->size() / data->dims()[0];
3912 
3913   auto DH = data->getHandle<ElemTy>();
3914   auto OH = out->getHandle<ElemTy>();
3915 
3916   size_t curIdx = 0;
3917   for (size_t i = 0; i < segments; i++) {
3918     for (size_t j = 0, e = LH.raw(i); j < e; j++) {
3919       size_t offsetIn = IH.raw(curIdx++) * lineSize;
3920       size_t offsetOut = i * lineSize;
3921       for (size_t k = 0; k < lineSize; k++)
3922         OH.raw(offsetOut++) += DH.raw(offsetIn++);
3923     }
3924   }
3925 }
3926 
fwdSparseLengthsSumInst(const SparseLengthsSumInst * I)3927 void BoundInterpreterFunction::fwdSparseLengthsSumInst(
3928     const SparseLengthsSumInst *I) {
3929   if (I->getDest()->getType()->isQuantizedType()) {
3930     dispatchIndexTypeImpl(fwdSparseLengthsSumInstI8Impl,
3931                           I->getIndices()->getElementType(), I);
3932     return;
3933   }
3934   dispatchFloatingPointAndIndexImpl(fwdSparseLengthsSumInstFloatImpl,
3935                                     I->getData()->getElementType(),
3936                                     I->getIndices()->getElementType(), I);
3937 }
3938 
3939 template <typename ElemTy, typename TI>
fwdSparseLengthsWeightedSumInstFloatImpl(const SparseLengthsWeightedSumInst * I)3940 void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInstFloatImpl(
3941     const SparseLengthsWeightedSumInst *I) {
3942   staticAssertFloatingPointType(ElemTy);
3943 
3944   auto out = getTensor(I->getDest());
3945   auto data = getTensor(I->getData());
3946   auto weights = getTensor(I->getWeights());
3947   auto indices = getTensor(I->getIndices());
3948   auto lengths = getTensor(I->getLengths());
3949 
3950   out->zero();
3951 
3952   auto IH = indices->getHandle<TI>();
3953   auto LH = lengths->getHandle<int32_t>();
3954 
3955   size_t segments = lengths->dims()[0];
3956   size_t totalLength = 0;
3957   for (dim_t i = 0; i < segments; i++) {
3958     totalLength += LH.raw(i);
3959   }
3960   assert(totalLength <= indices->dims()[0] &&
3961          "sum(Lengths) must be equal to len(Indices)");
3962 
3963   dim_t lineSize = data->size() / data->dims()[0];
3964 
3965   auto DH = data->getHandle<ElemTy>();
3966   auto WH = weights->getHandle<ElemTy>();
3967   auto OH = out->getHandle<ElemTy>();
3968 
3969   dim_t curIdx = 0;
3970   for (dim_t i = 0; i < segments; i++) {
3971     for (dim_t j = 0, e = LH.raw(i); j < e; j++) {
3972       ElemTy weight = WH.raw(curIdx);
3973       size_t offsetIn = IH.raw(curIdx++) * lineSize;
3974       size_t offsetOut = i * lineSize;
3975       for (dim_t k = 0; k < lineSize; k++)
3976         OH.raw(offsetOut++) += DH.raw(offsetIn++) * weight;
3977     }
3978   }
3979 }
3980 
3981 template <typename TI>
fwdSparseLengthsWeightedSumInstI8Impl(const SparseLengthsWeightedSumInst * I)3982 void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInstI8Impl(
3983     const SparseLengthsWeightedSumInst *I) {
3984 
3985   auto out = getTensor(I->getDest());
3986   auto data = getTensor(I->getData());
3987   auto weights = getTensor(I->getWeights());
3988   auto indices = getTensor(I->getIndices());
3989   auto lengths = getTensor(I->getLengths());
3990 
3991   out->zero();
3992 
3993   auto IH = indices->getHandle<TI>();
3994   auto LH = lengths->getHandle<int32_t>();
3995 
3996   dim_t segments = lengths->dims()[0];
3997   dim_t totalLength = 0;
3998   for (dim_t i = 0; i < segments; i++) {
3999     totalLength += LH.raw(i);
4000   }
4001   assert(totalLength <= indices->dims()[0] &&
4002          "sum(Lengths) must be equal to len(Indices)");
4003 
4004   dim_t lineSize = data->size() / data->dims()[0];
4005 
4006   auto DH = data->getHandle<int8_t>();
4007   auto WH = weights->getHandle<int8_t>();
4008   auto OH = out->getHandle<int8_t>();
4009 
4010   auto TQP = [](Tensor *T) {
4011     return TensorQuantizationParams{T->getType().getScale(),
4012                                     T->getType().getOffset()};
4013   };
4014   using namespace quantization;
4015 
4016   dim_t curIdx = 0;
4017   for (dim_t i = 0; i < segments; i++) {
4018     std::vector<float> accum(lineSize, 0.0f);
4019     for (int32_t j = 0; j < LH.raw(i); j++) {
4020       float weight = dequantize(WH.raw(curIdx), TQP(weights));
4021       size_t offsetIn = IH.raw(curIdx) * lineSize;
4022       for (dim_t k = 0; k < lineSize; k++) {
4023         accum[k] += weight * dequantize(DH.raw(offsetIn++), TQP(data));
4024       }
4025       curIdx++;
4026     }
4027     dim_t offsetOut = i * lineSize;
4028     for (dim_t k = 0; k < lineSize; k++) {
4029       OH.raw(offsetOut++) = quantize(accum[k], TQP(out));
4030     }
4031   }
4032 }
4033 
fwdSparseLengthsSumGradInst(const SparseLengthsSumGradInst *)4034 void BoundInterpreterFunction::fwdSparseLengthsSumGradInst(
4035     const SparseLengthsSumGradInst * /*I*/) {
4036   DCHECK(!"Found SparseLengthsSumGradInst but SparseLengthsSum is lowered on "
4037           "Interpreter");
4038 }
4039 
fwdSparseLengthsWeightedSumInst(const SparseLengthsWeightedSumInst * I)4040 void BoundInterpreterFunction::fwdSparseLengthsWeightedSumInst(
4041     const SparseLengthsWeightedSumInst *I) {
4042   if (I->getDest()->getType()->isQuantizedType()) {
4043     dispatchIndexTypeImpl(fwdSparseLengthsWeightedSumInstI8Impl,
4044                           I->getIndices()->getElementType(), I);
4045     return;
4046   }
4047   dispatchFloatingPointAndIndexImpl(fwdSparseLengthsWeightedSumInstFloatImpl,
4048                                     I->getData()->getElementType(),
4049                                     I->getIndices()->getElementType(), I);
4050 }
4051 
fwdSparseLengthsWeightedSumGradInst(const SparseLengthsWeightedSumGradInst * I)4052 void BoundInterpreterFunction::fwdSparseLengthsWeightedSumGradInst(
4053     const SparseLengthsWeightedSumGradInst *I) {
4054   assert(I->getDataGrad()->getType()->getElementType() == ElemKind::FloatTy &&
4055          "Input type must be float");
4056 
4057   auto destGrad = getTensor(I->getDestGrad());
4058   auto data = getTensor(I->getData());
4059   auto dataGrad = getTensor(I->getDataGrad());
4060   auto weightsGrad = getTensor(I->getWeightsGrad());
4061   auto weights = getTensor(I->getWeights());
4062   auto indices = getTensor(I->getIndices());
4063   auto lengths = getTensor(I->getLengths());
4064 
4065   // The data gradients not touched by this operation should
4066   // be 0, so set the entire buffer to 0 to start with.
4067   dataGrad->zero();
4068 
4069   auto LH = lengths->getHandle<int32_t>();
4070   auto IH = indices->getHandle<int64_t>();
4071 
4072   size_t segments = lengths->dims()[0];
4073   size_t totalLength = 0;
4074   for (size_t i = 0; i < segments; ++i) {
4075     totalLength += LH.raw(i);
4076   }
4077   assert(totalLength == indices->dims()[0] &&
4078          "sum(Lengths) must be equal to len(Indices)");
4079 
4080   size_t lineSize = dataGrad->size() / dataGrad->dims()[0];
4081 
4082   auto IGH = destGrad->getHandle();
4083   auto WH = weights->getHandle();
4084   auto WGH = weightsGrad->getHandle();
4085   auto DH = data->getHandle();
4086   auto OGH = dataGrad->getHandle();
4087 
4088   // For each index in each segment:
4089   //    1) accumulate into the corresponding data gradient the product of the
4090   //    gradient of the result it was added to and the weight that it was
4091   //    multiplied by during the SparseLengthsWeightedSum operation.
4092   //
4093   //    2) accumulate into each weight gradient the reduced sum of the
4094   //    elementwise product of the result slice that the corresponding weight
4095   //    produced and the input slice that the weight was multiplied with.
4096   for (size_t i = 0, curIdx = 0; i < segments; ++i) {
4097     size_t destOffset = i * lineSize;
4098     for (size_t j = 0, e = LH.raw(i); j < e; ++j, ++curIdx) {
4099       float weightGrad = 0.0f;
4100       float weight = WH.raw(curIdx);
4101       size_t dataOffset = IH.raw(curIdx) * lineSize;
4102 
4103       for (size_t k = 0; k < lineSize; ++k) {
4104         OGH.raw(dataOffset + k) += IGH.raw(destOffset + k) * weight;
4105         weightGrad += IGH.raw(destOffset + k) * DH.raw(dataOffset + k);
4106       }
4107 
4108       WGH.raw(curIdx) = weightGrad;
4109     }
4110   }
4111 }
4112 
4113 template <typename ElemTy>
fwdEmbeddingBagInstFloatImpl(const EmbeddingBagInst * I)4114 void BoundInterpreterFunction::fwdEmbeddingBagInstFloatImpl(
4115     const EmbeddingBagInst *I) {
4116   staticAssertFloatingPointType(ElemTy);
4117 
4118   auto out = getTensor(I->getDest());
4119   auto data = getTensor(I->getData());
4120   auto weights = getTensor(I->getWeights());
4121   auto indices = getTensor(I->getIndices());
4122   auto offsets = getTensor(I->getOffsets());
4123   bool hasEndOffset = I->getHasEndOffset();
4124 
4125   out->zero();
4126 
4127   auto IH = indices->getHandle<int64_t>();
4128   auto OFFH = offsets->getHandle<int64_t>();
4129 
4130   // If an end offset is present to mark the end of the last segment then this
4131   // must be subtracted to get the correct number of segments
4132   size_t segments = hasEndOffset ? offsets->dims()[0] - 1 : offsets->dims()[0];
4133   size_t numIndices = indices->dims()[0];
4134 
4135   size_t lineSize = data->size() / data->dims()[0];
4136 
4137   auto DH = data->getHandle<ElemTy>();
4138   auto WH = weights->getHandle<ElemTy>();
4139   auto OH = out->getHandle<ElemTy>();
4140 
4141   dim_t curIdx = 0;
4142   for (dim_t i = 0; i < segments; i++) {
4143     dim_t start = OFFH.raw(i);
4144     dim_t end;
4145     if (!hasEndOffset) {
4146       // Note that in this case we have to use numIndices to find the end of
4147       // the last segment. This is an issue though because it relies on knowing
4148       // the total length of the indices tensor which may not be possible.
4149       // Future implementations of this operator should always give an end
4150       // offset so eventually this case should be removed.
4151       end = i == segments - 1 ? numIndices : OFFH.raw(i + 1);
4152     } else {
4153       end = OFFH.raw(i + 1);
4154     }
4155     if (start == end) {
4156       continue;
4157     } else if (start > end) {
4158       break;
4159     }
4160     for (dim_t j = start; j < end; j++) {
4161       ElemTy weight = WH.raw(curIdx);
4162       dim_t offsetIn = IH.raw(curIdx++) * lineSize;
4163       dim_t offsetOut = i * lineSize;
4164       for (dim_t k = 0; k < lineSize; k++) {
4165         OH.raw(offsetOut++) += DH.raw(offsetIn++) * weight;
4166       }
4167     }
4168   }
4169 }
4170 
fwdEmbeddingBagInst(const EmbeddingBagInst * I)4171 void BoundInterpreterFunction::fwdEmbeddingBagInst(const EmbeddingBagInst *I) {
4172   dispatchFloatingPointImpl(fwdEmbeddingBagInstFloatImpl,
4173                             I->getData()->getElementType(), I);
4174 }
4175 
4176 template <typename T, typename AccumT, typename TI>
fwdRowwiseQuantizedSparseLengthsWeightedSumImpl(const RowwiseQuantizedSparseLengthsWeightedSumInst * I)4177 void BoundInterpreterFunction::fwdRowwiseQuantizedSparseLengthsWeightedSumImpl(
4178     const RowwiseQuantizedSparseLengthsWeightedSumInst *I) {
4179   auto *out = getTensor(I->getDest());
4180   auto *data = getTensor(I->getData());
4181   auto *dataScales = getTensor(I->getScales());
4182   auto *dataOffsets = getTensor(I->getOffsets());
4183   auto *weights = getTensor(I->getWeights());
4184   auto *indices = getTensor(I->getIndices());
4185   auto *lengths = getTensor(I->getLengths());
4186 
4187   out->zero();
4188 
4189   auto IH = indices->getHandle<TI>();
4190   auto LH = lengths->getHandle<int32_t>();
4191 
4192   dim_t segments = lengths->dims()[0];
4193   dim_t totalLength = 0;
4194   for (dim_t i = 0; i < segments; i++) {
4195     totalLength += LH.raw(i);
4196   }
4197   assert(totalLength <= indices->dims()[0] &&
4198          "sum(Lengths) must be equal to len(Indices)");
4199 
4200   dim_t lineSize = data->size() / data->dims()[0];
4201 
4202   auto DH = data->getHandle<uint8_t>();
4203   auto DSH = dataScales->getHandle<T>();
4204   auto DOH = dataOffsets->getHandle<T>();
4205   auto WH = weights->getHandle<T>();
4206   auto OH = out->getHandle<T>();
4207 
4208   dim_t curIdx = 0;
4209   for (dim_t i = 0; i < segments; i++) {
4210     std::vector<AccumT> accum(lineSize, 0.0f);
4211     for (dim_t j = 0, e = LH.raw(i); j < e; j++) {
4212       const float weight = static_cast<float>(WH.raw(curIdx));
4213       const dim_t rowIdx = IH.raw(curIdx++);
4214       const float scale = static_cast<float>(DSH.at({rowIdx}));
4215       const float offset = static_cast<float>(DOH.at({rowIdx}));
4216       size_t offsetIn = rowIdx * lineSize;
4217       for (dim_t k = 0; k < lineSize; k++) {
4218         float d = quantization::dequantizeWithFloatOffset(DH.raw(offsetIn++),
4219                                                           scale, offset);
4220         accum[k] += d * weight;
4221       }
4222     }
4223     // Accumulation in FP32 complete, now copy back to output with cast to T.
4224     size_t offsetOut = i * lineSize;
4225     for (size_t k = 0; k < lineSize; k++) {
4226       OH.raw(offsetOut++) = static_cast<T>(accum[k]);
4227     }
4228   }
4229 }
4230 
fwdRowwiseQuantizedSparseLengthsWeightedSumInst(const RowwiseQuantizedSparseLengthsWeightedSumInst * I)4231 void BoundInterpreterFunction::fwdRowwiseQuantizedSparseLengthsWeightedSumInst(
4232     const RowwiseQuantizedSparseLengthsWeightedSumInst *I) {
4233   const auto ity = I->getIndices()->getElementType();
4234   switch (I->getDest()->getElementType()) {
4235   case ElemKind::FloatTy:
4236     if (ity == ElemKind::Int32ITy) {
4237       fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float, float, int32_t>(I);
4238     } else if (ity == ElemKind::Int64ITy) {
4239       fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float, float, int64_t>(I);
4240     } else {
4241       llvm_unreachable("Index type is not supported");
4242     }
4243     break;
4244   case ElemKind::Float16Ty:
4245     if (I->getUseFP16Accumulation()) {
4246       if (ity == ElemKind::Int32ITy) {
4247         fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float16_t,
4248                                                         int32_t>(I);
4249       } else if (ity == ElemKind::Int64ITy) {
4250         fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float16_t,
4251                                                         int64_t>(I);
4252       } else {
4253         llvm_unreachable("Index type is not supported");
4254       }
4255     } else {
4256       if (ity == ElemKind::Int32ITy) {
4257         fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float,
4258                                                         int32_t>(I);
4259       } else if (ity == ElemKind::Int64ITy) {
4260         fwdRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float,
4261                                                         int64_t>(I);
4262       } else {
4263         llvm_unreachable("Index type is not supported");
4264       }
4265     }
4266     break;
4267   default:
4268     llvm_unreachable("Type is not supported");
4269   }
4270 }
4271 
4272 template <typename T, typename AccumT, typename TI>
4273 void BoundInterpreterFunction::
fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl(const FusedRowwiseQuantizedSparseLengthsWeightedSumInst * I)4274     fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl(
4275         const FusedRowwiseQuantizedSparseLengthsWeightedSumInst *I) {
4276   auto *out = getTensor(I->getDest());
4277   auto *data = getTensor(I->getData());
4278   auto *weights = getTensor(I->getWeights());
4279   auto *indices = getTensor(I->getIndices());
4280   auto *lengths = getTensor(I->getLengths());
4281 
4282   out->zero();
4283 
4284   auto IH = indices->getHandle<TI>();
4285   auto LH = lengths->getHandle<int32_t>();
4286 
4287   size_t segments = lengths->dims()[0];
4288   size_t totalLength = 0;
4289   for (size_t i = 0; i < segments; i++) {
4290     totalLength += LH.raw(i);
4291   }
4292   assert(totalLength <= indices->dims()[0] &&
4293          "sum(Lengths) must be equal to len(Indices)");
4294 
4295   const bool using4BitQuantization =
4296       data->getType().getElementType() == ElemKind::UInt4FusedFP16QTy;
4297 
4298   const size_t outLineSize = out->size() / out->dims()[0];
4299 
4300   auto DH = data->getHandle<uint8_t>();
4301   auto WH = weights->getHandle<T>();
4302   auto OH = out->getHandle<T>();
4303 
4304   dim_t curIdx = 0;
4305   for (dim_t i = 0; i < segments; i++) {
4306     std::vector<AccumT> accum(outLineSize, 0.0f);
4307     for (dim_t j = 0, e = LH.raw(i); j < e; j++) {
4308       const float weight = static_cast<float>(WH.raw(curIdx));
4309       const dim_t rowIdx = IH.raw(curIdx++);
4310       // Data type for the Scale and Offset for fused types need not follow
4311       // the type for the output Tensor passed in T.
4312       float scale, offset;
4313       switch (
4314           getScaleOffsetElemKindFromFused(data->getType().getElementType())) {
4315       case ElemKind::FloatTy:
4316         std::tie(scale, offset) = DH.getFusedScaleOffsetFromRow<float>(rowIdx);
4317         break;
4318       case ElemKind::Float16Ty:
4319         std::tie(scale, offset) =
4320             DH.getFusedScaleOffsetFromRow<float16_t>(rowIdx);
4321         break;
4322       default:
4323         llvm_unreachable("Type is not supported");
4324         break;
4325       }
4326 
4327       for (dim_t k = 0; k < outLineSize; k++) {
4328         float d = 0.0f;
4329         if (!using4BitQuantization) {
4330           d = quantization::dequantizeWithFloatOffset(
4331               DH.at({rowIdx, k}), static_cast<float>(scale),
4332               static_cast<float>(offset));
4333         } else {
4334           const bool isMSB = (k % 2 == 1);
4335           d = quantization::dequantize4BitWithFloatOffset(
4336               DH.at({rowIdx, k / 2}), static_cast<float>(scale),
4337               static_cast<float>(offset), isMSB);
4338         }
4339         accum[k] += d * weight;
4340       }
4341     }
4342     // Accumulation in FP32 complete, now copy back to output with cast to T.
4343     dim_t offsetOut = i * outLineSize;
4344     for (dim_t k = 0; k < outLineSize; k++) {
4345       OH.raw(offsetOut++) = static_cast<T>(accum[k]);
4346     }
4347   }
4348 }
4349 
4350 void BoundInterpreterFunction::
fwdFusedRowwiseQuantizedSparseLengthsWeightedSumInst(const FusedRowwiseQuantizedSparseLengthsWeightedSumInst * I)4351     fwdFusedRowwiseQuantizedSparseLengthsWeightedSumInst(
4352         const FusedRowwiseQuantizedSparseLengthsWeightedSumInst *I) {
4353   const auto ity = I->getIndices()->getElementType();
4354   switch (I->getDest()->getElementType()) {
4355   case ElemKind::FloatTy:
4356     if (ity == ElemKind::Int32ITy) {
4357       fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<float, float,
4358                                                            int32_t>(I);
4359     } else if (ity == ElemKind::Int64ITy) {
4360       fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<float, float,
4361                                                            int64_t>(I);
4362     } else {
4363       llvm_unreachable("Index type is not supported");
4364     }
4365     break;
4366   case ElemKind::Float16Ty:
4367     if (I->getUseFP16Accumulation()) {
4368       if (ity == ElemKind::Int32ITy) {
4369         fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<
4370             float16_t, float16_t, int32_t>(I);
4371       } else if (ity == ElemKind::Int64ITy) {
4372         fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<
4373             float16_t, float16_t, int64_t>(I);
4374       } else {
4375         llvm_unreachable("Index type is not supported");
4376       }
4377     } else {
4378       if (ity == ElemKind::Int32ITy) {
4379         fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float,
4380                                                              int32_t>(I);
4381       } else if (ity == ElemKind::Int64ITy) {
4382         fwdFusedRowwiseQuantizedSparseLengthsWeightedSumImpl<float16_t, float,
4383                                                              int64_t>(I);
4384       } else {
4385         llvm_unreachable("Index type is not supported");
4386       }
4387     }
4388     break;
4389   default:
4390     llvm_unreachable("Type is not supported");
4391   }
4392 }
4393 
4394 template <typename T, typename AccumT>
fwdEmbeddingBagByteRowwiseOffsetsImpl(const EmbeddingBagByteRowwiseOffsetsInst * I)4395 void BoundInterpreterFunction::fwdEmbeddingBagByteRowwiseOffsetsImpl(
4396     const EmbeddingBagByteRowwiseOffsetsInst *I) {
4397   auto *out = getTensor(I->getDest());
4398   auto *data = getTensor(I->getData());
4399   auto *weights = getTensor(I->getWeights());
4400   auto *indices = getTensor(I->getIndices());
4401   auto *offsets = getTensor(I->getOffsets());
4402   bool hasEndOffset = I->getHasEndOffset();
4403 
4404   out->zero();
4405 
4406   auto IH = indices->getHandle<int64_t>();
4407   auto OFFH = offsets->getHandle<int64_t>();
4408 
4409   // If an end offset is present to mark the end of the last segment then this
4410   // must be subtracted to get the correct number of segments
4411   size_t segments = hasEndOffset ? offsets->dims()[0] - 1 : offsets->dims()[0];
4412   dim_t numIndices = indices->dims()[0];
4413 
4414   const bool using4BitQuantization =
4415       data->getType().getElementType() == ElemKind::UInt4FusedFP16QTy;
4416 
4417   const size_t outLineSize = out->size() / out->dims()[0];
4418 
4419   auto DH = data->getHandle<uint8_t>();
4420   auto WH = weights->getHandle<T>();
4421   auto OH = out->getHandle<T>();
4422 
4423   for (dim_t i = 0; i < segments; i++) {
4424     std::vector<AccumT> accum(outLineSize, 0.0f);
4425     size_t start = OFFH.raw(i);
4426     dim_t end;
4427     if (!hasEndOffset) {
4428       // Note that in this case we have to use numIndices to find the end of
4429       // the last segment. This is an issue though because it relies on knowing
4430       // the total length of the indices tensor which may not be possible.
4431       // Future implementations of this operator should always give an end
4432       // offset so eventually this case should be removed.
4433       end = i == segments - 1 ? numIndices : OFFH.raw(i + 1);
4434     } else {
4435       end = OFFH.raw(i + 1);
4436     }
4437     if (start == end) {
4438       continue;
4439     } else if (start > end) {
4440       break;
4441     }
4442 
4443     for (dim_t j = start; j < end; j++) {
4444       const float weight = static_cast<float>(WH.raw(j));
4445       const dim_t rowIdx = IH.raw(j);
4446       T scale, offset;
4447       std::tie(scale, offset) = DH.getFusedScaleOffsetFromRow<T>(rowIdx);
4448       for (dim_t k = 0; k < outLineSize; k++) {
4449         float d = 0.0f;
4450         if (!using4BitQuantization) {
4451           d = quantization::dequantizeWithFloatOffset(
4452               DH.at({rowIdx, k}), static_cast<float>(scale),
4453               static_cast<float>(offset));
4454         } else {
4455           const bool isMSB = (k % 2 == 1);
4456           d = quantization::dequantize4BitWithFloatOffset(
4457               DH.at({rowIdx, k / 2}), static_cast<float>(scale),
4458               static_cast<float>(offset), isMSB);
4459         }
4460         accum[k] += d * weight;
4461       }
4462     }
4463     // Accumulation in FP32 complete, now copy back to output with cast to T.
4464     dim_t offsetOut = i * outLineSize;
4465     for (dim_t k = 0; k < outLineSize; k++) {
4466       OH.raw(offsetOut++) = static_cast<T>(accum[k]);
4467     }
4468   }
4469 }
4470 
fwdEmbeddingBagByteRowwiseOffsetsInst(const EmbeddingBagByteRowwiseOffsetsInst * I)4471 void BoundInterpreterFunction::fwdEmbeddingBagByteRowwiseOffsetsInst(
4472     const EmbeddingBagByteRowwiseOffsetsInst *I) {
4473   switch (I->getDest()->getElementType()) {
4474   case ElemKind::FloatTy:
4475     fwdEmbeddingBagByteRowwiseOffsetsImpl<float, float>(I);
4476     break;
4477   case ElemKind::Float16Ty:
4478     if (I->getUseFP16Accumulation()) {
4479       fwdEmbeddingBagByteRowwiseOffsetsImpl<float16_t, float16_t>(I);
4480     } else {
4481       fwdEmbeddingBagByteRowwiseOffsetsImpl<float16_t, float>(I);
4482     }
4483     break;
4484   default:
4485     llvm_unreachable("Type is not supported");
4486   }
4487 }
4488 
fwdLengthsToRangesInst(const LengthsToRangesInst * I)4489 void BoundInterpreterFunction::fwdLengthsToRangesInst(
4490     const LengthsToRangesInst *I) {
4491   auto ranges = getTensor(I->getDest())->getHandle<int32_t>();
4492   auto lengths = getTensor(I->getLengths())->getHandle<int32_t>();
4493   int32_t offset = 0;
4494   for (dim_t i = 0; i < lengths.dims()[0]; i++) {
4495     auto length = lengths.at({i});
4496     ranges.at({i, 0}) = offset;
4497     ranges.at({i, 1}) = length;
4498     offset += length;
4499   }
4500 }
4501 
fwdLengthsRangeFillInst(const LengthsRangeFillInst * I)4502 void BoundInterpreterFunction::fwdLengthsRangeFillInst(
4503     const LengthsRangeFillInst *I) {
4504   auto lengthsH = getTensor(I->getLengths())->getHandle<int32_t>();
4505   auto resultH = getTensor(I->getDest())->getHandle<int32_t>();
4506   dim_t curIdx = 0;
4507   for (dim_t i = 0, e = lengthsH.dims()[0]; i < e; i++) {
4508     for (int32_t j = 0, f = lengthsH.at({i}); j < f; j++) {
4509       resultH.at({curIdx++}) = j;
4510     }
4511   }
4512 }
4513 
4514 template <typename ElemTy>
fwdSparseToDenseInstImpl(const SparseToDenseInst * I)4515 void BoundInterpreterFunction::fwdSparseToDenseInstImpl(
4516     const SparseToDenseInst *I) {
4517 
4518   auto out = getTensor(I->getDest());
4519   auto indices = getTensor(I->getIndices());
4520   auto values = getTensor(I->getValues());
4521 
4522   out->zero();
4523 
4524   auto IH = indices->getHandle<int64_t>();
4525 
4526   size_t numIndices = indices->dims()[0];
4527   size_t numOutDims = out->dims().size();
4528 
4529   // Convert sparse representation to dense representation by taking
4530   // slices of output and values and accumulating the value slice into
4531   // the output slice.
4532 
4533   // Dimensions and offsets for the output and values slices. sliceDims
4534   // will always be {1, [rest of output dimensions]} since the first dimension
4535   // is the index in this operation. sliceOffsets will be {indices[j], 0, ...}
4536   // for the output slice and {j, 0, ...} for the values slice so that the
4537   // slice at index j gets mapped to index indices[j] in the dense
4538   // representation.
4539   ShapeVector sliceDims(out->dims().begin(), out->dims().end());
4540   ShapeVector sliceOffsets(numOutDims, 0);
4541   sliceDims[0] = 1;
4542 
4543   for (dim_t j = 0; j < numIndices; ++j) {
4544     // Create values slice with offsets {j, 0, ...}.
4545     sliceOffsets[0] = j;
4546     auto VS = values->getUnowned(sliceDims, sliceOffsets);
4547     auto VSH = VS.getHandle<ElemTy>();
4548 
4549     // Create output slice with offsets {indices[j], 0, ...}.
4550     sliceOffsets[0] = IH.at({j});
4551     auto OS = out->getUnowned(sliceDims, sliceOffsets);
4552     auto OSH = OS.getHandle<ElemTy>();
4553 
4554     // Accumulate values slice into output slice.
4555     size_t outputSliceSize = OS.size();
4556     for (size_t k = 0; k < outputSliceSize; ++k) {
4557       OSH.raw(k) += VSH.raw(k);
4558     }
4559   }
4560 }
4561 
fwdSparseToDenseInst(const SparseToDenseInst * I)4562 void BoundInterpreterFunction::fwdSparseToDenseInst(
4563     const SparseToDenseInst *I) {
4564   dispatchArithmeticImpl(fwdSparseToDenseInstImpl,
4565                          I->getDest()->getElementType(), I);
4566 }
4567 
fwdSparseToDenseMaskInst(const SparseToDenseMaskInst * I)4568 void BoundInterpreterFunction::fwdSparseToDenseMaskInst(
4569     const SparseToDenseMaskInst *I) {
4570   auto out = getTensor(I->getDest());
4571   auto values = getTensor(I->getValues());
4572   auto defaultValue = getTensor(I->getDefaultValue());
4573 
4574   auto indicesH = getTensor(I->getIndices())->getHandle<int64_t>();
4575   auto lengthsH = getTensor(I->getLengths())->getHandle<int32_t>();
4576 
4577   const std::vector<dim_t> &mask = I->getMask();
4578   size_t maskSize = mask.size();
4579   // Create a reverse map from ID to its position in the mask.
4580   std::unordered_map<int64_t, size_t> reverseMap;
4581   for (size_t i = 0; i < maskSize; i++) {
4582     assert(reverseMap.find(mask[i]) == reverseMap.end() &&
4583            "duplicate IDs in the mask");
4584     reverseMap[mask[i]] = i;
4585   }
4586 
4587   auto valueSize = defaultValue->getSizeInBytes();
4588 
4589   // First un-processed index-value pair.
4590   size_t posIn = 0;
4591   // Beginning of output block for first unprocessed batch.
4592   size_t byteOffsetOut = 0;
4593   // Lengths can be scalar, which means that all pairs belong to one batch.
4594   size_t numBatches = lengthsH.dims().empty() ? 1 : lengthsH.dims()[0];
4595   for (size_t batch = 0; batch < numBatches; batch++) {
4596     // Fill everything with maskSize copies of defaultValue.
4597     for (size_t i = 0; i < maskSize; i++) {
4598       std::copy(defaultValue->getUnsafePtr(),
4599                 &defaultValue->getUnsafePtr()[valueSize],
4600                 &out->getUnsafePtr()[byteOffsetOut + valueSize * i]);
4601     }
4602     // Go through input pairs and find matches.
4603     for (size_t i = 0, batchLen = lengthsH.raw(batch); i < batchLen;
4604          i++, posIn++) {
4605       int64_t idx = indicesH.raw(posIn);
4606       auto it = reverseMap.find(idx);
4607       // Skip if ID is not present in the mask.
4608       if (it == reverseMap.end())
4609         continue;
4610       size_t to = it->second;
4611 
4612       std::copy(&values->getUnsafePtr()[posIn * valueSize],
4613                 &values->getUnsafePtr()[(posIn + 1) * valueSize],
4614                 &out->getUnsafePtr()[byteOffsetOut + valueSize * to]);
4615     }
4616 
4617     byteOffsetOut += maskSize * valueSize;
4618   }
4619 
4620   assert(posIn == indicesH.dims()[0] &&
4621          "Sum of Lengths must be equal to size of indices.");
4622 }
4623 
4624 //===----------------------------------------------------------------------===//
4625 //                Instructions used by RNN
4626 //===----------------------------------------------------------------------===//
4627 template <typename T, typename TI>
fwdTopK(Tensor * outW,Tensor * indW,Tensor * inW,size_t k)4628 static void fwdTopK(Tensor *outW, Tensor *indW, Tensor *inW, size_t k) {
4629   auto values = outW->getHandle<T>();
4630   auto indices = indW->getHandle<TI>();
4631   auto in = inW->getHandle<T>();
4632   size_t n = in.dims().back();
4633 
4634   size_t in_p = 0, out_p = 0;
4635   size_t tensor_end = in.size();
4636   using pairType = std::pair<float, size_t>;
4637   std::vector<pairType> buf(n);
4638 
4639   while (in_p < tensor_end) {
4640     for (size_t i = 0; i < n; i++) {
4641       buf[i].first = in.raw(in_p++);
4642       buf[i].second = i;
4643     }
4644     // NOTE: it's possible to do N + KlogK, while this version is NlogN
4645     std::sort(buf.begin(), buf.end(), [](const pairType &a, const pairType &b) {
4646       if (a.first != b.first)
4647         return a.first > b.first;
4648       return a.second < b.second;
4649     });
4650     for (size_t i = 0; i < k; i++) {
4651       values.raw(out_p) = buf[i].first;
4652       indices.raw(out_p) = buf[i].second;
4653       out_p++;
4654     }
4655   }
4656 }
4657 
4658 template <typename inpType, typename outType>
fwdArgMax(Tensor * inpT,Tensor * outT,size_t axis)4659 static void fwdArgMax(Tensor *inpT, Tensor *outT, size_t axis) {
4660 
4661   // Get input/output handles with dimensions expanded to maximum.
4662   ShapeVector inpDims = expandDimsToMax(inpT->dims());
4663   ShapeVector outDims = inpDims;
4664   outDims[axis] = 1;
4665   auto eInpT = inpT->getUnowned(inpDims);
4666   auto eOutT = outT->getUnowned(outDims);
4667   auto inpH = eInpT.getHandle<inpType>();
4668   auto outH = eOutT.getHandle<outType>();
4669 
4670   static_assert(max_tensor_dimensions == 6,
4671                 "Loops below assume max_tensor_dimensions = 6.");
4672 
4673   for (dim_t idx0 = 0; idx0 < outDims[0]; idx0++) {
4674     for (dim_t idx1 = 0; idx1 < outDims[1]; idx1++) {
4675       for (dim_t idx2 = 0; idx2 < outDims[2]; idx2++) {
4676         for (dim_t idx3 = 0; idx3 < outDims[3]; idx3++) {
4677           for (dim_t idx4 = 0; idx4 < outDims[4]; idx4++) {
4678             for (dim_t idx5 = 0; idx5 < outDims[5]; idx5++) {
4679 
4680               // Initialize maximum value/index.
4681               inpType maxVal = std::numeric_limits<inpType>::lowest();
4682               outType maxIdx = 0;
4683 
4684               // Iterate input axis dimension.
4685               for (dim_t axisIdx = 0; axisIdx < inpDims[axis]; axisIdx++) {
4686                 std::vector<dim_t> inpIdx = {idx0, idx1, idx2,
4687                                              idx3, idx4, idx5};
4688                 inpIdx[axis] = axisIdx;
4689                 inpType inpVal = inpH.at(inpIdx);
4690                 if (inpVal > maxVal) {
4691                   maxVal = inpVal;
4692                   maxIdx = axisIdx;
4693                 }
4694               }
4695 
4696               // Store maximum index.
4697               outH.at({idx0, idx1, idx2, idx3, idx4, idx5}) = maxIdx;
4698             }
4699           }
4700         }
4701       }
4702     }
4703   }
4704 }
4705 
4706 template <typename inpType, typename outType>
fwdArgMin(Tensor * inpT,Tensor * outT,size_t axis)4707 static void fwdArgMin(Tensor *inpT, Tensor *outT, size_t axis) {
4708 
4709   // Get input/output handles with dimensions expanded to maximum.
4710   ShapeVector inpDims = expandDimsToMax(inpT->dims());
4711   ShapeVector outDims = inpDims;
4712   outDims[axis] = 1;
4713   auto eInpT = inpT->getUnowned(inpDims);
4714   auto eOutT = outT->getUnowned(outDims);
4715   auto inpH = eInpT.getHandle<inpType>();
4716   auto outH = eOutT.getHandle<outType>();
4717 
4718   static_assert(max_tensor_dimensions == 6,
4719                 "Loops below assume max_tensor_dimensions = 6.");
4720 
4721   for (dim_t idx0 = 0; idx0 < outDims[0]; idx0++) {
4722     for (dim_t idx1 = 0; idx1 < outDims[1]; idx1++) {
4723       for (dim_t idx2 = 0; idx2 < outDims[2]; idx2++) {
4724         for (dim_t idx3 = 0; idx3 < outDims[3]; idx3++) {
4725           for (dim_t idx4 = 0; idx4 < outDims[4]; idx4++) {
4726             for (dim_t idx5 = 0; idx5 < outDims[5]; idx5++) {
4727 
4728               // Initialize minimum value/index.
4729               inpType minVal = std::numeric_limits<inpType>::max();
4730               outType minIdx = 0;
4731 
4732               // Iterate input axis dimension.
4733               for (dim_t axisIdx = 0; axisIdx < inpDims[axis]; axisIdx++) {
4734                 std::vector<dim_t> inpIdx = {idx0, idx1, idx2,
4735                                              idx3, idx4, idx5};
4736                 inpIdx[axis] = axisIdx;
4737                 inpType inpVal = inpH.at(inpIdx);
4738                 if (inpVal < minVal) {
4739                   minVal = inpVal;
4740                   minIdx = axisIdx;
4741                 }
4742               }
4743 
4744               // Store minimum index.
4745               outH.at({idx0, idx1, idx2, idx3, idx4, idx5}) = minIdx;
4746             }
4747           }
4748         }
4749       }
4750     }
4751   }
4752 }
4753 
4754 //===----------------------------------------------------------------------===//
4755 //                       Sorting operators
4756 //===----------------------------------------------------------------------===//
4757 
fwdTopKInst(const TopKInst * I)4758 void BoundInterpreterFunction::fwdTopKInst(const TopKInst *I) {
4759   auto outW = getTensor(I->getValues());
4760   auto indW = getTensor(I->getIndices());
4761   auto inW = getTensor(I->getInput());
4762   size_t k = I->getK();
4763 
4764   if (inW->getType().isQuantizedType()) {
4765     if (indW->getElementType() == ElemKind::Int64ITy) {
4766 
4767       fwdTopK<int8_t, int64_t>(outW, indW, inW, k);
4768     } else if (indW->getElementType() == ElemKind::Int32ITy) {
4769       fwdTopK<int8_t, int32_t>(outW, indW, inW, k);
4770     }
4771     return;
4772   }
4773 
4774   dispatchFloatingPointAndIndexImpl(fwdTopK, inW->getElementType(),
4775                                     indW->getElementType(), outW, indW, inW, k);
4776 }
4777 
4778 #define DISPATCH_ARG_MIN_MAX(functionName, elemTy, elemTyIndex, ...)           \
4779   switch (elemTy) {                                                            \
4780   case ElemKind::FloatTy:                                                      \
4781     if (elemTyIndex == ElemKind::Int64ITy) {                                   \
4782       functionName<float, int64_t>(__VA_ARGS__);                               \
4783     } else if (elemTyIndex == ElemKind::Int32ITy) {                            \
4784       functionName<float, int32_t>(__VA_ARGS__);                               \
4785     }                                                                          \
4786     break;                                                                     \
4787   case ElemKind::Int8QTy:                                                      \
4788     if (elemTyIndex == ElemKind::Int64ITy) {                                   \
4789       functionName<int8_t, int64_t>(__VA_ARGS__);                              \
4790     } else if (elemTyIndex == ElemKind::Int32ITy) {                            \
4791       functionName<int8_t, int32_t>(__VA_ARGS__);                              \
4792     }                                                                          \
4793     break;                                                                     \
4794   default:                                                                     \
4795     llvm_unreachable("Type is not supported");                                 \
4796   }
4797 
fwdArgMaxInst(const ArgMaxInst * I)4798 void BoundInterpreterFunction::fwdArgMaxInst(const ArgMaxInst *I) {
4799   auto inpT = getTensor(I->getSrc());
4800   auto outT = getTensor(I->getDest());
4801   size_t axis = I->getAxis();
4802   auto inpElemType = inpT->getElementType();
4803   auto outElemType = outT->getElementType();
4804   DISPATCH_ARG_MIN_MAX(fwdArgMax, inpElemType, outElemType, inpT, outT, axis);
4805 }
4806 
fwdArgMinInst(const ArgMinInst * I)4807 void BoundInterpreterFunction::fwdArgMinInst(const ArgMinInst *I) {
4808   auto inpT = getTensor(I->getSrc());
4809   auto outT = getTensor(I->getDest());
4810   size_t axis = I->getAxis();
4811   auto inpElemType = inpT->getElementType();
4812   auto outElemType = outT->getElementType();
4813   DISPATCH_ARG_MIN_MAX(fwdArgMin, inpElemType, outElemType, inpT, outT, axis);
4814 }
4815 #undef DISPATCH_ARG_MIN_MAX
4816 
4817 //===----------------------------------------------------------------------===//
4818 //                  Tensor allocation operations
4819 //===----------------------------------------------------------------------===//
4820 
fwdAllocActivationInst(const AllocActivationInst * I)4821 void BoundInterpreterFunction::fwdAllocActivationInst(
4822     const AllocActivationInst *I) {
4823   getOrCreateTensor(I);
4824 }
4825 
fwdDeallocActivationInst(const DeallocActivationInst * I)4826 void BoundInterpreterFunction::fwdDeallocActivationInst(
4827     const DeallocActivationInst *I) {
4828   deleteTensor(I->getSrc());
4829 }
4830 
4831 //===----------------------------------------------------------------------===//
4832 //                       Debug instructions
4833 //===----------------------------------------------------------------------===//
4834 /// Prints a value of the instruction's operand.
4835 /// In most cases it will be the name of the variable and the value of the
4836 /// tensor.
fwdDebugPrintInst(const DebugPrintInst * I)4837 void BoundInterpreterFunction::fwdDebugPrintInst(const DebugPrintInst *I) {
4838   auto *V = I->getSrc();
4839   auto *T = getTensor(V);
4840   std::string format = I->getFormat();
4841   std::string filename = I->getFileName();
4842 
4843   if (format == "console") {
4844     // Dump tensor in console.
4845     llvm::outs() << I->getName() << ": ";
4846     V->dump();
4847     llvm::outs() << "\n";
4848     dumpImpl(T);
4849     llvm::outs() << "\n";
4850   } else if (format == "bin") {
4851     TensorSerializationOptions opts;
4852     opts.withType = true;
4853     glow::dumpTensorToBinaryFile(*T, filename, opts);
4854   } else if (format == "txt") {
4855     TensorSerializationOptions opts;
4856     opts.withType = true;
4857     glow::dumpTensorToTextFile(*T, filename, opts);
4858   } else if (format == "rawbin") {
4859     TensorSerializationOptions opts;
4860     opts.withType = false;
4861     glow::dumpTensorToBinaryFile(*T, filename, opts);
4862   } else if (format == "rawtxt") {
4863     TensorSerializationOptions opts;
4864     opts.withType = false;
4865     glow::dumpTensorToTextFile(*T, filename, opts);
4866   } else {
4867     llvm_unreachable("DebugPrint format not supported!");
4868   }
4869 }
4870 
fwdTraceEventInst(const TraceEventInst * I)4871 void BoundInterpreterFunction::fwdTraceEventInst(const TraceEventInst *I) {
4872   auto T = getTensor(I->getData());
4873   auto IH = T->getHandle<int64_t>();
4874   size_t index = I->getIndex();
4875   IH.raw(index) = std::chrono::duration_cast<std::chrono::microseconds>(
4876                       std::chrono::steady_clock::now().time_since_epoch())
4877                       .count();
4878 }
4879 
4880 //===----------------------------------------------------------------------===//
4881 //                Instructions used by Quantization
4882 //===----------------------------------------------------------------------===//
fwdQuantizationProfileInst(const glow::QuantizationProfileInst * I)4883 void BoundInterpreterFunction::fwdQuantizationProfileInst(
4884     const glow::QuantizationProfileInst *I) {
4885   auto inputTensor = getWeightHandle(I->getInputTensor());
4886   auto currentHistogram = getWeightHandle(I->getHistogram());
4887   auto computationInfo = getWeightHandle(I->getComputationInfo());
4888 
4889   float &min = computationInfo.raw(0);
4890   float &max = computationInfo.raw(1);
4891 
4892   // Update current histogram, min and max based on the inputTensor data.
4893   quantization::generateTensorHistogram(inputTensor, currentHistogram, min,
4894                                         max);
4895 }
4896 
4897 /// Quantize floating point tensor. Scale and Offset are based on return type
4898 /// of the instruction \p I.
fwdQuantizeInst(const glow::QuantizeInst * I)4899 void BoundInterpreterFunction::fwdQuantizeInst(const glow::QuantizeInst *I) {
4900   auto *srcTensor = getTensor(I->getSrc());
4901   auto *destTensor = getTensor(I->getDest());
4902   auto destTy = destTensor->getType();
4903   Tensor qTensor = quantization::quantizeTensor(
4904       *srcTensor, {destTy.getScale(), destTy.getOffset()},
4905       destTy.getElementType());
4906   destTensor->assign(&qTensor);
4907 }
4908 
4909 /// Dequantize integer tensor. Scale and Offset are based
4910 /// on the source tensor type.
fwdDequantizeInst(const glow::DequantizeInst * I)4911 void BoundInterpreterFunction::fwdDequantizeInst(
4912     const glow::DequantizeInst *I) {
4913   auto *srcTensor = getTensor(I->getSrc());
4914   auto *destTensor = getTensor(I->getDest());
4915   auto destTy = destTensor->getType();
4916   Tensor fTensor =
4917       quantization::dequantizeTensor(*srcTensor, destTy.getElementType());
4918   destTensor->assign(&fTensor);
4919 }
4920 
4921 template <class eTy>
fwdRescaleQuantizedInstImpl(Value * src,Value * dest,TensorQuantizationParams & srcQ,TensorQuantizationParams & destQ)4922 void BoundInterpreterFunction::fwdRescaleQuantizedInstImpl(
4923     Value *src, Value *dest, TensorQuantizationParams &srcQ,
4924     TensorQuantizationParams &destQ) {
4925 
4926   auto srcH = getWeightHandle<eTy>(src);
4927   auto destH = getWeightHandle<eTy>(dest);
4928 
4929   for (size_t i = 0, e = destH.size(); i < e; ++i) {
4930     float val = quantization::dequantize(srcH.raw(i), srcQ);
4931     destH.raw(i) = quantization::quantize(val, destQ);
4932   }
4933 }
4934 
fwdRescaleQuantizedInst(const glow::RescaleQuantizedInst * I)4935 void BoundInterpreterFunction::fwdRescaleQuantizedInst(
4936     const glow::RescaleQuantizedInst *I) {
4937   auto src = I->getSrc();
4938   auto dest = I->getDest();
4939   auto srcTy = src->getType();
4940   auto destTy = dest->getType();
4941 
4942   TensorQuantizationParams srcQ{srcTy->getScale(), srcTy->getOffset()};
4943   TensorQuantizationParams destQ{destTy->getScale(), destTy->getOffset()};
4944 
4945   dispatchQuantizedImpl(fwdRescaleQuantizedInstImpl, destTy->getElementType(),
4946                         src, dest, srcQ, destQ);
4947 }
4948 
fwdIntLookupTableInst(const IntLookupTableInst * I)4949 void BoundInterpreterFunction::fwdIntLookupTableInst(
4950     const IntLookupTableInst *I) {
4951   auto srcH = getWeightHandle<int8_t>(I->getSrc());
4952   auto destH = getWeightHandle<int8_t>(I->getDest());
4953   auto mappingH = getWeightHandle<int8_t>(I->getMapping());
4954 
4955   for (size_t i = 0, e = destH.size(); i < e; i++) {
4956     destH.raw(i) = mappingH.raw((int)srcH.raw(i) + 128);
4957   }
4958 }
4959 
fwdConvertToInst(const glow::ConvertToInst * I)4960 void BoundInterpreterFunction::fwdConvertToInst(const glow::ConvertToInst *I) {
4961   Tensor *source = getTensor(I->getInput());
4962   Tensor *dest = getTensor(I->getResult());
4963   auto srcElType = source->getType().getElementType();
4964   auto destElType = dest->getType().getElementType();
4965   if (srcElType == destElType) {
4966     // This is a noop conversion.
4967     dest->copyRawFrom(source);
4968     return;
4969   }
4970 
4971 #define CONVERT(T_FROM, T_TO, DTY_FROM, DTY_TO)                                \
4972   if (srcElType == DTY_FROM && destElType == DTY_TO) {                         \
4973     dest->copyWithCast<T_TO, T_FROM>(source);                                  \
4974     return;                                                                    \
4975   }
4976   CONVERT(float, float16_t, ElemKind::FloatTy, ElemKind::Float16Ty)
4977   CONVERT(float, bfloat16_t, ElemKind::FloatTy, ElemKind::BFloat16Ty)
4978   CONVERT(float, bool, ElemKind::FloatTy, ElemKind::BoolTy)
4979   CONVERT(float, int32_t, ElemKind::FloatTy, ElemKind::Int32ITy)
4980   CONVERT(float, int64_t, ElemKind::FloatTy, ElemKind::Int64ITy)
4981   CONVERT(float16_t, float, ElemKind::Float16Ty, ElemKind::FloatTy)
4982   CONVERT(float16_t, bfloat16_t, ElemKind::Float16Ty, ElemKind::BFloat16Ty)
4983   CONVERT(float16_t, int32_t, ElemKind::Float16Ty, ElemKind::Int32ITy)
4984   CONVERT(float16_t, int64_t, ElemKind::Float16Ty, ElemKind::Int64ITy)
4985   CONVERT(bfloat16_t, float, ElemKind::BFloat16Ty, ElemKind::FloatTy)
4986   CONVERT(bfloat16_t, float16_t, ElemKind::BFloat16Ty, ElemKind::Float16Ty)
4987   CONVERT(bfloat16_t, int32_t, ElemKind::BFloat16Ty, ElemKind::Int32ITy)
4988   CONVERT(bfloat16_t, int64_t, ElemKind::BFloat16Ty, ElemKind::Int64ITy)
4989   CONVERT(bool, float, ElemKind::BoolTy, ElemKind::FloatTy)
4990   CONVERT(bool, bfloat16_t, ElemKind::BoolTy, ElemKind::BFloat16Ty)
4991   CONVERT(int32_t, float, ElemKind::Int32ITy, ElemKind::FloatTy)
4992   CONVERT(int32_t, float16_t, ElemKind::Int32ITy, ElemKind::Float16Ty)
4993   CONVERT(int32_t, bfloat16_t, ElemKind::Int32ITy, ElemKind::BFloat16Ty)
4994   CONVERT(int32_t, int64_t, ElemKind::Int32ITy, ElemKind::Int64ITy)
4995   CONVERT(int64_t, float, ElemKind::Int64ITy, ElemKind::FloatTy)
4996   CONVERT(int64_t, float16_t, ElemKind::Int64ITy, ElemKind::Float16Ty)
4997   CONVERT(int64_t, bfloat16_t, ElemKind::Int64ITy, ElemKind::BFloat16Ty)
4998   CONVERT(int64_t, int32_t, ElemKind::Int64ITy, ElemKind::Int32ITy)
4999 #undef CONVERT
5000 
5001   if (srcElType == ElemKind::UInt8FusedQTy &&
5002       destElType == ElemKind::UInt8FusedFP16QTy) {
5003     dest->convertToType(ElemKind::UInt8FusedFP16QTy);
5004     return;
5005   }
5006   llvm_unreachable("Type not supported");
5007 }
5008 
5009 template <typename ElemTy>
fwdBatchedPairwiseDotProductInstImpl(const BatchedPairwiseDotProductInst * I)5010 void BoundInterpreterFunction::fwdBatchedPairwiseDotProductInstImpl(
5011     const BatchedPairwiseDotProductInst *I) {
5012   auto destT = getTensor(I->getDest());
5013   auto destH = destT->getHandle<ElemTy>();
5014 
5015   dim_t batchCount = destT->getType().dims()[0];
5016 
5017   // Gather all batched vector operands into an array so that they can be
5018   // indexed easily.
5019   std::vector<Value *> srcs;
5020   for (unsigned i = 1, e = I->getNumOperands(); i < e; ++i) {
5021     auto op = I->getOperand(i);
5022     srcs.emplace_back(op.first);
5023   }
5024 
5025   // pairIdx is the total number of pairs (i, j) that have been processed.
5026   unsigned pairIdx = 0;
5027 
5028   // For each src operand:
5029   for (unsigned i = 1, e = I->getNumInputs(); i < e; ++i) {
5030     auto vAH = getTensor(srcs[i])->getHandle<ElemTy>();
5031     dim_t vectorSize = getTensor(srcs[i])->getType().dims()[1];
5032 
5033     // Compute the dot product of src[i] with every other vector with a smaller
5034     // index.
5035     for (unsigned j = 0; j < i; ++j) {
5036       auto vBH = getTensor(srcs[j])->getHandle<ElemTy>();
5037 
5038       // Process all batches for a given pair (i, j).
5039       for (dim_t b = 0; b < batchCount; ++b) {
5040         ElemTy accum = 0;
5041 
5042         for (dim_t k = 0; k < vectorSize; ++k) {
5043           accum += vAH.at({b, k}) * vBH.at({b, k});
5044         }
5045 
5046         destH.at({b, pairIdx}) = accum;
5047       }
5048 
5049       ++pairIdx;
5050     }
5051   }
5052 }
5053 
fwdBatchedPairwiseDotProductInst(const BatchedPairwiseDotProductInst * I)5054 void BoundInterpreterFunction::fwdBatchedPairwiseDotProductInst(
5055     const BatchedPairwiseDotProductInst *I) {
5056   dispatchImpl(fwdBatchedPairwiseDotProductInstImpl,
5057                I->getDest()->getElementType(), I);
5058 }
5059 
5060 template <typename ElemTy>
fwdBatchedPairwiseDotProductGradInstImpl(const BatchedPairwiseDotProductGradInst * I)5061 void BoundInterpreterFunction::fwdBatchedPairwiseDotProductGradInstImpl(
5062     const BatchedPairwiseDotProductGradInst *I) {
5063   auto destGradT = getTensor(I->getDestGrad());
5064   auto destGradH = destGradT->getHandle<ElemTy>();
5065 
5066   dim_t batchCount = destGradT->getType().dims()[0];
5067 
5068   // Gather all batched vector operands into arrays so that they can be
5069   // indexed easily. Operands 1 -> numInputs are gradients of inputs, and
5070   // operands numInputs + 1 -> numOperands - 1 are the corresponding original
5071   // inputs.
5072   std::vector<Value *> srcs, srcGrads;
5073   for (unsigned i = 0, e = I->getNumInputs(); i < e; ++i) {
5074     auto gradOp = I->getOperand(i + 1);
5075     auto inputOp = I->getOperand(i + 1 + e);
5076 
5077     srcGrads.emplace_back(gradOp.first);
5078     srcs.emplace_back(inputOp.first);
5079   }
5080 
5081   // Zero initialize all srcGrad tensors.
5082   for (auto &s : srcGrads) {
5083     getTensor(s)->zero();
5084   }
5085 
5086   // pairIdx is the total number of pairs (i, j) that have been processed.
5087   unsigned pairIdx = 0;
5088 
5089   // For each srcGrad operand:
5090   for (unsigned i = 0, e = I->getNumInputs(); i < e; ++i) {
5091     auto dvAH = getTensor(srcGrads[i])->getHandle<ElemTy>();
5092     dim_t vectorSize = getTensor(srcs[i])->getType().dims()[1];
5093 
5094     // Accmulate into it the product of the gradient of all dot products that
5095     // src[i] contributed to and the corresponding vectors that src[i] was
5096     // dotted with.
5097     for (unsigned j = i + 1; j < e; ++j) {
5098       auto vBH = getTensor(srcs[j])->getHandle<ElemTy>();
5099 
5100       // Process all batches for a given pair (i, j).
5101       for (dim_t b = 0; b < batchCount; ++b) {
5102         ElemTy grad = destGradH.at({b, pairIdx});
5103 
5104         for (dim_t k = 0; k < vectorSize; ++k) {
5105           dvAH.at({b, k}) += grad * vBH.at({b, k});
5106         }
5107       }
5108 
5109       ++pairIdx;
5110     }
5111   }
5112 }
5113 
fwdBatchedPairwiseDotProductGradInst(const BatchedPairwiseDotProductGradInst * I)5114 void BoundInterpreterFunction::fwdBatchedPairwiseDotProductGradInst(
5115     const BatchedPairwiseDotProductGradInst *I) {
5116   dispatchImpl(fwdBatchedPairwiseDotProductGradInstImpl,
5117                I->getDestGrad()->getElementType(), I);
5118 }
5119 
5120 template <typename ElemTy>
fwdFlipInstImpl(const FlipInst * I)5121 void BoundInterpreterFunction::fwdFlipInstImpl(const FlipInst *I) {
5122 
5123   static_assert(max_tensor_dimensions == 6,
5124                 "Loops below assume max_tensor_dimensions = 6.");
5125 
5126   auto *src = I->getSrc();
5127   auto *dest = I->getDest();
5128 
5129   // Get unowned handles of src and dest with dims expanded to maximum.
5130   ShapeVector eDims = expandDimsToMax(src->dims());
5131   auto eSrc = getTensor(src)->getUnowned(eDims);
5132   auto eDest = getTensor(dest)->getUnowned(eDims);
5133   auto srcH = eSrc.getHandle<ElemTy>();
5134   auto destH = eDest.getHandle<ElemTy>();
5135 
5136 #define LOOP_AXIS_CASE(_D0, _D1, _D2, _D3, _D4, _D5)                           \
5137   for (dim_t idx0 = 0; idx0 < eDims[0]; idx0++)                                \
5138     for (dim_t idx1 = 0; idx1 < eDims[1]; idx1++)                              \
5139       for (dim_t idx2 = 0; idx2 < eDims[2]; idx2++)                            \
5140         for (dim_t idx3 = 0; idx3 < eDims[3]; idx3++)                          \
5141           for (dim_t idx4 = 0; idx4 < eDims[4]; idx4++)                        \
5142             for (dim_t idx5 = 0; idx5 < eDims[5]; idx5++) {                    \
5143               destH.at({_D0, _D1, _D2, _D3, _D4, _D5}) =                       \
5144                   srcH.at({idx0, idx1, idx2, idx3, idx4, idx5});               \
5145             }                                                                  \
5146   return;
5147 
5148   switch (I->getAxis()) {
5149   case 0:
5150     LOOP_AXIS_CASE(eDims[0] - 1 - idx0, idx1, idx2, idx3, idx4, idx5);
5151   case 1:
5152     LOOP_AXIS_CASE(idx0, eDims[1] - 1 - idx1, idx2, idx3, idx4, idx5);
5153   case 2:
5154     LOOP_AXIS_CASE(idx0, idx1, eDims[2] - 1 - idx2, idx3, idx4, idx5);
5155   case 3:
5156     LOOP_AXIS_CASE(idx0, idx1, idx2, eDims[3] - 1 - idx3, idx4, idx5);
5157   case 4:
5158     LOOP_AXIS_CASE(idx0, idx1, idx2, idx3, eDims[4] - 1 - idx4, idx5);
5159   case 5:
5160     LOOP_AXIS_CASE(idx0, idx1, idx2, idx3, idx4, eDims[5] - 1 - idx5);
5161   default:
5162     llvm_unreachable("Axis should be less than max_tensor_dimensions.");
5163   }
5164 }
5165 
fwdFlipInst(const FlipInst * I)5166 void BoundInterpreterFunction::fwdFlipInst(const FlipInst *I) {
5167   dispatchImpl(fwdFlipInstImpl, I->getSrc()->getElementType(), I);
5168 }
5169 
5170 //===----------------------------------------------------------------------===//
5171 //                Instructions used by ObjectDetection
5172 //===----------------------------------------------------------------------===//
maxMin(float lhs,float rhs,float & min,float & max)5173 static void maxMin(float lhs, float rhs, float &min, float &max) {
5174   if (lhs >= rhs) {
5175     min = rhs;
5176     max = lhs;
5177   } else {
5178     min = lhs;
5179     max = rhs;
5180   }
5181 }
5182 
5183 using ClassBox = std::pair<float, dim_t>;
5184 
5185 struct Box {
5186   float classValue{0.0f};
5187   dim_t batchIndex{0};
5188   dim_t classIndex{0};
5189   dim_t boxIndex{0};
5190 };
5191 
5192 template <typename ElemTy>
doIOU(Handle<ElemTy> & boxes,dim_t batchIndex,dim_t selectedBoxIndex,dim_t candidateBoxIndex,int centerPointBox,float iouThreshold,bool isV4)5193 static bool doIOU(Handle<ElemTy> &boxes, dim_t batchIndex,
5194                   dim_t selectedBoxIndex, dim_t candidateBoxIndex,
5195                   int centerPointBox, float iouThreshold, bool isV4) {
5196   float sx[] = {0.0f, 0.0f, 0.0f, 0.0f};
5197   float cx[] = {0.0f, 0.0f, 0.0f, 0.0f};
5198 
5199   if (isV4) {
5200     for (dim_t i = 0; i < 4; ++i) {
5201       sx[i] = boxes.at({selectedBoxIndex, i});
5202       cx[i] = boxes.at({candidateBoxIndex, i});
5203     }
5204   } else {
5205     for (dim_t i = 0; i < 4; ++i) {
5206       sx[i] = boxes.at({batchIndex, selectedBoxIndex, i});
5207       cx[i] = boxes.at({batchIndex, candidateBoxIndex, i});
5208     }
5209   }
5210 
5211   float xSMin = 0.0f;
5212   float ySMin = 0.0f;
5213   float xSMax = 0.0f;
5214   float ySMax = 0.0f;
5215 
5216   float xCMin = 0.0f;
5217   float yCMin = 0.0f;
5218   float xCMax = 0.0f;
5219   float yCMax = 0.0f;
5220 
5221   // Standardizing coordinates so that (xmin, ymin) is upper left corner of a
5222   // box and (xmax, ymax) is lower right corner of the box.
5223   if (!centerPointBox) {
5224     // 0 means coordinates for diagonal ends of a box.
5225     // Coordinates can either be absolute or normalized.
5226     maxMin(sx[0], sx[2], xSMin, xSMax);
5227     maxMin(sx[1], sx[3], ySMin, ySMax);
5228 
5229     maxMin(cx[0], cx[2], xCMin, xCMax);
5230     maxMin(cx[1], cx[3], yCMin, yCMax);
5231   } else {
5232     float halfWidthS = sx[2] / 2.0f;
5233     float halfHeightS = sx[3] / 2.0f;
5234     float halfWidthC = cx[2] / 2.0f;
5235     float halfHeightC = cx[3] / 2.0f;
5236 
5237     xSMin = sx[0] - halfWidthS;
5238     ySMin = sx[1] - halfHeightS;
5239     xSMax = sx[0] + halfWidthS;
5240     ySMax = sx[1] + halfHeightS;
5241 
5242     xCMin = cx[0] - halfWidthC;
5243     yCMin = cx[1] - halfHeightC;
5244     xCMax = cx[0] + halfWidthC;
5245     yCMax = cx[1] + halfHeightC;
5246   }
5247 
5248   // finding upper left and lower right corner of a box formed by intersection.
5249   float xMin = std::max(xSMin, xCMin);
5250   float yMin = std::max(ySMin, yCMin);
5251   float xMax = std::min(xSMax, xCMax);
5252   float yMax = std::min(ySMax, yCMax);
5253 
5254   float intersectionArea =
5255       std::max(0.0f, xMax - xMin) * std::max(0.0f, yMax - yMin);
5256 
5257   if (intersectionArea == 0.0f) {
5258     return false;
5259   }
5260 
5261   float sArea = (xSMax - xSMin) * (ySMax - ySMin);
5262   float cArea = (xCMax - xCMin) * (yCMax - yCMin);
5263   float unionArea = sArea + cArea - intersectionArea;
5264 
5265   return intersectionArea > iouThreshold * unionArea;
5266 }
5267 
5268 template <typename T>
fwdNonMaxSuppressionInstImpl(glow::NonMaxSuppressionInst const * I)5269 void BoundInterpreterFunction::fwdNonMaxSuppressionInstImpl(
5270     glow::NonMaxSuppressionInst const *I) {
5271 
5272   auto boxes = I->getBoxes();
5273   auto scores = I->getScores();
5274   auto indices = I->getIndices();
5275   auto numDetected = I->getNumberOfSelectedIndices();
5276   float iouThreshold = I->getIouThreshold();
5277   dim_t maxBoxesPerClass = I->getMaxOutputBoxesPerClass();
5278   float scoreThreshold = I->getScoreThreshold();
5279   unsigned centerPointBox = I->getCenterPointBox();
5280   bool isV4 = I->getIsTFVersion4();
5281 
5282   auto boxesH = getTensor(boxes)->getHandle<float>();
5283   auto scoresH = getTensor(scores)->getHandle<float>();
5284   auto indicesH = getTensor(indices)->getHandle<T>();
5285   auto numDetectedH = getTensor(numDetected)->getHandle<T>();
5286 
5287   int boxesBoxDim = boxes->dims().size() - 2;
5288 
5289   dim_t numBatches = 1;
5290   dim_t numClasses = 1;
5291   dim_t numBoxes = boxes->dims()[boxesBoxDim];
5292 
5293   size_t maxOutputPerBatch = 0;
5294 
5295   if (!isV4) {
5296     int boxesBatchDim = boxes->dims().size() - 3;
5297 
5298     int scoresBatchDim = scores->dims().size() - 3;
5299     int scoresBoxDim = scores->dims().size() - 1;
5300     int scoresClassDim = scores->dims().size() - 2;
5301     assert(scores->dims()[scoresBoxDim] == boxes->dims()[boxesBoxDim] &&
5302            "Mismatch between number of scores and number of boxes.");
5303     assert(scores->dims()[scoresBatchDim] == boxes->dims()[boxesBatchDim] &&
5304            "Mismatch in batch dimension.");
5305     (void)boxesBatchDim;
5306     (void)scoresBoxDim;
5307     numBatches = scores->dims()[scoresBatchDim];
5308     numClasses = scores->dims()[scoresClassDim];
5309     numBoxes = boxes->dims()[boxesBoxDim];
5310     maxOutputPerBatch =
5311         indices->dims()[indices->dims().size() - 2] / numBatches;
5312   } else {
5313     maxOutputPerBatch =
5314         indices->dims()[indices->dims().size() - 1] / numBatches;
5315   }
5316 
5317   auto cmpFunc = [](const ClassBox &a, const ClassBox &b) {
5318     return a.first < b.first;
5319   };
5320 
5321   std::vector<ClassBox> selectedIndices(numBoxes);
5322   dim_t outPutBoxIndex = 0;
5323 
5324   for (dim_t batchIndex = 0; batchIndex < numBatches; ++batchIndex) {
5325     Box minBox{scoresH.raw(batchIndex * numClasses * numBoxes), batchIndex, 0,
5326                0};
5327     int32_t detectedPerBatch = 0;
5328     for (dim_t classIndex = 0; classIndex < numClasses; ++classIndex) {
5329       selectedIndices.clear();
5330       size_t detectedPerClass = 0;
5331       std::priority_queue<ClassBox, std::vector<ClassBox>, decltype(cmpFunc)>
5332           queue(cmpFunc);
5333 
5334       for (size_t boxIndex = 0; boxIndex < numBoxes; ++boxIndex) {
5335         float classValue = scoresH.raw(
5336             (batchIndex * numClasses + classIndex) * numBoxes + boxIndex);
5337         if (classValue > scoreThreshold) {
5338           queue.emplace(classValue, boxIndex);
5339         }
5340       }
5341 
5342       float tScore = minBox.classValue;
5343       while (!queue.empty()) {
5344         auto priorBox = queue.top();
5345         queue.pop();
5346 
5347         bool selected = true;
5348         for (auto &sBox : selectedIndices) {
5349           if (doIOU(boxesH, batchIndex, sBox.second, priorBox.second,
5350                     centerPointBox, iouThreshold, isV4)) {
5351             selected = false;
5352             break;
5353           }
5354         }
5355 
5356         if (selected) {
5357           selectedIndices.emplace_back(priorBox);
5358           if (isV4) {
5359             indicesH.at({outPutBoxIndex}) = priorBox.second;
5360             tScore = scoresH.at({priorBox.second});
5361           } else {
5362             indicesH.at({outPutBoxIndex, 0}) = batchIndex;
5363             indicesH.at({outPutBoxIndex, 1}) = classIndex;
5364             indicesH.at({outPutBoxIndex, 2}) = priorBox.second;
5365             tScore = scoresH.at({batchIndex, classIndex, priorBox.second});
5366           }
5367 
5368           ++outPutBoxIndex;
5369           ++detectedPerClass;
5370           ++detectedPerBatch;
5371         }
5372         if (maxBoxesPerClass == detectedPerClass) {
5373           break;
5374         }
5375       }
5376 
5377       if (tScore < minBox.classValue) {
5378         minBox.classValue = tScore;
5379         minBox.classIndex = classIndex;
5380         if (isV4) {
5381           minBox.boxIndex = indicesH.at({outPutBoxIndex - 1});
5382         } else {
5383           minBox.boxIndex = indicesH.at({outPutBoxIndex - 1, 2});
5384         }
5385       }
5386     }
5387 
5388     for (size_t i = detectedPerBatch; i < maxOutputPerBatch; ++i) {
5389       if (isV4) {
5390         indicesH.at({outPutBoxIndex}) = minBox.boxIndex;
5391       } else {
5392         indicesH.at({outPutBoxIndex, 0}) = minBox.batchIndex;
5393         indicesH.at({outPutBoxIndex, 1}) = minBox.classIndex;
5394         indicesH.at({outPutBoxIndex, 2}) = minBox.boxIndex;
5395       }
5396 
5397       ++outPutBoxIndex;
5398     }
5399     // For ONNX NMS it's not used, for TF Batch Dimension is 1.
5400     for (dim_t i = 0; i < maxBoxesPerClass; ++i) {
5401       numDetectedH.at({batchIndex * maxBoxesPerClass + i}) = detectedPerBatch;
5402     }
5403   }
5404 }
5405 
fwdNonMaxSuppressionInst(glow::NonMaxSuppressionInst const * I)5406 void BoundInterpreterFunction::fwdNonMaxSuppressionInst(
5407     glow::NonMaxSuppressionInst const *I) {
5408   switch (I->getBoxes()->getElementType()) {
5409   case ElemKind::FloatTy:
5410     if (I->getIndices()->getElementType() == ElemKind::Int32ITy) {
5411       fwdNonMaxSuppressionInstImpl<int32_t>(I);
5412     } else if (I->getIndices()->getElementType() == ElemKind::Int64ITy) {
5413       fwdNonMaxSuppressionInstImpl<int64_t>(I);
5414     } else {
5415       llvm_unreachable("Output type is not supported.");
5416     }
5417     break;
5418   default:
5419     llvm_unreachable("Type is not supported.");
5420     break;
5421   }
5422 }
5423 
fwdAudioSpectrogramInstFloatImpl(glow::AudioSpectrogramInst const * I)5424 void BoundInterpreterFunction::fwdAudioSpectrogramInstFloatImpl(
5425     glow::AudioSpectrogramInst const *I) {
5426 
5427   auto spectrogram = I->getSpectrogram();
5428   auto input = I->getInput();
5429   auto window = I->getWindow();
5430   int64_t windowSize = I->getWindowSize();
5431   int64_t windowStride = I->getWindowStride();
5432 
5433   auto spectrogramH = getTensor(spectrogram)->getHandle<float>();
5434   auto inputH = getTensor(input)->getHandle<float>();
5435   auto windowH = getTensor(window)->getHandle<float>();
5436 
5437   // Compute window count.
5438   int64_t inputLength = input->size();
5439   int64_t windowCount =
5440       std::floor((inputLength - windowSize) / windowStride) + 1;
5441 
5442   // Compute FFT length (next power of 2) and spectrogram length.
5443   dim_t fftLen = 1 << (dim_t)std::ceil(std::log2((double)windowSize));
5444   dim_t specLen = fftLen / 2 + 1;
5445 
5446   // Allocate temporary buffers.
5447   auto winOut = std::make_unique<float[]>(windowSize);
5448   auto fftRealOut = std::make_unique<float[]>(specLen);
5449   auto fftImagOut = std::make_unique<float[]>(specLen);
5450 
5451   // Compute the spectrogram.
5452   for (dim_t winIdx = 0; winIdx < windowCount; winIdx++) {
5453 
5454     // Windowing.
5455     for (dim_t n = 0; n < windowSize; n++) {
5456       winOut[n] = inputH.raw(winIdx * windowStride + n) * windowH.raw(n);
5457     }
5458 
5459     // Compute spectrum (perform FFT).
5460     for (int k = 0; k < specLen; k++) {
5461       fftRealOut[k] = 0;
5462       fftImagOut[k] = 0;
5463       for (int n = 0; n < windowSize; n++) {
5464         fftRealOut[k] +=
5465             winOut[n] * cos(2.0 * M_PI * (double)(n * k) / (double)(fftLen));
5466         fftImagOut[k] -=
5467             winOut[n] * sin(2.0 * M_PI * (double)(n * k) / (double)(fftLen));
5468       }
5469     }
5470 
5471     // Compute spectrum magnitude/power.
5472     if (I->getMagnitudeSquared()) {
5473       for (dim_t k = 0; k < specLen; k++) {
5474         spectrogramH.at({winIdx, k}) =
5475             fftRealOut[k] * fftRealOut[k] + fftImagOut[k] * fftImagOut[k];
5476       }
5477     } else {
5478       for (dim_t k = 0; k < specLen; k++) {
5479         spectrogramH.at({winIdx, k}) =
5480             sqrt(fftRealOut[k] * fftRealOut[k] + fftImagOut[k] * fftImagOut[k]);
5481       }
5482     }
5483   }
5484 }
5485 
fwdAudioSpectrogramInst(glow::AudioSpectrogramInst const * I)5486 void BoundInterpreterFunction::fwdAudioSpectrogramInst(
5487     glow::AudioSpectrogramInst const *I) {
5488   auto inputTy = I->getInput()->getElementType();
5489   auto spectrogramTy = I->getSpectrogram()->getElementType();
5490   if ((inputTy == ElemKind::FloatTy) && (spectrogramTy == ElemKind::FloatTy)) {
5491     fwdAudioSpectrogramInstFloatImpl(I);
5492   } else {
5493     llvm_unreachable("Type is not supported.");
5494   }
5495 }
5496 
fwdMFCCInstFloatImpl(glow::MFCCInst const * I)5497 void BoundInterpreterFunction::fwdMFCCInstFloatImpl(glow::MFCCInst const *I) {
5498 
5499   auto coefficients = I->getCoefficients();
5500   auto spectrogram = I->getSpectrogram();
5501   auto melWeights = I->getMelWeights();
5502   auto melRanges = I->getMelRanges();
5503   auto dctMat = I->getDctMat();
5504   int64_t filterBankCount = I->getFilterBankCount();
5505   int64_t numCoefficients = I->getNumCoefficients();
5506 
5507   auto coefficientsH = getTensor(coefficients)->getHandle<float>();
5508   auto spectrogramH = getTensor(spectrogram)->getHandle<float>();
5509   auto melWeightsH = getTensor(melWeights)->getHandle<float>();
5510   auto melRangesH = getTensor(melRanges)->getHandle<int32_t>();
5511   auto dctMatH = getTensor(dctMat)->getHandle<float>();
5512 
5513   // Perform MFCC for all the windows.
5514   auto winNum = spectrogram->dims()[0];
5515   auto melBuff = std::make_unique<float[]>(filterBankCount);
5516   for (dim_t winIdx = 0; winIdx < winNum; winIdx++) {
5517 
5518     // Apply Mel filter bank mapping. We use sqrt for the spectrogram since we
5519     // assume the spectrogram is a power value and not a magnitude.
5520     dim_t melBinCoeffIdx = 0;
5521     for (dim_t melIdx = 0; melIdx < filterBankCount; melIdx++) {
5522       int32_t freqIdxStart = melRangesH.raw(2 * melIdx + 0);
5523       int32_t freqIdxStop = melRangesH.raw(2 * melIdx + 1);
5524       float melPwr = 0.0f;
5525       for (dim_t freqIdx = freqIdxStart; freqIdx <= freqIdxStop; freqIdx++) {
5526         melPwr += std::sqrt(spectrogramH.at({winIdx, freqIdx})) *
5527                   melWeightsH.raw(melBinCoeffIdx++);
5528       }
5529       melBuff[melIdx] = melPwr;
5530     }
5531 
5532     // Take logarithm in-place (avoid log(0)).
5533     for (dim_t melIdx = 0; melIdx < filterBankCount; melIdx++) {
5534       float melPwr = melBuff[melIdx];
5535       melBuff[melIdx] = (melPwr == 0.0)
5536                             ? logf(std::numeric_limits<float>::min())
5537                             : logf(melPwr);
5538     }
5539 
5540     // Compute DCT transform.
5541     for (dim_t k = 0; k < numCoefficients; k++) {
5542       float dctOut = 0.0f;
5543       for (dim_t n = 0; n < filterBankCount; n++) {
5544         dctOut += dctMatH.at({k, n}) * melBuff[n];
5545       }
5546       coefficientsH.at({winIdx, k}) = dctOut;
5547     }
5548   }
5549 }
5550 
fwdMFCCInst(glow::MFCCInst const * I)5551 void BoundInterpreterFunction::fwdMFCCInst(glow::MFCCInst const *I) {
5552   auto spectrogramTy = I->getSpectrogram()->getElementType();
5553   auto coefficientsTy = I->getCoefficients()->getElementType();
5554   if ((spectrogramTy == ElemKind::FloatTy) &&
5555       (coefficientsTy == ElemKind::FloatTy)) {
5556     fwdMFCCInstFloatImpl(I);
5557   } else {
5558     llvm_unreachable("Type is not supported.");
5559   }
5560 }
5561