1 /**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "BackendTestUtils.h"
18
19 #include "glow/ExecutionEngine/ExecutionEngine.h"
20 #include "glow/Graph/Graph.h"
21 #include "glow/Graph/PlaceholderBindings.h"
22 #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
23 #include "glow/Support/Random.h"
24
25 #include "gtest/gtest.h"
26
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Signals.h"
30
31 #include <functional>
32
33 using namespace glow;
34 using llvm::cast;
35
36 /// This matches the signature that is used for the parameterized tests here,
37 /// i.e. those passing three parameters via a single ::testing::Combine() into
38 /// GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST().
39 using ThreeIntTupleConfig = std::tuple<std::string, std::tuple<int, int, int>>;
40 using FourIntTupleConfig =
41 std::tuple<std::string, std::tuple<int, int, int, int>>;
42
43 #define SET_BACKEND_KIND_AND_THREE_INT_PARAMS(CONFIG, BACKEND_NAME, PARAM1, \
44 PARAM2, PARAM3) \
45 std::tuple<int, int, int> threeIntTupleParams; \
46 std::tie(BACKEND_NAME, threeIntTupleParams) = CONFIG; \
47 std::tie(PARAM1, PARAM2, PARAM3) = threeIntTupleParams;
48
49 #define SET_BACKEND_KIND_AND_FOUR_INT_PARAMS(CONFIG, BACKEND_KIND, PARAM1, \
50 PARAM2, PARAM3, PARAM4) \
51 std::tuple<int, int, int, int> fourIntTupleParams; \
52 std::tie(BACKEND_KIND, fourIntTupleParams) = CONFIG; \
53 std::tie(PARAM1, PARAM2, PARAM3, PARAM4) = fourIntTupleParams;
54
55 //===--------------------------------------------------------------------===//
56 // Convolution Parameter Sweep Tests
57 //===--------------------------------------------------------------------===//
58
59 /// Create a simple network that has a single fp convolution.
60 static FunctionTensorPair
createAndInitConvNet(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE,dim_t size,dim_t convDepth,dim_t kernel,dim_t stride,dim_t pad)61 createAndInitConvNet(glow::PlaceholderBindings &bindings,
62 glow::ExecutionEngine &EE, dim_t size, dim_t convDepth,
63 dim_t kernel, dim_t stride, dim_t pad) {
64 PseudoRNG PRNG;
65 auto &mod = EE.getModule();
66 Function *F = mod.createFunction("main");
67 auto *var = mod.createPlaceholder(ElemKind::FloatTy,
68 {1, size, size, convDepth}, "var", false);
69 bindings.allocate(var)->getHandle().initXavier(1, PRNG);
70
71 auto *conv =
72 F->createConv(bindings, "conv", var, convDepth, kernel, stride, pad, 1);
73 bindings.get(cast<Placeholder>(conv->getFilter()))->getHandle().clear(0.1);
74 bindings.get(cast<Placeholder>(conv->getBias()))->getHandle().clear(0.1);
75 auto *result = F->createSave("ret", conv);
76 auto *resultTensor = bindings.allocate(result->getPlaceholder());
77 convertPlaceholdersToConstants(F, bindings, {var, result->getPlaceholder()});
78
79 return std::make_pair(F, resultTensor);
80 }
81
82 /// Helper to test sweeping across a variety of configurations of a convolution
83 /// by comparing the results to the Interpreter given some \p allowedError.
84 /// \p config contains the backend to compare the Interpreter against, plus the
85 /// specific configuration to run for this test. \p interpElemKind and \p
86 /// backendElemKind are the element kinds to use for the Interpreter and
87 /// backend, respectively.
testParamSweepConv(ThreeIntTupleConfig config,ElemKind interpElemKind,ElemKind backendElemKind,float allowedError)88 static void testParamSweepConv(ThreeIntTupleConfig config,
89 ElemKind interpElemKind,
90 ElemKind backendElemKind, float allowedError) {
91 std::string backend;
92 size_t size, depth, kernel;
93 SET_BACKEND_KIND_AND_THREE_INT_PARAMS(config, backend, size, depth, kernel)
94
95 LOG(INFO) << "Testing Conv with size: " << size << "; depth: " << depth
96 << "; kernel: " << kernel << "\n";
97
98 auto boundF = std::bind(createAndInitConvNet, std::placeholders::_1,
99 std::placeholders::_2, size, depth, kernel,
100 /* stride */ 1, /* pad */ 0);
101 compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind,
102 allowedError, parCloneCountOpt);
103 }
104
105 DECLARE_STATELESS_BACKEND_TEST(ConvSweepTest, ThreeIntTupleConfig);
106
107 GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST(
108 SweepTest, ConvSweepTest,
109 ::testing::Combine(/* size */ ::testing::Values(5, 7, 15),
110 /* depth */ ::testing::Values(8, 64),
111 /* kernel */ ::testing::Values(1, 3)));
112
113 /// Compare backend against the interpreter in Float.
TEST_P(ConvSweepTest,ConvTest_Float)114 TEST_P(ConvSweepTest, ConvTest_Float) {
115 CHECK_IF_ENABLED();
116 testParamSweepConv(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, 0.0001f);
117 }
118
119 /// Compare backend against the interpreter in Int8.
TEST_P(ConvSweepTest,ConvTest_Int8)120 TEST_P(ConvSweepTest, ConvTest_Int8) {
121 CHECK_IF_ENABLED();
122 testParamSweepConv(GetParam(), ElemKind::FloatTy, ElemKind::Int8QTy, 0.045f);
123 }
124
125 /// Compare backend against the interpreter in FP16.
TEST_P(ConvSweepTest,ConvTest_Float16)126 TEST_P(ConvSweepTest, ConvTest_Float16) {
127 CHECK_IF_ENABLED();
128 testParamSweepConv(GetParam(), ElemKind::FloatTy, ElemKind::Float16Ty,
129 0.005f);
130 }
131
132 /// Compare backend against the interpreter in FP16.
TEST_P(ConvSweepTest,ConvTest_BFloat16)133 TEST_P(ConvSweepTest, ConvTest_BFloat16) {
134 CHECK_IF_ENABLED();
135 testParamSweepConv(GetParam(), ElemKind::FloatTy, ElemKind::BFloat16Ty,
136 0.005f);
137 }
138
139 //===--------------------------------------------------------------------===//
140 // BatchMatMul Parameter Sweep Tests
141 //===--------------------------------------------------------------------===//
142
143 /// Create a simple network that has a single fp batch mat mul.
144 static FunctionTensorPair
createAndInitBatchMatMulNet(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE,dim_t N,dim_t A,dim_t Z,dim_t B)145 createAndInitBatchMatMulNet(glow::PlaceholderBindings &bindings,
146 glow::ExecutionEngine &EE, dim_t N, dim_t A,
147 dim_t Z, dim_t B) {
148 PseudoRNG PRNG;
149 auto &mod = EE.getModule();
150 Function *F = mod.createFunction("main");
151 auto *LHS = mod.createPlaceholder(ElemKind::FloatTy, {N, A, Z}, "LHS", false);
152 auto *RHS = mod.createPlaceholder(ElemKind::FloatTy, {N, Z, B}, "RHS", false);
153 bindings.allocate(LHS)->getHandle().initXavier(10, PRNG);
154 bindings.allocate(RHS)->getHandle().initXavier(10, PRNG);
155
156 auto *R = F->createBatchMatMul("BMM", LHS, RHS);
157
158 auto *save = F->createSave("save", R);
159 auto *resultTensor = bindings.allocate(save->getPlaceholder());
160
161 return std::make_pair(F, resultTensor);
162 }
163
164 /// Helper to test sweeping across a variety of configurations of a BatchMatMul
165 /// by comparing the results to the Interpreter given some \p allowedError.
166 /// \p config contains the backend to compare the Interpreter against, plus the
167 /// specific configuration to run for this test. \p interpElemKind and \p
168 /// backendElemKind are the element kinds to use for the Interpreter and
169 /// backend, respectively.
testParamSweepBatchMatMul(ThreeIntTupleConfig config,ElemKind interpElemKind,ElemKind backendElemKind,float allowedError)170 static void testParamSweepBatchMatMul(ThreeIntTupleConfig config,
171 ElemKind interpElemKind,
172 ElemKind backendElemKind,
173 float allowedError) {
174 std::string backend;
175 size_t N, A, Z;
176 SET_BACKEND_KIND_AND_THREE_INT_PARAMS(config, backend, N, A, Z);
177 size_t B = A;
178
179 LOG(INFO) << "\n\tTesting BatchMatMul with N: " << N << "; A: " << A
180 << "; Z: " << Z << "; B: " << B << "\n";
181
182 // Multiplying LHS {N, A, Z} by RHS {N, Z, B} to get result {N, A, B}.
183 auto boundF = std::bind(createAndInitBatchMatMulNet, std::placeholders::_1,
184 std::placeholders::_2, N, A, Z, B);
185 compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind,
186 allowedError, parCloneCountOpt);
187 }
188
189 DECLARE_STATELESS_BACKEND_TEST(BatchMatMulSweepTest, ThreeIntTupleConfig);
190
191 GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST(
192 SweepTest, BatchMatMulSweepTest,
193 ::testing::Combine(/* N */ ::testing::Values(1, 4, 16, 24),
194 /* A */ ::testing::Range(10, 16),
195 /* Z */ ::testing::Values(32, 64, 128, 256)));
196
197 /// Compare backend against the interpreter in Float.
TEST_P(BatchMatMulSweepTest,BatchMatMulTest_Float)198 TEST_P(BatchMatMulSweepTest, BatchMatMulTest_Float) {
199 CHECK_IF_ENABLED();
200 testParamSweepBatchMatMul(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
201 0.0001f);
202 }
203
204 /// Compare backend against the interpreter in Int8.
TEST_P(BatchMatMulSweepTest,BatchMatMulTest_Int8)205 TEST_P(BatchMatMulSweepTest, BatchMatMulTest_Int8) {
206 CHECK_IF_ENABLED();
207 testParamSweepBatchMatMul(GetParam(), ElemKind::FloatTy, ElemKind::Int8QTy,
208 0.06f);
209 }
210
211 /// Compare backend against the interpreter in FP16.
TEST_P(BatchMatMulSweepTest,BatchMatMulTest_Float16)212 TEST_P(BatchMatMulSweepTest, BatchMatMulTest_Float16) {
213 CHECK_IF_ENABLED();
214 testParamSweepBatchMatMul(GetParam(), ElemKind::FloatTy, ElemKind::Float16Ty,
215 0.005f);
216 }
217
218 /// Compare backend against the interpreter in FP16.
TEST_P(BatchMatMulSweepTest,BatchMatMulTest_BFloat16)219 TEST_P(BatchMatMulSweepTest, BatchMatMulTest_BFloat16) {
220 CHECK_IF_ENABLED();
221 testParamSweepBatchMatMul(GetParam(), ElemKind::FloatTy, ElemKind::BFloat16Ty,
222 0.005f);
223 }
224
225 //===--------------------------------------------------------------------===//
226 // FullyConnected Parameter Sweep Tests
227 //===--------------------------------------------------------------------===//
228
229 /// Create a simple network that has a single fp FC.
230 static FunctionTensorPair
createAndInitFCNet(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE,dim_t A,dim_t Z,dim_t B)231 createAndInitFCNet(glow::PlaceholderBindings &bindings,
232 glow::ExecutionEngine &EE, dim_t A, dim_t Z, dim_t B) {
233 PseudoRNG PRNG;
234 auto &mod = EE.getModule();
235 Function *F = mod.createFunction("main");
236 auto *IP = mod.createPlaceholder(ElemKind::FloatTy, {A, Z}, "input", false);
237 auto *WC = mod.createConstant(ElemKind::FloatTy, {Z, B}, "weights");
238 auto *BC = mod.createConstant(ElemKind::FloatTy, {B}, "bias");
239 bindings.allocate(IP)->getHandle().randomize(-0.2, 0.2, mod.getPRNG());
240 BC->getPayloadMutable().getHandle().randomize(0, 0.000005, mod.getPRNG());
241 WC->getPayloadMutable().getHandle().randomize(-0.4, 0.4, mod.getPRNG());
242
243 auto *FC = F->createFullyConnected("FC", IP, WC, BC);
244 auto *save = F->createSave("save", FC);
245 auto *resultTensor = bindings.allocate(save->getPlaceholder());
246
247 return std::make_pair(F, resultTensor);
248 }
249
250 /// Helper to test sweeping across a variety of configurations of a FC by
251 /// comparing the results to the Interpreter given some \p allowedError.
252 /// \p config contains the backend to compare the Interpreter against, plus the
253 /// specific configuration to run for this test. \p interpElemKind and \p
254 /// backendElemKind are the element kinds to use for the Interpreter and
255 /// backend, respectively.
testParamSweepFC(ThreeIntTupleConfig config,ElemKind interpElemKind,ElemKind backendElemKind,float allowedError)256 static void testParamSweepFC(ThreeIntTupleConfig config,
257 ElemKind interpElemKind, ElemKind backendElemKind,
258 float allowedError) {
259 std::string backend;
260 size_t A, Z, B;
261 SET_BACKEND_KIND_AND_THREE_INT_PARAMS(config, backend, A, Z, B);
262
263 LOG(INFO) << "\n\tTesting FC with A: " << A << "; Z: " << Z << "; B: " << B
264 << "\n";
265
266 auto boundF = std::bind(createAndInitFCNet, std::placeholders::_1,
267 std::placeholders::_2, A, Z, B);
268 compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind,
269 allowedError, parCloneCountOpt);
270 }
271
272 DECLARE_STATELESS_BACKEND_TEST(FCSweepTest, ThreeIntTupleConfig);
273
274 GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST(
275 SweepTest, FCSweepTest,
276 ::testing::Combine(
277 /* A */ ::testing::Values(1, 4, 16, 64),
278 /* Z */ ::testing::Values(16, 128, 256, 512, 1024, 2048, 4096),
279 /* B */ ::testing::Values(1, 48, 64, 256, 1024)));
280
281 /// Compare backend against the interpreter in Float.
TEST_P(FCSweepTest,FCTest_Float)282 TEST_P(FCSweepTest, FCTest_Float) {
283 CHECK_IF_ENABLED();
284 testParamSweepFC(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, 0.0001f);
285 }
286
287 /// Compare backend against the interpreter in Int8.
TEST_P(FCSweepTest,FCTest_Int8)288 TEST_P(FCSweepTest, FCTest_Int8) {
289 CHECK_IF_ENABLED();
290 testParamSweepFC(GetParam(), ElemKind::FloatTy, ElemKind::Int8QTy, 0.065f);
291 }
292
293 /// Compare backend against the interpreter in FP16.
TEST_P(FCSweepTest,FCTest_Float16)294 TEST_P(FCSweepTest, FCTest_Float16) {
295 CHECK_IF_ENABLED();
296 testParamSweepFC(GetParam(), ElemKind::FloatTy, ElemKind::Float16Ty, 0.005f);
297 }
298
299 /// Compare backend against the interpreter in BFloat16.
TEST_P(FCSweepTest,FCTest_BFloat16)300 TEST_P(FCSweepTest, FCTest_BFloat16) {
301 CHECK_IF_ENABLED();
302 testParamSweepFC(GetParam(), ElemKind::FloatTy, ElemKind::BFloat16Ty, 0.005f);
303 }
304
305 //===--------------------------------------------------------------------===//
306 // Concat Parameter Sweep Tests
307 //===--------------------------------------------------------------------===//
308
309 /// Create a simple network that has a single fp Concat.
310 static FunctionTensorPair
createAndInitConcatNet(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE,size_t numInputs,size_t numDims,size_t maxLength,size_t axis)311 createAndInitConcatNet(glow::PlaceholderBindings &bindings,
312 glow::ExecutionEngine &EE, size_t numInputs,
313 size_t numDims, size_t maxLength, size_t axis) {
314 PseudoRNG PRNG;
315 auto &mod = EE.getModule();
316 Function *F = mod.createFunction("main");
317
318 // Make leading dimensions smaller than trailing. Reduces size of tests and is
319 // also in line with typical tests.
320 std::vector<dim_t> dims(numDims, maxLength);
321 for (size_t i = 0; i < numDims; i++) {
322 dims[numDims - 1 - i] /= std::pow(2, i);
323 }
324
325 std::vector<NodeValue> inputs(numInputs);
326 for (size_t i = 0; i < numInputs; i++) {
327 auto *IP = mod.createPlaceholder(ElemKind::FloatTy, dims, "input", false);
328 bindings.allocate(IP)->getHandle().randomize(-0.2, 0.2, mod.getPRNG());
329 assert(IP);
330 inputs[i] = IP->getOutput();
331 }
332
333 auto *concat = F->createConcat("concat", inputs, axis);
334 auto *save = F->createSave("save", concat);
335 auto *resultTensor = bindings.allocate(save->getPlaceholder());
336
337 return std::make_pair(F, resultTensor);
338 }
339
340 /// Helper to test sweeping across a variety of configurations of a Concat by
341 /// comparing the results to the Interpreter given some \p allowedError.
342 /// \p config contains the backend to compare the Interpreter against, plus the
343 /// specific configuration to run for this test. \p interpElemKind and \p
344 /// backendElemKind are the element kinds to use for the Interpreter and
345 /// backend, respectively.
testParamSweepConcat(FourIntTupleConfig config,ElemKind interpElemKind,ElemKind backendElemKind,float allowedError)346 static void testParamSweepConcat(FourIntTupleConfig config,
347 ElemKind interpElemKind,
348 ElemKind backendElemKind, float allowedError) {
349 std::string backend;
350 size_t numInputs, numDims, maxLength, axis;
351 SET_BACKEND_KIND_AND_FOUR_INT_PARAMS(config, backend, numInputs, numDims,
352 maxLength, axis);
353 // Exit if axis outside of numDims.
354 if (axis >= numDims) {
355 return;
356 }
357
358 LOG(INFO) << "\n\tTesting Concat with numInputs: " << numInputs
359 << "; numDims: " << numDims << "; maxLength: " << maxLength
360 << "; axis: " << axis << "\n";
361
362 auto boundF =
363 std::bind(createAndInitConcatNet, std::placeholders::_1,
364 std::placeholders::_2, numInputs, numDims, maxLength, axis);
365 compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind,
366 allowedError, parCloneCountOpt);
367 }
368
369 DECLARE_STATELESS_BACKEND_TEST(ConcatSweepTest, FourIntTupleConfig);
370
371 GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST(
372 SweepTest, ConcatSweepTest,
373 ::testing::Combine(/* numInputs */ ::testing::Values(1, 2, 4, 8, 16, 32, 64,
374 128, 192, 256),
375 /* numDims */ ::testing::Range(1, 4),
376 /* maxLength */ ::testing::Values(16, 32, 64, 128),
377 /* axis */ ::testing::Range(0, 3)));
378
379 /// Compare backend against the interpreter in Float.
TEST_P(ConcatSweepTest,ConcatTest_Float)380 TEST_P(ConcatSweepTest, ConcatTest_Float) {
381 CHECK_IF_ENABLED();
382 testParamSweepConcat(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy, 0.0f);
383 }
384
385 /// Compare backend against the interpreter in Int8. Note that we do not use the
386 /// same ElemKind for the Interpreter; this is because the backend will
387 /// quantize/dequantize the input/result anyway, so the comparison wouldn't be
388 /// purely on data movement.
TEST_P(ConcatSweepTest,ConcatTest_Int8)389 TEST_P(ConcatSweepTest, ConcatTest_Int8) {
390 CHECK_IF_ENABLED();
391 testParamSweepConcat(GetParam(), ElemKind::FloatTy, ElemKind::Int8QTy,
392 0.002f);
393 }
394
395 /// Compare backend against the interpreter in Float16. Note that we do not use
396 /// the same ElemKind for the Interpreter; this is because the backend will
397 /// down/up convert the input/result anyway, so the comparison wouldn't be
398 /// purely on data movement.
TEST_P(ConcatSweepTest,ConcatTest_Float16)399 TEST_P(ConcatSweepTest, ConcatTest_Float16) {
400 CHECK_IF_ENABLED();
401 testParamSweepConcat(GetParam(), ElemKind::FloatTy, ElemKind::Float16Ty,
402 0.0001f);
403 }
404
405 /// Compare backend against the interpreter in BFloat16. Note that we do not use
406 /// the same ElemKind for the Interpreter; this is because the backend will
407 /// down/up convert the input/result anyway, so the comparison wouldn't be
408 /// purely on data movement.
TEST_P(ConcatSweepTest,ConcatTest_BFloat16)409 TEST_P(ConcatSweepTest, ConcatTest_BFloat16) {
410 CHECK_IF_ENABLED();
411 testParamSweepConcat(GetParam(), ElemKind::FloatTy, ElemKind::BFloat16Ty,
412 0.0001f);
413 }
414
415 //===--------------------------------------------------------------------===//
416 // SLWS Parameter Sweep Tests
417 //===--------------------------------------------------------------------===//
418
419 /// Create a simple network that has a single fp SLWS.
420 static FunctionTensorPair
createAndInitSLWSNet(glow::PlaceholderBindings & bindings,glow::ExecutionEngine & EE,dim_t embeddingRows,dim_t embeddingDim,dim_t numLengths,bool rowwiseQuantize,bool fused,bool FP16,bool accumFP16)421 createAndInitSLWSNet(glow::PlaceholderBindings &bindings,
422 glow::ExecutionEngine &EE, dim_t embeddingRows,
423 dim_t embeddingDim, dim_t numLengths, bool rowwiseQuantize,
424 bool fused, bool FP16, bool accumFP16) {
425 PseudoRNG PRNG;
426 auto &mod = EE.getModule();
427 Function *F = mod.createFunction("main");
428
429 // Initialize lengths according to the number provided by the test. Note that
430 // we arbitrarily set them between [80,120].
431 auto *lengths =
432 mod.createPlaceholder(ElemKind::Int32ITy, {numLengths}, "lengths", false);
433 auto LH = bindings.allocate(lengths)->getHandle<int32_t>();
434 LH.randomize(80, 120, mod.getPRNG());
435
436 // Get the sum of the lengths to then use as the size for indices and weights.
437 dim_t sumOfLengths = 0;
438 for (const int32_t &e : LH) {
439 sumOfLengths += e;
440 }
441
442 // Initialize indices to size of sum of lengths. Randomly set them to point
443 // somewhere inside the embedding.
444 auto *indices = mod.createPlaceholder(ElemKind::Int64ITy, {sumOfLengths},
445 "indices", false);
446 bindings.allocate(indices)->getHandle<int64_t>().randomize(
447 0, embeddingRows - 1, mod.getPRNG());
448
449 // Xavier initialize the weights with the correct data type.
450 Constant *weights;
451 if (FP16) {
452 weights =
453 mod.createConstant(ElemKind::Float16Ty, {sumOfLengths}, "weights");
454 weights->getPayloadMutable().getHandle<float16_t>().initXavier(
455 weights->getType()->size() * 2, mod.getPRNG());
456 } else {
457 weights = mod.createConstant(ElemKind::FloatTy, {sumOfLengths}, "weights");
458 weights->getPayloadMutable().getHandle<float>().initXavier(
459 weights->getType()->size() * 2, mod.getPRNG());
460 }
461
462 // Create the embedding; non-RWQ versions will simply create a Constant with
463 // it, while RWQ versions will use its data to create a RWQ Constant
464 // internally in the Node constructor.
465 Tensor embeddingT(ElemKind::FloatTy, {embeddingRows, embeddingDim});
466 embeddingT.getHandle().initXavier(embeddingT.size() * 2, mod.getPRNG());
467
468 // Create the SLWS based on provided options.
469 Node *SLWS;
470 if (!rowwiseQuantize) {
471 auto *embeddingC = mod.createConstant("embedding", std::move(embeddingT));
472 SLWS = F->createSparseLengthsWeightedSum("SLWS", embeddingC, weights,
473 indices, lengths);
474 } else {
475 if (fused) {
476 const ElemKind precision =
477 FP16 ? ElemKind::UInt8FusedFP16QTy : ElemKind::UInt8FusedQTy;
478 SLWS = F->createFusedRowwiseQuantizedSparseLengthsWeightedSum(
479 "FRQSLWS", embeddingT, weights, indices, lengths, precision,
480 accumFP16);
481 } else {
482 const ElemKind precision = FP16 ? ElemKind::Float16Ty : ElemKind::FloatTy;
483 SLWS = F->createRowwiseQuantizedSparseLengthsWeightedSum(
484 "RQSLWS", embeddingT, weights, indices, lengths,
485 quantization::Schema::Asymmetric, precision, accumFP16);
486 }
487 }
488 auto *save = F->createSave("save", SLWS);
489 auto *resultTensor = bindings.allocate(save->getPlaceholder());
490
491 return std::make_pair(F, resultTensor);
492 }
493
494 /// Helper to test sweeping across a variety of configurations of a SLWS by
495 /// comparing the results to the Interpreter given some \p allowedError.
496 /// \p config contains the backend to compare the Interpreter against, plus the
497 /// specific configuration to run for this test. \p interpElemKind and \p
498 /// backendElemKind are the element kinds to use for the Interpreter and
499 /// backend, respectively. Pass in options for the test \p rowwiseQuantize,
500 /// \p fused, \p FP16, and \p accumFP16.
testParamSweepSLWS(ThreeIntTupleConfig config,ElemKind interpElemKind,ElemKind backendElemKind,float allowedError,bool rowwiseQuantize,bool fused,bool FP16,bool accumFP16)501 static void testParamSweepSLWS(ThreeIntTupleConfig config,
502 ElemKind interpElemKind,
503 ElemKind backendElemKind, float allowedError,
504 bool rowwiseQuantize, bool fused, bool FP16,
505 bool accumFP16) {
506 std::string backend;
507 size_t embeddingRows, embeddingDim, numLengths;
508 SET_BACKEND_KIND_AND_THREE_INT_PARAMS(config, backend, embeddingRows,
509 embeddingDim, numLengths);
510
511 LOG(INFO) << "\n\tTesting SLWS with embeddingRows: " << embeddingRows
512 << "; embeddingDim: " << embeddingDim
513 << "; numLengths: " << numLengths << "\n";
514
515 auto boundF = std::bind(createAndInitSLWSNet, std::placeholders::_1,
516 std::placeholders::_2, embeddingRows, embeddingDim,
517 numLengths, rowwiseQuantize, fused, FP16, accumFP16);
518 compareAgainstInterpreter(backend, boundF, interpElemKind, backendElemKind,
519 allowedError, parCloneCountOpt);
520 }
521
522 DECLARE_STATELESS_BACKEND_TEST(SLWSSweepTest, ThreeIntTupleConfig);
523
524 GLOW_INSTANTIATE_TEST_SUITE_P_FOR_BACKEND_COMBINED_TEST(
525 SweepTest, SLWSSweepTest,
526 ::testing::Combine(
527 /* embeddingRows */ ::testing::Values(100, 1000, 10000, 100000),
528 /* embeddingDim */ ::testing::Values(32, 64, 96, 128),
529 /* numLengths */ ::testing::Values(16, 32, 64, 128, 256)));
530
531 /// Compare backend against the interpreter.
TEST_P(SLWSSweepTest,SLWS_Float)532 TEST_P(SLWSSweepTest, SLWS_Float) {
533 CHECK_IF_ENABLED();
534 testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
535 0.000001f,
536 /* rowwiseQuantize */ false,
537 /* fused */ false, /* FP16 */ false,
538 /* accumFP16 */ false);
539 }
540
541 /// Compare backend against the interpreter in Float.
TEST_P(SLWSSweepTest,RWQSLWS_Float)542 TEST_P(SLWSSweepTest, RWQSLWS_Float) {
543 CHECK_IF_ENABLED();
544 testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
545 0.000001f,
546 /* rowwiseQuantize */ true,
547 /* fused */ false, /* FP16 */ false,
548 /* accumFP16 */ false);
549 }
550
551 /// Compare backend against the interpreter in Float.
TEST_P(SLWSSweepTest,FRWQSLWS_Float)552 TEST_P(SLWSSweepTest, FRWQSLWS_Float) {
553 CHECK_IF_ENABLED();
554 testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
555 0.000001f,
556 /* rowwiseQuantize */ true,
557 /* fused */ true, /* FP16 */ false,
558 /* accumFP16 */ false);
559 }
560
561 /// Compare backend against the interpreter in Float.
TEST_P(SLWSSweepTest,RWQSLWS_Float16)562 TEST_P(SLWSSweepTest, RWQSLWS_Float16) {
563 // Note: not currently enabled for any open-source backends, as only the
564 // Interpreter supports this.
565 CHECK_IF_ENABLED();
566 testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
567 0.000001f,
568 /* rowwiseQuantize */ true,
569 /* fused */ false, /* FP16 */ true,
570 /* accumFP16 */ false);
571 }
572
573 /// Compare backend against the interpreter in Float.
TEST_P(SLWSSweepTest,FRWQSLWS_Float16)574 TEST_P(SLWSSweepTest, FRWQSLWS_Float16) {
575 // Note: not currently enabled for any open-source backends, as only the
576 // Interpreter supports this.
577 CHECK_IF_ENABLED();
578 testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
579 0.000001f,
580 /* rowwiseQuantize */ true,
581 /* fused */ true, /* FP16 */ true,
582 /* accumFP16 */ false);
583 }
584
585 /// Compare backend against the interpreter in Float.
TEST_P(SLWSSweepTest,RWQSLWS_Float16_AccumFloat16)586 TEST_P(SLWSSweepTest, RWQSLWS_Float16_AccumFloat16) {
587 // Note: not currently enabled for any open-source backends, as only the
588 // Interpreter supports this.
589 CHECK_IF_ENABLED();
590 testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
591 0.000001f,
592 /* rowwiseQuantize */ true,
593 /* fused */ false, /* FP16 */ true,
594 /* accumFP16 */ true);
595 }
596
597 /// Compare backend against the interpreter in Float.
TEST_P(SLWSSweepTest,FRWQSLWS_Float16_AccumFloat16)598 TEST_P(SLWSSweepTest, FRWQSLWS_Float16_AccumFloat16) {
599 // Note: not currently enabled for any open-source backends, as only the
600 // Interpreter supports this.
601 CHECK_IF_ENABLED();
602 testParamSweepSLWS(GetParam(), ElemKind::FloatTy, ElemKind::FloatTy,
603 0.000001f,
604 /* rowwiseQuantize */ true,
605 /* fused */ true, /* FP16 */ true,
606 /* accumFP16 */ true);
607 }
608
main(int argc,char ** argv)609 int main(int argc, char **argv) {
610 ::testing::InitGoogleTest(&argc, argv);
611 llvm::sys::PrintStackTraceOnErrorSignal(argv[0]);
612 llvm::cl::ParseCommandLineOptions(argc, argv);
613 return RUN_ALL_TESTS();
614 }
615