1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, unused-argument, too-many-arguments 18"""Backend compiler related feature registration""" 19from __future__ import absolute_import 20 21import topi 22from topi.util import get_const_tuple 23from .. import op as reg 24from ..op import OpPattern, schedule_injective 25from .._tensor import elemwise_shape_func 26from ....api import convert 27from ....hybrid import script 28 29# relu 30reg.register_schedule("nn.relu", schedule_injective) 31reg.register_pattern("nn.relu", OpPattern.ELEMWISE) 32 33# softmax 34@reg.register_schedule("nn.softmax") 35def schedule_softmax(_, outputs, target): 36 """Schedule definition of softmax""" 37 with target: 38 return topi.generic.schedule_softmax(outputs) 39 40 41reg.register_pattern("nn.softmax", OpPattern.OPAQUE) 42 43schedule_broadcast = schedule_injective 44 45 46@reg.register_schedule("nn.log_softmax") 47def schedule_log_softmax(_, outputs, target): 48 """Schedule definition of log_softmax""" 49 with target: 50 return topi.generic.schedule_softmax(outputs) 51 52 53reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE) 54 55 56# dense 57@reg.register_compute("nn.dense") 58def compute_dense(attrs, inputs, out_type, target): 59 """Compute definition of dense""" 60 out_dtype = attrs.out_dtype 61 out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype 62 return [topi.nn.dense(inputs[0], inputs[1], None, out_dtype)] 63 64 65@reg.register_schedule("nn.dense") 66def schedule_dense(attrs, outputs, target): 67 """Schedule definition of dense""" 68 with target: 69 return topi.generic.schedule_dense(outputs) 70 71 72reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) 73 74 75@reg.register_compute('nn.fifo_buffer') 76def compute_fifo_buffer(attrs, inputs, out_type, target): 77 return [topi.nn.fifo_buffer(inputs[0], inputs[1], axis=attrs.get_int('axis'))] 78 79 80@reg.register_schedule('nn.fifo_buffer') 81def schedule_fifo_buffer(attrs, outputs, target): 82 with target: 83 return topi.generic.schedule_injective(outputs) 84 85 86reg.register_pattern("nn.fifo_buffer", OpPattern.OPAQUE) 87 88 89# batch_matmul 90@reg.register_compute("nn.batch_matmul") 91def compute_batch_matmul(attrs, inputs, out_type, target): 92 """Compute definition of batch_matmul""" 93 with target: 94 return [topi.nn.batch_matmul(inputs[0], inputs[1])] 95 96 97@reg.register_schedule("nn.batch_matmul") 98def schedule_batch_matmul(attrs, outputs, target): 99 """Schedule definition of batch_matmul""" 100 with target: 101 return topi.generic.schedule_batch_matmul(outputs) 102 103 104reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE) 105 106# sparse_dense 107@reg.register_compute("nn.sparse_dense") 108def compute_sparse_dense(attrs, inputs, out_type, target): 109 """Compute definition of sparse_dense""" 110 return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])] 111 112@reg.register_schedule("nn.sparse_dense") 113def schedule_sparse_dense(attrs, outputs, target): 114 """Schedule definition of batch_matmul""" 115 with target: 116 return topi.generic.schedule_sparse_dense(outputs) 117 118reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) 119 120# sparse_transpose 121@reg.register_compute("nn.sparse_transpose") 122def compute_sparse_transpose(attrs, inputs, out_type, target): 123 """Compute definition of sparse_transpose""" 124 return topi.nn.sparse_transpose(inputs[0], inputs[1], inputs[2]) 125 126@reg.register_schedule("nn.sparse_transpose") 127def schedule_sparse_transpose(attrs, outputs, target): 128 """Schedule definition of batch_matmul""" 129 with target: 130 return topi.generic.schedule_sparse_transpose(outputs) 131 132reg.register_pattern("nn.sparse_transpose", reg.OpPattern.OUT_ELEMWISE_FUSABLE) 133 134# conv2d 135def _find_conv2d_op(op): 136 """Find the op with conv2d in its tag by traversing.""" 137 if 'conv2d' in op.tag: 138 return op 139 for tensor in op.input_tensors: 140 op_ = _find_conv2d_op(tensor.op) 141 if op_ is not None: 142 return op_ 143 return None 144 145 146@reg.register_compute("nn.conv2d") 147def compute_conv2d(attrs, inputs, out_type, target): 148 """Compute definition of conv2d""" 149 padding = get_const_tuple(attrs.padding) 150 strides = get_const_tuple(attrs.strides) 151 dilation = get_const_tuple(attrs.dilation) 152 groups = attrs.groups 153 layout = attrs.data_layout 154 kernel_layout = attrs.kernel_layout 155 out_dtype = attrs.out_dtype 156 out_dtype = (inputs[0].dtype if out_dtype in ("same", "") 157 else out_dtype) 158 159 assert layout in ["NCHW", "NHWC", "NCHW4c", "HWCN"] 160 (dilation_h, dilation_w) = dilation 161 if dilation_h < 1 or dilation_w < 1: 162 raise ValueError("dilation should be positive value") 163 164 def _get_out_depth(): 165 weight_shape = get_const_tuple(inputs[1].shape) 166 if kernel_layout.startswith("HW"): 167 return weight_shape[2] * weight_shape[3] 168 return weight_shape[0] * weight_shape[1] 169 170 if groups == 1: 171 out = topi.nn.conv2d( 172 inputs[0], inputs[1], strides, padding, 173 dilation, layout, out_dtype) 174 elif layout == "NCHW" and _get_out_depth() == groups: 175 out = topi.nn.depthwise_conv2d_nchw( 176 inputs[0], inputs[1], strides, padding, dilation, out_dtype) 177 elif layout == "NHWC" and kernel_layout == "HWOI" and _get_out_depth() == groups: 178 out = topi.nn.depthwise_conv2d_nhwc( 179 inputs[0], inputs[1], strides, padding, dilation, out_dtype) 180 elif layout in ['NCHW', 'NCHW4c']: 181 out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups, 182 out_dtype) 183 else: 184 raise ValueError("not support arbitrary group number for now") 185 return [out] 186 187 188@reg.register_schedule("nn.conv2d") 189def schedule_conv2d(attrs, outs, target): 190 """Schedule definition of conv2d""" 191 groups = attrs.groups 192 layout = attrs.data_layout 193 kernel_layout = attrs.kernel_layout 194 195 with target: 196 if groups == 1 and layout == "NCHW": 197 return topi.generic.schedule_conv2d_nchw(outs) 198 elif groups == 1 and layout == "NCHW4c": 199 return topi.generic.schedule_conv2d_nchw(outs) 200 elif groups == 1 and layout == "NHWC": 201 return topi.generic.schedule_conv2d_nhwc(outs) 202 elif groups == 1 and layout == "HWCN": 203 return topi.generic.schedule_conv2d_hwcn(outs) 204 elif groups != 1: 205 # collect in_channels to distinguish depthwise and group conv2d 206 op = _find_conv2d_op(outs[0].op) 207 assert op is not None 208 209 is_depthwise = 'depthwise' in op.tag 210 if is_depthwise: 211 if layout == "NCHW": 212 # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d. 213 return topi.generic.schedule_depthwise_conv2d_nchw(outs) 214 if layout == "NHWC" and kernel_layout == "HWOI": 215 return topi.generic.schedule_depthwise_conv2d_nhwc(outs) 216 else: 217 if layout in ["NCHW", "NCHW4c"]: 218 return topi.generic.schedule_group_conv2d_nchw(outs) 219 raise ValueError("No compatible schedule") 220 221 222@reg.register_alter_op_layout("nn.conv2d") 223def alter_op_layout_conv2d(attrs, inputs, tinfos): 224 """Alternate the layout of conv2d""" 225 from ... import op 226 return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op) 227 228@reg.register_legalize("nn.conv2d") 229def legalize_conv2d(attrs, inputs, types): 230 """Legalize conv2d op. 231 232 Parameters 233 ---------- 234 attrs : tvm.attrs.Attrs 235 Attributes of current convolution 236 inputs : list of tvm.relay.Expr 237 The args of the Relay expr to be legalized 238 types : list of types 239 List of input and output types 240 241 Returns 242 ------- 243 result : tvm.relay.Expr 244 The legalized expr 245 """ 246 return topi.nn.conv2d_legalize(attrs, inputs, types) 247 248reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) 249 250 251# conv2d_transpose 252@reg.register_compute("nn.conv2d_transpose") 253def compute_conv2d_transpose(attrs, inputs, out_dtype, target): 254 """Compute definition of conv2d_transpose""" 255 padding = get_const_tuple(attrs.padding) 256 strides = get_const_tuple(attrs.strides) 257 dilation = get_const_tuple(attrs.dilation) 258 groups = attrs.groups 259 layout = attrs.data_layout 260 out_dtype = attrs.out_dtype 261 out_dtype = (inputs[0].dtype if out_dtype in ("same", "") 262 else out_dtype) 263 assert layout == "NCHW", "only support nchw for now" 264 assert dilation == (1, 1), "not support dilate now" 265 assert groups == 1, "only support groups == 1 for now" 266 out = topi.nn.conv2d_transpose_nchw( 267 inputs[0], inputs[1], strides, padding, out_dtype) 268 output_padding = get_const_tuple(attrs.output_padding) 269 out = topi.nn.pad(out, 270 [0, 0, 0, 0], [0, 0, output_padding[0], output_padding[1]]) 271 return [out] 272 273 274@reg.register_schedule("nn.conv2d_transpose") 275def schedule_conv2d_transpose(attrs, outs, target): 276 """Schedule definition of conv2d_transpose""" 277 with target: 278 return topi.generic.schedule_conv2d_transpose_nchw(outs) 279 280 281@reg.register_legalize("nn.conv2d_transpose") 282def legalize_conv2d_transpose(attrs, inputs, types): 283 """Legalize conv2d_transpose op. 284 285 Parameters 286 ---------- 287 attrs : tvm.attrs.Attrs 288 Attributes of current Transposed convolution 289 inputs : list of tvm.relay.Expr 290 The args of the Relay expr to be legalized 291 types : list of types 292 List of input and output types 293 294 Returns 295 ------- 296 result : tvm.relay.Expr 297 The legalized expr 298 """ 299 return topi.nn.conv2d_transpose_legalize(attrs, inputs, types) 300 301reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE) 302 303# bias_add 304reg.register_schedule("nn.bias_add", schedule_injective) 305reg.register_pattern("nn.bias_add", OpPattern.BROADCAST) 306 307 308# max_pool2d 309@reg.register_schedule("nn.max_pool2d") 310def schedule_max_pool2d(attrs, outs, target): 311 """Schedule definition of max_pool2d""" 312 layout = attrs.layout 313 with target: 314 return topi.generic.schedule_pool(outs, layout) 315 316 317reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) 318 319 320# avg_pool2d 321@reg.register_schedule("nn.avg_pool2d") 322def schedule_avg_pool2d(attrs, outs, target): 323 """Schedule definition of avg_pool2d""" 324 layout = attrs.layout 325 with target: 326 return topi.generic.schedule_pool(outs, layout) 327 328 329reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) 330 331 332# max_pool2d_grad 333@reg.register_schedule("nn.max_pool2d_grad") 334def schedule_max_pool2d_grad(attrs, outs, target): 335 """Schedule definition of max_pool2d_grad""" 336 with target: 337 return topi.generic.schedule_pool_grad(outs) 338 339 340reg.register_pattern("nn.max_pool2d_grad", OpPattern.OUT_ELEMWISE_FUSABLE) 341 342 343# avg_pool2d_grad 344@reg.register_schedule("nn.avg_pool2d_grad") 345def schedule_avg_pool2d_grad(attrs, outs, target): 346 """Schedule definition of avg_pool2d_grad""" 347 with target: 348 return topi.generic.schedule_pool_grad(outs) 349 350 351reg.register_pattern("nn.avg_pool2d_grad", OpPattern.OUT_ELEMWISE_FUSABLE) 352 353 354# global_max_pool2d 355@reg.register_schedule("nn.global_max_pool2d") 356def schedule_global_max_pool2d(_, outs, target): 357 """Schedule definition of global_max_pool2d""" 358 with target: 359 return topi.generic.schedule_adaptive_pool(outs) 360 361 362reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) 363 364 365# global_avg_pool2d 366@reg.register_schedule("nn.global_avg_pool2d") 367def schedule_global_avg_pool2d(_, outs, target): 368 """Schedule definition of global_avg_pool2d""" 369 with target: 370 return topi.generic.schedule_adaptive_pool(outs) 371 372 373reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) 374 375 376# leaky_relu 377reg.register_schedule("nn.leaky_relu", schedule_broadcast) 378reg.register_pattern("nn.leaky_relu", OpPattern.ELEMWISE) 379 380# prelu 381reg.register_schedule("nn.prelu", schedule_broadcast) 382reg.register_pattern("nn.prelu", OpPattern.BROADCAST) 383 384# flatten 385reg.register_schedule("nn.batch_flatten", schedule_broadcast) 386reg.register_pattern("nn.batch_flatten", OpPattern.INJECTIVE) 387 388 389# lrn 390@reg.register_compute("nn.lrn") 391def compute_lrn(attrs, inputs, out_dtype, target): 392 """Compute definition of lrn""" 393 assert len(inputs) == 1 394 return [topi.nn.lrn(inputs[0], attrs.size, attrs.axis, 395 attrs.alpha, attrs.beta, attrs.bias)] 396 397 398@reg.register_schedule("nn.lrn") 399def schedule_lrn(attrs, outs, target): 400 """Schedule definition of lrn""" 401 with target: 402 return topi.generic.schedule_lrn(outs) 403 404 405reg.register_pattern("nn.lrn", OpPattern.OPAQUE) 406 407 408# l2_normalize 409@reg.register_compute("nn.l2_normalize") 410def compute_l2_normalize(attrs, inputs, out_dtype, target): 411 """Compute definition of l2 normalize""" 412 return [topi.nn.l2_normalize(inputs[0], attrs.eps, attrs.axis)] 413 414 415@reg.register_schedule("nn.l2_normalize") 416def schedule_l2_normalize(attrs, outs, target): 417 """Schedule definition of l2 normalize""" 418 with target: 419 return topi.generic.schedule_l2_normalize(outs) 420 421 422reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE) 423 424# upsampling 425reg.register_schedule("nn.upsampling", reg.schedule_injective) 426 427 428def schedule_upsampling(_, outs, target): 429 """Schedule definition of upsampling""" 430 with target: 431 return topi.generic.schedule_injective(outs) 432 433@reg.register_compute("nn.upsampling") 434def compute_upsampling(attrs, inputs, out_dtype, target): 435 scale_h = attrs.scale_h 436 scale_w = attrs.scale_w 437 layout = attrs.layout 438 method = attrs.method 439 align_corners = attrs.align_corners 440 return [topi.nn.upsampling(inputs[0], scale_h, scale_w, layout, method, align_corners)] 441 442# pad 443reg.register_schedule("nn.pad", schedule_broadcast) 444 445# mirror_pad 446reg.register_schedule("nn.mirror_pad", schedule_broadcast) 447 448@reg.register_compute("nn.mirror_pad") 449def compute_mirror_pad(attrs, inputs, out_dtype, target): 450 pad_before, pad_after = list(zip(*attrs.pad_width)) 451 mode = attrs.mode 452 out = topi.nn.mirror_pad(inputs[0], pad_before=pad_before, pad_after=pad_after, mode=mode) 453 return [out] 454 455# winograd related operators 456@reg.register_compute("nn.contrib_conv2d_winograd_without_weight_transform") 457def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, out_dtype, target): 458 """Compute definition of conv2d_winograd_without_weight_transform""" 459 # pylint: disable=assignment-from-no-return 460 padding = attrs.get_int_tuple("padding") 461 strides = attrs.get_int_tuple("strides") 462 dilation = attrs.get_int_tuple("dilation") 463 groups = attrs.get_int("groups") 464 data_layout = attrs.get_str("data_layout") 465 out_dtype = attrs.get_str("out_dtype") 466 tile_size = attrs.get_int("tile_size") 467 out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype 468 assert dilation == (1, 1), "Do not support dilate now" 469 assert groups == 1, "Do not supoort arbitrary group number" 470 471 out = topi.nn.conv2d_winograd_without_weight_transform( 472 inputs[0], inputs[1], strides, padding, dilation, data_layout, 473 out_dtype, tile_size) 474 475 return [out] 476 477 478@reg.register_schedule("nn.contrib_conv2d_winograd_without_weight_transform") 479def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target): 480 """Schedule definition of conv2d_winograd_without_weight_transform""" 481 with target: 482 return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs) 483 484 485reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform", 486 OpPattern.OUT_ELEMWISE_FUSABLE) 487 488 489@reg.register_compute("nn.contrib_conv2d_winograd_weight_transform") 490def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype, target): 491 """Compute definition of contrib_conv2d_winograd_weight_transform""" 492 out = topi.nn.conv2d_winograd_weight_transform( 493 inputs[0], attrs.get_int('tile_size')) 494 return [out] 495 496 497@reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform") 498def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target): 499 """Schedule definition of contrib_conv2d_winograd_weight_transform""" 500 with target: 501 return topi.generic.schedule_conv2d_winograd_weight_transform(outs) 502 503 504reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform", 505 OpPattern.OUT_ELEMWISE_FUSABLE) 506 507 508# winograd nnpack related operators 509@reg.register_compute("nn.contrib_conv2d_winograd_nnpack_without_weight_transform") 510def compute_contrib_conv2d_winograd_nnpack_without_weight_transform( 511 attrs, inputs, out_dtype, target): 512 """Compute definition of conv2d_winograd_nnpack_without_weight_transform""" 513 # pylint: disable=assignment-from-no-return 514 padding = attrs.get_int_tuple("padding") 515 strides = attrs.get_int_tuple("strides") 516 dilation = attrs.get_int_tuple("dilation") 517 groups = attrs.get_int("groups") 518 data_layout = attrs.get_str("data_layout") 519 out_dtype = attrs.get_str("out_dtype") 520 out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype 521 assert dilation == (1, 1), "Do not support dilate now" 522 assert groups == 1, "Do not supoort arbitrary group number" 523 524 # No bias 525 out = topi.nn.conv2d_winograd_nnpack_without_weight_transform( 526 inputs[0], inputs[1], None, strides, padding, dilation, data_layout, 527 out_dtype) 528 529 return [out] 530 531 532@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_without_weight_transform") 533def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target): 534 """Schedule definition of conv2d_winograd_nnpack_without_weight_transform""" 535 with target: 536 return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs) 537 538 539reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_without_weight_transform", 540 OpPattern.OPAQUE) 541 542 543@reg.register_compute("nn.contrib_conv2d_winograd_nnpack_weight_transform") 544def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_dtype, target): 545 """Compute definition of contrib_conv2d_winograd_nnpack_weight_transform""" 546 convolution_algorithm = attrs.get_int('convolution_algorithm') 547 out = topi.nn.conv2d_winograd_nnpack_weight_transform( 548 inputs[0], convolution_algorithm, out_dtype) 549 return [out] 550 551 552@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform") 553def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target): 554 """Schedule definition of contrib_conv2d_winograd_nnpack_weight_transform""" 555 with target: 556 return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs) 557 558 559reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform", 560 OpPattern.OPAQUE) 561 562 563@reg.register_compute("nn.contrib_conv2d_NCHWc") 564def compute_contrib_conv2d_NCHWc(attrs, inputs, out_dtype, target): 565 """Compute definition of conv2d NCHWc""" 566 # pylint: disable=assignment-from-no-return 567 padding = attrs.get_int_tuple("padding") 568 strides = attrs.get_int_tuple("strides") 569 dilation = attrs.get_int_tuple("dilation") 570 data_layout = attrs.get_str("data_layout") 571 out_layout = attrs.get_str("out_layout") 572 out_dtype = attrs.get_str("out_dtype") 573 out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype 574 575 out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation, 576 data_layout, out_layout, out_dtype) 577 return [out] 578 579 580@reg.register_schedule("nn.contrib_conv2d_NCHWc") 581def schedule_contrib_conv2d_NCHWc(attrs, outs, target): 582 """Schedule definition of contrib_conv2d_NCHWc""" 583 with target: 584 return topi.generic.schedule_conv2d_NCHWc(outs) 585 586 587reg.register_pattern("nn.contrib_conv2d_NCHWc", 588 OpPattern.OUT_ELEMWISE_FUSABLE) 589 590 591@reg.register_compute("nn.contrib_conv2d_NCHWc_int8") 592def compute_contrib_conv2d_NCHWc_int8(attrs, inputs, out_dtype, target): 593 """Compute definition of conv2d NCHWc""" 594 # pylint: disable=assignment-from-no-return 595 padding = attrs.get_int_tuple("padding") 596 strides = attrs.get_int_tuple("strides") 597 dilation = attrs.get_int_tuple("dilation") 598 data_layout = attrs.get_str("data_layout") 599 out_layout = attrs.get_str("out_layout") 600 out_dtype = attrs.get_str("out_dtype") 601 out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype 602 603 out = topi.nn.conv2d_NCHWc_int8(inputs[0], inputs[1], strides, padding, dilation, 604 data_layout, out_layout, out_dtype) 605 return [out] 606 607 608@reg.register_schedule("nn.contrib_conv2d_NCHWc_int8") 609def schedule_contrib_conv2d_NCHWc_int8(attrs, outs, target): 610 """Schedule definition of contrib_conv2d_NCHWc_int8""" 611 with target: 612 return topi.generic.schedule_conv2d_NCHWc_int8(outs) 613 614 615reg.register_pattern("nn.contrib_conv2d_NCHWc_int8", 616 OpPattern.OUT_ELEMWISE_FUSABLE) 617 618 619@reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc") 620def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target): 621 """Compute definition of depthwise conv2d NCHWc""" 622 # pylint: disable=assignment-from-no-return 623 padding = attrs.get_int_tuple("padding") 624 strides = attrs.get_int_tuple("strides") 625 dilation = attrs.get_int_tuple("dilation") 626 data_layout = attrs.get_str("data_layout") 627 out_layout = attrs.get_str("out_layout") 628 out_dtype = attrs.get_str("out_dtype") 629 out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype 630 631 out = topi.nn.depthwise_conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation, 632 data_layout, out_layout, out_dtype) 633 return [out] 634 635 636@reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc") 637def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target): 638 """Schedule definition of contrib_conv2d_NCHWc""" 639 with target: 640 return topi.generic.schedule_depthwise_conv2d_NCHWc(outs) 641 642 643reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc", 644 OpPattern.OUT_ELEMWISE_FUSABLE) 645 646 647@reg.register_compute("nn.deformable_conv2d") 648def compute_deformable_conv2d(attrs, inputs, out_dtype, target): 649 """Compute definition of deformable_conv2d""" 650 padding = get_const_tuple(attrs.padding) 651 strides = get_const_tuple(attrs.strides) 652 dilation = get_const_tuple(attrs.dilation) 653 deformable_groups = attrs.deformable_groups 654 groups = attrs.groups 655 out_dtype = attrs.out_dtype 656 out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype 657 with target: 658 out = topi.nn.deformable_conv2d_nchw(inputs[0], inputs[1], inputs[2], strides, padding, 659 dilation, deformable_groups, groups, out_dtype) 660 return [out] 661 662 663@reg.register_schedule("nn.deformable_conv2d") 664def schedule_deformable_conv2d(attrs, outs, target): 665 """Schedule definition of deformable_conv2d""" 666 with target: 667 return topi.generic.schedule_deformable_conv2d_nchw(outs) 668 669 670reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) 671 672 673@reg.register_compute("nn.bitpack") 674def compute_bitpack(attrs, inputs, out_dtype, target): 675 """Compute definition for bitpack""" 676 bits = attrs.bits 677 pack_axis = attrs.pack_axis 678 bit_axis = attrs.bit_axis 679 pack_type = attrs.pack_type 680 name = attrs.name 681 with target: 682 out = topi.nn.bitpack(inputs[0], bits, pack_axis, bit_axis, pack_type, 683 name) 684 return [out] 685 686@reg.register_schedule("nn.bitpack") 687def schedule_bitpack(attrs, outs, target): 688 with target: 689 return topi.generic.schedule_bitpack(outs) 690 691reg.register_pattern("nn.bitpack", OpPattern.INJECTIVE) 692 693 694@reg.register_compute("nn.bitserial_conv2d") 695def compute_bitserial_conv2d(attrs, inputs, out_dtype, target): 696 """Compute definition for bitserial conv2d.""" 697 padding = get_const_tuple(attrs.padding) 698 strides = get_const_tuple(attrs.strides) 699 activation_bits = attrs.activation_bits 700 weight_bits = attrs.weight_bits 701 layout = attrs.data_layout 702 pack_dtype = attrs.pack_dtype 703 out_dtype = attrs.out_dtype 704 unipolar = attrs.unipolar 705 if layout == 'NCHW': 706 with target: 707 out = topi.nn.bitserial_conv2d_nchw( 708 inputs[0], inputs[1], strides, padding, activation_bits, 709 weight_bits, pack_dtype, out_dtype, unipolar) 710 elif layout == 'NHWC': 711 with target: 712 out = topi.nn.bitserial_conv2d_nhwc( 713 inputs[0], inputs[1], strides, padding, activation_bits, 714 weight_bits, pack_dtype, out_dtype, unipolar) 715 else: 716 raise ValueError("Data layout not supported.") 717 718 return [out] 719 720 721@reg.register_schedule("nn.bitserial_conv2d") 722def schedule_bitserial_conv2d(attrs, outs, target): 723 """Schedule definition for bitserial conv2d.""" 724 layout = attrs.data_layout 725 if layout == 'NCHW': 726 with target: 727 return topi.generic.schedule_bitserial_conv2d_nchw(outs) 728 elif layout == 'NHWC': 729 with target: 730 return topi.generic.schedule_bitserial_conv2d_nhwc(outs) 731 else: 732 raise ValueError("Data layout not supported.") 733 734@reg.register_legalize("nn.bitserial_conv2d") 735def legalize_bitserial_conv2d(attrs, inputs, types): 736 """Legalize bitserial_conv2d op. 737 738 Parameters 739 ---------- 740 attrs : tvm.attrs.Attrs 741 Attributes of current convolution 742 inputs : list of tvm.relay.Expr 743 The args of the Relay expr to be legalized 744 types : list of types 745 List of input and output types 746 747 Returns 748 ------- 749 result : tvm.relay.Expr 750 The legalized expr 751 """ 752 return topi.nn.bitserial_conv2d_legalize(attrs, inputs, types) 753 754 755reg.register_pattern("nn.bitserial_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) 756 757 758# bitserial_dense 759@reg.register_compute("nn.bitserial_dense") 760def compute_bitserial_dense(attrs, inputs, out_type, target): 761 """Compute definition of bitserial_dense""" 762 data_bits = attrs.data_bits 763 weight_bits = attrs.weight_bits 764 pack_dtype = attrs.pack_dtype 765 out_dtype = attrs.out_dtype 766 out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype 767 unipolar = attrs.unipolar 768 return [ 769 topi.nn.bitserial_dense( 770 inputs[0], 771 inputs[1], 772 data_bits, 773 weight_bits, 774 pack_dtype, 775 out_dtype, 776 unipolar) 777 ] 778 779 780@reg.register_schedule("nn.bitserial_dense") 781def schedule_bitserial_dense(attrs, outputs, target): 782 """Schedule definition of bitserial_dense""" 783 with target: 784 return topi.generic.schedule_bitserial_dense(outputs) 785 786 787reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) 788 789 790reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE) 791 792@reg.register_compute("nn.cross_entropy") 793def compute_cross_entropy(attrs, inputs, out_dtype, target): 794 x, y = inputs 795 return [-topi.sum(topi.log(x) * y) / x.shape[0]] 796 797 798reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE) 799 800@reg.register_compute("nn.cross_entropy_with_logits") 801def compute_cross_entropy_with_logits(attrs, inputs, out_dtype, target): 802 x, y = inputs 803 return [-topi.sum(x * y) / x.shape[0]] 804 805# shape func 806@script 807def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn): 808 out = output_tensor((dshape.shape[0],), "int64") 809 ic_chunk = dshape[1] 810 height = dshape[2] 811 width = dshape[3] 812 ic_bn = dshape[4] 813 kheight = kshape[2] 814 kwidth = kshape[3] 815 dilated_kh = (kheight - 1) * dilation[0] + 1 816 dilated_kw = (kwidth - 1) * dilation[1] + 1 817 kflatten = int64(1) 818 for i in const_range(kshape.shape[0]): 819 kflatten *= kshape[i] 820 821 oc = kflatten // (kheight * kwidth * ic_chunk * ic_bn) 822 oc_chunk = oc // oc_bn 823 824 out_height = (height + 2 * padding[0] - dilated_kh) // strides[0] + 1 825 out_width = (width + 2 * padding[1] - dilated_kw) // strides[1] + 1 826 827 out[0] = dshape[0] 828 out[1] = oc_chunk 829 out[2] = out_height 830 out[3] = out_width 831 out[4] = int64(oc_bn) 832 return out 833 834@reg.register_shape_func("nn.contrib_conv2d_NCHWc", False) 835def conv2d_NCHWc_shape_func(attrs, inputs, _): 836 """ 837 Shape function for contrib_conv2d_NCHWc op. 838 """ 839 strides = get_const_tuple(attrs.strides) 840 padding = get_const_tuple(attrs.padding) 841 dilation = get_const_tuple(attrs.dilation) 842 out_layout = attrs.out_layout 843 oc_bn = int(out_layout[4:-1]) 844 845 return [_conv2d_NCHWc_shape_func(inputs[0], inputs[1], 846 convert(strides), convert(padding), 847 convert(dilation), convert(oc_bn))] 848 849@script 850def _pool2d_shape_func(data_shape, pool_size, strides, 851 padding, height_axis, width_axis): 852 out = output_tensor((data_shape.shape[0],), "int64") 853 for i in const_range(data_shape.shape[0]): 854 if i == height_axis: 855 out[i] = (data_shape[i] + padding[0] + padding[2] - pool_size[0]) // strides[0] + 1 856 elif i == width_axis: 857 out[i] = (data_shape[i] + padding[1] + padding[3] - pool_size[1]) // strides[1] + 1 858 else: 859 out[i] = data_shape[i] 860 861 return out 862 863def pool2d_shape_func(attrs, inputs, _): 864 """ 865 Shape function for pool2d op. 866 """ 867 pool_size = get_const_tuple(attrs.pool_size) 868 strides = get_const_tuple(attrs.strides) 869 padding = get_const_tuple(attrs.padding) 870 layout = attrs.layout 871 height_axis = layout.index("H") 872 width_axis = layout.index("W") 873 if len(padding) == 1: 874 padding = [padding[0]] * 4 875 elif len(padding) == 2: 876 padding = [padding[0], padding[1], padding[0], padding[1]] 877 878 return [_pool2d_shape_func(inputs[0], convert(pool_size), 879 convert(strides), convert(padding), 880 convert(height_axis), convert(width_axis))] 881 882reg.register_shape_func("nn.max_pool2d", False, pool2d_shape_func) 883reg.register_shape_func("nn.avg_pool2d", False, pool2d_shape_func) 884 885@script 886def _global_pool2d_shape_func(data_shape, height_axis, width_axis): 887 out = output_tensor((data_shape.shape[0],), "int64") 888 for i in const_range(out.shape[0]): 889 if i == height_axis or i == width_axis: 890 out[i] = int64(1) 891 else: 892 out[i] = data_shape[i] 893 894 return out 895 896def global_pool2d_shape_func(attrs, inputs, _): 897 """ 898 Shape function for global pool2d op. 899 """ 900 layout = attrs.layout 901 height_axis = width_axis = 1 902 for i, letter in enumerate(layout): 903 if letter == "H": 904 height_axis = i 905 if letter == "W": 906 width_axis = i 907 return [_global_pool2d_shape_func(inputs[0], convert(height_axis), convert(width_axis))] 908 909reg.register_shape_func("nn.global_max_pool2d", False, global_pool2d_shape_func) 910reg.register_shape_func("nn.global_avg_pool2d", False, global_pool2d_shape_func) 911 912@script 913def _batch_flatten_shape_func(data_shape): 914 out = output_tensor((2,), "int64") 915 out[0] = data_shape[0] 916 out[1] = int64(1) 917 for i in const_range(data_shape.shape[0] - 1): 918 out[1] *= data_shape[i + 1] 919 920 return out 921 922@reg.register_shape_func("nn.batch_flatten", False) 923def batch_flatten_shape_func(attrs, inputs, _): 924 """ 925 Shape function for batch_flatten op. 926 """ 927 return [_batch_flatten_shape_func(inputs[0])] 928 929@script 930def _dense_shape_func(data_shape, weight_shape): 931 out = output_tensor((data_shape.shape[0],), "int64") 932 for i in const_range(out.shape[0] - 1): 933 out[i] = data_shape[i] 934 out[out.shape[0] - 1] = weight_shape[0] 935 936 return out 937 938@reg.register_shape_func("nn.dense", False) 939def dense_shape_func(attrs, inputs, _): 940 """ 941 Shape function for dense op. 942 """ 943 ret = [_dense_shape_func(inputs[0], inputs[1])] 944 return ret 945 946@script 947def _pad_shape_func(data_shape, pad_width): 948 out = output_tensor((data_shape.shape[0],), "int64") 949 for i in const_range(out.shape[0]): 950 out[i] = data_shape[i] + pad_width[i][0] + pad_width[i][1] 951 952 return out 953 954@reg.register_shape_func("nn.pad", False) 955def pad_shape_func(attrs, inputs, _): 956 """ 957 Shape function for pad op. 958 """ 959 pad_width = [] 960 for pair in attrs.pad_width: 961 pad_width.append(get_const_tuple(pair)) 962 return [_pad_shape_func(inputs[0], convert(pad_width))] 963 964reg.register_shape_func("nn.bias_add", False, elemwise_shape_func) 965reg.register_shape_func("nn.softmax", False, elemwise_shape_func) 966reg.register_shape_func("nn.relu", False, elemwise_shape_func) 967