1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file Use external miopen utils function
22  */
23 #include "miopen_utils.h"
24 #include <dmlc/thread_local.h>
25 #include <tvm/runtime/registry.h>
26 #include <vector>
27 #include <string>
28 
29 namespace tvm {
30 namespace contrib {
31 namespace miopen {
32 
miopenGetErrorString(int error_code)33 std::string miopenGetErrorString(int error_code) {
34   const std::vector<std::string> mio_err{
35       "StatusSuccess        ", "StatusNotInitialized ", "StatusInvalidValue   ",
36       "StatusBadParm        ", "StatusAllocFailed    ", "StatusInternalError  ",
37       "StatusNotImplemented ", "StatusUnknownError   "};
38   return mio_err[error_code];
39 }
40 
41 // MiopenThreadEntry
MIOpenThreadEntry()42 MIOpenThreadEntry::MIOpenThreadEntry() {
43   auto stream = runtime::ROCMThreadEntry::ThreadLocal()->stream;
44   auto func = runtime::Registry::Get("device_api.rocm");
45   void *ret = (*func)();
46   rocm_api = static_cast<runtime::DeviceAPI*>(ret);
47   MIOPEN_CALL(miopenCreate(&handle));
48   MIOPEN_CALL(miopenSetStream(handle, stream));
49   conv_entry.rocm_api = rocm_api;
50 }
51 
~MIOpenThreadEntry()52 MIOpenThreadEntry::~MIOpenThreadEntry() {
53   MIOPEN_CALL(miopenDestroy(handle));
54 }
55 
56 typedef dmlc::ThreadLocalStore<MIOpenThreadEntry> MIOpenThreadStore;
57 
ThreadLocal()58 MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal() {
59   return MIOpenThreadStore::Get();
60 }
61 
62 // ConvEntry
63 
ConvEntry()64 ConvEntry::ConvEntry() {
65   MIOPEN_CALL(miopenCreateConvolutionDescriptor(&conv_desc));
66   MIOPEN_CALL(miopenCreateTensorDescriptor(&filter_desc));
67   MIOPEN_CALL(miopenCreateTensorDescriptor(&input_desc));
68   MIOPEN_CALL(miopenCreateTensorDescriptor(&output_desc));
69 }
70 
~ConvEntry()71 ConvEntry::~ConvEntry() {
72   MIOPEN_CALL(miopenDestroyConvolutionDescriptor(conv_desc));
73   MIOPEN_CALL(miopenDestroyTensorDescriptor(filter_desc));
74   MIOPEN_CALL(miopenDestroyTensorDescriptor(input_desc));
75   MIOPEN_CALL(miopenDestroyTensorDescriptor(output_desc));
76   CleanWorkspace();
77 }
78 
UpdateWorkspace(const size_t wsize)79 void ConvEntry::UpdateWorkspace(const size_t wsize) {
80   if (workspace_size < wsize) {
81     if (workspace != nullptr) {
82       CleanWorkspace();
83     }
84     workspace_size = wsize;
85     workspace = rocm_api->AllocWorkspace(ctx, workspace_size);
86   }
87 }
88 
CleanWorkspace()89 void ConvEntry::CleanWorkspace() {
90   if (workspace) rocm_api->FreeWorkspace(ctx, workspace);
91   workspace_size = 0;
92 }
93 
94 }  // namespace miopen
95 }  // namespace contrib
96 }  // namespace tvm
97