1 //TODO: move this to a binary target once Rust supports
2 // binary-specific dependencies.
3 
4 use std::{fs, path::PathBuf};
5 
6 const DIR_IN: &str = "tests/in";
7 const DIR_OUT: &str = "tests/out";
8 
9 bitflags::bitflags! {
10     struct Targets: u32 {
11         const IR = 0x1;
12         const ANALYSIS = 0x2;
13         const SPIRV = 0x4;
14         const METAL = 0x8;
15         const GLSL = 0x10;
16         const DOT = 0x20;
17         const HLSL = 0x40;
18         const WGSL = 0x80;
19     }
20 }
21 
22 #[derive(Default, serde::Deserialize)]
23 struct Parameters {
24     #[serde(default)]
25     god_mode: bool,
26     #[cfg_attr(not(feature = "spv-out"), allow(dead_code))]
27     spv_version: (u8, u8),
28     #[cfg_attr(not(feature = "spv-out"), allow(dead_code))]
29     spv_capabilities: naga::FastHashSet<spirv::Capability>,
30     #[cfg_attr(not(feature = "spv-out"), allow(dead_code))]
31     #[serde(default)]
32     spv_debug: bool,
33     #[cfg_attr(not(feature = "spv-out"), allow(dead_code))]
34     #[serde(default)]
35     spv_adjust_coordinate_space: bool,
36     #[cfg(all(feature = "deserialize", feature = "msl-out"))]
37     #[serde(default)]
38     msl: naga::back::msl::Options,
39     #[cfg(all(not(feature = "deserialize"), feature = "msl-out"))]
40     #[serde(default)]
41     msl_custom: bool,
42     #[cfg_attr(not(feature = "glsl-out"), allow(dead_code))]
43     #[serde(default)]
44     glsl_desktop_version: Option<u16>,
45 }
46 
47 #[allow(dead_code, unused_variables)]
check_targets(module: &naga::Module, name: &str, targets: Targets)48 fn check_targets(module: &naga::Module, name: &str, targets: Targets) {
49     let root = env!("CARGO_MANIFEST_DIR");
50     let params = match fs::read_to_string(format!("{}/{}/{}.param.ron", root, DIR_IN, name)) {
51         Ok(string) => ron::de::from_str(&string).expect("Couldn't find param file"),
52         Err(_) => Parameters::default(),
53     };
54     let capabilities = if params.god_mode {
55         naga::valid::Capabilities::all()
56     } else {
57         naga::valid::Capabilities::empty()
58     };
59     let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities)
60         .validate(module)
61         .unwrap();
62 
63     let dest = PathBuf::from(root).join(DIR_OUT).join(name);
64 
65     #[cfg(feature = "serialize")]
66     {
67         if targets.contains(Targets::IR) {
68             let config = ron::ser::PrettyConfig::default().with_new_line("\n".to_string());
69             let string = ron::ser::to_string_pretty(module, config).unwrap();
70             fs::write(dest.with_extension("ron"), string).unwrap();
71         }
72         if targets.contains(Targets::ANALYSIS) {
73             let config = ron::ser::PrettyConfig::default().with_new_line("\n".to_string());
74             let string = ron::ser::to_string_pretty(&info, config).unwrap();
75             fs::write(dest.with_extension("info.ron"), string).unwrap();
76         }
77     }
78 
79     #[cfg(feature = "spv-out")]
80     {
81         if targets.contains(Targets::SPIRV) {
82             check_output_spv(module, &info, &dest, &params);
83         }
84     }
85     #[cfg(feature = "msl-out")]
86     {
87         if targets.contains(Targets::METAL) {
88             check_output_msl(module, &info, &dest, &params);
89         }
90     }
91     #[cfg(feature = "glsl-out")]
92     {
93         if targets.contains(Targets::GLSL) {
94             for ep in module.entry_points.iter() {
95                 check_output_glsl(module, &info, &dest, ep.stage, &ep.name, &params);
96             }
97         }
98     }
99     #[cfg(feature = "dot-out")]
100     {
101         if targets.contains(Targets::DOT) {
102             let string = naga::back::dot::write(module, Some(&info)).unwrap();
103             fs::write(dest.with_extension("dot"), string).unwrap();
104         }
105     }
106     #[cfg(feature = "hlsl-out")]
107     {
108         if targets.contains(Targets::HLSL) {
109             for ep in module.entry_points.iter() {
110                 check_output_hlsl(module, &dest, ep.stage);
111             }
112         }
113     }
114     #[cfg(feature = "wgsl-out")]
115     {
116         if targets.contains(Targets::WGSL) {
117             check_output_wgsl(module, &info, &dest);
118         }
119     }
120 }
121 
122 #[cfg(feature = "spv-out")]
check_output_spv( module: &naga::Module, info: &naga::valid::ModuleInfo, destination: &PathBuf, params: &Parameters, )123 fn check_output_spv(
124     module: &naga::Module,
125     info: &naga::valid::ModuleInfo,
126     destination: &PathBuf,
127     params: &Parameters,
128 ) {
129     use naga::back::spv;
130     use rspirv::binary::Disassemble;
131 
132     let mut flags = spv::WriterFlags::empty();
133     if params.spv_debug {
134         flags |= spv::WriterFlags::DEBUG;
135     }
136     if params.spv_adjust_coordinate_space {
137         flags |= spv::WriterFlags::ADJUST_COORDINATE_SPACE;
138     }
139     let options = spv::Options {
140         lang_version: params.spv_version,
141         flags,
142         capabilities: Some(params.spv_capabilities.clone()),
143     };
144 
145     let spv = spv::write_vec(module, info, &options).unwrap();
146 
147     let dis = rspirv::dr::load_words(spv)
148         .expect("Produced invalid SPIR-V")
149         .disassemble();
150 
151     fs::write(destination.with_extension("spvasm"), dis).unwrap();
152 }
153 
154 #[cfg(feature = "msl-out")]
check_output_msl( module: &naga::Module, info: &naga::valid::ModuleInfo, destination: &PathBuf, params: &Parameters, )155 fn check_output_msl(
156     module: &naga::Module,
157     info: &naga::valid::ModuleInfo,
158     destination: &PathBuf,
159     params: &Parameters,
160 ) {
161     use naga::back::msl;
162 
163     #[cfg_attr(feature = "deserialize", allow(unused_variables))]
164     let default_options = msl::Options::default();
165     #[cfg(feature = "deserialize")]
166     let options = &params.msl;
167     #[cfg(not(feature = "deserialize"))]
168     let options = if params.msl_custom {
169         println!("Skipping {}", destination.display());
170         return;
171     } else {
172         &default_options
173     };
174 
175     let pipeline_options = msl::PipelineOptions {
176         allow_point_size: true,
177     };
178 
179     let (string, tr_info) = msl::write_string(module, info, options, &pipeline_options).unwrap();
180 
181     for (ep, result) in module.entry_points.iter().zip(tr_info.entry_point_names) {
182         if let Err(error) = result {
183             panic!("Failed to translate '{}': {}", ep.name, error);
184         }
185     }
186 
187     fs::write(destination.with_extension("msl"), string).unwrap();
188 }
189 
190 #[cfg(feature = "glsl-out")]
check_output_glsl( module: &naga::Module, info: &naga::valid::ModuleInfo, destination: &PathBuf, stage: naga::ShaderStage, ep_name: &str, params: &Parameters, )191 fn check_output_glsl(
192     module: &naga::Module,
193     info: &naga::valid::ModuleInfo,
194     destination: &PathBuf,
195     stage: naga::ShaderStage,
196     ep_name: &str,
197     params: &Parameters,
198 ) {
199     use naga::back::glsl;
200 
201     let options = glsl::Options {
202         version: match params.glsl_desktop_version {
203             Some(v) => glsl::Version::Desktop(v),
204             None => glsl::Version::Embedded(310),
205         },
206         shader_stage: stage,
207         entry_point: ep_name.to_string(),
208     };
209 
210     let mut buffer = String::new();
211     let mut writer = glsl::Writer::new(&mut buffer, module, info, &options).unwrap();
212     writer.write().unwrap();
213 
214     let ext = format!("{:?}.glsl", stage);
215     fs::write(destination.with_extension(&ext), buffer).unwrap();
216 }
217 
218 #[cfg(feature = "hlsl-out")]
check_output_hlsl(module: &naga::Module, destination: &PathBuf, stage: naga::ShaderStage)219 fn check_output_hlsl(module: &naga::Module, destination: &PathBuf, stage: naga::ShaderStage) {
220     use naga::back::hlsl;
221 
222     let string = hlsl::write_string(module).unwrap();
223 
224     let ext = format!("{:?}.hlsl", stage);
225     fs::write(destination.with_extension(&ext), string).unwrap();
226 }
227 
228 #[cfg(feature = "wgsl-out")]
check_output_wgsl(module: &naga::Module, info: &naga::valid::ModuleInfo, destination: &PathBuf)229 fn check_output_wgsl(module: &naga::Module, info: &naga::valid::ModuleInfo, destination: &PathBuf) {
230     use naga::back::wgsl;
231 
232     let string = wgsl::write_string(module, info).unwrap();
233 
234     fs::write(destination.with_extension("wgsl"), string).unwrap();
235 }
236 
237 #[cfg(feature = "wgsl-in")]
238 #[test]
convert_wgsl()239 fn convert_wgsl() {
240     let root = env!("CARGO_MANIFEST_DIR");
241     let inputs = [
242         (
243             "empty",
244             Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL,
245         ),
246         (
247             "quad",
248             Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::DOT | Targets::WGSL,
249         ),
250         ("boids", Targets::SPIRV | Targets::METAL | Targets::GLSL),
251         ("skybox", Targets::SPIRV | Targets::METAL | Targets::GLSL),
252         (
253             "collatz",
254             Targets::SPIRV | Targets::METAL | Targets::IR | Targets::ANALYSIS,
255         ),
256         ("shadow", Targets::SPIRV | Targets::METAL | Targets::GLSL),
257         ("image", Targets::SPIRV | Targets::METAL),
258         ("extra", Targets::SPIRV | Targets::METAL),
259         ("operators", Targets::SPIRV | Targets::METAL | Targets::GLSL),
260         (
261             "interpolate",
262             Targets::SPIRV | Targets::METAL | Targets::GLSL,
263         ),
264         ("access", Targets::SPIRV | Targets::METAL),
265         (
266             "control-flow",
267             Targets::SPIRV | Targets::METAL | Targets::GLSL,
268         ),
269     ];
270 
271     for &(name, targets) in inputs.iter() {
272         println!("Processing '{}'", name);
273         let file = fs::read_to_string(format!("{}/{}/{}.wgsl", root, DIR_IN, name))
274             .expect("Couldn't find wgsl file");
275         match naga::front::wgsl::parse_str(&file) {
276             Ok(module) => check_targets(&module, name, targets),
277             Err(e) => panic!("{}", e),
278         }
279     }
280 }
281 
282 #[cfg(feature = "spv-in")]
convert_spv(name: &str, adjust_coordinate_space: bool, targets: Targets)283 fn convert_spv(name: &str, adjust_coordinate_space: bool, targets: Targets) {
284     let root = env!("CARGO_MANIFEST_DIR");
285     let module = naga::front::spv::parse_u8_slice(
286         &fs::read(format!("{}/{}/{}.spv", root, DIR_IN, name)).expect("Couldn't find spv file"),
287         &naga::front::spv::Options {
288             adjust_coordinate_space,
289             strict_capabilities: false,
290             flow_graph_dump_prefix: None,
291         },
292     )
293     .unwrap();
294     check_targets(&module, name, targets);
295     naga::valid::Validator::new(
296         naga::valid::ValidationFlags::all(),
297         naga::valid::Capabilities::empty(),
298     )
299     .validate(&module)
300     .unwrap();
301 }
302 
303 #[cfg(feature = "spv-in")]
304 #[test]
convert_spv_quad_vert()305 fn convert_spv_quad_vert() {
306     convert_spv(
307         "quad-vert",
308         false,
309         Targets::METAL | Targets::GLSL | Targets::WGSL,
310     );
311 }
312 
313 #[cfg(feature = "spv-in")]
314 #[test]
convert_spv_shadow()315 fn convert_spv_shadow() {
316     convert_spv("shadow", true, Targets::IR | Targets::ANALYSIS);
317 }
318 
319 #[cfg(feature = "glsl-in")]
convert_glsl( name: &str, entry_points: naga::FastHashMap<String, naga::ShaderStage>, _targets: Targets, )320 fn convert_glsl(
321     name: &str,
322     entry_points: naga::FastHashMap<String, naga::ShaderStage>,
323     _targets: Targets,
324 ) {
325     let root = env!("CARGO_MANIFEST_DIR");
326     let _module = naga::front::glsl::parse_str(
327         &fs::read_to_string(format!("{}/{}/{}.glsl", root, DIR_IN, name))
328             .expect("Couldn't find glsl file"),
329         &naga::front::glsl::Options {
330             entry_points,
331             defines: Default::default(),
332         },
333     )
334     .unwrap();
335     //TODO
336     //check_targets(&module, name, targets);
337 }
338 
339 #[cfg(feature = "glsl-in")]
340 #[test]
convert_glsl_quad()341 fn convert_glsl_quad() {
342     let mut entry_points = naga::FastHashMap::default();
343     entry_points.insert("vert_main".to_string(), naga::ShaderStage::Vertex);
344     entry_points.insert("frag_main".to_string(), naga::ShaderStage::Fragment);
345     convert_glsl("quad-glsl", entry_points, Targets::SPIRV | Targets::IR);
346 }
347