1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return, too-many-arguments, too-many-locals, too-many-statements, no-member, too-many-branches, too-many-boolean-expressions
18"""conv2d schedule on Intel Graphics"""
19
20from __future__ import absolute_import as _abs
21
22import tvm
23
24from tvm import autotvm
25from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
26from tvm.autotvm.task.topi_integration import deserialize_args
27from tvm.autotvm.task import get_config
28from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, conv2d_infer_layout
29from ..nn.util import get_pad_tuple
30from ..nn.depthwise_conv2d import depthwise_conv2d_nchw
31from ..nn import pad
32from .. import tag
33from .. import generic
34from .. import util
35from ..util import simplify, get_const_tuple
36
37
38def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False):
39    if is_depthwise:
40        raise RuntimeError("Depthwise not supported for intel graphics.")
41    else:
42        batch_size, in_channel, height, width = get_const_tuple(data.shape)
43        out_channel, _, hkernel, _ = get_const_tuple(kernel.shape)
44        HSTR, _ = strides
45
46        ic_bn = 1
47        oc_bn, oc_bn_upper = 16, 16
48        for i in range(oc_bn_upper, 0, -1):
49            if out_channel % i == 0:
50                oc_bn = i
51                break
52
53        if HSTR == 2:
54            if out_channel + hkernel == 515:
55                block_oh = 4
56                block_ow = 4
57            else:
58                block_oh = 4
59                block_ow = 5
60        elif hkernel == 3:
61            if out_channel == 512:
62                block_oh = 2
63                block_ow = 7
64            else:
65                block_oh = 2
66                block_ow = 14
67        else:
68            block_oh = 1
69            block_ow = 16
70        cfg["tile_ic"] = SplitEntity([in_channel // ic_bn, ic_bn])
71        cfg["tile_oc"] = SplitEntity([out_channel // oc_bn, oc_bn])
72        cfg["block_oh"] = OtherOptionEntity(block_oh)
73        cfg["block_ow"] = OtherOptionEntity(block_ow)
74
75
76def _create_schedule_template(cfg, data, kernel, strides, padding, dilation, layout):
77    """Create schedule configuration from input arguments"""
78    dshape = get_const_tuple(data.shape)
79    kshape = get_const_tuple(kernel.shape)
80    if layout == 'NCHW':
81        n, ic, h, w = dshape
82        oc, _, kh, kw = kshape
83    else:
84        raise ValueError("Not support this layout {} with "
85                         "schedule template.".format(layout))
86    ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding)
87    sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
88    oh = (h - kh + 2 * ph) // sh + 1
89    ow = (w - kw + 2 * pw) // sw + 1
90    ic_bn_upper = 32
91    oc_bn_upper = 64
92    oc_bn_lower = min(oc, 8)
93    ic_bn_candidates, oc_bn_candidates = [], []
94    for i in range(1, ic + 1):
95        if ic % i == 0 and i <= ic_bn_upper:
96            ic_bn_candidates.append(i)
97    if not ic_bn_candidates:
98        ic_bn_candidates.append(1)
99        ic_bn_candidates.append(ic)
100
101    for i in range(1, oc + 1):
102        if oc % i == 0 and oc_bn_lower <= i <= oc_bn_upper:
103            oc_bn_candidates.append(i)
104    if not oc_bn_candidates:
105        oc_bn_candidates.append(1)
106        oc_bn_candidates.append(oc)
107
108    blk_candidates_low_limits = 5
109    blk_oh_list, blk_ow_list = [], []
110    for i, j in zip(range(oh, 0, -1), range(ow, 0, -1)):
111        if i <= 16 and oh % i == 0:
112            blk_oh_list.append(i)
113        if j <= 16 and ow % j == 0:
114            blk_ow_list.append(j)
115
116    if len(blk_oh_list) < blk_candidates_low_limits:
117        for i in range(2, oh):
118            if i not in blk_oh_list:
119                blk_oh_list.append(i)
120                if len(blk_oh_list) >= 5:
121                    break
122
123    if len(blk_ow_list) < blk_candidates_low_limits:
124        for i in range(min(ow - 1, 16), 1, -1):
125            if i not in blk_ow_list:
126                blk_ow_list.append(i)
127                if len(blk_ow_list) >= 5:
128                    break
129
130    # Create schedule config
131    cfg.define_knob("tile_ic", ic_bn_candidates)
132    cfg.define_knob("tile_oc", oc_bn_candidates)
133    cfg.define_knob("block_oh", blk_oh_list)
134    cfg.define_knob("block_ow", blk_ow_list)
135
136
137##### SCHEDULE UTILITIES #####
138def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None):
139    """ tile and bind 3d """
140    y_factor = y_factor or z_factor
141    x_factor = x_factor or y_factor
142    zo, zi = s[tensor].split(z, z_factor)
143    yo, yi = s[tensor].split(y, y_factor)
144    xo, xi = s[tensor].split(x, x_factor)
145    s[tensor].reorder(zo, yo, xo, zi, yi, xi)
146
147    thread_z = tvm.thread_axis((0, z_factor), "threadIdx.z")
148    thread_y = tvm.thread_axis((0, y_factor), "threadIdx.y")
149    thread_x = tvm.thread_axis((0, x_factor), "threadIdx.x")
150    s[tensor].bind(zo, tvm.thread_axis("blockIdx.z"))
151    s[tensor].bind(zi, thread_z)
152    s[tensor].bind(yo, tvm.thread_axis("blockIdx.y"))
153    s[tensor].bind(yi, thread_y)
154    s[tensor].bind(xo, tvm.thread_axis("blockIdx.x"))
155    s[tensor].bind(xi, thread_x)
156    return xi, thread_z, thread_y, thread_x
157
158# Define template function for autotvm task
159# We define schedule template in this function instead of
160# declaration function since actual input arguments need
161# to be altered by the schedule selected.
162@autotvm.task.register("topi_intel_graphics_conv2d_NCHWc")
163def __topi_nn_conv2d_NCHWc(*args, **kwargs):
164    assert not kwargs, "Do not support kwargs in template function call"
165    data, kernel, strides, padding, dilation, layout, dtype = deserialize_args(args)
166    raw_data_shape = get_const_tuple(data.shape)
167    raw_kernel_shape = get_const_tuple(kernel.shape)
168
169    # get config here
170    cfg = get_config()
171    _create_schedule_template(cfg, data, kernel, strides, padding, dilation, layout)
172    cfg.add_flop(1)
173
174    # change shape with the value in config
175    ic_bn = cfg["tile_ic"].val if hasattr(cfg["tile_ic"], "val") else cfg["tile_ic"].size[-1]
176    oc_bn = cfg["tile_oc"].val if hasattr(cfg["tile_oc"], "val") else cfg["tile_oc"].size[-1]
177
178    new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn,
179                      raw_data_shape[2], raw_data_shape[3], ic_bn)
180    new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn,
181                        raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn)
182    new_data = tvm.placeholder(new_data_shape, data.dtype)
183    new_kernel = tvm.placeholder(new_kernel_shape, kernel.dtype)
184
185    C = _decl_cl_spatialpack_NCHWc(cfg, new_data, new_kernel, strides, padding, dilation, dtype)
186    s = _schedule_conv2d_NCHWc(cfg, [C])
187
188    return s, [new_data, new_kernel, C]
189
190@conv2d_alter_layout.register(["intel_graphics"])
191def _alter_conv2d_layout(attrs, inputs, tinfo, F):
192    import nnvm.symbol as sym
193
194    copy_inputs = [s for s in inputs]
195    new_attrs = {k : attrs[k] for k in attrs.keys()}
196
197    if F.__name__ == 'tvm.relay.op':
198        # Derive channels for frontends (e.g ONNX) that miss "channel" field.
199        new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')]
200
201    data, kernel = tinfo[0], tinfo[1]
202    batch_size, in_channel, height, width = get_const_tuple(data.shape)
203
204    groups = attrs.get_int("groups")
205    out_channel = attrs.get_int("channels")
206    padding = attrs.get_int_tuple("padding")
207    strides = attrs.get_int_tuple("strides")
208    dilation = attrs.get_int_tuple("dilation")
209    out_dtype = attrs["out_dtype"]
210
211    layout_name = 'layout' if F == sym else 'data_layout'
212    layout = attrs[layout_name]
213    kh, kw = attrs.get_int_tuple("kernel_size")
214
215    dtype = data.dtype
216    out_dtype = dtype if out_dtype in ("same", "") else out_dtype
217    is_depthwise = groups == in_channel and groups == out_channel
218
219    # only optimize for NCHW
220    if layout != 'NCHW':
221        return None
222    if groups != 1 and not is_depthwise:
223        return None
224
225    dispatch_ctx = autotvm.task.DispatchContext.current
226    target = tvm.target.current_target()
227
228    # query schedule and fallback if necessary
229    workload = autotvm.task.args_to_workload(
230        [data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) \
231        if is_depthwise else \
232        autotvm.task.args_to_workload(
233            [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
234    if is_depthwise:
235        return None
236    cfg = dispatch_ctx.query(target, workload)
237    if cfg.is_fallback:
238        _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise)
239
240    ic_bn = cfg["tile_ic"].val if hasattr(cfg["tile_ic"], "val") else cfg["tile_ic"].size[-1]
241    oc_bn = cfg["tile_oc"].val if hasattr(cfg["tile_oc"], "val") else cfg["tile_oc"].size[-1]
242
243    new_attrs[layout_name] = 'NCHW%dc' % ic_bn
244    new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
245
246    new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
247                               dtype=data.dtype)
248
249    out_channel, _, kh, kw = get_const_tuple(kernel.shape)
250    # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
251    new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
252
253    # Store altered operator's config
254    new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
255                                 dtype=kernel.dtype)
256    new_workload = autotvm.task.args_to_workload(
257        [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
258         new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
259
260    dispatch_ctx.update(target, new_workload, cfg)
261    if F == sym:
262        return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
263    return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
264
265@autotvm.register_topi_compute(conv2d_NCHWc, 'intel_graphics', 'direct')
266def _decl_conv2d(cfg, data, kernel, strides, padding, dilation,
267                 layout, out_layout, out_dtype='float32'):
268    """Conv2D operator for Intel Graphics backend.
269
270    Parameters
271    ----------
272    data : tvm.Tensor
273        4-D with shape [batch, in_channel, in_height, in_width]
274
275    kernel : tvm.Tensor
276        5-D with shape [num_filter, in_channel, filter_height, filter_width, nnum_filter_vec]
277
278    stride : int or a list/tuple of two ints
279        stride size, or [stride_height, stride_width]
280
281    padding : int or a list/tuple of two ints
282        padding size, or [pad_height, pad_width]
283
284    layout : str
285        layout of data
286
287    Returns
288    -------
289    output : tvm.Tensor
290        4-D with shape [batch, out_channel, out_height, out_width]
291    """
292    dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
293    assert (dh, dw) == (1, 1), "Does not support dilation"
294
295    n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
296    oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape)
297    in_channel = ic_chunk * ic_bn
298    num_filter = oc_chunk * oc_bn
299    if cfg.is_fallback:
300        _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype),
301                            tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width),
302                                            dtype=kernel.dtype),
303                            strides, padding, out_dtype)
304
305    return _decl_cl_spatialpack_NCHWc(cfg, data, kernel, strides, padding, dilation, out_dtype)
306
307
308@conv2d_infer_layout.register("intel_graphics")
309def _conv2d_infer_layout(workload, cfg):
310    _, data, kernel, strides, padding, dilation, layout, dtype = workload
311    batch_size, in_channel, in_height, in_width = data[:-1]
312    out_channel, _, k_height, k_width = kernel[:-1]
313    out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1
314    out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1
315    tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
316    in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic)
317    in_layout = "NCHW%dc" % tile_ic
318    out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc)
319    out_layout = "NCHW%dc" % tile_oc
320    return ((in_shape, in_layout),), ((out_shape, out_layout),)
321
322
323@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc, 'intel_graphics', ['direct'])
324def _schedule_conv2d_NCHWc(cfg, outs):
325    """Schedule for conv2d_nchw for Intel Graphics
326
327    Parameters
328    ----------
329    outs: Array of Tensor
330        The computation graph description of conv2d_nchw
331        in the format of an array of tensors.
332
333    Returns
334    -------
335    s: Schedule
336        The computation schedule for conv2d_nchw.
337    """
338    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
339    s = tvm.create_schedule([x.op for x in outs])
340    scheduled_ops = []
341
342    def traverse(op):
343        """inline all one-to-one-mapping operators except the last stage (output)"""
344        if tag.is_injective(op.tag):
345            if op not in s.outputs:
346                s[op].compute_inline()
347            for tensor in op.input_tensors:
348                if tensor.op.input_tensors and tensor.op not in scheduled_ops:
349                    traverse(tensor.op)
350        if "conv" in op.tag:
351            _schedule_cl_spatialpack_NCHWc(cfg, s, op)
352
353        scheduled_ops.append(op)
354
355    traverse(outs[0].op)
356
357    return s
358
359def _decl_cl_spatialpack_NCHWc(cfg, data, kernel, strides, padding, dilation, out_dtype='float16'):
360    batch, in_channel, in_height, in_width, vc = [util.get_const_int(x) for x in data.shape]
361    in_channel *= vc
362    num_filter, channel, kernel_h, kernel_w, ci, co = [util.get_const_int(x) for x in kernel.shape]
363    num_filter *= co
364    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, kernel)
365
366    ic_bn = vc
367    assert vc == ci
368
369    if isinstance(strides, (tuple, list)):
370        stride_h, stride_w = strides
371    else:
372        stride_h, stride_w = strides, strides
373
374    out_channel = num_filter
375    out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1)
376    out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1)
377    oshape = (batch, out_channel // co, out_height, out_width, co)
378
379    rc = tvm.reduce_axis((0, in_channel), name='rc')
380    ry = tvm.reduce_axis((0, kernel_h), name='ry')
381    rx = tvm.reduce_axis((0, kernel_w), name='rx')
382
383    block_h = cfg["block_oh"].val
384    block_w = cfg["block_ow"].val
385
386    c_h = out_height
387    c_w = out_width
388
389    if out_height % block_h != 0:
390        c_h = (out_height // block_h + 1) * block_h
391
392    if out_width % block_w != 0:
393        c_w = (out_width // block_w + 1) * block_w
394
395    cshape = (batch, out_channel // co, c_h, c_w, co)
396
397    pad_before = [0, 0, pad_top, pad_left, 0]
398    pad_after = [0, 0, pad_down + c_h - out_height, pad_right + \
399                 c_w - out_width, 0]
400    DOPAD = (pad_top != 0 or pad_left != 0 or pad_down + c_h - out_height != 0 \
401             or pad_right + c_w - out_width != 0)
402    DOUNPACK = (c_h - out_height != 0 or c_w - out_width != 0)
403    if DOPAD:
404        temp = pad(data, pad_before, pad_after, name="pad_temp")
405    else:
406        temp = data
407
408    conv = tvm.compute(
409        cshape,
410        lambda nn, ff, yy, xx, ff_v: \
411            tvm.sum(
412                temp[nn, rc//ic_bn, yy * stride_h + ry, xx * stride_w + rx, rc%ic_bn]. \
413                        astype(out_dtype) *
414                kernel[ff, rc//ic_bn, ry, rx, rc%ic_bn, ff_v].astype(out_dtype),
415                axis=[rc, ry, rx]), tag="conv", name='conv')
416
417    if DOUNPACK:
418        output = tvm.compute(
419            oshape,
420            lambda nn, ff, yy, xx, ff_v:
421            conv[nn][ff][yy][xx][ff_v],
422            name='output_unpack', tag="conv_unpack")
423    else:
424        output = conv
425
426
427    return output
428
429
430def _schedule_cl_spatialpack_NCHWc(cfg, s, op):
431    output = op.output(0)
432    conv = op.input_tensors[0]
433    if conv.op.name == "conv":
434        temp = s[conv].op.input_tensors[0]
435        kernel = s[conv].op.input_tensors[1]
436        temp_W = s.cache_read(temp, "warp", [conv])
437        conv_L = s.cache_write(conv, "local")
438        SCHEDULE_OUTPUT = True
439    else:
440        temp = op.input_tensors[0]
441        kernel = op.input_tensors[1]
442        temp_W = s.cache_read(temp, "warp", [output])
443        conv_L = s.cache_write(output, "local")
444        if output.op in s.outputs:
445            conv = output
446        else:
447            s[output].compute_inline()
448            conv = s.outputs[0]
449        SCHEDULE_OUTPUT = False
450    kernel_L = s.cache_read(kernel, "local", [conv_L])
451
452    OUTPUT_BLOCK_HEIGHT = cfg["block_oh"].val
453    OUTPUT_BLOCK_WIDTH = cfg["block_ow"].val
454
455    # schedule conv
456    z_factor = 1
457    y_factor = 1
458    x_factor = 16
459    thread_z = tvm.thread_axis((0, z_factor), "threadIdx.z")
460    thread_y = tvm.thread_axis((0, y_factor), "threadIdx.y")
461    thread_x = tvm.thread_axis((0, x_factor), "threadIdx.x")
462    _, co, oh, ow, vc = s[conv].op.axis
463    ooh, ioh = s[conv].split(oh, factor=OUTPUT_BLOCK_HEIGHT)
464    oow, iow = s[conv].split(ow, factor=OUTPUT_BLOCK_WIDTH)
465    s[conv].reorder(_, co, ooh, oow, vc, ioh, iow)
466    coo, coi = s[conv].split(co, nparts=1)
467    ooho, oohi = s[conv].split(ooh, factor=z_factor)
468    oowo, oowi = s[conv].split(oow, factor=y_factor)
469    vco, vci = s[conv].split(vc, factor=x_factor)
470    s[conv].reorder(_, coo, vco, ooho, oowo, coi, oohi, oowi, vci, ioh, iow)
471    s[conv].bind(oohi, thread_z)
472    s[conv].bind(oowi, thread_y)
473    s[conv].bind(vci, thread_x)
474    s[conv].bind(ooho, tvm.thread_axis("blockIdx.z"))
475    s[conv].bind(oowo, tvm.thread_axis("blockIdx.y"))
476    s[conv].bind(coi, tvm.thread_axis("blockIdx.x"))
477
478    # schedule conv_L
479    s[conv_L].compute_at(s[conv], vci)
480    i, oc, h, w, vc = s[conv_L].op.axis
481    rc, ry, rx = s[conv_L].op.reduce_axis
482    s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w)
483    s[temp_W].compute_at(s[conv_L], rc)
484    if kernel.shape[3].value != 7:
485        s[conv_L].unroll(ry)
486        s[conv_L].unroll(rx)
487
488    # schedule temp
489    if temp.op.name == "pad_temp":
490        _, ci, h, w, vci = s[temp].op.axis
491        tile_and_bind3d(s, temp, ci, h, w, 1, 16, 16)
492
493    # schedule temp_W
494    _, ci, h, w, vci = s[temp_W].op.axis
495    zo, zi = s[temp_W].split(vci, 1)
496    yo, yi = s[temp_W].split(h, 1)
497    xo, xi = s[temp_W].split(w, 16)
498    s[temp_W].reorder(zo, yo, xo, zi, yi, xi)
499    s[temp_W].bind(zi, thread_z)
500    s[temp_W].bind(yi, thread_y)
501    s[temp_W].bind(xi, thread_x)
502    s[temp_W].storage_align(s[temp_W].op.axis[2], 16, 0)
503
504    # schedule kernel_L
505    if OUTPUT_BLOCK_HEIGHT == 2 and OUTPUT_BLOCK_WIDTH == 14:
506        s[kernel_L].compute_at(s[conv_L], ry)
507    else:
508        s[kernel_L].compute_at(s[conv_L], rx)
509
510    # schedule output
511    if SCHEDULE_OUTPUT:
512        if output.op in s.outputs:
513            out = output
514        else:
515            s[output].compute_inline()
516            out = s.outputs[0]
517
518        _, co, h, w, vc = s[out].op.axis
519        tile_and_bind3d(s, out, w, h, vc, 4, 8, 8)
520
521
522def conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype):
523    """convert argument to workload"""
524    if len(kernel.shape) == 4:
525        raw_kernel = kernel
526    else:  # the input kernel is transformed by alter_op_layout
527        shape = get_const_tuple(kernel.shape)
528        raw_kernel = tvm.placeholder((shape[0] * shape[4], shape[1], shape[2], shape[3]),
529                                     dtype=kernel.dtype)
530    return ('conv2d', ) + autotvm.task.args_to_workload(
531        [data, raw_kernel, strides, padding, layout, out_dtype])
532
533@autotvm.register_topi_compute(conv2d, 'intel_graphics', 'direct')
534def decl_conv2d(cfg, data, kernel, stride, padding, dilation, layout='NCHW', out_dtype='float32'):
535    """Conv2D operator for Intel Graphics backend.
536
537    Parameters
538    ----------
539    data : tvm.Tensor
540        4-D with shape [batch, in_channel, in_height, in_width]
541    kernel : tvm.Tensor
542        4-D with shape [num_filter, in_channel, filter_height, filter_width]
543    stride : int or a list/tuple of two ints
544        stride size, or [stride_height, stride_width]
545    padding : int or a list/tuple of two ints
546        padding size, or [pad_height, pad_width]
547    layout : str
548        layout of data
549    Returns
550    -------
551    output : tvm.Tensor
552        4-D with shape [batch, out_channel, out_height, out_width]
553    """
554    assert layout == 'NCHW', "only support NCHW convolution on intel gpu"
555    assert data.shape[0].value == 1, "only support batch size=1 convolution on intel gpu"
556    assert data.dtype == kernel.dtype, "Do not support inputs with different data types now."
557
558    return _decl_cl_spatialpack(cfg, data, kernel, stride, padding, layout, out_dtype)
559
560@autotvm.task.register_topi_schedule(generic.schedule_conv2d_nchw, 'intel_graphics', ['direct'])
561def schedule_conv2d_nchw(cfg, outs):
562    """Schedule for conv2d_nchw for Intel Graphics
563
564    Parameters
565    ----------
566    outs: Array of Tensor
567        The computation graph description of conv2d_nchw
568        in the format of an array of tensors.
569    Returns
570    -------
571    s: Schedule
572        The computation schedule for conv2d_nchw.
573    """
574    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
575    s = tvm.create_schedule([x.op for x in outs])
576    scheduled_ops = []
577
578    def traverse(op):
579        """inline all one-to-one-mapping operators except the last stage (output)"""
580        if tag.is_broadcast(op.tag):
581            if op not in s.outputs:
582                s[op].compute_inline()
583            for tensor in op.input_tensors:
584                if tensor.op.input_tensors and tensor.op not in scheduled_ops:
585                    traverse(tensor.op)
586        if 'conv2d' in op.tag:
587            _schedule_cl_spatialpack(cfg, s, op)
588
589        scheduled_ops.append(op)
590
591    traverse(outs[0].op)
592    return s
593
594def _decl_cl_spatialpack(cfg, data, kernel, stride, padding, layout, out_dtype='float16'):
595    batch, in_channel, in_height, in_width = [util.get_const_int(x) for x in data.shape]
596    num_filter, channel, kernel_h, kernel_w = [util.get_const_int(x) for x in kernel.shape]
597    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, kernel)
598
599    if isinstance(stride, (tuple, list)):
600        stride_h, stride_w = stride
601    else:
602        stride_h, stride_w = stride, stride
603
604    out_channel = num_filter
605    out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1)
606    out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1)
607    oshape = (batch, out_channel, out_height, out_width)
608
609    rc = tvm.reduce_axis((0, in_channel), name='rc')
610    ry = tvm.reduce_axis((0, kernel_h), name='ry')
611    rx = tvm.reduce_axis((0, kernel_w), name='rx')
612
613    block_w = 1
614    block_h = 1
615    if stride_h == 2:
616        if num_filter + kernel_h == 515:
617            block_h = 4
618            block_w = 4
619        else:
620            block_h = 4
621            block_w = 5
622    elif kernel_h == 3:
623        if num_filter == 512:
624            block_h = 2
625            block_w = 7
626        else:
627            block_h = 2
628            block_w = 14
629    elif kernel_h == 7 and padding == 3 and stride == 1:
630        block_h = 3
631        block_w = 4
632    else:
633        block_h = 1
634        block_w = 16
635    attrs = {'block_h': block_h, 'block_w' : block_w}
636    c_h = out_height
637    c_w = out_width
638
639    if out_height % block_h != 0:
640        c_h = (out_height // block_h + 1) * block_h
641
642    if out_width % block_w != 0:
643        c_w = (out_width // block_w + 1) * block_w
644
645    pad_before = [0, 0, pad_top, pad_left]
646    pad_after = [0, 0, pad_down + c_h - block_h, pad_right + c_w - block_w]
647    temp = pad(data, pad_before, pad_after, name="pad_temp")
648
649    nv = 16
650    if num_filter % nv != 0:
651        num_filter = (num_filter // nv + 1) * nv
652        out_channel = num_filter
653
654    cshape = (batch, out_channel // nv, c_h, c_w, nv)
655    kvshape = (num_filter // nv, channel, kernel_h, kernel_w, nv)
656
657    kernel_vec = tvm.compute(
658        kvshape,
659        lambda co, ci, kh, kw, vc:
660        kernel[co*nv + vc][ci][kh][kw], name='kernel_vec')
661
662    conv = tvm.compute(
663        cshape,
664        lambda nn, ff, yy, xx, vc: \
665            tvm.sum(
666                temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx].astype(out_dtype) *
667                kernel_vec[ff, rc, ry, rx, vc].astype(out_dtype),
668                axis=[rc, ry, rx]), name='conv', attrs=attrs)
669
670    output = tvm.compute(
671        oshape,
672        lambda nn, ff, yy, xx:
673        conv[nn][ff//nv][yy][xx][ff%nv],
674        name='output_unpack', tag='conv2d',
675        attrs={'workload': conv_arg_to_workload(data, kernel, stride, padding,
676                                                layout, out_dtype)})
677
678    return output
679
680def _schedule_cl_spatialpack(cfg, s, op):
681    output = op.output(0)
682    _, _, out_height, out_width = [util.get_const_int(x) for x in output.shape]
683
684    conv = op.input_tensors[0]
685    temp = s[conv].op.input_tensors[0]
686    kernel_vec = s[conv].op.input_tensors[1]
687    kernel = s[kernel_vec].op.input_tensors[0]
688    temp_W = s.cache_read(temp, "warp", [conv])
689    conv_L = s.cache_write(conv, "local")
690
691    kernel_L = s.cache_read(kernel_vec, "local", [conv_L])
692    _, in_channel, temp_h, temp_w = [util.get_const_int(x) for x in temp.shape]
693
694    attrs = s[conv].op.attrs
695    OUTPUT_BLOCK_HEIGHT = attrs['block_h']
696    OUTPUT_BLOCK_WIDTH = attrs['block_w']
697
698    # schedule conv
699    z_factor = 1
700    y_factor = 1
701    x_factor = 16
702    thread_z = tvm.thread_axis((0, z_factor), "threadIdx.z")
703    thread_y = tvm.thread_axis((0, y_factor), "threadIdx.y")
704    thread_x = tvm.thread_axis((0, x_factor), "threadIdx.x")
705    _, co, oh, ow, vc = s[conv].op.axis
706    ooh, ioh = s[conv].split(oh, factor=OUTPUT_BLOCK_HEIGHT)
707    oow, iow = s[conv].split(ow, factor=OUTPUT_BLOCK_WIDTH)
708    s[conv].reorder(_, co, ooh, oow, vc, ioh, iow)
709    coo, coi = s[conv].split(co, nparts=1)
710    ooho, oohi = s[conv].split(ooh, factor=z_factor)
711    oowo, oowi = s[conv].split(oow, factor=y_factor)
712    vco, vci = s[conv].split(vc, factor=x_factor)
713    s[conv].reorder(_, coo, vco, ooho, oowo, coi, oohi, oowi, vci, ioh, iow)
714    s[conv].bind(oohi, thread_z)
715    s[conv].bind(oowi, thread_y)
716    s[conv].bind(vci, thread_x)
717    s[conv].bind(ooho, tvm.thread_axis("blockIdx.z"))
718    s[conv].bind(oowo, tvm.thread_axis("blockIdx.y"))
719    s[conv].bind(coi, tvm.thread_axis("blockIdx.x"))
720
721    # schedule conv_L
722    s[conv_L].compute_at(s[conv], vci)
723    i, oc, h, w, vc = s[conv_L].op.axis
724    rc, ry, rx = s[conv_L].op.reduce_axis
725    s[conv_L].reorder(i, oc, rc, ry, rx, vc, h, w)
726    s[temp_W].compute_at(s[conv_L], rc)
727    if kernel.shape[3].value != 7:
728        s[conv_L].unroll(ry)
729        s[conv_L].unroll(rx)
730
731    # schedule temp
732    _, ci, h, w = s[temp].op.axis
733    tile_and_bind3d(s, temp, ci, h, w, 1, 16, 16)
734
735    # schedule temp_W
736    _, ci, h, w = s[temp_W].op.axis
737    zo, zi = s[temp_W].split(ci, 1)
738    yo, yi = s[temp_W].split(h, 1)
739    xo, xi = s[temp_W].split(w, 16)
740    s[temp_W].reorder(zo, yo, xo, zi, yi, xi)
741    s[temp_W].bind(zi, thread_z)
742    s[temp_W].bind(yi, thread_y)
743    s[temp_W].bind(xi, thread_x)
744    s[temp_W].storage_align(s[temp_W].op.axis[2], 16, 0)
745
746    s[kernel_vec].compute_inline()
747
748    # schedule kernel_L
749    if "2_14" in s[conv].op.tag:
750        s[kernel_L].compute_at(s[conv_L], ry)
751    else:
752        s[kernel_L].compute_at(s[conv_L], rx)
753
754    # schedule output
755    if output.op in s.outputs:
756        out = output
757    else:
758        s[output].compute_inline()
759        out = s.outputs[0]
760
761    _, co, h, w = s[out].op.axis
762    tile_and_bind3d(s, out, w, h, co, 4, 8, 8)
763