1 /**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "BackendTestUtils.h"
17
18 #include "glow/Graph/Graph.h"
19 #include "glow/Graph/Node.h"
20 #include "glow/Graph/Nodes.h"
21 #include "glow/Graph/PlaceholderBindings.h"
22 #include "glow/IR/IR.h"
23 #include "glow/Optimizer/GraphOptimizer/FunctionPassPipeline.h"
24 #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
25 #include "glow/Optimizer/Lower/Lower.h"
26
27 #include "gtest/gtest.h"
28
29 using namespace glow;
30
31 class GraphFold : public GraphOptz {};
32
33 /// A helper predicate to check if the provided node has the same address as a
34 /// pre-defined address provided in constructor. This is useful if you need to
35 /// check that a given node is still in the graph. In general, it is not safe to
36 /// use the std::find(begin_it, end_it, value) and compare the nodes by value,
37 /// because the node provided as the last parameter of std::find (i.e. the value
38 /// reference) may have been removed by some optimizations and cannot be
39 /// dereferenced anymore. But comparing the addresses of the nodes should be
40 /// fine. Thus, one can use the following form instead:
41 /// std::find_if(begin_it, end_it, IsSameNodeAddress(node_address))
42 struct IsSameNodeAddress {
43 const Node *nodeAddress_;
IsSameNodeAddressIsSameNodeAddress44 IsSameNodeAddress(const Node *nodeAddress) : nodeAddress_(nodeAddress) {}
operator ()IsSameNodeAddress45 bool operator()(const Node &n) const { return &n == nodeAddress_; }
46 };
47
48 /// \returns true if the Function \p F contains the Node \p N.
functionContainsNode(const Function * F,const Node * N)49 static bool functionContainsNode(const Function *F, const Node *N) {
50 return std::find_if(F->getNodes().begin(), F->getNodes().end(),
51 IsSameNodeAddress(N)) != F->getNodes().end();
52 }
53
54 /// Optimize the function \p F with \p cctx. \returns the optimized function. If
55 /// \p pass is empty then the whole default optimization pipeline is run.
56 /// Otherwise only \p pipeline is used.
57 static Function *
optimizeFunction(Function * F,std::initializer_list<FunctionPassConfig> configs={},const CompilationContext cctx=CompilationContext ())58 optimizeFunction(Function *F,
59 std::initializer_list<FunctionPassConfig> configs = {},
60 const CompilationContext cctx = CompilationContext()) {
61 auto *G = F->clone(F->getName().str() + "_optimized");
62 if (configs.size() == 0) {
63 ::glow::optimize(G, CompilationMode::Infer);
64 return G;
65 }
66 FunctionPassManager FPM("TestFPM", configs);
67 FPM.run(G, cctx);
68 return G;
69 }
70
71 /// \returns the first node in a function which has the specificied name.
72 template <typename NodeT = Node>
findFunctionNodeByName(const Function * F,const llvm::StringRef name)73 static const NodeT *findFunctionNodeByName(const Function *F,
74 const llvm::StringRef name) {
75 return llvm::dyn_cast<NodeT>(
76 std::find_if(F->getNodes().begin(), F->getNodes().end(),
77 [=](auto &N) { return N.getName() == name; }));
78 }
79
TEST_F(GraphOptz,OptimizeClipFunnel)80 TEST_F(GraphOptz, OptimizeClipFunnel) {
81 auto *A =
82 mod_.createPlaceholder(ElemKind::FloatTy, {100, 16}, "input", false);
83 Node *K = A;
84 float min = 0.0;
85 float max = 1000.0;
86 for (int i = 0; i < 10; ++i) {
87 min += 1.0;
88 max -= 1.0;
89 K = F_->createClip("clip", K, min, max);
90 }
91 F_->createSave("ret", K);
92
93 EXPECT_EQ(F_->getNodes().size(), 11);
94
95 optimizedF_ = optimizeFunction(F_);
96 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
97
98 // Find clip node in the optimized graph.
99 Node *newClip = A;
100 for (auto &N : optimizedF_->getNodes()) {
101 if (N.getKind() == Kinded::Kind::ClipNodeKind) {
102 newClip = llvm::dyn_cast<ClipNode>(&N);
103 }
104 }
105 EXPECT_TRUE(llvm::isa<ClipNode>(newClip));
106 ClipNode *c = llvm::dyn_cast<ClipNode>(newClip);
107 EXPECT_EQ(min, c->getMin());
108 EXPECT_EQ(max, c->getMax());
109
110 bindings_.allocate(mod_.getPlaceholders());
111 bindings_.get(A)->getHandle().randomize(-1000, 1000, mod_.getPRNG());
112 bindings_.get(A)->getHandle().raw(0) = -1000;
113 checkNumericalEquivalence();
114 }
115
TEST_F(GraphOptz,DCE)116 TEST_F(GraphOptz, DCE) {
117 Node *K = mod_.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input",
118 false);
119
120 for (int i = 0; i < 40; i++) {
121 K = F_->createRELU("relu", K);
122 // Add a graph structure that diverges and converges, to catch algorithms
123 // that perform a dump recursive scan.
124 K = F_->createAdd("arith", K, K);
125 }
126
127 // Check that we know how many nodes we've created.
128 EXPECT_EQ(F_->getNodes().size(), 80);
129
130 // Optimize all of the dead code.
131 ::glow::optimize(F_, CompilationMode::Infer);
132
133 // All of the nodes are gone.
134 EXPECT_EQ(F_->getNodes().size(), 0);
135 EXPECT_EQ(mod_.getConstants().size(), 0);
136 }
137
138 /// Check that predicated instructions are DCE'ed like
139 /// regular instructions.
TEST_F(GraphOptz,DCEwithPredicate)140 TEST_F(GraphOptz, DCEwithPredicate) {
141 Node *K = mod_.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input",
142 false);
143 Node *predicatedBatch =
144 mod_.createPlaceholder(ElemKind::FloatTy, {4}, "predicate", true);
145 for (int i = 0; i < 40; i++) {
146 K = F_->createRELU("relu", K);
147 K->setPredicate(predicatedBatch);
148 // Add a graph structure that diverges and converges, to catch algorithms
149 // that perform a dump recursive scan.
150 K = F_->createAdd("arith", K, K);
151 K->setPredicate(predicatedBatch);
152 }
153
154 // Check that we know how many nodes we've created.
155 EXPECT_EQ(F_->getNodes().size(), 80);
156
157 // Optimize all of the dead code.
158 ::glow::optimize(F_, CompilationMode::Infer);
159
160 // All of the nodes are gone.
161 EXPECT_EQ(F_->getNodes().size(), 0);
162 EXPECT_EQ(mod_.getConstants().size(), 0);
163 }
164
TEST_F(GraphOptz,liveCodeNotEliminated)165 TEST_F(GraphOptz, liveCodeNotEliminated) {
166 Node *K = mod_.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input",
167 false);
168 auto *Ex = mod_.createPlaceholder(ElemKind::Int64ITy, {4, 1}, "Ex", false);
169
170 for (int i = 0; i < 40; i++) {
171 K = F_->createRELU("relu", K);
172 K = F_->createAdd("arith", K, K);
173 }
174 K = F_->createSoftMax("Regression", K, Ex);
175 F_->createSave("ret", K);
176
177 // Check that we know how many nodes we've created.
178 EXPECT_EQ(F_->getNodes().size(), 82);
179
180 // This should not optimize code because none is dead.
181 ::glow::optimize(F_, CompilationMode::Infer);
182
183 // Nothing got optimized.
184 EXPECT_EQ(F_->getNodes().size(), 82);
185 EXPECT_EQ(mod_.getPlaceholders().size(), 3);
186 }
187
188 /// Skip Reshape sinking below BatchNorm when inapplicable.
TEST_F(GraphOptz,SkipReshapeSinkBatchNorm)189 TEST_F(GraphOptz, SkipReshapeSinkBatchNorm) {
190 auto *A = mod_.createPlaceholder(ElemKind::FloatTy, {32, 64}, "A", false);
191 Node *RS = F_->createReshape("reshape", A, {32, 64, 1});
192 Node *BN =
193 F_->createBatchNormalization(bindings_, "batch", RS, 1, 0.0001, 0.9);
194 F_->createSave("ret", BN);
195
196 optimizedF_ = optimizeFunction(F_);
197 EXPECT_EQ(F_->toString(/* skipUsersForStorage */ false, /* skipName */ true),
198 optimizedF_->toString(/* skipUsersForStorage */ false,
199 /* skipName */ true));
200 }
201
202 // Conv->Reshape->BatchNorm is optimized to Conv->Reshape after sinking Reshape
203 // below BatchNorm. Reshape transforms [N][H][W][C] to [N][W][H][C].
TEST_F(GraphOptz,optimizeBatchNormAfterConvAndReshapeNHWC)204 TEST_F(GraphOptz, optimizeBatchNormAfterConvAndReshapeNHWC) {
205 auto *A =
206 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
207 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
208 Node *RS = F_->createReshape("reshape", CV, {1, 20, 10, 16});
209 Node *BN =
210 F_->createBatchNormalization(bindings_, "batch", RS, 3, 0.0001, 0.9);
211 F_->createSave("ret", BN);
212
213 EXPECT_EQ(F_->getNodes().size(), 4);
214 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
215 optimizedF_ = optimizeFunction(F_);
216 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
217
218 ASSERT_EQ(A->getNumUsers(), 2);
219 Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(),
220 [CV](auto &it) { return it.getUser() == CV; })
221 ->getUser();
222 EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
223 ASSERT_EQ(newCV->getNumUsers(), 1);
224 Node *reshape = newCV->getUsers().begin()->getUser();
225 EXPECT_TRUE(llvm::isa<ReshapeNode>(reshape));
226
227 bindings_.allocate(mod_.getPlaceholders());
228 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
229 checkNumericalEquivalence();
230 }
231
232 // Conv->Reshape->BatchNorm is optimized to Conv->Reshape after sinking Reshape
233 // below BatchNorm. Reshape flattens [N][H][W][C] to [N][HxW][C].
TEST_F(GraphOptz,optimizeBatchNormAfterConvAndReshapeNHWC2)234 TEST_F(GraphOptz, optimizeBatchNormAfterConvAndReshapeNHWC2) {
235 auto *A =
236 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
237 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
238 Node *RS = F_->createReshape("reshape", CV, {1, 200, 16});
239 Node *BN =
240 F_->createBatchNormalization(bindings_, "batch", RS, 2, 0.0001, 0.9);
241 F_->createSave("ret", BN);
242
243 EXPECT_EQ(F_->getNodes().size(), 4);
244 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
245 optimizedF_ = optimizeFunction(F_);
246 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
247
248 ASSERT_EQ(A->getNumUsers(), 2);
249 Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(),
250 [CV](auto &it) { return it.getUser() == CV; })
251 ->getUser();
252 EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
253 ASSERT_EQ(newCV->getNumUsers(), 1);
254 Node *reshape = newCV->getUsers().begin()->getUser();
255 EXPECT_TRUE(llvm::isa<ReshapeNode>(reshape));
256
257 bindings_.allocate(mod_.getPlaceholders());
258 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
259 checkNumericalEquivalence();
260 }
261
262 // BatchNorm is not folded into Conv. Reshape changes Channel Index dimensions
263 // and it prevents optimization. Reshape transforms [N][H][W][C] to
264 // [N][H][W/2][C*2].
TEST_F(GraphOptz,optimizeBatchNormAfterConvAndReshapeNHWCneg)265 TEST_F(GraphOptz, optimizeBatchNormAfterConvAndReshapeNHWCneg) {
266 auto *A =
267 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
268 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
269 Node *RS = F_->createReshape("reshape", CV, {1, 10, 10, 32});
270 Node *BN =
271 F_->createBatchNormalization(bindings_, "batch", RS, 3, 0.0001, 0.9);
272 F_->createSave("ret", BN);
273
274 EXPECT_EQ(F_->getNodes().size(), 4);
275 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
276 optimizedF_ = optimizeFunction(F_);
277 EXPECT_EQ(optimizedF_->getNodes().size(), 4);
278
279 ASSERT_EQ(A->getNumUsers(), 2);
280 Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(),
281 [CV](auto &it) { return it.getUser() == CV; })
282 ->getUser();
283 EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
284 ASSERT_EQ(newCV->getNumUsers(), 1);
285 Node *reshape = newCV->getUsers().begin()->getUser();
286 EXPECT_TRUE(llvm::isa<ReshapeNode>(reshape));
287 Node *bn = reshape->getUsers().begin()->getUser();
288 EXPECT_TRUE(llvm::isa<BatchNormalizationNode>(bn));
289
290 bindings_.allocate(mod_.getPlaceholders());
291 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
292 checkNumericalEquivalence();
293 }
294
295 // Conv->Reshape->BatchNorm. Sink Reshape below BatchNorm. Check that BatchNorm
296 // does not fold in to Conv.
TEST_F(GraphOptz,sinkReshapeBelowBatchNormAndDoNotFuseConvBatchNorm)297 TEST_F(GraphOptz, sinkReshapeBelowBatchNormAndDoNotFuseConvBatchNorm) {
298 auto *A =
299 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
300 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1,
301 ConvolutionLayout::NCHW);
302 Node *RS = F_->createReshape("reshape", CV, {1, 10, 16, 20});
303 Node *BN =
304 F_->createBatchNormalization(bindings_, "batch", RS, 1, 0.0001, 0.9);
305 F_->createSave("ret", BN);
306
307 EXPECT_EQ(F_->getNodes().size(), 4);
308 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
309 optimizedF_ = optimizeFunction(F_);
310 EXPECT_EQ(optimizedF_->getNodes().size(), 4);
311
312 ASSERT_EQ(A->getNumUsers(), 2);
313 Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(),
314 [CV](auto &it) { return it.getUser() == CV; })
315 ->getUser();
316
317 EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
318 ASSERT_EQ(newCV->getNumUsers(), 1);
319 Node *bn = newCV->getUsers().begin()->getUser();
320 EXPECT_TRUE(llvm::isa<BatchNormalizationNode>(bn));
321 Node *reshape = bn->getUsers().begin()->getUser();
322 EXPECT_TRUE(llvm::isa<ReshapeNode>(reshape));
323
324 bindings_.allocate(mod_.getPlaceholders());
325 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
326 checkNumericalEquivalence();
327 }
328
TEST_F(GraphOptz,optimizeBatchNormAfterConv)329 TEST_F(GraphOptz, optimizeBatchNormAfterConv) {
330 auto *A =
331 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
332 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
333 Node *BN =
334 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
335 F_->createSave("ret", BN);
336
337 EXPECT_EQ(F_->getNodes().size(), 3);
338 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
339 optimizedF_ = optimizeFunction(F_);
340 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
341
342 ASSERT_EQ(A->getNumUsers(), 2);
343 Node *newCV = std::find_if_not(A->getUsers().begin(), A->getUsers().end(),
344 [CV](auto &it) { return it.getUser() == CV; })
345 ->getUser();
346 EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
347 ASSERT_EQ(newCV->getNumUsers(), 1);
348 Node *save = newCV->getUsers().begin()->getUser();
349 EXPECT_TRUE(llvm::isa<SaveNode>(save));
350
351 bindings_.allocate(mod_.getPlaceholders());
352 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
353 checkNumericalEquivalence();
354 }
355
356 /// Verify that the Conv-BatchNorm merging optimization is not impacted by
357 /// multiple users on the filter/bias.
TEST_F(GraphOptz,optimizeBatchNormAfterConvMultiple)358 TEST_F(GraphOptz, optimizeBatchNormAfterConvMultiple) {
359 Placeholder *A =
360 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
361 ConvolutionNode *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
362 BatchNormalizationNode *BN =
363 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
364 F_->createSave("ret", BN);
365
366 // Adding these saves means the filter and bias have multiple uses. This
367 // should not impact the Conv-BatchNorm merging optimization.
368 F_->createSave("saveFilter", CV->getFilter());
369 F_->createSave("saveBias", CV->getBias());
370
371 // Three Saves, one Conv, and one BatchNorm.
372 EXPECT_EQ(F_->getNodes().size(), 5);
373
374 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
375
376 // Conv's Filter and Bias, plus BN's Scale, Bias, Mean, and Var.
377 EXPECT_EQ(mod_.getConstants().size(), 6);
378
379 optimizedF_ = optimizeFunction(F_);
380
381 // BatchNorm should have been merged into the Conv.
382 EXPECT_EQ(optimizedF_->getNodes().size(), 4);
383
384 // Filter and Bias should have been duplicated so that the Conv-BN
385 // optimization does not modify the filter/bias being saved, equaling 4
386 // Constants. Additionally, the BN's Scale, Bias, Mean, and Var should be
387 // eliminated due to the opti.
388 EXPECT_EQ(mod_.getConstants().size(), 8);
389
390 ASSERT_EQ(A->getNumUsers(), 2);
391 Node *newCV = A->getUsers().back().getUser();
392 EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
393 ASSERT_EQ(newCV->getNumUsers(), 1);
394 Node *save = newCV->getUsers().begin()->getUser();
395 EXPECT_TRUE(llvm::isa<SaveNode>(save));
396
397 EXPECT_EQ(
398 countNodeKind(optimizedF_, Kinded::Kind::BatchNormalizationNodeKind), 0);
399
400 bindings_.allocate(mod_.getPlaceholders());
401 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
402 checkNumericalEquivalence();
403 }
404
TEST_F(GraphOptz,optimizeBatchNormAfterConvFP16)405 TEST_F(GraphOptz, optimizeBatchNormAfterConvFP16) {
406 auto *A =
407 mod_.createPlaceholder(ElemKind::Float16Ty, {1, 10, 20, 3}, "A", false);
408 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
409 Node *BN =
410 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
411 F_->createSave("ret", BN);
412
413 EXPECT_EQ(F_->getNodes().size(), 3);
414
415 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
416 optimizedF_ = optimizeFunction(F_);
417
418 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
419
420 ASSERT_EQ(A->getNumUsers(), 2);
421
422 bool optimizedPathExists{false};
423 for (const auto &path : A->getUsers()) {
424 auto cv = path.getUser();
425 EXPECT_TRUE(llvm::isa<ConvolutionNode>(cv));
426 ASSERT_EQ(cv->getNumUsers(), 1);
427 auto next = cv->getUsers().begin()->getUser();
428 optimizedPathExists |= llvm::isa<SaveNode>(next);
429 }
430
431 EXPECT_TRUE(optimizedPathExists);
432
433 bindings_.allocate(A)->getHandle<float16_t>().randomize(-1.0, 1.0,
434 mod_.getPRNG());
435
436 checkNumericalEquivalence();
437 }
438
439 /// Check that transpose constant folding is done before BatchNorm optimization,
440 /// which allows to merge BatchNorm into Convolution with transposed weights.
TEST_F(GraphOptz,optimizeBatchNormAfterConvWithTransposedWeights)441 TEST_F(GraphOptz, optimizeBatchNormAfterConvWithTransposedWeights) {
442 auto *input =
443 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "input", false);
444 auto *filter =
445 mod_.createPlaceholder(ElemKind::FloatTy, {16, 3, 5, 5}, "filter", false);
446 auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {16}, "bias", false);
447
448 auto *TN = F_->createTranspose("transpose", filter, NCHW2NHWC);
449 auto *CV = F_->createConv("conv", input, TN, bias,
450 mod_.uniqueType(ElemKind::FloatTy, {1, 10, 20, 16}),
451 5, 1, 2, 1);
452 auto *BN =
453 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
454 F_->createSave("ret", BN);
455
456 // Initialize to ensure that constant tensors are not optimized out.
457 bindings_.allocate(filter)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
458 bindings_.allocate(bias)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
459
460 EXPECT_EQ(F_->getNodes().size(), 4);
461 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchNormalizationNodeKind), 1);
462
463 ::glow::convertPlaceholdersToConstants(F_, bindings_, {input});
464 optimizedF_ = optimizeFunction(F_);
465
466 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
467 EXPECT_EQ(
468 countNodeKind(optimizedF_, Kinded::Kind::BatchNormalizationNodeKind), 0);
469
470 bindings_.allocate(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
471 checkNumericalEquivalence();
472 }
473
474 /// Check that reshape constant folding is done before BatchNorm optimization,
475 /// where Reshape is a result of Transpose 2 Reshape optimization,
476 /// which allows to merge BatchNorm into Convolution with transposed weights.
TEST_F(GraphOptz,optimizeBatchNormAfterConvWithReshapeConst)477 TEST_F(GraphOptz, optimizeBatchNormAfterConvWithReshapeConst) {
478 auto *input =
479 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "input", false);
480 auto *filter =
481 mod_.createPlaceholder(ElemKind::FloatTy, {5, 5, 3, 1}, "filter", false);
482 auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "bias", false);
483
484 auto *TN = F_->createTranspose("transpose", filter, HWCN2NHWC);
485 auto *CV = F_->createConv("conv", input, TN, bias,
486 mod_.uniqueType(ElemKind::FloatTy, {1, 10, 20, 1}),
487 5, 1, 2, 1);
488 auto *BN =
489 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
490 F_->createSave("ret", BN);
491
492 // Initialize to ensure that constant tensors are not optimized out.
493 bindings_.allocate(filter)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
494 bindings_.allocate(bias)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
495
496 EXPECT_EQ(F_->getNodes().size(), 4);
497 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchNormalizationNodeKind), 1);
498
499 ::glow::convertPlaceholdersToConstants(F_, bindings_, {input});
500 optimizedF_ = optimizeFunction(F_);
501
502 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
503 EXPECT_EQ(
504 countNodeKind(optimizedF_, Kinded::Kind::BatchNormalizationNodeKind), 0);
505
506 bindings_.allocate(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
507 checkNumericalEquivalence();
508 }
509
510 /// Check that the batch normalization optimization is
511 /// not blocked by predicates and that it preserves them.
TEST_F(GraphOptz,optimizeBatchNormAfterConvWithPred)512 TEST_F(GraphOptz, optimizeBatchNormAfterConvWithPred) {
513 Node *A =
514 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
515 Node *pred1 =
516 mod_.createPlaceholder(ElemKind::FloatTy, {1}, "predicate", false);
517 Node *pred2 =
518 mod_.createPlaceholder(ElemKind::FloatTy, {1}, "predicate", false);
519 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
520 CV->setPredicate(pred1);
521 Node *BN =
522 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
523 BN->setPredicate(pred2);
524 F_->createSave("ret", BN);
525
526 EXPECT_EQ(F_->getNodes().size(), 3);
527
528 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
529 ::glow::optimize(F_, CompilationMode::Infer);
530 EXPECT_EQ(F_->getNodes().size(), 2);
531
532 ASSERT_EQ(A->getNumUsers(), 1);
533 Node *newCV = A->getUsers().begin()->getUser();
534 EXPECT_TRUE(llvm::isa<ConvolutionNode>(newCV));
535 ASSERT_TRUE(newCV->hasPredicate());
536 EXPECT_EQ(newCV->getPredicate().getNode(), pred2);
537 ASSERT_EQ(newCV->getNumUsers(), 1);
538 Node *save = newCV->getUsers().begin()->getUser();
539 EXPECT_TRUE(llvm::isa<SaveNode>(save));
540 }
541
542 /// Testing merge of single-user arithmetic operation chain (Sub, Mul, Add)
543 /// into a BatchNorm.
TEST_F(GraphOptz,MergeBatchNormalizationWithArithmeticChainTest)544 TEST_F(GraphOptz, MergeBatchNormalizationWithArithmeticChainTest) {
545 // Inputs.
546 auto *input =
547 mod_.createPlaceholder(ElemKind::FloatTy, {3, 2, 2, 4}, "input", false);
548 auto *var = mod_.createConstant(ElemKind::FloatTy, {4}, "var");
549 auto *mean = mod_.createConstant(ElemKind::FloatTy, {4}, "mean");
550 auto *beta = mod_.createConstant(ElemKind::FloatTy, {4}, "beta");
551 auto *gamma = mod_.createConstant(ElemKind::FloatTy, {4}, "gamma");
552
553 Node *subC = mod_.createConstant(ElemKind::FloatTy, {3, 2, 2, 4}, "subC");
554 Node *mulC = mod_.createConstant(ElemKind::FloatTy, {3, 2, 2, 4}, "mulC");
555 Node *addC = mod_.createConstant(ElemKind::FloatTy, {3, 2, 2, 4}, "addC");
556 Node *divC = mod_.createConstant(ElemKind::FloatTy, {3, 2, 2, 4}, "divC");
557
558 // Fill tensors to check boundary values after the transformation.
559 std::vector<float> betaV = {1., 2., 3., 7.};
560 std::vector<float> gammaV = {4., 5., 6., 7.};
561
562 var->getPayloadMutable().getHandle<float>() = {1., 1., 1., 1.};
563 mean->getPayloadMutable().getHandle<float>() = {0., 0., 0., 0.};
564 beta->getPayloadMutable().getHandle<float>() = betaV;
565 gamma->getPayloadMutable().getHandle<float>() = gammaV;
566
567 // For at least one node (sub) make values within channel different, to test
568 // folding better.
569 const std::vector<float> subV = {1, 2., 3., 4.};
570 const float mulV = 4., addV = 3., divV = 2.;
571 auto subH = llvm::cast<Constant>(subC)->getHandle<float>();
572 subH = {1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.,
573 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.,
574 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4., 1., 2., 3., 4.};
575
576 llvm::cast<Constant>(mulC)->getHandle<float>().clear(mulV);
577 llvm::cast<Constant>(addC)->getHandle<float>().clear(addV);
578 llvm::cast<Constant>(divC)->getHandle<float>().clear(divV);
579
580 BatchNormalizationNode *bn =
581 F_->createBatchNormalization("batch", input, beta, gamma, mean, var, 3);
582
583 auto *sub = F_->createSub("sub", bn, subC);
584 auto *mul = F_->createMul("mul", sub, mulC);
585 auto *add = F_->createAdd("add", addC, mul);
586 auto *div = F_->createDiv("div", add, divC);
587 auto *res = F_->createSave("save", div);
588
589 // Compile.
590 EXPECT_EQ(F_->getNodes().size(), 6);
591 ::glow::convertPlaceholdersToConstants(F_, bindings_, {input});
592 optimizedF_ = optimizeFunction(F_);
593 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
594
595 Constant *cs, *cb;
596
597 auto *opt_res = findFunctionNodeByName<SaveNode>(optimizedF_, res->getName());
598
599 auto *newBn = llvm::dyn_cast<BatchNormalizationNode>(opt_res->getInput());
600 ASSERT_TRUE(newBn);
601
602 cs = llvm::dyn_cast<Constant>(newBn->getScale());
603 cb = llvm::dyn_cast<Constant>(newBn->getBias());
604 ASSERT_TRUE(cs);
605 ASSERT_TRUE(cb);
606 ASSERT_TRUE(cs->getType()->isFPType());
607 ASSERT_TRUE(cb->getType()->isFPType());
608
609 auto hs = cs->getHandle<float>();
610 auto hb = cb->getHandle<float>();
611
612 // Verify that scale and offset are computed correctly.
613 for (dim_t i = 0; i < 4; i++) {
614 const float expScale = gammaV[i] * mulV / divV;
615 const float expBias = ((betaV[i] - subV[i]) * mulV + addV) / divV;
616 EXPECT_EQ(expScale, hs.raw(i));
617 EXPECT_EQ(expBias, hb.raw(i));
618 }
619
620 bindings_.allocate(mod_.getPlaceholders());
621 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
622 checkNumericalEquivalence();
623 }
624
625 /// Testing merge of single-user arithmetic operation chain (Sub, Mul, Add)
626 /// into a BatchNorm.
TEST_F(GraphOptz,FoldArithmeticChainAfterConvIntoBatchNorm)627 TEST_F(GraphOptz, FoldArithmeticChainAfterConvIntoBatchNorm) {
628 Node *subC = mod_.createConstant(ElemKind::FloatTy, {2, 3, 3, 3}, "subC");
629 Node *mulC = mod_.createConstant(ElemKind::FloatTy, {2, 3, 3, 3}, "mulC");
630 Node *addC = mod_.createConstant(ElemKind::FloatTy, {2, 3, 3, 3}, "addC");
631 Node *divC = mod_.createConstant(ElemKind::FloatTy, {2, 3, 3, 3}, "divC");
632
633 // Start with identity values.
634 std::vector<float> betaV = {0., 0., 0.};
635 std::vector<float> gammaV = {1., 1., 1.};
636
637 // For at least one node make values within channel different, to test
638 // the folding better (ideally all should have different values).
639 const std::vector<float> subV = {1, 2., 3.};
640 const float mulV = 4., addV = 3., divV = 2.;
641 llvm::cast<Constant>(mulC)->getHandle<float>().clear(mulV);
642 llvm::cast<Constant>(addC)->getHandle<float>().clear(addV);
643 llvm::cast<Constant>(divC)->getHandle<float>().clear(divV);
644 auto subH = llvm::cast<Constant>(subC)->getHandle<float>();
645 subH = {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3,
646 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3,
647 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3};
648
649 auto *input =
650 mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 2, 3}, "input", false);
651 auto filter =
652 mod_.createPlaceholder(ElemKind::FloatTy, {3, 2, 2, 3}, "filter", false);
653 auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {3}, "bias", false);
654 bindings_.allocate(bias)->zero();
655
656 ConvolutionNode *CV = F_->createConv(
657 "Conv", input, filter, bias,
658 mod_.uniqueType(ElemKind::FloatTy, {2, 3, 3, 3}), 2, 1, 1, 1);
659
660 auto *sub = F_->createSub("sub", CV, subC);
661 auto *mul = F_->createMul("mul", sub, mulC);
662 auto *add = F_->createAdd("add", addC, mul);
663 auto *div = F_->createDiv("div", add, divC);
664 auto *res = F_->createSave("save", div);
665
666 // Compile.
667 EXPECT_EQ(F_->getNodes().size(), 6);
668 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
669 optimizedF_ = optimizeFunction(F_);
670 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
671
672 auto *opt_res = findFunctionNodeByName<SaveNode>(optimizedF_, res->getName());
673
674 Constant *cs, *cb;
675
676 auto *bn = llvm::dyn_cast<BatchNormalizationNode>(opt_res->getInput());
677 ASSERT_TRUE(bn);
678
679 cs = llvm::dyn_cast<Constant>(bn->getScale());
680 cb = llvm::dyn_cast<Constant>(bn->getBias());
681
682 ASSERT_TRUE(cs);
683 ASSERT_TRUE(cb);
684 ASSERT_TRUE(cs->getType()->isFPType());
685 ASSERT_TRUE(cb->getType()->isFPType());
686
687 auto hs = cs->getHandle<float>();
688 auto hb = cb->getHandle<float>();
689
690 // Verify that scale and offset are computed correctly.
691 for (dim_t i = 0; i < 3; i++) {
692 const float expectedScale = gammaV[i] * (mulV / divV);
693 const float expectedBias = ((betaV[i] - subV[i]) * mulV + addV) / divV;
694 EXPECT_EQ(expectedScale, hs.raw(i));
695 EXPECT_EQ(expectedBias, hb.raw(i));
696 }
697 bindings_.allocate(mod_.getPlaceholders());
698 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
699 bindings_.get(filter)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
700 bindings_.get(bias)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
701 checkNumericalEquivalence();
702 }
703
704 /// Check CSE will not merge two nodes that have all the same inputs but
705 /// different predicates.
TEST_F(GraphOptz,cseRespectsPredicates)706 TEST_F(GraphOptz, cseRespectsPredicates) {
707 Placeholder *in = mod_.createPlaceholder(ElemKind::FloatTy, {5}, "in", false);
708 Placeholder *pred1 =
709 mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
710 Placeholder *pred2 =
711 mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
712
713 Node *RN1 = F_->createRELU("relu1", in);
714 RN1->setPredicate(pred1);
715 SaveNode *save1 = F_->createSave("save1", RN1);
716 save1->setPredicate(pred1);
717
718 Node *RN2 = F_->createRELU("relu2", in);
719 RN2->setPredicate(pred2);
720 SaveNode *save2 = F_->createSave("save2", RN2);
721 save2->setPredicate(pred2);
722
723 // Two RELUS and two Saves.
724 EXPECT_EQ(F_->getNodes().size(), 4);
725 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 2);
726 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 2);
727
728 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
729 optimizedF_ = optimizeFunction(F_);
730
731 // Two RELUS and two Saves should still be there.
732 EXPECT_EQ(F_->getNodes().size(), 4);
733 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 2);
734 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 2);
735
736 bindings_.allocate(mod_.getPlaceholders());
737 bindings_.get(in)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
738 checkNumericalEquivalence();
739 }
740
TEST_F(GraphOptz,optimizeBatchNormAfterConvButConvReused)741 TEST_F(GraphOptz, optimizeBatchNormAfterConvButConvReused) {
742 Placeholder *A =
743 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
744 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
745 Node *BN =
746 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
747 F_->createSave("ret", BN);
748 F_->createSave("convSave", CV);
749
750 EXPECT_EQ(F_->getNodes().size(), 4);
751 optimizedF_ = optimizeFunction(F_);
752 // Make sure the structure of the graph did not change, since the convolution
753 // node is used more than once.
754 EXPECT_EQ(optimizedF_->getNodes().size(), 4);
755 auto convIt =
756 std::find_if(optimizedF_->getNodes().begin(),
757 optimizedF_->getNodes().end(), [](const Node &node) -> bool {
758 return llvm::isa<ConvolutionNode>(node);
759 });
760 ASSERT_NE(convIt, optimizedF_->getNodes().end());
761 auto batchNormIt =
762 std::find_if(optimizedF_->getNodes().begin(),
763 optimizedF_->getNodes().end(), [](const Node &node) -> bool {
764 return (llvm::isa<BatchNormalizationNode>(node));
765 });
766 ConvolutionNode *conv = llvm::dyn_cast<ConvolutionNode>(convIt);
767 BatchNormalizationNode *batchNorm =
768 llvm::dyn_cast<BatchNormalizationNode>(batchNormIt);
769
770 EXPECT_EQ(*conv, *CV);
771 EXPECT_EQ(batchNorm->getInput().getNode(), conv);
772 EXPECT_EQ(conv->getInput().getNode(), A);
773
774 bindings_.allocate(mod_.getPlaceholders());
775 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
776 checkNumericalEquivalence();
777 }
778
TEST_F(GraphOptz,optimizeBatchNormAfterConvButVarReused)779 TEST_F(GraphOptz, optimizeBatchNormAfterConvButVarReused) {
780 auto *A =
781 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
782
783 ConvolutionNode *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
784 Node *BN =
785 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
786 auto *retSaveNode = F_->createSave("ret", BN);
787 auto *filterSaveNode = F_->createSave("filter", CV->getFilter());
788
789 EXPECT_EQ(F_->getNodes().size(), 4);
790 optimizedF_ = optimizeFunction(F_);
791 ASSERT_EQ(A->getNumUsers(), 2);
792
793 auto *optimizedF_ret =
794 findFunctionNodeByName<SaveNode>(optimizedF_, retSaveNode->getName());
795 auto *optimizedF_filterSave =
796 findFunctionNodeByName<SaveNode>(optimizedF_, filterSaveNode->getName());
797
798 // Make sure the structure of the graph did not change.
799 EXPECT_EQ(optimizedF_->getNodes().size(), 4);
800 EXPECT_TRUE(llvm::isa<Placeholder>(optimizedF_filterSave->getInput()));
801 auto *varFilter =
802 llvm::dyn_cast<Placeholder>(optimizedF_filterSave->getInput());
803 EXPECT_EQ(varFilter, CV->getFilter());
804 EXPECT_TRUE(llvm::isa<BatchNormalizationNode>(optimizedF_ret->getInput()));
805
806 BatchNormalizationNode *batchNorm =
807 llvm::dyn_cast<BatchNormalizationNode>(optimizedF_ret->getInput());
808 ASSERT_TRUE(batchNorm);
809 auto *newCVNode =
810 llvm::dyn_cast<ConvolutionNode>(batchNorm->getInput().getNode());
811 ASSERT_TRUE(newCVNode);
812 EXPECT_EQ(newCVNode->getInput().getNode(), CV->getInput().getNode());
813 EXPECT_EQ(newCVNode->getInput().getNode(), A);
814
815 bindings_.allocate(mod_.getPlaceholders());
816 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
817 checkNumericalEquivalence();
818 }
819
TEST_F(GraphOptz,transposeConstant)820 TEST_F(GraphOptz, transposeConstant) {
821 auto *A =
822 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
823 bindings_.allocate(A)->getHandle().randomize(-7.0, 12.0, mod_.getPRNG());
824 Tensor transposedA;
825 bindings_.get(A)->transpose(&transposedA, {0, 3, 1, 2});
826 Node *T = F_->createTranspose("transpose", A, NHWC2NCHW);
827 SaveNode *save = F_->createSave("ret", T);
828 EXPECT_EQ(F_->getNodes().size(), 2);
829
830 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
831 ::glow::optimize(F_, CompilationMode::Infer);
832 ASSERT_EQ(F_->getNodes().size(), 1);
833 EXPECT_EQ(&*F_->getNodes().begin(), save);
834 Constant *optimizedA = llvm::dyn_cast<Constant>(save->getInput().getNode());
835 ASSERT_NE(optimizedA, nullptr);
836 // Check that A has been properly transposed.
837 EXPECT_TRUE(optimizedA->getPayload().isEqual(transposedA));
838 }
839
840 /// Check that the Transpose is merged with Constant in a sequence
841 /// Transpose(Quantize(Constant)).
TEST_F(GraphOptz,transposeQuantizeConstant)842 TEST_F(GraphOptz, transposeQuantizeConstant) {
843 auto *qTy = mod_.uniqueType(ElemKind::Int8QTy, {1, 10, 20, 3}, 0.2, 0);
844 auto *input = F_->getParent()->createConstant(ElemKind::FloatTy,
845 {1, 10, 20, 3}, "input");
846 auto *Q = F_->createQuantize("quantize", input, qTy);
847 auto *T = F_->createTranspose("transpose", Q, NHWC2NCHW);
848 auto *S = F_->createSave("save", T);
849
850 // Skip ConstantFolding as it would have the same result as this opt.
851 CompilationContext cctx;
852 cctx.optimizationOpts.enableConstantFolding = false;
853
854 EXPECT_EQ(F_->getNodes().size(), 3);
855 ::glow::optimize(F_, cctx);
856 EXPECT_EQ(F_->getNodes().size(), 2);
857
858 // Constant and Quantize should have new shape.
859 auto *newQ = llvm::dyn_cast<QuantizeNode>(S->getInput());
860 ASSERT_TRUE(newQ);
861 EXPECT_TRUE(newQ->getResult().dims().equals({1, 3, 10, 20}));
862 auto *newC = llvm::dyn_cast<Constant>(newQ->getInput());
863 ASSERT_TRUE(newC);
864 EXPECT_TRUE(newC->getType()->dims().equals({1, 3, 10, 20}));
865 }
866
867 /// Check that the removing of transposes still happens when
868 /// predicates are involved.
TEST_F(GraphOptz,transposeConstantWithPredicate)869 TEST_F(GraphOptz, transposeConstantWithPredicate) {
870 auto *A =
871 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
872 auto *pred = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
873 bindings_.allocate(A)->getHandle().randomize(-7.0, 12.0, mod_.getPRNG());
874 Tensor transposedA;
875 bindings_.get(A)->transpose(&transposedA, {0, 3, 1, 2});
876 // Arguably, if the transpose doesn't happen because the predicate is false
877 // the value of A should be unchanged. However, the semantic of our
878 // predicate is that they can be ignored and the program would still
879 // be correct, thus this optimization is still legal.
880 Node *T = F_->createTranspose("transpose", A, NHWC2NCHW);
881 T->setPredicate(pred);
882 SaveNode *save = F_->createSave("ret", T);
883 save->setPredicate(pred);
884 EXPECT_EQ(F_->getNodes().size(), 2);
885
886 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
887 ::glow::optimize(F_, CompilationMode::Infer);
888 ASSERT_EQ(F_->getNodes().size(), 1);
889 EXPECT_EQ(&*F_->getNodes().begin(), save);
890 // We should have kept the predicate on the save node.
891 ASSERT_EQ(pred->getNumUsers(), 1);
892 EXPECT_EQ(pred->getUsers().begin()->getUser(), save);
893 Constant *optimizedA = llvm::dyn_cast<Constant>(save->getInput().getNode());
894 ASSERT_NE(optimizedA, nullptr);
895 // Check that A has been properly transposed.
896 EXPECT_TRUE(optimizedA->getPayload().isEqual(transposedA));
897 }
898
TEST_F(GraphOptz,BatchNormAfterConvNotOptimizeForTrain)899 TEST_F(GraphOptz, BatchNormAfterConvNotOptimizeForTrain) {
900 Placeholder *A =
901 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
902 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
903 Node *BN =
904 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
905 F_->createSave("ret", BN);
906
907 EXPECT_EQ(F_->getNodes().size(), 3);
908
909 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
910 ::glow::optimize(optimizedF_, CompilationMode::Train);
911 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
912
913 ASSERT_EQ(A->getNumUsers(), 2);
914 Node *curCV = A->getUsers().begin()->getUser();
915 EXPECT_EQ(curCV, CV);
916 ASSERT_EQ(curCV->getNumUsers(), 1);
917 Node *curBN = curCV->getUsers().begin()->getUser();
918 EXPECT_EQ(curBN, BN);
919 ASSERT_EQ(curBN->getNumUsers(), 1);
920 Node *save = curBN->getUsers().begin()->getUser();
921 EXPECT_TRUE(llvm::isa<SaveNode>(save));
922
923 bindings_.allocate(mod_.getPlaceholders());
924 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
925 checkNumericalEquivalence();
926 }
927
TEST_F(GraphOptz,batchNormAfterConvNotOptimizeWhenMoreThanOneUseOfConv)928 TEST_F(GraphOptz, batchNormAfterConvNotOptimizeWhenMoreThanOneUseOfConv) {
929 Node *A =
930 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false);
931
932 Node *CV = F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1);
933 Node *BN =
934 F_->createBatchNormalization(bindings_, "batch", CV, 3, 0.0001, 0.9);
935 SaveNode *convSave = F_->createSave("ret", CV);
936 SaveNode *ret = F_->createSave("ret", BN);
937
938 EXPECT_EQ(F_->getNodes().size(), 4);
939
940 ::glow::optimize(F_, CompilationMode::Infer);
941 // Make sure the structure of the graph did not change, since the convolution
942 // node is used more than once.
943 EXPECT_EQ(F_->getNodes().size(), 4);
944 ASSERT_TRUE(llvm::isa<ConvolutionNode>(convSave->getInput()));
945 ConvolutionNode *conv = llvm::dyn_cast<ConvolutionNode>(convSave->getInput());
946 EXPECT_EQ(conv, CV);
947 EXPECT_TRUE(llvm::isa<BatchNormalizationNode>(ret->getInput()));
948 BatchNormalizationNode *batchNorm =
949 llvm::dyn_cast<BatchNormalizationNode>(ret->getInput());
950 EXPECT_EQ(batchNorm, BN);
951 EXPECT_EQ(batchNorm->getInput().getNode(), CV);
952 EXPECT_EQ(conv->getInput().getNode(), A);
953 }
954
955 enum class TestSinkTransposeNodesKind {
956 BatchNormalization,
957 Relu,
958 Clip,
959 Sigmoid,
960 Tanh,
961 Quantize,
962 };
963
964 class GraphOptzSinkTransposeBelowParametrized
965 : public GraphOptz,
966 public ::testing::WithParamInterface<TestSinkTransposeNodesKind> {
967 public:
getNodeFromInput(TestSinkTransposeNodesKind testNode,Node * T)968 NodeValue getNodeFromInput(TestSinkTransposeNodesKind testNode, Node *T) {
969 switch (testNode) {
970 case TestSinkTransposeNodesKind::BatchNormalization: {
971 return F_->createBatchNormalization(bindings_, "batch", T, 3, 0.0001, 0.9)
972 ->getResult();
973 }
974 case TestSinkTransposeNodesKind::Relu: {
975 return F_->createRELU("relu", T)->getResult();
976 }
977 case TestSinkTransposeNodesKind::Clip: {
978 return F_->createClip("clip", T, 0.0, 6.0)->getResult();
979 }
980 case TestSinkTransposeNodesKind::Sigmoid: {
981 return F_->createSigmoid("sigmoid", T)->getResult();
982 }
983 case TestSinkTransposeNodesKind::Tanh: {
984 return F_->createTanh("tanh", T)->getResult();
985 }
986 case TestSinkTransposeNodesKind::Quantize: {
987 return F_
988 ->createQuantize(
989 "quantize", T,
990 mod_.uniqueType(ElemKind::Int8QTy, T->dims(0), 0.03, 5))
991 ->getResult();
992 }
993 }
994 LOG(DFATAL) << "Cannot reach here.";
995 }
996 };
997
TEST_P(GraphOptzSinkTransposeBelowParametrized,TestSinkTransposeForDifferentCases)998 TEST_P(GraphOptzSinkTransposeBelowParametrized,
999 TestSinkTransposeForDifferentCases) {
1000 const dim_t origDims[] = {1, 5, 10, 15};
1001 const dim_t transposedDims[] = {1, 15, 5, 10};
1002 auto *A = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1003 Node *T = F_->createTranspose("transpose", A, NHWC2NCHW);
1004 auto IN = getNodeFromInput(GetParam(), T);
1005 SaveNode *O = F_->createSave("ret", IN);
1006
1007 EXPECT_EQ(F_->getNodes().size(), 3);
1008 EXPECT_EQ(IN.dims(), llvm::makeArrayRef(transposedDims));
1009
1010 optimizedF_ = optimizeFunction(F_);
1011 O = llvm::dyn_cast<SaveNode>(std::find_if(
1012 optimizedF_->getNodes().begin(), optimizedF_->getNodes().end(),
1013 [](const auto &N) { return N.getKind() == Kinded::Kind::SaveNodeKind; }));
1014
1015 // Expecting Transpose->Output rather than N->Output.
1016 auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1017 ASSERT_NE(transpose, nullptr);
1018 Node *N = transpose->getInput();
1019 ASSERT_TRUE(N);
1020 // Test correct input.
1021 if (GetParam() == TestSinkTransposeNodesKind::BatchNormalization) {
1022 ASSERT_EQ(BatchNormalizationNode::InputIdx, 0);
1023 } else {
1024 ASSERT_EQ(N->getNumInputs(), 1);
1025 }
1026 // Check that the dimensions of the input and output have been
1027 // updated to compensate the absence of transpose.
1028 EXPECT_EQ(transpose->getInput().dims(), llvm::makeArrayRef(origDims));
1029 EXPECT_EQ(N->getNthInput(0).dims(), llvm::makeArrayRef(origDims));
1030 EXPECT_EQ(F_->getNodes().size(), 3);
1031
1032 bindings_.allocate(mod_.getPlaceholders());
1033 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1034 checkNumericalEquivalence();
1035 }
1036
TEST_P(GraphOptzSinkTransposeBelowParametrized,TestSinkTransposeWithPredicateForDifferentCases)1037 TEST_P(GraphOptzSinkTransposeBelowParametrized,
1038 TestSinkTransposeWithPredicateForDifferentCases) {
1039 if (GetParam() == TestSinkTransposeNodesKind::Quantize) {
1040 // Quantize does not work with generic test for predicates.
1041 return;
1042 }
1043 const dim_t origDims[] = {1, 5, 10, 15};
1044 const dim_t transposedDims[] = {1, 15, 5, 10};
1045 Node *A = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1046 Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1047 Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1048 Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1049 Node *T = F_->createTranspose("transpose", A, NHWC2NCHW);
1050 T->setPredicate(pred1);
1051 Node *IN = getNodeFromInput(GetParam(), T);
1052 IN->setPredicate(pred2);
1053 SaveNode *O = F_->createSave("ret", IN);
1054 O->setPredicate(pred3);
1055
1056 EXPECT_EQ(F_->getNodes().size(), 3);
1057 EXPECT_EQ(IN->getNthResult(0).dims(), llvm::makeArrayRef(transposedDims));
1058
1059 ::glow::optimize(F_, CompilationMode::Infer);
1060
1061 EXPECT_EQ(O->getPredicate().getNode(), pred3);
1062 // Expecting Transpose->Output rather than N->Output.
1063 auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1064 ASSERT_NE(transpose, nullptr);
1065 EXPECT_EQ(transpose->getPredicate().getNode(), pred2);
1066 Node *N = transpose->getInput();
1067 ASSERT_TRUE(N);
1068 EXPECT_EQ(N->getPredicate().getNode(), pred2);
1069
1070 // Test correct input.
1071 if (GetParam() == TestSinkTransposeNodesKind::BatchNormalization) {
1072 ASSERT_EQ(BatchNormalizationNode::InputIdx, 0);
1073 } else {
1074 ASSERT_EQ(N->getNumInputs(), 1);
1075 }
1076
1077 // Check that the dimensions of the input and output have been
1078 // updated to compensate the absence of transpose.
1079 EXPECT_EQ(transpose->getInput().dims(), llvm::makeArrayRef(origDims));
1080 EXPECT_EQ(N->getNthInput(0).dims(), llvm::makeArrayRef(origDims));
1081 EXPECT_EQ(F_->getNodes().size(), 3);
1082 }
1083
1084 GLOW_INSTANTIATE_TEST_SUITE_P(
1085 TestSinkTranspose, GraphOptzSinkTransposeBelowParametrized,
1086 ::testing::Values(TestSinkTransposeNodesKind::BatchNormalization,
1087 TestSinkTransposeNodesKind::Relu,
1088 TestSinkTransposeNodesKind::Clip,
1089 TestSinkTransposeNodesKind::Sigmoid,
1090 TestSinkTransposeNodesKind::Tanh,
1091 TestSinkTransposeNodesKind::Quantize));
1092
TEST_F(GraphOptz,SinkTransposeBelowDequantize)1093 TEST_F(GraphOptz, SinkTransposeBelowDequantize) {
1094 auto *in =
1095 mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input", false);
1096 auto *quantize = F_->createQuantize(
1097 "quantize", in, mod_.uniqueType(ElemKind::Int8QTy, in->dims(), 0.01, 2));
1098 auto *tile = F_->createTile("tile", quantize, 3, 0);
1099 auto *transpose = F_->createTranspose("transpose", tile, NHWC2NCHW);
1100 auto *deq = F_->createDequantize("dequantize", transpose, ElemKind::FloatTy);
1101 SaveNode *O = F_->createSave("out", deq);
1102
1103 optimizedF_ = optimizeFunction(F_);
1104
1105 EXPECT_EQ(F_->getNodes().size(), 5);
1106 EXPECT_EQ(optimizedF_->getNodes().size(), 5);
1107
1108 auto *optOut = findFunctionNodeByName<SaveNode>(optimizedF_, O->getName());
1109 EXPECT_TRUE(llvm::isa<TransposeNode>(optOut->getInput().getNode()));
1110
1111 bindings_.allocate(mod_.getPlaceholders());
1112 bindings_.get(in)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1113 checkNumericalEquivalence();
1114 }
1115
1116 /// For example folding Rescale in to Convolution.
TEST_F(GraphOptz,sinkTransposeBelowRescale)1117 TEST_F(GraphOptz, sinkTransposeBelowRescale) {
1118 // Inputs.
1119 const dim_t origDims[] = {1, 5, 10, 15};
1120 const dim_t transposedDims[] = {1, 15, 5, 10};
1121 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, origDims, 0.1, 0,
1122 "input", false);
1123 auto *filter = mod_.createPlaceholder(ElemKind::Int8QTy, {15, 1, 1, 15}, 0.1,
1124 0, "filter", false);
1125 auto *bias =
1126 mod_.createPlaceholder(ElemKind::Int32QTy, {15}, 0.01, 0, "bias", false);
1127
1128 // Graph.
1129 ConvolutionNode *conv =
1130 F_->createConv("conv", input, filter, bias, input->getType(), {1, 1},
1131 {1, 1}, {0, 0, 0, 0}, 1);
1132
1133 auto *T = F_->createTranspose("transpose", conv, NHWC2NCHW);
1134 auto *RT = mod_.uniqueType(ElemKind::Int8QTy, T->getResult().dims(), 0.2, 0);
1135 auto *R = F_->createRescaleQuantized("rescale", T, RT);
1136 SaveNode *O = F_->createSave("ret", R);
1137
1138 EXPECT_EQ(F_->getNodes().size(), 4);
1139 EXPECT_EQ(RT->dims(), llvm::makeArrayRef(transposedDims));
1140
1141 ::glow::optimize(F_, CompilationMode::Infer);
1142
1143 // Expecting Transpose->Output rather than Rescale->Output.
1144 auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1145 ASSERT_NE(transpose, nullptr);
1146 ASSERT_TRUE(llvm::isa<ConvolutionNode>(transpose->getInput()));
1147 auto &convTRInput = transpose->getInput();
1148 // Check that the dimensions of the input and output have been
1149 // updated to compensate the absence of transpose.
1150 EXPECT_EQ(convTRInput.dims(), llvm::makeArrayRef(origDims));
1151 EXPECT_EQ(convTRInput.getNode()->getNthInput(0).dims(),
1152 llvm::makeArrayRef(origDims));
1153 EXPECT_EQ(F_->getNodes().size(), 3);
1154 }
1155
TEST_F(GraphOptz,cancelTwoTransposes)1156 TEST_F(GraphOptz, cancelTwoTransposes) {
1157 const dim_t origDims[] = {1, 5, 10, 15};
1158 Placeholder *A =
1159 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1160 Node *T1 = F_->createTranspose("transpose", A, NCHW2NHWC);
1161 Node *T2 = F_->createTranspose("transpose", T1, NHWC2NCHW);
1162 ReluNode *K = F_->createRELU("relu", T2);
1163 SaveNode *save = F_->createSave("ret", K);
1164
1165 EXPECT_EQ(K->getInput().dims(), llvm::makeArrayRef(origDims));
1166 EXPECT_EQ(F_->getNodes().size(), 4);
1167
1168 optimizedF_ = optimizeFunction(F_);
1169
1170 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
1171
1172 for (auto &N : optimizedF_->getNodes()) {
1173 if (N.getKind() == Kinded::Kind::SaveNodeKind) {
1174 save = llvm::dyn_cast<SaveNode>(&N);
1175 }
1176 }
1177
1178 ReluNode *relu = llvm::dyn_cast<ReluNode>(save->getInput());
1179 ASSERT_TRUE(relu);
1180 EXPECT_EQ(relu->getResult().dims(), llvm::makeArrayRef(origDims));
1181 EXPECT_EQ(relu->getInput().getNode(), A);
1182
1183 bindings_.allocate(mod_.getPlaceholders());
1184 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1185
1186 checkNumericalEquivalence();
1187 }
1188
1189 /// Make sure the predicates don't get in the way of the
1190 /// transpose(transpose) => identity and that they are
1191 /// preserved.
TEST_F(GraphOptz,cancelTwoTransposesWithPredicate)1192 TEST_F(GraphOptz, cancelTwoTransposesWithPredicate) {
1193 const dim_t origDims[] = {1, 5, 10, 15};
1194 Node *A = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1195 Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1196 Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1197 Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1198 Node *pred4 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1199 Node *T1 = F_->createTranspose("transpose", A, NCHW2NHWC);
1200 T1->setPredicate(pred1);
1201 Node *T2 = F_->createTranspose("transpose", T1, NHWC2NCHW);
1202 T2->setPredicate(pred2);
1203 ReluNode *K = F_->createRELU("relu", T2);
1204 K->setPredicate(pred3);
1205 SaveNode *save = F_->createSave("ret", K);
1206 save->setPredicate(pred4);
1207
1208 EXPECT_EQ(K->getInput().dims(), llvm::makeArrayRef(origDims));
1209 EXPECT_EQ(F_->getNodes().size(), 4);
1210
1211 ::glow::optimize(F_, CompilationMode::Infer);
1212
1213 EXPECT_EQ(F_->getNodes().size(), 2);
1214 EXPECT_EQ(save->getPredicate().getNode(), pred4);
1215 ReluNode *relu = llvm::dyn_cast<ReluNode>(save->getInput());
1216 ASSERT_TRUE(relu);
1217 EXPECT_EQ(relu->getPredicate().getNode(), pred3);
1218 EXPECT_EQ(relu->getResult().dims(), llvm::makeArrayRef(origDims));
1219 EXPECT_EQ(relu->getInput().getNode(), A);
1220 }
1221
TEST_F(GraphOptz,removeIdentityTranspose)1222 TEST_F(GraphOptz, removeIdentityTranspose) {
1223 const dim_t origDims[] = {1, 5, 10, 15};
1224 Placeholder *A =
1225 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1226 TransposeNode *T = F_->createTranspose("transpose", A, {0, 1, 2, 3});
1227 ReluNode *K = F_->createRELU("relu", T);
1228 F_->createSave("ret", K);
1229
1230 EXPECT_EQ(F_->getNodes().size(), 3);
1231 EXPECT_EQ(K->getInput().getNode(), T);
1232
1233 ::glow::optimize(F_, CompilationMode::Infer);
1234
1235 EXPECT_EQ(F_->getNodes().size(), 2);
1236 EXPECT_EQ(K->getInput().getNode(), A);
1237 // Make sure we didn't mess up with the dimensions of the
1238 // variable while eliminating the transpose.
1239 EXPECT_EQ(A->dims(), llvm::makeArrayRef(origDims));
1240 }
1241
1242 /// Check that the predicates don't get in the way of
1243 /// the identity transpose removal, while still being
1244 /// preserved.
TEST_F(GraphOptz,removeIdentityTransposeWithPredicate)1245 TEST_F(GraphOptz, removeIdentityTransposeWithPredicate) {
1246 const dim_t origDims[] = {1, 5, 10, 15};
1247 Placeholder *A =
1248 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1249 Placeholder *pred1 =
1250 mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1251 Placeholder *pred2 =
1252 mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1253 Placeholder *pred3 =
1254 mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1255 TransposeNode *T = F_->createTranspose("transpose", A, {0, 1, 2, 3});
1256 T->setPredicate(pred1);
1257 ReluNode *K = F_->createRELU("relu", T);
1258 K->setPredicate(pred2);
1259 SaveNode *save = F_->createSave("ret", K);
1260 save->setPredicate(pred3);
1261
1262 EXPECT_EQ(F_->getNodes().size(), 3);
1263 EXPECT_EQ(K->getInput().getNode(), T);
1264
1265 ::glow::optimize(F_, CompilationMode::Infer);
1266 EXPECT_EQ(F_->getNodes().size(), 2);
1267 EXPECT_EQ(save->getPredicate().getNode(), pred3);
1268 EXPECT_EQ(save->getInput().getNode(), K);
1269 EXPECT_EQ(K->getInput().getNode(), A);
1270 EXPECT_EQ(K->getPredicate().getNode(), pred2);
1271 // Make sure we didn't mess up with the dimensions of the
1272 // variable while eliminating the transpose.
1273 EXPECT_EQ(A->dims(), llvm::makeArrayRef(origDims));
1274 }
1275
1276 /// Check that consecutive non-inverse transposes are merged
1277 /// into an equivalent single transpose node.
TEST_F(GraphOptz,mergeNonInverseTransposes)1278 TEST_F(GraphOptz, mergeNonInverseTransposes) {
1279 const dim_t origDims[] = {1, 5, 10, 15};
1280 const dim_t finalDims[] = {5, 1, 15, 10};
1281
1282 Placeholder *A =
1283 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input", false);
1284 TransposeNode *T1 = F_->createTranspose("transpose", A, {0, 3, 2, 1});
1285 TransposeNode *T2 = F_->createTranspose("transpose", T1, {0, 2, 3, 1});
1286 TransposeNode *T3 = F_->createTranspose("transpose", T2, {1, 0, 3, 2});
1287 TransposeNode *T4 = F_->createTranspose("transpose", T3, {3, 1, 2, 0});
1288
1289 // Intermediate dims after each tranpose
1290 // Initial : {1, 5, 10, 15}
1291 // After T1: {1, 15, 10, 5}
1292 // After T2: {1, 10, 5, 15}
1293 // After T3: {10, 1, 15, 5}
1294 // After T4: {5, 1, 15, 10}
1295
1296 SaveNode *save = F_->createSave("ret", T4);
1297
1298 EXPECT_EQ(F_->getNodes().size(), 5);
1299
1300 optimizedF_ = optimizeFunction(F_);
1301 // Find save node in the optimized graph.
1302 for (auto &N : optimizedF_->getNodes()) {
1303 if (N.getKind() == Kinded::Kind::SaveNodeKind) {
1304 save = llvm::dyn_cast<SaveNode>(&N);
1305 }
1306 }
1307 // Get the last transpose node in the optimized graph.
1308 auto *TR = llvm::dyn_cast<TransposeNode>(save->getInput());
1309 ASSERT_NE(TR, nullptr);
1310
1311 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
1312 EXPECT_EQ(TR->getResult().dims(), llvm::makeArrayRef(finalDims));
1313 EXPECT_EQ(A->getNthResult(0).dims(), llvm::makeArrayRef(origDims));
1314 EXPECT_EQ(TR->getInput().getNode(), A);
1315
1316 bindings_.allocate(mod_.getPlaceholders());
1317 bindings_.get(A)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1318 checkNumericalEquivalence();
1319 }
1320
TEST_F(GraphOptz,sinkTransposeBelowArithmeticNodes)1321 TEST_F(GraphOptz, sinkTransposeBelowArithmeticNodes) {
1322 const dim_t origDims[] = {1, 5, 10, 15};
1323 Node *A1 =
1324 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1325 Node *A2 =
1326 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1327 Node *T1 = F_->createTranspose("transpose1", A1, NHWC2NCHW);
1328 Node *T2 = F_->createTranspose("transpose2", A2, NHWC2NCHW);
1329 Node *K = F_->createAdd("arith", T1, T2);
1330 SaveNode *O = F_->createSave("ret", K);
1331
1332 EXPECT_EQ(F_->getNodes().size(), 4);
1333
1334 ::glow::optimize(F_, CompilationMode::Infer);
1335
1336 // Expecting Transpose->Output rather than Add->Output.
1337 auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1338 ASSERT_NE(transpose, nullptr);
1339 auto *add = llvm::dyn_cast<AddNode>(transpose->getInput());
1340 ASSERT_TRUE(add);
1341 // Check that the dimensions of the input and output have been
1342 // updated to compensate the absence of transpose.
1343 EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims));
1344 EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims));
1345 EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims));
1346 EXPECT_EQ(add->getLHS().getNode(), A1);
1347 EXPECT_EQ(add->getRHS().getNode(), A2);
1348
1349 EXPECT_EQ(F_->getNodes().size(), 3);
1350 }
1351
1352 /// Check that Transpose node is sunk below arithmetic nodes when one of the
1353 /// operands is a Constant.
TEST_F(GraphOptz,sinkTransposeBelowArithmeticNodesWithConstantOperand)1354 TEST_F(GraphOptz, sinkTransposeBelowArithmeticNodesWithConstantOperand) {
1355 const dim_t origDims[] = {1, 5, 10, 15};
1356 const dim_t transposedDims[] = {1, 15, 5, 10};
1357
1358 // Create one subgraph in which the Constant is the LHS operand of the Add.
1359 Constant *C1 = mod_.createConstant(ElemKind::FloatTy, transposedDims, "C1");
1360 // Initialize the payload before optimization so that it can be copied to the
1361 // new Constant that will be created by the GraphOptimizer.
1362 C1->getHandle().randomize(-1, 1, mod_.getPRNG());
1363
1364 auto *P1 = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "P1", false);
1365 auto *T1 = F_->createTranspose("T1", P1, NHWC2NCHW);
1366 auto *A1 = F_->createAdd("A1", C1, T1);
1367 SaveNode *S1 = F_->createSave("S1", A1);
1368
1369 // Create one subgraph in which the Constnat is the RHS operand of the Add.
1370 Constant *C2 = mod_.createConstant(ElemKind::FloatTy, transposedDims, "C2");
1371 // Initialize the payload before optimization so that it can be copied to the
1372 // new Constant that will be created by the GraphOptimizer.
1373 C2->getHandle().randomize(-1, 1, mod_.getPRNG());
1374
1375 auto *P2 = mod_.createPlaceholder(ElemKind::FloatTy, origDims, "P2", false);
1376 auto *T2 = F_->createTranspose("T2", P2, NHWC2NCHW);
1377 auto *A2 = F_->createAdd("A2", T2, C2);
1378 SaveNode *S2 = F_->createSave("S2", A2);
1379
1380 EXPECT_EQ(F_->getNodes().size(), 6);
1381
1382 optimizedF_ = optimizeFunction(F_);
1383
1384 // Find the SaveNodes of the optimized graph.
1385 for (auto &N : optimizedF_->getNodes()) {
1386 if (N.getKind() == Kinded::Kind::SaveNodeKind) {
1387 if (N.getName() == S1->getName()) {
1388 S1 = llvm::dyn_cast<SaveNode>(&N);
1389 }
1390
1391 if (N.getName() == S2->getName()) {
1392 S2 = llvm::dyn_cast<SaveNode>(&N);
1393 }
1394 }
1395 }
1396
1397 // Expecting Transpose->Output rather than Add->Output.
1398 auto *transpose = llvm::dyn_cast<TransposeNode>(S1->getInput());
1399 ASSERT_NE(transpose, nullptr);
1400 auto *add = llvm::dyn_cast<AddNode>(transpose->getInput());
1401 ASSERT_TRUE(add);
1402 // Check that the dimensions of the input and output of the add have been
1403 // updated to compensate the absence of transpose.
1404 EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims));
1405 EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims));
1406 EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims));
1407 EXPECT_EQ(add->getRHS().getNode(), P1);
1408
1409 // Repeat checks for other subgraph.
1410 transpose = llvm::dyn_cast<TransposeNode>(S2->getInput());
1411 ASSERT_NE(transpose, nullptr);
1412 add = llvm::dyn_cast<AddNode>(transpose->getInput());
1413 ASSERT_TRUE(add);
1414 EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims));
1415 EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims));
1416 EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims));
1417 EXPECT_EQ(add->getLHS().getNode(), P2);
1418
1419 EXPECT_EQ(optimizedF_->getNodes().size(), 6);
1420
1421 // Check that the original and optimized functions are numerically equivalent.
1422 // This indirectly checks that the Constant has been transposed properly.
1423 bindings_.allocate(mod_.getPlaceholders());
1424 bindings_.get(P1)->getHandle().randomize(-1, 1, mod_.getPRNG());
1425 bindings_.get(P2)->getHandle().randomize(-1, 1, mod_.getPRNG());
1426
1427 checkNumericalEquivalence();
1428 }
1429
1430 /// Check that the predicates are properly preserved while doing
1431 /// the add(transpose, transpose) => transpose(add).
TEST_F(GraphOptz,sinkTransposeBelowArithmeticNodesWithPredicate)1432 TEST_F(GraphOptz, sinkTransposeBelowArithmeticNodesWithPredicate) {
1433 const dim_t origDims[] = {1, 5, 10, 15};
1434 Node *A1 =
1435 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1436 Node *A2 =
1437 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1438 Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1439 Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1440 Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1441 Node *pred4 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1442 Node *T1 = F_->createTranspose("transpose1", A1, NHWC2NCHW);
1443 T1->setPredicate(pred1);
1444 Node *T2 = F_->createTranspose("transpose2", A2, NHWC2NCHW);
1445 T2->setPredicate(pred2);
1446 Node *K = F_->createAdd("arith", T1, T2);
1447 K->setPredicate(pred3);
1448 SaveNode *O = F_->createSave("ret", K);
1449 O->setPredicate(pred4);
1450
1451 EXPECT_EQ(F_->getNodes().size(), 4);
1452
1453 ::glow::optimize(F_, CompilationMode::Infer);
1454
1455 EXPECT_EQ(O->getPredicate().getNode(), pred4);
1456 // Expecting Transpose->Output rather than Add->Output.
1457 auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1458 ASSERT_NE(transpose, nullptr);
1459 EXPECT_EQ(transpose->getPredicate().getNode(), pred3);
1460 auto *add = llvm::dyn_cast<AddNode>(transpose->getInput());
1461 ASSERT_TRUE(add);
1462 EXPECT_EQ(add->getPredicate().getNode(), pred3);
1463 // Check that the dimensions of the input and output have been
1464 // updated to compensate the absence of transpose.
1465 EXPECT_EQ(add->getResult().dims(), llvm::makeArrayRef(origDims));
1466 EXPECT_EQ(add->getRHS().dims(), llvm::makeArrayRef(origDims));
1467 EXPECT_EQ(add->getLHS().dims(), llvm::makeArrayRef(origDims));
1468 EXPECT_EQ(add->getLHS().getNode(), A1);
1469 EXPECT_EQ(add->getRHS().getNode(), A2);
1470
1471 EXPECT_EQ(F_->getNodes().size(), 3);
1472 }
1473
TEST_F(GraphOptz,sinkReluBelowConcatNodes)1474 TEST_F(GraphOptz, sinkReluBelowConcatNodes) {
1475 const dim_t origDims[] = {1, 5, 10, 15};
1476 const dim_t origDimsConcat[] = {1, 10, 10, 15};
1477 Node *A1 =
1478 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1479 Node *A2 =
1480 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1481 Node *R1 = F_->createRELU("relu1", A1);
1482 Node *R2 = F_->createRELU("relu2", A2);
1483 Node *CN = F_->createConcat("concat", {R1, R2}, 1);
1484 SaveNode *O = F_->createSave("ret", CN);
1485
1486 EXPECT_EQ(F_->getNodes().size(), 4);
1487
1488 ::glow::optimize(F_, CompilationMode::Infer);
1489
1490 // Expecting RELU->Output rather than Concat->Output.
1491 auto *relu = llvm::dyn_cast<ReluNode>(O->getInput());
1492 ASSERT_NE(relu, nullptr);
1493 auto *concat = llvm::dyn_cast<ConcatNode>(relu->getInput());
1494 ASSERT_TRUE(concat);
1495 // Check that the dimensions of the input and output have been
1496 // updated to compensate the absence of transpose.
1497 EXPECT_EQ(concat->getResult().dims(), llvm::makeArrayRef(origDimsConcat));
1498 EXPECT_EQ(concat->getInputs()[0].dims(), llvm::makeArrayRef(origDims));
1499 EXPECT_EQ(concat->getInputs()[1].dims(), llvm::makeArrayRef(origDims));
1500 EXPECT_EQ(concat->getInputs()[0].getNode(), A1);
1501 EXPECT_EQ(concat->getInputs()[1].getNode(), A2);
1502
1503 EXPECT_EQ(F_->getNodes().size(), 3);
1504 }
1505
1506 /// Check that the predicates are properly preserved while doing
1507 /// the sinking of relu nodes.
TEST_F(GraphOptz,sinkReluBelowConcatNodesWithPredicate)1508 TEST_F(GraphOptz, sinkReluBelowConcatNodesWithPredicate) {
1509 const dim_t origDims[] = {1, 5, 10, 15};
1510 const dim_t origDimsConcat[] = {1, 10, 10, 15};
1511 Node *A1 =
1512 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1513 Node *A2 =
1514 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1515 Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1516 Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1517 Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1518 Node *pred4 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1519 Node *R1 = F_->createRELU("relu1", A1);
1520 R1->setPredicate(pred1);
1521 Node *R2 = F_->createRELU("relu2", A2);
1522 R2->setPredicate(pred2);
1523 Node *CN = F_->createConcat("concat", {R1, R2}, 1);
1524 CN->setPredicate(pred3);
1525 SaveNode *O = F_->createSave("ret", CN);
1526 O->setPredicate(pred4);
1527
1528 EXPECT_EQ(F_->getNodes().size(), 4);
1529
1530 ::glow::optimize(F_, CompilationMode::Infer);
1531
1532 // Expecting RELU->Output rather than Concat->Output.
1533 EXPECT_EQ(O->getPredicate().getNode(), pred4);
1534 auto *relu = llvm::dyn_cast<ReluNode>(O->getInput());
1535 ASSERT_NE(relu, nullptr);
1536 EXPECT_EQ(relu->getPredicate().getNode(), pred3);
1537 auto *concat = llvm::dyn_cast<ConcatNode>(relu->getInput());
1538 ASSERT_TRUE(concat);
1539 EXPECT_EQ(concat->getPredicate().getNode(), pred3);
1540 // Check that the dimensions of the input and output have been
1541 // updated to compensate the absence of transpose.
1542 EXPECT_EQ(concat->getResult().dims(), llvm::makeArrayRef(origDimsConcat));
1543 EXPECT_EQ(concat->getInputs()[0].dims(), llvm::makeArrayRef(origDims));
1544 EXPECT_EQ(concat->getInputs()[1].dims(), llvm::makeArrayRef(origDims));
1545 EXPECT_EQ(concat->getInputs()[0].getNode(), A1);
1546 EXPECT_EQ(concat->getInputs()[1].getNode(), A2);
1547
1548 EXPECT_EQ(F_->getNodes().size(), 3);
1549 }
1550
TEST_F(GraphOptz,sinkTransposeBelowConcatNodes)1551 TEST_F(GraphOptz, sinkTransposeBelowConcatNodes) {
1552 const dim_t origDims[] = {1, 5, 10, 15};
1553 const dim_t origDimsConcat[] = {1, 5, 20, 15};
1554 Node *A1 =
1555 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1556 Node *A2 =
1557 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1558 Node *T1 = F_->createTranspose("transpose", A1, NCHW2NHWC);
1559 Node *T2 = F_->createTranspose("transpose", A2, NCHW2NHWC);
1560 Node *CN = F_->createConcat("concat", {T1, T2}, 1);
1561 SaveNode *O = F_->createSave("ret", CN);
1562
1563 EXPECT_EQ(F_->getNodes().size(), 4);
1564
1565 ::glow::optimize(F_, CompilationMode::Infer);
1566
1567 // Expecting Transpose->Output rather than Add->Output.
1568 auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1569 ASSERT_NE(transpose, nullptr);
1570 auto *concat = llvm::dyn_cast<ConcatNode>(transpose->getInput());
1571 ASSERT_TRUE(concat);
1572 // Check that the dimensions of the input and output have been
1573 // updated to compensate the absence of transpose.
1574 EXPECT_EQ(concat->getResult().dims(), llvm::makeArrayRef(origDimsConcat));
1575 EXPECT_EQ(concat->getInputs()[0].dims(), llvm::makeArrayRef(origDims));
1576 EXPECT_EQ(concat->getInputs()[1].dims(), llvm::makeArrayRef(origDims));
1577 EXPECT_EQ(concat->getInputs()[0].getNode(), A1);
1578 EXPECT_EQ(concat->getInputs()[1].getNode(), A2);
1579
1580 EXPECT_EQ(F_->getNodes().size(), 3);
1581 }
1582
1583 /// Check that the predicates are properly preserved while doing
1584 /// the concat(transpose, transpose) => transpose(add).
TEST_F(GraphOptz,sinkTransposeBelowConcatWithPredicate)1585 TEST_F(GraphOptz, sinkTransposeBelowConcatWithPredicate) {
1586 const dim_t origDims[] = {1, 5, 10, 15};
1587 const dim_t origDimsConcat[] = {1, 5, 20, 15};
1588 Node *A1 =
1589 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input1", false);
1590 Node *A2 =
1591 mod_.createPlaceholder(ElemKind::FloatTy, origDims, "input2", false);
1592 Node *pred1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1593 Node *pred2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1594 Node *pred3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1595 Node *pred4 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "pred", false);
1596 Node *T1 = F_->createTranspose("transpose", A1, NCHW2NHWC);
1597 T1->setPredicate(pred1);
1598 Node *T2 = F_->createTranspose("transpose", A2, NCHW2NHWC);
1599 T2->setPredicate(pred2);
1600 Node *CN = F_->createConcat("concat", {T1, T2}, 1);
1601 CN->setPredicate(pred3);
1602 SaveNode *O = F_->createSave("ret", CN);
1603 O->setPredicate(pred4);
1604
1605 EXPECT_EQ(F_->getNodes().size(), 4);
1606
1607 ::glow::optimize(F_, CompilationMode::Infer);
1608
1609 EXPECT_EQ(O->getPredicate().getNode(), pred4);
1610 // Expecting Transpose->Output rather than Add->Output.
1611 auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1612 ASSERT_NE(transpose, nullptr);
1613 EXPECT_EQ(transpose->getPredicate().getNode(), pred3);
1614 auto *concat = llvm::dyn_cast<ConcatNode>(transpose->getInput());
1615 ASSERT_TRUE(concat);
1616 EXPECT_EQ(concat->getPredicate().getNode(), pred3);
1617 // Check that the dimensions of the input and output have been
1618 // updated to compensate the absence of transpose.
1619 EXPECT_EQ(concat->getResult().dims(), llvm::makeArrayRef(origDimsConcat));
1620 EXPECT_EQ(concat->getInputs()[0].dims(), llvm::makeArrayRef(origDims));
1621 EXPECT_EQ(concat->getInputs()[1].dims(), llvm::makeArrayRef(origDims));
1622 EXPECT_EQ(concat->getInputs()[0].getNode(), A1);
1623 EXPECT_EQ(concat->getInputs()[1].getNode(), A2);
1624
1625 EXPECT_EQ(F_->getNodes().size(), 3);
1626 }
1627
TEST_F(GraphOptz,sinkTransposeBelowPad)1628 TEST_F(GraphOptz, sinkTransposeBelowPad) {
1629 // The shape of the graph before the optimization.
1630 const dim_t inputDims[] = {1, 5, 10, 15};
1631 const dim_t outTransposeDims[] = {1, 10, 15, 5};
1632 const dim_t outPadDims[] = {5, 18, 25, 11};
1633 // Padding before the optimization.
1634 int pads[] = {0, 2, 3, 1, 4, 6, 7, 5};
1635
1636 // The shape of the graph after the optimization.
1637 const dim_t outPadDimsAfterOptim[] = {5, 11, 18, 25};
1638 const dim_t outTransposeDimsAfterOptims[] = {5, 18, 25, 11};
1639 // Padding after the optimization.
1640 int padsAfterOptim[] = {0, 1, 2, 3, 4, 5, 6, 7};
1641
1642 // Create the initial graph.
1643 Node *A =
1644 mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
1645 auto outTy = mod_.uniqueType(ElemKind::FloatTy, outPadDims);
1646 TransposeNode *T = F_->createTranspose("transpose", A, NCHW2NHWC);
1647 Node *P = F_->createPad("pad", T, outTy, PaddingMode::CONSTANT, pads, 23.f);
1648 EXPECT_EQ(T->getResult().dims(), llvm::makeArrayRef(outTransposeDims));
1649 SaveNode *O = F_->createSave("ret", P);
1650
1651 EXPECT_EQ(F_->getNodes().size(), 3);
1652
1653 ::glow::optimize(F_, CompilationMode::Infer);
1654
1655 // Check the graph structure and additional properties after optimization.
1656 auto *trans = llvm::dyn_cast<TransposeNode>(O->getInput());
1657 ASSERT_NE(trans, nullptr);
1658 EXPECT_EQ(trans->getResult().dims(),
1659 llvm::makeArrayRef(outTransposeDimsAfterOptims));
1660 auto *pad = llvm::dyn_cast<PadNode>(trans->getInput().getNode());
1661 ASSERT_NE(pad, nullptr);
1662
1663 EXPECT_EQ(pad->getPads(), llvm::makeArrayRef(padsAfterOptim));
1664 EXPECT_EQ(pad->getResult().dims(), llvm::makeArrayRef(outPadDimsAfterOptim));
1665
1666 EXPECT_EQ(F_->getNodes().size(), 3);
1667 }
1668
TEST_F(GraphOptz,sinkTransposeBelowRelu)1669 TEST_F(GraphOptz, sinkTransposeBelowRelu) {
1670 // Define a type with custom alignments.
1671 Type typeWithAlignments(ElemKind::FloatTy, {2, 3, 4, 5}, {1, 1, 32, 1});
1672 Type transposedTypeWithAlignments(ElemKind::FloatTy, {2, 4, 5, 3},
1673 {1, 1, 32, 1});
1674 auto modTyWithAlignments = mod_.uniqueType(typeWithAlignments);
1675 auto modTransposedTyWithAlignments =
1676 mod_.uniqueType(transposedTypeWithAlignments);
1677 auto *A1 = mod_.createPlaceholder(modTyWithAlignments, "input1", false);
1678 auto *T1 = F_->createTranspose("transpose", A1, NCHW2NHWC);
1679 T1->setType(0, modTransposedTyWithAlignments);
1680 auto *RN = F_->createRELU("relu", T1);
1681 SaveNode *O = F_->createSave("ret", RN);
1682
1683 EXPECT_EQ(F_->getNodes().size(), 3);
1684
1685 ::glow::optimize(F_, CompilationMode::Infer);
1686
1687 // Expecting Transpose->Output rather than Relu->Output, because Transpose was
1688 // sinked.
1689 auto *transpose = llvm::dyn_cast<TransposeNode>(O->getInput());
1690 ASSERT_NE(transpose, nullptr);
1691 auto *relu = llvm::dyn_cast<ReluNode>(transpose->getInput());
1692 ASSERT_TRUE(relu);
1693 // Check that alignments are preserved by optimizations.
1694 ASSERT_TRUE(relu->getInput().getType()->isEqual(modTyWithAlignments));
1695 ASSERT_TRUE(transpose->getInput().getType()->isEqual(modTyWithAlignments));
1696 ASSERT_TRUE(
1697 transpose->getResult().getType()->isEqual(modTransposedTyWithAlignments));
1698
1699 EXPECT_EQ(F_->getNodes().size(), 3);
1700 ASSERT_TRUE(F_->verify());
1701 }
1702
TEST_F(GraphOptz,mergeConcatNodes)1703 TEST_F(GraphOptz, mergeConcatNodes) {
1704 Node *A1 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input1",
1705 false);
1706 Node *A2 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input2",
1707 false);
1708 Node *A3 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input3",
1709 false);
1710 Node *A4 =
1711 mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 5, 15}, "input4", false);
1712 Node *A5 =
1713 mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 5, 15}, "input5", false);
1714
1715 Node *CN1 = F_->createConcat("concat1", {A1, A2}, 1);
1716 Node *CN2 = F_->createConcat("concat2", {A1, CN1}, 1);
1717 Node *CN3 = F_->createConcat("concat3", {A4, A5}, 2);
1718 Node *CN4 = F_->createConcat("concat4", {A3, CN2, CN3}, 1);
1719 Node *O = F_->createSave("ret", CN4);
1720
1721 EXPECT_EQ(F_->getNodes().size(), 5);
1722
1723 ::glow::optimize(F_, CompilationMode::Train);
1724
1725 // It is expected that the optimization transforms
1726 // concat4(1, A3, concat2(1, A1, concat1(1, A1, A2)), concat3(2, A4, A5))
1727 // into
1728 // concat4(1, A3, A1, A1, A2, concat3(2, A4, A5))
1729
1730 EXPECT_TRUE(llvm::isa<SaveNode>(O));
1731
1732 auto *CN =
1733 llvm::dyn_cast<ConcatNode>(llvm::dyn_cast<SaveNode>(O)->getInput());
1734 EXPECT_TRUE(CN);
1735
1736 // The merged ConcatNode should have 5 inputs.
1737 EXPECT_EQ(CN->getInputs().size(), 5);
1738
1739 // CN1 should be merged into a new CN2 and later into a new CN4 and removed by
1740 // the optimizations.
1741 EXPECT_FALSE(functionContainsNode(F_, CN1));
1742
1743 // CN2 should be merged into a new CN4 and removed by the optimizations.
1744 EXPECT_FALSE(functionContainsNode(F_, CN2));
1745
1746 // CN3 should not be merged into CN4 and should not be removed,
1747 // because CN4 and CN3 have a different dimension parameter.
1748 EXPECT_TRUE(functionContainsNode(F_, CN3));
1749
1750 // The CN4 concat node should be replaced by a merged concat node.
1751 EXPECT_FALSE(functionContainsNode(F_, CN4));
1752
1753 EXPECT_EQ(F_->getNodes().size(), 3);
1754 }
1755
TEST_F(GraphOptz,CSE)1756 TEST_F(GraphOptz, CSE) {
1757 Node *A1 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input1",
1758 false);
1759 Node *A2 = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5, 10, 15}, "input2",
1760 false);
1761
1762 Node *CN1 = F_->createConcat("concat1", {A1, A2}, 1);
1763 Node *CN2 = F_->createConcat("concat2", {A1, A2}, 1);
1764 Node *CN3 = F_->createConcat("concat3", {CN1, CN2}, 2);
1765 Node *O = F_->createSave("ret", CN3);
1766
1767 EXPECT_EQ(F_->getNodes().size(), 4);
1768
1769 ::glow::optimize(F_, CompilationMode::Train);
1770
1771 EXPECT_TRUE(llvm::isa<SaveNode>(O));
1772
1773 auto *CN =
1774 llvm::dyn_cast<ConcatNode>(llvm::dyn_cast<SaveNode>(O)->getInput());
1775 EXPECT_TRUE(CN);
1776
1777 // The merged ConcatNode should have 2 inputs.
1778 EXPECT_EQ(CN->getInputs().size(), 2);
1779
1780 // CN1 should not be removed.
1781 EXPECT_TRUE(functionContainsNode(F_, CN1));
1782
1783 // CSE should replace CN2 by CN1 and remove CN2.
1784 EXPECT_FALSE(functionContainsNode(F_, CN2));
1785
1786 EXPECT_EQ(F_->getNodes().size(), 3);
1787 }
1788
TEST_F(GraphOptz,SliceOfSplatNode)1789 TEST_F(GraphOptz, SliceOfSplatNode) {
1790 Type t(ElemKind::FloatTy, {1000, 1000, 1000});
1791 Node *Z = F_->createSplat("zero", &t, 0.);
1792 Node *S = F_->createSlice("slice", Z, {5, 15, 42}, {99, 88, 77});
1793 Node *O = F_->createSave("ret", S);
1794
1795 EXPECT_EQ(F_->getNodes().size(), 3);
1796
1797 ::glow::optimize(F_, CompilationMode::Train);
1798
1799 EXPECT_EQ(F_->getNodes().size(), 2);
1800
1801 EXPECT_TRUE(llvm::isa<SaveNode>(O));
1802
1803 auto *CN = llvm::dyn_cast<SplatNode>(llvm::dyn_cast<SaveNode>(O)->getInput());
1804 EXPECT_TRUE(CN);
1805
1806 EXPECT_TRUE(CN->getResult().getType()->dims().equals({94, 73, 35}));
1807 }
1808
1809 /// Test Clip(Splat(args)) -> Splat(args').
TEST_F(GraphOptz,ClipOfSplatNode)1810 TEST_F(GraphOptz, ClipOfSplatNode) {
1811 Type T(ElemKind::FloatTy, {10, 10});
1812 SplatNode *splat = F_->createSplat("zero", &T, 5);
1813 ClipNode *clipMin = F_->createClip("clip", splat, 10, 15);
1814 ClipNode *clipMax = F_->createClip("clip", splat, 0, 2);
1815 ClipNode *clipSame = F_->createClip("clip", splat, 0, 10);
1816 SaveNode *saveMin = F_->createSave("saveMin", clipMin);
1817 SaveNode *saveMax = F_->createSave("saveMax", clipMax);
1818 SaveNode *saveSame = F_->createSave("saveSame", clipSame);
1819
1820 // Start with one splat, three clips, three saves.
1821 EXPECT_EQ(F_->getNodes().size(), 7);
1822
1823 ::glow::optimize(F_, CompilationMode::Infer);
1824
1825 // We will end up with three Splats and three saves.
1826 EXPECT_EQ(F_->getNodes().size(), 6);
1827
1828 SplatNode *splatMin = llvm::dyn_cast<SplatNode>(saveMin->getInput());
1829 ASSERT_TRUE(splatMin);
1830 EXPECT_EQ(splatMin->getValue(), 10);
1831
1832 SplatNode *splatMax = llvm::dyn_cast<SplatNode>(saveMax->getInput());
1833 ASSERT_TRUE(splatMax);
1834 EXPECT_EQ(splatMax->getValue(), 2);
1835
1836 ASSERT_EQ(saveSame->getInput().getNode(), splat);
1837 EXPECT_EQ(splat->getValue(), 5);
1838 }
1839
TEST_F(GraphOptz,ZeroArithmetic)1840 TEST_F(GraphOptz, ZeroArithmetic) {
1841 // Tests the identities: [0 + X = X] [0 * X = 0] [0 / X = 0] [ X - 0 = X]
1842
1843 auto *input =
1844 mod_.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input", true);
1845
1846 // This builds the expression: ((0 / I) + (0 + I) + (0 * I)) - 0
1847
1848 auto *zero = F_->createSplat("zero", input->getType(), 0.);
1849
1850 auto *div = F_->createDiv("div", zero, input); // -> zero
1851
1852 auto *add = F_->createAdd("add", zero, input); // -> input
1853
1854 auto *mul = F_->createMul("mul", zero, input); // -> zero
1855
1856 auto *add3 = F_->createAdd("add", div, add);
1857
1858 add3 = F_->createAdd("add", add3, mul);
1859
1860 auto *sub = F_->createSub("sub", add3, zero); // -> input
1861
1862 SaveNode *O = F_->createSave("ret", sub);
1863
1864 // The expression evaluates to "I".
1865
1866 EXPECT_EQ(F_->getNodes().size(), 8);
1867
1868 ::glow::optimize(F_, CompilationMode::Infer);
1869
1870 EXPECT_EQ(F_->getNodes().size(), 1);
1871
1872 EXPECT_EQ(O->getInput().getNode(), input);
1873
1874 optimizedF_ = optimizeFunction(F_);
1875
1876 bindings_.allocate(mod_.getPlaceholders());
1877 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1878
1879 checkNumericalEquivalence();
1880 }
1881
1882 // Similar to ZeroArithmetic, but tests that nodes with multiple results are
1883 // correctly handled (i.e. that the correct output is selected after optimising
1884 // away an arithmetic identity).
TEST_F(GraphOptz,ZeroArithmeticMultiResNode)1885 TEST_F(GraphOptz, ZeroArithmeticMultiResNode) {
1886 auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {10}, "input", true);
1887 auto *topK = F_->createTopK("topK", input, /*k=*/5);
1888 auto *zero = F_->createSplat("zero", topK->getValues().getType(), 0.);
1889 auto *add = F_->createAdd("add", topK->getValues(), zero);
1890 auto *sub = F_->createSub("sub", topK->getValues(), zero);
1891
1892 SaveNode *AS = F_->createSave("ret", add);
1893 SaveNode *SS = F_->createSave("ret", sub);
1894
1895 // There should be 6 nodes: 2 Saves, Add, Sub, Splat and TopK.
1896 EXPECT_EQ(F_->getNodes().size(), 6);
1897
1898 optimizedF_ = optimizeFunction(F_);
1899
1900 // Now there should only be 3 nodes: TopK and 2 Saves.
1901 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
1902
1903 auto *OAS = findFunctionNodeByName<SaveNode>(optimizedF_, AS->getName());
1904 auto *OSS = findFunctionNodeByName<SaveNode>(optimizedF_, SS->getName());
1905 auto *OTopK = findFunctionNodeByName<TopKNode>(optimizedF_, topK->getName());
1906
1907 // Since the operations reprsented by the arithmetic nodes are no-ops,
1908 // the input to both SaveNodes should be the Values result of TopKNode.
1909 EXPECT_EQ(OAS->getInput(), OTopK->getValues());
1910 EXPECT_EQ(OSS->getInput(), OTopK->getValues());
1911
1912 // Check numerical equivalence.
1913 bindings_.allocate(mod_.getPlaceholders());
1914 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1915
1916 checkNumericalEquivalence();
1917 }
1918
1919 /// A test that verifies that arithmetic simplification works correctly when
1920 /// the parents need to be simplified prior to the node itself.
TEST_F(GraphOptz,ZeroArithmeticParentsMustBeSimplifiedFirst)1921 TEST_F(GraphOptz, ZeroArithmeticParentsMustBeSimplifiedFirst) {
1922 auto *input1 =
1923 mod_.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input1", true);
1924 auto *input2 =
1925 mod_.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input2", true);
1926
1927 // This builds the expression: ((0 * I1) * (0 * I2)) = 0
1928 // It should be simplified to simply the splat zero node being saved.
1929
1930 SplatNode *zero = F_->createSplat("zero", input1->getType(), 0.);
1931
1932 MulNode *mul1 = F_->createMul("mul1", zero, input1); // -> 0
1933 MulNode *mul2 = F_->createMul("mul2", zero, input2); // -> 0
1934
1935 MulNode *mul3 = F_->createMul("mul3", mul1, mul2); // -> 0
1936
1937 SaveNode *O = F_->createSave("ret", mul3);
1938
1939 // Expect 1 splat, 3 muls, 1 save.
1940 EXPECT_EQ(F_->getNodes().size(), 5);
1941
1942 ::glow::optimize(F_, CompilationMode::Infer);
1943
1944 // Expect all muls to be optimized away, with 1 splat and 1 save left.
1945 EXPECT_EQ(F_->getNodes().size(), 2);
1946 EXPECT_TRUE(functionContainsNode(F_, O));
1947 EXPECT_TRUE(functionContainsNode(F_, zero));
1948 EXPECT_EQ(O->getInput().getNode(), zero);
1949 }
1950
1951 /// Tests opts for the identities: [1 * X = X] [X / 1 = X]
TEST_F(GraphOptz,ArithmeticIdentitiesOne)1952 TEST_F(GraphOptz, ArithmeticIdentitiesOne) {
1953 auto *input =
1954 mod_.createPlaceholder(ElemKind::FloatTy, {4, 10}, "input", true);
1955
1956 // This builds the expression: (I / 1) * 1:
1957 SplatNode *one = F_->createSplat("one", input->getType(), 1.);
1958 DivNode *div = F_->createDiv("div", input, one);
1959 MulNode *mul = F_->createMul("mul", div, one);
1960 SaveNode *save = F_->createSave("ret", mul);
1961
1962 // Splat, Div, Mul, Save.
1963 EXPECT_EQ(F_->getNodes().size(), 4);
1964 // Save optimized function for future comparision
1965 optimizedF_ = optimizeFunction(F_);
1966
1967 // The expression evaluates to "I", so Save is only node left.
1968 EXPECT_EQ(optimizedF_->getNodes().size(), 1);
1969 SaveNode *SN =
1970 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save->getName()));
1971 ASSERT_TRUE(functionContainsNode(optimizedF_, SN));
1972 ASSERT_NE(SN, nullptr);
1973
1974 // Save node should just save the input.
1975 EXPECT_TRUE(SN->getInput().getNode() == input);
1976
1977 bindings_.allocate(mod_.getPlaceholders());
1978 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
1979
1980 checkNumericalEquivalence();
1981 }
1982
1983 /// Reverse the intrusive list of nodes. This custom implementation is required,
1984 /// because std::reverse cannot be used with LLVM's intrusive lists.
reverse(NodesList & L)1985 static void reverse(NodesList &L) {
1986 if (L.empty())
1987 return;
1988 // Last element of the list before reversal.
1989 auto &last = L.back();
1990 // Take element from the beginning and move it right after the old last
1991 // element. Do it until the old last element becomes the first element.
1992 while (true) {
1993 auto &first = L.front();
1994 // Finish when the old last element becomes the new front element.
1995 if (&first == &last) {
1996 break;
1997 }
1998 L.remove(first);
1999 L.insert(++last.getIterator(), &first);
2000 }
2001 }
2002
TEST(GraphOptzTest,SliceOfSplatNodeChain)2003 TEST(GraphOptzTest, SliceOfSplatNodeChain) {
2004 for (int shouldReverse = 0; shouldReverse <= 1; shouldReverse++) {
2005 Module mod;
2006 Function *F = mod.createFunction("foo");
2007
2008 Type t(ElemKind::FloatTy, {1000, 1000, 1000});
2009 Node *Z = F->createSplat("zero", &t, 0.);
2010 Node *S1 = F->createSlice("slice1", Z, {5, 15, 42}, {99, 88, 77});
2011 Node *S2 = F->createSlice("slice2", S1, {1, 1, 1}, {2, 3, 4});
2012 F->createSave("ret", S2);
2013
2014 if (shouldReverse) {
2015 auto &nodes = F->getNodes();
2016 reverse(nodes);
2017 }
2018
2019 EXPECT_EQ(F->getNodes().size(), 4);
2020
2021 CompilationContext cctx;
2022 cctx.compMode = CompilationMode::Train;
2023 // Do not perform any compile-time constant folding.
2024 cctx.optimizationOpts.enableConstantFolding = false;
2025 ::glow::optimize(F, cctx);
2026
2027 // This test illustrates some inconsistency in the optimization.
2028 // Chain splats are not guaranteed to be optimized.
2029 EXPECT_EQ(F->getNodes().size(), shouldReverse ? 3 : 2);
2030 }
2031 }
2032
TEST_F(GraphOptz,ReshapeNoop)2033 TEST_F(GraphOptz, ReshapeNoop) {
2034 const dim_t shape[] = {10, 20, 30};
2035 Type t(ElemKind::FloatTy, shape);
2036 auto *Z = F_->createSplat("zero", &t, 0.);
2037 auto *R = F_->createReshape("reshape", Z, shape);
2038 auto *O = F_->createSave("ret", R);
2039
2040 EXPECT_EQ(F_->getNodes().size(), 3);
2041
2042 ::glow::optimize(F_, CompilationMode::Train);
2043
2044 EXPECT_EQ(F_->getNodes().size(), 2);
2045
2046 auto *SN = llvm::dyn_cast<SplatNode>(llvm::dyn_cast<SaveNode>(O)->getInput());
2047 EXPECT_TRUE(SN);
2048
2049 EXPECT_TRUE(SN->getResult().getType()->dims().equals(shape));
2050 }
2051
2052 /// Test the Reshape(Splat(args)) -> Splat(args') transformation.
2053 /// Including a positive and a negative test case. In the positive case,
2054 /// the optimization will take place for the splat node (Z2) that has only one
2055 /// use. In the negative case, the optimization will not happen as the splat
2056 /// node (Z1) has more than one use.
TEST_F(GraphOptz,ReshapeAfterSplat)2057 TEST_F(GraphOptz, ReshapeAfterSplat) {
2058 const dim_t shape[] = {10, 20, 30};
2059 const dim_t reshape[] = {1, 6000};
2060 Type t1(ElemKind::FloatTy, shape);
2061 Type t2(ElemKind::FloatTy, reshape);
2062 Node *input = F_->getParent()->createPlaceholder(ElemKind::FloatTy, shape,
2063 "input", true);
2064 auto *Z1 = F_->createSplat("zero1", &t1, 1.5);
2065 auto *A1 = F_->createAdd("add1", Z1->getResult().getType(), input, Z1);
2066 auto *R1 = F_->createReshape("reshape1", Z1, reshape);
2067 // Z1 is used by R1 and A1.
2068 // The reshape optimization will thus NOT be able to remove this reshape node
2069 // (R1).
2070 auto *R2 = F_->createReshape("reshape2", A1, reshape);
2071 auto *A2 = F_->createAdd("add", R1->getResult().getType(), R1, R2);
2072 auto *Z2 = F_->createSplat("zero2", &t1, 2.5);
2073 auto *R3 = F_->createReshape("reshape3", Z2, reshape);
2074 // Z2 is only used by R3.
2075 // The Z2,R3 nodes will be replaced by a new splat node with the shape of R3.
2076 auto *A3 = F_->createAdd("add", A2->getResult().getType(), A2, R3);
2077 auto *O = F_->createSave("ret", A3);
2078
2079 // Before optimization, we have 9 nodes in the graph.
2080 EXPECT_EQ(F_->getNodes().size(), 9);
2081
2082 cctx_.compMode = CompilationMode::Infer;
2083 // Do not perform any compile-time constant folding.
2084 cctx_.optimizationOpts.enableConstantFolding = false;
2085 ::glow::optimize(F_, cctx_);
2086
2087 // After optimization, we expect to see only 8 nodes, as Z2,R2 would be
2088 // replace by a new splat node.
2089 EXPECT_EQ(F_->getNodes().size(), 8);
2090
2091 // The second input of A3 shoule be a splat node with a shape of R3.
2092 auto *newA3 = llvm::dyn_cast<AddNode>(O->getInput());
2093 ASSERT_TRUE(newA3);
2094 auto *SN = llvm::dyn_cast<SplatNode>(newA3->getRHS());
2095 EXPECT_TRUE(SN);
2096 EXPECT_TRUE(SN->getResult().getType()->dims().equals(reshape));
2097
2098 // R1 should still be in the graph.
2099 EXPECT_TRUE(functionContainsNode(F_, R1));
2100
2101 // R3 and Z2 should not be in the graph any more.
2102 EXPECT_FALSE(functionContainsNode(F_, R3));
2103 EXPECT_FALSE(functionContainsNode(F_, Z2));
2104 }
2105
2106 /// Test the Reshape(Reshape(x)) -> Reshape(x) transformation.
TEST_F(GraphOptz,ReshapeReshapeOpt)2107 TEST_F(GraphOptz, ReshapeReshapeOpt) {
2108 const dim_t shape[] = {10, 20};
2109 const dim_t reshape1[] = {200, 1};
2110 const dim_t reshape2[] = {200};
2111 Node *input = F_->getParent()->createPlaceholder(ElemKind::FloatTy, shape,
2112 "input", true);
2113 auto *R1 = F_->createReshape("reshape1", input, reshape1);
2114 auto *R2 = F_->createReshape("reshape2", R1, reshape2);
2115 auto *O = F_->createSave("ret", R2);
2116
2117 // Before optimization, we have 2 Reshapes and a Save.
2118 EXPECT_EQ(F_->getNodes().size(), 3);
2119
2120 ::glow::optimize(F_, CompilationMode::Infer);
2121
2122 // After optimization, we expect to see only 1 Reshape and a Save.
2123 EXPECT_EQ(F_->getNodes().size(), 2);
2124
2125 // Save should have the new Reshape as input.
2126 auto *RN = llvm::dyn_cast<ReshapeNode>(O->getInput());
2127 ASSERT_TRUE(RN);
2128 // The new Reshape should have the same shape as the original second Reshape.
2129 EXPECT_TRUE(RN->getResult().getType()->dims().equals(reshape2));
2130
2131 // R1 and R2 should not be in the graph any more; they were replaced by a
2132 // single new reshape.
2133 EXPECT_FALSE(functionContainsNode(F_, R1));
2134 EXPECT_FALSE(functionContainsNode(F_, R2));
2135 }
2136
TEST_F(GraphOptz,DCEPublicVars)2137 TEST_F(GraphOptz, DCEPublicVars) {
2138 mod_.createPlaceholder(ElemKind::FloatTy, {4, 320, 200, 3}, "input", true);
2139
2140 EXPECT_EQ(mod_.getPlaceholders().size(), 1);
2141
2142 // Optimize all of the dead code.
2143 ::glow::optimize(F_, CompilationMode::Infer);
2144
2145 // Public nodes should not be deleted.
2146 EXPECT_EQ(mod_.getPlaceholders().size(), 1);
2147 }
2148
TEST_F(GraphOptz,foldQuantizeIntoConstant)2149 TEST_F(GraphOptz, foldQuantizeIntoConstant) {
2150 auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {4}, "input", true);
2151 *bindings_.allocate(input) = {10, 10, 10, 10};
2152 auto qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2153
2154 auto *Q = F_->createQuantize("quantize", input, qType);
2155 auto *S = F_->createSave("save", Q);
2156
2157 EXPECT_EQ(2, F_->getNodes().size());
2158 ::glow::convertPlaceholdersToConstants(F_, bindings_, {S->getPlaceholder()});
2159
2160 // 'optimize' doesn't merge quantize nodes into Constant.
2161 ::glow::optimize(F_, CompilationMode::Infer);
2162 EXPECT_EQ(2, F_->getNodes().size());
2163
2164 // 'convertQuantizedConstants' merges quantize nodes into Constant
2165 CompilationContext cctx;
2166 ::glow::convertQuantizedConstants(F_, cctx);
2167 EXPECT_EQ(1, F_->getNodes().size());
2168
2169 auto quantizedInput = llvm::cast<Constant>(S->getInput());
2170 auto quantizedValues = quantizedInput->getHandle<int8_t>();
2171 for (unsigned i = 0; i < 4; ++i) {
2172 EXPECT_EQ(5, quantizedValues.raw(i));
2173 }
2174 }
2175
TEST_F(GraphOptz,foldQuantizeIntoConstantMultipleUsages)2176 TEST_F(GraphOptz, foldQuantizeIntoConstantMultipleUsages) {
2177 auto *input = mod_.createPlaceholder(ElemKind::FloatTy, {4}, "input", true);
2178 *bindings_.allocate(input) = {10, 10, 10, 10};
2179 auto qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2180
2181 auto *Q = F_->createQuantize("quantize", input, qType);
2182 F_->createSave("save", Q);
2183 auto clonedF = F_->clone("cloned");
2184
2185 EXPECT_EQ(2, clonedF->getNodes().size());
2186 ::glow::convertPlaceholdersToConstants(clonedF, bindings_, {});
2187 CompilationContext cctx;
2188 ::glow::convertQuantizedConstants(clonedF, cctx);
2189
2190 // F_ function should not be affected.
2191 EXPECT_EQ(2, F_->getNodes().size());
2192
2193 // Check original var.
2194 for (unsigned i = 0; i < 4; ++i) {
2195 EXPECT_EQ(10, bindings_.get(input)->getHandle().raw(i));
2196 }
2197
2198 // Quantization node was merged into input var.
2199 EXPECT_EQ(1, clonedF->getNodes().size());
2200 auto *save = llvm::dyn_cast<SaveNode>(&clonedF->getNodes().front());
2201 ASSERT_TRUE(save);
2202 auto quantizedInput = llvm::cast<Constant>(save->getInput());
2203 auto quantizedValues = quantizedInput->getHandle<int8_t>();
2204 for (unsigned i = 0; i < 4; ++i) {
2205 EXPECT_EQ(5, quantizedValues.raw(i));
2206 }
2207 }
2208
2209 /// Search for a unique Save node in input graph \p F and return it.
2210 /// Fails in case there is no Save node or more than one detected.
getUniqueSaveNode(Function * F)2211 static SaveNode *getUniqueSaveNode(Function *F) {
2212 SaveNode *foundSaveNode = nullptr;
2213 for (auto &node : F->getNodes()) {
2214 if (auto *s = llvm::dyn_cast<SaveNode>(&node)) {
2215 EXPECT_EQ(foundSaveNode, nullptr);
2216 foundSaveNode = s;
2217 }
2218 }
2219 EXPECT_NE(foundSaveNode, nullptr);
2220 return foundSaveNode;
2221 }
2222
2223 /// Mock backend that requests the pre-quantization of constants.
2224 class MockBackendPrequantizeConst : public MockBackend {
shouldPreQuantizeConstants() const2225 bool shouldPreQuantizeConstants() const override { return true; }
isOpSupported(const NodeInfo &) const2226 bool isOpSupported(const NodeInfo &) const override { return true; }
2227 Expected<bool>
transformPostLowering(Function * F,CompilationContext &,const glow::runtime::DeviceInfo *) const2228 transformPostLowering(Function *F, CompilationContext &,
2229 const glow::runtime::DeviceInfo *) const override {
2230 // Check the IR.
2231 EXPECT_EQ(F->getNodes().size(), 1);
2232 auto *save = getUniqueSaveNode(F);
2233 EXPECT_TRUE(llvm::isa<Constant>(save->getInput()));
2234
2235 return false;
2236 }
2237 };
2238 /// Mock backend that requests the non pre-quantization of constants.
2239 class MockBackendNotPrequantizeConst : public MockBackend {
shouldPreQuantizeConstants() const2240 bool shouldPreQuantizeConstants() const override { return false; }
isOpSupported(const NodeInfo &) const2241 bool isOpSupported(const NodeInfo &) const override { return true; }
2242 Expected<bool>
transformPostLowering(Function * F,CompilationContext &,const glow::runtime::DeviceInfo *) const2243 transformPostLowering(Function *F, CompilationContext &,
2244 const glow::runtime::DeviceInfo *) const override {
2245 // Check the IR.
2246 EXPECT_EQ(F->getNodes().size(), 2);
2247 auto *save = getUniqueSaveNode(F);
2248 auto *quant = llvm::dyn_cast<QuantizeNode>(save->getInput());
2249 EXPECT_TRUE(quant);
2250 EXPECT_TRUE(llvm::isa<Constant>(quant->getInput()));
2251
2252 return false;
2253 }
2254 };
2255
2256 /// Test the actual constant quantization for backends.
2257 template <typename Backend>
testFoldQuantizeIntoConstant(Module & mod_,Function * F_)2258 void testFoldQuantizeIntoConstant(Module &mod_, Function *F_) {
2259 auto *input = mod_.createConstant(ElemKind::FloatTy, {4}, "input");
2260 input->getHandle<float>() = {10, 10, 10, 10};
2261 auto qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2262 auto *Q = F_->createQuantize("quantize", input, qType);
2263 auto *save = F_->createSave("save", Q);
2264
2265 CompilationContext cctx;
2266 auto B = Backend();
2267 // Note: the check that Quantize is or not folded into Constant before
2268 // post-lowering is done in <backend>::transformPostLowering()
2269 EXIT_ON_ERR(::glow::optimizeFunction(F_, B, cctx));
2270
2271 // Check the IR (the constant must have been quantized).
2272 EXPECT_EQ(F_->getNodes().size(), 1);
2273 EXPECT_TRUE(llvm::isa<Constant>(save->getInput()));
2274 }
2275
2276 /// Check the backend actual constant quantization is done before post-lowering.
TEST_F(GraphOptz,foldQuantizeIntoConstantBeforePostLowering)2277 TEST_F(GraphOptz, foldQuantizeIntoConstantBeforePostLowering) {
2278 testFoldQuantizeIntoConstant<MockBackendPrequantizeConst>(mod_, F_);
2279 }
2280
2281 /// Check the backend actual constant quantization is done after post-lowering.
TEST_F(GraphOptz,foldQuantizeIntoConstantAfterPostLowering)2282 TEST_F(GraphOptz, foldQuantizeIntoConstantAfterPostLowering) {
2283 testFoldQuantizeIntoConstant<MockBackendNotPrequantizeConst>(mod_, F_);
2284 }
2285
2286 /// Check that the Quantize(Splat) -> Splat' optimization works.
TEST_F(GraphOptz,foldQuantizeIntoSplat)2287 TEST_F(GraphOptz, foldQuantizeIntoSplat) {
2288 TypeRef fType = mod_.uniqueType(ElemKind::FloatTy, {4});
2289 TypeRef qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2290
2291 const float splatVal = 6.0;
2292 SplatNode *SN = F_->createSplat("splat", fType, splatVal);
2293
2294 QuantizeNode *Q = F_->createQuantize("quantize", SN, qType);
2295 SaveNode *S = F_->createSave("save", Q);
2296
2297 // Splat, quantize, save.
2298 EXPECT_EQ(3, F_->getNodes().size());
2299
2300 ::glow::optimize(F_, CompilationMode::Infer);
2301
2302 // Quantization node was merged into input splat.
2303 EXPECT_EQ(2, F_->getNodes().size());
2304
2305 // New quantized splat should exist with same value.
2306 SplatNode *newSN = llvm::dyn_cast<SplatNode>(S->getInput());
2307 ASSERT_TRUE(newSN);
2308 EXPECT_EQ(splatVal, newSN->getValue());
2309 EXPECT_EQ(qType, newSN->getResult().getType());
2310 }
2311
2312 /// Check that the Dequantize(Splat) -> Splat' optimization works.
TEST_F(GraphOptz,foldDequantizeIntoSplat)2313 TEST_F(GraphOptz, foldDequantizeIntoSplat) {
2314 TypeRef fType = mod_.uniqueType(ElemKind::FloatTy, {4});
2315 TypeRef qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2316
2317 const float splatVal = 6.0;
2318 SplatNode *SN = F_->createSplat("splat", qType, splatVal);
2319
2320 DequantizeNode *Q = F_->createDequantize("dequantize", SN, ElemKind::FloatTy);
2321 SaveNode *S = F_->createSave("save", Q);
2322
2323 // Splat, dequantize, save.
2324 EXPECT_EQ(3, F_->getNodes().size());
2325
2326 ::glow::optimize(F_, CompilationMode::Infer);
2327
2328 // Dequantization node was merged into input splat.
2329 EXPECT_EQ(2, F_->getNodes().size());
2330
2331 // New quantized splat should exist with same value.
2332 SplatNode *newSN = llvm::dyn_cast<SplatNode>(S->getInput());
2333 ASSERT_TRUE(newSN);
2334 EXPECT_EQ(splatVal, newSN->getValue());
2335 EXPECT_EQ(fType, newSN->getResult().getType());
2336 }
2337
2338 /// Check that the Quantize(Splat) -> Splat' optimization works when the Splat
2339 /// has multiple users.
TEST_F(GraphOptz,foldQuantizeIntoSplatMultipleUsers)2340 TEST_F(GraphOptz, foldQuantizeIntoSplatMultipleUsers) {
2341 TypeRef fType = mod_.uniqueType(ElemKind::FloatTy, {4});
2342 TypeRef qType = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 0);
2343
2344 SplatNode *SN = F_->createSplat("splat", fType, 6.0);
2345
2346 QuantizeNode *Q = F_->createQuantize("quantize", SN, qType);
2347 SaveNode *SQ = F_->createSave("saveQ", Q);
2348 SaveNode *SF = F_->createSave("saveF", SN);
2349
2350 // Splat, quantize, 2 saves.
2351 EXPECT_EQ(4, F_->getNodes().size());
2352
2353 ::glow::optimize(F_, CompilationMode::Infer);
2354
2355 // Quantization node was merged into input splat creating a new quantized
2356 // splat, but the original float splat still exists.
2357 EXPECT_EQ(4, F_->getNodes().size());
2358
2359 // New quantized splat should exist with same value.
2360 SplatNode *newSN = llvm::dyn_cast<SplatNode>(SQ->getInput());
2361 ASSERT_TRUE(newSN);
2362 EXPECT_EQ(SN->getValue(), newSN->getValue());
2363 EXPECT_EQ(qType, newSN->getResult().getType());
2364
2365 // Original float splat should still exist.
2366 EXPECT_EQ(llvm::dyn_cast<SplatNode>(SF->getInput()), SN);
2367 }
2368
2369 /// Check that an unnecessary rescale gets removed.
TEST_F(GraphOptz,removeUnnecessaryRescale)2370 TEST_F(GraphOptz, removeUnnecessaryRescale) {
2371 TypeRef qType = mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.03f, 5);
2372 Placeholder *input =
2373 mod_.createPlaceholder(qType, "input", /* isTrainable */ true);
2374 RescaleQuantizedNode *RQ =
2375 F_->createRescaleQuantized("rescale", input, qType);
2376 SaveNode *save = F_->createSave("ret", RQ);
2377
2378 // RescaleQuantized and Save.
2379 EXPECT_EQ(F_->getNodes().size(), 2);
2380
2381 ::glow::optimize(F_, CompilationMode::Infer);
2382
2383 // Only Save should be left, which saves the Placeholder directly with
2384 // unchanged quantization parameters.
2385 EXPECT_EQ(F_->getNodes().size(), 1);
2386 EXPECT_EQ(save->getInput().getNode(), input);
2387 EXPECT_EQ(save->getInput().getType(), qType);
2388 }
2389
2390 /// Check that rescale gets correctly merged into a following dequantize node
TEST_F(GraphOptz,mergeRescaleIntoDequantize)2391 TEST_F(GraphOptz, mergeRescaleIntoDequantize) {
2392 // Check that we are combining quantization-dequantization pairs.
2393 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11,
2394 "input", true);
2395 auto *qType = mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.03f, 5);
2396 auto *R = F_->createRescaleQuantized("rescale", input, qType);
2397 auto *D = F_->createDequantize("dequantize", R, ElemKind::FloatTy);
2398 F_->createSave("ret", D);
2399
2400 EXPECT_EQ(F_->getNodes().size(), 3);
2401 ::glow::optimize(F_, CompilationMode::Infer);
2402
2403 // Only 2 nodes should remain (Dequantize -> Save)
2404 EXPECT_EQ(F_->getNodes().size(), 2);
2405
2406 // Check the graph structure
2407 auto *SN = F_->getNodeByName("ret_save");
2408 EXPECT_NE(nullptr, SN);
2409 auto *S = llvm::dyn_cast<SaveNode>(SN);
2410 EXPECT_NE(nullptr, S);
2411 auto *newDN = S->getInput().getNode();
2412 EXPECT_NE(nullptr, newDN);
2413 EXPECT_NE(nullptr, llvm::dyn_cast<DequantizeNode>(newDN));
2414 }
2415
TEST_F(GraphOptz,quantizeToRescale)2416 TEST_F(GraphOptz, quantizeToRescale) {
2417 // Check that we are combining quantization-dequantization pairs.
2418 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11,
2419 "input", true);
2420
2421 auto *D = F_->createDequantize("dequantize", input, ElemKind::FloatTy);
2422
2423 auto qType = mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.03, 5);
2424 auto *Q = F_->createQuantize("quantize", D, qType);
2425
2426 F_->createSave("ret", Q);
2427
2428 EXPECT_EQ(F_->getNodes().size(), 3);
2429
2430 ::glow::optimize(F_, CompilationMode::Infer);
2431 EXPECT_EQ(F_->getNodes().size(), 2);
2432 }
2433
TEST_F(GraphOptz,MaxOfQuantizedSplat)2434 TEST_F(GraphOptz, MaxOfQuantizedSplat) {
2435 const dim_t size = 5;
2436 const float scale = 1;
2437 // offset == -128 guarantees that fp range has values which are not less than
2438 // 0.
2439 const int32_t offset = -128;
2440
2441 auto splatTy = mod_.uniqueType(ElemKind::Int8QTy, {size}, scale, offset);
2442 auto *splat = F_->createSplat("splat", splatTy, 0.0);
2443
2444 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {size}, scale, offset,
2445 "input", true);
2446
2447 auto *max = F_->createMax("max", splat, input);
2448 F_->createSave("save", max);
2449 EXPECT_EQ(F_->getNodes().size(), 3);
2450
2451 ::glow::optimize(F_, CompilationMode::Infer);
2452 // Splat and Max should be gone.
2453 EXPECT_EQ(F_->getNodes().size(), 1);
2454 }
2455
TEST_F(GraphOptz,FuseRescaleIntoArithmetic)2456 TEST_F(GraphOptz, FuseRescaleIntoArithmetic) {
2457 // This test ensures the fact that fusing of rescale is done.
2458 auto opOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10}, 1, 0);
2459 auto rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10}, 2, 1);
2460
2461 Placeholder *LHS =
2462 mod_.createPlaceholder(ElemKind::Int8QTy, {10}, 0.4, 0, "LHS", true);
2463 Placeholder *RHS =
2464 mod_.createPlaceholder(ElemKind::Int8QTy, {10}, 0.3, 0, "RHS", true);
2465
2466 AddNode *add = F_->createAdd("qAdd", opOutTy, LHS, RHS);
2467 RescaleQuantizedNode *rescaleAdd =
2468 F_->createRescaleQuantized("rsAdd", add, rescaleOutTy);
2469 SaveNode *addSave = F_->createSave("saveAdd", rescaleAdd);
2470
2471 SubNode *sub = F_->createSub("qSub", opOutTy, LHS, RHS);
2472 RescaleQuantizedNode *rescaleSub =
2473 F_->createRescaleQuantized("rsSub", sub, rescaleOutTy);
2474 SaveNode *subSave = F_->createSave("saveSub", rescaleSub);
2475
2476 DivNode *div = F_->createDiv("qDiv", opOutTy, LHS, RHS);
2477 RescaleQuantizedNode *rescaleDiv =
2478 F_->createRescaleQuantized("rsDiv", div, rescaleOutTy);
2479 SaveNode *divSave = F_->createSave("saveDiv", rescaleDiv);
2480
2481 MulNode *mul = F_->createMul("qMul", opOutTy, LHS, RHS);
2482 RescaleQuantizedNode *rescaleMul =
2483 F_->createRescaleQuantized("rsMul", mul, rescaleOutTy);
2484 SaveNode *mulSave = F_->createSave("saveMul", rescaleMul);
2485
2486 MinNode *min = F_->createMin("qMin", opOutTy, LHS, RHS);
2487 RescaleQuantizedNode *rescaleMin =
2488 F_->createRescaleQuantized("rsMin", min, rescaleOutTy);
2489 SaveNode *minSave = F_->createSave("saveMin", rescaleMin);
2490
2491 MaxNode *max = F_->createMax("qMax", opOutTy, LHS, RHS);
2492 RescaleQuantizedNode *rescaleMax =
2493 F_->createRescaleQuantized("rsMax", max, rescaleOutTy);
2494 SaveNode *maxSave = F_->createSave("saveMax", rescaleMax);
2495
2496 // All rescales must be fused into arithmetic operations above.
2497 ::glow::optimize(F_, CompilationMode::Infer);
2498
2499 EXPECT_EQ(F_->getNodes().size(), 12);
2500
2501 EXPECT_EQ(addSave->getInput().getType(), rescaleOutTy);
2502 EXPECT_EQ(subSave->getInput().getType(), rescaleOutTy);
2503 EXPECT_EQ(mulSave->getInput().getType(), rescaleOutTy);
2504 EXPECT_EQ(divSave->getInput().getType(), rescaleOutTy);
2505 EXPECT_EQ(minSave->getInput().getType(), rescaleOutTy);
2506 EXPECT_EQ(maxSave->getInput().getType(), rescaleOutTy);
2507 }
2508
2509 /// Check that the Rescale(MatMul) -> MatMul' optimization works correctly.
TEST_F(GraphOptz,FuseRescaleUpIntoMatMul)2510 TEST_F(GraphOptz, FuseRescaleUpIntoMatMul) {
2511 // This test ensures the fact that fusing of rescale is done.
2512 auto opOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10, 10}, 1, 0);
2513 auto rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10, 10}, 2, 1);
2514
2515 Placeholder *LHS = mod_.createPlaceholder(ElemKind::Int8QTy, {10, 10}, 0.4, 0,
2516 "LHS", /* isTrainable */ false);
2517 Placeholder *RHS = mod_.createPlaceholder(ElemKind::Int8QTy, {10, 10}, 0.3, 0,
2518 "RHS", /* isTrainable */ false);
2519
2520 MatMulNode *MMN = F_->createMatMul("matmul", opOutTy, LHS, RHS);
2521 RescaleQuantizedNode *rescaleMMN =
2522 F_->createRescaleQuantized("rsMMN", MMN, rescaleOutTy);
2523 SaveNode *saveMMN = F_->createSave("saveMMN", rescaleMMN);
2524
2525 // MatMul, Rescale, Save.
2526 EXPECT_EQ(F_->getNodes().size(), 3);
2527
2528 // All rescales must be fused into arithmetic operations above.
2529 ::glow::optimize(F_, CompilationMode::Infer);
2530
2531 // Rescale merged up into the MatMul.
2532 EXPECT_EQ(F_->getNodes().size(), 2);
2533
2534 MatMulNode *newMMN = llvm::dyn_cast<MatMulNode>(saveMMN->getInput());
2535 ASSERT_TRUE(newMMN);
2536 EXPECT_EQ(newMMN->getResult().getType(), rescaleOutTy);
2537 }
2538
2539 /// Check that the Rescale(SparseLengthsWeightedSum) ->
2540 /// SparseLengthsWeightedSum' optimization works correctly.
TEST_F(GraphOptz,FuseRescaleUpIntoSparseLengthsWeightedSum)2541 TEST_F(GraphOptz, FuseRescaleUpIntoSparseLengthsWeightedSum) {
2542 // This test ensures the fact that fusing of rescale is done.
2543 TypeRef rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {4}, 2, 1);
2544
2545 Placeholder *data =
2546 mod_.createPlaceholder(ElemKind::Int8QTy, {3}, 0.5, 0, "data",
2547 /* isTrainable */ false);
2548 Placeholder *weights = mod_.createPlaceholder(
2549 ElemKind::Int8QTy, {8}, 0.5, 0, "weights", /* isTrainable */ false);
2550 Placeholder *indices =
2551 mod_.createPlaceholder(ElemKind::Int64ITy, {8}, "indices",
2552 /* isTrainable */ false);
2553 Placeholder *lengths =
2554 mod_.createPlaceholder(ElemKind::Int32ITy, {4}, "lengths",
2555 /* isTrainable */ false);
2556
2557 SparseLengthsWeightedSumNode *SLWS = F_->createSparseLengthsWeightedSum(
2558 "SLWS", data, weights, indices, lengths);
2559 RescaleQuantizedNode *rescaleSLWS =
2560 F_->createRescaleQuantized("rsSLWS", SLWS, rescaleOutTy);
2561 SaveNode *saveSLWS = F_->createSave("saveSLWS", rescaleSLWS);
2562
2563 // SparseLengthsWeightedSum, Rescale, Save.
2564 EXPECT_EQ(F_->getNodes().size(), 3);
2565
2566 // All rescales must be fused into arithmetic operations above.
2567 ::glow::optimize(F_, CompilationMode::Infer);
2568
2569 // Rescale merged up into the SparseLengthsWeightedSum.
2570 EXPECT_EQ(F_->getNodes().size(), 2);
2571
2572 SparseLengthsWeightedSumNode *newSLWS =
2573 llvm::dyn_cast<SparseLengthsWeightedSumNode>(saveSLWS->getInput());
2574 ASSERT_TRUE(newSLWS);
2575 EXPECT_EQ(newSLWS->getResult().getType(), rescaleOutTy);
2576 }
2577
TEST_F(GraphOptz,fuseRescaleIntoConv)2578 TEST_F(GraphOptz, fuseRescaleIntoConv) {
2579 // This test ensures the fact that fusing of rescale is done.
2580 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {1, 10, 20, 3}, 0.5,
2581 10, "input", true);
2582 auto *filter = mod_.createPlaceholder(ElemKind::Int8QTy, {16, 5, 5, 3}, 0.5,
2583 10, "filter", true);
2584 auto *bias =
2585 mod_.createPlaceholder(ElemKind::Int8QTy, {16}, 0.5, 10, "bias", true);
2586
2587 auto *rInput = F_->createRescaleQuantized(
2588 "rescale", input,
2589 mod_.uniqueType(ElemKind::Int8QTy, {1, 10, 20, 3}, 0.1, -25));
2590 auto *rFilter = F_->createRescaleQuantized(
2591 "rescale", filter,
2592 mod_.uniqueType(ElemKind::Int8QTy, {16, 5, 5, 3}, 0.2, 0));
2593 auto *rBias = F_->createRescaleQuantized(
2594 "rescale", bias, mod_.uniqueType(ElemKind::Int8QTy, {16}, 0.3, 25));
2595 auto *CV = F_->createConv(
2596 "conv", rInput, rFilter, rBias,
2597 mod_.uniqueType(ElemKind::Int8QTy, {1, 10, 20, 16}, 0.7, -3), 5, 1, 2, 1);
2598 auto *rCV = F_->createRescaleQuantized(
2599 "rescale", CV,
2600 mod_.uniqueType(ElemKind::Int8QTy, {1, 10, 20, 16}, 0.4, 37));
2601 F_->createSave("save", rCV);
2602
2603 // All rescales must be fused into convolution.
2604 EXPECT_EQ(F_->getNodes().size(), 6);
2605 ::glow::optimize(F_, CompilationMode::Infer);
2606 EXPECT_EQ(F_->getNodes().size(), 2);
2607 }
2608
2609 /// This test ensures that if there is a Pad node as input of a Convolution
2610 /// node, Pad gets merges into Convolution.
2611 /// Note that Pads is merged into convolution only when it is compatible with
2612 /// the convolution padding:
2613 /// - Resulting padding after merge is positive
2614 /// - Padding only concerns spatial dimensions
2615 /// - Padding has mode 'constant' with value 0.f
fusePadIntoConvTest(glow::Module & mod_,glow::Function * F_,llvm::ArrayRef<dim_t> inputDims,llvm::ArrayRef<int> pads,unsigned_t convKernelSize,llvm::ArrayRef<unsigned_t> convPads,unsigned_t convStride,unsigned_t convNumKernels)2616 void fusePadIntoConvTest(glow::Module &mod_, glow::Function *F_,
2617 llvm::ArrayRef<dim_t> inputDims,
2618 llvm::ArrayRef<int> pads, unsigned_t convKernelSize,
2619 llvm::ArrayRef<unsigned_t> convPads,
2620 unsigned_t convStride, unsigned_t convNumKernels) {
2621 auto *input =
2622 mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input", true);
2623
2624 // Pad
2625 dim_t inputWithPadDims[4];
2626 for (int i = 0; i < 4; i++) {
2627 inputWithPadDims[i] = dim_t(ssize_t(inputDims[i]) + pads[i] + pads[4 + i]);
2628 }
2629 dim_t outputConvDims[4] = {
2630 inputWithPadDims[0],
2631 inputWithPadDims[1] + convPads[0] + convPads[2] - (convKernelSize - 1),
2632 inputWithPadDims[2] + convPads[1] + convPads[3] - (convKernelSize - 1),
2633 convNumKernels};
2634
2635 auto outTy = mod_.uniqueType(ElemKind::FloatTy, inputWithPadDims);
2636 Node *P =
2637 F_->createPad("pad", input, outTy, PaddingMode::CONSTANT, pads, 0.f);
2638
2639 // Convolution
2640 dim_t filterDims[] = {convNumKernels, convKernelSize, convKernelSize,
2641 inputDims[3]};
2642 auto *F =
2643 mod_.createPlaceholder(ElemKind::FloatTy, filterDims, "filter", true);
2644 auto *B =
2645 mod_.createPlaceholder(ElemKind::FloatTy, {convNumKernels}, "bias", true);
2646 auto *CV = F_->createConv(
2647 "conv", P, F, B, mod_.uniqueType(ElemKind::FloatTy, outputConvDims),
2648 {convKernelSize, convKernelSize}, {convStride, convStride}, convPads, 1);
2649
2650 SaveNode *O = F_->createSave("save", CV);
2651
2652 // The pad node must be merged into convolution.
2653 EXPECT_EQ(F_->getNodes().size(), 3);
2654 ::glow::optimize(F_, CompilationMode::Infer);
2655 EXPECT_EQ(F_->getNodes().size(), 2);
2656
2657 // Check the graph structure and additional properties after optimization.
2658 auto *conv = llvm::dyn_cast<ConvolutionNode>(O->getInput());
2659 ASSERT_NE(conv, nullptr);
2660 EXPECT_EQ(conv->getResult().dims(), llvm::ArrayRef<dim_t>(outputConvDims));
2661 unsigned_t expectedPads[4];
2662 for (int i = 0; i < 2; i++) {
2663 for (int j = 0; j < 2; j++) {
2664 expectedPads[2 * i + j] =
2665 unsigned_t(int(convPads[2 * i + j]) + pads[4 * i + (1 + j)]);
2666 }
2667 }
2668 EXPECT_EQ(conv->getPads(), llvm::makeArrayRef(expectedPads));
2669 }
2670
TEST_F(GraphOptz,fusePadIntoConv)2671 TEST_F(GraphOptz, fusePadIntoConv) {
2672 fusePadIntoConvTest(mod_, F_, {1, 6, 14, 3} /* inputDims */,
2673 {0, 1, 2, 0, 0, 3, 4, 0} /* pads */,
2674 5 /* convKernelSize */, {0, 0, 0, 0} /* convPads */,
2675 1 /* convStride */, 16 /* convNumKernels */);
2676 }
2677
TEST_F(GraphOptz,fusePadIntoConvNeg1)2678 TEST_F(GraphOptz, fusePadIntoConvNeg1) {
2679 fusePadIntoConvTest(mod_, F_, {1, 6, 14, 3} /* inputDims */,
2680 {0, -1, 2, 0, 0, 3, -2, 0} /* pads */,
2681 5 /* convKernelSize */, {3, 0, 2, 5} /* convPads */,
2682 1 /* convStride */, 16 /* convNumKernels */);
2683 }
2684
TEST_F(GraphOptz,fusePadIntoConvNeg2)2685 TEST_F(GraphOptz, fusePadIntoConvNeg2) {
2686 fusePadIntoConvTest(mod_, F_, {1, 6, 14, 3} /* inputDims */,
2687 {0, 1, -2, 0, 0, -3, 4, 0} /* pads */,
2688 5 /* convKernelSize */, {0, 2, 5, 7} /* convPads */,
2689 1 /* convStride */, 16 /* convNumKernels */);
2690 }
2691
2692 /// This test checks that a lowered LeakyRelu is corrected folded:
2693 /// Max(A, Mult(A, Splat)) -> PRelu(Splat)
TEST_F(GraphFold,foldLeakyReluFromSplat)2694 TEST_F(GraphFold, foldLeakyReluFromSplat) {
2695 std::vector<dim_t> dims = {5, 2};
2696
2697 auto *input = mod_.createPlaceholder(ElemKind::FloatTy, dims, "input", true);
2698
2699 const float leakyAlpha = 0.05f;
2700 auto OutTy = mod_.uniqueType(ElemKind::FloatTy, dims);
2701 SplatNode *splatNode = F_->createSplat("splat", OutTy, leakyAlpha);
2702 MulNode *mulNode = F_->createMul("mul", input, splatNode);
2703 MaxNode *maxNode = F_->createMax("max", input, mulNode);
2704 SaveNode *output = F_->createSave("save", maxNode);
2705
2706 EXPECT_EQ(4, F_->getNodes().size());
2707
2708 CompilationContext cctx;
2709 ::glow::fold(F_, cctx);
2710
2711 // Check the resulting graph after folding.
2712 EXPECT_EQ(3, F_->getNodes().size());
2713 auto *newPReluNode = llvm::dyn_cast<PReluNode>(output->getInput());
2714 ASSERT_TRUE(newPReluNode);
2715 auto *newSplatNode = llvm::dyn_cast<SplatNode>(newPReluNode->getSlope());
2716 ASSERT_TRUE(newSplatNode);
2717 EXPECT_EQ(leakyAlpha, newSplatNode->getValue());
2718 EXPECT_EQ(input, newPReluNode->getInput());
2719 }
2720
2721 /// This test checks that a lowered LeakyRelu is corrected folded:
2722 /// Max(A, Mult(A, broadcasted Const)) -> PRelu(Splat)
TEST_F(GraphFold,foldLeakyReluFromConst)2723 TEST_F(GraphFold, foldLeakyReluFromConst) {
2724 std::vector<dim_t> dims = {5, 2};
2725 auto *input = mod_.createPlaceholder(ElemKind::FloatTy, dims, "input", true);
2726
2727 const float leakyAlpha = 0.99f;
2728 auto *alphaConst = mod_.createConstant(ElemKind::FloatTy, {1}, "alphaConst");
2729 alphaConst->getHandle() = {leakyAlpha};
2730 ReshapeNode *reshapeNode = F_->createReshape("reshape", alphaConst, {1, 1});
2731 TileNode *tileNode1 = F_->createTile("tile1", reshapeNode, 2, 1);
2732 TileNode *tileNode2 = F_->createTile("tile2", tileNode1, 5, 0);
2733 MulNode *mulNode = F_->createMul("mul", input, tileNode2);
2734 MaxNode *maxNode = F_->createMax("max", input, mulNode);
2735 SaveNode *output = F_->createSave("save", maxNode);
2736
2737 EXPECT_EQ(6, F_->getNodes().size());
2738
2739 CompilationContext cctx;
2740 ::glow::fold(F_, cctx);
2741
2742 // Check the resulting graph after folding. Reshape must have been merged into
2743 // the constant and LeakyRelu must have been folded.
2744 EXPECT_EQ(3, F_->getNodes().size());
2745 auto *newPReluNode = llvm::dyn_cast<PReluNode>(output->getInput());
2746 ASSERT_TRUE(newPReluNode);
2747 auto *newSplatNode = llvm::dyn_cast<SplatNode>(newPReluNode->getSlope());
2748 ASSERT_TRUE(newSplatNode);
2749 EXPECT_EQ(leakyAlpha, newSplatNode->getValue());
2750 EXPECT_EQ(input, newPReluNode->getInput());
2751 }
2752
2753 /// Testing folding of Reshape->Transpose->Reshape into ChannelShuffle.
TEST_F(GraphFold,foldChannelShuffle)2754 TEST_F(GraphFold, foldChannelShuffle) {
2755 const dim_t inputDims[] = {3, 136, 28, 28};
2756
2757 Node *K =
2758 mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
2759 K = F_->createReshape("CS_reshape1", K, {3, 4, 34, 28, 28});
2760 K = F_->createTranspose("CS_transpose", K, {0, 2, 1, 3, 4});
2761 K = F_->createReshape("CS_reshape2", K, {3, 136, 28, 28});
2762 auto *save = F_->createSave("ret", K);
2763
2764 EXPECT_EQ(F_->getNodes().size(), 4);
2765
2766 // Fold RN->TR->RN into ChannelShuffle
2767 CompilationContext cctx;
2768 ::glow::fold(F_, cctx);
2769
2770 ASSERT_EQ(F_->getNodes().size(), 2);
2771
2772 // Check for ChannelShuffle node.
2773 auto *CS = llvm::dyn_cast<ChannelShuffleNode>(save->getInput().getNode());
2774 ASSERT_NE(nullptr, CS);
2775
2776 // Ensure ChannelShuffle node has the same dimensions as the input.
2777 EXPECT_EQ(CS->getResult().dims(), llvm::makeArrayRef(inputDims));
2778
2779 // Ensure Group and Kernel are as expected.
2780 EXPECT_EQ(CS->getGroup(), 4);
2781 EXPECT_EQ(CS->getKernel(), 1);
2782 }
2783
TEST_F(GraphFold,NoFoldChannelShuffle)2784 TEST_F(GraphFold, NoFoldChannelShuffle) {
2785 auto Float = ElemKind::FloatTy;
2786 auto *P = mod_.createPlaceholder(Float, {10, 8928}, "P", false);
2787 auto *R1 = F_->createReshape("R1", P, {10, 186, 48});
2788 auto *TR = F_->createTranspose("TR", R1, {0, 2, 1});
2789 auto *R2 = F_->createReshape("R2", TR, {480, 186});
2790 auto *save = F_->createSave("save", R2);
2791
2792 EXPECT_EQ(F_->getNodes().size(), 4);
2793
2794 CompilationContext cctx;
2795 ::glow::fold(F_, cctx);
2796
2797 EXPECT_EQ(F_->getNodes().size(), 4);
2798 EXPECT_FALSE(llvm::isa<ChannelShuffleNode>(save->getInput()));
2799 }
2800
2801 class MockBackendWithFusion : public MockBackend {
supportsFusedActivation(Node * parent,Node * activation) const2802 bool supportsFusedActivation(Node *parent, Node *activation) const override {
2803 switch (parent->getKind()) {
2804 case Kinded::Kind::ConvolutionNodeKind:
2805 switch (activation->getKind()) {
2806 case Kinded::Kind::ReluNodeKind:
2807 case Kinded::Kind::SigmoidNodeKind:
2808 case Kinded::Kind::TanhNodeKind:
2809 return true;
2810 default:
2811 return false;
2812 }
2813 default:
2814 return false;
2815 }
2816 }
2817 };
2818
2819 #define CONV_ACTIVATION_TEST(ACTIVATION_, CREATOR_) \
2820 TEST_F(GraphFold, FoldConv##ACTIVATION_##Activation) { \
2821 auto *A = \
2822 mod_.createPlaceholder(ElemKind::FloatTy, {1, 10, 20, 3}, "A", false); \
2823 ConvolutionNode *CV = \
2824 F_->createConv(bindings_, "conv", A, 16, 5, 1, 2, 1); \
2825 auto *AN = F_->CREATOR_(#ACTIVATION_, CV); \
2826 SaveNode *SN = F_->createSave("ret", AN); \
2827 \
2828 EXPECT_EQ(F_->getNodes().size(), 3); \
2829 \
2830 CompilationContext cctx; \
2831 auto B = MockBackendWithFusion(); \
2832 ::glow::fold(F_, cctx, &B); \
2833 \
2834 ConvolutionNode *fusedCV = \
2835 llvm::dyn_cast<ConvolutionNode>(SN->getInput()); \
2836 ASSERT_TRUE(fusedCV); \
2837 EXPECT_EQ(fusedCV->getFusedActivation(), FusedActivation::ACTIVATION_); \
2838 }
2839
2840 CONV_ACTIVATION_TEST(RELU, createRELU);
2841 CONV_ACTIVATION_TEST(SIGMOID, createSigmoid);
2842 CONV_ACTIVATION_TEST(TANH, createTanh);
2843
2844 #undef CONV_ACTIVATION_TEST
2845
2846 /// This test ensures that if there is a RescaleNode whose input has multiple
2847 /// users that the input is not cloned, as this duplicates the node.
TEST_F(GraphOptz,MultipleUsersRescaleCombineNoOpt)2848 TEST_F(GraphOptz, MultipleUsersRescaleCombineNoOpt) {
2849 auto opOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10}, 1, 0);
2850 auto rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10}, 2, 1);
2851
2852 Node *LHS =
2853 mod_.createPlaceholder(ElemKind::Int8QTy, {10}, 0.4, 0, "LHS", true);
2854 Node *RHS =
2855 mod_.createPlaceholder(ElemKind::Int8QTy, {10}, 0.3, 0, "RHS", true);
2856
2857 AddNode *AN = F_->createAdd("qAdd", opOutTy, LHS, RHS);
2858 RescaleQuantizedNode *RQN =
2859 F_->createRescaleQuantized("rsAdd", AN, rescaleOutTy);
2860 SaveNode *saveRQN = F_->createSave("saveRQN", RQN);
2861 SaveNode *saveAN = F_->createSave("saveAN", AN);
2862
2863 EXPECT_EQ(F_->getNodes().size(), 4);
2864
2865 ::glow::optimize(F_, CompilationMode::Infer);
2866
2867 // The graph should be unchanged.
2868 EXPECT_EQ(F_->getNodes().size(), 4);
2869 EXPECT_EQ(saveRQN->getInput().getNode(), RQN);
2870 EXPECT_EQ(RQN->getInput().getNode(), AN);
2871 EXPECT_EQ(saveAN->getInput().getNode(), AN);
2872 EXPECT_EQ(AN->getLHS().getNode(), LHS);
2873 EXPECT_EQ(AN->getRHS().getNode(), RHS);
2874 }
2875
2876 /// This test ensures that fusing of rescale into MatMul is done.
TEST_F(GraphOptz,FuseRescaleIntoMatMul)2877 TEST_F(GraphOptz, FuseRescaleIntoMatMul) {
2878 auto opOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10, 10}, 1, 0);
2879 auto rescaleOutTy = mod_.uniqueType(ElemKind::Int8QTy, {10, 10}, 2, 1);
2880
2881 Placeholder *LHS =
2882 mod_.createPlaceholder(ElemKind::Int8QTy, {10, 10}, 0.4, 0, "LHS", true);
2883 Placeholder *RHS =
2884 mod_.createPlaceholder(ElemKind::Int8QTy, {10, 10}, 0.3, 0, "RHS", true);
2885
2886 RescaleQuantizedNode *LHSR =
2887 F_->createRescaleQuantized("rs1", LHS, rescaleOutTy);
2888 RescaleQuantizedNode *RHSR =
2889 F_->createRescaleQuantized("rs2", RHS, rescaleOutTy);
2890 MatMulNode *MN = F_->createMatMul("qMatMul", opOutTy, LHSR, RHSR);
2891 SaveNode *SN = F_->createSave("save", MN);
2892
2893 // All rescales must be fused into arithmetic operations above.
2894 ::glow::optimize(F_, CompilationMode::Infer);
2895
2896 // Only the MatMul and Save should be left.
2897 EXPECT_EQ(F_->getNodes().size(), 2);
2898 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::RescaleQuantizedNodeKind), 0);
2899
2900 MatMulNode *newMN = llvm::dyn_cast<MatMulNode>(SN->getInput());
2901 ASSERT_TRUE(newMN);
2902 Placeholder *LPH = llvm::dyn_cast<Placeholder>(newMN->getLHS());
2903 EXPECT_EQ(LPH, LHS);
2904 Placeholder *RPH = llvm::dyn_cast<Placeholder>(newMN->getRHS());
2905 EXPECT_EQ(RPH, RHS);
2906 }
2907
TEST_F(GraphOptz,sinkRescaledQuantizedNode)2908 TEST_F(GraphOptz, sinkRescaledQuantizedNode) {
2909 // Check that we eliminate rescale nodes by sinking them into other
2910 // operators.
2911 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11,
2912 "input", true);
2913
2914 // slice -> rescale -> reshape -> rescale -> transpose -> maxpool -> save.
2915 auto *slice = F_->createSlice("slice", input, {0, 0}, {2, 4});
2916 auto *rescale = F_->createRescaleQuantized(
2917 "rescale", slice, mod_.uniqueType(ElemKind::Int8QTy, {2, 4}, 0.4, 10));
2918 auto *reshape = F_->createReshape("reshape", rescale, {1, 2, 2, 2});
2919 auto *rescale2 = F_->createRescaleQuantized(
2920 "rescale", reshape,
2921 mod_.uniqueType(ElemKind::Int8QTy, {1, 2, 2, 2}, 0.3, 9));
2922 auto *transpose = F_->createTranspose("transpose", rescale2, {0, 2, 3, 1});
2923 auto *maxpool =
2924 F_->createMaxPool("maxpool", transpose, {2, 2}, {1, 1}, {0, 0, 0, 0});
2925 auto *save = F_->createSave("ret", maxpool->getResult());
2926
2927 EXPECT_EQ(F_->getNodes().size(), 7);
2928 ::glow::optimize(F_, CompilationMode::Infer);
2929 EXPECT_EQ(F_->getNodes().size(), 6);
2930 // Check that rescale sank all the way down to the save node.
2931 EXPECT_TRUE(llvm::dyn_cast<RescaleQuantizedNode>(save->getInput()));
2932 }
2933
TEST_F(GraphOptz,mergeRescaleWithArithmeticNode)2934 TEST_F(GraphOptz, mergeRescaleWithArithmeticNode) {
2935 // Check that Arithmetic operations can be merged with the Rescale.
2936 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11,
2937 "input", true);
2938
2939 auto *rescale1 = F_->createRescaleQuantized(
2940 "rescale", input, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.4, 11));
2941 auto *add = F_->createAdd("add", rescale1, rescale1);
2942 auto *rescale2 = F_->createRescaleQuantized(
2943 "rescale", add, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.3, 11));
2944 auto *sub = F_->createSub("sub", rescale2, rescale2);
2945 auto *rescale3 = F_->createRescaleQuantized(
2946 "rescale", sub, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.2, 11));
2947 auto *mul = F_->createMul("mul", rescale3, rescale3);
2948 auto *rescale4 = F_->createRescaleQuantized(
2949 "rescale", mul, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.1, 11));
2950 auto *div = F_->createDiv("div", rescale4, rescale4);
2951 F_->createSave("save", div);
2952
2953 EXPECT_EQ(F_->getNodes().size(), 9);
2954 ::glow::optimize(F_, CompilationMode::Infer);
2955 EXPECT_EQ(F_->getNodes().size(), 5);
2956 }
2957
2958 /// Check that Relu can be merged with Rescale.
TEST_F(GraphOptz,mergeRescaleWithRelu)2959 TEST_F(GraphOptz, mergeRescaleWithRelu) {
2960 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {4, 10}, 0.5, 11,
2961 "input", false);
2962
2963 auto *rescale1 = F_->createRescaleQuantized(
2964 "rescale", input, mod_.uniqueType(ElemKind::Int8QTy, {4, 10}, 0.4, 11));
2965 auto *relu = F_->createRELU("relu", rescale1);
2966 F_->createSave("save", relu);
2967
2968 // Rescale, RELU, Save nodes.
2969 EXPECT_EQ(F_->getNodes().size(), 3);
2970
2971 ::glow::optimize(F_, CompilationMode::Infer);
2972
2973 // RELU, Save nodes left; Rescale merged into RELU.
2974 EXPECT_EQ(F_->getNodes().size(), 2);
2975 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::RescaleQuantizedNodeKind), 0);
2976 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 1);
2977 }
2978
2979 // Check that we are able to merge some small matmuls into a larger one.
TEST_F(GraphOptz,mergeMatMulNodes)2980 TEST_F(GraphOptz, mergeMatMulNodes) {
2981 Node *input =
2982 mod_.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input", true);
2983 Node *weight =
2984 mod_.createPlaceholder(ElemKind::FloatTy, {10, 10}, "weight", true);
2985
2986 // Split the input to a bunch of small slices.
2987 std::vector<NodeValue> inputs;
2988 for (dim_t i = 0; i < 10; i++) {
2989 auto *K = F_->createSlice("extract", input, {i, 0, 0}, {i + 1, 10, 10});
2990 auto *R = F_->createReshape("reshape", K, {10, 10});
2991 auto *MM = F_->createMatMul("mm", R, weight);
2992 inputs.push_back(MM);
2993 }
2994
2995 auto *cc = F_->createConcat("merge", inputs, 0);
2996 F_->createSave("save", cc);
2997
2998 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 10);
2999 ::glow::optimize(F_, CompilationMode::Infer);
3000
3001 // Check that all of the matmuls are merged into a single matmul node.
3002 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 1);
3003 }
3004
3005 // Check that we are able to merge batched adds.
TEST_F(GraphOptz,mergeBANodes)3006 TEST_F(GraphOptz, mergeBANodes) {
3007 Node *input =
3008 mod_.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input", true);
3009 Node *slice =
3010 mod_.createPlaceholder(ElemKind::FloatTy, {10, 10}, "weight", true);
3011
3012 // Split the input to a bunch of small slices.
3013 std::vector<NodeValue> inputs;
3014 for (dim_t i = 0; i < 10; i++) {
3015 auto *K = F_->createSlice("extract", input, {i, 0, 0}, {i + 1, 10, 10});
3016 auto *MM = F_->createBatchedAdd("BA", K, slice);
3017 inputs.push_back(MM);
3018 }
3019
3020 auto *cc = F_->createConcat("merge", inputs, 0);
3021 F_->createSave("save", cc);
3022
3023 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchedAddNodeKind), 10);
3024 ::glow::optimize(F_, CompilationMode::Infer);
3025
3026 // Check that all of the batched-adds are merged into a single node.
3027 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchedAddNodeKind), 1);
3028 }
3029
3030 /// Check that EliminateNoop optimization pass removes nodes which don't do
3031 /// anything useful.
TEST_F(GraphOptz,eliminateNoop)3032 TEST_F(GraphOptz, eliminateNoop) {
3033 std::vector<dim_t> shape = {1, 2, 2, 3};
3034 Placeholder *input1 = mod_.createPlaceholder(ElemKind::Int8QTy, shape, 0.004,
3035 0, "input", false);
3036 Placeholder *input2 = mod_.createPlaceholder(ElemKind::Int8QTy, shape, 0.004,
3037 0, "input", false);
3038 auto *cond = mod_.createConstant(ElemKind::BoolTy, shape, "input1");
3039 cond->getHandle<bool>() = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
3040
3041 auto *select = F_->createSelect("select", cond, input1, input2);
3042 auto *slice = F_->createSlice("slice", select, {0, 0, 0, 0}, shape);
3043 auto *tile = F_->createTile("tile", slice, 1, 1);
3044 auto *pad = F_->createPad("pad", tile, tile->getResult().getType(), 0,
3045 {0, 0, 0, 0, 0, 0, 0, 0}, 0);
3046 auto *avgPool = F_->createAvgPool("avgpool", pad, 1, 1, 0);
3047 auto *maxPool = F_->createMaxPool("maxpool", avgPool, 1, 1, 0);
3048
3049 F_->createSave("save", maxPool->getResult());
3050
3051 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SelectNodeKind), 1);
3052 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 1);
3053 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TileNodeKind), 1);
3054 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::PadNodeKind), 1);
3055 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AvgPoolNodeKind), 1);
3056 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MaxPoolNodeKind), 1);
3057
3058 optimizedF_ = optimizeFunction(F_);
3059
3060 // Check that all nodes except for Save are eliminated.
3061 EXPECT_EQ(optimizedF_->getNodes().size(), 1);
3062
3063 bindings_.allocate(mod_.getPlaceholders());
3064 bindings_.get(input1)->getHandle<int8_t>().randomize(-1.0, 1.0,
3065 mod_.getPRNG());
3066 bindings_.get(input2)->getHandle<int8_t>().randomize(-1.0, 1.0,
3067 mod_.getPRNG());
3068
3069 checkNumericalEquivalence();
3070 }
3071
3072 // Check that we are able to replace
3073 // Add(I, tile(B)) with -> BatchedAdd(I, B).
TEST_F(GraphOptz,FoldTileAddIntoBatchedAdd)3074 TEST_F(GraphOptz, FoldTileAddIntoBatchedAdd) {
3075 auto *batch =
3076 mod_.createPlaceholder(ElemKind::FloatTy, {3, 1, 2}, "batch", false);
3077 auto *added = mod_.createConstant(ElemKind::FloatTy, {1, 1, 2}, "added");
3078 auto *addedTiled = F_->createTile("addedTiled", added, 3, 0);
3079 auto *add = F_->createAdd("add", batch, addedTiled);
3080 auto *save = F_->createSave("save", add);
3081 auto *output = save->getPlaceholder();
3082
3083 bindings_.allocate(batch)->getHandle() = {2, 2, 3, 3, 4, 4};
3084 added->getPayloadMutable().getHandle() = {1, 1};
3085 bindings_.allocate(output);
3086
3087 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TileNodeKind), 1);
3088 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AddNodeKind), 1);
3089 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchedAddNodeKind), 0);
3090
3091 ASSERT_TRUE(F_->verify());
3092
3093 // Currently the FoldTileAddIntoBatchedAdd opt which we're testing here is not
3094 // part of the default optimization pipeline. Create a local version of the
3095 // pipeline with that pass included.
3096 auto p = createDefaultGraphOptimizationPassPipeline();
3097 p->pushFront({FunctionPassID::FoldTileAddIntoBatchedAdd});
3098 FunctionPassManager FPM("opt", std::move(p));
3099 FPM.run(F_, CompilationContext());
3100 ASSERT_TRUE(F_->verify());
3101
3102 // Check that the Tile node and the Add node is replaced by
3103 // a BatchedAdd node.
3104 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::TileNodeKind), 0);
3105 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AddNodeKind), 0);
3106 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchedAddNodeKind), 1);
3107
3108 // Verify the correctness of the input to BatchedAdd operator.
3109 // The correctness of BatchedAdd operator itself is verified
3110 // by operator's unit tests.
3111 Tensor expectedBatch(ElemKind::FloatTy, {3, 1, 2});
3112 expectedBatch.getHandle() = {2, 2, 3, 3, 4, 4};
3113 Tensor expectedSlice(ElemKind::FloatTy, {1, 2});
3114 expectedSlice.getHandle() = {1, 1};
3115 for (auto &node : F_->getNodes()) {
3116 auto *recvdBANode = llvm::dyn_cast<BatchedAddNode>(&node);
3117 if (!recvdBANode) {
3118 continue;
3119 }
3120 auto *recvdBatch = llvm::dyn_cast<Placeholder>(recvdBANode->getBatch());
3121 ASSERT_TRUE(recvdBatch);
3122 auto *recvdSlice = llvm::dyn_cast<Constant>(recvdBANode->getSlice());
3123 ASSERT_TRUE(recvdSlice);
3124 EXPECT_TRUE(recvdBatch->dims().equals({3, 1, 2}));
3125 EXPECT_TRUE(recvdSlice->dims().equals({1, 2}));
3126 EXPECT_TRUE(bindings_.get(recvdBatch)->isEqual(expectedBatch));
3127 EXPECT_TRUE(recvdSlice->getPayload().isEqual(expectedSlice));
3128 break;
3129 }
3130 }
3131
3132 /// Test Concat(Slice, ..., Slice) opt works correctly. If \p reverseOrder then
3133 /// the optimization is inapplicable and should not occur.
testConcatElim(Module & mod,Function * F,Function * & optimizedF,PlaceholderBindings & bindings,bool reverseOrder)3134 static void testConcatElim(Module &mod, Function *F, Function *&optimizedF,
3135 PlaceholderBindings &bindings, bool reverseOrder) {
3136 Placeholder *input =
3137 mod.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input", true);
3138 bindings.allocate(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
3139
3140 // Split the input to a bunch of small slices.
3141 std::array<NodeValue, 10> inputs;
3142 for (dim_t i = 0; i < 10; i++) {
3143 dim_t idx = reverseOrder ? 9 - i : i;
3144 inputs[i] =
3145 F->createSlice("extract", input, {idx, 0, 0}, {idx + 1, 10, 10});
3146 }
3147
3148 auto *cc = F->createConcat("merge", inputs, 0);
3149 F->createSave("save", cc);
3150
3151 EXPECT_EQ(countNodeKind(F, Kinded::Kind::SliceNodeKind), 10);
3152
3153 optimizedF = optimizeFunction(F);
3154
3155 // Check that either the concat and slices are gone if the optimization was
3156 // applicable, or otherwise that they're still there.
3157 EXPECT_EQ(countNodeKind(optimizedF, Kinded::Kind::ConcatNodeKind),
3158 reverseOrder ? 1 : 0);
3159 EXPECT_EQ(countNodeKind(optimizedF, Kinded::Kind::SliceNodeKind),
3160 reverseOrder ? 10 : 0);
3161 }
3162
3163 // Check that we are able to eliminate concat nodes.
TEST_F(GraphOptz,concatElim)3164 TEST_F(GraphOptz, concatElim) {
3165 testConcatElim(mod_, F_, optimizedF_, bindings_, /* reverseOrder */ false);
3166 checkNumericalEquivalence(0.0f);
3167 }
3168
3169 // Check that when the order of the Slices is reversed no optimization kicks in.
TEST_F(GraphOptz,concatElimReverseOrder)3170 TEST_F(GraphOptz, concatElimReverseOrder) {
3171 testConcatElim(mod_, F_, optimizedF_, bindings_, /* reverseOrder */ true);
3172 checkNumericalEquivalence(0.0f);
3173 }
3174
3175 /// Check that we are able to eliminate concat nodes with redundant arithmetic
3176 /// ops in way.
TEST_F(GraphOptz,concatArithElim)3177 TEST_F(GraphOptz, concatArithElim) {
3178 auto *input =
3179 mod_.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input", true);
3180 bindings_.allocate(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
3181
3182 Type t(ElemKind::FloatTy, {1, 10, 10});
3183 Node *one = F_->createSplat("one", &t, 1.0);
3184 Node *zero = F_->createSplat("zero", &t, 0.0);
3185
3186 // Split the input to a bunch of small slices.
3187 std::vector<NodeValue> inputs;
3188 for (dim_t i = 0; i < 10; i++) {
3189 auto *K = F_->createSlice("extract", input, {i, 0, 0}, {i + 1, 10, 10});
3190 // Insert the nodes in reverse order to make sure that we can catch
3191 // non-consecutive graph-order slices.
3192 Node *N = K;
3193 switch (i) {
3194 case 0:
3195 N = F_->createAdd("add0", K, zero);
3196 break;
3197 case 1:
3198 N = F_->createSub("sub0", K, zero);
3199 break;
3200 case 2:
3201 N = F_->createAdd("add_0", zero, K);
3202 break;
3203 case 3:
3204 N = F_->createMul("mul1", K, one);
3205 break;
3206 case 4:
3207 N = F_->createDiv("div1", K, one);
3208 break;
3209 case 5:
3210 N = F_->createMul("mul_1", one, K);
3211 break;
3212 default:
3213 break;
3214 }
3215 inputs.push_back(N);
3216 }
3217
3218 auto *cc = F_->createConcat("merge", inputs, 0);
3219 F_->createSave("save", cc);
3220 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 10);
3221 optimizedF_ = optimizeFunction(F_);
3222
3223 // Check that the concat node is gone.
3224 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 0);
3225 checkNumericalEquivalence(0.0f);
3226 }
3227
3228 /// Check that we are able to eliminate concat followed by slices on axis
3229 /// \p dim under certain conditions.
testConcatSliceElim(Module & mod,Function * F,Function * & optimizedF,PlaceholderBindings & bindings,size_t dim)3230 static void testConcatSliceElim(Module &mod, Function *F, Function *&optimizedF,
3231 PlaceholderBindings &bindings, size_t dim) {
3232 constexpr size_t N = 5;
3233 std::array<NodeValue, N> inputs;
3234 std::vector<dim_t> inShape = {10, 20};
3235 inShape.insert(inShape.begin() + dim, 0);
3236 for (dim_t i = 0; i < N; i++) {
3237 inShape[dim] = 1 + i;
3238 auto *P = mod.createPlaceholder(ElemKind::FloatTy, inShape, "in", true);
3239 bindings.allocate(P)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
3240 inputs[i] = P;
3241 }
3242 auto *CN = F->createConcat("merge", inputs, dim);
3243
3244 // Split the concat to a bunch of slices of the same shape as the concat
3245 // inputs and on the same axis.
3246 std::vector<dim_t> startShape = {0, 0, 0};
3247 std::vector<dim_t> endShape = {10, 20};
3248 endShape.insert(endShape.begin() + dim, 0);
3249 for (dim_t i = 0; i < N; i++) {
3250 startShape[dim] = (i * (i + 1)) / 2;
3251 endShape[dim] = ((i + 1) * (i + 2)) / 2;
3252 auto *SN = F->createSlice("extract", CN, startShape, endShape);
3253 F->createSave("save", SN);
3254 }
3255
3256 // We created a concat followed by N slices of its results.
3257 EXPECT_EQ(countNodeKind(F, Kinded::Kind::SliceNodeKind), N);
3258 EXPECT_EQ(countNodeKind(F, Kinded::Kind::ConcatNodeKind), 1);
3259
3260 optimizedF = optimizeFunction(F);
3261
3262 // Check that the concat and slices are gone.
3263 EXPECT_EQ(countNodeKind(optimizedF, Kinded::Kind::ConcatNodeKind), 0);
3264 EXPECT_EQ(countNodeKind(optimizedF, Kinded::Kind::SliceNodeKind), 0);
3265 }
3266
TEST_F(GraphOptz,concatSliceElimInnerDim)3267 TEST_F(GraphOptz, concatSliceElimInnerDim) {
3268 testConcatSliceElim(mod_, F_, optimizedF_, bindings_, 0);
3269 checkNumericalEquivalence(0.0f);
3270 }
3271
TEST_F(GraphOptz,concatSliceElimMiddleDim)3272 TEST_F(GraphOptz, concatSliceElimMiddleDim) {
3273 testConcatSliceElim(mod_, F_, optimizedF_, bindings_, 1);
3274 checkNumericalEquivalence(0.0f);
3275 }
3276
TEST_F(GraphOptz,concatSliceElimOuterDim)3277 TEST_F(GraphOptz, concatSliceElimOuterDim) {
3278 testConcatSliceElim(mod_, F_, optimizedF_, bindings_, 2);
3279 checkNumericalEquivalence(0.0f);
3280 }
3281
3282 /// Check the interaction between Sices(Concat) and Concat(Slices) optimizations
3283 /// to make sure they work nicely together. Builds Concat(Slices(Concat)) and
3284 /// expected a single Concat after optimizations.
TEST_F(GraphOptz,concatSliceElimMultiConcat)3285 TEST_F(GraphOptz, concatSliceElimMultiConcat) {
3286 std::array<NodeValue, 4> inputs;
3287 for (size_t i = 0; i < 4; i++) {
3288 auto *P = mod_.createPlaceholder(ElemKind::FloatTy, {2, 4},
3289 "in_" + std::to_string(i), false);
3290 bindings_.allocate(P)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
3291 inputs[i] = P;
3292 }
3293 auto *CN0 = F_->createConcat("merge0", inputs, /* axis */ 1);
3294
3295 auto *SN0 = F_->createSlice("slice0", CN0, {0, 0}, {2, 4});
3296 auto *SN1 = F_->createSlice("slice1", CN0, {0, 4}, {2, 8});
3297 auto *SN2 = F_->createSlice("slice2", CN0, {0, 8}, {2, 12});
3298 auto *SN3 = F_->createSlice("slice3", CN0, {0, 12}, {2, 16});
3299
3300 auto *CN1 = F_->createConcat("merge1", {SN1, SN0, SN3, SN2}, /* axis */ 1);
3301 F_->createSave("save", CN1);
3302
3303 // We created a concat followed by 4 slices of its results followed by another
3304 // concat.
3305 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConcatNodeKind), 2);
3306 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), 4);
3307
3308 optimizedF_ = optimizeFunction(F_);
3309
3310 // Check that one concat and slices are gone.
3311 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1);
3312 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SliceNodeKind), 0);
3313
3314 checkNumericalEquivalence(0.0f);
3315 }
3316
3317 // Check the transformation Concat(Reshape(x) * N) -> Reshape(Concat(x * N)).
TEST_F(GraphOptz,concatReshapes)3318 TEST_F(GraphOptz, concatReshapes) {
3319 const dim_t shape1[] = {2, 5, 2, 1, 20};
3320 const dim_t shape2[] = {10, 2, 2, 10};
3321 const dim_t shape3[] = {5, 80};
3322 llvm::SmallVector<NodeValue, 10> inputs1;
3323 llvm::SmallVector<NodeValue, 10> inputs2;
3324 for (size_t i = 0; i < 10; i++) {
3325 // 10 reshape nodes that transform from {2,5,2,1,20} to {10,2,2,10}.
3326 // And a ConcatNode concatenates the outputs of reshape at 2nd dim.
3327 // The optimization would kick in, as the size of trailing dimensions of
3328 // original ConcatNode (before opt) is 20, and the size of leading
3329 // dimensions of original ConcatNode (before opt) is 10.
3330 Node *var = F_->getParent()->createPlaceholder(
3331 ElemKind::FloatTy, shape1, "input" + std::to_string(i), true);
3332 auto *RN = F_->createReshape("reshape" + std::to_string(i), var, shape2);
3333 inputs1.push_back(RN);
3334 }
3335 auto *concatNode1 = F_->createConcat("concat", inputs1, 1);
3336 for (size_t i = 0; i < 10; i++) {
3337 // 10 reshape nodes that transform from {5,80} to {10,1,2,10}.
3338 // And a ConcatNode concatenates the outputs of reshape at 2nd dim.
3339 // The optimization would NOT kick in, as we cannot find the dim that
3340 // makes the leading/trailing dims same as in the case of the original
3341 // concat node.
3342 Node *var = F_->getParent()->createPlaceholder(
3343 ElemKind::FloatTy, shape3, "input" + std::to_string(i), true);
3344 auto *RN = F_->createReshape("reshape" + std::to_string(i), var, shape2);
3345 inputs2.push_back(RN);
3346 }
3347 auto *concatNode2 = F_->createConcat("concat", inputs2, 1);
3348 auto outputShape = concatNode1->getResult().dims();
3349 // Need to dereference the RN vectors, otherwise the user number of those
3350 // nodes would always be positive, making them unable to be removed by DCE.
3351 inputs1.clear();
3352 inputs2.clear();
3353
3354 auto *addNode = F_->createAdd("add", concatNode1, concatNode2);
3355 auto *O = F_->createSave("ret", addNode);
3356
3357 EXPECT_EQ(F_->getNodes().size(), 24);
3358
3359 ::glow::optimize(F_, CompilationMode::Infer);
3360
3361 // After optimization, we expect to see only 15 nodes. All 10 of the
3362 // reshapes that were the inputs to the first original concat node
3363 // (concatNode1) are removed, and a single new reshape is added after the
3364 // new concat.
3365 EXPECT_EQ(F_->getNodes().size(), 15);
3366
3367 // concatNode1 should not exist any more.
3368 EXPECT_FALSE(functionContainsNode(F_, concatNode1));
3369 // concatNode2 should still exist.
3370 EXPECT_TRUE(functionContainsNode(F_, concatNode2));
3371
3372 // The first input of addNode should be a Reshape node now, with the same
3373 // result shape of concatNode1.
3374 auto *newAddNode = llvm::dyn_cast<AddNode>(O->getInput());
3375 ASSERT_TRUE(newAddNode);
3376 auto *newRN = llvm::dyn_cast<ReshapeNode>(newAddNode->getLHS());
3377 ASSERT_TRUE(newRN);
3378 EXPECT_TRUE(newRN->getResult().getType()->dims().equals(outputShape));
3379
3380 // The input of newRN should be a ConcatNode now.
3381 auto *newCN = llvm::dyn_cast<ConcatNode>(newRN->getInput());
3382 ASSERT_TRUE(newCN);
3383 }
3384
3385 // Making sure we do not try to to optimize concat2(dim1, concat1(dim2, X, Y),
3386 // Z)
3387 // -> concat(dim1, X, Y, Z) when concat1 has multiple users.
TEST_F(GraphOptz,ConcatSimplificationNegative)3388 TEST_F(GraphOptz, ConcatSimplificationNegative) {
3389 const dim_t dim1[] = {1, 4, 4, 4};
3390 const dim_t dim2[] = {1, 4, 4, 8};
3391 auto *in1 = mod_.createPlaceholder(ElemKind::FloatTy, dim1, "in1", false);
3392 auto *in2 = mod_.createPlaceholder(ElemKind::FloatTy, dim1, "in2", false);
3393 auto *in3 = mod_.createPlaceholder(ElemKind::FloatTy, dim2, "in3", false);
3394
3395 auto *cnc1 = F_->createConcat("cnc1", {in1, in2}, 3);
3396 auto *add1 = F_->createAdd("add1", in3, cnc1);
3397 auto *cnc2 = F_->createConcat("cnc2", {add1, cnc1}, 3);
3398 F_->createSave("ret", cnc2);
3399 EXPECT_EQ(F_->getNodes().size(), 4);
3400 ::glow::optimize(F_, CompilationMode::Infer);
3401 EXPECT_EQ(F_->getNodes().size(), 4);
3402 for (auto &n : F_->getNodes()) {
3403 if (auto *tcnc = llvm::dyn_cast<ConcatNode>(&n)) {
3404 EXPECT_EQ(tcnc->getNumInputs(), 2);
3405 }
3406 }
3407 }
3408
3409 /// Check that Variable CSE works correctly, combining small Variables that
3410 /// have the same data.
TEST_F(GraphOptz,VarsCSE)3411 TEST_F(GraphOptz, VarsCSE) {
3412 // Create three variables that are Private, are not trainable, and have no
3413 // writers. The first two variables have the same data, and so should be
3414 // combined via variable CSE. The third variable differs by the last value,
3415 // and so should not be combined.
3416 auto *input1 = mod_.createConstant(ElemKind::FloatTy, {10}, "input1");
3417 auto *input2 = mod_.createConstant(ElemKind::FloatTy, {10}, "input2");
3418 auto *input3 = mod_.createConstant(ElemKind::FloatTy, {10}, "input3");
3419 input1->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
3420 input2->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
3421 input3->getHandle() = {0, 1, 2, 3, 4, 5, 6, 7, 8, -1};
3422
3423 // Input them each to different nodes, so node CSE does not change them.
3424 auto *TN = F_->createTanh("tanh", input1);
3425 auto *SN = F_->createSigmoid("sigmoid", input2);
3426 auto *RN = F_->createRELU("relu", input3);
3427 auto *CN = F_->createConcat("concat", {TN, SN, RN}, /* axis */ 0);
3428 F_->createSave("ret", CN);
3429
3430 // Initially there are three variables: inputs 1, 2, and 3 (the save uses a
3431 // placeholder).
3432 EXPECT_EQ(mod_.getConstants().size(), 3);
3433
3434 cctx_.compMode = CompilationMode::Infer;
3435 // Do not perform any compile-time constant folding.
3436 cctx_.optimizationOpts.enableConstantFolding = false;
3437 ::glow::optimize(F_, cctx_);
3438
3439 // Now only two variables are left; input1 and input2 have been combined,
3440 // but input3 has not.
3441 EXPECT_EQ(mod_.getConstants().size(), 2);
3442
3443 // Verify that only one of input1 and input2 exists, and that input3 still
3444 // exists.
3445 Constant *varOneOrTwo = nullptr;
3446 bool foundVarThree = false;
3447 for (auto *V : mod_.getConstants()) {
3448 if (V == input1 || V == input2) {
3449 EXPECT_TRUE(varOneOrTwo == nullptr);
3450 varOneOrTwo = V;
3451 } else if (V == input3) {
3452 foundVarThree = true;
3453 }
3454 }
3455 EXPECT_TRUE(varOneOrTwo != nullptr);
3456 EXPECT_TRUE(foundVarThree);
3457
3458 // Verify that the users of the inputs are updated correctly.
3459 EXPECT_TRUE(TN->getInput().getNode() == varOneOrTwo);
3460 EXPECT_TRUE(SN->getInput().getNode() == varOneOrTwo);
3461 EXPECT_TRUE(RN->getInput().getNode() == input3);
3462
3463 // Verify that whichever input1/input2 is left over has two users TN and SN.
3464 EXPECT_TRUE(varOneOrTwo->getUsers().size() == 2);
3465 for (auto &U : varOneOrTwo->getUsers()) {
3466 auto *N = U.getUser();
3467 EXPECT_TRUE(N == TN || N == SN);
3468 }
3469
3470 // Verify that input3 only has a single user RN.
3471 ASSERT_TRUE(input3->getUsers().size() == 1);
3472 EXPECT_TRUE(input3->getUsers().begin()->getUser() == RN);
3473 }
3474
TEST_F(GraphOptz,VarsCSENaN)3475 TEST_F(GraphOptz, VarsCSENaN) {
3476 // Create two variables that are Private, are not trainable, have no writers
3477 // and include NaNs. The first two variables have the same data, and so should
3478 // be combined via variable CSE. In particular, the NaN constants should not
3479 // prevent the variables from being combine.
3480 auto *input1 = mod_.createConstant(ElemKind::FloatTy, {5}, "input1");
3481 auto *input2 = mod_.createConstant(ElemKind::FloatTy, {5}, "input2");
3482 input1->getHandle() = {0, NAN, 2, NAN, 4};
3483 input2->getHandle() = {0, NAN, 2, NAN, 4};
3484
3485 // Input them each to different nodes, so node CSE does not change them.
3486 auto *TN = F_->createTanh("tanh", input1);
3487 auto *SN = F_->createSigmoid("sigmoid", input2);
3488 auto *CN = F_->createConcat("concat", {TN, SN}, /* axis */ 0);
3489 F_->createSave("ret", CN);
3490
3491 // Initially there are two variables: inputs 1 and 2 (the save uses a
3492 // placeholder).
3493 EXPECT_EQ(mod_.getConstants().size(), 2);
3494
3495 cctx_.compMode = CompilationMode::Infer;
3496 // Do not perform any compile-time constant folding.
3497 cctx_.optimizationOpts.enableConstantFolding = false;
3498 ::glow::optimize(F_, cctx_);
3499
3500 // Now only one variables is left; input1 and input2 have been combined.
3501 EXPECT_EQ(mod_.getConstants().size(), 1);
3502
3503 // Verify that only one of input1 and input2 exists.
3504 Constant *varOneOrTwo = nullptr;
3505 for (auto *V : mod_.getConstants()) {
3506 if (V == input1 || V == input2) {
3507 EXPECT_TRUE(varOneOrTwo == nullptr);
3508 varOneOrTwo = V;
3509 }
3510 }
3511 EXPECT_TRUE(varOneOrTwo != nullptr);
3512
3513 // Verify that the users of the inputs are updated correctly.
3514 EXPECT_TRUE(TN->getInput().getNode() == varOneOrTwo);
3515 EXPECT_TRUE(SN->getInput().getNode() == varOneOrTwo);
3516
3517 // Verify that whichever input1/input2 is left over has two users TN and SN.
3518 EXPECT_TRUE(varOneOrTwo->getUsers().size() == 2);
3519 for (auto &U : varOneOrTwo->getUsers()) {
3520 auto *N = U.getUser();
3521 EXPECT_TRUE(N == TN || N == SN);
3522 }
3523 }
3524
3525 // Verify that constant input canonicalization works correctly when the
3526 // arithmetic nodes have multiple users.
TEST_F(GraphOptz,simplifyArithmeticMultipleUsers)3527 TEST_F(GraphOptz, simplifyArithmeticMultipleUsers) {
3528 Node *I1 =
3529 mod_.createPlaceholder(ElemKind::FloatTy, {10, 10, 10}, "input1", false);
3530
3531 Type t(ElemKind::FloatTy, {10, 10, 10});
3532 Node *SN = F_->createSplat("one", &t, 1.0);
3533
3534 // The splat is a constant input to add1 and add2, and is their LHS input.
3535 // We expect canonicalization to occur during optimization, moving the splat
3536 // to the RHS for both. Note that add1 has multiple users: add2 and save1.
3537 Node *AN1 = F_->createAdd("add1", SN, I1);
3538 Node *AN2 = F_->createAdd("add2", SN, AN1);
3539 SaveNode *SN1 = F_->createSave("save1", AN1);
3540 SaveNode *SN2 = F_->createSave("save2", AN2);
3541
3542 // Five nodes in total: one splat, two adds, and two saves.
3543 EXPECT_EQ(F_->getNodes().size(), 5);
3544 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SplatNodeKind), 1);
3545 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AddNodeKind), 2);
3546 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 2);
3547
3548 // input1 has a single user before optimization.
3549 EXPECT_EQ(I1->getUsers().size(), 1);
3550
3551 // Simplify nodes will canonicalize add1 and add2, and should replace all
3552 // their users, without otherwise adding new nodes to the graph/changing the
3553 // overall structure.
3554 ::glow::optimize(F_, CompilationMode::Infer);
3555
3556 // We should have the same five nodes: one splat, two adds, and two saves.
3557 EXPECT_EQ(F_->getNodes().size(), 5);
3558 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SplatNodeKind), 1);
3559 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::AddNodeKind), 2);
3560 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SaveNodeKind), 2);
3561
3562 // Verify that both add nodes were canonicalized, and that the graph's shape
3563 // is the same as prior to optimization other than canonicalization.
3564 AddNode *newAN1 = llvm::dyn_cast<AddNode>(SN1->getInput().getNode());
3565 ASSERT_TRUE(newAN1 != nullptr);
3566 EXPECT_TRUE(llvm::isa<Placeholder>(newAN1->getLHS()));
3567 EXPECT_TRUE(llvm::isa<SplatNode>(newAN1->getRHS()));
3568
3569 AddNode *newAN2 = llvm::dyn_cast<AddNode>(SN2->getInput().getNode());
3570 ASSERT_TRUE(newAN2 != nullptr);
3571 EXPECT_TRUE(llvm::isa<AddNode>(newAN2->getLHS()));
3572 EXPECT_TRUE(llvm::isa<SplatNode>(newAN2->getRHS()));
3573
3574 EXPECT_EQ(newAN1, newAN2->getLHS());
3575
3576 // input1 should still have a single user after optimization.
3577 EXPECT_EQ(I1->getUsers().size(), 1);
3578 }
3579
3580 /// Test that a concat with a single input is replaced by the input.
TEST_F(GraphOptz,eliminateSingleConcat)3581 TEST_F(GraphOptz, eliminateSingleConcat) {
3582 Node *input = mod_.createPlaceholder(ElemKind::FloatTy, {10}, "input", false);
3583
3584 ConcatNode *CN = F_->createConcat("concat1", {input}, 0);
3585 SaveNode *SN = F_->createSave("ret", CN);
3586
3587 // The ConcatNode and SaveNode.
3588 EXPECT_EQ(F_->getNodes().size(), 2);
3589
3590 ::glow::optimize(F_, CompilationMode::Infer);
3591
3592 // Just the SaveNode should be left.
3593 EXPECT_EQ(F_->getNodes().size(), 1);
3594 ASSERT_TRUE(functionContainsNode(F_, SN));
3595
3596 // Save node should just save the input.
3597 EXPECT_TRUE(SN->getInput().getNode() == input);
3598 }
3599
3600 /// Test that a reshape of a private variable with one use has the reshape
3601 /// merged into the variable.
TEST_F(GraphOptz,ReshapeConstantOneUse)3602 TEST_F(GraphOptz, ReshapeConstantOneUse) {
3603 const dim_t shape[] = {10, 20};
3604 const dim_t reshape1[] = {200, 1};
3605 const dim_t reshape2[] = {200};
3606 Constant *input =
3607 F_->getParent()->createConstant(ElemKind::FloatTy, shape, "input");
3608 input->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
3609
3610 auto *R1 = F_->createReshape("reshape1", input, reshape1);
3611 auto *R2 = F_->createReshape("reshape2", R1, reshape2);
3612 auto *O = F_->createSave("ret", R2);
3613
3614 // Before optimization, we have 2 Reshapes and a Save.
3615 EXPECT_EQ(F_->getNodes().size(), 3);
3616
3617 // Skip ConstantFolding as it would have the same result as this opt.
3618 cctx_.optimizationOpts.enableConstantFolding = false;
3619 ::glow::optimize(F_, cctx_);
3620
3621 // After optimization, we expect to see just a Save.
3622 EXPECT_EQ(F_->getNodes().size(), 1);
3623
3624 // Save should have the new Variable as input.
3625 auto *V = llvm::dyn_cast<Constant>(O->getInput());
3626 ASSERT_TRUE(V);
3627 // The new Variable should have the same shape as the original second
3628 // Reshape.
3629 EXPECT_TRUE(V->getType()->dims().equals(reshape2));
3630 }
3631
3632 /// Test that reshape node is merged into Constant in a sequence
3633 /// Reshape(Quantize(Constant)).
TEST_F(GraphOptz,ReshapeQuantizeConstant)3634 TEST_F(GraphOptz, ReshapeQuantizeConstant) {
3635 const dim_t shape[] = {10, 20};
3636 const dim_t newShape[] = {200, 1};
3637
3638 auto *qTy = mod_.uniqueType(ElemKind::Int8QTy, shape, 0.2, 0);
3639
3640 auto *input =
3641 F_->getParent()->createConstant(ElemKind::FloatTy, shape, "input");
3642 auto *Q = F_->createQuantize("quantize", input, qTy);
3643 auto *R = F_->createReshape("reshape", Q, newShape);
3644 auto *S = F_->createSave("ret", R);
3645
3646 // Skip ConstantFolding as it would have the same result as this opt.
3647 CompilationContext cctx;
3648 cctx.optimizationOpts.enableConstantFolding = false;
3649
3650 EXPECT_EQ(F_->getNodes().size(), 3);
3651 ::glow::optimize(F_, cctx);
3652 EXPECT_EQ(F_->getNodes().size(), 2);
3653
3654 // Constant and Quantize should have new shape.
3655 auto *newQ = llvm::dyn_cast<QuantizeNode>(S->getInput());
3656 ASSERT_TRUE(newQ);
3657 EXPECT_TRUE(newQ->getResult().dims().equals(newShape));
3658 auto *newC = llvm::dyn_cast<Constant>(newQ->getInput());
3659 ASSERT_TRUE(newC);
3660 EXPECT_TRUE(newC->getType()->dims().equals(newShape));
3661 }
3662
3663 /// Test that Transpose is optimized into Reshape when it moves no data.
TEST_F(GraphOptz,transposeIntoReshapeOptim)3664 TEST_F(GraphOptz, transposeIntoReshapeOptim) {
3665 auto *batch =
3666 mod_.createPlaceholder(ElemKind::FloatTy, {1, 3, 2, 4}, "batch", false);
3667 Node *T = F_->createTranspose("transpose", batch, {1, 2, 0, 3});
3668 SaveNode *O = F_->createSave("ret", T);
3669
3670 EXPECT_EQ(F_->getNodes().size(), 2);
3671
3672 ::glow::optimize(F_, CompilationMode::Infer);
3673 EXPECT_EQ(F_->getNodes().size(), 2);
3674
3675 // TransposeNode is Optimized into ReshapeNode.
3676 auto *reshape = llvm::dyn_cast<ReshapeNode>(O->getInput().getNode());
3677 ASSERT_NE(reshape, nullptr);
3678 }
3679
3680 /// Test that transpose is merged into matmul.
TEST_F(GraphOptz,mergeTransposeIntoMatMul)3681 TEST_F(GraphOptz, mergeTransposeIntoMatMul) {
3682 auto *input =
3683 mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 2, 3}, "input", false);
3684 auto *weights =
3685 F_->getParent()->createConstant(ElemKind::FloatTy, {6, 1}, "weights");
3686
3687 weights->getHandle() = {0, 1, 2, 3, 4, 5};
3688 float newWeightsRef[] = {0, 2, 4, 1, 3, 5};
3689
3690 auto *TN = F_->createTranspose("transpose", input, NHWC2NCHW);
3691 auto *RN = F_->createReshape("reshape", TN, {1, 6});
3692 auto *MMN = F_->createMatMul("matmul", RN, weights);
3693 auto *SN = F_->createSave("ret", MMN);
3694
3695 // Transpose + Reshape + MatMul + Save.
3696 EXPECT_EQ(F_->getNodes().size(), 4);
3697
3698 ::glow::optimize(F_, CompilationMode::Infer);
3699
3700 // Reshape + MatMul + Save.
3701 EXPECT_EQ(F_->getNodes().size(), 3);
3702
3703 // Check reordered weights.
3704 auto *newMMN = llvm::dyn_cast<MatMulNode>(SN->getInput());
3705 ASSERT_TRUE(newMMN != nullptr);
3706 auto *newW = llvm::dyn_cast<Constant>(newMMN->getRHS());
3707 ASSERT_TRUE(newW != nullptr);
3708 for (unsigned i = 0; i < 6; ++i) {
3709 EXPECT_EQ(newWeightsRef[i], newW->getHandle().raw(i));
3710 }
3711 }
3712
3713 /// Test that transpose is merged into FullyConnected.
TEST_F(GraphOptz,mergeTransposeIntoFC)3714 TEST_F(GraphOptz, mergeTransposeIntoFC) {
3715 auto *input =
3716 mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 2, 3}, "input", false);
3717 auto *weights =
3718 F_->getParent()->createConstant(ElemKind::FloatTy, {6, 1}, "weights");
3719 auto *bias = F_->getParent()->createConstant(ElemKind::FloatTy, {1}, "bias");
3720
3721 weights->getHandle() = {0, 1, 2, 3, 4, 5};
3722 float newWeightsRef[] = {0, 2, 4, 1, 3, 5};
3723
3724 auto *TN = F_->createTranspose("transpose", input, NHWC2NCHW);
3725 auto *RN = F_->createReshape("reshape", TN, {1, 6});
3726 auto *FCN = F_->createFullyConnected("fc", RN, weights, bias);
3727 auto *SN = F_->createSave("ret", FCN);
3728
3729 // Transpose + Reshape + FC + Save.
3730 EXPECT_EQ(F_->getNodes().size(), 4);
3731
3732 ::glow::optimize(F_, CompilationMode::Infer);
3733
3734 // Reshape + FC + Save.
3735 EXPECT_EQ(F_->getNodes().size(), 3);
3736
3737 // Check reordered weights.
3738 auto *newFCN = llvm::dyn_cast<FullyConnectedNode>(SN->getInput());
3739 ASSERT_TRUE(newFCN != nullptr);
3740 auto *newW = llvm::dyn_cast<Constant>(newFCN->getWeights());
3741 ASSERT_TRUE(newW != nullptr);
3742 for (unsigned i = 0; i < 6; ++i) {
3743 EXPECT_EQ(newWeightsRef[i], newW->getHandle().raw(i));
3744 }
3745 }
3746
TEST_F(GraphOptz,ConvertPlaceholdersToConstants)3747 TEST_F(GraphOptz, ConvertPlaceholdersToConstants) {
3748 auto *input1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input1", true);
3749 auto *input2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input2", true);
3750 auto *input3 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input3", true);
3751 auto *save1 = F_->createSave("save1", input1);
3752 auto *save2 = F_->createSave("save2", input2);
3753 auto *save3 = F_->createSave("save3", input3);
3754
3755 // No variables, six PHs (3 inputs, 3 saves).
3756 EXPECT_EQ(mod_.getConstants().size(), 0);
3757 EXPECT_EQ(mod_.getPlaceholders().size(), 6);
3758
3759 // Allocate two of the three inputs, but mark input2 of them as
3760 // non-constant.
3761 bindings_.allocate(input1);
3762 bindings_.allocate(input2);
3763 // Don't allocate input3; keep it as a placeholder instead.
3764 ::glow::convertPlaceholdersToConstants(F_, bindings_, {input2});
3765
3766 // input1 becomes a variable.
3767 EXPECT_EQ(mod_.getConstants().size(), 1);
3768 EXPECT_EQ(mod_.getPlaceholders().size(), 6);
3769
3770 EXPECT_TRUE(llvm::isa<Constant>(save1->getInput()));
3771 EXPECT_TRUE(llvm::isa<Placeholder>(save2->getInput()));
3772 EXPECT_TRUE(llvm::isa<Placeholder>(save3->getInput()));
3773 }
3774
TEST_F(GraphOptz,optimizeConversion_i32_i64_i32)3775 TEST_F(GraphOptz, optimizeConversion_i32_i64_i32) {
3776 auto *i32 = mod_.uniqueType(ElemKind::Int32ITy, {1});
3777 auto *i64 = mod_.uniqueType(ElemKind::Int64ITy, {1});
3778
3779 auto *A = mod_.createPlaceholder(i32, "A", false);
3780 auto *B = F_->createConvertTo("B", A, i64);
3781 auto *C = F_->createConvertTo("C", B, i32);
3782 auto *S = F_->createSave("S", C);
3783
3784 ::glow::optimize(F_, CompilationMode::Infer);
3785
3786 // All casting is optimized away, only left with Save of Placeholder.
3787 EXPECT_EQ(F_->getNodes().size(), 1);
3788 EXPECT_TRUE(llvm::isa<Placeholder>(S->getInput()));
3789 }
3790
TEST_F(GraphOptz,optimizeSameTypeConversions)3791 TEST_F(GraphOptz, optimizeSameTypeConversions) {
3792 auto *input1 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input1", true);
3793 auto *input2 = mod_.createPlaceholder(ElemKind::FloatTy, {1}, "input2", true);
3794 auto *conv1 = F_->createConvertTo("cast1", input1, ElemKind::FloatTy);
3795 auto *conv2 = F_->createConvertTo("cast2", input2, ElemKind::Float16Ty);
3796 auto *save1 = F_->createSave("save1", conv1);
3797 auto *save2 = F_->createSave("save1", conv2);
3798
3799 // convert_to1 + save1 + convert_to2 + save2 nodes.
3800 EXPECT_EQ(F_->getNodes().size(), 4);
3801 EXPECT_TRUE(llvm::isa<ConvertToNode>(save1->getInput()));
3802
3803 ::glow::optimize(F_, CompilationMode::Infer);
3804
3805 // save1 + convert_to2 + save2 nodes.
3806 EXPECT_EQ(F_->getNodes().size(), 3);
3807 // convert_to1 node should be eliminated, because it converts the node into
3808 // the same type.
3809 EXPECT_TRUE(llvm::isa<Placeholder>(save1->getInput()));
3810 // convert_to1 node should not be eliminated, because it converts the node
3811 // into a different type.
3812 EXPECT_TRUE(llvm::isa<ConvertToNode>(save2->getInput()));
3813 EXPECT_EQ(save2->getInput(), NodeValue(conv2));
3814 }
3815
TEST_F(GraphOptz,optimizeConvertingBetweenFused)3816 TEST_F(GraphOptz, optimizeConvertingBetweenFused) {
3817 // Call with dims {5, 2}, which will actually create a constant with {5, 10}
3818 // for scale/offset per row.
3819 Constant *C = createRandomFusedRowwiseQuantizedConstant(
3820 mod_, {5, 2}, "fused", /* useFusedFP16 */ false);
3821 // Converting to fused FP16 means we have 4 total less bytes for scale/offset,
3822 // so we move to {5, 10} from {5, 6}.
3823 auto newOT = mod_.uniqueType(ElemKind::UInt8FusedFP16QTy, {5, 6}, 1.0, 0);
3824 auto *CN = F_->createConvertTo("convert", C, newOT);
3825 auto *SN = F_->createSave("save", CN);
3826
3827 ::glow::optimize(F_, CompilationMode::Infer);
3828
3829 // Convert should be eliminated and just the save of the Constant left.
3830 EXPECT_EQ(F_->getNodes().size(), 1);
3831 Constant *convertedC = llvm::dyn_cast<Constant>(SN->getInput());
3832 ASSERT_TRUE(convertedC);
3833 EXPECT_EQ(convertedC->getElementType(), ElemKind::UInt8FusedFP16QTy);
3834 }
3835
TEST_F(GraphOptz,dceBeforeOptimizeTranpose)3836 TEST_F(GraphOptz, dceBeforeOptimizeTranpose) {
3837 auto *input1 = mod_.createConstant(ElemKind::FloatTy, {5, 10}, "input1");
3838 // Create an unused node.
3839 F_->createAdd("add", input1, input1);
3840 auto *transposedInput1 = F_->createTranspose("transpose", input1, {1, 0});
3841 auto *save1 = F_->createSave("save1", transposedInput1);
3842
3843 // add + transpose + save.
3844 EXPECT_EQ(F_->getNodes().size(), 3);
3845
3846 ::glow::optimize(F_, CompilationMode::Infer);
3847
3848 // A single node: save.
3849 EXPECT_EQ(F_->getNodes().size(), 1);
3850 // transpose should be eliminated and replaced by the transposed constant.
3851 EXPECT_TRUE(llvm::isa<Constant>(save1->getInput()));
3852 }
3853
3854 /// Test that Transpose is sunk below ChannelShuffle and cancels with an
3855 /// inverse transpose below the ChannelShuffle. This test models a pattern
3856 /// that has has been observed in shufflenet during graph optimization.
TEST_F(GraphOptz,sinkTransposeBelowChannelShuffleNodesAndEliminate)3857 TEST_F(GraphOptz, sinkTransposeBelowChannelShuffleNodesAndEliminate) {
3858 const dim_t inputDims[] = {3, 28, 28, 136};
3859
3860 Node *K =
3861 mod_.createPlaceholder(ElemKind::FloatTy, inputDims, "input", false);
3862 K = F_->createTranspose("unnecessary_transpose_1", K, {0, 3, 1, 2});
3863 K = F_->createChannelShuffle("channel_shuffle", K, 4, 1);
3864 K = F_->createTranspose("unnecessary_transpose_2", K, {0, 2, 3, 1});
3865 auto *save = F_->createSave("ret", K);
3866
3867 EXPECT_EQ(F_->getNodes().size(), 4);
3868
3869 // Optimize away the unnecessary transposes.
3870 optimize(F_, CompilationMode::Infer);
3871
3872 // Ensure the two unnecessary transposes are gone.
3873 ASSERT_EQ(F_->getNodes().size(), 2);
3874
3875 // Check that the channel shuffle node is still there.
3876 auto *CSN = llvm::dyn_cast<ChannelShuffleNode>(save->getInput().getNode());
3877 ASSERT_NE(nullptr, CSN);
3878
3879 // Ensure ChannelShuffle node has the same dimensions as the input.
3880 EXPECT_EQ(CSN->getResult().dims(), llvm::makeArrayRef(inputDims));
3881
3882 // Ensure Group and Kernel are as expected.
3883 EXPECT_EQ(CSN->getGroup(), 4);
3884 EXPECT_EQ(CSN->getKernel(), 3);
3885 }
3886
3887 /// Test BatchNorm sinking below Slice.
TEST_F(GraphOptz,sinkBatchNormBelowSlice)3888 TEST_F(GraphOptz, sinkBatchNormBelowSlice) {
3889 auto *inputTy = mod_.uniqueType(ElemKind::FloatTy, {1, 10, 10, 3});
3890 auto *slicedTy1 = mod_.uniqueType(ElemKind::FloatTy, {1, 8, 8, 3});
3891 auto *slicedTy2 = mod_.uniqueType(ElemKind::FloatTy, {1, 6, 6, 1});
3892
3893 auto *input = mod_.createPlaceholder(inputTy, "input", false);
3894 auto *BN = F_->createBatchNormalization(bindings_, "batchnorm", input, 3,
3895 0.0001, 0.9);
3896 auto *SN1 = F_->createSlice("slice1", BN, {0, 1, 1, 0}, slicedTy1);
3897 auto *SN2 = F_->createSlice("slice2", SN1, {0, 1, 1, 1}, slicedTy2);
3898 auto *save = F_->createSave("save", SN2);
3899
3900 EXPECT_EQ(F_->getNodes().size(), 4);
3901 ::glow::convertPlaceholdersToConstants(F_, bindings_, {});
3902 optimizedF_ = optimizeFunction(F_);
3903 EXPECT_EQ(optimizedF_->getNodes().size(), 4);
3904
3905 // BatchNorm should have sunk below the first Slice, but not the second one,
3906 // as it changes channel dimmension.
3907 auto *newSave =
3908 findFunctionNodeByName<SaveNode>(optimizedF_, save->getName());
3909 ASSERT_TRUE(newSave);
3910 auto *newSN2 = llvm::dyn_cast<SliceNode>(newSave->getInput());
3911 ASSERT_TRUE(newSN2);
3912 auto *newBN = llvm::dyn_cast<BatchNormalizationNode>(newSN2->getInput());
3913 ASSERT_TRUE(newBN);
3914 ASSERT_EQ(newBN->getResult().dims(), slicedTy1->dims());
3915 ASSERT_TRUE(llvm::isa<SliceNode>(newBN->getInput()));
3916
3917 bindings_.allocate(mod_.getPlaceholders());
3918 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
3919 checkNumericalEquivalence();
3920 }
3921
3922 /// Test that convertPlaceholdersToConstants works properly with quantized
3923 /// types.
TEST_F(GraphOptz,QuantizedFC)3924 TEST_F(GraphOptz, QuantizedFC) {
3925 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 1.0, 0,
3926 "input", false);
3927 auto *weights = mod_.createPlaceholder(ElemKind::Int8QTy, {32, 32}, 1.0, 0,
3928 "weights", false);
3929 auto *bias =
3930 mod_.createPlaceholder(ElemKind::Int32QTy, {32}, 1.0, 0, "bias", false);
3931 auto *output = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 1.0, 0,
3932 "output", false);
3933
3934 auto *fc = F_->createFullyConnected("fc", input, weights, bias);
3935 F_->createSave("save", fc, output);
3936
3937 bindings_.allocate(input);
3938 bindings_.allocate(weights);
3939 bindings_.allocate(bias);
3940 bindings_.allocate(output);
3941
3942 glow::convertPlaceholdersToConstants(F_, bindings_, {input, output});
3943 // Two constants: weight and bias
3944 EXPECT_EQ(mod_.getConstants().size(), 2);
3945 // All four placeholders still exist in the module. The old weight and bias
3946 // placeholders just aren't hooked up the the Graph F_.
3947 EXPECT_EQ(mod_.getPlaceholders().size(), 4);
3948 }
3949
3950 /// Test batchedReduceMean optimization using AvgPool.
TEST_F(GraphOptz,convertReduceMean2AvgPool)3951 TEST_F(GraphOptz, convertReduceMean2AvgPool) {
3952 const dim_t dims[] = {2, 2, 2, 2};
3953
3954 Node *A = mod_.createPlaceholder(ElemKind::FloatTy, dims, "input", false);
3955 Node *R = F_->createBatchedReduceMean("reduce.mean", A, {2, 3});
3956
3957 SaveNode *O = F_->createSave("ret", R);
3958
3959 EXPECT_EQ(F_->getNodes().size(), 2);
3960
3961 ::glow::optimize(F_, CompilationMode::Infer);
3962
3963 // Optimization adds 2 transpose nodes and one reshape node.
3964 EXPECT_EQ(F_->getNodes().size(), 5);
3965
3966 // Expecting reshape output rather than ReduceMean.
3967 auto *RN = llvm::dyn_cast<ReshapeNode>(O->getInput());
3968 ASSERT_NE(RN, nullptr);
3969
3970 // Expecting Transpose node before Reshape node.
3971 auto *TN = llvm::dyn_cast<TransposeNode>(RN->getInput());
3972 ASSERT_NE(TN, nullptr);
3973
3974 // Expecting AvgPool node before Transpose node.
3975 auto *APN = llvm::dyn_cast<AvgPoolNode>(TN->getInput());
3976 ASSERT_NE(APN, nullptr);
3977 }
3978
3979 /// Test Broadcasted RHS BatchMatMul is converted correctly to a single MatMul.
TEST_F(GraphOptz,convertBroadcastedBatchMatMulToMatMul)3980 TEST_F(GraphOptz, convertBroadcastedBatchMatMulToMatMul) {
3981 auto *lhs =
3982 mod_.createPlaceholder(ElemKind::FloatTy, {2, 3, 2}, "lhs", false);
3983 auto *rhs = mod_.createPlaceholder(ElemKind::FloatTy, {2, 1}, "rhs", false);
3984 auto *BMMN = F_->createBatchMatMul("BMM", lhs, rhs);
3985 F_->createSave("save", BMMN);
3986
3987 // Start with a BatchMatMul, not a MatMul.
3988 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchMatMulNodeKind), 1);
3989 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 0);
3990
3991 ::glow::optimize(F_, CompilationMode::Infer);
3992
3993 // Optimization should replace the BatchMatMul with a single MatMul.
3994 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::MatMulNodeKind), 1);
3995 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::BatchMatMulNodeKind), 0);
3996 }
3997
TEST_F(GraphOptz,dceQuantization)3998 TEST_F(GraphOptz, dceQuantization) {
3999 auto *lhs =
4000 mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5}, 0.3, 15, "lhs", false);
4001 auto *weights =
4002 mod_.createConstant(ElemKind::Int8QTy, {3, 5}, 0.3, 15, "weights");
4003
4004 auto *add = F_->createAdd("add", lhs, weights);
4005 auto *t1 = mod_.uniqueType(ElemKind::Int8QTy, {3, 5}, 0.2, 0);
4006 auto *rs1 = F_->createRescaleQuantized("rs1", add, t1);
4007 auto *t2 = mod_.uniqueType(ElemKind::Int8QTy, {3, 5}, 0.1, 1);
4008 auto *rs2 = F_->createRescaleQuantized("rs2", rs1, t2);
4009 F_->createSave("save", rs2);
4010
4011 ::glow::optimize(F_, CompilationMode::Infer);
4012
4013 EXPECT_EQ(F_->getNodes().size(), 2);
4014 }
4015
TEST_F(GraphOptz,nopRelu)4016 TEST_F(GraphOptz, nopRelu) {
4017 auto *in = mod_.createPlaceholder(ElemKind::Int8QTy, {3, 5}, 0.3, -128, "lhs",
4018 false);
4019
4020 auto *relu = F_->createRELU("relu", in);
4021 F_->createSave("save", relu);
4022
4023 optimizedF_ = optimizeFunction(F_);
4024
4025 EXPECT_EQ(optimizedF_->getNodes().size(), 1);
4026
4027 bindings_.allocate(mod_.getPlaceholders());
4028 bindings_.get(in)->getHandle<int8_t>().randomize(-4, 4, mod_.getPRNG());
4029
4030 checkNumericalEquivalence();
4031 }
4032
4033 template <typename ElemTy>
setConstValue(Constant * C,ElemTy value)4034 static void setConstValue(Constant *C, ElemTy value) {
4035 Handle<ElemTy> TH = C->getPayload().getHandle<ElemTy>();
4036 TH.clear(value);
4037 }
4038
TEST_F(GraphOptz,constantFoldSingleNode)4039 TEST_F(GraphOptz, constantFoldSingleNode) {
4040 auto *const1 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const1");
4041 auto *const2 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const2");
4042 auto *ph1 = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "input1",
4043 /* isTrainable */ false);
4044 setConstValue(const1, 1.0f);
4045 setConstValue(const2, 2.0f);
4046 auto *splat2 = F_->createSplat(
4047 "splat2", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 2.0f);
4048 auto *splat3 = F_->createSplat(
4049 "splat3", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 3.0f);
4050
4051 auto *add1 = F_->createAdd("add", const1, const2);
4052 auto *mul1 = F_->createMul("mul1", add1, splat2);
4053 auto *mul2 = F_->createMul("mul2", mul1, splat3);
4054 auto *SN1 = F_->createSave("save", mul2);
4055 auto *add3 = F_->createAdd("add", const1, ph1);
4056 auto *SN2 = F_->createSave("save", add3);
4057
4058 // Perform constant folding for a specific node.
4059 std::vector<Constant *> constResults =
4060 constantFold(SN1->getInput().getNode());
4061
4062 EXPECT_EQ(constResults.size(), 1);
4063 SN1->getInput().replaceAllUsesOfWith(constResults[0]);
4064 // Second save should be unaffected.
4065 EXPECT_FALSE(llvm::isa<Constant>(SN2->getInput()));
4066 // First save should have been constant folded.
4067 EXPECT_TRUE(llvm::isa<Constant>(SN1->getInput()));
4068 Constant *C = llvm::dyn_cast<Constant>(SN1->getInput());
4069 auto CH = C->getHandle();
4070 // The expected result should be: (((1+2) * 2 * 3) = 18
4071 EXPECT_EQ(CH.at({0, 0}), 18.0f);
4072 EXPECT_EQ(CH.at({0, 1}), 18.0f);
4073 EXPECT_EQ(CH.at({1, 0}), 18.0f);
4074 EXPECT_EQ(CH.at({1, 1}), 18.0f);
4075 }
4076
4077 /// Test that we correctly record a single constant folding subgraph that has a
4078 /// single output.
TEST_F(GraphOptz,constantFoldWithRecordSingleChain)4079 TEST_F(GraphOptz, constantFoldWithRecordSingleChain) {
4080 Placeholder *I =
4081 mod_.createPlaceholder(ElemKind::Float16Ty, {2, 100}, "input",
4082 /* isTrainable */ false);
4083 Constant *W = mod_.createConstant(ElemKind::FloatTy, {10, 100}, "weight");
4084 ClipNode *clipW = F_->createClip("clip", W, -5.f, 5.f);
4085 ConvertToNode *convertW =
4086 F_->createConvertTo("conv", clipW, ElemKind::Float16Ty);
4087 TransposeNode *transposeW =
4088 F_->createTranspose("transpose", convertW, {1, 0});
4089 MatMulNode *MM = F_->createMatMul("matmul", I, transposeW);
4090 SaveNode *save = F_->createSave("save", MM);
4091 Placeholder *O = save->getPlaceholder();
4092 bindings_.allocate(O);
4093
4094 ASSERT_TRUE(F_->verify());
4095
4096 Tensor *IT = bindings_.allocate(I);
4097 IT->getHandle<float16_t>().randomize(-10, 10, mod_.getPRNG());
4098 W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG());
4099
4100 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4101
4102 ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_);
4103
4104 runDCEPass(optimizedF_, cctx_);
4105
4106 ASSERT_EQ(record.size(), 1);
4107 SaveNode *SN = record.begin()->second;
4108 Function *constFoldF = SN->getParent();
4109
4110 // Expect to find a chain of Nodes based on Nodes above. Note that the clip is
4111 // lowered for the Interpreter backend which performs constant folding.
4112 EXPECT_EQ(2, countNodeKind(constFoldF, Kinded::Kind::SplatNodeKind));
4113 EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::MaxNodeKind));
4114 EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::MinNodeKind));
4115 EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::ConvertToNodeKind));
4116 EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::TransposeNodeKind));
4117
4118 // Skip optimizations -- we just want to run them as is (otherwise we'll
4119 // constant fold them inside the optimization pipeline).
4120 cctx_.optimizationOpts.onlyLowerFuns.insert(constFoldF);
4121 cctx_.optimizationOpts.onlyLowerFuns.insert(F_);
4122 cctx_.optimizationOpts.onlyLowerFuns.insert(optimizedF_);
4123
4124 // Don't strip the module as we want to compare the Constant values below.
4125 EE_.setSkipModuleStrip(true);
4126
4127 EE_.compile(cctx_);
4128 alreadyCompiled_ = true;
4129
4130 bindings_.allocate(mod_.getPlaceholders());
4131
4132 // Run the constant folding chain to check that we have the same constant used
4133 // by the optimized Function.
4134 EE_.run(bindings_, constFoldF->getName());
4135 Tensor *rerunT = bindings_.get(SN->getPlaceholder());
4136 ASSERT_TRUE(rerunT);
4137 auto optimizedConstants = optimizedF_->findConstants();
4138 ASSERT_EQ(optimizedConstants.size(), 1);
4139 EXPECT_TRUE(
4140 (*optimizedConstants.begin())->getPayload().isEqual(*rerunT, 0.f));
4141
4142 // Remove the temporary constant folding Functions and their Placeholders.
4143 cleanupConstantFolding(mod_, record, &bindings_);
4144
4145 // Now compile/run/compare F_ and optimizedF_.
4146 checkNumericalEquivalence(0.f);
4147 }
4148
4149 /// Test that we correctly record two constant folding subgraphs, with each with
4150 /// a single output.
TEST_F(GraphOptz,constantFoldWithRecordMultiChain)4151 TEST_F(GraphOptz, constantFoldWithRecordMultiChain) {
4152 Placeholder *I =
4153 mod_.createPlaceholder(ElemKind::Float16Ty, {2, 100}, "input",
4154 /* isTrainable */ false);
4155 Constant *W = mod_.createConstant(ElemKind::FloatTy, {10, 100}, "weight");
4156 ClipNode *clipW = F_->createClip("clip", W, -5.f, 5.f);
4157 ConvertToNode *convertW =
4158 F_->createConvertTo("conv", clipW, ElemKind::Float16Ty);
4159 TransposeNode *transposeW =
4160 F_->createTranspose("transpose", convertW, {1, 0});
4161 MatMulNode *MM = F_->createMatMul("matmul", I, transposeW);
4162 SaveNode *saveMM = F_->createSave("save_mm", MM);
4163 Placeholder *MMP = saveMM->getPlaceholder();
4164 bindings_.allocate(MMP);
4165
4166 SigmoidNode *sigmoidW = F_->createSigmoid("sig", convertW);
4167 SaveNode *saveSig = F_->createSave("save_sig", sigmoidW);
4168 Placeholder *sigP = saveSig->getPlaceholder();
4169 bindings_.allocate(sigP);
4170
4171 ASSERT_TRUE(F_->verify());
4172
4173 Tensor *IT = bindings_.allocate(I);
4174 IT->getHandle<float16_t>().randomize(-10, 10, mod_.getPRNG());
4175 W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG());
4176
4177 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4178
4179 ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_);
4180
4181 runDCEPass(optimizedF_, cctx_);
4182
4183 ASSERT_EQ(record.size(), 2);
4184 SaveNode *sigSN = record.begin()->second;
4185 SaveNode *transSN = std::next(record.begin())->second;
4186 if (llvm::isa<SigmoidNode>(transSN->getInput())) {
4187 std::swap(sigSN, transSN);
4188 }
4189
4190 Function *constFoldSig = sigSN->getParent();
4191 Function *constFoldTrans = transSN->getParent();
4192
4193 // Expect to find a chain of Nodes based on Nodes above. Note that the clip is
4194 // lowered for the Interpreter backend which performs constant folding.
4195 EXPECT_EQ(2, countNodeKind(constFoldTrans, Kinded::Kind::SplatNodeKind));
4196 EXPECT_EQ(1, countNodeKind(constFoldTrans, Kinded::Kind::MaxNodeKind));
4197 EXPECT_EQ(1, countNodeKind(constFoldTrans, Kinded::Kind::MinNodeKind));
4198 EXPECT_EQ(1, countNodeKind(constFoldTrans, Kinded::Kind::ConvertToNodeKind));
4199 EXPECT_EQ(1, countNodeKind(constFoldTrans, Kinded::Kind::TransposeNodeKind));
4200
4201 EXPECT_EQ(2, countNodeKind(constFoldSig, Kinded::Kind::SplatNodeKind));
4202 EXPECT_EQ(1, countNodeKind(constFoldSig, Kinded::Kind::MaxNodeKind));
4203 EXPECT_EQ(1, countNodeKind(constFoldSig, Kinded::Kind::MinNodeKind));
4204 EXPECT_EQ(1, countNodeKind(constFoldSig, Kinded::Kind::ConvertToNodeKind));
4205 EXPECT_EQ(1, countNodeKind(constFoldSig, Kinded::Kind::SigmoidNodeKind));
4206
4207 // Skip optimizations -- we just want to run them as is (otherwise we'll
4208 // constant fold them inside the optimization pipeline).
4209 cctx_.optimizationOpts.onlyLowerFuns.insert(constFoldTrans);
4210 cctx_.optimizationOpts.onlyLowerFuns.insert(constFoldSig);
4211 cctx_.optimizationOpts.onlyLowerFuns.insert(F_);
4212 cctx_.optimizationOpts.onlyLowerFuns.insert(optimizedF_);
4213
4214 // Don't strip the module as we want to compare the Constant values below.
4215 EE_.setSkipModuleStrip(true);
4216
4217 EE_.compile(cctx_);
4218 alreadyCompiled_ = true;
4219
4220 bindings_.allocate(mod_.getPlaceholders());
4221
4222 // Run the constant folding chain to check that we have the same constant used
4223 // by the optimized Function.
4224 EE_.run(bindings_, constFoldTrans->getName());
4225 EE_.run(bindings_, constFoldSig->getName());
4226
4227 // Find the correct PHs for each of the constant folding we do.
4228 Tensor *rerunTransT = bindings_.get(transSN->getPlaceholder());
4229 Tensor *rerunSigT = bindings_.get(sigSN->getPlaceholder());
4230 ASSERT_TRUE(rerunTransT);
4231 ASSERT_TRUE(rerunSigT);
4232
4233 auto optimizedConstants = optimizedF_->findConstants();
4234 ASSERT_EQ(optimizedConstants.size(), 2);
4235 Constant *transC = *optimizedConstants.begin();
4236 Constant *sigC = *std::next(optimizedConstants.begin());
4237 // If we have the constants backwards then swap them. Note that we know
4238 // sigC must be directly saved, while transC is input to a MatMulNode.
4239 ASSERT_EQ(transC->getNumUsers(), 1);
4240 if (llvm::isa<SaveNode>(transC->getUsers().begin()->getUser())) {
4241 std::swap(transC, sigC);
4242 }
4243 EXPECT_TRUE(transC->getPayload().isEqual(*rerunTransT, 0.f));
4244 EXPECT_TRUE(sigC->getPayload().isEqual(*rerunSigT, 0.f));
4245
4246 // Remove the temporary constant folding Functions and their Placeholders.
4247 cleanupConstantFolding(mod_, record, &bindings_);
4248
4249 // Now compile/run/compare F_ and optimizedF_.
4250 checkNumericalEquivalence(0.f);
4251 }
4252
4253 /// Test that we correctly record a single constant folding subgraph that has
4254 /// two outputs.
TEST_F(GraphOptz,constantFoldWithRecordSingleChainMultiOutput)4255 TEST_F(GraphOptz, constantFoldWithRecordSingleChainMultiOutput) {
4256 Constant *W = mod_.createConstant(ElemKind::FloatTy, {100}, "weight");
4257 SigmoidNode *sigmoidW = F_->createSigmoid("sig", W);
4258 ConvertToNode *convertW =
4259 F_->createConvertTo("conv", sigmoidW, ElemKind::Float16Ty);
4260 TopKNode *TK = F_->createTopK("topk", convertW, 5);
4261
4262 SaveNode *indicesSave = F_->createSave("save_indices", TK->getIndices());
4263 Placeholder *indicesP = indicesSave->getPlaceholder();
4264 bindings_.allocate(indicesP);
4265
4266 Placeholder *I = mod_.createPlaceholder(ElemKind::Float16Ty, {5}, "input",
4267 /* isTrainable */ false);
4268 AddNode *add = F_->createAdd("add", I, TK->getValues());
4269 SaveNode *addSave = F_->createSave("save_add", add);
4270 Placeholder *addP = addSave->getPlaceholder();
4271 bindings_.allocate(addP);
4272
4273 ASSERT_TRUE(F_->verify());
4274
4275 Tensor *IT = bindings_.allocate(I);
4276 IT->getHandle<float16_t>().randomize(-10, 10, mod_.getPRNG());
4277 W->getPayloadMutable().getHandle<float>().randomize(-10, 10, mod_.getPRNG());
4278
4279 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4280
4281 ConstantFoldingRecordMap record = constantFoldAndRecord(optimizedF_, cctx_);
4282
4283 runDCEPass(optimizedF_, cctx_);
4284
4285 ASSERT_EQ(record.size(), 2);
4286 SaveNode *indicesSN = record.begin()->second;
4287 SaveNode *addSN = std::next(record.begin())->second;
4288
4289 // Find the correct PHs for each of the constant folding we do.
4290 if (indicesSN->getInput().getResNo() != TopKNode::IndicesIdx) {
4291 std::swap(indicesSN, addSN);
4292 }
4293
4294 // Expect that the two constants that we folded are from the same Function,
4295 // and that the two saves use the two different outputs from a topk.
4296 EXPECT_EQ(indicesSN->getParent(), addSN->getParent());
4297 ASSERT_TRUE(llvm::isa<TopKNode>(addSN->getInput()));
4298 ASSERT_TRUE(llvm::isa<TopKNode>(indicesSN->getInput()));
4299 EXPECT_EQ(addSN->getInput().getNode(), indicesSN->getInput().getNode());
4300
4301 Function *constFoldF = addSN->getParent();
4302
4303 // Expect to find a chain of Nodes based on Nodes above.
4304 EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::TopKNodeKind));
4305 EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::SigmoidNodeKind));
4306 EXPECT_EQ(1, countNodeKind(constFoldF, Kinded::Kind::ConvertToNodeKind));
4307
4308 // Skip optimizations -- we just want to run them as is (otherwise we'll
4309 // constant fold them inside the optimization pipeline).
4310 cctx_.optimizationOpts.onlyLowerFuns.insert(constFoldF);
4311 cctx_.optimizationOpts.onlyLowerFuns.insert(F_);
4312 cctx_.optimizationOpts.onlyLowerFuns.insert(optimizedF_);
4313
4314 // Don't strip the module as we want to compare the Constant values below.
4315 EE_.setSkipModuleStrip(true);
4316
4317 EE_.compile(cctx_);
4318 alreadyCompiled_ = true;
4319
4320 bindings_.allocate(mod_.getPlaceholders());
4321
4322 // Run the constant folding chain to check that we have the same constant used
4323 // by the optimized Function.
4324 EE_.run(bindings_, constFoldF->getName());
4325
4326 Tensor *rerunAddT = bindings_.get(addSN->getPlaceholder());
4327 Tensor *rerunIndicesT = bindings_.get(indicesSN->getPlaceholder());
4328 ASSERT_TRUE(rerunAddT);
4329 ASSERT_TRUE(rerunIndicesT);
4330
4331 auto optimizedConstants = optimizedF_->findConstants();
4332 ASSERT_EQ(optimizedConstants.size(), 2);
4333 Constant *addC = *optimizedConstants.begin();
4334 Constant *indicesC = *std::next(optimizedConstants.begin());
4335
4336 // If we have the constants backwards then swap them. Note that we know
4337 // indicesC must be directly saved, while addC is input to an AddNode.
4338 ASSERT_EQ(addC->getNumUsers(), 1);
4339 if (llvm::isa<SaveNode>(addC->getUsers().begin()->getUser())) {
4340 std::swap(addC, indicesC);
4341 }
4342 EXPECT_TRUE(addC->getPayload().isEqual(*rerunAddT, 0.f));
4343 EXPECT_TRUE(indicesC->getPayload().isEqual(*rerunIndicesT, 0.f));
4344
4345 // Remove the temporary constant folding Functions and their Placeholders.
4346 cleanupConstantFolding(mod_, record, &bindings_);
4347
4348 // Now compile/run/compare F_ and optimizedF_.
4349 checkNumericalEquivalence(0.f);
4350 }
4351
TEST_F(GraphOptz,constantFoldWholeFunction)4352 TEST_F(GraphOptz, constantFoldWholeFunction) {
4353 auto *const1 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const1");
4354 auto *const2 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const2");
4355 auto *const3 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const3");
4356 auto *const4 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const4");
4357 auto *ph1 = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "input1",
4358 /* isTrainable */ false);
4359 setConstValue(const1, 1.0f);
4360 setConstValue(const2, 2.0f);
4361 setConstValue(const3, 3.0f);
4362 setConstValue(const4, 4.0f);
4363 auto *splat2 = F_->createSplat(
4364 "splat2", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 2.0f);
4365 auto *splat3 = F_->createSplat(
4366 "splat2", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 3.0f);
4367 auto *splat4 = F_->createSplat(
4368 "splat2", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 4.0f);
4369
4370 auto *add1 = F_->createAdd("add", const1, const2);
4371 auto *mul1 = F_->createMul("mul1", add1, splat2);
4372 auto *mul2 = F_->createMul("mul2", mul1, splat3);
4373 auto *sub = F_->createSub("sub", mul2, const3);
4374 auto *add2 = F_->createAdd("add2", sub, const4);
4375 auto *mul3 = F_->createMul("mul3", add2, splat4);
4376 // Check compile-time constant folding for nodes with multiple results.
4377 auto *topK = F_->createTopK("topK", mul3, 2);
4378 auto *SN1_0 = F_->createSave("save", topK->getValues());
4379 auto *SN1_1 = F_->createSave("save", topK->getIndices());
4380 auto *add3 = F_->createAdd("add", const1, ph1);
4381 auto *SN2 = F_->createSave("save", add3);
4382
4383 // Perform constant folding for a whole function.
4384 ::glow::optimize(F_, CompilationMode::Infer);
4385
4386 EXPECT_EQ(F_->getNodes().size(), 4);
4387 // Second save should be unaffected, as its value is not a constant operation.
4388 EXPECT_FALSE(llvm::isa<Constant>(SN2->getInput()));
4389 // First save should have been constant folded.
4390 EXPECT_TRUE(llvm::isa<Constant>(SN1_0->getInput()));
4391 EXPECT_TRUE(llvm::isa<Constant>(SN1_1->getInput()));
4392 Constant *C = llvm::dyn_cast<Constant>(SN1_0->getInput());
4393 auto CH = C->getHandle();
4394 // The expected result should be: (((1+2) * 2 * 3 - 3) + 4) * 4 = 76
4395 EXPECT_EQ(CH.at({0, 0}), 76.0f);
4396 EXPECT_EQ(CH.at({0, 1}), 76.0f);
4397 EXPECT_EQ(CH.at({1, 0}), 76.0f);
4398 EXPECT_EQ(CH.at({1, 1}), 76.0f);
4399 }
4400
4401 /// Test constant folding for operators which are lowered in Interpreter
4402 /// backend.
TEST_F(GraphOptz,constantFoldWithLowering)4403 TEST_F(GraphOptz, constantFoldWithLowering) {
4404 auto *input = mod_.createConstant(ElemKind::FloatTy, {1, 6}, "input");
4405 input->getHandle() = {5, 4, 3, 2, 1, 0};
4406 auto *TN = F_->createTile("tile", input, 5, 0);
4407 auto *SN = F_->createSave("ret", TN);
4408
4409 // Perform constant folding.
4410 EXPECT_EQ(F_->getNodes().size(), 2);
4411 ::glow::optimize(F_, CompilationMode::Infer);
4412
4413 // Tile with its input should be folded into a single Constant node.
4414 EXPECT_EQ(F_->getNodes().size(), 1);
4415 ASSERT_TRUE(llvm::isa<Constant>(SN->getInput()));
4416 }
4417
4418 /// Test Splitting FC into multiple FCs.
TEST_F(GraphOptz,SplitFCIntoMultipleOps)4419 TEST_F(GraphOptz, SplitFCIntoMultipleOps) {
4420 auto *input =
4421 mod_.createPlaceholder(ElemKind::FloatTy, {2, 32}, "input", false);
4422 bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
4423 mod_.getPRNG());
4424 auto *weights = mod_.createConstant(ElemKind::FloatTy, {32, 850}, "weights");
4425 weights->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
4426 auto *bias = mod_.createConstant(ElemKind::FloatTy, {850}, "bias");
4427 bias->getHandle().randomize(0.0, 0.5, mod_.getPRNG());
4428 auto *output =
4429 mod_.createPlaceholder(ElemKind::FloatTy, {2, 850}, "output", false);
4430 bindings_.allocate(output);
4431
4432 auto *fc = F_->createFullyConnected("fc", input, weights, bias);
4433 auto *save = F_->createSave("save", fc, output);
4434
4435 ::glow::optimize(F_, CompilationMode::Infer);
4436
4437 // This is F_ but without the parallel transformation below.
4438 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4439
4440 EXPECT_TRUE(::glow::executeVerticalFCWeightsSplit(F_,
4441 /*numOfChunks*/ 12,
4442 /*minKToSplit*/ 800));
4443 runDCEPass(F_, cctx_);
4444
4445 // 24 Slices: 12 from bias and 12 from weights.
4446 EXPECT_EQ(24, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
4447
4448 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
4449
4450 // 12 newly created FCs.
4451 EXPECT_EQ(12, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
4452
4453 auto *concatNode = llvm::dyn_cast<ConcatNode>(save->getInput());
4454 ASSERT_TRUE(concatNode);
4455 // 12 FCs are connected to the concat node.
4456 EXPECT_EQ(12, concatNode->getInputs().size());
4457
4458 // Check all splitted FCs.
4459 for (unsigned i = 0; i < 12; ++i) {
4460 auto *fc = llvm::dyn_cast<FullyConnectedNode>(concatNode->getNthInput(i));
4461 ASSERT_TRUE(fc);
4462 // 2 * 71 for first 11 FCs and last 2 * 69
4463 if (i == 11) {
4464 EXPECT_TRUE(fc->getResult().dims().equals({2, 69}));
4465 EXPECT_TRUE(fc->getBias().dims().equals({69}));
4466 EXPECT_TRUE(fc->getWeights().dims().equals({32, 69}));
4467 } else {
4468 EXPECT_TRUE(fc->getResult().dims().equals({2, 71}));
4469 EXPECT_TRUE(fc->getBias().dims().equals({71}));
4470 EXPECT_TRUE(fc->getWeights().dims().equals({32, 71}));
4471 }
4472 }
4473
4474 checkNumericalEquivalence();
4475 }
4476
4477 /// Test Splitting FC into multiple FCs.
TEST_F(GraphOptz,ParallelizeGraph_FC_ModelParallel)4478 TEST_F(GraphOptz, ParallelizeGraph_FC_ModelParallel) {
4479 auto *input =
4480 mod_.createPlaceholder(ElemKind::FloatTy, {8, 32}, "input", false);
4481 bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
4482 mod_.getPRNG());
4483 auto *weights1 = mod_.createConstant(ElemKind::FloatTy, {32, 150}, "weights");
4484 weights1->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
4485 auto *bias1 = mod_.createConstant(ElemKind::FloatTy, {150}, "bias");
4486 bias1->getHandle().randomize(0.0, 0.5, mod_.getPRNG());
4487 auto *weights2 =
4488 mod_.createConstant(ElemKind::FloatTy, {150, 150}, "weights");
4489 weights2->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
4490 auto *bias2 = mod_.createConstant(ElemKind::FloatTy, {150}, "bias");
4491 bias2->getHandle().randomize(0.0, 0.5, mod_.getPRNG());
4492 auto *output =
4493 mod_.createPlaceholder(ElemKind::FloatTy, {8, 150}, "output", false);
4494 bindings_.allocate(output);
4495
4496 auto *fc1 = F_->createFullyConnected("fc1", input, weights1, bias1);
4497 auto *relu1 = F_->createRELU("relu1", fc1);
4498
4499 auto *fc2 = F_->createFullyConnected("fc2", relu1, weights2, bias2);
4500 auto *relu2 = F_->createRELU("relu2", fc2);
4501 F_->createSave("save", relu2, output);
4502
4503 ::glow::optimize(F_, CompilationMode::Infer);
4504
4505 // This is F_ but without the parallel transformation below.
4506 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4507
4508 // Perform parallel transformation on F_.
4509 llvm::DenseMap<Node *, size_t> numChunks;
4510 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
4511 numChunks[fc1] = 2;
4512 numChunks[relu1] = 2;
4513 numChunks[fc2] = 2;
4514 numChunks[relu2] = 2;
4515 parOpts[fc1] = ParallelTransformKind::Model;
4516 parOpts[relu1] = ParallelTransformKind::Model;
4517 parOpts[fc2] = ParallelTransformKind::Model;
4518 parOpts[relu2] = ParallelTransformKind::Model;
4519 std::unordered_map<Node *, ConcatNode *> replacedMap;
4520 ASSIGN_VALUE_OR_FAIL_TEST(replacedMap,
4521 ::glow::parallelizeOps(F_, numChunks, parOpts));
4522 EXPECT_EQ(replacedMap.size(), parOpts.size());
4523
4524 runDCEPass(F_, cctx_);
4525
4526 EXPECT_EQ(4, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
4527 EXPECT_EQ(4, countNodeKind(F_, Kinded::Kind::ReluNodeKind));
4528
4529 checkNumericalEquivalence();
4530 }
4531
4532 /// Test Splitting Add into multiple Adds.
TEST_F(GraphOptz,ParallelizeGraph_Add)4533 TEST_F(GraphOptz, ParallelizeGraph_Add) {
4534 auto *input1 =
4535 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input1", false);
4536 bindings_.allocate(input1)->getHandle<float>().randomize(-1.0, 1.0,
4537 mod_.getPRNG());
4538 auto *input2 =
4539 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "input2", false);
4540 bindings_.allocate(input2)->getHandle<float>().randomize(-1.0, 1.0,
4541 mod_.getPRNG());
4542 auto *output =
4543 mod_.createPlaceholder(ElemKind::FloatTy, {32, 2048}, "output", false);
4544 bindings_.allocate(output);
4545
4546 auto *add1 = F_->createAdd("add1", input1, input2);
4547 auto *add2 = F_->createAdd("add2", add1, add1);
4548 F_->createSave("save", add2, output);
4549
4550 ::glow::optimize(F_, CompilationMode::Infer);
4551
4552 // This is F_ but without the parallel transformation below.
4553 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4554
4555 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
4556 parOpts[add1] = ParallelTransformKind::Data;
4557
4558 std::unordered_map<Node *, ConcatNode *> replacedMap;
4559 ASSIGN_VALUE_OR_FAIL_TEST(
4560 replacedMap, ::glow::parallelizeOps(F_, llvm::DenseMap<Node *, size_t>(),
4561 parOpts, 12));
4562 EXPECT_EQ(replacedMap.size(), parOpts.size());
4563 runDCEPass(F_, cctx_);
4564
4565 // We now have 12 Adds from add1, as well as the original add2 which is
4566 // unchanged.
4567 EXPECT_EQ(13, countNodeKind(F_, Kinded::Kind::AddNodeKind));
4568
4569 // Each input of the 12 Adds are sliced.
4570 EXPECT_EQ(24, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
4571
4572 // One concat to bring all of the parallelized sliced Adds together.
4573 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ConcatNodeKind));
4574
4575 checkNumericalEquivalence();
4576 }
4577
4578 /// Test Splitting Transpose into multiple Transposes.
TEST_F(GraphOptz,ParallelizeGraph_Transpose)4579 TEST_F(GraphOptz, ParallelizeGraph_Transpose) {
4580 auto *input =
4581 mod_.createPlaceholder(ElemKind::FloatTy, {32, 151, 64}, "input", false);
4582 bindings_.allocate(input)->getHandle<float>().randomize(-1.0, 1.0,
4583 mod_.getPRNG());
4584 auto *output =
4585 mod_.createPlaceholder(ElemKind::FloatTy, {32, 64, 151}, "output", false);
4586 bindings_.allocate(output);
4587
4588 auto *trans1 = F_->createTranspose("trans1", input, {0, 2, 1});
4589 F_->createSave("save", trans1, output);
4590
4591 ::glow::optimize(F_, CompilationMode::Infer);
4592
4593 // This is F_ but without the parallel transformation below.
4594 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
4595
4596 llvm::DenseMap<Node *, size_t> numChunks;
4597 llvm::DenseMap<Node *, ParallelTransformKind> parOpts;
4598 numChunks[trans1] = 2;
4599 parOpts[trans1] = ParallelTransformKind::Data;
4600 std::unordered_map<Node *, ConcatNode *> replacedMap;
4601 ASSIGN_VALUE_OR_FAIL_TEST(replacedMap,
4602 ::glow::parallelizeOps(F_, numChunks, parOpts));
4603 EXPECT_EQ(replacedMap.size(), parOpts.size());
4604
4605 runDCEPass(F_, cctx_);
4606
4607 EXPECT_EQ(2, countNodeKind(F_, Kinded::Kind::TransposeNodeKind));
4608
4609 checkNumericalEquivalence();
4610 }
4611
TEST_F(GraphOptz,SinkClipBelowReshape)4612 TEST_F(GraphOptz, SinkClipBelowReshape) {
4613 Placeholder *in =
4614 mod_.createPlaceholder(ElemKind::FloatTy, {10}, "input", false);
4615 ClipNode *clip = F_->createClip("clip", in, 0.2, 0.8);
4616 ReshapeNode *reshape = F_->createReshape("reshape", clip, {2, 5});
4617 SaveNode *save = F_->createSave("save", reshape);
4618
4619 optimizedF_ = optimizeFunction(F_);
4620
4621 // Same number of nodes, just swapped order.
4622 EXPECT_EQ(F_->getNodes().size(), 3);
4623 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
4624
4625 const SaveNode *optSave =
4626 findFunctionNodeByName<SaveNode>(optimizedF_, save->getName());
4627 ASSERT_TRUE(optSave);
4628 ClipNode *newClip = llvm::dyn_cast<ClipNode>(optSave->getInput());
4629 ASSERT_TRUE(newClip);
4630 ReshapeNode *newReshape = llvm::dyn_cast<ReshapeNode>(newClip->getInput());
4631 ASSERT_TRUE(newReshape);
4632 EXPECT_EQ(newReshape->getResult().dims(), reshape->getResult().dims());
4633
4634 bindings_.allocate(mod_.getPlaceholders());
4635 bindings_.get(in)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
4636 checkNumericalEquivalence();
4637 }
4638
4639 /// Test that Add after ConvTranspose is folded into Bias add when the actual
4640 /// Add is is a broadcast of the bias. Test \p RnL (right of left) side add.
foldConvTransposeAddIntoBiasAdd(PlaceholderBindings & bindings,Module & mod,Function * F,Function * & optF,bool RnL)4641 static void foldConvTransposeAddIntoBiasAdd(PlaceholderBindings &bindings,
4642 Module &mod, Function *F,
4643 Function *&optF, bool RnL) {
4644 dim_t batch = 2;
4645 dim_t inC = 2;
4646 dim_t outC = 5;
4647 dim_t inH = 3;
4648 dim_t inW = 3;
4649 unsigned_t kernel = 3;
4650 std::vector<uint32_t> pads = {0, 0, 0, 0};
4651 std::vector<uint32_t> stride = {1, 1};
4652
4653 auto *input = mod.createPlaceholder(ElemKind::FloatTy, {2, inH, inW, inC},
4654 "input", false);
4655 auto *filter = mod.createPlaceholder(
4656 ElemKind::FloatTy, {outC, kernel, kernel, inC}, "filter", false);
4657
4658 auto *bias = mod.createConstant(ElemKind::FloatTy, {outC}, "bias");
4659 bias->getPayloadMutable().getHandle<float>() = {1, 3, 5, 7, 9};
4660
4661 std::pair<dim_t, dim_t> outHW = calculateConvTransposeOutputDims(
4662 inH, inW, {kernel, kernel}, stride, pads);
4663 auto outTy = mod.uniqueType(ElemKind::FloatTy,
4664 {batch, outHW.first, outHW.second, outC});
4665
4666 ConvTransposeNode *CTN =
4667 F->createConvTranspose("ConvTranspose", input, filter, bias, outTy,
4668 {kernel, kernel}, stride, {0, 0, 0, 0}, 1);
4669
4670 auto *CN = mod.createConstant(ElemKind::FloatTy,
4671 {batch, outHW.first, outHW.second, outC}, "c1");
4672 auto *AN = RnL ? F->createAdd("add", CN, CTN) : F->createAdd("add", CTN, CN);
4673
4674 CN->getPayloadMutable().getHandle<float>() = {
4675 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3,
4676 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1,
4677 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4,
4678 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2,
4679 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5,
4680 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3,
4681 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1,
4682 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4,
4683 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2,
4684 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5,
4685 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5};
4686
4687 SaveNode *save = F->createSave("save", AN);
4688 bindings.allocate(save->getPlaceholder());
4689
4690 EXPECT_EQ(F->getNodes().size(), 3);
4691 optF = optimizeFunction(F);
4692 EXPECT_EQ(optF->getNodes().size(), 2);
4693
4694 const SaveNode *optSave =
4695 findFunctionNodeByName<SaveNode>(optF, save->getName());
4696
4697 ConvTransposeNode *optCN =
4698 llvm::dyn_cast<ConvTransposeNode>(optSave->getInput());
4699 EXPECT_TRUE(optCN);
4700
4701 Constant *optBias = llvm::dyn_cast<Constant>(optCN->getBias());
4702 EXPECT_TRUE(optBias);
4703
4704 auto BH = optBias->getPayload().getHandle();
4705 EXPECT_EQ(BH.raw(0), 1 + 1);
4706 EXPECT_EQ(BH.raw(1), 2 + 3);
4707 EXPECT_EQ(BH.raw(2), 3 + 5);
4708 EXPECT_EQ(BH.raw(3), 4 + 7);
4709 EXPECT_EQ(BH.raw(4), 5 + 9);
4710
4711 bindings.allocate(mod.getPlaceholders());
4712 bindings.get(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
4713 bindings.get(filter)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
4714 }
4715
4716 /// Test that Add after ConvTranspose is folded into Bias add when the actual
4717 /// Add is is a broadcast of the bias.
TEST_F(GraphOptz,FoldConvTransposeAddIntoBiasAddRHS)4718 TEST_F(GraphOptz, FoldConvTransposeAddIntoBiasAddRHS) {
4719 foldConvTransposeAddIntoBiasAdd(bindings_, mod_, F_, optimizedF_, false);
4720 checkNumericalEquivalence();
4721 }
TEST_F(GraphOptz,FoldConvTransposeAddIntoBiasAddLHS)4722 TEST_F(GraphOptz, FoldConvTransposeAddIntoBiasAddLHS) {
4723 foldConvTransposeAddIntoBiasAdd(bindings_, mod_, F_, optimizedF_, true);
4724 checkNumericalEquivalence();
4725 }
4726
4727 /// Test that MatMul + Add is folded into FullyConnected.
TEST_F(GraphOptz,FoldMatMulAddIntoFullyConnected)4728 TEST_F(GraphOptz, FoldMatMulAddIntoFullyConnected) {
4729
4730 auto *input =
4731 mod_.createPlaceholder(ElemKind::FloatTy, {1, 3}, "input", false);
4732 auto *weights =
4733 mod_.createPlaceholder(ElemKind::FloatTy, {3, 5}, "weights", false);
4734 auto *bias = mod_.createPlaceholder(ElemKind::FloatTy, {1, 5}, "bias", false);
4735
4736 MatMulNode *matmul = F_->createMatMul("matmul", input, weights);
4737 AddNode *add = F_->createAdd("add", matmul, bias);
4738 F_->createSave("save", add);
4739 EXPECT_EQ(3, F_->getNodes().size());
4740
4741 // The folding should replace the MatMul + Add into a FullyConnected and a
4742 // Reshape to 1D for the Bias.
4743 CompilationContext cctx;
4744 ::glow::fold(F_, cctx);
4745 EXPECT_EQ(3, F_->getNodes().size());
4746 EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::AddNodeKind));
4747 EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::MatMulNodeKind));
4748 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
4749 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ReshapeNodeKind));
4750 }
4751
4752 /// Test that batched MatMul + Add is folded into batched FullyConnected.
4753 /// This optimization takes place only if the Bias is constant and the
4754 /// bias data repeats for all the batches.
TEST_F(GraphOptz,FoldMatMulAddIntoFullyConnectedBatched)4755 TEST_F(GraphOptz, FoldMatMulAddIntoFullyConnectedBatched) {
4756
4757 auto *input =
4758 mod_.createPlaceholder(ElemKind::FloatTy, {2, 3}, "input", false);
4759 auto *weights =
4760 mod_.createPlaceholder(ElemKind::FloatTy, {3, 5}, "weights", false);
4761 auto *bias = mod_.createConstant(ElemKind::FloatTy, {2, 5}, "bias");
4762 auto biasH = bias->getPayloadMutable().getHandle<float>();
4763 biasH = {1, 2, 3, 4, 5, 1, 2, 3, 4, 5};
4764
4765 MatMulNode *matmul = F_->createMatMul("matmul", input, weights);
4766 AddNode *add = F_->createAdd("add", matmul, bias);
4767 F_->createSave("save", add);
4768 EXPECT_EQ(3, F_->getNodes().size());
4769
4770 // The folding should replace the MatMul + Add into a FullyConnected and a
4771 // Reshape to 1D for the Bias.
4772 CompilationContext cctx;
4773 ::glow::fold(F_, cctx);
4774 EXPECT_EQ(4, F_->getNodes().size());
4775 EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::AddNodeKind));
4776 EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::MatMulNodeKind));
4777 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
4778 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::SliceNodeKind));
4779 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::ReshapeNodeKind));
4780 }
4781
4782 /// Test that FoldSlicesIntoConstants pass works as expected.
TEST_F(GraphOptz,FoldSlicesIntoConstantsTest)4783 TEST_F(GraphOptz, FoldSlicesIntoConstantsTest) {
4784 Constant *C = mod_.createConstant(ElemKind::FloatTy, {3, 4}, "C");
4785 auto CH = C->getPayloadMutable().getHandle<float>();
4786 CH = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
4787
4788 SliceNode *S1 = F_->createSlice("s1", C, {0, 0}, {3, 2});
4789 SliceNode *S2 = F_->createSlice("s2", C, {0, 2}, {3, 4});
4790 SaveNode *SN1 = F_->createSave("save1", S1);
4791 SaveNode *SN2 = F_->createSave("save2", S2);
4792
4793 optimizedF_ = optimizeFunction(
4794 F_, {FunctionPassID::FoldSlicesIntoConstants, getDCEPassConfig()});
4795
4796 SaveNode *optSN1 =
4797 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN1->getName()));
4798 SaveNode *optSN2 =
4799 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN2->getName()));
4800 ASSERT_TRUE(optSN1);
4801 ASSERT_TRUE(optSN2);
4802
4803 Constant *C1 = llvm::dyn_cast<Constant>(optSN1->getInput());
4804 ASSERT_TRUE(C1);
4805 auto H1 = C1->getPayloadMutable().getHandle();
4806 Constant *C2 = llvm::dyn_cast<Constant>(optSN2->getInput());
4807 ASSERT_TRUE(C2);
4808 auto H2 = C2->getPayloadMutable().getHandle();
4809 for (dim_t i = 0, e = 3; i < e; i++) {
4810 for (dim_t j = 0, e = 2; j < e; j++) {
4811 EXPECT_EQ(H1.at({i, j}), CH.at({i, j}));
4812 EXPECT_EQ(H2.at({i, j}), CH.at({i, j + 2}));
4813 }
4814 }
4815 }
4816
4817 /// Test that RaiseClipsAboveShapeNodes pass works as expected.
TEST_F(GraphOptz,RaiseClipsAboveShapeNodesTest)4818 TEST_F(GraphOptz, RaiseClipsAboveShapeNodesTest) {
4819 Placeholder *input =
4820 mod_.createPlaceholder(ElemKind::FloatTy, {256, 64}, "input", false);
4821
4822 ReshapeNode *RN1 = F_->createReshape("reshape1", input, {4, 128, 32});
4823 ReshapeNode *RN2 = F_->createReshape("reshape2", RN1, {64, 256});
4824 TransposeNode *TN = F_->createTranspose("transpose", RN2, {1, 0});
4825 SliceNode *SN = F_->createSlice("slice", TN, {64, 0}, {256, 64});
4826 ClipNode *CN = F_->createClip("clip", SN, -0.1, 0.1);
4827 SaveNode *save1 = F_->createSave("save1", RN1);
4828 SaveNode *save2 = F_->createSave("save2", CN);
4829
4830 optimizedF_ =
4831 optimizeFunction(F_, {FunctionPassID::RaiseClipsAboveShapeNodes});
4832
4833 SaveNode *optSave1 =
4834 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save1->getName()));
4835 ASSERT_TRUE(optSave1);
4836 SaveNode *optSave2 =
4837 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(save2->getName()));
4838 ASSERT_TRUE(optSave2);
4839
4840 // save1 should only have a single untouched Reshape RN1 input which has input
4841 // input into it, because RN1 has multiple users.
4842 ReshapeNode *optRN1 =
4843 llvm::dyn_cast<ReshapeNode>(optSave1->getInput().getNode());
4844 ASSERT_TRUE(optRN1);
4845 EXPECT_EQ(input, optRN1->getInput().getNode());
4846
4847 // save2 should have CN it originally saved pushed up above SN, TN, and RN2.
4848 SliceNode *newSN = llvm::dyn_cast<SliceNode>(optSave2->getInput());
4849 ASSERT_TRUE(newSN);
4850 EXPECT_EQ(newSN->getStart(), SN->getStart());
4851 TransposeNode *newTN = llvm::dyn_cast<TransposeNode>(newSN->getInput());
4852 ASSERT_TRUE(newTN);
4853 EXPECT_EQ(newTN->getShuffle(), TN->getShuffle());
4854 ReshapeNode *newRN2 = llvm::dyn_cast<ReshapeNode>(newTN->getInput());
4855 ASSERT_TRUE(newRN2);
4856 ClipNode *newCN = llvm::dyn_cast<ClipNode>(newRN2->getInput());
4857 ASSERT_TRUE(newCN);
4858 EXPECT_EQ(newCN->getMin(), CN->getMin());
4859 EXPECT_EQ(newCN->getMax(), CN->getMax());
4860
4861 bindings_.allocate(mod_.getPlaceholders());
4862 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
4863 checkNumericalEquivalence();
4864 }
4865
testOptimizeDequantizeClip(PlaceholderBindings & bindings,Module & mod,Function * F,Function * & optF,bool enableQuantParamChanges)4866 static void testOptimizeDequantizeClip(PlaceholderBindings &bindings,
4867 Module &mod, Function *F,
4868 Function *&optF,
4869 bool enableQuantParamChanges) {
4870 Placeholder *input =
4871 mod.createPlaceholder(ElemKind::FloatTy, {20, 20}, "input", false);
4872
4873 const auto qParams = quantization::chooseQuantizationParams({-0.1, 0.1});
4874
4875 QuantizeNode *QN =
4876 F->createQuantize("quantize", input,
4877 mod.uniqueType(ElemKind::Int8QTy, {20, 20},
4878 qParams.scale, qParams.offset));
4879 DequantizeNode *DN = F->createDequantize("dequantize", QN, ElemKind::FloatTy);
4880 ClipNode *CN =
4881 F->createClip("clip", DN, enableQuantParamChanges ? 0 : -100, 100);
4882 SaveNode *SN = F->createSave("save", CN);
4883
4884 CompilationContext cctx;
4885 cctx.optimizationOpts.enableQuantParamChanges = true;
4886 optF = optimizeFunction(
4887 F, {FunctionPassID::OptimizeQuantizeClip, getDCEPassConfig()}, cctx);
4888
4889 EXPECT_EQ(countNodeKind(optF, Kinded::Kind::ClipNodeKind), 0);
4890
4891 SaveNode *optSN =
4892 llvm::dyn_cast<SaveNode>(optF->getNodeByName(SN->getName()));
4893 ASSERT_TRUE(optSN);
4894
4895 // Now check that the quantization params have been correctly updated for QN,
4896 // and that CN has been eliminated.
4897 DequantizeNode *optDN =
4898 llvm::dyn_cast<DequantizeNode>(optSN->getInput().getNode());
4899 ASSERT_TRUE(optDN);
4900 const auto qMinMax = optDN->getInput().getType()->getQuantizedValueRange();
4901 // Min is either from Clip or Quant range depending on enableQuantParamChanges
4902 EXPECT_NEAR(qMinMax.first, enableQuantParamChanges ? 0 : -0.1, 1E-3);
4903 EXPECT_NEAR(qMinMax.second, 0.1, 1E-3); // Max from Quant range
4904
4905 bindings.allocate(mod.getPlaceholders());
4906 bindings.get(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
4907 }
4908
4909 /// Test that OptimizeQuantizeClip pass works as expected for Clip(Dequantize)
4910 /// when the quantization parameters are allowed to change.
TEST_F(GraphOptz,OptimizeDequantizeClipTest_QuantParamChanges)4911 TEST_F(GraphOptz, OptimizeDequantizeClipTest_QuantParamChanges) {
4912 testOptimizeDequantizeClip(bindings_, mod_, F_, optimizedF_,
4913 /* enableQuantParamChanges */ true);
4914 checkNumericalEquivalence(0.0005);
4915 }
4916
4917 /// Test that OptimizeQuantizeClip pass works as expected for Clip(Dequantize)
4918 /// when the quantization parameters are not allowed to change.
TEST_F(GraphOptz,OptimizeDequantizeClipTest_NoQuantParamChanges)4919 TEST_F(GraphOptz, OptimizeDequantizeClipTest_NoQuantParamChanges) {
4920 testOptimizeDequantizeClip(bindings_, mod_, F_, optimizedF_,
4921 /* enableQuantParamChanges */ false);
4922 checkNumericalEquivalence();
4923 }
4924
testOptimizeClipQuantize(PlaceholderBindings & bindings,Module & mod,Function * F,Function * & optF,bool enableQuantParamChanges)4925 static void testOptimizeClipQuantize(PlaceholderBindings &bindings, Module &mod,
4926 Function *F, Function *&optF,
4927 bool enableQuantParamChanges) {
4928 Placeholder *input =
4929 mod.createPlaceholder(ElemKind::FloatTy, {20, 20}, "input", false);
4930
4931 const auto qParams = quantization::chooseQuantizationParams({-0.1, 0.1});
4932
4933 ClipNode *CN =
4934 F->createClip("clip", input, enableQuantParamChanges ? 0 : -100, 100);
4935 QuantizeNode *QN =
4936 F->createQuantize("quantize", CN,
4937 mod.uniqueType(ElemKind::Int8QTy, {20, 20},
4938 qParams.scale, qParams.offset));
4939 DequantizeNode *DN = F->createDequantize("dequantize", QN, ElemKind::FloatTy);
4940 SaveNode *SN = F->createSave("save", DN);
4941
4942 CompilationContext cctx;
4943 cctx.optimizationOpts.enableQuantParamChanges = enableQuantParamChanges;
4944 optF = optimizeFunction(
4945 F, {FunctionPassID::OptimizeQuantizeClip, getDCEPassConfig()}, cctx);
4946
4947 EXPECT_EQ(countNodeKind(optF, Kinded::Kind::ClipNodeKind), 0);
4948
4949 SaveNode *optSN =
4950 llvm::dyn_cast<SaveNode>(optF->getNodeByName(SN->getName()));
4951 ASSERT_TRUE(optSN);
4952
4953 // Now check that the quantization params have been correctly updated for QN,
4954 // and that CN has been eliminated.
4955 DequantizeNode *optDN =
4956 llvm::dyn_cast<DequantizeNode>(optSN->getInput().getNode());
4957 ASSERT_TRUE(optDN);
4958 const auto qMinMax = optDN->getInput().getType()->getQuantizedValueRange();
4959 // Min is either from Clip or Quant range depending on enableQuantParamChanges
4960 EXPECT_NEAR(qMinMax.first, enableQuantParamChanges ? 0 : -0.1, 1E-3);
4961 EXPECT_NEAR(qMinMax.second, 0.1, 1E-3); // Max always from Quant range
4962
4963 bindings.allocate(mod.getPlaceholders());
4964 bindings.get(input)->getHandle().randomize(-1.0, 1.0, mod.getPRNG());
4965 }
4966
4967 /// Test that OptimizeQuantizeClip pass works as expected for Clip(Quantize)
4968 /// when the quantization parameters are allowed to change.
TEST_F(GraphOptz,OptimizeClipQuantizeTest_QuantParamChanges)4969 TEST_F(GraphOptz, OptimizeClipQuantizeTest_QuantParamChanges) {
4970 testOptimizeClipQuantize(bindings_, mod_, F_, optimizedF_,
4971 /* enableQuantParamChanges */ true);
4972 checkNumericalEquivalence(0.0005);
4973 }
4974
4975 /// Test that OptimizeQuantizeClip pass works as expected for Clip(Quantize)
4976 /// when the quantization parameters are not allowed to change.
TEST_F(GraphOptz,OptimizeClipQuantizeTest_NoQuantParamChanges)4977 TEST_F(GraphOptz, OptimizeClipQuantizeTest_NoQuantParamChanges) {
4978 testOptimizeClipQuantize(bindings_, mod_, F_, optimizedF_,
4979 /* enableQuantParamChanges */ false);
4980 checkNumericalEquivalence();
4981 }
4982
4983 /// Test Quantize(ConvertTo(Node)) -> Quantize(Node), where Quantize is int8.
TEST_F(GraphOptz,OptimizeOutIntermediateConversionsTest)4984 TEST_F(GraphOptz, OptimizeOutIntermediateConversionsTest) {
4985 Placeholder *input =
4986 mod_.createPlaceholder(ElemKind::FloatTy, {20, 20}, "input", false);
4987
4988 const auto qParams = quantization::chooseQuantizationParams({-0.1, 0.1});
4989
4990 ConvertToNode *CN = F_->createConvertTo("conv", input, ElemKind::Float16Ty);
4991 QuantizeNode *QN =
4992 F_->createQuantize("quantize", CN,
4993 mod_.uniqueType(ElemKind::Int8QTy, {20, 20},
4994 qParams.scale, qParams.offset));
4995 DequantizeNode *DN =
4996 F_->createDequantize("dequantize", QN, ElemKind::FloatTy);
4997 F_->createSave("save", DN);
4998
4999 optimizedF_ =
5000 optimizeFunction(F_, {FunctionPassID::OptimizeOutIntermediateConversions,
5001 getDCEPassConfig()});
5002
5003 // Now check that the ConvertToNode has been eliminated.
5004 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConvertToNodeKind), 0);
5005
5006 bindings_.allocate(mod_.getPlaceholders());
5007 bindings_.get(input)->getHandle().randomize(-1.0, 1.0, mod_.getPRNG());
5008 checkNumericalEquivalence();
5009 }
5010
5011 /// Test Clip(Relu(Clip)) -> Clip'.
TEST_F(GraphOptz,ClipReluClipElimTest)5012 TEST_F(GraphOptz, ClipReluClipElimTest) {
5013 Placeholder *input =
5014 mod_.createPlaceholder(ElemKind::FloatTy, {64, 64}, "input", false);
5015 ClipNode *CN1 = F_->createClip("CN1", input, -10, 30);
5016 ReluNode *RN = F_->createRELU("RN", CN1);
5017 ClipNode *CN2 = F_->createClip("CN2", RN, -5, 20);
5018 SaveNode *SN = F_->createSave("save", CN2);
5019
5020 // Start with 2 clips, a relu, and a save.
5021 EXPECT_EQ(F_->getNodes().size(), 4);
5022 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ClipNodeKind), 2);
5023 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 1);
5024
5025 optimizedF_ = optimizeFunction(F_);
5026
5027 // Remove one of the clips and the relu.
5028 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
5029 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ClipNodeKind), 1);
5030 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ReluNodeKind), 0);
5031
5032 SaveNode *optSN =
5033 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
5034 ASSERT_TRUE(optSN);
5035
5036 // We combined all of the ranges into the single Clip.
5037 ClipNode *optCN = llvm::dyn_cast<ClipNode>(optSN->getInput());
5038 ASSERT_TRUE(optCN);
5039 EXPECT_EQ(optCN->getMin(), 0);
5040 EXPECT_EQ(optCN->getMax(), 20);
5041
5042 bindings_.allocate(input)->getHandle().randomize(-50.0, 5.0, mod_.getPRNG());
5043 checkNumericalEquivalence();
5044 }
5045
5046 /// Test that we can find a non-quantized relu and fuse it up into a quant FC.
TEST_F(GraphOptz,OptimizeQuantFCFloatReluTest)5047 TEST_F(GraphOptz, OptimizeQuantFCFloatReluTest) {
5048 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32}, 1.0, 0,
5049 "input", false);
5050 auto *weights =
5051 mod_.createConstant(ElemKind::Int8QTy, {32, 32}, 1.0, 0, "weights");
5052 auto *bias = mod_.createConstant(ElemKind::Int32QTy, {32}, 1.0, 0, "bias");
5053
5054 auto *FC = F_->createFullyConnected("fc", input, weights, bias);
5055 auto *DN = F_->createDequantize("dq", FC, ElemKind::FloatTy);
5056 auto *RN = F_->createRELU("relu", DN);
5057 auto *SN = F_->createSave("save", RN);
5058
5059 optimizedF_ = optimizeFunction(
5060 F_, {FunctionPassID::OptimizeQuantFCFloatRelu, getDCEPassConfig()});
5061
5062 SaveNode *optSN =
5063 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
5064 ASSERT_TRUE(optSN);
5065
5066 DequantizeNode *optDN = llvm::dyn_cast<DequantizeNode>(optSN->getInput());
5067 ASSERT_TRUE(optDN);
5068 ReluNode *optRN = llvm::dyn_cast<ReluNode>(optDN->getInput());
5069 ASSERT_TRUE(optRN);
5070 auto rangeRN = optRN->getResult().getType()->getQuantizedValueRange();
5071 EXPECT_EQ(rangeRN.first, 0.0f);
5072 FullyConnectedNode *optFC =
5073 llvm::dyn_cast<FullyConnectedNode>(optRN->getInput());
5074 ASSERT_TRUE(optFC);
5075 auto rangeFC = optFC->getResult().getType()->getQuantizedValueRange();
5076 EXPECT_EQ(rangeRN.second, rangeFC.second);
5077
5078 bindings_.allocate(input)->getHandle<int8_t>().randomize(-128, 127,
5079 mod_.getPRNG());
5080 weights->getPayloadMutable().getHandle<int8_t>().randomize(-128, 127,
5081 mod_.getPRNG());
5082 bias->getPayloadMutable().getHandle<int32_t>().randomize(-128, 127,
5083 mod_.getPRNG());
5084 checkNumericalEquivalence();
5085 }
5086
5087 /// Test that we can find a non-quantized relu and fuse it up into a series of
5088 /// concatenated quant FCs.
TEST_F(GraphOptz,OptimizeConcatQuantFCFloatReluTest)5089 TEST_F(GraphOptz, OptimizeConcatQuantFCFloatReluTest) {
5090 std::array<NodeValue, 5> DQs;
5091 for (size_t i = 0; i < 5; i++) {
5092 auto *input = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32},
5093 1.0 / (i + 1), 0, "input", false);
5094 auto *weights =
5095 mod_.createConstant(ElemKind::Int8QTy, {32, 32}, 1.0, 0, "weights");
5096 auto *bias = mod_.createConstant(ElemKind::Int32QTy, {32}, 1.0, 0, "bias");
5097
5098 auto *FC = F_->createFullyConnected("fc", input, weights, bias);
5099 DQs[i] = F_->createDequantize("dq", FC, ElemKind::FloatTy)->getResult();
5100
5101 bindings_.allocate(input)->getHandle<int8_t>().randomize(-128, 127,
5102 mod_.getPRNG());
5103 weights->getPayloadMutable().getHandle<int8_t>().randomize(-128, 127,
5104 mod_.getPRNG());
5105 bias->getPayloadMutable().getHandle<int32_t>().randomize(-128, 127,
5106 mod_.getPRNG());
5107 }
5108
5109 auto *CN = F_->createConcat("concat", DQs, 0);
5110 auto *RN = F_->createRELU("relu", CN);
5111 auto *SN = F_->createSave("save", RN);
5112
5113 optimizedF_ = optimizeFunction(
5114 F_, {FunctionPassID::OptimizeQuantFCFloatRelu, getDCEPassConfig()});
5115
5116 SaveNode *optSN =
5117 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
5118 ASSERT_TRUE(optSN);
5119 ConcatNode *optCN = llvm::dyn_cast<ConcatNode>(optSN->getInput());
5120 ASSERT_TRUE(optCN);
5121 EXPECT_EQ(optCN->getInputs().size(), 5);
5122
5123 for (const NodeValue NV : optCN->getInputs()) {
5124 DequantizeNode *optDN = llvm::dyn_cast<DequantizeNode>(NV);
5125 ASSERT_TRUE(optDN);
5126 ReluNode *optRN = llvm::dyn_cast<ReluNode>(optDN->getInput());
5127 ASSERT_TRUE(optRN);
5128 auto rangeRN = optRN->getResult().getType()->getQuantizedValueRange();
5129 EXPECT_EQ(rangeRN.first, 0.0f);
5130 FullyConnectedNode *optFC =
5131 llvm::dyn_cast<FullyConnectedNode>(optRN->getInput());
5132 ASSERT_TRUE(optFC);
5133 auto rangeFC = optFC->getResult().getType()->getQuantizedValueRange();
5134 EXPECT_EQ(rangeRN.second, rangeFC.second);
5135 }
5136
5137 checkNumericalEquivalence();
5138 }
5139
5140 /// Test that we can find a concat with all dequantize inputs and a quantize at
5141 /// its output, and then replace quant/dequants with rescales.
TEST_F(GraphOptz,OptimizeDequantConcatQuant)5142 TEST_F(GraphOptz, OptimizeDequantConcatQuant) {
5143 std::array<NodeValue, 5> DQs;
5144 std::array<Placeholder *, 5> inputs;
5145 for (size_t i = 0; i < 5; i++) {
5146 inputs[i] = mod_.createPlaceholder(ElemKind::Int8QTy, {2, 32},
5147 0.3 / (i + 1), 5, "input", false);
5148 DQs[i] =
5149 F_->createDequantize("dq", inputs[i], ElemKind::FloatTy)->getResult();
5150
5151 bindings_.allocate(inputs[i])->getHandle<int8_t>().randomize(
5152 -128, 127, mod_.getPRNG());
5153 }
5154
5155 auto *CN = F_->createConcat("concat", DQs, 0);
5156 constexpr float scale = 0.3;
5157 constexpr int32_t offset = 5;
5158 auto *RN = F_->createQuantize("quantize", CN,
5159 mod_.uniqueType(ElemKind::Int8QTy,
5160 CN->getResult().dims(), scale,
5161 offset));
5162 auto *SN = F_->createSave("save", RN);
5163
5164 optimizedF_ = optimizeFunction(
5165 F_, {FunctionPassID::OptimizeConcatQuantization, getDCEPassConfig()});
5166
5167 SaveNode *optSN =
5168 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
5169 ASSERT_TRUE(optSN);
5170 ConcatNode *optCN = llvm::dyn_cast<ConcatNode>(optSN->getInput());
5171 ASSERT_TRUE(optCN);
5172 EXPECT_EQ(optCN->getInputs().size(), 5);
5173
5174 for (size_t i = 0, e = optCN->getInputs().size(); i < e; i++) {
5175 const NodeValue NV = optCN->getInputs()[i];
5176 if (i == 0) {
5177 EXPECT_EQ(inputs[i], NV.getNode());
5178 EXPECT_EQ(inputs[i]->getOutput().getType()->getScale(), scale);
5179 EXPECT_EQ(inputs[i]->getOutput().getType()->getOffset(), offset);
5180 } else {
5181 RescaleQuantizedNode *optRN = llvm::dyn_cast<RescaleQuantizedNode>(NV);
5182 ASSERT_TRUE(optRN);
5183 EXPECT_EQ(optRN->getResult().getType()->getScale(), scale);
5184 EXPECT_EQ(optRN->getResult().getType()->getOffset(), offset);
5185 EXPECT_EQ(inputs[i], optRN->getInput().getNode());
5186 }
5187 }
5188 checkNumericalEquivalence();
5189 }
5190
5191 /// Test that if we have a Concat with all Dequantize inputs with the same
5192 /// scale/offset/kind that we can sink the Dequantizes below the Concat.
TEST_F(GraphOptz,SinkDequantizeBelowConcatTest)5193 TEST_F(GraphOptz, SinkDequantizeBelowConcatTest) {
5194 const float scale = 0.06;
5195 const int32_t offset = -15;
5196 std::array<NodeValue, 5> inputs;
5197 for (dim_t i = 0; i < 5; i++) {
5198 Placeholder *input = mod_.createPlaceholder(ElemKind::Int8QTy, {i + 1, 100},
5199 scale, offset, "input", false);
5200 bindings_.allocate(input)->getHandle<int8_t>().randomize(-100, 100,
5201 mod_.getPRNG());
5202 DequantizeNode *dequantize =
5203 F_->createDequantize("dequantize", input, ElemKind::Float16Ty);
5204 inputs[i] = dequantize->getResult();
5205 }
5206 ConcatNode *concat = F_->createConcat("concat", inputs, 0);
5207 SaveNode *SN = F_->createSave("ret", concat);
5208
5209 optimizedF_ = optimizeFunction(
5210 F_, {FunctionPassID::SinkConversions, getDCEPassConfig()});
5211
5212 // Concat, dequantize, save.
5213 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
5214 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::DequantizeNodeKind), 1);
5215 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1);
5216 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 1);
5217
5218 SaveNode *optSN =
5219 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
5220 ASSERT_TRUE(optSN);
5221 DequantizeNode *optDequantize =
5222 llvm::dyn_cast<DequantizeNode>(optSN->getInput());
5223 ASSERT_TRUE(optDequantize);
5224 NodeValue input = optDequantize->getInput();
5225 EXPECT_EQ(scale, input.getType()->getScale());
5226 EXPECT_EQ(offset, input.getType()->getOffset());
5227 EXPECT_EQ(ElemKind::Int8QTy, input.getType()->getElementType());
5228
5229 // Find dequantize node in the optimized graph.
5230 checkNumericalEquivalence();
5231 }
5232
5233 /// Test that if we have a Concat with all Quantize inputs with the same
5234 /// scale/offset/kind that we can sink the Dequantizes below the Concat.
TEST_F(GraphOptz,SinkQuantizeBelowConcatTest)5235 TEST_F(GraphOptz, SinkQuantizeBelowConcatTest) {
5236 const float scale = 0.06;
5237 const int32_t offset = -15;
5238 std::array<NodeValue, 5> inputs;
5239 for (dim_t i = 0; i < 5; i++) {
5240 Placeholder *input = mod_.createPlaceholder(ElemKind::Float16Ty,
5241 {i + 1, 100}, "input", false);
5242 bindings_.allocate(input)->getHandle<float16_t>().randomize(-100, 100,
5243 mod_.getPRNG());
5244 const TypeRef QTy = mod_.uniqueType(
5245 ElemKind::Int8QTy, input->getOutput().dims(), scale, offset);
5246 QuantizeNode *quantize = F_->createQuantize("quantize", input, QTy);
5247 inputs[i] = quantize->getResult();
5248 }
5249 ConcatNode *concat = F_->createConcat("concat", inputs, 0);
5250 SaveNode *SN = F_->createSave("ret", concat);
5251
5252 optimizedF_ = optimizeFunction(
5253 F_, {FunctionPassID::SinkConversions, getDCEPassConfig()});
5254
5255 // Concat, quantize, save.
5256 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
5257 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::QuantizeNodeKind), 1);
5258 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1);
5259 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 1);
5260
5261 SaveNode *optSN =
5262 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
5263 ASSERT_TRUE(optSN);
5264 QuantizeNode *optQuantize = llvm::dyn_cast<QuantizeNode>(optSN->getInput());
5265 ASSERT_TRUE(optQuantize);
5266 EXPECT_EQ(scale, optQuantize->getResult().getType()->getScale());
5267 EXPECT_EQ(offset, optQuantize->getResult().getType()->getOffset());
5268 EXPECT_EQ(ElemKind::Int8QTy,
5269 optQuantize->getResult().getType()->getElementType());
5270
5271 // Find quantize node in the optimized graph.
5272 checkNumericalEquivalence();
5273 }
5274
5275 /// Test Clip(Relu) -> Clip'.
TEST_F(GraphOptz,ClipReluTest)5276 TEST_F(GraphOptz, ClipReluTest) {
5277 Placeholder *input =
5278 mod_.createPlaceholder(ElemKind::Float16Ty, {64, 64}, "input", false);
5279 ReluNode *RN = F_->createRELU("RN", input);
5280 ClipNode *CN = F_->createClip("CN", RN, -5, 20);
5281 SaveNode *SN = F_->createSave("save", CN);
5282
5283 // Start with a clip, a relu, and a save.
5284 EXPECT_EQ(F_->getNodes().size(), 3);
5285 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ClipNodeKind), 1);
5286 EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ReluNodeKind), 1);
5287
5288 optimizedF_ = optimizeFunction(F_);
5289
5290 // Removed the relu
5291 EXPECT_EQ(optimizedF_->getNodes().size(), 2);
5292 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ClipNodeKind), 1);
5293 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ReluNodeKind), 0);
5294
5295 SaveNode *optSN =
5296 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
5297 ASSERT_TRUE(optSN);
5298
5299 // We have the same max for clip as before, but 0 for min due to the Relu.
5300 ClipNode *optCN = llvm::dyn_cast<ClipNode>(optSN->getInput());
5301 ASSERT_TRUE(optCN);
5302 EXPECT_EQ(optCN->getMin(), 0);
5303 EXPECT_EQ(optCN->getMax(), 20);
5304
5305 bindings_.allocate(input)->getHandle<float16_t>().randomize(-50.0, 5.0,
5306 mod_.getPRNG());
5307 checkNumericalEquivalence();
5308 }
5309
5310 /// Test that if we have a concat with some dequantize inputs that are
5311 /// concatenated together, and then a quantize after the concat, that we can
5312 /// move the quantize above the concat and eliminate the dequantizes.
TEST_F(GraphOptz,SinkConcatBelowQuantize)5313 TEST_F(GraphOptz, SinkConcatBelowQuantize) {
5314 const float scale = 0.06;
5315 const int32_t offset = -15;
5316 std::array<NodeValue, 3> inputs;
5317
5318 // Concat input 0: Dequant(PH)
5319 const TypeRef in0QTy =
5320 mod_.uniqueType(ElemKind::Int8QTy, {1, 3}, scale, offset);
5321 Placeholder *input0 = mod_.createPlaceholder(in0QTy, "input", false);
5322 inputs[0] =
5323 F_->createDequantize("deq", input0, ElemKind::Float16Ty)->getResult();
5324
5325 // Concat input 1: Dequant(Add(PH, PH))
5326 const TypeRef in1QTy =
5327 mod_.uniqueType(ElemKind::Int8QTy, {5, 3}, scale, offset + 1);
5328 Placeholder *input1 = mod_.createPlaceholder(in1QTy, "input", false);
5329 AddNode *add = F_->createAdd("add", input1, input1);
5330 inputs[1] =
5331 F_->createDequantize("deq", add, ElemKind::Float16Ty)->getResult();
5332
5333 // Concat input 2: PH
5334 Placeholder *input2 =
5335 mod_.createPlaceholder(ElemKind::Float16Ty, {10, 3}, "input_fp", false);
5336 inputs[2] = input2->getOutput();
5337
5338 // Concat all 3 together, all FP16.
5339 ConcatNode *concat = F_->createConcat("concat", inputs, 0);
5340
5341 // Now quantize the result of the concat.
5342 const TypeRef QTy = mod_.uniqueType(
5343 ElemKind::Int8QTy, concat->getResult().dims(), scale, offset);
5344 QuantizeNode *QN = F_->createQuantize("quantize", concat, QTy);
5345 SaveNode *SN = F_->createSave("ret", QN);
5346
5347 optimizedF_ = optimizeFunction(F_, {FunctionPassID::SinkConcatBelowQuantize,
5348 {FunctionPassID::OptimizeQuantization,
5349 ConvergenceMode::UntilFixedPoint},
5350 getDCEPassConfig()});
5351
5352 EXPECT_EQ(optimizedF_->getNodes().size(), 4);
5353 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 1);
5354 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::AddNodeKind), 1);
5355 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::QuantizeNodeKind), 1);
5356 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 1);
5357
5358 SaveNode *optSN =
5359 llvm::dyn_cast<SaveNode>(optimizedF_->getNodeByName(SN->getName()));
5360 ASSERT_TRUE(optSN);
5361
5362 // Concat should be directly connected to save, with same quantization
5363 // parameters as the quantize which used to follow it.
5364 ConcatNode *optCN = llvm::dyn_cast<ConcatNode>(optSN->getInput());
5365 ASSERT_TRUE(optCN);
5366 ASSERT_EQ(ElemKind::Int8QTy, optCN->getResult().getType()->getElementType());
5367 EXPECT_EQ(scale, optCN->getResult().getType()->getScale());
5368 EXPECT_EQ(offset, optCN->getResult().getType()->getOffset());
5369
5370 ASSERT_EQ(optCN->getInputs().size(), 3);
5371
5372 // No rescale here for the PH since its scale/offset match the PH and so
5373 // are optimized away.
5374 EXPECT_EQ(optCN->getInputs()[0], input0->getOutput());
5375
5376 // No rescale here because it should be fused into optAN. Check the
5377 // scale/offset use that scale/offset.
5378 AddNode *optAN = llvm::dyn_cast<AddNode>(optCN->getInputs()[1]);
5379 ASSERT_TRUE(optAN);
5380 ASSERT_EQ(ElemKind::Int8QTy, optAN->getResult().getType()->getElementType());
5381 EXPECT_EQ(scale, optAN->getResult().getType()->getScale());
5382 EXPECT_EQ(offset, optAN->getResult().getType()->getOffset());
5383 EXPECT_EQ(optAN->getLHS(), input1->getOutput());
5384 EXPECT_EQ(optAN->getRHS(), input1->getOutput());
5385
5386 // Must quantize this input since the PH is float16.
5387 QuantizeNode *optQN = llvm::dyn_cast<QuantizeNode>(optCN->getInputs()[2]);
5388 ASSERT_TRUE(optQN);
5389 ASSERT_EQ(ElemKind::Int8QTy, optQN->getResult().getType()->getElementType());
5390 EXPECT_EQ(scale, optQN->getResult().getType()->getScale());
5391 EXPECT_EQ(offset, optQN->getResult().getType()->getOffset());
5392 EXPECT_EQ(optQN->getInput(), input2->getOutput());
5393
5394 bindings_.allocate(input0)->getHandle<int8_t>().randomize(-50, 50,
5395 mod_.getPRNG());
5396 bindings_.allocate(input1)->getHandle<int8_t>().randomize(-50, 50,
5397 mod_.getPRNG());
5398 bindings_.allocate(input2)->getHandle<float16_t>().randomize(-10, 10,
5399 mod_.getPRNG());
5400 }
5401
TEST_F(GraphOptz,EliminateSliceConcatTest)5402 TEST_F(GraphOptz, EliminateSliceConcatTest) {
5403 auto *src1 =
5404 mod_.createPlaceholder(ElemKind::FloatTy, {10, 70}, "src1", false);
5405 auto *src2 =
5406 mod_.createPlaceholder(ElemKind::FloatTy, {10, 80}, "src2", false);
5407 auto *A = F_->createSlice("A", src1, {0, 0}, {10, 10});
5408 auto *B = F_->createSlice("B", src1, {0, 10}, {10, 20});
5409 auto *C = F_->createSlice("C", src1, {0, 20}, {10, 30});
5410 // interleaved Slices with different sources shouldn't merge
5411 auto *E = F_->createSlice("E", src1, {0, 30}, {10, 40});
5412 auto *F = F_->createSlice("F", src2, {0, 30}, {10, 40});
5413 auto *G = F_->createSlice("G", src1, {0, 40}, {10, 50});
5414 auto *H = F_->createSlice("H", src2, {0, 40}, {10, 50});
5415
5416 auto *D = mod_.createPlaceholder(ElemKind::FloatTy, {10, 50}, "D", false);
5417 auto *R = F_->createRELU("Relu", C);
5418 auto *CN = F_->createConcat("Concat", {A, B, D, E, F, G, H}, 1);
5419 F_->createSave("save1", CN);
5420 F_->createSave("save2", R);
5421
5422 EXPECT_EQ(F_->getNodes().size(), 11);
5423
5424 optimizedF_ = optimizeFunction(
5425 F_, {FunctionPassID::EliminateSliceConcat, getDCEPassConfig()});
5426
5427 EXPECT_EQ(optimizedF_->getNodes().size(), 10);
5428
5429 int numSlicesToConcat = 0;
5430 for (const auto &node : optimizedF_->getNodes()) {
5431 auto *newCN = llvm::dyn_cast<ConcatNode>(&node);
5432 if (!newCN) {
5433 continue;
5434 }
5435 EXPECT_EQ(newCN->getInputs().size(), 6);
5436 for (const auto &concatInput : newCN->getInputs()) {
5437 auto *SN = llvm::dyn_cast<SliceNode>(concatInput.getNode());
5438 if (SN) {
5439 numSlicesToConcat++;
5440 }
5441 }
5442 }
5443 EXPECT_EQ(numSlicesToConcat, 5);
5444
5445 bindings_.allocate(src1)->getHandle<float>().randomize(-10.0, 10.0,
5446 mod_.getPRNG());
5447 bindings_.allocate(src2)->getHandle<float>().randomize(-10.0, 10.0,
5448 mod_.getPRNG());
5449 bindings_.allocate(D)->getHandle<float>().randomize(-10.0, 10.0,
5450 mod_.getPRNG());
5451 checkNumericalEquivalence();
5452 }
5453
5454 /// Verify that when we want to prevent constant folding it doesn't occur.
TEST_F(GraphOptz,constantFoldPreventedNoop)5455 TEST_F(GraphOptz, constantFoldPreventedNoop) {
5456 auto *const1 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const1");
5457 auto *const2 = mod_.createConstant(ElemKind::FloatTy, {2, 2}, "const2");
5458 auto *ph1 = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2}, "input1",
5459 /* isTrainable */ false);
5460 setConstValue(const1, 1.0f);
5461 setConstValue(const2, 2.0f);
5462 auto *splat2 = F_->createSplat(
5463 "splat2", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 2.0f);
5464 auto *splat3 = F_->createSplat(
5465 "splat3", mod_.uniqueType(ElemKind::FloatTy, {2, 2}), 3.0f);
5466
5467 auto *add1 = F_->createAdd("add", const1, const2);
5468 auto *mul1 = F_->createMul("mul1", add1, splat2);
5469 auto *mul2 = F_->createMul("mul2", mul1, splat3);
5470 F_->createSave("save", mul2);
5471 auto *add3 = F_->createAdd("add", const1, ph1);
5472 F_->createSave("save", add3);
5473
5474 ConstantModificationPreventer constModPreventer(mod_);
5475 constModPreventer.activate();
5476
5477 // Check that both Constants are protected and no change is made to the
5478 // Function during optimization.
5479 EXPECT_EQ(constModPreventer.getMapping().size(), 2);
5480 optimizedF_ = optimizeFunction(F_);
5481 EXPECT_EQ(F_->toString(/* skipUsersForStorage */ false,
5482 /* skipName */ true),
5483 optimizedF_->toString(/* skipUsersForStorage */ false,
5484 /* skipName */ true));
5485
5486 // Now deactivate the constModPreventer and check we can const fold still.
5487 constModPreventer.deactivateAndCleanup();
5488 mod_.eraseFunction(optimizedF_);
5489 optimizedF_ = optimizeFunction(F_);
5490
5491 // After constant folding, left with just two Saves, one Add.
5492 EXPECT_EQ(optimizedF_->getNodes().size(), 3);
5493 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::AddNodeKind), 1);
5494 EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SaveNodeKind), 2);
5495
5496 bindings_.allocate(ph1)->getHandle<float>().randomize(-10.0, 10.0,
5497 mod_.getPRNG());
5498 checkNumericalEquivalence();
5499 }
5500
5501 /// Test that a Conv2D is correctly lowered to FC for single batch.
TEST_F(GraphOptz,lowerConv2DToFCSingleBatch)5502 TEST_F(GraphOptz, lowerConv2DToFCSingleBatch) {
5503 Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {1, 2, 3, 4},
5504 "input", /* isTrainable */ false);
5505 bindings_.allocate(input)->getHandle<float>().randomize(-10, 10,
5506 mod_.getPRNG());
5507
5508 Constant *filter =
5509 mod_.createConstant(ElemKind::FloatTy, {8, 1, 1, 4}, "filter");
5510 filter->getPayloadMutable().getHandle<float>().randomize(-10, 10,
5511 mod_.getPRNG());
5512
5513 Constant *bias = mod_.createConstant(ElemKind::FloatTy, {8}, "bias");
5514 bias->getPayloadMutable().getHandle<float>().randomize(-10, 10,
5515 mod_.getPRNG());
5516
5517 auto outTy = mod_.uniqueType(ElemKind::FloatTy, {1, 2, 3, 8});
5518 auto *conv = F_->createConv("conv", input, filter, bias, outTy, {1, 1},
5519 {1, 1}, {0, 0, 0, 0}, 1, 1);
5520 SaveNode *save = F_->createSave("save", conv);
5521 bindings_.allocate(save->getPlaceholder());
5522
5523 // Backup function in optimizedF_.
5524 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5525
5526 // Lower Convolution.
5527 EXPECT_TRUE(isConvolutionSameAsFullyConnected(conv));
5528 EXPECT_TRUE(glow::lowerNode(F_, conv, cctx_));
5529 runDCEPass(F_, cctx_);
5530 EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::ConvolutionNodeKind));
5531 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
5532
5533 // Now compile/run/compare F_ and optimizedF_.
5534 checkNumericalEquivalence(1e-6);
5535 }
5536
5537 /// Test that a Conv2D is correctly lowered to FC for multi batch.
TEST_F(GraphOptz,lowerConv2DToFCMultiBatch)5538 TEST_F(GraphOptz, lowerConv2DToFCMultiBatch) {
5539 Placeholder *input = mod_.createPlaceholder(ElemKind::FloatTy, {2, 2, 3, 4},
5540 "input", /* isTrainable */ false);
5541 bindings_.allocate(input)->getHandle<float>().randomize(-10, 10,
5542 mod_.getPRNG());
5543
5544 Constant *filter =
5545 mod_.createConstant(ElemKind::FloatTy, {8, 1, 1, 4}, "filter");
5546 filter->getPayloadMutable().getHandle<float>().randomize(-10, 10,
5547 mod_.getPRNG());
5548
5549 Constant *bias = mod_.createConstant(ElemKind::FloatTy, {8}, "bias");
5550 bias->getPayloadMutable().getHandle<float>().randomize(-10, 10,
5551 mod_.getPRNG());
5552
5553 auto outTy = mod_.uniqueType(ElemKind::FloatTy, {2, 2, 3, 8});
5554 auto *conv = F_->createConv("conv", input, filter, bias, outTy, {1, 1},
5555 {1, 1}, {0, 0, 0, 0}, 1, 1);
5556 SaveNode *save = F_->createSave("save", conv);
5557 bindings_.allocate(save->getPlaceholder());
5558
5559 // Backup function in optimizedF_.
5560 optimizedF_ = F_->clone(F_->getName().str() + "_optimized");
5561
5562 // Lower Convolution.
5563 EXPECT_TRUE(isConvolutionSameAsFullyConnected(conv));
5564 EXPECT_TRUE(glow::lowerNode(F_, conv, cctx_));
5565 runDCEPass(F_, cctx_);
5566 EXPECT_EQ(0, countNodeKind(F_, Kinded::Kind::ConvolutionNodeKind));
5567 EXPECT_EQ(1, countNodeKind(F_, Kinded::Kind::FullyConnectedNodeKind));
5568
5569 // Now compile/run/compare F_ and optimizedF_.
5570 checkNumericalEquivalence(1e-6);
5571 }
5572
5573 /// Test that Mul and Add can be folded into LayerNorm.
TEST_F(GraphOptz,foldMulAddIntoLayerNorm)5574 TEST_F(GraphOptz, foldMulAddIntoLayerNorm) {
5575 auto *input =
5576 mod_.createPlaceholder(ElemKind::FloatTy, {2, 4, 10, 20}, "in", false);
5577
5578 Tensor scaleT(ElemKind::FloatTy, {10, 20});
5579 scaleT.getHandle().randomize(0.0f, 1.0f, mod_.getPRNG());
5580 Constant *scaleC = mod_.createConstant("scale", std::move(scaleT));
5581 SplatNode *biasS = F_->createSplat("bias", scaleC->getType(), 1.5f);
5582
5583 auto *LN = F_->createLayerNormalization("LN", input, scaleC, biasS, 1e-5);
5584
5585 SplatNode *splat = F_->createSplat("splat", scaleC->getType(), 0.5f);
5586 MulNode *MN =
5587 F_->createNodeWithBroadcast<MulNode>("mul", /* axis */ -1, LN, splat);
5588
5589 Tensor addT(ElemKind::FloatTy, {1, 1, 10, 20});
5590 addT.getHandle().randomize(-1.0f, 1.0f, mod_.getPRNG());
5591 Constant *addC = mod_.createConstant("addC", std::move(addT));
5592 AddNode *AN =
5593 F_->createNodeWithBroadcast<AddNode>("add", /* axis */ -1, MN, addC);
5594 F_->createSave("save", AN);
5595
5596 optimizedF_ = optimizeFunction(F_);
5597
5598 // Because Mul and Add are folded in, they should not exist anymore, nor
5599 // should tiles that expand them to match the output of LN.
5600 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::MulNodeKind));
5601 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::AddNodeKind));
5602 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::TileNodeKind));
5603
5604 // Now compile/run/compare F_ and optimizedF_.
5605 bindings_.allocate(input)->getHandle().randomize(0.0f, 1.0f, mod_.getPRNG());
5606 checkNumericalEquivalence(1e-6);
5607 }
5608
5609 /// Test that Mul and Add can be folded into LayerNorm when the leading dims are
5610 /// all one.
TEST_F(GraphOptz,foldMulAddIntoLayerNormNoBatch)5611 TEST_F(GraphOptz, foldMulAddIntoLayerNormNoBatch) {
5612 auto *input =
5613 mod_.createPlaceholder(ElemKind::FloatTy, {1, 1, 10, 20}, "in", false);
5614
5615 Tensor scaleT(ElemKind::FloatTy, {10, 20});
5616 scaleT.getHandle().randomize(0.0f, 1.0f, mod_.getPRNG());
5617 Constant *scaleC = mod_.createConstant("scale", std::move(scaleT));
5618 SplatNode *biasS = F_->createSplat("bias", scaleC->getType(), 1.5f);
5619
5620 auto *LN = F_->createLayerNormalization("LN", input, scaleC, biasS, 1e-5);
5621
5622 SplatNode *splat = F_->createSplat("splat", scaleC->getType(), 0.5f);
5623 MulNode *MN =
5624 F_->createNodeWithBroadcast<MulNode>("mul", /* axis */ -1, LN, splat);
5625
5626 Tensor addT(ElemKind::FloatTy, {1, 1, 10, 20});
5627 addT.getHandle().randomize(-1.0f, 1.0f, mod_.getPRNG());
5628 Constant *addC = mod_.createConstant("addC", std::move(addT));
5629 AddNode *AN =
5630 F_->createNodeWithBroadcast<AddNode>("add", /* axis */ -1, MN, addC);
5631 F_->createSave("save", AN);
5632
5633 optimizedF_ = optimizeFunction(F_);
5634
5635 // Because Mul and Add are folded in, they should not exist anymore, nor
5636 // should tiles that expand them to match the output of LN.
5637 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::MulNodeKind));
5638 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::AddNodeKind));
5639 EXPECT_EQ(0, countNodeKind(optimizedF_, Kinded::Kind::TileNodeKind));
5640
5641 // Now compile/run/compare F_ and optimizedF_.
5642 bindings_.allocate(input)->getHandle().randomize(0.0f, 1.0f, mod_.getPRNG());
5643 checkNumericalEquivalence(1e-6);
5644 }
5645
TEST_F(GraphOptz,transposeQuantizeConstantWithAlignment)5646 TEST_F(GraphOptz, transposeQuantizeConstantWithAlignment) {
5647 // Define a type with custom alignments.
5648 Type typeWithAlignments(ElemKind::FloatTy, {2, 3, 4, 5}, {1, 1, 32, 1});
5649 Type quantTypeWithAlignments(ElemKind::Int8QTy, {2, 3, 4, 5}, {1, 1, 32, 1},
5650 1.0, 0);
5651 Type transposedQuantTypeWithAlignments(ElemKind::Int8QTy, {2, 4, 5, 3},
5652 {1, 1, 32, 1}, 1.0, 0);
5653 auto modTyWithAlignments = mod_.uniqueType(typeWithAlignments);
5654 auto modQuantTransposedTyWithAlignments =
5655 mod_.uniqueType(transposedQuantTypeWithAlignments);
5656 auto modQuantTyWithAlignments = mod_.uniqueType(quantTypeWithAlignments);
5657 auto *I = mod_.createConstant(modTyWithAlignments, "input1");
5658 auto *Q = F_->createQuantize("quantize", I, modQuantTyWithAlignments);
5659 auto *T = F_->createTranspose("transpose", Q, NCHW2NHWC);
5660 T->setType(TransposeNode::ResultIdx, modQuantTransposedTyWithAlignments);
5661 SaveNode *S = F_->createSave("ret", T);
5662
5663 // Skip ConstantFolding as it would have the same result as this opt.
5664 CompilationContext cctx;
5665 cctx.optimizationOpts.enableConstantFolding = false;
5666
5667 EXPECT_EQ(F_->getNodes().size(), 3);
5668 ::glow::optimize(F_, cctx);
5669 EXPECT_EQ(F_->getNodes().size(), 2);
5670
5671 // Constant and Quantize should have new shape.
5672 auto *newQ = llvm::dyn_cast<QuantizeNode>(S->getInput());
5673 ASSERT_TRUE(newQ);
5674 EXPECT_TRUE(newQ->getResult().dims().equals({2, 4, 5, 3}));
5675 auto *newC = llvm::dyn_cast<Constant>(newQ->getInput());
5676 ASSERT_TRUE(newC);
5677 EXPECT_TRUE(newC->getType()->dims().equals({2, 4, 5, 3}));
5678
5679 // Check that alignments are preserved by optimizations.
5680 auto expectedNewTy = mod_.uniqueTypeWithNewShape(
5681 modTyWithAlignments, modQuantTransposedTyWithAlignments);
5682 EXPECT_TRUE(newQ->getInput().getType()->isEqual(expectedNewTy));
5683
5684 EXPECT_TRUE(F_->verify());
5685 }
5686