1 //! Implementations for `BlockContext` methods.
2
3 use super::{
4 index::BoundsCheckResult, make_local, selection::Selection, Block, BlockContext, Dimension,
5 Error, Instruction, LocalType, LookupType, LoopContext, ResultMember, Writer, WriterFlags,
6 };
7 use crate::{arena::Handle, proc::TypeResolution};
8 use spirv::Word;
9
get_dimension(type_inner: &crate::TypeInner) -> Dimension10 fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
11 match *type_inner {
12 crate::TypeInner::Scalar { .. } => Dimension::Scalar,
13 crate::TypeInner::Vector { .. } => Dimension::Vector,
14 crate::TypeInner::Matrix { .. } => Dimension::Matrix,
15 _ => unreachable!(),
16 }
17 }
18
19 /// The results of emitting code for a left-hand-side expression.
20 ///
21 /// On success, `write_expression_pointer` returns one of these.
22 enum ExpressionPointer {
23 /// The pointer to the expression's value is available, as the value of the
24 /// expression with the given id.
25 Ready { pointer_id: Word },
26
27 /// The access expression must be conditional on the value of `condition`, a boolean
28 /// expression that is true if all indices are in bounds. If `condition` is true, then
29 /// `access` is an `OpAccessChain` instruction that will compute a pointer to the
30 /// expression's value. If `condition` is false, then executing `access` would be
31 /// undefined behavior.
32 Conditional {
33 condition: Word,
34 access: Instruction,
35 },
36 }
37
38 impl Writer {
39 // Flip Y coordinate to adjust for coordinate space difference
40 // between SPIR-V and our IR.
41 // The `position_id` argument is a pointer to a `vecN<f32>`,
42 // whose `y` component we will negate.
write_epilogue_position_y_flip( &mut self, position_id: Word, body: &mut Vec<Instruction>, ) -> Result<(), Error>43 fn write_epilogue_position_y_flip(
44 &mut self,
45 position_id: Word,
46 body: &mut Vec<Instruction>,
47 ) -> Result<(), Error> {
48 let float_ptr_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
49 vector_size: None,
50 kind: crate::ScalarKind::Float,
51 width: 4,
52 pointer_class: Some(spirv::StorageClass::Output),
53 }));
54 let index_y_id = self.get_index_constant(1);
55 let access_id = self.id_gen.next();
56 body.push(Instruction::access_chain(
57 float_ptr_type_id,
58 access_id,
59 position_id,
60 &[index_y_id],
61 ));
62
63 let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
64 vector_size: None,
65 kind: crate::ScalarKind::Float,
66 width: 4,
67 pointer_class: None,
68 }));
69 let load_id = self.id_gen.next();
70 body.push(Instruction::load(float_type_id, load_id, access_id, None));
71
72 let neg_id = self.id_gen.next();
73 body.push(Instruction::unary(
74 spirv::Op::FNegate,
75 float_type_id,
76 neg_id,
77 load_id,
78 ));
79
80 body.push(Instruction::store(access_id, neg_id, None));
81 Ok(())
82 }
83
84 // Clamp fragment depth between 0 and 1.
write_epilogue_frag_depth_clamp( &mut self, frag_depth_id: Word, body: &mut Vec<Instruction>, ) -> Result<(), Error>85 fn write_epilogue_frag_depth_clamp(
86 &mut self,
87 frag_depth_id: Word,
88 body: &mut Vec<Instruction>,
89 ) -> Result<(), Error> {
90 let float_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
91 vector_size: None,
92 kind: crate::ScalarKind::Float,
93 width: 4,
94 pointer_class: None,
95 }));
96 let value0_id = self.get_constant_scalar(crate::ScalarValue::Float(0.0), 4);
97 let value1_id = self.get_constant_scalar(crate::ScalarValue::Float(1.0), 4);
98
99 let original_id = self.id_gen.next();
100 body.push(Instruction::load(
101 float_type_id,
102 original_id,
103 frag_depth_id,
104 None,
105 ));
106
107 let clamp_id = self.id_gen.next();
108 body.push(Instruction::ext_inst(
109 self.gl450_ext_inst_id,
110 spirv::GLOp::FClamp,
111 float_type_id,
112 clamp_id,
113 &[original_id, value0_id, value1_id],
114 ));
115
116 body.push(Instruction::store(frag_depth_id, clamp_id, None));
117 Ok(())
118 }
119
write_entry_point_return( &mut self, value_id: Word, ir_result: &crate::FunctionResult, result_members: &[ResultMember], body: &mut Vec<Instruction>, ) -> Result<(), Error>120 fn write_entry_point_return(
121 &mut self,
122 value_id: Word,
123 ir_result: &crate::FunctionResult,
124 result_members: &[ResultMember],
125 body: &mut Vec<Instruction>,
126 ) -> Result<(), Error> {
127 for (index, res_member) in result_members.iter().enumerate() {
128 let member_value_id = match ir_result.binding {
129 Some(_) => value_id,
130 None => {
131 let member_value_id = self.id_gen.next();
132 body.push(Instruction::composite_extract(
133 res_member.type_id,
134 member_value_id,
135 value_id,
136 &[index as u32],
137 ));
138 member_value_id
139 }
140 };
141
142 body.push(Instruction::store(res_member.id, member_value_id, None));
143
144 match res_member.built_in {
145 Some(crate::BuiltIn::Position)
146 if self.flags.contains(WriterFlags::ADJUST_COORDINATE_SPACE) =>
147 {
148 self.write_epilogue_position_y_flip(res_member.id, body)?;
149 }
150 Some(crate::BuiltIn::FragDepth)
151 if self.flags.contains(WriterFlags::CLAMP_FRAG_DEPTH) =>
152 {
153 self.write_epilogue_frag_depth_clamp(res_member.id, body)?;
154 }
155 _ => {}
156 }
157 }
158 Ok(())
159 }
160 }
161
162 impl<'w> BlockContext<'w> {
163 /// Decide whether to put off emitting instructions for `expr_handle`.
164 ///
165 /// We would like to gather together chains of `Access` and `AccessIndex`
166 /// Naga expressions into a single `OpAccessChain` SPIR-V instruction. To do
167 /// this, we don't generate instructions for these exprs when we first
168 /// encounter them. Their ids in `self.writer.cached.ids` are left as zero. Then,
169 /// once we encounter a `Load` or `Store` expression that actually needs the
170 /// chain's value, we call `write_expression_pointer` to handle the whole
171 /// thing in one fell swoop.
is_intermediate(&self, expr_handle: Handle<crate::Expression>) -> bool172 fn is_intermediate(&self, expr_handle: Handle<crate::Expression>) -> bool {
173 match self.ir_function.expressions[expr_handle] {
174 crate::Expression::GlobalVariable(_) | crate::Expression::LocalVariable(_) => true,
175 crate::Expression::FunctionArgument(index) => {
176 let arg = &self.ir_function.arguments[index as usize];
177 self.ir_module.types[arg.ty].inner.pointer_class().is_some()
178 }
179
180 // The chain rule: if this `Access...`'s `base` operand was
181 // previously omitted, then omit this one, too.
182 _ => self.cached.ids[expr_handle.index()] == 0,
183 }
184 }
185
186 /// Cache an expression for a value.
cache_expression_value( &mut self, expr_handle: Handle<crate::Expression>, block: &mut Block, ) -> Result<(), Error>187 pub(super) fn cache_expression_value(
188 &mut self,
189 expr_handle: Handle<crate::Expression>,
190 block: &mut Block,
191 ) -> Result<(), Error> {
192 let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
193
194 let id = match self.ir_function.expressions[expr_handle] {
195 crate::Expression::Access { base, index: _ } if self.is_intermediate(base) => {
196 // See `is_intermediate`; we'll handle this later in
197 // `write_expression_pointer`.
198 0
199 }
200 crate::Expression::Access { base, index } => {
201 let base_ty = self.fun_info[base].ty.inner_with(&self.ir_module.types);
202 match *base_ty {
203 crate::TypeInner::Vector { .. } => (),
204 ref other => {
205 log::error!(
206 "Unable to access base {:?} of type {:?}",
207 self.ir_function.expressions[base],
208 other
209 );
210 return Err(Error::Validation(
211 "only vectors may be dynamically indexed by value",
212 ));
213 }
214 };
215
216 self.write_vector_access(expr_handle, base, index, block)?
217 }
218 crate::Expression::AccessIndex { base, index: _ } if self.is_intermediate(base) => {
219 // See `is_intermediate`; we'll handle this later in
220 // `write_expression_pointer`.
221 0
222 }
223 crate::Expression::AccessIndex { base, index } => {
224 match *self.fun_info[base].ty.inner_with(&self.ir_module.types) {
225 crate::TypeInner::Vector { .. }
226 | crate::TypeInner::Matrix { .. }
227 | crate::TypeInner::Array { .. }
228 | crate::TypeInner::Struct { .. } => {
229 // We never need bounds checks here: dynamically sized arrays can
230 // only appear behind pointers, and are thus handled by the
231 // `is_intermediate` case above. Everything else's size is
232 // statically known and checked in validation.
233 let id = self.gen_id();
234 let base_id = self.cached[base];
235 block.body.push(Instruction::composite_extract(
236 result_type_id,
237 id,
238 base_id,
239 &[index],
240 ));
241 id
242 }
243 ref other => {
244 log::error!("Unable to access index of {:?}", other);
245 return Err(Error::FeatureNotImplemented("access index for type"));
246 }
247 }
248 }
249 crate::Expression::GlobalVariable(handle) => {
250 self.writer.global_variables[handle.index()].access_id
251 }
252 crate::Expression::Constant(handle) => self.writer.constant_ids[handle.index()],
253 crate::Expression::Splat { size, value } => {
254 let value_id = self.cached[value];
255 let components = [value_id; 4];
256 let id = self.gen_id();
257 block.body.push(Instruction::composite_construct(
258 result_type_id,
259 id,
260 &components[..size as usize],
261 ));
262 id
263 }
264 crate::Expression::Swizzle {
265 size,
266 vector,
267 pattern,
268 } => {
269 let vector_id = self.cached[vector];
270 self.temp_list.clear();
271 for &sc in pattern[..size as usize].iter() {
272 self.temp_list.push(sc as Word);
273 }
274 let id = self.gen_id();
275 block.body.push(Instruction::vector_shuffle(
276 result_type_id,
277 id,
278 vector_id,
279 vector_id,
280 &self.temp_list,
281 ));
282 id
283 }
284 crate::Expression::Compose {
285 ty: _,
286 ref components,
287 } => {
288 self.temp_list.clear();
289 for &component in components {
290 self.temp_list.push(self.cached[component]);
291 }
292
293 let id = self.gen_id();
294 block.body.push(Instruction::composite_construct(
295 result_type_id,
296 id,
297 &self.temp_list,
298 ));
299 id
300 }
301 crate::Expression::Unary { op, expr } => {
302 let id = self.gen_id();
303 let expr_id = self.cached[expr];
304 let expr_ty_inner = self.fun_info[expr].ty.inner_with(&self.ir_module.types);
305
306 let spirv_op = match op {
307 crate::UnaryOperator::Negate => match expr_ty_inner.scalar_kind() {
308 Some(crate::ScalarKind::Float) => spirv::Op::FNegate,
309 Some(crate::ScalarKind::Sint) => spirv::Op::SNegate,
310 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNot,
311 Some(crate::ScalarKind::Uint) | None => {
312 log::error!("Unable to negate {:?}", expr_ty_inner);
313 return Err(Error::FeatureNotImplemented("negation"));
314 }
315 },
316 crate::UnaryOperator::Not => match expr_ty_inner.scalar_kind() {
317 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNot,
318 _ => spirv::Op::Not,
319 },
320 };
321
322 block
323 .body
324 .push(Instruction::unary(spirv_op, result_type_id, id, expr_id));
325 id
326 }
327 crate::Expression::Binary { op, left, right } => {
328 let id = self.gen_id();
329 let left_id = self.cached[left];
330 let right_id = self.cached[right];
331
332 let left_ty_inner = self.fun_info[left].ty.inner_with(&self.ir_module.types);
333 let right_ty_inner = self.fun_info[right].ty.inner_with(&self.ir_module.types);
334
335 let left_dimension = get_dimension(left_ty_inner);
336 let right_dimension = get_dimension(right_ty_inner);
337
338 let mut preserve_order = true;
339
340 let spirv_op = match op {
341 crate::BinaryOperator::Add => match *left_ty_inner {
342 crate::TypeInner::Scalar { kind, .. }
343 | crate::TypeInner::Vector { kind, .. } => match kind {
344 crate::ScalarKind::Float => spirv::Op::FAdd,
345 _ => spirv::Op::IAdd,
346 },
347 _ => unimplemented!(),
348 },
349 crate::BinaryOperator::Subtract => match *left_ty_inner {
350 crate::TypeInner::Scalar { kind, .. }
351 | crate::TypeInner::Vector { kind, .. } => match kind {
352 crate::ScalarKind::Float => spirv::Op::FSub,
353 _ => spirv::Op::ISub,
354 },
355 _ => unimplemented!(),
356 },
357 crate::BinaryOperator::Multiply => match (left_dimension, right_dimension) {
358 (Dimension::Scalar, Dimension::Vector { .. }) => {
359 preserve_order = false;
360 spirv::Op::VectorTimesScalar
361 }
362 (Dimension::Vector, Dimension::Scalar { .. }) => {
363 spirv::Op::VectorTimesScalar
364 }
365 (Dimension::Vector, Dimension::Matrix) => spirv::Op::VectorTimesMatrix,
366 (Dimension::Matrix, Dimension::Scalar { .. }) => {
367 spirv::Op::MatrixTimesScalar
368 }
369 (Dimension::Scalar, Dimension::Matrix { .. }) => {
370 preserve_order = false;
371 spirv::Op::MatrixTimesScalar
372 }
373 (Dimension::Matrix, Dimension::Vector) => spirv::Op::MatrixTimesVector,
374 (Dimension::Matrix, Dimension::Matrix) => spirv::Op::MatrixTimesMatrix,
375 (Dimension::Vector, Dimension::Vector)
376 | (Dimension::Scalar, Dimension::Scalar)
377 if left_ty_inner.scalar_kind() == Some(crate::ScalarKind::Float) =>
378 {
379 spirv::Op::FMul
380 }
381 (Dimension::Vector, Dimension::Vector)
382 | (Dimension::Scalar, Dimension::Scalar) => spirv::Op::IMul,
383 },
384 crate::BinaryOperator::Divide => match left_ty_inner.scalar_kind() {
385 Some(crate::ScalarKind::Sint) => spirv::Op::SDiv,
386 Some(crate::ScalarKind::Uint) => spirv::Op::UDiv,
387 Some(crate::ScalarKind::Float) => spirv::Op::FDiv,
388 _ => unimplemented!(),
389 },
390 crate::BinaryOperator::Modulo => match left_ty_inner.scalar_kind() {
391 Some(crate::ScalarKind::Sint) => spirv::Op::SMod,
392 Some(crate::ScalarKind::Uint) => spirv::Op::UMod,
393 Some(crate::ScalarKind::Float) => spirv::Op::FRem,
394 _ => unimplemented!(),
395 },
396 crate::BinaryOperator::Equal => match left_ty_inner.scalar_kind() {
397 Some(crate::ScalarKind::Sint) | Some(crate::ScalarKind::Uint) => {
398 spirv::Op::IEqual
399 }
400 Some(crate::ScalarKind::Float) => spirv::Op::FOrdEqual,
401 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalEqual,
402 _ => unimplemented!(),
403 },
404 crate::BinaryOperator::NotEqual => match left_ty_inner.scalar_kind() {
405 Some(crate::ScalarKind::Sint) | Some(crate::ScalarKind::Uint) => {
406 spirv::Op::INotEqual
407 }
408 Some(crate::ScalarKind::Float) => spirv::Op::FOrdNotEqual,
409 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalNotEqual,
410 _ => unimplemented!(),
411 },
412 crate::BinaryOperator::Less => match left_ty_inner.scalar_kind() {
413 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThan,
414 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThan,
415 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThan,
416 _ => unimplemented!(),
417 },
418 crate::BinaryOperator::LessEqual => match left_ty_inner.scalar_kind() {
419 Some(crate::ScalarKind::Sint) => spirv::Op::SLessThanEqual,
420 Some(crate::ScalarKind::Uint) => spirv::Op::ULessThanEqual,
421 Some(crate::ScalarKind::Float) => spirv::Op::FOrdLessThanEqual,
422 _ => unimplemented!(),
423 },
424 crate::BinaryOperator::Greater => match left_ty_inner.scalar_kind() {
425 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThan,
426 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThan,
427 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThan,
428 _ => unimplemented!(),
429 },
430 crate::BinaryOperator::GreaterEqual => match left_ty_inner.scalar_kind() {
431 Some(crate::ScalarKind::Sint) => spirv::Op::SGreaterThanEqual,
432 Some(crate::ScalarKind::Uint) => spirv::Op::UGreaterThanEqual,
433 Some(crate::ScalarKind::Float) => spirv::Op::FOrdGreaterThanEqual,
434 _ => unimplemented!(),
435 },
436 crate::BinaryOperator::And => match left_ty_inner.scalar_kind() {
437 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalAnd,
438 _ => spirv::Op::BitwiseAnd,
439 },
440 crate::BinaryOperator::ExclusiveOr => spirv::Op::BitwiseXor,
441 crate::BinaryOperator::InclusiveOr => match left_ty_inner.scalar_kind() {
442 Some(crate::ScalarKind::Bool) => spirv::Op::LogicalOr,
443 _ => spirv::Op::BitwiseOr,
444 },
445 crate::BinaryOperator::LogicalAnd => spirv::Op::LogicalAnd,
446 crate::BinaryOperator::LogicalOr => spirv::Op::LogicalOr,
447 crate::BinaryOperator::ShiftLeft => spirv::Op::ShiftLeftLogical,
448 crate::BinaryOperator::ShiftRight => match left_ty_inner.scalar_kind() {
449 Some(crate::ScalarKind::Sint) => spirv::Op::ShiftRightArithmetic,
450 Some(crate::ScalarKind::Uint) => spirv::Op::ShiftRightLogical,
451 _ => unimplemented!(),
452 },
453 };
454
455 block.body.push(Instruction::binary(
456 spirv_op,
457 result_type_id,
458 id,
459 if preserve_order { left_id } else { right_id },
460 if preserve_order { right_id } else { left_id },
461 ));
462 id
463 }
464 crate::Expression::Math {
465 fun,
466 arg,
467 arg1,
468 arg2,
469 arg3,
470 } => {
471 use crate::MathFunction as Mf;
472 enum MathOp {
473 Ext(spirv::GLOp),
474 Custom(Instruction),
475 }
476
477 let arg0_id = self.cached[arg];
478 let arg_ty = self.fun_info[arg].ty.inner_with(&self.ir_module.types);
479 let arg_scalar_kind = arg_ty.scalar_kind();
480 let arg1_id = match arg1 {
481 Some(handle) => self.cached[handle],
482 None => 0,
483 };
484 let arg2_id = match arg2 {
485 Some(handle) => self.cached[handle],
486 None => 0,
487 };
488 let arg3_id = match arg3 {
489 Some(handle) => self.cached[handle],
490 None => 0,
491 };
492
493 let id = self.gen_id();
494 let math_op = match fun {
495 // comparison
496 Mf::Abs => {
497 match arg_scalar_kind {
498 Some(crate::ScalarKind::Float) => MathOp::Ext(spirv::GLOp::FAbs),
499 Some(crate::ScalarKind::Sint) => MathOp::Ext(spirv::GLOp::SAbs),
500 Some(crate::ScalarKind::Uint) => {
501 MathOp::Custom(Instruction::unary(
502 spirv::Op::CopyObject, // do nothing
503 result_type_id,
504 id,
505 arg0_id,
506 ))
507 }
508 other => unimplemented!("Unexpected abs({:?})", other),
509 }
510 }
511 Mf::Min => MathOp::Ext(match arg_scalar_kind {
512 Some(crate::ScalarKind::Float) => spirv::GLOp::FMin,
513 Some(crate::ScalarKind::Sint) => spirv::GLOp::SMin,
514 Some(crate::ScalarKind::Uint) => spirv::GLOp::UMin,
515 other => unimplemented!("Unexpected min({:?})", other),
516 }),
517 Mf::Max => MathOp::Ext(match arg_scalar_kind {
518 Some(crate::ScalarKind::Float) => spirv::GLOp::FMax,
519 Some(crate::ScalarKind::Sint) => spirv::GLOp::SMax,
520 Some(crate::ScalarKind::Uint) => spirv::GLOp::UMax,
521 other => unimplemented!("Unexpected max({:?})", other),
522 }),
523 Mf::Clamp => MathOp::Ext(match arg_scalar_kind {
524 Some(crate::ScalarKind::Float) => spirv::GLOp::FClamp,
525 Some(crate::ScalarKind::Sint) => spirv::GLOp::SClamp,
526 Some(crate::ScalarKind::Uint) => spirv::GLOp::UClamp,
527 other => unimplemented!("Unexpected max({:?})", other),
528 }),
529 // trigonometry
530 Mf::Sin => MathOp::Ext(spirv::GLOp::Sin),
531 Mf::Sinh => MathOp::Ext(spirv::GLOp::Sinh),
532 Mf::Asin => MathOp::Ext(spirv::GLOp::Asin),
533 Mf::Cos => MathOp::Ext(spirv::GLOp::Cos),
534 Mf::Cosh => MathOp::Ext(spirv::GLOp::Cosh),
535 Mf::Acos => MathOp::Ext(spirv::GLOp::Acos),
536 Mf::Tan => MathOp::Ext(spirv::GLOp::Tan),
537 Mf::Tanh => MathOp::Ext(spirv::GLOp::Tanh),
538 Mf::Atan => MathOp::Ext(spirv::GLOp::Atan),
539 Mf::Atan2 => MathOp::Ext(spirv::GLOp::Atan2),
540 Mf::Asinh => MathOp::Ext(spirv::GLOp::Asinh),
541 Mf::Acosh => MathOp::Ext(spirv::GLOp::Acosh),
542 Mf::Atanh => MathOp::Ext(spirv::GLOp::Atanh),
543 Mf::Radians => MathOp::Ext(spirv::GLOp::Radians),
544 Mf::Degrees => MathOp::Ext(spirv::GLOp::Degrees),
545 // decomposition
546 Mf::Ceil => MathOp::Ext(spirv::GLOp::Ceil),
547 Mf::Round => MathOp::Ext(spirv::GLOp::RoundEven),
548 Mf::Floor => MathOp::Ext(spirv::GLOp::Floor),
549 Mf::Fract => MathOp::Ext(spirv::GLOp::Fract),
550 Mf::Trunc => MathOp::Ext(spirv::GLOp::Trunc),
551 Mf::Modf => MathOp::Ext(spirv::GLOp::Modf),
552 Mf::Frexp => MathOp::Ext(spirv::GLOp::Frexp),
553 Mf::Ldexp => MathOp::Ext(spirv::GLOp::Ldexp),
554 // geometry
555 Mf::Dot => MathOp::Custom(Instruction::binary(
556 spirv::Op::Dot,
557 result_type_id,
558 id,
559 arg0_id,
560 arg1_id,
561 )),
562 Mf::Outer => MathOp::Custom(Instruction::binary(
563 spirv::Op::OuterProduct,
564 result_type_id,
565 id,
566 arg0_id,
567 arg1_id,
568 )),
569 Mf::Cross => MathOp::Ext(spirv::GLOp::Cross),
570 Mf::Distance => MathOp::Ext(spirv::GLOp::Distance),
571 Mf::Length => MathOp::Ext(spirv::GLOp::Length),
572 Mf::Normalize => MathOp::Ext(spirv::GLOp::Normalize),
573 Mf::FaceForward => MathOp::Ext(spirv::GLOp::FaceForward),
574 Mf::Reflect => MathOp::Ext(spirv::GLOp::Reflect),
575 Mf::Refract => MathOp::Ext(spirv::GLOp::Refract),
576 // exponent
577 Mf::Exp => MathOp::Ext(spirv::GLOp::Exp),
578 Mf::Exp2 => MathOp::Ext(spirv::GLOp::Exp2),
579 Mf::Log => MathOp::Ext(spirv::GLOp::Log),
580 Mf::Log2 => MathOp::Ext(spirv::GLOp::Log2),
581 Mf::Pow => MathOp::Ext(spirv::GLOp::Pow),
582 // computational
583 Mf::Sign => MathOp::Ext(match arg_scalar_kind {
584 Some(crate::ScalarKind::Float) => spirv::GLOp::FSign,
585 Some(crate::ScalarKind::Sint) => spirv::GLOp::SSign,
586 other => unimplemented!("Unexpected sign({:?})", other),
587 }),
588 Mf::Fma => MathOp::Ext(spirv::GLOp::Fma),
589 Mf::Mix => {
590 let selector = arg2.unwrap();
591 let selector_ty =
592 self.fun_info[selector].ty.inner_with(&self.ir_module.types);
593 match (arg_ty, selector_ty) {
594 // if the selector is a scalar, we need to splat it
595 (
596 &crate::TypeInner::Vector { size, .. },
597 &crate::TypeInner::Scalar { kind, width },
598 ) => {
599 let selector_type_id =
600 self.get_type_id(LookupType::Local(LocalType::Value {
601 vector_size: Some(size),
602 kind,
603 width,
604 pointer_class: None,
605 }));
606 self.temp_list.clear();
607 self.temp_list.resize(size as usize, arg2_id);
608
609 let selector_id = self.gen_id();
610 block.body.push(Instruction::composite_construct(
611 selector_type_id,
612 selector_id,
613 &self.temp_list,
614 ));
615
616 MathOp::Custom(Instruction::ext_inst(
617 self.writer.gl450_ext_inst_id,
618 spirv::GLOp::FMix,
619 result_type_id,
620 id,
621 &[arg0_id, arg1_id, selector_id],
622 ))
623 }
624 _ => MathOp::Ext(spirv::GLOp::FMix),
625 }
626 }
627 Mf::Step => MathOp::Ext(spirv::GLOp::Step),
628 Mf::SmoothStep => MathOp::Ext(spirv::GLOp::SmoothStep),
629 Mf::Sqrt => MathOp::Ext(spirv::GLOp::Sqrt),
630 Mf::InverseSqrt => MathOp::Ext(spirv::GLOp::InverseSqrt),
631 Mf::Inverse => MathOp::Ext(spirv::GLOp::MatrixInverse),
632 Mf::Transpose => MathOp::Custom(Instruction::unary(
633 spirv::Op::Transpose,
634 result_type_id,
635 id,
636 arg0_id,
637 )),
638 Mf::Determinant => MathOp::Ext(spirv::GLOp::Determinant),
639 Mf::ReverseBits | Mf::CountOneBits => {
640 log::error!("unimplemented math function {:?}", fun);
641 return Err(Error::FeatureNotImplemented("math function"));
642 }
643 Mf::ExtractBits => {
644 let op = match arg_scalar_kind {
645 Some(crate::ScalarKind::Uint) => spirv::Op::BitFieldUExtract,
646 Some(crate::ScalarKind::Sint) => spirv::Op::BitFieldSExtract,
647 other => unimplemented!("Unexpected sign({:?})", other),
648 };
649 MathOp::Custom(Instruction::ternary(
650 op,
651 result_type_id,
652 id,
653 arg0_id,
654 arg1_id,
655 arg2_id,
656 ))
657 }
658 Mf::InsertBits => MathOp::Custom(Instruction::quaternary(
659 spirv::Op::BitFieldInsert,
660 result_type_id,
661 id,
662 arg0_id,
663 arg1_id,
664 arg2_id,
665 arg3_id,
666 )),
667 Mf::FindLsb => MathOp::Ext(spirv::GLOp::FindILsb),
668 Mf::FindMsb => MathOp::Ext(match arg_scalar_kind {
669 Some(crate::ScalarKind::Uint) => spirv::GLOp::FindUMsb,
670 Some(crate::ScalarKind::Sint) => spirv::GLOp::FindSMsb,
671 other => unimplemented!("Unexpected findMSB({:?})", other),
672 }),
673 Mf::Pack4x8unorm => MathOp::Ext(spirv::GLOp::PackUnorm4x8),
674 Mf::Pack4x8snorm => MathOp::Ext(spirv::GLOp::PackSnorm4x8),
675 Mf::Pack2x16float => MathOp::Ext(spirv::GLOp::PackHalf2x16),
676 Mf::Pack2x16unorm => MathOp::Ext(spirv::GLOp::PackUnorm2x16),
677 Mf::Pack2x16snorm => MathOp::Ext(spirv::GLOp::PackSnorm2x16),
678 Mf::Unpack4x8unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm4x8),
679 Mf::Unpack4x8snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm4x8),
680 Mf::Unpack2x16float => MathOp::Ext(spirv::GLOp::UnpackHalf2x16),
681 Mf::Unpack2x16unorm => MathOp::Ext(spirv::GLOp::UnpackUnorm2x16),
682 Mf::Unpack2x16snorm => MathOp::Ext(spirv::GLOp::UnpackSnorm2x16),
683 };
684
685 block.body.push(match math_op {
686 MathOp::Ext(op) => Instruction::ext_inst(
687 self.writer.gl450_ext_inst_id,
688 op,
689 result_type_id,
690 id,
691 &[arg0_id, arg1_id, arg2_id, arg3_id][..fun.argument_count()],
692 ),
693 MathOp::Custom(inst) => inst,
694 });
695 id
696 }
697 crate::Expression::LocalVariable(variable) => self.function.variables[&variable].id,
698 crate::Expression::Load { pointer } => {
699 match self.write_expression_pointer(pointer, block)? {
700 ExpressionPointer::Ready { pointer_id } => {
701 let id = self.gen_id();
702 let atomic_class =
703 match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) {
704 crate::TypeInner::Pointer { base, class } => {
705 match self.ir_module.types[base].inner {
706 crate::TypeInner::Atomic { .. } => Some(class),
707 _ => None,
708 }
709 }
710 _ => None,
711 };
712 let instruction = if let Some(class) = atomic_class {
713 let (semantics, scope) = class.to_spirv_semantics_and_scope();
714 let scope_constant_id = self.get_scope_constant(scope as u32);
715 let semantics_id = self.get_index_constant(semantics.bits());
716 Instruction::atomic_load(
717 result_type_id,
718 id,
719 pointer_id,
720 scope_constant_id,
721 semantics_id,
722 )
723 } else {
724 Instruction::load(result_type_id, id, pointer_id, None)
725 };
726 block.body.push(instruction);
727 id
728 }
729 ExpressionPointer::Conditional { condition, access } => {
730 //TODO: support atomics?
731 self.write_conditional_indexed_load(
732 result_type_id,
733 condition,
734 block,
735 move |id_gen, block| {
736 // The in-bounds path. Perform the access and the load.
737 let pointer_id = access.result_id.unwrap();
738 let value_id = id_gen.next();
739 block.body.push(access);
740 block.body.push(Instruction::load(
741 result_type_id,
742 value_id,
743 pointer_id,
744 None,
745 ));
746 value_id
747 },
748 )
749 }
750 }
751 }
752 crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
753 crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } => {
754 self.cached[expr_handle]
755 }
756 crate::Expression::As {
757 expr,
758 kind,
759 convert,
760 } => {
761 use crate::ScalarKind as Sk;
762
763 let expr_id = self.cached[expr];
764 let (src_kind, src_size, src_width) =
765 match *self.fun_info[expr].ty.inner_with(&self.ir_module.types) {
766 crate::TypeInner::Scalar { kind, width } => (kind, None, width),
767 crate::TypeInner::Vector { kind, width, size } => (kind, Some(size), width),
768 ref other => {
769 log::error!("As source {:?}", other);
770 return Err(Error::Validation("Unexpected Expression::As source"));
771 }
772 };
773
774 enum Cast {
775 Unary(spirv::Op),
776 Binary(spirv::Op, Word),
777 Ternary(spirv::Op, Word, Word),
778 }
779
780 let cast = match (src_kind, kind, convert) {
781 (_, _, None) | (Sk::Bool, Sk::Bool, Some(_)) => Cast::Unary(spirv::Op::Bitcast),
782 // casting to a bool - generate `OpXxxNotEqual`
783 (_, Sk::Bool, Some(_)) => {
784 let (op, value) = match src_kind {
785 Sk::Sint => (spirv::Op::INotEqual, crate::ScalarValue::Sint(0)),
786 Sk::Uint => (spirv::Op::INotEqual, crate::ScalarValue::Uint(0)),
787 Sk::Float => {
788 (spirv::Op::FUnordNotEqual, crate::ScalarValue::Float(0.0))
789 }
790 Sk::Bool => unreachable!(),
791 };
792 let zero_scalar_id = self.writer.get_constant_scalar(value, src_width);
793 let zero_id = match src_size {
794 Some(size) => {
795 let vector_type_id =
796 self.get_type_id(LookupType::Local(LocalType::Value {
797 vector_size: Some(size),
798 kind: src_kind,
799 width: src_width,
800 pointer_class: None,
801 }));
802 let components = [zero_scalar_id; 4];
803
804 let zero_id = self.gen_id();
805 block.body.push(Instruction::composite_construct(
806 vector_type_id,
807 zero_id,
808 &components[..size as usize],
809 ));
810 zero_id
811 }
812 None => zero_scalar_id,
813 };
814
815 Cast::Binary(op, zero_id)
816 }
817 // casting from a bool - generate `OpSelect`
818 (Sk::Bool, _, Some(dst_width)) => {
819 let (val0, val1) = match kind {
820 Sk::Sint => (crate::ScalarValue::Sint(0), crate::ScalarValue::Sint(1)),
821 Sk::Uint => (crate::ScalarValue::Uint(0), crate::ScalarValue::Uint(1)),
822 Sk::Float => (
823 crate::ScalarValue::Float(0.0),
824 crate::ScalarValue::Float(1.0),
825 ),
826 Sk::Bool => unreachable!(),
827 };
828 let scalar0_id = self.writer.get_constant_scalar(val0, dst_width);
829 let scalar1_id = self.writer.get_constant_scalar(val1, dst_width);
830 let (accept_id, reject_id) = match src_size {
831 Some(size) => {
832 let vector_type_id =
833 self.get_type_id(LookupType::Local(LocalType::Value {
834 vector_size: Some(size),
835 kind,
836 width: dst_width,
837 pointer_class: None,
838 }));
839 let components0 = [scalar0_id; 4];
840 let components1 = [scalar1_id; 4];
841
842 let vec0_id = self.gen_id();
843 block.body.push(Instruction::composite_construct(
844 vector_type_id,
845 vec0_id,
846 &components0[..size as usize],
847 ));
848 let vec1_id = self.gen_id();
849 block.body.push(Instruction::composite_construct(
850 vector_type_id,
851 vec1_id,
852 &components1[..size as usize],
853 ));
854 (vec1_id, vec0_id)
855 }
856 None => (scalar1_id, scalar0_id),
857 };
858
859 Cast::Ternary(spirv::Op::Select, accept_id, reject_id)
860 }
861 (Sk::Float, Sk::Uint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToU),
862 (Sk::Float, Sk::Sint, Some(_)) => Cast::Unary(spirv::Op::ConvertFToS),
863 (Sk::Float, Sk::Float, Some(dst_width)) if src_width != dst_width => {
864 Cast::Unary(spirv::Op::FConvert)
865 }
866 (Sk::Sint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertSToF),
867 (Sk::Sint, Sk::Sint, Some(dst_width)) if src_width != dst_width => {
868 Cast::Unary(spirv::Op::SConvert)
869 }
870 (Sk::Uint, Sk::Float, Some(_)) => Cast::Unary(spirv::Op::ConvertUToF),
871 (Sk::Uint, Sk::Uint, Some(dst_width)) if src_width != dst_width => {
872 Cast::Unary(spirv::Op::UConvert)
873 }
874 // We assume it's either an identity cast, or int-uint.
875 _ => Cast::Unary(spirv::Op::Bitcast),
876 };
877
878 let id = self.gen_id();
879 let instruction = match cast {
880 Cast::Unary(op) => Instruction::unary(op, result_type_id, id, expr_id),
881 Cast::Binary(op, operand) => {
882 Instruction::binary(op, result_type_id, id, expr_id, operand)
883 }
884 Cast::Ternary(op, op1, op2) => {
885 Instruction::ternary(op, result_type_id, id, expr_id, op1, op2)
886 }
887 };
888 block.body.push(instruction);
889 id
890 }
891 crate::Expression::ImageLoad {
892 image,
893 coordinate,
894 array_index,
895 index,
896 } => {
897 self.write_image_load(result_type_id, image, coordinate, array_index, index, block)?
898 }
899 crate::Expression::ImageSample {
900 image,
901 sampler,
902 gather,
903 coordinate,
904 array_index,
905 offset,
906 level,
907 depth_ref,
908 } => self.write_image_sample(
909 result_type_id,
910 image,
911 sampler,
912 gather,
913 coordinate,
914 array_index,
915 offset,
916 level,
917 depth_ref,
918 block,
919 )?,
920 crate::Expression::Select {
921 condition,
922 accept,
923 reject,
924 } => {
925 let id = self.gen_id();
926 let mut condition_id = self.cached[condition];
927 let accept_id = self.cached[accept];
928 let reject_id = self.cached[reject];
929
930 let condition_ty = self.fun_info[condition]
931 .ty
932 .inner_with(&self.ir_module.types);
933 let object_ty = self.fun_info[accept].ty.inner_with(&self.ir_module.types);
934
935 if let (
936 &crate::TypeInner::Scalar {
937 kind: crate::ScalarKind::Bool,
938 width,
939 },
940 &crate::TypeInner::Vector { size, .. },
941 ) = (condition_ty, object_ty)
942 {
943 self.temp_list.clear();
944 self.temp_list.resize(size as usize, condition_id);
945
946 let bool_vector_type_id =
947 self.get_type_id(LookupType::Local(LocalType::Value {
948 vector_size: Some(size),
949 kind: crate::ScalarKind::Bool,
950 width,
951 pointer_class: None,
952 }));
953
954 let id = self.gen_id();
955 block.body.push(Instruction::composite_construct(
956 bool_vector_type_id,
957 id,
958 &self.temp_list,
959 ));
960 condition_id = id
961 }
962
963 let instruction =
964 Instruction::select(result_type_id, id, condition_id, accept_id, reject_id);
965 block.body.push(instruction);
966 id
967 }
968 crate::Expression::Derivative { axis, expr } => {
969 use crate::DerivativeAxis as Da;
970
971 let id = self.gen_id();
972 let expr_id = self.cached[expr];
973 let op = match axis {
974 Da::X => spirv::Op::DPdx,
975 Da::Y => spirv::Op::DPdy,
976 Da::Width => spirv::Op::Fwidth,
977 };
978 block
979 .body
980 .push(Instruction::derivative(op, result_type_id, id, expr_id));
981 id
982 }
983 crate::Expression::ImageQuery { image, query } => {
984 self.write_image_query(result_type_id, image, query, block)?
985 }
986 crate::Expression::Relational { fun, argument } => {
987 use crate::RelationalFunction as Rf;
988 let arg_id = self.cached[argument];
989 let op = match fun {
990 Rf::All => spirv::Op::All,
991 Rf::Any => spirv::Op::Any,
992 Rf::IsNan => spirv::Op::IsNan,
993 Rf::IsInf => spirv::Op::IsInf,
994 //TODO: these require Kernel capability
995 Rf::IsFinite | Rf::IsNormal => {
996 return Err(Error::FeatureNotImplemented("is finite/normal"))
997 }
998 };
999 let id = self.gen_id();
1000 block
1001 .body
1002 .push(Instruction::relational(op, result_type_id, id, arg_id));
1003 id
1004 }
1005 crate::Expression::ArrayLength(expr) => self.write_runtime_array_length(expr, block)?,
1006 };
1007
1008 self.cached[expr_handle] = id;
1009 Ok(())
1010 }
1011
1012 /// Build an `OpAccessChain` instruction.
1013 ///
1014 /// Emit any needed bounds-checking expressions to `block`.
1015 ///
1016 /// On success, the return value is an [`ExpressionPointer`] value; see the
1017 /// documentation for that type.
write_expression_pointer( &mut self, mut expr_handle: Handle<crate::Expression>, block: &mut Block, ) -> Result<ExpressionPointer, Error>1018 fn write_expression_pointer(
1019 &mut self,
1020 mut expr_handle: Handle<crate::Expression>,
1021 block: &mut Block,
1022 ) -> Result<ExpressionPointer, Error> {
1023 let result_lookup_ty = match self.fun_info[expr_handle].ty {
1024 TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle),
1025 TypeResolution::Value(ref inner) => LookupType::Local(make_local(inner).unwrap()),
1026 };
1027 let result_type_id = self.get_type_id(result_lookup_ty);
1028
1029 // The id of the boolean `and` of all dynamic bounds checks up to this point. If
1030 // `None`, then we haven't done any dynamic bounds checks yet.
1031 //
1032 // When we have a chain of bounds checks, we combine them with `OpLogicalAnd`, not
1033 // a short-circuit branch. This means we might do comparisons we don't need to,
1034 // but we expect these checks to almost always succeed, and keeping branches to a
1035 // minimum is essential.
1036 let mut accumulated_checks = None;
1037
1038 self.temp_list.clear();
1039 let root_id = loop {
1040 expr_handle = match self.ir_function.expressions[expr_handle] {
1041 crate::Expression::Access { base, index } => {
1042 let index_id = match self.write_bounds_check(base, index, block)? {
1043 BoundsCheckResult::KnownInBounds(known_index) => {
1044 // Even if the index is known, `OpAccessIndex`
1045 // requires expression operands, not literals.
1046 let scalar = crate::ScalarValue::Uint(known_index as u64);
1047 self.writer.get_constant_scalar(scalar, 4)
1048 }
1049 BoundsCheckResult::Computed(computed_index_id) => computed_index_id,
1050 BoundsCheckResult::Conditional(comparison_id) => {
1051 match accumulated_checks {
1052 Some(prior_checks) => {
1053 let combined = self.gen_id();
1054 block.body.push(Instruction::binary(
1055 spirv::Op::LogicalAnd,
1056 self.writer.get_bool_type_id(),
1057 combined,
1058 prior_checks,
1059 comparison_id,
1060 ));
1061 accumulated_checks = Some(combined);
1062 }
1063 None => {
1064 // Start a fresh chain of checks.
1065 accumulated_checks = Some(comparison_id);
1066 }
1067 }
1068
1069 // Either way, the index to use is unchanged.
1070 self.cached[index]
1071 }
1072 };
1073 self.temp_list.push(index_id);
1074
1075 base
1076 }
1077 crate::Expression::AccessIndex { base, index } => {
1078 let const_id = self.get_index_constant(index);
1079 self.temp_list.push(const_id);
1080 base
1081 }
1082 crate::Expression::GlobalVariable(handle) => {
1083 let gv = &self.writer.global_variables[handle.index()];
1084 break gv.access_id;
1085 }
1086 crate::Expression::LocalVariable(variable) => {
1087 let local_var = &self.function.variables[&variable];
1088 break local_var.id;
1089 }
1090 crate::Expression::FunctionArgument(index) => {
1091 break self.function.parameter_id(index);
1092 }
1093 ref other => unimplemented!("Unexpected pointer expression {:?}", other),
1094 }
1095 };
1096
1097 let pointer = if self.temp_list.is_empty() {
1098 ExpressionPointer::Ready {
1099 pointer_id: root_id,
1100 }
1101 } else {
1102 self.temp_list.reverse();
1103 let pointer_id = self.gen_id();
1104 let access =
1105 Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list);
1106
1107 // If we generated some bounds checks, we need to leave it to our
1108 // caller to generate the branch, the access, the load or store, and
1109 // the zero value (for loads). Otherwise, we can emit the access
1110 // ourselves, and just hand them the id of the pointer.
1111 match accumulated_checks {
1112 Some(condition) => ExpressionPointer::Conditional { condition, access },
1113 None => {
1114 block.body.push(access);
1115 ExpressionPointer::Ready { pointer_id }
1116 }
1117 }
1118 };
1119
1120 Ok(pointer)
1121 }
1122
write_block( &mut self, label_id: Word, statements: &[crate::Statement], exit_id: Option<Word>, loop_context: LoopContext, ) -> Result<(), Error>1123 pub(super) fn write_block(
1124 &mut self,
1125 label_id: Word,
1126 statements: &[crate::Statement],
1127 exit_id: Option<Word>,
1128 loop_context: LoopContext,
1129 ) -> Result<(), Error> {
1130 let mut block = Block::new(label_id);
1131
1132 for statement in statements {
1133 match *statement {
1134 crate::Statement::Emit(ref range) => {
1135 for handle in range.clone() {
1136 self.cache_expression_value(handle, &mut block)?;
1137 }
1138 }
1139 crate::Statement::Block(ref block_statements) => {
1140 let scope_id = self.gen_id();
1141 self.function.consume(block, Instruction::branch(scope_id));
1142
1143 let merge_id = self.gen_id();
1144 self.write_block(scope_id, block_statements, Some(merge_id), loop_context)?;
1145
1146 block = Block::new(merge_id);
1147 }
1148 crate::Statement::If {
1149 condition,
1150 ref accept,
1151 ref reject,
1152 } => {
1153 let condition_id = self.cached[condition];
1154
1155 let merge_id = self.gen_id();
1156 block.body.push(Instruction::selection_merge(
1157 merge_id,
1158 spirv::SelectionControl::NONE,
1159 ));
1160
1161 let accept_id = if accept.is_empty() {
1162 None
1163 } else {
1164 Some(self.gen_id())
1165 };
1166 let reject_id = if reject.is_empty() {
1167 None
1168 } else {
1169 Some(self.gen_id())
1170 };
1171
1172 self.function.consume(
1173 block,
1174 Instruction::branch_conditional(
1175 condition_id,
1176 accept_id.unwrap_or(merge_id),
1177 reject_id.unwrap_or(merge_id),
1178 ),
1179 );
1180
1181 if let Some(block_id) = accept_id {
1182 self.write_block(block_id, accept, Some(merge_id), loop_context)?;
1183 }
1184 if let Some(block_id) = reject_id {
1185 self.write_block(block_id, reject, Some(merge_id), loop_context)?;
1186 }
1187
1188 block = Block::new(merge_id);
1189 }
1190 crate::Statement::Switch {
1191 selector,
1192 ref cases,
1193 } => {
1194 let selector_id = self.cached[selector];
1195
1196 let merge_id = self.gen_id();
1197 block.body.push(Instruction::selection_merge(
1198 merge_id,
1199 spirv::SelectionControl::NONE,
1200 ));
1201
1202 let default_id = self.gen_id();
1203
1204 let mut reached_default = false;
1205 let mut raw_cases = Vec::with_capacity(cases.len());
1206 let mut case_ids = Vec::with_capacity(cases.len());
1207 for case in cases.iter() {
1208 match case.value {
1209 crate::SwitchValue::Integer(value) => {
1210 let label_id = self.gen_id();
1211 // No cases should be added after the default case is encountered
1212 // since the default case catches all
1213 if !reached_default {
1214 raw_cases.push(super::instructions::Case {
1215 value: value as Word,
1216 label_id,
1217 });
1218 }
1219 case_ids.push(label_id);
1220 }
1221 crate::SwitchValue::Default => {
1222 case_ids.push(default_id);
1223 reached_default = true;
1224 }
1225 }
1226 }
1227
1228 self.function.consume(
1229 block,
1230 Instruction::switch(selector_id, default_id, &raw_cases),
1231 );
1232
1233 let inner_context = LoopContext {
1234 break_id: Some(merge_id),
1235 ..loop_context
1236 };
1237
1238 for (i, (case, label_id)) in cases.iter().zip(case_ids.iter()).enumerate() {
1239 let case_finish_id = if case.fall_through {
1240 case_ids[i + 1]
1241 } else {
1242 merge_id
1243 };
1244 self.write_block(
1245 *label_id,
1246 &case.body,
1247 Some(case_finish_id),
1248 inner_context,
1249 )?;
1250 }
1251
1252 // If no default was encountered write a empty block to satisfy the presence of
1253 // a block the default label
1254 if !reached_default {
1255 self.write_block(default_id, &[], Some(merge_id), inner_context)?;
1256 }
1257
1258 block = Block::new(merge_id);
1259 }
1260 crate::Statement::Loop {
1261 ref body,
1262 ref continuing,
1263 } => {
1264 let preamble_id = self.gen_id();
1265 self.function
1266 .consume(block, Instruction::branch(preamble_id));
1267
1268 let merge_id = self.gen_id();
1269 let body_id = self.gen_id();
1270 let continuing_id = self.gen_id();
1271
1272 // SPIR-V requires the continuing to the `OpLoopMerge`,
1273 // so we have to start a new block with it.
1274 block = Block::new(preamble_id);
1275 block.body.push(Instruction::loop_merge(
1276 merge_id,
1277 continuing_id,
1278 spirv::SelectionControl::NONE,
1279 ));
1280 self.function.consume(block, Instruction::branch(body_id));
1281
1282 self.write_block(
1283 body_id,
1284 body,
1285 Some(continuing_id),
1286 LoopContext {
1287 continuing_id: Some(continuing_id),
1288 break_id: Some(merge_id),
1289 },
1290 )?;
1291
1292 self.write_block(
1293 continuing_id,
1294 continuing,
1295 Some(preamble_id),
1296 LoopContext {
1297 continuing_id: None,
1298 break_id: Some(merge_id),
1299 },
1300 )?;
1301
1302 block = Block::new(merge_id);
1303 }
1304 crate::Statement::Break => {
1305 self.function
1306 .consume(block, Instruction::branch(loop_context.break_id.unwrap()));
1307 return Ok(());
1308 }
1309 crate::Statement::Continue => {
1310 self.function.consume(
1311 block,
1312 Instruction::branch(loop_context.continuing_id.unwrap()),
1313 );
1314 return Ok(());
1315 }
1316 crate::Statement::Return { value: Some(value) } => {
1317 let value_id = self.cached[value];
1318 let instruction = match self.function.entry_point_context {
1319 // If this is an entry point, and we need to return anything,
1320 // let's instead store the output variables and return `void`.
1321 Some(ref context) => {
1322 self.writer.write_entry_point_return(
1323 value_id,
1324 self.ir_function.result.as_ref().unwrap(),
1325 &context.results,
1326 &mut block.body,
1327 )?;
1328 Instruction::return_void()
1329 }
1330 None => Instruction::return_value(value_id),
1331 };
1332 self.function.consume(block, instruction);
1333 return Ok(());
1334 }
1335 crate::Statement::Return { value: None } => {
1336 self.function.consume(block, Instruction::return_void());
1337 return Ok(());
1338 }
1339 crate::Statement::Kill => {
1340 self.function.consume(block, Instruction::kill());
1341 return Ok(());
1342 }
1343 crate::Statement::Barrier(flags) => {
1344 let memory_scope = if flags.contains(crate::Barrier::STORAGE) {
1345 spirv::Scope::Device
1346 } else {
1347 spirv::Scope::Workgroup
1348 };
1349 let mut semantics = spirv::MemorySemantics::ACQUIRE_RELEASE;
1350 semantics.set(
1351 spirv::MemorySemantics::UNIFORM_MEMORY,
1352 flags.contains(crate::Barrier::STORAGE),
1353 );
1354 semantics.set(
1355 spirv::MemorySemantics::WORKGROUP_MEMORY,
1356 flags.contains(crate::Barrier::WORK_GROUP),
1357 );
1358 let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32);
1359 let mem_scope_id = self.get_index_constant(memory_scope as u32);
1360 let semantics_id = self.get_index_constant(semantics.bits());
1361 block.body.push(Instruction::control_barrier(
1362 exec_scope_id,
1363 mem_scope_id,
1364 semantics_id,
1365 ));
1366 }
1367 crate::Statement::Store { pointer, value } => {
1368 let value_id = self.cached[value];
1369 match self.write_expression_pointer(pointer, &mut block)? {
1370 ExpressionPointer::Ready { pointer_id } => {
1371 let atomic_class = match *self.fun_info[pointer]
1372 .ty
1373 .inner_with(&self.ir_module.types)
1374 {
1375 crate::TypeInner::Pointer { base, class } => {
1376 match self.ir_module.types[base].inner {
1377 crate::TypeInner::Atomic { .. } => Some(class),
1378 _ => None,
1379 }
1380 }
1381 _ => None,
1382 };
1383 let instruction = if let Some(class) = atomic_class {
1384 let (semantics, scope) = class.to_spirv_semantics_and_scope();
1385 let scope_constant_id = self.get_scope_constant(scope as u32);
1386 let semantics_id = self.get_index_constant(semantics.bits());
1387 Instruction::atomic_store(
1388 pointer_id,
1389 scope_constant_id,
1390 semantics_id,
1391 value_id,
1392 )
1393 } else {
1394 Instruction::store(pointer_id, value_id, None)
1395 };
1396 block.body.push(instruction);
1397 }
1398 ExpressionPointer::Conditional { condition, access } => {
1399 let mut selection = Selection::start(&mut block, ());
1400 selection.if_true(self, condition, ());
1401
1402 // The in-bounds path. Perform the access and the store.
1403 let pointer_id = access.result_id.unwrap();
1404 selection.block().body.push(access);
1405 selection
1406 .block()
1407 .body
1408 .push(Instruction::store(pointer_id, value_id, None));
1409
1410 // Finish the in-bounds block and start the merge block. This
1411 // is the block we'll leave current on return.
1412 selection.finish(self, ());
1413 }
1414 };
1415 }
1416 crate::Statement::ImageStore {
1417 image,
1418 coordinate,
1419 array_index,
1420 value,
1421 } => self.write_image_store(image, coordinate, array_index, value, &mut block)?,
1422 crate::Statement::Call {
1423 function: local_function,
1424 ref arguments,
1425 result,
1426 } => {
1427 let id = self.gen_id();
1428 self.temp_list.clear();
1429 for &argument in arguments {
1430 self.temp_list.push(self.cached[argument]);
1431 }
1432
1433 let type_id = match result {
1434 Some(expr) => {
1435 self.cached[expr] = id;
1436 self.get_expression_type_id(&self.fun_info[expr].ty)
1437 }
1438 None => self.writer.void_type,
1439 };
1440
1441 block.body.push(Instruction::function_call(
1442 type_id,
1443 id,
1444 self.writer.lookup_function[&local_function],
1445 &self.temp_list,
1446 ));
1447 }
1448 crate::Statement::Atomic {
1449 pointer,
1450 ref fun,
1451 value,
1452 result,
1453 } => {
1454 let id = self.gen_id();
1455 let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
1456
1457 self.cached[result] = id;
1458
1459 let pointer_id = match self.write_expression_pointer(pointer, &mut block)? {
1460 ExpressionPointer::Ready { pointer_id } => pointer_id,
1461 ExpressionPointer::Conditional { .. } => {
1462 return Err(Error::FeatureNotImplemented(
1463 "Atomics out-of-bounds handling",
1464 ));
1465 }
1466 };
1467
1468 let class = match *self.fun_info[pointer].ty.inner_with(&self.ir_module.types) {
1469 crate::TypeInner::Pointer { base: _, class } => class,
1470 _ => unimplemented!(),
1471 };
1472 let (semantics, scope) = class.to_spirv_semantics_and_scope();
1473 let scope_constant_id = self.get_scope_constant(scope as u32);
1474 let semantics_id = self.get_index_constant(semantics.bits());
1475 let value_id = self.cached[value];
1476 let value_inner = self.fun_info[value].ty.inner_with(&self.ir_module.types);
1477
1478 let instruction = match *fun {
1479 crate::AtomicFunction::Add => Instruction::atomic_binary(
1480 spirv::Op::AtomicIAdd,
1481 result_type_id,
1482 id,
1483 pointer_id,
1484 scope_constant_id,
1485 semantics_id,
1486 value_id,
1487 ),
1488 crate::AtomicFunction::Subtract => Instruction::atomic_binary(
1489 spirv::Op::AtomicISub,
1490 result_type_id,
1491 id,
1492 pointer_id,
1493 scope_constant_id,
1494 semantics_id,
1495 value_id,
1496 ),
1497 crate::AtomicFunction::And => Instruction::atomic_binary(
1498 spirv::Op::AtomicAnd,
1499 result_type_id,
1500 id,
1501 pointer_id,
1502 scope_constant_id,
1503 semantics_id,
1504 value_id,
1505 ),
1506 crate::AtomicFunction::InclusiveOr => Instruction::atomic_binary(
1507 spirv::Op::AtomicOr,
1508 result_type_id,
1509 id,
1510 pointer_id,
1511 scope_constant_id,
1512 semantics_id,
1513 value_id,
1514 ),
1515 crate::AtomicFunction::ExclusiveOr => Instruction::atomic_binary(
1516 spirv::Op::AtomicXor,
1517 result_type_id,
1518 id,
1519 pointer_id,
1520 scope_constant_id,
1521 semantics_id,
1522 value_id,
1523 ),
1524 crate::AtomicFunction::Min => {
1525 let spirv_op = match *value_inner {
1526 crate::TypeInner::Scalar {
1527 kind: crate::ScalarKind::Sint,
1528 width: _,
1529 } => spirv::Op::AtomicSMin,
1530 crate::TypeInner::Scalar {
1531 kind: crate::ScalarKind::Uint,
1532 width: _,
1533 } => spirv::Op::AtomicUMin,
1534 _ => unimplemented!(),
1535 };
1536 Instruction::atomic_binary(
1537 spirv_op,
1538 result_type_id,
1539 id,
1540 pointer_id,
1541 scope_constant_id,
1542 semantics_id,
1543 value_id,
1544 )
1545 }
1546 crate::AtomicFunction::Max => {
1547 let spirv_op = match *value_inner {
1548 crate::TypeInner::Scalar {
1549 kind: crate::ScalarKind::Sint,
1550 width: _,
1551 } => spirv::Op::AtomicSMax,
1552 crate::TypeInner::Scalar {
1553 kind: crate::ScalarKind::Uint,
1554 width: _,
1555 } => spirv::Op::AtomicUMax,
1556 _ => unimplemented!(),
1557 };
1558 Instruction::atomic_binary(
1559 spirv_op,
1560 result_type_id,
1561 id,
1562 pointer_id,
1563 scope_constant_id,
1564 semantics_id,
1565 value_id,
1566 )
1567 }
1568 crate::AtomicFunction::Exchange { compare: None } => {
1569 Instruction::atomic_binary(
1570 spirv::Op::AtomicExchange,
1571 result_type_id,
1572 id,
1573 pointer_id,
1574 scope_constant_id,
1575 semantics_id,
1576 value_id,
1577 )
1578 }
1579 crate::AtomicFunction::Exchange { compare: Some(_) } => {
1580 return Err(Error::FeatureNotImplemented("atomic CompareExchange"));
1581 }
1582 };
1583
1584 block.body.push(instruction);
1585 }
1586 }
1587 }
1588
1589 let termination = match exit_id {
1590 Some(id) => Instruction::branch(id),
1591 // This can happen if the last branch had all the paths
1592 // leading out of the graph (i.e. returning).
1593 // Or it may be the end of the self.function.
1594 None => match self.ir_function.result {
1595 Some(ref result) if self.function.entry_point_context.is_none() => {
1596 let type_id = self.get_type_id(LookupType::Handle(result.ty));
1597 let null_id = self.writer.write_constant_null(type_id);
1598 Instruction::return_value(null_id)
1599 }
1600 _ => Instruction::return_void(),
1601 },
1602 };
1603
1604 self.function.consume(block, termination);
1605 Ok(())
1606 }
1607 }
1608