1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "BackendTestUtils.h"
18 
19 #include "glow/ExecutionEngine/ExecutionEngine.h"
20 #include "glow/Graph/Graph.h"
21 #include "glow/Graph/PlaceholderBindings.h"
22 #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
23 #include "glow/Support/Random.h"
24 
25 #include "gtest/gtest.h"
26 
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Signals.h"
30 
31 #include <functional>
32 
33 using namespace glow;
34 using llvm::cast;
35 
36 /// This matches the signature that is used for the parameterized tests here,
37 /// i.e. those passing three parameters via a single ::testing::Combine() into
38 /// GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST().
39 using ThreeIntTupleConfig = std::tuple<std::string, std::tuple<int, int, int>>;
40 using FourIntTupleConfig =
41     std::tuple<std::string, std::tuple<int, int, int, int>>;
42 
43 #define SET_BACKEND_KIND_AND_THREE_INT_PARAMS(CONFIG, BACKEND_NAME, PARAM1,    \
44                                               PARAM2, PARAM3)                  \
45   std::tuple<int, int, int> threeIntTupleParams;                               \
46   std::tie(BACKEND_NAME, threeIntTupleParams) = CONFIG;                        \
47   std::tie(PARAM1, PARAM2, PARAM3) = threeIntTupleParams;
48 
49 #define SET_BACKEND_KIND_AND_FOUR_INT_PARAMS(CONFIG, BACKEND_KIND, PARAM1,     \
50                                              PARAM2, PARAM3, PARAM4)           \
51   std::tuple<int, int, int, int> fourIntTupleParams;                           \
52   std::tie(BACKEND_KIND, fourIntTupleParams) = CONFIG;                         \
53   std::tie(PARAM1, PARAM2, PARAM3, PARAM4) = fourIntTupleParams;
54 
55 //===--------------------------------------------------------------------===//
56 //                   Convolution Parameter Sweep Tests
57 //===--------------------------------------------------------------------===//
58 
59 /// Create a simple network that has a single fp convolution.
60 static FunctionTensorPair
createAndInitConvNet(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE,dim_t size,dim_t convDepth,dim_t kernel,dim_t stride,dim_t pad)61 createAndInitConvNet(glow::PlaceholderBindings &bindings,
62                      glow::ExecutionEngine &EE, dim_t size, dim_t convDepth,
63                      dim_t kernel, dim_t stride, dim_t pad) {
64   PseudoRNG PRNG;
65   auto &mod = EE.getModule();
66   Function *F = mod.createFunction("main");
67   auto *var = mod.createPlaceholder(ElemKind::FloatTy,
68                                     {1, size, size, convDepth}, "var", false);
69   bindings.allocate(var)->getHandle().initXavier(1, PRNG);
70 
71   auto *conv =
72       F->createConv(bindings, "conv", var, convDepth, kernel, stride, pad, 1);
73   bindings.get(cast<Placeholder>(conv->getFilter()))->getHandle().clear(0.1);
74   bindings.get(cast<Placeholder>(conv->getBias()))->getHandle().clear(0.1);
75   auto *result = F->createSave("ret", conv);
76   auto *resultTensor = bindings.allocate(result->getPlaceholder());
77   convertPlaceholdersToConstants(F, bindings, {var, result->getPlaceholder()});
78 
79   return std::make_pair(F, resultTensor);
80 }
81 
82 /// Helper to test sweeping across a variety of configurations of a convolution
83 /// by comparing the results to the Interpreter given some \p allowedError.
84 /// \p config contains the backend to compare the Interpreter against, plus the
85 /// specific configuration to run for this test. \p interpElemKind and \p
86 /// backendElemKind are the element kinds to use for the Interpreter and
87 /// backend, respectively.
testParamSweepConv(ThreeIntTupleConfig config,ElemKind interpElemKind,ElemKind backendElemKind,float allowedError)88 static void testParamSweepConv(ThreeIntTupleConfig config,
89                                ElemKind interpElemKind,
90                                ElemKind backendElemKind, float allowedError) {
91   std::string backend;
92   size_t size, depth, kernel;
93   SET_BACKEND_KIND_AND_THREE_INT_PARAMS(config, backend, size, depth, kernel)
94 
95   LOG(INFO) << "Testing Conv with size: " << size << "; depth: " << depth
96             << "; kernel: " << kernel << "\n";
97 
98   auto boundF = std::bind(createAndInitConvNet, std::placeholders::_1,
99                           std::placeholders::_2, size, depth, kernel,
100                           /* stride */ 1, /* pad */ 0);
101   compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind,
102                             allowedError, parCloneCountOpt);
103 }
104 
105 DECLARE_STATELESS_BACKEND_TEST(ConvSweepTest, ThreeIntTupleConfig);
106 
107 GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST(
108     SweepTest, ConvSweepTest,
109     ::testing::Combine(/* size */ ::testing::Values(5, 7, 15),
110                        /* depth */ ::testing::Values(8, 64),
111                        /* kernel */ ::testing::Values(1, 3)));
112 
113 /// Compare backend against the interpreter in Float.
TEST_P(ConvSweepTest,ConvTest_Float)114 TEST_P(ConvSweepTest, ConvTest_Float) {
115   CHECK_IF_ENABLED();
116   testParamSweepConv(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, 0.0001f);
117 }
118 
119 /// Compare backend against the interpreter in Int8.
TEST_P(ConvSweepTest,ConvTest_Int8)120 TEST_P(ConvSweepTest, ConvTest_Int8) {
121   CHECK_IF_ENABLED();
122   testParamSweepConv(GetParam(), ElemKind::FloatTy, ElemKind::Int8QTy, 0.045f);
123 }
124 
125 /// Compare backend against the interpreter in FP16.
TEST_P(ConvSweepTest,ConvTest_Float16)126 TEST_P(ConvSweepTest, ConvTest_Float16) {
127   CHECK_IF_ENABLED();
128   testParamSweepConv(GetParam(), ElemKind::FloatTy, ElemKind::Float16Ty,
129                      0.005f);
130 }
131 
132 /// Compare backend against the interpreter in FP16.
TEST_P(ConvSweepTest,ConvTest_BFloat16)133 TEST_P(ConvSweepTest, ConvTest_BFloat16) {
134   CHECK_IF_ENABLED();
135   testParamSweepConv(GetParam(), ElemKind::FloatTy, ElemKind::BFloat16Ty,
136                      0.005f);
137 }
138 
139 //===--------------------------------------------------------------------===//
140 //                   BatchMatMul Parameter Sweep Tests
141 //===--------------------------------------------------------------------===//
142 
143 /// Create a simple network that has a single fp batch mat mul.
144 static FunctionTensorPair
createAndInitBatchMatMulNet(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE,dim_t N,dim_t A,dim_t Z,dim_t B)145 createAndInitBatchMatMulNet(glow::PlaceholderBindings &bindings,
146                             glow::ExecutionEngine &EE, dim_t N, dim_t A,
147                             dim_t Z, dim_t B) {
148   PseudoRNG PRNG;
149   auto &mod = EE.getModule();
150   Function *F = mod.createFunction("main");
151   auto *LHS = mod.createPlaceholder(ElemKind::FloatTy, {N, A, Z}, "LHS", false);
152   auto *RHS = mod.createPlaceholder(ElemKind::FloatTy, {N, Z, B}, "RHS", false);
153   bindings.allocate(LHS)->getHandle().initXavier(10, PRNG);
154   bindings.allocate(RHS)->getHandle().initXavier(10, PRNG);
155 
156   auto *R = F->createBatchMatMul("BMM", LHS, RHS);
157 
158   auto *save = F->createSave("save", R);
159   auto *resultTensor = bindings.allocate(save->getPlaceholder());
160 
161   return std::make_pair(F, resultTensor);
162 }
163 
164 /// Helper to test sweeping across a variety of configurations of a BatchMatMul
165 /// by comparing the results to the Interpreter given some \p allowedError.
166 /// \p config contains the backend to compare the Interpreter against, plus the
167 /// specific configuration to run for this test. \p interpElemKind and \p
168 /// backendElemKind are the element kinds to use for the Interpreter and
169 /// backend, respectively.
testParamSweepBatchMatMul(ThreeIntTupleConfig config,ElemKind interpElemKind,ElemKind backendElemKind,float allowedError)170 static void testParamSweepBatchMatMul(ThreeIntTupleConfig config,
171                                       ElemKind interpElemKind,
172                                       ElemKind backendElemKind,
173                                       float allowedError) {
174   std::string backend;
175   size_t N, A, Z;
176   SET_BACKEND_KIND_AND_THREE_INT_PARAMS(config, backend, N, A, Z);
177   size_t B = A;
178 
179   LOG(INFO) << "\n\tTesting BatchMatMul with N: " << N << "; A: " << A
180             << "; Z: " << Z << "; B: " << B << "\n";
181 
182   // Multiplying LHS {N, A, Z} by RHS {N, Z, B} to get result {N, A, B}.
183   auto boundF = std::bind(createAndInitBatchMatMulNet, std::placeholders::_1,
184                           std::placeholders::_2, N, A, Z, B);
185   compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind,
186                             allowedError, parCloneCountOpt);
187 }
188 
189 DECLARE_STATELESS_BACKEND_TEST(BatchMatMulSweepTest, ThreeIntTupleConfig);
190 
191 GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST(
192     SweepTest, BatchMatMulSweepTest,
193     ::testing::Combine(/* N */ ::testing::Values(1, 4, 16, 24),
194                        /* A */ ::testing::Range(10, 16),
195                        /* Z */ ::testing::Values(32, 64, 128, 256)));
196 
197 /// Compare backend against the interpreter in Float.
TEST_P(BatchMatMulSweepTest,BatchMatMulTest_Float)198 TEST_P(BatchMatMulSweepTest, BatchMatMulTest_Float) {
199   CHECK_IF_ENABLED();
200   testParamSweepBatchMatMul(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
201                             0.0001f);
202 }
203 
204 /// Compare backend against the interpreter in Int8.
TEST_P(BatchMatMulSweepTest,BatchMatMulTest_Int8)205 TEST_P(BatchMatMulSweepTest, BatchMatMulTest_Int8) {
206   CHECK_IF_ENABLED();
207   testParamSweepBatchMatMul(GetParam(), ElemKind::FloatTy, ElemKind::Int8QTy,
208                             0.06f);
209 }
210 
211 /// Compare backend against the interpreter in FP16.
TEST_P(BatchMatMulSweepTest,BatchMatMulTest_Float16)212 TEST_P(BatchMatMulSweepTest, BatchMatMulTest_Float16) {
213   CHECK_IF_ENABLED();
214   testParamSweepBatchMatMul(GetParam(), ElemKind::FloatTy, ElemKind::Float16Ty,
215                             0.005f);
216 }
217 
218 /// Compare backend against the interpreter in FP16.
TEST_P(BatchMatMulSweepTest,BatchMatMulTest_BFloat16)219 TEST_P(BatchMatMulSweepTest, BatchMatMulTest_BFloat16) {
220   CHECK_IF_ENABLED();
221   testParamSweepBatchMatMul(GetParam(), ElemKind::FloatTy, ElemKind::BFloat16Ty,
222                             0.005f);
223 }
224 
225 //===--------------------------------------------------------------------===//
226 //                   FullyConnected Parameter Sweep Tests
227 //===--------------------------------------------------------------------===//
228 
229 /// Create a simple network that has a single fp FC.
230 static FunctionTensorPair
createAndInitFCNet(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE,dim_t A,dim_t Z,dim_t B)231 createAndInitFCNet(glow::PlaceholderBindings &bindings,
232                    glow::ExecutionEngine &EE, dim_t A, dim_t Z, dim_t B) {
233   PseudoRNG PRNG;
234   auto &mod = EE.getModule();
235   Function *F = mod.createFunction("main");
236   auto *IP = mod.createPlaceholder(ElemKind::FloatTy, {A, Z}, "input", false);
237   auto *WC = mod.createConstant(ElemKind::FloatTy, {Z, B}, "weights");
238   auto *BC = mod.createConstant(ElemKind::FloatTy, {B}, "bias");
239   bindings.allocate(IP)->getHandle().randomize(-0.2, 0.2, mod.getPRNG());
240   BC->getPayloadMutable().getHandle().randomize(0, 0.000005, mod.getPRNG());
241   WC->getPayloadMutable().getHandle().randomize(-0.4, 0.4, mod.getPRNG());
242 
243   auto *FC = F->createFullyConnected("FC", IP, WC, BC);
244   auto *save = F->createSave("save", FC);
245   auto *resultTensor = bindings.allocate(save->getPlaceholder());
246 
247   return std::make_pair(F, resultTensor);
248 }
249 
250 /// Helper to test sweeping across a variety of configurations of a FC by
251 /// comparing the results to the Interpreter given some \p allowedError.
252 /// \p config contains the backend to compare the Interpreter against, plus the
253 /// specific configuration to run for this test. \p interpElemKind and \p
254 /// backendElemKind are the element kinds to use for the Interpreter and
255 /// backend, respectively.
testParamSweepFC(ThreeIntTupleConfig config,ElemKind interpElemKind,ElemKind backendElemKind,float allowedError)256 static void testParamSweepFC(ThreeIntTupleConfig config,
257                              ElemKind interpElemKind, ElemKind backendElemKind,
258                              float allowedError) {
259   std::string backend;
260   size_t A, Z, B;
261   SET_BACKEND_KIND_AND_THREE_INT_PARAMS(config, backend, A, Z, B);
262 
263   LOG(INFO) << "\n\tTesting FC with A: " << A << "; Z: " << Z << "; B: " << B
264             << "\n";
265 
266   auto boundF = std::bind(createAndInitFCNet, std::placeholders::_1,
267                           std::placeholders::_2, A, Z, B);
268   compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind,
269                             allowedError, parCloneCountOpt);
270 }
271 
272 DECLARE_STATELESS_BACKEND_TEST(FCSweepTest, ThreeIntTupleConfig);
273 
274 GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST(
275     SweepTest, FCSweepTest,
276     ::testing::Combine(
277         /* A */ ::testing::Values(1, 4, 16, 64),
278         /* Z */ ::testing::Values(16, 128, 256, 512, 1024, 2048, 4096),
279         /* B */ ::testing::Values(1, 48, 64, 256, 1024)));
280 
281 /// Compare backend against the interpreter in Float.
TEST_P(FCSweepTest,FCTest_Float)282 TEST_P(FCSweepTest, FCTest_Float) {
283   CHECK_IF_ENABLED();
284   testParamSweepFC(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, 0.0001f);
285 }
286 
287 /// Compare backend against the interpreter in Int8.
TEST_P(FCSweepTest,FCTest_Int8)288 TEST_P(FCSweepTest, FCTest_Int8) {
289   CHECK_IF_ENABLED();
290   testParamSweepFC(GetParam(), ElemKind::FloatTy, ElemKind::Int8QTy, 0.065f);
291 }
292 
293 /// Compare backend against the interpreter in FP16.
TEST_P(FCSweepTest,FCTest_Float16)294 TEST_P(FCSweepTest, FCTest_Float16) {
295   CHECK_IF_ENABLED();
296   testParamSweepFC(GetParam(), ElemKind::FloatTy, ElemKind::Float16Ty, 0.005f);
297 }
298 
299 /// Compare backend against the interpreter in BFloat16.
TEST_P(FCSweepTest,FCTest_BFloat16)300 TEST_P(FCSweepTest, FCTest_BFloat16) {
301   CHECK_IF_ENABLED();
302   testParamSweepFC(GetParam(), ElemKind::FloatTy, ElemKind::BFloat16Ty, 0.005f);
303 }
304 
305 //===--------------------------------------------------------------------===//
306 //                   Concat Parameter Sweep Tests
307 //===--------------------------------------------------------------------===//
308 
309 /// Create a simple network that has a single fp Concat.
310 static FunctionTensorPair
createAndInitConcatNet(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE,size_t numInputs,size_t numDims,size_t maxLength,size_t axis)311 createAndInitConcatNet(glow::PlaceholderBindings &bindings,
312                        glow::ExecutionEngine &EE, size_t numInputs,
313                        size_t numDims, size_t maxLength, size_t axis) {
314   PseudoRNG PRNG;
315   auto &mod = EE.getModule();
316   Function *F = mod.createFunction("main");
317 
318   // Make leading dimensions smaller than trailing. Reduces size of tests and is
319   // also in line with typical tests.
320   std::vector<dim_t> dims(numDims, maxLength);
321   for (size_t i = 0; i < numDims; i++) {
322     dims[numDims - 1 - i] /= std::pow(2, i);
323   }
324 
325   std::vector<NodeValue> inputs(numInputs);
326   for (size_t i = 0; i < numInputs; i++) {
327     auto *IP = mod.createPlaceholder(ElemKind::FloatTy, dims, "input", false);
328     bindings.allocate(IP)->getHandle().randomize(-0.2, 0.2, mod.getPRNG());
329     assert(IP);
330     inputs[i] = IP->getOutput();
331   }
332 
333   auto *concat = F->createConcat("concat", inputs, axis);
334   auto *save = F->createSave("save", concat);
335   auto *resultTensor = bindings.allocate(save->getPlaceholder());
336 
337   return std::make_pair(F, resultTensor);
338 }
339 
340 /// Helper to test sweeping across a variety of configurations of a Concat by
341 /// comparing the results to the Interpreter given some \p allowedError.
342 /// \p config contains the backend to compare the Interpreter against, plus the
343 /// specific configuration to run for this test. \p interpElemKind and \p
344 /// backendElemKind are the element kinds to use for the Interpreter and
345 /// backend, respectively.
testParamSweepConcat(FourIntTupleConfig config,ElemKind interpElemKind,ElemKind backendElemKind,float allowedError)346 static void testParamSweepConcat(FourIntTupleConfig config,
347                                  ElemKind interpElemKind,
348                                  ElemKind backendElemKind, float allowedError) {
349   std::string backend;
350   size_t numInputs, numDims, maxLength, axis;
351   SET_BACKEND_KIND_AND_FOUR_INT_PARAMS(config, backend, numInputs, numDims,
352                                        maxLength, axis);
353   // Exit if axis outside of numDims.
354   if (axis >= numDims) {
355     return;
356   }
357 
358   LOG(INFO) << "\n\tTesting Concat with numInputs: " << numInputs
359             << "; numDims: " << numDims << "; maxLength: " << maxLength
360             << "; axis: " << axis << "\n";
361 
362   auto boundF =
363       std::bind(createAndInitConcatNet, std::placeholders::_1,
364                 std::placeholders::_2, numInputs, numDims, maxLength, axis);
365   compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind,
366                             allowedError, parCloneCountOpt);
367 }
368 
369 DECLARE_STATELESS_BACKEND_TEST(ConcatSweepTest, FourIntTupleConfig);
370 
371 GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST(
372     SweepTest, ConcatSweepTest,
373     ::testing::Combine(/* numInputs */ ::testing::Values(1, 2, 4, 8, 16, 32, 64,
374                                                          128, 192, 256),
375                        /* numDims */ ::testing::Range(1, 4),
376                        /* maxLength */ ::testing::Values(16, 32, 64, 128),
377                        /* axis */ ::testing::Range(0, 3)));
378 
379 /// Compare backend against the interpreter in Float.
TEST_P(ConcatSweepTest,ConcatTest_Float)380 TEST_P(ConcatSweepTest, ConcatTest_Float) {
381   CHECK_IF_ENABLED();
382   testParamSweepConcat(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, 0.0f);
383 }
384 
385 /// Compare backend against the interpreter in Int8. Note that we do not use the
386 /// same ElemKind for the Interpreter; this is because the backend will
387 /// quantize/dequantize the input/result anyway, so the comparison wouldn't be
388 /// purely on data movement.
TEST_P(ConcatSweepTest,ConcatTest_Int8)389 TEST_P(ConcatSweepTest, ConcatTest_Int8) {
390   CHECK_IF_ENABLED();
391   testParamSweepConcat(GetParam(), ElemKind::FloatTy, ElemKind::Int8QTy,
392                        0.002f);
393 }
394 
395 /// Compare backend against the interpreter in Float16. Note that we do not use
396 /// the same ElemKind for the Interpreter; this is because the backend will
397 /// down/up convert the input/result anyway, so the comparison wouldn't be
398 /// purely on data movement.
TEST_P(ConcatSweepTest,ConcatTest_Float16)399 TEST_P(ConcatSweepTest, ConcatTest_Float16) {
400   CHECK_IF_ENABLED();
401   testParamSweepConcat(GetParam(), ElemKind::FloatTy, ElemKind::Float16Ty,
402                        0.0001f);
403 }
404 
405 /// Compare backend against the interpreter in BFloat16. Note that we do not use
406 /// the same ElemKind for the Interpreter; this is because the backend will
407 /// down/up convert the input/result anyway, so the comparison wouldn't be
408 /// purely on data movement.
TEST_P(ConcatSweepTest,ConcatTest_BFloat16)409 TEST_P(ConcatSweepTest, ConcatTest_BFloat16) {
410   CHECK_IF_ENABLED();
411   testParamSweepConcat(GetParam(), ElemKind::FloatTy, ElemKind::BFloat16Ty,
412                        0.0001f);
413 }
414 
415 //===--------------------------------------------------------------------===//
416 //                   SLWS Parameter Sweep Tests
417 //===--------------------------------------------------------------------===//
418 
419 /// Create a simple network that has a single fp SLWS.
420 static FunctionTensorPair
createAndInitSLWSNet(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE,dim_t embeddingRows,dim_t embeddingDim,dim_t numLengths,bool rowwiseQuantize,bool fused,bool FP16,bool accumFP16)421 createAndInitSLWSNet(glow::PlaceholderBindings &bindings,
422                      glow::ExecutionEngine &EE, dim_t embeddingRows,
423                      dim_t embeddingDim, dim_t numLengths, bool rowwiseQuantize,
424                      bool fused, bool FP16, bool accumFP16) {
425   PseudoRNG PRNG;
426   auto &mod = EE.getModule();
427   Function *F = mod.createFunction("main");
428 
429   // Initialize lengths according to the number provided by the test. Note that
430   // we arbitrarily set them between [80,120].
431   auto *lengths =
432       mod.createPlaceholder(ElemKind::Int32ITy, {numLengths}, "lengths", false);
433   auto LH = bindings.allocate(lengths)->getHandle<int32_t>();
434   LH.randomize(80, 120, mod.getPRNG());
435 
436   // Get the sum of the lengths to then use as the size for indices and weights.
437   dim_t sumOfLengths = 0;
438   for (const int32_t &e : LH) {
439     sumOfLengths += e;
440   }
441 
442   // Initialize indices to size of sum of lengths. Randomly set them to point
443   // somewhere inside the embedding.
444   auto *indices = mod.createPlaceholder(ElemKind::Int64ITy, {sumOfLengths},
445                                         "indices", false);
446   bindings.allocate(indices)->getHandle<int64_t>().randomize(
447       0, embeddingRows - 1, mod.getPRNG());
448 
449   // Xavier initialize the weights with the correct data type.
450   Constant *weights;
451   if (FP16) {
452     weights =
453         mod.createConstant(ElemKind::Float16Ty, {sumOfLengths}, "weights");
454     weights->getPayloadMutable().getHandle<float16_t>().initXavier(
455         weights->getType()->size() * 2, mod.getPRNG());
456   } else {
457     weights = mod.createConstant(ElemKind::FloatTy, {sumOfLengths}, "weights");
458     weights->getPayloadMutable().getHandle<float>().initXavier(
459         weights->getType()->size() * 2, mod.getPRNG());
460   }
461 
462   // Create the embedding; non-RWQ versions will simply create a Constant with
463   // it, while RWQ versions will use its data to create a RWQ Constant
464   // internally in the Node constructor.
465   Tensor embeddingT(ElemKind::FloatTy, {embeddingRows, embeddingDim});
466   embeddingT.getHandle().initXavier(embeddingT.size() * 2, mod.getPRNG());
467 
468   // Create the SLWS based on provided options.
469   Node *SLWS;
470   if (!rowwiseQuantize) {
471     auto *embeddingC = mod.createConstant("embedding", std::move(embeddingT));
472     SLWS = F->createSparseLengthsWeightedSum("SLWS", embeddingC, weights,
473                                              indices, lengths);
474   } else {
475     if (fused) {
476       const ElemKind precision =
477           FP16 ? ElemKind::UInt8FusedFP16QTy : ElemKind::UInt8FusedQTy;
478       SLWS = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
479           "FRQSLWS", embeddingT, weights, indices, lengths, precision,
480           accumFP16);
481     } else {
482       const ElemKind precision = FP16 ? ElemKind::Float16Ty : ElemKind::FloatTy;
483       SLWS = F->createRowwiseQuantizedSparseLengthsWeightedSum(
484           "RQSLWS", embeddingT, weights, indices, lengths,
485           quantization::Schema::Asymmetric, precision, accumFP16);
486     }
487   }
488   auto *save = F->createSave("save", SLWS);
489   auto *resultTensor = bindings.allocate(save->getPlaceholder());
490 
491   return std::make_pair(F, resultTensor);
492 }
493 
494 /// Helper to test sweeping across a variety of configurations of a SLWS by
495 /// comparing the results to the Interpreter given some \p allowedError.
496 /// \p config contains the backend to compare the Interpreter against, plus the
497 /// specific configuration to run for this test. \p interpElemKind and \p
498 /// backendElemKind are the element kinds to use for the Interpreter and
499 /// backend, respectively. Pass in options for the test \p rowwiseQuantize,
500 /// \p fused, \p FP16, and \p accumFP16.
testParamSweepSLWS(ThreeIntTupleConfig config,ElemKind interpElemKind,ElemKind backendElemKind,float allowedError,bool rowwiseQuantize,bool fused,bool FP16,bool accumFP16)501 static void testParamSweepSLWS(ThreeIntTupleConfig config,
502                                ElemKind interpElemKind,
503                                ElemKind backendElemKind, float allowedError,
504                                bool rowwiseQuantize, bool fused, bool FP16,
505                                bool accumFP16) {
506   std::string backend;
507   size_t embeddingRows, embeddingDim, numLengths;
508   SET_BACKEND_KIND_AND_THREE_INT_PARAMS(config, backend, embeddingRows,
509                                         embeddingDim, numLengths);
510 
511   LOG(INFO) << "\n\tTesting SLWS with embeddingRows: " << embeddingRows
512             << "; embeddingDim: " << embeddingDim
513             << "; numLengths: " << numLengths << "\n";
514 
515   auto boundF = std::bind(createAndInitSLWSNet, std::placeholders::_1,
516                           std::placeholders::_2, embeddingRows, embeddingDim,
517                           numLengths, rowwiseQuantize, fused, FP16, accumFP16);
518   compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind,
519                             allowedError, parCloneCountOpt);
520 }
521 
522 DECLARE_STATELESS_BACKEND_TEST(SLWSSweepTest, ThreeIntTupleConfig);
523 
524 GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST(
525     SweepTest, SLWSSweepTest,
526     ::testing::Combine(
527         /* embeddingRows */ ::testing::Values(100, 1000, 10000, 100000),
528         /* embeddingDim */ ::testing::Values(32, 64, 96, 128),
529         /* numLengths */ ::testing::Values(16, 32, 64, 128, 256)));
530 
531 /// Compare backend against the interpreter.
TEST_P(SLWSSweepTest,SLWS_Float)532 TEST_P(SLWSSweepTest, SLWS_Float) {
533   CHECK_IF_ENABLED();
534   testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
535                      0.000001f,
536                      /* rowwiseQuantize */ false,
537                      /* fused */ false, /* FP16 */ false,
538                      /* accumFP16 */ false);
539 }
540 
541 /// Compare backend against the interpreter in Float.
TEST_P(SLWSSweepTest,RWQSLWS_Float)542 TEST_P(SLWSSweepTest, RWQSLWS_Float) {
543   CHECK_IF_ENABLED();
544   testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
545                      0.000001f,
546                      /* rowwiseQuantize */ true,
547                      /* fused */ false, /* FP16 */ false,
548                      /* accumFP16 */ false);
549 }
550 
551 /// Compare backend against the interpreter in Float.
TEST_P(SLWSSweepTest,FRWQSLWS_Float)552 TEST_P(SLWSSweepTest, FRWQSLWS_Float) {
553   CHECK_IF_ENABLED();
554   testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
555                      0.000001f,
556                      /* rowwiseQuantize */ true,
557                      /* fused */ true, /* FP16 */ false,
558                      /* accumFP16 */ false);
559 }
560 
561 /// Compare backend against the interpreter in Float.
TEST_P(SLWSSweepTest,RWQSLWS_Float16)562 TEST_P(SLWSSweepTest, RWQSLWS_Float16) {
563   // Note: not currently enabled for any open-source backends, as only the
564   // Interpreter supports this.
565   CHECK_IF_ENABLED();
566   testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
567                      0.000001f,
568                      /* rowwiseQuantize */ true,
569                      /* fused */ false, /* FP16 */ true,
570                      /* accumFP16 */ false);
571 }
572 
573 /// Compare backend against the interpreter in Float.
TEST_P(SLWSSweepTest,FRWQSLWS_Float16)574 TEST_P(SLWSSweepTest, FRWQSLWS_Float16) {
575   // Note: not currently enabled for any open-source backends, as only the
576   // Interpreter supports this.
577   CHECK_IF_ENABLED();
578   testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
579                      0.000001f,
580                      /* rowwiseQuantize */ true,
581                      /* fused */ true, /* FP16 */ true,
582                      /* accumFP16 */ false);
583 }
584 
585 /// Compare backend against the interpreter in Float.
TEST_P(SLWSSweepTest,RWQSLWS_Float16_AccumFloat16)586 TEST_P(SLWSSweepTest, RWQSLWS_Float16_AccumFloat16) {
587   // Note: not currently enabled for any open-source backends, as only the
588   // Interpreter supports this.
589   CHECK_IF_ENABLED();
590   testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
591                      0.000001f,
592                      /* rowwiseQuantize */ true,
593                      /* fused */ false, /* FP16 */ true,
594                      /* accumFP16 */ true);
595 }
596 
597 /// Compare backend against the interpreter in Float.
TEST_P(SLWSSweepTest,FRWQSLWS_Float16_AccumFloat16)598 TEST_P(SLWSSweepTest, FRWQSLWS_Float16_AccumFloat16) {
599   // Note: not currently enabled for any open-source backends, as only the
600   // Interpreter supports this.
601   CHECK_IF_ENABLED();
602   testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
603                      0.000001f,
604                      /* rowwiseQuantize */ true,
605                      /* fused */ true, /* FP16 */ true,
606                      /* accumFP16 */ true);
607 }
608 
main(int argc,char ** argv)609 int main(int argc, char **argv) {
610   ::testing::InitGoogleTest(&argc, argv);
611   llvm::sys::PrintStackTraceOnErrorSignal(argv[0]);
612   llvm::cl::ParseCommandLineOptions(argc, argv);
613   return RUN_ALL_TESTS();
614 }
615