1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name
18"""1x1 Conv2D schedule on for Intel CPU"""
19from __future__ import absolute_import as _abs
20import tvm
21from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
22
23from ..nn.pad import pad
24from ..nn.util import infer_pad, get_pad_tuple
25from ..generic import conv2d as conv2d_generic
26from ..util import get_const_tuple, simplify
27from .tensor_intrin import dot_16x1x16_uint8_int8_int32
28from .util import get_fp32_len
29
30def _fallback_schedule(cfg, wkl):
31    simd_width = get_fp32_len()
32    HPAD, WPAD = wkl.hpad, wkl.wpad
33    HSTR, WSTR = wkl.hstride, wkl.wstride
34    out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
35    out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
36
37    oc_bn = 1
38    for bn in range(simd_width, 0, -1):
39        if wkl.out_filter % bn == 0:
40            oc_bn = bn
41            break
42
43    ic_bn = 1
44    for bn in range(oc_bn, 0, -1):
45        if wkl.in_filter % bn == 0:
46            ic_bn = bn
47            break
48
49    for ow_factor in range(out_width, 0, -1):
50        if out_width % ow_factor == 0:
51            for oh_factor in range(out_height, 0, -1):
52                if out_height % oh_factor == 0 and ow_factor * oh_factor < 32:
53                    cfg["tile_ic"] = SplitEntity([wkl.in_filter // ic_bn, ic_bn])
54                    cfg["tile_oc"] = SplitEntity([wkl.out_filter // oc_bn, oc_bn])
55                    cfg["tile_oh"] = OtherOptionEntity(oh_factor)
56                    cfg["tile_ow"] = SplitEntity([out_width // ow_factor, ow_factor])
57                    return
58    raise ValueError("cannot decide default schedule for workload: {}".format(wkl))
59
60
61def _schedule_conv(s, cfg, data, data_pad, data_vec, kernel_vec, conv_out, output, last):
62    # fetch schedule
63    ic_bn, oc_bn, oh_factor, ow_factor = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
64                                          cfg["tile_oh"].val, cfg["tile_ow"].size[-1])
65
66    # no stride and padding info here
67    padding = infer_pad(data, data_pad)
68    HPAD, WPAD = padding
69    DOPAD = (HPAD != 0 or WPAD != 0)
70
71    A, W = data, kernel_vec
72    A0, A1 = data_pad, data_vec
73    # schedule data
74    if DOPAD:
75        s[A0].compute_inline()
76    batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis
77    parallel_axis = s[A1].fuse(batch, ic_chunk, ih)
78    s[A1].parallel(parallel_axis)
79
80    # schedule kernel pack
81    oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[W].op.axis
82    s[W].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
83    if oc_bn > 1:
84        s[W].vectorize(oc_block)
85    parallel_axis = s[W].fuse(oc_chunk, oh)
86    s[W].parallel(parallel_axis)
87
88    C, O0, O = conv_out, output, last
89    CC = s.cache_write(C, 'global')
90
91    batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
92    oh_outer, oh_inner = s[C].split(oh, factor=oh_factor)
93    s[C].vectorize(oc_block)
94
95    s[CC].compute_at(s[C], oh_outer)
96    _, oc_chunk, oh, ow, oc_block = s[CC].op.axis
97    ic, _, _ = s[CC].op.reduce_axis
98
99    ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn)
100
101    oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor)
102    ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor)
103
104    s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_chunk, ic_block, oh_inner, ow_inner, oc_block)
105    s[CC].vectorize(oc_block)
106
107    s[CC].unroll(ow_inner)
108    s[CC].unroll(oh_inner)
109
110    if O0 != O:
111        s[O0].compute_inline()
112    batch, oc, oh, ow = s[O].op.axis
113
114    oc_chunk, oc_block = s[O].split(oc, factor=oc_bn)
115    oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
116    ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
117    s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
118
119    parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
120    s[C].compute_at(s[O], parallel_axis)
121    s[O].vectorize(oc_block)
122
123    s[O].parallel(parallel_axis)
124
125    return s
126
127
128def _schedule_conv_NCHWc(s, cfg, data, conv_out, last):
129    # fetch schedule
130    oh_factor, ow_factor = cfg["tile_oh"].val, cfg["tile_ow"].size[-1]
131    _, _, _, _, ic_bn = get_const_tuple(data.shape)
132
133    # schedule data
134    A = data
135    if isinstance(s[A].op, tvm.tensor.ComputeOp):
136        batch, ic_chunk, ih, iw, ic_block = s[A].op.axis
137        parallel_axis = s[A].fuse(batch, ic_chunk, ih)
138        s[A].parallel(parallel_axis)
139
140    C, O = conv_out, last
141    CC = s.cache_write(C, 'global')
142
143    batch, oc_chunk, oh, ow, oc_block = s[C].op.axis
144    oh_outer, oh_inner = s[C].split(oh, factor=oh_factor)
145    ow_outer, ow_inner = s[C].split(ow, factor=ow_factor)
146    s[C].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
147    s[C].vectorize(oc_block)
148
149    parallel_axis = s[C].fuse(batch, oc_chunk, oh_outer)
150    s[CC].compute_at(s[C], parallel_axis)
151    if C == O:
152        s[C].parallel(parallel_axis)
153
154    _, oc_chunk, oh, ow, oc_block = s[CC].op.axis
155    ic, _, _ = s[CC].op.reduce_axis
156
157    ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn)
158
159    oh_outer, oh_inner = s[CC].split(oh, factor=oh_factor)
160    ow_outer, ow_inner = s[CC].split(ow, factor=ow_factor)
161
162    s[CC].reorder(oc_chunk, oh_outer, ow_outer, ic_chunk, ic_block, oh_inner, ow_inner, oc_block)
163    s[CC].fuse(oc_chunk, oh_outer)
164    s[CC].vectorize(oc_block)
165
166    s[CC].unroll(ow_inner)
167    s[CC].unroll(oh_inner)
168
169    if C != O:
170        batch, oc_chunk, oh, ow, oc_block = s[O].op.axis
171        oh_outer, oh_inner = s[O].split(oh, factor=oh_factor)
172        ow_outer, ow_inner = s[O].split(ow, factor=ow_factor)
173        s[O].reorder(oc_chunk, oh_outer, ow_outer, oh_inner, ow_inner, oc_block)
174
175        parallel_axis = s[O].fuse(batch, oc_chunk, oh_outer)
176        s[C].compute_at(s[O], parallel_axis)
177        s[O].vectorize(oc_block)
178        s[O].parallel(parallel_axis)
179
180    return s
181
182
183def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last):
184    return conv2d_generic.schedule_conv_NCHWc_cpu_1x1_int8(s, cfg, data, conv_out, last,
185                                                           int32_lanes=16,
186                                                           intrin=dot_16x1x16_uint8_int8_int32())
187
188
189def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, out_dtype):
190    # more assertion for the shapes
191    assert isinstance(stride, int) or len(stride) == 2
192    assert isinstance(dilation, int) or len(dilation) == 2
193    if isinstance(stride, int):
194        stride_h = stride_w = stride
195    else:
196        stride_h, stride_w = stride
197
198    if isinstance(dilation, int):
199        dilation_h = dilation_w = dilation
200    else:
201        dilation_h, dilation_w = dilation
202
203    batch, in_height, in_width, in_channel = Input.shape
204    kernel_h, kernel_w, num_filter, channel = Filter.shape
205
206    # compute the output shape
207    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
208    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
209    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
210        padding, (dilated_kernel_h, dilated_kernel_w))
211    out_channel = num_filter
212    out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
213    out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
214    pad_before = [0, pad_top, pad_left, 0]
215    pad_after = [0, pad_down, pad_right, 0]
216    PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
217    # todo: padding filter to accomodate the intrinsic
218
219    # packing the Filter to let memory access be consecutive for AVX512 intrinsic
220    # Done in pre-compute stage
221    idxd = tvm.indexdiv
222    idxm = tvm.indexmod
223
224    packw_shape = (kernel_h, kernel_w, idxd(num_filter, 16), 16 * idxd(channel, 4), 4)
225    PackW = tvm.compute(packw_shape,
226                        lambda a, b, c, d, e:
227                        Filter[a, b,
228                               c*16 + idxm(d, 16),
229                               idxd(d, 16) * 4 + e],
230                        name="packed_filter")
231
232    rc = tvm.reduce_axis((0, in_channel), name='rc')
233    ry = tvm.reduce_axis((0, kernel_h), name='ry')
234    rx = tvm.reduce_axis((0, kernel_w), name='rx')
235    Output = tvm.compute(
236        (batch, out_height, out_width, out_channel),
237        lambda nn, yy, xx, ff: tvm.sum(
238            PaddedInput[nn, yy * stride_h + ry * dilation_h,
239                        xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
240            PackW[ry, rx, idxd(ff, 16),
241                  idxd(rc, 4) * 16 + idxm(ff, 16),
242                  idxm(rc, 4)].astype(out_dtype), axis=[ry, rx, rc]),
243        name="Conv2d_1x1_Output_int8", tag="conv2d_nhwc_pack_int8")
244    return Output
245
246
247def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last):
248    """
249    Defines the schedule for the int8 nhwc layout. For 1x1 conv, it
250    is a matrix-multiply operation by using nhwc layout. We will do
251    packing of weight to make the address access be friendly to int8
252    intrinsic
253    """
254    # FIXME - https://github.com/apache/incubator-tvm/issues/3598
255    # pylint: disable=unreachable
256    return s
257
258    int32_lanes = 16
259
260    # assertion to fail the unhandled case
261    _, _, _, ic_num = get_const_tuple(data.shape)
262    _, _, _, oc_num = get_const_tuple(conv_out.shape)
263    assert ic_num % 4 == 0
264    assert oc_num % 16 == 0
265
266    ic_factor, oc_factor = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
267    # schedule data
268    A = data
269    if isinstance(s[A].op, tvm.tensor.ComputeOp):
270        batch, ih, iw, ic = s[A].op.axis
271        d_ic_chunk, d_ic_block = s[A].split(ic, factor=4)
272        s[A].vectorize(d_ic_block)
273
274    C, O = conv_out, last
275
276    batch, oh, ow, oc = s[C].op.axis
277    kh, kw, ic = s[C].op.reduce_axis
278    # match the x86 intrinsic
279    ic_outer, ic_inner = s[C].split(ic, factor=4)
280    oc_outer, oc_inner = s[C].split(oc, factor=int32_lanes)
281
282    ic_f_outer, ic_s_outer = s[C].split(ic_outer, factor=ic_factor)
283    s[C].reorder(oc_outer, oh, ow, ic_f_outer, ic_s_outer, kh, kw, oc_inner, ic_inner)
284
285    pc = dot_16x1x16_uint8_int8_int32()
286    s[C].tensorize(oc_inner, pc)
287
288    if C != O:
289        batch, last_oh, last_ow, last_oc = s[O].op.axis
290        oc_chunk, oc_block = s[O].split(ochannel, 16)
291        # not saw perf improvement to split oh/ow here
292        s[O].vectorize(oc_block)
293
294    return s
295