1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17import numpy as np 18 19import tvm 20from tvm import te 21from tvm import relay 22from tvm.relay import transform 23 24 25def _get_positive_scale(size): 26 return np.random.uniform(0.5, 1, size=size).astype("float32") 27 28 29def run_opt_pass(expr, opt_pass): 30 assert isinstance(opt_pass, tvm.transform.Pass) 31 mod = tvm.IRModule.from_expr(expr) 32 mod = opt_pass(mod) 33 entry = mod["main"] 34 return entry if isinstance(expr, relay.Function) else entry.body 35 36 37def test_fold_fwd_simple(): 38 """Simple testcase.""" 39 40 def before(x, conv_weight, in_bias, in_scale, channels, blocking): 41 args = [x, conv_weight, in_bias] 42 x = relay.multiply(x, in_scale) 43 x = relay.nn.relu(x) 44 x = relay.add(x, in_bias) 45 y = relay.nn.conv2d( 46 x, 47 conv_weight, 48 channels=channels, 49 kernel_size=(3, 3), 50 padding=(1, 1), 51 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 52 kernel_layout="OIHW2i{}o".format(blocking[1]) if blocking else "OIHW", 53 ) 54 55 return relay.Function(args, y) 56 57 def expected(x, conv_weight, in_bias, in_scale, in_channels, channels, blocking): 58 # use a fixed order of args so alpha equal check can pass 59 args = [x, conv_weight, in_bias] 60 if blocking: 61 squeezed_scale = relay.squeeze(in_scale, axis=[0, 2, 3]) 62 x = relay.nn.relu(x) 63 in_bias = relay.divide( 64 in_bias, 65 relay.reshape(squeezed_scale, (1, in_channels // blocking[0], 1, 1, blocking[0])), 66 ) # NCHWc 67 x = relay.add(x, in_bias) 68 conv_weight = relay.multiply( 69 conv_weight, relay.reshape(squeezed_scale, (1, in_channels // 2, 1, 1, 2, 1)) 70 ) # OIHWio 71 else: 72 squeezed_scale = relay.squeeze(in_scale, axis=[1, 2]) 73 x = relay.nn.relu(x) 74 in_bias = relay.divide( 75 in_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2) 76 ) 77 x = relay.add(x, in_bias) 78 conv_weight = relay.multiply( 79 conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2) 80 ) 81 82 y = relay.nn.conv2d( 83 x, 84 conv_weight, 85 channels=channels, 86 kernel_size=(3, 3), 87 padding=(1, 1), 88 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 89 kernel_layout="OIHW2i{}o".format(blocking[1]) if blocking else "OIHW", 90 ) 91 return relay.Function(args, y) 92 93 def check(shape, channels, blocking): 94 x = relay.var("x", shape=shape) 95 weight = relay.var("weight") 96 if blocking: 97 in_channels = shape[1] * shape[4] 98 in_bias = relay.var("in_bias", shape=(1, in_channels // blocking[0], 1, 1, blocking[0])) 99 in_scale = relay.const( 100 _get_positive_scale((1, in_channels // blocking[0], 1, 1, blocking[0])) 101 ) 102 else: 103 in_channels = shape[1] 104 in_bias = relay.var("in_bias", shape=(in_channels, 1, 1)) 105 in_scale = relay.const(_get_positive_scale((in_channels, 1, 1))) 106 y1 = before(x, weight, in_bias, in_scale, channels, blocking) 107 y1 = run_opt_pass(y1, transform.InferType()) 108 type_dict = {x.name_hint: x.checked_type for x in y1.params} 109 weight = relay.var("weight", type_dict["weight"]) 110 y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) 111 y1_expected = expected(x, weight, in_bias, in_scale, in_channels, channels, blocking) 112 113 y1_folded = run_opt_pass(y1_folded, transform.InferType()) 114 y1_expected = run_opt_pass(y1_expected, transform.InferType()) 115 assert tvm.ir.structural_equal(y1_folded, y1_expected) 116 117 check((2, 4, 10, 10), 2, None) 118 check((2, 2, 10, 10, 2), 8, (2, 4)) 119 120 121def test_fold_fwd_dual_path(): 122 """scale axis being consumed by two consumers""" 123 124 def before(x, conv_weight, in_bias, in_scale, channels, blocking): 125 args = [x, conv_weight, in_bias] 126 x = relay.multiply(in_scale, x) 127 x = relay.nn.relu(x) 128 x = relay.subtract(x, in_bias) 129 y1 = relay.nn.conv2d( 130 x, 131 conv_weight, 132 channels=channels, 133 kernel_size=(3, 3), 134 data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", 135 kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", 136 groups=channels, 137 padding=(1, 1), 138 ) 139 y2 = relay.nn.conv2d( 140 x, 141 conv_weight, 142 channels=channels, 143 kernel_size=(3, 3), 144 data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", 145 kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", 146 groups=channels, 147 padding=(1, 1), 148 ) 149 z = relay.add(y1, y2) 150 return relay.Function(args, z) 151 152 def expected(x, conv_weight, in_bias, in_scale, channels, blocking): 153 args = [x, conv_weight, in_bias] 154 x = relay.nn.relu(x) 155 if blocking: 156 _in_scale = relay.reshape( 157 in_scale, (1, 1, 1, channels // blocking[0], blocking[0]) 158 ) # NHWCc 159 else: 160 _in_scale = in_scale 161 in_bias = relay.divide(in_bias, _in_scale) 162 x = relay.subtract(x, in_bias) 163 if blocking: 164 _in_scale = relay.reshape( 165 in_scale, (1, 1, 1, channels // blocking[0], 1, blocking[0]) 166 ) # HWIOio 167 y1 = relay.nn.conv2d( 168 x, 169 relay.multiply(conv_weight, _in_scale), 170 channels=channels, 171 kernel_size=(3, 3), 172 data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", 173 kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", 174 groups=channels, 175 padding=(1, 1), 176 ) 177 if blocking: 178 _in_scale = relay.reshape( 179 in_scale, (1, 1, 1, channels // blocking[0], 1, blocking[0]) 180 ) # HWIOio 181 y2 = relay.nn.conv2d( 182 x, 183 relay.multiply(conv_weight, _in_scale), 184 channels=channels, 185 kernel_size=(3, 3), 186 data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", 187 kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", 188 groups=channels, 189 padding=(1, 1), 190 ) 191 z = relay.add(y1, y2) 192 return relay.Function(args, z) 193 194 def check(dshape, channels, blocking): 195 x = relay.var("x", shape=dshape) 196 if blocking: 197 in_channels = dshape[3] * dshape[4] 198 wshape = (3, 3, 1, channels // blocking[1], 1, blocking[1]) # HWIOio 199 weight = relay.var("weight", shape=wshape) 200 in_bias = relay.var("in_bias", shape=(in_channels // blocking[0], blocking[0])) 201 in_scale = relay.const(_get_positive_scale((in_channels // blocking[0], blocking[0]))) 202 else: 203 in_channels = dshape[-1] 204 wshape = (3, 3, 1, channels) # HWIO 205 weight = relay.var("weight", shape=wshape) 206 in_bias = relay.var("in_bias", shape=(in_channels,)) 207 in_scale = relay.const( 208 _get_positive_scale( 209 in_channels, 210 ) 211 ) 212 213 # test depthwise 214 assert in_channels == channels 215 216 y1 = before(x, weight, in_bias, in_scale, channels, blocking) 217 y1 = run_opt_pass(y1, transform.InferType()) 218 y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) 219 type_dict = {x.name_hint: x.checked_type for x in y1.params} 220 weight = relay.var("weight", type_dict["weight"]) 221 y1_expected = expected(x, weight, in_bias, in_scale, channels, blocking) 222 y1_expected = run_opt_pass(y1_expected, transform.InferType()) 223 assert tvm.ir.structural_equal(y1_folded, y1_expected) 224 225 check((2, 4, 10, 3), 3, None) 226 check((2, 4, 10, 2, 2), 4, (2, 2)) 227 228 229def test_fold_fwd_fail(): 230 """testcase where we canont fold""" 231 232 def before(x, conv_weight, in_bias, in_scale, channels, blocking): 233 x = relay.multiply(x, in_scale) 234 xx = relay.nn.leaky_relu(x, alpha=0.1) 235 y1 = relay.nn.conv2d( 236 xx, 237 conv_weight, 238 channels=channels, 239 kernel_size=(3, 3), 240 data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", 241 kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", 242 padding=(1, 1), 243 ) 244 z = relay.add(y1, x) 245 return relay.Function(relay.analysis.free_vars(z), z) 246 247 def check(shape, channels, blocking): 248 x = relay.var("x", shape=shape) 249 if blocking: 250 in_channels = shape[3] * shape[4] 251 in_bias = relay.var("in_bias", shape=(in_channels // blocking[0], blocking[0])) 252 in_scale = relay.const(_get_positive_scale((in_channels // blocking[0], blocking[0]))) 253 else: 254 in_channels = shape[-1] 255 in_bias = relay.var("in_bias", shape=(in_channels,)) 256 in_scale = relay.const(_get_positive_scale(size=(in_channels,))) 257 # test depthwise 258 assert in_channels == channels 259 weight = relay.var("weight") 260 y1 = before(x, weight, in_bias, in_scale, channels, blocking) 261 y1 = run_opt_pass(y1, transform.InferType()) 262 y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) 263 assert tvm.ir.structural_equal(y1, y1_folded) 264 265 check((2, 11, 10, 4), 4, None) 266 check((2, 11, 10, 2, 2), 4, (2, 2)) 267 268 269def test_fold_fwd_relu_fail(): 270 """testcase where we canont fold because scale can not pass relu""" 271 272 def before(x, conv_weight, in_bias, in_scale, channels, blocking): 273 x = relay.multiply(x, in_scale) 274 xx = relay.nn.relu(x) 275 y1 = relay.nn.conv2d( 276 xx, 277 conv_weight, 278 channels=channels, 279 kernel_size=(3, 3), 280 data_layout="NHWC{}c".format(blocking[0]) if blocking else "NHWC", 281 kernel_layout="HWIO1i{}o".format(blocking[1]) if blocking else "HWIO", 282 padding=(1, 1), 283 ) 284 z = relay.add(y1, x) 285 return relay.Function(relay.analysis.free_vars(z), z) 286 287 def check(shape, channels, blocking, in_scale): 288 x = relay.var("x", shape=shape) 289 weight = relay.var("weight") 290 if blocking: 291 in_channels = shape[3] * shape[4] 292 in_bias = relay.var("in_bias", shape=(1, in_channels // blocking[0], 1, 1, blocking[0])) 293 else: 294 in_channels = shape[-1] 295 in_bias = relay.var("in_bias", shape=(in_channels,)) 296 297 assert in_channels == channels 298 y1 = before(x, weight, in_bias, in_scale, channels, blocking) 299 y1 = run_opt_pass(y1, transform.InferType()) 300 y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) 301 assert tvm.ir.structural_equal(y1, y1_folded) 302 303 in_scale = relay.var("in_scale", shape=(4,)) 304 check((2, 11, 10, 4), 4, None, in_scale) 305 in_scale = relay.const(-_get_positive_scale((4,))) 306 check((2, 11, 10, 4), 4, None, in_scale) 307 308 in_scale = relay.var("in_scale", shape=(1, 1, 1, 2, 2)) 309 check((2, 11, 10, 2, 2), 4, (2, 2), in_scale) 310 in_scale = relay.const(-_get_positive_scale((1, 1, 1, 2, 2))) 311 check((2, 11, 10, 2, 2), 4, (2, 2), in_scale) 312 313 314def test_fold_fwd_negative_scale(): 315 """Testcase of folding negative scale""" 316 317 def before(x, conv_weight, in_scale, channels, blocking): 318 args = [x, conv_weight] 319 x = relay.multiply(x, in_scale) 320 y = relay.nn.conv2d( 321 x, 322 conv_weight, 323 channels=channels, 324 kernel_size=(3, 3), 325 padding=(1, 1), 326 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 327 kernel_layout="OIHW4i{}o".format(blocking[1]) if blocking else "OIHW", 328 ) 329 return relay.Function(args, y) 330 331 def expected(x, conv_weight, in_scale, in_channels, channels, blocking): 332 # use a fixed order of args so alpha equal check can pass 333 args = [x, conv_weight] 334 if blocking: 335 squeezed_scale = relay.squeeze(in_scale, axis=[0, 2, 3]) 336 conv_weight = relay.multiply( 337 conv_weight, relay.reshape(squeezed_scale, (1, in_channels // 4, 1, 1, 4, 1)) 338 ) 339 # blocking by "i" in OIHWio 340 else: 341 squeezed_scale = relay.squeeze(in_scale, axis=[1, 2]) 342 conv_weight = relay.multiply( 343 conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2) 344 ) 345 y = relay.nn.conv2d( 346 x, 347 conv_weight, 348 channels=channels, 349 kernel_size=(3, 3), 350 padding=(1, 1), 351 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 352 kernel_layout="OIHW4i{}o".format(blocking[1]) if blocking else "OIHW", 353 ) 354 return relay.Function(args, y) 355 356 def check(shape, channels, blocking): 357 x = relay.var("x", shape=shape) 358 if blocking: 359 in_channels = shape[1] * shape[4] 360 in_scale = relay.const(-_get_positive_scale((1, shape[1], 1, 1, shape[4]))) 361 else: 362 in_channels = shape[1] 363 in_scale = relay.const(-_get_positive_scale((in_channels, 1, 1))) 364 weight = relay.var("weight") 365 y1 = before(x, weight, in_scale, channels, blocking) 366 y1 = run_opt_pass(y1, transform.InferType()) 367 type_dict = {x.name_hint: x.checked_type for x in y1.params} 368 weight = relay.var("weight", type_dict["weight"]) 369 y1_folded = run_opt_pass(y1, transform.ForwardFoldScaleAxis()) 370 y1_expected = expected(x, weight, in_scale, in_channels, channels, blocking) 371 y1_expected = run_opt_pass(y1_expected, transform.InferType()) 372 assert tvm.ir.structural_equal(y1_folded, y1_expected) 373 374 check((2, 4, 10, 10), 4, None) 375 check((2, 2, 10, 10, 2), 8, (2, 2)) 376 377 378def test_fold_bwd_simple(): 379 """Simple testcase.""" 380 381 def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): 382 args = [x, conv_weight, out_bias] 383 if blocking: 384 out_bias = relay.reshape(out_bias, (1, channels // blocking[1], 1, 1, blocking[1])) 385 else: 386 out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) 387 y = relay.nn.conv2d( 388 x, 389 conv_weight, 390 channels=channels, 391 kernel_size=(3, 3), 392 padding=(1, 1), 393 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 394 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 395 ) 396 y = relay.add(y, out_bias) 397 y = relay.nn.relu(y) 398 if blocking: 399 out_scale = relay.reshape(out_scale, (1, channels // blocking[1], 1, 1, blocking[1])) 400 y = relay.multiply(y, out_scale) 401 return relay.Function(args, y) 402 403 def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): 404 # use a fixed order of args so alpha equal check can pass 405 args = [x, conv_weight, out_bias] 406 if blocking: 407 out_bias = relay.reshape(out_bias, (1, channels // blocking[1], 1, 1, blocking[1])) 408 out_scale = relay.reshape(out_scale, (1, channels // blocking[1], 1, 1, blocking[1])) 409 squeezed_scale = relay.squeeze(out_scale, axis=[0, 2, 3]) 410 conv_weight = relay.multiply( 411 conv_weight, 412 relay.reshape(squeezed_scale, (channels // blocking[1], 1, 1, 1, 1, blocking[1])), 413 ) 414 else: 415 out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) 416 squeezed_scale = relay.squeeze(out_scale, axis=[1, 2]) 417 conv_weight = relay.multiply( 418 conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3) 419 ) 420 421 y = relay.nn.conv2d( 422 x, 423 conv_weight, 424 channels=channels, 425 kernel_size=(3, 3), 426 padding=(1, 1), 427 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 428 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 429 ) 430 if blocking: 431 out_bias = relay.multiply( 432 out_bias, 433 relay.reshape(squeezed_scale, (1, channels // blocking[1], 1, 1, blocking[1])), 434 ) 435 else: 436 out_bias = relay.multiply( 437 out_bias, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=2) 438 ) 439 y = relay.add(y, out_bias) 440 y = relay.nn.relu(y) 441 return relay.Function(args, y) 442 443 def check(shape, in_channels, channels, blocking): 444 x = relay.var("x", shape=shape) 445 weight = relay.var("weight") 446 out_bias = relay.var("out_bias", shape=(channels,)) 447 if blocking: 448 out_scale = relay.const(_get_positive_scale((channels,))) 449 else: 450 out_scale = relay.const(_get_positive_scale((channels, 1, 1))) 451 y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking) 452 y1 = run_opt_pass(y1, transform.InferType()) 453 type_dict = {x.name_hint: x.checked_type for x in y1.params} 454 weight = relay.var("weight", type_dict["weight"]) 455 y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) 456 y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking) 457 y1_expected = run_opt_pass(y1_expected, transform.InferType()) 458 assert tvm.ir.structural_equal(y1_folded, y1_expected) 459 460 check((2, 4, 10, 10), 4, 8, None) 461 check((2, 2, 10, 10, 16), 32, 64, (16, 16)) 462 463 464def test_fold_bwd_dual_path(): 465 """Dual path testcase.""" 466 467 def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): 468 args = [x, conv_weight, out_bias] 469 y1 = relay.nn.conv2d( 470 x, 471 conv_weight, 472 channels=channels, 473 kernel_size=(3, 3), 474 padding=(1, 1), 475 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 476 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 477 ) 478 y1 = relay.nn.relu(y1) 479 y2 = relay.nn.conv2d( 480 x, 481 conv_weight, 482 channels=channels, 483 kernel_size=(3, 3), 484 padding=(1, 1), 485 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 486 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 487 ) 488 y2 = relay.nn.relu(y2) 489 y = relay.add(y1, y2) 490 y = relay.multiply(y, out_scale) 491 return relay.Function(args, y) 492 493 def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): 494 # use a fixed order of args so alpha equal check can pass 495 args = [x, conv_weight, out_bias] 496 if not blocking: 497 out_bias = relay.expand_dims(out_bias, axis=1, num_newaxis=2) 498 squeezed_scale = relay.squeeze(out_scale, axis=[1, 2]) 499 500 def fold_conv_weight(): 501 if blocking: 502 return relay.multiply( 503 conv_weight, 504 relay.reshape( 505 squeezed_scale, (channels // blocking[1], 1, 1, 1, 1, blocking[1]) 506 ), 507 ) 508 else: 509 return relay.multiply( 510 conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3) 511 ) 512 513 y1 = relay.nn.conv2d( 514 x, 515 fold_conv_weight(), 516 channels=channels, 517 kernel_size=(3, 3), 518 padding=(1, 1), 519 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 520 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 521 ) 522 y1 = relay.nn.relu(y1) 523 y2 = relay.nn.conv2d( 524 x, 525 fold_conv_weight(), 526 channels=channels, 527 kernel_size=(3, 3), 528 padding=(1, 1), 529 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 530 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 531 ) 532 y2 = relay.nn.relu(y2) 533 y = relay.add(y1, y2) 534 return relay.Function(args, y) 535 536 def check(shape, in_channels, channels, blocking): 537 x = relay.var("x", shape=shape) 538 weight = relay.var("weight") 539 if blocking: 540 out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1])) 541 out_scale = relay.const( 542 _get_positive_scale((channels // blocking[1], 1, 1, blocking[1])) 543 ) 544 else: 545 out_bias = relay.var("out_bias", shape=(channels,)) 546 out_scale = relay.const(_get_positive_scale((channels, 1, 1))) 547 548 y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking) 549 y1 = run_opt_pass(y1, transform.InferType()) 550 type_dict = {x.name_hint: x.checked_type for x in y1.params} 551 weight = relay.var("weight", type_dict["weight"]) 552 y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) 553 y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking) 554 y1_expected = run_opt_pass(y1_expected, transform.InferType()) 555 assert tvm.ir.structural_equal(y1_folded, y1_expected) 556 557 check((2, 4, 10, 10), 4, 8, None) 558 check((2, 2, 10, 10, 2), 4, 8, (2, 2)) 559 560 561def test_fold_bwd_dual_consumer(): 562 def before(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): 563 args = [x, conv_weight, out_bias] 564 y0 = relay.nn.conv2d( 565 x, 566 conv_weight, 567 channels=channels, 568 kernel_size=(3, 3), 569 padding=(1, 1), 570 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 571 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 572 ) 573 y0 = relay.multiply(y0, out_scale) 574 y0 = relay.nn.relu(y0) 575 576 y1 = relay.nn.conv2d( 577 y0, 578 conv_weight, 579 channels=channels, 580 kernel_size=(3, 3), 581 padding=(1, 1), 582 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 583 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 584 ) 585 y1 = relay.multiply(y1, out_scale) 586 y1 = relay.nn.relu(y1) 587 588 y2 = relay.nn.conv2d( 589 y0, 590 conv_weight, 591 channels=channels, 592 kernel_size=(3, 3), 593 padding=(1, 1), 594 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 595 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 596 ) 597 y2 = relay.multiply(y2, out_scale) 598 y2 = relay.nn.relu(y2) 599 600 y = relay.add(y1, y2) 601 return relay.Function(args, y) 602 603 def expected(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): 604 # use a fixed order of args so alpha equal check can pass 605 args = [x, conv_weight, out_bias] 606 607 def fold_conv_weight(): 608 squeezed_scale = relay.squeeze(out_scale, axis=[1, 2]) 609 if blocking: 610 return relay.multiply( 611 conv_weight, 612 relay.reshape( 613 squeezed_scale, (channels // blocking[1], 1, 1, 1, 1, blocking[1]) 614 ), 615 ) 616 else: 617 return relay.multiply( 618 conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3) 619 ) 620 621 y0 = relay.nn.conv2d( 622 x, 623 fold_conv_weight(), 624 channels=channels, 625 kernel_size=(3, 3), 626 padding=(1, 1), 627 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 628 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 629 ) 630 y0 = relay.nn.relu(y0) 631 y1 = relay.nn.conv2d( 632 y0, 633 fold_conv_weight(), 634 channels=channels, 635 kernel_size=(3, 3), 636 padding=(1, 1), 637 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 638 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 639 ) 640 y1 = relay.nn.relu(y1) 641 y2 = relay.nn.conv2d( 642 y0, 643 fold_conv_weight(), 644 channels=channels, 645 kernel_size=(3, 3), 646 padding=(1, 1), 647 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 648 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 649 ) 650 y2 = relay.nn.relu(y2) 651 y = relay.add(y1, y2) 652 return relay.Function(args, y) 653 654 def check(shape, in_channels, channels, blocking): 655 x = relay.var("x", shape=shape) 656 weight = relay.var("weight") 657 if blocking: 658 out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1])) 659 out_scale = relay.const( 660 _get_positive_scale((channels // blocking[1], 1, 1, blocking[1])) 661 ) 662 else: 663 out_bias = relay.var("out_bias", shape=(channels,)) 664 out_scale = relay.const(_get_positive_scale((channels, 1, 1))) 665 666 y1 = before(x, weight, out_bias, out_scale, in_channels, channels, blocking) 667 y1 = run_opt_pass(y1, transform.InferType()) 668 type_dict = {x.name_hint: x.checked_type for x in y1.params} 669 weight = relay.var("weight", type_dict["weight"]) 670 y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) 671 y1_expected = expected(x, weight, out_bias, out_scale, in_channels, channels, blocking) 672 y1_expected = run_opt_pass(y1_expected, transform.InferType()) 673 assert tvm.ir.structural_equal(y1_folded, y1_expected) 674 675 check((2, 4, 10, 10), 4, 4, None) 676 check((2, 2, 10, 10, 2), 4, 4, (2, 2)) 677 678 679def test_fold_bwd_fail(): 680 """Dual path testcase.""" 681 682 def fail1(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): 683 args = [x, conv_weight, out_bias] 684 y1 = relay.nn.conv2d( 685 x, 686 conv_weight, 687 channels=channels, 688 kernel_size=(3, 3), 689 padding=(1, 1), 690 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 691 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 692 ) 693 y1 = relay.nn.relu(y1) 694 y2 = relay.nn.conv2d( 695 x, 696 conv_weight, 697 channels=channels, 698 kernel_size=(3, 3), 699 padding=(1, 1), 700 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 701 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 702 out_layout="CNHW{}c".format(blocking[1]) if blocking else "CNHW", 703 ) 704 # fold will fail because the axis from two path 705 # differs from each other. 706 y2 = relay.nn.relu(y2) 707 y = relay.add(y1, y2) 708 y = relay.multiply(y, out_scale) 709 return relay.Function(args, y) 710 711 def fail2(x, conv_weight, out_bias, out_scale, in_channels, channels, blocking): 712 args = [x, conv_weight, out_bias] 713 y1 = relay.nn.conv2d( 714 x, 715 conv_weight, 716 channels=channels, 717 kernel_size=(3, 3), 718 padding=(1, 1), 719 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 720 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 721 ) 722 y2 = relay.nn.relu(y1) 723 # fold will fail because y1 is referred also by y2 724 y1 = relay.multiply(y1, out_scale) 725 y = relay.add(y1, y2) 726 return relay.Function(args, y) 727 728 def check(shape, in_channels, channels, blocking, fbefore): 729 x = relay.var("x", shape=shape) 730 weight = relay.var("weight") 731 if blocking: 732 out_bias = relay.var("out_bias", shape=(channels // blocking[1], 1, 1, blocking[1])) 733 out_scale = relay.const( 734 _get_positive_scale((channels // blocking[1], 1, 1, blocking[1])) 735 ) 736 else: 737 out_bias = relay.var("out_bias", shape=(channels, 1, 1)) 738 out_scale = relay.const(_get_positive_scale((channels, 1, 1))) 739 y1 = fbefore(x, weight, out_bias, out_scale, in_channels, channels, blocking) 740 y1 = run_opt_pass(y1, transform.InferType()) 741 y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) 742 assert tvm.ir.structural_equal(y1_folded, y1) 743 744 check((4, 4, 10, 10), 4, 4, None, fail1) 745 check((2, 2, 10, 10, 2), 4, 4, (2, 2), fail1) 746 check((4, 4, 10, 10), 4, 4, None, fail2) 747 check((4, 2, 10, 10, 2), 4, 4, (2, 2), fail2) 748 749 750def test_fold_bwd_relu_fail(): 751 """testcase where we canont fold because scale can not pass relu""" 752 753 def before(x, conv_weight, out_scale, channels, blocking): 754 y = relay.nn.conv2d( 755 x, 756 conv_weight, 757 channels=channels, 758 kernel_size=(3, 3), 759 padding=(1, 1), 760 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 761 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 762 ) 763 y = relay.nn.relu(y) 764 y = relay.multiply(x, out_scale) 765 return relay.Function(relay.analysis.free_vars(y), y) 766 767 def check(shape, channels, blocking, out_scale): 768 x = relay.var("x", shape=shape) 769 in_channels = shape[1] 770 weight = relay.var("weight") 771 y1 = before(x, weight, out_scale, channels, blocking) 772 y1 = run_opt_pass(y1, transform.InferType()) 773 y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) 774 assert tvm.ir.structural_equal(y1, y1_folded) 775 776 out_scale = relay.var("in_scale", shape=(4, 1, 1)) 777 check((4, 4, 10, 10), 4, None, out_scale) 778 out_scale = relay.const(np.random.uniform(size=(4, 1, 1), low=-1.0, high=0.0)).astype("float32") 779 check((4, 4, 10, 10), 4, None, out_scale) 780 781 out_scale = relay.var("in_scale", shape=(1, 2, 1, 1, 2)) 782 check((4, 2, 10, 10, 2), 4, (2, 2), out_scale) 783 out_scale = relay.const(np.random.uniform(size=(1, 2, 1, 1, 2), low=-1.0, high=0.0)).astype( 784 "float32" 785 ) 786 check((4, 2, 10, 10, 2), 4, (2, 2), out_scale) 787 788 789def test_fold_bwd_negative_scale(): 790 """Testcase of folding negative scale""" 791 792 def before(x, conv_weight, out_scale, channels, blocking): 793 args = [x, conv_weight] 794 y = relay.nn.conv2d( 795 x, 796 conv_weight, 797 channels=channels, 798 kernel_size=(3, 3), 799 padding=(1, 1), 800 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 801 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 802 ) 803 y = relay.multiply(y, out_scale) 804 return relay.Function(args, y) 805 806 def expected(x, conv_weight, out_scale, channels, blocking): 807 # use a fixed order of args so alpha equal check can pass 808 args = [x, conv_weight] 809 if blocking: 810 squeezed_scale = relay.squeeze(out_scale, axis=[0, 2, 3]) 811 conv_weight = relay.multiply( 812 conv_weight, 813 relay.reshape(squeezed_scale, (channels // blocking[1], 1, 1, 1, 1, blocking[1])), 814 ) 815 else: 816 squeezed_scale = relay.squeeze(out_scale, axis=[1, 2]) 817 conv_weight = relay.multiply( 818 conv_weight, relay.expand_dims(squeezed_scale, axis=1, num_newaxis=3) 819 ) 820 y = relay.nn.conv2d( 821 x, 822 conv_weight, 823 channels=channels, 824 kernel_size=(3, 3), 825 padding=(1, 1), 826 data_layout="NCHW{}c".format(blocking[0]) if blocking else "NCHW", 827 kernel_layout="OIHW1i{}o".format(blocking[1]) if blocking else "OIHW", 828 ) 829 return relay.Function(args, y) 830 831 def check(shape, channels, blocking): 832 x = relay.var("x", shape=shape) 833 weight = relay.var("weight") 834 if blocking: 835 out_scale = relay.const( 836 -_get_positive_scale((1, channels // blocking[1], 1, 1, blocking[1])) 837 ) 838 else: 839 out_scale = relay.const(-_get_positive_scale((channels, 1, 1))) 840 y1 = before(x, weight, out_scale, channels, blocking) 841 y1 = run_opt_pass(y1, transform.InferType()) 842 type_dict = {x.name_hint: x.checked_type for x in y1.params} 843 weight = relay.var("weight", type_dict["weight"]) 844 y1_folded = run_opt_pass(y1, transform.BackwardFoldScaleAxis()) 845 y1_expected = expected(x, weight, out_scale, channels, blocking) 846 y1_expected = run_opt_pass(y1_expected, transform.InferType()) 847 assert tvm.ir.structural_equal(y1_folded, y1_expected) 848 849 check((2, 4, 10, 10), 8, None) 850 check((2, 2, 10, 10, 2), 8, (2, 2)) 851 852 853if __name__ == "__main__": 854 test_fold_fwd_simple() 855 test_fold_fwd_dual_path() 856 test_fold_fwd_fail() 857 test_fold_fwd_relu_fail() 858 test_fold_fwd_negative_scale() 859 test_fold_bwd_simple() 860 test_fold_bwd_dual_path() 861 test_fold_bwd_dual_consumer() 862 test_fold_bwd_fail() 863 test_fold_bwd_relu_fail() 864 test_fold_bwd_negative_scale() 865