1 //! Bounds-checking for SPIR-V output.
2 
3 use super::{selection::Selection, Block, BlockContext, Error, IdGenerator, Instruction, Word};
4 use crate::{arena::Handle, proc::BoundsCheckPolicy};
5 
6 /// The results of performing a bounds check.
7 ///
8 /// On success, `write_bounds_check` returns a value of this type.
9 pub(super) enum BoundsCheckResult {
10     /// The index is statically known and in bounds, with the given value.
11     KnownInBounds(u32),
12 
13     /// The given instruction computes the index to be used.
14     Computed(Word),
15 
16     /// The given instruction computes a boolean condition which is true
17     /// if the index is in bounds.
18     Conditional(Word),
19 }
20 
21 /// A value that we either know at translation time, or need to compute at runtime.
22 pub(super) enum MaybeKnown<T> {
23     /// The value is known at shader translation time.
24     Known(T),
25 
26     /// The value is computed by the instruction with the given id.
27     Computed(Word),
28 }
29 
30 impl<'w> BlockContext<'w> {
31     /// Emit code to compute the length of a run-time array.
32     ///
33     /// Given `array`, an expression referring to the final member of a struct,
34     /// where the member in question is a runtime-sized array, return the
35     /// instruction id for the array's length.
write_runtime_array_length( &mut self, array: Handle<crate::Expression>, block: &mut Block, ) -> Result<Word, Error>36     pub(super) fn write_runtime_array_length(
37         &mut self,
38         array: Handle<crate::Expression>,
39         block: &mut Block,
40     ) -> Result<Word, Error> {
41         // Look into the expression to find the value and type of the struct
42         // holding the dynamically-sized array.
43         let (structure_id, last_member_index) = match self.ir_function.expressions[array] {
44             crate::Expression::AccessIndex { base, index } => {
45                 match self.ir_function.expressions[base] {
46                     crate::Expression::GlobalVariable(handle) => (
47                         self.writer.global_variables[handle.index()].access_id,
48                         index,
49                     ),
50                     _ => return Err(Error::Validation("array length expression")),
51                 }
52             }
53             _ => return Err(Error::Validation("array length expression")),
54         };
55 
56         let length_id = self.gen_id();
57         block.body.push(Instruction::array_length(
58             self.writer.get_uint_type_id(),
59             length_id,
60             structure_id,
61             last_member_index,
62         ));
63 
64         Ok(length_id)
65     }
66 
67     /// Compute the length of a subscriptable value.
68     ///
69     /// Given `sequence`, an expression referring to some indexable type, return
70     /// its length. The result may either be computed by SPIR-V instructions, or
71     /// known at shader translation time.
72     ///
73     /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any
74     /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically
75     /// sized, or use a specializable constant as its length.
write_sequence_length( &mut self, sequence: Handle<crate::Expression>, block: &mut Block, ) -> Result<MaybeKnown<u32>, Error>76     fn write_sequence_length(
77         &mut self,
78         sequence: Handle<crate::Expression>,
79         block: &mut Block,
80     ) -> Result<MaybeKnown<u32>, Error> {
81         let sequence_ty = self.fun_info[sequence].ty.inner_with(&self.ir_module.types);
82         match sequence_ty.indexable_length(self.ir_module)? {
83             crate::proc::IndexableLength::Known(known_length) => {
84                 Ok(MaybeKnown::Known(known_length))
85             }
86             crate::proc::IndexableLength::Dynamic => {
87                 let length_id = self.write_runtime_array_length(sequence, block)?;
88                 Ok(MaybeKnown::Computed(length_id))
89             }
90         }
91     }
92 
93     /// Compute the maximum valid index of a subscriptable value.
94     ///
95     /// Given `sequence`, an expression referring to some indexable type, return
96     /// its maximum valid index - one less than its length. The result may
97     /// either be computed, or known at shader translation time.
98     ///
99     /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any
100     /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically
101     /// sized, or use a specializable constant as its length.
write_sequence_max_index( &mut self, sequence: Handle<crate::Expression>, block: &mut Block, ) -> Result<MaybeKnown<u32>, Error>102     fn write_sequence_max_index(
103         &mut self,
104         sequence: Handle<crate::Expression>,
105         block: &mut Block,
106     ) -> Result<MaybeKnown<u32>, Error> {
107         match self.write_sequence_length(sequence, block)? {
108             MaybeKnown::Known(known_length) => {
109                 // We should have thrown out all attempts to subscript zero-length
110                 // sequences during validation, so the following subtraction should never
111                 // underflow.
112                 assert!(known_length > 0);
113                 // Compute the max index from the length now.
114                 Ok(MaybeKnown::Known(known_length - 1))
115             }
116             MaybeKnown::Computed(length_id) => {
117                 // Emit code to compute the max index from the length.
118                 let const_one_id = self.get_index_constant(1);
119                 let max_index_id = self.gen_id();
120                 block.body.push(Instruction::binary(
121                     spirv::Op::ISub,
122                     self.writer.get_uint_type_id(),
123                     max_index_id,
124                     length_id,
125                     const_one_id,
126                 ));
127                 Ok(MaybeKnown::Computed(max_index_id))
128             }
129         }
130     }
131 
132     /// Restrict an index to be in range for a vector, matrix, or array.
133     ///
134     /// This is used to implement `BoundsCheckPolicy::Restrict`. An in-bounds
135     /// index is left unchanged. An out-of-bounds index is replaced with some
136     /// arbitrary in-bounds index. Note,this is not necessarily clamping; for
137     /// example, negative indices might be changed to refer to the last element
138     /// of the sequence, not the first, as clamping would do.
139     ///
140     /// Either return the restricted index value, if known, or add instructions
141     /// to `block` to compute it, and return the id of the result. See the
142     /// documentation for `BoundsCheckResult` for details.
143     ///
144     /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a
145     /// `Pointer` to any of those, or a `ValuePointer`. An array may be
146     /// fixed-size, dynamically sized, or use a specializable constant as its
147     /// length.
write_restricted_index( &mut self, sequence: Handle<crate::Expression>, index: Handle<crate::Expression>, block: &mut Block, ) -> Result<BoundsCheckResult, Error>148     pub(super) fn write_restricted_index(
149         &mut self,
150         sequence: Handle<crate::Expression>,
151         index: Handle<crate::Expression>,
152         block: &mut Block,
153     ) -> Result<BoundsCheckResult, Error> {
154         let index_id = self.cached[index];
155 
156         // Get the sequence's maximum valid index. Return early if we've already
157         // done the bounds check.
158         let max_index_id = match self.write_sequence_max_index(sequence, block)? {
159             MaybeKnown::Known(known_max_index) => {
160                 if let crate::Expression::Constant(index_k) = self.ir_function.expressions[index] {
161                     if let Some(known_index) = self.ir_module.constants[index_k].to_array_length() {
162                         // Both the index and length are known at compile time.
163                         //
164                         // In strict WGSL compliance mode, out-of-bounds indices cannot be
165                         // reported at shader translation time, and must be replaced with
166                         // in-bounds indices at run time. So we cannot assume that
167                         // validation ensured the index was in bounds. Restrict now.
168                         let restricted = std::cmp::min(known_index, known_max_index);
169                         return Ok(BoundsCheckResult::KnownInBounds(restricted));
170                     }
171                 }
172 
173                 self.get_index_constant(known_max_index)
174             }
175             MaybeKnown::Computed(max_index_id) => max_index_id,
176         };
177 
178         // One or the other of the index or length is dynamic, so emit code for
179         // BoundsCheckPolicy::Restrict.
180         let restricted_index_id = self.gen_id();
181         block.body.push(Instruction::ext_inst(
182             self.writer.gl450_ext_inst_id,
183             spirv::GLOp::UMin,
184             self.writer.get_uint_type_id(),
185             restricted_index_id,
186             &[index_id, max_index_id],
187         ));
188         Ok(BoundsCheckResult::Computed(restricted_index_id))
189     }
190 
191     /// Write an index bounds comparison to `block`, if needed.
192     ///
193     /// If we're able to determine statically that `index` is in bounds for
194     /// `sequence`, return `KnownInBounds(value)`, where `value` is the actual
195     /// value of the index. (In principle, one could know that the index is in
196     /// bounds without knowing its specific value, but in our simple-minded
197     /// situation, we always know it.)
198     ///
199     /// If instead we must generate code to perform the comparison at run time,
200     /// return `Conditional(comparison_id)`, where `comparison_id` is an
201     /// instruction producing a boolean value that is true if `index` is in
202     /// bounds for `sequence`.
203     ///
204     /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a
205     /// `Pointer` to any of those, or a `ValuePointer`. An array may be
206     /// fixed-size, dynamically sized, or use a specializable constant as its
207     /// length.
write_index_comparison( &mut self, sequence: Handle<crate::Expression>, index: Handle<crate::Expression>, block: &mut Block, ) -> Result<BoundsCheckResult, Error>208     fn write_index_comparison(
209         &mut self,
210         sequence: Handle<crate::Expression>,
211         index: Handle<crate::Expression>,
212         block: &mut Block,
213     ) -> Result<BoundsCheckResult, Error> {
214         let index_id = self.cached[index];
215 
216         // Get the sequence's length. Return early if we've already done the
217         // bounds check.
218         let length_id = match self.write_sequence_length(sequence, block)? {
219             MaybeKnown::Known(known_length) => {
220                 if let crate::Expression::Constant(index_k) = self.ir_function.expressions[index] {
221                     if let Some(known_index) = self.ir_module.constants[index_k].to_array_length() {
222                         // Both the index and length are known at compile time.
223                         //
224                         // It would be nice to assume that, since we are using the
225                         // `ReadZeroSkipWrite` policy, we are not in strict WGSL
226                         // compliance mode, and thus we can count on the validator to have
227                         // rejected any programs with known out-of-bounds indices, and
228                         // thus just return `KnownInBounds` here without actually
229                         // checking.
230                         //
231                         // But it's also reasonable to expect that bounds check policies
232                         // and error reporting policies should be able to vary
233                         // independently without introducing security holes. So, we should
234                         // support the case where bad indices do not cause validation
235                         // errors, and are handled via `ReadZeroSkipWrite`.
236                         //
237                         // In theory, when `known_index` is bad, we could return a new
238                         // `KnownOutOfBounds` variant here. But it's simpler just to fall
239                         // through and let the bounds check take place. The shader is
240                         // broken anyway, so it doesn't make sense to invest in emitting
241                         // the ideal code for it.
242                         if known_index < known_length {
243                             return Ok(BoundsCheckResult::KnownInBounds(known_index));
244                         }
245                     }
246                 }
247 
248                 self.get_index_constant(known_length)
249             }
250             MaybeKnown::Computed(length_id) => length_id,
251         };
252 
253         // Compare the index against the length.
254         let condition_id = self.gen_id();
255         block.body.push(Instruction::binary(
256             spirv::Op::ULessThan,
257             self.writer.get_bool_type_id(),
258             condition_id,
259             index_id,
260             length_id,
261         ));
262 
263         // Indicate that we did generate the check.
264         Ok(BoundsCheckResult::Conditional(condition_id))
265     }
266 
267     /// Emit a conditional load for `BoundsCheckPolicy::ReadZeroSkipWrite`.
268     ///
269     /// Generate code to load a value of `result_type` if `condition` is true,
270     /// and generate a null value of that type if it is false. Call `emit_load`
271     /// to emit the instructions to perform the load. Return the id of the
272     /// merged value of the two branches.
write_conditional_indexed_load<F>( &mut self, result_type: Word, condition: Word, block: &mut Block, emit_load: F, ) -> Word where F: FnOnce(&mut IdGenerator, &mut Block) -> Word,273     pub(super) fn write_conditional_indexed_load<F>(
274         &mut self,
275         result_type: Word,
276         condition: Word,
277         block: &mut Block,
278         emit_load: F,
279     ) -> Word
280     where
281         F: FnOnce(&mut IdGenerator, &mut Block) -> Word,
282     {
283         // For the out-of-bounds case, we produce a zero value.
284         let null_id = self.writer.write_constant_null(result_type);
285 
286         let mut selection = Selection::start(block, result_type);
287 
288         // As it turns out, we don't actually need a full 'if-then-else'
289         // structure for this: SPIR-V constants are declared up front, so the
290         // 'else' block would have no instructions. Instead we emit something
291         // like this:
292         //
293         //     result = zero;
294         //     if in_bounds {
295         //         result = do the load;
296         //     }
297         //     use result;
298 
299         // Continue only if the index was in bounds. Otherwise, branch to the
300         // merge block.
301         selection.if_true(self, condition, null_id);
302 
303         // The in-bounds path. Perform the access and the load.
304         let loaded_value = emit_load(&mut self.writer.id_gen, selection.block());
305 
306         selection.finish(self, loaded_value)
307     }
308 
309     /// Emit code for bounds checks for an array, vector, or matrix access.
310     ///
311     /// This implements either `index_bounds_check_policy` or
312     /// `buffer_bounds_check_policy`, depending on the storage class of the
313     /// pointer being accessed.
314     ///
315     /// Return a `BoundsCheckResult` indicating how the index should be
316     /// consumed. See that type's documentation for details.
write_bounds_check( &mut self, base: Handle<crate::Expression>, index: Handle<crate::Expression>, block: &mut Block, ) -> Result<BoundsCheckResult, Error>317     pub(super) fn write_bounds_check(
318         &mut self,
319         base: Handle<crate::Expression>,
320         index: Handle<crate::Expression>,
321         block: &mut Block,
322     ) -> Result<BoundsCheckResult, Error> {
323         let policy = self.writer.bounds_check_policies.choose_policy(
324             base,
325             &self.ir_module.types,
326             self.fun_info,
327         );
328 
329         Ok(match policy {
330             BoundsCheckPolicy::Restrict => self.write_restricted_index(base, index, block)?,
331             BoundsCheckPolicy::ReadZeroSkipWrite => {
332                 self.write_index_comparison(base, index, block)?
333             }
334             BoundsCheckPolicy::Unchecked => BoundsCheckResult::Computed(self.cached[index]),
335         })
336     }
337 
338     /// Emit code to subscript a vector by value with a computed index.
339     ///
340     /// Return the id of the element value.
write_vector_access( &mut self, expr_handle: Handle<crate::Expression>, base: Handle<crate::Expression>, index: Handle<crate::Expression>, block: &mut Block, ) -> Result<Word, Error>341     pub(super) fn write_vector_access(
342         &mut self,
343         expr_handle: Handle<crate::Expression>,
344         base: Handle<crate::Expression>,
345         index: Handle<crate::Expression>,
346         block: &mut Block,
347     ) -> Result<Word, Error> {
348         let result_type_id = self.get_expression_type_id(&self.fun_info[expr_handle].ty);
349 
350         let base_id = self.cached[base];
351         let index_id = self.cached[index];
352 
353         let result_id = match self.write_bounds_check(base, index, block)? {
354             BoundsCheckResult::KnownInBounds(known_index) => {
355                 let result_id = self.gen_id();
356                 block.body.push(Instruction::composite_extract(
357                     result_type_id,
358                     result_id,
359                     base_id,
360                     &[known_index],
361                 ));
362                 result_id
363             }
364             BoundsCheckResult::Computed(computed_index_id) => {
365                 let result_id = self.gen_id();
366                 block.body.push(Instruction::vector_extract_dynamic(
367                     result_type_id,
368                     result_id,
369                     base_id,
370                     computed_index_id,
371                 ));
372                 result_id
373             }
374             BoundsCheckResult::Conditional(comparison_id) => {
375                 // Run-time bounds checks were required. Emit
376                 // conditional load.
377                 self.write_conditional_indexed_load(
378                     result_type_id,
379                     comparison_id,
380                     block,
381                     |id_gen, block| {
382                         // The in-bounds path. Generate the access.
383                         let element_id = id_gen.next();
384                         block.body.push(Instruction::vector_extract_dynamic(
385                             result_type_id,
386                             element_id,
387                             base_id,
388                             index_id,
389                         ));
390                         element_id
391                     },
392                 )
393             }
394         };
395 
396         Ok(result_id)
397     }
398 }
399