1 use spirv_cross_internal::{hlsl as lang, spirv};
2 
3 mod common;
4 use crate::common::words_from_bytes;
5 
6 #[test]
ast_gets_multiple_entry_points()7 fn ast_gets_multiple_entry_points() {
8     let module = spirv::Module::from_words(words_from_bytes(include_bytes!(
9         "shaders/multiple_entry_points.cl.spv"
10     )));
11     let entry_points = spirv::Ast::<lang::Target>::parse(&module)
12         .unwrap()
13         .get_entry_points()
14         .unwrap();
15 
16     assert_eq!(entry_points.len(), 2);
17     assert!(entry_points.iter().any(|e| e.name == "entry_1"));
18     assert!(entry_points.iter().any(|e| e.name == "entry_2"));
19 }
20 
21 #[test]
ast_gets_shader_resources()22 fn ast_gets_shader_resources() {
23     let module =
24         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
25     let shader_resources = spirv::Ast::<lang::Target>::parse(&module)
26         .unwrap()
27         .get_shader_resources()
28         .unwrap();
29 
30     let spirv::ShaderResources {
31         uniform_buffers,
32         stage_inputs,
33         stage_outputs,
34         ..
35     } = shader_resources;
36 
37     assert_eq!(uniform_buffers.len(), 1);
38     assert_eq!(uniform_buffers[0].name, "uniform_buffer_object");
39     assert_eq!(shader_resources.storage_buffers.len(), 0);
40     assert_eq!(stage_inputs.len(), 2);
41     assert!(stage_inputs
42         .iter()
43         .any(|stage_input| stage_input.name == "a_normal"));
44     assert!(stage_inputs
45         .iter()
46         .any(|stage_input| stage_input.name == "a_position"));
47     assert_eq!(stage_outputs.len(), 1);
48     assert!(stage_outputs
49         .iter()
50         .any(|stage_output| stage_output.name == "v_normal"));
51     assert_eq!(shader_resources.subpass_inputs.len(), 0);
52     assert_eq!(shader_resources.storage_images.len(), 0);
53     assert_eq!(shader_resources.sampled_images.len(), 0);
54     assert_eq!(shader_resources.atomic_counters.len(), 0);
55     assert_eq!(shader_resources.push_constant_buffers.len(), 0);
56     assert_eq!(shader_resources.separate_images.len(), 0);
57     assert_eq!(shader_resources.separate_samplers.len(), 0);
58 }
59 
60 #[test]
ast_gets_decoration()61 fn ast_gets_decoration() {
62     let module =
63         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
64     let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
65 
66     let stage_inputs = ast.get_shader_resources().unwrap().stage_inputs;
67     let decoration = ast
68         .get_decoration(stage_inputs[0].id, spirv::Decoration::DescriptorSet)
69         .unwrap();
70     assert_eq!(decoration, 0);
71 }
72 
73 #[test]
ast_sets_decoration()74 fn ast_sets_decoration() {
75     let module =
76         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
77     let mut ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
78 
79     let stage_inputs = ast.get_shader_resources().unwrap().stage_inputs;
80     let updated_value = 3;
81     ast.set_decoration(
82         stage_inputs[0].id,
83         spirv::Decoration::DescriptorSet,
84         updated_value,
85     )
86     .unwrap();
87     assert_eq!(
88         ast.get_decoration(stage_inputs[0].id, spirv::Decoration::DescriptorSet)
89             .unwrap(),
90         updated_value
91     );
92 }
93 
94 #[test]
ast_gets_type_member_types_and_array()95 fn ast_gets_type_member_types_and_array() {
96     let module =
97         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
98     let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
99 
100     let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
101 
102     let is_struct = match ast.get_type(uniform_buffers[0].base_type_id).unwrap() {
103         spirv::Type::Struct {
104             member_types,
105             array,
106         } => {
107             assert_eq!(member_types.len(), 2);
108             assert_eq!(array.len(), 0);
109             true
110         }
111         _ => false,
112     };
113 
114     assert!(is_struct);
115 }
116 
117 #[test]
ast_gets_array_dimensions()118 fn ast_gets_array_dimensions() {
119     let module =
120         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/array.vert.spv")));
121     let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
122 
123     let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
124 
125     let is_struct = match ast.get_type(uniform_buffers[0].base_type_id).unwrap() {
126         spirv::Type::Struct { member_types, .. } => {
127             assert_eq!(member_types.len(), 3);
128             let is_float = match ast.get_type(member_types[2]).unwrap() {
129                 spirv::Type::Float {
130                     vecsize,
131                     columns,
132                     array,
133                 } => {
134                     assert_eq!(vecsize, 3);
135                     assert_eq!(columns, 1);
136                     assert_eq!(array.len(), 1);
137                     assert_eq!(array[0], 3);
138                     true
139                 }
140                 _ => false,
141             };
142             assert!(is_float);
143             true
144         }
145         _ => false,
146     };
147 
148     assert!(is_struct);
149 }
150 
151 #[test]
ast_gets_declared_struct_size_and_struct_member_size()152 fn ast_gets_declared_struct_size_and_struct_member_size() {
153     let module =
154         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
155     let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
156     let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
157     let mat4_size = 4 * 16;
158     let float_size = 4;
159     assert_eq!(
160         ast.get_declared_struct_size(uniform_buffers[0].base_type_id)
161             .unwrap(),
162         mat4_size + float_size
163     );
164     assert_eq!(
165         ast.get_declared_struct_member_size(uniform_buffers[0].base_type_id, 0)
166             .unwrap(),
167         mat4_size
168     );
169     assert_eq!(
170         ast.get_declared_struct_member_size(uniform_buffers[0].base_type_id, 1)
171             .unwrap(),
172         float_size
173     );
174 }
175 
176 #[test]
ast_gets_member_name()177 fn ast_gets_member_name() {
178     let module =
179         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
180     let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
181 
182     let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
183 
184     assert_eq!(
185         ast.get_member_name(uniform_buffers[0].base_type_id, 0)
186             .unwrap(),
187         "u_model_view_projection"
188     );
189 }
190 
191 #[test]
ast_gets_member_decoration()192 fn ast_gets_member_decoration() {
193     let module =
194         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
195     let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
196 
197     let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
198 
199     assert_eq!(
200         ast.get_member_decoration(
201             uniform_buffers[0].base_type_id,
202             1,
203             spirv::Decoration::Offset
204         )
205         .unwrap(),
206         64
207     );
208 }
209 
210 #[test]
ast_sets_member_decoration()211 fn ast_sets_member_decoration() {
212     let module =
213         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
214     let mut ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
215 
216     let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
217 
218     let new_offset = 128;
219 
220     ast.set_member_decoration(
221         uniform_buffers[0].base_type_id,
222         1,
223         spirv::Decoration::Offset,
224         new_offset,
225     )
226     .unwrap();
227 
228     assert_eq!(
229         ast.get_member_decoration(
230             uniform_buffers[0].base_type_id,
231             1,
232             spirv::Decoration::Offset
233         )
234         .unwrap(),
235         new_offset
236     );
237 }
238 
239 #[test]
ast_gets_specialization_constants()240 fn ast_gets_specialization_constants() {
241     let comp = spirv::Module::from_words(words_from_bytes(include_bytes!(
242         "shaders/specialization.comp.spv"
243     )));
244     let comp_ast = spirv::Ast::<lang::Target>::parse(&comp).unwrap();
245     let specialization_constants = comp_ast.get_specialization_constants().unwrap();
246     assert_eq!(specialization_constants[0].constant_id, 10);
247 }
248 
249 #[test]
ast_gets_work_group_size_specialization_constants()250 fn ast_gets_work_group_size_specialization_constants() {
251     let comp = spirv::Module::from_words(words_from_bytes(include_bytes!(
252         "shaders/workgroup.comp.spv"
253     )));
254     let comp_ast = spirv::Ast::<lang::Target>::parse(&comp).unwrap();
255     let work_group_size = comp_ast
256         .get_work_group_size_specialization_constants()
257         .unwrap();
258     assert_eq!(
259         work_group_size,
260         spirv::WorkGroupSizeSpecializationConstants {
261             x: spirv::SpecializationConstant {
262                 id: 7,
263                 constant_id: 5,
264             },
265             y: spirv::SpecializationConstant {
266                 id: 8,
267                 constant_id: 10,
268             },
269             z: spirv::SpecializationConstant {
270                 id: 9,
271                 constant_id: 15,
272             },
273         }
274     );
275 }
276 
277 #[test]
ast_gets_active_buffer_ranges()278 fn ast_gets_active_buffer_ranges() {
279     let module =
280         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/two_ubo.vert.spv")));
281     let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
282 
283     let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
284     assert_eq!(uniform_buffers.len(), 2);
285 
286     let ubo1 = ast.get_active_buffer_ranges(uniform_buffers[0].id).unwrap();
287     assert_eq!(
288         ubo1,
289         [
290             spirv::BufferRange {
291                 index: 0,
292                 offset: 0,
293                 range: 64,
294             },
295             spirv::BufferRange {
296                 index: 1,
297                 offset: 64,
298                 range: 16,
299             },
300             spirv::BufferRange {
301                 index: 2,
302                 offset: 80,
303                 range: 32,
304             }
305         ]
306     );
307 
308     let ubo2 = ast.get_active_buffer_ranges(uniform_buffers[1].id).unwrap();
309     assert_eq!(
310         ubo2,
311         [
312             spirv::BufferRange {
313                 index: 0,
314                 offset: 0,
315                 range: 16,
316             },
317             spirv::BufferRange {
318                 index: 1,
319                 offset: 16,
320                 range: 16,
321             },
322             spirv::BufferRange {
323                 index: 2,
324                 offset: 32,
325                 range: 12,
326             }
327         ]
328     );
329 }
330