1 use spirv_cross_internal::{hlsl as lang, spirv};
2
3 mod common;
4 use crate::common::words_from_bytes;
5
6 #[test]
ast_gets_multiple_entry_points()7 fn ast_gets_multiple_entry_points() {
8 let module = spirv::Module::from_words(words_from_bytes(include_bytes!(
9 "shaders/multiple_entry_points.cl.spv"
10 )));
11 let entry_points = spirv::Ast::<lang::Target>::parse(&module)
12 .unwrap()
13 .get_entry_points()
14 .unwrap();
15
16 assert_eq!(entry_points.len(), 2);
17 assert!(entry_points.iter().any(|e| e.name == "entry_1"));
18 assert!(entry_points.iter().any(|e| e.name == "entry_2"));
19 }
20
21 #[test]
ast_gets_shader_resources()22 fn ast_gets_shader_resources() {
23 let module =
24 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
25 let shader_resources = spirv::Ast::<lang::Target>::parse(&module)
26 .unwrap()
27 .get_shader_resources()
28 .unwrap();
29
30 let spirv::ShaderResources {
31 uniform_buffers,
32 stage_inputs,
33 stage_outputs,
34 ..
35 } = shader_resources;
36
37 assert_eq!(uniform_buffers.len(), 1);
38 assert_eq!(uniform_buffers[0].name, "uniform_buffer_object");
39 assert_eq!(shader_resources.storage_buffers.len(), 0);
40 assert_eq!(stage_inputs.len(), 2);
41 assert!(stage_inputs
42 .iter()
43 .any(|stage_input| stage_input.name == "a_normal"));
44 assert!(stage_inputs
45 .iter()
46 .any(|stage_input| stage_input.name == "a_position"));
47 assert_eq!(stage_outputs.len(), 1);
48 assert!(stage_outputs
49 .iter()
50 .any(|stage_output| stage_output.name == "v_normal"));
51 assert_eq!(shader_resources.subpass_inputs.len(), 0);
52 assert_eq!(shader_resources.storage_images.len(), 0);
53 assert_eq!(shader_resources.sampled_images.len(), 0);
54 assert_eq!(shader_resources.atomic_counters.len(), 0);
55 assert_eq!(shader_resources.push_constant_buffers.len(), 0);
56 assert_eq!(shader_resources.separate_images.len(), 0);
57 assert_eq!(shader_resources.separate_samplers.len(), 0);
58 }
59
60 #[test]
ast_gets_decoration()61 fn ast_gets_decoration() {
62 let module =
63 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
64 let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
65
66 let stage_inputs = ast.get_shader_resources().unwrap().stage_inputs;
67 let decoration = ast
68 .get_decoration(stage_inputs[0].id, spirv::Decoration::DescriptorSet)
69 .unwrap();
70 assert_eq!(decoration, 0);
71 }
72
73 #[test]
ast_sets_decoration()74 fn ast_sets_decoration() {
75 let module =
76 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
77 let mut ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
78
79 let stage_inputs = ast.get_shader_resources().unwrap().stage_inputs;
80 let updated_value = 3;
81 ast.set_decoration(
82 stage_inputs[0].id,
83 spirv::Decoration::DescriptorSet,
84 updated_value,
85 )
86 .unwrap();
87 assert_eq!(
88 ast.get_decoration(stage_inputs[0].id, spirv::Decoration::DescriptorSet)
89 .unwrap(),
90 updated_value
91 );
92 }
93
94 #[test]
ast_gets_type_member_types_and_array()95 fn ast_gets_type_member_types_and_array() {
96 let module =
97 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
98 let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
99
100 let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
101
102 let is_struct = match ast.get_type(uniform_buffers[0].base_type_id).unwrap() {
103 spirv::Type::Struct {
104 member_types,
105 array,
106 } => {
107 assert_eq!(member_types.len(), 2);
108 assert_eq!(array.len(), 0);
109 true
110 }
111 _ => false,
112 };
113
114 assert!(is_struct);
115 }
116
117 #[test]
ast_gets_array_dimensions()118 fn ast_gets_array_dimensions() {
119 let module =
120 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/array.vert.spv")));
121 let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
122
123 let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
124
125 let is_struct = match ast.get_type(uniform_buffers[0].base_type_id).unwrap() {
126 spirv::Type::Struct { member_types, .. } => {
127 assert_eq!(member_types.len(), 3);
128 let is_float = match ast.get_type(member_types[2]).unwrap() {
129 spirv::Type::Float {
130 vecsize,
131 columns,
132 array,
133 } => {
134 assert_eq!(vecsize, 3);
135 assert_eq!(columns, 1);
136 assert_eq!(array.len(), 1);
137 assert_eq!(array[0], 3);
138 true
139 }
140 _ => false,
141 };
142 assert!(is_float);
143 true
144 }
145 _ => false,
146 };
147
148 assert!(is_struct);
149 }
150
151 #[test]
ast_gets_declared_struct_size_and_struct_member_size()152 fn ast_gets_declared_struct_size_and_struct_member_size() {
153 let module =
154 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
155 let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
156 let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
157 let mat4_size = 4 * 16;
158 let float_size = 4;
159 assert_eq!(
160 ast.get_declared_struct_size(uniform_buffers[0].base_type_id)
161 .unwrap(),
162 mat4_size + float_size
163 );
164 assert_eq!(
165 ast.get_declared_struct_member_size(uniform_buffers[0].base_type_id, 0)
166 .unwrap(),
167 mat4_size
168 );
169 assert_eq!(
170 ast.get_declared_struct_member_size(uniform_buffers[0].base_type_id, 1)
171 .unwrap(),
172 float_size
173 );
174 }
175
176 #[test]
ast_gets_member_name()177 fn ast_gets_member_name() {
178 let module =
179 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
180 let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
181
182 let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
183
184 assert_eq!(
185 ast.get_member_name(uniform_buffers[0].base_type_id, 0)
186 .unwrap(),
187 "u_model_view_projection"
188 );
189 }
190
191 #[test]
ast_gets_member_decoration()192 fn ast_gets_member_decoration() {
193 let module =
194 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
195 let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
196
197 let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
198
199 assert_eq!(
200 ast.get_member_decoration(
201 uniform_buffers[0].base_type_id,
202 1,
203 spirv::Decoration::Offset
204 )
205 .unwrap(),
206 64
207 );
208 }
209
210 #[test]
ast_sets_member_decoration()211 fn ast_sets_member_decoration() {
212 let module =
213 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
214 let mut ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
215
216 let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
217
218 let new_offset = 128;
219
220 ast.set_member_decoration(
221 uniform_buffers[0].base_type_id,
222 1,
223 spirv::Decoration::Offset,
224 new_offset,
225 )
226 .unwrap();
227
228 assert_eq!(
229 ast.get_member_decoration(
230 uniform_buffers[0].base_type_id,
231 1,
232 spirv::Decoration::Offset
233 )
234 .unwrap(),
235 new_offset
236 );
237 }
238
239 #[test]
ast_gets_specialization_constants()240 fn ast_gets_specialization_constants() {
241 let comp = spirv::Module::from_words(words_from_bytes(include_bytes!(
242 "shaders/specialization.comp.spv"
243 )));
244 let comp_ast = spirv::Ast::<lang::Target>::parse(&comp).unwrap();
245 let specialization_constants = comp_ast.get_specialization_constants().unwrap();
246 assert_eq!(specialization_constants[0].constant_id, 10);
247 }
248
249 #[test]
ast_gets_work_group_size_specialization_constants()250 fn ast_gets_work_group_size_specialization_constants() {
251 let comp = spirv::Module::from_words(words_from_bytes(include_bytes!(
252 "shaders/workgroup.comp.spv"
253 )));
254 let comp_ast = spirv::Ast::<lang::Target>::parse(&comp).unwrap();
255 let work_group_size = comp_ast
256 .get_work_group_size_specialization_constants()
257 .unwrap();
258 assert_eq!(
259 work_group_size,
260 spirv::WorkGroupSizeSpecializationConstants {
261 x: spirv::SpecializationConstant {
262 id: 7,
263 constant_id: 5,
264 },
265 y: spirv::SpecializationConstant {
266 id: 8,
267 constant_id: 10,
268 },
269 z: spirv::SpecializationConstant {
270 id: 9,
271 constant_id: 15,
272 },
273 }
274 );
275 }
276
277 #[test]
ast_gets_active_buffer_ranges()278 fn ast_gets_active_buffer_ranges() {
279 let module =
280 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/two_ubo.vert.spv")));
281 let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
282
283 let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
284 assert_eq!(uniform_buffers.len(), 2);
285
286 let ubo1 = ast.get_active_buffer_ranges(uniform_buffers[0].id).unwrap();
287 assert_eq!(
288 ubo1,
289 [
290 spirv::BufferRange {
291 index: 0,
292 offset: 0,
293 range: 64,
294 },
295 spirv::BufferRange {
296 index: 1,
297 offset: 64,
298 range: 16,
299 },
300 spirv::BufferRange {
301 index: 2,
302 offset: 80,
303 range: 32,
304 }
305 ]
306 );
307
308 let ubo2 = ast.get_active_buffer_ranges(uniform_buffers[1].id).unwrap();
309 assert_eq!(
310 ubo2,
311 [
312 spirv::BufferRange {
313 index: 0,
314 offset: 0,
315 range: 16,
316 },
317 spirv::BufferRange {
318 index: 1,
319 offset: 16,
320 range: 16,
321 },
322 spirv::BufferRange {
323 index: 2,
324 offset: 32,
325 range: 12,
326 }
327 ]
328 );
329 }
330