1 /*
2  Copyright 2017-2018 Google Inc.
3 
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License.
6  You may obtain a copy of the License at
7 
8  http://www.apache.org/licenses/LICENSE-2.0
9 
10  Unless required by applicable law or agreed to in writing, software
11  distributed under the License is distributed on an "AS IS" BASIS,
12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  See the License for the specific language governing permissions and
14  limitations under the License.
15 */
16 
17 #include "spirv_reflect.h"
18 #include <assert.h>
19 #include <stdbool.h>
20 #include <string.h>
21 
22 #if defined(WIN32)
23   #define _CRTDBG_MAP_ALLOC
24   #include <stdlib.h>
25   #include <crtdbg.h>
26 #else
27   #include <stdlib.h>
28 #endif
29 
30 // Temporary enums until these make it into SPIR-V/Vulkan
31 // clang-format off
32 enum {
33   SpvReflectOpDecorateId                      = 332,
34   SpvReflectOpDecorateStringGOOGLE            = 5632,
35   SpvReflectOpMemberDecorateStringGOOGLE      = 5633,
36   SpvReflectDecorationHlslCounterBufferGOOGLE = 5634,
37   SpvReflectDecorationHlslSemanticGOOGLE      = 5635
38 };
39 // clang-format on
40 
41 // clang-format off
42 enum {
43   SPIRV_STARTING_WORD_INDEX       = 5,
44   SPIRV_WORD_SIZE                 = sizeof(uint32_t),
45   SPIRV_BYTE_WIDTH                = 8,
46   SPIRV_MINIMUM_FILE_SIZE         = SPIRV_STARTING_WORD_INDEX * SPIRV_WORD_SIZE,
47   SPIRV_DATA_ALIGNMENT            = 4 * SPIRV_WORD_SIZE, // 16
48   SPIRV_ACCESS_CHAIN_INDEX_OFFSET = 4,
49 };
50 // clang-format on
51 
52 // clang-format off
53 enum {
54   INVALID_VALUE  = 0xFFFFFFFF,
55 };
56 // clang-format on
57 
58 // clang-format off
59 enum {
60   MAX_NODE_NAME_LENGTH      = 1024,
61 };
62 // clang-format on
63 
64 // clang-format off
65 enum {
66   IMAGE_SAMPLED = 1,
67   IMAGE_STORAGE = 2
68 };
69 // clang-format on
70 
71 // clang-format off
72 typedef struct ArrayTraits {
73   uint32_t              element_type_id;
74   uint32_t              length_id;
75 } ArrayTraits;
76 // clang-format on
77 
78 // clang-format off
79 typedef struct ImageTraits {
80   uint32_t              sampled_type_id;
81   SpvDim                dim;
82   uint32_t              depth;
83   uint32_t              arrayed;
84   uint32_t              ms;
85   uint32_t              sampled;
86   SpvImageFormat        image_format;
87 } ImageTraits;
88 // clang-format on
89 
90 // clang-format off
91 typedef struct NumberDecoration {
92   uint32_t              word_offset;
93   uint32_t              value;
94 } NumberDecoration;
95 // clang-format on
96 
97 // clang-format off
98 typedef struct StringDecoration {
99   uint32_t              word_offset;
100   const char*           value;
101 } StringDecoration;
102 // clang-format on
103 
104 // clang-format off
105 typedef struct Decorations {
106   bool                  is_block;
107   bool                  is_buffer_block;
108   bool                  is_row_major;
109   bool                  is_column_major;
110   bool                  is_built_in;
111   bool                  is_noperspective;
112   bool                  is_flat;
113   bool                  is_non_writable;
114   NumberDecoration      set;
115   NumberDecoration      binding;
116   NumberDecoration      input_attachment_index;
117   NumberDecoration      location;
118   NumberDecoration      offset;
119   NumberDecoration      uav_counter_buffer;
120   StringDecoration      semantic;
121   uint32_t              array_stride;
122   uint32_t              matrix_stride;
123   SpvBuiltIn            built_in;
124 } Decorations;
125 // clang-format on
126 
127 // clang-format off
128 typedef struct Node {
129   uint32_t              result_id;
130   SpvOp                 op;
131   uint32_t              result_type_id;
132   uint32_t              type_id;
133   SpvStorageClass       storage_class;
134   uint32_t              word_offset;
135   uint32_t              word_count;
136   bool                  is_type;
137 
138   ArrayTraits           array_traits;
139   ImageTraits           image_traits;
140   uint32_t              image_type_id;
141 
142   const char*           name;
143   Decorations           decorations;
144   uint32_t              member_count;
145   const char**          member_names;
146   Decorations*          member_decorations;
147 } Node;
148 // clang-format on
149 
150 // clang-format off
151 typedef struct String {
152   uint32_t              result_id;
153   const char*           string;
154 } String;
155 // clang-format on
156 
157 // clang-format off
158 typedef struct Function {
159   uint32_t              id;
160   uint32_t              callee_count;
161   uint32_t*             callees;
162   struct Function**     callee_ptrs;
163   uint32_t              accessed_ptr_count;
164   uint32_t*             accessed_ptrs;
165 } Function;
166 // clang-format on
167 
168 // clang-format off
169 typedef struct AccessChain {
170   uint32_t              result_id;
171   uint32_t              result_type_id;
172   //
173   // Pointing to the base of a composite object.
174   // Generally the id of descriptor block variable
175   uint32_t              base_id;
176   //
177   // From spec:
178   //   The first index in Indexes will select the
179   //   top-level member/element/component/element
180   //   of the base composite
181   uint32_t              index_count;
182   uint32_t*             indexes;
183 } AccessChain;
184 // clang-format on
185 
186 // clang-format off
187 typedef struct Parser {
188   size_t                spirv_word_count;
189   uint32_t*             spirv_code;
190   uint32_t              string_count;
191   String*               strings;
192   SpvSourceLanguage     source_language;
193   uint32_t              source_language_version;
194   uint32_t              source_file_id;
195   String                source_embedded;
196   size_t                node_count;
197   Node*                 nodes;
198   uint32_t              entry_point_count;
199   uint32_t              function_count;
200   Function*             functions;
201   uint32_t              access_chain_count;
202   AccessChain*          access_chains;
203 
204   uint32_t              type_count;
205   uint32_t              descriptor_count;
206   uint32_t              push_constant_count;
207 } Parser;
208 // clang-format on
209 
Max(uint32_t a,uint32_t b)210 static uint32_t Max(uint32_t a, uint32_t b)
211 {
212   return a > b ? a : b;
213 }
214 
RoundUp(uint32_t value,uint32_t multiple)215 static uint32_t RoundUp(uint32_t value, uint32_t multiple)
216 {
217   assert(multiple && ((multiple & (multiple - 1)) == 0));
218   return (value + multiple - 1) & ~(multiple - 1);
219 }
220 
221 #define IsNull(ptr) \
222   (ptr == NULL)
223 
224 #define IsNotNull(ptr) \
225   (ptr != NULL)
226 
227 #define SafeFree(ptr)    \
228   {                      \
229      if (ptr != NULL) {  \
230        free((void*)ptr); \
231        ptr = NULL;       \
232      }                   \
233   }
234 
SortCompareUint32(const void * a,const void * b)235 static int SortCompareUint32(const void* a, const void* b)
236 {
237   const uint32_t* p_a = (const uint32_t*)a;
238   const uint32_t* p_b = (const uint32_t*)b;
239 
240   return (int)*p_a - (int)*p_b;
241 }
242 
243 //
244 // De-duplicates a sorted array and returns the new size.
245 //
246 // Note: The array doesn't actually need to be sorted, just
247 // arranged into "runs" so that all the entries with one
248 // value are adjacent.
249 //
DedupSortedUint32(uint32_t * arr,size_t size)250 static size_t DedupSortedUint32(uint32_t* arr, size_t size)
251 {
252   if (size == 0) {
253     return 0;
254   }
255   size_t dedup_idx = 0;
256   for (size_t i = 0; i < size; ++i) {
257     if (arr[dedup_idx] != arr[i]) {
258       ++dedup_idx;
259       arr[dedup_idx] = arr[i];
260     }
261   }
262   return dedup_idx+1;
263 }
264 
SearchSortedUint32(const uint32_t * arr,size_t size,uint32_t target)265 static bool SearchSortedUint32(const uint32_t* arr, size_t size, uint32_t target)
266 {
267   size_t lo = 0;
268   size_t hi = size;
269   while (lo < hi) {
270     size_t mid = (hi - lo) / 2 + lo;
271     if (arr[mid] == target) {
272       return true;
273     } else if (arr[mid] < target) {
274       lo = mid+1;
275     } else {
276       hi = mid;
277     }
278   }
279   return false;
280 }
281 
IntersectSortedUint32(const uint32_t * p_arr0,size_t arr0_size,const uint32_t * p_arr1,size_t arr1_size,uint32_t ** pp_res,size_t * res_size)282 static SpvReflectResult IntersectSortedUint32(
283   const uint32_t* p_arr0,
284   size_t          arr0_size,
285   const uint32_t* p_arr1,
286   size_t          arr1_size,
287   uint32_t**      pp_res,
288   size_t*         res_size
289 )
290 {
291   *res_size = 0;
292   const uint32_t* arr0_end = p_arr0 + arr0_size;
293   const uint32_t* arr1_end = p_arr1 + arr1_size;
294 
295   const uint32_t* idx0 = p_arr0;
296   const uint32_t* idx1 = p_arr1;
297   while (idx0 != arr0_end && idx1 != arr1_end) {
298     if (*idx0 < *idx1) {
299       ++idx0;
300     } else if (*idx0 > *idx1) {
301       ++idx1;
302     } else {
303       ++*res_size;
304       ++idx0;
305       ++idx1;
306     }
307   }
308 
309   *pp_res = NULL;
310   if (*res_size > 0) {
311     *pp_res = (uint32_t*)calloc(*res_size, sizeof(**pp_res));
312     if (IsNull(*pp_res)) {
313       return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
314     }
315     uint32_t* idxr = *pp_res;
316     idx0 = p_arr0;
317     idx1 = p_arr1;
318     while (idx0 != arr0_end && idx1 != arr1_end) {
319       if (*idx0 < *idx1) {
320         ++idx0;
321       } else if (*idx0 > *idx1) {
322         ++idx1;
323       } else {
324         *(idxr++) = *idx0;
325         ++idx0;
326         ++idx1;
327       }
328     }
329   }
330   return SPV_REFLECT_RESULT_SUCCESS;
331 }
332 
333 
InRange(const Parser * p_parser,uint32_t index)334 static bool InRange(const Parser* p_parser, uint32_t index)
335 {
336   bool in_range = false;
337   if (IsNotNull(p_parser)) {
338     in_range = (index < p_parser->spirv_word_count);
339   }
340   return in_range;
341 }
342 
ReadU32(Parser * p_parser,uint32_t word_offset,uint32_t * p_value)343 static SpvReflectResult ReadU32(Parser* p_parser, uint32_t word_offset, uint32_t* p_value)
344 {
345   assert(IsNotNull(p_parser));
346   assert(IsNotNull(p_parser->spirv_code));
347   assert(InRange(p_parser, word_offset));
348   SpvReflectResult result = SPV_REFLECT_RESULT_ERROR_SPIRV_UNEXPECTED_EOF;
349   if (IsNotNull(p_parser) && IsNotNull(p_parser->spirv_code) && InRange(p_parser, word_offset)) {
350     *p_value = *(p_parser->spirv_code + word_offset);
351     result = SPV_REFLECT_RESULT_SUCCESS;
352   }
353   return result;
354 }
355 
356 #define CHECKED_READU32(parser, word_offset, value)                                      \
357   {                                                                                      \
358     SpvReflectResult checked_readu32_result = ReadU32(parser,                            \
359                                                       word_offset, (uint32_t*)&(value)); \
360     if (checked_readu32_result != SPV_REFLECT_RESULT_SUCCESS) {                          \
361       return checked_readu32_result;                                                     \
362     }                                                                                    \
363   }
364 
365 #define CHECKED_READU32_CAST(parser, word_offset, cast_to_type, value)         \
366   {                                                                            \
367     uint32_t checked_readu32_cast_u32 = UINT32_MAX;                            \
368     SpvReflectResult checked_readu32_cast_result = ReadU32(parser,             \
369                                       word_offset,                             \
370                                       (uint32_t*)&(checked_readu32_cast_u32)); \
371     if (checked_readu32_cast_result != SPV_REFLECT_RESULT_SUCCESS) {           \
372       return checked_readu32_cast_result;                                      \
373     }                                                                          \
374     value = (cast_to_type)checked_readu32_cast_u32;                            \
375   }
376 
377 #define IF_READU32(result, parser, word_offset, value)          \
378   if ((result) == SPV_REFLECT_RESULT_SUCCESS) {                 \
379     result = ReadU32(parser, word_offset, (uint32_t*)&(value)); \
380   }
381 
382 #define IF_READU32_CAST(result, parser, word_offset, cast_to_type, value) \
383   if ((result) == SPV_REFLECT_RESULT_SUCCESS) {                           \
384     uint32_t if_readu32_cast_u32 = UINT32_MAX;                            \
385     result = ReadU32(parser, word_offset, &if_readu32_cast_u32);          \
386     if ((result) == SPV_REFLECT_RESULT_SUCCESS) {                         \
387       value = (cast_to_type)if_readu32_cast_u32;                          \
388     }                                                                     \
389   }
390 
ReadStr(Parser * p_parser,uint32_t word_offset,uint32_t word_index,uint32_t word_count,uint32_t * p_buf_size,char * p_buf)391 static SpvReflectResult ReadStr(
392   Parser*   p_parser,
393   uint32_t  word_offset,
394   uint32_t  word_index,
395   uint32_t  word_count,
396   uint32_t* p_buf_size,
397   char*     p_buf
398 )
399 {
400   uint32_t limit = (word_offset + word_count);
401   assert(IsNotNull(p_parser));
402   assert(IsNotNull(p_parser->spirv_code));
403   assert(InRange(p_parser, limit));
404   SpvReflectResult result = SPV_REFLECT_RESULT_ERROR_SPIRV_UNEXPECTED_EOF;
405   if (IsNotNull(p_parser) && IsNotNull(p_parser->spirv_code) && InRange(p_parser, limit)) {
406     const char* c_str = (const char*)(p_parser->spirv_code + word_offset + word_index);
407     uint32_t n = word_count * SPIRV_WORD_SIZE;
408     uint32_t length_with_terminator = 0;
409     for (uint32_t i = 0; i < n; ++i) {
410       char c = *(c_str + i);
411       if (c == 0) {
412         length_with_terminator = i + 1;
413         break;
414       }
415     }
416 
417     if (length_with_terminator > 0) {
418       result = SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
419       if (IsNotNull(p_buf_size) && IsNotNull(p_buf)) {
420         result = SPV_REFLECT_RESULT_ERROR_RANGE_EXCEEDED;
421         if (length_with_terminator <= *p_buf_size) {
422           memset(p_buf, 0, *p_buf_size);
423           memcpy(p_buf, c_str, length_with_terminator);
424           result = SPV_REFLECT_RESULT_SUCCESS;
425         }
426       }
427       else {
428         if (IsNotNull(p_buf_size)) {
429           *p_buf_size = length_with_terminator;
430           result = SPV_REFLECT_RESULT_SUCCESS;
431         }
432       }
433     }
434   }
435   return result;
436 }
437 
ApplyDecorations(const Decorations * p_decoration_fields)438 static SpvReflectDecorationFlags ApplyDecorations(const Decorations* p_decoration_fields)
439 {
440   SpvReflectDecorationFlags decorations = SPV_REFLECT_DECORATION_NONE;
441   if (p_decoration_fields->is_block) {
442     decorations |= SPV_REFLECT_DECORATION_BLOCK;
443   }
444   if (p_decoration_fields->is_buffer_block) {
445     decorations |= SPV_REFLECT_DECORATION_BUFFER_BLOCK;
446   }
447   if (p_decoration_fields->is_row_major) {
448     decorations |= SPV_REFLECT_DECORATION_ROW_MAJOR;
449   }
450   if (p_decoration_fields->is_column_major) {
451     decorations |= SPV_REFLECT_DECORATION_COLUMN_MAJOR;
452   }
453   if (p_decoration_fields->is_built_in) {
454     decorations |= SPV_REFLECT_DECORATION_BUILT_IN;
455   }
456   if (p_decoration_fields->is_noperspective) {
457     decorations |= SPV_REFLECT_DECORATION_NOPERSPECTIVE;
458   }
459   if (p_decoration_fields->is_flat) {
460     decorations |= SPV_REFLECT_DECORATION_FLAT;
461   }
462   if (p_decoration_fields->is_non_writable) {
463     decorations |= SPV_REFLECT_DECORATION_NON_WRITABLE;
464   }
465   return decorations;
466 }
467 
ApplyNumericTraits(const SpvReflectTypeDescription * p_type,SpvReflectNumericTraits * p_numeric_traits)468 static void ApplyNumericTraits(const SpvReflectTypeDescription* p_type, SpvReflectNumericTraits* p_numeric_traits)
469 {
470   memcpy(p_numeric_traits, &p_type->traits.numeric, sizeof(p_type->traits.numeric));
471 }
472 
ApplyArrayTraits(const SpvReflectTypeDescription * p_type,SpvReflectArrayTraits * p_array_traits)473 static void ApplyArrayTraits(const SpvReflectTypeDescription* p_type, SpvReflectArrayTraits* p_array_traits)
474 {
475   memcpy(p_array_traits, &p_type->traits.array, sizeof(p_type->traits.array));
476 }
477 
FindNode(Parser * p_parser,uint32_t result_id)478 static Node* FindNode(Parser* p_parser, uint32_t result_id)
479 {
480   Node* p_node = NULL;
481   for (size_t i = 0; i < p_parser->node_count; ++i) {
482     Node* p_elem = &(p_parser->nodes[i]);
483     if (p_elem->result_id == result_id) {
484       p_node = p_elem;
485       break;
486     }
487   }
488   return p_node;
489 }
490 
FindType(SpvReflectShaderModule * p_module,uint32_t type_id)491 static SpvReflectTypeDescription* FindType(SpvReflectShaderModule* p_module, uint32_t type_id)
492 {
493   SpvReflectTypeDescription* p_type = NULL;
494   for (size_t i = 0; i < p_module->_internal->type_description_count; ++i) {
495     SpvReflectTypeDescription* p_elem = &(p_module->_internal->type_descriptions[i]);
496     if (p_elem->id == type_id) {
497       p_type = p_elem;
498       break;
499     }
500   }
501   return p_type;
502 }
503 
CreateParser(size_t size,void * p_code,Parser * p_parser)504 static SpvReflectResult CreateParser(size_t size, void* p_code, Parser* p_parser)
505 {
506   if (p_code == NULL) {
507     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
508   }
509 
510   if (size < SPIRV_MINIMUM_FILE_SIZE) {
511     return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_CODE_SIZE;
512   }
513   if ((size % 4) != 0) {
514     return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_CODE_SIZE;
515   }
516 
517   p_parser->spirv_word_count = size / SPIRV_WORD_SIZE;
518   p_parser->spirv_code = (uint32_t*)p_code;
519 
520   if (p_parser->spirv_code[0] != SpvMagicNumber) {
521     return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_MAGIC_NUMBER;
522   }
523 
524   return SPV_REFLECT_RESULT_SUCCESS;
525 }
526 
DestroyParser(Parser * p_parser)527 static void DestroyParser(Parser* p_parser)
528 {
529   if (!IsNull(p_parser->nodes)) {
530     // Free nodes
531     for (size_t i = 0; i < p_parser->node_count; ++i) {
532       Node* p_node = &(p_parser->nodes[i]);
533       if (IsNotNull(p_node->member_names)) {
534         SafeFree(p_node->member_names);
535       }
536       if (IsNotNull(p_node->member_decorations)) {
537         SafeFree(p_node->member_decorations);
538       }
539     }
540 
541     // Free functions
542     for (size_t i = 0; i < p_parser->function_count; ++i) {
543       SafeFree(p_parser->functions[i].callees);
544       SafeFree(p_parser->functions[i].callee_ptrs);
545       SafeFree(p_parser->functions[i].accessed_ptrs);
546     }
547 
548     // Free access chains
549     for (uint32_t i = 0; i < p_parser->access_chain_count; ++i) {
550       SafeFree(p_parser->access_chains[i].indexes);
551     }
552 
553     SafeFree(p_parser->nodes);
554     SafeFree(p_parser->strings);
555     SafeFree(p_parser->functions);
556     SafeFree(p_parser->access_chains);
557     p_parser->node_count = 0;
558   }
559 }
560 
ParseNodes(Parser * p_parser)561 static SpvReflectResult ParseNodes(Parser* p_parser)
562 {
563   assert(IsNotNull(p_parser));
564   assert(IsNotNull(p_parser->spirv_code));
565 
566   uint32_t* p_spirv = p_parser->spirv_code;
567   uint32_t spirv_word_index = SPIRV_STARTING_WORD_INDEX;
568 
569   // Count nodes
570   uint32_t node_count = 0;
571   while (spirv_word_index < p_parser->spirv_word_count) {
572     uint32_t word = p_spirv[spirv_word_index];
573     SpvOp op = (SpvOp)(word & 0xFFFF);
574     uint32_t node_word_count = (word >> 16) & 0xFFFF;
575     if (node_word_count == 0) {
576       return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_INSTRUCTION;
577     }
578     if (op == SpvOpAccessChain) {
579       ++(p_parser->access_chain_count);
580     }
581     spirv_word_index += node_word_count;
582     ++node_count;
583   }
584 
585   if (node_count == 0) {
586     return SPV_REFLECT_RESULT_ERROR_SPIRV_UNEXPECTED_EOF;
587   }
588 
589   // Allocate nodes
590   p_parser->node_count = node_count;
591   p_parser->nodes = (Node*)calloc(p_parser->node_count, sizeof(*(p_parser->nodes)));
592   if (IsNull(p_parser->nodes)) {
593     return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
594   }
595   // Mark all nodes with an invalid state
596   for (uint32_t i = 0; i < node_count; ++i) {
597     p_parser->nodes[i].op = (SpvOp)INVALID_VALUE;
598     p_parser->nodes[i].storage_class = (SpvStorageClass)INVALID_VALUE;
599     p_parser->nodes[i].decorations.set.value = (uint32_t)INVALID_VALUE;
600     p_parser->nodes[i].decorations.binding.value = (uint32_t)INVALID_VALUE;
601     p_parser->nodes[i].decorations.location.value = (uint32_t)INVALID_VALUE;
602     p_parser->nodes[i].decorations.offset.value = (uint32_t)INVALID_VALUE;
603     p_parser->nodes[i].decorations.uav_counter_buffer.value = (uint32_t)INVALID_VALUE;
604     p_parser->nodes[i].decorations.built_in = (SpvBuiltIn)INVALID_VALUE;
605   }
606   // Mark source file id node
607   p_parser->source_file_id = (uint32_t)INVALID_VALUE;
608 
609   // Function node
610   uint32_t function_node = (uint32_t)INVALID_VALUE;
611 
612   // Allocate access chain
613   if (p_parser->access_chain_count > 0) {
614     p_parser->access_chains = (AccessChain*)calloc(p_parser->access_chain_count, sizeof(*(p_parser->access_chains)));
615     if (IsNull(p_parser->access_chains)) {
616       return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
617     }
618   }
619 
620   // Parse nodes
621   uint32_t node_index = 0;
622   uint32_t access_chain_index = 0;
623   spirv_word_index = SPIRV_STARTING_WORD_INDEX;
624   while (spirv_word_index < p_parser->spirv_word_count) {
625     uint32_t word = p_spirv[spirv_word_index];
626     SpvOp op = (SpvOp)(word & 0xFFFF);
627     uint32_t node_word_count = (word >> 16) & 0xFFFF;
628 
629     Node* p_node = &(p_parser->nodes[node_index]);
630     p_node->op = op;
631     p_node->word_offset = spirv_word_index;
632     p_node->word_count = node_word_count;
633 
634     switch (p_node->op) {
635       default: break;
636 
637       case SpvOpString: {
638         ++(p_parser->string_count);
639       }
640       break;
641 
642       case SpvOpSource: {
643         CHECKED_READU32_CAST(p_parser, p_node->word_offset + 1, SpvSourceLanguage, p_parser->source_language);
644         CHECKED_READU32(p_parser, p_node->word_offset + 2, p_parser->source_language_version);
645         if (p_node->word_count >= 4) {
646           CHECKED_READU32(p_parser, p_node->word_offset + 3, p_parser->source_file_id);
647         }
648       }
649       break;
650 
651       case SpvOpEntryPoint: {
652         ++(p_parser->entry_point_count);
653       }
654       break;
655 
656       case SpvOpName:
657       case SpvOpMemberName:
658       {
659         uint32_t member_offset = (p_node->op == SpvOpMemberName) ? 1 : 0;
660         uint32_t name_start = p_node->word_offset + member_offset + 2;
661         p_node->name = (const char*)(p_parser->spirv_code + name_start);
662       }
663       break;
664 
665       case SpvOpTypeStruct:
666       {
667         p_node->member_count = p_node->word_count - 2;
668       } // Fall through
669       case SpvOpTypeVoid:
670       case SpvOpTypeBool:
671       case SpvOpTypeInt:
672       case SpvOpTypeFloat:
673       case SpvOpTypeVector:
674       case SpvOpTypeMatrix:
675       case SpvOpTypeSampler:
676       case SpvOpTypeOpaque:
677       case SpvOpTypeFunction:
678       case SpvOpTypeEvent:
679       case SpvOpTypeDeviceEvent:
680       case SpvOpTypeReserveId:
681       case SpvOpTypeQueue:
682       case SpvOpTypePipe:
683       {
684         CHECKED_READU32(p_parser, p_node->word_offset + 1, p_node->result_id);
685         p_node->is_type = true;
686       }
687       break;
688 
689       case SpvOpTypeImage: {
690         CHECKED_READU32(p_parser, p_node->word_offset + 1, p_node->result_id);
691         CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->image_traits.sampled_type_id);
692         CHECKED_READU32(p_parser, p_node->word_offset + 3, p_node->image_traits.dim);
693         CHECKED_READU32(p_parser, p_node->word_offset + 4, p_node->image_traits.depth);
694         CHECKED_READU32(p_parser, p_node->word_offset + 5, p_node->image_traits.arrayed);
695         CHECKED_READU32(p_parser, p_node->word_offset + 6, p_node->image_traits.ms);
696         CHECKED_READU32(p_parser, p_node->word_offset + 7, p_node->image_traits.sampled);
697         CHECKED_READU32(p_parser, p_node->word_offset + 8, p_node->image_traits.image_format);
698         p_node->is_type = true;
699       }
700       break;
701 
702       case SpvOpTypeSampledImage: {
703         CHECKED_READU32(p_parser, p_node->word_offset + 1, p_node->result_id);
704         CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->image_type_id);
705         p_node->is_type = true;
706       }
707       break;
708 
709       case SpvOpTypeArray:  {
710         CHECKED_READU32(p_parser, p_node->word_offset + 1, p_node->result_id);
711         CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->array_traits.element_type_id);
712         CHECKED_READU32(p_parser, p_node->word_offset + 3, p_node->array_traits.length_id);
713         p_node->is_type = true;
714       }
715       break;
716 
717       case SpvOpTypeRuntimeArray:  {
718         CHECKED_READU32(p_parser, p_node->word_offset + 1, p_node->result_id);
719         CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->array_traits.element_type_id);
720         p_node->is_type = true;
721       }
722       break;
723 
724       case SpvOpTypePointer: {
725         CHECKED_READU32(p_parser, p_node->word_offset + 1, p_node->result_id);
726         CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->storage_class);
727         CHECKED_READU32(p_parser, p_node->word_offset + 3, p_node->type_id);
728         p_node->is_type = true;
729       }
730       break;
731 
732       case SpvOpTypeForwardPointer:
733       {
734         CHECKED_READU32(p_parser, p_node->word_offset + 1, p_node->result_id);
735         CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->storage_class);
736         p_node->is_type = true;
737       }
738       break;
739 
740       case SpvOpConstantTrue:
741       case SpvOpConstantFalse:
742       case SpvOpConstant:
743       case SpvOpConstantComposite:
744       case SpvOpConstantSampler:
745       case SpvOpConstantNull: {
746         CHECKED_READU32(p_parser, p_node->word_offset + 1, p_node->result_type_id);
747         CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->result_id);
748       }
749       break;
750 
751       case SpvOpSpecConstantTrue:
752       case SpvOpSpecConstantFalse:
753       case SpvOpSpecConstant:
754       case SpvOpSpecConstantComposite:
755       case SpvOpSpecConstantOp: {
756         CHECKED_READU32(p_parser, p_node->word_offset + 1, p_node->result_type_id);
757         CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->result_id);
758       }
759       break;
760 
761       case SpvOpVariable:
762       {
763         CHECKED_READU32(p_parser, p_node->word_offset + 1, p_node->type_id);
764         CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->result_id);
765         CHECKED_READU32(p_parser, p_node->word_offset + 3, p_node->storage_class);
766       }
767       break;
768 
769       case SpvOpLoad:
770       {
771         // Only load enough so OpDecorate can reference the node, skip the remaining operands.
772         CHECKED_READU32(p_parser, p_node->word_offset + 1, p_node->result_type_id);
773         CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->result_id);
774       }
775       break;
776 
777       case SpvOpAccessChain:
778       {
779         AccessChain* p_access_chain = &(p_parser->access_chains[access_chain_index]);
780         CHECKED_READU32(p_parser, p_node->word_offset + 1, p_access_chain->result_type_id);
781         CHECKED_READU32(p_parser, p_node->word_offset + 2, p_access_chain->result_id);
782         CHECKED_READU32(p_parser, p_node->word_offset + 3, p_access_chain->base_id);
783         //
784         // SPIRV_ACCESS_CHAIN_INDEX_OFFSET (4) is the number of words up until the first index:
785         //   [Node, Result Type Id, Result Id, Base Id, <Indexes>]
786         //
787         p_access_chain->index_count = (node_word_count - SPIRV_ACCESS_CHAIN_INDEX_OFFSET);
788         if (p_access_chain->index_count > 0) {
789           p_access_chain->indexes = (uint32_t*)calloc(p_access_chain->index_count, sizeof(*(p_access_chain->indexes)));
790           if (IsNull( p_access_chain->indexes)) {
791             return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
792           }
793           // Parse any index values for access chain
794           for (uint32_t index_index = 0; index_index < p_access_chain->index_count; ++index_index) {
795             // Read index id
796             uint32_t index_id = 0;
797             CHECKED_READU32(p_parser, p_node->word_offset + SPIRV_ACCESS_CHAIN_INDEX_OFFSET + index_index, index_id);
798             // Find OpConstant node that contains index value
799             Node* p_index_value_node = FindNode(p_parser, index_id);
800             if ((p_index_value_node != NULL) && (p_index_value_node->op == SpvOpConstant)) {
801               // Read index value
802               uint32_t index_value = UINT32_MAX;
803               CHECKED_READU32(p_parser, p_index_value_node->word_offset + 3, index_value);
804               assert(index_value != UINT32_MAX);
805               // Write index value to array
806               p_access_chain->indexes[index_index] = index_value;
807             }
808           }
809         }
810         ++access_chain_index;
811       }
812       break;
813 
814       case SpvOpFunction:
815       {
816         CHECKED_READU32(p_parser, p_node->word_offset + 2, p_node->result_id);
817         // Count function definitions, not function declarations.  To determine
818         // the difference, set an in-function variable, and then if an OpLabel
819         // is reached before the end of the function increment the function
820         // count.
821         function_node = node_index;
822       }
823       break;
824 
825       case SpvOpLabel:
826       {
827         if (function_node != (uint32_t)INVALID_VALUE) {
828           Node* p_func_node = &(p_parser->nodes[function_node]);
829           CHECKED_READU32(p_parser, p_func_node->word_offset + 2, p_func_node->result_id);
830           ++(p_parser->function_count);
831         }
832       } // Fall through
833 
834       case SpvOpFunctionEnd:
835       {
836         function_node = (uint32_t)INVALID_VALUE;
837       }
838       break;
839     }
840 
841     if (p_node->is_type) {
842       ++(p_parser->type_count);
843     }
844 
845     spirv_word_index += node_word_count;
846     ++node_index;
847   }
848 
849   return SPV_REFLECT_RESULT_SUCCESS;
850 }
851 
ParseStrings(Parser * p_parser)852 static SpvReflectResult ParseStrings(Parser* p_parser)
853 {
854   assert(IsNotNull(p_parser));
855   assert(IsNotNull(p_parser->spirv_code));
856   assert(IsNotNull(p_parser->nodes));
857 
858   // Early out
859   if (p_parser->string_count == 0) {
860     return SPV_REFLECT_RESULT_SUCCESS;
861   }
862 
863   if (IsNotNull(p_parser) && IsNotNull(p_parser->spirv_code) && IsNotNull(p_parser->nodes)) {
864     // Allocate string storage
865     p_parser->strings = (String*)calloc(p_parser->string_count, sizeof(*(p_parser->strings)));
866 
867     uint32_t string_index = 0;
868     for (size_t i = 0; i < p_parser->node_count; ++i) {
869       Node* p_node = &(p_parser->nodes[i]);
870       if (p_node->op != SpvOpString) {
871         continue;
872       }
873 
874       // Paranoid check against string count
875       assert(string_index < p_parser->string_count);
876       if (string_index >= p_parser->string_count) {
877         return SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH;
878       }
879 
880       // Result id
881       String* p_string = &(p_parser->strings[string_index]);
882       CHECKED_READU32(p_parser, p_node->word_offset + 1, p_string->result_id);
883 
884       // String
885       uint32_t string_start = p_node->word_offset + 2;
886       p_string->string = (const char*)(p_parser->spirv_code + string_start);
887 
888       // Increment string index
889       ++string_index;
890     }
891   }
892 
893   return SPV_REFLECT_RESULT_SUCCESS;
894 }
895 
ParseSource(Parser * p_parser,SpvReflectShaderModule * p_module)896 static SpvReflectResult ParseSource(Parser* p_parser, SpvReflectShaderModule* p_module)
897 {
898   assert(IsNotNull(p_parser));
899   assert(IsNotNull(p_parser->spirv_code));
900 
901   if (IsNotNull(p_parser) && IsNotNull(p_parser->spirv_code)) {
902     // Source file
903     if (IsNotNull(p_parser->strings)) {
904       for (uint32_t i = 0; i < p_parser->string_count; ++i) {
905         String* p_string = &(p_parser->strings[i]);
906         if (p_string->result_id == p_parser->source_file_id) {
907           p_module->source_file = p_string->string;
908           break;
909         }
910       }
911     }
912   }
913 
914   return SPV_REFLECT_RESULT_SUCCESS;
915 }
916 
ParseFunction(Parser * p_parser,Node * p_func_node,Function * p_func,size_t first_label_index)917 static SpvReflectResult ParseFunction(Parser* p_parser, Node* p_func_node, Function* p_func, size_t first_label_index)
918 {
919   p_func->id = p_func_node->result_id;
920 
921   p_func->callee_count = 0;
922   p_func->accessed_ptr_count = 0;
923 
924   for (size_t i = first_label_index; i < p_parser->node_count; ++i) {
925     Node* p_node = &(p_parser->nodes[i]);
926     if (p_node->op == SpvOpFunctionEnd) {
927       break;
928     }
929     switch (p_node->op) {
930       case SpvOpFunctionCall: {
931         ++(p_func->callee_count);
932       }
933       break;
934       case SpvOpLoad:
935       case SpvOpAccessChain:
936       case SpvOpInBoundsAccessChain:
937       case SpvOpPtrAccessChain:
938       case SpvOpArrayLength:
939       case SpvOpGenericPtrMemSemantics:
940       case SpvOpInBoundsPtrAccessChain:
941       case SpvOpStore:
942       {
943         ++(p_func->accessed_ptr_count);
944       }
945       break;
946       case SpvOpCopyMemory:
947       case SpvOpCopyMemorySized:
948       {
949         p_func->accessed_ptr_count += 2;
950       }
951       break;
952       default: break;
953     }
954   }
955 
956   if (p_func->callee_count > 0) {
957     p_func->callees = (uint32_t*)calloc(p_func->callee_count,
958                                         sizeof(*(p_func->callees)));
959     if (IsNull(p_func->callees)) {
960       return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
961     }
962   }
963 
964   if (p_func->accessed_ptr_count > 0) {
965     p_func->accessed_ptrs = (uint32_t*)calloc(p_func->accessed_ptr_count,
966                                               sizeof(*(p_func->accessed_ptrs)));
967     if (IsNull(p_func->accessed_ptrs)) {
968       return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
969     }
970   }
971 
972   p_func->callee_count = 0;
973   p_func->accessed_ptr_count = 0;
974   for (size_t i = first_label_index; i < p_parser->node_count; ++i) {
975     Node* p_node = &(p_parser->nodes[i]);
976     if (p_node->op == SpvOpFunctionEnd) {
977       break;
978     }
979     switch (p_node->op) {
980       case SpvOpFunctionCall: {
981         CHECKED_READU32(p_parser, p_node->word_offset + 3,
982                         p_func->callees[p_func->callee_count]);
983         (++p_func->callee_count);
984       }
985       break;
986       case SpvOpLoad:
987       case SpvOpAccessChain:
988       case SpvOpInBoundsAccessChain:
989       case SpvOpPtrAccessChain:
990       case SpvOpArrayLength:
991       case SpvOpGenericPtrMemSemantics:
992       case SpvOpInBoundsPtrAccessChain:
993       {
994         CHECKED_READU32(p_parser, p_node->word_offset + 3,
995                         p_func->accessed_ptrs[p_func->accessed_ptr_count]);
996         (++p_func->accessed_ptr_count);
997       }
998       break;
999       case SpvOpStore:
1000       {
1001         CHECKED_READU32(p_parser, p_node->word_offset + 2,
1002                         p_func->accessed_ptrs[p_func->accessed_ptr_count]);
1003         (++p_func->accessed_ptr_count);
1004       }
1005       break;
1006       case SpvOpCopyMemory:
1007       case SpvOpCopyMemorySized:
1008       {
1009         CHECKED_READU32(p_parser, p_node->word_offset + 2,
1010                         p_func->accessed_ptrs[p_func->accessed_ptr_count]);
1011         (++p_func->accessed_ptr_count);
1012         CHECKED_READU32(p_parser, p_node->word_offset + 3,
1013                         p_func->accessed_ptrs[p_func->accessed_ptr_count]);
1014         (++p_func->accessed_ptr_count);
1015       }
1016       break;
1017       default: break;
1018     }
1019   }
1020 
1021   if (p_func->callee_count > 0) {
1022     qsort(p_func->callees, p_func->callee_count,
1023           sizeof(*(p_func->callees)), SortCompareUint32);
1024   }
1025   p_func->callee_count = (uint32_t)DedupSortedUint32(p_func->callees,
1026                                                      p_func->callee_count);
1027 
1028   if (p_func->accessed_ptr_count > 0) {
1029     qsort(p_func->accessed_ptrs, p_func->accessed_ptr_count,
1030           sizeof(*(p_func->accessed_ptrs)), SortCompareUint32);
1031   }
1032   p_func->accessed_ptr_count = (uint32_t)DedupSortedUint32(p_func->accessed_ptrs,
1033                                                            p_func->accessed_ptr_count);
1034 
1035   return SPV_REFLECT_RESULT_SUCCESS;
1036 }
1037 
SortCompareFunctions(const void * a,const void * b)1038 static int SortCompareFunctions(const void* a, const void* b)
1039 {
1040   const Function* af = (const Function*)a;
1041   const Function* bf = (const Function*)b;
1042   return (int)af->id - (int)bf->id;
1043 }
1044 
ParseFunctions(Parser * p_parser)1045 static SpvReflectResult ParseFunctions(Parser* p_parser)
1046 {
1047   assert(IsNotNull(p_parser));
1048   assert(IsNotNull(p_parser->spirv_code));
1049   assert(IsNotNull(p_parser->nodes));
1050 
1051   if (IsNotNull(p_parser) && IsNotNull(p_parser->spirv_code) && IsNotNull(p_parser->nodes)) {
1052     if (p_parser->function_count == 0) {
1053       return SPV_REFLECT_RESULT_SUCCESS;
1054     }
1055 
1056     p_parser->functions = (Function*)calloc(p_parser->function_count,
1057                                             sizeof(*(p_parser->functions)));
1058     if (IsNull(p_parser->functions)) {
1059       return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
1060     }
1061 
1062     size_t function_index = 0;
1063     for (size_t i = 0; i < p_parser->node_count; ++i) {
1064       Node* p_node = &(p_parser->nodes[i]);
1065       if (p_node->op != SpvOpFunction) {
1066         continue;
1067       }
1068 
1069       // Skip over function declarations that aren't definitions
1070       bool func_definition = false;
1071       // Intentionally reuse i to avoid iterating over these nodes more than
1072       // once
1073       for (; i < p_parser->node_count; ++i) {
1074         if (p_parser->nodes[i].op == SpvOpLabel) {
1075           func_definition = true;
1076           break;
1077         }
1078         if (p_parser->nodes[i].op == SpvOpFunctionEnd) {
1079           break;
1080         }
1081       }
1082       if (!func_definition) {
1083         continue;
1084       }
1085 
1086       Function* p_function = &(p_parser->functions[function_index]);
1087 
1088       SpvReflectResult result = ParseFunction(p_parser, p_node, p_function, i);
1089       if (result != SPV_REFLECT_RESULT_SUCCESS) {
1090         return result;
1091       }
1092 
1093       ++function_index;
1094     }
1095 
1096     qsort(p_parser->functions, p_parser->function_count,
1097           sizeof(*(p_parser->functions)), SortCompareFunctions);
1098 
1099     // Once they're sorted, link the functions with pointers to improve graph
1100     // traversal efficiency
1101     for (size_t i = 0; i < p_parser->function_count; ++i) {
1102       Function* p_func = &(p_parser->functions[i]);
1103       if (p_func->callee_count == 0) {
1104         continue;
1105       }
1106       p_func->callee_ptrs = (Function**)calloc(p_func->callee_count,
1107                                                sizeof(*(p_func->callee_ptrs)));
1108       for (size_t j = 0, k = 0; j < p_func->callee_count; ++j) {
1109         while (p_parser->functions[k].id != p_func->callees[j]) {
1110           ++k;
1111           if (k >= p_parser->function_count) {
1112             // Invalid called function ID somewhere
1113             return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
1114           }
1115         }
1116         p_func->callee_ptrs[j] = &(p_parser->functions[k]);
1117       }
1118     }
1119   }
1120 
1121   return SPV_REFLECT_RESULT_SUCCESS;
1122 }
1123 
ParseMemberCounts(Parser * p_parser)1124 static SpvReflectResult ParseMemberCounts(Parser* p_parser)
1125 {
1126   assert(IsNotNull(p_parser));
1127   assert(IsNotNull(p_parser->spirv_code));
1128   assert(IsNotNull(p_parser->nodes));
1129 
1130   if (IsNotNull(p_parser) && IsNotNull(p_parser->spirv_code) && IsNotNull(p_parser->nodes)) {
1131     for (size_t i = 0; i < p_parser->node_count; ++i) {
1132       Node* p_node = &(p_parser->nodes[i]);
1133       if ((p_node->op != SpvOpMemberName) && (p_node->op != SpvOpMemberDecorate)) {
1134         continue;
1135       }
1136 
1137       uint32_t target_id = 0;
1138       uint32_t member_index = (uint32_t)INVALID_VALUE;
1139       CHECKED_READU32(p_parser, p_node->word_offset + 1, target_id);
1140       CHECKED_READU32(p_parser, p_node->word_offset + 2, member_index);
1141       Node* p_target_node = FindNode(p_parser, target_id);
1142       // Not all nodes get parsed, so FindNode returning NULL is expected.
1143       if (IsNull(p_target_node)) {
1144         continue;
1145       }
1146 
1147       if (member_index == INVALID_VALUE) {
1148         return SPV_REFLECT_RESULT_ERROR_RANGE_EXCEEDED;
1149       }
1150 
1151       p_target_node->member_count = Max(p_target_node->member_count, member_index + 1);
1152     }
1153 
1154     for (uint32_t i = 0; i < p_parser->node_count; ++i) {
1155       Node* p_node = &(p_parser->nodes[i]);
1156       if (p_node->member_count == 0) {
1157         continue;
1158       }
1159 
1160       p_node->member_names = (const char **)calloc(p_node->member_count, sizeof(*(p_node->member_names)));
1161       if (IsNull(p_node->member_names)) {
1162         return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
1163       }
1164 
1165       p_node->member_decorations = (Decorations*)calloc(p_node->member_count, sizeof(*(p_node->member_decorations)));
1166       if (IsNull(p_node->member_decorations)) {
1167         return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
1168       }
1169     }
1170   }
1171   return SPV_REFLECT_RESULT_SUCCESS;
1172 }
1173 
ParseNames(Parser * p_parser)1174 static SpvReflectResult ParseNames(Parser* p_parser)
1175 {
1176   assert(IsNotNull(p_parser));
1177   assert(IsNotNull(p_parser->spirv_code));
1178   assert(IsNotNull(p_parser->nodes));
1179 
1180   if (IsNotNull(p_parser) && IsNotNull(p_parser->spirv_code) && IsNotNull(p_parser->nodes)) {
1181     for (size_t i = 0; i < p_parser->node_count; ++i) {
1182       Node* p_node = &(p_parser->nodes[i]);
1183       if ((p_node->op != SpvOpName) && (p_node->op != SpvOpMemberName)) {
1184         continue;
1185       }
1186 
1187       uint32_t target_id = 0;
1188       CHECKED_READU32(p_parser, p_node->word_offset + 1, target_id);
1189       Node* p_target_node = FindNode(p_parser, target_id);
1190       // Not all nodes get parsed, so FindNode returning NULL is expected.
1191       if (IsNull(p_target_node)) {
1192         continue;
1193       }
1194 
1195       const char** pp_target_name = &(p_target_node->name);
1196       if (p_node->op == SpvOpMemberName) {
1197         uint32_t member_index = UINT32_MAX;
1198         CHECKED_READU32(p_parser, p_node->word_offset + 2, member_index);
1199         pp_target_name = &(p_target_node->member_names[member_index]);
1200       }
1201 
1202       *pp_target_name = p_node->name;
1203     }
1204   }
1205   return SPV_REFLECT_RESULT_SUCCESS;
1206 }
1207 
ParseDecorations(Parser * p_parser)1208 static SpvReflectResult ParseDecorations(Parser* p_parser)
1209 {
1210   for (uint32_t i = 0; i < p_parser->node_count; ++i) {
1211     Node* p_node = &(p_parser->nodes[i]);
1212 
1213     if (((uint32_t)p_node->op != (uint32_t)SpvOpDecorate) &&
1214         ((uint32_t)p_node->op != (uint32_t)SpvOpMemberDecorate) &&
1215         ((uint32_t)p_node->op != (uint32_t)SpvReflectOpDecorateId) &&
1216         ((uint32_t)p_node->op != (uint32_t)SpvReflectOpDecorateStringGOOGLE) &&
1217         ((uint32_t)p_node->op != (uint32_t)SpvReflectOpMemberDecorateStringGOOGLE))
1218     {
1219      continue;
1220     }
1221 
1222     // Need to adjust the read offset if this is a member decoration
1223     uint32_t member_offset = 0;
1224     if (p_node->op == SpvOpMemberDecorate) {
1225       member_offset = 1;
1226     }
1227 
1228     // Get decoration
1229     uint32_t decoration = (uint32_t)INVALID_VALUE;
1230     CHECKED_READU32(p_parser, p_node->word_offset + member_offset + 2, decoration);
1231 
1232     // Filter out the decoration that do not affect reflection, otherwise
1233     // there will be random crashes because the nodes aren't found.
1234     bool skip = false;
1235     switch (decoration) {
1236       default: {
1237         skip = true;
1238       }
1239       break;
1240       case SpvDecorationBlock:
1241       case SpvDecorationBufferBlock:
1242       case SpvDecorationColMajor:
1243       case SpvDecorationRowMajor:
1244       case SpvDecorationArrayStride:
1245       case SpvDecorationMatrixStride:
1246       case SpvDecorationBuiltIn:
1247       case SpvDecorationNoPerspective:
1248       case SpvDecorationFlat:
1249       case SpvDecorationNonWritable:
1250       case SpvDecorationLocation:
1251       case SpvDecorationBinding:
1252       case SpvDecorationDescriptorSet:
1253       case SpvDecorationOffset:
1254       case SpvDecorationInputAttachmentIndex:
1255       case SpvReflectDecorationHlslCounterBufferGOOGLE:
1256       case SpvReflectDecorationHlslSemanticGOOGLE: {
1257         skip = false;
1258       }
1259       break;
1260     }
1261     if (skip) {
1262       continue;
1263     }
1264 
1265     // Find target target node
1266     uint32_t target_id = 0;
1267     CHECKED_READU32(p_parser, p_node->word_offset + 1, target_id);
1268     Node* p_target_node = FindNode(p_parser, target_id);
1269     if (IsNull(p_target_node)) {
1270       return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
1271     }
1272     // Get decorations
1273     Decorations* p_target_decorations = &(p_target_node->decorations);
1274     // Update pointer if this is a member member decoration
1275     if (p_node->op == SpvOpMemberDecorate) {
1276       uint32_t member_index = (uint32_t)INVALID_VALUE;
1277       CHECKED_READU32(p_parser, p_node->word_offset + 2, member_index);
1278       p_target_decorations = &(p_target_node->member_decorations[member_index]);
1279     }
1280 
1281     switch (decoration) {
1282       default: break;
1283 
1284       case SpvDecorationBlock: {
1285         p_target_decorations->is_block = true;
1286       }
1287       break;
1288 
1289       case SpvDecorationBufferBlock: {
1290         p_target_decorations->is_buffer_block = true;
1291       }
1292       break;
1293 
1294       case SpvDecorationColMajor: {
1295         p_target_decorations->is_column_major = true;
1296       }
1297       break;
1298 
1299       case SpvDecorationRowMajor: {
1300         p_target_decorations->is_row_major = true;
1301       }
1302       break;
1303 
1304       case SpvDecorationArrayStride: {
1305         uint32_t word_offset = p_node->word_offset + member_offset + 3;
1306         CHECKED_READU32(p_parser, word_offset, p_target_decorations->array_stride);
1307       }
1308       break;
1309 
1310       case SpvDecorationMatrixStride: {
1311         uint32_t word_offset = p_node->word_offset + member_offset + 3;
1312         CHECKED_READU32(p_parser, word_offset, p_target_decorations->matrix_stride);
1313       }
1314       break;
1315 
1316       case SpvDecorationBuiltIn: {
1317         p_target_decorations->is_built_in = true;
1318         uint32_t word_offset = p_node->word_offset + member_offset + 3;
1319         CHECKED_READU32_CAST(p_parser, word_offset, SpvBuiltIn, p_target_decorations->built_in);
1320       }
1321       break;
1322 
1323       case SpvDecorationNoPerspective: {
1324         p_target_decorations->is_noperspective = true;
1325       }
1326       break;
1327 
1328       case SpvDecorationFlat: {
1329         p_target_decorations->is_flat = true;
1330       }
1331       break;
1332 
1333       case SpvDecorationNonWritable: {
1334         p_target_decorations->is_non_writable = true;
1335       }
1336       break;
1337 
1338       case SpvDecorationLocation: {
1339         uint32_t word_offset = p_node->word_offset + member_offset + 3;
1340         CHECKED_READU32(p_parser, word_offset, p_target_decorations->location.value);
1341         p_target_decorations->location.word_offset = word_offset;
1342       }
1343       break;
1344 
1345       case SpvDecorationBinding: {
1346         uint32_t word_offset = p_node->word_offset + member_offset+ 3;
1347         CHECKED_READU32(p_parser, word_offset, p_target_decorations->binding.value);
1348         p_target_decorations->binding.word_offset = word_offset;
1349       }
1350       break;
1351 
1352       case SpvDecorationDescriptorSet: {
1353         uint32_t word_offset = p_node->word_offset + member_offset+ 3;
1354         CHECKED_READU32(p_parser, word_offset, p_target_decorations->set.value);
1355         p_target_decorations->set.word_offset = word_offset;
1356       }
1357       break;
1358 
1359       case SpvDecorationOffset: {
1360         uint32_t word_offset = p_node->word_offset + member_offset+ 3;
1361         CHECKED_READU32(p_parser, word_offset, p_target_decorations->offset.value);
1362         p_target_decorations->offset.word_offset = word_offset;
1363       }
1364       break;
1365 
1366       case SpvDecorationInputAttachmentIndex: {
1367         uint32_t word_offset = p_node->word_offset + member_offset+ 3;
1368         CHECKED_READU32(p_parser, word_offset, p_target_decorations->input_attachment_index.value);
1369         p_target_decorations->input_attachment_index.word_offset = word_offset;
1370       }
1371       break;
1372 
1373       case SpvReflectDecorationHlslCounterBufferGOOGLE: {
1374         uint32_t word_offset = p_node->word_offset + member_offset+ 3;
1375         CHECKED_READU32(p_parser, word_offset, p_target_decorations->uav_counter_buffer.value);
1376         p_target_decorations->uav_counter_buffer.word_offset = word_offset;
1377       }
1378       break;
1379 
1380       case SpvReflectDecorationHlslSemanticGOOGLE: {
1381         uint32_t word_offset = p_node->word_offset + member_offset + 3;
1382         p_target_decorations->semantic.value = (const char*)(p_parser->spirv_code + word_offset);
1383         p_target_decorations->semantic.word_offset = word_offset;
1384       }
1385       break;
1386     }
1387   }
1388   return SPV_REFLECT_RESULT_SUCCESS;
1389 }
1390 
EnumerateAllUniforms(SpvReflectShaderModule * p_module,size_t * p_uniform_count,uint32_t ** pp_uniforms)1391 static SpvReflectResult EnumerateAllUniforms(
1392   SpvReflectShaderModule* p_module,
1393   size_t*                 p_uniform_count,
1394   uint32_t**              pp_uniforms
1395 )
1396 {
1397   *p_uniform_count = p_module->descriptor_binding_count;
1398   if (*p_uniform_count == 0) {
1399     return SPV_REFLECT_RESULT_SUCCESS;
1400   }
1401   *pp_uniforms = (uint32_t*)calloc(*p_uniform_count, sizeof(**pp_uniforms));
1402 
1403   if (IsNull(*pp_uniforms)) {
1404     return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
1405   }
1406 
1407   for (size_t i = 0; i < *p_uniform_count; ++i) {
1408     (*pp_uniforms)[i] = p_module->descriptor_bindings[i].spirv_id;
1409   }
1410   qsort(*pp_uniforms, *p_uniform_count, sizeof(**pp_uniforms),
1411         SortCompareUint32);
1412   return SPV_REFLECT_RESULT_SUCCESS;
1413 }
1414 
ParseType(Parser * p_parser,Node * p_node,Decorations * p_struct_member_decorations,SpvReflectShaderModule * p_module,SpvReflectTypeDescription * p_type)1415 static SpvReflectResult ParseType(
1416   Parser*                     p_parser,
1417   Node*                       p_node,
1418   Decorations*                p_struct_member_decorations,
1419   SpvReflectShaderModule*     p_module,
1420   SpvReflectTypeDescription*  p_type
1421 )
1422 {
1423   SpvReflectResult result = SPV_REFLECT_RESULT_SUCCESS;
1424 
1425   if (p_node->member_count > 0) {
1426     p_type->member_count = p_node->member_count;
1427     p_type->members = (SpvReflectTypeDescription*)calloc(p_type->member_count, sizeof(*(p_type->members)));
1428     if (IsNotNull(p_type->members)) {
1429       // Mark all members types with an invalid state
1430       for (size_t i = 0; i < p_type->members->member_count; ++i) {
1431         SpvReflectTypeDescription* p_member_type = &(p_type->members[i]);
1432         p_member_type->id = (uint32_t)INVALID_VALUE;
1433         p_member_type->op = (SpvOp)INVALID_VALUE;
1434         p_member_type->storage_class = (SpvStorageClass)INVALID_VALUE;
1435       }
1436     }
1437     else {
1438       result = SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
1439     }
1440   }
1441 
1442   if (result == SPV_REFLECT_RESULT_SUCCESS) {
1443     // Since the parse descends on type information, these will get overwritten
1444     // if not guarded against assignment. Only assign if the id is invalid.
1445     if (p_type->id == INVALID_VALUE) {
1446       p_type->id = p_node->result_id;
1447       p_type->op = p_node->op;
1448       p_type->decoration_flags = 0;
1449     }
1450     // Top level types need to pick up decorations from all types below it.
1451     // Issue and fix here: https://github.com/chaoticbob/SPIRV-Reflect/issues/64
1452     p_type->decoration_flags = ApplyDecorations(&p_node->decorations);
1453 
1454     switch (p_node->op) {
1455       default: break;
1456       case SpvOpTypeVoid:
1457         p_type->type_flags |= SPV_REFLECT_TYPE_FLAG_VOID;
1458         break;
1459 
1460       case SpvOpTypeBool:
1461         p_type->type_flags |= SPV_REFLECT_TYPE_FLAG_BOOL;
1462         break;
1463 
1464       case SpvOpTypeInt: {
1465         p_type->type_flags |= SPV_REFLECT_TYPE_FLAG_INT;
1466         IF_READU32(result, p_parser, p_node->word_offset + 2, p_type->traits.numeric.scalar.width);
1467         IF_READU32(result, p_parser, p_node->word_offset + 3, p_type->traits.numeric.scalar.signedness);
1468       }
1469       break;
1470 
1471       case SpvOpTypeFloat: {
1472         p_type->type_flags |= SPV_REFLECT_TYPE_FLAG_FLOAT;
1473         IF_READU32(result, p_parser, p_node->word_offset + 2, p_type->traits.numeric.scalar.width);
1474       }
1475       break;
1476 
1477       case SpvOpTypeVector: {
1478         p_type->type_flags |= SPV_REFLECT_TYPE_FLAG_VECTOR;
1479         uint32_t component_type_id = (uint32_t)INVALID_VALUE;
1480         IF_READU32(result, p_parser, p_node->word_offset + 2, component_type_id);
1481         IF_READU32(result, p_parser, p_node->word_offset + 3, p_type->traits.numeric.vector.component_count);
1482         // Parse component type
1483         Node* p_next_node = FindNode(p_parser, component_type_id);
1484         if (IsNotNull(p_next_node)) {
1485           result = ParseType(p_parser, p_next_node, NULL, p_module, p_type);
1486         }
1487         else {
1488           result = SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
1489         }
1490       }
1491       break;
1492 
1493       case SpvOpTypeMatrix: {
1494         p_type->type_flags |= SPV_REFLECT_TYPE_FLAG_MATRIX;
1495         uint32_t column_type_id = (uint32_t)INVALID_VALUE;
1496         IF_READU32(result, p_parser, p_node->word_offset + 2, column_type_id);
1497         IF_READU32(result, p_parser, p_node->word_offset + 3, p_type->traits.numeric.matrix.column_count);
1498         Node* p_next_node = FindNode(p_parser, column_type_id);
1499         if (IsNotNull(p_next_node)) {
1500           result = ParseType(p_parser, p_next_node, NULL, p_module, p_type);
1501         }
1502         else {
1503           result = SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
1504         }
1505         p_type->traits.numeric.matrix.row_count = p_type->traits.numeric.vector.component_count;
1506         p_type->traits.numeric.matrix.stride = p_node->decorations.matrix_stride;
1507         // NOTE: Matrix stride is decorated using OpMemberDecoreate - not OpDecoreate.
1508         if (IsNotNull(p_struct_member_decorations)) {
1509           p_type->traits.numeric.matrix.stride = p_struct_member_decorations->matrix_stride;
1510         }
1511       }
1512       break;
1513 
1514       case SpvOpTypeImage: {
1515         p_type->type_flags |= SPV_REFLECT_TYPE_FLAG_EXTERNAL_IMAGE;
1516         IF_READU32_CAST(result, p_parser, p_node->word_offset + 3, SpvDim, p_type->traits.image.dim);
1517         IF_READU32(result, p_parser, p_node->word_offset + 4, p_type->traits.image.depth);
1518         IF_READU32(result, p_parser, p_node->word_offset + 5, p_type->traits.image.arrayed);
1519         IF_READU32(result, p_parser, p_node->word_offset + 6, p_type->traits.image.ms);
1520         IF_READU32(result, p_parser, p_node->word_offset + 7, p_type->traits.image.sampled);
1521         IF_READU32_CAST(result, p_parser, p_node->word_offset + 8, SpvImageFormat, p_type->traits.image.image_format);
1522       }
1523       break;
1524 
1525       case SpvOpTypeSampler: {
1526         p_type->type_flags |= SPV_REFLECT_TYPE_FLAG_EXTERNAL_SAMPLER;
1527       }
1528       break;
1529 
1530       case SpvOpTypeSampledImage: {
1531         p_type->type_flags |= SPV_REFLECT_TYPE_FLAG_EXTERNAL_SAMPLED_IMAGE;
1532         uint32_t image_type_id = (uint32_t)INVALID_VALUE;
1533         IF_READU32(result, p_parser, p_node->word_offset + 2, image_type_id);
1534         Node* p_next_node = FindNode(p_parser, image_type_id);
1535         if (IsNotNull(p_next_node)) {
1536           result = ParseType(p_parser, p_next_node, NULL, p_module, p_type);
1537         }
1538         else {
1539           result = SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
1540         }
1541       }
1542       break;
1543 
1544       case SpvOpTypeArray: {
1545         p_type->type_flags |= SPV_REFLECT_TYPE_FLAG_ARRAY;
1546         if (result == SPV_REFLECT_RESULT_SUCCESS) {
1547           uint32_t element_type_id = (uint32_t)INVALID_VALUE;
1548           uint32_t length_id = (uint32_t)INVALID_VALUE;
1549           IF_READU32(result, p_parser, p_node->word_offset + 2, element_type_id);
1550           IF_READU32(result, p_parser, p_node->word_offset + 3, length_id);
1551           // NOTE: Array stride is decorated using OpDecorate instead of
1552           //       OpMemberDecorate, even if the array is apart of a struct.
1553           p_type->traits.array.stride = p_node->decorations.array_stride;
1554           // Get length for current dimension
1555           Node* p_length_node = FindNode(p_parser, length_id);
1556           if (IsNotNull(p_length_node)) {
1557             if (p_length_node->op == SpvOpSpecConstant ||
1558                 p_length_node->op == SpvOpSpecConstantOp) {
1559               p_type->traits.array.dims[p_type->traits.array.dims_count] = 0xFFFFFFFF;
1560               p_type->traits.array.dims_count += 1;
1561             } else {
1562               uint32_t length = 0;
1563               IF_READU32(result, p_parser, p_length_node->word_offset + 3, length);
1564               if (result == SPV_REFLECT_RESULT_SUCCESS) {
1565                 // Write the array dim and increment the count and offset
1566                 p_type->traits.array.dims[p_type->traits.array.dims_count] = length;
1567                 p_type->traits.array.dims_count += 1;
1568               } else {
1569                 result = SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
1570               }
1571             }
1572             // Parse next dimension or element type
1573             Node* p_next_node = FindNode(p_parser, element_type_id);
1574             if (IsNotNull(p_next_node)) {
1575               result = ParseType(p_parser, p_next_node, NULL, p_module, p_type);
1576             }
1577           }
1578           else {
1579             result = SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
1580           }
1581         }
1582       }
1583       break;
1584 
1585       case SpvOpTypeRuntimeArray: {
1586         uint32_t element_type_id = (uint32_t)INVALID_VALUE;
1587         IF_READU32(result, p_parser, p_node->word_offset + 2, element_type_id);
1588         // Parse next dimension or element type
1589         Node* p_next_node = FindNode(p_parser, element_type_id);
1590         if (IsNotNull(p_next_node)) {
1591           result = ParseType(p_parser, p_next_node, NULL, p_module, p_type);
1592         }
1593         else {
1594           result = SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
1595         }
1596       }
1597       break;
1598 
1599       case SpvOpTypeStruct: {
1600         p_type->type_flags |= SPV_REFLECT_TYPE_FLAG_STRUCT;
1601         p_type->type_flags |= SPV_REFLECT_TYPE_FLAG_EXTERNAL_BLOCK;
1602         uint32_t word_index = 2;
1603         uint32_t member_index = 0;
1604         for (; word_index < p_node->word_count; ++word_index, ++member_index) {
1605           uint32_t member_id = (uint32_t)INVALID_VALUE;
1606           IF_READU32(result, p_parser, p_node->word_offset + word_index, member_id);
1607           // Find member node
1608           Node* p_member_node = FindNode(p_parser, member_id);
1609           if (IsNull(p_member_node)) {
1610             result = SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
1611             break;
1612           }
1613 
1614           // Member decorations
1615           Decorations* p_member_decorations = &p_node->member_decorations[member_index];
1616 
1617           assert(member_index < p_type->member_count);
1618           // Parse member type
1619           SpvReflectTypeDescription* p_member_type = &(p_type->members[member_index]);
1620           p_member_type->id = member_id;
1621           p_member_type->op = p_member_node->op;
1622           result = ParseType(p_parser, p_member_node, p_member_decorations, p_module, p_member_type);
1623           if (result != SPV_REFLECT_RESULT_SUCCESS) {
1624             break;
1625           }
1626           // This looks wrong
1627           //p_member_type->type_name = p_member_node->name;
1628           p_member_type->struct_member_name = p_node->member_names[member_index];
1629         }
1630       }
1631       break;
1632 
1633       case SpvOpTypeOpaque: break;
1634 
1635       case SpvOpTypePointer: {
1636         IF_READU32_CAST(result, p_parser, p_node->word_offset + 2, SpvStorageClass, p_type->storage_class);
1637         uint32_t type_id = (uint32_t)INVALID_VALUE;
1638         IF_READU32(result, p_parser, p_node->word_offset + 3, type_id);
1639         // Parse type
1640         Node* p_next_node = FindNode(p_parser, type_id);
1641         if (IsNotNull(p_next_node)) {
1642           result = ParseType(p_parser, p_next_node, NULL, p_module, p_type);
1643         }
1644         else {
1645           result = SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
1646         }
1647       }
1648       break;
1649     }
1650 
1651     if (result == SPV_REFLECT_RESULT_SUCCESS) {
1652       // Names get assigned on the way down. Guard against names
1653       // get overwritten on the way up.
1654       if (IsNull(p_type->type_name)) {
1655         p_type->type_name = p_node->name;
1656       }
1657     }
1658   }
1659 
1660   return result;
1661 }
1662 
ParseTypes(Parser * p_parser,SpvReflectShaderModule * p_module)1663 static SpvReflectResult ParseTypes(Parser* p_parser, SpvReflectShaderModule* p_module)
1664 {
1665   if (p_parser->type_count == 0) {
1666     return SPV_REFLECT_RESULT_SUCCESS;
1667   }
1668 
1669   p_module->_internal->type_description_count = p_parser->type_count;
1670   p_module->_internal->type_descriptions = (SpvReflectTypeDescription*)calloc(p_module->_internal->type_description_count,
1671                                                                               sizeof(*(p_module->_internal->type_descriptions)));
1672   if (IsNull(p_module->_internal->type_descriptions)) {
1673     return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
1674   }
1675 
1676   // Mark all types with an invalid state
1677   for (size_t i = 0; i < p_module->_internal->type_description_count; ++i) {
1678     SpvReflectTypeDescription* p_type = &(p_module->_internal->type_descriptions[i]);
1679     p_type->id = (uint32_t)INVALID_VALUE;
1680     p_type->op = (SpvOp)INVALID_VALUE;
1681     p_type->storage_class = (SpvStorageClass)INVALID_VALUE;
1682   }
1683 
1684   size_t type_index = 0;
1685   for (size_t i = 0; i < p_parser->node_count; ++i) {
1686     Node* p_node = &(p_parser->nodes[i]);
1687     if (! p_node->is_type) {
1688       continue;
1689     }
1690 
1691     SpvReflectTypeDescription* p_type = &(p_module->_internal->type_descriptions[type_index]);
1692     SpvReflectResult result = ParseType(p_parser, p_node, NULL, p_module, p_type);
1693     if (result != SPV_REFLECT_RESULT_SUCCESS) {
1694       return result;
1695     }
1696     ++type_index;
1697   }
1698   return SPV_REFLECT_RESULT_SUCCESS;
1699 }
1700 
SortCompareDescriptorBinding(const void * a,const void * b)1701 static int SortCompareDescriptorBinding(const void* a, const void* b)
1702 {
1703   const SpvReflectDescriptorBinding* p_elem_a = (const SpvReflectDescriptorBinding*)a;
1704   const SpvReflectDescriptorBinding* p_elem_b = (const SpvReflectDescriptorBinding*)b;
1705   int value = (int)(p_elem_a->binding) - (int)(p_elem_b->binding);
1706   if (value == 0) {
1707     // use spirv-id as a tiebreaker to ensure a stable ordering, as they're guaranteed
1708     // unique.
1709     assert(p_elem_a->spirv_id != p_elem_b->spirv_id);
1710     value = (int)(p_elem_a->spirv_id) - (int)(p_elem_b->spirv_id);
1711   }
1712   return value;
1713 }
1714 
ParseDescriptorBindings(Parser * p_parser,SpvReflectShaderModule * p_module)1715 static SpvReflectResult ParseDescriptorBindings(Parser* p_parser, SpvReflectShaderModule* p_module)
1716 {
1717   p_module->descriptor_binding_count = 0;
1718   for (size_t i = 0; i < p_parser->node_count; ++i) {
1719     Node* p_node = &(p_parser->nodes[i]);
1720     if ((p_node->op != SpvOpVariable) ||
1721         ((p_node->storage_class != SpvStorageClassUniform) && (p_node->storage_class != SpvStorageClassUniformConstant)))
1722     {
1723       continue;
1724     }
1725     if ((p_node->decorations.set.value == INVALID_VALUE) || (p_node->decorations.binding.value == INVALID_VALUE)) {
1726       continue;
1727     }
1728 
1729     p_module->descriptor_binding_count += 1;
1730   }
1731 
1732   if (p_module->descriptor_binding_count == 0) {
1733     return SPV_REFLECT_RESULT_SUCCESS;
1734   }
1735 
1736   p_module->descriptor_bindings = (SpvReflectDescriptorBinding*)calloc(p_module->descriptor_binding_count, sizeof(*(p_module->descriptor_bindings)));
1737   if (IsNull(p_module->descriptor_bindings)) {
1738     return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
1739   }
1740 
1741   // Mark all types with an invalid state
1742   for (uint32_t descriptor_index = 0; descriptor_index < p_module->descriptor_binding_count; ++descriptor_index) {
1743     SpvReflectDescriptorBinding* p_descriptor = &(p_module->descriptor_bindings[descriptor_index]);
1744     p_descriptor->binding = (uint32_t)INVALID_VALUE;
1745     p_descriptor->input_attachment_index = (uint32_t)INVALID_VALUE;
1746     p_descriptor->set = (uint32_t)INVALID_VALUE;
1747     p_descriptor->descriptor_type = (SpvReflectDescriptorType)INVALID_VALUE;
1748     p_descriptor->uav_counter_id = (uint32_t)INVALID_VALUE;
1749   }
1750 
1751   size_t descriptor_index = 0;
1752   for (size_t i = 0; i < p_parser->node_count; ++i) {
1753     Node* p_node = &(p_parser->nodes[i]);
1754     if ((p_node->op != SpvOpVariable) ||
1755         ((p_node->storage_class != SpvStorageClassUniform) && (p_node->storage_class != SpvStorageClassUniformConstant)))\
1756     {
1757       continue;
1758     }
1759     if ((p_node->decorations.set.value == INVALID_VALUE) || (p_node->decorations.binding.value == INVALID_VALUE)) {
1760       continue;
1761     }
1762 
1763     SpvReflectTypeDescription* p_type = FindType(p_module, p_node->type_id);
1764     if (IsNull(p_type)) {
1765       return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
1766     }
1767     // If the type is a pointer, resolve it
1768     if (p_type->op == SpvOpTypePointer) {
1769       // Find the type's node
1770       Node* p_type_node = FindNode(p_parser, p_type->id);
1771       if (IsNull(p_type_node)) {
1772         return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
1773       }
1774       // Should be the resolved type
1775       p_type = FindType(p_module, p_type_node->type_id);
1776       if (IsNull(p_type)) {
1777         return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
1778       }
1779     }
1780 
1781     SpvReflectDescriptorBinding* p_descriptor = &p_module->descriptor_bindings[descriptor_index];
1782     p_descriptor->spirv_id = p_node->result_id;
1783     p_descriptor->name = p_node->name;
1784     p_descriptor->binding = p_node->decorations.binding.value;
1785     p_descriptor->input_attachment_index = p_node->decorations.input_attachment_index.value;
1786     p_descriptor->set = p_node->decorations.set.value;
1787     p_descriptor->count = 1;
1788     p_descriptor->uav_counter_id = p_node->decorations.uav_counter_buffer.value;
1789     p_descriptor->type_description = p_type;
1790 
1791     // Copy image traits
1792     if ((p_type->type_flags & SPV_REFLECT_TYPE_FLAG_EXTERNAL_MASK) == SPV_REFLECT_TYPE_FLAG_EXTERNAL_IMAGE) {
1793       memcpy(&p_descriptor->image, &p_type->traits.image, sizeof(p_descriptor->image));
1794     }
1795 
1796     // This is a workaround for: https://github.com/KhronosGroup/glslang/issues/1096
1797     {
1798       const uint32_t resource_mask = SPV_REFLECT_TYPE_FLAG_EXTERNAL_SAMPLED_IMAGE | SPV_REFLECT_TYPE_FLAG_EXTERNAL_IMAGE;
1799       if ((p_type->type_flags & resource_mask) == resource_mask) {
1800         memcpy(&p_descriptor->image, &p_type->traits.image, sizeof(p_descriptor->image));
1801       }
1802     }
1803 
1804     // Copy array traits
1805     if (p_type->traits.array.dims_count > 0) {
1806       p_descriptor->array.dims_count = p_type->traits.array.dims_count;
1807       for (uint32_t dim_index = 0; dim_index < p_type->traits.array.dims_count; ++dim_index) {
1808         uint32_t dim_value = p_type->traits.array.dims[dim_index];
1809         p_descriptor->array.dims[dim_index] = dim_value;
1810         p_descriptor->count *= dim_value;
1811       }
1812     }
1813 
1814     // Count
1815 
1816 
1817     p_descriptor->word_offset.binding = p_node->decorations.binding.word_offset;
1818     p_descriptor->word_offset.set = p_node->decorations.set.word_offset;
1819 
1820     ++descriptor_index;
1821   }
1822 
1823   if (p_module->descriptor_binding_count > 0) {
1824     qsort(p_module->descriptor_bindings,
1825           p_module->descriptor_binding_count,
1826           sizeof(*(p_module->descriptor_bindings)),
1827           SortCompareDescriptorBinding);
1828   }
1829 
1830   return SPV_REFLECT_RESULT_SUCCESS;
1831 }
1832 
ParseDescriptorType(SpvReflectShaderModule * p_module)1833 static SpvReflectResult ParseDescriptorType(SpvReflectShaderModule* p_module)
1834 {
1835   if (p_module->descriptor_binding_count == 0) {
1836     return SPV_REFLECT_RESULT_SUCCESS;
1837   }
1838 
1839   for (uint32_t descriptor_index = 0; descriptor_index < p_module->descriptor_binding_count; ++descriptor_index) {
1840     SpvReflectDescriptorBinding* p_descriptor = &(p_module->descriptor_bindings[descriptor_index]);
1841     SpvReflectTypeDescription* p_type = p_descriptor->type_description;
1842 
1843     switch (p_type->type_flags & SPV_REFLECT_TYPE_FLAG_EXTERNAL_MASK) {
1844       default: assert(false && "unknown type flag"); break;
1845 
1846       case SPV_REFLECT_TYPE_FLAG_EXTERNAL_IMAGE: {
1847         if (p_descriptor->image.dim == SpvDimBuffer) {
1848           switch (p_descriptor->image.sampled) {
1849             default: assert(false && "unknown texel buffer sampled value"); break;
1850             case IMAGE_SAMPLED: p_descriptor->descriptor_type = SPV_REFLECT_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER; break;
1851             case IMAGE_STORAGE: p_descriptor->descriptor_type = SPV_REFLECT_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER; break;
1852           }
1853         }
1854         else if(p_descriptor->image.dim == SpvDimSubpassData) {
1855           p_descriptor->descriptor_type = SPV_REFLECT_DESCRIPTOR_TYPE_INPUT_ATTACHMENT;
1856         }
1857         else {
1858           switch (p_descriptor->image.sampled) {
1859             default: assert(false && "unknown image sampled value"); break;
1860             case IMAGE_SAMPLED: p_descriptor->descriptor_type = SPV_REFLECT_DESCRIPTOR_TYPE_SAMPLED_IMAGE; break;
1861             case IMAGE_STORAGE: p_descriptor->descriptor_type = SPV_REFLECT_DESCRIPTOR_TYPE_STORAGE_IMAGE; break;
1862           }
1863         }
1864       }
1865       break;
1866 
1867       case SPV_REFLECT_TYPE_FLAG_EXTERNAL_SAMPLER: {
1868         p_descriptor->descriptor_type = SPV_REFLECT_DESCRIPTOR_TYPE_SAMPLER;
1869       }
1870       break;
1871 
1872       case (SPV_REFLECT_TYPE_FLAG_EXTERNAL_SAMPLED_IMAGE | SPV_REFLECT_TYPE_FLAG_EXTERNAL_IMAGE): {
1873         // This is a workaround for: https://github.com/KhronosGroup/glslang/issues/1096
1874         if (p_descriptor->image.dim == SpvDimBuffer) {
1875           switch (p_descriptor->image.sampled) {
1876             default: assert(false && "unknown texel buffer sampled value"); break;
1877             case IMAGE_SAMPLED: p_descriptor->descriptor_type = SPV_REFLECT_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER; break;
1878             case IMAGE_STORAGE: p_descriptor->descriptor_type = SPV_REFLECT_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER; break;
1879           }
1880         }
1881         else {
1882           p_descriptor->descriptor_type = SPV_REFLECT_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER;
1883         }
1884       }
1885       break;
1886 
1887       case SPV_REFLECT_TYPE_FLAG_EXTERNAL_BLOCK: {
1888         if (p_type->decoration_flags & SPV_REFLECT_DECORATION_BLOCK) {
1889           p_descriptor->descriptor_type = SPV_REFLECT_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
1890         }
1891         else if (p_type->decoration_flags & SPV_REFLECT_DECORATION_BUFFER_BLOCK) {
1892           p_descriptor->descriptor_type = SPV_REFLECT_DESCRIPTOR_TYPE_STORAGE_BUFFER;
1893         }
1894         else {
1895           assert(false && "unknown struct");
1896         }
1897       }
1898       break;
1899     }
1900 
1901     switch (p_descriptor->descriptor_type) {
1902       case SPV_REFLECT_DESCRIPTOR_TYPE_SAMPLER                : p_descriptor->resource_type = SPV_REFLECT_RESOURCE_FLAG_SAMPLER; break;
1903       case SPV_REFLECT_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER : p_descriptor->resource_type = (SpvReflectResourceType)(SPV_REFLECT_RESOURCE_FLAG_SAMPLER | SPV_REFLECT_RESOURCE_FLAG_SRV); break;
1904       case SPV_REFLECT_DESCRIPTOR_TYPE_SAMPLED_IMAGE          : p_descriptor->resource_type = SPV_REFLECT_RESOURCE_FLAG_SRV; break;
1905       case SPV_REFLECT_DESCRIPTOR_TYPE_STORAGE_IMAGE          : p_descriptor->resource_type = SPV_REFLECT_RESOURCE_FLAG_UAV; break;
1906       case SPV_REFLECT_DESCRIPTOR_TYPE_UNIFORM_TEXEL_BUFFER   : p_descriptor->resource_type = SPV_REFLECT_RESOURCE_FLAG_SRV; break;
1907       case SPV_REFLECT_DESCRIPTOR_TYPE_STORAGE_TEXEL_BUFFER   : p_descriptor->resource_type = SPV_REFLECT_RESOURCE_FLAG_UAV; break;
1908       case SPV_REFLECT_DESCRIPTOR_TYPE_UNIFORM_BUFFER         : p_descriptor->resource_type = SPV_REFLECT_RESOURCE_FLAG_CBV; break;
1909       case SPV_REFLECT_DESCRIPTOR_TYPE_UNIFORM_BUFFER_DYNAMIC : p_descriptor->resource_type = SPV_REFLECT_RESOURCE_FLAG_CBV; break;
1910       case SPV_REFLECT_DESCRIPTOR_TYPE_STORAGE_BUFFER         : p_descriptor->resource_type = SPV_REFLECT_RESOURCE_FLAG_UAV; break;
1911       case SPV_REFLECT_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC : p_descriptor->resource_type = SPV_REFLECT_RESOURCE_FLAG_UAV; break;
1912 
1913       case SPV_REFLECT_DESCRIPTOR_TYPE_INPUT_ATTACHMENT:
1914         break;
1915     }
1916   }
1917 
1918   return SPV_REFLECT_RESULT_SUCCESS;
1919 }
1920 
ParseUAVCounterBindings(SpvReflectShaderModule * p_module)1921 static SpvReflectResult ParseUAVCounterBindings(SpvReflectShaderModule* p_module)
1922 {
1923   char name[MAX_NODE_NAME_LENGTH];
1924   const char* k_count_tag = "@count";
1925 
1926   for (uint32_t descriptor_index = 0; descriptor_index < p_module->descriptor_binding_count; ++descriptor_index) {
1927     SpvReflectDescriptorBinding* p_descriptor = &(p_module->descriptor_bindings[descriptor_index]);
1928 
1929     if (p_descriptor->descriptor_type != SPV_REFLECT_DESCRIPTOR_TYPE_STORAGE_BUFFER) {
1930       continue;
1931     }
1932 
1933     SpvReflectDescriptorBinding* p_counter_descriptor = NULL;
1934     // Use UAV counter buffer id if present...
1935     if (p_descriptor->uav_counter_id != UINT32_MAX) {
1936       for (uint32_t counter_descriptor_index = 0; counter_descriptor_index < p_module->descriptor_binding_count; ++counter_descriptor_index) {
1937         SpvReflectDescriptorBinding* p_test_counter_descriptor = &(p_module->descriptor_bindings[counter_descriptor_index]);
1938         if (p_test_counter_descriptor->descriptor_type != SPV_REFLECT_DESCRIPTOR_TYPE_STORAGE_BUFFER) {
1939           continue;
1940         }
1941         if (p_descriptor->uav_counter_id == p_test_counter_descriptor->spirv_id) {
1942           p_counter_descriptor = p_test_counter_descriptor;
1943           break;
1944         }
1945       }
1946     }
1947     // ...otherwise use old @count convention.
1948     else {
1949       const size_t descriptor_name_length = p_descriptor->name? strlen(p_descriptor->name): 0;
1950 
1951       memset(name, 0, MAX_NODE_NAME_LENGTH);
1952       memcpy(name, p_descriptor->name, descriptor_name_length);
1953 #if defined(WIN32)
1954       strcat_s(name, MAX_NODE_NAME_LENGTH, k_count_tag);
1955 #else
1956       strcat(name, k_count_tag);
1957 #endif
1958 
1959       for (uint32_t counter_descriptor_index = 0; counter_descriptor_index < p_module->descriptor_binding_count; ++counter_descriptor_index) {
1960         SpvReflectDescriptorBinding* p_test_counter_descriptor = &(p_module->descriptor_bindings[counter_descriptor_index]);
1961         if (p_test_counter_descriptor->descriptor_type != SPV_REFLECT_DESCRIPTOR_TYPE_STORAGE_BUFFER) {
1962           continue;
1963         }
1964         if (p_test_counter_descriptor->name && strcmp(name, p_test_counter_descriptor->name) == 0) {
1965           p_counter_descriptor = p_test_counter_descriptor;
1966           break;
1967         }
1968       }
1969     }
1970 
1971     if (p_counter_descriptor != NULL) {
1972       p_descriptor->uav_counter_binding = p_counter_descriptor;
1973     }
1974   }
1975 
1976   return SPV_REFLECT_RESULT_SUCCESS;
1977 }
1978 
ParseDescriptorBlockVariable(Parser * p_parser,SpvReflectShaderModule * p_module,SpvReflectTypeDescription * p_type,SpvReflectBlockVariable * p_var)1979 static SpvReflectResult ParseDescriptorBlockVariable(
1980   Parser*                     p_parser,
1981   SpvReflectShaderModule*     p_module,
1982   SpvReflectTypeDescription*  p_type,
1983   SpvReflectBlockVariable*    p_var
1984 )
1985 {
1986   bool has_non_writable = false;
1987 
1988   if (IsNotNull(p_type->members) && (p_type->member_count > 0)) {
1989     p_var->member_count = p_type->member_count;
1990     p_var->members = (SpvReflectBlockVariable*)calloc(p_var->member_count, sizeof(*p_var->members));
1991     if (IsNull(p_var->members)) {
1992       return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
1993     }
1994 
1995     Node* p_type_node = FindNode(p_parser, p_type->id);
1996     if (IsNull(p_type_node)) {
1997       return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
1998     }
1999     // Resolve to element type if current type is array or run time array
2000     if (p_type_node->op == SpvOpTypeArray) {
2001       while (p_type_node->op == SpvOpTypeArray) {
2002         p_type_node = FindNode(p_parser, p_type_node->array_traits.element_type_id);
2003         if (IsNull(p_type_node)) {
2004           return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
2005         }
2006       }
2007     }
2008     else if(p_type_node->op == SpvOpTypeRuntimeArray) {
2009       // Element type description
2010       p_type = FindType(p_module, p_type_node->array_traits.element_type_id);
2011       if (IsNull(p_type)) {
2012         return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
2013       }
2014       // Element type node
2015       p_type_node = FindNode(p_parser, p_type->id);
2016       if (IsNull(p_type_node)) {
2017         return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
2018       }
2019     }
2020 
2021     // Parse members
2022     for (uint32_t member_index = 0; member_index < p_type->member_count; ++member_index) {
2023       SpvReflectTypeDescription* p_member_type = &p_type->members[member_index];
2024       SpvReflectBlockVariable* p_member_var = &p_var->members[member_index];
2025       bool is_struct = (p_member_type->type_flags & SPV_REFLECT_TYPE_FLAG_STRUCT) == SPV_REFLECT_TYPE_FLAG_STRUCT;
2026       if (is_struct) {
2027         SpvReflectResult result = ParseDescriptorBlockVariable(p_parser, p_module, p_member_type, p_member_var);
2028         if (result != SPV_REFLECT_RESULT_SUCCESS) {
2029           return result;
2030         }
2031       }
2032 
2033       p_member_var->name = p_type_node->member_names[member_index];
2034       p_member_var->offset = p_type_node->member_decorations[member_index].offset.value;
2035       p_member_var->decoration_flags = ApplyDecorations(&p_type_node->member_decorations[member_index]);
2036       p_member_var->flags |= SPV_REFLECT_VARIABLE_FLAGS_UNUSED;
2037       if (!has_non_writable && (p_member_var->decoration_flags & SPV_REFLECT_DECORATION_NON_WRITABLE)) {
2038         has_non_writable = true;
2039       }
2040       ApplyNumericTraits(p_member_type, &p_member_var->numeric);
2041       if (p_member_type->op == SpvOpTypeArray) {
2042         ApplyArrayTraits(p_member_type, &p_member_var->array);
2043       }
2044 
2045       p_member_var->type_description = p_member_type;
2046     }
2047   }
2048 
2049   p_var->name = p_type->type_name;
2050   p_var->type_description = p_type;
2051   if (has_non_writable) {
2052     p_var->decoration_flags |= SPV_REFLECT_DECORATION_NON_WRITABLE;
2053   }
2054 
2055   return SPV_REFLECT_RESULT_SUCCESS;
2056 }
2057 
ParseDescriptorBlockVariableSizes(Parser * p_parser,SpvReflectShaderModule * p_module,bool is_parent_root,bool is_parent_aos,bool is_parent_rta,SpvReflectBlockVariable * p_var)2058 static SpvReflectResult ParseDescriptorBlockVariableSizes(
2059   Parser*                   p_parser,
2060   SpvReflectShaderModule*   p_module,
2061   bool                      is_parent_root,
2062   bool                      is_parent_aos,
2063   bool                      is_parent_rta,
2064   SpvReflectBlockVariable*  p_var
2065 )
2066 {
2067   if (p_var->member_count == 0) {
2068     return SPV_REFLECT_RESULT_SUCCESS;
2069   }
2070 
2071   // Absolute offsets
2072   for (uint32_t member_index = 0; member_index < p_var->member_count; ++member_index) {
2073     SpvReflectBlockVariable* p_member_var = &p_var->members[member_index];
2074     if (is_parent_root) {
2075       p_member_var->absolute_offset = p_member_var->offset;
2076     }
2077     else {
2078       p_member_var->absolute_offset = is_parent_aos ? 0 : p_member_var->offset + p_var->absolute_offset;
2079     }
2080   }
2081 
2082   // Size
2083   for (uint32_t member_index = 0; member_index < p_var->member_count; ++member_index) {
2084     SpvReflectBlockVariable* p_member_var = &p_var->members[member_index];
2085     SpvReflectTypeDescription* p_member_type = p_member_var->type_description;
2086 
2087     switch (p_member_type->op) {
2088       case SpvOpTypeBool: {
2089         p_member_var->size = SPIRV_WORD_SIZE;
2090       }
2091       break;
2092 
2093       case SpvOpTypeInt:
2094       case SpvOpTypeFloat: {
2095         p_member_var->size = p_member_type->traits.numeric.scalar.width / SPIRV_BYTE_WIDTH;
2096       }
2097       break;
2098 
2099       case SpvOpTypeVector: {
2100         uint32_t size = p_member_type->traits.numeric.vector.component_count *
2101                         (p_member_type->traits.numeric.scalar.width / SPIRV_BYTE_WIDTH);
2102         p_member_var->size = size;
2103       }
2104       break;
2105 
2106       case SpvOpTypeMatrix: {
2107         if (p_member_var->decoration_flags & SPV_REFLECT_DECORATION_COLUMN_MAJOR) {
2108           p_member_var->size = p_member_var->numeric.matrix.column_count * p_member_var->numeric.matrix.stride;
2109         }
2110         else if (p_member_var->decoration_flags & SPV_REFLECT_DECORATION_ROW_MAJOR) {
2111           p_member_var->size = p_member_var->numeric.matrix.row_count * p_member_var->numeric.matrix.stride;
2112         }
2113       }
2114       break;
2115 
2116       case SpvOpTypeArray: {
2117         // If array of structs, parse members first...
2118         bool is_struct = (p_member_type->type_flags & SPV_REFLECT_TYPE_FLAG_STRUCT) == SPV_REFLECT_TYPE_FLAG_STRUCT;
2119         if (is_struct) {
2120           SpvReflectResult result = ParseDescriptorBlockVariableSizes(p_parser, p_module, false, true, is_parent_rta, p_member_var);
2121           if (result != SPV_REFLECT_RESULT_SUCCESS) {
2122             return result;
2123           }
2124         }
2125         // ...then array
2126         uint32_t element_count = (p_member_var->array.dims_count > 0 ? 1 : 0);
2127         for (uint32_t i = 0; i < p_member_var->array.dims_count; ++i) {
2128           element_count *= p_member_var->array.dims[i];
2129         }
2130         p_member_var->size = element_count * p_member_var->array.stride;
2131       }
2132       break;
2133 
2134       case SpvOpTypeRuntimeArray: {
2135         bool is_struct = (p_member_type->type_flags & SPV_REFLECT_TYPE_FLAG_STRUCT) == SPV_REFLECT_TYPE_FLAG_STRUCT;
2136         if (is_struct) {
2137           SpvReflectResult result = ParseDescriptorBlockVariableSizes(p_parser, p_module, false, true, true, p_member_var);
2138           if (result != SPV_REFLECT_RESULT_SUCCESS) {
2139             return result;
2140           }
2141         }
2142       }
2143       break;
2144 
2145       case SpvOpTypeStruct: {
2146         SpvReflectResult result = ParseDescriptorBlockVariableSizes(p_parser, p_module, false, is_parent_aos, is_parent_rta, p_member_var);
2147         if (result != SPV_REFLECT_RESULT_SUCCESS) {
2148           return result;
2149         }
2150       }
2151       break;
2152 
2153       default:
2154         break;
2155     }
2156   }
2157 
2158   // Parse padded size using offset difference for all member except for the last entry...
2159   for (uint32_t member_index = 0; member_index < (p_var->member_count - 1); ++member_index) {
2160     SpvReflectBlockVariable* p_member_var = &p_var->members[member_index];
2161     SpvReflectBlockVariable* p_next_member_var = &p_var->members[member_index + 1];
2162     p_member_var->padded_size = p_next_member_var->offset - p_member_var->offset;
2163     if (p_member_var->size > p_member_var->padded_size) {
2164       p_member_var->size = p_member_var->padded_size;
2165     }
2166     if (is_parent_rta) {
2167       p_member_var->padded_size = p_member_var->size;
2168     }
2169   }
2170   // ...last entry just gets rounded up to near multiple of SPIRV_DATA_ALIGNMENT, which is 16 and
2171   // subtract the offset.
2172   if (p_var->member_count > 0) {
2173     SpvReflectBlockVariable* p_member_var = &p_var->members[p_var->member_count - 1];
2174     p_member_var->padded_size = RoundUp(p_member_var->offset  + p_member_var->size, SPIRV_DATA_ALIGNMENT) - p_member_var->offset;
2175     if (p_member_var->size > p_member_var->padded_size) {
2176       p_member_var->size = p_member_var->padded_size;
2177     }
2178     if (is_parent_rta) {
2179       p_member_var->padded_size = p_member_var->size;
2180     }
2181   }
2182 
2183   // @TODO validate this with assertion
2184   p_var->size = p_var->members[p_var->member_count - 1].offset +
2185                 p_var->members[p_var->member_count - 1].padded_size;
2186   p_var->padded_size = p_var->size;
2187 
2188   return SPV_REFLECT_RESULT_SUCCESS;
2189 }
2190 
ParseDescriptorBlockVariableUsage(Parser * p_parser,SpvReflectShaderModule * p_module,AccessChain * p_access_chain,uint32_t index_index,SpvOp override_op_type,SpvReflectBlockVariable * p_var)2191 static SpvReflectResult ParseDescriptorBlockVariableUsage(
2192   Parser*                  p_parser,
2193   SpvReflectShaderModule*  p_module,
2194   AccessChain*             p_access_chain,
2195   uint32_t                 index_index,
2196   SpvOp                    override_op_type,
2197   SpvReflectBlockVariable* p_var
2198 )
2199 {
2200   (void)p_parser;
2201   (void)p_access_chain;
2202   (void)p_var;
2203 
2204   // Clear the current variable's USED flag
2205   p_var->flags &= ~SPV_REFLECT_VARIABLE_FLAGS_UNUSED;
2206 
2207   // Parsing arrays requires overriding the op type for
2208   // for the lowest dim's element type.
2209   SpvOp op_type = p_var->type_description->op;
2210   if (override_op_type != (SpvOp)INVALID_VALUE) {
2211     op_type = override_op_type;
2212   }
2213 
2214   switch (op_type) {
2215     default: break;
2216 
2217     case SpvOpTypeArray: {
2218       // Parse through array's type hierarchy to find the actual/non-array element type
2219       SpvReflectTypeDescription* p_type = p_var->type_description;
2220       while ((p_type->op == SpvOpTypeArray) && (index_index < p_access_chain->index_count)) {
2221         // Find the array element type id
2222         Node* p_node = FindNode(p_parser, p_type->id);
2223         if (p_node == NULL) {
2224           return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
2225         }
2226         uint32_t element_type_id = p_node->array_traits.element_type_id;
2227         // Get the array element type
2228         p_type = FindType(p_module, element_type_id);
2229         if (p_type == NULL) {
2230           return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
2231         }
2232         // Next access index
2233         index_index += 1;
2234       }
2235       // Parse current var again with a type override and advanced index index
2236       SpvReflectResult result = ParseDescriptorBlockVariableUsage(
2237         p_parser,
2238         p_module,
2239         p_access_chain,
2240         index_index,
2241         p_type->op,
2242         p_var);
2243       if (result != SPV_REFLECT_RESULT_SUCCESS) {
2244         return result;
2245       }
2246     }
2247     break;
2248 
2249     case SpvOpTypeStruct: {
2250       assert(p_var->member_count > 0);
2251       if (p_var->member_count == 0) {
2252         return SPV_REFLECT_RESULT_ERROR_SPIRV_UNEXPECTED_BLOCK_DATA;
2253       }
2254 
2255       uint32_t index = p_access_chain->indexes[index_index];
2256 
2257       if (index >= p_var->member_count) {
2258         return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_BLOCK_MEMBER_REFERENCE;
2259       }
2260 
2261       SpvReflectBlockVariable* p_member_var = &p_var->members[index];
2262       if (index_index < p_access_chain->index_count) {
2263         SpvReflectResult result = ParseDescriptorBlockVariableUsage(
2264           p_parser,
2265           p_module,
2266           p_access_chain,
2267           index_index + 1,
2268           (SpvOp)INVALID_VALUE,
2269           p_member_var);
2270         if (result != SPV_REFLECT_RESULT_SUCCESS) {
2271           return result;
2272         }
2273       }
2274     }
2275     break;
2276   }
2277 
2278   return SPV_REFLECT_RESULT_SUCCESS;
2279 }
2280 
ParseDescriptorBlocks(Parser * p_parser,SpvReflectShaderModule * p_module)2281 static SpvReflectResult ParseDescriptorBlocks(Parser* p_parser, SpvReflectShaderModule* p_module)
2282 {
2283   if (p_module->descriptor_binding_count == 0) {
2284     return SPV_REFLECT_RESULT_SUCCESS;
2285   }
2286 
2287   for (uint32_t descriptor_index = 0; descriptor_index < p_module->descriptor_binding_count; ++descriptor_index) {
2288     SpvReflectDescriptorBinding* p_descriptor = &(p_module->descriptor_bindings[descriptor_index]);
2289     SpvReflectTypeDescription* p_type = p_descriptor->type_description;
2290     if ((p_descriptor->descriptor_type != SPV_REFLECT_DESCRIPTOR_TYPE_UNIFORM_BUFFER) &&
2291         (p_descriptor->descriptor_type != SPV_REFLECT_DESCRIPTOR_TYPE_STORAGE_BUFFER) )
2292     {
2293       continue;
2294     }
2295 
2296     // Mark UNUSED
2297     p_descriptor->block.flags |= SPV_REFLECT_VARIABLE_FLAGS_UNUSED;
2298     // Parse descriptor block
2299     SpvReflectResult result = ParseDescriptorBlockVariable(p_parser, p_module, p_type, &p_descriptor->block);
2300     if (result != SPV_REFLECT_RESULT_SUCCESS) {
2301       return result;
2302     }
2303 
2304     for (uint32_t access_chain_index = 0; access_chain_index < p_parser->access_chain_count; ++access_chain_index) {
2305       AccessChain* p_access_chain = &(p_parser->access_chains[access_chain_index]);
2306       // Skip any access chains that aren't touching this descriptor block
2307       if (p_descriptor->spirv_id != p_access_chain->base_id) {
2308         continue;
2309       }
2310       result = ParseDescriptorBlockVariableUsage(
2311         p_parser,
2312         p_module,
2313         p_access_chain,
2314         0,
2315         (SpvOp)INVALID_VALUE,
2316         &p_descriptor->block);
2317       if (result != SPV_REFLECT_RESULT_SUCCESS) {
2318         return result;
2319       }
2320     }
2321 
2322     p_descriptor->block.name = p_descriptor->name;
2323 
2324     bool is_parent_rta = (p_descriptor->descriptor_type == SPV_REFLECT_DESCRIPTOR_TYPE_STORAGE_BUFFER);
2325     result = ParseDescriptorBlockVariableSizes(p_parser, p_module, true, false, is_parent_rta, &p_descriptor->block);
2326     if (result != SPV_REFLECT_RESULT_SUCCESS) {
2327       return result;
2328     }
2329 
2330     if (is_parent_rta) {
2331       p_descriptor->block.size = 0;
2332       p_descriptor->block.padded_size = 0;
2333     }
2334   }
2335 
2336   return SPV_REFLECT_RESULT_SUCCESS;
2337 }
2338 
ParseFormat(const SpvReflectTypeDescription * p_type,SpvReflectFormat * p_format)2339 static SpvReflectResult ParseFormat(
2340   const SpvReflectTypeDescription*  p_type,
2341   SpvReflectFormat*                 p_format
2342 )
2343 {
2344   SpvReflectResult result = SPV_REFLECT_RESULT_ERROR_INTERNAL_ERROR;
2345   bool signedness = (p_type->traits.numeric.scalar.signedness != 0);
2346   if (p_type->type_flags & SPV_REFLECT_TYPE_FLAG_VECTOR) {
2347     uint32_t component_count = p_type->traits.numeric.vector.component_count;
2348     if (p_type->type_flags & SPV_REFLECT_TYPE_FLAG_FLOAT) {
2349       switch (component_count) {
2350         case 2: *p_format = SPV_REFLECT_FORMAT_R32G32_SFLOAT; break;
2351         case 3: *p_format = SPV_REFLECT_FORMAT_R32G32B32_SFLOAT; break;
2352         case 4: *p_format = SPV_REFLECT_FORMAT_R32G32B32A32_SFLOAT; break;
2353       }
2354       result = SPV_REFLECT_RESULT_SUCCESS;
2355     }
2356     else if (p_type->type_flags & (SPV_REFLECT_TYPE_FLAG_INT | SPV_REFLECT_TYPE_FLAG_BOOL)) {
2357       switch (component_count) {
2358         case 2: *p_format = signedness ? SPV_REFLECT_FORMAT_R32G32_SINT : SPV_REFLECT_FORMAT_R32G32_UINT; break;
2359         case 3: *p_format = signedness ? SPV_REFLECT_FORMAT_R32G32B32_SINT : SPV_REFLECT_FORMAT_R32G32B32_UINT; break;
2360         case 4: *p_format = signedness ? SPV_REFLECT_FORMAT_R32G32B32A32_SINT : SPV_REFLECT_FORMAT_R32G32B32A32_UINT; break;
2361       }
2362       result = SPV_REFLECT_RESULT_SUCCESS;
2363     }
2364   }
2365   else if (p_type->type_flags & SPV_REFLECT_TYPE_FLAG_FLOAT) {
2366     *p_format = SPV_REFLECT_FORMAT_R32_SFLOAT;
2367     result = SPV_REFLECT_RESULT_SUCCESS;
2368   }
2369   else if (p_type->type_flags & (SPV_REFLECT_TYPE_FLAG_INT | SPV_REFLECT_TYPE_FLAG_BOOL)) {
2370     if (signedness) {
2371       *p_format = SPV_REFLECT_FORMAT_R32_SINT;
2372       result = SPV_REFLECT_RESULT_SUCCESS;
2373     }
2374     else {
2375       *p_format = SPV_REFLECT_FORMAT_R32_UINT;
2376       result = SPV_REFLECT_RESULT_SUCCESS;
2377     }
2378   }
2379   else if (p_type->type_flags & SPV_REFLECT_TYPE_FLAG_STRUCT) {
2380     *p_format = SPV_REFLECT_FORMAT_UNDEFINED;
2381     result = SPV_REFLECT_RESULT_SUCCESS;
2382   }
2383   return result;
2384 }
2385 
ParseInterfaceVariable(Parser * p_parser,const Decorations * p_type_node_decorations,SpvReflectShaderModule * p_module,SpvReflectTypeDescription * p_type,SpvReflectInterfaceVariable * p_var,bool * p_has_built_in)2386 static SpvReflectResult ParseInterfaceVariable(
2387   Parser*                      p_parser,
2388   const Decorations*           p_type_node_decorations,
2389   SpvReflectShaderModule*      p_module,
2390   SpvReflectTypeDescription*   p_type,
2391   SpvReflectInterfaceVariable* p_var,
2392   bool*                        p_has_built_in
2393 )
2394 {
2395   Node* p_type_node = FindNode(p_parser, p_type->id);
2396   if (IsNull(p_type_node)) {
2397     return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
2398   }
2399 
2400   if (p_type->member_count > 0) {
2401     p_var->member_count = p_type->member_count;
2402     p_var->members = (SpvReflectInterfaceVariable*)calloc(p_var->member_count, sizeof(*p_var->members));
2403     if (IsNull(p_var->members)) {
2404       return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
2405     }
2406 
2407     for (uint32_t member_index = 0; member_index < p_type_node->member_count; ++member_index) {
2408       Decorations* p_member_decorations = &p_type_node->member_decorations[member_index];
2409       SpvReflectTypeDescription* p_member_type = &p_type->members[member_index];
2410       SpvReflectInterfaceVariable* p_member_var = &p_var->members[member_index];
2411       SpvReflectResult result = ParseInterfaceVariable(p_parser, p_member_decorations, p_module, p_member_type, p_member_var, p_has_built_in);
2412       if (result != SPV_REFLECT_RESULT_SUCCESS) {
2413         return result;
2414       }
2415     }
2416   }
2417 
2418   p_var->name = p_type_node->name;
2419   p_var->decoration_flags = ApplyDecorations(p_type_node_decorations);
2420   p_var->built_in = p_type_node_decorations->built_in;
2421   ApplyNumericTraits(p_type, &p_var->numeric);
2422   if (p_type->op == SpvOpTypeArray) {
2423     ApplyArrayTraits(p_type, &p_var->array);
2424   }
2425 
2426   p_var->type_description = p_type;
2427 
2428   *p_has_built_in |= p_type_node_decorations->is_built_in;
2429 
2430   SpvReflectResult result = ParseFormat(p_var->type_description, &p_var->format);
2431   if (result != SPV_REFLECT_RESULT_SUCCESS) {
2432     return result;
2433   }
2434 
2435   return SPV_REFLECT_RESULT_SUCCESS;
2436 }
2437 
ParseInterfaceVariables(Parser * p_parser,SpvReflectShaderModule * p_module,SpvReflectEntryPoint * p_entry,size_t io_var_count,uint32_t * io_vars)2438 static SpvReflectResult ParseInterfaceVariables(
2439   Parser*                 p_parser,
2440   SpvReflectShaderModule* p_module,
2441   SpvReflectEntryPoint*   p_entry,
2442   size_t                  io_var_count,
2443   uint32_t*               io_vars
2444 )
2445 {
2446   if (io_var_count == 0) {
2447     return SPV_REFLECT_RESULT_SUCCESS;
2448   }
2449 
2450   p_entry->input_variable_count = 0;
2451   p_entry->output_variable_count = 0;
2452   for (size_t i = 0; i < io_var_count; ++i) {
2453     uint32_t var_result_id = *(io_vars + i);
2454     Node* p_node = FindNode(p_parser, var_result_id);
2455     if (IsNull(p_node)) {
2456       return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
2457     }
2458 
2459     if (p_node->storage_class == SpvStorageClassInput) {
2460       p_entry->input_variable_count += 1;
2461     }
2462     else if (p_node->storage_class == SpvStorageClassOutput) {
2463       p_entry->output_variable_count += 1;
2464     }
2465   }
2466 
2467   if (p_entry->input_variable_count > 0) {
2468     p_entry->input_variables = (SpvReflectInterfaceVariable*)calloc(p_entry->input_variable_count, sizeof(*(p_entry->input_variables)));
2469     if (IsNull(p_entry->input_variables)) {
2470       return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
2471     }
2472   }
2473 
2474 
2475   if (p_entry->output_variable_count > 0) {
2476     p_entry->output_variables = (SpvReflectInterfaceVariable*)calloc(p_entry->output_variable_count, sizeof(*(p_entry->output_variables)));
2477     if (IsNull(p_entry->output_variables)) {
2478       return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
2479     }
2480   }
2481 
2482   size_t input_index = 0;
2483   size_t output_index = 0;
2484   for (size_t i = 0; i < io_var_count; ++i) {
2485     uint32_t var_result_id = *(io_vars + i);
2486     Node* p_node = FindNode(p_parser, var_result_id);
2487     if (IsNull(p_node)) {
2488       return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
2489     }
2490 
2491     SpvReflectTypeDescription* p_type = FindType(p_module, p_node->type_id);
2492     if (IsNull(p_node)) {
2493       return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
2494     }
2495     // If the type is a pointer, resolve it
2496     if (p_type->op == SpvOpTypePointer) {
2497       // Find the type's node
2498       Node* p_type_node = FindNode(p_parser, p_type->id);
2499       if (IsNull(p_type_node)) {
2500         return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
2501       }
2502       // Should be the resolved type
2503       p_type = FindType(p_module, p_type_node->type_id);
2504       if (IsNull(p_type)) {
2505         return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
2506       }
2507     }
2508 
2509     Node* p_type_node = FindNode(p_parser, p_type->id);
2510     if (IsNull(p_type_node)) {
2511       return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
2512     }
2513 
2514     SpvReflectInterfaceVariable* p_var = NULL;
2515     if (p_node->storage_class == SpvStorageClassInput) {
2516      p_var = &(p_entry->input_variables[input_index]);
2517      p_var->storage_class = SpvStorageClassInput;
2518       ++input_index;
2519     }
2520     else if (p_node->storage_class == SpvStorageClassOutput) {
2521       p_var = &(p_entry->output_variables[output_index]);
2522       p_var->storage_class = SpvStorageClassOutput;
2523       ++output_index;
2524     } else {
2525       // interface variables can only have input or output storage classes;
2526       // anything else is either a new addition or an error.
2527       assert(false && "Unsupported storage class for interface variable");
2528       return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_STORAGE_CLASS;
2529     }
2530 
2531     bool has_built_in = p_node->decorations.is_built_in;
2532     SpvReflectResult result = ParseInterfaceVariable(
2533       p_parser,
2534       &p_type_node->decorations,
2535       p_module,
2536       p_type,
2537       p_var,
2538       &has_built_in);
2539     if (result != SPV_REFLECT_RESULT_SUCCESS) {
2540       return result;
2541     }
2542 
2543     // SPIR-V result id
2544     p_var->spirv_id = p_node->result_id;
2545     // Name
2546     p_var->name = p_node->name;
2547     // Semantic
2548     p_var->semantic = p_node->decorations.semantic.value;
2549 
2550     // Decorate with built-in if any member is built-in
2551     if (has_built_in) {
2552       p_var->decoration_flags |= SPV_REFLECT_DECORATION_BUILT_IN;
2553     }
2554 
2555     // Location is decorated on OpVariable node, not the type node.
2556     p_var->location = p_node->decorations.location.value;
2557     p_var->word_offset.location = p_node->decorations.location.word_offset;
2558 
2559     // Built in
2560     if (p_node->decorations.is_built_in) {
2561       p_var->built_in = p_node->decorations.built_in;
2562     }
2563   }
2564 
2565   return SPV_REFLECT_RESULT_SUCCESS;
2566 }
2567 
EnumerateAllPushConstants(SpvReflectShaderModule * p_module,size_t * p_push_constant_count,uint32_t ** p_push_constants)2568 static SpvReflectResult EnumerateAllPushConstants(
2569   SpvReflectShaderModule* p_module,
2570   size_t*                 p_push_constant_count,
2571   uint32_t**              p_push_constants
2572 )
2573 {
2574   *p_push_constant_count = p_module->push_constant_block_count;
2575   if (*p_push_constant_count == 0) {
2576     return SPV_REFLECT_RESULT_SUCCESS;
2577   }
2578   *p_push_constants = (uint32_t*)calloc(*p_push_constant_count, sizeof(**p_push_constants));
2579 
2580   if (IsNull(*p_push_constants)) {
2581     return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
2582   }
2583 
2584   for (size_t i = 0; i < *p_push_constant_count; ++i) {
2585     (*p_push_constants)[i] = p_module->push_constant_blocks[i].spirv_id;
2586   }
2587   qsort(*p_push_constants, *p_push_constant_count, sizeof(**p_push_constants),
2588         SortCompareUint32);
2589   return SPV_REFLECT_RESULT_SUCCESS;
2590 }
2591 
TraverseCallGraph(Parser * p_parser,Function * p_func,size_t * p_func_count,uint32_t * p_func_ids,uint32_t depth)2592 static SpvReflectResult TraverseCallGraph(
2593   Parser*   p_parser,
2594   Function* p_func,
2595   size_t*   p_func_count,
2596   uint32_t* p_func_ids,
2597   uint32_t  depth
2598 )
2599 {
2600   if (depth > p_parser->function_count) {
2601     // Vulkan does not permit recursion (Vulkan spec Appendix A):
2602     //   "Recursion: The static function-call graph for an entry point must not
2603     //    contain cycles."
2604     return SPV_REFLECT_RESULT_ERROR_SPIRV_RECURSION;
2605   }
2606   if (IsNotNull(p_func_ids)) {
2607     p_func_ids[(*p_func_count)++] = p_func->id;
2608   } else {
2609     ++*p_func_count;
2610   }
2611   for (size_t i = 0; i < p_func->callee_count; ++i) {
2612     SpvReflectResult result = TraverseCallGraph(
2613         p_parser, p_func->callee_ptrs[i], p_func_count, p_func_ids, depth + 1);
2614     if (result != SPV_REFLECT_RESULT_SUCCESS) {
2615       return result;
2616     }
2617   }
2618   return SPV_REFLECT_RESULT_SUCCESS;
2619 }
2620 
ParseStaticallyUsedResources(Parser * p_parser,SpvReflectShaderModule * p_module,SpvReflectEntryPoint * p_entry,size_t uniform_count,uint32_t * uniforms,size_t push_constant_count,uint32_t * push_constants)2621 static SpvReflectResult ParseStaticallyUsedResources(
2622   Parser*                 p_parser,
2623   SpvReflectShaderModule* p_module,
2624   SpvReflectEntryPoint*   p_entry,
2625   size_t                  uniform_count,
2626   uint32_t*               uniforms,
2627   size_t                  push_constant_count,
2628   uint32_t*               push_constants
2629 )
2630 {
2631   // Find function with the right id
2632   Function* p_func = NULL;
2633   for (size_t i = 0; i < p_parser->function_count; ++i) {
2634     if (p_parser->functions[i].id == p_entry->id) {
2635       p_func = &(p_parser->functions[i]);
2636       break;
2637     }
2638   }
2639   if (p_func == NULL) {
2640     return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
2641   }
2642 
2643   size_t called_function_count = 0;
2644   SpvReflectResult result = TraverseCallGraph(
2645     p_parser,
2646     p_func,
2647     &called_function_count,
2648     NULL,
2649     0);
2650   if (result != SPV_REFLECT_RESULT_SUCCESS) {
2651     return result;
2652   }
2653 
2654   uint32_t* p_called_functions = NULL;
2655   if (called_function_count > 0) {
2656     p_called_functions = (uint32_t*)calloc(called_function_count, sizeof(*p_called_functions));
2657     if (IsNull(p_called_functions)) {
2658       return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
2659     }
2660   }
2661 
2662   called_function_count = 0;
2663   result = TraverseCallGraph(
2664     p_parser,
2665     p_func,
2666     &called_function_count,
2667     p_called_functions,
2668     0);
2669   if (result != SPV_REFLECT_RESULT_SUCCESS) {
2670     return result;
2671   }
2672 
2673   if (called_function_count > 0) {
2674     qsort(
2675       p_called_functions,
2676       called_function_count,
2677       sizeof(*p_called_functions),
2678       SortCompareUint32);
2679   }
2680   called_function_count = DedupSortedUint32(p_called_functions, called_function_count);
2681 
2682   uint32_t used_variable_count = 0;
2683   for (size_t i = 0, j = 0; i < called_function_count; ++i) {
2684     // No need to bounds check j because a missing ID issue would have been
2685     // found during TraverseCallGraph
2686     while (p_parser->functions[j].id != p_called_functions[i]) {
2687       ++j;
2688     }
2689     used_variable_count += p_parser->functions[j].accessed_ptr_count;
2690   }
2691   uint32_t* used_variables = NULL;
2692   if (used_variable_count > 0) {
2693     used_variables = (uint32_t*)calloc(used_variable_count,
2694                                        sizeof(*used_variables));
2695     if (IsNull(used_variables)) {
2696       SafeFree(p_called_functions);
2697       return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
2698     }
2699   }
2700   used_variable_count = 0;
2701   for (size_t i = 0, j = 0; i < called_function_count; ++i) {
2702     while (p_parser->functions[j].id != p_called_functions[i]) {
2703       ++j;
2704     }
2705 
2706     memcpy(&used_variables[used_variable_count],
2707            p_parser->functions[j].accessed_ptrs,
2708            p_parser->functions[j].accessed_ptr_count * sizeof(*used_variables));
2709     used_variable_count += p_parser->functions[j].accessed_ptr_count;
2710   }
2711   SafeFree(p_called_functions);
2712 
2713   if (used_variable_count > 0) {
2714     qsort(used_variables, used_variable_count, sizeof(*used_variables),
2715           SortCompareUint32);
2716   }
2717   used_variable_count = (uint32_t)DedupSortedUint32(used_variables,
2718                                                     used_variable_count);
2719 
2720   // Do set intersection to find the used uniform and push constants
2721   size_t used_uniform_count = 0;
2722   //
2723   SpvReflectResult result0 = IntersectSortedUint32(
2724     used_variables,
2725     used_variable_count,
2726     uniforms,
2727     uniform_count,
2728     &p_entry->used_uniforms,
2729     &used_uniform_count);
2730 
2731   size_t used_push_constant_count = 0;
2732   //
2733   SpvReflectResult result1 = IntersectSortedUint32(
2734     used_variables,
2735     used_variable_count,
2736     push_constants,
2737     push_constant_count,
2738     &p_entry->used_push_constants,
2739     &used_push_constant_count);
2740 
2741   for (uint32_t j = 0; j < p_module->descriptor_binding_count; ++j) {
2742     SpvReflectDescriptorBinding* p_binding = &p_module->descriptor_bindings[j];
2743     bool found = SearchSortedUint32(
2744       used_variables,
2745       used_variable_count,
2746       p_binding->spirv_id);
2747     if (found) {
2748       p_binding->accessed = 1;
2749     }
2750   }
2751 
2752   SafeFree(used_variables);
2753   if (result0 != SPV_REFLECT_RESULT_SUCCESS) {
2754     return result0;
2755   }
2756   if (result1 != SPV_REFLECT_RESULT_SUCCESS) {
2757     return result1;
2758   }
2759 
2760   p_entry->used_uniform_count = (uint32_t)used_uniform_count;
2761   p_entry->used_push_constant_count = (uint32_t)used_push_constant_count;
2762 
2763   return SPV_REFLECT_RESULT_SUCCESS;
2764 }
2765 
ParseEntryPoints(Parser * p_parser,SpvReflectShaderModule * p_module)2766 static SpvReflectResult ParseEntryPoints(Parser* p_parser, SpvReflectShaderModule* p_module)
2767 {
2768   if (p_parser->entry_point_count == 0) {
2769     return SPV_REFLECT_RESULT_SUCCESS;
2770   }
2771 
2772   p_module->entry_point_count = p_parser->entry_point_count;
2773   p_module->entry_points = (SpvReflectEntryPoint*)calloc(p_module->entry_point_count,
2774                                                          sizeof(*(p_module->entry_points)));
2775   if (IsNull(p_module->entry_points)) {
2776     return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
2777   }
2778 
2779   SpvReflectResult result;
2780   size_t uniform_count = 0;
2781   uint32_t* uniforms = NULL;
2782   if ((result = EnumerateAllUniforms(p_module, &uniform_count, &uniforms)) !=
2783        SPV_REFLECT_RESULT_SUCCESS) {
2784     return result;
2785   }
2786   size_t push_constant_count = 0;
2787   uint32_t* push_constants = NULL;
2788   if ((result = EnumerateAllPushConstants(p_module, &push_constant_count, &push_constants)) !=
2789        SPV_REFLECT_RESULT_SUCCESS) {
2790     return result;
2791   }
2792 
2793   size_t entry_point_index = 0;
2794   for (size_t i = 0; entry_point_index < p_parser->entry_point_count && i < p_parser->node_count; ++i) {
2795     Node* p_node = &(p_parser->nodes[i]);
2796     if (p_node->op != SpvOpEntryPoint) {
2797       continue;
2798     }
2799 
2800     SpvReflectEntryPoint* p_entry_point = &(p_module->entry_points[entry_point_index]);
2801     CHECKED_READU32_CAST(p_parser, p_node->word_offset + 1, SpvExecutionModel, p_entry_point->spirv_execution_model);
2802     CHECKED_READU32(p_parser, p_node->word_offset + 2, p_entry_point->id);
2803 
2804     switch (p_entry_point->spirv_execution_model) {
2805       default: break;
2806       case SpvExecutionModelVertex                 : p_entry_point->shader_stage = SPV_REFLECT_SHADER_STAGE_VERTEX_BIT; break;
2807       case SpvExecutionModelTessellationControl    : p_entry_point->shader_stage = SPV_REFLECT_SHADER_STAGE_TESSELLATION_CONTROL_BIT; break;
2808       case SpvExecutionModelTessellationEvaluation : p_entry_point->shader_stage = SPV_REFLECT_SHADER_STAGE_TESSELLATION_EVALUATION_BIT; break;
2809       case SpvExecutionModelGeometry               : p_entry_point->shader_stage = SPV_REFLECT_SHADER_STAGE_GEOMETRY_BIT; break;
2810       case SpvExecutionModelFragment               : p_entry_point->shader_stage = SPV_REFLECT_SHADER_STAGE_FRAGMENT_BIT; break;
2811       case SpvExecutionModelGLCompute              : p_entry_point->shader_stage = SPV_REFLECT_SHADER_STAGE_COMPUTE_BIT; break;
2812     }
2813 
2814     ++entry_point_index;
2815 
2816     // Name length is required to calculate next operand
2817     uint32_t name_start_word_offset = 3;
2818     uint32_t name_length_with_terminator = 0;
2819     result = ReadStr(p_parser, p_node->word_offset + name_start_word_offset, 0, p_node->word_count, &name_length_with_terminator, NULL);
2820     if (result != SPV_REFLECT_RESULT_SUCCESS) {
2821       return result;
2822     }
2823     p_entry_point->name = (const char*)(p_parser->spirv_code + p_node->word_offset + name_start_word_offset);
2824 
2825     uint32_t name_word_count = RoundUp(name_length_with_terminator, SPIRV_WORD_SIZE) / SPIRV_WORD_SIZE;
2826     size_t interface_variable_count = (p_node->word_count - (name_start_word_offset + name_word_count));
2827     uint32_t* interface_variables = NULL;
2828     if (interface_variable_count > 0) {
2829       interface_variables = (uint32_t*)calloc(interface_variable_count, sizeof(*(interface_variables)));
2830       if (IsNull(interface_variables)) {
2831         return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
2832       }
2833     }
2834 
2835     for (uint32_t var_index = 0; var_index < interface_variable_count; ++var_index) {
2836       uint32_t var_result_id = (uint32_t)INVALID_VALUE;
2837       uint32_t offset = name_start_word_offset + name_word_count + var_index;
2838       CHECKED_READU32(p_parser, p_node->word_offset + offset, var_result_id);
2839       interface_variables[var_index] = var_result_id;
2840     }
2841 
2842     result = ParseInterfaceVariables(
2843       p_parser,
2844       p_module,
2845       p_entry_point,
2846       interface_variable_count,
2847       interface_variables);
2848     if (result != SPV_REFLECT_RESULT_SUCCESS) {
2849       return result;
2850     }
2851     SafeFree(interface_variables);
2852 
2853     result = ParseStaticallyUsedResources(
2854       p_parser,
2855       p_module,
2856       p_entry_point,
2857       uniform_count,
2858       uniforms,
2859       push_constant_count,
2860       push_constants);
2861     if (result != SPV_REFLECT_RESULT_SUCCESS) {
2862       return result;
2863     }
2864   }
2865 
2866   SafeFree(uniforms);
2867   SafeFree(push_constants);
2868 
2869   return SPV_REFLECT_RESULT_SUCCESS;
2870 }
2871 
ParseExecutionModes(Parser * p_parser,SpvReflectShaderModule * p_module)2872 static SpvReflectResult ParseExecutionModes(Parser* p_parser, SpvReflectShaderModule* p_module)
2873 {
2874   assert(IsNotNull(p_parser));
2875   assert(IsNotNull(p_parser->nodes));
2876   assert(IsNotNull(p_module));
2877 
2878   if (IsNotNull(p_parser) && IsNotNull(p_parser->spirv_code) && IsNotNull(p_parser->nodes)) {
2879     for (size_t node_idx = 0; node_idx < p_parser->node_count; ++node_idx) {
2880       Node* p_node = &(p_parser->nodes[node_idx]);
2881       if (p_node->op != SpvOpExecutionMode) {
2882         continue;
2883       }
2884 
2885       // Read entry point id
2886       uint32_t entry_point_id = 0;
2887       CHECKED_READU32(p_parser, p_node->word_offset + 1, entry_point_id);
2888 
2889       // Find entry point
2890       SpvReflectEntryPoint* p_entry_point = NULL;
2891       for (size_t entry_point_idx = 0; entry_point_idx < p_module->entry_point_count; ++entry_point_idx) {
2892         if (p_module->entry_points[entry_point_idx].id == entry_point_id) {
2893           p_entry_point = &p_module->entry_points[entry_point_idx];
2894           break;
2895         }
2896       }
2897       // Bail if entry point is null
2898       if (IsNull(p_entry_point)) {
2899         return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ENTRY_POINT;
2900       }
2901 
2902       // Read execution mode
2903       uint32_t execution_mode = (uint32_t)INVALID_VALUE;
2904       CHECKED_READU32(p_parser, p_node->word_offset + 2, execution_mode);
2905 
2906       // Parse execution mode
2907       switch (execution_mode) {
2908         default: {
2909           return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_EXECUTION_MODE;
2910         }
2911         break;
2912 
2913         case SpvExecutionModeInvocations:
2914         case SpvExecutionModeSpacingEqual:
2915         case SpvExecutionModeSpacingFractionalEven:
2916         case SpvExecutionModeSpacingFractionalOdd:
2917         case SpvExecutionModeVertexOrderCw:
2918         case SpvExecutionModeVertexOrderCcw:
2919         case SpvExecutionModePixelCenterInteger:
2920         case SpvExecutionModeOriginUpperLeft:
2921         case SpvExecutionModeOriginLowerLeft:
2922         case SpvExecutionModeEarlyFragmentTests:
2923         case SpvExecutionModePointMode:
2924         case SpvExecutionModeXfb:
2925         case SpvExecutionModeDepthReplacing:
2926         case SpvExecutionModeDepthGreater:
2927         case SpvExecutionModeDepthLess:
2928         case SpvExecutionModeDepthUnchanged:
2929           break;
2930 
2931         case SpvExecutionModeLocalSize: {
2932           CHECKED_READU32(p_parser, p_node->word_offset + 3, p_entry_point->local_size.x);
2933           CHECKED_READU32(p_parser, p_node->word_offset + 4, p_entry_point->local_size.y);
2934           CHECKED_READU32(p_parser, p_node->word_offset + 5, p_entry_point->local_size.z);
2935         }
2936         break;
2937 
2938         case SpvExecutionModeLocalSizeHint:
2939         case SpvExecutionModeInputPoints:
2940         case SpvExecutionModeInputLines:
2941         case SpvExecutionModeInputLinesAdjacency:
2942         case SpvExecutionModeTriangles:
2943         case SpvExecutionModeInputTrianglesAdjacency:
2944         case SpvExecutionModeQuads:
2945         case SpvExecutionModeIsolines:
2946         case SpvExecutionModeOutputVertices:
2947         case SpvExecutionModeOutputPoints:
2948         case SpvExecutionModeOutputLineStrip:
2949         case SpvExecutionModeOutputTriangleStrip:
2950         case SpvExecutionModeVecTypeHint:
2951         case SpvExecutionModeContractionOff:
2952         case SpvExecutionModeInitializer:
2953         case SpvExecutionModeFinalizer:
2954         case SpvExecutionModeSubgroupSize:
2955         case SpvExecutionModeSubgroupsPerWorkgroup:
2956         case SpvExecutionModeSubgroupsPerWorkgroupId:
2957         case SpvExecutionModeLocalSizeId:
2958         case SpvExecutionModeLocalSizeHintId:
2959         case SpvExecutionModePostDepthCoverage:
2960         case SpvExecutionModeStencilRefReplacingEXT:
2961           break;
2962       }
2963     }
2964   }
2965   return SPV_REFLECT_RESULT_SUCCESS;
2966 }
2967 
ParsePushConstantBlocks(Parser * p_parser,SpvReflectShaderModule * p_module)2968 static SpvReflectResult ParsePushConstantBlocks(Parser* p_parser, SpvReflectShaderModule* p_module)
2969 {
2970   for (size_t i = 0; i < p_parser->node_count; ++i) {
2971     Node* p_node = &(p_parser->nodes[i]);
2972     if ((p_node->op != SpvOpVariable) || (p_node->storage_class != SpvStorageClassPushConstant)) {
2973       continue;
2974     }
2975 
2976     p_module->push_constant_block_count += 1;
2977   }
2978 
2979   if (p_module->push_constant_block_count == 0) {
2980     return SPV_REFLECT_RESULT_SUCCESS;
2981   }
2982 
2983   p_module->push_constant_blocks = (SpvReflectBlockVariable*)calloc(p_module->push_constant_block_count, sizeof(*p_module->push_constant_blocks));
2984   if (IsNull(p_module->push_constant_blocks)) {
2985     return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
2986   }
2987 
2988   uint32_t push_constant_index = 0;
2989   for (size_t i = 0; i < p_parser->node_count; ++i) {
2990     Node* p_node = &(p_parser->nodes[i]);
2991     if ((p_node->op != SpvOpVariable) || (p_node->storage_class != SpvStorageClassPushConstant)) {
2992       continue;
2993     }
2994 
2995     SpvReflectTypeDescription* p_type = FindType(p_module, p_node->type_id);
2996     if (IsNull(p_node)) {
2997       return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
2998     }
2999     // If the type is a pointer, resolve it
3000     if (p_type->op == SpvOpTypePointer) {
3001       // Find the type's node
3002       Node* p_type_node = FindNode(p_parser, p_type->id);
3003       if (IsNull(p_type_node)) {
3004         return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
3005       }
3006       // Should be the resolved type
3007       p_type = FindType(p_module, p_type_node->type_id);
3008       if (IsNull(p_type)) {
3009         return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
3010       }
3011     }
3012 
3013     Node* p_type_node = FindNode(p_parser, p_type->id);
3014     if (IsNull(p_type_node)) {
3015       return SPV_REFLECT_RESULT_ERROR_SPIRV_INVALID_ID_REFERENCE;
3016     }
3017 
3018     SpvReflectBlockVariable* p_push_constant = &p_module->push_constant_blocks[push_constant_index];
3019     p_push_constant->spirv_id = p_node->result_id;
3020     SpvReflectResult result = ParseDescriptorBlockVariable(p_parser, p_module, p_type, p_push_constant);
3021     if (result != SPV_REFLECT_RESULT_SUCCESS) {
3022       return result;
3023     }
3024     result = ParseDescriptorBlockVariableSizes(p_parser, p_module, true, false, false, p_push_constant);
3025     if (result != SPV_REFLECT_RESULT_SUCCESS) {
3026       return result;
3027     }
3028 
3029     ++push_constant_index;
3030   }
3031 
3032   return SPV_REFLECT_RESULT_SUCCESS;
3033 }
3034 
SortCompareDescriptorSet(const void * a,const void * b)3035 static int SortCompareDescriptorSet(const void* a, const void* b)
3036 {
3037   const SpvReflectDescriptorSet* p_elem_a = (const SpvReflectDescriptorSet*)a;
3038   const SpvReflectDescriptorSet* p_elem_b = (const SpvReflectDescriptorSet*)b;
3039   int value = (int)(p_elem_a->set) - (int)(p_elem_b->set);
3040   // We should never see duplicate descriptor set numbers in a shader; if so, a tiebreaker
3041   // would be needed here.
3042   assert(value != 0);
3043   return value;
3044 }
3045 
ParseEntrypointDescriptorSets(SpvReflectShaderModule * p_module)3046 static SpvReflectResult ParseEntrypointDescriptorSets(SpvReflectShaderModule* p_module) {
3047   // Update the entry point's sets
3048   for (uint32_t i = 0; i < p_module->entry_point_count; ++i) {
3049     SpvReflectEntryPoint* p_entry = &p_module->entry_points[i];
3050     for (uint32_t j = 0; j < p_entry->descriptor_set_count; ++j) {
3051       SafeFree(p_entry->descriptor_sets[j].bindings);
3052     }
3053     SafeFree(p_entry->descriptor_sets);
3054     p_entry->descriptor_set_count = 0;
3055     for (uint32_t j = 0; j < p_module->descriptor_set_count; ++j) {
3056       const SpvReflectDescriptorSet* p_set = &p_module->descriptor_sets[j];
3057       for (uint32_t k = 0; k < p_set->binding_count; ++k) {
3058         bool found = SearchSortedUint32(
3059           p_entry->used_uniforms,
3060           p_entry->used_uniform_count,
3061           p_set->bindings[k]->spirv_id);
3062         if (found) {
3063           ++p_entry->descriptor_set_count;
3064           break;
3065         }
3066       }
3067     }
3068 
3069     p_entry->descriptor_sets = NULL;
3070     if (p_entry->descriptor_set_count > 0) {
3071       p_entry->descriptor_sets = (SpvReflectDescriptorSet*)calloc(p_entry->descriptor_set_count,
3072                                                                   sizeof(*p_entry->descriptor_sets));
3073       if (IsNull(p_entry->descriptor_sets)) {
3074         return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
3075       }
3076     }
3077     p_entry->descriptor_set_count = 0;
3078     for (uint32_t j = 0; j < p_module->descriptor_set_count; ++j) {
3079       const SpvReflectDescriptorSet* p_set = &p_module->descriptor_sets[j];
3080       uint32_t count = 0;
3081       for (uint32_t k = 0; k < p_set->binding_count; ++k) {
3082         bool found = SearchSortedUint32(
3083           p_entry->used_uniforms,
3084           p_entry->used_uniform_count,
3085           p_set->bindings[k]->spirv_id);
3086         if (found) {
3087           ++count;
3088         }
3089       }
3090       if (count == 0) {
3091         continue;
3092       }
3093       SpvReflectDescriptorSet* p_entry_set = &p_entry->descriptor_sets[
3094           p_entry->descriptor_set_count++];
3095       p_entry_set->set = p_set->set;
3096       p_entry_set->bindings = (SpvReflectDescriptorBinding**)calloc(count,
3097                                                                     sizeof(*p_entry_set->bindings));
3098       if (IsNull(p_entry_set->bindings)) {
3099         return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
3100       }
3101       for (uint32_t k = 0; k < p_set->binding_count; ++k) {
3102         bool found = SearchSortedUint32(
3103           p_entry->used_uniforms,
3104           p_entry->used_uniform_count,
3105           p_set->bindings[k]->spirv_id);
3106         if (found) {
3107           p_entry_set->bindings[p_entry_set->binding_count++] = p_set->bindings[k];
3108         }
3109       }
3110     }
3111   }
3112 
3113   return SPV_REFLECT_RESULT_SUCCESS;
3114 }
3115 
ParseDescriptorSets(SpvReflectShaderModule * p_module)3116 static SpvReflectResult ParseDescriptorSets(SpvReflectShaderModule* p_module)
3117 {
3118   // Count the descriptors in each set
3119   for (uint32_t i = 0; i < p_module->descriptor_binding_count; ++i) {
3120     SpvReflectDescriptorBinding* p_descriptor = &(p_module->descriptor_bindings[i]);
3121 
3122     // Look for a target set using the descriptor's set number
3123     SpvReflectDescriptorSet* p_target_set = NULL;
3124     for (uint32_t j = 0; j < SPV_REFLECT_MAX_DESCRIPTOR_SETS; ++j) {
3125       SpvReflectDescriptorSet* p_set = &p_module->descriptor_sets[j];
3126       if (p_set->set == p_descriptor->set) {
3127         p_target_set = p_set;
3128         break;
3129       }
3130     }
3131 
3132     // If a target set isn't found, find the first available one.
3133     if (IsNull(p_target_set)) {
3134       for (uint32_t j = 0; j < SPV_REFLECT_MAX_DESCRIPTOR_SETS; ++j) {
3135         SpvReflectDescriptorSet* p_set = &p_module->descriptor_sets[j];
3136         if (p_set->set == (uint32_t)INVALID_VALUE) {
3137           p_target_set = p_set;
3138           p_target_set->set = p_descriptor->set;
3139           break;
3140         }
3141       }
3142     }
3143 
3144     if (IsNull(p_target_set)) {
3145       return SPV_REFLECT_RESULT_ERROR_INTERNAL_ERROR;
3146     }
3147 
3148     p_target_set->binding_count += 1;
3149   }
3150 
3151   // Count the descriptor sets
3152   for (uint32_t i = 0; i < SPV_REFLECT_MAX_DESCRIPTOR_SETS; ++i) {
3153     const SpvReflectDescriptorSet* p_set = &p_module->descriptor_sets[i];
3154     if (p_set->set != (uint32_t)INVALID_VALUE) {
3155       p_module->descriptor_set_count += 1;
3156     }
3157   }
3158 
3159   // Sort the descriptor sets based on numbers
3160   if (p_module->descriptor_set_count > 0) {
3161     qsort(p_module->descriptor_sets,
3162           p_module->descriptor_set_count,
3163           sizeof(*(p_module->descriptor_sets)),
3164           SortCompareDescriptorSet);
3165   }
3166 
3167   // Build descriptor pointer array
3168   for (uint32_t i = 0; i <p_module->descriptor_set_count; ++i) {
3169     SpvReflectDescriptorSet* p_set = &(p_module->descriptor_sets[i]);
3170     p_set->bindings = (SpvReflectDescriptorBinding **)calloc(p_set->binding_count, sizeof(*(p_set->bindings)));
3171 
3172     uint32_t descriptor_index = 0;
3173     for (uint32_t j = 0; j < p_module->descriptor_binding_count; ++j) {
3174       SpvReflectDescriptorBinding* p_descriptor = &(p_module->descriptor_bindings[j]);
3175       if (p_descriptor->set == p_set->set) {
3176         assert(descriptor_index < p_set->binding_count);
3177         p_set->bindings[descriptor_index] = p_descriptor;
3178         ++descriptor_index;
3179       }
3180     }
3181   }
3182 
3183   return ParseEntrypointDescriptorSets(p_module);
3184 }
3185 
DisambiguateStorageBufferSrvUav(SpvReflectShaderModule * p_module)3186 static SpvReflectResult DisambiguateStorageBufferSrvUav(SpvReflectShaderModule* p_module)
3187 {
3188   if (p_module->descriptor_binding_count == 0) {
3189     return SPV_REFLECT_RESULT_SUCCESS;
3190   }
3191 
3192   for (uint32_t descriptor_index = 0; descriptor_index < p_module->descriptor_binding_count; ++descriptor_index) {
3193     SpvReflectDescriptorBinding* p_descriptor = &(p_module->descriptor_bindings[descriptor_index]);
3194     // Skip everything that isn't a STORAGE_BUFFER descriptor
3195     if (p_descriptor->descriptor_type != SPV_REFLECT_DESCRIPTOR_TYPE_STORAGE_BUFFER) {
3196       continue;
3197     }
3198 
3199     //
3200     // Vulkan doesn't disambiguate between SRVs and UAVs so they
3201     // come back as STORAGE_BUFFER. The block parsing process will
3202     // mark a block as non-writable should any member of the block
3203     // or its descendants are non-writable.
3204     //
3205     if (p_descriptor->block.decoration_flags & SPV_REFLECT_DECORATION_NON_WRITABLE) {
3206       p_descriptor->resource_type = SPV_REFLECT_RESOURCE_FLAG_SRV;
3207     }
3208   }
3209 
3210   return SPV_REFLECT_RESULT_SUCCESS;
3211 }
3212 
SynchronizeDescriptorSets(SpvReflectShaderModule * p_module)3213 static SpvReflectResult SynchronizeDescriptorSets(SpvReflectShaderModule* p_module)
3214 {
3215   // Free and reset all descriptor set numbers
3216   for (uint32_t i = 0; i < SPV_REFLECT_MAX_DESCRIPTOR_SETS; ++i) {
3217     SpvReflectDescriptorSet* p_set = &p_module->descriptor_sets[i];
3218     SafeFree(p_set->bindings);
3219     p_set->binding_count = 0;
3220     p_set->set = (uint32_t)INVALID_VALUE;
3221   }
3222   // Set descriptor set count to zero
3223   p_module->descriptor_set_count = 0;
3224 
3225   SpvReflectResult result = ParseDescriptorSets(p_module);
3226   return result;
3227 }
3228 
spvReflectGetShaderModule(size_t size,const void * p_code,SpvReflectShaderModule * p_module)3229 SpvReflectResult spvReflectGetShaderModule(
3230   size_t                   size,
3231   const void*              p_code,
3232   SpvReflectShaderModule*  p_module
3233 )
3234 {
3235   return spvReflectCreateShaderModule(size, p_code, p_module);
3236 }
3237 
spvReflectCreateShaderModule(size_t size,const void * p_code,SpvReflectShaderModule * p_module)3238 SpvReflectResult spvReflectCreateShaderModule(
3239   size_t                   size,
3240   const void*              p_code,
3241   SpvReflectShaderModule*  p_module
3242 )
3243 {
3244   // Initialize all module fields to zero
3245   memset(p_module, 0, sizeof(*p_module));
3246 
3247   // Allocate module internals
3248 #ifdef __cplusplus
3249   p_module->_internal = (SpvReflectShaderModule::Internal*)calloc(1, sizeof(*(p_module->_internal)));
3250 #else
3251   p_module->_internal = calloc(1, sizeof(*(p_module->_internal)));
3252 #endif
3253   if (IsNull(p_module->_internal)) {
3254     return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
3255   }
3256   // Allocate SPIR-V code storage
3257   p_module->_internal->spirv_size = size;
3258   p_module->_internal->spirv_code = (uint32_t*)calloc(1, p_module->_internal->spirv_size);
3259   p_module->_internal->spirv_word_count = (uint32_t)(size / SPIRV_WORD_SIZE);
3260   if (IsNull(p_module->_internal->spirv_code)) {
3261     SafeFree(p_module->_internal);
3262     return SPV_REFLECT_RESULT_ERROR_ALLOC_FAILED;
3263   }
3264   memcpy(p_module->_internal->spirv_code, p_code, size);
3265 
3266   Parser parser = { 0 };
3267   SpvReflectResult result = CreateParser(p_module->_internal->spirv_size,
3268                                          p_module->_internal->spirv_code,
3269                                          &parser);
3270 
3271   // Generator
3272   {
3273     const uint32_t* p_ptr = (const uint32_t*)p_module->_internal->spirv_code;
3274     p_module->generator = (SpvReflectGenerator)((*(p_ptr + 2) & 0xFFFF0000) >> 16);
3275   }
3276 
3277   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3278     result = ParseNodes(&parser);
3279   }
3280   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3281     result = ParseStrings(&parser);
3282   }
3283   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3284     result = ParseSource(&parser, p_module);
3285   }
3286   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3287     result = ParseFunctions(&parser);
3288   }
3289   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3290     result = ParseMemberCounts(&parser);
3291   }
3292   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3293     result = ParseNames(&parser);
3294   }
3295   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3296     result = ParseDecorations(&parser);
3297   }
3298 
3299   // Start of reflection data parsing
3300   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3301     p_module->source_language = parser.source_language;
3302     p_module->source_language_version = parser.source_language_version;
3303 
3304     // Zero out descriptor set data
3305     p_module->descriptor_set_count = 0;
3306     memset(p_module->descriptor_sets, 0, SPV_REFLECT_MAX_DESCRIPTOR_SETS * sizeof(*p_module->descriptor_sets));
3307     // Initialize descriptor set numbers
3308     for (uint32_t set_number = 0; set_number < SPV_REFLECT_MAX_DESCRIPTOR_SETS; ++set_number) {
3309       p_module->descriptor_sets[set_number].set = (uint32_t)INVALID_VALUE;
3310     }
3311   }
3312   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3313     result = ParseTypes(&parser, p_module);
3314   }
3315   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3316     result = ParseDescriptorBindings(&parser, p_module);
3317   }
3318   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3319     result = ParseDescriptorType(p_module);
3320   }
3321   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3322     result = ParseUAVCounterBindings(p_module);
3323   }
3324   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3325     result = ParseDescriptorBlocks(&parser, p_module);
3326   }
3327   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3328     result = ParsePushConstantBlocks(&parser, p_module);
3329   }
3330   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3331     result = ParseEntryPoints(&parser, p_module);
3332   }
3333   if (result == SPV_REFLECT_RESULT_SUCCESS && p_module->entry_point_count > 0) {
3334     SpvReflectEntryPoint* p_entry = &(p_module->entry_points[0]);
3335     p_module->entry_point_name = p_entry->name;
3336     p_module->entry_point_id = p_entry->id;
3337     p_module->spirv_execution_model = p_entry->spirv_execution_model;
3338     p_module->shader_stage = p_entry->shader_stage;
3339     p_module->input_variable_count = p_entry->input_variable_count;
3340     p_module->input_variables = p_entry->input_variables;
3341     p_module->output_variable_count = p_entry->output_variable_count;
3342     p_module->output_variables = p_entry->output_variables;
3343   }
3344   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3345     result = DisambiguateStorageBufferSrvUav(p_module);
3346   }
3347   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3348     result = SynchronizeDescriptorSets(p_module);
3349   }
3350   if (result == SPV_REFLECT_RESULT_SUCCESS) {
3351     result = ParseExecutionModes(&parser, p_module);
3352   }
3353 
3354   // Destroy module if parse was not successful
3355   if (result != SPV_REFLECT_RESULT_SUCCESS) {
3356     spvReflectDestroyShaderModule(p_module);
3357   }
3358 
3359   DestroyParser(&parser);
3360 
3361   return result;
3362 }
3363 
SafeFreeTypes(SpvReflectTypeDescription * p_type)3364 static void SafeFreeTypes(SpvReflectTypeDescription* p_type)
3365 {
3366   if (IsNull(p_type)) {
3367     return;
3368   }
3369 
3370   if (IsNotNull(p_type->members)) {
3371     for (size_t i = 0; i < p_type->member_count; ++i) {
3372       SpvReflectTypeDescription* p_member = &p_type->members[i];
3373       SafeFreeTypes(p_member);
3374     }
3375 
3376     SafeFree(p_type->members);
3377     p_type->members = NULL;
3378   }
3379 }
3380 
SafeFreeBlockVariables(SpvReflectBlockVariable * p_block)3381 static void SafeFreeBlockVariables(SpvReflectBlockVariable* p_block)
3382 {
3383   if (IsNull(p_block)) {
3384     return;
3385   }
3386 
3387   if (IsNotNull(p_block->members)) {
3388     for (size_t i = 0; i < p_block->member_count; ++i) {
3389       SpvReflectBlockVariable* p_member = &p_block->members[i];
3390       SafeFreeBlockVariables(p_member);
3391     }
3392 
3393     SafeFree(p_block->members);
3394     p_block->members = NULL;
3395   }
3396 }
3397 
SafeFreeInterfaceVariable(SpvReflectInterfaceVariable * p_interface)3398 static void SafeFreeInterfaceVariable(SpvReflectInterfaceVariable* p_interface)
3399 {
3400   if (IsNull(p_interface)) {
3401     return;
3402   }
3403 
3404   if (IsNotNull(p_interface->members)) {
3405     for (size_t i = 0; i < p_interface->member_count; ++i) {
3406       SpvReflectInterfaceVariable* p_member = &p_interface->members[i];
3407       SafeFreeInterfaceVariable(p_member);
3408     }
3409 
3410     SafeFree(p_interface->members);
3411     p_interface->members = NULL;
3412   }
3413 }
3414 
spvReflectDestroyShaderModule(SpvReflectShaderModule * p_module)3415 void spvReflectDestroyShaderModule(SpvReflectShaderModule* p_module)
3416 {
3417   if (IsNull(p_module->_internal)) {
3418     return;
3419   }
3420 
3421   // Descriptor set bindings
3422   for (size_t i = 0; i < p_module->descriptor_set_count; ++i) {
3423     SpvReflectDescriptorSet* p_set = &p_module->descriptor_sets[i];
3424     free(p_set->bindings);
3425   }
3426 
3427   // Descriptor binding blocks
3428   for (size_t i = 0; i < p_module->descriptor_binding_count; ++i) {
3429     SpvReflectDescriptorBinding* p_descriptor = &p_module->descriptor_bindings[i];
3430     SafeFreeBlockVariables(&p_descriptor->block);
3431   }
3432   SafeFree(p_module->descriptor_bindings);
3433 
3434   // Entry points
3435   for (size_t i = 0; i < p_module->entry_point_count; ++i) {
3436     SpvReflectEntryPoint* p_entry = &p_module->entry_points[i];
3437     for (size_t j = 0; j < p_entry->input_variable_count; j++) {
3438       SafeFreeInterfaceVariable(&p_entry->input_variables[j]);
3439     }
3440     for (size_t j = 0; j < p_entry->output_variable_count; j++) {
3441       SafeFreeInterfaceVariable(&p_entry->output_variables[j]);
3442     }
3443     for (uint32_t j = 0; j < p_entry->descriptor_set_count; ++j) {
3444       SafeFree(p_entry->descriptor_sets[j].bindings);
3445     }
3446     SafeFree(p_entry->descriptor_sets);
3447     SafeFree(p_entry->input_variables);
3448     SafeFree(p_entry->output_variables);
3449     SafeFree(p_entry->used_uniforms);
3450     SafeFree(p_entry->used_push_constants);
3451   }
3452   SafeFree(p_module->entry_points);
3453 
3454   // Push constants
3455   for (size_t i = 0; i < p_module->push_constant_block_count; ++i) {
3456     SafeFreeBlockVariables(&p_module->push_constant_blocks[i]);
3457   }
3458   SafeFree(p_module->push_constant_blocks);
3459 
3460   // Type infos
3461   for (size_t i = 0; i < p_module->_internal->type_description_count; ++i) {
3462     SpvReflectTypeDescription* p_type = &p_module->_internal->type_descriptions[i];
3463     if (IsNotNull(p_type->members)) {
3464       SafeFreeTypes(p_type);
3465     }
3466     SafeFree(p_type->members);
3467   }
3468   SafeFree(p_module->_internal->type_descriptions);
3469 
3470   // Free SPIR-V code
3471   SafeFree(p_module->_internal->spirv_code);
3472   // Free internal
3473   SafeFree(p_module->_internal);
3474 }
3475 
spvReflectGetCodeSize(const SpvReflectShaderModule * p_module)3476 uint32_t spvReflectGetCodeSize(const SpvReflectShaderModule* p_module)
3477 {
3478   if (IsNull(p_module)) {
3479     return 0;
3480   }
3481 
3482   return (uint32_t)(p_module->_internal->spirv_size);
3483 }
3484 
spvReflectGetCode(const SpvReflectShaderModule * p_module)3485 const uint32_t* spvReflectGetCode(const SpvReflectShaderModule* p_module)
3486 {
3487   if (IsNull(p_module)) {
3488     return NULL;
3489   }
3490 
3491   return p_module->_internal->spirv_code;
3492 }
3493 
spvReflectGetEntryPoint(const SpvReflectShaderModule * p_module,const char * entry_point)3494 const SpvReflectEntryPoint* spvReflectGetEntryPoint(
3495   const SpvReflectShaderModule* p_module,
3496   const char*                   entry_point
3497 ) {
3498   if (IsNull(p_module) || IsNull(entry_point)) {
3499     return NULL;
3500   }
3501 
3502   for (uint32_t i = 0; i < p_module->entry_point_count; ++i) {
3503     if (strcmp(p_module->entry_points[i].name, entry_point) == 0) {
3504       return &p_module->entry_points[i];
3505     }
3506   }
3507   return NULL;
3508 }
3509 
spvReflectEnumerateDescriptorBindings(const SpvReflectShaderModule * p_module,uint32_t * p_count,SpvReflectDescriptorBinding ** pp_bindings)3510 SpvReflectResult spvReflectEnumerateDescriptorBindings(
3511   const SpvReflectShaderModule*  p_module,
3512   uint32_t*                      p_count,
3513   SpvReflectDescriptorBinding**  pp_bindings
3514 )
3515 {
3516   if (IsNull(p_module)) {
3517     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3518   }
3519   if (IsNull(p_count)) {
3520     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3521   }
3522 
3523   if (IsNotNull(pp_bindings)) {
3524     if (*p_count != p_module->descriptor_binding_count) {
3525       return SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH;
3526     }
3527 
3528     for (uint32_t index = 0; index < *p_count; ++index) {
3529       SpvReflectDescriptorBinding* p_bindings = (SpvReflectDescriptorBinding*)&p_module->descriptor_bindings[index];
3530       pp_bindings[index] = p_bindings;
3531     }
3532   }
3533   else {
3534     *p_count = p_module->descriptor_binding_count;
3535   }
3536 
3537   return SPV_REFLECT_RESULT_SUCCESS;
3538 }
3539 
spvReflectEnumerateEntryPointDescriptorBindings(const SpvReflectShaderModule * p_module,const char * entry_point,uint32_t * p_count,SpvReflectDescriptorBinding ** pp_bindings)3540 SpvReflectResult spvReflectEnumerateEntryPointDescriptorBindings(
3541   const SpvReflectShaderModule*  p_module,
3542   const char*                    entry_point,
3543   uint32_t*                      p_count,
3544   SpvReflectDescriptorBinding**  pp_bindings
3545 )
3546 {
3547   if (IsNull(p_module)) {
3548     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3549   }
3550   if (IsNull(p_count)) {
3551     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3552   }
3553 
3554   const SpvReflectEntryPoint* p_entry =
3555       spvReflectGetEntryPoint(p_module, entry_point);
3556   if (IsNull(p_entry)) {
3557     return SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
3558   }
3559 
3560   uint32_t count = 0;
3561   for (uint32_t i = 0; i < p_module->descriptor_binding_count; ++i) {
3562     bool found = SearchSortedUint32(
3563       p_entry->used_uniforms,
3564       p_entry->used_uniform_count,
3565       p_module->descriptor_bindings[i].spirv_id);
3566     if (found) {
3567       if (IsNotNull(pp_bindings)) {
3568         if (count >= *p_count) {
3569           return SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH;
3570         }
3571         pp_bindings[count++] = (SpvReflectDescriptorBinding*)&p_module->descriptor_bindings[i];
3572       } else {
3573         ++count;
3574       }
3575     }
3576   }
3577   if (IsNotNull(pp_bindings)) {
3578     if (count != *p_count) {
3579       return SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH;
3580     }
3581   } else {
3582     *p_count = count;
3583   }
3584   return SPV_REFLECT_RESULT_SUCCESS;
3585 }
3586 
spvReflectEnumerateDescriptorSets(const SpvReflectShaderModule * p_module,uint32_t * p_count,SpvReflectDescriptorSet ** pp_sets)3587 SpvReflectResult spvReflectEnumerateDescriptorSets(
3588   const SpvReflectShaderModule* p_module,
3589   uint32_t*                     p_count,
3590   SpvReflectDescriptorSet**     pp_sets
3591 )
3592 {
3593   if (IsNull(p_module)) {
3594     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3595   }
3596   if (IsNull(p_count)) {
3597     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3598   }
3599 
3600   if (IsNotNull(pp_sets)) {
3601     if (*p_count != p_module->descriptor_set_count) {
3602       return SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH;
3603     }
3604 
3605     for (uint32_t index = 0; index < *p_count; ++index) {
3606       SpvReflectDescriptorSet* p_set = (SpvReflectDescriptorSet*)&p_module->descriptor_sets[index];
3607       pp_sets[index] = p_set;
3608     }
3609   }
3610   else {
3611     *p_count = p_module->descriptor_set_count;
3612   }
3613 
3614   return SPV_REFLECT_RESULT_SUCCESS;
3615 }
3616 
spvReflectEnumerateEntryPointDescriptorSets(const SpvReflectShaderModule * p_module,const char * entry_point,uint32_t * p_count,SpvReflectDescriptorSet ** pp_sets)3617 SpvReflectResult spvReflectEnumerateEntryPointDescriptorSets(
3618   const SpvReflectShaderModule* p_module,
3619   const char*                   entry_point,
3620   uint32_t*                     p_count,
3621   SpvReflectDescriptorSet**     pp_sets
3622 )
3623 {
3624   if (IsNull(p_module)) {
3625     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3626   }
3627   if (IsNull(p_count)) {
3628     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3629   }
3630 
3631   const SpvReflectEntryPoint* p_entry =
3632       spvReflectGetEntryPoint(p_module, entry_point);
3633   if (IsNull(p_entry)) {
3634     return SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
3635   }
3636 
3637   if (IsNotNull(pp_sets)) {
3638     if (*p_count != p_entry->descriptor_set_count) {
3639       return SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH;
3640     }
3641 
3642     for (uint32_t index = 0; index < *p_count; ++index) {
3643       SpvReflectDescriptorSet* p_set = (SpvReflectDescriptorSet*)&p_entry->descriptor_sets[index];
3644       pp_sets[index] = p_set;
3645     }
3646   }
3647   else {
3648     *p_count = p_entry->descriptor_set_count;
3649   }
3650 
3651   return SPV_REFLECT_RESULT_SUCCESS;
3652 }
3653 
spvReflectEnumerateInputVariables(const SpvReflectShaderModule * p_module,uint32_t * p_count,SpvReflectInterfaceVariable ** pp_variables)3654 SpvReflectResult spvReflectEnumerateInputVariables(
3655   const SpvReflectShaderModule* p_module,
3656   uint32_t*                     p_count,
3657   SpvReflectInterfaceVariable** pp_variables
3658 )
3659 {
3660   if (IsNull(p_module)) {
3661     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3662   }
3663   if (IsNull(p_count)) {
3664     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3665   }
3666 
3667   if (IsNotNull(pp_variables)) {
3668     if (*p_count != p_module->input_variable_count) {
3669       return SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH;
3670     }
3671 
3672     for (uint32_t index = 0; index < *p_count; ++index) {
3673       SpvReflectInterfaceVariable* p_var = (SpvReflectInterfaceVariable*)&p_module->input_variables[index];
3674       pp_variables[index] = p_var;
3675     }
3676   }
3677   else {
3678     *p_count = p_module->input_variable_count;
3679   }
3680 
3681   return SPV_REFLECT_RESULT_SUCCESS;
3682 }
3683 
spvReflectEnumerateEntryPointInputVariables(const SpvReflectShaderModule * p_module,const char * entry_point,uint32_t * p_count,SpvReflectInterfaceVariable ** pp_variables)3684 SpvReflectResult spvReflectEnumerateEntryPointInputVariables(
3685   const SpvReflectShaderModule* p_module,
3686   const char*                   entry_point,
3687   uint32_t*                     p_count,
3688   SpvReflectInterfaceVariable** pp_variables
3689 )
3690 {
3691   if (IsNull(p_module)) {
3692     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3693   }
3694   if (IsNull(p_count)) {
3695     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3696   }
3697 
3698   const SpvReflectEntryPoint* p_entry =
3699       spvReflectGetEntryPoint(p_module, entry_point);
3700   if (IsNull(p_entry)) {
3701     return SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
3702   }
3703 
3704   if (IsNotNull(pp_variables)) {
3705     if (*p_count != p_entry->input_variable_count) {
3706       return SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH;
3707     }
3708 
3709     for (uint32_t index = 0; index < *p_count; ++index) {
3710       SpvReflectInterfaceVariable* p_var = (SpvReflectInterfaceVariable*)&p_entry->input_variables[index];
3711       pp_variables[index] = p_var;
3712     }
3713   }
3714   else {
3715     *p_count = p_entry->input_variable_count;
3716   }
3717 
3718   return SPV_REFLECT_RESULT_SUCCESS;
3719 }
3720 
spvReflectEnumerateOutputVariables(const SpvReflectShaderModule * p_module,uint32_t * p_count,SpvReflectInterfaceVariable ** pp_variables)3721 SpvReflectResult spvReflectEnumerateOutputVariables(
3722   const SpvReflectShaderModule* p_module,
3723   uint32_t*                     p_count,
3724   SpvReflectInterfaceVariable** pp_variables
3725 )
3726 {
3727   if (IsNull(p_module)) {
3728     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3729   }
3730   if (IsNull(p_count)) {
3731     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3732   }
3733 
3734   if (IsNotNull(pp_variables)) {
3735     if (*p_count != p_module->output_variable_count) {
3736       return SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH;
3737     }
3738 
3739     for (uint32_t index = 0; index < *p_count; ++index) {
3740       SpvReflectInterfaceVariable* p_var = (SpvReflectInterfaceVariable*)&p_module->output_variables[index];
3741       pp_variables[index] = p_var;
3742     }
3743   }
3744   else {
3745     *p_count = p_module->output_variable_count;
3746   }
3747 
3748   return SPV_REFLECT_RESULT_SUCCESS;
3749 }
3750 
spvReflectEnumerateEntryPointOutputVariables(const SpvReflectShaderModule * p_module,const char * entry_point,uint32_t * p_count,SpvReflectInterfaceVariable ** pp_variables)3751 SpvReflectResult spvReflectEnumerateEntryPointOutputVariables(
3752   const SpvReflectShaderModule* p_module,
3753   const char*                   entry_point,
3754   uint32_t*                     p_count,
3755   SpvReflectInterfaceVariable** pp_variables
3756 )
3757 {
3758   if (IsNull(p_module)) {
3759     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3760   }
3761   if (IsNull(p_count)) {
3762     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3763   }
3764 
3765   const SpvReflectEntryPoint* p_entry =
3766       spvReflectGetEntryPoint(p_module, entry_point);
3767   if (IsNull(p_entry)) {
3768     return SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
3769   }
3770 
3771   if (IsNotNull(pp_variables)) {
3772     if (*p_count != p_entry->output_variable_count) {
3773       return SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH;
3774     }
3775 
3776     for (uint32_t index = 0; index < *p_count; ++index) {
3777       SpvReflectInterfaceVariable* p_var = (SpvReflectInterfaceVariable*)&p_entry->output_variables[index];
3778       pp_variables[index] = p_var;
3779     }
3780   }
3781   else {
3782     *p_count = p_entry->output_variable_count;
3783   }
3784 
3785   return SPV_REFLECT_RESULT_SUCCESS;
3786 }
3787 
spvReflectEnumeratePushConstantBlocks(const SpvReflectShaderModule * p_module,uint32_t * p_count,SpvReflectBlockVariable ** pp_blocks)3788 SpvReflectResult spvReflectEnumeratePushConstantBlocks(
3789   const SpvReflectShaderModule* p_module,
3790   uint32_t*                     p_count,
3791   SpvReflectBlockVariable**     pp_blocks
3792 )
3793 {
3794   if (IsNull(p_module)) {
3795     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3796   }
3797   if (IsNull(p_count)) {
3798     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3799   }
3800 
3801   if (pp_blocks != NULL) {
3802     if (*p_count != p_module->push_constant_block_count) {
3803       return SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH;
3804     }
3805 
3806     for (uint32_t index = 0; index < *p_count; ++index) {
3807       SpvReflectBlockVariable* p_push_constant_blocks = (SpvReflectBlockVariable*)&p_module->push_constant_blocks[index];
3808       pp_blocks[index] = p_push_constant_blocks;
3809     }
3810   }
3811   else {
3812     *p_count = p_module->push_constant_block_count;
3813   }
3814 
3815   return SPV_REFLECT_RESULT_SUCCESS;
3816 }
spvReflectEnumeratePushConstants(const SpvReflectShaderModule * p_module,uint32_t * p_count,SpvReflectBlockVariable ** pp_blocks)3817 SpvReflectResult spvReflectEnumeratePushConstants(
3818   const SpvReflectShaderModule* p_module,
3819   uint32_t*                     p_count,
3820   SpvReflectBlockVariable**     pp_blocks
3821 )
3822 {
3823   return spvReflectEnumeratePushConstantBlocks(p_module, p_count, pp_blocks);
3824 }
3825 
spvReflectEnumerateEntryPointPushConstantBlocks(const SpvReflectShaderModule * p_module,const char * entry_point,uint32_t * p_count,SpvReflectBlockVariable ** pp_blocks)3826 SpvReflectResult spvReflectEnumerateEntryPointPushConstantBlocks(
3827   const SpvReflectShaderModule* p_module,
3828   const char*                   entry_point,
3829   uint32_t*                     p_count,
3830   SpvReflectBlockVariable**     pp_blocks
3831 )
3832 {
3833   if (IsNull(p_module)) {
3834     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3835   }
3836   if (IsNull(p_count)) {
3837     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
3838   }
3839 
3840 
3841   const SpvReflectEntryPoint* p_entry =
3842       spvReflectGetEntryPoint(p_module, entry_point);
3843   if (IsNull(p_entry)) {
3844     return SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
3845   }
3846 
3847   uint32_t count = 0;
3848   for (uint32_t i = 0; i < p_module->push_constant_block_count; ++i) {
3849     bool found = SearchSortedUint32(p_entry->used_push_constants,
3850                            p_entry->used_push_constant_count,
3851                            p_module->push_constant_blocks[i].spirv_id);
3852     if (found) {
3853       if (IsNotNull(pp_blocks)) {
3854         if (count >= *p_count) {
3855           return SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH;
3856         }
3857         pp_blocks[count++] = (SpvReflectBlockVariable*)&p_module->push_constant_blocks[i];
3858       } else {
3859         ++count;
3860       }
3861     }
3862   }
3863   if (IsNotNull(pp_blocks)) {
3864     if (count != *p_count) {
3865       return SPV_REFLECT_RESULT_ERROR_COUNT_MISMATCH;
3866     }
3867   } else {
3868     *p_count = count;
3869   }
3870   return SPV_REFLECT_RESULT_SUCCESS;
3871 }
3872 
spvReflectGetDescriptorBinding(const SpvReflectShaderModule * p_module,uint32_t binding_number,uint32_t set_number,SpvReflectResult * p_result)3873 const SpvReflectDescriptorBinding* spvReflectGetDescriptorBinding(
3874   const SpvReflectShaderModule* p_module,
3875   uint32_t                      binding_number,
3876   uint32_t                      set_number,
3877   SpvReflectResult*             p_result
3878 )
3879 {
3880   const SpvReflectDescriptorBinding* p_descriptor = NULL;
3881   if (IsNotNull(p_module)) {
3882     for (uint32_t index = 0; index < p_module->descriptor_binding_count; ++index) {
3883       const SpvReflectDescriptorBinding* p_potential = &p_module->descriptor_bindings[index];
3884       if ((p_potential->binding == binding_number) && (p_potential->set == set_number)) {
3885         p_descriptor = p_potential;
3886         break;
3887       }
3888     }
3889   }
3890   if (IsNotNull(p_result)) {
3891     *p_result = IsNotNull(p_descriptor)
3892         ?  SPV_REFLECT_RESULT_SUCCESS
3893         : (IsNull(p_module) ? SPV_REFLECT_RESULT_ERROR_NULL_POINTER
3894                             : SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND);
3895   }
3896   return p_descriptor;
3897 }
3898 
spvReflectGetEntryPointDescriptorBinding(const SpvReflectShaderModule * p_module,const char * entry_point,uint32_t binding_number,uint32_t set_number,SpvReflectResult * p_result)3899 const SpvReflectDescriptorBinding* spvReflectGetEntryPointDescriptorBinding(
3900   const SpvReflectShaderModule* p_module,
3901   const char*                   entry_point,
3902   uint32_t                      binding_number,
3903   uint32_t                      set_number,
3904   SpvReflectResult*             p_result
3905 )
3906 {
3907   const SpvReflectEntryPoint* p_entry =
3908       spvReflectGetEntryPoint(p_module, entry_point);
3909   if (IsNull(p_entry)) {
3910     if (IsNotNull(p_result)) {
3911       *p_result = SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
3912     }
3913     return NULL;
3914   }
3915   const SpvReflectDescriptorBinding* p_descriptor = NULL;
3916   if (IsNotNull(p_module)) {
3917     for (uint32_t index = 0; index < p_module->descriptor_binding_count; ++index) {
3918       const SpvReflectDescriptorBinding* p_potential = &p_module->descriptor_bindings[index];
3919       bool found = SearchSortedUint32(
3920         p_entry->used_uniforms,
3921         p_entry->used_uniform_count,
3922         p_potential->spirv_id);
3923       if ((p_potential->binding == binding_number) && (p_potential->set == set_number) && found) {
3924         p_descriptor = p_potential;
3925         break;
3926       }
3927     }
3928   }
3929   if (IsNotNull(p_result)) {
3930     *p_result = IsNotNull(p_descriptor)
3931         ?  SPV_REFLECT_RESULT_SUCCESS
3932         : (IsNull(p_module) ? SPV_REFLECT_RESULT_ERROR_NULL_POINTER
3933                             : SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND);
3934   }
3935   return p_descriptor;
3936 }
3937 
spvReflectGetDescriptorSet(const SpvReflectShaderModule * p_module,uint32_t set_number,SpvReflectResult * p_result)3938 const SpvReflectDescriptorSet* spvReflectGetDescriptorSet(
3939   const SpvReflectShaderModule* p_module,
3940   uint32_t                      set_number,
3941   SpvReflectResult*             p_result
3942 )
3943 {
3944   const SpvReflectDescriptorSet* p_set = NULL;
3945   if (IsNotNull(p_module)) {
3946     for (uint32_t index = 0; index < p_module->descriptor_set_count; ++index) {
3947       const SpvReflectDescriptorSet* p_potential = &p_module->descriptor_sets[index];
3948       if (p_potential->set == set_number) {
3949         p_set = p_potential;
3950       }
3951     }
3952   }
3953   if (IsNotNull(p_result)) {
3954     *p_result = IsNotNull(p_set)
3955         ?  SPV_REFLECT_RESULT_SUCCESS
3956         : (IsNull(p_module) ? SPV_REFLECT_RESULT_ERROR_NULL_POINTER
3957                             : SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND);
3958   }
3959   return p_set;
3960 }
3961 
spvReflectGetEntryPointDescriptorSet(const SpvReflectShaderModule * p_module,const char * entry_point,uint32_t set_number,SpvReflectResult * p_result)3962 const SpvReflectDescriptorSet* spvReflectGetEntryPointDescriptorSet(
3963   const SpvReflectShaderModule* p_module,
3964   const char*                   entry_point,
3965   uint32_t                      set_number,
3966   SpvReflectResult*             p_result)
3967 {
3968   const SpvReflectDescriptorSet* p_set = NULL;
3969   if (IsNotNull(p_module)) {
3970     const SpvReflectEntryPoint* p_entry = spvReflectGetEntryPoint(p_module, entry_point);
3971     if (IsNull(p_entry)) {
3972       if (IsNotNull(p_result)) {
3973         *p_result = SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
3974       }
3975       return NULL;
3976     }
3977     for (uint32_t index = 0; index < p_entry->descriptor_set_count; ++index) {
3978       const SpvReflectDescriptorSet* p_potential = &p_entry->descriptor_sets[index];
3979       if (p_potential->set == set_number) {
3980         p_set = p_potential;
3981       }
3982     }
3983   }
3984   if (IsNotNull(p_result)) {
3985     *p_result = IsNotNull(p_set)
3986         ?  SPV_REFLECT_RESULT_SUCCESS
3987         : (IsNull(p_module) ? SPV_REFLECT_RESULT_ERROR_NULL_POINTER
3988                             : SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND);
3989   }
3990   return p_set;
3991 }
3992 
3993 
spvReflectGetInputVariableByLocation(const SpvReflectShaderModule * p_module,uint32_t location,SpvReflectResult * p_result)3994 const SpvReflectInterfaceVariable* spvReflectGetInputVariableByLocation(
3995   const SpvReflectShaderModule* p_module,
3996   uint32_t                      location,
3997   SpvReflectResult*             p_result
3998 )
3999 {
4000   if (location == INVALID_VALUE) {
4001     if (IsNotNull(p_result)) {
4002       *p_result = SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
4003     }
4004     return NULL;
4005   }
4006   const SpvReflectInterfaceVariable* p_var = NULL;
4007   if (IsNotNull(p_module)) {
4008     for (uint32_t index = 0; index < p_module->input_variable_count; ++index) {
4009       const SpvReflectInterfaceVariable* p_potential = &p_module->input_variables[index];
4010       if (p_potential->location == location) {
4011         p_var = p_potential;
4012       }
4013     }
4014   }
4015   if (IsNotNull(p_result)) {
4016     *p_result = IsNotNull(p_var)
4017         ?  SPV_REFLECT_RESULT_SUCCESS
4018         : (IsNull(p_module) ? SPV_REFLECT_RESULT_ERROR_NULL_POINTER
4019                             : SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND);
4020   }
4021   return p_var;
4022 }
spvReflectGetInputVariable(const SpvReflectShaderModule * p_module,uint32_t location,SpvReflectResult * p_result)4023 const SpvReflectInterfaceVariable* spvReflectGetInputVariable(
4024   const SpvReflectShaderModule* p_module,
4025   uint32_t                      location,
4026   SpvReflectResult*             p_result
4027 )
4028 {
4029   return spvReflectGetInputVariableByLocation(p_module, location, p_result);
4030 }
4031 
spvReflectGetEntryPointInputVariableByLocation(const SpvReflectShaderModule * p_module,const char * entry_point,uint32_t location,SpvReflectResult * p_result)4032 const SpvReflectInterfaceVariable* spvReflectGetEntryPointInputVariableByLocation(
4033   const SpvReflectShaderModule* p_module,
4034   const char*                   entry_point,
4035   uint32_t                      location,
4036   SpvReflectResult*             p_result
4037 )
4038 {
4039   if (location == INVALID_VALUE) {
4040     if (IsNotNull(p_result)) {
4041       *p_result = SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
4042     }
4043     return NULL;
4044   }
4045 
4046   const SpvReflectInterfaceVariable* p_var = NULL;
4047   if (IsNotNull(p_module)) {
4048     const SpvReflectEntryPoint* p_entry =
4049         spvReflectGetEntryPoint(p_module, entry_point);
4050     if (IsNull(p_entry)) {
4051       if (IsNotNull(p_result)) {
4052         *p_result = SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
4053       }
4054       return NULL;
4055     }
4056     for (uint32_t index = 0; index < p_entry->input_variable_count; ++index) {
4057       const SpvReflectInterfaceVariable* p_potential = &p_entry->input_variables[index];
4058       if (p_potential->location == location) {
4059         p_var = p_potential;
4060       }
4061     }
4062   }
4063   if (IsNotNull(p_result)) {
4064     *p_result = IsNotNull(p_var)
4065         ?  SPV_REFLECT_RESULT_SUCCESS
4066         : (IsNull(p_module) ? SPV_REFLECT_RESULT_ERROR_NULL_POINTER
4067                             : SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND);
4068   }
4069   return p_var;
4070 }
4071 
spvReflectGetInputVariableBySemantic(const SpvReflectShaderModule * p_module,const char * semantic,SpvReflectResult * p_result)4072 const SpvReflectInterfaceVariable* spvReflectGetInputVariableBySemantic(
4073   const SpvReflectShaderModule* p_module,
4074   const char*                   semantic,
4075   SpvReflectResult*             p_result
4076 )
4077 {
4078   if (IsNull(semantic)) {
4079     if (IsNotNull(p_result)) {
4080       *p_result = SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
4081     }
4082     return NULL;
4083   }
4084   if (semantic[0] == '\0') {
4085     if (IsNotNull(p_result)) {
4086       *p_result = SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
4087     }
4088     return NULL;
4089   }
4090   const SpvReflectInterfaceVariable* p_var = NULL;
4091   if (IsNotNull(p_module)) {
4092     for (uint32_t index = 0; index < p_module->input_variable_count; ++index) {
4093       const SpvReflectInterfaceVariable* p_potential = &p_module->input_variables[index];
4094       if (p_potential->semantic != NULL && strcmp(p_potential->semantic, semantic) == 0) {
4095         p_var = p_potential;
4096       }
4097     }
4098   }
4099   if (IsNotNull(p_result)) {
4100     *p_result = IsNotNull(p_var)
4101       ?  SPV_REFLECT_RESULT_SUCCESS
4102       : (IsNull(p_module) ? SPV_REFLECT_RESULT_ERROR_NULL_POINTER
4103         : SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND);
4104   }
4105   return p_var;
4106 }
4107 
spvReflectGetEntryPointInputVariableBySemantic(const SpvReflectShaderModule * p_module,const char * entry_point,const char * semantic,SpvReflectResult * p_result)4108 const SpvReflectInterfaceVariable* spvReflectGetEntryPointInputVariableBySemantic(
4109   const SpvReflectShaderModule* p_module,
4110   const char*                   entry_point,
4111   const char*                   semantic,
4112   SpvReflectResult*             p_result
4113 )
4114 {
4115   if (IsNull(semantic)) {
4116     if (IsNotNull(p_result)) {
4117       *p_result = SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
4118     }
4119     return NULL;
4120   }
4121   if (semantic[0] == '\0') {
4122     if (IsNotNull(p_result)) {
4123       *p_result = SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
4124     }
4125     return NULL;
4126   }
4127   const SpvReflectInterfaceVariable* p_var = NULL;
4128   if (IsNotNull(p_module)) {
4129     const SpvReflectEntryPoint* p_entry = spvReflectGetEntryPoint(p_module, entry_point);
4130     if (IsNull(p_entry)) {
4131       if (IsNotNull(p_result)) {
4132         *p_result = SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
4133       }
4134       return NULL;
4135     }
4136     for (uint32_t index = 0; index < p_entry->input_variable_count; ++index) {
4137       const SpvReflectInterfaceVariable* p_potential = &p_entry->input_variables[index];
4138       if (p_potential->semantic != NULL && strcmp(p_potential->semantic, semantic) == 0) {
4139         p_var = p_potential;
4140       }
4141     }
4142   }
4143   if (IsNotNull(p_result)) {
4144     *p_result = IsNotNull(p_var)
4145       ?  SPV_REFLECT_RESULT_SUCCESS
4146       : (IsNull(p_module) ? SPV_REFLECT_RESULT_ERROR_NULL_POINTER
4147         : SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND);
4148   }
4149   return p_var;
4150 }
4151 
spvReflectGetOutputVariableByLocation(const SpvReflectShaderModule * p_module,uint32_t location,SpvReflectResult * p_result)4152 const SpvReflectInterfaceVariable* spvReflectGetOutputVariableByLocation(
4153   const SpvReflectShaderModule*  p_module,
4154   uint32_t                       location,
4155   SpvReflectResult*              p_result
4156 )
4157 {
4158   if (location == INVALID_VALUE) {
4159     if (IsNotNull(p_result)) {
4160       *p_result = SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
4161     }
4162     return NULL;
4163   }
4164   const SpvReflectInterfaceVariable* p_var = NULL;
4165   if (IsNotNull(p_module)) {
4166     for (uint32_t index = 0; index < p_module->output_variable_count; ++index) {
4167       const SpvReflectInterfaceVariable* p_potential = &p_module->output_variables[index];
4168       if (p_potential->location == location) {
4169         p_var = p_potential;
4170       }
4171     }
4172   }
4173   if (IsNotNull(p_result)) {
4174     *p_result = IsNotNull(p_var)
4175         ?  SPV_REFLECT_RESULT_SUCCESS
4176         : (IsNull(p_module) ? SPV_REFLECT_RESULT_ERROR_NULL_POINTER
4177                             : SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND);
4178   }
4179   return p_var;
4180 }
spvReflectGetOutputVariable(const SpvReflectShaderModule * p_module,uint32_t location,SpvReflectResult * p_result)4181 const SpvReflectInterfaceVariable* spvReflectGetOutputVariable(
4182   const SpvReflectShaderModule*  p_module,
4183   uint32_t                       location,
4184   SpvReflectResult*              p_result
4185 )
4186 {
4187   return spvReflectGetOutputVariableByLocation(p_module, location, p_result);
4188 }
4189 
spvReflectGetEntryPointOutputVariableByLocation(const SpvReflectShaderModule * p_module,const char * entry_point,uint32_t location,SpvReflectResult * p_result)4190 const SpvReflectInterfaceVariable* spvReflectGetEntryPointOutputVariableByLocation(
4191   const SpvReflectShaderModule* p_module,
4192   const char*                   entry_point,
4193   uint32_t                      location,
4194   SpvReflectResult*             p_result
4195 )
4196 {
4197   if (location == INVALID_VALUE) {
4198     if (IsNotNull(p_result)) {
4199       *p_result = SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
4200     }
4201     return NULL;
4202   }
4203 
4204   const SpvReflectInterfaceVariable* p_var = NULL;
4205   if (IsNotNull(p_module)) {
4206     const SpvReflectEntryPoint* p_entry = spvReflectGetEntryPoint(p_module, entry_point);
4207     if (IsNull(p_entry)) {
4208       if (IsNotNull(p_result)) {
4209         *p_result = SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
4210       }
4211       return NULL;
4212     }
4213     for (uint32_t index = 0; index < p_entry->output_variable_count; ++index) {
4214       const SpvReflectInterfaceVariable* p_potential = &p_entry->output_variables[index];
4215       if (p_potential->location == location) {
4216         p_var = p_potential;
4217       }
4218     }
4219   }
4220   if (IsNotNull(p_result)) {
4221     *p_result = IsNotNull(p_var)
4222         ?  SPV_REFLECT_RESULT_SUCCESS
4223         : (IsNull(p_module) ? SPV_REFLECT_RESULT_ERROR_NULL_POINTER
4224                             : SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND);
4225   }
4226   return p_var;
4227 }
4228 
spvReflectGetOutputVariableBySemantic(const SpvReflectShaderModule * p_module,const char * semantic,SpvReflectResult * p_result)4229 const SpvReflectInterfaceVariable* spvReflectGetOutputVariableBySemantic(
4230   const SpvReflectShaderModule*  p_module,
4231   const char*                    semantic,
4232   SpvReflectResult*              p_result
4233 )
4234 {
4235   if (IsNull(semantic)) {
4236     if (IsNotNull(p_result)) {
4237       *p_result = SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
4238     }
4239     return NULL;
4240   }
4241   if (semantic[0] == '\0') {
4242     if (IsNotNull(p_result)) {
4243       *p_result = SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
4244     }
4245     return NULL;
4246   }
4247   const SpvReflectInterfaceVariable* p_var = NULL;
4248   if (IsNotNull(p_module)) {
4249     for (uint32_t index = 0; index < p_module->output_variable_count; ++index) {
4250       const SpvReflectInterfaceVariable* p_potential = &p_module->output_variables[index];
4251       if (p_potential->semantic != NULL && strcmp(p_potential->semantic, semantic) == 0) {
4252         p_var = p_potential;
4253       }
4254     }
4255   }
4256   if (IsNotNull(p_result)) {
4257     *p_result = IsNotNull(p_var)
4258         ?  SPV_REFLECT_RESULT_SUCCESS
4259         : (IsNull(p_module) ? SPV_REFLECT_RESULT_ERROR_NULL_POINTER
4260                             : SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND);
4261   }
4262   return p_var;
4263 }
4264 
spvReflectGetEntryPointOutputVariableBySemantic(const SpvReflectShaderModule * p_module,const char * entry_point,const char * semantic,SpvReflectResult * p_result)4265 const SpvReflectInterfaceVariable* spvReflectGetEntryPointOutputVariableBySemantic(
4266   const SpvReflectShaderModule* p_module,
4267   const char*                   entry_point,
4268   const char*                   semantic,
4269   SpvReflectResult*             p_result)
4270 {
4271   if (IsNull(semantic)) {
4272     if (IsNotNull(p_result)) {
4273       *p_result = SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
4274     }
4275     return NULL;
4276   }
4277   if (semantic[0] == '\0') {
4278     if (IsNotNull(p_result)) {
4279       *p_result = SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
4280     }
4281     return NULL;
4282   }
4283   const SpvReflectInterfaceVariable* p_var = NULL;
4284   if (IsNotNull(p_module)) {
4285     const SpvReflectEntryPoint* p_entry = spvReflectGetEntryPoint(p_module, entry_point);
4286     if (IsNull(p_entry)) {
4287       if (IsNotNull(p_result)) {
4288         *p_result = SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
4289       }
4290       return NULL;
4291     }
4292     for (uint32_t index = 0; index < p_entry->output_variable_count; ++index) {
4293       const SpvReflectInterfaceVariable* p_potential = &p_entry->output_variables[index];
4294       if (p_potential->semantic != NULL && strcmp(p_potential->semantic, semantic) == 0) {
4295         p_var = p_potential;
4296       }
4297     }
4298   }
4299   if (IsNotNull(p_result)) {
4300     *p_result = IsNotNull(p_var)
4301       ?  SPV_REFLECT_RESULT_SUCCESS
4302       : (IsNull(p_module) ? SPV_REFLECT_RESULT_ERROR_NULL_POINTER
4303         : SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND);
4304   }
4305   return p_var;
4306 }
4307 
spvReflectGetPushConstantBlock(const SpvReflectShaderModule * p_module,uint32_t index,SpvReflectResult * p_result)4308 const SpvReflectBlockVariable* spvReflectGetPushConstantBlock(
4309   const SpvReflectShaderModule* p_module,
4310   uint32_t                      index,
4311   SpvReflectResult*             p_result
4312 )
4313 {
4314   const SpvReflectBlockVariable* p_push_constant = NULL;
4315   if (IsNotNull(p_module)) {
4316     if (index < p_module->push_constant_block_count) {
4317       p_push_constant = &p_module->push_constant_blocks[index];
4318     }
4319   }
4320   if (IsNotNull(p_result)) {
4321     *p_result = IsNotNull(p_push_constant)
4322         ?  SPV_REFLECT_RESULT_SUCCESS
4323         : (IsNull(p_module) ? SPV_REFLECT_RESULT_ERROR_NULL_POINTER
4324                             : SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND);
4325   }
4326   return p_push_constant;
4327 }
spvReflectGetPushConstant(const SpvReflectShaderModule * p_module,uint32_t index,SpvReflectResult * p_result)4328 const SpvReflectBlockVariable* spvReflectGetPushConstant(
4329   const SpvReflectShaderModule* p_module,
4330   uint32_t                      index,
4331   SpvReflectResult*             p_result
4332 )
4333 {
4334   return spvReflectGetPushConstantBlock(p_module, index, p_result);
4335 }
4336 
spvReflectGetEntryPointPushConstantBlock(const SpvReflectShaderModule * p_module,const char * entry_point,SpvReflectResult * p_result)4337 const SpvReflectBlockVariable* spvReflectGetEntryPointPushConstantBlock(
4338   const SpvReflectShaderModule*  p_module,
4339   const char*                    entry_point,
4340   SpvReflectResult*              p_result)
4341 {
4342   const SpvReflectBlockVariable* p_push_constant = NULL;
4343   if (IsNotNull(p_module)) {
4344     const SpvReflectEntryPoint* p_entry =
4345         spvReflectGetEntryPoint(p_module, entry_point);
4346     if (IsNull(p_entry)) {
4347       if (IsNotNull(p_result)) {
4348         *p_result = SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
4349       }
4350       return NULL;
4351     }
4352     for (uint32_t i = 0; i < p_module->push_constant_block_count; ++i) {
4353       bool found = SearchSortedUint32(
4354         p_entry->used_push_constants,
4355         p_entry->used_push_constant_count,
4356         p_module->push_constant_blocks[i].spirv_id);
4357       if (found) {
4358         p_push_constant = &p_module->push_constant_blocks[i];
4359         break;
4360       }
4361     }
4362   }
4363   if (IsNotNull(p_result)) {
4364     *p_result = IsNotNull(p_push_constant)
4365         ?  SPV_REFLECT_RESULT_SUCCESS
4366         : (IsNull(p_module) ? SPV_REFLECT_RESULT_ERROR_NULL_POINTER
4367                             : SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND);
4368   }
4369   return p_push_constant;
4370 }
4371 
spvReflectChangeDescriptorBindingNumbers(SpvReflectShaderModule * p_module,const SpvReflectDescriptorBinding * p_binding,uint32_t new_binding_number,uint32_t new_set_binding)4372 SpvReflectResult spvReflectChangeDescriptorBindingNumbers(
4373   SpvReflectShaderModule*            p_module,
4374   const SpvReflectDescriptorBinding* p_binding,
4375   uint32_t                           new_binding_number,
4376   uint32_t                           new_set_binding
4377 )
4378 {
4379   if (IsNull(p_module)) {
4380     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
4381   }
4382   if (IsNull(p_binding)) {
4383     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
4384   }
4385 
4386   SpvReflectDescriptorBinding* p_target_descriptor = NULL;
4387   for (uint32_t index = 0; index < p_module->descriptor_binding_count; ++index) {
4388     if(&p_module->descriptor_bindings[index] == p_binding) {
4389       p_target_descriptor = &p_module->descriptor_bindings[index];
4390       break;
4391     }
4392   }
4393 
4394   if (IsNotNull(p_target_descriptor)) {
4395     if (p_target_descriptor->word_offset.binding > (p_module->_internal->spirv_word_count - 1)) {
4396       return SPV_REFLECT_RESULT_ERROR_RANGE_EXCEEDED;
4397     }
4398     // Binding number
4399     if (new_binding_number != (uint32_t)SPV_REFLECT_BINDING_NUMBER_DONT_CHANGE) {
4400       uint32_t* p_code = p_module->_internal->spirv_code + p_target_descriptor->word_offset.binding;
4401       *p_code = new_binding_number;
4402       p_target_descriptor->binding = new_binding_number;
4403     }
4404     // Set number
4405     if (new_set_binding != (uint32_t)SPV_REFLECT_SET_NUMBER_DONT_CHANGE) {
4406       uint32_t* p_code = p_module->_internal->spirv_code + p_target_descriptor->word_offset.set;
4407       *p_code = new_set_binding;
4408       p_target_descriptor->set = new_set_binding;
4409     }
4410   }
4411 
4412   SpvReflectResult result = SPV_REFLECT_RESULT_SUCCESS;
4413   if (new_set_binding != (uint32_t)SPV_REFLECT_SET_NUMBER_DONT_CHANGE) {
4414     result = SynchronizeDescriptorSets(p_module);
4415   }
4416   return result;
4417 }
spvReflectChangeDescriptorBindingNumber(SpvReflectShaderModule * p_module,const SpvReflectDescriptorBinding * p_descriptor_binding,uint32_t new_binding_number,uint32_t optional_new_set_number)4418 SpvReflectResult spvReflectChangeDescriptorBindingNumber(
4419   SpvReflectShaderModule*            p_module,
4420   const SpvReflectDescriptorBinding* p_descriptor_binding,
4421   uint32_t                           new_binding_number,
4422   uint32_t                           optional_new_set_number
4423 )
4424 {
4425   return spvReflectChangeDescriptorBindingNumbers(
4426     p_module,p_descriptor_binding,
4427     new_binding_number,
4428     optional_new_set_number);
4429 }
4430 
spvReflectChangeDescriptorSetNumber(SpvReflectShaderModule * p_module,const SpvReflectDescriptorSet * p_set,uint32_t new_set_number)4431 SpvReflectResult spvReflectChangeDescriptorSetNumber(
4432   SpvReflectShaderModule*        p_module,
4433   const SpvReflectDescriptorSet* p_set,
4434   uint32_t                       new_set_number
4435 )
4436 {
4437   if (IsNull(p_module)) {
4438     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
4439   }
4440   if (IsNull(p_set)) {
4441     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
4442   }
4443   SpvReflectDescriptorSet* p_target_set = NULL;
4444   for (uint32_t index = 0; index < SPV_REFLECT_MAX_DESCRIPTOR_SETS; ++index) {
4445     // The descriptor sets for specific entry points might not be in this set,
4446     // so just match on set index.
4447     if (p_module->descriptor_sets[index].set == p_set->set) {
4448       p_target_set = (SpvReflectDescriptorSet*)p_set;
4449       break;
4450     }
4451   }
4452 
4453   SpvReflectResult result = SPV_REFLECT_RESULT_SUCCESS;
4454   if (IsNotNull(p_target_set) && new_set_number != (uint32_t)SPV_REFLECT_SET_NUMBER_DONT_CHANGE) {
4455     for (uint32_t index = 0; index < p_target_set->binding_count; ++index) {
4456       SpvReflectDescriptorBinding* p_descriptor = p_target_set->bindings[index];
4457       if (p_descriptor->word_offset.set > (p_module->_internal->spirv_word_count - 1)) {
4458         return SPV_REFLECT_RESULT_ERROR_RANGE_EXCEEDED;
4459       }
4460 
4461       uint32_t* p_code = p_module->_internal->spirv_code + p_descriptor->word_offset.set;
4462       *p_code = new_set_number;
4463       p_descriptor->set = new_set_number;
4464     }
4465 
4466     result = SynchronizeDescriptorSets(p_module);
4467   }
4468 
4469   return result;
4470 }
4471 
ChangeVariableLocation(SpvReflectShaderModule * p_module,SpvReflectInterfaceVariable * p_variable,uint32_t new_location)4472 static SpvReflectResult ChangeVariableLocation(
4473   SpvReflectShaderModule*      p_module,
4474   SpvReflectInterfaceVariable* p_variable,
4475   uint32_t                     new_location
4476 )
4477 {
4478   if (p_variable->word_offset.location > (p_module->_internal->spirv_word_count - 1)) {
4479     return SPV_REFLECT_RESULT_ERROR_RANGE_EXCEEDED;
4480   }
4481   uint32_t* p_code = p_module->_internal->spirv_code + p_variable->word_offset.location;
4482   *p_code = new_location;
4483   p_variable->location = new_location;
4484   return SPV_REFLECT_RESULT_SUCCESS;
4485 }
4486 
spvReflectChangeInputVariableLocation(SpvReflectShaderModule * p_module,const SpvReflectInterfaceVariable * p_input_variable,uint32_t new_location)4487 SpvReflectResult spvReflectChangeInputVariableLocation(
4488   SpvReflectShaderModule*            p_module,
4489   const SpvReflectInterfaceVariable* p_input_variable,
4490   uint32_t                           new_location
4491 )
4492 {
4493   if (IsNull(p_module)) {
4494     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
4495   }
4496   if (IsNull(p_input_variable)) {
4497     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
4498   }
4499   for (uint32_t index = 0; index < p_module->input_variable_count; ++index) {
4500     if(&p_module->input_variables[index] == p_input_variable) {
4501       return ChangeVariableLocation(p_module, &p_module->input_variables[index], new_location);
4502     }
4503   }
4504   return SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
4505 }
4506 
spvReflectChangeOutputVariableLocation(SpvReflectShaderModule * p_module,const SpvReflectInterfaceVariable * p_output_variable,uint32_t new_location)4507 SpvReflectResult spvReflectChangeOutputVariableLocation(
4508   SpvReflectShaderModule*             p_module,
4509   const SpvReflectInterfaceVariable*  p_output_variable,
4510   uint32_t                            new_location
4511 )
4512 {
4513   if (IsNull(p_module)) {
4514     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
4515   }
4516   if (IsNull(p_output_variable)) {
4517     return SPV_REFLECT_RESULT_ERROR_NULL_POINTER;
4518   }
4519   for (uint32_t index = 0; index < p_module->output_variable_count; ++index) {
4520     if(&p_module->output_variables[index] == p_output_variable) {
4521       return ChangeVariableLocation(p_module, &p_module->output_variables[index], new_location);
4522     }
4523   }
4524   return SPV_REFLECT_RESULT_ERROR_ELEMENT_NOT_FOUND;
4525 }
4526 
spvReflectSourceLanguage(SpvSourceLanguage source_lang)4527 const char* spvReflectSourceLanguage(SpvSourceLanguage source_lang)
4528 {
4529   switch (source_lang) {
4530     case SpvSourceLanguageUnknown    : return "Unknown";
4531     case SpvSourceLanguageESSL       : return "ESSL";
4532     case SpvSourceLanguageGLSL       : return "GLSL";
4533     case SpvSourceLanguageOpenCL_C   : return "OpenCL_C";
4534     case SpvSourceLanguageOpenCL_CPP : return "OpenCL_CPP";
4535     case SpvSourceLanguageHLSL       : return "HLSL";
4536 
4537     case SpvSourceLanguageMax:
4538       break;
4539   }
4540   return "";
4541 }
4542