1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Definition of ROCm operator strategy."""
18# pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import
19from tvm import topi
20from .generic import *
21from .. import op as _op
22
23
24@schedule_lrn.register("rocm")
25def schedule_lrn_rocm(attrs, outs, target):
26    """schedule LRN for rocm"""
27    with target:
28        return topi.rocm.schedule_lrn(outs)
29
30
31@conv2d_strategy.register("rocm")
32def conv2d_strategy_rocm(attrs, inputs, out_type, target):
33    """conv2d rocm strategy"""
34    strategy = _op.OpStrategy()
35    data, kernel = inputs
36    dilation_h, dilation_w = attrs.get_int_tuple("dilation")
37    groups = attrs.groups
38    layout = attrs.data_layout
39    stride_h, stride_w = attrs.get_int_tuple("strides")
40    kernel_layout = attrs.kernel_layout
41    padding = attrs.get_int_tuple("padding")
42    if dilation_h < 1 or dilation_w < 1:
43        raise ValueError("dilation should be positive value")
44
45    if groups == 1:
46        if layout == "NCHW":
47            # TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
48            assert kernel_layout == "OIHW"
49            strategy.add_implementation(
50                wrap_compute_conv2d(topi.cuda.conv2d_nchw),
51                wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
52                name="conv2d_nchw.cuda",
53            )
54            _, _, kh, kw = get_const_tuple(kernel.shape)
55            if (
56                2 < kh < 8
57                and 2 < kw < 8
58                and kh == kw
59                and stride_h == 1
60                and stride_w == 1
61                and dilation_h == 1
62                and dilation_w == 1
63            ):
64                strategy.add_implementation(
65                    wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd),
66                    wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd),
67                    name="conv2d_nchw_winograd.cuda",
68                    plevel=5,
69                )
70        elif layout == "HWCN":
71            assert kernel_layout == "HWIO"
72            strategy.add_implementation(
73                wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
74                wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
75                name="conv2d_hwcn.cuda",
76            )
77        # TODO(@alexgl-github): Re-enable this after fix the conv2d_nhwc for cuda
78        # elif layout == "NHWC":
79        #     assert kernel_layout == "HWIO"
80        #     strategy.add_implementation(
81        #         wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
82        #         wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc),
83        #         name="conv2d_nhwc.cuda")
84        elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
85            assert kernel_layout == "OIHW4o4i"
86            strategy.add_implementation(
87                wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
88                wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
89                name="conv2d_NCHWc_int8.cuda",
90            )
91        else:
92            raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
93        # add miopen implementation
94        if (
95            "miopen" in target.libs
96            and layout == "NCHW"
97            and padding[0] == padding[2]
98            and padding[1] == padding[3]
99        ):
100            strategy.add_implementation(
101                wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
102                wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
103                name="conv2d_nchw_miopen.rocm",
104                plevel=15,
105            )
106    elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
107        if layout == "NCHW":
108            assert kernel_layout == "OIHW"
109            strategy.add_implementation(
110                wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
111                wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
112                name="depthwise_conv2d_nchw.cuda",
113            )
114        elif layout == "NHWC":
115            assert kernel_layout == "HWOI"
116            strategy.add_implementation(
117                wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
118                wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc),
119                name="depthwise_conv2d_nhwc.cuda",
120            )
121        else:
122            raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
123    else:  # group_conv2d
124        if layout == "NCHW":
125            # TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
126            assert kernel_layout == "OIHW"
127            strategy.add_implementation(
128                wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True),
129                wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw),
130                name="group_conv2d_nchw.cuda",
131            )
132        elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
133            assert kernel_layout == "OIHW4o4i"
134            strategy.add_implementation(
135                wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
136                wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8),
137                name="group_conv2d_NCHWc_int8.cuda",
138            )
139        else:
140            raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
141    return strategy
142
143
144@dense_strategy.register("rocm")
145def dense_strategy_rocm(attrs, inputs, out_type, target):
146    """Dense strategy for ROCM"""
147    assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only support 2-dim dense"
148    strategy = _op.OpStrategy()
149    strategy.add_implementation(
150        wrap_compute_dense(topi.rocm.dense),
151        wrap_topi_schedule(topi.rocm.schedule_dense),
152        name="dense.rocm",
153    )
154    if target.kind.name == "rocm" and "rocblas" in target.libs:
155        assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
156        strategy.add_implementation(
157            wrap_compute_dense(topi.rocm.dense_rocblas),
158            wrap_topi_schedule(topi.rocm.schedule_dense_rocblas),
159            name="dense_rocblas.rocm",
160            plevel=15,
161        )
162    return strategy
163
164
165@batch_matmul_strategy.register("rocm")
166def batch_matmul_strategy_rocm(attrs, inputs, out_type, target):
167    """Batch matmul strategy for ROCM"""
168    strategy = _op.OpStrategy()
169    strategy.add_implementation(
170        wrap_compute_batch_matmul(topi.cuda.batch_matmul),
171        wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
172        name="batch_matmul.cuda",
173        plevel=10,
174    )
175    if target.kind.name == "rocm" and "rocblas" in target.libs:
176        assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
177        strategy.add_implementation(
178            wrap_compute_batch_matmul(topi.rocm.batch_matmul_rocblas),
179            wrap_topi_schedule(topi.rocm.schedule_batch_matmul_rocblas),
180            name="batch_matmul_rocblas.rocm",
181            plevel=12,
182        )
183    return strategy
184