1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file Use external miopen utils function
22 */
23 #include "miopen_utils.h"
24 #include <dmlc/thread_local.h>
25 #include <tvm/runtime/registry.h>
26 #include <vector>
27 #include <string>
28
29 namespace tvm {
30 namespace contrib {
31 namespace miopen {
32
miopenGetErrorString(int error_code)33 std::string miopenGetErrorString(int error_code) {
34 const std::vector<std::string> mio_err{
35 "StatusSuccess ", "StatusNotInitialized ", "StatusInvalidValue ",
36 "StatusBadParm ", "StatusAllocFailed ", "StatusInternalError ",
37 "StatusNotImplemented ", "StatusUnknownError "};
38 return mio_err[error_code];
39 }
40
41 // MiopenThreadEntry
MIOpenThreadEntry()42 MIOpenThreadEntry::MIOpenThreadEntry() {
43 auto stream = runtime::ROCMThreadEntry::ThreadLocal()->stream;
44 auto func = runtime::Registry::Get("device_api.rocm");
45 void *ret = (*func)();
46 rocm_api = static_cast<runtime::DeviceAPI*>(ret);
47 MIOPEN_CALL(miopenCreate(&handle));
48 MIOPEN_CALL(miopenSetStream(handle, stream));
49 conv_entry.rocm_api = rocm_api;
50 }
51
~MIOpenThreadEntry()52 MIOpenThreadEntry::~MIOpenThreadEntry() {
53 MIOPEN_CALL(miopenDestroy(handle));
54 }
55
56 typedef dmlc::ThreadLocalStore<MIOpenThreadEntry> MIOpenThreadStore;
57
ThreadLocal()58 MIOpenThreadEntry* MIOpenThreadEntry::ThreadLocal() {
59 return MIOpenThreadStore::Get();
60 }
61
62 // ConvEntry
63
ConvEntry()64 ConvEntry::ConvEntry() {
65 MIOPEN_CALL(miopenCreateConvolutionDescriptor(&conv_desc));
66 MIOPEN_CALL(miopenCreateTensorDescriptor(&filter_desc));
67 MIOPEN_CALL(miopenCreateTensorDescriptor(&input_desc));
68 MIOPEN_CALL(miopenCreateTensorDescriptor(&output_desc));
69 }
70
~ConvEntry()71 ConvEntry::~ConvEntry() {
72 MIOPEN_CALL(miopenDestroyConvolutionDescriptor(conv_desc));
73 MIOPEN_CALL(miopenDestroyTensorDescriptor(filter_desc));
74 MIOPEN_CALL(miopenDestroyTensorDescriptor(input_desc));
75 MIOPEN_CALL(miopenDestroyTensorDescriptor(output_desc));
76 CleanWorkspace();
77 }
78
UpdateWorkspace(const size_t wsize)79 void ConvEntry::UpdateWorkspace(const size_t wsize) {
80 if (workspace_size < wsize) {
81 if (workspace != nullptr) {
82 CleanWorkspace();
83 }
84 workspace_size = wsize;
85 workspace = rocm_api->AllocWorkspace(ctx, workspace_size);
86 }
87 }
88
CleanWorkspace()89 void ConvEntry::CleanWorkspace() {
90 if (workspace) rocm_api->FreeWorkspace(ctx, workspace);
91 workspace_size = 0;
92 }
93
94 } // namespace miopen
95 } // namespace contrib
96 } // namespace tvm
97