1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Definition of ROCm operator strategy.""" 18# pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import 19from tvm import topi 20from .generic import * 21from .. import op as _op 22 23 24@schedule_lrn.register("rocm") 25def schedule_lrn_rocm(attrs, outs, target): 26 """schedule LRN for rocm""" 27 with target: 28 return topi.rocm.schedule_lrn(outs) 29 30 31@conv2d_strategy.register("rocm") 32def conv2d_strategy_rocm(attrs, inputs, out_type, target): 33 """conv2d rocm strategy""" 34 strategy = _op.OpStrategy() 35 data, kernel = inputs 36 dilation_h, dilation_w = attrs.get_int_tuple("dilation") 37 groups = attrs.groups 38 layout = attrs.data_layout 39 stride_h, stride_w = attrs.get_int_tuple("strides") 40 kernel_layout = attrs.kernel_layout 41 padding = attrs.get_int_tuple("padding") 42 if dilation_h < 1 or dilation_w < 1: 43 raise ValueError("dilation should be positive value") 44 45 if groups == 1: 46 if layout == "NCHW": 47 # TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8. 48 assert kernel_layout == "OIHW" 49 strategy.add_implementation( 50 wrap_compute_conv2d(topi.cuda.conv2d_nchw), 51 wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw), 52 name="conv2d_nchw.cuda", 53 ) 54 _, _, kh, kw = get_const_tuple(kernel.shape) 55 if ( 56 2 < kh < 8 57 and 2 < kw < 8 58 and kh == kw 59 and stride_h == 1 60 and stride_w == 1 61 and dilation_h == 1 62 and dilation_w == 1 63 ): 64 strategy.add_implementation( 65 wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd), 66 wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd), 67 name="conv2d_nchw_winograd.cuda", 68 plevel=5, 69 ) 70 elif layout == "HWCN": 71 assert kernel_layout == "HWIO" 72 strategy.add_implementation( 73 wrap_compute_conv2d(topi.cuda.conv2d_hwcn), 74 wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn), 75 name="conv2d_hwcn.cuda", 76 ) 77 # TODO(@alexgl-github): Re-enable this after fix the conv2d_nhwc for cuda 78 # elif layout == "NHWC": 79 # assert kernel_layout == "HWIO" 80 # strategy.add_implementation( 81 # wrap_compute_conv2d(topi.cuda.conv2d_nhwc), 82 # wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc), 83 # name="conv2d_nhwc.cuda") 84 elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]: 85 assert kernel_layout == "OIHW4o4i" 86 strategy.add_implementation( 87 wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True), 88 wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8), 89 name="conv2d_NCHWc_int8.cuda", 90 ) 91 else: 92 raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout)) 93 # add miopen implementation 94 if ( 95 "miopen" in target.libs 96 and layout == "NCHW" 97 and padding[0] == padding[2] 98 and padding[1] == padding[3] 99 ): 100 strategy.add_implementation( 101 wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True), 102 wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen), 103 name="conv2d_nchw_miopen.rocm", 104 plevel=15, 105 ) 106 elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups): 107 if layout == "NCHW": 108 assert kernel_layout == "OIHW" 109 strategy.add_implementation( 110 wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw), 111 wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw), 112 name="depthwise_conv2d_nchw.cuda", 113 ) 114 elif layout == "NHWC": 115 assert kernel_layout == "HWOI" 116 strategy.add_implementation( 117 wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), 118 wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc), 119 name="depthwise_conv2d_nhwc.cuda", 120 ) 121 else: 122 raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout)) 123 else: # group_conv2d 124 if layout == "NCHW": 125 # TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8. 126 assert kernel_layout == "OIHW" 127 strategy.add_implementation( 128 wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True), 129 wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw), 130 name="group_conv2d_nchw.cuda", 131 ) 132 elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]: 133 assert kernel_layout == "OIHW4o4i" 134 strategy.add_implementation( 135 wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True), 136 wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8), 137 name="group_conv2d_NCHWc_int8.cuda", 138 ) 139 else: 140 raise RuntimeError("Unsupported group_conv2d layout {}".format(layout)) 141 return strategy 142 143 144@dense_strategy.register("rocm") 145def dense_strategy_rocm(attrs, inputs, out_type, target): 146 """Dense strategy for ROCM""" 147 assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only support 2-dim dense" 148 strategy = _op.OpStrategy() 149 strategy.add_implementation( 150 wrap_compute_dense(topi.rocm.dense), 151 wrap_topi_schedule(topi.rocm.schedule_dense), 152 name="dense.rocm", 153 ) 154 if target.kind.name == "rocm" and "rocblas" in target.libs: 155 assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported." 156 strategy.add_implementation( 157 wrap_compute_dense(topi.rocm.dense_rocblas), 158 wrap_topi_schedule(topi.rocm.schedule_dense_rocblas), 159 name="dense_rocblas.rocm", 160 plevel=15, 161 ) 162 return strategy 163 164 165@batch_matmul_strategy.register("rocm") 166def batch_matmul_strategy_rocm(attrs, inputs, out_type, target): 167 """Batch matmul strategy for ROCM""" 168 strategy = _op.OpStrategy() 169 strategy.add_implementation( 170 wrap_compute_batch_matmul(topi.cuda.batch_matmul), 171 wrap_topi_schedule(topi.cuda.schedule_batch_matmul), 172 name="batch_matmul.cuda", 173 plevel=10, 174 ) 175 if target.kind.name == "rocm" and "rocblas" in target.libs: 176 assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported." 177 strategy.add_implementation( 178 wrap_compute_batch_matmul(topi.rocm.batch_matmul_rocblas), 179 wrap_topi_schedule(topi.rocm.schedule_batch_matmul_rocblas), 180 name="batch_matmul_rocblas.rocm", 181 plevel=12, 182 ) 183 return strategy 184