1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #include "registry.h"
21 #include <tvm/api_registry.h>
22 
23 namespace tvm {
24 namespace datatype {
25 
__anon215becf80102(TVMArgs args, TVMRetValue* ret) 26 TVM_REGISTER_GLOBAL("_datatype_register").set_body([](TVMArgs args, TVMRetValue* ret) {
27   datatype::Registry::Global()->Register(args[0], static_cast<uint8_t>(args[1].operator int()));
28 });
29 
__anon215becf80202(TVMArgs args, TVMRetValue* ret) 30 TVM_REGISTER_GLOBAL("_datatype_get_type_code").set_body([](TVMArgs args, TVMRetValue* ret) {
31   *ret = datatype::Registry::Global()->GetTypeCode(args[0]);
32 });
33 
__anon215becf80302(TVMArgs args, TVMRetValue* ret) 34 TVM_REGISTER_GLOBAL("_datatype_get_type_name").set_body([](TVMArgs args, TVMRetValue* ret) {
35   *ret = Registry::Global()->GetTypeName(args[0].operator int());
36 });
37 
__anon215becf80402(TVMArgs args, TVMRetValue* ret) 38 TVM_REGISTER_GLOBAL("_datatype_get_type_registered").set_body([](TVMArgs args, TVMRetValue* ret) {
39   *ret = Registry::Global()->GetTypeRegistered(args[0].operator int());
40 });
41 
Global()42 Registry* Registry::Global() {
43   static Registry inst;
44   return &inst;
45 }
46 
Register(const std::string & type_name,uint8_t type_code)47 void Registry::Register(const std::string& type_name, uint8_t type_code) {
48   CHECK(type_code >= kCustomBegin) << "Please choose a type code >= kCustomBegin for custom types";
49   code_to_name_[type_code] = type_name;
50   name_to_code_[type_name] = type_code;
51 }
52 
GetTypeCode(const std::string & type_name)53 uint8_t Registry::GetTypeCode(const std::string& type_name) {
54   CHECK(name_to_code_.find(type_name) != name_to_code_.end())
55       << "Type name " << type_name << " not registered";
56   return name_to_code_[type_name];
57 }
58 
GetTypeName(uint8_t type_code)59 std::string Registry::GetTypeName(uint8_t type_code) {
60   CHECK(code_to_name_.find(type_code) != code_to_name_.end())
61       << "Type code " << static_cast<unsigned>(type_code) << " not registered";
62   return code_to_name_[type_code];
63 }
64 
GetCastLowerFunc(const std::string & target,uint8_t type_code,uint8_t src_type_code)65 const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t type_code,
66                                             uint8_t src_type_code) {
67   std::ostringstream ss;
68   ss << "tvm.datatype.lower.";
69   ss << target << ".";
70   ss << "Cast"
71      << ".";
72 
73   if (datatype::Registry::Global()->GetTypeRegistered(type_code)) {
74     ss << datatype::Registry::Global()->GetTypeName(type_code);
75   } else {
76     ss << runtime::TypeCode2Str(type_code);
77   }
78 
79   ss << ".";
80 
81   if (datatype::Registry::Global()->GetTypeRegistered(src_type_code)) {
82     ss << datatype::Registry::Global()->GetTypeName(src_type_code);
83   } else {
84     ss << runtime::TypeCode2Str(src_type_code);
85   }
86 
87   return runtime::Registry::Get(ss.str());
88 }
89 
GetFloatImmLowerFunc(const std::string & target,uint8_t type_code)90 const runtime::PackedFunc* GetFloatImmLowerFunc(const std::string& target, uint8_t type_code) {
91   std::ostringstream ss;
92   ss << "tvm.datatype.lower.";
93   ss << target;
94   ss << ".FloatImm.";
95   ss << datatype::Registry::Global()->GetTypeName(type_code);
96   return runtime::Registry::Get(ss.str());
97 }
98 
ConvertConstScalar(uint8_t type_code,double value)99 uint64_t ConvertConstScalar(uint8_t type_code, double value) {
100   std::ostringstream ss;
101   ss << "tvm.datatype.convertconstscalar.float.";
102   ss << datatype::Registry::Global()->GetTypeName(type_code);
103   auto make_const_scalar_func = runtime::Registry::Get(ss.str());
104   return (*make_const_scalar_func)(value).operator uint64_t();
105 }
106 
107 }  // namespace datatype
108 }  // namespace tvm
109