1 /**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <algorithm>
17 #include <assert.h>
18 #include <chrono>
19 #include <cmath>
20 #include <math.h>
21 #include <numeric>
22 #include <stddef.h>
23 #include <stdint.h>
24 #include <stdio.h>
25 #include <stdlib.h>
26 #include <string.h>
27 #include <sys/types.h>
28
29 #include "libjit_defs.h"
30
31 namespace {
32
33 template <class ElemTy>
libjit_dump_tensor_console_impl(ElemTy * tensor,dim_t * dims,dim_t numDims)34 static void libjit_dump_tensor_console_impl(ElemTy *tensor, dim_t *dims,
35 dim_t numDims) {
36 // Check for 0-dimensional tensor.
37 if (!numDims) {
38 printf("[ Scalar containing: %.3f ]\n", (float)tensor[0]);
39 return;
40 }
41
42 // Output shape.
43 printf("shape: ( ");
44 for (size_t i = 0; i < numDims; ++i) {
45 printf("%zu ", (size_t)dims[i]);
46 }
47 printf(")\n");
48
49 ElemTy mx = tensor[0];
50 ElemTy mn = tensor[0];
51
52 size_t size = 1;
53 size_t sliceSize[numDims];
54 for (size_t i = 0; i < numDims; ++i) {
55 size *= dims[i];
56 }
57
58 for (ssize_t i = numDims - 1, curSliceSize = 1; i >= 0; --i) {
59 sliceSize[i] = curSliceSize;
60 curSliceSize *= dims[i];
61 }
62
63 for (size_t i = 0, e = size; i < e; i++) {
64 mx = MAX(mx, tensor[i]);
65 mn = MIN(mn, tensor[i]);
66 }
67
68 // Check for zero tensor.
69 if (mn == .0 && mx == .0) {
70 printf("[ Zero tensor ]\n");
71 return;
72 }
73
74 // Output max and min.
75 printf("max: %.3f min: %.3f\n", (float)mx, (float)mn);
76
77 const unsigned maxNumElem = 100;
78
79 printf("[");
80
81 for (size_t i = 0, e = MIN(maxNumElem, size); i < e; i++) {
82
83 // Print one open brace at the beginning of every row, slice, and tensor.
84 for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
85 if (i % sliceSize[j] == 0) {
86 // This iteration of outer loop is a new row, slice or tensor.
87 printf("[");
88 }
89 }
90
91 // Print the value at the current index.
92 printf("%.3f", (float)tensor[i]);
93
94 // Print one closed brace at the end of every row, slice, or tensor.
95 for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
96 size_t next_index = i + 1;
97 if (next_index % sliceSize[j] == 0u) {
98 printf("]");
99 }
100 }
101
102 printf(", ");
103
104 // Print one newline at the end of every row, slice, or tensor.
105 for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
106 size_t next_index = i + 1;
107 if (next_index % sliceSize[j] == 0u) {
108 // Next iteration of outer loop will be a new row, slice or tensor.
109 printf("\n");
110 }
111 }
112 }
113
114 if (size > maxNumElem) {
115 printf("...");
116 }
117
118 printf("]\n");
119 }
120
121 template <class ElemTy>
libjit_dump_tensor_txt_impl(ElemTy * tensor,size_t tensorElemSize,const char * filename,const char * header)122 static void libjit_dump_tensor_txt_impl(ElemTy *tensor, size_t tensorElemSize,
123 const char *filename,
124 const char *header) {
125 FILE *fh = fopen(filename, "w");
126 if (!fh) {
127 printf("ERROR opening file: '%s'!\n"
128 "File name might be too long!\n",
129 filename);
130 return;
131 }
132 if (strlen(header)) {
133 fprintf(fh, "%s\n", header);
134 }
135 for (size_t idx = 0, end = tensorElemSize; idx < end; idx++) {
136 fprintf(fh, "%f, ", (double)tensor[idx]);
137 }
138 fclose(fh);
139 }
140
141 template <typename ElemTy>
get_element_ptr(const ElemTy * tensor,const dim_t * dims,dim_t numDims,const dim_t * indices,dim_t numIndices)142 static dim_t get_element_ptr(const ElemTy *tensor, const dim_t *dims,
143 dim_t numDims, const dim_t *indices,
144 dim_t numIndices) {
145 dim_t index = 0;
146 dim_t subdimensionSize = 1;
147 for (dim_t i = numDims; i > 0; i--) {
148 dim_t curIndicesValue = (i <= numIndices) ? indices[i - 1] : 0;
149 index += subdimensionSize * curIndicesValue;
150 subdimensionSize *= dims[i - 1];
151 }
152 return index;
153 }
154
155 template <typename ElemTy>
libjit_insert_tensor(ElemTy * tensor,ElemTy * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim,dim_t count,dim_t axis)156 static void libjit_insert_tensor(ElemTy *tensor, ElemTy *slice, dim_t *offset,
157 dim_t *tensorDim, dim_t *sliceDim,
158 dim_t numDimsTensor, dim_t numDimsSlice,
159 dim_t offsetDim, dim_t count, dim_t axis) {
160 // Destination coordinates.
161 dim_t C[5];
162
163 // A local copy of the offsets buffer. We copy the buffer to make it clear
164 // to the optimizer that the inputs don't alias. This loop is optimized away.
165 dim_t offsets_cpy[5];
166 for (dim_t i = 0; i < numDimsSlice; i++) {
167 offsets_cpy[i] = offset[i];
168 }
169
170 if (numDimsSlice == 5) {
171 for (dim_t c = 0; c < count; c++)
172 for (dim_t x = 0; x < sliceDim[0]; x++)
173 for (dim_t y = 0; y < sliceDim[1]; y++)
174 for (dim_t z = 0; z < sliceDim[2]; z++)
175 for (dim_t w = 0; w < sliceDim[3]; w++)
176 for (dim_t q = 0; q < sliceDim[4]; q++) {
177 const dim_t countAxisOffset = c * sliceDim[axis];
178 C[0] = x + offsets_cpy[0] + ((axis == 0) ? countAxisOffset : 0);
179 C[1] = y + offsets_cpy[1] + ((axis == 1) ? countAxisOffset : 0);
180 C[2] = z + offsets_cpy[2] + ((axis == 2) ? countAxisOffset : 0);
181 C[3] = w + offsets_cpy[3] + ((axis == 3) ? countAxisOffset : 0);
182 C[4] = q + offsets_cpy[4] + ((axis == 4) ? countAxisOffset : 0);
183 tensor[libjit_getXYZWQ(tensorDim, C[0], C[1], C[2], C[3],
184 C[4])] =
185 slice[libjit_getXYZWQ(sliceDim, x, y, z, w, q)];
186 }
187 return;
188 }
189
190 if (numDimsSlice == 4) {
191 for (dim_t c = 0; c < count; c++)
192 for (dim_t x = 0; x < sliceDim[0]; x++)
193 for (dim_t y = 0; y < sliceDim[1]; y++)
194 for (dim_t z = 0; z < sliceDim[2]; z++)
195 for (dim_t w = 0; w < sliceDim[3]; w++) {
196 const dim_t countAxisOffset = c * sliceDim[axis];
197 C[0] = x + offsets_cpy[0] + ((axis == 0) ? countAxisOffset : 0);
198 C[1] = y + offsets_cpy[1] + ((axis == 1) ? countAxisOffset : 0);
199 C[2] = z + offsets_cpy[2] + ((axis == 2) ? countAxisOffset : 0);
200 C[3] = w + offsets_cpy[3] + ((axis == 3) ? countAxisOffset : 0);
201 tensor[libjit_getXYZW(tensorDim, C[0], C[1], C[2], C[3])] =
202 slice[libjit_getXYZW(sliceDim, x, y, z, w)];
203 }
204 return;
205 }
206
207 if (numDimsSlice == 3) {
208 for (dim_t c = 0; c < count; c++)
209 for (dim_t x = 0; x < sliceDim[0]; x++)
210 for (dim_t y = 0; y < sliceDim[1]; y++)
211 for (dim_t z = 0; z < sliceDim[2]; z++) {
212 const dim_t countAxisOffset = c * sliceDim[axis];
213 C[0] = x + offsets_cpy[0] + ((axis == 0) ? countAxisOffset : 0);
214 C[1] = y + offsets_cpy[1] + ((axis == 1) ? countAxisOffset : 0);
215 C[2] = z + offsets_cpy[2] + ((axis == 2) ? countAxisOffset : 0);
216 tensor[libjit_getXYZ(tensorDim, C[0], C[1], C[2])] =
217 slice[libjit_getXYZ(sliceDim, x, y, z)];
218 }
219 return;
220 }
221
222 if (numDimsSlice == 2) {
223 for (dim_t c = 0; c < count; c++)
224 for (dim_t x = 0; x < sliceDim[0]; x++)
225 for (dim_t y = 0; y < sliceDim[1]; y++) {
226 const dim_t countAxisOffset = c * sliceDim[axis];
227 C[0] = x + offsets_cpy[0] + ((axis == 0) ? countAxisOffset : 0);
228 C[1] = y + offsets_cpy[1] + ((axis == 1) ? countAxisOffset : 0);
229 tensor[libjit_getXY(tensorDim, C[0], C[1])] =
230 slice[libjit_getXY(sliceDim, x, y)];
231 }
232 return;
233 }
234
235 if (numDimsSlice == 1) {
236 for (dim_t c = 0; c < count; c++)
237 for (dim_t x = 0; x < sliceDim[0]; x++) {
238 const dim_t countAxisOffset = c * sliceDim[axis];
239 tensor[x + offsets_cpy[0] + ((axis == 0) ? countAxisOffset : 0)] =
240 slice[x];
241 }
242 return;
243 }
244 }
245
246 template <typename ElemTy>
libjit_extract_tensor(ElemTy * tensor,ElemTy * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim)247 static void libjit_extract_tensor(ElemTy *tensor, ElemTy *slice, dim_t *offset,
248 dim_t *tensorDim, dim_t *sliceDim,
249 dim_t numDimsTensor, dim_t numDimsSlice,
250 dim_t offsetDim) {
251 // Source coordinates.
252 dim_t C[5];
253
254 // A local copy of the offsets buffer. We copy the buffer to make it clear
255 // to the optimizer that the inputs don't alias. This loop is optimized away.
256 dim_t offsets_cpy[5];
257 for (dim_t i = 0; i < numDimsSlice; i++) {
258 offsets_cpy[i] = offset[i];
259 }
260
261 if (numDimsSlice == 5) {
262 for (dim_t x = 0; x < sliceDim[0]; x++)
263 for (dim_t y = 0; y < sliceDim[1]; y++)
264 for (dim_t z = 0; z < sliceDim[2]; z++)
265 for (dim_t w = 0; w < sliceDim[3]; w++)
266 for (dim_t q = 0; q < sliceDim[4]; q++) {
267 C[0] = x + offsets_cpy[0];
268 C[1] = y + offsets_cpy[1];
269 C[2] = z + offsets_cpy[2];
270 C[3] = w + offsets_cpy[3];
271 C[4] = q + offsets_cpy[4];
272 slice[libjit_getXYZWQ(sliceDim, x, y, z, w, q)] =
273 tensor[libjit_getXYZWQ(tensorDim, C[0], C[1], C[2], C[3],
274 C[4])];
275 }
276 return;
277 }
278
279 if (numDimsSlice == 4) {
280 for (dim_t x = 0; x < sliceDim[0]; x++)
281 for (dim_t y = 0; y < sliceDim[1]; y++)
282 for (dim_t z = 0; z < sliceDim[2]; z++)
283 for (dim_t w = 0; w < sliceDim[3]; w++) {
284 C[0] = x + offsets_cpy[0];
285 C[1] = y + offsets_cpy[1];
286 C[2] = z + offsets_cpy[2];
287 C[3] = w + offsets_cpy[3];
288 slice[libjit_getXYZW(sliceDim, x, y, z, w)] =
289 tensor[libjit_getXYZW(tensorDim, C[0], C[1], C[2], C[3])];
290 }
291 return;
292 }
293
294 if (numDimsSlice == 3) {
295 for (dim_t x = 0; x < sliceDim[0]; x++)
296 for (dim_t y = 0; y < sliceDim[1]; y++)
297 for (dim_t z = 0; z < sliceDim[2]; z++) {
298 C[0] = x + offsets_cpy[0];
299 C[1] = y + offsets_cpy[1];
300 C[2] = z + offsets_cpy[2];
301 slice[libjit_getXYZ(sliceDim, x, y, z)] =
302 tensor[libjit_getXYZ(tensorDim, C[0], C[1], C[2])];
303 }
304 return;
305 }
306
307 if (numDimsSlice == 2) {
308 for (dim_t x = 0; x < sliceDim[0]; x++)
309 for (dim_t y = 0; y < sliceDim[1]; y++) {
310 C[0] = x + offsets_cpy[0];
311 C[1] = y + offsets_cpy[1];
312 slice[libjit_getXY(sliceDim, x, y)] =
313 tensor[libjit_getXY(tensorDim, C[0], C[1])];
314 }
315 return;
316 }
317
318 if (numDimsSlice == 1) {
319 for (dim_t x = 0; x < sliceDim[0]; x++) {
320 slice[x] = tensor[x + offsets_cpy[0]];
321 }
322 return;
323 }
324 }
325
326 /// Helper struct for TopK
327 template <typename T, typename TI> struct value_index {
328 TI index;
329 T value;
330 };
331
332 /// Helper function for TopK
333 template <typename T, typename TI>
value_index_sort(const void * va,const void * vb)334 static int value_index_sort(const void *va, const void *vb) {
335 value_index<T, TI> *a = (value_index<T, TI> *)va;
336 value_index<T, TI> *b = (value_index<T, TI> *)vb;
337 if (a->value != b->value)
338 return a->value > b->value ? -1 : 1;
339 return a->index < b->index ? -1 : 1;
340 }
341
342 /// Generic Top-K function. Here, \p scratch is some allocated buffer space, \p
343 /// size is the size of the input, and \p n is the size of the last dimension of
344 /// the input.
345 template <typename T, typename TI>
libjit_topk(T * values,TI * indices,const T * input,void * scratch,dim_t k,dim_t n,dim_t size)346 static void libjit_topk(T *values, TI *indices, const T *input, void *scratch,
347 dim_t k, dim_t n, dim_t size) {
348 dim_t in = 0;
349 dim_t out = 0;
350
351 // Initialize scratch with 0.
352 memset(scratch, 0, 2 * n * sizeof(TI));
353
354 value_index<T, TI> *buffer = (value_index<T, TI> *)scratch;
355
356 // Specialize TopK for the case where K is 1.
357 if (k == 1) {
358 while (in < size) {
359 // Find the largest value by iterating over the array instead of calling
360 // 'sort'.
361 value_index<T, TI> mx = {0, input[in]};
362 for (TI i = 1; i < n; i++) {
363 if (input[i + in] > mx.value) {
364 mx = {i, input[i + in]};
365 }
366 }
367 indices[out] = mx.index;
368 values[out] = mx.value;
369 out++;
370 in += n;
371 }
372 return;
373 }
374
375 while (in < size) {
376 for (dim_t i = 0; i < n; i++) {
377 buffer[i].index = i;
378 buffer[i].value = input[in++];
379 }
380 qsort(buffer, n, sizeof(value_index<T, TI>), value_index_sort<T, TI>);
381 for (dim_t i = 0; i < k; i++) {
382 indices[out] = buffer[i].index;
383 values[out] = buffer[i].value;
384 out++;
385 }
386 }
387 }
388
389 template <typename T, typename IDX>
libjit_gather(T * dest,const T * data,const IDX * indices,dim_t numIndices,dim_t sliceSize,dim_t numSamples,dim_t sampleSize)390 static void libjit_gather(T *dest, const T *data, const IDX *indices,
391 dim_t numIndices, dim_t sliceSize, dim_t numSamples,
392 dim_t sampleSize) {
393 // The index of the slice that is being written.
394 dim_t outIdx = 0;
395
396 // For each sample in our batch:
397 for (dim_t sample = 0; sample < numSamples; sample++) {
398 dim_t sampleStart = sample * sampleSize;
399
400 // For each slice that we fetch:
401 for (dim_t i = 0; i < numIndices; i++) {
402 dim_t slice = indices[i];
403
404 // Copy the slice.
405 memcpy(dest + outIdx * sliceSize, data + sampleStart + slice * sliceSize,
406 sliceSize * sizeof(T));
407
408 // Point to the next location in the destination tensor.
409 outIdx++;
410 }
411 }
412 }
413
414 template <typename T, typename U>
libjit_gatherranges(T * output,U * lengths,const T * data,const U * ranges,dim_t numExamples,dim_t exampleSize)415 static void libjit_gatherranges(T *output, U *lengths, const T *data,
416 const U *ranges, dim_t numExamples,
417 dim_t exampleSize) {
418 // Indices into the output and range buffers.
419 dim_t outputIdx = 0;
420 dim_t rangesIdx = 0;
421
422 // For each example:
423 for (dim_t example = 0; example < numExamples; ++example) {
424 // Keep track of the total length of the gathered ranges for the example.
425 U totalLen = 0;
426
427 // For each range:
428 for (dim_t range = 0; range < exampleSize; ++range) {
429 // Get the start and length of the range.
430 const U start = ranges[rangesIdx];
431 const U len = ranges[rangesIdx + 1];
432
433 // Copy the specified elements.
434 memcpy(output + outputIdx, data + start, len * sizeof(T));
435
436 // len elements were copied, so increment the output index by len.
437 outputIdx += len;
438
439 // Each range is of the form (start, len), so increment the ranges
440 // index by 2 to get to the next range.
441 rangesIdx += 2;
442
443 // Increment the total length for the example by len.
444 totalLen += len;
445 }
446
447 // Record the total length of gathered ranges for the current example in
448 // the lengths buffer.
449 lengths[example] = totalLen;
450 }
451 }
452
453 template <typename T, typename T2>
libjit_scatterdatacopy(T * data,const dim_t * dataDims,const T2 * indices,const T * slices,dim_t numIndices,dim_t indexSize,dim_t sliceSize)454 static void libjit_scatterdatacopy(T *data, const dim_t *dataDims,
455 const T2 *indices, const T *slices,
456 dim_t numIndices, dim_t indexSize,
457 dim_t sliceSize) {
458 for (dim_t i = 0; i < numIndices; i++) {
459 dim_t destDataIdx = indices[i * indexSize];
460 for (dim_t j = 1; j < indexSize; j++) {
461 destDataIdx *= dataDims[j];
462 destDataIdx += indices[i * indexSize + j];
463 }
464 memcpy(data + destDataIdx * sliceSize, slices + i * sliceSize,
465 sliceSize * sizeof(T));
466 }
467 }
468
469 template <typename T, typename T2>
libjit_scatterdataaddfloat(T * data,const dim_t * dataDims,const T2 * indices,const T * slices,dim_t numIndices,dim_t indexSize,dim_t sliceSize)470 static void libjit_scatterdataaddfloat(T *data, const dim_t *dataDims,
471 const T2 *indices, const T *slices,
472 dim_t numIndices, dim_t indexSize,
473 dim_t sliceSize) {
474 for (dim_t i = 0; i < numIndices; i++) {
475 dim_t destDataIdx = indices[i * indexSize];
476 for (dim_t j = 1; j < indexSize; j++) {
477 destDataIdx *= dataDims[j];
478 destDataIdx += indices[i * indexSize + j];
479 }
480 for (dim_t j = 0; j < sliceSize; j++) {
481 data[destDataIdx * sliceSize + j] += slices[i * sliceSize + j];
482 }
483 }
484 }
485
486 template <typename T, typename T2>
libjit_scatterdataaddquantized(T * data,const dim_t * dataDims,const T2 * indices,const T * slices,dim_t numIndices,dim_t indexSize,dim_t sliceSize,float dataScale,int32_t dataOffset,float sliceScale,int32_t sliceOffset)487 static void libjit_scatterdataaddquantized(T *data, const dim_t *dataDims,
488 const T2 *indices, const T *slices,
489 dim_t numIndices, dim_t indexSize,
490 dim_t sliceSize, float dataScale,
491 int32_t dataOffset, float sliceScale,
492 int32_t sliceOffset) {
493
494 for (size_t i = 0; i < numIndices; i++) {
495 size_t destDataIdx = indices[i * indexSize];
496 for (size_t j = 1; j < indexSize; j++) {
497 destDataIdx *= dataDims[j];
498 destDataIdx += indices[i * indexSize + j];
499 }
500 for (size_t j = 0; j < sliceSize; j++) {
501 float lhs = (data[destDataIdx * sliceSize + j] - dataOffset) * dataScale;
502 float rhs = (slices[i * sliceSize + j] - sliceOffset) * sliceScale;
503 T result = libjit_clip((lhs + rhs) / dataScale + dataOffset);
504 data[destDataIdx * sliceSize + j] = result;
505 }
506 }
507 }
508
509 template <typename T>
libjit_transpose_generic(const T * inW,T * outW,const dim_t * idim,const dim_t * odim,const dim_t * shuffle,dim_t numDims)510 static void libjit_transpose_generic(const T *inW, T *outW, const dim_t *idim,
511 const dim_t *odim, const dim_t *shuffle,
512 dim_t numDims) {
513 // Transpose 2d matrices one tile at a time. This access pattern ensures
514 // that the whole tile is kept in L1 cache. When scanning the whole row at
515 // once we invalidate many cache lines when we touch a single column.
516 const unsigned tileSize = 64;
517
518 // Source coordinate.
519 dim_t SC[5];
520
521 if (numDims == 5) {
522 for (dim_t x = 0; x < odim[0]; x++)
523 for (dim_t y = 0; y < odim[1]; y++)
524 for (dim_t z = 0; z < odim[2]; z++)
525 for (dim_t w = 0; w < odim[3]; w++)
526 for (dim_t q = 0; q < odim[4]; q++) {
527 SC[shuffle[0]] = x;
528 SC[shuffle[1]] = y;
529 SC[shuffle[2]] = z;
530 SC[shuffle[3]] = w;
531 SC[shuffle[4]] = q;
532 outW[libjit_getXYZWQ(odim, x, y, z, w, q)] =
533 inW[libjit_getXYZWQ(idim, SC[0], SC[1], SC[2], SC[3], SC[4])];
534 }
535 return;
536 }
537 if (numDims == 4) {
538 for (dim_t x = 0; x < odim[0]; x++)
539 for (dim_t y = 0; y < odim[1]; y++)
540 for (dim_t z = 0; z < odim[2]; z++)
541 for (dim_t w = 0; w < odim[3]; w++) {
542 SC[shuffle[0]] = x;
543 SC[shuffle[1]] = y;
544 SC[shuffle[2]] = z;
545 SC[shuffle[3]] = w;
546 outW[libjit_getXYZW(odim, x, y, z, w)] =
547 inW[libjit_getXYZW(idim, SC[0], SC[1], SC[2], SC[3])];
548 }
549 return;
550 }
551 if (numDims == 3) {
552 for (dim_t x = 0; x < odim[0]; x++) {
553 // Process the tiles in the innermost two dimensions:
554 for (dim_t sy = 0; sy < odim[1]; sy += tileSize) {
555 for (dim_t sz = 0; sz < odim[2]; sz += tileSize) {
556 // Process the inner tile:
557 for (dim_t y = sy; y < MIN(sy + tileSize, odim[1]); y++) {
558 for (dim_t z = sz; z < MIN(sz + tileSize, odim[2]); z++) {
559 SC[shuffle[0]] = x;
560 SC[shuffle[1]] = y;
561 SC[shuffle[2]] = z;
562 outW[libjit_getXYZ(odim, x, y, z)] =
563 inW[libjit_getXYZ(idim, SC[0], SC[1], SC[2])];
564 }
565 }
566 }
567 }
568 }
569 return;
570 }
571
572 if (numDims == 2) {
573 // Process the tiles in the matrix:
574 for (dim_t sx = 0; sx < odim[0]; sx += tileSize) {
575 for (dim_t sy = 0; sy < odim[1]; sy += tileSize) {
576 // Process the inner tile:
577 for (dim_t x = sx; x < MIN(sx + tileSize, odim[0]); x++) {
578 for (dim_t y = sy; y < MIN(sy + tileSize, odim[1]); y++) {
579 SC[shuffle[0]] = x;
580 SC[shuffle[1]] = y;
581 outW[libjit_getXY(odim, x, y)] =
582 inW[libjit_getXY(idim, SC[0], SC[1])];
583 }
584 }
585 }
586 }
587 return;
588 }
589 }
590
591 template <typename T>
libjit_flip_generic(const T * inW,T * outW,const dim_t * dims,dim_t axis,dim_t numDims)592 static void libjit_flip_generic(const T *inW, T *outW, const dim_t *dims,
593 dim_t axis, dim_t numDims) {
594
595 // Product of outer dimensions excluding the flip dimension.
596 dim_t outerLen = 1;
597 for (dim_t idx = 0; idx < axis; idx++) {
598 outerLen *= dims[idx];
599 }
600
601 // Flip dimension.
602 dim_t len = dims[axis];
603
604 // Product of inner dimensions excluding the flip dimension.
605 dim_t innerLen = 1;
606 for (dim_t idx = axis + 1; idx < numDims; idx++) {
607 innerLen *= dims[idx];
608 }
609
610 // Flip axis such that input data is read linearly.
611 const T *inpPtr = inW;
612 T *outPtr = outW + (len - 1) * innerLen;
613 for (dim_t outerIdx = 0; outerIdx < outerLen; outerIdx++) {
614 for (dim_t idx = 0; idx < len; idx++) {
615 for (dim_t innerIdx = 0; innerIdx < innerLen; innerIdx++) {
616 *outPtr++ = *inpPtr++;
617 }
618 outPtr -= 2 * innerLen;
619 }
620 outPtr += 2 * len * innerLen;
621 }
622 }
623
624 template <typename inpT, typename outT>
libjit_arg_max_generic(const inpT * inpW,outT * outW,const dim_t * dims,size_t numDims,size_t axis)625 static void libjit_arg_max_generic(const inpT *inpW, outT *outW,
626 const dim_t *dims, size_t numDims,
627 size_t axis) {
628
629 // Product of outer dimensions excluding the axis dimension.
630 dim_t outerLen = 1;
631 for (dim_t idx = 0; idx < axis; ++idx) {
632 outerLen *= dims[idx];
633 }
634
635 // Axis dimension length.
636 dim_t axisLen = dims[axis];
637
638 // Product of inner dimensions excluding the axis dimension.
639 dim_t innerLen = 1;
640 for (dim_t idx = axis + 1; idx < numDims; ++idx) {
641 innerLen *= dims[idx];
642 }
643
644 // Traverse data such that output is written linearly.
645 const inpT *inpPtr = inpW;
646 outT *outPtr = outW;
647 for (dim_t outerIdx = 0; outerIdx < outerLen; ++outerIdx) {
648 for (dim_t innerIdx = 0; innerIdx < innerLen; ++innerIdx) {
649 inpT maxVal = std::numeric_limits<inpT>::lowest();
650 outT maxIdx = 0;
651 for (dim_t axisIdx = 0; axisIdx < axisLen; ++axisIdx) {
652 inpT inpVal = *inpPtr;
653 if (inpVal > maxVal) {
654 maxVal = inpVal;
655 maxIdx = axisIdx;
656 }
657 inpPtr += innerLen;
658 }
659 inpPtr = inpPtr - axisLen * innerLen + 1;
660 *outPtr++ = maxIdx;
661 }
662 inpPtr = inpPtr - innerLen + axisLen * innerLen;
663 }
664 }
665
666 template <typename inpT, typename outT>
libjit_arg_min_generic(const inpT * inpW,outT * outW,const dim_t * dims,size_t numDims,size_t axis)667 static void libjit_arg_min_generic(const inpT *inpW, outT *outW,
668 const dim_t *dims, size_t numDims,
669 size_t axis) {
670
671 // Product of outer dimensions excluding the axis dimension.
672 dim_t outerLen = 1;
673 for (dim_t idx = 0; idx < axis; ++idx) {
674 outerLen *= dims[idx];
675 }
676
677 // Axis dimension length.
678 dim_t axisLen = dims[axis];
679
680 // Product of inner dimensions excluding the axis dimension.
681 dim_t innerLen = 1;
682 for (dim_t idx = axis + 1; idx < numDims; ++idx) {
683 innerLen *= dims[idx];
684 }
685
686 // Traverse data such that output is written linearly.
687 const inpT *inpPtr = inpW;
688 outT *outPtr = outW;
689 for (dim_t outerIdx = 0; outerIdx < outerLen; ++outerIdx) {
690 for (dim_t innerIdx = 0; innerIdx < innerLen; ++innerIdx) {
691 inpT minVal = std::numeric_limits<inpT>::max();
692 outT minIdx = 0;
693 for (dim_t axisIdx = 0; axisIdx < axisLen; ++axisIdx) {
694 inpT inpVal = *inpPtr;
695 if (inpVal < minVal) {
696 minVal = inpVal;
697 minIdx = axisIdx;
698 }
699 inpPtr += innerLen;
700 }
701 inpPtr = inpPtr - axisLen * innerLen + 1;
702 *outPtr++ = minIdx;
703 }
704 inpPtr = inpPtr - innerLen + axisLen * innerLen;
705 }
706 }
707
708 template <typename T>
libjit_max_pool_generic(const T * inW,T * outW,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernelSizes,dim_t * strides,dim_t * pads)709 static void libjit_max_pool_generic(const T *inW, T *outW, const dim_t *inWdims,
710 const dim_t *outWdims, dim_t *kernelSizes,
711 dim_t *strides, dim_t *pads) {
712 dim_t pad_t = pads[0];
713 dim_t pad_l = pads[1];
714 dim_t stride_h = strides[0];
715 dim_t stride_w = strides[1];
716 dim_t kernel_h = kernelSizes[0];
717 dim_t kernel_w = kernelSizes[1];
718 // For each sample in the batch:
719 for (dim_t n = 0; n < outWdims[0]; n++) {
720 // For each (x,y) step in the input/output tensor:
721 sdim_t x = -(sdim_t)pad_t;
722 for (dim_t ax = 0; ax < outWdims[1]; x += stride_h, ax++) {
723 sdim_t y = -(sdim_t)pad_l;
724 for (dim_t ay = 0; ay < outWdims[2]; y += stride_w, ay++) {
725
726 // For each layer in the output tensor:
727 for (dim_t z = 0; z < inWdims[3]; z++) {
728 int first = 1;
729 T max = 0;
730
731 // For each element in the pool filter:
732 for (dim_t fx = 0; fx < kernel_h; fx++) {
733 for (dim_t fy = 0; fy < kernel_w; fy++) {
734 sdim_t ox = x + fx;
735 sdim_t oy = y + fy;
736
737 // Ignore index access below zero (this is due to padding).
738 if (ox < 0 || oy < 0 || ox >= (sdim_t)inWdims[1] ||
739 oy >= (sdim_t)inWdims[2]) {
740 continue;
741 }
742
743 float val =
744 inW[libjit_getXYZW(inWdims, n, (dim_t)ox, (dim_t)oy, z)];
745
746 if (first || (val >= max)) {
747 first = 0;
748 max = val;
749 }
750 }
751 }
752
753 outW[libjit_getXYZW(outWdims, n, ax, ay, z)] = max;
754 } // C
755 } // W
756 } // H
757 } // N
758 }
759
760 template <typename T, typename T2>
761 static void
libjit_max_pool_argmax_generic(const T * inW,T * outW,T2 * argmax,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernels,dim_t * strides,dim_t * pads)762 libjit_max_pool_argmax_generic(const T *inW, T *outW, T2 *argmax,
763 const dim_t *inWdims, const dim_t *outWdims,
764 dim_t *kernels, dim_t *strides, dim_t *pads) {
765 dim_t pad_t = pads[0];
766 dim_t pad_l = pads[1];
767 dim_t stride_h = strides[0];
768 dim_t stride_w = strides[1];
769 dim_t kernel_h = kernels[0];
770 dim_t kernel_w = kernels[1];
771 // For each input in the batch:
772 for (dim_t n = 0; n < outWdims[0]; n++) {
773
774 // For each (x,y) step in the input/output tensor:
775 sdim_t x = -(sdim_t)pad_t;
776 for (dim_t ax = 0; ax < outWdims[1]; x += stride_h, ax++) {
777 sdim_t y = -(sdim_t)pad_l;
778 for (dim_t ay = 0; ay < outWdims[2]; y += stride_w, ay++) {
779
780 // For each channel in the output tensor:
781 for (dim_t z = 0; z < outWdims[3]; z++) {
782 int64_t argmaxNHWC = 0;
783 int first = 1;
784 T max = 0;
785
786 for (dim_t kx = 0; kx < kernel_h; kx++) {
787 for (dim_t ky = 0; ky < kernel_w; ky++) {
788 sdim_t ox = x + kx;
789 sdim_t oy = y + ky;
790
791 if (ox < 0 || oy < 0 || ox >= (sdim_t)inWdims[1] ||
792 oy >= (sdim_t)inWdims[2]) {
793 continue;
794 }
795 const dim_t flatIndex =
796 libjit_getXYZW(inWdims, n, (dim_t)ox, (dim_t)oy, z);
797 T val = inW[flatIndex];
798 if (first || (val >= max)) {
799 first = 0;
800 max = val;
801 argmaxNHWC = flatIndex;
802 }
803 }
804 }
805
806 const dim_t flatIndex = libjit_getXYZW(outWdims, n, ax, ay, z);
807 outW[flatIndex] = max;
808 argmax[flatIndex] = argmaxNHWC;
809 } // C
810 } // W
811 } // H
812 } // N
813 }
814
815 template <typename T>
libjit_resizenearest_generic(T * dst,const T * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)816 void libjit_resizenearest_generic(T *dst, const T *src, const float *scale,
817 const dim_t *inWdims, const dim_t *outWdims) {
818
819 for (dim_t ob = 0; ob < outWdims[0]; ++ob) {
820 auto ib = std::min(dim_t(ob / (scale[0])), inWdims[0] - 1);
821 for (dim_t oh = 0; oh < outWdims[1]; ++oh) {
822 auto ih = std::min(dim_t(oh / (scale[1])), inWdims[1] - 1);
823 for (dim_t ow = 0; ow < outWdims[2]; ++ow) {
824 auto iw = std::min(dim_t(ow / (scale[2])), inWdims[2] - 1);
825 for (dim_t oc = 0; oc < outWdims[3]; ++oc) {
826 auto ic = std::min(dim_t(oc / (scale[3])), inWdims[3] - 1);
827 const dim_t inIndex = libjit_getXYZW(inWdims, ib, ih, iw, ic);
828 const dim_t outIndex = libjit_getXYZW(outWdims, ob, oh, ow, oc);
829 dst[outIndex] = src[inIndex];
830 }
831 }
832 }
833 }
834 }
835
836 template <typename T>
837 static void
libjit_resizebilinear_generic(T * dst,const T * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)838 libjit_resizebilinear_generic(T *dst, const T *src, const float *scale,
839 const dim_t *inWdims, const dim_t *outWdims) {
840 for (dim_t ob = 0; ob < outWdims[0]; ++ob) {
841 for (dim_t oh = 0; oh < outWdims[1]; ++oh) {
842 for (dim_t ow = 0; ow < outWdims[2]; ++ow) {
843 float ihf = oh / scale[1];
844 float iwf = ow / scale[2];
845 dim_t ih = dim_t(ihf);
846 dim_t iw = dim_t(iwf);
847
848 auto ih0 = std::min(ih, inWdims[1] - 1);
849 auto ih1 = std::min(ih + 1, inWdims[1] - 1);
850 auto iw0 = std::min(iw, inWdims[2] - 1);
851 auto iw1 = std::min(iw + 1, inWdims[2] - 1);
852
853 for (dim_t oc = 0; oc < outWdims[3]; ++oc) {
854 float v00 = src[libjit_getXYZW(inWdims, ob, ih0, iw0, oc)];
855 float v01 = src[libjit_getXYZW(inWdims, ob, ih0, iw1, oc)];
856 float v10 = src[libjit_getXYZW(inWdims, ob, ih1, iw0, oc)];
857 float v11 = src[libjit_getXYZW(inWdims, ob, ih1, iw1, oc)];
858
859 float hd = v00 + (v10 - v00) * (ihf - ih);
860 float hw = v01 + (v11 - v01) * (ihf - ih);
861 float result = hd + (hw - hd) * (iwf - iw);
862 dst[libjit_getXYZW(outWdims, ob, oh, ow, oc)] = result;
863 }
864 }
865 }
866 }
867 }
868
869 template <typename T>
870 static void
libjit_batchedadd_quantized(int8_t * dest,const int8_t * batch,const T * slice,dim_t numSlice,dim_t sliceSize,int32_t destOffset,int32_t batchOffset,int32_t sliceOffset,int32_t batchPre,int32_t batchPost,int32_t batchScale,int32_t slicePre,int32_t slicePost,int32_t sliceScale)871 libjit_batchedadd_quantized(int8_t *dest, const int8_t *batch, const T *slice,
872 dim_t numSlice, dim_t sliceSize, int32_t destOffset,
873 int32_t batchOffset, int32_t sliceOffset,
874 int32_t batchPre, int32_t batchPost,
875 int32_t batchScale, int32_t slicePre,
876 int32_t slicePost, int32_t sliceScale) {
877 for (dim_t n = 0; n < numSlice; n++) {
878 dim_t base = n * sliceSize;
879 for (dim_t i = 0; i < sliceSize; i++) {
880 int32_t b = batch[base + i] - batchOffset;
881 int32_t s = slice[i] - sliceOffset;
882 int32_t x = libjit_scale_i32i8(b, batchPre, batchPost, batchScale, 0);
883 int32_t y = libjit_scale_i32i8(s, slicePre, slicePost, sliceScale, 0);
884 dest[base + i] = libjit_clip(x + y + destOffset);
885 }
886 }
887 }
888
find_min_max_f(float * tensor,dim_t size,float & min,float & max)889 static void find_min_max_f(float *tensor, dim_t size, float &min, float &max) {
890 min = tensor[0];
891 max = tensor[0];
892
893 for (dim_t i = 1; i < size; ++i) {
894 float tensorVal = tensor[i];
895 if (tensorVal < min)
896 min = tensorVal;
897
898 if (tensorVal > max)
899 max = tensorVal;
900
901 // Sanity check for NaN and Infinity.
902 assert(!std::isnan(tensor[i]) && "NaN value found!");
903 assert(!std::isinf(tensor[i]) && "Infinity value found!");
904 }
905 }
906
check_all_zeros(float * arrayToCheck,dim_t size)907 static int check_all_zeros(float *arrayToCheck, dim_t size) {
908 for (dim_t i = 0; i < size; ++i) {
909 if (arrayToCheck[i] != 0) {
910 return 0;
911 }
912 }
913 return 1;
914 }
915
916 /// Gen a bin number to insert \p value into the histogram which has \p nBins
917 /// with \p minValue and binWidth in histogram.
get_bin(dim_t nBins,float binWidth,float minValue,float value)918 static dim_t get_bin(dim_t nBins, float binWidth, float minValue, float value) {
919 dim_t result =
920 binWidth == 0
921 ? 0
922 : MIN(static_cast<dim_t>((value - minValue) / binWidth), nBins - 1);
923 return result;
924 }
925
926 template <typename T>
libjit_space_to_depth_generic(const T * inPtr,T * outPtr,dim_t blockSize,const dim_t * inDims,const dim_t * outDims)927 static void libjit_space_to_depth_generic(const T *inPtr, T *outPtr,
928 dim_t blockSize, const dim_t *inDims,
929 const dim_t *outDims) {
930 dim_t inHeight = inDims[1];
931 dim_t inWidth = inDims[2];
932 dim_t inDepth = inDims[3];
933
934 dim_t outBatch = outDims[0];
935 dim_t outHeight = outDims[1];
936 dim_t outWidth = outDims[2];
937 dim_t outDepth = outDims[3];
938
939 for (dim_t b = 0; b < outBatch; ++b) {
940 for (dim_t h = 0; h < outHeight; ++h) {
941 for (dim_t w = 0; w < outWidth; ++w) {
942 for (dim_t c = 0; c < outDepth; ++c) {
943 // NHWC
944 // c +
945 // w * outDepth +
946 // h * outDepth * outWidth +
947 // b * outDepth * outWidth * outHeight
948 dim_t outIndex = c + outDepth * (w + outWidth * (h + b * outHeight));
949
950 // Gets the block layer we are on
951 dim_t blockDepthLayer = c / inDepth;
952 // every multiple of block size we reset to 0 offset
953 dim_t iw = w * blockSize + blockDepthLayer % blockSize;
954 // every multiple of blockSize we start height traversal + 1
955 dim_t ih = h * blockSize + blockDepthLayer / blockSize;
956 // at every multiple of inDepth index in to input depths resets to 0
957 dim_t id = c % inDepth;
958
959 dim_t inIndex = id + inDepth * (iw + inWidth * (ih + b * inHeight));
960 outPtr[outIndex] = inPtr[inIndex];
961 }
962 }
963 }
964 }
965 }
966
967 template <typename DstType, typename SrcType>
968 static void
libjit_copy_kernel_with_conversion(DstType * dstPtr,const SrcType * srcPtr,const dim_t * dims,dim_t numDims)969 libjit_copy_kernel_with_conversion(DstType *dstPtr, const SrcType *srcPtr,
970 const dim_t *dims, dim_t numDims) {
971 dim_t dimSize = 1;
972 for (dim_t i = 0; i < numDims; ++i) {
973 dimSize *= dims[i];
974 }
975
976 for (dim_t i = 0; i < dimSize; ++i) {
977 dstPtr[i] = DstType(srcPtr[i]);
978 }
979 }
980
981 /// The dimensions passed in here are pre-expanded in LLVMIRGen with 1s so that
982 /// we can iterate over the shape here, regardless of the shape of the tensor.
983 template <typename T>
libjit_reducemin(T * dest,const T * batch,size_t destSize,const dim_t * destDims,const dim_t * batchDims,T init)984 static void libjit_reducemin(T *dest, const T *batch, size_t destSize,
985 const dim_t *destDims, const dim_t *batchDims,
986 T init) {
987 for (dim_t i = 0; i < destSize; i++) {
988 dest[i] = init;
989 }
990
991 unsigned int axis[6];
992 for (dim_t i = 0; i < 6; i++) {
993 axis[i] = (destDims[i] > 1);
994 }
995
996 for (dim_t x = 0, dx = 0; x < batchDims[0]; x++, dx += axis[0]) {
997 for (dim_t y = 0, dy = 0; y < batchDims[1]; y++, dy += axis[1]) {
998 for (dim_t z = 0, dz = 0; z < batchDims[2]; z++, dz += axis[2]) {
999 for (dim_t w = 0, dw = 0; w < batchDims[3]; w++, dw += axis[3]) {
1000 for (dim_t q = 0, dq = 0; q < batchDims[4]; q++, dq += axis[4]) {
1001 for (dim_t r = 0, dr = 0; r < batchDims[5]; r++, dr += axis[5]) {
1002 T fdest =
1003 dest[libjit_getXYZWQR(destDims, dx, dy, dz, dw, dq, dr)];
1004 T fnew = batch[libjit_getXYZWQR(batchDims, x, y, z, w, q, r)];
1005 dest[libjit_getXYZWQR(destDims, dx, dy, dz, dw, dq, dr)] =
1006 std::min(fdest, fnew);
1007 }
1008 }
1009 }
1010 }
1011 }
1012 }
1013 }
1014
1015 template <typename T, typename T2>
libjit_cross_entropy_loss_generic(T * CE,T * P,T2 * labels,dim_t * dims)1016 static void libjit_cross_entropy_loss_generic(T *CE, T *P, T2 *labels,
1017 dim_t *dims) {
1018 CE[0] = 0.0;
1019 for (dim_t n = 0; n < dims[0]; ++n) {
1020 auto y = labels[n];
1021 auto p_n = P[libjit_getXY(dims, n, y)];
1022 CE[0] -= log(p_n);
1023 }
1024 }
1025
1026 template <typename T, typename T2>
libjit_sparse_lengths_sum_generic(T * dest,T * data,T2 * indices,int32_t * lengths,dim_t segments,dim_t lineSize)1027 static void libjit_sparse_lengths_sum_generic(T *dest, T *data, T2 *indices,
1028 int32_t *lengths, dim_t segments,
1029 dim_t lineSize) {
1030 memset(dest, 0, segments * lineSize * sizeof(float));
1031 dim_t curIndex = 0;
1032 for (dim_t i = 0; i < segments; i++) {
1033 for (int32_t j = 0; j < lengths[i]; j++) {
1034 dim_t line = indices[curIndex];
1035 for (dim_t k = 0; k < lineSize; k++) {
1036 dest[i * lineSize + k] += data[line * lineSize + k];
1037 }
1038 curIndex++;
1039 }
1040 }
1041 }
1042
1043 template <typename T, typename T2>
1044 static void
libjit_sparse_lengths_weighted_sum_generic(T * dest,T * data,float * weights,T2 * indices,int32_t * lengths,dim_t segments,dim_t lineSize)1045 libjit_sparse_lengths_weighted_sum_generic(T *dest, T *data, float *weights,
1046 T2 *indices, int32_t *lengths,
1047 dim_t segments, dim_t lineSize) {
1048 memset(dest, 0, segments * lineSize * sizeof(float));
1049 dim_t curIndex = 0;
1050 for (dim_t i = 0; i < segments; i++) {
1051 for (int32_t j = 0; j < lengths[i]; j++) {
1052 float weight = weights[curIndex];
1053 dim_t line = indices[curIndex];
1054 for (dim_t k = 0; k < lineSize; k++) {
1055 dest[i * lineSize + k] += weight * data[line * lineSize + k];
1056 }
1057 curIndex++;
1058 }
1059 }
1060 }
1061
1062 template <typename T, typename T2>
libjit_sparse_lengths_weighted_sum_grad_generic(const T * destGrad,T * dataGrad,T * weightsGrad,const T * data,const T * weights,const T2 * indices,const int32_t * lengths,dim_t segments,dim_t lineSize,dim_t dataGradRawSize)1063 static void libjit_sparse_lengths_weighted_sum_grad_generic(
1064 const T *destGrad, T *dataGrad, T *weightsGrad, const T *data,
1065 const T *weights, const T2 *indices, const int32_t *lengths, dim_t segments,
1066 dim_t lineSize, dim_t dataGradRawSize) {
1067 // The data gradients not touched by this operation should
1068 // be 0, so set the entire buffer to 0 to start with.
1069 memset(dataGrad, 0, dataGradRawSize);
1070
1071 for (dim_t i = 0, curIndex = 0; i < segments; ++i) {
1072 for (int32_t j = 0; j < lengths[i]; ++j, ++curIndex) {
1073 // For each index in each segment:
1074 // 1) accumulate into the corresponding data gradient the product of
1075 // the gradient of the result it was added to and the weight that it
1076 // was multiplied by during the SparseLengthsWeightedSum operation.
1077 //
1078 // 2) accumulate into each weight gradient the reduced sum of the
1079 // elementwise product of the result slice that the corresponding
1080 // weight produced and the input slice that the weight was multiplied
1081 // with.
1082 float weightGrad = 0.0f;
1083 float weight = weights[curIndex];
1084 dim_t line = indices[curIndex];
1085 for (dim_t k = 0; k < lineSize; ++k) {
1086 dataGrad[line * lineSize + k] += weight * destGrad[i * lineSize + k];
1087 weightGrad += destGrad[i * lineSize + k] * data[line * lineSize + k];
1088 }
1089 weightsGrad[curIndex] = weightGrad;
1090 }
1091 }
1092 }
1093
1094 template <typename T, typename T2>
libjit_rowwise_quantized_sparse_lengths_weighted_sum_generic(T * dest,uint8_t * data,T * scales,T * offsets,T * weights,T2 * indices,int32_t * lengths,dim_t segments,dim_t lineSize)1095 static void libjit_rowwise_quantized_sparse_lengths_weighted_sum_generic(
1096 T *dest, uint8_t *data, T *scales, T *offsets, T *weights, T2 *indices,
1097 int32_t *lengths, dim_t segments, dim_t lineSize) {
1098 memset(dest, 0, segments * lineSize * sizeof(float));
1099 dim_t curIndex = 0;
1100 for (dim_t i = 0; i < segments; i++) {
1101 for (int32_t j = 0; j < lengths[i]; j++) {
1102 const float weight = weights[curIndex];
1103 const dim_t line = indices[curIndex];
1104 const float scale = scales[line];
1105 const float offset = offsets[line];
1106 for (dim_t k = 0; k < lineSize; k++) {
1107 const float fData = scale * data[line * lineSize + k] + offset;
1108 dest[i * lineSize + k] += weight * fData;
1109 }
1110 curIndex++;
1111 }
1112 }
1113 }
1114
1115 template <typename T, typename T2>
libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_generic(T * dest,int8_t * data,T * weights,T2 * indices,int32_t * lengths,dim_t segments,dim_t inLineSize,dim_t outLineSize)1116 static void libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_generic(
1117 T *dest, int8_t *data, T *weights, T2 *indices, int32_t *lengths,
1118 dim_t segments, dim_t inLineSize, dim_t outLineSize) {
1119 memset(dest, 0, segments * outLineSize * sizeof(float));
1120 dim_t curIndex = 0;
1121 for (dim_t i = 0; i < segments; i++) {
1122 for (int32_t j = 0, e = lengths[i]; j < e; j++) {
1123 const float weight = weights[curIndex];
1124 const dim_t line = indices[curIndex];
1125 const int8_t *currRowScaleOffsetPtr =
1126 data + ((line + 1) * inLineSize) - 2 * sizeof(float);
1127 float scale, offset;
1128 memcpy(&scale, currRowScaleOffsetPtr, sizeof(float));
1129 memcpy(&offset, currRowScaleOffsetPtr + sizeof(float), sizeof(float));
1130 for (dim_t k = 0; k < outLineSize; k++) {
1131 const float fData =
1132 (scale * (uint8_t)(data[line * inLineSize + k])) + offset;
1133 dest[i * outLineSize + k] += weight * fData;
1134 }
1135 curIndex++;
1136 }
1137 }
1138 }
1139
1140 template <typename T, typename T2>
libjit_sparse_to_dense_generic(T * dest,const T2 * indices,const T * values,dim_t numIndices,dim_t destSize,dim_t valueSize)1141 static void libjit_sparse_to_dense_generic(T *dest, const T2 *indices,
1142 const T *values, dim_t numIndices,
1143 dim_t destSize, dim_t valueSize) {
1144 memset(dest, 0, destSize * sizeof(float));
1145
1146 for (dim_t i = 0, valuesOffset = 0; i < numIndices;
1147 ++i, valuesOffset += valueSize) {
1148 dim_t idx = indices[i];
1149 dim_t destOffset = idx * valueSize;
1150
1151 for (size_t j = 0; j < valueSize; ++j) {
1152 dest[destOffset + j] += values[valuesOffset + j];
1153 }
1154 }
1155 }
1156
1157 struct ClassBox {
1158 float score{0.0f};
1159 size_t index{0};
1160 };
1161
1162 struct Box {
1163 float v0{0.0f};
1164 float v1{0.0f};
1165 float v2{0.0f};
1166 float v3{0.0f};
1167 };
1168
1169 struct OutBox {
1170 float classValue{0.0f};
1171 size_t batchIndex{0};
1172 size_t classIndex{0};
1173 size_t boxIndex{0};
1174 };
1175
maxMin(float lhs,float rhs,float & min,float & max)1176 static void maxMin(float lhs, float rhs, float &min, float &max) {
1177 if (lhs >= rhs) {
1178 min = rhs;
1179 max = lhs;
1180 } else {
1181 min = lhs;
1182 max = rhs;
1183 }
1184 }
1185
checkIOU(const Box & sb,const Box & cb,float iouThreshold,size_t centerPointBox)1186 static bool checkIOU(const Box &sb, const Box &cb, float iouThreshold,
1187 size_t centerPointBox) {
1188 float xSMin = 0.0f;
1189 float ySMin = 0.0f;
1190 float xSMax = 0.0f;
1191 float ySMax = 0.0f;
1192
1193 float xCMin = 0.0f;
1194 float yCMin = 0.0f;
1195 float xCMax = 0.0f;
1196 float yCMax = 0.0f;
1197
1198 // Standardizing coordinates so that (xmin, ymin) is upper left corner of a
1199 // box and (xmax, ymax) is lower right corner of the box.
1200 if (!centerPointBox) {
1201 // 0 means coordinates for diagonal ends of a box.
1202 // Coordinates can either be absolute or normalized.
1203 maxMin(sb.v0, sb.v2, xSMin, xSMax);
1204 maxMin(sb.v1, sb.v3, ySMin, ySMax);
1205
1206 maxMin(cb.v0, cb.v2, xCMin, xCMax);
1207 maxMin(cb.v1, cb.v3, yCMin, yCMax);
1208 } else {
1209 float halfWidthS = sb.v2 / 2.0f;
1210 float halfHeightS = sb.v3 / 2.0f;
1211 float halfWidthC = cb.v2 / 2.0f;
1212 float halfHeightC = cb.v3 / 2.0f;
1213
1214 xSMin = sb.v0 - halfWidthS;
1215 ySMin = sb.v1 - halfHeightS;
1216 xSMax = sb.v0 + halfWidthS;
1217 ySMax = sb.v1 + halfHeightS;
1218
1219 xCMin = cb.v0 - halfWidthC;
1220 yCMin = cb.v1 - halfHeightC;
1221 xCMax = cb.v0 + halfWidthC;
1222 yCMax = cb.v1 + halfHeightC;
1223 }
1224
1225 // finding upper left and lower right corner of a box formed by intersection.
1226 float xMin = MAX(xSMin, xCMin);
1227 float yMin = MAX(ySMin, yCMin);
1228 float xMax = MIN(xSMax, xCMax);
1229 float yMax = MIN(ySMax, yCMax);
1230
1231 float intersectionArea = MAX((0.0f), xMax - xMin) * MAX((0.0f), yMax - yMin);
1232
1233 if (intersectionArea == 0.0f) {
1234 return false;
1235 }
1236
1237 float sArea = (xSMax - xSMin) * (ySMax - ySMin);
1238 float cArea = (xCMax - xCMin) * (yCMax - yCMin);
1239 float unionArea = sArea + cArea - intersectionArea;
1240
1241 return intersectionArea > iouThreshold * unionArea;
1242 }
1243
1244 // ONNX
1245 // Class/Score [BatchNum][ClassNum][BoxNum]
1246 // Box [BatchNum][BoxNum][4]
1247 // Result [BatchNum*MaxOutputPerBatch][3]
1248 // V4
1249 // Class/Score [BatchNum][BoxNum]
1250 // Boxes [BatdhNum][BoxNum][4]
1251 // Result [BatchNum*MaxOutputPerBatch]
1252 // NumberOfIndicesDetected [BatchNum*MaxOutputPerBatch]
1253 template <typename T>
1254 static void
libjit_nms_generic(T * indices,T * numDetected,const float * boxTensor,const dim_t * boxTensorDims,dim_t boxTensorDimSize,const float * scoresTensor,const dim_t * scoresTensorDims,dim_t scoresTensorDimSize,const dim_t * resultTensorDims,dim_t resultTensorDimSize,unsigned centerPointBox,unsigned maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold,bool isV4)1255 libjit_nms_generic(T *indices, T *numDetected, const float *boxTensor,
1256 const dim_t *boxTensorDims, dim_t boxTensorDimSize,
1257 const float *scoresTensor, const dim_t *scoresTensorDims,
1258 dim_t scoresTensorDimSize, const dim_t *resultTensorDims,
1259 dim_t resultTensorDimSize, unsigned centerPointBox,
1260 unsigned maxOutputBoxesPerClass, float iouThreshold,
1261 float scoreThreshold, bool isV4) {
1262 int boxesBoxDim = boxTensorDimSize - 2;
1263
1264 size_t numBatches = 1;
1265 size_t numClasses = 1;
1266 size_t numBoxes = boxTensorDims[boxesBoxDim];
1267
1268 size_t maxOutputPerBatch = 0;
1269 if (!isV4) {
1270 int boxesBatchDim = boxTensorDimSize - 3;
1271 int scoresBatchDim = scoresTensorDimSize - 3;
1272
1273 int scoresBoxDim = scoresTensorDimSize - 1;
1274 int scoresClassDim = scoresTensorDimSize - 2;
1275
1276 assert(scoresTensorDims[scoresBoxDim] == boxTensorDims[boxesBoxDim] &&
1277 "Mismatch between number of scores and number of boxes.");
1278 assert(scoresTensorDims[scoresBatchDim] == boxTensorDims[boxesBatchDim] &&
1279 "Scores and Box Batch Dimensions don't match.");
1280 (void)boxesBatchDim;
1281 (void)scoresBoxDim;
1282 numBatches = scoresTensorDims[scoresBatchDim];
1283 numClasses = scoresTensorDims[scoresClassDim];
1284 numBoxes = boxTensorDims[boxesBoxDim];
1285 maxOutputPerBatch = resultTensorDims[resultTensorDimSize - 2] / numBatches;
1286 } else {
1287 maxOutputPerBatch = resultTensorDims[resultTensorDimSize - 1] / numBatches;
1288 }
1289
1290 static_assert(sizeof(Box) == 4 * sizeof(float),
1291 "Can't reinterpret raw float data as a Box.");
1292 const Box *boxes = reinterpret_cast<const Box *>(boxTensor);
1293
1294 auto cmpFunc = [](const ClassBox &cb1, const ClassBox &cb2) -> bool {
1295 return cb1.score > cb2.score;
1296 };
1297
1298 size_t outPutBoxIndex = 0;
1299 for (size_t batchIndex = 0; batchIndex < numBatches; ++batchIndex) {
1300 int32_t detectedPerBatch = 0;
1301 OutBox minBox{scoresTensor[batchIndex * numClasses], batchIndex, 0, 0};
1302 for (size_t classIndex = 0; classIndex < numClasses; ++classIndex) {
1303 ClassBox selectedIndices[numBoxes];
1304 ClassBox potentialBoxes[numBoxes];
1305 size_t indexPBoxes = 0;
1306 const float *currClass =
1307 &scoresTensor[(batchIndex * numClasses + classIndex) * numBoxes];
1308 for (size_t boxIndex = 0; boxIndex < numBoxes; ++boxIndex) {
1309 float classScore = currClass[boxIndex];
1310 if (classScore > scoreThreshold) {
1311 ClassBox &b = potentialBoxes[indexPBoxes++];
1312 b.score = classScore;
1313 b.index = boxIndex;
1314 }
1315 }
1316
1317 std::sort(potentialBoxes, potentialBoxes + indexPBoxes, cmpFunc);
1318
1319 size_t indexSBoxes = 0;
1320 size_t detectedPerClass = 0;
1321 float tScore = minBox.classValue;
1322 for (unsigned int i = 0; i < indexPBoxes; ++i) {
1323 ClassBox &pbI = potentialBoxes[i];
1324 const Box &potentialBox = boxes[batchIndex * numBoxes + pbI.index];
1325 bool selected = true;
1326 for (unsigned int j = 0; j < indexSBoxes && selected; ++j) {
1327 ClassBox &sbI = selectedIndices[j];
1328 const Box &selectedBox = boxes[batchIndex * numBoxes + sbI.index];
1329 selected = !checkIOU(selectedBox, potentialBox, iouThreshold,
1330 centerPointBox);
1331 }
1332
1333 if (selected) {
1334 selectedIndices[indexSBoxes++] = pbI;
1335 if (isV4) {
1336 indices[outPutBoxIndex] = pbI.index;
1337 } else {
1338 indices[outPutBoxIndex * 3 + 0] = batchIndex;
1339 indices[outPutBoxIndex * 3 + 1] = classIndex;
1340 indices[outPutBoxIndex * 3 + 2] = pbI.index;
1341 }
1342
1343 tScore = pbI.score;
1344 ++outPutBoxIndex;
1345 ++detectedPerClass;
1346 ++detectedPerBatch;
1347 }
1348
1349 if (detectedPerClass == maxOutputBoxesPerClass) {
1350 break;
1351 }
1352 }
1353
1354 if (tScore < minBox.classValue) {
1355 minBox.classValue = tScore;
1356 if (isV4) {
1357 minBox.boxIndex = indices[outPutBoxIndex - 1];
1358 } else {
1359 minBox.boxIndex = indices[(outPutBoxIndex - 1) * 3 + 2];
1360 }
1361 minBox.classIndex = classIndex;
1362 }
1363 }
1364
1365 // Filling the rest of the class with minimum value.
1366 for (size_t i = detectedPerBatch; i < maxOutputPerBatch; ++i) {
1367 if (isV4) {
1368 indices[outPutBoxIndex] = minBox.boxIndex;
1369 } else {
1370 indices[outPutBoxIndex * 3 + 0] = minBox.batchIndex;
1371 indices[outPutBoxIndex * 3 + 1] = minBox.classIndex;
1372 indices[outPutBoxIndex * 3 + 2] = minBox.boxIndex;
1373 }
1374
1375 ++outPutBoxIndex;
1376 }
1377 // For ONNX NMS it's not used, for TF Batch Dimension is 1.
1378 for (size_t i = 0; i < maxOutputBoxesPerClass; ++i) {
1379 numDetected[batchIndex * maxOutputBoxesPerClass + i] = detectedPerBatch;
1380 }
1381 }
1382 }
1383
1384 template <typename T, typename T2>
libjit_softmax_grad_generic(T * inG,T * outW,const T2 * selectedW,const dim_t * idim,const dim_t * selectdim)1385 void libjit_softmax_grad_generic(T *inG, T *outW, const T2 *selectedW,
1386 const dim_t *idim, const dim_t *selectdim) {
1387 for (dim_t n = 0; n < idim[0]; n++) {
1388 for (dim_t i = 0; i < idim[1]; i++) {
1389 float delta = (selectedW[libjit_getXY(selectdim, n, 0)] == i);
1390 inG[libjit_getXY(idim, n, i)] = outW[libjit_getXY(idim, n, i)] - delta;
1391 }
1392 }
1393 }
1394
1395 template <typename T, typename T2>
libjit_max_pool_argmax_grad_generic(T * inG,const T * outG,const T2 * argmax,const dim_t * inGdims,const dim_t * outWdims)1396 void libjit_max_pool_argmax_grad_generic(T *inG, const T *outG,
1397 const T2 *argmax, const dim_t *inGdims,
1398 const dim_t *outWdims) {
1399 // NHWC format is assumed
1400 for (dim_t n = 0; n < outWdims[0]; n++) {
1401 for (dim_t z = 0; z < outWdims[3]; z++) {
1402 // Clear inG
1403 for (dim_t x = 0; x < inGdims[1]; x++) {
1404 for (dim_t y = 0; y < inGdims[2]; y++) {
1405 inG[libjit_getXYZW(inGdims, n, x, y, z)] = 0.0;
1406 }
1407 }
1408
1409 for (dim_t ax = 0; ax < outWdims[1]; ax++) {
1410 for (dim_t ay = 0; ay < outWdims[2]; ay++) {
1411 // Reuse precomputed linear index of max element from argmax.
1412 const dim_t flatIndex = libjit_getXYZW(outWdims, n, ax, ay, z);
1413 float df = outG[flatIndex];
1414 inG[argmax[flatIndex]] += df;
1415 } // W
1416 } // H
1417 } // C
1418 } // N
1419 }
1420 } // namespace
1421
1422 extern "C" {
1423
1424 /// Macro to define a mini-kernel for data-parallel operations. The body of the
1425 /// kernel is auto-generated by the macro.
1426 /// \p name the name of the kernel
1427 /// \p type the type of the tensor elements and of the return value
1428 /// \p body the operation to be performed
1429 #define DEFINE_DATA_PARALLEL_KERNEL(name, type, body) \
1430 type name(dim_t idx, const type *LHS, const type *RHS, const type *op3) { \
1431 return body; \
1432 }
1433
1434 /// Macro to define a mini-kernel for data-parallel operations. The body of the
1435 /// kernel is not auto-generated by the macro.
1436 /// \p name the name of the kernel
1437 #define DEFINE_DATA_PARALLEL_KERNEL_FUNC(name) \
1438 float name(dim_t idx, const float *LHS, const float *RHS, const float *op3)
1439
1440 /// Macro to define a mini-kernel for data-parallel operations with immediate
1441 /// operands.
1442 /// \p name the name of the kernel
1443 /// \p type the type of the tensor elements and of the return value
1444 /// \p body the operation to be performed
1445 #define DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(name, type, body) \
1446 type name(dim_t idx, type val, const type *LHS, const type *RHS) { \
1447 return body; \
1448 }
1449
1450 /// Macro to define a mini-kernel for data-parallel arithmetic quantized
1451 /// operations. The body of the kernel is auto-generated by the macro.
1452 /// \p name the name of the kernel
1453 /// \p type the type of the tensor elements
1454 /// \p body the operation to be performed
1455 #define DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED(name, type, body) \
1456 type name(dim_t idx, const type *LHS, const type *RHS, int32_t destOffset, \
1457 int32_t lhsOffset, int32_t rhsOffset, int32_t lhsPre, \
1458 int32_t lhsPost, int32_t lhsScale, int32_t rhsPre, \
1459 int32_t rhsPost, int32_t rhsScale) { \
1460 int32_t lhs = libjit_scale_i32i8(LHS[idx] - lhsOffset, lhsPre, lhsPost, \
1461 lhsScale, 0); \
1462 int32_t rhs = libjit_scale_i32i8(RHS[idx] - rhsOffset, rhsPre, rhsPost, \
1463 rhsScale, 0); \
1464 return libjit_clip((body) + destOffset); \
1465 }
1466
1467 /// Macro to define a mini-kernel for data-parallel multiplicative quantized
1468 /// operations. The body of the kernel is auto-generated by the macro.
1469 /// \p name the name of the kernel
1470 /// \p type the type of the tensor elements
1471 /// \p body the operation to be performed
1472 #define DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED_M(name, body) \
1473 int8_t name(dim_t idx, const int8_t *LHS, const int8_t *RHS, \
1474 int32_t destOffset, int32_t lhsOffset, int32_t rhsOffset, \
1475 int32_t pre, int32_t post, int32_t scale) { \
1476 int32_t lhs = LHS[idx] - lhsOffset; \
1477 int32_t rhs = RHS[idx] - rhsOffset; \
1478 return libjit_clip( \
1479 libjit_scale_i32i8((body), pre, post, scale, destOffset)); \
1480 }
1481
1482 /// Define mini-kernels for all data parallel operations. They are invoked from
1483 /// the generated kernels for sequences of data parallel operations.
1484 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_max_kernel_f, float,
1485 MAX(LHS[idx], RHS[idx]))
1486 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_min_kernel_f, float,
1487 MIN(LHS[idx], RHS[idx]))
1488 DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_f, float, LHS[idx])
1489 DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_u, int64_t, LHS[idx])
1490 DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_i8, int8_t, LHS[idx])
1491 DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_i16, int16_t, LHS[idx])
1492 DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_i32, int32_t, LHS[idx])
1493 DEFINE_DATA_PARALLEL_KERNEL(libjit_copy_kernel_b, int8_t, LHS[idx])
1494 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_add_kernel_f, float,
1495 LHS[idx] + RHS[idx])
1496 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_add_kernel_i32, int32_t,
1497 LHS[idx] + RHS[idx])
1498 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_sub_kernel_f, float,
1499 LHS[idx] - RHS[idx])
1500 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_div_kernel_f, float,
1501 LHS[idx] / RHS[idx])
1502 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_div_kernel_u, int64_t,
1503 LHS[idx] / RHS[idx])
1504 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_div_kernel_i32, int32_t,
1505 LHS[idx] / RHS[idx])
1506 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_mul_kernel_f, float,
1507 LHS[idx] * RHS[idx])
1508 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_mul_kernel_i32, int32_t,
1509 LHS[idx] * RHS[idx])
1510 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_pow_kernel_f, float,
1511 pow(LHS[idx], RHS[idx]))
1512 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_log_kernel_f, float, log(LHS[idx]))
1513 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_exp_kernel_f, float, exp(LHS[idx]))
1514 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_abs_kernel_f, float,
1515 std::abs(LHS[idx]))
1516 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_neg_kernel_f, float, -LHS[idx])
1517 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_floor_kernel_f, float,
1518 std::floor(LHS[idx]))
1519 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_ceil_kernel_f, float,
1520 std::ceil(LHS[idx]))
1521 // Rounding mode required by ONNX, Numpy, TensorFlow is round to even which
1522 // rounds to nearest even integer those values with fractional part 0.5.
1523 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_round_kernel_f, float,
1524 std::nearbyintf(LHS[idx]))
1525 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_sqrt_kernel_f, float,
1526 std::sqrt(LHS[idx]))
1527 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_rsqrt_kernel_f, float,
1528 1 / std::sqrt(LHS[idx]))
1529 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_reciprocal_kernel_f, float,
1530 1 / LHS[idx])
1531 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_sin_kernel_f, float,
1532 std::sin(LHS[idx]))
1533 DEFINE_DATA_PARALLEL_KERNEL(libjit_element_cos_kernel_f, float,
1534 std::cos(LHS[idx]))
1535 DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED(libjit_element_add_kernel_i8, int8_t,
1536 lhs + rhs)
1537 DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED(libjit_element_sub_kernel_i8, int8_t,
1538 lhs - rhs)
1539 DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED(libjit_element_max_kernel_i8, int8_t,
1540 MAX(lhs, rhs))
1541 DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED(libjit_element_min_kernel_i8, int8_t,
1542 MIN(lhs, rhs))
1543 DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED_M(libjit_element_mul_kernel_i8, lhs *rhs)
1544 DEFINE_DATA_PARALLEL_KERNEL_QUANTIZED_M(libjit_element_div_kernel_i8, lhs / rhs)
1545
1546 /// This is a variable used by Glow backends to determine the actual type used
1547 /// for size_t, dim_t and int variables when libjit was compiled.
1548 size_t libjit_sizeTVar;
1549 dim_t libjit_dimTVar;
1550 int libjit_intVar;
1551
1552 /// Specialize the Modulo kernel into two functions based on the
1553 /// value of SignFollowDivisor.
libjit_element_modulo_kernel_sign_follow_u(dim_t idx,const int64_t divisor,const int64_t * input)1554 int64_t libjit_element_modulo_kernel_sign_follow_u(dim_t idx,
1555 const int64_t divisor,
1556 const int64_t *input) {
1557 int64_t res = input[idx] % divisor;
1558 if (res && ((res > 0) != (divisor > 0))) {
1559 res += divisor;
1560 }
1561 return res;
1562 }
1563
libjit_element_modulo_kernel_no_sign_follow_u(dim_t idx,const int64_t divisor,const int64_t * input)1564 int64_t libjit_element_modulo_kernel_no_sign_follow_u(dim_t idx,
1565 const int64_t divisor,
1566 const int64_t *input) {
1567 return input[idx] % divisor;
1568 }
1569
libjit_element_modulo_kernel_sign_follow_i32(dim_t idx,const int64_t divisor,const int32_t * input)1570 int32_t libjit_element_modulo_kernel_sign_follow_i32(dim_t idx,
1571 const int64_t divisor,
1572 const int32_t *input) {
1573 int32_t res = input[idx] % divisor;
1574 if (res && ((res > 0) != (divisor > 0))) {
1575 res += divisor;
1576 }
1577 return res;
1578 }
1579
libjit_element_modulo_kernel_no_sign_follow_i32(dim_t idx,const int64_t divisor,const int32_t * input)1580 int32_t libjit_element_modulo_kernel_no_sign_follow_i32(dim_t idx,
1581 const int64_t divisor,
1582 const int32_t *input) {
1583 return input[idx] % divisor;
1584 }
1585
1586 //===----------------------------------------------------------------------===//
1587 // Logical operations
1588 //===----------------------------------------------------------------------===//
libjit_element_not_kernel_b(dim_t idx,const bool * input)1589 int8_t libjit_element_not_kernel_b(dim_t idx, const bool *input) {
1590 return !input[idx];
1591 }
1592
libjit_element_and_kernel_b(dim_t idx,const bool * LHS,const bool * RHS)1593 int8_t libjit_element_and_kernel_b(dim_t idx, const bool *LHS,
1594 const bool *RHS) {
1595 return LHS[idx] && RHS[idx];
1596 }
1597
libjit_element_or_kernel_b(dim_t idx,const bool * LHS,const bool * RHS)1598 int8_t libjit_element_or_kernel_b(dim_t idx, const bool *LHS, const bool *RHS) {
1599 return LHS[idx] || RHS[idx];
1600 }
1601
libjit_element_xor_kernel_b(dim_t idx,const bool * LHS,const bool * RHS)1602 int8_t libjit_element_xor_kernel_b(dim_t idx, const bool *LHS,
1603 const bool *RHS) {
1604 return LHS[idx] ^ RHS[idx];
1605 }
1606
1607 //===----------------------------------------------------------------------===//
1608 // Compare operations
1609 //===----------------------------------------------------------------------===//
1610 #define DEFINE_CMP_KERNEL_QUANTIZED(name, type, cmp) \
1611 int8_t name(dim_t idx, const type *LHS, const type *RHS, int32_t lhsOffset, \
1612 int32_t rhsOffset, int32_t pre, int32_t post, int32_t scale) { \
1613 int32_t lhs = LHS[idx] - lhsOffset; \
1614 int32_t rhs = RHS[idx] - rhsOffset; \
1615 return (libjit_scale_i32i8(lhs, pre, post, scale, 0) cmp rhs) ? 1 : 0; \
1616 }
1617 DEFINE_CMP_KERNEL_QUANTIZED(libjit_element_cmp_eq_kernel_i8, int8_t, ==)
1618 DEFINE_CMP_KERNEL_QUANTIZED(libjit_element_cmp_neq_kernel_i8, int8_t, !=)
1619 DEFINE_CMP_KERNEL_QUANTIZED(libjit_element_cmp_lt_kernel_i8, int8_t, <)
1620 DEFINE_CMP_KERNEL_QUANTIZED(libjit_element_cmp_lte_kernel_i8, int8_t, <=)
1621 #undef DEFINE_CMP_KERNEL_QUANTIZED
1622
1623 #define DEFINE_CMP_KERNEL_NON_QUANTIZED(name, type, cmp) \
1624 int8_t name(dim_t idx, const type *LHS, const type *RHS) { \
1625 return (LHS[idx] cmp RHS[idx]) ? 1 : 0; \
1626 }
1627
1628 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_eq_kernel_f, float, ==)
1629 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_eq_kernel_i32, int32_t, ==)
1630 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_eq_kernel_u, size_t, ==)
1631
1632 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_neq_kernel_f, float, !=)
1633 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_neq_kernel_i32, int32_t, !=)
1634 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_neq_kernel_u, size_t, !=)
1635
1636 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lt_kernel_f, float, <)
1637 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lt_kernel_i32, int32_t, <)
1638 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lt_kernel_u, size_t, <)
1639
1640 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lte_kernel_f, float, <=)
1641 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lte_kernel_i32, int32_t, <=)
1642 DEFINE_CMP_KERNEL_NON_QUANTIZED(libjit_element_cmp_lte_kernel_u, size_t, <=)
1643 #undef DEFINE_CMP_KERNEL_NON_QUANTIZED
1644
libjit_element_is_nan_kernel_f(dim_t idx,const float * input)1645 int8_t libjit_element_is_nan_kernel_f(dim_t idx, const float *input) {
1646 return std::isnan(input[idx]) ? 1 : 0;
1647 }
1648
1649 // Tanh cannot be vectorized by LLVM yet. Therefore we use the following
1650 // formula instead: 1 - 2 / (exp(x * 2) + 1), which is also used by Caffe2 and
1651 // provides a good accuracy.
1652 // Once LLVM supports the vectorization of tanh, we can replace this
1653 // approximation by a direct tanh call.
1654 // When the LIBJIT compile option "-ffast-math" is enabled the intermediate
1655 // computation expf(x) for Tanh operator is not handled properly for very
1656 // large positive values which results in NaN values for the Tanh output.
1657 // Therefore when the "-ffast-math" is enabled we compute the Tanh such that
1658 // we avoid computing large values for the "expf" function.
1659 #ifdef FFAST_MATH
DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_tanh_kernel_f)1660 DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_tanh_kernel_f) {
1661 float inpVal = LHS[idx];
1662 float tanhVal = -1 + 2 / (expf(-2 * std::abs(inpVal)) + 1);
1663 return std::copysignf(tanhVal, inpVal);
1664 }
1665 #else
DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_tanh_kernel_f)1666 DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_tanh_kernel_f) {
1667 return 1 - 2 / (expf(LHS[idx] * 2) + 1);
1668 }
1669 #endif // FFAST_MATH
1670
libjit_intlookuptable_kernel_i8(dim_t idx,const int8_t * src,const int8_t * mapping)1671 int8_t libjit_intlookuptable_kernel_i8(dim_t idx, const int8_t *src,
1672 const int8_t *mapping) {
1673 return mapping[src[idx] + 128];
1674 }
1675
libjit_elementselect_kernel_f(dim_t idx,const int8_t * cond,const float * LHS,const float * RHS)1676 float libjit_elementselect_kernel_f(dim_t idx, const int8_t *cond,
1677 const float *LHS, const float *RHS) {
1678 return (cond[idx] != 0) ? LHS[idx] : RHS[idx];
1679 }
1680
libjit_elementselect_kernel_i8(dim_t idx,const int8_t * cond,const int8_t * LHS,const int8_t * RHS,int32_t destOffset,int32_t lhsOffset,int32_t rhsOffset,int32_t lhsPre,int32_t lhsPost,int32_t lhsScale,int32_t rhsPre,int32_t rhsPost,int32_t rhsScale)1681 int8_t libjit_elementselect_kernel_i8(dim_t idx, const int8_t *cond,
1682 const int8_t *LHS, const int8_t *RHS,
1683 int32_t destOffset, int32_t lhsOffset,
1684 int32_t rhsOffset, int32_t lhsPre,
1685 int32_t lhsPost, int32_t lhsScale,
1686 int32_t rhsPre, int32_t rhsPost,
1687 int32_t rhsScale) {
1688 return (cond[idx] != 0)
1689 ? libjit_clip(libjit_scale_i32i8(LHS[idx] - lhsOffset, lhsPre,
1690 lhsPost, lhsScale, destOffset))
1691 : libjit_clip(libjit_scale_i32i8(RHS[idx] - rhsOffset, rhsPre,
1692 rhsPost, rhsScale, destOffset));
1693 }
1694
libjit_element_relu_f(dim_t idx,const float * src)1695 float libjit_element_relu_f(dim_t idx, const float *src) {
1696 float srcVal = src[idx];
1697 return MAX(srcVal, 0);
1698 }
1699
libjit_element_relu_i8(dim_t idx,const int8_t * src,int8_t srcOffset,int8_t destOffset,int32_t destPre,int32_t destPost,int32_t destScale)1700 int8_t libjit_element_relu_i8(dim_t idx, const int8_t *src, int8_t srcOffset,
1701 int8_t destOffset, int32_t destPre,
1702 int32_t destPost, int32_t destScale) {
1703 int32_t reluVal = MAX(src[idx], srcOffset);
1704 int32_t scaledVal = libjit_scale_i32i8(reluVal - srcOffset, destPre, destPost,
1705 destScale, destOffset);
1706 return libjit_clip(scaledVal);
1707 }
1708
libjit_element_clip_f(dim_t idx,const float * src,float min,float max)1709 float libjit_element_clip_f(dim_t idx, const float *src, float min, float max) {
1710 float srcVal = src[idx];
1711 return MIN(MAX(srcVal, min), max);
1712 }
1713
libjit_element_clip_i8(dim_t idx,const int8_t * src,int8_t clipMin,int8_t clipMax,int8_t srcOffset,int8_t destOffset,int32_t destPre,int32_t destPost,int32_t destScale)1714 int8_t libjit_element_clip_i8(dim_t idx, const int8_t *src, int8_t clipMin,
1715 int8_t clipMax, int8_t srcOffset,
1716 int8_t destOffset, int32_t destPre,
1717 int32_t destPost, int32_t destScale) {
1718 int32_t clipVal = MIN(MAX(src[idx], clipMin), clipMax);
1719 int32_t scaledVal = libjit_scale_i32i8(clipVal - srcOffset, destPre, destPost,
1720 destScale, destOffset);
1721 return libjit_clip(scaledVal);
1722 }
1723
1724 // When the LIBJIT compile option "-ffast-math" is enabled the intermediate
1725 // computation expf(x) for Sigmoid operator is not handled properly for very
1726 // large positive values which results in NaN values for the Sigmoid output.
1727 // Therefore when the "-ffast-math" is enabled we compute the Sigmoid such that
1728 // we avoid computing large values for the "expf" function.
1729 #ifdef FFAST_MATH
DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_sigmoid_kernel_f)1730 DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_sigmoid_kernel_f) {
1731 float inpVal = LHS[idx];
1732 float sigmoidVal = 1 / (1 + expf(-std::abs(inpVal)));
1733 return (float)(std::signbit(inpVal)) + std::copysignf(sigmoidVal, inpVal);
1734 }
1735 #else
DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_sigmoid_kernel_f)1736 DEFINE_DATA_PARALLEL_KERNEL_FUNC(libjit_sigmoid_kernel_f) {
1737 float e = expf(-LHS[idx]);
1738 return 1 / (e + 1);
1739 }
1740 #endif // FFAST_MATH
1741
DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_splat_kernel_f,float,val)1742 DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_splat_kernel_f, float, val)
1743 DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_splat_kernel_u, int64_t,
1744 val)
1745 DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_splat_kernel_i8, int8_t,
1746 val)
1747 DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_splat_kernel_i32, int32_t,
1748 val)
1749 DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND(libjit_splat_kernel_b, int8_t, val)
1750
1751 #undef DEFINE_DATA_PARALLEL_KERNEL
1752 #undef DEFINE_DATA_PARALLEL_KERNEL_FUNC
1753 #undef DEFINE_DATA_PARALLEL_KERNEL_FUNC
1754 #undef DEFINE_DATA_PARALLEL_KERNEL_WITH_IMM_OPERAND
1755
1756 void libjit_batchedadd_f(float *dest, const float *batch, const float *slice,
1757 dim_t numSlice, dim_t sliceSize) {
1758 // For each layer in the batch:
1759 for (dim_t n = 0; n < numSlice; n++) {
1760 dim_t base = n * sliceSize;
1761 // For each element in the slice.
1762 for (dim_t i = 0; i < sliceSize; i++) {
1763 dest[base + i] = batch[base + i] + slice[i];
1764 }
1765 }
1766 }
1767
libjit_batchedadd_i8(int8_t * dest,const int8_t * batch,const int8_t * slice,dim_t numSlice,dim_t sliceSize,int32_t destOffset,int32_t batchOffset,int32_t sliceOffset,int32_t batchPre,int32_t batchPost,int32_t batchScale,int32_t slicePre,int32_t slicePost,int32_t sliceScale)1768 void libjit_batchedadd_i8(int8_t *dest, const int8_t *batch,
1769 const int8_t *slice, dim_t numSlice, dim_t sliceSize,
1770 int32_t destOffset, int32_t batchOffset,
1771 int32_t sliceOffset, int32_t batchPre,
1772 int32_t batchPost, int32_t batchScale,
1773 int32_t slicePre, int32_t slicePost,
1774 int32_t sliceScale) {
1775 libjit_batchedadd_quantized(dest, batch, slice, numSlice, sliceSize,
1776 destOffset, batchOffset, sliceOffset, batchPre,
1777 batchPost, batchScale, slicePre, slicePost,
1778 sliceScale);
1779 }
1780
libjit_batchedadd_i32_i8(int8_t * dest,const int8_t * batch,const int32_t * slice,dim_t numSlice,dim_t sliceSize,int32_t destOffset,int32_t batchOffset,int32_t sliceOffset,int32_t batchPre,int32_t batchPost,int32_t batchScale,int32_t slicePre,int32_t slicePost,int32_t sliceScale)1781 void libjit_batchedadd_i32_i8(int8_t *dest, const int8_t *batch,
1782 const int32_t *slice, dim_t numSlice,
1783 dim_t sliceSize, int32_t destOffset,
1784 int32_t batchOffset, int32_t sliceOffset,
1785 int32_t batchPre, int32_t batchPost,
1786 int32_t batchScale, int32_t slicePre,
1787 int32_t slicePost, int32_t sliceScale) {
1788 libjit_batchedadd_quantized(dest, batch, slice, numSlice, sliceSize,
1789 destOffset, batchOffset, sliceOffset, batchPre,
1790 batchPost, batchScale, slicePre, slicePost,
1791 sliceScale);
1792 }
1793
1794 /// The dimensions passed in here are pre-expanded in LLVMIRGen with 1s so that
1795 /// we can iterate over the shape here, regardless of the shape of the tensor.
libjit_batchedreduceadd_f(float * dest,const float * batch,dim_t destSize,const dim_t * destDims,const dim_t * batchDims,dim_t axis)1796 void libjit_batchedreduceadd_f(float *dest, const float *batch, dim_t destSize,
1797 const dim_t *destDims, const dim_t *batchDims,
1798 dim_t axis) {
1799 for (dim_t i = 0; i < destSize; i++)
1800 dest[i] = 0.0;
1801
1802 for (dim_t x = 0; x < batchDims[0]; x++)
1803 for (dim_t y = 0; y < batchDims[1]; y++)
1804 for (dim_t z = 0; z < batchDims[2]; z++)
1805 for (dim_t w = 0; w < batchDims[3]; w++)
1806 for (dim_t q = 0; q < batchDims[4]; q++)
1807 for (dim_t r = 0; r < batchDims[5]; r++) {
1808 dim_t I[] = {x, y, z, w, q, r};
1809 I[axis] = 0;
1810 dest[libjit_getXYZWQR(destDims, I[0], I[1], I[2], I[3], I[4],
1811 I[5])] +=
1812 batch[libjit_getXYZWQR(batchDims, x, y, z, w, q, r)];
1813 }
1814 }
1815
libjit_reducemin_f(float * dest,const float * batch,size_t destSize,const dim_t * destDims,const dim_t * batchDims)1816 void libjit_reducemin_f(float *dest, const float *batch, size_t destSize,
1817 const dim_t *destDims, const dim_t *batchDims) {
1818 libjit_reducemin(dest, batch, destSize, destDims, batchDims,
1819 std::numeric_limits<float>::max());
1820 }
1821
libjit_reducemin_i32(int32_t * dest,const int32_t * batch,size_t destSize,const dim_t * destDims,const dim_t * batchDims)1822 void libjit_reducemin_i32(int32_t *dest, const int32_t *batch, size_t destSize,
1823 const dim_t *destDims, const dim_t *batchDims) {
1824 libjit_reducemin(dest, batch, destSize, destDims, batchDims,
1825 std::numeric_limits<int32_t>::max());
1826 }
1827
libjit_reducemin_u(int64_t * dest,const int64_t * batch,size_t destSize,const dim_t * destDims,const dim_t * batchDims)1828 void libjit_reducemin_u(int64_t *dest, const int64_t *batch, size_t destSize,
1829 const dim_t *destDims, const dim_t *batchDims) {
1830 libjit_reducemin(dest, batch, destSize, destDims, batchDims,
1831 std::numeric_limits<int64_t>::max());
1832 }
1833
1834 /// Same as the non-quantized version, the dimensions here are pre-expanded in
1835 /// LLVMIRGen. However, for quantization, we must accumulate in the inner-most
1836 /// loop with higher precision (int32_t) and then clip the result back into the
1837 /// dest tensor. Thus we add max_tensor_dimensions different cases for this to
1838 /// ensure the axis is used as the inner-most loop.
libjit_batchedreduceadd_i8(int8_t * dest,const int8_t * batch,const dim_t * destDims,const dim_t * batchDims,int32_t destOffset,int32_t batchOffset,int32_t batchPre,int32_t batchPost,int32_t batchScale,dim_t axis)1839 void libjit_batchedreduceadd_i8(int8_t *dest, const int8_t *batch,
1840 const dim_t *destDims, const dim_t *batchDims,
1841 int32_t destOffset, int32_t batchOffset,
1842 int32_t batchPre, int32_t batchPost,
1843 int32_t batchScale, dim_t axis) {
1844 switch (axis) {
1845 #define LOOP_AXIS_CASE(_D0, _D1, _D2, _D3, _D4, _D5_AXIS) \
1846 case _D5_AXIS: \
1847 for (dim_t i##_D0 = 0; i##_D0 < batchDims[_D0]; i##_D0++) \
1848 for (dim_t i##_D1 = 0; i##_D1 < batchDims[_D1]; i##_D1++) \
1849 for (dim_t i##_D2 = 0; i##_D2 < batchDims[_D2]; i##_D2++) \
1850 for (dim_t i##_D3 = 0; i##_D3 < batchDims[_D3]; i##_D3++) \
1851 for (dim_t i##_D4 = 0; i##_D4 < batchDims[_D4]; i##_D4++) { \
1852 int32_t sum = 0.0; \
1853 for (dim_t i##_D5_AXIS = 0; i##_D5_AXIS < batchDims[_D5_AXIS]; \
1854 i##_D5_AXIS++) { \
1855 sum += batch[libjit_getXYZWQR(batchDims, i0, i1, i2, i3, i4, \
1856 i5)] - \
1857 batchOffset; \
1858 } \
1859 dim_t i##_D5_AXIS = 0; \
1860 int32_t res = libjit_scale_i32i8(sum, batchPre, batchPost, \
1861 batchScale, destOffset); \
1862 dest[libjit_getXYZWQR(destDims, i0, i1, i2, i3, i4, i5)] = \
1863 libjit_clip(res); \
1864 } \
1865 return;
1866
1867 // Each loop order, with the inner-most dimension/index equal to the axis.
1868 LOOP_AXIS_CASE(1, 2, 3, 4, 5, 0);
1869 LOOP_AXIS_CASE(0, 2, 3, 4, 5, 1);
1870 LOOP_AXIS_CASE(0, 1, 3, 4, 5, 2);
1871 LOOP_AXIS_CASE(0, 1, 2, 4, 5, 3);
1872 LOOP_AXIS_CASE(0, 1, 2, 3, 5, 4);
1873 LOOP_AXIS_CASE(0, 1, 2, 3, 4, 5);
1874 #undef LOOP_AXIS_CASE
1875 }
1876 }
1877
libjit_cross_entropy_loss_f_u(float * CE,float * P,size_t * labels,dim_t * dims)1878 void libjit_cross_entropy_loss_f_u(float *CE, float *P, size_t *labels,
1879 dim_t *dims) {
1880 libjit_cross_entropy_loss_generic(CE, P, labels, dims);
1881 }
1882
libjit_cross_entropy_loss_f_i32(float * CE,float * P,int32_t * labels,dim_t * dims)1883 void libjit_cross_entropy_loss_f_i32(float *CE, float *P, int32_t *labels,
1884 dim_t *dims) {
1885 libjit_cross_entropy_loss_generic(CE, P, labels, dims);
1886 }
1887
libjit_gather64_f(float * dest,const float * data,const int64_t * indices,dim_t numIndices,dim_t sliceSize,dim_t numSamples,dim_t sampleSize)1888 void libjit_gather64_f(float *dest, const float *data, const int64_t *indices,
1889 dim_t numIndices, dim_t sliceSize, dim_t numSamples,
1890 dim_t sampleSize) {
1891 libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
1892 sampleSize);
1893 }
1894
libjit_gather64_i8(int8_t * dest,const int8_t * data,const int64_t * indices,dim_t numIndices,dim_t sliceSize,dim_t numSamples,dim_t sampleSize)1895 void libjit_gather64_i8(int8_t *dest, const int8_t *data,
1896 const int64_t *indices, dim_t numIndices,
1897 dim_t sliceSize, dim_t numSamples, dim_t sampleSize) {
1898 libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
1899 sampleSize);
1900 }
1901
libjit_gather64_u(int64_t * dest,const int64_t * data,const int64_t * indices,dim_t numIndices,dim_t sliceSize,dim_t numSamples,dim_t sampleSize)1902 void libjit_gather64_u(int64_t *dest, const int64_t *data,
1903 const int64_t *indices, dim_t numIndices,
1904 dim_t sliceSize, dim_t numSamples, dim_t sampleSize) {
1905 libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
1906 sampleSize);
1907 }
1908
libjit_gather32_f(float * dest,const float * data,const int32_t * indices,dim_t numIndices,dim_t sliceSize,dim_t numSamples,dim_t sampleSize)1909 void libjit_gather32_f(float *dest, const float *data, const int32_t *indices,
1910 dim_t numIndices, dim_t sliceSize, dim_t numSamples,
1911 dim_t sampleSize) {
1912 libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
1913 sampleSize);
1914 }
1915
libjit_gather32_i8(int8_t * dest,const int8_t * data,const int32_t * indices,dim_t numIndices,dim_t sliceSize,dim_t numSamples,dim_t sampleSize)1916 void libjit_gather32_i8(int8_t *dest, const int8_t *data,
1917 const int32_t *indices, dim_t numIndices,
1918 dim_t sliceSize, dim_t numSamples, dim_t sampleSize) {
1919 libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
1920 sampleSize);
1921 }
1922
libjit_gather32_u(int64_t * dest,const int64_t * data,const int32_t * indices,dim_t numIndices,dim_t sliceSize,dim_t numSamples,dim_t sampleSize)1923 void libjit_gather32_u(int64_t *dest, const int64_t *data,
1924 const int32_t *indices, dim_t numIndices,
1925 dim_t sliceSize, dim_t numSamples, dim_t sampleSize) {
1926 libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
1927 sampleSize);
1928 }
1929
libjit_gather32_i32(int32_t * dest,const int32_t * data,const int32_t * indices,dim_t numIndices,dim_t sliceSize,dim_t numSamples,dim_t sampleSize)1930 void libjit_gather32_i32(int32_t *dest, const int32_t *data,
1931 const int32_t *indices, dim_t numIndices,
1932 dim_t sliceSize, dim_t numSamples, dim_t sampleSize) {
1933 libjit_gather(dest, data, indices, numIndices, sliceSize, numSamples,
1934 sampleSize);
1935 }
1936
libjit_gatherranges64_f(float * output,int64_t * lengths,const float * data,const int64_t * ranges,dim_t numExamples,dim_t exampleSize)1937 void libjit_gatherranges64_f(float *output, int64_t *lengths, const float *data,
1938 const int64_t *ranges, dim_t numExamples,
1939 dim_t exampleSize) {
1940 libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
1941 }
1942
libjit_gatherranges64_i8(int8_t * output,int64_t * lengths,const int8_t * data,const int64_t * ranges,dim_t numExamples,dim_t exampleSize)1943 void libjit_gatherranges64_i8(int8_t *output, int64_t *lengths,
1944 const int8_t *data, const int64_t *ranges,
1945 dim_t numExamples, dim_t exampleSize) {
1946 libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
1947 }
1948
libjit_gatherranges64_u(int64_t * output,int64_t * lengths,const int64_t * data,const int64_t * ranges,dim_t numExamples,dim_t exampleSize)1949 void libjit_gatherranges64_u(int64_t *output, int64_t *lengths,
1950 const int64_t *data, const int64_t *ranges,
1951 dim_t numExamples, dim_t exampleSize) {
1952 libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
1953 }
1954
libjit_gatherranges32_f(float * output,int32_t * lengths,const float * data,const int32_t * ranges,dim_t numExamples,dim_t exampleSize)1955 void libjit_gatherranges32_f(float *output, int32_t *lengths, const float *data,
1956 const int32_t *ranges, dim_t numExamples,
1957 dim_t exampleSize) {
1958 libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
1959 }
1960
libjit_gatherranges32_i8(int8_t * output,int32_t * lengths,const int8_t * data,const int32_t * ranges,dim_t numExamples,dim_t exampleSize)1961 void libjit_gatherranges32_i8(int8_t *output, int32_t *lengths,
1962 const int8_t *data, const int32_t *ranges,
1963 dim_t numExamples, dim_t exampleSize) {
1964 libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
1965 }
1966
libjit_gatherranges32_u(uint64_t * output,int32_t * lengths,const uint64_t * data,const int32_t * ranges,dim_t numExamples,dim_t exampleSize)1967 void libjit_gatherranges32_u(uint64_t *output, int32_t *lengths,
1968 const uint64_t *data, const int32_t *ranges,
1969 dim_t numExamples, dim_t exampleSize) {
1970 libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
1971 }
1972
libjit_gatherranges32_i32(int32_t * output,int32_t * lengths,const int32_t * data,const int32_t * ranges,dim_t numExamples,dim_t exampleSize)1973 void libjit_gatherranges32_i32(int32_t *output, int32_t *lengths,
1974 const int32_t *data, const int32_t *ranges,
1975 dim_t numExamples, dim_t exampleSize) {
1976 libjit_gatherranges(output, lengths, data, ranges, numExamples, exampleSize);
1977 }
1978
libjit_lengths_range_fill_i32(const int32_t * lengths,int32_t * output,const dim_t lengthsSize)1979 void libjit_lengths_range_fill_i32(const int32_t *lengths, int32_t *output,
1980 const dim_t lengthsSize) {
1981 dim_t curIdx = 0;
1982 for (dim_t i = 0, e = lengthsSize; i < e; i++) {
1983 for (int32_t j = 0, f = lengths[i]; j < f; j++) {
1984 output[curIdx++] = j;
1985 }
1986 }
1987 }
1988
libjit_scatterdata_f_i32(float * data,const dim_t * dataDims,const int32_t * indices,const float * slices,dim_t numIndices,dim_t indexSize,dim_t sliceSize,bool isCumulative)1989 void libjit_scatterdata_f_i32(float *data, const dim_t *dataDims,
1990 const int32_t *indices, const float *slices,
1991 dim_t numIndices, dim_t indexSize,
1992 dim_t sliceSize, bool isCumulative) {
1993 if (isCumulative) {
1994 libjit_scatterdataaddfloat(data, dataDims, indices, slices, numIndices,
1995 indexSize, sliceSize);
1996 } else {
1997 libjit_scatterdatacopy(data, dataDims, indices, slices, numIndices,
1998 indexSize, sliceSize);
1999 }
2000 }
2001
libjit_scatterdata_i8_u(int8_t * data,const dim_t * dataDims,const int64_t * indices,const int8_t * slices,dim_t numIndices,dim_t indexSize,dim_t sliceSize,bool isCumulative,float dataScale,int32_t dataOffset,float sliceScale,int32_t sliceOffset)2002 void libjit_scatterdata_i8_u(int8_t *data, const dim_t *dataDims,
2003 const int64_t *indices, const int8_t *slices,
2004 dim_t numIndices, dim_t indexSize, dim_t sliceSize,
2005 bool isCumulative, float dataScale,
2006 int32_t dataOffset, float sliceScale,
2007 int32_t sliceOffset) {
2008 if (isCumulative) {
2009 libjit_scatterdataaddquantized(data, dataDims, indices, slices, numIndices,
2010 indexSize, sliceSize, dataScale, dataOffset,
2011 sliceScale, sliceOffset);
2012 } else {
2013 libjit_scatterdatacopy(data, dataDims, indices, slices, numIndices,
2014 indexSize, sliceSize);
2015 }
2016 }
2017
libjit_scatterdata_i8_i32(int8_t * data,const dim_t * dataDims,const int32_t * indices,const int8_t * slices,dim_t numIndices,dim_t indexSize,dim_t sliceSize,bool isCumulative,float dataScale,int32_t dataOffset,float sliceScale,int32_t sliceOffset)2018 void libjit_scatterdata_i8_i32(int8_t *data, const dim_t *dataDims,
2019 const int32_t *indices, const int8_t *slices,
2020 dim_t numIndices, dim_t indexSize,
2021 dim_t sliceSize, bool isCumulative,
2022 float dataScale, int32_t dataOffset,
2023 float sliceScale, int32_t sliceOffset) {
2024 if (isCumulative) {
2025 libjit_scatterdataaddquantized(data, dataDims, indices, slices, numIndices,
2026 indexSize, sliceSize, dataScale, dataOffset,
2027 sliceScale, sliceOffset);
2028 } else {
2029 libjit_scatterdatacopy(data, dataDims, indices, slices, numIndices,
2030 indexSize, sliceSize);
2031 }
2032 }
2033
libjit_lengths_to_ranges_i32(int32_t * ranges,const int32_t * lengths,dim_t size)2034 void libjit_lengths_to_ranges_i32(int32_t *ranges, const int32_t *lengths,
2035 dim_t size) {
2036 int32_t offset = 0;
2037 for (dim_t i = 0; i < size; i++) {
2038 auto length = lengths[i];
2039 ranges[i * 2] = offset;
2040 ranges[i * 2 + 1] = length;
2041 offset += length;
2042 }
2043 }
2044
libjit_sparse_lengths_sum_f_u(float * dest,float * data,size_t * indices,int32_t * lengths,dim_t segments,dim_t lineSize)2045 void libjit_sparse_lengths_sum_f_u(float *dest, float *data, size_t *indices,
2046 int32_t *lengths, dim_t segments,
2047 dim_t lineSize) {
2048 libjit_sparse_lengths_sum_generic(dest, data, indices, lengths, segments,
2049 lineSize);
2050 }
2051
libjit_sparse_lengths_sum_f_i32(float * dest,float * data,int32_t * indices,int32_t * lengths,dim_t segments,dim_t lineSize)2052 void libjit_sparse_lengths_sum_f_i32(float *dest, float *data, int32_t *indices,
2053 int32_t *lengths, dim_t segments,
2054 dim_t lineSize) {
2055 libjit_sparse_lengths_sum_generic(dest, data, indices, lengths, segments,
2056 lineSize);
2057 }
2058
libjit_sparse_lengths_weighted_sum_f_u(float * dest,float * data,float * weights,size_t * indices,int32_t * lengths,dim_t segments,dim_t lineSize)2059 void libjit_sparse_lengths_weighted_sum_f_u(float *dest, float *data,
2060 float *weights, size_t *indices,
2061 int32_t *lengths, dim_t segments,
2062 dim_t lineSize) {
2063 libjit_sparse_lengths_weighted_sum_generic(dest, data, weights, indices,
2064 lengths, segments, lineSize);
2065 }
2066
libjit_sparse_lengths_weighted_sum_f_i32(float * dest,float * data,float * weights,int32_t * indices,int32_t * lengths,dim_t segments,dim_t lineSize)2067 void libjit_sparse_lengths_weighted_sum_f_i32(float *dest, float *data,
2068 float *weights, int32_t *indices,
2069 int32_t *lengths, dim_t segments,
2070 dim_t lineSize) {
2071 libjit_sparse_lengths_weighted_sum_generic(dest, data, weights, indices,
2072 lengths, segments, lineSize);
2073 }
2074
libjit_embedding_bag_f(float * dest,float * data,float * weights,size_t * indices,size_t * offsets,dim_t segments,dim_t lineSize,dim_t totalLength,bool hasEndOffset)2075 void libjit_embedding_bag_f(float *dest, float *data, float *weights,
2076 size_t *indices, size_t *offsets, dim_t segments,
2077 dim_t lineSize, dim_t totalLength,
2078 bool hasEndOffset) {
2079 if (hasEndOffset) {
2080 --segments;
2081 }
2082 memset(dest, 0, segments * lineSize * sizeof(float));
2083 dim_t curIndex = 0;
2084 for (dim_t i = 0; i < segments; i++) {
2085 int64_t start = offsets[i];
2086 int64_t end =
2087 !hasEndOffset && i == segments - 1 ? totalLength : offsets[i + 1];
2088 for (int64_t j = start; j < end; j++) {
2089 float weight = weights[curIndex];
2090 dim_t line = indices[curIndex];
2091 for (dim_t k = 0; k < lineSize; k++) {
2092 dest[i * lineSize + k] += weight * data[line * lineSize + k];
2093 }
2094 curIndex++;
2095 }
2096 }
2097 }
2098
libjit_sparse_lengths_weighted_sum_grad_f_u(const float * destGrad,float * dataGrad,float * weightsGrad,const float * data,const float * weights,const size_t * indices,const int32_t * lengths,dim_t segments,dim_t lineSize,dim_t dataGradRawSize)2099 void libjit_sparse_lengths_weighted_sum_grad_f_u(
2100 const float *destGrad, float *dataGrad, float *weightsGrad,
2101 const float *data, const float *weights, const size_t *indices,
2102 const int32_t *lengths, dim_t segments, dim_t lineSize,
2103 dim_t dataGradRawSize) {
2104 libjit_sparse_lengths_weighted_sum_grad_generic(
2105 destGrad, dataGrad, weightsGrad, data, weights, indices, lengths,
2106 segments, lineSize, dataGradRawSize);
2107 }
2108
libjit_sparse_lengths_weighted_sum_grad_f_i32(const float * destGrad,float * dataGrad,float * weightsGrad,const float * data,const float * weights,const int32_t * indices,const int32_t * lengths,dim_t segments,dim_t lineSize,dim_t dataGradRawSize)2109 void libjit_sparse_lengths_weighted_sum_grad_f_i32(
2110 const float *destGrad, float *dataGrad, float *weightsGrad,
2111 const float *data, const float *weights, const int32_t *indices,
2112 const int32_t *lengths, dim_t segments, dim_t lineSize,
2113 dim_t dataGradRawSize) {
2114 libjit_sparse_lengths_weighted_sum_grad_generic(
2115 destGrad, dataGrad, weightsGrad, data, weights, indices, lengths,
2116 segments, lineSize, dataGradRawSize);
2117 }
2118
libjit_rowwise_quantized_sparse_lengths_weighted_sum_f_u(float * dest,uint8_t * data,float * scales,float * offsets,float * weights,size_t * indices,int32_t * lengths,dim_t segments,dim_t lineSize)2119 void libjit_rowwise_quantized_sparse_lengths_weighted_sum_f_u(
2120 float *dest, uint8_t *data, float *scales, float *offsets, float *weights,
2121 size_t *indices, int32_t *lengths, dim_t segments, dim_t lineSize) {
2122 libjit_rowwise_quantized_sparse_lengths_weighted_sum_generic(
2123 dest, data, scales, offsets, weights, indices, lengths, segments,
2124 lineSize);
2125 }
2126
libjit_rowwise_quantized_sparse_lengths_weighted_sum_f_i32(float * dest,uint8_t * data,float * scales,float * offsets,float * weights,int32_t * indices,int32_t * lengths,dim_t segments,dim_t lineSize)2127 void libjit_rowwise_quantized_sparse_lengths_weighted_sum_f_i32(
2128 float *dest, uint8_t *data, float *scales, float *offsets, float *weights,
2129 int32_t *indices, int32_t *lengths, dim_t segments, dim_t lineSize) {
2130 libjit_rowwise_quantized_sparse_lengths_weighted_sum_generic(
2131 dest, data, scales, offsets, weights, indices, lengths, segments,
2132 lineSize);
2133 }
2134
libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f_u(float * dest,int8_t * data,float * weights,size_t * indices,int32_t * lengths,dim_t segments,dim_t inLineSize,dim_t outLineSize)2135 void libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f_u(
2136 float *dest, int8_t *data, float *weights, size_t *indices,
2137 int32_t *lengths, dim_t segments, dim_t inLineSize, dim_t outLineSize) {
2138 libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_generic(
2139 dest, data, weights, indices, lengths, segments, inLineSize, outLineSize);
2140 }
2141
libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f_i32(float * dest,int8_t * data,float * weights,int32_t * indices,int32_t * lengths,dim_t segments,dim_t inLineSize,dim_t outLineSize)2142 void libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f_i32(
2143 float *dest, int8_t *data, float *weights, int32_t *indices,
2144 int32_t *lengths, dim_t segments, dim_t inLineSize, dim_t outLineSize) {
2145 libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_generic(
2146 dest, data, weights, indices, lengths, segments, inLineSize, outLineSize);
2147 }
2148
libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f(float * dest,int8_t * data,float * weights,dim_t * indices,int32_t * lengths,dim_t segments,dim_t inLineSize,dim_t outLineSize)2149 void libjit_fused_rowwise_quantized_sparse_lengths_weighted_sum_f(
2150 float *dest, int8_t *data, float *weights, dim_t *indices, int32_t *lengths,
2151 dim_t segments, dim_t inLineSize, dim_t outLineSize) {
2152 memset(dest, 0, segments * outLineSize * sizeof(float));
2153 dim_t curIndex = 0;
2154 for (dim_t i = 0; i < segments; i++) {
2155 for (int32_t j = 0, e = lengths[i]; j < e; j++) {
2156 const float weight = weights[curIndex];
2157 const dim_t line = indices[curIndex];
2158 const int8_t *currRowScaleOffsetPtr =
2159 data + ((line + 1) * inLineSize) - 2 * sizeof(float);
2160 float scale, offset;
2161 memcpy(&scale, currRowScaleOffsetPtr, sizeof(float));
2162 memcpy(&offset, currRowScaleOffsetPtr + sizeof(float), sizeof(float));
2163 for (dim_t k = 0; k < outLineSize; k++) {
2164 const float fData =
2165 (scale * (uint8_t)(data[line * inLineSize + k])) + offset;
2166 dest[i * outLineSize + k] += weight * fData;
2167 }
2168 curIndex++;
2169 }
2170 }
2171 }
2172
libjit_embedding_bag_byte_rowwise_offsets_f(float * dest,int8_t * data,float * weights,size_t * indices,size_t * offsets,dim_t segments,dim_t numIndices,dim_t inLineSize,dim_t outLineSize,bool hasEndOffset)2173 void libjit_embedding_bag_byte_rowwise_offsets_f(
2174 float *dest, int8_t *data, float *weights, size_t *indices, size_t *offsets,
2175 dim_t segments, dim_t numIndices, dim_t inLineSize, dim_t outLineSize,
2176 bool hasEndOffset) {
2177 if (hasEndOffset) {
2178 --segments;
2179 }
2180 memset(dest, 0, segments * outLineSize * sizeof(float));
2181 for (dim_t i = 0; i < segments; i++) {
2182 dim_t start = offsets[i];
2183 dim_t end =
2184 !hasEndOffset && i == segments - 1 ? numIndices : offsets[i + 1];
2185 for (dim_t j = start; j < end; j++) {
2186 const float weight = weights[j];
2187 const dim_t line = indices[j];
2188 const int8_t *currRowScaleOffsetPtr =
2189 data + ((line + 1) * inLineSize) - 2 * sizeof(float);
2190 float scale, offset;
2191 memcpy(&scale, currRowScaleOffsetPtr, sizeof(float));
2192 memcpy(&offset, currRowScaleOffsetPtr + sizeof(float), sizeof(float));
2193 for (dim_t k = 0; k < outLineSize; k++) {
2194 const float fData =
2195 (scale * (uint8_t)(data[line * inLineSize + k])) + offset;
2196 dest[i * outLineSize + k] += weight * fData;
2197 }
2198 }
2199 }
2200 }
2201
libjit_sparse_to_dense_f_u(float * dest,const size_t * indices,const float * values,dim_t numIndices,dim_t destSize,dim_t valueSize)2202 void libjit_sparse_to_dense_f_u(float *dest, const size_t *indices,
2203 const float *values, dim_t numIndices,
2204 dim_t destSize, dim_t valueSize) {
2205 libjit_sparse_to_dense_generic(dest, indices, values, numIndices, destSize,
2206 valueSize);
2207 }
2208
libjit_sparse_to_dense_f_i32(float * dest,const int32_t * indices,const float * values,dim_t numIndices,dim_t destSize,dim_t valueSize)2209 void libjit_sparse_to_dense_f_i32(float *dest, const int32_t *indices,
2210 const float *values, dim_t numIndices,
2211 dim_t destSize, dim_t valueSize) {
2212 libjit_sparse_to_dense_generic(dest, indices, values, numIndices, destSize,
2213 valueSize);
2214 }
2215
libjit_lengths_sum_f(float * dest,const float * data,const int32_t * lengths,dim_t destSize,dim_t lengthsSize,dim_t sliceSize)2216 void libjit_lengths_sum_f(float *dest, const float *data,
2217 const int32_t *lengths, dim_t destSize,
2218 dim_t lengthsSize, dim_t sliceSize) {
2219 memset(dest, 0, destSize * sizeof(float));
2220
2221 dim_t offsetOut = 0;
2222 dim_t offsetIn = 0;
2223
2224 for (dim_t i = 0; i < lengthsSize; ++i) {
2225 for (int32_t j = 0; j < lengths[i]; ++j) {
2226 for (dim_t k = 0; k < sliceSize; ++k) {
2227 dest[offsetOut + k] += data[offsetIn + k];
2228 }
2229 offsetIn += sliceSize;
2230 }
2231 offsetOut += sliceSize;
2232 }
2233 }
2234
libjit_local_response_normalization_f(float * outW,const float * inW,float * scaleCache,const dim_t * outWdims,const dim_t * inWdims,dim_t halfWindow,float alpha,float beta,float k)2235 void libjit_local_response_normalization_f(
2236 float *outW, const float *inW, float *scaleCache, const dim_t *outWdims,
2237 const dim_t *inWdims, dim_t halfWindow, float alpha, float beta, float k) {
2238 dim_t window = 2 * halfWindow + 1;
2239 float normedAlpha = alpha / window;
2240
2241 for (dim_t n = 0; n < inWdims[0]; n++) {
2242 for (dim_t h = 0; h < inWdims[1]; h++) {
2243 for (dim_t w = 0; w < inWdims[2]; w++) {
2244 for (dim_t c = 0; c < inWdims[3]; c++) {
2245 float m2 = 0.0;
2246 for (dim_t i = (c >= halfWindow ? c - halfWindow : 0);
2247 i <= MIN(c + halfWindow, inWdims[3] - 1); i++) {
2248 float val = inW[libjit_getXYZW(inWdims, n, h, w, i)];
2249 m2 += val * val;
2250 }
2251
2252 float scale = k + normedAlpha * m2;
2253 scaleCache[libjit_getXYZW(inWdims, n, h, w, c)] = scale;
2254 float normFactor = pow(scale, -beta);
2255 outW[libjit_getXYZW(outWdims, n, h, w, c)] =
2256 inW[libjit_getXYZW(inWdims, n, h, w, c)] * normFactor;
2257 } // C
2258 } // W
2259 } // H
2260 } // N
2261 }
2262
libjit_local_response_normalization_grad_f(float * inG,const float * outG,const float * inW,const float * outW,const float * scaleCache,const dim_t * outWdims,dim_t halfWindow,float alpha,float beta)2263 void libjit_local_response_normalization_grad_f(
2264 float *inG, const float *outG, const float *inW, const float *outW,
2265 const float *scaleCache, const dim_t *outWdims, dim_t halfWindow,
2266 float alpha, float beta) {
2267 dim_t window = 2 * halfWindow + 1;
2268 float normedAlpha = alpha / window;
2269 float coeff = 2 * normedAlpha * beta;
2270
2271 for (dim_t n = 0; n < outWdims[0]; n++) {
2272 for (dim_t h = 0; h < outWdims[1]; h++) {
2273 for (dim_t w = 0; w < outWdims[2]; w++) {
2274 // Prepare right half of sliding window based at c = 0
2275 float sum = 0.0;
2276 for (dim_t i = 0; i < MIN(halfWindow, outWdims[3]); i++) {
2277 float outg = outG[libjit_getXYZW(outWdims, n, h, w, i)];
2278 float outw = outW[libjit_getXYZW(outWdims, n, h, w, i)];
2279 float scale = scaleCache[libjit_getXYZW(outWdims, n, h, w, i)];
2280 sum += outg * (outw / scale);
2281 }
2282
2283 for (dim_t c = 0; c < outWdims[3]; c++) {
2284 if (c > halfWindow) {
2285 dim_t j = c - halfWindow - 1;
2286 float outg = outG[libjit_getXYZW(outWdims, n, h, w, j)];
2287 float outw = outW[libjit_getXYZW(outWdims, n, h, w, j)];
2288 float scale = scaleCache[libjit_getXYZW(outWdims, n, h, w, j)];
2289 sum -= outg * (outw / scale);
2290 }
2291
2292 dim_t j = c + halfWindow;
2293 if (j < outWdims[3]) {
2294 float outg = outG[libjit_getXYZW(outWdims, n, h, w, j)];
2295 float outw = outW[libjit_getXYZW(outWdims, n, h, w, j)];
2296 float scale = scaleCache[libjit_getXYZW(outWdims, n, h, w, j)];
2297 sum += outg * (outw / scale);
2298 }
2299
2300 float outg = outG[libjit_getXYZW(outWdims, n, h, w, c)];
2301 float inw = inW[libjit_getXYZW(outWdims, n, h, w, c)];
2302 float scale = scaleCache[libjit_getXYZW(outWdims, n, h, w, c)];
2303 inG[libjit_getXYZW(outWdims, n, h, w, c)] =
2304 outg * pow(scale, -beta) - coeff * inw * sum;
2305 }
2306 } // W
2307 } // H
2308 } // N
2309 }
2310
libjit_max_pool_i8(const int8_t * inW,int8_t * outW,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernelSizes,dim_t * strides,dim_t * pads)2311 void libjit_max_pool_i8(const int8_t *inW, int8_t *outW, const dim_t *inWdims,
2312 const dim_t *outWdims, dim_t *kernelSizes,
2313 dim_t *strides, dim_t *pads) {
2314 libjit_max_pool_generic(inW, outW, inWdims, outWdims, kernelSizes, strides,
2315 pads);
2316 }
2317
libjit_max_pool_f(const float * inW,float * outW,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernelSizes,dim_t * strides,dim_t * pads)2318 void libjit_max_pool_f(const float *inW, float *outW, const dim_t *inWdims,
2319 const dim_t *outWdims, dim_t *kernelSizes,
2320 dim_t *strides, dim_t *pads) {
2321 libjit_max_pool_generic(inW, outW, inWdims, outWdims, kernelSizes, strides,
2322 pads);
2323 }
2324
libjit_max_pool_argmax_i8_u(const int8_t * inW,int8_t * outW,int64_t * argmax,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernels,dim_t * strides,dim_t * pads)2325 void libjit_max_pool_argmax_i8_u(const int8_t *inW, int8_t *outW,
2326 int64_t *argmax, const dim_t *inWdims,
2327 const dim_t *outWdims, dim_t *kernels,
2328 dim_t *strides, dim_t *pads) {
2329 libjit_max_pool_argmax_generic(inW, outW, argmax, inWdims, outWdims, kernels,
2330 strides, pads);
2331 }
2332
libjit_max_pool_argmax_f_u(const float * inW,float * outW,int64_t * argmax,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernels,dim_t * strides,dim_t * pads)2333 void libjit_max_pool_argmax_f_u(const float *inW, float *outW, int64_t *argmax,
2334 const dim_t *inWdims, const dim_t *outWdims,
2335 dim_t *kernels, dim_t *strides, dim_t *pads) {
2336 libjit_max_pool_argmax_generic(inW, outW, argmax, inWdims, outWdims, kernels,
2337 strides, pads);
2338 }
2339
libjit_max_pool_argmax_i8_i32(const int8_t * inW,int8_t * outW,int32_t * argmax,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernels,dim_t * strides,dim_t * pads)2340 void libjit_max_pool_argmax_i8_i32(const int8_t *inW, int8_t *outW,
2341 int32_t *argmax, const dim_t *inWdims,
2342 const dim_t *outWdims, dim_t *kernels,
2343 dim_t *strides, dim_t *pads) {
2344 libjit_max_pool_argmax_generic(inW, outW, argmax, inWdims, outWdims, kernels,
2345 strides, pads);
2346 }
2347
libjit_max_pool_argmax_f_i32(const float * inW,float * outW,int32_t * argmax,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernels,dim_t * strides,dim_t * pads)2348 void libjit_max_pool_argmax_f_i32(const float *inW, float *outW,
2349 int32_t *argmax, const dim_t *inWdims,
2350 const dim_t *outWdims, dim_t *kernels,
2351 dim_t *strides, dim_t *pads) {
2352 libjit_max_pool_argmax_generic(inW, outW, argmax, inWdims, outWdims, kernels,
2353 strides, pads);
2354 }
2355
libjit_arg_max_i8_u(const int8_t * inW,int64_t * outW,const dim_t * inWdims,size_t inWNumDims,size_t axis)2356 void libjit_arg_max_i8_u(const int8_t *inW, int64_t *outW, const dim_t *inWdims,
2357 size_t inWNumDims, size_t axis) {
2358 libjit_arg_max_generic(inW, outW, inWdims, inWNumDims, axis);
2359 }
2360
libjit_arg_max_i8_i32(const int8_t * inW,int32_t * outW,const dim_t * inWdims,size_t inWNumDims,size_t axis)2361 void libjit_arg_max_i8_i32(const int8_t *inW, int32_t *outW,
2362 const dim_t *inWdims, size_t inWNumDims,
2363 size_t axis) {
2364 libjit_arg_max_generic(inW, outW, inWdims, inWNumDims, axis);
2365 }
2366
libjit_arg_max_f_u(const float * inW,int64_t * outW,const dim_t * inWdims,size_t inWNumDims,size_t axis)2367 void libjit_arg_max_f_u(const float *inW, int64_t *outW, const dim_t *inWdims,
2368 size_t inWNumDims, size_t axis) {
2369 libjit_arg_max_generic(inW, outW, inWdims, inWNumDims, axis);
2370 }
2371
libjit_arg_max_f_i32(const float * inW,int32_t * outW,const dim_t * inWdims,size_t inWNumDims,size_t axis)2372 void libjit_arg_max_f_i32(const float *inW, int32_t *outW, const dim_t *inWdims,
2373 size_t inWNumDims, size_t axis) {
2374 libjit_arg_max_generic(inW, outW, inWdims, inWNumDims, axis);
2375 }
2376
libjit_arg_min_i8_u(const int8_t * inW,int64_t * outW,const dim_t * inWdims,size_t inWNumDims,size_t axis)2377 void libjit_arg_min_i8_u(const int8_t *inW, int64_t *outW, const dim_t *inWdims,
2378 size_t inWNumDims, size_t axis) {
2379 libjit_arg_min_generic(inW, outW, inWdims, inWNumDims, axis);
2380 }
2381
libjit_arg_min_i8_i32(const int8_t * inW,int32_t * outW,const dim_t * inWdims,size_t inWNumDims,size_t axis)2382 void libjit_arg_min_i8_i32(const int8_t *inW, int32_t *outW,
2383 const dim_t *inWdims, size_t inWNumDims,
2384 size_t axis) {
2385 libjit_arg_min_generic(inW, outW, inWdims, inWNumDims, axis);
2386 }
2387
libjit_arg_min_f_u(const float * inW,int64_t * outW,const dim_t * inWdims,size_t inWNumDims,size_t axis)2388 void libjit_arg_min_f_u(const float *inW, int64_t *outW, const dim_t *inWdims,
2389 size_t inWNumDims, size_t axis) {
2390 libjit_arg_min_generic(inW, outW, inWdims, inWNumDims, axis);
2391 }
2392
libjit_arg_min_f_i32(const float * inW,int32_t * outW,const dim_t * inWdims,size_t inWNumDims,size_t axis)2393 void libjit_arg_min_f_i32(const float *inW, int32_t *outW, const dim_t *inWdims,
2394 size_t inWNumDims, size_t axis) {
2395 libjit_arg_min_generic(inW, outW, inWdims, inWNumDims, axis);
2396 }
2397
libjit_max_pool_argmax_grad_f_u(float * inG,const float * outG,const int64_t * argmax,const dim_t * inGdims,const dim_t * outWdims)2398 void libjit_max_pool_argmax_grad_f_u(float *inG, const float *outG,
2399 const int64_t *argmax,
2400 const dim_t *inGdims,
2401 const dim_t *outWdims) {
2402 libjit_max_pool_argmax_grad_generic(inG, outG, argmax, inGdims, outWdims);
2403 }
2404
libjit_max_pool_argmax_grad_f_i32(float * inG,const float * outG,const int32_t * argmax,const dim_t * inGdims,const dim_t * outWdims)2405 void libjit_max_pool_argmax_grad_f_i32(float *inG, const float *outG,
2406 const int32_t *argmax,
2407 const dim_t *inGdims,
2408 const dim_t *outWdims) {
2409 libjit_max_pool_argmax_grad_generic(inG, outG, argmax, inGdims, outWdims);
2410 }
2411
libjit_resizenearest_f(float * dst,const float * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)2412 void libjit_resizenearest_f(float *dst, const float *src, const float *scale,
2413 const dim_t *inWdims, const dim_t *outWdims) {
2414 libjit_resizenearest_generic(dst, src, scale, inWdims, outWdims);
2415 }
2416
libjit_resizenearest_i8(int8_t * dst,const int8_t * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)2417 void libjit_resizenearest_i8(int8_t *dst, const int8_t *src, const float *scale,
2418 const dim_t *inWdims, const dim_t *outWdims) {
2419 libjit_resizenearest_generic(dst, src, scale, inWdims, outWdims);
2420 }
2421
libjit_resizenearest_i32(int32_t * dst,const int32_t * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)2422 void libjit_resizenearest_i32(int32_t *dst, const int32_t *src,
2423 const float *scale, const dim_t *inWdims,
2424 const dim_t *outWdims) {
2425 libjit_resizenearest_generic(dst, src, scale, inWdims, outWdims);
2426 }
2427
libjit_resizenearest_u(int64_t * dst,const int64_t * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)2428 void libjit_resizenearest_u(int64_t *dst, const int64_t *src,
2429 const float *scale, const dim_t *inWdims,
2430 const dim_t *outWdims) {
2431 libjit_resizenearest_generic(dst, src, scale, inWdims, outWdims);
2432 }
2433
libjit_resizebilinear_f(float * dst,const float * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)2434 void libjit_resizebilinear_f(float *dst, const float *src, const float *scale,
2435 const dim_t *inWdims, const dim_t *outWdims) {
2436 libjit_resizebilinear_generic(dst, src, scale, inWdims, outWdims);
2437 }
2438
libjit_resizebilinear_i8(int8_t * dst,const int8_t * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)2439 void libjit_resizebilinear_i8(int8_t *dst, const int8_t *src,
2440 const float *scale, const dim_t *inWdims,
2441 const dim_t *outWdims) {
2442 libjit_resizebilinear_generic(dst, src, scale, inWdims, outWdims);
2443 }
2444
libjit_resizebilinear_i32(int32_t * dst,const int32_t * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)2445 void libjit_resizebilinear_i32(int32_t *dst, const int32_t *src,
2446 const float *scale, const dim_t *inWdims,
2447 const dim_t *outWdims) {
2448 libjit_resizebilinear_generic(dst, src, scale, inWdims, outWdims);
2449 }
2450
libjit_resizebilinear_u(int64_t * dst,const int64_t * src,const float * scale,const dim_t * inWdims,const dim_t * outWdims)2451 void libjit_resizebilinear_u(int64_t *dst, const int64_t *src,
2452 const float *scale, const dim_t *inWdims,
2453 const dim_t *outWdims) {
2454 libjit_resizebilinear_generic(dst, src, scale, inWdims, outWdims);
2455 }
2456
libjit_avg_pool_i8(const int8_t * inW,int8_t * outW,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernelSizes,dim_t * strides,dim_t * pads,int32_t outOffset,int32_t inOffset,int32_t outPre,int32_t outPost,int32_t outScale)2457 void libjit_avg_pool_i8(const int8_t *inW, int8_t *outW, const dim_t *inWdims,
2458 const dim_t *outWdims, dim_t *kernelSizes,
2459 dim_t *strides, dim_t *pads, int32_t outOffset,
2460 int32_t inOffset, int32_t outPre, int32_t outPost,
2461 int32_t outScale) {
2462 dim_t pad_t = pads[0];
2463 dim_t pad_l = pads[1];
2464 dim_t stride_h = strides[0];
2465 dim_t stride_w = strides[1];
2466 dim_t kernel_h = kernelSizes[0];
2467 dim_t kernel_w = kernelSizes[1];
2468 // For each input in the batch:
2469 for (dim_t n = 0; n < outWdims[0]; n++) {
2470 // For each (x,y) step in the input/output tensor:
2471 sdim_t x = -sdim_t(pad_t);
2472 for (dim_t ax = 0; ax < outWdims[1]; x += stride_h, ax++) {
2473 sdim_t y = -sdim_t(pad_l);
2474 for (dim_t ay = 0; ay < outWdims[2]; y += stride_w, ay++) {
2475 // For each layer in the output tensor:
2476 for (dim_t z = 0; z < inWdims[3]; z++) {
2477 int32_t sum = 0;
2478
2479 for (dim_t fx = 0; fx < kernel_h; fx++) {
2480 for (dim_t fy = 0; fy < kernel_w; fy++) {
2481 sdim_t ox = x + fx;
2482 sdim_t oy = y + fy;
2483
2484 // Ignore index access below zero (this is due to padding).
2485 if (ox < 0 || oy < 0 || ox >= (sdim_t)inWdims[1] ||
2486 oy >= (sdim_t)inWdims[2]) {
2487 continue;
2488 }
2489 sum += inW[libjit_getXYZW(inWdims, n, (dim_t)ox, (dim_t)oy, z)] -
2490 inOffset;
2491 }
2492 }
2493
2494 outW[libjit_getXYZW(outWdims, n, ax, ay, z)] = libjit_clip(
2495 libjit_scale_i32i8(sum, outPre, outPost, outScale, outOffset));
2496 } // C
2497 } // W
2498 } // H
2499 } // N
2500 }
2501
libjit_avg_pool_f(const float * inW,float * outW,const dim_t * inWdims,const dim_t * outWdims,dim_t * kernelSizes,dim_t * strides,dim_t * pads)2502 void libjit_avg_pool_f(const float *inW, float *outW, const dim_t *inWdims,
2503 const dim_t *outWdims, dim_t *kernelSizes,
2504 dim_t *strides, dim_t *pads) {
2505 dim_t pad_t = pads[0];
2506 dim_t pad_l = pads[1];
2507 dim_t stride_h = strides[0];
2508 dim_t stride_w = strides[1];
2509 dim_t kernel_h = kernelSizes[0];
2510 dim_t kernel_w = kernelSizes[1];
2511 float filterArea = kernel_h * kernel_w;
2512 // For each input in the batch:
2513 for (dim_t n = 0; n < outWdims[0]; n++) {
2514 // For each (x,y) step in the input/output tensor:
2515 sdim_t x = -(sdim_t)pad_t;
2516 for (dim_t ax = 0; ax < outWdims[1]; x += stride_h, ax++) {
2517 sdim_t y = -(sdim_t)pad_l;
2518 for (dim_t ay = 0; ay < outWdims[2]; y += stride_w, ay++) {
2519 // For each layer in the output tensor:
2520 for (dim_t z = 0; z < inWdims[3]; z++) {
2521
2522 float sum = 0;
2523
2524 for (dim_t fx = 0; fx < kernel_h; fx++) {
2525 for (dim_t fy = 0; fy < kernel_w; fy++) {
2526 sdim_t ox = x + fx;
2527 sdim_t oy = y + fy;
2528
2529 // Ignore index access below zero (this is due to padding).
2530 if (ox < 0 || oy < 0 || ox >= (sdim_t)inWdims[1] ||
2531 oy >= (sdim_t)inWdims[2]) {
2532 continue;
2533 }
2534
2535 sum += inW[libjit_getXYZW(inWdims, n, (dim_t)ox, (dim_t)oy, z)];
2536 }
2537 }
2538
2539 outW[libjit_getXYZW(outWdims, n, ax, ay, z)] = sum / filterArea;
2540 } // C
2541 } // W
2542 } // H
2543 } // N
2544 }
2545
libjit_adaptive_avg_pool_f(const float * inW,float * outW,const dim_t * inWdims,const dim_t * outWdims)2546 void libjit_adaptive_avg_pool_f(const float *inW, float *outW,
2547 const dim_t *inWdims, const dim_t *outWdims) {
2548 // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/AdaptiveAveragePooling.cpp
2549 #define START_IND(a, b, c) (size_t) std::floor((float)((a) * (c)) / (b))
2550 #define END_IND(a, b, c) (size_t) std::ceil((float)(((a) + 1) * (c)) / (b))
2551
2552 // For each input in the batch:
2553 for (dim_t n = 0; n < outWdims[0]; n++) {
2554 // For each layer in the output tensor:
2555 for (dim_t z = 0; z < inWdims[3]; z++) {
2556 // For each value in the output tensor:
2557 for (dim_t ax = 0; ax < outWdims[1]; ax++) {
2558
2559 dim_t x = START_IND(ax, outWdims[1], inWdims[1]);
2560 dim_t kH = END_IND(ax, outWdims[1], inWdims[1]) - x;
2561
2562 for (dim_t ay = 0; ay < outWdims[2]; ay++) {
2563
2564 dim_t y = START_IND(ay, outWdims[2], inWdims[2]);
2565 dim_t kW = END_IND(ay, outWdims[2], inWdims[2]) - y;
2566
2567 float sum = 0;
2568 for (dim_t fx = 0; fx < kH; fx++) {
2569 for (dim_t fy = 0; fy < kW; fy++) {
2570 dim_t ox = x + fx;
2571 dim_t oy = y + fy;
2572
2573 sum += inW[libjit_getXYZW(inWdims, n, ox, oy, z)];
2574 }
2575 }
2576 outW[libjit_getXYZW(outWdims, n, ax, ay, z)] = (sum / kW / kH);
2577 } // W
2578 } // H
2579 } // C
2580 } // N
2581 #undef START_IND
2582 #undef END_IND
2583 }
2584
libjit_avg_pool_grad_f(float * inG,const float * outG,const dim_t * inGdims,const dim_t * outWdims,dim_t * kernels,dim_t * strides,dim_t * pads)2585 void libjit_avg_pool_grad_f(float *inG, const float *outG, const dim_t *inGdims,
2586 const dim_t *outWdims, dim_t *kernels,
2587 dim_t *strides, dim_t *pads) {
2588 dim_t pad_t = pads[0];
2589 dim_t pad_l = pads[1];
2590 dim_t stride_h = strides[0];
2591 dim_t stride_w = strides[1];
2592 dim_t kernel_h = kernels[0];
2593 dim_t kernel_w = kernels[1];
2594 float kernelArea = kernel_h * kernel_w;
2595
2596 // NHWC format is assumed
2597 for (dim_t n = 0; n < outWdims[0]; n++) {
2598 for (dim_t z = 0; z < outWdims[3]; z++) {
2599 // Clear inG
2600 for (dim_t x = 0; x < inGdims[1]; x++) {
2601 for (dim_t y = 0; y < inGdims[2]; y++) {
2602 inG[libjit_getXYZW(inGdims, n, x, y, z)] = 0.0;
2603 }
2604 }
2605
2606 sdim_t x = -(sdim_t)pad_t;
2607 for (dim_t ax = 0; ax < outWdims[1]; x += stride_h, ax++) {
2608 sdim_t y = -(sdim_t)pad_l;
2609 for (dim_t ay = 0; ay < outWdims[2]; y += stride_w, ay++) {
2610 float df = outG[libjit_getXYZW(outWdims, n, ax, ay, z)] / kernelArea;
2611 for (dim_t kx = 0; kx < kernel_h; kx++) {
2612 for (dim_t ky = 0; ky < kernel_w; ky++) {
2613 sdim_t ox = x + kx;
2614 sdim_t oy = y + ky;
2615 if (ox < 0 || oy < 0 || ox >= (sdim_t)inGdims[1] ||
2616 oy >= (sdim_t)inGdims[2]) {
2617 continue;
2618 }
2619 inG[libjit_getXYZW(inGdims, n, (dim_t)ox, (dim_t)oy, z)] += df;
2620 }
2621 }
2622 } // W
2623 } // H
2624 } // C
2625 } // N
2626 }
2627
libjit_element_quantize_kernel_i8(dim_t idx,const float * inW,float scale,int32_t offset)2628 int8_t libjit_element_quantize_kernel_i8(dim_t idx, const float *inW,
2629 float scale, int32_t offset) {
2630 int32_t result = (int32_t)nearbyintf(inW[idx] / scale + offset);
2631 return (int8_t)MAX(INT8_MIN, MIN(INT8_MAX, result));
2632 }
2633
libjit_element_quantize_kernel_i32(dim_t idx,const float * inW,float scale,int32_t offset)2634 int32_t libjit_element_quantize_kernel_i32(dim_t idx, const float *inW,
2635 float scale, int32_t offset) {
2636 int32_t result = (int32_t)nearbyintf(inW[idx] / scale + offset);
2637 return result;
2638 }
2639
libjit_element_dequantize_kernel_f(dim_t idx,const int8_t * inW,float scale,int32_t offset)2640 float libjit_element_dequantize_kernel_f(dim_t idx, const int8_t *inW,
2641 float scale, int32_t offset) {
2642 return scale * (inW[idx] - offset);
2643 }
2644
libjit_element_rescale_kernel_i8(dim_t idx,const int8_t * inW,int32_t outOffset,int32_t inOffset,int32_t pre,int32_t post,int32_t scale)2645 int8_t libjit_element_rescale_kernel_i8(dim_t idx, const int8_t *inW,
2646 int32_t outOffset, int32_t inOffset,
2647 int32_t pre, int32_t post,
2648 int32_t scale) {
2649 int32_t s =
2650 libjit_scale_i32i8(inW[idx] - inOffset, pre, post, scale, outOffset);
2651 return libjit_clip(s);
2652 }
2653
libjit_softmax_f(const float * inW,float * outW,const dim_t * idim,const dim_t * odim)2654 void libjit_softmax_f(const float *inW, float *outW, const dim_t *idim,
2655 const dim_t *odim) {
2656 for (dim_t n = 0; n < idim[0]; n++) {
2657 float max = inW[libjit_getXY(idim, n, 0)];
2658
2659 // Find Max.
2660 for (dim_t i = 1; i < idim[1]; i++) {
2661 max = MAX(max, inW[libjit_getXY(idim, n, i)]);
2662 }
2663
2664 float sum = 0;
2665
2666 // Compute exp.
2667 for (dim_t i = 0; i < idim[1]; i++) {
2668 float e = expf(inW[libjit_getXY(idim, n, i)] - max);
2669 sum += e;
2670 outW[libjit_getXY(odim, n, i)] = e;
2671 }
2672
2673 // Normalize the output.
2674 for (dim_t i = 0; i < idim[1]; i++) {
2675 outW[libjit_getXY(odim, n, i)] = outW[libjit_getXY(odim, n, i)] / sum;
2676 }
2677 } // N
2678 }
2679
libjit_softmax_grad_f_u(float * inG,float * outW,const size_t * selectedW,const dim_t * idim,const dim_t * selectdim)2680 void libjit_softmax_grad_f_u(float *inG, float *outW, const size_t *selectedW,
2681 const dim_t *idim, const dim_t *selectdim) {
2682 libjit_softmax_grad_generic(inG, outW, selectedW, idim, selectdim);
2683 }
2684
libjit_softmax_grad_f_i32(float * inG,float * outW,const int32_t * selectedW,const dim_t * idim,const dim_t * selectdim)2685 void libjit_softmax_grad_f_i32(float *inG, float *outW,
2686 const int32_t *selectedW, const dim_t *idim,
2687 const dim_t *selectdim) {
2688 libjit_softmax_grad_generic(inG, outW, selectedW, idim, selectdim);
2689 }
2690
libjit_topk_f_u(float * values,size_t * indices,const float * input,void * scratch,dim_t k,dim_t n,dim_t size)2691 void libjit_topk_f_u(float *values, size_t *indices, const float *input,
2692 void *scratch, dim_t k, dim_t n, dim_t size) {
2693 libjit_topk(values, indices, input, scratch, k, n, size);
2694 }
2695
libjit_topk_f_i32(float * values,int32_t * indices,const float * input,void * scratch,dim_t k,dim_t n,dim_t size)2696 void libjit_topk_f_i32(float *values, int32_t *indices, const float *input,
2697 void *scratch, dim_t k, dim_t n, dim_t size) {
2698 libjit_topk(values, indices, input, scratch, k, n, size);
2699 }
2700
libjit_topk_i8_u(int8_t * values,size_t * indices,const int8_t * input,void * scratch,dim_t k,dim_t n,dim_t size)2701 void libjit_topk_i8_u(int8_t *values, size_t *indices, const int8_t *input,
2702 void *scratch, dim_t k, dim_t n, dim_t size) {
2703 libjit_topk(values, indices, input, scratch, k, n, size);
2704 }
2705
libjit_topk_i8_i32(int8_t * values,int32_t * indices,const int8_t * input,void * scratch,dim_t k,dim_t n,dim_t size)2706 void libjit_topk_i8_i32(int8_t *values, int32_t *indices, const int8_t *input,
2707 void *scratch, dim_t k, dim_t n, dim_t size) {
2708 libjit_topk(values, indices, input, scratch, k, n, size);
2709 }
2710
libjit_transpose_i8(const int8_t * inW,int8_t * outW,const dim_t * idim,const dim_t * odim,const dim_t * shuffle,dim_t numDims)2711 void libjit_transpose_i8(const int8_t *inW, int8_t *outW, const dim_t *idim,
2712 const dim_t *odim, const dim_t *shuffle,
2713 dim_t numDims) {
2714 libjit_transpose_generic(inW, outW, idim, odim, shuffle, numDims);
2715 }
2716
libjit_transpose_f(const float * inW,float * outW,const dim_t * idim,const dim_t * odim,const dim_t * shuffle,dim_t numDims)2717 void libjit_transpose_f(const float *inW, float *outW, const dim_t *idim,
2718 const dim_t *odim, const dim_t *shuffle,
2719 dim_t numDims) {
2720 libjit_transpose_generic(inW, outW, idim, odim, shuffle, numDims);
2721 }
2722
libjit_transpose_u(const int64_t * inW,int64_t * outW,const dim_t * idim,const dim_t * odim,const dim_t * shuffle,dim_t numDims)2723 void libjit_transpose_u(const int64_t *inW, int64_t *outW, const dim_t *idim,
2724 const dim_t *odim, const dim_t *shuffle,
2725 dim_t numDims) {
2726 libjit_transpose_generic(inW, outW, idim, odim, shuffle, numDims);
2727 }
2728
libjit_transpose_b(const bool * inW,bool * outW,const dim_t * idim,const dim_t * odim,const dim_t * shuffle,dim_t numDims)2729 void libjit_transpose_b(const bool *inW, bool *outW, const dim_t *idim,
2730 const dim_t *odim, const dim_t *shuffle,
2731 dim_t numDims) {
2732 libjit_transpose_generic(inW, outW, idim, odim, shuffle, numDims);
2733 }
2734
libjit_flip_i8(const int8_t * inW,int8_t * outW,const dim_t * dims,dim_t axis,dim_t numDims)2735 void libjit_flip_i8(const int8_t *inW, int8_t *outW, const dim_t *dims,
2736 dim_t axis, dim_t numDims) {
2737 libjit_flip_generic(inW, outW, dims, axis, numDims);
2738 }
2739
libjit_flip_i16(const int16_t * inW,int16_t * outW,const dim_t * dims,dim_t axis,dim_t numDims)2740 void libjit_flip_i16(const int16_t *inW, int16_t *outW, const dim_t *dims,
2741 dim_t axis, dim_t numDims) {
2742 libjit_flip_generic(inW, outW, dims, axis, numDims);
2743 }
2744
libjit_flip_i32(const int32_t * inW,int32_t * outW,const dim_t * dims,dim_t axis,dim_t numDims)2745 void libjit_flip_i32(const int32_t *inW, int32_t *outW, const dim_t *dims,
2746 dim_t axis, dim_t numDims) {
2747 libjit_flip_generic(inW, outW, dims, axis, numDims);
2748 }
2749
libjit_flip_u(const int64_t * inW,int64_t * outW,const dim_t * dims,dim_t axis,dim_t numDims)2750 void libjit_flip_u(const int64_t *inW, int64_t *outW, const dim_t *dims,
2751 dim_t axis, dim_t numDims) {
2752 libjit_flip_generic(inW, outW, dims, axis, numDims);
2753 }
2754
libjit_flip_f(const float * inW,float * outW,const dim_t * dims,dim_t axis,dim_t numDims)2755 void libjit_flip_f(const float *inW, float *outW, const dim_t *dims, dim_t axis,
2756 dim_t numDims) {
2757 libjit_flip_generic(inW, outW, dims, axis, numDims);
2758 }
2759
libjit_flip_b(const bool * inW,bool * outW,const dim_t * dims,dim_t axis,dim_t numDims)2760 void libjit_flip_b(const bool *inW, bool *outW, const dim_t *dims, dim_t axis,
2761 dim_t numDims) {
2762 libjit_flip_generic(inW, outW, dims, axis, numDims);
2763 }
2764
libjit_insert_tensor_f(float * tensor,float * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim,dim_t count,dim_t axis)2765 void libjit_insert_tensor_f(float *tensor, float *slice, dim_t *offset,
2766 dim_t *tensorDim, dim_t *sliceDim,
2767 dim_t numDimsTensor, dim_t numDimsSlice,
2768 dim_t offsetDim, dim_t count, dim_t axis) {
2769 libjit_insert_tensor(tensor, slice, offset, tensorDim, sliceDim,
2770 numDimsTensor, numDimsSlice, offsetDim, count, axis);
2771 }
2772
libjit_insert_tensor_i32(int32_t * tensor,int32_t * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim,dim_t count,dim_t axis)2773 void libjit_insert_tensor_i32(int32_t *tensor, int32_t *slice, dim_t *offset,
2774 dim_t *tensorDim, dim_t *sliceDim,
2775 dim_t numDimsTensor, dim_t numDimsSlice,
2776 dim_t offsetDim, dim_t count, dim_t axis) {
2777 libjit_insert_tensor(tensor, slice, offset, tensorDim, sliceDim,
2778 numDimsTensor, numDimsSlice, offsetDim, count, axis);
2779 }
2780
libjit_extract_tensor_f(float * tensor,float * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim)2781 void libjit_extract_tensor_f(float *tensor, float *slice, dim_t *offset,
2782 dim_t *tensorDim, dim_t *sliceDim,
2783 dim_t numDimsTensor, dim_t numDimsSlice,
2784 dim_t offsetDim) {
2785 libjit_extract_tensor(tensor, slice, offset, tensorDim, sliceDim,
2786 numDimsTensor, numDimsSlice, offsetDim);
2787 }
2788
libjit_extract_tensor_i8(int8_t * tensor,int8_t * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim)2789 void libjit_extract_tensor_i8(int8_t *tensor, int8_t *slice, dim_t *offset,
2790 dim_t *tensorDim, dim_t *sliceDim,
2791 dim_t numDimsTensor, dim_t numDimsSlice,
2792 dim_t offsetDim) {
2793 libjit_extract_tensor(tensor, slice, offset, tensorDim, sliceDim,
2794 numDimsTensor, numDimsSlice, offsetDim);
2795 }
2796
libjit_extract_tensor_i32(int32_t * tensor,int32_t * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim)2797 void libjit_extract_tensor_i32(int32_t *tensor, int32_t *slice, dim_t *offset,
2798 dim_t *tensorDim, dim_t *sliceDim,
2799 dim_t numDimsTensor, dim_t numDimsSlice,
2800 dim_t offsetDim) {
2801 libjit_extract_tensor(tensor, slice, offset, tensorDim, sliceDim,
2802 numDimsTensor, numDimsSlice, offsetDim);
2803 }
2804
libjit_insert_tensor_u(int64_t * tensor,int64_t * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim,dim_t count,dim_t axis)2805 void libjit_insert_tensor_u(int64_t *tensor, int64_t *slice, dim_t *offset,
2806 dim_t *tensorDim, dim_t *sliceDim,
2807 dim_t numDimsTensor, dim_t numDimsSlice,
2808 dim_t offsetDim, dim_t count, dim_t axis) {
2809 libjit_insert_tensor(tensor, slice, offset, tensorDim, sliceDim,
2810 numDimsTensor, numDimsSlice, offsetDim, count, axis);
2811 }
2812
libjit_extract_tensor_u(int64_t * tensor,int64_t * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim)2813 void libjit_extract_tensor_u(int64_t *tensor, int64_t *slice, dim_t *offset,
2814 dim_t *tensorDim, dim_t *sliceDim,
2815 dim_t numDimsTensor, dim_t numDimsSlice,
2816 dim_t offsetDim) {
2817 libjit_extract_tensor(tensor, slice, offset, tensorDim, sliceDim,
2818 numDimsTensor, numDimsSlice, offsetDim);
2819 }
2820
libjit_insert_tensor_i8(int8_t * tensor,int8_t * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim,dim_t count,dim_t axis)2821 void libjit_insert_tensor_i8(int8_t *tensor, int8_t *slice, dim_t *offset,
2822 dim_t *tensorDim, dim_t *sliceDim,
2823 dim_t numDimsTensor, dim_t numDimsSlice,
2824 dim_t offsetDim, dim_t count, dim_t axis) {
2825 libjit_insert_tensor(tensor, slice, offset, tensorDim, sliceDim,
2826 numDimsTensor, numDimsSlice, offsetDim, count, axis);
2827 }
2828
libjit_insert_tensor_b(int8_t * tensor,int8_t * slice,dim_t * offset,dim_t * tensorDim,dim_t * sliceDim,dim_t numDimsTensor,dim_t numDimsSlice,dim_t offsetDim,dim_t count,dim_t axis)2829 void libjit_insert_tensor_b(int8_t *tensor, int8_t *slice, dim_t *offset,
2830 dim_t *tensorDim, dim_t *sliceDim,
2831 dim_t numDimsTensor, dim_t numDimsSlice,
2832 dim_t offsetDim, dim_t count, dim_t axis) {
2833 libjit_insert_tensor(tensor, slice, offset, tensorDim, sliceDim,
2834 numDimsTensor, numDimsSlice, offsetDim, count, axis);
2835 }
2836
libjit_space_to_depth_f(const float * inTensor,float * outTensor,dim_t blockSize,const dim_t * inDims,const dim_t * outDims)2837 void libjit_space_to_depth_f(const float *inTensor, float *outTensor,
2838 dim_t blockSize, const dim_t *inDims,
2839 const dim_t *outDims) {
2840 libjit_space_to_depth_generic(inTensor, outTensor, blockSize, inDims,
2841 outDims);
2842 }
2843
libjit_space_to_depth_i8(const int8_t * inTensor,int8_t * outTensor,dim_t blockSize,const dim_t * inDims,const dim_t * outDims)2844 void libjit_space_to_depth_i8(const int8_t *inTensor, int8_t *outTensor,
2845 dim_t blockSize, const dim_t *inDims,
2846 const dim_t *outDims) {
2847 libjit_space_to_depth_generic(inTensor, outTensor, blockSize, inDims,
2848 outDims);
2849 }
2850
2851 /// Function to dump a tensor in text format in the console.
libjit_dump_tensor_console(uint8_t * tensor,dim_t * tensorDim,dim_t numDimsTensor,dim_t elemKind,const char * name)2852 __attribute__((noinline)) void libjit_dump_tensor_console(uint8_t *tensor,
2853 dim_t *tensorDim,
2854 dim_t numDimsTensor,
2855 dim_t elemKind,
2856 const char *name) {
2857 printf("%s\n", name);
2858 /// This definition should match the defintion in Glow.
2859 enum class ElemKind : unsigned char {
2860 FloatTy, // 32-bit float type (float)
2861 Float16Ty, // 16-bit float type (half, fp16)
2862 BFloat16Ty, // 16-bit float type (bfloat16)
2863 Int8QTy, // 8-bit quantized type (int8_t)
2864 UInt8QTy, // unsigned 8-bit quantized type (uint8_t)
2865 Int16QTy, // 16-bit quantized type (int16_t)
2866 Int32QTy, // 32-bit quantized type (int32_t)
2867 Int32ITy, // 32-bit index type (int32_t)
2868 Int64ITy, // 64-bit index type (int64_t)
2869 UInt8FusedQTy, // 8-bit quantized type with fused scale/offset (uint8_t)
2870 BoolTy, // Bool type (bool)
2871 };
2872 // Dump the content of a tensor.
2873 switch ((ElemKind)elemKind) {
2874 case ElemKind::FloatTy:
2875 libjit_dump_tensor_console_impl((float *)tensor, tensorDim, numDimsTensor);
2876 break;
2877 case ElemKind::Int64ITy:
2878 libjit_dump_tensor_console_impl((dim_t *)tensor, tensorDim, numDimsTensor);
2879 break;
2880 case ElemKind::Int8QTy:
2881 libjit_dump_tensor_console_impl((int8_t *)tensor, tensorDim, numDimsTensor);
2882 break;
2883 case ElemKind::Int32QTy:
2884 libjit_dump_tensor_console_impl((int32_t *)tensor, tensorDim,
2885 numDimsTensor);
2886 break;
2887 default:
2888 printf("Dumping this type of payload is not supported: %zu\n",
2889 (size_t)elemKind);
2890 break;
2891 }
2892 puts("");
2893 }
2894
2895 /// Function to dump a tensor in binary format in a file using the raw tensor
2896 /// data pointer \p tensor, the tensor data size \p tensorSize (in bytes) and
2897 /// the file name \p filename. A text header \p header will also be dumped.
libjit_dump_tensor_bin(uint8_t * tensor,size_t tensorSize,const char * filename,const char * header)2898 __attribute__((noinline)) void libjit_dump_tensor_bin(uint8_t *tensor,
2899 size_t tensorSize,
2900 const char *filename,
2901 const char *header) {
2902 FILE *fh = fopen(filename, "wb");
2903 if (!fh) {
2904 printf("ERROR opening file: '%s'!\n"
2905 "File name might be too long!\n",
2906 filename);
2907 return;
2908 }
2909 // Dump header.
2910 fprintf(fh, "%s", header);
2911 // Dump tensor data.
2912 size_t size = fwrite(tensor, 1, tensorSize, fh);
2913 assert((size == tensorSize) && "Error dumping tensor to file!");
2914 (void)size;
2915 fclose(fh);
2916 }
2917
2918 /// Functions to dump a tensor in text format in a file using the raw tensor
2919 /// data pointer \p tensor, the tensor data size \p tensorElemSize (number of
2920 /// elements) and the file name \p filename. A text header \p header will also
2921 /// be dumped.
2922 #define DEFINE_DUMP_TENSOR_TXT_KERNEL(type, suffix) \
2923 __attribute__((noinline)) void libjit_dump_tensor_txt_##suffix( \
2924 uint8_t *tensor, size_t tensorElemSize, const char *filename, \
2925 const char *header) { \
2926 libjit_dump_tensor_txt_impl((type *)tensor, tensorElemSize, filename, \
2927 header); \
2928 }
DEFINE_DUMP_TENSOR_TXT_KERNEL(float,f)2929 DEFINE_DUMP_TENSOR_TXT_KERNEL(float, f)
2930 DEFINE_DUMP_TENSOR_TXT_KERNEL(int8_t, i8)
2931 DEFINE_DUMP_TENSOR_TXT_KERNEL(int16_t, i16)
2932 DEFINE_DUMP_TENSOR_TXT_KERNEL(int32_t, i32)
2933 DEFINE_DUMP_TENSOR_TXT_KERNEL(int64_t, u)
2934 DEFINE_DUMP_TENSOR_TXT_KERNEL(bool, b)
2935 #undef DEFINE_DUMP_TENSOR_TXT_KERNEL
2936
2937 void libjit_write_timestamp(uint64_t *tensor, dim_t offset) {
2938 // We are using C++ timer here to a avoid issues with gettimeofday
2939 // Issue #2397 covers migrating this to a libc approach but if you have issues
2940 // with a lack of C++ symbols at runtime check there first.
2941 uint64_t ts = std::chrono::duration_cast<std::chrono::microseconds>(
2942 std::chrono::steady_clock::now().time_since_epoch())
2943 .count();
2944 memcpy(tensor + offset, &ts, sizeof(uint64_t));
2945 }
2946
2947 /// Copies a kernel with type conversion
libjit_convertTo_f_b(float * dstPtr,const bool * srcPtr,const dim_t * dims,dim_t numDims)2948 void libjit_convertTo_f_b(float *dstPtr, const bool *srcPtr, const dim_t *dims,
2949 dim_t numDims) {
2950 libjit_copy_kernel_with_conversion<float, bool>(dstPtr, srcPtr, dims,
2951 numDims);
2952 }
2953
libjit_convertTo_b_f(bool * dstPtr,const float * srcPtr,const dim_t * dims,dim_t numDims)2954 void libjit_convertTo_b_f(bool *dstPtr, const float *srcPtr, const dim_t *dims,
2955 dim_t numDims) {
2956 libjit_copy_kernel_with_conversion<bool, float>(dstPtr, srcPtr, dims,
2957 numDims);
2958 }
2959
libjit_convertTo_f_i32(float * dstPtr,const int32_t * srcPtr,const dim_t * dims,dim_t numDims)2960 void libjit_convertTo_f_i32(float *dstPtr, const int32_t *srcPtr,
2961 const dim_t *dims, dim_t numDims) {
2962 libjit_copy_kernel_with_conversion<float, int32_t>(dstPtr, srcPtr, dims,
2963 numDims);
2964 }
2965
libjit_convertTo_i32_u(int32_t * dstPtr,const int64_t * srcPtr,const dim_t * dims,dim_t numDims)2966 void libjit_convertTo_i32_u(int32_t *dstPtr, const int64_t *srcPtr,
2967 const dim_t *dims, dim_t numDims) {
2968 libjit_copy_kernel_with_conversion<int32_t, int64_t>(dstPtr, srcPtr, dims,
2969 numDims);
2970 }
2971
libjit_convertTo_u_i32(int64_t * dstPtr,const int32_t * srcPtr,const dim_t * dims,dim_t numDims)2972 void libjit_convertTo_u_i32(int64_t *dstPtr, const int32_t *srcPtr,
2973 const dim_t *dims, dim_t numDims) {
2974 libjit_copy_kernel_with_conversion<int64_t, int32_t>(dstPtr, srcPtr, dims,
2975 numDims);
2976 }
2977
2978 /// Update min/max values \p compInfo and histogram \p existingHistogram with
2979 /// data collected from tensor \p inputTensor.
2980 /// Note: code ported from Profile.cpp: generateTensorHistogram
2981 __attribute__((noinline)) void
libjit_quantization_profile(float * inputTensor,dim_t tensorSize,float * compInfo,float * existingHistogram,dim_t * histDim)2982 libjit_quantization_profile(float *inputTensor, dim_t tensorSize,
2983 float *compInfo, float *existingHistogram,
2984 dim_t *histDim) {
2985 dim_t nBins = histDim[0];
2986
2987 // Min/max computed from previous runs. If this is the first run, compInfo is
2988 // expected to be initialized as following:
2989 // compInfo[0]: std::numeric_limits<float>::max()
2990 // compInfo[1]: std::numeric_limits<float>::lowest()
2991 float min = compInfo[0];
2992 float max = compInfo[1];
2993
2994 // Min/max value for entire current input tensor.
2995 float minInput;
2996 float maxInput;
2997 find_min_max_f(inputTensor, tensorSize, minInput, maxInput);
2998
2999 // Update the global min/max.
3000 float newMin = MIN(minInput, min);
3001 float newMax = MAX(maxInput, max);
3002 compInfo[0] = newMin;
3003 compInfo[1] = newMax;
3004
3005 // If input histogram is empty then return.
3006 if (nBins == 0) {
3007 return;
3008 }
3009
3010 // Initial profile.
3011 if (check_all_zeros(existingHistogram, nBins) == 1) {
3012 min = minInput;
3013 max = maxInput;
3014 }
3015
3016 // If the min/max range changes, there is the need to rescale the histogram.
3017 if (newMin < min || newMax > max) {
3018 float destBinWidth = (newMax - newMin) / nBins;
3019 float srcBinWidth = (max - min) / nBins;
3020 float scaledHistogram[nBins];
3021 for (dim_t i = 0; i < nBins; ++i) {
3022 scaledHistogram[i] = 0.0f;
3023 }
3024
3025 for (dim_t i = 0; i < nBins; ++i) {
3026 if (existingHistogram[i] == 0)
3027 continue;
3028
3029 float srcBinBegin = min + srcBinWidth * i;
3030 dim_t destBin = (srcBinBegin - newMin) / destBinWidth;
3031 float destBinEnd = newMin + destBinWidth * (destBin + 1);
3032
3033 float srcBinEnd = srcBinBegin + srcBinWidth;
3034 dim_t destBinToVerify = (srcBinEnd - newMin) / destBinWidth;
3035 // Make sure that destination bin is mapped at most to 2 final bins, based
3036 // on that redistribute percentage is calculated.
3037 assert(destBinToVerify <= destBin + 2);
3038 (void)destBinToVerify;
3039
3040 // Calculate how much we need to redistribute.
3041 uint64_t dstBinCnt = static_cast<uint64_t>(
3042 MIN(static_cast<float>(round((destBinEnd - srcBinBegin) /
3043 srcBinWidth * existingHistogram[i])),
3044 existingHistogram[i]));
3045
3046 dim_t newBin = get_bin(nBins, destBinWidth, newMin, srcBinBegin);
3047 scaledHistogram[newBin] += dstBinCnt;
3048
3049 if (dstBinCnt < existingHistogram[i]) {
3050 dim_t newBin =
3051 get_bin(nBins, destBinWidth, newMin, srcBinBegin + destBinWidth);
3052 scaledHistogram[newBin] += existingHistogram[i] - dstBinCnt;
3053 }
3054 }
3055
3056 // Copy scaled histogram back to the existing histogram.
3057 for (dim_t i = 0, e = nBins; i < e; ++i) {
3058 existingHistogram[i] = scaledHistogram[i];
3059 }
3060
3061 // Update global min and max.
3062 min = newMin;
3063 max = newMax;
3064 }
3065
3066 // Update the histogram with the values of the current input tensor.
3067 float binWidth = (max - min) / nBins;
3068 for (dim_t i = 0, e = tensorSize; i < e; ++i) {
3069 dim_t newBin = get_bin(nBins, binWidth, min, inputTensor[i]);
3070 existingHistogram[newBin]++;
3071 }
3072 }
3073
3074 __attribute__((noinline)) void
libjit_nms_u(uint64_t * indices,uint64_t * numDetected,const float * boxTensor,const dim_t * boxTensorDims,dim_t boxTensorDimSize,const float * scoresTensor,const dim_t * scoresTensorDims,dim_t scoresTensorDimSize,const dim_t * resultTensorDims,dim_t resultTensorDimSize,unsigned centerPointBox,unsigned maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold,bool isV4)3075 libjit_nms_u(uint64_t *indices, uint64_t *numDetected, const float *boxTensor,
3076 const dim_t *boxTensorDims, dim_t boxTensorDimSize,
3077 const float *scoresTensor, const dim_t *scoresTensorDims,
3078 dim_t scoresTensorDimSize, const dim_t *resultTensorDims,
3079 dim_t resultTensorDimSize, unsigned centerPointBox,
3080 unsigned maxOutputBoxesPerClass, float iouThreshold,
3081 float scoreThreshold, bool isV4) {
3082 libjit_nms_generic(indices, numDetected, boxTensor, boxTensorDims,
3083 boxTensorDimSize, scoresTensor, scoresTensorDims,
3084 scoresTensorDimSize, resultTensorDims, resultTensorDimSize,
3085 centerPointBox, maxOutputBoxesPerClass, iouThreshold,
3086 scoreThreshold, isV4);
3087 }
3088
3089 __attribute__((noinline)) void
libjit_nms_i32(int32_t * indices,int32_t * numDetected,const float * boxTensor,const dim_t * boxTensorDims,dim_t boxTensorDimSize,const float * scoresTensor,const dim_t * scoresTensorDims,dim_t scoresTensorDimSize,const dim_t * resultTensorDims,dim_t resultTensorDimSize,unsigned centerPointBox,unsigned maxOutputBoxesPerClass,float iouThreshold,float scoreThreshold,bool isV4)3090 libjit_nms_i32(int32_t *indices, int32_t *numDetected, const float *boxTensor,
3091 const dim_t *boxTensorDims, dim_t boxTensorDimSize,
3092 const float *scoresTensor, const dim_t *scoresTensorDims,
3093 dim_t scoresTensorDimSize, const dim_t *resultTensorDims,
3094 dim_t resultTensorDimSize, unsigned centerPointBox,
3095 unsigned maxOutputBoxesPerClass, float iouThreshold,
3096 float scoreThreshold, bool isV4) {
3097 libjit_nms_generic(indices, numDetected, boxTensor, boxTensorDims,
3098 boxTensorDimSize, scoresTensor, scoresTensorDims,
3099 scoresTensorDimSize, resultTensorDims, resultTensorDimSize,
3100 centerPointBox, maxOutputBoxesPerClass, iouThreshold,
3101 scoreThreshold, isV4);
3102 }
3103
3104 /// FFT Radix2 DIT (Decimation In Time) implementation for Complex data.
3105 /// The \p input and \p output buffers have 2 * \p fftLength float
3106 /// samples corresponding to \p fftLength complex samples with real and
3107 /// imaginary parts interleaved: real[0], imag[0], real[1], imag[1], ...
3108 /// The lookup tables \p twiddleFactors and \p bitReverseIndices are
3109 /// generated at compile time. The boolean flag \p inPlace decides whether
3110 /// the FFT computation is done in-place (that is in the \p input buffer
3111 /// without writing in the \p output buffer) or out-of-place (written in
3112 /// the \p output buffer).
libjit_fft_complex_f(float * output,float * input,const float * twiddleFactors,const int32_t * bitReverseIndices,unsigned fftLength,bool inPlace)3113 void libjit_fft_complex_f(float *output, float *input,
3114 const float *twiddleFactors,
3115 const int32_t *bitReverseIndices, unsigned fftLength,
3116 bool inPlace) {
3117
3118 // Bit Reverse Reordering.
3119 if (inPlace) {
3120 for (dim_t idx = 0; idx < fftLength; idx++) {
3121 int32_t bitRevIdx = bitReverseIndices[idx];
3122 if (idx < bitRevIdx) {
3123 // Swap complex pair.
3124 float real = input[2 * idx + 0];
3125 float imag = input[2 * idx + 1];
3126 input[2 * idx + 0] = input[2 * bitRevIdx + 0];
3127 input[2 * idx + 1] = input[2 * bitRevIdx + 1];
3128 input[2 * bitRevIdx + 0] = real;
3129 input[2 * bitRevIdx + 1] = imag;
3130 }
3131 }
3132 } else {
3133 for (dim_t idx = 0; idx < fftLength; idx++) {
3134 int32_t bitRevIdx = bitReverseIndices[idx];
3135 output[2 * idx + 0] = input[2 * bitRevIdx + 0];
3136 output[2 * idx + 1] = input[2 * bitRevIdx + 1];
3137 }
3138 }
3139
3140 // FFT output pointer.
3141 float *bitRevOut = inPlace ? input : output;
3142
3143 // Number of FFT stages.
3144 dim_t stageNum = std::log2((double)fftLength);
3145
3146 // Number of radix2 butterfly groups for 1st stage.
3147 dim_t groupNum = fftLength / 2;
3148
3149 // Number of radix2 butterflies per group for 1st stage.
3150 dim_t groupButterNum = 1;
3151
3152 // Stage loop.
3153 for (dim_t stageIdx = 0; stageIdx < stageNum; stageIdx++) {
3154
3155 // Butterfly input/output pointers.
3156 float *inp1Ptr = bitRevOut + 0 * groupButterNum;
3157 float *inp2Ptr = bitRevOut + 2 * groupButterNum;
3158
3159 // Butterfly group loop.
3160 for (dim_t groupIdx = 0; groupIdx < groupNum; groupIdx++) {
3161
3162 // Twiddle factors pointer.
3163 const float *twPtr = twiddleFactors;
3164
3165 // Butterfly loop within group.
3166 for (dim_t groupButterIdx = 0; groupButterIdx < groupButterNum;
3167 groupButterIdx++) {
3168
3169 // Radix 2 butterfly.
3170 float inp0_re = *inp1Ptr++;
3171 float inp0_im = *inp1Ptr--;
3172 float inp1_re = *inp2Ptr++;
3173 float inp1_im = *inp2Ptr--;
3174
3175 float tw_re = *twPtr++;
3176 float tw_im = *twPtr--;
3177 twPtr += (2 * groupNum);
3178
3179 float inp1_tw_mult_re = inp1_re * tw_re - inp1_im * tw_im;
3180 float inp1_tw_mult_im = inp1_re * tw_im + inp1_im * tw_re;
3181
3182 *inp1Ptr++ = inp0_re + inp1_tw_mult_re;
3183 *inp1Ptr++ = inp0_im + inp1_tw_mult_im;
3184 *inp2Ptr++ = inp0_re - inp1_tw_mult_re;
3185 *inp2Ptr++ = inp0_im - inp1_tw_mult_im;
3186 }
3187
3188 inp1Ptr += 2 * groupButterNum;
3189 inp2Ptr += 2 * groupButterNum;
3190 }
3191
3192 // Update parameters for next stage.
3193 groupNum >>= 1;
3194 groupButterNum <<= 1;
3195 }
3196 }
3197
3198 /// FFT Radix2 DIT (Decimation In Time) implementation for Real data.
3199 /// The implementation uses a fftLength/2 FFT for Complex data followed
3200 /// by a step to map the complex FFT to the real FFT by using a set of
3201 /// of complex weights \p complexToRealWeights A[k] defined as:
3202 /// A[k] = 1/2 * (1 - j*exp(-j*2*pi*k/N)) for k = 0 .. N/4-1
3203 /// The \p input buffer has \p fftLength float values corresponding
3204 /// to \p fftLength real samples. Since the FFT of a real signal
3205 /// has conjugate symmetry, the \p output buffer only contains
3206 /// 2 * (fftLength/2+1) = fftLength + 2 float values corresponding
3207 /// to fftLength/2+1 complex samples with real and imaginary parts
3208 /// interleaved: real[0], imag[0], real[1], imag[1], ...
3209 /// The lookup tables \p twiddleFactors and \p bitReverseIndices are
3210 /// generated at compile time as if they were generated for a N/2
3211 /// complex FFT. The boolean flag \p inPlace decides whether the FFT
3212 /// computation is done in-place (that is in the \p input buffer
3213 /// without writing in the \p output buffer) or out-of-place (written in
3214 /// the \p output buffer).
libjit_fft_real_f(float * output,float * input,const float * twiddleFactors,const int32_t * bitReverseIndices,const float * complexToRealWeights,unsigned fftLength,bool inPlace)3215 void libjit_fft_real_f(float *output, float *input, const float *twiddleFactors,
3216 const int32_t *bitReverseIndices,
3217 const float *complexToRealWeights, unsigned fftLength,
3218 bool inPlace) {
3219
3220 // Perform N/2 complex FFT (in-place or out-of-place).
3221 // G[k] with k = 0 .. N/2-1.
3222 libjit_fft_complex_f(output, input, twiddleFactors, bitReverseIndices,
3223 fftLength / 2, inPlace);
3224
3225 // Complex to Real FFT mapping (in-place).
3226 // X[k] = G[k] * A[k] + conj(G[N/2-k]) * (1 - A[k])
3227 // for k = 0 .. N/2 with the convention G[N/2] = G[0].
3228 // Particular cases:
3229 // real(X[0]) = real(G[0]) + imag(G[0])
3230 // imag(X[0]) = 0
3231 // real(X[N/2]) = real(G[0]) - imag(G[0])
3232 // imag(X[N/2]) = 0
3233 // X[N/4] = conj(G[N/4])
3234
3235 const float *Ak = complexToRealWeights + 2;
3236 float *ptr = inPlace ? input : output;
3237 float *ptr0 = &ptr[0];
3238 float *ptr1 = &ptr[2 * fftLength / 2 + 1];
3239 float inp0_re = *ptr0++;
3240 float inp0_im = *ptr0--;
3241 *ptr0++ = inp0_re + inp0_im;
3242 *ptr0++ = 0;
3243 *ptr1-- = 0;
3244 *ptr1-- = inp0_re - inp0_im;
3245
3246 for (dim_t k = 1; k < fftLength / 4; k++) {
3247
3248 float inp0_re = *ptr0++;
3249 float inp0_im = *ptr0--;
3250 float inp1_im = *ptr1--;
3251 float inp1_re = *ptr1++;
3252
3253 float Ak_re = *Ak++;
3254 float Ak_im = *Ak++;
3255
3256 float dif_re = inp0_re - inp1_re;
3257 float sum_im = inp0_im + inp1_im;
3258 float prod0 = dif_re * Ak_re - sum_im * Ak_im;
3259 float prod1 = dif_re * Ak_im + sum_im * Ak_re;
3260
3261 *ptr0++ = +prod0 + inp1_re;
3262 *ptr0++ = +prod1 - inp1_im;
3263 *ptr1-- = +prod1 - inp0_im;
3264 *ptr1-- = -prod0 + inp0_re;
3265 }
3266
3267 if (fftLength >= 4) {
3268 *ptr1 = -*ptr1;
3269 }
3270 }
3271
3272 /// Compute the spectrogram for the given 1D mono audio signal \p input.
3273 /// The input windows are weighted using the \p window function and the
3274 /// FFT LUTs \p twiddleFactors and \p bitReverseIndices are computed at
3275 /// compile-time. More details in Graph.h about the AudioSpectrogram node.
libjit_audio_spectrogram_f(void * winOutScratch,void * fftOutScratch,float * spectrogram,const float * input,const float * window,const float * twiddleFactors,const int32_t * bitReverseIndices,const float * complexToRealWeights,const dim_t * spectrogramDims,const dim_t inputLength,const dim_t windowSize,const dim_t windowStride,const bool magnitudeSquared)3276 void libjit_audio_spectrogram_f(
3277 void *winOutScratch, void *fftOutScratch, float *spectrogram,
3278 const float *input, const float *window, const float *twiddleFactors,
3279 const int32_t *bitReverseIndices, const float *complexToRealWeights,
3280 const dim_t *spectrogramDims, const dim_t inputLength,
3281 const dim_t windowSize, const dim_t windowStride,
3282 const bool magnitudeSquared) {
3283
3284 dim_t winNum = spectrogramDims[0];
3285 dim_t specLen = spectrogramDims[1];
3286 dim_t fftLen = (specLen - 1) * 2;
3287
3288 // Scratch buffers.
3289 float *winOut = (float *)winOutScratch;
3290 float *fftOut = (float *)fftOutScratch;
3291 memset(winOut, 0, fftLen * sizeof(float));
3292
3293 // Compute the spectrogram.
3294 for (dim_t winIdx = 0; winIdx < winNum; winIdx++) {
3295
3296 // Windowing.
3297 for (dim_t n = 0; n < windowSize; n++) {
3298 winOut[n] = input[winIdx * windowStride + n] * window[n];
3299 }
3300
3301 // Compute spectrum (perform FFT for real data).
3302 libjit_fft_real_f(fftOut, winOut, twiddleFactors, bitReverseIndices,
3303 complexToRealWeights, fftLen, false /* inPlace */);
3304
3305 // Compute spectrum magnitude/power.
3306 for (dim_t k = 0; k < specLen; k++) {
3307 float real = fftOut[2 * k + 0];
3308 float imag = fftOut[2 * k + 1];
3309 float power = real * real + imag * imag;
3310 if (magnitudeSquared) {
3311 *spectrogram++ = power;
3312 } else {
3313 *spectrogram++ = std::sqrt(power);
3314 }
3315 }
3316 }
3317 }
3318
3319 /// Compute the MFCC (Mel Frequency Cepstral Coefficient) for the given
3320 /// \p spectrogram power. The lookup tables \p melWeights, \p melRanges
3321 /// and \p dctMat are computed at compile-time. More details in Graph.h
3322 /// about the MFCC node.
libjit_mfcc_f(void * scratch,float * coefficients,const float * spectrogram,const float * melWeights,const int32_t * melRanges,const float * dctMat,const dim_t * coefficientsDims,const dim_t * spectrogramDims,const dim_t filterBankCount)3323 void libjit_mfcc_f(void *scratch, float *coefficients, const float *spectrogram,
3324 const float *melWeights, const int32_t *melRanges,
3325 const float *dctMat, const dim_t *coefficientsDims,
3326 const dim_t *spectrogramDims, const dim_t filterBankCount) {
3327
3328 // Scratch buffer.
3329 float *melBuff = (float *)scratch;
3330
3331 // Perform MFCC for all the windows.
3332 dim_t winNum = spectrogramDims[0];
3333 dim_t winSize = spectrogramDims[1];
3334 dim_t numCoefficients = coefficientsDims[1];
3335 for (dim_t winIdx = 0; winIdx < winNum; winIdx++) {
3336
3337 // Pointers backup for this window.
3338 const float *melWeightsPtr = melWeights;
3339 const int32_t *melRangesPtr = melRanges;
3340 const float *dctMatPtr = dctMat;
3341
3342 // Apply Mel filter bank mapping. We use sqrt for the spectrogram since we
3343 // assume the spectrogram is a power value and not a magnitude.
3344 for (dim_t melIdx = 0; melIdx < filterBankCount; melIdx++) {
3345
3346 int32_t freqIdxStart = *melRangesPtr++;
3347 int32_t freqIdxStop = *melRangesPtr++;
3348
3349 // Compute Mel Power.
3350 float melPwr = 0.0f;
3351 for (int32_t freqIdx = freqIdxStart; freqIdx <= freqIdxStop; freqIdx++) {
3352 melPwr += std::sqrt(spectrogram[freqIdx]) * (*melWeightsPtr++);
3353 }
3354
3355 // Take logarithm in-place (avoid log(0)).
3356 melBuff[melIdx] = (melPwr == 0.0)
3357 ? logf(std::numeric_limits<float>::min())
3358 : logf(melPwr);
3359 }
3360
3361 // Compute DCT transform.
3362 for (dim_t k = 0; k < numCoefficients; k++) {
3363 float dctOut = 0.0f;
3364 for (dim_t n = 0; n < filterBankCount; n++) {
3365 dctOut += (*dctMatPtr++) * melBuff[n];
3366 }
3367 *coefficients++ = dctOut;
3368 }
3369
3370 // Go to next spectrogram window.
3371 spectrogram += winSize;
3372 }
3373 }
3374 } // extern "C"
3375