1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, unused-argument, too-many-arguments
18"""Backend compiler related feature registration"""
19from __future__ import absolute_import
20
21import topi
22from topi.util import get_const_tuple
23from .. import op as reg
24from ..op import OpPattern, schedule_injective
25from .._tensor import elemwise_shape_func
26from ....api import convert
27from ....hybrid import script
28
29# relu
30reg.register_schedule("nn.relu", schedule_injective)
31reg.register_pattern("nn.relu", OpPattern.ELEMWISE)
32
33# softmax
34@reg.register_schedule("nn.softmax")
35def schedule_softmax(_, outputs, target):
36    """Schedule definition of softmax"""
37    with target:
38        return topi.generic.schedule_softmax(outputs)
39
40
41reg.register_pattern("nn.softmax", OpPattern.OPAQUE)
42
43schedule_broadcast = schedule_injective
44
45
46@reg.register_schedule("nn.log_softmax")
47def schedule_log_softmax(_, outputs, target):
48    """Schedule definition of log_softmax"""
49    with target:
50        return topi.generic.schedule_softmax(outputs)
51
52
53reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
54
55
56# dense
57@reg.register_compute("nn.dense")
58def compute_dense(attrs, inputs, out_type, target):
59    """Compute definition of dense"""
60    out_dtype = attrs.out_dtype
61    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
62    return [topi.nn.dense(inputs[0], inputs[1], None, out_dtype)]
63
64
65@reg.register_schedule("nn.dense")
66def schedule_dense(attrs, outputs, target):
67    """Schedule definition of dense"""
68    with target:
69        return topi.generic.schedule_dense(outputs)
70
71
72reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
73
74
75@reg.register_compute('nn.fifo_buffer')
76def compute_fifo_buffer(attrs, inputs, out_type, target):
77    return [topi.nn.fifo_buffer(inputs[0], inputs[1], axis=attrs.get_int('axis'))]
78
79
80@reg.register_schedule('nn.fifo_buffer')
81def schedule_fifo_buffer(attrs, outputs, target):
82    with target:
83        return topi.generic.schedule_injective(outputs)
84
85
86reg.register_pattern("nn.fifo_buffer", OpPattern.OPAQUE)
87
88
89# batch_matmul
90@reg.register_compute("nn.batch_matmul")
91def compute_batch_matmul(attrs, inputs, out_type, target):
92    """Compute definition of batch_matmul"""
93    with target:
94        return [topi.nn.batch_matmul(inputs[0], inputs[1])]
95
96
97@reg.register_schedule("nn.batch_matmul")
98def schedule_batch_matmul(attrs, outputs, target):
99    """Schedule definition of batch_matmul"""
100    with target:
101        return topi.generic.schedule_batch_matmul(outputs)
102
103
104reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
105
106# sparse_dense
107@reg.register_compute("nn.sparse_dense")
108def compute_sparse_dense(attrs, inputs, out_type, target):
109    """Compute definition of sparse_dense"""
110    return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])]
111
112@reg.register_schedule("nn.sparse_dense")
113def schedule_sparse_dense(attrs, outputs, target):
114    """Schedule definition of batch_matmul"""
115    with target:
116        return topi.generic.schedule_sparse_dense(outputs)
117
118reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
119
120# sparse_transpose
121@reg.register_compute("nn.sparse_transpose")
122def compute_sparse_transpose(attrs, inputs, out_type, target):
123    """Compute definition of sparse_transpose"""
124    return topi.nn.sparse_transpose(inputs[0], inputs[1], inputs[2])
125
126@reg.register_schedule("nn.sparse_transpose")
127def schedule_sparse_transpose(attrs, outputs, target):
128    """Schedule definition of batch_matmul"""
129    with target:
130        return topi.generic.schedule_sparse_transpose(outputs)
131
132reg.register_pattern("nn.sparse_transpose", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
133
134# conv2d
135def _find_conv2d_op(op):
136    """Find the op with conv2d in its tag by traversing."""
137    if 'conv2d' in op.tag:
138        return op
139    for tensor in op.input_tensors:
140        op_ = _find_conv2d_op(tensor.op)
141        if op_ is not None:
142            return op_
143    return None
144
145
146@reg.register_compute("nn.conv2d")
147def compute_conv2d(attrs, inputs, out_type, target):
148    """Compute definition of conv2d"""
149    padding = get_const_tuple(attrs.padding)
150    strides = get_const_tuple(attrs.strides)
151    dilation = get_const_tuple(attrs.dilation)
152    groups = attrs.groups
153    layout = attrs.data_layout
154    kernel_layout = attrs.kernel_layout
155    out_dtype = attrs.out_dtype
156    out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
157                 else out_dtype)
158
159    assert layout in ["NCHW", "NHWC", "NCHW4c", "HWCN"]
160    (dilation_h, dilation_w) = dilation
161    if dilation_h < 1 or dilation_w < 1:
162        raise ValueError("dilation should be positive value")
163
164    def _get_out_depth():
165        weight_shape = get_const_tuple(inputs[1].shape)
166        if kernel_layout.startswith("HW"):
167            return weight_shape[2] * weight_shape[3]
168        return weight_shape[0] * weight_shape[1]
169
170    if groups == 1:
171        out = topi.nn.conv2d(
172            inputs[0], inputs[1], strides, padding,
173            dilation, layout, out_dtype)
174    elif layout == "NCHW" and _get_out_depth() == groups:
175        out = topi.nn.depthwise_conv2d_nchw(
176            inputs[0], inputs[1], strides, padding, dilation, out_dtype)
177    elif layout == "NHWC" and kernel_layout == "HWOI" and _get_out_depth() == groups:
178        out = topi.nn.depthwise_conv2d_nhwc(
179            inputs[0], inputs[1], strides, padding, dilation, out_dtype)
180    elif layout in ['NCHW', 'NCHW4c']:
181        out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups,
182                                        out_dtype)
183    else:
184        raise ValueError("not support arbitrary group number for now")
185    return [out]
186
187
188@reg.register_schedule("nn.conv2d")
189def schedule_conv2d(attrs, outs, target):
190    """Schedule definition of conv2d"""
191    groups = attrs.groups
192    layout = attrs.data_layout
193    kernel_layout = attrs.kernel_layout
194
195    with target:
196        if groups == 1 and layout == "NCHW":
197            return topi.generic.schedule_conv2d_nchw(outs)
198        elif groups == 1 and layout == "NCHW4c":
199            return topi.generic.schedule_conv2d_nchw(outs)
200        elif groups == 1 and layout == "NHWC":
201            return topi.generic.schedule_conv2d_nhwc(outs)
202        elif groups == 1 and layout == "HWCN":
203            return topi.generic.schedule_conv2d_hwcn(outs)
204        elif groups != 1:
205            # collect in_channels to distinguish depthwise and group conv2d
206            op = _find_conv2d_op(outs[0].op)
207            assert op is not None
208
209            is_depthwise = 'depthwise' in op.tag
210            if is_depthwise:
211                if layout == "NCHW":
212                    # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d.
213                    return topi.generic.schedule_depthwise_conv2d_nchw(outs)
214                if layout == "NHWC" and kernel_layout == "HWOI":
215                    return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
216            else:
217                if layout in ["NCHW", "NCHW4c"]:
218                    return topi.generic.schedule_group_conv2d_nchw(outs)
219    raise ValueError("No compatible schedule")
220
221
222@reg.register_alter_op_layout("nn.conv2d")
223def alter_op_layout_conv2d(attrs, inputs, tinfos):
224    """Alternate the layout of conv2d"""
225    from ... import op
226    return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos, op)
227
228@reg.register_legalize("nn.conv2d")
229def legalize_conv2d(attrs, inputs, types):
230    """Legalize conv2d op.
231
232    Parameters
233    ----------
234    attrs : tvm.attrs.Attrs
235        Attributes of current convolution
236    inputs : list of tvm.relay.Expr
237        The args of the Relay expr to be legalized
238    types : list of types
239        List of input and output types
240
241    Returns
242    -------
243    result : tvm.relay.Expr
244        The legalized expr
245    """
246    return topi.nn.conv2d_legalize(attrs, inputs, types)
247
248reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
249
250
251# conv2d_transpose
252@reg.register_compute("nn.conv2d_transpose")
253def compute_conv2d_transpose(attrs, inputs, out_dtype, target):
254    """Compute definition of conv2d_transpose"""
255    padding = get_const_tuple(attrs.padding)
256    strides = get_const_tuple(attrs.strides)
257    dilation = get_const_tuple(attrs.dilation)
258    groups = attrs.groups
259    layout = attrs.data_layout
260    out_dtype = attrs.out_dtype
261    out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
262                 else out_dtype)
263    assert layout == "NCHW", "only support nchw for now"
264    assert dilation == (1, 1), "not support dilate now"
265    assert groups == 1, "only support groups == 1 for now"
266    out = topi.nn.conv2d_transpose_nchw(
267        inputs[0], inputs[1], strides, padding, out_dtype)
268    output_padding = get_const_tuple(attrs.output_padding)
269    out = topi.nn.pad(out,
270                      [0, 0, 0, 0], [0, 0, output_padding[0], output_padding[1]])
271    return [out]
272
273
274@reg.register_schedule("nn.conv2d_transpose")
275def schedule_conv2d_transpose(attrs, outs, target):
276    """Schedule definition of conv2d_transpose"""
277    with target:
278        return topi.generic.schedule_conv2d_transpose_nchw(outs)
279
280
281@reg.register_legalize("nn.conv2d_transpose")
282def legalize_conv2d_transpose(attrs, inputs, types):
283    """Legalize conv2d_transpose op.
284
285    Parameters
286    ----------
287    attrs : tvm.attrs.Attrs
288        Attributes of current Transposed convolution
289    inputs : list of tvm.relay.Expr
290        The args of the Relay expr to be legalized
291    types : list of types
292        List of input and output types
293
294    Returns
295    -------
296    result : tvm.relay.Expr
297        The legalized expr
298    """
299    return topi.nn.conv2d_transpose_legalize(attrs, inputs, types)
300
301reg.register_pattern("nn.conv2d_transpose", OpPattern.OUT_ELEMWISE_FUSABLE)
302
303# bias_add
304reg.register_schedule("nn.bias_add", schedule_injective)
305reg.register_pattern("nn.bias_add", OpPattern.BROADCAST)
306
307
308# max_pool2d
309@reg.register_schedule("nn.max_pool2d")
310def schedule_max_pool2d(attrs, outs, target):
311    """Schedule definition of max_pool2d"""
312    layout = attrs.layout
313    with target:
314        return topi.generic.schedule_pool(outs, layout)
315
316
317reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
318
319
320# avg_pool2d
321@reg.register_schedule("nn.avg_pool2d")
322def schedule_avg_pool2d(attrs, outs, target):
323    """Schedule definition of avg_pool2d"""
324    layout = attrs.layout
325    with target:
326        return topi.generic.schedule_pool(outs, layout)
327
328
329reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
330
331
332# max_pool2d_grad
333@reg.register_schedule("nn.max_pool2d_grad")
334def schedule_max_pool2d_grad(attrs, outs, target):
335    """Schedule definition of max_pool2d_grad"""
336    with target:
337        return topi.generic.schedule_pool_grad(outs)
338
339
340reg.register_pattern("nn.max_pool2d_grad", OpPattern.OUT_ELEMWISE_FUSABLE)
341
342
343# avg_pool2d_grad
344@reg.register_schedule("nn.avg_pool2d_grad")
345def schedule_avg_pool2d_grad(attrs, outs, target):
346    """Schedule definition of avg_pool2d_grad"""
347    with target:
348        return topi.generic.schedule_pool_grad(outs)
349
350
351reg.register_pattern("nn.avg_pool2d_grad", OpPattern.OUT_ELEMWISE_FUSABLE)
352
353
354# global_max_pool2d
355@reg.register_schedule("nn.global_max_pool2d")
356def schedule_global_max_pool2d(_, outs, target):
357    """Schedule definition of global_max_pool2d"""
358    with target:
359        return topi.generic.schedule_adaptive_pool(outs)
360
361
362reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
363
364
365# global_avg_pool2d
366@reg.register_schedule("nn.global_avg_pool2d")
367def schedule_global_avg_pool2d(_, outs, target):
368    """Schedule definition of global_avg_pool2d"""
369    with target:
370        return topi.generic.schedule_adaptive_pool(outs)
371
372
373reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
374
375
376# leaky_relu
377reg.register_schedule("nn.leaky_relu", schedule_broadcast)
378reg.register_pattern("nn.leaky_relu", OpPattern.ELEMWISE)
379
380# prelu
381reg.register_schedule("nn.prelu", schedule_broadcast)
382reg.register_pattern("nn.prelu", OpPattern.BROADCAST)
383
384# flatten
385reg.register_schedule("nn.batch_flatten", schedule_broadcast)
386reg.register_pattern("nn.batch_flatten", OpPattern.INJECTIVE)
387
388
389# lrn
390@reg.register_compute("nn.lrn")
391def compute_lrn(attrs, inputs, out_dtype, target):
392    """Compute definition of lrn"""
393    assert len(inputs) == 1
394    return [topi.nn.lrn(inputs[0], attrs.size, attrs.axis,
395                        attrs.alpha, attrs.beta, attrs.bias)]
396
397
398@reg.register_schedule("nn.lrn")
399def schedule_lrn(attrs, outs, target):
400    """Schedule definition of lrn"""
401    with target:
402        return topi.generic.schedule_lrn(outs)
403
404
405reg.register_pattern("nn.lrn", OpPattern.OPAQUE)
406
407
408# l2_normalize
409@reg.register_compute("nn.l2_normalize")
410def compute_l2_normalize(attrs, inputs, out_dtype, target):
411    """Compute definition of l2 normalize"""
412    return [topi.nn.l2_normalize(inputs[0], attrs.eps, attrs.axis)]
413
414
415@reg.register_schedule("nn.l2_normalize")
416def schedule_l2_normalize(attrs, outs, target):
417    """Schedule definition of l2 normalize"""
418    with target:
419        return topi.generic.schedule_l2_normalize(outs)
420
421
422reg.register_pattern("nn.l2_normalize", OpPattern.OUT_ELEMWISE_FUSABLE)
423
424# upsampling
425reg.register_schedule("nn.upsampling", reg.schedule_injective)
426
427
428def schedule_upsampling(_, outs, target):
429    """Schedule definition of upsampling"""
430    with target:
431        return topi.generic.schedule_injective(outs)
432
433@reg.register_compute("nn.upsampling")
434def compute_upsampling(attrs, inputs, out_dtype, target):
435    scale_h = attrs.scale_h
436    scale_w = attrs.scale_w
437    layout = attrs.layout
438    method = attrs.method
439    align_corners = attrs.align_corners
440    return [topi.nn.upsampling(inputs[0], scale_h, scale_w, layout, method, align_corners)]
441
442# pad
443reg.register_schedule("nn.pad", schedule_broadcast)
444
445# mirror_pad
446reg.register_schedule("nn.mirror_pad", schedule_broadcast)
447
448@reg.register_compute("nn.mirror_pad")
449def compute_mirror_pad(attrs, inputs, out_dtype, target):
450    pad_before, pad_after = list(zip(*attrs.pad_width))
451    mode = attrs.mode
452    out = topi.nn.mirror_pad(inputs[0], pad_before=pad_before, pad_after=pad_after, mode=mode)
453    return [out]
454
455# winograd related operators
456@reg.register_compute("nn.contrib_conv2d_winograd_without_weight_transform")
457def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, out_dtype, target):
458    """Compute definition of conv2d_winograd_without_weight_transform"""
459    # pylint: disable=assignment-from-no-return
460    padding = attrs.get_int_tuple("padding")
461    strides = attrs.get_int_tuple("strides")
462    dilation = attrs.get_int_tuple("dilation")
463    groups = attrs.get_int("groups")
464    data_layout = attrs.get_str("data_layout")
465    out_dtype = attrs.get_str("out_dtype")
466    tile_size = attrs.get_int("tile_size")
467    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
468    assert dilation == (1, 1), "Do not support dilate now"
469    assert groups == 1, "Do not supoort arbitrary group number"
470
471    out = topi.nn.conv2d_winograd_without_weight_transform(
472        inputs[0], inputs[1], strides, padding, dilation, data_layout,
473        out_dtype, tile_size)
474
475    return [out]
476
477
478@reg.register_schedule("nn.contrib_conv2d_winograd_without_weight_transform")
479def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target):
480    """Schedule definition of conv2d_winograd_without_weight_transform"""
481    with target:
482        return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs)
483
484
485reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform",
486                     OpPattern.OUT_ELEMWISE_FUSABLE)
487
488
489@reg.register_compute("nn.contrib_conv2d_winograd_weight_transform")
490def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype, target):
491    """Compute definition of contrib_conv2d_winograd_weight_transform"""
492    out = topi.nn.conv2d_winograd_weight_transform(
493        inputs[0], attrs.get_int('tile_size'))
494    return [out]
495
496
497@reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform")
498def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target):
499    """Schedule definition of contrib_conv2d_winograd_weight_transform"""
500    with target:
501        return topi.generic.schedule_conv2d_winograd_weight_transform(outs)
502
503
504reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform",
505                     OpPattern.OUT_ELEMWISE_FUSABLE)
506
507
508# winograd nnpack related operators
509@reg.register_compute("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
510def compute_contrib_conv2d_winograd_nnpack_without_weight_transform(
511        attrs, inputs, out_dtype, target):
512    """Compute definition of conv2d_winograd_nnpack_without_weight_transform"""
513    # pylint: disable=assignment-from-no-return
514    padding = attrs.get_int_tuple("padding")
515    strides = attrs.get_int_tuple("strides")
516    dilation = attrs.get_int_tuple("dilation")
517    groups = attrs.get_int("groups")
518    data_layout = attrs.get_str("data_layout")
519    out_dtype = attrs.get_str("out_dtype")
520    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
521    assert dilation == (1, 1), "Do not support dilate now"
522    assert groups == 1, "Do not supoort arbitrary group number"
523
524    # No bias
525    out = topi.nn.conv2d_winograd_nnpack_without_weight_transform(
526        inputs[0], inputs[1], None, strides, padding, dilation, data_layout,
527        out_dtype)
528
529    return [out]
530
531
532@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
533def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target):
534    """Schedule definition of conv2d_winograd_nnpack_without_weight_transform"""
535    with target:
536        return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs)
537
538
539reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_without_weight_transform",
540                     OpPattern.OPAQUE)
541
542
543@reg.register_compute("nn.contrib_conv2d_winograd_nnpack_weight_transform")
544def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_dtype, target):
545    """Compute definition of contrib_conv2d_winograd_nnpack_weight_transform"""
546    convolution_algorithm = attrs.get_int('convolution_algorithm')
547    out = topi.nn.conv2d_winograd_nnpack_weight_transform(
548        inputs[0], convolution_algorithm, out_dtype)
549    return [out]
550
551
552@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform")
553def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
554    """Schedule definition of contrib_conv2d_winograd_nnpack_weight_transform"""
555    with target:
556        return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)
557
558
559reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform",
560                     OpPattern.OPAQUE)
561
562
563@reg.register_compute("nn.contrib_conv2d_NCHWc")
564def compute_contrib_conv2d_NCHWc(attrs, inputs, out_dtype, target):
565    """Compute definition of conv2d NCHWc"""
566    # pylint: disable=assignment-from-no-return
567    padding = attrs.get_int_tuple("padding")
568    strides = attrs.get_int_tuple("strides")
569    dilation = attrs.get_int_tuple("dilation")
570    data_layout = attrs.get_str("data_layout")
571    out_layout = attrs.get_str("out_layout")
572    out_dtype = attrs.get_str("out_dtype")
573    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
574
575    out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation,
576                               data_layout, out_layout, out_dtype)
577    return [out]
578
579
580@reg.register_schedule("nn.contrib_conv2d_NCHWc")
581def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
582    """Schedule definition of contrib_conv2d_NCHWc"""
583    with target:
584        return topi.generic.schedule_conv2d_NCHWc(outs)
585
586
587reg.register_pattern("nn.contrib_conv2d_NCHWc",
588                     OpPattern.OUT_ELEMWISE_FUSABLE)
589
590
591@reg.register_compute("nn.contrib_conv2d_NCHWc_int8")
592def compute_contrib_conv2d_NCHWc_int8(attrs, inputs, out_dtype, target):
593    """Compute definition of conv2d NCHWc"""
594    # pylint: disable=assignment-from-no-return
595    padding = attrs.get_int_tuple("padding")
596    strides = attrs.get_int_tuple("strides")
597    dilation = attrs.get_int_tuple("dilation")
598    data_layout = attrs.get_str("data_layout")
599    out_layout = attrs.get_str("out_layout")
600    out_dtype = attrs.get_str("out_dtype")
601    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
602
603    out = topi.nn.conv2d_NCHWc_int8(inputs[0], inputs[1], strides, padding, dilation,
604                                    data_layout, out_layout, out_dtype)
605    return [out]
606
607
608@reg.register_schedule("nn.contrib_conv2d_NCHWc_int8")
609def schedule_contrib_conv2d_NCHWc_int8(attrs, outs, target):
610    """Schedule definition of contrib_conv2d_NCHWc_int8"""
611    with target:
612        return topi.generic.schedule_conv2d_NCHWc_int8(outs)
613
614
615reg.register_pattern("nn.contrib_conv2d_NCHWc_int8",
616                     OpPattern.OUT_ELEMWISE_FUSABLE)
617
618
619@reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc")
620def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target):
621    """Compute definition of depthwise conv2d NCHWc"""
622    # pylint: disable=assignment-from-no-return
623    padding = attrs.get_int_tuple("padding")
624    strides = attrs.get_int_tuple("strides")
625    dilation = attrs.get_int_tuple("dilation")
626    data_layout = attrs.get_str("data_layout")
627    out_layout = attrs.get_str("out_layout")
628    out_dtype = attrs.get_str("out_dtype")
629    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
630
631    out = topi.nn.depthwise_conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation,
632                                         data_layout, out_layout, out_dtype)
633    return [out]
634
635
636@reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc")
637def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target):
638    """Schedule definition of contrib_conv2d_NCHWc"""
639    with target:
640        return topi.generic.schedule_depthwise_conv2d_NCHWc(outs)
641
642
643reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc",
644                     OpPattern.OUT_ELEMWISE_FUSABLE)
645
646
647@reg.register_compute("nn.deformable_conv2d")
648def compute_deformable_conv2d(attrs, inputs, out_dtype, target):
649    """Compute definition of deformable_conv2d"""
650    padding = get_const_tuple(attrs.padding)
651    strides = get_const_tuple(attrs.strides)
652    dilation = get_const_tuple(attrs.dilation)
653    deformable_groups = attrs.deformable_groups
654    groups = attrs.groups
655    out_dtype = attrs.out_dtype
656    out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
657    with target:
658        out = topi.nn.deformable_conv2d_nchw(inputs[0], inputs[1], inputs[2], strides, padding,
659                                             dilation, deformable_groups, groups, out_dtype)
660    return [out]
661
662
663@reg.register_schedule("nn.deformable_conv2d")
664def schedule_deformable_conv2d(attrs, outs, target):
665    """Schedule definition of deformable_conv2d"""
666    with target:
667        return topi.generic.schedule_deformable_conv2d_nchw(outs)
668
669
670reg.register_pattern("nn.deformable_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
671
672
673@reg.register_compute("nn.bitpack")
674def compute_bitpack(attrs, inputs, out_dtype, target):
675    """Compute definition for bitpack"""
676    bits = attrs.bits
677    pack_axis = attrs.pack_axis
678    bit_axis = attrs.bit_axis
679    pack_type = attrs.pack_type
680    name = attrs.name
681    with target:
682        out = topi.nn.bitpack(inputs[0], bits, pack_axis, bit_axis, pack_type,
683                              name)
684    return [out]
685
686@reg.register_schedule("nn.bitpack")
687def schedule_bitpack(attrs, outs, target):
688    with target:
689        return topi.generic.schedule_bitpack(outs)
690
691reg.register_pattern("nn.bitpack", OpPattern.INJECTIVE)
692
693
694@reg.register_compute("nn.bitserial_conv2d")
695def compute_bitserial_conv2d(attrs, inputs, out_dtype, target):
696    """Compute definition for bitserial conv2d."""
697    padding = get_const_tuple(attrs.padding)
698    strides = get_const_tuple(attrs.strides)
699    activation_bits = attrs.activation_bits
700    weight_bits = attrs.weight_bits
701    layout = attrs.data_layout
702    pack_dtype = attrs.pack_dtype
703    out_dtype = attrs.out_dtype
704    unipolar = attrs.unipolar
705    if layout == 'NCHW':
706        with target:
707            out = topi.nn.bitserial_conv2d_nchw(
708                inputs[0], inputs[1], strides, padding, activation_bits,
709                weight_bits, pack_dtype, out_dtype, unipolar)
710    elif layout == 'NHWC':
711        with target:
712            out = topi.nn.bitserial_conv2d_nhwc(
713                inputs[0], inputs[1], strides, padding, activation_bits,
714                weight_bits, pack_dtype, out_dtype, unipolar)
715    else:
716        raise ValueError("Data layout not supported.")
717
718    return [out]
719
720
721@reg.register_schedule("nn.bitserial_conv2d")
722def schedule_bitserial_conv2d(attrs, outs, target):
723    """Schedule definition for bitserial conv2d."""
724    layout = attrs.data_layout
725    if layout == 'NCHW':
726        with target:
727            return topi.generic.schedule_bitserial_conv2d_nchw(outs)
728    elif layout == 'NHWC':
729        with target:
730            return topi.generic.schedule_bitserial_conv2d_nhwc(outs)
731    else:
732        raise ValueError("Data layout not supported.")
733
734@reg.register_legalize("nn.bitserial_conv2d")
735def legalize_bitserial_conv2d(attrs, inputs, types):
736    """Legalize bitserial_conv2d op.
737
738    Parameters
739    ----------
740    attrs : tvm.attrs.Attrs
741        Attributes of current convolution
742    inputs : list of tvm.relay.Expr
743        The args of the Relay expr to be legalized
744    types : list of types
745        List of input and output types
746
747    Returns
748    -------
749    result : tvm.relay.Expr
750        The legalized expr
751    """
752    return topi.nn.bitserial_conv2d_legalize(attrs, inputs, types)
753
754
755reg.register_pattern("nn.bitserial_conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
756
757
758# bitserial_dense
759@reg.register_compute("nn.bitserial_dense")
760def compute_bitserial_dense(attrs, inputs, out_type, target):
761    """Compute definition of bitserial_dense"""
762    data_bits = attrs.data_bits
763    weight_bits = attrs.weight_bits
764    pack_dtype = attrs.pack_dtype
765    out_dtype = attrs.out_dtype
766    out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
767    unipolar = attrs.unipolar
768    return [
769        topi.nn.bitserial_dense(
770            inputs[0],
771            inputs[1],
772            data_bits,
773            weight_bits,
774            pack_dtype,
775            out_dtype,
776            unipolar)
777    ]
778
779
780@reg.register_schedule("nn.bitserial_dense")
781def schedule_bitserial_dense(attrs, outputs, target):
782    """Schedule definition of bitserial_dense"""
783    with target:
784        return topi.generic.schedule_bitserial_dense(outputs)
785
786
787reg.register_pattern("nn.bitserial_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
788
789
790reg.register_pattern("nn.cross_entropy", OpPattern.OPAQUE)
791
792@reg.register_compute("nn.cross_entropy")
793def compute_cross_entropy(attrs, inputs, out_dtype, target):
794    x, y = inputs
795    return [-topi.sum(topi.log(x) * y) / x.shape[0]]
796
797
798reg.register_pattern("nn.cross_entropy_with_logits", OpPattern.OPAQUE)
799
800@reg.register_compute("nn.cross_entropy_with_logits")
801def compute_cross_entropy_with_logits(attrs, inputs, out_dtype, target):
802    x, y = inputs
803    return [-topi.sum(x * y) / x.shape[0]]
804
805# shape func
806@script
807def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn):
808    out = output_tensor((dshape.shape[0],), "int64")
809    ic_chunk = dshape[1]
810    height = dshape[2]
811    width = dshape[3]
812    ic_bn = dshape[4]
813    kheight = kshape[2]
814    kwidth = kshape[3]
815    dilated_kh = (kheight - 1) * dilation[0] + 1
816    dilated_kw = (kwidth - 1) * dilation[1] + 1
817    kflatten = int64(1)
818    for i in const_range(kshape.shape[0]):
819        kflatten *= kshape[i]
820
821    oc = kflatten // (kheight * kwidth * ic_chunk * ic_bn)
822    oc_chunk = oc // oc_bn
823
824    out_height = (height + 2 * padding[0] - dilated_kh) // strides[0] + 1
825    out_width = (width + 2 * padding[1] - dilated_kw) // strides[1] + 1
826
827    out[0] = dshape[0]
828    out[1] = oc_chunk
829    out[2] = out_height
830    out[3] = out_width
831    out[4] = int64(oc_bn)
832    return out
833
834@reg.register_shape_func("nn.contrib_conv2d_NCHWc", False)
835def conv2d_NCHWc_shape_func(attrs, inputs, _):
836    """
837    Shape function for contrib_conv2d_NCHWc op.
838    """
839    strides = get_const_tuple(attrs.strides)
840    padding = get_const_tuple(attrs.padding)
841    dilation = get_const_tuple(attrs.dilation)
842    out_layout = attrs.out_layout
843    oc_bn = int(out_layout[4:-1])
844
845    return [_conv2d_NCHWc_shape_func(inputs[0], inputs[1],
846                                     convert(strides), convert(padding),
847                                     convert(dilation), convert(oc_bn))]
848
849@script
850def _pool2d_shape_func(data_shape, pool_size, strides,
851                       padding, height_axis, width_axis):
852    out = output_tensor((data_shape.shape[0],), "int64")
853    for i in const_range(data_shape.shape[0]):
854        if i == height_axis:
855            out[i] = (data_shape[i] + padding[0] + padding[2] - pool_size[0]) // strides[0] + 1
856        elif i == width_axis:
857            out[i] = (data_shape[i] + padding[1] + padding[3] - pool_size[1]) // strides[1] + 1
858        else:
859            out[i] = data_shape[i]
860
861    return out
862
863def pool2d_shape_func(attrs, inputs, _):
864    """
865    Shape function for pool2d op.
866    """
867    pool_size = get_const_tuple(attrs.pool_size)
868    strides = get_const_tuple(attrs.strides)
869    padding = get_const_tuple(attrs.padding)
870    layout = attrs.layout
871    height_axis = layout.index("H")
872    width_axis = layout.index("W")
873    if len(padding) == 1:
874        padding = [padding[0]] * 4
875    elif len(padding) == 2:
876        padding = [padding[0], padding[1], padding[0], padding[1]]
877
878    return [_pool2d_shape_func(inputs[0], convert(pool_size),
879                               convert(strides), convert(padding),
880                               convert(height_axis), convert(width_axis))]
881
882reg.register_shape_func("nn.max_pool2d", False, pool2d_shape_func)
883reg.register_shape_func("nn.avg_pool2d", False, pool2d_shape_func)
884
885@script
886def _global_pool2d_shape_func(data_shape, height_axis, width_axis):
887    out = output_tensor((data_shape.shape[0],), "int64")
888    for i in const_range(out.shape[0]):
889        if i == height_axis or i == width_axis:
890            out[i] = int64(1)
891        else:
892            out[i] = data_shape[i]
893
894    return out
895
896def global_pool2d_shape_func(attrs, inputs, _):
897    """
898    Shape function for global pool2d op.
899    """
900    layout = attrs.layout
901    height_axis = width_axis = 1
902    for i, letter in enumerate(layout):
903        if letter == "H":
904            height_axis = i
905        if letter == "W":
906            width_axis = i
907    return [_global_pool2d_shape_func(inputs[0], convert(height_axis), convert(width_axis))]
908
909reg.register_shape_func("nn.global_max_pool2d", False, global_pool2d_shape_func)
910reg.register_shape_func("nn.global_avg_pool2d", False, global_pool2d_shape_func)
911
912@script
913def _batch_flatten_shape_func(data_shape):
914    out = output_tensor((2,), "int64")
915    out[0] = data_shape[0]
916    out[1] = int64(1)
917    for i in const_range(data_shape.shape[0] - 1):
918        out[1] *= data_shape[i + 1]
919
920    return out
921
922@reg.register_shape_func("nn.batch_flatten", False)
923def batch_flatten_shape_func(attrs, inputs, _):
924    """
925    Shape function for batch_flatten op.
926    """
927    return [_batch_flatten_shape_func(inputs[0])]
928
929@script
930def _dense_shape_func(data_shape, weight_shape):
931    out = output_tensor((data_shape.shape[0],), "int64")
932    for i in const_range(out.shape[0] - 1):
933        out[i] = data_shape[i]
934    out[out.shape[0] - 1] = weight_shape[0]
935
936    return out
937
938@reg.register_shape_func("nn.dense", False)
939def dense_shape_func(attrs, inputs, _):
940    """
941    Shape function for dense op.
942    """
943    ret = [_dense_shape_func(inputs[0], inputs[1])]
944    return ret
945
946@script
947def _pad_shape_func(data_shape, pad_width):
948    out = output_tensor((data_shape.shape[0],), "int64")
949    for i in const_range(out.shape[0]):
950        out[i] = data_shape[i] + pad_width[i][0] + pad_width[i][1]
951
952    return out
953
954@reg.register_shape_func("nn.pad", False)
955def pad_shape_func(attrs, inputs, _):
956    """
957    Shape function for pad op.
958    """
959    pad_width = []
960    for pair in attrs.pad_width:
961        pad_width.append(get_const_tuple(pair))
962    return [_pad_shape_func(inputs[0], convert(pad_width))]
963
964reg.register_shape_func("nn.bias_add", False, elemwise_shape_func)
965reg.register_shape_func("nn.softmax", False, elemwise_shape_func)
966reg.register_shape_func("nn.relu", False, elemwise_shape_func)
967