1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name,unused-variable
18"""Depthwise convolution schedule for ARM CPU"""
19
20import tvm
21from tvm import autotvm
22
23from ..generic import schedule_depthwise_conv2d_nchw
24from ..nn import depthwise_conv2d_nchw, pad
25from ..util import traverse_inline, get_const_tuple, get_const_int
26from ..nn.util import get_pad_tuple
27
28# register original implementation of depthwise_conv2d_nchw since we don't need to change this part
29autotvm.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', 'direct',
30                              depthwise_conv2d_nchw.fdefault)
31
32# register customized schedule for arm cpu.
33@autotvm.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'arm_cpu',
34                                ['direct', 'contrib_spatial_pack'])
35def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
36    """Schedule depthwise conv2d
37
38    Parameters
39    ----------
40    cfg: ConfigEntity
41        The configuration of this template
42    outs: Array of Tensor
43        The computation graph description of depthwise convolution2d
44        in the format of an array of tensors.
45
46    Returns
47    -------
48    s: Schedule
49        The computation schedule for depthwise_conv2d nchw.
50    """
51    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
52    s = tvm.create_schedule([x.op for x in outs])
53
54    def _schedule(cfg, s, data, data_pad, kernel, output):
55        A, B, C = data, kernel, output
56        s[data_pad].compute_inline()
57
58        ##### space definition begin #####
59        n, c, h, w = s[output].op.axis
60        _, vc = cfg.define_split('tile_c', c, num_outputs=2)
61        _, vh = cfg.define_split('tile_h', h, num_outputs=2)
62        _, vw = cfg.define_split('tile_w', w, num_outputs=2)
63        cfg.define_annotate('ann', [vh, vw, vc], policy='try_unroll_vec')
64
65        # fallback support
66        if cfg.is_fallback:
67            ref_log = autotvm.tophub.load_reference_log(
68                'arm_cpu', 'rk3399', 'depthwise_conv2d_nchw', 'direct')
69            cfg.fallback_with_reference_log(ref_log)
70        ##### space definition end #####
71
72        # park data to vector form  [n, c, h, w] -> [n, C, h, w, VC]
73        A0 = s.cache_read(data_pad, "global", C)
74        n, c, h, w = s[A0].op.axis
75        c, vc = cfg['tile_c'].apply(s, A0, c)
76        s[A0].reorder(n, c, h, w, vc)
77        A1 = s.cache_write(A0, 'global')
78        s[A0].compute_inline()
79
80        # park kernel to vector form  [co, ci, kh, kw] -> [CO, ci, kh, kw, VC]
81        B0 = s.cache_read(B, "global", C)
82        c, m, h, w = s[B0].op.axis
83        c, vc, = cfg['tile_c'].apply(s, B0, c)
84        s[B0].reorder(c, m, h, w, vc)
85        B1 = s.cache_write(B0, 'global')
86        s[B0].compute_inline()
87
88        n, c, h, w = s[C].op.axis
89        c, vc, = cfg['tile_c'].apply(s, C, c)
90        s[C].reorder(n, c, h, w, vc)
91
92        # depthwise conv
93        C0 = s.cache_write(C, 'global')
94        _, c, h, w, vc = s[C0].op.axis
95        dh, dw = s[C0].op.reduce_axis
96        oh, ih = cfg['tile_h'].apply(s, C0, h)
97        ow, iw = cfg['tile_w'].apply(s, C0, w)
98        s[C0].reorder(c, oh, ow, dh, dw, ih, iw, vc)
99        s[A1].compute_at(s[C0], oh)
100
101        # try unroll and vectorization
102        cfg['ann'].apply(s, C0, [ih, iw, vc],
103                         axis_lens=[cfg['tile_h'].size[-1],
104                                    cfg['tile_w'].size[-1],
105                                    cfg['tile_c'].size[-1]],
106                         max_unroll=16,
107                         cfg=cfg)
108
109        # fusion
110        if C.op not in s.outputs:
111            s[C].compute_inline()
112
113        # mark parallel
114        last = outs[0]
115        n, c, h, w = s[last].op.axis
116        s[last].parallel(c)
117
118        n, c, h, w, vc = s[C0].op.axis
119        s[C0].parallel(c)
120
121        c, m, h, w, vc = s[B1].op.axis
122        s[B1].parallel(c)
123
124        return s
125
126    def _callback(op):
127        if op.tag == 'depthwise_conv2d_nchw':
128            output = op.output(0)
129            kernel = op.input_tensors[1]
130            data = op.input_tensors[0]
131            data_pad = None
132            if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
133                data_pad = data
134                data = data_pad.op.input_tensors[0]
135            _schedule(cfg, s, data, data_pad, kernel, output)
136
137        if op.tag == 'spatial_depthwise_conv2d_nchw_output':
138            output = op.output(0)
139            conv = op.input_tensors[0]
140            data_vec = conv.op.input_tensors[0]
141            kernel_vec = conv.op.input_tensors[1]
142            if kernel_vec.op.name == 'kernel_vec':
143                kernel = kernel_vec.op.input_tensors[0]
144            else:
145                kernel = kernel_vec
146            if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
147                s[kernel].compute_inline()
148
149            _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, outs[0])
150
151    traverse_inline(s, outs[0].op, _callback)
152    return s
153
154@autotvm.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', ['contrib_spatial_pack'])
155def depthwise_conv2d_arm_cpu(cfg, data, kernel, strides, padding, dilation, out_dtype):
156    """TOPI compute callback for depthwise_conv2d nchw
157
158    Parameters
159    ----------
160    cfg: ConfigEntity
161        The config for this template
162
163    data : tvm.Tensor
164        4-D with shape [batch, in_channel, in_height, in_width]
165
166    kernel : tvm.Tensor
167        4-D with shape [num_filter, multiplier, filter_height, filter_width] or
168        pre-packed 5-D with shape [num_filter_chunk, multiplier, filter_height,
169        filter_width, num_filter_block]
170
171    strides : list of two ints
172        [stride_height, stride_width]
173
174    padding : list of two ints
175        [pad_height, pad_width]
176
177    dilation : list of two ints
178        [dilation_height, dilation_width]
179
180    out_dtype: str
181        The output type. This is used for mixed precision.
182
183    Returns
184    -------
185    output : tvm.Tensor
186        4-D with shape [batch, out_channel, out_height, out_width]
187    """
188
189    return _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile=2)
190
191
192def _decl_spatial_pack(cfg, data, kernel, strides, padding, dilation, out_dtype, num_tile):
193    out_dtype = out_dtype or data.dtype
194
195    N, C, IH, IW = get_const_tuple(data.shape)
196
197    if isinstance(dilation, int):
198        dilation_h = dilation_w = dilation
199    else:
200        dilation_h, dilation_w = dilation
201
202    if len(kernel.shape) == 4:
203        pre_packed = False
204        C, M, KH, KW = get_const_tuple(kernel.shape)
205    else:  # kernel tensor is pre packed
206        pre_packed = True
207        C, M, KH, KW, VC = get_const_tuple(kernel.shape)
208        C = C * VC
209
210    dilated_kernel_h = (KH - 1) * dilation_h + 1
211    dilated_kernel_w = (KW - 1) * dilation_w + 1
212
213    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
214        padding, (dilated_kernel_h, dilated_kernel_w))
215    HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
216    OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
217    OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
218    # pack data
219    HPAD = pad_top + pad_down
220    WPAD = pad_left + pad_right
221    DOPAD = (HPAD != 0 or WPAD != 0)
222    if DOPAD:
223        data_pad = pad(data, (0, 0, pad_top, pad_left), (0, 0, pad_down, pad_right),
224                       name="data_pad")
225    else:
226        data_pad = data
227
228    # fallback support
229    # Currently, Mali schedule doesn't use it like conv2d.
230    if cfg.is_fallback:
231        ref_log = autotvm.tophub.load_reference_log('arm_cpu', 'rk3399', 'depthwise_conv2d_nchw',
232                                                    'contrib_spatial_pack')
233        cfg.fallback_with_reference_log(ref_log)
234
235    # ==================== define configuration space ====================
236    n, c, oh, ow = cfg.axis(N), cfg.axis(C), cfg.axis(OH), cfg.axis(OW)
237    kh, kw = cfg.reduce_axis(KH), cfg.reduce_axis(KW)
238
239    # Currently, Mali schedule doesn't use it like conv2d.
240    # Leave num_tile for possible future use of Mali schedule
241    if num_tile == 2:     # for arm cpu
242        co, vc = cfg.define_split('tile_co', c, num_outputs=2)
243        oh, vh = cfg.define_split('tile_oh', oh, num_outputs=2)
244        ow, vw = cfg.define_split('tile_ow', ow, num_outputs=2)
245    else:
246        raise RuntimeError("Invalid num_tile")
247
248    cfg.define_reorder("reorder_0",
249                       [n, co, oh, ow, kh, kw, vh, vw, vc],
250                       policy='candidate', candidate=[
251                           [n, co, oh, ow, kh, kw, vh, vw, vc],
252                           [n, co, oh, ow, kh, kw, vc, vh, vw]])
253
254    cfg.define_reorder("reorder_1",
255                       [n, co, oh, ow, vh, vw, vc],
256                       policy='candidate', candidate=[
257                           [n, co, oh, ow, vh, vw, vc],
258                           [n, co, oh, ow, vc, vh, vw],
259                           [n, co, oh, ow, vh, vc, vw]])
260
261    cfg.define_annotate("ann_reduce", [kh, kw], policy='try_unroll')
262    cfg.define_annotate("ann_spatial", [vh, vw, vc], policy='try_unroll_vec')
263    # ====================================================================
264
265    VC = cfg["tile_co"].size[-1]
266    VH = cfg["tile_oh"].size[-1]
267    VW = cfg["tile_ow"].size[-1]
268
269    kvshape = (C // VC, M, KH, KW, VC)
270    ovshape = (N, C * M // VC, OH // VH, OW // VW, VH, VW, VC)
271    oshape = (N, C * M, OH, OW)
272
273    if dilation_h != 1 or dilation_w != 1:
274        # undilate input data
275        dvshape = (N, OH // VH, OW // VW, C, KH, KW, VH, VW)
276        data_vec = tvm.compute(dvshape, lambda n, h, w, c, kh, kw, vh, vw:
277                               data_pad[n][c][(h * VH + vh) * HSTR + kh * dilation_h]
278                               [(w*VW+vw)*WSTR+kw*dilation_w],
279                               name='data_vec_undilated')
280    else:
281        dvshape = (N, OH // VH, OW // VW, C, VH*HSTR + KH-1, VW*WSTR + KW-1)
282        data_vec = tvm.compute(dvshape, lambda n, h, w, c, vh, vw:
283                               data_pad[n][c][h * VH * HSTR + vh][w * VW * WSTR + vw],
284                               name='data_vec')
285
286    if pre_packed:
287        kernel_vec = kernel
288    else:
289        kernel_vec = tvm.compute(kvshape, lambda co, m, kh, kw, vc:
290                                 kernel[co*VC+vc][m][kh][kw],
291                                 name='kernel_vec')
292
293    kh = tvm.reduce_axis((0, KH), name='kh')
294    kw = tvm.reduce_axis((0, KW), name='kw')
295
296    idxdiv = tvm.indexdiv
297    idxmod = tvm.indexmod
298
299    if dilation_h != 1 or dilation_w != 1:
300        conv = tvm.compute(
301            ovshape, lambda n, co, h, w, vh, vw, vc: \
302            tvm.sum(data_vec[n, h, w, idxdiv(co * VC + vc, M), kh, kw, vh, vw]
303                    .astype(out_dtype) *
304                    kernel_vec[idxdiv(co, M), idxmod(co, M), kh, kw, vc].astype(out_dtype),
305                    axis=[kh, kw]), name='depthwise_conv')
306    else:
307        conv = tvm.compute(ovshape, lambda n, co, h, w, vh, vw, vc: \
308                           tvm.sum(data_vec[n, h, w, idxdiv((co * VC + vc), M), vh * HSTR + kh,
309                                            vw * WSTR + kw].astype(out_dtype) *
310                                   kernel_vec[idxdiv(co, M),
311                                              idxmod(co, M),
312                                              kh, kw, vc].astype(out_dtype),
313                                   axis=[kh, kw]), name='depthwise_conv')
314
315    output = tvm.compute(oshape, lambda n, co, h, w:
316                         conv[n,
317                              idxdiv(co, VC), idxdiv(h, VH), idxdiv(w, VW),
318                              idxmod(h, VH), idxmod(w, VW), idxmod(co, VC)],
319                         name='output_unpack', tag='spatial_depthwise_conv2d_nchw_output')
320    return output
321
322def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
323                           conv, output, last):
324    """schedule implementation"""
325    n, co, oh, ow, vh, vw, vc = s[conv].op.axis
326    kh, kw = s[conv].op.reduce_axis
327
328    if data_vec.op.name == 'data_vec_undilated':
329        _, dv_oh, dv_ow, dv_c, _, _, dv_vh, dv_vw = s[data_vec].op.axis
330    else:
331        _, dv_oh, dv_ow, dv_c, dv_vh, dv_vw = s[data_vec].op.axis
332
333    data_pad = data_vec.op.input_tensors[0]
334    if data_pad.op.name == "data_pad":
335        assert isinstance(data_pad.op, tvm.tensor.ComputeOp)
336        has_padding = True
337    else:
338        assert isinstance(data_pad.op, tvm.tensor.PlaceholderOp)
339        has_padding = False
340
341    cfg.define_knob('data_pad_inline', [0, 1, 2, 3, 4])
342
343    if cfg['data_pad_inline'].val == 1 and has_padding:
344        s[data_pad].compute_inline()
345    if cfg['data_pad_inline'].val == 2 and has_padding:
346        s[data_pad].vectorize(list(s[data_pad].op.axis)[-1])
347    if cfg['data_pad_inline'].val == 3 and has_padding:
348        s[data_pad].vectorize(list(s[data_pad].op.axis)[-1])
349        s[data_pad].compute_at(s[data_vec], dv_oh)
350    if cfg['data_pad_inline'].val == 4 and has_padding:
351        s[data_pad].vectorize(list(s[data_pad].op.axis)[-1])
352        s[data_pad].compute_at(s[data_vec], dv_ow)
353
354    cfg.define_knob('data_vec_inline', [0, 1, 2, 3])
355    if cfg['data_vec_inline'].val == 1:
356        s[data_vec].compute_at(s[conv], oh)
357    if cfg['data_vec_inline'].val == 2:
358        s[data_vec].compute_at(s[conv], ow)
359    if cfg['data_vec_inline'].val == 3:
360        s[data_vec].compute_at(s[conv], co)
361
362    # schedule conv
363    cfg["reorder_0"].apply(s, conv, [n, co, oh, ow, kh, kw, vh, vw, vc])
364    cfg["ann_reduce"].apply(s, conv, [kh, kw],
365                            axis_lens=[get_const_int(kh.dom.extent),
366                                       get_const_int(kw.dom.extent)],
367                            max_unroll=16,
368                            cfg=cfg)
369    cfg["ann_spatial"].apply(s, conv, [vh, vw, vc],
370                             axis_lens=[cfg['tile_oh'].size[-1],
371                                        cfg['tile_ow'].size[-1],
372                                        cfg['tile_co'].size[-1]],
373                             max_unroll=16,
374                             cfg=cfg)
375
376    # schedule fusion
377    n, co, h, w = s[last].op.axis
378    co, vc = cfg['tile_co'].apply(s, last, co)
379    oh, vh = cfg['tile_oh'].apply(s, last, h)
380    ow, vw = cfg['tile_ow'].apply(s, last, w)
381    cfg["reorder_1"].apply(s, last, [n, co, oh, ow, vh, vw, vc])
382    if last != output:
383        s[output].compute_inline()
384        cfg["ann_spatial"].apply(s, last, [vh, vw, vc],
385                                 axis_lens=[cfg['tile_oh'].size[-1],
386                                            cfg['tile_ow'].size[-1],
387                                            cfg['tile_co'].size[-1]],
388                                 max_unroll=16,
389                                 cfg=cfg)
390    else:
391        s[last].vectorize(vw)
392    cfg.define_knob('conv_inline', [0, 1, 2, 3])
393    if cfg['conv_inline'].val == 1:
394        s[conv].compute_at(s[last], ow)
395    if cfg['conv_inline'].val == 2:
396        s[conv].compute_at(s[last], oh)
397    if cfg['conv_inline'].val == 3:
398        s[conv].compute_at(s[last], co)
399
400    # mark parallel
401    s[last].parallel(co)
402
403    if data_vec.op.name == 'data_vec_undilated':
404        _, h, _, _, _, _, _, _ = s[data_vec].op.axis
405    else:
406        _, h, _, _, _, _ = s[data_vec].op.axis
407    s[data_vec].parallel(h)
408
409    if kernel_vec.op.name == 'kernel_vec':
410        co, _, _, _, _ = s[kernel_vec].op.axis
411        if autotvm.GLOBAL_SCOPE.in_tuning:
412            # kernel packing will be pre-computed during compliation, so we skip
413            # this part to make tuning records correct
414            s[kernel_vec].pragma(co, 'debug_skip_region')
415        else:
416            s[kernel_vec].parallel(co)
417
418    return s
419