1 //! Processing push constant ranges
2 //!
3 //! This module provides utitlity functions to make push constants root signature
4 //! compatible. Root constants are non-overlapping, therefore, the push constant
5 //! ranges passed at pipeline layout creation need to be `split` into disjunct
6 //! ranges. The disjunct ranges can be then converted into root signature entries.
7 
8 use hal::pso;
9 use std::{cmp::Ordering, ops::Range};
10 
11 #[derive(Clone, Debug, PartialEq, Eq, Hash)]
12 pub struct RootConstant {
13     pub stages: pso::ShaderStageFlags,
14     pub range: Range<u32>,
15 }
16 
17 impl RootConstant {
is_empty(&self) -> bool18     fn is_empty(&self) -> bool {
19         self.range.end <= self.range.start
20     }
21 
22     // Divide a root constant into two separate ranges depending on the overlap
23     // with another root constant.
divide(self, other: &RootConstant) -> (RootConstant, RootConstant)24     fn divide(self, other: &RootConstant) -> (RootConstant, RootConstant) {
25         assert!(self.range.start <= other.range.start);
26         let left = RootConstant {
27             stages: self.stages,
28             range: self.range.start..other.range.start,
29         };
30 
31         let right = RootConstant {
32             stages: self.stages,
33             range: other.range.start..self.range.end,
34         };
35 
36         (left, right)
37     }
38 }
39 
40 impl PartialOrd for RootConstant {
partial_cmp(&self, other: &RootConstant) -> Option<Ordering>41     fn partial_cmp(&self, other: &RootConstant) -> Option<Ordering> {
42         Some(
43             self.range
44                 .start
45                 .cmp(&other.range.start)
46                 .then(self.range.end.cmp(&other.range.end))
47                 .then(self.stages.cmp(&other.stages)),
48         )
49     }
50 }
51 
52 impl Ord for RootConstant {
cmp(&self, other: &RootConstant) -> Ordering53     fn cmp(&self, other: &RootConstant) -> Ordering {
54         self.partial_cmp(other).unwrap()
55     }
56 }
57 
split<I>(ranges: I) -> Vec<RootConstant> where I: IntoIterator<Item = (pso::ShaderStageFlags, Range<u32>)>,58 pub fn split<I>(ranges: I) -> Vec<RootConstant>
59 where
60     I: IntoIterator<Item = (pso::ShaderStageFlags, Range<u32>)>,
61 {
62     // Frontier of unexplored root constant ranges, sorted descending
63     // (less element shifting for Vec) regarding to the start of ranges.
64     let mut ranges = into_vec(ranges);
65     ranges.sort_by(|a, b| b.cmp(a));
66 
67     // Storing resulting disjunct root constant ranges.
68     let mut disjunct = Vec::with_capacity(ranges.len());
69 
70     while let Some(cur) = ranges.pop() {
71         // Run trough all unexplored ranges. After each run the frontier will be
72         // resorted!
73         //
74         // Case 1: Single element remaining
75         //      Push is to the disjunct list, done.
76         // Case 2: At least two ranges, which possibly overlap
77         //      Divide the first range into a left set and right set, depending
78         //      on the overlap of the two ranges:
79         //      Range 1: |---- left ---||--- right ---|
80         //      Range 2:                |--------...
81         if let Some(mut next) = ranges.pop() {
82             let (left, mut right) = cur.divide(&next);
83             if !left.is_empty() {
84                 // The left part is, by definition, disjunct to all other ranges.
85                 // Push all remaining pieces in the frontier, handled by the next
86                 // iteration.
87                 disjunct.push(left);
88                 ranges.push(next);
89                 if !right.is_empty() {
90                     ranges.push(right);
91                 }
92             } else if !right.is_empty() {
93                 // If the left part is empty this means that both ranges have the
94                 // same start value. The right segment is a candidate for a disjunct
95                 // segment but we haven't checked against other ranges so far.
96                 // Therefore, we push is on the frontier again, but added the
97                 // stage flags from the overlapping segment.
98                 // The second range will be shrunken to be disjunct with the pushed
99                 // segment as we have already processed it.
100                 // In the next iteration we will look again at the push right
101                 // segment and compare it to other elements on the list until we
102                 // have a small enough disjunct segment, which doesn't overlap
103                 // with any part of the frontier.
104                 right.stages |= next.stages;
105                 next.range.start = right.range.end;
106                 ranges.push(right);
107                 if !next.is_empty() {
108                     ranges.push(next);
109                 }
110             }
111         } else {
112             disjunct.push(cur);
113         }
114         ranges.sort_by(|a, b| b.cmp(a));
115     }
116 
117     disjunct
118 }
119 
into_vec<I>(ranges: I) -> Vec<RootConstant> where I: IntoIterator<Item = (pso::ShaderStageFlags, Range<u32>)>,120 fn into_vec<I>(ranges: I) -> Vec<RootConstant>
121 where
122     I: IntoIterator<Item = (pso::ShaderStageFlags, Range<u32>)>,
123 {
124     ranges
125         .into_iter()
126         .map(|(stages, ref range)| {
127             debug_assert_eq!(range.start % 4, 0);
128             debug_assert_eq!(range.end % 4, 0);
129             RootConstant {
130                 stages,
131                 range: range.start / 4..range.end / 4,
132             }
133         })
134         .collect()
135 }
136 
137 #[cfg(test)]
138 mod tests {
139     use super::*;
140 
141     #[test]
test_single()142     fn test_single() {
143         //TODO: use movable fixed-size ranges when available in Rust
144         let range = Some((pso::ShaderStageFlags::VERTEX, 0..12));
145         assert_eq!(into_vec(range.clone()), split(range));
146     }
147 
148     #[test]
test_overlap_1()149     fn test_overlap_1() {
150         // Case:
151         //      |----------|
152         //          |------------|
153         let ranges = vec![
154             (pso::ShaderStageFlags::VERTEX, 0..12),
155             (pso::ShaderStageFlags::FRAGMENT, 8..16),
156         ];
157 
158         let reference = vec![
159             RootConstant {
160                 stages: pso::ShaderStageFlags::VERTEX,
161                 range: 0..2,
162             },
163             RootConstant {
164                 stages: pso::ShaderStageFlags::VERTEX | pso::ShaderStageFlags::FRAGMENT,
165                 range: 2..3,
166             },
167             RootConstant {
168                 stages: pso::ShaderStageFlags::FRAGMENT,
169                 range: 3..4,
170             },
171         ];
172         assert_eq!(reference, split(ranges));
173     }
174 
175     #[test]
test_overlap_2()176     fn test_overlap_2() {
177         // Case:
178         //      |-------------------|
179         //          |------------|
180         let ranges = vec![
181             (pso::ShaderStageFlags::VERTEX, 0..20),
182             (pso::ShaderStageFlags::FRAGMENT, 8..16),
183         ];
184 
185         let reference = vec![
186             RootConstant {
187                 stages: pso::ShaderStageFlags::VERTEX,
188                 range: 0..2,
189             },
190             RootConstant {
191                 stages: pso::ShaderStageFlags::VERTEX | pso::ShaderStageFlags::FRAGMENT,
192                 range: 2..4,
193             },
194             RootConstant {
195                 stages: pso::ShaderStageFlags::VERTEX,
196                 range: 4..5,
197             },
198         ];
199         assert_eq!(reference, split(ranges));
200     }
201 
202     #[test]
test_overlap_4()203     fn test_overlap_4() {
204         // Case:
205         //      |--------------|
206         //      |------------|
207         let ranges = vec![
208             (pso::ShaderStageFlags::VERTEX, 0..20),
209             (pso::ShaderStageFlags::FRAGMENT, 0..16),
210         ];
211 
212         let reference = vec![
213             RootConstant {
214                 stages: pso::ShaderStageFlags::VERTEX | pso::ShaderStageFlags::FRAGMENT,
215                 range: 0..4,
216             },
217             RootConstant {
218                 stages: pso::ShaderStageFlags::VERTEX,
219                 range: 4..5,
220             },
221         ];
222         assert_eq!(reference, split(ranges));
223     }
224 
225     #[test]
test_equal()226     fn test_equal() {
227         // Case:
228         //      |-----|
229         //      |-----|
230         let ranges = vec![
231             (pso::ShaderStageFlags::VERTEX, 0..16),
232             (pso::ShaderStageFlags::FRAGMENT, 0..16),
233         ];
234 
235         let reference = vec![RootConstant {
236             stages: pso::ShaderStageFlags::VERTEX | pso::ShaderStageFlags::FRAGMENT,
237             range: 0..4,
238         }];
239         assert_eq!(reference, split(ranges));
240     }
241 
242     #[test]
test_disjunct()243     fn test_disjunct() {
244         // Case:
245         //      |------|
246         //               |------------|
247         let ranges = vec![
248             (pso::ShaderStageFlags::VERTEX, 0..12),
249             (pso::ShaderStageFlags::FRAGMENT, 12..16),
250         ];
251         assert_eq!(into_vec(ranges.clone()), split(ranges));
252     }
253 
254     #[test]
test_complex()255     fn test_complex() {
256         let ranges = vec![
257             (pso::ShaderStageFlags::VERTEX, 8..40),
258             (pso::ShaderStageFlags::FRAGMENT, 0..20),
259             (pso::ShaderStageFlags::GEOMETRY, 24..40),
260             (pso::ShaderStageFlags::HULL, 16..28),
261         ];
262 
263         let reference = vec![
264             RootConstant {
265                 stages: pso::ShaderStageFlags::FRAGMENT,
266                 range: 0..2,
267             },
268             RootConstant {
269                 stages: pso::ShaderStageFlags::VERTEX | pso::ShaderStageFlags::FRAGMENT,
270                 range: 2..4,
271             },
272             RootConstant {
273                 stages: pso::ShaderStageFlags::VERTEX
274                     | pso::ShaderStageFlags::FRAGMENT
275                     | pso::ShaderStageFlags::HULL,
276                 range: 4..5,
277             },
278             RootConstant {
279                 stages: pso::ShaderStageFlags::VERTEX | pso::ShaderStageFlags::HULL,
280                 range: 5..6,
281             },
282             RootConstant {
283                 stages: pso::ShaderStageFlags::VERTEX
284                     | pso::ShaderStageFlags::GEOMETRY
285                     | pso::ShaderStageFlags::HULL,
286                 range: 6..7,
287             },
288             RootConstant {
289                 stages: pso::ShaderStageFlags::VERTEX | pso::ShaderStageFlags::GEOMETRY,
290                 range: 7..10,
291             },
292         ];
293 
294         assert_eq!(reference, split(ranges));
295     }
296 }
297