1 use spirv_cross::{hlsl as lang, spirv};
2 
3 mod common;
4 use crate::common::words_from_bytes;
5 
6 #[test]
ast_gets_entry_points()7 fn ast_gets_entry_points() {
8     let module =
9         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
10     let entry_points = spirv::Ast::<lang::Target>::parse(&module)
11         .unwrap()
12         .get_entry_points()
13         .unwrap();
14 
15     assert_eq!(entry_points.len(), 1);
16     assert_eq!(entry_points[0].name, "main");
17 }
18 
19 #[test]
ast_gets_shader_resources()20 fn ast_gets_shader_resources() {
21     let module =
22         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
23     let shader_resources = spirv::Ast::<lang::Target>::parse(&module)
24         .unwrap()
25         .get_shader_resources()
26         .unwrap();
27 
28     let spirv::ShaderResources {
29         uniform_buffers,
30         stage_inputs,
31         stage_outputs,
32         ..
33     } = shader_resources;
34 
35     assert_eq!(uniform_buffers.len(), 1);
36     assert_eq!(uniform_buffers[0].name, "uniform_buffer_object");
37     assert_eq!(shader_resources.storage_buffers.len(), 0);
38     assert_eq!(stage_inputs.len(), 2);
39     assert!(stage_inputs
40         .iter()
41         .any(|stage_input| stage_input.name == "a_normal"));
42     assert!(stage_inputs
43         .iter()
44         .any(|stage_input| stage_input.name == "a_position"));
45     assert_eq!(stage_outputs.len(), 1);
46     assert!(stage_outputs
47         .iter()
48         .any(|stage_output| stage_output.name == "v_normal"));
49     assert_eq!(shader_resources.subpass_inputs.len(), 0);
50     assert_eq!(shader_resources.storage_images.len(), 0);
51     assert_eq!(shader_resources.sampled_images.len(), 0);
52     assert_eq!(shader_resources.atomic_counters.len(), 0);
53     assert_eq!(shader_resources.push_constant_buffers.len(), 0);
54     assert_eq!(shader_resources.separate_images.len(), 0);
55     assert_eq!(shader_resources.separate_samplers.len(), 0);
56 }
57 
58 #[test]
ast_gets_decoration()59 fn ast_gets_decoration() {
60     let module =
61         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
62     let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
63 
64     let stage_inputs = ast.get_shader_resources().unwrap().stage_inputs;
65     let decoration = ast
66         .get_decoration(stage_inputs[0].id, spirv::Decoration::DescriptorSet)
67         .unwrap();
68     assert_eq!(decoration, 0);
69 }
70 
71 #[test]
ast_sets_decoration()72 fn ast_sets_decoration() {
73     let module =
74         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
75     let mut ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
76 
77     let stage_inputs = ast.get_shader_resources().unwrap().stage_inputs;
78     let updated_value = 3;
79     ast.set_decoration(
80         stage_inputs[0].id,
81         spirv::Decoration::DescriptorSet,
82         updated_value,
83     )
84     .unwrap();
85     assert_eq!(
86         ast.get_decoration(stage_inputs[0].id, spirv::Decoration::DescriptorSet)
87             .unwrap(),
88         updated_value
89     );
90 }
91 
92 #[test]
ast_gets_type_member_types_and_array()93 fn ast_gets_type_member_types_and_array() {
94     let module =
95         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
96     let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
97 
98     let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
99 
100     let is_struct = match ast.get_type(uniform_buffers[0].base_type_id).unwrap() {
101         spirv::Type::Struct {
102             member_types,
103             array,
104         } => {
105             assert_eq!(member_types.len(), 2);
106             assert_eq!(array.len(), 0);
107             true
108         }
109         _ => false,
110     };
111 
112     assert!(is_struct);
113 }
114 
115 #[test]
ast_gets_array_dimensions()116 fn ast_gets_array_dimensions() {
117     let module =
118         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/array.vert.spv")));
119     let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
120 
121     let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
122 
123     let is_struct = match ast.get_type(uniform_buffers[0].base_type_id).unwrap() {
124         spirv::Type::Struct { member_types, .. } => {
125             assert_eq!(member_types.len(), 3);
126             let is_float = match ast.get_type(member_types[2]).unwrap() {
127                 spirv::Type::Float { array } => {
128                     assert_eq!(array.len(), 1);
129                     assert_eq!(array[0], 3);
130                     true
131                 }
132                 _ => false,
133             };
134             assert!(is_float);
135             true
136         }
137         _ => false,
138     };
139 
140     assert!(is_struct);
141 }
142 
143 #[test]
ast_gets_declared_struct_size_and_struct_member_size()144 fn ast_gets_declared_struct_size_and_struct_member_size() {
145     let module =
146         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
147     let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
148     let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
149     let mat4_size = 4 * 16;
150     let float_size = 4;
151     assert_eq!(
152         ast.get_declared_struct_size(uniform_buffers[0].base_type_id)
153             .unwrap(),
154         mat4_size + float_size
155     );
156     assert_eq!(
157         ast.get_declared_struct_member_size(uniform_buffers[0].base_type_id, 0)
158             .unwrap(),
159         mat4_size
160     );
161     assert_eq!(
162         ast.get_declared_struct_member_size(uniform_buffers[0].base_type_id, 1)
163             .unwrap(),
164         float_size
165     );
166 }
167 
168 #[test]
ast_gets_member_name()169 fn ast_gets_member_name() {
170     let module =
171         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
172     let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
173 
174     let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
175 
176     assert_eq!(
177         ast.get_member_name(uniform_buffers[0].base_type_id, 0)
178             .unwrap(),
179         "u_model_view_projection"
180     );
181 }
182 
183 #[test]
ast_gets_member_decoration()184 fn ast_gets_member_decoration() {
185     let module =
186         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
187     let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
188 
189     let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
190 
191     assert_eq!(
192         ast.get_member_decoration(
193             uniform_buffers[0].base_type_id,
194             1,
195             spirv::Decoration::Offset
196         )
197         .unwrap(),
198         64
199     );
200 }
201 
202 #[test]
ast_sets_member_decoration()203 fn ast_sets_member_decoration() {
204     let module =
205         spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
206     let mut ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
207 
208     let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
209 
210     let new_offset = 128;
211 
212     ast.set_member_decoration(
213         uniform_buffers[0].base_type_id,
214         1,
215         spirv::Decoration::Offset,
216         new_offset,
217     )
218     .unwrap();
219 
220     assert_eq!(
221         ast.get_member_decoration(
222             uniform_buffers[0].base_type_id,
223             1,
224             spirv::Decoration::Offset
225         )
226         .unwrap(),
227         new_offset
228     );
229 }
230 
231 #[test]
as_gets_specialization_constants()232 fn as_gets_specialization_constants() {
233     let comp = spirv::Module::from_words(words_from_bytes(include_bytes!(
234         "shaders/specialization.comp.spv"
235     )));
236     let comp_ast = spirv::Ast::<lang::Target>::parse(&comp).unwrap();
237     let specialization_constants = comp_ast.get_specialization_constants().unwrap();
238     assert_eq!(specialization_constants[0].constant_id, 10);
239 }
240 
241 #[test]
as_gets_work_group_size_specialization_constants()242 fn as_gets_work_group_size_specialization_constants() {
243     let comp = spirv::Module::from_words(words_from_bytes(include_bytes!(
244         "shaders/workgroup.comp.spv"
245     )));
246     let comp_ast = spirv::Ast::<lang::Target>::parse(&comp).unwrap();
247     let work_group_size = comp_ast
248         .get_work_group_size_specialization_constants()
249         .unwrap();
250     assert_eq!(
251         work_group_size,
252         spirv::WorkGroupSizeSpecializationConstants {
253             x: spirv::SpecializationConstant {
254                 id: 7,
255                 constant_id: 5,
256             },
257             y: spirv::SpecializationConstant {
258                 id: 8,
259                 constant_id: 10,
260             },
261             z: spirv::SpecializationConstant {
262                 id: 9,
263                 constant_id: 15,
264             },
265         }
266     );
267 }
268