1 /**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "glow/Graph/Graph.h"
18 #include "BackendTestUtils.h"
19 #include "glow/ExecutionEngine/ExecutionEngine.h"
20 #include "glow/Graph/Hook.h"
21 #include "glow/Graph/Node.h"
22 #include "glow/Graph/Nodes.h"
23 #include "glow/Graph/Utils.h"
24 #include "glow/IR/IR.h"
25 #include "glow/IR/Instrs.h"
26 #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
27
28 #include "llvm/ADT/SmallPtrSet.h"
29 #include "llvm/Support/FileSystem.h"
30
31 #include "gtest/gtest.h"
32
33 using namespace glow;
34
35 // Helper to find a node in the Function by name
nodeByName(const Function * F,const std::string & name)36 static const Node *nodeByName(const Function *F, const std::string &name) {
37 for (auto &n : F->getNodes()) {
38 if (n.getName().str() == name) {
39 return &n;
40 }
41 }
42 return nullptr;
43 }
44
45 /// Mock backend that does lower FC nodes.
46 class MockBackendNoLowerConv3D : public MockBackend {
shouldLower(const Node * N) const47 bool shouldLower(const Node *N) const override {
48 if (N->getKind() == Kinded::Kind::Convolution3DNodeKind) {
49 return false;
50 } else {
51 return true;
52 }
53 }
54 };
55
TEST(Graph,testVariableErasure)56 TEST(Graph, testVariableErasure) {
57 Module MD;
58 auto &vars = MD.getConstants();
59 EXPECT_EQ(vars.size(), 0);
60 EXPECT_EQ(std::distance(vars.begin(), vars.end()), vars.size());
61
62 Constant *V = MD.createConstant(ElemKind::FloatTy, {1, 1}, "dummy");
63 EXPECT_EQ(vars.size(), 1);
64 EXPECT_EQ(std::distance(vars.begin(), vars.end()), vars.size());
65
66 MD.eraseConstant(V);
67 EXPECT_EQ(vars.size(), 0);
68 EXPECT_EQ(std::distance(vars.begin(), vars.end()), vars.size());
69 }
70
71 /// Check that the clear method completely reset a module.
TEST(Graph,clear)72 TEST(Graph, clear) {
73 Module M;
74
75 // Check that the module is initially empty.
76 EXPECT_EQ(M.getConstants().size(), 0);
77 EXPECT_EQ(M.getPlaceholders().size(), 0);
78 EXPECT_EQ(M.getFunctions().size(), 0);
79
80 // Create a few things.
81 M.createFunction("main");
82 M.createPlaceholder(ElemKind::FloatTy, {1}, "placeholder", true);
83 M.createConstant(ElemKind::FloatTy, {1}, "var");
84
85 EXPECT_EQ(M.getConstants().size(), 1);
86 EXPECT_EQ(M.getPlaceholders().size(), 1);
87 EXPECT_EQ(M.getFunctions().size(), 1);
88
89 // Check that clearing the module makes it completely free of any kind of
90 // objects.
91 M.clear();
92 EXPECT_EQ(M.getConstants().size(), 0);
93 EXPECT_EQ(M.getPlaceholders().size(), 0);
94 EXPECT_EQ(M.getFunctions().size(), 0);
95 }
96
97 /// Check that the clear method works as expected.
TEST(Graph,clearFunctions)98 TEST(Graph, clearFunctions) {
99 Module M;
100
101 // Check that the module is initially empty.
102 EXPECT_EQ(M.getConstants().size(), 0);
103 EXPECT_EQ(M.getPlaceholders().size(), 0);
104 EXPECT_EQ(M.getFunctions().size(), 0);
105
106 // Create a few things.
107 Function *F = M.createFunction("main");
108 auto *PH = M.createPlaceholder(ElemKind::FloatTy, {1}, "placeholder", true);
109 auto *C = M.createConstant(ElemKind::FloatTy, {1}, "var");
110 auto *AN = F->createAdd("add", PH, C);
111 F->createSave("save", AN);
112
113 EXPECT_EQ(M.getConstants().size(), 1);
114 EXPECT_EQ(M.getPlaceholders().size(), 2); // Input PH and PH for Save
115 EXPECT_EQ(M.getFunctions().size(), 1);
116 EXPECT_EQ(F->getNodes().size(), 2); // Add, Save
117
118 M.clearFunctions();
119 EXPECT_EQ(M.getConstants().size(), 1);
120 EXPECT_EQ(M.getPlaceholders().size(), 2);
121 ASSERT_EQ(M.getFunctions().size(), 1);
122 // Same Function ptr should exist, just nothing left in them.
123 EXPECT_EQ(*M.getFunctions().begin(), F);
124 EXPECT_EQ(F->getNodes().size(), 0);
125 }
126
127 /// Test the graph nodes names and utilities.
TEST(Graph,testGraphNames)128 TEST(Graph, testGraphNames) {
129 Module MD;
130 Function *F = MD.createFunction("F");
131
132 Node *op1 = MD.createPlaceholder(ElemKind::FloatTy, {1, 10}, "op1",
133 false /*isTrainable*/);
134 Node *op2 = MD.createConstant(ElemKind::FloatTy, {1, 10}, "op2");
135 Node *add = F->createAdd("add", op1, op2);
136 auto *top = F->createTopK("top", add, 5);
137 Node *save = F->createSave("out", top->getValues());
138
139 EXPECT_TRUE(MD.getPlaceholderByNameSlow("op1"));
140 EXPECT_TRUE(MD.getConstantByName("op2"));
141 EXPECT_TRUE(F->getNodeByName("add"));
142 EXPECT_TRUE(F->getNodeByName("top"));
143 EXPECT_TRUE(F->getNodeByName("out_save"));
144
145 NodeValue op1Res = op1->getNthResult(0);
146 NodeValue op2Res = op2->getNthResult(0);
147 NodeValue addRes = add->getNthResult(0);
148 EXPECT_TRUE(top->getNumResults() == 2);
149 NodeValue topValRes = top->getNthResult(0);
150 NodeValue topIndRes = top->getNthResult(1);
151
152 auto op1ResName =
153 op1Res.generateNodeOutputName(false /*stripResNoFor0thInput*/);
154 auto op2ResName =
155 op2Res.generateNodeOutputName(false /*stripResNoFor0thInput*/);
156 auto addResName =
157 addRes.generateNodeOutputName(true /*stripResNoFor0thInput*/);
158 auto topValResName =
159 topValRes.generateNodeOutputName(false /*stripResNoFor0thInput*/);
160 auto topIndResName =
161 topIndRes.generateNodeOutputName(false /*stripResNoFor0thInput*/);
162
163 EXPECT_EQ(op1ResName, "op1:0");
164 EXPECT_EQ(op2ResName, "op2:0");
165 EXPECT_EQ(addResName, "add");
166 EXPECT_EQ(topValResName, "top:0");
167 EXPECT_EQ(topIndResName, "top:1");
168
169 EXPECT_EQ(F->getNodeValueByName(op1ResName), op1Res);
170 EXPECT_EQ(F->getNodeValueByName(op2ResName), op2Res);
171 EXPECT_EQ(F->getNodeValueByName(addResName), addRes);
172 EXPECT_EQ(F->getNodeValueByName(topValResName), topValRes);
173 EXPECT_EQ(F->getNodeValueByName(topIndResName), topIndRes);
174
175 EXPECT_EQ(F->getNodeValueByName("op1"), op1Res);
176 EXPECT_EQ(F->getNodeValueByName("op2"), op2Res);
177 EXPECT_EQ(F->getNodeValueByName("add:0"), addRes);
178
179 // Verify the node value is invalid for the SaveNode which has no outputs.
180 EXPECT_EQ(F->getNodeValueByName(save->getName()).getNode(), nullptr);
181 }
182
183 /// Check node names.
TEST(Graph,testNodeNames)184 TEST(Graph, testNodeNames) {
185 Module MD;
186 Function *F = MD.createFunction("F");
187 IRFunction M(F);
188 PlaceholderBindings bindings;
189 Node *K =
190 MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
191 Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", true);
192
193 K = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, 1);
194 K = F->createRELU("Relu", K);
195 K = F->createSoftMax("SoftMax", K, S);
196 F->createSave("Save", K);
197 F->dump();
198 auto filePath = F->dumpDAG();
199 auto backend = MockBackend();
200 CompilationContext cctx;
201 lower(F, cctx, &backend);
202 ::optimize(F, CompilationMode::Train);
203 M.generateIR(backend);
204 M.dump();
205 EXPECT_GT(M.getInstrs().size(), 0);
206 llvm::sys::fs::remove(filePath);
207 }
208
209 /// Check that a createConv3D can be run.
TEST(Graph,simpleTestConv3D)210 TEST(Graph, simpleTestConv3D) {
211 Module MD;
212 Function *F = MD.createFunction("F");
213 IRFunction M(F);
214 PlaceholderBindings bindings;
215 Node *K = MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 100, 3},
216 "input", true);
217 K = F->createConv3D(bindings, /* name */ "Conv3D", /* input */ K,
218 /* outChannels */ 16, /* kernel */ 3, /* stride */ 2,
219 /* pad */ 3, /* group */ 1);
220 K = F->createRELU("Relu", K);
221 F->createSave("Save", K);
222 F->dump();
223 auto filePath = F->dumpDAG();
224 auto backend = MockBackend();
225 CompilationContext cctx;
226 lower(F, cctx, &backend);
227 ::optimize(F, CompilationMode::Train);
228 M.generateIR(backend);
229 M.dump();
230 EXPECT_GT(M.getInstrs().size(), 0);
231 llvm::sys::fs::remove(filePath);
232 }
233
234 /// Tests custom lowering from Node to Instruction IR
TEST(Graph,simpleTestConvCustomLower)235 TEST(Graph, simpleTestConvCustomLower) {
236 Module MD;
237 Function *F = MD.createFunction("F");
238 IRFunction M(F);
239 PlaceholderBindings bindings;
240 Node *K =
241 MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
242 Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", true);
243
244 K = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, 1);
245 K = F->createRELU("Relu", K);
246 K = F->createSoftMax("SoftMax", K, S);
247 F->createSave("Save", K);
248 F->dump();
249 auto filePath = F->dumpDAG();
250 auto backend = MockBackendCustomIRGen();
251 CompilationContext cctx;
252 lower(F, cctx, &backend);
253 ::optimize(F, CompilationMode::Train);
254 M.generateIR(MockBackendCustomIRGen());
255 M.dump();
256 auto &instrList = M.getInstrs();
257 bool customHappened = false;
258 for (auto begin = instrList.begin(); begin != instrList.end(); ++begin) {
259 if (begin->getName().equals("CustomConvolutionInstruction")) {
260 customHappened = true;
261 break;
262 }
263 }
264
265 EXPECT_EQ(customHappened, true);
266 llvm::sys::fs::remove(filePath);
267 }
268
269 /// Check that we can create convolution with float16.
TEST(Graph,float16Conv)270 TEST(Graph, float16Conv) {
271 Module MD;
272 Function *F = MD.createFunction("F");
273 PlaceholderBindings bindings;
274 Node *K = MD.createConstant(ElemKind::Float16Ty, {4, 320, 200, 3}, "input");
275
276 auto *conv = F->createConv(bindings, "Conv", K, 16, 3, 2, 3, 1);
277 F->createSave("Save", conv);
278 EXPECT_TRUE(conv->verify());
279 EXPECT_EQ(conv->getResult().getElementType(), ElemKind::Float16Ty);
280 EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::Float16Ty);
281 EXPECT_EQ(conv->getBias().getElementType(), ElemKind::Float16Ty);
282
283 auto backend = MockBackend();
284 CompilationContext cctx;
285 lower(F, cctx, &backend);
286
287 IRFunction M(F);
288
289 M.generateIR(backend);
290 EXPECT_GT(M.getInstrs().size(), 0);
291 auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
292 [](const Instruction &inst) -> bool {
293 return llvm::isa<ConvolutionInst>(inst);
294 });
295 ASSERT_TRUE(convIt != M.getInstrs().end());
296 const auto *convInst = llvm::cast<ConvolutionInst>(&*convIt);
297 EXPECT_EQ(convInst->getSrc()->getElementType(), ElemKind::Float16Ty);
298 EXPECT_EQ(convInst->getFilter()->getElementType(), ElemKind::Float16Ty);
299 EXPECT_EQ(convInst->getBias()->getElementType(), ElemKind::Float16Ty);
300 }
301
302 /// Check that we can create conv3D with float16.
TEST(Graph,float16Conv3DLower)303 TEST(Graph, float16Conv3DLower) {
304 Module MD;
305 Function *F = MD.createFunction("F");
306 PlaceholderBindings bindings;
307 Node *K =
308 MD.createConstant(ElemKind::Float16Ty, {4, 320, 200, 200, 3}, "input");
309
310 auto *conv = F->createConv3D(bindings, "Conv3D", K, 16, 3, 2, 3, 1);
311 F->createSave("Save", conv);
312 EXPECT_TRUE(conv->verify());
313 EXPECT_EQ(conv->getResult().getElementType(), ElemKind::Float16Ty);
314 EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::Float16Ty);
315 EXPECT_EQ(conv->getBias().getElementType(), ElemKind::Float16Ty);
316
317 auto backend = MockBackend();
318 CompilationContext cctx;
319 lower(F, cctx, &backend);
320
321 IRFunction M(F);
322
323 M.generateIR(backend);
324 EXPECT_GT(M.getInstrs().size(), 0);
325 auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
326 [](const Instruction &inst) -> bool {
327 return llvm::isa<Convolution3DInst>(inst);
328 });
329 ASSERT_TRUE(convIt == M.getInstrs().end());
330 }
331
332 /// Check that we can create conv3D with float16.
TEST(Graph,float16Conv3DNoLower)333 TEST(Graph, float16Conv3DNoLower) {
334 Module MD;
335 Function *F = MD.createFunction("F");
336 PlaceholderBindings bindings;
337 Node *K =
338 MD.createConstant(ElemKind::Float16Ty, {4, 320, 200, 200, 3}, "input");
339
340 auto *conv = F->createConv3D(bindings, "Conv3D", K, 16, 3, 2, 3, 1);
341 F->createSave("Save", conv);
342 EXPECT_TRUE(conv->verify());
343 EXPECT_EQ(conv->getResult().getElementType(), ElemKind::Float16Ty);
344 EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::Float16Ty);
345 EXPECT_EQ(conv->getBias().getElementType(), ElemKind::Float16Ty);
346
347 auto backend = MockBackendNoLowerConv3D();
348 CompilationContext cctx;
349 lower(F, cctx, &backend);
350
351 IRFunction M(F);
352
353 M.generateIR(backend);
354 EXPECT_GT(M.getInstrs().size(), 0);
355 auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
356 [](const Instruction &inst) -> bool {
357 return llvm::isa<Convolution3DInst>(inst);
358 });
359 ASSERT_TRUE(convIt != M.getInstrs().end());
360 const auto *convInst = llvm::cast<Convolution3DInst>(&*convIt);
361 EXPECT_EQ(convInst->getSrc()->getElementType(), ElemKind::Float16Ty);
362 EXPECT_EQ(convInst->getFilter()->getElementType(), ElemKind::Float16Ty);
363 EXPECT_EQ(convInst->getBias()->getElementType(), ElemKind::Float16Ty);
364 }
365
366 /// Check that we can create batchNorm with float16.
TEST(Graph,float16BatchNorm)367 TEST(Graph, float16BatchNorm) {
368 Module MD;
369 Function *F = MD.createFunction("F");
370 PlaceholderBindings bindings;
371 auto *input =
372 MD.createPlaceholder(ElemKind::Float16Ty, {1, 10, 20, 3}, "input", false);
373 BatchNormalizationNode *BN =
374 F->createBatchNormalization(bindings, "batch", input, 3, 0.0001, 0.9);
375
376 EXPECT_TRUE(BN->verify());
377 EXPECT_EQ(BN->getResult().getElementType(), ElemKind::Float16Ty);
378 EXPECT_EQ(BN->getScale().getElementType(), ElemKind::Float16Ty);
379 EXPECT_EQ(BN->getBias().getElementType(), ElemKind::Float16Ty);
380 EXPECT_EQ(BN->getMean().getElementType(), ElemKind::Float16Ty);
381 EXPECT_EQ(BN->getVar().getElementType(), ElemKind::Float16Ty);
382
383 auto backend = MockBackend();
384 CompilationContext cctx;
385 lower(F, cctx, &backend);
386
387 EXPECT_TRUE(std::all_of(
388 F->getNodes().begin(), F->getNodes().end(), [](const Node &node) -> bool {
389 for (unsigned idx = 0, end = node.getNumResults(); idx != end; ++idx) {
390 if (node.getType(idx)->getElementType() != ElemKind::Float16Ty) {
391 return false;
392 }
393 }
394 return true;
395 }));
396 }
397
398 /// Check that we can create convolution with bfloat16.
TEST(Graph,bfloat16Conv)399 TEST(Graph, bfloat16Conv) {
400 Module MD;
401 Function *F = MD.createFunction("F");
402 PlaceholderBindings bindings;
403 Node *K = MD.createConstant(ElemKind::BFloat16Ty, {4, 320, 200, 3}, "input");
404
405 auto *conv = F->createConv(bindings, "Conv", K, 16, 3, 2, 3, 1);
406 F->createSave("Save", conv);
407 EXPECT_TRUE(conv->verify());
408 EXPECT_EQ(conv->getResult().getElementType(), ElemKind::BFloat16Ty);
409 EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::BFloat16Ty);
410 EXPECT_EQ(conv->getBias().getElementType(), ElemKind::BFloat16Ty);
411
412 auto backend = MockBackend();
413 CompilationContext cctx;
414 lower(F, cctx, &backend);
415
416 IRFunction M(F);
417
418 M.generateIR(backend);
419 EXPECT_GT(M.getInstrs().size(), 0);
420 auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
421 [](const Instruction &inst) -> bool {
422 return llvm::isa<ConvolutionInst>(inst);
423 });
424 ASSERT_TRUE(convIt != M.getInstrs().end());
425 const auto *convInst = llvm::cast<ConvolutionInst>(&*convIt);
426 EXPECT_EQ(convInst->getSrc()->getElementType(), ElemKind::BFloat16Ty);
427 EXPECT_EQ(convInst->getFilter()->getElementType(), ElemKind::BFloat16Ty);
428 EXPECT_EQ(convInst->getBias()->getElementType(), ElemKind::BFloat16Ty);
429 }
430
431 /// Check that we can create conv3D with bfloat16.
TEST(Graph,bfloat16Conv3DLower)432 TEST(Graph, bfloat16Conv3DLower) {
433 Module MD;
434 Function *F = MD.createFunction("F");
435 PlaceholderBindings bindings;
436 Node *K =
437 MD.createConstant(ElemKind::BFloat16Ty, {4, 320, 200, 200, 3}, "input");
438
439 auto *conv = F->createConv3D(bindings, "Conv3D", K, 16, 3, 2, 3, 1);
440 F->createSave("Save", conv);
441 EXPECT_TRUE(conv->verify());
442 EXPECT_EQ(conv->getResult().getElementType(), ElemKind::BFloat16Ty);
443 EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::BFloat16Ty);
444 EXPECT_EQ(conv->getBias().getElementType(), ElemKind::BFloat16Ty);
445
446 auto backend = MockBackend();
447 CompilationContext cctx;
448 lower(F, cctx, &backend);
449
450 IRFunction M(F);
451
452 M.generateIR(backend);
453 EXPECT_GT(M.getInstrs().size(), 0);
454 auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
455 [](const Instruction &inst) -> bool {
456 return llvm::isa<Convolution3DInst>(inst);
457 });
458 ASSERT_TRUE(convIt == M.getInstrs().end());
459 }
460
461 /// Check that we can create conv3D with bfloat16.
TEST(Graph,bfloat16Conv3DNoLower)462 TEST(Graph, bfloat16Conv3DNoLower) {
463 Module MD;
464 Function *F = MD.createFunction("F");
465 PlaceholderBindings bindings;
466 Node *K =
467 MD.createConstant(ElemKind::BFloat16Ty, {4, 320, 200, 200, 3}, "input");
468
469 auto *conv = F->createConv3D(bindings, "Conv3D", K, 16, 3, 2, 3, 1);
470 F->createSave("Save", conv);
471 EXPECT_TRUE(conv->verify());
472 EXPECT_EQ(conv->getResult().getElementType(), ElemKind::BFloat16Ty);
473 EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::BFloat16Ty);
474 EXPECT_EQ(conv->getBias().getElementType(), ElemKind::BFloat16Ty);
475
476 auto backend = MockBackendNoLowerConv3D();
477 CompilationContext cctx;
478 lower(F, cctx, &backend);
479
480 IRFunction M(F);
481
482 M.generateIR(backend);
483 EXPECT_GT(M.getInstrs().size(), 0);
484 auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
485 [](const Instruction &inst) -> bool {
486 return llvm::isa<Convolution3DInst>(inst);
487 });
488 ASSERT_TRUE(convIt != M.getInstrs().end());
489 const auto *convInst = llvm::cast<Convolution3DInst>(&*convIt);
490 EXPECT_EQ(convInst->getSrc()->getElementType(), ElemKind::BFloat16Ty);
491 EXPECT_EQ(convInst->getFilter()->getElementType(), ElemKind::BFloat16Ty);
492 EXPECT_EQ(convInst->getBias()->getElementType(), ElemKind::BFloat16Ty);
493 }
494
495 /// Check that we can create batchNorm with float16.
TEST(Graph,bfloat16BatchNorm)496 TEST(Graph, bfloat16BatchNorm) {
497 Module MD;
498 Function *F = MD.createFunction("F");
499 PlaceholderBindings bindings;
500 auto *input = MD.createPlaceholder(ElemKind::BFloat16Ty, {1, 10, 20, 3},
501 "input", false);
502 BatchNormalizationNode *BN =
503 F->createBatchNormalization(bindings, "batch", input, 3, 0.0001, 0.9);
504
505 EXPECT_TRUE(BN->verify());
506 EXPECT_EQ(BN->getResult().getElementType(), ElemKind::BFloat16Ty);
507 EXPECT_EQ(BN->getScale().getElementType(), ElemKind::BFloat16Ty);
508 EXPECT_EQ(BN->getBias().getElementType(), ElemKind::BFloat16Ty);
509 EXPECT_EQ(BN->getMean().getElementType(), ElemKind::BFloat16Ty);
510 EXPECT_EQ(BN->getVar().getElementType(), ElemKind::BFloat16Ty);
511
512 auto backend = MockBackend();
513 CompilationContext cctx;
514 lower(F, cctx, &backend);
515
516 EXPECT_TRUE(std::all_of(
517 F->getNodes().begin(), F->getNodes().end(), [](const Node &node) -> bool {
518 for (unsigned idx = 0, end = node.getNumResults(); idx != end; ++idx) {
519 if (node.getType(idx)->getElementType() != ElemKind::BFloat16Ty) {
520 return false;
521 }
522 }
523 return true;
524 }));
525 }
526
527 /// Test that our use lists are correctly reflecting the state of the IR
528 /// and in particular that it is not polluted by temporary variable.
TEST(Graph,useList)529 TEST(Graph, useList) {
530 Module MD;
531 Function *F = MD.createFunction("F");
532 IRFunction M(F);
533 PlaceholderBindings bindings;
534 auto *K =
535 MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
536
537 EXPECT_EQ(K->getNumUsers(), 0);
538
539 ConvolutionNode *conv = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, 1);
540
541 EXPECT_TRUE(K->hasOneUse());
542 EXPECT_EQ(K->getNumUsers(), 1);
543 EXPECT_EQ(conv->getNumUsers(), 0);
544
545 // Although the filter of the convolution is only used by the convolution
546 // node, calling getFilter creates a temporary NodeValue that messes up
547 // with the actual use list.
548 // Therefore those checks are currently inverted but should be
549 // fixed eventually.
550 // Test with implicit temporary NodeValue.
551 EXPECT_TRUE(conv->getFilter().getNode()->hasOneUse());
552 EXPECT_EQ(conv->getFilter().getNode()->getNumUsers(), 1);
553
554 // Test with explicit temporary NodeValue.
555 Node *nodeFilter;
556 {
557 NodeValue tmp = conv->getFilter();
558 EXPECT_TRUE(tmp.getNode()->hasOneUse());
559 EXPECT_EQ(tmp.getNode()->getNumUsers(), 1);
560 nodeFilter = tmp.getNode();
561 // Test with NodeValue still around.
562 EXPECT_TRUE(nodeFilter->hasOneUse());
563 EXPECT_EQ(nodeFilter->getNumUsers(), 1);
564 }
565
566 // Test with NodeValue took out.
567 EXPECT_TRUE(nodeFilter->hasOneUse());
568 EXPECT_EQ(nodeFilter->getNumUsers(), 1);
569
570 // Same kind of test but with the convolution node itself.
571 {
572 NodeValue tmpConvRes(conv, 0);
573 EXPECT_EQ(conv->getNumUsers(), 0);
574 EXPECT_EQ(tmpConvRes.getNode()->getNumUsers(), 0);
575 }
576
577 // Add a couple of uses to conv and make sure it reflects on its use list.
578 F->createSave("Save", conv, K);
579
580 EXPECT_FALSE(K->hasOneUse());
581 EXPECT_EQ(K->getNumUsers(), 2);
582 EXPECT_EQ(conv->getNumUsers(), 1);
583 EXPECT_TRUE(conv->hasOneUse());
584
585 {
586 NodeValue tmpConvRes(conv, 0);
587 EXPECT_TRUE(tmpConvRes.getNode()->hasOneUse());
588 EXPECT_TRUE(conv->hasOneUse());
589 EXPECT_EQ(conv->getNumUsers(), 1);
590 EXPECT_EQ(tmpConvRes.getNode()->getNumUsers(), 1);
591 }
592
593 F->createSave("Save", conv, K);
594
595 EXPECT_FALSE(K->hasOneUse());
596 EXPECT_EQ(K->getNumUsers(), 3);
597 EXPECT_EQ(conv->getNumUsers(), 2);
598 EXPECT_FALSE(conv->hasOneUse());
599
600 {
601 NodeValue tmpConvRes(conv, 0);
602 EXPECT_FALSE(tmpConvRes.getNode()->hasOneUse());
603 EXPECT_FALSE(conv->hasOneUse());
604 EXPECT_EQ(conv->getNumUsers(), 2);
605 EXPECT_EQ(tmpConvRes.getNode()->getNumUsers(), 2);
606 }
607 }
608
TEST(Graph,useListIteration)609 TEST(Graph, useListIteration) {
610 Module MD;
611 Function *F = MD.createFunction("F");
612 IRFunction M(F);
613 Node *K =
614 MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
615
616 EXPECT_EQ(K->getNumUsers(), 0);
617
618 PlaceholderBindings bindings;
619 ConvolutionNode *conv1 = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, 1);
620 ConvolutionNode *conv2 = F->createConv(bindings, "Conv2", K, 16, 3, 2, 3, 1);
621 // Check the number of users for different nodes.
622 EXPECT_EQ(K->getNumUsers(), 2);
623 EXPECT_EQ(conv1->getNumUsers(), 0);
624 EXPECT_TRUE(conv2->getFilter().getNode()->hasOneUse());
625 EXPECT_EQ(conv1->getFilter().getNode()->getNumUsers(), 1);
626 // Check that the first user of K is conv1.
627 EXPECT_EQ(K->getUsers().begin()->getUser(), conv1);
628 // Check that the second user of K is conv2.
629 EXPECT_EQ((++K->getUsers().begin())->getUser(), conv2);
630 }
631
TEST(Graph,simpleTestFC)632 TEST(Graph, simpleTestFC) {
633 unsigned numInputs = 10;
634 Module MD;
635 Function *F = MD.createFunction("F");
636 IRFunction M(F);
637
638 auto *A = MD.createPlaceholder(ElemKind::FloatTy, {numInputs, 2}, "A", true);
639 auto *Ex =
640 MD.createPlaceholder(ElemKind::FloatTy, {numInputs, 1}, "Ex", true);
641
642 PlaceholderBindings bindings;
643 Node *O = F->createFullyConnected(bindings, "FC1", A, 6);
644 O = F->createRELU("RELU1", O);
645 O = F->createFullyConnected(bindings, "FC2", O, 1);
646 O = F->createRELU("RELU2", O);
647 O = F->createRegression("Regression", O, Ex);
648 F->createSave("Save", O);
649 F->dump();
650 auto filePath = F->dumpDAG();
651 auto backend = MockBackend();
652 CompilationContext cctx;
653 lower(F, cctx, &backend);
654 ::optimize(F, CompilationMode::Train);
655 M.generateIR(backend);
656 M.dump();
657 EXPECT_GT(M.getInstrs().size(), 0);
658 llvm::sys::fs::remove(filePath);
659 }
660
TEST(Graph,QuantizationProfileNodes)661 TEST(Graph, QuantizationProfileNodes) {
662 unsigned numInputs = 10;
663 Module MD;
664 Function *F = MD.createFunction("F");
665 IRFunction M(F);
666
667 auto *A = MD.createPlaceholder(ElemKind::FloatTy, {numInputs, 2}, "A", true);
668
669 // Add non float operation, which should not be profiled.
670 auto *outQTy = F->getParent()->uniqueType(glow::ElemKind::Int8QTy,
671 {numInputs, 2}, 1.5, 6);
672 auto *quantize = F->createQuantize("quantize", A, outQTy);
673 // Make sure that quantize is not optimized away.
674 PlaceholderBindings bindings;
675 F->createSave("save", quantize);
676
677 // Multiple nodes read from the same variable.
678 // Only one Quantization Profile node should be created for the output
679 // from the variable.
680 Node *O = F->createFullyConnected(bindings, "FC1", A, 6);
681 Node *C = F->createFullyConnected(bindings, "FC2", A, 6);
682 O = F->createRELU("RELU1", O);
683 F->createSave("save", O);
684 F->createSave("save", C);
685
686 LoweredInfoMap loweredMapForProf;
687 CompilationContext cctx{&bindings, &loweredMapForProf};
688 cctx.precisionConfig.quantMode = QuantizationMode::Profile;
689 std::unique_ptr<Backend> backend(createBackend("Interpreter"));
690 EXIT_ON_ERR(::optimizeFunction(F, *backend, cctx));
691
692 size_t numberOfProfileNodes =
693 std::count_if(F->getNodes().begin(), F->getNodes().end(), [](Node &node) {
694 return llvm::isa<QuantizationProfileNode>(&node);
695 });
696
697 // 1 from A
698 // 8 from two lowered FCs: MM, BA, weight PH, bias PH
699 // 2 from RELU (lowered to Max+Splat)
700 EXPECT_EQ(11, numberOfProfileNodes);
701 }
702
TEST(Graph,simpleQuant)703 TEST(Graph, simpleQuant) {
704 ExecutionEngine EE;
705 auto &MD = EE.getModule();
706 auto *F = MD.createFunction("main");
707
708 unsigned depth = 16;
709 llvm::SmallVector<unsigned_t, 2> kernels = {5, 5};
710 llvm::SmallVector<unsigned_t, 4> pads = {0, 0, 0, 0};
711 llvm::SmallVector<unsigned_t, 2> steps = {1, 1};
712 unsigned width = 224;
713
714 auto *input = MD.createPlaceholder(ElemKind::Int8QTy, {1, width, width, 3},
715 0.4, 2, "Input", true);
716
717 // Calculate the size and allocate the output buffer.
718 std::array<dim_t, 4> filterDim = {{depth, kernels[0], kernels[1], 3}};
719 auto *filter =
720 MD.createPlaceholder(ElemKind::Int8QTy, filterDim, 3.3, 4, "F", true);
721 auto *bias =
722 MD.createPlaceholder(ElemKind::Int32QTy, {depth}, 1.3, 5, "B", true);
723
724 // Calculate the size and allocate the output buffer.
725 auto outSz = calculateConvPoolOutputDims(width, width, kernels, steps, pads);
726 std::array<dim_t, 4> outDims = {{1, outSz.first, outSz.second, 16}};
727 auto t = F->getParent()->uniqueType(glow::ElemKind::Int8QTy, outDims, 1.5, 6);
728
729 auto *conv =
730 F->createConv("conv", input, filter, bias, t, kernels, steps, pads, 1);
731
732 auto s = conv->getResult().getType()->size();
733 auto *fcFilter =
734 MD.createPlaceholder(ElemKind::Int8QTy, {s, 6}, 0.4, 2, "F", true);
735 auto *fcBias =
736 MD.createPlaceholder(ElemKind::Int32QTy, {6}, 0.4, 2, "B", true);
737 Node *O = F->createFullyConnected("fc1", conv, fcFilter, fcBias);
738 PlaceholderBindings bindings;
739 F->createSave("ret", O);
740 EE.compile(CompilationMode::Infer);
741 }
742
TEST(Graph,quantizeDequantizeNodes)743 TEST(Graph, quantizeDequantizeNodes) {
744 ExecutionEngine EE;
745 auto &MD = EE.getModule();
746 auto F = MD.createFunction("main");
747
748 auto *input = MD.createPlaceholder(ElemKind::FloatTy, {1, 3}, "Input", true);
749 auto qType = F->getParent()->uniqueType(ElemKind::Int8QTy, {1, 3}, 0.3, 5);
750
751 auto *Q = F->createQuantize("quantize", input, qType);
752
753 auto transform =
754 F->getParent()->uniqueType(ElemKind::Int8QTy, {1, 3}, 1.4, 3);
755 auto *A = F->createRescaleQuantized("rescale", Q, transform);
756
757 auto *D = F->createDequantize("dequantize", A, ElemKind::FloatTy);
758 PlaceholderBindings bindings;
759 F->createSave("ret", D);
760 EE.compile(CompilationMode::Infer);
761 }
762
TEST(Graph,quantizeGather)763 TEST(Graph, quantizeGather) {
764 ExecutionEngine EE;
765 auto &mod = EE.getModule();
766 auto *F = mod.createFunction("main");
767 auto *input =
768 mod.createPlaceholder(ElemKind::Int8QTy, {2, 2}, 0.4, 2, "input", true);
769 auto *indices = mod.createPlaceholder(ElemKind::Int64ITy, {1}, "index", true);
770 auto *gather = F->createGather("gather", input, indices);
771 PlaceholderBindings bindings;
772 F->createSave("ret", gather);
773 EE.compile(CompilationMode::Infer);
774 }
775
TEST(Graph,cloneTest)776 TEST(Graph, cloneTest) {
777 Module M;
778 PlaceholderBindings bindings;
779
780 Function *F = M.createFunction("main");
781 Node *K =
782 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
783 Node *S = M.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", true);
784 Node *conv = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, 1);
785 Node *relu = F->createRELU("Relu", conv);
786 Node *SM = F->createSoftMax("SoftMax", relu, S);
787 F->createSave("Save", SM);
788
789 auto *newConv = F->addNode(conv->clone());
790 auto *newRelu = F->addNode(relu->clone());
791 auto *newSM = F->addNode(SM->clone());
792
793 EXPECT_TRUE(newConv != conv && conv->isEqual(*newConv));
794 EXPECT_TRUE(newRelu != relu && relu->isEqual(*newRelu));
795 EXPECT_TRUE(newSM != SM && SM->isEqual(*newSM));
796 }
797
TEST(Graph,moduleTest)798 TEST(Graph, moduleTest) {
799 Module M;
800 M.createFunction("one");
801 M.createFunction("two");
802 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V1", true);
803 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V2", true);
804 EXPECT_TRUE(M.hasFunction("one"));
805 EXPECT_TRUE(M.hasFunction("two"));
806 EXPECT_FALSE(M.hasFunction("four"));
807 M.dumpDAG();
808 }
809
TEST(Graph,functionDependenciesTest)810 TEST(Graph, functionDependenciesTest) {
811 Module M;
812 auto *F1 = M.createFunction("one");
813 auto *F2 = M.createFunction("two");
814 auto *V1 =
815 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V1", true);
816 auto *V2 =
817 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V2", true);
818 auto *V3 =
819 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V3", true);
820 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V4", true);
821
822 PlaceholderBindings bindings;
823 auto sum = F1->createSub("1_sub_2", V1, V2);
824 F1->createSave("sv", sum, V1);
825 F2->createSave("sv", V3, V2);
826
827 EXPECT_TRUE(M.hasFunction("one"));
828 EXPECT_TRUE(M.hasFunction("two"));
829 EXPECT_FALSE(M.hasFunction("four"));
830 M.dumpDAG();
831 }
832
TEST(Graph,functionCloneTest)833 TEST(Graph, functionCloneTest) {
834 Module M;
835 PlaceholderBindings bindings;
836
837 auto *F = M.createFunction("main");
838 Node *K =
839 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
840 Node *S = M.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", true);
841 Node *conv = F->createConv(bindings, "Conv", K, 16, 3, 2, 3, 1);
842 Node *relu = F->createRELU("Relu", conv);
843 Node *concat = F->createConcat("concat", {relu, relu, relu}, 0);
844 Node *SM = F->createSoftMax("SoftMax", concat, S);
845 F->createSave("Save", SM);
846
847 auto *newF = F->clone("new_main");
848
849 EXPECT_TRUE(newF->verify());
850
851 EXPECT_EQ(newF->getNodes().size(), F->getNodes().size());
852 EXPECT_EQ(newF->getParent(), F->getParent());
853 }
854
855 /// Compile the module \p M inside the execution engine \p EE and then run it
856 /// using the provided \p bindings. Use the provided \p inputName and \p
857 /// outputName.
compileAndRun(ExecutionEngine & EE,PlaceholderBindings & bindings,Module & M,llvm::StringRef inputName,llvm::StringRef outputName)858 static void compileAndRun(ExecutionEngine &EE, PlaceholderBindings &bindings,
859 Module &M, llvm::StringRef inputName,
860 llvm::StringRef outputName) {
861 EE.compile(glow::CompilationMode::Infer);
862 // Allocate stprage for placeholders and initialize inputs.
863 bindings.allocate(M.getPlaceholderByNameSlow(inputName))
864 ->getHandle()
865 .clear(2.0);
866 bindings.allocate(M.getPlaceholderByNameSlow(outputName));
867 EE.run(bindings);
868 }
869
870 /// Check the module cloning functionality.
TEST(Graph,moduleCloneTest)871 TEST(Graph, moduleCloneTest) {
872 // State related to the cloned module and its execution.
873 ExecutionEngine clonedEE("Interpreter");
874 Module &clonedM = clonedEE.getModule();
875 PlaceholderBindings clonedBindings;
876 Tensor clonedResult;
877 // State related to the original module and its execution.
878 PlaceholderBindings originalBindings;
879 Tensor originalResult;
880 // Name of the placeholder holding the results of executions.
881 std::string resultName;
882 {
883 // Define the original execution engine and module.
884 ExecutionEngine originalEE("Interpreter");
885 Module &originalM = originalEE.getModule();
886
887 // Create a function.
888 auto *F = originalM.createFunction("main");
889 auto *input1 = originalM.createPlaceholder(ElemKind::FloatTy,
890 {4, 10, 10, 3}, "input", true);
891
892 auto *add = F->createAdd("add", input1, input1);
893 auto *relu = F->createRELU("Relu", add);
894 auto *concat = F->createConcat("concat", {relu, relu, relu}, 0);
895 auto *C = originalM.createConstant(concat->getResult().getType(), "C");
896 C->getPayloadMutable().getHandle().clear(1.0f);
897 auto *SM = F->createAdd("add", concat, C);
898 auto *SN = F->createSave("Save", SM);
899 resultName = SN->getPlaceholder()->getName();
900
901 // Clone the original module into the cloned module.
902 originalM.clone(&clonedM);
903 // The cloned module should have the same numer of types, functions,
904 // constants and placeholders.
905 EXPECT_EQ(originalM.getFunctions().size(), clonedM.getFunctions().size());
906 EXPECT_EQ(originalM.getPlaceholders().size(),
907 clonedM.getPlaceholders().size());
908 EXPECT_EQ(originalM.getConstants().size(), clonedM.getConstants().size());
909 EXPECT_EQ(originalM.getTypes().size(), clonedM.getTypes().size());
910 // String representations of the original and cloned modules should be the
911 // same.
912 EXPECT_EQ(originalM.toString(), clonedM.toString());
913 for (auto *originalF : originalM.getFunctions()) {
914 EXPECT_EQ(originalF->toString(),
915 clonedM.getFunction(originalF->getName())->toString());
916 }
917
918 // Compile and run the original module.
919 compileAndRun(originalEE, originalBindings, originalM, "input", resultName);
920 // Store the result of running the original module.
921 originalResult.assign(originalBindings.get(
922 originalBindings.getPlaceholderByNameSlow(resultName)));
923 // The old module should be removed when this scope ends. Thus, if the
924 // cloned module newM refers to any deleted nodes from the original module,
925 // it would result in a dangling reference and most likely in a crash.
926 }
927 // Check that the cloned module is still alive and valid after the original
928 // module was deleted.
929 EXPECT_TRUE(clonedM.verify());
930 // Compile and run the cloned model.
931 compileAndRun(clonedEE, clonedBindings, clonedM, "input", resultName);
932 // Store the result of running the cloned module.
933 clonedResult.assign(
934 clonedBindings.get(clonedBindings.getPlaceholderByNameSlow(resultName)));
935 // The results of execution should be exactly the same in both cases.
936 EXPECT_TRUE(originalResult.isEqual(clonedResult, 0));
937 }
938
TEST(Graph,cloneWithPredicates)939 TEST(Graph, cloneWithPredicates) {
940 Module M;
941 PlaceholderBindings bindings;
942
943 auto *F = M.createFunction("main");
944 auto *input =
945 M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", false);
946 auto *counters =
947 M.createPlaceholder(ElemKind::FloatTy, {10}, "counters", false);
948 auto *reluExt = F->createRELU("reluExt", input);
949 auto *reluInt = F->createRELU("reluInt", input);
950 auto *externalPredicate =
951 M.createPlaceholder(ElemKind::Int64ITy, {1}, "predicate", false);
952 auto *C10 = F->createSplat("C10", counters->getType(), 10.0);
953 auto *internalPredicate = F->createCmpLTE("lte", C10, counters);
954
955 reluExt->setPredicate(externalPredicate);
956 reluInt->setPredicate(internalPredicate);
957
958 auto *newF = F->clone("new_main");
959
960 EXPECT_TRUE(newF->verify());
961 EXPECT_EQ(newF->getNodes().size(), F->getNodes().size());
962 EXPECT_EQ(newF->getParent(), F->getParent());
963
964 // Original predicates are not changed
965 EXPECT_EQ(reluExt->getPredicate().getNode(), externalPredicate);
966 EXPECT_EQ(reluInt->getPredicate().getNode(), internalPredicate);
967 // Clone of predicate that points to a node outside the graph
968 // points to the same node (predicate is shared)
969 EXPECT_EQ(nodeByName(newF, "reluExt")->getPredicate().getNode(),
970 externalPredicate);
971 // Clone of predicate that points to a node that belongs to the graph
972 // points to the predicate clone
973 EXPECT_EQ(nodeByName(newF, "reluInt")->getPredicate().getNode(),
974 nodeByName(newF, "lte"));
975 }
976
TEST(Graph,NodeValue)977 TEST(Graph, NodeValue) {
978 ExecutionEngine EE;
979 auto &mod = EE.getModule();
980 Function *F = mod.createFunction("main");
981 PlaceholderBindings bindings;
982 auto *inputX = mod.createPlaceholder(ElemKind::FloatTy, {1}, "input", true);
983 bindings.allocate(inputX)->init(Tensor::InitKind::Broadcast, 3.0,
984 mod.getPRNG());
985
986 NodeValue a = F->createAdd("x2", inputX, inputX);
987 a = F->createAdd("x4", a, a);
988 a = F->createAdd("x8", a, a);
989 auto *S = F->createSave("Save", a);
990 auto *res = bindings.allocate(S->getPlaceholder());
991
992 EE.compile(CompilationMode::Infer);
993
994 EE.run(bindings);
995
996 EXPECT_EQ(res->getHandle().raw(0), 24);
997 }
998
999 /// Check that by deleting one function, the variables that refernced
1000 /// by this function, will reduce its number of uses by one.
TEST(Graph,deleteFunction)1001 TEST(Graph, deleteFunction) {
1002 ExecutionEngine EE;
1003 auto &mod = EE.getModule();
1004 Function *F1 = mod.createFunction("f1");
1005 auto *inputX = mod.createPlaceholder(ElemKind::FloatTy, {1}, "input", true);
1006 F1->createLog("log1", inputX);
1007 Function *F2 = mod.createFunction("f2");
1008 F2->createLog("log2", inputX);
1009 // We check the number of user of inputX to be 2 as only F1 and F2 are
1010 // using it.
1011 EXPECT_EQ(inputX->getNumUsers(), 2);
1012 // Erase this function here to see if we can see the number of user of inputX
1013 // reduce to 1.
1014 mod.eraseFunction(F1);
1015 EXPECT_EQ(inputX->getNumUsers(), 1);
1016 }
1017
TEST(Graph,nodesWithPredicates)1018 TEST(Graph, nodesWithPredicates) {
1019 ExecutionEngine EE;
1020
1021 Tensor inputs(ElemKind::FloatTy, {1, 32, 32, 3});
1022
1023 auto &mod = EE.getModule();
1024 Function *F = mod.createFunction("main");
1025 F->setName("interpret");
1026 PlaceholderBindings bindings;
1027 auto *input =
1028 mod.createPlaceholder(ElemKind::FloatTy, {1, 32, 32, 3}, "input", true);
1029 auto *ex = mod.createPlaceholder(ElemKind::Int64ITy, {1, 1}, "exp", true);
1030 auto *pred =
1031 mod.createPlaceholder(ElemKind::Int64ITy, {1}, "predicate", false);
1032 bindings.allocate(input);
1033 bindings.allocate(ex);
1034 bindings.allocate(pred);
1035
1036 auto *CV0 = F->createConv(bindings, "conv1", input, 16, 5, 1, 2, 1);
1037 auto *RL0 = F->createRELU("relu1", CV0);
1038 auto *MP0 = F->createMaxPool("pool1", RL0, 2, 2, 0);
1039
1040 CV0->setPredicate(pred);
1041 RL0->setPredicate(pred);
1042 MP0->setPredicate(pred);
1043
1044 auto *FCL1 = F->createFullyConnected(bindings, "fc", MP0->getResult(), 10);
1045 auto *RL3 = F->createRELU("relu4", FCL1);
1046 auto *SM = F->createSoftMax("sm", RL3, ex);
1047 auto *save = F->createSave("ret", SM);
1048 bindings.allocate(save->getPlaceholder());
1049
1050 EE.compile(CompilationMode::Infer);
1051
1052 updateInputPlaceholders(bindings, {input}, {&inputs});
1053 EE.run(bindings);
1054 }
1055
1056 // Return the number of ConvolutionNode after lower.
getConvNodeSize(llvm::StringRef kind)1057 unsigned getConvNodeSize(llvm::StringRef kind) {
1058 Module mod;
1059 Function *F = mod.createFunction("main");
1060 IRFunction M(F);
1061 PlaceholderBindings bindings;
1062 auto *input =
1063 mod.createPlaceholder(ElemKind::FloatTy, {1, 2, 1, 32}, "input", true);
1064 ConvolutionNode *CN = F->createConv(bindings, "conv", input, 6, 1, 1, 0, 2);
1065 F->createSave("save", CN);
1066
1067 std::unique_ptr<Backend> backend(createBackend(kind));
1068 CompilationContext cctx;
1069 lower(F, cctx, backend.get());
1070
1071 unsigned count = 0;
1072 for (auto &n : F->getNodes()) {
1073 if (n.getKind() == Kinded::Kind::ConvolutionNodeKind) {
1074 count++;
1075 }
1076 }
1077
1078 if (kind == "Interpreter") {
1079 EXPECT_EQ(count, 1);
1080 }
1081
1082 return count;
1083 }
1084
1085 // Check the unrolling grouped convolution opt status:
1086 // -- disabled for Interpreter, CPU and OpenCL backend,
TEST(Graph,disableUnrollingGroupConv)1087 TEST(Graph, disableUnrollingGroupConv) {
1088 unsigned numberOfNodesInterpreter = getConvNodeSize("Interpreter");
1089 (void)numberOfNodesInterpreter;
1090
1091 #ifdef GLOW_WITH_CPU
1092 unsigned numberOfNodesCPU = getConvNodeSize("CPU");
1093 EXPECT_EQ(numberOfNodesCPU, numberOfNodesInterpreter);
1094 #endif // GLOW_WITH_CPU
1095
1096 #ifdef GLOW_WITH_OPENCL
1097 unsigned numberOfNodesOpenCL = getConvNodeSize("OpenCL");
1098 EXPECT_EQ(numberOfNodesOpenCL, numberOfNodesInterpreter);
1099 #endif // GLOW_WITH_OPENCL
1100 }
1101
1102 /// Check that save nodes are properly scheduled.
1103 /// That is, they happen after the last use of the related variable.
1104 /// In that test, the order of the creation of the nodes give a valid schedule.
TEST(Graph,schedulingOfSavesOrderProvided)1105 TEST(Graph, schedulingOfSavesOrderProvided) {
1106 ExecutionEngine EE;
1107
1108 auto &mod = EE.getModule();
1109 Function *F = mod.createFunction("main");
1110 auto *A = mod.createPlaceholder(ElemKind::FloatTy, {3, 32}, "A", true);
1111 auto *B = mod.createPlaceholder(A->getType(), "B", true);
1112 auto *zero = mod.createPlaceholder(A->getType(), "zero", true);
1113
1114 PlaceholderBindings bindings;
1115 bindings.allocate(A)->init(Tensor::InitKind::Xavier, 1.0, mod.getPRNG());
1116 bindings.allocate(B)->init(Tensor::InitKind::Xavier, 1.0, mod.getPRNG());
1117 bindings.allocate(zero)->init(Tensor::InitKind::Broadcast, 0.0,
1118 mod.getPRNG());
1119
1120 auto *addAB = F->createAdd("addAB", A, B);
1121
1122 auto *saveNode = F->createSave("ret", addAB);
1123 auto *savePH = saveNode->getPlaceholder();
1124 bindings.allocate(savePH);
1125 F->createSave("resetA", zero, A);
1126
1127 // Copy the value of A.
1128 Tensor AOrig = bindings.get(A)->clone();
1129
1130 EE.compile(CompilationMode::Infer);
1131
1132 EE.run(bindings);
1133 auto *ret = bindings.get(savePH);
1134 auto handleAOrig = AOrig.getHandle<>();
1135 auto handleB = bindings.get(B)->getHandle<>();
1136 auto handleRet = ret->getHandle<>();
1137 bool allEqual = true;
1138 for (unsigned row = 0; row != 3; ++row) {
1139 for (unsigned column = 0; column != 32; ++column) {
1140 allEqual &= handleAOrig.at({row, column}) + handleB.at({row, column}) ==
1141 handleRet.at({row, column});
1142 }
1143 }
1144 EXPECT_TRUE(bindings.get(A)->isEqual(*bindings.get(zero), 0.0));
1145 EXPECT_TRUE(allEqual);
1146 }
1147
1148 /// Same as schedulingOfSavesOrderProvided except the order in which the nodes
1149 /// are added to the function don't form a valid schedule.
1150 /// In other words, the scheduler won't get away with scheduling
1151 /// using only the order of the nodes in the list of nodes.
TEST(Graph,schedulingOfSaves)1152 TEST(Graph, schedulingOfSaves) {
1153 ExecutionEngine EE;
1154 PlaceholderBindings bindings;
1155
1156 auto &mod = EE.getModule();
1157 Function *F = mod.createFunction("main");
1158 auto *A = mod.createPlaceholder(ElemKind::FloatTy, {3, 32}, "A", true);
1159 auto *B = mod.createPlaceholder(A->getType(), "B", true);
1160 auto *zero = mod.createPlaceholder(A->getType(), "zero", true);
1161 F->createSave("resetA", zero, A);
1162
1163 bindings.allocate(A)->init(Tensor::InitKind::Xavier, 1.0, mod.getPRNG());
1164 bindings.allocate(B)->init(Tensor::InitKind::Xavier, 1.0, mod.getPRNG());
1165 bindings.allocate(zero)->init(Tensor::InitKind::Broadcast, 0.0,
1166 mod.getPRNG());
1167
1168 auto *addAB = F->createAdd("addAB", A, B);
1169
1170 auto *saveNode = F->createSave("ret", addAB);
1171 bindings.allocate(saveNode->getPlaceholder());
1172
1173 // Copy the value of A.
1174 Tensor AOrig = bindings.get(A)->clone();
1175 auto *ret = saveNode->getPlaceholder();
1176 EE.compile(CompilationMode::Infer);
1177
1178 EE.run(bindings);
1179
1180 auto handleAOrig = AOrig.getHandle<>();
1181 auto handleB = bindings.get(B)->getHandle<>();
1182 auto handleRet = bindings.get(ret)->getHandle<>();
1183 bool allEqual = true;
1184 for (unsigned row = 0; row != 3; ++row) {
1185 for (unsigned column = 0; column != 32; ++column) {
1186 allEqual &= handleAOrig.at({row, column}) + handleB.at({row, column}) ==
1187 handleRet.at({row, column});
1188 }
1189 }
1190 EXPECT_TRUE(bindings.get(A)->isEqual(*bindings.get(zero), 0.0));
1191 EXPECT_TRUE(allEqual);
1192 }
1193
1194 /// Check that the parent link is properly updated while tweaking
1195 /// nodes and their function.
TEST(Graph,parentLink)1196 TEST(Graph, parentLink) {
1197 ExecutionEngine EE;
1198
1199 auto &mod = EE.getModule();
1200 Constant *V =
1201 new Constant("V", mod.uniqueType(ElemKind::FloatTy, {3, 32}), ANY_LAYOUT);
1202
1203 // Variables don't belong to any function...
1204 EXPECT_EQ(V->getParent(), nullptr);
1205 // Even when we create them from a module...
1206 Constant *V2 = mod.createConstant(V->getType(), "V2");
1207 EXPECT_EQ(V2->getParent(), nullptr);
1208
1209 Function *F = mod.createFunction("main");
1210
1211 // Nodes created with function helper belong to the related function.
1212 auto *addNode = F->createAdd("addnode", V, V2);
1213 EXPECT_EQ(addNode->getParent(), F);
1214
1215 // Nodes created directly don't belong to any function.
1216 auto *addNode2 = new AddNode("addnode2", V->getType(), addNode, addNode);
1217 EXPECT_EQ(addNode2->getParent(), nullptr);
1218
1219 // Nodes added to a function belong to that function.
1220 F->addNode(addNode2);
1221 EXPECT_EQ(addNode2->getParent(), F);
1222
1223 // Cloned nodes don't belong to anything.
1224 auto *clonedAddNode = addNode->clone();
1225 EXPECT_EQ(clonedAddNode->getParent(), nullptr);
1226
1227 // Check that the setter properly sets things.
1228 clonedAddNode->setParent(F);
1229 EXPECT_EQ(clonedAddNode->getParent(), F);
1230 clonedAddNode->setParent(nullptr);
1231 EXPECT_EQ(clonedAddNode->getParent(), nullptr);
1232
1233 // Add the cloned node to F so that the memory is properly
1234 // cleaned at the end of the test.
1235 F->addNode(clonedAddNode);
1236 EXPECT_EQ(clonedAddNode->getParent(), F);
1237
1238 delete V;
1239 }
1240
1241 /// Check that verification can detect that Storage nodes are being used by
1242 /// Functions in a Module that doesn't own the Storage nodes.
TEST(Graph,moduleLink)1243 TEST(Graph, moduleLink) {
1244 ExecutionEngine EEA, EEB;
1245
1246 auto &modA = EEA.getModule();
1247 auto &modB = EEB.getModule();
1248
1249 auto *FA = modA.createFunction("FA");
1250 auto *FB = modB.createFunction("FB");
1251
1252 auto *C = modA.createConstant(ElemKind::FloatTy, {1}, "C");
1253 auto *P = modA.createPlaceholder(ElemKind::FloatTy, {1}, "P", false);
1254
1255 auto *AA = FA->createAdd("AA", C, P);
1256 FA->createSave("SA", AA);
1257
1258 // These nodes use Storage nodes that reside in modA
1259 auto *AB = FB->createAdd("AB", C, P);
1260 FB->createSave("SB", AB);
1261
1262 EXPECT_TRUE(modA.verify());
1263 EXPECT_FALSE(
1264 modB.verify()); // Module::verify calls Function::verify on all functions
1265 // within the module, so this should fail
1266 }
1267
1268 /// Check that Cmp nodes are created with proper output types.
TEST(Graph,cmpOutputTypes)1269 TEST(Graph, cmpOutputTypes) {
1270 ExecutionEngine EE;
1271
1272 auto &mod = EE.getModule();
1273 Function *F = mod.createFunction("main");
1274 // Define two different quntized types.
1275 auto qType1 = F->getParent()->uniqueType(ElemKind::Int8QTy, {1, 3}, 0.3, 5);
1276 auto qType2 = F->getParent()->uniqueType(ElemKind::Int8QTy, {1, 3}, 0.4, 5);
1277 // Define two variables of quantized types.
1278 auto *qv1 = mod.createPlaceholder(qType1, "V1", true);
1279 auto *qv2 = mod.createPlaceholder(qType2, "V2", true);
1280 // Create cmp nodes using quantized inputs.
1281 auto *cmpNode1 = F->createCmpEQ("cmpeq", qv1, qv2);
1282 auto *cmpNode2 = F->createCmpLTE("cmplte", qv1, qv2);
1283 // Check that the output type of cmp nodes is BoolKind.
1284 EXPECT_TRUE(cmpNode1->getResult().getElementType() == ElemKind::BoolTy);
1285 EXPECT_TRUE(cmpNode2->getResult().getElementType() == ElemKind::BoolTy);
1286
1287 // Define a non-quantized type.
1288 auto nqType3 = F->getParent()->uniqueType(ElemKind::FloatTy, {1, 3});
1289 // Define two variables of non-quantized types.
1290 auto *nqv3 = mod.createPlaceholder(nqType3, "V3", true);
1291 auto *nqv4 = mod.createPlaceholder(nqType3, "V4", true);
1292 // Create cmp nodes using non-quantized inputs.
1293 auto *cmpNode3 = F->createCmpEQ("cmpeq", nqv3, nqv4);
1294 auto *cmpNode4 = F->createCmpLTE("cmplte", nqv3, nqv4);
1295 // Check that the output type of cmp nodes is BoolKind.
1296 EXPECT_TRUE(cmpNode3->getResult().getElementType() == ElemKind::BoolTy);
1297 EXPECT_TRUE(cmpNode4->getResult().getElementType() == ElemKind::BoolTy);
1298 }
1299
1300 /// Check that the users of value are equal to expectedUsers.
1301 static bool
hasAllTheseUses(const llvm::SmallPtrSetImpl<const Node * > & expectedUsers,const NodeValue & value)1302 hasAllTheseUses(const llvm::SmallPtrSetImpl<const Node *> &expectedUsers,
1303 const NodeValue &value) {
1304 llvm::SmallPtrSet<const Node *, 4> uses;
1305 for (const NodeUse &use : value.getUsers()) {
1306 const Node *user = use.getUser();
1307 if (!expectedUsers.count(user)) {
1308 // We found a user that wasn't on the list.
1309 return false;
1310 }
1311 uses.insert(user);
1312 }
1313 return expectedUsers.size() == uses.size();
1314 }
1315
1316 /// Check that our uses lists are correct for nodes with multiple results.
TEST(Graph,usesListsWithSeveralResult)1317 TEST(Graph, usesListsWithSeveralResult) {
1318 ExecutionEngine EE;
1319 PlaceholderBindings bindings;
1320
1321 auto &mod = EE.getModule();
1322 Function *F = mod.createFunction("main");
1323 auto *input =
1324 mod.createPlaceholder(ElemKind::FloatTy, {3, 32}, "input", true);
1325 auto *topK = F->createTopK("topK", input, 12);
1326 EXPECT_EQ(topK->getNumUsers(), 0);
1327
1328 NodeValue values = topK->getValues();
1329 NodeValue indices = topK->getIndices();
1330 llvm::SmallPtrSet<const Node *, 4> savesOfValues;
1331 llvm::SmallPtrSet<const Node *, 4> savesOfIndices;
1332
1333 EXPECT_EQ(indices.getNumUsers(), 0);
1334 EXPECT_EQ(values.getNumUsers(), 0);
1335
1336 EXPECT_FALSE(indices.hasOneUse());
1337 EXPECT_FALSE(values.hasOneUse());
1338
1339 EXPECT_TRUE(hasAllTheseUses(savesOfIndices, indices));
1340 EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1341
1342 // Now add a user to only one result of the topK node.
1343 savesOfValues.insert(F->createSave("saveValues1", values));
1344
1345 // The whole node should inherit the uses of each of its results.
1346 EXPECT_EQ(topK->getNumUsers(), 1);
1347
1348 // Each result should have its own use list.
1349 EXPECT_EQ(indices.getNumUsers(), 0);
1350 EXPECT_EQ(values.getNumUsers(), 1);
1351
1352 EXPECT_FALSE(indices.hasOneUse());
1353 EXPECT_TRUE(values.hasOneUse());
1354
1355 EXPECT_TRUE(hasAllTheseUses(savesOfIndices, indices));
1356 EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1357
1358 // Add a user to the other result of the topK node.
1359 savesOfIndices.insert(F->createSave("saveIndices1", indices));
1360
1361 // The whole node should inherit the uses of each of its results.
1362 EXPECT_EQ(topK->getNumUsers(), 2);
1363
1364 // Each result should have its own use list.
1365 EXPECT_EQ(indices.getNumUsers(), 1);
1366 EXPECT_EQ(values.getNumUsers(), 1);
1367
1368 EXPECT_TRUE(indices.hasOneUse());
1369 EXPECT_TRUE(values.hasOneUse());
1370
1371 EXPECT_TRUE(hasAllTheseUses(savesOfIndices, indices));
1372 EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1373
1374 // Add a couple more users of values and indices.
1375 // Interleaves the insertions in the uses list for both values and indices.
1376 savesOfValues.insert(F->createSave("saveValues2", values));
1377 savesOfValues.insert(F->createSave("saveValues3", values));
1378 savesOfIndices.insert(F->createSave("saveIndices2", indices));
1379
1380 EXPECT_EQ(topK->getNumUsers(), 5);
1381
1382 EXPECT_EQ(indices.getNumUsers(), 2);
1383 EXPECT_EQ(values.getNumUsers(), 3);
1384
1385 EXPECT_FALSE(indices.hasOneUse());
1386 EXPECT_FALSE(values.hasOneUse());
1387
1388 EXPECT_TRUE(hasAllTheseUses(savesOfIndices, indices));
1389 EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1390 }
1391
1392 /// Check that our uses lists are correct when accessed through
1393 /// NodeValue.
TEST(Graph,usesListsThroughNodeValues)1394 TEST(Graph, usesListsThroughNodeValues) {
1395 ExecutionEngine EE;
1396 PlaceholderBindings bindings;
1397
1398 auto &mod = EE.getModule();
1399 Function *F = mod.createFunction("main");
1400 auto *input =
1401 mod.createPlaceholder(ElemKind::FloatTy, {3, 32}, "input", true);
1402 auto *reLU = F->createRELU("reLU", input);
1403 EXPECT_EQ(reLU->getNumUsers(), 0);
1404
1405 NodeValue values = reLU->getResult();
1406 llvm::SmallPtrSet<const Node *, 4> savesOfValues;
1407
1408 EXPECT_EQ(values.getNumUsers(), 0);
1409
1410 EXPECT_FALSE(values.hasOneUse());
1411
1412 EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1413
1414 // Now add a user to only one result of the reLU node.
1415 savesOfValues.insert(F->createSave("saveValues1", values));
1416
1417 // The whole node should inherit the uses of each of its results.
1418 EXPECT_EQ(reLU->getNumUsers(), 1);
1419
1420 // The NodeValue should match.
1421 EXPECT_EQ(values.getNumUsers(), 1);
1422 EXPECT_TRUE(values.hasOneUse());
1423 EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1424
1425 // Add one more use.
1426 savesOfValues.insert(F->createSave("saveValues2", values));
1427
1428 // The whole node should inherit the uses of each of its results.
1429 EXPECT_EQ(reLU->getNumUsers(), 2);
1430
1431 EXPECT_EQ(values.getNumUsers(), 2);
1432 EXPECT_FALSE(values.hasOneUse());
1433 EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1434
1435 // Add a couple more users.
1436 savesOfValues.insert(F->createSave("saveValues3", values));
1437 savesOfValues.insert(F->createSave("saveValues4", values));
1438
1439 EXPECT_EQ(reLU->getNumUsers(), 4);
1440
1441 EXPECT_EQ(values.getNumUsers(), 4);
1442 EXPECT_FALSE(values.hasOneUse());
1443 EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1444 }
1445
1446 /// Verify that the pre-order visitor works correctly.
TEST(Graph,PreOrderTest)1447 TEST(Graph, PreOrderTest) {
1448 Module M;
1449 PlaceholderBindings bindings;
1450 auto *F = M.createFunction("main");
1451
1452 auto *input1 =
1453 M.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input1", true);
1454 auto *input2 =
1455 M.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input2", true);
1456 SplatNode *zero = F->createSplat("zero", input1->getType(), 0.);
1457 MulNode *mul1 = F->createMul("mul1", zero, input1);
1458 MulNode *mul2 = F->createMul("mul2", zero, input2);
1459 MulNode *mul3 = F->createMul("mul3", mul1, mul2);
1460 SaveNode *ret1 = F->createSave("ret1", mul3);
1461
1462 SplatNode *one = F->createSplat("one", input2->getType(), 1.0);
1463 AddNode *add1 = F->createAdd("add1", input2, one);
1464 AddNode *add2 = F->createAdd("add2", add1, one);
1465 AddNode *add3 = F->createAdd("add3", add2, one);
1466 SaveNode *ret2 = F->createSave("ret2", add2);
1467
1468 GraphPreOrderVisitor visitor(*F);
1469 auto order = visitor.getPreOrder();
1470
1471 ASSERT_EQ(order.size(), 14);
1472 EXPECT_EQ(order[0], ret1);
1473 EXPECT_EQ(order[1], mul3);
1474 EXPECT_EQ(order[2], mul1);
1475 EXPECT_EQ(order[3], zero);
1476 EXPECT_EQ(order[4], input1);
1477 EXPECT_EQ(order[5], mul2);
1478 EXPECT_EQ(order[6], input2);
1479 EXPECT_EQ(order[7], ret1->getOutput());
1480 EXPECT_EQ(order[8], add3);
1481 EXPECT_EQ(order[9], add2);
1482 EXPECT_EQ(order[10], add1);
1483 EXPECT_EQ(order[11], one);
1484 EXPECT_EQ(order[12], ret2);
1485 EXPECT_EQ(order[13], ret2->getOutput());
1486 }
1487
1488 /// Verify that the post-order visitor works correctly.
TEST(Graph,PostOrderTest)1489 TEST(Graph, PostOrderTest) {
1490 Module M;
1491 PlaceholderBindings bindings;
1492 auto *F = M.createFunction("main");
1493
1494 auto *input1 =
1495 M.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input1", true);
1496 auto *input2 =
1497 M.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input2", true);
1498 SplatNode *zero = F->createSplat("zero", input1->getType(), 0.);
1499 MulNode *mul1 = F->createMul("mul1", zero, input1);
1500 MulNode *mul2 = F->createMul("mul2", zero, input2);
1501 MulNode *mul3 = F->createMul("mul3", mul1, mul2);
1502 SaveNode *ret1 = F->createSave("ret1", mul3);
1503
1504 SplatNode *one = F->createSplat("one", input2->getType(), 1.0);
1505 AddNode *add1 = F->createAdd("add1", input2, one);
1506 AddNode *add2 = F->createAdd("add2", add1, one);
1507 AddNode *add3 = F->createAdd("add3", add2, one);
1508 SaveNode *ret2 = F->createSave("ret2", add2);
1509
1510 GraphPostOrderVisitor visitor(*F);
1511 auto order = visitor.getPostOrder();
1512
1513 ASSERT_EQ(order.size(), 14);
1514 EXPECT_EQ(order[0], zero);
1515 EXPECT_EQ(order[1], input1);
1516 EXPECT_EQ(order[2], mul1);
1517 EXPECT_EQ(order[3], input2);
1518 EXPECT_EQ(order[4], mul2);
1519 EXPECT_EQ(order[5], mul3);
1520 EXPECT_EQ(order[6], ret1->getOutput());
1521 EXPECT_EQ(order[7], ret1);
1522 EXPECT_EQ(order[8], one);
1523 EXPECT_EQ(order[9], add1);
1524 EXPECT_EQ(order[10], add2);
1525 EXPECT_EQ(order[11], add3);
1526 EXPECT_EQ(order[12], ret2->getOutput());
1527 EXPECT_EQ(order[13], ret2);
1528 }
1529
TEST(Graph,placeholder)1530 TEST(Graph, placeholder) {
1531 Module MD;
1532 PlaceholderBindings bindings;
1533 Function *F = MD.createFunction("F");
1534 IRFunction M(F);
1535 Node *K =
1536 MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", false);
1537 Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", false);
1538
1539 K = F->createFullyConnected(bindings, "FC", K, 10);
1540 K = F->createRELU("Relu", K);
1541 K = F->createSoftMax("SoftMax", K, S);
1542 F->createSave("Save", K);
1543 }
1544
1545 /// Check that the setType API allows to change the type of the
1546 /// related result and only the related result.
TEST(Graph,setType)1547 TEST(Graph, setType) {
1548 Module M;
1549 auto *F = M.createFunction("main");
1550
1551 const dim_t inputDims[] = {4, 10};
1552 const dim_t top5Dims[] = {4, 5};
1553 auto *input =
1554 M.createPlaceholder(ElemKind::FloatTy, inputDims, "input", true);
1555 TopKNode *topK = F->createTopK("add", input, 5);
1556 TypeRef origTopKRes0 = M.uniqueType(ElemKind::FloatTy, top5Dims);
1557 TypeRef origTopKRes1 = M.uniqueType(ElemKind::Int64ITy, top5Dims);
1558
1559 EXPECT_EQ(topK->getType(TopKNode::ValuesIdx), origTopKRes0);
1560 EXPECT_EQ(topK->getType(TopKNode::IndicesIdx), origTopKRes1);
1561
1562 // Modify the type of result 0 and make sure type 1 is not
1563 // affected. Similarly the input shouldn't be affected.
1564 TypeRef inputTy = M.uniqueType(ElemKind::FloatTy, inputDims);
1565 TypeRef topKRes0 = M.uniqueType(ElemKind::Float16Ty, top5Dims);
1566 topK->setType(TopKNode::ValuesIdx, topKRes0);
1567 EXPECT_EQ(input->getType(), inputTy);
1568 EXPECT_EQ(topK->getType(TopKNode::ValuesIdx), topKRes0);
1569 EXPECT_EQ(topK->getType(TopKNode::IndicesIdx), origTopKRes1);
1570
1571 // Make sure the NodeValue API works the same way
1572 // as the Node::setType API.
1573 NodeValue valRes1 = topK->getNthResult(TopKNode::IndicesIdx);
1574 valRes1.setType(topKRes0);
1575 EXPECT_EQ(input->getType(), inputTy);
1576 EXPECT_EQ(topK->getType(TopKNode::ValuesIdx), topKRes0);
1577 EXPECT_EQ(topK->getType(TopKNode::IndicesIdx), topKRes0);
1578 EXPECT_EQ(valRes1.getType(), topKRes0);
1579
1580 // Now restore sane types.
1581 NodeValue valRes0 = topK->getNthResult(TopKNode::ValuesIdx);
1582 valRes0.setType(origTopKRes0);
1583 topK->setType(TopKNode::IndicesIdx, origTopKRes1);
1584 EXPECT_EQ(input->getType(), inputTy);
1585 EXPECT_EQ(topK->getType(TopKNode::ValuesIdx), origTopKRes0);
1586 EXPECT_EQ(valRes0.getType(), origTopKRes0);
1587 EXPECT_EQ(topK->getType(TopKNode::IndicesIdx), origTopKRes1);
1588 EXPECT_EQ(valRes1.getType(), origTopKRes1);
1589 }
1590
1591 /// Check that we fixed the bug with Function::eraseNode. This method used to
1592 /// erase a node that was equal to the node we wanted to delete, which may be
1593 /// two different entities.
1594 /// To see this bug in action, we create a bunch of nodes with the same value.
1595 /// Then we erase them in reserve order. This reserve ordering was actually
1596 /// freeing the node in the original order, thus at some point we try to delete
1597 /// a node that has already deleted and an assert (debug mode) or segmentation
1598 /// fault (release would occur).
1599 /// Note: Which node is actually freed depend on the implementation of
1600 /// std::find, thus we cannot really predict when the bug occurs.
TEST(Graph,eraseNodeBug)1601 TEST(Graph, eraseNodeBug) {
1602 Module M;
1603 auto *F = M.createFunction("main");
1604
1605 auto *input = M.createPlaceholder(ElemKind::FloatTy, {3, 2}, "input", true);
1606 std::vector<Node *> ReLUs;
1607 // Create a bunch of ReLUs.
1608 for (unsigned idx = 0; idx != 5; ++idx) {
1609 ReLUs.push_back(F->createRELU("relu", input));
1610 }
1611 // Check that we can erase all the nodes.
1612 for (int idx = 4; idx != -1; --idx) {
1613 F->eraseNode(ReLUs[idx]);
1614 }
1615 EXPECT_EQ(F->getNodes().size(), 0);
1616 }
1617
1618 /// Verify that two Nodes with different predicates but the same inputs are not
1619 /// considered equal.
TEST(Graph,nodeEqualityWithDifferentPredicates)1620 TEST(Graph, nodeEqualityWithDifferentPredicates) {
1621 Module M;
1622 auto *F = M.createFunction("main");
1623
1624 Node *in = M.createPlaceholder(ElemKind::FloatTy, {5}, "in", false);
1625 Node *pred1 = M.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1626 Node *pred2 = M.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1627
1628 Node *RN1 = F->createRELU("relu1", in);
1629 RN1->setPredicate(pred1);
1630
1631 Node *RN2 = F->createRELU("relu2", in);
1632 RN2->setPredicate(pred2);
1633
1634 EXPECT_FALSE(RN1->isEqual(*RN2));
1635 }
1636
1637 /// Check that verify doesn't allow for multiple writers to the same node.
TEST(Graph,verifyOneWriter)1638 TEST(Graph, verifyOneWriter) {
1639 Module M;
1640 auto *F = M.createFunction("main");
1641
1642 auto *input = M.createPlaceholder(ElemKind::FloatTy, {5}, "input", false);
1643 auto *output = M.createPlaceholder(ElemKind::FloatTy, {5}, "output", false);
1644 F->createSave("Save1", input, output);
1645 F->createSave("Save2", input, output);
1646
1647 EXPECT_FALSE(M.verify());
1648 }
1649
1650 /// Check that verify doesn't allow for Constants to be written to. Note that
1651 /// createSave() cannot do this as the API only accepts Placeholders to write
1652 /// to, however it could happen during graph transformations, e.g. via
1653 /// replaceAllUsesOfWith() as shown here.
TEST(Graph,verifyConstantNoWriters)1654 TEST(Graph, verifyConstantNoWriters) {
1655 Module M;
1656 auto *F = M.createFunction("main");
1657
1658 auto *input = M.createPlaceholder(ElemKind::FloatTy, {5}, "input", false);
1659 auto *outputPH = M.createPlaceholder(ElemKind::FloatTy, {5}, "outPH", false);
1660 F->createSave("save", input, outputPH);
1661
1662 // Replace the output Placeholder with a Constant. This should fail
1663 // verification.
1664 auto *outputC = M.createConstant(ElemKind::FloatTy, {5}, "outC");
1665 NodeValue(outputPH).replaceAllUsesOfWith(outputC);
1666
1667 EXPECT_FALSE(M.verify());
1668 }
1669
TEST(Graph,typeUnsafeReplaceAllUsesOfWith)1670 TEST(Graph, typeUnsafeReplaceAllUsesOfWith) {
1671 Module M;
1672 auto *F = M.createFunction("main");
1673
1674 auto *LHS = M.createPlaceholder(ElemKind::FloatTy, {3, 4}, "A", false);
1675 auto *RHS = M.createPlaceholder(ElemKind::FloatTy, {4, 5}, "B", false);
1676 auto *FC = F->createMatMul("fc", LHS, RHS);
1677 F->createSave("save", FC);
1678
1679 auto newLHS = M.createPlaceholder(ElemKind::FloatTy, {10, 10}, "A", false);
1680 LHS->getOutput().typeUnsafeReplaceAllUsesOfWith(newLHS);
1681 }
1682
1683 /// Check that the verifier will complain if a constant and its
1684 /// underlying tensor have mismatching types.
1685 /// Here the constant is updated but not the tensor.
TEST(Graph,verifyConstantTensorTypeMatchesConstantTypeChanged)1686 TEST(Graph, verifyConstantTensorTypeMatchesConstantTypeChanged) {
1687 Module M;
1688
1689 auto *input = M.createConstant(ElemKind::FloatTy, {5}, "input");
1690 // Fresh constant should verify just fine.
1691 EXPECT_TRUE(input->verify());
1692
1693 input->setType(Storage::OutputIdx, M.uniqueType(ElemKind::Float16Ty, {5}));
1694
1695 EXPECT_FALSE(input->verify());
1696 }
1697
1698 /// Check that the verifier will complain if a constant and its
1699 /// underlying tensor have mismatching types.
1700 /// Here the tensor is updated but not the constant.
TEST(Graph,verifyConstantTensorTypeMatchesTensorTypeChanged)1701 TEST(Graph, verifyConstantTensorTypeMatchesTensorTypeChanged) {
1702 Module M;
1703
1704 auto *input = M.createConstant(ElemKind::FloatTy, {5}, "input");
1705 // Fresh constant should verify just fine.
1706 EXPECT_TRUE(input->verify());
1707 input->getPayloadMutable().convertToType(ElemKind::Float16Ty);
1708
1709 EXPECT_FALSE(input->verify());
1710 }
1711
1712 /// Check that Constants backed by unowned Tensors are in fact unowned until
1713 /// a mutable reference to their payload is obtained at which point the backing
1714 /// Tensor is copied and becomes owned.
TEST(Graph,verifyConstantWithUnownedTensorCopiesOnWrite)1715 TEST(Graph, verifyConstantWithUnownedTensorCopiesOnWrite) {
1716 Module M;
1717
1718 Tensor originalT(ElemKind::FloatTy, {3});
1719 Tensor unownedT = originalT.getUnowned({3});
1720
1721 auto originalH = originalT.getHandle();
1722
1723 for (size_t i = 0; i < originalT.size(); i++) {
1724 originalH.raw(i) = i;
1725 }
1726
1727 // Both Tensors should have the same underlying memory because unownedT shares
1728 // originalT's memory.
1729 EXPECT_EQ(originalT.getUnsafePtr(), unownedT.getUnsafePtr());
1730
1731 Constant *originalC = M.createConstant("original", std::move(originalT));
1732 Constant *unownedC = M.createConstant("unowned", std::move(unownedT));
1733
1734 const Tensor &originalCT = originalC->getPayload();
1735 const Tensor &unownedCT = unownedC->getPayload();
1736
1737 const auto originalCTH = originalCT.getHandle();
1738 const auto unownedCTH = unownedCT.getHandle();
1739
1740 ASSERT_EQ(originalCTH.size(), unownedCTH.size());
1741
1742 // Both Constants should have the same values because their Tensors have the
1743 // same underlying memory.
1744 for (size_t i = 0; i < originalCTH.size(); i++) {
1745 EXPECT_EQ(i, originalCTH.raw(i));
1746 EXPECT_EQ(i, unownedCTH.raw(i));
1747 }
1748
1749 Tensor &originalCTM = originalC->getPayloadMutable();
1750 auto originalCTMH = originalCTM.getHandle();
1751
1752 // Bump up the value in the original Constant, this change should be
1753 // reflected in the unowned Constant as well.
1754 for (size_t i = 0; i < originalCTMH.size(); i++) {
1755 originalCTMH.raw(i) += 1;
1756 }
1757
1758 // After changing the values in the original Constant, we should see an update
1759 // in the values of the unowned Constant because they share the same
1760 // underlying memory.
1761 for (size_t i = 0; i < unownedCTH.size(); i++) {
1762 EXPECT_EQ(unownedCTH.raw(i), i + 1);
1763 }
1764
1765 Tensor &unownedCTM = unownedC->getPayloadMutable();
1766 auto unownedCTMH = unownedCTM.getHandle();
1767
1768 ASSERT_EQ(originalCTH.size(), unownedCTMH.size());
1769
1770 // After getting a mutable reference to the unowned Constant's payload, the
1771 // underlying memory should have been copied but should still contain the same
1772 // values as it did previously at this point.
1773 EXPECT_NE(unownedCTM.getUnsafePtr(), originalCT.getUnsafePtr());
1774 for (size_t i = 0; i < unownedCTMH.size(); i++) {
1775 EXPECT_EQ(unownedCTMH.raw(i), i + 1);
1776 }
1777
1778 // Bump up the value in the original Constant again, this change should not be
1779 // reflected in the unowned Constant now because at this point, after a
1780 // mutable reference to its payload has been obtained, it should have it's own
1781 // memory.
1782 for (size_t i = 0; i < originalCTMH.size(); i++) {
1783 originalCTMH.raw(i) += 1;
1784 }
1785
1786 // Now that the unowned Constant's payload has been obtained as mutable, it
1787 // should have been copied and thus have its own memory and changes to the
1788 // original constant should not be reflected in the unowned Constant.
1789 for (size_t i = 0; i < unownedCTMH.size(); i++) {
1790 EXPECT_EQ(unownedCTMH.raw(i), i + 1);
1791 }
1792 }
1793
1794 /// Check that hooking an intermediate node works.
TEST(Graph,hookTest)1795 TEST(Graph, hookTest) {
1796 Module mod;
1797 auto *F = mod.createFunction("main");
1798 auto *in = mod.createPlaceholder(ElemKind::FloatTy, {1}, "in", false);
1799 auto *relu1 = F->createRELU("relu1", in);
1800 auto *relu2 = F->createRELU("relu2", relu1);
1801 F->createSave("save", relu2);
1802 EXPECT_EQ(F->getNodes().size(), 3);
1803 EXPECT_EQ(mod.getPlaceholders().size(), 2);
1804
1805 // Hook the first relu and verify that the hooked graph looks right.
1806 auto hooked = glow::hookNode(F, relu1);
1807 auto const &nodes = hooked.function->getNodes();
1808 ASSERT_EQ(mod.getPlaceholders().size(), 3);
1809 ASSERT_EQ(nodes.size(), 2);
1810 auto const *hookSave = *hooked.outputSaves.begin();
1811 ASSERT_TRUE(hookSave);
1812 auto *inp = llvm::dyn_cast<ReluNode>(hookSave->getInput());
1813 ASSERT_TRUE(inp);
1814 auto *ph = llvm::dyn_cast<Placeholder>(inp->getInput());
1815 ASSERT_TRUE(ph);
1816 ASSERT_EQ(ph, in);
1817 }
1818
1819 /// Check that getConstantsSize returns the correct size of constants.
TEST(Graph,moduleSize)1820 TEST(Graph, moduleSize) {
1821 Module mod;
1822
1823 EXPECT_EQ(mod.getConstantsSize(), 0);
1824
1825 auto *cons1 = mod.createConstant(ElemKind::FloatTy, {1}, "var");
1826 EXPECT_EQ(mod.getConstantsSize(), sizeof(float) * cons1->getPayload().size());
1827
1828 auto *cons2 = mod.createConstant(ElemKind::FloatTy, {1, 32, 32, 16}, "var2");
1829 EXPECT_EQ(mod.getConstantsSize(),
1830 sizeof(float) + sizeof(float) * cons2->getPayload().size());
1831 }
1832
1833 /// Check that getDataSize() returns the correct size of backing tensors.
TEST(Graph,contextSize)1834 TEST(Graph, contextSize) {
1835 Module mod;
1836 PlaceholderBindings bindings;
1837
1838 Placeholder *PH =
1839 mod.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
1840
1841 EXPECT_EQ(bindings.getDataSize(), 0);
1842 bindings.allocate(PH);
1843 EXPECT_EQ(bindings.get(PH)->size(), 4 * 320 * 200 * 3);
1844 EXPECT_EQ(bindings.getDataSize(), sizeof(float) * bindings.get(PH)->size());
1845 }
1846
1847 /// Check that clones of the context are distinct and share no references back
1848 /// to the original object.
TEST(Graph,clonePlaceholderBindings)1849 TEST(Graph, clonePlaceholderBindings) {
1850 Module mod;
1851
1852 Placeholder *PH1 =
1853 mod.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4}, "PH1", false);
1854
1855 PlaceholderBindings bindings1;
1856 bindings1.allocate(PH1);
1857
1858 PlaceholderBindings bindings2 = bindings1.clone();
1859
1860 Tensor *t1 = bindings1.get(PH1);
1861 Tensor *t2 = bindings2.get(PH1);
1862
1863 EXPECT_NE(t1, nullptr);
1864 EXPECT_NE(t2, nullptr);
1865 EXPECT_NE(t1, t2);
1866
1867 // The new PlaceholderBindings has no references back, and changing it does
1868 // not affect bindings1
1869 Placeholder *PH2 =
1870 mod.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4}, "PH2", false);
1871
1872 bindings2.allocate(PH2);
1873 // now exists in bindings1 but not bindings2
1874 EXPECT_EQ(bindings1.get(PH2), nullptr);
1875 EXPECT_NE(bindings2.get(PH2), nullptr);
1876
1877 // Likewise changing bindings1 does not affect bindings2
1878 bindings1.clear();
1879 EXPECT_EQ(bindings1.count(PH1), 0);
1880 EXPECT_EQ(bindings2.count(PH1), 1);
1881
1882 // Adds are distinct
1883 Placeholder *PH3 =
1884 mod.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4}, "PH3", false);
1885 bindings1.allocate(PH3);
1886 bindings2.allocate(PH3);
1887 EXPECT_NE(bindings1.get(PH3), nullptr);
1888 EXPECT_NE(bindings2.get(PH3), nullptr);
1889 EXPECT_NE(bindings1.get(PH3), bindings2.get(PH3));
1890 }
1891
1892 /// Check that running a function multiple times on cloned PlaceholderBindingss
1893 /// have distinct outputs.
TEST(Graph,clonePlaceholderBindingsRuns)1894 TEST(Graph, clonePlaceholderBindingsRuns) {
1895 ExecutionEngine EE;
1896 PseudoRNG PRNG;
1897
1898 Tensor inputs(ElemKind::FloatTy, {1, 32, 32, 3});
1899 auto &mod = EE.getModule();
1900 Function *F = mod.createFunction("main");
1901 PlaceholderBindings bindings;
1902 auto *input =
1903 mod.createPlaceholder(ElemKind::FloatTy, {1, 32, 32, 3}, "input", true);
1904
1905 bindings.allocate(input);
1906
1907 auto *FCL1 = F->createFullyConnected(bindings, "fc", input, 10);
1908 auto *RL3 = F->createRELU("relu4", FCL1);
1909 auto *save = F->createSave("ret", RL3);
1910 auto *savePH = save->getPlaceholder();
1911
1912 bindings.allocate(save->getPlaceholder());
1913
1914 // Compile once.
1915 EE.compile(CompilationMode::Infer);
1916
1917 // Run with random inputs.
1918 inputs.getHandle<>().randomize(-3.0, 3.0, PRNG);
1919 updateInputPlaceholders(bindings, {input}, {&inputs});
1920 EE.run(bindings);
1921
1922 // Clone the context.
1923 PlaceholderBindings bindings2 = bindings.clone();
1924
1925 // PlaceholderBindingss are identical.
1926 Tensor *saveBacking1, *saveBacking2;
1927 saveBacking1 = bindings.get(savePH);
1928 saveBacking2 = bindings2.get(savePH);
1929 EXPECT_NE(saveBacking1, saveBacking2);
1930 EXPECT_EQ(saveBacking1->size(), saveBacking2->size());
1931 EXPECT_TRUE(saveBacking1->isEqual(*saveBacking2));
1932
1933 // Run again with different random inputs using the cloned context.
1934 Tensor inputs2(ElemKind::FloatTy, {1, 32, 32, 3});
1935 inputs2.getHandle<>().randomize(-3.0, 3.0, PRNG);
1936 updateInputPlaceholders(bindings2, {input}, {&inputs2});
1937 EE.run(bindings2);
1938
1939 // PlaceholderBindingss are no longer identical.
1940 EXPECT_EQ(saveBacking1->size(), saveBacking2->size());
1941 EXPECT_FALSE(saveBacking1->isEqual(*saveBacking2));
1942 }
1943
1944 /// Check that using the indices enums in nodes works correctly, with
1945 /// multi-input, multi-output, and single-input/output nodes.
TEST(Graph,TestNodeEnums)1946 TEST(Graph, TestNodeEnums) {
1947 Module MD;
1948 Function *F = MD.createFunction("F");
1949 PlaceholderBindings bindings;
1950 Placeholder *I =
1951 MD.createPlaceholder(ElemKind::FloatTy, {10, 10}, "input", true);
1952 Placeholder *O = MD.createPlaceholder(ElemKind::FloatTy, {3}, "output", true);
1953
1954 TopKNode *TKN = F->createTopK("topk", I, 3);
1955 GatherNode *GN =
1956 F->createGather("gather", TKN->getValues(), TKN->getIndices());
1957 TanhNode *TN = F->createTanh("tanh", GN);
1958 SaveNode *SN = F->createSave("save", TN, O);
1959
1960 // Check structure of Placeholders.
1961 EXPECT_EQ(I->getNthResult(Storage::OutputIdx), I->getOutput());
1962 EXPECT_EQ(O->getNthResult(Storage::OutputIdx), O->getOutput());
1963
1964 // Check structure of TopK.
1965 EXPECT_EQ(TKN->getInput(), TKN->getNthInput(TopKNode::InputIdx));
1966 EXPECT_EQ(TKN->getNthResult(TopKNode::ValuesIdx), TKN->getValues());
1967 EXPECT_EQ(TKN->getNthResult(TopKNode::IndicesIdx), TKN->getIndices());
1968
1969 // Check structure of Gather.
1970 EXPECT_EQ(GN->getNthInput(GatherNode::DataIdx), GN->getData());
1971 EXPECT_EQ(GN->getNthInput(GatherNode::IndicesIdx), GN->getIndices());
1972 EXPECT_EQ(GN->getNthResult(GatherNode::ResultIdx), GN->getResult());
1973
1974 // Check structure of Tanh.
1975 EXPECT_EQ(TN->getNthInput(TanhNode::InputIdx), TN->getInput());
1976 EXPECT_EQ(TN->getNthResult(TanhNode::ResultIdx), TN->getResult());
1977
1978 // Check structure of Save.
1979 EXPECT_EQ(SN->getNthInput(SaveNode::InputIdx), SN->getInput());
1980 EXPECT_EQ(SN->getNthInput(SaveNode::OutputIdx), SN->getOutput());
1981
1982 // Check connection between Placeholder and TopK.
1983 EXPECT_EQ(TKN->getNthInput(TopKNode::InputIdx), I->getOutput());
1984
1985 // Check connections between TopK and Gather.
1986 EXPECT_EQ(TKN->getNthResult(TopKNode::ValuesIdx),
1987 GN->getNthInput(GatherNode::DataIdx));
1988 EXPECT_EQ(TKN->getNthResult(TopKNode::IndicesIdx),
1989 GN->getNthInput(GatherNode::IndicesIdx));
1990
1991 // Check connection between Gather and Tanh.
1992 EXPECT_EQ(GN->getNthResult(GatherNode::ResultIdx),
1993 TN->getNthInput(TanhNode::InputIdx));
1994
1995 // Check connection between Gather and Tanh.
1996 EXPECT_EQ(TN->getNthResult(TanhNode::ResultIdx),
1997 SN->getNthInput(SaveNode::InputIdx));
1998
1999 // Check connection between Gather and Tanh.
2000 EXPECT_EQ(SN->getNthInput(SaveNode::OutputIdx), O->getOutput());
2001 }
2002
2003 /// Searched \p F for a single instance of a node of Kind T. If more than one is
2004 /// found, \returns nullptr, otherwise returns the single instance.
findSingleInstanceOfNode(Function * F)2005 template <class T> static T *findSingleInstanceOfNode(Function *F) {
2006 T *found = nullptr;
2007 for (auto &n : F->getNodes()) {
2008 if (auto *currNode = llvm::dyn_cast<T>(&n)) {
2009 if (found != nullptr) {
2010 return nullptr;
2011 }
2012 found = currNode;
2013 }
2014 }
2015 return found;
2016 }
2017
2018 /// Check that group Conv is not lowered when specified to lower by backend if
2019 /// doNotLowerKinds contains Conv.
TEST(Graph,GroupTestConvNoLower)2020 TEST(Graph, GroupTestConvNoLower) {
2021 Module MD;
2022 Function *F = MD.createFunction("F");
2023 IRFunction M(F);
2024 PlaceholderBindings bindings;
2025 Node *K =
2026 MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 8}, "input", true);
2027 Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", true);
2028
2029 K = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, /* group */ 8);
2030 K = F->createRELU("Relu", K);
2031 K = F->createSoftMax("SoftMax", K, S);
2032 F->createSave("Save", K);
2033 F->dump();
2034 auto filePath = F->dumpDAG();
2035 auto backend = MockBackend();
2036
2037 {
2038 // Before we lower, we should have a single Conv node with group = 8.
2039 ConvolutionNode *CN = findSingleInstanceOfNode<ConvolutionNode>(F);
2040 if (!CN) {
2041 llvm::sys::fs::remove(filePath);
2042 }
2043 ASSERT_TRUE(CN);
2044 EXPECT_EQ(CN->getGroup(), 8);
2045 }
2046
2047 // Now lower, but prevent ConvolutionNodeKinds from being lowered.
2048 KindSet doNotLower;
2049 doNotLower.insert(Kinded::Kind::ConvolutionNodeKind);
2050 CompilationContext cctx;
2051 lower(F, cctx, &backend, doNotLower);
2052
2053 {
2054 // Now have lowered but should still have a single Conv node with group = 8.
2055 ConvolutionNode *CN = findSingleInstanceOfNode<ConvolutionNode>(F);
2056 if (!CN) {
2057 llvm::sys::fs::remove(filePath);
2058 }
2059 ASSERT_TRUE(CN);
2060 EXPECT_EQ(CN->getGroup(), 8);
2061 }
2062 }
2063
2064 /// Check that getOutputSave returns SaveNode object for the correct Placeholder
2065 /// and nullptr in other cases.
TEST(Graph,GetOutputSaveTest)2066 TEST(Graph, GetOutputSaveTest) {
2067 Module MD;
2068 Function *F = MD.createFunction("F");
2069 PlaceholderBindings bindings;
2070 Placeholder *I =
2071 MD.createPlaceholder(ElemKind::FloatTy, {10, 10}, "input", true);
2072 Placeholder *O = MD.createPlaceholder(ElemKind::FloatTy, {3}, "output", true);
2073 TopKNode *TKN = F->createTopK("topk", I, 3);
2074 GatherNode *GN =
2075 F->createGather("gather", TKN->getValues(), TKN->getIndices());
2076 TanhNode *TN = F->createTanh("tanh", GN);
2077 SaveNode *SN = F->createSave("save", TN, O);
2078
2079 // Check the return value of getOutputSave method.
2080 // Placeholder parent is null.
2081 auto *FoundNode = glow::getOutputSave(F, O);
2082 EXPECT_NE(nullptr, FoundNode);
2083 EXPECT_EQ(SN, FoundNode);
2084
2085 // Placeholder parent is set to the correct value.
2086 O->setParent(F);
2087 EXPECT_EQ(F, O->getParent());
2088 FoundNode = glow::getOutputSave(F, O);
2089 EXPECT_NE(nullptr, FoundNode);
2090 EXPECT_EQ(SN, FoundNode);
2091
2092 // Invalid placeholder type is provided.
2093 EXPECT_EQ(nullptr, glow::getOutputSave(F, I));
2094
2095 // Save belongs to a different function
2096 Function *F2 = MD.createFunction("F2");
2097 TopKNode *TKN2 = F2->createTopK("topk", I, 3);
2098 GatherNode *GN2 =
2099 F2->createGather("gather", TKN2->getValues(), TKN2->getIndices());
2100 TanhNode *TN2 = F2->createTanh("tanh", GN2);
2101 SaveNode *SN2 = F2->createSave("save", TN2, O);
2102
2103 FoundNode = glow::getOutputSave(F, O);
2104 EXPECT_NE(nullptr, FoundNode);
2105 EXPECT_EQ(SN, FoundNode);
2106
2107 O->setParent(F2);
2108 FoundNode = glow::getOutputSave(F2, O);
2109 EXPECT_NE(nullptr, FoundNode);
2110 EXPECT_EQ(SN2, FoundNode);
2111 }
2112
2113 /// Check if dump functions work for Node, Function and Module.
TEST(Graph,testDumpStructure)2114 TEST(Graph, testDumpStructure) {
2115 Module MD;
2116 Function *F = MD.createFunction("F");
2117 IRFunction M(F);
2118 PlaceholderBindings bindings;
2119 Node *K = MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 100, 3},
2120 "input", true);
2121 // Test Node
2122 std::string storageN1;
2123 llvm::raw_string_ostream osN1(storageN1);
2124 K->dump(osN1);
2125 std::string mesN = K->toString();
2126 std::string expectMes = R"(Placeholder
2127 name : "input"
2128 layout : *
2129 output : float<4 x 320 x 200 x 100 x 3>
2130 trainable : 1
2131 static : 0
2132 users : 0
2133 )";
2134 EXPECT_EQ(mesN, expectMes);
2135 EXPECT_EQ(mesN, osN1.str());
2136 std::string storageN2;
2137 llvm::raw_string_ostream osN2(storageN2);
2138 osN2 << K;
2139 EXPECT_EQ(mesN, osN2.str());
2140 // Test Function
2141 Placeholder *I =
2142 MD.createPlaceholder(ElemKind::FloatTy, {10, 10}, "input", true);
2143 I->setStatic(true);
2144 Function *F2 = MD.createFunction("F2");
2145 F2->createTopK("topk", I, 3);
2146 std::string storageF1;
2147 llvm::raw_string_ostream osF1(storageF1);
2148 F2->dump(osF1);
2149 std::string mesF = F2->toString();
2150 std::string expectMesF = R"(Graph structure F2:
2151 TopK
2152 name : topk
2153 Input : float<10 x 10>
2154 K : 3
2155 users : 0
2156 Values : float<10 x 3>
2157 Indices : index64<10 x 3>
2158 Placeholder
2159 name : "input__1"
2160 layout : *
2161 output : float<10 x 10>
2162 trainable : 1
2163 static : 1
2164 users : 1
2165 )";
2166 EXPECT_EQ(mesF, expectMesF);
2167 EXPECT_EQ(mesF, osF1.str());
2168 std::string storageF2;
2169 llvm::raw_string_ostream osF2(storageF2);
2170 osF2 << F2;
2171 EXPECT_EQ(mesF, osF2.str());
2172 storageF1.clear();
2173 F2->dump(osF1, /* skipUsersForStorage */ true);
2174 mesF = F2->toString(/* skipUsersForStorage */ true);
2175 expectMesF = R"(Graph structure F2:
2176 TopK
2177 name : topk
2178 Input : float<10 x 10>
2179 K : 3
2180 users : 0
2181 Values : float<10 x 3>
2182 Indices : index64<10 x 3>
2183 Placeholder
2184 name : "input__1"
2185 layout : *
2186 output : float<10 x 10>
2187 trainable : 1
2188 static : 1
2189 )";
2190 EXPECT_EQ(mesF, expectMesF);
2191 EXPECT_EQ(mesF, osF1.str());
2192 // Test Module
2193 MD.createConstant(ElemKind::FloatTy, {1, 1}, "dummy");
2194 std::string storageM1;
2195 llvm::raw_string_ostream osM1(storageM1);
2196 MD.dump(osM1);
2197 std::string mesM = MD.toString();
2198 std::string expectMesM = R"(Module structure:
2199 Constant
2200 name : "dummy"
2201 layout : *
2202 output : float<1 x 1>
2203 users : 0
2204
2205 Placeholder
2206 name : "input__1"
2207 layout : *
2208 output : float<10 x 10>
2209 trainable : 1
2210 static : 1
2211 users : 1
2212
2213 Placeholder
2214 name : "input"
2215 layout : *
2216 output : float<4 x 320 x 200 x 100 x 3>
2217 trainable : 1
2218 static : 0
2219 users : 0
2220
2221 Function : F2
2222 Function : F
2223 )";
2224 EXPECT_EQ(mesM, expectMesM);
2225 EXPECT_EQ(mesM, osM1.str());
2226 std::string storageM2;
2227 llvm::raw_string_ostream osM2(storageM2);
2228 osM2 << MD;
2229 EXPECT_EQ(mesM, osM2.str());
2230 }
2231
2232 /// Initialize tensor payload for testing purposes. The value at index i is set
2233 /// to i.
initTensor(Tensor & T)2234 template <typename ElemTy> static void initTensor(Tensor &T) {
2235 Handle<ElemTy> handle = T.getHandle<ElemTy>();
2236 float val = 0;
2237 for (auto &elem : handle) {
2238 elem = val;
2239 val += 1.0;
2240 }
2241 }
2242
2243 // Test that randomizing Constants in a Function works.
TEST(Graph,testRandomizeConstants)2244 TEST(Graph, testRandomizeConstants) {
2245 Module MD;
2246 Function *F = MD.createFunction("F");
2247
2248 // Create tensors to be used in Constants
2249 Tensor floatT(ElemKind::FloatTy, {10});
2250 initTensor<float>(floatT);
2251
2252 Tensor halfT(ElemKind::Float16Ty, {10});
2253 initTensor<float16_t>(halfT);
2254
2255 Tensor bfloat16T(ElemKind::BFloat16Ty, {10});
2256 initTensor<bfloat16_t>(bfloat16T);
2257
2258 Tensor int8QT(ElemKind::Int8QTy, {10}, 1.0, 0);
2259 initTensor<int8_t>(int8QT);
2260
2261 Tensor uint8QT(ElemKind::UInt8QTy, {10}, 1.0, 0);
2262 initTensor<uint8_t>(uint8QT);
2263
2264 Tensor int16QT(ElemKind::Int16QTy, {10}, 1.0, 0);
2265 initTensor<int16_t>(int16QT);
2266
2267 Tensor int32QT(ElemKind::Int32QTy, {10}, 1.0, 0);
2268 initTensor<int32_t>(int32QT);
2269
2270 Tensor int32IT(ElemKind::Int32ITy, {10});
2271 initTensor<int32_t>(int32IT);
2272
2273 Tensor int64IT(ElemKind::Int64ITy, {10});
2274 initTensor<int64_t>(int64IT);
2275
2276 Tensor uint8FusedQT(ElemKind::UInt8FusedQTy, {16, 16}, 1.0, 0);
2277 initTensor<uint8_t>(uint8FusedQT);
2278
2279 Tensor uint8FusedFP16QT(ElemKind::UInt8FusedFP16QTy, {16, 16}, 1.0, 0);
2280 initTensor<uint8_t>(uint8FusedFP16QT);
2281
2282 Tensor uint4FusedFP16QT(ElemKind::UInt4FusedFP16QTy, {16, 16}, 1.0, 0);
2283 initTensor<uint8_t>(uint4FusedFP16QT);
2284
2285 Tensor boolT(ElemKind::BoolTy, {10});
2286 initTensor<bool>(boolT);
2287
2288 // Create Constants and use them in F
2289 auto *floatC = MD.createConstant("floatC", floatT);
2290 F->createAdd("add", floatC, floatC);
2291
2292 auto *halfC = MD.createConstant("halfC", halfT);
2293 F->createAdd("add", halfC, halfC);
2294
2295 auto *bfloat16C = MD.createConstant("bloat16C", bfloat16T);
2296 F->createAdd("add", bfloat16C, bfloat16C);
2297
2298 auto *int8QC = MD.createConstant("int8QC", int8QT);
2299 F->createAdd("add", int8QC, int8QC);
2300
2301 auto *uint8QC = MD.createConstant("uint8QC", uint8QT);
2302 F->createAdd("add", uint8QC, uint8QC);
2303
2304 auto *int16QC = MD.createConstant("int16QC", int16QT);
2305 F->createAdd("add", int16QC, int16QC);
2306
2307 auto *int32QC = MD.createConstant("int32QC", int32QT);
2308 F->createAdd("add", int32QC, int32QC);
2309
2310 auto *int32IC = MD.createConstant("int32IC", int32IT);
2311 F->createAdd("add", int32IC, int32IC);
2312
2313 auto *int64IC = MD.createConstant("int64IC", int64IT);
2314 F->createAdd("add", int64IC, int64IC);
2315
2316 auto *uint8FusedQC = MD.createConstant("uint8FusedQC", uint8FusedQT);
2317 F->createAdd("add", uint8FusedQC, uint8FusedQC);
2318
2319 auto *uint8FusedFP16QC =
2320 MD.createConstant("uint8FusedFP16QC", uint8FusedFP16QT);
2321 F->createAdd("add", uint8FusedFP16QC, uint8FusedFP16QC);
2322
2323 auto *uint4FusedFP16QC =
2324 MD.createConstant("uint4FusedFP16QC", uint4FusedFP16QT);
2325 F->createAdd("add", uint4FusedFP16QC, uint4FusedFP16QC);
2326
2327 auto *boolC = MD.createConstant("boolC", boolT);
2328 F->createAdd("add", boolC, boolC);
2329
2330 // Randomize Constants in F
2331 F->randomizeConstants();
2332
2333 // Check that no Constant is the same as what it started as
2334 EXPECT_FALSE(floatT.isEqual(floatC->getPayload()));
2335 EXPECT_FALSE(halfT.isEqual(halfC->getPayload()));
2336 EXPECT_FALSE(bfloat16T.isEqual(bfloat16C->getPayload()));
2337 EXPECT_FALSE(int8QT.isEqual(int8QC->getPayload()));
2338 EXPECT_FALSE(uint8QT.isEqual(uint8QC->getPayload()));
2339 EXPECT_FALSE(int16QT.isEqual(int16QC->getPayload()));
2340 EXPECT_FALSE(int32QT.isEqual(int32QC->getPayload()));
2341 EXPECT_FALSE(int32IT.isEqual(int32IC->getPayload()));
2342 EXPECT_FALSE(int64IT.isEqual(int64IC->getPayload()));
2343 EXPECT_FALSE(uint8FusedQT.isEqual(uint8FusedQC->getPayload()));
2344 EXPECT_FALSE(uint8FusedFP16QT.isEqual(uint8FusedFP16QC->getPayload()));
2345 EXPECT_FALSE(uint4FusedFP16QT.isEqual(uint4FusedFP16QC->getPayload()));
2346 EXPECT_FALSE(boolT.isEqual(boolC->getPayload()));
2347 }
2348