1 // TODO: temp
2 #![allow(dead_code)]
3 use super::Error;
4 use crate::{
5     back::{binary_operation_str, vector_size_str, wgsl::keywords::RESERVED},
6     proc::{EntryPointIndex, NameKey, Namer, TypeResolution},
7     valid::{FunctionInfo, ModuleInfo},
8     Arena, ArraySize, Binding, Constant, Expression, FastHashMap, Function, GlobalVariable, Handle,
9     ImageClass, ImageDimension, Interpolation, Module, Sampling, ScalarKind, ScalarValue,
10     ShaderStage, Statement, StorageFormat, StructLevel, StructMember, Type, TypeInner,
11 };
12 use bit_set::BitSet;
13 use std::fmt::Write;
14 
15 const INDENT: &str = "    ";
16 const COMPONENTS: &[char] = &['x', 'y', 'z', 'w'];
17 const BAKE_PREFIX: &str = "_e";
18 
19 /// Shorthand result used internally by the backend
20 type BackendResult = Result<(), Error>;
21 
22 /// WGSL attribute
23 /// https://gpuweb.github.io/gpuweb/wgsl/#attributes
24 enum Attribute {
25     Access(crate::StorageAccess),
26     Binding(u32),
27     Block,
28     BuiltIn(crate::BuiltIn),
29     Group(u32),
30     Interpolate(Option<Interpolation>, Option<Sampling>),
31     Location(u32),
32     Stage(ShaderStage),
33     Stride(u32),
34     WorkGroupSize([u32; 3]),
35 }
36 
37 /// Stores the current function type (either a regular function or an entry point)
38 ///
39 /// Also stores data needed to identify it (handle for a regular function or index for an entry point)
40 // TODO: copy-paste from glsl-out
41 enum FunctionType {
42     /// A regular function and it's handle
43     Function(Handle<Function>),
44     /// A entry point and it's index
45     EntryPoint(EntryPointIndex),
46 }
47 
48 /// Helper structure that stores data needed when writing the function
49 // TODO: copy-paste from glsl-out
50 struct FunctionCtx<'a> {
51     /// The current function type being written
52     ty: FunctionType,
53     /// Analysis about the function
54     info: &'a FunctionInfo,
55     /// The expression arena of the current function being written
56     expressions: &'a Arena<Expression>,
57 }
58 
59 pub struct Writer<W> {
60     out: W,
61     names: FastHashMap<NameKey, String>,
62     namer: Namer,
63     named_expressions: BitSet,
64 }
65 
66 impl<W: Write> Writer<W> {
new(out: W) -> Self67     pub fn new(out: W) -> Self {
68         Writer {
69             out,
70             names: FastHashMap::default(),
71             namer: Namer::default(),
72             named_expressions: BitSet::new(),
73         }
74     }
75 
reset(&mut self, module: &Module)76     fn reset(&mut self, module: &Module) {
77         self.names.clear();
78         self.namer.reset(module, RESERVED, &[], &mut self.names);
79         self.named_expressions.clear();
80     }
81 
write(&mut self, module: &Module, info: &ModuleInfo) -> BackendResult82     pub fn write(&mut self, module: &Module, info: &ModuleInfo) -> BackendResult {
83         self.reset(module);
84 
85         // Write all structs
86         for (handle, ty) in module.types.iter() {
87             if let TypeInner::Struct {
88                 level, ref members, ..
89             } = ty.inner
90             {
91                 let block = level == StructLevel::Root;
92                 self.write_struct(module, handle, block, members)?;
93                 writeln!(self.out)?;
94             }
95         }
96 
97         // Write all constants
98         for (handle, constant) in module.constants.iter() {
99             if constant.name.is_some() {
100                 self.write_global_constant(&constant, handle)?;
101             }
102         }
103 
104         // Write all globals
105         for (ty, global) in module.global_variables.iter() {
106             self.write_global(&module, &global, ty)?;
107         }
108 
109         if !module.global_variables.is_empty() {
110             // Add extra newline for readability
111             writeln!(self.out)?;
112         }
113 
114         // Write all regular functions
115         for (handle, function) in module.functions.iter() {
116             let fun_info = &info[handle];
117 
118             let func_ctx = FunctionCtx {
119                 ty: FunctionType::Function(handle),
120                 info: fun_info,
121                 expressions: &function.expressions,
122             };
123 
124             // Write the function
125             self.write_function(&module, &function, &func_ctx)?;
126 
127             writeln!(self.out)?;
128         }
129 
130         // Write all entry points
131         for (index, ep) in module.entry_points.iter().enumerate() {
132             let attributes = match ep.stage {
133                 ShaderStage::Vertex | ShaderStage::Fragment => vec![Attribute::Stage(ep.stage)],
134                 ShaderStage::Compute => vec![
135                     Attribute::Stage(ShaderStage::Compute),
136                     Attribute::WorkGroupSize(ep.workgroup_size),
137                 ],
138             };
139 
140             self.write_attributes(&attributes, false)?;
141             // Add a newline after attribute
142             writeln!(self.out)?;
143 
144             let func_ctx = FunctionCtx {
145                 ty: FunctionType::EntryPoint(index as u16),
146                 info: &info.get_entry_point(index),
147                 expressions: &ep.function.expressions,
148             };
149             self.write_function(&module, &ep.function, &func_ctx)?;
150 
151             if index < module.entry_points.len() - 1 {
152                 writeln!(self.out)?;
153             }
154         }
155 
156         Ok(())
157     }
158 
159     /// Helper method used to write [`ScalarValue`](ScalarValue)
160     ///
161     /// # Notes
162     /// Adds no trailing or leading whitespace
write_scalar_value(&mut self, value: ScalarValue) -> BackendResult163     fn write_scalar_value(&mut self, value: ScalarValue) -> BackendResult {
164         match value {
165             ScalarValue::Sint(value) => write!(self.out, "{}", value)?,
166             ScalarValue::Uint(value) => write!(self.out, "{}", value)?,
167             // Floats are written using `Debug` instead of `Display` because it always appends the
168             // decimal part even it's zero
169             ScalarValue::Float(value) => write!(self.out, "{:?}", value)?,
170             ScalarValue::Bool(value) => write!(self.out, "{}", value)?,
171         }
172 
173         Ok(())
174     }
175 
176     /// Helper method used to write structs
177     /// https://gpuweb.github.io/gpuweb/wgsl/#functions
178     ///
179     /// # Notes
180     /// Ends in a newline
write_function( &mut self, module: &Module, func: &Function, func_ctx: &FunctionCtx<'_>, ) -> BackendResult181     fn write_function(
182         &mut self,
183         module: &Module,
184         func: &Function,
185         func_ctx: &FunctionCtx<'_>,
186     ) -> BackendResult {
187         let func_name = match func_ctx.ty {
188             FunctionType::EntryPoint(index) => self.names[&NameKey::EntryPoint(index)].clone(),
189             FunctionType::Function(handle) => self.names[&NameKey::Function(handle)].clone(),
190         };
191 
192         // Write function name
193         write!(self.out, "fn {}(", func_name)?;
194 
195         // Write function arguments
196         for (index, arg) in func.arguments.iter().enumerate() {
197             // Write argument attribute if a binding is present
198             if let Some(ref binding) = arg.binding {
199                 self.write_attributes(&map_binding_to_attribute(binding), false)?;
200                 write!(self.out, " ")?;
201             }
202             // Write argument name
203             let argument_name = match func_ctx.ty {
204                 FunctionType::Function(handle) => {
205                     self.names[&NameKey::FunctionArgument(handle, index as u32)].clone()
206                 }
207                 FunctionType::EntryPoint(ep_index) => {
208                     self.names[&NameKey::EntryPointArgument(ep_index, index as u32)].clone()
209                 }
210             };
211 
212             write!(self.out, "{}: ", argument_name)?;
213             // Write argument type
214             self.write_type(module, arg.ty)?;
215             if index < func.arguments.len() - 1 {
216                 // Add a separator between args
217                 write!(self.out, ", ")?;
218             }
219         }
220 
221         write!(self.out, ")")?;
222 
223         // Write function return type
224         if let Some(ref result) = func.result {
225             if let Some(ref binding) = result.binding {
226                 write!(self.out, " -> ")?;
227                 self.write_attributes(&map_binding_to_attribute(binding), true)?;
228                 self.write_type(module, result.ty)?;
229             } else {
230                 let struct_name = &self.names[&NameKey::Type(result.ty)].clone();
231                 write!(self.out, " -> {}", struct_name)?;
232             }
233         }
234 
235         write!(self.out, " {{")?;
236         writeln!(self.out)?;
237 
238         // Write function local variables
239         for (handle, local) in func.local_variables.iter() {
240             // Write indentation (only for readability)
241             write!(self.out, "{}", INDENT)?;
242 
243             // Write the local name
244             // The leading space is important
245             let name_key = match func_ctx.ty {
246                 FunctionType::Function(func_handle) => NameKey::FunctionLocal(func_handle, handle),
247                 FunctionType::EntryPoint(idx) => NameKey::EntryPointLocal(idx, handle),
248             };
249             write!(self.out, "var {}: ", self.names[&name_key])?;
250 
251             // Write the local type
252             self.write_type(&module, local.ty)?;
253 
254             // Write the local initializer if needed
255             if let Some(init) = local.init {
256                 // Put the equal signal only if there's a initializer
257                 // The leading and trailing spaces aren't needed but help with readability
258                 write!(self.out, " = ")?;
259 
260                 // Write the constant
261                 // `write_constant` adds no trailing or leading space/newline
262                 self.write_constant(module, init)?;
263             }
264 
265             // Finish the local with `;` and add a newline (only for readability)
266             writeln!(self.out, ";")?
267         }
268 
269         if !func.local_variables.is_empty() {
270             writeln!(self.out)?;
271         }
272 
273         // Write the function body (statement list)
274         for sta in func.body.iter() {
275             // The indentation should always be 1 when writing the function body
276             self.write_stmt(&module, sta, &func_ctx, 1)?;
277         }
278 
279         writeln!(self.out, "}}")?;
280 
281         self.named_expressions.clear();
282 
283         Ok(())
284     }
285 
286     /// Helper method to write a attribute
287     ///
288     /// # Notes
289     /// Adds an extra space if required
write_attributes(&mut self, attributes: &[Attribute], extra_space: bool) -> BackendResult290     fn write_attributes(&mut self, attributes: &[Attribute], extra_space: bool) -> BackendResult {
291         let mut attributes_str = String::new();
292         for (index, attribute) in attributes.iter().enumerate() {
293             let attribute_str = match *attribute {
294                 Attribute::Access(access) => {
295                     let access_str = if access.is_all() {
296                         "read_write"
297                     } else if access.contains(crate::StorageAccess::LOAD) {
298                         "read"
299                     } else {
300                         "write"
301                     };
302                     format!("access({})", access_str)
303                 }
304                 Attribute::Block => String::from("block"),
305                 Attribute::Location(id) => format!("location({})", id),
306                 Attribute::BuiltIn(builtin_attrib) => {
307                     let builtin_str = builtin_str(builtin_attrib);
308                     if let Some(builtin) = builtin_str {
309                         format!("builtin({})", builtin)
310                     } else {
311                         log::warn!("Unsupported builtin attribute: {:?}", builtin_attrib);
312                         String::from("")
313                     }
314                 }
315                 Attribute::Stage(shader_stage) => match shader_stage {
316                     ShaderStage::Vertex => String::from("stage(vertex)"),
317                     ShaderStage::Fragment => String::from("stage(fragment)"),
318                     ShaderStage::Compute => String::from("stage(compute)"),
319                 },
320                 Attribute::Stride(stride) => format!("stride({})", stride),
321                 Attribute::WorkGroupSize(size) => {
322                     format!("workgroup_size({}, {}, {})", size[0], size[1], size[2])
323                 }
324                 Attribute::Binding(id) => format!("binding({})", id),
325                 Attribute::Group(id) => format!("group({})", id),
326                 Attribute::Interpolate(interpolation, sampling) => {
327                     if interpolation.is_some() || sampling.is_some() {
328                         let interpolation_str = if let Some(interpolation) = interpolation {
329                             interpolation_str(interpolation)
330                         } else {
331                             ""
332                         };
333                         let sampling_str = if let Some(sampling) = sampling {
334                             // Center sampling is the default
335                             if sampling == Sampling::Center {
336                                 String::from("")
337                             } else {
338                                 format!(",{}", sampling_str(sampling))
339                             }
340                         } else {
341                             String::from("")
342                         };
343                         format!("interpolate({}{})", interpolation_str, sampling_str)
344                     } else {
345                         String::from("")
346                     }
347                 }
348             };
349             if !attribute_str.is_empty() {
350                 // Add a separator between args
351                 let separator = if index < attributes.len() - 1 {
352                     ", "
353                 } else {
354                     ""
355                 };
356                 attributes_str = format!("{}{}{}", attributes_str, attribute_str, separator);
357             }
358         }
359         if !attributes_str.is_empty() {
360             //TODO: looks ugly
361             if attributes_str.ends_with(", ") {
362                 attributes_str = attributes_str[0..attributes_str.len() - 2].to_string();
363             }
364             let extra_space_str = if extra_space { " " } else { "" };
365             write!(self.out, "[[{}]]{}", attributes_str, extra_space_str)?;
366         }
367 
368         Ok(())
369     }
370 
371     /// Helper method used to write structs
372     ///
373     /// # Notes
374     /// Ends in a newline
write_struct( &mut self, module: &Module, handle: Handle<Type>, block: bool, members: &[StructMember], ) -> BackendResult375     fn write_struct(
376         &mut self,
377         module: &Module,
378         handle: Handle<Type>,
379         block: bool,
380         members: &[StructMember],
381     ) -> BackendResult {
382         if block {
383             self.write_attributes(&[Attribute::Block], false)?;
384             writeln!(self.out)?;
385         }
386         let name = &self.names[&NameKey::Type(handle)].clone();
387         write!(self.out, "struct {} {{", name)?;
388         writeln!(self.out)?;
389         for (index, member) in members.iter().enumerate() {
390             // Skip struct member with unsupported built in
391             if let Some(Binding::BuiltIn(builtin)) = member.binding {
392                 if builtin_str(builtin).is_none() {
393                     log::warn!("Skip member with unsupported builtin {:?}", builtin);
394                     continue;
395                 }
396             }
397 
398             // The indentation is only for readability
399             write!(self.out, "{}", INDENT)?;
400             if let Some(ref binding) = member.binding {
401                 self.write_attributes(&map_binding_to_attribute(binding), true)?;
402             }
403             // Write struct member name and type
404             let member_name = &self.names[&NameKey::StructMember(handle, index as u32)];
405             write!(self.out, "{}: ", member_name)?;
406             // Write stride attribute for array struct member
407             if let TypeInner::Array {
408                 base: _,
409                 size: _,
410                 stride,
411             } = module.types[member.ty].inner
412             {
413                 self.write_attributes(&[Attribute::Stride(stride)], true)?;
414             }
415             self.write_type(module, member.ty)?;
416             write!(self.out, ";")?;
417             writeln!(self.out)?;
418         }
419 
420         write!(self.out, "}};")?;
421 
422         writeln!(self.out)?;
423 
424         Ok(())
425     }
426 
427     /// Helper method used to write non image/sampler types
428     ///
429     /// # Notes
430     /// Adds no trailing or leading whitespace
write_type(&mut self, module: &Module, ty: Handle<Type>) -> BackendResult431     fn write_type(&mut self, module: &Module, ty: Handle<Type>) -> BackendResult {
432         let inner = &module.types[ty].inner;
433         match *inner {
434             TypeInner::Struct { .. } => {
435                 // Get the struct name
436                 let name = &self.names[&NameKey::Type(ty)];
437                 write!(self.out, "{}", name)?;
438                 return Ok(());
439             }
440             ref other => self.write_value_type(module, other)?,
441         }
442 
443         Ok(())
444     }
445 
446     /// Helper method used to write value types
447     ///
448     /// # Notes
449     /// Adds no trailing or leading whitespace
write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult450     fn write_value_type(&mut self, module: &Module, inner: &TypeInner) -> BackendResult {
451         match *inner {
452             TypeInner::Vector { size, kind, .. } => write!(
453                 self.out,
454                 "{}",
455                 format!("vec{}<{}>", vector_size_str(size), scalar_kind_str(kind),)
456             )?,
457             TypeInner::Sampler { comparison: false } => {
458                 write!(self.out, "sampler")?;
459             }
460             TypeInner::Sampler { comparison: true } => {
461                 write!(self.out, "sampler_comparison")?;
462             }
463             TypeInner::Image {
464                 dim,
465                 arrayed,
466                 class,
467             } => {
468                 // More about texture types: https://gpuweb.github.io/gpuweb/wgsl/#sampled-texture-type
469                 let dim_str = image_dimension_str(dim);
470                 let arrayed_str = if arrayed { "_array" } else { "" };
471                 let (class_str, multisampled_str, scalar_str) = match class {
472                     ImageClass::Sampled { kind, multi } => (
473                         "",
474                         if multi { "multisampled" } else { "" },
475                         format!("<{}>", scalar_kind_str(kind)),
476                     ),
477                     ImageClass::Depth => ("depth", "", String::from("")),
478                     ImageClass::Storage(storage_format) => (
479                         "storage_",
480                         "",
481                         format!("<{}>", storage_format_str(storage_format)),
482                     ),
483                 };
484                 let ty_str = format!(
485                     "texture_{}{}{}{}{}",
486                     class_str, multisampled_str, dim_str, arrayed_str, scalar_str
487                 );
488                 write!(self.out, "{}", ty_str)?;
489             }
490             TypeInner::Scalar { kind, .. } => {
491                 write!(self.out, "{}", scalar_kind_str(kind))?;
492             }
493             TypeInner::Array { base, size, .. } => {
494                 // More info https://gpuweb.github.io/gpuweb/wgsl/#array-types
495                 // array<A, 3> -- Constant array
496                 // array<A> -- Dynamic array
497                 write!(self.out, "array<")?;
498                 match size {
499                     ArraySize::Constant(handle) => {
500                         self.write_type(module, base)?;
501                         write!(self.out, ",")?;
502                         self.write_constant(module, handle)?;
503                     }
504                     ArraySize::Dynamic => {
505                         self.write_type(module, base)?;
506                     }
507                 }
508                 write!(self.out, ">")?;
509             }
510             TypeInner::Matrix {
511                 columns,
512                 rows,
513                 width: _,
514             } => {
515                 write!(
516                     self.out,
517                     //TODO: Can matrix be other than f32?
518                     "mat{}x{}<f32>",
519                     vector_size_str(columns),
520                     vector_size_str(rows),
521                 )?;
522             }
523             _ => {
524                 return Err(Error::Unimplemented(format!(
525                     "write_value_type {:?}",
526                     inner
527                 )));
528             }
529         }
530 
531         Ok(())
532     }
533     /// Helper method used to write statements
534     ///
535     /// # Notes
536     /// Always adds a newline
write_stmt( &mut self, module: &Module, stmt: &Statement, func_ctx: &FunctionCtx<'_>, indent: usize, ) -> BackendResult537     fn write_stmt(
538         &mut self,
539         module: &Module,
540         stmt: &Statement,
541         func_ctx: &FunctionCtx<'_>,
542         indent: usize,
543     ) -> BackendResult {
544         match *stmt {
545             Statement::Emit(ref range) => {
546                 for handle in range.clone() {
547                     let min_ref_count = func_ctx.expressions[handle].bake_ref_count();
548                     if min_ref_count <= func_ctx.info[handle].ref_count {
549                         write!(self.out, "{}", INDENT.repeat(indent))?;
550                         self.start_baking_expr(module, handle, &func_ctx)?;
551                         self.write_expr(module, handle, &func_ctx)?;
552                         writeln!(self.out, ";")?;
553                         self.named_expressions.insert(handle.index());
554                     }
555                 }
556             }
557             // TODO: copy-paste from glsl-out
558             Statement::If {
559                 condition,
560                 ref accept,
561                 ref reject,
562             } => {
563                 write!(self.out, "{}", INDENT.repeat(indent))?;
564                 write!(self.out, "if (")?;
565                 self.write_expr(module, condition, func_ctx)?;
566                 writeln!(self.out, ") {{")?;
567 
568                 for sta in accept {
569                     // Increase indentation to help with readability
570                     self.write_stmt(module, sta, func_ctx, indent + 1)?;
571                 }
572 
573                 // If there are no statements in the reject block we skip writing it
574                 // This is only for readability
575                 if !reject.is_empty() {
576                     writeln!(self.out, "{}}} else {{", INDENT.repeat(indent))?;
577 
578                     for sta in reject {
579                         // Increase indentation to help with readability
580                         self.write_stmt(module, sta, func_ctx, indent + 1)?;
581                     }
582                 }
583 
584                 writeln!(self.out, "{}}}", INDENT.repeat(indent))?
585             }
586             Statement::Return { value } => {
587                 write!(self.out, "{}", INDENT.repeat(indent))?;
588                 write!(self.out, "return")?;
589                 if let Some(return_value) = value {
590                     // The leading space is important
591                     write!(self.out, " ")?;
592                     self.write_expr(module, return_value, &func_ctx)?;
593                 }
594                 writeln!(self.out, ";")?;
595             }
596             // TODO: copy-paste from glsl-out
597             Statement::Kill => {
598                 write!(self.out, "{}", INDENT.repeat(indent))?;
599                 writeln!(self.out, "discard;")?
600             }
601             // TODO: copy-paste from glsl-out
602             Statement::Store { pointer, value } => {
603                 write!(self.out, "{}", INDENT.repeat(indent))?;
604                 self.write_expr(module, pointer, func_ctx)?;
605                 write!(self.out, " = ")?;
606                 self.write_expr(module, value, func_ctx)?;
607                 writeln!(self.out, ";")?
608             }
609             crate::Statement::Call {
610                 function,
611                 ref arguments,
612                 result,
613             } => {
614                 write!(self.out, "{}", INDENT.repeat(indent))?;
615                 if let Some(expr) = result {
616                     self.start_baking_expr(module, expr, &func_ctx)?;
617                     self.named_expressions.insert(expr.index());
618                 }
619                 let func_name = &self.names[&NameKey::Function(function)];
620                 write!(self.out, "{}(", func_name)?;
621                 for (index, argument) in arguments.iter().enumerate() {
622                     self.write_expr(module, *argument, func_ctx)?;
623                     // Only write a comma if isn't the last element
624                     if index != arguments.len().saturating_sub(1) {
625                         // The leading space is for readability only
626                         write!(self.out, ", ")?;
627                     }
628                 }
629                 writeln!(self.out, ");")?
630             }
631             _ => {
632                 return Err(Error::Unimplemented(format!("write_stmt {:?}", stmt)));
633             }
634         }
635 
636         Ok(())
637     }
638 
start_baking_expr( &mut self, module: &Module, handle: Handle<Expression>, context: &FunctionCtx, ) -> BackendResult639     fn start_baking_expr(
640         &mut self,
641         module: &Module,
642         handle: Handle<Expression>,
643         context: &FunctionCtx,
644     ) -> BackendResult {
645         // Write variable name
646         write!(self.out, "let {}{}: ", BAKE_PREFIX, handle.index())?;
647         let ty = &context.info[handle].ty;
648         // Write variable type
649         match *ty {
650             TypeResolution::Handle(ty_handle) => {
651                 self.write_type(module, ty_handle)?;
652             }
653             TypeResolution::Value(crate::TypeInner::Scalar { kind, .. }) => {
654                 write!(self.out, "{}", scalar_kind_str(kind))?;
655             }
656             TypeResolution::Value(crate::TypeInner::Vector { size, kind, .. }) => {
657                 write!(
658                     self.out,
659                     "vec{}<{}>",
660                     vector_size_str(size),
661                     scalar_kind_str(kind),
662                 )?;
663             }
664             _ => {
665                 return Err(Error::Unimplemented(format!("start_baking_expr {:?}", ty)));
666             }
667         }
668 
669         write!(self.out, " = ")?;
670         Ok(())
671     }
672 
673     /// Helper method to write expressions
674     ///
675     /// # Notes
676     /// Doesn't add any newlines or leading/trailing spaces
write_expr( &mut self, module: &Module, expr: Handle<Expression>, func_ctx: &FunctionCtx<'_>, ) -> BackendResult677     fn write_expr(
678         &mut self,
679         module: &Module,
680         expr: Handle<Expression>,
681         func_ctx: &FunctionCtx<'_>,
682     ) -> BackendResult {
683         let expression = &func_ctx.expressions[expr];
684 
685         if self.named_expressions.contains(expr.index()) {
686             write!(self.out, "{}{}", BAKE_PREFIX, expr.index())?;
687             return Ok(());
688         }
689 
690         match *expression {
691             Expression::Constant(constant) => self.write_constant(module, constant)?,
692             Expression::Compose { ty, ref components } => {
693                 self.write_type(&module, ty)?;
694                 write!(self.out, "(")?;
695                 // !spv-in specific notes!
696                 // WGSL does not support all SPIR-V builtins and we should skip it in generated shaders.
697                 // We already skip them when we generate struct type.
698                 // Now we need to find components that used struct with ignored builtins.
699 
700                 // So, why we can't just return the error to a user?
701                 // We can, but otherwise, we can't generate WGSL shader from any glslang SPIR-V shaders.
702                 // glslang generates gl_PerVertex struct with gl_CullDistance, gl_ClipDistance and gl_PointSize builtin inside by default.
703                 // All of them are not supported by WGSL.
704 
705                 // We need to copy components to another vec because we don't know which of them we should write.
706                 let mut components_to_write = Vec::with_capacity(components.len());
707                 for component in components {
708                     let mut skip_component = false;
709                     if let Expression::Load { pointer } = func_ctx.expressions[*component] {
710                         if let Expression::AccessIndex {
711                             base,
712                             index: access_index,
713                         } = func_ctx.expressions[pointer]
714                         {
715                             let base_ty_res = &func_ctx.info[base].ty;
716                             let resolved = base_ty_res.inner_with(&module.types);
717                             if let TypeInner::Pointer {
718                                 base: pointer_base_handle,
719                                 ..
720                             } = *resolved
721                             {
722                                 // Let's check that we try to access a struct member with unsupported built-in and skip it.
723                                 if let TypeInner::Struct { ref members, .. } =
724                                     module.types[pointer_base_handle].inner
725                                 {
726                                     if let Some(Binding::BuiltIn(builtin)) =
727                                         members[access_index as usize].binding
728                                     {
729                                         if builtin_str(builtin).is_none() {
730                                             // glslang why you did this with us...
731                                             log::warn!(
732                                                 "Skip component with unsupported builtin {:?}",
733                                                 builtin
734                                             );
735                                             skip_component = true;
736                                         }
737                                     }
738                                 }
739                             }
740                         }
741                     }
742                     if skip_component {
743                         continue;
744                     } else {
745                         components_to_write.push(*component);
746                     }
747                 }
748 
749                 // non spv-in specific notes!
750                 // Real `Expression::Compose` logic generates here.
751                 for (index, component) in components_to_write.iter().enumerate() {
752                     self.write_expr(module, *component, &func_ctx)?;
753                     // Only write a comma if isn't the last element
754                     if index != components_to_write.len().saturating_sub(1) {
755                         // The leading space is for readability only
756                         write!(self.out, ", ")?;
757                     }
758                 }
759                 write!(self.out, ")")?
760             }
761             Expression::FunctionArgument(pos) => {
762                 let name_key = match func_ctx.ty {
763                     FunctionType::Function(handle) => NameKey::FunctionArgument(handle, pos),
764                     FunctionType::EntryPoint(ep_index) => {
765                         NameKey::EntryPointArgument(ep_index, pos)
766                     }
767                 };
768                 let name = &self.names[&name_key];
769                 write!(self.out, "{}", name)?;
770             }
771             Expression::Binary { op, left, right } => {
772                 self.write_expr(module, left, func_ctx)?;
773 
774                 write!(self.out, " {} ", binary_operation_str(op),)?;
775 
776                 self.write_expr(module, right, func_ctx)?;
777             }
778             // TODO: copy-paste from glsl-out
779             Expression::Access { base, index } => {
780                 self.write_expr(module, base, func_ctx)?;
781                 write!(self.out, "[")?;
782                 self.write_expr(module, index, func_ctx)?;
783                 write!(self.out, "]")?
784             }
785             // TODO: copy-paste from glsl-out
786             Expression::AccessIndex { base, index } => {
787                 self.write_expr(module, base, func_ctx)?;
788 
789                 let base_ty_res = &func_ctx.info[base].ty;
790                 let mut resolved = base_ty_res.inner_with(&module.types);
791                 let base_ty_handle = match *resolved {
792                     TypeInner::Pointer { base, class: _ } => {
793                         resolved = &module.types[base].inner;
794                         Some(base)
795                     }
796                     _ => base_ty_res.handle(),
797                 };
798 
799                 match *resolved {
800                     TypeInner::Vector { .. }
801                     | TypeInner::Matrix { .. }
802                     | TypeInner::Array { .. }
803                     | TypeInner::ValuePointer { .. } => write!(self.out, "[{}]", index)?,
804                     TypeInner::Struct { .. } => {
805                         // This will never panic in case the type is a `Struct`, this is not true
806                         // for other types so we can only check while inside this match arm
807                         let ty = base_ty_handle.unwrap();
808 
809                         write!(
810                             self.out,
811                             ".{}",
812                             &self.names[&NameKey::StructMember(ty, index)]
813                         )?
814                     }
815                     ref other => return Err(Error::Custom(format!("Cannot index {:?}", other))),
816                 }
817             }
818             Expression::ImageSample {
819                 image,
820                 sampler,
821                 coordinate,
822                 array_index: _,
823                 offset: _,
824                 level,
825                 depth_ref: _,
826             } => {
827                 // TODO: other texture functions
828                 // TODO: comments
829                 let fun_name = match level {
830                     crate::SampleLevel::Auto => "textureSample",
831                     _ => {
832                         return Err(Error::Unimplemented(format!(
833                             "expression_imagesample_level {:?}",
834                             level
835                         )));
836                     }
837                 };
838                 write!(self.out, "{}(", fun_name)?;
839                 self.write_expr(module, image, func_ctx)?;
840                 write!(self.out, ", ")?;
841                 self.write_expr(module, sampler, func_ctx)?;
842                 write!(self.out, ", ")?;
843                 self.write_expr(module, coordinate, func_ctx)?;
844                 write!(self.out, ")")?;
845             }
846             // TODO: copy-paste from msl-out
847             Expression::GlobalVariable(handle) => {
848                 let name = &self.names[&NameKey::GlobalVariable(handle)];
849                 write!(self.out, "{}", name)?;
850             }
851             Expression::As {
852                 expr,
853                 kind,
854                 convert: _, //TODO:
855             } => {
856                 let inner = func_ctx.info[expr].ty.inner_with(&module.types);
857                 let op = match *inner {
858                     TypeInner::Matrix { columns, rows, .. } => {
859                         format!("mat{}x{}", vector_size_str(columns), vector_size_str(rows))
860                     }
861                     TypeInner::Vector { size, .. } => format!("vec{}", vector_size_str(size)),
862                     TypeInner::Scalar { kind, .. } => String::from(scalar_kind_str(kind)),
863                     _ => {
864                         return Err(Error::Unimplemented(format!(
865                             "write_expr expression::as {:?}",
866                             inner
867                         )));
868                     }
869                 };
870                 let scalar = scalar_kind_str(kind);
871                 write!(self.out, "{}<{}>(", op, scalar)?;
872                 self.write_expr(module, expr, func_ctx)?;
873                 write!(self.out, ")")?;
874             }
875             Expression::Splat { size, value } => {
876                 let inner = func_ctx.info[value].ty.inner_with(&module.types);
877                 let scalar_kind = match *inner {
878                     crate::TypeInner::Scalar { kind, .. } => kind,
879                     _ => {
880                         return Err(Error::Unimplemented(format!(
881                             "write_expr expression::splat {:?}",
882                             inner
883                         )));
884                     }
885                 };
886                 let scalar = scalar_kind_str(scalar_kind);
887                 let size = vector_size_str(size);
888 
889                 write!(self.out, "vec{}<{}>(", size, scalar)?;
890                 self.write_expr(module, value, func_ctx)?;
891                 write!(self.out, ")")?;
892             }
893             //TODO: add pointer logic
894             Expression::Load { pointer } => self.write_expr(module, pointer, func_ctx)?,
895             Expression::LocalVariable(handle) => {
896                 let name_key = match func_ctx.ty {
897                     FunctionType::Function(func_handle) => {
898                         NameKey::FunctionLocal(func_handle, handle)
899                     }
900                     FunctionType::EntryPoint(idx) => NameKey::EntryPointLocal(idx, handle),
901                 };
902                 write!(self.out, "{}", self.names[&name_key])?
903             }
904             Expression::ArrayLength(expr) => {
905                 write!(self.out, "arrayLength(")?;
906                 self.write_expr(module, expr, func_ctx)?;
907                 write!(self.out, ")")?;
908             }
909             Expression::Math {
910                 fun,
911                 arg,
912                 arg1,
913                 arg2,
914             } => {
915                 use crate::MathFunction as Mf;
916 
917                 let fun_name = match fun {
918                     Mf::Length => "length",
919                     Mf::Mix => "mix",
920                     _ => {
921                         return Err(Error::Unimplemented(format!(
922                             "write_expr Math func {:?}",
923                             fun
924                         )));
925                     }
926                 };
927 
928                 write!(self.out, "{}(", fun_name)?;
929                 self.write_expr(module, arg, func_ctx)?;
930                 if let Some(arg) = arg1 {
931                     write!(self.out, ", ")?;
932                     self.write_expr(module, arg, func_ctx)?;
933                 }
934                 if let Some(arg) = arg2 {
935                     write!(self.out, ", ")?;
936                     self.write_expr(module, arg, func_ctx)?;
937                 }
938                 write!(self.out, ")")?
939             }
940             Expression::Swizzle {
941                 size,
942                 vector,
943                 pattern,
944             } => {
945                 self.write_expr(module, vector, func_ctx)?;
946                 write!(self.out, ".")?;
947                 for &sc in pattern[..size as usize].iter() {
948                     self.out.write_char(COMPONENTS[sc as usize])?;
949                 }
950             }
951             _ => {
952                 return Err(Error::Unimplemented(format!("write_expr {:?}", expression)));
953             }
954         }
955 
956         Ok(())
957     }
958 
959     /// Helper method used to write global variables
960     /// # Notes
961     /// Always adds a newline
write_global( &mut self, module: &Module, global: &GlobalVariable, handle: Handle<GlobalVariable>, ) -> BackendResult962     fn write_global(
963         &mut self,
964         module: &Module,
965         global: &GlobalVariable,
966         handle: Handle<GlobalVariable>,
967     ) -> BackendResult {
968         let name = self.names[&NameKey::GlobalVariable(handle)].clone();
969         // Write group and dinding attributes if present
970         if let Some(ref binding) = global.binding {
971             self.write_attributes(
972                 &[
973                     Attribute::Group(binding.group),
974                     Attribute::Binding(binding.binding),
975                 ],
976                 false,
977             )?;
978             writeln!(self.out)?;
979         }
980 
981         // First write only global name
982         write!(self.out, "var {}: ", name)?;
983         // Write access attribute if present
984         if !global.storage_access.is_empty() {
985             self.write_attributes(&[Attribute::Access(global.storage_access)], true)?;
986         }
987         // Write global type
988         self.write_type(module, global.ty)?;
989         // End with semicolon
990         writeln!(self.out, ";")?;
991 
992         Ok(())
993     }
994 
995     /// Helper method used to write constants
996     ///
997     /// # Notes
998     /// Doesn't add any newlines or leading/trailing spaces
write_constant(&mut self, module: &Module, handle: Handle<Constant>) -> BackendResult999     fn write_constant(&mut self, module: &Module, handle: Handle<Constant>) -> BackendResult {
1000         let constant = &module.constants[handle];
1001         match constant.inner {
1002             crate::ConstantInner::Scalar {
1003                 width: _,
1004                 ref value,
1005             } => {
1006                 if let Some(ref name) = constant.name {
1007                     write!(self.out, "{}", name)?;
1008                 } else {
1009                     self.write_scalar_value(*value)?;
1010                 }
1011             }
1012             crate::ConstantInner::Composite { ty, ref components } => {
1013                 self.write_type(module, ty)?;
1014                 write!(self.out, "(")?;
1015 
1016                 // Write the comma separated constants
1017                 for (index, constant) in components.iter().enumerate() {
1018                     self.write_constant(module, *constant)?;
1019                     // Only write a comma if isn't the last element
1020                     if index != components.len().saturating_sub(1) {
1021                         // The leading space is for readability only
1022                         write!(self.out, ", ")?;
1023                     }
1024                 }
1025                 write!(self.out, ")")?
1026             }
1027         }
1028 
1029         Ok(())
1030     }
1031 
1032     /// Helper method used to write global constants
1033     ///
1034     /// # Notes
1035     /// Ends in a newline
write_global_constant( &mut self, constant: &Constant, handle: Handle<Constant>, ) -> BackendResult1036     fn write_global_constant(
1037         &mut self,
1038         constant: &Constant,
1039         handle: Handle<Constant>,
1040     ) -> BackendResult {
1041         match constant.inner {
1042             crate::ConstantInner::Scalar {
1043                 width: _,
1044                 ref value,
1045             } => {
1046                 let name = self.names[&NameKey::Constant(handle)].clone();
1047                 // First write only constant name
1048                 write!(self.out, "let {}: ", name)?;
1049                 // Next write constant type and value
1050                 match *value {
1051                     crate::ScalarValue::Sint(value) => {
1052                         write!(self.out, "i32 = {}", value)?;
1053                     }
1054                     crate::ScalarValue::Uint(value) => {
1055                         write!(self.out, "u32 = {}", value)?;
1056                     }
1057                     crate::ScalarValue::Float(value) => {
1058                         // Floats are written using `Debug` instead of `Display` because it always appends the
1059                         // decimal part even it's zero
1060                         write!(self.out, "f32 = {:?}", value)?;
1061                     }
1062                     crate::ScalarValue::Bool(value) => {
1063                         write!(self.out, "bool = {}", value)?;
1064                     }
1065                 };
1066                 // End with semicolon and extra newline for readability
1067                 writeln!(self.out, ";")?;
1068                 writeln!(self.out)?;
1069             }
1070             _ => {
1071                 return Err(Error::Unimplemented(format!(
1072                     "write_global_constant {:?}",
1073                     constant.inner
1074                 )));
1075             }
1076         }
1077 
1078         Ok(())
1079     }
1080 
finish(self) -> W1081     pub fn finish(self) -> W {
1082         self.out
1083     }
1084 }
1085 
builtin_str(built_in: crate::BuiltIn) -> Option<&'static str>1086 fn builtin_str(built_in: crate::BuiltIn) -> Option<&'static str> {
1087     use crate::BuiltIn;
1088     match built_in {
1089         BuiltIn::VertexIndex => Some("vertex_index"),
1090         BuiltIn::InstanceIndex => Some("instance_index"),
1091         BuiltIn::Position => Some("position"),
1092         BuiltIn::FrontFacing => Some("front_facing"),
1093         BuiltIn::FragDepth => Some("frag_depth"),
1094         BuiltIn::LocalInvocationId => Some("local_invocation_id"),
1095         BuiltIn::LocalInvocationIndex => Some("local_invocation_index"),
1096         BuiltIn::GlobalInvocationId => Some("global_invocation_id"),
1097         BuiltIn::WorkGroupId => Some("workgroup_id"),
1098         BuiltIn::WorkGroupSize => Some("workgroup_size"),
1099         BuiltIn::SampleIndex => Some("sample_index"),
1100         BuiltIn::SampleMask => Some("sample_mask"),
1101         _ => None,
1102     }
1103 }
1104 
image_dimension_str(dim: ImageDimension) -> &'static str1105 fn image_dimension_str(dim: ImageDimension) -> &'static str {
1106     match dim {
1107         ImageDimension::D1 => "1d",
1108         ImageDimension::D2 => "2d",
1109         ImageDimension::D3 => "3d",
1110         ImageDimension::Cube => "cube",
1111     }
1112 }
1113 
scalar_kind_str(kind: ScalarKind) -> &'static str1114 fn scalar_kind_str(kind: ScalarKind) -> &'static str {
1115     match kind {
1116         crate::ScalarKind::Float => "f32",
1117         crate::ScalarKind::Sint => "i32",
1118         crate::ScalarKind::Uint => "u32",
1119         crate::ScalarKind::Bool => "bool",
1120     }
1121 }
1122 
storage_format_str(format: StorageFormat) -> &'static str1123 fn storage_format_str(format: StorageFormat) -> &'static str {
1124     match format {
1125         StorageFormat::R8Unorm => "r8unorm",
1126         StorageFormat::R8Snorm => "r8snorm",
1127         StorageFormat::R8Uint => "r8uint",
1128         StorageFormat::R8Sint => "r8sint",
1129         StorageFormat::R16Uint => "r16uint",
1130         StorageFormat::R16Sint => "r16sint",
1131         StorageFormat::R16Float => "r16float",
1132         StorageFormat::Rg8Unorm => "rg8unorm",
1133         StorageFormat::Rg8Snorm => "rg8snorm",
1134         StorageFormat::Rg8Uint => "rg8uint",
1135         StorageFormat::Rg8Sint => "rg8sint",
1136         StorageFormat::R32Uint => "r32uint",
1137         StorageFormat::R32Sint => "r32sint",
1138         StorageFormat::R32Float => "r32float",
1139         StorageFormat::Rg16Uint => "rg16uint",
1140         StorageFormat::Rg16Sint => "rg16sint",
1141         StorageFormat::Rg16Float => "rg16float",
1142         StorageFormat::Rgba8Unorm => "rgba8unorm",
1143         StorageFormat::Rgba8Snorm => "rgba8snorm",
1144         StorageFormat::Rgba8Uint => "rgba8uint",
1145         StorageFormat::Rgba8Sint => "rgba8sint",
1146         StorageFormat::Rgb10a2Unorm => "rgb10a2unorm",
1147         StorageFormat::Rg11b10Float => "rg11b10float",
1148         StorageFormat::Rg32Uint => "rg32uint",
1149         StorageFormat::Rg32Sint => "rg32sint",
1150         StorageFormat::Rg32Float => "rg32float",
1151         StorageFormat::Rgba16Uint => "rgba16uint",
1152         StorageFormat::Rgba16Sint => "rgba16sint",
1153         StorageFormat::Rgba16Float => "rgba16float",
1154         StorageFormat::Rgba32Uint => "rgba32uint",
1155         StorageFormat::Rgba32Sint => "rgba32sint",
1156         StorageFormat::Rgba32Float => "rgba32float",
1157     }
1158 }
1159 
1160 /// Helper function that returns the string corresponding to the WGSL interpolation qualifier
interpolation_str(interpolation: Interpolation) -> &'static str1161 fn interpolation_str(interpolation: Interpolation) -> &'static str {
1162     match interpolation {
1163         Interpolation::Perspective => "perspective",
1164         Interpolation::Linear => "linear",
1165         Interpolation::Flat => "flat",
1166     }
1167 }
1168 
1169 /// Return the WGSL auxiliary qualifier for the given sampling value.
sampling_str(sampling: Sampling) -> &'static str1170 fn sampling_str(sampling: Sampling) -> &'static str {
1171     match sampling {
1172         Sampling::Center => "",
1173         Sampling::Centroid => "centroid",
1174         Sampling::Sample => "sample",
1175     }
1176 }
1177 
map_binding_to_attribute(binding: &Binding) -> Vec<Attribute>1178 fn map_binding_to_attribute(binding: &Binding) -> Vec<Attribute> {
1179     match *binding {
1180         Binding::BuiltIn(built_in) => vec![Attribute::BuiltIn(built_in)],
1181         Binding::Location {
1182             location,
1183             interpolation,
1184             sampling,
1185         } => vec![
1186             Attribute::Location(location),
1187             Attribute::Interpolate(interpolation, sampling),
1188         ],
1189     }
1190 }
1191