1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <algorithm>
17 #include <assert.h>
18 #include <chrono>
19 #include <cmath>
20 #include <math.h>
21 #include <numeric>
22 #include <stddef.h>
23 #include <stdint.h>
24 #include <stdio.h>
25 #include <stdlib.h>
26 #include <string.h>
27 #include <sys/types.h>
28 
29 #include "libjit_defs.h"
30 
31 namespace {
32 
33 template <class ElemTy>
libjit_dump_tensor_console_impl(ElemTy * tensor,dim_t * dims,dim_t numDims)34 static void libjit_dump_tensor_console_impl(ElemTy *tensor, dim_t *dims,
35                                             dim_t numDims) {
36   // Check for 0-dimensional tensor.
37   if (!numDims) {
38     printf("[ Scalar containing: %.3f ]\n", (float)tensor[0]);
39     return;
40   }
41 
42   // Output shape.
43   printf("shape: ( ");
44   for (size_t i = 0; i < numDims; ++i) {
45     printf("%zu ", (size_t)dims[i]);
46   }
47   printf(")\n");
48 
49   ElemTy mx = tensor[0];
50   ElemTy mn = tensor[0];
51 
52   size_t size = 1;
53   size_t sliceSize[numDims];
54   for (size_t i = 0; i < numDims; ++i) {
55     size *= dims[i];
56   }
57 
58   for (ssize_t i = numDims - 1, curSliceSize = 1; i >= 0; --i) {
59     sliceSize[i] = curSliceSize;
60     curSliceSize *= dims[i];
61   }
62 
63   for (size_t i = 0, e = size; i < e; i++) {
64     mx = MAX(mx, tensor[i]);
65     mn = MIN(mn, tensor[i]);
66   }
67 
68   // Check for zero tensor.
69   if (mn == .0 && mx == .0) {
70     printf("[ Zero tensor ]\n");
71     return;
72   }
73 
74   // Output max and min.
75   printf("max: %.3f  min: %.3f\n", (float)mx, (float)mn);
76 
77   const unsigned maxNumElem = 100;
78 
79   printf("[");
80 
81   for (size_t i = 0, e = MIN(maxNumElem, size); i < e; i++) {
82 
83     // Print one open brace at the beginning of every row, slice, and tensor.
84     for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
85       if (i % sliceSize[j] == 0) {
86         // This iteration of outer loop is a new row, slice or tensor.
87         printf("[");
88       }
89     }
90 
91     // Print the value at the current index.
92     printf("%.3f", (float)tensor[i]);
93 
94     // Print one closed brace at the end of every row, slice, or tensor.
95     for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
96       size_t next_index = i + 1;
97       if (next_index % sliceSize[j] == 0u) {
98         printf("]");
99       }
100     }
101 
102     printf(", ");
103 
104     // Print one newline at the end of every row, slice, or tensor.
105     for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
106       size_t next_index = i + 1;
107       if (next_index % sliceSize[j] == 0u) {
108         // Next iteration of outer loop will be a new row, slice or tensor.
109         printf("\n");
110       }
111     }
112   }
113 
114   if (size > maxNumElem) {
115     printf("...");
116   }
117 
118   printf("]\n");
119 }
120 
121 template <class ElemTy>
libjit_dump_tensor_txt_impl(ElemTy * tensor,size_t tensorElemSize,const char * filename,const char * header)122 static void libjit_dump_tensor_txt_impl(ElemTy *tensor, size_t tensorElemSize,
123                                         const char *filename,
124                                         const char *header) {
125   FILE *fh = fopen(filename, "w");
126   if (!fh) {
127     printf("ERROR opening file: '%s'!\n"
128            "File name might be too long!\n",
129            filename);
130     return;
131   }
132   if (strlen(header)) {
133     fprintf(fh, "%s\n", header);
134   }
135   for (size_t idx = 0, end = tensorElemSize; idx < end; idx++) {
136     fprintf(fh, "%f, ", (double)tensor[idx]);
137   }
138   fclose(fh);
139 }
140 
141 template <typename ElemTy>
get_element_ptr(const ElemTy * tensor,const dim_t * dims,dim_t numDims,const dim_t * indices,dim_t numIndices)142 static dim_t get_element_ptr(const ElemTy *tensor, const dim_t *dims,
143                              dim_t numDims, const dim_t *indices,
144                              dim_t numIndices) {
145   dim_t index = 0;
146   dim_t subdimensionSize = 1;
147   for (dim_t i = numDims; i > 0; i--) {
148     dim_t curIndicesValue = (i <= numIndices) ? indices[i - 1] : 0;
149     index += subdimensionSize * curIndicesValue;
150     subdimensionSize *= dims[i - 1];
151   }
152   return index;
153 }
154 
155 template <typename ElemTy>
libjit_insert_tensor(ElemTy * tensor,ElemTy * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim,dim_t count,dim_t axis)156 static void libjit_insert_tensor(ElemTy *tensor, ElemTy *slice, dim_t *offset,
157                                  dim_t *tensorDim, dim_t *sliceDim,
158                                  dim_t numDimsTensor, dim_t numDimsSlice,
159                                  dim_t offsetDim, dim_t count, dim_t axis) {
160   // Destination coordinates.
161   dim_t C[5];
162 
163   // A local copy of the offsets buffer. We copy the buffer to make it clear
164   // to the optimizer that the inputs don't alias. This loop is optimized away.
165   dim_t offsets_cpy[5];
166   for (dim_t i = 0; i < numDimsSlice; i++) {
167     offsets_cpy[i] = offset[i];
168   }
169 
170   if (numDimsSlice == 5) {
171     for (dim_t c = 0; c < count; c++)
172       for (dim_t x = 0; x < sliceDim[0]; x++)
173         for (dim_t y = 0; y < sliceDim[1]; y++)
174           for (dim_t z = 0; z < sliceDim[2]; z++)
175             for (dim_t w = 0; w < sliceDim[3]; w++)
176               for (dim_t q = 0; q < sliceDim[4]; q++) {
177                 const dim_t countAxisOffset = c * sliceDim[axis];
178                 C[0] = x + offsets_cpy[0] + ((axis == 0) ? countAxisOffset : 0);
179                 C[1] = y + offsets_cpy[1] + ((axis == 1) ? countAxisOffset : 0);
180                 C[2] = z + offsets_cpy[2] + ((axis == 2) ? countAxisOffset : 0);
181                 C[3] = w + offsets_cpy[3] + ((axis == 3) ? countAxisOffset : 0);
182                 C[4] = q + offsets_cpy[4] + ((axis == 4) ? countAxisOffset : 0);
183                 tensor[libjit_getXYZWQ(tensorDim, C[0], C[1], C[2], C[3],
184                                        C[4])] =
185                     slice[libjit_getXYZWQ(sliceDim, x, y, z, w, q)];
186               }
187     return;
188   }
189 
190   if (numDimsSlice == 4) {
191     for (dim_t c = 0; c < count; c++)
192       for (dim_t x = 0; x < sliceDim[0]; x++)
193         for (dim_t y = 0; y < sliceDim[1]; y++)
194           for (dim_t z = 0; z < sliceDim[2]; z++)
195             for (dim_t w = 0; w < sliceDim[3]; w++) {
196               const dim_t countAxisOffset = c * sliceDim[axis];
197               C[0] = x + offsets_cpy[0] + ((axis == 0) ? countAxisOffset : 0);
198               C[1] = y + offsets_cpy[1] + ((axis == 1) ? countAxisOffset : 0);
199               C[2] = z + offsets_cpy[2] + ((axis == 2) ? countAxisOffset : 0);
200               C[3] = w + offsets_cpy[3] + ((axis == 3) ? countAxisOffset : 0);
201               tensor[libjit_getXYZW(tensorDim, C[0], C[1], C[2], C[3])] =
202                   slice[libjit_getXYZW(sliceDim, x, y, z, w)];
203             }
204     return;
205   }
206 
207   if (numDimsSlice == 3) {
208     for (dim_t c = 0; c < count; c++)
209       for (dim_t x = 0; x < sliceDim[0]; x++)
210         for (dim_t y = 0; y < sliceDim[1]; y++)
211           for (dim_t z = 0; z < sliceDim[2]; z++) {
212             const dim_t countAxisOffset = c * sliceDim[axis];
213             C[0] = x + offsets_cpy[0] + ((axis == 0) ? countAxisOffset : 0);
214             C[1] = y + offsets_cpy[1] + ((axis == 1) ? countAxisOffset : 0);
215             C[2] = z + offsets_cpy[2] + ((axis == 2) ? countAxisOffset : 0);
216             tensor[libjit_getXYZ(tensorDim, C[0], C[1], C[2])] =
217                 slice[libjit_getXYZ(sliceDim, x, y, z)];
218           }
219     return;
220   }
221 
222   if (numDimsSlice == 2) {
223     for (dim_t c = 0; c < count; c++)
224       for (dim_t x = 0; x < sliceDim[0]; x++)
225         for (dim_t y = 0; y < sliceDim[1]; y++) {
226           const dim_t countAxisOffset = c * sliceDim[axis];
227           C[0] = x + offsets_cpy[0] + ((axis == 0) ? countAxisOffset : 0);
228           C[1] = y + offsets_cpy[1] + ((axis == 1) ? countAxisOffset : 0);
229           tensor[libjit_getXY(tensorDim, C[0], C[1])] =
230               slice[libjit_getXY(sliceDim, x, y)];
231         }
232     return;
233   }
234 
235   if (numDimsSlice == 1) {
236     for (dim_t c = 0; c < count; c++)
237       for (dim_t x = 0; x < sliceDim[0]; x++) {
238         const dim_t countAxisOffset = c * sliceDim[axis];
239         tensor[x + offsets_cpy[0] + ((axis == 0) ? countAxisOffset : 0)] =
240             slice[x];
241       }
242     return;
243   }
244 }
245 
246 template <typename ElemTy>
libjit_extract_tensor(ElemTy * tensor,ElemTy * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim)247 static void libjit_extract_tensor(ElemTy *tensor, ElemTy *slice, dim_t *offset,
248                                   dim_t *tensorDim, dim_t *sliceDim,
249                                   dim_t numDimsTensor, dim_t numDimsSlice,
250                                   dim_t offsetDim) {
251   // Source coordinates.
252   dim_t C[5];
253 
254   // A local copy of the offsets buffer. We copy the buffer to make it clear
255   // to the optimizer that the inputs don't alias. This loop is optimized away.
256   dim_t offsets_cpy[5];
257   for (dim_t i = 0; i < numDimsSlice; i++) {
258     offsets_cpy[i] = offset[i];
259   }
260 
261   if (numDimsSlice == 5) {
262     for (dim_t x = 0; x < sliceDim[0]; x++)
263       for (dim_t y = 0; y < sliceDim[1]; y++)
264         for (dim_t z = 0; z < sliceDim[2]; z++)
265           for (dim_t w = 0; w < sliceDim[3]; w++)
266             for (dim_t q = 0; q < sliceDim[4]; q++) {
267               C[0] = x + offsets_cpy[0];
268               C[1] = y + offsets_cpy[1];
269               C[2] = z + offsets_cpy[2];
270               C[3] = w + offsets_cpy[3];
271               C[4] = q + offsets_cpy[4];
272               slice[libjit_getXYZWQ(sliceDim, x, y, z, w, q)] =
273                   tensor[libjit_getXYZWQ(tensorDim, C[0], C[1], C[2], C[3],
274                                          C[4])];
275             }
276     return;
277   }
278 
279   if (numDimsSlice == 4) {
280     for (dim_t x = 0; x < sliceDim[0]; x++)
281       for (dim_t y = 0; y < sliceDim[1]; y++)
282         for (dim_t z = 0; z < sliceDim[2]; z++)
283           for (dim_t w = 0; w < sliceDim[3]; w++) {
284             C[0] = x + offsets_cpy[0];
285             C[1] = y + offsets_cpy[1];
286             C[2] = z + offsets_cpy[2];
287             C[3] = w + offsets_cpy[3];
288             slice[libjit_getXYZW(sliceDim, x, y, z, w)] =
289                 tensor[libjit_getXYZW(tensorDim, C[0], C[1], C[2], C[3])];
290           }
291     return;
292   }
293 
294   if (numDimsSlice == 3) {
295     for (dim_t x = 0; x < sliceDim[0]; x++)
296       for (dim_t y = 0; y < sliceDim[1]; y++)
297         for (dim_t z = 0; z < sliceDim[2]; z++) {
298           C[0] = x + offsets_cpy[0];
299           C[1] = y + offsets_cpy[1];
300           C[2] = z + offsets_cpy[2];
301           slice[libjit_getXYZ(sliceDim, x, y, z)] =
302               tensor[libjit_getXYZ(tensorDim, C[0], C[1], C[2])];
303         }
304     return;
305   }
306 
307   if (numDimsSlice == 2) {
308     for (dim_t x = 0; x < sliceDim[0]; x++)
309       for (dim_t y = 0; y < sliceDim[1]; y++) {
310         C[0] = x + offsets_cpy[0];
311         C[1] = y + offsets_cpy[1];
312         slice[libjit_getXY(sliceDim, x, y)] =
313             tensor[libjit_getXY(tensorDim, C[0], C[1])];
314       }
315     return;
316   }
317 
318   if (numDimsSlice == 1) {
319     for (dim_t x = 0; x < sliceDim[0]; x++) {
320       slice[x] = tensor[x + offsets_cpy[0]];
321     }
322     return;
323   }
324 }
325 
326 /// Helper struct for TopK
327 template <typename T, typename TI> struct value_index {
328   TI index;
329   T value;
330 };
331 
332 /// Helper function for TopK
333 template <typename T, typename TI>
value_index_sort(const void * va,const void * vb)334 static int value_index_sort(const void *va, const void *vb) {
335   value_index<T, TI> *a = (value_index<T, TI> *)va;
336   value_index<T, TI> *b = (value_index<T, TI> *)vb;
337   if (a->value != b->value)
338     return a->value > b->value ? -1 : 1;
339   return a->index < b->index ? -1 : 1;
340 }
341 
342 /// Generic Top-K function. Here, \p scratch is some allocated buffer space, \p
343 /// size is the size of the input, and \p n is the size of the last dimension of
344 /// the input.
345 template <typename T, typename TI>
libjit_topk(T * values,TI * indices,const T * input,void * scratch,dim_t k,dim_t n,dim_t size)346 static void libjit_topk(T *values, TI *indices, const T *input, void *scratch,
347                         dim_t k, dim_t n, dim_t size) {
348   dim_t in = 0;
349   dim_t out = 0;
350 
351   // Initialize scratch with 0.
352   memset(scratch, 0, 2 * n * sizeof(TI));
353 
354   value_index<T, TI> *buffer = (value_index<T, TI> *)scratch;
355 
356   // Specialize TopK for the case where K is 1.
357   if (k == 1) {
358     while (in < size) {
359       // Find the largest value by iterating over the array instead of calling
360       // 'sort'.
361       value_index<T, TI> mx = {0, input[in]};
362       for (TI i = 1; i < n; i++) {
363         if (input[i + in] > mx.value) {
364           mx = {i, input[i + in]};
365         }
366       }
367       indices[out] = mx.index;
368       values[out] = mx.value;
369       out++;
370       in += n;
371     }
372     return;
373   }
374 
375   while (in < size) {
376     for (dim_t i = 0; i < n; i++) {
377       buffer[i].index = i;
378       buffer[i].value = input[in++];
379     }
380     qsort(buffer, n, sizeof(value_index<T, TI>), value_index_sort<T, TI>);
381     for (dim_t i = 0; i < k; i++) {
382       indices[out] = buffer[i].index;
383       values[out] = buffer[i].value;
384       out++;
385     }
386   }
387 }
388 
389 template <typename T, typename IDX>
libjit_gather(T * dest,const T * data,const IDX * indices,dim_t numIndices,dim_t sliceSize,dim_t numSamples,dim_t sampleSize)390 static void libjit_gather(T *dest, const T *data, const IDX *indices,
391                           dim_t numIndices, dim_t sliceSize, dim_t numSamples,
392                           dim_t sampleSize) {
393   // The index of the slice that is being written.
394   dim_t outIdx = 0;
395 
396   // For each sample in our batch:
397   for (dim_t sample = 0; sample < numSamples; sample++) {
398     dim_t sampleStart = sample * sampleSize;
399 
400     // For each slice that we fetch:
401     for (dim_t i = 0; i < numIndices; i++) {
402       dim_t slice = indices[i];
403 
404       // Copy the slice.
405       memcpy(dest + outIdx * sliceSize, data + sampleStart + slice * sliceSize,
406              sliceSize * sizeof(T));
407 
408       // Point to the next location in the destination tensor.
409       outIdx++;
410     }
411   }
412 }
413 
414 template <typename T, typename U>
libjit_gatherranges(T * output,U * lengths,const T * data,const U * ranges,dim_t numExamples,dim_t exampleSize)415 static void libjit_gatherranges(T *output, U *lengths, const T *data,
416                                 const U *ranges, dim_t numExamples,
417                                 dim_t exampleSize) {
418   // Indices into the output and range buffers.
419   dim_t outputIdx = 0;
420   dim_t rangesIdx = 0;
421 
422   // For each example:
423   for (dim_t example = 0; example < numExamples; ++example) {
424     // Keep track of the total length of the gathered ranges for the example.
425     U totalLen = 0;
426 
427     // For each range:
428     for (dim_t range = 0; range < exampleSize; ++range) {
429       // Get the start and length of the range.
430       const U start = ranges[rangesIdx];
431       const U len = ranges[rangesIdx + 1];
432 
433       // Copy the specified elements.
434       memcpy(output + outputIdx, data + start, len * sizeof(T));
435 
436       // len elements were copied, so increment the output index by len.
437       outputIdx += len;
438 
439       // Each range is of the form (start, len), so increment the ranges
440       // index by 2 to get to the next range.
441       rangesIdx += 2;
442 
443       // Increment the total length for the example by len.
444       totalLen += len;
445     }
446 
447     // Record the total length of gathered ranges for the current example in
448     // the lengths buffer.
449     lengths[example] = totalLen;
450   }
451 }
452 
453 template <typename T, typename T2>
libjit_scatterdatacopy(T * data,const dim_t * dataDims,const T2 * indices,const T * slices,dim_t numIndices,dim_t indexSize,dim_t sliceSize)454 static void libjit_scatterdatacopy(T *data, const dim_t *dataDims,
455                                    const T2 *indices, const T *slices,
456                                    dim_t numIndices, dim_t indexSize,
457                                    dim_t sliceSize) {
458   for (dim_t i = 0; i < numIndices; i++) {
459     dim_t destDataIdx = indices[i * indexSize];
460     for (dim_t j = 1; j < indexSize; j++) {
461       destDataIdx *= dataDims[j];
462       destDataIdx += indices[i * indexSize + j];
463     }
464     memcpy(data + destDataIdx * sliceSize, slices + i * sliceSize,
465            sliceSize * sizeof(T));
466   }
467 }
468 
469 template <typename T, typename T2>
libjit_scatterdataaddfloat(T * data,const dim_t * dataDims,const T2 * indices,const T * slices,dim_t numIndices,dim_t indexSize,dim_t sliceSize)470 static void libjit_scatterdataaddfloat(T *data, const dim_t *dataDims,
471                                        const T2 *indices, const T *slices,
472                                        dim_t numIndices, dim_t indexSize,
473                                        dim_t sliceSize) {
474   for (dim_t i = 0; i < numIndices; i++) {
475     dim_t destDataIdx = indices[i * indexSize];
476     for (dim_t j = 1; j < indexSize; j++) {
477       destDataIdx *= dataDims[j];
478       destDataIdx += indices[i * indexSize + j];
479     }
480     for (dim_t j = 0; j < sliceSize; j++) {
481       data[destDataIdx * sliceSize + j] += slices[i * sliceSize + j];
482     }
483   }
484 }
485 
486 template <typename T, typename T2>
libjit_scatterdataaddquantized(T * data,const dim_t * dataDims,const T2 * indices,const T * slices,dim_t numIndices,dim_t indexSize,dim_t sliceSize,float dataScale,int32_t dataOffset,float sliceScale,int32_t sliceOffset)487 static void libjit_scatterdataaddquantized(T *data, const dim_t *dataDims,
488                                            const T2 *indices, const T *slices,
489                                            dim_t numIndices, dim_t indexSize,
490                                            dim_t sliceSize, float dataScale,
491                                            int32_t dataOffset, float sliceScale,
492                                            int32_t sliceOffset) {
493 
494   for (size_t i = 0; i < numIndices; i++) {
495     size_t destDataIdx = indices[i * indexSize];
496     for (size_t j = 1; j < indexSize; j++) {
497       destDataIdx *= dataDims[j];
498       destDataIdx += indices[i * indexSize + j];
499     }
500     for (size_t j = 0; j < sliceSize; j++) {
501       float lhs = (data[destDataIdx * sliceSize + j] - dataOffset) * dataScale;
502       float rhs = (slices[i * sliceSize + j] - sliceOffset) * sliceScale;
503       T result = libjit_clip((lhs + rhs) / dataScale + dataOffset);
504       data[destDataIdx * sliceSize + j] = result;
505     }
506   }
507 }
508 
509 template <typename T>
libjit_transpose_generic(const T * inW,T * outW,const dim_t * idim,const dim_t * odim,const dim_t * shuffle,dim_t numDims)510 static void libjit_transpose_generic(const T *inW, T *outW, const dim_t *idim,
511                                      const dim_t *odim, const dim_t *shuffle,
512                                      dim_t numDims) {
513   // Transpose 2d matrices one tile at a time. This access pattern ensures
514   // that the whole tile is kept in L1 cache. When scanning the whole row at
515   // once we invalidate many cache lines when we touch a single column.
516   const unsigned tileSize = 64;
517 
518   // Source coordinate.
519   dim_t SC[5];
520 
521   if (numDims == 5) {
522     for (dim_t x = 0; x < odim[0]; x++)
523       for (dim_t y = 0; y < odim[1]; y++)
524         for (dim_t z = 0; z < odim[2]; z++)
525           for (dim_t w = 0; w < odim[3]; w++)
526             for (dim_t q = 0; q < odim[4]; q++) {
527               SC[shuffle[0]] = x;
528               SC[shuffle[1]] = y;
529               SC[shuffle[2]] = z;
530               SC[shuffle[3]] = w;
531               SC[shuffle[4]] = q;
532               outW[libjit_getXYZWQ(odim, x, y, z, w, q)] =
533                   inW[libjit_getXYZWQ(idim, SC[0], SC[1], SC[2], SC[3], SC[4])];
534             }
535     return;
536   }
537   if (numDims == 4) {
538     for (dim_t x = 0; x < odim[0]; x++)
539       for (dim_t y = 0; y < odim[1]; y++)
540         for (dim_t z = 0; z < odim[2]; z++)
541           for (dim_t w = 0; w < odim[3]; w++) {
542             SC[shuffle[0]] = x;
543             SC[shuffle[1]] = y;
544             SC[shuffle[2]] = z;
545             SC[shuffle[3]] = w;
546             outW[libjit_getXYZW(odim, x, y, z, w)] =
547                 inW[libjit_getXYZW(idim, SC[0], SC[1], SC[2], SC[3])];
548           }
549     return;
550   }
551   if (numDims == 3) {
552     for (dim_t x = 0; x < odim[0]; x++) {
553       // Process the tiles in the innermost two dimensions:
554       for (dim_t sy = 0; sy < odim[1]; sy += tileSize) {
555         for (dim_t sz = 0; sz < odim[2]; sz += tileSize) {
556           // Process the inner tile:
557           for (dim_t y = sy; y < MIN(sy + tileSize, odim[1]); y++) {
558             for (dim_t z = sz; z < MIN(sz + tileSize, odim[2]); z++) {
559               SC[shuffle[0]] = x;
560               SC[shuffle[1]] = y;
561               SC[shuffle[2]] = z;
562               outW[libjit_getXYZ(odim, x, y, z)] =
563                   inW[libjit_getXYZ(idim, SC[0], SC[1], SC[2])];
564             }
565           }
566         }
567       }
568     }
569     return;
570   }
571 
572   if (numDims == 2) {
573     // Process the tiles in the matrix:
574     for (dim_t sx = 0; sx < odim[0]; sx += tileSize) {
575       for (dim_t sy = 0; sy < odim[1]; sy += tileSize) {
576         // Process the inner tile:
577         for (dim_t x = sx; x < MIN(sx + tileSize, odim[0]); x++) {
578           for (dim_t y = sy; y < MIN(sy + tileSize, odim[1]); y++) {
579             SC[shuffle[0]] = x;
580             SC[shuffle[1]] = y;
581             outW[libjit_getXY(odim, x, y)] =
582                 inW[libjit_getXY(idim, SC[0], SC[1])];
583           }
584         }
585       }
586     }
587     return;
588   }
589 }
590 
591 template <typename T>
libjit_flip_generic(const T * inW,T * outW,const dim_t * dims,dim_t axis,dim_t numDims)592 static void libjit_flip_generic(const T *inW, T *outW, const dim_t *dims,
593                                 dim_t axis, dim_t numDims) {
594 
595   // Product of outer dimensions excluding the flip dimension.
596   dim_t outerLen = 1;
597   for (dim_t idx = 0; idx < axis; idx++) {
598     outerLen *= dims[idx];
599   }
600 
601   // Flip dimension.
602   dim_t len = dims[axis];
603 
604   // Product of inner dimensions excluding the flip dimension.
605   dim_t innerLen = 1;
606   for (dim_t idx = axis + 1; idx < numDims; idx++) {
607     innerLen *= dims[idx];
608   }
609 
610   // Flip axis such that input data is read linearly.
611   const T *inpPtr = inW;
612   T *outPtr = outW + (len - 1) * innerLen;
613   for (dim_t outerIdx = 0; outerIdx < outerLen; outerIdx++) {
614     for (dim_t idx = 0; idx < len; idx++) {
615       for (dim_t innerIdx = 0; innerIdx < innerLen; innerIdx++) {
616         *outPtr++ = *inpPtr++;
617       }
618       outPtr -= 2 * innerLen;
619     }
620     outPtr += 2 * len * innerLen;
621   }
622 }
623 
624 template <typename inpT, typename outT>
libjit_arg_max_generic(const inpT * inpW,outT * outW,const dim_t * dims,size_t numDims,size_t axis)625 static void libjit_arg_max_generic(const inpT *inpW, outT *outW,
626                                    const dim_t *dims, size_t numDims,
627                                    size_t axis) {
628 
629   // Product of outer dimensions excluding the axis dimension.
630   dim_t outerLen = 1;
631   for (dim_t idx = 0; idx < axis; ++idx) {
632     outerLen *= dims[idx];
633   }
634 
635   // Axis dimension length.
636   dim_t axisLen = dims[axis];
637 
638   // Product of inner dimensions excluding the axis dimension.
639   dim_t innerLen = 1;
640   for (dim_t idx = axis + 1; idx < numDims; ++idx) {
641     innerLen *= dims[idx];
642   }
643 
644   // Traverse data such that output is written linearly.
645   const inpT *inpPtr = inpW;
646   outT *outPtr = outW;
647   for (dim_t outerIdx = 0; outerIdx < outerLen; ++outerIdx) {
648     for (dim_t innerIdx = 0; innerIdx < innerLen; ++innerIdx) {
649       inpT maxVal = std::numeric_limits<inpT>::lowest();
650       outT maxIdx = 0;
651       for (dim_t axisIdx = 0; axisIdx < axisLen; ++axisIdx) {
652         inpT inpVal = *inpPtr;
653         if (inpVal > maxVal) {
654           maxVal = inpVal;
655           maxIdx = axisIdx;
656         }
657         inpPtr += innerLen;
658       }
659       inpPtr = inpPtr - axisLen * innerLen + 1;
660       *outPtr++ = maxIdx;
661     }
662     inpPtr = inpPtr - innerLen + axisLen * innerLen;
663   }
664 }
665 
666 template <typename inpT, typename outT>
libjit_arg_min_generic(const inpT * inpW,outT * outW,const dim_t * dims,size_t numDims,size_t axis)667 static void libjit_arg_min_generic(const inpT *inpW, outT *outW,
668                                    const dim_t *dims, size_t numDims,
669                                    size_t axis) {
670 
671   // Product of outer dimensions excluding the axis dimension.
672   dim_t outerLen = 1;
673   for (dim_t idx = 0; idx < axis; ++idx) {
674     outerLen *= dims[idx];
675   }
676 
677   // Axis dimension length.
678   dim_t axisLen = dims[axis];
679 
680   // Product of inner dimensions excluding the axis dimension.
681   dim_t innerLen = 1;
682   for (dim_t idx = axis + 1; idx < numDims; ++idx) {
683     innerLen *= dims[idx];
684   }
685 
686   // Traverse data such that output is written linearly.
687   const inpT *inpPtr = inpW;
688   outT *outPtr = outW;
689   for (dim_t outerIdx = 0; outerIdx < outerLen; ++outerIdx) {
690     for (dim_t innerIdx = 0; innerIdx < innerLen; ++innerIdx) {
691       inpT minVal = std::numeric_limits<inpT>::max();
692       outT minIdx = 0;
693       for (dim_t axisIdx = 0; axisIdx < axisLen; ++axisIdx) {
694         inpT inpVal = *inpPtr;
695         if (inpVal < minVal) {
696           minVal = inpVal;
697           minIdx = axisIdx;
698         }
699         inpPtr += innerLen;
700       }
701       inpPtr = inpPtr - axisLen * innerLen + 1;
702       *outPtr++ = minIdx;
703     }
704     inpPtr = inpPtr - innerLen + axisLen * innerLen;
705   }
706 }
707 
708 template <typename T>
libjit_max_pool_generic(const T * inW,T * outW,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernelSizes,dim_t * strides,dim_t * pads)709 static void libjit_max_pool_generic(const T *inW, T *outW, const dim_t *inWdims,
710                                     const dim_t *outWdims, dim_t *kernelSizes,
711                                     dim_t *strides, dim_t *pads) {
712   dim_t pad_t = pads[0];
713   dim_t pad_l = pads[1];
714   dim_t stride_h = strides[0];
715   dim_t stride_w = strides[1];
716   dim_t kernel_h = kernelSizes[0];
717   dim_t kernel_w = kernelSizes[1];
718   // For each sample in the batch:
719   for (dim_t n = 0; n < outWdims[0]; n++) {
720     // For each (x,y) step in the input/output tensor:
721     sdim_t x = -(sdim_t)pad_t;
722     for (dim_t ax = 0; ax < outWdims[1]; x += stride_h, ax++) {
723       sdim_t y = -(sdim_t)pad_l;
724       for (dim_t ay = 0; ay < outWdims[2]; y += stride_w, ay++) {
725 
726         // For each layer in the output tensor:
727         for (dim_t z = 0; z < inWdims[3]; z++) {
728           int first = 1;
729           T max = 0;
730 
731           // For each element in the pool filter:
732           for (dim_t fx = 0; fx < kernel_h; fx++) {
733             for (dim_t fy = 0; fy < kernel_w; fy++) {
734               sdim_t ox = x + fx;
735               sdim_t oy = y + fy;
736 
737               // Ignore index access below zero (this is due to padding).
738               if (ox < 0 || oy < 0 || ox >= (sdim_t)inWdims[1] ||
739                   oy >= (sdim_t)inWdims[2]) {
740                 continue;
741               }
742 
743               float val =
744                   inW[libjit_getXYZW(inWdims, n, (dim_t)ox, (dim_t)oy, z)];
745 
746               if (first || (val >= max)) {
747                 first = 0;
748                 max = val;
749               }
750             }
751           }
752 
753           outW[libjit_getXYZW(outWdims, n, ax, ay, z)] = max;
754         } // C
755       }   // W
756     }     // H
757   }       // N
758 }
759 
760 template <typename T, typename T2>
761 static void
libjit_max_pool_argmax_generic(const T * inW,T * outW,T2 * argmax,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernels,dim_t * strides,dim_t * pads)762 libjit_max_pool_argmax_generic(const T *inW, T *outW, T2 *argmax,
763                                const dim_t *inWdims, const dim_t *outWdims,
764                                dim_t *kernels, dim_t *strides, dim_t *pads) {
765   dim_t pad_t = pads[0];
766   dim_t pad_l = pads[1];
767   dim_t stride_h = strides[0];
768   dim_t stride_w = strides[1];
769   dim_t kernel_h = kernels[0];
770   dim_t kernel_w = kernels[1];
771   // For each input in the batch:
772   for (dim_t n = 0; n < outWdims[0]; n++) {
773 
774     // For each (x,y) step in the input/output tensor:
775     sdim_t x = -(sdim_t)pad_t;
776     for (dim_t ax = 0; ax < outWdims[1]; x += stride_h, ax++) {
777       sdim_t y = -(sdim_t)pad_l;
778       for (dim_t ay = 0; ay < outWdims[2]; y += stride_w, ay++) {
779 
780         // For each channel in the output tensor:
781         for (dim_t z = 0; z < outWdims[3]; z++) {
782           int64_t argmaxNHWC = 0;
783           int first = 1;
784           T max = 0;
785 
786           for (dim_t kx = 0; kx < kernel_h; kx++) {
787             for (dim_t ky = 0; ky < kernel_w; ky++) {
788               sdim_t ox = x + kx;
789               sdim_t oy = y + ky;
790 
791               if (ox < 0 || oy < 0 || ox >= (sdim_t)inWdims[1] ||
792                   oy >= (sdim_t)inWdims[2]) {
793                 continue;
794               }
795               const dim_t flatIndex =
796                   libjit_getXYZW(inWdims, n, (dim_t)ox, (dim_t)oy, z);
797               T val = inW[flatIndex];
798               if (first || (val >= max)) {
799                 first = 0;
800                 max = val;
801                 argmaxNHWC = flatIndex;
802               }
803             }
804           }
805 
806           const dim_t flatIndex = libjit_getXYZW(outWdims, n, ax, ay, z);
807           outW[flatIndex] = max;
808           argmax[flatIndex] = argmaxNHWC;
809         } // C
810       }   // W
811     }     // H
812   }       // N
813 }
814 
815 template <typename T>
libjit_resizenearest_generic(T * dst,const T * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)816 void libjit_resizenearest_generic(T *dst, const T *src, const float *scale,
817                                   const dim_t *inWdims, const dim_t *outWdims) {
818 
819   for (dim_t ob = 0; ob < outWdims[0]; ++ob) {
820     auto ib = std::min(dim_t(ob / (scale[0])), inWdims[0] - 1);
821     for (dim_t oh = 0; oh < outWdims[1]; ++oh) {
822       auto ih = std::min(dim_t(oh / (scale[1])), inWdims[1] - 1);
823       for (dim_t ow = 0; ow < outWdims[2]; ++ow) {
824         auto iw = std::min(dim_t(ow / (scale[2])), inWdims[2] - 1);
825         for (dim_t oc = 0; oc < outWdims[3]; ++oc) {
826           auto ic = std::min(dim_t(oc / (scale[3])), inWdims[3] - 1);
827           const dim_t inIndex = libjit_getXYZW(inWdims, ib, ih, iw, ic);
828           const dim_t outIndex = libjit_getXYZW(outWdims, ob, oh, ow, oc);
829           dst[outIndex] = src[inIndex];
830         }
831       }
832     }
833   }
834 }
835 
836 template <typename T>
837 static void
libjit_resizebilinear_generic(T * dst,const T * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)838 libjit_resizebilinear_generic(T *dst, const T *src, const float *scale,
839                               const dim_t *inWdims, const dim_t *outWdims) {
840   for (dim_t ob = 0; ob < outWdims[0]; ++ob) {
841     for (dim_t oh = 0; oh < outWdims[1]; ++oh) {
842       for (dim_t ow = 0; ow < outWdims[2]; ++ow) {
843         float ihf = oh / scale[1];
844         float iwf = ow / scale[2];
845         dim_t ih = dim_t(ihf);
846         dim_t iw = dim_t(iwf);
847 
848         auto ih0 = std::min(ih, inWdims[1] - 1);
849         auto ih1 = std::min(ih + 1, inWdims[1] - 1);
850         auto iw0 = std::min(iw, inWdims[2] - 1);
851         auto iw1 = std::min(iw + 1, inWdims[2] - 1);
852 
853         for (dim_t oc = 0; oc < outWdims[3]; ++oc) {
854           float v00 = src[libjit_getXYZW(inWdims, ob, ih0, iw0, oc)];
855           float v01 = src[libjit_getXYZW(inWdims, ob, ih0, iw1, oc)];
856           float v10 = src[libjit_getXYZW(inWdims, ob, ih1, iw0, oc)];
857           float v11 = src[libjit_getXYZW(inWdims, ob, ih1, iw1, oc)];
858 
859           float hd = v00 + (v10 - v00) * (ihf - ih);
860           float hw = v01 + (v11 - v01) * (ihf - ih);
861           float result = hd + (hw - hd) * (iwf - iw);
862           dst[libjit_getXYZW(outWdims, ob, oh, ow, oc)] = result;
863         }
864       }
865     }
866   }
867 }
868 
869 template <typename T>
870 static void
libjit_batchedadd_quantized(int8_t * dest,const int8_t * batch,const T * slice,dim_t numSlice,dim_t sliceSize,int32_t destOffset,int32_t batchOffset,int32_t sliceOffset,int32_t batchPre,int32_t batchPost,int32_t batchScale,int32_t slicePre,int32_t slicePost,int32_t sliceScale)871 libjit_batchedadd_quantized(int8_t *dest, const int8_t *batch, const T *slice,
872                             dim_t numSlice, dim_t sliceSize, int32_t destOffset,
873                             int32_t batchOffset, int32_t sliceOffset,
874                             int32_t batchPre, int32_t batchPost,
875                             int32_t batchScale, int32_t slicePre,
876                             int32_t slicePost, int32_t sliceScale) {
877   for (dim_t n = 0; n < numSlice; n++) {
878     dim_t base = n * sliceSize;
879     for (dim_t i = 0; i < sliceSize; i++) {
880       int32_t b = batch[base + i] - batchOffset;
881       int32_t s = slice[i] - sliceOffset;
882       int32_t x = libjit_scale_i32i8(b, batchPre, batchPost, batchScale, 0);
883       int32_t y = libjit_scale_i32i8(s, slicePre, slicePost, sliceScale, 0);
884       dest[base + i] = libjit_clip(x + y + destOffset);
885     }
886   }
887 }
888 
find_min_max_f(float * tensor,dim_t size,float & min,float & max)889 static void find_min_max_f(float *tensor, dim_t size, float &min, float &max) {
890   min = tensor[0];
891   max = tensor[0];
892 
893   for (dim_t i = 1; i < size; ++i) {
894     float tensorVal = tensor[i];
895     if (tensorVal < min)
896       min = tensorVal;
897 
898     if (tensorVal > max)
899       max = tensorVal;
900 
901     // Sanity check for NaN and Infinity.
902     assert(!std::isnan(tensor[i]) && "NaN value found!");
903     assert(!std::isinf(tensor[i]) && "Infinity value found!");
904   }
905 }
906 
check_all_zeros(float * arrayToCheck,dim_t size)907 static int check_all_zeros(float *arrayToCheck, dim_t size) {
908   for (dim_t i = 0; i < size; ++i) {
909     if (arrayToCheck[i] != 0) {
910       return 0;
911     }
912   }
913   return 1;
914 }
915 
916 /// Gen a bin number to insert \p value into the histogram which has \p nBins
917 /// with \p minValue and binWidth in histogram.
get_bin(dim_t nBins,float binWidth,float minValue,float value)918 static dim_t get_bin(dim_t nBins, float binWidth, float minValue, float value) {
919   dim_t result =
920       binWidth == 0
921           ? 0
922           : MIN(static_cast<dim_t>((value - minValue) / binWidth), nBins - 1);
923   return result;
924 }
925 
926 template <typename T>
libjit_space_to_depth_generic(const T * inPtr,T * outPtr,dim_t blockSize,const dim_t * inDims,const dim_t * outDims)927 static void libjit_space_to_depth_generic(const T *inPtr, T *outPtr,
928                                           dim_t blockSize, const dim_t *inDims,
929                                           const dim_t *outDims) {
930   dim_t inHeight = inDims[1];
931   dim_t inWidth = inDims[2];
932   dim_t inDepth = inDims[3];
933 
934   dim_t outBatch = outDims[0];
935   dim_t outHeight = outDims[1];
936   dim_t outWidth = outDims[2];
937   dim_t outDepth = outDims[3];
938 
939   for (dim_t b = 0; b < outBatch; ++b) {
940     for (dim_t h = 0; h < outHeight; ++h) {
941       for (dim_t w = 0; w < outWidth; ++w) {
942         for (dim_t c = 0; c < outDepth; ++c) {
943           // NHWC
944           // c +
945           // w * outDepth +
946           // h * outDepth * outWidth +
947           // b * outDepth * outWidth * outHeight
948           dim_t outIndex = c + outDepth * (w + outWidth * (h + b * outHeight));
949 
950           // Gets the block layer we are on
951           dim_t blockDepthLayer = c / inDepth;
952           // every multiple of block size we reset to 0 offset
953           dim_t iw = w * blockSize + blockDepthLayer % blockSize;
954           // every multiple of blockSize we start height traversal + 1
955           dim_t ih = h * blockSize + blockDepthLayer / blockSize;
956           // at every multiple of inDepth index in to input depths resets to 0
957           dim_t id = c % inDepth;
958 
959           dim_t inIndex = id + inDepth * (iw + inWidth * (ih + b * inHeight));
960           outPtr[outIndex] = inPtr[inIndex];
961         }
962       }
963     }
964   }
965 }
966 
967 template <typename DstType, typename SrcType>
968 static void
libjit_copy_kernel_with_conversion(DstType * dstPtr,const SrcType * srcPtr,const dim_t * dims,dim_t numDims)969 libjit_copy_kernel_with_conversion(DstType *dstPtr, const SrcType *srcPtr,
970                                    const dim_t *dims, dim_t numDims) {
971   dim_t dimSize = 1;
972   for (dim_t i = 0; i < numDims; ++i) {
973     dimSize *= dims[i];
974   }
975 
976   for (dim_t i = 0; i < dimSize; ++i) {
977     dstPtr[i] = DstType(srcPtr[i]);
978   }
979 }
980 
981 /// The dimensions passed in here are pre-expanded in LLVMIRGen with 1s so that
982 /// we can iterate over the shape here, regardless of the shape of the tensor.
983 template <typename T>
libjit_reducemin(T * dest,const T * batch,size_t destSize,const dim_t * destDims,const dim_t * batchDims,T init)984 static void libjit_reducemin(T *dest, const T *batch, size_t destSize,
985                              const dim_t *destDims, const dim_t *batchDims,
986                              T init) {
987   for (dim_t i = 0; i < destSize; i++) {
988     dest[i] = init;
989   }
990 
991   unsigned int axis[6];
992   for (dim_t i = 0; i < 6; i++) {
993     axis[i] = (destDims[i] > 1);
994   }
995 
996   for (dim_t x = 0, dx = 0; x < batchDims[0]; x++, dx += axis[0]) {
997     for (dim_t y = 0, dy = 0; y < batchDims[1]; y++, dy += axis[1]) {
998       for (dim_t z = 0, dz = 0; z < batchDims[2]; z++, dz += axis[2]) {
999         for (dim_t w = 0, dw = 0; w < batchDims[3]; w++, dw += axis[3]) {
1000           for (dim_t q = 0, dq = 0; q < batchDims[4]; q++, dq += axis[4]) {
1001             for (dim_t r = 0, dr = 0; r < batchDims[5]; r++, dr += axis[5]) {
1002               T fdest =
1003                   dest[libjit_getXYZWQR(destDims, dx, dy, dz, dw, dq, dr)];
1004               T fnew = batch[libjit_getXYZWQR(batchDims, x, y, z, w, q, r)];
1005               dest[libjit_getXYZWQR(destDims, dx, dy, dz, dw, dq, dr)] =
1006                   std::min(fdest, fnew);
1007             }
1008           }
1009         }
1010       }
1011     }
1012   }
1013 }
1014 
1015 template <typename T, typename T2>
libjit_cross_entropy_loss_generic(T * CE,T * P,T2 * labels,dim_t * dims)1016 static void libjit_cross_entropy_loss_generic(T *CE, T *P, T2 *labels,
1017                                               dim_t *dims) {
1018   CE[0] = 0.0;
1019   for (dim_t n = 0; n < dims[0]; ++n) {
1020     auto y = labels[n];
1021     auto p_n = P[libjit_getXY(dims, n, y)];
1022     CE[0] -= log(p_n);
1023   }
1024 }
1025 
1026 template <typename T, typename T2>
libjit_sparse_lengths_sum_generic(T * dest,T * data,T2 * indices,int32_t * lengths,dim_t segments,dim_t lineSize)1027 static void libjit_sparse_lengths_sum_generic(T *dest, T *data, T2 *indices,
1028                                               int32_t *lengths, dim_t segments,
1029                                               dim_t lineSize) {
1030   memset(dest, 0, segments * lineSize * sizeof(float));
1031   dim_t curIndex = 0;
1032   for (dim_t i = 0; i < segments; i++) {
1033     for (int32_t j = 0; j < lengths[i]; j++) {
1034       dim_t line = indices[curIndex];
1035       for (dim_t k = 0; k < lineSize; k++) {
1036         dest[i * lineSize + k] += data[line * lineSize + k];
1037       }
1038       curIndex++;
1039     }
1040   }
1041 }
1042 
1043 template <typename T, typename T2>
1044 static void
libjit_sparse_lengths_weighted_sum_generic(T * dest,T * data,float * weights,T2 * indices,int32_t * lengths,dim_t segments,dim_t lineSize)1045 libjit_sparse_lengths_weighted_sum_generic(T *dest, T *data, float *weights,
1046                                            T2 *indices, int32_t *lengths,
1047                                            dim_t segments, dim_t lineSize) {
1048   memset(dest, 0, segments * lineSize * sizeof(float));
1049   dim_t curIndex = 0;
1050   for (dim_t i = 0; i < segments; i++) {
1051     for (int32_t j = 0; j < lengths[i]; j++) {
1052       float weight = weights[curIndex];
1053       dim_t line = indices[curIndex];
1054       for (dim_t k = 0; k < lineSize; k++) {
1055         dest[i * lineSize + k] += weight * data[line * lineSize + k];
1056       }
1057       curIndex++;
1058     }
1059   }
1060 }
1061 
1062 template <typename T, typename T2>
libjit_sparse_lengths_weighted_sum_grad_generic(const T * destGrad,T * dataGrad,T * weightsGrad,const T * data,const T * weights,const T2 * indices,const int32_t * lengths,dim_t segments,dim_t lineSize,dim_t dataGradRawSize)1063 static void libjit_sparse_lengths_weighted_sum_grad_generic(
1064     const T *destGrad, T *dataGrad, T *weightsGrad, const T *data,
1065     const T *weights, const T2 *indices, const int32_t *lengths, dim_t segments,
1066     dim_t lineSize, dim_t dataGradRawSize) {
1067   // The data gradients not touched by this operation should
1068   // be 0, so set the entire buffer to 0 to start with.
1069   memset(dataGrad, 0, dataGradRawSize);
1070 
1071   for (dim_t i = 0, curIndex = 0; i < segments; ++i) {
1072     for (int32_t j = 0; j < lengths[i]; ++j, ++curIndex) {
1073       // For each index in each segment:
1074       //    1) accumulate into the corresponding data gradient the product of
1075       //    the gradient of the result it was added to and the weight that it
1076       //    was multiplied by during the SparseLengthsWeightedSum operation.
1077       //
1078       //    2) accumulate into each weight gradient the reduced sum of the
1079       //    elementwise product of the result slice that the corresponding
1080       //    weight produced and the input slice that the weight was multiplied
1081       //    with.
1082       float weightGrad = 0.0f;
1083       float weight = weights[curIndex];
1084       dim_t line = indices[curIndex];
1085       for (dim_t k = 0; k < lineSize; ++k) {
1086         dataGrad[line * lineSize + k] += weight * destGrad[i * lineSize + k];
1087         weightGrad += destGrad[i * lineSize + k] * data[line * lineSize + k];
1088       }
1089       weightsGrad[curIndex] = weightGrad;
1090     }
1091   }
1092 }
1093 
1094 template <typename T, typename T2>
libjit_rowwise_quantized_sparse_lengths_weighted_sum_generic(T * dest,uint8_t * data,T * scales,T * offsets,T * weights,T2 * indices,int32_t * lengths,dim_t segments,dim_t lineSize)1095 static void libjit_rowwise_quantized_sparse_lengths_weighted_sum_generic(
1096     T *dest, uint8_t *data, T *scales, T *offsets, T *weights, T2 *indices,
1097     int32_t *lengths, dim_t segments, dim_t lineSize) {
1098   memset(dest, 0, segments * lineSize * sizeof(float));
1099   dim_t curIndex = 0;
1100   for (dim_t i = 0; i < segments; i++) {
1101     for (int32_t j = 0; j < lengths[i]; j++) {
1102       const float weight = weights[curIndex];
1103       const dim_t line = indices[curIndex];
1104       const float scale = scales[line];
1105       const float offset = offsets[line];
1106       for (dim_t k = 0; k < lineSize; k++) {
1107         const float fData = scale * data[line * lineSize + k] + offset;
1108         dest[i * lineSize + k] += weight * fData;
1109       }
1110       curIndex++;
1111     }
1112   }
1113 }
1114 
1115 template <typename T, typename T2>
libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_generic(T * dest,int8_t * data,T * weights,T2 * indices,int32_t * lengths,dim_t segments,dim_t inLineSize,dim_t outLineSize)1116 static void libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_generic(
1117     T *dest, int8_t *data, T *weights, T2 *indices, int32_t *lengths,
1118     dim_t segments, dim_t inLineSize, dim_t outLineSize) {
1119   memset(dest, 0, segments * outLineSize * sizeof(float));
1120   dim_t curIndex = 0;
1121   for (dim_t i = 0; i < segments; i++) {
1122     for (int32_t j = 0, e = lengths[i]; j < e; j++) {
1123       const float weight = weights[curIndex];
1124       const dim_t line = indices[curIndex];
1125       const int8_t *currRowScaleOffsetPtr =
1126           data + ((line + 1) * inLineSize) - 2 * sizeof(float);
1127       float scale, offset;
1128       memcpy(&scale, currRowScaleOffsetPtr, sizeof(float));
1129       memcpy(&offset, currRowScaleOffsetPtr + sizeof(float), sizeof(float));
1130       for (dim_t k = 0; k < outLineSize; k++) {
1131         const float fData =
1132             (scale * (uint8_t)(data[line * inLineSize + k])) + offset;
1133         dest[i * outLineSize + k] += weight * fData;
1134       }
1135       curIndex++;
1136     }
1137   }
1138 }
1139 
1140 template <typename T, typename T2>
libjit_sparse_to_dense_generic(T * dest,const T2 * indices,const T * values,dim_t numIndices,dim_t destSize,dim_t valueSize)1141 static void libjit_sparse_to_dense_generic(T *dest, const T2 *indices,
1142                                            const T *values, dim_t numIndices,
1143                                            dim_t destSize, dim_t valueSize) {
1144   memset(dest, 0, destSize * sizeof(float));
1145 
1146   for (dim_t i = 0, valuesOffset = 0; i < numIndices;
1147        ++i, valuesOffset += valueSize) {
1148     dim_t idx = indices[i];
1149     dim_t destOffset = idx * valueSize;
1150 
1151     for (size_t j = 0; j < valueSize; ++j) {
1152       dest[destOffset + j] += values[valuesOffset + j];
1153     }
1154   }
1155 }
1156 
1157 struct ClassBox {
1158   float score{0.0f};
1159   size_t index{0};
1160 };
1161 
1162 struct Box {
1163   float v0{0.0f};
1164   float v1{0.0f};
1165   float v2{0.0f};
1166   float v3{0.0f};
1167 };
1168 
1169 struct OutBox {
1170   float classValue{0.0f};
1171   size_t batchIndex{0};
1172   size_t classIndex{0};
1173   size_t boxIndex{0};
1174 };
1175 
maxMin(float lhs,float rhs,float & min,float & max)1176 static void maxMin(float lhs, float rhs, float &min, float &max) {
1177   if (lhs >= rhs) {
1178     min = rhs;
1179     max = lhs;
1180   } else {
1181     min = lhs;
1182     max = rhs;
1183   }
1184 }
1185 
checkIOU(const Box & sb,const Box & cb,float iouThreshold,size_t centerPointBox)1186 static bool checkIOU(const Box &sb, const Box &cb, float iouThreshold,
1187                      size_t centerPointBox) {
1188   float xSMin = 0.0f;
1189   float ySMin = 0.0f;
1190   float xSMax = 0.0f;
1191   float ySMax = 0.0f;
1192 
1193   float xCMin = 0.0f;
1194   float yCMin = 0.0f;
1195   float xCMax = 0.0f;
1196   float yCMax = 0.0f;
1197 
1198   // Standardizing coordinates so that (xmin, ymin) is upper left corner of a
1199   // box and (xmax, ymax) is lower right corner of the box.
1200   if (!centerPointBox) {
1201     // 0 means coordinates for diagonal ends of a box.
1202     // Coordinates can either be absolute or normalized.
1203     maxMin(sb.v0, sb.v2, xSMin, xSMax);
1204     maxMin(sb.v1, sb.v3, ySMin, ySMax);
1205 
1206     maxMin(cb.v0, cb.v2, xCMin, xCMax);
1207     maxMin(cb.v1, cb.v3, yCMin, yCMax);
1208   } else {
1209     float halfWidthS = sb.v2 / 2.0f;
1210     float halfHeightS = sb.v3 / 2.0f;
1211     float halfWidthC = cb.v2 / 2.0f;
1212     float halfHeightC = cb.v3 / 2.0f;
1213 
1214     xSMin = sb.v0 - halfWidthS;
1215     ySMin = sb.v1 - halfHeightS;
1216     xSMax = sb.v0 + halfWidthS;
1217     ySMax = sb.v1 + halfHeightS;
1218 
1219     xCMin = cb.v0 - halfWidthC;
1220     yCMin = cb.v1 - halfHeightC;
1221     xCMax = cb.v0 + halfWidthC;
1222     yCMax = cb.v1 + halfHeightC;
1223   }
1224 
1225   // finding upper left and lower right corner of a box formed by intersection.
1226   float xMin = MAX(xSMin, xCMin);
1227   float yMin = MAX(ySMin, yCMin);
1228   float xMax = MIN(xSMax, xCMax);
1229   float yMax = MIN(ySMax, yCMax);
1230 
1231   float intersectionArea = MAX((0.0f), xMax - xMin) * MAX((0.0f), yMax - yMin);
1232 
1233   if (intersectionArea == 0.0f) {
1234     return false;
1235   }
1236 
1237   float sArea = (xSMax - xSMin) * (ySMax - ySMin);
1238   float cArea = (xCMax - xCMin) * (yCMax - yCMin);
1239   float unionArea = sArea + cArea - intersectionArea;
1240 
1241   return intersectionArea > iouThreshold * unionArea;
1242 }
1243 
1244 // ONNX
1245 // Class/Score [BatchNum][ClassNum][BoxNum]
1246 // Box [BatchNum][BoxNum][4]
1247 // Result [BatchNum*MaxOutputPerBatch][3]
1248 // V4
1249 // Class/Score [BatchNum][BoxNum]
1250 // Boxes [BatdhNum][BoxNum][4]
1251 // Result [BatchNum*MaxOutputPerBatch]
1252 // NumberOfIndicesDetected [BatchNum*MaxOutputPerBatch]
1253 template <typename T>
1254 static void
libjit_nms_generic(T * indices,T * numDetected,const float * boxTensor,const dim_t * boxTensorDims,dim_t boxTensorDimSize,const float * scoresTensor,const dim_t * scoresTensorDims,dim_t scoresTensorDimSize,const dim_t * resultTensorDims,dim_t resultTensorDimSize,unsigned centerPointBox,unsigned maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold,bool isV4)1255 libjit_nms_generic(T *indices, T *numDetected, const float *boxTensor,
1256                    const dim_t *boxTensorDims, dim_t boxTensorDimSize,
1257                    const float *scoresTensor, const dim_t *scoresTensorDims,
1258                    dim_t scoresTensorDimSize, const dim_t *resultTensorDims,
1259                    dim_t resultTensorDimSize, unsigned centerPointBox,
1260                    unsigned maxOutputBoxesPerClass, float iouThreshold,
1261                    float scoreThreshold, bool isV4) {
1262   int boxesBoxDim = boxTensorDimSize - 2;
1263 
1264   size_t numBatches = 1;
1265   size_t numClasses = 1;
1266   size_t numBoxes = boxTensorDims[boxesBoxDim];
1267 
1268   size_t maxOutputPerBatch = 0;
1269   if (!isV4) {
1270     int boxesBatchDim = boxTensorDimSize - 3;
1271     int scoresBatchDim = scoresTensorDimSize - 3;
1272 
1273     int scoresBoxDim = scoresTensorDimSize - 1;
1274     int scoresClassDim = scoresTensorDimSize - 2;
1275 
1276     assert(scoresTensorDims[scoresBoxDim] == boxTensorDims[boxesBoxDim] &&
1277            "Mismatch between number of scores and number of boxes.");
1278     assert(scoresTensorDims[scoresBatchDim] == boxTensorDims[boxesBatchDim] &&
1279            "Scores and Box Batch Dimensions don't match.");
1280     (void)boxesBatchDim;
1281     (void)scoresBoxDim;
1282     numBatches = scoresTensorDims[scoresBatchDim];
1283     numClasses = scoresTensorDims[scoresClassDim];
1284     numBoxes = boxTensorDims[boxesBoxDim];
1285     maxOutputPerBatch = resultTensorDims[resultTensorDimSize - 2] / numBatches;
1286   } else {
1287     maxOutputPerBatch = resultTensorDims[resultTensorDimSize - 1] / numBatches;
1288   }
1289 
1290   static_assert(sizeof(Box) == 4 * sizeof(float),
1291                 "Can't reinterpret raw float data as a Box.");
1292   const Box *boxes = reinterpret_cast<const Box *>(boxTensor);
1293 
1294   auto cmpFunc = [](const ClassBox &cb1, const ClassBox &cb2) -> bool {
1295     return cb1.score > cb2.score;
1296   };
1297 
1298   size_t outPutBoxIndex = 0;
1299   for (size_t batchIndex = 0; batchIndex < numBatches; ++batchIndex) {
1300     int32_t detectedPerBatch = 0;
1301     OutBox minBox{scoresTensor[batchIndex * numClasses], batchIndex, 0, 0};
1302     for (size_t classIndex = 0; classIndex < numClasses; ++classIndex) {
1303       ClassBox selectedIndices[numBoxes];
1304       ClassBox potentialBoxes[numBoxes];
1305       size_t indexPBoxes = 0;
1306       const float *currClass =
1307           &scoresTensor[(batchIndex * numClasses + classIndex) * numBoxes];
1308       for (size_t boxIndex = 0; boxIndex < numBoxes; ++boxIndex) {
1309         float classScore = currClass[boxIndex];
1310         if (classScore > scoreThreshold) {
1311           ClassBox &b = potentialBoxes[indexPBoxes++];
1312           b.score = classScore;
1313           b.index = boxIndex;
1314         }
1315       }
1316 
1317       std::sort(potentialBoxes, potentialBoxes + indexPBoxes, cmpFunc);
1318 
1319       size_t indexSBoxes = 0;
1320       size_t detectedPerClass = 0;
1321       float tScore = minBox.classValue;
1322       for (unsigned int i = 0; i < indexPBoxes; ++i) {
1323         ClassBox &pbI = potentialBoxes[i];
1324         const Box &potentialBox = boxes[batchIndex * numBoxes + pbI.index];
1325         bool selected = true;
1326         for (unsigned int j = 0; j < indexSBoxes && selected; ++j) {
1327           ClassBox &sbI = selectedIndices[j];
1328           const Box &selectedBox = boxes[batchIndex * numBoxes + sbI.index];
1329           selected = !checkIOU(selectedBox, potentialBox, iouThreshold,
1330                                centerPointBox);
1331         }
1332 
1333         if (selected) {
1334           selectedIndices[indexSBoxes++] = pbI;
1335           if (isV4) {
1336             indices[outPutBoxIndex] = pbI.index;
1337           } else {
1338             indices[outPutBoxIndex * 3 + 0] = batchIndex;
1339             indices[outPutBoxIndex * 3 + 1] = classIndex;
1340             indices[outPutBoxIndex * 3 + 2] = pbI.index;
1341           }
1342 
1343           tScore = pbI.score;
1344           ++outPutBoxIndex;
1345           ++detectedPerClass;
1346           ++detectedPerBatch;
1347         }
1348 
1349         if (detectedPerClass == maxOutputBoxesPerClass) {
1350           break;
1351         }
1352       }
1353 
1354       if (tScore < minBox.classValue) {
1355         minBox.classValue = tScore;
1356         if (isV4) {
1357           minBox.boxIndex = indices[outPutBoxIndex - 1];
1358         } else {
1359           minBox.boxIndex = indices[(outPutBoxIndex - 1) * 3 + 2];
1360         }
1361         minBox.classIndex = classIndex;
1362       }
1363     }
1364 
1365     // Filling the rest of the class with minimum value.
1366     for (size_t i = detectedPerBatch; i < maxOutputPerBatch; ++i) {
1367       if (isV4) {
1368         indices[outPutBoxIndex] = minBox.boxIndex;
1369       } else {
1370         indices[outPutBoxIndex * 3 + 0] = minBox.batchIndex;
1371         indices[outPutBoxIndex * 3 + 1] = minBox.classIndex;
1372         indices[outPutBoxIndex * 3 + 2] = minBox.boxIndex;
1373       }
1374 
1375       ++outPutBoxIndex;
1376     }
1377     // For ONNX NMS it's not used, for TF Batch Dimension is 1.
1378     for (size_t i = 0; i < maxOutputBoxesPerClass; ++i) {
1379       numDetected[batchIndex * maxOutputBoxesPerClass + i] = detectedPerBatch;
1380     }
1381   }
1382 }
1383 
1384 template <typename T, typename T2>
libjit_softmax_grad_generic(T * inG,T * outW,const T2 * selectedW,const dim_t * idim,const dim_t * selectdim)1385 void libjit_softmax_grad_generic(T *inG, T *outW, const T2 *selectedW,
1386                                  const dim_t *idim, const dim_t *selectdim) {
1387   for (dim_t n = 0; n < idim[0]; n++) {
1388     for (dim_t i = 0; i < idim[1]; i++) {
1389       float delta = (selectedW[libjit_getXY(selectdim, n, 0)] == i);
1390       inG[libjit_getXY(idim, n, i)] = outW[libjit_getXY(idim, n, i)] - delta;
1391     }
1392   }
1393 }
1394 
1395 template <typename T, typename T2>
libjit_max_pool_argmax_grad_generic(T * inG,const T * outG,const T2 * argmax,const dim_t * inGdims,const dim_t * outWdims)1396 void libjit_max_pool_argmax_grad_generic(T *inG, const T *outG,
1397                                          const T2 *argmax, const dim_t *inGdims,
1398                                          const dim_t *outWdims) {
1399   // NHWC format is assumed
1400   for (dim_t n = 0; n < outWdims[0]; n++) {
1401     for (dim_t z = 0; z < outWdims[3]; z++) {
1402       // Clear inG
1403       for (dim_t x = 0; x < inGdims[1]; x++) {
1404         for (dim_t y = 0; y < inGdims[2]; y++) {
1405           inG[libjit_getXYZW(inGdims, n, x, y, z)] = 0.0;
1406         }
1407       }
1408 
1409       for (dim_t ax = 0; ax < outWdims[1]; ax++) {
1410         for (dim_t ay = 0; ay < outWdims[2]; ay++) {
1411           // Reuse precomputed linear index of max element from argmax.
1412           const dim_t flatIndex = libjit_getXYZW(outWdims, n, ax, ay, z);
1413           float df = outG[flatIndex];
1414           inG[argmax[flatIndex]] += df;
1415         } // W
1416       }   // H
1417     }     // C
1418   }       // N
1419 }
1420 } // namespace
1421 
1422 extern "C" {
1423 
1424 /// Macro to define a mini-kernel for data-parallel operations. The body of the
1425 /// kernel is auto-generated by the macro.
1426 /// \p name the name of the kernel
1427 /// \p type the type of the tensor elements and of the return value
1428 /// \p body the operation to be performed
1429 #define DEFINE_DATA_PARALLEL_KERNEL(name, type, body)                          \
1430   type name(dim_t idx, const type *LHS, const type *RHS, const type *op3) {    \
1431     return body;                                                               \
1432   }
1433 
1434 /// Macro to define a mini-kernel for data-parallel operations. The body of the
1435 /// kernel is not auto-generated by the macro.
1436 /// \p name the name of the kernel
1437 #define DEFINE_DATA_PARALLEL_KERNEL_FUNC(name)                                 \
1438   float name(dim_t idx, const float *LHS, const float *RHS, const float *op3)
1439 
1440 /// Macro to define a mini-kernel for data-parallel operations with immediate
1441 /// operands.
1442 /// \p name the name of the kernel
1443 /// \p type the type of the tensor elements and of the return value
1444 /// \p body the operation to be performed
1445 #define DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(name, type, body)         \
1446   type name(dim_t idx, type val, const type *LHS, const type *RHS) {           \
1447     return body;                                                               \
1448   }
1449 
1450 /// Macro to define a mini-kernel for data-parallel arithmetic quantized
1451 /// operations. The body of the kernel is auto-generated by the macro.
1452 /// \p name the name of the kernel
1453 /// \p type the type of the tensor elements
1454 /// \p body the operation to be performed
1455 #define DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED(name, type, body)                \
1456   type name(dim_t idx, const type *LHS, const type *RHS, int32_t destOffset,   \
1457             int32_t lhsOffset, int32_t rhsOffset, int32_t lhsPre,              \
1458             int32_t lhsPost, int32_t lhsScale, int32_t rhsPre,                 \
1459             int32_t rhsPost, int32_t rhsScale) {                               \
1460     int32_t lhs = libjit_scale_i32i8(LHS[idx] - lhsOffset, lhsPre, lhsPost,    \
1461                                      lhsScale, 0);                             \
1462     int32_t rhs = libjit_scale_i32i8(RHS[idx] - rhsOffset, rhsPre, rhsPost,    \
1463                                      rhsScale, 0);                             \
1464     return libjit_clip((body) + destOffset);                                   \
1465   }
1466 
1467 /// Macro to define a mini-kernel for data-parallel multiplicative quantized
1468 /// operations. The body of the kernel is auto-generated by the macro.
1469 /// \p name the name of the kernel
1470 /// \p type the type of the tensor elements
1471 /// \p body the operation to be performed
1472 #define DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED_M(name, body)                    \
1473   int8_t name(dim_t idx, const int8_t *LHS, const int8_t *RHS,                 \
1474               int32_t destOffset, int32_t lhsOffset, int32_t rhsOffset,        \
1475               int32_t pre, int32_t post, int32_t scale) {                      \
1476     int32_t lhs = LHS[idx] - lhsOffset;                                        \
1477     int32_t rhs = RHS[idx] - rhsOffset;                                        \
1478     return libjit_clip(                                                        \
1479         libjit_scale_i32i8((body), pre, post, scale, destOffset));             \
1480   }
1481 
1482 /// Define mini-kernels for all data parallel operations. They are invoked from
1483 /// the generated kernels for sequences of data parallel operations.
1484 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_max_kernel_f, float,
1485                             MAX(LHS[idx], RHS[idx]))
1486 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_min_kernel_f, float,
1487                             MIN(LHS[idx], RHS[idx]))
1488 DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_f, float, LHS[idx])
1489 DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_u, int64_t, LHS[idx])
1490 DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_i8, int8_t, LHS[idx])
1491 DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_i16, int16_t, LHS[idx])
1492 DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_i32, int32_t, LHS[idx])
1493 DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_b, int8_t, LHS[idx])
1494 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_add_kernel_f, float,
1495                             LHS[idx] + RHS[idx])
1496 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_add_kernel_i32, int32_t,
1497                             LHS[idx] + RHS[idx])
1498 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_sub_kernel_f, float,
1499                             LHS[idx] - RHS[idx])
1500 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_div_kernel_f, float,
1501                             LHS[idx] / RHS[idx])
1502 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_div_kernel_u, int64_t,
1503                             LHS[idx] / RHS[idx])
1504 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_div_kernel_i32, int32_t,
1505                             LHS[idx] / RHS[idx])
1506 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_mul_kernel_f, float,
1507                             LHS[idx] * RHS[idx])
1508 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_mul_kernel_i32, int32_t,
1509                             LHS[idx] * RHS[idx])
1510 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_pow_kernel_f, float,
1511                             pow(LHS[idx], RHS[idx]))
1512 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_log_kernel_f, float, log(LHS[idx]))
1513 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_exp_kernel_f, float, exp(LHS[idx]))
1514 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_abs_kernel_f, float,
1515                             std::abs(LHS[idx]))
1516 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_neg_kernel_f, float, -LHS[idx])
1517 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_floor_kernel_f, float,
1518                             std::floor(LHS[idx]))
1519 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_ceil_kernel_f, float,
1520                             std::ceil(LHS[idx]))
1521 // Rounding mode required by ONNX, Numpy, TensorFlow is round to even which
1522 // rounds to nearest even integer those values with fractional part 0.5.
1523 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_round_kernel_f, float,
1524                             std::nearbyintf(LHS[idx]))
1525 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_sqrt_kernel_f, float,
1526                             std::sqrt(LHS[idx]))
1527 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_rsqrt_kernel_f, float,
1528                             1 / std::sqrt(LHS[idx]))
1529 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_reciprocal_kernel_f, float,
1530                             1 / LHS[idx])
1531 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_sin_kernel_f, float,
1532                             std::sin(LHS[idx]))
1533 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_cos_kernel_f, float,
1534                             std::cos(LHS[idx]))
1535 DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED(libjit_element_add_kernel_i8, int8_t,
1536                                       lhs + rhs)
1537 DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED(libjit_element_sub_kernel_i8, int8_t,
1538                                       lhs - rhs)
1539 DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED(libjit_element_max_kernel_i8, int8_t,
1540                                       MAX(lhs, rhs))
1541 DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED(libjit_element_min_kernel_i8, int8_t,
1542                                       MIN(lhs, rhs))
1543 DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED_M(libjit_element_mul_kernel_i8, lhs *rhs)
1544 DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED_M(libjit_element_div_kernel_i8, lhs / rhs)
1545 
1546 /// This is a variable used by Glow backends to determine the actual type used
1547 /// for size_t, dim_t and int variables when libjit was compiled.
1548 size_t libjit_sizeTVar;
1549 dim_t libjit_dimTVar;
1550 int libjit_intVar;
1551 
1552 /// Specialize the Modulo kernel into two functions based on the
1553 /// value of SignFollowDivisor.
libjit_element_modulo_kernel_sign_follow_u(dim_t idx,const int64_t divisor,const int64_t * input)1554 int64_t libjit_element_modulo_kernel_sign_follow_u(dim_t idx,
1555                                                    const int64_t divisor,
1556                                                    const int64_t *input) {
1557   int64_t res = input[idx] % divisor;
1558   if (res && ((res > 0) != (divisor > 0))) {
1559     res += divisor;
1560   }
1561   return res;
1562 }
1563 
libjit_element_modulo_kernel_no_sign_follow_u(dim_t idx,const int64_t divisor,const int64_t * input)1564 int64_t libjit_element_modulo_kernel_no_sign_follow_u(dim_t idx,
1565                                                       const int64_t divisor,
1566                                                       const int64_t *input) {
1567   return input[idx] % divisor;
1568 }
1569 
libjit_element_modulo_kernel_sign_follow_i32(dim_t idx,const int64_t divisor,const int32_t * input)1570 int32_t libjit_element_modulo_kernel_sign_follow_i32(dim_t idx,
1571                                                      const int64_t divisor,
1572                                                      const int32_t *input) {
1573   int32_t res = input[idx] % divisor;
1574   if (res && ((res > 0) != (divisor > 0))) {
1575     res += divisor;
1576   }
1577   return res;
1578 }
1579 
libjit_element_modulo_kernel_no_sign_follow_i32(dim_t idx,const int64_t divisor,const int32_t * input)1580 int32_t libjit_element_modulo_kernel_no_sign_follow_i32(dim_t idx,
1581                                                         const int64_t divisor,
1582                                                         const int32_t *input) {
1583   return input[idx] % divisor;
1584 }
1585 
1586 //===----------------------------------------------------------------------===//
1587 //                              Logical operations
1588 //===----------------------------------------------------------------------===//
libjit_element_not_kernel_b(dim_t idx,const bool * input)1589 int8_t libjit_element_not_kernel_b(dim_t idx, const bool *input) {
1590   return !input[idx];
1591 }
1592 
libjit_element_and_kernel_b(dim_t idx,const bool * LHS,const bool * RHS)1593 int8_t libjit_element_and_kernel_b(dim_t idx, const bool *LHS,
1594                                    const bool *RHS) {
1595   return LHS[idx] && RHS[idx];
1596 }
1597 
libjit_element_or_kernel_b(dim_t idx,const bool * LHS,const bool * RHS)1598 int8_t libjit_element_or_kernel_b(dim_t idx, const bool *LHS, const bool *RHS) {
1599   return LHS[idx] || RHS[idx];
1600 }
1601 
libjit_element_xor_kernel_b(dim_t idx,const bool * LHS,const bool * RHS)1602 int8_t libjit_element_xor_kernel_b(dim_t idx, const bool *LHS,
1603                                    const bool *RHS) {
1604   return LHS[idx] ^ RHS[idx];
1605 }
1606 
1607 //===----------------------------------------------------------------------===//
1608 //                              Compare operations
1609 //===----------------------------------------------------------------------===//
1610 #define DEFINE_CMP_KERNEL_QUANTIZED(name, type, cmp)                           \
1611   int8_t name(dim_t idx, const type *LHS, const type *RHS, int32_t lhsOffset,  \
1612               int32_t rhsOffset, int32_t pre, int32_t post, int32_t scale) {   \
1613     int32_t lhs = LHS[idx] - lhsOffset;                                        \
1614     int32_t rhs = RHS[idx] - rhsOffset;                                        \
1615     return (libjit_scale_i32i8(lhs, pre, post, scale, 0) cmp rhs) ? 1 : 0;     \
1616   }
1617 DEFINE_CMP_KERNEL_QUANTIZED(libjit_element_cmp_eq_kernel_i8, int8_t, ==)
1618 DEFINE_CMP_KERNEL_QUANTIZED(libjit_element_cmp_neq_kernel_i8, int8_t, !=)
1619 DEFINE_CMP_KERNEL_QUANTIZED(libjit_element_cmp_lt_kernel_i8, int8_t, <)
1620 DEFINE_CMP_KERNEL_QUANTIZED(libjit_element_cmp_lte_kernel_i8, int8_t, <=)
1621 #undef DEFINE_CMP_KERNEL_QUANTIZED
1622 
1623 #define DEFINE_CMP_KERNEL_NON_QUANTIZED(name, type, cmp)                       \
1624   int8_t name(dim_t idx, const type *LHS, const type *RHS) {                   \
1625     return (LHS[idx] cmp RHS[idx]) ? 1 : 0;                                    \
1626   }
1627 
1628 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_eq_kernel_f, float, ==)
1629 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_eq_kernel_i32, int32_t, ==)
1630 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_eq_kernel_u, size_t, ==)
1631 
1632 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_neq_kernel_f, float, !=)
1633 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_neq_kernel_i32, int32_t, !=)
1634 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_neq_kernel_u, size_t, !=)
1635 
1636 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lt_kernel_f, float, <)
1637 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lt_kernel_i32, int32_t, <)
1638 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lt_kernel_u, size_t, <)
1639 
1640 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lte_kernel_f, float, <=)
1641 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lte_kernel_i32, int32_t, <=)
1642 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lte_kernel_u, size_t, <=)
1643 #undef DEFINE_CMP_KERNEL_NON_QUANTIZED
1644 
libjit_element_is_nan_kernel_f(dim_t idx,const float * input)1645 int8_t libjit_element_is_nan_kernel_f(dim_t idx, const float *input) {
1646   return std::isnan(input[idx]) ? 1 : 0;
1647 }
1648 
1649 // Tanh cannot be vectorized by LLVM yet. Therefore we use the following
1650 // formula instead: 1 - 2 / (exp(x * 2) + 1), which is also used by Caffe2 and
1651 // provides a good accuracy.
1652 // Once LLVM supports the vectorization of tanh, we can replace this
1653 // approximation by a direct tanh call.
1654 // When the LIBJIT compile option "-ffast-math" is enabled the intermediate
1655 // computation expf(x) for Tanh operator is not handled properly for very
1656 // large positive values which results in NaN values for the Tanh output.
1657 // Therefore when the "-ffast-math" is enabled we compute the Tanh such that
1658 // we avoid computing large values for the "expf" function.
1659 #ifdef FFAST_MATH
DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_tanh_kernel_f)1660 DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_tanh_kernel_f) {
1661   float inpVal = LHS[idx];
1662   float tanhVal = -1 + 2 / (expf(-2 * std::abs(inpVal)) + 1);
1663   return std::copysignf(tanhVal, inpVal);
1664 }
1665 #else
DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_tanh_kernel_f)1666 DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_tanh_kernel_f) {
1667   return 1 - 2 / (expf(LHS[idx] * 2) + 1);
1668 }
1669 #endif // FFAST_MATH
1670 
libjit_intlookuptable_kernel_i8(dim_t idx,const int8_t * src,const int8_t * mapping)1671 int8_t libjit_intlookuptable_kernel_i8(dim_t idx, const int8_t *src,
1672                                        const int8_t *mapping) {
1673   return mapping[src[idx] + 128];
1674 }
1675 
libjit_elementselect_kernel_f(dim_t idx,const int8_t * cond,const float * LHS,const float * RHS)1676 float libjit_elementselect_kernel_f(dim_t idx, const int8_t *cond,
1677                                     const float *LHS, const float *RHS) {
1678   return (cond[idx] != 0) ? LHS[idx] : RHS[idx];
1679 }
1680 
libjit_elementselect_kernel_i8(dim_t idx,const int8_t * cond,const int8_t * LHS,const int8_t * RHS,int32_t destOffset,int32_t lhsOffset,int32_t rhsOffset,int32_t lhsPre,int32_t lhsPost,int32_t lhsScale,int32_t rhsPre,int32_t rhsPost,int32_t rhsScale)1681 int8_t libjit_elementselect_kernel_i8(dim_t idx, const int8_t *cond,
1682                                       const int8_t *LHS, const int8_t *RHS,
1683                                       int32_t destOffset, int32_t lhsOffset,
1684                                       int32_t rhsOffset, int32_t lhsPre,
1685                                       int32_t lhsPost, int32_t lhsScale,
1686                                       int32_t rhsPre, int32_t rhsPost,
1687                                       int32_t rhsScale) {
1688   return (cond[idx] != 0)
1689              ? libjit_clip(libjit_scale_i32i8(LHS[idx] - lhsOffset, lhsPre,
1690                                               lhsPost, lhsScale, destOffset))
1691              : libjit_clip(libjit_scale_i32i8(RHS[idx] - rhsOffset, rhsPre,
1692                                               rhsPost, rhsScale, destOffset));
1693 }
1694 
libjit_element_relu_f(dim_t idx,const float * src)1695 float libjit_element_relu_f(dim_t idx, const float *src) {
1696   float srcVal = src[idx];
1697   return MAX(srcVal, 0);
1698 }
1699 
libjit_element_relu_i8(dim_t idx,const int8_t * src,int8_t srcOffset,int8_t destOffset,int32_t destPre,int32_t destPost,int32_t destScale)1700 int8_t libjit_element_relu_i8(dim_t idx, const int8_t *src, int8_t srcOffset,
1701                               int8_t destOffset, int32_t destPre,
1702                               int32_t destPost, int32_t destScale) {
1703   int32_t reluVal = MAX(src[idx], srcOffset);
1704   int32_t scaledVal = libjit_scale_i32i8(reluVal - srcOffset, destPre, destPost,
1705                                          destScale, destOffset);
1706   return libjit_clip(scaledVal);
1707 }
1708 
libjit_element_clip_f(dim_t idx,const float * src,float min,float max)1709 float libjit_element_clip_f(dim_t idx, const float *src, float min, float max) {
1710   float srcVal = src[idx];
1711   return MIN(MAX(srcVal, min), max);
1712 }
1713 
libjit_element_clip_i8(dim_t idx,const int8_t * src,int8_t clipMin,int8_t clipMax,int8_t srcOffset,int8_t destOffset,int32_t destPre,int32_t destPost,int32_t destScale)1714 int8_t libjit_element_clip_i8(dim_t idx, const int8_t *src, int8_t clipMin,
1715                               int8_t clipMax, int8_t srcOffset,
1716                               int8_t destOffset, int32_t destPre,
1717                               int32_t destPost, int32_t destScale) {
1718   int32_t clipVal = MIN(MAX(src[idx], clipMin), clipMax);
1719   int32_t scaledVal = libjit_scale_i32i8(clipVal - srcOffset, destPre, destPost,
1720                                          destScale, destOffset);
1721   return libjit_clip(scaledVal);
1722 }
1723 
1724 // When the LIBJIT compile option "-ffast-math" is enabled the intermediate
1725 // computation expf(x) for Sigmoid operator is not handled properly for very
1726 // large positive values which results in NaN values for the Sigmoid output.
1727 // Therefore when the "-ffast-math" is enabled we compute the Sigmoid such that
1728 // we avoid computing large values for the "expf" function.
1729 #ifdef FFAST_MATH
DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_sigmoid_kernel_f)1730 DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_sigmoid_kernel_f) {
1731   float inpVal = LHS[idx];
1732   float sigmoidVal = 1 / (1 + expf(-std::abs(inpVal)));
1733   return (float)(std::signbit(inpVal)) + std::copysignf(sigmoidVal, inpVal);
1734 }
1735 #else
DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_sigmoid_kernel_f)1736 DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_sigmoid_kernel_f) {
1737   float e = expf(-LHS[idx]);
1738   return 1 / (e + 1);
1739 }
1740 #endif // FFAST_MATH
1741 
DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_splat_kernel_f,float,val)1742 DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_splat_kernel_f, float, val)
1743 DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_splat_kernel_u, int64_t,
1744                                              val)
1745 DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_splat_kernel_i8, int8_t,
1746                                              val)
1747 DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_splat_kernel_i32, int32_t,
1748                                              val)
1749 DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_splat_kernel_b, int8_t, val)
1750 
1751 #undef DEFINE_DATA_PARALLEL_KERNEL
1752 #undef DEFINE_DATA_PARALLEL_KERNEL_FUNC
1753 #undef DEFINE_DATA_PARALLEL_KERNEL_FUNC
1754 #undef DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND
1755 
1756 void libjit_batchedadd_f(float *dest, const float *batch, const float *slice,
1757                          dim_t numSlice, dim_t sliceSize) {
1758   // For each layer in the batch:
1759   for (dim_t n = 0; n < numSlice; n++) {
1760     dim_t base = n * sliceSize;
1761     // For each element in the slice.
1762     for (dim_t i = 0; i < sliceSize; i++) {
1763       dest[base + i] = batch[base + i] + slice[i];
1764     }
1765   }
1766 }
1767 
libjit_batchedadd_i8(int8_t * dest,const int8_t * batch,const int8_t * slice,dim_t numSlice,dim_t sliceSize,int32_t destOffset,int32_t batchOffset,int32_t sliceOffset,int32_t batchPre,int32_t batchPost,int32_t batchScale,int32_t slicePre,int32_t slicePost,int32_t sliceScale)1768 void libjit_batchedadd_i8(int8_t *dest, const int8_t *batch,
1769                           const int8_t *slice, dim_t numSlice, dim_t sliceSize,
1770                           int32_t destOffset, int32_t batchOffset,
1771                           int32_t sliceOffset, int32_t batchPre,
1772                           int32_t batchPost, int32_t batchScale,
1773                           int32_t slicePre, int32_t slicePost,
1774                           int32_t sliceScale) {
1775   libjit_batchedadd_quantized(dest, batch, slice, numSlice, sliceSize,
1776                               destOffset, batchOffset, sliceOffset, batchPre,
1777                               batchPost, batchScale, slicePre, slicePost,
1778                               sliceScale);
1779 }
1780 
libjit_batchedadd_i32_i8(int8_t * dest,const int8_t * batch,const int32_t * slice,dim_t numSlice,dim_t sliceSize,int32_t destOffset,int32_t batchOffset,int32_t sliceOffset,int32_t batchPre,int32_t batchPost,int32_t batchScale,int32_t slicePre,int32_t slicePost,int32_t sliceScale)1781 void libjit_batchedadd_i32_i8(int8_t *dest, const int8_t *batch,
1782                               const int32_t *slice, dim_t numSlice,
1783                               dim_t sliceSize, int32_t destOffset,
1784                               int32_t batchOffset, int32_t sliceOffset,
1785                               int32_t batchPre, int32_t batchPost,
1786                               int32_t batchScale, int32_t slicePre,
1787                               int32_t slicePost, int32_t sliceScale) {
1788   libjit_batchedadd_quantized(dest, batch, slice, numSlice, sliceSize,
1789                               destOffset, batchOffset, sliceOffset, batchPre,
1790                               batchPost, batchScale, slicePre, slicePost,
1791                               sliceScale);
1792 }
1793 
1794 /// The dimensions passed in here are pre-expanded in LLVMIRGen with 1s so that
1795 /// we can iterate over the shape here, regardless of the shape of the tensor.
libjit_batchedreduceadd_f(float * dest,const float * batch,dim_t destSize,const dim_t * destDims,const dim_t * batchDims,dim_t axis)1796 void libjit_batchedreduceadd_f(float *dest, const float *batch, dim_t destSize,
1797                                const dim_t *destDims, const dim_t *batchDims,
1798                                dim_t axis) {
1799   for (dim_t i = 0; i < destSize; i++)
1800     dest[i] = 0.0;
1801 
1802   for (dim_t x = 0; x < batchDims[0]; x++)
1803     for (dim_t y = 0; y < batchDims[1]; y++)
1804       for (dim_t z = 0; z < batchDims[2]; z++)
1805         for (dim_t w = 0; w < batchDims[3]; w++)
1806           for (dim_t q = 0; q < batchDims[4]; q++)
1807             for (dim_t r = 0; r < batchDims[5]; r++) {
1808               dim_t I[] = {x, y, z, w, q, r};
1809               I[axis] = 0;
1810               dest[libjit_getXYZWQR(destDims, I[0], I[1], I[2], I[3], I[4],
1811                                     I[5])] +=
1812                   batch[libjit_getXYZWQR(batchDims, x, y, z, w, q, r)];
1813             }
1814 }
1815 
libjit_reducemin_f(float * dest,const float * batch,size_t destSize,const dim_t * destDims,const dim_t * batchDims)1816 void libjit_reducemin_f(float *dest, const float *batch, size_t destSize,
1817                         const dim_t *destDims, const dim_t *batchDims) {
1818   libjit_reducemin(dest, batch, destSize, destDims, batchDims,
1819                    std::numeric_limits<float>::max());
1820 }
1821 
libjit_reducemin_i32(int32_t * dest,const int32_t * batch,size_t destSize,const dim_t * destDims,const dim_t * batchDims)1822 void libjit_reducemin_i32(int32_t *dest, const int32_t *batch, size_t destSize,
1823                           const dim_t *destDims, const dim_t *batchDims) {
1824   libjit_reducemin(dest, batch, destSize, destDims, batchDims,
1825                    std::numeric_limits<int32_t>::max());
1826 }
1827 
libjit_reducemin_u(int64_t * dest,const int64_t * batch,size_t destSize,const dim_t * destDims,const dim_t * batchDims)1828 void libjit_reducemin_u(int64_t *dest, const int64_t *batch, size_t destSize,
1829                         const dim_t *destDims, const dim_t *batchDims) {
1830   libjit_reducemin(dest, batch, destSize, destDims, batchDims,
1831                    std::numeric_limits<int64_t>::max());
1832 }
1833 
1834 /// Same as the non-quantized version, the dimensions here are pre-expanded in
1835 /// LLVMIRGen. However, for quantization, we must accumulate in the inner-most
1836 /// loop with higher precision (int32_t) and then clip the result back into the
1837 /// dest tensor. Thus we add max_tensor_dimensions different cases for this to
1838 /// ensure the axis is used as the inner-most loop.
libjit_batchedreduceadd_i8(int8_t * dest,const int8_t * batch,const dim_t * destDims,const dim_t * batchDims,int32_t destOffset,int32_t batchOffset,int32_t batchPre,int32_t batchPost,int32_t batchScale,dim_t axis)1839 void libjit_batchedreduceadd_i8(int8_t *dest, const int8_t *batch,
1840                                 const dim_t *destDims, const dim_t *batchDims,
1841                                 int32_t destOffset, int32_t batchOffset,
1842                                 int32_t batchPre, int32_t batchPost,
1843                                 int32_t batchScale, dim_t axis) {
1844   switch (axis) {
1845 #define LOOP_AXIS_CASE(_D0, _D1, _D2, _D3, _D4, _D5_AXIS)                      \
1846   case _D5_AXIS:                                                               \
1847     for (dim_t i##_D0 = 0; i##_D0 < batchDims[_D0]; i##_D0++)                  \
1848       for (dim_t i##_D1 = 0; i##_D1 < batchDims[_D1]; i##_D1++)                \
1849         for (dim_t i##_D2 = 0; i##_D2 < batchDims[_D2]; i##_D2++)              \
1850           for (dim_t i##_D3 = 0; i##_D3 < batchDims[_D3]; i##_D3++)            \
1851             for (dim_t i##_D4 = 0; i##_D4 < batchDims[_D4]; i##_D4++) {        \
1852               int32_t sum = 0.0;                                               \
1853               for (dim_t i##_D5_AXIS = 0; i##_D5_AXIS < batchDims[_D5_AXIS];   \
1854                    i##_D5_AXIS++) {                                            \
1855                 sum += batch[libjit_getXYZWQR(batchDims, i0, i1, i2, i3, i4,   \
1856                                               i5)] -                           \
1857                        batchOffset;                                            \
1858               }                                                                \
1859               dim_t i##_D5_AXIS = 0;                                           \
1860               int32_t res = libjit_scale_i32i8(sum, batchPre, batchPost,       \
1861                                                batchScale, destOffset);        \
1862               dest[libjit_getXYZWQR(destDims, i0, i1, i2, i3, i4, i5)] =       \
1863                   libjit_clip(res);                                            \
1864             }                                                                  \
1865     return;
1866 
1867     // Each loop order, with the inner-most dimension/index equal to the axis.
1868     LOOP_AXIS_CASE(1, 2, 3, 4, 5, 0);
1869     LOOP_AXIS_CASE(0, 2, 3, 4, 5, 1);
1870     LOOP_AXIS_CASE(0, 1, 3, 4, 5, 2);
1871     LOOP_AXIS_CASE(0, 1, 2, 4, 5, 3);
1872     LOOP_AXIS_CASE(0, 1, 2, 3, 5, 4);
1873     LOOP_AXIS_CASE(0, 1, 2, 3, 4, 5);
1874 #undef LOOP_AXIS_CASE
1875   }
1876 }
1877 
libjit_cross_entropy_loss_f_u(float * CE,float * P,size_t * labels,dim_t * dims)1878 void libjit_cross_entropy_loss_f_u(float *CE, float *P, size_t *labels,
1879                                    dim_t *dims) {
1880   libjit_cross_entropy_loss_generic(CE, P, labels, dims);
1881 }
1882 
libjit_cross_entropy_loss_f_i32(float * CE,float * P,int32_t * labels,dim_t * dims)1883 void libjit_cross_entropy_loss_f_i32(float *CE, float *P, int32_t *labels,
1884                                      dim_t *dims) {
1885   libjit_cross_entropy_loss_generic(CE, P, labels, dims);
1886 }
1887 
libjit_gather64_f(float * dest,const float * data,const int64_t * indices,dim_t numIndices,dim_t sliceSize,dim_t numSamples,dim_t sampleSize)1888 void libjit_gather64_f(float *dest, const float *data, const int64_t *indices,
1889                        dim_t numIndices, dim_t sliceSize, dim_t numSamples,
1890                        dim_t sampleSize) {
1891   libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
1892                 sampleSize);
1893 }
1894 
libjit_gather64_i8(int8_t * dest,const int8_t * data,const int64_t * indices,dim_t numIndices,dim_t sliceSize,dim_t numSamples,dim_t sampleSize)1895 void libjit_gather64_i8(int8_t *dest, const int8_t *data,
1896                         const int64_t *indices, dim_t numIndices,
1897                         dim_t sliceSize, dim_t numSamples, dim_t sampleSize) {
1898   libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
1899                 sampleSize);
1900 }
1901 
libjit_gather64_u(int64_t * dest,const int64_t * data,const int64_t * indices,dim_t numIndices,dim_t sliceSize,dim_t numSamples,dim_t sampleSize)1902 void libjit_gather64_u(int64_t *dest, const int64_t *data,
1903                        const int64_t *indices, dim_t numIndices,
1904                        dim_t sliceSize, dim_t numSamples, dim_t sampleSize) {
1905   libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
1906                 sampleSize);
1907 }
1908 
libjit_gather32_f(float * dest,const float * data,const int32_t * indices,dim_t numIndices,dim_t sliceSize,dim_t numSamples,dim_t sampleSize)1909 void libjit_gather32_f(float *dest, const float *data, const int32_t *indices,
1910                        dim_t numIndices, dim_t sliceSize, dim_t numSamples,
1911                        dim_t sampleSize) {
1912   libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
1913                 sampleSize);
1914 }
1915 
libjit_gather32_i8(int8_t * dest,const int8_t * data,const int32_t * indices,dim_t numIndices,dim_t sliceSize,dim_t numSamples,dim_t sampleSize)1916 void libjit_gather32_i8(int8_t *dest, const int8_t *data,
1917                         const int32_t *indices, dim_t numIndices,
1918                         dim_t sliceSize, dim_t numSamples, dim_t sampleSize) {
1919   libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
1920                 sampleSize);
1921 }
1922 
libjit_gather32_u(int64_t * dest,const int64_t * data,const int32_t * indices,dim_t numIndices,dim_t sliceSize,dim_t numSamples,dim_t sampleSize)1923 void libjit_gather32_u(int64_t *dest, const int64_t *data,
1924                        const int32_t *indices, dim_t numIndices,
1925                        dim_t sliceSize, dim_t numSamples, dim_t sampleSize) {
1926   libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
1927                 sampleSize);
1928 }
1929 
libjit_gather32_i32(int32_t * dest,const int32_t * data,const int32_t * indices,dim_t numIndices,dim_t sliceSize,dim_t numSamples,dim_t sampleSize)1930 void libjit_gather32_i32(int32_t *dest, const int32_t *data,
1931                          const int32_t *indices, dim_t numIndices,
1932                          dim_t sliceSize, dim_t numSamples, dim_t sampleSize) {
1933   libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
1934                 sampleSize);
1935 }
1936 
libjit_gatherranges64_f(float * output,int64_t * lengths,const float * data,const int64_t * ranges,dim_t numExamples,dim_t exampleSize)1937 void libjit_gatherranges64_f(float *output, int64_t *lengths, const float *data,
1938                              const int64_t *ranges, dim_t numExamples,
1939                              dim_t exampleSize) {
1940   libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
1941 }
1942 
libjit_gatherranges64_i8(int8_t * output,int64_t * lengths,const int8_t * data,const int64_t * ranges,dim_t numExamples,dim_t exampleSize)1943 void libjit_gatherranges64_i8(int8_t *output, int64_t *lengths,
1944                               const int8_t *data, const int64_t *ranges,
1945                               dim_t numExamples, dim_t exampleSize) {
1946   libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
1947 }
1948 
libjit_gatherranges64_u(int64_t * output,int64_t * lengths,const int64_t * data,const int64_t * ranges,dim_t numExamples,dim_t exampleSize)1949 void libjit_gatherranges64_u(int64_t *output, int64_t *lengths,
1950                              const int64_t *data, const int64_t *ranges,
1951                              dim_t numExamples, dim_t exampleSize) {
1952   libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
1953 }
1954 
libjit_gatherranges32_f(float * output,int32_t * lengths,const float * data,const int32_t * ranges,dim_t numExamples,dim_t exampleSize)1955 void libjit_gatherranges32_f(float *output, int32_t *lengths, const float *data,
1956                              const int32_t *ranges, dim_t numExamples,
1957                              dim_t exampleSize) {
1958   libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
1959 }
1960 
libjit_gatherranges32_i8(int8_t * output,int32_t * lengths,const int8_t * data,const int32_t * ranges,dim_t numExamples,dim_t exampleSize)1961 void libjit_gatherranges32_i8(int8_t *output, int32_t *lengths,
1962                               const int8_t *data, const int32_t *ranges,
1963                               dim_t numExamples, dim_t exampleSize) {
1964   libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
1965 }
1966 
libjit_gatherranges32_u(uint64_t * output,int32_t * lengths,const uint64_t * data,const int32_t * ranges,dim_t numExamples,dim_t exampleSize)1967 void libjit_gatherranges32_u(uint64_t *output, int32_t *lengths,
1968                              const uint64_t *data, const int32_t *ranges,
1969                              dim_t numExamples, dim_t exampleSize) {
1970   libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
1971 }
1972 
libjit_gatherranges32_i32(int32_t * output,int32_t * lengths,const int32_t * data,const int32_t * ranges,dim_t numExamples,dim_t exampleSize)1973 void libjit_gatherranges32_i32(int32_t *output, int32_t *lengths,
1974                                const int32_t *data, const int32_t *ranges,
1975                                dim_t numExamples, dim_t exampleSize) {
1976   libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
1977 }
1978 
libjit_lengths_range_fill_i32(const int32_t * lengths,int32_t * output,const dim_t lengthsSize)1979 void libjit_lengths_range_fill_i32(const int32_t *lengths, int32_t *output,
1980                                    const dim_t lengthsSize) {
1981   dim_t curIdx = 0;
1982   for (dim_t i = 0, e = lengthsSize; i < e; i++) {
1983     for (int32_t j = 0, f = lengths[i]; j < f; j++) {
1984       output[curIdx++] = j;
1985     }
1986   }
1987 }
1988 
libjit_scatterdata_f_i32(float * data,const dim_t * dataDims,const int32_t * indices,const float * slices,dim_t numIndices,dim_t indexSize,dim_t sliceSize,bool isCumulative)1989 void libjit_scatterdata_f_i32(float *data, const dim_t *dataDims,
1990                               const int32_t *indices, const float *slices,
1991                               dim_t numIndices, dim_t indexSize,
1992                               dim_t sliceSize, bool isCumulative) {
1993   if (isCumulative) {
1994     libjit_scatterdataaddfloat(data, dataDims, indices, slices, numIndices,
1995                                indexSize, sliceSize);
1996   } else {
1997     libjit_scatterdatacopy(data, dataDims, indices, slices, numIndices,
1998                            indexSize, sliceSize);
1999   }
2000 }
2001 
libjit_scatterdata_i8_u(int8_t * data,const dim_t * dataDims,const int64_t * indices,const int8_t * slices,dim_t numIndices,dim_t indexSize,dim_t sliceSize,bool isCumulative,float dataScale,int32_t dataOffset,float sliceScale,int32_t sliceOffset)2002 void libjit_scatterdata_i8_u(int8_t *data, const dim_t *dataDims,
2003                              const int64_t *indices, const int8_t *slices,
2004                              dim_t numIndices, dim_t indexSize, dim_t sliceSize,
2005                              bool isCumulative, float dataScale,
2006                              int32_t dataOffset, float sliceScale,
2007                              int32_t sliceOffset) {
2008   if (isCumulative) {
2009     libjit_scatterdataaddquantized(data, dataDims, indices, slices, numIndices,
2010                                    indexSize, sliceSize, dataScale, dataOffset,
2011                                    sliceScale, sliceOffset);
2012   } else {
2013     libjit_scatterdatacopy(data, dataDims, indices, slices, numIndices,
2014                            indexSize, sliceSize);
2015   }
2016 }
2017 
libjit_scatterdata_i8_i32(int8_t * data,const dim_t * dataDims,const int32_t * indices,const int8_t * slices,dim_t numIndices,dim_t indexSize,dim_t sliceSize,bool isCumulative,float dataScale,int32_t dataOffset,float sliceScale,int32_t sliceOffset)2018 void libjit_scatterdata_i8_i32(int8_t *data, const dim_t *dataDims,
2019                                const int32_t *indices, const int8_t *slices,
2020                                dim_t numIndices, dim_t indexSize,
2021                                dim_t sliceSize, bool isCumulative,
2022                                float dataScale, int32_t dataOffset,
2023                                float sliceScale, int32_t sliceOffset) {
2024   if (isCumulative) {
2025     libjit_scatterdataaddquantized(data, dataDims, indices, slices, numIndices,
2026                                    indexSize, sliceSize, dataScale, dataOffset,
2027                                    sliceScale, sliceOffset);
2028   } else {
2029     libjit_scatterdatacopy(data, dataDims, indices, slices, numIndices,
2030                            indexSize, sliceSize);
2031   }
2032 }
2033 
libjit_lengths_to_ranges_i32(int32_t * ranges,const int32_t * lengths,dim_t size)2034 void libjit_lengths_to_ranges_i32(int32_t *ranges, const int32_t *lengths,
2035                                   dim_t size) {
2036   int32_t offset = 0;
2037   for (dim_t i = 0; i < size; i++) {
2038     auto length = lengths[i];
2039     ranges[i * 2] = offset;
2040     ranges[i * 2 + 1] = length;
2041     offset += length;
2042   }
2043 }
2044 
libjit_sparse_lengths_sum_f_u(float * dest,float * data,size_t * indices,int32_t * lengths,dim_t segments,dim_t lineSize)2045 void libjit_sparse_lengths_sum_f_u(float *dest, float *data, size_t *indices,
2046                                    int32_t *lengths, dim_t segments,
2047                                    dim_t lineSize) {
2048   libjit_sparse_lengths_sum_generic(dest, data, indices, lengths, segments,
2049                                     lineSize);
2050 }
2051 
libjit_sparse_lengths_sum_f_i32(float * dest,float * data,int32_t * indices,int32_t * lengths,dim_t segments,dim_t lineSize)2052 void libjit_sparse_lengths_sum_f_i32(float *dest, float *data, int32_t *indices,
2053                                      int32_t *lengths, dim_t segments,
2054                                      dim_t lineSize) {
2055   libjit_sparse_lengths_sum_generic(dest, data, indices, lengths, segments,
2056                                     lineSize);
2057 }
2058 
libjit_sparse_lengths_weighted_sum_f_u(float * dest,float * data,float * weights,size_t * indices,int32_t * lengths,dim_t segments,dim_t lineSize)2059 void libjit_sparse_lengths_weighted_sum_f_u(float *dest, float *data,
2060                                             float *weights, size_t *indices,
2061                                             int32_t *lengths, dim_t segments,
2062                                             dim_t lineSize) {
2063   libjit_sparse_lengths_weighted_sum_generic(dest, data, weights, indices,
2064                                              lengths, segments, lineSize);
2065 }
2066 
libjit_sparse_lengths_weighted_sum_f_i32(float * dest,float * data,float * weights,int32_t * indices,int32_t * lengths,dim_t segments,dim_t lineSize)2067 void libjit_sparse_lengths_weighted_sum_f_i32(float *dest, float *data,
2068                                               float *weights, int32_t *indices,
2069                                               int32_t *lengths, dim_t segments,
2070                                               dim_t lineSize) {
2071   libjit_sparse_lengths_weighted_sum_generic(dest, data, weights, indices,
2072                                              lengths, segments, lineSize);
2073 }
2074 
libjit_embedding_bag_f(float * dest,float * data,float * weights,size_t * indices,size_t * offsets,dim_t segments,dim_t lineSize,dim_t totalLength,bool hasEndOffset)2075 void libjit_embedding_bag_f(float *dest, float *data, float *weights,
2076                             size_t *indices, size_t *offsets, dim_t segments,
2077                             dim_t lineSize, dim_t totalLength,
2078                             bool hasEndOffset) {
2079   if (hasEndOffset) {
2080     --segments;
2081   }
2082   memset(dest, 0, segments * lineSize * sizeof(float));
2083   dim_t curIndex = 0;
2084   for (dim_t i = 0; i < segments; i++) {
2085     int64_t start = offsets[i];
2086     int64_t end =
2087         !hasEndOffset && i == segments - 1 ? totalLength : offsets[i + 1];
2088     for (int64_t j = start; j < end; j++) {
2089       float weight = weights[curIndex];
2090       dim_t line = indices[curIndex];
2091       for (dim_t k = 0; k < lineSize; k++) {
2092         dest[i * lineSize + k] += weight * data[line * lineSize + k];
2093       }
2094       curIndex++;
2095     }
2096   }
2097 }
2098 
libjit_sparse_lengths_weighted_sum_grad_f_u(const float * destGrad,float * dataGrad,float * weightsGrad,const float * data,const float * weights,const size_t * indices,const int32_t * lengths,dim_t segments,dim_t lineSize,dim_t dataGradRawSize)2099 void libjit_sparse_lengths_weighted_sum_grad_f_u(
2100     const float *destGrad, float *dataGrad, float *weightsGrad,
2101     const float *data, const float *weights, const size_t *indices,
2102     const int32_t *lengths, dim_t segments, dim_t lineSize,
2103     dim_t dataGradRawSize) {
2104   libjit_sparse_lengths_weighted_sum_grad_generic(
2105       destGrad, dataGrad, weightsGrad, data, weights, indices, lengths,
2106       segments, lineSize, dataGradRawSize);
2107 }
2108 
libjit_sparse_lengths_weighted_sum_grad_f_i32(const float * destGrad,float * dataGrad,float * weightsGrad,const float * data,const float * weights,const int32_t * indices,const int32_t * lengths,dim_t segments,dim_t lineSize,dim_t dataGradRawSize)2109 void libjit_sparse_lengths_weighted_sum_grad_f_i32(
2110     const float *destGrad, float *dataGrad, float *weightsGrad,
2111     const float *data, const float *weights, const int32_t *indices,
2112     const int32_t *lengths, dim_t segments, dim_t lineSize,
2113     dim_t dataGradRawSize) {
2114   libjit_sparse_lengths_weighted_sum_grad_generic(
2115       destGrad, dataGrad, weightsGrad, data, weights, indices, lengths,
2116       segments, lineSize, dataGradRawSize);
2117 }
2118 
libjit_rowwise_quantized_sparse_lengths_weighted_sum_f_u(float * dest,uint8_t * data,float * scales,float * offsets,float * weights,size_t * indices,int32_t * lengths,dim_t segments,dim_t lineSize)2119 void libjit_rowwise_quantized_sparse_lengths_weighted_sum_f_u(
2120     float *dest, uint8_t *data, float *scales, float *offsets, float *weights,
2121     size_t *indices, int32_t *lengths, dim_t segments, dim_t lineSize) {
2122   libjit_rowwise_quantized_sparse_lengths_weighted_sum_generic(
2123       dest, data, scales, offsets, weights, indices, lengths, segments,
2124       lineSize);
2125 }
2126 
libjit_rowwise_quantized_sparse_lengths_weighted_sum_f_i32(float * dest,uint8_t * data,float * scales,float * offsets,float * weights,int32_t * indices,int32_t * lengths,dim_t segments,dim_t lineSize)2127 void libjit_rowwise_quantized_sparse_lengths_weighted_sum_f_i32(
2128     float *dest, uint8_t *data, float *scales, float *offsets, float *weights,
2129     int32_t *indices, int32_t *lengths, dim_t segments, dim_t lineSize) {
2130   libjit_rowwise_quantized_sparse_lengths_weighted_sum_generic(
2131       dest, data, scales, offsets, weights, indices, lengths, segments,
2132       lineSize);
2133 }
2134 
libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f_u(float * dest,int8_t * data,float * weights,size_t * indices,int32_t * lengths,dim_t segments,dim_t inLineSize,dim_t outLineSize)2135 void libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f_u(
2136     float *dest, int8_t *data, float *weights, size_t *indices,
2137     int32_t *lengths, dim_t segments, dim_t inLineSize, dim_t outLineSize) {
2138   libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_generic(
2139       dest, data, weights, indices, lengths, segments, inLineSize, outLineSize);
2140 }
2141 
libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f_i32(float * dest,int8_t * data,float * weights,int32_t * indices,int32_t * lengths,dim_t segments,dim_t inLineSize,dim_t outLineSize)2142 void libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f_i32(
2143     float *dest, int8_t *data, float *weights, int32_t *indices,
2144     int32_t *lengths, dim_t segments, dim_t inLineSize, dim_t outLineSize) {
2145   libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_generic(
2146       dest, data, weights, indices, lengths, segments, inLineSize, outLineSize);
2147 }
2148 
libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f(float * dest,int8_t * data,float * weights,dim_t * indices,int32_t * lengths,dim_t segments,dim_t inLineSize,dim_t outLineSize)2149 void libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f(
2150     float *dest, int8_t *data, float *weights, dim_t *indices, int32_t *lengths,
2151     dim_t segments, dim_t inLineSize, dim_t outLineSize) {
2152   memset(dest, 0, segments * outLineSize * sizeof(float));
2153   dim_t curIndex = 0;
2154   for (dim_t i = 0; i < segments; i++) {
2155     for (int32_t j = 0, e = lengths[i]; j < e; j++) {
2156       const float weight = weights[curIndex];
2157       const dim_t line = indices[curIndex];
2158       const int8_t *currRowScaleOffsetPtr =
2159           data + ((line + 1) * inLineSize) - 2 * sizeof(float);
2160       float scale, offset;
2161       memcpy(&scale, currRowScaleOffsetPtr, sizeof(float));
2162       memcpy(&offset, currRowScaleOffsetPtr + sizeof(float), sizeof(float));
2163       for (dim_t k = 0; k < outLineSize; k++) {
2164         const float fData =
2165             (scale * (uint8_t)(data[line * inLineSize + k])) + offset;
2166         dest[i * outLineSize + k] += weight * fData;
2167       }
2168       curIndex++;
2169     }
2170   }
2171 }
2172 
libjit_embedding_bag_byte_rowwise_offsets_f(float * dest,int8_t * data,float * weights,size_t * indices,size_t * offsets,dim_t segments,dim_t numIndices,dim_t inLineSize,dim_t outLineSize,bool hasEndOffset)2173 void libjit_embedding_bag_byte_rowwise_offsets_f(
2174     float *dest, int8_t *data, float *weights, size_t *indices, size_t *offsets,
2175     dim_t segments, dim_t numIndices, dim_t inLineSize, dim_t outLineSize,
2176     bool hasEndOffset) {
2177   if (hasEndOffset) {
2178     --segments;
2179   }
2180   memset(dest, 0, segments * outLineSize * sizeof(float));
2181   for (dim_t i = 0; i < segments; i++) {
2182     dim_t start = offsets[i];
2183     dim_t end =
2184         !hasEndOffset && i == segments - 1 ? numIndices : offsets[i + 1];
2185     for (dim_t j = start; j < end; j++) {
2186       const float weight = weights[j];
2187       const dim_t line = indices[j];
2188       const int8_t *currRowScaleOffsetPtr =
2189           data + ((line + 1) * inLineSize) - 2 * sizeof(float);
2190       float scale, offset;
2191       memcpy(&scale, currRowScaleOffsetPtr, sizeof(float));
2192       memcpy(&offset, currRowScaleOffsetPtr + sizeof(float), sizeof(float));
2193       for (dim_t k = 0; k < outLineSize; k++) {
2194         const float fData =
2195             (scale * (uint8_t)(data[line * inLineSize + k])) + offset;
2196         dest[i * outLineSize + k] += weight * fData;
2197       }
2198     }
2199   }
2200 }
2201 
libjit_sparse_to_dense_f_u(float * dest,const size_t * indices,const float * values,dim_t numIndices,dim_t destSize,dim_t valueSize)2202 void libjit_sparse_to_dense_f_u(float *dest, const size_t *indices,
2203                                 const float *values, dim_t numIndices,
2204                                 dim_t destSize, dim_t valueSize) {
2205   libjit_sparse_to_dense_generic(dest, indices, values, numIndices, destSize,
2206                                  valueSize);
2207 }
2208 
libjit_sparse_to_dense_f_i32(float * dest,const int32_t * indices,const float * values,dim_t numIndices,dim_t destSize,dim_t valueSize)2209 void libjit_sparse_to_dense_f_i32(float *dest, const int32_t *indices,
2210                                   const float *values, dim_t numIndices,
2211                                   dim_t destSize, dim_t valueSize) {
2212   libjit_sparse_to_dense_generic(dest, indices, values, numIndices, destSize,
2213                                  valueSize);
2214 }
2215 
libjit_lengths_sum_f(float * dest,const float * data,const int32_t * lengths,dim_t destSize,dim_t lengthsSize,dim_t sliceSize)2216 void libjit_lengths_sum_f(float *dest, const float *data,
2217                           const int32_t *lengths, dim_t destSize,
2218                           dim_t lengthsSize, dim_t sliceSize) {
2219   memset(dest, 0, destSize * sizeof(float));
2220 
2221   dim_t offsetOut = 0;
2222   dim_t offsetIn = 0;
2223 
2224   for (dim_t i = 0; i < lengthsSize; ++i) {
2225     for (int32_t j = 0; j < lengths[i]; ++j) {
2226       for (dim_t k = 0; k < sliceSize; ++k) {
2227         dest[offsetOut + k] += data[offsetIn + k];
2228       }
2229       offsetIn += sliceSize;
2230     }
2231     offsetOut += sliceSize;
2232   }
2233 }
2234 
libjit_local_response_normalization_f(float * outW,const float * inW,float * scaleCache,const dim_t * outWdims,const dim_t * inWdims,dim_t halfWindow,float alpha,float beta,float k)2235 void libjit_local_response_normalization_f(
2236     float *outW, const float *inW, float *scaleCache, const dim_t *outWdims,
2237     const dim_t *inWdims, dim_t halfWindow, float alpha, float beta, float k) {
2238   dim_t window = 2 * halfWindow + 1;
2239   float normedAlpha = alpha / window;
2240 
2241   for (dim_t n = 0; n < inWdims[0]; n++) {
2242     for (dim_t h = 0; h < inWdims[1]; h++) {
2243       for (dim_t w = 0; w < inWdims[2]; w++) {
2244         for (dim_t c = 0; c < inWdims[3]; c++) {
2245           float m2 = 0.0;
2246           for (dim_t i = (c >= halfWindow ? c - halfWindow : 0);
2247                i <= MIN(c + halfWindow, inWdims[3] - 1); i++) {
2248             float val = inW[libjit_getXYZW(inWdims, n, h, w, i)];
2249             m2 += val * val;
2250           }
2251 
2252           float scale = k + normedAlpha * m2;
2253           scaleCache[libjit_getXYZW(inWdims, n, h, w, c)] = scale;
2254           float normFactor = pow(scale, -beta);
2255           outW[libjit_getXYZW(outWdims, n, h, w, c)] =
2256               inW[libjit_getXYZW(inWdims, n, h, w, c)] * normFactor;
2257         } // C
2258       }   // W
2259     }     // H
2260   }       // N
2261 }
2262 
libjit_local_response_normalization_grad_f(float * inG,const float * outG,const float * inW,const float * outW,const float * scaleCache,const dim_t * outWdims,dim_t halfWindow,float alpha,float beta)2263 void libjit_local_response_normalization_grad_f(
2264     float *inG, const float *outG, const float *inW, const float *outW,
2265     const float *scaleCache, const dim_t *outWdims, dim_t halfWindow,
2266     float alpha, float beta) {
2267   dim_t window = 2 * halfWindow + 1;
2268   float normedAlpha = alpha / window;
2269   float coeff = 2 * normedAlpha * beta;
2270 
2271   for (dim_t n = 0; n < outWdims[0]; n++) {
2272     for (dim_t h = 0; h < outWdims[1]; h++) {
2273       for (dim_t w = 0; w < outWdims[2]; w++) {
2274         // Prepare right half of sliding window based at c = 0
2275         float sum = 0.0;
2276         for (dim_t i = 0; i < MIN(halfWindow, outWdims[3]); i++) {
2277           float outg = outG[libjit_getXYZW(outWdims, n, h, w, i)];
2278           float outw = outW[libjit_getXYZW(outWdims, n, h, w, i)];
2279           float scale = scaleCache[libjit_getXYZW(outWdims, n, h, w, i)];
2280           sum += outg * (outw / scale);
2281         }
2282 
2283         for (dim_t c = 0; c < outWdims[3]; c++) {
2284           if (c > halfWindow) {
2285             dim_t j = c - halfWindow - 1;
2286             float outg = outG[libjit_getXYZW(outWdims, n, h, w, j)];
2287             float outw = outW[libjit_getXYZW(outWdims, n, h, w, j)];
2288             float scale = scaleCache[libjit_getXYZW(outWdims, n, h, w, j)];
2289             sum -= outg * (outw / scale);
2290           }
2291 
2292           dim_t j = c + halfWindow;
2293           if (j < outWdims[3]) {
2294             float outg = outG[libjit_getXYZW(outWdims, n, h, w, j)];
2295             float outw = outW[libjit_getXYZW(outWdims, n, h, w, j)];
2296             float scale = scaleCache[libjit_getXYZW(outWdims, n, h, w, j)];
2297             sum += outg * (outw / scale);
2298           }
2299 
2300           float outg = outG[libjit_getXYZW(outWdims, n, h, w, c)];
2301           float inw = inW[libjit_getXYZW(outWdims, n, h, w, c)];
2302           float scale = scaleCache[libjit_getXYZW(outWdims, n, h, w, c)];
2303           inG[libjit_getXYZW(outWdims, n, h, w, c)] =
2304               outg * pow(scale, -beta) - coeff * inw * sum;
2305         }
2306       } // W
2307     }   // H
2308   }     // N
2309 }
2310 
libjit_max_pool_i8(const int8_t * inW,int8_t * outW,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernelSizes,dim_t * strides,dim_t * pads)2311 void libjit_max_pool_i8(const int8_t *inW, int8_t *outW, const dim_t *inWdims,
2312                         const dim_t *outWdims, dim_t *kernelSizes,
2313                         dim_t *strides, dim_t *pads) {
2314   libjit_max_pool_generic(inW, outW, inWdims, outWdims, kernelSizes, strides,
2315                           pads);
2316 }
2317 
libjit_max_pool_f(const float * inW,float * outW,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernelSizes,dim_t * strides,dim_t * pads)2318 void libjit_max_pool_f(const float *inW, float *outW, const dim_t *inWdims,
2319                        const dim_t *outWdims, dim_t *kernelSizes,
2320                        dim_t *strides, dim_t *pads) {
2321   libjit_max_pool_generic(inW, outW, inWdims, outWdims, kernelSizes, strides,
2322                           pads);
2323 }
2324 
libjit_max_pool_argmax_i8_u(const int8_t * inW,int8_t * outW,int64_t * argmax,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernels,dim_t * strides,dim_t * pads)2325 void libjit_max_pool_argmax_i8_u(const int8_t *inW, int8_t *outW,
2326                                  int64_t *argmax, const dim_t *inWdims,
2327                                  const dim_t *outWdims, dim_t *kernels,
2328                                  dim_t *strides, dim_t *pads) {
2329   libjit_max_pool_argmax_generic(inW, outW, argmax, inWdims, outWdims, kernels,
2330                                  strides, pads);
2331 }
2332 
libjit_max_pool_argmax_f_u(const float * inW,float * outW,int64_t * argmax,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernels,dim_t * strides,dim_t * pads)2333 void libjit_max_pool_argmax_f_u(const float *inW, float *outW, int64_t *argmax,
2334                                 const dim_t *inWdims, const dim_t *outWdims,
2335                                 dim_t *kernels, dim_t *strides, dim_t *pads) {
2336   libjit_max_pool_argmax_generic(inW, outW, argmax, inWdims, outWdims, kernels,
2337                                  strides, pads);
2338 }
2339 
libjit_max_pool_argmax_i8_i32(const int8_t * inW,int8_t * outW,int32_t * argmax,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernels,dim_t * strides,dim_t * pads)2340 void libjit_max_pool_argmax_i8_i32(const int8_t *inW, int8_t *outW,
2341                                    int32_t *argmax, const dim_t *inWdims,
2342                                    const dim_t *outWdims, dim_t *kernels,
2343                                    dim_t *strides, dim_t *pads) {
2344   libjit_max_pool_argmax_generic(inW, outW, argmax, inWdims, outWdims, kernels,
2345                                  strides, pads);
2346 }
2347 
libjit_max_pool_argmax_f_i32(const float * inW,float * outW,int32_t * argmax,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernels,dim_t * strides,dim_t * pads)2348 void libjit_max_pool_argmax_f_i32(const float *inW, float *outW,
2349                                   int32_t *argmax, const dim_t *inWdims,
2350                                   const dim_t *outWdims, dim_t *kernels,
2351                                   dim_t *strides, dim_t *pads) {
2352   libjit_max_pool_argmax_generic(inW, outW, argmax, inWdims, outWdims, kernels,
2353                                  strides, pads);
2354 }
2355 
libjit_arg_max_i8_u(const int8_t * inW,int64_t * outW,const dim_t * inWdims,size_t inWNumDims,size_t axis)2356 void libjit_arg_max_i8_u(const int8_t *inW, int64_t *outW, const dim_t *inWdims,
2357                          size_t inWNumDims, size_t axis) {
2358   libjit_arg_max_generic(inW, outW, inWdims, inWNumDims, axis);
2359 }
2360 
libjit_arg_max_i8_i32(const int8_t * inW,int32_t * outW,const dim_t * inWdims,size_t inWNumDims,size_t axis)2361 void libjit_arg_max_i8_i32(const int8_t *inW, int32_t *outW,
2362                            const dim_t *inWdims, size_t inWNumDims,
2363                            size_t axis) {
2364   libjit_arg_max_generic(inW, outW, inWdims, inWNumDims, axis);
2365 }
2366 
libjit_arg_max_f_u(const float * inW,int64_t * outW,const dim_t * inWdims,size_t inWNumDims,size_t axis)2367 void libjit_arg_max_f_u(const float *inW, int64_t *outW, const dim_t *inWdims,
2368                         size_t inWNumDims, size_t axis) {
2369   libjit_arg_max_generic(inW, outW, inWdims, inWNumDims, axis);
2370 }
2371 
libjit_arg_max_f_i32(const float * inW,int32_t * outW,const dim_t * inWdims,size_t inWNumDims,size_t axis)2372 void libjit_arg_max_f_i32(const float *inW, int32_t *outW, const dim_t *inWdims,
2373                           size_t inWNumDims, size_t axis) {
2374   libjit_arg_max_generic(inW, outW, inWdims, inWNumDims, axis);
2375 }
2376 
libjit_arg_min_i8_u(const int8_t * inW,int64_t * outW,const dim_t * inWdims,size_t inWNumDims,size_t axis)2377 void libjit_arg_min_i8_u(const int8_t *inW, int64_t *outW, const dim_t *inWdims,
2378                          size_t inWNumDims, size_t axis) {
2379   libjit_arg_min_generic(inW, outW, inWdims, inWNumDims, axis);
2380 }
2381 
libjit_arg_min_i8_i32(const int8_t * inW,int32_t * outW,const dim_t * inWdims,size_t inWNumDims,size_t axis)2382 void libjit_arg_min_i8_i32(const int8_t *inW, int32_t *outW,
2383                            const dim_t *inWdims, size_t inWNumDims,
2384                            size_t axis) {
2385   libjit_arg_min_generic(inW, outW, inWdims, inWNumDims, axis);
2386 }
2387 
libjit_arg_min_f_u(const float * inW,int64_t * outW,const dim_t * inWdims,size_t inWNumDims,size_t axis)2388 void libjit_arg_min_f_u(const float *inW, int64_t *outW, const dim_t *inWdims,
2389                         size_t inWNumDims, size_t axis) {
2390   libjit_arg_min_generic(inW, outW, inWdims, inWNumDims, axis);
2391 }
2392 
libjit_arg_min_f_i32(const float * inW,int32_t * outW,const dim_t * inWdims,size_t inWNumDims,size_t axis)2393 void libjit_arg_min_f_i32(const float *inW, int32_t *outW, const dim_t *inWdims,
2394                           size_t inWNumDims, size_t axis) {
2395   libjit_arg_min_generic(inW, outW, inWdims, inWNumDims, axis);
2396 }
2397 
libjit_max_pool_argmax_grad_f_u(float * inG,const float * outG,const int64_t * argmax,const dim_t * inGdims,const dim_t * outWdims)2398 void libjit_max_pool_argmax_grad_f_u(float *inG, const float *outG,
2399                                      const int64_t *argmax,
2400                                      const dim_t *inGdims,
2401                                      const dim_t *outWdims) {
2402   libjit_max_pool_argmax_grad_generic(inG, outG, argmax, inGdims, outWdims);
2403 }
2404 
libjit_max_pool_argmax_grad_f_i32(float * inG,const float * outG,const int32_t * argmax,const dim_t * inGdims,const dim_t * outWdims)2405 void libjit_max_pool_argmax_grad_f_i32(float *inG, const float *outG,
2406                                        const int32_t *argmax,
2407                                        const dim_t *inGdims,
2408                                        const dim_t *outWdims) {
2409   libjit_max_pool_argmax_grad_generic(inG, outG, argmax, inGdims, outWdims);
2410 }
2411 
libjit_resizenearest_f(float * dst,const float * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)2412 void libjit_resizenearest_f(float *dst, const float *src, const float *scale,
2413                             const dim_t *inWdims, const dim_t *outWdims) {
2414   libjit_resizenearest_generic(dst, src, scale, inWdims, outWdims);
2415 }
2416 
libjit_resizenearest_i8(int8_t * dst,const int8_t * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)2417 void libjit_resizenearest_i8(int8_t *dst, const int8_t *src, const float *scale,
2418                              const dim_t *inWdims, const dim_t *outWdims) {
2419   libjit_resizenearest_generic(dst, src, scale, inWdims, outWdims);
2420 }
2421 
libjit_resizenearest_i32(int32_t * dst,const int32_t * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)2422 void libjit_resizenearest_i32(int32_t *dst, const int32_t *src,
2423                               const float *scale, const dim_t *inWdims,
2424                               const dim_t *outWdims) {
2425   libjit_resizenearest_generic(dst, src, scale, inWdims, outWdims);
2426 }
2427 
libjit_resizenearest_u(int64_t * dst,const int64_t * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)2428 void libjit_resizenearest_u(int64_t *dst, const int64_t *src,
2429                             const float *scale, const dim_t *inWdims,
2430                             const dim_t *outWdims) {
2431   libjit_resizenearest_generic(dst, src, scale, inWdims, outWdims);
2432 }
2433 
libjit_resizebilinear_f(float * dst,const float * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)2434 void libjit_resizebilinear_f(float *dst, const float *src, const float *scale,
2435                              const dim_t *inWdims, const dim_t *outWdims) {
2436   libjit_resizebilinear_generic(dst, src, scale, inWdims, outWdims);
2437 }
2438 
libjit_resizebilinear_i8(int8_t * dst,const int8_t * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)2439 void libjit_resizebilinear_i8(int8_t *dst, const int8_t *src,
2440                               const float *scale, const dim_t *inWdims,
2441                               const dim_t *outWdims) {
2442   libjit_resizebilinear_generic(dst, src, scale, inWdims, outWdims);
2443 }
2444 
libjit_resizebilinear_i32(int32_t * dst,const int32_t * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)2445 void libjit_resizebilinear_i32(int32_t *dst, const int32_t *src,
2446                                const float *scale, const dim_t *inWdims,
2447                                const dim_t *outWdims) {
2448   libjit_resizebilinear_generic(dst, src, scale, inWdims, outWdims);
2449 }
2450 
libjit_resizebilinear_u(int64_t * dst,const int64_t * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)2451 void libjit_resizebilinear_u(int64_t *dst, const int64_t *src,
2452                              const float *scale, const dim_t *inWdims,
2453                              const dim_t *outWdims) {
2454   libjit_resizebilinear_generic(dst, src, scale, inWdims, outWdims);
2455 }
2456 
libjit_avg_pool_i8(const int8_t * inW,int8_t * outW,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernelSizes,dim_t * strides,dim_t * pads,int32_t outOffset,int32_t inOffset,int32_t outPre,int32_t outPost,int32_t outScale)2457 void libjit_avg_pool_i8(const int8_t *inW, int8_t *outW, const dim_t *inWdims,
2458                         const dim_t *outWdims, dim_t *kernelSizes,
2459                         dim_t *strides, dim_t *pads, int32_t outOffset,
2460                         int32_t inOffset, int32_t outPre, int32_t outPost,
2461                         int32_t outScale) {
2462   dim_t pad_t = pads[0];
2463   dim_t pad_l = pads[1];
2464   dim_t stride_h = strides[0];
2465   dim_t stride_w = strides[1];
2466   dim_t kernel_h = kernelSizes[0];
2467   dim_t kernel_w = kernelSizes[1];
2468   // For each input in the batch:
2469   for (dim_t n = 0; n < outWdims[0]; n++) {
2470     // For each (x,y) step in the input/output tensor:
2471     sdim_t x = -sdim_t(pad_t);
2472     for (dim_t ax = 0; ax < outWdims[1]; x += stride_h, ax++) {
2473       sdim_t y = -sdim_t(pad_l);
2474       for (dim_t ay = 0; ay < outWdims[2]; y += stride_w, ay++) {
2475         // For each layer in the output tensor:
2476         for (dim_t z = 0; z < inWdims[3]; z++) {
2477           int32_t sum = 0;
2478 
2479           for (dim_t fx = 0; fx < kernel_h; fx++) {
2480             for (dim_t fy = 0; fy < kernel_w; fy++) {
2481               sdim_t ox = x + fx;
2482               sdim_t oy = y + fy;
2483 
2484               // Ignore index access below zero (this is due to padding).
2485               if (ox < 0 || oy < 0 || ox >= (sdim_t)inWdims[1] ||
2486                   oy >= (sdim_t)inWdims[2]) {
2487                 continue;
2488               }
2489               sum += inW[libjit_getXYZW(inWdims, n, (dim_t)ox, (dim_t)oy, z)] -
2490                      inOffset;
2491             }
2492           }
2493 
2494           outW[libjit_getXYZW(outWdims, n, ax, ay, z)] = libjit_clip(
2495               libjit_scale_i32i8(sum, outPre, outPost, outScale, outOffset));
2496         } // C
2497       }   // W
2498     }     // H
2499   }       // N
2500 }
2501 
libjit_avg_pool_f(const float * inW,float * outW,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernelSizes,dim_t * strides,dim_t * pads)2502 void libjit_avg_pool_f(const float *inW, float *outW, const dim_t *inWdims,
2503                        const dim_t *outWdims, dim_t *kernelSizes,
2504                        dim_t *strides, dim_t *pads) {
2505   dim_t pad_t = pads[0];
2506   dim_t pad_l = pads[1];
2507   dim_t stride_h = strides[0];
2508   dim_t stride_w = strides[1];
2509   dim_t kernel_h = kernelSizes[0];
2510   dim_t kernel_w = kernelSizes[1];
2511   float filterArea = kernel_h * kernel_w;
2512   // For each input in the batch:
2513   for (dim_t n = 0; n < outWdims[0]; n++) {
2514     // For each (x,y) step in the input/output tensor:
2515     sdim_t x = -(sdim_t)pad_t;
2516     for (dim_t ax = 0; ax < outWdims[1]; x += stride_h, ax++) {
2517       sdim_t y = -(sdim_t)pad_l;
2518       for (dim_t ay = 0; ay < outWdims[2]; y += stride_w, ay++) {
2519         // For each layer in the output tensor:
2520         for (dim_t z = 0; z < inWdims[3]; z++) {
2521 
2522           float sum = 0;
2523 
2524           for (dim_t fx = 0; fx < kernel_h; fx++) {
2525             for (dim_t fy = 0; fy < kernel_w; fy++) {
2526               sdim_t ox = x + fx;
2527               sdim_t oy = y + fy;
2528 
2529               // Ignore index access below zero (this is due to padding).
2530               if (ox < 0 || oy < 0 || ox >= (sdim_t)inWdims[1] ||
2531                   oy >= (sdim_t)inWdims[2]) {
2532                 continue;
2533               }
2534 
2535               sum += inW[libjit_getXYZW(inWdims, n, (dim_t)ox, (dim_t)oy, z)];
2536             }
2537           }
2538 
2539           outW[libjit_getXYZW(outWdims, n, ax, ay, z)] = sum / filterArea;
2540         } // C
2541       }   // W
2542     }     // H
2543   }       // N
2544 }
2545 
libjit_adaptive_avg_pool_f(const float * inW,float * outW,const dim_t * inWdims,const dim_t * outWdims)2546 void libjit_adaptive_avg_pool_f(const float *inW, float *outW,
2547                                 const dim_t *inWdims, const dim_t *outWdims) {
2548 // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
2549 #define START_IND(a, b, c) (size_t) std::floor((float)((a) * (c)) / (b))
2550 #define END_IND(a, b, c) (size_t) std::ceil((float)(((a) + 1) * (c)) / (b))
2551 
2552   // For each input in the batch:
2553   for (dim_t n = 0; n < outWdims[0]; n++) {
2554     // For each layer in the output tensor:
2555     for (dim_t z = 0; z < inWdims[3]; z++) {
2556       // For each value in the output tensor:
2557       for (dim_t ax = 0; ax < outWdims[1]; ax++) {
2558 
2559         dim_t x = START_IND(ax, outWdims[1], inWdims[1]);
2560         dim_t kH = END_IND(ax, outWdims[1], inWdims[1]) - x;
2561 
2562         for (dim_t ay = 0; ay < outWdims[2]; ay++) {
2563 
2564           dim_t y = START_IND(ay, outWdims[2], inWdims[2]);
2565           dim_t kW = END_IND(ay, outWdims[2], inWdims[2]) - y;
2566 
2567           float sum = 0;
2568           for (dim_t fx = 0; fx < kH; fx++) {
2569             for (dim_t fy = 0; fy < kW; fy++) {
2570               dim_t ox = x + fx;
2571               dim_t oy = y + fy;
2572 
2573               sum += inW[libjit_getXYZW(inWdims, n, ox, oy, z)];
2574             }
2575           }
2576           outW[libjit_getXYZW(outWdims, n, ax, ay, z)] = (sum / kW / kH);
2577         } // W
2578       }   // H
2579     }     // C
2580   }       // N
2581 #undef START_IND
2582 #undef END_IND
2583 }
2584 
libjit_avg_pool_grad_f(float * inG,const float * outG,const dim_t * inGdims,const dim_t * outWdims,dim_t * kernels,dim_t * strides,dim_t * pads)2585 void libjit_avg_pool_grad_f(float *inG, const float *outG, const dim_t *inGdims,
2586                             const dim_t *outWdims, dim_t *kernels,
2587                             dim_t *strides, dim_t *pads) {
2588   dim_t pad_t = pads[0];
2589   dim_t pad_l = pads[1];
2590   dim_t stride_h = strides[0];
2591   dim_t stride_w = strides[1];
2592   dim_t kernel_h = kernels[0];
2593   dim_t kernel_w = kernels[1];
2594   float kernelArea = kernel_h * kernel_w;
2595 
2596   // NHWC format is assumed
2597   for (dim_t n = 0; n < outWdims[0]; n++) {
2598     for (dim_t z = 0; z < outWdims[3]; z++) {
2599       // Clear inG
2600       for (dim_t x = 0; x < inGdims[1]; x++) {
2601         for (dim_t y = 0; y < inGdims[2]; y++) {
2602           inG[libjit_getXYZW(inGdims, n, x, y, z)] = 0.0;
2603         }
2604       }
2605 
2606       sdim_t x = -(sdim_t)pad_t;
2607       for (dim_t ax = 0; ax < outWdims[1]; x += stride_h, ax++) {
2608         sdim_t y = -(sdim_t)pad_l;
2609         for (dim_t ay = 0; ay < outWdims[2]; y += stride_w, ay++) {
2610           float df = outG[libjit_getXYZW(outWdims, n, ax, ay, z)] / kernelArea;
2611           for (dim_t kx = 0; kx < kernel_h; kx++) {
2612             for (dim_t ky = 0; ky < kernel_w; ky++) {
2613               sdim_t ox = x + kx;
2614               sdim_t oy = y + ky;
2615               if (ox < 0 || oy < 0 || ox >= (sdim_t)inGdims[1] ||
2616                   oy >= (sdim_t)inGdims[2]) {
2617                 continue;
2618               }
2619               inG[libjit_getXYZW(inGdims, n, (dim_t)ox, (dim_t)oy, z)] += df;
2620             }
2621           }
2622         } // W
2623       }   // H
2624     }     // C
2625   }       // N
2626 }
2627 
libjit_element_quantize_kernel_i8(dim_t idx,const float * inW,float scale,int32_t offset)2628 int8_t libjit_element_quantize_kernel_i8(dim_t idx, const float *inW,
2629                                          float scale, int32_t offset) {
2630   int32_t result = (int32_t)nearbyintf(inW[idx] / scale + offset);
2631   return (int8_t)MAX(INT8_MIN, MIN(INT8_MAX, result));
2632 }
2633 
libjit_element_quantize_kernel_i32(dim_t idx,const float * inW,float scale,int32_t offset)2634 int32_t libjit_element_quantize_kernel_i32(dim_t idx, const float *inW,
2635                                            float scale, int32_t offset) {
2636   int32_t result = (int32_t)nearbyintf(inW[idx] / scale + offset);
2637   return result;
2638 }
2639 
libjit_element_dequantize_kernel_f(dim_t idx,const int8_t * inW,float scale,int32_t offset)2640 float libjit_element_dequantize_kernel_f(dim_t idx, const int8_t *inW,
2641                                          float scale, int32_t offset) {
2642   return scale * (inW[idx] - offset);
2643 }
2644 
libjit_element_rescale_kernel_i8(dim_t idx,const int8_t * inW,int32_t outOffset,int32_t inOffset,int32_t pre,int32_t post,int32_t scale)2645 int8_t libjit_element_rescale_kernel_i8(dim_t idx, const int8_t *inW,
2646                                         int32_t outOffset, int32_t inOffset,
2647                                         int32_t pre, int32_t post,
2648                                         int32_t scale) {
2649   int32_t s =
2650       libjit_scale_i32i8(inW[idx] - inOffset, pre, post, scale, outOffset);
2651   return libjit_clip(s);
2652 }
2653 
libjit_softmax_f(const float * inW,float * outW,const dim_t * idim,const dim_t * odim)2654 void libjit_softmax_f(const float *inW, float *outW, const dim_t *idim,
2655                       const dim_t *odim) {
2656   for (dim_t n = 0; n < idim[0]; n++) {
2657     float max = inW[libjit_getXY(idim, n, 0)];
2658 
2659     // Find Max.
2660     for (dim_t i = 1; i < idim[1]; i++) {
2661       max = MAX(max, inW[libjit_getXY(idim, n, i)]);
2662     }
2663 
2664     float sum = 0;
2665 
2666     // Compute exp.
2667     for (dim_t i = 0; i < idim[1]; i++) {
2668       float e = expf(inW[libjit_getXY(idim, n, i)] - max);
2669       sum += e;
2670       outW[libjit_getXY(odim, n, i)] = e;
2671     }
2672 
2673     // Normalize the output.
2674     for (dim_t i = 0; i < idim[1]; i++) {
2675       outW[libjit_getXY(odim, n, i)] = outW[libjit_getXY(odim, n, i)] / sum;
2676     }
2677   } // N
2678 }
2679 
libjit_softmax_grad_f_u(float * inG,float * outW,const size_t * selectedW,const dim_t * idim,const dim_t * selectdim)2680 void libjit_softmax_grad_f_u(float *inG, float *outW, const size_t *selectedW,
2681                              const dim_t *idim, const dim_t *selectdim) {
2682   libjit_softmax_grad_generic(inG, outW, selectedW, idim, selectdim);
2683 }
2684 
libjit_softmax_grad_f_i32(float * inG,float * outW,const int32_t * selectedW,const dim_t * idim,const dim_t * selectdim)2685 void libjit_softmax_grad_f_i32(float *inG, float *outW,
2686                                const int32_t *selectedW, const dim_t *idim,
2687                                const dim_t *selectdim) {
2688   libjit_softmax_grad_generic(inG, outW, selectedW, idim, selectdim);
2689 }
2690 
libjit_topk_f_u(float * values,size_t * indices,const float * input,void * scratch,dim_t k,dim_t n,dim_t size)2691 void libjit_topk_f_u(float *values, size_t *indices, const float *input,
2692                      void *scratch, dim_t k, dim_t n, dim_t size) {
2693   libjit_topk(values, indices, input, scratch, k, n, size);
2694 }
2695 
libjit_topk_f_i32(float * values,int32_t * indices,const float * input,void * scratch,dim_t k,dim_t n,dim_t size)2696 void libjit_topk_f_i32(float *values, int32_t *indices, const float *input,
2697                        void *scratch, dim_t k, dim_t n, dim_t size) {
2698   libjit_topk(values, indices, input, scratch, k, n, size);
2699 }
2700 
libjit_topk_i8_u(int8_t * values,size_t * indices,const int8_t * input,void * scratch,dim_t k,dim_t n,dim_t size)2701 void libjit_topk_i8_u(int8_t *values, size_t *indices, const int8_t *input,
2702                       void *scratch, dim_t k, dim_t n, dim_t size) {
2703   libjit_topk(values, indices, input, scratch, k, n, size);
2704 }
2705 
libjit_topk_i8_i32(int8_t * values,int32_t * indices,const int8_t * input,void * scratch,dim_t k,dim_t n,dim_t size)2706 void libjit_topk_i8_i32(int8_t *values, int32_t *indices, const int8_t *input,
2707                         void *scratch, dim_t k, dim_t n, dim_t size) {
2708   libjit_topk(values, indices, input, scratch, k, n, size);
2709 }
2710 
libjit_transpose_i8(const int8_t * inW,int8_t * outW,const dim_t * idim,const dim_t * odim,const dim_t * shuffle,dim_t numDims)2711 void libjit_transpose_i8(const int8_t *inW, int8_t *outW, const dim_t *idim,
2712                          const dim_t *odim, const dim_t *shuffle,
2713                          dim_t numDims) {
2714   libjit_transpose_generic(inW, outW, idim, odim, shuffle, numDims);
2715 }
2716 
libjit_transpose_f(const float * inW,float * outW,const dim_t * idim,const dim_t * odim,const dim_t * shuffle,dim_t numDims)2717 void libjit_transpose_f(const float *inW, float *outW, const dim_t *idim,
2718                         const dim_t *odim, const dim_t *shuffle,
2719                         dim_t numDims) {
2720   libjit_transpose_generic(inW, outW, idim, odim, shuffle, numDims);
2721 }
2722 
libjit_transpose_u(const int64_t * inW,int64_t * outW,const dim_t * idim,const dim_t * odim,const dim_t * shuffle,dim_t numDims)2723 void libjit_transpose_u(const int64_t *inW, int64_t *outW, const dim_t *idim,
2724                         const dim_t *odim, const dim_t *shuffle,
2725                         dim_t numDims) {
2726   libjit_transpose_generic(inW, outW, idim, odim, shuffle, numDims);
2727 }
2728 
libjit_transpose_b(const bool * inW,bool * outW,const dim_t * idim,const dim_t * odim,const dim_t * shuffle,dim_t numDims)2729 void libjit_transpose_b(const bool *inW, bool *outW, const dim_t *idim,
2730                         const dim_t *odim, const dim_t *shuffle,
2731                         dim_t numDims) {
2732   libjit_transpose_generic(inW, outW, idim, odim, shuffle, numDims);
2733 }
2734 
libjit_flip_i8(const int8_t * inW,int8_t * outW,const dim_t * dims,dim_t axis,dim_t numDims)2735 void libjit_flip_i8(const int8_t *inW, int8_t *outW, const dim_t *dims,
2736                     dim_t axis, dim_t numDims) {
2737   libjit_flip_generic(inW, outW, dims, axis, numDims);
2738 }
2739 
libjit_flip_i16(const int16_t * inW,int16_t * outW,const dim_t * dims,dim_t axis,dim_t numDims)2740 void libjit_flip_i16(const int16_t *inW, int16_t *outW, const dim_t *dims,
2741                      dim_t axis, dim_t numDims) {
2742   libjit_flip_generic(inW, outW, dims, axis, numDims);
2743 }
2744 
libjit_flip_i32(const int32_t * inW,int32_t * outW,const dim_t * dims,dim_t axis,dim_t numDims)2745 void libjit_flip_i32(const int32_t *inW, int32_t *outW, const dim_t *dims,
2746                      dim_t axis, dim_t numDims) {
2747   libjit_flip_generic(inW, outW, dims, axis, numDims);
2748 }
2749 
libjit_flip_u(const int64_t * inW,int64_t * outW,const dim_t * dims,dim_t axis,dim_t numDims)2750 void libjit_flip_u(const int64_t *inW, int64_t *outW, const dim_t *dims,
2751                    dim_t axis, dim_t numDims) {
2752   libjit_flip_generic(inW, outW, dims, axis, numDims);
2753 }
2754 
libjit_flip_f(const float * inW,float * outW,const dim_t * dims,dim_t axis,dim_t numDims)2755 void libjit_flip_f(const float *inW, float *outW, const dim_t *dims, dim_t axis,
2756                    dim_t numDims) {
2757   libjit_flip_generic(inW, outW, dims, axis, numDims);
2758 }
2759 
libjit_flip_b(const bool * inW,bool * outW,const dim_t * dims,dim_t axis,dim_t numDims)2760 void libjit_flip_b(const bool *inW, bool *outW, const dim_t *dims, dim_t axis,
2761                    dim_t numDims) {
2762   libjit_flip_generic(inW, outW, dims, axis, numDims);
2763 }
2764 
libjit_insert_tensor_f(float * tensor,float * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim,dim_t count,dim_t axis)2765 void libjit_insert_tensor_f(float *tensor, float *slice, dim_t *offset,
2766                             dim_t *tensorDim, dim_t *sliceDim,
2767                             dim_t numDimsTensor, dim_t numDimsSlice,
2768                             dim_t offsetDim, dim_t count, dim_t axis) {
2769   libjit_insert_tensor(tensor, slice, offset, tensorDim, sliceDim,
2770                        numDimsTensor, numDimsSlice, offsetDim, count, axis);
2771 }
2772 
libjit_insert_tensor_i32(int32_t * tensor,int32_t * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim,dim_t count,dim_t axis)2773 void libjit_insert_tensor_i32(int32_t *tensor, int32_t *slice, dim_t *offset,
2774                               dim_t *tensorDim, dim_t *sliceDim,
2775                               dim_t numDimsTensor, dim_t numDimsSlice,
2776                               dim_t offsetDim, dim_t count, dim_t axis) {
2777   libjit_insert_tensor(tensor, slice, offset, tensorDim, sliceDim,
2778                        numDimsTensor, numDimsSlice, offsetDim, count, axis);
2779 }
2780 
libjit_extract_tensor_f(float * tensor,float * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim)2781 void libjit_extract_tensor_f(float *tensor, float *slice, dim_t *offset,
2782                              dim_t *tensorDim, dim_t *sliceDim,
2783                              dim_t numDimsTensor, dim_t numDimsSlice,
2784                              dim_t offsetDim) {
2785   libjit_extract_tensor(tensor, slice, offset, tensorDim, sliceDim,
2786                         numDimsTensor, numDimsSlice, offsetDim);
2787 }
2788 
libjit_extract_tensor_i8(int8_t * tensor,int8_t * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim)2789 void libjit_extract_tensor_i8(int8_t *tensor, int8_t *slice, dim_t *offset,
2790                               dim_t *tensorDim, dim_t *sliceDim,
2791                               dim_t numDimsTensor, dim_t numDimsSlice,
2792                               dim_t offsetDim) {
2793   libjit_extract_tensor(tensor, slice, offset, tensorDim, sliceDim,
2794                         numDimsTensor, numDimsSlice, offsetDim);
2795 }
2796 
libjit_extract_tensor_i32(int32_t * tensor,int32_t * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim)2797 void libjit_extract_tensor_i32(int32_t *tensor, int32_t *slice, dim_t *offset,
2798                                dim_t *tensorDim, dim_t *sliceDim,
2799                                dim_t numDimsTensor, dim_t numDimsSlice,
2800                                dim_t offsetDim) {
2801   libjit_extract_tensor(tensor, slice, offset, tensorDim, sliceDim,
2802                         numDimsTensor, numDimsSlice, offsetDim);
2803 }
2804 
libjit_insert_tensor_u(int64_t * tensor,int64_t * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim,dim_t count,dim_t axis)2805 void libjit_insert_tensor_u(int64_t *tensor, int64_t *slice, dim_t *offset,
2806                             dim_t *tensorDim, dim_t *sliceDim,
2807                             dim_t numDimsTensor, dim_t numDimsSlice,
2808                             dim_t offsetDim, dim_t count, dim_t axis) {
2809   libjit_insert_tensor(tensor, slice, offset, tensorDim, sliceDim,
2810                        numDimsTensor, numDimsSlice, offsetDim, count, axis);
2811 }
2812 
libjit_extract_tensor_u(int64_t * tensor,int64_t * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim)2813 void libjit_extract_tensor_u(int64_t *tensor, int64_t *slice, dim_t *offset,
2814                              dim_t *tensorDim, dim_t *sliceDim,
2815                              dim_t numDimsTensor, dim_t numDimsSlice,
2816                              dim_t offsetDim) {
2817   libjit_extract_tensor(tensor, slice, offset, tensorDim, sliceDim,
2818                         numDimsTensor, numDimsSlice, offsetDim);
2819 }
2820 
libjit_insert_tensor_i8(int8_t * tensor,int8_t * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim,dim_t count,dim_t axis)2821 void libjit_insert_tensor_i8(int8_t *tensor, int8_t *slice, dim_t *offset,
2822                              dim_t *tensorDim, dim_t *sliceDim,
2823                              dim_t numDimsTensor, dim_t numDimsSlice,
2824                              dim_t offsetDim, dim_t count, dim_t axis) {
2825   libjit_insert_tensor(tensor, slice, offset, tensorDim, sliceDim,
2826                        numDimsTensor, numDimsSlice, offsetDim, count, axis);
2827 }
2828 
libjit_insert_tensor_b(int8_t * tensor,int8_t * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim,dim_t count,dim_t axis)2829 void libjit_insert_tensor_b(int8_t *tensor, int8_t *slice, dim_t *offset,
2830                             dim_t *tensorDim, dim_t *sliceDim,
2831                             dim_t numDimsTensor, dim_t numDimsSlice,
2832                             dim_t offsetDim, dim_t count, dim_t axis) {
2833   libjit_insert_tensor(tensor, slice, offset, tensorDim, sliceDim,
2834                        numDimsTensor, numDimsSlice, offsetDim, count, axis);
2835 }
2836 
libjit_space_to_depth_f(const float * inTensor,float * outTensor,dim_t blockSize,const dim_t * inDims,const dim_t * outDims)2837 void libjit_space_to_depth_f(const float *inTensor, float *outTensor,
2838                              dim_t blockSize, const dim_t *inDims,
2839                              const dim_t *outDims) {
2840   libjit_space_to_depth_generic(inTensor, outTensor, blockSize, inDims,
2841                                 outDims);
2842 }
2843 
libjit_space_to_depth_i8(const int8_t * inTensor,int8_t * outTensor,dim_t blockSize,const dim_t * inDims,const dim_t * outDims)2844 void libjit_space_to_depth_i8(const int8_t *inTensor, int8_t *outTensor,
2845                               dim_t blockSize, const dim_t *inDims,
2846                               const dim_t *outDims) {
2847   libjit_space_to_depth_generic(inTensor, outTensor, blockSize, inDims,
2848                                 outDims);
2849 }
2850 
2851 /// Function to dump a tensor in text format in the console.
libjit_dump_tensor_console(uint8_t * tensor,dim_t * tensorDim,dim_t numDimsTensor,dim_t elemKind,const char * name)2852 __attribute__((noinline)) void libjit_dump_tensor_console(uint8_t *tensor,
2853                                                           dim_t *tensorDim,
2854                                                           dim_t numDimsTensor,
2855                                                           dim_t elemKind,
2856                                                           const char *name) {
2857   printf("%s\n", name);
2858   /// This definition should match the defintion in Glow.
2859   enum class ElemKind : unsigned char {
2860     FloatTy,       // 32-bit float type (float)
2861     Float16Ty,     // 16-bit float type (half, fp16)
2862     BFloat16Ty,    // 16-bit float type (bfloat16)
2863     Int8QTy,       // 8-bit quantized type (int8_t)
2864     UInt8QTy,      // unsigned 8-bit quantized type (uint8_t)
2865     Int16QTy,      // 16-bit quantized type (int16_t)
2866     Int32QTy,      // 32-bit quantized type (int32_t)
2867     Int32ITy,      // 32-bit index type (int32_t)
2868     Int64ITy,      // 64-bit index type (int64_t)
2869     UInt8FusedQTy, // 8-bit quantized type with fused scale/offset (uint8_t)
2870     BoolTy,        // Bool type (bool)
2871   };
2872   // Dump the content of a tensor.
2873   switch ((ElemKind)elemKind) {
2874   case ElemKind::FloatTy:
2875     libjit_dump_tensor_console_impl((float *)tensor, tensorDim, numDimsTensor);
2876     break;
2877   case ElemKind::Int64ITy:
2878     libjit_dump_tensor_console_impl((dim_t *)tensor, tensorDim, numDimsTensor);
2879     break;
2880   case ElemKind::Int8QTy:
2881     libjit_dump_tensor_console_impl((int8_t *)tensor, tensorDim, numDimsTensor);
2882     break;
2883   case ElemKind::Int32QTy:
2884     libjit_dump_tensor_console_impl((int32_t *)tensor, tensorDim,
2885                                     numDimsTensor);
2886     break;
2887   default:
2888     printf("Dumping this type of payload is not supported: %zu\n",
2889            (size_t)elemKind);
2890     break;
2891   }
2892   puts("");
2893 }
2894 
2895 /// Function to dump a tensor in binary format in a file using the raw tensor
2896 /// data pointer \p tensor, the tensor data size \p tensorSize (in bytes) and
2897 /// the file name \p filename. A text header \p header will also be dumped.
libjit_dump_tensor_bin(uint8_t * tensor,size_t tensorSize,const char * filename,const char * header)2898 __attribute__((noinline)) void libjit_dump_tensor_bin(uint8_t *tensor,
2899                                                       size_t tensorSize,
2900                                                       const char *filename,
2901                                                       const char *header) {
2902   FILE *fh = fopen(filename, "wb");
2903   if (!fh) {
2904     printf("ERROR opening file: '%s'!\n"
2905            "File name might be too long!\n",
2906            filename);
2907     return;
2908   }
2909   // Dump header.
2910   fprintf(fh, "%s", header);
2911   // Dump tensor data.
2912   size_t size = fwrite(tensor, 1, tensorSize, fh);
2913   assert((size == tensorSize) && "Error dumping tensor to file!");
2914   (void)size;
2915   fclose(fh);
2916 }
2917 
2918 /// Functions to dump a tensor in text format in a file using the raw tensor
2919 /// data pointer \p tensor, the tensor data size \p tensorElemSize (number of
2920 /// elements) and the file name \p filename. A text header \p header will also
2921 /// be dumped.
2922 #define DEFINE_DUMP_TENSOR_TXT_KERNEL(type, suffix)                            \
2923   __attribute__((noinline)) void libjit_dump_tensor_txt_##suffix(              \
2924       uint8_t *tensor, size_t tensorElemSize, const char *filename,            \
2925       const char *header) {                                                    \
2926     libjit_dump_tensor_txt_impl((type *)tensor, tensorElemSize, filename,      \
2927                                 header);                                       \
2928   }
DEFINE_DUMP_TENSOR_TXT_KERNEL(float,f)2929 DEFINE_DUMP_TENSOR_TXT_KERNEL(float, f)
2930 DEFINE_DUMP_TENSOR_TXT_KERNEL(int8_t, i8)
2931 DEFINE_DUMP_TENSOR_TXT_KERNEL(int16_t, i16)
2932 DEFINE_DUMP_TENSOR_TXT_KERNEL(int32_t, i32)
2933 DEFINE_DUMP_TENSOR_TXT_KERNEL(int64_t, u)
2934 DEFINE_DUMP_TENSOR_TXT_KERNEL(bool, b)
2935 #undef DEFINE_DUMP_TENSOR_TXT_KERNEL
2936 
2937 void libjit_write_timestamp(uint64_t *tensor, dim_t offset) {
2938   // We are using C++ timer here to a avoid issues with gettimeofday
2939   // Issue #2397 covers migrating this to a libc approach but if you have issues
2940   // with a lack of C++ symbols at runtime check there first.
2941   uint64_t ts = std::chrono::duration_cast<std::chrono::microseconds>(
2942                     std::chrono::steady_clock::now().time_since_epoch())
2943                     .count();
2944   memcpy(tensor + offset, &ts, sizeof(uint64_t));
2945 }
2946 
2947 /// Copies a kernel with type conversion
libjit_convertTo_f_b(float * dstPtr,const bool * srcPtr,const dim_t * dims,dim_t numDims)2948 void libjit_convertTo_f_b(float *dstPtr, const bool *srcPtr, const dim_t *dims,
2949                           dim_t numDims) {
2950   libjit_copy_kernel_with_conversion<float, bool>(dstPtr, srcPtr, dims,
2951                                                   numDims);
2952 }
2953 
libjit_convertTo_b_f(bool * dstPtr,const float * srcPtr,const dim_t * dims,dim_t numDims)2954 void libjit_convertTo_b_f(bool *dstPtr, const float *srcPtr, const dim_t *dims,
2955                           dim_t numDims) {
2956   libjit_copy_kernel_with_conversion<bool, float>(dstPtr, srcPtr, dims,
2957                                                   numDims);
2958 }
2959 
libjit_convertTo_f_i32(float * dstPtr,const int32_t * srcPtr,const dim_t * dims,dim_t numDims)2960 void libjit_convertTo_f_i32(float *dstPtr, const int32_t *srcPtr,
2961                             const dim_t *dims, dim_t numDims) {
2962   libjit_copy_kernel_with_conversion<float, int32_t>(dstPtr, srcPtr, dims,
2963                                                      numDims);
2964 }
2965 
libjit_convertTo_i32_u(int32_t * dstPtr,const int64_t * srcPtr,const dim_t * dims,dim_t numDims)2966 void libjit_convertTo_i32_u(int32_t *dstPtr, const int64_t *srcPtr,
2967                             const dim_t *dims, dim_t numDims) {
2968   libjit_copy_kernel_with_conversion<int32_t, int64_t>(dstPtr, srcPtr, dims,
2969                                                        numDims);
2970 }
2971 
libjit_convertTo_u_i32(int64_t * dstPtr,const int32_t * srcPtr,const dim_t * dims,dim_t numDims)2972 void libjit_convertTo_u_i32(int64_t *dstPtr, const int32_t *srcPtr,
2973                             const dim_t *dims, dim_t numDims) {
2974   libjit_copy_kernel_with_conversion<int64_t, int32_t>(dstPtr, srcPtr, dims,
2975                                                        numDims);
2976 }
2977 
2978 /// Update min/max values \p compInfo and histogram \p existingHistogram with
2979 /// data collected from tensor \p inputTensor.
2980 /// Note: code ported from Profile.cpp: generateTensorHistogram
2981 __attribute__((noinline)) void
libjit_quantization_profile(float * inputTensor,dim_t tensorSize,float * compInfo,float * existingHistogram,dim_t * histDim)2982 libjit_quantization_profile(float *inputTensor, dim_t tensorSize,
2983                             float *compInfo, float *existingHistogram,
2984                             dim_t *histDim) {
2985   dim_t nBins = histDim[0];
2986 
2987   // Min/max computed from previous runs. If this is the first run, compInfo is
2988   // expected to be initialized as following:
2989   // compInfo[0]: std::numeric_limits<float>::max()
2990   // compInfo[1]: std::numeric_limits<float>::lowest()
2991   float min = compInfo[0];
2992   float max = compInfo[1];
2993 
2994   // Min/max value for entire current input tensor.
2995   float minInput;
2996   float maxInput;
2997   find_min_max_f(inputTensor, tensorSize, minInput, maxInput);
2998 
2999   // Update the global min/max.
3000   float newMin = MIN(minInput, min);
3001   float newMax = MAX(maxInput, max);
3002   compInfo[0] = newMin;
3003   compInfo[1] = newMax;
3004 
3005   // If input histogram is empty then return.
3006   if (nBins == 0) {
3007     return;
3008   }
3009 
3010   // Initial profile.
3011   if (check_all_zeros(existingHistogram, nBins) == 1) {
3012     min = minInput;
3013     max = maxInput;
3014   }
3015 
3016   // If the min/max range changes, there is the need to rescale the histogram.
3017   if (newMin < min || newMax > max) {
3018     float destBinWidth = (newMax - newMin) / nBins;
3019     float srcBinWidth = (max - min) / nBins;
3020     float scaledHistogram[nBins];
3021     for (dim_t i = 0; i < nBins; ++i) {
3022       scaledHistogram[i] = 0.0f;
3023     }
3024 
3025     for (dim_t i = 0; i < nBins; ++i) {
3026       if (existingHistogram[i] == 0)
3027         continue;
3028 
3029       float srcBinBegin = min + srcBinWidth * i;
3030       dim_t destBin = (srcBinBegin - newMin) / destBinWidth;
3031       float destBinEnd = newMin + destBinWidth * (destBin + 1);
3032 
3033       float srcBinEnd = srcBinBegin + srcBinWidth;
3034       dim_t destBinToVerify = (srcBinEnd - newMin) / destBinWidth;
3035       // Make sure that destination bin is mapped at most to 2 final bins, based
3036       // on that redistribute percentage is calculated.
3037       assert(destBinToVerify <= destBin + 2);
3038       (void)destBinToVerify;
3039 
3040       // Calculate how much we need to redistribute.
3041       uint64_t dstBinCnt = static_cast<uint64_t>(
3042           MIN(static_cast<float>(round((destBinEnd - srcBinBegin) /
3043                                        srcBinWidth * existingHistogram[i])),
3044               existingHistogram[i]));
3045 
3046       dim_t newBin = get_bin(nBins, destBinWidth, newMin, srcBinBegin);
3047       scaledHistogram[newBin] += dstBinCnt;
3048 
3049       if (dstBinCnt < existingHistogram[i]) {
3050         dim_t newBin =
3051             get_bin(nBins, destBinWidth, newMin, srcBinBegin + destBinWidth);
3052         scaledHistogram[newBin] += existingHistogram[i] - dstBinCnt;
3053       }
3054     }
3055 
3056     // Copy scaled histogram back to the existing histogram.
3057     for (dim_t i = 0, e = nBins; i < e; ++i) {
3058       existingHistogram[i] = scaledHistogram[i];
3059     }
3060 
3061     // Update global min and max.
3062     min = newMin;
3063     max = newMax;
3064   }
3065 
3066   // Update the histogram with the values of the current input tensor.
3067   float binWidth = (max - min) / nBins;
3068   for (dim_t i = 0, e = tensorSize; i < e; ++i) {
3069     dim_t newBin = get_bin(nBins, binWidth, min, inputTensor[i]);
3070     existingHistogram[newBin]++;
3071   }
3072 }
3073 
3074 __attribute__((noinline)) void
libjit_nms_u(uint64_t * indices,uint64_t * numDetected,const float * boxTensor,const dim_t * boxTensorDims,dim_t boxTensorDimSize,const float * scoresTensor,const dim_t * scoresTensorDims,dim_t scoresTensorDimSize,const dim_t * resultTensorDims,dim_t resultTensorDimSize,unsigned centerPointBox,unsigned maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold,bool isV4)3075 libjit_nms_u(uint64_t *indices, uint64_t *numDetected, const float *boxTensor,
3076              const dim_t *boxTensorDims, dim_t boxTensorDimSize,
3077              const float *scoresTensor, const dim_t *scoresTensorDims,
3078              dim_t scoresTensorDimSize, const dim_t *resultTensorDims,
3079              dim_t resultTensorDimSize, unsigned centerPointBox,
3080              unsigned maxOutputBoxesPerClass, float iouThreshold,
3081              float scoreThreshold, bool isV4) {
3082   libjit_nms_generic(indices, numDetected, boxTensor, boxTensorDims,
3083                      boxTensorDimSize, scoresTensor, scoresTensorDims,
3084                      scoresTensorDimSize, resultTensorDims, resultTensorDimSize,
3085                      centerPointBox, maxOutputBoxesPerClass, iouThreshold,
3086                      scoreThreshold, isV4);
3087 }
3088 
3089 __attribute__((noinline)) void
libjit_nms_i32(int32_t * indices,int32_t * numDetected,const float * boxTensor,const dim_t * boxTensorDims,dim_t boxTensorDimSize,const float * scoresTensor,const dim_t * scoresTensorDims,dim_t scoresTensorDimSize,const dim_t * resultTensorDims,dim_t resultTensorDimSize,unsigned centerPointBox,unsigned maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold,bool isV4)3090 libjit_nms_i32(int32_t *indices, int32_t *numDetected, const float *boxTensor,
3091                const dim_t *boxTensorDims, dim_t boxTensorDimSize,
3092                const float *scoresTensor, const dim_t *scoresTensorDims,
3093                dim_t scoresTensorDimSize, const dim_t *resultTensorDims,
3094                dim_t resultTensorDimSize, unsigned centerPointBox,
3095                unsigned maxOutputBoxesPerClass, float iouThreshold,
3096                float scoreThreshold, bool isV4) {
3097   libjit_nms_generic(indices, numDetected, boxTensor, boxTensorDims,
3098                      boxTensorDimSize, scoresTensor, scoresTensorDims,
3099                      scoresTensorDimSize, resultTensorDims, resultTensorDimSize,
3100                      centerPointBox, maxOutputBoxesPerClass, iouThreshold,
3101                      scoreThreshold, isV4);
3102 }
3103 
3104 /// FFT Radix2 DIT (Decimation In Time) implementation for Complex data.
3105 /// The \p input and \p output buffers have 2 * \p fftLength float
3106 /// samples corresponding to \p fftLength complex samples with real and
3107 /// imaginary parts interleaved: real[0], imag[0], real[1], imag[1], ...
3108 /// The lookup tables \p twiddleFactors and \p bitReverseIndices are
3109 /// generated at compile time. The boolean flag \p inPlace decides whether
3110 /// the FFT computation is done in-place (that is in the \p input buffer
3111 /// without writing in the \p output buffer) or out-of-place (written in
3112 /// the \p output buffer).
libjit_fft_complex_f(float * output,float * input,const float * twiddleFactors,const int32_t * bitReverseIndices,unsigned fftLength,bool inPlace)3113 void libjit_fft_complex_f(float *output, float *input,
3114                           const float *twiddleFactors,
3115                           const int32_t *bitReverseIndices, unsigned fftLength,
3116                           bool inPlace) {
3117 
3118   // Bit Reverse Reordering.
3119   if (inPlace) {
3120     for (dim_t idx = 0; idx < fftLength; idx++) {
3121       int32_t bitRevIdx = bitReverseIndices[idx];
3122       if (idx < bitRevIdx) {
3123         // Swap complex pair.
3124         float real = input[2 * idx + 0];
3125         float imag = input[2 * idx + 1];
3126         input[2 * idx + 0] = input[2 * bitRevIdx + 0];
3127         input[2 * idx + 1] = input[2 * bitRevIdx + 1];
3128         input[2 * bitRevIdx + 0] = real;
3129         input[2 * bitRevIdx + 1] = imag;
3130       }
3131     }
3132   } else {
3133     for (dim_t idx = 0; idx < fftLength; idx++) {
3134       int32_t bitRevIdx = bitReverseIndices[idx];
3135       output[2 * idx + 0] = input[2 * bitRevIdx + 0];
3136       output[2 * idx + 1] = input[2 * bitRevIdx + 1];
3137     }
3138   }
3139 
3140   // FFT output pointer.
3141   float *bitRevOut = inPlace ? input : output;
3142 
3143   // Number of FFT stages.
3144   dim_t stageNum = std::log2((double)fftLength);
3145 
3146   // Number of radix2 butterfly groups for 1st stage.
3147   dim_t groupNum = fftLength / 2;
3148 
3149   // Number of radix2 butterflies per group for 1st stage.
3150   dim_t groupButterNum = 1;
3151 
3152   // Stage loop.
3153   for (dim_t stageIdx = 0; stageIdx < stageNum; stageIdx++) {
3154 
3155     // Butterfly input/output pointers.
3156     float *inp1Ptr = bitRevOut + 0 * groupButterNum;
3157     float *inp2Ptr = bitRevOut + 2 * groupButterNum;
3158 
3159     // Butterfly group loop.
3160     for (dim_t groupIdx = 0; groupIdx < groupNum; groupIdx++) {
3161 
3162       // Twiddle factors pointer.
3163       const float *twPtr = twiddleFactors;
3164 
3165       // Butterfly loop within group.
3166       for (dim_t groupButterIdx = 0; groupButterIdx < groupButterNum;
3167            groupButterIdx++) {
3168 
3169         // Radix 2 butterfly.
3170         float inp0_re = *inp1Ptr++;
3171         float inp0_im = *inp1Ptr--;
3172         float inp1_re = *inp2Ptr++;
3173         float inp1_im = *inp2Ptr--;
3174 
3175         float tw_re = *twPtr++;
3176         float tw_im = *twPtr--;
3177         twPtr += (2 * groupNum);
3178 
3179         float inp1_tw_mult_re = inp1_re * tw_re - inp1_im * tw_im;
3180         float inp1_tw_mult_im = inp1_re * tw_im + inp1_im * tw_re;
3181 
3182         *inp1Ptr++ = inp0_re + inp1_tw_mult_re;
3183         *inp1Ptr++ = inp0_im + inp1_tw_mult_im;
3184         *inp2Ptr++ = inp0_re - inp1_tw_mult_re;
3185         *inp2Ptr++ = inp0_im - inp1_tw_mult_im;
3186       }
3187 
3188       inp1Ptr += 2 * groupButterNum;
3189       inp2Ptr += 2 * groupButterNum;
3190     }
3191 
3192     // Update parameters for next stage.
3193     groupNum >>= 1;
3194     groupButterNum <<= 1;
3195   }
3196 }
3197 
3198 /// FFT Radix2 DIT (Decimation In Time) implementation for Real data.
3199 /// The implementation uses a fftLength/2 FFT for Complex data followed
3200 /// by a step to map the complex FFT to the real FFT by using a set of
3201 /// of complex weights \p complexToRealWeights A[k] defined as:
3202 ///   A[k] = 1/2 * (1 - j*exp(-j*2*pi*k/N)) for k = 0 .. N/4-1
3203 /// The \p input buffer has \p fftLength float values corresponding
3204 /// to \p fftLength real samples. Since the FFT of a real signal
3205 /// has conjugate symmetry, the \p output buffer only contains
3206 /// 2 * (fftLength/2+1) = fftLength + 2 float values corresponding
3207 /// to fftLength/2+1 complex samples with real and imaginary parts
3208 /// interleaved: real[0], imag[0], real[1], imag[1], ...
3209 /// The lookup tables \p twiddleFactors and \p bitReverseIndices are
3210 /// generated at compile time as if they were generated for a N/2
3211 /// complex FFT. The boolean flag \p inPlace decides whether the FFT
3212 /// computation is done in-place (that is in the \p input buffer
3213 /// without writing in the \p output buffer) or out-of-place (written in
3214 /// the \p output buffer).
libjit_fft_real_f(float * output,float * input,const float * twiddleFactors,const int32_t * bitReverseIndices,const float * complexToRealWeights,unsigned fftLength,bool inPlace)3215 void libjit_fft_real_f(float *output, float *input, const float *twiddleFactors,
3216                        const int32_t *bitReverseIndices,
3217                        const float *complexToRealWeights, unsigned fftLength,
3218                        bool inPlace) {
3219 
3220   // Perform N/2 complex FFT (in-place or out-of-place).
3221   // G[k] with k = 0 .. N/2-1.
3222   libjit_fft_complex_f(output, input, twiddleFactors, bitReverseIndices,
3223                        fftLength / 2, inPlace);
3224 
3225   // Complex to Real FFT mapping (in-place).
3226   //   X[k] = G[k] * A[k] + conj(G[N/2-k]) * (1 - A[k])
3227   // for k = 0 .. N/2 with the convention G[N/2] = G[0].
3228   // Particular cases:
3229   //   real(X[0]) = real(G[0]) + imag(G[0])
3230   //   imag(X[0]) = 0
3231   //   real(X[N/2]) = real(G[0]) - imag(G[0])
3232   //   imag(X[N/2]) = 0
3233   //   X[N/4] = conj(G[N/4])
3234 
3235   const float *Ak = complexToRealWeights + 2;
3236   float *ptr = inPlace ? input : output;
3237   float *ptr0 = &ptr[0];
3238   float *ptr1 = &ptr[2 * fftLength / 2 + 1];
3239   float inp0_re = *ptr0++;
3240   float inp0_im = *ptr0--;
3241   *ptr0++ = inp0_re + inp0_im;
3242   *ptr0++ = 0;
3243   *ptr1-- = 0;
3244   *ptr1-- = inp0_re - inp0_im;
3245 
3246   for (dim_t k = 1; k < fftLength / 4; k++) {
3247 
3248     float inp0_re = *ptr0++;
3249     float inp0_im = *ptr0--;
3250     float inp1_im = *ptr1--;
3251     float inp1_re = *ptr1++;
3252 
3253     float Ak_re = *Ak++;
3254     float Ak_im = *Ak++;
3255 
3256     float dif_re = inp0_re - inp1_re;
3257     float sum_im = inp0_im + inp1_im;
3258     float prod0 = dif_re * Ak_re - sum_im * Ak_im;
3259     float prod1 = dif_re * Ak_im + sum_im * Ak_re;
3260 
3261     *ptr0++ = +prod0 + inp1_re;
3262     *ptr0++ = +prod1 - inp1_im;
3263     *ptr1-- = +prod1 - inp0_im;
3264     *ptr1-- = -prod0 + inp0_re;
3265   }
3266 
3267   if (fftLength >= 4) {
3268     *ptr1 = -*ptr1;
3269   }
3270 }
3271 
3272 /// Compute the spectrogram for the given 1D mono audio signal \p input.
3273 /// The input windows are weighted using the \p window function and the
3274 /// FFT LUTs \p twiddleFactors and \p bitReverseIndices are computed at
3275 /// compile-time. More details in Graph.h about the AudioSpectrogram node.
libjit_audio_spectrogram_f(void * winOutScratch,void * fftOutScratch,float * spectrogram,const float * input,const float * window,const float * twiddleFactors,const int32_t * bitReverseIndices,const float * complexToRealWeights,const dim_t * spectrogramDims,const dim_t inputLength,const dim_t windowSize,const dim_t windowStride,const bool magnitudeSquared)3276 void libjit_audio_spectrogram_f(
3277     void *winOutScratch, void *fftOutScratch, float *spectrogram,
3278     const float *input, const float *window, const float *twiddleFactors,
3279     const int32_t *bitReverseIndices, const float *complexToRealWeights,
3280     const dim_t *spectrogramDims, const dim_t inputLength,
3281     const dim_t windowSize, const dim_t windowStride,
3282     const bool magnitudeSquared) {
3283 
3284   dim_t winNum = spectrogramDims[0];
3285   dim_t specLen = spectrogramDims[1];
3286   dim_t fftLen = (specLen - 1) * 2;
3287 
3288   // Scratch buffers.
3289   float *winOut = (float *)winOutScratch;
3290   float *fftOut = (float *)fftOutScratch;
3291   memset(winOut, 0, fftLen * sizeof(float));
3292 
3293   // Compute the spectrogram.
3294   for (dim_t winIdx = 0; winIdx < winNum; winIdx++) {
3295 
3296     // Windowing.
3297     for (dim_t n = 0; n < windowSize; n++) {
3298       winOut[n] = input[winIdx * windowStride + n] * window[n];
3299     }
3300 
3301     // Compute spectrum (perform FFT for real data).
3302     libjit_fft_real_f(fftOut, winOut, twiddleFactors, bitReverseIndices,
3303                       complexToRealWeights, fftLen, false /* inPlace */);
3304 
3305     // Compute spectrum magnitude/power.
3306     for (dim_t k = 0; k < specLen; k++) {
3307       float real = fftOut[2 * k + 0];
3308       float imag = fftOut[2 * k + 1];
3309       float power = real * real + imag * imag;
3310       if (magnitudeSquared) {
3311         *spectrogram++ = power;
3312       } else {
3313         *spectrogram++ = std::sqrt(power);
3314       }
3315     }
3316   }
3317 }
3318 
3319 /// Compute the MFCC (Mel Frequency Cepstral Coefficient) for the given
3320 /// \p spectrogram power. The lookup tables \p melWeights, \p melRanges
3321 /// and \p dctMat are computed at compile-time. More details in Graph.h
3322 /// about the MFCC node.
libjit_mfcc_f(void * scratch,float * coefficients,const float * spectrogram,const float * melWeights,const int32_t * melRanges,const float * dctMat,const dim_t * coefficientsDims,const dim_t * spectrogramDims,const dim_t filterBankCount)3323 void libjit_mfcc_f(void *scratch, float *coefficients, const float *spectrogram,
3324                    const float *melWeights, const int32_t *melRanges,
3325                    const float *dctMat, const dim_t *coefficientsDims,
3326                    const dim_t *spectrogramDims, const dim_t filterBankCount) {
3327 
3328   // Scratch buffer.
3329   float *melBuff = (float *)scratch;
3330 
3331   // Perform MFCC for all the windows.
3332   dim_t winNum = spectrogramDims[0];
3333   dim_t winSize = spectrogramDims[1];
3334   dim_t numCoefficients = coefficientsDims[1];
3335   for (dim_t winIdx = 0; winIdx < winNum; winIdx++) {
3336 
3337     // Pointers backup for this window.
3338     const float *melWeightsPtr = melWeights;
3339     const int32_t *melRangesPtr = melRanges;
3340     const float *dctMatPtr = dctMat;
3341 
3342     // Apply Mel filter bank mapping. We use sqrt for the spectrogram since we
3343     // assume the spectrogram is a power value and not a magnitude.
3344     for (dim_t melIdx = 0; melIdx < filterBankCount; melIdx++) {
3345 
3346       int32_t freqIdxStart = *melRangesPtr++;
3347       int32_t freqIdxStop = *melRangesPtr++;
3348 
3349       // Compute Mel Power.
3350       float melPwr = 0.0f;
3351       for (int32_t freqIdx = freqIdxStart; freqIdx <= freqIdxStop; freqIdx++) {
3352         melPwr += std::sqrt(spectrogram[freqIdx]) * (*melWeightsPtr++);
3353       }
3354 
3355       // Take logarithm in-place (avoid log(0)).
3356       melBuff[melIdx] = (melPwr == 0.0)
3357                             ? logf(std::numeric_limits<float>::min())
3358                             : logf(melPwr);
3359     }
3360 
3361     // Compute DCT transform.
3362     for (dim_t k = 0; k < numCoefficients; k++) {
3363       float dctOut = 0.0f;
3364       for (dim_t n = 0; n < filterBankCount; n++) {
3365         dctOut += (*dctMatPtr++) * melBuff[n];
3366       }
3367       *coefficients++ = dctOut;
3368     }
3369 
3370     // Go to next spectrogram window.
3371     spectrogram += winSize;
3372   }
3373 }
3374 } // extern "C"
3375