1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "BackendTestUtils.h"
18 
19 #include "glow/ExecutionEngine/ExecutionEngine.h"
20 #include "glow/Exporter/ONNXModelWriter.h"
21 #include "glow/Graph/Graph.h"
22 #include "glow/IR/IR.h"
23 #include "glow/IR/IRBuilder.h"
24 #include "glow/IR/Instrs.h"
25 #include "glow/Importer/ONNXModelLoader.h"
26 #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
27 #include "glow/Quantization/Base/Base.h"
28 
29 #include "llvm/ADT/SmallVector.h"
30 #include "llvm/Support/FileSystem.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 #include <functional>
34 #include <numeric>
35 
36 using namespace glow;
37 
38 class OperatorStatelessTest : public BackendStatelessTest {};
39 
40 class OperatorTest : public BackendTest {
41 protected:
42   PlaceholderBindings bindings_;
43   /// Use this for storing tensors that are unowned, i.e. if they would normally
44   /// be stack local and so they cannot be read in TearDown.
45   std::vector<Tensor> unownedTensors_;
SetUp()46   virtual void SetUp() override {
47     // Skip stripping the module so that we can inspect Constants after
48     // compilation.
49     EE_.setSkipModuleStrip(true);
50   }
51 
TearDown()52   virtual void TearDown() override {
53     if (::testing::Test::IsSkipped()) {
54       return;
55     }
56 
57     EXPECT_TRUE(F_->getNodes().size() != 0)
58         << "Functions should have nodes at the end of the test.";
59 
60     ASSERT_TRUE(F_->verify(&EE_.getBackend()))
61         << "Function must pass verification.";
62 
63     // Now export the model to later import it back in.
64     llvm::SmallString<64> path;
65     auto tempFileRes =
66         llvm::sys::fs::createTemporaryFile("exporter", "output.onnxtxt", path);
67     ASSERT_EQ(tempFileRes.value(), 0)
68         << "Failed to create temp file to write into.";
69     std::string pathToModel(path.c_str());
70 
71     Error err = Error::empty();
72     ONNXModelWriter onnxWR(pathToModel, *F_, 7, 9, &err,
73                            /* textMode */ true, /* zipMode */ false,
74                            /* useGlowCustomOps */ true);
75     ASSERT_FALSE(ERR_TO_BOOL(std::move(err))) << "Error exporting model";
76 
77     // Now that we've exported, load it back into a new module/function, run it,
78     // and compare results from the original run.
79     PlaceholderBindings loadedBindings;
80     ExecutionEngine loadedEE{getBackendName()};
81     Module &loadedMod = loadedEE.getModule();
82     Function *loadedF = loadedMod.createFunction(F_->getName());
83     {
84       Error err = Error::empty();
85       // Note: We disable constant folding here because we only need it to
86       // calculate shapes that are the result of constant compute in the proto,
87       // but this won't be the case when using useGlowCustomOps exporting.
88       ONNXModelLoader onnxLD(pathToModel, {}, {}, *loadedF, &err,
89                              /* zipMode */ false, /* perNodeOpts */ nullptr,
90                              /* disableConstFoldInLoader */ true,
91                              /* loadIntoExistingModule */ false,
92                              &loadedEE.getBackend());
93       if (ERR_TO_BOOL(std::move(err))) {
94         llvm::sys::fs::remove(pathToModel);
95         FAIL() << "Error loading exported model";
96       }
97     }
98 
99     // Note that we use the backend for verification here, because the function
100     // is post optimization pipeline and so has backend-specific requirements
101     // built in, e.g. for required layout.
102     ASSERT_TRUE(loadedF->verify(&loadedEE.getBackend()))
103         << "Loaded Function must pass verification";
104 
105     // String representations of original and loaded functions must be the same.
106     // Note that we skip printing users for Storage because some tests have
107     // other Functions sharing Storage for testing purposes.
108     EXPECT_EQ(F_->toString(/* skipUsersForStorage */ true),
109               loadedF->toString(/* skipUsersForStorage */ true));
110 
111     // Copy over inputs from previous bindings to newly loaded bindings. We have
112     // new Placeholders so can't reuse the bindings from before.
113     for (const auto &p : bindings_.pairs()) {
114       if (!isInput(p.first, *F_)) {
115         continue;
116       }
117 
118       // Look for an input PH by the same name as the original Function.
119       Placeholder *inputPH =
120           loadedMod.getPlaceholderByNameSlow(p.first->getName());
121       ASSERT_TRUE(inputPH);
122       loadedBindings.insert(inputPH, p.second.getUnowned(inputPH->dims()));
123     }
124 
125     // Allocate all other PHs/tensors that need it (i.e. result PHs/tensors).
126     loadedBindings.allocate(loadedF->findPlaceholders());
127 
128     // Skip the optimization pipeline for loadedF (via onlyLowerFuns), as we
129     // already passed it through the optimization pipeline before exporting it.
130     CompilationContext cctx;
131     cctx.optimizationOpts.onlyLowerFuns.insert(loadedF);
132     loadedEE.compile(cctx);
133     loadedEE.run(loadedBindings);
134 
135     // Now bitwise-equal compare result tensors from before and after.
136     for (const auto &p : bindings_.pairs()) {
137       const Placeholder *resultPH = p.first;
138       if (!isOutput(resultPH, *F_)) {
139         continue;
140       }
141       const Tensor &resultT = p.second;
142 
143       // Find the result PH by the same name in the loaded Function.
144       Placeholder *loadedResultPH =
145           loadedMod.getPlaceholderByNameSlow(resultPH->getName());
146       ASSERT_TRUE(loadedResultPH);
147       const Tensor *loadedResultT = loadedBindings.get(loadedResultPH);
148 
149       EXPECT_TRUE(resultT.isBitwiseEqual(*loadedResultT, /* verbose */ true));
150     }
151 
152     llvm::sys::fs::remove(pathToModel);
153   }
154 };
155 
156 /// Helper to create a Placeholder; if \p T is quantized, then it will include a
157 /// dummy scale and offset, otherwise it will not.
createPlaceholderConditionallyQuantized(Module & mod,ElemKind T,llvm::ArrayRef<dim_t> dims,llvm::StringRef name,bool isTrainable,llvm::StringRef layout=ANY_LAYOUT)158 static Placeholder *createPlaceholderConditionallyQuantized(
159     Module &mod, ElemKind T, llvm::ArrayRef<dim_t> dims, llvm::StringRef name,
160     bool isTrainable, llvm::StringRef layout = ANY_LAYOUT) {
161   return isQuantizedElemKind(T)
162              ? mod.createPlaceholder(T, dims, 1.0, 0, name, isTrainable, layout)
163              : mod.createPlaceholder(T, dims, name, isTrainable, layout);
164 }
165 
166 /// Helper to get a unique Type; if \p T is quantized, then it will include a
167 /// dummy scale and offset, otherwise it will not.
uniqueTypeConditionallyQuantized(Module & mod,ElemKind T,llvm::ArrayRef<dim_t> dims)168 static TypeRef uniqueTypeConditionallyQuantized(Module &mod, ElemKind T,
169                                                 llvm::ArrayRef<dim_t> dims) {
170   return isQuantizedElemKind(T) ? mod.uniqueType(T, dims, 1.0, 0)
171                                 : mod.uniqueType(T, dims);
172 }
173 
174 /// Helper to create a Tensor; if \p T is quantized, then it will include a
175 /// dummy scale and offset, otherwise it will not.
createTensorConditionallyQuantized(ElemKind T,llvm::ArrayRef<dim_t> dims)176 static Tensor createTensorConditionallyQuantized(ElemKind T,
177                                                  llvm::ArrayRef<dim_t> dims) {
178   return isQuantizedElemKind(T) ? Tensor(T, dims, 1.0, 0) : Tensor(T, dims);
179 }
180 
181 template <typename DataType>
182 glow::Handle<bool>
lessHelper(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,llvm::ArrayRef<DataType> xValues,llvm::ArrayRef<DataType> yValues,llvm::ArrayRef<dim_t> xDims,llvm::ArrayRef<dim_t> yDims)183 lessHelper(glow::PlaceholderBindings &bindings, glow::Module &mod,
184            glow::Function *F, glow::ExecutionEngine &EE, ElemKind DTy,
185            llvm::ArrayRef<DataType> xValues, llvm::ArrayRef<DataType> yValues,
186            llvm::ArrayRef<dim_t> xDims, llvm::ArrayRef<dim_t> yDims) {
187   auto *X = createPlaceholderConditionallyQuantized(mod, DTy, xDims, "X",
188                                                     /* isTrainable */ false);
189 
190   auto *Y = createPlaceholderConditionallyQuantized(mod, DTy, yDims, "Y",
191                                                     /* isTrainable */ false);
192 
193   bindings.allocate(llvm::dyn_cast<Placeholder>(X))->getHandle<DataType>() =
194       xValues;
195 
196   bindings.allocate(llvm::dyn_cast<Placeholder>(Y))->getHandle<DataType>() =
197       yValues;
198 
199   auto *cmpr =
200       F->createNodeWithBroadcast<CmpLTNode>("cmpLT", /* axis */ -1, X, Y);
201 
202   auto *save = F->createSave("save", cmpr);
203   auto *saveAlloc = bindings.allocate(save->getPlaceholder());
204 
205   EE.compile(CompilationMode::Infer);
206   EE.run(bindings);
207 
208   return saveAlloc->getHandle<bool>();
209 }
210 
TEST_P(OperatorTest,less_int8)211 TEST_P(OperatorTest, less_int8) {
212   CHECK_IF_ENABLED();
213 
214   int8_t xValues[] = {3, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
215 
216                       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
217 
218                       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
219 
220                       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1};
221 
222   int8_t yValues[] = {3, 4, 5, 7, 2, 5, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
223 
224                       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
225 
226                       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
227 
228                       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6};
229 
230   dim_t xDims[] = {2, 2, 4, 4};
231   dim_t yDims[] = {2, 2, 4, 4};
232 
233   Handle<bool> saveH =
234       lessHelper<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy, xValues,
235                          yValues, xDims, yDims);
236 
237   bool refResults[] = {
238       false, true,  true,  true, false, false, false, true,
239       false, false, false, true, true,  true,  false, true,
240 
241       true,  true,  true,  true, false, false, false, true,
242       false, false, false, true, true,  true,  false, true,
243 
244       true,  true,  true,  true, false, false, false, true,
245       false, false, false, true, true,  true,  false, true,
246 
247       true,  true,  true,  true, false, false, false, true,
248       false, false, false, true, true,  true,  false, true,
249   };
250 
251   int counter = 0;
252   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
253     for (dim_t j = 0; j < saveH.dims()[1]; ++j) {
254       for (dim_t k = 0; k < saveH.dims()[2]; ++k) {
255         for (dim_t f = 0; f < saveH.dims()[3]; ++f) {
256           EXPECT_FLOAT_EQ(refResults[counter++], saveH.at({i, j, k, f}));
257         }
258       }
259     }
260   }
261 }
262 
TEST_P(OperatorTest,less_floatCases)263 TEST_P(OperatorTest, less_floatCases) {
264   CHECK_IF_ENABLED();
265 
266   float xValues[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
267 
268   float yValues[] = {5.0f, 4.0f, 3.0f, 2.0f, 1.0f};
269 
270   dim_t xDims[] = {5};
271   dim_t yDims[] = {5};
272 
273   Handle<bool> saveH =
274       lessHelper<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, xValues,
275                         yValues, xDims, yDims);
276 
277   bool refResults[] = {true, true, false, false, false};
278 
279   int counter = 0;
280   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
281     EXPECT_FLOAT_EQ(refResults[counter++], saveH.at({i}));
282   }
283 }
284 
TEST_P(OperatorTest,less_float16Cases)285 TEST_P(OperatorTest, less_float16Cases) {
286   CHECK_IF_ENABLED();
287 
288   float16 xValues[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
289 
290   float16 yValues[] = {5.0f, 4.0f, 3.0f, 2.0f, 1.0f};
291 
292   dim_t xDims[] = {5};
293   dim_t yDims[] = {5};
294 
295   Handle<bool> saveH =
296       lessHelper<float16>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty,
297                           xValues, yValues, xDims, yDims);
298 
299   bool refResults[] = {true, true, false, false, false};
300 
301   int counter = 0;
302   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
303     EXPECT_FLOAT_EQ(refResults[counter++], saveH.at({i}));
304   }
305 }
306 
TEST_P(OperatorTest,less_bfloat16Cases)307 TEST_P(OperatorTest, less_bfloat16Cases) {
308   CHECK_IF_ENABLED();
309 
310   bfloat16 xValues[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f};
311 
312   bfloat16 yValues[] = {5.0f, 4.0f, 3.0f, 2.0f, 1.0f};
313 
314   dim_t xDims[] = {5};
315   dim_t yDims[] = {5};
316 
317   Handle<bool> saveH =
318       lessHelper<bfloat16>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty,
319                            xValues, yValues, xDims, yDims);
320 
321   bool refResults[] = {true, true, false, false, false};
322 
323   int counter = 0;
324   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
325     EXPECT_FLOAT_EQ(refResults[counter++], saveH.at({i}));
326   }
327 }
328 
TEST_P(OperatorTest,less_int64Cases)329 TEST_P(OperatorTest, less_int64Cases) {
330   CHECK_IF_ENABLED();
331 
332   int64_t xValues[] = {1, 2, 3, 4, 5};
333 
334   int64_t yValues[] = {5, 4, 3, 2, 1};
335 
336   dim_t xDims[] = {5};
337   dim_t yDims[] = {5};
338 
339   Handle<bool> saveH =
340       lessHelper<int64_t>(bindings_, mod_, F_, EE_, ElemKind::Int64ITy, xValues,
341                           yValues, xDims, yDims);
342 
343   bool refResults[] = {true, true, false, false, false};
344 
345   int counter = 0;
346   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
347     EXPECT_FLOAT_EQ(refResults[counter++], saveH.at({i}));
348   }
349 }
350 
TEST_P(OperatorTest,less_float)351 TEST_P(OperatorTest, less_float) {
352   CHECK_IF_ENABLED();
353 
354   float xValues[] = {1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f, 6.0f, 3.0f,
355                      7.0f, 8.0f, 9.0f, 2.0f, 3.0f, 5.0f, 7.0f, 1.0f,
356 
357                      1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f, 6.0f, 3.0f,
358                      7.0f, 8.0f, 9.0f, 2.0f, 3.0f, 5.0f, 7.0f, 1.0f,
359 
360                      1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f, 6.0f, 3.0f,
361                      7.0f, 8.0f, 9.0f, 2.0f, 3.0f, 5.0f, 7.0f, 1.0f,
362 
363                      1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f, 6.0f, 3.0f,
364                      7.0f, 8.0f, 9.0f, 2.0f, 3.0f, 5.0f, 7.0f, 1.0f};
365 
366   float yValues[] = {3.0f, 4.0f, 5.0f, 7.0f, 2.0f, 1.0f, 0.0f, 6.0f,
367                      4.0f, 2.0f, 1.0f, 8.0f, 5.0f, 9.0f, 2.0f, 6.0f,
368 
369                      3.0f, 4.0f, 5.0f, 7.0f, 2.0f, 1.0f, 0.0f, 6.0f,
370                      4.0f, 2.0f, 1.0f, 8.0f, 5.0f, 9.0f, 2.0f, 6.0f,
371 
372                      3.0f, 4.0f, 5.0f, 7.0f, 2.0f, 1.0f, 0.0f, 6.0f,
373                      4.0f, 2.0f, 1.0f, 8.0f, 5.0f, 9.0f, 2.0f, 6.0f,
374 
375                      3.0f, 4.0f, 5.0f, 7.0f, 2.0f, 1.0f, 0.0f, 6.0f,
376                      4.0f, 2.0f, 1.0f, 8.0f, 5.0f, 9.0f, 2.0f, 6.0f};
377 
378   dim_t xDims[] = {2, 2, 4, 4};
379   dim_t yDims[] = {2, 2, 4, 4};
380 
381   Handle<bool> saveH =
382       lessHelper<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, xValues,
383                         yValues, xDims, yDims);
384 
385   bool refResults[] = {
386       true,  true,  true,  true, false, false, false, true,
387       false, false, false, true, true,  true,  false, true,
388 
389       true,  true,  true,  true, false, false, false, true,
390       false, false, false, true, true,  true,  false, true,
391 
392       true,  true,  true,  true, false, false, false, true,
393       false, false, false, true, true,  true,  false, true,
394 
395       true,  true,  true,  true, false, false, false, true,
396       false, false, false, true, true,  true,  false, true,
397   };
398 
399   int counter = 0;
400   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
401     for (dim_t j = 0; j < saveH.dims()[1]; ++j) {
402       for (dim_t k = 0; k < saveH.dims()[2]; ++k) {
403         for (dim_t f = 0; f < saveH.dims()[3]; ++f) {
404           EXPECT_FLOAT_EQ(refResults[counter++], saveH.at({i, j, k, f}));
405         }
406       }
407     }
408   }
409 }
410 
TEST_P(OperatorTest,less_broadcast_float)411 TEST_P(OperatorTest, less_broadcast_float) {
412   CHECK_IF_ENABLED();
413 
414   float xValues[] = {1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f, 6.0f, 3.0f,
415                      7.0f, 8.0f, 9.0f, 2.0f, 3.0f, 5.0f, 7.0f, 1.0f,
416 
417                      1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f, 6.0f, 3.0f,
418                      7.0f, 8.0f, 9.0f, 2.0f, 3.0f, 5.0f, 7.0f, 1.0f,
419 
420                      1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f, 6.0f, 3.0f,
421                      7.0f, 8.0f, 9.0f, 2.0f, 3.0f, 5.0f, 7.0f, 1.0f,
422 
423                      1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f, 6.0f, 3.0f,
424                      7.0f, 8.0f, 9.0f, 2.0f, 3.0f, 5.0f, 7.0f, 1.0f};
425 
426   float yValues[] = {3.0f, 4.0f, 5.0f, 7.0f, 2.0f, 1.0f, 0.0f, 6.0f,
427                      4.0f, 2.0f, 1.0f, 8.0f, 5.0f, 9.0f, 2.0f, 6.0f,
428 
429                      3.0f, 4.0f, 5.0f, 7.0f, 2.0f, 1.0f, 0.0f, 6.0f,
430                      4.0f, 2.0f, 1.0f, 8.0f, 5.0f, 9.0f, 2.0f, 6.0f};
431 
432   dim_t xDims[] = {2, 2, 4, 4};
433   dim_t yDims[] = {1, 2, 4, 4};
434 
435   Handle<bool> saveH =
436       lessHelper<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, xValues,
437                         yValues, xDims, yDims);
438 
439   bool refResults[] = {true,  true,  true,  true, false, false, false, true,
440                        false, false, false, true, true,  true,  false, true,
441 
442                        true,  true,  true,  true, false, false, false, true,
443                        false, false, false, true, true,  true,  false, true,
444 
445                        true,  true,  true,  true, false, false, false, true,
446                        false, false, false, true, true,  true,  false, true,
447 
448                        true,  true,  true,  true, false, false, false, true,
449                        false, false, false, true, true,  true,  false, true};
450 
451   int counter = 0;
452   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
453     for (dim_t j = 0; j < saveH.dims()[1]; ++j) {
454       for (dim_t k = 0; k < saveH.dims()[2]; ++k) {
455         for (dim_t f = 0; f < saveH.dims()[3]; ++f) {
456           EXPECT_FLOAT_EQ(refResults[counter++], saveH.at({i, j, k, f}));
457         }
458       }
459     }
460   }
461 }
462 
TEST_P(OperatorTest,less_int32Cases)463 TEST_P(OperatorTest, less_int32Cases) {
464   CHECK_IF_ENABLED();
465 
466   int32_t xValues[] = {1, 2, 3, 4, 5};
467   int32_t yValues[] = {5, 4, 3, 2, 1};
468 
469   dim_t xDims[] = {1, 1, 1, 5};
470   dim_t yDims[] = {1, 1, 1, 5};
471 
472   Handle<bool> saveH =
473       lessHelper<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32ITy, xValues,
474                           yValues, xDims, yDims);
475 
476   bool refResults[] = {true, true, false, false, false};
477 
478   int counter = 0;
479   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
480     for (dim_t j = 0; j < saveH.dims()[1]; ++j) {
481       for (dim_t k = 0; k < saveH.dims()[2]; ++k) {
482         for (dim_t f = 0; f < saveH.dims()[3]; ++f) {
483           EXPECT_FLOAT_EQ(refResults[counter++], saveH.at({i, j, k, f}));
484         }
485       }
486     }
487   }
488 }
489 
490 template <typename DataType>
491 glow::Handle<DataType>
whereHelper(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,llvm::ArrayRef<DataType> xValues,llvm::ArrayRef<DataType> yValues,llvm::ArrayRef<bool> cValues,llvm::ArrayRef<dim_t> xDims,llvm::ArrayRef<dim_t> yDims,llvm::ArrayRef<dim_t> cDims)492 whereHelper(glow::PlaceholderBindings &bindings, glow::Module &mod,
493             glow::Function *F, glow::ExecutionEngine &EE, ElemKind DTy,
494             llvm::ArrayRef<DataType> xValues, llvm::ArrayRef<DataType> yValues,
495             llvm::ArrayRef<bool> cValues, llvm::ArrayRef<dim_t> xDims,
496             llvm::ArrayRef<dim_t> yDims, llvm::ArrayRef<dim_t> cDims) {
497   auto *cond = createPlaceholderConditionallyQuantized(mod, ElemKind::BoolTy,
498                                                        cDims, "cond", false);
499   auto *X = createPlaceholderConditionallyQuantized(mod, DTy, xDims, "X",
500                                                     DTy != ElemKind::FloatTy);
501 
502   auto *Y = createPlaceholderConditionallyQuantized(mod, DTy, yDims, "Y",
503                                                     DTy != ElemKind::FloatTy);
504 
505   bindings.allocate(llvm::dyn_cast<Placeholder>(cond))->getHandle<bool>() =
506       cValues;
507 
508   bindings.allocate(llvm::dyn_cast<Placeholder>(X))->getHandle<DataType>() =
509       xValues;
510 
511   bindings.allocate(llvm::dyn_cast<Placeholder>(Y))->getHandle<DataType>() =
512       yValues;
513 
514   auto *whr = F->createNodeWithBroadcast<SelectNode>("Select", /* axis */ -1,
515                                                      cond, X, Y);
516 
517   auto *save = F->createSave("save", whr);
518   auto *saveAlloc = bindings.allocate(save->getPlaceholder());
519 
520   EE.compile(CompilationMode::Infer);
521   EE.run(bindings);
522 
523   return saveAlloc->getHandle<DataType>();
524 }
525 
TEST_P(OperatorTest,where_2d_broadcast_x_y_i8)526 TEST_P(OperatorTest, where_2d_broadcast_x_y_i8) {
527   CHECK_IF_ENABLED();
528   llvm::SmallVector<int8_t, 16> xValues = {3, 5, 7};
529 
530   llvm::SmallVector<int8_t, 16> yValues = {2, 4, 6};
531 
532   llvm::SmallVector<bool, 4> cValues = {1, 0, 1};
533 
534   llvm::SmallVector<dim_t, 4> condDims = {3, 1, 1};
535 
536   llvm::SmallVector<dim_t, 4> xDims = {1, 3, 1};
537   llvm::SmallVector<dim_t, 4> yDims = {3, 1, 1};
538 
539   Handle<int8_t> saveH =
540       whereHelper<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy, xValues,
541                           yValues, cValues, xDims, yDims, condDims);
542 
543   llvm::SmallVector<int8_t, 16> refResults = {3, 5, 7, 4, 4, 4, 3, 5, 7};
544 
545   int counter = 0;
546   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
547     for (dim_t j = 0; j < saveH.dims()[1]; ++j) {
548       for (dim_t k = 0; k < saveH.dims()[2]; ++k) {
549         EXPECT_EQ(refResults[counter++], saveH.at({i, j, k}));
550       }
551     }
552   }
553 }
554 
TEST_P(OperatorTest,where_2d_wise_i8)555 TEST_P(OperatorTest, where_2d_wise_i8) {
556   CHECK_IF_ENABLED();
557   llvm::SmallVector<int8_t, 16> xValues = {
558       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
559 
560       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
561 
562       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
563 
564       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1};
565 
566   llvm::SmallVector<int8_t, 16> yValues = {
567       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
568 
569       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
570 
571       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
572 
573       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6};
574 
575   llvm::SmallVector<bool, 4> cValues = {1, 0, 1, 0};
576 
577   llvm::SmallVector<dim_t, 4> condDims = {2, 2, 1, 1};
578 
579   llvm::SmallVector<dim_t, 4> xDims = {2, 2, 4, 4};
580   llvm::SmallVector<dim_t, 4> yDims = {2, 2, 4, 4};
581 
582   Handle<int8_t> saveH =
583       whereHelper<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy, xValues,
584                           yValues, cValues, xDims, yDims, condDims);
585 
586   llvm::SmallVector<int8_t, 16> refResults = {
587       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
588 
589       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
590 
591       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
592 
593       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6};
594 
595   int counter = 0;
596   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
597     for (dim_t j = 0; j < saveH.dims()[1]; ++j) {
598       for (dim_t k = 0; k < saveH.dims()[2]; ++k) {
599         for (dim_t f = 0; f < saveH.dims()[3]; ++f) {
600           EXPECT_EQ(refResults[counter++], saveH.at({i, j, k, f}));
601         }
602       }
603     }
604   }
605 }
606 
TEST_P(OperatorTest,where_2d_wise_float)607 TEST_P(OperatorTest, where_2d_wise_float) {
608   CHECK_IF_ENABLED();
609 
610   llvm::SmallVector<float, 16> xValues = {
611       1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f, 6.0f, 3.0f,
612       7.0f, 8.0f, 9.0f, 2.0f, 3.0f, 5.0f, 7.0f, 1.0f,
613 
614       1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f, 6.0f, 3.0f,
615       7.0f, 8.0f, 9.0f, 2.0f, 3.0f, 5.0f, 7.0f, 1.0f,
616 
617       1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f, 6.0f, 3.0f,
618       7.0f, 8.0f, 9.0f, 2.0f, 3.0f, 5.0f, 7.0f, 1.0f,
619 
620       1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f, 6.0f, 3.0f,
621       7.0f, 8.0f, 9.0f, 2.0f, 3.0f, 5.0f, 7.0f, 1.0f};
622 
623   llvm::SmallVector<float, 16> yValues = {
624       3.0f, 4.0f, 5.0f, 7.0f, 2.0f, 1.0f, 0.0f, 6.0f,
625       4.0f, 2.0f, 1.0f, 8.0f, 5.0f, 9.0f, 2.0f, 6.0f,
626 
627       3.0f, 4.0f, 5.0f, 7.0f, 2.0f, 1.0f, 0.0f, 6.0f,
628       4.0f, 2.0f, 1.0f, 8.0f, 5.0f, 9.0f, 2.0f, 6.0f,
629 
630       3.0f, 4.0f, 5.0f, 7.0f, 2.0f, 1.0f, 0.0f, 6.0f,
631       4.0f, 2.0f, 1.0f, 8.0f, 5.0f, 9.0f, 2.0f, 6.0f,
632 
633       3.0f, 4.0f, 5.0f, 7.0f, 2.0f, 1.0f, 0.0f, 6.0f,
634       4.0f, 2.0f, 1.0f, 8.0f, 5.0f, 9.0f, 2.0f, 6.0f};
635 
636   llvm::SmallVector<bool, 4> cValues = {1, 0, 1, 0};
637 
638   llvm::SmallVector<dim_t, 4> condDims = {2, 2, 1, 1};
639 
640   llvm::SmallVector<dim_t, 4> xDims = {2, 2, 4, 4};
641   llvm::SmallVector<dim_t, 4> yDims = {2, 2, 4, 4};
642 
643   Handle<float> saveH =
644       whereHelper<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, xValues,
645                          yValues, cValues, xDims, yDims, condDims);
646 
647   llvm::SmallVector<float, 16> refResults = {
648       1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f, 6.0f, 3.0f,
649       7.0f, 8.0f, 9.0f, 2.0f, 3.0f, 5.0f, 7.0f, 1.0f,
650 
651       3.0f, 4.0f, 5.0f, 7.0f, 2.0f, 1.0f, 0.0f, 6.0f,
652       4.0f, 2.0f, 1.0f, 8.0f, 5.0f, 9.0f, 2.0f, 6.0f,
653 
654       1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f, 6.0f, 3.0f,
655       7.0f, 8.0f, 9.0f, 2.0f, 3.0f, 5.0f, 7.0f, 1.0f,
656 
657       3.0f, 4.0f, 5.0f, 7.0f, 2.0f, 1.0f, 0.0f, 6.0f,
658       4.0f, 2.0f, 1.0f, 8.0f, 5.0f, 9.0f, 2.0f, 6.0f};
659 
660   int counter = 0;
661   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
662     for (dim_t j = 0; j < saveH.dims()[1]; ++j) {
663       for (dim_t k = 0; k < saveH.dims()[2]; ++k) {
664         for (dim_t f = 0; f < saveH.dims()[3]; ++f) {
665           EXPECT_FLOAT_EQ(refResults[counter++], saveH.at({i, j, k, f}));
666         }
667       }
668     }
669   }
670 }
671 
TEST_P(OperatorTest,where_row_wise_float)672 TEST_P(OperatorTest, where_row_wise_float) {
673   CHECK_IF_ENABLED();
674 
675   llvm::SmallVector<bool, 4> cValues = {1, 1, 1, 0, 0, 1, 0, 0};
676 
677   llvm::SmallVector<dim_t, 4> condDims = {2, 4, 1};
678 
679   llvm::SmallVector<dim_t, 4> xDims = {2, 4, 4};
680   llvm::SmallVector<dim_t, 4> yDims = {2, 4, 4};
681 
682   llvm::SmallVector<float, 16> xValues = {
683       1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f, 6.0f, 3.0f,
684       7.0f, 8.0f, 9.0f, 2.0f, 3.0f, 5.0f, 7.0f, 1.0f,
685 
686       1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f, 6.0f, 3.0f,
687       7.0f, 8.0f, 9.0f, 2.0f, 3.0f, 5.0f, 7.0f, 1.0f};
688 
689   llvm::SmallVector<float, 16> yValues = {
690       3.0f, 4.0f, 5.0f, 7.0f, 2.0f, 1.0f, 0.0f, 6.0f,
691       4.0f, 2.0f, 1.0f, 8.0f, 5.0f, 9.0f, 2.0f, 6.0f,
692 
693       3.0f, 4.0f, 5.0f, 7.0f, 2.0f, 1.0f, 0.0f, 6.0f,
694       4.0f, 2.0f, 1.0f, 8.0f, 5.0f, 9.0f, 2.0f, 6.0f};
695 
696   Handle<float> saveH =
697       whereHelper<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, xValues,
698                          yValues, cValues, xDims, yDims, condDims);
699 
700   llvm::SmallVector<float, 16> refResults = {
701       1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f, 6.0f, 3.0f,
702       7.0f, 8.0f, 9.0f, 2.0f, 5.0f, 9.0f, 2.0f, 6.0f,
703 
704       3.0f, 4.0f, 5.0f, 7.0f, 4.0f, 5.0f, 6.0f, 3.0f,
705       4.0f, 2.0f, 1.0f, 8.0f, 5.0f, 9.0f, 2.0f, 6.0f,
706   };
707 
708   int counter = 0;
709   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
710     for (dim_t j = 0; j < saveH.dims()[1]; ++j) {
711       for (dim_t k = 0; k < saveH.dims()[2]; ++k) {
712         EXPECT_FLOAT_EQ(refResults[counter++], saveH.at({i, j, k}));
713       }
714     }
715   }
716 }
717 
TEST_P(OperatorTest,where_element_wise_float)718 TEST_P(OperatorTest, where_element_wise_float) {
719   CHECK_IF_ENABLED();
720 
721   llvm::SmallVector<dim_t, 4> condDims = {1, 4, 4};
722 
723   llvm::SmallVector<dim_t, 4> xDims = {1, 4, 4};
724   llvm::SmallVector<dim_t, 4> yDims = {1, 4, 4};
725 
726   llvm::SmallVector<bool, 4> cValues = {1, 1, 1, 0, 0, 1, 0, 0,
727                                         0, 1, 0, 1, 1, 0, 1, 0};
728 
729   llvm::SmallVector<float, 16> xValues = {1.0f, 2.0f, 3.0f, 6.0f, 4.0f, 5.0f,
730                                           6.0f, 3.0f, 7.0f, 8.0f, 9.0f, 2.0f,
731                                           3.0f, 5.0f, 7.0f, 1.0f};
732 
733   llvm::SmallVector<float, 16> yValues = {3.0f, 4.0f, 5.0f, 7.0f, 2.0f, 1.0f,
734                                           0.0f, 6.0f, 4.0f, 2.0f, 1.0f, 8.0f,
735                                           5.0f, 9.0f, 2.0f, 6.0f};
736 
737   Handle<float> saveH =
738       whereHelper<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, xValues,
739                          yValues, cValues, xDims, yDims, condDims);
740 
741   llvm::SmallVector<float, 16> refResults = {1.0f, 2.0f, 3.0f, 7.0f, 2.0f, 5.0f,
742                                              0.0f, 6.0f, 4.0f, 8.0f, 1.0f, 2.0f,
743                                              3.0f, 9.0f, 7.0f, 6.0f};
744 
745   int counter = 0;
746   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
747     for (dim_t j = 0; j < saveH.dims()[1]; ++j) {
748       for (dim_t k = 0; k < saveH.dims()[2]; ++k) {
749         EXPECT_FLOAT_EQ(refResults[counter++], saveH.at({i, j, k}));
750       }
751     }
752   }
753 }
754 
755 struct NMSMetaData {
756   int centerPoint{0};
757   size_t maxOutputPerClass{0};
758   float iouThreshold{0.0};
759   float scoreThreshold{0.0};
760 };
761 
762 struct SelectedBox {
763   int batchIndex{0};
764   int classIndex{0};
765   int boxIndex{0};
766 };
767 
768 struct Box {
769   float x;
770   float y;
771   float h;
772   float w;
773 };
774 
775 template <typename DataType, typename outType = int64_t>
testNonMaxSuppression(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,llvm::ArrayRef<dim_t> boxesDims,llvm::ArrayRef<dim_t> scoresDims,llvm::ArrayRef<DataType> boxesData,llvm::ArrayRef<DataType> classes,llvm::ArrayRef<SelectedBox> refResults,llvm::ArrayRef<int32_t> refNumSelected,const NMSMetaData & metaData,bool isV4)776 static Handle<outType> testNonMaxSuppression(
777     glow::PlaceholderBindings &bindings, glow::Module &mod, glow::Function *F,
778     glow::ExecutionEngine &EE, ElemKind DTy, llvm::ArrayRef<dim_t> boxesDims,
779     llvm::ArrayRef<dim_t> scoresDims, llvm::ArrayRef<DataType> boxesData,
780     llvm::ArrayRef<DataType> classes, llvm::ArrayRef<SelectedBox> refResults,
781     llvm::ArrayRef<int32_t> refNumSelected, const NMSMetaData &metaData,
782     bool isV4) {
783 
784   // NHW
785   auto *boxes = createPlaceholderConditionallyQuantized(mod, DTy, boxesDims,
786                                                         "boxes", false);
787 
788   auto *scores = createPlaceholderConditionallyQuantized(mod, DTy, scoresDims,
789                                                          "scores", false);
790 
791   NonMaxSuppressionNode *nms = nullptr;
792 
793   if (isV4) {
794     nms = F->createNonMaxSuppressionV4(
795         "NMS", boxes, scores, metaData.centerPoint, metaData.maxOutputPerClass,
796         metaData.iouThreshold, metaData.scoreThreshold);
797   } else {
798     nms = F->createNonMaxSuppressionONNX(
799         "NMS", boxes, scores, metaData.centerPoint, metaData.maxOutputPerClass,
800         metaData.iouThreshold, metaData.scoreThreshold);
801   }
802 
803   auto *saveIndices = F->createSave("save", nms->getIndices());
804   auto *saveNumSelected =
805       F->createSave("numSelected", nms->getNumberOfSelectedIndices());
806   auto *result = bindings.allocate(saveIndices->getPlaceholder());
807   auto *result2 = bindings.allocate(saveNumSelected->getPlaceholder());
808 
809   bindings.allocate(boxes)->getHandle<DataType>() = boxesData;
810   bindings.allocate(scores)->getHandle<DataType>() = classes;
811 
812   CompilationContext cctx;
813   cctx.compMode = CompilationMode::Infer;
814   EE.compile(cctx);
815   EE.run(bindings);
816 
817   Handle<outType> result2H = result2->getHandle<outType>();
818   for (dim_t i = 0; i < (dim_t)refNumSelected.size(); ++i) {
819     EXPECT_EQ(result2H.at({i}), refNumSelected[i]);
820   }
821 
822   Handle<outType> resultH = result->getHandle<outType>();
823 
824   if (isV4) {
825     for (dim_t i = 0; i < (dim_t)metaData.maxOutputPerClass; ++i) {
826       EXPECT_EQ(refResults[i].boxIndex, resultH.at({i}));
827     }
828   } else {
829     for (dim_t i = 0; i < (dim_t)metaData.maxOutputPerClass; ++i) {
830       EXPECT_EQ(refResults[i].batchIndex, resultH.at({i, (dim_t)0}));
831       EXPECT_EQ(refResults[i].classIndex, resultH.at({i, (dim_t)1}));
832       EXPECT_EQ(refResults[i].boxIndex, resultH.at({i, (dim_t)2}));
833     }
834   }
835 
836   return resultH;
837 }
838 
839 template <typename DataType, typename outType = int64_t>
testNonMaxSuppressionWithGather(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,llvm::ArrayRef<dim_t> boxesDims,llvm::ArrayRef<dim_t> scoresDims,llvm::ArrayRef<dim_t> boxIndicesDim,llvm::ArrayRef<DataType> boxesData,llvm::ArrayRef<DataType> classes,llvm::ArrayRef<int32_t> boxIndicesData,llvm::ArrayRef<Box> refBoxResults,llvm::ArrayRef<int32_t> refNumSelected,const NMSMetaData & metaData,bool isV4)840 static Handle<float> testNonMaxSuppressionWithGather(
841     glow::PlaceholderBindings &bindings, glow::Module &mod, glow::Function *F,
842     glow::ExecutionEngine &EE, ElemKind DTy, llvm::ArrayRef<dim_t> boxesDims,
843     llvm::ArrayRef<dim_t> scoresDims, llvm::ArrayRef<dim_t> boxIndicesDim,
844     llvm::ArrayRef<DataType> boxesData, llvm::ArrayRef<DataType> classes,
845     llvm::ArrayRef<int32_t> boxIndicesData, llvm::ArrayRef<Box> refBoxResults,
846     llvm::ArrayRef<int32_t> refNumSelected, const NMSMetaData &metaData,
847     bool isV4) {
848   // NHW
849   auto *boxes = createPlaceholderConditionallyQuantized(mod, DTy, boxesDims,
850                                                         "boxes", false);
851 
852   auto *scores = createPlaceholderConditionallyQuantized(mod, DTy, scoresDims,
853                                                          "scores", false);
854 
855   auto *boxIndices = createPlaceholderConditionallyQuantized(
856       mod, ElemKind::Int32ITy, boxIndicesDim, "boxIndices", false);
857 
858   NonMaxSuppressionNode *nms = nullptr;
859 
860   unsigned axis = 1;
861   if (isV4) {
862     nms = F->createNonMaxSuppressionV4(
863         "NMS", boxes, scores, metaData.centerPoint, metaData.maxOutputPerClass,
864         metaData.iouThreshold, metaData.scoreThreshold);
865     axis = 0;
866   } else {
867 
868     nms = F->createNonMaxSuppressionONNX(
869         "NMS", boxes, scores, metaData.centerPoint, metaData.maxOutputPerClass,
870         metaData.iouThreshold, metaData.scoreThreshold);
871   }
872 
873   // extract all the box indices
874   auto *gthI =
875       F->createGather("gatherBoxIndices", nms->getIndices(), boxIndices, axis);
876   auto *gthB = F->createGather("gatherClassIndices", boxes, gthI, axis);
877   Node *fltn2 = nullptr;
878 
879   if (isV4) {
880     fltn2 = gthB;
881   } else {
882     fltn2 = F->createFlatten("flatten", gthB, 2);
883   }
884 
885   auto *saveBoxes = F->createSave("saveBoxes", fltn2);
886   auto saveNumSelected =
887       F->createSave("numSelected", nms->getNumberOfSelectedIndices());
888 
889   auto *result = bindings.allocate(saveBoxes->getPlaceholder());
890   auto *result2 = bindings.allocate(saveNumSelected->getPlaceholder());
891 
892   bindings.allocate(boxes)->getHandle<DataType>() = boxesData;
893   bindings.allocate(scores)->getHandle<DataType>() = classes;
894   bindings.allocate(boxIndices)->getHandle<int32_t>() = boxIndicesData;
895 
896   CompilationContext cctx;
897   cctx.compMode = CompilationMode::Infer;
898   EE.compile(cctx);
899   EE.run(bindings);
900 
901   Handle<outType> result2H = result2->getHandle<outType>();
902   for (dim_t i = 0; i < (dim_t)refNumSelected.size(); ++i) {
903     EXPECT_EQ(result2H.at({i}), refNumSelected[i]);
904   }
905 
906   Handle<float> resultH = result->getHandle<float>();
907 
908   for (dim_t i = 0; i < (dim_t)refBoxResults.size(); ++i) {
909     EXPECT_EQ(refBoxResults[i].x, resultH.at({i, (dim_t)0}));
910     EXPECT_EQ(refBoxResults[i].y, resultH.at({i, (dim_t)1}));
911     EXPECT_EQ(refBoxResults[i].h, resultH.at({i, (dim_t)2}));
912     EXPECT_EQ(refBoxResults[i].w, resultH.at({i, (dim_t)3}));
913   }
914 
915   return resultH;
916 }
917 
TEST_P(OperatorTest,nms_center_point_box_with_gather_float)918 TEST_P(OperatorTest, nms_center_point_box_with_gather_float) {
919   CHECK_IF_ENABLED();
920   llvm::SmallVector<dim_t, 3> boxesDims = {1, 6, 4};
921   llvm::SmallVector<dim_t, 3> scoresDims = {1, 1, 6};
922   llvm::SmallVector<dim_t, 1> boxIndexesDms = {1};
923 
924   llvm::SmallVector<float, 24> boxes = {
925       0.5, 0.5,  1.0, 1.0, 0.5, 0.6,  1.0, 1.0, 0.5, 0.4,   1.0, 1.0,
926       0.5, 10.5, 1.0, 1.0, 0.5, 10.6, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0};
927 
928   llvm::SmallVector<float, 6> classes = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
929   llvm::SmallVector<int32_t, 1> boxIndices = {2};
930   llvm::SmallVector<Box, 3> refResults = {
931       {0.5, 10.5, 1.0, 1.0}, {0.5, 0.5, 1.0, 1.0}, {0.5, 0.5, 1.0, 1.0}};
932   NMSMetaData metaData = {1, 3, 0.5, 0.4};
933   llvm::SmallVector<int32_t, 1> refNumSelected = {2};
934 
935   testNonMaxSuppressionWithGather<float>(
936       bindings_, mod_, F_, EE_, ElemKind::FloatTy, boxesDims, scoresDims,
937       boxIndexesDms, boxes, classes, boxIndices, refResults, refNumSelected,
938       metaData, false);
939 }
940 
TEST_P(OperatorTest,nms_v4_center_point_box_with_gather_float)941 TEST_P(OperatorTest, nms_v4_center_point_box_with_gather_float) {
942   CHECK_IF_ENABLED();
943   llvm::SmallVector<dim_t, 3> boxesDims = {6, 4};
944   llvm::SmallVector<dim_t, 1> scoresDims = {6};
945   llvm::SmallVector<dim_t, 1> boxIndexesDims = {3};
946 
947   llvm::SmallVector<float, 24> boxes = {
948       0.5, 0.5,  1.0, 1.0, 0.5, 0.6,  1.0, 1.0, 0.5, 0.4,   1.0, 1.0,
949       0.5, 10.5, 1.0, 1.0, 0.5, 10.6, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0};
950 
951   llvm::SmallVector<float, 6> classes = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
952   llvm::SmallVector<int32_t, 3> boxIndices = {0, 1, 2};
953   llvm::SmallVector<Box, 3> refResults = {
954       {0.5, 10.5, 1.0, 1.0}, {0.5, 0.5, 1.0, 1.0}, {0.5, 0.5, 1.0, 1.0}};
955   NMSMetaData metaData = {1, 3, 0.5, 0.4};
956   llvm::SmallVector<int32_t, 1> refNumSelected{2};
957 
958   testNonMaxSuppressionWithGather<float>(
959       bindings_, mod_, F_, EE_, ElemKind::FloatTy, boxesDims, scoresDims,
960       boxIndexesDims, boxes, classes, boxIndices, refResults, refNumSelected,
961       metaData, true);
962 }
963 
TEST_P(OperatorTest,nms_center_point_box_float)964 TEST_P(OperatorTest, nms_center_point_box_float) {
965   CHECK_IF_ENABLED();
966   llvm::SmallVector<dim_t, 3> boxesDims = {1, 6, 4};
967   llvm::SmallVector<dim_t, 3> scoresDims = {1, 1, 6};
968   llvm::SmallVector<float, 24> boxes = {
969       0.5, 0.5,  1.0, 1.0, 0.5, 0.6,  1.0, 1.0, 0.5, 0.4,   1.0, 1.0,
970       0.5, 10.5, 1.0, 1.0, 0.5, 10.6, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0};
971   llvm::SmallVector<float, 6> classes = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
972   llvm::SmallVector<SelectedBox, 3> refResults = {
973       {0, 0, 3}, {0, 0, 0}, {0, 0, 5}};
974   NMSMetaData metaData = {1, 3, 0.5, 0.0};
975   llvm::SmallVector<int32_t, 1> refNumSelected{3};
976 
977   testNonMaxSuppression<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
978                                boxesDims, scoresDims, boxes, classes,
979                                refResults, refNumSelected, metaData, false);
980 }
981 
TEST_P(OperatorTest,nms_v4_center_point_box_float)982 TEST_P(OperatorTest, nms_v4_center_point_box_float) {
983   CHECK_IF_ENABLED();
984   llvm::SmallVector<dim_t, 3> boxesDims = {6, 4};
985   llvm::SmallVector<dim_t, 1> scoresDims = {6};
986   llvm::SmallVector<float, 24> boxes = {
987       0.5, 0.5,  1.0, 1.0, 0.5, 0.6,  1.0, 1.0, 0.5, 0.4,   1.0, 1.0,
988       0.5, 10.5, 1.0, 1.0, 0.5, 10.6, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0};
989   llvm::SmallVector<float, 6> classes = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
990   llvm::SmallVector<SelectedBox, 3> refResults = {
991       {0, 0, 3}, {0, 0, 0}, {0, 0, 5}};
992   NMSMetaData metaData = {1, 3, 0.5, 0.0};
993   llvm::SmallVector<int32_t, 1> refNumSelected{3};
994 
995   testNonMaxSuppression<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
996                                boxesDims, scoresDims, boxes, classes,
997                                refResults, refNumSelected, metaData, true);
998 }
999 
TEST_P(OperatorTest,nms_flipped_coordinates_float)1000 TEST_P(OperatorTest, nms_flipped_coordinates_float) {
1001   CHECK_IF_ENABLED();
1002   llvm::SmallVector<dim_t, 3> boxesDims = {1, 6, 4};
1003   llvm::SmallVector<dim_t, 3> scoresDims = {1, 1, 6};
1004   llvm::SmallVector<float, 24> boxes = {
1005       1.0, 1.0,  0.0, 0.0,  0.0, 0.1,  1.0, 1.1,  0.0, 0.9,   1.0, -0.1,
1006       0.0, 10.0, 1.0, 11.0, 1.0, 10.1, 0.0, 11.1, 1.0, 101.0, 0.0, 100.0};
1007   llvm::SmallVector<float, 6> classes = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
1008   llvm::SmallVector<SelectedBox, 3> refResults = {
1009       {0, 0, 3}, {0, 0, 0}, {0, 0, 5}};
1010   NMSMetaData metaData = {0, 3, 0.5, 0.0};
1011   llvm::SmallVector<int32_t, 1> refNumSelected{3};
1012 
1013   testNonMaxSuppression<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
1014                                boxesDims, scoresDims, boxes, classes,
1015                                refResults, refNumSelected, metaData, false);
1016 }
1017 
TEST_P(OperatorTest,nms_identical_boxes_float)1018 TEST_P(OperatorTest, nms_identical_boxes_float) {
1019   CHECK_IF_ENABLED();
1020   llvm::SmallVector<dim_t, 3> boxesDims = {1, 10, 4};
1021   llvm::SmallVector<dim_t, 3> scoresDims = {1, 1, 10};
1022   llvm::SmallVector<float, 40> boxes = {
1023       0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0,
1024       1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0,
1025       0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0};
1026   llvm::SmallVector<float, 10> classes = {0.9, 0.9, 0.9, 0.9, 0.9,
1027                                           0.9, 0.9, 0.9, 0.9, 0.9};
1028   llvm::SmallVector<SelectedBox, 3> refResults = {{0, 0, 0}};
1029   NMSMetaData metaData = {0, 1, 0.5, 0.0};
1030   llvm::SmallVector<int32_t, 1> refNumSelected{1};
1031 
1032   testNonMaxSuppression<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
1033                                boxesDims, scoresDims, boxes, classes,
1034                                refResults, refNumSelected, metaData, false);
1035 }
1036 
TEST_P(OperatorTest,nms_limit_output_size_float)1037 TEST_P(OperatorTest, nms_limit_output_size_float) {
1038   CHECK_IF_ENABLED();
1039   llvm::SmallVector<dim_t, 3> boxesDims = {1, 6, 4};
1040   llvm::SmallVector<dim_t, 3> scoresDims = {1, 1, 6};
1041   llvm::SmallVector<float, 24> boxes = {
1042       0.0, 0.0,  1.0, 1.0,  0.0, 0.1,  1.0, 1.1,  0.0, -0.1,  1.0, 0.9,
1043       0.0, 10.0, 1.0, 11.0, 0.0, 10.1, 1.0, 11.1, 0.0, 100.0, 1.0, 101.0};
1044   llvm::SmallVector<float, 6> classes = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
1045   llvm::SmallVector<SelectedBox, 2> refResults = {{0, 0, 3}, {0, 0, 0}};
1046   NMSMetaData metaData = {0, 2, 0.5, 0.0};
1047   llvm::SmallVector<int32_t, 1> refNumSelected{2};
1048 
1049   testNonMaxSuppression<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
1050                                boxesDims, scoresDims, boxes, classes,
1051                                refResults, refNumSelected, metaData, false);
1052 }
1053 
TEST_P(OperatorTest,nms_single_box_float)1054 TEST_P(OperatorTest, nms_single_box_float) {
1055   CHECK_IF_ENABLED();
1056   llvm::SmallVector<dim_t, 3> boxesDims = {1, 1, 4};
1057   llvm::SmallVector<dim_t, 3> scoresDims = {1, 1, 1};
1058   llvm::SmallVector<float, 4> boxes = {0.0, 0.0, 1.0, 1.0};
1059   llvm::SmallVector<float, 1> classes = {0.9};
1060   llvm::SmallVector<SelectedBox, 1> refResults = {
1061       {0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
1062   NMSMetaData metaData = {0, 3, 0.5, 0.0};
1063   llvm::SmallVector<int32_t, 1> refNumSelected{1};
1064 
1065   testNonMaxSuppression<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
1066                                boxesDims, scoresDims, boxes, classes,
1067                                refResults, refNumSelected, metaData, false);
1068 }
1069 
TEST_P(OperatorTest,nms_by_iou_float)1070 TEST_P(OperatorTest, nms_by_iou_float) {
1071   CHECK_IF_ENABLED();
1072   llvm::SmallVector<dim_t, 3> boxesDims = {1, 6, 4};
1073   llvm::SmallVector<dim_t, 3> scoresDims = {1, 1, 6};
1074   llvm::SmallVector<float, 24> boxes = {
1075       0.0, 0.0,  1.0, 1.0,  0.0, 0.1,  1.0, 1.1,  0.0, -0.1,  1.0, 0.9,
1076       0.0, 10.0, 1.0, 11.0, 0.0, 10.1, 1.0, 11.1, 0.0, 100.0, 1.0, 101.0};
1077   llvm::SmallVector<float, 6> classes = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
1078   llvm::SmallVector<SelectedBox, 2> refResults = {
1079       {0, 0, 3}, {0, 0, 0}, {0, 0, 5}};
1080   NMSMetaData metaData = {0, 3, 0.5, 0.0};
1081   llvm::SmallVector<int32_t, 1> refNumSelected{3};
1082 
1083   testNonMaxSuppression<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
1084                                boxesDims, scoresDims, boxes, classes,
1085                                refResults, refNumSelected, metaData, false);
1086 }
1087 
TEST_P(OperatorTest,nms_by_iou_and_scores_float)1088 TEST_P(OperatorTest, nms_by_iou_and_scores_float) {
1089   CHECK_IF_ENABLED();
1090   llvm::SmallVector<dim_t, 3> boxesDims = {1, 6, 4};
1091   llvm::SmallVector<dim_t, 3> scoresDims = {1, 1, 6};
1092   llvm::SmallVector<float, 24> boxes = {
1093       0.0, 0.0,  1.0, 1.0,  0.0, 0.1,  1.0, 1.1,  0.0, -0.1,  1.0, 0.9,
1094       0.0, 10.0, 1.0, 11.0, 0.0, 10.1, 1.0, 11.1, 0.0, 100.0, 1.0, 101.0};
1095   llvm::SmallVector<float, 6> classes = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
1096   llvm::SmallVector<SelectedBox, 2> refResults = {{0, 0, 3}, {0, 0, 0}};
1097   NMSMetaData metaData = {0, 2, 0.5, 0.4};
1098   llvm::SmallVector<int32_t, 1> refNumSelected{2};
1099 
1100   testNonMaxSuppression<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
1101                                boxesDims, scoresDims, boxes, classes,
1102                                refResults, refNumSelected, metaData, false);
1103 }
1104 
TEST_P(OperatorTest,nms_two_batches_float)1105 TEST_P(OperatorTest, nms_two_batches_float) {
1106   CHECK_IF_ENABLED();
1107   llvm::SmallVector<dim_t, 3> boxesDims = {2, 6, 4};
1108   llvm::SmallVector<dim_t, 3> scoresDims = {2, 1, 6};
1109   llvm::SmallVector<float, 48> boxes = {
1110       0.0, 0.0,  1.0, 1.0,  0.0, 0.1,  1.0, 1.1,  0.0, -0.1,  1.0, 0.9,
1111       0.0, 10.0, 1.0, 11.0, 0.0, 10.1, 1.0, 11.1, 0.0, 100.0, 1.0, 101.0,
1112       0.0, 0.0,  1.0, 1.0,  0.0, 0.1,  1.0, 1.1,  0.0, -0.1,  1.0, 0.9,
1113       0.0, 10.0, 1.0, 11.0, 0.0, 10.1, 1.0, 11.1, 0.0, 100.0, 1.0, 101.0};
1114   llvm::SmallVector<float, 12> classes = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3,
1115                                           0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
1116   llvm::SmallVector<SelectedBox, 4> refResults = {
1117       {0, 0, 3}, {0, 0, 0}, {1, 0, 3}, {1, 0, 0}};
1118   NMSMetaData metaData = {0, 2, 0.5, 0.0};
1119   llvm::SmallVector<int32_t, 2> refNumSelected{2, 2};
1120 
1121   testNonMaxSuppression<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
1122                                boxesDims, scoresDims, boxes, classes,
1123                                refResults, refNumSelected, metaData, false);
1124 }
1125 
TEST_P(OperatorTest,nms_two_classes_float)1126 TEST_P(OperatorTest, nms_two_classes_float) {
1127   CHECK_IF_ENABLED();
1128   llvm::SmallVector<dim_t, 3> boxesDims = {1, 6, 4};
1129   llvm::SmallVector<dim_t, 3> scoresDims = {1, 2, 6};
1130   llvm::SmallVector<float, 24> boxes = {
1131       0.0, 0.0,  1.0, 1.0,  0.0, 0.1,  1.0, 1.1,  0.0, -0.1,  1.0, 0.9,
1132       0.0, 10.0, 1.0, 11.0, 0.0, 10.1, 1.0, 11.1, 0.0, 100.0, 1.0, 101.0};
1133   llvm::SmallVector<float, 12> classes = {0.9, 0.75, 0.6, 0.95, 0.5, 0.3,
1134                                           0.9, 0.75, 0.6, 0.95, 0.5, 0.3};
1135   llvm::SmallVector<SelectedBox, 4> refResults = {
1136       {0, 0, 3}, {0, 0, 0}, {0, 1, 3}, {0, 1, 0}};
1137   NMSMetaData metaData = {0, 2, 0.5, 0.4};
1138   llvm::SmallVector<int32_t, 1> refNumSelected{4};
1139 
1140   testNonMaxSuppression<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
1141                                boxesDims, scoresDims, boxes, classes,
1142                                refResults, refNumSelected, metaData, false);
1143 }
1144 
TEST_P(OperatorTest,nms_two_boxes_float)1145 TEST_P(OperatorTest, nms_two_boxes_float) {
1146   CHECK_IF_ENABLED();
1147   llvm::SmallVector<dim_t, 3> boxesDims = {1, 2, 4};
1148   llvm::SmallVector<dim_t, 3> scoresDims = {1, 1, 2};
1149   llvm::SmallVector<float, 4> boxes = {0.0, 0.0, 1.0, 1.0, 0.1, 0.1, 0.9, 0.9};
1150   llvm::SmallVector<float, 2> classes = {0.8, 0.9};
1151   llvm::SmallVector<SelectedBox, 1> refResults = {{0, 0, 1}};
1152   NMSMetaData metaData = {0, 1, 0.5, 0.0};
1153   llvm::SmallVector<int32_t, 1> refNumSelected{1};
1154 
1155   testNonMaxSuppression<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
1156                                boxesDims, scoresDims, boxes, classes,
1157                                refResults, refNumSelected, metaData, false);
1158 }
1159 
1160 /// Helper function to test AudioSpectrogram node.
1161 template <size_t windowCount, size_t windowSize, bool magnitudeSquared>
1162 static FunctionTensorPair
createAndInitBasicAudioSpectrogramTest(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE)1163 createAndInitBasicAudioSpectrogramTest(glow::PlaceholderBindings &bindings,
1164                                        glow::ExecutionEngine &EE) {
1165   auto &mod = EE.getModule();
1166   Function *F = mod.createFunction("main");
1167 
1168   // Create random input audio signal.
1169   dim_t windowStride = 320;
1170   dim_t inputLength = windowSize + (windowCount - 1) * windowStride;
1171   auto *input = mod.createPlaceholder(ElemKind::FloatTy, {inputLength}, "input",
1172                                       false /* isTrainable */);
1173   bindings.allocate(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
1174 
1175   // Create AudioSpectrogram node.
1176   auto *audioSpec = F->createAudioSpectrogram(
1177       "audio_spectrogram", input, windowSize, windowStride, magnitudeSquared);
1178   auto *res = F->createSave("save", audioSpec);
1179   auto *resultTensor = bindings.allocate(res->getPlaceholder());
1180   return std::make_pair(F, resultTensor);
1181 }
1182 
1183 #define TEST_AUDIO_SPECTROGRAM(WCOUNT, WSIZE, MSQUARED, TOL)                   \
1184   TEST_P(OperatorStatelessTest,                                                \
1185          AudioSpectrogram_##WCOUNT##x##WSIZE##_##MSQUARED##_Float) {           \
1186     ENABLED_BACKENDS("Interpreter", "CPU");                                    \
1187     compareAgainstInterpreter(                                                 \
1188         getBackendName(),                                                      \
1189         createAndInitBasicAudioSpectrogramTest<WCOUNT, WSIZE, MSQUARED>,       \
1190         ElemKind::FloatTy, ElemKind::FloatTy, TOL);                            \
1191   }
1192 
1193 /// Test one window magnitude spectrograms.
1194 TEST_AUDIO_SPECTROGRAM(1, 2, false, 1e-6)
1195 TEST_AUDIO_SPECTROGRAM(1, 4, false, 1e-6)
1196 TEST_AUDIO_SPECTROGRAM(1, 8, false, 1e-6)
1197 TEST_AUDIO_SPECTROGRAM(1, 16, false, 1e-6)
1198 TEST_AUDIO_SPECTROGRAM(1, 32, false, 1e-6)
1199 TEST_AUDIO_SPECTROGRAM(1, 64, false, 5e-6)
1200 TEST_AUDIO_SPECTROGRAM(1, 128, false, 5e-6)
1201 TEST_AUDIO_SPECTROGRAM(1, 256, false, 1e-5)
1202 TEST_AUDIO_SPECTROGRAM(1, 512, false, 5e-5)
1203 TEST_AUDIO_SPECTROGRAM(1, 1024, false, 5e-5)
1204 
1205 /// Test multiple window magnitude spectrograms.
1206 TEST_AUDIO_SPECTROGRAM(2, 256, false, 1e-5)
1207 TEST_AUDIO_SPECTROGRAM(3, 320, false, 1e-5)
1208 TEST_AUDIO_SPECTROGRAM(4, 640, false, 5e-5)
1209 
1210 /// Test multiple window power spectrograms.
1211 TEST_AUDIO_SPECTROGRAM(2, 256, true, 5e-4)
1212 TEST_AUDIO_SPECTROGRAM(3, 320, true, 5e-4)
1213 TEST_AUDIO_SPECTROGRAM(4, 640, true, 1e-3)
1214 
1215 /// Helper function to test MFCC node.
1216 template <size_t winNum, size_t specLen>
1217 static FunctionTensorPair
createAndInitBasicMFCCTest(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE)1218 createAndInitBasicMFCCTest(glow::PlaceholderBindings &bindings,
1219                            glow::ExecutionEngine &EE) {
1220   auto &mod = EE.getModule();
1221   Function *F = mod.createFunction("main");
1222 
1223   // Create random input spectrogram.
1224   auto *spectrogram =
1225       mod.createPlaceholder(ElemKind::FloatTy, {winNum, specLen}, "spectrogram",
1226                             false /* isTrainable */);
1227   bindings.allocate(spectrogram)
1228       ->getHandle()
1229       .randomize(10.0, 100.0, mod.getPRNG());
1230 
1231   // Create MFCC node.
1232   float sampleRate = 16000.0;
1233   float lowerFrequency = 20.0;
1234   float upperFrequency = 4000.0;
1235   size_t filterBankCount = 40;
1236   size_t numCoefficients = 13;
1237   auto *mfcc = F->createMFCC("mfcc", spectrogram, sampleRate, lowerFrequency,
1238                              upperFrequency, filterBankCount, numCoefficients);
1239   auto *res = F->createSave("save", mfcc);
1240   auto *resultTensor = bindings.allocate(res->getPlaceholder());
1241   return std::make_pair(F, resultTensor);
1242 }
1243 
1244 #define TEST_MFCC(WNUM, SLEN, TOL)                                             \
1245   TEST_P(OperatorStatelessTest, MFCC_##WNUM##x##SLEN##_Float) {                \
1246     ENABLED_BACKENDS("Interpreter", "CPU");                                    \
1247     compareAgainstInterpreter(getBackendName(),                                \
1248                               createAndInitBasicMFCCTest<WNUM, SLEN>,          \
1249                               ElemKind::FloatTy, ElemKind::FloatTy, TOL);      \
1250   }
1251 
1252 TEST_MFCC(1, 17, 5e-5)
1253 TEST_MFCC(1, 33, 5e-5)
1254 TEST_MFCC(1, 65, 2e-5)
1255 TEST_MFCC(1, 129, 1e-5)
1256 TEST_MFCC(2, 257, 1e-5)
1257 TEST_MFCC(3, 513, 1e-5)
1258 TEST_MFCC(3, 1025, 1e-5)
1259 
1260 // Helper to test SpaceToDepth using \p DTy.
1261 template <typename DataType>
testSpaceToDepthBlock3(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)1262 static void testSpaceToDepthBlock3(glow::PlaceholderBindings &bindings,
1263                                    glow::Module &mod, glow::Function *F,
1264                                    glow::ExecutionEngine &EE, ElemKind DTy) {
1265   unsigned blockSize = 3;
1266   auto *in = createPlaceholderConditionallyQuantized(mod, DTy, {1, 2, 6, 6},
1267                                                      "in", false, "NHWC");
1268   auto *tri = F->createTranspose("sptdTransposeIn", in, {0, 2, 3, 1}, "NHWC");
1269   auto *stdn = F->createSpaceToDepth("spacetodepth", tri, blockSize);
1270   auto *tro =
1271       F->createTranspose("sptdTransposeOut", stdn, {0, 3, 1, 2}, "NCHW");
1272   auto *save = F->createSave("save", tro);
1273   auto *result = bindings.allocate(save->getPlaceholder());
1274 
1275   /*
1276     Example for first batch.
1277   FROM:
1278   C0:             C1:
1279   [0  1  2  3  16 17]    [ 0  -1  -2  -3  -16 -17]
1280   [4  5  6  7  18 19]    [-4  -5  -6  -7  -18 -19]
1281   [8  9  10 11 20 21]    [-8  -9  -10 -11 -20 -21]
1282   [12 13 14 15 22 23]    [-12 -13 -14 -15 -22 -23]
1283   [24 25 26 27 28 29]    [-24 -25 -26 -27 -28 -29]
1284   [30 31 32 33 34 35]    [-30 -31 -32 -33 -34 -35]
1285 
1286   TO:
1287   C = 0
1288   [0,3]
1289   [12,15]
1290 
1291   C = 1
1292   [0,-3]
1293   [-12,-15]
1294 
1295   C = 2
1296   [1,16]
1297   [13,22]
1298 
1299   C = 3
1300   [-1,-16]
1301   [-13,-22]
1302 
1303   C = 4
1304   [2,17]
1305   [14,23]
1306 
1307   C = 5
1308   [-2,-17]
1309   [-14,-23]
1310 
1311   C = 6
1312   [4,7]
1313   [24,27]
1314 
1315   C = 7
1316   [-4,-7]
1317   [-24,-27]
1318 
1319   C = 8
1320   [5,18]
1321   [25,28]
1322 
1323   C = 9
1324   [-5,-18]
1325   [-25,-28]
1326 
1327   C = 10
1328   [6,19]
1329   [26,29]
1330 
1331   C = 11
1332   [-6,-19]
1333   [-26,-29]
1334 
1335   C = 12
1336   [8,11]
1337   [30,33]
1338 
1339   C = 13
1340   [-8,-11]
1341   [-30,-33]
1342 
1343   C = 14
1344   [9,20]
1345   [31,34]
1346 
1347   C = 15
1348   [-9,-20]
1349   [-31,-34]
1350 
1351   C = 16
1352   [10,21]
1353   [32,35]
1354 
1355   C = 17
1356   [-10,-21]
1357   [-32,-35]
1358   */
1359 
1360   bindings.allocate(in)->getHandle<DataType>() = {
1361       0,   1,   2,   3,   16,  17,  4,   5,   6,   7,   18,  19,  8,   9,   10,
1362       11,  20,  21,  12,  13,  14,  15,  22,  23,  24,  25,  26,  27,  28,  29,
1363       30,  31,  32,  33,  34,  35,  0,   -1,  -2,  -3,  -16, -17, -4,  -5,  -6,
1364       -7,  -18, -19, -8,  -9,  -10, -11, -20, -21, -12, -13, -14, -15, -22, -23,
1365       -24, -25, -26, -27, -28, -29, -30, -31, -32, -33, -34, -35};
1366 
1367   std::vector<DataType> refResult = {
1368       0,   3,   12,  15,  0,  -3, -12, -15, 1,   16,  13,  22, -1, -16, -13,
1369       -22, 2,   17,  14,  23, -2, -17, -14, -23, 4,   7,   24, 27, -4,  -7,
1370       -24, -27, 5,   18,  25, 28, -5,  -18, -25, -28, 6,   19, 26, 29,  -6,
1371       -19, -26, -29, 8,   11, 30, 33,  -8,  -11, -30, -33, 9,  20, 31,  34,
1372       -9,  -20, -31, -34, 10, 21, 32,  35,  -10, -21, -32, -35};
1373 
1374   EE.compile(CompilationMode::Infer);
1375   EE.run(bindings);
1376 
1377   Handle<DataType> resultH = result->getHandle<DataType>();
1378 
1379   auto iDims = in->dims();
1380   auto oDims = resultH.dims();
1381   EXPECT_EQ(iDims[0], oDims[0]);
1382   EXPECT_EQ(iDims[1] * blockSize * blockSize, oDims[1]);
1383   EXPECT_EQ(iDims[2], oDims[2] * blockSize);
1384   EXPECT_EQ(iDims[3], oDims[3] * blockSize);
1385 
1386   // NCHW format
1387   dim_t resIndex = 0;
1388   for (dim_t on = 0; on < oDims[0]; ++on) {
1389     for (dim_t oc = 0; oc < oDims[1]; ++oc) {
1390       for (dim_t oh = 0; oh < oDims[2]; ++oh) {
1391         for (dim_t ow = 0; ow < oDims[3]; ++ow) {
1392           DataType resultVal = resultH.at({on, oc, oh, ow});
1393           DataType refVal = refResult[resIndex++];
1394           EXPECT_EQ(resultVal, refVal);
1395         }
1396       }
1397     }
1398   }
1399 }
1400 
1401 /// Verify that the SpaceToDepth operator works correctly for int8. Block
1402 /// Size 3.
TEST_P(OperatorTest,spaceToDepth_block3_int8)1403 TEST_P(OperatorTest, spaceToDepth_block3_int8) {
1404   CHECK_IF_ENABLED();
1405   testSpaceToDepthBlock3<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
1406 }
1407 
1408 /// Verify that the SpaceToDepth operator works correctly for Float. Block
1409 /// Size 3.
TEST_P(OperatorTest,spaceToDepth_block3_Float)1410 TEST_P(OperatorTest, spaceToDepth_block3_Float) {
1411   CHECK_IF_ENABLED();
1412   testSpaceToDepthBlock3<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
1413 }
1414 
1415 // Helper to test SpaceToDepth using \p DTy.
1416 template <typename DataType>
testSpaceToDepth(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)1417 static void testSpaceToDepth(glow::PlaceholderBindings &bindings,
1418                              glow::Module &mod, glow::Function *F,
1419                              glow::ExecutionEngine &EE, ElemKind DTy) {
1420   unsigned blockSize = 2;
1421   auto *in = createPlaceholderConditionallyQuantized(mod, DTy, {2, 2, 4, 4},
1422                                                      "in", false, "NHWC");
1423   auto *tri = F->createTranspose("sptdTransposeIn", in, {0, 2, 3, 1}, "NHWC");
1424   auto *stdn = F->createSpaceToDepth("spacetodepth", tri, blockSize);
1425   auto *tro =
1426       F->createTranspose("sptdTransposeOut", stdn, {0, 3, 1, 2}, "NCHW");
1427   auto *save = F->createSave("save", tro);
1428   auto *result = bindings.allocate(save->getPlaceholder());
1429 
1430   /*
1431     Example for first batch.
1432   FROM:
1433   C0:             C1:
1434   [0  1  2  3]    [ 0  -1  -2  -3]
1435   [4  5  6  7]    [-4  -5  -6  -7]
1436   [8  9  10 11]   [-8  -9  -10 -11]
1437   [12 13 14 15]   [-12 -13 -14 -15]
1438 
1439   TO:
1440   C0:
1441   [0,  2]
1442   [8,  10]
1443 
1444   C1:
1445   [ 0,  -2]
1446   [-8, -10]
1447 
1448   C2:
1449   [1, 3]
1450   [9, 11]
1451 
1452   C3:
1453   [-1, -3]
1454   [-9, -11]
1455 
1456   C4:
1457   [4,  6]
1458   [12, 14]
1459 
1460   C5:
1461   [-4,  -6]
1462   [-12, -14]
1463 
1464   C6:
1465   [5, 7]
1466   [13, 15]
1467 
1468   C7:
1469   [-5,  -7]
1470   [-13, -15]
1471   */
1472 
1473   bindings.allocate(in)->getHandle<DataType>() = {
1474       0, 1,   2,   3,   4,  5,  6,   7,   8,  9,  10,  11,  12,  13,  14,  15,
1475       0, -1,  -2,  -3,  -4, -5, -6,  -7,  -8, -9, -10, -11, -12, -13, -14, -15,
1476       0, 7,   9,   23,  24, 25, 26,  27,  8,  9,  10,  33,  12,  13,  14,  15,
1477       0, -21, -22, -23, -4, -5, -26, -27, -8, -9, -10, -11, -12, -13, -14, -15};
1478 
1479   std::vector<DataType> refResult = {
1480       0,  2,  8,  10, 0,  -2,  -8,  -10, 1,  3,  9,  11, -1,  -3,  -9,  -11,
1481       4,  6,  12, 14, -4, -6,  -12, -14, 5,  7,  13, 15, -5,  -7,  -13, -15,
1482       0,  9,  8,  10, 0,  -22, -8,  -10, 7,  23, 9,  33, -21, -23, -9,  -11,
1483       24, 26, 12, 14, -4, -26, -12, -14, 25, 27, 13, 15, -5,  -27, -13, -15};
1484 
1485   EE.compile(CompilationMode::Infer);
1486   EE.run(bindings);
1487 
1488   Handle<DataType> resultH = result->getHandle<DataType>();
1489 
1490   auto iDims = in->dims();
1491   auto oDims = resultH.dims();
1492   EXPECT_EQ(iDims[0], oDims[0]);
1493   EXPECT_EQ(iDims[1] * blockSize * blockSize, oDims[1]);
1494   EXPECT_EQ(iDims[2], oDims[2] * blockSize);
1495   EXPECT_EQ(iDims[3], oDims[3] * blockSize);
1496 
1497   // NCHW format
1498   dim_t resIndex = 0;
1499   for (dim_t on = 0; on < oDims[0]; ++on) {
1500     for (dim_t oc = 0; oc < oDims[1]; ++oc) {
1501       for (dim_t oh = 0; oh < oDims[2]; ++oh) {
1502         for (dim_t ow = 0; ow < oDims[3]; ++ow) {
1503           DataType resultVal = resultH.at({on, oc, oh, ow});
1504           DataType refVal = refResult[resIndex++];
1505           EXPECT_EQ(resultVal, refVal);
1506         }
1507       }
1508     }
1509   }
1510 }
1511 
1512 /// Verify that the SpaceToDepth operator works correctly for int8. Block
1513 /// Size 2.
TEST_P(OperatorTest,spaceToDepth_block2_int8)1514 TEST_P(OperatorTest, spaceToDepth_block2_int8) {
1515   CHECK_IF_ENABLED();
1516   testSpaceToDepth<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
1517 }
1518 
1519 /// Verify that the SpaceToDepth operator works correctly for Float. Block
1520 /// Size 2.
TEST_P(OperatorTest,spaceToDepth_block2_Float)1521 TEST_P(OperatorTest, spaceToDepth_block2_Float) {
1522   CHECK_IF_ENABLED();
1523   testSpaceToDepth<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
1524 }
1525 
1526 /// Helper to test ResizeNearest using \p DTy.
1527 template <typename DataType>
testResizeNearest(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,bool v11=false)1528 static void testResizeNearest(glow::PlaceholderBindings &bindings,
1529                               glow::Module &mod, glow::Function *F,
1530                               glow::ExecutionEngine &EE, ElemKind DTy,
1531                               bool v11 = false) {
1532   auto *input = createPlaceholderConditionallyQuantized(mod, DTy, {1, 2, 2, 1},
1533                                                         "input", false, "NHWC");
1534   bindings.allocate(input)->getHandle<DataType>() = {2, 4, 8, 16};
1535 
1536   ResizeNearestNode *resizeUp = nullptr;
1537   ResizeNearestNode *resizeDown = nullptr;
1538 
1539   std::vector<float> scaleUp = {1, 2.0f, 1.5f, 1};
1540 
1541   if (v11) {
1542     dim_t newH = std::floor(2 * 2.0f);
1543     dim_t newW = std::floor(2 * 1.5f);
1544     auto outTy =
1545         mod.uniqueTypeWithNewShape(input->getType(), {1, newH, newW, 1});
1546     resizeUp = F->createResizeNearest("resizeUp", input, outTy);
1547   } else {
1548     resizeUp = F->createResizeNearest("resizeUp", input, scaleUp);
1549   }
1550   auto *saveUp = F->createSave("saveUp", resizeUp);
1551   auto *resultUp = bindings.allocate(saveUp->getPlaceholder());
1552 
1553   std::vector<float> scaleDown = {1, 0.9f, 0.6f, 1};
1554 
1555   if (v11) {
1556     dim_t newH = std::floor(2 * 0.9f);
1557     dim_t newW = std::floor(2 * 0.6f);
1558     auto outTy =
1559         mod.uniqueTypeWithNewShape(input->getType(), {1, newH, newW, 1});
1560     resizeDown = F->createResizeNearest("resizeDown", input, outTy);
1561   } else {
1562     resizeDown = F->createResizeNearest("resizeDown", input, scaleDown);
1563   }
1564 
1565   auto *saveDown = F->createSave("saveDown", resizeDown);
1566   auto *resultDown = bindings.allocate(saveDown->getPlaceholder());
1567 
1568   ::glow::convertPlaceholdersToConstants(
1569       F, bindings,
1570       {input, saveUp->getPlaceholder(), saveDown->getPlaceholder()});
1571 
1572   EE.compile(CompilationMode::Infer);
1573   EE.run(bindings);
1574 
1575   auto resultUpH = resultUp->getHandle<DataType>();
1576   std::vector<dim_t> expectedDimsUp = {1, 4, 3, 1};
1577   ASSERT_TRUE(resultUpH.dims().vec() == expectedDimsUp);
1578 
1579   EXPECT_EQ(resultUpH.at({0, 0, 0, 0}), static_cast<DataType>(2));
1580   EXPECT_EQ(resultUpH.at({0, 0, 1, 0}), static_cast<DataType>(2));
1581   EXPECT_EQ(resultUpH.at({0, 0, 2, 0}), static_cast<DataType>(4));
1582 
1583   EXPECT_EQ(resultUpH.at({0, 1, 0, 0}), static_cast<DataType>(2));
1584   EXPECT_EQ(resultUpH.at({0, 1, 1, 0}), static_cast<DataType>(2));
1585   EXPECT_EQ(resultUpH.at({0, 1, 2, 0}), static_cast<DataType>(4));
1586 
1587   EXPECT_EQ(resultUpH.at({0, 2, 0, 0}), static_cast<DataType>(8));
1588   EXPECT_EQ(resultUpH.at({0, 2, 1, 0}), static_cast<DataType>(8));
1589   EXPECT_EQ(resultUpH.at({0, 2, 2, 0}), static_cast<DataType>(16));
1590 
1591   EXPECT_EQ(resultUpH.at({0, 3, 0, 0}), static_cast<DataType>(8));
1592   EXPECT_EQ(resultUpH.at({0, 3, 1, 0}), static_cast<DataType>(8));
1593   EXPECT_EQ(resultUpH.at({0, 3, 2, 0}), static_cast<DataType>(16));
1594 
1595   auto resultDownH = resultDown->getHandle<DataType>();
1596   std::vector<dim_t> expectedDimsDown = {1, 1, 1, 1};
1597   ASSERT_TRUE(resultDownH.dims().vec() == expectedDimsDown);
1598   EXPECT_EQ(resultDownH.at({0, 0, 0, 0}), static_cast<DataType>(2));
1599 }
1600 
1601 /// Verify that the ResizeNearest operator works correctly for Float.
TEST_P(OperatorTest,ResizeNearest_Float)1602 TEST_P(OperatorTest, ResizeNearest_Float) {
1603   CHECK_IF_ENABLED();
1604   testResizeNearest<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
1605 }
1606 
1607 /// Verify that the ResizeNearest operator works correctly for Float16.
TEST_P(OperatorTest,ResizeNearest_Float16)1608 TEST_P(OperatorTest, ResizeNearest_Float16) {
1609   CHECK_IF_ENABLED();
1610   testResizeNearest<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
1611 }
1612 
1613 /// Verify that the ResizeNearest operator works correctly for BFloat16.
TEST_P(OperatorTest,ResizeNearest_BFloat16)1614 TEST_P(OperatorTest, ResizeNearest_BFloat16) {
1615   CHECK_IF_ENABLED();
1616   testResizeNearest<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty);
1617 }
1618 
1619 /// Verify that the ResizeNearest operator works correctly for Int8Q.
TEST_P(OperatorTest,ResizeNearest_Int8)1620 TEST_P(OperatorTest, ResizeNearest_Int8) {
1621   CHECK_IF_ENABLED();
1622   testResizeNearest<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
1623 }
1624 
1625 /// Verify that the ResizeNearest operator works correctly for Int16Q.
TEST_P(OperatorTest,ResizeNearest_Int16)1626 TEST_P(OperatorTest, ResizeNearest_Int16) {
1627   CHECK_IF_ENABLED();
1628   testResizeNearest<int16_t>(bindings_, mod_, F_, EE_, ElemKind::Int16QTy);
1629 }
1630 
1631 /// Verify that the ResizeNearest operator works correctly for Int32Q.
TEST_P(OperatorTest,ResizeNearest_Int32)1632 TEST_P(OperatorTest, ResizeNearest_Int32) {
1633   CHECK_IF_ENABLED();
1634   testResizeNearest<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32QTy);
1635 }
1636 
TEST_P(OperatorTest,ResizeNearest_Float_outTy)1637 TEST_P(OperatorTest, ResizeNearest_Float_outTy) {
1638   CHECK_IF_ENABLED();
1639   testResizeNearest<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, true);
1640 }
1641 
TEST_P(OperatorTest,ResizeNearest_Float16_outTy)1642 TEST_P(OperatorTest, ResizeNearest_Float16_outTy) {
1643   CHECK_IF_ENABLED();
1644   testResizeNearest<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty,
1645                                true);
1646 }
1647 
TEST_P(OperatorTest,ResizeNearest_BFloat16_outTy)1648 TEST_P(OperatorTest, ResizeNearest_BFloat16_outTy) {
1649   CHECK_IF_ENABLED();
1650   testResizeNearest<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty,
1651                                 true);
1652 }
1653 
TEST_P(OperatorTest,ResizeNearest_Int8_outTy)1654 TEST_P(OperatorTest, ResizeNearest_Int8_outTy) {
1655   CHECK_IF_ENABLED();
1656   testResizeNearest<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy, true);
1657 }
TEST_P(OperatorTest,ResizeNearest_Int16_outTy)1658 TEST_P(OperatorTest, ResizeNearest_Int16_outTy) {
1659   CHECK_IF_ENABLED();
1660   testResizeNearest<int16_t>(bindings_, mod_, F_, EE_, ElemKind::Int16QTy,
1661                              true);
1662 }
TEST_P(OperatorTest,ResizeNearest_Int32_outTy)1663 TEST_P(OperatorTest, ResizeNearest_Int32_outTy) {
1664   CHECK_IF_ENABLED();
1665   testResizeNearest<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32QTy,
1666                              true);
1667 }
1668 
1669 /// Helper to test ResizeNearest using \p DTy.
1670 template <typename DataType>
testResizeBilinear(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,bool v11=false)1671 static void testResizeBilinear(glow::PlaceholderBindings &bindings,
1672                                glow::Module &mod, glow::Function *F,
1673                                glow::ExecutionEngine &EE, ElemKind DTy,
1674                                bool v11 = false) {
1675   auto *input = createPlaceholderConditionallyQuantized(mod, DTy, {1, 2, 2, 1},
1676                                                         "input", false, "NHWC");
1677   bindings.allocate(input)->getHandle<DataType>() = {2, 4, 8, 16};
1678 
1679   std::vector<float> scaleUp = {1, 2.0f, 1.5f, 1};
1680 
1681   ResizeBilinearNode *resizeUp = nullptr;
1682   ResizeBilinearNode *resizeDown = nullptr;
1683 
1684   if (v11) {
1685     dim_t newH = std::floor(2 * 2.0f);
1686     dim_t newW = std::floor(2 * 1.5f);
1687     auto outTy =
1688         mod.uniqueTypeWithNewShape(input->getType(), {1, newH, newW, 1});
1689     resizeUp = F->createResizeBilinear("resizeUp", input, outTy);
1690   } else {
1691     resizeUp = F->createResizeBilinear("resizeUp", input, scaleUp);
1692   }
1693 
1694   auto *saveUp = F->createSave("saveUp", resizeUp);
1695   auto *resultUp = bindings.allocate(saveUp->getPlaceholder());
1696 
1697   std::vector<float> scaleDown = {1, 0.9f, 0.6f, 1};
1698 
1699   if (v11) {
1700     dim_t newH = std::floor(2 * 0.9f);
1701     dim_t newW = std::floor(2 * 0.6f);
1702     auto outTy =
1703         mod.uniqueTypeWithNewShape(input->getType(), {1, newH, newW, 1});
1704     resizeDown = F->createResizeBilinear("resizeDown", input, outTy);
1705   } else {
1706     resizeDown = F->createResizeBilinear("resizeDown", input, scaleDown);
1707   }
1708 
1709   auto *saveDown = F->createSave("saveDown", resizeDown);
1710   auto *resultDown = bindings.allocate(saveDown->getPlaceholder());
1711 
1712   ::glow::convertPlaceholdersToConstants(
1713       F, bindings,
1714       {input, saveUp->getPlaceholder(), saveDown->getPlaceholder()});
1715 
1716   EE.compile(CompilationMode::Infer);
1717   EE.run(bindings);
1718 
1719   auto resultUpH = resultUp->getHandle<DataType>();
1720   std::vector<dim_t> expectedDimsUp = {1, 4, 3, 1};
1721   ASSERT_TRUE(resultUpH.dims().vec() == expectedDimsUp);
1722 
1723 // use EXPECT_NEAR for float otherwise EXPECT_EQ. Optional third arg is
1724 // allowed error for EXPECT_NEAR. If not specified uses default.
1725 #define EXPECT_EQF(a, b, ...)                                                  \
1726   if ((std::is_same<DataType, float>::value) ||                                \
1727       (std::is_same<DataType, float16_t>::value) ||                            \
1728       (std::is_same<DataType, bfloat16_t>::value)) {                           \
1729     EXPECT_FLOAT_EQ(a, b);                                                     \
1730   } else {                                                                     \
1731     EXPECT_EQ(a, b);                                                           \
1732   }
1733 
1734   EXPECT_EQF(resultUpH.at({0, 0, 0, 0}), static_cast<DataType>(2));
1735   EXPECT_EQF(resultUpH.at({0, 0, 1, 0}), static_cast<DataType>(3.333333));
1736   EXPECT_EQF(resultUpH.at({0, 0, 2, 0}), static_cast<DataType>(4));
1737 
1738   EXPECT_EQF(resultUpH.at({0, 1, 0, 0}), static_cast<DataType>(5));
1739   EXPECT_EQF(resultUpH.at({0, 1, 1, 0}), static_cast<DataType>(8.333333));
1740   EXPECT_EQF(resultUpH.at({0, 1, 2, 0}), static_cast<DataType>(10));
1741 
1742   EXPECT_EQF(resultUpH.at({0, 2, 0, 0}), static_cast<DataType>(8));
1743   EXPECT_EQF(resultUpH.at({0, 2, 1, 0}), static_cast<DataType>(13.33333));
1744   EXPECT_EQF(resultUpH.at({0, 2, 2, 0}), static_cast<DataType>(16));
1745 
1746   EXPECT_EQF(resultUpH.at({0, 3, 0, 0}), static_cast<DataType>(8));
1747   EXPECT_EQF(resultUpH.at({0, 3, 1, 0}), static_cast<DataType>(13.33333));
1748   EXPECT_EQF(resultUpH.at({0, 3, 2, 0}), static_cast<DataType>(16));
1749 
1750   auto resultDownH = resultDown->getHandle<DataType>();
1751   std::vector<dim_t> expectedDimsDown = {1, 1, 1, 1};
1752   ASSERT_TRUE(resultDownH.dims().vec() == expectedDimsDown);
1753   EXPECT_EQF(resultDownH.at({0, 0, 0, 0}), static_cast<DataType>(2));
1754 }
1755 
1756 /// Verify that the ResizeNearest operator works correctly for Float.
TEST_P(OperatorTest,ResizeBilinear_Float)1757 TEST_P(OperatorTest, ResizeBilinear_Float) {
1758   CHECK_IF_ENABLED();
1759   testResizeBilinear<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
1760 }
1761 
1762 /// Verify that the ResizeNearest operator works correctly for Float16.
TEST_P(OperatorTest,ResizeBilinear_Float16)1763 TEST_P(OperatorTest, ResizeBilinear_Float16) {
1764   CHECK_IF_ENABLED();
1765   testResizeBilinear<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
1766 }
1767 
1768 /// Verify that the ResizeNearest operator works correctly for BFloat16.
TEST_P(OperatorTest,ResizeBilinear_BFloat16)1769 TEST_P(OperatorTest, ResizeBilinear_BFloat16) {
1770   CHECK_IF_ENABLED();
1771   testResizeBilinear<bfloat16_t>(bindings_, mod_, F_, EE_,
1772                                  ElemKind::BFloat16Ty);
1773 }
1774 
1775 /// Verify that the ResizeNearest operator works correctly for Int8Q.
TEST_P(OperatorTest,ResizeBilinear_Int8)1776 TEST_P(OperatorTest, ResizeBilinear_Int8) {
1777   CHECK_IF_ENABLED();
1778   testResizeBilinear<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
1779 }
1780 
1781 /// Verify that the ResizeNearest operator works correctly for Int16Q.
TEST_P(OperatorTest,ResizeBilinear_Int16)1782 TEST_P(OperatorTest, ResizeBilinear_Int16) {
1783   CHECK_IF_ENABLED();
1784   testResizeBilinear<int16_t>(bindings_, mod_, F_, EE_, ElemKind::Int16QTy);
1785 }
1786 
1787 /// Verify that the ResizeNearest operator works correctly for Int32Q.
TEST_P(OperatorTest,ResizeBilinear_Int32)1788 TEST_P(OperatorTest, ResizeBilinear_Int32) {
1789   CHECK_IF_ENABLED();
1790   testResizeBilinear<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32QTy);
1791 }
1792 
TEST_P(OperatorTest,ResizeBilinear_Float_outTy)1793 TEST_P(OperatorTest, ResizeBilinear_Float_outTy) {
1794   CHECK_IF_ENABLED();
1795   testResizeBilinear<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, true);
1796 }
TEST_P(OperatorTest,ResizeBilinear_Float16_outTy)1797 TEST_P(OperatorTest, ResizeBilinear_Float16_outTy) {
1798   CHECK_IF_ENABLED();
1799   testResizeBilinear<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty,
1800                                 true);
1801 }
TEST_P(OperatorTest,ResizeBilinear_BFloat16_outTy)1802 TEST_P(OperatorTest, ResizeBilinear_BFloat16_outTy) {
1803   CHECK_IF_ENABLED();
1804   testResizeBilinear<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty,
1805                                  true);
1806 }
TEST_P(OperatorTest,ResizeBilinear_Int8_outTy)1807 TEST_P(OperatorTest, ResizeBilinear_Int8_outTy) {
1808   CHECK_IF_ENABLED();
1809   testResizeBilinear<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy, true);
1810 }
TEST_P(OperatorTest,ResizeBilinear_Int16_outTy)1811 TEST_P(OperatorTest, ResizeBilinear_Int16_outTy) {
1812   CHECK_IF_ENABLED();
1813   testResizeBilinear<int16_t>(bindings_, mod_, F_, EE_, ElemKind::Int16QTy,
1814                               true);
1815 }
TEST_P(OperatorTest,ResizeBilinear_Int32_outTy)1816 TEST_P(OperatorTest, ResizeBilinear_Int32_outTy) {
1817   CHECK_IF_ENABLED();
1818   testResizeBilinear<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32QTy,
1819                               true);
1820 }
1821 
TEST_P(OperatorTest,pow)1822 TEST_P(OperatorTest, pow) {
1823   CHECK_IF_ENABLED();
1824 
1825   auto *X = mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 3}, "X", false);
1826   auto *Y = mod_.createPlaceholder(ElemKind::FloatTy, {2}, "Y", false);
1827   auto *Exp = mod_.createPlaceholder(ElemKind::FloatTy, {2}, "Exp", false);
1828 
1829   bindings_.allocate(X)->getHandle() = {5, 0.1f, -3};
1830   bindings_.allocate(Y)->getHandle() = {2, 100};
1831   bindings_.allocate(Exp)->getHandle() = {2, -1};
1832 
1833   auto *Pow1 = F_->createPow("Pow1", X, 2.0);
1834   auto *Pow2 = F_->createPow("Pow2", Y, 0.5);
1835   auto *Pow3 = F_->createPow("Pow3", Y, Exp);
1836 
1837   auto *save1 = F_->createSave("save", Pow1);
1838   auto *savePlaceholder1 = save1->getPlaceholder();
1839 
1840   auto *save2 = F_->createSave("save", Pow2);
1841   auto *savePlaceholder2 = save2->getPlaceholder();
1842 
1843   auto *save3 = F_->createSave("save", Pow3);
1844   auto *savePlaceholder3 = save3->getPlaceholder();
1845 
1846   bindings_.allocate(savePlaceholder1);
1847   bindings_.allocate(savePlaceholder2);
1848   bindings_.allocate(savePlaceholder3);
1849 
1850   EE_.compile(CompilationMode::Infer);
1851 
1852   EE_.run(bindings_);
1853 
1854   auto H_X = bindings_.get(savePlaceholder1)->getHandle();
1855   EXPECT_NEAR(H_X.at({0, 0, 0}), 25, 1E-5);
1856   EXPECT_NEAR(H_X.at({0, 0, 1}), 0.01, 1E-5);
1857   EXPECT_NEAR(H_X.at({0, 0, 2}), 9, 1E-5);
1858 
1859   auto H_Y = bindings_.get(savePlaceholder2)->getHandle();
1860   EXPECT_NEAR(H_Y.at({0}), sqrt(2.0), 1E-5);
1861   EXPECT_NEAR(H_Y.at({1}), 10, 1E-5);
1862 
1863   auto H_Z = bindings_.get(savePlaceholder3)->getHandle();
1864   EXPECT_NEAR(H_Z.at({0}), 4, 1E-5);
1865   EXPECT_NEAR(H_Z.at({1}), 0.01, 1E-5);
1866 }
1867 
1868 /// Helper to test ReplaceNaN using \p DTy.
1869 template <typename DataType>
testReplaceNaN(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)1870 static void testReplaceNaN(glow::PlaceholderBindings &bindings,
1871                            glow::Module &mod, glow::Function *F,
1872                            glow::ExecutionEngine &EE, ElemKind DTy) {
1873   auto value = 1.0f;
1874   auto *X = mod.createPlaceholder(DTy, {6}, "X", false);
1875   auto XH = bindings.allocate(X)->getHandle<DataType>();
1876   XH = {1, NAN, 2, NAN, 3, NAN};
1877 
1878   auto *RNN = F->createReplaceNaN("replaceNaN", X, value);
1879 
1880   auto *save = F->createSave("save", RNN);
1881   auto *saveTensor = bindings.allocate(save->getPlaceholder());
1882 
1883   EE.compile(CompilationMode::Infer);
1884 
1885   EE.run(bindings);
1886 
1887   auto saveH = saveTensor->getHandle<DataType>();
1888 
1889   for (size_t i = 0; i < 6; i++) {
1890     if (std::isnan((float)XH.raw(i))) {
1891       EXPECT_EQ(saveH.raw(i), (DataType)value);
1892     } else {
1893       EXPECT_EQ(XH.raw(i), saveH.raw(i));
1894     }
1895   }
1896 }
1897 
1898 /// Test that ReplaceNaN is correctly supported in FloatTy.
TEST_P(OperatorTest,replaceNaN_Float)1899 TEST_P(OperatorTest, replaceNaN_Float) {
1900   CHECK_IF_ENABLED();
1901   testReplaceNaN<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
1902 }
1903 
1904 /// Test that ReplaceNaN is correctly supported in Float16Ty.
TEST_P(OperatorTest,replaceNaN_Float16)1905 TEST_P(OperatorTest, replaceNaN_Float16) {
1906   CHECK_IF_ENABLED();
1907   testReplaceNaN<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
1908 }
1909 
1910 /// Test that ReplaceNaN is correctly supported in BFloat16Ty.
TEST_P(OperatorTest,replaceNaN_BFloat16)1911 TEST_P(OperatorTest, replaceNaN_BFloat16) {
1912   CHECK_IF_ENABLED();
1913   testReplaceNaN<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty);
1914 }
1915 
TEST_P(OperatorTest,log)1916 TEST_P(OperatorTest, log) {
1917   CHECK_IF_ENABLED();
1918 
1919   auto *X = mod_.createPlaceholder(ElemKind::FloatTy, {6}, "X", false);
1920   auto XH = bindings_.allocate(X)->getHandle();
1921   XH = {210030, 600, 4, 0.7f, .005f, 0.000829f};
1922 
1923   auto *LN = F_->createLog("log", X);
1924 
1925   auto *save = F_->createSave("save", LN);
1926   auto *saveTensor = bindings_.allocate(save->getPlaceholder());
1927 
1928   EE_.compile(CompilationMode::Infer);
1929 
1930   EE_.run(bindings_);
1931 
1932   auto saveH = saveTensor->getHandle();
1933 
1934   for (dim_t i = 0; i < 6; i++) {
1935     EXPECT_NEAR(saveH.at({i}), log(XH.at({i})), 1E-5);
1936   }
1937 }
1938 
1939 /// Helper to test Logit using \p DTy.
1940 template <typename DataType>
testLogit(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,float allowedError)1941 static void testLogit(glow::PlaceholderBindings &bindings, glow::Module &mod,
1942                       glow::Function *F, glow::ExecutionEngine &EE,
1943                       ElemKind DTy, float allowedError) {
1944   constexpr auto eps = 1E-6f; // the default in Caffe2
1945   constexpr dim_t size = 10;  // sample size for randomized tests
1946 
1947   auto *input = mod.createPlaceholder(DTy, {size}, "input", false);
1948   // generate the input data in (0.0f, 1.0f) (probabilites including degenerate
1949   // cases) and test that afterward the input data is clamped in
1950   // (eps, 1 - eps) as in Caffe2.
1951   bindings.allocate(input)->getHandle<DataType>().randomize(0.0f, 1.0f,
1952                                                             mod.getPRNG());
1953 
1954   auto *logitDiff = F->createLogit("logitDiff", input, eps);
1955   auto *saveDiff = F->createSave("saveDiff", logitDiff);
1956   bindings.allocate(saveDiff->getPlaceholder());
1957 
1958   // property: zero-sum for the log-odds for complementary events probabilities
1959   // i.e., logit(p) + logit(1 - p) == 0
1960   Node *const1 = F->createSplat("const1", input->getType(), 1.0);
1961   Node *complInput = F->createSub("sub", const1, input);
1962   Node *logitCompl = F->createLogit("logitCompl", complInput, eps);
1963   auto *saveCompl = F->createSave("saveCompl", logitCompl);
1964   bindings.allocate(saveCompl->getPlaceholder());
1965 
1966   EE.compile(CompilationMode::Infer);
1967   EE.run(bindings);
1968 
1969   // results: differential test against the oracle
1970   auto resultDiffH =
1971       bindings.get(saveDiff->getPlaceholder())->getHandle<DataType>();
1972   auto inputH = bindings.get(input)->getHandle<DataType>();
1973 
1974   // results: zero-sum property
1975   auto resultComplH =
1976       bindings.get(saveCompl->getPlaceholder())->getHandle<DataType>();
1977 
1978   // differential test:
1979   // ensure we match an oracle `logit_test` (a C++ reimplementation test)
1980   auto clamp_test = [](float v, float lo, float hi) {
1981     return std::max(std::min(v, hi), lo);
1982   };
1983   auto logit_test = [clamp_test](float x, float eps = 1E-6f) {
1984     float p = clamp_test(x, eps, 1.0f - eps);
1985     return std::log(p / (1.0f - p));
1986   };
1987 
1988   // property: the logit function is the right-inverse of the logistic function
1989   // i.e., logistic(logit(p)) == p
1990   auto logistic_test = [](float x) { return 1.0f / (1.0f + std::exp(-x)); };
1991 
1992   for (dim_t i = 0; i != size; ++i) {
1993     // differential test against the oracle
1994     EXPECT_NEAR(resultDiffH.at({i}), logit_test(inputH.at({i})), allowedError);
1995     // zero-sum property
1996     EXPECT_NEAR(resultComplH.at({i}) + resultDiffH.at({i}), 0.0f, allowedError);
1997     // right-inverse property
1998     EXPECT_NEAR(logistic_test(resultDiffH.at({i})),
1999                 clamp_test(inputH.at({i}), eps, 1.0f - eps), allowedError);
2000   }
2001 }
2002 
2003 /// Test the Logit operator using FloatTy.
TEST_P(OperatorTest,Logit_Float)2004 TEST_P(OperatorTest, Logit_Float) {
2005   CHECK_IF_ENABLED();
2006   testLogit<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, 1E-5);
2007 }
2008 
2009 /// Test the Logit operator using Float16Ty.
TEST_P(OperatorTest,Logit_Float16)2010 TEST_P(OperatorTest, Logit_Float16) {
2011   CHECK_IF_ENABLED();
2012   testLogit<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty, 0.002);
2013 }
2014 
2015 /// Test the Logit operator using Float16Ty.
TEST_P(OperatorTest,Logit_BFloat16)2016 TEST_P(OperatorTest, Logit_BFloat16) {
2017   CHECK_IF_ENABLED();
2018   testLogit<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty, 0.05);
2019 }
2020 
2021 /// Helper to test CmpEQ using \p DTy.
2022 template <typename DataType>
testCmpEQ(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)2023 static void testCmpEQ(glow::PlaceholderBindings &bindings, glow::Module &mod,
2024                       glow::Function *F, glow::ExecutionEngine &EE,
2025                       ElemKind DTy) {
2026   auto *X = mod.createPlaceholder(DTy, {2, 7}, "X", false);
2027   // Values listed here in the dynamic range of both int32_t and int64_t
2028   bindings.allocate(X)->getHandle<DataType>() = {
2029       0, 1, 17, 876, 1000, 44444, 65535, 0, 1, 17, 876, 1000, 44444, 65535};
2030   auto *Y = mod.createPlaceholder(DTy, {2, 7}, "Y", false);
2031   bindings.allocate(Y)->getHandle<DataType>() = {
2032       1, 2, 16, 900, 1111, 44544, 65534, 0, 1, 17, 876, 1000, 44444, 65535};
2033 
2034   auto *cmpEQ = F->createCmpEQ("cmpEQ", X, Y);
2035   auto *save = F->createSave("save", cmpEQ);
2036   auto *saveTensor = bindings.allocate(save->getPlaceholder());
2037 
2038   EE.compile(CompilationMode::Infer);
2039 
2040   EE.run(bindings);
2041 
2042   auto saveH = saveTensor->getHandle<bool>();
2043   for (dim_t i = 0; i < 7; ++i) {
2044     EXPECT_FALSE(saveH.at({0, i}));
2045   }
2046   for (dim_t i = 0; i < 7; ++i) {
2047     EXPECT_TRUE(saveH.at({1, i}));
2048   }
2049 }
2050 
2051 /// Test the CmpEQ operator using Int64ITy.
TEST_P(OperatorTest,CmpEQ_Int64)2052 TEST_P(OperatorTest, CmpEQ_Int64) {
2053   CHECK_IF_ENABLED();
2054   testCmpEQ<int64_t>(bindings_, mod_, F_, EE_, ElemKind::Int64ITy);
2055 }
2056 
2057 /// Test the CmpEQ operator using Int32ITy.
TEST_P(OperatorTest,CmpEQ_Int32)2058 TEST_P(OperatorTest, CmpEQ_Int32) {
2059   CHECK_IF_ENABLED();
2060   testCmpEQ<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32ITy);
2061 }
2062 
2063 /// Check that the add operator works properly with FP16.
TEST_P(OperatorTest,FP16Add)2064 TEST_P(OperatorTest, FP16Add) {
2065   CHECK_IF_ENABLED();
2066 
2067   PseudoRNG PRNG;
2068 
2069   auto *inputA =
2070       mod_.createPlaceholder(ElemKind::Float16Ty, {1, 3, 3, 1}, "A", false);
2071   bindings_.allocate(inputA)->getHandle<float16_t>().randomize(-3.0, 3.0, PRNG);
2072   auto *inputB =
2073       mod_.createPlaceholder(ElemKind::Float16Ty, {1, 3, 3, 1}, "B", false);
2074   bindings_.allocate(inputB)->getHandle<float16_t>().randomize(-3.0, 3.0, PRNG);
2075   auto *Pool = F_->createAdd("pool", inputA, inputB);
2076   auto *S = F_->createSave("save", Pool);
2077   bindings_.allocate(S->getPlaceholder());
2078 
2079   EE_.compile(CompilationMode::Infer);
2080   EE_.run(bindings_);
2081 
2082   auto result = bindings_.get(S->getPlaceholder())->getHandle<float16_t>();
2083   auto handleA = bindings_.get(inputA)->getHandle<float16_t>();
2084   auto handleB = bindings_.get(inputB)->getHandle<float16_t>();
2085   ASSERT_EQ(result.size(), handleA.size());
2086   for (size_t idx = 0, end = result.size(); idx != end; ++idx) {
2087     EXPECT_EQ(result.raw(idx), handleA.raw(idx) + handleB.raw(idx));
2088   }
2089 }
2090 
2091 /// Check that the add operator works properly with FP16.
TEST_P(OperatorTest,BFloat16Add)2092 TEST_P(OperatorTest, BFloat16Add) {
2093   CHECK_IF_ENABLED();
2094 
2095   PseudoRNG PRNG;
2096 
2097   auto *inputA =
2098       mod_.createPlaceholder(ElemKind::BFloat16Ty, {1, 3, 3, 1}, "A", false);
2099   bindings_.allocate(inputA)->getHandle<bfloat16_t>().randomize(-3.0, 3.0,
2100                                                                 PRNG);
2101   auto *inputB =
2102       mod_.createPlaceholder(ElemKind::BFloat16Ty, {1, 3, 3, 1}, "B", false);
2103   bindings_.allocate(inputB)->getHandle<bfloat16_t>().randomize(-3.0, 3.0,
2104                                                                 PRNG);
2105   auto *Pool = F_->createAdd("pool", inputA, inputB);
2106   auto *S = F_->createSave("save", Pool);
2107   bindings_.allocate(S->getPlaceholder());
2108 
2109   EE_.compile(CompilationMode::Infer);
2110   EE_.run(bindings_);
2111 
2112   auto result = bindings_.get(S->getPlaceholder())->getHandle<bfloat16_t>();
2113   auto handleA = bindings_.get(inputA)->getHandle<bfloat16_t>();
2114   auto handleB = bindings_.get(inputB)->getHandle<bfloat16_t>();
2115   ASSERT_EQ(result.size(), handleA.size());
2116   for (size_t idx = 0, end = result.size(); idx != end; ++idx) {
2117     EXPECT_EQ(result.raw(idx), handleA.raw(idx) + handleB.raw(idx));
2118   }
2119 }
2120 
TEST_P(OperatorTest,matmul)2121 TEST_P(OperatorTest, matmul) {
2122   CHECK_IF_ENABLED();
2123 
2124   auto *lhs = mod_.createPlaceholder(ElemKind::FloatTy, {3, 2}, "lhs", false);
2125   auto *rhs = mod_.createPlaceholder(ElemKind::FloatTy, {2, 1}, "rhs", false);
2126   bindings_.allocate(lhs)->getHandle() = {1, 2, 3, 4, 5, 6};
2127   bindings_.allocate(rhs)->getHandle() = {7, 10};
2128 
2129   auto *R = F_->createMatMul("MM", lhs, rhs);
2130 
2131   auto *save = F_->createSave("save", R);
2132   auto *saveTensor = bindings_.allocate(save->getPlaceholder());
2133 
2134   EE_.compile(CompilationMode::Infer);
2135   EE_.run(bindings_);
2136 
2137   auto H = saveTensor->getHandle();
2138   EXPECT_NEAR(H.at({0, 0}), 27, 0.001);
2139   EXPECT_NEAR(H.at({1, 0}), 61, 0.001);
2140   EXPECT_NEAR(H.at({2, 0}), 95, 0.001);
2141 }
2142 
2143 /// Test that cloneFunInsideFun works correctly with matmuls.
TEST_P(OperatorTest,matmul_ParCloneTest10)2144 TEST_P(OperatorTest, matmul_ParCloneTest10) {
2145   CHECK_IF_ENABLED();
2146 
2147   auto *lhs = mod_.createPlaceholder(ElemKind::FloatTy, {3, 2}, "lhs", false);
2148   auto *rhs = mod_.createPlaceholder(ElemKind::FloatTy, {2, 1}, "rhs", false);
2149   bindings_.allocate(lhs)->getHandle() = {1, 2, 3, 4, 5, 6};
2150   bindings_.allocate(rhs)->getHandle() = {7, 10};
2151 
2152   auto *R = F_->createMatMul("MM", lhs, rhs);
2153 
2154   auto *save = F_->createSave("save", R);
2155   auto *saveTensor = bindings_.allocate(save->getPlaceholder());
2156 
2157   CompilationContext cctx;
2158   const unsigned parallelCount = 10;
2159   auto resultTensors = cloneFunInsideFun(std::make_pair(F_, saveTensor),
2160                                          &bindings_, cctx, parallelCount);
2161 
2162   EXPECT_EQ(resultTensors.size(), parallelCount);
2163 
2164   EE_.compile(cctx);
2165   EE_.run(bindings_);
2166 
2167   for (Tensor *T : resultTensors) {
2168     auto H = T->getHandle();
2169     EXPECT_NEAR(H.at({0, 0}), 27, 0.001);
2170     EXPECT_NEAR(H.at({1, 0}), 61, 0.001);
2171     EXPECT_NEAR(H.at({2, 0}), 95, 0.001);
2172   }
2173 }
2174 
2175 /// Test that compareAgainstInterpreter works correctly along with quantization
2176 /// and parallel cloning.
TEST_P(OperatorStatelessTest,matmulQuantized_InterpCompareParClone)2177 TEST_P(OperatorStatelessTest, matmulQuantized_InterpCompareParClone) {
2178   CHECK_IF_ENABLED();
2179 
2180   constexpr unsigned parallelCount = 10;
2181   compareAgainstInterpreter(
2182       getBackendName(),
2183       [](PlaceholderBindings &bindings, ExecutionEngine &EE) {
2184         Module &mod = EE.getModule();
2185         Function *F = mod.createFunction("main");
2186         Placeholder *lhs =
2187             mod.createPlaceholder(ElemKind::FloatTy, {3, 2}, "lhs", false);
2188         Placeholder *rhs =
2189             mod.createPlaceholder(ElemKind::FloatTy, {2, 1}, "rhs", false);
2190         bindings.allocate(lhs)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
2191         bindings.allocate(rhs)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
2192 
2193         MatMulNode *R = F->createMatMul("MM", lhs, rhs);
2194 
2195         SaveNode *save = F->createSave("save", R);
2196         Tensor *saveTensor = bindings.allocate(save->getPlaceholder());
2197         return std::make_pair(F, saveTensor);
2198       },
2199       ElemKind::FloatTy, ElemKind::Int8QTy, 0.006, parallelCount);
2200 }
2201 
2202 /// Check that the matmul operator behaves correctly with FP16.
TEST_P(OperatorTest,FP16Matmul)2203 TEST_P(OperatorTest, FP16Matmul) {
2204   CHECK_IF_ENABLED();
2205 
2206   auto *lhs = mod_.createPlaceholder(ElemKind::Float16Ty, {3, 2}, "lhs", false);
2207   auto *rhs = mod_.createPlaceholder(ElemKind::Float16Ty, {2, 1}, "rhs", false);
2208   bindings_.allocate(lhs)->getHandle<float16_t>() = {1, 2, 3, 4, 5, 6};
2209   bindings_.allocate(rhs)->getHandle<float16_t>() = {7, 10};
2210 
2211   auto *R = F_->createMatMul("MM", lhs, rhs);
2212 
2213   auto *save = F_->createSave("save", R);
2214   auto *saveTensor = bindings_.allocate(save->getPlaceholder());
2215 
2216   EE_.compile(CompilationMode::Infer);
2217   EE_.run(bindings_);
2218 
2219   auto H = saveTensor->getHandle<float16_t>();
2220   EXPECT_NEAR(H.at({0, 0}), 27, 0.001);
2221   EXPECT_NEAR(H.at({1, 0}), 61, 0.001);
2222   EXPECT_NEAR(H.at({2, 0}), 95, 0.001);
2223 }
2224 
2225 /// Check that the matmul operator behaves correctly with FP16.
TEST_P(OperatorTest,BFloat16Matmul)2226 TEST_P(OperatorTest, BFloat16Matmul) {
2227   CHECK_IF_ENABLED();
2228 
2229   auto *lhs =
2230       mod_.createPlaceholder(ElemKind::BFloat16Ty, {3, 2}, "lhs", false);
2231   auto *rhs =
2232       mod_.createPlaceholder(ElemKind::BFloat16Ty, {2, 1}, "rhs", false);
2233   bindings_.allocate(lhs)->getHandle<bfloat16_t>() = {1, 2, 3, 4, 5, 6};
2234   bindings_.allocate(rhs)->getHandle<bfloat16_t>() = {7, 10};
2235 
2236   auto *R = F_->createMatMul("MM", lhs, rhs);
2237 
2238   auto *save = F_->createSave("save", R);
2239   auto *saveTensor = bindings_.allocate(save->getPlaceholder());
2240 
2241   EE_.compile(CompilationMode::Infer);
2242   EE_.run(bindings_);
2243 
2244   auto H = saveTensor->getHandle<bfloat16_t>();
2245   EXPECT_NEAR(H.at({0, 0}), 27, 0.001);
2246   EXPECT_NEAR(H.at({1, 0}), 61, 0.001);
2247   EXPECT_NEAR(H.at({2, 0}), 95, 0.001);
2248 }
2249 
2250 /// Test that the broadcasted batch mat mul operator works as expected.
TEST_P(OperatorTest,BroadcastedBatchMatMul)2251 TEST_P(OperatorTest, BroadcastedBatchMatMul) {
2252   CHECK_IF_ENABLED();
2253 
2254   auto *lhs =
2255       mod_.createPlaceholder(ElemKind::FloatTy, {2, 3, 2}, "lhs", false);
2256   auto *rhs = mod_.createPlaceholder(ElemKind::FloatTy, {2, 1}, "rhs", false);
2257   bindings_.allocate(lhs)->getHandle() = {1,  2,  3,  4,  5,  6,
2258                                           -1, -2, -3, -4, -5, -6};
2259   bindings_.allocate(rhs)->getHandle() = {7, 10};
2260 
2261   auto *R = F_->createBatchMatMul("BMM", lhs, rhs);
2262 
2263   auto *save = F_->createSave("save", R);
2264   auto *result = bindings_.allocate(save->getPlaceholder());
2265 
2266   EE_.compile(CompilationMode::Infer);
2267   EE_.run(bindings_);
2268 
2269   auto H = result->getHandle();
2270   EXPECT_NEAR(H.at({0, 0, 0}), 27, 0.001);
2271   EXPECT_NEAR(H.at({0, 1, 0}), 61, 0.001);
2272   EXPECT_NEAR(H.at({0, 2, 0}), 95, 0.001);
2273   EXPECT_NEAR(H.at({1, 0, 0}), -27, 0.001);
2274   EXPECT_NEAR(H.at({1, 1, 0}), -61, 0.001);
2275   EXPECT_NEAR(H.at({1, 2, 0}), -95, 0.001);
2276 }
2277 
2278 /// Test that the broadcasted batch mat mul operator works as expected when the
2279 /// RHS does not have to be tiled.
TEST_P(OperatorTest,NonBroadcastedBatchMatMul)2280 TEST_P(OperatorTest, NonBroadcastedBatchMatMul) {
2281   CHECK_IF_ENABLED();
2282   auto *lhs =
2283       mod_.createPlaceholder(ElemKind::FloatTy, {1, 3, 2}, "lhs", false);
2284   auto *rhs = mod_.createPlaceholder(ElemKind::FloatTy, {2, 1}, "rhs", false);
2285   bindings_.allocate(lhs)->getHandle() = {1, 2, 3, 4, 5, 6};
2286   bindings_.allocate(rhs)->getHandle() = {7, 10};
2287 
2288   auto *R = F_->createBatchMatMul("BMM", lhs, rhs);
2289 
2290   auto *save = F_->createSave("save", R);
2291   auto *result = bindings_.allocate(save->getPlaceholder());
2292 
2293   EE_.compile(CompilationMode::Infer);
2294   EE_.run(bindings_);
2295 
2296   auto H = result->getHandle();
2297   EXPECT_NEAR(H.at({0, 0, 0}), 27, 0.001);
2298   EXPECT_NEAR(H.at({0, 1, 0}), 61, 0.001);
2299   EXPECT_NEAR(H.at({0, 2, 0}), 95, 0.001);
2300 }
2301 
TEST_P(OperatorTest,ParallelBatchMatMul)2302 TEST_P(OperatorTest, ParallelBatchMatMul) {
2303   CHECK_IF_ENABLED();
2304 
2305   auto *lhs =
2306       mod_.createPlaceholder(ElemKind::FloatTy, {2, 3, 2}, "lhs", false);
2307   auto *rhs =
2308       mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 1}, "rhs", false);
2309   bindings_.allocate(lhs)->getHandle() = {1,  2,  3,  4,  5,  6,
2310                                           -1, -2, -3, -4, -5, -6};
2311   bindings_.allocate(rhs)->getHandle() = {7, 10, 12, -1};
2312 
2313   auto *R = F_->createBatchMatMul("BMM", lhs, rhs);
2314 
2315   auto *save = F_->createSave("save", R);
2316   auto *result = bindings_.allocate(save->getPlaceholder());
2317 
2318   EE_.compile(CompilationMode::Infer);
2319   EE_.run(bindings_);
2320 
2321   auto H = result->getHandle();
2322   EXPECT_NEAR(H.at({0, 0, 0}), 27, 0.001);
2323   EXPECT_NEAR(H.at({0, 1, 0}), 61, 0.001);
2324   EXPECT_NEAR(H.at({0, 2, 0}), 95, 0.001);
2325   EXPECT_NEAR(H.at({1, 0, 0}), -10, 0.001);
2326   EXPECT_NEAR(H.at({1, 1, 0}), -32, 0.001);
2327   EXPECT_NEAR(H.at({1, 2, 0}), -54, 0.001);
2328 }
2329 
2330 static FunctionTensorPair
createAndInitParallelBatchMatMulTest(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE)2331 createAndInitParallelBatchMatMulTest(glow::PlaceholderBindings &bindings,
2332                                      glow::ExecutionEngine &EE) {
2333   auto &mod = EE.getModule();
2334   Function *F = mod.createFunction("main");
2335 
2336   auto *lhs =
2337       mod.createPlaceholder(ElemKind::FloatTy, {10, 50, 100}, "lhs", false);
2338   auto *rhs =
2339       mod.createPlaceholder(ElemKind::FloatTy, {10, 100, 80}, "rhs", false);
2340   bindings.allocate(lhs)->getHandle().randomize(-0.1, 0.1, mod.getPRNG());
2341   bindings.allocate(rhs)->getHandle().randomize(-0.1, 0.1, mod.getPRNG());
2342 
2343   auto *R = F->createBatchMatMul("BMM", lhs, rhs);
2344 
2345   auto *save = F->createSave("save", R);
2346   auto *resultTensor = bindings.allocate(save->getPlaceholder());
2347 
2348   return std::make_pair(F, resultTensor);
2349 }
2350 
TEST_P(OperatorStatelessTest,ParallelBatchMatMul_Float16)2351 TEST_P(OperatorStatelessTest, ParallelBatchMatMul_Float16) {
2352   CHECK_IF_ENABLED();
2353   compareAgainstInterpreter(
2354       getBackendName(), createAndInitParallelBatchMatMulTest, ElemKind::FloatTy,
2355       ElemKind::Float16Ty, 0.0005f, parCloneCountOpt);
2356 }
2357 
TEST_P(OperatorStatelessTest,ParallelBatchMatMul_BFloat16)2358 TEST_P(OperatorStatelessTest, ParallelBatchMatMul_BFloat16) {
2359   CHECK_IF_ENABLED();
2360   compareAgainstInterpreter(
2361       getBackendName(), createAndInitParallelBatchMatMulTest, ElemKind::FloatTy,
2362       ElemKind::BFloat16Ty, 0.0005f, parCloneCountOpt);
2363 }
2364 
TEST_P(OperatorStatelessTest,ParallelBatchMatMul_Int8)2365 TEST_P(OperatorStatelessTest, ParallelBatchMatMul_Int8) {
2366   CHECK_IF_ENABLED();
2367   compareAgainstInterpreter(
2368       getBackendName(), createAndInitParallelBatchMatMulTest, ElemKind::FloatTy,
2369       ElemKind::Int8QTy, 0.002f, parCloneCountOpt);
2370 }
2371 
2372 /// Helper to test BatchedReduceAdd using \p DTy.
2373 template <typename DataType>
testBatchedReduceAdd(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)2374 static void testBatchedReduceAdd(glow::PlaceholderBindings &bindings,
2375                                  glow::Module &mod, glow::Function *F,
2376                                  glow::ExecutionEngine &EE, ElemKind DTy) {
2377   auto *batch = mod.createPlaceholder(DTy, {2, 4}, "batch", false);
2378   bindings.allocate(batch)->getHandle<DataType>() = {10, 20, 30, 40,
2379                                                      1,  2,  3,  4};
2380 
2381   auto *R = F->createBatchedReduceAdd("reduce.add", batch, /* axis */ 0);
2382 
2383   auto *save = F->createSave("save", R);
2384   auto *result = bindings.allocate(save->getPlaceholder());
2385 
2386   EE.compile(CompilationMode::Infer);
2387   EE.run(bindings);
2388 
2389   Tensor expected(DTy, {4});
2390   expected.getHandle<DataType>() = {11, 22, 33, 44};
2391   EXPECT_TRUE(result->isEqual(expected));
2392 }
2393 
2394 /// Test that BatchedReduceAdd is correctly supported in FloatTy.
TEST_P(OperatorTest,batchedReduceAdd_Float)2395 TEST_P(OperatorTest, batchedReduceAdd_Float) {
2396   CHECK_IF_ENABLED();
2397 
2398   testBatchedReduceAdd<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
2399 }
2400 
2401 /// Test that BatchedReduceAdd is correctly supported in Float16Ty.
TEST_P(OperatorTest,batchedReduceAdd_Float16)2402 TEST_P(OperatorTest, batchedReduceAdd_Float16) {
2403   CHECK_IF_ENABLED();
2404   testBatchedReduceAdd<float16_t>(bindings_, mod_, F_, EE_,
2405                                   ElemKind::Float16Ty);
2406 }
2407 
2408 /// Test that BatchedReduceAdd is correctly supported in Float16Ty.
TEST_P(OperatorTest,batchedReduceAdd_BFloat16)2409 TEST_P(OperatorTest, batchedReduceAdd_BFloat16) {
2410   CHECK_IF_ENABLED();
2411   testBatchedReduceAdd<bfloat16_t>(bindings_, mod_, F_, EE_,
2412                                    ElemKind::BFloat16Ty);
2413 }
2414 
2415 /// Test that BatchedReduceAdd works correctly reducing the outermost axis.
TEST_P(OperatorTest,batchedReduceAdd_outerAxis)2416 TEST_P(OperatorTest, batchedReduceAdd_outerAxis) {
2417   CHECK_IF_ENABLED();
2418 
2419   auto *batch =
2420       mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 4}, "batch", false);
2421   bindings_.allocate(batch)->getHandle<float>() = {10, 20, 30, 40, 1, 2, 3, 4,
2422                                                    10, 20, 30, 40, 1, 2, 3, 4};
2423 
2424   auto *R = F_->createBatchedReduceAdd("reduce.add", batch, /* axis */ 0);
2425 
2426   auto *save = F_->createSave("save", R);
2427   auto *result = bindings_.allocate(save->getPlaceholder());
2428 
2429   EE_.compile(CompilationMode::Infer);
2430   EE_.run(bindings_);
2431 
2432   Tensor expected(ElemKind::FloatTy, {2, 4});
2433   expected.getHandle<float>() = {20, 40, 60, 80, 2, 4, 6, 8};
2434 
2435   EXPECT_TRUE(result->isEqual(expected));
2436 }
2437 
2438 /// Test that BatchedReduceAdd works correctly reducing an internal axis.
TEST_P(OperatorTest,batchedReduceAdd_innerAxis)2439 TEST_P(OperatorTest, batchedReduceAdd_innerAxis) {
2440   CHECK_IF_ENABLED();
2441 
2442   auto *batch =
2443       mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 4}, "batch", false);
2444   bindings_.allocate(batch)->getHandle<float>() = {10, 20, 30, 40, 1, 2, 3, 4,
2445                                                    10, 20, 30, 40, 1, 2, 3, 4};
2446 
2447   auto *R = F_->createBatchedReduceAdd("reduce.add", batch, /* axis */ 1);
2448 
2449   auto *save = F_->createSave("save", R);
2450   auto *result = bindings_.allocate(save->getPlaceholder());
2451 
2452   EE_.compile(CompilationMode::Infer);
2453   EE_.run(bindings_);
2454 
2455   Tensor expected(ElemKind::FloatTy, {2, 4});
2456   expected.getHandle<float>() = {11, 22, 33, 44, 11, 22, 33, 44};
2457 
2458   EXPECT_TRUE(result->isEqual(expected));
2459 }
2460 
2461 /// Test that BatchedReduceAdd works correctly reducing the innermost axis.
TEST_P(OperatorTest,batchedReduceAdd_lastAxis)2462 TEST_P(OperatorTest, batchedReduceAdd_lastAxis) {
2463   CHECK_IF_ENABLED();
2464 
2465   auto *batch =
2466       mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 4}, "batch", false);
2467   bindings_.allocate(batch)->getHandle<float>() = {10, 20, 30, 40, 1, 2, 3, 4,
2468                                                    10, 20, 30, 40, 1, 2, 3, 4};
2469   auto *R = F_->createBatchedReduceAdd("reduce.add", batch, /* axis */ 2);
2470 
2471   auto *save = F_->createSave("save", R);
2472   auto *result = bindings_.allocate(save->getPlaceholder());
2473 
2474   EE_.compile(CompilationMode::Infer);
2475   EE_.run(bindings_);
2476 
2477   Tensor expected(ElemKind::FloatTy, {2, 2});
2478   expected.getHandle<float>() = {100, 10, 100, 10};
2479 
2480   EXPECT_TRUE(result->isEqual(expected));
2481 }
2482 
2483 /// Test that BatchReducedAdd works on a 4D input.
TEST_P(OperatorTest,batchedReduceAdd_4Dinput)2484 TEST_P(OperatorTest, batchedReduceAdd_4Dinput) {
2485   CHECK_IF_ENABLED();
2486 
2487   auto *batch =
2488       mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 2, 4}, "batch", false);
2489   bindings_.allocate(batch)->getHandle<float>() = {
2490       10, 20, 30, 40, 1, 2, 3, 4, 10, 20, 30, 40, 1, 2, 3, 4,
2491       10, 20, 30, 40, 1, 2, 3, 4, 10, 20, 30, 40, 1, 2, 3, 4};
2492 
2493   auto *R = F_->createBatchedReduceAdd("reduce.add", batch, /* axis */ 0);
2494 
2495   auto *save = F_->createSave("save", R);
2496   auto *result = bindings_.allocate(save->getPlaceholder());
2497 
2498   EE_.compile(CompilationMode::Infer);
2499   EE_.run(bindings_);
2500 
2501   Tensor expected(ElemKind::FloatTy, {2, 2, 4});
2502   expected.getHandle<float>() = {20, 40, 60, 80, 2, 4, 6, 8,
2503                                  20, 40, 60, 80, 2, 4, 6, 8};
2504 
2505   EXPECT_TRUE(result->isEqual(expected));
2506 }
2507 
2508 /// Test that BatchReducedAdd works on a 5D input.
TEST_P(OperatorTest,batchedReduceAdd_5Dinput)2509 TEST_P(OperatorTest, batchedReduceAdd_5Dinput) {
2510   CHECK_IF_ENABLED();
2511   auto *batch = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 2, 2, 4},
2512                                        "batch", false);
2513   bindings_.allocate(batch)->getHandle<float>() = {
2514       10, 20, 30, 40, 1, 2, 3, 4, 10, 20, 30, 40, 1, 2, 3, 4,
2515       10, 20, 30, 40, 1, 2, 3, 4, 10, 20, 30, 40, 1, 2, 3, 4,
2516       10, 20, 30, 40, 1, 2, 3, 4, 10, 20, 30, 40, 1, 2, 3, 4,
2517       10, 20, 30, 40, 1, 2, 3, 4, 10, 20, 30, 40, 1, 2, 3, 4};
2518 
2519   auto *R = F_->createBatchedReduceAdd("reduce.add", batch, /* axis */ 2);
2520 
2521   auto *save = F_->createSave("save", R);
2522   auto *result = bindings_.allocate(save->getPlaceholder());
2523 
2524   EE_.compile(CompilationMode::Infer);
2525   EE_.run(bindings_);
2526 
2527   Tensor expected(ElemKind::FloatTy, {2, 2, 2, 4});
2528   expected.getHandle<float>() = {20, 40, 60, 80, 2,  4,  6,  8,  20, 40, 60,
2529                                  80, 2,  4,  6,  8,  20, 40, 60, 80, 2,  4,
2530                                  6,  8,  20, 40, 60, 80, 2,  4,  6,  8};
2531 
2532   EXPECT_TRUE(result->isEqual(expected));
2533 }
2534 
2535 /// Helper to test BatchedReduceMin using \p DTy.
2536 template <typename DataType>
testBatchedReduceMin(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)2537 static void testBatchedReduceMin(glow::PlaceholderBindings &bindings,
2538                                  glow::Module &mod, glow::Function *F,
2539                                  glow::ExecutionEngine &EE, ElemKind DTy) {
2540 
2541   auto *batch = mod.createPlaceholder(DTy, {2, 4}, "batch", false);
2542   bindings.allocate(batch)->getHandle<DataType>() = {10, 20, 30, 40,
2543                                                      1,  2,  3,  4};
2544   auto *R = F->createBatchedReduceMin("reduce.min", batch, /* axis */ 0);
2545 
2546   auto *save = F->createSave("save", R);
2547   auto *result = bindings.allocate(save->getPlaceholder());
2548 
2549   EE.compile(CompilationMode::Infer);
2550   EE.run(bindings);
2551 
2552   Tensor expected(DTy, {4});
2553   expected.getHandle<DataType>() = {1, 2, 3, 4};
2554 
2555   EXPECT_TRUE(result->isEqual(expected));
2556 }
2557 
2558 /// Helper to test BatchedReduceMin using \p DTy.
2559 template <typename DataType>
testBatchedReduceMinMultiAxis(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)2560 static void testBatchedReduceMinMultiAxis(glow::PlaceholderBindings &bindings,
2561                                           glow::Module &mod, glow::Function *F,
2562                                           glow::ExecutionEngine &EE,
2563                                           ElemKind DTy) {
2564   auto *batch = mod.createPlaceholder(DTy, {2, 2, 2, 2}, "batch", false);
2565   bindings.allocate(batch)->getHandle<DataType>() = {
2566       1, -2, 3, -4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
2567   auto *R = F->createBatchedReduceMin("reduce.min", batch, /* axis */ {1, 3});
2568   auto *save = F->createSave("save", R);
2569   auto *result = bindings.allocate(save->getPlaceholder());
2570 
2571   EE.compile(CompilationMode::Infer);
2572   EE.run(bindings);
2573 
2574   Tensor expected(DTy, {2, 2});
2575   expected.getHandle<DataType>() = {-2, -4, 9, 11};
2576   EXPECT_TRUE(result->isEqual(expected));
2577 }
2578 
2579 /// Test that BatchedReduceMin is correctly supported in FloatTy.
TEST_P(OperatorTest,batchedReduceMin_Float)2580 TEST_P(OperatorTest, batchedReduceMin_Float) {
2581   CHECK_IF_ENABLED();
2582   testBatchedReduceMin<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
2583 }
2584 
2585 /// Test that BatchedReduceMin is correctly supported in Int32Ty.
TEST_P(OperatorTest,batchedReduceMin_Int32)2586 TEST_P(OperatorTest, batchedReduceMin_Int32) {
2587   CHECK_IF_ENABLED();
2588   testBatchedReduceMin<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32ITy);
2589 }
2590 
2591 /// Test that BatchedReduceMin is correctly supported in Int64Ty.
TEST_P(OperatorTest,batchedReduceMin_Int64)2592 TEST_P(OperatorTest, batchedReduceMin_Int64) {
2593   CHECK_IF_ENABLED();
2594   testBatchedReduceMin<int64_t>(bindings_, mod_, F_, EE_, ElemKind::Int64ITy);
2595 }
2596 
2597 /// Test that BatchedReduceMin is correctly supported in FloatTy.
TEST_P(OperatorTest,batchedReduceMinMultiAxis_Float)2598 TEST_P(OperatorTest, batchedReduceMinMultiAxis_Float) {
2599   CHECK_IF_ENABLED();
2600   testBatchedReduceMinMultiAxis<float>(bindings_, mod_, F_, EE_,
2601                                        ElemKind::FloatTy);
2602 }
2603 
2604 /// Test that BatchedReduceMin is correctly supported in Int32Ty.
TEST_P(OperatorTest,batchedReduceMinMultiAxis_Int32)2605 TEST_P(OperatorTest, batchedReduceMinMultiAxis_Int32) {
2606   CHECK_IF_ENABLED();
2607   testBatchedReduceMinMultiAxis<int32_t>(bindings_, mod_, F_, EE_,
2608                                          ElemKind::Int32ITy);
2609 }
2610 
2611 /// Test that BatchedReduceMin is correctly supported in Int64Ty.
TEST_P(OperatorTest,batchedReduceMinMultiAxis_Int64)2612 TEST_P(OperatorTest, batchedReduceMinMultiAxis_Int64) {
2613   CHECK_IF_ENABLED();
2614   testBatchedReduceMinMultiAxis<int64_t>(bindings_, mod_, F_, EE_,
2615                                          ElemKind::Int64ITy);
2616 }
2617 
2618 /// Helper to test BatchedReduceZeroDimResult using \p DTy.
2619 template <typename DataType>
testBatchedReduceZeroDimResult(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)2620 static void testBatchedReduceZeroDimResult(glow::PlaceholderBindings &bindings,
2621                                            glow::Module &mod, glow::Function *F,
2622                                            glow::ExecutionEngine &EE,
2623                                            ElemKind DTy) {
2624   auto *batch = createPlaceholderConditionallyQuantized(
2625       mod, DTy, {4}, "batch", /* isTrainable */ false, "N");
2626   bindings.allocate(batch)->getHandle<DataType>() = {2, 4, 6, 8};
2627 
2628   auto OT = uniqueTypeConditionallyQuantized(mod, DTy, {});
2629   auto *RA = F->createBatchedReduceAdd("reduce.add", OT, batch, /* axis */ 0);
2630   auto *RM = F->createBatchedReduceMean("reduce.mean", OT, batch, /* axis */ 0);
2631   auto *saveRA = F->createSave("saveRA", RA);
2632   auto *saveRM = F->createSave("saveRM", RM);
2633   auto *resultRA = bindings.allocate(saveRA->getPlaceholder());
2634   auto *resultRM = bindings.allocate(saveRM->getPlaceholder());
2635 
2636   EE.compile(CompilationMode::Infer);
2637   EE.run(bindings);
2638 
2639   auto RAH = resultRA->getHandle<DataType>();
2640   auto RMH = resultRM->getHandle<DataType>();
2641   if (isQuantizedElemKind(DTy)) {
2642     EXPECT_EQ(RAH.at({}), static_cast<DataType>(20));
2643     EXPECT_EQ(RMH.at({}), static_cast<DataType>(5));
2644   } else {
2645     EXPECT_NEAR(RAH.at({}), 20, 0.001);
2646     EXPECT_NEAR(RMH.at({}), 5, 0.001);
2647   }
2648 }
2649 
2650 /// Test reduction down to a zero-dim tensor on FloatTy.
TEST_P(OperatorTest,batchedReduceZeroDimResult_Float)2651 TEST_P(OperatorTest, batchedReduceZeroDimResult_Float) {
2652   CHECK_IF_ENABLED();
2653   testBatchedReduceZeroDimResult<float>(bindings_, mod_, F_, EE_,
2654                                         ElemKind::FloatTy);
2655 }
2656 
2657 /// Test reduction down to a zero-dim tensor on Float16Ty.
TEST_P(OperatorTest,batchedReduceZeroDimResult_Float16)2658 TEST_P(OperatorTest, batchedReduceZeroDimResult_Float16) {
2659   CHECK_IF_ENABLED();
2660   testBatchedReduceZeroDimResult<float16_t>(bindings_, mod_, F_, EE_,
2661                                             ElemKind::Float16Ty);
2662 }
2663 
2664 /// Test reduction down to a zero-dim tensor on BFloat16Ty.
TEST_P(OperatorTest,batchedReduceZeroDimResult_BFloat16)2665 TEST_P(OperatorTest, batchedReduceZeroDimResult_BFloat16) {
2666   CHECK_IF_ENABLED();
2667   testBatchedReduceZeroDimResult<bfloat16_t>(bindings_, mod_, F_, EE_,
2668                                              ElemKind::BFloat16Ty);
2669 }
2670 
2671 /// Test reduction down to a zero-dim tensor on Int8QTy.
TEST_P(OperatorTest,batchedReduceZeroDimResult_Int8)2672 TEST_P(OperatorTest, batchedReduceZeroDimResult_Int8) {
2673   CHECK_IF_ENABLED();
2674   testBatchedReduceZeroDimResult<int8_t>(bindings_, mod_, F_, EE_,
2675                                          ElemKind::Int8QTy);
2676 }
2677 
2678 /// Helper to test BatchedReduceAddWithAxis using \p DTy.
2679 template <typename DataType>
testBatchedReduceAddWithAxis(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)2680 static void testBatchedReduceAddWithAxis(glow::PlaceholderBindings &bindings,
2681                                          glow::Module &mod, glow::Function *F,
2682                                          glow::ExecutionEngine &EE,
2683                                          ElemKind DTy) {
2684   auto *batch = createPlaceholderConditionallyQuantized(mod, DTy, {2, 3, 2},
2685                                                         "batch", false);
2686   bindings.allocate(batch)->getHandle<DataType>() = {0, 1, 2, 3, 4,  5,
2687                                                      6, 7, 8, 9, 10, 11};
2688 
2689   auto OT1 = uniqueTypeConditionallyQuantized(mod, DTy, {2, 2});
2690   auto *R1 =
2691       F->createBatchedReduceAdd("reduce.add.axis.1", OT1, batch, /* axis */ 1);
2692   auto OT2 = uniqueTypeConditionallyQuantized(mod, DTy, {2, 3});
2693   auto *R2 =
2694       F->createBatchedReduceAdd("reduce.add.axis.2", OT2, batch, /* axis */ 2);
2695   auto *save1 = F->createSave("save1", R1);
2696   auto *save2 = F->createSave("save2", R2);
2697 
2698   auto *result1 = bindings.allocate(save1->getPlaceholder());
2699   auto *result2 = bindings.allocate(save2->getPlaceholder());
2700 
2701   EE.compile(CompilationMode::Infer);
2702   EE.run(bindings);
2703 
2704   auto expected1 = createTensorConditionallyQuantized(DTy, {2, 2});
2705   expected1.getHandle<DataType>() = {6, 9, 24, 27};
2706   EXPECT_TRUE(result1->isEqual(expected1));
2707 
2708   auto expected2 = createTensorConditionallyQuantized(DTy, {2, 3});
2709   expected2.getHandle<DataType>() = {1, 5, 9, 13, 17, 21};
2710   EXPECT_TRUE(result2->isEqual(expected2));
2711 }
2712 
2713 /// Test that batchedReduceAddWithAxis is correctly supported in FloatTy.
TEST_P(OperatorTest,batchedReduceAddWithAxis_Float)2714 TEST_P(OperatorTest, batchedReduceAddWithAxis_Float) {
2715   CHECK_IF_ENABLED();
2716   testBatchedReduceAddWithAxis<float>(bindings_, mod_, F_, EE_,
2717                                       ElemKind::FloatTy);
2718 }
2719 
2720 /// Test that batchedReduceAddWithAxis is correctly supported in Float16Ty.
TEST_P(OperatorTest,batchedReduceAddWithAxis_Float16)2721 TEST_P(OperatorTest, batchedReduceAddWithAxis_Float16) {
2722   CHECK_IF_ENABLED();
2723   testBatchedReduceAddWithAxis<float16_t>(bindings_, mod_, F_, EE_,
2724                                           ElemKind::Float16Ty);
2725 }
2726 
2727 /// Test that batchedReduceAddWithAxis is correctly supported in BFloat16Ty.
TEST_P(OperatorTest,batchedReduceAddWithAxis_BFloat16)2728 TEST_P(OperatorTest, batchedReduceAddWithAxis_BFloat16) {
2729   CHECK_IF_ENABLED();
2730   testBatchedReduceAddWithAxis<bfloat16_t>(bindings_, mod_, F_, EE_,
2731                                            ElemKind::BFloat16Ty);
2732 }
2733 
2734 /// Test that batchedReduceAddWithAxis is correctly supported in Int8QTy.
TEST_P(OperatorTest,batchedReduceAddWithAxis_Int8Q)2735 TEST_P(OperatorTest, batchedReduceAddWithAxis_Int8Q) {
2736   CHECK_IF_ENABLED();
2737   testBatchedReduceAddWithAxis<int8_t>(bindings_, mod_, F_, EE_,
2738                                        ElemKind::Int8QTy);
2739 }
2740 
TEST_P(OperatorTest,batchedReduceAddQuantized)2741 TEST_P(OperatorTest, batchedReduceAddQuantized) {
2742   CHECK_IF_ENABLED();
2743 
2744   auto BT = mod_.uniqueType(ElemKind::Int8QTy, {3, 8}, 0.5, 3);
2745   auto OT = mod_.uniqueType(ElemKind::Int8QTy, {8}, 2.0, -1);
2746 
2747   auto *batch =
2748       mod_.createPlaceholder(ElemKind::Int8QTy, {3, 8}, BT->getScale(),
2749                              BT->getOffset(), "batch", false);
2750 
2751   bindings_.allocate(batch)->getHandle<int8_t>() = {
2752       27, -31, 16,  7,  20, 34, -2, 8,   -10, 83, 29,  -17,
2753       19, 13,  -11, -9, 50, 58, 0,  -20, -72, 43, -25, -1};
2754 
2755   auto BH = bindings_.get(batch)->getHandle<int8_t>();
2756 
2757   auto *R =
2758       F_->createBatchedReduceAdd("batched.reduce.add", OT, batch, /* axis */ 0);
2759 
2760   auto *save = F_->createSave("save", R);
2761   auto OH = bindings_.allocate(save->getPlaceholder())->getHandle<int8_t>();
2762 
2763   EE_.compile(CompilationMode::Infer);
2764   EE_.run(bindings_);
2765 
2766   for (dim_t i = 0; i < 8; i++) {
2767     std::array<int32_t, 3> b{{BH.at({0, i}), BH.at({1, i}), BH.at({2, i})}};
2768     float s = BT->getScale() / OT->getScale();
2769     int32_t o = BT->getOffset();
2770     float result = (b[0] - o) + (b[1] - o) + (b[2] - o);
2771     result = s * result + OT->getOffset();
2772 
2773     EXPECT_NEAR(std::round(result), OH.at({i}), 1.0);
2774   }
2775 }
2776 
TEST_P(OperatorTest,batchedReduceAddQuantizedWithAxis)2777 TEST_P(OperatorTest, batchedReduceAddQuantizedWithAxis) {
2778   CHECK_IF_ENABLED();
2779 
2780   auto BT = mod_.uniqueType(ElemKind::Int8QTy, {2, 3, 4}, 0.5, 3);
2781   auto OT = mod_.uniqueType(ElemKind::Int8QTy, {2, 4}, 2.0, -1);
2782 
2783   auto *batch =
2784       mod_.createPlaceholder(ElemKind::Int8QTy, {2, 3, 4}, BT->getScale(),
2785                              BT->getOffset(), "batch", false);
2786 
2787   bindings_.allocate(batch)->getHandle<int8_t>() = {
2788       27, -31, 16,  7,  20, 34, -2, 8,   -10, 83, 29,  -17,
2789       19, 13,  -11, -9, 50, 58, 0,  -20, -72, 43, -25, -1};
2790 
2791   auto BH = bindings_.get(batch)->getHandle<int8_t>();
2792 
2793   auto *R =
2794       F_->createBatchedReduceAdd("batched.reduce.add", OT, batch, /* axis */ 1);
2795   auto *save = F_->createSave("save", R);
2796   auto OH = bindings_.allocate(save->getPlaceholder())->getHandle<int8_t>();
2797 
2798   EE_.compile(CompilationMode::Infer);
2799   EE_.run(bindings_);
2800 
2801   for (dim_t i = 0; i < 2; i++) {
2802     for (dim_t j = 0; j < 4; j++) {
2803       std::array<int32_t, 3> b{
2804           {BH.at({i, 0, j}), BH.at({i, 1, j}), BH.at({i, 2, j})}};
2805       float s = BT->getScale() / OT->getScale();
2806       int32_t o = BT->getOffset();
2807       float result = (b[0] - o) + (b[1] - o) + (b[2] - o);
2808       result = s * result + OT->getOffset();
2809 
2810       EXPECT_NEAR(std::round(result), OH.at({i, j}), 1.0);
2811     }
2812   }
2813 }
2814 
TEST_P(OperatorTest,batchedReduceMean)2815 TEST_P(OperatorTest, batchedReduceMean) {
2816   CHECK_IF_ENABLED();
2817 
2818   auto *batch =
2819       mod_.createPlaceholder(ElemKind::FloatTy, {2, 4}, "batch", false);
2820   bindings_.allocate(batch)->getHandle() = {10, 20, 30, 40, 1, 2, 3, 4};
2821 
2822   auto *R = F_->createBatchedReduceMean("reduce.add", batch, /* axis */ 0);
2823 
2824   auto *save = F_->createSave("save", R);
2825   auto *result = bindings_.allocate(save->getPlaceholder());
2826 
2827   EE_.compile(CompilationMode::Infer);
2828   EE_.run(bindings_);
2829 
2830   auto H = result->getHandle();
2831   EXPECT_NEAR(H.at({0}), 5.5, 0.001);
2832   EXPECT_NEAR(H.at({1}), 11.0, 0.001);
2833   EXPECT_NEAR(H.at({2}), 16.5, 0.001);
2834   EXPECT_NEAR(H.at({3}), 22.0, 0.001);
2835 }
2836 
TEST_P(OperatorTest,batchedReduceMeanWithAxis)2837 TEST_P(OperatorTest, batchedReduceMeanWithAxis) {
2838   CHECK_IF_ENABLED();
2839 
2840   auto *batch =
2841       mod_.createPlaceholder(ElemKind::FloatTy, {2, 3, 2}, "batch", false);
2842   bindings_.allocate(batch)->getHandle() = {0, 1, 2, 3, 4,  5,
2843                                             6, 7, 8, 9, 10, 11};
2844 
2845   auto *R = F_->createBatchedReduceMean("reduce.add", batch, /* axis */ 1);
2846 
2847   auto *save = F_->createSave("save", R);
2848   auto *result = bindings_.allocate(save->getPlaceholder());
2849 
2850   EE_.compile(CompilationMode::Infer);
2851   EE_.run(bindings_);
2852 
2853   auto H = result->getHandle();
2854   EXPECT_NEAR(H.at({0, 0}), 2.0, 0.001);
2855   EXPECT_NEAR(H.at({0, 1}), 3.0, 0.001);
2856   EXPECT_NEAR(H.at({1, 0}), 8.0, 0.001);
2857   EXPECT_NEAR(H.at({1, 1}), 9.0, 0.001);
2858 }
2859 
TEST_P(OperatorTest,batchedReduceMeanQuantized)2860 TEST_P(OperatorTest, batchedReduceMeanQuantized) {
2861   CHECK_IF_ENABLED();
2862 
2863   auto BT = mod_.uniqueType(ElemKind::Int8QTy, {3, 8}, 0.5, 3);
2864   auto OT = mod_.uniqueType(ElemKind::Int8QTy, {8}, 2.0, -1);
2865 
2866   auto *batch =
2867       mod_.createPlaceholder(ElemKind::Int8QTy, {3, 8}, BT->getScale(),
2868                              BT->getOffset(), "batch", false);
2869 
2870   bindings_.allocate(batch)->getHandle<int8_t>() = {
2871       27, -31, 16,  7,  20, 34, -2, 8,   -10, 83, 29,  -17,
2872       19, 13,  -11, -9, 50, 58, 0,  -20, -72, 43, -25, -1};
2873 
2874   auto BH = bindings_.get(batch)->getHandle<int8_t>();
2875 
2876   auto *R = F_->createBatchedReduceMean("batched.reduce.add", OT, batch,
2877                                         /* axis */ 0);
2878 
2879   auto *save = F_->createSave("save", R);
2880   auto OH = bindings_.allocate(save->getPlaceholder())->getHandle<int8_t>();
2881 
2882   EE_.compile(CompilationMode::Infer);
2883   EE_.run(bindings_);
2884 
2885   for (dim_t i = 0; i < 8; i++) {
2886     std::array<int32_t, 3> b{{BH.at({0, i}), BH.at({1, i}), BH.at({2, i})}};
2887     float s = BT->getScale() / OT->getScale();
2888     int32_t o = BT->getOffset();
2889     float result = ((b[0] - o) + (b[1] - o) + (b[2] - o)) / 3;
2890     result = s * result + OT->getOffset();
2891 
2892     EXPECT_NEAR(std::round(result), OH.at({i}), 1.0);
2893   }
2894 }
2895 
TEST_P(OperatorTest,batchedReduceMeanQuantizedWithAxis)2896 TEST_P(OperatorTest, batchedReduceMeanQuantizedWithAxis) {
2897   CHECK_IF_ENABLED();
2898 
2899   auto BT = mod_.uniqueType(ElemKind::Int8QTy, {2, 3, 4}, 0.5, 3);
2900   auto OT = mod_.uniqueType(ElemKind::Int8QTy, {2, 4}, 2.0, -1);
2901 
2902   auto *batch =
2903       mod_.createPlaceholder(ElemKind::Int8QTy, {2, 3, 4}, BT->getScale(),
2904                              BT->getOffset(), "batch", false);
2905 
2906   bindings_.allocate(batch)->getHandle<int8_t>() = {
2907       27, -31, 16,  7,  20, 34, -2, 8,   -10, 83, 29,  -17,
2908       19, 13,  -11, -9, 50, 58, 0,  -20, -72, 43, -25, -1};
2909 
2910   auto BH = bindings_.get(batch)->getHandle<int8_t>();
2911 
2912   auto *R = F_->createBatchedReduceMean("batched.reduce.add", OT, batch,
2913                                         /* axis */ 1);
2914   auto *save = F_->createSave("save", R);
2915   auto OH = bindings_.allocate(save->getPlaceholder())->getHandle<int8_t>();
2916 
2917   EE_.compile(CompilationMode::Infer);
2918   EE_.run(bindings_);
2919 
2920   for (dim_t i = 0; i < 2; i++) {
2921     for (dim_t j = 0; j < 4; j++) {
2922       std::array<int32_t, 3> b{
2923           {BH.at({i, 0, j}), BH.at({i, 1, j}), BH.at({i, 2, j})}};
2924       float s = BT->getScale() / OT->getScale();
2925       int32_t o = BT->getOffset();
2926       float result = ((b[0] - o) + (b[1] - o) + (b[2] - o)) / 3;
2927       result = s * result + OT->getOffset();
2928 
2929       EXPECT_NEAR(std::round(result), OH.at({i, j}), 1.0);
2930     }
2931   }
2932 }
2933 
2934 /// Verify that batchedReduceMean optimization using AvgPool works correctly.
TEST_P(OperatorTest,batchedReduceMeanUsingAvgPool)2935 TEST_P(OperatorTest, batchedReduceMeanUsingAvgPool) {
2936   CHECK_IF_ENABLED();
2937 
2938   std::vector<dim_t> dims = {3, 20, 4, 8};
2939 
2940   auto *batch =
2941       mod_.createPlaceholder(ElemKind::FloatTy, dims, "batch", false, "NHWC");
2942 
2943   auto IH = bindings_.allocate(batch)->getHandle();
2944   IH.randomize(1.0, 100.0, mod_.getPRNG());
2945 
2946   auto *R = F_->createBatchedReduceMean("reduce.mean", batch, {2, 3});
2947 
2948   auto *save = F_->createSave("save", R);
2949   auto *result = bindings_.allocate(save->getPlaceholder());
2950   EE_.compile(CompilationMode::Infer);
2951 
2952   EE_.run(bindings_);
2953   auto H = result->getHandle();
2954 
2955   std::array<std::array<float, 20>, 3> results{};
2956   for (dim_t i = 0; i < dims[0]; i++) {
2957     for (dim_t j = 0; j < dims[1]; j++) {
2958       for (dim_t k = 0; k < dims[2]; k++) {
2959         for (dim_t l = 0; l < dims[3]; l++) {
2960           results[i][j] += IH.at({i, j, k, l});
2961         }
2962       }
2963       results[i][j] /= (dims[2] * dims[3]);
2964       EXPECT_NEAR(H.at({i, j}), results[i][j], 0.001);
2965     }
2966   }
2967 }
2968 
2969 /// Verify that quantized batchedReduceMean optimization using AvgPool works
2970 /// correctly.
TEST_P(OperatorTest,batchedReduceMeanUsingAvgPoolQuantized)2971 TEST_P(OperatorTest, batchedReduceMeanUsingAvgPoolQuantized) {
2972   CHECK_IF_ENABLED();
2973 
2974   std::vector<dim_t> dims = {2, 3, 3, 4};
2975 
2976   auto BT = mod_.uniqueType(ElemKind::Int8QTy, dims, 1, 0);
2977   auto OT = mod_.uniqueType(ElemKind::Int8QTy, {dims[0], dims[1]}, 1, 0);
2978   auto *batch = mod_.createPlaceholder(ElemKind::Int8QTy, dims, BT->getScale(),
2979                                        BT->getOffset(), "batch", false);
2980 
2981   auto IH = bindings_.allocate(batch)->getHandle<int8_t>();
2982   IH.randomize(1, 100, mod_.getPRNG());
2983 
2984   auto *R = F_->createBatchedReduceMean("reduce.mean", OT, batch, {2, 3});
2985 
2986   auto *save = F_->createSave("save", R);
2987   auto OH = bindings_.allocate(save->getPlaceholder())->getHandle<int8_t>();
2988 
2989   EE_.compile(CompilationMode::Infer);
2990   EE_.run(bindings_);
2991 
2992   std::array<std::array<float, 3>, 2> results{};
2993   float s = BT->getScale() / OT->getScale();
2994   for (dim_t i = 0; i < dims[0]; i++) {
2995     for (dim_t j = 0; j < dims[1]; j++) {
2996       for (dim_t k = 0; k < dims[2]; k++) {
2997         int32_t o = BT->getOffset();
2998         for (dim_t l = 0; l < dims[3]; l++) {
2999           results[i][j] += IH.at({i, j, k, l}) - o;
3000         }
3001       }
3002       results[i][j] = s * results[i][j] + OT->getOffset();
3003       results[i][j] /= (dims[2] * dims[3]);
3004       EXPECT_NEAR(std::round(results[i][j]), OH.at({i, j}), 1.0);
3005     }
3006   }
3007 }
3008 
3009 /// Test that the BatchedAdd operator works.
TEST_P(OperatorTest,BatchedAdd)3010 TEST_P(OperatorTest, BatchedAdd) {
3011   CHECK_IF_ENABLED();
3012 
3013   auto *batch =
3014       mod_.createPlaceholder(ElemKind::FloatTy, {2, 3, 3}, "batch", false);
3015   auto *added =
3016       mod_.createPlaceholder(ElemKind::FloatTy, {3, 3}, "added", false);
3017 
3018   bindings_.allocate(batch)->getHandle() = {9, 8, 7, 6, 5,  4,  3,  4,  5,
3019                                             6, 7, 8, 9, 10, 11, 12, 13, 14};
3020   bindings_.allocate(added)->getHandle().clear(1.0);
3021 
3022   auto *R = F_->createBatchedAdd("batch.add", batch, added);
3023   auto *save = F_->createSave("save", R);
3024   auto *result = bindings_.allocate(save->getPlaceholder());
3025 
3026   EE_.compile(CompilationMode::Infer);
3027   EE_.run(bindings_);
3028 
3029   auto BH = bindings_.get(batch)->getHandle();
3030   auto RH = result->getHandle();
3031   for (dim_t i = 0; i < 2; i++) {
3032     for (dim_t j = 0; j < 3; j++) {
3033       for (dim_t k = 0; k < 3; k++) {
3034         EXPECT_NEAR(RH.at({i, j, k}), BH.at({i, j, k}) + 1.0, 0.001);
3035       }
3036     }
3037   }
3038 }
3039 
3040 /// Broadcast Tensor of shape (2,1,1) to (2,4,2) with axis 0.
TEST_P(OperatorTest,broadcastSimple)3041 TEST_P(OperatorTest, broadcastSimple) {
3042   CHECK_IF_ENABLED();
3043 
3044   const dim_t numDims_A = 3;
3045   const dim_t dimY_A = 2;
3046   const dim_t dimZ_A = 4;
3047   const dim_t dimW_A = 2;
3048   const dim_t dims_A[numDims_A] = {dimY_A, dimZ_A, dimW_A};
3049 
3050   const dim_t numDims_B = 3;
3051   const dim_t dimY_B = 2;
3052   const dim_t dimZ_B = 1;
3053   const dim_t dimW_B = 1;
3054   const dim_t dims_B[numDims_B] = {dimY_B, dimZ_B, dimW_B};
3055 
3056   auto *B = mod_.createPlaceholder(ElemKind::FloatTy, dims_B, "B", false);
3057   auto *QB =
3058       mod_.createPlaceholder(ElemKind::Int8QTy, dims_B, 1.1, -2, "QB", false);
3059   auto H_B = bindings_.allocate(B)->getHandle();
3060   auto H_QB = bindings_.allocate(QB)->getHandle<int8_t>();
3061   H_B = {20, 10};
3062   H_QB = {35, -18};
3063 
3064   const unsigned axis = 0;
3065 
3066   auto *R = F_->createBroadcast("broadcasted", B, dims_A, axis);
3067   auto *QR = F_->createBroadcast("broadcastedQ", QB, dims_A, axis);
3068 
3069   auto *save = F_->createSave("save", R);
3070   auto *broadcasted = bindings_.allocate(save->getPlaceholder());
3071 
3072   auto *saveQ = F_->createSave("saveQ", QR);
3073   auto *broadcastedQ = bindings_.allocate(saveQ->getPlaceholder());
3074 
3075   EE_.compile(CompilationMode::Infer);
3076   EE_.run(bindings_);
3077 
3078   auto broadcastedBHandle = broadcasted->getHandle();
3079   auto broadcastedQBHandle = broadcastedQ->getHandle<int8_t>();
3080   // Verify broadcasted B has same shape.
3081   EXPECT_EQ(broadcastedBHandle.dims().size(), numDims_A);
3082   EXPECT_EQ(broadcastedQBHandle.dims().size(), numDims_A);
3083   for (size_t i = 0; i < broadcastedBHandle.dims().size(); i++) {
3084     EXPECT_EQ(broadcastedBHandle.dims()[i], dims_A[i]);
3085     EXPECT_EQ(broadcastedQBHandle.dims()[i], dims_A[i]);
3086   }
3087 
3088   // Look at the two values in X_B and verify in the three dimensions it was
3089   // broadcasted that the values were correctly broadcasted.
3090   const dim_t k_B = 0;
3091   const dim_t l_B = 0;
3092   for (dim_t j_B = 0; j_B < dimY_B; ++j_B) {
3093     const float origVal = H_B.at({j_B, k_B, l_B});
3094     const int8_t origValQ = H_QB.at({j_B, k_B, l_B});
3095     const dim_t j_A = j_B; // This dim was not broadcasted (dims were equal).
3096     for (dim_t k_A = 0; k_A < dimZ_A; k_A++) {
3097       for (dim_t l_A = 0; l_A < dimW_A; l_A++) {
3098         EXPECT_EQ(broadcastedBHandle.at({j_A, k_A, l_A}), origVal);
3099         EXPECT_EQ(broadcastedQBHandle.at({j_A, k_A, l_A}), origValQ);
3100       }
3101     }
3102   }
3103 }
3104 
3105 /// Broadcast a Tensor of shape (2,1) to (3,2,4,2) with axis 1.
TEST_P(OperatorTest,broadcast)3106 TEST_P(OperatorTest, broadcast) {
3107   CHECK_IF_ENABLED();
3108 
3109   const dim_t numDims_A = 4;
3110   const dim_t dimX_A = 3;
3111   const dim_t dimY_A = 2;
3112   const dim_t dimZ_A = 4;
3113   const dim_t dimW_A = 2;
3114   const dim_t dims_A[numDims_A] = {dimX_A, dimY_A, dimZ_A, dimW_A};
3115 
3116   const dim_t numDims_B = 2;
3117   const dim_t dimY_B = 2;
3118   const dim_t dimZ_B = 1;
3119   const dim_t dims_B[numDims_B] = {dimY_B, dimZ_B};
3120 
3121   auto *B = mod_.createPlaceholder(ElemKind::FloatTy, dims_B, "B", false);
3122   auto *QB =
3123       mod_.createPlaceholder(ElemKind::Int8QTy, dims_B, 0.8, 3, "QB", false);
3124 
3125   auto H_B = bindings_.allocate(B)->getHandle();
3126   auto H_QB = bindings_.allocate(QB)->getHandle<int8_t>();
3127   H_B = {20, 10};
3128   H_QB = {-8, 41};
3129 
3130   const unsigned axis = 1;
3131 
3132   auto *R = F_->createBroadcast("broadcasted", B, dims_A, axis);
3133   auto *QR = F_->createBroadcast("broadcastedQ", QB, dims_A, axis);
3134 
3135   auto *save = F_->createSave("save", R);
3136   auto *broadcasted = bindings_.allocate(save->getPlaceholder());
3137 
3138   auto *saveQ = F_->createSave("saveQ", QR);
3139   auto *broadcastedQ = bindings_.allocate(saveQ->getPlaceholder());
3140 
3141   EE_.compile(CompilationMode::Infer);
3142   EE_.run(bindings_);
3143 
3144   auto broadcastedBHandle = broadcasted->getHandle();
3145   auto broadcastedQBHandle = broadcastedQ->getHandle<int8_t>();
3146   // Verify broadcasted B has same shape.
3147   EXPECT_EQ(broadcastedBHandle.dims().size(), numDims_A);
3148   EXPECT_EQ(broadcastedQBHandle.dims().size(), numDims_A);
3149   for (size_t i = 0; i < broadcastedBHandle.dims().size(); i++) {
3150     EXPECT_EQ(broadcastedBHandle.dims()[i], dims_A[i]);
3151     EXPECT_EQ(broadcastedQBHandle.dims()[i], dims_A[i]);
3152   }
3153 
3154   // Look at the two values in X_B and verify in the three dimensions it was
3155   // broadcasted that the values were correctly broadcasted.
3156   const dim_t k_B = 0;
3157   for (dim_t j_B = 0; j_B < dimY_B; ++j_B) {
3158     const float origVal = H_B.at({j_B, k_B});
3159     const int8_t origValQ = H_QB.at({j_B, k_B});
3160     const dim_t j_A = j_B; // This dim was not broadcasted (dims were equal).
3161     for (dim_t i_A = 0; i_A < dimX_A; i_A++) {
3162       for (dim_t k_A = 0; k_A < dimZ_A; k_A++) {
3163         for (dim_t l_A = 0; l_A < dimW_A; l_A++) {
3164           EXPECT_EQ(broadcastedBHandle.at({i_A, j_A, k_A, l_A}), origVal);
3165           EXPECT_EQ(broadcastedQBHandle.at({i_A, j_A, k_A, l_A}), origValQ);
3166         }
3167       }
3168     }
3169   }
3170 }
3171 
3172 /// Perform a simple weighted sum.
TEST_P(OperatorTest,weightedSum)3173 TEST_P(OperatorTest, weightedSum) {
3174   CHECK_IF_ENABLED();
3175 
3176   // Create the data.
3177   auto *A = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "A", false);
3178   bindings_.allocate(A)->getHandle() = {1.0, 2.0, 3.0, 4.0};
3179 
3180   auto *B = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "B", false);
3181   bindings_.allocate(B)->getHandle() = {5.0, 6.0, 7.0, 8.0};
3182 
3183   // Create the weights.
3184   auto *AW = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "AW", false);
3185   bindings_.allocate(AW)->getHandle() = {0.1f};
3186 
3187   auto *BW = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "BW", false);
3188   bindings_.allocate(BW)->getHandle() = {10.0f};
3189 
3190   // Create the weighted sum with the data and weights, and save it.
3191   auto *WS = F_->createWeightedSum("ws", {A, B}, {AW, BW});
3192   auto *save = F_->createSave("save", WS);
3193   auto *saveTensor = bindings_.allocate(save->getPlaceholder());
3194 
3195   EE_.compile(CompilationMode::Infer);
3196   EE_.run(bindings_);
3197 
3198   // Verify the weighted sum was correctly calculated.
3199   auto resultH = saveTensor->getHandle();
3200   EXPECT_NEAR(resultH.at({0, 0}), 50.1, 1E-5);
3201   EXPECT_NEAR(resultH.at({0, 1}), 60.2, 1E-5);
3202   EXPECT_NEAR(resultH.at({1, 0}), 70.3, 1E-5);
3203   EXPECT_NEAR(resultH.at({1, 1}), 80.4, 1E-5);
3204 }
3205 
3206 /// Helper to test ReluSimple using \p DTy.
3207 template <typename DataType>
testReluSimple(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)3208 static void testReluSimple(glow::PlaceholderBindings &bindings,
3209                            glow::Module &mod, glow::Function *F,
3210                            glow::ExecutionEngine &EE, ElemKind DTy) {
3211   auto *in = mod.createPlaceholder(DTy, {7}, "in", false);
3212   auto *relu = F->createRELU("relu", in);
3213   auto *save = F->createSave("relu", relu);
3214   auto *result = bindings.allocate(save->getPlaceholder());
3215 
3216   bindings.allocate(in)->getHandle<DataType>() = {0, -1, -2, -3, 4, 5, 6};
3217 
3218   EE.compile(CompilationMode::Infer);
3219   EE.run(bindings);
3220 
3221   auto resultH = result->getHandle<DataType>();
3222 
3223   for (size_t i = 0; i < 7; i++) {
3224     if (i < 4) {
3225       EXPECT_EQ(resultH.raw(i), static_cast<DataType>(0));
3226     } else {
3227       EXPECT_EQ(resultH.raw(i), static_cast<DataType>(i));
3228     }
3229   }
3230 }
3231 
3232 /// Verify that the RELU operator works correctly for Float.
TEST_P(OperatorTest,ReluSimple_Float)3233 TEST_P(OperatorTest, ReluSimple_Float) {
3234   CHECK_IF_ENABLED();
3235 
3236   testReluSimple<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
3237 }
3238 
3239 /// Verify that the RELU operator works correctly for Float16.
TEST_P(OperatorTest,ReluSimple_Float16)3240 TEST_P(OperatorTest, ReluSimple_Float16) {
3241   CHECK_IF_ENABLED();
3242   testReluSimple<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
3243 }
3244 
3245 /// Verify that the RELU operator works correctly for Float16.
TEST_P(OperatorTest,ReluSimple_BFloat16)3246 TEST_P(OperatorTest, ReluSimple_BFloat16) {
3247   CHECK_IF_ENABLED();
3248   testReluSimple<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty);
3249 }
3250 
3251 /// Helper to test PReluSimple using \p DTy.
3252 template <typename DataType>
testPReluSimple(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,double allowedError)3253 static void testPReluSimple(glow::PlaceholderBindings &bindings,
3254                             glow::Module &mod, glow::Function *F,
3255                             glow::ExecutionEngine &EE, ElemKind DTy,
3256                             double allowedError) {
3257   auto *in = mod.createPlaceholder(DTy, {7}, "in", false);
3258   auto *slope = mod.createPlaceholder(DTy, {7}, "slope", false);
3259   auto *prelu = F->createPRELU("prelu", in, slope);
3260   auto *save = F->createSave("prelu", prelu);
3261   auto *result = bindings.allocate(save->getPlaceholder());
3262 
3263   bindings.allocate(in)->getHandle<DataType>() = {0, -1, -2, -3, 4, 5, 6};
3264   bindings.allocate(slope)->getHandle<DataType>().randomize(0.1, 3.0,
3265                                                             mod.getPRNG());
3266 
3267   EE.compile(CompilationMode::Infer);
3268   EE.run(bindings);
3269 
3270   auto resultH = result->getHandle<DataType>();
3271   auto inH = bindings.get(in)->getHandle<DataType>();
3272   auto slopeH = bindings.get(slope)->getHandle<DataType>();
3273 
3274   for (size_t i = 0; i < 7; i++) {
3275     DataType expectedResult =
3276         slopeH.raw(i) * std::min<DataType>(0, inH.raw(i)) +
3277         std::max<DataType>(0, inH.raw(i));
3278     EXPECT_NEAR(resultH.at(i), expectedResult, allowedError);
3279   }
3280 }
3281 
3282 /// Verify that the PRELU operator works correctly for Float.
TEST_P(OperatorTest,PReluSimple_Float)3283 TEST_P(OperatorTest, PReluSimple_Float) {
3284   CHECK_IF_ENABLED();
3285   testPReluSimple<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, 1E-32);
3286 }
3287 
3288 /// Verify that the PRELU operator works correctly for Float16.
TEST_P(OperatorTest,PReluSimple_Float16)3289 TEST_P(OperatorTest, PReluSimple_Float16) {
3290   CHECK_IF_ENABLED();
3291   testPReluSimple<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty,
3292                              1E-16);
3293 }
3294 
3295 /// Verify that the PRELU operator works correctly for BFloat16.
TEST_P(OperatorTest,PReluSimple_BFloat16)3296 TEST_P(OperatorTest, PReluSimple_BFloat16) {
3297   CHECK_IF_ENABLED();
3298   testPReluSimple<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty,
3299                               1E-16);
3300 }
3301 
TEST_P(OperatorTest,TopK)3302 TEST_P(OperatorTest, TopK) {
3303   CHECK_IF_ENABLED();
3304 
3305   auto *inp =
3306       mod_.createPlaceholder(ElemKind::FloatTy, {3, 1, 5}, "input", false);
3307   auto *values =
3308       mod_.createPlaceholder(ElemKind::FloatTy, {3, 1, 3}, "values", false);
3309   auto *indices =
3310       mod_.createPlaceholder(ElemKind::Int64ITy, {3, 1, 3}, "indices", false);
3311 
3312   bindings_.allocate(inp)->getHandle() = {
3313       28, 4, 411, 19, 42, 0.4f, 0.4f, 0.4f, -0.4f, 0.45f, 7, 5, 9, 8, 100,
3314   };
3315   bindings_.allocate(values);
3316   bindings_.allocate(indices);
3317 
3318   auto *R = F_->createTopK("TopK", inp, 3);
3319 
3320   F_->createSave("save.values", {R, 0}, values);
3321   F_->createSave("save.indices", {R, 1}, indices);
3322 
3323   EE_.compile(CompilationMode::Infer);
3324 
3325   EE_.run(bindings_);
3326 
3327   auto V = bindings_.get(values)->getHandle();
3328   auto I = bindings_.get(indices)->getHandle<int64_t>();
3329 
3330   EXPECT_FLOAT_EQ(V.at({0, 0, 0}), 411);
3331   EXPECT_EQ(I.at({0, 0, 0}), 2);
3332   EXPECT_FLOAT_EQ(V.at({0, 0, 1}), 42);
3333   EXPECT_EQ(I.at({0, 0, 1}), 4);
3334   EXPECT_FLOAT_EQ(V.at({0, 0, 2}), 28);
3335   EXPECT_EQ(I.at({0, 0, 2}), 0);
3336 
3337   EXPECT_FLOAT_EQ(V.at({1, 0, 0}), 0.45);
3338   EXPECT_EQ(I.at({1, 0, 0}), 4);
3339   EXPECT_FLOAT_EQ(V.at({1, 0, 1}), 0.4);
3340   EXPECT_EQ(I.at({1, 0, 1}), 0);
3341   EXPECT_FLOAT_EQ(V.at({1, 0, 2}), 0.4);
3342   EXPECT_EQ(I.at({1, 0, 2}), 1);
3343 
3344   EXPECT_FLOAT_EQ(V.at({2, 0, 0}), 100);
3345   EXPECT_EQ(I.at({2, 0, 0}), 4);
3346   EXPECT_FLOAT_EQ(V.at({2, 0, 1}), 9);
3347   EXPECT_EQ(I.at({2, 0, 1}), 2);
3348   EXPECT_FLOAT_EQ(V.at({2, 0, 2}), 8);
3349   EXPECT_EQ(I.at({2, 0, 2}), 3);
3350 }
3351 
3352 template <typename DataType>
testArgMaxKeepDim(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)3353 static void testArgMaxKeepDim(glow::PlaceholderBindings &bindings,
3354                               glow::Module &mod, glow::Function *F,
3355                               glow::ExecutionEngine &EE, ElemKind DTy) {
3356   auto *input = createPlaceholderConditionallyQuantized(mod, DTy, {2, 3, 2, 2},
3357                                                         "input", false, "NHWC");
3358   auto *argmax = mod.createPlaceholder(ElemKind::Int64ITy, {1, 3, 2, 2},
3359                                        "argmax", false, "NHWC");
3360 
3361   bindings.allocate(input)->getHandle<DataType>() = {
3362       11, 24, 33, 41, 15, 26, 37, 48, 12, 28, 31, 42,
3363       13, 24, 35, 46, 12, 28, 39, 40, 11, 22, 33, 47};
3364   bindings.allocate(argmax);
3365 
3366   auto *AM = F->createArgMax("argmax", input, 0, true);
3367   F->createSave("save.argmax", AM, argmax);
3368 
3369   EE.compile(CompilationMode::Infer);
3370   EE.run(bindings);
3371 
3372   auto I = bindings.get(argmax)->getHandle<int64_t>();
3373   EXPECT_EQ(I.raw(0), 1);
3374   EXPECT_EQ(I.raw(1), 0);
3375   EXPECT_EQ(I.raw(2), 1);
3376   EXPECT_EQ(I.raw(3), 1);
3377   EXPECT_EQ(I.raw(4), 0);
3378   EXPECT_EQ(I.raw(5), 1);
3379   EXPECT_EQ(I.raw(6), 1);
3380   EXPECT_EQ(I.raw(7), 0);
3381   EXPECT_EQ(I.raw(8), 0);
3382   EXPECT_EQ(I.raw(9), 0);
3383   EXPECT_EQ(I.raw(10), 1);
3384   EXPECT_EQ(I.raw(11), 1);
3385 }
3386 
TEST_P(OperatorTest,FloatArgMaxKeepDim)3387 TEST_P(OperatorTest, FloatArgMaxKeepDim) {
3388   CHECK_IF_ENABLED();
3389   testArgMaxKeepDim<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
3390 }
3391 
TEST_P(OperatorTest,QuantizedArgMaxKeepDim)3392 TEST_P(OperatorTest, QuantizedArgMaxKeepDim) {
3393   CHECK_IF_ENABLED();
3394   testArgMaxKeepDim<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
3395 }
3396 
3397 template <typename DataType>
testArgMaxNoKeepDim(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)3398 static void testArgMaxNoKeepDim(glow::PlaceholderBindings &bindings,
3399                                 glow::Module &mod, glow::Function *F,
3400                                 glow::ExecutionEngine &EE, ElemKind DTy) {
3401   auto *input = createPlaceholderConditionallyQuantized(mod, DTy, {2, 3, 2, 2},
3402                                                         "input", false, "NHWC");
3403   auto *argmax =
3404       mod.createPlaceholder(ElemKind::Int64ITy, {2, 2, 2}, "argmax", false);
3405 
3406   bindings.allocate(input)->getHandle<DataType>() = {
3407       11, 24, 33, 41, 15, 26, 37, 48, 12, 28, 31, 42,
3408       13, 24, 35, 46, 12, 28, 39, 40, 11, 22, 33, 47};
3409   bindings.allocate(argmax);
3410 
3411   auto *AM = F->createArgMax("argmax", input, 1, false);
3412   F->createSave("save.argmax", AM, argmax);
3413 
3414   EE.compile(CompilationMode::Infer);
3415   EE.run(bindings);
3416 
3417   auto I = bindings.get(argmax)->getHandle<int64_t>();
3418   EXPECT_EQ(I.raw(0), 1);
3419   EXPECT_EQ(I.raw(1), 2);
3420   EXPECT_EQ(I.raw(2), 1);
3421   EXPECT_EQ(I.raw(3), 1);
3422   EXPECT_EQ(I.raw(4), 0);
3423   EXPECT_EQ(I.raw(5), 1);
3424   EXPECT_EQ(I.raw(6), 1);
3425   EXPECT_EQ(I.raw(7), 2);
3426 }
3427 
TEST_P(OperatorTest,FloatArgMaxNoKeepDim)3428 TEST_P(OperatorTest, FloatArgMaxNoKeepDim) {
3429   CHECK_IF_ENABLED();
3430   testArgMaxNoKeepDim<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
3431 }
3432 
TEST_P(OperatorTest,QuantizedArgMaxNoKeepDim)3433 TEST_P(OperatorTest, QuantizedArgMaxNoKeepDim) {
3434   CHECK_IF_ENABLED();
3435   testArgMaxNoKeepDim<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
3436 }
3437 
TEST_P(OperatorTest,FloatArgMaxNoKeepDimWithAxis1)3438 TEST_P(OperatorTest, FloatArgMaxNoKeepDimWithAxis1) {
3439   CHECK_IF_ENABLED();
3440 
3441   auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4}, "input",
3442                                        false, "NHWC");
3443   auto *argmax =
3444       mod_.createPlaceholder(ElemKind::Int64ITy, {1, 3, 4}, "argmax", false);
3445 
3446   bindings_.allocate(input)->getHandle<float>() = {
3447       -2.0031254,  1.6150867,  -0.7161922,  -0.25389647, -2.3863597,
3448       1.3052065,   -1.2064048, -0.12670185, 1.4289513,   0.38050872,
3449       -0.15112245, 1.360533,   -1.9638863,  -0.7602536,  0.68145376,
3450       1.1685915,   0.35476854, 1.0272173,   -1.554366,   -1.6835353,
3451       -1.4499142,  0.9042695,  1.0751117,   -1.0798755};
3452 
3453   bindings_.allocate(argmax);
3454 
3455   auto *AM =
3456       F_->createArgMax("argmax", input, /* axis */ 1, /* keepDims */ false);
3457   F_->createSave("save.argmax", AM, argmax);
3458 
3459   EE_.compile(CompilationMode::Infer);
3460   EE_.run(bindings_);
3461 
3462   auto I = bindings_.get(argmax)->getHandle<int64_t>();
3463   EXPECT_EQ(I.raw(0), 1);
3464   EXPECT_EQ(I.raw(1), 0);
3465   EXPECT_EQ(I.raw(2), 1);
3466   EXPECT_EQ(I.raw(3), 1);
3467   EXPECT_EQ(I.raw(4), 1);
3468   EXPECT_EQ(I.raw(5), 0);
3469   EXPECT_EQ(I.raw(6), 0);
3470   EXPECT_EQ(I.raw(7), 0);
3471   EXPECT_EQ(I.raw(8), 0);
3472   EXPECT_EQ(I.raw(9), 1);
3473   EXPECT_EQ(I.raw(10), 1);
3474   EXPECT_EQ(I.raw(11), 0);
3475 }
3476 
TEST_P(OperatorTest,FloatArgMaxNoKeepDimWithAxis2)3477 TEST_P(OperatorTest, FloatArgMaxNoKeepDimWithAxis2) {
3478   CHECK_IF_ENABLED();
3479 
3480   auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4}, "input",
3481                                        false, "NHWC");
3482   auto *argmax =
3483       mod_.createPlaceholder(ElemKind::Int64ITy, {1, 2, 4}, "argmax", false);
3484 
3485   bindings_.allocate(input)->getHandle<float>() = {
3486       -0.11289205, -0.13215652, -1.184799,  0.2295995,   0.03064479,
3487       -0.28138036, -0.51807016, 0.89983666, -0.46122625, -0.70558083,
3488       0.43882176,  -0.6988644,  2.0838234,  -0.22806482, -0.6829437,
3489       0.70269305,  -0.8199907,  0.25597557, 0.3598691,   -0.9919779,
3490       2.069314,    -1.8825238,  1.2604765,  -0.78306365};
3491 
3492   bindings_.allocate(argmax);
3493 
3494   auto *AM =
3495       F_->createArgMax("argmax", input, /* axis */ 2, /* keepDims */ false);
3496   F_->createSave("save.argmax", AM, argmax);
3497 
3498   EE_.compile(CompilationMode::Infer);
3499   EE_.run(bindings_);
3500 
3501   auto I = bindings_.get(argmax)->getHandle<int64_t>();
3502   EXPECT_EQ(I.raw(0), 1);
3503   EXPECT_EQ(I.raw(1), 0);
3504   EXPECT_EQ(I.raw(2), 2);
3505   EXPECT_EQ(I.raw(3), 1);
3506   EXPECT_EQ(I.raw(4), 0);
3507   EXPECT_EQ(I.raw(5), 1);
3508   EXPECT_EQ(I.raw(6), 2);
3509   EXPECT_EQ(I.raw(7), 0);
3510 }
3511 
3512 template <typename DataType>
testArgMinKeepDim(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)3513 static void testArgMinKeepDim(glow::PlaceholderBindings &bindings,
3514                               glow::Module &mod, glow::Function *F,
3515                               glow::ExecutionEngine &EE, ElemKind DTy) {
3516   auto *input = createPlaceholderConditionallyQuantized(mod, DTy, {2, 3, 2, 2},
3517                                                         "input", false, "NHWC");
3518   auto *argmin = mod.createPlaceholder(ElemKind::Int64ITy, {1, 3, 2, 2},
3519                                        "argmin", false, "NHWC");
3520 
3521   bindings.allocate(input)->getHandle<DataType>() = {
3522       11, 24, 33, 41, 15, 26, 37, 48, 12, 28, 31, 42,
3523       13, 24, 35, 46, 12, 28, 39, 40, 11, 22, 33, 47};
3524   bindings.allocate(argmin);
3525 
3526   auto *AM = F->createArgMin("argmin", input, 0, true);
3527   F->createSave("save.argmin", AM, argmin);
3528 
3529   EE.compile(CompilationMode::Infer);
3530   EE.run(bindings);
3531 
3532   auto I = bindings.get(argmin)->getHandle<int64_t>();
3533   EXPECT_EQ(I.raw(0), 0);
3534   EXPECT_EQ(I.raw(1), 0);
3535   EXPECT_EQ(I.raw(2), 0);
3536   EXPECT_EQ(I.raw(3), 0);
3537   EXPECT_EQ(I.raw(4), 1);
3538   EXPECT_EQ(I.raw(5), 0);
3539   EXPECT_EQ(I.raw(6), 0);
3540   EXPECT_EQ(I.raw(7), 1);
3541   EXPECT_EQ(I.raw(8), 1);
3542   EXPECT_EQ(I.raw(9), 1);
3543   EXPECT_EQ(I.raw(10), 0);
3544   EXPECT_EQ(I.raw(11), 0);
3545 }
3546 
TEST_P(OperatorTest,FloatArgMinKeepDim)3547 TEST_P(OperatorTest, FloatArgMinKeepDim) {
3548   CHECK_IF_ENABLED();
3549   testArgMinKeepDim<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
3550 }
3551 
TEST_P(OperatorTest,QuantizedArgMinKeepDim)3552 TEST_P(OperatorTest, QuantizedArgMinKeepDim) {
3553   CHECK_IF_ENABLED();
3554   testArgMinKeepDim<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
3555 }
3556 
3557 template <typename DataType>
testArgMinNoKeepDim(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)3558 static void testArgMinNoKeepDim(glow::PlaceholderBindings &bindings,
3559                                 glow::Module &mod, glow::Function *F,
3560                                 glow::ExecutionEngine &EE, ElemKind DTy) {
3561   auto *input = createPlaceholderConditionallyQuantized(mod, DTy, {2, 3, 2, 2},
3562                                                         "input", false, "NHWC");
3563   auto *argmin =
3564       mod.createPlaceholder(ElemKind::Int64ITy, {2, 2, 2}, "argmin", false);
3565 
3566   bindings.allocate(input)->getHandle<DataType>() = {
3567       11, 24, 33, 41, 15, 26, 37, 48, 12, 28, 31, 42,
3568       13, 24, 35, 46, 12, 28, 39, 40, 11, 22, 33, 47};
3569   bindings.allocate(argmin);
3570 
3571   auto *AM = F->createArgMin("argmin", input, 1, false);
3572   F->createSave("save.argmin", AM, argmin);
3573 
3574   EE.compile(CompilationMode::Infer);
3575   EE.run(bindings);
3576 
3577   auto I = bindings.get(argmin)->getHandle<int64_t>();
3578   EXPECT_EQ(I.raw(0), 0);
3579   EXPECT_EQ(I.raw(1), 0);
3580   EXPECT_EQ(I.raw(2), 2);
3581   EXPECT_EQ(I.raw(3), 0);
3582   EXPECT_EQ(I.raw(4), 2);
3583   EXPECT_EQ(I.raw(5), 2);
3584   EXPECT_EQ(I.raw(6), 2);
3585   EXPECT_EQ(I.raw(7), 1);
3586 }
3587 
TEST_P(OperatorTest,FloatArgMinNoKeepDim)3588 TEST_P(OperatorTest, FloatArgMinNoKeepDim) {
3589   CHECK_IF_ENABLED();
3590   testArgMinNoKeepDim<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
3591 }
3592 
TEST_P(OperatorTest,QuantizedArgMinNoKeepDim)3593 TEST_P(OperatorTest, QuantizedArgMinNoKeepDim) {
3594   CHECK_IF_ENABLED();
3595   testArgMinNoKeepDim<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
3596 }
3597 
TEST_P(OperatorTest,FloatArgMinNoKeepDimWithAxis1)3598 TEST_P(OperatorTest, FloatArgMinNoKeepDimWithAxis1) {
3599   CHECK_IF_ENABLED();
3600 
3601   auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4}, "input",
3602                                        false, "NHWC");
3603   auto *argmin =
3604       mod_.createPlaceholder(ElemKind::Int64ITy, {1, 3, 4}, "argmin", false);
3605 
3606   bindings_.allocate(input)->getHandle<float>() = {
3607       -2.0031254,  1.6150867,  -0.7161922,  -0.25389647, -2.3863597,
3608       1.3052065,   -1.2064048, -0.12670185, 1.4289513,   0.38050872,
3609       -0.15112245, 1.360533,   -1.9638863,  -0.7602536,  0.68145376,
3610       1.1685915,   0.35476854, 1.0272173,   -1.554366,   -1.6835353,
3611       -1.4499142,  0.9042695,  1.0751117,   -1.0798755};
3612 
3613   bindings_.allocate(argmin);
3614 
3615   auto *AM =
3616       F_->createArgMin("argmin", input, /* axis */ 1, /* keepDims */ false);
3617   F_->createSave("save.argmin", AM, argmin);
3618 
3619   EE_.compile(CompilationMode::Infer);
3620   EE_.run(bindings_);
3621 
3622   auto I = bindings_.get(argmin)->getHandle<int64_t>();
3623   EXPECT_EQ(I.raw(0), 0);
3624   EXPECT_EQ(I.raw(1), 1);
3625   EXPECT_EQ(I.raw(2), 0);
3626   EXPECT_EQ(I.raw(3), 0);
3627   EXPECT_EQ(I.raw(4), 0);
3628   EXPECT_EQ(I.raw(5), 1);
3629   EXPECT_EQ(I.raw(6), 1);
3630   EXPECT_EQ(I.raw(7), 1);
3631   EXPECT_EQ(I.raw(8), 1);
3632   EXPECT_EQ(I.raw(9), 0);
3633   EXPECT_EQ(I.raw(10), 0);
3634   EXPECT_EQ(I.raw(11), 1);
3635 }
3636 
TEST_P(OperatorTest,FloatArgMinNoKeepDimWithAxis2)3637 TEST_P(OperatorTest, FloatArgMinNoKeepDimWithAxis2) {
3638   CHECK_IF_ENABLED();
3639 
3640   auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4}, "input",
3641                                        false, "NHWC");
3642   auto *argmin =
3643       mod_.createPlaceholder(ElemKind::Int64ITy, {1, 2, 4}, "argmin", false);
3644 
3645   bindings_.allocate(input)->getHandle<float>() = {
3646       -0.11289205, -0.13215652, -1.184799,  0.2295995,   0.03064479,
3647       -0.28138036, -0.51807016, 0.89983666, -0.46122625, -0.70558083,
3648       0.43882176,  -0.6988644,  2.0838234,  -0.22806482, -0.6829437,
3649       0.70269305,  -0.8199907,  0.25597557, 0.3598691,   -0.9919779,
3650       2.069314,    -1.8825238,  1.2604765,  -0.78306365};
3651 
3652   bindings_.allocate(argmin);
3653 
3654   auto *AM =
3655       F_->createArgMin("argmin", input, /* axis */ 2, /* keepDims */ false);
3656   F_->createSave("save.argmin", AM, argmin);
3657 
3658   EE_.compile(CompilationMode::Infer);
3659   EE_.run(bindings_);
3660 
3661   auto I = bindings_.get(argmin)->getHandle<int64_t>();
3662   EXPECT_EQ(I.raw(0), 2);
3663   EXPECT_EQ(I.raw(1), 2);
3664   EXPECT_EQ(I.raw(2), 0);
3665   EXPECT_EQ(I.raw(3), 2);
3666   EXPECT_EQ(I.raw(4), 1);
3667   EXPECT_EQ(I.raw(5), 2);
3668   EXPECT_EQ(I.raw(6), 0);
3669   EXPECT_EQ(I.raw(7), 1);
3670 }
3671 
3672 // Check that concatenating Nodes with multiple outputs works correctly.
TEST_P(OperatorTest,ConcatTopK)3673 TEST_P(OperatorTest, ConcatTopK) {
3674   CHECK_IF_ENABLED();
3675 
3676   auto *inp1 =
3677       mod_.createPlaceholder(ElemKind::FloatTy, {2, 1, 3}, "input", false);
3678   auto *inp2 =
3679       mod_.createPlaceholder(ElemKind::FloatTy, {2, 1, 3}, "input", false);
3680   auto *indices =
3681       mod_.createPlaceholder(ElemKind::Int64ITy, {4, 1, 2}, "indices", false);
3682 
3683   bindings_.allocate(inp1)->getHandle() = {1, 2, 3, 17.4f, -0.1f, -10.1f};
3684   bindings_.allocate(inp2)->getHandle() = {1, 2, -3, -17.4f, -0.1f, -10.1f};
3685 
3686   auto *R1 = F_->createTopK("TopK1", inp1, 2);
3687   auto *R2 = F_->createTopK("TopK2", inp2, 2);
3688 
3689   // Concat the values and indices separately, both on the 0th dimension,
3690   // matching the shapes of the values and indices variables above.
3691   auto *CV =
3692       F_->createConcat("Concat.Values", {R1->getValues(), R2->getValues()}, 0);
3693   auto *CI = F_->createConcat("Concat.Indices",
3694                               {R1->getIndices(), R2->getIndices()}, 0);
3695 
3696   auto *saveValues = F_->createSave("Save.Values", CV);
3697   auto *saveValuesTensor = bindings_.allocate(saveValues->getPlaceholder());
3698 
3699   auto *saveIndices = F_->createSave("Save.Indices", CI, indices);
3700   auto *saveIndicesTensor = bindings_.allocate(saveIndices->getPlaceholder());
3701 
3702   EE_.compile(CompilationMode::Infer);
3703 
3704   EE_.run(bindings_);
3705 
3706   auto V = saveValuesTensor->getHandle();
3707   auto I = saveIndicesTensor->getHandle<int64_t>();
3708 
3709   EXPECT_FLOAT_EQ(V.at({0, 0, 0}), 3);
3710   EXPECT_FLOAT_EQ(I.at({0, 0, 0}), 2);
3711   EXPECT_FLOAT_EQ(V.at({0, 0, 1}), 2);
3712   EXPECT_FLOAT_EQ(I.at({0, 0, 1}), 1);
3713 
3714   EXPECT_FLOAT_EQ(V.at({1, 0, 0}), 17.4f);
3715   EXPECT_FLOAT_EQ(I.at({1, 0, 0}), 0);
3716   EXPECT_FLOAT_EQ(V.at({1, 0, 1}), -0.1f);
3717   EXPECT_FLOAT_EQ(I.at({1, 0, 1}), 1);
3718 
3719   EXPECT_FLOAT_EQ(V.at({2, 0, 0}), 2);
3720   EXPECT_FLOAT_EQ(I.at({2, 0, 0}), 1);
3721   EXPECT_FLOAT_EQ(V.at({2, 0, 1}), 1);
3722   EXPECT_FLOAT_EQ(I.at({2, 0, 1}), 0);
3723 
3724   EXPECT_FLOAT_EQ(V.at({3, 0, 0}), -0.1f);
3725   EXPECT_FLOAT_EQ(I.at({3, 0, 0}), 1);
3726   EXPECT_FLOAT_EQ(V.at({3, 0, 1}), -10.1f);
3727   EXPECT_FLOAT_EQ(I.at({3, 0, 1}), 2);
3728 }
3729 
3730 // Check that matrix multiplication works well on some predefined values.
TEST_P(OperatorTest,matmul2)3731 TEST_P(OperatorTest, matmul2) {
3732   CHECK_IF_ENABLED();
3733 
3734   auto *inp0 =
3735       mod_.createPlaceholder(ElemKind::FloatTy, {1, 2}, "input0", false);
3736   auto *inp1 =
3737       mod_.createPlaceholder(ElemKind::FloatTy, {1, 2}, "input1", false);
3738   auto *inp2 =
3739       mod_.createPlaceholder(ElemKind::FloatTy, {1, 2}, "input2", false);
3740   auto *rot = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "rot", false);
3741 
3742   float deg = 45.0 / 180.0 * 3.1415926;
3743   // Use the rotation matrix to manipulate some values.
3744   // https://en.wikipedia.org/wiki/Rotation_matrix
3745   bindings_.allocate(rot)->getHandle() = {
3746       cosf(deg),
3747       -sinf(deg),
3748       sinf(deg),
3749       cosf(deg),
3750   };
3751 
3752   // Some test vectors.
3753   bindings_.allocate(inp0)->getHandle() = {1, 4};
3754   bindings_.allocate(inp1)->getHandle() = {14, 2};
3755   bindings_.allocate(inp2)->getHandle() = {5, 2};
3756 
3757   auto *A0 = F_->createMatMul("m0", inp0, rot);
3758   auto *A1 = F_->createMatMul("m1", inp1, rot);
3759   auto *A2 = F_->createMatMul("m2", inp2, rot);
3760 
3761   auto *res0 = F_->createSave("save.values", A0);
3762   bindings_.allocate(res0->getPlaceholder());
3763   auto *res1 = F_->createSave("save.values", A1);
3764   bindings_.allocate(res1->getPlaceholder());
3765   auto *res2 = F_->createSave("save.values", A2);
3766   bindings_.allocate(res2->getPlaceholder());
3767 
3768   EE_.compile(CompilationMode::Infer);
3769 
3770   EE_.run(bindings_);
3771 
3772   auto R0 = bindings_.get(res0->getPlaceholder())->getHandle();
3773   auto R1 = bindings_.get(res1->getPlaceholder())->getHandle();
3774   auto R2 = bindings_.get(res2->getPlaceholder())->getHandle();
3775 
3776   EXPECT_FLOAT_EQ(R0.at({0, 0}), 3.5355339);
3777   EXPECT_FLOAT_EQ(R0.at({0, 1}), 2.1213205);
3778   EXPECT_FLOAT_EQ(R1.at({0, 0}), 11.313709);
3779   EXPECT_FLOAT_EQ(R1.at({0, 1}), -8.485281);
3780   EXPECT_FLOAT_EQ(R2.at({0, 0}), 4.9497476);
3781   EXPECT_FLOAT_EQ(R2.at({0, 1}), -2.1213202);
3782 }
3783 
3784 template <typename HandleTy>
topK1Template(Module & mod_,Function * F_,ExecutionEngine & EE_,PlaceholderBindings & bindings_,ElemKind topKElemKind)3785 static void topK1Template(Module &mod_, Function *F_, ExecutionEngine &EE_,
3786                           PlaceholderBindings &bindings_,
3787                           ElemKind topKElemKind) {
3788   auto *inp =
3789       mod_.createPlaceholder(ElemKind::FloatTy, {3, 1, 5}, "input", false);
3790 
3791   bindings_.allocate(inp)->getHandle() = {
3792       0, 18, 7, 16, 5, 14, 33, 2, 41, 0, 1, -23, 34, 4, -5,
3793   };
3794 
3795   auto *R = F_->createTopK("TopK", inp, 1, topKElemKind);
3796 
3797   auto *values = F_->createSave("save.values", {R, 0});
3798   bindings_.allocate(values->getPlaceholder());
3799 
3800   auto *indices = F_->createSave("save.indices", {R, 1});
3801   bindings_.allocate(indices->getPlaceholder());
3802 
3803   EE_.compile(CompilationMode::Infer);
3804   EE_.run(bindings_);
3805 
3806   auto V = bindings_.get(values->getPlaceholder())->getHandle();
3807   auto I = bindings_.get(indices->getPlaceholder())->getHandle<HandleTy>();
3808 
3809   EXPECT_FLOAT_EQ(V.at({0, 0, 0}), 18);
3810   EXPECT_EQ(I.at({0, 0, 0}), 1);
3811   EXPECT_FLOAT_EQ(V.at({1, 0, 0}), 41);
3812   EXPECT_EQ(I.at({1, 0, 0}), 3);
3813   EXPECT_FLOAT_EQ(V.at({2, 0, 0}), 34);
3814   EXPECT_EQ(I.at({2, 0, 0}), 2);
3815 }
3816 // Check the TopK operator for the special case of K=1.
TEST_P(OperatorTest,TopK1)3817 TEST_P(OperatorTest, TopK1) {
3818   CHECK_IF_ENABLED();
3819 
3820   topK1Template<int64_t>(mod_, F_, EE_, bindings_, ElemKind::Int64ITy);
3821 }
3822 
3823 // Check the TopK operator for the special case of K=1.
TEST_P(OperatorTest,TopK1int32)3824 TEST_P(OperatorTest, TopK1int32) {
3825   CHECK_IF_ENABLED();
3826 
3827   topK1Template<int32_t>(mod_, F_, EE_, bindings_, ElemKind::Int32ITy);
3828 }
3829 
TEST_P(OperatorTest,QuantizedTopK)3830 TEST_P(OperatorTest, QuantizedTopK) {
3831   CHECK_IF_ENABLED();
3832 
3833   auto *INV = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 1, 5}, 1.2, 5,
3834                                      "input", false);
3835   bindings_.allocate(INV)->getHandle<int8_t>() = {
3836       -12, -28, -7, 8, -93, 0, 10, 3, -1, 10, -2, 3, -2, 3, 3,
3837   };
3838 
3839   auto *TK = F_->createTopK("TopK", INV, 3);
3840 
3841   auto *values = F_->createSave("save.values", TK->getValues());
3842   bindings_.allocate(values->getPlaceholder());
3843   auto *indices = F_->createSave("save.indices", TK->getIndices());
3844   bindings_.allocate(indices->getPlaceholder());
3845 
3846   EE_.compile(CompilationMode::Infer);
3847   EE_.run(bindings_);
3848 
3849   auto VH = bindings_.get(values->getPlaceholder())->getHandle<int8_t>();
3850   auto IH = bindings_.get(indices->getPlaceholder())->getHandle<int64_t>();
3851 
3852   EXPECT_EQ(VH.at({0, 0, 0}), 8);
3853   EXPECT_EQ(IH.at({0, 0, 0}), 3);
3854   EXPECT_EQ(VH.at({0, 0, 1}), -7);
3855   EXPECT_EQ(IH.at({0, 0, 1}), 2);
3856   EXPECT_EQ(VH.at({0, 0, 2}), -12);
3857   EXPECT_EQ(IH.at({0, 0, 2}), 0);
3858 
3859   EXPECT_EQ(VH.at({1, 0, 0}), 10);
3860   EXPECT_EQ(IH.at({1, 0, 0}), 1);
3861   EXPECT_EQ(VH.at({1, 0, 1}), 10);
3862   EXPECT_EQ(IH.at({1, 0, 1}), 4);
3863   EXPECT_EQ(VH.at({1, 0, 2}), 3);
3864   EXPECT_EQ(IH.at({1, 0, 2}), 2);
3865 
3866   EXPECT_EQ(VH.at({2, 0, 0}), 3);
3867   EXPECT_EQ(IH.at({2, 0, 0}), 1);
3868   EXPECT_EQ(VH.at({2, 0, 1}), 3);
3869   EXPECT_EQ(IH.at({2, 0, 1}), 3);
3870   EXPECT_EQ(VH.at({2, 0, 2}), 3);
3871   EXPECT_EQ(IH.at({2, 0, 2}), 4);
3872 }
3873 
3874 /// Helper for testing Gather with different \p ITy / \p IndexType.
3875 template <typename DataType, typename IndexType>
gatherFloatInputTest(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,ElemKind ITy)3876 static void gatherFloatInputTest(glow::PlaceholderBindings &bindings,
3877                                  glow::Module &mod, glow::Function *F,
3878                                  glow::ExecutionEngine &EE, ElemKind DTy,
3879                                  ElemKind ITy) {
3880   /*
3881     DATA  = [
3882         [1.0, 1.2],
3883         [2.3, 3.4],
3884         [4.5, 5.7],
3885     ]
3886     INDICES = [
3887         [0, 1, 0, 1],
3888         [1, 2, 2, 0],
3889     ]
3890     OUTPUT = [
3891         [
3892             [1.0, 1.2],
3893             [2.3, 3.4],
3894             [1.0, 1.2],
3895             [2.3, 3.4],
3896         ],
3897         [
3898             [2.3, 3.4],
3899             [4.5, 5.7],
3900             [4.5, 5.7],
3901             [1.0, 1.2],
3902         ],
3903     ]
3904   */
3905   auto *data = mod.createPlaceholder(DTy, {3, 2}, "data", false);
3906   auto *indices = mod.createPlaceholder(ITy, {2, 4}, "indices", false);
3907 
3908   bindings.allocate(data)->getHandle<DataType>() = {
3909       1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f,
3910   };
3911   bindings.allocate(indices)->getHandle<IndexType>() = {
3912       0, 1, 0, 1, 1, 2, 2, 0,
3913   };
3914 
3915   auto *R = F->createGather("gather", data, indices);
3916 
3917   auto *result = F->createSave("save", R);
3918   bindings.allocate(result->getPlaceholder());
3919 
3920   EE.compile(CompilationMode::Infer);
3921   EE.run(bindings);
3922 
3923   Tensor *resultT = bindings.get(result->getPlaceholder());
3924   Tensor expectedT(DTy, {2, 4, 2});
3925   expectedT.getHandle<DataType>() = {1.0, 1.2, 2.3, 3.4, 1.0, 1.2, 2.3, 3.4,
3926                                      2.3, 3.4, 4.5, 5.7, 4.5, 5.7, 1.0, 1.2};
3927 
3928   EXPECT_TRUE(resultT->isEqual(expectedT));
3929 }
3930 
3931 /// Test that Gather works with Float data and Int32 indices.
TEST_P(OperatorTest,GatherDataFloatIdxInt32)3932 TEST_P(OperatorTest, GatherDataFloatIdxInt32) {
3933   CHECK_IF_ENABLED();
3934   gatherFloatInputTest<float, int32_t>(bindings_, mod_, F_, EE_,
3935                                        ElemKind::FloatTy, ElemKind::Int32ITy);
3936 }
3937 
3938 #if DIM_T_BITWIDTH >= 64
3939 /// Test that Gather works with Float data and Int64 indices.
TEST_P(OperatorTest,GatherDataFloatIdxInt64)3940 TEST_P(OperatorTest, GatherDataFloatIdxInt64) {
3941   CHECK_IF_ENABLED();
3942   gatherFloatInputTest<float, int64_t>(bindings_, mod_, F_, EE_,
3943                                        ElemKind::FloatTy, ElemKind::Int64ITy);
3944 }
3945 #endif
3946 
3947 /// Test that Gather works with Float16 data and Int32 indices.
TEST_P(OperatorTest,GatherDataFloat16IdxInt32)3948 TEST_P(OperatorTest, GatherDataFloat16IdxInt32) {
3949   CHECK_IF_ENABLED();
3950   gatherFloatInputTest<float16_t, int32_t>(
3951       bindings_, mod_, F_, EE_, ElemKind::Float16Ty, ElemKind::Int32ITy);
3952 }
3953 
3954 /// Test that Gather works with BFloat16 data and Int32 indices.
TEST_P(OperatorTest,GatherDataBFloat16IdxInt32)3955 TEST_P(OperatorTest, GatherDataBFloat16IdxInt32) {
3956   CHECK_IF_ENABLED();
3957   gatherFloatInputTest<bfloat16_t, int32_t>(
3958       bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty, ElemKind::Int32ITy);
3959 }
3960 
3961 /// Test that Gather works with Float16 data and Int64 indices.
TEST_P(OperatorTest,GatherDataFloat16IdxInt64)3962 TEST_P(OperatorTest, GatherDataFloat16IdxInt64) {
3963   CHECK_IF_ENABLED();
3964   gatherFloatInputTest<float16_t, int64_t>(
3965       bindings_, mod_, F_, EE_, ElemKind::Float16Ty, ElemKind::Int64ITy);
3966 }
3967 
3968 /// Test that Gather works with BFloat16 data and Int64 indices.
TEST_P(OperatorTest,GatherDataBFloat16IdxInt64)3969 TEST_P(OperatorTest, GatherDataBFloat16IdxInt64) {
3970   CHECK_IF_ENABLED();
3971   gatherFloatInputTest<bfloat16_t, int64_t>(
3972       bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty, ElemKind::Int64ITy);
3973 }
3974 
3975 /// Helper for testing Gather with different \p ITy / \p IndexType.
3976 template <typename IndexType>
gatherInt8InputTest(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind ITy)3977 static void gatherInt8InputTest(glow::PlaceholderBindings &bindings,
3978                                 glow::Module &mod, glow::Function *F,
3979                                 glow::ExecutionEngine &EE, ElemKind ITy) {
3980   /*
3981     DATA  = [
3982         [1, 2],
3983         [3, 4],
3984         [5, 6],
3985     ]
3986     INDICES = [
3987         [0, 1, 0, 1],
3988         [1, 2, 2, 0],
3989     ]
3990     OUTPUT = [
3991         [
3992             [1, 2],
3993             [3, 4],
3994             [1, 2],
3995             [3, 4],
3996         ],
3997         [
3998             [3, 4],
3999             [5, 6],
4000             [5, 6],
4001             [1, 2],
4002         ],
4003     ]
4004   */
4005   auto *data =
4006       mod.createPlaceholder(ElemKind::Int8QTy, {3, 2}, 1.0, 0, "data", false);
4007   auto *indices = mod.createPlaceholder(ITy, {2, 4}, "indices", false);
4008 
4009   bindings.allocate(data)->getHandle<int8_t>() = {
4010       1, 2, 3, 4, 5, 6,
4011   };
4012   bindings.allocate(indices)->getHandle<IndexType>() = {
4013       0, 1, 0, 1, 1, 2, 2, 0,
4014   };
4015 
4016   auto *R = F->createGather("gather", data, indices);
4017 
4018   auto *result = F->createSave("save", R);
4019   bindings.allocate(result->getPlaceholder());
4020 
4021   EE.compile(CompilationMode::Infer);
4022   EE.run(bindings);
4023 
4024   Tensor *resultT = bindings.get(result->getPlaceholder());
4025   Tensor expectedT(ElemKind::Int8QTy, {2, 4, 2}, 1.0, 0);
4026   expectedT.getHandle<int8_t>() = {1, 2, 3, 4, 1, 2, 3, 4,
4027                                    3, 4, 5, 6, 5, 6, 1, 2};
4028 
4029   EXPECT_TRUE(resultT->isEqual(expectedT));
4030 }
4031 
4032 /// Test that Gather works with Int8 data and Int32 indices.
TEST_P(OperatorTest,GatherDataInt8IdxInt32)4033 TEST_P(OperatorTest, GatherDataInt8IdxInt32) {
4034   CHECK_IF_ENABLED();
4035   gatherInt8InputTest<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32ITy);
4036 }
4037 
4038 #if DIM_T_BITWIDTH >= 64
4039 /// Test that Gather works with Int8 data and Int64 indices.
TEST_P(OperatorTest,GatherDataInt8IdxInt64)4040 TEST_P(OperatorTest, GatherDataInt8IdxInt64) {
4041   CHECK_IF_ENABLED();
4042   gatherInt8InputTest<int64_t>(bindings_, mod_, F_, EE_, ElemKind::Int64ITy);
4043 }
4044 #endif
4045 
4046 /// Helper for testing GatherRanges with different \p ITy / \p IndexType.
4047 template <typename DataType, typename IndexType>
gatherRangesTest(glow::PlaceholderBindings & bindings_,glow::Module & mod_,glow::Function * F_,glow::ExecutionEngine & EE_,ElemKind DTy,ElemKind ITy)4048 void gatherRangesTest(glow::PlaceholderBindings &bindings_, glow::Module &mod_,
4049                       glow::Function *F_, glow::ExecutionEngine &EE_,
4050                       ElemKind DTy, ElemKind ITy) {
4051   /*
4052     DATA  = [1, 2, 3, 4, 5, 6]
4053     RANGES = [
4054       [
4055         [0, 1],
4056         [2, 2],
4057       ],
4058       [
4059         [4, 1],
4060         [5, 1],
4061       ]
4062     ]
4063     OUTPUT = [1, 3, 4, 5, 6]
4064     LENGTHS = [3, 2]
4065   */
4066   auto *data = createPlaceholderConditionallyQuantized(mod_, DTy, {6}, "data",
4067                                                        false, "N");
4068   auto *ranges = mod_.createPlaceholder(ITy, {2, 2, 2}, "ranges", false);
4069 
4070   bindings_.allocate(data)->getHandle<DataType>() = {1, 2, 3, 4, 5, 6};
4071   bindings_.allocate(ranges)->getHandle<IndexType>() = {0, 1, 2, 2, 4, 1, 5, 1};
4072 
4073   auto *R =
4074       F_->createGatherRanges("gatherranges", data, ranges, /*maxOutputSize=*/5);
4075 
4076   auto *output = F_->createSave("output", R->getOutput());
4077   auto *lengths = F_->createSave("lengths", R->getLengths());
4078 
4079   Tensor *outputT = bindings_.allocate(output->getPlaceholder());
4080   Tensor *lengthsT = bindings_.allocate(lengths->getPlaceholder());
4081 
4082   EE_.compile(CompilationMode::Infer);
4083   EE_.run(bindings_);
4084 
4085   auto expectedOutputT = createTensorConditionallyQuantized(DTy, {5});
4086   expectedOutputT.getHandle<DataType>() = {1, 3, 4, 5, 6};
4087   EXPECT_TRUE(outputT->isEqual(expectedOutputT));
4088 
4089   Tensor expectedLengthsT(ITy, {2});
4090   expectedLengthsT.getHandle<IndexType>() = {3, 2};
4091   EXPECT_TRUE(lengthsT->isEqual(expectedLengthsT));
4092 }
4093 
4094 /// Test GatherRanges with Int64 data and Int32 indices.
TEST_P(OperatorTest,GatherRangesDataInt64IdxInt32)4095 TEST_P(OperatorTest, GatherRangesDataInt64IdxInt32) {
4096   CHECK_IF_ENABLED();
4097   gatherRangesTest<int64_t, int32_t>(bindings_, mod_, F_, EE_,
4098                                      ElemKind::Int64ITy, ElemKind::Int32ITy);
4099 }
4100 
4101 #if DIM_T_BITWIDTH >= 64
4102 /// Test GatherRanges with Int64 data and Int64 indices.
TEST_P(OperatorTest,GatherRangesDataInt64IdxInt64)4103 TEST_P(OperatorTest, GatherRangesDataInt64IdxInt64) {
4104   CHECK_IF_ENABLED();
4105   gatherRangesTest<int64_t, int64_t>(bindings_, mod_, F_, EE_,
4106                                      ElemKind::Int64ITy, ElemKind::Int64ITy);
4107 }
4108 #endif
4109 
4110 /// Test GatherRanges with Float data and Int32 indices.
TEST_P(OperatorTest,GatherRangesDataFloatIdxInt32)4111 TEST_P(OperatorTest, GatherRangesDataFloatIdxInt32) {
4112   CHECK_IF_ENABLED();
4113   gatherRangesTest<float, int32_t>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
4114                                    ElemKind::Int32ITy);
4115 }
4116 
4117 #if DIM_T_BITWIDTH >= 64
4118 /// Test GatherRanges with Float data and Int64 indices.
TEST_P(OperatorTest,GatherRangesDataFloatIdxInt64)4119 TEST_P(OperatorTest, GatherRangesDataFloatIdxInt64) {
4120   CHECK_IF_ENABLED();
4121   gatherRangesTest<float, int64_t>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
4122                                    ElemKind::Int64ITy);
4123 }
4124 #endif
4125 
4126 /// Test GatherRanges with Float16 data and Int32 indices.
TEST_P(OperatorTest,GatherRangesDataFloat16IdxInt32)4127 TEST_P(OperatorTest, GatherRangesDataFloat16IdxInt32) {
4128   CHECK_IF_ENABLED();
4129   gatherRangesTest<float16_t, int32_t>(bindings_, mod_, F_, EE_,
4130                                        ElemKind::Float16Ty, ElemKind::Int32ITy);
4131 }
4132 
4133 /// Test GatherRanges with BFloat16 data and Int32 indices.
TEST_P(OperatorTest,GatherRangesDataBFloat16IdxInt32)4134 TEST_P(OperatorTest, GatherRangesDataBFloat16IdxInt32) {
4135   CHECK_IF_ENABLED();
4136   gatherRangesTest<bfloat16_t, int32_t>(
4137       bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty, ElemKind::Int32ITy);
4138 }
4139 
4140 #if DIM_T_BITWIDTH >= 64
4141 /// Test GatherRanges with Float16 data and Int64 indices.
TEST_P(OperatorTest,GatherRangesDataFloat16IdxInt64)4142 TEST_P(OperatorTest, GatherRangesDataFloat16IdxInt64) {
4143   CHECK_IF_ENABLED();
4144   gatherRangesTest<float16_t, int64_t>(bindings_, mod_, F_, EE_,
4145                                        ElemKind::Float16Ty, ElemKind::Int64ITy);
4146 }
4147 
4148 /// Test GatherRanges with BFloat16 data and Int64 indices.
TEST_P(OperatorTest,GatherRangesDataBFloat16IdxInt64)4149 TEST_P(OperatorTest, GatherRangesDataBFloat16IdxInt64) {
4150   CHECK_IF_ENABLED();
4151   gatherRangesTest<bfloat16_t, int64_t>(
4152       bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty, ElemKind::Int64ITy);
4153 }
4154 #endif
4155 
4156 /// Test GatherRanges with Int8Q data and Int32 indices.
TEST_P(OperatorTest,GatherRangesDataInt8QIdxInt32)4157 TEST_P(OperatorTest, GatherRangesDataInt8QIdxInt32) {
4158   CHECK_IF_ENABLED();
4159   gatherRangesTest<int8_t, int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy,
4160                                     ElemKind::Int32ITy);
4161 }
4162 
4163 #if DIM_T_BITWIDTH >= 64
4164 /// Test GatherRanges with Int8Q data and Int64 indices.
TEST_P(OperatorTest,GatherRangesDataInt8QIdxInt64)4165 TEST_P(OperatorTest, GatherRangesDataInt8QIdxInt64) {
4166   CHECK_IF_ENABLED();
4167   gatherRangesTest<int8_t, int64_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy,
4168                                     ElemKind::Int64ITy);
4169 }
4170 #endif
4171 
4172 /// Check if the code generation of transposes
4173 /// is correct for tensors with 2 dimensions.
4174 /// Note: This assumes that Tensor::transpose is correct.
TEST_P(OperatorTest,Transpose2Dims)4175 TEST_P(OperatorTest, Transpose2Dims) {
4176   CHECK_IF_ENABLED();
4177 
4178   auto *A = mod_.createPlaceholder(ElemKind::FloatTy, {20, 13}, "A", false);
4179   bindings_.allocate(A)->getHandle().randomize(-3.0, 3.0, mod_.getPRNG());
4180 
4181   auto *tr = F_->createTranspose("tr", A, {1, 0});
4182   auto *result = F_->createSave("saveTranspose", tr);
4183   bindings_.allocate(result->getPlaceholder());
4184 
4185   EE_.compile(CompilationMode::Infer);
4186   EE_.run(bindings_);
4187 
4188   Tensor dest(ElemKind::FloatTy, {13, 20});
4189   bindings_.get(A)->transpose(&dest, {1, 0});
4190   EXPECT_TRUE(bindings_.get(result->getPlaceholder())->isEqual(dest));
4191 }
4192 
4193 /// Check that transpose is supported for FP16.
TEST_P(OperatorTest,FP16Transpose2Dims)4194 TEST_P(OperatorTest, FP16Transpose2Dims) {
4195   CHECK_IF_ENABLED();
4196 
4197   auto *A = mod_.createPlaceholder(ElemKind::Float16Ty, {20, 13}, "A", false);
4198   bindings_.allocate(A)->getHandle<float16_t>().randomize(-3.0, 3.0,
4199                                                           mod_.getPRNG());
4200 
4201   auto *tr = F_->createTranspose("tr", A, {1, 0});
4202   auto *result = F_->createSave("saveTranspose", tr);
4203   bindings_.allocate(result->getPlaceholder());
4204 
4205   EE_.compile(CompilationMode::Infer);
4206   EE_.run(bindings_);
4207 
4208   Tensor dest(ElemKind::Float16Ty, {13, 20});
4209   bindings_.get(A)->transpose(&dest, {1, 0});
4210   EXPECT_TRUE(bindings_.get(result->getPlaceholder())->isEqual(dest));
4211 }
4212 
4213 /// Check that transpose is supported for BFloat16.
TEST_P(OperatorTest,BFloat16Transpose2Dims)4214 TEST_P(OperatorTest, BFloat16Transpose2Dims) {
4215   CHECK_IF_ENABLED();
4216 
4217   auto *A = mod_.createPlaceholder(ElemKind::BFloat16Ty, {20, 13}, "A", false);
4218   bindings_.allocate(A)->getHandle<bfloat16_t>().randomize(-3.0, 3.0,
4219                                                            mod_.getPRNG());
4220 
4221   auto *tr = F_->createTranspose("tr", A, {1, 0});
4222   auto *result = F_->createSave("saveTranspose", tr);
4223   bindings_.allocate(result->getPlaceholder());
4224 
4225   EE_.compile(CompilationMode::Infer);
4226   EE_.run(bindings_);
4227 
4228   Tensor dest(ElemKind::BFloat16Ty, {13, 20});
4229   bindings_.get(A)->transpose(&dest, {1, 0});
4230   EXPECT_TRUE(bindings_.get(result->getPlaceholder())->isEqual(dest));
4231 }
4232 
4233 /// Check that transpose is supported for BoolTy.
TEST_P(OperatorTest,BoolTranspose2Dims)4234 TEST_P(OperatorTest, BoolTranspose2Dims) {
4235   CHECK_IF_ENABLED();
4236 
4237   auto *A = mod_.createPlaceholder(ElemKind::BoolTy, {20, 13}, "A", false);
4238   bindings_.allocate(A)->getHandle<bool>().randomize(0, 1, mod_.getPRNG());
4239 
4240   auto *tr = F_->createTranspose("tr", A, {1, 0});
4241   auto *result = F_->createSave("saveTranspose", tr);
4242   bindings_.allocate(result->getPlaceholder());
4243 
4244   EE_.compile(CompilationMode::Infer);
4245   EE_.run(bindings_);
4246 
4247   Tensor dest(ElemKind::BoolTy, {13, 20});
4248   bindings_.get(A)->transpose(&dest, {1, 0});
4249   EXPECT_TRUE(bindings_.get(result->getPlaceholder())->isEqual(dest));
4250 }
4251 
4252 /// Helper to check if the code generation of transposes
4253 /// is correct for tensors with 3 dimensions using \p DTy.
4254 /// Note: This assumes that Tensor::transpose is correct.
4255 template <typename DataType>
testTranspose3Dims(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)4256 static void testTranspose3Dims(glow::PlaceholderBindings &bindings,
4257                                glow::Module &mod, glow::Function *F,
4258                                glow::ExecutionEngine &EE, ElemKind DTy) {
4259   constexpr dim_t dims[] = {20, 13, 7};
4260   auto *A = createPlaceholderConditionallyQuantized(mod, DTy, dims, "A", false);
4261   bindings.allocate(A)->getHandle<DataType>().randomize(-3.0, 3.0,
4262                                                         mod.getPRNG());
4263 
4264   int nbOfShuffle = 0;
4265   SaveNode *savedTransposes[6];
4266   unsigned_t shuffles[6][3];
4267 
4268   for (unsigned_t i = 0; i < 3; ++i) {
4269     for (unsigned_t j = 0; j < 3; ++j) {
4270       if (j == i) {
4271         continue;
4272       }
4273       for (unsigned_t k = 0; k < 3; ++k) {
4274         if (k == j || k == i) {
4275           continue;
4276         }
4277         shuffles[nbOfShuffle][0] = i;
4278         shuffles[nbOfShuffle][1] = j;
4279         shuffles[nbOfShuffle][2] = k;
4280         auto *tr = F->createTranspose("tr", A, shuffles[nbOfShuffle]);
4281         savedTransposes[nbOfShuffle] = F->createSave("saveTranspose", tr);
4282         bindings.allocate(savedTransposes[nbOfShuffle]->getPlaceholder());
4283         ++nbOfShuffle;
4284       }
4285     }
4286   }
4287 
4288   // We should have exactly 6 possible permutations for 3 dimensions.
4289   EXPECT_EQ(6, nbOfShuffle);
4290 
4291   EE.compile(CompilationMode::Infer);
4292   EE.run(bindings);
4293 
4294   for (int i = 0; i < 6; ++i) {
4295     auto dest = createTensorConditionallyQuantized(
4296         DTy,
4297         {dims[shuffles[i][0]], dims[shuffles[i][1]], dims[shuffles[i][2]]});
4298     bindings.get(A)->transpose(&dest, shuffles[i]);
4299     EXPECT_TRUE(
4300         bindings.get(savedTransposes[i]->getPlaceholder())->isEqual(dest));
4301   }
4302 }
4303 
4304 /// Test Transpose3Dims with Float.
TEST_P(OperatorTest,Transpose3Dims_Float)4305 TEST_P(OperatorTest, Transpose3Dims_Float) {
4306   CHECK_IF_ENABLED();
4307   testTranspose3Dims<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
4308 }
4309 
4310 /// Test Transpose3Dims with Float16.
TEST_P(OperatorTest,Transpose3Dims_Float16)4311 TEST_P(OperatorTest, Transpose3Dims_Float16) {
4312   CHECK_IF_ENABLED();
4313   testTranspose3Dims<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
4314 }
4315 
4316 /// Test Transpose3Dims with BFloat16.
TEST_P(OperatorTest,Transpose3Dims_BFloat16)4317 TEST_P(OperatorTest, Transpose3Dims_BFloat16) {
4318   CHECK_IF_ENABLED();
4319   testTranspose3Dims<bfloat16_t>(bindings_, mod_, F_, EE_,
4320                                  ElemKind::BFloat16Ty);
4321 }
4322 
4323 /// Test Transpose3Dims with Int8.
TEST_P(OperatorTest,Transpose3Dims_Int8)4324 TEST_P(OperatorTest, Transpose3Dims_Int8) {
4325   CHECK_IF_ENABLED();
4326   testTranspose3Dims<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
4327 }
4328 
4329 /// Test that Transpose optimization into Reshape yields expected results.
TEST_P(OperatorTest,TransposeIntoReshapeOptim)4330 TEST_P(OperatorTest, TransposeIntoReshapeOptim) {
4331   CHECK_IF_ENABLED();
4332   auto *batch = mod_.createPlaceholder(ElemKind::FloatTy, {1, 3, 2, 4}, "batch",
4333                                        false, "NHWC");
4334   auto IH = bindings_.allocate(batch)->getHandle();
4335   for (size_t i = 0; i < 24; i++) {
4336     IH.raw(i) = i + 1;
4337   }
4338 
4339   Node *T = F_->createTranspose("transpose", batch, {1, 2, 0, 3}, "HWNC");
4340   Node *R = F_->createBatchedReduceMean("reduce.mean", T, {2, 3});
4341   SaveNode *O = F_->createSave("ret", R);
4342   bindings_.allocate(mod_.getPlaceholders());
4343   EE_.compile(CompilationMode::Infer);
4344   EE_.run(bindings_);
4345 
4346   auto result = bindings_.get(O->getPlaceholder())->getHandle();
4347   std::vector<dim_t> expectedDims = {3, 2};
4348   EXPECT_TRUE(result.dims().vec() == expectedDims);
4349 
4350   std::vector<float> expectedValues = {2.5f, 6.5f, 10.5f, 14.5f, 18.5f, 22.5f};
4351   for (size_t i = 0; i < 3 * 2; i++) {
4352     EXPECT_EQ(result.raw(i), expectedValues[i]);
4353   }
4354 }
4355 
4356 /// Helper to check the code generation for flip nodes.
4357 template <typename elemType>
testFlip(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,std::vector<elemType> inputData,std::vector<elemType> expectedData,llvm::ArrayRef<dim_t> dims,dim_t axis,ElemKind elemKind=ElemKind::FloatTy)4358 static void testFlip(glow::PlaceholderBindings &bindings, glow::Module &mod,
4359                      glow::Function *F, glow::ExecutionEngine &EE,
4360                      std::vector<elemType> inputData,
4361                      std::vector<elemType> expectedData,
4362                      llvm::ArrayRef<dim_t> dims, dim_t axis,
4363                      ElemKind elemKind = ElemKind::FloatTy) {
4364 
4365   // Create network.
4366   auto *input =
4367       createPlaceholderConditionallyQuantized(mod, elemKind, dims, "input",
4368                                               /* isTrainable */ false);
4369   auto *flip = F->createFlip("flip", input, axis);
4370   Placeholder *output = F->createSave("save", flip)->getPlaceholder();
4371 
4372   // Allocate input/output and initialize input.
4373   auto inputH = bindings.allocate(input)->getHandle<elemType>();
4374   auto outputH = bindings.allocate(output)->getHandle<elemType>();
4375   inputH = inputData;
4376 
4377   // Compile and run.
4378   EE.compile(CompilationMode::Infer);
4379   EE.run(bindings);
4380 
4381   // Compare output with reference.
4382   EXPECT_EQ(outputH.size(), expectedData.size());
4383   for (size_t i = 0; i < expectedData.size(); i++) {
4384     EXPECT_EQ(outputH.raw(i), expectedData[i]);
4385   }
4386 }
4387 
4388 /// Test Flip 1D with Int8.
TEST_P(OperatorTest,Flip1D_Int8)4389 TEST_P(OperatorTest, Flip1D_Int8) {
4390   ENABLED_BACKENDS("Interpreter", "CPU");
4391   testFlip<int8_t>(bindings_, mod_, F_, EE_, {1, 2, 3, 4}, {4, 3, 2, 1}, {4}, 0,
4392                    ElemKind::Int8QTy);
4393 }
4394 
4395 /// Test Flip 1D with Int32.
TEST_P(OperatorTest,Flip1D_Int32)4396 TEST_P(OperatorTest, Flip1D_Int32) {
4397   ENABLED_BACKENDS("Interpreter", "CPU");
4398   testFlip<int32_t>(bindings_, mod_, F_, EE_, {1, 2, 3, 4}, {4, 3, 2, 1}, {4},
4399                     0, ElemKind::Int32QTy);
4400 }
4401 
4402 /// Test Flip 1D with Int64.
TEST_P(OperatorTest,Flip1D_Int64)4403 TEST_P(OperatorTest, Flip1D_Int64) {
4404   ENABLED_BACKENDS("Interpreter", "CPU");
4405   testFlip<int64_t>(bindings_, mod_, F_, EE_, {1, 2, 3, 4}, {4, 3, 2, 1}, {4},
4406                     0, ElemKind::Int64ITy);
4407 }
4408 
4409 #define FLIP_3D_INPUT                                                          \
4410   { 1, 2, 3, 4, 5, 6, 7, 8 }
4411 #define FLIP_3D_AXIS0                                                          \
4412   { 5, 6, 7, 8, 1, 2, 3, 4 }
4413 #define FLIP_3D_AXIS1                                                          \
4414   { 3, 4, 1, 2, 7, 8, 5, 6 }
4415 #define FLIP_3D_AXIS2                                                          \
4416   { 2, 1, 4, 3, 6, 5, 8, 7 }
4417 
4418 #define FLIP_4D_INPUT                                                          \
4419   { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 }
4420 #define FLIP_4D_AXIS0                                                          \
4421   { 9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8 }
4422 #define FLIP_4D_AXIS1                                                          \
4423   { 5, 6, 7, 8, 1, 2, 3, 4, 13, 14, 15, 16, 9, 10, 11, 12 }
4424 #define FLIP_4D_AXIS2                                                          \
4425   { 3, 4, 1, 2, 7, 8, 5, 6, 11, 12, 9, 10, 15, 16, 13, 14 }
4426 #define FLIP_4D_AXIS3                                                          \
4427   { 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15 }
4428 
4429 #define FLIP_5D_INPUT                                                          \
4430   {                                                                            \
4431     1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, \
4432         22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32                             \
4433   }
4434 #define FLIP_5D_AXIS0                                                          \
4435   {                                                                            \
4436     17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 1, 2, 3,   \
4437         4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16                           \
4438   }
4439 #define FLIP_5D_AXIS1                                                          \
4440   {                                                                            \
4441     9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 25, 26, 27, 28, 29, \
4442         30, 31, 32, 17, 18, 19, 20, 21, 22, 23, 24                             \
4443   }
4444 #define FLIP_5D_AXIS2                                                          \
4445   {                                                                            \
4446     5, 6, 7, 8, 1, 2, 3, 4, 13, 14, 15, 16, 9, 10, 11, 12, 21, 22, 23, 24, 17, \
4447         18, 19, 20, 29, 30, 31, 32, 25, 26, 27, 28                             \
4448   }
4449 #define FLIP_5D_AXIS3                                                          \
4450   {                                                                            \
4451     3, 4, 1, 2, 7, 8, 5, 6, 11, 12, 9, 10, 15, 16, 13, 14, 19, 20, 17, 18, 23, \
4452         24, 21, 22, 27, 28, 25, 26, 31, 32, 29, 30                             \
4453   }
4454 #define FLIP_5D_AXIS4                                                          \
4455   {                                                                            \
4456     2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19, 22, \
4457         21, 24, 23, 26, 25, 28, 27, 30, 29, 32, 31                             \
4458   }
4459 
4460 #define FLIP_6D_INPUT                                                          \
4461   {                                                                            \
4462     1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, \
4463         22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,    \
4464         39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,    \
4465         56, 57, 58, 59, 60, 61, 62, 63, 64                                     \
4466   }
4467 #define FLIP_6D_AXIS0                                                          \
4468   {                                                                            \
4469     33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,    \
4470         51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 1, 2, 3, 4, 5, \
4471         6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,    \
4472         24, 25, 26, 27, 28, 29, 30, 31, 32                                     \
4473   }
4474 #define FLIP_6D_AXIS1                                                          \
4475   {                                                                            \
4476     17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 1, 2, 3,   \
4477         4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 49, 50, 51, 52, 53, 54,  \
4478         55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 33, 34, 35, 36, 37, 38, 39,    \
4479         40, 41, 42, 43, 44, 45, 46, 47, 48                                     \
4480   }
4481 #define FLIP_6D_AXIS2                                                          \
4482   {                                                                            \
4483     9, 10, 11, 12, 13, 14, 15, 16, 1, 2, 3, 4, 5, 6, 7, 8, 25, 26, 27, 28, 29, \
4484         30, 31, 32, 17, 18, 19, 20, 21, 22, 23, 24, 41, 42, 43, 44, 45, 46,    \
4485         47, 48, 33, 34, 35, 36, 37, 38, 39, 40, 57, 58, 59, 60, 61, 62, 63,    \
4486         64, 49, 50, 51, 52, 53, 54, 55, 56                                     \
4487   }
4488 #define FLIP_6D_AXIS3                                                          \
4489   {                                                                            \
4490     5, 6, 7, 8, 1, 2, 3, 4, 13, 14, 15, 16, 9, 10, 11, 12, 21, 22, 23, 24, 17, \
4491         18, 19, 20, 29, 30, 31, 32, 25, 26, 27, 28, 37, 38, 39, 40, 33, 34,    \
4492         35, 36, 45, 46, 47, 48, 41, 42, 43, 44, 53, 54, 55, 56, 49, 50, 51,    \
4493         52, 61, 62, 63, 64, 57, 58, 59, 60                                     \
4494   }
4495 #define FLIP_6D_AXIS4                                                          \
4496   {                                                                            \
4497     3, 4, 1, 2, 7, 8, 5, 6, 11, 12, 9, 10, 15, 16, 13, 14, 19, 20, 17, 18, 23, \
4498         24, 21, 22, 27, 28, 25, 26, 31, 32, 29, 30, 35, 36, 33, 34, 39, 40,    \
4499         37, 38, 43, 44, 41, 42, 47, 48, 45, 46, 51, 52, 49, 50, 55, 56, 53,    \
4500         54, 59, 60, 57, 58, 63, 64, 61, 62                                     \
4501   }
4502 #define FLIP_6D_AXIS5                                                          \
4503   {                                                                            \
4504     2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19, 22, \
4505         21, 24, 23, 26, 25, 28, 27, 30, 29, 32, 31, 34, 33, 36, 35, 38, 37,    \
4506         40, 39, 42, 41, 44, 43, 46, 45, 48, 47, 50, 49, 52, 51, 54, 53, 56,    \
4507         55, 58, 57, 60, 59, 62, 61, 64, 63                                     \
4508   }
4509 
4510 /// Test Flip 1D with Float.
TEST_P(OperatorTest,Flip1D_Axis0_Float)4511 TEST_P(OperatorTest, Flip1D_Axis0_Float) {
4512   ENABLED_BACKENDS("Interpreter", "CPU");
4513   testFlip<float>(bindings_, mod_, F_, EE_, {1, 2}, {2, 1}, {2}, 0);
4514 }
4515 
4516 /// Test Flip 2D with Float.
TEST_P(OperatorTest,Flip2D_Axis0_Float)4517 TEST_P(OperatorTest, Flip2D_Axis0_Float) {
4518   ENABLED_BACKENDS("Interpreter", "CPU");
4519   testFlip<float>(bindings_, mod_, F_, EE_, {1, 2, 3, 4}, {3, 4, 1, 2}, {2, 2},
4520                   0);
4521 }
TEST_P(OperatorTest,Flip2D_Axis1_Float)4522 TEST_P(OperatorTest, Flip2D_Axis1_Float) {
4523   ENABLED_BACKENDS("Interpreter", "CPU");
4524   testFlip<float>(bindings_, mod_, F_, EE_, {1, 2, 3, 4}, {2, 1, 4, 3}, {2, 2},
4525                   1);
4526 }
4527 
4528 /// Test Flip 3D with Float.
TEST_P(OperatorTest,Flip3D_Axis0_Float)4529 TEST_P(OperatorTest, Flip3D_Axis0_Float) {
4530   ENABLED_BACKENDS("Interpreter", "CPU");
4531   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_3D_INPUT, FLIP_3D_AXIS0,
4532                   {2, 2, 2}, 0);
4533 }
TEST_P(OperatorTest,Flip3D_Axis1_Float)4534 TEST_P(OperatorTest, Flip3D_Axis1_Float) {
4535   ENABLED_BACKENDS("Interpreter", "CPU");
4536   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_3D_INPUT, FLIP_3D_AXIS1,
4537                   {2, 2, 2}, 1);
4538 }
TEST_P(OperatorTest,Flip3D_Axis2_Float)4539 TEST_P(OperatorTest, Flip3D_Axis2_Float) {
4540   ENABLED_BACKENDS("Interpreter", "CPU");
4541   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_3D_INPUT, FLIP_3D_AXIS2,
4542                   {2, 2, 2}, 2);
4543 }
4544 
4545 /// Test Flip 4D with Float.
TEST_P(OperatorTest,Flip4D_Axis0_Float)4546 TEST_P(OperatorTest, Flip4D_Axis0_Float) {
4547   ENABLED_BACKENDS("Interpreter", "CPU");
4548   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_4D_INPUT, FLIP_4D_AXIS0,
4549                   {2, 2, 2, 2}, 0);
4550 }
TEST_P(OperatorTest,Flip4D_Axis1_Float)4551 TEST_P(OperatorTest, Flip4D_Axis1_Float) {
4552   ENABLED_BACKENDS("Interpreter", "CPU");
4553   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_4D_INPUT, FLIP_4D_AXIS1,
4554                   {2, 2, 2, 2}, 1);
4555 }
TEST_P(OperatorTest,Flip4D_Axis2_Float)4556 TEST_P(OperatorTest, Flip4D_Axis2_Float) {
4557   ENABLED_BACKENDS("Interpreter", "CPU");
4558   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_4D_INPUT, FLIP_4D_AXIS2,
4559                   {2, 2, 2, 2}, 2);
4560 }
TEST_P(OperatorTest,Flip4D_Axis3_Float)4561 TEST_P(OperatorTest, Flip4D_Axis3_Float) {
4562   ENABLED_BACKENDS("Interpreter", "CPU");
4563   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_4D_INPUT, FLIP_4D_AXIS3,
4564                   {2, 2, 2, 2}, 3);
4565 }
4566 
4567 /// Test Flip 5D with Float.
TEST_P(OperatorTest,Flip5D_Axis0_Float)4568 TEST_P(OperatorTest, Flip5D_Axis0_Float) {
4569   ENABLED_BACKENDS("Interpreter", "CPU");
4570   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_5D_INPUT, FLIP_5D_AXIS0,
4571                   {2, 2, 2, 2, 2}, 0);
4572 }
TEST_P(OperatorTest,Flip5D_Axis1_Float)4573 TEST_P(OperatorTest, Flip5D_Axis1_Float) {
4574   ENABLED_BACKENDS("Interpreter", "CPU");
4575   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_5D_INPUT, FLIP_5D_AXIS1,
4576                   {2, 2, 2, 2, 2}, 1);
4577 }
TEST_P(OperatorTest,Flip5D_Axis2_Float)4578 TEST_P(OperatorTest, Flip5D_Axis2_Float) {
4579   ENABLED_BACKENDS("Interpreter", "CPU");
4580   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_5D_INPUT, FLIP_5D_AXIS2,
4581                   {2, 2, 2, 2, 2}, 2);
4582 }
TEST_P(OperatorTest,Flip5D_Axis3_Float)4583 TEST_P(OperatorTest, Flip5D_Axis3_Float) {
4584   ENABLED_BACKENDS("Interpreter", "CPU");
4585   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_5D_INPUT, FLIP_5D_AXIS3,
4586                   {2, 2, 2, 2, 2}, 3);
4587 }
TEST_P(OperatorTest,Flip5D_Axis4_Float)4588 TEST_P(OperatorTest, Flip5D_Axis4_Float) {
4589   ENABLED_BACKENDS("Interpreter", "CPU");
4590   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_5D_INPUT, FLIP_5D_AXIS4,
4591                   {2, 2, 2, 2, 2}, 4);
4592 }
4593 
4594 /// Test Flip 6D with Float.
TEST_P(OperatorTest,Flip6D_Axis0_Float)4595 TEST_P(OperatorTest, Flip6D_Axis0_Float) {
4596   ENABLED_BACKENDS("Interpreter", "CPU");
4597   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_6D_INPUT, FLIP_6D_AXIS0,
4598                   {2, 2, 2, 2, 2, 2}, 0);
4599 }
TEST_P(OperatorTest,Flip6D_Axis1_Float)4600 TEST_P(OperatorTest, Flip6D_Axis1_Float) {
4601   ENABLED_BACKENDS("Interpreter", "CPU");
4602   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_6D_INPUT, FLIP_6D_AXIS1,
4603                   {2, 2, 2, 2, 2, 2}, 1);
4604 }
TEST_P(OperatorTest,Flip6D_Axis2_Float)4605 TEST_P(OperatorTest, Flip6D_Axis2_Float) {
4606   ENABLED_BACKENDS("Interpreter", "CPU");
4607   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_6D_INPUT, FLIP_6D_AXIS2,
4608                   {2, 2, 2, 2, 2, 2}, 2);
4609 }
TEST_P(OperatorTest,Flip6D_Axis3_Float)4610 TEST_P(OperatorTest, Flip6D_Axis3_Float) {
4611   ENABLED_BACKENDS("Interpreter", "CPU");
4612   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_6D_INPUT, FLIP_6D_AXIS3,
4613                   {2, 2, 2, 2, 2, 2}, 3);
4614 }
TEST_P(OperatorTest,Flip6D_Axis4_Float)4615 TEST_P(OperatorTest, Flip6D_Axis4_Float) {
4616   ENABLED_BACKENDS("Interpreter", "CPU");
4617   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_6D_INPUT, FLIP_6D_AXIS4,
4618                   {2, 2, 2, 2, 2, 2}, 4);
4619 }
TEST_P(OperatorTest,Flip6D_Axis5_Float)4620 TEST_P(OperatorTest, Flip6D_Axis5_Float) {
4621   ENABLED_BACKENDS("Interpreter", "CPU");
4622   testFlip<float>(bindings_, mod_, F_, EE_, FLIP_6D_INPUT, FLIP_6D_AXIS5,
4623                   {2, 2, 2, 2, 2, 2}, 5);
4624 }
4625 
4626 #undef FLIP_3D_INPUT
4627 #undef FLIP_3D_AXIS0
4628 #undef FLIP_3D_AXIS1
4629 #undef FLIP_3D_AXIS2
4630 #undef FLIP_4D_INPUT
4631 #undef FLIP_4D_AXIS0
4632 #undef FLIP_4D_AXIS1
4633 #undef FLIP_4D_AXIS2
4634 #undef FLIP_4D_AXIS3
4635 #undef FLIP_5D_INPUT
4636 #undef FLIP_5D_AXIS0
4637 #undef FLIP_5D_AXIS1
4638 #undef FLIP_5D_AXIS2
4639 #undef FLIP_5D_AXIS3
4640 #undef FLIP_5D_AXIS4
4641 #undef FLIP_6D_INPUT
4642 #undef FLIP_6D_AXIS0
4643 #undef FLIP_6D_AXIS1
4644 #undef FLIP_6D_AXIS2
4645 #undef FLIP_6D_AXIS3
4646 #undef FLIP_6D_AXIS4
4647 #undef FLIP_6D_AXIS5
4648 
4649 /// Check that gather on Int64ITy/size_t works.
TEST_P(OperatorTest,GatherSizeT)4650 TEST_P(OperatorTest, GatherSizeT) {
4651   CHECK_IF_ENABLED();
4652 
4653   /*
4654     DATA  = [
4655         [1, 2],
4656         [3, 4],
4657         [5, 6],
4658     ]
4659     INDICES = [
4660         [0, 1, 0, 1],
4661         [1, 2, 2, 0],
4662     ]
4663     OUTPUT = [
4664         [
4665             [1, 2],
4666             [3, 4],
4667             [1, 2],
4668             [3, 4],
4669         ],
4670         [
4671             [3, 4],
4672             [5, 6],
4673             [5, 6],
4674             [1, 2],
4675         ],
4676     ]
4677   */
4678   auto *data =
4679       mod_.createPlaceholder(ElemKind::Int64ITy, {3, 2}, "data", false);
4680   auto *indices =
4681       mod_.createPlaceholder(ElemKind::Int64ITy, {2, 4}, "indices", false);
4682 
4683   bindings_.allocate(data)->getHandle<int64_t>() = {
4684       1, 2, 3, 4, 5, 6,
4685   };
4686   bindings_.allocate(indices)->getHandle<int64_t>() = {
4687       0, 1, 0, 1, 1, 2, 2, 0,
4688   };
4689 
4690   auto *R = F_->createGather("gather", data, indices);
4691 
4692   auto *result = F_->createSave("save", R);
4693   bindings_.allocate(result->getPlaceholder());
4694 
4695   EE_.compile(CompilationMode::Infer);
4696   EE_.run(bindings_);
4697 
4698   auto H = bindings_.get(result->getPlaceholder())->getHandle<int64_t>();
4699 
4700   EXPECT_EQ(H.at({0, 0, 0}), 1);
4701   EXPECT_EQ(H.at({0, 0, 1}), 2);
4702   EXPECT_EQ(H.at({0, 1, 0}), 3);
4703   EXPECT_EQ(H.at({0, 1, 1}), 4);
4704   EXPECT_EQ(H.at({0, 2, 0}), 1);
4705   EXPECT_EQ(H.at({0, 2, 1}), 2);
4706   EXPECT_EQ(H.at({0, 3, 0}), 3);
4707   EXPECT_EQ(H.at({0, 3, 1}), 4);
4708 
4709   EXPECT_EQ(H.at({1, 0, 0}), 3);
4710   EXPECT_EQ(H.at({1, 0, 1}), 4);
4711   EXPECT_EQ(H.at({1, 1, 0}), 5);
4712   EXPECT_EQ(H.at({1, 1, 1}), 6);
4713   EXPECT_EQ(H.at({1, 2, 0}), 5);
4714   EXPECT_EQ(H.at({1, 2, 1}), 6);
4715   EXPECT_EQ(H.at({1, 3, 0}), 1);
4716   EXPECT_EQ(H.at({1, 3, 1}), 2);
4717 }
4718 
TEST_P(OperatorTest,BatchedGather)4719 TEST_P(OperatorTest, BatchedGather) {
4720   CHECK_IF_ENABLED();
4721 
4722   /*
4723    DATA  = [
4724     [1.0, 1.2, 2.4, 4.5],
4725     [2.3, 3.4, 3.6, 2.3],
4726     [4.5, 5.7, 1.2, 4.5],
4727    ]
4728 
4729    INDICES = [0, 2],
4730 
4731    OUTPUT = [
4732     [1.0, 2.4],
4733     [2.3, 3.6],
4734     [4.5, 1.2],
4735    ]
4736    */
4737   auto *data = mod_.createPlaceholder(ElemKind::FloatTy, {3, 4}, "data", false);
4738   auto *indices =
4739       mod_.createPlaceholder(ElemKind::Int64ITy, {2}, "indices", false);
4740 
4741   bindings_.allocate(data)->getHandle() = {
4742       1.0f, 1.2f, 2.4f, 4.5f, 2.3f, 3.4f, 3.6f, 2.3f, 4.5f, 5.7f, 1.2f, 4.5f,
4743   };
4744   bindings_.allocate(indices)->getHandle<int64_t>() = {
4745       0,
4746       2,
4747   };
4748 
4749   // Create a batched gather (a single batch dimension).
4750   auto *R = F_->createGather("gather", data, indices, 1);
4751 
4752   auto *result = F_->createSave("save", R);
4753   bindings_.allocate(result->getPlaceholder());
4754 
4755   EE_.compile(CompilationMode::Infer);
4756   EE_.run(bindings_);
4757 
4758   auto H = bindings_.get(result->getPlaceholder())->getHandle();
4759   EXPECT_FLOAT_EQ(H.at({0, 0}), 1.0);
4760   EXPECT_FLOAT_EQ(H.at({0, 1}), 2.4);
4761   EXPECT_FLOAT_EQ(H.at({1, 0}), 2.3);
4762   EXPECT_FLOAT_EQ(H.at({1, 1}), 3.6);
4763   EXPECT_FLOAT_EQ(H.at({2, 0}), 4.5);
4764   EXPECT_FLOAT_EQ(H.at({2, 1}), 1.2);
4765 }
4766 
TEST_P(OperatorTest,ScatterData)4767 TEST_P(OperatorTest, ScatterData) {
4768   CHECK_IF_ENABLED();
4769 
4770   auto *data = mod_.createPlaceholder(ElemKind::FloatTy, {5, 2}, "data", false);
4771   auto *indices =
4772       mod_.createPlaceholder(ElemKind::Int64ITy, {2, 1}, "indices", false);
4773   auto *slices =
4774       mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "slices", false);
4775 
4776   bindings_.allocate(data)->getHandle() = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
4777   bindings_.allocate(indices)->getHandle<int64_t>() = {1, 3};
4778   bindings_.allocate(slices)->getHandle() = {-3, -4, -7, -8};
4779 
4780   auto *R = F_->createScatterData("scatterdata", data, indices, slices);
4781 
4782   auto *result = F_->createSave("save", R);
4783   bindings_.allocate(result->getPlaceholder());
4784 
4785   EE_.compile(CompilationMode::Infer);
4786   EE_.run(bindings_);
4787 
4788   auto H = bindings_.get(result->getPlaceholder())->getHandle();
4789 
4790   EXPECT_FLOAT_EQ(H.at({0, 0}), 1.0);
4791   EXPECT_FLOAT_EQ(H.at({0, 1}), 2.0);
4792   EXPECT_FLOAT_EQ(H.at({1, 0}), -3.0);
4793   EXPECT_FLOAT_EQ(H.at({1, 1}), -4.0);
4794   EXPECT_FLOAT_EQ(H.at({2, 0}), 5.0);
4795   EXPECT_FLOAT_EQ(H.at({2, 1}), 6.0);
4796   EXPECT_FLOAT_EQ(H.at({3, 0}), -7.0);
4797   EXPECT_FLOAT_EQ(H.at({3, 1}), -8.0);
4798   EXPECT_FLOAT_EQ(H.at({4, 0}), 9.0);
4799   EXPECT_FLOAT_EQ(H.at({4, 1}), 10.0);
4800 }
4801 
TEST_P(OperatorTest,ScatterDataQuantized)4802 TEST_P(OperatorTest, ScatterDataQuantized) {
4803   CHECK_IF_ENABLED();
4804 
4805   auto *data = mod_.createPlaceholder(ElemKind::FloatTy, {5, 2}, "data", false);
4806   auto *indices =
4807       mod_.createPlaceholder(ElemKind::Int64ITy, {2, 1}, "indices", false);
4808   auto *slices =
4809       mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "slices", false);
4810 
4811   bindings_.allocate(data)->getHandle() = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
4812   bindings_.allocate(indices)->getHandle<int64_t>() = {1, 3};
4813   bindings_.allocate(slices)->getHandle() = {-3, -4, -7, -8};
4814 
4815   auto qParams = glow::quantization::chooseQuantizationParams({-11, 11});
4816   auto dataTy =
4817       mod_.uniqueType(ElemKind::Int8QTy, {5, 2}, qParams.scale, qParams.offset);
4818   auto slicesTy =
4819       mod_.uniqueType(ElemKind::Int8QTy, {2, 2}, qParams.scale, qParams.offset);
4820 
4821   auto *dataQ = F_->createQuantize("quantizeQ", data, dataTy);
4822   auto *slicesQ = F_->createQuantize("quantizeS", slices, slicesTy);
4823   auto *SA = F_->createScatterData("scatterdata", dataQ, indices, slicesQ);
4824   auto *DQ = F_->createDequantize("dequantize", SA, ElemKind::FloatTy);
4825 
4826   auto *result = F_->createSave("save", DQ);
4827   bindings_.allocate(result->getPlaceholder());
4828 
4829   EE_.compile(CompilationMode::Infer);
4830   EE_.run(bindings_);
4831 
4832   auto H = bindings_.get(result->getPlaceholder())->getHandle();
4833 
4834   EXPECT_NEAR(H.at({0, 0}), 1.0, 0.05);
4835   EXPECT_NEAR(H.at({0, 1}), 2.0, 0.05);
4836   EXPECT_NEAR(H.at({1, 0}), -3.0, 0.05);
4837   EXPECT_NEAR(H.at({1, 1}), -4.0, 0.05);
4838   EXPECT_NEAR(H.at({2, 0}), 5.0, 0.05);
4839   EXPECT_NEAR(H.at({2, 1}), 6.0, 0.05);
4840   EXPECT_NEAR(H.at({3, 0}), -7.0, 0.05);
4841   EXPECT_NEAR(H.at({3, 1}), -8.0, 0.05);
4842   EXPECT_NEAR(H.at({4, 0}), 9.0, 0.05);
4843   EXPECT_NEAR(H.at({4, 1}), 10.0, 0.05);
4844 }
4845 
TEST_P(OperatorTest,ScatterDataNDimensionalSimple)4846 TEST_P(OperatorTest, ScatterDataNDimensionalSimple) {
4847   CHECK_IF_ENABLED();
4848 
4849   // Data = {{1,2},{3,4},{5,6}}
4850   // Slices = {-3,-4}
4851   // Indices = {{1,0},{1,1}}
4852   // Result = {{1,2},{-3,-4},{5,6}}
4853   auto *data = mod_.createPlaceholder(ElemKind::FloatTy, {3, 2}, "data", false);
4854   auto *indices =
4855       mod_.createPlaceholder(ElemKind::Int64ITy, {2, 2}, "indices", false);
4856   auto *slices =
4857       mod_.createPlaceholder(ElemKind::FloatTy, {2}, "slices", false);
4858 
4859   // Fill tensor with consecutive data.
4860   std::vector<float> init(6);
4861   std::iota(init.begin(), init.end(), 1);
4862   bindings_.allocate(data)->getHandle() = init;
4863   bindings_.allocate(indices)->getHandle<int64_t>() = {1, 0, 1, 1};
4864   bindings_.allocate(slices)->getHandle() = {-3., -4.};
4865   auto *R = F_->createScatterData("scatterdata", data, indices, slices);
4866 
4867   auto *result = F_->createSave("save", R);
4868   bindings_.allocate(result->getPlaceholder());
4869 
4870   EE_.compile(CompilationMode::Infer);
4871   EE_.run(bindings_);
4872 
4873   std::vector<dim_t> expectedDims = {3, 2};
4874   std::vector<float> expectedValues = {1., 2., -3., -4., 5., 6.};
4875   auto H = bindings_.get(result->getPlaceholder())->getHandle();
4876   EXPECT_TRUE(H.dims().vec() == expectedDims);
4877   for (dim_t i = 0; i < expectedValues.size(); i++) {
4878     EXPECT_EQ(expectedValues[i], H.raw(i));
4879   }
4880 }
4881 
TEST_P(OperatorTest,ScatterDataNDimensional)4882 TEST_P(OperatorTest, ScatterDataNDimensional) {
4883   CHECK_IF_ENABLED();
4884 
4885   // In tensor 2x4x4x3, make two updates with 2-dimensional slices by
4886   // 2-dimensional indices:
4887   // 1. By index [0, 3], set [[-1.,  -2.,  -3.]
4888   //                          [-4.,  -5.,  -6.]
4889   //                          [-7.,  -8.,  -9.]
4890   //                          [-10., -11., -12.]];
4891   //
4892   // 2. By index [1, 1], set [[-13., -14., -15.]
4893   //                          [-16., -17., -18.]
4894   //                          [-19., -20., -21.]
4895   //                          [-22., -23., -24.]];
4896   //
4897   auto *data =
4898       mod_.createPlaceholder(ElemKind::FloatTy, {2, 4, 4, 3}, "data", false);
4899   auto *indices =
4900       mod_.createPlaceholder(ElemKind::Int64ITy, {2, 2}, "indices", false);
4901   auto *slices =
4902       mod_.createPlaceholder(ElemKind::FloatTy, {2, 4, 3}, "slices", false);
4903 
4904   // Fill tensor with consecutive data.
4905   std::vector<float> init(2 * 4 * 4 * 3);
4906   std::iota(init.begin(), init.end(), 0);
4907   bindings_.allocate(data)->getHandle() = init;
4908   bindings_.allocate(indices)->getHandle<int64_t>() = {0, 3, 1, 1};
4909   std::vector<float> initUpdates;
4910   for (int32_t i = -1; i > -25; i--) {
4911     initUpdates.push_back(static_cast<float>(i));
4912   }
4913   bindings_.allocate(slices)->getHandle() = initUpdates;
4914 
4915   auto *R = F_->createScatterData("scatterdata", data, indices, slices);
4916 
4917   auto *result = F_->createSave("save", R);
4918   bindings_.allocate(result->getPlaceholder());
4919 
4920   EE_.compile(CompilationMode::Infer);
4921   EE_.run(bindings_);
4922 
4923   std::vector<dim_t> expectedDims = {2, 4, 4, 3};
4924   std::vector<float> expectedValues = {
4925       0.0f,   1.0f,   2.0f,   3.0f,   4.0f,   5.0f,
4926       6.0f,   7.0f,   8.0f,   9.0f,   10.0f,  11.0f,
4927 
4928       12.0f,  13.0f,  14.0f,  15.0f,  16.0f,  17.0f,
4929       18.0f,  19.0f,  20.0f,  21.0f,  22.0f,  23.0f,
4930 
4931       24.0f,  25.0f,  26.0f,  27.0f,  28.0f,  29.0f,
4932       30.0f,  31.0f,  32.0f,  33.0f,  34.0f,  35.0f,
4933 
4934       -1.0f,  -2.0f,  -3.0f,  -4.0f,  -5.0f,  -6.0f,
4935       -7.0f,  -8.0f,  -9.0f,  -10.0f, -11.0f, -12.0f,
4936 
4937       48.0f,  49.0f,  50.0f,  51.0f,  52.0f,  53.0f,
4938       54.0f,  55.0f,  56.0f,  57.0f,  58.0f,  59.0f,
4939 
4940       -13.0f, -14.0f, -15.0f, -16.0f, -17.0f, -18.0f,
4941       -19.0f, -20.0f, -21.0f, -22.0f, -23.0f, -24.0f,
4942 
4943       72.0f,  73.0f,  74.0f,  75.0f,  76.0f,  77.0f,
4944       78.0f,  79.0f,  80.0f,  81.0f,  82.0f,  83.0f,
4945 
4946       84.0f,  85.0f,  86.0f,  87.0f,  88.0f,  89.0f,
4947       90.0f,  91.0f,  92.0f,  93.0f,  94.0f,  95.0f};
4948   auto H = bindings_.get(result->getPlaceholder())->getHandle();
4949   EXPECT_TRUE(H.dims().vec() == expectedDims);
4950   for (dim_t i = 0; i < expectedValues.size(); i++) {
4951     EXPECT_EQ(expectedValues[i], H.raw(i));
4952   }
4953 }
4954 
TEST_P(OperatorTest,ScatterAddQuantized)4955 TEST_P(OperatorTest, ScatterAddQuantized) {
4956   CHECK_IF_ENABLED();
4957 
4958   auto *data = mod_.createPlaceholder(ElemKind::FloatTy, {5, 2}, "data", false);
4959   auto *indices =
4960       mod_.createPlaceholder(ElemKind::Int64ITy, {2, 1}, "indices", false);
4961   auto *slices =
4962       mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "slices", false);
4963 
4964   bindings_.allocate(data)->getHandle() = {1, 2, -3, -8, 5, 6, 7, 8, 9, 10};
4965   bindings_.allocate(indices)->getHandle<int64_t>() = {1, 3};
4966   bindings_.allocate(slices)->getHandle() = {3, -8, -7, 8};
4967 
4968   auto qParams = glow::quantization::chooseQuantizationParams({-11, 11});
4969   auto dataTy =
4970       mod_.uniqueType(ElemKind::Int8QTy, {5, 2}, qParams.scale, qParams.offset);
4971   auto slicesTy =
4972       mod_.uniqueType(ElemKind::Int8QTy, {2, 2}, qParams.scale, qParams.offset);
4973 
4974   auto *dataQ = F_->createQuantize("quantizeQ", data, dataTy);
4975   auto *slicesQ = F_->createQuantize("quantizeS", slices, slicesTy);
4976   auto *SA = F_->createScatterData("scatteradd", dataQ, indices, slicesQ,
4977                                    /*Cumulative*/ true);
4978   auto *DQ = F_->createDequantize("dequantize", SA, ElemKind::FloatTy);
4979 
4980   auto *result = F_->createSave("save", DQ);
4981   bindings_.allocate(result->getPlaceholder());
4982 
4983   EE_.compile(CompilationMode::Infer);
4984   EE_.run(bindings_);
4985 
4986   auto H = bindings_.get(result->getPlaceholder())->getHandle();
4987 
4988   EXPECT_NEAR(H.at({0, 0}), 1.0, 0.05);
4989   EXPECT_NEAR(H.at({0, 1}), 2.0, 0.05);
4990   EXPECT_NEAR(H.at({1, 0}), 0.0, 0.05);
4991   EXPECT_NEAR(H.at({1, 1}), -11.0, 0.05);
4992   EXPECT_NEAR(H.at({2, 0}), 5.0, 0.05);
4993   EXPECT_NEAR(H.at({2, 1}), 6.0, 0.05);
4994   EXPECT_NEAR(H.at({3, 0}), 0.0, 0.05);
4995   EXPECT_NEAR(H.at({3, 1}), 11.0, 0.05);
4996   EXPECT_NEAR(H.at({4, 0}), 9.0, 0.05);
4997   EXPECT_NEAR(H.at({4, 1}), 10.0, 0.05);
4998 }
4999 
TEST_P(OperatorTest,ScatterAddNDimensionalSimple)5000 TEST_P(OperatorTest, ScatterAddNDimensionalSimple) {
5001   CHECK_IF_ENABLED();
5002   // Test that scatter addition works.
5003   // Data = {{1,2},{3,4},{5,6}}
5004   // Slices = {-3,-4}
5005   // Indices = {{1,0},{1,1}}
5006   // Result = {{1,2},{0,0},{5,6}}
5007   auto *data = mod_.createPlaceholder(ElemKind::FloatTy, {3, 2}, "data", false);
5008   auto *indices =
5009       mod_.createPlaceholder(ElemKind::Int64ITy, {2, 2}, "indices", false);
5010   auto *slices =
5011       mod_.createPlaceholder(ElemKind::FloatTy, {2}, "slices", false);
5012 
5013   // Fill tensor with consecutive data.
5014   std::vector<float> init;
5015   for (int32_t i = 1; i < 7; i++) {
5016     init.push_back(static_cast<float>(i));
5017   }
5018   bindings_.allocate(data)->getHandle() = init;
5019   bindings_.allocate(indices)->getHandle<int64_t>() = {1, 0, 1, 1};
5020   bindings_.allocate(slices)->getHandle() = {-3., -4.};
5021   auto *R = F_->createScatterData("scatteradd", data, indices, slices,
5022                                   /*Cumulative*/ true);
5023 
5024   auto *result = F_->createSave("save", R);
5025   bindings_.allocate(result->getPlaceholder());
5026 
5027   EE_.compile(CompilationMode::Infer);
5028   EE_.run(bindings_);
5029 
5030   std::vector<dim_t> expectedDims = {3, 2};
5031   std::vector<float> expectedValues = {1., 2., 0., 0., 5., 6.};
5032   auto H = bindings_.get(result->getPlaceholder())->getHandle();
5033   EXPECT_TRUE(H.dims().vec() == expectedDims);
5034   for (dim_t i = 0; i < expectedValues.size(); i++) {
5035     EXPECT_EQ(expectedValues[i], H.raw(i));
5036   }
5037 }
5038 
TEST_P(OperatorTest,ScatterAddNDimensionalDuplicatingIndices)5039 TEST_P(OperatorTest, ScatterAddNDimensionalDuplicatingIndices) {
5040   CHECK_IF_ENABLED();
5041   // Test that scatter addition with duplicating indices works.
5042   // Data = {{1,2},{3,4},{5,6}}
5043   // Slices = {-3,-4,-3,-4}
5044   // Indices = {{1,0},{1,1}{1,0},{1,1}}
5045   // Result = {{1,2},{-3,-4},{5,6}}
5046   auto *data = mod_.createPlaceholder(ElemKind::FloatTy, {3, 2}, "data", false);
5047   auto *indices =
5048       mod_.createPlaceholder(ElemKind::Int64ITy, {4, 2}, "indices", false);
5049   auto *slices =
5050       mod_.createPlaceholder(ElemKind::FloatTy, {4}, "slices", false);
5051 
5052   // Fill tensor with consecutive data.
5053   std::vector<float> init;
5054   for (int32_t i = 1; i < 7; i++) {
5055     init.push_back(static_cast<float>(i));
5056   }
5057   bindings_.allocate(data)->getHandle() = init;
5058   bindings_.allocate(indices)->getHandle<int64_t>() = {1, 0, 1, 1, 1, 0, 1, 1};
5059   bindings_.allocate(slices)->getHandle() = {-3., -4., -3., -4.};
5060   auto *R = F_->createScatterData("scatteradd", data, indices, slices,
5061                                   /*Cumulative*/ true);
5062 
5063   auto *result = F_->createSave("save", R);
5064   bindings_.allocate(result->getPlaceholder());
5065 
5066   EE_.compile(CompilationMode::Infer);
5067   EE_.run(bindings_);
5068 
5069   std::vector<dim_t> expectedDims = {3, 2};
5070   std::vector<float> expectedValues = {1., 2., -3., -4., 5., 6.};
5071   auto H = bindings_.get(result->getPlaceholder())->getHandle();
5072   EXPECT_TRUE(H.dims().vec() == expectedDims);
5073   for (dim_t i = 0; i < expectedValues.size(); i++) {
5074     EXPECT_EQ(expectedValues[i], H.raw(i));
5075   }
5076 }
5077 
5078 #define COMPARE_ARITH_FUN(_OP_NAME_)                                           \
5079   static FunctionTensorPair createAndInitBasic##_OP_NAME_##Test(               \
5080       glow::PlaceholderBindings &bindings, glow::ExecutionEngine &EE) {        \
5081     auto &mod = EE.getModule();                                                \
5082     Function *F = mod.createFunction("main");                                  \
5083                                                                                \
5084     auto *A = mod.createPlaceholder(ElemKind::FloatTy, {1, 4}, "A", false);    \
5085     auto *B = mod.createPlaceholder(ElemKind::FloatTy, {1, 4}, "B", false);    \
5086     bindings.allocate(A)->getHandle() = {1.0f, -1.2f, 0.5f, -1.3f};            \
5087     bindings.allocate(B)->getHandle() = {1.8f, -0.2f, -2.4f, 2.7f};            \
5088                                                                                \
5089     auto *add = F->create##_OP_NAME_("arith", A, B);                           \
5090     auto *result = F->createSave("save", add);                                 \
5091     auto *resultTensor = bindings.allocate(result->getPlaceholder());          \
5092                                                                                \
5093     return std::make_pair(F, resultTensor);                                    \
5094   }
5095 COMPARE_ARITH_FUN(Add)
COMPARE_ARITH_FUN(Sub)5096 COMPARE_ARITH_FUN(Sub)
5097 COMPARE_ARITH_FUN(Mul)
5098 COMPARE_ARITH_FUN(Div)
5099 COMPARE_ARITH_FUN(Max)
5100 COMPARE_ARITH_FUN(Min)
5101 #undef COMPARE_ARITH_FUN
5102 
5103 #define COMPARE_ARITH_FLOAT_VS_INT8(_OP_NAME_)                                 \
5104   TEST_P(OperatorStatelessTest, Basic##_OP_NAME_##NetFloatVsInt8) {            \
5105     CHECK_IF_ENABLED();                                                        \
5106     compareAgainstInterpreter(                                                 \
5107         getBackendName(), createAndInitBasic##_OP_NAME_##Test,                 \
5108         ElemKind::FloatTy, ElemKind::Int8QTy, 0.035f, parCloneCountOpt);       \
5109   }
5110 COMPARE_ARITH_FLOAT_VS_INT8(Add)
5111 COMPARE_ARITH_FLOAT_VS_INT8(Sub)
5112 COMPARE_ARITH_FLOAT_VS_INT8(Mul)
5113 COMPARE_ARITH_FLOAT_VS_INT8(Div)
5114 COMPARE_ARITH_FLOAT_VS_INT8(Max)
5115 COMPARE_ARITH_FLOAT_VS_INT8(Min)
5116 #undef COMPARE_ARITH_FLOAT_VS_INT8
5117 
5118 #define COMPARE_ARITH_FLOAT_VS_FLOAT16(_OP_NAME_)                              \
5119   TEST_P(OperatorStatelessTest, Basic##_OP_NAME_##NetFloatVsFloat16) {         \
5120     CHECK_IF_ENABLED();                                                        \
5121     compareAgainstInterpreter(                                                 \
5122         getBackendName(), createAndInitBasic##_OP_NAME_##Test,                 \
5123         ElemKind::FloatTy, ElemKind::Float16Ty, 0.01f, parCloneCountOpt);      \
5124   }
5125 
5126 #define COMPARE_ARITH_FLOAT_VS_BFLOAT16(_OP_NAME_)                             \
5127   TEST_P(OperatorStatelessTest, Basic##_OP_NAME_##NetFloatVsBFloat16) {        \
5128     CHECK_IF_ENABLED();                                                        \
5129     compareAgainstInterpreter(                                                 \
5130         getBackendName(), createAndInitBasic##_OP_NAME_##Test,                 \
5131         ElemKind::FloatTy, ElemKind::BFloat16Ty, 0.01f, parCloneCountOpt);     \
5132   }
5133 COMPARE_ARITH_FLOAT_VS_FLOAT16(Add)
5134 COMPARE_ARITH_FLOAT_VS_FLOAT16(Sub)
5135 COMPARE_ARITH_FLOAT_VS_FLOAT16(Mul)
5136 COMPARE_ARITH_FLOAT_VS_FLOAT16(Div)
5137 COMPARE_ARITH_FLOAT_VS_FLOAT16(Max)
5138 COMPARE_ARITH_FLOAT_VS_FLOAT16(Min)
5139 
5140 COMPARE_ARITH_FLOAT_VS_BFLOAT16(Add)
5141 COMPARE_ARITH_FLOAT_VS_BFLOAT16(Sub)
5142 COMPARE_ARITH_FLOAT_VS_BFLOAT16(Mul)
5143 COMPARE_ARITH_FLOAT_VS_BFLOAT16(Div)
5144 COMPARE_ARITH_FLOAT_VS_BFLOAT16(Max)
5145 COMPARE_ARITH_FLOAT_VS_BFLOAT16(Min)
5146 #undef COMPARE_ARITH_FLOAT_VS_FLOAT16
5147 #undef COMPARE_ARITH_FLOAT_VS_BFLOAT16
5148 
5149 #define ARITH_FUN_IMPL(_OP_NAME_, _REFERENCE_FUNCTION_, _PARENTHESES_)         \
5150   template <typename DataType>                                                 \
5151   static void testArithmetic##_OP_NAME_##Impl(                                 \
5152       glow::PlaceholderBindings &bindings, glow::Module &mod,                  \
5153       glow::Function *F, glow::ExecutionEngine &EE, ElemKind DTy) {            \
5154     std::vector<DataType> data1 = {3, 17, 7, 23};                              \
5155     std::vector<DataType> data2 = {13, 5, 19, 11};                             \
5156     auto *A = mod.createPlaceholder(DTy, {1, 4}, "A", false);                  \
5157     auto *B = mod.createPlaceholder(DTy, {1, 4}, "B", false);                  \
5158     bindings.allocate(A)->getHandle<DataType>() = data1;                       \
5159     bindings.allocate(B)->getHandle<DataType>() = data2;                       \
5160                                                                                \
5161     auto *add = F->create##_OP_NAME_("arith", A, B);                           \
5162     auto *result = F->createSave("save", add);                                 \
5163     auto *resultTensor = bindings.allocate(result->getPlaceholder());          \
5164                                                                                \
5165     EE.compile(CompilationMode::Infer);                                        \
5166     EE.run(bindings);                                                          \
5167     std::vector<DataType> reference;                                           \
5168     assert(data1.size() == data2.size() && "Size mismatch!");                  \
5169     for (size_t i = 0; i < data1.size(); i++) {                                \
5170       reference.push_back(                                                     \
5171           _REFERENCE_FUNCTION_<DataType> _PARENTHESES_(data1[i], data2[i]));   \
5172     }                                                                          \
5173     auto RH = resultTensor->getHandle<DataType>();                             \
5174     EXPECT_EQ(reference.size(), RH.size());                                    \
5175     for (size_t i = 0; i < reference.size(); i++) {                            \
5176       EXPECT_EQ(reference[i], RH.raw(i));                                      \
5177     }                                                                          \
5178   }
5179 
5180 #define ARITH_FUNC_TEST_TYPED(_OP_NAME_, _DATA_TYPE_, _ELEM_KIND_)             \
5181   TEST_P(OperatorTest, Arith##_OP_NAME_##_##_DATA_TYPE_) {                     \
5182     CHECK_IF_ENABLED();                                                        \
5183     testArithmetic##_OP_NAME_##Impl<_DATA_TYPE_>(bindings_, mod_, F_, EE_,     \
5184                                                  _ELEM_KIND_);                 \
5185   }
5186 
5187 #define ARITH_FUNC_TEST(_OP_NAME_, _REFERENCE_FUNCTION_, _PARENTHESES_)        \
5188   ARITH_FUN_IMPL(_OP_NAME_, _REFERENCE_FUNCTION_, _PARENTHESES_)               \
5189   ARITH_FUNC_TEST_TYPED(_OP_NAME_, int32_t, ElemKind::Int32ITy)                \
5190   ARITH_FUNC_TEST_TYPED(_OP_NAME_, int64_t, ElemKind::Int64ITy)                \
5191   ARITH_FUNC_TEST_TYPED(_OP_NAME_, float, ElemKind::FloatTy)                   \
5192   ARITH_FUNC_TEST_TYPED(_OP_NAME_, float16_t, ElemKind::Float16Ty)             \
5193   ARITH_FUNC_TEST_TYPED(_OP_NAME_, bfloat16_t, ElemKind::BFloat16Ty)
5194 
5195 ARITH_FUNC_TEST(Add, std::plus, ())
5196 ARITH_FUNC_TEST(Sub, std::minus, ())
5197 ARITH_FUNC_TEST(Mul, std::multiplies, ())
5198 ARITH_FUNC_TEST(Max, std::max, )
5199 ARITH_FUNC_TEST(Min, std::min, )
5200 #undef ARITH_FUN_IMPL
5201 #undef ARITH_FUNC_TEST_TYPED
5202 #undef ARITH_FUNC_TEST
5203 
5204 TEST_P(OperatorTest, IntMatMul) {
5205   CHECK_IF_ENABLED();
5206 
5207   // The scaling factor 1.4x was carefully selected to make sure we don't
5208   // overflow or underflow the calculation.
5209   TypeRef resTy = mod_.uniqueType(ElemKind::Int8QTy, {3, 3}, 0.60, 4);
5210   TypeRef lhsTy = mod_.uniqueType(ElemKind::Int8QTy, {3, 3}, 0.075, -2);
5211   TypeRef rhsTy = mod_.uniqueType(ElemKind::Int8QTy, {3, 3}, 0.075, 2);
5212 
5213   auto *lhs = mod_.createPlaceholder(ElemKind::FloatTy, {3, 3}, "lhs", false);
5214   auto *rhs = mod_.createPlaceholder(ElemKind::FloatTy, {3, 3}, "rhs", false);
5215 
5216   bindings_.allocate(lhs)->getHandle() = {
5217       1.0, 2.0, 3.0, 4.0, 5.0, -5.0, -4.0, -3.0, 9.0,
5218   };
5219 
5220   bindings_.allocate(rhs)->getHandle() = {
5221       0.1f, -0.2f, 0.3f, 9.0f, -8.0f, 7.0f, 6.0f, 5.0f, 9.0f,
5222   };
5223 
5224   auto *lhsq = F_->createQuantize("lhs.q", lhs, lhsTy);
5225   auto *rhsq = F_->createQuantize("rhs.q", rhs, rhsTy);
5226 
5227   auto *matmulq = F_->createMatMul("matmul.q", resTy, lhsq, rhsq);
5228 
5229   auto *rq = F_->createDequantize("dequant", matmulq, ElemKind::FloatTy);
5230 
5231   auto *result = F_->createSave("save", rq);
5232   bindings_.allocate(result->getPlaceholder());
5233 
5234   EE_.compile(CompilationMode::Infer);
5235   EE_.run(bindings_);
5236 
5237   /*
5238    Test the following matrix multiplication:
5239    A = [[1.0, 2.0, 3.0], [4.0, 5.0, -5.0], [-4.0, -3.0, 9.0]]
5240    B = [[0.1, -0.2, 0.3], [9.0, -8.0, 7.0], [6.0, 5.0, 9.0]]
5241    A x B = [36.1,  -1.2,  41.3], [15.4, -65.8, -8.8], [26.6, 69.8,  58.8]]
5242    */
5243 
5244   auto H = bindings_.get(result->getPlaceholder())->getHandle();
5245   EXPECT_NEAR(H.at({0, 0}), 36.1, 1.0);
5246   EXPECT_NEAR(H.at({0, 1}), -1.2, 1.0);
5247   EXPECT_NEAR(H.at({0, 2}), 41.3, 1.0);
5248   EXPECT_NEAR(H.at({1, 0}), 15.4, 1.0);
5249   EXPECT_NEAR(H.at({1, 1}), -65.8, 1.0);
5250   EXPECT_NEAR(H.at({1, 2}), -8.8, 1.0);
5251   EXPECT_NEAR(H.at({2, 0}), 26.6, 1.0);
5252   EXPECT_NEAR(H.at({2, 1}), 69.8, 1.0);
5253   EXPECT_NEAR(H.at({2, 2}), 58.8, 1.0);
5254 }
5255 
TEST_P(OperatorTest,IntBatchedArith)5256 TEST_P(OperatorTest, IntBatchedArith) {
5257   CHECK_IF_ENABLED();
5258 
5259   TypeRef resTy = mod_.uniqueType(ElemKind::Int8QTy, {1, 3, 3}, 0.10, 1.0);
5260   TypeRef lhsTy = mod_.uniqueType(ElemKind::Int8QTy, {1, 3, 3}, 0.11, 4.0);
5261   TypeRef rhsTy = mod_.uniqueType(ElemKind::Int8QTy, {3, 3}, 0.14, -2.0);
5262 
5263   auto *lhs =
5264       mod_.createPlaceholder(ElemKind::FloatTy, {1, 3, 3}, "lhs", false);
5265   bindings_.allocate(lhs);
5266   auto *rhs = mod_.createPlaceholder(ElemKind::FloatTy, {3, 3}, "rhs", false);
5267   bindings_.allocate(rhs);
5268 
5269   bindings_.get(lhs)->getHandle() = {
5270       8.7f, 6.5f, 4.3f, 2.1f, 1.0f, -5.1f, -4.0f, -12.0f, 0.2f,
5271   };
5272 
5273   bindings_.get(rhs)->getHandle() = {
5274       -9.1f, -0.4f, 1.3f, 2.2f, -8.1f, 7.6f, -6.4f, 10.0f, 9.1f,
5275   };
5276 
5277   auto *lhsq = F_->createQuantize("lhs.q", lhs, lhsTy);
5278   auto *rhsq = F_->createQuantize("rhs.q", rhs, rhsTy);
5279 
5280   auto *matmulq = F_->createBatchedAdd("add", resTy, lhsq, rhsq);
5281 
5282   auto *rq = F_->createDequantize("dequant", matmulq, ElemKind::FloatTy);
5283 
5284   auto *result = F_->createSave("save", rq);
5285   bindings_.allocate(result->getPlaceholder());
5286   EE_.compile(CompilationMode::Infer);
5287 
5288   EE_.run(bindings_);
5289 
5290   // A = [8.7, 6.5, 4.3, 2.1, 1.0, -5.1, -4.0, -12.0, 0.2]
5291   // B = [-9.1, -0.4, 1.3, 2.2, -8.1, 7.6, -6.4, 10.0, 9.1]
5292   // A + B = [-0.4, 6.1, 5.6, 4.3, -7.1, 2.5, -10.4, -2. , 9.3]
5293   auto H = bindings_.get(result->getPlaceholder())->getHandle();
5294   constexpr float allowedError = 0.105;
5295   EXPECT_NEAR(H.at({0, 0, 0}), -0.4, allowedError);
5296   EXPECT_NEAR(H.at({0, 0, 1}), 6.1, allowedError);
5297   EXPECT_NEAR(H.at({0, 0, 2}), 5.6, allowedError);
5298   EXPECT_NEAR(H.at({0, 1, 0}), 4.3, allowedError);
5299   EXPECT_NEAR(H.at({0, 1, 1}), -7.1, allowedError);
5300   EXPECT_NEAR(H.at({0, 1, 2}), 2.5, allowedError);
5301   EXPECT_NEAR(H.at({0, 2, 0}), -10.4, allowedError);
5302   EXPECT_NEAR(H.at({0, 2, 1}), -2, allowedError);
5303   EXPECT_NEAR(H.at({0, 2, 2}), 9.3, allowedError);
5304 }
5305 
TEST_P(OperatorTest,convTest)5306 TEST_P(OperatorTest, convTest) {
5307   CHECK_IF_ENABLED();
5308   auto *input =
5309       mod_.createPlaceholder(ElemKind::FloatTy, {1, 3, 3, 1}, "input", false);
5310   auto IH = bindings_.allocate(input)->getHandle();
5311   IH = {1, 1, 1, 1, 1, 1, 1, 1, 1};
5312 
5313   auto filter =
5314       mod_.createPlaceholder(ElemKind::FloatTy, {1, 3, 3, 1}, "filter", false);
5315   auto FH = bindings_.allocate(filter)->getHandle();
5316   FH = {0, 0, 0, 1, 1, 1, 0, 0, 0};
5317 
5318   auto *zeroBias =
5319       mod_.createPlaceholder(ElemKind::FloatTy, {1}, "bias", false);
5320   bindings_.allocate(zeroBias)->zero();
5321 
5322   auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 3, 3, 1});
5323 
5324   ConvolutionNode *CN =
5325       F_->createConv("Conv", input, filter, zeroBias, outTy, 3, 1, 1, 1);
5326   SaveNode *S = F_->createSave("save", CN);
5327   bindings_.allocate(S->getPlaceholder());
5328 
5329   EE_.compile(CompilationMode::Infer);
5330   EE_.run(bindings_);
5331 
5332   auto result = bindings_.get(S->getPlaceholder());
5333 
5334   Tensor expected(outTy);
5335   expected.getHandle() = {2, 3, 2, 2, 3, 2, 2, 3, 2};
5336 
5337   EXPECT_TRUE(expected.isEqual(*result));
5338 }
5339 
TEST_P(OperatorTest,convTest_Float16)5340 TEST_P(OperatorTest, convTest_Float16) {
5341   CHECK_IF_ENABLED();
5342   auto *input =
5343       mod_.createPlaceholder(ElemKind::Float16Ty, {1, 3, 3, 1}, "input", false);
5344   auto IH = bindings_.allocate(input)->getHandle<float16_t>();
5345   IH = {1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9};
5346 
5347   auto filter = mod_.createPlaceholder(ElemKind::Float16Ty, {1, 3, 3, 1},
5348                                        "filter", false);
5349   auto FH = bindings_.allocate(filter)->getHandle<float16_t>();
5350   FH = {0.25, 0.5, 0.25, 1, 1, 1, 0.25, 0.5, 0.25};
5351 
5352   auto *zeroBias =
5353       mod_.createPlaceholder(ElemKind::Float16Ty, {1}, "bias", false);
5354   bindings_.allocate(zeroBias)->zero();
5355 
5356   auto outTy = mod_.uniqueType(ElemKind::Float16Ty, {1, 3, 3, 1});
5357 
5358   ConvolutionNode *CN =
5359       F_->createConv("Conv", input, filter, zeroBias, outTy, 3, 1, 1, 1);
5360   SaveNode *S = F_->createSave("save", CN);
5361   bindings_.allocate(S->getPlaceholder());
5362 
5363   EE_.compile(CompilationMode::Infer);
5364   EE_.run(bindings_);
5365 
5366   auto result = bindings_.get(S->getPlaceholder())->getHandle<float16_t>();
5367 
5368   Tensor expected(outTy);
5369   auto expectedH = expected.getHandle<float16_t>();
5370   expectedH = {3.375, 5.102, 3.676, 5.051, 7.5, 5.449, 4.574, 6.898, 4.875};
5371 
5372   for (dim_t x = 0; x < 3; x++) {
5373     for (dim_t y = 0; y < 3; y++) {
5374       EXPECT_NEAR(result.at({0, x, y, 0}), expectedH.at({0, x, y, 0}), 0.001);
5375     }
5376   }
5377 }
5378 
TEST_P(OperatorTest,convTest_BFloat16)5379 TEST_P(OperatorTest, convTest_BFloat16) {
5380   CHECK_IF_ENABLED();
5381   auto *input = mod_.createPlaceholder(ElemKind::BFloat16Ty, {1, 3, 3, 1},
5382                                        "input", false);
5383   auto IH = bindings_.allocate(input)->getHandle<bfloat16_t>();
5384   IH = {1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9};
5385 
5386   auto filter = mod_.createPlaceholder(ElemKind::BFloat16Ty, {1, 3, 3, 1},
5387                                        "filter", false);
5388   auto FH = bindings_.allocate(filter)->getHandle<bfloat16_t>();
5389   FH = {0.25, 0.5, 0.25, 1, 1, 1, 0.25, 0.5, 0.25};
5390 
5391   auto *zeroBias =
5392       mod_.createPlaceholder(ElemKind::BFloat16Ty, {1}, "bias", false);
5393   bindings_.allocate(zeroBias)->zero();
5394 
5395   auto outTy = mod_.uniqueType(ElemKind::BFloat16Ty, {1, 3, 3, 1});
5396 
5397   ConvolutionNode *CN =
5398       F_->createConv("Conv", input, filter, zeroBias, outTy, 3, 1, 1, 1);
5399   SaveNode *S = F_->createSave("save", CN);
5400   bindings_.allocate(S->getPlaceholder());
5401 
5402   EE_.compile(CompilationMode::Infer);
5403   EE_.run(bindings_);
5404 
5405   auto result = bindings_.get(S->getPlaceholder())->getHandle<bfloat16_t>();
5406 
5407   Tensor expected(outTy);
5408   auto expectedH = expected.getHandle<bfloat16_t>();
5409   expectedH = {3.375, 5.102, 3.676, 5.051, 7.5, 5.449, 4.574, 6.898, 4.875};
5410 
5411   for (dim_t x = 0; x < 3; x++) {
5412     for (dim_t y = 0; y < 3; y++) {
5413       EXPECT_NEAR(result.at({0, x, y, 0}), expectedH.at({0, x, y, 0}), 0.05);
5414     }
5415   }
5416 }
5417 
5418 template <size_t convDepth>
5419 static FunctionTensorPair
createAndInitConvDepthTest(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE)5420 createAndInitConvDepthTest(glow::PlaceholderBindings &bindings,
5421                            glow::ExecutionEngine &EE) {
5422   auto &mod = EE.getModule();
5423   Function *F = mod.createFunction("main");
5424 
5425   auto *input =
5426       mod.createPlaceholder(ElemKind::FloatTy, {1, 10, 10, 3}, "in", false);
5427   auto *conv = F->createConv(bindings, "conv", input, convDepth, 5, 1, 0, 1);
5428   auto *bias = llvm::cast<Placeholder>(conv->getBias().getNode());
5429 
5430   bindings.allocate(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
5431   bindings.get(bias)->getHandle().randomize(-2.0, 2.0, mod.getPRNG());
5432 
5433   auto *res = F->createSave("save", conv);
5434   ::glow::convertPlaceholdersToConstants(F, bindings,
5435                                          {input, res->getPlaceholder()});
5436   auto *resultTensor = bindings.allocate(res->getPlaceholder());
5437 
5438   return std::make_pair(F, resultTensor);
5439 }
5440 
TEST_P(OperatorStatelessTest,Int8ConvolutionDepth10)5441 TEST_P(OperatorStatelessTest, Int8ConvolutionDepth10) {
5442   CHECK_IF_ENABLED();
5443   compareAgainstInterpreter(getBackendName(), createAndInitConvDepthTest<10>,
5444                             ElemKind::FloatTy, ElemKind::Int8QTy, 0.045f,
5445                             parCloneCountOpt);
5446 }
5447 
TEST_P(OperatorStatelessTest,Int16ConvolutionDepth10)5448 TEST_P(OperatorStatelessTest, Int16ConvolutionDepth10) {
5449   CHECK_IF_ENABLED();
5450   compareAgainstInterpreter(getBackendName(), createAndInitConvDepthTest<10>,
5451                             ElemKind::FloatTy, ElemKind::Int16QTy, 0.03f,
5452                             parCloneCountOpt);
5453 }
5454 
TEST_P(OperatorStatelessTest,Int8ConvolutionDepth8)5455 TEST_P(OperatorStatelessTest, Int8ConvolutionDepth8) {
5456   CHECK_IF_ENABLED();
5457   compareAgainstInterpreter(getBackendName(), createAndInitConvDepthTest<8>,
5458                             ElemKind::FloatTy, ElemKind::Int8QTy, 0.03f,
5459                             parCloneCountOpt);
5460 }
TEST_P(OperatorStatelessTest,Int16ConvolutionDepth8)5461 TEST_P(OperatorStatelessTest, Int16ConvolutionDepth8) {
5462   CHECK_IF_ENABLED();
5463   compareAgainstInterpreter(getBackendName(), createAndInitConvDepthTest<8>,
5464                             ElemKind::FloatTy, ElemKind::Int16QTy, 0.03f,
5465                             parCloneCountOpt);
5466 }
5467 
TEST_P(OperatorStatelessTest,FP16ConvolutionDepth10)5468 TEST_P(OperatorStatelessTest, FP16ConvolutionDepth10) {
5469   CHECK_IF_ENABLED();
5470   compareAgainstInterpreter(getBackendName(), createAndInitConvDepthTest<10>,
5471                             ElemKind::FloatTy, ElemKind::Float16Ty, 0.015f,
5472                             parCloneCountOpt);
5473 }
5474 
TEST_P(OperatorStatelessTest,BFloat16ConvolutionDepth10)5475 TEST_P(OperatorStatelessTest, BFloat16ConvolutionDepth10) {
5476   CHECK_IF_ENABLED();
5477   compareAgainstInterpreter(getBackendName(), createAndInitConvDepthTest<10>,
5478                             ElemKind::FloatTy, ElemKind::BFloat16Ty, 0.015f,
5479                             parCloneCountOpt);
5480 }
5481 
TEST_P(OperatorStatelessTest,FP16ConvolutionDepth8)5482 TEST_P(OperatorStatelessTest, FP16ConvolutionDepth8) {
5483   CHECK_IF_ENABLED();
5484   compareAgainstInterpreter(getBackendName(), createAndInitConvDepthTest<8>,
5485                             ElemKind::FloatTy, ElemKind::Float16Ty, 0.015f,
5486                             parCloneCountOpt);
5487 }
5488 
TEST_P(OperatorStatelessTest,BFloat16ConvolutionDepth8)5489 TEST_P(OperatorStatelessTest, BFloat16ConvolutionDepth8) {
5490   CHECK_IF_ENABLED();
5491   compareAgainstInterpreter(getBackendName(), createAndInitConvDepthTest<8>,
5492                             ElemKind::FloatTy, ElemKind::BFloat16Ty, 0.015f,
5493                             parCloneCountOpt);
5494 }
5495 
TEST_P(OperatorStatelessTest,ConvolutionDepth10_Int8_BiasInt8)5496 TEST_P(OperatorStatelessTest, ConvolutionDepth10_Int8_BiasInt8) {
5497   ENABLED_BACKENDS("Interpreter", "CPU");
5498   compareAgainstInterpreter(
5499       getBackendName(), createAndInitConvDepthTest<10>, ElemKind::FloatTy,
5500       ElemKind::Int8QTy, 0.03f, parCloneCountOpt,
5501       /* convertToRowwiseQuantization */ false,
5502       quantization::Schema::Asymmetric, ElemKind::Int8QTy);
5503 }
5504 
TEST_P(OperatorStatelessTest,ConvolutionDepth10_Int8_BiasInt32)5505 TEST_P(OperatorStatelessTest, ConvolutionDepth10_Int8_BiasInt32) {
5506   ENABLED_BACKENDS("Interpreter", "CPU");
5507   compareAgainstInterpreter(
5508       getBackendName(), createAndInitConvDepthTest<10>, ElemKind::FloatTy,
5509       ElemKind::Int8QTy, 0.03f, parCloneCountOpt,
5510       /* convertToRowwiseQuantization */ false,
5511       quantization::Schema::Asymmetric, ElemKind::Int32QTy);
5512 }
5513 
TEST_P(OperatorStatelessTest,ConvolutionDepth10_Int16_BiasInt16)5514 TEST_P(OperatorStatelessTest, ConvolutionDepth10_Int16_BiasInt16) {
5515   ENABLED_BACKENDS("Interpreter");
5516   compareAgainstInterpreter(
5517       getBackendName(), createAndInitConvDepthTest<10>, ElemKind::FloatTy,
5518       ElemKind::Int16QTy, 0.0003f, parCloneCountOpt,
5519       /* convertToRowwiseQuantization */ false,
5520       quantization::Schema::Asymmetric, ElemKind::Int16QTy);
5521 }
5522 
TEST_P(OperatorStatelessTest,ConvolutionDepth10_Int16_BiasInt32)5523 TEST_P(OperatorStatelessTest, ConvolutionDepth10_Int16_BiasInt32) {
5524   ENABLED_BACKENDS("Interpreter");
5525   compareAgainstInterpreter(
5526       getBackendName(), createAndInitConvDepthTest<10>, ElemKind::FloatTy,
5527       ElemKind::Int16QTy, 0.0003f, parCloneCountOpt,
5528       /* convertToRowwiseQuantization */ false,
5529       quantization::Schema::Asymmetric, ElemKind::Int32QTy);
5530 }
5531 
5532 static FunctionTensorPair
createAndInitBasicConcatTest(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE)5533 createAndInitBasicConcatTest(glow::PlaceholderBindings &bindings,
5534                              glow::ExecutionEngine &EE) {
5535   auto &mod = EE.getModule();
5536   Function *F = mod.createFunction("main");
5537 
5538   auto *A = mod.createPlaceholder(ElemKind::FloatTy, {3, 3}, "A", false);
5539   auto *B = mod.createPlaceholder(ElemKind::FloatTy, {2, 3}, "B", false);
5540   bindings.allocate(A)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
5541   bindings.allocate(B)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
5542 
5543   auto *C = F->createConcat("concat", {A, B}, 0);
5544   auto *res = F->createSave("save", C);
5545   auto *resultTensor = bindings.allocate(res->getPlaceholder());
5546 
5547   ::glow::convertPlaceholdersToConstants(F, bindings,
5548                                          {A, B, res->getPlaceholder()});
5549 
5550   return std::make_pair(F, resultTensor);
5551 }
5552 
TEST_P(OperatorStatelessTest,IntConcat)5553 TEST_P(OperatorStatelessTest, IntConcat) {
5554   CHECK_IF_ENABLED();
5555   compareAgainstInterpreter(getBackendName(), createAndInitBasicConcatTest,
5556                             ElemKind::FloatTy, ElemKind::Int8QTy, 0.05f,
5557                             parCloneCountOpt);
5558 }
5559 
TEST_P(OperatorTest,FCWithFlatten)5560 TEST_P(OperatorTest, FCWithFlatten) {
5561   CHECK_IF_ENABLED();
5562 
5563   auto *input =
5564       mod_.createPlaceholder(ElemKind::FloatTy, {2, 1, 3}, "input", false);
5565   Constant *weights = mod_.createConstant(ElemKind::FloatTy, {3, 4}, "weights");
5566   Constant *bias = mod_.createConstant(ElemKind::FloatTy, {4}, "bias");
5567 
5568   bindings_.allocate(input)->getHandle() = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
5569   weights->getPayloadMutable().getHandle() = {1.0f, 4.0f, 7.0f, 10.0f, //
5570                                               2.0f, 5.0f, 8.0f, 11.0f, //
5571                                               3.0f, 6.0f, 9.0f, 12.0f};
5572   bias->getPayloadMutable().getHandle() = {0.1f, 0.2f, 0.3f, 0.4f};
5573 
5574   auto *FC = F_->createFullyConnected("fc", input, weights, bias);
5575   auto *S = F_->createSave("save", FC);
5576   bindings_.allocate(S->getPlaceholder());
5577 
5578   EE_.compile(CompilationMode::Infer);
5579   EE_.run(bindings_);
5580 
5581   auto result = bindings_.get(S->getPlaceholder())->getHandle();
5582   std::vector<dim_t> expectedDimensions = {2, 4};
5583   std::vector<float> expectedValues = {14.1f, 32.2f, 50.3f,  68.4f,
5584                                        32.1f, 77.2f, 122.3f, 167.4f};
5585   EXPECT_TRUE(result.dims().vec() == expectedDimensions);
5586   for (size_t i = 0; i < 2 * 4; i++) {
5587     EXPECT_FLOAT_EQ(result.raw(i), expectedValues[i]);
5588   }
5589 }
5590 
TEST_P(OperatorTest,TestFP32Accumulator)5591 TEST_P(OperatorTest, TestFP32Accumulator) {
5592   CHECK_IF_ENABLED();
5593   auto *input =
5594       mod_.createPlaceholder(ElemKind::Float16Ty, {1, 3}, "input", false);
5595   Constant *weights =
5596       mod_.createConstant(ElemKind::Float16Ty, {3, 2}, "weights");
5597   Constant *bias = mod_.createConstant(ElemKind::Float16Ty, {2}, "bias");
5598 
5599   /* 9.7e-4 is smaller than what the mantissa can represent
5600     when the initial value is 1, but 2 * 9.7e-4 is exactly
5601     the smallest number that can be represented after 1
5602     In Fp16 accumulation, we will be losing the update leading to 1,
5603     in fp32, we get a value slightly larger than 1.
5604   */
5605   bindings_.allocate(input)->getHandle<float16_t>() = {1.0f, 9.7e-4, 9.7e-4f};
5606   weights->getPayloadMutable().getHandle<float16_t>() = {1.0f, 1.0f, 0.5f,
5607                                                          1.0f, 0.5f, 1.0f};
5608   bias->getPayloadMutable().getHandle<float16_t>() = {0.0f, 0.0f};
5609 
5610   auto *FC = F_->createFullyConnected("fc", input, weights, bias);
5611   auto *S = F_->createSave("save", FC);
5612   bindings_.allocate(S->getPlaceholder());
5613   EE_.compile(CompilationMode::Infer);
5614   EE_.run(bindings_);
5615   auto result = bindings_.get(S->getPlaceholder())->getHandle<float16_t>();
5616   std::vector<dim_t> expectedDimensions = {1, 2};
5617 
5618   EXPECT_TRUE(result.dims().vec() == expectedDimensions);
5619   float finalResult = result.raw(0);
5620   if (finalResult == 1.0) {
5621     llvm::outs() << "fp16 accumulator\n";
5622   } else if (fabs(finalResult - 1.00098) < 1e-3) {
5623     llvm::outs() << "fp32 accumulator\n";
5624   } else {
5625     // Unhandled case
5626     FAIL() << "unknown " << finalResult;
5627   }
5628   llvm::outs().flush();
5629 }
5630 
5631 static FunctionTensorPair
createAndInitBasicFCTest(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE)5632 createAndInitBasicFCTest(glow::PlaceholderBindings &bindings,
5633                          glow::ExecutionEngine &EE) {
5634   auto &mod = EE.getModule();
5635   Function *F = mod.createFunction("main");
5636 
5637   auto *input =
5638       mod.createPlaceholder(ElemKind::FloatTy, {1, 10, 10, 3}, "in", false);
5639   auto *fc = F->createFullyConnected(bindings, "FC", input, 30);
5640 
5641   auto *weights = llvm::cast<Placeholder>(fc->getWeights());
5642   auto *bias = llvm::cast<Placeholder>(fc->getBias());
5643 
5644   bindings.allocate(input)->getHandle().randomize(-0.5, 0.5, mod.getPRNG());
5645   bindings.get(bias)->getHandle().randomize(0, 0.00001, mod.getPRNG());
5646   bindings.get(weights)->getHandle().randomize(-0.7, 0.7, mod.getPRNG());
5647 
5648   auto *res = F->createSave("save", fc);
5649   ::glow::convertPlaceholdersToConstants(F, bindings,
5650                                          {input, res->getPlaceholder()});
5651   auto *resultTensor = bindings.allocate(res->getPlaceholder());
5652 
5653   return std::make_pair(F, resultTensor);
5654 }
5655 
TEST_P(OperatorStatelessTest,IntFC)5656 TEST_P(OperatorStatelessTest, IntFC) {
5657   CHECK_IF_ENABLED();
5658   compareAgainstInterpreter(getBackendName(), createAndInitBasicFCTest,
5659                             ElemKind::FloatTy, ElemKind::Int8QTy, 0.05f,
5660                             parCloneCountOpt);
5661 }
5662 
5663 /// Test FC with Float16.
TEST_P(OperatorStatelessTest,FC_Float16)5664 TEST_P(OperatorStatelessTest, FC_Float16) {
5665   CHECK_IF_ENABLED();
5666   compareAgainstInterpreter(getBackendName(), createAndInitBasicFCTest,
5667                             ElemKind::FloatTy, ElemKind::Float16Ty, 0.02f,
5668                             parCloneCountOpt);
5669 }
5670 
5671 /// Test FC with BFloat16.
TEST_P(OperatorStatelessTest,FC_BFloat16)5672 TEST_P(OperatorStatelessTest, FC_BFloat16) {
5673   CHECK_IF_ENABLED();
5674   compareAgainstInterpreter(getBackendName(), createAndInitBasicFCTest,
5675                             ElemKind::FloatTy, ElemKind::BFloat16Ty, 0.02f,
5676                             parCloneCountOpt);
5677 }
5678 
5679 /// Test Int8 FullyConnected with Int8 bias.
TEST_P(OperatorStatelessTest,FullyConnected_Int8_BiasInt8)5680 TEST_P(OperatorStatelessTest, FullyConnected_Int8_BiasInt8) {
5681   ENABLED_BACKENDS("Interpreter", "CPU");
5682   compareAgainstInterpreter(
5683       getBackendName(), createAndInitBasicFCTest, ElemKind::FloatTy,
5684       ElemKind::Int8QTy, 0.05f, parCloneCountOpt,
5685       /* convertToRowwiseQuantization */ false,
5686       quantization::Schema::Asymmetric, ElemKind::Int8QTy);
5687 }
5688 
5689 /// Test Int8 FullyConnected with Int32 bias.
TEST_P(OperatorStatelessTest,FullyConnected_Int8_BiasInt32)5690 TEST_P(OperatorStatelessTest, FullyConnected_Int8_BiasInt32) {
5691   ENABLED_BACKENDS("Interpreter", "CPU", "NNPI");
5692   compareAgainstInterpreter(
5693       getBackendName(), createAndInitBasicFCTest, ElemKind::FloatTy,
5694       ElemKind::Int8QTy, 0.05f, parCloneCountOpt,
5695       /* convertToRowwiseQuantization */ false,
5696       quantization::Schema::Asymmetric, ElemKind::Int32QTy);
5697 }
5698 
5699 /// Test Int16 FullyConnected with Int16 bias.
TEST_P(OperatorStatelessTest,FullyConnected_Int16_BiasInt16)5700 TEST_P(OperatorStatelessTest, FullyConnected_Int16_BiasInt16) {
5701   ENABLED_BACKENDS("Interpreter");
5702   compareAgainstInterpreter(
5703       getBackendName(), createAndInitBasicFCTest, ElemKind::FloatTy,
5704       ElemKind::Int16QTy, 0.0005f, parCloneCountOpt,
5705       /* convertToRowwiseQuantization */ false,
5706       quantization::Schema::Asymmetric, ElemKind::Int16QTy);
5707 }
5708 
5709 /// Test Int16 FullyConnected with Int32 bias.
TEST_P(OperatorStatelessTest,FullyConnected_Int16_BiasInt32)5710 TEST_P(OperatorStatelessTest, FullyConnected_Int16_BiasInt32) {
5711   ENABLED_BACKENDS("Interpreter");
5712   compareAgainstInterpreter(
5713       getBackendName(), createAndInitBasicFCTest, ElemKind::FloatTy,
5714       ElemKind::Int16QTy, 0.0005f, parCloneCountOpt,
5715       /* convertToRowwiseQuantization */ false,
5716       quantization::Schema::Asymmetric, ElemKind::Int32QTy);
5717 }
5718 
TEST_P(OperatorTest,EntropyLossTest)5719 TEST_P(OperatorTest, EntropyLossTest) {
5720   CHECK_IF_ENABLED();
5721 
5722   auto *P = mod_.createPlaceholder(ElemKind::FloatTy, {2, 3}, "P", false);
5723   auto *Y = mod_.createPlaceholder(ElemKind::Int64ITy, {2}, "Y", false);
5724 
5725   bindings_.allocate(P)->getHandle() = {0.2f, 0.5f, 0.3f, 0.4f, 0.3f, 0.3f};
5726   bindings_.allocate(Y)->getHandle<int64_t>() = {1, 2};
5727   auto *ceLoss = F_->createCrossEntropyLoss("CELoss", P, Y);
5728   auto *L = F_->createSave("save", ceLoss);
5729   bindings_.allocate(L->getPlaceholder());
5730 
5731   EE_.compile(CompilationMode::Infer);
5732   EE_.run(bindings_);
5733 
5734   auto R = bindings_.get(L->getPlaceholder())->getHandle();
5735   EXPECT_NEAR(R.at({0}), -log(0.5) - log(0.3), 0.1);
5736 }
5737 
5738 /// Check that the max operator works properly with FP16.
TEST_P(OperatorTest,FP16Max)5739 TEST_P(OperatorTest, FP16Max) {
5740   CHECK_IF_ENABLED();
5741 
5742   PseudoRNG PRNG;
5743 
5744   auto *inputA =
5745       mod_.createPlaceholder(ElemKind::Float16Ty, {1, 3, 3, 1}, "A", false);
5746   bindings_.allocate(inputA)->getHandle<float16_t>().randomize(-3.0, 3.0, PRNG);
5747   auto *inputB =
5748       mod_.createPlaceholder(ElemKind::Float16Ty, {1, 3, 3, 1}, "B", false);
5749   bindings_.allocate(inputB)->getHandle<float16_t>().randomize(-3.0, 3.0, PRNG);
5750   auto *Max = F_->createMax("max", inputA, inputB);
5751   auto *S = F_->createSave("save", Max);
5752   bindings_.allocate(S->getPlaceholder());
5753 
5754   EE_.compile(CompilationMode::Infer);
5755   EE_.run(bindings_);
5756 
5757   auto result = bindings_.get(S->getPlaceholder())->getHandle<float16_t>();
5758   auto handleA = bindings_.get(inputA)->getHandle<float16_t>();
5759   auto handleB = bindings_.get(inputB)->getHandle<float16_t>();
5760   ASSERT_EQ(result.size(), handleA.size());
5761   for (size_t idx = 0, end = result.size(); idx != end; ++idx) {
5762     EXPECT_EQ(result.raw(idx), std::max(handleA.raw(idx), handleB.raw(idx)));
5763   }
5764 }
5765 
5766 /// Check that the max operator works properly with FP16.
TEST_P(OperatorTest,BFloat16Max)5767 TEST_P(OperatorTest, BFloat16Max) {
5768   CHECK_IF_ENABLED();
5769 
5770   PseudoRNG PRNG;
5771 
5772   auto *inputA =
5773       mod_.createPlaceholder(ElemKind::BFloat16Ty, {1, 3, 3, 1}, "A", false);
5774   bindings_.allocate(inputA)->getHandle<bfloat16_t>().randomize(-3.0, 3.0,
5775                                                                 PRNG);
5776   auto *inputB =
5777       mod_.createPlaceholder(ElemKind::BFloat16Ty, {1, 3, 3, 1}, "B", false);
5778   bindings_.allocate(inputB)->getHandle<bfloat16_t>().randomize(-3.0, 3.0,
5779                                                                 PRNG);
5780   auto *Max = F_->createMax("max", inputA, inputB);
5781   auto *S = F_->createSave("save", Max);
5782   bindings_.allocate(S->getPlaceholder());
5783 
5784   EE_.compile(CompilationMode::Infer);
5785   EE_.run(bindings_);
5786 
5787   auto result = bindings_.get(S->getPlaceholder())->getHandle<bfloat16_t>();
5788   auto handleA = bindings_.get(inputA)->getHandle<bfloat16_t>();
5789   auto handleB = bindings_.get(inputB)->getHandle<bfloat16_t>();
5790   ASSERT_EQ(result.size(), handleA.size());
5791   for (size_t idx = 0, end = result.size(); idx != end; ++idx) {
5792     EXPECT_EQ(result.raw(idx), std::max(handleA.raw(idx), handleB.raw(idx)));
5793   }
5794 }
5795 
5796 /// Helper to test Broadcast Max/Min using \p DTy and \p NTy
5797 template <typename DataType, typename NodeType>
testBroadcastMaxMin(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)5798 static void testBroadcastMaxMin(glow::PlaceholderBindings &bindings,
5799                                 glow::Module &mod, glow::Function *F,
5800                                 glow::ExecutionEngine &EE, ElemKind DTy) {
5801 
5802   auto *inputA = mod.createPlaceholder(DTy, {1, 3, 3, 1}, "A", false);
5803   bindings.allocate(inputA)->getHandle<DataType>().randomize(-3.0, 3.0,
5804                                                              mod.getPRNG());
5805   auto *inputB = mod.createPlaceholder(DTy, {1, 3, 3, 1}, "B", false);
5806   bindings.allocate(inputB)->getHandle<DataType>().randomize(-3.0, 3.0,
5807                                                              mod.getPRNG());
5808 
5809   Node *maxorMinOp = F->createNodeWithBroadcast<NodeType>(
5810       "maxormin", -1 /*axis */, inputA, inputB);
5811 
5812   auto *S = F->createSave("save", maxorMinOp);
5813   bindings.allocate(S->getPlaceholder());
5814 
5815   EE.compile(CompilationMode::Infer);
5816   EE.run(bindings);
5817 
5818   ASSERT_TRUE(F->verify(&EE.getBackend()))
5819       << "Function must pass verification.";
5820 
5821   auto result = bindings.get(S->getPlaceholder())->getHandle<DataType>();
5822   auto handleA = bindings.get(inputA)->getHandle<DataType>();
5823   auto handleB = bindings.get(inputB)->getHandle<DataType>();
5824   ASSERT_EQ(result.size(), handleA.size());
5825   for (size_t idx = 0, end = result.size(); idx != end; ++idx) {
5826     if (std::is_same<NodeType, MaxNode>::value) {
5827       EXPECT_EQ(result.raw(idx), std::max(handleA.raw(idx), handleB.raw(idx)));
5828     } else {
5829       EXPECT_EQ(result.raw(idx), std::min(handleA.raw(idx), handleB.raw(idx)));
5830     }
5831   }
5832 }
5833 
TEST_P(OperatorTest,BroadCastMax)5834 TEST_P(OperatorTest, BroadCastMax) {
5835   CHECK_IF_ENABLED();
5836   testBroadcastMaxMin<int64_t, MaxNode>(bindings_, mod_, F_, EE_,
5837                                         ElemKind::Int64ITy);
5838 }
5839 
TEST_P(OperatorTest,BroadCastMin)5840 TEST_P(OperatorTest, BroadCastMin) {
5841   CHECK_IF_ENABLED();
5842   testBroadcastMaxMin<int64_t, MinNode>(bindings_, mod_, F_, EE_,
5843                                         ElemKind::Int64ITy);
5844 }
5845 
TEST_P(OperatorTest,RescaleNode)5846 TEST_P(OperatorTest, RescaleNode) {
5847   CHECK_IF_ENABLED();
5848 
5849   // Check the outputs of the RescaleQuantized operation.
5850   auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.4, -3,
5851                                        "input", false);
5852   bindings_.allocate(input)->init(Tensor::InitKind::Broadcast, 40,
5853                                   mod_.getPRNG());
5854 
5855   auto T1 = mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.7, 5);
5856   auto T2 = mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.3, -4);
5857   auto resTy = mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.4, -4);
5858 
5859   // Test a sequence of rescale operations that the optimizer may try to
5860   // optimize at some point.
5861   auto *X = F_->createRescaleQuantized("R1", input, T1);
5862   auto *Y = F_->createRescaleQuantized("R2", X, T2);
5863   auto *Z = F_->createRescaleQuantized("R3", Y, resTy);
5864 
5865   auto *output = F_->createSave("save", Z);
5866   bindings_.allocate(output->getPlaceholder());
5867 
5868   EE_.compile(CompilationMode::Infer);
5869   EE_.run(bindings_);
5870 
5871   auto RI = bindings_.get(input)->getHandle<int8_t>();
5872   auto RO = bindings_.get(output->getPlaceholder())->getHandle<int8_t>();
5873 
5874   EXPECT_EQ(RI.raw(0), 40);
5875   EXPECT_NEAR(RO.raw(0), 40, 1);
5876 }
5877 
TEST_P(OperatorTest,QuantizedArithmeticRescaled)5878 TEST_P(OperatorTest, QuantizedArithmeticRescaled) {
5879   CHECK_IF_ENABLED();
5880 
5881   const dim_t len = 100;
5882 
5883   // In this test we check the correctness of the quantized Max, Min, Add,
5884   // Sub, Mul, and Div nodes as well as how they interact with the rescaling
5885   // node.
5886   auto *A = mod_.createPlaceholder(ElemKind::FloatTy, {len}, "A", false);
5887   auto *B = mod_.createPlaceholder(ElemKind::FloatTy, {len}, "B", false);
5888   auto *C = mod_.createPlaceholder(ElemKind::FloatTy, {len}, "C", false);
5889 
5890   auto AH = bindings_.allocate(A)->getHandle();
5891   auto BH = bindings_.allocate(B)->getHandle();
5892   auto CH = bindings_.allocate(C)->getHandle();
5893 
5894   AH.randomize(-10, 10, mod_.getPRNG());
5895   BH.randomize(-10, 10, mod_.getPRNG());
5896   // Below, randomize between 1 and 10 to avoid division by 0 later.
5897   CH.randomize(1, 10, mod_.getPRNG());
5898 
5899   auto TA = mod_.uniqueType(ElemKind::Int8QTy, {len}, 0.2, 0);
5900   auto TB = mod_.uniqueType(ElemKind::Int8QTy, {len}, 0.1, 0);
5901   auto TC = mod_.uniqueType(ElemKind::Int8QTy, {len}, 0.3, 0);
5902 
5903   auto TI1 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.1, 0);
5904   auto TI2 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 0.8, 0);
5905   auto TI3 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 0.9, 0);
5906   auto TI4 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.0, 0);
5907   auto TI5 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.2, 0);
5908   auto TI6 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 0.7, 0);
5909 
5910   auto TO1 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.0, 0);
5911   auto TO2 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 0.9, 0);
5912   auto TO3 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.1, 0);
5913   auto TO4 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.2, 0);
5914   auto TO5 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.0, 0);
5915   auto TO6 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.1, 0);
5916 
5917   // Quantize input vars and apply max/min/add/sub/mul/div quantized.
5918   auto *QA = F_->createQuantize("QA", A, TA);
5919   auto *QB = F_->createQuantize("QB", B, TB);
5920   auto *QC = F_->createQuantize("QC", C, TC);
5921 
5922   Node *max = F_->createMax("max", TI1, QA, QB);
5923   Node *min = F_->createMin("min", TI2, QA, QB);
5924   Node *add = F_->createAdd("add", TI3, QA, QB);
5925   Node *sub = F_->createSub("sub", TI4, QA, QB);
5926   Node *mul = F_->createMul("mul", TI5, QA, QB);
5927   Node *div = F_->createDiv("div", TI6, QB, QC);
5928 
5929   // Rescale quantized results.
5930   max = F_->createRescaleQuantized("rescaleMax", max, TO1);
5931   min = F_->createRescaleQuantized("rescaleMin", min, TO2);
5932   add = F_->createRescaleQuantized("rescaleAdd", add, TO3);
5933   sub = F_->createRescaleQuantized("rescaleSub", sub, TO4);
5934   mul = F_->createRescaleQuantized("rescaleMul", mul, TO5);
5935   div = F_->createRescaleQuantized("rescaleDiv", div, TO6);
5936 
5937   // Dequantize results back to floating-point.
5938   max = F_->createDequantize("maxDQ", max, ElemKind::FloatTy);
5939   min = F_->createDequantize("minDQ", min, ElemKind::FloatTy);
5940   add = F_->createDequantize("addDQ", add, ElemKind::FloatTy);
5941   sub = F_->createDequantize("subDQ", sub, ElemKind::FloatTy);
5942   mul = F_->createDequantize("mulDQ", mul, ElemKind::FloatTy);
5943   div = F_->createDequantize("divDQ", div, ElemKind::FloatTy);
5944 
5945   // Save results of the operations.
5946   auto *O1 = F_->createSave("saveMax", max);
5947   auto *O2 = F_->createSave("saveMin", min);
5948   auto *O3 = F_->createSave("saveAdd", add);
5949   auto *O4 = F_->createSave("saveSub", sub);
5950   auto *O5 = F_->createSave("saveMul", mul);
5951   auto *O6 = F_->createSave("saveDiv", div);
5952 
5953   bindings_.allocate(O1->getPlaceholder());
5954   bindings_.allocate(O2->getPlaceholder());
5955   bindings_.allocate(O3->getPlaceholder());
5956   bindings_.allocate(O4->getPlaceholder());
5957   bindings_.allocate(O5->getPlaceholder());
5958   bindings_.allocate(O6->getPlaceholder());
5959 
5960   EE_.compile(CompilationMode::Infer);
5961   EE_.run(bindings_);
5962 
5963   for (dim_t i = 0; i < len; i++) {
5964     auto max = std::max(AH.at({i}), BH.at({i}));
5965     auto min = std::min(AH.at({i}), BH.at({i}));
5966     auto add = AH.at({i}) + BH.at({i});
5967     auto sub = AH.at({i}) - BH.at({i});
5968     auto mul = AH.at({i}) * BH.at({i});
5969     auto div = BH.at({i}) / CH.at({i});
5970 
5971     // We generate numbers up to 110, so a difference of 2 (~2%) is
5972     // reasonable.
5973     EXPECT_NEAR(max, bindings_.get(O1->getPlaceholder())->getHandle().at({i}),
5974                 2.0);
5975     EXPECT_NEAR(min, bindings_.get(O2->getPlaceholder())->getHandle().at({i}),
5976                 2.0);
5977     EXPECT_NEAR(add, bindings_.get(O3->getPlaceholder())->getHandle().at({i}),
5978                 2.0);
5979     EXPECT_NEAR(sub, bindings_.get(O4->getPlaceholder())->getHandle().at({i}),
5980                 2.0);
5981     EXPECT_NEAR(mul, bindings_.get(O5->getPlaceholder())->getHandle().at({i}),
5982                 2.0);
5983     EXPECT_NEAR(div, bindings_.get(O6->getPlaceholder())->getHandle().at({i}),
5984                 2.0);
5985   }
5986 }
5987 
5988 static FunctionTensorPair
createAndInitTransposeNet(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE)5989 createAndInitTransposeNet(glow::PlaceholderBindings &bindings,
5990                           glow::ExecutionEngine &EE) {
5991   auto &mod = EE.getModule();
5992   Function *F = mod.createFunction("main");
5993 
5994   auto *A = mod.createPlaceholder(ElemKind::FloatTy, {2, 3}, "A", false);
5995   bindings.allocate(A)->getHandle() = {1, 1.2f, 0.5f, 1.3f, 2.7f, 3.1f};
5996   auto *tr = F->createTranspose("Tr", A, {1, 0});
5997   auto *result = F->createSave("Ret", tr);
5998   auto *resultTensor = bindings.allocate(result->getPlaceholder());
5999 
6000   return std::make_pair(F, resultTensor);
6001 }
6002 
TEST_P(OperatorStatelessTest,QuantizedTranspose)6003 TEST_P(OperatorStatelessTest, QuantizedTranspose) {
6004   CHECK_IF_ENABLED();
6005   compareAgainstInterpreter(getBackendName(), createAndInitTransposeNet,
6006                             ElemKind::FloatTy, ElemKind::Int8QTy, 0.0045f,
6007                             parCloneCountOpt);
6008 }
6009 
TEST_P(OperatorTest,QuantizedArithmeticUnrescaled)6010 TEST_P(OperatorTest, QuantizedArithmeticUnrescaled) {
6011   CHECK_IF_ENABLED();
6012 
6013   const dim_t len = 1000;
6014 
6015   // In this test we check the correctness of the quantized Max, Min, Add,
6016   // Sub, Mul, and Div operations.
6017   auto TQA = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.1, -1);
6018   auto TQB = mod_.uniqueType(ElemKind::Int8QTy, {len}, 0.9, 2);
6019   // For TQC, set offset to -11 to avoid division by 0 later.
6020   auto TQC = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.2, -11);
6021   auto TO1 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.4, 3);
6022   auto TO2 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 0.8, 2);
6023   auto TO3 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.7, 5);
6024   auto TO4 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 0.3, -7);
6025   auto TO5 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.2, 3);
6026   auto TO6 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.0, -2);
6027 
6028   auto *QA = mod_.createPlaceholder(ElemKind::Int8QTy, {len}, TQA->getScale(),
6029                                     TQA->getOffset(), "QA", false);
6030   auto *QB = mod_.createPlaceholder(ElemKind::Int8QTy, {len}, TQB->getScale(),
6031                                     TQB->getOffset(), "QB", false);
6032   auto *QC = mod_.createPlaceholder(ElemKind::Int8QTy, {len}, TQC->getScale(),
6033                                     TQC->getOffset(), "QC", false);
6034 
6035   bindings_.allocate(QA)->getHandle<int8_t>().randomize(-10, 10,
6036                                                         mod_.getPRNG());
6037   bindings_.allocate(QB)->getHandle<int8_t>().randomize(-10, 10,
6038                                                         mod_.getPRNG());
6039   bindings_.allocate(QC)->getHandle<int8_t>().randomize(-10, 10,
6040                                                         mod_.getPRNG());
6041 
6042   // Apply max/min/add/sub/mul/div quantized.
6043   Node *max = F_->createMax("max", TO1, QA, QB);
6044   Node *min = F_->createMin("min", TO2, QA, QB);
6045   Node *add = F_->createAdd("add", TO3, QA, QB);
6046   Node *sub = F_->createSub("sub", TO4, QA, QB);
6047   Node *mul = F_->createMul("mul", TO5, QA, QB);
6048   Node *div = F_->createDiv("div", TO6, QB, QC);
6049 
6050   // Save results of the operations.
6051   auto *O1 = F_->createSave("saveMax", max);
6052   auto *O2 = F_->createSave("saveMin", min);
6053   auto *O3 = F_->createSave("saveAdd", add);
6054   auto *O4 = F_->createSave("saveSub", sub);
6055   auto *O5 = F_->createSave("saveMul", mul);
6056   auto *O6 = F_->createSave("saveDiv", div);
6057 
6058   bindings_.allocate(O1->getPlaceholder());
6059   bindings_.allocate(O2->getPlaceholder());
6060   bindings_.allocate(O3->getPlaceholder());
6061   bindings_.allocate(O4->getPlaceholder());
6062   bindings_.allocate(O5->getPlaceholder());
6063   bindings_.allocate(O6->getPlaceholder());
6064 
6065   auto QAH = bindings_.get(QA)->getHandle<int8_t>();
6066   auto QBH = bindings_.get(QB)->getHandle<int8_t>();
6067   auto QCH = bindings_.get(QC)->getHandle<int8_t>();
6068   auto O1H = bindings_.get(O1->getPlaceholder())->getHandle<int8_t>();
6069   auto O2H = bindings_.get(O2->getPlaceholder())->getHandle<int8_t>();
6070   auto O3H = bindings_.get(O3->getPlaceholder())->getHandle<int8_t>();
6071   auto O4H = bindings_.get(O4->getPlaceholder())->getHandle<int8_t>();
6072   auto O5H = bindings_.get(O5->getPlaceholder())->getHandle<int8_t>();
6073   auto O6H = bindings_.get(O6->getPlaceholder())->getHandle<int8_t>();
6074 
6075   EE_.compile(CompilationMode::Infer);
6076   EE_.run(bindings_);
6077 
6078   for (dim_t i = 0; i < len; i++) {
6079     float a = TQA->getScale() * (QAH.at({i}) - TQA->getOffset());
6080     float b = TQB->getScale() * (QBH.at({i}) - TQB->getOffset());
6081     float c = TQC->getScale() * (QCH.at({i}) - TQC->getOffset());
6082     float max = std::max(a, b) / TO1->getScale() + TO1->getOffset();
6083     float min = std::min(a, b) / TO2->getScale() + TO2->getOffset();
6084     float add = (a + b) / TO3->getScale() + TO3->getOffset();
6085     float sub = (a - b) / TO4->getScale() + TO4->getOffset();
6086     float mul = (a * b) / TO5->getScale() + TO5->getOffset();
6087     float div = (b / c) / TO6->getScale() + TO6->getOffset();
6088 
6089     EXPECT_NEAR(std::round(max), O1H.at({i}), 1.0);
6090     EXPECT_NEAR(std::round(min), O2H.at({i}), 1.0);
6091     EXPECT_NEAR(std::round(add), O3H.at({i}), 1.0);
6092     EXPECT_NEAR(std::round(sub), O4H.at({i}), 1.0);
6093     EXPECT_NEAR(std::round(mul), O5H.at({i}), 1.0);
6094     EXPECT_NEAR(std::round(div), O6H.at({i}), 1.0);
6095   }
6096 }
6097 
TEST_P(OperatorTest,QuantizedCmpLTEAndSelect)6098 TEST_P(OperatorTest, QuantizedCmpLTEAndSelect) {
6099   CHECK_IF_ENABLED();
6100 
6101   // In this test we check the correctness of the quantized
6102   // less-than-or-equal-to comparison operator.
6103   const dim_t len = 1000;
6104   auto TQA = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.1, -3);
6105   auto TQB = mod_.uniqueType(ElemKind::Int8QTy, {len}, 0.9, 5);
6106   auto TQC = mod_.uniqueType(ElemKind::Int8QTy, {len}, 0.8, 3);
6107   auto TQD = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.2, -4);
6108   auto OT = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.5, -2);
6109 
6110   auto *QA = mod_.createPlaceholder(ElemKind::Int8QTy, {len}, TQA->getScale(),
6111                                     TQA->getOffset(), "QA", false);
6112   auto *QB = mod_.createPlaceholder(ElemKind::Int8QTy, {len}, TQB->getScale(),
6113                                     TQB->getOffset(), "QB", false);
6114   auto *QC = mod_.createPlaceholder(ElemKind::Int8QTy, {len}, TQC->getScale(),
6115                                     TQC->getOffset(), "QC", false);
6116   auto *QD = mod_.createPlaceholder(ElemKind::Int8QTy, {len}, TQD->getScale(),
6117                                     TQD->getOffset(), "QD", false);
6118 
6119   auto QAH = bindings_.allocate(QA)->getHandle<int8_t>();
6120   auto QBH = bindings_.allocate(QB)->getHandle<int8_t>();
6121   auto QCH = bindings_.allocate(QC)->getHandle<int8_t>();
6122   auto QDH = bindings_.allocate(QD)->getHandle<int8_t>();
6123 
6124   QAH.randomize(-128, 127, mod_.getPRNG());
6125   QBH.randomize(-128, 127, mod_.getPRNG());
6126   QCH.randomize(-128, 127, mod_.getPRNG());
6127   QDH.randomize(-128, 127, mod_.getPRNG());
6128 
6129   // Apply comparison and selection quantized.
6130   Node *cmpLTE = F_->createCmpLTE("cmpLTE", QA, QB);
6131   Node *select = F_->createSelect("select", OT, cmpLTE, QC, QD);
6132 
6133   // Save result of the operation.
6134   auto *out = F_->createSave("save", select);
6135   auto OH = bindings_.allocate(out->getPlaceholder())->getHandle<int8_t>();
6136 
6137   EE_.compile(CompilationMode::Infer);
6138   EE_.run(bindings_);
6139 
6140   int count_strict = 0;
6141   int count = 0;
6142   for (dim_t i = 0; i < len; i++) {
6143     float a = TQA->getScale() * (QAH.at({i}) - TQA->getOffset());
6144     float b = TQB->getScale() * (QBH.at({i}) - TQB->getOffset());
6145     float c = TQC->getScale() * (QCH.at({i}) - TQC->getOffset());
6146     float d = TQD->getScale() * (QDH.at({i}) - TQD->getOffset());
6147     float tmp = (a <= b) ? c : d;
6148     int32_t q = std::round(tmp / 1.5 - 2);
6149     int8_t select = quantization::clip<int32_t, int8_t>(q);
6150 
6151     if (OH.at({i}) != select) {
6152       count_strict++;
6153       if (std::abs(OH.at({i}) - select) > 1) {
6154         count++;
6155       }
6156     }
6157   }
6158   // Require that the number of off-by-1 errors be at most 0.6%.
6159   EXPECT_LE(count_strict, 6);
6160   EXPECT_LE(count, 4);
6161 }
6162 
TEST_P(OperatorTest,TestQuantizedRescaleSequence)6163 TEST_P(OperatorTest, TestQuantizedRescaleSequence) {
6164   CHECK_IF_ENABLED();
6165 
6166   const dim_t len = 100;
6167 
6168   auto *A = mod_.createPlaceholder(ElemKind::FloatTy, {len}, "A", false);
6169 
6170   auto AH = bindings_.allocate(A)->getHandle();
6171 
6172   // Notice that the range below is the an approximation of the scale factors
6173   // in T3 and T4. If we increase the size of the range we may start losing
6174   // some values.
6175   AH.randomize(-12, 12, mod_.getPRNG());
6176 
6177   auto T1 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 1.0, 0);
6178   auto T2 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 0.9, 2);
6179   auto T3 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 0.1, -3);
6180   auto T4 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 0.1, 7);
6181   auto T5 = mod_.uniqueType(ElemKind::Int8QTy, {len}, 0.3, -3);
6182 
6183   Node *R = F_->createQuantize("R", A, T1);
6184   // Check that a sequence of type conversions does not change the result.
6185   R = F_->createRescaleQuantized("R", R, T1);
6186   R = F_->createRescaleQuantized("R", R, T2);
6187   R = F_->createRescaleQuantized("R", R, T3);
6188   // Check that adding the quantized zero does not change the result.
6189   auto *G = F_->createSplat("splatZero", T3, 0.0);
6190   R = F_->createAdd("addZero", G, R);
6191   R = F_->createRescaleQuantized("R", R, T4);
6192   R = F_->createRescaleQuantized("R", R, T5);
6193   R = F_->createRescaleQuantized("R", R, T1);
6194   auto *DQ = F_->createDequantize("DQ", R, ElemKind::FloatTy);
6195 
6196   // Test a sequence of rescale operations t
6197   auto *result = F_->createSave("save", DQ);
6198   auto OH = bindings_.allocate(result->getPlaceholder())->getHandle();
6199   EE_.compile(CompilationMode::Infer);
6200   EE_.run(bindings_);
6201 
6202   for (dim_t i = 0; i < len; i++) {
6203     EXPECT_NEAR(AH.at({i}), OH.at({i}), 1.0);
6204   }
6205 }
6206 
6207 /// Helper to test concatVectors using \p DTy.
6208 template <typename DataType>
testConcatVectors(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)6209 static void testConcatVectors(glow::PlaceholderBindings &bindings,
6210                               glow::Module &mod, glow::Function *F,
6211                               glow::ExecutionEngine &EE, ElemKind DTy) {
6212   F->setName("concatVectors");
6213 
6214   auto *V1 =
6215       createPlaceholderConditionallyQuantized(mod, DTy, {10}, "V1", false);
6216   auto *V2 =
6217       createPlaceholderConditionallyQuantized(mod, DTy, {20}, "V2", false);
6218   auto *V3 =
6219       createPlaceholderConditionallyQuantized(mod, DTy, {30}, "V3", false);
6220 
6221   bindings.allocate(V1);
6222   bindings.allocate(V2);
6223   bindings.allocate(V3);
6224 
6225   Node *L = F->createConcat("concat", {V1, V2, V3}, 0);
6226   auto *result = F->createSave("ret", L);
6227   bindings.allocate(result->getPlaceholder());
6228 
6229   auto I1 = createTensorConditionallyQuantized(DTy, {10});
6230   auto I2 = createTensorConditionallyQuantized(DTy, {20});
6231   auto I3 = createTensorConditionallyQuantized(DTy, {30});
6232 
6233   for (dim_t i = 0; i < 10; i++) {
6234     I1.getHandle<DataType>().at({i}) = i;
6235 
6236     I2.getHandle<DataType>().at({i}) = i + 10;
6237     I2.getHandle<DataType>().at({i + 10}) = i + 20;
6238     I3.getHandle<DataType>().at({i}) = i + 30;
6239     I3.getHandle<DataType>().at({i + 10}) = i + 40;
6240     I3.getHandle<DataType>().at({i + 20}) = i + 50;
6241   }
6242 
6243   EE.compile(CompilationMode::Infer);
6244 
6245   // Testing the output vector.
6246   updateInputPlaceholders(bindings, {V1, V2, V3}, {&I1, &I2, &I3});
6247   EE.run(bindings);
6248 
6249   auto RNWH = bindings.get(result->getPlaceholder())->getHandle<DataType>();
6250   (void)RNWH;
6251 
6252   for (dim_t i = 0; i < 60; i++) {
6253     EXPECT_NEAR(RNWH.at({i}), static_cast<DataType>(i), 0.001);
6254   }
6255 }
6256 
6257 /// Test concatenating vectors that are Int64ITy.
TEST_P(OperatorTest,concatVectors_Int64)6258 TEST_P(OperatorTest, concatVectors_Int64) {
6259   CHECK_IF_ENABLED();
6260   testConcatVectors<int64_t>(bindings_, mod_, F_, EE_, ElemKind::Int64ITy);
6261 }
6262 
6263 /// Test concatenating vectors that are Int32ITy.
TEST_P(OperatorTest,concatVectors_Int32)6264 TEST_P(OperatorTest, concatVectors_Int32) {
6265   CHECK_IF_ENABLED();
6266   testConcatVectors<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32ITy);
6267 }
6268 
6269 /// Test concatenating vectors that are Int8Qty.
TEST_P(OperatorTest,concatVectors_Int8)6270 TEST_P(OperatorTest, concatVectors_Int8) {
6271   CHECK_IF_ENABLED();
6272   testConcatVectors<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
6273 }
6274 
6275 /// Test concatenating vectors that are BoolTy.
TEST_P(OperatorTest,concatVectors_Bool)6276 TEST_P(OperatorTest, concatVectors_Bool) {
6277   CHECK_IF_ENABLED();
6278   testConcatVectors<bool>(bindings_, mod_, F_, EE_, ElemKind::BoolTy);
6279 }
6280 
6281 /// Test concatenating vectors that are FloatTy.
TEST_P(OperatorTest,concatVectors_Float)6282 TEST_P(OperatorTest, concatVectors_Float) {
6283   CHECK_IF_ENABLED();
6284   testConcatVectors<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
6285 }
6286 
6287 /// Test concatenating vectors that are Float16Ty.
TEST_P(OperatorTest,concatVectors_Float16)6288 TEST_P(OperatorTest, concatVectors_Float16) {
6289   CHECK_IF_ENABLED();
6290   testConcatVectors<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
6291 }
6292 
6293 /// Test concatenating vectors that are Float16Ty.
TEST_P(OperatorTest,concatVectors_BFloat16)6294 TEST_P(OperatorTest, concatVectors_BFloat16) {
6295   CHECK_IF_ENABLED();
6296   testConcatVectors<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty);
6297 }
6298 
6299 /// Helper to test ConcatVectorsRepeated using \p DTy.
6300 template <typename DataType>
testConcatVectorsRepeated(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)6301 static void testConcatVectorsRepeated(glow::PlaceholderBindings &bindings,
6302                                       glow::Module &mod, glow::Function *F,
6303                                       glow::ExecutionEngine &EE, ElemKind DTy) {
6304   F->setName("concatVectors");
6305 
6306   auto *V1 =
6307       createPlaceholderConditionallyQuantized(mod, DTy, {10}, "V1", false);
6308   auto *V2 =
6309       createPlaceholderConditionallyQuantized(mod, DTy, {20}, "V2", false);
6310   bindings.allocate(V1);
6311   bindings.allocate(V2);
6312 
6313   // Alternate adding sequences of V1 and V2, so that the IRGen'd
6314   // InsertTensors have different counts.
6315   Node *L = F->createConcat("concat", {V2, V1, V1, V1, V2, V2, V1, V1, V2}, 0);
6316   auto *result = F->createSave("ret", L);
6317   bindings.allocate(result->getPlaceholder());
6318 
6319   auto I1 = createTensorConditionallyQuantized(DTy, {10});
6320   auto I2 = createTensorConditionallyQuantized(DTy, {20});
6321   auto I1H = I1.getHandle<DataType>();
6322   auto I2H = I2.getHandle<DataType>();
6323   for (dim_t i = 0; i < 10; i++) {
6324     I1H.at({i}) = 1;
6325 
6326     I2H.at({i}) = 2;
6327     I2H.at({i + 10}) = 2;
6328   }
6329 
6330   EE.compile(CompilationMode::Infer);
6331 
6332   // Testing the output vector.
6333   updateInputPlaceholders(bindings, {V1, V2}, {&I1, &I2});
6334   EE.run(bindings);
6335 
6336   auto outH = bindings.get(result->getPlaceholder())->getHandle<DataType>();
6337 
6338   // Simply verify here that the values are in their correct places, based on
6339   // the number of times/order V1 and V2 are concatenated and their sizes.
6340   for (dim_t i = 0; i < 130; i++) {
6341     if ((i < 20) || (i >= 50 && i < 90) || (i >= 110)) {
6342       EXPECT_EQ(outH.at({i}), static_cast<DataType>(2));
6343     } else {
6344       EXPECT_EQ(outH.at({i}), static_cast<DataType>(1));
6345     }
6346   }
6347 }
6348 
6349 /// Check that concatenating two tensors repeatedly is correct. This is
6350 /// intended to verify that IRGen to InsertTensor instructions with axis/count
6351 /// works correctly. Testing Int64ITy data.
TEST_P(OperatorTest,concatVectorsRepeated_Int64)6352 TEST_P(OperatorTest, concatVectorsRepeated_Int64) {
6353   CHECK_IF_ENABLED();
6354   testConcatVectorsRepeated<int64_t>(bindings_, mod_, F_, EE_,
6355                                      ElemKind::Int64ITy);
6356 }
6357 
6358 /// Check that concatenating two tensors repeatedly is correct. This is
6359 /// intended to verify that IRGen to InsertTensor instructions with axis/count
6360 /// works correctly. Testing Int32ITy data.
TEST_P(OperatorTest,concatVectorsRepeated_Int32)6361 TEST_P(OperatorTest, concatVectorsRepeated_Int32) {
6362   CHECK_IF_ENABLED();
6363   testConcatVectorsRepeated<int32_t>(bindings_, mod_, F_, EE_,
6364                                      ElemKind::Int32ITy);
6365 }
6366 
6367 /// Check that concatenating two tensors repeatedly is correct. This is
6368 /// intended to verify that IRGen to InsertTensor instructions with axis/count
6369 /// works correctly. Testing Int8QTy data.
TEST_P(OperatorTest,concatVectorsRepeated_Int8)6370 TEST_P(OperatorTest, concatVectorsRepeated_Int8) {
6371   CHECK_IF_ENABLED();
6372   testConcatVectorsRepeated<int8_t>(bindings_, mod_, F_, EE_,
6373                                     ElemKind::Int8QTy);
6374 }
6375 
6376 /// Check that concatenating two tensors repeatedly is correct. This is
6377 /// intended to verify that IRGen to InsertTensor instructions with axis/count
6378 /// works correctly. Testing BoolTy data.
TEST_P(OperatorTest,concatVectorsRepeated_Bool)6379 TEST_P(OperatorTest, concatVectorsRepeated_Bool) {
6380   CHECK_IF_ENABLED();
6381   testConcatVectorsRepeated<bool>(bindings_, mod_, F_, EE_, ElemKind::BoolTy);
6382 }
6383 
6384 /// Check that concatenating two tensors repeatedly is correct. This is
6385 /// intended to verify that IRGen to InsertTensor instructions with axis/count
6386 /// works correctly. Testing FloatTy data.
TEST_P(OperatorTest,concatVectorsRepeated_Float)6387 TEST_P(OperatorTest, concatVectorsRepeated_Float) {
6388   CHECK_IF_ENABLED();
6389   testConcatVectorsRepeated<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
6390 }
6391 
6392 /// Check that concatenating two tensors repeatedly is correct. This is
6393 /// intended to verify that IRGen to InsertTensor instructions with axis/count
6394 /// works correctly. Testing Float16Ty data.
TEST_P(OperatorTest,concatVectorsRepeated_Float16)6395 TEST_P(OperatorTest, concatVectorsRepeated_Float16) {
6396   CHECK_IF_ENABLED();
6397   testConcatVectorsRepeated<float16_t>(bindings_, mod_, F_, EE_,
6398                                        ElemKind::Float16Ty);
6399 }
6400 
6401 /// Check that concatenating two tensors repeatedly is correct. This is
6402 /// intended to verify that IRGen to InsertTensor instructions with axis/count
6403 /// works correctly. Testing BFloat16Ty data.
TEST_P(OperatorTest,concatVectorsRepeated_BFloat16)6404 TEST_P(OperatorTest, concatVectorsRepeated_BFloat16) {
6405   CHECK_IF_ENABLED();
6406   testConcatVectorsRepeated<bfloat16_t>(bindings_, mod_, F_, EE_,
6407                                         ElemKind::BFloat16Ty);
6408 }
6409 
6410 /// Helper to test SliceVectors using \p DTy.
6411 template <typename DataType>
testSliceVectors(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)6412 static void testSliceVectors(glow::PlaceholderBindings &bindings,
6413                              glow::Module &mod, glow::Function *F,
6414                              glow::ExecutionEngine &EE, ElemKind DTy) {
6415   F->setName("sliceVectors");
6416 
6417   auto *V =
6418       createPlaceholderConditionallyQuantized(mod, DTy, {3, 30}, "V", false);
6419   bindings.allocate(V);
6420 
6421   Node *S1 = F->createSlice("slice1", V, {0, 10}, {3, 13});
6422   Node *S2 = F->createSlice("slice2", V, {1, 0}, {2, 30});
6423   Node *S3 = F->createSlice("slice3", V, {2, 10}, {3, 12});
6424 
6425   auto *result1 = F->createSave("ret1", S1);
6426   auto *result2 = F->createSave("ret2", S2);
6427   auto *result3 = F->createSave("ret3", S3);
6428 
6429   bindings.allocate(result1->getPlaceholder());
6430   bindings.allocate(result2->getPlaceholder());
6431   bindings.allocate(result3->getPlaceholder());
6432 
6433   auto I = createTensorConditionallyQuantized(DTy, {3, 30});
6434   auto IH = I.getHandle<DataType>();
6435   for (dim_t j = 0; j < 30; j++) {
6436     IH.at({0, j}) = j;
6437     IH.at({1, j}) = j + 30;
6438     IH.at({2, j}) = j + 60;
6439   }
6440 
6441   EE.compile(CompilationMode::Infer);
6442 
6443   // Testing the output slices.
6444   updateInputPlaceholders(bindings, {V}, {&I});
6445   EE.run(bindings);
6446 
6447   auto RNWH1 = bindings.get(result1->getPlaceholder())->getHandle<DataType>();
6448   auto RNWH2 = bindings.get(result2->getPlaceholder())->getHandle<DataType>();
6449   auto RNWH3 = bindings.get(result3->getPlaceholder())->getHandle<DataType>();
6450 
6451   EXPECT_EQ(3, RNWH1.dims()[0]);
6452   EXPECT_EQ(3, RNWH1.dims()[1]);
6453   for (dim_t i = 0; i < 3; i++) {
6454     for (dim_t j = 10; j < 13; j++) {
6455       EXPECT_NEAR(RNWH1.at({i, j - 10}), j + i * 30, 0.001);
6456     }
6457   }
6458   EXPECT_EQ(1, RNWH2.dims()[0]);
6459   EXPECT_EQ(30, RNWH2.dims()[1]);
6460   for (dim_t j = 0; j < 30; j++) {
6461     EXPECT_NEAR(RNWH2.at({0, j}), j + 30, 0.001);
6462   }
6463   EXPECT_EQ(1, RNWH3.dims()[0]);
6464   EXPECT_EQ(2, RNWH3.dims()[1]);
6465   for (dim_t j = 10; j < 12; j++) {
6466     EXPECT_NEAR(RNWH3.at({0, j - 10}), j + 60, 0.001);
6467   }
6468 }
6469 
6470 /// Test slicing with Int64ITy.
TEST_P(OperatorTest,sliceVectors_Int64)6471 TEST_P(OperatorTest, sliceVectors_Int64) {
6472   CHECK_IF_ENABLED();
6473   testSliceVectors<int64_t>(bindings_, mod_, F_, EE_, ElemKind::Int64ITy);
6474 }
6475 
6476 /// Test slicing with FloatTy.
TEST_P(OperatorTest,sliceVectors_Float)6477 TEST_P(OperatorTest, sliceVectors_Float) {
6478   CHECK_IF_ENABLED();
6479   testSliceVectors<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
6480 }
6481 
6482 /// Test slicing with Float16Ty.
TEST_P(OperatorTest,sliceVectors_Float16)6483 TEST_P(OperatorTest, sliceVectors_Float16) {
6484   CHECK_IF_ENABLED();
6485   testSliceVectors<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
6486 }
6487 
6488 /// Test slicing with BFloat16Ty.
TEST_P(OperatorTest,sliceVectors_BFloat16)6489 TEST_P(OperatorTest, sliceVectors_BFloat16) {
6490   CHECK_IF_ENABLED();
6491   testSliceVectors<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty);
6492 }
6493 
6494 /// Test slicing with Int8QTy.
TEST_P(OperatorTest,sliceVectors_Int8)6495 TEST_P(OperatorTest, sliceVectors_Int8) {
6496   CHECK_IF_ENABLED();
6497   testSliceVectors<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
6498 }
6499 
6500 /// Test slicing with Int32QTy.
TEST_P(OperatorTest,sliceVectors_Int32Q)6501 TEST_P(OperatorTest, sliceVectors_Int32Q) {
6502   CHECK_IF_ENABLED();
6503   testSliceVectors<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32QTy);
6504 }
6505 
6506 /// Test slicing with Int32ITy.
TEST_P(OperatorTest,sliceVectors_Int32I)6507 TEST_P(OperatorTest, sliceVectors_Int32I) {
6508   CHECK_IF_ENABLED();
6509   testSliceVectors<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32ITy);
6510 }
6511 
6512 /// Helper to test SliceConcatVectors using \p DTy.
6513 template <typename DataType>
testSliceConcatVectors(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)6514 static void testSliceConcatVectors(glow::PlaceholderBindings &bindings,
6515                                    glow::Module &mod, glow::Function *F,
6516                                    glow::ExecutionEngine &EE, ElemKind DTy) {
6517   F->setName("sliceConcatVectors");
6518 
6519   auto *V =
6520       createPlaceholderConditionallyQuantized(mod, DTy, {5, 4}, "V", false);
6521   bindings.allocate(V);
6522 
6523   auto I = createTensorConditionallyQuantized(DTy, {5, 4});
6524   auto IH = I.getHandle<DataType>();
6525   for (dim_t i = 0; i < 5; i++) {
6526     for (dim_t j = 0; j < 4; j++) {
6527       IH.at({i, j}) = i * 10 + j;
6528     }
6529   }
6530 
6531   Node *S0 = F->createSlice("slice0", V, {1, 0}, {5, 4});
6532   Node *S1 = F->createSlice("slice1", S0, {0, 0}, {2, 4});
6533   Node *S2 = F->createSlice("slice2", S0, {2, 0}, {4, 4});
6534   Node *S3 = F->createSlice("slice3", S0, {0, 0}, {2, 2});
6535   Node *S4 = F->createSlice("slice4", S0, {2, 2}, {4, 4});
6536   Node *S5 = F->createSlice("slice5", V, {0, 0}, {1, 4});
6537 
6538   Node *C0 = F->createConcat("concat0", {S5, S1}, 0);
6539   Node *C1 = F->createConcat("concat1", {S3, S4}, 1);
6540   Node *C2 = F->createConcat("concat2", {S2, C1, C0}, 0);
6541 
6542   auto *result = F->createSave("ret", C2);
6543   bindings.allocate(result->getPlaceholder());
6544 
6545   EE.compile(CompilationMode::Infer);
6546 
6547   updateInputPlaceholders(bindings, {V}, {&I});
6548   EE.run(bindings);
6549 
6550   const DataType expected[7][4] = {
6551       {30, 31, 32, 33}, {40, 41, 42, 43}, {10, 11, 32, 33}, {20, 21, 42, 43},
6552       {0, 1, 2, 3},     {10, 11, 12, 13}, {20, 21, 22, 23}};
6553 
6554   auto resultH = bindings.get(result->getPlaceholder())->getHandle<DataType>();
6555   EXPECT_EQ(7, resultH.dims()[0]);
6556   EXPECT_EQ(4, resultH.dims()[1]);
6557   for (dim_t i = 0; i < 7; i++) {
6558     for (dim_t j = 0; j < 4; j++) {
6559       EXPECT_EQ(resultH.at({i, j}), expected[i][j]);
6560     }
6561   }
6562 }
6563 
6564 /// Test a combination of slicing and concating, in Int64ITy.
TEST_P(OperatorTest,sliceConcatVectors_Int64)6565 TEST_P(OperatorTest, sliceConcatVectors_Int64) {
6566   CHECK_IF_ENABLED();
6567   testSliceConcatVectors<int64_t>(bindings_, mod_, F_, EE_, ElemKind::Int64ITy);
6568 }
6569 
6570 /// Test a combination of slicing and concating, in Int8QTy.
TEST_P(OperatorTest,sliceConcatVectors_Int8)6571 TEST_P(OperatorTest, sliceConcatVectors_Int8) {
6572   CHECK_IF_ENABLED();
6573   testSliceConcatVectors<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
6574 }
6575 
6576 /// Test a combination of slicing and concating, in FloatTy.
TEST_P(OperatorTest,sliceConcatVectors_Float)6577 TEST_P(OperatorTest, sliceConcatVectors_Float) {
6578   CHECK_IF_ENABLED();
6579   testSliceConcatVectors<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
6580 }
6581 
6582 /// Test a combination of slicing and concating, in Float16Ty.
TEST_P(OperatorTest,sliceConcatVectors_Float16)6583 TEST_P(OperatorTest, sliceConcatVectors_Float16) {
6584   CHECK_IF_ENABLED();
6585   testSliceConcatVectors<float16_t>(bindings_, mod_, F_, EE_,
6586                                     ElemKind::Float16Ty);
6587 }
6588 
6589 /// Test a combination of slicing and concating, in BFloat16Ty.
TEST_P(OperatorTest,sliceConcatVectors_BFloat16)6590 TEST_P(OperatorTest, sliceConcatVectors_BFloat16) {
6591   CHECK_IF_ENABLED();
6592   testSliceConcatVectors<bfloat16_t>(bindings_, mod_, F_, EE_,
6593                                      ElemKind::BFloat16Ty);
6594 }
6595 
TEST_P(OperatorTest,Tile)6596 TEST_P(OperatorTest, Tile) {
6597   CHECK_IF_ENABLED();
6598 
6599   F_->setName("concatVectors");
6600 
6601   auto *V = mod_.createPlaceholder(ElemKind::FloatTy, {4, 5}, "V", false);
6602   bindings_.allocate(V);
6603 
6604   Node *T0 = F_->createTile("tile0", V, /* tiles */ 3, /* axis */ 0);
6605   auto *result0 = F_->createSave("res0", T0);
6606   bindings_.allocate(result0->getPlaceholder());
6607 
6608   Node *T1 = F_->createTile("tile1", V, /* tiles */ 3, /* axis */ 1);
6609   auto *result1 = F_->createSave("res1", T1);
6610   bindings_.allocate(result1->getPlaceholder());
6611 
6612   Tensor VT(ElemKind::FloatTy, {4, 5});
6613 
6614   for (dim_t i = 0; i < 4; i++) {
6615     for (dim_t j = 0; j < 5; j++) {
6616       VT.getHandle<float>().at({i, j}) = i * 5 + j;
6617     }
6618   }
6619 
6620   EE_.compile(CompilationMode::Infer);
6621 
6622   updateInputPlaceholders(bindings_, {V}, {&VT});
6623   EE_.run(bindings_);
6624 
6625   // Testing the output vector with axis 0.
6626   auto res0 = bindings_.get(result0->getPlaceholder())->getHandle<float>();
6627   for (dim_t i = 0; i < res0.dims()[0]; i++) {
6628     for (dim_t j = 0; j < res0.dims()[1]; j++) {
6629       EXPECT_EQ(res0.at({i, j}), (i % 4) * 5 + j);
6630     }
6631   }
6632 
6633   // Testing the output vector with axis 1.
6634   auto res1 = bindings_.get(result1->getPlaceholder())->getHandle<float>();
6635   for (dim_t i = 0; i < res1.dims()[0]; i++) {
6636     for (dim_t j = 0; j < res1.dims()[1]; j++) {
6637       EXPECT_EQ(res1.at({i, j}), i * 5 + (j % 5));
6638     }
6639   }
6640 }
6641 
TEST_P(OperatorTest,QuantizedTile)6642 TEST_P(OperatorTest, QuantizedTile) {
6643   CHECK_IF_ENABLED();
6644 
6645   F_->setName("concatVectors");
6646 
6647   auto *V = mod_.createPlaceholder(ElemKind::FloatTy, {4, 5}, "V", false);
6648   bindings_.allocate(V);
6649 
6650   auto quantizationParams =
6651       glow::quantization::chooseQuantizationParams({0, 20});
6652   auto quantizeTy =
6653       mod_.uniqueType(ElemKind::Int8QTy, {4, 5}, quantizationParams.scale,
6654                       quantizationParams.offset);
6655   auto *Q = F_->createQuantize("quantize", V, quantizeTy);
6656 
6657   Node *T0 = F_->createTile("tile0", Q, /* tiles */ 3, /* axis */ 0);
6658   auto *DQ0 = F_->createDequantize("dequantize0", T0, ElemKind::FloatTy);
6659   auto *result0 = F_->createSave("res0", DQ0);
6660   bindings_.allocate(result0->getPlaceholder());
6661 
6662   Node *T1 = F_->createTile("tile1", Q, /* tiles */ 3, /* axis */ 1);
6663   auto *DQ1 = F_->createDequantize("dequantize1", T1, ElemKind::FloatTy);
6664   auto *result1 = F_->createSave("res1", DQ1);
6665   bindings_.allocate(result1->getPlaceholder());
6666 
6667   Tensor VT(ElemKind::FloatTy, {4, 5});
6668 
6669   for (dim_t i = 0; i < 4; i++) {
6670     for (dim_t j = 0; j < 5; j++) {
6671       VT.getHandle<float>().at({i, j}) = i * 5 + j;
6672     }
6673   }
6674 
6675   EE_.compile(CompilationMode::Infer);
6676 
6677   updateInputPlaceholders(bindings_, {V}, {&VT});
6678   EE_.run(bindings_);
6679 
6680   // Testing the output vector with axis 0.
6681   auto res0 = bindings_.get(result0->getPlaceholder())->getHandle<float>();
6682   for (dim_t i = 0; i < res0.dims()[0]; i++) {
6683     for (dim_t j = 0; j < res0.dims()[1]; j++) {
6684       EXPECT_NEAR(res0.at({i, j}), (i % 4) * 5 + j, 0.05);
6685     }
6686   }
6687 
6688   // Testing the output vector with axis 1.
6689   auto res1 = bindings_.get(result1->getPlaceholder())->getHandle<float>();
6690   (void)res1;
6691   for (dim_t i = 0; i < res1.dims()[0]; i++) {
6692     for (dim_t j = 0; j < res1.dims()[1]; j++) {
6693       EXPECT_NEAR(res1.at({i, j}), i * 5 + (j % 5), 0.05);
6694     }
6695   }
6696 }
6697 
TEST_P(OperatorTest,Clip)6698 TEST_P(OperatorTest, Clip) {
6699   CHECK_IF_ENABLED();
6700 
6701   auto *X = mod_.createPlaceholder(ElemKind::FloatTy, {5, 5}, "X", false);
6702   auto xHandle = bindings_.allocate(X)->getHandle();
6703   xHandle = {45.0, 16.0, 59.0, 99.0, 48.0, 12.0, 44.0, 46.0, 82.0,
6704              28.0, 1.0,  91.0, 18.0, 9.0,  71.0, 24.0, 37.0, 61.0,
6705              12.0, 81.0, 36.0, 38.0, 30.0, 84.0, 40.0};
6706 
6707   float min = 20.0;
6708   float max = 60.0;
6709   auto *node = F_->createClip("clip", X, min, max);
6710   auto *save = F_->createSave("save", node);
6711   auto *saveTensor = bindings_.allocate(save->getPlaceholder());
6712   EE_.compile(CompilationMode::Infer);
6713   EE_.run(bindings_);
6714 
6715   auto result = saveTensor->getHandle();
6716   std::vector<dim_t> expectedDims = {5, 5};
6717   std::vector<float> expectedValues = {45.0, 20.0, 59.0, 60.0, 48.0, 20.0, 44.0,
6718                                        46.0, 60.0, 28.0, 20.0, 60.0, 20.0, 20.0,
6719                                        60.0, 24.0, 37.0, 60.0, 20.0, 60.0, 36.0,
6720                                        38.0, 30.0, 60.0, 40.0};
6721   EXPECT_TRUE(result.dims().vec() == expectedDims);
6722   for (size_t i = 0; i < 5 * 5; i++) {
6723     EXPECT_FLOAT_EQ(result.raw(i), expectedValues[i]);
6724   }
6725 }
6726 
TEST_P(OperatorTest,Not)6727 TEST_P(OperatorTest, Not) {
6728   CHECK_IF_ENABLED();
6729   auto *input = mod_.createPlaceholder(ElemKind::BoolTy, {2}, "inp", false);
6730   bindings_.allocate(input)->getHandle<bool>() = {false, true};
6731   auto *node = F_->createNot("not", input);
6732   auto *save = F_->createSave("save", node);
6733   auto *outT = bindings_.allocate(save->getPlaceholder());
6734   EE_.compile(CompilationMode::Infer);
6735   EE_.run(bindings_);
6736   auto outH = outT->getHandle<bool>();
6737   EXPECT_EQ(outH.size(), 2);
6738   EXPECT_EQ(outH.raw(0), true);
6739   EXPECT_EQ(outH.raw(1), false);
6740 }
6741 
TEST_P(OperatorTest,And)6742 TEST_P(OperatorTest, And) {
6743   CHECK_IF_ENABLED();
6744   auto *LHS = mod_.createPlaceholder(ElemKind::BoolTy, {4}, "LHS", false);
6745   auto *RHS = mod_.createPlaceholder(ElemKind::BoolTy, {4}, "RHS", false);
6746   bindings_.allocate(LHS)->getHandle<bool>() = {false, true, false, true};
6747   bindings_.allocate(RHS)->getHandle<bool>() = {false, false, true, true};
6748   auto *node = F_->createAnd("and", LHS, RHS);
6749   auto *save = F_->createSave("save", node);
6750   auto *outT = bindings_.allocate(save->getPlaceholder());
6751   EE_.compile(CompilationMode::Infer);
6752   EE_.run(bindings_);
6753   auto outH = outT->getHandle<bool>();
6754   EXPECT_EQ(outH.size(), 4);
6755   EXPECT_EQ(outH.raw(0), false);
6756   EXPECT_EQ(outH.raw(1), false);
6757   EXPECT_EQ(outH.raw(2), false);
6758   EXPECT_EQ(outH.raw(3), true);
6759 }
6760 
TEST_P(OperatorTest,Or)6761 TEST_P(OperatorTest, Or) {
6762   CHECK_IF_ENABLED();
6763   auto *LHS = mod_.createPlaceholder(ElemKind::BoolTy, {4}, "LHS", false);
6764   auto *RHS = mod_.createPlaceholder(ElemKind::BoolTy, {4}, "RHS", false);
6765   bindings_.allocate(LHS)->getHandle<bool>() = {false, true, false, true};
6766   bindings_.allocate(RHS)->getHandle<bool>() = {false, false, true, true};
6767   auto *node = F_->createOr("or", LHS, RHS);
6768   auto *save = F_->createSave("save", node);
6769   auto *outT = bindings_.allocate(save->getPlaceholder());
6770   EE_.compile(CompilationMode::Infer);
6771   EE_.run(bindings_);
6772   auto outH = outT->getHandle<bool>();
6773   EXPECT_EQ(outH.size(), 4);
6774   EXPECT_EQ(outH.raw(0), false);
6775   EXPECT_EQ(outH.raw(1), true);
6776   EXPECT_EQ(outH.raw(2), true);
6777   EXPECT_EQ(outH.raw(3), true);
6778 }
6779 
TEST_P(OperatorTest,Xor)6780 TEST_P(OperatorTest, Xor) {
6781   CHECK_IF_ENABLED();
6782   auto *LHS = mod_.createPlaceholder(ElemKind::BoolTy, {4}, "LHS", false);
6783   auto *RHS = mod_.createPlaceholder(ElemKind::BoolTy, {4}, "RHS", false);
6784   bindings_.allocate(LHS)->getHandle<bool>() = {false, true, false, true};
6785   bindings_.allocate(RHS)->getHandle<bool>() = {false, false, true, true};
6786   auto *node = F_->createXor("xor", LHS, RHS);
6787   auto *save = F_->createSave("save", node);
6788   auto *outT = bindings_.allocate(save->getPlaceholder());
6789   EE_.compile(CompilationMode::Infer);
6790   EE_.run(bindings_);
6791   auto outH = outT->getHandle<bool>();
6792   EXPECT_EQ(outH.size(), 4);
6793   EXPECT_EQ(outH.raw(0), false);
6794   EXPECT_EQ(outH.raw(1), true);
6795   EXPECT_EQ(outH.raw(2), true);
6796   EXPECT_EQ(outH.raw(3), false);
6797 }
6798 
TEST_P(OperatorTest,Abs_FloatTy)6799 TEST_P(OperatorTest, Abs_FloatTy) {
6800   CHECK_IF_ENABLED();
6801   auto *inp = mod_.createPlaceholder(ElemKind::FloatTy, {2}, "inp", false);
6802   bindings_.allocate(inp)->getHandle<float>() = {-1.0, 1.0};
6803   auto *node = F_->createAbs("abs", inp);
6804   auto *save = F_->createSave("save", node);
6805   auto *outT = bindings_.allocate(save->getPlaceholder());
6806   EE_.compile(CompilationMode::Infer);
6807   EE_.run(bindings_);
6808   auto outH = outT->getHandle<float>();
6809   EXPECT_EQ(outH.size(), 2);
6810   EXPECT_FLOAT_EQ(outH.raw(0), 1.0);
6811   EXPECT_FLOAT_EQ(outH.raw(1), 1.0);
6812 }
6813 
TEST_P(OperatorTest,Abs_Int8QTy)6814 TEST_P(OperatorTest, Abs_Int8QTy) {
6815   CHECK_IF_ENABLED();
6816   auto *inp =
6817       mod_.createPlaceholder(ElemKind::Int8QTy, {2}, 1.0, 0, "inp", false);
6818   bindings_.allocate(inp)->getHandle<int8_t>() = {-1, 1};
6819   auto *node = F_->createAbs("abs", inp);
6820   auto *save = F_->createSave("save", node);
6821   auto *outT = bindings_.allocate(save->getPlaceholder());
6822   EE_.compile(CompilationMode::Infer);
6823   EE_.run(bindings_);
6824   auto outH = outT->getHandle<int8_t>();
6825   EXPECT_EQ(outH.size(), 2);
6826   EXPECT_EQ(outH.raw(0), 1);
6827   EXPECT_EQ(outH.raw(1), 1);
6828 }
6829 
TEST_P(OperatorTest,Neg_FloatTy)6830 TEST_P(OperatorTest, Neg_FloatTy) {
6831   CHECK_IF_ENABLED();
6832   auto *inp = mod_.createPlaceholder(ElemKind::FloatTy, {2}, "inp", false);
6833   bindings_.allocate(inp)->getHandle<float>() = {1.0, -1.0};
6834   auto *node = F_->createNeg("neg", inp);
6835   auto *save = F_->createSave("save", node);
6836   auto *outT = bindings_.allocate(save->getPlaceholder());
6837   EE_.compile(CompilationMode::Infer);
6838   EE_.run(bindings_);
6839   auto outH = outT->getHandle<float>();
6840   EXPECT_EQ(outH.size(), 2);
6841   EXPECT_FLOAT_EQ(outH.raw(0), -1.0);
6842   EXPECT_FLOAT_EQ(outH.raw(1), 1.0);
6843 }
6844 
TEST_P(OperatorTest,Neg_Int8QTy)6845 TEST_P(OperatorTest, Neg_Int8QTy) {
6846   CHECK_IF_ENABLED();
6847   auto *inp =
6848       mod_.createPlaceholder(ElemKind::Int8QTy, {2}, 1.0, 0, "inp", false);
6849   bindings_.allocate(inp)->getHandle<int8_t>() = {-1, 1};
6850   auto *node = F_->createNeg("neg", inp);
6851   auto *save = F_->createSave("save", node);
6852   auto *outT = bindings_.allocate(save->getPlaceholder());
6853   EE_.compile(CompilationMode::Infer);
6854   EE_.run(bindings_);
6855   auto outH = outT->getHandle<int8_t>();
6856   EXPECT_EQ(outH.size(), 2);
6857   EXPECT_EQ(outH.raw(0), 1);
6858   EXPECT_EQ(outH.raw(1), -1);
6859 }
6860 
TEST_P(OperatorTest,Floor_FloatTy)6861 TEST_P(OperatorTest, Floor_FloatTy) {
6862   CHECK_IF_ENABLED();
6863   auto *inp = mod_.createPlaceholder(ElemKind::FloatTy, {3}, "inp", false);
6864   bindings_.allocate(inp)->getHandle<float>() = {-0.2, 1.0, 1.99};
6865   auto *node = F_->createFloor("floor", inp);
6866   auto *save = F_->createSave("save", node);
6867   auto *outT = bindings_.allocate(save->getPlaceholder());
6868   EE_.compile(CompilationMode::Infer);
6869   EE_.run(bindings_);
6870   auto outH = outT->getHandle<float>();
6871   EXPECT_EQ(outH.size(), 3);
6872   EXPECT_FLOAT_EQ(outH.raw(0), -1.0);
6873   EXPECT_FLOAT_EQ(outH.raw(1), 1.0);
6874   EXPECT_FLOAT_EQ(outH.raw(2), 1.0);
6875 }
6876 
TEST_P(OperatorTest,Floor_Int8QTy)6877 TEST_P(OperatorTest, Floor_Int8QTy) {
6878   CHECK_IF_ENABLED();
6879   auto *inp =
6880       mod_.createPlaceholder(ElemKind::Int8QTy, {5}, 0.5, 0, "inp", false);
6881   bindings_.allocate(inp)->getHandle<int8_t>() = {-2, -1, 0, 1, 2};
6882   auto *node = F_->createFloor("floor", inp);
6883   auto *save = F_->createSave("save", node);
6884   auto *outT = bindings_.allocate(save->getPlaceholder());
6885   EE_.compile(CompilationMode::Infer);
6886   EE_.run(bindings_);
6887   auto outH = outT->getHandle<int8_t>();
6888   EXPECT_EQ(outH.size(), 5);
6889   EXPECT_EQ(outH.raw(0), -2);
6890   EXPECT_EQ(outH.raw(1), -2);
6891   EXPECT_EQ(outH.raw(2), 0);
6892   EXPECT_EQ(outH.raw(3), 0);
6893   EXPECT_EQ(outH.raw(4), 2);
6894 }
6895 
TEST_P(OperatorTest,Ceil_FloatTy)6896 TEST_P(OperatorTest, Ceil_FloatTy) {
6897   CHECK_IF_ENABLED();
6898   auto *inp = mod_.createPlaceholder(ElemKind::FloatTy, {3}, "inp", false);
6899   bindings_.allocate(inp)->getHandle<float>() = {-0.2, 1.0, 1.99};
6900   auto *node = F_->createCeil("ceil", inp);
6901   auto *save = F_->createSave("save", node);
6902   auto *outT = bindings_.allocate(save->getPlaceholder());
6903   EE_.compile(CompilationMode::Infer);
6904   EE_.run(bindings_);
6905   auto outH = outT->getHandle<float>();
6906   EXPECT_EQ(outH.size(), 3);
6907   EXPECT_FLOAT_EQ(outH.raw(0), 0.0);
6908   EXPECT_FLOAT_EQ(outH.raw(1), 1.0);
6909   EXPECT_FLOAT_EQ(outH.raw(2), 2.0);
6910 }
6911 
TEST_P(OperatorTest,Ceil_Int8QTy)6912 TEST_P(OperatorTest, Ceil_Int8QTy) {
6913   CHECK_IF_ENABLED();
6914   auto *inp =
6915       mod_.createPlaceholder(ElemKind::Int8QTy, {5}, 0.5, 0, "inp", false);
6916   bindings_.allocate(inp)->getHandle<int8_t>() = {-2, -1, 0, 1, 2};
6917   auto *node = F_->createCeil("ceil", inp);
6918   auto *save = F_->createSave("save", node);
6919   auto *outT = bindings_.allocate(save->getPlaceholder());
6920   EE_.compile(CompilationMode::Infer);
6921   EE_.run(bindings_);
6922   auto outH = outT->getHandle<int8_t>();
6923   EXPECT_EQ(outH.size(), 5);
6924   EXPECT_EQ(outH.raw(0), -2);
6925   EXPECT_EQ(outH.raw(1), 0);
6926   EXPECT_EQ(outH.raw(2), 0);
6927   EXPECT_EQ(outH.raw(3), 2);
6928   EXPECT_EQ(outH.raw(4), 2);
6929 }
6930 
TEST_P(OperatorTest,Round_FloatTy)6931 TEST_P(OperatorTest, Round_FloatTy) {
6932   CHECK_IF_ENABLED();
6933   auto *inp = mod_.createPlaceholder(ElemKind::FloatTy, {5}, "inp", false);
6934   bindings_.allocate(inp)->getHandle<float>() = {0.9, 2.5, 2.3, 1.5, -4.5};
6935   auto *node = F_->createRound("round", inp);
6936   auto *save = F_->createSave("save", node);
6937   auto *outT = bindings_.allocate(save->getPlaceholder());
6938   EE_.compile(CompilationMode::Infer);
6939   EE_.run(bindings_);
6940   auto outH = outT->getHandle<float>();
6941   EXPECT_EQ(outH.size(), 5);
6942   // Rounding mode required by ONNX, Numpy, TensorFlow is round to even which
6943   // rounds to nearest even integer those values with fractional part 0.5.
6944   EXPECT_FLOAT_EQ(outH.raw(0), 1.0);
6945   EXPECT_FLOAT_EQ(outH.raw(1), 2.0);
6946   EXPECT_FLOAT_EQ(outH.raw(2), 2.0);
6947   EXPECT_FLOAT_EQ(outH.raw(3), 2.0);
6948   EXPECT_FLOAT_EQ(outH.raw(4), -4.0);
6949 }
6950 
TEST_P(OperatorTest,Round_Int8QTy)6951 TEST_P(OperatorTest, Round_Int8QTy) {
6952   CHECK_IF_ENABLED();
6953   auto *inp =
6954       mod_.createPlaceholder(ElemKind::Int8QTy, {5}, 0.1, 0, "inp", false);
6955   bindings_.allocate(inp)->getHandle<int8_t>() = {-8, -2, 0, 2, 8};
6956   auto *node = F_->createRound("round", inp);
6957   auto *save = F_->createSave("save", node);
6958   auto *outT = bindings_.allocate(save->getPlaceholder());
6959   EE_.compile(CompilationMode::Infer);
6960   EE_.run(bindings_);
6961   auto outH = outT->getHandle<int8_t>();
6962   EXPECT_EQ(outH.size(), 5);
6963   EXPECT_EQ(outH.raw(0), -10);
6964   EXPECT_EQ(outH.raw(1), 0);
6965   EXPECT_EQ(outH.raw(2), 0);
6966   EXPECT_EQ(outH.raw(3), 0);
6967   EXPECT_EQ(outH.raw(4), 10);
6968 }
6969 
TEST_P(OperatorTest,Sqrt_FloatTy)6970 TEST_P(OperatorTest, Sqrt_FloatTy) {
6971   CHECK_IF_ENABLED();
6972   auto *inp = mod_.createPlaceholder(ElemKind::FloatTy, {4}, "inp", false);
6973   bindings_.allocate(inp)->getHandle<float>() = {0.0, 1.0, 4.0, 9.0};
6974   auto *node = F_->createSqrt("sqrt", inp);
6975   auto *save = F_->createSave("save", node);
6976   auto *outT = bindings_.allocate(save->getPlaceholder());
6977   EE_.compile(CompilationMode::Infer);
6978   EE_.run(bindings_);
6979   auto outH = outT->getHandle<float>();
6980   EXPECT_EQ(outH.size(), 4);
6981   EXPECT_FLOAT_EQ(outH.raw(0), 0.0);
6982   EXPECT_FLOAT_EQ(outH.raw(1), 1.0);
6983   EXPECT_FLOAT_EQ(outH.raw(2), 2.0);
6984   EXPECT_FLOAT_EQ(outH.raw(3), 3.0);
6985 }
6986 
TEST_P(OperatorTest,Sqrt_Int8QTy)6987 TEST_P(OperatorTest, Sqrt_Int8QTy) {
6988   CHECK_IF_ENABLED();
6989   auto *inp =
6990       mod_.createPlaceholder(ElemKind::Int8QTy, {4}, 1.0, 0, "inp", false);
6991   bindings_.allocate(inp)->getHandle<int8_t>() = {0, 1, 4, 9};
6992   auto *node = F_->createSqrt("sqrt", inp);
6993   auto *save = F_->createSave("save", node);
6994   auto *outT = bindings_.allocate(save->getPlaceholder());
6995   EE_.compile(CompilationMode::Infer);
6996   EE_.run(bindings_);
6997   auto outH = outT->getHandle<int8_t>();
6998   EXPECT_EQ(outH.size(), 4);
6999   EXPECT_EQ(outH.raw(0), 0);
7000   EXPECT_EQ(outH.raw(1), 1);
7001   EXPECT_EQ(outH.raw(2), 2);
7002   EXPECT_EQ(outH.raw(3), 3);
7003 }
7004 
TEST_P(OperatorTest,Rsqrt_FloatTy)7005 TEST_P(OperatorTest, Rsqrt_FloatTy) {
7006   CHECK_IF_ENABLED();
7007   auto *inp = mod_.createPlaceholder(ElemKind::FloatTy, {4}, "inp", false);
7008   bindings_.allocate(inp)->getHandle<float>() = {1.0, 4.0, 16.0, 64.0};
7009   auto *node = F_->createRsqrt("rsqrt", inp);
7010   auto *save = F_->createSave("save", node);
7011   auto *outT = bindings_.allocate(save->getPlaceholder());
7012   EE_.compile(CompilationMode::Infer);
7013   EE_.run(bindings_);
7014   auto outH = outT->getHandle<float>();
7015   EXPECT_EQ(outH.size(), 4);
7016   EXPECT_FLOAT_EQ(outH.raw(0), 1.0);
7017   EXPECT_FLOAT_EQ(outH.raw(1), 0.5);
7018   EXPECT_FLOAT_EQ(outH.raw(2), 0.25);
7019   EXPECT_FLOAT_EQ(outH.raw(3), 0.125);
7020 }
7021 
TEST_P(OperatorTest,Rsqrt_Int8QTy)7022 TEST_P(OperatorTest, Rsqrt_Int8QTy) {
7023   CHECK_IF_ENABLED();
7024   auto *inp =
7025       mod_.createPlaceholder(ElemKind::Int8QTy, {4}, 1.0, 0, "inp", false);
7026   bindings_.allocate(inp)->getHandle<int8_t>() = {1, 4, 16, 64};
7027   auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {4}, 1.0 / 8.0, 0);
7028   auto *node = F_->createRsqrt("rsqrt", outTy, inp);
7029   auto *save = F_->createSave("save", node);
7030   auto *outT = bindings_.allocate(save->getPlaceholder());
7031   EE_.compile(CompilationMode::Infer);
7032   EE_.run(bindings_);
7033   auto outH = outT->getHandle<int8_t>();
7034   EXPECT_EQ(outH.size(), 4);
7035   EXPECT_EQ(outH.raw(0), 8);
7036   EXPECT_EQ(outH.raw(1), 4);
7037   EXPECT_EQ(outH.raw(2), 2);
7038   EXPECT_EQ(outH.raw(3), 1);
7039 }
7040 
TEST_P(OperatorTest,Reciprocal_FloatTy)7041 TEST_P(OperatorTest, Reciprocal_FloatTy) {
7042   CHECK_IF_ENABLED();
7043   auto *inp = mod_.createPlaceholder(ElemKind::FloatTy, {4}, "inp", false);
7044   bindings_.allocate(inp)->getHandle<float>() = {1.0, 2.0, 4.0, 8.0};
7045   auto *node = F_->createReciprocal("reciprocal", inp);
7046   auto *save = F_->createSave("save", node);
7047   auto *outT = bindings_.allocate(save->getPlaceholder());
7048   EE_.compile(CompilationMode::Infer);
7049   EE_.run(bindings_);
7050   auto outH = outT->getHandle<float>();
7051   EXPECT_EQ(outH.size(), 4);
7052   EXPECT_FLOAT_EQ(outH.raw(0), 1.0);
7053   EXPECT_FLOAT_EQ(outH.raw(1), 0.5);
7054   EXPECT_FLOAT_EQ(outH.raw(2), 0.25);
7055   EXPECT_FLOAT_EQ(outH.raw(3), 0.125);
7056 }
7057 
TEST_P(OperatorTest,Reciprocal_Int8QTy)7058 TEST_P(OperatorTest, Reciprocal_Int8QTy) {
7059   CHECK_IF_ENABLED();
7060   auto *inp =
7061       mod_.createPlaceholder(ElemKind::Int8QTy, {4}, 1.0, 0, "inp", false);
7062   bindings_.allocate(inp)->getHandle<int8_t>() = {1, 2, 4, 8};
7063   auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {4}, 1.0 / 8.0, 0);
7064   auto *node = F_->createReciprocal("reciprocal", outTy, inp);
7065   auto *save = F_->createSave("save", node);
7066   auto *outT = bindings_.allocate(save->getPlaceholder());
7067   EE_.compile(CompilationMode::Infer);
7068   EE_.run(bindings_);
7069   auto outH = outT->getHandle<int8_t>();
7070   EXPECT_EQ(outH.size(), 4);
7071   EXPECT_EQ(outH.raw(0), 8);
7072   EXPECT_EQ(outH.raw(1), 4);
7073   EXPECT_EQ(outH.raw(2), 2);
7074   EXPECT_EQ(outH.raw(3), 1);
7075 }
7076 
TEST_P(OperatorTest,Sin_FloatTy)7077 TEST_P(OperatorTest, Sin_FloatTy) {
7078   CHECK_IF_ENABLED();
7079   auto *inp = mod_.createPlaceholder(ElemKind::FloatTy, {4}, "inp", false);
7080   bindings_.allocate(inp)->getHandle<float>() = {-1.0, 0.0, 1.0, 2.0};
7081   auto *node = F_->createSin("sin", inp);
7082   auto *save = F_->createSave("save", node);
7083   auto *outT = bindings_.allocate(save->getPlaceholder());
7084   EE_.compile(CompilationMode::Infer);
7085   EE_.run(bindings_);
7086   auto outH = outT->getHandle<float>();
7087   EXPECT_EQ(outH.size(), 4);
7088   EXPECT_FLOAT_EQ(outH.raw(0), std::sin(-1.0));
7089   EXPECT_FLOAT_EQ(outH.raw(1), std::sin(0.0));
7090   EXPECT_FLOAT_EQ(outH.raw(2), std::sin(1.0));
7091   EXPECT_FLOAT_EQ(outH.raw(3), std::sin(2.0));
7092 }
7093 
TEST_P(OperatorTest,Sin_Int8QTy)7094 TEST_P(OperatorTest, Sin_Int8QTy) {
7095   CHECK_IF_ENABLED();
7096   auto *inp =
7097       mod_.createPlaceholder(ElemKind::Int8QTy, {4}, 1.0, 0, "inp", false);
7098   bindings_.allocate(inp)->getHandle<int8_t>() = {-1, 0, 1, 2};
7099   auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {4}, 1.0 / 127.0, 0);
7100   auto *node = F_->createSin("sin", outTy, inp);
7101   auto *save = F_->createSave("save", node);
7102   auto *outT = bindings_.allocate(save->getPlaceholder());
7103   EE_.compile(CompilationMode::Infer);
7104   EE_.run(bindings_);
7105   auto outH = outT->getHandle<int8_t>();
7106   EXPECT_EQ(outH.size(), 4);
7107   EXPECT_EQ(outH.raw(0), static_cast<int8_t>(std::round(std::sin(-1) * 127)));
7108   EXPECT_EQ(outH.raw(1), static_cast<int8_t>(std::round(std::sin(0) * 127)));
7109   EXPECT_EQ(outH.raw(2), static_cast<int8_t>(std::round(std::sin(1) * 127)));
7110   EXPECT_EQ(outH.raw(3), static_cast<int8_t>(std::round(std::sin(2) * 127)));
7111 }
7112 
TEST_P(OperatorTest,Cos_FloatTy)7113 TEST_P(OperatorTest, Cos_FloatTy) {
7114   CHECK_IF_ENABLED();
7115   auto *inp = mod_.createPlaceholder(ElemKind::FloatTy, {4}, "inp", false);
7116   bindings_.allocate(inp)->getHandle<float>() = {-1.0, 0.0, 1.0, 2.0};
7117   auto *node = F_->createCos("cos", inp);
7118   auto *save = F_->createSave("save", node);
7119   auto *outT = bindings_.allocate(save->getPlaceholder());
7120   EE_.compile(CompilationMode::Infer);
7121   EE_.run(bindings_);
7122   auto outH = outT->getHandle<float>();
7123   EXPECT_EQ(outH.size(), 4);
7124   EXPECT_FLOAT_EQ(outH.raw(0), std::cos(-1.0));
7125   EXPECT_FLOAT_EQ(outH.raw(1), std::cos(0.0));
7126   EXPECT_FLOAT_EQ(outH.raw(2), std::cos(1.0));
7127   EXPECT_FLOAT_EQ(outH.raw(3), std::cos(2.0));
7128 }
7129 
TEST_P(OperatorTest,Cos_Int8QTy)7130 TEST_P(OperatorTest, Cos_Int8QTy) {
7131   CHECK_IF_ENABLED();
7132   auto *inp =
7133       mod_.createPlaceholder(ElemKind::Int8QTy, {4}, 1.0, 0, "inp", false);
7134   bindings_.allocate(inp)->getHandle<int8_t>() = {-1, 0, 1, 2};
7135   auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {4}, 1.0 / 127.0, 0);
7136   auto *node = F_->createCos("cos", outTy, inp);
7137   auto *save = F_->createSave("save", node);
7138   auto *outT = bindings_.allocate(save->getPlaceholder());
7139   EE_.compile(CompilationMode::Infer);
7140   EE_.run(bindings_);
7141   auto outH = outT->getHandle<int8_t>();
7142   EXPECT_EQ(outH.size(), 4);
7143   EXPECT_EQ(outH.raw(0), static_cast<int8_t>(std::round(std::cos(-1) * 127)));
7144   EXPECT_EQ(outH.raw(1), static_cast<int8_t>(std::round(std::cos(0) * 127)));
7145   EXPECT_EQ(outH.raw(2), static_cast<int8_t>(std::round(std::cos(1) * 127)));
7146   EXPECT_EQ(outH.raw(3), static_cast<int8_t>(std::round(std::cos(2) * 127)));
7147 }
7148 
7149 /// Helper to test CmpNEQ using \p elemKind.
7150 template <typename ElemType>
testCmpNEQ(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind elemKind)7151 static void testCmpNEQ(glow::PlaceholderBindings &bindings, glow::Module &mod,
7152                        glow::Function *F, glow::ExecutionEngine &EE,
7153                        ElemKind elemKind) {
7154   auto *LHS =
7155       createPlaceholderConditionallyQuantized(mod, elemKind, {2}, "LHS", false);
7156   auto *RHS =
7157       createPlaceholderConditionallyQuantized(mod, elemKind, {2}, "RHS", false);
7158   bindings.allocate(LHS)->getHandle<ElemType>() = {1, 1};
7159   bindings.allocate(RHS)->getHandle<ElemType>() = {1, 2};
7160   auto *node = F->createCmpNEQ("cmpNEQ", LHS, RHS);
7161   auto *save = F->createSave("save", node);
7162   auto *outT = bindings.allocate(save->getPlaceholder());
7163   EE.compile(CompilationMode::Infer);
7164   EE.run(bindings);
7165   auto outH = outT->getHandle<bool>();
7166   EXPECT_EQ(outH.size(), 2);
7167   EXPECT_EQ(outH.raw(0), false);
7168   EXPECT_EQ(outH.raw(1), true);
7169 }
7170 
TEST_P(OperatorTest,CmpNEQ_FloatTy)7171 TEST_P(OperatorTest, CmpNEQ_FloatTy) {
7172   CHECK_IF_ENABLED();
7173   testCmpNEQ<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
7174 }
7175 
TEST_P(OperatorTest,CmpNEQ_Int8QTy)7176 TEST_P(OperatorTest, CmpNEQ_Int8QTy) {
7177   CHECK_IF_ENABLED();
7178   testCmpNEQ<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
7179 }
7180 
TEST_P(OperatorTest,CmpNEQ_Int32ITy)7181 TEST_P(OperatorTest, CmpNEQ_Int32ITy) {
7182   CHECK_IF_ENABLED();
7183   testCmpNEQ<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32ITy);
7184 }
7185 
TEST_P(OperatorTest,CmpNEQ_Int64ITy)7186 TEST_P(OperatorTest, CmpNEQ_Int64ITy) {
7187   CHECK_IF_ENABLED();
7188   testCmpNEQ<int64_t>(bindings_, mod_, F_, EE_, ElemKind::Int64ITy);
7189 }
7190 
7191 /// Helper to test CmpGT using \p elemKind.
7192 template <typename ElemType>
testCmpGT(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind elemKind)7193 static void testCmpGT(glow::PlaceholderBindings &bindings, glow::Module &mod,
7194                       glow::Function *F, glow::ExecutionEngine &EE,
7195                       ElemKind elemKind) {
7196   auto *LHS =
7197       createPlaceholderConditionallyQuantized(mod, elemKind, {3}, "LHS", false);
7198   auto *RHS =
7199       createPlaceholderConditionallyQuantized(mod, elemKind, {3}, "RHS", false);
7200   bindings.allocate(LHS)->getHandle<ElemType>() = {1, 1, 2};
7201   bindings.allocate(RHS)->getHandle<ElemType>() = {1, 2, 1};
7202   auto *node = F->createCmpGT("cmpGT", LHS, RHS);
7203   auto *save = F->createSave("save", node);
7204   auto *outT = bindings.allocate(save->getPlaceholder());
7205   EE.compile(CompilationMode::Infer);
7206   EE.run(bindings);
7207   auto outH = outT->getHandle<bool>();
7208   EXPECT_EQ(outH.size(), 3);
7209   EXPECT_EQ(outH.raw(0), false);
7210   EXPECT_EQ(outH.raw(1), false);
7211   EXPECT_EQ(outH.raw(2), true);
7212 }
7213 
TEST_P(OperatorTest,CmpGT_FloatTy)7214 TEST_P(OperatorTest, CmpGT_FloatTy) {
7215   CHECK_IF_ENABLED();
7216   testCmpGT<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
7217 }
7218 
TEST_P(OperatorTest,CmpGT_Int8QTy)7219 TEST_P(OperatorTest, CmpGT_Int8QTy) {
7220   CHECK_IF_ENABLED();
7221   testCmpGT<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
7222 }
7223 
TEST_P(OperatorTest,CmpGT_Int32ITy)7224 TEST_P(OperatorTest, CmpGT_Int32ITy) {
7225   CHECK_IF_ENABLED();
7226   testCmpGT<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32ITy);
7227 }
7228 
TEST_P(OperatorTest,CmpGT_Int64ITy)7229 TEST_P(OperatorTest, CmpGT_Int64ITy) {
7230   CHECK_IF_ENABLED();
7231   testCmpGT<int64_t>(bindings_, mod_, F_, EE_, ElemKind::Int64ITy);
7232 }
7233 
7234 /// Helper to test CmpGTE using \p elemKind.
7235 template <typename ElemType>
testCmpGTE(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind elemKind)7236 static void testCmpGTE(glow::PlaceholderBindings &bindings, glow::Module &mod,
7237                        glow::Function *F, glow::ExecutionEngine &EE,
7238                        ElemKind elemKind) {
7239   auto *LHS =
7240       createPlaceholderConditionallyQuantized(mod, elemKind, {3}, "LHS", false);
7241   auto *RHS =
7242       createPlaceholderConditionallyQuantized(mod, elemKind, {3}, "RHS", false);
7243   bindings.allocate(LHS)->getHandle<ElemType>() = {1, 1, 2};
7244   bindings.allocate(RHS)->getHandle<ElemType>() = {1, 2, 1};
7245   auto *node = F->createCmpGTE("cmpGTE", LHS, RHS);
7246   auto *save = F->createSave("save", node);
7247   auto *outT = bindings.allocate(save->getPlaceholder());
7248   EE.compile(CompilationMode::Infer);
7249   EE.run(bindings);
7250   auto outH = outT->getHandle<bool>();
7251   EXPECT_EQ(outH.size(), 3);
7252   EXPECT_EQ(outH.raw(0), true);
7253   EXPECT_EQ(outH.raw(1), false);
7254   EXPECT_EQ(outH.raw(2), true);
7255 }
7256 
TEST_P(OperatorTest,CmpGTE_FloatTy)7257 TEST_P(OperatorTest, CmpGTE_FloatTy) {
7258   CHECK_IF_ENABLED();
7259   testCmpGTE<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
7260 }
7261 
TEST_P(OperatorTest,CmpGTE_Int8QTy)7262 TEST_P(OperatorTest, CmpGTE_Int8QTy) {
7263   CHECK_IF_ENABLED();
7264   testCmpGTE<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
7265 }
7266 
TEST_P(OperatorTest,CmpGTE_Int32ITy)7267 TEST_P(OperatorTest, CmpGTE_Int32ITy) {
7268   CHECK_IF_ENABLED();
7269   testCmpGTE<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32ITy);
7270 }
7271 
TEST_P(OperatorTest,CmpGTE_Int64ITy)7272 TEST_P(OperatorTest, CmpGTE_Int64ITy) {
7273   CHECK_IF_ENABLED();
7274   testCmpGTE<int64_t>(bindings_, mod_, F_, EE_, ElemKind::Int64ITy);
7275 }
7276 
TEST_P(OperatorTest,simpleCmpSelectPredication)7277 TEST_P(OperatorTest, simpleCmpSelectPredication) {
7278   CHECK_IF_ENABLED();
7279 
7280   // A simple test that checks predication of some values using the
7281   // compare-select pair of instructions. Keep doubling some values
7282   // until some condition is met.
7283 
7284   auto *inputs =
7285       mod_.createPlaceholder(ElemKind::FloatTy, {10}, "inputs", false);
7286   auto *counters =
7287       mod_.createPlaceholder(ElemKind::FloatTy, {10}, "counters", false);
7288 
7289   bindings_.allocate(counters)->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
7290   bindings_.allocate(inputs)->getHandle().clear(1);
7291 
7292   Node *cnt = counters;
7293   NodeValue data = inputs;
7294   Node *const1 = F_->createSplat("const1", counters->getType(), 1.0);
7295   Node *const0 = F_->createSplat("const0", counters->getType(), 0.0);
7296 
7297   for (int i = 0; i < 10; i++) {
7298     cnt = F_->createSub("sub1", cnt, const1);
7299     Node *pred = F_->createCmpLTE("cmp", const0, cnt);
7300 
7301     Node *const2 = F_->createSplat("const2", data.getType(), 2.0);
7302     Node *newData = F_->createMul("mul2x", data, const2);
7303 
7304     data = F_->createSelect("select", pred, newData, data);
7305   }
7306 
7307   auto *SN = F_->createSave("ret", data);
7308   bindings_.allocate(SN->getPlaceholder());
7309 
7310   EE_.compile(CompilationMode::Infer);
7311   EE_.run(bindings_);
7312 
7313   auto H = bindings_.get(SN->getPlaceholder())->getHandle();
7314   ASSERT_NEAR(H.at(0), 1, 0.001);
7315   ASSERT_NEAR(H.at(1), 2, 0.001);
7316   ASSERT_NEAR(H.at(2), 4, 0.001);
7317   ASSERT_NEAR(H.at(3), 8, 0.001);
7318   ASSERT_NEAR(H.at(4), 16, 0.001);
7319   ASSERT_NEAR(H.at(5), 32, 0.001);
7320   ASSERT_NEAR(H.at(6), 64, 0.001);
7321   ASSERT_NEAR(H.at(7), 128, 0.001);
7322   ASSERT_NEAR(H.at(8), 256, 0.001);
7323   ASSERT_NEAR(H.at(9), 512, 0.001);
7324 }
7325 
TEST_P(OperatorTest,simplePredication)7326 TEST_P(OperatorTest, simplePredication) {
7327   CHECK_IF_ENABLED();
7328 
7329   auto *inputs =
7330       mod_.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "inputs", false);
7331   auto *counters =
7332       mod_.createPlaceholder(ElemKind::FloatTy, {10}, "counters", false);
7333 
7334   bindings_.allocate(counters)->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
7335   bindings_.allocate(inputs)->getHandle().randomize(-10, 10, mod_.getPRNG());
7336 
7337   Node *C5 = F_->createSplat("C5", counters->getType(), 5.0);
7338   Node *pred = F_->createCmpLTE("cmp", C5, counters);
7339 
7340   auto *FC0 = F_->createFullyConnected(bindings_, "FC0", inputs, 128);
7341   auto *RL0 = F_->createRELU("RL0", FC0);
7342   auto *FC1 = F_->createFullyConnected(bindings_, "FC1", RL0, 64);
7343   auto *RL1 = F_->createRELU("RL1", FC1);
7344   auto *FC2 = F_->createFullyConnected(bindings_, "FC2", RL1, 32);
7345   auto *RL2 = F_->createRELU("RL2", FC2);
7346 
7347   auto *save = F_->createSave("ret", RL2);
7348   bindings_.allocate(save->getPlaceholder());
7349 
7350   FC0->setPredicate(pred);
7351   FC1->setPredicate(pred);
7352   FC2->setPredicate(pred);
7353 
7354   ::glow::convertPlaceholdersToConstants(
7355       F_, bindings_, {inputs, counters, save->getPlaceholder()});
7356   EE_.compile(CompilationMode::Infer);
7357   EE_.run(bindings_);
7358 }
7359 
TEST_P(OperatorTest,ChannelShuffle)7360 TEST_P(OperatorTest, ChannelShuffle) {
7361   CHECK_IF_ENABLED();
7362 
7363   auto *inputs =
7364       mod_.createPlaceholder(ElemKind::FloatTy, {1, 12, 1, 1}, "inputs", false);
7365   bindings_.allocate(inputs)->getHandle() = {1, 2, 3, 4,  5,  6,
7366                                              7, 8, 9, 10, 11, 12};
7367 
7368   Node *CS = F_->createChannelShuffle("CS", inputs, 3, 1);
7369   SaveNode *S = F_->createSave("save", CS);
7370   bindings_.allocate(S->getPlaceholder());
7371 
7372   EE_.compile(CompilationMode::Infer);
7373   EE_.run(bindings_);
7374 
7375   auto results = bindings_.get(S->getPlaceholder())->getHandle();
7376 
7377   EXPECT_EQ(results.size(), 12);
7378   std::vector<float> expected = {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12};
7379   for (dim_t i = 0; i < expected.size(); i++)
7380     EXPECT_FLOAT_EQ(results.at({0, i, 0, 0}), expected[i]);
7381 }
7382 
TEST_P(OperatorTest,SqueezeOneAxis)7383 TEST_P(OperatorTest, SqueezeOneAxis) {
7384   CHECK_IF_ENABLED();
7385 
7386   auto *inputs =
7387       mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 1, 5}, "inputs", false);
7388   bindings_.allocate(inputs)->getHandle() = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
7389 
7390   std::vector<float> expectedValues = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
7391 
7392   std::vector<dim_t> axes = {0};
7393   Node *SQZ = F_->createSqueeze("SQZ", inputs, axes);
7394   SaveNode *S = F_->createSave("save", SQZ);
7395   bindings_.allocate(S->getPlaceholder());
7396 
7397   EE_.compile(CompilationMode::Infer);
7398   EE_.run(bindings_);
7399 
7400   auto results = bindings_.get(S->getPlaceholder())->getHandle();
7401   std::vector<dim_t> expectedDims = {2, 1, 5};
7402   EXPECT_TRUE(results.dims().vec() == expectedDims);
7403   for (size_t i = 0; i < 10; i++)
7404     EXPECT_FLOAT_EQ(results.raw(i), expectedValues[i]);
7405 }
7406 
TEST_P(OperatorTest,SqueezeTwoAxes)7407 TEST_P(OperatorTest, SqueezeTwoAxes) {
7408   CHECK_IF_ENABLED();
7409 
7410   auto mod = &EE_.getModule();
7411   auto *inputs =
7412       mod->createPlaceholder(ElemKind::FloatTy, {1, 2, 1, 5}, "inputs", false);
7413   bindings_.allocate(inputs)->getHandle() = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
7414 
7415   std::vector<float> expectedValues = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
7416 
7417   std::vector<dim_t> axes = {0, 2, 2};
7418   Node *SQZ = F_->createSqueeze("SQZ", inputs, axes);
7419   SaveNode *S = F_->createSave("save", SQZ);
7420   bindings_.allocate(S->getPlaceholder());
7421 
7422   EE_.compile(CompilationMode::Infer);
7423   EE_.run(bindings_);
7424 
7425   auto results = bindings_.get(S->getPlaceholder())->getHandle();
7426   std::vector<dim_t> expectedDims = {2, 5};
7427   EXPECT_TRUE(results.dims().vec() == expectedDims);
7428   for (size_t i = 0; i < 10; i++)
7429     EXPECT_FLOAT_EQ(results.raw(i), expectedValues[i]);
7430 }
7431 
TEST_P(OperatorTest,SqueezeExpand)7432 TEST_P(OperatorTest, SqueezeExpand) {
7433   CHECK_IF_ENABLED();
7434 
7435   auto mod = &EE_.getModule();
7436   auto *inputs =
7437       mod->createPlaceholder(ElemKind::FloatTy, {1, 2, 1, 5}, "inputs", false);
7438   bindings_.allocate(inputs)->getHandle() = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
7439   auto *emptyInput =
7440       mod->createPlaceholder(ElemKind::FloatTy, {1}, "emptyInput", false);
7441   bindings_.allocate(emptyInput)->getHandle() = {42.0};
7442 
7443   std::vector<float> expectedValues = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
7444 
7445   std::vector<dim_t> axes = {0};
7446   Node *SQZ = F_->createSqueeze("SQZ", emptyInput, axes);
7447   SaveNode *S1 = F_->createSave("save", SQZ);
7448   Node *UnSQZ = F_->createExpandDims("UnSQZ", SQZ, axes);
7449   SaveNode *S2 = F_->createSave("save", UnSQZ);
7450 
7451   bindings_.allocate(S1->getPlaceholder());
7452   bindings_.allocate(S2->getPlaceholder());
7453 
7454   EE_.compile(CompilationMode::Infer);
7455   EE_.run(bindings_);
7456 
7457   auto res1 = bindings_.get(S1->getPlaceholder())->getHandle();
7458   EXPECT_TRUE(res1.dims().vec() == std::vector<dim_t>());
7459   EXPECT_FLOAT_EQ(res1.raw(0), 42.0);
7460   auto res2 = bindings_.get(S2->getPlaceholder())->getHandle();
7461   EXPECT_TRUE(res2.dims().vec() == std::vector<dim_t>(1, 1));
7462   EXPECT_FLOAT_EQ(res2.raw(0), 42.0);
7463 }
7464 
7465 /// Helper to test ExpandDims using \p DTy.
7466 template <typename DataType>
testExpandDims(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)7467 static void testExpandDims(glow::PlaceholderBindings &bindings,
7468                            glow::Module &mod, glow::Function *F,
7469                            glow::ExecutionEngine &EE, ElemKind DTy) {
7470   auto *inputs = createPlaceholderConditionallyQuantized(mod, DTy, {2, 2},
7471                                                          "inputs", false);
7472   auto IH = bindings.allocate(inputs)->getHandle<DataType>();
7473   IH = {1, 2, 3, 4};
7474 
7475   // This should be uniqued and sorted, so should become {0, 1, 3, 5}.
7476   std::vector<dim_t> axes = {3, 0, 5, 1, 3};
7477   Node *EDN = F->createExpandDims("expand", inputs, axes);
7478   SaveNode *S = F->createSave("save", EDN);
7479   bindings.allocate(S->getPlaceholder());
7480 
7481   EE.compile(CompilationMode::Infer);
7482   EE.run(bindings);
7483 
7484   // Expected dims based on the axes above; inserted new dimensions of 1 in
7485   // every unique axes location, based on the output tensor shape.
7486   std::vector<dim_t> expectedDims = {1, 1, 2, 1, 2, 1};
7487   auto results = bindings.get(S->getPlaceholder())->getHandle<DataType>();
7488   EXPECT_TRUE(results.dims().vec() == expectedDims);
7489 
7490   // The data should be the same, as this was just a reshape.
7491   for (size_t i = 0; i < 4; i++) {
7492     EXPECT_FLOAT_EQ(results.raw(i), IH.raw(i));
7493   }
7494 }
7495 
7496 /// Check that the expand dims operator works, which is implemented with a
7497 /// reshape, in FloatTy.
TEST_P(OperatorTest,ExpandDims_Float)7498 TEST_P(OperatorTest, ExpandDims_Float) {
7499   CHECK_IF_ENABLED();
7500   testExpandDims<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
7501 }
7502 
7503 /// Check that the expand dims operator works, which is implemented with a
7504 /// reshape, in Float16Ty.
TEST_P(OperatorTest,ExpandDims_Float16)7505 TEST_P(OperatorTest, ExpandDims_Float16) {
7506   CHECK_IF_ENABLED();
7507   testExpandDims<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
7508 }
7509 
7510 /// Check that the expand dims operator works, which is implemented with a
7511 /// reshape, in BFloat16Ty.
TEST_P(OperatorTest,ExpandDims_BFloat16)7512 TEST_P(OperatorTest, ExpandDims_BFloat16) {
7513   CHECK_IF_ENABLED();
7514   testExpandDims<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty);
7515 }
7516 
7517 /// Check that the expand dims operator works, which is implemented with a
7518 /// reshape, in Int8QTy.
TEST_P(OperatorTest,ExpandDims_Int8)7519 TEST_P(OperatorTest, ExpandDims_Int8) {
7520   CHECK_IF_ENABLED();
7521   testExpandDims<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
7522 }
7523 
7524 /// Helper to test Split using \p DTy.
7525 template <typename DataType>
testSplit(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)7526 static void testSplit(glow::PlaceholderBindings &bindings, glow::Module &mod,
7527                       glow::Function *F, glow::ExecutionEngine &EE,
7528                       ElemKind DTy) {
7529   auto *inputs = createPlaceholderConditionallyQuantized(mod, DTy, {1, 2, 6},
7530                                                          "inputs", false);
7531   bindings.allocate(inputs)->getHandle<DataType>() = {1, 2, 3, 4,  5,  6,
7532                                                       7, 8, 9, 10, 11, 12};
7533 
7534   std::vector<SliceNode *> outputs1;
7535   F->createSplit("Split1", inputs, /*outputNum = */ 2, /*axis = */ 2,
7536                  /*split = */ {}, outputs1);
7537   std::vector<SliceNode *> outputs2;
7538   F->createSplit("Split2", inputs, /*outputNum = */ 2, /*axis = */ 2,
7539                  /*split = */ {2, 4}, outputs2);
7540   auto *S1 = F->createSave("save1", outputs1[0]);
7541   auto *S2 = F->createSave("save2", outputs1[1]);
7542   auto *S3 = F->createSave("save3", outputs2[0]);
7543   auto *S4 = F->createSave("save4", outputs2[1]);
7544 
7545   auto *result1 = bindings.allocate(S1->getPlaceholder());
7546   auto *result2 = bindings.allocate(S2->getPlaceholder());
7547   auto *result3 = bindings.allocate(S3->getPlaceholder());
7548   auto *result4 = bindings.allocate(S4->getPlaceholder());
7549 
7550   EE.compile(CompilationMode::Infer);
7551   EE.run(bindings);
7552 
7553   Tensor expected1 = createTensorConditionallyQuantized(DTy, {1, 2, 3});
7554   expected1.getHandle<DataType>() = {1, 2, 3, 7, 8, 9};
7555   EXPECT_TRUE(result1->isEqual(expected1));
7556 
7557   Tensor expected2 = createTensorConditionallyQuantized(DTy, {1, 2, 3});
7558   expected2.getHandle<DataType>() = {4, 5, 6, 10, 11, 12};
7559   EXPECT_TRUE(result2->isEqual(expected2));
7560 
7561   Tensor expected3 = createTensorConditionallyQuantized(DTy, {1, 2, 2});
7562   expected3.getHandle<DataType>() = {1, 2, 7, 8};
7563   EXPECT_TRUE(result3->isEqual(expected3));
7564 
7565   Tensor expected4 = createTensorConditionallyQuantized(DTy, {1, 2, 4});
7566   expected4.getHandle<DataType>() = {3, 4, 5, 6, 9, 10, 11, 12};
7567   EXPECT_TRUE(result4->isEqual(expected4));
7568 }
7569 
7570 /// Test that Split is correctly supported in FloatTy.
TEST_P(OperatorTest,Split_Float)7571 TEST_P(OperatorTest, Split_Float) {
7572   CHECK_IF_ENABLED();
7573   testSplit<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
7574 }
7575 
7576 /// Test that Split is correctly supported in Float16Ty.
TEST_P(OperatorTest,Split_Float16)7577 TEST_P(OperatorTest, Split_Float16) {
7578   CHECK_IF_ENABLED();
7579   testSplit<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
7580 }
7581 
7582 /// Test that Split is correctly supported in BFloat16Ty.
TEST_P(OperatorTest,Split_BFloat16)7583 TEST_P(OperatorTest, Split_BFloat16) {
7584   CHECK_IF_ENABLED();
7585   testSplit<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty);
7586 }
7587 
7588 /// Test that Split is correctly supported in Int8QTy.
TEST_P(OperatorTest,Split_Int8)7589 TEST_P(OperatorTest, Split_Int8) {
7590   CHECK_IF_ENABLED();
7591   testSplit<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
7592 }
7593 
7594 /// Test Relu with Int8QTy.
TEST_P(OperatorTest,Relu_Int8)7595 TEST_P(OperatorTest, Relu_Int8) {
7596   CHECK_IF_ENABLED();
7597 
7598   std::vector<float> inputVals = {-2.0, -1.0, 0.0, 1.0, 2.0};
7599   dim_t size = inputVals.size();
7600   const float inputScale = 1.0;
7601   const int32_t inputOffset = 5;
7602   const float outputScale = 0.5;
7603   const int32_t outputOffset = -128;
7604 
7605   auto *inputTy =
7606       mod_.uniqueType(ElemKind::Int8QTy, {size}, inputScale, inputOffset);
7607   auto *outputTy =
7608       mod_.uniqueType(ElemKind::Int8QTy, {size}, outputScale, outputOffset);
7609   auto *input = mod_.createPlaceholder(inputTy, "input", false);
7610   auto *relu = F_->createRELU("relu", input, outputTy);
7611   auto *dequantize =
7612       F_->createDequantize("dequantize", relu, ElemKind::FloatTy);
7613   auto *save = F_->createSave("save", dequantize);
7614   bindings_.allocate(mod_.getPlaceholders());
7615 
7616   auto inputH = bindings_.get(input)->getHandle<int8_t>();
7617   for (dim_t idx = 0; idx < size; idx++) {
7618     inputH.raw(idx) =
7619         quantization::quantize(inputVals[idx], {inputScale, inputOffset});
7620   }
7621 
7622   EE_.compile(CompilationMode::Infer);
7623   EE_.run(bindings_);
7624 
7625   auto outputH = bindings_.get(save->getPlaceholder())->getHandle();
7626   for (dim_t idx = 0; idx < size; idx++) {
7627     float expectedValue = std::max(0.0f, inputVals[idx]);
7628     EXPECT_EQ(expectedValue, outputH.raw(idx));
7629   }
7630 }
7631 
7632 /// Test Clip with Int8QTy.
TEST_P(OperatorTest,Clip_Int8)7633 TEST_P(OperatorTest, Clip_Int8) {
7634   CHECK_IF_ENABLED();
7635 
7636   std::vector<float> inputVals = {-3, -2, -1, 0, 1, 2, 3, 4};
7637   float clipMin = -2.0;
7638   float clipMax = 3.0;
7639   dim_t size = inputVals.size();
7640   const float inputScale = 1.0;
7641   const int32_t inputOffset = 5;
7642   const float outputScale = 0.5;
7643   const int32_t outputOffset = -3;
7644 
7645   auto *inputTy =
7646       mod_.uniqueType(ElemKind::Int8QTy, {size}, inputScale, inputOffset);
7647   auto *outputTy =
7648       mod_.uniqueType(ElemKind::Int8QTy, {size}, outputScale, outputOffset);
7649   auto *input = mod_.createPlaceholder(inputTy, "input", false);
7650   auto *relu = F_->createClip("clip", input, outputTy, clipMin, clipMax);
7651   auto *dequantize =
7652       F_->createDequantize("dequantize", relu, ElemKind::FloatTy);
7653   auto *save = F_->createSave("save", dequantize);
7654   bindings_.allocate(mod_.getPlaceholders());
7655 
7656   auto inputH = bindings_.get(input)->getHandle<int8_t>();
7657   for (dim_t idx = 0; idx < size; idx++) {
7658     inputH.raw(idx) =
7659         quantization::quantize(inputVals[idx], {inputScale, inputOffset});
7660   }
7661 
7662   EE_.compile(CompilationMode::Infer);
7663   EE_.run(bindings_);
7664 
7665   auto outputH = bindings_.get(save->getPlaceholder())->getHandle();
7666   for (dim_t idx = 0; idx < size; idx++) {
7667     float expectedValue = std::min(clipMax, std::max(clipMin, inputVals[idx]));
7668     EXPECT_EQ(expectedValue, outputH.raw(idx));
7669   }
7670 }
7671 
7672 /// Verify quantized splats work correctly (add 0 to it to ensure constant
7673 /// folding doesn't make this test meaningless).
TEST_P(OperatorTest,IntSplat)7674 TEST_P(OperatorTest, IntSplat) {
7675   CHECK_IF_ENABLED();
7676 
7677   const float splatValue = 10;
7678   const float scale = 1.0;
7679   const int32_t offset = 5;
7680   const dim_t size = 3;
7681 
7682   auto *in = mod_.createPlaceholder(ElemKind::Int8QTy, {size}, scale, offset,
7683                                     "in", false);
7684   auto splatTy = mod_.uniqueType(ElemKind::Int8QTy, {size}, scale, offset);
7685   auto *splat = F_->createSplat("splat", splatTy, splatValue);
7686   auto *add = F_->createAdd("add", in, splat);
7687   auto *dequantize = F_->createDequantize("dequantize", add, ElemKind::FloatTy);
7688   auto *save = F_->createSave("save", dequantize);
7689 
7690   bindings_.allocate(in)->zero();
7691   auto resultH = bindings_.allocate(save->getPlaceholder())->getHandle();
7692 
7693   EE_.compile(CompilationMode::Infer);
7694   EE_.run(bindings_);
7695 
7696   for (dim_t i = 0; i < resultH.size(); i++) {
7697     EXPECT_EQ(splatValue, resultH.raw(i));
7698   }
7699 }
7700 
7701 /// Verify fp16 splats work correctly (add 0 to it to ensure constant
7702 /// folding doesn't make this test meaningless).
TEST_P(OperatorTest,Fp16Splat)7703 TEST_P(OperatorTest, Fp16Splat) {
7704   CHECK_IF_ENABLED();
7705 
7706   const float splatValue = 10;
7707   const dim_t size = 3;
7708 
7709   auto *in = mod_.createPlaceholder(ElemKind::Float16Ty, {size}, "in", false);
7710   auto splatTy = mod_.uniqueType(ElemKind::Float16Ty, {size});
7711   auto *splat = F_->createSplat("splat", splatTy, splatValue);
7712   auto *add = F_->createAdd("add", in, splat);
7713   auto *save = F_->createSave("save", add);
7714 
7715   bindings_.allocate(in)->zero();
7716   auto resultH =
7717       bindings_.allocate(save->getPlaceholder())->getHandle<float16_t>();
7718 
7719   EE_.compile(CompilationMode::Infer);
7720   EE_.run(bindings_);
7721 
7722   for (dim_t i = 0; i < resultH.size(); i++) {
7723     EXPECT_EQ(float16_t(splatValue), resultH.raw(i));
7724   }
7725 }
7726 
7727 /// Verify bfloat16 splats work correctly (add 0 to it to ensure constant
7728 /// folding doesn't make this test meaningless).
TEST_P(OperatorTest,BFloat16Splat)7729 TEST_P(OperatorTest, BFloat16Splat) {
7730   CHECK_IF_ENABLED();
7731 
7732   const float splatValue = 10;
7733   const dim_t size = 3;
7734 
7735   auto *in = mod_.createPlaceholder(ElemKind::BFloat16Ty, {size}, "in", false);
7736   auto splatTy = mod_.uniqueType(ElemKind::BFloat16Ty, {size});
7737   auto *splat = F_->createSplat("splat", splatTy, splatValue);
7738   auto *add = F_->createAdd("add", in, splat);
7739   auto *save = F_->createSave("save", add);
7740 
7741   bindings_.allocate(in)->zero();
7742   auto resultH =
7743       bindings_.allocate(save->getPlaceholder())->getHandle<bfloat16_t>();
7744 
7745   EE_.compile(CompilationMode::Infer);
7746   EE_.run(bindings_);
7747 
7748   for (dim_t i = 0; i < resultH.size(); i++) {
7749     EXPECT_EQ(bfloat16_t(splatValue), resultH.raw(i));
7750   }
7751 }
7752 
7753 // simple convTranspose. symmetric, no pads, strides or channels.
TEST_P(OperatorTest,sanityConvTranspose)7754 TEST_P(OperatorTest, sanityConvTranspose) {
7755   CHECK_IF_ENABLED();
7756 
7757   float biasVal[2] = {1.1, 2.2};
7758   auto *input =
7759       mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 2, 1}, "input", false);
7760   bindings_.allocate(input)->getHandle() = {2., 3., 4., 5.};
7761 
7762   auto *filter =
7763       mod_.createPlaceholder(ElemKind::FloatTy, {2, 3, 3, 1}, "filter", false);
7764   bindings_.allocate(filter)->getHandle() = {2., 3., 4.,  5., 6.,  7.,
7765                                              8., 9., 10., 3., 4.,  5.,
7766                                              6., 7., 8.,  9., 10., 11.};
7767 
7768   auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {2}, "bias", false);
7769   bindings_.allocate(bias)->getHandle() = biasVal;
7770 
7771   std::pair<dim_t, dim_t> outWH =
7772       calculateConvTransposeOutputDims(2, 2, {3, 3}, {1, 1}, {0, 0, 0, 0});
7773   auto outTy =
7774       mod_.uniqueType(ElemKind::FloatTy, {1, outWH.first, outWH.second, 2});
7775 
7776   ConvTransposeNode *CN =
7777       F_->createConvTranspose("ConvTranspose", input, filter, bias, outTy,
7778                               {3, 3}, {1, 1}, {0, 0, 0, 0}, 1);
7779 
7780   SaveNode *S = F_->createSave("save", CN);
7781   bindings_.allocate(S->getPlaceholder());
7782 
7783   ::glow::convertPlaceholdersToConstants(F_, bindings_,
7784                                          {input, S->getPlaceholder()});
7785   EE_.compile(CompilationMode::Infer);
7786   EE_.run(bindings_);
7787 
7788   auto result = bindings_.get(S->getPlaceholder())->getHandle();
7789   std::vector<dim_t> expectedDims = {1, 4, 4, 2};
7790   ASSERT_TRUE(result.dims().vec() == expectedDims);
7791   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 0}), 4 + biasVal[0]);
7792   EXPECT_FLOAT_EQ(result.at({0, 0, 1, 0}), 12 + biasVal[0]);
7793   EXPECT_FLOAT_EQ(result.at({0, 0, 2, 0}), 17 + biasVal[0]);
7794   EXPECT_FLOAT_EQ(result.at({0, 0, 3, 0}), 12 + biasVal[0]);
7795   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 0}), 18 + biasVal[0]);
7796   EXPECT_FLOAT_EQ(result.at({0, 1, 1, 0}), 49 + biasVal[0]);
7797   EXPECT_FLOAT_EQ(result.at({0, 1, 2, 0}), 63 + biasVal[0]);
7798   EXPECT_FLOAT_EQ(result.at({0, 1, 3, 0}), 41 + biasVal[0]);
7799   EXPECT_FLOAT_EQ(result.at({0, 2, 0, 0}), 36 + biasVal[0]);
7800   EXPECT_FLOAT_EQ(result.at({0, 2, 1, 0}), 91 + biasVal[0]);
7801   EXPECT_FLOAT_EQ(result.at({0, 2, 2, 0}), 105 + biasVal[0]);
7802   EXPECT_FLOAT_EQ(result.at({0, 2, 3, 0}), 65 + biasVal[0]);
7803   EXPECT_FLOAT_EQ(result.at({0, 3, 0, 0}), 32 + biasVal[0]);
7804   EXPECT_FLOAT_EQ(result.at({0, 3, 1, 0}), 76 + biasVal[0]);
7805   EXPECT_FLOAT_EQ(result.at({0, 3, 2, 0}), 85 + biasVal[0]);
7806   EXPECT_FLOAT_EQ(result.at({0, 3, 3, 0}), 50 + biasVal[0]);
7807 
7808   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 1}), 6 + biasVal[1]);
7809   EXPECT_FLOAT_EQ(result.at({0, 0, 1, 1}), 17 + biasVal[1]);
7810   EXPECT_FLOAT_EQ(result.at({0, 0, 2, 1}), 22 + biasVal[1]);
7811   EXPECT_FLOAT_EQ(result.at({0, 0, 3, 1}), 15 + biasVal[1]);
7812   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 1}), 24 + biasVal[1]);
7813   EXPECT_FLOAT_EQ(result.at({0, 1, 1, 1}), 63 + biasVal[1]);
7814   EXPECT_FLOAT_EQ(result.at({0, 1, 2, 1}), 77 + biasVal[1]);
7815   EXPECT_FLOAT_EQ(result.at({0, 1, 3, 1}), 49 + biasVal[1]);
7816   EXPECT_FLOAT_EQ(result.at({0, 2, 0, 1}), 42 + biasVal[1]);
7817   EXPECT_FLOAT_EQ(result.at({0, 2, 1, 1}), 105 + biasVal[1]);
7818   EXPECT_FLOAT_EQ(result.at({0, 2, 2, 1}), 119 + biasVal[1]);
7819   EXPECT_FLOAT_EQ(result.at({0, 2, 3, 1}), 73 + biasVal[1]);
7820   EXPECT_FLOAT_EQ(result.at({0, 3, 0, 1}), 36 + biasVal[1]);
7821   EXPECT_FLOAT_EQ(result.at({0, 3, 1, 1}), 85 + biasVal[1]);
7822   EXPECT_FLOAT_EQ(result.at({0, 3, 2, 1}), 94 + biasVal[1]);
7823   EXPECT_FLOAT_EQ(result.at({0, 3, 3, 1}), 55 + biasVal[1]);
7824 }
7825 
7826 /// ConvTranspose with multi-channel input/output and asymmetric kernel,
7827 /// strides, pads.
TEST_P(OperatorTest,ConvTransposedAsymmetric)7828 TEST_P(OperatorTest, ConvTransposedAsymmetric) {
7829 
7830   CHECK_IF_ENABLED();
7831 
7832   float biasVal[2] = {1.1, 2.2};
7833   auto bias = mod_.createPlaceholder(ElemKind::FloatTy, {2}, "bias", false);
7834   bindings_.allocate(bias)->getHandle() = biasVal;
7835 
7836   auto *input =
7837       mod_.createPlaceholder(ElemKind::FloatTy, {1, 4, 4, 3}, "input", false);
7838   auto IH = bindings_.allocate(input)->getHandle();
7839   for (dim_t i = 0; i < IH.size(); i++) {
7840     IH.raw(i) = i;
7841   }
7842 
7843   auto filter =
7844       mod_.createPlaceholder(ElemKind::FloatTy, {2, 3, 2, 3}, "filter", false);
7845   auto FH = bindings_.allocate(filter)->getHandle();
7846   for (dim_t i = 0; i < FH.size(); i++) {
7847     FH.raw(i) = i * 2;
7848   }
7849 
7850   std::pair<dim_t, dim_t> outWH =
7851       calculateConvTransposeOutputDims(4, 4, {3, 2}, {1, 2}, {0, 3, 1, 2});
7852   auto outTy =
7853       mod_.uniqueType(ElemKind::FloatTy, {1, outWH.first, outWH.second, 2});
7854 
7855   ConvTransposeNode *CN =
7856       F_->createConvTranspose("ConvTranspose", input, filter, bias, outTy,
7857                               {3, 2}, {1, 2}, {0, 3, 1, 2}, 1);
7858 
7859   SaveNode *S = F_->createSave("save", CN);
7860   bindings_.allocate(S->getPlaceholder());
7861 
7862   ::glow::convertPlaceholdersToConstants(F_, bindings_,
7863                                          {input, S->getPlaceholder()});
7864   EE_.compile(CompilationMode::Infer);
7865   EE_.run(bindings_);
7866   auto result = bindings_.get(S->getPlaceholder())->getHandle();
7867   std::vector<dim_t> expectedDims = {1, 5, 3, 2};
7868   ASSERT_TRUE(result.dims().vec() == expectedDims);
7869   // values from onnxruntime w/o bias, thus bias is added during compare.
7870   std::vector<float> expected = {
7871       100,  532,   46,   802,   172,  928,   632,  2792,  416,  3224,
7872       884,  3692,  2028, 7212,  1542, 7698,  2568, 8724,  4188, 13260,
7873       3054, 13098, 4728, 14772, 5096, 12440, 4232, 12224, 5564, 13556};
7874   for (dim_t i = 0; i < result.size(); i++) {
7875     float exp = expected[i] + biasVal[i % 2];
7876     EXPECT_FLOAT_EQ(result.raw(i), exp);
7877   }
7878 }
7879 
7880 /// Compare ConvTranspose with equivalent Convolution, no strides.
7881 /// TODO - need version with Strides (dilate input).
7882 template <unsigned_t kernel, unsigned_t stride, unsigned_t pad, unsigned_t idim>
convTransposeConvCompare(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE)7883 static void convTransposeConvCompare(glow::PlaceholderBindings &bindings,
7884                                      glow::Module &mod, glow::Function *F,
7885                                      glow::ExecutionEngine &EE) {
7886   unsigned_t Cpad = kernel - pad - 1;
7887   llvm::SmallVector<unsigned_t, 4> pads = {pad, pad, pad, pad};
7888   llvm::SmallVector<unsigned_t, 4> Cpads = {Cpad, Cpad, Cpad, Cpad};
7889   llvm::SmallVector<unsigned_t, 2> kernels = {kernel, kernel};
7890   llvm::SmallVector<unsigned_t, 2> strides = {stride, stride};
7891 
7892   auto *input = mod.createPlaceholder(ElemKind::FloatTy, {1, idim, idim, 1},
7893                                       "input", false);
7894   bindings.allocate(input)->getHandle().randomize(-10.0, 10.0, mod.getPRNG());
7895 
7896   auto *filterConv = mod.createPlaceholder(
7897       ElemKind::FloatTy, {1, kernel, kernel, 1}, "filterC", false);
7898   bindings.allocate(filterConv)
7899       ->getHandle()
7900       .randomize(-10.0, 10.0, mod.getPRNG());
7901   auto FCH = bindings.get(filterConv)->getHandle();
7902 
7903   auto *filterConvTr = mod.createPlaceholder(
7904       ElemKind::FloatTy, {1, kernel, kernel, 1}, "filterD", false);
7905   auto FDH = bindings.allocate(filterConvTr)->getHandle();
7906   for (dim_t i = 0; i < kernel * kernel; i++) {
7907     FDH.raw(i) = FCH.raw(kernel * kernel - i - 1);
7908   }
7909 
7910   auto *bias = mod.createPlaceholder(ElemKind::FloatTy, {1}, "bias", false);
7911   bindings.allocate(bias)->zero();
7912 
7913   std::pair<dim_t, dim_t> outHW =
7914       calculateConvTransposeOutputDims(idim, idim, kernels, strides, pads);
7915 
7916   auto outTy =
7917       mod.uniqueType(ElemKind::FloatTy, {1, outHW.first, outHW.second, 1});
7918 
7919   ConvolutionNode *CN = F->createConv("conv", input, filterConv, bias, outTy,
7920                                       kernels, strides, Cpads, 1, 1);
7921   ConvTransposeNode *DN =
7922       F->createConvTranspose("ConvTranspose", input, filterConvTr, bias, outTy,
7923                              kernels, strides, pads, 1, 1);
7924 
7925   SaveNode *SC = F->createSave("saveC", CN);
7926   bindings.allocate(SC->getPlaceholder());
7927 
7928   SaveNode *SD = F->createSave("saveD", DN);
7929   bindings.allocate(SD->getPlaceholder());
7930 
7931   ::glow::convertPlaceholdersToConstants(
7932       F, bindings, {input, SC->getPlaceholder(), SD->getPlaceholder()});
7933   EE.compile(CompilationMode::Infer);
7934   EE.run(bindings);
7935 
7936   outHW = calculateConvPoolOutputDims(idim, idim, kernels, strides, Cpads, 1);
7937 
7938   auto resultConv = bindings.get(SC->getPlaceholder())->getHandle();
7939   auto resultConvTranspose = bindings.get(SD->getPlaceholder())->getHandle();
7940 
7941   std::vector<dim_t> expectedDims = {1, outHW.first, outHW.second, 1};
7942   ASSERT_TRUE(resultConv.dims().vec() == expectedDims);
7943   ASSERT_TRUE(resultConvTranspose.dims().vec() == expectedDims);
7944 
7945   for (dim_t i = 0; i < outHW.first; i++) {
7946     for (dim_t j = 0; j < outHW.second; j++) {
7947       EXPECT_FLOAT_EQ(static_cast<float>(resultConv.at({0, i, j, 0})),
7948                       static_cast<float>(resultConvTranspose.at({0, i, j, 0})));
7949     }
7950   }
7951 }
7952 
TEST_P(OperatorTest,ConvTransposeonvolutionCompareSimpleK8S1P0I3)7953 TEST_P(OperatorTest, ConvTransposeonvolutionCompareSimpleK8S1P0I3) {
7954   ENABLED_BACKENDS("Interpreter", "CPU");
7955   convTransposeConvCompare<8, 1, 0, 3>(bindings_, mod_, F_, EE_);
7956 }
7957 
TEST_P(OperatorTest,ConvTransposeConvolutionCompareSimpleK6S1P1I4)7958 TEST_P(OperatorTest, ConvTransposeConvolutionCompareSimpleK6S1P1I4) {
7959   ENABLED_BACKENDS("Interpreter", "CPU");
7960   convTransposeConvCompare<6, 1, 1, 4>(bindings_, mod_, F_, EE_);
7961 }
7962 
TEST_P(OperatorTest,ConvTransposeConvolutionCompareSimpleK5S1P2I3)7963 TEST_P(OperatorTest, ConvTransposeConvolutionCompareSimpleK5S1P2I3) {
7964   ENABLED_BACKENDS("Interpreter", "CPU");
7965   convTransposeConvCompare<5, 1, 2, 3>(bindings_, mod_, F_, EE_);
7966 }
7967 
TEST_P(OperatorTest,GroupConvolution)7968 TEST_P(OperatorTest, GroupConvolution) {
7969   CHECK_IF_ENABLED();
7970 
7971   auto *input =
7972       mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 1, 8}, "input", false);
7973   auto IH = bindings_.allocate(input)->getHandle();
7974   for (dim_t i = 0; i < 2 * 8; i++) {
7975     IH.raw(i) = i + 1;
7976   }
7977 
7978   auto filter =
7979       mod_.createPlaceholder(ElemKind::FloatTy, {6, 1, 1, 4}, "filter", false);
7980   auto FH = bindings_.allocate(filter)->getHandle();
7981   for (dim_t i = 0; i < 6; i++)
7982     for (dim_t j = 0; j < 4; j++) {
7983       FH.at({i, 0, 0, j}) = pow(10.0, i);
7984     }
7985 
7986   auto *zeroBias =
7987       mod_.createPlaceholder(ElemKind::FloatTy, {6}, "bias", false);
7988   bindings_.allocate(zeroBias)->zero();
7989 
7990   auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 2, 1, 6});
7991 
7992   ConvolutionNode *CN =
7993       F_->createConv("Conv", input, filter, zeroBias, outTy, 1, 1, 0, 2);
7994   SaveNode *S = F_->createSave("save", CN);
7995   bindings_.allocate(S->getPlaceholder());
7996 
7997   ::glow::convertPlaceholdersToConstants(F_, bindings_,
7998                                          {input, S->getPlaceholder()});
7999   EE_.compile(CompilationMode::Infer);
8000   EE_.run(bindings_);
8001 
8002   auto result = bindings_.get(S->getPlaceholder())->getHandle();
8003 
8004   std::vector<dim_t> expectedDims = {1, 2, 1, 6};
8005   ASSERT_TRUE(result.dims().vec() == expectedDims);
8006   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 0}), 1 + 2 + 3 + 4);
8007   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 1}), (1 + 2 + 3 + 4) * 10);
8008   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 2}), (1 + 2 + 3 + 4) * 100);
8009   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 3}), (5 + 6 + 7 + 8) * 1000);
8010   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 4}), (5 + 6 + 7 + 8) * 10000);
8011   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 5}), (5 + 6 + 7 + 8) * 100000);
8012   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 0}), 9 + 10 + 11 + 12);
8013   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 1}), (9 + 10 + 11 + 12) * 10);
8014   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 2}), (9 + 10 + 11 + 12) * 100);
8015   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 3}), (13 + 14 + 15 + 16) * 1000);
8016   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 4}), (13 + 14 + 15 + 16) * 10000);
8017   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 5}), (13 + 14 + 15 + 16) * 100000);
8018 }
8019 
8020 /// Utility function to test numerically the ChannelwiseQuantizedConvolution2D
8021 /// against a floating point Convolution for different parameters.
testChannelwiseQuantizedConv2D(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,quantization::Schema schema,ElemKind elemQKind,ElemKind biasElemQKind,bool filterFloat,bool biasFloat,bool biasScalesExplicit)8022 static void testChannelwiseQuantizedConv2D(
8023     glow::PlaceholderBindings &bindings, glow::Module &mod, glow::Function *F,
8024     glow::ExecutionEngine &EE, quantization::Schema schema, ElemKind elemQKind,
8025     ElemKind biasElemQKind, bool filterFloat, bool biasFloat,
8026     bool biasScalesExplicit) {
8027 
8028   std::vector<dim_t> inputDims = {5, 10, 10, 4};
8029   std::vector<dim_t> filterDims = {8, 3, 3, 2};
8030   std::vector<dim_t> biasDims = {8};
8031   std::vector<dim_t> outputDims = {5, 6, 6, 8};
8032   std::vector<unsigned_t> kernels = {3, 3};
8033   std::vector<unsigned_t> strides = {1, 1};
8034   std::vector<unsigned_t> pads = {0, 0, 0, 0};
8035   dim_t group = 2;
8036   dim_t dilation = 2;
8037   dim_t qDim = 0;
8038   dim_t qStep = 1;
8039 
8040   // Create input placeholder.
8041   auto *inputF =
8042       mod.createPlaceholder(ElemKind::FloatTy, inputDims, "inputF", false);
8043   bindings.allocate(inputF)->getHandle<float>().randomize(-1.0, 1.0,
8044                                                           mod.getPRNG());
8045 
8046   // Quantize input.
8047   auto inputTQP =
8048       quantization::chooseQuantizationParams({-1.0, 1.0}, schema, elemQKind);
8049   auto *inputQTy =
8050       mod.uniqueType(elemQKind, inputDims, inputTQP.scale, inputTQP.offset);
8051   auto *inputQ = F->createQuantize("inputQ", inputF, inputQTy);
8052 
8053   // Create float filter constant.
8054   auto *filterF = mod.createConstant(ElemKind::FloatTy, filterDims, "filterF");
8055   filterF->getPayloadMutable().getHandle<float>().randomize(-1.0, 1.0,
8056                                                             mod.getPRNG());
8057 
8058   // Create float bias constant.
8059   auto *biasF = mod.createConstant(ElemKind::FloatTy, biasDims, "biasF");
8060   biasF->getPayloadMutable().getHandle<float>().randomize(-1.0, 1.0,
8061                                                           mod.getPRNG());
8062 
8063   // Create quantized filter and filterScales/filterOffsets constants for
8064   // ChannelwiseQuantizedConvolution.
8065   dim_t numChannels = outputDims[3];
8066   Constant *filterQ =
8067       mod.createConstant(elemQKind, filterDims, 1.0, 0, "filterQ");
8068   Constant *filterScales =
8069       mod.createConstant(ElemKind::FloatTy, {numChannels}, "filterScales");
8070   Constant *filterOffsets =
8071       mod.createConstant(ElemKind::Int32ITy, {numChannels}, "filterOffsets");
8072   quantization::getTensorQuantizationParams(
8073       filterF->getPayload(), filterScales->getPayloadMutable(),
8074       filterOffsets->getPayloadMutable(), schema, elemQKind, qDim, qStep);
8075   filterQ->getPayloadMutable() = quantization::quantizeTensor(
8076       filterF->getPayload(), filterScales->getPayload(),
8077       filterOffsets->getPayload(), elemQKind, qDim, qStep);
8078 
8079   // Create quantized bias and biasScales/biasOffsets constants for
8080   // ChannelwiseQuantizedConvolution.
8081   Constant *biasQ =
8082       mod.createConstant(biasElemQKind, {numChannels}, 1.0, 0, "biasQ");
8083   Constant *biasScales =
8084       mod.createConstant(ElemKind::FloatTy, {numChannels}, "biasScales");
8085   Constant *biasOffsets =
8086       mod.createConstant(ElemKind::Int32ITy, {numChannels}, "biasOffsets");
8087   auto biasScalesH = biasScales->getPayload().getHandle<float>();
8088   auto biasOffsetsH = biasOffsets->getPayload().getHandle<int32_t>();
8089   auto filterScalesH = filterScales->getPayload().getHandle<float>();
8090   auto filterOffsetsH = filterOffsets->getPayload().getHandle<int32_t>();
8091   auto inputScale = inputQ->getResult().getType()->getScale();
8092   auto inputOffset = inputQ->getResult().getType()->getOffset();
8093   if (biasScalesExplicit) {
8094     quantization::getTensorQuantizationParams(
8095         biasF->getPayload(), biasScales->getPayloadMutable(),
8096         biasOffsets->getPayloadMutable(), schema, biasElemQKind, qDim, qStep);
8097     for (dim_t idx = 0; idx < numChannels; idx++) {
8098       auto biasTQPNew = specializeBiasQuantizationParams(
8099           {biasScalesH.raw(idx), biasOffsetsH.raw(idx)},
8100           {inputScale, inputOffset},
8101           {filterScalesH.raw(idx), filterOffsetsH.raw(idx)}, schema,
8102           biasElemQKind);
8103       biasScalesH.raw(idx) = biasTQPNew.scale;
8104       biasOffsetsH.raw(idx) = biasTQPNew.offset;
8105     }
8106   } else {
8107     for (dim_t idx = 0; idx < numChannels; idx++) {
8108       float filterScale = filterScalesH.raw(idx);
8109       biasScalesH.raw(idx) = inputScale * filterScale;
8110       biasOffsetsH.raw(idx) = 0;
8111     }
8112   }
8113   biasQ->getPayloadMutable() = quantization::quantizeTensor(
8114       biasF->getPayload(), biasScales->getPayload(), biasOffsets->getPayload(),
8115       biasElemQKind, qDim, qStep);
8116 
8117   // Get optimal output TQP based on inspecting the output range for the
8118   // particular values of the convolution parameters. If the convolution
8119   // sizes are changed than these parameters must be adjusted.
8120   auto outputTQP =
8121       quantization::chooseQuantizationParams({-6.0, 6.0}, schema, elemQKind);
8122   auto *outQTy =
8123       mod.uniqueType(elemQKind, outputDims, outputTQP.scale, outputTQP.offset);
8124 
8125   // Prepare parameters for ChannelwiseQuantizedConvolutionNode.
8126   Constant *filterCWQ = nullptr;
8127   Constant *filterScalesCWQ = nullptr;
8128   Constant *filterOffsetsCWQ = nullptr;
8129   if (filterFloat) {
8130     filterCWQ = filterF;
8131   } else {
8132     filterCWQ = filterQ;
8133     filterScalesCWQ = filterScales;
8134     filterOffsetsCWQ = filterOffsets;
8135   }
8136   Constant *biasCWQ = nullptr;
8137   Constant *biasScalesCWQ = nullptr;
8138   Constant *biasOffsetsCWQ = nullptr;
8139   if (biasFloat) {
8140     biasCWQ = biasF;
8141   } else {
8142     biasCWQ = biasQ;
8143   }
8144   if (biasScalesExplicit) {
8145     biasScalesCWQ = biasScales;
8146     biasOffsetsCWQ = biasOffsets;
8147   }
8148 
8149   // Create ChannelwiseQuantizedConvolution and Dequantize.
8150   ChannelwiseQuantizedConvolutionNode *outQ = F->createChannelwiseQuantizedConv(
8151       "CWQConv", inputQ, filterCWQ, biasCWQ, filterScalesCWQ, filterOffsetsCWQ,
8152       biasScalesCWQ, biasOffsetsCWQ, outQTy, kernels, strides, pads, group,
8153       dilation, /* quantizeFilter */ true, /* quantizeBias */ true, schema,
8154       elemQKind, biasElemQKind);
8155   DequantizeNode *out =
8156       F->createDequantize("dequantize", outQ, ElemKind::FloatTy);
8157   SaveNode *saveOut = F->createSave("saveOut", out);
8158   bindings.allocate(saveOut->getPlaceholder());
8159 
8160   // Create reference floating-point Convolution.
8161   auto *refTy = mod.uniqueType(ElemKind::FloatTy, outputDims);
8162   ConvolutionNode *ref = F->createConv("Conv", inputF, filterF, biasF, refTy,
8163                                        kernels, strides, pads, group, dilation);
8164   SaveNode *saveRef = F->createSave("saveRef", ref);
8165   bindings.allocate(saveRef->getPlaceholder());
8166 
8167   // Compile and run.
8168   EE.compile(CompilationMode::Infer);
8169   EE.run(bindings);
8170 
8171   // Extra validations.
8172   EXPECT_EQ(F->getNodes().size(), 6);
8173   EXPECT_EQ(outQ->getFilter().getElementType(), elemQKind);
8174   EXPECT_EQ(outQ->getBias().getElementType(), biasElemQKind);
8175 
8176   // Check error. If bias is carefully quantized then the bias precision does
8177   // not matter and so the error tolerance is the same.
8178   auto outH = bindings.get(saveOut->getPlaceholder())->getHandle();
8179   auto refH = bindings.get(saveRef->getPlaceholder())->getHandle();
8180   for (dim_t idx = 0; idx < refH.size(); idx++) {
8181     float errVal = std::abs(refH.raw(idx) - outH.raw(idx));
8182     EXPECT_TRUE(errVal < 0.05);
8183   }
8184 }
8185 
8186 #define TEST_CWQCONV(testName, ...)                                            \
8187   TEST_P(OperatorTest, testName) {                                             \
8188     CHECK_IF_ENABLED();                                                        \
8189     testChannelwiseQuantizedConv2D(bindings_, mod_, F_, EE_,                   \
8190                                    quantization::Schema::Asymmetric,           \
8191                                    __VA_ARGS__);                               \
8192   }
8193 
8194 /// These unit tests prove that the bias quantization for low precision (Int8)
8195 /// requires a special handling because if we provide a quantized bias with
8196 /// implicit quantization parameters biasScales[i] = inputScale*filterScales[i]
8197 /// and biasOffsets[i]=0 does not work numerically due to BIAS DATA saturation.
8198 /// Therefore in the unit tests below we do not use the *_*FF tests.
TEST_CWQCONV(ChannelwiseQuantizedConv2D_Int8_BiasInt8_FFT,ElemKind::Int8QTy,ElemKind::Int8QTy,false,false,true)8199 TEST_CWQCONV(ChannelwiseQuantizedConv2D_Int8_BiasInt8_FFT, ElemKind::Int8QTy,
8200              ElemKind::Int8QTy, false, false, true)
8201 TEST_CWQCONV(ChannelwiseQuantizedConv2D_Int8_BiasInt8_FTF, ElemKind::Int8QTy,
8202              ElemKind::Int8QTy, false, true, false)
8203 TEST_CWQCONV(ChannelwiseQuantizedConv2D_Int8_BiasInt8_FTT, ElemKind::Int8QTy,
8204              ElemKind::Int8QTy, false, true, true)
8205 TEST_CWQCONV(ChannelwiseQuantizedConv2D_Int8_BiasInt8_TFT, ElemKind::Int8QTy,
8206              ElemKind::Int8QTy, true, false, true)
8207 TEST_CWQCONV(ChannelwiseQuantizedConv2D_Int8_BiasInt8_TTF, ElemKind::Int8QTy,
8208              ElemKind::Int8QTy, true, true, false)
8209 TEST_CWQCONV(ChannelwiseQuantizedConv2D_Int8_BiasInt8_TTT, ElemKind::Int8QTy,
8210              ElemKind::Int8QTy, true, true, true)
8211 
8212 /// These unit tests prove that the bias quantization for high precision (Int32)
8213 /// can work without a special handling (implicit quantization parameters).
8214 TEST_CWQCONV(ChannelwiseQuantizedConv2D_Int8_BiasInt32_FFF, ElemKind::Int8QTy,
8215              ElemKind::Int32QTy, false, false, false)
8216 TEST_CWQCONV(ChannelwiseQuantizedConv2D_Int8_BiasInt32_FFT, ElemKind::Int8QTy,
8217              ElemKind::Int32QTy, false, false, true)
8218 TEST_CWQCONV(ChannelwiseQuantizedConv2D_Int8_BiasInt32_FTF, ElemKind::Int8QTy,
8219              ElemKind::Int32QTy, false, true, false)
8220 TEST_CWQCONV(ChannelwiseQuantizedConv2D_Int8_BiasInt32_FTT, ElemKind::Int8QTy,
8221              ElemKind::Int32QTy, false, true, true)
8222 TEST_CWQCONV(ChannelwiseQuantizedConv2D_Int8_BiasInt32_TFF, ElemKind::Int8QTy,
8223              ElemKind::Int32QTy, true, false, false)
8224 TEST_CWQCONV(ChannelwiseQuantizedConv2D_Int8_BiasInt32_TFT, ElemKind::Int8QTy,
8225              ElemKind::Int32QTy, true, false, true)
8226 TEST_CWQCONV(ChannelwiseQuantizedConv2D_Int8_BiasInt32_TTF, ElemKind::Int8QTy,
8227              ElemKind::Int32QTy, true, true, false)
8228 TEST_CWQCONV(ChannelwiseQuantizedConv2D_Int8_BiasInt32_TTT, ElemKind::Int8QTy,
8229              ElemKind::Int32QTy, true, true, true)
8230 #undef TEST_CWQCONV
8231 
8232 /// Utility function to test numerically the ChannelwiseQuantizedConvolution2D
8233 /// against Interpreter implementation.
8234 static FunctionTensorPair
8235 createAndInitBasicChannelwiseConv2DTest(glow::PlaceholderBindings &bindings,
8236                                         glow::ExecutionEngine &EE) {
8237 
8238   auto &mod = EE.getModule();
8239   Function *F = mod.createFunction("main");
8240 
8241   std::vector<dim_t> inputDims = {5, 10, 10, 4};
8242   std::vector<dim_t> filterDims = {8, 3, 3, 2};
8243   std::vector<dim_t> biasDims = {8};
8244   std::vector<dim_t> outputDims = {5, 6, 6, 8};
8245   std::vector<unsigned_t> kernels = {3, 3};
8246   std::vector<unsigned_t> strides = {1, 1};
8247   std::vector<unsigned_t> pads = {0, 0, 0, 0};
8248   dim_t group = 2;
8249   dim_t dilation = 2;
8250 
8251   // Create input placeholder.
8252   auto *input =
8253       mod.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
8254   bindings.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
8255                                                          mod.getPRNG());
8256 
8257   // Create filter constant.
8258   auto *filter = mod.createConstant(ElemKind::FloatTy, filterDims, "filter");
8259   filter->getPayloadMutable().getHandle<float>().randomize(-1.0, 1.0,
8260                                                            mod.getPRNG());
8261 
8262   // Create bias constant.
8263   auto *bias = mod.createConstant(ElemKind::FloatTy, biasDims, "bias");
8264   bias->getPayloadMutable().getHandle<float>().randomize(-1.0, 1.0,
8265                                                          mod.getPRNG());
8266 
8267   // Create Convolution.
8268   auto *outTy = mod.uniqueType(ElemKind::FloatTy, outputDims);
8269   ConvolutionNode *conv =
8270       F->createConv("Conv", input, filter, bias, outTy, kernels, strides, pads,
8271                     group, dilation);
8272   SaveNode *save = F->createSave("save", conv);
8273   auto *outputTensor = bindings.allocate(save->getPlaceholder());
8274   return std::make_pair(F, outputTensor);
8275 }
8276 
8277 /// Test Int8 ChannelwiseQuantizedConvolution2D with Int8 bias.
TEST_P(OperatorStatelessTest,ChannelwiseQuantizedConv2D_Int8_BiasInt8)8278 TEST_P(OperatorStatelessTest, ChannelwiseQuantizedConv2D_Int8_BiasInt8) {
8279   CHECK_IF_ENABLED();
8280   compareAgainstInterpreter(
8281       getBackendName(), createAndInitBasicChannelwiseConv2DTest,
8282       ElemKind::FloatTy, ElemKind::Int8QTy, 0.05f, parCloneCountOpt,
8283       /* convertToRowwiseQuantization */ false,
8284       quantization::Schema::Asymmetric, ElemKind::Int8QTy,
8285       /* forceFP16AccumSLS */ false,
8286       PrecisionConfiguration::Float16Format::None,
8287       /* convertToChannelwiseQuantization */ true);
8288 }
8289 
8290 /// Test Int8 ChannelwiseQuantizedConvolution2D with Int32 bias.
TEST_P(OperatorStatelessTest,ChannelwiseQuantizedConv2D_Int8_BiasInt32)8291 TEST_P(OperatorStatelessTest, ChannelwiseQuantizedConv2D_Int8_BiasInt32) {
8292   CHECK_IF_ENABLED();
8293   compareAgainstInterpreter(
8294       getBackendName(), createAndInitBasicChannelwiseConv2DTest,
8295       ElemKind::FloatTy, ElemKind::Int8QTy, 0.05f, parCloneCountOpt,
8296       /* convertToRowwiseQuantization */ false,
8297       quantization::Schema::Asymmetric, ElemKind::Int32QTy,
8298       /* forceFP16AccumSLS */ false,
8299       PrecisionConfiguration::Float16Format::None,
8300       /* convertToChannelwiseQuantization */ true);
8301 }
8302 
8303 /// Test the functionality of channelwise quantized group convolution using
8304 /// ChannelwiseQuantizedConvNode.
TEST_P(OperatorTest,ChannelwiseQuantizedConv2D)8305 TEST_P(OperatorTest, ChannelwiseQuantizedConv2D) {
8306   CHECK_IF_ENABLED();
8307 
8308   constexpr size_t groups = 2;
8309   constexpr dim_t output_channel = 4;
8310 
8311   auto *input =
8312       mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 2}, "input", false);
8313   auto IH = bindings_.allocate(input)->getHandle<float>();
8314   for (size_t i = 0; i < 2 * 3 * 2; i++) {
8315     IH.raw(i) = i + 1;
8316   }
8317 
8318   auto *qInTy = mod_.uniqueType(ElemKind::Int8QTy, {1, 2, 3, 2}, 1.0, 0);
8319   auto *qInput = F_->createQuantize("qInput", input, qInTy);
8320 
8321   auto filterT = Tensor(ElemKind::Int8QTy, {4, 2, 1, 1}, 1.0, 0);
8322   for (dim_t i = 0; i < 4; i++) {
8323     for (dim_t j = 0; j < 2; j++) {
8324       for (dim_t k = 0; k < 1; k++) {
8325         for (dim_t l = 0; l < 1; l++) {
8326           filterT.getHandle<int8_t>().at({i, j, k, l}) = j + 1;
8327         }
8328       }
8329     }
8330   }
8331   auto *filter = mod_.createConstant("filter", std::move(filterT));
8332 
8333   auto biasT = Tensor(ElemKind::FloatTy, {4});
8334   biasT.zero();
8335   auto *bias = mod_.createConstant("bias", std::move(biasT));
8336 
8337   auto filterScalesT = Tensor(ElemKind::FloatTy, {output_channel});
8338   for (size_t i = 0; i < filterScalesT.size(); i++) {
8339     filterScalesT.getHandle<float>().raw(i) = 1;
8340   }
8341   auto *filterScales =
8342       mod_.createConstant("filterScales", std::move(filterScalesT));
8343 
8344   auto filterOffsetsT = Tensor(ElemKind::Int32ITy, {output_channel});
8345   filterOffsetsT.zero();
8346   auto *filterOffsets =
8347       mod_.createConstant("filterOffsets", std::move(filterOffsetsT));
8348 
8349   auto *outTy = mod_.uniqueType(ElemKind::Int8QTy, {1, 1, 3, 4}, 1.0, 0);
8350   ChannelwiseQuantizedConvolutionNode *CQC = F_->createChannelwiseQuantizedConv(
8351       "channelwiseQuantizedConv", qInput, filter, bias, filterScales,
8352       filterOffsets, /* biasScales */ nullptr, /* biasOffsets */ nullptr, outTy,
8353       {2, 1}, {1, 1}, {0, 0, 0, 0}, groups);
8354 
8355   DequantizeNode *dq =
8356       F_->createDequantize("dequantize", CQC, ElemKind::FloatTy);
8357   SaveNode *S = F_->createSave("save", dq);
8358   bindings_.allocate(S->getPlaceholder());
8359 
8360   ::glow::convertPlaceholdersToConstants(F_, bindings_,
8361                                          {input, S->getPlaceholder()});
8362 
8363   EE_.compile(CompilationMode::Infer);
8364   EE_.run(bindings_);
8365 
8366   auto result = bindings_.get(S->getPlaceholder())->getHandle();
8367 
8368   std::vector<dim_t> expectedDims = {1, 1, 3, 4};
8369   ASSERT_TRUE(result.dims().vec() == expectedDims);
8370   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 0}), 15);
8371   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 1}), 15);
8372   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 2}), 18);
8373   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 3}), 18);
8374   EXPECT_FLOAT_EQ(result.at({0, 0, 1, 0}), 21);
8375   EXPECT_FLOAT_EQ(result.at({0, 0, 1, 1}), 21);
8376 
8377   EXPECT_FLOAT_EQ(result.at({0, 0, 1, 2}), 24);
8378   EXPECT_FLOAT_EQ(result.at({0, 0, 1, 3}), 24);
8379   EXPECT_FLOAT_EQ(result.at({0, 0, 2, 0}), 27);
8380   EXPECT_FLOAT_EQ(result.at({0, 0, 2, 1}), 27);
8381   EXPECT_FLOAT_EQ(result.at({0, 0, 2, 2}), 30);
8382   EXPECT_FLOAT_EQ(result.at({0, 0, 2, 3}), 30);
8383 }
8384 
8385 /// Test the functionality of channelwise quantized group convolution using
8386 /// ChannelwiseQuantizedConvNode.
TEST_P(OperatorTest,ChannelwiseQuantizedConv3D)8387 TEST_P(OperatorTest, ChannelwiseQuantizedConv3D) {
8388   CHECK_IF_ENABLED();
8389 
8390   constexpr size_t groups = 2;
8391   constexpr dim_t output_channel = 4;
8392   constexpr dim_t input_channel = 2;
8393 
8394   auto *input = mod_.createPlaceholder(
8395       ElemKind::FloatTy, {1, input_channel, 2, 3, 2}, "input", false);
8396   auto IH = bindings_.allocate(input)->getHandle<float>();
8397   for (size_t i = 0; i < input_channel * 2 * 3 * 2; i++) {
8398     IH.raw(i) = i + 1;
8399   }
8400 
8401   auto *qInTy =
8402       mod_.uniqueType(ElemKind::Int8QTy, {1, input_channel, 2, 3, 2}, 1.0, 0);
8403   auto *qInput = F_->createQuantize("qInput", input, qInTy);
8404 
8405   auto filterT = Tensor(
8406       ElemKind::Int8QTy,
8407       {output_channel / groups, input_channel / groups, 1, 1, 1}, 1.0, 0);
8408   for (dim_t i = 0; i < output_channel / groups; i++) {
8409     for (dim_t j = 0; j < input_channel / groups; j++) {
8410       for (dim_t t = 0; t < 1; t++) {
8411         for (dim_t k = 0; k < 1; k++) {
8412           for (dim_t l = 0; l < 1; l++) {
8413             filterT.getHandle<int8_t>().at({i, j, t, k, l}) = j + 1;
8414           }
8415         }
8416       }
8417     }
8418   }
8419   auto *filter = mod_.createConstant("filter", std::move(filterT));
8420 
8421   auto biasT = Tensor(ElemKind::FloatTy, {output_channel / groups});
8422   biasT.zero();
8423   auto *bias = mod_.createConstant("bias", std::move(biasT));
8424 
8425   auto scalesT = Tensor(ElemKind::FloatTy, {output_channel / groups});
8426   for (size_t i = 0; i < scalesT.size(); i++) {
8427     scalesT.getHandle<float>().raw(i) = 1;
8428   }
8429   auto *scales = mod_.createConstant("scales", std::move(scalesT));
8430 
8431   auto offsetsT = Tensor(ElemKind::Int32ITy, {output_channel / groups});
8432   offsetsT.zero();
8433   auto *offsets = mod_.createConstant("offsets", std::move(offsetsT));
8434 
8435   auto *outTy = mod_.uniqueType(ElemKind::Int8QTy,
8436                                 {1, output_channel / groups, 2, 3, 2}, 1.0, 0);
8437   ChannelwiseQuantizedConvolutionNode *CQC = F_->createChannelwiseQuantizedConv(
8438       "channelwiseQuantizedConv", qInput, filter, bias, scales, offsets,
8439       /* biasScales */ nullptr, /* biasOffsets */ nullptr, outTy, {1, 1, 1},
8440       {1, 1, 1}, {0, 0, 0, 0, 0, 0}, groups);
8441 
8442   DequantizeNode *dq =
8443       F_->createDequantize("dequantize", CQC, ElemKind::FloatTy);
8444   SaveNode *S = F_->createSave("save", dq);
8445   bindings_.allocate(S->getPlaceholder());
8446 
8447   ::glow::convertPlaceholdersToConstants(F_, bindings_,
8448                                          {input, S->getPlaceholder()});
8449 
8450   EE_.compile(CompilationMode::Infer);
8451   EE_.run(bindings_);
8452 
8453   auto result = bindings_.get(S->getPlaceholder())->getHandle();
8454 
8455   std::vector<dim_t> expectedDims = {1, output_channel / groups, 2, 3, 2};
8456   ASSERT_TRUE(result.dims().vec() == expectedDims);
8457 
8458   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 0}), 1);
8459   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 1}), 3);
8460   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 2}), 5);
8461   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 3}), 7);
8462   EXPECT_FLOAT_EQ(result.at({0, 0, 1, 0}), 7);
8463   EXPECT_FLOAT_EQ(result.at({0, 0, 1, 1}), 9);
8464 
8465   EXPECT_FLOAT_EQ(result.at({0, 0, 1, 2}), 11);
8466   EXPECT_FLOAT_EQ(result.at({0, 0, 1, 3}), 13);
8467   EXPECT_FLOAT_EQ(result.at({0, 0, 2, 0}), 13);
8468   EXPECT_FLOAT_EQ(result.at({0, 0, 2, 1}), 15);
8469   EXPECT_FLOAT_EQ(result.at({0, 0, 2, 2}), 17);
8470   EXPECT_FLOAT_EQ(result.at({0, 0, 2, 3}), 19);
8471 }
8472 
TEST_P(OperatorTest,DilatedConvolution)8473 TEST_P(OperatorTest, DilatedConvolution) {
8474   CHECK_IF_ENABLED();
8475 
8476   auto *input =
8477       mod_.createPlaceholder(ElemKind::FloatTy, {1, 4, 1, 1}, "input", false);
8478   auto IH = bindings_.allocate(input)->getHandle();
8479   for (size_t i = 0; i < 4; i++) {
8480     IH.raw(i) = i + 1;
8481   }
8482 
8483   auto filter =
8484       mod_.createPlaceholder(ElemKind::FloatTy, {1, 3, 3, 1}, "filter", false);
8485   auto FH = bindings_.allocate(filter)->getHandle();
8486   for (dim_t i = 0; i < 3; i++)
8487     for (dim_t j = 0; j < 3; j++) {
8488       FH.at({0, i, j, 0}) = 1;
8489     }
8490   FH.at({0, 1, 1, 0}) = 0;
8491 
8492   auto *zeroBias =
8493       mod_.createPlaceholder(ElemKind::FloatTy, {1}, "bias", false);
8494   bindings_.allocate(zeroBias)->zero();
8495 
8496   auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 4, 1, 1});
8497 
8498   ConvolutionNode *CN =
8499       F_->createConv("Conv", input, filter, zeroBias, outTy, 3, 1, 2, 1, 2);
8500   SaveNode *S = F_->createSave("save", CN);
8501   bindings_.allocate(S->getPlaceholder());
8502 
8503   ::glow::convertPlaceholdersToConstants(F_, bindings_,
8504                                          {input, S->getPlaceholder()});
8505   EE_.compile(CompilationMode::Infer);
8506   EE_.run(bindings_);
8507 
8508   auto result = bindings_.get(S->getPlaceholder())->getHandle();
8509 
8510   std::vector<dim_t> expectedDims = {1, 4, 1, 1};
8511   EXPECT_TRUE(result.dims().vec() == expectedDims);
8512   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 0}), 3);
8513   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 0}), 4);
8514   EXPECT_FLOAT_EQ(result.at({0, 2, 0, 0}), 1);
8515   EXPECT_FLOAT_EQ(result.at({0, 3, 0, 0}), 2);
8516 }
8517 
8518 /// Test the functionality of channelwise quantized group convolution using
8519 /// ChannelwiseQuantizedConvNode with non-zero offsets and biases.
testChannelwiseQuantizedConv2DNonZero(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,bool quantizeBias)8520 void testChannelwiseQuantizedConv2DNonZero(glow::PlaceholderBindings &bindings,
8521                                            glow::Module &mod, glow::Function *F,
8522                                            glow::ExecutionEngine &EE,
8523                                            bool quantizeBias) {
8524   constexpr size_t groups = 2;
8525   constexpr dim_t output_channel = 4;
8526 
8527   auto *input =
8528       mod.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 2}, "input", false);
8529   auto IH = bindings.allocate(input)->getHandle<float>();
8530   for (size_t i = 0; i < 2 * 3 * 2; i++) {
8531     IH.raw(i) = i + 1;
8532   }
8533 
8534   auto *qInTy = mod.uniqueType(ElemKind::Int8QTy, {1, 2, 3, 2}, 2.5, 3);
8535   auto *qInput = F->createQuantize("qInput", input, qInTy);
8536 
8537   auto filterT = Tensor(ElemKind::Int8QTy, {4, 2, 1, 1}, 1.0, 0);
8538   for (dim_t i = 0; i < 4; i++) {
8539     for (dim_t j = 0; j < 2; j++) {
8540       for (dim_t k = 0; k < 1; k++) {
8541         for (dim_t l = 0; l < 1; l++) {
8542           filterT.getHandle<int8_t>().at({i, j, k, l}) = j + 1;
8543         }
8544       }
8545     }
8546   }
8547   auto *filter = mod.createConstant("filter", std::move(filterT));
8548 
8549   auto biasT = Tensor(ElemKind::FloatTy, {4});
8550   for (dim_t i = 0; i < 4; i++) {
8551     biasT.getHandle<float>().raw(i) = i + 1;
8552   }
8553   auto *bias = mod.createConstant("bias", std::move(biasT));
8554 
8555   auto filterScalesT = Tensor(ElemKind::FloatTy, {output_channel});
8556   for (size_t i = 0; i < filterScalesT.size(); i++) {
8557     filterScalesT.getHandle<float>().raw(i) = 1;
8558   }
8559   auto *filterScales =
8560       mod.createConstant("filterScales", std::move(filterScalesT));
8561 
8562   auto filterOffsetsT = Tensor(ElemKind::Int32ITy, {output_channel});
8563   filterOffsetsT.zero();
8564 
8565   auto *filterOffsets =
8566       mod.createConstant("filterOffsets", std::move(filterOffsetsT));
8567 
8568   auto *outTy = mod.uniqueType(ElemKind::Int8QTy, {1, 1, 3, 4}, 2, 2);
8569   ChannelwiseQuantizedConvolutionNode *CQC = F->createChannelwiseQuantizedConv(
8570       "channelwiseQuantizedConv", qInput, filter, bias, filterScales,
8571       filterOffsets, /* biasScales */ nullptr, /* biasOffsets */ nullptr, outTy,
8572       {2, 1}, {1, 1}, {0, 0, 0, 0}, groups, /* dilation */ 1,
8573       /* quantizeFilter */ false, quantizeBias);
8574 
8575   DequantizeNode *dq =
8576       F->createDequantize("dequantize", CQC, ElemKind::FloatTy);
8577   SaveNode *S = F->createSave("save", dq);
8578   bindings.allocate(S->getPlaceholder());
8579 
8580   ::glow::convertPlaceholdersToConstants(F, bindings,
8581                                          {input, S->getPlaceholder()});
8582 
8583   EE.compile(CompilationMode::Infer);
8584   EE.run(bindings);
8585 
8586   auto result = bindings.get(S->getPlaceholder())->getHandle();
8587 
8588   std::vector<dim_t> expectedDims = {1, 1, 3, 4};
8589   ASSERT_TRUE(result.dims().vec() == expectedDims);
8590   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 0}), 16);
8591   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 1}), 18);
8592   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 2}), 20);
8593   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 3}), 22);
8594   EXPECT_FLOAT_EQ(result.at({0, 0, 1, 0}), 22);
8595   EXPECT_FLOAT_EQ(result.at({0, 0, 1, 1}), 26);
8596 
8597   EXPECT_FLOAT_EQ(result.at({0, 0, 1, 2}), 28);
8598   EXPECT_FLOAT_EQ(result.at({0, 0, 1, 3}), 30);
8599   EXPECT_FLOAT_EQ(result.at({0, 0, 2, 0}), 26);
8600   EXPECT_FLOAT_EQ(result.at({0, 0, 2, 1}), 28);
8601   EXPECT_FLOAT_EQ(result.at({0, 0, 2, 2}), 32);
8602   EXPECT_FLOAT_EQ(result.at({0, 0, 2, 3}), 36);
8603 }
8604 
TEST_P(OperatorTest,ChannelwiseQuantizedConv2D_NonZero_FloatBias)8605 TEST_P(OperatorTest, ChannelwiseQuantizedConv2D_NonZero_FloatBias) {
8606   CHECK_IF_ENABLED();
8607   testChannelwiseQuantizedConv2DNonZero(bindings_, mod_, F_, EE_,
8608                                         /* quantizeBias */ false);
8609 }
8610 
TEST_P(OperatorTest,ChannelwiseQuantizedConv2D_NonZero_QuantizedBias)8611 TEST_P(OperatorTest, ChannelwiseQuantizedConv2D_NonZero_QuantizedBias) {
8612   CHECK_IF_ENABLED();
8613   testChannelwiseQuantizedConv2DNonZero(bindings_, mod_, F_, EE_,
8614                                         /* quantizeBias */ true);
8615 }
8616 
TEST_P(OperatorTest,GroupDilatedConvolution)8617 TEST_P(OperatorTest, GroupDilatedConvolution) {
8618   CHECK_IF_ENABLED();
8619 
8620   auto *input =
8621       mod_.createPlaceholder(ElemKind::FloatTy, {1, 4, 4, 2}, "input", false);
8622   auto IH = bindings_.allocate(input)->getHandle();
8623   for (dim_t i = 0; i < 4 * 4 * 2; i++) {
8624     IH.raw(i) = i;
8625   }
8626 
8627   auto filter =
8628       mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 2, 1}, "filter", false);
8629   auto FH = bindings_.allocate(filter)->getHandle();
8630   for (dim_t i = 0; i < 2; i++)
8631     for (dim_t j = 0; j < 2; j++) {
8632       for (dim_t k = 0; k < 2; k++) {
8633         FH.at({i, j, k, 0}) = 1;
8634       }
8635     }
8636 
8637   auto *zeroBias =
8638       mod_.createPlaceholder(ElemKind::FloatTy, {2}, "bias", false);
8639   bindings_.allocate(zeroBias)->zero();
8640 
8641   auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 4, 4, 2});
8642 
8643   ConvolutionNode *CN =
8644       F_->createConv("Conv", input, filter, zeroBias, outTy, 2, 1, 1, 2, 2);
8645   SaveNode *S = F_->createSave("save", CN);
8646   bindings_.allocate(S->getPlaceholder());
8647 
8648   ::glow::convertPlaceholdersToConstants(F_, bindings_,
8649                                          {input, S->getPlaceholder()});
8650   EE_.compile(CompilationMode::Infer);
8651   EE_.run(bindings_);
8652 
8653   auto result = bindings_.get(S->getPlaceholder())->getHandle();
8654 
8655   std::vector<dim_t> expectedDims = {1, 4, 4, 2};
8656   ASSERT_TRUE(result.dims().vec() == expectedDims);
8657 
8658   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 0}), 10);
8659   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 1}), 11);
8660   EXPECT_FLOAT_EQ(result.at({0, 0, 1, 0}), 20);
8661   EXPECT_FLOAT_EQ(result.at({0, 0, 1, 1}), 22);
8662   EXPECT_FLOAT_EQ(result.at({0, 0, 2, 0}), 24);
8663   EXPECT_FLOAT_EQ(result.at({0, 0, 2, 1}), 26);
8664   EXPECT_FLOAT_EQ(result.at({0, 0, 3, 0}), 12);
8665   EXPECT_FLOAT_EQ(result.at({0, 0, 3, 1}), 13);
8666 
8667   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 0}), 20);
8668   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 1}), 22);
8669   EXPECT_FLOAT_EQ(result.at({0, 1, 1, 0}), 40);
8670   EXPECT_FLOAT_EQ(result.at({0, 1, 1, 1}), 44);
8671   EXPECT_FLOAT_EQ(result.at({0, 1, 2, 0}), 48);
8672   EXPECT_FLOAT_EQ(result.at({0, 1, 2, 1}), 52);
8673   EXPECT_FLOAT_EQ(result.at({0, 1, 3, 0}), 24);
8674   EXPECT_FLOAT_EQ(result.at({0, 1, 3, 1}), 26);
8675 
8676   EXPECT_FLOAT_EQ(result.at({0, 2, 0, 0}), 36);
8677   EXPECT_FLOAT_EQ(result.at({0, 2, 0, 1}), 38);
8678   EXPECT_FLOAT_EQ(result.at({0, 2, 1, 0}), 72);
8679   EXPECT_FLOAT_EQ(result.at({0, 2, 1, 1}), 76);
8680   EXPECT_FLOAT_EQ(result.at({0, 2, 2, 0}), 80);
8681   EXPECT_FLOAT_EQ(result.at({0, 2, 2, 1}), 84);
8682   EXPECT_FLOAT_EQ(result.at({0, 2, 3, 0}), 40);
8683   EXPECT_FLOAT_EQ(result.at({0, 2, 3, 1}), 42);
8684 
8685   EXPECT_FLOAT_EQ(result.at({0, 3, 0, 0}), 18);
8686   EXPECT_FLOAT_EQ(result.at({0, 3, 0, 1}), 19);
8687   EXPECT_FLOAT_EQ(result.at({0, 3, 1, 0}), 36);
8688   EXPECT_FLOAT_EQ(result.at({0, 3, 1, 1}), 38);
8689   EXPECT_FLOAT_EQ(result.at({0, 3, 2, 0}), 40);
8690   EXPECT_FLOAT_EQ(result.at({0, 3, 2, 1}), 42);
8691   EXPECT_FLOAT_EQ(result.at({0, 3, 3, 0}), 20);
8692   EXPECT_FLOAT_EQ(result.at({0, 3, 3, 1}), 21);
8693 }
8694 
8695 /// Test Conv3D with group size of 2 to make sure that group 3d convolution
8696 /// works as expected.
TEST_P(OperatorTest,GroupConv3D)8697 TEST_P(OperatorTest, GroupConv3D) {
8698   CHECK_IF_ENABLED();
8699 
8700   auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 1, 2, 8},
8701                                        "input", false);
8702   auto IH = bindings_.allocate(input)->getHandle();
8703   for (size_t i = 0; i < input->getType()->size(); i++) {
8704     IH.raw(i) = i + 1;
8705   }
8706 
8707   auto *filter = mod_.createPlaceholder(ElemKind::FloatTy, {6, 1, 1, 1, 4},
8708                                         "filter", false);
8709   auto FH = bindings_.allocate(filter)->getHandle();
8710   for (dim_t i = 0; i < 6; i++)
8711     for (dim_t j = 0; j < 4; j++) {
8712       FH.at({i, 0, 0, 0, j}) = pow(10.0, i);
8713     }
8714 
8715   auto *zeroBias =
8716       mod_.createPlaceholder(ElemKind::FloatTy, {6}, "bias", false);
8717   bindings_.allocate(zeroBias)->zero();
8718 
8719   auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 2, 1, 2, 6});
8720 
8721   Convolution3DNode *CN =
8722       F_->createConv3D("Conv3D", input, filter, zeroBias, outTy, 1, 1, 0, 2);
8723   SaveNode *S = F_->createSave("save", CN);
8724   bindings_.allocate(S->getPlaceholder());
8725 
8726   ::glow::convertPlaceholdersToConstants(F_, bindings_,
8727                                          {input, S->getPlaceholder()});
8728   EE_.compile(CompilationMode::Infer);
8729   EE_.run(bindings_);
8730 
8731   auto result = bindings_.get(S->getPlaceholder())->getHandle();
8732 
8733   std::vector<dim_t> expectedDims = {1, 2, 1, 2, 6};
8734   ASSERT_TRUE(result.dims().vec() == expectedDims);
8735 
8736   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 0, 0}), 1 + 2 + 3 + 4);
8737   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 0, 1}), (1 + 2 + 3 + 4) * 10);
8738   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 0, 2}), (1 + 2 + 3 + 4) * 100);
8739   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 0, 3}), (5 + 6 + 7 + 8) * 1000);
8740   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 0, 4}), (5 + 6 + 7 + 8) * 10000);
8741   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 0, 5}), (5 + 6 + 7 + 8) * 100000);
8742 
8743   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 1, 0}), 9 + 10 + 11 + 12);
8744   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 1, 1}), (9 + 10 + 11 + 12) * 10);
8745   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 1, 2}), (9 + 10 + 11 + 12) * 100);
8746   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 1, 3}), (13 + 14 + 15 + 16) * 1000);
8747   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 1, 4}), (13 + 14 + 15 + 16) * 10000);
8748   EXPECT_FLOAT_EQ(result.at({0, 0, 0, 1, 5}), (13 + 14 + 15 + 16) * 100000);
8749 
8750   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 0, 0}), 17 + 18 + 19 + 20);
8751   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 0, 1}), (17 + 18 + 19 + 20) * 10);
8752   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 0, 2}), (17 + 18 + 19 + 20) * 100);
8753   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 0, 3}), (21 + 22 + 23 + 24) * 1000);
8754   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 0, 4}), (21 + 22 + 23 + 24) * 10000);
8755   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 0, 5}), (21 + 22 + 23 + 24) * 100000);
8756 
8757   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 1, 0}), 25 + 26 + 27 + 28);
8758   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 1, 1}), (25 + 26 + 27 + 28) * 10);
8759   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 1, 2}), (25 + 26 + 27 + 28) * 100);
8760   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 1, 3}), (29 + 30 + 31 + 32) * 1000);
8761   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 1, 4}), (29 + 30 + 31 + 32) * 10000);
8762   EXPECT_FLOAT_EQ(result.at({0, 1, 0, 1, 5}), (29 + 30 + 31 + 32) * 100000);
8763 }
8764 
8765 /// Check non-square padding for convolution. The first conv has non-square
8766 /// padding, while the second one has zero padding. The second conv's input is
8767 /// the same as the first one's after-padding input. All other parameters of
8768 /// the two convs are the same.
TEST_P(OperatorTest,NonSquarePaddingConvolution)8769 TEST_P(OperatorTest, NonSquarePaddingConvolution) {
8770   CHECK_IF_ENABLED();
8771 
8772   auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 4, 4, 1}, "input",
8773                                        false, "NHWC");
8774   auto IH = bindings_.allocate(input)->getHandle();
8775   for (dim_t i = 0; i < 4 * 4; i++) {
8776     IH.raw(i) = i + 1;
8777   }
8778 
8779   auto filter = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 2, 1},
8780                                        "filter", false, "NHWC");
8781   auto FH = bindings_.allocate(filter)->getHandle();
8782   for (dim_t i = 0; i < 2 * 2 * 2; i++) {
8783     FH.raw(i) = pow(2.0, i);
8784   }
8785   auto *zeroBias =
8786       mod_.createPlaceholder(ElemKind::FloatTy, {2}, "bias", false);
8787   bindings_.allocate(zeroBias)->zero();
8788 
8789   auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 4, 8, 2});
8790 
8791   ConvolutionNode *CN = F_->createConv("Conv", input, filter, zeroBias, outTy,
8792                                        {2, 2}, {1, 1}, {0, 2, 1, 3}, 1);
8793   SaveNode *S = F_->createSave("save", CN);
8794   bindings_.allocate(S->getPlaceholder());
8795 
8796   ::glow::convertPlaceholdersToConstants(F_, bindings_,
8797                                          {input, S->getPlaceholder()});
8798 
8799   Tensor &result = *bindings_.get(S->getPlaceholder());
8800 
8801   // Create the reference conv operator whose input is the same as the
8802   // after-padding-input above.
8803   auto *input1 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 9, 1},
8804                                         "input1", false, "NHWC");
8805   bindings_.allocate(input1)->zero();
8806   auto IH1 = bindings_.get(input1)->getHandle();
8807   for (dim_t i = 0; i < 4; i++)
8808     for (dim_t j = 2; j < 6; j++) {
8809       IH1.at({0, i, j, 0}) = i * 4 + j - 2 + 1;
8810     }
8811 
8812   Function *refF = mod_.createFunction("mainRef");
8813   CN = refF->createConv("Conv1", input1, filter, zeroBias, outTy, {2, 2},
8814                         {1, 1}, {0, 0, 0, 0}, 1);
8815   S = refF->createSave("save1", CN);
8816   bindings_.allocate(S->getPlaceholder());
8817 
8818   ::glow::convertPlaceholdersToConstants(refF, bindings_,
8819                                          {input, input1, S->getPlaceholder()});
8820   EE_.compile(CompilationMode::Infer);
8821   EE_.run(bindings_, "main");
8822   EE_.run(bindings_, "mainRef");
8823   Tensor &result1 = *bindings_.get(S->getPlaceholder());
8824 
8825   EXPECT_TRUE(result.isEqual(result1));
8826 }
8827 
8828 /// Check non-cubic padding for conv3D. The first conv3D has non-cubic
8829 /// padding, while the second one has zero padding. The second conv3D's input
8830 /// is the same as the first one's after-padding input. All other parameters
8831 /// of the two conv3Ds are the same.
TEST_P(OperatorTest,NonCubicPaddingConv3D)8832 TEST_P(OperatorTest, NonCubicPaddingConv3D) {
8833   CHECK_IF_ENABLED();
8834 
8835   auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 4, 4, 4, 1},
8836                                        "input", false);
8837   auto IH = bindings_.allocate(input)->getHandle();
8838   int nextVal = 1;
8839   for (dim_t i = 0; i < 4; i++) {
8840     for (dim_t j = 0; j < 4; j++) {
8841       for (dim_t k = 0; k < 4; k++) {
8842         IH.at({0, i, j, k, 0}) = static_cast<float>(nextVal++);
8843       } // W
8844     }   // H
8845   }     // T
8846 
8847   auto *filter = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 2, 2, 1},
8848                                         "filter", false);
8849   auto FH = bindings_.allocate(filter)->getHandle();
8850   for (size_t i = 0; i < filter->getType()->size(); i++) {
8851     FH.raw(i) = pow(2.0, i);
8852   }
8853   auto *zeroBias =
8854       mod_.createPlaceholder(ElemKind::FloatTy, {2}, "bias", false);
8855   bindings_.allocate(zeroBias)->zero();
8856 
8857   auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 12, 4, 8, 2});
8858 
8859   Convolution3DNode *CN =
8860       F_->createConv3D("Conv3D", input, filter, zeroBias, outTy, {2, 2, 2},
8861                        {1, 1, 1}, // {0, 2, 5, 1, 3, 4},
8862                        {5, 4, 0, 1, 2, 3}, 1);
8863   SaveNode *S = F_->createSave("save", CN);
8864   bindings_.allocate(S->getPlaceholder());
8865 
8866   ::glow::convertPlaceholdersToConstants(F_, bindings_,
8867                                          {input, S->getPlaceholder()});
8868 
8869   Tensor &result = *bindings_.get(S->getPlaceholder());
8870 
8871   // Create the reference conv3D operator whose input is the same as the
8872   // after-padding-input above.
8873   auto *input1 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 13, 5, 9, 1},
8874                                         "input1", false);
8875   bindings_.allocate(input1)->zero();
8876   auto IH1 = bindings_.get(input1)->getHandle();
8877   nextVal = 1;
8878   for (dim_t i = 5; i < 9; i++) {
8879     for (dim_t j = 0; j < 4; j++) {
8880       for (dim_t k = 2; k < 6; k++) {
8881         IH1.at({0, i, j, k, 0}) = static_cast<float>(nextVal++);
8882       } // W
8883     }   // H
8884   }     // T
8885 
8886   Function *refF = mod_.createFunction("mainRef");
8887   CN = refF->createConv3D("Conv3D_1", input1, filter, zeroBias, outTy,
8888                           {2, 2, 2}, {1, 1, 1}, {0, 0, 0, 0, 0, 0}, 1);
8889   S = refF->createSave("save1", CN);
8890   bindings_.allocate(S->getPlaceholder());
8891 
8892   ::glow::convertPlaceholdersToConstants(refF, bindings_,
8893                                          {input, input1, S->getPlaceholder()});
8894   EE_.compile(CompilationMode::Infer);
8895   EE_.run(bindings_, "main");
8896   EE_.run(bindings_, "mainRef");
8897   Tensor &result1 = *bindings_.get(S->getPlaceholder());
8898 
8899   EXPECT_TRUE(result.isEqual(result1));
8900 }
8901 
8902 /// Check non-square padding for AveragePool. The first pool op has non-square
8903 /// padding, while the second one has zero padding. The second pool op's input
8904 /// is the same as the first one's after-padding input. All other parameters
8905 /// of the two convs are the same.
TEST_P(OperatorTest,NonSquarePaddingAveragePool)8906 TEST_P(OperatorTest, NonSquarePaddingAveragePool) {
8907   CHECK_IF_ENABLED();
8908 
8909   auto *input =
8910       mod_.createPlaceholder(ElemKind::FloatTy, {1, 4, 4, 1}, "input", false);
8911   auto IH = bindings_.allocate(input)->getHandle();
8912   for (size_t i = 0; i < 4 * 4; i++) {
8913     IH.raw(i) = i + 1;
8914   }
8915   auto *Pool = F_->createAvgPool("pool", input, {2, 2}, {1, 1}, {0, 2, 1, 3});
8916   auto *S = F_->createSave("save", Pool);
8917   bindings_.allocate(S->getPlaceholder());
8918 
8919   Tensor &result = *bindings_.get(S->getPlaceholder());
8920 
8921   auto *input1 =
8922       mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 9, 1}, "input1", false);
8923   bindings_.allocate(input1)->zero();
8924   auto IH1 = bindings_.get(input1)->getHandle();
8925   for (dim_t i = 0; i < 4; i++)
8926     for (dim_t j = 2; j < 6; j++) {
8927       IH1.at({0, i, j, 0}) = i * 4 + j - 2 + 1;
8928     }
8929 
8930   Function *refF = mod_.createFunction("mainRef");
8931   Pool = refF->createAvgPool("pool1", input1, 2, 1, 0);
8932   S = refF->createSave("save1", Pool);
8933   bindings_.allocate(S->getPlaceholder());
8934   EE_.compile(CompilationMode::Infer);
8935   EE_.run(bindings_, "main");
8936   EE_.run(bindings_, "mainRef");
8937   Tensor &result1 = *bindings_.get(S->getPlaceholder());
8938 
8939   EXPECT_TRUE(result.isEqual(result1));
8940 }
8941 
8942 /// Check non-square padding for MaxPool. The first pool op has non-square
8943 /// padding, while the second one has zero padding. The second pool-op's input
8944 /// is the same as the first one's after-padding input. All other parameters
8945 /// of the two convs are the same.
TEST_P(OperatorTest,NonSquarePaddingMaxPool)8946 TEST_P(OperatorTest, NonSquarePaddingMaxPool) {
8947   CHECK_IF_ENABLED();
8948 
8949   auto *input =
8950       mod_.createPlaceholder(ElemKind::FloatTy, {1, 4, 4, 1}, "input", false);
8951   auto IH = bindings_.allocate(input)->getHandle();
8952   for (size_t i = 0; i < 4 * 4; i++) {
8953     IH.raw(i) = i + 1;
8954   }
8955   auto *Pool = F_->createMaxPool("pool", input, {2, 2}, {1, 1}, {0, 2, 1, 3});
8956   auto *S = F_->createSave("save", Pool->getResult());
8957   bindings_.allocate(S->getPlaceholder());
8958 
8959   Tensor &result = *bindings_.get(S->getPlaceholder());
8960 
8961   auto *input1 =
8962       mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 9, 1}, "input1", false);
8963   bindings_.allocate(input1)->zero();
8964   auto IH1 = bindings_.get(input1)->getHandle();
8965   for (dim_t i = 0; i < 4; i++)
8966     for (dim_t j = 2; j < 6; j++) {
8967       IH1.at({0, i, j, 0}) = i * 4 + j - 2 + 1;
8968     }
8969 
8970   Function *refF = mod_.createFunction("mainRef");
8971   Pool = refF->createMaxPool("pool1", input1, 2, 1, 0);
8972   S = refF->createSave("save1", Pool->getResult());
8973   bindings_.allocate(S->getPlaceholder());
8974 
8975   EE_.compile(CompilationMode::Infer);
8976   EE_.run(bindings_, "main");
8977   EE_.run(bindings_, "mainRef");
8978 
8979   Tensor &result1 = *bindings_.get(S->getPlaceholder());
8980 
8981   EXPECT_TRUE(result.isEqual(result1));
8982 }
8983 
TEST_P(OperatorTest,FP16AvgPool)8984 TEST_P(OperatorTest, FP16AvgPool) {
8985   CHECK_IF_ENABLED();
8986 
8987   auto *input =
8988       mod_.createPlaceholder(ElemKind::Float16Ty, {1, 3, 3, 1}, "input", false);
8989   bindings_.allocate(input)->getHandle<float16_t>() = {0., 1., 2., 3., 4.,
8990                                                        5., 6., 7., 8.};
8991   auto *Pool = F_->createAvgPool("pool", input, {2, 2}, {1, 1}, {0, 0, 0, 0});
8992   auto *S = F_->createSave("save", Pool);
8993   bindings_.allocate(S->getPlaceholder());
8994 
8995   EE_.compile(CompilationMode::Infer);
8996   EE_.run(bindings_);
8997 
8998   auto *result = bindings_.get(S->getPlaceholder());
8999   Tensor out(ElemKind::Float16Ty, {1, 2, 2, 1});
9000   out.getHandle<float16_t>() = {2., 3., 5., 6.};
9001   EXPECT_TRUE(out.isEqual(*result));
9002 }
9003 
TEST_P(OperatorTest,BFloat16AvgPool)9004 TEST_P(OperatorTest, BFloat16AvgPool) {
9005   CHECK_IF_ENABLED();
9006 
9007   auto *input = mod_.createPlaceholder(ElemKind::BFloat16Ty, {1, 3, 3, 1},
9008                                        "input", false);
9009   bindings_.allocate(input)->getHandle<bfloat16_t>() = {0., 1., 2., 3., 4.,
9010                                                         5., 6., 7., 8.};
9011   auto *Pool = F_->createAvgPool("pool", input, {2, 2}, {1, 1}, {0, 0, 0, 0});
9012   auto *S = F_->createSave("save", Pool);
9013   bindings_.allocate(S->getPlaceholder());
9014 
9015   EE_.compile(CompilationMode::Infer);
9016   EE_.run(bindings_);
9017 
9018   auto *result = bindings_.get(S->getPlaceholder());
9019   Tensor out(ElemKind::BFloat16Ty, {1, 2, 2, 1});
9020   out.getHandle<bfloat16_t>() = {2., 3., 5., 6.};
9021   EXPECT_TRUE(out.isEqual(*result));
9022 }
9023 
9024 /// Verify that the AvgPool operator works correctly.
TEST_P(OperatorTest,AvgPool)9025 TEST_P(OperatorTest, AvgPool) {
9026   CHECK_IF_ENABLED();
9027 
9028   auto *input =
9029       mod_.createPlaceholder(ElemKind::FloatTy, {1, 3, 3, 1}, "input", false);
9030   bindings_.allocate(input)->getHandle() = {0., 1., 2., 3., 4., 5., 6., 7., 8.};
9031   auto *Pool = F_->createAvgPool("pool", input, {2, 2}, {1, 1}, {0, 0, 0, 0});
9032   auto *S = F_->createSave("save", Pool);
9033   bindings_.allocate(S->getPlaceholder());
9034 
9035   EE_.compile(CompilationMode::Infer);
9036   EE_.run(bindings_);
9037 
9038   auto *result = bindings_.get(S->getPlaceholder());
9039   Tensor out(ElemKind::FloatTy, {1, 2, 2, 1});
9040   out.getHandle() = {2., 3., 5., 6.};
9041   EXPECT_TRUE(out.isEqual(*result));
9042 }
9043 
TEST_P(OperatorTest,Int8AvgPool)9044 TEST_P(OperatorTest, Int8AvgPool) {
9045   CHECK_IF_ENABLED();
9046 
9047   auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {1, 3, 3, 1}, 1, 0,
9048                                        "input", false);
9049   bindings_.allocate(input)->getHandle<int8_t>() = {0, 1, 2, 3, 4, 5, 6, 7, 8};
9050   auto *Pool = F_->createAvgPool("pool", input, {2, 2}, {1, 1}, {0, 0, 0, 0});
9051   auto *S = F_->createSave("save", Pool);
9052   bindings_.allocate(S->getPlaceholder());
9053 
9054   EE_.compile(CompilationMode::Infer);
9055   EE_.run(bindings_);
9056 
9057   auto result = bindings_.get(S->getPlaceholder())->getHandle<int8_t>();
9058   Tensor out(ElemKind::Int8QTy, {2, 2}, 1, 0);
9059   out.getHandle<int8_t>() = {2, 3, 5, 6};
9060   for (size_t i = 0; i < 2 * 2; i++) {
9061     EXPECT_EQ(result.raw(i), out.getHandle<int8_t>().raw(i));
9062   }
9063 }
9064 
TEST_P(OperatorTest,FP16AvgPool3D)9065 TEST_P(OperatorTest, FP16AvgPool3D) {
9066   CHECK_IF_ENABLED();
9067 
9068   auto *input =
9069       mod_.createPlaceholder(ElemKind::Float16Ty, {1, 1, 3, 3, 3}, // NCTHW
9070                              "input", false);
9071   bindings_.allocate(input)->getHandle<float16_t>() = {
9072       0., 1., 2., 3., 4., 5., 6., 7., 8., 0., 1., 2., 3., 4.,
9073       5., 6., 7., 8., 0., 1., 2., 3., 4., 5., 6., 7., 8.};
9074   auto *inputNTHWC =
9075       F_->createTranspose("avgpool3d_input_NCTHW2NTHWC", input, NCTHW2NTHWC);
9076   auto *Pool = F_->createAvgPool("pool", inputNTHWC, {2, 2, 2}, // kernel
9077                                  {1, 1, 1},                     // stride
9078                                  {0, 0, 0, 0, 0, 0},            // padding
9079                                  NTHWC);
9080   auto *outputNCTHW =
9081       F_->createTranspose("avgpool3d_output_NTHWC2NCTHW", Pool, NTHWC2NCTHW);
9082   auto *S = F_->createSave("save", outputNCTHW);
9083   bindings_.allocate(S->getPlaceholder());
9084 
9085   EE_.compile(CompilationMode::Infer);
9086   EE_.run(bindings_);
9087 
9088   auto *result = bindings_.get(S->getPlaceholder());
9089   Tensor out(ElemKind::Float16Ty, {1, 1, 2, 2, 2});
9090   out.getHandle<float16_t>() = {2., 3., 5., 6., 2., 3., 5., 6.};
9091   EXPECT_TRUE(out.isEqual(*result));
9092 }
9093 
TEST_P(OperatorTest,BFloat16AvgPool3D)9094 TEST_P(OperatorTest, BFloat16AvgPool3D) {
9095   CHECK_IF_ENABLED();
9096 
9097   auto *input =
9098       mod_.createPlaceholder(ElemKind::BFloat16Ty, {1, 1, 3, 3, 3}, // NCTHW
9099                              "input", false);
9100   bindings_.allocate(input)->getHandle<bfloat16_t>() = {
9101       0., 1., 2., 3., 4., 5., 6., 7., 8., 0., 1., 2., 3., 4.,
9102       5., 6., 7., 8., 0., 1., 2., 3., 4., 5., 6., 7., 8.};
9103   auto *inputNTHWC =
9104       F_->createTranspose("avgpool3d_input_NCTHW2NTHWC", input, NCTHW2NTHWC);
9105   auto *Pool = F_->createAvgPool("pool", inputNTHWC, {2, 2, 2}, // kernel
9106                                  {1, 1, 1},                     // stride
9107                                  {0, 0, 0, 0, 0, 0},            // padding
9108                                  NTHWC);
9109   auto *outputNCTHW =
9110       F_->createTranspose("avgpool3d_output_NTHWC2NCTHW", Pool, NTHWC2NCTHW);
9111   auto *S = F_->createSave("save", outputNCTHW);
9112   bindings_.allocate(S->getPlaceholder());
9113 
9114   EE_.compile(CompilationMode::Infer);
9115   EE_.run(bindings_);
9116 
9117   auto *result = bindings_.get(S->getPlaceholder());
9118   Tensor out(ElemKind::BFloat16Ty, {1, 1, 2, 2, 2});
9119   out.getHandle<bfloat16_t>() = {2., 3., 5., 6., 2., 3., 5., 6.};
9120   EXPECT_TRUE(out.isEqual(*result));
9121 }
9122 
TEST_P(OperatorTest,Int8AvgPool3D)9123 TEST_P(OperatorTest, Int8AvgPool3D) {
9124   CHECK_IF_ENABLED();
9125 
9126   auto *input =
9127       mod_.createPlaceholder(ElemKind::Int8QTy, {1, 1, 3, 3, 3}, // NCTHW
9128                              1, 0, // scale, offset
9129                              "input", false);
9130   bindings_.allocate(input)->getHandle<int8_t>() = {0, 1, 2, 3, 4, 5, 6, 7, 8,
9131                                                     0, 1, 2, 3, 4, 5, 6, 7, 8,
9132                                                     0, 1, 2, 3, 4, 5, 6, 7, 8};
9133   auto *inputNTHWC =
9134       F_->createTranspose("avgpool3d_input_NCTHW2NTHWC", input, NCTHW2NTHWC);
9135   auto *Pool = F_->createAvgPool("avgpool3d", inputNTHWC, {2, 2, 2}, // kernel
9136                                  {1, 1, 1},                          // stride
9137                                  {0, 0, 0, 0, 0, 0},                 // padding
9138                                  NTHWC);
9139   auto *outputNCTHW =
9140       F_->createTranspose("avgpool3d_output_NTHWC2NCTHW", Pool, NTHWC2NCTHW);
9141   auto *S = F_->createSave("save", outputNCTHW);
9142   bindings_.allocate(S->getPlaceholder());
9143 
9144   EE_.compile(CompilationMode::Infer);
9145   EE_.run(bindings_);
9146 
9147   auto result = bindings_.get(S->getPlaceholder())->getHandle<int8_t>();
9148   Tensor out(ElemKind::Int8QTy, {1, 1, 2, 2, 2}, 1, 0);
9149   out.getHandle<int8_t>() = {
9150       2, 3, 5, 6, 2, 3, 5, 6,
9151   };
9152   for (size_t i = 0; i < 2 * 2 * 2; i++) {
9153     EXPECT_EQ(result.raw(i), out.getHandle<int8_t>().raw(i));
9154   }
9155 }
9156 
9157 /// Verify that the AdaptiveAvgPool operator works correctly.
TEST_P(OperatorTest,AdaptiveAvgPool)9158 TEST_P(OperatorTest, AdaptiveAvgPool) {
9159   CHECK_IF_ENABLED();
9160   auto *input =
9161       mod_.createPlaceholder(ElemKind::FloatTy, {1, 4, 4, 1}, "input", false);
9162   bindings_.allocate(input)->getHandle() = {
9163       0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.};
9164 
9165   auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 3, 3, 1});
9166   auto *pool = F_->createAdaptiveAvgPool("pool", input, outTy);
9167   auto *S = F_->createSave("save", pool);
9168   bindings_.allocate(S->getPlaceholder());
9169 
9170   EE_.compile(CompilationMode::Infer);
9171   EE_.run(bindings_);
9172 
9173   auto *result = bindings_.get(S->getPlaceholder());
9174   Tensor out(ElemKind::FloatTy, {1, 3, 3, 1});
9175   out.getHandle() = {2.5, 3.5, 4.5, 6.5, 7.5, 8.5, 10.5, 11.5, 12.5};
9176   EXPECT_TRUE(out.isEqual(*result));
9177 }
9178 
9179 /// Verify that the AdaptiveAvgPool operator works correctly with fp16.
TEST_P(OperatorTest,FP16AdaptiveAvgPool)9180 TEST_P(OperatorTest, FP16AdaptiveAvgPool) {
9181   CHECK_IF_ENABLED();
9182   auto *input =
9183       mod_.createPlaceholder(ElemKind::Float16Ty, {1, 4, 4, 1}, "input", false);
9184   bindings_.allocate(input)->getHandle<float16_t>() = {
9185       0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.};
9186   auto outTy = mod_.uniqueType(ElemKind::Float16Ty, {1, 3, 3, 1});
9187   auto *pool = F_->createAdaptiveAvgPool("pool", input, outTy);
9188   auto *S = F_->createSave("save", pool);
9189   bindings_.allocate(S->getPlaceholder());
9190 
9191   EE_.compile(CompilationMode::Infer);
9192   EE_.run(bindings_);
9193 
9194   auto *result = bindings_.get(S->getPlaceholder());
9195   Tensor out(ElemKind::Float16Ty, {1, 3, 3, 1});
9196   out.getHandle<float16_t>() = {2.5, 3.5, 4.5, 6.5, 7.5, 8.5, 10.5, 11.5, 12.5};
9197   EXPECT_TRUE(out.isEqual(*result));
9198 }
9199 
9200 /// Verify that the AdaptiveAvgPool operator works correctly with bfloat16.
TEST_P(OperatorTest,BFloat16AdaptiveAvgPool)9201 TEST_P(OperatorTest, BFloat16AdaptiveAvgPool) {
9202   CHECK_IF_ENABLED();
9203   auto *input = mod_.createPlaceholder(ElemKind::BFloat16Ty, {1, 4, 4, 1},
9204                                        "input", false);
9205   bindings_.allocate(input)->getHandle<bfloat16_t>() = {
9206       0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.};
9207   auto outTy = mod_.uniqueType(ElemKind::BFloat16Ty, {1, 3, 3, 1});
9208   auto *pool = F_->createAdaptiveAvgPool("pool", input, outTy);
9209   auto *S = F_->createSave("save", pool);
9210   bindings_.allocate(S->getPlaceholder());
9211 
9212   EE_.compile(CompilationMode::Infer);
9213   EE_.run(bindings_);
9214 
9215   auto *result = bindings_.get(S->getPlaceholder());
9216   Tensor out(ElemKind::BFloat16Ty, {1, 3, 3, 1});
9217   out.getHandle<bfloat16_t>() = {2.5, 3.5,  4.5,  6.5, 7.5,
9218                                  8.5, 10.5, 11.5, 12.5};
9219   EXPECT_TRUE(out.isEqual(*result));
9220 }
9221 
9222 /// Verify that the AdaptiveAvgPool operator works correctly with int8.
TEST_P(OperatorTest,Int8AdaptiveAvgPool)9223 TEST_P(OperatorTest, Int8AdaptiveAvgPool) {
9224   CHECK_IF_ENABLED();
9225   auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {1, 4, 4, 1}, 1, 0,
9226                                        "input", false);
9227   bindings_.allocate(input)->getHandle<int8_t>() = {
9228       0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
9229   auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {1, 3, 3, 1}, 1, 0);
9230   auto *pool = F_->createAdaptiveAvgPool("pool", input, outTy);
9231   auto *S = F_->createSave("save", pool);
9232   bindings_.allocate(S->getPlaceholder());
9233 
9234   EE_.compile(CompilationMode::Infer);
9235   EE_.run(bindings_);
9236 
9237   auto *result = bindings_.get(S->getPlaceholder());
9238   Tensor out(ElemKind::Int8QTy, {1, 3, 3, 1}, 1, 0);
9239   out.getHandle<int8_t>() = {3, 4, 5, 7, 8, 9, 11, 12, 13};
9240   EXPECT_TRUE(out.isEqual(*result));
9241 }
9242 
9243 /// Verify that the AdaptiveAvgPool operator works correctly with non-square
9244 /// inputs and outputs.
TEST_P(OperatorTest,AdaptiveAvgPoolNonSquare)9245 TEST_P(OperatorTest, AdaptiveAvgPoolNonSquare) {
9246   CHECK_IF_ENABLED();
9247   auto *input =
9248       mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 3, 1}, "input", false);
9249   bindings_.allocate(input)->getHandle() = {0., 1., 2.,  3.,  4.,  5.,  6., 7.,
9250                                             8., 9., 10., 11., 12., 13., 14.};
9251 
9252   auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 3, 2, 1});
9253   auto *pool = F_->createAdaptiveAvgPool("pool", input, outTy);
9254   auto *S = F_->createSave("save", pool);
9255   bindings_.allocate(S->getPlaceholder());
9256 
9257   EE_.compile(CompilationMode::Infer);
9258   EE_.run(bindings_);
9259 
9260   auto *result = bindings_.get(S->getPlaceholder());
9261   Tensor out(ElemKind::FloatTy, {1, 3, 2, 1});
9262   out.getHandle() = {2, 3, 6.5, 7.5, 11, 12};
9263   EXPECT_TRUE(out.isEqual(*result));
9264 }
9265 
TEST_P(OperatorTest,MaxPool)9266 TEST_P(OperatorTest, MaxPool) {
9267   CHECK_IF_ENABLED();
9268 
9269   auto *input =
9270       mod_.createPlaceholder(ElemKind::FloatTy, {1, 3, 3, 1}, "input", false);
9271   bindings_.allocate(input)->getHandle() = {0., 1., 2., 3., 4., 5., 6., 7., 8.};
9272   auto *pool = F_->createMaxPool("pool", input, {2, 2}, {1, 1}, {0, 0, 0, 0});
9273   auto *S = F_->createSave("save", pool->getResult());
9274   bindings_.allocate(S->getPlaceholder());
9275 
9276   EE_.compile(CompilationMode::Infer);
9277   EE_.run(bindings_);
9278 
9279   auto result = bindings_.get(S->getPlaceholder());
9280   Tensor out(ElemKind::FloatTy, {1, 2, 2, 1});
9281   out.getHandle() = {4., 5., 7., 8.};
9282   EXPECT_TRUE(out.isEqual(*result));
9283 }
9284 
TEST_P(OperatorTest,FP16MaxPool)9285 TEST_P(OperatorTest, FP16MaxPool) {
9286   CHECK_IF_ENABLED();
9287 
9288   auto *input =
9289       mod_.createPlaceholder(ElemKind::Float16Ty, {1, 3, 3, 1}, "input", false);
9290   bindings_.allocate(input)->getHandle<float16_t>() = {0., 1., 2., 3., 4.,
9291                                                        5., 6., 7., 8.};
9292   auto *pool = F_->createMaxPool("pool", input, {2, 2}, {1, 1}, {0, 0, 0, 0});
9293   auto *S = F_->createSave("save", pool->getResult());
9294   bindings_.allocate(S->getPlaceholder());
9295 
9296   EE_.compile(CompilationMode::Infer);
9297   EE_.run(bindings_);
9298 
9299   auto result = bindings_.get(S->getPlaceholder());
9300   Tensor out(ElemKind::Float16Ty, {1, 2, 2, 1});
9301   out.getHandle<float16_t>() = {4., 5., 7., 8.};
9302   EXPECT_TRUE(out.isEqual(*result));
9303 }
9304 
TEST_P(OperatorTest,BFloat16MaxPool)9305 TEST_P(OperatorTest, BFloat16MaxPool) {
9306   CHECK_IF_ENABLED();
9307 
9308   auto *input = mod_.createPlaceholder(ElemKind::BFloat16Ty, {1, 3, 3, 1},
9309                                        "input", false);
9310   bindings_.allocate(input)->getHandle<bfloat16_t>() = {0., 1., 2., 3., 4.,
9311                                                         5., 6., 7., 8.};
9312   auto *pool = F_->createMaxPool("pool", input, {2, 2}, {1, 1}, {0, 0, 0, 0});
9313   auto *S = F_->createSave("save", pool->getResult());
9314   bindings_.allocate(S->getPlaceholder());
9315 
9316   EE_.compile(CompilationMode::Infer);
9317   EE_.run(bindings_);
9318 
9319   auto result = bindings_.get(S->getPlaceholder());
9320   Tensor out(ElemKind::BFloat16Ty, {1, 2, 2, 1});
9321   out.getHandle<bfloat16_t>() = {4., 5., 7., 8.};
9322   EXPECT_TRUE(out.isEqual(*result));
9323 }
9324 
TEST_P(OperatorTest,Int8MaxPool)9325 TEST_P(OperatorTest, Int8MaxPool) {
9326   CHECK_IF_ENABLED();
9327 
9328   auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {1, 3, 3, 1}, 1, 0,
9329                                        "input", false);
9330   bindings_.allocate(input)->getHandle<int8_t>() = {0, 1, 2, 3, 4, 5, 6, 7, 8};
9331   auto *Pool = F_->createMaxPool("pool", input, {2, 2}, {1, 1}, {0, 0, 0, 0});
9332   auto *S = F_->createSave("save", Pool->getResult());
9333   bindings_.allocate(S->getPlaceholder());
9334 
9335   EE_.compile(CompilationMode::Infer);
9336   EE_.run(bindings_);
9337 
9338   auto result = bindings_.get(S->getPlaceholder())->getHandle<int8_t>();
9339   Tensor out(ElemKind::Int8QTy, {2, 2}, 1, 0);
9340   out.getHandle<int8_t>() = {4, 5, 7, 8};
9341   for (size_t i = 0; i < 2 * 2; i++) {
9342     EXPECT_EQ(result.raw(i), out.getHandle<int8_t>().raw(i));
9343   }
9344 }
9345 
9346 #define COMPARE_UNARY_OP_FUN(_OP_NAME_, LEN, LOW, HIGH)                        \
9347   static FunctionTensorPair createAndInitBasic##_OP_NAME_##Test(               \
9348       glow::PlaceholderBindings &bindings, glow::ExecutionEngine &EE) {        \
9349     auto &mod = EE.getModule();                                                \
9350     Function *F = mod.createFunction("main");                                  \
9351                                                                                \
9352     auto *input =                                                              \
9353         mod.createPlaceholder(ElemKind::FloatTy, {LEN}, "input", false);       \
9354     bindings.allocate(input)->getHandle().randomize(LOW, HIGH, mod.getPRNG()); \
9355     auto *tanh = F->create##_OP_NAME_(#_OP_NAME_, input);                      \
9356     auto *save = F->createSave("Save", tanh);                                  \
9357     auto *resultTensor = bindings.allocate(save->getPlaceholder());            \
9358     return std::make_pair(F, resultTensor);                                    \
9359   }
9360 COMPARE_UNARY_OP_FUN(Exp, 10, -1.0F, 1.0F)
9361 COMPARE_UNARY_OP_FUN(Tanh, 10, -10.0F, 10.0F)
9362 COMPARE_UNARY_OP_FUN(Log, 1000, 1.0F, 100.0F)
9363 COMPARE_UNARY_OP_FUN(Sigmoid, 10, -10.0F, 10.0F)
9364 #undef COMPARE_UNARY_OP_FUN
9365 
9366 /// Reference ideal sigmoid implementation. Computes an fp32 sigmoid
9367 /// and casts the result to FP16.
refSigmoidFp16(float x)9368 static float16_t refSigmoidFp16(float x) {
9369   float res = 1 / (1 + exp(-x));
9370 
9371   return (float16_t)res;
9372 }
9373 
9374 /// Reference ideal sigmoid implementation. Computes an fp32 sigmoid
9375 /// and casts the result to BFloat16.
refSigmoidBFloat16(float x)9376 static bfloat16_t refSigmoidBFloat16(float x) {
9377   float res = 1 / (1 + exp(-x));
9378 
9379   return (bfloat16_t)res;
9380 }
9381 
9382 /// Test to verify that the sigmoid implementation is equal to the
9383 /// Mirrored LUT implementation
9384 /// Does a sweep of -15,15 and prints the outputs of the NNPI implementation
9385 /// compared to the LUT one, the ideal sigmoid in fp16 is also provided as
9386 /// a visual sanity check, but nothing is enforced against that last one.
testSigmoidFp16Sweep(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE)9387 static void testSigmoidFp16Sweep(glow::PlaceholderBindings &bindings,
9388                                  glow::Module &mod, glow::Function *F,
9389                                  glow::ExecutionEngine &EE) {
9390   constexpr dim_t N = 100;
9391   auto *input = mod.createPlaceholder(ElemKind::FloatTy, {N}, "input", false);
9392   auto inputH = bindings.allocate(input)->getHandle();
9393 
9394   constexpr float rangeStart = -20;
9395   constexpr float rangeEnd = 20;
9396   constexpr float delta = (rangeEnd - rangeStart) / N;
9397 
9398   for (dim_t i = 0; i < N; i++) {
9399     inputH.raw(i) = rangeStart + i * delta;
9400   }
9401 
9402   auto *sigmoid = F->createSigmoid("Sigmoid", input);
9403   auto *save = F->createSave("Save", sigmoid);
9404   auto *resultTensor = bindings.allocate(save->getPlaceholder());
9405 
9406   CompilationContext cctx;
9407   cctx.precisionConfig.convertToFP16 = true;
9408   cctx.precisionConfig.convertFusedToFP16 = true;
9409   cctx.precisionConfig.float16Format =
9410       PrecisionConfiguration::Float16Format::FP16;
9411 
9412   EE.compile(cctx);
9413   EE.run(bindings);
9414 
9415   auto resultH = resultTensor->getHandle();
9416   int numDiffs = 0;
9417 
9418   for (dim_t i = 0; i < N; i++) {
9419     float inputV = inputH.at({i});
9420     float refIdeal = refSigmoidFp16(inputV);
9421     float output = resultH.at({i});
9422     float absDiff = fabs(output - refIdeal);
9423     float relDiff = fabs(absDiff / (refIdeal + 1e-8));
9424 
9425     bool failed = false;
9426     // Relative error should be 2^-11 but we are relaxing this constraint
9427     // due to linear interpolation
9428     // Absolute error can remain 1e-5 for now
9429     if (absDiff > 1e-5 && relDiff > 2e-3) {
9430       numDiffs++;
9431       failed = true;
9432     }
9433 
9434     llvm::outs() << "Sigmoid " << i << " " << inputV << " Backend:" << output
9435                  << " ref_ideal:" << refIdeal << " relDiff:" << relDiff
9436                  << " absDiff:" << absDiff << " failed:" << failed << "\n";
9437   }
9438   llvm::outs() << "Number of diffs: " << numDiffs << "\n";
9439   llvm::outs().flush();
9440 
9441   EXPECT_EQ(numDiffs, 0);
9442 }
9443 
9444 /// Test to verify that the sigmoid implementation is equal to the
9445 /// Mirrored LUT implementation
9446 /// Does a sweep of -15,15 and prints the outputs of the NNPI implementation
9447 /// compared to the LUT one, the ideal sigmoid in bfloat16 is also provided as
9448 /// a visual sanity check, but nothing is enforced against that last one.
testSigmoidBFloat16Sweep(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE)9449 static void testSigmoidBFloat16Sweep(glow::PlaceholderBindings &bindings,
9450                                      glow::Module &mod, glow::Function *F,
9451                                      glow::ExecutionEngine &EE) {
9452   constexpr dim_t N = 100;
9453   auto *input = mod.createPlaceholder(ElemKind::FloatTy, {N}, "input", false);
9454   auto inputH = bindings.allocate(input)->getHandle();
9455 
9456   constexpr float rangeStart = -20;
9457   constexpr float rangeEnd = 20;
9458   constexpr float delta = (rangeEnd - rangeStart) / N;
9459 
9460   for (dim_t i = 0; i < N; i++) {
9461     inputH.raw(i) = rangeStart + i * delta;
9462   }
9463 
9464   auto *sigmoid = F->createSigmoid("Sigmoid", input);
9465   auto *save = F->createSave("Save", sigmoid);
9466   auto *resultTensor = bindings.allocate(save->getPlaceholder());
9467 
9468   CompilationContext cctx;
9469   cctx.precisionConfig.convertToFP16 = true;
9470   cctx.precisionConfig.convertFusedToFP16 = true;
9471   cctx.precisionConfig.float16Format =
9472       PrecisionConfiguration::Float16Format::BFloat16;
9473 
9474   EE.compile(cctx);
9475   EE.run(bindings);
9476 
9477   auto resultH = resultTensor->getHandle();
9478   int numDiffs = 0;
9479 
9480   for (dim_t i = 0; i < N; i++) {
9481     float inputV = inputH.at({i});
9482     float refIdeal = refSigmoidBFloat16(inputV);
9483     float output = resultH.at({i});
9484     float absDiff = fabs(output - refIdeal);
9485     float relDiff = fabs(absDiff / (refIdeal + 1e-8));
9486 
9487     bool failed = false;
9488     // Relative error should be 2^-11 but we are relaxing this constraint
9489     // due to linear interpolation.
9490     // Absolute error can remain 1e-5 for now
9491     if (absDiff > 1e-3 && relDiff > 2e-2) {
9492       numDiffs++;
9493       failed = true;
9494     }
9495 
9496     llvm::outs() << "Sigmoid " << i << " " << inputV << " Backend:" << output
9497                  << " ref_ideal:" << refIdeal << " relDiff:" << relDiff
9498                  << " absDiff:" << absDiff << " failed:" << failed << "\n";
9499   }
9500   llvm::outs() << "Number of diffs: " << numDiffs << "\n";
9501   llvm::outs().flush();
9502 
9503   EXPECT_EQ(numDiffs, 0);
9504 }
9505 
TEST_P(OperatorTest,SigmoidSweep_Float16)9506 TEST_P(OperatorTest, SigmoidSweep_Float16) {
9507   CHECK_IF_ENABLED();
9508 
9509   testSigmoidFp16Sweep(bindings_, mod_, F_, EE_);
9510 }
9511 
TEST_P(OperatorTest,SigmoidSweep_BFloat16)9512 TEST_P(OperatorTest, SigmoidSweep_BFloat16) {
9513   CHECK_IF_ENABLED();
9514 
9515   testSigmoidBFloat16Sweep(bindings_, mod_, F_, EE_);
9516 }
9517 
9518 /// Reference ideal tanh implementation. Computes an fp32 tanh
9519 /// and casts the result to FP16, no denorms
refTanHFp16(float x)9520 static float16_t refTanHFp16(float x) {
9521   float res = (exp(2 * x) - 1) / (exp(2 * x) + 1);
9522   if (fabs(res) < 6e-5) {
9523     res = 0.0;
9524   }
9525   return (float16_t)res;
9526 }
9527 
9528 /// Reference ideal tanh implementation. Computes an fp32 tanh
9529 /// and casts the result to BFloat16, no denorms
refTanHBFloat16(float x)9530 static bfloat16_t refTanHBFloat16(float x) {
9531   float res = (exp(2 * x) - 1) / (exp(2 * x) + 1);
9532   if (fabs(res) < 6e-5) {
9533     res = 0.0;
9534   }
9535   return (bfloat16_t)res;
9536 }
9537 
9538 /// Test to verify that the tanh implementation is close to the ideal one
9539 /// Does a sweep of -15,15 and prints the outputs of the NNPI implementation
9540 /// compared to the ideal tanh in fp16.
testTanHFp16Sweep(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE)9541 static void testTanHFp16Sweep(glow::PlaceholderBindings &bindings,
9542                               glow::Module &mod, glow::Function *F,
9543                               glow::ExecutionEngine &EE) {
9544   constexpr dim_t N = 100;
9545   auto *input = mod.createPlaceholder(ElemKind::FloatTy, {N}, "input", false);
9546   auto inputH = bindings.allocate(input)->getHandle();
9547 
9548   constexpr float rangeStart = -15;
9549   constexpr float rangeEnd = 15;
9550   constexpr float delta = (rangeEnd - rangeStart) / N;
9551 
9552   for (dim_t i = 0; i < N; i++) {
9553     inputH.raw(i) = rangeStart + i * delta;
9554   }
9555 
9556   auto *sigmoid = F->createTanh("TanH", input);
9557   auto *save = F->createSave("Save", sigmoid);
9558   auto *resultTensor = bindings.allocate(save->getPlaceholder());
9559 
9560   CompilationContext cctx;
9561   cctx.precisionConfig.convertToFP16 = true;
9562   cctx.precisionConfig.convertFusedToFP16 = true;
9563   cctx.precisionConfig.float16Format =
9564       PrecisionConfiguration::Float16Format::FP16;
9565 
9566   EE.compile(cctx);
9567   EE.run(bindings);
9568 
9569   auto resultH = resultTensor->getHandle();
9570   int count = 0;
9571 
9572   for (dim_t i = 0; i < N; i++) {
9573     float inputV = inputH.at({i});
9574     float refIdeal = refTanHFp16(inputV);
9575     float output = resultH.at({i});
9576     float diff = fabs(output - refIdeal);
9577 
9578     if (diff > 1e-6) {
9579       count++;
9580     }
9581 
9582     llvm::outs() << "TanH " << i << " " << inputV << " Backend:" << output
9583                  << " ref_ideal:" << refIdeal << " diff:" << diff << "\n";
9584   }
9585   llvm::outs().flush();
9586 
9587   EXPECT_EQ(count, 0);
9588 }
9589 
9590 /// Test to verify that the tanh implementation is close to the ideal one
9591 /// Does a sweep of -15,15 and prints the outputs of the NNPI implementation
9592 /// compared to the ideal tanh in fp16.
testTanHBFloat16Sweep(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE)9593 static void testTanHBFloat16Sweep(glow::PlaceholderBindings &bindings,
9594                                   glow::Module &mod, glow::Function *F,
9595                                   glow::ExecutionEngine &EE) {
9596   constexpr dim_t N = 100;
9597   auto *input = mod.createPlaceholder(ElemKind::FloatTy, {N}, "input", false);
9598   auto inputH = bindings.allocate(input)->getHandle();
9599 
9600   constexpr float rangeStart = -15;
9601   constexpr float rangeEnd = 15;
9602   constexpr float delta = (rangeEnd - rangeStart) / N;
9603 
9604   for (dim_t i = 0; i < N; i++) {
9605     inputH.raw(i) = rangeStart + i * delta;
9606   }
9607 
9608   auto *sigmoid = F->createTanh("TanH", input);
9609   auto *save = F->createSave("Save", sigmoid);
9610   auto *resultTensor = bindings.allocate(save->getPlaceholder());
9611 
9612   CompilationContext cctx;
9613   cctx.precisionConfig.convertToFP16 = true;
9614   cctx.precisionConfig.convertFusedToFP16 = true;
9615   cctx.precisionConfig.float16Format =
9616       PrecisionConfiguration::Float16Format::BFloat16;
9617 
9618   EE.compile(cctx);
9619   EE.run(bindings);
9620 
9621   auto resultH = resultTensor->getHandle();
9622   int count = 0;
9623 
9624   for (dim_t i = 0; i < N; i++) {
9625     float inputV = inputH.at({i});
9626     float refIdeal = refTanHBFloat16(inputV);
9627     float output = resultH.at({i});
9628     float diff = fabs(output - refIdeal);
9629 
9630     if (diff > 1e-2) {
9631       count++;
9632     }
9633 
9634     llvm::outs() << "TanH " << i << " " << inputV << " Backend:" << output
9635                  << " ref_ideal:" << refIdeal << " diff:" << diff << "\n";
9636   }
9637   llvm::outs().flush();
9638 
9639   EXPECT_EQ(count, 0);
9640 }
9641 
TEST_P(OperatorTest,TanHSweep_Float16)9642 TEST_P(OperatorTest, TanHSweep_Float16) {
9643   CHECK_IF_ENABLED();
9644 
9645   testTanHFp16Sweep(bindings_, mod_, F_, EE_);
9646 }
9647 
TEST_P(OperatorTest,TanHSweep_BFloat16)9648 TEST_P(OperatorTest, TanHSweep_BFloat16) {
9649   CHECK_IF_ENABLED();
9650 
9651   testTanHBFloat16Sweep(bindings_, mod_, F_, EE_);
9652 }
9653 
9654 template <typename DataType>
testMaxPoolWithArgmax(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)9655 static void testMaxPoolWithArgmax(glow::PlaceholderBindings &bindings,
9656                                   glow::Module &mod, glow::Function *F,
9657                                   glow::ExecutionEngine &EE, ElemKind DTy) {
9658   auto *input = createPlaceholderConditionallyQuantized(mod, DTy, {1, 3, 3, 1},
9659                                                         "input", false, "NHWC");
9660   bindings.allocate(input)->getHandle<DataType>() = {0, 3, 7, 6, 5, 1, 2, 8, 4};
9661   auto *pool = F->createMaxPool("pool", input, {2, 2}, {1, 1}, {0, 0, 0, 0});
9662   auto *SResult = F->createSave("save_result", pool->getResult());
9663   auto *SArgmax = F->createSave("save_argmax", pool->getArgmax());
9664   bindings.allocate(SResult->getPlaceholder());
9665   bindings.allocate(SArgmax->getPlaceholder());
9666 
9667   EE.compile(CompilationMode::Infer);
9668   EE.run(bindings);
9669 
9670   auto result = bindings.get(SResult->getPlaceholder());
9671   auto argmax = bindings.get(SArgmax->getPlaceholder());
9672   Tensor out1 = createTensorConditionallyQuantized(DTy, {1, 2, 2, 1});
9673   out1.getHandle<DataType>() = {6, 7, 8, 8};
9674   EXPECT_TRUE(out1.isEqual(*result));
9675 
9676   Tensor out2(ElemKind::Int64ITy, {1, 2, 2, 1});
9677   out2.getHandle<int64_t>() = {3, 2, 7, 7};
9678   EXPECT_TRUE(out2.isEqual(*argmax));
9679 }
9680 
TEST_P(OperatorTest,FloatMaxPoolWithArgmax)9681 TEST_P(OperatorTest, FloatMaxPoolWithArgmax) {
9682   CHECK_IF_ENABLED();
9683   testMaxPoolWithArgmax<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
9684 }
9685 
TEST_P(OperatorTest,QuantizedMaxPoolWithArgmax)9686 TEST_P(OperatorTest, QuantizedMaxPoolWithArgmax) {
9687   CHECK_IF_ENABLED();
9688   testMaxPoolWithArgmax<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
9689 }
9690 
9691 template <typename DataType>
9692 static void
testMaxPoolWithArgmaxTransposed(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)9693 testMaxPoolWithArgmaxTransposed(glow::PlaceholderBindings &bindings,
9694                                 glow::Module &mod, glow::Function *F,
9695                                 glow::ExecutionEngine &EE, ElemKind DTy) {
9696   // Show that sequence Tensor(NCHW) -> Transpose(NCHWtoNHWC) ->
9697   // MaxPoolWithArgmax -> Transpose(NHWCtoNCHW) produces correct
9698   // linearization.
9699   auto *inputNCHW = createPlaceholderConditionallyQuantized(
9700       mod, DTy, {1, 3, 4, 4}, "input", false, "NCHW");
9701   auto inHandle = bindings.allocate(inputNCHW)->getHandle<DataType>();
9702   inHandle.clear(0.);
9703   inHandle.at({0, 0, 2, 2}) = 11;
9704   inHandle.at({0, 1, 2, 2}) = 22;
9705   inHandle.at({0, 2, 2, 2}) = 33;
9706 
9707   // Input NCHW to NHWC conversion.
9708   auto *inputNHWC =
9709       F->createTranspose("transposeInput", inputNCHW, {0, 2, 3, 1}, "NHWC");
9710   auto *pool =
9711       F->createMaxPool("pool", inputNHWC, {4, 4}, {4, 4}, {0, 0, 0, 0});
9712 
9713   // NHWC to NCHW conversion.
9714   auto *resultNCHW = F->createTranspose("transposeRes", pool->getResult(),
9715                                         {0, 3, 1, 2}, "NCHW");
9716   auto *argmaxNCHW = F->createTranspose("transposeArgmax", pool->getArgmax(),
9717                                         {0, 3, 1, 2}, "NCHW");
9718 
9719   auto *SResult = F->createSave("save_result", resultNCHW);
9720   auto *SArgmax = F->createSave("save_argmax", argmaxNCHW);
9721   bindings.allocate(SResult->getPlaceholder());
9722   bindings.allocate(SArgmax->getPlaceholder());
9723 
9724   EE.compile(CompilationMode::Infer);
9725   EE.run(bindings);
9726 
9727   auto result = bindings.get(SResult->getPlaceholder());
9728   auto argmax = bindings.get(SArgmax->getPlaceholder());
9729   Tensor out1 = createTensorConditionallyQuantized(DTy, {1, 3, 1, 1});
9730   out1.getHandle<DataType>() = {11, 22, 33};
9731   EXPECT_TRUE(out1.isEqual(*result));
9732 
9733   Tensor out2(ElemKind::Int64ITy, {1, 3, 1, 1});
9734   out2.getHandle<int64_t>() = {0 + 2 * 3 + 2 * 12, 1 + 2 * 3 + 2 * 12,
9735                                2 + 2 * 3 + 2 * 12};
9736   EXPECT_TRUE(out2.isEqual(*argmax));
9737 }
9738 
TEST_P(OperatorTest,FloatMaxPoolWithArgmaxTransposed)9739 TEST_P(OperatorTest, FloatMaxPoolWithArgmaxTransposed) {
9740   CHECK_IF_ENABLED();
9741   testMaxPoolWithArgmaxTransposed<float>(bindings_, mod_, F_, EE_,
9742                                          ElemKind::FloatTy);
9743 }
9744 
TEST_P(OperatorTest,QuantizedMaxPoolWithArgmaxTransposed)9745 TEST_P(OperatorTest, QuantizedMaxPoolWithArgmaxTransposed) {
9746   CHECK_IF_ENABLED();
9747   testMaxPoolWithArgmaxTransposed<int8_t>(bindings_, mod_, F_, EE_,
9748                                           ElemKind::Int8QTy);
9749 }
9750 
TEST_P(OperatorStatelessTest,Int8Tanh)9751 TEST_P(OperatorStatelessTest, Int8Tanh) {
9752   CHECK_IF_ENABLED();
9753   compareAgainstInterpreter(getBackendName(), createAndInitBasicTanhTest,
9754                             ElemKind::FloatTy, ElemKind::Int8QTy, 0.005f,
9755                             parCloneCountOpt);
9756 }
9757 
TEST_P(OperatorStatelessTest,Tanh_Float16)9758 TEST_P(OperatorStatelessTest, Tanh_Float16) {
9759   CHECK_IF_ENABLED();
9760   compareAgainstInterpreter(getBackendName(), createAndInitBasicTanhTest,
9761                             ElemKind::FloatTy, ElemKind::Float16Ty, 0.001f,
9762                             parCloneCountOpt);
9763 }
9764 
TEST_P(OperatorStatelessTest,Tanh_BFloat16)9765 TEST_P(OperatorStatelessTest, Tanh_BFloat16) {
9766   CHECK_IF_ENABLED();
9767   compareAgainstInterpreter(getBackendName(), createAndInitBasicTanhTest,
9768                             ElemKind::FloatTy, ElemKind::BFloat16Ty, 0.001f,
9769                             parCloneCountOpt);
9770 }
9771 
9772 /// Verify that the Tanh operator works correctly.
TEST_P(OperatorTest,Tanh)9773 TEST_P(OperatorTest, Tanh) {
9774   CHECK_IF_ENABLED();
9775 
9776   constexpr dim_t size = 10;
9777   auto *input =
9778       mod_.createPlaceholder(ElemKind::FloatTy, {size}, "input", false);
9779   bindings_.allocate(input)->getHandle().randomize(-10.0, 10.0, mod_.getPRNG());
9780 
9781   auto *tanh = F_->createTanh("Tanh", input);
9782   auto *save = F_->createSave("Save", tanh);
9783   bindings_.allocate(save->getPlaceholder());
9784 
9785   EE_.compile(CompilationMode::Infer);
9786   EE_.run(bindings_);
9787 
9788   auto resultH = bindings_.get(save->getPlaceholder())->getHandle();
9789   auto inputH = bindings_.get(input)->getHandle();
9790 
9791   for (dim_t i = 0; i < size; i++) {
9792     EXPECT_NEAR(resultH.at({i}), std::tanh(inputH.at({i})), 0.001);
9793   }
9794 }
9795 
TEST_P(OperatorStatelessTest,Exp_Float16)9796 TEST_P(OperatorStatelessTest, Exp_Float16) {
9797   CHECK_IF_ENABLED();
9798   compareAgainstInterpreter(getBackendName(), createAndInitBasicExpTest,
9799                             ElemKind::FloatTy, ElemKind::Float16Ty, 0.005f,
9800                             parCloneCountOpt);
9801 }
9802 
TEST_P(OperatorStatelessTest,Exp_BFloat16)9803 TEST_P(OperatorStatelessTest, Exp_BFloat16) {
9804   CHECK_IF_ENABLED();
9805   compareAgainstInterpreter(getBackendName(), createAndInitBasicExpTest,
9806                             ElemKind::FloatTy, ElemKind::BFloat16Ty, 0.005f,
9807                             parCloneCountOpt);
9808 }
9809 
9810 /// Verify that the Exp operator works correctly.
TEST_P(OperatorTest,Exp)9811 TEST_P(OperatorTest, Exp) {
9812   CHECK_IF_ENABLED();
9813   constexpr dim_t size = 10;
9814   auto *input =
9815       mod_.createPlaceholder(ElemKind::FloatTy, {size}, "input", false);
9816   bindings_.allocate(input)->getHandle().randomize(-10.0, 10.0, mod_.getPRNG());
9817 
9818   auto *expn = F_->createExp("Exp", input);
9819   auto *save = F_->createSave("Save", expn);
9820   bindings_.allocate(save->getPlaceholder());
9821 
9822   EE_.compile(CompilationMode::Infer);
9823   EE_.run(bindings_);
9824 
9825   auto resultH = bindings_.get(save->getPlaceholder())->getHandle();
9826   auto inputH = bindings_.get(input)->getHandle();
9827 
9828   for (dim_t i = 0; i < size; i++) {
9829     EXPECT_NEAR(resultH.at({i}), std::exp(inputH.at({i})), 0.001);
9830   }
9831 }
9832 
9833 /// Verify that a quantized Log works correctly.
TEST_P(OperatorStatelessTest,Int8Log)9834 TEST_P(OperatorStatelessTest, Int8Log) {
9835   CHECK_IF_ENABLED();
9836   compareAgainstInterpreter(getBackendName(), createAndInitBasicLogTest,
9837                             ElemKind::FloatTy, ElemKind::Int8QTy, 0.1f,
9838                             parCloneCountOpt);
9839 }
9840 
9841 /// Check Non-square kernel for conv.
TEST_P(OperatorTest,NonSquareKernelConvolution)9842 TEST_P(OperatorTest, NonSquareKernelConvolution) {
9843   CHECK_IF_ENABLED();
9844 
9845   auto *input =
9846       mod_.createPlaceholder(ElemKind::FloatTy, {1, 4, 4, 1}, "input", false);
9847   auto IH = bindings_.allocate(input)->getHandle();
9848   for (size_t i = 0; i < 4 * 4; i++) {
9849     IH.raw(i) = i + 1;
9850   }
9851 
9852   auto filter =
9853       mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 1}, "filter", false);
9854   auto FH = bindings_.allocate(filter)->getHandle();
9855   for (size_t i = 0; i < 1 * 2 * 3; i++) {
9856     FH.raw(i) = i + 1;
9857   }
9858 
9859   auto *zeroBias =
9860       mod_.createPlaceholder(ElemKind::FloatTy, {1}, "bias", false);
9861   bindings_.allocate(zeroBias)->zero();
9862 
9863   auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 3, 2, 1});
9864   ConvolutionNode *CN = F_->createConv("Conv", input, filter, zeroBias, outTy,
9865                                        {2, 3}, {1, 1}, {0, 0, 0, 0}, 1);
9866   SaveNode *S = F_->createSave("save", CN);
9867   bindings_.allocate(S->getPlaceholder());
9868 
9869   ::glow::convertPlaceholdersToConstants(F_, bindings_,
9870                                          {input, S->getPlaceholder()});
9871   EE_.compile(CompilationMode::Infer);
9872   EE_.run(bindings_);
9873   Tensor &result = *bindings_.get(S->getPlaceholder());
9874 
9875   static const float ref[] = {106, 127, 190, 211, 274, 295};
9876   for (size_t i = 0; i < 6; i++)
9877     EXPECT_EQ(result.getHandle().raw(i), ref[i]);
9878 }
9879 
9880 /// Check Non-cubic kernel for conv3D.
TEST_P(OperatorTest,NonCubicKernelConv3D)9881 TEST_P(OperatorTest, NonCubicKernelConv3D) {
9882   CHECK_IF_ENABLED();
9883 
9884   auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 4, 4, 4, 1},
9885                                        "input", false);
9886   auto IH = bindings_.allocate(input)->getHandle();
9887   int nextVal = 1;
9888   for (dim_t i = 0; i < 4; i++) {
9889     for (dim_t j = 0; j < 4; j++) {
9890       for (dim_t k = 0; k < 4; k++) {
9891         IH.at({0, i, j, k, 0}) = static_cast<float>(nextVal++);
9892       } // D
9893     }   // W
9894   }     // H
9895 
9896   auto *filter = mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 2, 3, 1},
9897                                         "filter", false);
9898   auto FH = bindings_.allocate(filter)->getHandle();
9899   nextVal = 1;
9900   for (dim_t i = 0; i < 1; i++) {
9901     for (dim_t j = 0; j < 2; j++) {
9902       for (dim_t k = 0; k < 3; k++) {
9903         FH.at({0, i, j, k, 0}) = static_cast<float>(nextVal++);
9904       } // D
9905     }   // W
9906   }     // H
9907 
9908   auto *zeroBias =
9909       mod_.createPlaceholder(ElemKind::FloatTy, {1}, "bias", false);
9910   bindings_.allocate(zeroBias)->zero();
9911 
9912   auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 4, 3, 2, 1});
9913 
9914   Convolution3DNode *CN =
9915       F_->createConv3D("Conv3D", input, filter, zeroBias, outTy, {1, 2, 3},
9916                        {1, 1, 1}, {0, 0, 0, 0, 0, 0}, 1);
9917   SaveNode *S = F_->createSave("save", CN);
9918   bindings_.allocate(S->getPlaceholder());
9919 
9920   ::glow::convertPlaceholdersToConstants(F_, bindings_,
9921                                          {input, S->getPlaceholder()});
9922   EE_.compile(CompilationMode::Infer);
9923   EE_.run(bindings_);
9924   Tensor &result = *bindings_.get(S->getPlaceholder());
9925 
9926   static const float ref[] = {106, 127, 190,  211,  274,  295,  442,  463,
9927                               526, 547, 610,  631,  778,  799,  862,  883,
9928                               946, 967, 1114, 1135, 1198, 1219, 1282, 1303};
9929   for (size_t i = 0; i < 4 * 3 * 2; i++) {
9930     EXPECT_EQ(result.getHandle().raw(i), ref[i]);
9931   }
9932 }
9933 
9934 /// Check Non-cubic kernel for conv3D with quantized input, filters, and bias.
TEST_P(OperatorTest,NonCubicKernelConv3DQuantized)9935 TEST_P(OperatorTest, NonCubicKernelConv3DQuantized) {
9936   CHECK_IF_ENABLED();
9937 
9938   auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 4, 4, 4, 1},
9939                                        "input", false);
9940   auto IH = bindings_.allocate(input)->getHandle();
9941   int nextVal = 1;
9942   for (dim_t i = 0; i < 4; i++) {
9943     for (dim_t j = 0; j < 4; j++) {
9944       for (dim_t k = 0; k < 4; k++) {
9945         IH.at({0, i, j, k, 0}) = static_cast<float>(nextVal++);
9946       } // D
9947     }   // W
9948   }     // H
9949 
9950   auto qInType = mod_.uniqueType(ElemKind::Int16QTy, {1, 4, 4, 4, 1}, 0.1, 0);
9951   QuantizeNode *qInput = F_->createQuantize("q_input", input, qInType);
9952 
9953   auto *filter = mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 2, 3, 1},
9954                                         "filter", false);
9955   auto FH = bindings_.allocate(filter)->getHandle();
9956   nextVal = 1;
9957   for (dim_t i = 0; i < 1; i++) {
9958     for (dim_t j = 0; j < 2; j++) {
9959       for (dim_t k = 0; k < 3; k++) {
9960         FH.at({0, i, j, k, 0}) = static_cast<float>(nextVal++);
9961       } // D
9962     }   // W
9963   }     // H
9964 
9965   auto qFilterType =
9966       mod_.uniqueType(ElemKind::Int16QTy, {1, 1, 2, 3, 1}, 0.1, 0);
9967   QuantizeNode *qFilter = F_->createQuantize("q_filter", filter, qFilterType);
9968 
9969   auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "bias", false);
9970   bindings_.allocate(bias)->zero();
9971 
9972   auto qBiasType = mod_.uniqueType(ElemKind::Int32QTy, {1}, 0.1, 0);
9973   QuantizeNode *qBias = F_->createQuantize("q_bias", bias, qBiasType);
9974 
9975   auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 4, 3, 2, 1});
9976 
9977   Convolution3DNode *CN =
9978       F_->createConv3D("Conv3D", input, filter, bias, outTy, {1, 2, 3},
9979                        {1, 1, 1}, {0, 0, 0, 0, 0, 0}, 1);
9980 
9981   auto qOutTy = mod_.uniqueType(ElemKind::Int16QTy, {1, 4, 3, 2, 1}, 0.1, 0);
9982 
9983   Convolution3DNode *qCN =
9984       F_->createConv3D("q_Conv3D", qInput, qFilter, qBias, qOutTy, {1, 2, 3},
9985                        {1, 1, 1}, {0, 0, 0, 0, 0, 0}, 1);
9986 
9987   SaveNode *S = F_->createSave("save", CN);
9988 
9989   DequantizeNode *deQ =
9990       F_->createDequantize("deQ_result", qCN, ElemKind::FloatTy);
9991   SaveNode *qS = F_->createSave("save", deQ);
9992 
9993   bindings_.allocate(S->getPlaceholder());
9994 
9995   ::glow::convertPlaceholdersToConstants(F_, bindings_,
9996                                          {input, S->getPlaceholder()});
9997   bindings_.allocate(mod_.getPlaceholders());
9998   EE_.compile(CompilationMode::Infer);
9999   EE_.run(bindings_);
10000 
10001   Tensor &result = *bindings_.get(S->getPlaceholder());
10002   Tensor &qResult = *bindings_.get(qS->getPlaceholder());
10003 
10004   for (size_t i = 0; i < 4 * 3 * 2; i++) {
10005     EXPECT_NEAR(qResult.getHandle().raw(i), result.getHandle().raw(i), 0.5);
10006   }
10007 }
10008 
10009 /// Test for quantized Convolution3D.
Conv3DQuantizedTest(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind elemKind,ElemKind biaselemKind)10010 static void Conv3DQuantizedTest(glow::PlaceholderBindings &bindings,
10011                                 glow::Module &mod, glow::Function *F,
10012                                 glow::ExecutionEngine &EE, ElemKind elemKind,
10013                                 ElemKind biaselemKind) {
10014   // Create floating-point network.
10015   auto *input =
10016       mod.createPlaceholder(ElemKind::FloatTy, {1, 4, 4, 4, 1}, "input", false);
10017   auto *filter = mod.createPlaceholder(ElemKind::FloatTy, {1, 1, 2, 3, 1},
10018                                        "filter", false);
10019   auto *bias = mod.createPlaceholder(ElemKind::FloatTy, {1}, "bias", false);
10020   auto outTy = mod.uniqueType(ElemKind::FloatTy, {1, 4, 3, 2, 1});
10021   Convolution3DNode *conv3d =
10022       F->createConv3D("Conv3D", input, filter, bias, outTy, {1, 2, 3},
10023                       {1, 1, 1}, {0, 0, 0, 0, 0, 0}, 1);
10024   SaveNode *save = F->createSave("save", conv3d);
10025 
10026   // Quantized types.
10027   auto inputTQP = quantization::chooseQuantizationParams(
10028       {-1.0, 1.0}, quantization::Schema::Asymmetric, elemKind);
10029   auto filterTQP = quantization::chooseQuantizationParams(
10030       {-1.0, 1.0}, quantization::Schema::Asymmetric, elemKind);
10031   auto outputTQP = quantization::chooseQuantizationParams(
10032       {-4.0, 4.0}, quantization::Schema::Asymmetric, elemKind);
10033 
10034   // Create quantized network.
10035   auto inputQTy = mod.uniqueType(elemKind, {1, 4, 4, 4, 1}, inputTQP.scale,
10036                                  inputTQP.offset);
10037   auto filterQTy = mod.uniqueType(elemKind, {1, 1, 2, 3, 1}, filterTQP.scale,
10038                                   filterTQP.offset);
10039   auto outQTy = mod.uniqueType(elemKind, {1, 4, 3, 2, 1}, outputTQP.scale,
10040                                outputTQP.offset);
10041   QuantizeNode *inputQ = F->createQuantize("inputQ", input, inputQTy);
10042   QuantizeNode *filterQ = F->createQuantize("filterQ", filter, filterQTy);
10043   Convolution3DNode *conv3dQ = nullptr;
10044   if (biaselemKind == ElemKind::FloatTy) {
10045     conv3dQ = F->createConv3D("Conv3DQ", inputQ, filterQ, bias, outQTy,
10046                               {1, 2, 3}, {1, 1, 1}, {0, 0, 0, 0, 0, 0}, 1);
10047   } else {
10048     auto biasTQP = quantization::chooseQuantizationParams(
10049         {-1.0, 1.0}, quantization::Schema::Asymmetric, biaselemKind);
10050     auto biasQTy =
10051         mod.uniqueType(biaselemKind, {1}, biasTQP.scale, biasTQP.offset);
10052     QuantizeNode *biasQ = F->createQuantize("biasQ", bias, biasQTy);
10053     conv3dQ = F->createConv3D("Conv3DQ", inputQ, filterQ, biasQ, outQTy,
10054                               {1, 2, 3}, {1, 1, 1}, {0, 0, 0, 0, 0, 0}, 1);
10055   }
10056   DequantizeNode *deQ = F->createDequantize("deQ", conv3dQ, ElemKind::FloatTy);
10057   SaveNode *saveQ = F->createSave("saveQ", deQ);
10058 
10059   // Allocate placeholders.
10060   bindings.allocate(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
10061   bindings.allocate(filter)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
10062   bindings.allocate(bias)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
10063   bindings.allocate(save->getPlaceholder());
10064   bindings.allocate(saveQ->getPlaceholder());
10065 
10066   // Run network.
10067   ::glow::convertPlaceholdersToConstants(
10068       F, bindings, {input, save->getPlaceholder(), saveQ->getPlaceholder()});
10069   EE.compile(CompilationMode::Infer);
10070   EE.run(bindings);
10071 
10072   // Compare.
10073   Tensor &res = *bindings.get(save->getPlaceholder());
10074   Tensor &resQ = *bindings.get(saveQ->getPlaceholder());
10075   for (size_t i = 0; i < res.size(); i++) {
10076     EXPECT_NEAR(res.getHandle().raw(i), resQ.getHandle().raw(i), 0.03);
10077   }
10078 }
10079 
10080 /// Test Int8 Conv3D with Int8 bias.
TEST_P(OperatorTest,Conv3DQuantizedTest_Int8_BiasInt8)10081 TEST_P(OperatorTest, Conv3DQuantizedTest_Int8_BiasInt8) {
10082   ENABLED_BACKENDS("Interpreter");
10083   Conv3DQuantizedTest(bindings_, mod_, F_, EE_, ElemKind::Int8QTy,
10084                       ElemKind::Int8QTy);
10085 }
10086 
10087 /// Test Int8 Conv3D with Int32 bias.
TEST_P(OperatorTest,Conv3DQuantizedTest_Int8_BiasInt32)10088 TEST_P(OperatorTest, Conv3DQuantizedTest_Int8_BiasInt32) {
10089   ENABLED_BACKENDS("Interpreter", "NNPI");
10090   Conv3DQuantizedTest(bindings_, mod_, F_, EE_, ElemKind::Int8QTy,
10091                       ElemKind::Int32QTy);
10092 }
10093 
10094 /// Test Int8 Conv3D with Float32 bias.
TEST_P(OperatorTest,Conv3DQuantizedTest_Int8_BiasFloat)10095 TEST_P(OperatorTest, Conv3DQuantizedTest_Int8_BiasFloat) {
10096   ENABLED_BACKENDS("Interpreter", "NNPI");
10097   Conv3DQuantizedTest(bindings_, mod_, F_, EE_, ElemKind::Int8QTy,
10098                       ElemKind::FloatTy);
10099 }
10100 
10101 /// Test Int16 Conv3D with Int16 bias.
TEST_P(OperatorTest,Conv3DQuantizedTest_Int16_BiasInt16)10102 TEST_P(OperatorTest, Conv3DQuantizedTest_Int16_BiasInt16) {
10103   ENABLED_BACKENDS("Interpreter");
10104   Conv3DQuantizedTest(bindings_, mod_, F_, EE_, ElemKind::Int16QTy,
10105                       ElemKind::Int16QTy);
10106 }
10107 
10108 /// Test Int16 Conv3D with Int32 bias.
TEST_P(OperatorTest,Conv3DQuantizedTest_Int16_BiasInt32)10109 TEST_P(OperatorTest, Conv3DQuantizedTest_Int16_BiasInt32) {
10110   ENABLED_BACKENDS("Interpreter");
10111   Conv3DQuantizedTest(bindings_, mod_, F_, EE_, ElemKind::Int16QTy,
10112                       ElemKind::Int32QTy);
10113 }
10114 
10115 /// Check Non-square kernel for AveragePool.
TEST_P(OperatorTest,NonSquareKernelAveragePool)10116 TEST_P(OperatorTest, NonSquareKernelAveragePool) {
10117   CHECK_IF_ENABLED();
10118 
10119   auto *input =
10120       mod_.createPlaceholder(ElemKind::FloatTy, {1, 4, 4, 1}, "input", false);
10121   auto IH = bindings_.allocate(input)->getHandle();
10122   for (size_t i = 0; i < 4 * 4; i++) {
10123     IH.raw(i) = i + 1;
10124   }
10125   auto *Pool = F_->createAvgPool("pool", input, {2, 3}, {1, 1}, {0, 0, 0, 0});
10126   auto *S = F_->createSave("save", Pool);
10127   bindings_.allocate(S->getPlaceholder());
10128 
10129   EE_.compile(CompilationMode::Infer);
10130   EE_.run(bindings_);
10131   Tensor &result = *bindings_.get(S->getPlaceholder());
10132 
10133   static const float ref[] = {4, 5, 8, 9, 12, 13};
10134   for (size_t i = 0; i < 6; i++)
10135     EXPECT_EQ(result.getHandle().raw(i), ref[i]);
10136 }
10137 
10138 /// Check Non-square kernel for MaxPool.
TEST_P(OperatorTest,NonSquareKernelMaxPool)10139 TEST_P(OperatorTest, NonSquareKernelMaxPool) {
10140   CHECK_IF_ENABLED();
10141 
10142   auto *input =
10143       mod_.createPlaceholder(ElemKind::FloatTy, {1, 4, 4, 1}, "input", false);
10144   auto IH = bindings_.allocate(input)->getHandle();
10145   for (size_t i = 0; i < 4 * 4; i++) {
10146     IH.raw(i) = i + 1;
10147   }
10148   auto *Pool = F_->createMaxPool("pool", input, {2, 3}, {1, 1}, {0, 0, 0, 0});
10149   auto *S = F_->createSave("save", Pool->getResult());
10150   bindings_.allocate(S->getPlaceholder());
10151 
10152   EE_.compile(CompilationMode::Infer);
10153   EE_.run(bindings_);
10154   Tensor &result = *bindings_.get(S->getPlaceholder());
10155 
10156   static const float ref[] = {7, 8, 11, 12, 15, 16};
10157   for (size_t i = 0; i < 6; i++)
10158     EXPECT_EQ(result.getHandle().raw(i), ref[i]);
10159 }
10160 
10161 /// Check Non-square stride for conv.
TEST_P(OperatorTest,NonSquareStrideConvolution)10162 TEST_P(OperatorTest, NonSquareStrideConvolution) {
10163   CHECK_IF_ENABLED();
10164 
10165   auto *input =
10166       mod_.createPlaceholder(ElemKind::FloatTy, {1, 4, 4, 1}, "input", false);
10167   auto IH = bindings_.allocate(input)->getHandle();
10168   for (size_t i = 0; i < 4 * 4; i++) {
10169     IH.raw(i) = i + 1;
10170   }
10171 
10172   auto filter =
10173       mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 2, 1}, "filter", false);
10174   auto FH = bindings_.allocate(filter)->getHandle();
10175   for (size_t i = 0; i < 1 * 2 * 2; i++) {
10176     FH.raw(i) = i + 1;
10177   }
10178 
10179   auto *zeroBias =
10180       mod_.createPlaceholder(ElemKind::FloatTy, {1}, "bias", false);
10181   bindings_.allocate(zeroBias)->zero();
10182 
10183   auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 2, 2, 1});
10184   ConvolutionNode *CN = F_->createConv("Conv", input, filter, zeroBias, outTy,
10185                                        {2, 2}, {3, 2}, {0, 0, 1, 1}, 1);
10186   SaveNode *S = F_->createSave("save", CN);
10187   bindings_.allocate(S->getPlaceholder());
10188 
10189   ::glow::convertPlaceholdersToConstants(F_, bindings_,
10190                                          {input, S->getPlaceholder()});
10191   EE_.compile(CompilationMode::Infer);
10192   EE_.run(bindings_);
10193   Tensor &result = *bindings_.get(S->getPlaceholder());
10194 
10195   static const float ref[] = {44, 64, 41, 47};
10196   for (size_t i = 0; i < 4; i++)
10197     EXPECT_EQ(result.getHandle().raw(i), ref[i]);
10198 }
10199 
10200 /// Check Non-cubic stride for conv3D.
TEST_P(OperatorTest,NonCubicStrideConv3D)10201 TEST_P(OperatorTest, NonCubicStrideConv3D) {
10202   CHECK_IF_ENABLED();
10203 
10204   auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 4, 4, 4, 1},
10205                                        "input", false);
10206   auto IH = bindings_.allocate(input)->getHandle();
10207   int nextVal = 1;
10208   for (dim_t i = 0; i < 4; i++) {
10209     for (dim_t j = 0; j < 4; j++) {
10210       for (dim_t k = 0; k < 4; k++) {
10211         IH.at({0, i, j, k, 0}) = static_cast<float>(nextVal++);
10212       } // W
10213     }   // H
10214   }     // T
10215 
10216   auto *filter = mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 2, 2, 1},
10217                                         "filter", false);
10218   auto FH = bindings_.allocate(filter)->getHandle();
10219   nextVal = 1;
10220   for (dim_t i = 0; i < 2; i++) {
10221     for (dim_t j = 0; j < 2; j++) {
10222       for (dim_t k = 0; k < 2; k++) {
10223         FH.at({0, i, j, k, 0}) = static_cast<float>(nextVal++);
10224       } // W
10225     }   // H
10226   }     // T
10227 
10228   auto *zeroBias =
10229       mod_.createPlaceholder(ElemKind::FloatTy, {1}, "bias", false);
10230   bindings_.allocate(zeroBias)->zero();
10231 
10232   auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 2, 2, 2, 1});
10233 
10234   Convolution3DNode *CN =
10235       F_->createConv3D("Conv3D", input, filter, zeroBias, outTy, {2, 2, 2},
10236                        {3, 3, 2}, //{0, 0, 0, 1, 1, 1}, 1);
10237                        {0, 1, 0, 1, 0, 1}, 1);
10238   SaveNode *S = F_->createSave("save", CN);
10239   bindings_.allocate(S->getPlaceholder());
10240 
10241   ::glow::convertPlaceholdersToConstants(F_, bindings_,
10242                                          {input, S->getPlaceholder()});
10243   EE_.compile(CompilationMode::Infer);
10244   EE_.run(bindings_);
10245   Tensor &result = *bindings_.get(S->getPlaceholder());
10246 
10247   static const float ref[] = {560, 632, 366, 394, 524, 544, 185, 191};
10248   for (size_t i = 0; i < 8; i++) {
10249     EXPECT_EQ(result.getHandle().raw(i), ref[i]);
10250   }
10251 }
10252 
10253 /// Check Non-square stride for AveragePool.
TEST_P(OperatorTest,NonSquareStrideAveragePool)10254 TEST_P(OperatorTest, NonSquareStrideAveragePool) {
10255   CHECK_IF_ENABLED();
10256 
10257   auto *input =
10258       mod_.createPlaceholder(ElemKind::FloatTy, {1, 4, 4, 1}, "input", false);
10259   auto IH = bindings_.allocate(input)->getHandle();
10260   for (size_t i = 0; i < 4 * 4; i++) {
10261     IH.raw(i) = i + 1;
10262   }
10263   auto *Pool = F_->createAvgPool("pool", input, {2, 2}, {3, 2}, {0, 0, 1, 1});
10264   auto *S = F_->createSave("save", Pool);
10265   bindings_.allocate(S->getPlaceholder());
10266 
10267   EE_.compile(CompilationMode::Infer);
10268   EE_.run(bindings_);
10269   Tensor &result = *bindings_.get(S->getPlaceholder());
10270 
10271   static const float ref[] = {3.5, 5.5, 6.75, 7.75};
10272   for (size_t i = 0; i < 4; i++)
10273     EXPECT_EQ(result.getHandle().raw(i), ref[i]);
10274 }
10275 
10276 /// Check Non-square stride for MaxPool.
TEST_P(OperatorTest,NonSquareStrideMaxPool)10277 TEST_P(OperatorTest, NonSquareStrideMaxPool) {
10278   CHECK_IF_ENABLED();
10279 
10280   auto *input =
10281       mod_.createPlaceholder(ElemKind::FloatTy, {1, 4, 4, 1}, "input", false);
10282   auto IH = bindings_.allocate(input)->getHandle();
10283   for (size_t i = 0; i < 4 * 4; i++) {
10284     IH.raw(i) = i + 1;
10285   }
10286   auto *Pool = F_->createMaxPool("pool", input, {2, 2}, {3, 2}, {0, 0, 1, 1});
10287   auto *S = F_->createSave("save", Pool->getResult());
10288   bindings_.allocate(S->getPlaceholder());
10289 
10290   EE_.compile(CompilationMode::Infer);
10291   EE_.run(bindings_);
10292   Tensor &result = *bindings_.get(S->getPlaceholder());
10293 
10294   static const float ref[] = {6, 8, 14, 16};
10295   for (size_t i = 0; i < 4; i++)
10296     EXPECT_EQ(result.getHandle().raw(i), ref[i]);
10297 }
10298 
TEST_P(OperatorTest,SigmoidOverflow)10299 TEST_P(OperatorTest, SigmoidOverflow) {
10300   CHECK_IF_ENABLED();
10301 
10302   auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {2}, "input", false);
10303   auto IH = bindings_.allocate(input)->getHandle();
10304   IH.raw(0) = 1000;
10305   IH.raw(1) = -1000;
10306 
10307   auto *fpSigmoid = F_->createSigmoid("fpSigmoid", input);
10308   auto *S = F_->createSave("fpSave", fpSigmoid);
10309   bindings_.allocate(S->getPlaceholder());
10310   EE_.compile(CompilationMode::Infer);
10311   EE_.run(bindings_);
10312   Tensor &result = *bindings_.get(S->getPlaceholder());
10313   static const float ref[] = {1, 0};
10314   for (size_t i = 0; i < 2; i++) {
10315     EXPECT_EQ(result.getHandle().raw(i), ref[i]);
10316   }
10317 }
10318 
10319 /// This unit test exposes a problem with the CPU Sigmoid when stacking a higher
10320 /// number of operations for extreme input values which result in NaNs.
TEST_P(OperatorTest,SigmoidOverflowCPUStacking)10321 TEST_P(OperatorTest, SigmoidOverflowCPUStacking) {
10322   CHECK_IF_ENABLED();
10323   dim_t size = 20;
10324   auto *input =
10325       mod_.createPlaceholder(ElemKind::FloatTy, {size}, "input", false);
10326   auto IH = bindings_.allocate(input)->getHandle();
10327   IH = {
10328       -1588.409912109375,  -460.55999755859375, -1176.9149169921875,
10329       -1655.9249267578125, -1580.1217041015625, -1680.279541015625,
10330       -1750.2677001953125, -1762.1697998046875, -1616.599365234375,
10331       -1725.301025390625,  +1588.409912109375,  +460.55999755859375,
10332       +1176.9149169921875, +1655.9249267578125, +1580.1217041015625,
10333       +1680.279541015625,  +1750.2677001953125, +1762.1697998046875,
10334       +1616.599365234375,  +1725.301025390625,
10335   };
10336   auto *fpSigmoid = F_->createSigmoid("fpSigmoid", input);
10337   auto *S = F_->createSave("fpSave", fpSigmoid);
10338   bindings_.allocate(S->getPlaceholder());
10339   EE_.compile(CompilationMode::Infer);
10340   EE_.run(bindings_);
10341   Tensor &result = *bindings_.get(S->getPlaceholder());
10342   for (size_t i = 0; i < size; i++) {
10343     float ref = IH.raw(i) > 0 ? 1 : 0;
10344     EXPECT_NEAR(result.getHandle().raw(i), ref, 1E-6);
10345   }
10346 }
10347 
10348 /// This unit test exposes a problem with the CPU Tanh when stacking a higher
10349 /// number of operations for extreme input values which result in NaNs.
TEST_P(OperatorTest,TanhOverflowCPUStacking)10350 TEST_P(OperatorTest, TanhOverflowCPUStacking) {
10351   CHECK_IF_ENABLED();
10352   dim_t size = 20;
10353   auto *input =
10354       mod_.createPlaceholder(ElemKind::FloatTy, {size}, "input", false);
10355   auto IH = bindings_.allocate(input)->getHandle();
10356   IH = {
10357       -1588.409912109375,  -460.55999755859375, -1176.9149169921875,
10358       -1655.9249267578125, -1580.1217041015625, -1680.279541015625,
10359       -1750.2677001953125, -1762.1697998046875, -1616.599365234375,
10360       -1725.301025390625,  +1588.409912109375,  +460.55999755859375,
10361       +1176.9149169921875, +1655.9249267578125, +1580.1217041015625,
10362       +1680.279541015625,  +1750.2677001953125, +1762.1697998046875,
10363       +1616.599365234375,  +1725.301025390625,
10364   };
10365   auto *fpTanh = F_->createTanh("fpTanh", input);
10366   auto *S = F_->createSave("fpSave", fpTanh);
10367   bindings_.allocate(S->getPlaceholder());
10368   EE_.compile(CompilationMode::Infer);
10369   EE_.run(bindings_);
10370   Tensor &result = *bindings_.get(S->getPlaceholder());
10371   for (size_t i = 0; i < size; i++) {
10372     float ref = IH.raw(i) > 0 ? 1 : -1;
10373     EXPECT_NEAR(result.getHandle().raw(i), ref, 1E-6);
10374   }
10375 }
10376 
TEST_P(OperatorStatelessTest,Int8Sigmoid)10377 TEST_P(OperatorStatelessTest, Int8Sigmoid) {
10378   CHECK_IF_ENABLED();
10379   compareAgainstInterpreter(getBackendName(), createAndInitBasicSigmoidTest,
10380                             ElemKind::FloatTy, ElemKind::Int8QTy, 0.005f,
10381                             parCloneCountOpt);
10382 }
10383 
10384 /// Check that the batch add operator works properly.
TEST_P(OperatorTest,BatchAdd)10385 TEST_P(OperatorTest, BatchAdd) {
10386   CHECK_IF_ENABLED();
10387 
10388   PseudoRNG PRNG;
10389 
10390   auto *input =
10391       mod_.createPlaceholder(ElemKind::FloatTy, {13, 3, 3}, "A", false);
10392   bindings_.allocate(input)->getHandle<float>().randomize(-3.0, 3.0, PRNG);
10393   auto *slice =
10394       mod_.createPlaceholder(ElemKind::FloatTy, {3, 3}, "slice", false);
10395   bindings_.allocate(slice)->getHandle<float>().randomize(-3.0, 3.0, PRNG);
10396   auto *batchAdd = F_->createBatchedAdd("batchAdd", input, slice);
10397   auto *S = F_->createSave("save", batchAdd);
10398   bindings_.allocate(S->getPlaceholder());
10399 
10400   EE_.compile(CompilationMode::Infer);
10401   EE_.run(bindings_);
10402 
10403   auto result = bindings_.get(S->getPlaceholder())->getHandle<float>();
10404   auto handleInput = bindings_.get(input)->getHandle<float>();
10405   auto handleSlice = bindings_.get(slice)->getHandle<float>();
10406   ASSERT_EQ(result.size(), handleInput.size());
10407   for (size_t idx = 0, end = result.size(); idx != end; ++idx) {
10408     EXPECT_EQ(result.raw(idx),
10409               handleInput.raw(idx) + handleSlice.raw(idx % handleSlice.size()));
10410   }
10411 }
10412 
10413 /// Check that the batch add operator works properly for FP16.
TEST_P(OperatorTest,FP16BatchAdd)10414 TEST_P(OperatorTest, FP16BatchAdd) {
10415   CHECK_IF_ENABLED();
10416 
10417   PseudoRNG PRNG;
10418 
10419   auto *input =
10420       mod_.createPlaceholder(ElemKind::Float16Ty, {13, 3, 3}, "A", false);
10421   bindings_.allocate(input)->getHandle<float16_t>().randomize(-3.0, 3.0, PRNG);
10422   auto *slice =
10423       mod_.createPlaceholder(ElemKind::Float16Ty, {3, 3}, "slice", false);
10424   bindings_.allocate(slice)->getHandle<float16_t>().randomize(-3.0, 3.0, PRNG);
10425   auto *batchAdd = F_->createBatchedAdd("batchAdd", input, slice);
10426   auto *S = F_->createSave("save", batchAdd);
10427   bindings_.allocate(S->getPlaceholder());
10428 
10429   EE_.compile(CompilationMode::Infer);
10430   EE_.run(bindings_);
10431 
10432   auto result = bindings_.get(S->getPlaceholder())->getHandle<float16_t>();
10433   auto handleInput = bindings_.get(input)->getHandle<float16_t>();
10434   auto handleSlice = bindings_.get(slice)->getHandle<float16_t>();
10435   ASSERT_EQ(result.size(), handleInput.size());
10436   for (size_t idx = 0, end = result.size(); idx != end; ++idx) {
10437     EXPECT_EQ(result.raw(idx),
10438               handleInput.raw(idx) + handleSlice.raw(idx % handleSlice.size()));
10439   }
10440 }
10441 
10442 /// Check that the batch add operator works properly for BFloat16.
TEST_P(OperatorTest,BFloat16BatchAdd)10443 TEST_P(OperatorTest, BFloat16BatchAdd) {
10444   CHECK_IF_ENABLED();
10445 
10446   PseudoRNG PRNG;
10447 
10448   auto *input =
10449       mod_.createPlaceholder(ElemKind::BFloat16Ty, {13, 3, 3}, "A", false);
10450   bindings_.allocate(input)->getHandle<bfloat16_t>().randomize(-3.0, 3.0, PRNG);
10451   auto *slice =
10452       mod_.createPlaceholder(ElemKind::BFloat16Ty, {3, 3}, "slice", false);
10453   bindings_.allocate(slice)->getHandle<bfloat16_t>().randomize(-3.0, 3.0, PRNG);
10454   auto *batchAdd = F_->createBatchedAdd("batchAdd", input, slice);
10455   auto *S = F_->createSave("save", batchAdd);
10456   bindings_.allocate(S->getPlaceholder());
10457 
10458   EE_.compile(CompilationMode::Infer);
10459   EE_.run(bindings_);
10460 
10461   auto result = bindings_.get(S->getPlaceholder())->getHandle<bfloat16_t>();
10462   auto handleInput = bindings_.get(input)->getHandle<bfloat16_t>();
10463   auto handleSlice = bindings_.get(slice)->getHandle<bfloat16_t>();
10464   ASSERT_EQ(result.size(), handleInput.size());
10465   for (size_t idx = 0, end = result.size(); idx != end; ++idx) {
10466     EXPECT_EQ(result.raw(idx),
10467               handleInput.raw(idx) + handleSlice.raw(idx % handleSlice.size()));
10468   }
10469 }
10470 
TEST_P(OperatorTest,BroadcastAdd2x)10471 TEST_P(OperatorTest, BroadcastAdd2x) {
10472   CHECK_IF_ENABLED();
10473 
10474   auto *input =
10475       mod_.createPlaceholder(ElemKind::FloatTy, {10, 1}, "input", false);
10476   auto *bias = mod_.createConstant(ElemKind::FloatTy, {1, 1}, "bias");
10477   bias->getPayloadMutable().getHandle() = {42};
10478   auto *tile = F_->createTile("tile", bias, 10, 0);
10479   auto *add = F_->createAdd("add", input, tile);
10480   auto *save = F_->createSave("save", add);
10481   auto *output = save->getPlaceholder();
10482   bindings_.allocate(input)->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
10483   bindings_.allocate(output);
10484   EE_.compile(CompilationMode::Infer);
10485   for (int i = 0; i < 2; i++) {
10486     Tensor expected(ElemKind::FloatTy, {10, 1});
10487     expected.getHandle() = {42, 43, 44, 45, 46, 47, 48, 49, 50, 51};
10488     EE_.run(bindings_);
10489     EXPECT_TRUE(bindings_.get(output)->isEqual(expected));
10490   }
10491 }
10492 
10493 /// Helper to test Sigmoid using \p DTy.
10494 template <typename DataType>
testSigmoid(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,float allowedError=0.001f)10495 static void testSigmoid(glow::PlaceholderBindings &bindings, glow::Module &mod,
10496                         glow::Function *F, glow::ExecutionEngine &EE,
10497                         ElemKind DTy, float allowedError = 0.001f) {
10498   constexpr dim_t size = 10;
10499   auto *input = mod.createPlaceholder(DTy, {size}, "input", false);
10500   bindings.allocate(input)->getHandle<DataType>().randomize(-10.0, 10.0,
10501                                                             mod.getPRNG());
10502 
10503   auto *sigmoid = F->createSigmoid("sigmoid", input);
10504   auto *save = F->createSave("Save", sigmoid);
10505   bindings.allocate(save->getPlaceholder());
10506 
10507   EE.compile(CompilationMode::Infer);
10508   EE.run(bindings);
10509 
10510   auto RH = bindings.get(save->getPlaceholder())->getHandle<DataType>();
10511   auto inH = bindings.get(input)->getHandle<DataType>();
10512 
10513   for (dim_t i = 0; i < size; i++) {
10514     float val = 1 / (1 + std::exp(-(float)inH.at({i})));
10515     EXPECT_NEAR(RH.at({i}), val, allowedError);
10516   }
10517 }
10518 
10519 /// Verify that the Sigmoid operator works correctly with FloatTy.
TEST_P(OperatorTest,Sigmoid_Float)10520 TEST_P(OperatorTest, Sigmoid_Float) {
10521   CHECK_IF_ENABLED();
10522   testSigmoid<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
10523 }
10524 
10525 /// Verify that the Sigmoid operator works correctly with Float16Ty.
TEST_P(OperatorTest,Sigmoid_Float16)10526 TEST_P(OperatorTest, Sigmoid_Float16) {
10527   CHECK_IF_ENABLED();
10528   testSigmoid<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
10529 }
10530 
10531 /// Verify that the Sigmoid operator works correctly with BFloat16Ty.
TEST_P(OperatorTest,Sigmoid_BFloat16)10532 TEST_P(OperatorTest, Sigmoid_BFloat16) {
10533   CHECK_IF_ENABLED();
10534   testSigmoid<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty,
10535                           0.01f);
10536 }
10537 
10538 /// Helper to test Swish using \p DTy.
10539 template <typename DataType>
testSwish(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,float allowedError=0.002f)10540 static void testSwish(glow::PlaceholderBindings &bindings, glow::Module &mod,
10541                       glow::Function *F, glow::ExecutionEngine &EE,
10542                       ElemKind DTy, float allowedError = 0.002f) {
10543   constexpr dim_t size = 10;
10544   auto *input = mod.createPlaceholder(DTy, {size}, "input", false);
10545   bindings.allocate(input)->getHandle<DataType>().randomize(-5.0, 5.0,
10546                                                             mod.getPRNG());
10547 
10548   auto *swish = F->createSwish("swish", input);
10549   auto *save = F->createSave("Save", swish);
10550   bindings.allocate(save->getPlaceholder());
10551 
10552   EE.compile(CompilationMode::Infer);
10553   EE.run(bindings);
10554 
10555   auto RH = bindings.get(save->getPlaceholder())->getHandle<DataType>();
10556   auto inH = bindings.get(input)->getHandle<DataType>();
10557 
10558   for (dim_t i = 0; i < size; i++) {
10559     float x = (float)inH.at({i});
10560     float val = x / (1 + std::exp(-x));
10561     EXPECT_NEAR(RH.at({i}), val, allowedError);
10562   }
10563 }
10564 
10565 /// Verify that the Swish operator works correctly with FloatTy.
TEST_P(OperatorTest,Swish_Float)10566 TEST_P(OperatorTest, Swish_Float) {
10567   CHECK_IF_ENABLED();
10568   testSwish<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
10569 }
10570 
10571 /// Verify that the Swish operator works correctly with Float16Ty.
TEST_P(OperatorTest,Swish_Float16)10572 TEST_P(OperatorTest, Swish_Float16) {
10573   CHECK_IF_ENABLED();
10574   testSwish<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
10575 }
10576 
10577 /// Verify that the Swish operator works correctly with BFloat16Ty.
TEST_P(OperatorTest,Swish_BFloat16)10578 TEST_P(OperatorTest, Swish_BFloat16) {
10579   CHECK_IF_ENABLED();
10580   testSwish<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty, 0.2f);
10581 }
10582 
TEST_P(OperatorTest,IntLookupTable)10583 TEST_P(OperatorTest, IntLookupTable) {
10584   CHECK_IF_ENABLED();
10585 
10586   constexpr dim_t size = 6;
10587   auto *input =
10588       mod_.createPlaceholder(ElemKind::Int8QTy, {size}, 1, 0, "input", false);
10589   bindings_.allocate(input)->getHandle<int8_t>() = {0, 1, 2, 3, 4, 5};
10590 
10591   auto outTy = mod_.uniqueType(ElemKind::Int8QTy, {size}, 3, 3);
10592 
10593   // Mapping i -> i.
10594   std::vector<int8_t> initValues(256);
10595   for (size_t i = 0; i < 256; ++i) {
10596     initValues[i] = i - 128;
10597   }
10598 
10599   auto *lookupTable =
10600       F_->createIntLookupTable("lookupTable", input, initValues, outTy);
10601   auto *save = F_->createSave("save", lookupTable);
10602   bindings_.allocate(save->getPlaceholder());
10603 
10604   EE_.compile(CompilationMode::Infer);
10605   EE_.run(bindings_);
10606 
10607   auto result = bindings_.get(save->getPlaceholder())->getHandle<int8_t>();
10608   for (size_t i = 0; i < size; ++i) {
10609     EXPECT_EQ(result.raw(i), i);
10610   }
10611 }
10612 
10613 /// Helper to test BatchAdd using \p DTy.
10614 template <typename DataType>
testBatchAdd(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)10615 static void testBatchAdd(glow::PlaceholderBindings &bindings, glow::Module &mod,
10616                          glow::Function *F, glow::ExecutionEngine &EE,
10617                          ElemKind DTy) {
10618   unsigned numSlices = 10;
10619   auto *input = mod.createPlaceholder(DTy, {numSlices, 10, 10}, "input", false);
10620   auto *slice = mod.createPlaceholder(DTy, {10, 10}, "slice", false);
10621 
10622   bindings.allocate(input)->getHandle<DataType>().randomize(-10.0, 10.0,
10623                                                             mod.getPRNG());
10624   bindings.allocate(slice)->getHandle<DataType>().randomize(-10.0, 10.0,
10625                                                             mod.getPRNG());
10626 
10627   std::vector<NodeValue> adds;
10628   for (dim_t i = 0; i < numSlices; i++) {
10629     auto *ex = F->createSlice("slice", input, {i, 0, 0}, {i + 1, 10, 10});
10630     auto *ba = F->createBatchedAdd("add", ex, slice);
10631     adds.push_back(ba);
10632   }
10633 
10634   auto *cc = F->createConcat("concat", adds, 0);
10635 
10636   // Remove the reference to the graph nodes to allow DCE to remove them.
10637   adds.clear();
10638 
10639   auto *result = F->createSave("save", cc);
10640   bindings.allocate(result->getPlaceholder());
10641 
10642   EE.compile(CompilationMode::Infer);
10643   EE.run(bindings);
10644 
10645   auto RH = bindings.get(result->getPlaceholder())->getHandle<DataType>();
10646   auto IH = bindings.get(input)->getHandle<DataType>();
10647   auto SH = bindings.get(slice)->getHandle<DataType>();
10648 
10649   // Check that batched add works as expected.
10650   for (dim_t i = 0; i < numSlices; i++) {
10651     for (dim_t j = 0; j < 10; j++) {
10652       for (dim_t k = 0; k < 10; k++) {
10653         EXPECT_NEAR(IH.at({i, j, k}) + SH.at({j, k}), RH.at({i, j, k}),
10654                     0.00001);
10655       }
10656     }
10657   }
10658 }
10659 
10660 /// Check that the sequence of extract-batchedadd-concat works.
TEST_P(OperatorTest,testBatchAdd_Float)10661 TEST_P(OperatorTest, testBatchAdd_Float) {
10662   CHECK_IF_ENABLED();
10663   testBatchAdd<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
10664 }
10665 
10666 /// Check that the sequence of extract-batchedadd-concat works.
TEST_P(OperatorTest,testBatchAdd_Float16)10667 TEST_P(OperatorTest, testBatchAdd_Float16) {
10668   CHECK_IF_ENABLED();
10669   testBatchAdd<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
10670 }
10671 
10672 /// Check that the sequence of extract-batchedadd-concat works.
TEST_P(OperatorTest,testBatchAdd_BFloat16)10673 TEST_P(OperatorTest, testBatchAdd_BFloat16) {
10674   CHECK_IF_ENABLED();
10675   testBatchAdd<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty);
10676 }
10677 
quantizedBatchAdd(ExecutionEngine & EE,Function * F,PlaceholderBindings & bindings,ElemKind Ty)10678 static void quantizedBatchAdd(ExecutionEngine &EE, Function *F,
10679                               PlaceholderBindings &bindings, ElemKind Ty) {
10680   auto &mod = EE.getModule();
10681   unsigned numSlices = 10;
10682   auto *input = mod.createPlaceholder(ElemKind::FloatTy, {numSlices, 10, 10},
10683                                       "input", false);
10684   auto *slice =
10685       mod.createPlaceholder(ElemKind::FloatTy, {10, 10}, "slice", false);
10686 
10687   bindings.allocate(input)->getHandle().randomize(-5.0, 5.0, mod.getPRNG());
10688   bindings.allocate(slice)->getHandle().randomize(-5.0, 5.0, mod.getPRNG());
10689 
10690   // Scale the numbers in the range (-5. .. 5.) to (-50 .. 50).
10691   auto qInType = mod.uniqueType(ElemKind::Int8QTy, {numSlices, 10, 10}, .1, 0);
10692   auto qSliceType2 = mod.uniqueType(Ty, {10, 10}, .1, 0);
10693   auto qSliceType3 = mod.uniqueType(ElemKind::Int8QTy, {1, 10, 10}, .1, 0);
10694 
10695   auto *intInput = F->createQuantize("qinput", input, qInType);
10696   auto *intSlice = F->createQuantize("qslice", slice, qSliceType2);
10697 
10698   std::vector<NodeValue> adds;
10699   for (dim_t i = 0; i < numSlices; i++) {
10700     auto *ex = F->createSlice("slice", intInput, {i, 0, 0}, qSliceType3);
10701     auto *ba = F->createBatchedAdd("add", ex, intSlice);
10702     adds.push_back(ba);
10703   }
10704 
10705   Node *cc = F->createConcat("concat", adds, 0, qInType);
10706   cc = F->createDequantize("dq", cc, ElemKind::FloatTy);
10707   auto *result = F->createSave("save", cc);
10708   bindings.allocate(result->getPlaceholder());
10709 
10710   // Remove the reference to the graph nodes to allow DCE to remove them.
10711   adds.clear();
10712 
10713   EE.compile(CompilationMode::Infer);
10714   EE.run(bindings);
10715 
10716   auto RH = bindings.get(result->getPlaceholder())->getHandle();
10717   auto IH = bindings.get(input)->getHandle();
10718   auto SH = bindings.get(slice)->getHandle();
10719 
10720   // Check that batched add works as expected.
10721   for (dim_t i = 0; i < numSlices; i++) {
10722     for (dim_t j = 0; j < 10; j++) {
10723       for (dim_t k = 0; k < 10; k++) {
10724         EXPECT_NEAR(IH.at({i, j, k}) + SH.at({j, k}), RH.at({i, j, k}), 0.1);
10725       }
10726     }
10727   }
10728 }
10729 
10730 /// Tests quantized batched-add arithmetic on Int8QTy.
TEST_P(OperatorTest,testQuantizedBatchAdd_Int8)10731 TEST_P(OperatorTest, testQuantizedBatchAdd_Int8) {
10732   CHECK_IF_ENABLED();
10733 
10734   quantizedBatchAdd(EE_, F_, bindings_, ElemKind::Int8QTy);
10735 }
10736 
10737 /// Tests quantized batched-add arithmetic on Int32QTy.
TEST_P(OperatorTest,testQuantizedBatchAdd_Int32)10738 TEST_P(OperatorTest, testQuantizedBatchAdd_Int32) {
10739   CHECK_IF_ENABLED();
10740 
10741   quantizedBatchAdd(EE_, F_, bindings_, ElemKind::Int32QTy);
10742 }
10743 
10744 template <typename DataType>
testCumSum(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,bool exclusive,bool reverse)10745 static Tensor *testCumSum(glow::PlaceholderBindings &bindings,
10746                           glow::Module &mod, glow::Function *F,
10747                           glow::ExecutionEngine &EE, ElemKind DTy,
10748                           bool exclusive, bool reverse) {
10749   auto *data = mod.createPlaceholder(DTy, {4}, "data", false);
10750   bindings.allocate(data)->getHandle<DataType>() = {1, 2, 3, 4};
10751 
10752   auto *CS = F->createCumSum("CumSum", data, exclusive, reverse);
10753   auto *S = F->createSave("save", CS);
10754   bindings.allocate(S->getPlaceholder());
10755 
10756   EE.compile(CompilationMode::Infer);
10757   EE.run(bindings);
10758   return bindings.get(S->getPlaceholder());
10759 }
10760 
TEST_P(OperatorTest,CumSum_Float)10761 TEST_P(OperatorTest, CumSum_Float) {
10762   CHECK_IF_ENABLED();
10763   /*
10764     DATA  = [1, 2, 3, 4]
10765     OUTPUT = [1, 3, 6, 10]
10766   */
10767 
10768   Tensor *result =
10769       testCumSum<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
10770                         /*exclusive*/ false, /*reverse*/ false);
10771   Tensor expected(result->getType());
10772   expected.getHandle<float>() = {1, 3, 6, 10};
10773 
10774   EXPECT_TRUE(expected.isEqual(*result));
10775 }
10776 
TEST_P(OperatorTest,CumSum_Float16)10777 TEST_P(OperatorTest, CumSum_Float16) {
10778   CHECK_IF_ENABLED();
10779   /*
10780     DATA  = [1, 2, 3, 4]
10781     OUTPUT = [1, 3, 6, 10]
10782   */
10783 
10784   Tensor *result =
10785       testCumSum<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty,
10786                             /*exclusive*/ false, /*reverse*/ false);
10787   Tensor expected(result->getType());
10788   expected.getHandle<float16_t>() = {1, 3, 6, 10};
10789 
10790   EXPECT_TRUE(expected.isEqual(*result));
10791 }
10792 
TEST_P(OperatorTest,CumSum_BFloat16)10793 TEST_P(OperatorTest, CumSum_BFloat16) {
10794   CHECK_IF_ENABLED();
10795   /*
10796     DATA  = [1, 2, 3, 4]
10797     OUTPUT = [1, 3, 6, 10]
10798   */
10799 
10800   Tensor *result =
10801       testCumSum<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty,
10802                              /*exclusive*/ false, /*reverse*/ false);
10803   Tensor expected(result->getType());
10804   expected.getHandle<bfloat16_t>() = {1, 3, 6, 10};
10805 
10806   EXPECT_TRUE(expected.isEqual(*result));
10807 }
10808 
TEST_P(OperatorTest,CumSum_Int32)10809 TEST_P(OperatorTest, CumSum_Int32) {
10810   CHECK_IF_ENABLED();
10811   /*
10812     DATA  = [1, 2, 3, 4]
10813     OUTPUT = [1, 3, 6, 10]
10814   */
10815 
10816   Tensor *result =
10817       testCumSum<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32ITy,
10818                           /*exclusive*/ false, /*reverse*/ false);
10819   Tensor expected(result->getType());
10820   expected.getHandle<int32_t>() = {1, 3, 6, 10};
10821 
10822   EXPECT_TRUE(expected.isEqual(*result));
10823 }
10824 
TEST_P(OperatorTest,CumSum_Int64)10825 TEST_P(OperatorTest, CumSum_Int64) {
10826   CHECK_IF_ENABLED();
10827   /*
10828     DATA  = [1, 2, 3, 4]
10829     OUTPUT = [1, 3, 6, 10]
10830   */
10831 
10832   Tensor *result =
10833       testCumSum<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
10834                         /*exclusive*/ false, /*reverse*/ false);
10835   Tensor expected(result->getType());
10836   expected.getHandle<float>() = {1, 3, 6, 10};
10837 
10838   EXPECT_TRUE(expected.isEqual(*result));
10839 }
10840 
TEST_P(OperatorTest,CumSum_Exclusive)10841 TEST_P(OperatorTest, CumSum_Exclusive) {
10842   CHECK_IF_ENABLED();
10843   /*
10844     DATA  = [1, 2, 3, 4]
10845     OUTPUT = [0, 1, 3, 6]
10846   */
10847 
10848   Tensor *result =
10849       testCumSum<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
10850                         /*exclusive*/ true, /*reverse*/ false);
10851   Tensor expected(result->getType());
10852   expected.getHandle<float>() = {0, 1, 3, 6};
10853 
10854   EXPECT_TRUE(expected.isEqual(*result));
10855 }
10856 
TEST_P(OperatorTest,CumSum_Reverse)10857 TEST_P(OperatorTest, CumSum_Reverse) {
10858   CHECK_IF_ENABLED();
10859   /*
10860     DATA  = [1, 2, 3, 4]
10861     OUTPUT = [10, 9, 7, 4]
10862   */
10863 
10864   Tensor *result =
10865       testCumSum<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty,
10866                             /*exclusive*/ false, /*reverse*/ true);
10867   Tensor expected(result->getType());
10868   expected.getHandle<float16_t>() = {10, 9, 7, 4};
10869 
10870   EXPECT_TRUE(expected.isEqual(*result));
10871 }
10872 
TEST_P(OperatorTest,CumSum_Reverse_BFloat16)10873 TEST_P(OperatorTest, CumSum_Reverse_BFloat16) {
10874   CHECK_IF_ENABLED();
10875   /*
10876     DATA  = [1, 2, 3, 4]
10877     OUTPUT = [10, 9, 7, 4]
10878   */
10879 
10880   Tensor *result =
10881       testCumSum<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty,
10882                              /*exclusive*/ false, /*reverse*/ true);
10883   Tensor expected(result->getType());
10884   expected.getHandle<bfloat16_t>() = {10, 9, 7, 4};
10885 
10886   EXPECT_TRUE(expected.isEqual(*result));
10887 }
10888 
TEST_P(OperatorTest,CumSum_ExclusiveReverse)10889 TEST_P(OperatorTest, CumSum_ExclusiveReverse) {
10890   CHECK_IF_ENABLED();
10891   /*
10892     DATA  = [1, 2, 3, 4]
10893     OUTPUT = [9, 7, 4, 0]
10894   */
10895 
10896   Tensor *result =
10897       testCumSum<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32ITy,
10898                           /*exclusive*/ true, /*reverse*/ true);
10899   Tensor expected(result->getType());
10900   expected.getHandle<int32_t>() = {9, 7, 4, 0};
10901 
10902   EXPECT_TRUE(expected.isEqual(*result));
10903 }
10904 
TEST_P(OperatorTest,CumSum_WithZeroes)10905 TEST_P(OperatorTest, CumSum_WithZeroes) {
10906   CHECK_IF_ENABLED();
10907   /*
10908     DATA  = [0, 0, 1, 0, 0, 2, 0, 0, 3]
10909     OUTPUT = [0, 0, 1, 1, 1, 3, 3, 3, 6]
10910   */
10911 
10912   auto *data = mod_.createPlaceholder(ElemKind::Int64ITy, {9}, "data", false);
10913   bindings_.allocate(data)->getHandle<int64_t>() = {0, 0, 1, 0, 0, 2, 0, 0, 3};
10914 
10915   auto *CS = F_->createCumSum("CumSum", data);
10916   auto *S = F_->createSave("save", CS);
10917   bindings_.allocate(S->getPlaceholder());
10918 
10919   EE_.compile(CompilationMode::Infer);
10920   EE_.run(bindings_);
10921   Tensor *result = bindings_.get(S->getPlaceholder());
10922   Tensor expected(result->getType());
10923   expected.getHandle<int64_t>() = {0, 0, 1, 1, 1, 3, 3, 3, 6};
10924 
10925   EXPECT_TRUE(expected.isEqual(*result));
10926 }
10927 
TEST_P(OperatorTest,LengthsSum)10928 TEST_P(OperatorTest, LengthsSum) {
10929   CHECK_IF_ENABLED();
10930 
10931   /*
10932     DATA  = [
10933         [1.0, 1.2],
10934         [2.3, 3.4],
10935         [4.5, 3.7],
10936         [3.0, 2.9],
10937         [1.1, 1.4],
10938         [2.8, 8.4],
10939     ]
10940     LENGTHS = [2, 0, 3, 1]
10941     OUTPUT = [
10942         [3.3, 4.6],
10943         [0.0, 0.0],
10944         [8.6, 8.0],
10945         [2.8, 8.4],
10946     ]
10947   */
10948   auto *data = mod_.createPlaceholder(ElemKind::FloatTy, {6, 2}, "data", false);
10949   auto *lengths =
10950       mod_.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths", false);
10951 
10952   bindings_.allocate(data)->getHandle() = {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 3.7f,
10953                                            3.0f, 2.9f, 1.1f, 1.4f, 2.8f, 8.4f};
10954   bindings_.allocate(lengths)->getHandle<int32_t>() = {2, 0, 3, 1};
10955 
10956   auto *R = F_->createLengthsSum("LS", data, lengths);
10957   auto *S = F_->createSave("save", R);
10958   bindings_.allocate(S->getPlaceholder());
10959 
10960   EE_.compile(CompilationMode::Infer);
10961   EE_.run(bindings_);
10962 
10963   Tensor &result = *bindings_.get(S->getPlaceholder());
10964   Tensor expected(ElemKind::FloatTy, {4, 2});
10965   expected.getHandle() = {3.3f, 4.6f, 0.0f, 0.0f, 8.6f, 8.0f, 2.8f, 8.4f};
10966 
10967   EXPECT_TRUE(expected.isEqual(result));
10968 }
10969 
10970 /// Helper to test SLS using \p DTy.
10971 template <typename DataType, typename IndexType>
testSLS(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,ElemKind ITy,float allowedError)10972 static void testSLS(glow::PlaceholderBindings &bindings, glow::Module &mod,
10973                     glow::Function *F, glow::ExecutionEngine &EE, ElemKind DTy,
10974                     ElemKind ITy, float allowedError) {
10975   /*
10976     DATA  = [
10977         [1.0, 1.2],
10978         [2.3, 3.4],
10979         [4.5, 5.7],
10980     ]
10981     INDICES = [2, 0, 1, 2, 0, 0, 0, 0]
10982     LENGTHS = [2, 0, 2, 1, 3]
10983     OUTPUT = [
10984         [5.5, 6.9],
10985         [0.0, 0.0],
10986         [6.8, 9.1],
10987         [1.0, 1.2],
10988         [3.0, 3.6],
10989     ]
10990   */
10991   auto *data = mod.createPlaceholder(DTy, {3, 2}, "data", false);
10992   auto *indices = mod.createPlaceholder(ITy, {8}, "indices", false);
10993   auto *lengths =
10994       mod.createPlaceholder(ElemKind::Int32ITy, {5}, "lengths", false);
10995 
10996   bindings.allocate(data)->getHandle<DataType>() = {
10997       1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f,
10998   };
10999   bindings.allocate(indices)->getHandle<IndexType>() = {
11000       2, 0, 1, 2, 0, 0, 0, 0,
11001   };
11002   bindings.allocate(lengths)->getHandle<int32_t>() = {
11003       2, 0, 2, 1, 3,
11004   };
11005 
11006   auto *R = F->createSparseLengthsSum("SLS", data, indices, lengths);
11007 
11008   auto *S = F->createSave("save", R);
11009   bindings.allocate(S->getPlaceholder());
11010 
11011   EE.compile(CompilationMode::Infer);
11012   EE.run(bindings);
11013 
11014   Tensor &result = *bindings.get(S->getPlaceholder());
11015   Tensor expected(DTy, {5, 2});
11016   expected.getHandle<DataType>() = {
11017       5.5f, 6.9f, 0.0f, 0.0f, 6.8f, 9.1f, 1.0f, 1.2f, 3.0f, 3.6f,
11018   };
11019 
11020   EXPECT_TRUE(expected.isEqual(result, allowedError));
11021 }
11022 
11023 /// Test that SLS is correctly supported in FloatTy with int64 indices.
TEST_P(OperatorTest,SparseLengthsSum_Float)11024 TEST_P(OperatorTest, SparseLengthsSum_Float) {
11025   CHECK_IF_ENABLED();
11026   testSLS<float, int64_t>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
11027                           ElemKind::Int64ITy, 0.0001);
11028 }
11029 
11030 /// Test that SLS is correctly supported in FloatTy with int32 indices.
TEST_P(OperatorTest,SparseLengthsSum_Float_Int32)11031 TEST_P(OperatorTest, SparseLengthsSum_Float_Int32) {
11032   CHECK_IF_ENABLED();
11033   testSLS<float, int32_t>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
11034                           ElemKind::Int32ITy, 0.0001);
11035 }
11036 
11037 /// Test that SLS is correctly supported in Float16Ty with int64 indices.
TEST_P(OperatorTest,SparseLengthsSum_Float16)11038 TEST_P(OperatorTest, SparseLengthsSum_Float16) {
11039   CHECK_IF_ENABLED();
11040   testSLS<float16_t, int64_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty,
11041                               ElemKind::Int64ITy, 0.002);
11042 }
11043 
11044 /// Test that SLS is correctly supported in BFloat16Ty with int64 indices.
TEST_P(OperatorTest,SparseLengthsSum_BFloat16)11045 TEST_P(OperatorTest, SparseLengthsSum_BFloat16) {
11046   CHECK_IF_ENABLED();
11047   testSLS<bfloat16_t, int64_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty,
11048                                ElemKind::Int64ITy, 0.05);
11049 }
11050 
11051 /// Test that SLS is correctly supported in Float16Ty with int32 indices.
TEST_P(OperatorTest,SparseLengthsSum_Float16_Int32)11052 TEST_P(OperatorTest, SparseLengthsSum_Float16_Int32) {
11053   CHECK_IF_ENABLED();
11054   testSLS<float16_t, int32_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty,
11055                               ElemKind::Int32ITy, 0.05);
11056 }
11057 
11058 /// Test that SLS is correctly supported in BFloat16Ty with int32 indices.
TEST_P(OperatorTest,SparseLengthsSum_BFloat16_Int32)11059 TEST_P(OperatorTest, SparseLengthsSum_BFloat16_Int32) {
11060   CHECK_IF_ENABLED();
11061   testSLS<bfloat16_t, int32_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty,
11062                                ElemKind::Int32ITy, 0.05);
11063 }
11064 
TEST_P(OperatorTest,SparseLengthsSumI8)11065 TEST_P(OperatorTest, SparseLengthsSumI8) {
11066   CHECK_IF_ENABLED();
11067 
11068   /*
11069     DATA  = [
11070         [11, 13],
11071         [24, 35],
11072         [46, 58],
11073     ]
11074     INDICES = [2, 0, 1, 2, 0, 0, 0, 0]
11075     LENGTHS = [2, 0, 2, 1, 3]
11076     OUTPUT = [
11077         [56, 70],
11078         [ 1,  1],
11079         [69, 92],
11080         [11, 13],
11081         [31, 37],
11082     ]
11083   */
11084   auto *data =
11085       mod_.createPlaceholder(ElemKind::Int8QTy, {3, 2}, 0.1f, 1, "data", false);
11086   auto *indices =
11087       mod_.createPlaceholder(ElemKind::Int64ITy, {8}, "indices", false);
11088   auto *lengths =
11089       mod_.createPlaceholder(ElemKind::Int32ITy, {5}, "lengths", false);
11090 
11091   bindings_.allocate(data)->getHandle<int8_t>() = {
11092       11, 13, 24, 35, 46, 58,
11093   };
11094   bindings_.allocate(indices)->getHandle<int64_t>() = {
11095       2, 0, 1, 2, 0, 0, 0, 0,
11096   };
11097   bindings_.allocate(lengths)->getHandle<int32_t>() = {
11098       2, 0, 2, 1, 3,
11099   };
11100 
11101   auto *R = F_->createSparseLengthsSum("SLS", data, indices, lengths);
11102   auto *S = F_->createSave("save", R);
11103   bindings_.allocate(S->getPlaceholder());
11104 
11105   EE_.compile(CompilationMode::Infer);
11106   EE_.run(bindings_);
11107 
11108   Tensor &result = *bindings_.get(S->getPlaceholder());
11109   Tensor expected(ElemKind::Int8QTy, {5, 2}, 0.1f, 1);
11110   expected.getHandle<int8_t>() = {
11111       56, 70, 1, 1, 69, 92, 11, 13, 31, 37,
11112   };
11113   EXPECT_TRUE(expected.isEqual(result));
11114 }
11115 
11116 /// Test SparseLengthsWeightedSum with an N-dimension embedding table.
11117 template <typename DataType>
testSLWS(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,float allowedError,size_t ndims)11118 static void testSLWS(glow::PlaceholderBindings &bindings, glow::Module &mod,
11119                      glow::Function *F, glow::ExecutionEngine &EE, ElemKind DTy,
11120                      float allowedError, size_t ndims) {
11121   /*
11122     DATA  =   [[2.0, -0.5, 13]]
11123     WEIGHTS = [3, 1, 0, 0, 0, 0, 2, -0.5]
11124     INDICES = [1, 0, 2, 0, 1, 2, 2, 0]
11125     LENGTHS = [3, 0, 3, 2]
11126     OUTPUT =  [0.5, 0, 0, 25]
11127   */
11128   ShapeVector idims(ndims, 1);
11129   ShapeVector odims(ndims, 1);
11130   idims[0] = 3;
11131   odims[0] = 4;
11132 
11133   auto *data = mod.createPlaceholder(DTy, idims, "data", false);
11134   auto *weights = mod.createPlaceholder(DTy, {8}, "weights", false);
11135   auto *indices =
11136       mod.createPlaceholder(ElemKind::Int64ITy, {8}, "indices", false);
11137   auto *lengths =
11138       mod.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths", false);
11139 
11140   bindings.allocate(data)->getHandle<DataType>() = {
11141       2.0,
11142       -0.5,
11143       13,
11144   };
11145   bindings.allocate(weights)->getHandle<DataType>() = {
11146       3, 1, 0, 0, 0, 0, 2, -0.5,
11147   };
11148   bindings.allocate(indices)->getHandle<int64_t>() = {
11149       1, 0, 2, 0, 1, 2, 2, 0,
11150   };
11151   bindings.allocate(lengths)->getHandle<int32_t>() = {
11152       3,
11153       0,
11154       3,
11155       2,
11156   };
11157 
11158   auto *R = F->createSparseLengthsWeightedSum("SLWS", data, weights, indices,
11159                                               lengths);
11160   auto *S = F->createSave("save", R);
11161   bindings.allocate(S->getPlaceholder());
11162 
11163   EE.compile(CompilationMode::Infer);
11164   EE.run(bindings);
11165 
11166   Tensor &result = *bindings.get(S->getPlaceholder());
11167   Tensor expected(DTy, odims);
11168   expected.getHandle<DataType>() = {
11169       0.5,
11170       0,
11171       0,
11172       25,
11173   };
11174 
11175   EXPECT_TRUE(expected.isEqual(result));
11176 }
11177 
11178 /// Test that SLWS is correctly supported in FloatTy in 1D.
TEST_P(OperatorTest,SparseLengthsWeightedSum_1D_Float)11179 TEST_P(OperatorTest, SparseLengthsWeightedSum_1D_Float) {
11180   CHECK_IF_ENABLED();
11181   testSLWS<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, 0.0001,
11182                   /* ndims */ 1);
11183 }
11184 
11185 /// Test that SLWS is correctly supported in FloatTy in 2D.
TEST_P(OperatorTest,SparseLengthsWeightedSum_2D_Float)11186 TEST_P(OperatorTest, SparseLengthsWeightedSum_2D_Float) {
11187   CHECK_IF_ENABLED();
11188   testSLWS<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, 0.0001,
11189                   /* ndims */ 2);
11190 }
11191 
11192 /// Test that SLWS is correctly supported in Float16Ty in 1D.
TEST_P(OperatorTest,SparseLengthsWeightedSum_1D_Float16)11193 TEST_P(OperatorTest, SparseLengthsWeightedSum_1D_Float16) {
11194   CHECK_IF_ENABLED();
11195   testSLWS<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty, 0.0001,
11196                       /* ndims */ 1);
11197 }
11198 
11199 /// Test that SLWS is correctly supported in BFloat16Ty in 1D.
TEST_P(OperatorTest,SparseLengthsWeightedSum_1D_BFloat16)11200 TEST_P(OperatorTest, SparseLengthsWeightedSum_1D_BFloat16) {
11201   CHECK_IF_ENABLED();
11202   testSLWS<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty, 0.0001,
11203                        /* ndims */ 1);
11204 }
11205 
11206 /// Test that SLWS is correctly supported in Float16Ty in 2D.
TEST_P(OperatorTest,SparseLengthsWeightedSum_2D_Float16)11207 TEST_P(OperatorTest, SparseLengthsWeightedSum_2D_Float16) {
11208   CHECK_IF_ENABLED();
11209   testSLWS<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty, 0.0001,
11210                       /* ndims */ 2);
11211 }
11212 
11213 /// Test that SLWS is correctly supported in BFloat16Ty in 2D.
TEST_P(OperatorTest,SparseLengthsWeightedSum_2D_BFloat16)11214 TEST_P(OperatorTest, SparseLengthsWeightedSum_2D_BFloat16) {
11215   CHECK_IF_ENABLED();
11216   testSLWS<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty, 0.0001,
11217                        /* ndims */ 2);
11218 }
11219 
TEST_P(OperatorTest,SparseLengthsWeightedSumI8)11220 TEST_P(OperatorTest, SparseLengthsWeightedSumI8) {
11221   CHECK_IF_ENABLED();
11222 
11223   /*
11224     DATA  =   [4, -1, 26]
11225     WEIGHTS = [6, 2, 0, 0, 0, 0, 4, -1]
11226     INDICES = [1, 0, 2, 0, 1, 2, 2, 0]
11227     LENGTHS = [3, 0, 3, 2]
11228     OUTPUT =  [1, 0, 0, 50]
11229   */
11230   auto *data =
11231       mod_.createPlaceholder(ElemKind::Int8QTy, {3}, 0.5, 0, "data", false);
11232   auto *weights =
11233       mod_.createPlaceholder(ElemKind::Int8QTy, {8}, 0.5, 0, "weights", false);
11234   auto *indices =
11235       mod_.createPlaceholder(ElemKind::Int64ITy, {8}, "indices", false);
11236   auto *lengths =
11237       mod_.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths", false);
11238 
11239   bindings_.allocate(data)->getHandle<int8_t>() = {
11240       4,
11241       -1,
11242       26,
11243   };
11244   bindings_.allocate(weights)->getHandle<int8_t>() = {
11245       6, 2, 0, 0, 0, 0, 4, -1,
11246   };
11247   bindings_.allocate(indices)->getHandle<int64_t>() = {
11248       1, 0, 2, 0, 1, 2, 2, 0,
11249   };
11250   bindings_.allocate(lengths)->getHandle<int32_t>() = {
11251       3,
11252       0,
11253       3,
11254       2,
11255   };
11256 
11257   auto *R = F_->createSparseLengthsWeightedSum("SLWS", data, weights, indices,
11258                                                lengths);
11259   auto *S = F_->createSave("save", R);
11260   bindings_.allocate(S->getPlaceholder());
11261 
11262   EE_.compile(CompilationMode::Infer);
11263   EE_.run(bindings_);
11264 
11265   Tensor &result = *bindings_.get(S->getPlaceholder());
11266   Tensor expected(ElemKind::Int8QTy, {4}, 0.5, 0);
11267   expected.getHandle<int8_t>() = {
11268       1,
11269       0,
11270       0,
11271       50,
11272   };
11273 
11274   EXPECT_TRUE(expected.isEqual(result));
11275 }
11276 
11277 /// Helper function to construct indices/offsets pair for EmbeddingBag
11278 /// and EmbeddingBagByteRowwiseOffsets
11279 template <typename DataType>
addEmbeddingBagPartialInputs(glow::PlaceholderBindings & bindings,glow::Module & mod,ElemKind DTy,Placeholder * & weights,Placeholder * & indices,Placeholder * & offsets,bool hasEndOffset,bool partialInput=false)11280 static void addEmbeddingBagPartialInputs(
11281     glow::PlaceholderBindings &bindings, glow::Module &mod, ElemKind DTy,
11282     Placeholder *&weights, Placeholder *&indices, Placeholder *&offsets,
11283     bool hasEndOffset, bool partialInput = false) {
11284 
11285   if (hasEndOffset) {
11286     Tensor weightsTensorReal(DTy, {10});
11287     Tensor indicesTensorReal(ElemKind::Int64ITy, {10});
11288     Tensor offsetsTensorReal(ElemKind::Int64ITy, {5});
11289 
11290     weightsTensorReal.getHandle<DataType>() = {
11291         3, 1, 0, 0, 0, 0, 2, -0.5, 42.0, 42.0,
11292     };
11293     indicesTensorReal.getHandle<int64_t>() = {
11294         1, 0, 2, 0, 1, 2, 2, 0, 13, 10,
11295     };
11296     offsetsTensorReal.getHandle<int64_t>() = {
11297         0, 3, 3, 6,
11298         8, // extra end offset
11299     };
11300 
11301     if (partialInput) {
11302       weights = mod.createPlaceholder(DTy, {20}, "weights", false);
11303       indices =
11304           mod.createPlaceholder(ElemKind::Int64ITy, {20}, "indices", false);
11305       offsets =
11306           mod.createPlaceholder(ElemKind::Int64ITy, {6}, "offsets", false);
11307 
11308       // If we use partial weights, it will cause problems when it added as a
11309       // Constant. So here we pad it with zeros.
11310       Tensor weightsTensorPadded(weights->getType());
11311       memcpy(weightsTensorPadded.getUnsafePtr(),
11312              weightsTensorReal.getUnsafePtr(),
11313              weightsTensorReal.getSizeInBytes());
11314       memset(weightsTensorPadded.getUnsafePtr() +
11315                  weightsTensorReal.getSizeInBytes(),
11316              0,
11317              weightsTensorPadded.getSizeInBytes() -
11318                  weightsTensorReal.getSizeInBytes());
11319 
11320       Tensor indicesTensorPartial(indicesTensorReal.getUnsafePtr(),
11321                                   indices->getType(),
11322                                   indicesTensorReal.getSizeInBytes());
11323       Tensor offsetsTensorPartial(offsetsTensorReal.getUnsafePtr(),
11324                                   offsets->getType(),
11325                                   offsetsTensorReal.getSizeInBytes());
11326       bindings.insert(weights, std::move(weightsTensorPadded));
11327       bindings.insert(indices, indicesTensorPartial.clone());
11328       bindings.insert(offsets, offsetsTensorPartial.clone());
11329     } else {
11330       weights = mod.createPlaceholder(DTy, {10}, "weights", false);
11331       indices =
11332           mod.createPlaceholder(ElemKind::Int64ITy, {10}, "indices", false);
11333       offsets =
11334           mod.createPlaceholder(ElemKind::Int64ITy, {5}, "offsets", false);
11335 
11336       bindings.insert(weights, std::move(weightsTensorReal));
11337       bindings.insert(indices, std::move(indicesTensorReal));
11338       bindings.insert(offsets, std::move(offsetsTensorReal));
11339     }
11340   } else {
11341     // We assume no partial inputs will be used if hasEndOffset is false
11342     Tensor weightsTensorReal(DTy, {8});
11343     Tensor indicesTensorReal(ElemKind::Int64ITy, {8});
11344     Tensor offsetsTensorReal(ElemKind::Int64ITy, {4});
11345 
11346     weightsTensorReal.getHandle<DataType>() = {
11347         3, 1, 0, 0, 0, 0, 2, -0.5,
11348     };
11349     indicesTensorReal.getHandle<int64_t>() = {
11350         1, 0, 2, 0, 1, 2, 2, 0,
11351     };
11352     offsetsTensorReal.getHandle<int64_t>() = {0, 3, 3, 6};
11353 
11354     weights = mod.createPlaceholder(DTy, {8}, "weights", false);
11355     indices = mod.createPlaceholder(ElemKind::Int64ITy, {8}, "indices", false);
11356     offsets = mod.createPlaceholder(ElemKind::Int64ITy, {4}, "offsets", false);
11357 
11358     bindings.insert(weights, std::move(weightsTensorReal));
11359     bindings.insert(indices, std::move(indicesTensorReal));
11360     bindings.insert(offsets, std::move(offsetsTensorReal));
11361   }
11362 }
11363 
11364 /// Test EmbeddingBag with an N-dimension embedding table.
11365 template <typename DataType>
testEmbeddingBag(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,float allowedError,dim_t ndims,bool hasEndOffset,bool partialInput=false)11366 static void testEmbeddingBag(glow::PlaceholderBindings &bindings,
11367                              glow::Module &mod, glow::Function *F,
11368                              glow::ExecutionEngine &EE, ElemKind DTy,
11369                              float allowedError, dim_t ndims, bool hasEndOffset,
11370                              bool partialInput = false) {
11371   /*
11372     DATA  =   [[2.0, -0.5, 13]]
11373     WEIGHTS = [3, 1, 0, 0, 0, 0, 2, -0.5]
11374     INDICES = [1, 0, 2, 0, 1, 2, 2, 0]
11375     OFFSETS = [0, 3, 3, 6]
11376     OUTPUT =  [0.5, 0, 0, 25]
11377   */
11378   ShapeVector idims(ndims, 1);
11379   ShapeVector odims(ndims, 1);
11380   idims[0] = 3;
11381   odims[0] = partialInput ? 5 : 4;
11382 
11383   auto *data = mod.createPlaceholder(DTy, idims, "data", false);
11384 
11385   bindings.allocate(data)->getHandle<DataType>() = {
11386       2.0,
11387       -0.5,
11388       13,
11389   };
11390 
11391   // If hasEndOffset then add some additional junk to the end of indices and
11392   // weights and an extra offset to offsets.
11393   Placeholder *weights;
11394   Placeholder *indices;
11395   Placeholder *offsets;
11396 
11397   addEmbeddingBagPartialInputs<DataType>(bindings, mod, DTy, weights, indices,
11398                                          offsets, hasEndOffset, partialInput);
11399 
11400   auto *R = F->createEmbeddingBag("EB", data, weights, indices, offsets,
11401                                   hasEndOffset);
11402   auto *S = F->createSave("save", R);
11403   bindings.allocate(S->getPlaceholder());
11404 
11405   EE.compile(CompilationMode::Infer);
11406   EE.run(bindings);
11407 
11408   Tensor &result = *bindings.get(S->getPlaceholder());
11409   Tensor expected(DTy, odims);
11410   if (partialInput) {
11411     expected.getHandle<DataType>() = {
11412         0.5, 0, 0, 25, 0,
11413     };
11414   } else {
11415     expected.getHandle<DataType>() = {
11416         0.5,
11417         0,
11418         0,
11419         25,
11420     };
11421   }
11422 
11423   EXPECT_TRUE(expected.isEqual(result, allowedError));
11424 }
11425 
11426 /// Test that EB is correctly supported in FloatTy in 1D.
TEST_P(OperatorTest,EmbeddingBag_1D_Float)11427 TEST_P(OperatorTest, EmbeddingBag_1D_Float) {
11428   CHECK_IF_ENABLED();
11429   testEmbeddingBag<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, 0.0001,
11430                           /* ndims */ 1, /* hasEndOffset */ false);
11431 }
11432 
11433 /// Test that EB is correctly supported in FloatTy in 1D with an end offset.
TEST_P(OperatorTest,EmbeddingBag_1D_Float_End_Offset)11434 TEST_P(OperatorTest, EmbeddingBag_1D_Float_End_Offset) {
11435   CHECK_IF_ENABLED();
11436   testEmbeddingBag<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, 0.0001,
11437                           /* ndims */ 1, /* hasEndOffset */ true);
11438 }
11439 
11440 /// Test that EB is correctly supported in FloatTy in 2D.
TEST_P(OperatorTest,EmbeddingBag_2D_Float)11441 TEST_P(OperatorTest, EmbeddingBag_2D_Float) {
11442   CHECK_IF_ENABLED();
11443   testEmbeddingBag<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, 0.0001,
11444                           /* ndims */ 2, /* hasEndOffset */ false);
11445 }
11446 
11447 /// Test that EB is correctly supported in FloatTy in 2D with an end offset.
TEST_P(OperatorTest,EmbeddingBag_2D_Float_End_Offset)11448 TEST_P(OperatorTest, EmbeddingBag_2D_Float_End_Offset) {
11449   CHECK_IF_ENABLED();
11450   testEmbeddingBag<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, 0.0001,
11451                           /* ndims */ 2, /* hasEndOffset */ true);
11452 }
11453 
11454 /// Test that EB is correctly supported in Float16Ty in 1D.
TEST_P(OperatorTest,EmbeddingBag_1D_Float16)11455 TEST_P(OperatorTest, EmbeddingBag_1D_Float16) {
11456   CHECK_IF_ENABLED();
11457   testEmbeddingBag<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty,
11458                               0.0001,
11459                               /* ndims */ 1, /* hasEndOffset */ false);
11460 }
11461 
11462 /// Test that EB is correctly supported in BFloat16Ty in 1D.
TEST_P(OperatorTest,EmbeddingBag_1D_BFloat16)11463 TEST_P(OperatorTest, EmbeddingBag_1D_BFloat16) {
11464   CHECK_IF_ENABLED();
11465   testEmbeddingBag<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty,
11466                                0.0001,
11467                                /* ndims */ 1, /* hasEndOffset */ false);
11468 }
11469 
11470 /// Test that EB is correctly supported in Float16Ty in 1D with an end offset.
TEST_P(OperatorTest,EmbeddingBag_1D_Float16_End_Offset)11471 TEST_P(OperatorTest, EmbeddingBag_1D_Float16_End_Offset) {
11472   CHECK_IF_ENABLED();
11473   testEmbeddingBag<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty,
11474                               0.0001,
11475                               /* ndims */ 1, /* hasEndOffset */ true);
11476 }
11477 
11478 /// Test that EB is correctly supported in BFloat16Ty in 1D with an end offset.
TEST_P(OperatorTest,EmbeddingBag_1D_BFloat16_End_Offset)11479 TEST_P(OperatorTest, EmbeddingBag_1D_BFloat16_End_Offset) {
11480   CHECK_IF_ENABLED();
11481   testEmbeddingBag<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty,
11482                                0.0001,
11483                                /* ndims */ 1, /* hasEndOffset */ true);
11484 }
11485 
11486 /// Test that EB is correctly supported in Float16Ty in 2D.
TEST_P(OperatorTest,EmbeddingBag_2D_Float16)11487 TEST_P(OperatorTest, EmbeddingBag_2D_Float16) {
11488   CHECK_IF_ENABLED();
11489   testEmbeddingBag<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty,
11490                               0.0001,
11491                               /* ndims */ 2, /* hasEndOffset */ false);
11492 }
11493 
11494 /// Test that EB is correctly supported in BFloat16Ty in 2D.
TEST_P(OperatorTest,EmbeddingBag_2D_BFloat16)11495 TEST_P(OperatorTest, EmbeddingBag_2D_BFloat16) {
11496   CHECK_IF_ENABLED();
11497   testEmbeddingBag<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty,
11498                                0.0001,
11499                                /* ndims */ 2, /* hasEndOffset */ false);
11500 }
11501 
11502 /// Test that EB is correctly supported in Float16Ty in 2D with an end offset.
TEST_P(OperatorTest,EmbeddingBag_2D_Float16_End_Offset)11503 TEST_P(OperatorTest, EmbeddingBag_2D_Float16_End_Offset) {
11504   CHECK_IF_ENABLED();
11505   testEmbeddingBag<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty,
11506                               0.0001,
11507                               /* ndims */ 2, /* hasEndOffset */ true);
11508 }
11509 
11510 /// Test that EB is correctly supported in BFloat16Ty in 2D with an end offset.
TEST_P(OperatorTest,EmbeddingBag_2D_BFloat16_End_Offset)11511 TEST_P(OperatorTest, EmbeddingBag_2D_BFloat16_End_Offset) {
11512   CHECK_IF_ENABLED();
11513   testEmbeddingBag<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty,
11514                                0.0001,
11515                                /* ndims */ 2, /* hasEndOffset */ true);
11516 }
11517 
11518 /// Test that EB is correctly supported in FloatTy in 1D with an end offset and
11519 /// partial inputs.
TEST_P(OperatorTest,EmbeddingBag_1D_Float_End_Offset_Partial)11520 TEST_P(OperatorTest, EmbeddingBag_1D_Float_End_Offset_Partial) {
11521   CHECK_IF_ENABLED();
11522   ASSERT_TRUE(EE_.getBackend(getBackendName()).supportsPartialTensors());
11523   testEmbeddingBag<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, 0.0001,
11524                           /* ndims */ 1, /* hasEndOffset */ true,
11525                           /* partialInput */ true);
11526 }
11527 
11528 /// Test that EB is correctly supported in Float16Ty in 1D with an end offset
11529 /// and partial inputs.
TEST_P(OperatorTest,EmbeddingBag_2D_Float_End_Offset_Partial)11530 TEST_P(OperatorTest, EmbeddingBag_2D_Float_End_Offset_Partial) {
11531   CHECK_IF_ENABLED();
11532   ASSERT_TRUE(EE_.getBackend(getBackendName()).supportsPartialTensors());
11533   testEmbeddingBag<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, 0.0001,
11534                           /* ndims */ 2, /* hasEndOffset */ true,
11535                           /* partialInput */ true);
11536 }
11537 
11538 /// Helper to test EmbeddingBagByteRowwiseOffsets using \p DTy.
11539 template <typename DataType>
testEmbeddingBagByteRowwiseOffsets(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind fusedDTy,float allowedError,bool useFP16Accumulation,bool hasEndOffset,bool partialInput=false)11540 static void testEmbeddingBagByteRowwiseOffsets(
11541     glow::PlaceholderBindings &bindings, glow::Module &mod, glow::Function *F,
11542     glow::ExecutionEngine &EE, ElemKind fusedDTy, float allowedError,
11543     bool useFP16Accumulation, bool hasEndOffset, bool partialInput = false) {
11544   /*
11545     DATA  =   [[2.0, -0.5, 13]]
11546     WEIGHTS = [3, 1, 0, 0, 0, 0, 2, -0.5]
11547     INDICES = [1, 0, 2, 0, 1, 2, 2, 0]
11548     OFFSETS = [0, 3, 3, 6]
11549     OUTPUT =  [0.5, 0, 0, 25]
11550   */
11551   const bool fusedData = isFusedQuantizedElemKind(fusedDTy);
11552   const ElemKind DTy =
11553       fusedData ? getScaleOffsetElemKindFromFused(fusedDTy) : fusedDTy;
11554   Tensor data(ElemKind::FloatTy, {3, 1});
11555   data.getHandle() = {
11556       2.0,
11557       -0.5,
11558       13,
11559   };
11560 
11561   // If hasEndOffset then add some additional junk to the end of indices and
11562   // weights and an extra offset to offsets.
11563   // Note that weights here needs to be Constant instead of Placeholder for
11564   // EmbeddingBagByteRowwiseOffsets, so we need to convert it later on
11565   Placeholder *weights;
11566   Placeholder *indices;
11567   Placeholder *offsets;
11568 
11569   addEmbeddingBagPartialInputs<DataType>(bindings, mod, DTy, weights, indices,
11570                                          offsets, hasEndOffset, partialInput);
11571 
11572   auto *R = F->createEmbeddingBagByteRowwiseOffsets(
11573       "EBBRO", data, weights, indices, offsets, fusedDTy, useFP16Accumulation,
11574       hasEndOffset);
11575   SaveNode *S = F->createSave("save", R);
11576   bindings.allocate(S->getPlaceholder());
11577 
11578   ::glow::convertPlaceholdersToConstants(
11579       F, bindings, {indices, offsets, S->getPlaceholder()});
11580 
11581   EE.compile(CompilationMode::Infer);
11582   EE.run(bindings);
11583 
11584   Tensor &result = *bindings.get(S->getPlaceholder());
11585   ShapeVector odims(2, 1);
11586   odims[0] = partialInput ? 5 : 4;
11587   Tensor expected(DTy, odims);
11588   if (partialInput) {
11589     expected.getHandle<DataType>() = {
11590         0.5, 0, 0, 25, 0,
11591     };
11592   } else {
11593     expected.getHandle<DataType>() = {
11594         0.5,
11595         0,
11596         0,
11597         25,
11598     };
11599   }
11600 
11601   EXPECT_TRUE(expected.isEqual(result, allowedError));
11602 }
11603 
11604 /// Test EmbeddingBagByteRowwiseOffsets in Float.
TEST_P(OperatorTest,EmbeddingBagByteRowwiseOffsets_Float)11605 TEST_P(OperatorTest, EmbeddingBagByteRowwiseOffsets_Float) {
11606   CHECK_IF_ENABLED();
11607   testEmbeddingBagByteRowwiseOffsets<float>(
11608       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedQTy, 0.0001,
11609       /* useFP16Accumulation */ false, /* hasEndOffset */ false);
11610 }
11611 
11612 /// Test EmbeddingBagByteRowwiseOffsets in Float with end offset.
TEST_P(OperatorTest,EmbeddingBagByteRowwiseOffsets_Float_End_Offset)11613 TEST_P(OperatorTest, EmbeddingBagByteRowwiseOffsets_Float_End_Offset) {
11614   CHECK_IF_ENABLED();
11615   testEmbeddingBagByteRowwiseOffsets<float>(
11616       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedQTy, 0.0001,
11617       /* useFP16Accumulation */ false, /* hasEndOffset */ true);
11618 }
11619 
11620 /// Test EmbeddingBagByteRowwiseOffsets in Float with end offset and partial
11621 /// inputs.
TEST_P(OperatorTest,EmbeddingBagByteRowwiseOffsets_Float_End_Offset_Partial)11622 TEST_P(OperatorTest, EmbeddingBagByteRowwiseOffsets_Float_End_Offset_Partial) {
11623   CHECK_IF_ENABLED();
11624   ASSERT_TRUE(EE_.getBackend(getBackendName()).supportsPartialTensors());
11625   testEmbeddingBagByteRowwiseOffsets<float>(
11626       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedQTy, 0.0001,
11627       /* useFP16Accumulation */ false, /* hasEndOffset */ true,
11628       /* partialInputs */ true);
11629 }
11630 
11631 /// Test EmbeddingBagByteRowwiseOffsets in Float16. Uses Float accumulation.
TEST_P(OperatorTest,EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat)11632 TEST_P(OperatorTest, EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat) {
11633   CHECK_IF_ENABLED();
11634   testEmbeddingBagByteRowwiseOffsets<float16_t>(
11635       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedFP16QTy, 0.0001,
11636       /* useFP16Accumulation */ false, /* hasEndOffset */ false);
11637 }
11638 
11639 /// Test EmbeddingBagByteRowwiseOffsets in Float16. Uses Float accumulation.
11640 /// Has end offset.
TEST_P(OperatorTest,EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat_End_Offset)11641 TEST_P(OperatorTest,
11642        EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat_End_Offset) {
11643   CHECK_IF_ENABLED();
11644   testEmbeddingBagByteRowwiseOffsets<float16_t>(
11645       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedFP16QTy, 0.0001,
11646       /* useFP16Accumulation */ false, /* hasEndOffset */ true);
11647 }
11648 
11649 /// Test EmbeddingBagByteRowwiseOffsets in Float16. Uses Float accumulation.
11650 /// Has end offset and using partial inputs.
TEST_P(OperatorTest,EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat_End_Offset_Partial)11651 TEST_P(OperatorTest,
11652        EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat_End_Offset_Partial) {
11653   CHECK_IF_ENABLED();
11654   ASSERT_TRUE(EE_.getBackend(getBackendName()).supportsPartialTensors());
11655   testEmbeddingBagByteRowwiseOffsets<float16_t>(
11656       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedFP16QTy, 0.0001,
11657       /* useFP16Accumulation */ false, /* hasEndOffset */ true,
11658       /* partialInputs */ true);
11659 }
11660 
11661 /// Test EmbeddingBagByteRowwiseOffsets in Float16. Uses Float16 accumulation.
TEST_P(OperatorTest,EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat16)11662 TEST_P(OperatorTest, EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat16) {
11663   CHECK_IF_ENABLED();
11664   testEmbeddingBagByteRowwiseOffsets<float16_t>(
11665       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedFP16QTy, 0.0001,
11666       /* useFP16Accumulation */ true, /* hasEndOffset */ false);
11667 }
11668 
11669 /// Test EmbeddingBagByteRowwiseOffsets in Float16. Uses Float16 accumulation.
11670 /// Has end offset.
TEST_P(OperatorTest,EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat16_End_Offset)11671 TEST_P(OperatorTest,
11672        EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat16_End_Offset) {
11673   CHECK_IF_ENABLED();
11674   testEmbeddingBagByteRowwiseOffsets<float16_t>(
11675       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedFP16QTy, 0.0001,
11676       /* useFP16Accumulation */ true, /* hasEndOffset */ true);
11677 }
11678 
11679 /// Test EmbeddingBagByteRowwiseOffsets in Float16. Uses Float16 accumulation.
11680 /// Has end offset and using partial inputs.
TEST_P(OperatorTest,EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat16_End_Offset_Partial)11681 TEST_P(OperatorTest,
11682        EmbeddingBagByteRowwiseOffsets_Float16_AccumFloat16_End_Offset_Partial) {
11683   CHECK_IF_ENABLED();
11684   ASSERT_TRUE(EE_.getBackend(getBackendName()).supportsPartialTensors());
11685   testEmbeddingBagByteRowwiseOffsets<float16_t>(
11686       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedFP16QTy, 0.0001,
11687       /* useFP16Accumulation */ false, /* hasEndOffset */ true,
11688       /* partialInputs */ true);
11689 }
11690 
11691 /// Helper to test EmbeddingBag4BitRowwiseOffsets.
11692 template <typename DataType>
testEmbeddingBag4BitRowwiseOffsets(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,bool useFP16Accumulation,bool hasEndOffset,float allowedError)11693 static void testEmbeddingBag4BitRowwiseOffsets(
11694     glow::PlaceholderBindings &bindings, glow::Module &mod, glow::Function *F,
11695     glow::ExecutionEngine &EE, bool useFP16Accumulation, bool hasEndOffset,
11696     float allowedError) {
11697   /*
11698     DATA  =   [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3], // First Slice.
11699                [-3, -2, -1., 0], [0, -1, -2, -3],  // Second Slice.
11700                [2, 2, 2, 2,], [2, 2, 2, 2]  // Third Slice.
11701                ]
11702     WEIGHTS = [1, 2, 3, 2, 0.5, -0.5, 2]
11703     INDICES = [0, 1, 2, 4, 3, 5, 6]
11704     OFFSETS = [
11705         0, // This slice contains numbers >= 0.
11706         3, // This slice contains numbers <= 0.
11707         5, // This slice contains numbers which are all the same.
11708         7, // Empty slice.
11709     ]
11710     OUTPUT =  [[0, 6, 12, 18], // Output row per slice.
11711                [-1.5, -3, -4.5, -6],
11712                [3, 3, 3, 3]
11713                [0, 0, 0, 0]]
11714   */
11715   Tensor data(ElemKind::FloatTy, {7, 4});
11716   data.getHandle() = {
11717       0.,  1., 2., 3.,  0.,  1.,  2., 3., 0., 1., 2., 3., -3., -2.,
11718       -1., 0., 0., -1., -2., -3., 2., 2., 2., 2., 2., 2., 2.,  2.,
11719   };
11720 
11721   // If hasEndOffset then add some additional junk to the end of indices and
11722   // weights and an extra offset to offsets.
11723   Constant *weights;
11724   Placeholder *indices;
11725   Placeholder *offsets;
11726   if (hasEndOffset) {
11727     weights = mod.createConstant(ElemKind::Float16Ty, {9}, "weights");
11728     weights->getPayloadMutable().getHandle<DataType>() = {
11729         1.,
11730         2.,
11731         3.,
11732         2,
11733         0.5,
11734         -0.5,
11735         2,
11736         -42.0 /* A dummy weight for end offset. */,
11737         42.0 /* A dummy weight for end offset. */,
11738     };
11739 
11740     indices = mod.createPlaceholder(ElemKind::Int64ITy, {9}, "indices",
11741                                     /* isTrainable */ false);
11742     offsets = mod.createPlaceholder(ElemKind::Int64ITy, {5}, "offsets",
11743                                     /* isTrainable */ false);
11744 
11745     bindings.allocate(indices)->getHandle<int64_t>() = {
11746         0,
11747         1,
11748         2,
11749         4,
11750         3,
11751         5,
11752         6,
11753         100 /* A dummy indice for end offset. */,
11754         200 /* A dummy indice for end offset. */,
11755     };
11756 
11757     bindings.allocate(offsets)->getHandle<int64_t>() = {
11758         0, // This slice contains numbers >= 0.
11759         3, // This slice contains numbers <= 0.
11760         5, // This slice contains numbers which are all the same.
11761         7, // Empty slice.
11762         7, // Dummy end offset.
11763     };
11764 
11765   } else {
11766     weights = mod.createConstant(ElemKind::Float16Ty, {7}, "weights");
11767     weights->getPayloadMutable().getHandle<DataType>() = {
11768         1., 2., 3., 2, 0.5, -0.5, 2,
11769     };
11770 
11771     indices = mod.createPlaceholder(ElemKind::Int64ITy, {7}, "indices",
11772                                     /* isTrainable */ false);
11773     offsets = mod.createPlaceholder(ElemKind::Int64ITy, {4}, "offsets",
11774                                     /* isTrainable */ false);
11775 
11776     bindings.allocate(indices)->getHandle<int64_t>() = {
11777         0, 1, 2, 4, 3, 5, 6,
11778     };
11779     bindings.allocate(offsets)->getHandle<int64_t>() = {
11780         0, // This slice contains numbers >= 0.
11781         3, // This slice contains numbers <= 0.
11782         5, // This slice contains numbers which are all the same.
11783         7, // Empty slice.
11784     };
11785   }
11786 
11787   auto *R = F->createEmbeddingBagByteRowwiseOffsets(
11788       "EBBRO", data, weights, indices, offsets, ElemKind::UInt4FusedFP16QTy,
11789       useFP16Accumulation, hasEndOffset);
11790   SaveNode *S = F->createSave("save", R);
11791   bindings.allocate(S->getPlaceholder());
11792 
11793   EE.compile(CompilationMode::Infer);
11794   EE.run(bindings);
11795 
11796   Tensor &result = *bindings.get(S->getPlaceholder());
11797   Tensor expected(ElemKind::Float16Ty, {4, 4});
11798   expected.getHandle<DataType>() = {0., 6., 12., 18., -1.5, -3., -4.5, -6,
11799                                     3., 3., 3.,  3.,  0.,   0.,  0.,   0.};
11800 
11801   EXPECT_TRUE(expected.isEqual(result, allowedError));
11802 }
11803 
TEST_P(OperatorTest,EmbeddingBag4BitRowwiseOffsets_Float16)11804 TEST_P(OperatorTest, EmbeddingBag4BitRowwiseOffsets_Float16) {
11805   CHECK_IF_ENABLED();
11806   testEmbeddingBag4BitRowwiseOffsets<float16_t>(
11807       bindings_, mod_, F_, EE_,
11808       /* useFP16Accumulation */ false, /* hasEndOffset */ false, 0.005);
11809 }
11810 
TEST_P(OperatorTest,EmbeddingBag4BitRowwiseOffsets_Float16_AccumFloat)11811 TEST_P(OperatorTest, EmbeddingBag4BitRowwiseOffsets_Float16_AccumFloat) {
11812   CHECK_IF_ENABLED();
11813   testEmbeddingBag4BitRowwiseOffsets<float16_t>(
11814       bindings_, mod_, F_, EE_,
11815       /* useFP16Accumulation */ true, /* hasEndOffset */ false, 0.005);
11816 }
11817 
TEST_P(OperatorTest,EmbeddingBag4BitRowwiseOffsets_Float16_HasEndOffset)11818 TEST_P(OperatorTest, EmbeddingBag4BitRowwiseOffsets_Float16_HasEndOffset) {
11819   CHECK_IF_ENABLED();
11820   testEmbeddingBag4BitRowwiseOffsets<float16_t>(bindings_, mod_, F_, EE_,
11821                                                 /* useFP16Accumulation */ false,
11822                                                 /* hasEndOffset */ true, 0.005);
11823 }
11824 
TEST_P(OperatorTest,EmbeddingBag4BitRowwiseOffsets_Float16_HasEndOffset_AccumFloat)11825 TEST_P(OperatorTest,
11826        EmbeddingBag4BitRowwiseOffsets_Float16_HasEndOffset_AccumFloat) {
11827   CHECK_IF_ENABLED();
11828   testEmbeddingBag4BitRowwiseOffsets<float16_t>(bindings_, mod_, F_, EE_,
11829                                                 /* useFP16Accumulation */ true,
11830                                                 /* hasEndOffset */ true, 0.005);
11831 }
11832 
11833 /// Helper to test RowwiseQuantizedSparseLengthsWeightedSum using \p DTy.
11834 template <typename DataType, typename IndexType>
testRowwiseQuantizedSparseLengthsWeightedSum(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,ElemKind ITy,float allowedError,bool useFP16Accumulation=false)11835 static void testRowwiseQuantizedSparseLengthsWeightedSum(
11836     glow::PlaceholderBindings &bindings, glow::Module &mod, glow::Function *F,
11837     glow::ExecutionEngine &EE, ElemKind DTy, ElemKind ITy, float allowedError,
11838     bool useFP16Accumulation = false) {
11839   /*
11840     DATA  =   [2.0, -0.5, 13]
11841     WEIGHTS = [3, 1, 0, 0, 0, 0, 2, -0.5]
11842     INDICES = [1, 0, 2, 0, 1, 2, 2, 0]
11843     LENGTHS = [3, 0, 3, 2]
11844     OUTPUT =  [0.5, 0, 0, 25]
11845   */
11846   Tensor data(ElemKind::FloatTy, {3});
11847   data.getHandle<float>() = {
11848       2.0,
11849       -0.5,
11850       13,
11851   };
11852 
11853   Constant *weights = mod.createConstant(DTy, {8}, "weights");
11854   weights->getPayloadMutable().getHandle<DataType>() = {
11855       3., 1., 0., 0., 0., 0., 2., -0.5,
11856   };
11857 
11858   Placeholder *indices = mod.createPlaceholder(ITy, {8}, "indices",
11859                                                /* isTrainable */ false);
11860   Placeholder *lengths =
11861       mod.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths",
11862                             /* isTrainable */ false);
11863 
11864   bindings.allocate(indices)->getHandle<IndexType>() = {
11865       1, 0, 2, 0, 1, 2, 2, 0,
11866   };
11867   bindings.allocate(lengths)->getHandle<int32_t>() = {
11868       3,
11869       0,
11870       3,
11871       2,
11872   };
11873 
11874   auto *R = F->createRowwiseQuantizedSparseLengthsWeightedSum(
11875       "RQSLWS", data, weights, indices, lengths,
11876       quantization::Schema::Asymmetric, DTy, useFP16Accumulation);
11877   SaveNode *S = F->createSave("save", R);
11878   bindings.allocate(S->getPlaceholder());
11879 
11880   EE.compile(CompilationMode::Infer);
11881   EE.run(bindings);
11882 
11883   Tensor &result = *bindings.get(S->getPlaceholder());
11884   Tensor expected(DTy, {4});
11885   expected.getHandle<DataType>() = {
11886       0.5,
11887       0,
11888       0,
11889       25,
11890   };
11891 
11892   EXPECT_TRUE(expected.isEqual(result, allowedError));
11893 }
11894 
11895 /// Test RWQ-SLWS with Float Weights, Scales, Offsets, and Output.
TEST_P(OperatorTest,RowwiseQuantizedSparseLengthsWeightedSum_Float)11896 TEST_P(OperatorTest, RowwiseQuantizedSparseLengthsWeightedSum_Float) {
11897   CHECK_IF_ENABLED();
11898   testRowwiseQuantizedSparseLengthsWeightedSum<float, int64_t>(
11899       bindings_, mod_, F_, EE_, ElemKind::FloatTy, ElemKind::Int64ITy, 0.0001);
11900 }
11901 
11902 /// Test RWQ-SLWS with Float16 Weights, Scales, Offsets, and Output. Uses
11903 /// Float accumulation.
TEST_P(OperatorTest,RowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat)11904 TEST_P(OperatorTest,
11905        RowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat) {
11906   CHECK_IF_ENABLED();
11907   testRowwiseQuantizedSparseLengthsWeightedSum<float16_t, int64_t>(
11908       bindings_, mod_, F_, EE_, ElemKind::Float16Ty, ElemKind::Int64ITy, 0.0001,
11909       /* useFP16Accumulation */ false);
11910 }
11911 
11912 /// Test RWQ-SLWS with Float16 Weights, Scales, Offsets, and Output. Uses
11913 /// Float16 accumulation.
TEST_P(OperatorTest,RowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat16)11914 TEST_P(OperatorTest,
11915        RowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat16) {
11916   CHECK_IF_ENABLED();
11917   testRowwiseQuantizedSparseLengthsWeightedSum<float16_t, int64_t>(
11918       bindings_, mod_, F_, EE_, ElemKind::Float16Ty, ElemKind::Int64ITy, 0.0001,
11919       /* useFP16Accumulation */ true);
11920 }
11921 
11922 /// Test RWQ-SLWS with Float Weights, Scales, Offsets, and Output. Int32
11923 /// indices.
TEST_P(OperatorTest,RowwiseQuantizedSparseLengthsWeightedSum_Float_Int32)11924 TEST_P(OperatorTest, RowwiseQuantizedSparseLengthsWeightedSum_Float_Int32) {
11925   CHECK_IF_ENABLED();
11926   testRowwiseQuantizedSparseLengthsWeightedSum<float, int32_t>(
11927       bindings_, mod_, F_, EE_, ElemKind::FloatTy, ElemKind::Int32ITy, 0.0001);
11928 }
11929 
11930 /// Test RWQ-SLWS with Float16 Weights, Scales, Offsets, and Output. Uses
11931 /// Float accumulation. Int32 indices.
TEST_P(OperatorTest,RowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat_Int32)11932 TEST_P(OperatorTest,
11933        RowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat_Int32) {
11934   CHECK_IF_ENABLED();
11935   testRowwiseQuantizedSparseLengthsWeightedSum<float16_t, int32_t>(
11936       bindings_, mod_, F_, EE_, ElemKind::Float16Ty, ElemKind::Int32ITy, 0.0001,
11937       /* useFP16Accumulation */ false);
11938 }
11939 
11940 /// Test RWQ-SLWS with Float16 Weights, Scales, Offsets, and Output. Uses
11941 /// Float16 accumulation. Int32 indices.
TEST_P(OperatorTest,RowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat16_Int32)11942 TEST_P(OperatorTest,
11943        RowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat16_Int32) {
11944   CHECK_IF_ENABLED();
11945   testRowwiseQuantizedSparseLengthsWeightedSum<float16_t, int32_t>(
11946       bindings_, mod_, F_, EE_, ElemKind::Float16Ty, ElemKind::Int32ITy, 0.0001,
11947       /* useFP16Accumulation */ true);
11948 }
11949 
11950 static FunctionTensorPair
createAndInitRWQSLWSAllSame(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE)11951 createAndInitRWQSLWSAllSame(glow::PlaceholderBindings &bindings,
11952                             glow::ExecutionEngine &EE) {
11953   auto &mod = EE.getModule();
11954   Function *F = mod.createFunction("main");
11955 
11956   Tensor data(ElemKind::FloatTy, {20, 2});
11957   data.getHandle<float>() = {
11958       0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
11959       0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
11960       0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,
11961   };
11962 
11963   Constant *weights = mod.createConstant(ElemKind::FloatTy, {21}, "weights");
11964   weights->getPayloadMutable().getHandle<float>() = {
11965       0.44419134, 0.3419154,  0.28775468, 0.47224975, 0.05422213, 0.14346851,
11966       0.05846643, 0.3750175,  0.09190885, 0.3335992,  0.09665264, 0.4560224,
11967       0.2244578,  0.44881952, 0.42696562, 0.33007848, 0.4511249,  0.11568925,
11968       0.02629679, 0.33864713, 0.42614424};
11969 
11970   Placeholder *indices =
11971       mod.createPlaceholder(ElemKind::Int64ITy, {21}, "indices",
11972                             /* isTrainable */ false);
11973   Placeholder *lengths =
11974       mod.createPlaceholder(ElemKind::Int32ITy, {2}, "lengths",
11975                             /* isTrainable */ false);
11976 
11977   bindings.allocate(indices)->getHandle<int64_t>() = {
11978       11, 8, 19, 8, 4, 11, 4, 19, 6, 18, 2, 6, 15, 5, 14, 14, 15, 13, 4, 6, 5,
11979   };
11980   bindings.allocate(lengths)->getHandle<int32_t>() = {15, 6};
11981 
11982   auto *R = F->createRowwiseQuantizedSparseLengthsWeightedSum(
11983       "RQSLWS", data, weights, indices, lengths,
11984       quantization::Schema::Asymmetric, ElemKind::FloatTy,
11985       /* useFP16Accumulation */ false);
11986   SaveNode *S = F->createSave("save", R);
11987   Tensor *resultT = bindings.allocate(S->getPlaceholder());
11988 
11989   return std::make_pair(F, resultT);
11990 }
11991 
TEST_P(OperatorStatelessTest,RWQSLWSAllSame_Float16_AccumFP16)11992 TEST_P(OperatorStatelessTest, RWQSLWSAllSame_Float16_AccumFP16) {
11993   CHECK_IF_ENABLED();
11994   compareAgainstInterpreter(
11995       getBackendName(), createAndInitRWQSLWSAllSame, ElemKind::Float16Ty,
11996       ElemKind::Float16Ty, 0.0005, parCloneCountOpt,
11997       /* convertToRowwiseQuantization */ false,
11998       /*schema */ quantization::Schema::Asymmetric,
11999       /* biasElemKind */ ElemKind::Int32QTy, /* forceFP16AccumSLS */ true);
12000 }
12001 
TEST_P(OperatorStatelessTest,RWQSLWSAllSame_Float16_AccumFP32)12002 TEST_P(OperatorStatelessTest, RWQSLWSAllSame_Float16_AccumFP32) {
12003   CHECK_IF_ENABLED();
12004   compareAgainstInterpreter(
12005       getBackendName(), createAndInitRWQSLWSAllSame, ElemKind::Float16Ty,
12006       ElemKind::Float16Ty, 1e-6, parCloneCountOpt,
12007       /* convertToRowwiseQuantization */ false,
12008       /*schema */ quantization::Schema::Asymmetric,
12009       /* biasElemKind */ ElemKind::Int32QTy, /* forceFP16AccumSLS */ false);
12010 }
12011 
12012 /// Helper to test RowwiseQuantizedSparseLengthsWeightedSum using \p DTy.
12013 template <typename DataType>
testRowwiseQuantizedSparseLengthsSum(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,float allowedError,bool useFP16Accumulation=false)12014 static void testRowwiseQuantizedSparseLengthsSum(
12015     glow::PlaceholderBindings &bindings, glow::Module &mod, glow::Function *F,
12016     glow::ExecutionEngine &EE, ElemKind DTy, float allowedError,
12017     bool useFP16Accumulation = false) {
12018   /*
12019     DATA  = [
12020         [1.0, 1.2],
12021         [2.3, 3.4],
12022         [4.5, 5.7],
12023     ]
12024     INDICES = [2, 0, 1, 2, 0, 0, 0, 0]
12025     LENGTHS = [2, 0, 2, 1, 3]
12026     OUTPUT = [
12027         [5.5, 6.9],
12028         [0.0, 0.0],
12029         [6.8, 9.1],
12030         [1.0, 1.2],
12031         [3.0, 3.6],
12032     ]
12033   */
12034   Tensor data(ElemKind::FloatTy, {3, 2});
12035   data.getHandle() = {
12036       1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f,
12037   };
12038 
12039   Placeholder *indices =
12040       mod.createPlaceholder(ElemKind::Int64ITy, {8}, "indices",
12041                             /* isTrainable */ false);
12042   Placeholder *lengths = mod.createPlaceholder(
12043       ElemKind::Int32ITy, {5}, "lengths", /* isTrainable */ false);
12044 
12045   bindings.allocate(indices)->getHandle<int64_t>() = {
12046       2, 0, 1, 2, 0, 0, 0, 0,
12047   };
12048   bindings.allocate(lengths)->getHandle<int32_t>() = {
12049       2, 0, 2, 1, 3,
12050   };
12051 
12052   auto *R = F->createRowwiseQuantizedSparseLengthsSum(
12053       "RQSLWS", data, indices, lengths, quantization::Schema::Asymmetric, DTy,
12054       useFP16Accumulation);
12055   SaveNode *S = F->createSave("save", R);
12056   bindings.allocate(S->getPlaceholder());
12057 
12058   EE.compile(CompilationMode::Infer);
12059   EE.run(bindings);
12060 
12061   Tensor &result = *bindings.get(S->getPlaceholder());
12062   Tensor expected(DTy, {5, 2});
12063   expected.getHandle<DataType>() = {
12064       5.5f, 6.9f, 0.0f, 0.0f, 6.8f, 9.1f, 1.0f, 1.2f, 3.0f, 3.6f,
12065   };
12066 
12067   EXPECT_TRUE(expected.isEqual(result, allowedError));
12068 }
12069 
12070 /// Test RWQ-SLS with Float Weights, Scales, Offsets, and Output.
TEST_P(OperatorTest,RowwiseQuantizedSparseLengthsSum_Float)12071 TEST_P(OperatorTest, RowwiseQuantizedSparseLengthsSum_Float) {
12072   CHECK_IF_ENABLED();
12073   testRowwiseQuantizedSparseLengthsSum<float>(bindings_, mod_, F_, EE_,
12074                                               ElemKind::FloatTy, 0.015);
12075 }
12076 
12077 /// Test RWQ-SLS with Float16 Weights, Scales, Offsets, and Output. Uses
12078 /// Float accumulation.
TEST_P(OperatorTest,RowwiseQuantizedSparseLengthsSum_Float16_AccumFloat)12079 TEST_P(OperatorTest, RowwiseQuantizedSparseLengthsSum_Float16_AccumFloat) {
12080   CHECK_IF_ENABLED();
12081   testRowwiseQuantizedSparseLengthsSum<float16_t>(
12082       bindings_, mod_, F_, EE_, ElemKind::Float16Ty, 0.02,
12083       /* useFP16Accumulation */ false);
12084 }
12085 
12086 /// Test RWQ-SLS with Float16 Weights, Scales, Offsets, and Output. Uses
12087 /// Float16 accumulation.
TEST_P(OperatorTest,RowwiseQuantizedSparseLengthsSum_Float16_AccumFloat16)12088 TEST_P(OperatorTest, RowwiseQuantizedSparseLengthsSum_Float16_AccumFloat16) {
12089   CHECK_IF_ENABLED();
12090   testRowwiseQuantizedSparseLengthsSum<float16_t>(
12091       bindings_, mod_, F_, EE_, ElemKind::Float16Ty, 0.02,
12092       /* useFP16Accumulation */ true);
12093 }
12094 
TEST_P(OperatorTest,RepeatedSLSWithPartialTensors)12095 TEST_P(OperatorTest, RepeatedSLSWithPartialTensors) {
12096   CHECK_IF_ENABLED();
12097 
12098   // This test is only meaningful if the backend supports partial tensors.
12099   ASSERT_TRUE(EE_.getBackend(getBackendName()).supportsPartialTensors());
12100 
12101   constexpr dim_t embeddingRows = 1275;
12102   constexpr dim_t numLengths = 20;
12103   constexpr dim_t maxIndices = 20000;
12104   constexpr dim_t numIndices = 20; // Must be less than sum(lengths).
12105   constexpr dim_t iterations = 33;
12106 
12107   auto *data =
12108       mod_.createConstant(ElemKind::FloatTy, {embeddingRows, 1}, "data");
12109   data->getPayloadMutable().getHandle<float>().randomize(-1.0, 1.0,
12110                                                          mod_.getPRNG());
12111   auto *indices = mod_.createPlaceholder(ElemKind::Int64ITy, {maxIndices},
12112                                          "indices", false);
12113   auto *lengths = mod_.createPlaceholder(ElemKind::Int32ITy, {numLengths},
12114                                          "lengths", false);
12115   auto *SLS = F_->createSparseLengthsSum("SLS", data, indices, lengths);
12116   auto *save = F_->createSave("save", SLS);
12117   auto *outPH = save->getPlaceholder();
12118   EE_.compile(CompilationMode::Infer);
12119 
12120   Tensor indicesReal(ElemKind::Int64ITy, {numIndices});
12121   indicesReal.getHandle<int64_t>().randomize(0, embeddingRows - 1,
12122                                              mod_.getPRNG());
12123   Tensor indicesPartial(indicesReal.getUnsafePtr(), indices->getType(),
12124                         indicesReal.getSizeInBytes());
12125   Tensor indicesPadded(indices->getType());
12126   indicesPadded.zero();
12127   memcpy(indicesPadded.getUnsafePtr(), indicesReal.getUnsafePtr(),
12128          numIndices * sizeof(int64_t));
12129 
12130   Tensor lengthsReal(ElemKind::Int32ITy, {numLengths});
12131   lengthsReal.getHandle<int32_t>().clear(1);
12132   Tensor lengthsPartial(lengthsReal.getUnsafePtr(), lengths->getType(),
12133                         lengthsReal.getSizeInBytes());
12134   Tensor lengthsPadded(ElemKind::Int32ITy, {numLengths});
12135   lengthsPadded.assign(&lengthsReal);
12136 
12137   bindings_.insert(indices, std::move(indicesPartial));
12138   bindings_.insert(lengths, std::move(lengthsPartial));
12139   bindings_.allocate(outPH);
12140 
12141   PlaceholderBindings paddedBindings;
12142   paddedBindings.insert(indices, std::move(indicesPadded));
12143   paddedBindings.insert(lengths, std::move(lengthsPadded));
12144   paddedBindings.allocate(outPH);
12145 
12146   for (dim_t i = 0; i < iterations; i++) {
12147     EE_.run(bindings_);
12148     EE_.run(paddedBindings);
12149     ASSERT_TRUE(bindings_.get(outPH)->isEqual(*paddedBindings.get(outPH)));
12150   }
12151 
12152   // Keep these around so their memory is not freed at the end of the
12153   // test/scope. This is so that inside TearDown during import/export testing
12154   // the data is still around.
12155   unownedTensors_.push_back(std::move(indicesReal));
12156   unownedTensors_.push_back(std::move(lengthsReal));
12157 }
12158 
12159 /// Helper to test gathers using partial inputs using \p ITy.
12160 template <typename IndicesType>
12161 static void
testPartialGather(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,std::vector<Tensor> & unownedTensors,ElemKind ITy)12162 testPartialGather(glow::PlaceholderBindings &bindings, glow::Module &mod,
12163                   glow::Function *F, glow::ExecutionEngine &EE,
12164                   std::vector<Tensor> &unownedTensors, ElemKind ITy) {
12165   /*
12166     The acutal input we care about has the following shape/result:
12167 
12168     DATA  = [1.0, 2.3, 4.5]
12169     INDICES = [0, 1, 0, 1, 2, 0]
12170     OUTPUT = [1.0, 2.3, 1.0, 2.3, 4.5, 1.0]
12171 
12172     However, we are going to create a larger INDICES input that is only
12173     partially filled, and expect a larger OUTPUT that we expect will have data
12174     we do not care about.
12175   */
12176 
12177   Placeholder *data = mod.createPlaceholder(ElemKind::FloatTy, {3}, "data",
12178                                             /* isTrainable */ false);
12179   Placeholder *indices =
12180       mod.createPlaceholder(ITy, {10000}, "indices", /* isTrainable */ false);
12181 
12182   bindings.allocate(data)->getHandle<float>() = {1.0f, 2.3f, 4.5f};
12183 
12184   Tensor indicesReal(ITy, {6});
12185   indicesReal.getHandle<IndicesType>() = {0, 1, 0, 1, 2, 0};
12186   Tensor indicesPartial(indicesReal.getUnsafePtr(), indices->getType(),
12187                         indicesReal.getSizeInBytes());
12188   bindings.insert(indices, std::move(indicesPartial));
12189 
12190   auto *R = F->createGather("gather", data, indices);
12191 
12192   auto *result = F->createSave("save", R);
12193   Tensor *resultT = bindings.allocate(result->getPlaceholder());
12194 
12195   // Result should be 10000, even though we only care about the first 6
12196   // results.
12197   EXPECT_EQ(resultT->getType().dims().size(), 1);
12198   EXPECT_EQ(resultT->getType().dims()[0], 10000);
12199 
12200   EE.compile(CompilationMode::Infer);
12201   EE.run(bindings);
12202 
12203   Tensor expectedT(ElemKind::FloatTy, {6});
12204   auto expectedH = expectedT.getHandle<float>();
12205   expectedH = {1.0, 2.3, 1.0, 2.3, 4.5, 1.0};
12206   auto resultH = resultT->getHandle<float>();
12207 
12208   for (dim_t i = 0; i < 6; ++i) {
12209     EXPECT_EQ(expectedH.at({i}), resultH.at({i}));
12210   }
12211 
12212   // Keep this around so their memory is not freed at the end of the
12213   // test/scope. This is so that inside TearDown during import/export testing
12214   // the data is still around.
12215   unownedTensors.push_back(std::move(indicesReal));
12216 }
12217 
TEST_P(OperatorTest,GatherWithInt64PartialTensors)12218 TEST_P(OperatorTest, GatherWithInt64PartialTensors) {
12219   CHECK_IF_ENABLED();
12220   // This test is only meaningful if the backend supports partial tensors.
12221   ASSERT_TRUE(EE_.getBackend(getBackendName()).supportsPartialTensors());
12222   testPartialGather<int64_t>(bindings_, mod_, F_, EE_, unownedTensors_,
12223                              ElemKind::Int64ITy);
12224 }
12225 
TEST_P(OperatorTest,GatherWithInt32PartialTensors)12226 TEST_P(OperatorTest, GatherWithInt32PartialTensors) {
12227   CHECK_IF_ENABLED();
12228   // This test is only meaningful if the backend supports partial tensors.
12229   ASSERT_TRUE(EE_.getBackend(getBackendName()).supportsPartialTensors());
12230   testPartialGather<int32_t>(bindings_, mod_, F_, EE_, unownedTensors_,
12231                              ElemKind::Int32ITy);
12232 }
12233 
12234 /// Helper to test FusedRowwiseQuantizedSparseLengthsWeightedSum using \p DTy.
12235 template <typename DataType, typename IndexType>
testFusedRowwiseQuantizedSparseLengthsWeightedSum(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind fusedDTy,ElemKind ITy,float allowedError,bool useFP16Accumulation=false)12236 static void testFusedRowwiseQuantizedSparseLengthsWeightedSum(
12237     glow::PlaceholderBindings &bindings, glow::Module &mod, glow::Function *F,
12238     glow::ExecutionEngine &EE, ElemKind fusedDTy, ElemKind ITy,
12239     float allowedError, bool useFP16Accumulation = false) {
12240   /*
12241     DATA  =   [[2.0, -0.5, 13]]
12242     WEIGHTS = [3, 1, 0, 0, 0, 0, 2, -0.5]
12243     INDICES = [1, 0, 2, 0, 1, 2, 2, 0]
12244     LENGTHS = [3, 0, 3, 2]
12245     OUTPUT =  [[0.5, 0, 0, 25]]
12246   */
12247   const bool fusedData = isFusedQuantizedElemKind(fusedDTy);
12248   const ElemKind DTy =
12249       fusedData ? getScaleOffsetElemKindFromFused(fusedDTy) : fusedDTy;
12250   Tensor data(ElemKind::FloatTy, {3, 1});
12251   data.getHandle() = {
12252       2.0,
12253       -0.5,
12254       13,
12255   };
12256 
12257   Constant *weights = mod.createConstant(DTy, {8}, "weights");
12258   weights->getPayloadMutable().getHandle<DataType>() = {
12259       3., 1., 0., 0., 0., 0., 2., -0.5,
12260   };
12261 
12262   Placeholder *indices = mod.createPlaceholder(ITy, {8}, "indices",
12263                                                /* isTrainable */ false);
12264   Placeholder *lengths =
12265       mod.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths",
12266                             /* isTrainable */ false);
12267 
12268   bindings.allocate(indices)->getHandle<IndexType>() = {
12269       1, 0, 2, 0, 1, 2, 2, 0,
12270   };
12271   bindings.allocate(lengths)->getHandle<int32_t>() = {
12272       3,
12273       0,
12274       3,
12275       2,
12276   };
12277 
12278   auto *R = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
12279       "RQSLWS", data, weights, indices, lengths, fusedDTy, useFP16Accumulation);
12280   SaveNode *S = F->createSave("save", R);
12281   bindings.allocate(S->getPlaceholder());
12282 
12283   EE.compile(CompilationMode::Infer);
12284   EE.run(bindings);
12285 
12286   Tensor &result = *bindings.get(S->getPlaceholder());
12287   Tensor expected(DTy, {4, 1});
12288   expected.getHandle<DataType>() = {
12289       0.5,
12290       0,
12291       0,
12292       25,
12293   };
12294 
12295   EXPECT_TRUE(expected.isEqual(result, allowedError));
12296 }
12297 
12298 /// Test Fused-RWQ-SLWS in Float.
TEST_P(OperatorTest,FusedRowwiseQuantizedSparseLengthsWeightedSum_Float)12299 TEST_P(OperatorTest, FusedRowwiseQuantizedSparseLengthsWeightedSum_Float) {
12300   CHECK_IF_ENABLED();
12301   testFusedRowwiseQuantizedSparseLengthsWeightedSum<float, int64_t>(
12302       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedQTy, ElemKind::Int64ITy,
12303       0.0001);
12304 }
12305 
12306 /// Test Fused-RWQ-SLWS in Float16. Uses Float accumulation.
TEST_P(OperatorTest,FusedRowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat)12307 TEST_P(OperatorTest,
12308        FusedRowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat) {
12309   CHECK_IF_ENABLED();
12310   testFusedRowwiseQuantizedSparseLengthsWeightedSum<float16_t, int64_t>(
12311       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedFP16QTy, ElemKind::Int64ITy,
12312       0.0001,
12313       /* useFP16Accumulation */ false);
12314 }
12315 
12316 /// Test Fused-RWQ-SLWS in Float16. Uses Float16 accumulation.
TEST_P(OperatorTest,FusedRowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat16)12317 TEST_P(OperatorTest,
12318        FusedRowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat16) {
12319   CHECK_IF_ENABLED();
12320   testFusedRowwiseQuantizedSparseLengthsWeightedSum<float16_t, int64_t>(
12321       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedFP16QTy, ElemKind::Int64ITy,
12322       0.0001,
12323       /* useFP16Accumulation */ true);
12324 }
12325 
12326 /// Test Fused-RWQ-SLWS in Float. Int32 indices.
TEST_P(OperatorTest,FusedRowwiseQuantizedSparseLengthsWeightedSum_Float_Int32)12327 TEST_P(OperatorTest,
12328        FusedRowwiseQuantizedSparseLengthsWeightedSum_Float_Int32) {
12329   CHECK_IF_ENABLED();
12330   testFusedRowwiseQuantizedSparseLengthsWeightedSum<float, int32_t>(
12331       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedQTy, ElemKind::Int32ITy,
12332       0.0001);
12333 }
12334 
12335 /// Test Fused-RWQ-SLWS in Float16. Uses Float accumulation. Int32 indices.
TEST_P(OperatorTest,FusedRowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat_Int32)12336 TEST_P(OperatorTest,
12337        FusedRowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat_Int32) {
12338   CHECK_IF_ENABLED();
12339   testFusedRowwiseQuantizedSparseLengthsWeightedSum<float16_t, int32_t>(
12340       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedFP16QTy, ElemKind::Int32ITy,
12341       0.0001,
12342       /* useFP16Accumulation */ false);
12343 }
12344 
12345 /// Test Fused-RWQ-SLWS in Float16. Uses Float16 accumulation. Int32 indices.
TEST_P(OperatorTest,FusedRowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat16_Int32)12346 TEST_P(
12347     OperatorTest,
12348     FusedRowwiseQuantizedSparseLengthsWeightedSum_Float16_AccumFloat16_Int32) {
12349   CHECK_IF_ENABLED();
12350   testFusedRowwiseQuantizedSparseLengthsWeightedSum<float16_t, int32_t>(
12351       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedFP16QTy, ElemKind::Int32ITy,
12352       0.0001,
12353       /* useFP16Accumulation */ true);
12354 }
12355 
testRowwiseQuantizedSparseLengthsSum_ConvertedFloat16(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,float allowedError,bool convertFusedToFP16,bool useFP16AccumSLS)12356 static void testRowwiseQuantizedSparseLengthsSum_ConvertedFloat16(
12357     glow::PlaceholderBindings &bindings, glow::Module &mod, glow::Function *F,
12358     glow::ExecutionEngine &EE, float allowedError, bool convertFusedToFP16,
12359     bool useFP16AccumSLS) {
12360   CHECK_IF_ENABLED();
12361   /*
12362     DATA  =   [[2.0, -0.5, 13]]
12363     WEIGHTS = [3, 1, 0, 0, 0, 0, 2, -0.5]
12364     INDICES = [1, 0, 2, 0, 1, 2, 2, 0]
12365     LENGTHS = [3, 0, 3, 2]
12366     OUTPUT =  [[0.5, 0, 0, 25]]
12367   */
12368   Tensor data(ElemKind::FloatTy, {3, 1});
12369   data.getHandle() = {
12370       2.0,
12371       -0.5,
12372       13,
12373   };
12374 
12375   Constant *weights = mod.createConstant(ElemKind::FloatTy, {8}, "weights");
12376   weights->getPayloadMutable().getHandle<float>() = {
12377       3., 1., 0., 0., 0., 0., 2., -0.5,
12378   };
12379 
12380   Placeholder *indices =
12381       mod.createPlaceholder(ElemKind::Int64ITy, {8}, "indices",
12382                             /* isTrainable */ false);
12383   Placeholder *lengths =
12384       mod.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths",
12385                             /* isTrainable */ false);
12386 
12387   bindings.allocate(indices)->getHandle<int64_t>() = {
12388       1, 0, 2, 0, 1, 2, 2, 0,
12389   };
12390   bindings.allocate(lengths)->getHandle<int32_t>() = {
12391       3,
12392       0,
12393       3,
12394       2,
12395   };
12396 
12397   auto *R = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
12398       "RQSLWS", data, weights, indices, lengths);
12399   SaveNode *S = F->createSave("save", R);
12400   bindings.allocate(S->getPlaceholder());
12401 
12402   CompilationContext cctx;
12403   cctx.precisionConfig.convertToFP16 = true;
12404   cctx.precisionConfig.convertFusedToFP16 = convertFusedToFP16;
12405   cctx.precisionConfig.forceFP16AccumSLS = useFP16AccumSLS;
12406   cctx.precisionConfig.float16Format =
12407       PrecisionConfiguration::Float16Format::FP16;
12408 
12409   EE.compile(cctx);
12410   EE.run(bindings);
12411 
12412   Tensor &result = *bindings.get(S->getPlaceholder());
12413   Tensor expected(ElemKind::FloatTy, {4, 1});
12414   expected.getHandle<float>() = {
12415       0.5,
12416       0,
12417       0,
12418       25,
12419   };
12420 
12421   EXPECT_TRUE(expected.isEqual(result, allowedError));
12422 }
12423 
12424 /// Test Fused-RWQ-SLWS in where the weights are in Fp16, data
12425 /// inputs are UInt8FusedQTy.
TEST_P(OperatorTest,FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_NoFusedConvert)12426 TEST_P(
12427     OperatorTest,
12428     FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_NoFusedConvert) {
12429   CHECK_IF_ENABLED();
12430   return testRowwiseQuantizedSparseLengthsSum_ConvertedFloat16(
12431       bindings_, mod_, F_, EE_, 0.02,
12432       /* convertFusedToFP16*/ false, /* useFP16AccumSLS */ true);
12433 }
12434 
TEST_P(OperatorTest,FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_NoFusedConvert_FP32Accum)12435 TEST_P(
12436     OperatorTest,
12437     FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_NoFusedConvert_FP32Accum) {
12438   CHECK_IF_ENABLED();
12439   return testRowwiseQuantizedSparseLengthsSum_ConvertedFloat16(
12440       bindings_, mod_, F_, EE_, 0.02,
12441       /* convertFusedToFP16*/ false, /* useFP16AccumSLS */ false);
12442 }
12443 
TEST_P(OperatorTest,FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16)12444 TEST_P(OperatorTest,
12445        FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16) {
12446   CHECK_IF_ENABLED();
12447   return testRowwiseQuantizedSparseLengthsSum_ConvertedFloat16(
12448       bindings_, mod_, F_, EE_, 0.02,
12449       /* convertFusedToFP16*/ true, /* useFP16AccumSLS */ true);
12450 }
12451 
TEST_P(OperatorTest,FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_back_to_back)12452 TEST_P(
12453     OperatorTest,
12454     FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_back_to_back) {
12455   CHECK_IF_ENABLED();
12456   /*
12457     DATA  =   [[2.0, -0.5, 13]]
12458     WEIGHTS = [1]
12459     INDICES = [0]
12460     LENGTHS = [0, 0, 0, 1] and then [1, 0, 0, 0]
12461     OUTPUT =  [[0, 0, 0, 0.2]] and then [[2.0, 0, 0, 0]]
12462   */
12463   Tensor data(ElemKind::FloatTy, {3, 1});
12464   data.getHandle() = {
12465       2.0,
12466       -0.5,
12467       13,
12468   };
12469 
12470   Constant *weights = mod_.createConstant(ElemKind::FloatTy, {1}, "weights");
12471   weights->getPayloadMutable().getHandle<float>() = {1.};
12472 
12473   Placeholder *indices =
12474       mod_.createPlaceholder(ElemKind::Int64ITy, {1}, "indices",
12475                              /* isTrainable */ false);
12476   Placeholder *lengths =
12477       mod_.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths",
12478                              /* isTrainable */ false);
12479 
12480   bindings_.allocate(indices)->getHandle<int64_t>() = {
12481       0,
12482   };
12483   bindings_.allocate(lengths)->getHandle<int32_t>() = {
12484       0,
12485       0,
12486       0,
12487       1,
12488   };
12489 
12490   auto *R = F_->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
12491       "RQSLWS", data, weights, indices, lengths);
12492   SaveNode *S = F_->createSave("save", R);
12493   bindings_.allocate(S->getPlaceholder());
12494 
12495   CompilationContext cctx;
12496   cctx.precisionConfig.convertToFP16 = true;
12497   cctx.precisionConfig.convertFusedToFP16 = true;
12498   cctx.precisionConfig.float16Format =
12499       PrecisionConfiguration::Float16Format::FP16;
12500 
12501   EE_.compile(cctx);
12502   EE_.run(bindings_);
12503 
12504   Tensor &result = *bindings_.get(S->getPlaceholder());
12505   Tensor expected(ElemKind::FloatTy, {4, 1});
12506   expected.getHandle<float>() = {
12507       0,
12508       0,
12509       0,
12510       2.0,
12511   };
12512 
12513   EXPECT_TRUE(expected.isEqual(result, 0.02));
12514 
12515   // Send another inference
12516   bindings_.get(lengths)->getHandle<int32_t>() = {
12517       1,
12518       0,
12519       0,
12520       0,
12521   };
12522   EE_.run(bindings_);
12523 
12524   Tensor &result1 = *bindings_.get(S->getPlaceholder());
12525   Tensor expected1(ElemKind::FloatTy, {4, 1});
12526   expected1.getHandle<float>() = {
12527       2.0,
12528       0,
12529       0,
12530       0,
12531   };
12532   EXPECT_TRUE(expected1.isEqual(result1, 0.02));
12533 }
12534 
TEST_P(OperatorTest,FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_back_to_back2)12535 TEST_P(
12536     OperatorTest,
12537     FusedRowwiseQuantizedSparseLengthsWeightedSum_ConvertedFloat16_back_to_back2) {
12538   CHECK_IF_ENABLED();
12539 
12540   Tensor data(ElemKind::FloatTy, {10000, 64});
12541   data.getHandle().randomize(-1, 1, mod_.getPRNG());
12542 
12543   Placeholder *weights =
12544       mod_.createPlaceholder(ElemKind::FloatTy, {10000}, "weights",
12545                              /* isTrainable */ false);
12546 
12547   Placeholder *indices =
12548       mod_.createPlaceholder(ElemKind::Int64ITy, {10000}, "indices",
12549                              /* isTrainable */ false);
12550   Placeholder *lengths =
12551       mod_.createPlaceholder(ElemKind::Int32ITy, {32}, "lengths",
12552                              /* isTrainable */ false);
12553 
12554   Tensor *wT = bindings_.allocate(weights);
12555   wT->zero();
12556   wT->getHandle<float>().at({0}) = 4.18067;
12557 
12558   Tensor *iT = bindings_.allocate(indices);
12559   iT->zero();
12560   iT->getHandle<int64_t>().at({0}) = 4124;
12561 
12562   bindings_.allocate(lengths)->getHandle<int32_t>() = {
12563       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
12564       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
12565 
12566   auto *R = F_->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
12567       "RQSLWS", data, weights, indices, lengths);
12568   SaveNode *S = F_->createSave("save", R);
12569   bindings_.allocate(S->getPlaceholder());
12570 
12571   CompilationContext cctx;
12572   cctx.precisionConfig.convertToFP16 = true;
12573   cctx.precisionConfig.convertFusedToFP16 = true;
12574   cctx.precisionConfig.float16Format =
12575       PrecisionConfiguration::Float16Format::FP16;
12576 
12577   EE_.compile(cctx);
12578   EE_.run(bindings_);
12579 
12580   // This is the result for the first inference. We expect the result in the
12581   // second last row or raw location 30 * 64 to 31 * 64 -1. The rest of the
12582   // rows should be all 0.
12583   Tensor &result = *bindings_.get(S->getPlaceholder());
12584 
12585   // Send another inference
12586   result.zero();
12587   // set new indices.
12588   iT = bindings_.get(indices);
12589   iT->zero();
12590   iT->getHandle<int64_t>().at({0}) = 1256;
12591   // set new lengths.
12592   bindings_.get(lengths)->getHandle<int32_t>() = {
12593       0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
12594       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
12595 
12596   };
12597   EE_.run(bindings_);
12598 
12599   // We now expect the second to last row to be all 0.
12600   Tensor &result1 = *bindings_.get(S->getPlaceholder());
12601   float *d = reinterpret_cast<float *>(result1.getUnsafePtr());
12602   for (size_t i = 30 * 64; i < 31 * 64; ++i) {
12603     EXPECT_EQ(0, d[i]);
12604   }
12605 }
12606 
12607 /// Helper to test FusedRowwiseQuantizedSparseLengthsSum using \p fusedDTy.
12608 template <typename DataType>
testFusedRowwiseQuantizedSparseLengthsSum(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind fusedDTy,float allowedError,bool useFP16Accumulation=false)12609 static void testFusedRowwiseQuantizedSparseLengthsSum(
12610     glow::PlaceholderBindings &bindings, glow::Module &mod, glow::Function *F,
12611     glow::ExecutionEngine &EE, ElemKind fusedDTy, float allowedError,
12612     bool useFP16Accumulation = false) {
12613   /*
12614     DATA  = [
12615         [1.0, 1.2],
12616         [2.3, 3.4],
12617         [4.5, 5.7],
12618     ]
12619     INDICES = [2, 0, 1, 2, 0, 0, 0, 0]
12620     LENGTHS = [2, 0, 2, 1, 3]
12621     OUTPUT = [
12622         [5.5, 6.9],
12623         [0.0, 0.0],
12624         [6.8, 9.1],
12625         [1.0, 1.2],
12626         [3.0, 3.6],
12627     ]
12628   */
12629   const bool fusedData = isFusedQuantizedElemKind(fusedDTy);
12630   const ElemKind DTy =
12631       fusedData ? getScaleOffsetElemKindFromFused(fusedDTy) : fusedDTy;
12632 
12633   Tensor data(ElemKind::FloatTy, {3, 2});
12634   data.getHandle() = {
12635       1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f,
12636   };
12637 
12638   Placeholder *indices =
12639       mod.createPlaceholder(ElemKind::Int64ITy, {8}, "indices",
12640                             /* isTrainable */ false);
12641   Placeholder *lengths = mod.createPlaceholder(
12642       ElemKind::Int32ITy, {5}, "lengths", /* isTrainable */ false);
12643 
12644   bindings.allocate(indices)->getHandle<int64_t>() = {
12645       2, 0, 1, 2, 0, 0, 0, 0,
12646   };
12647   bindings.allocate(lengths)->getHandle<int32_t>() = {
12648       2, 0, 2, 1, 3,
12649   };
12650 
12651   auto *R = F->createFusedRowwiseQuantizedSparseLengthsSum(
12652       "RQSLWS", data, indices, lengths, fusedDTy, useFP16Accumulation);
12653   SaveNode *S = F->createSave("save", R);
12654   bindings.allocate(S->getPlaceholder());
12655 
12656   EE.compile(CompilationMode::Infer);
12657   EE.run(bindings);
12658 
12659   Tensor &result = *bindings.get(S->getPlaceholder());
12660   Tensor expected(DTy, {5, 2});
12661   expected.getHandle<DataType>() = {
12662       5.5f, 6.9f, 0.0f, 0.0f, 6.8f, 9.1f, 1.0f, 1.2f, 3.0f, 3.6f,
12663   };
12664 
12665   EXPECT_TRUE(expected.isEqual(result, allowedError));
12666 }
12667 
12668 /// Test Fused-RWQ-SLS in Float.
TEST_P(OperatorTest,FusedRowwiseQuantizedSparseLengthsSum_Float)12669 TEST_P(OperatorTest, FusedRowwiseQuantizedSparseLengthsSum_Float) {
12670   CHECK_IF_ENABLED();
12671   testFusedRowwiseQuantizedSparseLengthsSum<float>(
12672       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedQTy, 0.015);
12673 }
12674 
12675 /// Test Fused-RWQ-SLS in Float16. Uses Float accumulation.
TEST_P(OperatorTest,FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat)12676 TEST_P(OperatorTest, FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat) {
12677   CHECK_IF_ENABLED();
12678   testFusedRowwiseQuantizedSparseLengthsSum<float16_t>(
12679       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedFP16QTy, 0.02,
12680       /* useFP16Accumulation */ false);
12681 }
12682 
12683 /// Test Fused-RWQ-SLS in Float16. Uses Float16 accumulation.
TEST_P(OperatorTest,FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat16)12684 TEST_P(OperatorTest,
12685        FusedRowwiseQuantizedSparseLengthsSum_Float16_AccumFloat16) {
12686   CHECK_IF_ENABLED();
12687   testFusedRowwiseQuantizedSparseLengthsSum<float16_t>(
12688       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedFP16QTy, 0.02,
12689       /* useFP16Accumulation */ true);
12690 }
12691 
12692 /// Test Fused-RWQ-SLS in Float16 wth 4-bit quantization for the embedding.
12693 /// Uses Float16 accumulation.
TEST_P(OperatorTest,FusedRowwiseQuantizedSparseLengthsSum_Fused4Bit_Float16_AccumFloat16)12694 TEST_P(OperatorTest,
12695        FusedRowwiseQuantizedSparseLengthsSum_Fused4Bit_Float16_AccumFloat16) {
12696   CHECK_IF_ENABLED();
12697   testFusedRowwiseQuantizedSparseLengthsSum<float16_t>(
12698       bindings_, mod_, F_, EE_, ElemKind::UInt4FusedFP16QTy, 0.15,
12699       /* useFP16Accumulation */ true);
12700 }
12701 
12702 /// Helper to test all variants of SLWS wiith all lengths as one, with
12703 /// precision \p DTy, and precision for data \p dataDTy.
12704 template <typename DataType>
testSLWSTwoColumn(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind dataDTy,float allowedError,bool useFP16Accumulation=false)12705 static void testSLWSTwoColumn(glow::PlaceholderBindings &bindings,
12706                               glow::Module &mod, glow::Function *F,
12707                               glow::ExecutionEngine &EE, ElemKind dataDTy,
12708                               float allowedError,
12709                               bool useFP16Accumulation = false) {
12710   /*
12711     DATA  = [
12712         [1.0, 1.2],
12713         [2.3, 3.4],
12714         [4.5, 5.7],
12715     ]
12716     INDICES = [2, 0, 1, 2, 0, 0, 0, 0]
12717     LENGTHS = [2, 0, 2, 1, 3]
12718     WEIGHTS = [1, -1, 1.5, 0.5, -1.5, 2, -2, -0.5]
12719     OUTPUT = [
12720         [3.5, 4.5],
12721         [0.0, 0.0],
12722         [5.7, 7.95],
12723         [-1.5, -1.8],
12724         [-0.5, -0.6],
12725     ]
12726   */
12727   const bool fusedData = isFusedQuantizedElemKind(dataDTy);
12728   const ElemKind DTy =
12729       fusedData ? getScaleOffsetElemKindFromFused(dataDTy) : dataDTy;
12730 
12731   Tensor data(fusedData ? ElemKind::FloatTy : DTy, {3, 2});
12732 #define floatData                                                              \
12733   { 1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f, }
12734   if (fusedData) {
12735     data.getHandle<float>() = floatData;
12736   } else {
12737     data.getHandle<DataType>() = floatData;
12738   }
12739 
12740   Placeholder *indices =
12741       mod.createPlaceholder(ElemKind::Int64ITy, {8}, "indices",
12742                             /* isTrainable */ false);
12743   Placeholder *lengths = mod.createPlaceholder(
12744       ElemKind::Int32ITy, {5}, "lengths", /* isTrainable */ false);
12745   Placeholder *weights =
12746       mod.createPlaceholder(DTy, {8}, "weights", /* isTrainable */ false);
12747 
12748   bindings.allocate(indices)->getHandle<int64_t>() = {
12749       2, 0, 1, 2, 0, 0, 0, 0,
12750   };
12751   bindings.allocate(lengths)->getHandle<int32_t>() = {
12752       2, 0, 2, 1, 3,
12753   };
12754   bindings.allocate(weights)->getHandle<DataType>() = {
12755       1, -1, 1.5, 0.5, -1.5, 2, -2, -0.5,
12756   };
12757 
12758   Node *SLWS = nullptr;
12759   if (fusedData) {
12760     SLWS = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
12761         "RQSLWS", data, weights, indices, lengths, dataDTy,
12762         useFP16Accumulation);
12763   } else {
12764     Placeholder *dataP = mod.createPlaceholder(&data.getType(), "data",
12765                                                /* isTrainable */ false);
12766     bindings.insert(dataP, std::move(data));
12767     SLWS = F->createSparseLengthsWeightedSum("SLWS", dataP, weights, indices,
12768                                              lengths);
12769   }
12770   SaveNode *S = F->createSave("save", SLWS);
12771   bindings.allocate(S->getPlaceholder());
12772 
12773   EE.compile(CompilationMode::Infer);
12774   EE.run(bindings);
12775 
12776   Tensor &result = *bindings.get(S->getPlaceholder());
12777   Tensor expected(DTy, {5, 2});
12778   expected.getHandle<DataType>() = {
12779       3.5, 4.5, 0.0, 0.0, 5.7, 7.95, -1.5, -1.8, -0.5, -0.6,
12780   };
12781 
12782   EXPECT_TRUE(expected.isEqual(result, allowedError));
12783 }
12784 
12785 /// Test SLWS in Float.
TEST_P(OperatorTest,SLWSTwoColumn_Float)12786 TEST_P(OperatorTest, SLWSTwoColumn_Float) {
12787   CHECK_IF_ENABLED();
12788   testSLWSTwoColumn<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, 0.0001);
12789 }
12790 
12791 /// Test SLWS in Float16.
TEST_P(OperatorTest,SLWSTwoColumn_Float16_AccumFloat)12792 TEST_P(OperatorTest, SLWSTwoColumn_Float16_AccumFloat) {
12793   CHECK_IF_ENABLED();
12794   testSLWSTwoColumn<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty,
12795                                0.005,
12796                                /* useFP16Accumulation */ false);
12797 }
12798 
12799 /// Test Fused-RWQ-SLWS in Float.
TEST_P(OperatorTest,FusedRowwiseQuantizedSLWSTwoColumn_Float)12800 TEST_P(OperatorTest, FusedRowwiseQuantizedSLWSTwoColumn_Float) {
12801   CHECK_IF_ENABLED();
12802   testSLWSTwoColumn<float>(bindings_, mod_, F_, EE_, ElemKind::UInt8FusedQTy,
12803                            0.015);
12804 }
12805 
12806 /// Test Fused-RWQ-SLWS in Float16. Uses Float accumulation.
TEST_P(OperatorTest,FusedRowwiseQuantizedSLWSTwoColumn_Float16_AccumFloat)12807 TEST_P(OperatorTest, FusedRowwiseQuantizedSLWSTwoColumn_Float16_AccumFloat) {
12808   CHECK_IF_ENABLED();
12809   testSLWSTwoColumn<float16_t>(bindings_, mod_, F_, EE_,
12810                                ElemKind::UInt8FusedFP16QTy, 0.015,
12811                                /* useFP16Accumulation */ false);
12812 }
12813 
12814 /// Test Fused-RWQ-SLWS in Float16. Uses Float16 accumulation.
TEST_P(OperatorTest,FusedRowwiseQuantizedSLWSTwoColumn_Float16_AccumFloat16)12815 TEST_P(OperatorTest, FusedRowwiseQuantizedSLWSTwoColumn_Float16_AccumFloat16) {
12816   CHECK_IF_ENABLED();
12817   testSLWSTwoColumn<float16_t>(bindings_, mod_, F_, EE_,
12818                                ElemKind::UInt8FusedFP16QTy, 0.015,
12819                                /* useFP16Accumulation */ true);
12820 }
12821 
12822 /// Test Fused-RWQ-SLWS in Float16 wth 4-bit quantization for the embedding.
12823 /// Uses Float16 accumulation.
TEST_P(OperatorTest,FusedRowwiseQuantizedSLWSTwoColumn_Fused4Bit_Float16_AccumFloat16)12824 TEST_P(OperatorTest,
12825        FusedRowwiseQuantizedSLWSTwoColumn_Fused4Bit_Float16_AccumFloat16) {
12826   CHECK_IF_ENABLED();
12827   testSLWSTwoColumn<float16_t>(bindings_, mod_, F_, EE_,
12828                                ElemKind::UInt4FusedFP16QTy, 0.1,
12829                                /* useFP16Accumulation */ true);
12830 }
12831 
12832 /// Helper to test SLWS with different lengths modes, with precision \p DTy,
12833 /// and precision for data \p dataDTy.
12834 template <typename DataType>
testSLWSLengthsMode(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind dataDTy,float allowedError,bool useFP16Accumulation,LengthsMode lengthsMode)12835 static void testSLWSLengthsMode(glow::PlaceholderBindings &bindings,
12836                                 glow::Module &mod, glow::Function *F,
12837                                 glow::ExecutionEngine &EE, ElemKind dataDTy,
12838                                 float allowedError, bool useFP16Accumulation,
12839                                 LengthsMode lengthsMode) {
12840   /*
12841     DATA  = [
12842         [1.0, 1.2],
12843         [2.3, 3.4],
12844         [4.5, 5.7],
12845     ]
12846     INDICES = [2, 0, 1, 2, 0]
12847     LENGTHS = [1, 1, 1, 1, 1]
12848     WEIGHTS = [1, -1, 1.5, 0.5, -1.5]
12849     OUTPUT = [
12850         [4.5, 5.7],
12851         [-1.0, -1.2],
12852         [3.45, 5.1],
12853         [2.25, 2.85],
12854         [-1.5, -1.8],
12855     ]
12856   */
12857   const bool fusedData = isFusedQuantizedElemKind(dataDTy);
12858   const ElemKind DTy =
12859       fusedData ? getScaleOffsetElemKindFromFused(dataDTy) : dataDTy;
12860 
12861   Tensor data(fusedData ? ElemKind::FloatTy : DTy, {3, 2});
12862 #define floatData                                                              \
12863   { 1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f, }
12864   if (fusedData) {
12865     data.getHandle<float>() = floatData;
12866   } else {
12867     data.getHandle<DataType>() = floatData;
12868   }
12869 
12870   Placeholder *indices =
12871       mod.createPlaceholder(ElemKind::Int64ITy, {5}, "indices",
12872                             /* isTrainable */ false);
12873   Placeholder *lengths = mod.createPlaceholder(
12874       ElemKind::Int32ITy, {5}, "lengths", /* isTrainable */ false);
12875   Placeholder *weights =
12876       mod.createPlaceholder(DTy, {5}, "weights", /* isTrainable */ false);
12877 
12878   bindings.allocate(indices)->getHandle<int64_t>() = {
12879       2, 0, 1, 2, 0,
12880   };
12881   auto LH = bindings.allocate(lengths)->getHandle<int32_t>();
12882   Tensor expected(DTy, {5, 2});
12883   LH = {1, 1, 1, 1, 1};
12884   expected.getHandle<DataType>() = {
12885       4.5, 5.7, -1.0, -1.2, 3.45, 5.1, 2.25, 2.85, -1.5, -1.8,
12886   };
12887   bindings.allocate(weights)->getHandle<DataType>() = {
12888       1, -1, 1.5, 0.5, -1.5,
12889   };
12890 
12891   Node *SLWS = nullptr;
12892   if (fusedData) {
12893     SLWS = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
12894         "RQSLWS", data, weights, indices, lengths, dataDTy, useFP16Accumulation,
12895         lengthsMode);
12896   } else {
12897     Placeholder *dataP = mod.createPlaceholder(&data.getType(), "data",
12898                                                /* isTrainable */ false);
12899     bindings.insert(dataP, std::move(data));
12900     SLWS = F->createSparseLengthsWeightedSum("SLWS", dataP, weights, indices,
12901                                              lengths, lengthsMode);
12902   }
12903   SaveNode *S = F->createSave("save", SLWS);
12904   bindings.allocate(S->getPlaceholder());
12905 
12906   EE.compile(CompilationMode::Infer);
12907   EE.run(bindings);
12908 
12909   Tensor &result = *bindings.get(S->getPlaceholder());
12910 
12911   EXPECT_TRUE(expected.isEqual(result, allowedError));
12912 }
12913 
12914 /// Test SLWS in Float.
TEST_P(OperatorTest,SLWSAllLengthsOne_Float)12915 TEST_P(OperatorTest, SLWSAllLengthsOne_Float) {
12916   CHECK_IF_ENABLED();
12917   testSLWSLengthsMode<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy,
12918                              0.0001, /* useFP16Accumulation */ false,
12919                              LengthsMode::AllOne);
12920 }
12921 
12922 /// Test SLWS in Float16.
TEST_P(OperatorTest,SLWSAllLengthsOne_Float16_AccumFloat)12923 TEST_P(OperatorTest, SLWSAllLengthsOne_Float16_AccumFloat) {
12924   CHECK_IF_ENABLED();
12925   testSLWSLengthsMode<float16_t>(
12926       bindings_, mod_, F_, EE_, ElemKind::Float16Ty, 0.005,
12927       /* useFP16Accumulation */ false, LengthsMode::AllOne);
12928 }
12929 
12930 /// Test Fused-RWQ-SLWS in Float.
TEST_P(OperatorTest,FusedRowwiseQuantizedSLWSAllLengthsOne_Float)12931 TEST_P(OperatorTest, FusedRowwiseQuantizedSLWSAllLengthsOne_Float) {
12932   CHECK_IF_ENABLED();
12933   testSLWSLengthsMode<float>(
12934       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedQTy, 0.015,
12935       /* useFP16Accumulation */ false, LengthsMode::AllOne);
12936 }
12937 
12938 /// Test Fused-RWQ-SLWS in Float16. Uses Float accumulation.
TEST_P(OperatorTest,FusedRowwiseQuantizedSLWSAllLengthsOne_Float16_AccumFloat)12939 TEST_P(OperatorTest,
12940        FusedRowwiseQuantizedSLWSAllLengthsOne_Float16_AccumFloat) {
12941   CHECK_IF_ENABLED();
12942   testSLWSLengthsMode<float16_t>(
12943       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedFP16QTy, 0.015,
12944       /* useFP16Accumulation */ false, LengthsMode::AllOne);
12945 }
12946 
12947 /// Test Fused-RWQ-SLWS in Float16. Uses Float16 accumulation.
TEST_P(OperatorTest,FusedRowwiseQuantizedSLWSAllLengthsOne_Float16_AccumFloat16)12948 TEST_P(OperatorTest,
12949        FusedRowwiseQuantizedSLWSAllLengthsOne_Float16_AccumFloat16) {
12950   CHECK_IF_ENABLED();
12951   testSLWSLengthsMode<float16_t>(
12952       bindings_, mod_, F_, EE_, ElemKind::UInt8FusedFP16QTy, 0.015,
12953       /* useFP16Accumulation */ true, LengthsMode::AllOne);
12954 }
12955 
12956 /// Test Fused-RWQ-SLWS in Float16 wth 4-bit quantization for the embedding.
12957 /// Uses Float16 accumulation.
TEST_P(OperatorTest,FusedRowwiseQuantizedSLWSAllLengthsOne_Fused4Bit_Float16_AccumFloat16)12958 TEST_P(OperatorTest,
12959        FusedRowwiseQuantizedSLWSAllLengthsOne_Fused4Bit_Float16_AccumFloat16) {
12960   CHECK_IF_ENABLED();
12961   testSLWSLengthsMode<float16_t>(
12962       bindings_, mod_, F_, EE_, ElemKind::UInt4FusedFP16QTy, 0.1,
12963       /* useFP16Accumulation */ true, LengthsMode::AllOne);
12964 }
12965 
12966 /// Test SLS when some input tensors are constants.
TEST_P(OperatorTest,ConstantSLS)12967 TEST_P(OperatorTest, ConstantSLS) {
12968   CHECK_IF_ENABLED();
12969 
12970   auto *data = mod_.createConstant(ElemKind::FloatTy, {1024, 32}, "data");
12971   auto *indices =
12972       mod_.createPlaceholder(ElemKind::Int64ITy, {314}, "indices", false);
12973   auto *lengths = mod_.createConstant(ElemKind::Int32ITy, {20}, "lengths");
12974 
12975   // data
12976   auto DH = data->getPayload().getHandle();
12977   for (dim_t i = 0; i < 1024; i++) {
12978     for (dim_t j = 0; j < 32; j++) {
12979       DH.at({i, j}) = (float)i;
12980     }
12981   }
12982 
12983   // indices
12984   auto IH = bindings_.allocate(indices)->getHandle<int64_t>();
12985   std::iota(IH.begin(), IH.end(), 0);
12986 
12987   // lengths
12988   auto LH = lengths->getHandle<int32_t>();
12989   LH.clear(16);
12990   for (dim_t ldx : {1, 2, 6, 13, 14, 19}) {
12991     LH.at({ldx}) = 15;
12992   }
12993 
12994   auto *R = F_->createSparseLengthsSum("SLS", data, indices, lengths);
12995   auto *S = F_->createSave("save", R);
12996   auto *out = bindings_.allocate(S->getPlaceholder());
12997 
12998   EE_.compile(CompilationMode::Infer);
12999   EE_.run(bindings_);
13000 
13001   std::vector<float> expected = {120,  345,  570,  856,  1112, 1368, 1515,
13002                                  1864, 2120, 2376, 2632, 2888, 3144, 3180,
13003                                  3405, 3880, 4136, 4392, 4648, 4590};
13004   auto OH = out->getHandle();
13005   for (dim_t i = 0; i < 20; i++) {
13006     for (dim_t j = 0; j < 32; j++) {
13007       EXPECT_EQ(OH.at({i, j}), expected[i]);
13008     }
13009   }
13010 }
13011 
13012 /// Test SLS when some "lengths" inputs are zero.
TEST_P(OperatorStatelessTest,SLSWithZeroLengths)13013 TEST_P(OperatorStatelessTest, SLSWithZeroLengths) {
13014   CHECK_IF_ENABLED();
13015 
13016   compareAgainstInterpreter(
13017       getBackendName(),
13018       [](PlaceholderBindings &bindings, ExecutionEngine &EE) {
13019         auto &mod = EE.getModule();
13020         auto *F = mod.createFunction("main");
13021         constexpr dim_t embedWidth = 1000;
13022         Tensor data(ElemKind::FloatTy, {embedWidth, 8});
13023         data.getHandle().randomize(-1, 1, mod.getPRNG());
13024         Constant *weights =
13025             mod.createConstant(ElemKind::FloatTy, {3000}, "weights");
13026         weights->getPayloadMutable().getHandle().clear(1.0f);
13027         auto *indices =
13028             mod.createPlaceholder(ElemKind::Int64ITy, {3000}, "indices", false);
13029         auto *lengths =
13030             mod.createPlaceholder(ElemKind::Int32ITy, {1000}, "lengths", false);
13031         bindings.allocate(indices)->getHandle<int64_t>().randomize(
13032             0, embedWidth - 1, mod.getPRNG());
13033         auto LH = bindings.allocate(lengths)->getHandle<int32_t>();
13034         LH.clear(0);
13035         auto it = LH.begin();
13036         for (int i = 0; i < 13; ++i, ++it) {
13037           *it = 20;
13038         }
13039 
13040         auto *R = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
13041             "RQSLWS", data, weights, indices, lengths);
13042         auto *S = F->createSave("save", R);
13043         auto *res = bindings.allocate(S->getPlaceholder());
13044         return std::make_pair(F, res);
13045       },
13046       ElemKind::FloatTy, ElemKind::FloatTy);
13047 }
13048 
13049 /// Helper to create an SLS test with all zero lengths, with and without fused
13050 /// rowwise quantization based on \p convertToRowwiseQuantization.
13051 static FunctionTensorPair
createAndInitZeroLengthsSLSTest(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE,bool convertToRowwiseQuantization)13052 createAndInitZeroLengthsSLSTest(glow::PlaceholderBindings &bindings,
13053                                 glow::ExecutionEngine &EE,
13054                                 bool convertToRowwiseQuantization) {
13055   auto &mod = EE.getModule();
13056   auto *F = mod.createFunction("main");
13057   constexpr dim_t embedWidth = 1000;
13058   auto dataTy = mod.uniqueType(ElemKind::FloatTy, {embedWidth, 8});
13059   Tensor data(dataTy);
13060   data.getHandle().randomize(-1, 1, mod.getPRNG());
13061   Constant *weights = mod.createConstant(ElemKind::FloatTy, {3000}, "weights");
13062   weights->getPayloadMutable().getHandle().clear(1.0f);
13063   auto *indices =
13064       mod.createPlaceholder(ElemKind::Int64ITy, {3000}, "indices", false);
13065   auto *lengths =
13066       mod.createPlaceholder(ElemKind::Int32ITy, {1000}, "lengths", false);
13067   bindings.allocate(indices)->getHandle<int64_t>().randomize(0, embedWidth - 1,
13068                                                              mod.getPRNG());
13069   auto LH = bindings.allocate(lengths)->getHandle<int32_t>();
13070   LH.clear(0);
13071 
13072   Node *R = nullptr;
13073   if (convertToRowwiseQuantization) {
13074     R = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
13075         "RQSLWS", data, weights, indices, lengths);
13076   } else {
13077     Placeholder *dataP =
13078         mod.createPlaceholder(dataTy, "data", /* isTrainable */ false);
13079     bindings.insert(dataP, std::move(data));
13080     R = F->createSparseLengthsWeightedSum("SLWS", dataP, weights, indices,
13081                                           lengths);
13082   }
13083   auto *S = F->createSave("save", R);
13084   auto *res = bindings.allocate(S->getPlaceholder());
13085   return std::make_pair(F, res);
13086 }
13087 
13088 /// Test Fused RWQ-SLS when all "lengths" inputs are zero in FloatTy.
TEST_P(OperatorStatelessTest,FusedRWQSLSAllZeroLengths_Float)13089 TEST_P(OperatorStatelessTest, FusedRWQSLSAllZeroLengths_Float) {
13090   CHECK_IF_ENABLED();
13091 
13092   compareAgainstInterpreter(getBackendName(),
13093                             std::bind(createAndInitZeroLengthsSLSTest,
13094                                       std::placeholders::_1,
13095                                       std::placeholders::_2,
13096                                       /* convertToRowwiseQuantization */ true),
13097                             ElemKind::FloatTy, ElemKind::FloatTy);
13098 }
13099 
13100 /// Test Fused RWQ-SLS when all "lengths" inputs are zero in Float16Ty.
TEST_P(OperatorStatelessTest,FusedRWQSLSAllZeroLengths_Float16)13101 TEST_P(OperatorStatelessTest, FusedRWQSLSAllZeroLengths_Float16) {
13102   CHECK_IF_ENABLED();
13103 
13104   compareAgainstInterpreter(getBackendName(),
13105                             std::bind(createAndInitZeroLengthsSLSTest,
13106                                       std::placeholders::_1,
13107                                       std::placeholders::_2,
13108                                       /* convertToRowwiseQuantization */ true),
13109 
13110                             ElemKind::Float16Ty, ElemKind::Float16Ty);
13111 }
13112 
13113 /// Test SLS when all "lengths" inputs are zero in FloatTy.
TEST_P(OperatorStatelessTest,SLSAllZeroLengths_Float)13114 TEST_P(OperatorStatelessTest, SLSAllZeroLengths_Float) {
13115   CHECK_IF_ENABLED();
13116 
13117   compareAgainstInterpreter(getBackendName(),
13118                             std::bind(createAndInitZeroLengthsSLSTest,
13119                                       std::placeholders::_1,
13120                                       std::placeholders::_2,
13121                                       /* convertToRowwiseQuantization */ false),
13122                             ElemKind::FloatTy, ElemKind::FloatTy);
13123 }
13124 
13125 /// Test SLS when all "lengths" inputs are zero in Float16Ty.
TEST_P(OperatorStatelessTest,SLSAllZeroLengths_Float16)13126 TEST_P(OperatorStatelessTest, SLSAllZeroLengths_Float16) {
13127   CHECK_IF_ENABLED();
13128 
13129   compareAgainstInterpreter(getBackendName(),
13130                             std::bind(createAndInitZeroLengthsSLSTest,
13131                                       std::placeholders::_1,
13132                                       std::placeholders::_2,
13133                                       /* convertToRowwiseQuantization */ false),
13134 
13135                             ElemKind::Float16Ty, ElemKind::Float16Ty);
13136 }
13137 
13138 template <typename DataType>
testSparseToDense(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)13139 static void testSparseToDense(glow::PlaceholderBindings &bindings,
13140                               glow::Module &mod, glow::Function *F,
13141                               glow::ExecutionEngine &EE, ElemKind DTy) {
13142 
13143   // Create and initialize inputs. Make input 3D to make sure
13144   // multidimensional values are handled properly.
13145   constexpr dim_t kNumIndices = 4;
13146   constexpr dim_t kRows = 10;
13147   constexpr dim_t kCols = 5;
13148   constexpr dim_t kMaxIndex = 10;
13149 
13150   auto *indices = mod.createPlaceholder(ElemKind::Int64ITy, {kNumIndices},
13151                                         "indices", false);
13152   auto *values =
13153       mod.createPlaceholder(DTy, {kNumIndices, kRows, kCols}, "data", false);
13154   auto *dataToInferDim = mod.createPlaceholder(ElemKind::FloatTy, {kMaxIndex},
13155                                                "dataToInferDim", false);
13156 
13157   auto IH = bindings.allocate(indices)->getHandle<int64_t>();
13158   auto VH = bindings.allocate(values)->getHandle<DataType>();
13159 
13160   // Duplicate one index to test that the corresponding values are added.
13161   IH = {1, 3, 1, 9};
13162   VH.randomize(-3.0, 3.0, mod.getPRNG());
13163 
13164   auto *STDN = F->createSparseToDense("STDN", indices, values, dataToInferDim);
13165   auto *S = F->createSave("save", STDN);
13166   bindings.allocate(S->getPlaceholder());
13167 
13168   EE.compile(CompilationMode::Infer);
13169   EE.run(bindings);
13170 
13171   Tensor &result = *bindings.get(S->getPlaceholder());
13172 
13173   // Compute expected output.
13174   Tensor expected(DTy, {kMaxIndex, kRows, kCols});
13175   auto EH = expected.getHandle<DataType>();
13176 
13177   expected.zero();
13178   for (dim_t i = 0; i < kNumIndices; ++i) {
13179     dim_t idx = IH.at({i});
13180     for (dim_t j = 0; j < kRows; ++j) {
13181       for (dim_t k = 0; k < kCols; ++k) {
13182         EH.at({idx, j, k}) += VH.at({i, j, k});
13183       }
13184     }
13185   }
13186 
13187   EXPECT_TRUE(expected.isEqual(result));
13188 }
13189 
TEST_P(OperatorTest,SparseToDense_Float)13190 TEST_P(OperatorTest, SparseToDense_Float) {
13191   CHECK_IF_ENABLED();
13192   testSparseToDense<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
13193 }
13194 
TEST_P(OperatorTest,SparseToDense_Int64)13195 TEST_P(OperatorTest, SparseToDense_Int64) {
13196   CHECK_IF_ENABLED();
13197   testSparseToDense<int64_t>(bindings_, mod_, F_, EE_, ElemKind::Int64ITy);
13198 }
13199 
TEST_P(OperatorTest,SparseToDenseMask1)13200 TEST_P(OperatorTest, SparseToDenseMask1) {
13201   CHECK_IF_ENABLED();
13202 
13203   /*
13204     INDICES = [4, 42, 13, 0, 100, 13]
13205     VALUES = [-5.5, 0.7, 11, 1e6, 2, 3.5]
13206     DEFAULTVALUE = 1.1
13207     LENGTHS = [4, 2]
13208     MASK = [2, 1, 0, 13, 42, 43]
13209     OUTPUT =  [[1.1, 1.1, 1e6, 11, 0.7, 1.1], [1.1, 1.1, 1.1, 3.5, 1.1, 1.1]]
13210   */
13211   auto *indices =
13212       mod_.createPlaceholder(ElemKind::Int64ITy, {6}, "indices", false);
13213   auto *values =
13214       mod_.createPlaceholder(ElemKind::FloatTy, {6}, "values", false);
13215   auto *defaultValue =
13216       mod_.createPlaceholder(ElemKind::FloatTy, {}, "default_value", false);
13217   auto *lengths =
13218       mod_.createPlaceholder(ElemKind::Int32ITy, {2}, "lengths", false);
13219   std::vector<dim_t> mask{2, 1, 0, 13, 42, 43};
13220 
13221   bindings_.allocate(indices)->getHandle<int64_t>() = {4, 42, 13, 0, 100, 13};
13222   bindings_.allocate(values)->getHandle<float>() = {-5.5, 0.7, 11, 1e6, 2, 3.5};
13223   bindings_.allocate(defaultValue)->getHandle<float>().raw(0) = 1.1;
13224   bindings_.allocate(lengths)->getHandle<int32_t>() = {4, 2};
13225 
13226   auto *R = F_->createSparseToDenseMask("STDM", indices, values, defaultValue,
13227                                         lengths, mask);
13228   auto *S = F_->createSave("save", R);
13229   bindings_.allocate(S->getPlaceholder());
13230 
13231   EE_.compile(CompilationMode::Infer);
13232   EE_.run(bindings_);
13233 
13234   Tensor &result = *bindings_.get(S->getPlaceholder());
13235   Tensor expected(ElemKind::FloatTy, {2, 6});
13236   expected.getHandle<float>() = {
13237       1.1, 1.1, 1e6, 11, 0.7, 1.1, 1.1, 1.1, 1.1, 3.5, 1.1, 1.1,
13238   };
13239 
13240   EXPECT_TRUE(expected.isEqual(result));
13241 }
13242 
TEST_P(OperatorTest,SparseToDenseMask2)13243 TEST_P(OperatorTest, SparseToDenseMask2) {
13244   CHECK_IF_ENABLED();
13245 
13246   /*
13247     INDICES = [300, 100, 101, 299]
13248     VALUES = [[[-0.1, -0.2], [-0.3, -0.4]], [[2, -2], [2, 9]],
13249               [[15, 4.2], [10.3, 30.4]], [[0, 2], [3, 4.4]]]
13250     DEFAULTVALUE = [[0.1, 0.2], [0.3, 0.4]]
13251     LENGTHS = []
13252     MASK = [100, 300, 1]
13253     OUTPUT =  [[[2, -2], [2, 9]], [[-0.1, -0.2], [-0.3, -0.4]],
13254                [[0.1, 0.2], [0.3, 0.4]]]
13255   */
13256   auto *indices =
13257       mod_.createPlaceholder(ElemKind::Int64ITy, {4}, "indices", false);
13258   auto *values =
13259       mod_.createPlaceholder(ElemKind::FloatTy, {4, 2, 2}, "values", false);
13260   auto *defaultValue =
13261       mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "default_value", false);
13262   auto *lengths =
13263       mod_.createPlaceholder(ElemKind::Int32ITy, {}, "lengths", false);
13264   std::vector<dim_t> mask{100, 300, 1};
13265 
13266   bindings_.allocate(indices)->getHandle<int64_t>() = {300, 100, 101, 299};
13267   bindings_.allocate(values)->getHandle<float>() = {
13268       -0.1, -0.2, -0.3, -0.4, 2, -2, 2, 9, 15, 4.2, 10.3, 30.4, 0, 2, 3, 4.4};
13269   bindings_.allocate(defaultValue)->getHandle<float>() = {0.1, 0.2, 0.3, 0.4};
13270   bindings_.allocate(lengths)->getHandle<int32_t>() = {4};
13271 
13272   auto *R = F_->createSparseToDenseMask("STDM", indices, values, defaultValue,
13273                                         lengths, mask);
13274   auto *S = F_->createSave("save", R);
13275   bindings_.allocate(S->getPlaceholder());
13276 
13277   EE_.compile(CompilationMode::Infer);
13278   EE_.run(bindings_);
13279 
13280   Tensor &result = *bindings_.get(S->getPlaceholder());
13281   Tensor expected(ElemKind::FloatTy, {3, 2, 2});
13282   expected.getHandle<float>() = {
13283       2, -2, 2, 9, -0.1, -0.2, -0.3, -0.4, 0.1, 0.2, 0.3, 0.4,
13284   };
13285 
13286   EXPECT_TRUE(expected.isEqual(result));
13287 }
13288 
TEST_P(OperatorTest,FP16Reshape)13289 TEST_P(OperatorTest, FP16Reshape) {
13290   CHECK_IF_ENABLED();
13291 
13292   auto *A = mod_.createPlaceholder(ElemKind::Float16Ty, {20, 13}, "A", false);
13293   auto inputHandle = bindings_.allocate(A)->getHandle<float16_t>();
13294   inputHandle.randomize(-3.0, 3.0, mod_.getPRNG());
13295 
13296   auto *tr = F_->createReshape("tr", A, {13, 20, 1});
13297   auto *result = F_->createSave("saveTranspose", tr);
13298   bindings_.allocate(result->getPlaceholder());
13299 
13300   EE_.compile(CompilationMode::Infer);
13301   EE_.run(bindings_);
13302 
13303   auto outputHandle =
13304       bindings_.get(result->getPlaceholder())->getHandle<float16_t>();
13305   ASSERT_EQ(outputHandle.size(), inputHandle.size());
13306   for (size_t idx = 0, end = inputHandle.size(); idx != end; ++idx) {
13307     EXPECT_EQ(inputHandle.raw(idx), outputHandle.raw(idx));
13308   }
13309 }
13310 
TEST_P(OperatorTest,BFloat16Reshape)13311 TEST_P(OperatorTest, BFloat16Reshape) {
13312   CHECK_IF_ENABLED();
13313 
13314   auto *A = mod_.createPlaceholder(ElemKind::BFloat16Ty, {20, 13}, "A", false);
13315   auto inputHandle = bindings_.allocate(A)->getHandle<bfloat16_t>();
13316   inputHandle.randomize(-3.0, 3.0, mod_.getPRNG());
13317 
13318   auto *tr = F_->createReshape("tr", A, {13, 20, 1});
13319   auto *result = F_->createSave("saveTranspose", tr);
13320   bindings_.allocate(result->getPlaceholder());
13321 
13322   EE_.compile(CompilationMode::Infer);
13323   EE_.run(bindings_);
13324 
13325   auto outputHandle =
13326       bindings_.get(result->getPlaceholder())->getHandle<bfloat16_t>();
13327   ASSERT_EQ(outputHandle.size(), inputHandle.size());
13328   for (size_t idx = 0, end = inputHandle.size(); idx != end; ++idx) {
13329     EXPECT_EQ(inputHandle.raw(idx), outputHandle.raw(idx));
13330   }
13331 }
13332 
13333 /// Verify that the Reshape operator works correctly.
TEST_P(OperatorTest,Reshape)13334 TEST_P(OperatorTest, Reshape) {
13335   CHECK_IF_ENABLED();
13336 
13337   auto *A = mod_.createPlaceholder(ElemKind::FloatTy, {5, 7}, "A", false);
13338   auto inputHandle = bindings_.allocate(A)->getHandle();
13339   inputHandle.randomize(-3.0, 3.0, mod_.getPRNG());
13340 
13341   auto *RN = F_->createReshape("reshape", A, {7, 5, 1});
13342   auto *result = F_->createSave("saveReshape", RN);
13343   bindings_.allocate(result->getPlaceholder());
13344 
13345   EE_.compile(CompilationMode::Infer);
13346   EE_.run(bindings_);
13347 
13348   auto outputHandle = bindings_.get(result->getPlaceholder())->getHandle();
13349   ASSERT_EQ(outputHandle.size(), inputHandle.size());
13350   ASSERT_EQ(outputHandle.dims().size(), 3);
13351   EXPECT_EQ(outputHandle.dims()[0], 7);
13352   EXPECT_EQ(outputHandle.dims()[1], 5);
13353   EXPECT_EQ(outputHandle.dims()[2], 1);
13354 
13355   // Check values are still in the same order.
13356   for (size_t idx = 0, end = inputHandle.size(); idx != end; ++idx) {
13357     EXPECT_EQ(inputHandle.raw(idx), outputHandle.raw(idx));
13358   }
13359 }
13360 
13361 /// Verify that the Reshape operator works correctly with Int64ITy.
TEST_P(OperatorTest,ReshapeInt)13362 TEST_P(OperatorTest, ReshapeInt) {
13363   CHECK_IF_ENABLED();
13364 
13365   auto *A = mod_.createPlaceholder(ElemKind::Int64ITy, {5, 7}, "A", false);
13366   auto inputHandle = bindings_.allocate(A)->getHandle<int64_t>();
13367   inputHandle.randomize<int64_t>(0, 100, mod_.getPRNG());
13368 
13369   auto *RN = F_->createReshape("reshape", A, {7, 5, 1});
13370   auto *result = F_->createSave("saveReshape", RN);
13371   bindings_.allocate(result->getPlaceholder());
13372 
13373   EE_.compile(CompilationMode::Infer);
13374   EE_.run(bindings_);
13375 
13376   auto outputHandle =
13377       bindings_.get(result->getPlaceholder())->getHandle<int64_t>();
13378   ASSERT_EQ(outputHandle.size(), inputHandle.size());
13379   ASSERT_EQ(outputHandle.dims().size(), 3);
13380   EXPECT_EQ(outputHandle.dims()[0], 7);
13381   EXPECT_EQ(outputHandle.dims()[1], 5);
13382   EXPECT_EQ(outputHandle.dims()[2], 1);
13383 
13384   // Check values are still in the same order.
13385   for (size_t idx = 0, end = inputHandle.size(); idx != end; ++idx) {
13386     EXPECT_EQ(inputHandle.raw(idx), outputHandle.raw(idx));
13387   }
13388 }
13389 
13390 /// Verify that the Select operator works correctly.
TEST_P(OperatorTest,Select)13391 TEST_P(OperatorTest, Select) {
13392   CHECK_IF_ENABLED();
13393 
13394   auto *A = mod_.createPlaceholder(ElemKind::BoolTy, {5}, "A", false);
13395   bindings_.allocate(A)->getHandle<bool>() = {false, true, true, false, false};
13396 
13397   auto SNTy = mod_.uniqueType(ElemKind::FloatTy, {5});
13398   SplatNode *SN10 = F_->createSplat("zero", SNTy, 10.0);
13399   SplatNode *SN20 = F_->createSplat("zero", SNTy, 20.0);
13400 
13401   auto *SN = F_->createSelect("select", A, SN10, SN20);
13402   auto *result = F_->createSave("saveSelect", SN);
13403   bindings_.allocate(result->getPlaceholder());
13404 
13405   EE_.compile(CompilationMode::Infer);
13406   EE_.run(bindings_);
13407 
13408   auto resH = bindings_.get(result->getPlaceholder())->getHandle();
13409   EXPECT_EQ(resH.at({0}), 20.0);
13410   EXPECT_EQ(resH.at({1}), 10.0);
13411   EXPECT_EQ(resH.at({2}), 10.0);
13412   EXPECT_EQ(resH.at({3}), 20.0);
13413   EXPECT_EQ(resH.at({4}), 20.0);
13414 }
13415 
13416 /// Verify that the CmpLTE operator works correctly.
TEST_P(OperatorTest,CmpLTE)13417 TEST_P(OperatorTest, CmpLTE) {
13418   CHECK_IF_ENABLED();
13419 
13420   Placeholder *A = mod_.createPlaceholder(ElemKind::FloatTy, {5}, "A", false);
13421   Placeholder *B = mod_.createPlaceholder(ElemKind::FloatTy, {5}, "B", false);
13422   bindings_.allocate(A)->getHandle<float>() = {0.0, 1.0, 2.0, 3.0, 4.0};
13423   bindings_.allocate(B)->getHandle<float>() = {0.0, 1.1, 1.5, 10.1, -1.0};
13424 
13425   auto *CMPLTE = F_->createCmpLTE("select", A, B);
13426   auto *result = F_->createSave("saveCMPLTE", CMPLTE);
13427   Tensor *resultT = bindings_.allocate(result->getPlaceholder());
13428 
13429   EE_.compile(CompilationMode::Infer);
13430   EE_.run(bindings_);
13431 
13432   auto resH = resultT->getHandle<bool>();
13433   EXPECT_TRUE(resH.at({0}));
13434   EXPECT_TRUE(resH.at({1}));
13435   EXPECT_FALSE(resH.at({2}));
13436   EXPECT_TRUE(resH.at({3}));
13437   EXPECT_FALSE(resH.at({4}));
13438 }
13439 
13440 /// Helper to test SliceReshape using \p DTy.
13441 template <typename DataType>
testSliceReshape(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)13442 static void testSliceReshape(glow::PlaceholderBindings &bindings,
13443                              glow::Module &mod, glow::Function *F,
13444                              glow::ExecutionEngine &EE, ElemKind DTy) {
13445   auto *X =
13446       createPlaceholderConditionallyQuantized(mod, DTy, {3, 3}, "X", false);
13447 
13448   auto XH = bindings.allocate(X)->getHandle<DataType>();
13449   for (dim_t i = 0; i < 3; i++) {
13450     for (dim_t j = 0; j < 3; j++) {
13451       XH.at({i, j}) = i * 3 + j;
13452     }
13453   }
13454 
13455   // Do an assortment of slices/reshapes stacked on top of each other.
13456   auto *SX = F->createSlice("sliceX", X, {2, 0}, {3, 3});
13457   auto *RSX = F->createReshape("reshapeSX", SX, {3});
13458   auto *SSX = F->createSlice("sliceSliceX", SX, {0, 2}, {1, 3});
13459   auto *RSSX = F->createReshape("reshapeSliceSliceX", SSX, {1});
13460 
13461   auto *resultSX = F->createSave("saveSX", SX);
13462   auto *resultRSX = F->createSave("saveRSX", RSX);
13463   auto *resultSSX = F->createSave("saveSSX", SSX);
13464   auto *resultRSSX = F->createSave("saveRSSX", RSSX);
13465 
13466   bindings.allocate(resultSX->getPlaceholder());
13467   bindings.allocate(resultRSX->getPlaceholder());
13468   bindings.allocate(resultSSX->getPlaceholder());
13469   bindings.allocate(resultRSSX->getPlaceholder());
13470 
13471   EE.compile(CompilationMode::Infer);
13472 
13473   EE.run(bindings);
13474 
13475   // Verify the slice has the same data as the original X.
13476   auto SXH = bindings.get(resultSX->getPlaceholder())->getHandle<DataType>();
13477   for (dim_t i = 0; i < 3; i++) {
13478     EXPECT_NEAR(SXH.at({0, i}), XH.at({2, i}), 1E-5);
13479   }
13480 
13481   // Verify the reshaped slice has the same data as the slice.
13482   auto RSXH = bindings.get(resultRSX->getPlaceholder())->getHandle<DataType>();
13483   for (dim_t i = 0; i < 3; i++) {
13484     EXPECT_NEAR(SXH.at({0, i}), RSXH.at({i}), 1E-5);
13485   }
13486 
13487   // Verify the slice of the slice has the same data as the slice.
13488   auto SSXH = bindings.get(resultSSX->getPlaceholder())->getHandle<DataType>();
13489   EXPECT_NEAR(SXH.at({0, 2}), SSXH.at({0, 0}), 1E-5);
13490 
13491   // Verify the reshape of the slice of the slice has the same data as the
13492   // slice of the slice.
13493   auto RSSXH =
13494       bindings.get(resultRSSX->getPlaceholder())->getHandle<DataType>();
13495   EXPECT_NEAR(RSSXH.at({0}), SSXH.at({0, 0}), 1E-5);
13496 }
13497 
13498 /// Stack many slices/reshapes together. Some of these may be turned into
13499 /// tensor views stacked onto each other. Test in FloatTy.
TEST_P(OperatorTest,sliceReshape_Float)13500 TEST_P(OperatorTest, sliceReshape_Float) {
13501   CHECK_IF_ENABLED();
13502 
13503   testSliceReshape<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
13504 }
13505 
13506 /// Stack many slices/reshapes together. Some of these may be turned into
13507 /// tensor views stacked onto each other. Test in Float16Ty.
TEST_P(OperatorTest,sliceReshape_Float16)13508 TEST_P(OperatorTest, sliceReshape_Float16) {
13509   CHECK_IF_ENABLED();
13510   testSliceReshape<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
13511 }
13512 
13513 /// Stack many slices/reshapes together. Some of these may be turned into
13514 /// tensor views stacked onto each other. Test in BFloat16Ty.
TEST_P(OperatorTest,sliceReshape_BFloat16)13515 TEST_P(OperatorTest, sliceReshape_BFloat16) {
13516   CHECK_IF_ENABLED();
13517   testSliceReshape<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty);
13518 }
13519 
13520 /// Stack many slices/reshapes together. Some of these may be turned into
13521 /// tensor views stacked onto each other. Test in Int8QTy.
TEST_P(OperatorTest,sliceReshape_Int8)13522 TEST_P(OperatorTest, sliceReshape_Int8) {
13523   CHECK_IF_ENABLED();
13524   testSliceReshape<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
13525 }
13526 
13527 /// Stack many slices/reshapes together. Some of these may be turned into
13528 /// tensor views stacked onto each other. Test in Int32QTy.
TEST_P(OperatorTest,sliceReshape_Int32)13529 TEST_P(OperatorTest, sliceReshape_Int32) {
13530   CHECK_IF_ENABLED();
13531   testSliceReshape<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32QTy);
13532 }
13533 
13534 /// Helper to test Flatten using \p DTy.
13535 template <typename DataType>
testFlatten(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)13536 static void testFlatten(glow::PlaceholderBindings &bindings, glow::Module &mod,
13537                         glow::Function *F, glow::ExecutionEngine &EE,
13538                         ElemKind DTy) {
13539   auto *tensor4D = createPlaceholderConditionallyQuantized(
13540       mod, DTy, {3, 2, 4, 3}, "4D", false, "NHWC");
13541   bindings.allocate(tensor4D)->getHandle<DataType>().randomize(0, 100,
13542                                                                mod.getPRNG());
13543 
13544   NodeValue reshape4Dto2DAxis1 = F->createFlatten("flat4Dto2Da1", tensor4D, 1);
13545   EXPECT_EQ(reshape4Dto2DAxis1.dims().size(), 2);
13546   EXPECT_EQ(reshape4Dto2DAxis1.dims()[0], 3);
13547   EXPECT_EQ(reshape4Dto2DAxis1.dims()[1], 24);
13548 
13549   NodeValue reshape4Dto2DAxis2 = F->createFlatten("flat4Dto2Da2", tensor4D, 2);
13550   EXPECT_EQ(reshape4Dto2DAxis2.dims().size(), 2);
13551   EXPECT_EQ(reshape4Dto2DAxis2.dims()[0], 6);
13552   EXPECT_EQ(reshape4Dto2DAxis2.dims()[1], 12);
13553 
13554   NodeValue reshape4Dto2DAxis3 = F->createFlatten("flat4Dto2Da3", tensor4D, 3);
13555   EXPECT_EQ(reshape4Dto2DAxis3.dims().size(), 2);
13556   EXPECT_EQ(reshape4Dto2DAxis3.dims()[0], 24);
13557   EXPECT_EQ(reshape4Dto2DAxis3.dims()[1], 3);
13558 
13559   // Now, let us do the fifth (4) axis.
13560   // This comes straight from caffe2 because flattening is
13561   // supported for every axis up and including the rank of a tensor.
13562   // The rank of this tensor is 4, so axis 4 is fine.
13563   NodeValue reshape4Dto2DAxis4 = F->createFlatten("flat4Dto2Da4", tensor4D, 4);
13564   EXPECT_EQ(reshape4Dto2DAxis4.dims().size(), 2);
13565   EXPECT_EQ(reshape4Dto2DAxis4.dims()[0], 72);
13566   EXPECT_EQ(reshape4Dto2DAxis4.dims()[1], 1);
13567 
13568   // This one is weird because we flatten something that is already flat, but
13569   // again because flattening is supported for every axis up and including the
13570   // rank of a tensor, 1D vector means we can flatten it on axis 1.
13571   auto *tensor1D =
13572       createPlaceholderConditionallyQuantized(mod, DTy, {15}, "1D", false, "N");
13573   bindings.allocate(tensor1D)->getHandle<DataType>().randomize(0, 100,
13574                                                                mod.getPRNG());
13575 
13576   NodeValue reshape1Dto2DAxis1 = F->createFlatten("flat1Dto2D", tensor1D, 1);
13577   EXPECT_EQ(reshape1Dto2DAxis1.dims().size(), 2);
13578   EXPECT_EQ(reshape1Dto2DAxis1.dims()[0], 15);
13579   EXPECT_EQ(reshape1Dto2DAxis1.dims()[1], 1);
13580 
13581   // Save all the reshapes so that the optimizations won't kill the network.
13582   auto *save1Dto2D = F->createSave("save1Dto2D", reshape1Dto2DAxis1);
13583   auto *save4Dto2Da1 = F->createSave("save4Dto2Da1", reshape4Dto2DAxis1);
13584   auto *save4Dto2Da2 = F->createSave("save4Dto2Da2", reshape4Dto2DAxis2);
13585   auto *save4Dto2Da3 = F->createSave("save4Dto2Da3", reshape4Dto2DAxis3);
13586   auto *save4Dto2Da4 = F->createSave("save4Dto2Da4", reshape4Dto2DAxis4);
13587 
13588   bindings.allocate(save1Dto2D->getPlaceholder());
13589   bindings.allocate(save4Dto2Da1->getPlaceholder());
13590   bindings.allocate(save4Dto2Da2->getPlaceholder());
13591   bindings.allocate(save4Dto2Da3->getPlaceholder());
13592   bindings.allocate(save4Dto2Da4->getPlaceholder());
13593 
13594   EE.compile(CompilationMode::Infer);
13595 
13596   EE.run(bindings);
13597 
13598   // Verify the reshapes have the same data as the original value.
13599   auto tensor4DH = bindings.get(tensor4D)->getHandle<DataType>();
13600   auto save4Dto2Da1H =
13601       bindings.get(save4Dto2Da1->getPlaceholder())->getHandle<DataType>();
13602   for (size_t i = 0; i < 72; i++) {
13603     EXPECT_NEAR(tensor4DH.raw(i), save4Dto2Da1H.raw(i), 1E-5);
13604   }
13605 
13606   auto save4Dto2Da2H =
13607       bindings.get(save4Dto2Da2->getPlaceholder())->getHandle<DataType>();
13608   for (size_t i = 0; i < 72; i++) {
13609     EXPECT_NEAR(tensor4DH.raw(i), save4Dto2Da2H.raw(i), 1E-5);
13610   }
13611 
13612   auto save4Dto2Da3H =
13613       bindings.get(save4Dto2Da3->getPlaceholder())->getHandle<DataType>();
13614   for (size_t i = 0; i < 72; i++) {
13615     EXPECT_NEAR(tensor4DH.raw(i), save4Dto2Da3H.raw(i), 1E-5);
13616   }
13617 
13618   auto save4Dto2Da4H =
13619       bindings.get(save4Dto2Da4->getPlaceholder())->getHandle<DataType>();
13620   for (size_t i = 0; i < 72; i++) {
13621     EXPECT_NEAR(tensor4DH.raw(i), save4Dto2Da4H.raw(i), 1E-5);
13622   }
13623 
13624   auto tensor1DH = bindings.get(tensor1D)->getHandle<DataType>();
13625   auto save1Dto2DH =
13626       bindings.get(save1Dto2D->getPlaceholder())->getHandle<DataType>();
13627   for (size_t i = 0; i < 15; i++) {
13628     EXPECT_NEAR(tensor1DH.raw(i), save1Dto2DH.raw(i), 1E-5);
13629   }
13630 }
13631 
13632 /// Check that the flatten operator produces 2D tensors of the right
13633 /// dimensions, using FloatTy.
TEST_P(OperatorTest,Flatten_FloatTy)13634 TEST_P(OperatorTest, Flatten_FloatTy) {
13635   CHECK_IF_ENABLED();
13636   testFlatten<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
13637 }
13638 
13639 /// Check that the flatten operator produces 2D tensors of the right
13640 /// dimensions, using Float16Ty.
TEST_P(OperatorTest,Flatten_Float16Ty)13641 TEST_P(OperatorTest, Flatten_Float16Ty) {
13642   CHECK_IF_ENABLED();
13643   testFlatten<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
13644 }
13645 
13646 /// Check that the flatten operator produces 2D tensors of the right
13647 /// dimensions, using BFloat16Ty.
TEST_P(OperatorTest,Flatten_BFloat16Ty)13648 TEST_P(OperatorTest, Flatten_BFloat16Ty) {
13649   CHECK_IF_ENABLED();
13650   testFlatten<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty);
13651 }
13652 
13653 /// Check that the flatten operator produces 2D tensors of the right
13654 /// dimensions, using Int8QTy.
TEST_P(OperatorTest,Flatten_Int8)13655 TEST_P(OperatorTest, Flatten_Int8) {
13656   CHECK_IF_ENABLED();
13657   testFlatten<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
13658 }
13659 
13660 /// Check that div on Int64ITy/size_t works.
TEST_P(OperatorTest,DivSizeT)13661 TEST_P(OperatorTest, DivSizeT) {
13662   CHECK_IF_ENABLED();
13663 
13664   auto *LHS = mod_.createPlaceholder(ElemKind::Int64ITy, {3, 2}, "LHS", false);
13665   auto *RHS = mod_.createPlaceholder(ElemKind::Int64ITy, {3, 2}, "RHS", false);
13666   auto LHSH = bindings_.allocate(LHS)->getHandle<int64_t>();
13667   auto RHSH = bindings_.allocate(RHS)->getHandle<int64_t>();
13668 
13669   LHSH = {10, 20, 30, 40, 50, 60};
13670   RHSH = {2, 20, 100, 41, 3, 59};
13671 
13672   auto *R = F_->createDiv("div", LHS, RHS);
13673 
13674   auto *result = F_->createSave("save", R);
13675   bindings_.allocate(result->getPlaceholder());
13676 
13677   CompilationContext cctx;
13678   cctx.compMode = CompilationMode::Infer;
13679   // Disabling this so that  division of Int64ITy/size_t can be tested.
13680   cctx.optimizationOpts.enableTypeDemotion = false;
13681   EE_.compile(cctx);
13682   EE_.run(bindings_);
13683 
13684   auto H = bindings_.get(result->getPlaceholder())->getHandle<int64_t>();
13685 
13686   for (dim_t i = 0; i < 3; i++) {
13687     for (dim_t j = 0; j < 2; j++) {
13688       EXPECT_EQ(LHSH.at({i, j}) / RHSH.at({i, j}), H.at({i, j}));
13689     }
13690   }
13691 }
13692 
TEST_P(OperatorTest,SigmoidCrossEntropyWithLogits)13693 TEST_P(OperatorTest, SigmoidCrossEntropyWithLogits) {
13694   CHECK_IF_ENABLED();
13695 
13696   /*
13697     LOGITS  = [
13698       [
13699         [1.0, 1.2, -0.5],
13700         [0.1, 0.6, 0.5],
13701       ],
13702       [
13703         [-0.1, -2., 0.3],
13704         [1, 2, 3],
13705       ],
13706     ]
13707     TARGETS = [
13708       [
13709         [0.7, 0.7, 0.7],
13710         [-0.7, -0.99, 1.0],
13711       ],
13712       [
13713         [0, 0, 0],
13714         [1, 2, 3],
13715       ],
13716     ]
13717     OUTPUT = [
13718       [ 0.68687367,  0.97332054],
13719       [ 0.5418933,  -2.50374103],
13720     ]
13721   */
13722   auto *logits =
13723       mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 3}, "logits", false);
13724   auto *targets =
13725       mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 3}, "targets", false);
13726 
13727   bindings_.allocate(logits)->getHandle() = {
13728       1.0f, 1.2f, -0.5f, 0.1f, 0.6f, 0.5f, -0.1f, -2.f, 0.3f, 1.f, 2.f, 3.f};
13729   bindings_.allocate(targets)->getHandle() = {
13730       0.7f, 0.7f, 0.7f, -0.7f, -0.99f, 1.0f, 0.f, 0.f, 0.f, 1.f, 2.f, 3.f};
13731 
13732   auto *R = F_->createSigmoidCrossEntropyWithLogits("SCEL", logits, targets);
13733 
13734   auto *result = F_->createSave("save", R);
13735   bindings_.allocate(result->getPlaceholder());
13736 
13737   EE_.compile(CompilationMode::Infer);
13738   EE_.run(bindings_);
13739 
13740   Tensor expected(ElemKind::FloatTy, {2, 2});
13741   expected.getHandle() = {
13742       0.68687367f,
13743       0.97332054f,
13744       0.5418933f,
13745       -2.50374103f,
13746   };
13747 
13748   EXPECT_TRUE(expected.isEqual(*bindings_.get(result->getPlaceholder())));
13749 }
13750 
13751 /// Test the InsertTensor node works correctly.
TEST_P(OperatorTest,insertTensorTest)13752 TEST_P(OperatorTest, insertTensorTest) {
13753   CHECK_IF_ENABLED();
13754 
13755   // 0 0 0 0 0 0
13756   // 0 0 0 0 0 0
13757   // 0 0 0 0 0 0
13758   // 0 0 0 0 0 0
13759   auto *SN0 = mod_.createPlaceholder(ElemKind::FloatTy, {4, 6}, "SN0", false);
13760   bindings_.allocate(SN0)->init(Tensor::InitKind::Broadcast, 0, mod_.getPRNG());
13761 
13762   // 1 1
13763   // 1 1
13764   auto *SN1 = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "SN1", false);
13765   bindings_.allocate(SN1)->init(Tensor::InitKind::Broadcast, 1, mod_.getPRNG());
13766 
13767   // 0 0 0 0 0 0
13768   // 0 1 1 1 1 0
13769   // 0 1 1 1 1 0
13770   // 0 0 0 0 0 0
13771   Node *IN = F_->createInsertTensor("insert", SN0, SN1, /* start */ {1, 1},
13772                                     /* count */ 2, /* axis */ 1);
13773   SaveNode *result = F_->createSave("result", IN);
13774   bindings_.allocate(result->getPlaceholder());
13775 
13776   EE_.compile(CompilationMode::Infer);
13777 
13778   EE_.run(bindings_);
13779 
13780   // Verify the output looks as expected (pictured above).
13781   auto resultH = bindings_.get(result->getPlaceholder())->getHandle<float>();
13782   for (dim_t i = 0; i < 4; i++) {
13783     for (dim_t j = 0; j < 6; j++) {
13784       int64_t expected = 1;
13785       if (i == 0 || i == 3 || j == 0 || j == 5)
13786         expected = 0;
13787       EXPECT_EQ(resultH.at({i, j}), expected);
13788     }
13789   }
13790 }
13791 
13792 /// Test the InsertTensor node works correctly for 3 dimensions.
TEST_P(OperatorTest,insertTensorTest3D)13793 TEST_P(OperatorTest, insertTensorTest3D) {
13794   CHECK_IF_ENABLED();
13795 
13796   // 0 0 0 0 0 0 | 0 0 0 0 0 0
13797   // 0 0 0 0 0 0 | 0 0 0 0 0 0
13798   // 0 0 0 0 0 0 | 0 0 0 0 0 0
13799   // 0 0 0 0 0 0 | 0 0 0 0 0 0
13800   auto *SN0 =
13801       mod_.createPlaceholder(ElemKind::FloatTy, {2, 4, 6}, "SN0", false);
13802   bindings_.allocate(SN0)->init(Tensor::InitKind::Broadcast, 0, mod_.getPRNG());
13803 
13804   // 1 1 | 1 1
13805   // 1 1 | 1 1
13806   auto *SN1 =
13807       mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 2}, "SN1", false);
13808   bindings_.allocate(SN1)->init(Tensor::InitKind::Broadcast, 1, mod_.getPRNG());
13809 
13810   // 0 0 0 0 0 0 | 0 0 0 0 0 0
13811   // 0 1 1 1 1 0 | 0 1 1 1 1 0
13812   // 0 1 1 1 1 0 | 0 1 1 1 1 0
13813   // 0 0 0 0 0 0 | 0 0 0 0 0 0
13814   Node *IN = F_->createInsertTensor("insert", SN0, SN1, /* start */ {0, 1, 1},
13815                                     /* count */ 2, /* axis */ 2);
13816   SaveNode *result = F_->createSave("result", IN);
13817   bindings_.allocate(result->getPlaceholder());
13818 
13819   EE_.compile(CompilationMode::Infer);
13820 
13821   EE_.run(bindings_);
13822 
13823   // Verify the output looks as expected (pictured above).
13824   auto resultH = bindings_.get(result->getPlaceholder())->getHandle<float>();
13825   for (dim_t i = 0; i < 2; i++) {
13826     for (dim_t j = 0; j < 4; j++) {
13827       for (dim_t k = 0; k < 6; k++) {
13828         int64_t expected = 1;
13829         if (j == 0 || j == 3 || k == 0 || k == 5)
13830           expected = 0;
13831         EXPECT_EQ(resultH.at({i, j, k}), expected);
13832       }
13833     }
13834   }
13835 }
13836 
13837 /// Test that the InsertTensor operator works correctly when crossing outer
13838 /// dimensions.
TEST_P(OperatorTest,insertTensorCrossDimensions)13839 TEST_P(OperatorTest, insertTensorCrossDimensions) {
13840   CHECK_IF_ENABLED();
13841 
13842   // 0 0 0 0 0
13843   // 0 0 0 0 0
13844   // 0 0 0 0 0
13845   // 0 0 0 0 0
13846   // 0 0 0 0 0
13847   // 0 0 0 0 0
13848   auto *SN0 =
13849       mod_.createPlaceholder(ElemKind::FloatTy, {3, 2, 5}, "SN0", false);
13850   bindings_.allocate(SN0)->init(Tensor::InitKind::Broadcast, 0, mod_.getPRNG());
13851 
13852   // 1 1 1 1 1 1 (T)
13853   auto *SN1 =
13854       mod_.createPlaceholder(ElemKind::FloatTy, {3, 2, 1}, "SN1", false);
13855   bindings_.allocate(SN1)->init(Tensor::InitKind::Broadcast, 1, mod_.getPRNG());
13856 
13857   // 2 2 | 2 2
13858   // 2 2 | 2 2
13859   // 2 2 | 2 2
13860   auto *SN2 =
13861       mod_.createPlaceholder(ElemKind::FloatTy, {3, 2, 2}, "SN2", false);
13862   bindings_.allocate(SN2)->init(Tensor::InitKind::Broadcast, 2, mod_.getPRNG());
13863 
13864   // 1 0 2 2 0
13865   // 1 0 2 2 0
13866   // 1 0 2 2 0
13867   // 1 0 2 2 0
13868   // 1 0 2 2 0
13869   // 1 0 2 2 0
13870   Node *IN = F_->createInsertTensor("insert", SN0, SN1, /* start */ {0, 0, 0},
13871                                     /* count */ 1, /* axis */ 2);
13872   Node *IN2 = F_->createInsertTensor("insert", IN, SN2, /* start */ {0, 0, 2},
13873                                      /* count */ 1, /* axis */ 2);
13874   SaveNode *result = F_->createSave("result", IN2);
13875   bindings_.allocate(result->getPlaceholder());
13876 
13877   EE_.compile(CompilationMode::Infer);
13878 
13879   EE_.run(bindings_);
13880 
13881   // Verify the output looks as expected (pictured above).
13882   auto resultH = bindings_.get(result->getPlaceholder())->getHandle<float>();
13883   for (dim_t i = 0; i < 3; i++) {
13884     for (dim_t j = 0; j < 2; j++) {
13885       for (dim_t k = 0; k < 5; k++) {
13886         int64_t expected = 0;
13887         if (k == 0)
13888           expected = 1;
13889         if (k == 2 || k == 3)
13890           expected = 2;
13891         EXPECT_EQ(resultH.at({i, j, k}), expected);
13892       }
13893     }
13894   }
13895 }
13896 
13897 /// Test the InsertTensor operator works correctly when inserting across an
13898 /// outer dimension where the inner dimensions have different sizes.
TEST_P(OperatorTest,insertTensorPartialSliceInnerDim)13899 TEST_P(OperatorTest, insertTensorPartialSliceInnerDim) {
13900   CHECK_IF_ENABLED();
13901 
13902   // 0 0 0 0 0
13903   // 0 0 0 0 0
13904   // 0 0 0 0 0
13905   // 0 0 0 0 0
13906   // 0 0 0 0 0
13907   // 0 0 0 0 0
13908   // 0 0 0 0 0
13909   // 0 0 0 0 0
13910   // 0 0 0 0 0
13911   auto *SN0 =
13912       mod_.createPlaceholder(ElemKind::FloatTy, {3, 3, 5}, "SN0", false);
13913   bindings_.allocate(SN0)->init(Tensor::InitKind::Broadcast, 0, mod_.getPRNG());
13914 
13915   // 1 1
13916   // 1 1
13917   // 1 1
13918   auto *SN1 =
13919       mod_.createPlaceholder(ElemKind::FloatTy, {3, 1, 2}, "SN1", false);
13920   bindings_.allocate(SN1)->init(Tensor::InitKind::Broadcast, 1, mod_.getPRNG());
13921 
13922   // 2 2 2
13923   // 2 2 2
13924   // 2 2 2
13925   auto *SN2 =
13926       mod_.createPlaceholder(ElemKind::FloatTy, {3, 1, 3}, "SN2", false);
13927   bindings_.allocate(SN2)->init(Tensor::InitKind::Broadcast, 2, mod_.getPRNG());
13928 
13929   // 1 1 0 0 0
13930   // 0 2 2 2 0
13931   // 0 0 0 0 0
13932   // 1 1 0 0 0
13933   // 0 2 2 2 0
13934   // 0 0 0 0 0
13935   // 1 1 0 0 0
13936   // 0 2 2 2 0
13937   // 0 0 0 0 0
13938   Node *IN = F_->createInsertTensor("insert", SN0, SN1, /* start */ {0, 0, 0},
13939                                     /* count */ 1, /* axis */ 2);
13940   Node *IN2 = F_->createInsertTensor("insert", IN, SN2, /* start */ {0, 1, 1},
13941                                      /* count */ 1, /* axis */ 2);
13942   SaveNode *result = F_->createSave("result", IN2);
13943   bindings_.allocate(result->getPlaceholder());
13944 
13945   EE_.compile(CompilationMode::Infer);
13946 
13947   EE_.run(bindings_);
13948   // Verify the output looks as expected (pictured above).
13949   auto resultH = bindings_.get(result->getPlaceholder())->getHandle<float>();
13950   for (dim_t i = 0; i < 3; i++) {
13951     for (dim_t j = 0; j < 3; j++) {
13952       for (dim_t k = 0; k < 5; k++) {
13953         int64_t expected = 0;
13954         if (j == 0 && k <= 1)
13955           expected = 1;
13956         if (j == 1 && k >= 1 && k <= 3)
13957           expected = 2;
13958         EXPECT_EQ(resultH.at({i, j, k}), expected);
13959       }
13960     }
13961   }
13962 }
13963 
13964 static FunctionTensorPair
createAndInitBasicRowwiseFCTest(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE)13965 createAndInitBasicRowwiseFCTest(glow::PlaceholderBindings &bindings,
13966                                 glow::ExecutionEngine &EE) {
13967   auto &mod = EE.getModule();
13968   Function *F = mod.createFunction("main");
13969 
13970   // In this test we subtract the outputs of a row-wise quantized FC and a
13971   // floating-point FC and ensure that the error is below some low value.
13972   auto *input = mod.createPlaceholder(ElemKind::FloatTy, {2, 100}, "in", false);
13973   auto *fc = F->createFullyConnected(bindings, "FC", input, 5);
13974 
13975   auto *weights = llvm::cast<Placeholder>(fc->getWeights());
13976   auto *bias = llvm::cast<Placeholder>(fc->getBias());
13977 
13978   bindings.allocate(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
13979   bindings.get(bias)->getHandle().randomize(0, 0.1, mod.getPRNG());
13980   bindings.get(weights)->getHandle().randomize(-1.1, 1.1, mod.getPRNG());
13981 
13982   auto *res = F->createSave("save", fc);
13983   ::glow::convertPlaceholdersToConstants(F, bindings,
13984                                          {input, res->getPlaceholder()});
13985   auto *resultTensor = bindings.allocate(res->getPlaceholder());
13986 
13987   return std::make_pair(F, resultTensor);
13988 }
13989 
13990 /// Test Int8 RowwiseQuantizedFullyConnected Node with Int8 bias.
TEST_P(OperatorStatelessTest,rowwiseQuantizedFCTest_Int8_BiasInt8)13991 TEST_P(OperatorStatelessTest, rowwiseQuantizedFCTest_Int8_BiasInt8) {
13992   ENABLED_BACKENDS("Interpreter", "CPU");
13993   compareAgainstInterpreter(
13994       getBackendName(), createAndInitBasicRowwiseFCTest, ElemKind::FloatTy,
13995       ElemKind::Int8QTy, 0.06f, parCloneCountOpt,
13996       /* convertToRowwiseQuantization */ true, quantization::Schema::Asymmetric,
13997       ElemKind::Int8QTy);
13998 }
13999 
14000 /// Test Int8 RowwiseQuantizedFullyConnected Node with Int32 bias.
TEST_P(OperatorStatelessTest,rowwiseQuantizedFCTest_Int8_BiasInt32)14001 TEST_P(OperatorStatelessTest, rowwiseQuantizedFCTest_Int8_BiasInt32) {
14002   ENABLED_BACKENDS("Interpreter", "CPU");
14003   compareAgainstInterpreter(
14004       getBackendName(), createAndInitBasicRowwiseFCTest, ElemKind::FloatTy,
14005       ElemKind::Int8QTy, 0.06f, parCloneCountOpt,
14006       /* convertToRowwiseQuantization */ true, quantization::Schema::Asymmetric,
14007       ElemKind::Int32QTy);
14008 }
14009 
14010 /// Test RowwiseQuantizedFullyConnected Node with Symmetric quantization.
TEST_P(OperatorStatelessTest,rowwiseQuantizedFCTestSymmetric)14011 TEST_P(OperatorStatelessTest, rowwiseQuantizedFCTestSymmetric) {
14012   CHECK_IF_ENABLED();
14013   compareAgainstInterpreter(
14014       getBackendName(), createAndInitBasicRowwiseFCTest, ElemKind::FloatTy,
14015       ElemKind::Int8QTy, 0.07f, parCloneCountOpt,
14016       /* convertToRowwiseQuantization */ true, quantization::Schema::Symmetric);
14017 }
14018 
TEST_P(OperatorStatelessTest,rowwiseQuantizedFCTestSymmetric_Int8_BiasFloat32)14019 TEST_P(OperatorStatelessTest,
14020        rowwiseQuantizedFCTestSymmetric_Int8_BiasFloat32) {
14021   CHECK_IF_ENABLED();
14022   compareAgainstInterpreter(
14023       getBackendName(), createAndInitBasicRowwiseFCTest, ElemKind::FloatTy,
14024       ElemKind::Int8QTy, 0.07f, parCloneCountOpt,
14025       /* convertToRowwiseQuantization */ true, quantization::Schema::Symmetric,
14026       /*biasElemKind*/ ElemKind::Int32QTy,
14027       /*forceFP16AccumSLS*/ false, PrecisionConfiguration::Float16Format::None,
14028       /*convertToChannelwiseQuantization*/ false,
14029       /*skipQuantizeFCBias*/ true);
14030 }
14031 
TEST_P(OperatorStatelessTest,rowwiseQuantizedFCTestAsymmetric_Int8_BiasFloat32)14032 TEST_P(OperatorStatelessTest,
14033        rowwiseQuantizedFCTestAsymmetric_Int8_BiasFloat32) {
14034   CHECK_IF_ENABLED();
14035   compareAgainstInterpreter(
14036       getBackendName(), createAndInitBasicRowwiseFCTest, ElemKind::FloatTy,
14037       ElemKind::Int8QTy, 0.06f, parCloneCountOpt,
14038       /* convertToRowwiseQuantization */ true, quantization::Schema::Asymmetric,
14039       /*biasElemKind*/ ElemKind::Int32QTy,
14040       /*forceFP16AccumSLS*/ false, PrecisionConfiguration::Float16Format::None,
14041       /*convertToChannelwiseQuantization*/ false,
14042       /*skipQuantizeFCBias*/ true);
14043 }
14044 
14045 static FunctionTensorPair
createAndInitBasicSLWSTest(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE)14046 createAndInitBasicSLWSTest(glow::PlaceholderBindings &bindings,
14047                            glow::ExecutionEngine &EE) {
14048   auto &mod = EE.getModule();
14049   Function *F = mod.createFunction("main");
14050 
14051   /*
14052     DATA  =   [2.0, -0.5, 13]
14053     WEIGHTS = [3, 1, 0, 0, 0, 0, 2, -0.5]
14054     INDICES = [1, 0, 2, 0, 1, 2, 2, 0]
14055     LENGTHS = [3, 0, 3, 2]
14056     OUTPUT =  [0.5, 0, 0, 25]
14057   */
14058   auto *data = mod.createPlaceholder(ElemKind::FloatTy, {3}, "data", false);
14059   auto *weights =
14060       mod.createPlaceholder(ElemKind::FloatTy, {8}, "weights", false);
14061   auto *indices =
14062       mod.createPlaceholder(ElemKind::Int64ITy, {8}, "indices", false);
14063   auto *lengths =
14064       mod.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths", false);
14065 
14066   bindings.allocate(data)->getHandle() = {
14067       2.0,
14068       -0.5,
14069       13,
14070   };
14071   bindings.allocate(weights)->getHandle() = {
14072       3, 1, 0, 0, 0, 0, 2, -0.5,
14073   };
14074   bindings.allocate(indices)->getHandle<int64_t>() = {
14075       1, 0, 2, 0, 1, 2, 2, 0,
14076   };
14077   bindings.allocate(lengths)->getHandle<int32_t>() = {
14078       3,
14079       0,
14080       3,
14081       2,
14082   };
14083 
14084   auto *SLWS = F->createSparseLengthsWeightedSum("SLWS", data, weights, indices,
14085                                                  lengths);
14086   auto *res = F->createSave("save", SLWS);
14087   ::glow::convertPlaceholdersToConstants(
14088       F, bindings, {indices, lengths, res->getPlaceholder()});
14089   auto *resultTensor = bindings.allocate(res->getPlaceholder());
14090 
14091   return std::make_pair(F, resultTensor);
14092 }
14093 
14094 /// Test RowwiseQuantizedSLWS Node.
TEST_P(OperatorStatelessTest,rowwiseQuantizedSLWSTest)14095 TEST_P(OperatorStatelessTest, rowwiseQuantizedSLWSTest) {
14096   CHECK_IF_ENABLED();
14097   compareAgainstInterpreter(getBackendName(), createAndInitBasicSLWSTest,
14098                             ElemKind::FloatTy, ElemKind::Int8QTy, 0.01f,
14099                             parCloneCountOpt,
14100                             /* convertToRowwiseQuantization */ true);
14101 }
14102 
setupBucketNode(Function * F,PlaceholderBindings & bindings,Placeholder * input,const std::string & suffix)14103 static SaveNode *setupBucketNode(Function *F, PlaceholderBindings &bindings,
14104                                  Placeholder *input,
14105                                  const std::string &suffix) {
14106   std::vector<float> boundaries = {0.1, 2.5};
14107 
14108   auto *bucketize =
14109       F->createBucketizeNode("bucketize" + suffix, input, boundaries);
14110   auto *save = F->createSave("save" + suffix, bucketize);
14111   bindings.allocate(save->getPlaceholder());
14112   return save;
14113 }
14114 
14115 /// Check the correctness of the bucketize operator.
TEST_P(OperatorTest,Bucketize)14116 TEST_P(OperatorTest, Bucketize) {
14117   CHECK_IF_ENABLED();
14118 
14119   auto *input1 =
14120       mod_.createPlaceholder(ElemKind::FloatTy, {3}, "input1", false);
14121   bindings_.allocate(input1)->getHandle<float>() = {2.0, 4.0, 1.0};
14122   auto *save1 =
14123       setupBucketNode(F_, bindings_, input1, /* suffix */ std::to_string(1));
14124 
14125   auto *input2 =
14126       mod_.createPlaceholder(ElemKind::FloatTy, {3, 2}, "input2", false);
14127   bindings_.allocate(input2)->getHandle<float>() = {2.0, 3.0, 4.0,
14128                                                     1.0, 2.0, 5.0};
14129   auto *save2 =
14130       setupBucketNode(F_, bindings_, input2, /* suffix */ std::to_string(2));
14131 
14132   EE_.compile(CompilationMode::Infer);
14133   EE_.run(bindings_);
14134 
14135   // Check the result of the first op:
14136   Tensor *result1 = bindings_.get(save1->getPlaceholder());
14137   Tensor expected1(ElemKind::Int32ITy, {3});
14138   expected1.getHandle<int32_t>() = {1, 2, 1};
14139   EXPECT_TRUE(expected1.isEqual(*result1));
14140 
14141   // Check the result of the second op:
14142   Tensor *result2 = bindings_.get(save2->getPlaceholder());
14143   Tensor expected2(ElemKind::Int32ITy, {3, 2});
14144   expected2.getHandle<int32_t>() = {1, 2, 2, 1, 1, 2};
14145   EXPECT_TRUE(expected2.isEqual(*result2));
14146 }
14147 
14148 /// Check the correctness of the SoftMax operator.
14149 /// The semantic of SoftMax is
14150 /// res_i = exp(input_i) / (exp(input_0) + ... + exp(input_N)).
TEST_P(OperatorTest,SoftMax)14151 TEST_P(OperatorTest, SoftMax) {
14152   CHECK_IF_ENABLED();
14153 
14154   auto *input =
14155       mod_.createPlaceholder(ElemKind::FloatTy, {1, 6}, "input", false);
14156   bindings_.allocate(input)->getHandle<float>() = {1., 3., 2.5, 5., 4., 2.};
14157   auto *selected =
14158       mod_.createPlaceholder(ElemKind::Int64ITy, {1, 1}, "expected", false);
14159   auto *Pool = F_->createSoftMax("pool", input, selected);
14160   auto *S = F_->createSave("save", Pool);
14161   bindings_.allocate(S->getPlaceholder());
14162 
14163   EE_.compile(CompilationMode::Infer);
14164   EE_.run(bindings_);
14165 
14166   auto result = bindings_.get(S->getPlaceholder());
14167   Tensor out(ElemKind::FloatTy, {1, 6});
14168   // Expected results are:
14169   // sum = exp(input_0) + ... + exp(input_N) = ~245.387
14170   // res_0 = exp(1) / sum = ~0.011
14171   // res_1 = exp(3) / sum = ~0.082
14172   // And so on.
14173   out.getHandle<float>() = {0.011f, 0.082f, 0.05f, 0.605f, 0.222f, 0.03f};
14174   EXPECT_TRUE(out.isEqual(*result, 0.001));
14175 }
14176 
14177 /// Check that the softmax operator works properly with FP16.
14178 /// See the test that check the SoftMax operator for more details.
TEST_P(OperatorTest,FP16SoftMax)14179 TEST_P(OperatorTest, FP16SoftMax) {
14180   CHECK_IF_ENABLED();
14181 
14182   auto *input =
14183       mod_.createPlaceholder(ElemKind::Float16Ty, {1, 6}, "input", false);
14184   bindings_.allocate(input)->getHandle<float16_t>() = {1., 3., 2.5, 5., 4., 2.};
14185   auto *selected =
14186       mod_.createPlaceholder(ElemKind::Int64ITy, {1, 1}, "expected", false);
14187   auto *Pool = F_->createSoftMax("pool", input, selected);
14188   auto *S = F_->createSave("save", Pool);
14189   bindings_.allocate(S->getPlaceholder());
14190 
14191   EE_.compile(CompilationMode::Infer);
14192   EE_.run(bindings_);
14193 
14194   auto result = bindings_.get(S->getPlaceholder());
14195   Tensor out(ElemKind::Float16Ty, {1, 6});
14196   out.getHandle<float16_t>() = {0.011f, 0.082f, 0.05f, 0.605f, 0.222f, 0.03f};
14197   EXPECT_TRUE(out.isEqual(*result, 0.001));
14198 }
14199 
14200 /// Check that the softmax operator works properly with BFloat16.
14201 /// See the test that check the SoftMax operator for more details.
TEST_P(OperatorTest,BFloat16SoftMax)14202 TEST_P(OperatorTest, BFloat16SoftMax) {
14203   CHECK_IF_ENABLED();
14204 
14205   auto *input =
14206       mod_.createPlaceholder(ElemKind::BFloat16Ty, {1, 6}, "input", false);
14207   bindings_.allocate(input)->getHandle<bfloat16_t>() = {1., 3., 2.5,
14208                                                         5., 4., 2.};
14209   auto *selected =
14210       mod_.createPlaceholder(ElemKind::Int64ITy, {1, 1}, "expected", false);
14211   auto *Pool = F_->createSoftMax("pool", input, selected);
14212   auto *S = F_->createSave("save", Pool);
14213   bindings_.allocate(S->getPlaceholder());
14214 
14215   EE_.compile(CompilationMode::Infer);
14216   EE_.run(bindings_);
14217 
14218   auto result = bindings_.get(S->getPlaceholder());
14219   Tensor out(ElemKind::BFloat16Ty, {1, 6});
14220   out.getHandle<bfloat16_t>() = {0.011f, 0.082f, 0.05f, 0.605f, 0.222f, 0.03f};
14221   EXPECT_TRUE(out.isEqual(*result, 0.001));
14222 }
14223 
14224 /// Verify that Quantize, Rescale, Dequantize work correctly together.
quantizeSimpleTest(glow::PlaceholderBindings & bindings_,glow::Module & mod_,glow::Function * F_,glow::ExecutionEngine & EE_,ElemKind QTy)14225 static void quantizeSimpleTest(glow::PlaceholderBindings &bindings_,
14226                                glow::Module &mod_, glow::Function *F_,
14227                                glow::ExecutionEngine &EE_, ElemKind QTy) {
14228   auto *input =
14229       mod_.createPlaceholder(ElemKind::FloatTy, {1, 1}, "input", true);
14230   bindings_.allocate(input)->init(Tensor::InitKind::Broadcast, 21,
14231                                   mod_.getPRNG());
14232 
14233   auto *Q =
14234       F_->createQuantize("quant", input, mod_.uniqueType(QTy, {1, 1}, 0.25, 4));
14235   auto *RS = F_->createRescaleQuantized("rescale", Q,
14236                                         mod_.uniqueType(QTy, {1, 1}, 0.5, 11));
14237   auto *D = F_->createDequantize("dequantize", RS, ElemKind::FloatTy);
14238   auto *save = F_->createSave("ret", D);
14239   auto *result = bindings_.allocate(save->getPlaceholder());
14240 
14241   EXPECT_EQ(F_->getNodes().size(), 4);
14242   EE_.compile(CompilationMode::Infer);
14243 
14244   EE_.run(bindings_);
14245   EXPECT_EQ(F_->getNodes().size(), 1);
14246 
14247   auto RH = result->getHandle();
14248   EXPECT_NEAR(RH.at({0, 0}), 21.0, 0.001);
14249 }
14250 
TEST_P(OperatorTest,QuantizeSimpleInt8)14251 TEST_P(OperatorTest, QuantizeSimpleInt8) {
14252   CHECK_IF_ENABLED();
14253   quantizeSimpleTest(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
14254 }
TEST_P(OperatorTest,QuantizeSimpleInt16)14255 TEST_P(OperatorTest, QuantizeSimpleInt16) {
14256   CHECK_IF_ENABLED();
14257   quantizeSimpleTest(bindings_, mod_, F_, EE_, ElemKind::Int16QTy);
14258 }
TEST_P(OperatorTest,QuantizeSimpleInt32)14259 TEST_P(OperatorTest, QuantizeSimpleInt32) {
14260   CHECK_IF_ENABLED();
14261   quantizeSimpleTest(bindings_, mod_, F_, EE_, ElemKind::Int32QTy);
14262 }
14263 
TEST_P(OperatorTest,LengthsToRanges)14264 TEST_P(OperatorTest, LengthsToRanges) {
14265   CHECK_IF_ENABLED();
14266 
14267   /*
14268     LENGTHS = [1, 3, 0, 2]
14269     OUTPUT =  [[0, 1], [1, 3], [4, 0], [4, 2]]
14270   */
14271   auto *lengths =
14272       mod_.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths", false);
14273 
14274   bindings_.allocate(lengths)->getHandle<int32_t>() = {1, 3, 0, 2};
14275 
14276   auto *R = F_->createLengthsToRanges("LTR", lengths);
14277   auto *S = F_->createSave("save", R);
14278   bindings_.allocate(S->getPlaceholder());
14279 
14280   EE_.compile(CompilationMode::Infer);
14281   EE_.run(bindings_);
14282 
14283   Tensor &result = *bindings_.get(S->getPlaceholder());
14284   Tensor expected(ElemKind::Int32ITy, {4, 2});
14285   expected.getHandle<int32_t>() = {
14286       0, 1, 1, 3, 4, 0, 4, 2,
14287   };
14288 
14289   EXPECT_TRUE(expected.isEqual(result));
14290 }
14291 
14292 /// Test that LengthsRangeFill works.
TEST_P(OperatorTest,LengthsRangeFill)14293 TEST_P(OperatorTest, LengthsRangeFill) {
14294   CHECK_IF_ENABLED();
14295 
14296   /*
14297     LENGTHS = [4, 3, 1]
14298     OUTPUT =  [0, 1, 2, 3, 0, 1, 2, 0]
14299   */
14300   auto *lengths =
14301       mod_.createPlaceholder(ElemKind::Int32ITy, {3}, "lengths", false);
14302 
14303   bindings_.allocate(lengths)->getHandle<int32_t>() = {4, 3, 1};
14304 
14305   auto *LRF = F_->createLengthsRangeFill("LRF", lengths, /* maxOutputSize */ 8);
14306   auto *S = F_->createSave("save", LRF);
14307   bindings_.allocate(S->getPlaceholder());
14308 
14309   EE_.compile(CompilationMode::Infer);
14310   EE_.run(bindings_);
14311 
14312   Tensor &result = *bindings_.get(S->getPlaceholder());
14313   Tensor expected(ElemKind::Int32ITy, {8});
14314   expected.getHandle<int32_t>() = {0, 1, 2, 3, 0, 1, 2, 0};
14315 
14316   EXPECT_TRUE(expected.isEqual(result));
14317 }
14318 
14319 /// Helper for testing BatchOneHot with different \p DTy.
14320 template <typename DataType>
batchOneHotTest(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)14321 void batchOneHotTest(glow::PlaceholderBindings &bindings, glow::Module &mod,
14322                      glow::Function *F, glow::ExecutionEngine &EE,
14323                      ElemKind DTy) {
14324   /*
14325     DATA = [[5, 0], [11, 3], [0, 5]]
14326     LENGTHS = [4, 2]
14327     VALUES = [5, 0, 11, 0, 5, 0]
14328     OUTPUT =  [[1, 0, 0, 0, 0, 1], [0, 0, 1, 0, 0, 0], [0, 1, 0, 1, 1, 0]]
14329   */
14330   auto *data =
14331       createPlaceholderConditionallyQuantized(mod, DTy, {3, 2}, "data", false);
14332   auto *lengths =
14333       mod.createPlaceholder(ElemKind::Int32ITy, {2}, "lengths", false, "N");
14334   auto *values = createPlaceholderConditionallyQuantized(mod, DTy, {6},
14335                                                          "values", false, "N");
14336 
14337   bindings.allocate(data)->getHandle<DataType>() = {5, 0, 11, 3, 0, 5};
14338   bindings.allocate(lengths)->getHandle<int32_t>() = {4, 2};
14339   bindings.allocate(values)->getHandle<DataType>() = {5, 0, 11, 0, 5, 0};
14340 
14341   auto *R = F->createBatchOneHot("BOH", data, lengths, values);
14342   auto *S = F->createSave("save", R);
14343   bindings.allocate(S->getPlaceholder());
14344 
14345   EE.compile(CompilationMode::Infer);
14346   EE.run(bindings);
14347 
14348   Tensor &result = *bindings.get(S->getPlaceholder());
14349   auto expected = createTensorConditionallyQuantized(DTy, {3, 6});
14350   expected.getHandle<DataType>() = {
14351       1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0,
14352   };
14353 
14354   EXPECT_TRUE(expected.isEqual(result));
14355 }
14356 
14357 /// Test BatchOneHot with Float data and Int32 Lengths.
TEST_P(OperatorTest,BatchOneHotDataFloat)14358 TEST_P(OperatorTest, BatchOneHotDataFloat) {
14359   CHECK_IF_ENABLED();
14360   batchOneHotTest<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
14361 }
14362 
14363 /// Test BatchOneHot with Float16 data and Int32 Lengths
TEST_P(OperatorTest,BatchOneHotDataFloat16)14364 TEST_P(OperatorTest, BatchOneHotDataFloat16) {
14365   CHECK_IF_ENABLED();
14366   batchOneHotTest<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
14367 }
14368 
14369 /// Test BatchOneHot with BFloat16 data and Int32 Lengths
TEST_P(OperatorTest,BatchOneHotDataBFloat16)14370 TEST_P(OperatorTest, BatchOneHotDataBFloat16) {
14371   CHECK_IF_ENABLED();
14372   batchOneHotTest<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty);
14373 }
14374 
14375 /// Test BatchOneHot with Int64 data and Int32 Lengths.
TEST_P(OperatorTest,BatchOneHotDataInt64)14376 TEST_P(OperatorTest, BatchOneHotDataInt64) {
14377   CHECK_IF_ENABLED();
14378   batchOneHotTest<int64_t>(bindings_, mod_, F_, EE_, ElemKind::Int64ITy);
14379 }
14380 
14381 /// Test BatchOneHot with Int32 data and Int32 Lengths.
TEST_P(OperatorTest,BatchOneHotDataInt32)14382 TEST_P(OperatorTest, BatchOneHotDataInt32) {
14383   CHECK_IF_ENABLED();
14384   batchOneHotTest<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32ITy);
14385 }
14386 
14387 /// Test BatchOneHot with Int8 data and Int32 Lengths.
TEST_P(OperatorTest,BatchOneHotDataInt8)14388 TEST_P(OperatorTest, BatchOneHotDataInt8) {
14389   CHECK_IF_ENABLED();
14390   batchOneHotTest<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
14391 }
14392 
14393 /// Modulo with Int64 Tensors with SignFollowDivisor off.
TEST_P(OperatorTest,ModuloInt64NoSignFollow)14394 TEST_P(OperatorTest, ModuloInt64NoSignFollow) {
14395   CHECK_IF_ENABLED();
14396 
14397   auto *src = mod_.createPlaceholder(ElemKind::Int64ITy, {3, 5}, "src", false);
14398   auto srcH = bindings_.allocate(src)->getHandle<int64_t>();
14399 
14400   srcH = {-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7};
14401 
14402   int64_t divisor = 3;
14403   bool signFollowDivisor = false;
14404 
14405   auto *modulo = F_->createModulo("mod", src, divisor, signFollowDivisor);
14406   auto *result = F_->createSave("save", modulo);
14407   bindings_.allocate(result->getPlaceholder());
14408 
14409   EE_.compile(CompilationMode::Infer);
14410   EE_.run(bindings_);
14411 
14412   auto resultH = bindings_.get(result->getPlaceholder())->getHandle<int64_t>();
14413 
14414   std::vector<int64_t> expectedResults = {-1, 0, -2, -1, 0, -2, -1, 0,
14415                                           1,  2, 0,  1,  2, 0,  1};
14416   ASSERT_EQ(expectedResults.size(), resultH.size());
14417 
14418   for (size_t i = 0, end = expectedResults.size(); i < end; ++i) {
14419     EXPECT_EQ(resultH.raw(i), expectedResults.at(i));
14420   }
14421 }
14422 
14423 /// Modulo with Int64 Tensors with SignFollowDivisor on.
TEST_P(OperatorTest,ModuloInt64SignFollow)14424 TEST_P(OperatorTest, ModuloInt64SignFollow) {
14425   CHECK_IF_ENABLED();
14426 
14427   auto *src = mod_.createPlaceholder(ElemKind::Int64ITy, {3, 5}, "src", false);
14428   auto srcH = bindings_.allocate(src)->getHandle<int64_t>();
14429 
14430   srcH = {-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7};
14431 
14432   int64_t divisor = 3;
14433   bool signFollowDivisor = true;
14434 
14435   auto *modulo = F_->createModulo("mod", src, divisor, signFollowDivisor);
14436   auto *result = F_->createSave("save", modulo);
14437   bindings_.allocate(result->getPlaceholder());
14438 
14439   EE_.compile(CompilationMode::Infer);
14440   EE_.run(bindings_);
14441 
14442   auto resultH = bindings_.get(result->getPlaceholder())->getHandle<int64_t>();
14443 
14444   std::vector<int64_t> expectedResults = {2, 0, 1, 2, 0, 1, 2, 0,
14445                                           1, 2, 0, 1, 2, 0, 1};
14446   ASSERT_EQ(expectedResults.size(), resultH.size());
14447 
14448   for (size_t i = 0, end = expectedResults.size(); i < end; ++i) {
14449     EXPECT_EQ(resultH.raw(i), expectedResults.at(i));
14450   }
14451 }
14452 
14453 /// Modulo with Int32 Tensors with SignFollowDivisor off.
TEST_P(OperatorTest,ModuloInt32NoSignFollow)14454 TEST_P(OperatorTest, ModuloInt32NoSignFollow) {
14455   CHECK_IF_ENABLED();
14456 #define TENSORTYPE int32_t
14457   auto *src = mod_.createPlaceholder(ElemKind::Int32ITy, {3, 5}, "src", false);
14458   auto srcH = bindings_.allocate(src)->getHandle<int32_t>();
14459 
14460   srcH = {-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7};
14461 
14462   int64_t divisor = 3;
14463   bool signFollowDivisor = false;
14464 
14465   auto *modulo = F_->createModulo("mod", src, divisor, signFollowDivisor);
14466   auto *result = F_->createSave("save", modulo);
14467   bindings_.allocate(result->getPlaceholder());
14468 
14469   EE_.compile(CompilationMode::Infer);
14470   EE_.run(bindings_);
14471 
14472   auto resultH = bindings_.get(result->getPlaceholder())->getHandle<int32_t>();
14473 
14474   std::vector<int32_t> expectedResults = {-1, 0, -2, -1, 0, -2, -1, 0,
14475                                           1,  2, 0,  1,  2, 0,  1};
14476   ASSERT_EQ(expectedResults.size(), resultH.size());
14477 
14478   for (size_t i = 0, end = expectedResults.size(); i < end; ++i) {
14479     EXPECT_EQ(resultH.raw(i), expectedResults.at(i));
14480   }
14481 }
14482 
14483 /// Modulo with Int32 Tensors with SignFollowDivisor off.
TEST_P(OperatorTest,ModuloInt32SignFollow)14484 TEST_P(OperatorTest, ModuloInt32SignFollow) {
14485   CHECK_IF_ENABLED();
14486 
14487   auto *src = mod_.createPlaceholder(ElemKind::Int32ITy, {3, 5}, "src", false);
14488   auto srcH = bindings_.allocate(src)->getHandle<int32_t>();
14489 
14490   srcH = {-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7};
14491 
14492   int64_t divisor = 3;
14493   bool signFollowDivisor = true;
14494 
14495   auto *modulo = F_->createModulo("mod", src, divisor, signFollowDivisor);
14496   auto *result = F_->createSave("save", modulo);
14497   bindings_.allocate(result->getPlaceholder());
14498 
14499   EE_.compile(CompilationMode::Infer);
14500   EE_.run(bindings_);
14501 
14502   auto resultH = bindings_.get(result->getPlaceholder())->getHandle<int32_t>();
14503 
14504   std::vector<int32_t> expectedResults = {2, 0, 1, 2, 0, 1, 2, 0,
14505                                           1, 2, 0, 1, 2, 0, 1};
14506   ASSERT_EQ(expectedResults.size(), resultH.size());
14507 
14508   for (size_t i = 0, end = expectedResults.size(); i < end; ++i) {
14509     EXPECT_EQ(resultH.raw(i), expectedResults.at(i));
14510   }
14511 }
14512 
14513 /// Helper to test DotProduct1D using \p DTy.
14514 template <typename DataType>
testDotProduct1D(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)14515 static void testDotProduct1D(glow::PlaceholderBindings &bindings,
14516                              glow::Module &mod, glow::Function *F,
14517                              glow::ExecutionEngine &EE, ElemKind DTy) {
14518   // Input tensors.
14519   constexpr dim_t kDataSize = 10;
14520   auto *X = createPlaceholderConditionallyQuantized(mod, DTy, {kDataSize}, "X",
14521                                                     false, "N");
14522   auto *Y = createPlaceholderConditionallyQuantized(mod, DTy, {kDataSize}, "Y",
14523                                                     false, "N");
14524   auto XH = bindings.allocate(X)->getHandle<DataType>();
14525   auto YH = bindings.allocate(Y)->getHandle<DataType>();
14526 
14527   // Fill inputs with random values.
14528   XH.randomize(-10.0, 10.0, mod.getPRNG());
14529   YH.randomize(-10.0, 10.0, mod.getPRNG());
14530 
14531   // Compute expected output.
14532   auto expected = createTensorConditionallyQuantized(DTy, {kDataSize});
14533   auto expectedH = expected.getHandle<DataType>();
14534 
14535   for (dim_t i = 0; i < kDataSize; ++i) {
14536     expectedH.at({i}) = XH.at({i}) * YH.at({i});
14537   }
14538 
14539   // Compile and run the model.
14540   auto *dotProduct = F->createDotProduct("prod", X, Y);
14541   auto *result = F->createSave("save", dotProduct);
14542   bindings.allocate(result->getPlaceholder());
14543 
14544   EE.compile(CompilationMode::Infer);
14545   EE.run(bindings);
14546 
14547   auto actualH = bindings.get(result->getPlaceholder())->getHandle<DataType>();
14548 
14549   // Check that the output tensor is the same as the expected output.
14550   EXPECT_EQ(actualH.size(), expectedH.size());
14551   for (std::size_t i = 0; i < actualH.size(); ++i) {
14552     EXPECT_NEAR(actualH.raw(i), expectedH.raw(i), 0.00001);
14553   }
14554 }
14555 
14556 /// Test a DotProduct operator with 1D inputs, using FloatTy.
TEST_P(OperatorTest,dotProduct1D_Float)14557 TEST_P(OperatorTest, dotProduct1D_Float) {
14558   CHECK_IF_ENABLED();
14559   testDotProduct1D<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
14560 }
14561 
14562 /// Test a DotProduct operator with 1D inputs, using Float16Ty.
TEST_P(OperatorTest,dotProduct1D_Float16)14563 TEST_P(OperatorTest, dotProduct1D_Float16) {
14564   CHECK_IF_ENABLED();
14565   testDotProduct1D<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
14566 }
14567 
14568 /// Test a DotProduct operator with 1D inputs, using Float16Ty.
TEST_P(OperatorTest,dotProduct1D_BFloat16)14569 TEST_P(OperatorTest, dotProduct1D_BFloat16) {
14570   CHECK_IF_ENABLED();
14571   testDotProduct1D<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty);
14572 }
14573 
14574 /// Test a DotProduct operator with 1D inputs, using Int8Ty.
TEST_P(OperatorTest,dotProduct1D_Int8)14575 TEST_P(OperatorTest, dotProduct1D_Int8) {
14576   CHECK_IF_ENABLED();
14577   testDotProduct1D<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
14578 }
14579 
14580 // Test a BatchedPairwiseDotProduct operator.
TEST_P(OperatorTest,batchedPairwiseDotProduct)14581 TEST_P(OperatorTest, batchedPairwiseDotProduct) {
14582   CHECK_IF_ENABLED();
14583 
14584   // Input tensors.
14585   constexpr dim_t kBatchSize = 2;
14586   constexpr dim_t kVectorSize = 6;
14587 
14588   auto *W = createPlaceholderConditionallyQuantized(
14589       mod_, ElemKind::FloatTy, {kBatchSize, kVectorSize}, "X", false);
14590   auto *X = createPlaceholderConditionallyQuantized(
14591       mod_, ElemKind::FloatTy, {kBatchSize, kVectorSize}, "X", false);
14592   auto *Y = createPlaceholderConditionallyQuantized(
14593       mod_, ElemKind::FloatTy, {kBatchSize, kVectorSize}, "Y", false);
14594   auto *Z = createPlaceholderConditionallyQuantized(
14595       mod_, ElemKind::FloatTy, {kBatchSize, kVectorSize}, "Z", false);
14596   auto WH = bindings_.allocate(W)->getHandle();
14597   auto XH = bindings_.allocate(X)->getHandle();
14598   auto YH = bindings_.allocate(Y)->getHandle();
14599   auto ZH = bindings_.allocate(Z)->getHandle();
14600 
14601   // Fill inputs with random values.
14602 
14603   WH = {1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2};
14604   XH = {2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3};
14605   YH = {3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4};
14606   ZH = {4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5};
14607 
14608   // Compute expected output.
14609   auto expected =
14610       createTensorConditionallyQuantized(ElemKind::FloatTy, {kBatchSize, 6});
14611   auto expectedH = expected.getHandle();
14612 
14613   expectedH = {12, 18, 36, 24, 48, 72, 36, 48, 72, 60, 90, 120};
14614 
14615   // Compile and run the model.
14616   auto *pairwiseDotProduct =
14617       F_->createBatchedPairwiseDotProduct("prod", {W, X, Y, Z});
14618   auto *result = F_->createSave("save", pairwiseDotProduct);
14619   bindings_.allocate(result->getPlaceholder());
14620 
14621   EE_.compile(CompilationMode::Infer);
14622   EE_.run(bindings_);
14623 
14624   auto actualH = bindings_.get(result->getPlaceholder())->getHandle();
14625 
14626   // Check that the output tensor is the same as the expected output.
14627   EXPECT_TRUE(actualH.size() == expectedH.size());
14628   EXPECT_TRUE(actualH.getType().isEqual(expectedH.getType()));
14629   for (std::size_t i = 0; i < actualH.size(); ++i) {
14630     EXPECT_NEAR(actualH.raw(i), expectedH.raw(i), 0.00001);
14631   }
14632 }
14633 
14634 // Test an ElementwiseLinear operator with both axis = 0 and axis = 1
14635 // arguments.
TEST_P(OperatorTest,elementwiseLinear)14636 TEST_P(OperatorTest, elementwiseLinear) {
14637   CHECK_IF_ENABLED();
14638 
14639   constexpr dim_t kRows = 10;
14640   constexpr dim_t kCols = 20;
14641 
14642   // Create and allocate input placeholders.
14643   auto *X =
14644       mod_.createPlaceholder(ElemKind::FloatTy, {kCols, kRows}, "X", false);
14645   auto *w = mod_.createPlaceholder(ElemKind::FloatTy, {kCols}, "w", false);
14646   auto *b = mod_.createPlaceholder(ElemKind::FloatTy, {kCols}, "b", false);
14647 
14648   auto XH = bindings_.allocate(X)->getHandle();
14649   auto wH = bindings_.allocate(w)->getHandle();
14650   auto bH = bindings_.allocate(b)->getHandle();
14651 
14652   // Fill inputs with random values.
14653   XH.randomize(-3.0, 3.0, mod_.getPRNG());
14654   wH.randomize(-3.0, 3.0, mod_.getPRNG());
14655   bH.randomize(-3.0, 3.0, mod_.getPRNG());
14656 
14657   // Create two separate models to test behaviour when axis = 0 and axis = 1.
14658   // For the test with axis = 0, the 0th dimension of X, w, and b must match.
14659   auto *elementwiseLinearAxisZero =
14660       F_->createElementwiseLinear("elAxisZero", X, w, b, /*axis=*/0);
14661   auto *resultAxisZero =
14662       F_->createSave("saveAxisZero", elementwiseLinearAxisZero);
14663   bindings_.allocate(resultAxisZero->getPlaceholder());
14664 
14665   // For the test with axis = 1, the 1st dimension of X must match the 0th
14666   // dimension of w and b must match, so a transpose is needed.
14667   auto *XT = F_->createTranspose("XT", X, {1, 0});
14668   auto *elementwiseLinearAxisOne =
14669       F_->createElementwiseLinear("elAxisOne", XT, w, b, /*axis=*/1);
14670   auto *resultAxisOne = F_->createSave("saveAxisOne", elementwiseLinearAxisOne);
14671   bindings_.allocate(resultAxisOne->getPlaceholder());
14672 
14673   // Compile and run the model.
14674   EE_.compile(CompilationMode::Infer);
14675   EE_.run(bindings_);
14676 
14677   auto resAxisZeroH =
14678       bindings_.get(resultAxisZero->getPlaceholder())->getHandle();
14679   auto resAxisOneH =
14680       bindings_.get(resultAxisOne->getPlaceholder())->getHandle();
14681 
14682   // Results should be the same shape as X/XT.
14683   ASSERT_EQ(resAxisZeroH.size(), XH.size());
14684   ASSERT_EQ(resAxisOneH.size(), (XT->getResult().getType())->size());
14685 
14686   // Compute the expected output and check that the model outputs match.
14687   for (dim_t i = 0; i < resAxisZeroH.dims()[0]; ++i) {
14688     for (dim_t j = 0; j < resAxisZeroH.dims()[1]; ++j) {
14689       float expected = (XH.at({i, j}) * wH.at({i})) + bH.at({i});
14690       EXPECT_NEAR(resAxisZeroH.at({i, j}), expected, 0.00001);
14691       EXPECT_NEAR(resAxisOneH.at({j, i}), expected, 0.00001);
14692     }
14693   }
14694 }
14695 
14696 /// Helper to test DotProduct2D using \p DTy.
14697 template <typename DataType>
testDotProduct2D(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)14698 static void testDotProduct2D(glow::PlaceholderBindings &bindings,
14699                              glow::Module &mod, glow::Function *F,
14700                              glow::ExecutionEngine &EE, ElemKind DTy) {
14701   // Input tensors.
14702   constexpr dim_t kRows = 10;
14703   constexpr dim_t kCols = 14;
14704   auto *X = createPlaceholderConditionallyQuantized(mod, DTy, {kRows, kCols},
14705                                                     "X", false);
14706   auto *Y = createPlaceholderConditionallyQuantized(mod, DTy, {kRows, kCols},
14707                                                     "Y", false);
14708   auto XH = bindings.allocate(X)->getHandle<DataType>();
14709   auto YH = bindings.allocate(Y)->getHandle<DataType>();
14710 
14711   // Fill inputs with random values.
14712   XH.randomize(-3.0, 3.0, mod.getPRNG());
14713   YH.randomize(-3.0, 3.0, mod.getPRNG());
14714 
14715   // Compute expected output.
14716   auto expected = createTensorConditionallyQuantized(DTy, {kRows});
14717   auto expectedH = expected.getHandle<DataType>();
14718 
14719   for (dim_t i = 0; i < kRows; ++i) {
14720     DataType dotProduct = 0.0f;
14721 
14722     // Compute dot product of the i-th row of X and Y.
14723     for (dim_t j = 0; j < kCols; ++j) {
14724       dotProduct += (XH.at({i, j}) * YH.at({i, j}));
14725     }
14726 
14727     expectedH.at({i}) = dotProduct;
14728   }
14729 
14730   // Compile and run the model.
14731   auto *dotProduct = F->createDotProduct("prod", X, Y);
14732   auto *result = F->createSave("save", dotProduct);
14733   bindings.allocate(result->getPlaceholder());
14734 
14735   EE.compile(CompilationMode::Infer);
14736   EE.run(bindings);
14737 
14738   auto actualH = bindings.get(result->getPlaceholder())->getHandle<DataType>();
14739 
14740   // Check that the output tensor is the same as the expected output.
14741   EXPECT_EQ(actualH.size(), expectedH.size());
14742   for (std::size_t i = 0; i < actualH.size(); ++i) {
14743     EXPECT_NEAR(actualH.raw(i), expectedH.raw(i), 0.00001);
14744   }
14745 }
14746 
14747 // Test a DotProduct operator with 2D inputs, using FloatTy.
TEST_P(OperatorTest,dotProduct2D_Float)14748 TEST_P(OperatorTest, dotProduct2D_Float) {
14749   CHECK_IF_ENABLED();
14750   testDotProduct2D<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
14751 }
14752 
14753 // Test a DotProduct operator with 2D inputs, using Float16Ty.
TEST_P(OperatorTest,dotProduct2D_Float16)14754 TEST_P(OperatorTest, dotProduct2D_Float16) {
14755   CHECK_IF_ENABLED();
14756   testDotProduct2D<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
14757 }
14758 
14759 // Test a DotProduct operator with 2D inputs, using BFloat16Ty.
TEST_P(OperatorTest,dotProduct2D_BFloat16)14760 TEST_P(OperatorTest, dotProduct2D_BFloat16) {
14761   CHECK_IF_ENABLED();
14762   testDotProduct2D<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty);
14763 }
14764 
14765 // Test a DotProduct operator with 2D inputs, using Int8QTy.
TEST_P(OperatorTest,dotProduct2D_Int8)14766 TEST_P(OperatorTest, dotProduct2D_Int8) {
14767   CHECK_IF_ENABLED();
14768   testDotProduct2D<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
14769 }
14770 
14771 /// Helper to test BatchBoxCox using \p DTy.
14772 template <typename DataType>
testBatchBoxCox(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,float allowedError=0.0001f,float maxRange=5.0f,float maxLambda2=2.0f)14773 static void testBatchBoxCox(glow::PlaceholderBindings &bindings,
14774                             glow::Module &mod, glow::Function *F,
14775                             glow::ExecutionEngine &EE, ElemKind DTy,
14776                             float allowedError = 0.0001f, float maxRange = 5.0f,
14777                             float maxLambda2 = 2.0f) {
14778   // Input tensors.
14779   const dim_t kRows = 10;
14780   const dim_t kCols = 5;
14781   auto *data = mod.createPlaceholder(DTy, {kRows, kCols}, "data",
14782                                      /* isTrainable */ false);
14783   auto *lambda1 = mod.createPlaceholder(DTy, {kCols}, "lambda1",
14784                                         /* isTrainable */ false);
14785   auto *lambda2 = mod.createPlaceholder(DTy, {kCols}, "lambda2",
14786                                         /* isTrainable */ false);
14787   auto dataH = bindings.allocate(data)->getHandle<DataType>();
14788   auto lambda1H = bindings.allocate(lambda1)->getHandle<DataType>();
14789   auto lambda2H = bindings.allocate(lambda2)->getHandle<DataType>();
14790 
14791   // Fill inputs with random values.
14792   dataH.randomize(0.0, maxRange, mod.getPRNG());
14793   lambda1H.randomize(1.0, 2.0, mod.getPRNG());
14794   lambda2H.randomize(1.0, maxLambda2, mod.getPRNG());
14795 
14796   // Zero out every other element to lambda1 to test that case of the transform.
14797   for (dim_t i = 0; i < kCols; i += 2) {
14798     lambda1H.at({i}) = 0;
14799   }
14800 
14801   const float epsilon = std::is_same<float, DataType>::value
14802                             ? std::numeric_limits<float>::min()
14803                             : 1e-6f;
14804 
14805   // Construct the graph for the backend to run.
14806   auto *BBC = F->createBatchBoxCox("bbc", data, lambda1, lambda2, epsilon);
14807   auto *save = F->createSave("save", BBC);
14808   auto resultH =
14809       bindings.allocate(save->getPlaceholder())->getHandle<DataType>();
14810 
14811   // Compile and run the model, setting results in tensor backed by resultH.
14812   EE.compile(CompilationMode::Infer);
14813   EE.run(bindings);
14814 
14815   // Compute expected output here on the host to compare results.
14816   Tensor expected(DTy, {kRows, kCols});
14817   auto expectedH = expected.getHandle<DataType>();
14818 
14819   for (dim_t i = 0; i < kRows; ++i) {
14820     for (dim_t j = 0; j < kCols; ++j) {
14821       float d = dataH.at({i, j});
14822       float l1 = lambda1H.at({j});
14823       float l2 = lambda2H.at({j});
14824 
14825       // Compute elementwise Box-Cox transform.
14826       float tmp = std::max(d + l2, 1e-6f);
14827       if (l1 == 0) {
14828         // Clip argument to log and pow at 1e-6 to avoid saturation.
14829         expectedH.at({i, j}) = std::log(tmp);
14830       } else {
14831         expectedH.at({i, j}) = (std::pow(tmp, l1) - 1) / l1;
14832       }
14833     }
14834   }
14835 
14836   // Check that the output tensor is the same as the expected output.
14837   for (size_t i = 0; i < resultH.size(); ++i) {
14838     EXPECT_NEAR(resultH.raw(i), expectedH.raw(i), allowedError);
14839   }
14840 }
14841 
14842 /// Test that the BatchBoxCox operator works as expected in FloatTy.
TEST_P(OperatorTest,BatchBoxCox_Float)14843 TEST_P(OperatorTest, BatchBoxCox_Float) {
14844   CHECK_IF_ENABLED();
14845   testBatchBoxCox<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, 0.001f);
14846 }
14847 
14848 /// Test that the BatchBoxCox operator works as expected in Float16Ty.
TEST_P(OperatorTest,BatchBoxCox_Large_Float16)14849 TEST_P(OperatorTest, BatchBoxCox_Large_Float16) {
14850   CHECK_IF_ENABLED();
14851   testBatchBoxCox<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty,
14852                              0.032f, 5.0f);
14853 }
TEST_P(OperatorTest,BatchBoxCox_Medium_Float16)14854 TEST_P(OperatorTest, BatchBoxCox_Medium_Float16) {
14855   CHECK_IF_ENABLED();
14856   testBatchBoxCox<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty,
14857                              0.016f, 3.0f);
14858 }
TEST_P(OperatorTest,BatchBoxCox_Small_Float16)14859 TEST_P(OperatorTest, BatchBoxCox_Small_Float16) {
14860   CHECK_IF_ENABLED();
14861   testBatchBoxCox<float16_t>(bindings_, mod_, F_, EE_, ElemKind::Float16Ty,
14862                              0.003f, 1.0f, 1.001f);
14863 }
14864 
14865 /// Test that the BatchBoxCox operator works as expected in BFloat16Ty.
TEST_P(OperatorTest,BatchBoxCox_Large_BFloat16)14866 TEST_P(OperatorTest, BatchBoxCox_Large_BFloat16) {
14867   CHECK_IF_ENABLED();
14868   testBatchBoxCox<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty,
14869                               0.32f, 5.0f);
14870 }
TEST_P(OperatorTest,BatchBoxCox_Medium_BFloat16)14871 TEST_P(OperatorTest, BatchBoxCox_Medium_BFloat16) {
14872   CHECK_IF_ENABLED();
14873   testBatchBoxCox<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty,
14874                               0.16f, 3.0f);
14875 }
TEST_P(OperatorTest,BatchBoxCox_Small_BFloat16)14876 TEST_P(OperatorTest, BatchBoxCox_Small_BFloat16) {
14877   CHECK_IF_ENABLED();
14878   testBatchBoxCox<bfloat16_t>(bindings_, mod_, F_, EE_, ElemKind::BFloat16Ty,
14879                               0.03f, 1.0f, 1.001f);
14880 }
14881 
14882 /// Test that Arithmetic ops work.
14883 #define TEST_ARITH_OP_FLOAT(OP_NAME_, OP_)                                     \
14884   TEST_P(OperatorTest, OP_NAME_##ArithFloatTest) {                             \
14885     CHECK_IF_ENABLED();                                                        \
14886     constexpr dim_t size = 50;                                                 \
14887     auto *A = mod_.createPlaceholder(ElemKind::FloatTy, {size}, "A", false);   \
14888     auto *B = mod_.createPlaceholder(ElemKind::FloatTy, {size}, "B", false);   \
14889     auto *AT = bindings_.allocate(A);                                          \
14890     auto *BT = bindings_.allocate(B);                                          \
14891     auto AH = AT->getHandle();                                                 \
14892     auto BH = BT->getHandle();                                                 \
14893     AH.randomize(-10.0f, 10.0f, mod_.getPRNG());                               \
14894     BH.randomize(0.01f, 10.0f, mod_.getPRNG());                                \
14895                                                                                \
14896     auto *N = F_->create##OP_NAME_("op", A, B);                                \
14897     auto *save = F_->createSave("save", N);                                    \
14898     auto resultH = bindings_.allocate(save->getPlaceholder())->getHandle();    \
14899                                                                                \
14900     EE_.compile(CompilationMode::Infer);                                       \
14901     EE_.run(bindings_);                                                        \
14902                                                                                \
14903     for (size_t i = 0; i < size; i++) {                                        \
14904       EXPECT_FLOAT_EQ(resultH.raw(i), OP_(AH.raw(i), BH.raw(i)));              \
14905     }                                                                          \
14906   }
14907 
__anon6e6a51880602(float a, float b) 14908 TEST_ARITH_OP_FLOAT(Add, [](float a, float b) { return a + b; })
__anon6e6a51880702(float a, float b) 14909 TEST_ARITH_OP_FLOAT(Sub, [](float a, float b) { return a - b; })
__anon6e6a51880802(float a, float b) 14910 TEST_ARITH_OP_FLOAT(Mul, [](float a, float b) { return a * b; })
__anon6e6a51880902(float a, float b) 14911 TEST_ARITH_OP_FLOAT(Div, [](float a, float b) { return a / b; })
__anon6e6a51880a02(float a, float b) 14912 TEST_ARITH_OP_FLOAT(Min, [](float a, float b) { return std::min(a, b); })
__anon6e6a51880b02(float a, float b) 14913 TEST_ARITH_OP_FLOAT(Max, [](float a, float b) { return std::max(a, b); })
14914 
14915 /// Helper to test ConvertTo casting from \p STy to \p DTy.
14916 template <typename SourceType, typename DestType>
testConvertTo(glow::PlaceholderBindings & bindings_,glow::Module & mod_,glow::Function * F_,glow::ExecutionEngine & EE_,ElemKind STy,ElemKind DTy)14917 static void testConvertTo(glow::PlaceholderBindings &bindings_,
14918                           glow::Module &mod_, glow::Function *F_,
14919                           glow::ExecutionEngine &EE_, ElemKind STy,
14920                           ElemKind DTy) {
14921   // Input tensor in source type.
14922   dim_t shape[] = {5, 3, 20};
14923   auto *data = mod_.createPlaceholder(STy, shape, "data",
14924                                       /* isTrainable */ false);
14925   auto dataH = bindings_.allocate(data)->getHandle<SourceType>();
14926   if (STy == ElemKind::BoolTy) {
14927     for (dim_t i = 0; i < dataH.size(); i++) {
14928       dataH.raw(i) = static_cast<bool>(i % 2 == 0);
14929     }
14930   } else {
14931     dataH.randomize(-1000, 1000, mod_.getPRNG());
14932   }
14933 
14934   // Construct the graph for the backend to run, converting to dest type.
14935   auto OT = mod_.uniqueType(DTy, shape);
14936   auto *convert = F_->createConvertTo("convert", data, OT);
14937   auto *save = F_->createSave("save", convert);
14938   auto resultH =
14939       bindings_.allocate(save->getPlaceholder())->getHandle<DestType>();
14940 
14941   // Compile and run the model, setting results in tensor backed by resultH.
14942   EE_.compile(CompilationMode::Infer);
14943   EE_.run(bindings_);
14944 
14945   // Compute expected output here on the host to compare results.
14946   Tensor expected(DTy, shape);
14947   auto expectedH = expected.getHandle<DestType>();
14948   for (size_t i = 0, e = expectedH.size(); i < e; ++i) {
14949     expectedH.raw(i) = static_cast<DestType>(dataH.raw(i));
14950   }
14951 
14952   // Check that the output tensor is the same as the expected output.
14953   for (size_t i = 0, e = resultH.size(); i < e; i++) {
14954     const DestType exp = expectedH.raw(i);
14955     const DestType res = resultH.raw(i);
14956     if (DTy == ElemKind::FloatTy) {
14957       EXPECT_FLOAT_EQ(exp, res);
14958     } else {
14959       EXPECT_EQ(exp, res);
14960     }
14961   }
14962 }
14963 
14964 /// Test that ConvertTo operator casts correctly from one type to another.
14965 #define TEST_CONVERT_TO(T_FROM, T_TO, DTY_FROM, DTY_TO)                        \
14966   TEST_P(OperatorTest, ConvertFrom_##DTY_FROM##_To_##DTY_TO) {                 \
14967     CHECK_IF_ENABLED();                                                        \
14968     testConvertTo<T_FROM, T_TO>(bindings_, mod_, F_, EE_, ElemKind::DTY_FROM,  \
14969                                 ElemKind::DTY_TO);                             \
14970   }
TEST_CONVERT_TO(float,float,FloatTy,FloatTy)14971 TEST_CONVERT_TO(float, float, FloatTy, FloatTy)
14972 TEST_CONVERT_TO(float, float16_t, FloatTy, Float16Ty)
14973 TEST_CONVERT_TO(float, bfloat16_t, FloatTy, BFloat16Ty)
14974 TEST_CONVERT_TO(float, int32_t, FloatTy, Int32ITy)
14975 TEST_CONVERT_TO(float, int64_t, FloatTy, Int64ITy)
14976 TEST_CONVERT_TO(float, bool, FloatTy, BoolTy)
14977 TEST_CONVERT_TO(float16_t, float, Float16Ty, FloatTy)
14978 TEST_CONVERT_TO(float16_t, float16_t, Float16Ty, Float16Ty)
14979 TEST_CONVERT_TO(float16_t, bfloat16_t, Float16Ty, BFloat16Ty)
14980 TEST_CONVERT_TO(float16_t, int32_t, Float16Ty, Int32ITy)
14981 TEST_CONVERT_TO(float16_t, int64_t, Float16Ty, Int64ITy)
14982 TEST_CONVERT_TO(bfloat16_t, float, BFloat16Ty, FloatTy)
14983 TEST_CONVERT_TO(bfloat16_t, float16_t, BFloat16Ty, Float16Ty)
14984 TEST_CONVERT_TO(bfloat16_t, bfloat16_t, BFloat16Ty, BFloat16Ty)
14985 TEST_CONVERT_TO(bfloat16_t, int32_t, BFloat16Ty, Int32ITy)
14986 TEST_CONVERT_TO(bfloat16_t, int64_t, BFloat16Ty, Int64ITy)
14987 TEST_CONVERT_TO(int32_t, float, Int32ITy, FloatTy)
14988 TEST_CONVERT_TO(int32_t, float16_t, Int32ITy, Float16Ty)
14989 TEST_CONVERT_TO(int32_t, bfloat16_t, Int32ITy, BFloat16Ty)
14990 TEST_CONVERT_TO(int32_t, int32_t, Int32ITy, Int32ITy)
14991 TEST_CONVERT_TO(int32_t, int64_t, Int32ITy, Int64ITy)
14992 TEST_CONVERT_TO(int64_t, float, Int64ITy, FloatTy)
14993 TEST_CONVERT_TO(int64_t, float16_t, Int64ITy, Float16Ty)
14994 TEST_CONVERT_TO(int64_t, bfloat16_t, Int64ITy, BFloat16Ty)
14995 TEST_CONVERT_TO(int64_t, int32_t, Int64ITy, Int32ITy)
14996 TEST_CONVERT_TO(int64_t, int64_t, Int64ITy, Int64ITy)
14997 TEST_CONVERT_TO(bool, float, BoolTy, FloatTy)
14998 TEST_CONVERT_TO(bool, float16_t, BoolTy, Float16Ty)
14999 TEST_CONVERT_TO(bool, bfloat16_t, BoolTy, BFloat16Ty)
15000 
15001 #undef TEST_CONVERT_TO
15002 
15003 /// Helper to test ConvertTo casting from \p STy to \p DTy and back.
15004 template <typename SourceType, typename DestType>
15005 static void testConvertToAndBack(glow::PlaceholderBindings &bindings_,
15006                                  glow::Module &mod_, glow::Function *F_,
15007                                  glow::ExecutionEngine &EE_, ElemKind STy,
15008                                  ElemKind DTy, bool castIsNoOp) {
15009   // Input tensor in source type.
15010   dim_t shape[] = {5, 3, 20};
15011   auto *data = mod_.createPlaceholder(STy, shape, "data",
15012                                       /* isTrainable */ false);
15013   auto dataH = bindings_.allocate(data)->getHandle<SourceType>();
15014   dataH.randomize(-1000, 1000, mod_.getPRNG());
15015 
15016   // Construct the graph for the backend to run, converting to dest type and
15017   // back.
15018   auto IT = mod_.uniqueType(STy, shape);
15019   auto OT = mod_.uniqueType(DTy, shape);
15020   auto *convert = F_->createConvertTo("convert_forth", data, OT);
15021   auto *convertBack = F_->createConvertTo("convert_back", convert, IT);
15022   auto *save = F_->createSave("save", convertBack);
15023   auto resultH =
15024       bindings_.allocate(save->getPlaceholder())->getHandle<SourceType>();
15025 
15026   // Compile and run the model, setting results in tensor backed by resultH.
15027   EXPECT_EQ(F_->getNodes().size(), 3);
15028   EE_.compile(CompilationMode::Infer);
15029   EE_.run(bindings_);
15030   EXPECT_EQ(F_->getNodes().size(), size_t(castIsNoOp ? 1 : 3));
15031 
15032   for (size_t i = 0, e = resultH.size(); i < e; i++) {
15033     const SourceType res = resultH.raw(i);
15034     const SourceType expected =
15035         static_cast<SourceType>(static_cast<DestType>(dataH.raw(i)));
15036     EXPECT_EQ(res, expected);
15037   }
15038 }
15039 
15040 /// Test that ConvertTo operator casts correctly from one type to another.
15041 #define TEST_CAST_2WAYS(T_FROM, T_TO, DTY_FROM, DTY_TO, NOOP_CAST)             \
15042   TEST_P(OperatorTest, ConvertFrom_##DTY_FROM##_To_##DTY_TO##_AndBack) {       \
15043     CHECK_IF_ENABLED();                                                        \
15044     testConvertToAndBack<T_FROM, T_TO>(bindings_, mod_, F_, EE_,               \
15045                                        ElemKind::DTY_FROM, ElemKind::DTY_TO,   \
15046                                        NOOP_CAST);                             \
15047   }
TEST_CAST_2WAYS(float,float,FloatTy,FloatTy,true)15048 TEST_CAST_2WAYS(float, float, FloatTy, FloatTy, /* castIsNoOp */ true)
15049 TEST_CAST_2WAYS(float, float16_t, FloatTy, Float16Ty, /* castIsNoOp */ false)
15050 // FIXME: Should this test succeed?
15051 TEST_CAST_2WAYS(float, bfloat16_t, FloatTy, BFloat16Ty, /* castIsNoOp */ false)
15052 TEST_CAST_2WAYS(float, int32_t, FloatTy, Int32ITy, /* castIsNoOp */ false)
15053 TEST_CAST_2WAYS(float, int64_t, FloatTy, Int64ITy, /* castIsNoOp */ false)
15054 TEST_CAST_2WAYS(float16_t, float, Float16Ty, FloatTy, /* castIsNoOp */ true)
15055 TEST_CAST_2WAYS(float16_t, float16_t, Float16Ty, Float16Ty,
15056                 /* castIsNoOp */ true)
15057 TEST_CAST_2WAYS(float16_t, bfloat16_t, Float16Ty, BFloat16Ty,
15058                 /* castIsNoOp */ false)
15059 TEST_CAST_2WAYS(float16_t, int32_t, Float16Ty, Int32ITy,
15060                 /* castIsNoOp */ false)
15061 TEST_CAST_2WAYS(float16_t, int64_t, Float16Ty, Int64ITy,
15062                 /* castIsNoOp */ false)
15063 TEST_CAST_2WAYS(bfloat16_t, float, BFloat16Ty, FloatTy, /* castIsNoOp */ true)
15064 TEST_CAST_2WAYS(bfloat16_t, float16_t, BFloat16Ty, Float16Ty,
15065                 /* castIsNoOp */ true)
15066 TEST_CAST_2WAYS(bfloat16_t, bfloat16_t, BFloat16Ty, BFloat16Ty,
15067                 /* castIsNoOp */ true)
15068 TEST_CAST_2WAYS(bfloat16_t, int32_t, BFloat16Ty, Int32ITy,
15069                 /* castIsNoOp */ false)
15070 TEST_CAST_2WAYS(bfloat16_t, int64_t, BFloat16Ty, Int64ITy,
15071                 /* castIsNoOp */ false)
15072 TEST_CAST_2WAYS(int32_t, float, Int32ITy, FloatTy, /* castIsNoOp */ false)
15073 TEST_CAST_2WAYS(int32_t, float16_t, Int32ITy, Float16Ty,
15074                 /* castIsNoOp */ false)
15075 TEST_CAST_2WAYS(int32_t, bfloat16_t, Int32ITy, BFloat16Ty,
15076                 /* castIsNoOp */ false)
15077 TEST_CAST_2WAYS(int32_t, int32_t, Int32ITy, Int32ITy, /* castIsNoOp */ true)
15078 TEST_CAST_2WAYS(int32_t, int64_t, Int32ITy, Int64ITy, /* castIsNoOp */ true)
15079 TEST_CAST_2WAYS(int64_t, float, Int64ITy, FloatTy, /* castIsNoOp */ false)
15080 TEST_CAST_2WAYS(int64_t, float16_t, Int64ITy, Float16Ty,
15081                 /* castIsNoOp */ false)
15082 TEST_CAST_2WAYS(int64_t, bfloat16_t, Int64ITy, BFloat16Ty,
15083                 /* castIsNoOp */ false)
15084 TEST_CAST_2WAYS(int64_t, int32_t, Int64ITy, Int32ITy, /* castIsNoOp */ false)
15085 TEST_CAST_2WAYS(int64_t, int64_t, Int64ITy, Int64ITy, /* castIsNoOp */ true)
15086 
15087 #undef TEST_CAST_2WAYS
15088 
15089 template <typename DataType>
15090 glow::Handle<DataType>
15091 mulHelper(glow::PlaceholderBindings &bindings, glow::Module &mod,
15092           glow::Function *F, glow::ExecutionEngine &EE, ElemKind DTy,
15093           llvm::ArrayRef<DataType> lhsValues,
15094           llvm::ArrayRef<DataType> rhsValues, llvm::ArrayRef<dim_t> lhsDims,
15095           llvm::ArrayRef<dim_t> rhsDims) {
15096   auto *lhs = mod.createPlaceholder(DTy, lhsDims, "lhs", false);
15097   auto *rhs = mod.createPlaceholder(DTy, rhsDims, "rhs", false);
15098   bindings.allocate(lhs)->getHandle<DataType>() = lhsValues;
15099   bindings.allocate(rhs)->getHandle<DataType>() = rhsValues;
15100 
15101   auto *N = F->createMul("Mul", lhs, rhs);
15102   auto *save = F->createSave("save", N);
15103   auto *saveTensor = bindings.allocate(save->getPlaceholder());
15104 
15105   EE.compile(CompilationMode::Infer);
15106   EE.run(bindings);
15107 
15108   return saveTensor->getHandle<DataType>();
15109 }
15110 
15111 /// Check that the Mul operator behaves correctly with int32.
TEST_P(OperatorTest,mul_int32)15112 TEST_P(OperatorTest, mul_int32) {
15113   CHECK_IF_ENABLED();
15114 
15115   llvm::SmallVector<int32_t, 16> xValues = {
15116       3, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15117 
15118       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15119 
15120       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15121 
15122       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1};
15123 
15124   llvm::SmallVector<int32_t, 16> yValues = {
15125       3, 4, 5, 7, 2, 5, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15126 
15127       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15128 
15129       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15130 
15131       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6};
15132 
15133   llvm::SmallVector<dim_t, 4> xDims = {2, 2, 4, 4};
15134   llvm::SmallVector<dim_t, 4> yDims = {2, 2, 4, 4};
15135 
15136   Handle<int32_t> saveH =
15137       mulHelper<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32ITy, xValues,
15138                          yValues, xDims, yDims);
15139 
15140   int counter = 0;
15141   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
15142     for (dim_t j = 0; j < saveH.dims()[1]; ++j) {
15143       for (dim_t k = 0; k < saveH.dims()[2]; ++k) {
15144         for (dim_t f = 0; f < saveH.dims()[3]; ++f) {
15145           EXPECT_EQ(xValues[counter] * yValues[counter],
15146                     saveH.at({i, j, k, f}));
15147           ++counter;
15148         }
15149       }
15150     }
15151   }
15152 }
15153 
15154 /// Check that the Mul operator behaves correctly with int64.
TEST_P(OperatorTest,mul_int64)15155 TEST_P(OperatorTest, mul_int64) {
15156   CHECK_IF_ENABLED();
15157 
15158   llvm::SmallVector<int64_t, 16> xValues = {
15159       3, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15160 
15161       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15162 
15163       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15164 
15165       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1};
15166 
15167   llvm::SmallVector<int64_t, 16> yValues = {
15168       3, 4, 5, 7, 2, 5, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15169 
15170       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15171 
15172       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15173 
15174       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6};
15175 
15176   llvm::SmallVector<dim_t, 4> xDims = {2, 2, 4, 4};
15177   llvm::SmallVector<dim_t, 4> yDims = {2, 2, 4, 4};
15178 
15179   Handle<int64_t> saveH =
15180       mulHelper<int64_t>(bindings_, mod_, F_, EE_, ElemKind::Int64ITy, xValues,
15181                          yValues, xDims, yDims);
15182 
15183   int counter = 0;
15184   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
15185     for (dim_t j = 0; j < saveH.dims()[1]; ++j) {
15186       for (dim_t k = 0; k < saveH.dims()[2]; ++k) {
15187         for (dim_t f = 0; f < saveH.dims()[3]; ++f) {
15188           EXPECT_EQ(xValues[counter] * yValues[counter],
15189                     saveH.at({i, j, k, f}));
15190           ++counter;
15191         }
15192       }
15193     }
15194   }
15195 }
15196 /// Check that the Mul operator behaves correctly with float.
TEST_P(OperatorTest,mul_float)15197 TEST_P(OperatorTest, mul_float) {
15198   CHECK_IF_ENABLED();
15199 
15200   llvm::SmallVector<float, 16> xValues = {
15201       3, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15202 
15203       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15204 
15205       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15206 
15207       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1};
15208 
15209   llvm::SmallVector<float, 16> yValues = {
15210       3, 4, 5, 7, 2, 5, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15211 
15212       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15213 
15214       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15215 
15216       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6};
15217 
15218   llvm::SmallVector<dim_t, 4> xDims = {2, 2, 4, 4};
15219   llvm::SmallVector<dim_t, 4> yDims = {2, 2, 4, 4};
15220 
15221   Handle<float> saveH =
15222       mulHelper<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, xValues,
15223                        yValues, xDims, yDims);
15224 
15225   int counter = 0;
15226   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
15227     for (dim_t j = 0; j < saveH.dims()[1]; ++j) {
15228       for (dim_t k = 0; k < saveH.dims()[2]; ++k) {
15229         for (dim_t f = 0; f < saveH.dims()[3]; ++f) {
15230           EXPECT_FLOAT_EQ(xValues[counter] * yValues[counter],
15231                           saveH.at({i, j, k, f}));
15232           ++counter;
15233         }
15234       }
15235     }
15236   }
15237 }
15238 
15239 template <typename DataType>
15240 glow::Handle<DataType>
addHelper(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy,llvm::ArrayRef<DataType> lhsValues,llvm::ArrayRef<DataType> rhsValues,llvm::ArrayRef<dim_t> lhsDims,llvm::ArrayRef<dim_t> rhsDims)15241 addHelper(glow::PlaceholderBindings &bindings, glow::Module &mod,
15242           glow::Function *F, glow::ExecutionEngine &EE, ElemKind DTy,
15243           llvm::ArrayRef<DataType> lhsValues,
15244           llvm::ArrayRef<DataType> rhsValues, llvm::ArrayRef<dim_t> lhsDims,
15245           llvm::ArrayRef<dim_t> rhsDims) {
15246   auto *lhs = mod.createPlaceholder(DTy, lhsDims, "lhs", false);
15247   auto *rhs = mod.createPlaceholder(DTy, rhsDims, "rhs", false);
15248   bindings.allocate(lhs)->getHandle<DataType>() = lhsValues;
15249   bindings.allocate(rhs)->getHandle<DataType>() = rhsValues;
15250 
15251   auto *N = F->createAdd("Add", lhs, rhs);
15252   auto *save = F->createSave("save", N);
15253   auto *saveTensor = bindings.allocate(save->getPlaceholder());
15254 
15255   EE.compile(CompilationMode::Infer);
15256   EE.run(bindings);
15257 
15258   return saveTensor->getHandle<DataType>();
15259 }
15260 
15261 /// Check that the Mul operator behaves correctly with int32.
TEST_P(OperatorTest,add_int32)15262 TEST_P(OperatorTest, add_int32) {
15263   CHECK_IF_ENABLED();
15264 
15265   llvm::SmallVector<int32_t, 16> xValues = {
15266       3, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15267 
15268       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15269 
15270       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15271 
15272       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1};
15273 
15274   llvm::SmallVector<int32_t, 16> yValues = {
15275       3, 4, 5, 7, 2, 5, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15276 
15277       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15278 
15279       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15280 
15281       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6};
15282 
15283   llvm::SmallVector<dim_t, 4> xDims = {2, 2, 4, 4};
15284   llvm::SmallVector<dim_t, 4> yDims = {2, 2, 4, 4};
15285 
15286   Handle<int32_t> saveH =
15287       addHelper<int32_t>(bindings_, mod_, F_, EE_, ElemKind::Int32ITy, xValues,
15288                          yValues, xDims, yDims);
15289 
15290   int counter = 0;
15291   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
15292     for (dim_t j = 0; j < saveH.dims()[1]; ++j) {
15293       for (dim_t k = 0; k < saveH.dims()[2]; ++k) {
15294         for (dim_t f = 0; f < saveH.dims()[3]; ++f) {
15295           EXPECT_EQ(xValues[counter] + yValues[counter],
15296                     saveH.at({i, j, k, f}));
15297           ++counter;
15298         }
15299       }
15300     }
15301   }
15302 }
15303 
15304 /// Check that the Mul operator behaves correctly with int32.
TEST_P(OperatorTest,add_int64)15305 TEST_P(OperatorTest, add_int64) {
15306   CHECK_IF_ENABLED();
15307 
15308   llvm::SmallVector<int64_t, 16> xValues = {
15309       3, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15310 
15311       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15312 
15313       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15314 
15315       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1};
15316 
15317   llvm::SmallVector<int64_t, 16> yValues = {
15318       3, 4, 5, 7, 2, 5, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15319 
15320       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15321 
15322       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15323 
15324       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6};
15325 
15326   llvm::SmallVector<dim_t, 4> xDims = {2, 2, 4, 4};
15327   llvm::SmallVector<dim_t, 4> yDims = {2, 2, 4, 4};
15328 
15329   Handle<int64_t> saveH =
15330       addHelper<int64_t>(bindings_, mod_, F_, EE_, ElemKind::Int64ITy, xValues,
15331                          yValues, xDims, yDims);
15332 
15333   int counter = 0;
15334   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
15335     for (dim_t j = 0; j < saveH.dims()[1]; ++j) {
15336       for (dim_t k = 0; k < saveH.dims()[2]; ++k) {
15337         for (dim_t f = 0; f < saveH.dims()[3]; ++f) {
15338           EXPECT_EQ(xValues[counter] + yValues[counter],
15339                     saveH.at({i, j, k, f}));
15340           ++counter;
15341         }
15342       }
15343     }
15344   }
15345 }
15346 /// Check that the Mul operator behaves correctly with int32.
TEST_P(OperatorTest,add_float)15347 TEST_P(OperatorTest, add_float) {
15348   CHECK_IF_ENABLED();
15349 
15350   llvm::SmallVector<float, 16> xValues = {
15351       3, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15352 
15353       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15354 
15355       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1,
15356 
15357       1, 2, 3, 6, 4, 5, 6, 3, 7, 8, 9, 2, 3, 5, 7, 1};
15358 
15359   llvm::SmallVector<float, 16> yValues = {
15360       3, 4, 5, 7, 2, 5, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15361 
15362       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15363 
15364       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6,
15365 
15366       3, 4, 5, 7, 2, 1, 0, 6, 4, 2, 1, 8, 5, 9, 2, 6};
15367 
15368   llvm::SmallVector<dim_t, 4> xDims = {2, 2, 4, 4};
15369   llvm::SmallVector<dim_t, 4> yDims = {2, 2, 4, 4};
15370 
15371   Handle<float> saveH =
15372       addHelper<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy, xValues,
15373                        yValues, xDims, yDims);
15374 
15375   int counter = 0;
15376   for (dim_t i = 0; i < saveH.dims()[0]; ++i) {
15377     for (dim_t j = 0; j < saveH.dims()[1]; ++j) {
15378       for (dim_t k = 0; k < saveH.dims()[2]; ++k) {
15379         for (dim_t f = 0; f < saveH.dims()[3]; ++f) {
15380           EXPECT_FLOAT_EQ(xValues[counter] + yValues[counter],
15381                           saveH.at({i, j, k, f}));
15382           ++counter;
15383         }
15384       }
15385     }
15386   }
15387 }
15388 
15389 static FunctionTensorPair
createAndInitLayerNormTest(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE)15390 createAndInitLayerNormTest(glow::PlaceholderBindings &bindings,
15391                            glow::ExecutionEngine &EE) {
15392   auto &mod = EE.getModule();
15393   Function *F = mod.createFunction("main");
15394 
15395   auto *input =
15396       mod.createPlaceholder(ElemKind::FloatTy, {1, 4, 5, 5}, "in", false);
15397 
15398   Tensor scaleT(ElemKind::FloatTy, {5, 5});
15399   scaleT.getHandle().randomize(0.0f, 1.0f, mod.getPRNG());
15400   Constant *scaleC = mod.createConstant("scale", std::move(scaleT));
15401   Tensor biasT(ElemKind::FloatTy, {5, 5});
15402   biasT.getHandle().randomize(0.0f, 1.0f, mod.getPRNG());
15403   Constant *biasC = mod.createConstant("bias", std::move(biasT));
15404 
15405   LayerNormalizationNode *LNN =
15406       F->createLayerNormalization("LN", input, scaleC, biasC, 1e-5);
15407 
15408   bindings.allocate(input)->getHandle().randomize(0.0f, 1.0f, mod.getPRNG());
15409 
15410   auto *res = F->createSave("save", LNN);
15411   ::glow::convertPlaceholdersToConstants(F, bindings,
15412                                          {input, res->getPlaceholder()});
15413   auto *resultTensor = bindings.allocate(res->getPlaceholder());
15414 
15415   return std::make_pair(F, resultTensor);
15416 }
15417 
15418 /// Test LayerNorm with FloatTy.
TEST_P(OperatorStatelessTest,LayerNorm_Float)15419 TEST_P(OperatorStatelessTest, LayerNorm_Float) {
15420   CHECK_IF_ENABLED();
15421   compareAgainstInterpreter(getBackendName(), createAndInitLayerNormTest,
15422                             ElemKind::FloatTy, ElemKind::FloatTy, 0.0001f,
15423                             parCloneCountOpt);
15424 }
15425 
15426 /// Test LayerNorm with Float16Ty.
TEST_P(OperatorStatelessTest,LayerNorm_Float16)15427 TEST_P(OperatorStatelessTest, LayerNorm_Float16) {
15428   CHECK_IF_ENABLED();
15429   compareAgainstInterpreter(getBackendName(), createAndInitLayerNormTest,
15430                             ElemKind::FloatTy, ElemKind::Float16Ty, 0.01f,
15431                             parCloneCountOpt);
15432 }
15433 
15434 /// Test LayerNorm with BFloat16Ty.
TEST_P(OperatorStatelessTest,LayerNorm_BFloat16)15435 TEST_P(OperatorStatelessTest, LayerNorm_BFloat16) {
15436   CHECK_IF_ENABLED();
15437   compareAgainstInterpreter(getBackendName(), createAndInitLayerNormTest,
15438                             ElemKind::FloatTy, ElemKind::BFloat16Ty, 0.01f,
15439                             parCloneCountOpt);
15440 }
15441 
15442 /// Test LayerNorm with Int8Ty.
TEST_P(OperatorStatelessTest,LayerNorm_Int8)15443 TEST_P(OperatorStatelessTest, LayerNorm_Int8) {
15444   CHECK_IF_ENABLED();
15445   compareAgainstInterpreter(getBackendName(), createAndInitLayerNormTest,
15446                             ElemKind::FloatTy, ElemKind::Int8QTy, 0.04f,
15447                             parCloneCountOpt);
15448 }
15449 
testDequantizeFRWQ(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind destTy)15450 static void testDequantizeFRWQ(glow::PlaceholderBindings &bindings,
15451                                glow::Module &mod, glow::Function *F,
15452                                glow::ExecutionEngine &EE, ElemKind destTy) {
15453   Tensor FT(ElemKind::FloatTy, {10, 20});
15454   FT.getHandle().randomize(-0.5, 0.5, mod.getPRNG());
15455   TypeRef RWQTy = mod.uniqueType(ElemKind::UInt8FusedQTy,
15456                                  {10, 20 + 2 * sizeof(float)}, 1.0, 0);
15457   Tensor RWQT(RWQTy);
15458   quantization::tensorFusedRowwiseQuantization<float>(FT, RWQT);
15459 
15460   auto *input = mod.createPlaceholder(RWQTy, "input", false);
15461   bindings.insert(input, std::move(RWQT));
15462 
15463   auto *D = F->createDequantize("dequantize", input, destTy);
15464   auto *save = F->createSave("ret", D);
15465   auto *result = bindings.allocate(save->getPlaceholder());
15466 
15467   EE.compile(CompilationMode::Infer);
15468   EE.run(bindings);
15469 
15470   if (destTy == ElemKind::Float16Ty) {
15471     FT.convertToType(destTy);
15472   }
15473   EXPECT_TRUE(FT.isEqual(*result, 0.002f));
15474 }
15475 
TEST_P(OperatorTest,DequantizeFRWQ_Float)15476 TEST_P(OperatorTest, DequantizeFRWQ_Float) {
15477   CHECK_IF_ENABLED();
15478   testDequantizeFRWQ(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
15479 }
TEST_P(OperatorTest,DequantizeFRWQ_Float16)15480 TEST_P(OperatorTest, DequantizeFRWQ_Float16) {
15481   CHECK_IF_ENABLED();
15482   testDequantizeFRWQ(bindings_, mod_, F_, EE_, ElemKind::Float16Ty);
15483 }
15484 
15485 template <typename DataType>
testUpsample3D(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)15486 static void testUpsample3D(glow::PlaceholderBindings &bindings,
15487                            glow::Module &mod, glow::Function *F,
15488                            glow::ExecutionEngine &EE, ElemKind DTy) {
15489   constexpr dim_t size[4] = {3, 2, 3, 4};
15490   auto *input =
15491       createPlaceholderConditionallyQuantized(mod, DTy, size, "input", false);
15492   bindings.allocate(input)->getHandle<DataType>().randomize(-10.0, 10.0,
15493                                                             mod.getPRNG());
15494 
15495   auto *output = F->createUpsample("Upsample", input, 3);
15496   auto *save = F->createSave("Save", output);
15497   bindings.allocate(save->getPlaceholder());
15498 
15499   EE.compile(CompilationMode::Infer);
15500   EE.run(bindings);
15501 
15502   auto resultH = bindings.get(save->getPlaceholder())->getHandle<DataType>();
15503   auto inputH = bindings.get(input)->getHandle<DataType>();
15504 
15505   EXPECT_EQ(resultH.dims()[0], inputH.dims()[0]);
15506   EXPECT_EQ(resultH.dims()[1], 2 * inputH.dims()[1]);
15507   EXPECT_EQ(resultH.dims()[2], 2 * inputH.dims()[2]);
15508   EXPECT_EQ(resultH.dims()[3], 2 * inputH.dims()[3]);
15509   for (dim_t m = 0; m < size[0]; m++) {
15510     for (dim_t i = 0; i < size[1]; i++) {
15511       for (dim_t j = 0; j < size[2]; j++) {
15512         for (dim_t k = 0; k < size[3]; k++) {
15513           EXPECT_EQ(resultH.at({m, 2 * i + 0, 2 * j + 0, 2 * k + 0}),
15514                     static_cast<DataType>(inputH.at({m, i, j, k})));
15515           EXPECT_EQ(resultH.at({m, 2 * i + 0, 2 * j + 0, 2 * k + 1}),
15516                     static_cast<DataType>(inputH.at({m, i, j, k})));
15517           EXPECT_EQ(resultH.at({m, 2 * i + 0, 2 * j + 1, 2 * k + 0}),
15518                     static_cast<DataType>(inputH.at({m, i, j, k})));
15519           EXPECT_EQ(resultH.at({m, 2 * i + 0, 2 * j + 1, 2 * k + 1}),
15520                     static_cast<DataType>(inputH.at({m, i, j, k})));
15521           EXPECT_EQ(resultH.at({m, 2 * i + 1, 2 * j + 0, 2 * k + 0}),
15522                     static_cast<DataType>(inputH.at({m, i, j, k})));
15523           EXPECT_EQ(resultH.at({m, 2 * i + 1, 2 * j + 0, 2 * k + 1}),
15524                     static_cast<DataType>(inputH.at({m, i, j, k})));
15525           EXPECT_EQ(resultH.at({m, 2 * i + 1, 2 * j + 1, 2 * k + 0}),
15526                     static_cast<DataType>(inputH.at({m, i, j, k})));
15527           EXPECT_EQ(resultH.at({m, 2 * i + 1, 2 * j + 1, 2 * k + 1}),
15528                     static_cast<DataType>(inputH.at({m, i, j, k})));
15529         }
15530       }
15531     }
15532   }
15533 }
15534 
15535 template <typename DataType>
testUpsample2D(glow::PlaceholderBindings & bindings,glow::Module & mod,glow::Function * F,glow::ExecutionEngine & EE,ElemKind DTy)15536 static void testUpsample2D(glow::PlaceholderBindings &bindings,
15537                            glow::Module &mod, glow::Function *F,
15538                            glow::ExecutionEngine &EE, ElemKind DTy) {
15539   constexpr dim_t size[3] = {2, 3, 4};
15540   auto *input =
15541       createPlaceholderConditionallyQuantized(mod, DTy, size, "input", false);
15542   bindings.allocate(input)->getHandle<DataType>().randomize(-10.0, 10.0,
15543                                                             mod.getPRNG());
15544 
15545   auto *output = F->createUpsample("Upsample", input, 2);
15546   auto *save = F->createSave("Save", output);
15547   bindings.allocate(save->getPlaceholder());
15548 
15549   EE.compile(CompilationMode::Infer);
15550   EE.run(bindings);
15551 
15552   auto resultH = bindings.get(save->getPlaceholder())->getHandle<DataType>();
15553   auto inputH = bindings.get(input)->getHandle<DataType>();
15554 
15555   EXPECT_EQ(resultH.dims()[0], inputH.dims()[0]);
15556   EXPECT_EQ(resultH.dims()[1], 2 * inputH.dims()[1]);
15557   EXPECT_EQ(resultH.dims()[2], 2 * inputH.dims()[2]);
15558   for (dim_t m = 0; m < size[0]; m++) {
15559     for (dim_t i = 0; i < size[1]; i++) {
15560       for (dim_t j = 0; j < size[2]; j++) {
15561         EXPECT_EQ(resultH.at({m, 2 * i + 0, 2 * j + 0}),
15562                   static_cast<DataType>(inputH.at({m, i, j})));
15563         EXPECT_EQ(resultH.at({m, 2 * i + 0, 2 * j + 1}),
15564                   static_cast<DataType>(inputH.at({m, i, j})));
15565         EXPECT_EQ(resultH.at({m, 2 * i + 1, 2 * j + 0}),
15566                   static_cast<DataType>(inputH.at({m, i, j})));
15567         EXPECT_EQ(resultH.at({m, 2 * i + 1, 2 * j + 1}),
15568                   static_cast<DataType>(inputH.at({m, i, j})));
15569       }
15570     }
15571   }
15572 }
15573 
TEST_P(OperatorTest,Upsample3D_Float)15574 TEST_P(OperatorTest, Upsample3D_Float) {
15575   CHECK_IF_ENABLED();
15576   testUpsample3D<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
15577 }
15578 
TEST_P(OperatorTest,Upsample3D_Int8)15579 TEST_P(OperatorTest, Upsample3D_Int8) {
15580   CHECK_IF_ENABLED();
15581   testUpsample3D<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
15582 }
15583 
TEST_P(OperatorTest,Upsample2D_Float)15584 TEST_P(OperatorTest, Upsample2D_Float) {
15585   CHECK_IF_ENABLED();
15586   testUpsample2D<float>(bindings_, mod_, F_, EE_, ElemKind::FloatTy);
15587 }
15588 
TEST_P(OperatorTest,Upsample2D_Int8)15589 TEST_P(OperatorTest, Upsample2D_Int8) {
15590   CHECK_IF_ENABLED();
15591   testUpsample2D<int8_t>(bindings_, mod_, F_, EE_, ElemKind::Int8QTy);
15592 }
15593 
15594 INSTANTIATE_BACKEND_TEST(OperatorStatelessTest);
15595 INSTANTIATE_BACKEND_TEST(OperatorTest);
15596