1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 // LINT_C_FILE
21 
22 /*!
23  * \file tvm/runtime/crt/func_registry.c
24  * \brief Defines implementations of generic string-based function lookup structs
25  */
26 
27 #include <stdio.h>
28 #include <string.h>
29 #include <tvm/runtime/crt/func_registry.h>
30 
31 /*!
32  * \brief strcmp against the next string in the registry, and return the end.
33  *
34  * Regardless of return value, after calling this function, cursor's value will be modified to
35  * point at the \0 at the end of the string it currently points to.
36  *
37  * \param cursor Pointer to cursor to first string to compare.
38  * \param name Pointer to reference string.
39  * \return 0 if the string pointed to by cursor == name; non-zero otherwise.
40  */
strcmp_cursor(const char ** cursor,const char * name)41 int strcmp_cursor(const char** cursor, const char* name) {
42   int return_value = 0;
43   while (return_value == 0) {
44     char c = **cursor;
45     char n = *name;
46     return_value = ((int)c) - ((int)n);
47 
48     if (n == 0 || c == 0) {
49       break;
50     }
51 
52     name++;
53     (*cursor)++;
54   }
55 
56   while (**cursor != 0) {
57     (*cursor)++;
58   }
59 
60   return return_value;
61 }
62 
TVMFuncRegistry_Lookup(const TVMFuncRegistry * reg,const char * name,tvm_function_index_t * function_index)63 tvm_crt_error_t TVMFuncRegistry_Lookup(const TVMFuncRegistry* reg, const char* name,
64                                        tvm_function_index_t* function_index) {
65   tvm_function_index_t idx;
66   const char* reg_name_ptr;
67 
68   idx = 0;
69   // NOTE: reg_name_ptr starts at index 1 to skip num_funcs.
70   for (reg_name_ptr = reg->names + 1; *reg_name_ptr != '\0'; reg_name_ptr++) {
71     if (!strcmp_cursor(&reg_name_ptr, name)) {
72       *function_index = idx;
73       return kTvmErrorNoError;
74     }
75 
76     idx++;
77   }
78 
79   return kTvmErrorFunctionNameNotFound;
80 }
81 
TVMFuncRegistry_GetByIndex(const TVMFuncRegistry * reg,tvm_function_index_t function_index,TVMBackendPackedCFunc * out_func)82 tvm_crt_error_t TVMFuncRegistry_GetByIndex(const TVMFuncRegistry* reg,
83                                            tvm_function_index_t function_index,
84                                            TVMBackendPackedCFunc* out_func) {
85   uint8_t num_funcs;
86 
87   num_funcs = reg->names[0];
88   if (function_index >= num_funcs) {
89     return kTvmErrorFunctionIndexInvalid;
90   }
91 
92   *out_func = reg->funcs[function_index];
93   return kTvmErrorNoError;
94 }
95 
TVMMutableFuncRegistry_Create(TVMMutableFuncRegistry * reg,uint8_t * buffer,size_t buffer_size_bytes)96 tvm_crt_error_t TVMMutableFuncRegistry_Create(TVMMutableFuncRegistry* reg, uint8_t* buffer,
97                                               size_t buffer_size_bytes) {
98   if (buffer_size_bytes < kTvmAverageFuncEntrySizeBytes) {
99     return kTvmErrorBufferTooSmall;
100   }
101 
102   memset(reg, 0, sizeof(*reg));
103   reg->registry.names = (const char*)buffer;
104   buffer[0] = 0;  // number of functions present in buffer.
105   buffer[1] = 0;  // end of names list marker.
106 
107   // compute a guess of the average size of one entry:
108   //  - assume average function name is around ~10 bytes
109   //  - 1 byte for \0
110   //  - size of 1 function pointer
111   reg->max_functions = buffer_size_bytes / kTvmAverageFuncEntrySizeBytes;
112   reg->registry.funcs =
113       (TVMBackendPackedCFunc*)(buffer + buffer_size_bytes - reg->max_functions * sizeof(void*));
114 
115   return kTvmErrorNoError;
116 }
117 
TVMMutableFuncRegistry_Set(TVMMutableFuncRegistry * reg,const char * name,TVMBackendPackedCFunc func,int override)118 tvm_crt_error_t TVMMutableFuncRegistry_Set(TVMMutableFuncRegistry* reg, const char* name,
119                                            TVMBackendPackedCFunc func, int override) {
120   size_t idx;
121   char* reg_name_ptr;
122 
123   idx = 0;
124   // NOTE: safe to discard const qualifier here, since reg->registry.names was set from
125   // TVMMutableFuncRegistry_Create above.
126   // NOTE: reg_name_ptr starts at index 1 to skip num_funcs.
127   for (reg_name_ptr = (char*)reg->registry.names + 1; *reg_name_ptr != 0; reg_name_ptr++) {
128     if (!strcmp_cursor((const char**)&reg_name_ptr, name)) {
129       if (override == 0) {
130         return kTvmErrorFunctionAlreadyDefined;
131       }
132       ((TVMBackendPackedCFunc*)reg->registry.funcs)[idx] = func;
133       return kTvmErrorNoError;
134     }
135 
136     idx++;
137   }
138 
139   size_t name_len = strlen(name);
140   ssize_t names_bytes_remaining = ((const char*)reg->registry.funcs) - reg_name_ptr;
141   if (idx >= reg->max_functions || name_len + 1 > names_bytes_remaining) {
142     return kTvmErrorFunctionRegistryFull;
143   }
144 
145   memcpy(reg_name_ptr, name, name_len + 1);
146   reg_name_ptr += name_len + 1;
147   *reg_name_ptr = 0;
148   ((TVMBackendPackedCFunc*)reg->registry.funcs)[idx] = func;
149   ((char*)reg->registry.names)[0]++;  // increment num_funcs.
150 
151   return kTvmErrorNoError;
152 }
153