1 use spirv_cross::{hlsl as lang, spirv};
2
3 mod common;
4 use crate::common::words_from_bytes;
5
6 #[test]
ast_gets_entry_points()7 fn ast_gets_entry_points() {
8 let module =
9 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
10 let entry_points = spirv::Ast::<lang::Target>::parse(&module)
11 .unwrap()
12 .get_entry_points()
13 .unwrap();
14
15 assert_eq!(entry_points.len(), 1);
16 assert_eq!(entry_points[0].name, "main");
17 }
18
19 #[test]
ast_gets_shader_resources()20 fn ast_gets_shader_resources() {
21 let module =
22 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
23 let shader_resources = spirv::Ast::<lang::Target>::parse(&module)
24 .unwrap()
25 .get_shader_resources()
26 .unwrap();
27
28 let spirv::ShaderResources {
29 uniform_buffers,
30 stage_inputs,
31 stage_outputs,
32 ..
33 } = shader_resources;
34
35 assert_eq!(uniform_buffers.len(), 1);
36 assert_eq!(uniform_buffers[0].name, "uniform_buffer_object");
37 assert_eq!(shader_resources.storage_buffers.len(), 0);
38 assert_eq!(stage_inputs.len(), 2);
39 assert!(stage_inputs
40 .iter()
41 .any(|stage_input| stage_input.name == "a_normal"));
42 assert!(stage_inputs
43 .iter()
44 .any(|stage_input| stage_input.name == "a_position"));
45 assert_eq!(stage_outputs.len(), 1);
46 assert!(stage_outputs
47 .iter()
48 .any(|stage_output| stage_output.name == "v_normal"));
49 assert_eq!(shader_resources.subpass_inputs.len(), 0);
50 assert_eq!(shader_resources.storage_images.len(), 0);
51 assert_eq!(shader_resources.sampled_images.len(), 0);
52 assert_eq!(shader_resources.atomic_counters.len(), 0);
53 assert_eq!(shader_resources.push_constant_buffers.len(), 0);
54 assert_eq!(shader_resources.separate_images.len(), 0);
55 assert_eq!(shader_resources.separate_samplers.len(), 0);
56 }
57
58 #[test]
ast_gets_decoration()59 fn ast_gets_decoration() {
60 let module =
61 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
62 let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
63
64 let stage_inputs = ast.get_shader_resources().unwrap().stage_inputs;
65 let decoration = ast
66 .get_decoration(stage_inputs[0].id, spirv::Decoration::DescriptorSet)
67 .unwrap();
68 assert_eq!(decoration, 0);
69 }
70
71 #[test]
ast_sets_decoration()72 fn ast_sets_decoration() {
73 let module =
74 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
75 let mut ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
76
77 let stage_inputs = ast.get_shader_resources().unwrap().stage_inputs;
78 let updated_value = 3;
79 ast.set_decoration(
80 stage_inputs[0].id,
81 spirv::Decoration::DescriptorSet,
82 updated_value,
83 )
84 .unwrap();
85 assert_eq!(
86 ast.get_decoration(stage_inputs[0].id, spirv::Decoration::DescriptorSet)
87 .unwrap(),
88 updated_value
89 );
90 }
91
92 #[test]
ast_gets_type_member_types_and_array()93 fn ast_gets_type_member_types_and_array() {
94 let module =
95 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
96 let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
97
98 let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
99
100 let is_struct = match ast.get_type(uniform_buffers[0].base_type_id).unwrap() {
101 spirv::Type::Struct {
102 member_types,
103 array,
104 } => {
105 assert_eq!(member_types.len(), 2);
106 assert_eq!(array.len(), 0);
107 true
108 }
109 _ => false,
110 };
111
112 assert!(is_struct);
113 }
114
115 #[test]
ast_gets_array_dimensions()116 fn ast_gets_array_dimensions() {
117 let module =
118 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/array.vert.spv")));
119 let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
120
121 let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
122
123 let is_struct = match ast.get_type(uniform_buffers[0].base_type_id).unwrap() {
124 spirv::Type::Struct { member_types, .. } => {
125 assert_eq!(member_types.len(), 3);
126 let is_float = match ast.get_type(member_types[2]).unwrap() {
127 spirv::Type::Float { array } => {
128 assert_eq!(array.len(), 1);
129 assert_eq!(array[0], 3);
130 true
131 }
132 _ => false,
133 };
134 assert!(is_float);
135 true
136 }
137 _ => false,
138 };
139
140 assert!(is_struct);
141 }
142
143 #[test]
ast_gets_declared_struct_size_and_struct_member_size()144 fn ast_gets_declared_struct_size_and_struct_member_size() {
145 let module =
146 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
147 let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
148 let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
149 let mat4_size = 4 * 16;
150 let float_size = 4;
151 assert_eq!(
152 ast.get_declared_struct_size(uniform_buffers[0].base_type_id)
153 .unwrap(),
154 mat4_size + float_size
155 );
156 assert_eq!(
157 ast.get_declared_struct_member_size(uniform_buffers[0].base_type_id, 0)
158 .unwrap(),
159 mat4_size
160 );
161 assert_eq!(
162 ast.get_declared_struct_member_size(uniform_buffers[0].base_type_id, 1)
163 .unwrap(),
164 float_size
165 );
166 }
167
168 #[test]
ast_gets_member_name()169 fn ast_gets_member_name() {
170 let module =
171 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
172 let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
173
174 let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
175
176 assert_eq!(
177 ast.get_member_name(uniform_buffers[0].base_type_id, 0)
178 .unwrap(),
179 "u_model_view_projection"
180 );
181 }
182
183 #[test]
ast_gets_member_decoration()184 fn ast_gets_member_decoration() {
185 let module =
186 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
187 let ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
188
189 let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
190
191 assert_eq!(
192 ast.get_member_decoration(
193 uniform_buffers[0].base_type_id,
194 1,
195 spirv::Decoration::Offset
196 )
197 .unwrap(),
198 64
199 );
200 }
201
202 #[test]
ast_sets_member_decoration()203 fn ast_sets_member_decoration() {
204 let module =
205 spirv::Module::from_words(words_from_bytes(include_bytes!("shaders/simple.vert.spv")));
206 let mut ast = spirv::Ast::<lang::Target>::parse(&module).unwrap();
207
208 let uniform_buffers = ast.get_shader_resources().unwrap().uniform_buffers;
209
210 let new_offset = 128;
211
212 ast.set_member_decoration(
213 uniform_buffers[0].base_type_id,
214 1,
215 spirv::Decoration::Offset,
216 new_offset,
217 )
218 .unwrap();
219
220 assert_eq!(
221 ast.get_member_decoration(
222 uniform_buffers[0].base_type_id,
223 1,
224 spirv::Decoration::Offset
225 )
226 .unwrap(),
227 new_offset
228 );
229 }
230
231 #[test]
as_gets_specialization_constants()232 fn as_gets_specialization_constants() {
233 let comp = spirv::Module::from_words(words_from_bytes(include_bytes!(
234 "shaders/specialization.comp.spv"
235 )));
236 let comp_ast = spirv::Ast::<lang::Target>::parse(&comp).unwrap();
237 let specialization_constants = comp_ast.get_specialization_constants().unwrap();
238 assert_eq!(specialization_constants[0].constant_id, 10);
239 }
240
241 #[test]
as_gets_work_group_size_specialization_constants()242 fn as_gets_work_group_size_specialization_constants() {
243 let comp = spirv::Module::from_words(words_from_bytes(include_bytes!(
244 "shaders/workgroup.comp.spv"
245 )));
246 let comp_ast = spirv::Ast::<lang::Target>::parse(&comp).unwrap();
247 let work_group_size = comp_ast
248 .get_work_group_size_specialization_constants()
249 .unwrap();
250 assert_eq!(
251 work_group_size,
252 spirv::WorkGroupSizeSpecializationConstants {
253 x: spirv::SpecializationConstant {
254 id: 7,
255 constant_id: 5,
256 },
257 y: spirv::SpecializationConstant {
258 id: 8,
259 constant_id: 10,
260 },
261 z: spirv::SpecializationConstant {
262 id: 9,
263 constant_id: 15,
264 },
265 }
266 );
267 }
268