1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 #include "registry.h"
20 
21 #include <tvm/runtime/registry.h>
22 
23 namespace tvm {
24 namespace datatype {
25 
26 using runtime::TVMArgs;
27 using runtime::TVMRetValue;
28 
__anon15bfe8050102(TVMArgs args, TVMRetValue* ret) 29 TVM_REGISTER_GLOBAL("runtime._datatype_register").set_body([](TVMArgs args, TVMRetValue* ret) {
30   datatype::Registry::Global()->Register(args[0], static_cast<uint8_t>(args[1].operator int()));
31 });
32 
__anon15bfe8050202(TVMArgs args, TVMRetValue* ret) 33 TVM_REGISTER_GLOBAL("runtime._datatype_get_type_code").set_body([](TVMArgs args, TVMRetValue* ret) {
34   *ret = datatype::Registry::Global()->GetTypeCode(args[0]);
35 });
36 
__anon15bfe8050302(TVMArgs args, TVMRetValue* ret) 37 TVM_REGISTER_GLOBAL("runtime._datatype_get_type_name").set_body([](TVMArgs args, TVMRetValue* ret) {
38   *ret = Registry::Global()->GetTypeName(args[0].operator int());
39 });
40 
41 TVM_REGISTER_GLOBAL("runtime._datatype_get_type_registered")
__anon15bfe8050402(TVMArgs args, TVMRetValue* ret) 42     .set_body([](TVMArgs args, TVMRetValue* ret) {
43       *ret = Registry::Global()->GetTypeRegistered(args[0].operator int());
44     });
45 
Global()46 Registry* Registry::Global() {
47   static Registry inst;
48   return &inst;
49 }
50 
Register(const std::string & type_name,uint8_t type_code)51 void Registry::Register(const std::string& type_name, uint8_t type_code) {
52   CHECK(type_code >= DataType::kCustomBegin)
53       << "Please choose a type code >= DataType::kCustomBegin for custom types";
54   code_to_name_[type_code] = type_name;
55   name_to_code_[type_name] = type_code;
56 }
57 
GetTypeCode(const std::string & type_name)58 uint8_t Registry::GetTypeCode(const std::string& type_name) {
59   CHECK(name_to_code_.find(type_name) != name_to_code_.end())
60       << "Type name " << type_name << " not registered";
61   return name_to_code_[type_name];
62 }
63 
GetTypeName(uint8_t type_code)64 std::string Registry::GetTypeName(uint8_t type_code) {
65   CHECK(code_to_name_.find(type_code) != code_to_name_.end())
66       << "Type code " << static_cast<unsigned>(type_code) << " not registered";
67   return code_to_name_[type_code];
68 }
69 
GetCastLowerFunc(const std::string & target,uint8_t type_code,uint8_t src_type_code)70 const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t type_code,
71                                             uint8_t src_type_code) {
72   std::ostringstream ss;
73   ss << "tvm.datatype.lower.";
74   ss << target << ".";
75   ss << "Cast"
76      << ".";
77 
78   if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
79     ss << datatype::Registry::Global()->GetTypeName(type_code);
80   } else {
81     ss << runtime::DLDataTypeCode2Str(static_cast<DLDataTypeCode>(type_code));
82   }
83 
84   ss << ".";
85 
86   if (datatype::Registry::Global()->GetTypeRegistered(src_type_code)) {
87     ss << datatype::Registry::Global()->GetTypeName(src_type_code);
88   } else {
89     ss << runtime::DLDataTypeCode2Str(static_cast<DLDataTypeCode>(src_type_code));
90   }
91   return runtime::Registry::Get(ss.str());
92 }
93 
GetMinFunc(uint8_t type_code)94 const runtime::PackedFunc* GetMinFunc(uint8_t type_code) {
95   std::ostringstream ss;
96   ss << "tvm.datatype.min.";
97   ss << datatype::Registry::Global()->GetTypeName(type_code);
98   return runtime::Registry::Get(ss.str());
99 }
100 
GetFloatImmLowerFunc(const std::string & target,uint8_t type_code)101 const runtime::PackedFunc* GetFloatImmLowerFunc(const std::string& target, uint8_t type_code) {
102   std::ostringstream ss;
103   ss << "tvm.datatype.lower.";
104   ss << target;
105   ss << ".FloatImm.";
106   ss << datatype::Registry::Global()->GetTypeName(type_code);
107   return runtime::Registry::Get(ss.str());
108 }
109 
GetIntrinLowerFunc(const std::string & target,const std::string & name,uint8_t type_code)110 const runtime::PackedFunc* GetIntrinLowerFunc(const std::string& target, const std::string& name,
111                                               uint8_t type_code) {
112   std::ostringstream ss;
113   ss << "tvm.datatype.lower.";
114   ss << target;
115   ss << ".Call.intrin.";
116   ss << name;
117   ss << ".";
118   ss << datatype::Registry::Global()->GetTypeName(type_code);
119   return runtime::Registry::Get(ss.str());
120 }
121 
ConvertConstScalar(uint8_t type_code,double value)122 uint64_t ConvertConstScalar(uint8_t type_code, double value) {
123   std::ostringstream ss;
124   ss << "tvm.datatype.convertconstscalar.float.";
125   ss << datatype::Registry::Global()->GetTypeName(type_code);
126   auto make_const_scalar_func = runtime::Registry::Get(ss.str());
127   return (*make_const_scalar_func)(value).operator uint64_t();
128 }
129 
130 }  // namespace datatype
131 }  // namespace tvm
132