1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
18"""Conv2D schedule on x86"""
19
20import logging
21import re
22
23import tvm
24from tvm import autotvm
25from tvm.autotvm.task.topi_integration import deserialize_args
26from tvm.autotvm.task import get_config
27from .. import generic, tag
28from .. import nn
29from ..nn.conv2d import conv2d, conv2d_NCHWc, \
30    conv2d_infer_layout, _get_workload as _get_conv2d_workload
31from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
32from ..nn.pad import pad
33from ..util import get_const_tuple
34
35from . import conv2d_avx_1x1, conv2d_avx_common
36
37logger = logging.getLogger('topi')
38
39def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False,
40                        layout='NCHW'):
41    """
42    Get default schedule config for the workload
43    """
44    static_data_shape = []
45    for dim in get_const_tuple(data.shape):
46        if isinstance(dim, tvm.expr.Var):
47            static_data_shape.append(1)
48        else:
49            static_data_shape.append(dim)
50    data = tvm.placeholder(static_data_shape, dtype=data.dtype)
51    if is_depthwise:
52        wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype)
53        from .depthwise_conv2d import _fallback_schedule
54        _fallback_schedule(cfg, wkl)
55    else:
56        wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
57        is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
58        if is_kernel_1x1:
59            conv2d_avx_1x1._fallback_schedule(cfg, wkl)
60        else:
61            conv2d_avx_common._fallback_schedule(cfg, wkl)
62
63def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
64    """Create schedule configuration from input arguments"""
65    dshape = get_const_tuple(data.shape)
66    kshape = get_const_tuple(kernel.shape)
67    pat = re.compile(r'NCHW.+(\d+)c')
68    if layout == 'NCHW':
69        n, ic, h, w = dshape
70        oc, _, kh, kw = kshape
71    elif layout == 'NHWC':
72        n, h, w, ic = dshape
73        kh, kw, oc, _ = kshape
74    elif pat.match(layout) is not None:
75        n, ic_chunk, h, w, ic_bn = dshape
76        target = tvm.target.current_target(allow_none=False)
77        oc_chunk, k_ic_chunk, kh, kw, k_ic_bn, oc_bn = kshape
78        assert ic_chunk == k_ic_chunk
79        assert ic_bn == k_ic_bn
80        ic = ic_chunk*ic_bn
81        oc = oc_chunk*oc_bn
82    else:
83        raise ValueError("Not support this layout {} with "
84                         "schedule template.".format(layout))
85
86    is_kernel_1x1 = kh == 1 and kw == 1
87    ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding)
88    sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
89    oh = (h - kh + 2 * ph) // sh + 1
90    ow = (w - kw + 2 * pw) // sw + 1
91
92    # Create schedule config
93    cfg.define_split("tile_ic", ic, num_outputs=2)
94    cfg.define_split("tile_oc", oc, num_outputs=2)
95    cfg.define_split("tile_ow", ow, num_outputs=2, filter=lambda y: y.size[-1] <= 64)
96    if is_kernel_1x1:
97        cfg.define_knob("tile_oh", [1, 2] if oh > 1 else [1])
98    else:
99        cfg.define_knob("unroll_kw", [True, False])
100
101
102@autotvm.register_topi_compute(conv2d, 'cpu', ['direct'])
103def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
104    out_dtype = data.dtype if out_dtype is None else out_dtype
105    padding = padding if isinstance(padding, (tuple, list)) else (padding, padding)
106    strides = strides if isinstance(strides, (tuple, list)) else (strides, strides)
107    dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
108
109    if layout == 'NCHW':
110        _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout)
111        if cfg.is_fallback:
112            _get_default_config(cfg, data, kernel, strides, padding, out_dtype)
113        return _declaration_conv_impl(cfg, data, kernel, strides,
114                                      padding, dilation, layout, out_dtype)
115
116    # HWOI kernel layout is for NHWC and HWCN
117    kh, kw, _, _ = get_const_tuple(kernel.shape)
118    if layout == 'HWCN':
119        return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype)
120    # FIXME - https://github.com/apache/incubator-tvm/issues/4122
121    # _declaration_conv_nhwc_pack expects kernel layout to be HWOI. However, the tests use HWIO
122    # layout. Commenting until we have clarity about the nhwc_pack implementation from the author.
123    # elif layout == 'NHWC' and kh == 1 and kw == 1 and kernel.dtype == "int8":
124    #     if cfg.is_fallback:
125    #         _get_default_config(cfg, data, kernel, strides, padding, out_dtype, False, layout)
126    #     # specialize for INT8 1X1 conv on X86
127    #     return conv2d_avx_1x1._declaration_conv_nhwc_pack(cfg, data, kernel, strides,
128    #                                                       padding, dilation, out_dtype)
129    elif layout == 'NHWC':
130        return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype)
131    raise ValueError("not support this layout {} yet".format(layout))
132
133
134def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
135    out_dtype = data.dtype if out_dtype is None else out_dtype
136    assert layout == 'NCHW', "only support NCHW convolution for AVX"
137
138    assert isinstance(dilation, int) or len(dilation) == 2
139    if isinstance(dilation, int):
140        dilation_h, dilation_w = dilation
141    else:
142        dilation_h, dilation_w = dilation
143
144    HPAD, WPAD = padding
145    HSTR, WSTR = strides
146
147    batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape)
148    num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape)
149
150    pad_height = in_height + 2 * HPAD
151    pad_width = in_width + 2 * WPAD
152
153    dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
154    dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
155    out_height = (in_height + 2 * HPAD - dilated_kernel_h) // HSTR + 1
156    out_width = (in_width + 2 * WPAD - dilated_kernel_w) // WSTR + 1
157
158    # pack data
159    DOPAD = (HPAD != 0 or WPAD != 0)
160    if DOPAD:
161        data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
162    else:
163        data_pad = data
164
165    # fetch schedule
166    ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
167
168    shape = (batch_size, in_channel // ic_bn, pad_height, ic_bn, pad_width)
169    data_vec = tvm.compute(shape,
170                           lambda n, C, h, c, w: data_pad[n, C * ic_bn + c, h, w],
171                           name='data_vec')
172
173    # pack kernel
174    shape = (num_filter//oc_bn, in_channel//ic_bn,
175             kernel_height, kernel_width, ic_bn, oc_bn)
176    kernel_vec = tvm.compute(shape,
177                             lambda CO, CI, h, w, ci, co:
178                             kernel[CO * oc_bn + co, CI * ic_bn + ci, h, w],
179                             name='kernel_vec')
180
181    # convolution
182    oshape = (batch_size, num_filter//oc_bn, out_height, out_width, oc_bn)
183    unpack_shape = (batch_size, num_filter, out_height, out_width)
184
185    ic = tvm.reduce_axis((0, in_channel), name='ic')
186    kh = tvm.reduce_axis((0, kernel_height), name='kh')
187    kw = tvm.reduce_axis((0, kernel_width), name='kw')
188    idxmod = tvm.indexmod
189    idxdiv = tvm.indexdiv
190
191    conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
192                       tvm.sum(data_vec[n, idxdiv(ic, ic_bn), oh*HSTR+kh*dilation_h,
193                                        idxmod(ic, ic_bn),
194                                        ow*WSTR+kw*dilation_w].astype(out_dtype) *
195                               kernel_vec[oc_chunk, idxdiv(ic, ic_bn), kh, kw,
196                                          idxmod(ic, ic_bn),
197                                          oc_block].astype(out_dtype),
198                               axis=[ic, kh, kw]), name='conv')
199
200    unpack = tvm.compute(unpack_shape,
201                         lambda n, c, h, w: conv[n, idxdiv(c, oc_bn), h, w, idxmod(c, oc_bn)]
202                         .astype(out_dtype),
203                         name='output_unpack',
204                         tag='conv2d_nchw')
205    return unpack
206
207
208@autotvm.register_topi_schedule(generic.schedule_conv2d_nchw, 'cpu', ['direct'])
209def schedule_conv2d(cfg, outs):
210    """Create schedule for tensors"""
211    s = tvm.create_schedule([x.op for x in outs])
212    scheduled_ops = []
213
214    def traverse(op):
215        """Traverse operators from computation graph"""
216        # inline all one-to-one-mapping operators except the last stage (output)
217        if tag.is_broadcast(op.tag):
218            if op not in s.outputs:
219                s[op].compute_inline()
220            for tensor in op.input_tensors:
221                if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
222                    traverse(tensor.op)
223
224        if 'conv2d_nchw' in op.tag:
225            output = op.output(0)
226            conv_out = op.input_tensors[0]
227            kernel_vec = conv_out.op.input_tensors[1]
228            kernel = kernel_vec.op.input_tensors[0]
229            if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
230                s[kernel].compute_inline()
231            data_vec = conv_out.op.input_tensors[0]
232            data = data_vec.op.input_tensors[0]
233            data_pad = None
234            if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
235                data_pad = data
236                data = data_pad.op.input_tensors[0]
237
238            _, _, kh, kw = get_const_tuple(kernel.shape)
239            is_kernel_1x1 = kh == 1 and kw == 1
240            args = [s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, outs[0]]
241            if is_kernel_1x1:
242                conv2d_avx_1x1._schedule_conv(*args)
243            else:
244                conv2d_avx_common._schedule_conv(*args)
245
246        scheduled_ops.append(op)
247
248    traverse(outs[0].op)
249    return s
250
251@generic.schedule_conv2d_nhwc.register("cpu")
252def schedule_conv2d_nhwc(outs):
253    """Create schedule for tensors"""
254    s = tvm.create_schedule([x.op for x in outs])
255    output_op = outs[0].op
256    scheduled_ops = []
257
258    def traverse(op):
259        """Traverse operators from computation graph"""
260        # inline all one-to-one-mapping operators except the last stage (output)
261        if tag.is_broadcast(op.tag):
262            if op not in s.outputs:
263                s[op].compute_inline()
264            else: # inject custom schedule
265                if len(op.axis) == 4: # schedule bias + bn + relu
266                    n, h, w, c = op.axis
267                    fused = s[op].fuse(n, h, w)
268                    s[op].parallel(fused)
269                    s[op].vectorize(c)
270            for tensor in op.input_tensors:
271                if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
272                    traverse(tensor.op)
273
274        if 'conv2d_nhwc' in op.tag:
275            conv = op.output(0)
276            kernel = op.input_tensors[1]
277            if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
278                s[kernel].compute_inline()
279
280            data = op.input_tensors[0]
281            data_pad = None
282            if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
283                data_pad = data
284                data = data_pad.op.input_tensors[0]
285
286            n_pad, h_pad, w_pad, c_pad = data_pad.op.axis
287            pad_fused = s[data_pad].fuse(n_pad, h_pad)
288            s[data_pad].parallel(pad_fused)
289            C = conv
290            n, h, w, c = C.op.axis
291            ry, rx, rc = C.op.reduce_axis
292            n_out, h_out, w_out, c_out = output_op.axis
293            s[C].vectorize(c)
294            if op != output_op: # fuse bias + bn + relu into conv
295                s[C].compute_at(s[output_op], c_out)
296            else:
297                fused = s[C].fuse(n, h, w)
298                s[C].parallel(fused)
299
300        scheduled_ops.append(op)
301
302    traverse(output_op)
303    return s
304
305
306# Define template function for autotvm task
307# We define schedule template in this function instead of
308# declaration function since actual input arguments need
309# to be altered by the schedule selected.
310@autotvm.task.register("topi_x86_conv2d_NCHWc")
311def _topi_nn_conv2d_NCHWc(*args, **kwargs):
312    assert not kwargs, "Do not support kwargs in template function call"
313    args = deserialize_args(args)
314
315    if len(args) == 7:
316        data, kernel, strides, padding, dilation, origin_layout, dtype = args
317    else:
318        assert len(args) == 8
319        data, kernel, strides, padding, dilation, origin_layout, out_layout, dtype = args
320
321    raw_data_shape = get_const_tuple(data.shape)
322    raw_kernel_shape = get_const_tuple(kernel.shape)
323
324    # get config here
325    cfg = get_config()
326    _create_tuning_space(cfg, data, kernel, strides, padding, dilation, origin_layout)
327
328    idxdiv = tvm.indexdiv
329    idxmod = tvm.indexmod
330    # change shape with the value in config
331    ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
332                           cfg["tile_ow"].size[-1])
333    new_data_shape = (raw_data_shape[0], idxdiv(raw_data_shape[1], ic_bn),
334                      raw_data_shape[2], raw_data_shape[3], ic_bn)
335    data_layout = "NCHW%dc" % ic_bn
336    out_layout = "NCHW%dc" % oc_bn
337    new_kernel_shape = (idxdiv(raw_kernel_shape[0], oc_bn),
338                        idxdiv(raw_kernel_shape[1], ic_bn),
339                        raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn)
340    new_data = tvm.placeholder(new_data_shape, data.dtype)
341    new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
342
343    C = _declaration_conv_NCHWc(cfg, new_data, new_kernel, strides, padding, dilation,
344                                data_layout, out_layout, dtype)
345    s = _schedule_conv2d_NCHWc(cfg, [C])
346    return s, [new_data, new_kernel, C]
347
348
349@conv2d_infer_layout.register("cpu")
350def _conv2d_infer_layout(workload, cfg):
351    _, data, kernel, strides, padding, dilation, layout, dtype = workload
352    batch_size, in_channel, in_height, in_width = data[:-1]
353    out_channel, _, k_height, k_width = kernel[:-1]
354    idxdiv = tvm.indexdiv
355
356    out_height = idxdiv(in_height + 2 * padding[0] - k_height, strides[0]) + 1
357    out_width = idxdiv(in_width + 2 * padding[1] - k_width, strides[1]) + 1
358    tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
359    in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic)
360    in_layout = "NCHW%dc" % tile_ic
361    out_shape = (batch_size, idxdiv(out_channel, tile_oc), out_height, out_width, tile_oc)
362    out_layout = "NCHW%dc" % tile_oc
363    return ((in_shape, in_layout),), ((out_shape, out_layout),)
364
365
366@autotvm.register_topi_compute(conv2d_NCHWc, 'cpu', 'direct')
367def _declaration_conv_NCHWc(cfg, data, kernel, strides,
368                            padding, dilation, layout, out_layout, out_dtype):
369    # layout and out_layout are not used here,
370    # we keep them for debug convenience when dumping autotvm workload
371    n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
372    in_channel = ic_chunk * ic_bn
373    oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \
374            get_const_tuple(kernel.shape)
375    num_filter = oc_chunk * oc_bn
376
377    # If no config was set, we can fallback to NCHW config.
378    if cfg.is_fallback:
379        _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
380                            tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width),
381                                            dtype=kernel.dtype),
382                            strides, padding, out_dtype)
383
384    return nn.conv2d_NCHWc_compute(data,
385                                   kernel,
386                                   strides,
387                                   padding,
388                                   dilation,
389                                   layout,
390                                   out_layout,
391                                   out_dtype)
392
393
394@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc, 'cpu', ['direct'])
395def _schedule_conv2d_NCHWc(cfg, outs):
396    """Create schedule for tensors"""
397    s = tvm.create_schedule([x.op for x in outs])
398    scheduled_ops = []
399
400    def traverse(op):
401        """Traverse operators from computation graph"""
402        # inline all one-to-one-mapping operators except the last stage (output)
403        if tag.is_broadcast(op.tag):
404            if op not in s.outputs:
405                s[op].compute_inline()
406            for tensor in op.input_tensors:
407                if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
408                    traverse(tensor.op)
409
410        if 'conv2d_NCHWc' in op.tag:
411            conv_out = op.output(0)
412            kernel = conv_out.op.input_tensors[1]
413            data_vec = conv_out.op.input_tensors[0]
414            data = data_vec.op.input_tensors[0] \
415                if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \
416                else data_vec
417            if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
418                data_pad = data
419                data = data_pad.op.input_tensors[0]
420
421            args = [s, cfg, data_vec, conv_out, outs[0]]
422            target = tvm.target.current_target(allow_none=False)
423            _, _, kh, kw, _, _, = get_const_tuple(kernel.shape)
424            if kh == 1 and kw == 1:
425                conv2d_avx_1x1._schedule_conv_NCHWc(*args)
426            else:
427                conv2d_avx_common._schedule_conv_NCHWc(*args)
428
429        scheduled_ops.append(op)
430
431    traverse(outs[0].op)
432    return s
433