1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file Use external miopen utils function
22 */
23 #include <tvm/runtime/registry.h>
24 #include <tvm/runtime/util.h>
25 #include <tvm/runtime/device_api.h>
26 #include "miopen_utils.h"
27
28 namespace tvm {
29 namespace contrib {
30 namespace miopen {
31
32 using namespace runtime;
33
34 TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
__anone1005ce40102(TVMArgs args, TVMRetValue *ret) 35 .set_body([](TVMArgs args, TVMRetValue *ret) {
36 const int mode = args[0];
37 const int dtype = args[1];
38 const int pad_h = args[2];
39 const int pad_w = args[3];
40 const int stride_h = args[4];
41 const int stride_w = args[5];
42 const int dilation_h = args[6];
43 const int dilation_w = args[7];
44 const int x_dim0 = args[8];
45 const int x_dim1 = args[9];
46 const int x_dim2 = args[10];
47 const int x_dim3 = args[11];
48 const int w_dim0 = args[12];
49 const int w_dim1 = args[13];
50 const int w_dim2 = args[14];
51 const int w_dim3 = args[15];
52 const int n_group = args[16];
53 void *out_shape = args[17];
54
55 MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
56 assert(n_group > 0 && "Group Size > 0 is expected");
57 if (n_group > 1)
58 assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1");
59 // Set Mode
60 entry_ptr->conv_entry.mode = static_cast<miopenConvolutionMode_t>(mode);
61 // Set Ctx
62 entry_ptr->conv_entry.ctx = TVMContext{kDLROCM, 0};
63 // Set Data Type
64 entry_ptr->conv_entry.data_type = static_cast<miopenDataType_t>(
65 dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf), int32, int8 at
66 // this moment.
67 // Set Desc
68 MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc,
69 entry_ptr->conv_entry.mode,
70 pad_h,
71 pad_w,
72 stride_h,
73 stride_w,
74 dilation_h,
75 dilation_w));
76 if (n_group > 1)
77 MIOPEN_CALL(miopenSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, n_group));
78 // Set Filter
79 MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc,
80 entry_ptr->conv_entry.data_type,
81 w_dim0,
82 w_dim1/n_group,
83 w_dim2,
84 w_dim3));
85 // Set Input
86 MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc,
87 entry_ptr->conv_entry.data_type,
88 x_dim0,
89 x_dim1,
90 x_dim2,
91 x_dim3));
92
93 // Set Output shape
94 MIOPEN_CALL(miopenGetConvolutionForwardOutputDim(entry_ptr->conv_entry.conv_desc,
95 entry_ptr->conv_entry.input_desc,
96 entry_ptr->conv_entry.filter_desc,
97 static_cast<int*>(out_shape),
98 static_cast<int*>(out_shape) + 1,
99 static_cast<int*>(out_shape) + 2,
100 static_cast<int*>(out_shape) + 3));
101
102 const int *oshape = static_cast<int*>(out_shape);
103 // Set Output
104 MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc,
105 entry_ptr->conv_entry.data_type,
106 oshape[0],
107 oshape[1],
108 oshape[2],
109 oshape[3]));
110
111 // Set workspace
112 size_t workspace_size = 0;
113 MIOPEN_CALL(miopenConvolutionForwardGetWorkSpaceSize(entry_ptr->handle,
114 entry_ptr->conv_entry.filter_desc,
115 entry_ptr->conv_entry.input_desc,
116 entry_ptr->conv_entry.conv_desc,
117 entry_ptr->conv_entry.output_desc,
118 &workspace_size));
119 entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
120
121 const size_t input_size = x_dim0 * x_dim1 * x_dim2 * x_dim3;
122 const size_t filter_size = w_dim0 * w_dim1 * w_dim2 * w_dim3;
123 const size_t output_size = oshape[0] * oshape[1] * oshape[2] * oshape[3];
124
125 runtime::DeviceAPI* rocm_api = entry_ptr->conv_entry.rocm_api;
126 float* input_buf = static_cast<float*>(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx,
127 input_size * sizeof(float)));
128 float* filter_buf = static_cast<float*>(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx,
129 filter_size * sizeof(float)));
130 float* output_buf = static_cast<float*>(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx,
131 output_size * sizeof(float)));
132
133 const int request_algo_count = 4;
134 const bool exhaustive_search = false;
135 void* workspace = entry_ptr->conv_entry.workspace;
136 if (workspace_size == 0) workspace = nullptr;
137 int returned_algo_count = 0;
138 miopenConvAlgoPerf_t perfs[4];
139
140 MIOPEN_CALL(miopenFindConvolutionForwardAlgorithm(entry_ptr->handle,
141 entry_ptr->conv_entry.input_desc,
142 input_buf,
143 entry_ptr->conv_entry.filter_desc,
144 filter_buf,
145 entry_ptr->conv_entry.conv_desc,
146 entry_ptr->conv_entry.output_desc,
147 output_buf,
148 request_algo_count,
149 &returned_algo_count,
150 perfs,
151 workspace,
152 workspace_size,
153 exhaustive_search));
154
155 rocm_api->FreeWorkspace(entry_ptr->conv_entry.ctx, input_buf);
156 rocm_api->FreeWorkspace(entry_ptr->conv_entry.ctx, filter_buf);
157 rocm_api->FreeWorkspace(entry_ptr->conv_entry.ctx, output_buf);
158
159 const std::vector<std::string> fwd_algo_names{
160 "miopenConvolutionFwdAlgoGEMM",
161 "miopenConvolutionFwdAlgoDirect",
162 "miopenConvolutionFwdAlgoFFT",
163 "miopenConvolutionFwdAlgoWinograd",
164 };
165 const auto best_algo = perfs[0].fwd_algo;
166 LOG(INFO) << "\tMIOpen Found " << returned_algo_count
167 << " fwd algorithms, choosing " << fwd_algo_names[best_algo];
168 for (int i = 0; i < returned_algo_count; ++i) {
169 LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perfs[i].fwd_algo]
170 << " - time: " << perfs[i].time << " ms"
171 << ", Memory: " << perfs[i].memory;
172 }
173 // Set Algo
174 ret[0] = static_cast<int>(best_algo);
175 });
176
177
178 TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward")
__anone1005ce40202(TVMArgs args, TVMRetValue *ret) 179 .set_body([](TVMArgs args, TVMRetValue *ret) {
180 const int mode = args[0];
181 const int dtype = args[1];
182 const int pad_h = args[2];
183 const int pad_w = args[3];
184 const int stride_h = args[4];
185 const int stride_w = args[5];
186 const int dilation_h = args[6];
187 const int dilation_w = args[7];
188 const int algo = args[8];
189 const DLTensor *x = args[9];
190 const DLTensor *w = args[10];
191 const DLTensor *y = args[11];
192
193 MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
194 entry_ptr->conv_entry.fwd_algo = static_cast<miopenConvFwdAlgorithm_t>(algo);
195 // Set Mode
196 entry_ptr->conv_entry.mode = static_cast<miopenConvolutionMode_t>(mode);
197 // Set Ctx
198 entry_ptr->conv_entry.ctx = x->ctx;
199 // Set Data Type
200 entry_ptr->conv_entry.data_type = static_cast<miopenDataType_t>(
201 dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at
202 // this moment.
203 // Set Desc
204 MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc,
205 entry_ptr->conv_entry.mode,
206 pad_h,
207 pad_w,
208 stride_h,
209 stride_w,
210 dilation_h,
211 dilation_w));
212 // Set Filter
213 MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc,
214 entry_ptr->conv_entry.data_type,
215 w->shape[0],
216 w->shape[1],
217 w->shape[2],
218 w->shape[3]));
219 // Set Input
220 MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc,
221 entry_ptr->conv_entry.data_type,
222 x->shape[0],
223 x->shape[1],
224 x->shape[2],
225 x->shape[3]));
226 // Set Output
227 MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc,
228 entry_ptr->conv_entry.data_type,
229 y->shape[0],
230 y->shape[1],
231 y->shape[2],
232 y->shape[3]));
233
234 const float alpha = 1.f;
235 const float beta = 0.f;
236 MIOPEN_CALL(miopenConvolutionForward(entry_ptr->handle,
237 &alpha,
238 entry_ptr->conv_entry.input_desc,
239 x->data,
240 entry_ptr->conv_entry.filter_desc,
241 w->data,
242 entry_ptr->conv_entry.conv_desc,
243 entry_ptr->conv_entry.fwd_algo,
244 &beta,
245 entry_ptr->conv_entry.output_desc,
246 y->data,
247 entry_ptr->conv_entry.workspace,
248 entry_ptr->conv_entry.workspace_size));
249 });
250
251 } // namespace miopen
252 } // namespace contrib
253 } // namespace tvm
254