1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "BackendTestUtils.h"
17 
18 #include "glow/Graph/Graph.h"
19 #include "glow/Graph/Node.h"
20 #include "glow/Graph/Nodes.h"
21 #include "glow/Graph/PlaceholderBindings.h"
22 #include "glow/IR/IR.h"
23 #include "glow/Optimizer/GraphOptimizer/FunctionPassPipeline.h"
24 #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
25 #include "glow/Optimizer/Lower/Lower.h"
26 
27 #include "gtest/gtest.h"
28 
29 using namespace glow;
30 
31 class GraphFold : public GraphOptz {};
32 
33 /// A helper predicate to check if the provided node has the same address as a
34 /// pre-defined address provided in constructor. This is useful if you need to
35 /// check that a given node is still in the graph. In general, it is not safe to
36 /// use the std::find(begin_it, end_it, value) and compare the nodes by value,
37 /// because the node provided as the last parameter of std::find (i.e. the value
38 /// reference) may have been removed by some optimizations and cannot be
39 /// dereferenced anymore. But comparing the addresses of the nodes should be
40 /// fine. Thus, one can use the following form instead:
41 /// std::find_if(begin_it, end_it, IsSameNodeAddress(node_address))
42 struct IsSameNodeAddress {
43   const Node *nodeAddress_;
IsSameNodeAddressIsSameNodeAddress44   IsSameNodeAddress(const Node *nodeAddress) : nodeAddress_(nodeAddress) {}
operator ()IsSameNodeAddress45   bool operator()(const Node &n) const { return &n == nodeAddress_; }
46 };
47 
48 /// \returns true if the Function \p F contains the Node \p N.
functionContainsNode(const Function * F,const Node * N)49 static bool functionContainsNode(const Function *F, const Node *N) {
50   return std::find_if(F->getNodes().begin(), F->getNodes().end(),
51                       IsSameNodeAddress(N)) != F->getNodes().end();
52 }
53 
54 /// Optimize the function \p F with \p cctx. \returns the optimized function. If
55 /// \p pass is empty then the whole default optimization pipeline is run.
56 /// Otherwise only \p pipeline is used.
57 static Function *
optimizeFunction(Function * F,std::initializer_list<FunctionPassConfig> configs={},const CompilationContext cctx=CompilationContext ())58 optimizeFunction(Function *F,
59                  std::initializer_list<FunctionPassConfig> configs = {},
60                  const CompilationContext cctx = CompilationContext()) {
61   auto *G = F->clone(F->getName().str() + "_optimized");
62   if (configs.size() == 0) {
63     ::glow::optimize(G, CompilationMode::Infer);
64     return G;
65   }
66   FunctionPassManager FPM("TestFPM", configs);
67   FPM.run(G, cctx);
68   return G;
69 }
70 
71 /// \returns the first node in a function which has the specificied name.
72 template <typename NodeT = Node>
findFunctionNodeByName(const Function * F,const llvm::StringRef name)73 static const NodeT *findFunctionNodeByName(const Function *F,
74                                            const llvm::StringRef name) {
75   return llvm::dyn_cast<NodeT>(
76       std::find_if(F->getNodes().begin(), F->getNodes().end(),
77                    [=](auto &N) { return N.getName() == name; }));
78 }
79 
TEST_F(GraphOptz,OptimizeClipFunnel)80 TEST_F(GraphOptz, OptimizeClipFunnel) {
81   auto *A =
82       mod_.createPlaceholder(ElemKind::FloatTy, {100, 16}, "input", false);
83   Node *K = A;
84   float min = 0.0;
85   float max = 1000.0;
86   for (int i = 0; i < 10; ++i) {
87     min += 1.0;
88     max -= 1.0;
89     K = F_->createClip("clip", K, min, max);
90   }
91   F_->createSave("ret", K);
92 
93   EXPECT_EQ(F_->getNodes().size(), 11);
94 
95   optimizedF_ = optimizeFunction(F_);
96   EXPECT_EQ(optimizedF_->getNodes().size(), 2);
97 
98   // Find clip node in the optimized graph.
99   Node *newClip = A;
100   for (auto &N : optimizedF_->getNodes()) {
101     if (N.getKind() == Kinded::Kind::ClipNodeKind) {
102       newClip = llvm::dyn_cast<ClipNode>(&N);
103     }
104   }
105   EXPECT_TRUE(llvm::isa<ClipNode>(newClip));
106   ClipNode *c = llvm::dyn_cast<ClipNode>(newClip);
107   EXPECT_EQ(min, c->getMin());
108   EXPECT_EQ(max, c->getMax());
109 
110   bindings_.allocate(mod_.getPlaceholders());
111   bindings_.get(A)->getHandle().randomize(-1000, 1000, mod_.getPRNG());
112   bindings_.get(A)->getHandle().raw(0) = -1000;
113   checkNumericalEquivalence();
114 }
115 
TEST_F(GraphOptz,DCE)116 TEST_F(GraphOptz, DCE) {
117   Node *K = mod_.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input",
118                                    false);
119 
120   for (int i = 0; i < 40; i++) {
121     K = F_->createRELU("relu", K);
122     // Add a graph structure that diverges and converges, to catch algorithms
123     // that perform a dump recursive scan.
124     K = F_->createAdd("arith", K, K);
125   }
126 
127   // Check that we know how many nodes we've created.
128   EXPECT_EQ(F_->getNodes().size(), 80);
129 
130   // Optimize all of the dead code.
131   ::glow::optimize(F_, CompilationMode::Infer);
132 
133   //  All of the nodes are gone.
134   EXPECT_EQ(F_->getNodes().size(), 0);
135   EXPECT_EQ(mod_.getConstants().size(), 0);
136 }
137 
138 /// Check that predicated instructions are DCE'ed like
139 /// regular instructions.
TEST_F(GraphOptz,DCEwithPredicate)140 TEST_F(GraphOptz, DCEwithPredicate) {
141   Node *K = mod_.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input",
142                                    false);
143   Node *predicatedBatch =
144       mod_.createPlaceholder(ElemKind::FloatTy, {4}, "predicate", true);
145   for (int i = 0; i < 40; i++) {
146     K = F_->createRELU("relu", K);
147     K->setPredicate(predicatedBatch);
148     // Add a graph structure that diverges and converges, to catch algorithms
149     // that perform a dump recursive scan.
150     K = F_->createAdd("arith", K, K);
151     K->setPredicate(predicatedBatch);
152   }
153 
154   // Check that we know how many nodes we've created.
155   EXPECT_EQ(F_->getNodes().size(), 80);
156 
157   // Optimize all of the dead code.
158   ::glow::optimize(F_, CompilationMode::Infer);
159 
160   //  All of the nodes are gone.
161   EXPECT_EQ(F_->getNodes().size(), 0);
162   EXPECT_EQ(mod_.getConstants().size(), 0);
163 }
164 
TEST_F(GraphOptz,liveCodeNotEliminated)165 TEST_F(GraphOptz, liveCodeNotEliminated) {
166   Node *K = mod_.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input",
167                                    false);
168   auto *Ex = mod_.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "Ex", false);
169 
170   for (int i = 0; i < 40; i++) {
171     K = F_->createRELU("relu", K);
172     K = F_->createAdd("arith", K, K);
173   }
174   K = F_->createSoftMax("Regression", K, Ex);
175   F_->createSave("ret", K);
176 
177   // Check that we know how many nodes we've created.
178   EXPECT_EQ(F_->getNodes().size(), 82);
179 
180   // This should not optimize code because none is dead.
181   ::glow::optimize(F_, CompilationMode::Infer);
182 
183   //  Nothing got optimized.
184   EXPECT_EQ(F_->getNodes().size(), 82);
185   EXPECT_EQ(mod_.getPlaceholders().size(), 3);
186 }
187 
188 /// Skip Reshape sinking below BatchNorm when inapplicable.
TEST_F(GraphOptz,SkipReshapeSinkBatchNorm)189 TEST_F(GraphOptz, SkipReshapeSinkBatchNorm) {
190   auto *A = mod_.createPlaceholder(ElemKind::FloatTy, {32, 64}, "A", false);
191   Node *RS = F_->createReshape("reshape", A, {32, 64, 1});
192   Node *BN =
193       F_->createBatchNormalization(bindings_, "batch", RS, 1, 0.0001, 0.9);
194   F_->createSave("ret", BN);
195 
196   optimizedF_ = optimizeFunction(F_);
197   EXPECT_EQ(F_->toString(/* skipUsersForStorage */ false, /* skipName */ true),
198             optimizedF_->toString(/* skipUsersForStorage */ false,
199                                   /* skipName */ true));
200 }
201 
202 // Conv->Reshape->BatchNorm is optimized to Conv->Reshape after sinking Reshape
203 // below BatchNorm. Reshape transforms [N][H][W][C] to [N][W][H][C].
TEST_F(GraphOptz,optimizeBatchNormAfterConvAndReshapeNHWC)204 TEST_F(GraphOptz, optimizeBatchNormAfterConvAndReshapeNHWC) {
205   auto *A =
206       mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
207   Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
208   Node *RS = F_->createReshape("reshape", CV, {1, 20, 10, 16});
209   Node *BN =
210       F_->createBatchNormalization(bindings_, "batch", RS, 3, 0.0001, 0.9);
211   F_->createSave("ret", BN);
212 
213   EXPECT_EQ(F_->getNodes().size(), 4);
214   ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
215   optimizedF_ = optimizeFunction(F_);
216   EXPECT_EQ(optimizedF_->getNodes().size(), 3);
217 
218   ASSERT_EQ(A->getNumUsers(), 2);
219   Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(),
220                                  [CV](auto &it) { return it.getUser() == CV; })
221                     ->getUser();
222   EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
223   ASSERT_EQ(newCV->getNumUsers(), 1);
224   Node *reshape = newCV->getUsers().begin()->getUser();
225   EXPECT_TRUE(llvm::isa<ReshapeNode>(reshape));
226 
227   bindings_.allocate(mod_.getPlaceholders());
228   bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
229   checkNumericalEquivalence();
230 }
231 
232 // Conv->Reshape->BatchNorm is optimized to Conv->Reshape after sinking Reshape
233 // below BatchNorm. Reshape flattens [N][H][W][C] to [N][HxW][C].
TEST_F(GraphOptz,optimizeBatchNormAfterConvAndReshapeNHWC2)234 TEST_F(GraphOptz, optimizeBatchNormAfterConvAndReshapeNHWC2) {
235   auto *A =
236       mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
237   Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
238   Node *RS = F_->createReshape("reshape", CV, {1, 200, 16});
239   Node *BN =
240       F_->createBatchNormalization(bindings_, "batch", RS, 2, 0.0001, 0.9);
241   F_->createSave("ret", BN);
242 
243   EXPECT_EQ(F_->getNodes().size(), 4);
244   ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
245   optimizedF_ = optimizeFunction(F_);
246   EXPECT_EQ(optimizedF_->getNodes().size(), 3);
247 
248   ASSERT_EQ(A->getNumUsers(), 2);
249   Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(),
250                                  [CV](auto &it) { return it.getUser() == CV; })
251                     ->getUser();
252   EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
253   ASSERT_EQ(newCV->getNumUsers(), 1);
254   Node *reshape = newCV->getUsers().begin()->getUser();
255   EXPECT_TRUE(llvm::isa<ReshapeNode>(reshape));
256 
257   bindings_.allocate(mod_.getPlaceholders());
258   bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
259   checkNumericalEquivalence();
260 }
261 
262 // BatchNorm is not folded into Conv. Reshape changes Channel Index dimensions
263 // and it prevents optimization. Reshape transforms [N][H][W][C] to
264 // [N][H][W/2][C*2].
TEST_F(GraphOptz,optimizeBatchNormAfterConvAndReshapeNHWCneg)265 TEST_F(GraphOptz, optimizeBatchNormAfterConvAndReshapeNHWCneg) {
266   auto *A =
267       mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
268   Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
269   Node *RS = F_->createReshape("reshape", CV, {1, 10, 10, 32});
270   Node *BN =
271       F_->createBatchNormalization(bindings_, "batch", RS, 3, 0.0001, 0.9);
272   F_->createSave("ret", BN);
273 
274   EXPECT_EQ(F_->getNodes().size(), 4);
275   ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
276   optimizedF_ = optimizeFunction(F_);
277   EXPECT_EQ(optimizedF_->getNodes().size(), 4);
278 
279   ASSERT_EQ(A->getNumUsers(), 2);
280   Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(),
281                                  [CV](auto &it) { return it.getUser() == CV; })
282                     ->getUser();
283   EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
284   ASSERT_EQ(newCV->getNumUsers(), 1);
285   Node *reshape = newCV->getUsers().begin()->getUser();
286   EXPECT_TRUE(llvm::isa<ReshapeNode>(reshape));
287   Node *bn = reshape->getUsers().begin()->getUser();
288   EXPECT_TRUE(llvm::isa<BatchNormalizationNode>(bn));
289 
290   bindings_.allocate(mod_.getPlaceholders());
291   bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
292   checkNumericalEquivalence();
293 }
294 
295 // Conv->Reshape->BatchNorm. Sink Reshape below BatchNorm. Check that BatchNorm
296 // does not fold in to Conv.
TEST_F(GraphOptz,sinkReshapeBelowBatchNormAndDoNotFuseConvBatchNorm)297 TEST_F(GraphOptz, sinkReshapeBelowBatchNormAndDoNotFuseConvBatchNorm) {
298   auto *A =
299       mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
300   Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1,
301                             ConvolutionLayout::NCHW);
302   Node *RS = F_->createReshape("reshape", CV, {1, 10, 16, 20});
303   Node *BN =
304       F_->createBatchNormalization(bindings_, "batch", RS, 1, 0.0001, 0.9);
305   F_->createSave("ret", BN);
306 
307   EXPECT_EQ(F_->getNodes().size(), 4);
308   ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
309   optimizedF_ = optimizeFunction(F_);
310   EXPECT_EQ(optimizedF_->getNodes().size(), 4);
311 
312   ASSERT_EQ(A->getNumUsers(), 2);
313   Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(),
314                                  [CV](auto &it) { return it.getUser() == CV; })
315                     ->getUser();
316 
317   EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
318   ASSERT_EQ(newCV->getNumUsers(), 1);
319   Node *bn = newCV->getUsers().begin()->getUser();
320   EXPECT_TRUE(llvm::isa<BatchNormalizationNode>(bn));
321   Node *reshape = bn->getUsers().begin()->getUser();
322   EXPECT_TRUE(llvm::isa<ReshapeNode>(reshape));
323 
324   bindings_.allocate(mod_.getPlaceholders());
325   bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
326   checkNumericalEquivalence();
327 }
328 
TEST_F(GraphOptz,optimizeBatchNormAfterConv)329 TEST_F(GraphOptz, optimizeBatchNormAfterConv) {
330   auto *A =
331       mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
332   Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
333   Node *BN =
334       F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
335   F_->createSave("ret", BN);
336 
337   EXPECT_EQ(F_->getNodes().size(), 3);
338   ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
339   optimizedF_ = optimizeFunction(F_);
340   EXPECT_EQ(optimizedF_->getNodes().size(), 2);
341 
342   ASSERT_EQ(A->getNumUsers(), 2);
343   Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(),
344                                  [CV](auto &it) { return it.getUser() == CV; })
345                     ->getUser();
346   EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
347   ASSERT_EQ(newCV->getNumUsers(), 1);
348   Node *save = newCV->getUsers().begin()->getUser();
349   EXPECT_TRUE(llvm::isa<SaveNode>(save));
350 
351   bindings_.allocate(mod_.getPlaceholders());
352   bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
353   checkNumericalEquivalence();
354 }
355 
356 /// Verify that the Conv-BatchNorm merging optimization is not impacted by
357 /// multiple users on the filter/bias.
TEST_F(GraphOptz,optimizeBatchNormAfterConvMultiple)358 TEST_F(GraphOptz, optimizeBatchNormAfterConvMultiple) {
359   Placeholder *A =
360       mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
361   ConvolutionNode *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
362   BatchNormalizationNode *BN =
363       F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
364   F_->createSave("ret", BN);
365 
366   // Adding these saves means the filter and bias have multiple uses. This
367   // should not impact the Conv-BatchNorm merging optimization.
368   F_->createSave("saveFilter", CV->getFilter());
369   F_->createSave("saveBias", CV->getBias());
370 
371   // Three Saves, one Conv, and one BatchNorm.
372   EXPECT_EQ(F_->getNodes().size(), 5);
373 
374   ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
375 
376   // Conv's Filter and Bias, plus BN's Scale, Bias, Mean, and Var.
377   EXPECT_EQ(mod_.getConstants().size(), 6);
378 
379   optimizedF_ = optimizeFunction(F_);
380 
381   // BatchNorm should have been merged into the Conv.
382   EXPECT_EQ(optimizedF_->getNodes().size(), 4);
383 
384   // Filter and Bias should have been duplicated so that the Conv-BN
385   // optimization does not modify the filter/bias being saved, equaling 4
386   // Constants. Additionally, the BN's Scale, Bias, Mean, and Var should be
387   // eliminated due to the opti.
388   EXPECT_EQ(mod_.getConstants().size(), 8);
389 
390   ASSERT_EQ(A->getNumUsers(), 2);
391   Node *newCV = A->getUsers().back().getUser();
392   EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
393   ASSERT_EQ(newCV->getNumUsers(), 1);
394   Node *save = newCV->getUsers().begin()->getUser();
395   EXPECT_TRUE(llvm::isa<SaveNode>(save));
396 
397   EXPECT_EQ(
398       countNodeKind(optimizedF_, Kinded::Kind::BatchNormalizationNodeKind), 0);
399 
400   bindings_.allocate(mod_.getPlaceholders());
401   bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
402   checkNumericalEquivalence();
403 }
404 
TEST_F(GraphOptz,optimizeBatchNormAfterConvFP16)405 TEST_F(GraphOptz, optimizeBatchNormAfterConvFP16) {
406   auto *A =
407       mod_.createPlaceholder(ElemKind::Float16Ty, {1, 10, 20, 3}, "A", false);
408   Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
409   Node *BN =
410       F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
411   F_->createSave("ret", BN);
412 
413   EXPECT_EQ(F_->getNodes().size(), 3);
414 
415   ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
416   optimizedF_ = optimizeFunction(F_);
417 
418   EXPECT_EQ(optimizedF_->getNodes().size(), 2);
419 
420   ASSERT_EQ(A->getNumUsers(), 2);
421 
422   bool optimizedPathExists{false};
423   for (const auto &path : A->getUsers()) {
424     auto cv = path.getUser();
425     EXPECT_TRUE(llvm::isa<ConvolutionNode>(cv));
426     ASSERT_EQ(cv->getNumUsers(), 1);
427     auto next = cv->getUsers().begin()->getUser();
428     optimizedPathExists |= llvm::isa<SaveNode>(next);
429   }
430 
431   EXPECT_TRUE(optimizedPathExists);
432 
433   bindings_.allocate(A)->getHandle<float16_t>().randomize(-1.0, 1.0,
434                                                           mod_.getPRNG());
435 
436   checkNumericalEquivalence();
437 }
438 
439 /// Check that transpose constant folding is done before BatchNorm optimization,
440 /// which allows to merge BatchNorm into Convolution with transposed weights.
TEST_F(GraphOptz,optimizeBatchNormAfterConvWithTransposedWeights)441 TEST_F(GraphOptz, optimizeBatchNormAfterConvWithTransposedWeights) {
442   auto *input =
443       mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "input", false);
444   auto *filter =
445       mod_.createPlaceholder(ElemKind::FloatTy, {16, 3, 5, 5}, "filter", false);
446   auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {16}, "bias", false);
447 
448   auto *TN = F_->createTranspose("transpose", filter, NCHW2NHWC);
449   auto *CV = F_->createConv("conv", input, TN, bias,
450                             mod_.uniqueType(ElemKind::FloatTy, {1, 10, 20, 16}),
451                             5, 1, 2, 1);
452   auto *BN =
453       F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
454   F_->createSave("ret", BN);
455 
456   // Initialize to ensure that constant tensors are not optimized out.
457   bindings_.allocate(filter)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
458   bindings_.allocate(bias)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
459 
460   EXPECT_EQ(F_->getNodes().size(), 4);
461   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchNormalizationNodeKind), 1);
462 
463   ::glow::convertPlaceholdersToConstants(F_, bindings_, {input});
464   optimizedF_ = optimizeFunction(F_);
465 
466   EXPECT_EQ(optimizedF_->getNodes().size(), 2);
467   EXPECT_EQ(
468       countNodeKind(optimizedF_, Kinded::Kind::BatchNormalizationNodeKind), 0);
469 
470   bindings_.allocate(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
471   checkNumericalEquivalence();
472 }
473 
474 /// Check that reshape constant folding is done before BatchNorm optimization,
475 /// where Reshape is a result of Transpose 2 Reshape optimization,
476 /// which allows to merge BatchNorm into Convolution with transposed weights.
TEST_F(GraphOptz,optimizeBatchNormAfterConvWithReshapeConst)477 TEST_F(GraphOptz, optimizeBatchNormAfterConvWithReshapeConst) {
478   auto *input =
479       mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "input", false);
480   auto *filter =
481       mod_.createPlaceholder(ElemKind::FloatTy, {5, 5, 3, 1}, "filter", false);
482   auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "bias", false);
483 
484   auto *TN = F_->createTranspose("transpose", filter, HWCN2NHWC);
485   auto *CV = F_->createConv("conv", input, TN, bias,
486                             mod_.uniqueType(ElemKind::FloatTy, {1, 10, 20, 1}),
487                             5, 1, 2, 1);
488   auto *BN =
489       F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
490   F_->createSave("ret", BN);
491 
492   // Initialize to ensure that constant tensors are not optimized out.
493   bindings_.allocate(filter)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
494   bindings_.allocate(bias)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
495 
496   EXPECT_EQ(F_->getNodes().size(), 4);
497   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchNormalizationNodeKind), 1);
498 
499   ::glow::convertPlaceholdersToConstants(F_, bindings_, {input});
500   optimizedF_ = optimizeFunction(F_);
501 
502   EXPECT_EQ(optimizedF_->getNodes().size(), 2);
503   EXPECT_EQ(
504       countNodeKind(optimizedF_, Kinded::Kind::BatchNormalizationNodeKind), 0);
505 
506   bindings_.allocate(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
507   checkNumericalEquivalence();
508 }
509 
510 /// Check that the batch normalization optimization is
511 /// not blocked by predicates and that it preserves them.
TEST_F(GraphOptz,optimizeBatchNormAfterConvWithPred)512 TEST_F(GraphOptz, optimizeBatchNormAfterConvWithPred) {
513   Node *A =
514       mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
515   Node *pred1 =
516       mod_.createPlaceholder(ElemKind::FloatTy, {1}, "predicate", false);
517   Node *pred2 =
518       mod_.createPlaceholder(ElemKind::FloatTy, {1}, "predicate", false);
519   Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
520   CV->setPredicate(pred1);
521   Node *BN =
522       F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
523   BN->setPredicate(pred2);
524   F_->createSave("ret", BN);
525 
526   EXPECT_EQ(F_->getNodes().size(), 3);
527 
528   ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
529   ::glow::optimize(F_, CompilationMode::Infer);
530   EXPECT_EQ(F_->getNodes().size(), 2);
531 
532   ASSERT_EQ(A->getNumUsers(), 1);
533   Node *newCV = A->getUsers().begin()->getUser();
534   EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
535   ASSERT_TRUE(newCV->hasPredicate());
536   EXPECT_EQ(newCV->getPredicate().getNode(), pred2);
537   ASSERT_EQ(newCV->getNumUsers(), 1);
538   Node *save = newCV->getUsers().begin()->getUser();
539   EXPECT_TRUE(llvm::isa<SaveNode>(save));
540 }
541 
542 /// Testing merge of single-user arithmetic operation chain (Sub, Mul, Add)
543 /// into a BatchNorm.
TEST_F(GraphOptz,MergeBatchNormalizationWithArithmeticChainTest)544 TEST_F(GraphOptz, MergeBatchNormalizationWithArithmeticChainTest) {
545   // Inputs.
546   auto *input =
547       mod_.createPlaceholder(ElemKind::FloatTy, {3, 2, 2, 4}, "input", false);
548   auto *var = mod_.createConstant(ElemKind::FloatTy, {4}, "var");
549   auto *mean = mod_.createConstant(ElemKind::FloatTy, {4}, "mean");
550   auto *beta = mod_.createConstant(ElemKind::FloatTy, {4}, "beta");
551   auto *gamma = mod_.createConstant(ElemKind::FloatTy, {4}, "gamma");
552 
553   Node *subC = mod_.createConstant(ElemKind::FloatTy, {3, 2, 2, 4}, "subC");
554   Node *mulC = mod_.createConstant(ElemKind::FloatTy, {3, 2, 2, 4}, "mulC");
555   Node *addC = mod_.createConstant(ElemKind::FloatTy, {3, 2, 2, 4}, "addC");
556   Node *divC = mod_.createConstant(ElemKind::FloatTy, {3, 2, 2, 4}, "divC");
557 
558   // Fill tensors to check boundary values after the transformation.
559   std::vector<float> betaV = {1., 2., 3., 7.};
560   std::vector<float> gammaV = {4., 5., 6., 7.};
561 
562   var->getPayloadMutable().getHandle<float>() = {1., 1., 1., 1.};
563   mean->getPayloadMutable().getHandle<float>() = {0., 0., 0., 0.};
564   beta->getPayloadMutable().getHandle<float>() = betaV;
565   gamma->getPayloadMutable().getHandle<float>() = gammaV;
566 
567   // For at least one node (sub) make values within channel different, to test
568   // folding better.
569   const std::vector<float> subV = {1, 2., 3., 4.};
570   const float mulV = 4., addV = 3., divV = 2.;
571   auto subH = llvm::cast<Constant>(subC)->getHandle<float>();
572   subH = {1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.,
573           1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.,
574           1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.};
575 
576   llvm::cast<Constant>(mulC)->getHandle<float>().clear(mulV);
577   llvm::cast<Constant>(addC)->getHandle<float>().clear(addV);
578   llvm::cast<Constant>(divC)->getHandle<float>().clear(divV);
579 
580   BatchNormalizationNode *bn =
581       F_->createBatchNormalization("batch", input, beta, gamma, mean, var, 3);
582 
583   auto *sub = F_->createSub("sub", bn, subC);
584   auto *mul = F_->createMul("mul", sub, mulC);
585   auto *add = F_->createAdd("add", addC, mul);
586   auto *div = F_->createDiv("div", add, divC);
587   auto *res = F_->createSave("save", div);
588 
589   // Compile.
590   EXPECT_EQ(F_->getNodes().size(), 6);
591   ::glow::convertPlaceholdersToConstants(F_, bindings_, {input});
592   optimizedF_ = optimizeFunction(F_);
593   EXPECT_EQ(optimizedF_->getNodes().size(), 2);
594 
595   Constant *cs, *cb;
596 
597   auto *opt_res = findFunctionNodeByName<SaveNode>(optimizedF_, res->getName());
598 
599   auto *newBn = llvm::dyn_cast<BatchNormalizationNode>(opt_res->getInput());
600   ASSERT_TRUE(newBn);
601 
602   cs = llvm::dyn_cast<Constant>(newBn->getScale());
603   cb = llvm::dyn_cast<Constant>(newBn->getBias());
604   ASSERT_TRUE(cs);
605   ASSERT_TRUE(cb);
606   ASSERT_TRUE(cs->getType()->isFPType());
607   ASSERT_TRUE(cb->getType()->isFPType());
608 
609   auto hs = cs->getHandle<float>();
610   auto hb = cb->getHandle<float>();
611 
612   // Verify that scale and offset are computed correctly.
613   for (dim_t i = 0; i < 4; i++) {
614     const float expScale = gammaV[i] * mulV / divV;
615     const float expBias = ((betaV[i] - subV[i]) * mulV + addV) / divV;
616     EXPECT_EQ(expScale, hs.raw(i));
617     EXPECT_EQ(expBias, hb.raw(i));
618   }
619 
620   bindings_.allocate(mod_.getPlaceholders());
621   bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
622   checkNumericalEquivalence();
623 }
624 
625 /// Testing merge of single-user arithmetic operation chain (Sub, Mul, Add)
626 /// into a BatchNorm.
TEST_F(GraphOptz,FoldArithmeticChainAfterConvIntoBatchNorm)627 TEST_F(GraphOptz, FoldArithmeticChainAfterConvIntoBatchNorm) {
628   Node *subC = mod_.createConstant(ElemKind::FloatTy, {2, 3, 3, 3}, "subC");
629   Node *mulC = mod_.createConstant(ElemKind::FloatTy, {2, 3, 3, 3}, "mulC");
630   Node *addC = mod_.createConstant(ElemKind::FloatTy, {2, 3, 3, 3}, "addC");
631   Node *divC = mod_.createConstant(ElemKind::FloatTy, {2, 3, 3, 3}, "divC");
632 
633   // Start with identity values.
634   std::vector<float> betaV = {0., 0., 0.};
635   std::vector<float> gammaV = {1., 1., 1.};
636 
637   // For at least one node make values within channel different, to test
638   // the folding better (ideally all should have different values).
639   const std::vector<float> subV = {1, 2., 3.};
640   const float mulV = 4., addV = 3., divV = 2.;
641   llvm::cast<Constant>(mulC)->getHandle<float>().clear(mulV);
642   llvm::cast<Constant>(addC)->getHandle<float>().clear(addV);
643   llvm::cast<Constant>(divC)->getHandle<float>().clear(divV);
644   auto subH = llvm::cast<Constant>(subC)->getHandle<float>();
645   subH = {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3,
646           1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3,
647           1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3};
648 
649   auto *input =
650       mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 2, 3}, "input", false);
651   auto filter =
652       mod_.createPlaceholder(ElemKind::FloatTy, {3, 2, 2, 3}, "filter", false);
653   auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {3}, "bias", false);
654   bindings_.allocate(bias)->zero();
655 
656   ConvolutionNode *CV = F_->createConv(
657       "Conv", input, filter, bias,
658       mod_.uniqueType(ElemKind::FloatTy, {2, 3, 3, 3}), 2, 1, 1, 1);
659 
660   auto *sub = F_->createSub("sub", CV, subC);
661   auto *mul = F_->createMul("mul", sub, mulC);
662   auto *add = F_->createAdd("add", addC, mul);
663   auto *div = F_->createDiv("div", add, divC);
664   auto *res = F_->createSave("save", div);
665 
666   // Compile.
667   EXPECT_EQ(F_->getNodes().size(), 6);
668   ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
669   optimizedF_ = optimizeFunction(F_);
670   EXPECT_EQ(optimizedF_->getNodes().size(), 3);
671 
672   auto *opt_res = findFunctionNodeByName<SaveNode>(optimizedF_, res->getName());
673 
674   Constant *cs, *cb;
675 
676   auto *bn = llvm::dyn_cast<BatchNormalizationNode>(opt_res->getInput());
677   ASSERT_TRUE(bn);
678 
679   cs = llvm::dyn_cast<Constant>(bn->getScale());
680   cb = llvm::dyn_cast<Constant>(bn->getBias());
681 
682   ASSERT_TRUE(cs);
683   ASSERT_TRUE(cb);
684   ASSERT_TRUE(cs->getType()->isFPType());
685   ASSERT_TRUE(cb->getType()->isFPType());
686 
687   auto hs = cs->getHandle<float>();
688   auto hb = cb->getHandle<float>();
689 
690   // Verify that scale and offset are computed correctly.
691   for (dim_t i = 0; i < 3; i++) {
692     const float expectedScale = gammaV[i] * (mulV / divV);
693     const float expectedBias = ((betaV[i] - subV[i]) * mulV + addV) / divV;
694     EXPECT_EQ(expectedScale, hs.raw(i));
695     EXPECT_EQ(expectedBias, hb.raw(i));
696   }
697   bindings_.allocate(mod_.getPlaceholders());
698   bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
699   bindings_.get(filter)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
700   bindings_.get(bias)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
701   checkNumericalEquivalence();
702 }
703 
704 /// Check CSE will not merge two nodes that have all the same inputs but
705 /// different predicates.
TEST_F(GraphOptz,cseRespectsPredicates)706 TEST_F(GraphOptz, cseRespectsPredicates) {
707   Placeholder *in = mod_.createPlaceholder(ElemKind::FloatTy, {5}, "in", false);
708   Placeholder *pred1 =
709       mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
710   Placeholder *pred2 =
711       mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
712 
713   Node *RN1 = F_->createRELU("relu1", in);
714   RN1->setPredicate(pred1);
715   SaveNode *save1 = F_->createSave("save1", RN1);
716   save1->setPredicate(pred1);
717 
718   Node *RN2 = F_->createRELU("relu2", in);
719   RN2->setPredicate(pred2);
720   SaveNode *save2 = F_->createSave("save2", RN2);
721   save2->setPredicate(pred2);
722 
723   // Two RELUS and two Saves.
724   EXPECT_EQ(F_->getNodes().size(), 4);
725   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 2);
726   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 2);
727 
728   ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
729   optimizedF_ = optimizeFunction(F_);
730 
731   // Two RELUS and two Saves should still be there.
732   EXPECT_EQ(F_->getNodes().size(), 4);
733   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 2);
734   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 2);
735 
736   bindings_.allocate(mod_.getPlaceholders());
737   bindings_.get(in)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
738   checkNumericalEquivalence();
739 }
740 
TEST_F(GraphOptz,optimizeBatchNormAfterConvButConvReused)741 TEST_F(GraphOptz, optimizeBatchNormAfterConvButConvReused) {
742   Placeholder *A =
743       mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
744   Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
745   Node *BN =
746       F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
747   F_->createSave("ret", BN);
748   F_->createSave("convSave", CV);
749 
750   EXPECT_EQ(F_->getNodes().size(), 4);
751   optimizedF_ = optimizeFunction(F_);
752   // Make sure the structure of the graph did not change, since the convolution
753   // node is used more than once.
754   EXPECT_EQ(optimizedF_->getNodes().size(), 4);
755   auto convIt =
756       std::find_if(optimizedF_->getNodes().begin(),
757                    optimizedF_->getNodes().end(), [](const Node &node) -> bool {
758                      return llvm::isa<ConvolutionNode>(node);
759                    });
760   ASSERT_NE(convIt, optimizedF_->getNodes().end());
761   auto batchNormIt =
762       std::find_if(optimizedF_->getNodes().begin(),
763                    optimizedF_->getNodes().end(), [](const Node &node) -> bool {
764                      return (llvm::isa<BatchNormalizationNode>(node));
765                    });
766   ConvolutionNode *conv = llvm::dyn_cast<ConvolutionNode>(convIt);
767   BatchNormalizationNode *batchNorm =
768       llvm::dyn_cast<BatchNormalizationNode>(batchNormIt);
769 
770   EXPECT_EQ(*conv, *CV);
771   EXPECT_EQ(batchNorm->getInput().getNode(), conv);
772   EXPECT_EQ(conv->getInput().getNode(), A);
773 
774   bindings_.allocate(mod_.getPlaceholders());
775   bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
776   checkNumericalEquivalence();
777 }
778 
TEST_F(GraphOptz,optimizeBatchNormAfterConvButVarReused)779 TEST_F(GraphOptz, optimizeBatchNormAfterConvButVarReused) {
780   auto *A =
781       mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
782 
783   ConvolutionNode *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
784   Node *BN =
785       F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
786   auto *retSaveNode = F_->createSave("ret", BN);
787   auto *filterSaveNode = F_->createSave("filter", CV->getFilter());
788 
789   EXPECT_EQ(F_->getNodes().size(), 4);
790   optimizedF_ = optimizeFunction(F_);
791   ASSERT_EQ(A->getNumUsers(), 2);
792 
793   auto *optimizedF_ret =
794       findFunctionNodeByName<SaveNode>(optimizedF_, retSaveNode->getName());
795   auto *optimizedF_filterSave =
796       findFunctionNodeByName<SaveNode>(optimizedF_, filterSaveNode->getName());
797 
798   // Make sure the structure of the graph did not change.
799   EXPECT_EQ(optimizedF_->getNodes().size(), 4);
800   EXPECT_TRUE(llvm::isa<Placeholder>(optimizedF_filterSave->getInput()));
801   auto *varFilter =
802       llvm::dyn_cast<Placeholder>(optimizedF_filterSave->getInput());
803   EXPECT_EQ(varFilter, CV->getFilter());
804   EXPECT_TRUE(llvm::isa<BatchNormalizationNode>(optimizedF_ret->getInput()));
805 
806   BatchNormalizationNode *batchNorm =
807       llvm::dyn_cast<BatchNormalizationNode>(optimizedF_ret->getInput());
808   ASSERT_TRUE(batchNorm);
809   auto *newCVNode =
810       llvm::dyn_cast<ConvolutionNode>(batchNorm->getInput().getNode());
811   ASSERT_TRUE(newCVNode);
812   EXPECT_EQ(newCVNode->getInput().getNode(), CV->getInput().getNode());
813   EXPECT_EQ(newCVNode->getInput().getNode(), A);
814 
815   bindings_.allocate(mod_.getPlaceholders());
816   bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
817   checkNumericalEquivalence();
818 }
819 
TEST_F(GraphOptz,transposeConstant)820 TEST_F(GraphOptz, transposeConstant) {
821   auto *A =
822       mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
823   bindings_.allocate(A)->getHandle().randomize(-7.0, 12.0, mod_.getPRNG());
824   Tensor transposedA;
825   bindings_.get(A)->transpose(&transposedA, {0, 3, 1, 2});
826   Node *T = F_->createTranspose("transpose", A, NHWC2NCHW);
827   SaveNode *save = F_->createSave("ret", T);
828   EXPECT_EQ(F_->getNodes().size(), 2);
829 
830   ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
831   ::glow::optimize(F_, CompilationMode::Infer);
832   ASSERT_EQ(F_->getNodes().size(), 1);
833   EXPECT_EQ(&*F_->getNodes().begin(), save);
834   Constant *optimizedA = llvm::dyn_cast<Constant>(save->getInput().getNode());
835   ASSERT_NE(optimizedA, nullptr);
836   // Check that A has been properly transposed.
837   EXPECT_TRUE(optimizedA->getPayload().isEqual(transposedA));
838 }
839 
840 /// Check that the Transpose is merged with Constant in a sequence
841 /// Transpose(Quantize(Constant)).
TEST_F(GraphOptz,transposeQuantizeConstant)842 TEST_F(GraphOptz, transposeQuantizeConstant) {
843   auto *qTy = mod_.uniqueType(ElemKind::Int8QTy, {1, 10, 20, 3}, 0.2, 0);
844   auto *input = F_->getParent()->createConstant(ElemKind::FloatTy,
845                                                 {1, 10, 20, 3}, "input");
846   auto *Q = F_->createQuantize("quantize", input, qTy);
847   auto *T = F_->createTranspose("transpose", Q, NHWC2NCHW);
848   auto *S = F_->createSave("save", T);
849 
850   // Skip ConstantFolding as it would have the same result as this opt.
851   CompilationContext cctx;
852   cctx.optimizationOpts.enableConstantFolding = false;
853 
854   EXPECT_EQ(F_->getNodes().size(), 3);
855   ::glow::optimize(F_, cctx);
856   EXPECT_EQ(F_->getNodes().size(), 2);
857 
858   // Constant and Quantize should have new shape.
859   auto *newQ = llvm::dyn_cast<QuantizeNode>(S->getInput());
860   ASSERT_TRUE(newQ);
861   EXPECT_TRUE(newQ->getResult().dims().equals({1, 3, 10, 20}));
862   auto *newC = llvm::dyn_cast<Constant>(newQ->getInput());
863   ASSERT_TRUE(newC);
864   EXPECT_TRUE(newC->getType()->dims().equals({1, 3, 10, 20}));
865 }
866 
867 /// Check that the removing of transposes still happens when
868 /// predicates are involved.
TEST_F(GraphOptz,transposeConstantWithPredicate)869 TEST_F(GraphOptz, transposeConstantWithPredicate) {
870   auto *A =
871       mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
872   auto *pred = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
873   bindings_.allocate(A)->getHandle().randomize(-7.0, 12.0, mod_.getPRNG());
874   Tensor transposedA;
875   bindings_.get(A)->transpose(&transposedA, {0, 3, 1, 2});
876   // Arguably, if the transpose doesn't happen because the predicate is false
877   // the value of A should be unchanged. However, the semantic of our
878   // predicate is that they can be ignored and the program would still
879   // be correct, thus this optimization is still legal.
880   Node *T = F_->createTranspose("transpose", A, NHWC2NCHW);
881   T->setPredicate(pred);
882   SaveNode *save = F_->createSave("ret", T);
883   save->setPredicate(pred);
884   EXPECT_EQ(F_->getNodes().size(), 2);
885 
886   ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
887   ::glow::optimize(F_, CompilationMode::Infer);
888   ASSERT_EQ(F_->getNodes().size(), 1);
889   EXPECT_EQ(&*F_->getNodes().begin(), save);
890   // We should have kept the predicate on the save node.
891   ASSERT_EQ(pred->getNumUsers(), 1);
892   EXPECT_EQ(pred->getUsers().begin()->getUser(), save);
893   Constant *optimizedA = llvm::dyn_cast<Constant>(save->getInput().getNode());
894   ASSERT_NE(optimizedA, nullptr);
895   // Check that A has been properly transposed.
896   EXPECT_TRUE(optimizedA->getPayload().isEqual(transposedA));
897 }
898 
TEST_F(GraphOptz,BatchNormAfterConvNotOptimizeForTrain)899 TEST_F(GraphOptz, BatchNormAfterConvNotOptimizeForTrain) {
900   Placeholder *A =
901       mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
902   Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
903   Node *BN =
904       F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
905   F_->createSave("ret", BN);
906 
907   EXPECT_EQ(F_->getNodes().size(), 3);
908 
909   optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
910   ::glow::optimize(optimizedF_, CompilationMode::Train);
911   EXPECT_EQ(optimizedF_->getNodes().size(), 3);
912 
913   ASSERT_EQ(A->getNumUsers(), 2);
914   Node *curCV = A->getUsers().begin()->getUser();
915   EXPECT_EQ(curCV, CV);
916   ASSERT_EQ(curCV->getNumUsers(), 1);
917   Node *curBN = curCV->getUsers().begin()->getUser();
918   EXPECT_EQ(curBN, BN);
919   ASSERT_EQ(curBN->getNumUsers(), 1);
920   Node *save = curBN->getUsers().begin()->getUser();
921   EXPECT_TRUE(llvm::isa<SaveNode>(save));
922 
923   bindings_.allocate(mod_.getPlaceholders());
924   bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
925   checkNumericalEquivalence();
926 }
927 
TEST_F(GraphOptz,batchNormAfterConvNotOptimizeWhenMoreThanOneUseOfConv)928 TEST_F(GraphOptz, batchNormAfterConvNotOptimizeWhenMoreThanOneUseOfConv) {
929   Node *A =
930       mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
931 
932   Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
933   Node *BN =
934       F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
935   SaveNode *convSave = F_->createSave("ret", CV);
936   SaveNode *ret = F_->createSave("ret", BN);
937 
938   EXPECT_EQ(F_->getNodes().size(), 4);
939 
940   ::glow::optimize(F_, CompilationMode::Infer);
941   // Make sure the structure of the graph did not change, since the convolution
942   // node is used more than once.
943   EXPECT_EQ(F_->getNodes().size(), 4);
944   ASSERT_TRUE(llvm::isa<ConvolutionNode>(convSave->getInput()));
945   ConvolutionNode *conv = llvm::dyn_cast<ConvolutionNode>(convSave->getInput());
946   EXPECT_EQ(conv, CV);
947   EXPECT_TRUE(llvm::isa<BatchNormalizationNode>(ret->getInput()));
948   BatchNormalizationNode *batchNorm =
949       llvm::dyn_cast<BatchNormalizationNode>(ret->getInput());
950   EXPECT_EQ(batchNorm, BN);
951   EXPECT_EQ(batchNorm->getInput().getNode(), CV);
952   EXPECT_EQ(conv->getInput().getNode(), A);
953 }
954 
955 enum class TestSinkTransposeNodesKind {
956   BatchNormalization,
957   Relu,
958   Clip,
959   Sigmoid,
960   Tanh,
961   Quantize,
962 };
963 
964 class GraphOptzSinkTransposeBelowParametrized
965     : public GraphOptz,
966       public ::testing::WithParamInterface<TestSinkTransposeNodesKind> {
967 public:
getNodeFromInput(TestSinkTransposeNodesKind testNode,Node * T)968   NodeValue getNodeFromInput(TestSinkTransposeNodesKind testNode, Node *T) {
969     switch (testNode) {
970     case TestSinkTransposeNodesKind::BatchNormalization: {
971       return F_->createBatchNormalization(bindings_, "batch", T, 3, 0.0001, 0.9)
972           ->getResult();
973     }
974     case TestSinkTransposeNodesKind::Relu: {
975       return F_->createRELU("relu", T)->getResult();
976     }
977     case TestSinkTransposeNodesKind::Clip: {
978       return F_->createClip("clip", T, 0.0, 6.0)->getResult();
979     }
980     case TestSinkTransposeNodesKind::Sigmoid: {
981       return F_->createSigmoid("sigmoid", T)->getResult();
982     }
983     case TestSinkTransposeNodesKind::Tanh: {
984       return F_->createTanh("tanh", T)->getResult();
985     }
986     case TestSinkTransposeNodesKind::Quantize: {
987       return F_
988           ->createQuantize(
989               "quantize", T,
990               mod_.uniqueType(ElemKind::Int8QTy, T->dims(0), 0.03, 5))
991           ->getResult();
992     }
993     }
994     LOG(DFATAL) << "Cannot reach here.";
995   }
996 };
997 
TEST_P(GraphOptzSinkTransposeBelowParametrized,TestSinkTransposeForDifferentCases)998 TEST_P(GraphOptzSinkTransposeBelowParametrized,
999        TestSinkTransposeForDifferentCases) {
1000   const dim_t origDims[] = {1, 5, 10, 15};
1001   const dim_t transposedDims[] = {1, 15, 5, 10};
1002   auto *A = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1003   Node *T = F_->createTranspose("transpose", A, NHWC2NCHW);
1004   auto IN = getNodeFromInput(GetParam(), T);
1005   SaveNode *O = F_->createSave("ret", IN);
1006 
1007   EXPECT_EQ(F_->getNodes().size(), 3);
1008   EXPECT_EQ(IN.dims(), llvm::makeArrayRef(transposedDims));
1009 
1010   optimizedF_ = optimizeFunction(F_);
1011   O = llvm::dyn_cast<SaveNode>(std::find_if(
1012       optimizedF_->getNodes().begin(), optimizedF_->getNodes().end(),
1013       [](const auto &N) { return N.getKind() == Kinded::Kind::SaveNodeKind; }));
1014 
1015   // Expecting Transpose->Output rather than N->Output.
1016   auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1017   ASSERT_NE(transpose, nullptr);
1018   Node *N = transpose->getInput();
1019   ASSERT_TRUE(N);
1020   // Test correct input.
1021   if (GetParam() == TestSinkTransposeNodesKind::BatchNormalization) {
1022     ASSERT_EQ(BatchNormalizationNode::InputIdx, 0);
1023   } else {
1024     ASSERT_EQ(N->getNumInputs(), 1);
1025   }
1026   // Check that the dimensions of the input and output have been
1027   // updated to compensate the absence of transpose.
1028   EXPECT_EQ(transpose->getInput().dims(), llvm::makeArrayRef(origDims));
1029   EXPECT_EQ(N->getNthInput(0).dims(), llvm::makeArrayRef(origDims));
1030   EXPECT_EQ(F_->getNodes().size(), 3);
1031 
1032   bindings_.allocate(mod_.getPlaceholders());
1033   bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1034   checkNumericalEquivalence();
1035 }
1036 
TEST_P(GraphOptzSinkTransposeBelowParametrized,TestSinkTransposeWithPredicateForDifferentCases)1037 TEST_P(GraphOptzSinkTransposeBelowParametrized,
1038        TestSinkTransposeWithPredicateForDifferentCases) {
1039   if (GetParam() == TestSinkTransposeNodesKind::Quantize) {
1040     // Quantize does not work with generic test for predicates.
1041     return;
1042   }
1043   const dim_t origDims[] = {1, 5, 10, 15};
1044   const dim_t transposedDims[] = {1, 15, 5, 10};
1045   Node *A = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1046   Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1047   Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1048   Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1049   Node *T = F_->createTranspose("transpose", A, NHWC2NCHW);
1050   T->setPredicate(pred1);
1051   Node *IN = getNodeFromInput(GetParam(), T);
1052   IN->setPredicate(pred2);
1053   SaveNode *O = F_->createSave("ret", IN);
1054   O->setPredicate(pred3);
1055 
1056   EXPECT_EQ(F_->getNodes().size(), 3);
1057   EXPECT_EQ(IN->getNthResult(0).dims(), llvm::makeArrayRef(transposedDims));
1058 
1059   ::glow::optimize(F_, CompilationMode::Infer);
1060 
1061   EXPECT_EQ(O->getPredicate().getNode(), pred3);
1062   // Expecting Transpose->Output rather than N->Output.
1063   auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1064   ASSERT_NE(transpose, nullptr);
1065   EXPECT_EQ(transpose->getPredicate().getNode(), pred2);
1066   Node *N = transpose->getInput();
1067   ASSERT_TRUE(N);
1068   EXPECT_EQ(N->getPredicate().getNode(), pred2);
1069 
1070   // Test correct input.
1071   if (GetParam() == TestSinkTransposeNodesKind::BatchNormalization) {
1072     ASSERT_EQ(BatchNormalizationNode::InputIdx, 0);
1073   } else {
1074     ASSERT_EQ(N->getNumInputs(), 1);
1075   }
1076 
1077   // Check that the dimensions of the input and output have been
1078   // updated to compensate the absence of transpose.
1079   EXPECT_EQ(transpose->getInput().dims(), llvm::makeArrayRef(origDims));
1080   EXPECT_EQ(N->getNthInput(0).dims(), llvm::makeArrayRef(origDims));
1081   EXPECT_EQ(F_->getNodes().size(), 3);
1082 }
1083 
1084 GLOW_INSTANTIATE_TEST_SUITE_P(
1085     TestSinkTranspose, GraphOptzSinkTransposeBelowParametrized,
1086     ::testing::Values(TestSinkTransposeNodesKind::BatchNormalization,
1087                       TestSinkTransposeNodesKind::Relu,
1088                       TestSinkTransposeNodesKind::Clip,
1089                       TestSinkTransposeNodesKind::Sigmoid,
1090                       TestSinkTransposeNodesKind::Tanh,
1091                       TestSinkTransposeNodesKind::Quantize));
1092 
TEST_F(GraphOptz,SinkTransposeBelowDequantize)1093 TEST_F(GraphOptz, SinkTransposeBelowDequantize) {
1094   auto *in =
1095       mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input", false);
1096   auto *quantize = F_->createQuantize(
1097       "quantize", in, mod_.uniqueType(ElemKind::Int8QTy, in->dims(), 0.01, 2));
1098   auto *tile = F_->createTile("tile", quantize, 3, 0);
1099   auto *transpose = F_->createTranspose("transpose", tile, NHWC2NCHW);
1100   auto *deq = F_->createDequantize("dequantize", transpose, ElemKind::FloatTy);
1101   SaveNode *O = F_->createSave("out", deq);
1102 
1103   optimizedF_ = optimizeFunction(F_);
1104 
1105   EXPECT_EQ(F_->getNodes().size(), 5);
1106   EXPECT_EQ(optimizedF_->getNodes().size(), 5);
1107 
1108   auto *optOut = findFunctionNodeByName<SaveNode>(optimizedF_, O->getName());
1109   EXPECT_TRUE(llvm::isa<TransposeNode>(optOut->getInput().getNode()));
1110 
1111   bindings_.allocate(mod_.getPlaceholders());
1112   bindings_.get(in)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1113   checkNumericalEquivalence();
1114 }
1115 
1116 /// For example folding Rescale in to Convolution.
TEST_F(GraphOptz,sinkTransposeBelowRescale)1117 TEST_F(GraphOptz, sinkTransposeBelowRescale) {
1118   // Inputs.
1119   const dim_t origDims[] = {1, 5, 10, 15};
1120   const dim_t transposedDims[] = {1, 15, 5, 10};
1121   auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, origDims, 0.1, 0,
1122                                        "input", false);
1123   auto *filter = mod_.createPlaceholder(ElemKind::Int8QTy, {15, 1, 1, 15}, 0.1,
1124                                         0, "filter", false);
1125   auto *bias =
1126       mod_.createPlaceholder(ElemKind::Int32QTy, {15}, 0.01, 0, "bias", false);
1127 
1128   // Graph.
1129   ConvolutionNode *conv =
1130       F_->createConv("conv", input, filter, bias, input->getType(), {1, 1},
1131                      {1, 1}, {0, 0, 0, 0}, 1);
1132 
1133   auto *T = F_->createTranspose("transpose", conv, NHWC2NCHW);
1134   auto *RT = mod_.uniqueType(ElemKind::Int8QTy, T->getResult().dims(), 0.2, 0);
1135   auto *R = F_->createRescaleQuantized("rescale", T, RT);
1136   SaveNode *O = F_->createSave("ret", R);
1137 
1138   EXPECT_EQ(F_->getNodes().size(), 4);
1139   EXPECT_EQ(RT->dims(), llvm::makeArrayRef(transposedDims));
1140 
1141   ::glow::optimize(F_, CompilationMode::Infer);
1142 
1143   // Expecting Transpose->Output rather than Rescale->Output.
1144   auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1145   ASSERT_NE(transpose, nullptr);
1146   ASSERT_TRUE(llvm::isa<ConvolutionNode>(transpose->getInput()));
1147   auto &convTRInput = transpose->getInput();
1148   // Check that the dimensions of the input and output have been
1149   // updated to compensate the absence of transpose.
1150   EXPECT_EQ(convTRInput.dims(), llvm::makeArrayRef(origDims));
1151   EXPECT_EQ(convTRInput.getNode()->getNthInput(0).dims(),
1152             llvm::makeArrayRef(origDims));
1153   EXPECT_EQ(F_->getNodes().size(), 3);
1154 }
1155 
TEST_F(GraphOptz,cancelTwoTransposes)1156 TEST_F(GraphOptz, cancelTwoTransposes) {
1157   const dim_t origDims[] = {1, 5, 10, 15};
1158   Placeholder *A =
1159       mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1160   Node *T1 = F_->createTranspose("transpose", A, NCHW2NHWC);
1161   Node *T2 = F_->createTranspose("transpose", T1, NHWC2NCHW);
1162   ReluNode *K = F_->createRELU("relu", T2);
1163   SaveNode *save = F_->createSave("ret", K);
1164 
1165   EXPECT_EQ(K->getInput().dims(), llvm::makeArrayRef(origDims));
1166   EXPECT_EQ(F_->getNodes().size(), 4);
1167 
1168   optimizedF_ = optimizeFunction(F_);
1169 
1170   EXPECT_EQ(optimizedF_->getNodes().size(), 2);
1171 
1172   for (auto &N : optimizedF_->getNodes()) {
1173     if (N.getKind() == Kinded::Kind::SaveNodeKind) {
1174       save = llvm::dyn_cast<SaveNode>(&N);
1175     }
1176   }
1177 
1178   ReluNode *relu = llvm::dyn_cast<ReluNode>(save->getInput());
1179   ASSERT_TRUE(relu);
1180   EXPECT_EQ(relu->getResult().dims(), llvm::makeArrayRef(origDims));
1181   EXPECT_EQ(relu->getInput().getNode(), A);
1182 
1183   bindings_.allocate(mod_.getPlaceholders());
1184   bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1185 
1186   checkNumericalEquivalence();
1187 }
1188 
1189 /// Make sure the predicates don't get in the way of the
1190 /// transpose(transpose) => identity and that they are
1191 /// preserved.
TEST_F(GraphOptz,cancelTwoTransposesWithPredicate)1192 TEST_F(GraphOptz, cancelTwoTransposesWithPredicate) {
1193   const dim_t origDims[] = {1, 5, 10, 15};
1194   Node *A = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1195   Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1196   Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1197   Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1198   Node *pred4 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1199   Node *T1 = F_->createTranspose("transpose", A, NCHW2NHWC);
1200   T1->setPredicate(pred1);
1201   Node *T2 = F_->createTranspose("transpose", T1, NHWC2NCHW);
1202   T2->setPredicate(pred2);
1203   ReluNode *K = F_->createRELU("relu", T2);
1204   K->setPredicate(pred3);
1205   SaveNode *save = F_->createSave("ret", K);
1206   save->setPredicate(pred4);
1207 
1208   EXPECT_EQ(K->getInput().dims(), llvm::makeArrayRef(origDims));
1209   EXPECT_EQ(F_->getNodes().size(), 4);
1210 
1211   ::glow::optimize(F_, CompilationMode::Infer);
1212 
1213   EXPECT_EQ(F_->getNodes().size(), 2);
1214   EXPECT_EQ(save->getPredicate().getNode(), pred4);
1215   ReluNode *relu = llvm::dyn_cast<ReluNode>(save->getInput());
1216   ASSERT_TRUE(relu);
1217   EXPECT_EQ(relu->getPredicate().getNode(), pred3);
1218   EXPECT_EQ(relu->getResult().dims(), llvm::makeArrayRef(origDims));
1219   EXPECT_EQ(relu->getInput().getNode(), A);
1220 }
1221 
TEST_F(GraphOptz,removeIdentityTranspose)1222 TEST_F(GraphOptz, removeIdentityTranspose) {
1223   const dim_t origDims[] = {1, 5, 10, 15};
1224   Placeholder *A =
1225       mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1226   TransposeNode *T = F_->createTranspose("transpose", A, {0, 1, 2, 3});
1227   ReluNode *K = F_->createRELU("relu", T);
1228   F_->createSave("ret", K);
1229 
1230   EXPECT_EQ(F_->getNodes().size(), 3);
1231   EXPECT_EQ(K->getInput().getNode(), T);
1232 
1233   ::glow::optimize(F_, CompilationMode::Infer);
1234 
1235   EXPECT_EQ(F_->getNodes().size(), 2);
1236   EXPECT_EQ(K->getInput().getNode(), A);
1237   // Make sure we didn't mess up with the dimensions of the
1238   // variable while eliminating the transpose.
1239   EXPECT_EQ(A->dims(), llvm::makeArrayRef(origDims));
1240 }
1241 
1242 /// Check that the predicates don't get in the way of
1243 /// the identity transpose removal, while still being
1244 /// preserved.
TEST_F(GraphOptz,removeIdentityTransposeWithPredicate)1245 TEST_F(GraphOptz, removeIdentityTransposeWithPredicate) {
1246   const dim_t origDims[] = {1, 5, 10, 15};
1247   Placeholder *A =
1248       mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1249   Placeholder *pred1 =
1250       mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1251   Placeholder *pred2 =
1252       mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1253   Placeholder *pred3 =
1254       mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1255   TransposeNode *T = F_->createTranspose("transpose", A, {0, 1, 2, 3});
1256   T->setPredicate(pred1);
1257   ReluNode *K = F_->createRELU("relu", T);
1258   K->setPredicate(pred2);
1259   SaveNode *save = F_->createSave("ret", K);
1260   save->setPredicate(pred3);
1261 
1262   EXPECT_EQ(F_->getNodes().size(), 3);
1263   EXPECT_EQ(K->getInput().getNode(), T);
1264 
1265   ::glow::optimize(F_, CompilationMode::Infer);
1266   EXPECT_EQ(F_->getNodes().size(), 2);
1267   EXPECT_EQ(save->getPredicate().getNode(), pred3);
1268   EXPECT_EQ(save->getInput().getNode(), K);
1269   EXPECT_EQ(K->getInput().getNode(), A);
1270   EXPECT_EQ(K->getPredicate().getNode(), pred2);
1271   // Make sure we didn't mess up with the dimensions of the
1272   // variable while eliminating the transpose.
1273   EXPECT_EQ(A->dims(), llvm::makeArrayRef(origDims));
1274 }
1275 
1276 /// Check that consecutive non-inverse transposes are merged
1277 /// into an equivalent single transpose node.
TEST_F(GraphOptz,mergeNonInverseTransposes)1278 TEST_F(GraphOptz, mergeNonInverseTransposes) {
1279   const dim_t origDims[] = {1, 5, 10, 15};
1280   const dim_t finalDims[] = {5, 1, 15, 10};
1281 
1282   Placeholder *A =
1283       mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1284   TransposeNode *T1 = F_->createTranspose("transpose", A, {0, 3, 2, 1});
1285   TransposeNode *T2 = F_->createTranspose("transpose", T1, {0, 2, 3, 1});
1286   TransposeNode *T3 = F_->createTranspose("transpose", T2, {1, 0, 3, 2});
1287   TransposeNode *T4 = F_->createTranspose("transpose", T3, {3, 1, 2, 0});
1288 
1289   // Intermediate dims after each tranpose
1290   // Initial : {1, 5, 10, 15}
1291   // After T1: {1, 15, 10, 5}
1292   // After T2: {1, 10, 5, 15}
1293   // After T3: {10, 1, 15, 5}
1294   // After T4: {5, 1, 15, 10}
1295 
1296   SaveNode *save = F_->createSave("ret", T4);
1297 
1298   EXPECT_EQ(F_->getNodes().size(), 5);
1299 
1300   optimizedF_ = optimizeFunction(F_);
1301   // Find save node in the optimized graph.
1302   for (auto &N : optimizedF_->getNodes()) {
1303     if (N.getKind() == Kinded::Kind::SaveNodeKind) {
1304       save = llvm::dyn_cast<SaveNode>(&N);
1305     }
1306   }
1307   // Get the last transpose node in the optimized graph.
1308   auto *TR = llvm::dyn_cast<TransposeNode>(save->getInput());
1309   ASSERT_NE(TR, nullptr);
1310 
1311   EXPECT_EQ(optimizedF_->getNodes().size(), 2);
1312   EXPECT_EQ(TR->getResult().dims(), llvm::makeArrayRef(finalDims));
1313   EXPECT_EQ(A->getNthResult(0).dims(), llvm::makeArrayRef(origDims));
1314   EXPECT_EQ(TR->getInput().getNode(), A);
1315 
1316   bindings_.allocate(mod_.getPlaceholders());
1317   bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1318   checkNumericalEquivalence();
1319 }
1320 
TEST_F(GraphOptz,sinkTransposeBelowArithmeticNodes)1321 TEST_F(GraphOptz, sinkTransposeBelowArithmeticNodes) {
1322   const dim_t origDims[] = {1, 5, 10, 15};
1323   Node *A1 =
1324       mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1325   Node *A2 =
1326       mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1327   Node *T1 = F_->createTranspose("transpose1", A1, NHWC2NCHW);
1328   Node *T2 = F_->createTranspose("transpose2", A2, NHWC2NCHW);
1329   Node *K = F_->createAdd("arith", T1, T2);
1330   SaveNode *O = F_->createSave("ret", K);
1331 
1332   EXPECT_EQ(F_->getNodes().size(), 4);
1333 
1334   ::glow::optimize(F_, CompilationMode::Infer);
1335 
1336   // Expecting Transpose->Output rather than Add->Output.
1337   auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1338   ASSERT_NE(transpose, nullptr);
1339   auto *add = llvm::dyn_cast<AddNode>(transpose->getInput());
1340   ASSERT_TRUE(add);
1341   // Check that the dimensions of the input and output have been
1342   // updated to compensate the absence of transpose.
1343   EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims));
1344   EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims));
1345   EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims));
1346   EXPECT_EQ(add->getLHS().getNode(), A1);
1347   EXPECT_EQ(add->getRHS().getNode(), A2);
1348 
1349   EXPECT_EQ(F_->getNodes().size(), 3);
1350 }
1351 
1352 /// Check that Transpose node is sunk below arithmetic nodes when one of the
1353 /// operands is a Constant.
TEST_F(GraphOptz,sinkTransposeBelowArithmeticNodesWithConstantOperand)1354 TEST_F(GraphOptz, sinkTransposeBelowArithmeticNodesWithConstantOperand) {
1355   const dim_t origDims[] = {1, 5, 10, 15};
1356   const dim_t transposedDims[] = {1, 15, 5, 10};
1357 
1358   // Create one subgraph in which the Constant is the LHS operand of the Add.
1359   Constant *C1 = mod_.createConstant(ElemKind::FloatTy, transposedDims, "C1");
1360   // Initialize the payload before optimization so that it can be copied to the
1361   // new Constant that will be created by the GraphOptimizer.
1362   C1->getHandle().randomize(-1, 1, mod_.getPRNG());
1363 
1364   auto *P1 = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "P1", false);
1365   auto *T1 = F_->createTranspose("T1", P1, NHWC2NCHW);
1366   auto *A1 = F_->createAdd("A1", C1, T1);
1367   SaveNode *S1 = F_->createSave("S1", A1);
1368 
1369   // Create one subgraph in which the Constnat is the RHS operand of the Add.
1370   Constant *C2 = mod_.createConstant(ElemKind::FloatTy, transposedDims, "C2");
1371   // Initialize the payload before optimization so that it can be copied to the
1372   // new Constant that will be created by the GraphOptimizer.
1373   C2->getHandle().randomize(-1, 1, mod_.getPRNG());
1374 
1375   auto *P2 = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "P2", false);
1376   auto *T2 = F_->createTranspose("T2", P2, NHWC2NCHW);
1377   auto *A2 = F_->createAdd("A2", T2, C2);
1378   SaveNode *S2 = F_->createSave("S2", A2);
1379 
1380   EXPECT_EQ(F_->getNodes().size(), 6);
1381 
1382   optimizedF_ = optimizeFunction(F_);
1383 
1384   // Find the SaveNodes of the optimized graph.
1385   for (auto &N : optimizedF_->getNodes()) {
1386     if (N.getKind() == Kinded::Kind::SaveNodeKind) {
1387       if (N.getName() == S1->getName()) {
1388         S1 = llvm::dyn_cast<SaveNode>(&N);
1389       }
1390 
1391       if (N.getName() == S2->getName()) {
1392         S2 = llvm::dyn_cast<SaveNode>(&N);
1393       }
1394     }
1395   }
1396 
1397   // Expecting Transpose->Output rather than Add->Output.
1398   auto *transpose = llvm::dyn_cast<TransposeNode>(S1->getInput());
1399   ASSERT_NE(transpose, nullptr);
1400   auto *add = llvm::dyn_cast<AddNode>(transpose->getInput());
1401   ASSERT_TRUE(add);
1402   // Check that the dimensions of the input and output of the add have been
1403   // updated to compensate the absence of transpose.
1404   EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims));
1405   EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims));
1406   EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims));
1407   EXPECT_EQ(add->getRHS().getNode(), P1);
1408 
1409   // Repeat checks for other subgraph.
1410   transpose = llvm::dyn_cast<TransposeNode>(S2->getInput());
1411   ASSERT_NE(transpose, nullptr);
1412   add = llvm::dyn_cast<AddNode>(transpose->getInput());
1413   ASSERT_TRUE(add);
1414   EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims));
1415   EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims));
1416   EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims));
1417   EXPECT_EQ(add->getLHS().getNode(), P2);
1418 
1419   EXPECT_EQ(optimizedF_->getNodes().size(), 6);
1420 
1421   // Check that the original and optimized functions are numerically equivalent.
1422   // This indirectly checks that the Constant has been transposed properly.
1423   bindings_.allocate(mod_.getPlaceholders());
1424   bindings_.get(P1)->getHandle().randomize(-1, 1, mod_.getPRNG());
1425   bindings_.get(P2)->getHandle().randomize(-1, 1, mod_.getPRNG());
1426 
1427   checkNumericalEquivalence();
1428 }
1429 
1430 /// Check that the predicates are properly preserved while doing
1431 /// the add(transpose, transpose) => transpose(add).
TEST_F(GraphOptz,sinkTransposeBelowArithmeticNodesWithPredicate)1432 TEST_F(GraphOptz, sinkTransposeBelowArithmeticNodesWithPredicate) {
1433   const dim_t origDims[] = {1, 5, 10, 15};
1434   Node *A1 =
1435       mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1436   Node *A2 =
1437       mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1438   Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1439   Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1440   Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1441   Node *pred4 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1442   Node *T1 = F_->createTranspose("transpose1", A1, NHWC2NCHW);
1443   T1->setPredicate(pred1);
1444   Node *T2 = F_->createTranspose("transpose2", A2, NHWC2NCHW);
1445   T2->setPredicate(pred2);
1446   Node *K = F_->createAdd("arith", T1, T2);
1447   K->setPredicate(pred3);
1448   SaveNode *O = F_->createSave("ret", K);
1449   O->setPredicate(pred4);
1450 
1451   EXPECT_EQ(F_->getNodes().size(), 4);
1452 
1453   ::glow::optimize(F_, CompilationMode::Infer);
1454 
1455   EXPECT_EQ(O->getPredicate().getNode(), pred4);
1456   // Expecting Transpose->Output rather than Add->Output.
1457   auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1458   ASSERT_NE(transpose, nullptr);
1459   EXPECT_EQ(transpose->getPredicate().getNode(), pred3);
1460   auto *add = llvm::dyn_cast<AddNode>(transpose->getInput());
1461   ASSERT_TRUE(add);
1462   EXPECT_EQ(add->getPredicate().getNode(), pred3);
1463   // Check that the dimensions of the input and output have been
1464   // updated to compensate the absence of transpose.
1465   EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims));
1466   EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims));
1467   EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims));
1468   EXPECT_EQ(add->getLHS().getNode(), A1);
1469   EXPECT_EQ(add->getRHS().getNode(), A2);
1470 
1471   EXPECT_EQ(F_->getNodes().size(), 3);
1472 }
1473 
TEST_F(GraphOptz,sinkReluBelowConcatNodes)1474 TEST_F(GraphOptz, sinkReluBelowConcatNodes) {
1475   const dim_t origDims[] = {1, 5, 10, 15};
1476   const dim_t origDimsConcat[] = {1, 10, 10, 15};
1477   Node *A1 =
1478       mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1479   Node *A2 =
1480       mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1481   Node *R1 = F_->createRELU("relu1", A1);
1482   Node *R2 = F_->createRELU("relu2", A2);
1483   Node *CN = F_->createConcat("concat", {R1, R2}, 1);
1484   SaveNode *O = F_->createSave("ret", CN);
1485 
1486   EXPECT_EQ(F_->getNodes().size(), 4);
1487 
1488   ::glow::optimize(F_, CompilationMode::Infer);
1489 
1490   // Expecting RELU->Output rather than Concat->Output.
1491   auto *relu = llvm::dyn_cast<ReluNode>(O->getInput());
1492   ASSERT_NE(relu, nullptr);
1493   auto *concat = llvm::dyn_cast<ConcatNode>(relu->getInput());
1494   ASSERT_TRUE(concat);
1495   // Check that the dimensions of the input and output have been
1496   // updated to compensate the absence of transpose.
1497   EXPECT_EQ(concat->getResult().dims(), llvm::makeArrayRef(origDimsConcat));
1498   EXPECT_EQ(concat->getInputs()[0].dims(), llvm::makeArrayRef(origDims));
1499   EXPECT_EQ(concat->getInputs()[1].dims(), llvm::makeArrayRef(origDims));
1500   EXPECT_EQ(concat->getInputs()[0].getNode(), A1);
1501   EXPECT_EQ(concat->getInputs()[1].getNode(), A2);
1502 
1503   EXPECT_EQ(F_->getNodes().size(), 3);
1504 }
1505 
1506 /// Check that the predicates are properly preserved while doing
1507 /// the sinking of relu nodes.
TEST_F(GraphOptz,sinkReluBelowConcatNodesWithPredicate)1508 TEST_F(GraphOptz, sinkReluBelowConcatNodesWithPredicate) {
1509   const dim_t origDims[] = {1, 5, 10, 15};
1510   const dim_t origDimsConcat[] = {1, 10, 10, 15};
1511   Node *A1 =
1512       mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1513   Node *A2 =
1514       mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1515   Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1516   Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1517   Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1518   Node *pred4 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1519   Node *R1 = F_->createRELU("relu1", A1);
1520   R1->setPredicate(pred1);
1521   Node *R2 = F_->createRELU("relu2", A2);
1522   R2->setPredicate(pred2);
1523   Node *CN = F_->createConcat("concat", {R1, R2}, 1);
1524   CN->setPredicate(pred3);
1525   SaveNode *O = F_->createSave("ret", CN);
1526   O->setPredicate(pred4);
1527 
1528   EXPECT_EQ(F_->getNodes().size(), 4);
1529 
1530   ::glow::optimize(F_, CompilationMode::Infer);
1531 
1532   // Expecting RELU->Output rather than Concat->Output.
1533   EXPECT_EQ(O->getPredicate().getNode(), pred4);
1534   auto *relu = llvm::dyn_cast<ReluNode>(O->getInput());
1535   ASSERT_NE(relu, nullptr);
1536   EXPECT_EQ(relu->getPredicate().getNode(), pred3);
1537   auto *concat = llvm::dyn_cast<ConcatNode>(relu->getInput());
1538   ASSERT_TRUE(concat);
1539   EXPECT_EQ(concat->getPredicate().getNode(), pred3);
1540   // Check that the dimensions of the input and output have been
1541   // updated to compensate the absence of transpose.
1542   EXPECT_EQ(concat->getResult().dims(), llvm::makeArrayRef(origDimsConcat));
1543   EXPECT_EQ(concat->getInputs()[0].dims(), llvm::makeArrayRef(origDims));
1544   EXPECT_EQ(concat->getInputs()[1].dims(), llvm::makeArrayRef(origDims));
1545   EXPECT_EQ(concat->getInputs()[0].getNode(), A1);
1546   EXPECT_EQ(concat->getInputs()[1].getNode(), A2);
1547 
1548   EXPECT_EQ(F_->getNodes().size(), 3);
1549 }
1550 
TEST_F(GraphOptz,sinkTransposeBelowConcatNodes)1551 TEST_F(GraphOptz, sinkTransposeBelowConcatNodes) {
1552   const dim_t origDims[] = {1, 5, 10, 15};
1553   const dim_t origDimsConcat[] = {1, 5, 20, 15};
1554   Node *A1 =
1555       mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1556   Node *A2 =
1557       mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1558   Node *T1 = F_->createTranspose("transpose", A1, NCHW2NHWC);
1559   Node *T2 = F_->createTranspose("transpose", A2, NCHW2NHWC);
1560   Node *CN = F_->createConcat("concat", {T1, T2}, 1);
1561   SaveNode *O = F_->createSave("ret", CN);
1562 
1563   EXPECT_EQ(F_->getNodes().size(), 4);
1564 
1565   ::glow::optimize(F_, CompilationMode::Infer);
1566 
1567   // Expecting Transpose->Output rather than Add->Output.
1568   auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1569   ASSERT_NE(transpose, nullptr);
1570   auto *concat = llvm::dyn_cast<ConcatNode>(transpose->getInput());
1571   ASSERT_TRUE(concat);
1572   // Check that the dimensions of the input and output have been
1573   // updated to compensate the absence of transpose.
1574   EXPECT_EQ(concat->getResult().dims(), llvm::makeArrayRef(origDimsConcat));
1575   EXPECT_EQ(concat->getInputs()[0].dims(), llvm::makeArrayRef(origDims));
1576   EXPECT_EQ(concat->getInputs()[1].dims(), llvm::makeArrayRef(origDims));
1577   EXPECT_EQ(concat->getInputs()[0].getNode(), A1);
1578   EXPECT_EQ(concat->getInputs()[1].getNode(), A2);
1579 
1580   EXPECT_EQ(F_->getNodes().size(), 3);
1581 }
1582 
1583 /// Check that the predicates are properly preserved while doing
1584 /// the concat(transpose, transpose) => transpose(add).
TEST_F(GraphOptz,sinkTransposeBelowConcatWithPredicate)1585 TEST_F(GraphOptz, sinkTransposeBelowConcatWithPredicate) {
1586   const dim_t origDims[] = {1, 5, 10, 15};
1587   const dim_t origDimsConcat[] = {1, 5, 20, 15};
1588   Node *A1 =
1589       mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1590   Node *A2 =
1591       mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1592   Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1593   Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1594   Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1595   Node *pred4 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1596   Node *T1 = F_->createTranspose("transpose", A1, NCHW2NHWC);
1597   T1->setPredicate(pred1);
1598   Node *T2 = F_->createTranspose("transpose", A2, NCHW2NHWC);
1599   T2->setPredicate(pred2);
1600   Node *CN = F_->createConcat("concat", {T1, T2}, 1);
1601   CN->setPredicate(pred3);
1602   SaveNode *O = F_->createSave("ret", CN);
1603   O->setPredicate(pred4);
1604 
1605   EXPECT_EQ(F_->getNodes().size(), 4);
1606 
1607   ::glow::optimize(F_, CompilationMode::Infer);
1608 
1609   EXPECT_EQ(O->getPredicate().getNode(), pred4);
1610   // Expecting Transpose->Output rather than Add->Output.
1611   auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1612   ASSERT_NE(transpose, nullptr);
1613   EXPECT_EQ(transpose->getPredicate().getNode(), pred3);
1614   auto *concat = llvm::dyn_cast<ConcatNode>(transpose->getInput());
1615   ASSERT_TRUE(concat);
1616   EXPECT_EQ(concat->getPredicate().getNode(), pred3);
1617   // Check that the dimensions of the input and output have been
1618   // updated to compensate the absence of transpose.
1619   EXPECT_EQ(concat->getResult().dims(), llvm::makeArrayRef(origDimsConcat));
1620   EXPECT_EQ(concat->getInputs()[0].dims(), llvm::makeArrayRef(origDims));
1621   EXPECT_EQ(concat->getInputs()[1].dims(), llvm::makeArrayRef(origDims));
1622   EXPECT_EQ(concat->getInputs()[0].getNode(), A1);
1623   EXPECT_EQ(concat->getInputs()[1].getNode(), A2);
1624 
1625   EXPECT_EQ(F_->getNodes().size(), 3);
1626 }
1627 
TEST_F(GraphOptz,sinkTransposeBelowPad)1628 TEST_F(GraphOptz, sinkTransposeBelowPad) {
1629   // The shape of the graph before the optimization.
1630   const dim_t inputDims[] = {1, 5, 10, 15};
1631   const dim_t outTransposeDims[] = {1, 10, 15, 5};
1632   const dim_t outPadDims[] = {5, 18, 25, 11};
1633   // Padding before the optimization.
1634   int pads[] = {0, 2, 3, 1, 4, 6, 7, 5};
1635 
1636   // The shape of the graph after the optimization.
1637   const dim_t outPadDimsAfterOptim[] = {5, 11, 18, 25};
1638   const dim_t outTransposeDimsAfterOptims[] = {5, 18, 25, 11};
1639   // Padding after the optimization.
1640   int padsAfterOptim[] = {0, 1, 2, 3, 4, 5, 6, 7};
1641 
1642   // Create the initial graph.
1643   Node *A =
1644       mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
1645   auto outTy = mod_.uniqueType(ElemKind::FloatTy, outPadDims);
1646   TransposeNode *T = F_->createTranspose("transpose", A, NCHW2NHWC);
1647   Node *P = F_->createPad("pad", T, outTy, PaddingMode::CONSTANT, pads, 23.f);
1648   EXPECT_EQ(T->getResult().dims(), llvm::makeArrayRef(outTransposeDims));
1649   SaveNode *O = F_->createSave("ret", P);
1650 
1651   EXPECT_EQ(F_->getNodes().size(), 3);
1652 
1653   ::glow::optimize(F_, CompilationMode::Infer);
1654 
1655   // Check the graph structure and additional properties after optimization.
1656   auto *trans = llvm::dyn_cast<TransposeNode>(O->getInput());
1657   ASSERT_NE(trans, nullptr);
1658   EXPECT_EQ(trans->getResult().dims(),
1659             llvm::makeArrayRef(outTransposeDimsAfterOptims));
1660   auto *pad = llvm::dyn_cast<PadNode>(trans->getInput().getNode());
1661   ASSERT_NE(pad, nullptr);
1662 
1663   EXPECT_EQ(pad->getPads(), llvm::makeArrayRef(padsAfterOptim));
1664   EXPECT_EQ(pad->getResult().dims(), llvm::makeArrayRef(outPadDimsAfterOptim));
1665 
1666   EXPECT_EQ(F_->getNodes().size(), 3);
1667 }
1668 
TEST_F(GraphOptz,sinkTransposeBelowRelu)1669 TEST_F(GraphOptz, sinkTransposeBelowRelu) {
1670   // Define a type with custom alignments.
1671   Type typeWithAlignments(ElemKind::FloatTy, {2, 3, 4, 5}, {1, 1, 32, 1});
1672   Type transposedTypeWithAlignments(ElemKind::FloatTy, {2, 4, 5, 3},
1673                                     {1, 1, 32, 1});
1674   auto modTyWithAlignments = mod_.uniqueType(typeWithAlignments);
1675   auto modTransposedTyWithAlignments =
1676       mod_.uniqueType(transposedTypeWithAlignments);
1677   auto *A1 = mod_.createPlaceholder(modTyWithAlignments, "input1", false);
1678   auto *T1 = F_->createTranspose("transpose", A1, NCHW2NHWC);
1679   T1->setType(0, modTransposedTyWithAlignments);
1680   auto *RN = F_->createRELU("relu", T1);
1681   SaveNode *O = F_->createSave("ret", RN);
1682 
1683   EXPECT_EQ(F_->getNodes().size(), 3);
1684 
1685   ::glow::optimize(F_, CompilationMode::Infer);
1686 
1687   // Expecting Transpose->Output rather than Relu->Output, because Transpose was
1688   // sinked.
1689   auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1690   ASSERT_NE(transpose, nullptr);
1691   auto *relu = llvm::dyn_cast<ReluNode>(transpose->getInput());
1692   ASSERT_TRUE(relu);
1693   // Check that alignments are preserved by optimizations.
1694   ASSERT_TRUE(relu->getInput().getType()->isEqual(modTyWithAlignments));
1695   ASSERT_TRUE(transpose->getInput().getType()->isEqual(modTyWithAlignments));
1696   ASSERT_TRUE(
1697       transpose->getResult().getType()->isEqual(modTransposedTyWithAlignments));
1698 
1699   EXPECT_EQ(F_->getNodes().size(), 3);
1700   ASSERT_TRUE(F_->verify());
1701 }
1702 
TEST_F(GraphOptz,mergeConcatNodes)1703 TEST_F(GraphOptz, mergeConcatNodes) {
1704   Node *A1 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input1",
1705                                     false);
1706   Node *A2 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input2",
1707                                     false);
1708   Node *A3 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input3",
1709                                     false);
1710   Node *A4 =
1711       mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 5, 15}, "input4", false);
1712   Node *A5 =
1713       mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 5, 15}, "input5", false);
1714 
1715   Node *CN1 = F_->createConcat("concat1", {A1, A2}, 1);
1716   Node *CN2 = F_->createConcat("concat2", {A1, CN1}, 1);
1717   Node *CN3 = F_->createConcat("concat3", {A4, A5}, 2);
1718   Node *CN4 = F_->createConcat("concat4", {A3, CN2, CN3}, 1);
1719   Node *O = F_->createSave("ret", CN4);
1720 
1721   EXPECT_EQ(F_->getNodes().size(), 5);
1722 
1723   ::glow::optimize(F_, CompilationMode::Train);
1724 
1725   // It is expected that the optimization transforms
1726   // concat4(1, A3, concat2(1, A1, concat1(1, A1, A2)), concat3(2, A4, A5))
1727   // into
1728   // concat4(1, A3, A1, A1, A2, concat3(2, A4, A5))
1729 
1730   EXPECT_TRUE(llvm::isa<SaveNode>(O));
1731 
1732   auto *CN =
1733       llvm::dyn_cast<ConcatNode>(llvm::dyn_cast<SaveNode>(O)->getInput());
1734   EXPECT_TRUE(CN);
1735 
1736   // The merged ConcatNode should have 5 inputs.
1737   EXPECT_EQ(CN->getInputs().size(), 5);
1738 
1739   // CN1 should be merged into a new CN2 and later into a new CN4 and removed by
1740   // the optimizations.
1741   EXPECT_FALSE(functionContainsNode(F_, CN1));
1742 
1743   // CN2 should be merged into a new CN4 and removed by the optimizations.
1744   EXPECT_FALSE(functionContainsNode(F_, CN2));
1745 
1746   // CN3 should not be merged into CN4 and should not be removed,
1747   // because CN4 and CN3 have a different dimension parameter.
1748   EXPECT_TRUE(functionContainsNode(F_, CN3));
1749 
1750   // The CN4 concat node should be replaced by a merged concat node.
1751   EXPECT_FALSE(functionContainsNode(F_, CN4));
1752 
1753   EXPECT_EQ(F_->getNodes().size(), 3);
1754 }
1755 
TEST_F(GraphOptz,CSE)1756 TEST_F(GraphOptz, CSE) {
1757   Node *A1 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input1",
1758                                     false);
1759   Node *A2 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input2",
1760                                     false);
1761 
1762   Node *CN1 = F_->createConcat("concat1", {A1, A2}, 1);
1763   Node *CN2 = F_->createConcat("concat2", {A1, A2}, 1);
1764   Node *CN3 = F_->createConcat("concat3", {CN1, CN2}, 2);
1765   Node *O = F_->createSave("ret", CN3);
1766 
1767   EXPECT_EQ(F_->getNodes().size(), 4);
1768 
1769   ::glow::optimize(F_, CompilationMode::Train);
1770 
1771   EXPECT_TRUE(llvm::isa<SaveNode>(O));
1772 
1773   auto *CN =
1774       llvm::dyn_cast<ConcatNode>(llvm::dyn_cast<SaveNode>(O)->getInput());
1775   EXPECT_TRUE(CN);
1776 
1777   // The merged ConcatNode should have 2 inputs.
1778   EXPECT_EQ(CN->getInputs().size(), 2);
1779 
1780   // CN1 should not be removed.
1781   EXPECT_TRUE(functionContainsNode(F_, CN1));
1782 
1783   // CSE should replace CN2 by CN1 and remove CN2.
1784   EXPECT_FALSE(functionContainsNode(F_, CN2));
1785 
1786   EXPECT_EQ(F_->getNodes().size(), 3);
1787 }
1788 
TEST_F(GraphOptz,SliceOfSplatNode)1789 TEST_F(GraphOptz, SliceOfSplatNode) {
1790   Type t(ElemKind::FloatTy, {1000, 1000, 1000});
1791   Node *Z = F_->createSplat("zero", &t, 0.);
1792   Node *S = F_->createSlice("slice", Z, {5, 15, 42}, {99, 88, 77});
1793   Node *O = F_->createSave("ret", S);
1794 
1795   EXPECT_EQ(F_->getNodes().size(), 3);
1796 
1797   ::glow::optimize(F_, CompilationMode::Train);
1798 
1799   EXPECT_EQ(F_->getNodes().size(), 2);
1800 
1801   EXPECT_TRUE(llvm::isa<SaveNode>(O));
1802 
1803   auto *CN = llvm::dyn_cast<SplatNode>(llvm::dyn_cast<SaveNode>(O)->getInput());
1804   EXPECT_TRUE(CN);
1805 
1806   EXPECT_TRUE(CN->getResult().getType()->dims().equals({94, 73, 35}));
1807 }
1808 
1809 /// Test Clip(Splat(args)) -> Splat(args').
TEST_F(GraphOptz,ClipOfSplatNode)1810 TEST_F(GraphOptz, ClipOfSplatNode) {
1811   Type T(ElemKind::FloatTy, {10, 10});
1812   SplatNode *splat = F_->createSplat("zero", &T, 5);
1813   ClipNode *clipMin = F_->createClip("clip", splat, 10, 15);
1814   ClipNode *clipMax = F_->createClip("clip", splat, 0, 2);
1815   ClipNode *clipSame = F_->createClip("clip", splat, 0, 10);
1816   SaveNode *saveMin = F_->createSave("saveMin", clipMin);
1817   SaveNode *saveMax = F_->createSave("saveMax", clipMax);
1818   SaveNode *saveSame = F_->createSave("saveSame", clipSame);
1819 
1820   // Start with one splat, three clips, three saves.
1821   EXPECT_EQ(F_->getNodes().size(), 7);
1822 
1823   ::glow::optimize(F_, CompilationMode::Infer);
1824 
1825   // We will end up with three Splats and three saves.
1826   EXPECT_EQ(F_->getNodes().size(), 6);
1827 
1828   SplatNode *splatMin = llvm::dyn_cast<SplatNode>(saveMin->getInput());
1829   ASSERT_TRUE(splatMin);
1830   EXPECT_EQ(splatMin->getValue(), 10);
1831 
1832   SplatNode *splatMax = llvm::dyn_cast<SplatNode>(saveMax->getInput());
1833   ASSERT_TRUE(splatMax);
1834   EXPECT_EQ(splatMax->getValue(), 2);
1835 
1836   ASSERT_EQ(saveSame->getInput().getNode(), splat);
1837   EXPECT_EQ(splat->getValue(), 5);
1838 }
1839 
TEST_F(GraphOptz,ZeroArithmetic)1840 TEST_F(GraphOptz, ZeroArithmetic) {
1841   // Tests the identities: [0 + X = X] [0 * X = 0] [0 / X = 0] [ X - 0 = X]
1842 
1843   auto *input =
1844       mod_.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input", true);
1845 
1846   // This builds the expression: ((0 / I) + (0 + I) + (0 * I)) - 0
1847 
1848   auto *zero = F_->createSplat("zero", input->getType(), 0.);
1849 
1850   auto *div = F_->createDiv("div", zero, input); // -> zero
1851 
1852   auto *add = F_->createAdd("add", zero, input); // -> input
1853 
1854   auto *mul = F_->createMul("mul", zero, input); // -> zero
1855 
1856   auto *add3 = F_->createAdd("add", div, add);
1857 
1858   add3 = F_->createAdd("add", add3, mul);
1859 
1860   auto *sub = F_->createSub("sub", add3, zero); // -> input
1861 
1862   SaveNode *O = F_->createSave("ret", sub);
1863 
1864   // The expression evaluates to "I".
1865 
1866   EXPECT_EQ(F_->getNodes().size(), 8);
1867 
1868   ::glow::optimize(F_, CompilationMode::Infer);
1869 
1870   EXPECT_EQ(F_->getNodes().size(), 1);
1871 
1872   EXPECT_EQ(O->getInput().getNode(), input);
1873 
1874   optimizedF_ = optimizeFunction(F_);
1875 
1876   bindings_.allocate(mod_.getPlaceholders());
1877   bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1878 
1879   checkNumericalEquivalence();
1880 }
1881 
1882 // Similar to ZeroArithmetic, but tests that nodes with multiple results are
1883 // correctly handled (i.e. that the correct output is selected after optimising
1884 // away an arithmetic identity).
TEST_F(GraphOptz,ZeroArithmeticMultiResNode)1885 TEST_F(GraphOptz, ZeroArithmeticMultiResNode) {
1886   auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {10}, "input", true);
1887   auto *topK = F_->createTopK("topK", input, /*k=*/5);
1888   auto *zero = F_->createSplat("zero", topK->getValues().getType(), 0.);
1889   auto *add = F_->createAdd("add", topK->getValues(), zero);
1890   auto *sub = F_->createSub("sub", topK->getValues(), zero);
1891 
1892   SaveNode *AS = F_->createSave("ret", add);
1893   SaveNode *SS = F_->createSave("ret", sub);
1894 
1895   // There should be 6 nodes: 2 Saves, Add, Sub, Splat and TopK.
1896   EXPECT_EQ(F_->getNodes().size(), 6);
1897 
1898   optimizedF_ = optimizeFunction(F_);
1899 
1900   // Now there should only be 3 nodes: TopK and 2 Saves.
1901   EXPECT_EQ(optimizedF_->getNodes().size(), 3);
1902 
1903   auto *OAS = findFunctionNodeByName<SaveNode>(optimizedF_, AS->getName());
1904   auto *OSS = findFunctionNodeByName<SaveNode>(optimizedF_, SS->getName());
1905   auto *OTopK = findFunctionNodeByName<TopKNode>(optimizedF_, topK->getName());
1906 
1907   // Since the operations reprsented by the arithmetic nodes are no-ops,
1908   // the input to both SaveNodes should be the Values result of TopKNode.
1909   EXPECT_EQ(OAS->getInput(), OTopK->getValues());
1910   EXPECT_EQ(OSS->getInput(), OTopK->getValues());
1911 
1912   // Check numerical equivalence.
1913   bindings_.allocate(mod_.getPlaceholders());
1914   bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1915 
1916   checkNumericalEquivalence();
1917 }
1918 
1919 /// A test that verifies that arithmetic simplification works correctly when
1920 /// the parents need to be simplified prior to the node itself.
TEST_F(GraphOptz,ZeroArithmeticParentsMustBeSimplifiedFirst)1921 TEST_F(GraphOptz, ZeroArithmeticParentsMustBeSimplifiedFirst) {
1922   auto *input1 =
1923       mod_.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input1", true);
1924   auto *input2 =
1925       mod_.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input2", true);
1926 
1927   // This builds the expression: ((0 * I1) * (0 * I2)) = 0
1928   // It should be simplified to simply the splat zero node being saved.
1929 
1930   SplatNode *zero = F_->createSplat("zero", input1->getType(), 0.);
1931 
1932   MulNode *mul1 = F_->createMul("mul1", zero, input1); // -> 0
1933   MulNode *mul2 = F_->createMul("mul2", zero, input2); // -> 0
1934 
1935   MulNode *mul3 = F_->createMul("mul3", mul1, mul2); // -> 0
1936 
1937   SaveNode *O = F_->createSave("ret", mul3);
1938 
1939   // Expect 1 splat, 3 muls, 1 save.
1940   EXPECT_EQ(F_->getNodes().size(), 5);
1941 
1942   ::glow::optimize(F_, CompilationMode::Infer);
1943 
1944   // Expect all muls to be optimized away, with 1 splat and 1 save left.
1945   EXPECT_EQ(F_->getNodes().size(), 2);
1946   EXPECT_TRUE(functionContainsNode(F_, O));
1947   EXPECT_TRUE(functionContainsNode(F_, zero));
1948   EXPECT_EQ(O->getInput().getNode(), zero);
1949 }
1950 
1951 /// Tests opts for the identities: [1 * X = X] [X / 1 = X]
TEST_F(GraphOptz,ArithmeticIdentitiesOne)1952 TEST_F(GraphOptz, ArithmeticIdentitiesOne) {
1953   auto *input =
1954       mod_.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input", true);
1955 
1956   // This builds the expression: (I / 1) * 1:
1957   SplatNode *one = F_->createSplat("one", input->getType(), 1.);
1958   DivNode *div = F_->createDiv("div", input, one);
1959   MulNode *mul = F_->createMul("mul", div, one);
1960   SaveNode *save = F_->createSave("ret", mul);
1961 
1962   // Splat, Div, Mul, Save.
1963   EXPECT_EQ(F_->getNodes().size(), 4);
1964   // Save optimized function for future comparision
1965   optimizedF_ = optimizeFunction(F_);
1966 
1967   // The expression evaluates to "I", so Save is only node left.
1968   EXPECT_EQ(optimizedF_->getNodes().size(), 1);
1969   SaveNode *SN =
1970       llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save->getName()));
1971   ASSERT_TRUE(functionContainsNode(optimizedF_, SN));
1972   ASSERT_NE(SN, nullptr);
1973 
1974   // Save node should just save the input.
1975   EXPECT_TRUE(SN->getInput().getNode() == input);
1976 
1977   bindings_.allocate(mod_.getPlaceholders());
1978   bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1979 
1980   checkNumericalEquivalence();
1981 }
1982 
1983 /// Reverse the intrusive list of nodes. This custom implementation is required,
1984 /// because std::reverse cannot be used with LLVM's intrusive lists.
reverse(NodesList & L)1985 static void reverse(NodesList &L) {
1986   if (L.empty())
1987     return;
1988   // Last element of the list before reversal.
1989   auto &last = L.back();
1990   // Take element from the beginning and move it right after the old last
1991   // element. Do it until the old last element becomes the first element.
1992   while (true) {
1993     auto &first = L.front();
1994     // Finish when the old last element becomes the new front element.
1995     if (&first == &last) {
1996       break;
1997     }
1998     L.remove(first);
1999     L.insert(++last.getIterator(), &first);
2000   }
2001 }
2002 
TEST(GraphOptzTest,SliceOfSplatNodeChain)2003 TEST(GraphOptzTest, SliceOfSplatNodeChain) {
2004   for (int shouldReverse = 0; shouldReverse <= 1; shouldReverse++) {
2005     Module mod;
2006     Function *F = mod.createFunction("foo");
2007 
2008     Type t(ElemKind::FloatTy, {1000, 1000, 1000});
2009     Node *Z = F->createSplat("zero", &t, 0.);
2010     Node *S1 = F->createSlice("slice1", Z, {5, 15, 42}, {99, 88, 77});
2011     Node *S2 = F->createSlice("slice2", S1, {1, 1, 1}, {2, 3, 4});
2012     F->createSave("ret", S2);
2013 
2014     if (shouldReverse) {
2015       auto &nodes = F->getNodes();
2016       reverse(nodes);
2017     }
2018 
2019     EXPECT_EQ(F->getNodes().size(), 4);
2020 
2021     CompilationContext cctx;
2022     cctx.compMode = CompilationMode::Train;
2023     // Do not perform any compile-time constant folding.
2024     cctx.optimizationOpts.enableConstantFolding = false;
2025     ::glow::optimize(F, cctx);
2026 
2027     // This test illustrates some inconsistency in the optimization.
2028     // Chain splats are not guaranteed to be optimized.
2029     EXPECT_EQ(F->getNodes().size(), shouldReverse ? 3 : 2);
2030   }
2031 }
2032 
TEST_F(GraphOptz,ReshapeNoop)2033 TEST_F(GraphOptz, ReshapeNoop) {
2034   const dim_t shape[] = {10, 20, 30};
2035   Type t(ElemKind::FloatTy, shape);
2036   auto *Z = F_->createSplat("zero", &t, 0.);
2037   auto *R = F_->createReshape("reshape", Z, shape);
2038   auto *O = F_->createSave("ret", R);
2039 
2040   EXPECT_EQ(F_->getNodes().size(), 3);
2041 
2042   ::glow::optimize(F_, CompilationMode::Train);
2043 
2044   EXPECT_EQ(F_->getNodes().size(), 2);
2045 
2046   auto *SN = llvm::dyn_cast<SplatNode>(llvm::dyn_cast<SaveNode>(O)->getInput());
2047   EXPECT_TRUE(SN);
2048 
2049   EXPECT_TRUE(SN->getResult().getType()->dims().equals(shape));
2050 }
2051 
2052 /// Test the Reshape(Splat(args)) -> Splat(args') transformation.
2053 /// Including a positive and a negative test case. In the positive case,
2054 /// the optimization will take place for the splat node (Z2) that has only one
2055 /// use. In the negative case, the optimization will not happen as the splat
2056 /// node (Z1) has more than one use.
TEST_F(GraphOptz,ReshapeAfterSplat)2057 TEST_F(GraphOptz, ReshapeAfterSplat) {
2058   const dim_t shape[] = {10, 20, 30};
2059   const dim_t reshape[] = {1, 6000};
2060   Type t1(ElemKind::FloatTy, shape);
2061   Type t2(ElemKind::FloatTy, reshape);
2062   Node *input = F_->getParent()->createPlaceholder(ElemKind::FloatTy, shape,
2063                                                    "input", true);
2064   auto *Z1 = F_->createSplat("zero1", &t1, 1.5);
2065   auto *A1 = F_->createAdd("add1", Z1->getResult().getType(), input, Z1);
2066   auto *R1 = F_->createReshape("reshape1", Z1, reshape);
2067   // Z1 is used by R1 and A1.
2068   // The reshape optimization will thus NOT be able to remove this reshape node
2069   // (R1).
2070   auto *R2 = F_->createReshape("reshape2", A1, reshape);
2071   auto *A2 = F_->createAdd("add", R1->getResult().getType(), R1, R2);
2072   auto *Z2 = F_->createSplat("zero2", &t1, 2.5);
2073   auto *R3 = F_->createReshape("reshape3", Z2, reshape);
2074   // Z2 is only used by R3.
2075   // The Z2,R3 nodes will be replaced by a new splat node with the shape of R3.
2076   auto *A3 = F_->createAdd("add", A2->getResult().getType(), A2, R3);
2077   auto *O = F_->createSave("ret", A3);
2078 
2079   // Before optimization, we have 9 nodes in the graph.
2080   EXPECT_EQ(F_->getNodes().size(), 9);
2081 
2082   cctx_.compMode = CompilationMode::Infer;
2083   // Do not perform any compile-time constant folding.
2084   cctx_.optimizationOpts.enableConstantFolding = false;
2085   ::glow::optimize(F_, cctx_);
2086 
2087   // After optimization, we expect to see only 8 nodes, as Z2,R2 would be
2088   // replace by a new splat node.
2089   EXPECT_EQ(F_->getNodes().size(), 8);
2090 
2091   // The second input of A3 shoule be a splat node with a shape of R3.
2092   auto *newA3 = llvm::dyn_cast<AddNode>(O->getInput());
2093   ASSERT_TRUE(newA3);
2094   auto *SN = llvm::dyn_cast<SplatNode>(newA3->getRHS());
2095   EXPECT_TRUE(SN);
2096   EXPECT_TRUE(SN->getResult().getType()->dims().equals(reshape));
2097 
2098   // R1 should still be in the graph.
2099   EXPECT_TRUE(functionContainsNode(F_, R1));
2100 
2101   // R3 and Z2 should not be in the graph any more.
2102   EXPECT_FALSE(functionContainsNode(F_, R3));
2103   EXPECT_FALSE(functionContainsNode(F_, Z2));
2104 }
2105 
2106 /// Test the Reshape(Reshape(x)) -> Reshape(x) transformation.
TEST_F(GraphOptz,ReshapeReshapeOpt)2107 TEST_F(GraphOptz, ReshapeReshapeOpt) {
2108   const dim_t shape[] = {10, 20};
2109   const dim_t reshape1[] = {200, 1};
2110   const dim_t reshape2[] = {200};
2111   Node *input = F_->getParent()->createPlaceholder(ElemKind::FloatTy, shape,
2112                                                    "input", true);
2113   auto *R1 = F_->createReshape("reshape1", input, reshape1);
2114   auto *R2 = F_->createReshape("reshape2", R1, reshape2);
2115   auto *O = F_->createSave("ret", R2);
2116 
2117   // Before optimization, we have 2 Reshapes and a Save.
2118   EXPECT_EQ(F_->getNodes().size(), 3);
2119 
2120   ::glow::optimize(F_, CompilationMode::Infer);
2121 
2122   // After optimization, we expect to see only 1 Reshape and a Save.
2123   EXPECT_EQ(F_->getNodes().size(), 2);
2124 
2125   // Save should have the new Reshape as input.
2126   auto *RN = llvm::dyn_cast<ReshapeNode>(O->getInput());
2127   ASSERT_TRUE(RN);
2128   // The new Reshape should have the same shape as the original second Reshape.
2129   EXPECT_TRUE(RN->getResult().getType()->dims().equals(reshape2));
2130 
2131   // R1 and R2 should not be in the graph any more; they were replaced by a
2132   // single new reshape.
2133   EXPECT_FALSE(functionContainsNode(F_, R1));
2134   EXPECT_FALSE(functionContainsNode(F_, R2));
2135 }
2136 
TEST_F(GraphOptz,DCEPublicVars)2137 TEST_F(GraphOptz, DCEPublicVars) {
2138   mod_.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
2139 
2140   EXPECT_EQ(mod_.getPlaceholders().size(), 1);
2141 
2142   // Optimize all of the dead code.
2143   ::glow::optimize(F_, CompilationMode::Infer);
2144 
2145   //  Public nodes should not be deleted.
2146   EXPECT_EQ(mod_.getPlaceholders().size(), 1);
2147 }
2148 
TEST_F(GraphOptz,foldQuantizeIntoConstant)2149 TEST_F(GraphOptz, foldQuantizeIntoConstant) {
2150   auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {4}, "input", true);
2151   *bindings_.allocate(input) = {10, 10, 10, 10};
2152   auto qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2153 
2154   auto *Q = F_->createQuantize("quantize", input, qType);
2155   auto *S = F_->createSave("save", Q);
2156 
2157   EXPECT_EQ(2, F_->getNodes().size());
2158   ::glow::convertPlaceholdersToConstants(F_, bindings_, {S->getPlaceholder()});
2159 
2160   // 'optimize' doesn't merge quantize nodes into Constant.
2161   ::glow::optimize(F_, CompilationMode::Infer);
2162   EXPECT_EQ(2, F_->getNodes().size());
2163 
2164   // 'convertQuantizedConstants' merges quantize nodes into Constant
2165   CompilationContext cctx;
2166   ::glow::convertQuantizedConstants(F_, cctx);
2167   EXPECT_EQ(1, F_->getNodes().size());
2168 
2169   auto quantizedInput = llvm::cast<Constant>(S->getInput());
2170   auto quantizedValues = quantizedInput->getHandle<int8_t>();
2171   for (unsigned i = 0; i < 4; ++i) {
2172     EXPECT_EQ(5, quantizedValues.raw(i));
2173   }
2174 }
2175 
TEST_F(GraphOptz,foldQuantizeIntoConstantMultipleUsages)2176 TEST_F(GraphOptz, foldQuantizeIntoConstantMultipleUsages) {
2177   auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {4}, "input", true);
2178   *bindings_.allocate(input) = {10, 10, 10, 10};
2179   auto qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2180 
2181   auto *Q = F_->createQuantize("quantize", input, qType);
2182   F_->createSave("save", Q);
2183   auto clonedF = F_->clone("cloned");
2184 
2185   EXPECT_EQ(2, clonedF->getNodes().size());
2186   ::glow::convertPlaceholdersToConstants(clonedF, bindings_, {});
2187   CompilationContext cctx;
2188   ::glow::convertQuantizedConstants(clonedF, cctx);
2189 
2190   // F_ function should not be affected.
2191   EXPECT_EQ(2, F_->getNodes().size());
2192 
2193   // Check original var.
2194   for (unsigned i = 0; i < 4; ++i) {
2195     EXPECT_EQ(10, bindings_.get(input)->getHandle().raw(i));
2196   }
2197 
2198   // Quantization node was merged into input var.
2199   EXPECT_EQ(1, clonedF->getNodes().size());
2200   auto *save = llvm::dyn_cast<SaveNode>(&clonedF->getNodes().front());
2201   ASSERT_TRUE(save);
2202   auto quantizedInput = llvm::cast<Constant>(save->getInput());
2203   auto quantizedValues = quantizedInput->getHandle<int8_t>();
2204   for (unsigned i = 0; i < 4; ++i) {
2205     EXPECT_EQ(5, quantizedValues.raw(i));
2206   }
2207 }
2208 
2209 /// Search for a unique Save node in input graph \p F and return it.
2210 /// Fails in case there is no Save node or more than one detected.
getUniqueSaveNode(Function * F)2211 static SaveNode *getUniqueSaveNode(Function *F) {
2212   SaveNode *foundSaveNode = nullptr;
2213   for (auto &node : F->getNodes()) {
2214     if (auto *s = llvm::dyn_cast<SaveNode>(&node)) {
2215       EXPECT_EQ(foundSaveNode, nullptr);
2216       foundSaveNode = s;
2217     }
2218   }
2219   EXPECT_NE(foundSaveNode, nullptr);
2220   return foundSaveNode;
2221 }
2222 
2223 /// Mock backend that requests the pre-quantization of constants.
2224 class MockBackendPrequantizeConst : public MockBackend {
shouldPreQuantizeConstants() const2225   bool shouldPreQuantizeConstants() const override { return true; }
isOpSupported(const NodeInfo &) const2226   bool isOpSupported(const NodeInfo &) const override { return true; }
2227   Expected<bool>
transformPostLowering(Function * F,CompilationContext &,const glow::runtime::DeviceInfo *) const2228   transformPostLowering(Function *F, CompilationContext &,
2229                         const glow::runtime::DeviceInfo *) const override {
2230     // Check the IR.
2231     EXPECT_EQ(F->getNodes().size(), 1);
2232     auto *save = getUniqueSaveNode(F);
2233     EXPECT_TRUE(llvm::isa<Constant>(save->getInput()));
2234 
2235     return false;
2236   }
2237 };
2238 /// Mock backend that requests the non pre-quantization of constants.
2239 class MockBackendNotPrequantizeConst : public MockBackend {
shouldPreQuantizeConstants() const2240   bool shouldPreQuantizeConstants() const override { return false; }
isOpSupported(const NodeInfo &) const2241   bool isOpSupported(const NodeInfo &) const override { return true; }
2242   Expected<bool>
transformPostLowering(Function * F,CompilationContext &,const glow::runtime::DeviceInfo *) const2243   transformPostLowering(Function *F, CompilationContext &,
2244                         const glow::runtime::DeviceInfo *) const override {
2245     // Check the IR.
2246     EXPECT_EQ(F->getNodes().size(), 2);
2247     auto *save = getUniqueSaveNode(F);
2248     auto *quant = llvm::dyn_cast<QuantizeNode>(save->getInput());
2249     EXPECT_TRUE(quant);
2250     EXPECT_TRUE(llvm::isa<Constant>(quant->getInput()));
2251 
2252     return false;
2253   }
2254 };
2255 
2256 /// Test the actual constant quantization for backends.
2257 template <typename Backend>
testFoldQuantizeIntoConstant(Module & mod_,Function * F_)2258 void testFoldQuantizeIntoConstant(Module &mod_, Function *F_) {
2259   auto *input = mod_.createConstant(ElemKind::FloatTy, {4}, "input");
2260   input->getHandle<float>() = {10, 10, 10, 10};
2261   auto qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2262   auto *Q = F_->createQuantize("quantize", input, qType);
2263   auto *save = F_->createSave("save", Q);
2264 
2265   CompilationContext cctx;
2266   auto B = Backend();
2267   // Note: the check that Quantize is or not folded into Constant before
2268   // post-lowering is done in <backend>::transformPostLowering()
2269   EXIT_ON_ERR(::glow::optimizeFunction(F_, B, cctx));
2270 
2271   // Check the IR (the constant must have been quantized).
2272   EXPECT_EQ(F_->getNodes().size(), 1);
2273   EXPECT_TRUE(llvm::isa<Constant>(save->getInput()));
2274 }
2275 
2276 /// Check the backend actual constant quantization is done before post-lowering.
TEST_F(GraphOptz,foldQuantizeIntoConstantBeforePostLowering)2277 TEST_F(GraphOptz, foldQuantizeIntoConstantBeforePostLowering) {
2278   testFoldQuantizeIntoConstant<MockBackendPrequantizeConst>(mod_, F_);
2279 }
2280 
2281 /// Check the backend actual constant quantization is done after post-lowering.
TEST_F(GraphOptz,foldQuantizeIntoConstantAfterPostLowering)2282 TEST_F(GraphOptz, foldQuantizeIntoConstantAfterPostLowering) {
2283   testFoldQuantizeIntoConstant<MockBackendNotPrequantizeConst>(mod_, F_);
2284 }
2285 
2286 /// Check that the Quantize(Splat) -> Splat' optimization works.
TEST_F(GraphOptz,foldQuantizeIntoSplat)2287 TEST_F(GraphOptz, foldQuantizeIntoSplat) {
2288   TypeRef fType = mod_.uniqueType(ElemKind::FloatTy, {4});
2289   TypeRef qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2290 
2291   const float splatVal = 6.0;
2292   SplatNode *SN = F_->createSplat("splat", fType, splatVal);
2293 
2294   QuantizeNode *Q = F_->createQuantize("quantize", SN, qType);
2295   SaveNode *S = F_->createSave("save", Q);
2296 
2297   // Splat, quantize, save.
2298   EXPECT_EQ(3, F_->getNodes().size());
2299 
2300   ::glow::optimize(F_, CompilationMode::Infer);
2301 
2302   // Quantization node was merged into input splat.
2303   EXPECT_EQ(2, F_->getNodes().size());
2304 
2305   // New quantized splat should exist with same value.
2306   SplatNode *newSN = llvm::dyn_cast<SplatNode>(S->getInput());
2307   ASSERT_TRUE(newSN);
2308   EXPECT_EQ(splatVal, newSN->getValue());
2309   EXPECT_EQ(qType, newSN->getResult().getType());
2310 }
2311 
2312 /// Check that the Dequantize(Splat) -> Splat' optimization works.
TEST_F(GraphOptz,foldDequantizeIntoSplat)2313 TEST_F(GraphOptz, foldDequantizeIntoSplat) {
2314   TypeRef fType = mod_.uniqueType(ElemKind::FloatTy, {4});
2315   TypeRef qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2316 
2317   const float splatVal = 6.0;
2318   SplatNode *SN = F_->createSplat("splat", qType, splatVal);
2319 
2320   DequantizeNode *Q = F_->createDequantize("dequantize", SN, ElemKind::FloatTy);
2321   SaveNode *S = F_->createSave("save", Q);
2322 
2323   // Splat, dequantize, save.
2324   EXPECT_EQ(3, F_->getNodes().size());
2325 
2326   ::glow::optimize(F_, CompilationMode::Infer);
2327 
2328   // Dequantization node was merged into input splat.
2329   EXPECT_EQ(2, F_->getNodes().size());
2330 
2331   // New quantized splat should exist with same value.
2332   SplatNode *newSN = llvm::dyn_cast<SplatNode>(S->getInput());
2333   ASSERT_TRUE(newSN);
2334   EXPECT_EQ(splatVal, newSN->getValue());
2335   EXPECT_EQ(fType, newSN->getResult().getType());
2336 }
2337 
2338 /// Check that the Quantize(Splat) -> Splat' optimization works when the Splat
2339 /// has multiple users.
TEST_F(GraphOptz,foldQuantizeIntoSplatMultipleUsers)2340 TEST_F(GraphOptz, foldQuantizeIntoSplatMultipleUsers) {
2341   TypeRef fType = mod_.uniqueType(ElemKind::FloatTy, {4});
2342   TypeRef qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2343 
2344   SplatNode *SN = F_->createSplat("splat", fType, 6.0);
2345 
2346   QuantizeNode *Q = F_->createQuantize("quantize", SN, qType);
2347   SaveNode *SQ = F_->createSave("saveQ", Q);
2348   SaveNode *SF = F_->createSave("saveF", SN);
2349 
2350   // Splat, quantize, 2 saves.
2351   EXPECT_EQ(4, F_->getNodes().size());
2352 
2353   ::glow::optimize(F_, CompilationMode::Infer);
2354 
2355   // Quantization node was merged into input splat creating a new quantized
2356   // splat, but the original float splat still exists.
2357   EXPECT_EQ(4, F_->getNodes().size());
2358 
2359   // New quantized splat should exist with same value.
2360   SplatNode *newSN = llvm::dyn_cast<SplatNode>(SQ->getInput());
2361   ASSERT_TRUE(newSN);
2362   EXPECT_EQ(SN->getValue(), newSN->getValue());
2363   EXPECT_EQ(qType, newSN->getResult().getType());
2364 
2365   // Original float splat should still exist.
2366   EXPECT_EQ(llvm::dyn_cast<SplatNode>(SF->getInput()), SN);
2367 }
2368 
2369 /// Check that an unnecessary rescale gets removed.
TEST_F(GraphOptz,removeUnnecessaryRescale)2370 TEST_F(GraphOptz, removeUnnecessaryRescale) {
2371   TypeRef qType = mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.03f, 5);
2372   Placeholder *input =
2373       mod_.createPlaceholder(qType, "input", /* isTrainable */ true);
2374   RescaleQuantizedNode *RQ =
2375       F_->createRescaleQuantized("rescale", input, qType);
2376   SaveNode *save = F_->createSave("ret", RQ);
2377 
2378   // RescaleQuantized and Save.
2379   EXPECT_EQ(F_->getNodes().size(), 2);
2380 
2381   ::glow::optimize(F_, CompilationMode::Infer);
2382 
2383   // Only Save should be left, which saves the Placeholder directly with
2384   // unchanged quantization parameters.
2385   EXPECT_EQ(F_->getNodes().size(), 1);
2386   EXPECT_EQ(save->getInput().getNode(), input);
2387   EXPECT_EQ(save->getInput().getType(), qType);
2388 }
2389 
2390 /// Check that rescale gets correctly merged into a following dequantize node
TEST_F(GraphOptz,mergeRescaleIntoDequantize)2391 TEST_F(GraphOptz, mergeRescaleIntoDequantize) {
2392   // Check that we are combining quantization-dequantization pairs.
2393   auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11,
2394                                        "input", true);
2395   auto *qType = mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.03f, 5);
2396   auto *R = F_->createRescaleQuantized("rescale", input, qType);
2397   auto *D = F_->createDequantize("dequantize", R, ElemKind::FloatTy);
2398   F_->createSave("ret", D);
2399 
2400   EXPECT_EQ(F_->getNodes().size(), 3);
2401   ::glow::optimize(F_, CompilationMode::Infer);
2402 
2403   // Only 2 nodes should remain (Dequantize -> Save)
2404   EXPECT_EQ(F_->getNodes().size(), 2);
2405 
2406   // Check the graph structure
2407   auto *SN = F_->getNodeByName("ret_save");
2408   EXPECT_NE(nullptr, SN);
2409   auto *S = llvm::dyn_cast<SaveNode>(SN);
2410   EXPECT_NE(nullptr, S);
2411   auto *newDN = S->getInput().getNode();
2412   EXPECT_NE(nullptr, newDN);
2413   EXPECT_NE(nullptr, llvm::dyn_cast<DequantizeNode>(newDN));
2414 }
2415 
TEST_F(GraphOptz,quantizeToRescale)2416 TEST_F(GraphOptz, quantizeToRescale) {
2417   // Check that we are combining quantization-dequantization pairs.
2418   auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11,
2419                                        "input", true);
2420 
2421   auto *D = F_->createDequantize("dequantize", input, ElemKind::FloatTy);
2422 
2423   auto qType = mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.03, 5);
2424   auto *Q = F_->createQuantize("quantize", D, qType);
2425 
2426   F_->createSave("ret", Q);
2427 
2428   EXPECT_EQ(F_->getNodes().size(), 3);
2429 
2430   ::glow::optimize(F_, CompilationMode::Infer);
2431   EXPECT_EQ(F_->getNodes().size(), 2);
2432 }
2433 
TEST_F(GraphOptz,MaxOfQuantizedSplat)2434 TEST_F(GraphOptz, MaxOfQuantizedSplat) {
2435   const dim_t size = 5;
2436   const float scale = 1;
2437   // offset == -128 guarantees that fp range has values which are not less than
2438   // 0.
2439   const int32_t offset = -128;
2440 
2441   auto splatTy = mod_.uniqueType(ElemKind::Int8QTy, {size}, scale, offset);
2442   auto *splat = F_->createSplat("splat", splatTy, 0.0);
2443 
2444   auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {size}, scale, offset,
2445                                        "input", true);
2446 
2447   auto *max = F_->createMax("max", splat, input);
2448   F_->createSave("save", max);
2449   EXPECT_EQ(F_->getNodes().size(), 3);
2450 
2451   ::glow::optimize(F_, CompilationMode::Infer);
2452   // Splat and Max should be gone.
2453   EXPECT_EQ(F_->getNodes().size(), 1);
2454 }
2455 
TEST_F(GraphOptz,FuseRescaleIntoArithmetic)2456 TEST_F(GraphOptz, FuseRescaleIntoArithmetic) {
2457   // This test ensures the fact that fusing of rescale is done.
2458   auto opOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10}, 1, 0);
2459   auto rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10}, 2, 1);
2460 
2461   Placeholder *LHS =
2462       mod_.createPlaceholder(ElemKind::Int8QTy, {10}, 0.4, 0, "LHS", true);
2463   Placeholder *RHS =
2464       mod_.createPlaceholder(ElemKind::Int8QTy, {10}, 0.3, 0, "RHS", true);
2465 
2466   AddNode *add = F_->createAdd("qAdd", opOutTy, LHS, RHS);
2467   RescaleQuantizedNode *rescaleAdd =
2468       F_->createRescaleQuantized("rsAdd", add, rescaleOutTy);
2469   SaveNode *addSave = F_->createSave("saveAdd", rescaleAdd);
2470 
2471   SubNode *sub = F_->createSub("qSub", opOutTy, LHS, RHS);
2472   RescaleQuantizedNode *rescaleSub =
2473       F_->createRescaleQuantized("rsSub", sub, rescaleOutTy);
2474   SaveNode *subSave = F_->createSave("saveSub", rescaleSub);
2475 
2476   DivNode *div = F_->createDiv("qDiv", opOutTy, LHS, RHS);
2477   RescaleQuantizedNode *rescaleDiv =
2478       F_->createRescaleQuantized("rsDiv", div, rescaleOutTy);
2479   SaveNode *divSave = F_->createSave("saveDiv", rescaleDiv);
2480 
2481   MulNode *mul = F_->createMul("qMul", opOutTy, LHS, RHS);
2482   RescaleQuantizedNode *rescaleMul =
2483       F_->createRescaleQuantized("rsMul", mul, rescaleOutTy);
2484   SaveNode *mulSave = F_->createSave("saveMul", rescaleMul);
2485 
2486   MinNode *min = F_->createMin("qMin", opOutTy, LHS, RHS);
2487   RescaleQuantizedNode *rescaleMin =
2488       F_->createRescaleQuantized("rsMin", min, rescaleOutTy);
2489   SaveNode *minSave = F_->createSave("saveMin", rescaleMin);
2490 
2491   MaxNode *max = F_->createMax("qMax", opOutTy, LHS, RHS);
2492   RescaleQuantizedNode *rescaleMax =
2493       F_->createRescaleQuantized("rsMax", max, rescaleOutTy);
2494   SaveNode *maxSave = F_->createSave("saveMax", rescaleMax);
2495 
2496   // All rescales must be fused into arithmetic operations above.
2497   ::glow::optimize(F_, CompilationMode::Infer);
2498 
2499   EXPECT_EQ(F_->getNodes().size(), 12);
2500 
2501   EXPECT_EQ(addSave->getInput().getType(), rescaleOutTy);
2502   EXPECT_EQ(subSave->getInput().getType(), rescaleOutTy);
2503   EXPECT_EQ(mulSave->getInput().getType(), rescaleOutTy);
2504   EXPECT_EQ(divSave->getInput().getType(), rescaleOutTy);
2505   EXPECT_EQ(minSave->getInput().getType(), rescaleOutTy);
2506   EXPECT_EQ(maxSave->getInput().getType(), rescaleOutTy);
2507 }
2508 
2509 /// Check that the Rescale(MatMul) -> MatMul' optimization works correctly.
TEST_F(GraphOptz,FuseRescaleUpIntoMatMul)2510 TEST_F(GraphOptz, FuseRescaleUpIntoMatMul) {
2511   // This test ensures the fact that fusing of rescale is done.
2512   auto opOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10, 10}, 1, 0);
2513   auto rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10, 10}, 2, 1);
2514 
2515   Placeholder *LHS = mod_.createPlaceholder(ElemKind::Int8QTy, {10, 10}, 0.4, 0,
2516                                             "LHS", /* isTrainable */ false);
2517   Placeholder *RHS = mod_.createPlaceholder(ElemKind::Int8QTy, {10, 10}, 0.3, 0,
2518                                             "RHS", /* isTrainable */ false);
2519 
2520   MatMulNode *MMN = F_->createMatMul("matmul", opOutTy, LHS, RHS);
2521   RescaleQuantizedNode *rescaleMMN =
2522       F_->createRescaleQuantized("rsMMN", MMN, rescaleOutTy);
2523   SaveNode *saveMMN = F_->createSave("saveMMN", rescaleMMN);
2524 
2525   // MatMul, Rescale, Save.
2526   EXPECT_EQ(F_->getNodes().size(), 3);
2527 
2528   // All rescales must be fused into arithmetic operations above.
2529   ::glow::optimize(F_, CompilationMode::Infer);
2530 
2531   // Rescale merged up into the MatMul.
2532   EXPECT_EQ(F_->getNodes().size(), 2);
2533 
2534   MatMulNode *newMMN = llvm::dyn_cast<MatMulNode>(saveMMN->getInput());
2535   ASSERT_TRUE(newMMN);
2536   EXPECT_EQ(newMMN->getResult().getType(), rescaleOutTy);
2537 }
2538 
2539 /// Check that the Rescale(SparseLengthsWeightedSum) ->
2540 /// SparseLengthsWeightedSum' optimization works correctly.
TEST_F(GraphOptz,FuseRescaleUpIntoSparseLengthsWeightedSum)2541 TEST_F(GraphOptz, FuseRescaleUpIntoSparseLengthsWeightedSum) {
2542   // This test ensures the fact that fusing of rescale is done.
2543   TypeRef rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 1);
2544 
2545   Placeholder *data =
2546       mod_.createPlaceholder(ElemKind::Int8QTy, {3}, 0.5, 0, "data",
2547                              /* isTrainable */ false);
2548   Placeholder *weights = mod_.createPlaceholder(
2549       ElemKind::Int8QTy, {8}, 0.5, 0, "weights", /* isTrainable */ false);
2550   Placeholder *indices =
2551       mod_.createPlaceholder(ElemKind::Int64ITy, {8}, "indices",
2552                              /* isTrainable */ false);
2553   Placeholder *lengths =
2554       mod_.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths",
2555                              /* isTrainable */ false);
2556 
2557   SparseLengthsWeightedSumNode *SLWS = F_->createSparseLengthsWeightedSum(
2558       "SLWS", data, weights, indices, lengths);
2559   RescaleQuantizedNode *rescaleSLWS =
2560       F_->createRescaleQuantized("rsSLWS", SLWS, rescaleOutTy);
2561   SaveNode *saveSLWS = F_->createSave("saveSLWS", rescaleSLWS);
2562 
2563   // SparseLengthsWeightedSum, Rescale, Save.
2564   EXPECT_EQ(F_->getNodes().size(), 3);
2565 
2566   // All rescales must be fused into arithmetic operations above.
2567   ::glow::optimize(F_, CompilationMode::Infer);
2568 
2569   // Rescale merged up into the SparseLengthsWeightedSum.
2570   EXPECT_EQ(F_->getNodes().size(), 2);
2571 
2572   SparseLengthsWeightedSumNode *newSLWS =
2573       llvm::dyn_cast<SparseLengthsWeightedSumNode>(saveSLWS->getInput());
2574   ASSERT_TRUE(newSLWS);
2575   EXPECT_EQ(newSLWS->getResult().getType(), rescaleOutTy);
2576 }
2577 
TEST_F(GraphOptz,fuseRescaleIntoConv)2578 TEST_F(GraphOptz, fuseRescaleIntoConv) {
2579   // This test ensures the fact that fusing of rescale is done.
2580   auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {1, 10, 20, 3}, 0.5,
2581                                        10, "input", true);
2582   auto *filter = mod_.createPlaceholder(ElemKind::Int8QTy, {16, 5, 5, 3}, 0.5,
2583                                         10, "filter", true);
2584   auto *bias =
2585       mod_.createPlaceholder(ElemKind::Int8QTy, {16}, 0.5, 10, "bias", true);
2586 
2587   auto *rInput = F_->createRescaleQuantized(
2588       "rescale", input,
2589       mod_.uniqueType(ElemKind::Int8QTy, {1, 10, 20, 3}, 0.1, -25));
2590   auto *rFilter = F_->createRescaleQuantized(
2591       "rescale", filter,
2592       mod_.uniqueType(ElemKind::Int8QTy, {16, 5, 5, 3}, 0.2, 0));
2593   auto *rBias = F_->createRescaleQuantized(
2594       "rescale", bias, mod_.uniqueType(ElemKind::Int8QTy, {16}, 0.3, 25));
2595   auto *CV = F_->createConv(
2596       "conv", rInput, rFilter, rBias,
2597       mod_.uniqueType(ElemKind::Int8QTy, {1, 10, 20, 16}, 0.7, -3), 5, 1, 2, 1);
2598   auto *rCV = F_->createRescaleQuantized(
2599       "rescale", CV,
2600       mod_.uniqueType(ElemKind::Int8QTy, {1, 10, 20, 16}, 0.4, 37));
2601   F_->createSave("save", rCV);
2602 
2603   // All rescales must be fused into convolution.
2604   EXPECT_EQ(F_->getNodes().size(), 6);
2605   ::glow::optimize(F_, CompilationMode::Infer);
2606   EXPECT_EQ(F_->getNodes().size(), 2);
2607 }
2608 
2609 /// This test ensures that if there is a Pad node as input of a Convolution
2610 /// node, Pad gets merges into Convolution.
2611 /// Note that Pads is merged into convolution only when it is compatible with
2612 /// the convolution padding:
2613 /// - Resulting padding after merge is positive
2614 /// - Padding only concerns spatial dimensions
2615 /// - Padding has mode 'constant' with value 0.f
fusePadIntoConvTest(glow::Module & mod_,glow::Function * F_,llvm::ArrayRef<dim_t> inputDims,llvm::ArrayRef<int> pads,unsigned_t convKernelSize,llvm::ArrayRef<unsigned_t> convPads,unsigned_t convStride,unsigned_t convNumKernels)2616 void fusePadIntoConvTest(glow::Module &mod_, glow::Function *F_,
2617                          llvm::ArrayRef<dim_t> inputDims,
2618                          llvm::ArrayRef<int> pads, unsigned_t convKernelSize,
2619                          llvm::ArrayRef<unsigned_t> convPads,
2620                          unsigned_t convStride, unsigned_t convNumKernels) {
2621   auto *input =
2622       mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input", true);
2623 
2624   // Pad
2625   dim_t inputWithPadDims[4];
2626   for (int i = 0; i < 4; i++) {
2627     inputWithPadDims[i] = dim_t(ssize_t(inputDims[i]) + pads[i] + pads[4 + i]);
2628   }
2629   dim_t outputConvDims[4] = {
2630       inputWithPadDims[0],
2631       inputWithPadDims[1] + convPads[0] + convPads[2] - (convKernelSize - 1),
2632       inputWithPadDims[2] + convPads[1] + convPads[3] - (convKernelSize - 1),
2633       convNumKernels};
2634 
2635   auto outTy = mod_.uniqueType(ElemKind::FloatTy, inputWithPadDims);
2636   Node *P =
2637       F_->createPad("pad", input, outTy, PaddingMode::CONSTANT, pads, 0.f);
2638 
2639   // Convolution
2640   dim_t filterDims[] = {convNumKernels, convKernelSize, convKernelSize,
2641                         inputDims[3]};
2642   auto *F =
2643       mod_.createPlaceholder(ElemKind::FloatTy, filterDims, "filter", true);
2644   auto *B =
2645       mod_.createPlaceholder(ElemKind::FloatTy, {convNumKernels}, "bias", true);
2646   auto *CV = F_->createConv(
2647       "conv", P, F, B, mod_.uniqueType(ElemKind::FloatTy, outputConvDims),
2648       {convKernelSize, convKernelSize}, {convStride, convStride}, convPads, 1);
2649 
2650   SaveNode *O = F_->createSave("save", CV);
2651 
2652   // The pad node must be merged into convolution.
2653   EXPECT_EQ(F_->getNodes().size(), 3);
2654   ::glow::optimize(F_, CompilationMode::Infer);
2655   EXPECT_EQ(F_->getNodes().size(), 2);
2656 
2657   // Check the graph structure and additional properties after optimization.
2658   auto *conv = llvm::dyn_cast<ConvolutionNode>(O->getInput());
2659   ASSERT_NE(conv, nullptr);
2660   EXPECT_EQ(conv->getResult().dims(), llvm::ArrayRef<dim_t>(outputConvDims));
2661   unsigned_t expectedPads[4];
2662   for (int i = 0; i < 2; i++) {
2663     for (int j = 0; j < 2; j++) {
2664       expectedPads[2 * i + j] =
2665           unsigned_t(int(convPads[2 * i + j]) + pads[4 * i + (1 + j)]);
2666     }
2667   }
2668   EXPECT_EQ(conv->getPads(), llvm::makeArrayRef(expectedPads));
2669 }
2670 
TEST_F(GraphOptz,fusePadIntoConv)2671 TEST_F(GraphOptz, fusePadIntoConv) {
2672   fusePadIntoConvTest(mod_, F_, {1, 6, 14, 3} /* inputDims */,
2673                       {0, 1, 2, 0, 0, 3, 4, 0} /* pads */,
2674                       5 /* convKernelSize */, {0, 0, 0, 0} /* convPads */,
2675                       1 /* convStride */, 16 /* convNumKernels */);
2676 }
2677 
TEST_F(GraphOptz,fusePadIntoConvNeg1)2678 TEST_F(GraphOptz, fusePadIntoConvNeg1) {
2679   fusePadIntoConvTest(mod_, F_, {1, 6, 14, 3} /* inputDims */,
2680                       {0, -1, 2, 0, 0, 3, -2, 0} /* pads */,
2681                       5 /* convKernelSize */, {3, 0, 2, 5} /* convPads */,
2682                       1 /* convStride */, 16 /* convNumKernels */);
2683 }
2684 
TEST_F(GraphOptz,fusePadIntoConvNeg2)2685 TEST_F(GraphOptz, fusePadIntoConvNeg2) {
2686   fusePadIntoConvTest(mod_, F_, {1, 6, 14, 3} /* inputDims */,
2687                       {0, 1, -2, 0, 0, -3, 4, 0} /* pads */,
2688                       5 /* convKernelSize */, {0, 2, 5, 7} /* convPads */,
2689                       1 /* convStride */, 16 /* convNumKernels */);
2690 }
2691 
2692 /// This test checks that a lowered LeakyRelu is corrected folded:
2693 /// Max(A, Mult(A, Splat)) -> PRelu(Splat)
TEST_F(GraphFold,foldLeakyReluFromSplat)2694 TEST_F(GraphFold, foldLeakyReluFromSplat) {
2695   std::vector<dim_t> dims = {5, 2};
2696 
2697   auto *input = mod_.createPlaceholder(ElemKind::FloatTy, dims, "input", true);
2698 
2699   const float leakyAlpha = 0.05f;
2700   auto OutTy = mod_.uniqueType(ElemKind::FloatTy, dims);
2701   SplatNode *splatNode = F_->createSplat("splat", OutTy, leakyAlpha);
2702   MulNode *mulNode = F_->createMul("mul", input, splatNode);
2703   MaxNode *maxNode = F_->createMax("max", input, mulNode);
2704   SaveNode *output = F_->createSave("save", maxNode);
2705 
2706   EXPECT_EQ(4, F_->getNodes().size());
2707 
2708   CompilationContext cctx;
2709   ::glow::fold(F_, cctx);
2710 
2711   // Check the resulting graph after folding.
2712   EXPECT_EQ(3, F_->getNodes().size());
2713   auto *newPReluNode = llvm::dyn_cast<PReluNode>(output->getInput());
2714   ASSERT_TRUE(newPReluNode);
2715   auto *newSplatNode = llvm::dyn_cast<SplatNode>(newPReluNode->getSlope());
2716   ASSERT_TRUE(newSplatNode);
2717   EXPECT_EQ(leakyAlpha, newSplatNode->getValue());
2718   EXPECT_EQ(input, newPReluNode->getInput());
2719 }
2720 
2721 /// This test checks that a lowered LeakyRelu is corrected folded:
2722 /// Max(A, Mult(A, broadcasted Const)) -> PRelu(Splat)
TEST_F(GraphFold,foldLeakyReluFromConst)2723 TEST_F(GraphFold, foldLeakyReluFromConst) {
2724   std::vector<dim_t> dims = {5, 2};
2725   auto *input = mod_.createPlaceholder(ElemKind::FloatTy, dims, "input", true);
2726 
2727   const float leakyAlpha = 0.99f;
2728   auto *alphaConst = mod_.createConstant(ElemKind::FloatTy, {1}, "alphaConst");
2729   alphaConst->getHandle() = {leakyAlpha};
2730   ReshapeNode *reshapeNode = F_->createReshape("reshape", alphaConst, {1, 1});
2731   TileNode *tileNode1 = F_->createTile("tile1", reshapeNode, 2, 1);
2732   TileNode *tileNode2 = F_->createTile("tile2", tileNode1, 5, 0);
2733   MulNode *mulNode = F_->createMul("mul", input, tileNode2);
2734   MaxNode *maxNode = F_->createMax("max", input, mulNode);
2735   SaveNode *output = F_->createSave("save", maxNode);
2736 
2737   EXPECT_EQ(6, F_->getNodes().size());
2738 
2739   CompilationContext cctx;
2740   ::glow::fold(F_, cctx);
2741 
2742   // Check the resulting graph after folding. Reshape must have been merged into
2743   // the constant and LeakyRelu must have been folded.
2744   EXPECT_EQ(3, F_->getNodes().size());
2745   auto *newPReluNode = llvm::dyn_cast<PReluNode>(output->getInput());
2746   ASSERT_TRUE(newPReluNode);
2747   auto *newSplatNode = llvm::dyn_cast<SplatNode>(newPReluNode->getSlope());
2748   ASSERT_TRUE(newSplatNode);
2749   EXPECT_EQ(leakyAlpha, newSplatNode->getValue());
2750   EXPECT_EQ(input, newPReluNode->getInput());
2751 }
2752 
2753 /// Testing folding of Reshape->Transpose->Reshape into ChannelShuffle.
TEST_F(GraphFold,foldChannelShuffle)2754 TEST_F(GraphFold, foldChannelShuffle) {
2755   const dim_t inputDims[] = {3, 136, 28, 28};
2756 
2757   Node *K =
2758       mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
2759   K = F_->createReshape("CS_reshape1", K, {3, 4, 34, 28, 28});
2760   K = F_->createTranspose("CS_transpose", K, {0, 2, 1, 3, 4});
2761   K = F_->createReshape("CS_reshape2", K, {3, 136, 28, 28});
2762   auto *save = F_->createSave("ret", K);
2763 
2764   EXPECT_EQ(F_->getNodes().size(), 4);
2765 
2766   // Fold RN->TR->RN into ChannelShuffle
2767   CompilationContext cctx;
2768   ::glow::fold(F_, cctx);
2769 
2770   ASSERT_EQ(F_->getNodes().size(), 2);
2771 
2772   // Check for ChannelShuffle node.
2773   auto *CS = llvm::dyn_cast<ChannelShuffleNode>(save->getInput().getNode());
2774   ASSERT_NE(nullptr, CS);
2775 
2776   // Ensure ChannelShuffle node has the same dimensions as the input.
2777   EXPECT_EQ(CS->getResult().dims(), llvm::makeArrayRef(inputDims));
2778 
2779   // Ensure Group and Kernel are as expected.
2780   EXPECT_EQ(CS->getGroup(), 4);
2781   EXPECT_EQ(CS->getKernel(), 1);
2782 }
2783 
TEST_F(GraphFold,NoFoldChannelShuffle)2784 TEST_F(GraphFold, NoFoldChannelShuffle) {
2785   auto Float = ElemKind::FloatTy;
2786   auto *P = mod_.createPlaceholder(Float, {10, 8928}, "P", false);
2787   auto *R1 = F_->createReshape("R1", P, {10, 186, 48});
2788   auto *TR = F_->createTranspose("TR", R1, {0, 2, 1});
2789   auto *R2 = F_->createReshape("R2", TR, {480, 186});
2790   auto *save = F_->createSave("save", R2);
2791 
2792   EXPECT_EQ(F_->getNodes().size(), 4);
2793 
2794   CompilationContext cctx;
2795   ::glow::fold(F_, cctx);
2796 
2797   EXPECT_EQ(F_->getNodes().size(), 4);
2798   EXPECT_FALSE(llvm::isa<ChannelShuffleNode>(save->getInput()));
2799 }
2800 
2801 class MockBackendWithFusion : public MockBackend {
supportsFusedActivation(Node * parent,Node * activation) const2802   bool supportsFusedActivation(Node *parent, Node *activation) const override {
2803     switch (parent->getKind()) {
2804     case Kinded::Kind::ConvolutionNodeKind:
2805       switch (activation->getKind()) {
2806       case Kinded::Kind::ReluNodeKind:
2807       case Kinded::Kind::SigmoidNodeKind:
2808       case Kinded::Kind::TanhNodeKind:
2809         return true;
2810       default:
2811         return false;
2812       }
2813     default:
2814       return false;
2815     }
2816   }
2817 };
2818 
2819 #define CONV_ACTIVATION_TEST(ACTIVATION_, CREATOR_)                            \
2820   TEST_F(GraphFold, FoldConv##ACTIVATION_##Activation) {                       \
2821     auto *A =                                                                  \
2822         mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false); \
2823     ConvolutionNode *CV =                                                      \
2824         F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);                  \
2825     auto *AN = F_->CREATOR_(#ACTIVATION_, CV);                                 \
2826     SaveNode *SN = F_->createSave("ret", AN);                                  \
2827                                                                                \
2828     EXPECT_EQ(F_->getNodes().size(), 3);                                       \
2829                                                                                \
2830     CompilationContext cctx;                                                   \
2831     auto B = MockBackendWithFusion();                                          \
2832     ::glow::fold(F_, cctx, &B);                                                \
2833                                                                                \
2834     ConvolutionNode *fusedCV =                                                 \
2835         llvm::dyn_cast<ConvolutionNode>(SN->getInput());                       \
2836     ASSERT_TRUE(fusedCV);                                                      \
2837     EXPECT_EQ(fusedCV->getFusedActivation(), FusedActivation::ACTIVATION_);    \
2838   }
2839 
2840 CONV_ACTIVATION_TEST(RELU, createRELU);
2841 CONV_ACTIVATION_TEST(SIGMOID, createSigmoid);
2842 CONV_ACTIVATION_TEST(TANH, createTanh);
2843 
2844 #undef CONV_ACTIVATION_TEST
2845 
2846 /// This test ensures that if there is a RescaleNode whose input has multiple
2847 /// users that the input is not cloned, as this duplicates the node.
TEST_F(GraphOptz,MultipleUsersRescaleCombineNoOpt)2848 TEST_F(GraphOptz, MultipleUsersRescaleCombineNoOpt) {
2849   auto opOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10}, 1, 0);
2850   auto rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10}, 2, 1);
2851 
2852   Node *LHS =
2853       mod_.createPlaceholder(ElemKind::Int8QTy, {10}, 0.4, 0, "LHS", true);
2854   Node *RHS =
2855       mod_.createPlaceholder(ElemKind::Int8QTy, {10}, 0.3, 0, "RHS", true);
2856 
2857   AddNode *AN = F_->createAdd("qAdd", opOutTy, LHS, RHS);
2858   RescaleQuantizedNode *RQN =
2859       F_->createRescaleQuantized("rsAdd", AN, rescaleOutTy);
2860   SaveNode *saveRQN = F_->createSave("saveRQN", RQN);
2861   SaveNode *saveAN = F_->createSave("saveAN", AN);
2862 
2863   EXPECT_EQ(F_->getNodes().size(), 4);
2864 
2865   ::glow::optimize(F_, CompilationMode::Infer);
2866 
2867   // The graph should be unchanged.
2868   EXPECT_EQ(F_->getNodes().size(), 4);
2869   EXPECT_EQ(saveRQN->getInput().getNode(), RQN);
2870   EXPECT_EQ(RQN->getInput().getNode(), AN);
2871   EXPECT_EQ(saveAN->getInput().getNode(), AN);
2872   EXPECT_EQ(AN->getLHS().getNode(), LHS);
2873   EXPECT_EQ(AN->getRHS().getNode(), RHS);
2874 }
2875 
2876 /// This test ensures that fusing of rescale into MatMul is done.
TEST_F(GraphOptz,FuseRescaleIntoMatMul)2877 TEST_F(GraphOptz, FuseRescaleIntoMatMul) {
2878   auto opOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10, 10}, 1, 0);
2879   auto rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10, 10}, 2, 1);
2880 
2881   Placeholder *LHS =
2882       mod_.createPlaceholder(ElemKind::Int8QTy, {10, 10}, 0.4, 0, "LHS", true);
2883   Placeholder *RHS =
2884       mod_.createPlaceholder(ElemKind::Int8QTy, {10, 10}, 0.3, 0, "RHS", true);
2885 
2886   RescaleQuantizedNode *LHSR =
2887       F_->createRescaleQuantized("rs1", LHS, rescaleOutTy);
2888   RescaleQuantizedNode *RHSR =
2889       F_->createRescaleQuantized("rs2", RHS, rescaleOutTy);
2890   MatMulNode *MN = F_->createMatMul("qMatMul", opOutTy, LHSR, RHSR);
2891   SaveNode *SN = F_->createSave("save", MN);
2892 
2893   // All rescales must be fused into arithmetic operations above.
2894   ::glow::optimize(F_, CompilationMode::Infer);
2895 
2896   // Only the MatMul and Save should be left.
2897   EXPECT_EQ(F_->getNodes().size(), 2);
2898   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::RescaleQuantizedNodeKind), 0);
2899 
2900   MatMulNode *newMN = llvm::dyn_cast<MatMulNode>(SN->getInput());
2901   ASSERT_TRUE(newMN);
2902   Placeholder *LPH = llvm::dyn_cast<Placeholder>(newMN->getLHS());
2903   EXPECT_EQ(LPH, LHS);
2904   Placeholder *RPH = llvm::dyn_cast<Placeholder>(newMN->getRHS());
2905   EXPECT_EQ(RPH, RHS);
2906 }
2907 
TEST_F(GraphOptz,sinkRescaledQuantizedNode)2908 TEST_F(GraphOptz, sinkRescaledQuantizedNode) {
2909   // Check that we eliminate rescale nodes by sinking them into other
2910   // operators.
2911   auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11,
2912                                        "input", true);
2913 
2914   // slice -> rescale -> reshape -> rescale -> transpose -> maxpool -> save.
2915   auto *slice = F_->createSlice("slice", input, {0, 0}, {2, 4});
2916   auto *rescale = F_->createRescaleQuantized(
2917       "rescale", slice, mod_.uniqueType(ElemKind::Int8QTy, {2, 4}, 0.4, 10));
2918   auto *reshape = F_->createReshape("reshape", rescale, {1, 2, 2, 2});
2919   auto *rescale2 = F_->createRescaleQuantized(
2920       "rescale", reshape,
2921       mod_.uniqueType(ElemKind::Int8QTy, {1, 2, 2, 2}, 0.3, 9));
2922   auto *transpose = F_->createTranspose("transpose", rescale2, {0, 2, 3, 1});
2923   auto *maxpool =
2924       F_->createMaxPool("maxpool", transpose, {2, 2}, {1, 1}, {0, 0, 0, 0});
2925   auto *save = F_->createSave("ret", maxpool->getResult());
2926 
2927   EXPECT_EQ(F_->getNodes().size(), 7);
2928   ::glow::optimize(F_, CompilationMode::Infer);
2929   EXPECT_EQ(F_->getNodes().size(), 6);
2930   // Check that rescale sank all the way down to the save node.
2931   EXPECT_TRUE(llvm::dyn_cast<RescaleQuantizedNode>(save->getInput()));
2932 }
2933 
TEST_F(GraphOptz,mergeRescaleWithArithmeticNode)2934 TEST_F(GraphOptz, mergeRescaleWithArithmeticNode) {
2935   // Check that Arithmetic operations can be merged with the Rescale.
2936   auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11,
2937                                        "input", true);
2938 
2939   auto *rescale1 = F_->createRescaleQuantized(
2940       "rescale", input, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.4, 11));
2941   auto *add = F_->createAdd("add", rescale1, rescale1);
2942   auto *rescale2 = F_->createRescaleQuantized(
2943       "rescale", add, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.3, 11));
2944   auto *sub = F_->createSub("sub", rescale2, rescale2);
2945   auto *rescale3 = F_->createRescaleQuantized(
2946       "rescale", sub, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.2, 11));
2947   auto *mul = F_->createMul("mul", rescale3, rescale3);
2948   auto *rescale4 = F_->createRescaleQuantized(
2949       "rescale", mul, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.1, 11));
2950   auto *div = F_->createDiv("div", rescale4, rescale4);
2951   F_->createSave("save", div);
2952 
2953   EXPECT_EQ(F_->getNodes().size(), 9);
2954   ::glow::optimize(F_, CompilationMode::Infer);
2955   EXPECT_EQ(F_->getNodes().size(), 5);
2956 }
2957 
2958 /// Check that Relu can be merged with Rescale.
TEST_F(GraphOptz,mergeRescaleWithRelu)2959 TEST_F(GraphOptz, mergeRescaleWithRelu) {
2960   auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11,
2961                                        "input", false);
2962 
2963   auto *rescale1 = F_->createRescaleQuantized(
2964       "rescale", input, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.4, 11));
2965   auto *relu = F_->createRELU("relu", rescale1);
2966   F_->createSave("save", relu);
2967 
2968   // Rescale, RELU, Save nodes.
2969   EXPECT_EQ(F_->getNodes().size(), 3);
2970 
2971   ::glow::optimize(F_, CompilationMode::Infer);
2972 
2973   // RELU, Save nodes left; Rescale merged into RELU.
2974   EXPECT_EQ(F_->getNodes().size(), 2);
2975   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::RescaleQuantizedNodeKind), 0);
2976   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 1);
2977 }
2978 
2979 // Check that we are able to merge some small matmuls into a larger one.
TEST_F(GraphOptz,mergeMatMulNodes)2980 TEST_F(GraphOptz, mergeMatMulNodes) {
2981   Node *input =
2982       mod_.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input", true);
2983   Node *weight =
2984       mod_.createPlaceholder(ElemKind::FloatTy, {10, 10}, "weight", true);
2985 
2986   // Split the input to a bunch of small slices.
2987   std::vector<NodeValue> inputs;
2988   for (dim_t i = 0; i < 10; i++) {
2989     auto *K = F_->createSlice("extract", input, {i, 0, 0}, {i + 1, 10, 10});
2990     auto *R = F_->createReshape("reshape", K, {10, 10});
2991     auto *MM = F_->createMatMul("mm", R, weight);
2992     inputs.push_back(MM);
2993   }
2994 
2995   auto *cc = F_->createConcat("merge", inputs, 0);
2996   F_->createSave("save", cc);
2997 
2998   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 10);
2999   ::glow::optimize(F_, CompilationMode::Infer);
3000 
3001   // Check that all of the matmuls are merged into a single matmul node.
3002   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 1);
3003 }
3004 
3005 // Check that we are able to merge batched adds.
TEST_F(GraphOptz,mergeBANodes)3006 TEST_F(GraphOptz, mergeBANodes) {
3007   Node *input =
3008       mod_.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input", true);
3009   Node *slice =
3010       mod_.createPlaceholder(ElemKind::FloatTy, {10, 10}, "weight", true);
3011 
3012   // Split the input to a bunch of small slices.
3013   std::vector<NodeValue> inputs;
3014   for (dim_t i = 0; i < 10; i++) {
3015     auto *K = F_->createSlice("extract", input, {i, 0, 0}, {i + 1, 10, 10});
3016     auto *MM = F_->createBatchedAdd("BA", K, slice);
3017     inputs.push_back(MM);
3018   }
3019 
3020   auto *cc = F_->createConcat("merge", inputs, 0);
3021   F_->createSave("save", cc);
3022 
3023   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchedAddNodeKind), 10);
3024   ::glow::optimize(F_, CompilationMode::Infer);
3025 
3026   // Check that all of the batched-adds are merged into a single node.
3027   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchedAddNodeKind), 1);
3028 }
3029 
3030 /// Check that EliminateNoop optimization pass removes nodes which don't do
3031 /// anything useful.
TEST_F(GraphOptz,eliminateNoop)3032 TEST_F(GraphOptz, eliminateNoop) {
3033   std::vector<dim_t> shape = {1, 2, 2, 3};
3034   Placeholder *input1 = mod_.createPlaceholder(ElemKind::Int8QTy, shape, 0.004,
3035                                                0, "input", false);
3036   Placeholder *input2 = mod_.createPlaceholder(ElemKind::Int8QTy, shape, 0.004,
3037                                                0, "input", false);
3038   auto *cond = mod_.createConstant(ElemKind::BoolTy, shape, "input1");
3039   cond->getHandle<bool>() = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
3040 
3041   auto *select = F_->createSelect("select", cond, input1, input2);
3042   auto *slice = F_->createSlice("slice", select, {0, 0, 0, 0}, shape);
3043   auto *tile = F_->createTile("tile", slice, 1, 1);
3044   auto *pad = F_->createPad("pad", tile, tile->getResult().getType(), 0,
3045                             {0, 0, 0, 0, 0, 0, 0, 0}, 0);
3046   auto *avgPool = F_->createAvgPool("avgpool", pad, 1, 1, 0);
3047   auto *maxPool = F_->createMaxPool("maxpool", avgPool, 1, 1, 0);
3048 
3049   F_->createSave("save", maxPool->getResult());
3050 
3051   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SelectNodeKind), 1);
3052   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 1);
3053   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TileNodeKind), 1);
3054   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::PadNodeKind), 1);
3055   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AvgPoolNodeKind), 1);
3056   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MaxPoolNodeKind), 1);
3057 
3058   optimizedF_ = optimizeFunction(F_);
3059 
3060   // Check that all nodes except for Save are eliminated.
3061   EXPECT_EQ(optimizedF_->getNodes().size(), 1);
3062 
3063   bindings_.allocate(mod_.getPlaceholders());
3064   bindings_.get(input1)->getHandle<int8_t>().randomize(-1.0, 1.0,
3065                                                        mod_.getPRNG());
3066   bindings_.get(input2)->getHandle<int8_t>().randomize(-1.0, 1.0,
3067                                                        mod_.getPRNG());
3068 
3069   checkNumericalEquivalence();
3070 }
3071 
3072 // Check that we are able to replace
3073 // Add(I, tile(B)) with -> BatchedAdd(I, B).
TEST_F(GraphOptz,FoldTileAddIntoBatchedAdd)3074 TEST_F(GraphOptz, FoldTileAddIntoBatchedAdd) {
3075   auto *batch =
3076       mod_.createPlaceholder(ElemKind::FloatTy, {3, 1, 2}, "batch", false);
3077   auto *added = mod_.createConstant(ElemKind::FloatTy, {1, 1, 2}, "added");
3078   auto *addedTiled = F_->createTile("addedTiled", added, 3, 0);
3079   auto *add = F_->createAdd("add", batch, addedTiled);
3080   auto *save = F_->createSave("save", add);
3081   auto *output = save->getPlaceholder();
3082 
3083   bindings_.allocate(batch)->getHandle() = {2, 2, 3, 3, 4, 4};
3084   added->getPayloadMutable().getHandle() = {1, 1};
3085   bindings_.allocate(output);
3086 
3087   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TileNodeKind), 1);
3088   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AddNodeKind), 1);
3089   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchedAddNodeKind), 0);
3090 
3091   ASSERT_TRUE(F_->verify());
3092 
3093   // Currently the FoldTileAddIntoBatchedAdd opt which we're testing here is not
3094   // part of the default optimization pipeline. Create a local version of the
3095   // pipeline with that pass included.
3096   auto p = createDefaultGraphOptimizationPassPipeline();
3097   p->pushFront({FunctionPassID::FoldTileAddIntoBatchedAdd});
3098   FunctionPassManager FPM("opt", std::move(p));
3099   FPM.run(F_, CompilationContext());
3100   ASSERT_TRUE(F_->verify());
3101 
3102   // Check that the Tile node and the Add node is replaced by
3103   // a BatchedAdd node.
3104   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TileNodeKind), 0);
3105   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AddNodeKind), 0);
3106   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchedAddNodeKind), 1);
3107 
3108   // Verify the correctness of the input to BatchedAdd operator.
3109   // The correctness of BatchedAdd operator itself is verified
3110   // by operator's unit tests.
3111   Tensor expectedBatch(ElemKind::FloatTy, {3, 1, 2});
3112   expectedBatch.getHandle() = {2, 2, 3, 3, 4, 4};
3113   Tensor expectedSlice(ElemKind::FloatTy, {1, 2});
3114   expectedSlice.getHandle() = {1, 1};
3115   for (auto &node : F_->getNodes()) {
3116     auto *recvdBANode = llvm::dyn_cast<BatchedAddNode>(&node);
3117     if (!recvdBANode) {
3118       continue;
3119     }
3120     auto *recvdBatch = llvm::dyn_cast<Placeholder>(recvdBANode->getBatch());
3121     ASSERT_TRUE(recvdBatch);
3122     auto *recvdSlice = llvm::dyn_cast<Constant>(recvdBANode->getSlice());
3123     ASSERT_TRUE(recvdSlice);
3124     EXPECT_TRUE(recvdBatch->dims().equals({3, 1, 2}));
3125     EXPECT_TRUE(recvdSlice->dims().equals({1, 2}));
3126     EXPECT_TRUE(bindings_.get(recvdBatch)->isEqual(expectedBatch));
3127     EXPECT_TRUE(recvdSlice->getPayload().isEqual(expectedSlice));
3128     break;
3129   }
3130 }
3131 
3132 /// Test Concat(Slice, ..., Slice) opt works correctly. If \p reverseOrder then
3133 /// the optimization is inapplicable and should not occur.
testConcatElim(Module & mod,Function * F,Function * & optimizedF,PlaceholderBindings & bindings,bool reverseOrder)3134 static void testConcatElim(Module &mod, Function *F, Function *&optimizedF,
3135                            PlaceholderBindings &bindings, bool reverseOrder) {
3136   Placeholder *input =
3137       mod.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input", true);
3138   bindings.allocate(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
3139 
3140   // Split the input to a bunch of small slices.
3141   std::array<NodeValue, 10> inputs;
3142   for (dim_t i = 0; i < 10; i++) {
3143     dim_t idx = reverseOrder ? 9 - i : i;
3144     inputs[i] =
3145         F->createSlice("extract", input, {idx, 0, 0}, {idx + 1, 10, 10});
3146   }
3147 
3148   auto *cc = F->createConcat("merge", inputs, 0);
3149   F->createSave("save", cc);
3150 
3151   EXPECT_EQ(countNodeKind(F, Kinded::Kind::SliceNodeKind), 10);
3152 
3153   optimizedF = optimizeFunction(F);
3154 
3155   // Check that either the concat and slices are gone if the optimization was
3156   // applicable, or otherwise that they're still there.
3157   EXPECT_EQ(countNodeKind(optimizedF, Kinded::Kind::ConcatNodeKind),
3158             reverseOrder ? 1 : 0);
3159   EXPECT_EQ(countNodeKind(optimizedF, Kinded::Kind::SliceNodeKind),
3160             reverseOrder ? 10 : 0);
3161 }
3162 
3163 // Check that we are able to eliminate concat nodes.
TEST_F(GraphOptz,concatElim)3164 TEST_F(GraphOptz, concatElim) {
3165   testConcatElim(mod_, F_, optimizedF_, bindings_, /* reverseOrder */ false);
3166   checkNumericalEquivalence(0.0f);
3167 }
3168 
3169 // Check that when the order of the Slices is reversed no optimization kicks in.
TEST_F(GraphOptz,concatElimReverseOrder)3170 TEST_F(GraphOptz, concatElimReverseOrder) {
3171   testConcatElim(mod_, F_, optimizedF_, bindings_, /* reverseOrder */ true);
3172   checkNumericalEquivalence(0.0f);
3173 }
3174 
3175 /// Check that we are able to eliminate concat nodes with redundant arithmetic
3176 /// ops in way.
TEST_F(GraphOptz,concatArithElim)3177 TEST_F(GraphOptz, concatArithElim) {
3178   auto *input =
3179       mod_.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input", true);
3180   bindings_.allocate(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
3181 
3182   Type t(ElemKind::FloatTy, {1, 10, 10});
3183   Node *one = F_->createSplat("one", &t, 1.0);
3184   Node *zero = F_->createSplat("zero", &t, 0.0);
3185 
3186   // Split the input to a bunch of small slices.
3187   std::vector<NodeValue> inputs;
3188   for (dim_t i = 0; i < 10; i++) {
3189     auto *K = F_->createSlice("extract", input, {i, 0, 0}, {i + 1, 10, 10});
3190     // Insert the nodes in reverse order to make sure that we can catch
3191     // non-consecutive graph-order slices.
3192     Node *N = K;
3193     switch (i) {
3194     case 0:
3195       N = F_->createAdd("add0", K, zero);
3196       break;
3197     case 1:
3198       N = F_->createSub("sub0", K, zero);
3199       break;
3200     case 2:
3201       N = F_->createAdd("add_0", zero, K);
3202       break;
3203     case 3:
3204       N = F_->createMul("mul1", K, one);
3205       break;
3206     case 4:
3207       N = F_->createDiv("div1", K, one);
3208       break;
3209     case 5:
3210       N = F_->createMul("mul_1", one, K);
3211       break;
3212     default:
3213       break;
3214     }
3215     inputs.push_back(N);
3216   }
3217 
3218   auto *cc = F_->createConcat("merge", inputs, 0);
3219   F_->createSave("save", cc);
3220   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 10);
3221   optimizedF_ = optimizeFunction(F_);
3222 
3223   // Check that the concat node is gone.
3224   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 0);
3225   checkNumericalEquivalence(0.0f);
3226 }
3227 
3228 /// Check that we are able to eliminate concat followed by slices on axis
3229 /// \p dim under certain conditions.
testConcatSliceElim(Module & mod,Function * F,Function * & optimizedF,PlaceholderBindings & bindings,size_t dim)3230 static void testConcatSliceElim(Module &mod, Function *F, Function *&optimizedF,
3231                                 PlaceholderBindings &bindings, size_t dim) {
3232   constexpr size_t N = 5;
3233   std::array<NodeValue, N> inputs;
3234   std::vector<dim_t> inShape = {10, 20};
3235   inShape.insert(inShape.begin() + dim, 0);
3236   for (dim_t i = 0; i < N; i++) {
3237     inShape[dim] = 1 + i;
3238     auto *P = mod.createPlaceholder(ElemKind::FloatTy, inShape, "in", true);
3239     bindings.allocate(P)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
3240     inputs[i] = P;
3241   }
3242   auto *CN = F->createConcat("merge", inputs, dim);
3243 
3244   // Split the concat to a bunch of slices of the same shape as the concat
3245   // inputs and on the same axis.
3246   std::vector<dim_t> startShape = {0, 0, 0};
3247   std::vector<dim_t> endShape = {10, 20};
3248   endShape.insert(endShape.begin() + dim, 0);
3249   for (dim_t i = 0; i < N; i++) {
3250     startShape[dim] = (i * (i + 1)) / 2;
3251     endShape[dim] = ((i + 1) * (i + 2)) / 2;
3252     auto *SN = F->createSlice("extract", CN, startShape, endShape);
3253     F->createSave("save", SN);
3254   }
3255 
3256   // We created a concat followed by N slices of its results.
3257   EXPECT_EQ(countNodeKind(F, Kinded::Kind::SliceNodeKind), N);
3258   EXPECT_EQ(countNodeKind(F, Kinded::Kind::ConcatNodeKind), 1);
3259 
3260   optimizedF = optimizeFunction(F);
3261 
3262   // Check that the concat and slices are gone.
3263   EXPECT_EQ(countNodeKind(optimizedF, Kinded::Kind::ConcatNodeKind), 0);
3264   EXPECT_EQ(countNodeKind(optimizedF, Kinded::Kind::SliceNodeKind), 0);
3265 }
3266 
TEST_F(GraphOptz,concatSliceElimInnerDim)3267 TEST_F(GraphOptz, concatSliceElimInnerDim) {
3268   testConcatSliceElim(mod_, F_, optimizedF_, bindings_, 0);
3269   checkNumericalEquivalence(0.0f);
3270 }
3271 
TEST_F(GraphOptz,concatSliceElimMiddleDim)3272 TEST_F(GraphOptz, concatSliceElimMiddleDim) {
3273   testConcatSliceElim(mod_, F_, optimizedF_, bindings_, 1);
3274   checkNumericalEquivalence(0.0f);
3275 }
3276 
TEST_F(GraphOptz,concatSliceElimOuterDim)3277 TEST_F(GraphOptz, concatSliceElimOuterDim) {
3278   testConcatSliceElim(mod_, F_, optimizedF_, bindings_, 2);
3279   checkNumericalEquivalence(0.0f);
3280 }
3281 
3282 /// Check the interaction between Sices(Concat) and Concat(Slices) optimizations
3283 /// to make sure they work nicely together. Builds Concat(Slices(Concat)) and
3284 /// expected a single Concat after optimizations.
TEST_F(GraphOptz,concatSliceElimMultiConcat)3285 TEST_F(GraphOptz, concatSliceElimMultiConcat) {
3286   std::array<NodeValue, 4> inputs;
3287   for (size_t i = 0; i < 4; i++) {
3288     auto *P = mod_.createPlaceholder(ElemKind::FloatTy, {2, 4},
3289                                      "in_" + std::to_string(i), false);
3290     bindings_.allocate(P)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
3291     inputs[i] = P;
3292   }
3293   auto *CN0 = F_->createConcat("merge0", inputs, /* axis */ 1);
3294 
3295   auto *SN0 = F_->createSlice("slice0", CN0, {0, 0}, {2, 4});
3296   auto *SN1 = F_->createSlice("slice1", CN0, {0, 4}, {2, 8});
3297   auto *SN2 = F_->createSlice("slice2", CN0, {0, 8}, {2, 12});
3298   auto *SN3 = F_->createSlice("slice3", CN0, {0, 12}, {2, 16});
3299 
3300   auto *CN1 = F_->createConcat("merge1", {SN1, SN0, SN3, SN2}, /* axis */ 1);
3301   F_->createSave("save", CN1);
3302 
3303   // We created a concat followed by 4 slices of its results followed by another
3304   // concat.
3305   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConcatNodeKind), 2);
3306   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 4);
3307 
3308   optimizedF_ = optimizeFunction(F_);
3309 
3310   // Check that one concat and slices are gone.
3311   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1);
3312   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SliceNodeKind), 0);
3313 
3314   checkNumericalEquivalence(0.0f);
3315 }
3316 
3317 // Check the transformation Concat(Reshape(x) * N) -> Reshape(Concat(x * N)).
TEST_F(GraphOptz,concatReshapes)3318 TEST_F(GraphOptz, concatReshapes) {
3319   const dim_t shape1[] = {2, 5, 2, 1, 20};
3320   const dim_t shape2[] = {10, 2, 2, 10};
3321   const dim_t shape3[] = {5, 80};
3322   llvm::SmallVector<NodeValue, 10> inputs1;
3323   llvm::SmallVector<NodeValue, 10> inputs2;
3324   for (size_t i = 0; i < 10; i++) {
3325     // 10 reshape nodes that transform from {2,5,2,1,20} to {10,2,2,10}.
3326     // And a ConcatNode concatenates the outputs of reshape at 2nd dim.
3327     // The optimization would kick in, as the size of trailing dimensions of
3328     // original ConcatNode (before opt) is 20, and the size of leading
3329     // dimensions of original ConcatNode (before opt) is 10.
3330     Node *var = F_->getParent()->createPlaceholder(
3331         ElemKind::FloatTy, shape1, "input" + std::to_string(i), true);
3332     auto *RN = F_->createReshape("reshape" + std::to_string(i), var, shape2);
3333     inputs1.push_back(RN);
3334   }
3335   auto *concatNode1 = F_->createConcat("concat", inputs1, 1);
3336   for (size_t i = 0; i < 10; i++) {
3337     // 10 reshape nodes that transform from {5,80} to {10,1,2,10}.
3338     // And a ConcatNode concatenates the outputs of reshape at 2nd dim.
3339     // The optimization would NOT kick in, as we cannot find the dim that
3340     // makes the leading/trailing dims same as in the case of the original
3341     // concat node.
3342     Node *var = F_->getParent()->createPlaceholder(
3343         ElemKind::FloatTy, shape3, "input" + std::to_string(i), true);
3344     auto *RN = F_->createReshape("reshape" + std::to_string(i), var, shape2);
3345     inputs2.push_back(RN);
3346   }
3347   auto *concatNode2 = F_->createConcat("concat", inputs2, 1);
3348   auto outputShape = concatNode1->getResult().dims();
3349   // Need to dereference the RN vectors, otherwise the user number of those
3350   // nodes would always be positive, making them unable to be removed by DCE.
3351   inputs1.clear();
3352   inputs2.clear();
3353 
3354   auto *addNode = F_->createAdd("add", concatNode1, concatNode2);
3355   auto *O = F_->createSave("ret", addNode);
3356 
3357   EXPECT_EQ(F_->getNodes().size(), 24);
3358 
3359   ::glow::optimize(F_, CompilationMode::Infer);
3360 
3361   // After optimization, we expect to see only 15 nodes. All 10 of the
3362   // reshapes that were the inputs to the first original concat node
3363   // (concatNode1) are removed, and a single new reshape is added after the
3364   // new concat.
3365   EXPECT_EQ(F_->getNodes().size(), 15);
3366 
3367   // concatNode1 should not exist any more.
3368   EXPECT_FALSE(functionContainsNode(F_, concatNode1));
3369   // concatNode2 should still exist.
3370   EXPECT_TRUE(functionContainsNode(F_, concatNode2));
3371 
3372   // The first input of addNode should be a Reshape node now, with the same
3373   // result shape of concatNode1.
3374   auto *newAddNode = llvm::dyn_cast<AddNode>(O->getInput());
3375   ASSERT_TRUE(newAddNode);
3376   auto *newRN = llvm::dyn_cast<ReshapeNode>(newAddNode->getLHS());
3377   ASSERT_TRUE(newRN);
3378   EXPECT_TRUE(newRN->getResult().getType()->dims().equals(outputShape));
3379 
3380   // The input of newRN should be a ConcatNode now.
3381   auto *newCN = llvm::dyn_cast<ConcatNode>(newRN->getInput());
3382   ASSERT_TRUE(newCN);
3383 }
3384 
3385 // Making sure we do not try to to optimize concat2(dim1, concat1(dim2, X, Y),
3386 // Z)
3387 // -> concat(dim1, X, Y, Z) when concat1 has multiple users.
TEST_F(GraphOptz,ConcatSimplificationNegative)3388 TEST_F(GraphOptz, ConcatSimplificationNegative) {
3389   const dim_t dim1[] = {1, 4, 4, 4};
3390   const dim_t dim2[] = {1, 4, 4, 8};
3391   auto *in1 = mod_.createPlaceholder(ElemKind::FloatTy, dim1, "in1", false);
3392   auto *in2 = mod_.createPlaceholder(ElemKind::FloatTy, dim1, "in2", false);
3393   auto *in3 = mod_.createPlaceholder(ElemKind::FloatTy, dim2, "in3", false);
3394 
3395   auto *cnc1 = F_->createConcat("cnc1", {in1, in2}, 3);
3396   auto *add1 = F_->createAdd("add1", in3, cnc1);
3397   auto *cnc2 = F_->createConcat("cnc2", {add1, cnc1}, 3);
3398   F_->createSave("ret", cnc2);
3399   EXPECT_EQ(F_->getNodes().size(), 4);
3400   ::glow::optimize(F_, CompilationMode::Infer);
3401   EXPECT_EQ(F_->getNodes().size(), 4);
3402   for (auto &n : F_->getNodes()) {
3403     if (auto *tcnc = llvm::dyn_cast<ConcatNode>(&n)) {
3404       EXPECT_EQ(tcnc->getNumInputs(), 2);
3405     }
3406   }
3407 }
3408 
3409 /// Check that Variable CSE works correctly, combining small Variables that
3410 /// have the same data.
TEST_F(GraphOptz,VarsCSE)3411 TEST_F(GraphOptz, VarsCSE) {
3412   // Create three variables that are Private, are not trainable, and have no
3413   // writers. The first two variables have the same data, and so should be
3414   // combined via variable CSE. The third variable differs by the last value,
3415   // and so should not be combined.
3416   auto *input1 = mod_.createConstant(ElemKind::FloatTy, {10}, "input1");
3417   auto *input2 = mod_.createConstant(ElemKind::FloatTy, {10}, "input2");
3418   auto *input3 = mod_.createConstant(ElemKind::FloatTy, {10}, "input3");
3419   input1->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
3420   input2->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
3421   input3->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, -1};
3422 
3423   // Input them each to different nodes, so node CSE does not change them.
3424   auto *TN = F_->createTanh("tanh", input1);
3425   auto *SN = F_->createSigmoid("sigmoid", input2);
3426   auto *RN = F_->createRELU("relu", input3);
3427   auto *CN = F_->createConcat("concat", {TN, SN, RN}, /* axis */ 0);
3428   F_->createSave("ret", CN);
3429 
3430   // Initially there are three variables: inputs 1, 2, and 3 (the save uses a
3431   // placeholder).
3432   EXPECT_EQ(mod_.getConstants().size(), 3);
3433 
3434   cctx_.compMode = CompilationMode::Infer;
3435   // Do not perform any compile-time constant folding.
3436   cctx_.optimizationOpts.enableConstantFolding = false;
3437   ::glow::optimize(F_, cctx_);
3438 
3439   // Now only two variables are left; input1 and input2 have been combined,
3440   // but input3 has not.
3441   EXPECT_EQ(mod_.getConstants().size(), 2);
3442 
3443   // Verify that only one of input1 and input2 exists, and that input3 still
3444   // exists.
3445   Constant *varOneOrTwo = nullptr;
3446   bool foundVarThree = false;
3447   for (auto *V : mod_.getConstants()) {
3448     if (V == input1 || V == input2) {
3449       EXPECT_TRUE(varOneOrTwo == nullptr);
3450       varOneOrTwo = V;
3451     } else if (V == input3) {
3452       foundVarThree = true;
3453     }
3454   }
3455   EXPECT_TRUE(varOneOrTwo != nullptr);
3456   EXPECT_TRUE(foundVarThree);
3457 
3458   // Verify that the users of the inputs are updated correctly.
3459   EXPECT_TRUE(TN->getInput().getNode() == varOneOrTwo);
3460   EXPECT_TRUE(SN->getInput().getNode() == varOneOrTwo);
3461   EXPECT_TRUE(RN->getInput().getNode() == input3);
3462 
3463   // Verify that whichever input1/input2 is left over has two users TN and SN.
3464   EXPECT_TRUE(varOneOrTwo->getUsers().size() == 2);
3465   for (auto &U : varOneOrTwo->getUsers()) {
3466     auto *N = U.getUser();
3467     EXPECT_TRUE(N == TN || N == SN);
3468   }
3469 
3470   // Verify that input3 only has a single user RN.
3471   ASSERT_TRUE(input3->getUsers().size() == 1);
3472   EXPECT_TRUE(input3->getUsers().begin()->getUser() == RN);
3473 }
3474 
TEST_F(GraphOptz,VarsCSENaN)3475 TEST_F(GraphOptz, VarsCSENaN) {
3476   // Create two variables that are Private, are not trainable, have no writers
3477   // and include NaNs. The first two variables have the same data, and so should
3478   // be combined via variable CSE.  In particular, the NaN constants should not
3479   // prevent the variables from being combine.
3480   auto *input1 = mod_.createConstant(ElemKind::FloatTy, {5}, "input1");
3481   auto *input2 = mod_.createConstant(ElemKind::FloatTy, {5}, "input2");
3482   input1->getHandle() = {0, NAN, 2, NAN, 4};
3483   input2->getHandle() = {0, NAN, 2, NAN, 4};
3484 
3485   // Input them each to different nodes, so node CSE does not change them.
3486   auto *TN = F_->createTanh("tanh", input1);
3487   auto *SN = F_->createSigmoid("sigmoid", input2);
3488   auto *CN = F_->createConcat("concat", {TN, SN}, /* axis */ 0);
3489   F_->createSave("ret", CN);
3490 
3491   // Initially there are two variables: inputs 1 and 2 (the save uses a
3492   // placeholder).
3493   EXPECT_EQ(mod_.getConstants().size(), 2);
3494 
3495   cctx_.compMode = CompilationMode::Infer;
3496   // Do not perform any compile-time constant folding.
3497   cctx_.optimizationOpts.enableConstantFolding = false;
3498   ::glow::optimize(F_, cctx_);
3499 
3500   // Now only one variables is left; input1 and input2 have been combined.
3501   EXPECT_EQ(mod_.getConstants().size(), 1);
3502 
3503   // Verify that only one of input1 and input2 exists.
3504   Constant *varOneOrTwo = nullptr;
3505   for (auto *V : mod_.getConstants()) {
3506     if (V == input1 || V == input2) {
3507       EXPECT_TRUE(varOneOrTwo == nullptr);
3508       varOneOrTwo = V;
3509     }
3510   }
3511   EXPECT_TRUE(varOneOrTwo != nullptr);
3512 
3513   // Verify that the users of the inputs are updated correctly.
3514   EXPECT_TRUE(TN->getInput().getNode() == varOneOrTwo);
3515   EXPECT_TRUE(SN->getInput().getNode() == varOneOrTwo);
3516 
3517   // Verify that whichever input1/input2 is left over has two users TN and SN.
3518   EXPECT_TRUE(varOneOrTwo->getUsers().size() == 2);
3519   for (auto &U : varOneOrTwo->getUsers()) {
3520     auto *N = U.getUser();
3521     EXPECT_TRUE(N == TN || N == SN);
3522   }
3523 }
3524 
3525 // Verify that constant input canonicalization works correctly when the
3526 // arithmetic nodes have multiple users.
TEST_F(GraphOptz,simplifyArithmeticMultipleUsers)3527 TEST_F(GraphOptz, simplifyArithmeticMultipleUsers) {
3528   Node *I1 =
3529       mod_.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input1", false);
3530 
3531   Type t(ElemKind::FloatTy, {10, 10, 10});
3532   Node *SN = F_->createSplat("one", &t, 1.0);
3533 
3534   // The splat is a constant input to add1 and add2, and is their LHS input.
3535   // We expect canonicalization to occur during optimization, moving the splat
3536   // to the RHS for both. Note that add1 has multiple users: add2 and save1.
3537   Node *AN1 = F_->createAdd("add1", SN, I1);
3538   Node *AN2 = F_->createAdd("add2", SN, AN1);
3539   SaveNode *SN1 = F_->createSave("save1", AN1);
3540   SaveNode *SN2 = F_->createSave("save2", AN2);
3541 
3542   // Five nodes in total: one splat, two adds, and two saves.
3543   EXPECT_EQ(F_->getNodes().size(), 5);
3544   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SplatNodeKind), 1);
3545   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AddNodeKind), 2);
3546   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 2);
3547 
3548   // input1 has a single user before optimization.
3549   EXPECT_EQ(I1->getUsers().size(), 1);
3550 
3551   // Simplify nodes will canonicalize add1 and add2, and should replace all
3552   // their users, without otherwise adding new nodes to the graph/changing the
3553   // overall structure.
3554   ::glow::optimize(F_, CompilationMode::Infer);
3555 
3556   // We should have the same five nodes: one splat, two adds, and two saves.
3557   EXPECT_EQ(F_->getNodes().size(), 5);
3558   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SplatNodeKind), 1);
3559   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AddNodeKind), 2);
3560   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 2);
3561 
3562   // Verify that both add nodes were canonicalized, and that the graph's shape
3563   // is the same as prior to optimization other than canonicalization.
3564   AddNode *newAN1 = llvm::dyn_cast<AddNode>(SN1->getInput().getNode());
3565   ASSERT_TRUE(newAN1 != nullptr);
3566   EXPECT_TRUE(llvm::isa<Placeholder>(newAN1->getLHS()));
3567   EXPECT_TRUE(llvm::isa<SplatNode>(newAN1->getRHS()));
3568 
3569   AddNode *newAN2 = llvm::dyn_cast<AddNode>(SN2->getInput().getNode());
3570   ASSERT_TRUE(newAN2 != nullptr);
3571   EXPECT_TRUE(llvm::isa<AddNode>(newAN2->getLHS()));
3572   EXPECT_TRUE(llvm::isa<SplatNode>(newAN2->getRHS()));
3573 
3574   EXPECT_EQ(newAN1, newAN2->getLHS());
3575 
3576   // input1 should still have a single user after optimization.
3577   EXPECT_EQ(I1->getUsers().size(), 1);
3578 }
3579 
3580 /// Test that a concat with a single input is replaced by the input.
TEST_F(GraphOptz,eliminateSingleConcat)3581 TEST_F(GraphOptz, eliminateSingleConcat) {
3582   Node *input = mod_.createPlaceholder(ElemKind::FloatTy, {10}, "input", false);
3583 
3584   ConcatNode *CN = F_->createConcat("concat1", {input}, 0);
3585   SaveNode *SN = F_->createSave("ret", CN);
3586 
3587   // The ConcatNode and SaveNode.
3588   EXPECT_EQ(F_->getNodes().size(), 2);
3589 
3590   ::glow::optimize(F_, CompilationMode::Infer);
3591 
3592   // Just the SaveNode should be left.
3593   EXPECT_EQ(F_->getNodes().size(), 1);
3594   ASSERT_TRUE(functionContainsNode(F_, SN));
3595 
3596   // Save node should just save the input.
3597   EXPECT_TRUE(SN->getInput().getNode() == input);
3598 }
3599 
3600 /// Test that a reshape of a private variable with one use has the reshape
3601 /// merged into the variable.
TEST_F(GraphOptz,ReshapeConstantOneUse)3602 TEST_F(GraphOptz, ReshapeConstantOneUse) {
3603   const dim_t shape[] = {10, 20};
3604   const dim_t reshape1[] = {200, 1};
3605   const dim_t reshape2[] = {200};
3606   Constant *input =
3607       F_->getParent()->createConstant(ElemKind::FloatTy, shape, "input");
3608   input->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
3609 
3610   auto *R1 = F_->createReshape("reshape1", input, reshape1);
3611   auto *R2 = F_->createReshape("reshape2", R1, reshape2);
3612   auto *O = F_->createSave("ret", R2);
3613 
3614   // Before optimization, we have 2 Reshapes and a Save.
3615   EXPECT_EQ(F_->getNodes().size(), 3);
3616 
3617   // Skip ConstantFolding as it would have the same result as this opt.
3618   cctx_.optimizationOpts.enableConstantFolding = false;
3619   ::glow::optimize(F_, cctx_);
3620 
3621   // After optimization, we expect to see just a Save.
3622   EXPECT_EQ(F_->getNodes().size(), 1);
3623 
3624   // Save should have the new Variable as input.
3625   auto *V = llvm::dyn_cast<Constant>(O->getInput());
3626   ASSERT_TRUE(V);
3627   // The new Variable should have the same shape as the original second
3628   // Reshape.
3629   EXPECT_TRUE(V->getType()->dims().equals(reshape2));
3630 }
3631 
3632 /// Test that reshape node is merged into Constant in a sequence
3633 /// Reshape(Quantize(Constant)).
TEST_F(GraphOptz,ReshapeQuantizeConstant)3634 TEST_F(GraphOptz, ReshapeQuantizeConstant) {
3635   const dim_t shape[] = {10, 20};
3636   const dim_t newShape[] = {200, 1};
3637 
3638   auto *qTy = mod_.uniqueType(ElemKind::Int8QTy, shape, 0.2, 0);
3639 
3640   auto *input =
3641       F_->getParent()->createConstant(ElemKind::FloatTy, shape, "input");
3642   auto *Q = F_->createQuantize("quantize", input, qTy);
3643   auto *R = F_->createReshape("reshape", Q, newShape);
3644   auto *S = F_->createSave("ret", R);
3645 
3646   // Skip ConstantFolding as it would have the same result as this opt.
3647   CompilationContext cctx;
3648   cctx.optimizationOpts.enableConstantFolding = false;
3649 
3650   EXPECT_EQ(F_->getNodes().size(), 3);
3651   ::glow::optimize(F_, cctx);
3652   EXPECT_EQ(F_->getNodes().size(), 2);
3653 
3654   // Constant and Quantize should have new shape.
3655   auto *newQ = llvm::dyn_cast<QuantizeNode>(S->getInput());
3656   ASSERT_TRUE(newQ);
3657   EXPECT_TRUE(newQ->getResult().dims().equals(newShape));
3658   auto *newC = llvm::dyn_cast<Constant>(newQ->getInput());
3659   ASSERT_TRUE(newC);
3660   EXPECT_TRUE(newC->getType()->dims().equals(newShape));
3661 }
3662 
3663 /// Test that Transpose is optimized into Reshape when it moves no data.
TEST_F(GraphOptz,transposeIntoReshapeOptim)3664 TEST_F(GraphOptz, transposeIntoReshapeOptim) {
3665   auto *batch =
3666       mod_.createPlaceholder(ElemKind::FloatTy, {1, 3, 2, 4}, "batch", false);
3667   Node *T = F_->createTranspose("transpose", batch, {1, 2, 0, 3});
3668   SaveNode *O = F_->createSave("ret", T);
3669 
3670   EXPECT_EQ(F_->getNodes().size(), 2);
3671 
3672   ::glow::optimize(F_, CompilationMode::Infer);
3673   EXPECT_EQ(F_->getNodes().size(), 2);
3674 
3675   // TransposeNode is Optimized into ReshapeNode.
3676   auto *reshape = llvm::dyn_cast<ReshapeNode>(O->getInput().getNode());
3677   ASSERT_NE(reshape, nullptr);
3678 }
3679 
3680 /// Test that transpose is merged into matmul.
TEST_F(GraphOptz,mergeTransposeIntoMatMul)3681 TEST_F(GraphOptz, mergeTransposeIntoMatMul) {
3682   auto *input =
3683       mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 2, 3}, "input", false);
3684   auto *weights =
3685       F_->getParent()->createConstant(ElemKind::FloatTy, {6, 1}, "weights");
3686 
3687   weights->getHandle() = {0, 1, 2, 3, 4, 5};
3688   float newWeightsRef[] = {0, 2, 4, 1, 3, 5};
3689 
3690   auto *TN = F_->createTranspose("transpose", input, NHWC2NCHW);
3691   auto *RN = F_->createReshape("reshape", TN, {1, 6});
3692   auto *MMN = F_->createMatMul("matmul", RN, weights);
3693   auto *SN = F_->createSave("ret", MMN);
3694 
3695   // Transpose + Reshape + MatMul + Save.
3696   EXPECT_EQ(F_->getNodes().size(), 4);
3697 
3698   ::glow::optimize(F_, CompilationMode::Infer);
3699 
3700   // Reshape + MatMul + Save.
3701   EXPECT_EQ(F_->getNodes().size(), 3);
3702 
3703   // Check reordered weights.
3704   auto *newMMN = llvm::dyn_cast<MatMulNode>(SN->getInput());
3705   ASSERT_TRUE(newMMN != nullptr);
3706   auto *newW = llvm::dyn_cast<Constant>(newMMN->getRHS());
3707   ASSERT_TRUE(newW != nullptr);
3708   for (unsigned i = 0; i < 6; ++i) {
3709     EXPECT_EQ(newWeightsRef[i], newW->getHandle().raw(i));
3710   }
3711 }
3712 
3713 /// Test that transpose is merged into FullyConnected.
TEST_F(GraphOptz,mergeTransposeIntoFC)3714 TEST_F(GraphOptz, mergeTransposeIntoFC) {
3715   auto *input =
3716       mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 2, 3}, "input", false);
3717   auto *weights =
3718       F_->getParent()->createConstant(ElemKind::FloatTy, {6, 1}, "weights");
3719   auto *bias = F_->getParent()->createConstant(ElemKind::FloatTy, {1}, "bias");
3720 
3721   weights->getHandle() = {0, 1, 2, 3, 4, 5};
3722   float newWeightsRef[] = {0, 2, 4, 1, 3, 5};
3723 
3724   auto *TN = F_->createTranspose("transpose", input, NHWC2NCHW);
3725   auto *RN = F_->createReshape("reshape", TN, {1, 6});
3726   auto *FCN = F_->createFullyConnected("fc", RN, weights, bias);
3727   auto *SN = F_->createSave("ret", FCN);
3728 
3729   // Transpose + Reshape + FC + Save.
3730   EXPECT_EQ(F_->getNodes().size(), 4);
3731 
3732   ::glow::optimize(F_, CompilationMode::Infer);
3733 
3734   // Reshape + FC + Save.
3735   EXPECT_EQ(F_->getNodes().size(), 3);
3736 
3737   // Check reordered weights.
3738   auto *newFCN = llvm::dyn_cast<FullyConnectedNode>(SN->getInput());
3739   ASSERT_TRUE(newFCN != nullptr);
3740   auto *newW = llvm::dyn_cast<Constant>(newFCN->getWeights());
3741   ASSERT_TRUE(newW != nullptr);
3742   for (unsigned i = 0; i < 6; ++i) {
3743     EXPECT_EQ(newWeightsRef[i], newW->getHandle().raw(i));
3744   }
3745 }
3746 
TEST_F(GraphOptz,ConvertPlaceholdersToConstants)3747 TEST_F(GraphOptz, ConvertPlaceholdersToConstants) {
3748   auto *input1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input1", true);
3749   auto *input2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input2", true);
3750   auto *input3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input3", true);
3751   auto *save1 = F_->createSave("save1", input1);
3752   auto *save2 = F_->createSave("save2", input2);
3753   auto *save3 = F_->createSave("save3", input3);
3754 
3755   // No variables, six PHs (3 inputs, 3 saves).
3756   EXPECT_EQ(mod_.getConstants().size(), 0);
3757   EXPECT_EQ(mod_.getPlaceholders().size(), 6);
3758 
3759   // Allocate two of the three inputs, but mark input2 of them as
3760   // non-constant.
3761   bindings_.allocate(input1);
3762   bindings_.allocate(input2);
3763   // Don't allocate input3; keep it as a placeholder instead.
3764   ::glow::convertPlaceholdersToConstants(F_, bindings_, {input2});
3765 
3766   // input1 becomes a variable.
3767   EXPECT_EQ(mod_.getConstants().size(), 1);
3768   EXPECT_EQ(mod_.getPlaceholders().size(), 6);
3769 
3770   EXPECT_TRUE(llvm::isa<Constant>(save1->getInput()));
3771   EXPECT_TRUE(llvm::isa<Placeholder>(save2->getInput()));
3772   EXPECT_TRUE(llvm::isa<Placeholder>(save3->getInput()));
3773 }
3774 
TEST_F(GraphOptz,optimizeConversion_i32_i64_i32)3775 TEST_F(GraphOptz, optimizeConversion_i32_i64_i32) {
3776   auto *i32 = mod_.uniqueType(ElemKind::Int32ITy, {1});
3777   auto *i64 = mod_.uniqueType(ElemKind::Int64ITy, {1});
3778 
3779   auto *A = mod_.createPlaceholder(i32, "A", false);
3780   auto *B = F_->createConvertTo("B", A, i64);
3781   auto *C = F_->createConvertTo("C", B, i32);
3782   auto *S = F_->createSave("S", C);
3783 
3784   ::glow::optimize(F_, CompilationMode::Infer);
3785 
3786   // All casting is optimized away, only left with Save of Placeholder.
3787   EXPECT_EQ(F_->getNodes().size(), 1);
3788   EXPECT_TRUE(llvm::isa<Placeholder>(S->getInput()));
3789 }
3790 
TEST_F(GraphOptz,optimizeSameTypeConversions)3791 TEST_F(GraphOptz, optimizeSameTypeConversions) {
3792   auto *input1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input1", true);
3793   auto *input2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input2", true);
3794   auto *conv1 = F_->createConvertTo("cast1", input1, ElemKind::FloatTy);
3795   auto *conv2 = F_->createConvertTo("cast2", input2, ElemKind::Float16Ty);
3796   auto *save1 = F_->createSave("save1", conv1);
3797   auto *save2 = F_->createSave("save1", conv2);
3798 
3799   // convert_to1 + save1 + convert_to2 + save2 nodes.
3800   EXPECT_EQ(F_->getNodes().size(), 4);
3801   EXPECT_TRUE(llvm::isa<ConvertToNode>(save1->getInput()));
3802 
3803   ::glow::optimize(F_, CompilationMode::Infer);
3804 
3805   // save1 + convert_to2 + save2 nodes.
3806   EXPECT_EQ(F_->getNodes().size(), 3);
3807   // convert_to1 node should be eliminated, because it converts the node into
3808   // the same type.
3809   EXPECT_TRUE(llvm::isa<Placeholder>(save1->getInput()));
3810   // convert_to1 node should not be eliminated, because it converts the node
3811   // into a different type.
3812   EXPECT_TRUE(llvm::isa<ConvertToNode>(save2->getInput()));
3813   EXPECT_EQ(save2->getInput(), NodeValue(conv2));
3814 }
3815 
TEST_F(GraphOptz,optimizeConvertingBetweenFused)3816 TEST_F(GraphOptz, optimizeConvertingBetweenFused) {
3817   // Call with dims {5, 2}, which will actually create a constant with {5, 10}
3818   // for scale/offset per row.
3819   Constant *C = createRandomFusedRowwiseQuantizedConstant(
3820       mod_, {5, 2}, "fused", /* useFusedFP16 */ false);
3821   // Converting to fused FP16 means we have 4 total less bytes for scale/offset,
3822   // so we move to {5, 10} from {5, 6}.
3823   auto newOT = mod_.uniqueType(ElemKind::UInt8FusedFP16QTy, {5, 6}, 1.0, 0);
3824   auto *CN = F_->createConvertTo("convert", C, newOT);
3825   auto *SN = F_->createSave("save", CN);
3826 
3827   ::glow::optimize(F_, CompilationMode::Infer);
3828 
3829   // Convert should be eliminated and just the save of the Constant left.
3830   EXPECT_EQ(F_->getNodes().size(), 1);
3831   Constant *convertedC = llvm::dyn_cast<Constant>(SN->getInput());
3832   ASSERT_TRUE(convertedC);
3833   EXPECT_EQ(convertedC->getElementType(), ElemKind::UInt8FusedFP16QTy);
3834 }
3835 
TEST_F(GraphOptz,dceBeforeOptimizeTranpose)3836 TEST_F(GraphOptz, dceBeforeOptimizeTranpose) {
3837   auto *input1 = mod_.createConstant(ElemKind::FloatTy, {5, 10}, "input1");
3838   // Create an unused node.
3839   F_->createAdd("add", input1, input1);
3840   auto *transposedInput1 = F_->createTranspose("transpose", input1, {1, 0});
3841   auto *save1 = F_->createSave("save1", transposedInput1);
3842 
3843   // add + transpose + save.
3844   EXPECT_EQ(F_->getNodes().size(), 3);
3845 
3846   ::glow::optimize(F_, CompilationMode::Infer);
3847 
3848   // A single node: save.
3849   EXPECT_EQ(F_->getNodes().size(), 1);
3850   // transpose should be eliminated and replaced by the transposed constant.
3851   EXPECT_TRUE(llvm::isa<Constant>(save1->getInput()));
3852 }
3853 
3854 /// Test that Transpose is sunk below ChannelShuffle and cancels with an
3855 /// inverse transpose below the ChannelShuffle. This test models a pattern
3856 /// that has has been observed in shufflenet during graph optimization.
TEST_F(GraphOptz,sinkTransposeBelowChannelShuffleNodesAndEliminate)3857 TEST_F(GraphOptz, sinkTransposeBelowChannelShuffleNodesAndEliminate) {
3858   const dim_t inputDims[] = {3, 28, 28, 136};
3859 
3860   Node *K =
3861       mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
3862   K = F_->createTranspose("unnecessary_transpose_1", K, {0, 3, 1, 2});
3863   K = F_->createChannelShuffle("channel_shuffle", K, 4, 1);
3864   K = F_->createTranspose("unnecessary_transpose_2", K, {0, 2, 3, 1});
3865   auto *save = F_->createSave("ret", K);
3866 
3867   EXPECT_EQ(F_->getNodes().size(), 4);
3868 
3869   // Optimize away the unnecessary transposes.
3870   optimize(F_, CompilationMode::Infer);
3871 
3872   // Ensure the two unnecessary transposes are gone.
3873   ASSERT_EQ(F_->getNodes().size(), 2);
3874 
3875   // Check that the channel shuffle node is still there.
3876   auto *CSN = llvm::dyn_cast<ChannelShuffleNode>(save->getInput().getNode());
3877   ASSERT_NE(nullptr, CSN);
3878 
3879   // Ensure ChannelShuffle node has the same dimensions as the input.
3880   EXPECT_EQ(CSN->getResult().dims(), llvm::makeArrayRef(inputDims));
3881 
3882   // Ensure Group and Kernel are as expected.
3883   EXPECT_EQ(CSN->getGroup(), 4);
3884   EXPECT_EQ(CSN->getKernel(), 3);
3885 }
3886 
3887 /// Test BatchNorm sinking below Slice.
TEST_F(GraphOptz,sinkBatchNormBelowSlice)3888 TEST_F(GraphOptz, sinkBatchNormBelowSlice) {
3889   auto *inputTy = mod_.uniqueType(ElemKind::FloatTy, {1, 10, 10, 3});
3890   auto *slicedTy1 = mod_.uniqueType(ElemKind::FloatTy, {1, 8, 8, 3});
3891   auto *slicedTy2 = mod_.uniqueType(ElemKind::FloatTy, {1, 6, 6, 1});
3892 
3893   auto *input = mod_.createPlaceholder(inputTy, "input", false);
3894   auto *BN = F_->createBatchNormalization(bindings_, "batchnorm", input, 3,
3895                                           0.0001, 0.9);
3896   auto *SN1 = F_->createSlice("slice1", BN, {0, 1, 1, 0}, slicedTy1);
3897   auto *SN2 = F_->createSlice("slice2", SN1, {0, 1, 1, 1}, slicedTy2);
3898   auto *save = F_->createSave("save", SN2);
3899 
3900   EXPECT_EQ(F_->getNodes().size(), 4);
3901   ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
3902   optimizedF_ = optimizeFunction(F_);
3903   EXPECT_EQ(optimizedF_->getNodes().size(), 4);
3904 
3905   // BatchNorm should have sunk below the first Slice, but not the second one,
3906   // as it changes channel dimmension.
3907   auto *newSave =
3908       findFunctionNodeByName<SaveNode>(optimizedF_, save->getName());
3909   ASSERT_TRUE(newSave);
3910   auto *newSN2 = llvm::dyn_cast<SliceNode>(newSave->getInput());
3911   ASSERT_TRUE(newSN2);
3912   auto *newBN = llvm::dyn_cast<BatchNormalizationNode>(newSN2->getInput());
3913   ASSERT_TRUE(newBN);
3914   ASSERT_EQ(newBN->getResult().dims(), slicedTy1->dims());
3915   ASSERT_TRUE(llvm::isa<SliceNode>(newBN->getInput()));
3916 
3917   bindings_.allocate(mod_.getPlaceholders());
3918   bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
3919   checkNumericalEquivalence();
3920 }
3921 
3922 /// Test that convertPlaceholdersToConstants works properly with quantized
3923 /// types.
TEST_F(GraphOptz,QuantizedFC)3924 TEST_F(GraphOptz, QuantizedFC) {
3925   auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 1.0, 0,
3926                                        "input", false);
3927   auto *weights = mod_.createPlaceholder(ElemKind::Int8QTy, {32, 32}, 1.0, 0,
3928                                          "weights", false);
3929   auto *bias =
3930       mod_.createPlaceholder(ElemKind::Int32QTy, {32}, 1.0, 0, "bias", false);
3931   auto *output = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 1.0, 0,
3932                                         "output", false);
3933 
3934   auto *fc = F_->createFullyConnected("fc", input, weights, bias);
3935   F_->createSave("save", fc, output);
3936 
3937   bindings_.allocate(input);
3938   bindings_.allocate(weights);
3939   bindings_.allocate(bias);
3940   bindings_.allocate(output);
3941 
3942   glow::convertPlaceholdersToConstants(F_, bindings_, {input, output});
3943   // Two constants: weight and bias
3944   EXPECT_EQ(mod_.getConstants().size(), 2);
3945   // All four placeholders still exist in the module.  The old weight and bias
3946   // placeholders just aren't hooked up the the Graph F_.
3947   EXPECT_EQ(mod_.getPlaceholders().size(), 4);
3948 }
3949 
3950 /// Test batchedReduceMean optimization using AvgPool.
TEST_F(GraphOptz,convertReduceMean2AvgPool)3951 TEST_F(GraphOptz, convertReduceMean2AvgPool) {
3952   const dim_t dims[] = {2, 2, 2, 2};
3953 
3954   Node *A = mod_.createPlaceholder(ElemKind::FloatTy, dims, "input", false);
3955   Node *R = F_->createBatchedReduceMean("reduce.mean", A, {2, 3});
3956 
3957   SaveNode *O = F_->createSave("ret", R);
3958 
3959   EXPECT_EQ(F_->getNodes().size(), 2);
3960 
3961   ::glow::optimize(F_, CompilationMode::Infer);
3962 
3963   // Optimization adds 2 transpose nodes and one reshape node.
3964   EXPECT_EQ(F_->getNodes().size(), 5);
3965 
3966   // Expecting reshape output rather than ReduceMean.
3967   auto *RN = llvm::dyn_cast<ReshapeNode>(O->getInput());
3968   ASSERT_NE(RN, nullptr);
3969 
3970   // Expecting Transpose node before Reshape node.
3971   auto *TN = llvm::dyn_cast<TransposeNode>(RN->getInput());
3972   ASSERT_NE(TN, nullptr);
3973 
3974   // Expecting AvgPool node before Transpose node.
3975   auto *APN = llvm::dyn_cast<AvgPoolNode>(TN->getInput());
3976   ASSERT_NE(APN, nullptr);
3977 }
3978 
3979 /// Test Broadcasted RHS BatchMatMul is converted correctly to a single MatMul.
TEST_F(GraphOptz,convertBroadcastedBatchMatMulToMatMul)3980 TEST_F(GraphOptz, convertBroadcastedBatchMatMulToMatMul) {
3981   auto *lhs =
3982       mod_.createPlaceholder(ElemKind::FloatTy, {2, 3, 2}, "lhs", false);
3983   auto *rhs = mod_.createPlaceholder(ElemKind::FloatTy, {2, 1}, "rhs", false);
3984   auto *BMMN = F_->createBatchMatMul("BMM", lhs, rhs);
3985   F_->createSave("save", BMMN);
3986 
3987   // Start with a BatchMatMul, not a MatMul.
3988   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchMatMulNodeKind), 1);
3989   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 0);
3990 
3991   ::glow::optimize(F_, CompilationMode::Infer);
3992 
3993   // Optimization should replace the BatchMatMul with a single MatMul.
3994   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 1);
3995   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchMatMulNodeKind), 0);
3996 }
3997 
TEST_F(GraphOptz,dceQuantization)3998 TEST_F(GraphOptz, dceQuantization) {
3999   auto *lhs =
4000       mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5}, 0.3, 15, "lhs", false);
4001   auto *weights =
4002       mod_.createConstant(ElemKind::Int8QTy, {3, 5}, 0.3, 15, "weights");
4003 
4004   auto *add = F_->createAdd("add", lhs, weights);
4005   auto *t1 = mod_.uniqueType(ElemKind::Int8QTy, {3, 5}, 0.2, 0);
4006   auto *rs1 = F_->createRescaleQuantized("rs1", add, t1);
4007   auto *t2 = mod_.uniqueType(ElemKind::Int8QTy, {3, 5}, 0.1, 1);
4008   auto *rs2 = F_->createRescaleQuantized("rs2", rs1, t2);
4009   F_->createSave("save", rs2);
4010 
4011   ::glow::optimize(F_, CompilationMode::Infer);
4012 
4013   EXPECT_EQ(F_->getNodes().size(), 2);
4014 }
4015 
TEST_F(GraphOptz,nopRelu)4016 TEST_F(GraphOptz, nopRelu) {
4017   auto *in = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5}, 0.3, -128, "lhs",
4018                                     false);
4019 
4020   auto *relu = F_->createRELU("relu", in);
4021   F_->createSave("save", relu);
4022 
4023   optimizedF_ = optimizeFunction(F_);
4024 
4025   EXPECT_EQ(optimizedF_->getNodes().size(), 1);
4026 
4027   bindings_.allocate(mod_.getPlaceholders());
4028   bindings_.get(in)->getHandle<int8_t>().randomize(-4, 4, mod_.getPRNG());
4029 
4030   checkNumericalEquivalence();
4031 }
4032 
4033 template <typename ElemTy>
setConstValue(Constant * C,ElemTy value)4034 static void setConstValue(Constant *C, ElemTy value) {
4035   Handle<ElemTy> TH = C->getPayload().getHandle<ElemTy>();
4036   TH.clear(value);
4037 }
4038 
TEST_F(GraphOptz,constantFoldSingleNode)4039 TEST_F(GraphOptz, constantFoldSingleNode) {
4040   auto *const1 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const1");
4041   auto *const2 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const2");
4042   auto *ph1 = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "input1",
4043                                      /* isTrainable */ false);
4044   setConstValue(const1, 1.0f);
4045   setConstValue(const2, 2.0f);
4046   auto *splat2 = F_->createSplat(
4047       "splat2", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 2.0f);
4048   auto *splat3 = F_->createSplat(
4049       "splat3", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 3.0f);
4050 
4051   auto *add1 = F_->createAdd("add", const1, const2);
4052   auto *mul1 = F_->createMul("mul1", add1, splat2);
4053   auto *mul2 = F_->createMul("mul2", mul1, splat3);
4054   auto *SN1 = F_->createSave("save", mul2);
4055   auto *add3 = F_->createAdd("add", const1, ph1);
4056   auto *SN2 = F_->createSave("save", add3);
4057 
4058   // Perform constant folding for a specific node.
4059   std::vector<Constant *> constResults =
4060       constantFold(SN1->getInput().getNode());
4061 
4062   EXPECT_EQ(constResults.size(), 1);
4063   SN1->getInput().replaceAllUsesOfWith(constResults[0]);
4064   // Second save should be unaffected.
4065   EXPECT_FALSE(llvm::isa<Constant>(SN2->getInput()));
4066   // First save should have been constant folded.
4067   EXPECT_TRUE(llvm::isa<Constant>(SN1->getInput()));
4068   Constant *C = llvm::dyn_cast<Constant>(SN1->getInput());
4069   auto CH = C->getHandle();
4070   // The expected result should be: (((1+2) * 2 * 3) = 18
4071   EXPECT_EQ(CH.at({0, 0}), 18.0f);
4072   EXPECT_EQ(CH.at({0, 1}), 18.0f);
4073   EXPECT_EQ(CH.at({1, 0}), 18.0f);
4074   EXPECT_EQ(CH.at({1, 1}), 18.0f);
4075 }
4076 
4077 /// Test that we correctly record a single constant folding subgraph that has a
4078 /// single output.
TEST_F(GraphOptz,constantFoldWithRecordSingleChain)4079 TEST_F(GraphOptz, constantFoldWithRecordSingleChain) {
4080   Placeholder *I =
4081       mod_.createPlaceholder(ElemKind::Float16Ty, {2, 100}, "input",
4082                              /* isTrainable */ false);
4083   Constant *W = mod_.createConstant(ElemKind::FloatTy, {10, 100}, "weight");
4084   ClipNode *clipW = F_->createClip("clip", W, -5.f, 5.f);
4085   ConvertToNode *convertW =
4086       F_->createConvertTo("conv", clipW, ElemKind::Float16Ty);
4087   TransposeNode *transposeW =
4088       F_->createTranspose("transpose", convertW, {1, 0});
4089   MatMulNode *MM = F_->createMatMul("matmul", I, transposeW);
4090   SaveNode *save = F_->createSave("save", MM);
4091   Placeholder *O = save->getPlaceholder();
4092   bindings_.allocate(O);
4093 
4094   ASSERT_TRUE(F_->verify());
4095 
4096   Tensor *IT = bindings_.allocate(I);
4097   IT->getHandle<float16_t>().randomize(-10, 10, mod_.getPRNG());
4098   W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG());
4099 
4100   optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4101 
4102   ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_);
4103 
4104   runDCEPass(optimizedF_, cctx_);
4105 
4106   ASSERT_EQ(record.size(), 1);
4107   SaveNode *SN = record.begin()->second;
4108   Function *constFoldF = SN->getParent();
4109 
4110   // Expect to find a chain of Nodes based on Nodes above. Note that the clip is
4111   // lowered for the Interpreter backend which performs constant folding.
4112   EXPECT_EQ(2, countNodeKind(constFoldF, Kinded::Kind::SplatNodeKind));
4113   EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::MaxNodeKind));
4114   EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::MinNodeKind));
4115   EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::ConvertToNodeKind));
4116   EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::TransposeNodeKind));
4117 
4118   // Skip optimizations -- we just want to run them as is (otherwise we'll
4119   // constant fold them inside the optimization pipeline).
4120   cctx_.optimizationOpts.onlyLowerFuns.insert(constFoldF);
4121   cctx_.optimizationOpts.onlyLowerFuns.insert(F_);
4122   cctx_.optimizationOpts.onlyLowerFuns.insert(optimizedF_);
4123 
4124   // Don't strip the module as we want to compare the Constant values below.
4125   EE_.setSkipModuleStrip(true);
4126 
4127   EE_.compile(cctx_);
4128   alreadyCompiled_ = true;
4129 
4130   bindings_.allocate(mod_.getPlaceholders());
4131 
4132   // Run the constant folding chain to check that we have the same constant used
4133   // by the optimized Function.
4134   EE_.run(bindings_, constFoldF->getName());
4135   Tensor *rerunT = bindings_.get(SN->getPlaceholder());
4136   ASSERT_TRUE(rerunT);
4137   auto optimizedConstants = optimizedF_->findConstants();
4138   ASSERT_EQ(optimizedConstants.size(), 1);
4139   EXPECT_TRUE(
4140       (*optimizedConstants.begin())->getPayload().isEqual(*rerunT, 0.f));
4141 
4142   // Remove the temporary constant folding Functions and their Placeholders.
4143   cleanupConstantFolding(mod_, record, &bindings_);
4144 
4145   // Now compile/run/compare F_ and optimizedF_.
4146   checkNumericalEquivalence(0.f);
4147 }
4148 
4149 /// Test that we correctly record two constant folding subgraphs, with each with
4150 /// a single output.
TEST_F(GraphOptz,constantFoldWithRecordMultiChain)4151 TEST_F(GraphOptz, constantFoldWithRecordMultiChain) {
4152   Placeholder *I =
4153       mod_.createPlaceholder(ElemKind::Float16Ty, {2, 100}, "input",
4154                              /* isTrainable */ false);
4155   Constant *W = mod_.createConstant(ElemKind::FloatTy, {10, 100}, "weight");
4156   ClipNode *clipW = F_->createClip("clip", W, -5.f, 5.f);
4157   ConvertToNode *convertW =
4158       F_->createConvertTo("conv", clipW, ElemKind::Float16Ty);
4159   TransposeNode *transposeW =
4160       F_->createTranspose("transpose", convertW, {1, 0});
4161   MatMulNode *MM = F_->createMatMul("matmul", I, transposeW);
4162   SaveNode *saveMM = F_->createSave("save_mm", MM);
4163   Placeholder *MMP = saveMM->getPlaceholder();
4164   bindings_.allocate(MMP);
4165 
4166   SigmoidNode *sigmoidW = F_->createSigmoid("sig", convertW);
4167   SaveNode *saveSig = F_->createSave("save_sig", sigmoidW);
4168   Placeholder *sigP = saveSig->getPlaceholder();
4169   bindings_.allocate(sigP);
4170 
4171   ASSERT_TRUE(F_->verify());
4172 
4173   Tensor *IT = bindings_.allocate(I);
4174   IT->getHandle<float16_t>().randomize(-10, 10, mod_.getPRNG());
4175   W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG());
4176 
4177   optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4178 
4179   ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_);
4180 
4181   runDCEPass(optimizedF_, cctx_);
4182 
4183   ASSERT_EQ(record.size(), 2);
4184   SaveNode *sigSN = record.begin()->second;
4185   SaveNode *transSN = std::next(record.begin())->second;
4186   if (llvm::isa<SigmoidNode>(transSN->getInput())) {
4187     std::swap(sigSN, transSN);
4188   }
4189 
4190   Function *constFoldSig = sigSN->getParent();
4191   Function *constFoldTrans = transSN->getParent();
4192 
4193   // Expect to find a chain of Nodes based on Nodes above. Note that the clip is
4194   // lowered for the Interpreter backend which performs constant folding.
4195   EXPECT_EQ(2, countNodeKind(constFoldTrans, Kinded::Kind::SplatNodeKind));
4196   EXPECT_EQ(1, countNodeKind(constFoldTrans, Kinded::Kind::MaxNodeKind));
4197   EXPECT_EQ(1, countNodeKind(constFoldTrans, Kinded::Kind::MinNodeKind));
4198   EXPECT_EQ(1, countNodeKind(constFoldTrans, Kinded::Kind::ConvertToNodeKind));
4199   EXPECT_EQ(1, countNodeKind(constFoldTrans, Kinded::Kind::TransposeNodeKind));
4200 
4201   EXPECT_EQ(2, countNodeKind(constFoldSig, Kinded::Kind::SplatNodeKind));
4202   EXPECT_EQ(1, countNodeKind(constFoldSig, Kinded::Kind::MaxNodeKind));
4203   EXPECT_EQ(1, countNodeKind(constFoldSig, Kinded::Kind::MinNodeKind));
4204   EXPECT_EQ(1, countNodeKind(constFoldSig, Kinded::Kind::ConvertToNodeKind));
4205   EXPECT_EQ(1, countNodeKind(constFoldSig, Kinded::Kind::SigmoidNodeKind));
4206 
4207   // Skip optimizations -- we just want to run them as is (otherwise we'll
4208   // constant fold them inside the optimization pipeline).
4209   cctx_.optimizationOpts.onlyLowerFuns.insert(constFoldTrans);
4210   cctx_.optimizationOpts.onlyLowerFuns.insert(constFoldSig);
4211   cctx_.optimizationOpts.onlyLowerFuns.insert(F_);
4212   cctx_.optimizationOpts.onlyLowerFuns.insert(optimizedF_);
4213 
4214   // Don't strip the module as we want to compare the Constant values below.
4215   EE_.setSkipModuleStrip(true);
4216 
4217   EE_.compile(cctx_);
4218   alreadyCompiled_ = true;
4219 
4220   bindings_.allocate(mod_.getPlaceholders());
4221 
4222   // Run the constant folding chain to check that we have the same constant used
4223   // by the optimized Function.
4224   EE_.run(bindings_, constFoldTrans->getName());
4225   EE_.run(bindings_, constFoldSig->getName());
4226 
4227   // Find the correct PHs for each of the constant folding we do.
4228   Tensor *rerunTransT = bindings_.get(transSN->getPlaceholder());
4229   Tensor *rerunSigT = bindings_.get(sigSN->getPlaceholder());
4230   ASSERT_TRUE(rerunTransT);
4231   ASSERT_TRUE(rerunSigT);
4232 
4233   auto optimizedConstants = optimizedF_->findConstants();
4234   ASSERT_EQ(optimizedConstants.size(), 2);
4235   Constant *transC = *optimizedConstants.begin();
4236   Constant *sigC = *std::next(optimizedConstants.begin());
4237   // If we have the constants backwards then swap them. Note that we know
4238   // sigC must be directly saved, while transC is input to a MatMulNode.
4239   ASSERT_EQ(transC->getNumUsers(), 1);
4240   if (llvm::isa<SaveNode>(transC->getUsers().begin()->getUser())) {
4241     std::swap(transC, sigC);
4242   }
4243   EXPECT_TRUE(transC->getPayload().isEqual(*rerunTransT, 0.f));
4244   EXPECT_TRUE(sigC->getPayload().isEqual(*rerunSigT, 0.f));
4245 
4246   // Remove the temporary constant folding Functions and their Placeholders.
4247   cleanupConstantFolding(mod_, record, &bindings_);
4248 
4249   // Now compile/run/compare F_ and optimizedF_.
4250   checkNumericalEquivalence(0.f);
4251 }
4252 
4253 /// Test that we correctly record a single constant folding subgraph that has
4254 /// two outputs.
TEST_F(GraphOptz,constantFoldWithRecordSingleChainMultiOutput)4255 TEST_F(GraphOptz, constantFoldWithRecordSingleChainMultiOutput) {
4256   Constant *W = mod_.createConstant(ElemKind::FloatTy, {100}, "weight");
4257   SigmoidNode *sigmoidW = F_->createSigmoid("sig", W);
4258   ConvertToNode *convertW =
4259       F_->createConvertTo("conv", sigmoidW, ElemKind::Float16Ty);
4260   TopKNode *TK = F_->createTopK("topk", convertW, 5);
4261 
4262   SaveNode *indicesSave = F_->createSave("save_indices", TK->getIndices());
4263   Placeholder *indicesP = indicesSave->getPlaceholder();
4264   bindings_.allocate(indicesP);
4265 
4266   Placeholder *I = mod_.createPlaceholder(ElemKind::Float16Ty, {5}, "input",
4267                                           /* isTrainable */ false);
4268   AddNode *add = F_->createAdd("add", I, TK->getValues());
4269   SaveNode *addSave = F_->createSave("save_add", add);
4270   Placeholder *addP = addSave->getPlaceholder();
4271   bindings_.allocate(addP);
4272 
4273   ASSERT_TRUE(F_->verify());
4274 
4275   Tensor *IT = bindings_.allocate(I);
4276   IT->getHandle<float16_t>().randomize(-10, 10, mod_.getPRNG());
4277   W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG());
4278 
4279   optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4280 
4281   ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_);
4282 
4283   runDCEPass(optimizedF_, cctx_);
4284 
4285   ASSERT_EQ(record.size(), 2);
4286   SaveNode *indicesSN = record.begin()->second;
4287   SaveNode *addSN = std::next(record.begin())->second;
4288 
4289   // Find the correct PHs for each of the constant folding we do.
4290   if (indicesSN->getInput().getResNo() != TopKNode::IndicesIdx) {
4291     std::swap(indicesSN, addSN);
4292   }
4293 
4294   // Expect that the two constants that we folded are from the same Function,
4295   // and that the two saves use the two different outputs from a topk.
4296   EXPECT_EQ(indicesSN->getParent(), addSN->getParent());
4297   ASSERT_TRUE(llvm::isa<TopKNode>(addSN->getInput()));
4298   ASSERT_TRUE(llvm::isa<TopKNode>(indicesSN->getInput()));
4299   EXPECT_EQ(addSN->getInput().getNode(), indicesSN->getInput().getNode());
4300 
4301   Function *constFoldF = addSN->getParent();
4302 
4303   // Expect to find a chain of Nodes based on Nodes above.
4304   EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::TopKNodeKind));
4305   EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::SigmoidNodeKind));
4306   EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::ConvertToNodeKind));
4307 
4308   // Skip optimizations -- we just want to run them as is (otherwise we'll
4309   // constant fold them inside the optimization pipeline).
4310   cctx_.optimizationOpts.onlyLowerFuns.insert(constFoldF);
4311   cctx_.optimizationOpts.onlyLowerFuns.insert(F_);
4312   cctx_.optimizationOpts.onlyLowerFuns.insert(optimizedF_);
4313 
4314   // Don't strip the module as we want to compare the Constant values below.
4315   EE_.setSkipModuleStrip(true);
4316 
4317   EE_.compile(cctx_);
4318   alreadyCompiled_ = true;
4319 
4320   bindings_.allocate(mod_.getPlaceholders());
4321 
4322   // Run the constant folding chain to check that we have the same constant used
4323   // by the optimized Function.
4324   EE_.run(bindings_, constFoldF->getName());
4325 
4326   Tensor *rerunAddT = bindings_.get(addSN->getPlaceholder());
4327   Tensor *rerunIndicesT = bindings_.get(indicesSN->getPlaceholder());
4328   ASSERT_TRUE(rerunAddT);
4329   ASSERT_TRUE(rerunIndicesT);
4330 
4331   auto optimizedConstants = optimizedF_->findConstants();
4332   ASSERT_EQ(optimizedConstants.size(), 2);
4333   Constant *addC = *optimizedConstants.begin();
4334   Constant *indicesC = *std::next(optimizedConstants.begin());
4335 
4336   // If we have the constants backwards then swap them. Note that we know
4337   // indicesC must be directly saved, while addC is input to an AddNode.
4338   ASSERT_EQ(addC->getNumUsers(), 1);
4339   if (llvm::isa<SaveNode>(addC->getUsers().begin()->getUser())) {
4340     std::swap(addC, indicesC);
4341   }
4342   EXPECT_TRUE(addC->getPayload().isEqual(*rerunAddT, 0.f));
4343   EXPECT_TRUE(indicesC->getPayload().isEqual(*rerunIndicesT, 0.f));
4344 
4345   // Remove the temporary constant folding Functions and their Placeholders.
4346   cleanupConstantFolding(mod_, record, &bindings_);
4347 
4348   // Now compile/run/compare F_ and optimizedF_.
4349   checkNumericalEquivalence(0.f);
4350 }
4351 
TEST_F(GraphOptz,constantFoldWholeFunction)4352 TEST_F(GraphOptz, constantFoldWholeFunction) {
4353   auto *const1 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const1");
4354   auto *const2 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const2");
4355   auto *const3 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const3");
4356   auto *const4 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const4");
4357   auto *ph1 = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "input1",
4358                                      /* isTrainable */ false);
4359   setConstValue(const1, 1.0f);
4360   setConstValue(const2, 2.0f);
4361   setConstValue(const3, 3.0f);
4362   setConstValue(const4, 4.0f);
4363   auto *splat2 = F_->createSplat(
4364       "splat2", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 2.0f);
4365   auto *splat3 = F_->createSplat(
4366       "splat2", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 3.0f);
4367   auto *splat4 = F_->createSplat(
4368       "splat2", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 4.0f);
4369 
4370   auto *add1 = F_->createAdd("add", const1, const2);
4371   auto *mul1 = F_->createMul("mul1", add1, splat2);
4372   auto *mul2 = F_->createMul("mul2", mul1, splat3);
4373   auto *sub = F_->createSub("sub", mul2, const3);
4374   auto *add2 = F_->createAdd("add2", sub, const4);
4375   auto *mul3 = F_->createMul("mul3", add2, splat4);
4376   // Check compile-time constant folding for nodes with multiple results.
4377   auto *topK = F_->createTopK("topK", mul3, 2);
4378   auto *SN1_0 = F_->createSave("save", topK->getValues());
4379   auto *SN1_1 = F_->createSave("save", topK->getIndices());
4380   auto *add3 = F_->createAdd("add", const1, ph1);
4381   auto *SN2 = F_->createSave("save", add3);
4382 
4383   // Perform constant folding for a whole function.
4384   ::glow::optimize(F_, CompilationMode::Infer);
4385 
4386   EXPECT_EQ(F_->getNodes().size(), 4);
4387   // Second save should be unaffected, as its value is not a constant operation.
4388   EXPECT_FALSE(llvm::isa<Constant>(SN2->getInput()));
4389   // First save should have been constant folded.
4390   EXPECT_TRUE(llvm::isa<Constant>(SN1_0->getInput()));
4391   EXPECT_TRUE(llvm::isa<Constant>(SN1_1->getInput()));
4392   Constant *C = llvm::dyn_cast<Constant>(SN1_0->getInput());
4393   auto CH = C->getHandle();
4394   // The expected result should be: (((1+2) * 2 * 3 - 3) + 4) * 4 = 76
4395   EXPECT_EQ(CH.at({0, 0}), 76.0f);
4396   EXPECT_EQ(CH.at({0, 1}), 76.0f);
4397   EXPECT_EQ(CH.at({1, 0}), 76.0f);
4398   EXPECT_EQ(CH.at({1, 1}), 76.0f);
4399 }
4400 
4401 /// Test constant folding for operators which are lowered in Interpreter
4402 /// backend.
TEST_F(GraphOptz,constantFoldWithLowering)4403 TEST_F(GraphOptz, constantFoldWithLowering) {
4404   auto *input = mod_.createConstant(ElemKind::FloatTy, {1, 6}, "input");
4405   input->getHandle() = {5, 4, 3, 2, 1, 0};
4406   auto *TN = F_->createTile("tile", input, 5, 0);
4407   auto *SN = F_->createSave("ret", TN);
4408 
4409   // Perform constant folding.
4410   EXPECT_EQ(F_->getNodes().size(), 2);
4411   ::glow::optimize(F_, CompilationMode::Infer);
4412 
4413   // Tile with its input should be folded into a single Constant node.
4414   EXPECT_EQ(F_->getNodes().size(), 1);
4415   ASSERT_TRUE(llvm::isa<Constant>(SN->getInput()));
4416 }
4417 
4418 /// Test Splitting FC into multiple FCs.
TEST_F(GraphOptz,SplitFCIntoMultipleOps)4419 TEST_F(GraphOptz, SplitFCIntoMultipleOps) {
4420   auto *input =
4421       mod_.createPlaceholder(ElemKind::FloatTy, {2, 32}, "input", false);
4422   bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
4423                                                           mod_.getPRNG());
4424   auto *weights = mod_.createConstant(ElemKind::FloatTy, {32, 850}, "weights");
4425   weights->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
4426   auto *bias = mod_.createConstant(ElemKind::FloatTy, {850}, "bias");
4427   bias->getHandle().randomize(0.0, 0.5, mod_.getPRNG());
4428   auto *output =
4429       mod_.createPlaceholder(ElemKind::FloatTy, {2, 850}, "output", false);
4430   bindings_.allocate(output);
4431 
4432   auto *fc = F_->createFullyConnected("fc", input, weights, bias);
4433   auto *save = F_->createSave("save", fc, output);
4434 
4435   ::glow::optimize(F_, CompilationMode::Infer);
4436 
4437   // This is F_ but without the parallel transformation below.
4438   optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4439 
4440   EXPECT_TRUE(::glow::executeVerticalFCWeightsSplit(F_,
4441                                                     /*numOfChunks*/ 12,
4442                                                     /*minKToSplit*/ 800));
4443   runDCEPass(F_, cctx_);
4444 
4445   // 24 Slices: 12 from bias and 12 from weights.
4446   EXPECT_EQ(24, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
4447 
4448   EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
4449 
4450   // 12 newly created FCs.
4451   EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
4452 
4453   auto *concatNode = llvm::dyn_cast<ConcatNode>(save->getInput());
4454   ASSERT_TRUE(concatNode);
4455   // 12 FCs are connected to the concat node.
4456   EXPECT_EQ(12, concatNode->getInputs().size());
4457 
4458   // Check all splitted FCs.
4459   for (unsigned i = 0; i < 12; ++i) {
4460     auto *fc = llvm::dyn_cast<FullyConnectedNode>(concatNode->getNthInput(i));
4461     ASSERT_TRUE(fc);
4462     // 2 * 71 for first 11 FCs and last 2 * 69
4463     if (i == 11) {
4464       EXPECT_TRUE(fc->getResult().dims().equals({2, 69}));
4465       EXPECT_TRUE(fc->getBias().dims().equals({69}));
4466       EXPECT_TRUE(fc->getWeights().dims().equals({32, 69}));
4467     } else {
4468       EXPECT_TRUE(fc->getResult().dims().equals({2, 71}));
4469       EXPECT_TRUE(fc->getBias().dims().equals({71}));
4470       EXPECT_TRUE(fc->getWeights().dims().equals({32, 71}));
4471     }
4472   }
4473 
4474   checkNumericalEquivalence();
4475 }
4476 
4477 /// Test Splitting FC into multiple FCs.
TEST_F(GraphOptz,ParallelizeGraph_FC_ModelParallel)4478 TEST_F(GraphOptz, ParallelizeGraph_FC_ModelParallel) {
4479   auto *input =
4480       mod_.createPlaceholder(ElemKind::FloatTy, {8, 32}, "input", false);
4481   bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
4482                                                           mod_.getPRNG());
4483   auto *weights1 = mod_.createConstant(ElemKind::FloatTy, {32, 150}, "weights");
4484   weights1->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
4485   auto *bias1 = mod_.createConstant(ElemKind::FloatTy, {150}, "bias");
4486   bias1->getHandle().randomize(0.0, 0.5, mod_.getPRNG());
4487   auto *weights2 =
4488       mod_.createConstant(ElemKind::FloatTy, {150, 150}, "weights");
4489   weights2->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
4490   auto *bias2 = mod_.createConstant(ElemKind::FloatTy, {150}, "bias");
4491   bias2->getHandle().randomize(0.0, 0.5, mod_.getPRNG());
4492   auto *output =
4493       mod_.createPlaceholder(ElemKind::FloatTy, {8, 150}, "output", false);
4494   bindings_.allocate(output);
4495 
4496   auto *fc1 = F_->createFullyConnected("fc1", input, weights1, bias1);
4497   auto *relu1 = F_->createRELU("relu1", fc1);
4498 
4499   auto *fc2 = F_->createFullyConnected("fc2", relu1, weights2, bias2);
4500   auto *relu2 = F_->createRELU("relu2", fc2);
4501   F_->createSave("save", relu2, output);
4502 
4503   ::glow::optimize(F_, CompilationMode::Infer);
4504 
4505   // This is F_ but without the parallel transformation below.
4506   optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4507 
4508   // Perform parallel transformation on F_.
4509   llvm::DenseMap<Node *, size_t> numChunks;
4510   llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
4511   numChunks[fc1] = 2;
4512   numChunks[relu1] = 2;
4513   numChunks[fc2] = 2;
4514   numChunks[relu2] = 2;
4515   parOpts[fc1] = ParallelTransformKind::Model;
4516   parOpts[relu1] = ParallelTransformKind::Model;
4517   parOpts[fc2] = ParallelTransformKind::Model;
4518   parOpts[relu2] = ParallelTransformKind::Model;
4519   std::unordered_map<Node *, ConcatNode *> replacedMap;
4520   ASSIGN_VALUE_OR_FAIL_TEST(replacedMap,
4521                             ::glow::parallelizeOps(F_, numChunks, parOpts));
4522   EXPECT_EQ(replacedMap.size(), parOpts.size());
4523 
4524   runDCEPass(F_, cctx_);
4525 
4526   EXPECT_EQ(4, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
4527   EXPECT_EQ(4, countNodeKind(F_, Kinded::Kind::ReluNodeKind));
4528 
4529   checkNumericalEquivalence();
4530 }
4531 
4532 /// Test Splitting Add into multiple Adds.
TEST_F(GraphOptz,ParallelizeGraph_Add)4533 TEST_F(GraphOptz, ParallelizeGraph_Add) {
4534   auto *input1 =
4535       mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input1", false);
4536   bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0,
4537                                                            mod_.getPRNG());
4538   auto *input2 =
4539       mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input2", false);
4540   bindings_.allocate(input2)->getHandle<float>().randomize(-1.0, 1.0,
4541                                                            mod_.getPRNG());
4542   auto *output =
4543       mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output", false);
4544   bindings_.allocate(output);
4545 
4546   auto *add1 = F_->createAdd("add1", input1, input2);
4547   auto *add2 = F_->createAdd("add2", add1, add1);
4548   F_->createSave("save", add2, output);
4549 
4550   ::glow::optimize(F_, CompilationMode::Infer);
4551 
4552   // This is F_ but without the parallel transformation below.
4553   optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4554 
4555   llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
4556   parOpts[add1] = ParallelTransformKind::Data;
4557 
4558   std::unordered_map<Node *, ConcatNode *> replacedMap;
4559   ASSIGN_VALUE_OR_FAIL_TEST(
4560       replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(),
4561                                           parOpts, 12));
4562   EXPECT_EQ(replacedMap.size(), parOpts.size());
4563   runDCEPass(F_, cctx_);
4564 
4565   // We now have 12 Adds from add1, as well as the original add2 which is
4566   // unchanged.
4567   EXPECT_EQ(13, countNodeKind(F_, Kinded::Kind::AddNodeKind));
4568 
4569   // Each input of the 12 Adds are sliced.
4570   EXPECT_EQ(24, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
4571 
4572   // One concat to bring all of the parallelized sliced Adds together.
4573   EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
4574 
4575   checkNumericalEquivalence();
4576 }
4577 
4578 /// Test Splitting Transpose into multiple Transposes.
TEST_F(GraphOptz,ParallelizeGraph_Transpose)4579 TEST_F(GraphOptz, ParallelizeGraph_Transpose) {
4580   auto *input =
4581       mod_.createPlaceholder(ElemKind::FloatTy, {32, 151, 64}, "input", false);
4582   bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
4583                                                           mod_.getPRNG());
4584   auto *output =
4585       mod_.createPlaceholder(ElemKind::FloatTy, {32, 64, 151}, "output", false);
4586   bindings_.allocate(output);
4587 
4588   auto *trans1 = F_->createTranspose("trans1", input, {0, 2, 1});
4589   F_->createSave("save", trans1, output);
4590 
4591   ::glow::optimize(F_, CompilationMode::Infer);
4592 
4593   // This is F_ but without the parallel transformation below.
4594   optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4595 
4596   llvm::DenseMap<Node *, size_t> numChunks;
4597   llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
4598   numChunks[trans1] = 2;
4599   parOpts[trans1] = ParallelTransformKind::Data;
4600   std::unordered_map<Node *, ConcatNode *> replacedMap;
4601   ASSIGN_VALUE_OR_FAIL_TEST(replacedMap,
4602                             ::glow::parallelizeOps(F_, numChunks, parOpts));
4603   EXPECT_EQ(replacedMap.size(), parOpts.size());
4604 
4605   runDCEPass(F_, cctx_);
4606 
4607   EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::TransposeNodeKind));
4608 
4609   checkNumericalEquivalence();
4610 }
4611 
TEST_F(GraphOptz,SinkClipBelowReshape)4612 TEST_F(GraphOptz, SinkClipBelowReshape) {
4613   Placeholder *in =
4614       mod_.createPlaceholder(ElemKind::FloatTy, {10}, "input", false);
4615   ClipNode *clip = F_->createClip("clip", in, 0.2, 0.8);
4616   ReshapeNode *reshape = F_->createReshape("reshape", clip, {2, 5});
4617   SaveNode *save = F_->createSave("save", reshape);
4618 
4619   optimizedF_ = optimizeFunction(F_);
4620 
4621   // Same number of nodes, just swapped order.
4622   EXPECT_EQ(F_->getNodes().size(), 3);
4623   EXPECT_EQ(optimizedF_->getNodes().size(), 3);
4624 
4625   const SaveNode *optSave =
4626       findFunctionNodeByName<SaveNode>(optimizedF_, save->getName());
4627   ASSERT_TRUE(optSave);
4628   ClipNode *newClip = llvm::dyn_cast<ClipNode>(optSave->getInput());
4629   ASSERT_TRUE(newClip);
4630   ReshapeNode *newReshape = llvm::dyn_cast<ReshapeNode>(newClip->getInput());
4631   ASSERT_TRUE(newReshape);
4632   EXPECT_EQ(newReshape->getResult().dims(), reshape->getResult().dims());
4633 
4634   bindings_.allocate(mod_.getPlaceholders());
4635   bindings_.get(in)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
4636   checkNumericalEquivalence();
4637 }
4638 
4639 /// Test that Add after ConvTranspose is folded into Bias add when the actual
4640 /// Add is is a broadcast of the bias. Test \p RnL (right of left) side add.
foldConvTransposeAddIntoBiasAdd(PlaceholderBindings & bindings,Module & mod,Function * F,Function * & optF,bool RnL)4641 static void foldConvTransposeAddIntoBiasAdd(PlaceholderBindings &bindings,
4642                                             Module &mod, Function *F,
4643                                             Function *&optF, bool RnL) {
4644   dim_t batch = 2;
4645   dim_t inC = 2;
4646   dim_t outC = 5;
4647   dim_t inH = 3;
4648   dim_t inW = 3;
4649   unsigned_t kernel = 3;
4650   std::vector<uint32_t> pads = {0, 0, 0, 0};
4651   std::vector<uint32_t> stride = {1, 1};
4652 
4653   auto *input = mod.createPlaceholder(ElemKind::FloatTy, {2, inH, inW, inC},
4654                                       "input", false);
4655   auto *filter = mod.createPlaceholder(
4656       ElemKind::FloatTy, {outC, kernel, kernel, inC}, "filter", false);
4657 
4658   auto *bias = mod.createConstant(ElemKind::FloatTy, {outC}, "bias");
4659   bias->getPayloadMutable().getHandle<float>() = {1, 3, 5, 7, 9};
4660 
4661   std::pair<dim_t, dim_t> outHW = calculateConvTransposeOutputDims(
4662       inH, inW, {kernel, kernel}, stride, pads);
4663   auto outTy = mod.uniqueType(ElemKind::FloatTy,
4664                               {batch, outHW.first, outHW.second, outC});
4665 
4666   ConvTransposeNode *CTN =
4667       F->createConvTranspose("ConvTranspose", input, filter, bias, outTy,
4668                              {kernel, kernel}, stride, {0, 0, 0, 0}, 1);
4669 
4670   auto *CN = mod.createConstant(ElemKind::FloatTy,
4671                                 {batch, outHW.first, outHW.second, outC}, "c1");
4672   auto *AN = RnL ? F->createAdd("add", CN, CTN) : F->createAdd("add", CTN, CN);
4673 
4674   CN->getPayloadMutable().getHandle<float>() = {
4675       1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3,
4676       4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1,
4677       2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4,
4678       5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2,
4679       3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5,
4680       1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3,
4681       4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1,
4682       2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4,
4683       5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2,
4684       3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5,
4685       1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5};
4686 
4687   SaveNode *save = F->createSave("save", AN);
4688   bindings.allocate(save->getPlaceholder());
4689 
4690   EXPECT_EQ(F->getNodes().size(), 3);
4691   optF = optimizeFunction(F);
4692   EXPECT_EQ(optF->getNodes().size(), 2);
4693 
4694   const SaveNode *optSave =
4695       findFunctionNodeByName<SaveNode>(optF, save->getName());
4696 
4697   ConvTransposeNode *optCN =
4698       llvm::dyn_cast<ConvTransposeNode>(optSave->getInput());
4699   EXPECT_TRUE(optCN);
4700 
4701   Constant *optBias = llvm::dyn_cast<Constant>(optCN->getBias());
4702   EXPECT_TRUE(optBias);
4703 
4704   auto BH = optBias->getPayload().getHandle();
4705   EXPECT_EQ(BH.raw(0), 1 + 1);
4706   EXPECT_EQ(BH.raw(1), 2 + 3);
4707   EXPECT_EQ(BH.raw(2), 3 + 5);
4708   EXPECT_EQ(BH.raw(3), 4 + 7);
4709   EXPECT_EQ(BH.raw(4), 5 + 9);
4710 
4711   bindings.allocate(mod.getPlaceholders());
4712   bindings.get(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
4713   bindings.get(filter)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
4714 }
4715 
4716 /// Test that Add after ConvTranspose is folded into Bias add when the actual
4717 /// Add is is a broadcast of the bias.
TEST_F(GraphOptz,FoldConvTransposeAddIntoBiasAddRHS)4718 TEST_F(GraphOptz, FoldConvTransposeAddIntoBiasAddRHS) {
4719   foldConvTransposeAddIntoBiasAdd(bindings_, mod_, F_, optimizedF_, false);
4720   checkNumericalEquivalence();
4721 }
TEST_F(GraphOptz,FoldConvTransposeAddIntoBiasAddLHS)4722 TEST_F(GraphOptz, FoldConvTransposeAddIntoBiasAddLHS) {
4723   foldConvTransposeAddIntoBiasAdd(bindings_, mod_, F_, optimizedF_, true);
4724   checkNumericalEquivalence();
4725 }
4726 
4727 /// Test that MatMul + Add is folded into FullyConnected.
TEST_F(GraphOptz,FoldMatMulAddIntoFullyConnected)4728 TEST_F(GraphOptz, FoldMatMulAddIntoFullyConnected) {
4729 
4730   auto *input =
4731       mod_.createPlaceholder(ElemKind::FloatTy, {1, 3}, "input", false);
4732   auto *weights =
4733       mod_.createPlaceholder(ElemKind::FloatTy, {3, 5}, "weights", false);
4734   auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5}, "bias", false);
4735 
4736   MatMulNode *matmul = F_->createMatMul("matmul", input, weights);
4737   AddNode *add = F_->createAdd("add", matmul, bias);
4738   F_->createSave("save", add);
4739   EXPECT_EQ(3, F_->getNodes().size());
4740 
4741   // The folding should replace the MatMul + Add into a FullyConnected and a
4742   // Reshape to 1D for the Bias.
4743   CompilationContext cctx;
4744   ::glow::fold(F_, cctx);
4745   EXPECT_EQ(3, F_->getNodes().size());
4746   EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::AddNodeKind));
4747   EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::MatMulNodeKind));
4748   EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
4749   EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ReshapeNodeKind));
4750 }
4751 
4752 /// Test that batched MatMul + Add is folded into batched FullyConnected.
4753 /// This optimization takes place only if the Bias is constant and the
4754 /// bias data repeats for all the batches.
TEST_F(GraphOptz,FoldMatMulAddIntoFullyConnectedBatched)4755 TEST_F(GraphOptz, FoldMatMulAddIntoFullyConnectedBatched) {
4756 
4757   auto *input =
4758       mod_.createPlaceholder(ElemKind::FloatTy, {2, 3}, "input", false);
4759   auto *weights =
4760       mod_.createPlaceholder(ElemKind::FloatTy, {3, 5}, "weights", false);
4761   auto *bias = mod_.createConstant(ElemKind::FloatTy, {2, 5}, "bias");
4762   auto biasH = bias->getPayloadMutable().getHandle<float>();
4763   biasH = {1, 2, 3, 4, 5, 1, 2, 3, 4, 5};
4764 
4765   MatMulNode *matmul = F_->createMatMul("matmul", input, weights);
4766   AddNode *add = F_->createAdd("add", matmul, bias);
4767   F_->createSave("save", add);
4768   EXPECT_EQ(3, F_->getNodes().size());
4769 
4770   // The folding should replace the MatMul + Add into a FullyConnected and a
4771   // Reshape to 1D for the Bias.
4772   CompilationContext cctx;
4773   ::glow::fold(F_, cctx);
4774   EXPECT_EQ(4, F_->getNodes().size());
4775   EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::AddNodeKind));
4776   EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::MatMulNodeKind));
4777   EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
4778   EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
4779   EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ReshapeNodeKind));
4780 }
4781 
4782 /// Test that FoldSlicesIntoConstants pass works as expected.
TEST_F(GraphOptz,FoldSlicesIntoConstantsTest)4783 TEST_F(GraphOptz, FoldSlicesIntoConstantsTest) {
4784   Constant *C = mod_.createConstant(ElemKind::FloatTy, {3, 4}, "C");
4785   auto CH = C->getPayloadMutable().getHandle<float>();
4786   CH = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
4787 
4788   SliceNode *S1 = F_->createSlice("s1", C, {0, 0}, {3, 2});
4789   SliceNode *S2 = F_->createSlice("s2", C, {0, 2}, {3, 4});
4790   SaveNode *SN1 = F_->createSave("save1", S1);
4791   SaveNode *SN2 = F_->createSave("save2", S2);
4792 
4793   optimizedF_ = optimizeFunction(
4794       F_, {FunctionPassID::FoldSlicesIntoConstants, getDCEPassConfig()});
4795 
4796   SaveNode *optSN1 =
4797       llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN1->getName()));
4798   SaveNode *optSN2 =
4799       llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN2->getName()));
4800   ASSERT_TRUE(optSN1);
4801   ASSERT_TRUE(optSN2);
4802 
4803   Constant *C1 = llvm::dyn_cast<Constant>(optSN1->getInput());
4804   ASSERT_TRUE(C1);
4805   auto H1 = C1->getPayloadMutable().getHandle();
4806   Constant *C2 = llvm::dyn_cast<Constant>(optSN2->getInput());
4807   ASSERT_TRUE(C2);
4808   auto H2 = C2->getPayloadMutable().getHandle();
4809   for (dim_t i = 0, e = 3; i < e; i++) {
4810     for (dim_t j = 0, e = 2; j < e; j++) {
4811       EXPECT_EQ(H1.at({i, j}), CH.at({i, j}));
4812       EXPECT_EQ(H2.at({i, j}), CH.at({i, j + 2}));
4813     }
4814   }
4815 }
4816 
4817 /// Test that RaiseClipsAboveShapeNodes pass works as expected.
TEST_F(GraphOptz,RaiseClipsAboveShapeNodesTest)4818 TEST_F(GraphOptz, RaiseClipsAboveShapeNodesTest) {
4819   Placeholder *input =
4820       mod_.createPlaceholder(ElemKind::FloatTy, {256, 64}, "input", false);
4821 
4822   ReshapeNode *RN1 = F_->createReshape("reshape1", input, {4, 128, 32});
4823   ReshapeNode *RN2 = F_->createReshape("reshape2", RN1, {64, 256});
4824   TransposeNode *TN = F_->createTranspose("transpose", RN2, {1, 0});
4825   SliceNode *SN = F_->createSlice("slice", TN, {64, 0}, {256, 64});
4826   ClipNode *CN = F_->createClip("clip", SN, -0.1, 0.1);
4827   SaveNode *save1 = F_->createSave("save1", RN1);
4828   SaveNode *save2 = F_->createSave("save2", CN);
4829 
4830   optimizedF_ =
4831       optimizeFunction(F_, {FunctionPassID::RaiseClipsAboveShapeNodes});
4832 
4833   SaveNode *optSave1 =
4834       llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save1->getName()));
4835   ASSERT_TRUE(optSave1);
4836   SaveNode *optSave2 =
4837       llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save2->getName()));
4838   ASSERT_TRUE(optSave2);
4839 
4840   // save1 should only have a single untouched Reshape RN1 input which has input
4841   // input into it, because RN1 has multiple users.
4842   ReshapeNode *optRN1 =
4843       llvm::dyn_cast<ReshapeNode>(optSave1->getInput().getNode());
4844   ASSERT_TRUE(optRN1);
4845   EXPECT_EQ(input, optRN1->getInput().getNode());
4846 
4847   // save2 should have CN it originally saved pushed up above SN, TN, and RN2.
4848   SliceNode *newSN = llvm::dyn_cast<SliceNode>(optSave2->getInput());
4849   ASSERT_TRUE(newSN);
4850   EXPECT_EQ(newSN->getStart(), SN->getStart());
4851   TransposeNode *newTN = llvm::dyn_cast<TransposeNode>(newSN->getInput());
4852   ASSERT_TRUE(newTN);
4853   EXPECT_EQ(newTN->getShuffle(), TN->getShuffle());
4854   ReshapeNode *newRN2 = llvm::dyn_cast<ReshapeNode>(newTN->getInput());
4855   ASSERT_TRUE(newRN2);
4856   ClipNode *newCN = llvm::dyn_cast<ClipNode>(newRN2->getInput());
4857   ASSERT_TRUE(newCN);
4858   EXPECT_EQ(newCN->getMin(), CN->getMin());
4859   EXPECT_EQ(newCN->getMax(), CN->getMax());
4860 
4861   bindings_.allocate(mod_.getPlaceholders());
4862   bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
4863   checkNumericalEquivalence();
4864 }
4865 
testOptimizeDequantizeClip(PlaceholderBindings & bindings,Module & mod,Function * F,Function * & optF,bool enableQuantParamChanges)4866 static void testOptimizeDequantizeClip(PlaceholderBindings &bindings,
4867                                        Module &mod, Function *F,
4868                                        Function *&optF,
4869                                        bool enableQuantParamChanges) {
4870   Placeholder *input =
4871       mod.createPlaceholder(ElemKind::FloatTy, {20, 20}, "input", false);
4872 
4873   const auto qParams = quantization::chooseQuantizationParams({-0.1, 0.1});
4874 
4875   QuantizeNode *QN =
4876       F->createQuantize("quantize", input,
4877                         mod.uniqueType(ElemKind::Int8QTy, {20, 20},
4878                                        qParams.scale, qParams.offset));
4879   DequantizeNode *DN = F->createDequantize("dequantize", QN, ElemKind::FloatTy);
4880   ClipNode *CN =
4881       F->createClip("clip", DN, enableQuantParamChanges ? 0 : -100, 100);
4882   SaveNode *SN = F->createSave("save", CN);
4883 
4884   CompilationContext cctx;
4885   cctx.optimizationOpts.enableQuantParamChanges = true;
4886   optF = optimizeFunction(
4887       F, {FunctionPassID::OptimizeQuantizeClip, getDCEPassConfig()}, cctx);
4888 
4889   EXPECT_EQ(countNodeKind(optF, Kinded::Kind::ClipNodeKind), 0);
4890 
4891   SaveNode *optSN =
4892       llvm::dyn_cast<SaveNode>(optF->getNodeByName(SN->getName()));
4893   ASSERT_TRUE(optSN);
4894 
4895   // Now check that the quantization params have been correctly updated for QN,
4896   // and that CN has been eliminated.
4897   DequantizeNode *optDN =
4898       llvm::dyn_cast<DequantizeNode>(optSN->getInput().getNode());
4899   ASSERT_TRUE(optDN);
4900   const auto qMinMax = optDN->getInput().getType()->getQuantizedValueRange();
4901   // Min is either from Clip or Quant range depending on enableQuantParamChanges
4902   EXPECT_NEAR(qMinMax.first, enableQuantParamChanges ? 0 : -0.1, 1E-3);
4903   EXPECT_NEAR(qMinMax.second, 0.1, 1E-3); // Max from Quant range
4904 
4905   bindings.allocate(mod.getPlaceholders());
4906   bindings.get(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
4907 }
4908 
4909 /// Test that OptimizeQuantizeClip pass works as expected for Clip(Dequantize)
4910 /// when the quantization parameters are allowed to change.
TEST_F(GraphOptz,OptimizeDequantizeClipTest_QuantParamChanges)4911 TEST_F(GraphOptz, OptimizeDequantizeClipTest_QuantParamChanges) {
4912   testOptimizeDequantizeClip(bindings_, mod_, F_, optimizedF_,
4913                              /* enableQuantParamChanges */ true);
4914   checkNumericalEquivalence(0.0005);
4915 }
4916 
4917 /// Test that OptimizeQuantizeClip pass works as expected for Clip(Dequantize)
4918 /// when the quantization parameters are not allowed to change.
TEST_F(GraphOptz,OptimizeDequantizeClipTest_NoQuantParamChanges)4919 TEST_F(GraphOptz, OptimizeDequantizeClipTest_NoQuantParamChanges) {
4920   testOptimizeDequantizeClip(bindings_, mod_, F_, optimizedF_,
4921                              /* enableQuantParamChanges */ false);
4922   checkNumericalEquivalence();
4923 }
4924 
testOptimizeClipQuantize(PlaceholderBindings & bindings,Module & mod,Function * F,Function * & optF,bool enableQuantParamChanges)4925 static void testOptimizeClipQuantize(PlaceholderBindings &bindings, Module &mod,
4926                                      Function *F, Function *&optF,
4927                                      bool enableQuantParamChanges) {
4928   Placeholder *input =
4929       mod.createPlaceholder(ElemKind::FloatTy, {20, 20}, "input", false);
4930 
4931   const auto qParams = quantization::chooseQuantizationParams({-0.1, 0.1});
4932 
4933   ClipNode *CN =
4934       F->createClip("clip", input, enableQuantParamChanges ? 0 : -100, 100);
4935   QuantizeNode *QN =
4936       F->createQuantize("quantize", CN,
4937                         mod.uniqueType(ElemKind::Int8QTy, {20, 20},
4938                                        qParams.scale, qParams.offset));
4939   DequantizeNode *DN = F->createDequantize("dequantize", QN, ElemKind::FloatTy);
4940   SaveNode *SN = F->createSave("save", DN);
4941 
4942   CompilationContext cctx;
4943   cctx.optimizationOpts.enableQuantParamChanges = enableQuantParamChanges;
4944   optF = optimizeFunction(
4945       F, {FunctionPassID::OptimizeQuantizeClip, getDCEPassConfig()}, cctx);
4946 
4947   EXPECT_EQ(countNodeKind(optF, Kinded::Kind::ClipNodeKind), 0);
4948 
4949   SaveNode *optSN =
4950       llvm::dyn_cast<SaveNode>(optF->getNodeByName(SN->getName()));
4951   ASSERT_TRUE(optSN);
4952 
4953   // Now check that the quantization params have been correctly updated for QN,
4954   // and that CN has been eliminated.
4955   DequantizeNode *optDN =
4956       llvm::dyn_cast<DequantizeNode>(optSN->getInput().getNode());
4957   ASSERT_TRUE(optDN);
4958   const auto qMinMax = optDN->getInput().getType()->getQuantizedValueRange();
4959   // Min is either from Clip or Quant range depending on enableQuantParamChanges
4960   EXPECT_NEAR(qMinMax.first, enableQuantParamChanges ? 0 : -0.1, 1E-3);
4961   EXPECT_NEAR(qMinMax.second, 0.1, 1E-3); // Max always from Quant range
4962 
4963   bindings.allocate(mod.getPlaceholders());
4964   bindings.get(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
4965 }
4966 
4967 /// Test that OptimizeQuantizeClip pass works as expected for Clip(Quantize)
4968 /// when the quantization parameters are allowed to change.
TEST_F(GraphOptz,OptimizeClipQuantizeTest_QuantParamChanges)4969 TEST_F(GraphOptz, OptimizeClipQuantizeTest_QuantParamChanges) {
4970   testOptimizeClipQuantize(bindings_, mod_, F_, optimizedF_,
4971                            /* enableQuantParamChanges */ true);
4972   checkNumericalEquivalence(0.0005);
4973 }
4974 
4975 /// Test that OptimizeQuantizeClip pass works as expected for Clip(Quantize)
4976 /// when the quantization parameters are not allowed to change.
TEST_F(GraphOptz,OptimizeClipQuantizeTest_NoQuantParamChanges)4977 TEST_F(GraphOptz, OptimizeClipQuantizeTest_NoQuantParamChanges) {
4978   testOptimizeClipQuantize(bindings_, mod_, F_, optimizedF_,
4979                            /* enableQuantParamChanges */ false);
4980   checkNumericalEquivalence();
4981 }
4982 
4983 /// Test Quantize(ConvertTo(Node)) -> Quantize(Node), where Quantize is int8.
TEST_F(GraphOptz,OptimizeOutIntermediateConversionsTest)4984 TEST_F(GraphOptz, OptimizeOutIntermediateConversionsTest) {
4985   Placeholder *input =
4986       mod_.createPlaceholder(ElemKind::FloatTy, {20, 20}, "input", false);
4987 
4988   const auto qParams = quantization::chooseQuantizationParams({-0.1, 0.1});
4989 
4990   ConvertToNode *CN = F_->createConvertTo("conv", input, ElemKind::Float16Ty);
4991   QuantizeNode *QN =
4992       F_->createQuantize("quantize", CN,
4993                          mod_.uniqueType(ElemKind::Int8QTy, {20, 20},
4994                                          qParams.scale, qParams.offset));
4995   DequantizeNode *DN =
4996       F_->createDequantize("dequantize", QN, ElemKind::FloatTy);
4997   F_->createSave("save", DN);
4998 
4999   optimizedF_ =
5000       optimizeFunction(F_, {FunctionPassID::OptimizeOutIntermediateConversions,
5001                             getDCEPassConfig()});
5002 
5003   // Now check that the ConvertToNode has been eliminated.
5004   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConvertToNodeKind), 0);
5005 
5006   bindings_.allocate(mod_.getPlaceholders());
5007   bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
5008   checkNumericalEquivalence();
5009 }
5010 
5011 /// Test Clip(Relu(Clip)) -> Clip'.
TEST_F(GraphOptz,ClipReluClipElimTest)5012 TEST_F(GraphOptz, ClipReluClipElimTest) {
5013   Placeholder *input =
5014       mod_.createPlaceholder(ElemKind::FloatTy, {64, 64}, "input", false);
5015   ClipNode *CN1 = F_->createClip("CN1", input, -10, 30);
5016   ReluNode *RN = F_->createRELU("RN", CN1);
5017   ClipNode *CN2 = F_->createClip("CN2", RN, -5, 20);
5018   SaveNode *SN = F_->createSave("save", CN2);
5019 
5020   // Start with 2 clips, a relu, and a save.
5021   EXPECT_EQ(F_->getNodes().size(), 4);
5022   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ClipNodeKind), 2);
5023   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 1);
5024 
5025   optimizedF_ = optimizeFunction(F_);
5026 
5027   // Remove one of the clips and the relu.
5028   EXPECT_EQ(optimizedF_->getNodes().size(), 2);
5029   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ClipNodeKind), 1);
5030   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ReluNodeKind), 0);
5031 
5032   SaveNode *optSN =
5033       llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
5034   ASSERT_TRUE(optSN);
5035 
5036   // We combined all of the ranges into the single Clip.
5037   ClipNode *optCN = llvm::dyn_cast<ClipNode>(optSN->getInput());
5038   ASSERT_TRUE(optCN);
5039   EXPECT_EQ(optCN->getMin(), 0);
5040   EXPECT_EQ(optCN->getMax(), 20);
5041 
5042   bindings_.allocate(input)->getHandle().randomize(-50.0, 5.0, mod_.getPRNG());
5043   checkNumericalEquivalence();
5044 }
5045 
5046 /// Test that we can find a non-quantized relu and fuse it up into a quant FC.
TEST_F(GraphOptz,OptimizeQuantFCFloatReluTest)5047 TEST_F(GraphOptz, OptimizeQuantFCFloatReluTest) {
5048   auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 1.0, 0,
5049                                        "input", false);
5050   auto *weights =
5051       mod_.createConstant(ElemKind::Int8QTy, {32, 32}, 1.0, 0, "weights");
5052   auto *bias = mod_.createConstant(ElemKind::Int32QTy, {32}, 1.0, 0, "bias");
5053 
5054   auto *FC = F_->createFullyConnected("fc", input, weights, bias);
5055   auto *DN = F_->createDequantize("dq", FC, ElemKind::FloatTy);
5056   auto *RN = F_->createRELU("relu", DN);
5057   auto *SN = F_->createSave("save", RN);
5058 
5059   optimizedF_ = optimizeFunction(
5060       F_, {FunctionPassID::OptimizeQuantFCFloatRelu, getDCEPassConfig()});
5061 
5062   SaveNode *optSN =
5063       llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
5064   ASSERT_TRUE(optSN);
5065 
5066   DequantizeNode *optDN = llvm::dyn_cast<DequantizeNode>(optSN->getInput());
5067   ASSERT_TRUE(optDN);
5068   ReluNode *optRN = llvm::dyn_cast<ReluNode>(optDN->getInput());
5069   ASSERT_TRUE(optRN);
5070   auto rangeRN = optRN->getResult().getType()->getQuantizedValueRange();
5071   EXPECT_EQ(rangeRN.first, 0.0f);
5072   FullyConnectedNode *optFC =
5073       llvm::dyn_cast<FullyConnectedNode>(optRN->getInput());
5074   ASSERT_TRUE(optFC);
5075   auto rangeFC = optFC->getResult().getType()->getQuantizedValueRange();
5076   EXPECT_EQ(rangeRN.second, rangeFC.second);
5077 
5078   bindings_.allocate(input)->getHandle<int8_t>().randomize(-128, 127,
5079                                                            mod_.getPRNG());
5080   weights->getPayloadMutable().getHandle<int8_t>().randomize(-128, 127,
5081                                                              mod_.getPRNG());
5082   bias->getPayloadMutable().getHandle<int32_t>().randomize(-128, 127,
5083                                                            mod_.getPRNG());
5084   checkNumericalEquivalence();
5085 }
5086 
5087 /// Test that we can find a non-quantized relu and fuse it up into a series of
5088 /// concatenated quant FCs.
TEST_F(GraphOptz,OptimizeConcatQuantFCFloatReluTest)5089 TEST_F(GraphOptz, OptimizeConcatQuantFCFloatReluTest) {
5090   std::array<NodeValue, 5> DQs;
5091   for (size_t i = 0; i < 5; i++) {
5092     auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32},
5093                                          1.0 / (i + 1), 0, "input", false);
5094     auto *weights =
5095         mod_.createConstant(ElemKind::Int8QTy, {32, 32}, 1.0, 0, "weights");
5096     auto *bias = mod_.createConstant(ElemKind::Int32QTy, {32}, 1.0, 0, "bias");
5097 
5098     auto *FC = F_->createFullyConnected("fc", input, weights, bias);
5099     DQs[i] = F_->createDequantize("dq", FC, ElemKind::FloatTy)->getResult();
5100 
5101     bindings_.allocate(input)->getHandle<int8_t>().randomize(-128, 127,
5102                                                              mod_.getPRNG());
5103     weights->getPayloadMutable().getHandle<int8_t>().randomize(-128, 127,
5104                                                                mod_.getPRNG());
5105     bias->getPayloadMutable().getHandle<int32_t>().randomize(-128, 127,
5106                                                              mod_.getPRNG());
5107   }
5108 
5109   auto *CN = F_->createConcat("concat", DQs, 0);
5110   auto *RN = F_->createRELU("relu", CN);
5111   auto *SN = F_->createSave("save", RN);
5112 
5113   optimizedF_ = optimizeFunction(
5114       F_, {FunctionPassID::OptimizeQuantFCFloatRelu, getDCEPassConfig()});
5115 
5116   SaveNode *optSN =
5117       llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
5118   ASSERT_TRUE(optSN);
5119   ConcatNode *optCN = llvm::dyn_cast<ConcatNode>(optSN->getInput());
5120   ASSERT_TRUE(optCN);
5121   EXPECT_EQ(optCN->getInputs().size(), 5);
5122 
5123   for (const NodeValue NV : optCN->getInputs()) {
5124     DequantizeNode *optDN = llvm::dyn_cast<DequantizeNode>(NV);
5125     ASSERT_TRUE(optDN);
5126     ReluNode *optRN = llvm::dyn_cast<ReluNode>(optDN->getInput());
5127     ASSERT_TRUE(optRN);
5128     auto rangeRN = optRN->getResult().getType()->getQuantizedValueRange();
5129     EXPECT_EQ(rangeRN.first, 0.0f);
5130     FullyConnectedNode *optFC =
5131         llvm::dyn_cast<FullyConnectedNode>(optRN->getInput());
5132     ASSERT_TRUE(optFC);
5133     auto rangeFC = optFC->getResult().getType()->getQuantizedValueRange();
5134     EXPECT_EQ(rangeRN.second, rangeFC.second);
5135   }
5136 
5137   checkNumericalEquivalence();
5138 }
5139 
5140 /// Test that we can find a concat with all dequantize inputs and a quantize at
5141 /// its output, and then replace quant/dequants with rescales.
TEST_F(GraphOptz,OptimizeDequantConcatQuant)5142 TEST_F(GraphOptz, OptimizeDequantConcatQuant) {
5143   std::array<NodeValue, 5> DQs;
5144   std::array<Placeholder *, 5> inputs;
5145   for (size_t i = 0; i < 5; i++) {
5146     inputs[i] = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32},
5147                                        0.3 / (i + 1), 5, "input", false);
5148     DQs[i] =
5149         F_->createDequantize("dq", inputs[i], ElemKind::FloatTy)->getResult();
5150 
5151     bindings_.allocate(inputs[i])->getHandle<int8_t>().randomize(
5152         -128, 127, mod_.getPRNG());
5153   }
5154 
5155   auto *CN = F_->createConcat("concat", DQs, 0);
5156   constexpr float scale = 0.3;
5157   constexpr int32_t offset = 5;
5158   auto *RN = F_->createQuantize("quantize", CN,
5159                                 mod_.uniqueType(ElemKind::Int8QTy,
5160                                                 CN->getResult().dims(), scale,
5161                                                 offset));
5162   auto *SN = F_->createSave("save", RN);
5163 
5164   optimizedF_ = optimizeFunction(
5165       F_, {FunctionPassID::OptimizeConcatQuantization, getDCEPassConfig()});
5166 
5167   SaveNode *optSN =
5168       llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
5169   ASSERT_TRUE(optSN);
5170   ConcatNode *optCN = llvm::dyn_cast<ConcatNode>(optSN->getInput());
5171   ASSERT_TRUE(optCN);
5172   EXPECT_EQ(optCN->getInputs().size(), 5);
5173 
5174   for (size_t i = 0, e = optCN->getInputs().size(); i < e; i++) {
5175     const NodeValue NV = optCN->getInputs()[i];
5176     if (i == 0) {
5177       EXPECT_EQ(inputs[i], NV.getNode());
5178       EXPECT_EQ(inputs[i]->getOutput().getType()->getScale(), scale);
5179       EXPECT_EQ(inputs[i]->getOutput().getType()->getOffset(), offset);
5180     } else {
5181       RescaleQuantizedNode *optRN = llvm::dyn_cast<RescaleQuantizedNode>(NV);
5182       ASSERT_TRUE(optRN);
5183       EXPECT_EQ(optRN->getResult().getType()->getScale(), scale);
5184       EXPECT_EQ(optRN->getResult().getType()->getOffset(), offset);
5185       EXPECT_EQ(inputs[i], optRN->getInput().getNode());
5186     }
5187   }
5188   checkNumericalEquivalence();
5189 }
5190 
5191 /// Test that if we have a Concat with all Dequantize inputs with the same
5192 /// scale/offset/kind that we can sink the Dequantizes below the Concat.
TEST_F(GraphOptz,SinkDequantizeBelowConcatTest)5193 TEST_F(GraphOptz, SinkDequantizeBelowConcatTest) {
5194   const float scale = 0.06;
5195   const int32_t offset = -15;
5196   std::array<NodeValue, 5> inputs;
5197   for (dim_t i = 0; i < 5; i++) {
5198     Placeholder *input = mod_.createPlaceholder(ElemKind::Int8QTy, {i + 1, 100},
5199                                                 scale, offset, "input", false);
5200     bindings_.allocate(input)->getHandle<int8_t>().randomize(-100, 100,
5201                                                              mod_.getPRNG());
5202     DequantizeNode *dequantize =
5203         F_->createDequantize("dequantize", input, ElemKind::Float16Ty);
5204     inputs[i] = dequantize->getResult();
5205   }
5206   ConcatNode *concat = F_->createConcat("concat", inputs, 0);
5207   SaveNode *SN = F_->createSave("ret", concat);
5208 
5209   optimizedF_ = optimizeFunction(
5210       F_, {FunctionPassID::SinkConversions, getDCEPassConfig()});
5211 
5212   // Concat, dequantize, save.
5213   EXPECT_EQ(optimizedF_->getNodes().size(), 3);
5214   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::DequantizeNodeKind), 1);
5215   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1);
5216   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 1);
5217 
5218   SaveNode *optSN =
5219       llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
5220   ASSERT_TRUE(optSN);
5221   DequantizeNode *optDequantize =
5222       llvm::dyn_cast<DequantizeNode>(optSN->getInput());
5223   ASSERT_TRUE(optDequantize);
5224   NodeValue input = optDequantize->getInput();
5225   EXPECT_EQ(scale, input.getType()->getScale());
5226   EXPECT_EQ(offset, input.getType()->getOffset());
5227   EXPECT_EQ(ElemKind::Int8QTy, input.getType()->getElementType());
5228 
5229   // Find dequantize node in the optimized graph.
5230   checkNumericalEquivalence();
5231 }
5232 
5233 /// Test that if we have a Concat with all Quantize inputs with the same
5234 /// scale/offset/kind that we can sink the Dequantizes below the Concat.
TEST_F(GraphOptz,SinkQuantizeBelowConcatTest)5235 TEST_F(GraphOptz, SinkQuantizeBelowConcatTest) {
5236   const float scale = 0.06;
5237   const int32_t offset = -15;
5238   std::array<NodeValue, 5> inputs;
5239   for (dim_t i = 0; i < 5; i++) {
5240     Placeholder *input = mod_.createPlaceholder(ElemKind::Float16Ty,
5241                                                 {i + 1, 100}, "input", false);
5242     bindings_.allocate(input)->getHandle<float16_t>().randomize(-100, 100,
5243                                                                 mod_.getPRNG());
5244     const TypeRef QTy = mod_.uniqueType(
5245         ElemKind::Int8QTy, input->getOutput().dims(), scale, offset);
5246     QuantizeNode *quantize = F_->createQuantize("quantize", input, QTy);
5247     inputs[i] = quantize->getResult();
5248   }
5249   ConcatNode *concat = F_->createConcat("concat", inputs, 0);
5250   SaveNode *SN = F_->createSave("ret", concat);
5251 
5252   optimizedF_ = optimizeFunction(
5253       F_, {FunctionPassID::SinkConversions, getDCEPassConfig()});
5254 
5255   // Concat, quantize, save.
5256   EXPECT_EQ(optimizedF_->getNodes().size(), 3);
5257   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::QuantizeNodeKind), 1);
5258   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1);
5259   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 1);
5260 
5261   SaveNode *optSN =
5262       llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
5263   ASSERT_TRUE(optSN);
5264   QuantizeNode *optQuantize = llvm::dyn_cast<QuantizeNode>(optSN->getInput());
5265   ASSERT_TRUE(optQuantize);
5266   EXPECT_EQ(scale, optQuantize->getResult().getType()->getScale());
5267   EXPECT_EQ(offset, optQuantize->getResult().getType()->getOffset());
5268   EXPECT_EQ(ElemKind::Int8QTy,
5269             optQuantize->getResult().getType()->getElementType());
5270 
5271   // Find quantize node in the optimized graph.
5272   checkNumericalEquivalence();
5273 }
5274 
5275 /// Test Clip(Relu) -> Clip'.
TEST_F(GraphOptz,ClipReluTest)5276 TEST_F(GraphOptz, ClipReluTest) {
5277   Placeholder *input =
5278       mod_.createPlaceholder(ElemKind::Float16Ty, {64, 64}, "input", false);
5279   ReluNode *RN = F_->createRELU("RN", input);
5280   ClipNode *CN = F_->createClip("CN", RN, -5, 20);
5281   SaveNode *SN = F_->createSave("save", CN);
5282 
5283   // Start with a clip, a relu, and a save.
5284   EXPECT_EQ(F_->getNodes().size(), 3);
5285   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ClipNodeKind), 1);
5286   EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 1);
5287 
5288   optimizedF_ = optimizeFunction(F_);
5289 
5290   // Removed the relu
5291   EXPECT_EQ(optimizedF_->getNodes().size(), 2);
5292   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ClipNodeKind), 1);
5293   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ReluNodeKind), 0);
5294 
5295   SaveNode *optSN =
5296       llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
5297   ASSERT_TRUE(optSN);
5298 
5299   // We have the same max for clip as before, but 0 for min due to the Relu.
5300   ClipNode *optCN = llvm::dyn_cast<ClipNode>(optSN->getInput());
5301   ASSERT_TRUE(optCN);
5302   EXPECT_EQ(optCN->getMin(), 0);
5303   EXPECT_EQ(optCN->getMax(), 20);
5304 
5305   bindings_.allocate(input)->getHandle<float16_t>().randomize(-50.0, 5.0,
5306                                                               mod_.getPRNG());
5307   checkNumericalEquivalence();
5308 }
5309 
5310 /// Test that if we have a concat with some dequantize inputs that are
5311 /// concatenated together, and then a quantize after the concat, that we can
5312 /// move the quantize above the concat and eliminate the dequantizes.
TEST_F(GraphOptz,SinkConcatBelowQuantize)5313 TEST_F(GraphOptz, SinkConcatBelowQuantize) {
5314   const float scale = 0.06;
5315   const int32_t offset = -15;
5316   std::array<NodeValue, 3> inputs;
5317 
5318   // Concat input 0: Dequant(PH)
5319   const TypeRef in0QTy =
5320       mod_.uniqueType(ElemKind::Int8QTy, {1, 3}, scale, offset);
5321   Placeholder *input0 = mod_.createPlaceholder(in0QTy, "input", false);
5322   inputs[0] =
5323       F_->createDequantize("deq", input0, ElemKind::Float16Ty)->getResult();
5324 
5325   // Concat input 1: Dequant(Add(PH, PH))
5326   const TypeRef in1QTy =
5327       mod_.uniqueType(ElemKind::Int8QTy, {5, 3}, scale, offset + 1);
5328   Placeholder *input1 = mod_.createPlaceholder(in1QTy, "input", false);
5329   AddNode *add = F_->createAdd("add", input1, input1);
5330   inputs[1] =
5331       F_->createDequantize("deq", add, ElemKind::Float16Ty)->getResult();
5332 
5333   // Concat input 2: PH
5334   Placeholder *input2 =
5335       mod_.createPlaceholder(ElemKind::Float16Ty, {10, 3}, "input_fp", false);
5336   inputs[2] = input2->getOutput();
5337 
5338   // Concat all 3 together, all FP16.
5339   ConcatNode *concat = F_->createConcat("concat", inputs, 0);
5340 
5341   // Now quantize the result of the concat.
5342   const TypeRef QTy = mod_.uniqueType(
5343       ElemKind::Int8QTy, concat->getResult().dims(), scale, offset);
5344   QuantizeNode *QN = F_->createQuantize("quantize", concat, QTy);
5345   SaveNode *SN = F_->createSave("ret", QN);
5346 
5347   optimizedF_ = optimizeFunction(F_, {FunctionPassID::SinkConcatBelowQuantize,
5348                                       {FunctionPassID::OptimizeQuantization,
5349                                        ConvergenceMode::UntilFixedPoint},
5350                                       getDCEPassConfig()});
5351 
5352   EXPECT_EQ(optimizedF_->getNodes().size(), 4);
5353   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1);
5354   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::AddNodeKind), 1);
5355   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::QuantizeNodeKind), 1);
5356   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 1);
5357 
5358   SaveNode *optSN =
5359       llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
5360   ASSERT_TRUE(optSN);
5361 
5362   // Concat should be directly connected to save, with same quantization
5363   // parameters as the quantize which used to follow it.
5364   ConcatNode *optCN = llvm::dyn_cast<ConcatNode>(optSN->getInput());
5365   ASSERT_TRUE(optCN);
5366   ASSERT_EQ(ElemKind::Int8QTy, optCN->getResult().getType()->getElementType());
5367   EXPECT_EQ(scale, optCN->getResult().getType()->getScale());
5368   EXPECT_EQ(offset, optCN->getResult().getType()->getOffset());
5369 
5370   ASSERT_EQ(optCN->getInputs().size(), 3);
5371 
5372   // No rescale here for the PH since its scale/offset match the PH and so
5373   // are optimized away.
5374   EXPECT_EQ(optCN->getInputs()[0], input0->getOutput());
5375 
5376   // No rescale here because it should be fused into optAN. Check the
5377   // scale/offset use that scale/offset.
5378   AddNode *optAN = llvm::dyn_cast<AddNode>(optCN->getInputs()[1]);
5379   ASSERT_TRUE(optAN);
5380   ASSERT_EQ(ElemKind::Int8QTy, optAN->getResult().getType()->getElementType());
5381   EXPECT_EQ(scale, optAN->getResult().getType()->getScale());
5382   EXPECT_EQ(offset, optAN->getResult().getType()->getOffset());
5383   EXPECT_EQ(optAN->getLHS(), input1->getOutput());
5384   EXPECT_EQ(optAN->getRHS(), input1->getOutput());
5385 
5386   // Must quantize this input since the PH is float16.
5387   QuantizeNode *optQN = llvm::dyn_cast<QuantizeNode>(optCN->getInputs()[2]);
5388   ASSERT_TRUE(optQN);
5389   ASSERT_EQ(ElemKind::Int8QTy, optQN->getResult().getType()->getElementType());
5390   EXPECT_EQ(scale, optQN->getResult().getType()->getScale());
5391   EXPECT_EQ(offset, optQN->getResult().getType()->getOffset());
5392   EXPECT_EQ(optQN->getInput(), input2->getOutput());
5393 
5394   bindings_.allocate(input0)->getHandle<int8_t>().randomize(-50, 50,
5395                                                             mod_.getPRNG());
5396   bindings_.allocate(input1)->getHandle<int8_t>().randomize(-50, 50,
5397                                                             mod_.getPRNG());
5398   bindings_.allocate(input2)->getHandle<float16_t>().randomize(-10, 10,
5399                                                                mod_.getPRNG());
5400 }
5401 
TEST_F(GraphOptz,EliminateSliceConcatTest)5402 TEST_F(GraphOptz, EliminateSliceConcatTest) {
5403   auto *src1 =
5404       mod_.createPlaceholder(ElemKind::FloatTy, {10, 70}, "src1", false);
5405   auto *src2 =
5406       mod_.createPlaceholder(ElemKind::FloatTy, {10, 80}, "src2", false);
5407   auto *A = F_->createSlice("A", src1, {0, 0}, {10, 10});
5408   auto *B = F_->createSlice("B", src1, {0, 10}, {10, 20});
5409   auto *C = F_->createSlice("C", src1, {0, 20}, {10, 30});
5410   // interleaved Slices with different sources shouldn't merge
5411   auto *E = F_->createSlice("E", src1, {0, 30}, {10, 40});
5412   auto *F = F_->createSlice("F", src2, {0, 30}, {10, 40});
5413   auto *G = F_->createSlice("G", src1, {0, 40}, {10, 50});
5414   auto *H = F_->createSlice("H", src2, {0, 40}, {10, 50});
5415 
5416   auto *D = mod_.createPlaceholder(ElemKind::FloatTy, {10, 50}, "D", false);
5417   auto *R = F_->createRELU("Relu", C);
5418   auto *CN = F_->createConcat("Concat", {A, B, D, E, F, G, H}, 1);
5419   F_->createSave("save1", CN);
5420   F_->createSave("save2", R);
5421 
5422   EXPECT_EQ(F_->getNodes().size(), 11);
5423 
5424   optimizedF_ = optimizeFunction(
5425       F_, {FunctionPassID::EliminateSliceConcat, getDCEPassConfig()});
5426 
5427   EXPECT_EQ(optimizedF_->getNodes().size(), 10);
5428 
5429   int numSlicesToConcat = 0;
5430   for (const auto &node : optimizedF_->getNodes()) {
5431     auto *newCN = llvm::dyn_cast<ConcatNode>(&node);
5432     if (!newCN) {
5433       continue;
5434     }
5435     EXPECT_EQ(newCN->getInputs().size(), 6);
5436     for (const auto &concatInput : newCN->getInputs()) {
5437       auto *SN = llvm::dyn_cast<SliceNode>(concatInput.getNode());
5438       if (SN) {
5439         numSlicesToConcat++;
5440       }
5441     }
5442   }
5443   EXPECT_EQ(numSlicesToConcat, 5);
5444 
5445   bindings_.allocate(src1)->getHandle<float>().randomize(-10.0, 10.0,
5446                                                          mod_.getPRNG());
5447   bindings_.allocate(src2)->getHandle<float>().randomize(-10.0, 10.0,
5448                                                          mod_.getPRNG());
5449   bindings_.allocate(D)->getHandle<float>().randomize(-10.0, 10.0,
5450                                                       mod_.getPRNG());
5451   checkNumericalEquivalence();
5452 }
5453 
5454 /// Verify that when we want to prevent constant folding it doesn't occur.
TEST_F(GraphOptz,constantFoldPreventedNoop)5455 TEST_F(GraphOptz, constantFoldPreventedNoop) {
5456   auto *const1 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const1");
5457   auto *const2 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const2");
5458   auto *ph1 = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "input1",
5459                                      /* isTrainable */ false);
5460   setConstValue(const1, 1.0f);
5461   setConstValue(const2, 2.0f);
5462   auto *splat2 = F_->createSplat(
5463       "splat2", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 2.0f);
5464   auto *splat3 = F_->createSplat(
5465       "splat3", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 3.0f);
5466 
5467   auto *add1 = F_->createAdd("add", const1, const2);
5468   auto *mul1 = F_->createMul("mul1", add1, splat2);
5469   auto *mul2 = F_->createMul("mul2", mul1, splat3);
5470   F_->createSave("save", mul2);
5471   auto *add3 = F_->createAdd("add", const1, ph1);
5472   F_->createSave("save", add3);
5473 
5474   ConstantModificationPreventer constModPreventer(mod_);
5475   constModPreventer.activate();
5476 
5477   // Check that both Constants are protected and no change is made to the
5478   // Function during optimization.
5479   EXPECT_EQ(constModPreventer.getMapping().size(), 2);
5480   optimizedF_ = optimizeFunction(F_);
5481   EXPECT_EQ(F_->toString(/* skipUsersForStorage */ false,
5482                          /* skipName */ true),
5483             optimizedF_->toString(/* skipUsersForStorage */ false,
5484                                   /* skipName */ true));
5485 
5486   // Now deactivate the constModPreventer and check we can const fold still.
5487   constModPreventer.deactivateAndCleanup();
5488   mod_.eraseFunction(optimizedF_);
5489   optimizedF_ = optimizeFunction(F_);
5490 
5491   // After constant folding, left with just two Saves, one Add.
5492   EXPECT_EQ(optimizedF_->getNodes().size(), 3);
5493   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::AddNodeKind), 1);
5494   EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 2);
5495 
5496   bindings_.allocate(ph1)->getHandle<float>().randomize(-10.0, 10.0,
5497                                                         mod_.getPRNG());
5498   checkNumericalEquivalence();
5499 }
5500 
5501 /// Test that a Conv2D is correctly lowered to FC for single batch.
TEST_F(GraphOptz,lowerConv2DToFCSingleBatch)5502 TEST_F(GraphOptz, lowerConv2DToFCSingleBatch) {
5503   Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4},
5504                                               "input", /* isTrainable */ false);
5505   bindings_.allocate(input)->getHandle<float>().randomize(-10, 10,
5506                                                           mod_.getPRNG());
5507 
5508   Constant *filter =
5509       mod_.createConstant(ElemKind::FloatTy, {8, 1, 1, 4}, "filter");
5510   filter->getPayloadMutable().getHandle<float>().randomize(-10, 10,
5511                                                            mod_.getPRNG());
5512 
5513   Constant *bias = mod_.createConstant(ElemKind::FloatTy, {8}, "bias");
5514   bias->getPayloadMutable().getHandle<float>().randomize(-10, 10,
5515                                                          mod_.getPRNG());
5516 
5517   auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 2, 3, 8});
5518   auto *conv = F_->createConv("conv", input, filter, bias, outTy, {1, 1},
5519                               {1, 1}, {0, 0, 0, 0}, 1, 1);
5520   SaveNode *save = F_->createSave("save", conv);
5521   bindings_.allocate(save->getPlaceholder());
5522 
5523   // Backup function in optimizedF_.
5524   optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5525 
5526   // Lower Convolution.
5527   EXPECT_TRUE(isConvolutionSameAsFullyConnected(conv));
5528   EXPECT_TRUE(glow::lowerNode(F_, conv, cctx_));
5529   runDCEPass(F_, cctx_);
5530   EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::ConvolutionNodeKind));
5531   EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
5532 
5533   // Now compile/run/compare F_ and optimizedF_.
5534   checkNumericalEquivalence(1e-6);
5535 }
5536 
5537 /// Test that a Conv2D is correctly lowered to FC for multi batch.
TEST_F(GraphOptz,lowerConv2DToFCMultiBatch)5538 TEST_F(GraphOptz, lowerConv2DToFCMultiBatch) {
5539   Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 3, 4},
5540                                               "input", /* isTrainable */ false);
5541   bindings_.allocate(input)->getHandle<float>().randomize(-10, 10,
5542                                                           mod_.getPRNG());
5543 
5544   Constant *filter =
5545       mod_.createConstant(ElemKind::FloatTy, {8, 1, 1, 4}, "filter");
5546   filter->getPayloadMutable().getHandle<float>().randomize(-10, 10,
5547                                                            mod_.getPRNG());
5548 
5549   Constant *bias = mod_.createConstant(ElemKind::FloatTy, {8}, "bias");
5550   bias->getPayloadMutable().getHandle<float>().randomize(-10, 10,
5551                                                          mod_.getPRNG());
5552 
5553   auto outTy = mod_.uniqueType(ElemKind::FloatTy, {2, 2, 3, 8});
5554   auto *conv = F_->createConv("conv", input, filter, bias, outTy, {1, 1},
5555                               {1, 1}, {0, 0, 0, 0}, 1, 1);
5556   SaveNode *save = F_->createSave("save", conv);
5557   bindings_.allocate(save->getPlaceholder());
5558 
5559   // Backup function in optimizedF_.
5560   optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5561 
5562   // Lower Convolution.
5563   EXPECT_TRUE(isConvolutionSameAsFullyConnected(conv));
5564   EXPECT_TRUE(glow::lowerNode(F_, conv, cctx_));
5565   runDCEPass(F_, cctx_);
5566   EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::ConvolutionNodeKind));
5567   EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
5568 
5569   // Now compile/run/compare F_ and optimizedF_.
5570   checkNumericalEquivalence(1e-6);
5571 }
5572 
5573 /// Test that Mul and Add can be folded into LayerNorm.
TEST_F(GraphOptz,foldMulAddIntoLayerNorm)5574 TEST_F(GraphOptz, foldMulAddIntoLayerNorm) {
5575   auto *input =
5576       mod_.createPlaceholder(ElemKind::FloatTy, {2, 4, 10, 20}, "in", false);
5577 
5578   Tensor scaleT(ElemKind::FloatTy, {10, 20});
5579   scaleT.getHandle().randomize(0.0f, 1.0f, mod_.getPRNG());
5580   Constant *scaleC = mod_.createConstant("scale", std::move(scaleT));
5581   SplatNode *biasS = F_->createSplat("bias", scaleC->getType(), 1.5f);
5582 
5583   auto *LN = F_->createLayerNormalization("LN", input, scaleC, biasS, 1e-5);
5584 
5585   SplatNode *splat = F_->createSplat("splat", scaleC->getType(), 0.5f);
5586   MulNode *MN =
5587       F_->createNodeWithBroadcast<MulNode>("mul", /* axis */ -1, LN, splat);
5588 
5589   Tensor addT(ElemKind::FloatTy, {1, 1, 10, 20});
5590   addT.getHandle().randomize(-1.0f, 1.0f, mod_.getPRNG());
5591   Constant *addC = mod_.createConstant("addC", std::move(addT));
5592   AddNode *AN =
5593       F_->createNodeWithBroadcast<AddNode>("add", /* axis */ -1, MN, addC);
5594   F_->createSave("save", AN);
5595 
5596   optimizedF_ = optimizeFunction(F_);
5597 
5598   // Because Mul and Add are folded in, they should not exist anymore, nor
5599   // should tiles that expand them to match the output of LN.
5600   EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::MulNodeKind));
5601   EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::AddNodeKind));
5602   EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::TileNodeKind));
5603 
5604   // Now compile/run/compare F_ and optimizedF_.
5605   bindings_.allocate(input)->getHandle().randomize(0.0f, 1.0f, mod_.getPRNG());
5606   checkNumericalEquivalence(1e-6);
5607 }
5608 
5609 /// Test that Mul and Add can be folded into LayerNorm when the leading dims are
5610 /// all one.
TEST_F(GraphOptz,foldMulAddIntoLayerNormNoBatch)5611 TEST_F(GraphOptz, foldMulAddIntoLayerNormNoBatch) {
5612   auto *input =
5613       mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 10, 20}, "in", false);
5614 
5615   Tensor scaleT(ElemKind::FloatTy, {10, 20});
5616   scaleT.getHandle().randomize(0.0f, 1.0f, mod_.getPRNG());
5617   Constant *scaleC = mod_.createConstant("scale", std::move(scaleT));
5618   SplatNode *biasS = F_->createSplat("bias", scaleC->getType(), 1.5f);
5619 
5620   auto *LN = F_->createLayerNormalization("LN", input, scaleC, biasS, 1e-5);
5621 
5622   SplatNode *splat = F_->createSplat("splat", scaleC->getType(), 0.5f);
5623   MulNode *MN =
5624       F_->createNodeWithBroadcast<MulNode>("mul", /* axis */ -1, LN, splat);
5625 
5626   Tensor addT(ElemKind::FloatTy, {1, 1, 10, 20});
5627   addT.getHandle().randomize(-1.0f, 1.0f, mod_.getPRNG());
5628   Constant *addC = mod_.createConstant("addC", std::move(addT));
5629   AddNode *AN =
5630       F_->createNodeWithBroadcast<AddNode>("add", /* axis */ -1, MN, addC);
5631   F_->createSave("save", AN);
5632 
5633   optimizedF_ = optimizeFunction(F_);
5634 
5635   // Because Mul and Add are folded in, they should not exist anymore, nor
5636   // should tiles that expand them to match the output of LN.
5637   EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::MulNodeKind));
5638   EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::AddNodeKind));
5639   EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::TileNodeKind));
5640 
5641   // Now compile/run/compare F_ and optimizedF_.
5642   bindings_.allocate(input)->getHandle().randomize(0.0f, 1.0f, mod_.getPRNG());
5643   checkNumericalEquivalence(1e-6);
5644 }
5645 
TEST_F(GraphOptz,transposeQuantizeConstantWithAlignment)5646 TEST_F(GraphOptz, transposeQuantizeConstantWithAlignment) {
5647   // Define a type with custom alignments.
5648   Type typeWithAlignments(ElemKind::FloatTy, {2, 3, 4, 5}, {1, 1, 32, 1});
5649   Type quantTypeWithAlignments(ElemKind::Int8QTy, {2, 3, 4, 5}, {1, 1, 32, 1},
5650                                1.0, 0);
5651   Type transposedQuantTypeWithAlignments(ElemKind::Int8QTy, {2, 4, 5, 3},
5652                                          {1, 1, 32, 1}, 1.0, 0);
5653   auto modTyWithAlignments = mod_.uniqueType(typeWithAlignments);
5654   auto modQuantTransposedTyWithAlignments =
5655       mod_.uniqueType(transposedQuantTypeWithAlignments);
5656   auto modQuantTyWithAlignments = mod_.uniqueType(quantTypeWithAlignments);
5657   auto *I = mod_.createConstant(modTyWithAlignments, "input1");
5658   auto *Q = F_->createQuantize("quantize", I, modQuantTyWithAlignments);
5659   auto *T = F_->createTranspose("transpose", Q, NCHW2NHWC);
5660   T->setType(TransposeNode::ResultIdx, modQuantTransposedTyWithAlignments);
5661   SaveNode *S = F_->createSave("ret", T);
5662 
5663   // Skip ConstantFolding as it would have the same result as this opt.
5664   CompilationContext cctx;
5665   cctx.optimizationOpts.enableConstantFolding = false;
5666 
5667   EXPECT_EQ(F_->getNodes().size(), 3);
5668   ::glow::optimize(F_, cctx);
5669   EXPECT_EQ(F_->getNodes().size(), 2);
5670 
5671   // Constant and Quantize should have new shape.
5672   auto *newQ = llvm::dyn_cast<QuantizeNode>(S->getInput());
5673   ASSERT_TRUE(newQ);
5674   EXPECT_TRUE(newQ->getResult().dims().equals({2, 4, 5, 3}));
5675   auto *newC = llvm::dyn_cast<Constant>(newQ->getInput());
5676   ASSERT_TRUE(newC);
5677   EXPECT_TRUE(newC->getType()->dims().equals({2, 4, 5, 3}));
5678 
5679   // Check that alignments are preserved by optimizations.
5680   auto expectedNewTy = mod_.uniqueTypeWithNewShape(
5681       modTyWithAlignments, modQuantTransposedTyWithAlignments);
5682   EXPECT_TRUE(newQ->getInput().getType()->isEqual(expectedNewTy));
5683 
5684   EXPECT_TRUE(F_->verify());
5685 }
5686