1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Defines functions to work with TVMModule FuncRegistry."""
19
20import json
21
22
23def graph_json_to_c_func_registry(graph_path, func_registry_path):
24    """Convert a graph json file to a CRT-compatible FuncRegistry.
25
26    Parameters
27    ----------
28    graph_path : str
29        Path to the graph JSON file.
30
31    func_registry_path : str
32        Path to a .c file which will be written containing the function registry.
33    """
34    with open(graph_path) as json_f:
35        graph = json.load(json_f)
36
37    funcs = []
38    for n in graph["nodes"]:
39        if n["op"] != "tvm_op":
40            continue
41
42        funcs.append(n["attrs"]["func_name"])
43
44    encoded_funcs = f"\\{len(funcs):03o}" + "\\0".join(funcs)
45    lines = [
46        "#include <tvm/runtime/c_runtime_api.h>",
47        "#include <tvm/runtime/crt/module.h>",
48        "#include <stdio.h>",
49        "",
50    ]
51
52    for f in funcs:
53        lines.append(
54            f"extern int {f}(TVMValue* args, int* type_codes, int num_args, "
55            "TVMValue* out_ret_value, int* out_ret_tcode, void* resource_handle);"
56        )
57
58    lines.append("static TVMBackendPackedCFunc funcs[] = {")
59
60    for f in funcs:
61        lines.append(f"    (TVMBackendPackedCFunc) &{f},")
62
63    lines += [
64        "};",
65        "static const TVMFuncRegistry system_lib_registry = {",
66        f'       "{encoded_funcs}\\0",',
67        "        funcs,",
68        "};",
69        "static const TVMModule system_lib = {",
70        "    &system_lib_registry,",
71        "};",
72        "",
73        "const TVMModule* TVMSystemLibEntryPoint(void) {",
74        "    return &system_lib;",
75        "}",
76        "",  # blank line to end the file
77    ]
78    with open(func_registry_path, "w") as wrapper_f:
79        wrapper_f.write("\n".join(lines))
80