1#!/usr/bin/python3 -i
2#
3# Copyright (c) 2015-2021 The Khronos Group Inc.
4# Copyright (c) 2015-2021 Valve Corporation
5# Copyright (c) 2015-2021 LunarG, Inc.
6# Copyright (c) 2015-2021 Google Inc.
7#
8# Licensed under the Apache License, Version 2.0 (the "License");
9# you may not use this file except in compliance with the License.
10# You may obtain a copy of the License at
11#
12#     http://www.apache.org/licenses/LICENSE-2.0
13#
14# Unless required by applicable law or agreed to in writing, software
15# distributed under the License is distributed on an "AS IS" BASIS,
16# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17# See the License for the specific language governing permissions and
18# limitations under the License.
19#
20# Author: Mark Lobodzinski <mark@lunarg.com>
21
22import os,re,sys
23import xml.etree.ElementTree as etree
24from generator import *
25from collections import namedtuple
26from common_codegen import *
27
28#
29# DispatchTableHelperOutputGeneratorOptions - subclass of GeneratorOptions.
30class DispatchTableHelperOutputGeneratorOptions(GeneratorOptions):
31    def __init__(self,
32                 conventions = None,
33                 filename = None,
34                 directory = '.',
35                 genpath = None,
36                 apiname = None,
37                 profile = None,
38                 versions = '.*',
39                 emitversions = '.*',
40                 defaultExtensions = None,
41                 addExtensions = None,
42                 removeExtensions = None,
43                 emitExtensions = None,
44                 emitSpirv = None,
45                 sortProcedure = regSortFeatures,
46                 prefixText = "",
47                 genFuncPointers = True,
48                 apicall = '',
49                 apientry = '',
50                 apientryp = '',
51                 alignFuncParam = 0,
52                 expandEnumerants = True):
53        GeneratorOptions.__init__(self,
54                conventions = conventions,
55                filename = filename,
56                directory = directory,
57                genpath = genpath,
58                apiname = apiname,
59                profile = profile,
60                versions = versions,
61                emitversions = emitversions,
62                defaultExtensions = defaultExtensions,
63                addExtensions = addExtensions,
64                removeExtensions = removeExtensions,
65                emitExtensions = emitExtensions,
66                emitSpirv = emitSpirv,
67                sortProcedure = sortProcedure)
68        self.prefixText      = prefixText
69        self.genFuncPointers = genFuncPointers
70        self.prefixText      = None
71        self.apicall         = apicall
72        self.apientry        = apientry
73        self.apientryp       = apientryp
74        self.alignFuncParam  = alignFuncParam
75#
76# DispatchTableHelperOutputGenerator - subclass of OutputGenerator.
77# Generates dispatch table helper header files for LVL
78class DispatchTableHelperOutputGenerator(OutputGenerator):
79    """Generate dispatch table helper header based on XML element attributes"""
80    def __init__(self,
81                 errFile = sys.stderr,
82                 warnFile = sys.stderr,
83                 diagFile = sys.stdout):
84        OutputGenerator.__init__(self, errFile, warnFile, diagFile)
85        # Internal state - accumulators for different inner block text
86        self.instance_dispatch_list = []      # List of entries for instance dispatch list
87        self.device_dispatch_list = []        # List of entries for device dispatch list
88        self.dev_ext_stub_list = []           # List of stub functions for device extension functions
89        self.stub_list = []                   # List of functions with stubs (promoted or extensions)
90        self.extension_type = ''
91
92    #
93    # Called once at the beginning of each run
94    def beginFile(self, genOpts):
95        OutputGenerator.beginFile(self, genOpts)
96
97        # Initialize members that require the tree
98        self.handle_types = GetHandleTypes(self.registry.tree)
99
100        write("#pragma once", file=self.outFile)
101        # User-supplied prefix text, if any (list of strings)
102        if (genOpts.prefixText):
103            for s in genOpts.prefixText:
104                write(s, file=self.outFile)
105        # File Comment
106        file_comment = '// *** THIS FILE IS GENERATED - DO NOT EDIT ***\n'
107        file_comment += '// See dispatch_helper_generator.py for modifications\n'
108        write(file_comment, file=self.outFile)
109        # Copyright Notice
110        copyright =  '/*\n'
111        copyright += ' * Copyright (c) 2015-2021 The Khronos Group Inc.\n'
112        copyright += ' * Copyright (c) 2015-2021 Valve Corporation\n'
113        copyright += ' * Copyright (c) 2015-2021 LunarG, Inc.\n'
114        copyright += ' *\n'
115        copyright += ' * Licensed under the Apache License, Version 2.0 (the "License");\n'
116        copyright += ' * you may not use this file except in compliance with the License.\n'
117        copyright += ' * You may obtain a copy of the License at\n'
118        copyright += ' *\n'
119        copyright += ' *     http://www.apache.org/licenses/LICENSE-2.0\n'
120        copyright += ' *\n'
121        copyright += ' * Unless required by applicable law or agreed to in writing, software\n'
122        copyright += ' * distributed under the License is distributed on an "AS IS" BASIS,\n'
123        copyright += ' * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n'
124        copyright += ' * See the License for the specific language governing permissions and\n'
125        copyright += ' * limitations under the License.\n'
126        copyright += ' *\n'
127        copyright += ' * Author: Courtney Goeltzenleuchter <courtney@LunarG.com>\n'
128        copyright += ' * Author: Jon Ashburn <jon@lunarg.com>\n'
129        copyright += ' * Author: Mark Lobodzinski <mark@lunarg.com>\n'
130        copyright += ' */\n'
131
132        preamble = ''
133        preamble += '#include <vulkan/vulkan.h>\n'
134        preamble += '#include <vulkan/vk_layer.h>\n'
135        preamble += '#include <cstring>\n'
136        preamble += '#include <string>\n'
137        preamble += '#include <unordered_set>\n'
138        preamble += '#include <unordered_map>\n'
139        preamble += '#include "vk_layer_dispatch_table.h"\n'
140        preamble += '#include "vk_extension_helper.h"\n'
141
142        write(copyright, file=self.outFile)
143        write(preamble, file=self.outFile)
144    #
145    # Write generate and write dispatch tables to output file
146    def endFile(self):
147        ext_enabled_fcn = ''
148        device_table = ''
149        instance_table = ''
150        ext_enabled_fcn += self.OutputExtEnabledFunction()
151        device_table += self.OutputDispatchTableHelper('device')
152        instance_table += self.OutputDispatchTableHelper('instance')
153
154        for stub in self.dev_ext_stub_list:
155            write(stub, file=self.outFile)
156
157        write("\n\n", file=self.outFile)
158        write(ext_enabled_fcn, file=self.outFile)
159        write("\n", file=self.outFile)
160        write(device_table, file=self.outFile);
161        write("\n", file=self.outFile)
162        write(instance_table, file=self.outFile);
163
164        # Finish processing in superclass
165        OutputGenerator.endFile(self)
166    #
167    # Processing at beginning of each feature or extension
168    def beginFeature(self, interface, emit):
169        OutputGenerator.beginFeature(self, interface, emit)
170        self.featureExtraProtect = GetFeatureProtect(interface)
171        self.extension_type = interface.get('type')
172    #
173    # Process commands, adding to appropriate dispatch tables
174    def genCmd(self, cmdinfo, name, alias):
175        OutputGenerator.genCmd(self, cmdinfo, name, alias)
176
177        avoid_entries = ['vkCreateInstance',
178                         'vkCreateDevice']
179        # Get first param type
180        params = cmdinfo.elem.findall('param')
181        info = self.getTypeNameTuple(params[0])
182
183        if name not in avoid_entries:
184            self.AddCommandToDispatchList(name, info[0], self.featureExtraProtect, cmdinfo)
185
186    #
187    # Determine if this API should be ignored or added to the instance or device dispatch table
188    def AddCommandToDispatchList(self, name, handle_type, protect, cmdinfo):
189        if handle_type not in self.handle_types:
190            return
191        extension = "VK_VERSION" not in self.featureName
192        promoted = not extension and "VK_VERSION_1_0" != self.featureName
193        if promoted or extension:
194            # We want feature written for all promoted entrypoints in addition to extensions
195            self.stub_list.append([name, self.featureName])
196            # Build up stub function
197            return_type = ''
198            decl = self.makeCDecls(cmdinfo.elem)[1]
199            if decl.startswith('typedef VkResult'):
200                return_type = 'return VK_SUCCESS;'
201            elif decl.startswith('typedef VkDeviceAddress'):
202                return_type = 'return 0;'
203            elif decl.startswith('typedef VkDeviceSize'):
204                return_type = 'return 0;'
205            elif decl.startswith('typedef uint32_t'):
206                return_type = 'return 0;'
207            elif decl.startswith('typedef uint64_t'):
208                return_type = 'return 0;'
209            elif decl.startswith('typedef VkBool32'):
210                return_type = 'return VK_FALSE;'
211            pre_decl, decl = decl.split('*PFN_vk')
212            pre_decl = pre_decl.replace('typedef ', '')
213            pre_decl = pre_decl.split(' (')[0]
214            decl = decl.replace(')(', '(')
215            decl = 'static VKAPI_ATTR ' + pre_decl + ' VKAPI_CALL Stub' + decl
216            func_body = ' { ' + return_type + ' };'
217            decl = decl.replace (';', func_body)
218            if self.featureExtraProtect is not None:
219                self.dev_ext_stub_list.append('#ifdef %s' % self.featureExtraProtect)
220            self.dev_ext_stub_list.append(decl)
221            if self.featureExtraProtect is not None:
222                self.dev_ext_stub_list.append('#endif // %s' % self.featureExtraProtect)
223        if handle_type != 'VkInstance' and handle_type != 'VkPhysicalDevice' and name != 'vkGetInstanceProcAddr':
224            self.device_dispatch_list.append((name, self.featureExtraProtect))
225        else:
226            self.instance_dispatch_list.append((name, self.featureExtraProtect))
227        return
228    #
229    # Retrieve the type and name for a parameter
230    def getTypeNameTuple(self, param):
231        type = ''
232        name = ''
233        for elem in param:
234            if elem.tag == 'type':
235                type = noneStr(elem.text)
236            elif elem.tag == 'name':
237                name = noneStr(elem.text)
238        return (type, name)
239    #
240    # Output a function that'll determine if an extension is in the enabled list
241    def OutputExtEnabledFunction(self):
242        ext_fcn = ''
243        # First, write out our static data structure -- map of all APIs that are part of extensions to their extension.
244        ext_fcn += 'const std::unordered_map<std::string, std::string> api_extension_map {\n'
245        api_ext = dict()
246        handles = GetHandleTypes(self.registry.tree)
247        features = self.registry.tree.findall('feature') + self.registry.tree.findall('extensions/extension')
248        for feature in features:
249            feature_name = feature.get('name')
250            if 'VK_VERSION_1_0' == feature_name:
251                continue
252            feature_supported = feature.get('supported')
253            # If feature is not yet supported, skip it
254            if feature_supported == 'disabled':
255                continue
256            for require_element in feature.findall('require'):
257                for command in require_element.findall('command'):
258                    command_name = command.get('name')
259                    if 'EnumerateInstanceVersion' in command_name:
260                        continue
261                    disp_obj = self.registry.tree.find("commands/command/[@name='%s']/param/type" % command_name)
262                    if disp_obj is None:
263                        cmd_info = self.registry.tree.find("commands/command/[@name='%s']" % command_name)
264                        alias_name = cmd_info.get('alias')
265                        if alias_name is not None:
266                            disp_obj = self.registry.tree.find("commands/command/[@name='%s']/param/type" % alias_name)
267                    if 'VkInstance' != disp_obj.text and 'VkPhysicalDevice' != disp_obj.text:
268                        # Ensure APIs belonging to multiple extensions match the existing order
269                        if command_name not in api_ext:
270                            api_ext[command_name] = feature_name
271        for api in sorted(api_ext):
272            ext_fcn += '    {"%s", "%s"},\n' % (api, api_ext[api])
273        ext_fcn += '};\n\n'
274        ext_fcn += '// Using the above code-generated map of APINames-to-parent extension names, this function will:\n'
275        ext_fcn += '//   o  Determine if the API has an associated extension\n'
276        ext_fcn += '//   o  If it does, determine if that extension name is present in the passed-in set of enabled_ext_names \n'
277        ext_fcn += '//   If the APIname has no parent extension, OR its parent extension name is IN the set, return TRUE, else FALSE\n'
278        ext_fcn += 'static inline bool ApiParentExtensionEnabled(const std::string api_name, const DeviceExtensions *device_extension_info) {\n'
279        ext_fcn += '    auto has_ext = api_extension_map.find(api_name);\n'
280        ext_fcn += '    // Is this API part of an extension or feature group?\n'
281        ext_fcn += '    if (has_ext != api_extension_map.end()) {\n'
282        ext_fcn += '        // Was the extension for this API enabled in the CreateDevice call?\n'
283        ext_fcn += '        auto info = device_extension_info->get_info(has_ext->second.c_str());\n'
284        ext_fcn += '        if ((!info.state) || (device_extension_info->*(info.state) != kEnabledByCreateinfo)) {\n'
285        ext_fcn += '            return false;\n'
286        ext_fcn += '        }\n'
287        ext_fcn += '    }\n'
288        ext_fcn += '    return true;\n'
289        ext_fcn += '}\n'
290        return ext_fcn
291    #
292    # Create a dispatch table from the appropriate list and return it as a string
293    def OutputDispatchTableHelper(self, table_type):
294        entries = []
295        table = ''
296        if table_type == 'device':
297            entries = self.device_dispatch_list
298            table += 'static inline void layer_init_device_dispatch_table(VkDevice device, VkLayerDispatchTable *table, PFN_vkGetDeviceProcAddr gpa) {\n'
299            table += '    memset(table, 0, sizeof(*table));\n'
300            table += '    // Device function pointers\n'
301        else:
302            entries = self.instance_dispatch_list
303            table += 'static inline void layer_init_instance_dispatch_table(VkInstance instance, VkLayerInstanceDispatchTable *table, PFN_vkGetInstanceProcAddr gpa) {\n'
304            table += '    memset(table, 0, sizeof(*table));\n'
305            table += '    // Instance function pointers\n'
306
307        stubbed_functions = dict(self.stub_list)
308        for item in entries:
309            # Remove 'vk' from proto name
310            base_name = item[0][2:]
311
312            if item[1] is not None:
313                table += '#ifdef %s\n' % item[1]
314
315            # If we're looking for the proc we are passing in, just point the table to it.  This fixes the issue where
316            # a layer overrides the function name for the loader.
317            if ('device' in table_type and base_name == 'GetDeviceProcAddr'):
318                table += '    table->GetDeviceProcAddr = gpa;\n'
319            elif ('device' not in table_type and base_name == 'GetInstanceProcAddr'):
320                table += '    table->GetInstanceProcAddr = gpa;\n'
321            else:
322                table += '    table->%s = (PFN_%s) gpa(%s, "%s");\n' % (base_name, item[0], table_type, item[0])
323                if item[0] in stubbed_functions:
324                    stub_check = '    if (table->%s == nullptr) { table->%s = (PFN_%s)Stub%s; }\n' % (base_name, base_name, item[0], base_name)
325                    table += stub_check
326            if item[1] is not None:
327                table += '#endif // %s\n' % item[1]
328
329        table += '}'
330        return table
331