1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file Use external miopen utils function
22  */
23 #include <tvm/runtime/registry.h>
24 #include <tvm/runtime/util.h>
25 #include <tvm/runtime/device_api.h>
26 #include "miopen_utils.h"
27 
28 namespace tvm {
29 namespace contrib {
30 namespace miopen {
31 
32 using namespace runtime;
33 
34 TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
__anone1005ce40102(TVMArgs args, TVMRetValue *ret) 35 .set_body([](TVMArgs args, TVMRetValue *ret) {
36   const int mode = args[0];
37   const int dtype = args[1];
38   const int pad_h = args[2];
39   const int pad_w = args[3];
40   const int stride_h = args[4];
41   const int stride_w = args[5];
42   const int dilation_h = args[6];
43   const int dilation_w = args[7];
44   const int x_dim0 = args[8];
45   const int x_dim1 = args[9];
46   const int x_dim2 = args[10];
47   const int x_dim3 = args[11];
48   const int w_dim0 = args[12];
49   const int w_dim1 = args[13];
50   const int w_dim2 = args[14];
51   const int w_dim3 = args[15];
52   const int n_group = args[16];
53   void *out_shape = args[17];
54 
55   MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
56   assert(n_group > 0 && "Group Size > 0 is expected");
57   if (n_group > 1)
58     assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1");
59   // Set Mode
60   entry_ptr->conv_entry.mode = static_cast<miopenConvolutionMode_t>(mode);
61   // Set Ctx
62   entry_ptr->conv_entry.ctx = TVMContext{kDLROCM, 0};
63   // Set Data Type
64   entry_ptr->conv_entry.data_type = static_cast<miopenDataType_t>(
65       dtype);  // MIOpen supports fp32(miopenFloat), fp16(miopenHalf), int32, int8 at
66                // this moment.
67   // Set Desc
68   MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc,
69                                               entry_ptr->conv_entry.mode,
70                                               pad_h,
71                                               pad_w,
72                                               stride_h,
73                                               stride_w,
74                                               dilation_h,
75                                               dilation_w));
76   if (n_group > 1)
77     MIOPEN_CALL(miopenSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, n_group));
78   // Set Filter
79   MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc,
80                                           entry_ptr->conv_entry.data_type,
81                                           w_dim0,
82                                           w_dim1/n_group,
83                                           w_dim2,
84                                           w_dim3));
85   // Set Input
86   MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc,
87                                           entry_ptr->conv_entry.data_type,
88                                           x_dim0,
89                                           x_dim1,
90                                           x_dim2,
91                                           x_dim3));
92 
93   // Set Output shape
94   MIOPEN_CALL(miopenGetConvolutionForwardOutputDim(entry_ptr->conv_entry.conv_desc,
95                                                    entry_ptr->conv_entry.input_desc,
96                                                    entry_ptr->conv_entry.filter_desc,
97                                                    static_cast<int*>(out_shape),
98                                                    static_cast<int*>(out_shape) + 1,
99                                                    static_cast<int*>(out_shape) + 2,
100                                                    static_cast<int*>(out_shape) + 3));
101 
102   const int *oshape = static_cast<int*>(out_shape);
103   // Set Output
104   MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc,
105                                           entry_ptr->conv_entry.data_type,
106                                           oshape[0],
107                                           oshape[1],
108                                           oshape[2],
109                                           oshape[3]));
110 
111   // Set workspace
112   size_t workspace_size = 0;
113   MIOPEN_CALL(miopenConvolutionForwardGetWorkSpaceSize(entry_ptr->handle,
114                                                        entry_ptr->conv_entry.filter_desc,
115                                                        entry_ptr->conv_entry.input_desc,
116                                                        entry_ptr->conv_entry.conv_desc,
117                                                        entry_ptr->conv_entry.output_desc,
118                                                        &workspace_size));
119   entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
120 
121   const size_t input_size = x_dim0 * x_dim1 * x_dim2 * x_dim3;
122   const size_t filter_size = w_dim0 * w_dim1 * w_dim2 * w_dim3;
123   const size_t output_size = oshape[0] * oshape[1] * oshape[2] * oshape[3];
124 
125   runtime::DeviceAPI* rocm_api = entry_ptr->conv_entry.rocm_api;
126   float* input_buf = static_cast<float*>(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx,
127                                                                   input_size * sizeof(float)));
128   float* filter_buf = static_cast<float*>(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx,
129                                                                    filter_size * sizeof(float)));
130   float* output_buf = static_cast<float*>(rocm_api->AllocWorkspace(entry_ptr->conv_entry.ctx,
131                                                                    output_size * sizeof(float)));
132 
133   const int request_algo_count = 4;
134   const bool exhaustive_search = false;
135   void* workspace = entry_ptr->conv_entry.workspace;
136   if (workspace_size == 0) workspace = nullptr;
137   int returned_algo_count = 0;
138   miopenConvAlgoPerf_t perfs[4];
139 
140   MIOPEN_CALL(miopenFindConvolutionForwardAlgorithm(entry_ptr->handle,
141                                                     entry_ptr->conv_entry.input_desc,
142                                                     input_buf,
143                                                     entry_ptr->conv_entry.filter_desc,
144                                                     filter_buf,
145                                                     entry_ptr->conv_entry.conv_desc,
146                                                     entry_ptr->conv_entry.output_desc,
147                                                     output_buf,
148                                                     request_algo_count,
149                                                     &returned_algo_count,
150                                                     perfs,
151                                                     workspace,
152                                                     workspace_size,
153                                                     exhaustive_search));
154 
155   rocm_api->FreeWorkspace(entry_ptr->conv_entry.ctx, input_buf);
156   rocm_api->FreeWorkspace(entry_ptr->conv_entry.ctx, filter_buf);
157   rocm_api->FreeWorkspace(entry_ptr->conv_entry.ctx, output_buf);
158 
159   const std::vector<std::string> fwd_algo_names{
160       "miopenConvolutionFwdAlgoGEMM",
161       "miopenConvolutionFwdAlgoDirect",
162       "miopenConvolutionFwdAlgoFFT",
163       "miopenConvolutionFwdAlgoWinograd",
164   };
165   const auto best_algo = perfs[0].fwd_algo;
166   LOG(INFO) << "\tMIOpen Found " << returned_algo_count
167             << " fwd algorithms, choosing " << fwd_algo_names[best_algo];
168   for (int i = 0; i < returned_algo_count; ++i) {
169     LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perfs[i].fwd_algo]
170               << " - time: " << perfs[i].time << " ms"
171               << ", Memory: " << perfs[i].memory;
172   }
173   // Set Algo
174   ret[0] = static_cast<int>(best_algo);
175 });
176 
177 
178 TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.forward")
__anone1005ce40202(TVMArgs args, TVMRetValue *ret) 179 .set_body([](TVMArgs args, TVMRetValue *ret) {
180   const int mode = args[0];
181   const int dtype = args[1];
182   const int pad_h = args[2];
183   const int pad_w = args[3];
184   const int stride_h = args[4];
185   const int stride_w = args[5];
186   const int dilation_h = args[6];
187   const int dilation_w = args[7];
188   const int algo = args[8];
189   const DLTensor *x = args[9];
190   const DLTensor *w = args[10];
191   const DLTensor *y = args[11];
192 
193   MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
194   entry_ptr->conv_entry.fwd_algo = static_cast<miopenConvFwdAlgorithm_t>(algo);
195   // Set Mode
196   entry_ptr->conv_entry.mode = static_cast<miopenConvolutionMode_t>(mode);
197   // Set Ctx
198   entry_ptr->conv_entry.ctx = x->ctx;
199   // Set Data Type
200   entry_ptr->conv_entry.data_type = static_cast<miopenDataType_t>(
201       dtype);  // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at
202                // this moment.
203   // Set Desc
204   MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc,
205                                               entry_ptr->conv_entry.mode,
206                                               pad_h,
207                                               pad_w,
208                                               stride_h,
209                                               stride_w,
210                                               dilation_h,
211                                               dilation_w));
212   // Set Filter
213   MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc,
214                                           entry_ptr->conv_entry.data_type,
215                                           w->shape[0],
216                                           w->shape[1],
217                                           w->shape[2],
218                                           w->shape[3]));
219   // Set Input
220   MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.input_desc,
221                                           entry_ptr->conv_entry.data_type,
222                                           x->shape[0],
223                                           x->shape[1],
224                                           x->shape[2],
225                                           x->shape[3]));
226   // Set Output
227   MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.output_desc,
228                                           entry_ptr->conv_entry.data_type,
229                                           y->shape[0],
230                                           y->shape[1],
231                                           y->shape[2],
232                                           y->shape[3]));
233 
234   const float alpha = 1.f;
235   const float beta = 0.f;
236   MIOPEN_CALL(miopenConvolutionForward(entry_ptr->handle,
237                                        &alpha,
238                                        entry_ptr->conv_entry.input_desc,
239                                        x->data,
240                                        entry_ptr->conv_entry.filter_desc,
241                                        w->data,
242                                        entry_ptr->conv_entry.conv_desc,
243                                        entry_ptr->conv_entry.fwd_algo,
244                                        &beta,
245                                        entry_ptr->conv_entry.output_desc,
246                                        y->data,
247                                        entry_ptr->conv_entry.workspace,
248                                        entry_ptr->conv_entry.workspace_size));
249 });
250 
251 }  // namespace miopen
252 }  // namespace contrib
253 }  // namespace tvm
254