1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "glow/Graph/Graph.h"
18 #include "BackendTestUtils.h"
19 #include "glow/ExecutionEngine/ExecutionEngine.h"
20 #include "glow/Graph/Hook.h"
21 #include "glow/Graph/Node.h"
22 #include "glow/Graph/Nodes.h"
23 #include "glow/Graph/Utils.h"
24 #include "glow/IR/IR.h"
25 #include "glow/IR/Instrs.h"
26 #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
27 
28 #include "llvm/ADT/SmallPtrSet.h"
29 #include "llvm/Support/FileSystem.h"
30 
31 #include "gtest/gtest.h"
32 
33 using namespace glow;
34 
35 // Helper to find a node in the Function by name
nodeByName(const Function * F,const std::string & name)36 static const Node *nodeByName(const Function *F, const std::string &name) {
37   for (auto &n : F->getNodes()) {
38     if (n.getName().str() == name) {
39       return &n;
40     }
41   }
42   return nullptr;
43 }
44 
45 /// Mock backend that does lower FC nodes.
46 class MockBackendNoLowerConv3D : public MockBackend {
shouldLower(const Node * N) const47   bool shouldLower(const Node *N) const override {
48     if (N->getKind() == Kinded::Kind::Convolution3DNodeKind) {
49       return false;
50     } else {
51       return true;
52     }
53   }
54 };
55 
TEST(Graph,testVariableErasure)56 TEST(Graph, testVariableErasure) {
57   Module MD;
58   auto &vars = MD.getConstants();
59   EXPECT_EQ(vars.size(), 0);
60   EXPECT_EQ(std::distance(vars.begin(), vars.end()), vars.size());
61 
62   Constant *V = MD.createConstant(ElemKind::FloatTy, {1, 1}, "dummy");
63   EXPECT_EQ(vars.size(), 1);
64   EXPECT_EQ(std::distance(vars.begin(), vars.end()), vars.size());
65 
66   MD.eraseConstant(V);
67   EXPECT_EQ(vars.size(), 0);
68   EXPECT_EQ(std::distance(vars.begin(), vars.end()), vars.size());
69 }
70 
71 /// Check that the clear method completely reset a module.
TEST(Graph,clear)72 TEST(Graph, clear) {
73   Module M;
74 
75   // Check that the module is initially empty.
76   EXPECT_EQ(M.getConstants().size(), 0);
77   EXPECT_EQ(M.getPlaceholders().size(), 0);
78   EXPECT_EQ(M.getFunctions().size(), 0);
79 
80   // Create a few things.
81   M.createFunction("main");
82   M.createPlaceholder(ElemKind::FloatTy, {1}, "placeholder", true);
83   M.createConstant(ElemKind::FloatTy, {1}, "var");
84 
85   EXPECT_EQ(M.getConstants().size(), 1);
86   EXPECT_EQ(M.getPlaceholders().size(), 1);
87   EXPECT_EQ(M.getFunctions().size(), 1);
88 
89   // Check that clearing the module makes it completely free of any kind of
90   // objects.
91   M.clear();
92   EXPECT_EQ(M.getConstants().size(), 0);
93   EXPECT_EQ(M.getPlaceholders().size(), 0);
94   EXPECT_EQ(M.getFunctions().size(), 0);
95 }
96 
97 /// Check that the clear method works as expected.
TEST(Graph,clearFunctions)98 TEST(Graph, clearFunctions) {
99   Module M;
100 
101   // Check that the module is initially empty.
102   EXPECT_EQ(M.getConstants().size(), 0);
103   EXPECT_EQ(M.getPlaceholders().size(), 0);
104   EXPECT_EQ(M.getFunctions().size(), 0);
105 
106   // Create a few things.
107   Function *F = M.createFunction("main");
108   auto *PH = M.createPlaceholder(ElemKind::FloatTy, {1}, "placeholder", true);
109   auto *C = M.createConstant(ElemKind::FloatTy, {1}, "var");
110   auto *AN = F->createAdd("add", PH, C);
111   F->createSave("save", AN);
112 
113   EXPECT_EQ(M.getConstants().size(), 1);
114   EXPECT_EQ(M.getPlaceholders().size(), 2); // Input PH and PH for Save
115   EXPECT_EQ(M.getFunctions().size(), 1);
116   EXPECT_EQ(F->getNodes().size(), 2); // Add, Save
117 
118   M.clearFunctions();
119   EXPECT_EQ(M.getConstants().size(), 1);
120   EXPECT_EQ(M.getPlaceholders().size(), 2);
121   ASSERT_EQ(M.getFunctions().size(), 1);
122   // Same Function ptr should exist, just nothing left in them.
123   EXPECT_EQ(*M.getFunctions().begin(), F);
124   EXPECT_EQ(F->getNodes().size(), 0);
125 }
126 
127 /// Test the graph nodes names and utilities.
TEST(Graph,testGraphNames)128 TEST(Graph, testGraphNames) {
129   Module MD;
130   Function *F = MD.createFunction("F");
131 
132   Node *op1 = MD.createPlaceholder(ElemKind::FloatTy, {1, 10}, "op1",
133                                    false /*isTrainable*/);
134   Node *op2 = MD.createConstant(ElemKind::FloatTy, {1, 10}, "op2");
135   Node *add = F->createAdd("add", op1, op2);
136   auto *top = F->createTopK("top", add, 5);
137   Node *save = F->createSave("out", top->getValues());
138 
139   EXPECT_TRUE(MD.getPlaceholderByNameSlow("op1"));
140   EXPECT_TRUE(MD.getConstantByName("op2"));
141   EXPECT_TRUE(F->getNodeByName("add"));
142   EXPECT_TRUE(F->getNodeByName("top"));
143   EXPECT_TRUE(F->getNodeByName("out_save"));
144 
145   NodeValue op1Res = op1->getNthResult(0);
146   NodeValue op2Res = op2->getNthResult(0);
147   NodeValue addRes = add->getNthResult(0);
148   EXPECT_TRUE(top->getNumResults() == 2);
149   NodeValue topValRes = top->getNthResult(0);
150   NodeValue topIndRes = top->getNthResult(1);
151 
152   auto op1ResName =
153       op1Res.generateNodeOutputName(false /*stripResNoFor0thInput*/);
154   auto op2ResName =
155       op2Res.generateNodeOutputName(false /*stripResNoFor0thInput*/);
156   auto addResName =
157       addRes.generateNodeOutputName(true /*stripResNoFor0thInput*/);
158   auto topValResName =
159       topValRes.generateNodeOutputName(false /*stripResNoFor0thInput*/);
160   auto topIndResName =
161       topIndRes.generateNodeOutputName(false /*stripResNoFor0thInput*/);
162 
163   EXPECT_EQ(op1ResName, "op1:0");
164   EXPECT_EQ(op2ResName, "op2:0");
165   EXPECT_EQ(addResName, "add");
166   EXPECT_EQ(topValResName, "top:0");
167   EXPECT_EQ(topIndResName, "top:1");
168 
169   EXPECT_EQ(F->getNodeValueByName(op1ResName), op1Res);
170   EXPECT_EQ(F->getNodeValueByName(op2ResName), op2Res);
171   EXPECT_EQ(F->getNodeValueByName(addResName), addRes);
172   EXPECT_EQ(F->getNodeValueByName(topValResName), topValRes);
173   EXPECT_EQ(F->getNodeValueByName(topIndResName), topIndRes);
174 
175   EXPECT_EQ(F->getNodeValueByName("op1"), op1Res);
176   EXPECT_EQ(F->getNodeValueByName("op2"), op2Res);
177   EXPECT_EQ(F->getNodeValueByName("add:0"), addRes);
178 
179   // Verify the node value is invalid for the SaveNode which has no outputs.
180   EXPECT_EQ(F->getNodeValueByName(save->getName()).getNode(), nullptr);
181 }
182 
183 /// Check node names.
TEST(Graph,testNodeNames)184 TEST(Graph, testNodeNames) {
185   Module MD;
186   Function *F = MD.createFunction("F");
187   IRFunction M(F);
188   PlaceholderBindings bindings;
189   Node *K =
190       MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
191   Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", true);
192 
193   K = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, 1);
194   K = F->createRELU("Relu", K);
195   K = F->createSoftMax("SoftMax", K, S);
196   F->createSave("Save", K);
197   F->dump();
198   auto filePath = F->dumpDAG();
199   auto backend = MockBackend();
200   CompilationContext cctx;
201   lower(F, cctx, &backend);
202   ::optimize(F, CompilationMode::Train);
203   M.generateIR(backend);
204   M.dump();
205   EXPECT_GT(M.getInstrs().size(), 0);
206   llvm::sys::fs::remove(filePath);
207 }
208 
209 /// Check that a createConv3D can be run.
TEST(Graph,simpleTestConv3D)210 TEST(Graph, simpleTestConv3D) {
211   Module MD;
212   Function *F = MD.createFunction("F");
213   IRFunction M(F);
214   PlaceholderBindings bindings;
215   Node *K = MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 100, 3},
216                                  "input", true);
217   K = F->createConv3D(bindings, /* name */ "Conv3D", /* input */ K,
218                       /* outChannels */ 16, /* kernel */ 3, /* stride */ 2,
219                       /* pad */ 3, /* group */ 1);
220   K = F->createRELU("Relu", K);
221   F->createSave("Save", K);
222   F->dump();
223   auto filePath = F->dumpDAG();
224   auto backend = MockBackend();
225   CompilationContext cctx;
226   lower(F, cctx, &backend);
227   ::optimize(F, CompilationMode::Train);
228   M.generateIR(backend);
229   M.dump();
230   EXPECT_GT(M.getInstrs().size(), 0);
231   llvm::sys::fs::remove(filePath);
232 }
233 
234 /// Tests custom lowering from Node to Instruction IR
TEST(Graph,simpleTestConvCustomLower)235 TEST(Graph, simpleTestConvCustomLower) {
236   Module MD;
237   Function *F = MD.createFunction("F");
238   IRFunction M(F);
239   PlaceholderBindings bindings;
240   Node *K =
241       MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
242   Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", true);
243 
244   K = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, 1);
245   K = F->createRELU("Relu", K);
246   K = F->createSoftMax("SoftMax", K, S);
247   F->createSave("Save", K);
248   F->dump();
249   auto filePath = F->dumpDAG();
250   auto backend = MockBackendCustomIRGen();
251   CompilationContext cctx;
252   lower(F, cctx, &backend);
253   ::optimize(F, CompilationMode::Train);
254   M.generateIR(MockBackendCustomIRGen());
255   M.dump();
256   auto &instrList = M.getInstrs();
257   bool customHappened = false;
258   for (auto begin = instrList.begin(); begin != instrList.end(); ++begin) {
259     if (begin->getName().equals("CustomConvolutionInstruction")) {
260       customHappened = true;
261       break;
262     }
263   }
264 
265   EXPECT_EQ(customHappened, true);
266   llvm::sys::fs::remove(filePath);
267 }
268 
269 /// Check that we can create convolution with float16.
TEST(Graph,float16Conv)270 TEST(Graph, float16Conv) {
271   Module MD;
272   Function *F = MD.createFunction("F");
273   PlaceholderBindings bindings;
274   Node *K = MD.createConstant(ElemKind::Float16Ty, {4, 320, 200, 3}, "input");
275 
276   auto *conv = F->createConv(bindings, "Conv", K, 16, 3, 2, 3, 1);
277   F->createSave("Save", conv);
278   EXPECT_TRUE(conv->verify());
279   EXPECT_EQ(conv->getResult().getElementType(), ElemKind::Float16Ty);
280   EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::Float16Ty);
281   EXPECT_EQ(conv->getBias().getElementType(), ElemKind::Float16Ty);
282 
283   auto backend = MockBackend();
284   CompilationContext cctx;
285   lower(F, cctx, &backend);
286 
287   IRFunction M(F);
288 
289   M.generateIR(backend);
290   EXPECT_GT(M.getInstrs().size(), 0);
291   auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
292                              [](const Instruction &inst) -> bool {
293                                return llvm::isa<ConvolutionInst>(inst);
294                              });
295   ASSERT_TRUE(convIt != M.getInstrs().end());
296   const auto *convInst = llvm::cast<ConvolutionInst>(&*convIt);
297   EXPECT_EQ(convInst->getSrc()->getElementType(), ElemKind::Float16Ty);
298   EXPECT_EQ(convInst->getFilter()->getElementType(), ElemKind::Float16Ty);
299   EXPECT_EQ(convInst->getBias()->getElementType(), ElemKind::Float16Ty);
300 }
301 
302 /// Check that we can create conv3D with float16.
TEST(Graph,float16Conv3DLower)303 TEST(Graph, float16Conv3DLower) {
304   Module MD;
305   Function *F = MD.createFunction("F");
306   PlaceholderBindings bindings;
307   Node *K =
308       MD.createConstant(ElemKind::Float16Ty, {4, 320, 200, 200, 3}, "input");
309 
310   auto *conv = F->createConv3D(bindings, "Conv3D", K, 16, 3, 2, 3, 1);
311   F->createSave("Save", conv);
312   EXPECT_TRUE(conv->verify());
313   EXPECT_EQ(conv->getResult().getElementType(), ElemKind::Float16Ty);
314   EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::Float16Ty);
315   EXPECT_EQ(conv->getBias().getElementType(), ElemKind::Float16Ty);
316 
317   auto backend = MockBackend();
318   CompilationContext cctx;
319   lower(F, cctx, &backend);
320 
321   IRFunction M(F);
322 
323   M.generateIR(backend);
324   EXPECT_GT(M.getInstrs().size(), 0);
325   auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
326                              [](const Instruction &inst) -> bool {
327                                return llvm::isa<Convolution3DInst>(inst);
328                              });
329   ASSERT_TRUE(convIt == M.getInstrs().end());
330 }
331 
332 /// Check that we can create conv3D with float16.
TEST(Graph,float16Conv3DNoLower)333 TEST(Graph, float16Conv3DNoLower) {
334   Module MD;
335   Function *F = MD.createFunction("F");
336   PlaceholderBindings bindings;
337   Node *K =
338       MD.createConstant(ElemKind::Float16Ty, {4, 320, 200, 200, 3}, "input");
339 
340   auto *conv = F->createConv3D(bindings, "Conv3D", K, 16, 3, 2, 3, 1);
341   F->createSave("Save", conv);
342   EXPECT_TRUE(conv->verify());
343   EXPECT_EQ(conv->getResult().getElementType(), ElemKind::Float16Ty);
344   EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::Float16Ty);
345   EXPECT_EQ(conv->getBias().getElementType(), ElemKind::Float16Ty);
346 
347   auto backend = MockBackendNoLowerConv3D();
348   CompilationContext cctx;
349   lower(F, cctx, &backend);
350 
351   IRFunction M(F);
352 
353   M.generateIR(backend);
354   EXPECT_GT(M.getInstrs().size(), 0);
355   auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
356                              [](const Instruction &inst) -> bool {
357                                return llvm::isa<Convolution3DInst>(inst);
358                              });
359   ASSERT_TRUE(convIt != M.getInstrs().end());
360   const auto *convInst = llvm::cast<Convolution3DInst>(&*convIt);
361   EXPECT_EQ(convInst->getSrc()->getElementType(), ElemKind::Float16Ty);
362   EXPECT_EQ(convInst->getFilter()->getElementType(), ElemKind::Float16Ty);
363   EXPECT_EQ(convInst->getBias()->getElementType(), ElemKind::Float16Ty);
364 }
365 
366 /// Check that we can create batchNorm with float16.
TEST(Graph,float16BatchNorm)367 TEST(Graph, float16BatchNorm) {
368   Module MD;
369   Function *F = MD.createFunction("F");
370   PlaceholderBindings bindings;
371   auto *input =
372       MD.createPlaceholder(ElemKind::Float16Ty, {1, 10, 20, 3}, "input", false);
373   BatchNormalizationNode *BN =
374       F->createBatchNormalization(bindings, "batch", input, 3, 0.0001, 0.9);
375 
376   EXPECT_TRUE(BN->verify());
377   EXPECT_EQ(BN->getResult().getElementType(), ElemKind::Float16Ty);
378   EXPECT_EQ(BN->getScale().getElementType(), ElemKind::Float16Ty);
379   EXPECT_EQ(BN->getBias().getElementType(), ElemKind::Float16Ty);
380   EXPECT_EQ(BN->getMean().getElementType(), ElemKind::Float16Ty);
381   EXPECT_EQ(BN->getVar().getElementType(), ElemKind::Float16Ty);
382 
383   auto backend = MockBackend();
384   CompilationContext cctx;
385   lower(F, cctx, &backend);
386 
387   EXPECT_TRUE(std::all_of(
388       F->getNodes().begin(), F->getNodes().end(), [](const Node &node) -> bool {
389         for (unsigned idx = 0, end = node.getNumResults(); idx != end; ++idx) {
390           if (node.getType(idx)->getElementType() != ElemKind::Float16Ty) {
391             return false;
392           }
393         }
394         return true;
395       }));
396 }
397 
398 /// Check that we can create convolution with bfloat16.
TEST(Graph,bfloat16Conv)399 TEST(Graph, bfloat16Conv) {
400   Module MD;
401   Function *F = MD.createFunction("F");
402   PlaceholderBindings bindings;
403   Node *K = MD.createConstant(ElemKind::BFloat16Ty, {4, 320, 200, 3}, "input");
404 
405   auto *conv = F->createConv(bindings, "Conv", K, 16, 3, 2, 3, 1);
406   F->createSave("Save", conv);
407   EXPECT_TRUE(conv->verify());
408   EXPECT_EQ(conv->getResult().getElementType(), ElemKind::BFloat16Ty);
409   EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::BFloat16Ty);
410   EXPECT_EQ(conv->getBias().getElementType(), ElemKind::BFloat16Ty);
411 
412   auto backend = MockBackend();
413   CompilationContext cctx;
414   lower(F, cctx, &backend);
415 
416   IRFunction M(F);
417 
418   M.generateIR(backend);
419   EXPECT_GT(M.getInstrs().size(), 0);
420   auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
421                              [](const Instruction &inst) -> bool {
422                                return llvm::isa<ConvolutionInst>(inst);
423                              });
424   ASSERT_TRUE(convIt != M.getInstrs().end());
425   const auto *convInst = llvm::cast<ConvolutionInst>(&*convIt);
426   EXPECT_EQ(convInst->getSrc()->getElementType(), ElemKind::BFloat16Ty);
427   EXPECT_EQ(convInst->getFilter()->getElementType(), ElemKind::BFloat16Ty);
428   EXPECT_EQ(convInst->getBias()->getElementType(), ElemKind::BFloat16Ty);
429 }
430 
431 /// Check that we can create conv3D with bfloat16.
TEST(Graph,bfloat16Conv3DLower)432 TEST(Graph, bfloat16Conv3DLower) {
433   Module MD;
434   Function *F = MD.createFunction("F");
435   PlaceholderBindings bindings;
436   Node *K =
437       MD.createConstant(ElemKind::BFloat16Ty, {4, 320, 200, 200, 3}, "input");
438 
439   auto *conv = F->createConv3D(bindings, "Conv3D", K, 16, 3, 2, 3, 1);
440   F->createSave("Save", conv);
441   EXPECT_TRUE(conv->verify());
442   EXPECT_EQ(conv->getResult().getElementType(), ElemKind::BFloat16Ty);
443   EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::BFloat16Ty);
444   EXPECT_EQ(conv->getBias().getElementType(), ElemKind::BFloat16Ty);
445 
446   auto backend = MockBackend();
447   CompilationContext cctx;
448   lower(F, cctx, &backend);
449 
450   IRFunction M(F);
451 
452   M.generateIR(backend);
453   EXPECT_GT(M.getInstrs().size(), 0);
454   auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
455                              [](const Instruction &inst) -> bool {
456                                return llvm::isa<Convolution3DInst>(inst);
457                              });
458   ASSERT_TRUE(convIt == M.getInstrs().end());
459 }
460 
461 /// Check that we can create conv3D with bfloat16.
TEST(Graph,bfloat16Conv3DNoLower)462 TEST(Graph, bfloat16Conv3DNoLower) {
463   Module MD;
464   Function *F = MD.createFunction("F");
465   PlaceholderBindings bindings;
466   Node *K =
467       MD.createConstant(ElemKind::BFloat16Ty, {4, 320, 200, 200, 3}, "input");
468 
469   auto *conv = F->createConv3D(bindings, "Conv3D", K, 16, 3, 2, 3, 1);
470   F->createSave("Save", conv);
471   EXPECT_TRUE(conv->verify());
472   EXPECT_EQ(conv->getResult().getElementType(), ElemKind::BFloat16Ty);
473   EXPECT_EQ(conv->getFilter().getElementType(), ElemKind::BFloat16Ty);
474   EXPECT_EQ(conv->getBias().getElementType(), ElemKind::BFloat16Ty);
475 
476   auto backend = MockBackendNoLowerConv3D();
477   CompilationContext cctx;
478   lower(F, cctx, &backend);
479 
480   IRFunction M(F);
481 
482   M.generateIR(backend);
483   EXPECT_GT(M.getInstrs().size(), 0);
484   auto convIt = std::find_if(M.getInstrs().begin(), M.getInstrs().end(),
485                              [](const Instruction &inst) -> bool {
486                                return llvm::isa<Convolution3DInst>(inst);
487                              });
488   ASSERT_TRUE(convIt != M.getInstrs().end());
489   const auto *convInst = llvm::cast<Convolution3DInst>(&*convIt);
490   EXPECT_EQ(convInst->getSrc()->getElementType(), ElemKind::BFloat16Ty);
491   EXPECT_EQ(convInst->getFilter()->getElementType(), ElemKind::BFloat16Ty);
492   EXPECT_EQ(convInst->getBias()->getElementType(), ElemKind::BFloat16Ty);
493 }
494 
495 /// Check that we can create batchNorm with float16.
TEST(Graph,bfloat16BatchNorm)496 TEST(Graph, bfloat16BatchNorm) {
497   Module MD;
498   Function *F = MD.createFunction("F");
499   PlaceholderBindings bindings;
500   auto *input = MD.createPlaceholder(ElemKind::BFloat16Ty, {1, 10, 20, 3},
501                                      "input", false);
502   BatchNormalizationNode *BN =
503       F->createBatchNormalization(bindings, "batch", input, 3, 0.0001, 0.9);
504 
505   EXPECT_TRUE(BN->verify());
506   EXPECT_EQ(BN->getResult().getElementType(), ElemKind::BFloat16Ty);
507   EXPECT_EQ(BN->getScale().getElementType(), ElemKind::BFloat16Ty);
508   EXPECT_EQ(BN->getBias().getElementType(), ElemKind::BFloat16Ty);
509   EXPECT_EQ(BN->getMean().getElementType(), ElemKind::BFloat16Ty);
510   EXPECT_EQ(BN->getVar().getElementType(), ElemKind::BFloat16Ty);
511 
512   auto backend = MockBackend();
513   CompilationContext cctx;
514   lower(F, cctx, &backend);
515 
516   EXPECT_TRUE(std::all_of(
517       F->getNodes().begin(), F->getNodes().end(), [](const Node &node) -> bool {
518         for (unsigned idx = 0, end = node.getNumResults(); idx != end; ++idx) {
519           if (node.getType(idx)->getElementType() != ElemKind::BFloat16Ty) {
520             return false;
521           }
522         }
523         return true;
524       }));
525 }
526 
527 /// Test that our use lists are correctly reflecting the state of the IR
528 /// and in particular that it is not polluted by temporary variable.
TEST(Graph,useList)529 TEST(Graph, useList) {
530   Module MD;
531   Function *F = MD.createFunction("F");
532   IRFunction M(F);
533   PlaceholderBindings bindings;
534   auto *K =
535       MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
536 
537   EXPECT_EQ(K->getNumUsers(), 0);
538 
539   ConvolutionNode *conv = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, 1);
540 
541   EXPECT_TRUE(K->hasOneUse());
542   EXPECT_EQ(K->getNumUsers(), 1);
543   EXPECT_EQ(conv->getNumUsers(), 0);
544 
545   // Although the filter of the convolution is only used by the convolution
546   // node, calling getFilter creates a temporary NodeValue that messes up
547   // with the actual use list.
548   // Therefore those checks are currently inverted but should be
549   // fixed eventually.
550   // Test with implicit temporary NodeValue.
551   EXPECT_TRUE(conv->getFilter().getNode()->hasOneUse());
552   EXPECT_EQ(conv->getFilter().getNode()->getNumUsers(), 1);
553 
554   // Test with explicit temporary NodeValue.
555   Node *nodeFilter;
556   {
557     NodeValue tmp = conv->getFilter();
558     EXPECT_TRUE(tmp.getNode()->hasOneUse());
559     EXPECT_EQ(tmp.getNode()->getNumUsers(), 1);
560     nodeFilter = tmp.getNode();
561     // Test with NodeValue still around.
562     EXPECT_TRUE(nodeFilter->hasOneUse());
563     EXPECT_EQ(nodeFilter->getNumUsers(), 1);
564   }
565 
566   // Test with NodeValue took out.
567   EXPECT_TRUE(nodeFilter->hasOneUse());
568   EXPECT_EQ(nodeFilter->getNumUsers(), 1);
569 
570   // Same kind of test but with the convolution node itself.
571   {
572     NodeValue tmpConvRes(conv, 0);
573     EXPECT_EQ(conv->getNumUsers(), 0);
574     EXPECT_EQ(tmpConvRes.getNode()->getNumUsers(), 0);
575   }
576 
577   // Add a couple of uses to conv and make sure it reflects on its use list.
578   F->createSave("Save", conv, K);
579 
580   EXPECT_FALSE(K->hasOneUse());
581   EXPECT_EQ(K->getNumUsers(), 2);
582   EXPECT_EQ(conv->getNumUsers(), 1);
583   EXPECT_TRUE(conv->hasOneUse());
584 
585   {
586     NodeValue tmpConvRes(conv, 0);
587     EXPECT_TRUE(tmpConvRes.getNode()->hasOneUse());
588     EXPECT_TRUE(conv->hasOneUse());
589     EXPECT_EQ(conv->getNumUsers(), 1);
590     EXPECT_EQ(tmpConvRes.getNode()->getNumUsers(), 1);
591   }
592 
593   F->createSave("Save", conv, K);
594 
595   EXPECT_FALSE(K->hasOneUse());
596   EXPECT_EQ(K->getNumUsers(), 3);
597   EXPECT_EQ(conv->getNumUsers(), 2);
598   EXPECT_FALSE(conv->hasOneUse());
599 
600   {
601     NodeValue tmpConvRes(conv, 0);
602     EXPECT_FALSE(tmpConvRes.getNode()->hasOneUse());
603     EXPECT_FALSE(conv->hasOneUse());
604     EXPECT_EQ(conv->getNumUsers(), 2);
605     EXPECT_EQ(tmpConvRes.getNode()->getNumUsers(), 2);
606   }
607 }
608 
TEST(Graph,useListIteration)609 TEST(Graph, useListIteration) {
610   Module MD;
611   Function *F = MD.createFunction("F");
612   IRFunction M(F);
613   Node *K =
614       MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
615 
616   EXPECT_EQ(K->getNumUsers(), 0);
617 
618   PlaceholderBindings bindings;
619   ConvolutionNode *conv1 = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, 1);
620   ConvolutionNode *conv2 = F->createConv(bindings, "Conv2", K, 16, 3, 2, 3, 1);
621   // Check the number of users for different nodes.
622   EXPECT_EQ(K->getNumUsers(), 2);
623   EXPECT_EQ(conv1->getNumUsers(), 0);
624   EXPECT_TRUE(conv2->getFilter().getNode()->hasOneUse());
625   EXPECT_EQ(conv1->getFilter().getNode()->getNumUsers(), 1);
626   // Check that the first user of K is conv1.
627   EXPECT_EQ(K->getUsers().begin()->getUser(), conv1);
628   // Check that the second user of K is conv2.
629   EXPECT_EQ((++K->getUsers().begin())->getUser(), conv2);
630 }
631 
TEST(Graph,simpleTestFC)632 TEST(Graph, simpleTestFC) {
633   unsigned numInputs = 10;
634   Module MD;
635   Function *F = MD.createFunction("F");
636   IRFunction M(F);
637 
638   auto *A = MD.createPlaceholder(ElemKind::FloatTy, {numInputs, 2}, "A", true);
639   auto *Ex =
640       MD.createPlaceholder(ElemKind::FloatTy, {numInputs, 1}, "Ex", true);
641 
642   PlaceholderBindings bindings;
643   Node *O = F->createFullyConnected(bindings, "FC1", A, 6);
644   O = F->createRELU("RELU1", O);
645   O = F->createFullyConnected(bindings, "FC2", O, 1);
646   O = F->createRELU("RELU2", O);
647   O = F->createRegression("Regression", O, Ex);
648   F->createSave("Save", O);
649   F->dump();
650   auto filePath = F->dumpDAG();
651   auto backend = MockBackend();
652   CompilationContext cctx;
653   lower(F, cctx, &backend);
654   ::optimize(F, CompilationMode::Train);
655   M.generateIR(backend);
656   M.dump();
657   EXPECT_GT(M.getInstrs().size(), 0);
658   llvm::sys::fs::remove(filePath);
659 }
660 
TEST(Graph,QuantizationProfileNodes)661 TEST(Graph, QuantizationProfileNodes) {
662   unsigned numInputs = 10;
663   Module MD;
664   Function *F = MD.createFunction("F");
665   IRFunction M(F);
666 
667   auto *A = MD.createPlaceholder(ElemKind::FloatTy, {numInputs, 2}, "A", true);
668 
669   // Add non float operation, which should not be profiled.
670   auto *outQTy = F->getParent()->uniqueType(glow::ElemKind::Int8QTy,
671                                             {numInputs, 2}, 1.5, 6);
672   auto *quantize = F->createQuantize("quantize", A, outQTy);
673   // Make sure that quantize is not optimized away.
674   PlaceholderBindings bindings;
675   F->createSave("save", quantize);
676 
677   // Multiple nodes read from the same variable.
678   // Only one Quantization Profile node should be created for the output
679   // from the variable.
680   Node *O = F->createFullyConnected(bindings, "FC1", A, 6);
681   Node *C = F->createFullyConnected(bindings, "FC2", A, 6);
682   O = F->createRELU("RELU1", O);
683   F->createSave("save", O);
684   F->createSave("save", C);
685 
686   LoweredInfoMap loweredMapForProf;
687   CompilationContext cctx{&bindings, &loweredMapForProf};
688   cctx.precisionConfig.quantMode = QuantizationMode::Profile;
689   std::unique_ptr<Backend> backend(createBackend("Interpreter"));
690   EXIT_ON_ERR(::optimizeFunction(F, *backend, cctx));
691 
692   size_t numberOfProfileNodes =
693       std::count_if(F->getNodes().begin(), F->getNodes().end(), [](Node &node) {
694         return llvm::isa<QuantizationProfileNode>(&node);
695       });
696 
697   // 1 from A
698   // 8 from two lowered FCs: MM, BA, weight PH, bias PH
699   // 2 from RELU (lowered to Max+Splat)
700   EXPECT_EQ(11, numberOfProfileNodes);
701 }
702 
TEST(Graph,simpleQuant)703 TEST(Graph, simpleQuant) {
704   ExecutionEngine EE;
705   auto &MD = EE.getModule();
706   auto *F = MD.createFunction("main");
707 
708   unsigned depth = 16;
709   llvm::SmallVector<unsigned_t, 2> kernels = {5, 5};
710   llvm::SmallVector<unsigned_t, 4> pads = {0, 0, 0, 0};
711   llvm::SmallVector<unsigned_t, 2> steps = {1, 1};
712   unsigned width = 224;
713 
714   auto *input = MD.createPlaceholder(ElemKind::Int8QTy, {1, width, width, 3},
715                                      0.4, 2, "Input", true);
716 
717   // Calculate the size and allocate the output buffer.
718   std::array<dim_t, 4> filterDim = {{depth, kernels[0], kernels[1], 3}};
719   auto *filter =
720       MD.createPlaceholder(ElemKind::Int8QTy, filterDim, 3.3, 4, "F", true);
721   auto *bias =
722       MD.createPlaceholder(ElemKind::Int32QTy, {depth}, 1.3, 5, "B", true);
723 
724   // Calculate the size and allocate the output buffer.
725   auto outSz = calculateConvPoolOutputDims(width, width, kernels, steps, pads);
726   std::array<dim_t, 4> outDims = {{1, outSz.first, outSz.second, 16}};
727   auto t = F->getParent()->uniqueType(glow::ElemKind::Int8QTy, outDims, 1.5, 6);
728 
729   auto *conv =
730       F->createConv("conv", input, filter, bias, t, kernels, steps, pads, 1);
731 
732   auto s = conv->getResult().getType()->size();
733   auto *fcFilter =
734       MD.createPlaceholder(ElemKind::Int8QTy, {s, 6}, 0.4, 2, "F", true);
735   auto *fcBias =
736       MD.createPlaceholder(ElemKind::Int32QTy, {6}, 0.4, 2, "B", true);
737   Node *O = F->createFullyConnected("fc1", conv, fcFilter, fcBias);
738   PlaceholderBindings bindings;
739   F->createSave("ret", O);
740   EE.compile(CompilationMode::Infer);
741 }
742 
TEST(Graph,quantizeDequantizeNodes)743 TEST(Graph, quantizeDequantizeNodes) {
744   ExecutionEngine EE;
745   auto &MD = EE.getModule();
746   auto F = MD.createFunction("main");
747 
748   auto *input = MD.createPlaceholder(ElemKind::FloatTy, {1, 3}, "Input", true);
749   auto qType = F->getParent()->uniqueType(ElemKind::Int8QTy, {1, 3}, 0.3, 5);
750 
751   auto *Q = F->createQuantize("quantize", input, qType);
752 
753   auto transform =
754       F->getParent()->uniqueType(ElemKind::Int8QTy, {1, 3}, 1.4, 3);
755   auto *A = F->createRescaleQuantized("rescale", Q, transform);
756 
757   auto *D = F->createDequantize("dequantize", A, ElemKind::FloatTy);
758   PlaceholderBindings bindings;
759   F->createSave("ret", D);
760   EE.compile(CompilationMode::Infer);
761 }
762 
TEST(Graph,quantizeGather)763 TEST(Graph, quantizeGather) {
764   ExecutionEngine EE;
765   auto &mod = EE.getModule();
766   auto *F = mod.createFunction("main");
767   auto *input =
768       mod.createPlaceholder(ElemKind::Int8QTy, {2, 2}, 0.4, 2, "input", true);
769   auto *indices = mod.createPlaceholder(ElemKind::Int64ITy, {1}, "index", true);
770   auto *gather = F->createGather("gather", input, indices);
771   PlaceholderBindings bindings;
772   F->createSave("ret", gather);
773   EE.compile(CompilationMode::Infer);
774 }
775 
TEST(Graph,cloneTest)776 TEST(Graph, cloneTest) {
777   Module M;
778   PlaceholderBindings bindings;
779 
780   Function *F = M.createFunction("main");
781   Node *K =
782       M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
783   Node *S = M.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", true);
784   Node *conv = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, 1);
785   Node *relu = F->createRELU("Relu", conv);
786   Node *SM = F->createSoftMax("SoftMax", relu, S);
787   F->createSave("Save", SM);
788 
789   auto *newConv = F->addNode(conv->clone());
790   auto *newRelu = F->addNode(relu->clone());
791   auto *newSM = F->addNode(SM->clone());
792 
793   EXPECT_TRUE(newConv != conv && conv->isEqual(*newConv));
794   EXPECT_TRUE(newRelu != relu && relu->isEqual(*newRelu));
795   EXPECT_TRUE(newSM != SM && SM->isEqual(*newSM));
796 }
797 
TEST(Graph,moduleTest)798 TEST(Graph, moduleTest) {
799   Module M;
800   M.createFunction("one");
801   M.createFunction("two");
802   M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V1", true);
803   M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V2", true);
804   EXPECT_TRUE(M.hasFunction("one"));
805   EXPECT_TRUE(M.hasFunction("two"));
806   EXPECT_FALSE(M.hasFunction("four"));
807   M.dumpDAG();
808 }
809 
TEST(Graph,functionDependenciesTest)810 TEST(Graph, functionDependenciesTest) {
811   Module M;
812   auto *F1 = M.createFunction("one");
813   auto *F2 = M.createFunction("two");
814   auto *V1 =
815       M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V1", true);
816   auto *V2 =
817       M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V2", true);
818   auto *V3 =
819       M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V3", true);
820   M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "V4", true);
821 
822   PlaceholderBindings bindings;
823   auto sum = F1->createSub("1_sub_2", V1, V2);
824   F1->createSave("sv", sum, V1);
825   F2->createSave("sv", V3, V2);
826 
827   EXPECT_TRUE(M.hasFunction("one"));
828   EXPECT_TRUE(M.hasFunction("two"));
829   EXPECT_FALSE(M.hasFunction("four"));
830   M.dumpDAG();
831 }
832 
TEST(Graph,functionCloneTest)833 TEST(Graph, functionCloneTest) {
834   Module M;
835   PlaceholderBindings bindings;
836 
837   auto *F = M.createFunction("main");
838   Node *K =
839       M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
840   Node *S = M.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", true);
841   Node *conv = F->createConv(bindings, "Conv", K, 16, 3, 2, 3, 1);
842   Node *relu = F->createRELU("Relu", conv);
843   Node *concat = F->createConcat("concat", {relu, relu, relu}, 0);
844   Node *SM = F->createSoftMax("SoftMax", concat, S);
845   F->createSave("Save", SM);
846 
847   auto *newF = F->clone("new_main");
848 
849   EXPECT_TRUE(newF->verify());
850 
851   EXPECT_EQ(newF->getNodes().size(), F->getNodes().size());
852   EXPECT_EQ(newF->getParent(), F->getParent());
853 }
854 
855 /// Compile the module \p M inside the execution engine \p EE and then run it
856 /// using the provided \p bindings. Use the provided \p inputName and \p
857 /// outputName.
compileAndRun(ExecutionEngine & EE,PlaceholderBindings & bindings,Module & M,llvm::StringRef inputName,llvm::StringRef outputName)858 static void compileAndRun(ExecutionEngine &EE, PlaceholderBindings &bindings,
859                           Module &M, llvm::StringRef inputName,
860                           llvm::StringRef outputName) {
861   EE.compile(glow::CompilationMode::Infer);
862   // Allocate stprage for placeholders and initialize inputs.
863   bindings.allocate(M.getPlaceholderByNameSlow(inputName))
864       ->getHandle()
865       .clear(2.0);
866   bindings.allocate(M.getPlaceholderByNameSlow(outputName));
867   EE.run(bindings);
868 }
869 
870 /// Check the module cloning functionality.
TEST(Graph,moduleCloneTest)871 TEST(Graph, moduleCloneTest) {
872   // State related to the cloned module and its execution.
873   ExecutionEngine clonedEE("Interpreter");
874   Module &clonedM = clonedEE.getModule();
875   PlaceholderBindings clonedBindings;
876   Tensor clonedResult;
877   // State related to the original module and its execution.
878   PlaceholderBindings originalBindings;
879   Tensor originalResult;
880   // Name of the placeholder holding the results of executions.
881   std::string resultName;
882   {
883     // Define the original execution engine and module.
884     ExecutionEngine originalEE("Interpreter");
885     Module &originalM = originalEE.getModule();
886 
887     // Create a function.
888     auto *F = originalM.createFunction("main");
889     auto *input1 = originalM.createPlaceholder(ElemKind::FloatTy,
890                                                {4, 10, 10, 3}, "input", true);
891 
892     auto *add = F->createAdd("add", input1, input1);
893     auto *relu = F->createRELU("Relu", add);
894     auto *concat = F->createConcat("concat", {relu, relu, relu}, 0);
895     auto *C = originalM.createConstant(concat->getResult().getType(), "C");
896     C->getPayloadMutable().getHandle().clear(1.0f);
897     auto *SM = F->createAdd("add", concat, C);
898     auto *SN = F->createSave("Save", SM);
899     resultName = SN->getPlaceholder()->getName();
900 
901     // Clone the original module into the cloned module.
902     originalM.clone(&clonedM);
903     // The cloned module should have the same numer of types, functions,
904     // constants and placeholders.
905     EXPECT_EQ(originalM.getFunctions().size(), clonedM.getFunctions().size());
906     EXPECT_EQ(originalM.getPlaceholders().size(),
907               clonedM.getPlaceholders().size());
908     EXPECT_EQ(originalM.getConstants().size(), clonedM.getConstants().size());
909     EXPECT_EQ(originalM.getTypes().size(), clonedM.getTypes().size());
910     // String representations of the original and cloned modules should be the
911     // same.
912     EXPECT_EQ(originalM.toString(), clonedM.toString());
913     for (auto *originalF : originalM.getFunctions()) {
914       EXPECT_EQ(originalF->toString(),
915                 clonedM.getFunction(originalF->getName())->toString());
916     }
917 
918     // Compile and run the original module.
919     compileAndRun(originalEE, originalBindings, originalM, "input", resultName);
920     // Store the result of running the original module.
921     originalResult.assign(originalBindings.get(
922         originalBindings.getPlaceholderByNameSlow(resultName)));
923     // The old module should be removed when this scope ends. Thus, if the
924     // cloned module newM refers to any deleted nodes from the original module,
925     // it would result in a dangling reference and most likely in a crash.
926   }
927   // Check that the cloned module is still alive and valid after the original
928   // module was deleted.
929   EXPECT_TRUE(clonedM.verify());
930   // Compile and run the cloned model.
931   compileAndRun(clonedEE, clonedBindings, clonedM, "input", resultName);
932   // Store the result of running the cloned module.
933   clonedResult.assign(
934       clonedBindings.get(clonedBindings.getPlaceholderByNameSlow(resultName)));
935   // The results of execution should be exactly the same in both cases.
936   EXPECT_TRUE(originalResult.isEqual(clonedResult, 0));
937 }
938 
TEST(Graph,cloneWithPredicates)939 TEST(Graph, cloneWithPredicates) {
940   Module M;
941   PlaceholderBindings bindings;
942 
943   auto *F = M.createFunction("main");
944   auto *input =
945       M.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", false);
946   auto *counters =
947       M.createPlaceholder(ElemKind::FloatTy, {10}, "counters", false);
948   auto *reluExt = F->createRELU("reluExt", input);
949   auto *reluInt = F->createRELU("reluInt", input);
950   auto *externalPredicate =
951       M.createPlaceholder(ElemKind::Int64ITy, {1}, "predicate", false);
952   auto *C10 = F->createSplat("C10", counters->getType(), 10.0);
953   auto *internalPredicate = F->createCmpLTE("lte", C10, counters);
954 
955   reluExt->setPredicate(externalPredicate);
956   reluInt->setPredicate(internalPredicate);
957 
958   auto *newF = F->clone("new_main");
959 
960   EXPECT_TRUE(newF->verify());
961   EXPECT_EQ(newF->getNodes().size(), F->getNodes().size());
962   EXPECT_EQ(newF->getParent(), F->getParent());
963 
964   // Original predicates are not changed
965   EXPECT_EQ(reluExt->getPredicate().getNode(), externalPredicate);
966   EXPECT_EQ(reluInt->getPredicate().getNode(), internalPredicate);
967   // Clone of predicate that points to a node outside the graph
968   // points to the same node (predicate is shared)
969   EXPECT_EQ(nodeByName(newF, "reluExt")->getPredicate().getNode(),
970             externalPredicate);
971   // Clone of predicate that points to a node that belongs to the graph
972   // points to the predicate clone
973   EXPECT_EQ(nodeByName(newF, "reluInt")->getPredicate().getNode(),
974             nodeByName(newF, "lte"));
975 }
976 
TEST(Graph,NodeValue)977 TEST(Graph, NodeValue) {
978   ExecutionEngine EE;
979   auto &mod = EE.getModule();
980   Function *F = mod.createFunction("main");
981   PlaceholderBindings bindings;
982   auto *inputX = mod.createPlaceholder(ElemKind::FloatTy, {1}, "input", true);
983   bindings.allocate(inputX)->init(Tensor::InitKind::Broadcast, 3.0,
984                                   mod.getPRNG());
985 
986   NodeValue a = F->createAdd("x2", inputX, inputX);
987   a = F->createAdd("x4", a, a);
988   a = F->createAdd("x8", a, a);
989   auto *S = F->createSave("Save", a);
990   auto *res = bindings.allocate(S->getPlaceholder());
991 
992   EE.compile(CompilationMode::Infer);
993 
994   EE.run(bindings);
995 
996   EXPECT_EQ(res->getHandle().raw(0), 24);
997 }
998 
999 /// Check that by deleting one function, the variables that refernced
1000 /// by this function, will reduce its number of uses by one.
TEST(Graph,deleteFunction)1001 TEST(Graph, deleteFunction) {
1002   ExecutionEngine EE;
1003   auto &mod = EE.getModule();
1004   Function *F1 = mod.createFunction("f1");
1005   auto *inputX = mod.createPlaceholder(ElemKind::FloatTy, {1}, "input", true);
1006   F1->createLog("log1", inputX);
1007   Function *F2 = mod.createFunction("f2");
1008   F2->createLog("log2", inputX);
1009   // We check the number of user of inputX to be 2 as only F1 and F2 are
1010   // using it.
1011   EXPECT_EQ(inputX->getNumUsers(), 2);
1012   // Erase this function here to see if we can see the number of user of inputX
1013   // reduce to 1.
1014   mod.eraseFunction(F1);
1015   EXPECT_EQ(inputX->getNumUsers(), 1);
1016 }
1017 
TEST(Graph,nodesWithPredicates)1018 TEST(Graph, nodesWithPredicates) {
1019   ExecutionEngine EE;
1020 
1021   Tensor inputs(ElemKind::FloatTy, {1, 32, 32, 3});
1022 
1023   auto &mod = EE.getModule();
1024   Function *F = mod.createFunction("main");
1025   F->setName("interpret");
1026   PlaceholderBindings bindings;
1027   auto *input =
1028       mod.createPlaceholder(ElemKind::FloatTy, {1, 32, 32, 3}, "input", true);
1029   auto *ex = mod.createPlaceholder(ElemKind::Int64ITy, {1, 1}, "exp", true);
1030   auto *pred =
1031       mod.createPlaceholder(ElemKind::Int64ITy, {1}, "predicate", false);
1032   bindings.allocate(input);
1033   bindings.allocate(ex);
1034   bindings.allocate(pred);
1035 
1036   auto *CV0 = F->createConv(bindings, "conv1", input, 16, 5, 1, 2, 1);
1037   auto *RL0 = F->createRELU("relu1", CV0);
1038   auto *MP0 = F->createMaxPool("pool1", RL0, 2, 2, 0);
1039 
1040   CV0->setPredicate(pred);
1041   RL0->setPredicate(pred);
1042   MP0->setPredicate(pred);
1043 
1044   auto *FCL1 = F->createFullyConnected(bindings, "fc", MP0->getResult(), 10);
1045   auto *RL3 = F->createRELU("relu4", FCL1);
1046   auto *SM = F->createSoftMax("sm", RL3, ex);
1047   auto *save = F->createSave("ret", SM);
1048   bindings.allocate(save->getPlaceholder());
1049 
1050   EE.compile(CompilationMode::Infer);
1051 
1052   updateInputPlaceholders(bindings, {input}, {&inputs});
1053   EE.run(bindings);
1054 }
1055 
1056 // Return the number of ConvolutionNode after lower.
getConvNodeSize(llvm::StringRef kind)1057 unsigned getConvNodeSize(llvm::StringRef kind) {
1058   Module mod;
1059   Function *F = mod.createFunction("main");
1060   IRFunction M(F);
1061   PlaceholderBindings bindings;
1062   auto *input =
1063       mod.createPlaceholder(ElemKind::FloatTy, {1, 2, 1, 32}, "input", true);
1064   ConvolutionNode *CN = F->createConv(bindings, "conv", input, 6, 1, 1, 0, 2);
1065   F->createSave("save", CN);
1066 
1067   std::unique_ptr<Backend> backend(createBackend(kind));
1068   CompilationContext cctx;
1069   lower(F, cctx, backend.get());
1070 
1071   unsigned count = 0;
1072   for (auto &n : F->getNodes()) {
1073     if (n.getKind() == Kinded::Kind::ConvolutionNodeKind) {
1074       count++;
1075     }
1076   }
1077 
1078   if (kind == "Interpreter") {
1079     EXPECT_EQ(count, 1);
1080   }
1081 
1082   return count;
1083 }
1084 
1085 // Check the unrolling grouped convolution opt status:
1086 // -- disabled for Interpreter, CPU and OpenCL backend,
TEST(Graph,disableUnrollingGroupConv)1087 TEST(Graph, disableUnrollingGroupConv) {
1088   unsigned numberOfNodesInterpreter = getConvNodeSize("Interpreter");
1089   (void)numberOfNodesInterpreter;
1090 
1091 #ifdef GLOW_WITH_CPU
1092   unsigned numberOfNodesCPU = getConvNodeSize("CPU");
1093   EXPECT_EQ(numberOfNodesCPU, numberOfNodesInterpreter);
1094 #endif // GLOW_WITH_CPU
1095 
1096 #ifdef GLOW_WITH_OPENCL
1097   unsigned numberOfNodesOpenCL = getConvNodeSize("OpenCL");
1098   EXPECT_EQ(numberOfNodesOpenCL, numberOfNodesInterpreter);
1099 #endif // GLOW_WITH_OPENCL
1100 }
1101 
1102 /// Check that save nodes are properly scheduled.
1103 /// That is, they happen after the last use of the related variable.
1104 /// In that test, the order of the creation of the nodes give a valid schedule.
TEST(Graph,schedulingOfSavesOrderProvided)1105 TEST(Graph, schedulingOfSavesOrderProvided) {
1106   ExecutionEngine EE;
1107 
1108   auto &mod = EE.getModule();
1109   Function *F = mod.createFunction("main");
1110   auto *A = mod.createPlaceholder(ElemKind::FloatTy, {3, 32}, "A", true);
1111   auto *B = mod.createPlaceholder(A->getType(), "B", true);
1112   auto *zero = mod.createPlaceholder(A->getType(), "zero", true);
1113 
1114   PlaceholderBindings bindings;
1115   bindings.allocate(A)->init(Tensor::InitKind::Xavier, 1.0, mod.getPRNG());
1116   bindings.allocate(B)->init(Tensor::InitKind::Xavier, 1.0, mod.getPRNG());
1117   bindings.allocate(zero)->init(Tensor::InitKind::Broadcast, 0.0,
1118                                 mod.getPRNG());
1119 
1120   auto *addAB = F->createAdd("addAB", A, B);
1121 
1122   auto *saveNode = F->createSave("ret", addAB);
1123   auto *savePH = saveNode->getPlaceholder();
1124   bindings.allocate(savePH);
1125   F->createSave("resetA", zero, A);
1126 
1127   // Copy the value of A.
1128   Tensor AOrig = bindings.get(A)->clone();
1129 
1130   EE.compile(CompilationMode::Infer);
1131 
1132   EE.run(bindings);
1133   auto *ret = bindings.get(savePH);
1134   auto handleAOrig = AOrig.getHandle<>();
1135   auto handleB = bindings.get(B)->getHandle<>();
1136   auto handleRet = ret->getHandle<>();
1137   bool allEqual = true;
1138   for (unsigned row = 0; row != 3; ++row) {
1139     for (unsigned column = 0; column != 32; ++column) {
1140       allEqual &= handleAOrig.at({row, column}) + handleB.at({row, column}) ==
1141                   handleRet.at({row, column});
1142     }
1143   }
1144   EXPECT_TRUE(bindings.get(A)->isEqual(*bindings.get(zero), 0.0));
1145   EXPECT_TRUE(allEqual);
1146 }
1147 
1148 /// Same as schedulingOfSavesOrderProvided except the order in which the nodes
1149 /// are added to the function don't form a valid schedule.
1150 /// In other words, the scheduler won't get away with scheduling
1151 /// using only the order of the nodes in the list of nodes.
TEST(Graph,schedulingOfSaves)1152 TEST(Graph, schedulingOfSaves) {
1153   ExecutionEngine EE;
1154   PlaceholderBindings bindings;
1155 
1156   auto &mod = EE.getModule();
1157   Function *F = mod.createFunction("main");
1158   auto *A = mod.createPlaceholder(ElemKind::FloatTy, {3, 32}, "A", true);
1159   auto *B = mod.createPlaceholder(A->getType(), "B", true);
1160   auto *zero = mod.createPlaceholder(A->getType(), "zero", true);
1161   F->createSave("resetA", zero, A);
1162 
1163   bindings.allocate(A)->init(Tensor::InitKind::Xavier, 1.0, mod.getPRNG());
1164   bindings.allocate(B)->init(Tensor::InitKind::Xavier, 1.0, mod.getPRNG());
1165   bindings.allocate(zero)->init(Tensor::InitKind::Broadcast, 0.0,
1166                                 mod.getPRNG());
1167 
1168   auto *addAB = F->createAdd("addAB", A, B);
1169 
1170   auto *saveNode = F->createSave("ret", addAB);
1171   bindings.allocate(saveNode->getPlaceholder());
1172 
1173   // Copy the value of A.
1174   Tensor AOrig = bindings.get(A)->clone();
1175   auto *ret = saveNode->getPlaceholder();
1176   EE.compile(CompilationMode::Infer);
1177 
1178   EE.run(bindings);
1179 
1180   auto handleAOrig = AOrig.getHandle<>();
1181   auto handleB = bindings.get(B)->getHandle<>();
1182   auto handleRet = bindings.get(ret)->getHandle<>();
1183   bool allEqual = true;
1184   for (unsigned row = 0; row != 3; ++row) {
1185     for (unsigned column = 0; column != 32; ++column) {
1186       allEqual &= handleAOrig.at({row, column}) + handleB.at({row, column}) ==
1187                   handleRet.at({row, column});
1188     }
1189   }
1190   EXPECT_TRUE(bindings.get(A)->isEqual(*bindings.get(zero), 0.0));
1191   EXPECT_TRUE(allEqual);
1192 }
1193 
1194 /// Check that the parent link is properly updated while tweaking
1195 /// nodes and their function.
TEST(Graph,parentLink)1196 TEST(Graph, parentLink) {
1197   ExecutionEngine EE;
1198 
1199   auto &mod = EE.getModule();
1200   Constant *V =
1201       new Constant("V", mod.uniqueType(ElemKind::FloatTy, {3, 32}), ANY_LAYOUT);
1202 
1203   // Variables don't belong to any function...
1204   EXPECT_EQ(V->getParent(), nullptr);
1205   // Even when we create them from a module...
1206   Constant *V2 = mod.createConstant(V->getType(), "V2");
1207   EXPECT_EQ(V2->getParent(), nullptr);
1208 
1209   Function *F = mod.createFunction("main");
1210 
1211   // Nodes created with function helper belong to the related function.
1212   auto *addNode = F->createAdd("addnode", V, V2);
1213   EXPECT_EQ(addNode->getParent(), F);
1214 
1215   // Nodes created directly don't belong to any function.
1216   auto *addNode2 = new AddNode("addnode2", V->getType(), addNode, addNode);
1217   EXPECT_EQ(addNode2->getParent(), nullptr);
1218 
1219   // Nodes added to a function belong to that function.
1220   F->addNode(addNode2);
1221   EXPECT_EQ(addNode2->getParent(), F);
1222 
1223   // Cloned nodes don't belong to anything.
1224   auto *clonedAddNode = addNode->clone();
1225   EXPECT_EQ(clonedAddNode->getParent(), nullptr);
1226 
1227   // Check that the setter properly sets things.
1228   clonedAddNode->setParent(F);
1229   EXPECT_EQ(clonedAddNode->getParent(), F);
1230   clonedAddNode->setParent(nullptr);
1231   EXPECT_EQ(clonedAddNode->getParent(), nullptr);
1232 
1233   // Add the cloned node to F so that the memory is properly
1234   // cleaned at the end of the test.
1235   F->addNode(clonedAddNode);
1236   EXPECT_EQ(clonedAddNode->getParent(), F);
1237 
1238   delete V;
1239 }
1240 
1241 /// Check that verification can detect that Storage nodes are being used by
1242 /// Functions in a Module that doesn't own the Storage nodes.
TEST(Graph,moduleLink)1243 TEST(Graph, moduleLink) {
1244   ExecutionEngine EEA, EEB;
1245 
1246   auto &modA = EEA.getModule();
1247   auto &modB = EEB.getModule();
1248 
1249   auto *FA = modA.createFunction("FA");
1250   auto *FB = modB.createFunction("FB");
1251 
1252   auto *C = modA.createConstant(ElemKind::FloatTy, {1}, "C");
1253   auto *P = modA.createPlaceholder(ElemKind::FloatTy, {1}, "P", false);
1254 
1255   auto *AA = FA->createAdd("AA", C, P);
1256   FA->createSave("SA", AA);
1257 
1258   // These nodes use Storage nodes that reside in modA
1259   auto *AB = FB->createAdd("AB", C, P);
1260   FB->createSave("SB", AB);
1261 
1262   EXPECT_TRUE(modA.verify());
1263   EXPECT_FALSE(
1264       modB.verify()); // Module::verify calls Function::verify on all functions
1265                       // within the module, so this should fail
1266 }
1267 
1268 /// Check that Cmp nodes are created with proper output types.
TEST(Graph,cmpOutputTypes)1269 TEST(Graph, cmpOutputTypes) {
1270   ExecutionEngine EE;
1271 
1272   auto &mod = EE.getModule();
1273   Function *F = mod.createFunction("main");
1274   // Define two different quntized types.
1275   auto qType1 = F->getParent()->uniqueType(ElemKind::Int8QTy, {1, 3}, 0.3, 5);
1276   auto qType2 = F->getParent()->uniqueType(ElemKind::Int8QTy, {1, 3}, 0.4, 5);
1277   // Define two variables of quantized types.
1278   auto *qv1 = mod.createPlaceholder(qType1, "V1", true);
1279   auto *qv2 = mod.createPlaceholder(qType2, "V2", true);
1280   // Create cmp nodes using quantized inputs.
1281   auto *cmpNode1 = F->createCmpEQ("cmpeq", qv1, qv2);
1282   auto *cmpNode2 = F->createCmpLTE("cmplte", qv1, qv2);
1283   // Check that the output type of cmp nodes is BoolKind.
1284   EXPECT_TRUE(cmpNode1->getResult().getElementType() == ElemKind::BoolTy);
1285   EXPECT_TRUE(cmpNode2->getResult().getElementType() == ElemKind::BoolTy);
1286 
1287   // Define a non-quantized type.
1288   auto nqType3 = F->getParent()->uniqueType(ElemKind::FloatTy, {1, 3});
1289   // Define two variables of non-quantized types.
1290   auto *nqv3 = mod.createPlaceholder(nqType3, "V3", true);
1291   auto *nqv4 = mod.createPlaceholder(nqType3, "V4", true);
1292   // Create cmp nodes using non-quantized inputs.
1293   auto *cmpNode3 = F->createCmpEQ("cmpeq", nqv3, nqv4);
1294   auto *cmpNode4 = F->createCmpLTE("cmplte", nqv3, nqv4);
1295   // Check that the output type of cmp nodes is BoolKind.
1296   EXPECT_TRUE(cmpNode3->getResult().getElementType() == ElemKind::BoolTy);
1297   EXPECT_TRUE(cmpNode4->getResult().getElementType() == ElemKind::BoolTy);
1298 }
1299 
1300 /// Check that the users of value are equal to expectedUsers.
1301 static bool
hasAllTheseUses(const llvm::SmallPtrSetImpl<const Node * > & expectedUsers,const NodeValue & value)1302 hasAllTheseUses(const llvm::SmallPtrSetImpl<const Node *> &expectedUsers,
1303                 const NodeValue &value) {
1304   llvm::SmallPtrSet<const Node *, 4> uses;
1305   for (const NodeUse &use : value.getUsers()) {
1306     const Node *user = use.getUser();
1307     if (!expectedUsers.count(user)) {
1308       // We found a user that wasn't on the list.
1309       return false;
1310     }
1311     uses.insert(user);
1312   }
1313   return expectedUsers.size() == uses.size();
1314 }
1315 
1316 /// Check that our uses lists are correct for nodes with multiple results.
TEST(Graph,usesListsWithSeveralResult)1317 TEST(Graph, usesListsWithSeveralResult) {
1318   ExecutionEngine EE;
1319   PlaceholderBindings bindings;
1320 
1321   auto &mod = EE.getModule();
1322   Function *F = mod.createFunction("main");
1323   auto *input =
1324       mod.createPlaceholder(ElemKind::FloatTy, {3, 32}, "input", true);
1325   auto *topK = F->createTopK("topK", input, 12);
1326   EXPECT_EQ(topK->getNumUsers(), 0);
1327 
1328   NodeValue values = topK->getValues();
1329   NodeValue indices = topK->getIndices();
1330   llvm::SmallPtrSet<const Node *, 4> savesOfValues;
1331   llvm::SmallPtrSet<const Node *, 4> savesOfIndices;
1332 
1333   EXPECT_EQ(indices.getNumUsers(), 0);
1334   EXPECT_EQ(values.getNumUsers(), 0);
1335 
1336   EXPECT_FALSE(indices.hasOneUse());
1337   EXPECT_FALSE(values.hasOneUse());
1338 
1339   EXPECT_TRUE(hasAllTheseUses(savesOfIndices, indices));
1340   EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1341 
1342   // Now add a user to only one result of the topK node.
1343   savesOfValues.insert(F->createSave("saveValues1", values));
1344 
1345   // The whole node should inherit the uses of each of its results.
1346   EXPECT_EQ(topK->getNumUsers(), 1);
1347 
1348   // Each result should have its own use list.
1349   EXPECT_EQ(indices.getNumUsers(), 0);
1350   EXPECT_EQ(values.getNumUsers(), 1);
1351 
1352   EXPECT_FALSE(indices.hasOneUse());
1353   EXPECT_TRUE(values.hasOneUse());
1354 
1355   EXPECT_TRUE(hasAllTheseUses(savesOfIndices, indices));
1356   EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1357 
1358   // Add a user to the other result of the topK node.
1359   savesOfIndices.insert(F->createSave("saveIndices1", indices));
1360 
1361   // The whole node should inherit the uses of each of its results.
1362   EXPECT_EQ(topK->getNumUsers(), 2);
1363 
1364   // Each result should have its own use list.
1365   EXPECT_EQ(indices.getNumUsers(), 1);
1366   EXPECT_EQ(values.getNumUsers(), 1);
1367 
1368   EXPECT_TRUE(indices.hasOneUse());
1369   EXPECT_TRUE(values.hasOneUse());
1370 
1371   EXPECT_TRUE(hasAllTheseUses(savesOfIndices, indices));
1372   EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1373 
1374   // Add a couple more users of values and indices.
1375   // Interleaves the insertions in the uses list for both values and indices.
1376   savesOfValues.insert(F->createSave("saveValues2", values));
1377   savesOfValues.insert(F->createSave("saveValues3", values));
1378   savesOfIndices.insert(F->createSave("saveIndices2", indices));
1379 
1380   EXPECT_EQ(topK->getNumUsers(), 5);
1381 
1382   EXPECT_EQ(indices.getNumUsers(), 2);
1383   EXPECT_EQ(values.getNumUsers(), 3);
1384 
1385   EXPECT_FALSE(indices.hasOneUse());
1386   EXPECT_FALSE(values.hasOneUse());
1387 
1388   EXPECT_TRUE(hasAllTheseUses(savesOfIndices, indices));
1389   EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1390 }
1391 
1392 /// Check that our uses lists are correct when accessed through
1393 /// NodeValue.
TEST(Graph,usesListsThroughNodeValues)1394 TEST(Graph, usesListsThroughNodeValues) {
1395   ExecutionEngine EE;
1396   PlaceholderBindings bindings;
1397 
1398   auto &mod = EE.getModule();
1399   Function *F = mod.createFunction("main");
1400   auto *input =
1401       mod.createPlaceholder(ElemKind::FloatTy, {3, 32}, "input", true);
1402   auto *reLU = F->createRELU("reLU", input);
1403   EXPECT_EQ(reLU->getNumUsers(), 0);
1404 
1405   NodeValue values = reLU->getResult();
1406   llvm::SmallPtrSet<const Node *, 4> savesOfValues;
1407 
1408   EXPECT_EQ(values.getNumUsers(), 0);
1409 
1410   EXPECT_FALSE(values.hasOneUse());
1411 
1412   EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1413 
1414   // Now add a user to only one result of the reLU node.
1415   savesOfValues.insert(F->createSave("saveValues1", values));
1416 
1417   // The whole node should inherit the uses of each of its results.
1418   EXPECT_EQ(reLU->getNumUsers(), 1);
1419 
1420   // The NodeValue should match.
1421   EXPECT_EQ(values.getNumUsers(), 1);
1422   EXPECT_TRUE(values.hasOneUse());
1423   EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1424 
1425   // Add one more use.
1426   savesOfValues.insert(F->createSave("saveValues2", values));
1427 
1428   // The whole node should inherit the uses of each of its results.
1429   EXPECT_EQ(reLU->getNumUsers(), 2);
1430 
1431   EXPECT_EQ(values.getNumUsers(), 2);
1432   EXPECT_FALSE(values.hasOneUse());
1433   EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1434 
1435   // Add a couple more users.
1436   savesOfValues.insert(F->createSave("saveValues3", values));
1437   savesOfValues.insert(F->createSave("saveValues4", values));
1438 
1439   EXPECT_EQ(reLU->getNumUsers(), 4);
1440 
1441   EXPECT_EQ(values.getNumUsers(), 4);
1442   EXPECT_FALSE(values.hasOneUse());
1443   EXPECT_TRUE(hasAllTheseUses(savesOfValues, values));
1444 }
1445 
1446 /// Verify that the pre-order visitor works correctly.
TEST(Graph,PreOrderTest)1447 TEST(Graph, PreOrderTest) {
1448   Module M;
1449   PlaceholderBindings bindings;
1450   auto *F = M.createFunction("main");
1451 
1452   auto *input1 =
1453       M.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input1", true);
1454   auto *input2 =
1455       M.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input2", true);
1456   SplatNode *zero = F->createSplat("zero", input1->getType(), 0.);
1457   MulNode *mul1 = F->createMul("mul1", zero, input1);
1458   MulNode *mul2 = F->createMul("mul2", zero, input2);
1459   MulNode *mul3 = F->createMul("mul3", mul1, mul2);
1460   SaveNode *ret1 = F->createSave("ret1", mul3);
1461 
1462   SplatNode *one = F->createSplat("one", input2->getType(), 1.0);
1463   AddNode *add1 = F->createAdd("add1", input2, one);
1464   AddNode *add2 = F->createAdd("add2", add1, one);
1465   AddNode *add3 = F->createAdd("add3", add2, one);
1466   SaveNode *ret2 = F->createSave("ret2", add2);
1467 
1468   GraphPreOrderVisitor visitor(*F);
1469   auto order = visitor.getPreOrder();
1470 
1471   ASSERT_EQ(order.size(), 14);
1472   EXPECT_EQ(order[0], ret1);
1473   EXPECT_EQ(order[1], mul3);
1474   EXPECT_EQ(order[2], mul1);
1475   EXPECT_EQ(order[3], zero);
1476   EXPECT_EQ(order[4], input1);
1477   EXPECT_EQ(order[5], mul2);
1478   EXPECT_EQ(order[6], input2);
1479   EXPECT_EQ(order[7], ret1->getOutput());
1480   EXPECT_EQ(order[8], add3);
1481   EXPECT_EQ(order[9], add2);
1482   EXPECT_EQ(order[10], add1);
1483   EXPECT_EQ(order[11], one);
1484   EXPECT_EQ(order[12], ret2);
1485   EXPECT_EQ(order[13], ret2->getOutput());
1486 }
1487 
1488 /// Verify that the post-order visitor works correctly.
TEST(Graph,PostOrderTest)1489 TEST(Graph, PostOrderTest) {
1490   Module M;
1491   PlaceholderBindings bindings;
1492   auto *F = M.createFunction("main");
1493 
1494   auto *input1 =
1495       M.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input1", true);
1496   auto *input2 =
1497       M.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input2", true);
1498   SplatNode *zero = F->createSplat("zero", input1->getType(), 0.);
1499   MulNode *mul1 = F->createMul("mul1", zero, input1);
1500   MulNode *mul2 = F->createMul("mul2", zero, input2);
1501   MulNode *mul3 = F->createMul("mul3", mul1, mul2);
1502   SaveNode *ret1 = F->createSave("ret1", mul3);
1503 
1504   SplatNode *one = F->createSplat("one", input2->getType(), 1.0);
1505   AddNode *add1 = F->createAdd("add1", input2, one);
1506   AddNode *add2 = F->createAdd("add2", add1, one);
1507   AddNode *add3 = F->createAdd("add3", add2, one);
1508   SaveNode *ret2 = F->createSave("ret2", add2);
1509 
1510   GraphPostOrderVisitor visitor(*F);
1511   auto order = visitor.getPostOrder();
1512 
1513   ASSERT_EQ(order.size(), 14);
1514   EXPECT_EQ(order[0], zero);
1515   EXPECT_EQ(order[1], input1);
1516   EXPECT_EQ(order[2], mul1);
1517   EXPECT_EQ(order[3], input2);
1518   EXPECT_EQ(order[4], mul2);
1519   EXPECT_EQ(order[5], mul3);
1520   EXPECT_EQ(order[6], ret1->getOutput());
1521   EXPECT_EQ(order[7], ret1);
1522   EXPECT_EQ(order[8], one);
1523   EXPECT_EQ(order[9], add1);
1524   EXPECT_EQ(order[10], add2);
1525   EXPECT_EQ(order[11], add3);
1526   EXPECT_EQ(order[12], ret2->getOutput());
1527   EXPECT_EQ(order[13], ret2);
1528 }
1529 
TEST(Graph,placeholder)1530 TEST(Graph, placeholder) {
1531   Module MD;
1532   PlaceholderBindings bindings;
1533   Function *F = MD.createFunction("F");
1534   IRFunction M(F);
1535   Node *K =
1536       MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", false);
1537   Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", false);
1538 
1539   K = F->createFullyConnected(bindings, "FC", K, 10);
1540   K = F->createRELU("Relu", K);
1541   K = F->createSoftMax("SoftMax", K, S);
1542   F->createSave("Save", K);
1543 }
1544 
1545 /// Check that the setType API allows to change the type of the
1546 /// related result and only the related result.
TEST(Graph,setType)1547 TEST(Graph, setType) {
1548   Module M;
1549   auto *F = M.createFunction("main");
1550 
1551   const dim_t inputDims[] = {4, 10};
1552   const dim_t top5Dims[] = {4, 5};
1553   auto *input =
1554       M.createPlaceholder(ElemKind::FloatTy, inputDims, "input", true);
1555   TopKNode *topK = F->createTopK("add", input, 5);
1556   TypeRef origTopKRes0 = M.uniqueType(ElemKind::FloatTy, top5Dims);
1557   TypeRef origTopKRes1 = M.uniqueType(ElemKind::Int64ITy, top5Dims);
1558 
1559   EXPECT_EQ(topK->getType(TopKNode::ValuesIdx), origTopKRes0);
1560   EXPECT_EQ(topK->getType(TopKNode::IndicesIdx), origTopKRes1);
1561 
1562   // Modify the type of result 0 and make sure type 1 is not
1563   // affected. Similarly the input shouldn't be affected.
1564   TypeRef inputTy = M.uniqueType(ElemKind::FloatTy, inputDims);
1565   TypeRef topKRes0 = M.uniqueType(ElemKind::Float16Ty, top5Dims);
1566   topK->setType(TopKNode::ValuesIdx, topKRes0);
1567   EXPECT_EQ(input->getType(), inputTy);
1568   EXPECT_EQ(topK->getType(TopKNode::ValuesIdx), topKRes0);
1569   EXPECT_EQ(topK->getType(TopKNode::IndicesIdx), origTopKRes1);
1570 
1571   // Make sure the NodeValue API works the same way
1572   // as the Node::setType API.
1573   NodeValue valRes1 = topK->getNthResult(TopKNode::IndicesIdx);
1574   valRes1.setType(topKRes0);
1575   EXPECT_EQ(input->getType(), inputTy);
1576   EXPECT_EQ(topK->getType(TopKNode::ValuesIdx), topKRes0);
1577   EXPECT_EQ(topK->getType(TopKNode::IndicesIdx), topKRes0);
1578   EXPECT_EQ(valRes1.getType(), topKRes0);
1579 
1580   // Now restore sane types.
1581   NodeValue valRes0 = topK->getNthResult(TopKNode::ValuesIdx);
1582   valRes0.setType(origTopKRes0);
1583   topK->setType(TopKNode::IndicesIdx, origTopKRes1);
1584   EXPECT_EQ(input->getType(), inputTy);
1585   EXPECT_EQ(topK->getType(TopKNode::ValuesIdx), origTopKRes0);
1586   EXPECT_EQ(valRes0.getType(), origTopKRes0);
1587   EXPECT_EQ(topK->getType(TopKNode::IndicesIdx), origTopKRes1);
1588   EXPECT_EQ(valRes1.getType(), origTopKRes1);
1589 }
1590 
1591 /// Check that we fixed the bug with Function::eraseNode. This method used to
1592 /// erase a node that was equal to the node we wanted to delete, which may be
1593 /// two different entities.
1594 /// To see this bug in action, we create a bunch of nodes with the same value.
1595 /// Then we erase them in reserve order. This reserve ordering was actually
1596 /// freeing the node in the original order, thus at some point we try to delete
1597 /// a node that has already deleted and an assert (debug mode) or segmentation
1598 /// fault (release would occur).
1599 /// Note: Which node is actually freed depend on the implementation of
1600 /// std::find, thus we cannot really predict when the bug occurs.
TEST(Graph,eraseNodeBug)1601 TEST(Graph, eraseNodeBug) {
1602   Module M;
1603   auto *F = M.createFunction("main");
1604 
1605   auto *input = M.createPlaceholder(ElemKind::FloatTy, {3, 2}, "input", true);
1606   std::vector<Node *> ReLUs;
1607   // Create a bunch of ReLUs.
1608   for (unsigned idx = 0; idx != 5; ++idx) {
1609     ReLUs.push_back(F->createRELU("relu", input));
1610   }
1611   // Check that we can erase all the nodes.
1612   for (int idx = 4; idx != -1; --idx) {
1613     F->eraseNode(ReLUs[idx]);
1614   }
1615   EXPECT_EQ(F->getNodes().size(), 0);
1616 }
1617 
1618 /// Verify that two Nodes with different predicates but the same inputs are not
1619 /// considered equal.
TEST(Graph,nodeEqualityWithDifferentPredicates)1620 TEST(Graph, nodeEqualityWithDifferentPredicates) {
1621   Module M;
1622   auto *F = M.createFunction("main");
1623 
1624   Node *in = M.createPlaceholder(ElemKind::FloatTy, {5}, "in", false);
1625   Node *pred1 = M.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1626   Node *pred2 = M.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1627 
1628   Node *RN1 = F->createRELU("relu1", in);
1629   RN1->setPredicate(pred1);
1630 
1631   Node *RN2 = F->createRELU("relu2", in);
1632   RN2->setPredicate(pred2);
1633 
1634   EXPECT_FALSE(RN1->isEqual(*RN2));
1635 }
1636 
1637 /// Check that verify doesn't allow for multiple writers to the same node.
TEST(Graph,verifyOneWriter)1638 TEST(Graph, verifyOneWriter) {
1639   Module M;
1640   auto *F = M.createFunction("main");
1641 
1642   auto *input = M.createPlaceholder(ElemKind::FloatTy, {5}, "input", false);
1643   auto *output = M.createPlaceholder(ElemKind::FloatTy, {5}, "output", false);
1644   F->createSave("Save1", input, output);
1645   F->createSave("Save2", input, output);
1646 
1647   EXPECT_FALSE(M.verify());
1648 }
1649 
1650 /// Check that verify doesn't allow for Constants to be written to. Note that
1651 /// createSave() cannot do this as the API only accepts Placeholders to write
1652 /// to, however it could happen during graph transformations, e.g. via
1653 /// replaceAllUsesOfWith() as shown here.
TEST(Graph,verifyConstantNoWriters)1654 TEST(Graph, verifyConstantNoWriters) {
1655   Module M;
1656   auto *F = M.createFunction("main");
1657 
1658   auto *input = M.createPlaceholder(ElemKind::FloatTy, {5}, "input", false);
1659   auto *outputPH = M.createPlaceholder(ElemKind::FloatTy, {5}, "outPH", false);
1660   F->createSave("save", input, outputPH);
1661 
1662   // Replace the output Placeholder with a Constant. This should fail
1663   // verification.
1664   auto *outputC = M.createConstant(ElemKind::FloatTy, {5}, "outC");
1665   NodeValue(outputPH).replaceAllUsesOfWith(outputC);
1666 
1667   EXPECT_FALSE(M.verify());
1668 }
1669 
TEST(Graph,typeUnsafeReplaceAllUsesOfWith)1670 TEST(Graph, typeUnsafeReplaceAllUsesOfWith) {
1671   Module M;
1672   auto *F = M.createFunction("main");
1673 
1674   auto *LHS = M.createPlaceholder(ElemKind::FloatTy, {3, 4}, "A", false);
1675   auto *RHS = M.createPlaceholder(ElemKind::FloatTy, {4, 5}, "B", false);
1676   auto *FC = F->createMatMul("fc", LHS, RHS);
1677   F->createSave("save", FC);
1678 
1679   auto newLHS = M.createPlaceholder(ElemKind::FloatTy, {10, 10}, "A", false);
1680   LHS->getOutput().typeUnsafeReplaceAllUsesOfWith(newLHS);
1681 }
1682 
1683 /// Check that the verifier will complain if a constant and its
1684 /// underlying tensor have mismatching types.
1685 /// Here the constant is updated but not the tensor.
TEST(Graph,verifyConstantTensorTypeMatchesConstantTypeChanged)1686 TEST(Graph, verifyConstantTensorTypeMatchesConstantTypeChanged) {
1687   Module M;
1688 
1689   auto *input = M.createConstant(ElemKind::FloatTy, {5}, "input");
1690   // Fresh constant should verify just fine.
1691   EXPECT_TRUE(input->verify());
1692 
1693   input->setType(Storage::OutputIdx, M.uniqueType(ElemKind::Float16Ty, {5}));
1694 
1695   EXPECT_FALSE(input->verify());
1696 }
1697 
1698 /// Check that the verifier will complain if a constant and its
1699 /// underlying tensor have mismatching types.
1700 /// Here the tensor is updated but not the constant.
TEST(Graph,verifyConstantTensorTypeMatchesTensorTypeChanged)1701 TEST(Graph, verifyConstantTensorTypeMatchesTensorTypeChanged) {
1702   Module M;
1703 
1704   auto *input = M.createConstant(ElemKind::FloatTy, {5}, "input");
1705   // Fresh constant should verify just fine.
1706   EXPECT_TRUE(input->verify());
1707   input->getPayloadMutable().convertToType(ElemKind::Float16Ty);
1708 
1709   EXPECT_FALSE(input->verify());
1710 }
1711 
1712 /// Check that Constants backed by unowned Tensors are in fact unowned until
1713 /// a mutable reference to their payload is obtained at which point the backing
1714 /// Tensor is copied and becomes owned.
TEST(Graph,verifyConstantWithUnownedTensorCopiesOnWrite)1715 TEST(Graph, verifyConstantWithUnownedTensorCopiesOnWrite) {
1716   Module M;
1717 
1718   Tensor originalT(ElemKind::FloatTy, {3});
1719   Tensor unownedT = originalT.getUnowned({3});
1720 
1721   auto originalH = originalT.getHandle();
1722 
1723   for (size_t i = 0; i < originalT.size(); i++) {
1724     originalH.raw(i) = i;
1725   }
1726 
1727   // Both Tensors should have the same underlying memory because unownedT shares
1728   // originalT's memory.
1729   EXPECT_EQ(originalT.getUnsafePtr(), unownedT.getUnsafePtr());
1730 
1731   Constant *originalC = M.createConstant("original", std::move(originalT));
1732   Constant *unownedC = M.createConstant("unowned", std::move(unownedT));
1733 
1734   const Tensor &originalCT = originalC->getPayload();
1735   const Tensor &unownedCT = unownedC->getPayload();
1736 
1737   const auto originalCTH = originalCT.getHandle();
1738   const auto unownedCTH = unownedCT.getHandle();
1739 
1740   ASSERT_EQ(originalCTH.size(), unownedCTH.size());
1741 
1742   // Both Constants should have the same values because their Tensors have the
1743   // same underlying memory.
1744   for (size_t i = 0; i < originalCTH.size(); i++) {
1745     EXPECT_EQ(i, originalCTH.raw(i));
1746     EXPECT_EQ(i, unownedCTH.raw(i));
1747   }
1748 
1749   Tensor &originalCTM = originalC->getPayloadMutable();
1750   auto originalCTMH = originalCTM.getHandle();
1751 
1752   // Bump up the value in the original Constant, this change should be
1753   // reflected in the unowned Constant as well.
1754   for (size_t i = 0; i < originalCTMH.size(); i++) {
1755     originalCTMH.raw(i) += 1;
1756   }
1757 
1758   // After changing the values in the original Constant, we should see an update
1759   // in the values of the unowned Constant because they share the same
1760   // underlying memory.
1761   for (size_t i = 0; i < unownedCTH.size(); i++) {
1762     EXPECT_EQ(unownedCTH.raw(i), i + 1);
1763   }
1764 
1765   Tensor &unownedCTM = unownedC->getPayloadMutable();
1766   auto unownedCTMH = unownedCTM.getHandle();
1767 
1768   ASSERT_EQ(originalCTH.size(), unownedCTMH.size());
1769 
1770   // After getting a mutable reference to the unowned Constant's payload, the
1771   // underlying memory should have been copied but should still contain the same
1772   // values as it did previously at this point.
1773   EXPECT_NE(unownedCTM.getUnsafePtr(), originalCT.getUnsafePtr());
1774   for (size_t i = 0; i < unownedCTMH.size(); i++) {
1775     EXPECT_EQ(unownedCTMH.raw(i), i + 1);
1776   }
1777 
1778   // Bump up the value in the original Constant again, this change should not be
1779   // reflected in the unowned Constant now because at this point, after a
1780   // mutable reference to its payload has been obtained, it should have it's own
1781   // memory.
1782   for (size_t i = 0; i < originalCTMH.size(); i++) {
1783     originalCTMH.raw(i) += 1;
1784   }
1785 
1786   // Now that the unowned Constant's payload has been obtained as mutable, it
1787   // should have been copied and thus have its own memory and changes to the
1788   // original constant should not be reflected in the unowned Constant.
1789   for (size_t i = 0; i < unownedCTMH.size(); i++) {
1790     EXPECT_EQ(unownedCTMH.raw(i), i + 1);
1791   }
1792 }
1793 
1794 /// Check that hooking an intermediate node works.
TEST(Graph,hookTest)1795 TEST(Graph, hookTest) {
1796   Module mod;
1797   auto *F = mod.createFunction("main");
1798   auto *in = mod.createPlaceholder(ElemKind::FloatTy, {1}, "in", false);
1799   auto *relu1 = F->createRELU("relu1", in);
1800   auto *relu2 = F->createRELU("relu2", relu1);
1801   F->createSave("save", relu2);
1802   EXPECT_EQ(F->getNodes().size(), 3);
1803   EXPECT_EQ(mod.getPlaceholders().size(), 2);
1804 
1805   // Hook the first relu and verify that the hooked graph looks right.
1806   auto hooked = glow::hookNode(F, relu1);
1807   auto const &nodes = hooked.function->getNodes();
1808   ASSERT_EQ(mod.getPlaceholders().size(), 3);
1809   ASSERT_EQ(nodes.size(), 2);
1810   auto const *hookSave = *hooked.outputSaves.begin();
1811   ASSERT_TRUE(hookSave);
1812   auto *inp = llvm::dyn_cast<ReluNode>(hookSave->getInput());
1813   ASSERT_TRUE(inp);
1814   auto *ph = llvm::dyn_cast<Placeholder>(inp->getInput());
1815   ASSERT_TRUE(ph);
1816   ASSERT_EQ(ph, in);
1817 }
1818 
1819 /// Check that getConstantsSize returns the correct size of constants.
TEST(Graph,moduleSize)1820 TEST(Graph, moduleSize) {
1821   Module mod;
1822 
1823   EXPECT_EQ(mod.getConstantsSize(), 0);
1824 
1825   auto *cons1 = mod.createConstant(ElemKind::FloatTy, {1}, "var");
1826   EXPECT_EQ(mod.getConstantsSize(), sizeof(float) * cons1->getPayload().size());
1827 
1828   auto *cons2 = mod.createConstant(ElemKind::FloatTy, {1, 32, 32, 16}, "var2");
1829   EXPECT_EQ(mod.getConstantsSize(),
1830             sizeof(float) + sizeof(float) * cons2->getPayload().size());
1831 }
1832 
1833 /// Check that getDataSize() returns the correct size of backing tensors.
TEST(Graph,contextSize)1834 TEST(Graph, contextSize) {
1835   Module mod;
1836   PlaceholderBindings bindings;
1837 
1838   Placeholder *PH =
1839       mod.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
1840 
1841   EXPECT_EQ(bindings.getDataSize(), 0);
1842   bindings.allocate(PH);
1843   EXPECT_EQ(bindings.get(PH)->size(), 4 * 320 * 200 * 3);
1844   EXPECT_EQ(bindings.getDataSize(), sizeof(float) * bindings.get(PH)->size());
1845 }
1846 
1847 /// Check that clones of the context are distinct and share no references back
1848 /// to the original object.
TEST(Graph,clonePlaceholderBindings)1849 TEST(Graph, clonePlaceholderBindings) {
1850   Module mod;
1851 
1852   Placeholder *PH1 =
1853       mod.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4}, "PH1", false);
1854 
1855   PlaceholderBindings bindings1;
1856   bindings1.allocate(PH1);
1857 
1858   PlaceholderBindings bindings2 = bindings1.clone();
1859 
1860   Tensor *t1 = bindings1.get(PH1);
1861   Tensor *t2 = bindings2.get(PH1);
1862 
1863   EXPECT_NE(t1, nullptr);
1864   EXPECT_NE(t2, nullptr);
1865   EXPECT_NE(t1, t2);
1866 
1867   // The new PlaceholderBindings has no references back, and changing it does
1868   // not affect bindings1
1869   Placeholder *PH2 =
1870       mod.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4}, "PH2", false);
1871 
1872   bindings2.allocate(PH2);
1873   // now exists in bindings1 but not bindings2
1874   EXPECT_EQ(bindings1.get(PH2), nullptr);
1875   EXPECT_NE(bindings2.get(PH2), nullptr);
1876 
1877   // Likewise changing bindings1 does not affect bindings2
1878   bindings1.clear();
1879   EXPECT_EQ(bindings1.count(PH1), 0);
1880   EXPECT_EQ(bindings2.count(PH1), 1);
1881 
1882   // Adds are distinct
1883   Placeholder *PH3 =
1884       mod.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4}, "PH3", false);
1885   bindings1.allocate(PH3);
1886   bindings2.allocate(PH3);
1887   EXPECT_NE(bindings1.get(PH3), nullptr);
1888   EXPECT_NE(bindings2.get(PH3), nullptr);
1889   EXPECT_NE(bindings1.get(PH3), bindings2.get(PH3));
1890 }
1891 
1892 /// Check that running a function multiple times on cloned PlaceholderBindingss
1893 /// have distinct outputs.
TEST(Graph,clonePlaceholderBindingsRuns)1894 TEST(Graph, clonePlaceholderBindingsRuns) {
1895   ExecutionEngine EE;
1896   PseudoRNG PRNG;
1897 
1898   Tensor inputs(ElemKind::FloatTy, {1, 32, 32, 3});
1899   auto &mod = EE.getModule();
1900   Function *F = mod.createFunction("main");
1901   PlaceholderBindings bindings;
1902   auto *input =
1903       mod.createPlaceholder(ElemKind::FloatTy, {1, 32, 32, 3}, "input", true);
1904 
1905   bindings.allocate(input);
1906 
1907   auto *FCL1 = F->createFullyConnected(bindings, "fc", input, 10);
1908   auto *RL3 = F->createRELU("relu4", FCL1);
1909   auto *save = F->createSave("ret", RL3);
1910   auto *savePH = save->getPlaceholder();
1911 
1912   bindings.allocate(save->getPlaceholder());
1913 
1914   // Compile once.
1915   EE.compile(CompilationMode::Infer);
1916 
1917   // Run with random inputs.
1918   inputs.getHandle<>().randomize(-3.0, 3.0, PRNG);
1919   updateInputPlaceholders(bindings, {input}, {&inputs});
1920   EE.run(bindings);
1921 
1922   // Clone the context.
1923   PlaceholderBindings bindings2 = bindings.clone();
1924 
1925   // PlaceholderBindingss are identical.
1926   Tensor *saveBacking1, *saveBacking2;
1927   saveBacking1 = bindings.get(savePH);
1928   saveBacking2 = bindings2.get(savePH);
1929   EXPECT_NE(saveBacking1, saveBacking2);
1930   EXPECT_EQ(saveBacking1->size(), saveBacking2->size());
1931   EXPECT_TRUE(saveBacking1->isEqual(*saveBacking2));
1932 
1933   // Run again with different random inputs using the cloned context.
1934   Tensor inputs2(ElemKind::FloatTy, {1, 32, 32, 3});
1935   inputs2.getHandle<>().randomize(-3.0, 3.0, PRNG);
1936   updateInputPlaceholders(bindings2, {input}, {&inputs2});
1937   EE.run(bindings2);
1938 
1939   // PlaceholderBindingss are no longer identical.
1940   EXPECT_EQ(saveBacking1->size(), saveBacking2->size());
1941   EXPECT_FALSE(saveBacking1->isEqual(*saveBacking2));
1942 }
1943 
1944 /// Check that using the indices enums in nodes works correctly, with
1945 /// multi-input, multi-output, and single-input/output nodes.
TEST(Graph,TestNodeEnums)1946 TEST(Graph, TestNodeEnums) {
1947   Module MD;
1948   Function *F = MD.createFunction("F");
1949   PlaceholderBindings bindings;
1950   Placeholder *I =
1951       MD.createPlaceholder(ElemKind::FloatTy, {10, 10}, "input", true);
1952   Placeholder *O = MD.createPlaceholder(ElemKind::FloatTy, {3}, "output", true);
1953 
1954   TopKNode *TKN = F->createTopK("topk", I, 3);
1955   GatherNode *GN =
1956       F->createGather("gather", TKN->getValues(), TKN->getIndices());
1957   TanhNode *TN = F->createTanh("tanh", GN);
1958   SaveNode *SN = F->createSave("save", TN, O);
1959 
1960   // Check structure of Placeholders.
1961   EXPECT_EQ(I->getNthResult(Storage::OutputIdx), I->getOutput());
1962   EXPECT_EQ(O->getNthResult(Storage::OutputIdx), O->getOutput());
1963 
1964   // Check structure of TopK.
1965   EXPECT_EQ(TKN->getInput(), TKN->getNthInput(TopKNode::InputIdx));
1966   EXPECT_EQ(TKN->getNthResult(TopKNode::ValuesIdx), TKN->getValues());
1967   EXPECT_EQ(TKN->getNthResult(TopKNode::IndicesIdx), TKN->getIndices());
1968 
1969   // Check structure of Gather.
1970   EXPECT_EQ(GN->getNthInput(GatherNode::DataIdx), GN->getData());
1971   EXPECT_EQ(GN->getNthInput(GatherNode::IndicesIdx), GN->getIndices());
1972   EXPECT_EQ(GN->getNthResult(GatherNode::ResultIdx), GN->getResult());
1973 
1974   // Check structure of Tanh.
1975   EXPECT_EQ(TN->getNthInput(TanhNode::InputIdx), TN->getInput());
1976   EXPECT_EQ(TN->getNthResult(TanhNode::ResultIdx), TN->getResult());
1977 
1978   // Check structure of Save.
1979   EXPECT_EQ(SN->getNthInput(SaveNode::InputIdx), SN->getInput());
1980   EXPECT_EQ(SN->getNthInput(SaveNode::OutputIdx), SN->getOutput());
1981 
1982   // Check connection between Placeholder and TopK.
1983   EXPECT_EQ(TKN->getNthInput(TopKNode::InputIdx), I->getOutput());
1984 
1985   // Check connections between TopK and Gather.
1986   EXPECT_EQ(TKN->getNthResult(TopKNode::ValuesIdx),
1987             GN->getNthInput(GatherNode::DataIdx));
1988   EXPECT_EQ(TKN->getNthResult(TopKNode::IndicesIdx),
1989             GN->getNthInput(GatherNode::IndicesIdx));
1990 
1991   // Check connection between Gather and Tanh.
1992   EXPECT_EQ(GN->getNthResult(GatherNode::ResultIdx),
1993             TN->getNthInput(TanhNode::InputIdx));
1994 
1995   // Check connection between Gather and Tanh.
1996   EXPECT_EQ(TN->getNthResult(TanhNode::ResultIdx),
1997             SN->getNthInput(SaveNode::InputIdx));
1998 
1999   // Check connection between Gather and Tanh.
2000   EXPECT_EQ(SN->getNthInput(SaveNode::OutputIdx), O->getOutput());
2001 }
2002 
2003 /// Searched \p F for a single instance of a node of Kind T. If more than one is
2004 /// found, \returns nullptr, otherwise returns the single instance.
findSingleInstanceOfNode(Function * F)2005 template <class T> static T *findSingleInstanceOfNode(Function *F) {
2006   T *found = nullptr;
2007   for (auto &n : F->getNodes()) {
2008     if (auto *currNode = llvm::dyn_cast<T>(&n)) {
2009       if (found != nullptr) {
2010         return nullptr;
2011       }
2012       found = currNode;
2013     }
2014   }
2015   return found;
2016 }
2017 
2018 /// Check that group Conv is not lowered when specified to lower by backend if
2019 /// doNotLowerKinds contains Conv.
TEST(Graph,GroupTestConvNoLower)2020 TEST(Graph, GroupTestConvNoLower) {
2021   Module MD;
2022   Function *F = MD.createFunction("F");
2023   IRFunction M(F);
2024   PlaceholderBindings bindings;
2025   Node *K =
2026       MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 8}, "input", true);
2027   Node *S = MD.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "select", true);
2028 
2029   K = F->createConv(bindings, "Conv1", K, 16, 3, 2, 3, /* group */ 8);
2030   K = F->createRELU("Relu", K);
2031   K = F->createSoftMax("SoftMax", K, S);
2032   F->createSave("Save", K);
2033   F->dump();
2034   auto filePath = F->dumpDAG();
2035   auto backend = MockBackend();
2036 
2037   {
2038     // Before we lower, we should have a single Conv node with group = 8.
2039     ConvolutionNode *CN = findSingleInstanceOfNode<ConvolutionNode>(F);
2040     if (!CN) {
2041       llvm::sys::fs::remove(filePath);
2042     }
2043     ASSERT_TRUE(CN);
2044     EXPECT_EQ(CN->getGroup(), 8);
2045   }
2046 
2047   // Now lower, but prevent ConvolutionNodeKinds from being lowered.
2048   KindSet doNotLower;
2049   doNotLower.insert(Kinded::Kind::ConvolutionNodeKind);
2050   CompilationContext cctx;
2051   lower(F, cctx, &backend, doNotLower);
2052 
2053   {
2054     // Now have lowered but should still have a single Conv node with group = 8.
2055     ConvolutionNode *CN = findSingleInstanceOfNode<ConvolutionNode>(F);
2056     if (!CN) {
2057       llvm::sys::fs::remove(filePath);
2058     }
2059     ASSERT_TRUE(CN);
2060     EXPECT_EQ(CN->getGroup(), 8);
2061   }
2062 }
2063 
2064 /// Check that getOutputSave returns SaveNode object for the correct Placeholder
2065 /// and nullptr in other cases.
TEST(Graph,GetOutputSaveTest)2066 TEST(Graph, GetOutputSaveTest) {
2067   Module MD;
2068   Function *F = MD.createFunction("F");
2069   PlaceholderBindings bindings;
2070   Placeholder *I =
2071       MD.createPlaceholder(ElemKind::FloatTy, {10, 10}, "input", true);
2072   Placeholder *O = MD.createPlaceholder(ElemKind::FloatTy, {3}, "output", true);
2073   TopKNode *TKN = F->createTopK("topk", I, 3);
2074   GatherNode *GN =
2075       F->createGather("gather", TKN->getValues(), TKN->getIndices());
2076   TanhNode *TN = F->createTanh("tanh", GN);
2077   SaveNode *SN = F->createSave("save", TN, O);
2078 
2079   // Check the return value of getOutputSave method.
2080   // Placeholder parent is null.
2081   auto *FoundNode = glow::getOutputSave(F, O);
2082   EXPECT_NE(nullptr, FoundNode);
2083   EXPECT_EQ(SN, FoundNode);
2084 
2085   // Placeholder parent is set to the correct value.
2086   O->setParent(F);
2087   EXPECT_EQ(F, O->getParent());
2088   FoundNode = glow::getOutputSave(F, O);
2089   EXPECT_NE(nullptr, FoundNode);
2090   EXPECT_EQ(SN, FoundNode);
2091 
2092   // Invalid placeholder type is provided.
2093   EXPECT_EQ(nullptr, glow::getOutputSave(F, I));
2094 
2095   // Save belongs to a different function
2096   Function *F2 = MD.createFunction("F2");
2097   TopKNode *TKN2 = F2->createTopK("topk", I, 3);
2098   GatherNode *GN2 =
2099       F2->createGather("gather", TKN2->getValues(), TKN2->getIndices());
2100   TanhNode *TN2 = F2->createTanh("tanh", GN2);
2101   SaveNode *SN2 = F2->createSave("save", TN2, O);
2102 
2103   FoundNode = glow::getOutputSave(F, O);
2104   EXPECT_NE(nullptr, FoundNode);
2105   EXPECT_EQ(SN, FoundNode);
2106 
2107   O->setParent(F2);
2108   FoundNode = glow::getOutputSave(F2, O);
2109   EXPECT_NE(nullptr, FoundNode);
2110   EXPECT_EQ(SN2, FoundNode);
2111 }
2112 
2113 /// Check if dump functions work for Node, Function and Module.
TEST(Graph,testDumpStructure)2114 TEST(Graph, testDumpStructure) {
2115   Module MD;
2116   Function *F = MD.createFunction("F");
2117   IRFunction M(F);
2118   PlaceholderBindings bindings;
2119   Node *K = MD.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 100, 3},
2120                                  "input", true);
2121   // Test Node
2122   std::string storageN1;
2123   llvm::raw_string_ostream osN1(storageN1);
2124   K->dump(osN1);
2125   std::string mesN = K->toString();
2126   std::string expectMes = R"(Placeholder
2127 name : "input"
2128 layout : *
2129 output : float<4 x 320 x 200 x 100 x 3>
2130 trainable : 1
2131 static : 0
2132 users : 0
2133 )";
2134   EXPECT_EQ(mesN, expectMes);
2135   EXPECT_EQ(mesN, osN1.str());
2136   std::string storageN2;
2137   llvm::raw_string_ostream osN2(storageN2);
2138   osN2 << K;
2139   EXPECT_EQ(mesN, osN2.str());
2140   // Test Function
2141   Placeholder *I =
2142       MD.createPlaceholder(ElemKind::FloatTy, {10, 10}, "input", true);
2143   I->setStatic(true);
2144   Function *F2 = MD.createFunction("F2");
2145   F2->createTopK("topk", I, 3);
2146   std::string storageF1;
2147   llvm::raw_string_ostream osF1(storageF1);
2148   F2->dump(osF1);
2149   std::string mesF = F2->toString();
2150   std::string expectMesF = R"(Graph structure F2:
2151 TopK
2152 name : topk
2153 Input : float<10 x 10>
2154 K : 3
2155 users : 0
2156 Values : float<10 x 3>
2157 Indices : index64<10 x 3>
2158 Placeholder
2159 name : "input__1"
2160 layout : *
2161 output : float<10 x 10>
2162 trainable : 1
2163 static : 1
2164 users : 1
2165 )";
2166   EXPECT_EQ(mesF, expectMesF);
2167   EXPECT_EQ(mesF, osF1.str());
2168   std::string storageF2;
2169   llvm::raw_string_ostream osF2(storageF2);
2170   osF2 << F2;
2171   EXPECT_EQ(mesF, osF2.str());
2172   storageF1.clear();
2173   F2->dump(osF1, /* skipUsersForStorage */ true);
2174   mesF = F2->toString(/* skipUsersForStorage */ true);
2175   expectMesF = R"(Graph structure F2:
2176 TopK
2177 name : topk
2178 Input : float<10 x 10>
2179 K : 3
2180 users : 0
2181 Values : float<10 x 3>
2182 Indices : index64<10 x 3>
2183 Placeholder
2184 name : "input__1"
2185 layout : *
2186 output : float<10 x 10>
2187 trainable : 1
2188 static : 1
2189 )";
2190   EXPECT_EQ(mesF, expectMesF);
2191   EXPECT_EQ(mesF, osF1.str());
2192   // Test Module
2193   MD.createConstant(ElemKind::FloatTy, {1, 1}, "dummy");
2194   std::string storageM1;
2195   llvm::raw_string_ostream osM1(storageM1);
2196   MD.dump(osM1);
2197   std::string mesM = MD.toString();
2198   std::string expectMesM = R"(Module structure:
2199 Constant
2200 name : "dummy"
2201 layout : *
2202 output : float<1 x 1>
2203 users : 0
2204 
2205 Placeholder
2206 name : "input__1"
2207 layout : *
2208 output : float<10 x 10>
2209 trainable : 1
2210 static : 1
2211 users : 1
2212 
2213 Placeholder
2214 name : "input"
2215 layout : *
2216 output : float<4 x 320 x 200 x 100 x 3>
2217 trainable : 1
2218 static : 0
2219 users : 0
2220 
2221 Function : F2
2222 Function : F
2223 )";
2224   EXPECT_EQ(mesM, expectMesM);
2225   EXPECT_EQ(mesM, osM1.str());
2226   std::string storageM2;
2227   llvm::raw_string_ostream osM2(storageM2);
2228   osM2 << MD;
2229   EXPECT_EQ(mesM, osM2.str());
2230 }
2231 
2232 /// Initialize tensor payload for testing purposes. The value at index i is set
2233 /// to i.
initTensor(Tensor & T)2234 template <typename ElemTy> static void initTensor(Tensor &T) {
2235   Handle<ElemTy> handle = T.getHandle<ElemTy>();
2236   float val = 0;
2237   for (auto &elem : handle) {
2238     elem = val;
2239     val += 1.0;
2240   }
2241 }
2242 
2243 // Test that randomizing Constants in a Function works.
TEST(Graph,testRandomizeConstants)2244 TEST(Graph, testRandomizeConstants) {
2245   Module MD;
2246   Function *F = MD.createFunction("F");
2247 
2248   // Create tensors to be used in Constants
2249   Tensor floatT(ElemKind::FloatTy, {10});
2250   initTensor<float>(floatT);
2251 
2252   Tensor halfT(ElemKind::Float16Ty, {10});
2253   initTensor<float16_t>(halfT);
2254 
2255   Tensor bfloat16T(ElemKind::BFloat16Ty, {10});
2256   initTensor<bfloat16_t>(bfloat16T);
2257 
2258   Tensor int8QT(ElemKind::Int8QTy, {10}, 1.0, 0);
2259   initTensor<int8_t>(int8QT);
2260 
2261   Tensor uint8QT(ElemKind::UInt8QTy, {10}, 1.0, 0);
2262   initTensor<uint8_t>(uint8QT);
2263 
2264   Tensor int16QT(ElemKind::Int16QTy, {10}, 1.0, 0);
2265   initTensor<int16_t>(int16QT);
2266 
2267   Tensor int32QT(ElemKind::Int32QTy, {10}, 1.0, 0);
2268   initTensor<int32_t>(int32QT);
2269 
2270   Tensor int32IT(ElemKind::Int32ITy, {10});
2271   initTensor<int32_t>(int32IT);
2272 
2273   Tensor int64IT(ElemKind::Int64ITy, {10});
2274   initTensor<int64_t>(int64IT);
2275 
2276   Tensor uint8FusedQT(ElemKind::UInt8FusedQTy, {16, 16}, 1.0, 0);
2277   initTensor<uint8_t>(uint8FusedQT);
2278 
2279   Tensor uint8FusedFP16QT(ElemKind::UInt8FusedFP16QTy, {16, 16}, 1.0, 0);
2280   initTensor<uint8_t>(uint8FusedFP16QT);
2281 
2282   Tensor uint4FusedFP16QT(ElemKind::UInt4FusedFP16QTy, {16, 16}, 1.0, 0);
2283   initTensor<uint8_t>(uint4FusedFP16QT);
2284 
2285   Tensor boolT(ElemKind::BoolTy, {10});
2286   initTensor<bool>(boolT);
2287 
2288   // Create Constants and use them in F
2289   auto *floatC = MD.createConstant("floatC", floatT);
2290   F->createAdd("add", floatC, floatC);
2291 
2292   auto *halfC = MD.createConstant("halfC", halfT);
2293   F->createAdd("add", halfC, halfC);
2294 
2295   auto *bfloat16C = MD.createConstant("bloat16C", bfloat16T);
2296   F->createAdd("add", bfloat16C, bfloat16C);
2297 
2298   auto *int8QC = MD.createConstant("int8QC", int8QT);
2299   F->createAdd("add", int8QC, int8QC);
2300 
2301   auto *uint8QC = MD.createConstant("uint8QC", uint8QT);
2302   F->createAdd("add", uint8QC, uint8QC);
2303 
2304   auto *int16QC = MD.createConstant("int16QC", int16QT);
2305   F->createAdd("add", int16QC, int16QC);
2306 
2307   auto *int32QC = MD.createConstant("int32QC", int32QT);
2308   F->createAdd("add", int32QC, int32QC);
2309 
2310   auto *int32IC = MD.createConstant("int32IC", int32IT);
2311   F->createAdd("add", int32IC, int32IC);
2312 
2313   auto *int64IC = MD.createConstant("int64IC", int64IT);
2314   F->createAdd("add", int64IC, int64IC);
2315 
2316   auto *uint8FusedQC = MD.createConstant("uint8FusedQC", uint8FusedQT);
2317   F->createAdd("add", uint8FusedQC, uint8FusedQC);
2318 
2319   auto *uint8FusedFP16QC =
2320       MD.createConstant("uint8FusedFP16QC", uint8FusedFP16QT);
2321   F->createAdd("add", uint8FusedFP16QC, uint8FusedFP16QC);
2322 
2323   auto *uint4FusedFP16QC =
2324       MD.createConstant("uint4FusedFP16QC", uint4FusedFP16QT);
2325   F->createAdd("add", uint4FusedFP16QC, uint4FusedFP16QC);
2326 
2327   auto *boolC = MD.createConstant("boolC", boolT);
2328   F->createAdd("add", boolC, boolC);
2329 
2330   // Randomize Constants in F
2331   F->randomizeConstants();
2332 
2333   // Check that no Constant is the same as what it started as
2334   EXPECT_FALSE(floatT.isEqual(floatC->getPayload()));
2335   EXPECT_FALSE(halfT.isEqual(halfC->getPayload()));
2336   EXPECT_FALSE(bfloat16T.isEqual(bfloat16C->getPayload()));
2337   EXPECT_FALSE(int8QT.isEqual(int8QC->getPayload()));
2338   EXPECT_FALSE(uint8QT.isEqual(uint8QC->getPayload()));
2339   EXPECT_FALSE(int16QT.isEqual(int16QC->getPayload()));
2340   EXPECT_FALSE(int32QT.isEqual(int32QC->getPayload()));
2341   EXPECT_FALSE(int32IT.isEqual(int32IC->getPayload()));
2342   EXPECT_FALSE(int64IT.isEqual(int64IC->getPayload()));
2343   EXPECT_FALSE(uint8FusedQT.isEqual(uint8FusedQC->getPayload()));
2344   EXPECT_FALSE(uint8FusedFP16QT.isEqual(uint8FusedFP16QC->getPayload()));
2345   EXPECT_FALSE(uint4FusedFP16QT.isEqual(uint4FusedFP16QC->getPayload()));
2346   EXPECT_FALSE(boolT.isEqual(boolC->getPayload()));
2347 }
2348