1 /*
2 * Copyright 2016-2020 The Brenwill Workshop Ltd.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "spirv_msl.hpp"
18 #include "GLSL.std.450.h"
19
20 #include <algorithm>
21 #include <assert.h>
22 #include <numeric>
23
24 using namespace spv;
25 using namespace SPIRV_CROSS_NAMESPACE;
26 using namespace std;
27
28 static const uint32_t k_unknown_location = ~0u;
29 static const uint32_t k_unknown_component = ~0u;
30 static const char *force_inline = "static inline __attribute__((always_inline))";
31
CompilerMSL(std::vector<uint32_t> spirv_)32 CompilerMSL::CompilerMSL(std::vector<uint32_t> spirv_)
33 : CompilerGLSL(move(spirv_))
34 {
35 }
36
CompilerMSL(const uint32_t * ir_,size_t word_count)37 CompilerMSL::CompilerMSL(const uint32_t *ir_, size_t word_count)
38 : CompilerGLSL(ir_, word_count)
39 {
40 }
41
CompilerMSL(const ParsedIR & ir_)42 CompilerMSL::CompilerMSL(const ParsedIR &ir_)
43 : CompilerGLSL(ir_)
44 {
45 }
46
CompilerMSL(ParsedIR && ir_)47 CompilerMSL::CompilerMSL(ParsedIR &&ir_)
48 : CompilerGLSL(std::move(ir_))
49 {
50 }
51
add_msl_shader_input(const MSLShaderInput & si)52 void CompilerMSL::add_msl_shader_input(const MSLShaderInput &si)
53 {
54 inputs_by_location[si.location] = si;
55 if (si.builtin != BuiltInMax && !inputs_by_builtin.count(si.builtin))
56 inputs_by_builtin[si.builtin] = si;
57 }
58
add_msl_resource_binding(const MSLResourceBinding & binding)59 void CompilerMSL::add_msl_resource_binding(const MSLResourceBinding &binding)
60 {
61 StageSetBinding tuple = { binding.stage, binding.desc_set, binding.binding };
62 resource_bindings[tuple] = { binding, false };
63 }
64
add_dynamic_buffer(uint32_t desc_set,uint32_t binding,uint32_t index)65 void CompilerMSL::add_dynamic_buffer(uint32_t desc_set, uint32_t binding, uint32_t index)
66 {
67 SetBindingPair pair = { desc_set, binding };
68 buffers_requiring_dynamic_offset[pair] = { index, 0 };
69 }
70
add_inline_uniform_block(uint32_t desc_set,uint32_t binding)71 void CompilerMSL::add_inline_uniform_block(uint32_t desc_set, uint32_t binding)
72 {
73 SetBindingPair pair = { desc_set, binding };
74 inline_uniform_blocks.insert(pair);
75 }
76
add_discrete_descriptor_set(uint32_t desc_set)77 void CompilerMSL::add_discrete_descriptor_set(uint32_t desc_set)
78 {
79 if (desc_set < kMaxArgumentBuffers)
80 argument_buffer_discrete_mask |= 1u << desc_set;
81 }
82
set_argument_buffer_device_address_space(uint32_t desc_set,bool device_storage)83 void CompilerMSL::set_argument_buffer_device_address_space(uint32_t desc_set, bool device_storage)
84 {
85 if (desc_set < kMaxArgumentBuffers)
86 {
87 if (device_storage)
88 argument_buffer_device_storage_mask |= 1u << desc_set;
89 else
90 argument_buffer_device_storage_mask &= ~(1u << desc_set);
91 }
92 }
93
is_msl_shader_input_used(uint32_t location)94 bool CompilerMSL::is_msl_shader_input_used(uint32_t location)
95 {
96 return inputs_in_use.count(location) != 0;
97 }
98
is_msl_resource_binding_used(ExecutionModel model,uint32_t desc_set,uint32_t binding) const99 bool CompilerMSL::is_msl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const
100 {
101 StageSetBinding tuple = { model, desc_set, binding };
102 auto itr = resource_bindings.find(tuple);
103 return itr != end(resource_bindings) && itr->second.second;
104 }
105
get_automatic_msl_resource_binding(uint32_t id) const106 uint32_t CompilerMSL::get_automatic_msl_resource_binding(uint32_t id) const
107 {
108 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexPrimary);
109 }
110
get_automatic_msl_resource_binding_secondary(uint32_t id) const111 uint32_t CompilerMSL::get_automatic_msl_resource_binding_secondary(uint32_t id) const
112 {
113 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexSecondary);
114 }
115
get_automatic_msl_resource_binding_tertiary(uint32_t id) const116 uint32_t CompilerMSL::get_automatic_msl_resource_binding_tertiary(uint32_t id) const
117 {
118 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexTertiary);
119 }
120
get_automatic_msl_resource_binding_quaternary(uint32_t id) const121 uint32_t CompilerMSL::get_automatic_msl_resource_binding_quaternary(uint32_t id) const
122 {
123 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexQuaternary);
124 }
125
set_fragment_output_components(uint32_t location,uint32_t components)126 void CompilerMSL::set_fragment_output_components(uint32_t location, uint32_t components)
127 {
128 fragment_output_components[location] = components;
129 }
130
builtin_translates_to_nonarray(spv::BuiltIn builtin) const131 bool CompilerMSL::builtin_translates_to_nonarray(spv::BuiltIn builtin) const
132 {
133 return (builtin == BuiltInSampleMask);
134 }
135
build_implicit_builtins()136 void CompilerMSL::build_implicit_builtins()
137 {
138 bool need_sample_pos = active_input_builtins.get(BuiltInSamplePosition);
139 bool need_vertex_params = capture_output_to_buffer && get_execution_model() == ExecutionModelVertex &&
140 !msl_options.vertex_for_tessellation;
141 bool need_tesc_params = get_execution_model() == ExecutionModelTessellationControl;
142 bool need_subgroup_mask =
143 active_input_builtins.get(BuiltInSubgroupEqMask) || active_input_builtins.get(BuiltInSubgroupGeMask) ||
144 active_input_builtins.get(BuiltInSubgroupGtMask) || active_input_builtins.get(BuiltInSubgroupLeMask) ||
145 active_input_builtins.get(BuiltInSubgroupLtMask);
146 bool need_subgroup_ge_mask = !msl_options.is_ios() && (active_input_builtins.get(BuiltInSubgroupGeMask) ||
147 active_input_builtins.get(BuiltInSubgroupGtMask));
148 bool need_multiview = get_execution_model() == ExecutionModelVertex && !msl_options.view_index_from_device_index &&
149 msl_options.multiview_layered_rendering &&
150 (msl_options.multiview || active_input_builtins.get(BuiltInViewIndex));
151 bool need_dispatch_base =
152 msl_options.dispatch_base && get_execution_model() == ExecutionModelGLCompute &&
153 (active_input_builtins.get(BuiltInWorkgroupId) || active_input_builtins.get(BuiltInGlobalInvocationId));
154 bool need_grid_params = get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation;
155 bool need_vertex_base_params =
156 need_grid_params &&
157 (active_input_builtins.get(BuiltInVertexId) || active_input_builtins.get(BuiltInVertexIndex) ||
158 active_input_builtins.get(BuiltInBaseVertex) || active_input_builtins.get(BuiltInInstanceId) ||
159 active_input_builtins.get(BuiltInInstanceIndex) || active_input_builtins.get(BuiltInBaseInstance));
160 bool need_sample_mask = msl_options.additional_fixed_sample_mask != 0xffffffff;
161 if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
162 need_multiview || need_dispatch_base || need_vertex_base_params || need_grid_params ||
163 needs_subgroup_invocation_id || need_sample_mask)
164 {
165 bool has_frag_coord = false;
166 bool has_sample_id = false;
167 bool has_vertex_idx = false;
168 bool has_base_vertex = false;
169 bool has_instance_idx = false;
170 bool has_base_instance = false;
171 bool has_invocation_id = false;
172 bool has_primitive_id = false;
173 bool has_subgroup_invocation_id = false;
174 bool has_subgroup_size = false;
175 bool has_view_idx = false;
176 bool has_layer = false;
177 uint32_t workgroup_id_type = 0;
178
179 // FIXME: Investigate the fact that there are no checks for the entry point interface variables.
180 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
181 if (!ir.meta[var.self].decoration.builtin)
182 return;
183
184 // Use Metal's native frame-buffer fetch API for subpass inputs.
185 BuiltIn builtin = ir.meta[var.self].decoration.builtin_type;
186
187 if (var.storage == StorageClassOutput)
188 {
189 if (need_sample_mask && builtin == BuiltInSampleMask)
190 {
191 builtin_sample_mask_id = var.self;
192 mark_implicit_builtin(StorageClassOutput, BuiltInSampleMask, var.self);
193 does_shader_write_sample_mask = true;
194 }
195 }
196
197 if (var.storage != StorageClassInput)
198 return;
199
200 if (need_subpass_input && (!msl_options.is_ios() || !msl_options.ios_use_framebuffer_fetch_subpasses))
201 {
202 switch (builtin)
203 {
204 case BuiltInFragCoord:
205 mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var.self);
206 builtin_frag_coord_id = var.self;
207 has_frag_coord = true;
208 break;
209 case BuiltInLayer:
210 if (!msl_options.arrayed_subpass_input || msl_options.multiview)
211 break;
212 mark_implicit_builtin(StorageClassInput, BuiltInLayer, var.self);
213 builtin_layer_id = var.self;
214 has_layer = true;
215 break;
216 case BuiltInViewIndex:
217 if (!msl_options.multiview)
218 break;
219 mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var.self);
220 builtin_view_idx_id = var.self;
221 has_view_idx = true;
222 break;
223 default:
224 break;
225 }
226 }
227
228 if (need_sample_pos && builtin == BuiltInSampleId)
229 {
230 builtin_sample_id_id = var.self;
231 mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var.self);
232 has_sample_id = true;
233 }
234
235 if (need_vertex_params)
236 {
237 switch (builtin)
238 {
239 case BuiltInVertexIndex:
240 builtin_vertex_idx_id = var.self;
241 mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var.self);
242 has_vertex_idx = true;
243 break;
244 case BuiltInBaseVertex:
245 builtin_base_vertex_id = var.self;
246 mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var.self);
247 has_base_vertex = true;
248 break;
249 case BuiltInInstanceIndex:
250 builtin_instance_idx_id = var.self;
251 mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var.self);
252 has_instance_idx = true;
253 break;
254 case BuiltInBaseInstance:
255 builtin_base_instance_id = var.self;
256 mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var.self);
257 has_base_instance = true;
258 break;
259 default:
260 break;
261 }
262 }
263
264 if (need_tesc_params)
265 {
266 switch (builtin)
267 {
268 case BuiltInInvocationId:
269 builtin_invocation_id_id = var.self;
270 mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var.self);
271 has_invocation_id = true;
272 break;
273 case BuiltInPrimitiveId:
274 builtin_primitive_id_id = var.self;
275 mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var.self);
276 has_primitive_id = true;
277 break;
278 default:
279 break;
280 }
281 }
282
283 if ((need_subgroup_mask || needs_subgroup_invocation_id) && builtin == BuiltInSubgroupLocalInvocationId)
284 {
285 builtin_subgroup_invocation_id_id = var.self;
286 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var.self);
287 has_subgroup_invocation_id = true;
288 }
289
290 if (need_subgroup_ge_mask && builtin == BuiltInSubgroupSize)
291 {
292 builtin_subgroup_size_id = var.self;
293 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var.self);
294 has_subgroup_size = true;
295 }
296
297 if (need_multiview)
298 {
299 switch (builtin)
300 {
301 case BuiltInInstanceIndex:
302 // The view index here is derived from the instance index.
303 builtin_instance_idx_id = var.self;
304 mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var.self);
305 has_instance_idx = true;
306 break;
307 case BuiltInBaseInstance:
308 // If a non-zero base instance is used, we need to adjust for it when calculating the view index.
309 builtin_base_instance_id = var.self;
310 mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var.self);
311 has_base_instance = true;
312 break;
313 case BuiltInViewIndex:
314 builtin_view_idx_id = var.self;
315 mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var.self);
316 has_view_idx = true;
317 break;
318 default:
319 break;
320 }
321 }
322
323 // The base workgroup needs to have the same type and vector size
324 // as the workgroup or invocation ID, so keep track of the type that
325 // was used.
326 if (need_dispatch_base && workgroup_id_type == 0 &&
327 (builtin == BuiltInWorkgroupId || builtin == BuiltInGlobalInvocationId))
328 workgroup_id_type = var.basetype;
329 });
330
331 // Use Metal's native frame-buffer fetch API for subpass inputs.
332 if ((!has_frag_coord || (msl_options.multiview && !has_view_idx) ||
333 (msl_options.arrayed_subpass_input && !msl_options.multiview && !has_layer)) &&
334 (!msl_options.is_ios() || !msl_options.ios_use_framebuffer_fetch_subpasses) && need_subpass_input)
335 {
336 if (!has_frag_coord)
337 {
338 uint32_t offset = ir.increase_bound_by(3);
339 uint32_t type_id = offset;
340 uint32_t type_ptr_id = offset + 1;
341 uint32_t var_id = offset + 2;
342
343 // Create gl_FragCoord.
344 SPIRType vec4_type;
345 vec4_type.basetype = SPIRType::Float;
346 vec4_type.width = 32;
347 vec4_type.vecsize = 4;
348 set<SPIRType>(type_id, vec4_type);
349
350 SPIRType vec4_type_ptr;
351 vec4_type_ptr = vec4_type;
352 vec4_type_ptr.pointer = true;
353 vec4_type_ptr.parent_type = type_id;
354 vec4_type_ptr.storage = StorageClassInput;
355 auto &ptr_type = set<SPIRType>(type_ptr_id, vec4_type_ptr);
356 ptr_type.self = type_id;
357
358 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
359 set_decoration(var_id, DecorationBuiltIn, BuiltInFragCoord);
360 builtin_frag_coord_id = var_id;
361 mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var_id);
362 }
363
364 if (!has_layer && msl_options.arrayed_subpass_input && !msl_options.multiview)
365 {
366 uint32_t offset = ir.increase_bound_by(2);
367 uint32_t type_ptr_id = offset;
368 uint32_t var_id = offset + 1;
369
370 // Create gl_Layer.
371 SPIRType uint_type_ptr;
372 uint_type_ptr = get_uint_type();
373 uint_type_ptr.pointer = true;
374 uint_type_ptr.parent_type = get_uint_type_id();
375 uint_type_ptr.storage = StorageClassInput;
376 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
377 ptr_type.self = get_uint_type_id();
378
379 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
380 set_decoration(var_id, DecorationBuiltIn, BuiltInLayer);
381 builtin_layer_id = var_id;
382 mark_implicit_builtin(StorageClassInput, BuiltInLayer, var_id);
383 }
384
385 if (!has_view_idx && msl_options.multiview)
386 {
387 uint32_t offset = ir.increase_bound_by(2);
388 uint32_t type_ptr_id = offset;
389 uint32_t var_id = offset + 1;
390
391 // Create gl_ViewIndex.
392 SPIRType uint_type_ptr;
393 uint_type_ptr = get_uint_type();
394 uint_type_ptr.pointer = true;
395 uint_type_ptr.parent_type = get_uint_type_id();
396 uint_type_ptr.storage = StorageClassInput;
397 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
398 ptr_type.self = get_uint_type_id();
399
400 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
401 set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex);
402 builtin_view_idx_id = var_id;
403 mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id);
404 }
405 }
406
407 if (!has_sample_id && need_sample_pos)
408 {
409 uint32_t offset = ir.increase_bound_by(2);
410 uint32_t type_ptr_id = offset;
411 uint32_t var_id = offset + 1;
412
413 // Create gl_SampleID.
414 SPIRType uint_type_ptr;
415 uint_type_ptr = get_uint_type();
416 uint_type_ptr.pointer = true;
417 uint_type_ptr.parent_type = get_uint_type_id();
418 uint_type_ptr.storage = StorageClassInput;
419 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
420 ptr_type.self = get_uint_type_id();
421
422 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
423 set_decoration(var_id, DecorationBuiltIn, BuiltInSampleId);
424 builtin_sample_id_id = var_id;
425 mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var_id);
426 }
427
428 if ((need_vertex_params && (!has_vertex_idx || !has_base_vertex || !has_instance_idx || !has_base_instance)) ||
429 (need_multiview && (!has_instance_idx || !has_base_instance || !has_view_idx)))
430 {
431 uint32_t type_ptr_id = ir.increase_bound_by(1);
432
433 SPIRType uint_type_ptr;
434 uint_type_ptr = get_uint_type();
435 uint_type_ptr.pointer = true;
436 uint_type_ptr.parent_type = get_uint_type_id();
437 uint_type_ptr.storage = StorageClassInput;
438 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
439 ptr_type.self = get_uint_type_id();
440
441 if (need_vertex_params && !has_vertex_idx)
442 {
443 uint32_t var_id = ir.increase_bound_by(1);
444
445 // Create gl_VertexIndex.
446 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
447 set_decoration(var_id, DecorationBuiltIn, BuiltInVertexIndex);
448 builtin_vertex_idx_id = var_id;
449 mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var_id);
450 }
451
452 if (need_vertex_params && !has_base_vertex)
453 {
454 uint32_t var_id = ir.increase_bound_by(1);
455
456 // Create gl_BaseVertex.
457 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
458 set_decoration(var_id, DecorationBuiltIn, BuiltInBaseVertex);
459 builtin_base_vertex_id = var_id;
460 mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var_id);
461 }
462
463 if (!has_instance_idx) // Needed by both multiview and tessellation
464 {
465 uint32_t var_id = ir.increase_bound_by(1);
466
467 // Create gl_InstanceIndex.
468 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
469 set_decoration(var_id, DecorationBuiltIn, BuiltInInstanceIndex);
470 builtin_instance_idx_id = var_id;
471 mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var_id);
472 }
473
474 if (!has_base_instance) // Needed by both multiview and tessellation
475 {
476 uint32_t var_id = ir.increase_bound_by(1);
477
478 // Create gl_BaseInstance.
479 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
480 set_decoration(var_id, DecorationBuiltIn, BuiltInBaseInstance);
481 builtin_base_instance_id = var_id;
482 mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var_id);
483 }
484
485 if (need_multiview)
486 {
487 // Multiview shaders are not allowed to write to gl_Layer, ostensibly because
488 // it is implicitly written from gl_ViewIndex, but we have to do that explicitly.
489 // Note that we can't just abuse gl_ViewIndex for this purpose: it's an input, but
490 // gl_Layer is an output in vertex-pipeline shaders.
491 uint32_t type_ptr_out_id = ir.increase_bound_by(2);
492 SPIRType uint_type_ptr_out;
493 uint_type_ptr_out = get_uint_type();
494 uint_type_ptr_out.pointer = true;
495 uint_type_ptr_out.parent_type = get_uint_type_id();
496 uint_type_ptr_out.storage = StorageClassOutput;
497 auto &ptr_out_type = set<SPIRType>(type_ptr_out_id, uint_type_ptr_out);
498 ptr_out_type.self = get_uint_type_id();
499 uint32_t var_id = type_ptr_out_id + 1;
500 set<SPIRVariable>(var_id, type_ptr_out_id, StorageClassOutput);
501 set_decoration(var_id, DecorationBuiltIn, BuiltInLayer);
502 builtin_layer_id = var_id;
503 mark_implicit_builtin(StorageClassOutput, BuiltInLayer, var_id);
504 }
505
506 if (need_multiview && !has_view_idx)
507 {
508 uint32_t var_id = ir.increase_bound_by(1);
509
510 // Create gl_ViewIndex.
511 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
512 set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex);
513 builtin_view_idx_id = var_id;
514 mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id);
515 }
516 }
517
518 if ((need_tesc_params && (msl_options.multi_patch_workgroup || !has_invocation_id || !has_primitive_id)) ||
519 need_grid_params)
520 {
521 uint32_t type_ptr_id = ir.increase_bound_by(1);
522
523 SPIRType uint_type_ptr;
524 uint_type_ptr = get_uint_type();
525 uint_type_ptr.pointer = true;
526 uint_type_ptr.parent_type = get_uint_type_id();
527 uint_type_ptr.storage = StorageClassInput;
528 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
529 ptr_type.self = get_uint_type_id();
530
531 if (msl_options.multi_patch_workgroup || need_grid_params)
532 {
533 uint32_t var_id = ir.increase_bound_by(1);
534
535 // Create gl_GlobalInvocationID.
536 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
537 set_decoration(var_id, DecorationBuiltIn, BuiltInGlobalInvocationId);
538 builtin_invocation_id_id = var_id;
539 mark_implicit_builtin(StorageClassInput, BuiltInGlobalInvocationId, var_id);
540 }
541 else if (need_tesc_params && !has_invocation_id)
542 {
543 uint32_t var_id = ir.increase_bound_by(1);
544
545 // Create gl_InvocationID.
546 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
547 set_decoration(var_id, DecorationBuiltIn, BuiltInInvocationId);
548 builtin_invocation_id_id = var_id;
549 mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var_id);
550 }
551
552 if (need_tesc_params && !has_primitive_id)
553 {
554 uint32_t var_id = ir.increase_bound_by(1);
555
556 // Create gl_PrimitiveID.
557 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
558 set_decoration(var_id, DecorationBuiltIn, BuiltInPrimitiveId);
559 builtin_primitive_id_id = var_id;
560 mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var_id);
561 }
562
563 if (need_grid_params)
564 {
565 uint32_t var_id = ir.increase_bound_by(1);
566
567 set<SPIRVariable>(var_id, build_extended_vector_type(get_uint_type_id(), 3), StorageClassInput);
568 set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInStageInputSize);
569 get_entry_point().interface_variables.push_back(var_id);
570 set_name(var_id, "spvStageInputSize");
571 builtin_stage_input_size_id = var_id;
572 }
573 }
574
575 if (!has_subgroup_invocation_id && (need_subgroup_mask || needs_subgroup_invocation_id))
576 {
577 uint32_t offset = ir.increase_bound_by(2);
578 uint32_t type_ptr_id = offset;
579 uint32_t var_id = offset + 1;
580
581 // Create gl_SubgroupInvocationID.
582 SPIRType uint_type_ptr;
583 uint_type_ptr = get_uint_type();
584 uint_type_ptr.pointer = true;
585 uint_type_ptr.parent_type = get_uint_type_id();
586 uint_type_ptr.storage = StorageClassInput;
587 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
588 ptr_type.self = get_uint_type_id();
589
590 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
591 set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupLocalInvocationId);
592 builtin_subgroup_invocation_id_id = var_id;
593 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var_id);
594 }
595
596 if (!has_subgroup_size && need_subgroup_ge_mask)
597 {
598 uint32_t offset = ir.increase_bound_by(2);
599 uint32_t type_ptr_id = offset;
600 uint32_t var_id = offset + 1;
601
602 // Create gl_SubgroupSize.
603 SPIRType uint_type_ptr;
604 uint_type_ptr = get_uint_type();
605 uint_type_ptr.pointer = true;
606 uint_type_ptr.parent_type = get_uint_type_id();
607 uint_type_ptr.storage = StorageClassInput;
608 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
609 ptr_type.self = get_uint_type_id();
610
611 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
612 set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupSize);
613 builtin_subgroup_size_id = var_id;
614 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var_id);
615 }
616
617 if (need_dispatch_base || need_vertex_base_params)
618 {
619 if (workgroup_id_type == 0)
620 workgroup_id_type = build_extended_vector_type(get_uint_type_id(), 3);
621 uint32_t var_id;
622 if (msl_options.supports_msl_version(1, 2))
623 {
624 // If we have MSL 1.2, we can (ab)use the [[grid_origin]] builtin
625 // to convey this information and save a buffer slot.
626 uint32_t offset = ir.increase_bound_by(1);
627 var_id = offset;
628
629 set<SPIRVariable>(var_id, workgroup_id_type, StorageClassInput);
630 set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase);
631 get_entry_point().interface_variables.push_back(var_id);
632 }
633 else
634 {
635 // Otherwise, we need to fall back to a good ol' fashioned buffer.
636 uint32_t offset = ir.increase_bound_by(2);
637 var_id = offset;
638 uint32_t type_id = offset + 1;
639
640 SPIRType var_type = get<SPIRType>(workgroup_id_type);
641 var_type.storage = StorageClassUniform;
642 set<SPIRType>(type_id, var_type);
643
644 set<SPIRVariable>(var_id, type_id, StorageClassUniform);
645 // This should never match anything.
646 set_decoration(var_id, DecorationDescriptorSet, ~(5u));
647 set_decoration(var_id, DecorationBinding, msl_options.indirect_params_buffer_index);
648 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary,
649 msl_options.indirect_params_buffer_index);
650 }
651 set_name(var_id, "spvDispatchBase");
652 builtin_dispatch_base_id = var_id;
653 }
654
655 if (need_sample_mask && !does_shader_write_sample_mask)
656 {
657 uint32_t offset = ir.increase_bound_by(2);
658 uint32_t var_id = offset + 1;
659
660 // Create gl_SampleMask.
661 SPIRType uint_type_ptr_out;
662 uint_type_ptr_out = get_uint_type();
663 uint_type_ptr_out.pointer = true;
664 uint_type_ptr_out.parent_type = get_uint_type_id();
665 uint_type_ptr_out.storage = StorageClassOutput;
666
667 auto &ptr_out_type = set<SPIRType>(offset, uint_type_ptr_out);
668 ptr_out_type.self = get_uint_type_id();
669 set<SPIRVariable>(var_id, offset, StorageClassOutput);
670 set_decoration(var_id, DecorationBuiltIn, BuiltInSampleMask);
671 builtin_sample_mask_id = var_id;
672 mark_implicit_builtin(StorageClassOutput, BuiltInSampleMask, var_id);
673 }
674 }
675
676 if (needs_swizzle_buffer_def)
677 {
678 uint32_t var_id = build_constant_uint_array_pointer();
679 set_name(var_id, "spvSwizzleConstants");
680 // This should never match anything.
681 set_decoration(var_id, DecorationDescriptorSet, kSwizzleBufferBinding);
682 set_decoration(var_id, DecorationBinding, msl_options.swizzle_buffer_index);
683 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.swizzle_buffer_index);
684 swizzle_buffer_id = var_id;
685 }
686
687 if (!buffers_requiring_array_length.empty())
688 {
689 uint32_t var_id = build_constant_uint_array_pointer();
690 set_name(var_id, "spvBufferSizeConstants");
691 // This should never match anything.
692 set_decoration(var_id, DecorationDescriptorSet, kBufferSizeBufferBinding);
693 set_decoration(var_id, DecorationBinding, msl_options.buffer_size_buffer_index);
694 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.buffer_size_buffer_index);
695 buffer_size_buffer_id = var_id;
696 }
697
698 if (needs_view_mask_buffer())
699 {
700 uint32_t var_id = build_constant_uint_array_pointer();
701 set_name(var_id, "spvViewMask");
702 // This should never match anything.
703 set_decoration(var_id, DecorationDescriptorSet, ~(4u));
704 set_decoration(var_id, DecorationBinding, msl_options.view_mask_buffer_index);
705 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.view_mask_buffer_index);
706 view_mask_buffer_id = var_id;
707 }
708
709 if (!buffers_requiring_dynamic_offset.empty())
710 {
711 uint32_t var_id = build_constant_uint_array_pointer();
712 set_name(var_id, "spvDynamicOffsets");
713 // This should never match anything.
714 set_decoration(var_id, DecorationDescriptorSet, ~(5u));
715 set_decoration(var_id, DecorationBinding, msl_options.dynamic_offsets_buffer_index);
716 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary,
717 msl_options.dynamic_offsets_buffer_index);
718 dynamic_offsets_buffer_id = var_id;
719 }
720 }
721
722 // Checks if the specified builtin variable (e.g. gl_InstanceIndex) is marked as active.
723 // If not, it marks it as active and forces a recompilation.
724 // This might be used when the optimization of inactive builtins was too optimistic (e.g. when "spvOut" is emitted).
ensure_builtin(spv::StorageClass storage,spv::BuiltIn builtin)725 void CompilerMSL::ensure_builtin(spv::StorageClass storage, spv::BuiltIn builtin)
726 {
727 Bitset *active_builtins = nullptr;
728 switch (storage)
729 {
730 case StorageClassInput:
731 active_builtins = &active_input_builtins;
732 break;
733
734 case StorageClassOutput:
735 active_builtins = &active_output_builtins;
736 break;
737
738 default:
739 break;
740 }
741
742 // At this point, the specified builtin variable must have already been declared in the entry point.
743 // If not, mark as active and force recompile.
744 if (active_builtins != nullptr && !active_builtins->get(builtin))
745 {
746 active_builtins->set(builtin);
747 force_recompile();
748 }
749 }
750
mark_implicit_builtin(StorageClass storage,BuiltIn builtin,uint32_t id)751 void CompilerMSL::mark_implicit_builtin(StorageClass storage, BuiltIn builtin, uint32_t id)
752 {
753 Bitset *active_builtins = nullptr;
754 switch (storage)
755 {
756 case StorageClassInput:
757 active_builtins = &active_input_builtins;
758 break;
759
760 case StorageClassOutput:
761 active_builtins = &active_output_builtins;
762 break;
763
764 default:
765 break;
766 }
767
768 assert(active_builtins != nullptr);
769 active_builtins->set(builtin);
770
771 auto &var = get_entry_point().interface_variables;
772 if (find(begin(var), end(var), VariableID(id)) == end(var))
773 var.push_back(id);
774 }
775
build_constant_uint_array_pointer()776 uint32_t CompilerMSL::build_constant_uint_array_pointer()
777 {
778 uint32_t offset = ir.increase_bound_by(3);
779 uint32_t type_ptr_id = offset;
780 uint32_t type_ptr_ptr_id = offset + 1;
781 uint32_t var_id = offset + 2;
782
783 // Create a buffer to hold extra data, including the swizzle constants.
784 SPIRType uint_type_pointer = get_uint_type();
785 uint_type_pointer.pointer = true;
786 uint_type_pointer.pointer_depth = 1;
787 uint_type_pointer.parent_type = get_uint_type_id();
788 uint_type_pointer.storage = StorageClassUniform;
789 set<SPIRType>(type_ptr_id, uint_type_pointer);
790 set_decoration(type_ptr_id, DecorationArrayStride, 4);
791
792 SPIRType uint_type_pointer2 = uint_type_pointer;
793 uint_type_pointer2.pointer_depth++;
794 uint_type_pointer2.parent_type = type_ptr_id;
795 set<SPIRType>(type_ptr_ptr_id, uint_type_pointer2);
796
797 set<SPIRVariable>(var_id, type_ptr_ptr_id, StorageClassUniformConstant);
798 return var_id;
799 }
800
create_sampler_address(const char * prefix,MSLSamplerAddress addr)801 static string create_sampler_address(const char *prefix, MSLSamplerAddress addr)
802 {
803 switch (addr)
804 {
805 case MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE:
806 return join(prefix, "address::clamp_to_edge");
807 case MSL_SAMPLER_ADDRESS_CLAMP_TO_ZERO:
808 return join(prefix, "address::clamp_to_zero");
809 case MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER:
810 return join(prefix, "address::clamp_to_border");
811 case MSL_SAMPLER_ADDRESS_REPEAT:
812 return join(prefix, "address::repeat");
813 case MSL_SAMPLER_ADDRESS_MIRRORED_REPEAT:
814 return join(prefix, "address::mirrored_repeat");
815 default:
816 SPIRV_CROSS_THROW("Invalid sampler addressing mode.");
817 }
818 }
819
get_stage_in_struct_type()820 SPIRType &CompilerMSL::get_stage_in_struct_type()
821 {
822 auto &si_var = get<SPIRVariable>(stage_in_var_id);
823 return get_variable_data_type(si_var);
824 }
825
get_stage_out_struct_type()826 SPIRType &CompilerMSL::get_stage_out_struct_type()
827 {
828 auto &so_var = get<SPIRVariable>(stage_out_var_id);
829 return get_variable_data_type(so_var);
830 }
831
get_patch_stage_in_struct_type()832 SPIRType &CompilerMSL::get_patch_stage_in_struct_type()
833 {
834 auto &si_var = get<SPIRVariable>(patch_stage_in_var_id);
835 return get_variable_data_type(si_var);
836 }
837
get_patch_stage_out_struct_type()838 SPIRType &CompilerMSL::get_patch_stage_out_struct_type()
839 {
840 auto &so_var = get<SPIRVariable>(patch_stage_out_var_id);
841 return get_variable_data_type(so_var);
842 }
843
get_tess_factor_struct_name()844 std::string CompilerMSL::get_tess_factor_struct_name()
845 {
846 if (get_entry_point().flags.get(ExecutionModeTriangles))
847 return "MTLTriangleTessellationFactorsHalf";
848 return "MTLQuadTessellationFactorsHalf";
849 }
850
get_uint_type()851 SPIRType &CompilerMSL::get_uint_type()
852 {
853 return get<SPIRType>(get_uint_type_id());
854 }
855
get_uint_type_id()856 uint32_t CompilerMSL::get_uint_type_id()
857 {
858 if (uint_type_id != 0)
859 return uint_type_id;
860
861 uint_type_id = ir.increase_bound_by(1);
862
863 SPIRType type;
864 type.basetype = SPIRType::UInt;
865 type.width = 32;
866 set<SPIRType>(uint_type_id, type);
867 return uint_type_id;
868 }
869
emit_entry_point_declarations()870 void CompilerMSL::emit_entry_point_declarations()
871 {
872 // FIXME: Get test coverage here ...
873 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
874 declare_complex_constant_arrays();
875
876 // Emit constexpr samplers here.
877 for (auto &samp : constexpr_samplers_by_id)
878 {
879 auto &var = get<SPIRVariable>(samp.first);
880 auto &type = get<SPIRType>(var.basetype);
881 if (type.basetype == SPIRType::Sampler)
882 add_resource_name(samp.first);
883
884 SmallVector<string> args;
885 auto &s = samp.second;
886
887 if (s.coord != MSL_SAMPLER_COORD_NORMALIZED)
888 args.push_back("coord::pixel");
889
890 if (s.min_filter == s.mag_filter)
891 {
892 if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
893 args.push_back("filter::linear");
894 }
895 else
896 {
897 if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
898 args.push_back("min_filter::linear");
899 if (s.mag_filter != MSL_SAMPLER_FILTER_NEAREST)
900 args.push_back("mag_filter::linear");
901 }
902
903 switch (s.mip_filter)
904 {
905 case MSL_SAMPLER_MIP_FILTER_NONE:
906 // Default
907 break;
908 case MSL_SAMPLER_MIP_FILTER_NEAREST:
909 args.push_back("mip_filter::nearest");
910 break;
911 case MSL_SAMPLER_MIP_FILTER_LINEAR:
912 args.push_back("mip_filter::linear");
913 break;
914 default:
915 SPIRV_CROSS_THROW("Invalid mip filter.");
916 }
917
918 if (s.s_address == s.t_address && s.s_address == s.r_address)
919 {
920 if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
921 args.push_back(create_sampler_address("", s.s_address));
922 }
923 else
924 {
925 if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
926 args.push_back(create_sampler_address("s_", s.s_address));
927 if (s.t_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
928 args.push_back(create_sampler_address("t_", s.t_address));
929 if (s.r_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
930 args.push_back(create_sampler_address("r_", s.r_address));
931 }
932
933 if (s.compare_enable)
934 {
935 switch (s.compare_func)
936 {
937 case MSL_SAMPLER_COMPARE_FUNC_ALWAYS:
938 args.push_back("compare_func::always");
939 break;
940 case MSL_SAMPLER_COMPARE_FUNC_NEVER:
941 args.push_back("compare_func::never");
942 break;
943 case MSL_SAMPLER_COMPARE_FUNC_EQUAL:
944 args.push_back("compare_func::equal");
945 break;
946 case MSL_SAMPLER_COMPARE_FUNC_NOT_EQUAL:
947 args.push_back("compare_func::not_equal");
948 break;
949 case MSL_SAMPLER_COMPARE_FUNC_LESS:
950 args.push_back("compare_func::less");
951 break;
952 case MSL_SAMPLER_COMPARE_FUNC_LESS_EQUAL:
953 args.push_back("compare_func::less_equal");
954 break;
955 case MSL_SAMPLER_COMPARE_FUNC_GREATER:
956 args.push_back("compare_func::greater");
957 break;
958 case MSL_SAMPLER_COMPARE_FUNC_GREATER_EQUAL:
959 args.push_back("compare_func::greater_equal");
960 break;
961 default:
962 SPIRV_CROSS_THROW("Invalid sampler compare function.");
963 }
964 }
965
966 if (s.s_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER || s.t_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER ||
967 s.r_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER)
968 {
969 switch (s.border_color)
970 {
971 case MSL_SAMPLER_BORDER_COLOR_OPAQUE_BLACK:
972 args.push_back("border_color::opaque_black");
973 break;
974 case MSL_SAMPLER_BORDER_COLOR_OPAQUE_WHITE:
975 args.push_back("border_color::opaque_white");
976 break;
977 case MSL_SAMPLER_BORDER_COLOR_TRANSPARENT_BLACK:
978 args.push_back("border_color::transparent_black");
979 break;
980 default:
981 SPIRV_CROSS_THROW("Invalid sampler border color.");
982 }
983 }
984
985 if (s.anisotropy_enable)
986 args.push_back(join("max_anisotropy(", s.max_anisotropy, ")"));
987 if (s.lod_clamp_enable)
988 {
989 args.push_back(join("lod_clamp(", convert_to_string(s.lod_clamp_min, current_locale_radix_character), ", ",
990 convert_to_string(s.lod_clamp_max, current_locale_radix_character), ")"));
991 }
992
993 // If we would emit no arguments, then omit the parentheses entirely. Otherwise,
994 // we'll wind up with a "most vexing parse" situation.
995 if (args.empty())
996 statement("constexpr sampler ",
997 type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first),
998 ";");
999 else
1000 statement("constexpr sampler ",
1001 type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first),
1002 "(", merge(args), ");");
1003 }
1004
1005 // Emit dynamic buffers here.
1006 for (auto &dynamic_buffer : buffers_requiring_dynamic_offset)
1007 {
1008 if (!dynamic_buffer.second.second)
1009 {
1010 // Could happen if no buffer was used at requested binding point.
1011 continue;
1012 }
1013
1014 const auto &var = get<SPIRVariable>(dynamic_buffer.second.second);
1015 uint32_t var_id = var.self;
1016 const auto &type = get_variable_data_type(var);
1017 string name = to_name(var.self);
1018 uint32_t desc_set = get_decoration(var.self, DecorationDescriptorSet);
1019 uint32_t arg_id = argument_buffer_ids[desc_set];
1020 uint32_t base_index = dynamic_buffer.second.first;
1021
1022 if (!type.array.empty())
1023 {
1024 // This is complicated, because we need to support arrays of arrays.
1025 // And it's even worse if the outermost dimension is a runtime array, because now
1026 // all this complicated goop has to go into the shader itself. (FIXME)
1027 if (!type.array[type.array.size() - 1])
1028 SPIRV_CROSS_THROW("Runtime arrays with dynamic offsets are not supported yet.");
1029 else
1030 {
1031 is_using_builtin_array = true;
1032 statement(get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id), name,
1033 type_to_array_glsl(type), " =");
1034
1035 uint32_t dim = uint32_t(type.array.size());
1036 uint32_t j = 0;
1037 for (SmallVector<uint32_t> indices(type.array.size());
1038 indices[type.array.size() - 1] < to_array_size_literal(type); j++)
1039 {
1040 while (dim > 0)
1041 {
1042 begin_scope();
1043 --dim;
1044 }
1045
1046 string arrays;
1047 for (uint32_t i = uint32_t(type.array.size()); i; --i)
1048 arrays += join("[", indices[i - 1], "]");
1049 statement("(", get_argument_address_space(var), " ", type_to_glsl(type), "* ",
1050 to_restrict(var_id, false), ")((", get_argument_address_space(var), " char* ",
1051 to_restrict(var_id, false), ")", to_name(arg_id), ".", ensure_valid_name(name, "m"),
1052 arrays, " + ", to_name(dynamic_offsets_buffer_id), "[", base_index + j, "]),");
1053
1054 while (++indices[dim] >= to_array_size_literal(type, dim) && dim < type.array.size() - 1)
1055 {
1056 end_scope(",");
1057 indices[dim++] = 0;
1058 }
1059 }
1060 end_scope_decl();
1061 statement_no_indent("");
1062 is_using_builtin_array = false;
1063 }
1064 }
1065 else
1066 {
1067 statement(get_argument_address_space(var), " auto& ", to_restrict(var_id), name, " = *(",
1068 get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id, false), ")((",
1069 get_argument_address_space(var), " char* ", to_restrict(var_id, false), ")", to_name(arg_id), ".",
1070 ensure_valid_name(name, "m"), " + ", to_name(dynamic_offsets_buffer_id), "[", base_index, "]);");
1071 }
1072 }
1073
1074 // Emit buffer arrays here.
1075 for (uint32_t array_id : buffer_arrays)
1076 {
1077 const auto &var = get<SPIRVariable>(array_id);
1078 const auto &type = get_variable_data_type(var);
1079 const auto &buffer_type = get_variable_element_type(var);
1080 string name = to_name(array_id);
1081 statement(get_argument_address_space(var), " ", type_to_glsl(buffer_type), "* ", to_restrict(array_id), name,
1082 "[] =");
1083 begin_scope();
1084 for (uint32_t i = 0; i < to_array_size_literal(type); ++i)
1085 statement(name, "_", i, ",");
1086 end_scope_decl();
1087 statement_no_indent("");
1088 }
1089 // For some reason, without this, we end up emitting the arrays twice.
1090 buffer_arrays.clear();
1091
1092 // Emit disabled fragment outputs.
1093 std::sort(disabled_frag_outputs.begin(), disabled_frag_outputs.end());
1094 for (uint32_t var_id : disabled_frag_outputs)
1095 {
1096 auto &var = get<SPIRVariable>(var_id);
1097 add_local_variable_name(var_id);
1098 statement(variable_decl(var), ";");
1099 var.deferred_declaration = false;
1100 }
1101 }
1102
compile()1103 string CompilerMSL::compile()
1104 {
1105 ir.fixup_reserved_names();
1106
1107 // Do not deal with GLES-isms like precision, older extensions and such.
1108 options.vulkan_semantics = true;
1109 options.es = false;
1110 options.version = 450;
1111 backend.null_pointer_literal = "nullptr";
1112 backend.float_literal_suffix = false;
1113 backend.uint32_t_literal_suffix = true;
1114 backend.int16_t_literal_suffix = "";
1115 backend.uint16_t_literal_suffix = "";
1116 backend.basic_int_type = "int";
1117 backend.basic_uint_type = "uint";
1118 backend.basic_int8_type = "char";
1119 backend.basic_uint8_type = "uchar";
1120 backend.basic_int16_type = "short";
1121 backend.basic_uint16_type = "ushort";
1122 backend.discard_literal = "discard_fragment()";
1123 backend.demote_literal = "unsupported-demote";
1124 backend.boolean_mix_function = "select";
1125 backend.swizzle_is_function = false;
1126 backend.shared_is_implied = false;
1127 backend.use_initializer_list = true;
1128 backend.use_typed_initializer_list = true;
1129 backend.native_row_major_matrix = false;
1130 backend.unsized_array_supported = false;
1131 backend.can_declare_arrays_inline = false;
1132 backend.allow_truncated_access_chain = true;
1133 backend.comparison_image_samples_scalar = true;
1134 backend.native_pointers = true;
1135 backend.nonuniform_qualifier = "";
1136 backend.support_small_type_sampling_result = true;
1137 backend.supports_empty_struct = true;
1138
1139 // Allow Metal to use the array<T> template unless we force it off.
1140 backend.can_return_array = !msl_options.force_native_arrays;
1141 backend.array_is_value_type = !msl_options.force_native_arrays;
1142 // Arrays which are part of buffer objects are never considered to be native arrays.
1143 backend.buffer_offset_array_is_value_type = false;
1144
1145 capture_output_to_buffer = msl_options.capture_output_to_buffer;
1146 is_rasterization_disabled = msl_options.disable_rasterization || capture_output_to_buffer;
1147
1148 // Initialize array here rather than constructor, MSVC 2013 workaround.
1149 for (auto &id : next_metal_resource_ids)
1150 id = 0;
1151
1152 fixup_type_alias();
1153 replace_illegal_names();
1154
1155 build_function_control_flow_graphs_and_analyze();
1156 update_active_builtins();
1157 analyze_image_and_sampler_usage();
1158 analyze_sampled_image_usage();
1159 analyze_interlocked_resource_usage();
1160 preprocess_op_codes();
1161 build_implicit_builtins();
1162
1163 fixup_image_load_store_access();
1164
1165 set_enabled_interface_variables(get_active_interface_variables());
1166 if (msl_options.force_active_argument_buffer_resources)
1167 activate_argument_buffer_resources();
1168
1169 if (swizzle_buffer_id)
1170 active_interface_variables.insert(swizzle_buffer_id);
1171 if (buffer_size_buffer_id)
1172 active_interface_variables.insert(buffer_size_buffer_id);
1173 if (view_mask_buffer_id)
1174 active_interface_variables.insert(view_mask_buffer_id);
1175 if (dynamic_offsets_buffer_id)
1176 active_interface_variables.insert(dynamic_offsets_buffer_id);
1177 if (builtin_layer_id)
1178 active_interface_variables.insert(builtin_layer_id);
1179 if (builtin_dispatch_base_id && !msl_options.supports_msl_version(1, 2))
1180 active_interface_variables.insert(builtin_dispatch_base_id);
1181 if (builtin_sample_mask_id)
1182 active_interface_variables.insert(builtin_sample_mask_id);
1183
1184 // Create structs to hold input, output and uniform variables.
1185 // Do output first to ensure out. is declared at top of entry function.
1186 qual_pos_var_name = "";
1187 stage_out_var_id = add_interface_block(StorageClassOutput);
1188 patch_stage_out_var_id = add_interface_block(StorageClassOutput, true);
1189 stage_in_var_id = add_interface_block(StorageClassInput);
1190 if (get_execution_model() == ExecutionModelTessellationEvaluation)
1191 patch_stage_in_var_id = add_interface_block(StorageClassInput, true);
1192
1193 if (get_execution_model() == ExecutionModelTessellationControl)
1194 stage_out_ptr_var_id = add_interface_block_pointer(stage_out_var_id, StorageClassOutput);
1195 if (is_tessellation_shader())
1196 stage_in_ptr_var_id = add_interface_block_pointer(stage_in_var_id, StorageClassInput);
1197
1198 // Metal vertex functions that define no output must disable rasterization and return void.
1199 if (!stage_out_var_id)
1200 is_rasterization_disabled = true;
1201
1202 // Convert the use of global variables to recursively-passed function parameters
1203 localize_global_variables();
1204 extract_global_variables_from_functions();
1205
1206 // Mark any non-stage-in structs to be tightly packed.
1207 mark_packable_structs();
1208 reorder_type_alias();
1209
1210 // Add fixup hooks required by shader inputs and outputs. This needs to happen before
1211 // the loop, so the hooks aren't added multiple times.
1212 fix_up_shader_inputs_outputs();
1213
1214 // If we are using argument buffers, we create argument buffer structures for them here.
1215 // These buffers will be used in the entry point, not the individual resources.
1216 if (msl_options.argument_buffers)
1217 {
1218 if (!msl_options.supports_msl_version(2, 0))
1219 SPIRV_CROSS_THROW("Argument buffers can only be used with MSL 2.0 and up.");
1220 analyze_argument_buffers();
1221 }
1222
1223 uint32_t pass_count = 0;
1224 do
1225 {
1226 if (pass_count >= 3)
1227 SPIRV_CROSS_THROW("Over 3 compilation loops detected. Must be a bug!");
1228
1229 reset();
1230
1231 // Start bindings at zero.
1232 next_metal_resource_index_buffer = 0;
1233 next_metal_resource_index_texture = 0;
1234 next_metal_resource_index_sampler = 0;
1235 for (auto &id : next_metal_resource_ids)
1236 id = 0;
1237
1238 // Move constructor for this type is broken on GCC 4.9 ...
1239 buffer.reset();
1240
1241 emit_header();
1242 emit_custom_templates();
1243 emit_specialization_constants_and_structs();
1244 emit_resources();
1245 emit_custom_functions();
1246 emit_function(get<SPIRFunction>(ir.default_entry_point), Bitset());
1247
1248 pass_count++;
1249 } while (is_forcing_recompilation());
1250
1251 return buffer.str();
1252 }
1253
1254 // Register the need to output any custom functions.
preprocess_op_codes()1255 void CompilerMSL::preprocess_op_codes()
1256 {
1257 OpCodePreprocessor preproc(*this);
1258 traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), preproc);
1259
1260 suppress_missing_prototypes = preproc.suppress_missing_prototypes;
1261
1262 if (preproc.uses_atomics)
1263 {
1264 add_header_line("#include <metal_atomic>");
1265 add_pragma_line("#pragma clang diagnostic ignored \"-Wunused-variable\"");
1266 }
1267
1268 // Metal vertex functions that write to resources must disable rasterization and return void.
1269 if (preproc.uses_resource_write)
1270 is_rasterization_disabled = true;
1271
1272 // Tessellation control shaders are run as compute functions in Metal, and so
1273 // must capture their output to a buffer.
1274 if (get_execution_model() == ExecutionModelTessellationControl ||
1275 (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
1276 {
1277 is_rasterization_disabled = true;
1278 capture_output_to_buffer = true;
1279 }
1280
1281 if (preproc.needs_subgroup_invocation_id)
1282 needs_subgroup_invocation_id = true;
1283 }
1284
1285 // Move the Private and Workgroup global variables to the entry function.
1286 // Non-constant variables cannot have global scope in Metal.
localize_global_variables()1287 void CompilerMSL::localize_global_variables()
1288 {
1289 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1290 auto iter = global_variables.begin();
1291 while (iter != global_variables.end())
1292 {
1293 uint32_t v_id = *iter;
1294 auto &var = get<SPIRVariable>(v_id);
1295 if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup)
1296 {
1297 if (!variable_is_lut(var))
1298 entry_func.add_local_variable(v_id);
1299 iter = global_variables.erase(iter);
1300 }
1301 else
1302 iter++;
1303 }
1304 }
1305
1306 // For any global variable accessed directly by a function,
1307 // extract that variable and add it as an argument to that function.
extract_global_variables_from_functions()1308 void CompilerMSL::extract_global_variables_from_functions()
1309 {
1310 // Uniforms
1311 unordered_set<uint32_t> global_var_ids;
1312 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1313 if (var.storage == StorageClassInput || var.storage == StorageClassOutput ||
1314 var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
1315 var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer)
1316 {
1317 global_var_ids.insert(var.self);
1318 }
1319 });
1320
1321 // Local vars that are declared in the main function and accessed directly by a function
1322 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1323 for (auto &var : entry_func.local_variables)
1324 if (get<SPIRVariable>(var).storage != StorageClassFunction)
1325 global_var_ids.insert(var);
1326
1327 std::set<uint32_t> added_arg_ids;
1328 unordered_set<uint32_t> processed_func_ids;
1329 extract_global_variables_from_function(ir.default_entry_point, added_arg_ids, global_var_ids, processed_func_ids);
1330 }
1331
1332 // MSL does not support the use of global variables for shader input content.
1333 // For any global variable accessed directly by the specified function, extract that variable,
1334 // add it as an argument to that function, and the arg to the added_arg_ids collection.
extract_global_variables_from_function(uint32_t func_id,std::set<uint32_t> & added_arg_ids,unordered_set<uint32_t> & global_var_ids,unordered_set<uint32_t> & processed_func_ids)1335 void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::set<uint32_t> &added_arg_ids,
1336 unordered_set<uint32_t> &global_var_ids,
1337 unordered_set<uint32_t> &processed_func_ids)
1338 {
1339 // Avoid processing a function more than once
1340 if (processed_func_ids.find(func_id) != processed_func_ids.end())
1341 {
1342 // Return function global variables
1343 added_arg_ids = function_global_vars[func_id];
1344 return;
1345 }
1346
1347 processed_func_ids.insert(func_id);
1348
1349 auto &func = get<SPIRFunction>(func_id);
1350
1351 // Recursively establish global args added to functions on which we depend.
1352 for (auto block : func.blocks)
1353 {
1354 auto &b = get<SPIRBlock>(block);
1355 for (auto &i : b.ops)
1356 {
1357 auto ops = stream(i);
1358 auto op = static_cast<Op>(i.op);
1359
1360 switch (op)
1361 {
1362 case OpLoad:
1363 case OpInBoundsAccessChain:
1364 case OpAccessChain:
1365 case OpPtrAccessChain:
1366 case OpArrayLength:
1367 {
1368 uint32_t base_id = ops[2];
1369 if (global_var_ids.find(base_id) != global_var_ids.end())
1370 added_arg_ids.insert(base_id);
1371
1372 // Use Metal's native frame-buffer fetch API for subpass inputs.
1373 auto &type = get<SPIRType>(ops[0]);
1374 if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData &&
1375 (!msl_options.is_ios() || !msl_options.ios_use_framebuffer_fetch_subpasses))
1376 {
1377 // Implicitly reads gl_FragCoord.
1378 assert(builtin_frag_coord_id != 0);
1379 added_arg_ids.insert(builtin_frag_coord_id);
1380 if (msl_options.multiview)
1381 {
1382 // Implicitly reads gl_ViewIndex.
1383 assert(builtin_view_idx_id != 0);
1384 added_arg_ids.insert(builtin_view_idx_id);
1385 }
1386 else if (msl_options.arrayed_subpass_input)
1387 {
1388 // Implicitly reads gl_Layer.
1389 assert(builtin_layer_id != 0);
1390 added_arg_ids.insert(builtin_layer_id);
1391 }
1392 }
1393
1394 break;
1395 }
1396
1397 case OpFunctionCall:
1398 {
1399 // First see if any of the function call args are globals
1400 for (uint32_t arg_idx = 3; arg_idx < i.length; arg_idx++)
1401 {
1402 uint32_t arg_id = ops[arg_idx];
1403 if (global_var_ids.find(arg_id) != global_var_ids.end())
1404 added_arg_ids.insert(arg_id);
1405 }
1406
1407 // Then recurse into the function itself to extract globals used internally in the function
1408 uint32_t inner_func_id = ops[2];
1409 std::set<uint32_t> inner_func_args;
1410 extract_global_variables_from_function(inner_func_id, inner_func_args, global_var_ids,
1411 processed_func_ids);
1412 added_arg_ids.insert(inner_func_args.begin(), inner_func_args.end());
1413 break;
1414 }
1415
1416 case OpStore:
1417 {
1418 uint32_t base_id = ops[0];
1419 if (global_var_ids.find(base_id) != global_var_ids.end())
1420 added_arg_ids.insert(base_id);
1421
1422 uint32_t rvalue_id = ops[1];
1423 if (global_var_ids.find(rvalue_id) != global_var_ids.end())
1424 added_arg_ids.insert(rvalue_id);
1425
1426 break;
1427 }
1428
1429 case OpSelect:
1430 {
1431 uint32_t base_id = ops[3];
1432 if (global_var_ids.find(base_id) != global_var_ids.end())
1433 added_arg_ids.insert(base_id);
1434 base_id = ops[4];
1435 if (global_var_ids.find(base_id) != global_var_ids.end())
1436 added_arg_ids.insert(base_id);
1437 break;
1438 }
1439
1440 // Emulate texture2D atomic operations
1441 case OpImageTexelPointer:
1442 {
1443 // When using the pointer, we need to know which variable it is actually loaded from.
1444 uint32_t base_id = ops[2];
1445 auto *var = maybe_get_backing_variable(base_id);
1446 if (var && atomic_image_vars.count(var->self))
1447 {
1448 if (global_var_ids.find(base_id) != global_var_ids.end())
1449 added_arg_ids.insert(base_id);
1450 }
1451 break;
1452 }
1453
1454 default:
1455 break;
1456 }
1457
1458 // TODO: Add all other operations which can affect memory.
1459 // We should consider a more unified system here to reduce boiler-plate.
1460 // This kind of analysis is done in several places ...
1461 }
1462 }
1463
1464 function_global_vars[func_id] = added_arg_ids;
1465
1466 // Add the global variables as arguments to the function
1467 if (func_id != ir.default_entry_point)
1468 {
1469 bool added_in = false;
1470 bool added_out = false;
1471 for (uint32_t arg_id : added_arg_ids)
1472 {
1473 auto &var = get<SPIRVariable>(arg_id);
1474 uint32_t type_id = var.basetype;
1475 auto *p_type = &get<SPIRType>(type_id);
1476 BuiltIn bi_type = BuiltIn(get_decoration(arg_id, DecorationBuiltIn));
1477
1478 if (((is_tessellation_shader() && var.storage == StorageClassInput) ||
1479 (get_execution_model() == ExecutionModelTessellationControl && var.storage == StorageClassOutput)) &&
1480 !(has_decoration(arg_id, DecorationPatch) || is_patch_block(*p_type)) &&
1481 (!is_builtin_variable(var) || bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
1482 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance ||
1483 p_type->basetype == SPIRType::Struct))
1484 {
1485 // Tessellation control shaders see inputs and per-vertex outputs as arrays.
1486 // Similarly, tessellation evaluation shaders see per-vertex inputs as arrays.
1487 // We collected them into a structure; we must pass the array of this
1488 // structure to the function.
1489 std::string name;
1490 if (var.storage == StorageClassInput)
1491 {
1492 if (added_in)
1493 continue;
1494 name = "gl_in";
1495 arg_id = stage_in_ptr_var_id;
1496 added_in = true;
1497 }
1498 else if (var.storage == StorageClassOutput)
1499 {
1500 if (added_out)
1501 continue;
1502 name = "gl_out";
1503 arg_id = stage_out_ptr_var_id;
1504 added_out = true;
1505 }
1506 type_id = get<SPIRVariable>(arg_id).basetype;
1507 uint32_t next_id = ir.increase_bound_by(1);
1508 func.add_parameter(type_id, next_id, true);
1509 set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1510
1511 set_name(next_id, name);
1512 }
1513 else if (is_builtin_variable(var) && p_type->basetype == SPIRType::Struct)
1514 {
1515 // Get the pointee type
1516 type_id = get_pointee_type_id(type_id);
1517 p_type = &get<SPIRType>(type_id);
1518
1519 uint32_t mbr_idx = 0;
1520 for (auto &mbr_type_id : p_type->member_types)
1521 {
1522 BuiltIn builtin = BuiltInMax;
1523 bool is_builtin = is_member_builtin(*p_type, mbr_idx, &builtin);
1524 if (is_builtin && has_active_builtin(builtin, var.storage))
1525 {
1526 // Add a arg variable with the same type and decorations as the member
1527 uint32_t next_ids = ir.increase_bound_by(2);
1528 uint32_t ptr_type_id = next_ids + 0;
1529 uint32_t var_id = next_ids + 1;
1530
1531 // Make sure we have an actual pointer type,
1532 // so that we will get the appropriate address space when declaring these builtins.
1533 auto &ptr = set<SPIRType>(ptr_type_id, get<SPIRType>(mbr_type_id));
1534 ptr.self = mbr_type_id;
1535 ptr.storage = var.storage;
1536 ptr.pointer = true;
1537 ptr.parent_type = mbr_type_id;
1538
1539 func.add_parameter(mbr_type_id, var_id, true);
1540 set<SPIRVariable>(var_id, ptr_type_id, StorageClassFunction);
1541 ir.meta[var_id].decoration = ir.meta[type_id].members[mbr_idx];
1542 }
1543 mbr_idx++;
1544 }
1545 }
1546 else
1547 {
1548 uint32_t next_id = ir.increase_bound_by(1);
1549 func.add_parameter(type_id, next_id, true);
1550 set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1551
1552 // Ensure the existing variable has a valid name and the new variable has all the same meta info
1553 set_name(arg_id, ensure_valid_name(to_name(arg_id), "v"));
1554 ir.meta[next_id] = ir.meta[arg_id];
1555 }
1556 }
1557 }
1558 }
1559
1560 // For all variables that are some form of non-input-output interface block, mark that all the structs
1561 // that are recursively contained within the type referenced by that variable should be packed tightly.
mark_packable_structs()1562 void CompilerMSL::mark_packable_structs()
1563 {
1564 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1565 if (var.storage != StorageClassFunction && !is_hidden_variable(var))
1566 {
1567 auto &type = this->get<SPIRType>(var.basetype);
1568 if (type.pointer &&
1569 (type.storage == StorageClassUniform || type.storage == StorageClassUniformConstant ||
1570 type.storage == StorageClassPushConstant || type.storage == StorageClassStorageBuffer) &&
1571 (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
1572 mark_as_packable(type);
1573 }
1574 });
1575 }
1576
1577 // If the specified type is a struct, it and any nested structs
1578 // are marked as packable with the SPIRVCrossDecorationBufferBlockRepacked decoration,
mark_as_packable(SPIRType & type)1579 void CompilerMSL::mark_as_packable(SPIRType &type)
1580 {
1581 // If this is not the base type (eg. it's a pointer or array), tunnel down
1582 if (type.parent_type)
1583 {
1584 mark_as_packable(get<SPIRType>(type.parent_type));
1585 return;
1586 }
1587
1588 if (type.basetype == SPIRType::Struct)
1589 {
1590 set_extended_decoration(type.self, SPIRVCrossDecorationBufferBlockRepacked);
1591
1592 // Recurse
1593 uint32_t mbr_cnt = uint32_t(type.member_types.size());
1594 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
1595 {
1596 uint32_t mbr_type_id = type.member_types[mbr_idx];
1597 auto &mbr_type = get<SPIRType>(mbr_type_id);
1598 mark_as_packable(mbr_type);
1599 if (mbr_type.type_alias)
1600 {
1601 auto &mbr_type_alias = get<SPIRType>(mbr_type.type_alias);
1602 mark_as_packable(mbr_type_alias);
1603 }
1604 }
1605 }
1606 }
1607
1608 // If a shader input exists at the location, it is marked as being used by this shader
mark_location_as_used_by_shader(uint32_t location,const SPIRType & type,StorageClass storage)1609 void CompilerMSL::mark_location_as_used_by_shader(uint32_t location, const SPIRType &type, StorageClass storage)
1610 {
1611 if (storage != StorageClassInput)
1612 return;
1613 if (is_array(type))
1614 {
1615 uint32_t dim = 1;
1616 for (uint32_t i = 0; i < type.array.size(); i++)
1617 dim *= to_array_size_literal(type, i);
1618 for (uint32_t i = 0; i < dim; i++)
1619 {
1620 if (is_matrix(type))
1621 {
1622 for (uint32_t j = 0; j < type.columns; j++)
1623 inputs_in_use.insert(location++);
1624 }
1625 else
1626 inputs_in_use.insert(location++);
1627 }
1628 }
1629 else if (is_matrix(type))
1630 {
1631 for (uint32_t i = 0; i < type.columns; i++)
1632 inputs_in_use.insert(location + i);
1633 }
1634 else
1635 inputs_in_use.insert(location);
1636 }
1637
get_target_components_for_fragment_location(uint32_t location) const1638 uint32_t CompilerMSL::get_target_components_for_fragment_location(uint32_t location) const
1639 {
1640 auto itr = fragment_output_components.find(location);
1641 if (itr == end(fragment_output_components))
1642 return 4;
1643 else
1644 return itr->second;
1645 }
1646
build_extended_vector_type(uint32_t type_id,uint32_t components,SPIRType::BaseType basetype)1647 uint32_t CompilerMSL::build_extended_vector_type(uint32_t type_id, uint32_t components, SPIRType::BaseType basetype)
1648 {
1649 uint32_t new_type_id = ir.increase_bound_by(1);
1650 auto &old_type = get<SPIRType>(type_id);
1651 auto *type = &set<SPIRType>(new_type_id, old_type);
1652 type->vecsize = components;
1653 if (basetype != SPIRType::Unknown)
1654 type->basetype = basetype;
1655 type->self = new_type_id;
1656 type->parent_type = type_id;
1657 type->array.clear();
1658 type->array_size_literal.clear();
1659 type->pointer = false;
1660
1661 if (is_array(old_type))
1662 {
1663 uint32_t array_type_id = ir.increase_bound_by(1);
1664 type = &set<SPIRType>(array_type_id, *type);
1665 type->parent_type = new_type_id;
1666 type->array = old_type.array;
1667 type->array_size_literal = old_type.array_size_literal;
1668 new_type_id = array_type_id;
1669 }
1670
1671 if (old_type.pointer)
1672 {
1673 uint32_t ptr_type_id = ir.increase_bound_by(1);
1674 type = &set<SPIRType>(ptr_type_id, *type);
1675 type->self = new_type_id;
1676 type->parent_type = new_type_id;
1677 type->storage = old_type.storage;
1678 type->pointer = true;
1679 new_type_id = ptr_type_id;
1680 }
1681
1682 return new_type_id;
1683 }
1684
add_plain_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)1685 void CompilerMSL::add_plain_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
1686 SPIRType &ib_type, SPIRVariable &var, InterfaceBlockMeta &meta)
1687 {
1688 bool is_builtin = is_builtin_variable(var);
1689 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
1690 bool is_flat = has_decoration(var.self, DecorationFlat);
1691 bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
1692 bool is_centroid = has_decoration(var.self, DecorationCentroid);
1693 bool is_sample = has_decoration(var.self, DecorationSample);
1694
1695 // Add a reference to the variable type to the interface struct.
1696 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1697 uint32_t type_id = ensure_correct_builtin_type(var.basetype, builtin);
1698 var.basetype = type_id;
1699
1700 type_id = get_pointee_type_id(var.basetype);
1701 if (meta.strip_array && is_array(get<SPIRType>(type_id)))
1702 type_id = get<SPIRType>(type_id).parent_type;
1703 auto &type = get<SPIRType>(type_id);
1704 uint32_t target_components = 0;
1705 uint32_t type_components = type.vecsize;
1706
1707 bool padded_output = false;
1708 bool padded_input = false;
1709 uint32_t start_component = 0;
1710
1711 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1712
1713 // Deal with Component decorations.
1714 InterfaceBlockMeta::LocationMeta *location_meta = nullptr;
1715 if (has_decoration(var.self, DecorationLocation))
1716 {
1717 auto location_meta_itr = meta.location_meta.find(get_decoration(var.self, DecorationLocation));
1718 if (location_meta_itr != end(meta.location_meta))
1719 location_meta = &location_meta_itr->second;
1720 }
1721
1722 bool pad_fragment_output = has_decoration(var.self, DecorationLocation) &&
1723 msl_options.pad_fragment_output_components &&
1724 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput;
1725
1726 // Check if we need to pad fragment output to match a certain number of components.
1727 if (location_meta)
1728 {
1729 start_component = get_decoration(var.self, DecorationComponent);
1730 uint32_t num_components = location_meta->num_components;
1731 if (pad_fragment_output)
1732 {
1733 uint32_t locn = get_decoration(var.self, DecorationLocation);
1734 num_components = std::max(num_components, get_target_components_for_fragment_location(locn));
1735 }
1736
1737 if (location_meta->ib_index != ~0u)
1738 {
1739 // We have already declared the variable. Just emit an early-declared variable and fixup as needed.
1740 entry_func.add_local_variable(var.self);
1741 vars_needing_early_declaration.push_back(var.self);
1742
1743 if (var.storage == StorageClassInput)
1744 {
1745 uint32_t ib_index = location_meta->ib_index;
1746 entry_func.fixup_hooks_in.push_back([=, &var]() {
1747 statement(to_name(var.self), " = ", ib_var_ref, ".", to_member_name(ib_type, ib_index),
1748 vector_swizzle(type_components, start_component), ";");
1749 });
1750 }
1751 else
1752 {
1753 uint32_t ib_index = location_meta->ib_index;
1754 entry_func.fixup_hooks_out.push_back([=, &var]() {
1755 statement(ib_var_ref, ".", to_member_name(ib_type, ib_index),
1756 vector_swizzle(type_components, start_component), " = ", to_name(var.self), ";");
1757 });
1758 }
1759 return;
1760 }
1761 else
1762 {
1763 location_meta->ib_index = uint32_t(ib_type.member_types.size());
1764 type_id = build_extended_vector_type(type_id, num_components);
1765 if (var.storage == StorageClassInput)
1766 padded_input = true;
1767 else
1768 padded_output = true;
1769 }
1770 }
1771 else if (pad_fragment_output)
1772 {
1773 uint32_t locn = get_decoration(var.self, DecorationLocation);
1774 target_components = get_target_components_for_fragment_location(locn);
1775 if (type_components < target_components)
1776 {
1777 // Make a new type here.
1778 type_id = build_extended_vector_type(type_id, target_components);
1779 padded_output = true;
1780 }
1781 }
1782
1783 ib_type.member_types.push_back(type_id);
1784
1785 // Give the member a name
1786 string mbr_name = ensure_valid_name(to_expression(var.self), "m");
1787 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1788
1789 // Update the original variable reference to include the structure reference
1790 string qual_var_name = ib_var_ref + "." + mbr_name;
1791
1792 if (padded_output || padded_input)
1793 {
1794 entry_func.add_local_variable(var.self);
1795 vars_needing_early_declaration.push_back(var.self);
1796
1797 if (padded_output)
1798 {
1799 entry_func.fixup_hooks_out.push_back([=, &var]() {
1800 statement(qual_var_name, vector_swizzle(type_components, start_component), " = ", to_name(var.self),
1801 ";");
1802 });
1803 }
1804 else
1805 {
1806 entry_func.fixup_hooks_in.push_back([=, &var]() {
1807 statement(to_name(var.self), " = ", qual_var_name, vector_swizzle(type_components, start_component),
1808 ";");
1809 });
1810 }
1811 }
1812 else if (!meta.strip_array)
1813 ir.meta[var.self].decoration.qualified_alias = qual_var_name;
1814
1815 if (var.storage == StorageClassOutput && var.initializer != ID(0))
1816 {
1817 if (padded_output || padded_input)
1818 {
1819 entry_func.fixup_hooks_in.push_back(
1820 [=, &var]() { statement(to_name(var.self), " = ", to_expression(var.initializer), ";"); });
1821 }
1822 else
1823 {
1824 entry_func.fixup_hooks_in.push_back(
1825 [=, &var]() { statement(qual_var_name, " = ", to_expression(var.initializer), ";"); });
1826 }
1827 }
1828
1829 // Copy the variable location from the original variable to the member
1830 if (get_decoration_bitset(var.self).get(DecorationLocation))
1831 {
1832 uint32_t locn = get_decoration(var.self, DecorationLocation);
1833 if (storage == StorageClassInput)
1834 {
1835 type_id = ensure_correct_input_type(var.basetype, locn, location_meta ? location_meta->num_components : 0);
1836 if (!location_meta)
1837 var.basetype = type_id;
1838
1839 type_id = get_pointee_type_id(type_id);
1840 if (meta.strip_array && is_array(get<SPIRType>(type_id)))
1841 type_id = get<SPIRType>(type_id).parent_type;
1842 ib_type.member_types[ib_mbr_idx] = type_id;
1843 }
1844 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1845 mark_location_as_used_by_shader(locn, get<SPIRType>(type_id), storage);
1846 }
1847 else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
1848 {
1849 uint32_t locn = inputs_by_builtin[builtin].location;
1850 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1851 mark_location_as_used_by_shader(locn, type, storage);
1852 }
1853
1854 if (!location_meta)
1855 {
1856 if (get_decoration_bitset(var.self).get(DecorationComponent))
1857 {
1858 uint32_t component = get_decoration(var.self, DecorationComponent);
1859 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, component);
1860 }
1861 }
1862
1863 if (get_decoration_bitset(var.self).get(DecorationIndex))
1864 {
1865 uint32_t index = get_decoration(var.self, DecorationIndex);
1866 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
1867 }
1868
1869 // Mark the member as builtin if needed
1870 if (is_builtin)
1871 {
1872 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
1873 if (builtin == BuiltInPosition && storage == StorageClassOutput)
1874 qual_pos_var_name = qual_var_name;
1875 }
1876
1877 // Copy interpolation decorations if needed
1878 if (is_flat)
1879 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
1880 if (is_noperspective)
1881 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
1882 if (is_centroid)
1883 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
1884 if (is_sample)
1885 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
1886
1887 // If we have location meta, there is no unique OrigID. We won't need it, since we flatten/unflatten
1888 // the variable to stack anyways here.
1889 if (!location_meta)
1890 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
1891 }
1892
add_composite_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)1893 void CompilerMSL::add_composite_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
1894 SPIRType &ib_type, SPIRVariable &var,
1895 InterfaceBlockMeta &meta)
1896 {
1897 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1898 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
1899 uint32_t elem_cnt = 0;
1900
1901 if (is_matrix(var_type))
1902 {
1903 if (is_array(var_type))
1904 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
1905
1906 elem_cnt = var_type.columns;
1907 }
1908 else if (is_array(var_type))
1909 {
1910 if (var_type.array.size() != 1)
1911 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
1912
1913 elem_cnt = to_array_size_literal(var_type);
1914 }
1915
1916 bool is_builtin = is_builtin_variable(var);
1917 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
1918 bool is_flat = has_decoration(var.self, DecorationFlat);
1919 bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
1920 bool is_centroid = has_decoration(var.self, DecorationCentroid);
1921 bool is_sample = has_decoration(var.self, DecorationSample);
1922
1923 auto *usable_type = &var_type;
1924 if (usable_type->pointer)
1925 usable_type = &get<SPIRType>(usable_type->parent_type);
1926 while (is_array(*usable_type) || is_matrix(*usable_type))
1927 usable_type = &get<SPIRType>(usable_type->parent_type);
1928
1929 // If a builtin, force it to have the proper name.
1930 if (is_builtin)
1931 set_name(var.self, builtin_to_glsl(builtin, StorageClassFunction));
1932
1933 bool flatten_from_ib_var = false;
1934 string flatten_from_ib_mbr_name;
1935
1936 if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
1937 {
1938 // Also declare [[clip_distance]] attribute here.
1939 uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
1940 ib_type.member_types.push_back(get_variable_data_type_id(var));
1941 set_member_decoration(ib_type.self, clip_array_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
1942
1943 flatten_from_ib_mbr_name = builtin_to_glsl(BuiltInClipDistance, StorageClassOutput);
1944 set_member_name(ib_type.self, clip_array_mbr_idx, flatten_from_ib_mbr_name);
1945
1946 // When we flatten, we flatten directly from the "out" struct,
1947 // not from a function variable.
1948 flatten_from_ib_var = true;
1949
1950 if (!msl_options.enable_clip_distance_user_varying)
1951 return;
1952 }
1953 else if (!meta.strip_array)
1954 {
1955 // Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
1956 entry_func.add_local_variable(var.self);
1957 // We need to declare the variable early and at entry-point scope.
1958 vars_needing_early_declaration.push_back(var.self);
1959 }
1960
1961 for (uint32_t i = 0; i < elem_cnt; i++)
1962 {
1963 // Add a reference to the variable type to the interface struct.
1964 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1965
1966 uint32_t target_components = 0;
1967 bool padded_output = false;
1968 uint32_t type_id = usable_type->self;
1969
1970 // Check if we need to pad fragment output to match a certain number of components.
1971 if (get_decoration_bitset(var.self).get(DecorationLocation) && msl_options.pad_fragment_output_components &&
1972 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput)
1973 {
1974 uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
1975 target_components = get_target_components_for_fragment_location(locn);
1976 if (usable_type->vecsize < target_components)
1977 {
1978 // Make a new type here.
1979 type_id = build_extended_vector_type(usable_type->self, target_components);
1980 padded_output = true;
1981 }
1982 }
1983
1984 ib_type.member_types.push_back(get_pointee_type_id(type_id));
1985
1986 // Give the member a name
1987 string mbr_name = ensure_valid_name(join(to_expression(var.self), "_", i), "m");
1988 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1989
1990 // There is no qualified alias since we need to flatten the internal array on return.
1991 if (get_decoration_bitset(var.self).get(DecorationLocation))
1992 {
1993 uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
1994 if (storage == StorageClassInput)
1995 {
1996 var.basetype = ensure_correct_input_type(var.basetype, locn);
1997 uint32_t mbr_type_id = ensure_correct_input_type(usable_type->self, locn);
1998 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
1999 }
2000 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2001 mark_location_as_used_by_shader(locn, *usable_type, storage);
2002 }
2003 else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2004 {
2005 uint32_t locn = inputs_by_builtin[builtin].location + i;
2006 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2007 mark_location_as_used_by_shader(locn, *usable_type, storage);
2008 }
2009 else if (is_builtin && builtin == BuiltInClipDistance)
2010 {
2011 // Declare the ClipDistance as [[user(clipN)]].
2012 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2013 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, i);
2014 }
2015
2016 if (get_decoration_bitset(var.self).get(DecorationIndex))
2017 {
2018 uint32_t index = get_decoration(var.self, DecorationIndex);
2019 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
2020 }
2021
2022 // Copy interpolation decorations if needed
2023 if (is_flat)
2024 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2025 if (is_noperspective)
2026 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2027 if (is_centroid)
2028 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2029 if (is_sample)
2030 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2031
2032 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2033
2034 // Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
2035 if (!meta.strip_array)
2036 {
2037 switch (storage)
2038 {
2039 case StorageClassInput:
2040 entry_func.fixup_hooks_in.push_back(
2041 [=, &var]() { statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, ";"); });
2042 break;
2043
2044 case StorageClassOutput:
2045 entry_func.fixup_hooks_out.push_back([=, &var]() {
2046 if (padded_output)
2047 {
2048 auto &padded_type = this->get<SPIRType>(type_id);
2049 statement(
2050 ib_var_ref, ".", mbr_name, " = ",
2051 remap_swizzle(padded_type, usable_type->vecsize, join(to_name(var.self), "[", i, "]")),
2052 ";");
2053 }
2054 else if (flatten_from_ib_var)
2055 statement(ib_var_ref, ".", mbr_name, " = ", ib_var_ref, ".", flatten_from_ib_mbr_name, "[", i,
2056 "];");
2057 else
2058 statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), "[", i, "];");
2059 });
2060 break;
2061
2062 default:
2063 break;
2064 }
2065 }
2066 }
2067 }
2068
get_accumulated_member_location(const SPIRVariable & var,uint32_t mbr_idx,bool strip_array)2069 uint32_t CompilerMSL::get_accumulated_member_location(const SPIRVariable &var, uint32_t mbr_idx, bool strip_array)
2070 {
2071 auto &type = strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2072 uint32_t location = get_decoration(var.self, DecorationLocation);
2073
2074 for (uint32_t i = 0; i < mbr_idx; i++)
2075 {
2076 auto &mbr_type = get<SPIRType>(type.member_types[i]);
2077
2078 // Start counting from any place we have a new location decoration.
2079 if (has_member_decoration(type.self, mbr_idx, DecorationLocation))
2080 location = get_member_decoration(type.self, mbr_idx, DecorationLocation);
2081
2082 uint32_t location_count = 1;
2083
2084 if (mbr_type.columns > 1)
2085 location_count = mbr_type.columns;
2086
2087 if (!mbr_type.array.empty())
2088 for (uint32_t j = 0; j < uint32_t(mbr_type.array.size()); j++)
2089 location_count *= to_array_size_literal(mbr_type, j);
2090
2091 location += location_count;
2092 }
2093
2094 return location;
2095 }
2096
add_composite_member_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,uint32_t mbr_idx,InterfaceBlockMeta & meta)2097 void CompilerMSL::add_composite_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2098 SPIRType &ib_type, SPIRVariable &var,
2099 uint32_t mbr_idx, InterfaceBlockMeta &meta)
2100 {
2101 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2102 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2103
2104 BuiltIn builtin = BuiltInMax;
2105 bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2106 bool is_flat =
2107 has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
2108 bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
2109 has_decoration(var.self, DecorationNoPerspective);
2110 bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
2111 has_decoration(var.self, DecorationCentroid);
2112 bool is_sample =
2113 has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
2114
2115 uint32_t mbr_type_id = var_type.member_types[mbr_idx];
2116 auto &mbr_type = get<SPIRType>(mbr_type_id);
2117 uint32_t elem_cnt = 0;
2118
2119 if (is_matrix(mbr_type))
2120 {
2121 if (is_array(mbr_type))
2122 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
2123
2124 elem_cnt = mbr_type.columns;
2125 }
2126 else if (is_array(mbr_type))
2127 {
2128 if (mbr_type.array.size() != 1)
2129 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
2130
2131 elem_cnt = to_array_size_literal(mbr_type);
2132 }
2133
2134 auto *usable_type = &mbr_type;
2135 if (usable_type->pointer)
2136 usable_type = &get<SPIRType>(usable_type->parent_type);
2137 while (is_array(*usable_type) || is_matrix(*usable_type))
2138 usable_type = &get<SPIRType>(usable_type->parent_type);
2139
2140 bool flatten_from_ib_var = false;
2141 string flatten_from_ib_mbr_name;
2142
2143 if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
2144 {
2145 // Also declare [[clip_distance]] attribute here.
2146 uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
2147 ib_type.member_types.push_back(mbr_type_id);
2148 set_member_decoration(ib_type.self, clip_array_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2149
2150 flatten_from_ib_mbr_name = builtin_to_glsl(BuiltInClipDistance, StorageClassOutput);
2151 set_member_name(ib_type.self, clip_array_mbr_idx, flatten_from_ib_mbr_name);
2152
2153 // When we flatten, we flatten directly from the "out" struct,
2154 // not from a function variable.
2155 flatten_from_ib_var = true;
2156
2157 if (!msl_options.enable_clip_distance_user_varying)
2158 return;
2159 }
2160
2161 for (uint32_t i = 0; i < elem_cnt; i++)
2162 {
2163 // Add a reference to the variable type to the interface struct.
2164 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2165 ib_type.member_types.push_back(usable_type->self);
2166
2167 // Give the member a name
2168 string mbr_name = ensure_valid_name(join(to_qualified_member_name(var_type, mbr_idx), "_", i), "m");
2169 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2170
2171 if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
2172 {
2173 uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation) + i;
2174 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2175 mark_location_as_used_by_shader(locn, *usable_type, storage);
2176 }
2177 else if (has_decoration(var.self, DecorationLocation))
2178 {
2179 uint32_t locn = get_accumulated_member_location(var, mbr_idx, meta.strip_array) + i;
2180 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2181 mark_location_as_used_by_shader(locn, *usable_type, storage);
2182 }
2183 else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2184 {
2185 uint32_t locn = inputs_by_builtin[builtin].location + i;
2186 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2187 mark_location_as_used_by_shader(locn, *usable_type, storage);
2188 }
2189 else if (is_builtin && builtin == BuiltInClipDistance)
2190 {
2191 // Declare the ClipDistance as [[user(clipN)]].
2192 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2193 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, i);
2194 }
2195
2196 if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
2197 SPIRV_CROSS_THROW("DecorationComponent on matrices and arrays make little sense.");
2198
2199 // Copy interpolation decorations if needed
2200 if (is_flat)
2201 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2202 if (is_noperspective)
2203 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2204 if (is_centroid)
2205 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2206 if (is_sample)
2207 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2208
2209 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2210 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
2211
2212 // Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
2213 if (!meta.strip_array)
2214 {
2215 switch (storage)
2216 {
2217 case StorageClassInput:
2218 entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
2219 statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), "[", i, "] = ", ib_var_ref,
2220 ".", mbr_name, ";");
2221 });
2222 break;
2223
2224 case StorageClassOutput:
2225 entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
2226 if (flatten_from_ib_var)
2227 {
2228 statement(ib_var_ref, ".", mbr_name, " = ", ib_var_ref, ".", flatten_from_ib_mbr_name, "[", i,
2229 "];");
2230 }
2231 else
2232 {
2233 statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), ".",
2234 to_member_name(var_type, mbr_idx), "[", i, "];");
2235 }
2236 });
2237 break;
2238
2239 default:
2240 break;
2241 }
2242 }
2243 }
2244 }
2245
add_plain_member_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,uint32_t mbr_idx,InterfaceBlockMeta & meta)2246 void CompilerMSL::add_plain_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2247 SPIRType &ib_type, SPIRVariable &var, uint32_t mbr_idx,
2248 InterfaceBlockMeta &meta)
2249 {
2250 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2251 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2252
2253 BuiltIn builtin = BuiltInMax;
2254 bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2255 bool is_flat =
2256 has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
2257 bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
2258 has_decoration(var.self, DecorationNoPerspective);
2259 bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
2260 has_decoration(var.self, DecorationCentroid);
2261 bool is_sample =
2262 has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
2263
2264 // Add a reference to the member to the interface struct.
2265 uint32_t mbr_type_id = var_type.member_types[mbr_idx];
2266 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2267 mbr_type_id = ensure_correct_builtin_type(mbr_type_id, builtin);
2268 var_type.member_types[mbr_idx] = mbr_type_id;
2269 ib_type.member_types.push_back(mbr_type_id);
2270
2271 // Give the member a name
2272 string mbr_name = ensure_valid_name(to_qualified_member_name(var_type, mbr_idx), "m");
2273 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2274
2275 // Update the original variable reference to include the structure reference
2276 string qual_var_name = ib_var_ref + "." + mbr_name;
2277
2278 if (is_builtin && !meta.strip_array)
2279 {
2280 // For the builtin gl_PerVertex, we cannot treat it as a block anyways,
2281 // so redirect to qualified name.
2282 set_member_qualified_name(var_type.self, mbr_idx, qual_var_name);
2283 }
2284 else if (!meta.strip_array)
2285 {
2286 // Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
2287 switch (storage)
2288 {
2289 case StorageClassInput:
2290 entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
2291 statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), " = ", qual_var_name, ";");
2292 });
2293 break;
2294
2295 case StorageClassOutput:
2296 entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
2297 statement(qual_var_name, " = ", to_name(var.self), ".", to_member_name(var_type, mbr_idx), ";");
2298 });
2299 break;
2300
2301 default:
2302 break;
2303 }
2304 }
2305
2306 // Copy the variable location from the original variable to the member
2307 if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
2308 {
2309 uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation);
2310 if (storage == StorageClassInput)
2311 {
2312 mbr_type_id = ensure_correct_input_type(mbr_type_id, locn);
2313 var_type.member_types[mbr_idx] = mbr_type_id;
2314 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2315 }
2316 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2317 mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2318 }
2319 else if (has_decoration(var.self, DecorationLocation))
2320 {
2321 // The block itself might have a location and in this case, all members of the block
2322 // receive incrementing locations.
2323 uint32_t locn = get_accumulated_member_location(var, mbr_idx, meta.strip_array);
2324 if (storage == StorageClassInput)
2325 {
2326 mbr_type_id = ensure_correct_input_type(mbr_type_id, locn);
2327 var_type.member_types[mbr_idx] = mbr_type_id;
2328 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2329 }
2330 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2331 mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2332 }
2333 else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2334 {
2335 uint32_t locn = 0;
2336 auto builtin_itr = inputs_by_builtin.find(builtin);
2337 if (builtin_itr != end(inputs_by_builtin))
2338 locn = builtin_itr->second.location;
2339 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2340 mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2341 }
2342
2343 // Copy the component location, if present.
2344 if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
2345 {
2346 uint32_t comp = get_member_decoration(var_type.self, mbr_idx, DecorationComponent);
2347 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp);
2348 }
2349
2350 // Mark the member as builtin if needed
2351 if (is_builtin)
2352 {
2353 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2354 if (builtin == BuiltInPosition && storage == StorageClassOutput)
2355 qual_pos_var_name = qual_var_name;
2356 }
2357
2358 // Copy interpolation decorations if needed
2359 if (is_flat)
2360 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2361 if (is_noperspective)
2362 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2363 if (is_centroid)
2364 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2365 if (is_sample)
2366 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2367
2368 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2369 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
2370 }
2371
2372 // In Metal, the tessellation levels are stored as tightly packed half-precision floating point values.
2373 // But, stage-in attribute offsets and strides must be multiples of four, so we can't pass the levels
2374 // individually. Therefore, we must pass them as vectors. Triangles get a single float4, with the outer
2375 // levels in 'xyz' and the inner level in 'w'. Quads get a float4 containing the outer levels and a
2376 // float2 containing the inner levels.
add_tess_level_input_to_interface_block(const std::string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var)2377 void CompilerMSL::add_tess_level_input_to_interface_block(const std::string &ib_var_ref, SPIRType &ib_type,
2378 SPIRVariable &var)
2379 {
2380 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2381 auto &var_type = get_variable_element_type(var);
2382
2383 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2384
2385 // Force the variable to have the proper name.
2386 set_name(var.self, builtin_to_glsl(builtin, StorageClassFunction));
2387
2388 if (get_entry_point().flags.get(ExecutionModeTriangles))
2389 {
2390 // Triangles are tricky, because we want only one member in the struct.
2391
2392 // We need to declare the variable early and at entry-point scope.
2393 entry_func.add_local_variable(var.self);
2394 vars_needing_early_declaration.push_back(var.self);
2395
2396 string mbr_name = "gl_TessLevel";
2397
2398 // If we already added the other one, we can skip this step.
2399 if (!added_builtin_tess_level)
2400 {
2401 // Add a reference to the variable type to the interface struct.
2402 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2403
2404 uint32_t type_id = build_extended_vector_type(var_type.self, 4);
2405
2406 ib_type.member_types.push_back(type_id);
2407
2408 // Give the member a name
2409 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2410
2411 // There is no qualified alias since we need to flatten the internal array on return.
2412 if (get_decoration_bitset(var.self).get(DecorationLocation))
2413 {
2414 uint32_t locn = get_decoration(var.self, DecorationLocation);
2415 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2416 mark_location_as_used_by_shader(locn, var_type, StorageClassInput);
2417 }
2418 else if (inputs_by_builtin.count(builtin))
2419 {
2420 uint32_t locn = inputs_by_builtin[builtin].location;
2421 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2422 mark_location_as_used_by_shader(locn, var_type, StorageClassInput);
2423 }
2424
2425 added_builtin_tess_level = true;
2426 }
2427
2428 switch (builtin)
2429 {
2430 case BuiltInTessLevelOuter:
2431 entry_func.fixup_hooks_in.push_back([=, &var]() {
2432 statement(to_name(var.self), "[0] = ", ib_var_ref, ".", mbr_name, ".x;");
2433 statement(to_name(var.self), "[1] = ", ib_var_ref, ".", mbr_name, ".y;");
2434 statement(to_name(var.self), "[2] = ", ib_var_ref, ".", mbr_name, ".z;");
2435 });
2436 break;
2437
2438 case BuiltInTessLevelInner:
2439 entry_func.fixup_hooks_in.push_back(
2440 [=, &var]() { statement(to_name(var.self), "[0] = ", ib_var_ref, ".", mbr_name, ".w;"); });
2441 break;
2442
2443 default:
2444 assert(false);
2445 break;
2446 }
2447 }
2448 else
2449 {
2450 // Add a reference to the variable type to the interface struct.
2451 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2452
2453 uint32_t type_id = build_extended_vector_type(var_type.self, builtin == BuiltInTessLevelOuter ? 4 : 2);
2454 // Change the type of the variable, too.
2455 uint32_t ptr_type_id = ir.increase_bound_by(1);
2456 auto &new_var_type = set<SPIRType>(ptr_type_id, get<SPIRType>(type_id));
2457 new_var_type.pointer = true;
2458 new_var_type.storage = StorageClassInput;
2459 new_var_type.parent_type = type_id;
2460 var.basetype = ptr_type_id;
2461
2462 ib_type.member_types.push_back(type_id);
2463
2464 // Give the member a name
2465 string mbr_name = to_expression(var.self);
2466 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2467
2468 // Since vectors can be indexed like arrays, there is no need to unpack this. We can
2469 // just refer to the vector directly. So give it a qualified alias.
2470 string qual_var_name = ib_var_ref + "." + mbr_name;
2471 ir.meta[var.self].decoration.qualified_alias = qual_var_name;
2472
2473 if (get_decoration_bitset(var.self).get(DecorationLocation))
2474 {
2475 uint32_t locn = get_decoration(var.self, DecorationLocation);
2476 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2477 mark_location_as_used_by_shader(locn, new_var_type, StorageClassInput);
2478 }
2479 else if (inputs_by_builtin.count(builtin))
2480 {
2481 uint32_t locn = inputs_by_builtin[builtin].location;
2482 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2483 mark_location_as_used_by_shader(locn, new_var_type, StorageClassInput);
2484 }
2485 }
2486 }
2487
add_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)2488 void CompilerMSL::add_variable_to_interface_block(StorageClass storage, const string &ib_var_ref, SPIRType &ib_type,
2489 SPIRVariable &var, InterfaceBlockMeta &meta)
2490 {
2491 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2492 // Tessellation control I/O variables and tessellation evaluation per-point inputs are
2493 // usually declared as arrays. In these cases, we want to add the element type to the
2494 // interface block, since in Metal it's the interface block itself which is arrayed.
2495 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2496 bool is_builtin = is_builtin_variable(var);
2497 auto builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2498
2499 if (var_type.basetype == SPIRType::Struct)
2500 {
2501 if (!is_builtin_type(var_type) && (!capture_output_to_buffer || storage == StorageClassInput) &&
2502 !meta.strip_array)
2503 {
2504 // For I/O blocks or structs, we will need to pass the block itself around
2505 // to functions if they are used globally in leaf functions.
2506 // Rather than passing down member by member,
2507 // we unflatten I/O blocks while running the shader,
2508 // and pass the actual struct type down to leaf functions.
2509 // We then unflatten inputs, and flatten outputs in the "fixup" stages.
2510 entry_func.add_local_variable(var.self);
2511 vars_needing_early_declaration.push_back(var.self);
2512 }
2513
2514 if (capture_output_to_buffer && storage != StorageClassInput && !has_decoration(var_type.self, DecorationBlock))
2515 {
2516 // In Metal tessellation shaders, the interface block itself is arrayed. This makes things
2517 // very complicated, since stage-in structures in MSL don't support nested structures.
2518 // Luckily, for stage-out when capturing output, we can avoid this and just add
2519 // composite members directly, because the stage-out structure is stored to a buffer,
2520 // not returned.
2521 add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
2522 }
2523 else
2524 {
2525 // Flatten the struct members into the interface struct
2526 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++)
2527 {
2528 builtin = BuiltInMax;
2529 is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2530 auto &mbr_type = get<SPIRType>(var_type.member_types[mbr_idx]);
2531
2532 if (!is_builtin || has_active_builtin(builtin, storage))
2533 {
2534 bool is_composite_type = is_matrix(mbr_type) || is_array(mbr_type);
2535 bool attribute_load_store =
2536 storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
2537 bool storage_is_stage_io =
2538 (storage == StorageClassInput && !(get_execution_model() == ExecutionModelTessellationControl &&
2539 msl_options.multi_patch_workgroup)) ||
2540 storage == StorageClassOutput;
2541
2542 // ClipDistance always needs to be declared as user attributes.
2543 if (builtin == BuiltInClipDistance)
2544 is_builtin = false;
2545
2546 if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
2547 {
2548 add_composite_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx,
2549 meta);
2550 }
2551 else
2552 {
2553 add_plain_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx, meta);
2554 }
2555 }
2556 }
2557 }
2558 }
2559 else if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput &&
2560 !meta.strip_array && is_builtin && (builtin == BuiltInTessLevelOuter || builtin == BuiltInTessLevelInner))
2561 {
2562 add_tess_level_input_to_interface_block(ib_var_ref, ib_type, var);
2563 }
2564 else if (var_type.basetype == SPIRType::Boolean || var_type.basetype == SPIRType::Char ||
2565 type_is_integral(var_type) || type_is_floating_point(var_type) || var_type.basetype == SPIRType::Boolean)
2566 {
2567 if (!is_builtin || has_active_builtin(builtin, storage))
2568 {
2569 bool is_composite_type = is_matrix(var_type) || is_array(var_type);
2570 bool storage_is_stage_io =
2571 (storage == StorageClassInput &&
2572 !(get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup)) ||
2573 (storage == StorageClassOutput && !capture_output_to_buffer);
2574 bool attribute_load_store = storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
2575
2576 // ClipDistance always needs to be declared as user attributes.
2577 if (builtin == BuiltInClipDistance)
2578 is_builtin = false;
2579
2580 // MSL does not allow matrices or arrays in input or output variables, so need to handle it specially.
2581 if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
2582 {
2583 add_composite_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
2584 }
2585 else
2586 {
2587 add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
2588 }
2589 }
2590 }
2591 }
2592
2593 // Fix up the mapping of variables to interface member indices, which is used to compile access chains
2594 // for per-vertex variables in a tessellation control shader.
fix_up_interface_member_indices(StorageClass storage,uint32_t ib_type_id)2595 void CompilerMSL::fix_up_interface_member_indices(StorageClass storage, uint32_t ib_type_id)
2596 {
2597 // Only needed for tessellation shaders.
2598 // Need to redirect interface indices back to variables themselves.
2599 // For structs, each member of the struct need a separate instance.
2600 if (get_execution_model() != ExecutionModelTessellationControl &&
2601 !(get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput))
2602 return;
2603
2604 auto mbr_cnt = uint32_t(ir.meta[ib_type_id].members.size());
2605 for (uint32_t i = 0; i < mbr_cnt; i++)
2606 {
2607 uint32_t var_id = get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceOrigID);
2608 if (!var_id)
2609 continue;
2610 auto &var = get<SPIRVariable>(var_id);
2611
2612 auto &type = get_variable_element_type(var);
2613 if (storage == StorageClassInput && type.basetype == SPIRType::Struct)
2614 {
2615 uint32_t mbr_idx = get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceMemberIndex);
2616
2617 // Only set the lowest InterfaceMemberIndex for each variable member.
2618 // IB struct members will be emitted in-order w.r.t. interface member index.
2619 if (!has_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex))
2620 set_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, i);
2621 }
2622 else
2623 {
2624 // Only set the lowest InterfaceMemberIndex for each variable.
2625 // IB struct members will be emitted in-order w.r.t. interface member index.
2626 if (!has_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex))
2627 set_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex, i);
2628 }
2629 }
2630 }
2631
2632 // Add an interface structure for the type of storage, which is either StorageClassInput or StorageClassOutput.
2633 // Returns the ID of the newly added variable, or zero if no variable was added.
add_interface_block(StorageClass storage,bool patch)2634 uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch)
2635 {
2636 // Accumulate the variables that should appear in the interface struct.
2637 SmallVector<SPIRVariable *> vars;
2638 bool incl_builtins = storage == StorageClassOutput || is_tessellation_shader();
2639 bool has_seen_barycentric = false;
2640
2641 InterfaceBlockMeta meta;
2642
2643 // Varying interfaces between stages which use "user()" attribute can be dealt with
2644 // without explicit packing and unpacking of components. For any variables which link against the runtime
2645 // in some way (vertex attributes, fragment output, etc), we'll need to deal with it somehow.
2646 bool pack_components =
2647 (storage == StorageClassInput && get_execution_model() == ExecutionModelVertex) ||
2648 (storage == StorageClassOutput && get_execution_model() == ExecutionModelFragment) ||
2649 (storage == StorageClassOutput && get_execution_model() == ExecutionModelVertex && capture_output_to_buffer);
2650
2651 ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
2652 if (var.storage != storage)
2653 return;
2654
2655 auto &type = this->get<SPIRType>(var.basetype);
2656
2657 bool is_builtin = is_builtin_variable(var);
2658 auto bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
2659 uint32_t location = get_decoration(var_id, DecorationLocation);
2660
2661 // These builtins are part of the stage in/out structs.
2662 bool is_interface_block_builtin =
2663 (bi_type == BuiltInPosition || bi_type == BuiltInPointSize || bi_type == BuiltInClipDistance ||
2664 bi_type == BuiltInCullDistance || bi_type == BuiltInLayer || bi_type == BuiltInViewportIndex ||
2665 bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV || bi_type == BuiltInFragDepth ||
2666 bi_type == BuiltInFragStencilRefEXT || bi_type == BuiltInSampleMask) ||
2667 (get_execution_model() == ExecutionModelTessellationEvaluation &&
2668 (bi_type == BuiltInTessLevelOuter || bi_type == BuiltInTessLevelInner));
2669
2670 bool is_active = interface_variable_exists_in_entry_point(var.self);
2671 if (is_builtin && is_active)
2672 {
2673 // Only emit the builtin if it's active in this entry point. Interface variable list might lie.
2674 is_active = has_active_builtin(bi_type, storage);
2675 }
2676
2677 bool filter_patch_decoration = (has_decoration(var_id, DecorationPatch) || is_patch_block(type)) == patch;
2678
2679 bool hidden = is_hidden_variable(var, incl_builtins);
2680
2681 // ClipDistance is never hidden, we need to emulate it when used as an input.
2682 if (bi_type == BuiltInClipDistance)
2683 hidden = false;
2684
2685 // It's not enough to simply avoid marking fragment outputs if the pipeline won't
2686 // accept them. We can't put them in the struct at all, or otherwise the compiler
2687 // complains that the outputs weren't explicitly marked.
2688 if (get_execution_model() == ExecutionModelFragment && storage == StorageClassOutput && !patch &&
2689 ((is_builtin && ((bi_type == BuiltInFragDepth && !msl_options.enable_frag_depth_builtin) ||
2690 (bi_type == BuiltInFragStencilRefEXT && !msl_options.enable_frag_stencil_ref_builtin))) ||
2691 (!is_builtin && !(msl_options.enable_frag_output_mask & (1 << location)))))
2692 {
2693 hidden = true;
2694 disabled_frag_outputs.push_back(var_id);
2695 // If a builtin, force it to have the proper name.
2696 if (is_builtin)
2697 set_name(var_id, builtin_to_glsl(bi_type, StorageClassFunction));
2698 }
2699
2700 // Barycentric inputs must be emitted in stage-in, because they can have interpolation arguments.
2701 if (is_active && (bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV))
2702 {
2703 if (has_seen_barycentric)
2704 SPIRV_CROSS_THROW("Cannot declare both BaryCoordNV and BaryCoordNoPerspNV in same shader in MSL.");
2705 has_seen_barycentric = true;
2706 hidden = false;
2707 }
2708
2709 if (is_active && !hidden && type.pointer && filter_patch_decoration &&
2710 (!is_builtin || is_interface_block_builtin))
2711 {
2712 vars.push_back(&var);
2713
2714 if (!is_builtin)
2715 {
2716 // Need to deal specially with DecorationComponent.
2717 // Multiple variables can alias the same Location, and try to make sure each location is declared only once.
2718 // We will swizzle data in and out to make this work.
2719 // We only need to consider plain variables here, not composites.
2720 // This is only relevant for vertex inputs and fragment outputs.
2721 // Technically tessellation as well, but it is too complicated to support.
2722 uint32_t component = get_decoration(var_id, DecorationComponent);
2723 if (component != 0)
2724 {
2725 if (is_tessellation_shader())
2726 SPIRV_CROSS_THROW("Component decoration is not supported in tessellation shaders.");
2727 else if (pack_components)
2728 {
2729 auto &location_meta = meta.location_meta[location];
2730 location_meta.num_components = std::max(location_meta.num_components, component + type.vecsize);
2731 }
2732 }
2733 }
2734 }
2735 });
2736
2737 // If no variables qualify, leave.
2738 // For patch input in a tessellation evaluation shader, the per-vertex stage inputs
2739 // are included in a special patch control point array.
2740 if (vars.empty() && !(storage == StorageClassInput && patch && stage_in_var_id))
2741 return 0;
2742
2743 // Add a new typed variable for this interface structure.
2744 // The initializer expression is allocated here, but populated when the function
2745 // declaraion is emitted, because it is cleared after each compilation pass.
2746 uint32_t next_id = ir.increase_bound_by(3);
2747 uint32_t ib_type_id = next_id++;
2748 auto &ib_type = set<SPIRType>(ib_type_id);
2749 ib_type.basetype = SPIRType::Struct;
2750 ib_type.storage = storage;
2751 set_decoration(ib_type_id, DecorationBlock);
2752
2753 uint32_t ib_var_id = next_id++;
2754 auto &var = set<SPIRVariable>(ib_var_id, ib_type_id, storage, 0);
2755 var.initializer = next_id++;
2756
2757 string ib_var_ref;
2758 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2759 switch (storage)
2760 {
2761 case StorageClassInput:
2762 ib_var_ref = patch ? patch_stage_in_var_name : stage_in_var_name;
2763 if (get_execution_model() == ExecutionModelTessellationControl)
2764 {
2765 // Add a hook to populate the shared workgroup memory containing the gl_in array.
2766 entry_func.fixup_hooks_in.push_back([=]() {
2767 // Can't use PatchVertices, PrimitiveId, or InvocationId yet; the hooks for those may not have run yet.
2768 if (msl_options.multi_patch_workgroup)
2769 {
2770 // n.b. builtin_invocation_id_id here is the dispatch global invocation ID,
2771 // not the TC invocation ID.
2772 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_in = &",
2773 input_buffer_var_name, "[min(", to_expression(builtin_invocation_id_id), ".x / ",
2774 get_entry_point().output_vertices,
2775 ", spvIndirectParams[1] - 1) * spvIndirectParams[0]];");
2776 }
2777 else
2778 {
2779 // It's safe to use InvocationId here because it's directly mapped to a
2780 // Metal builtin, and therefore doesn't need a hook.
2781 statement("if (", to_expression(builtin_invocation_id_id), " < spvIndirectParams[0])");
2782 statement(" ", input_wg_var_name, "[", to_expression(builtin_invocation_id_id),
2783 "] = ", ib_var_ref, ";");
2784 statement("threadgroup_barrier(mem_flags::mem_threadgroup);");
2785 statement("if (", to_expression(builtin_invocation_id_id),
2786 " >= ", get_entry_point().output_vertices, ")");
2787 statement(" return;");
2788 }
2789 });
2790 }
2791 break;
2792
2793 case StorageClassOutput:
2794 {
2795 ib_var_ref = patch ? patch_stage_out_var_name : stage_out_var_name;
2796
2797 // Add the output interface struct as a local variable to the entry function.
2798 // If the entry point should return the output struct, set the entry function
2799 // to return the output interface struct, otherwise to return nothing.
2800 // Indicate the output var requires early initialization.
2801 bool ep_should_return_output = !get_is_rasterization_disabled();
2802 uint32_t rtn_id = ep_should_return_output ? ib_var_id : 0;
2803 if (!capture_output_to_buffer)
2804 {
2805 entry_func.add_local_variable(ib_var_id);
2806 for (auto &blk_id : entry_func.blocks)
2807 {
2808 auto &blk = get<SPIRBlock>(blk_id);
2809 if (blk.terminator == SPIRBlock::Return)
2810 blk.return_value = rtn_id;
2811 }
2812 vars_needing_early_declaration.push_back(ib_var_id);
2813 }
2814 else
2815 {
2816 switch (get_execution_model())
2817 {
2818 case ExecutionModelVertex:
2819 case ExecutionModelTessellationEvaluation:
2820 // Instead of declaring a struct variable to hold the output and then
2821 // copying that to the output buffer, we'll declare the output variable
2822 // as a reference to the final output element in the buffer. Then we can
2823 // avoid the extra copy.
2824 entry_func.fixup_hooks_in.push_back([=]() {
2825 if (stage_out_var_id)
2826 {
2827 // The first member of the indirect buffer is always the number of vertices
2828 // to draw.
2829 // We zero-base the InstanceID & VertexID variables for HLSL emulation elsewhere, so don't do it twice
2830 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
2831 {
2832 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
2833 " = ", output_buffer_var_name, "[", to_expression(builtin_invocation_id_id),
2834 ".y * ", to_expression(builtin_stage_input_size_id), ".x + ",
2835 to_expression(builtin_invocation_id_id), ".x];");
2836 }
2837 else if (msl_options.enable_base_index_zero)
2838 {
2839 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
2840 " = ", output_buffer_var_name, "[", to_expression(builtin_instance_idx_id),
2841 " * spvIndirectParams[0] + ", to_expression(builtin_vertex_idx_id), "];");
2842 }
2843 else
2844 {
2845 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
2846 " = ", output_buffer_var_name, "[(", to_expression(builtin_instance_idx_id),
2847 " - ", to_expression(builtin_base_instance_id), ") * spvIndirectParams[0] + ",
2848 to_expression(builtin_vertex_idx_id), " - ",
2849 to_expression(builtin_base_vertex_id), "];");
2850 }
2851 }
2852 });
2853 break;
2854 case ExecutionModelTessellationControl:
2855 if (msl_options.multi_patch_workgroup)
2856 {
2857 // We cannot use PrimitiveId here, because the hook may not have run yet.
2858 if (patch)
2859 {
2860 entry_func.fixup_hooks_in.push_back([=]() {
2861 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
2862 " = ", patch_output_buffer_var_name, "[", to_expression(builtin_invocation_id_id),
2863 ".x / ", get_entry_point().output_vertices, "];");
2864 });
2865 }
2866 else
2867 {
2868 entry_func.fixup_hooks_in.push_back([=]() {
2869 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &",
2870 output_buffer_var_name, "[", to_expression(builtin_invocation_id_id), ".x - ",
2871 to_expression(builtin_invocation_id_id), ".x % ",
2872 get_entry_point().output_vertices, "];");
2873 });
2874 }
2875 }
2876 else
2877 {
2878 if (patch)
2879 {
2880 entry_func.fixup_hooks_in.push_back([=]() {
2881 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
2882 " = ", patch_output_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
2883 "];");
2884 });
2885 }
2886 else
2887 {
2888 entry_func.fixup_hooks_in.push_back([=]() {
2889 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &",
2890 output_buffer_var_name, "[", to_expression(builtin_primitive_id_id), " * ",
2891 get_entry_point().output_vertices, "];");
2892 });
2893 }
2894 }
2895 break;
2896 default:
2897 break;
2898 }
2899 }
2900 break;
2901 }
2902
2903 default:
2904 break;
2905 }
2906
2907 set_name(ib_type_id, to_name(ir.default_entry_point) + "_" + ib_var_ref);
2908 set_name(ib_var_id, ib_var_ref);
2909
2910 for (auto *p_var : vars)
2911 {
2912 bool strip_array =
2913 (get_execution_model() == ExecutionModelTessellationControl ||
2914 (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput)) &&
2915 !patch;
2916
2917 meta.strip_array = strip_array;
2918 add_variable_to_interface_block(storage, ib_var_ref, ib_type, *p_var, meta);
2919 }
2920
2921 if (get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup &&
2922 storage == StorageClassInput)
2923 {
2924 // For tessellation control inputs, add all outputs from the vertex shader to ensure
2925 // the struct containing them is the correct size and layout.
2926 for (auto &input : inputs_by_location)
2927 {
2928 if (is_msl_shader_input_used(input.first))
2929 continue;
2930
2931 // Create a fake variable to put at the location.
2932 uint32_t offset = ir.increase_bound_by(4);
2933 uint32_t type_id = offset;
2934 uint32_t array_type_id = offset + 1;
2935 uint32_t ptr_type_id = offset + 2;
2936 uint32_t var_id = offset + 3;
2937
2938 SPIRType type;
2939 switch (input.second.format)
2940 {
2941 case MSL_SHADER_INPUT_FORMAT_UINT16:
2942 case MSL_SHADER_INPUT_FORMAT_ANY16:
2943 type.basetype = SPIRType::UShort;
2944 type.width = 16;
2945 break;
2946 case MSL_SHADER_INPUT_FORMAT_ANY32:
2947 default:
2948 type.basetype = SPIRType::UInt;
2949 type.width = 32;
2950 break;
2951 }
2952 type.vecsize = input.second.vecsize;
2953 set<SPIRType>(type_id, type);
2954
2955 type.array.push_back(0);
2956 type.array_size_literal.push_back(true);
2957 type.parent_type = type_id;
2958 set<SPIRType>(array_type_id, type);
2959
2960 type.pointer = true;
2961 type.parent_type = array_type_id;
2962 type.storage = storage;
2963 auto &ptr_type = set<SPIRType>(ptr_type_id, type);
2964 ptr_type.self = array_type_id;
2965
2966 auto &fake_var = set<SPIRVariable>(var_id, ptr_type_id, storage);
2967 set_decoration(var_id, DecorationLocation, input.first);
2968 meta.strip_array = true;
2969 add_variable_to_interface_block(storage, ib_var_ref, ib_type, fake_var, meta);
2970 }
2971 }
2972
2973 // Sort the members of the structure by their locations.
2974 MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Location);
2975 member_sorter.sort();
2976
2977 // The member indices were saved to the original variables, but after the members
2978 // were sorted, those indices are now likely incorrect. Fix those up now.
2979 if (!patch)
2980 fix_up_interface_member_indices(storage, ib_type_id);
2981
2982 // For patch inputs, add one more member, holding the array of control point data.
2983 if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput && patch &&
2984 stage_in_var_id)
2985 {
2986 uint32_t pcp_type_id = ir.increase_bound_by(1);
2987 auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
2988 pcp_type.basetype = SPIRType::ControlPointArray;
2989 pcp_type.parent_type = pcp_type.type_alias = get_stage_in_struct_type().self;
2990 pcp_type.storage = storage;
2991 ir.meta[pcp_type_id] = ir.meta[ib_type.self];
2992 uint32_t mbr_idx = uint32_t(ib_type.member_types.size());
2993 ib_type.member_types.push_back(pcp_type_id);
2994 set_member_name(ib_type.self, mbr_idx, "gl_in");
2995 }
2996
2997 return ib_var_id;
2998 }
2999
add_interface_block_pointer(uint32_t ib_var_id,StorageClass storage)3000 uint32_t CompilerMSL::add_interface_block_pointer(uint32_t ib_var_id, StorageClass storage)
3001 {
3002 if (!ib_var_id)
3003 return 0;
3004
3005 uint32_t ib_ptr_var_id;
3006 uint32_t next_id = ir.increase_bound_by(3);
3007 auto &ib_type = expression_type(ib_var_id);
3008 if (get_execution_model() == ExecutionModelTessellationControl)
3009 {
3010 // Tessellation control per-vertex I/O is presented as an array, so we must
3011 // do the same with our struct here.
3012 uint32_t ib_ptr_type_id = next_id++;
3013 auto &ib_ptr_type = set<SPIRType>(ib_ptr_type_id, ib_type);
3014 ib_ptr_type.parent_type = ib_ptr_type.type_alias = ib_type.self;
3015 ib_ptr_type.pointer = true;
3016 ib_ptr_type.storage =
3017 storage == StorageClassInput ?
3018 (msl_options.multi_patch_workgroup ? StorageClassStorageBuffer : StorageClassWorkgroup) :
3019 StorageClassStorageBuffer;
3020 ir.meta[ib_ptr_type_id] = ir.meta[ib_type.self];
3021 // To ensure that get_variable_data_type() doesn't strip off the pointer,
3022 // which we need, use another pointer.
3023 uint32_t ib_ptr_ptr_type_id = next_id++;
3024 auto &ib_ptr_ptr_type = set<SPIRType>(ib_ptr_ptr_type_id, ib_ptr_type);
3025 ib_ptr_ptr_type.parent_type = ib_ptr_type_id;
3026 ib_ptr_ptr_type.type_alias = ib_type.self;
3027 ib_ptr_ptr_type.storage = StorageClassFunction;
3028 ir.meta[ib_ptr_ptr_type_id] = ir.meta[ib_type.self];
3029
3030 ib_ptr_var_id = next_id;
3031 set<SPIRVariable>(ib_ptr_var_id, ib_ptr_ptr_type_id, StorageClassFunction, 0);
3032 set_name(ib_ptr_var_id, storage == StorageClassInput ? "gl_in" : "gl_out");
3033 }
3034 else
3035 {
3036 // Tessellation evaluation per-vertex inputs are also presented as arrays.
3037 // But, in Metal, this array uses a very special type, 'patch_control_point<T>',
3038 // which is a container that can be used to access the control point data.
3039 // To represent this, a special 'ControlPointArray' type has been added to the
3040 // SPIRV-Cross type system. It should only be generated by and seen in the MSL
3041 // backend (i.e. this one).
3042 uint32_t pcp_type_id = next_id++;
3043 auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
3044 pcp_type.basetype = SPIRType::ControlPointArray;
3045 pcp_type.parent_type = pcp_type.type_alias = ib_type.self;
3046 pcp_type.storage = storage;
3047 ir.meta[pcp_type_id] = ir.meta[ib_type.self];
3048
3049 ib_ptr_var_id = next_id;
3050 set<SPIRVariable>(ib_ptr_var_id, pcp_type_id, storage, 0);
3051 set_name(ib_ptr_var_id, "gl_in");
3052 ir.meta[ib_ptr_var_id].decoration.qualified_alias = join(patch_stage_in_var_name, ".gl_in");
3053 }
3054 return ib_ptr_var_id;
3055 }
3056
3057 // Ensure that the type is compatible with the builtin.
3058 // If it is, simply return the given type ID.
3059 // Otherwise, create a new type, and return it's ID.
ensure_correct_builtin_type(uint32_t type_id,BuiltIn builtin)3060 uint32_t CompilerMSL::ensure_correct_builtin_type(uint32_t type_id, BuiltIn builtin)
3061 {
3062 auto &type = get<SPIRType>(type_id);
3063
3064 if ((builtin == BuiltInSampleMask && is_array(type)) ||
3065 ((builtin == BuiltInLayer || builtin == BuiltInViewportIndex || builtin == BuiltInFragStencilRefEXT) &&
3066 type.basetype != SPIRType::UInt))
3067 {
3068 uint32_t next_id = ir.increase_bound_by(type.pointer ? 2 : 1);
3069 uint32_t base_type_id = next_id++;
3070 auto &base_type = set<SPIRType>(base_type_id);
3071 base_type.basetype = SPIRType::UInt;
3072 base_type.width = 32;
3073
3074 if (!type.pointer)
3075 return base_type_id;
3076
3077 uint32_t ptr_type_id = next_id++;
3078 auto &ptr_type = set<SPIRType>(ptr_type_id);
3079 ptr_type = base_type;
3080 ptr_type.pointer = true;
3081 ptr_type.storage = type.storage;
3082 ptr_type.parent_type = base_type_id;
3083 return ptr_type_id;
3084 }
3085
3086 return type_id;
3087 }
3088
3089 // Ensure that the type is compatible with the shader input.
3090 // If it is, simply return the given type ID.
3091 // Otherwise, create a new type, and return its ID.
ensure_correct_input_type(uint32_t type_id,uint32_t location,uint32_t num_components)3092 uint32_t CompilerMSL::ensure_correct_input_type(uint32_t type_id, uint32_t location, uint32_t num_components)
3093 {
3094 auto &type = get<SPIRType>(type_id);
3095
3096 auto p_va = inputs_by_location.find(location);
3097 if (p_va == end(inputs_by_location))
3098 {
3099 if (num_components > type.vecsize)
3100 return build_extended_vector_type(type_id, num_components);
3101 else
3102 return type_id;
3103 }
3104
3105 if (num_components == 0)
3106 num_components = p_va->second.vecsize;
3107
3108 switch (p_va->second.format)
3109 {
3110 case MSL_SHADER_INPUT_FORMAT_UINT8:
3111 {
3112 switch (type.basetype)
3113 {
3114 case SPIRType::UByte:
3115 case SPIRType::UShort:
3116 case SPIRType::UInt:
3117 if (num_components > type.vecsize)
3118 return build_extended_vector_type(type_id, num_components);
3119 else
3120 return type_id;
3121
3122 case SPIRType::Short:
3123 return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3124 SPIRType::UShort);
3125 case SPIRType::Int:
3126 return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3127 SPIRType::UInt);
3128
3129 default:
3130 SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
3131 }
3132 }
3133
3134 case MSL_SHADER_INPUT_FORMAT_UINT16:
3135 {
3136 switch (type.basetype)
3137 {
3138 case SPIRType::UShort:
3139 case SPIRType::UInt:
3140 if (num_components > type.vecsize)
3141 return build_extended_vector_type(type_id, num_components);
3142 else
3143 return type_id;
3144
3145 case SPIRType::Int:
3146 return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3147 SPIRType::UInt);
3148
3149 default:
3150 SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
3151 }
3152 }
3153
3154 default:
3155 if (num_components > type.vecsize)
3156 type_id = build_extended_vector_type(type_id, num_components);
3157 break;
3158 }
3159
3160 return type_id;
3161 }
3162
mark_struct_members_packed(const SPIRType & type)3163 void CompilerMSL::mark_struct_members_packed(const SPIRType &type)
3164 {
3165 set_extended_decoration(type.self, SPIRVCrossDecorationPhysicalTypePacked);
3166
3167 // Problem case! Struct needs to be placed at an awkward alignment.
3168 // Mark every member of the child struct as packed.
3169 uint32_t mbr_cnt = uint32_t(type.member_types.size());
3170 for (uint32_t i = 0; i < mbr_cnt; i++)
3171 {
3172 auto &mbr_type = get<SPIRType>(type.member_types[i]);
3173 if (mbr_type.basetype == SPIRType::Struct)
3174 {
3175 // Recursively mark structs as packed.
3176 auto *struct_type = &mbr_type;
3177 while (!struct_type->array.empty())
3178 struct_type = &get<SPIRType>(struct_type->parent_type);
3179 mark_struct_members_packed(*struct_type);
3180 }
3181 else if (!is_scalar(mbr_type))
3182 set_extended_member_decoration(type.self, i, SPIRVCrossDecorationPhysicalTypePacked);
3183 }
3184 }
3185
mark_scalar_layout_structs(const SPIRType & type)3186 void CompilerMSL::mark_scalar_layout_structs(const SPIRType &type)
3187 {
3188 uint32_t mbr_cnt = uint32_t(type.member_types.size());
3189 for (uint32_t i = 0; i < mbr_cnt; i++)
3190 {
3191 auto &mbr_type = get<SPIRType>(type.member_types[i]);
3192 if (mbr_type.basetype == SPIRType::Struct)
3193 {
3194 auto *struct_type = &mbr_type;
3195 while (!struct_type->array.empty())
3196 struct_type = &get<SPIRType>(struct_type->parent_type);
3197
3198 if (has_extended_decoration(struct_type->self, SPIRVCrossDecorationPhysicalTypePacked))
3199 continue;
3200
3201 uint32_t msl_alignment = get_declared_struct_member_alignment_msl(type, i);
3202 uint32_t msl_size = get_declared_struct_member_size_msl(type, i);
3203 uint32_t spirv_offset = type_struct_member_offset(type, i);
3204 uint32_t spirv_offset_next;
3205 if (i + 1 < mbr_cnt)
3206 spirv_offset_next = type_struct_member_offset(type, i + 1);
3207 else
3208 spirv_offset_next = spirv_offset + msl_size;
3209
3210 // Both are complicated cases. In scalar layout, a struct of float3 might just consume 12 bytes,
3211 // and the next member will be placed at offset 12.
3212 bool struct_is_misaligned = (spirv_offset % msl_alignment) != 0;
3213 bool struct_is_too_large = spirv_offset + msl_size > spirv_offset_next;
3214 uint32_t array_stride = 0;
3215 bool struct_needs_explicit_padding = false;
3216
3217 // Verify that if a struct is used as an array that ArrayStride matches the effective size of the struct.
3218 if (!mbr_type.array.empty())
3219 {
3220 array_stride = type_struct_member_array_stride(type, i);
3221 uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
3222 for (uint32_t dim = 0; dim < dimensions; dim++)
3223 {
3224 uint32_t array_size = to_array_size_literal(mbr_type, dim);
3225 array_stride /= max(array_size, 1u);
3226 }
3227
3228 // Set expected struct size based on ArrayStride.
3229 struct_needs_explicit_padding = true;
3230
3231 // If struct size is larger than array stride, we might be able to fit, if we tightly pack.
3232 if (get_declared_struct_size_msl(*struct_type) > array_stride)
3233 struct_is_too_large = true;
3234 }
3235
3236 if (struct_is_misaligned || struct_is_too_large)
3237 mark_struct_members_packed(*struct_type);
3238 mark_scalar_layout_structs(*struct_type);
3239
3240 if (struct_needs_explicit_padding)
3241 {
3242 msl_size = get_declared_struct_size_msl(*struct_type, true, true);
3243 if (array_stride < msl_size)
3244 {
3245 SPIRV_CROSS_THROW("Cannot express an array stride smaller than size of struct type.");
3246 }
3247 else
3248 {
3249 if (has_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget))
3250 {
3251 if (array_stride !=
3252 get_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget))
3253 SPIRV_CROSS_THROW(
3254 "A struct is used with different array strides. Cannot express this in MSL.");
3255 }
3256 else
3257 set_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget, array_stride);
3258 }
3259 }
3260 }
3261 }
3262 }
3263
3264 // Sort the members of the struct type by offset, and pack and then pad members where needed
3265 // to align MSL members with SPIR-V offsets. The struct members are iterated twice. Packing
3266 // occurs first, followed by padding, because packing a member reduces both its size and its
3267 // natural alignment, possibly requiring a padding member to be added ahead of it.
align_struct(SPIRType & ib_type,unordered_set<uint32_t> & aligned_structs)3268 void CompilerMSL::align_struct(SPIRType &ib_type, unordered_set<uint32_t> &aligned_structs)
3269 {
3270 // We align structs recursively, so stop any redundant work.
3271 ID &ib_type_id = ib_type.self;
3272 if (aligned_structs.count(ib_type_id))
3273 return;
3274 aligned_structs.insert(ib_type_id);
3275
3276 // Sort the members of the interface structure by their offset.
3277 // They should already be sorted per SPIR-V spec anyway.
3278 MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Offset);
3279 member_sorter.sort();
3280
3281 auto mbr_cnt = uint32_t(ib_type.member_types.size());
3282
3283 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
3284 {
3285 // Pack any dependent struct types before we pack a parent struct.
3286 auto &mbr_type = get<SPIRType>(ib_type.member_types[mbr_idx]);
3287 if (mbr_type.basetype == SPIRType::Struct)
3288 align_struct(mbr_type, aligned_structs);
3289 }
3290
3291 // Test the alignment of each member, and if a member should be closer to the previous
3292 // member than the default spacing expects, it is likely that the previous member is in
3293 // a packed format. If so, and the previous member is packable, pack it.
3294 // For example ... this applies to any 3-element vector that is followed by a scalar.
3295 uint32_t msl_offset = 0;
3296 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
3297 {
3298 // This checks the member in isolation, if the member needs some kind of type remapping to conform to SPIR-V
3299 // offsets, array strides and matrix strides.
3300 ensure_member_packing_rules_msl(ib_type, mbr_idx);
3301
3302 // Align current offset to the current member's default alignment. If the member was packed, it will observe
3303 // the updated alignment here.
3304 uint32_t msl_align_mask = get_declared_struct_member_alignment_msl(ib_type, mbr_idx) - 1;
3305 uint32_t aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
3306
3307 // Fetch the member offset as declared in the SPIRV.
3308 uint32_t spirv_mbr_offset = get_member_decoration(ib_type_id, mbr_idx, DecorationOffset);
3309 if (spirv_mbr_offset > aligned_msl_offset)
3310 {
3311 // Since MSL and SPIR-V have slightly different struct member alignment and
3312 // size rules, we'll pad to standard C-packing rules with a char[] array. If the member is farther
3313 // away than C-packing, expects, add an inert padding member before the the member.
3314 uint32_t padding_bytes = spirv_mbr_offset - aligned_msl_offset;
3315 set_extended_member_decoration(ib_type_id, mbr_idx, SPIRVCrossDecorationPaddingTarget, padding_bytes);
3316
3317 // Re-align as a sanity check that aligning post-padding matches up.
3318 msl_offset += padding_bytes;
3319 aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
3320 }
3321 else if (spirv_mbr_offset < aligned_msl_offset)
3322 {
3323 // This should not happen, but deal with unexpected scenarios.
3324 // It *might* happen if a sub-struct has a larger alignment requirement in MSL than SPIR-V.
3325 SPIRV_CROSS_THROW("Cannot represent buffer block correctly in MSL.");
3326 }
3327
3328 assert(aligned_msl_offset == spirv_mbr_offset);
3329
3330 // Increment the current offset to be positioned immediately after the current member.
3331 // Don't do this for the last member since it can be unsized, and it is not relevant for padding purposes here.
3332 if (mbr_idx + 1 < mbr_cnt)
3333 msl_offset = aligned_msl_offset + get_declared_struct_member_size_msl(ib_type, mbr_idx);
3334 }
3335 }
3336
validate_member_packing_rules_msl(const SPIRType & type,uint32_t index) const3337 bool CompilerMSL::validate_member_packing_rules_msl(const SPIRType &type, uint32_t index) const
3338 {
3339 auto &mbr_type = get<SPIRType>(type.member_types[index]);
3340 uint32_t spirv_offset = get_member_decoration(type.self, index, DecorationOffset);
3341
3342 if (index + 1 < type.member_types.size())
3343 {
3344 // First, we will check offsets. If SPIR-V offset + MSL size > SPIR-V offset of next member,
3345 // we *must* perform some kind of remapping, no way getting around it.
3346 // We can always pad after this member if necessary, so that case is fine.
3347 uint32_t spirv_offset_next = get_member_decoration(type.self, index + 1, DecorationOffset);
3348 assert(spirv_offset_next >= spirv_offset);
3349 uint32_t maximum_size = spirv_offset_next - spirv_offset;
3350 uint32_t msl_mbr_size = get_declared_struct_member_size_msl(type, index);
3351 if (msl_mbr_size > maximum_size)
3352 return false;
3353 }
3354
3355 if (!mbr_type.array.empty())
3356 {
3357 // If we have an array type, array stride must match exactly with SPIR-V.
3358
3359 // An exception to this requirement is if we have one array element.
3360 // This comes from DX scalar layout workaround.
3361 // If app tries to be cheeky and access the member out of bounds, this will not work, but this is the best we can do.
3362 // In OpAccessChain with logical memory models, access chains must be in-bounds in SPIR-V specification.
3363 bool relax_array_stride = mbr_type.array.back() == 1 && mbr_type.array_size_literal.back();
3364
3365 if (!relax_array_stride)
3366 {
3367 uint32_t spirv_array_stride = type_struct_member_array_stride(type, index);
3368 uint32_t msl_array_stride = get_declared_struct_member_array_stride_msl(type, index);
3369 if (spirv_array_stride != msl_array_stride)
3370 return false;
3371 }
3372 }
3373
3374 if (is_matrix(mbr_type))
3375 {
3376 // Need to check MatrixStride as well.
3377 uint32_t spirv_matrix_stride = type_struct_member_matrix_stride(type, index);
3378 uint32_t msl_matrix_stride = get_declared_struct_member_matrix_stride_msl(type, index);
3379 if (spirv_matrix_stride != msl_matrix_stride)
3380 return false;
3381 }
3382
3383 // Now, we check alignment.
3384 uint32_t msl_alignment = get_declared_struct_member_alignment_msl(type, index);
3385 if ((spirv_offset % msl_alignment) != 0)
3386 return false;
3387
3388 // We're in the clear.
3389 return true;
3390 }
3391
3392 // Here we need to verify that the member type we declare conforms to Offset, ArrayStride or MatrixStride restrictions.
3393 // If there is a mismatch, we need to emit remapped types, either normal types, or "packed_X" types.
3394 // In odd cases we need to emit packed and remapped types, for e.g. weird matrices or arrays with weird array strides.
ensure_member_packing_rules_msl(SPIRType & ib_type,uint32_t index)3395 void CompilerMSL::ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t index)
3396 {
3397 if (validate_member_packing_rules_msl(ib_type, index))
3398 return;
3399
3400 // We failed validation.
3401 // This case will be nightmare-ish to deal with. This could possibly happen if struct alignment does not quite
3402 // match up with what we want. Scalar block layout comes to mind here where we might have to work around the rule
3403 // that struct alignment == max alignment of all members and struct size depends on this alignment.
3404 auto &mbr_type = get<SPIRType>(ib_type.member_types[index]);
3405 if (mbr_type.basetype == SPIRType::Struct)
3406 SPIRV_CROSS_THROW("Cannot perform any repacking for structs when it is used as a member of another struct.");
3407
3408 // Perform remapping here.
3409 // There is nothing to be gained by using packed scalars, so don't attempt it.
3410 if (!is_scalar(ib_type))
3411 set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
3412
3413 // Try validating again, now with packed.
3414 if (validate_member_packing_rules_msl(ib_type, index))
3415 return;
3416
3417 // We're in deep trouble, and we need to create a new PhysicalType which matches up with what we expect.
3418 // A lot of work goes here ...
3419 // We will need remapping on Load and Store to translate the types between Logical and Physical.
3420
3421 // First, we check if we have small vector std140 array.
3422 // We detect this if we have an array of vectors, and array stride is greater than number of elements.
3423 if (!mbr_type.array.empty() && !is_matrix(mbr_type))
3424 {
3425 uint32_t array_stride = type_struct_member_array_stride(ib_type, index);
3426
3427 // Hack off array-of-arrays until we find the array stride per element we must have to make it work.
3428 uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
3429 for (uint32_t dim = 0; dim < dimensions; dim++)
3430 array_stride /= max(to_array_size_literal(mbr_type, dim), 1u);
3431
3432 uint32_t elems_per_stride = array_stride / (mbr_type.width / 8);
3433
3434 if (elems_per_stride == 3)
3435 SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
3436 else if (elems_per_stride > 4)
3437 SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
3438
3439 auto physical_type = mbr_type;
3440 physical_type.vecsize = elems_per_stride;
3441 physical_type.parent_type = 0;
3442 uint32_t type_id = ir.increase_bound_by(1);
3443 set<SPIRType>(type_id, physical_type);
3444 set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID, type_id);
3445 set_decoration(type_id, DecorationArrayStride, array_stride);
3446
3447 // Remove packed_ for vectors of size 1, 2 and 4.
3448 if (has_extended_decoration(ib_type.self, SPIRVCrossDecorationPhysicalTypePacked))
3449 SPIRV_CROSS_THROW("Unable to remove packed decoration as entire struct must be fully packed. Do not mix "
3450 "scalar and std140 layout rules.");
3451 else
3452 unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
3453 }
3454 else if (is_matrix(mbr_type))
3455 {
3456 // MatrixStride might be std140-esque.
3457 uint32_t matrix_stride = type_struct_member_matrix_stride(ib_type, index);
3458
3459 uint32_t elems_per_stride = matrix_stride / (mbr_type.width / 8);
3460
3461 if (elems_per_stride == 3)
3462 SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
3463 else if (elems_per_stride > 4)
3464 SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
3465
3466 bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
3467
3468 auto physical_type = mbr_type;
3469 physical_type.parent_type = 0;
3470 if (row_major)
3471 physical_type.columns = elems_per_stride;
3472 else
3473 physical_type.vecsize = elems_per_stride;
3474 uint32_t type_id = ir.increase_bound_by(1);
3475 set<SPIRType>(type_id, physical_type);
3476 set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID, type_id);
3477
3478 // Remove packed_ for vectors of size 1, 2 and 4.
3479 if (has_extended_decoration(ib_type.self, SPIRVCrossDecorationPhysicalTypePacked))
3480 SPIRV_CROSS_THROW("Unable to remove packed decoration as entire struct must be fully packed. Do not mix "
3481 "scalar and std140 layout rules.");
3482 else
3483 unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
3484 }
3485 else
3486 SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
3487
3488 // Try validating again, now with physical type remapping.
3489 if (validate_member_packing_rules_msl(ib_type, index))
3490 return;
3491
3492 // We might have a particular odd scalar layout case where the last element of an array
3493 // does not take up as much space as the ArrayStride or MatrixStride. This can happen with DX cbuffers.
3494 // The "proper" workaround for this is extremely painful and essentially impossible in the edge case of float3[],
3495 // so we hack around it by declaring the offending array or matrix with one less array size/col/row,
3496 // and rely on padding to get the correct value. We will technically access arrays out of bounds into the padding region,
3497 // but it should spill over gracefully without too much trouble. We rely on behavior like this for unsized arrays anyways.
3498
3499 // E.g. we might observe a physical layout of:
3500 // { float2 a[2]; float b; } in cbuffer layout where ArrayStride of a is 16, but offset of b is 24, packed right after a[1] ...
3501 uint32_t type_id = get_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID);
3502 auto &type = get<SPIRType>(type_id);
3503
3504 // Modify the physical type in-place. This is safe since each physical type workaround is a copy.
3505 if (is_array(type))
3506 {
3507 if (type.array.back() > 1)
3508 {
3509 if (!type.array_size_literal.back())
3510 SPIRV_CROSS_THROW("Cannot apply scalar layout workaround with spec constant array size.");
3511 type.array.back() -= 1;
3512 }
3513 else
3514 {
3515 // We have an array of size 1, so we cannot decrement that. Our only option now is to
3516 // force a packed layout instead, and drop the physical type remap since ArrayStride is meaningless now.
3517 unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID);
3518 set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
3519 }
3520 }
3521 else if (is_matrix(type))
3522 {
3523 bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
3524 if (!row_major)
3525 {
3526 // Slice off one column. If we only have 2 columns, this might turn the matrix into a vector with one array element instead.
3527 if (type.columns > 2)
3528 {
3529 type.columns--;
3530 }
3531 else if (type.columns == 2)
3532 {
3533 type.columns = 1;
3534 assert(type.array.empty());
3535 type.array.push_back(1);
3536 type.array_size_literal.push_back(true);
3537 }
3538 }
3539 else
3540 {
3541 // Slice off one row. If we only have 2 rows, this might turn the matrix into a vector with one array element instead.
3542 if (type.vecsize > 2)
3543 {
3544 type.vecsize--;
3545 }
3546 else if (type.vecsize == 2)
3547 {
3548 type.vecsize = type.columns;
3549 type.columns = 1;
3550 assert(type.array.empty());
3551 type.array.push_back(1);
3552 type.array_size_literal.push_back(true);
3553 }
3554 }
3555 }
3556
3557 // This better validate now, or we must fail gracefully.
3558 if (!validate_member_packing_rules_msl(ib_type, index))
3559 SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
3560 }
3561
emit_store_statement(uint32_t lhs_expression,uint32_t rhs_expression)3562 void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression)
3563 {
3564 auto &type = expression_type(rhs_expression);
3565
3566 bool lhs_remapped_type = has_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypeID);
3567 bool lhs_packed_type = has_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypePacked);
3568 auto *lhs_e = maybe_get<SPIRExpression>(lhs_expression);
3569 auto *rhs_e = maybe_get<SPIRExpression>(rhs_expression);
3570
3571 bool transpose = lhs_e && lhs_e->need_transpose;
3572
3573 // No physical type remapping, and no packed type, so can just emit a store directly.
3574 if (!lhs_remapped_type && !lhs_packed_type)
3575 {
3576 // We might not be dealing with remapped physical types or packed types,
3577 // but we might be doing a clean store to a row-major matrix.
3578 // In this case, we just flip transpose states, and emit the store, a transpose must be in the RHS expression, if any.
3579 if (is_matrix(type) && lhs_e && lhs_e->need_transpose)
3580 {
3581 if (!rhs_e)
3582 SPIRV_CROSS_THROW("Need to transpose right-side expression of a store to row-major matrix, but it is "
3583 "not a SPIRExpression.");
3584 lhs_e->need_transpose = false;
3585
3586 if (rhs_e && rhs_e->need_transpose)
3587 {
3588 // Direct copy, but might need to unpack RHS.
3589 // Skip the transpose, as we will transpose when writing to LHS and transpose(transpose(T)) == T.
3590 rhs_e->need_transpose = false;
3591 statement(to_expression(lhs_expression), " = ", to_unpacked_row_major_matrix_expression(rhs_expression),
3592 ";");
3593 rhs_e->need_transpose = true;
3594 }
3595 else
3596 statement(to_expression(lhs_expression), " = transpose(", to_unpacked_expression(rhs_expression), ");");
3597
3598 lhs_e->need_transpose = true;
3599 register_write(lhs_expression);
3600 }
3601 else if (lhs_e && lhs_e->need_transpose)
3602 {
3603 lhs_e->need_transpose = false;
3604
3605 // Storing a column to a row-major matrix. Unroll the write.
3606 for (uint32_t c = 0; c < type.vecsize; c++)
3607 {
3608 auto lhs_expr = to_dereferenced_expression(lhs_expression);
3609 auto column_index = lhs_expr.find_last_of('[');
3610 if (column_index != string::npos)
3611 {
3612 statement(lhs_expr.insert(column_index, join('[', c, ']')), " = ",
3613 to_extract_component_expression(rhs_expression, c), ";");
3614 }
3615 }
3616 lhs_e->need_transpose = true;
3617 register_write(lhs_expression);
3618 }
3619 else
3620 CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
3621 }
3622 else if (!lhs_remapped_type && !is_matrix(type) && !transpose)
3623 {
3624 // Even if the target type is packed, we can directly store to it. We cannot store to packed matrices directly,
3625 // since they are declared as array of vectors instead, and we need the fallback path below.
3626 CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
3627 }
3628 else
3629 {
3630 // Special handling when storing to a remapped physical type.
3631 // This is mostly to deal with std140 padded matrices or vectors.
3632
3633 TypeID physical_type_id = lhs_remapped_type ?
3634 ID(get_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypeID)) :
3635 type.self;
3636
3637 auto &physical_type = get<SPIRType>(physical_type_id);
3638
3639 if (is_matrix(type))
3640 {
3641 const char *packed_pfx = lhs_packed_type ? "packed_" : "";
3642
3643 // Packed matrices are stored as arrays of packed vectors, so we need
3644 // to assign the vectors one at a time.
3645 // For row-major matrices, we need to transpose the *right-hand* side,
3646 // not the left-hand side.
3647
3648 // Lots of cases to cover here ...
3649
3650 bool rhs_transpose = rhs_e && rhs_e->need_transpose;
3651 SPIRType write_type = type;
3652 string cast_expr;
3653
3654 // We're dealing with transpose manually.
3655 if (rhs_transpose)
3656 rhs_e->need_transpose = false;
3657
3658 if (transpose)
3659 {
3660 // We're dealing with transpose manually.
3661 lhs_e->need_transpose = false;
3662 write_type.vecsize = type.columns;
3663 write_type.columns = 1;
3664
3665 if (physical_type.columns != type.columns)
3666 cast_expr = join("(device ", packed_pfx, type_to_glsl(write_type), "&)");
3667
3668 if (rhs_transpose)
3669 {
3670 // If RHS is also transposed, we can just copy row by row.
3671 for (uint32_t i = 0; i < type.vecsize; i++)
3672 {
3673 statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ",
3674 to_unpacked_row_major_matrix_expression(rhs_expression), "[", i, "];");
3675 }
3676 }
3677 else
3678 {
3679 auto vector_type = expression_type(rhs_expression);
3680 vector_type.vecsize = vector_type.columns;
3681 vector_type.columns = 1;
3682
3683 // Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
3684 // so pick out individual components instead.
3685 for (uint32_t i = 0; i < type.vecsize; i++)
3686 {
3687 string rhs_row = type_to_glsl_constructor(vector_type) + "(";
3688 for (uint32_t j = 0; j < vector_type.vecsize; j++)
3689 {
3690 rhs_row += join(to_enclosed_unpacked_expression(rhs_expression), "[", j, "][", i, "]");
3691 if (j + 1 < vector_type.vecsize)
3692 rhs_row += ", ";
3693 }
3694 rhs_row += ")";
3695
3696 statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";");
3697 }
3698 }
3699
3700 // We're dealing with transpose manually.
3701 lhs_e->need_transpose = true;
3702 }
3703 else
3704 {
3705 write_type.columns = 1;
3706
3707 if (physical_type.vecsize != type.vecsize)
3708 cast_expr = join("(device ", packed_pfx, type_to_glsl(write_type), "&)");
3709
3710 if (rhs_transpose)
3711 {
3712 auto vector_type = expression_type(rhs_expression);
3713 vector_type.columns = 1;
3714
3715 // Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
3716 // so pick out individual components instead.
3717 for (uint32_t i = 0; i < type.columns; i++)
3718 {
3719 string rhs_row = type_to_glsl_constructor(vector_type) + "(";
3720 for (uint32_t j = 0; j < vector_type.vecsize; j++)
3721 {
3722 // Need to explicitly unpack expression since we've mucked with transpose state.
3723 auto unpacked_expr = to_unpacked_row_major_matrix_expression(rhs_expression);
3724 rhs_row += join(unpacked_expr, "[", j, "][", i, "]");
3725 if (j + 1 < vector_type.vecsize)
3726 rhs_row += ", ";
3727 }
3728 rhs_row += ")";
3729
3730 statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";");
3731 }
3732 }
3733 else
3734 {
3735 // Copy column-by-column.
3736 for (uint32_t i = 0; i < type.columns; i++)
3737 {
3738 statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ",
3739 to_enclosed_unpacked_expression(rhs_expression), "[", i, "];");
3740 }
3741 }
3742 }
3743
3744 // We're dealing with transpose manually.
3745 if (rhs_transpose)
3746 rhs_e->need_transpose = true;
3747 }
3748 else if (transpose)
3749 {
3750 lhs_e->need_transpose = false;
3751
3752 SPIRType write_type = type;
3753 write_type.vecsize = 1;
3754 write_type.columns = 1;
3755
3756 // Storing a column to a row-major matrix. Unroll the write.
3757 for (uint32_t c = 0; c < type.vecsize; c++)
3758 {
3759 auto lhs_expr = to_enclosed_expression(lhs_expression);
3760 auto column_index = lhs_expr.find_last_of('[');
3761 if (column_index != string::npos)
3762 {
3763 statement("((device ", type_to_glsl(write_type), "*)&",
3764 lhs_expr.insert(column_index, join('[', c, ']', ")")), " = ",
3765 to_extract_component_expression(rhs_expression, c), ";");
3766 }
3767 }
3768
3769 lhs_e->need_transpose = true;
3770 }
3771 else if ((is_matrix(physical_type) || is_array(physical_type)) && physical_type.vecsize > type.vecsize)
3772 {
3773 assert(type.vecsize >= 1 && type.vecsize <= 3);
3774
3775 // If we have packed types, we cannot use swizzled stores.
3776 // We could technically unroll the store for each element if needed.
3777 // When remapping to a std140 physical type, we always get float4,
3778 // and the packed decoration should always be removed.
3779 assert(!lhs_packed_type);
3780
3781 string lhs = to_dereferenced_expression(lhs_expression);
3782 string rhs = to_pointer_expression(rhs_expression);
3783
3784 // Unpack the expression so we can store to it with a float or float2.
3785 // It's still an l-value, so it's fine. Most other unpacking of expressions turn them into r-values instead.
3786 lhs = join("(device ", type_to_glsl(type), "&)", enclose_expression(lhs));
3787 if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs))
3788 statement(lhs, " = ", rhs, ";");
3789 }
3790 else if (!is_matrix(type))
3791 {
3792 string lhs = to_dereferenced_expression(lhs_expression);
3793 string rhs = to_pointer_expression(rhs_expression);
3794 if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs))
3795 statement(lhs, " = ", rhs, ";");
3796 }
3797
3798 register_write(lhs_expression);
3799 }
3800 }
3801
expression_ends_with(const string & expr_str,const std::string & ending)3802 static bool expression_ends_with(const string &expr_str, const std::string &ending)
3803 {
3804 if (expr_str.length() >= ending.length())
3805 return (expr_str.compare(expr_str.length() - ending.length(), ending.length(), ending) == 0);
3806 else
3807 return false;
3808 }
3809
3810 // Converts the format of the current expression from packed to unpacked,
3811 // by wrapping the expression in a constructor of the appropriate type.
3812 // Also, handle special physical ID remapping scenarios, similar to emit_store_statement().
unpack_expression_type(string expr_str,const SPIRType & type,uint32_t physical_type_id,bool packed,bool row_major)3813 string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type, uint32_t physical_type_id,
3814 bool packed, bool row_major)
3815 {
3816 // Trivial case, nothing to do.
3817 if (physical_type_id == 0 && !packed)
3818 return expr_str;
3819
3820 const SPIRType *physical_type = nullptr;
3821 if (physical_type_id)
3822 physical_type = &get<SPIRType>(physical_type_id);
3823
3824 static const char *swizzle_lut[] = {
3825 ".x",
3826 ".xy",
3827 ".xyz",
3828 };
3829
3830 if (physical_type && is_vector(*physical_type) && is_array(*physical_type) &&
3831 physical_type->vecsize > type.vecsize && !expression_ends_with(expr_str, swizzle_lut[type.vecsize - 1]))
3832 {
3833 // std140 array cases for vectors.
3834 assert(type.vecsize >= 1 && type.vecsize <= 3);
3835 return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1];
3836 }
3837 else if (physical_type && is_matrix(*physical_type) && is_vector(type) && physical_type->vecsize > type.vecsize)
3838 {
3839 // Extract column from padded matrix.
3840 assert(type.vecsize >= 1 && type.vecsize <= 3);
3841 return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1];
3842 }
3843 else if (is_matrix(type))
3844 {
3845 // Packed matrices are stored as arrays of packed vectors. Unfortunately,
3846 // we can't just pass the array straight to the matrix constructor. We have to
3847 // pass each vector individually, so that they can be unpacked to normal vectors.
3848 if (!physical_type)
3849 physical_type = &type;
3850
3851 uint32_t vecsize = type.vecsize;
3852 uint32_t columns = type.columns;
3853 if (row_major)
3854 swap(vecsize, columns);
3855
3856 uint32_t physical_vecsize = row_major ? physical_type->columns : physical_type->vecsize;
3857
3858 const char *base_type = type.width == 16 ? "half" : "float";
3859 string unpack_expr = join(base_type, columns, "x", vecsize, "(");
3860
3861 const char *load_swiz = "";
3862
3863 if (physical_vecsize != vecsize)
3864 load_swiz = swizzle_lut[vecsize - 1];
3865
3866 for (uint32_t i = 0; i < columns; i++)
3867 {
3868 if (i > 0)
3869 unpack_expr += ", ";
3870
3871 if (packed)
3872 unpack_expr += join(base_type, physical_vecsize, "(", expr_str, "[", i, "]", ")", load_swiz);
3873 else
3874 unpack_expr += join(expr_str, "[", i, "]", load_swiz);
3875 }
3876
3877 unpack_expr += ")";
3878 return unpack_expr;
3879 }
3880 else
3881 {
3882 return join(type_to_glsl(type), "(", expr_str, ")");
3883 }
3884 }
3885
3886 // Emits the file header info
emit_header()3887 void CompilerMSL::emit_header()
3888 {
3889 // This particular line can be overridden during compilation, so make it a flag and not a pragma line.
3890 if (suppress_missing_prototypes)
3891 statement("#pragma clang diagnostic ignored \"-Wmissing-prototypes\"");
3892
3893 // Disable warning about missing braces for array<T> template to make arrays a value type
3894 if (spv_function_implementations.count(SPVFuncImplUnsafeArray) != 0)
3895 statement("#pragma clang diagnostic ignored \"-Wmissing-braces\"");
3896
3897 for (auto &pragma : pragma_lines)
3898 statement(pragma);
3899
3900 if (!pragma_lines.empty() || suppress_missing_prototypes)
3901 statement("");
3902
3903 statement("#include <metal_stdlib>");
3904 statement("#include <simd/simd.h>");
3905
3906 for (auto &header : header_lines)
3907 statement(header);
3908
3909 statement("");
3910 statement("using namespace metal;");
3911 statement("");
3912
3913 for (auto &td : typedef_lines)
3914 statement(td);
3915
3916 if (!typedef_lines.empty())
3917 statement("");
3918 }
3919
add_pragma_line(const string & line)3920 void CompilerMSL::add_pragma_line(const string &line)
3921 {
3922 auto rslt = pragma_lines.insert(line);
3923 if (rslt.second)
3924 force_recompile();
3925 }
3926
add_typedef_line(const string & line)3927 void CompilerMSL::add_typedef_line(const string &line)
3928 {
3929 auto rslt = typedef_lines.insert(line);
3930 if (rslt.second)
3931 force_recompile();
3932 }
3933
3934 // Template struct like spvUnsafeArray<> need to be declared *before* any resources are declared
emit_custom_templates()3935 void CompilerMSL::emit_custom_templates()
3936 {
3937 for (const auto &spv_func : spv_function_implementations)
3938 {
3939 switch (spv_func)
3940 {
3941 case SPVFuncImplUnsafeArray:
3942 statement("template<typename T, size_t Num>");
3943 statement("struct spvUnsafeArray");
3944 begin_scope();
3945 statement("T elements[Num ? Num : 1];");
3946 statement("");
3947 statement("thread T& operator [] (size_t pos) thread");
3948 begin_scope();
3949 statement("return elements[pos];");
3950 end_scope();
3951 statement("constexpr const thread T& operator [] (size_t pos) const thread");
3952 begin_scope();
3953 statement("return elements[pos];");
3954 end_scope();
3955 statement("");
3956 statement("device T& operator [] (size_t pos) device");
3957 begin_scope();
3958 statement("return elements[pos];");
3959 end_scope();
3960 statement("constexpr const device T& operator [] (size_t pos) const device");
3961 begin_scope();
3962 statement("return elements[pos];");
3963 end_scope();
3964 statement("");
3965 statement("constexpr const constant T& operator [] (size_t pos) const constant");
3966 begin_scope();
3967 statement("return elements[pos];");
3968 end_scope();
3969 statement("");
3970 statement("threadgroup T& operator [] (size_t pos) threadgroup");
3971 begin_scope();
3972 statement("return elements[pos];");
3973 end_scope();
3974 statement("constexpr const threadgroup T& operator [] (size_t pos) const threadgroup");
3975 begin_scope();
3976 statement("return elements[pos];");
3977 end_scope();
3978 end_scope_decl();
3979 statement("");
3980 break;
3981
3982 default:
3983 break;
3984 }
3985 }
3986 }
3987
3988 // Emits any needed custom function bodies.
3989 // Metal helper functions must be static force-inline, i.e. static inline __attribute__((always_inline))
3990 // otherwise they will cause problems when linked together in a single Metallib.
emit_custom_functions()3991 void CompilerMSL::emit_custom_functions()
3992 {
3993 for (uint32_t i = kArrayCopyMultidimMax; i >= 2; i--)
3994 if (spv_function_implementations.count(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i)))
3995 spv_function_implementations.insert(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i - 1));
3996
3997 if (spv_function_implementations.count(SPVFuncImplDynamicImageSampler))
3998 {
3999 // Unfortunately, this one needs a lot of the other functions to compile OK.
4000 if (!msl_options.supports_msl_version(2))
4001 SPIRV_CROSS_THROW(
4002 "spvDynamicImageSampler requires default-constructible texture objects, which require MSL 2.0.");
4003 spv_function_implementations.insert(SPVFuncImplForwardArgs);
4004 spv_function_implementations.insert(SPVFuncImplTextureSwizzle);
4005 if (msl_options.swizzle_texture_samples)
4006 spv_function_implementations.insert(SPVFuncImplGatherSwizzle);
4007 for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
4008 i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
4009 spv_function_implementations.insert(static_cast<SPVFuncImpl>(i));
4010 spv_function_implementations.insert(SPVFuncImplExpandITUFullRange);
4011 spv_function_implementations.insert(SPVFuncImplExpandITUNarrowRange);
4012 spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT709);
4013 spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT601);
4014 spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT2020);
4015 }
4016
4017 for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
4018 i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
4019 if (spv_function_implementations.count(static_cast<SPVFuncImpl>(i)))
4020 spv_function_implementations.insert(SPVFuncImplForwardArgs);
4021
4022 if (spv_function_implementations.count(SPVFuncImplTextureSwizzle) ||
4023 spv_function_implementations.count(SPVFuncImplGatherSwizzle) ||
4024 spv_function_implementations.count(SPVFuncImplGatherCompareSwizzle))
4025 {
4026 spv_function_implementations.insert(SPVFuncImplForwardArgs);
4027 spv_function_implementations.insert(SPVFuncImplGetSwizzle);
4028 }
4029
4030 for (const auto &spv_func : spv_function_implementations)
4031 {
4032 switch (spv_func)
4033 {
4034 case SPVFuncImplMod:
4035 statement("// Implementation of the GLSL mod() function, which is slightly different than Metal fmod()");
4036 statement("template<typename Tx, typename Ty>");
4037 statement("inline Tx mod(Tx x, Ty y)");
4038 begin_scope();
4039 statement("return x - y * floor(x / y);");
4040 end_scope();
4041 statement("");
4042 break;
4043
4044 case SPVFuncImplRadians:
4045 statement("// Implementation of the GLSL radians() function");
4046 statement("template<typename T>");
4047 statement("inline T radians(T d)");
4048 begin_scope();
4049 statement("return d * T(0.01745329251);");
4050 end_scope();
4051 statement("");
4052 break;
4053
4054 case SPVFuncImplDegrees:
4055 statement("// Implementation of the GLSL degrees() function");
4056 statement("template<typename T>");
4057 statement("inline T degrees(T r)");
4058 begin_scope();
4059 statement("return r * T(57.2957795131);");
4060 end_scope();
4061 statement("");
4062 break;
4063
4064 case SPVFuncImplFindILsb:
4065 statement("// Implementation of the GLSL findLSB() function");
4066 statement("template<typename T>");
4067 statement("inline T spvFindLSB(T x)");
4068 begin_scope();
4069 statement("return select(ctz(x), T(-1), x == T(0));");
4070 end_scope();
4071 statement("");
4072 break;
4073
4074 case SPVFuncImplFindUMsb:
4075 statement("// Implementation of the unsigned GLSL findMSB() function");
4076 statement("template<typename T>");
4077 statement("inline T spvFindUMSB(T x)");
4078 begin_scope();
4079 statement("return select(clz(T(0)) - (clz(x) + T(1)), T(-1), x == T(0));");
4080 end_scope();
4081 statement("");
4082 break;
4083
4084 case SPVFuncImplFindSMsb:
4085 statement("// Implementation of the signed GLSL findMSB() function");
4086 statement("template<typename T>");
4087 statement("inline T spvFindSMSB(T x)");
4088 begin_scope();
4089 statement("T v = select(x, T(-1) - x, x < T(0));");
4090 statement("return select(clz(T(0)) - (clz(v) + T(1)), T(-1), v == T(0));");
4091 end_scope();
4092 statement("");
4093 break;
4094
4095 case SPVFuncImplSSign:
4096 statement("// Implementation of the GLSL sign() function for integer types");
4097 statement("template<typename T, typename E = typename enable_if<is_integral<T>::value>::type>");
4098 statement("inline T sign(T x)");
4099 begin_scope();
4100 statement("return select(select(select(x, T(0), x == T(0)), T(1), x > T(0)), T(-1), x < T(0));");
4101 end_scope();
4102 statement("");
4103 break;
4104
4105 case SPVFuncImplArrayCopy:
4106 case SPVFuncImplArrayOfArrayCopy2Dim:
4107 case SPVFuncImplArrayOfArrayCopy3Dim:
4108 case SPVFuncImplArrayOfArrayCopy4Dim:
4109 case SPVFuncImplArrayOfArrayCopy5Dim:
4110 case SPVFuncImplArrayOfArrayCopy6Dim:
4111 {
4112 // Unfortunately we cannot template on the address space, so combinatorial explosion it is.
4113 static const char *function_name_tags[] = {
4114 "FromConstantToStack", "FromConstantToThreadGroup", "FromStackToStack",
4115 "FromStackToThreadGroup", "FromThreadGroupToStack", "FromThreadGroupToThreadGroup",
4116 "FromDeviceToDevice", "FromConstantToDevice", "FromStackToDevice",
4117 "FromThreadGroupToDevice", "FromDeviceToStack", "FromDeviceToThreadGroup",
4118 };
4119
4120 static const char *src_address_space[] = {
4121 "constant", "constant", "thread const", "thread const",
4122 "threadgroup const", "threadgroup const", "device const", "constant",
4123 "thread const", "threadgroup const", "device const", "device const",
4124 };
4125
4126 static const char *dst_address_space[] = {
4127 "thread", "threadgroup", "thread", "threadgroup", "thread", "threadgroup",
4128 "device", "device", "device", "device", "thread", "threadgroup",
4129 };
4130
4131 for (uint32_t variant = 0; variant < 12; variant++)
4132 {
4133 uint32_t dimensions = spv_func - SPVFuncImplArrayCopyMultidimBase;
4134 string tmp = "template<typename T";
4135 for (uint8_t i = 0; i < dimensions; i++)
4136 {
4137 tmp += ", uint ";
4138 tmp += 'A' + i;
4139 }
4140 tmp += ">";
4141 statement(tmp);
4142
4143 string array_arg;
4144 for (uint8_t i = 0; i < dimensions; i++)
4145 {
4146 array_arg += "[";
4147 array_arg += 'A' + i;
4148 array_arg += "]";
4149 }
4150
4151 statement("inline void spvArrayCopy", function_name_tags[variant], dimensions, "(",
4152 dst_address_space[variant], " T (&dst)", array_arg, ", ", src_address_space[variant],
4153 " T (&src)", array_arg, ")");
4154
4155 begin_scope();
4156 statement("for (uint i = 0; i < A; i++)");
4157 begin_scope();
4158
4159 if (dimensions == 1)
4160 statement("dst[i] = src[i];");
4161 else
4162 statement("spvArrayCopy", function_name_tags[variant], dimensions - 1, "(dst[i], src[i]);");
4163 end_scope();
4164 end_scope();
4165 statement("");
4166 }
4167 break;
4168 }
4169
4170 // Support for Metal 2.1's new texture_buffer type.
4171 case SPVFuncImplTexelBufferCoords:
4172 {
4173 if (msl_options.texel_buffer_texture_width > 0)
4174 {
4175 string tex_width_str = convert_to_string(msl_options.texel_buffer_texture_width);
4176 statement("// Returns 2D texture coords corresponding to 1D texel buffer coords");
4177 statement(force_inline);
4178 statement("uint2 spvTexelBufferCoord(uint tc)");
4179 begin_scope();
4180 statement(join("return uint2(tc % ", tex_width_str, ", tc / ", tex_width_str, ");"));
4181 end_scope();
4182 statement("");
4183 }
4184 else
4185 {
4186 statement("// Returns 2D texture coords corresponding to 1D texel buffer coords");
4187 statement(
4188 "#define spvTexelBufferCoord(tc, tex) uint2((tc) % (tex).get_width(), (tc) / (tex).get_width())");
4189 statement("");
4190 }
4191 break;
4192 }
4193
4194 // Emulate texture2D atomic operations
4195 case SPVFuncImplImage2DAtomicCoords:
4196 {
4197 statement("// Returns buffer coords corresponding to 2D texture coords for emulating 2D texture atomics");
4198 statement("#define spvImage2DAtomicCoord(tc, tex) (((tex).get_width() * (tc).x) + (tc).y)");
4199 statement("");
4200 break;
4201 }
4202
4203 // "fadd" intrinsic support
4204 case SPVFuncImplFAdd:
4205 statement("template<typename T>");
4206 statement("T spvFAdd(T l, T r)");
4207 begin_scope();
4208 statement("return fma(T(1), l, r);");
4209 end_scope();
4210 statement("");
4211 break;
4212
4213 // "fmul' intrinsic support
4214 case SPVFuncImplFMul:
4215 statement("template<typename T>");
4216 statement("T spvFMul(T l, T r)");
4217 begin_scope();
4218 statement("return fma(l, r, T(0));");
4219 end_scope();
4220 statement("");
4221
4222 statement("template<typename T, int Cols, int Rows>");
4223 statement("vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)");
4224 begin_scope();
4225 statement("vec<T, Cols> res = vec<T, Cols>(0);");
4226 statement("for (uint i = Rows; i > 0; --i)");
4227 begin_scope();
4228 statement("vec<T, Cols> tmp(0);");
4229 statement("for (uint j = 0; j < Cols; ++j)");
4230 begin_scope();
4231 statement("tmp[j] = m[j][i - 1];");
4232 end_scope();
4233 statement("res = fma(tmp, vec<T, Cols>(v[i - 1]), res);");
4234 end_scope();
4235 statement("return res;");
4236 end_scope();
4237 statement("");
4238
4239 statement("template<typename T, int Cols, int Rows>");
4240 statement("vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)");
4241 begin_scope();
4242 statement("vec<T, Rows> res = vec<T, Rows>(0);");
4243 statement("for (uint i = Cols; i > 0; --i)");
4244 begin_scope();
4245 statement("res = fma(m[i - 1], vec<T, Rows>(v[i - 1]), res);");
4246 end_scope();
4247 statement("return res;");
4248 end_scope();
4249 statement("");
4250
4251 statement("template<typename T, int LCols, int LRows, int RCols, int RRows>");
4252 statement(
4253 "matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)");
4254 begin_scope();
4255 statement("matrix<T, RCols, LRows> res;");
4256 statement("for (uint i = 0; i < RCols; i++)");
4257 begin_scope();
4258 statement("vec<T, RCols> tmp(0);");
4259 statement("for (uint j = 0; j < LCols; j++)");
4260 begin_scope();
4261 statement("tmp = fma(vec<T, RCols>(r[i][j]), l[j], tmp);");
4262 end_scope();
4263 statement("res[i] = tmp;");
4264 end_scope();
4265 statement("return res;");
4266 end_scope();
4267 statement("");
4268 break;
4269
4270 // Emulate texturecube_array with texture2d_array for iOS where this type is not available
4271 case SPVFuncImplCubemapTo2DArrayFace:
4272 statement(force_inline);
4273 statement("float3 spvCubemapTo2DArrayFace(float3 P)");
4274 begin_scope();
4275 statement("float3 Coords = abs(P.xyz);");
4276 statement("float CubeFace = 0;");
4277 statement("float ProjectionAxis = 0;");
4278 statement("float u = 0;");
4279 statement("float v = 0;");
4280 statement("if (Coords.x >= Coords.y && Coords.x >= Coords.z)");
4281 begin_scope();
4282 statement("CubeFace = P.x >= 0 ? 0 : 1;");
4283 statement("ProjectionAxis = Coords.x;");
4284 statement("u = P.x >= 0 ? -P.z : P.z;");
4285 statement("v = -P.y;");
4286 end_scope();
4287 statement("else if (Coords.y >= Coords.x && Coords.y >= Coords.z)");
4288 begin_scope();
4289 statement("CubeFace = P.y >= 0 ? 2 : 3;");
4290 statement("ProjectionAxis = Coords.y;");
4291 statement("u = P.x;");
4292 statement("v = P.y >= 0 ? P.z : -P.z;");
4293 end_scope();
4294 statement("else");
4295 begin_scope();
4296 statement("CubeFace = P.z >= 0 ? 4 : 5;");
4297 statement("ProjectionAxis = Coords.z;");
4298 statement("u = P.z >= 0 ? P.x : -P.x;");
4299 statement("v = -P.y;");
4300 end_scope();
4301 statement("u = 0.5 * (u/ProjectionAxis + 1);");
4302 statement("v = 0.5 * (v/ProjectionAxis + 1);");
4303 statement("return float3(u, v, CubeFace);");
4304 end_scope();
4305 statement("");
4306 break;
4307
4308 case SPVFuncImplInverse4x4:
4309 statement("// Returns the determinant of a 2x2 matrix.");
4310 statement(force_inline);
4311 statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
4312 begin_scope();
4313 statement("return a1 * b2 - b1 * a2;");
4314 end_scope();
4315 statement("");
4316
4317 statement("// Returns the determinant of a 3x3 matrix.");
4318 statement(force_inline);
4319 statement("float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
4320 "float c2, float c3)");
4321 begin_scope();
4322 statement("return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * spvDet2x2(a2, a3, "
4323 "b2, b3);");
4324 end_scope();
4325 statement("");
4326 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
4327 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
4328 statement(force_inline);
4329 statement("float4x4 spvInverse4x4(float4x4 m)");
4330 begin_scope();
4331 statement("float4x4 adj; // The adjoint matrix (inverse after dividing by determinant)");
4332 statement_no_indent("");
4333 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
4334 statement("adj[0][0] = spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
4335 "m[3][3]);");
4336 statement("adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
4337 "m[3][3]);");
4338 statement("adj[0][2] = spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
4339 "m[3][3]);");
4340 statement("adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
4341 "m[2][3]);");
4342 statement_no_indent("");
4343 statement("adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
4344 "m[3][3]);");
4345 statement("adj[1][1] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
4346 "m[3][3]);");
4347 statement("adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
4348 "m[3][3]);");
4349 statement("adj[1][3] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
4350 "m[2][3]);");
4351 statement_no_indent("");
4352 statement("adj[2][0] = spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
4353 "m[3][3]);");
4354 statement("adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
4355 "m[3][3]);");
4356 statement("adj[2][2] = spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
4357 "m[3][3]);");
4358 statement("adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
4359 "m[2][3]);");
4360 statement_no_indent("");
4361 statement("adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
4362 "m[3][2]);");
4363 statement("adj[3][1] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
4364 "m[3][2]);");
4365 statement("adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
4366 "m[3][2]);");
4367 statement("adj[3][3] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
4368 "m[2][2]);");
4369 statement_no_indent("");
4370 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
4371 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
4372 "* m[3][0]);");
4373 statement_no_indent("");
4374 statement("// Divide the classical adjoint matrix by the determinant.");
4375 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
4376 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
4377 end_scope();
4378 statement("");
4379 break;
4380
4381 case SPVFuncImplInverse3x3:
4382 if (spv_function_implementations.count(SPVFuncImplInverse4x4) == 0)
4383 {
4384 statement("// Returns the determinant of a 2x2 matrix.");
4385 statement(force_inline);
4386 statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
4387 begin_scope();
4388 statement("return a1 * b2 - b1 * a2;");
4389 end_scope();
4390 statement("");
4391 }
4392
4393 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
4394 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
4395 statement(force_inline);
4396 statement("float3x3 spvInverse3x3(float3x3 m)");
4397 begin_scope();
4398 statement("float3x3 adj; // The adjoint matrix (inverse after dividing by determinant)");
4399 statement_no_indent("");
4400 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
4401 statement("adj[0][0] = spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
4402 statement("adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
4403 statement("adj[0][2] = spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
4404 statement_no_indent("");
4405 statement("adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
4406 statement("adj[1][1] = spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
4407 statement("adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
4408 statement_no_indent("");
4409 statement("adj[2][0] = spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
4410 statement("adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
4411 statement("adj[2][2] = spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
4412 statement_no_indent("");
4413 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
4414 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
4415 statement_no_indent("");
4416 statement("// Divide the classical adjoint matrix by the determinant.");
4417 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
4418 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
4419 end_scope();
4420 statement("");
4421 break;
4422
4423 case SPVFuncImplInverse2x2:
4424 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
4425 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
4426 statement(force_inline);
4427 statement("float2x2 spvInverse2x2(float2x2 m)");
4428 begin_scope();
4429 statement("float2x2 adj; // The adjoint matrix (inverse after dividing by determinant)");
4430 statement_no_indent("");
4431 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
4432 statement("adj[0][0] = m[1][1];");
4433 statement("adj[0][1] = -m[0][1];");
4434 statement_no_indent("");
4435 statement("adj[1][0] = -m[1][0];");
4436 statement("adj[1][1] = m[0][0];");
4437 statement_no_indent("");
4438 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
4439 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
4440 statement_no_indent("");
4441 statement("// Divide the classical adjoint matrix by the determinant.");
4442 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
4443 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
4444 end_scope();
4445 statement("");
4446 break;
4447
4448 case SPVFuncImplForwardArgs:
4449 statement("template<typename T> struct spvRemoveReference { typedef T type; };");
4450 statement("template<typename T> struct spvRemoveReference<thread T&> { typedef T type; };");
4451 statement("template<typename T> struct spvRemoveReference<thread T&&> { typedef T type; };");
4452 statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
4453 "spvRemoveReference<T>::type& x)");
4454 begin_scope();
4455 statement("return static_cast<thread T&&>(x);");
4456 end_scope();
4457 statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
4458 "spvRemoveReference<T>::type&& x)");
4459 begin_scope();
4460 statement("return static_cast<thread T&&>(x);");
4461 end_scope();
4462 statement("");
4463 break;
4464
4465 case SPVFuncImplGetSwizzle:
4466 statement("enum class spvSwizzle : uint");
4467 begin_scope();
4468 statement("none = 0,");
4469 statement("zero,");
4470 statement("one,");
4471 statement("red,");
4472 statement("green,");
4473 statement("blue,");
4474 statement("alpha");
4475 end_scope_decl();
4476 statement("");
4477 statement("template<typename T>");
4478 statement("inline T spvGetSwizzle(vec<T, 4> x, T c, spvSwizzle s)");
4479 begin_scope();
4480 statement("switch (s)");
4481 begin_scope();
4482 statement("case spvSwizzle::none:");
4483 statement(" return c;");
4484 statement("case spvSwizzle::zero:");
4485 statement(" return 0;");
4486 statement("case spvSwizzle::one:");
4487 statement(" return 1;");
4488 statement("case spvSwizzle::red:");
4489 statement(" return x.r;");
4490 statement("case spvSwizzle::green:");
4491 statement(" return x.g;");
4492 statement("case spvSwizzle::blue:");
4493 statement(" return x.b;");
4494 statement("case spvSwizzle::alpha:");
4495 statement(" return x.a;");
4496 end_scope();
4497 end_scope();
4498 statement("");
4499 break;
4500
4501 case SPVFuncImplTextureSwizzle:
4502 statement("// Wrapper function that swizzles texture samples and fetches.");
4503 statement("template<typename T>");
4504 statement("inline vec<T, 4> spvTextureSwizzle(vec<T, 4> x, uint s)");
4505 begin_scope();
4506 statement("if (!s)");
4507 statement(" return x;");
4508 statement("return vec<T, 4>(spvGetSwizzle(x, x.r, spvSwizzle((s >> 0) & 0xFF)), "
4509 "spvGetSwizzle(x, x.g, spvSwizzle((s >> 8) & 0xFF)), spvGetSwizzle(x, x.b, spvSwizzle((s >> 16) "
4510 "& 0xFF)), "
4511 "spvGetSwizzle(x, x.a, spvSwizzle((s >> 24) & 0xFF)));");
4512 end_scope();
4513 statement("");
4514 statement("template<typename T>");
4515 statement("inline T spvTextureSwizzle(T x, uint s)");
4516 begin_scope();
4517 statement("return spvTextureSwizzle(vec<T, 4>(x, 0, 0, 1), s).x;");
4518 end_scope();
4519 statement("");
4520 break;
4521
4522 case SPVFuncImplGatherSwizzle:
4523 statement("// Wrapper function that swizzles texture gathers.");
4524 statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
4525 "typename... Ts>");
4526 statement("inline vec<T, 4> spvGatherSwizzle(const thread Tex<T>& t, sampler s, "
4527 "uint sw, component c, Ts... params) METAL_CONST_ARG(c)");
4528 begin_scope();
4529 statement("if (sw)");
4530 begin_scope();
4531 statement("switch (spvSwizzle((sw >> (uint(c) * 8)) & 0xFF))");
4532 begin_scope();
4533 statement("case spvSwizzle::none:");
4534 statement(" break;");
4535 statement("case spvSwizzle::zero:");
4536 statement(" return vec<T, 4>(0, 0, 0, 0);");
4537 statement("case spvSwizzle::one:");
4538 statement(" return vec<T, 4>(1, 1, 1, 1);");
4539 statement("case spvSwizzle::red:");
4540 statement(" return t.gather(s, spvForward<Ts>(params)..., component::x);");
4541 statement("case spvSwizzle::green:");
4542 statement(" return t.gather(s, spvForward<Ts>(params)..., component::y);");
4543 statement("case spvSwizzle::blue:");
4544 statement(" return t.gather(s, spvForward<Ts>(params)..., component::z);");
4545 statement("case spvSwizzle::alpha:");
4546 statement(" return t.gather(s, spvForward<Ts>(params)..., component::w);");
4547 end_scope();
4548 end_scope();
4549 // texture::gather insists on its component parameter being a constant
4550 // expression, so we need this silly workaround just to compile the shader.
4551 statement("switch (c)");
4552 begin_scope();
4553 statement("case component::x:");
4554 statement(" return t.gather(s, spvForward<Ts>(params)..., component::x);");
4555 statement("case component::y:");
4556 statement(" return t.gather(s, spvForward<Ts>(params)..., component::y);");
4557 statement("case component::z:");
4558 statement(" return t.gather(s, spvForward<Ts>(params)..., component::z);");
4559 statement("case component::w:");
4560 statement(" return t.gather(s, spvForward<Ts>(params)..., component::w);");
4561 end_scope();
4562 end_scope();
4563 statement("");
4564 break;
4565
4566 case SPVFuncImplGatherCompareSwizzle:
4567 statement("// Wrapper function that swizzles depth texture gathers.");
4568 statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
4569 "typename... Ts>");
4570 statement("inline vec<T, 4> spvGatherCompareSwizzle(const thread Tex<T>& t, sampler "
4571 "s, uint sw, Ts... params) ");
4572 begin_scope();
4573 statement("if (sw)");
4574 begin_scope();
4575 statement("switch (spvSwizzle(sw & 0xFF))");
4576 begin_scope();
4577 statement("case spvSwizzle::none:");
4578 statement("case spvSwizzle::red:");
4579 statement(" break;");
4580 statement("case spvSwizzle::zero:");
4581 statement("case spvSwizzle::green:");
4582 statement("case spvSwizzle::blue:");
4583 statement("case spvSwizzle::alpha:");
4584 statement(" return vec<T, 4>(0, 0, 0, 0);");
4585 statement("case spvSwizzle::one:");
4586 statement(" return vec<T, 4>(1, 1, 1, 1);");
4587 end_scope();
4588 end_scope();
4589 statement("return t.gather_compare(s, spvForward<Ts>(params)...);");
4590 end_scope();
4591 statement("");
4592 break;
4593
4594 case SPVFuncImplSubgroupBallot:
4595 statement("inline uint4 spvSubgroupBallot(bool value)");
4596 begin_scope();
4597 statement("simd_vote vote = simd_ballot(value);");
4598 statement("// simd_ballot() returns a 64-bit integer-like object, but");
4599 statement("// SPIR-V callers expect a uint4. We must convert.");
4600 statement("// FIXME: This won't include higher bits if Apple ever supports");
4601 statement("// 128 lanes in an SIMD-group.");
4602 statement("return uint4((uint)((simd_vote::vote_t)vote & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)vote >> "
4603 "32) & 0xFFFFFFFF), 0, 0);");
4604 end_scope();
4605 statement("");
4606 break;
4607
4608 case SPVFuncImplSubgroupBallotBitExtract:
4609 statement("inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)");
4610 begin_scope();
4611 statement("return !!extract_bits(ballot[bit / 32], bit % 32, 1);");
4612 end_scope();
4613 statement("");
4614 break;
4615
4616 case SPVFuncImplSubgroupBallotFindLSB:
4617 statement("inline uint spvSubgroupBallotFindLSB(uint4 ballot)");
4618 begin_scope();
4619 statement("return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + "
4620 "ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);");
4621 end_scope();
4622 statement("");
4623 break;
4624
4625 case SPVFuncImplSubgroupBallotFindMSB:
4626 statement("inline uint spvSubgroupBallotFindMSB(uint4 ballot)");
4627 begin_scope();
4628 statement("return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - "
4629 "(clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), "
4630 "ballot.z == 0), ballot.w == 0);");
4631 end_scope();
4632 statement("");
4633 break;
4634
4635 case SPVFuncImplSubgroupBallotBitCount:
4636 statement("inline uint spvSubgroupBallotBitCount(uint4 ballot)");
4637 begin_scope();
4638 statement("return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);");
4639 end_scope();
4640 statement("");
4641 statement("inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
4642 begin_scope();
4643 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), "
4644 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), "
4645 "uint2(0));");
4646 statement("return spvSubgroupBallotBitCount(ballot & mask);");
4647 end_scope();
4648 statement("");
4649 statement("inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
4650 begin_scope();
4651 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), "
4652 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));");
4653 statement("return spvSubgroupBallotBitCount(ballot & mask);");
4654 end_scope();
4655 statement("");
4656 break;
4657
4658 case SPVFuncImplSubgroupAllEqual:
4659 // Metal doesn't provide a function to evaluate this directly. But, we can
4660 // implement this by comparing every thread's value to one thread's value
4661 // (in this case, the value of the first active thread). Then, by the transitive
4662 // property of equality, if all comparisons return true, then they are all equal.
4663 statement("template<typename T>");
4664 statement("inline bool spvSubgroupAllEqual(T value)");
4665 begin_scope();
4666 statement("return simd_all(value == simd_broadcast_first(value));");
4667 end_scope();
4668 statement("");
4669 statement("template<>");
4670 statement("inline bool spvSubgroupAllEqual(bool value)");
4671 begin_scope();
4672 statement("return simd_all(value) || !simd_any(value);");
4673 end_scope();
4674 statement("");
4675 break;
4676
4677 case SPVFuncImplReflectScalar:
4678 // Metal does not support scalar versions of these functions.
4679 statement("template<typename T>");
4680 statement("inline T spvReflect(T i, T n)");
4681 begin_scope();
4682 statement("return i - T(2) * i * n * n;");
4683 end_scope();
4684 statement("");
4685 break;
4686
4687 case SPVFuncImplRefractScalar:
4688 // Metal does not support scalar versions of these functions.
4689 statement("template<typename T>");
4690 statement("inline T spvRefract(T i, T n, T eta)");
4691 begin_scope();
4692 statement("T NoI = n * i;");
4693 statement("T NoI2 = NoI * NoI;");
4694 statement("T k = T(1) - eta * eta * (T(1) - NoI2);");
4695 statement("if (k < T(0))");
4696 begin_scope();
4697 statement("return T(0);");
4698 end_scope();
4699 statement("else");
4700 begin_scope();
4701 statement("return eta * i - (eta * NoI + sqrt(k)) * n;");
4702 end_scope();
4703 end_scope();
4704 statement("");
4705 break;
4706
4707 case SPVFuncImplFaceForwardScalar:
4708 // Metal does not support scalar versions of these functions.
4709 statement("template<typename T>");
4710 statement("inline T spvFaceForward(T n, T i, T nref)");
4711 begin_scope();
4712 statement("return i * nref < T(0) ? n : -n;");
4713 end_scope();
4714 statement("");
4715 break;
4716
4717 case SPVFuncImplChromaReconstructNearest2Plane:
4718 statement("template<typename T, typename... LodOptions>");
4719 statement("inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, sampler "
4720 "samp, float2 coord, LodOptions... options)");
4721 begin_scope();
4722 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4723 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4724 statement("ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
4725 statement("return ycbcr;");
4726 end_scope();
4727 statement("");
4728 break;
4729
4730 case SPVFuncImplChromaReconstructNearest3Plane:
4731 statement("template<typename T, typename... LodOptions>");
4732 statement("inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, "
4733 "texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
4734 begin_scope();
4735 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4736 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4737 statement("ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4738 statement("ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4739 statement("return ycbcr;");
4740 end_scope();
4741 statement("");
4742 break;
4743
4744 case SPVFuncImplChromaReconstructLinear422CositedEven2Plane:
4745 statement("template<typename T, typename... LodOptions>");
4746 statement("inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
4747 "plane1, sampler samp, float2 coord, LodOptions... options)");
4748 begin_scope();
4749 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4750 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4751 statement("if (fract(coord.x * plane1.get_width()) != 0.0)");
4752 begin_scope();
4753 statement("ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4754 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).rg);");
4755 end_scope();
4756 statement("else");
4757 begin_scope();
4758 statement("ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
4759 end_scope();
4760 statement("return ycbcr;");
4761 end_scope();
4762 statement("");
4763 break;
4764
4765 case SPVFuncImplChromaReconstructLinear422CositedEven3Plane:
4766 statement("template<typename T, typename... LodOptions>");
4767 statement("inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
4768 "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
4769 begin_scope();
4770 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4771 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4772 statement("if (fract(coord.x * plane1.get_width()) != 0.0)");
4773 begin_scope();
4774 statement("ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4775 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
4776 statement("ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
4777 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
4778 end_scope();
4779 statement("else");
4780 begin_scope();
4781 statement("ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4782 statement("ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4783 end_scope();
4784 statement("return ycbcr;");
4785 end_scope();
4786 statement("");
4787 break;
4788
4789 case SPVFuncImplChromaReconstructLinear422Midpoint2Plane:
4790 statement("template<typename T, typename... LodOptions>");
4791 statement("inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
4792 "plane1, sampler samp, float2 coord, LodOptions... options)");
4793 begin_scope();
4794 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4795 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4796 statement("int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
4797 statement("ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4798 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).rg);");
4799 statement("return ycbcr;");
4800 end_scope();
4801 statement("");
4802 break;
4803
4804 case SPVFuncImplChromaReconstructLinear422Midpoint3Plane:
4805 statement("template<typename T, typename... LodOptions>");
4806 statement("inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
4807 "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
4808 begin_scope();
4809 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4810 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4811 statement("int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
4812 statement("ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4813 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
4814 statement("ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
4815 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
4816 statement("return ycbcr;");
4817 end_scope();
4818 statement("");
4819 break;
4820
4821 case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane:
4822 statement("template<typename T, typename... LodOptions>");
4823 statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
4824 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
4825 begin_scope();
4826 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4827 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4828 statement("float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
4829 statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4830 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4831 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4832 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
4833 statement("return ycbcr;");
4834 end_scope();
4835 statement("");
4836 break;
4837
4838 case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane:
4839 statement("template<typename T, typename... LodOptions>");
4840 statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
4841 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
4842 begin_scope();
4843 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4844 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4845 statement("float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
4846 statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4847 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4848 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4849 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
4850 statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
4851 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4852 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4853 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
4854 statement("return ycbcr;");
4855 end_scope();
4856 statement("");
4857 break;
4858
4859 case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane:
4860 statement("template<typename T, typename... LodOptions>");
4861 statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
4862 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
4863 begin_scope();
4864 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4865 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4866 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
4867 "0)) * 0.5);");
4868 statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4869 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4870 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4871 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
4872 statement("return ycbcr;");
4873 end_scope();
4874 statement("");
4875 break;
4876
4877 case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane:
4878 statement("template<typename T, typename... LodOptions>");
4879 statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
4880 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
4881 begin_scope();
4882 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4883 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4884 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
4885 "0)) * 0.5);");
4886 statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4887 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4888 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4889 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
4890 statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
4891 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4892 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4893 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
4894 statement("return ycbcr;");
4895 end_scope();
4896 statement("");
4897 break;
4898
4899 case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane:
4900 statement("template<typename T, typename... LodOptions>");
4901 statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
4902 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
4903 begin_scope();
4904 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4905 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4906 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
4907 "0.5)) * 0.5);");
4908 statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4909 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4910 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4911 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
4912 statement("return ycbcr;");
4913 end_scope();
4914 statement("");
4915 break;
4916
4917 case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane:
4918 statement("template<typename T, typename... LodOptions>");
4919 statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
4920 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
4921 begin_scope();
4922 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4923 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4924 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
4925 "0.5)) * 0.5);");
4926 statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4927 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4928 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4929 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
4930 statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
4931 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4932 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4933 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
4934 statement("return ycbcr;");
4935 end_scope();
4936 statement("");
4937 break;
4938
4939 case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane:
4940 statement("template<typename T, typename... LodOptions>");
4941 statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
4942 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
4943 begin_scope();
4944 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4945 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4946 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
4947 "0.5)) * 0.5);");
4948 statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4949 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4950 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4951 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
4952 statement("return ycbcr;");
4953 end_scope();
4954 statement("");
4955 break;
4956
4957 case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane:
4958 statement("template<typename T, typename... LodOptions>");
4959 statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
4960 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
4961 begin_scope();
4962 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4963 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4964 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
4965 "0.5)) * 0.5);");
4966 statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4967 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4968 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4969 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
4970 statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
4971 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4972 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4973 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
4974 statement("return ycbcr;");
4975 end_scope();
4976 statement("");
4977 break;
4978
4979 case SPVFuncImplExpandITUFullRange:
4980 statement("template<typename T>");
4981 statement("inline vec<T, 4> spvExpandITUFullRange(vec<T, 4> ycbcr, int n)");
4982 begin_scope();
4983 statement("ycbcr.br -= exp2(T(n-1))/(exp2(T(n))-1);");
4984 statement("return ycbcr;");
4985 end_scope();
4986 statement("");
4987 break;
4988
4989 case SPVFuncImplExpandITUNarrowRange:
4990 statement("template<typename T>");
4991 statement("inline vec<T, 4> spvExpandITUNarrowRange(vec<T, 4> ycbcr, int n)");
4992 begin_scope();
4993 statement("ycbcr.g = (ycbcr.g * (exp2(T(n)) - 1) - ldexp(T(16), n - 8))/ldexp(T(219), n - 8);");
4994 statement("ycbcr.br = (ycbcr.br * (exp2(T(n)) - 1) - ldexp(T(128), n - 8))/ldexp(T(224), n - 8);");
4995 statement("return ycbcr;");
4996 end_scope();
4997 statement("");
4998 break;
4999
5000 case SPVFuncImplConvertYCbCrBT709:
5001 statement("// cf. Khronos Data Format Specification, section 15.1.1");
5002 statement("constant float3x3 spvBT709Factors = {{1, 1, 1}, {0, -0.13397432/0.7152, 1.8556}, {1.5748, "
5003 "-0.33480248/0.7152, 0}};");
5004 statement("");
5005 statement("template<typename T>");
5006 statement("inline vec<T, 4> spvConvertYCbCrBT709(vec<T, 4> ycbcr)");
5007 begin_scope();
5008 statement("vec<T, 4> rgba;");
5009 statement("rgba.rgb = vec<T, 3>(spvBT709Factors * ycbcr.gbr);");
5010 statement("rgba.a = ycbcr.a;");
5011 statement("return rgba;");
5012 end_scope();
5013 statement("");
5014 break;
5015
5016 case SPVFuncImplConvertYCbCrBT601:
5017 statement("// cf. Khronos Data Format Specification, section 15.1.2");
5018 statement("constant float3x3 spvBT601Factors = {{1, 1, 1}, {0, -0.202008/0.587, 1.772}, {1.402, "
5019 "-0.419198/0.587, 0}};");
5020 statement("");
5021 statement("template<typename T>");
5022 statement("inline vec<T, 4> spvConvertYCbCrBT601(vec<T, 4> ycbcr)");
5023 begin_scope();
5024 statement("vec<T, 4> rgba;");
5025 statement("rgba.rgb = vec<T, 3>(spvBT601Factors * ycbcr.gbr);");
5026 statement("rgba.a = ycbcr.a;");
5027 statement("return rgba;");
5028 end_scope();
5029 statement("");
5030 break;
5031
5032 case SPVFuncImplConvertYCbCrBT2020:
5033 statement("// cf. Khronos Data Format Specification, section 15.1.3");
5034 statement("constant float3x3 spvBT2020Factors = {{1, 1, 1}, {0, -0.11156702/0.6780, 1.8814}, {1.4746, "
5035 "-0.38737742/0.6780, 0}};");
5036 statement("");
5037 statement("template<typename T>");
5038 statement("inline vec<T, 4> spvConvertYCbCrBT2020(vec<T, 4> ycbcr)");
5039 begin_scope();
5040 statement("vec<T, 4> rgba;");
5041 statement("rgba.rgb = vec<T, 3>(spvBT2020Factors * ycbcr.gbr);");
5042 statement("rgba.a = ycbcr.a;");
5043 statement("return rgba;");
5044 end_scope();
5045 statement("");
5046 break;
5047
5048 case SPVFuncImplDynamicImageSampler:
5049 statement("enum class spvFormatResolution");
5050 begin_scope();
5051 statement("_444 = 0,");
5052 statement("_422,");
5053 statement("_420");
5054 end_scope_decl();
5055 statement("");
5056 statement("enum class spvChromaFilter");
5057 begin_scope();
5058 statement("nearest = 0,");
5059 statement("linear");
5060 end_scope_decl();
5061 statement("");
5062 statement("enum class spvXChromaLocation");
5063 begin_scope();
5064 statement("cosited_even = 0,");
5065 statement("midpoint");
5066 end_scope_decl();
5067 statement("");
5068 statement("enum class spvYChromaLocation");
5069 begin_scope();
5070 statement("cosited_even = 0,");
5071 statement("midpoint");
5072 end_scope_decl();
5073 statement("");
5074 statement("enum class spvYCbCrModelConversion");
5075 begin_scope();
5076 statement("rgb_identity = 0,");
5077 statement("ycbcr_identity,");
5078 statement("ycbcr_bt_709,");
5079 statement("ycbcr_bt_601,");
5080 statement("ycbcr_bt_2020");
5081 end_scope_decl();
5082 statement("");
5083 statement("enum class spvYCbCrRange");
5084 begin_scope();
5085 statement("itu_full = 0,");
5086 statement("itu_narrow");
5087 end_scope_decl();
5088 statement("");
5089 statement("struct spvComponentBits");
5090 begin_scope();
5091 statement("constexpr explicit spvComponentBits(int v) thread : value(v) {}");
5092 statement("uchar value : 6;");
5093 end_scope_decl();
5094 statement("// A class corresponding to metal::sampler which holds sampler");
5095 statement("// Y'CbCr conversion info.");
5096 statement("struct spvYCbCrSampler");
5097 begin_scope();
5098 statement("constexpr spvYCbCrSampler() thread : val(build()) {}");
5099 statement("template<typename... Ts>");
5100 statement("constexpr spvYCbCrSampler(Ts... t) thread : val(build(t...)) {}");
5101 statement("constexpr spvYCbCrSampler(const thread spvYCbCrSampler& s) thread = default;");
5102 statement("");
5103 statement("spvFormatResolution get_resolution() const thread");
5104 begin_scope();
5105 statement("return spvFormatResolution((val & resolution_mask) >> resolution_base);");
5106 end_scope();
5107 statement("spvChromaFilter get_chroma_filter() const thread");
5108 begin_scope();
5109 statement("return spvChromaFilter((val & chroma_filter_mask) >> chroma_filter_base);");
5110 end_scope();
5111 statement("spvXChromaLocation get_x_chroma_offset() const thread");
5112 begin_scope();
5113 statement("return spvXChromaLocation((val & x_chroma_off_mask) >> x_chroma_off_base);");
5114 end_scope();
5115 statement("spvYChromaLocation get_y_chroma_offset() const thread");
5116 begin_scope();
5117 statement("return spvYChromaLocation((val & y_chroma_off_mask) >> y_chroma_off_base);");
5118 end_scope();
5119 statement("spvYCbCrModelConversion get_ycbcr_model() const thread");
5120 begin_scope();
5121 statement("return spvYCbCrModelConversion((val & ycbcr_model_mask) >> ycbcr_model_base);");
5122 end_scope();
5123 statement("spvYCbCrRange get_ycbcr_range() const thread");
5124 begin_scope();
5125 statement("return spvYCbCrRange((val & ycbcr_range_mask) >> ycbcr_range_base);");
5126 end_scope();
5127 statement("int get_bpc() const thread { return (val & bpc_mask) >> bpc_base; }");
5128 statement("");
5129 statement("private:");
5130 statement("ushort val;");
5131 statement("");
5132 statement("constexpr static constant ushort resolution_bits = 2;");
5133 statement("constexpr static constant ushort chroma_filter_bits = 2;");
5134 statement("constexpr static constant ushort x_chroma_off_bit = 1;");
5135 statement("constexpr static constant ushort y_chroma_off_bit = 1;");
5136 statement("constexpr static constant ushort ycbcr_model_bits = 3;");
5137 statement("constexpr static constant ushort ycbcr_range_bit = 1;");
5138 statement("constexpr static constant ushort bpc_bits = 6;");
5139 statement("");
5140 statement("constexpr static constant ushort resolution_base = 0;");
5141 statement("constexpr static constant ushort chroma_filter_base = 2;");
5142 statement("constexpr static constant ushort x_chroma_off_base = 4;");
5143 statement("constexpr static constant ushort y_chroma_off_base = 5;");
5144 statement("constexpr static constant ushort ycbcr_model_base = 6;");
5145 statement("constexpr static constant ushort ycbcr_range_base = 9;");
5146 statement("constexpr static constant ushort bpc_base = 10;");
5147 statement("");
5148 statement(
5149 "constexpr static constant ushort resolution_mask = ((1 << resolution_bits) - 1) << resolution_base;");
5150 statement("constexpr static constant ushort chroma_filter_mask = ((1 << chroma_filter_bits) - 1) << "
5151 "chroma_filter_base;");
5152 statement("constexpr static constant ushort x_chroma_off_mask = ((1 << x_chroma_off_bit) - 1) << "
5153 "x_chroma_off_base;");
5154 statement("constexpr static constant ushort y_chroma_off_mask = ((1 << y_chroma_off_bit) - 1) << "
5155 "y_chroma_off_base;");
5156 statement("constexpr static constant ushort ycbcr_model_mask = ((1 << ycbcr_model_bits) - 1) << "
5157 "ycbcr_model_base;");
5158 statement("constexpr static constant ushort ycbcr_range_mask = ((1 << ycbcr_range_bit) - 1) << "
5159 "ycbcr_range_base;");
5160 statement("constexpr static constant ushort bpc_mask = ((1 << bpc_bits) - 1) << bpc_base;");
5161 statement("");
5162 statement("static constexpr ushort build()");
5163 begin_scope();
5164 statement("return 0;");
5165 end_scope();
5166 statement("");
5167 statement("template<typename... Ts>");
5168 statement("static constexpr ushort build(spvFormatResolution res, Ts... t)");
5169 begin_scope();
5170 statement("return (ushort(res) << resolution_base) | (build(t...) & ~resolution_mask);");
5171 end_scope();
5172 statement("");
5173 statement("template<typename... Ts>");
5174 statement("static constexpr ushort build(spvChromaFilter filt, Ts... t)");
5175 begin_scope();
5176 statement("return (ushort(filt) << chroma_filter_base) | (build(t...) & ~chroma_filter_mask);");
5177 end_scope();
5178 statement("");
5179 statement("template<typename... Ts>");
5180 statement("static constexpr ushort build(spvXChromaLocation loc, Ts... t)");
5181 begin_scope();
5182 statement("return (ushort(loc) << x_chroma_off_base) | (build(t...) & ~x_chroma_off_mask);");
5183 end_scope();
5184 statement("");
5185 statement("template<typename... Ts>");
5186 statement("static constexpr ushort build(spvYChromaLocation loc, Ts... t)");
5187 begin_scope();
5188 statement("return (ushort(loc) << y_chroma_off_base) | (build(t...) & ~y_chroma_off_mask);");
5189 end_scope();
5190 statement("");
5191 statement("template<typename... Ts>");
5192 statement("static constexpr ushort build(spvYCbCrModelConversion model, Ts... t)");
5193 begin_scope();
5194 statement("return (ushort(model) << ycbcr_model_base) | (build(t...) & ~ycbcr_model_mask);");
5195 end_scope();
5196 statement("");
5197 statement("template<typename... Ts>");
5198 statement("static constexpr ushort build(spvYCbCrRange range, Ts... t)");
5199 begin_scope();
5200 statement("return (ushort(range) << ycbcr_range_base) | (build(t...) & ~ycbcr_range_mask);");
5201 end_scope();
5202 statement("");
5203 statement("template<typename... Ts>");
5204 statement("static constexpr ushort build(spvComponentBits bpc, Ts... t)");
5205 begin_scope();
5206 statement("return (ushort(bpc.value) << bpc_base) | (build(t...) & ~bpc_mask);");
5207 end_scope();
5208 end_scope_decl();
5209 statement("");
5210 statement("// A class which can hold up to three textures and a sampler, including");
5211 statement("// Y'CbCr conversion info, used to pass combined image-samplers");
5212 statement("// dynamically to functions.");
5213 statement("template<typename T>");
5214 statement("struct spvDynamicImageSampler");
5215 begin_scope();
5216 statement("texture2d<T> plane0;");
5217 statement("texture2d<T> plane1;");
5218 statement("texture2d<T> plane2;");
5219 statement("sampler samp;");
5220 statement("spvYCbCrSampler ycbcr_samp;");
5221 statement("uint swizzle = 0;");
5222 statement("");
5223 if (msl_options.swizzle_texture_samples)
5224 {
5225 statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, uint sw) thread :");
5226 statement(" plane0(tex), samp(samp), swizzle(sw) {}");
5227 }
5228 else
5229 {
5230 statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp) thread :");
5231 statement(" plane0(tex), samp(samp) {}");
5232 }
5233 statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, spvYCbCrSampler ycbcr_samp, "
5234 "uint sw) thread :");
5235 statement(" plane0(tex), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
5236 statement("constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1,");
5237 statement(" sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
5238 statement(" plane0(plane0), plane1(plane1), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
5239 statement(
5240 "constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1, texture2d<T> plane2,");
5241 statement(" sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
5242 statement(" plane0(plane0), plane1(plane1), plane2(plane2), samp(samp), ycbcr_samp(ycbcr_samp), "
5243 "swizzle(sw) {}");
5244 statement("");
5245 // XXX This is really hard to follow... I've left comments to make it a bit easier.
5246 statement("template<typename... LodOptions>");
5247 statement("vec<T, 4> do_sample(float2 coord, LodOptions... options) const thread");
5248 begin_scope();
5249 statement("if (!is_null_texture(plane1))");
5250 begin_scope();
5251 statement("if (ycbcr_samp.get_resolution() == spvFormatResolution::_444 ||");
5252 statement(" ycbcr_samp.get_chroma_filter() == spvChromaFilter::nearest)");
5253 begin_scope();
5254 statement("if (!is_null_texture(plane2))");
5255 statement(" return spvChromaReconstructNearest(plane0, plane1, plane2, samp, coord,");
5256 statement(" spvForward<LodOptions>(options)...);");
5257 statement(
5258 "return spvChromaReconstructNearest(plane0, plane1, samp, coord, spvForward<LodOptions>(options)...);");
5259 end_scope(); // if (resolution == 422 || chroma_filter == nearest)
5260 statement("switch (ycbcr_samp.get_resolution())");
5261 begin_scope();
5262 statement("case spvFormatResolution::_444: break;");
5263 statement("case spvFormatResolution::_422:");
5264 begin_scope();
5265 statement("switch (ycbcr_samp.get_x_chroma_offset())");
5266 begin_scope();
5267 statement("case spvXChromaLocation::cosited_even:");
5268 statement(" if (!is_null_texture(plane2))");
5269 statement(" return spvChromaReconstructLinear422CositedEven(");
5270 statement(" plane0, plane1, plane2, samp,");
5271 statement(" coord, spvForward<LodOptions>(options)...);");
5272 statement(" return spvChromaReconstructLinear422CositedEven(");
5273 statement(" plane0, plane1, samp, coord,");
5274 statement(" spvForward<LodOptions>(options)...);");
5275 statement("case spvXChromaLocation::midpoint:");
5276 statement(" if (!is_null_texture(plane2))");
5277 statement(" return spvChromaReconstructLinear422Midpoint(");
5278 statement(" plane0, plane1, plane2, samp,");
5279 statement(" coord, spvForward<LodOptions>(options)...);");
5280 statement(" return spvChromaReconstructLinear422Midpoint(");
5281 statement(" plane0, plane1, samp, coord,");
5282 statement(" spvForward<LodOptions>(options)...);");
5283 end_scope(); // switch (x_chroma_offset)
5284 end_scope(); // case 422:
5285 statement("case spvFormatResolution::_420:");
5286 begin_scope();
5287 statement("switch (ycbcr_samp.get_x_chroma_offset())");
5288 begin_scope();
5289 statement("case spvXChromaLocation::cosited_even:");
5290 begin_scope();
5291 statement("switch (ycbcr_samp.get_y_chroma_offset())");
5292 begin_scope();
5293 statement("case spvYChromaLocation::cosited_even:");
5294 statement(" if (!is_null_texture(plane2))");
5295 statement(" return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
5296 statement(" plane0, plane1, plane2, samp,");
5297 statement(" coord, spvForward<LodOptions>(options)...);");
5298 statement(" return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
5299 statement(" plane0, plane1, samp, coord,");
5300 statement(" spvForward<LodOptions>(options)...);");
5301 statement("case spvYChromaLocation::midpoint:");
5302 statement(" if (!is_null_texture(plane2))");
5303 statement(" return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
5304 statement(" plane0, plane1, plane2, samp,");
5305 statement(" coord, spvForward<LodOptions>(options)...);");
5306 statement(" return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
5307 statement(" plane0, plane1, samp, coord,");
5308 statement(" spvForward<LodOptions>(options)...);");
5309 end_scope(); // switch (y_chroma_offset)
5310 end_scope(); // case x::cosited_even:
5311 statement("case spvXChromaLocation::midpoint:");
5312 begin_scope();
5313 statement("switch (ycbcr_samp.get_y_chroma_offset())");
5314 begin_scope();
5315 statement("case spvYChromaLocation::cosited_even:");
5316 statement(" if (!is_null_texture(plane2))");
5317 statement(" return spvChromaReconstructLinear420XMidpointYCositedEven(");
5318 statement(" plane0, plane1, plane2, samp,");
5319 statement(" coord, spvForward<LodOptions>(options)...);");
5320 statement(" return spvChromaReconstructLinear420XMidpointYCositedEven(");
5321 statement(" plane0, plane1, samp, coord,");
5322 statement(" spvForward<LodOptions>(options)...);");
5323 statement("case spvYChromaLocation::midpoint:");
5324 statement(" if (!is_null_texture(plane2))");
5325 statement(" return spvChromaReconstructLinear420XMidpointYMidpoint(");
5326 statement(" plane0, plane1, plane2, samp,");
5327 statement(" coord, spvForward<LodOptions>(options)...);");
5328 statement(" return spvChromaReconstructLinear420XMidpointYMidpoint(");
5329 statement(" plane0, plane1, samp, coord,");
5330 statement(" spvForward<LodOptions>(options)...);");
5331 end_scope(); // switch (y_chroma_offset)
5332 end_scope(); // case x::midpoint
5333 end_scope(); // switch (x_chroma_offset)
5334 end_scope(); // case 420:
5335 end_scope(); // switch (resolution)
5336 end_scope(); // if (multiplanar)
5337 statement("return plane0.sample(samp, coord, spvForward<LodOptions>(options)...);");
5338 end_scope(); // do_sample()
5339 statement("template <typename... LodOptions>");
5340 statement("vec<T, 4> sample(float2 coord, LodOptions... options) const thread");
5341 begin_scope();
5342 statement(
5343 "vec<T, 4> s = spvTextureSwizzle(do_sample(coord, spvForward<LodOptions>(options)...), swizzle);");
5344 statement("if (ycbcr_samp.get_ycbcr_model() == spvYCbCrModelConversion::rgb_identity)");
5345 statement(" return s;");
5346 statement("");
5347 statement("switch (ycbcr_samp.get_ycbcr_range())");
5348 begin_scope();
5349 statement("case spvYCbCrRange::itu_full:");
5350 statement(" s = spvExpandITUFullRange(s, ycbcr_samp.get_bpc());");
5351 statement(" break;");
5352 statement("case spvYCbCrRange::itu_narrow:");
5353 statement(" s = spvExpandITUNarrowRange(s, ycbcr_samp.get_bpc());");
5354 statement(" break;");
5355 end_scope();
5356 statement("");
5357 statement("switch (ycbcr_samp.get_ycbcr_model())");
5358 begin_scope();
5359 statement("case spvYCbCrModelConversion::rgb_identity:"); // Silence Clang warning
5360 statement("case spvYCbCrModelConversion::ycbcr_identity:");
5361 statement(" return s;");
5362 statement("case spvYCbCrModelConversion::ycbcr_bt_709:");
5363 statement(" return spvConvertYCbCrBT709(s);");
5364 statement("case spvYCbCrModelConversion::ycbcr_bt_601:");
5365 statement(" return spvConvertYCbCrBT601(s);");
5366 statement("case spvYCbCrModelConversion::ycbcr_bt_2020:");
5367 statement(" return spvConvertYCbCrBT2020(s);");
5368 end_scope();
5369 end_scope();
5370 statement("");
5371 // Sampler Y'CbCr conversion forbids offsets.
5372 statement("vec<T, 4> sample(float2 coord, int2 offset) const thread");
5373 begin_scope();
5374 if (msl_options.swizzle_texture_samples)
5375 statement("return spvTextureSwizzle(plane0.sample(samp, coord, offset), swizzle);");
5376 else
5377 statement("return plane0.sample(samp, coord, offset);");
5378 end_scope();
5379 statement("template<typename lod_options>");
5380 statement("vec<T, 4> sample(float2 coord, lod_options options, int2 offset) const thread");
5381 begin_scope();
5382 if (msl_options.swizzle_texture_samples)
5383 statement("return spvTextureSwizzle(plane0.sample(samp, coord, options, offset), swizzle);");
5384 else
5385 statement("return plane0.sample(samp, coord, options, offset);");
5386 end_scope();
5387 statement("#if __HAVE_MIN_LOD_CLAMP__");
5388 statement("vec<T, 4> sample(float2 coord, bias b, min_lod_clamp min_lod, int2 offset) const thread");
5389 begin_scope();
5390 statement("return plane0.sample(samp, coord, b, min_lod, offset);");
5391 end_scope();
5392 statement(
5393 "vec<T, 4> sample(float2 coord, gradient2d grad, min_lod_clamp min_lod, int2 offset) const thread");
5394 begin_scope();
5395 statement("return plane0.sample(samp, coord, grad, min_lod, offset);");
5396 end_scope();
5397 statement("#endif");
5398 statement("");
5399 // Y'CbCr conversion forbids all operations but sampling.
5400 statement("vec<T, 4> read(uint2 coord, uint lod = 0) const thread");
5401 begin_scope();
5402 statement("return plane0.read(coord, lod);");
5403 end_scope();
5404 statement("");
5405 statement("vec<T, 4> gather(float2 coord, int2 offset = int2(0), component c = component::x) const thread");
5406 begin_scope();
5407 if (msl_options.swizzle_texture_samples)
5408 statement("return spvGatherSwizzle(plane0, samp, swizzle, c, coord, offset);");
5409 else
5410 statement("return plane0.gather(samp, coord, offset, c);");
5411 end_scope();
5412 end_scope_decl();
5413 statement("");
5414
5415 default:
5416 break;
5417 }
5418 }
5419 }
5420
5421 // Undefined global memory is not allowed in MSL.
5422 // Declare constant and init to zeros. Use {}, as global constructors can break Metal.
declare_undefined_values()5423 void CompilerMSL::declare_undefined_values()
5424 {
5425 bool emitted = false;
5426 ir.for_each_typed_id<SPIRUndef>([&](uint32_t, SPIRUndef &undef) {
5427 auto &type = this->get<SPIRType>(undef.basetype);
5428 // OpUndef can be void for some reason ...
5429 if (type.basetype == SPIRType::Void)
5430 return;
5431
5432 statement("constant ", variable_decl(type, to_name(undef.self), undef.self), " = {};");
5433 emitted = true;
5434 });
5435
5436 if (emitted)
5437 statement("");
5438 }
5439
declare_constant_arrays()5440 void CompilerMSL::declare_constant_arrays()
5441 {
5442 bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
5443
5444 // MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
5445 // global constants directly, so we are able to use constants as variable expressions.
5446 bool emitted = false;
5447
5448 ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
5449 if (c.specialization)
5450 return;
5451
5452 auto &type = this->get<SPIRType>(c.constant_type);
5453 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries.
5454 // FIXME: However, hoisting constants to main() means we need to pass down constant arrays to leaf functions if they are used there.
5455 // If there are multiple functions in the module, drop this case to avoid breaking use cases which do not need to
5456 // link into Metal libraries. This is hacky.
5457 if (!type.array.empty() && (!fully_inlined || is_scalar(type) || is_vector(type)))
5458 {
5459 auto name = to_name(c.self);
5460 statement("constant ", variable_decl(type, name), " = ", constant_expression(c), ";");
5461 emitted = true;
5462 }
5463 });
5464
5465 if (emitted)
5466 statement("");
5467 }
5468
5469 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
declare_complex_constant_arrays()5470 void CompilerMSL::declare_complex_constant_arrays()
5471 {
5472 // If we do not have a fully inlined module, we did not opt in to
5473 // declaring constant arrays of complex types. See CompilerMSL::declare_constant_arrays().
5474 bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
5475 if (!fully_inlined)
5476 return;
5477
5478 // MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
5479 // global constants directly, so we are able to use constants as variable expressions.
5480 bool emitted = false;
5481
5482 ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
5483 if (c.specialization)
5484 return;
5485
5486 auto &type = this->get<SPIRType>(c.constant_type);
5487 if (!type.array.empty() && !(is_scalar(type) || is_vector(type)))
5488 {
5489 auto name = to_name(c.self);
5490 statement("", variable_decl(type, name), " = ", constant_expression(c), ";");
5491 emitted = true;
5492 }
5493 });
5494
5495 if (emitted)
5496 statement("");
5497 }
5498
emit_resources()5499 void CompilerMSL::emit_resources()
5500 {
5501 declare_constant_arrays();
5502 declare_undefined_values();
5503
5504 // Emit the special [[stage_in]] and [[stage_out]] interface blocks which we created.
5505 emit_interface_block(stage_out_var_id);
5506 emit_interface_block(patch_stage_out_var_id);
5507 emit_interface_block(stage_in_var_id);
5508 emit_interface_block(patch_stage_in_var_id);
5509 }
5510
5511 // Emit declarations for the specialization Metal function constants
emit_specialization_constants_and_structs()5512 void CompilerMSL::emit_specialization_constants_and_structs()
5513 {
5514 SpecializationConstant wg_x, wg_y, wg_z;
5515 ID workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
5516 bool emitted = false;
5517
5518 unordered_set<uint32_t> declared_structs;
5519 unordered_set<uint32_t> aligned_structs;
5520
5521 // First, we need to deal with scalar block layout.
5522 // It is possible that a struct may have to be placed at an alignment which does not match the innate alignment of the struct itself.
5523 // In that case, if such a case exists for a struct, we must force that all elements of the struct become packed_ types.
5524 // This makes the struct alignment as small as physically possible.
5525 // When we actually align the struct later, we can insert padding as necessary to make the packed members behave like normally aligned types.
5526 ir.for_each_typed_id<SPIRType>([&](uint32_t type_id, const SPIRType &type) {
5527 if (type.basetype == SPIRType::Struct &&
5528 has_extended_decoration(type_id, SPIRVCrossDecorationBufferBlockRepacked))
5529 mark_scalar_layout_structs(type);
5530 });
5531
5532 // Very particular use of the soft loop lock.
5533 // align_struct may need to create custom types on the fly, but we don't care about
5534 // these types for purpose of iterating over them in ir.ids_for_type and friends.
5535 auto loop_lock = ir.create_loop_soft_lock();
5536
5537 for (auto &id_ : ir.ids_for_constant_or_type)
5538 {
5539 auto &id = ir.ids[id_];
5540
5541 if (id.get_type() == TypeConstant)
5542 {
5543 auto &c = id.get<SPIRConstant>();
5544
5545 if (c.self == workgroup_size_id)
5546 {
5547 // TODO: This can be expressed as a [[threads_per_threadgroup]] input semantic, but we need to know
5548 // the work group size at compile time in SPIR-V, and [[threads_per_threadgroup]] would need to be passed around as a global.
5549 // The work group size may be a specialization constant.
5550 statement("constant uint3 ", builtin_to_glsl(BuiltInWorkgroupSize, StorageClassWorkgroup),
5551 " [[maybe_unused]] = ", constant_expression(get<SPIRConstant>(workgroup_size_id)), ";");
5552 emitted = true;
5553 }
5554 else if (c.specialization)
5555 {
5556 auto &type = get<SPIRType>(c.constant_type);
5557 string sc_type_name = type_to_glsl(type);
5558 string sc_name = to_name(c.self);
5559 string sc_tmp_name = sc_name + "_tmp";
5560
5561 // Function constants are only supported in MSL 1.2 and later.
5562 // If we don't support it just declare the "default" directly.
5563 // This "default" value can be overridden to the true specialization constant by the API user.
5564 // Specialization constants which are used as array length expressions cannot be function constants in MSL,
5565 // so just fall back to macros.
5566 if (msl_options.supports_msl_version(1, 2) && has_decoration(c.self, DecorationSpecId) &&
5567 !c.is_used_as_array_length)
5568 {
5569 uint32_t constant_id = get_decoration(c.self, DecorationSpecId);
5570 // Only scalar, non-composite values can be function constants.
5571 statement("constant ", sc_type_name, " ", sc_tmp_name, " [[function_constant(", constant_id,
5572 ")]];");
5573 statement("constant ", sc_type_name, " ", sc_name, " = is_function_constant_defined(", sc_tmp_name,
5574 ") ? ", sc_tmp_name, " : ", constant_expression(c), ";");
5575 }
5576 else if (has_decoration(c.self, DecorationSpecId))
5577 {
5578 // Fallback to macro overrides.
5579 c.specialization_constant_macro_name =
5580 constant_value_macro_name(get_decoration(c.self, DecorationSpecId));
5581
5582 statement("#ifndef ", c.specialization_constant_macro_name);
5583 statement("#define ", c.specialization_constant_macro_name, " ", constant_expression(c));
5584 statement("#endif");
5585 statement("constant ", sc_type_name, " ", sc_name, " = ", c.specialization_constant_macro_name,
5586 ";");
5587 }
5588 else
5589 {
5590 // Composite specialization constants must be built from other specialization constants.
5591 statement("constant ", sc_type_name, " ", sc_name, " = ", constant_expression(c), ";");
5592 }
5593 emitted = true;
5594 }
5595 }
5596 else if (id.get_type() == TypeConstantOp)
5597 {
5598 auto &c = id.get<SPIRConstantOp>();
5599 auto &type = get<SPIRType>(c.basetype);
5600 auto name = to_name(c.self);
5601 statement("constant ", variable_decl(type, name), " = ", constant_op_expression(c), ";");
5602 emitted = true;
5603 }
5604 else if (id.get_type() == TypeType)
5605 {
5606 // Output non-builtin interface structs. These include local function structs
5607 // and structs nested within uniform and read-write buffers.
5608 auto &type = id.get<SPIRType>();
5609 TypeID type_id = type.self;
5610
5611 bool is_struct = (type.basetype == SPIRType::Struct) && type.array.empty() && !type.pointer;
5612 bool is_block =
5613 has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock);
5614
5615 bool is_builtin_block = is_block && is_builtin_type(type);
5616 bool is_declarable_struct = is_struct && !is_builtin_block;
5617
5618 // We'll declare this later.
5619 if (stage_out_var_id && get_stage_out_struct_type().self == type_id)
5620 is_declarable_struct = false;
5621 if (patch_stage_out_var_id && get_patch_stage_out_struct_type().self == type_id)
5622 is_declarable_struct = false;
5623 if (stage_in_var_id && get_stage_in_struct_type().self == type_id)
5624 is_declarable_struct = false;
5625 if (patch_stage_in_var_id && get_patch_stage_in_struct_type().self == type_id)
5626 is_declarable_struct = false;
5627
5628 // Align and emit declarable structs...but avoid declaring each more than once.
5629 if (is_declarable_struct && declared_structs.count(type_id) == 0)
5630 {
5631 if (emitted)
5632 statement("");
5633 emitted = false;
5634
5635 declared_structs.insert(type_id);
5636
5637 if (has_extended_decoration(type_id, SPIRVCrossDecorationBufferBlockRepacked))
5638 align_struct(type, aligned_structs);
5639
5640 // Make sure we declare the underlying struct type, and not the "decorated" type with pointers, etc.
5641 emit_struct(get<SPIRType>(type_id));
5642 }
5643 }
5644 }
5645
5646 if (emitted)
5647 statement("");
5648 }
5649
emit_binary_unord_op(uint32_t result_type,uint32_t result_id,uint32_t op0,uint32_t op1,const char * op)5650 void CompilerMSL::emit_binary_unord_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
5651 const char *op)
5652 {
5653 bool forward = should_forward(op0) && should_forward(op1);
5654 emit_op(result_type, result_id,
5655 join("(isunordered(", to_enclosed_unpacked_expression(op0), ", ", to_enclosed_unpacked_expression(op1),
5656 ") || ", to_enclosed_unpacked_expression(op0), " ", op, " ", to_enclosed_unpacked_expression(op1),
5657 ")"),
5658 forward);
5659
5660 inherit_expression_dependencies(result_id, op0);
5661 inherit_expression_dependencies(result_id, op1);
5662 }
5663
emit_tessellation_io_load(uint32_t result_type_id,uint32_t id,uint32_t ptr)5664 bool CompilerMSL::emit_tessellation_io_load(uint32_t result_type_id, uint32_t id, uint32_t ptr)
5665 {
5666 auto &ptr_type = expression_type(ptr);
5667 auto &result_type = get<SPIRType>(result_type_id);
5668 if (ptr_type.storage != StorageClassInput && ptr_type.storage != StorageClassOutput)
5669 return false;
5670 if (ptr_type.storage == StorageClassOutput && get_execution_model() == ExecutionModelTessellationEvaluation)
5671 return false;
5672
5673 bool multi_patch_tess_ctl = get_execution_model() == ExecutionModelTessellationControl &&
5674 msl_options.multi_patch_workgroup && ptr_type.storage == StorageClassInput;
5675 bool flat_matrix = is_matrix(result_type) && ptr_type.storage == StorageClassInput && !multi_patch_tess_ctl;
5676 bool flat_struct = result_type.basetype == SPIRType::Struct && ptr_type.storage == StorageClassInput;
5677 bool flat_data_type = flat_matrix || is_array(result_type) || flat_struct;
5678 if (!flat_data_type)
5679 return false;
5680
5681 if (has_decoration(ptr, DecorationPatch))
5682 return false;
5683
5684 // Now, we must unflatten a composite type and take care of interleaving array access with gl_in/gl_out.
5685 // Lots of painful code duplication since we *really* should not unroll these kinds of loads in entry point fixup
5686 // unless we're forced to do this when the code is emitting inoptimal OpLoads.
5687 string expr;
5688
5689 uint32_t interface_index = get_extended_decoration(ptr, SPIRVCrossDecorationInterfaceMemberIndex);
5690 auto *var = maybe_get_backing_variable(ptr);
5691 bool ptr_is_io_variable = ir.ids[ptr].get_type() == TypeVariable;
5692 auto &expr_type = get_pointee_type(ptr_type.self);
5693
5694 const auto &iface_type = expression_type(stage_in_ptr_var_id);
5695
5696 if (result_type.array.size() > 2)
5697 {
5698 SPIRV_CROSS_THROW("Cannot load tessellation IO variables with more than 2 dimensions.");
5699 }
5700 else if (result_type.array.size() == 2)
5701 {
5702 if (!ptr_is_io_variable)
5703 SPIRV_CROSS_THROW("Loading an array-of-array must be loaded directly from an IO variable.");
5704 if (interface_index == uint32_t(-1))
5705 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
5706 if (result_type.basetype == SPIRType::Struct || flat_matrix)
5707 SPIRV_CROSS_THROW("Cannot load array-of-array of composite type in tessellation IO.");
5708
5709 expr += type_to_glsl(result_type) + "({ ";
5710 uint32_t num_control_points = to_array_size_literal(result_type, 1);
5711 uint32_t base_interface_index = interface_index;
5712
5713 auto &sub_type = get<SPIRType>(result_type.parent_type);
5714
5715 for (uint32_t i = 0; i < num_control_points; i++)
5716 {
5717 expr += type_to_glsl(sub_type) + "({ ";
5718 interface_index = base_interface_index;
5719 uint32_t array_size = to_array_size_literal(result_type, 0);
5720 if (multi_patch_tess_ctl)
5721 {
5722 for (uint32_t j = 0; j < array_size; j++)
5723 {
5724 const uint32_t indices[3] = { i, interface_index, j };
5725
5726 AccessChainMeta meta;
5727 expr +=
5728 access_chain_internal(stage_in_ptr_var_id, indices, 3,
5729 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
5730 // If the expression has more vector components than the result type, insert
5731 // a swizzle. This shouldn't happen normally on valid SPIR-V, but it might
5732 // happen if we replace the type of an input variable.
5733 if (!is_matrix(sub_type) && sub_type.basetype != SPIRType::Struct &&
5734 expr_type.vecsize > sub_type.vecsize)
5735 expr += vector_swizzle(sub_type.vecsize, 0);
5736
5737 if (j + 1 < array_size)
5738 expr += ", ";
5739 }
5740 }
5741 else
5742 {
5743 for (uint32_t j = 0; j < array_size; j++, interface_index++)
5744 {
5745 const uint32_t indices[2] = { i, interface_index };
5746
5747 AccessChainMeta meta;
5748 expr +=
5749 access_chain_internal(stage_in_ptr_var_id, indices, 2,
5750 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
5751 if (!is_matrix(sub_type) && sub_type.basetype != SPIRType::Struct &&
5752 expr_type.vecsize > sub_type.vecsize)
5753 expr += vector_swizzle(sub_type.vecsize, 0);
5754
5755 if (j + 1 < array_size)
5756 expr += ", ";
5757 }
5758 }
5759 expr += " })";
5760 if (i + 1 < num_control_points)
5761 expr += ", ";
5762 }
5763 expr += " })";
5764 }
5765 else if (flat_struct)
5766 {
5767 bool is_array_of_struct = is_array(result_type);
5768 if (is_array_of_struct && !ptr_is_io_variable)
5769 SPIRV_CROSS_THROW("Loading array of struct from IO variable must come directly from IO variable.");
5770
5771 uint32_t num_control_points = 1;
5772 if (is_array_of_struct)
5773 {
5774 num_control_points = to_array_size_literal(result_type, 0);
5775 expr += type_to_glsl(result_type) + "({ ";
5776 }
5777
5778 auto &struct_type = is_array_of_struct ? get<SPIRType>(result_type.parent_type) : result_type;
5779 assert(struct_type.array.empty());
5780
5781 for (uint32_t i = 0; i < num_control_points; i++)
5782 {
5783 expr += type_to_glsl(struct_type) + "{ ";
5784 for (uint32_t j = 0; j < uint32_t(struct_type.member_types.size()); j++)
5785 {
5786 // The base interface index is stored per variable for structs.
5787 if (var)
5788 {
5789 interface_index =
5790 get_extended_member_decoration(var->self, j, SPIRVCrossDecorationInterfaceMemberIndex);
5791 }
5792
5793 if (interface_index == uint32_t(-1))
5794 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
5795
5796 const auto &mbr_type = get<SPIRType>(struct_type.member_types[j]);
5797 const auto &expr_mbr_type = get<SPIRType>(expr_type.member_types[j]);
5798 if (is_matrix(mbr_type) && ptr_type.storage == StorageClassInput && !multi_patch_tess_ctl)
5799 {
5800 expr += type_to_glsl(mbr_type) + "(";
5801 for (uint32_t k = 0; k < mbr_type.columns; k++, interface_index++)
5802 {
5803 if (is_array_of_struct)
5804 {
5805 const uint32_t indices[2] = { i, interface_index };
5806 AccessChainMeta meta;
5807 expr += access_chain_internal(
5808 stage_in_ptr_var_id, indices, 2,
5809 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
5810 }
5811 else
5812 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
5813 if (expr_mbr_type.vecsize > mbr_type.vecsize)
5814 expr += vector_swizzle(mbr_type.vecsize, 0);
5815
5816 if (k + 1 < mbr_type.columns)
5817 expr += ", ";
5818 }
5819 expr += ")";
5820 }
5821 else if (is_array(mbr_type))
5822 {
5823 expr += type_to_glsl(mbr_type) + "({ ";
5824 uint32_t array_size = to_array_size_literal(mbr_type, 0);
5825 if (multi_patch_tess_ctl)
5826 {
5827 for (uint32_t k = 0; k < array_size; k++)
5828 {
5829 if (is_array_of_struct)
5830 {
5831 const uint32_t indices[3] = { i, interface_index, k };
5832 AccessChainMeta meta;
5833 expr += access_chain_internal(
5834 stage_in_ptr_var_id, indices, 3,
5835 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
5836 }
5837 else
5838 expr += join(to_expression(ptr), ".", to_member_name(iface_type, interface_index), "[",
5839 k, "]");
5840 if (expr_mbr_type.vecsize > mbr_type.vecsize)
5841 expr += vector_swizzle(mbr_type.vecsize, 0);
5842
5843 if (k + 1 < array_size)
5844 expr += ", ";
5845 }
5846 }
5847 else
5848 {
5849 for (uint32_t k = 0; k < array_size; k++, interface_index++)
5850 {
5851 if (is_array_of_struct)
5852 {
5853 const uint32_t indices[2] = { i, interface_index };
5854 AccessChainMeta meta;
5855 expr += access_chain_internal(
5856 stage_in_ptr_var_id, indices, 2,
5857 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
5858 }
5859 else
5860 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
5861 if (expr_mbr_type.vecsize > mbr_type.vecsize)
5862 expr += vector_swizzle(mbr_type.vecsize, 0);
5863
5864 if (k + 1 < array_size)
5865 expr += ", ";
5866 }
5867 }
5868 expr += " })";
5869 }
5870 else
5871 {
5872 if (is_array_of_struct)
5873 {
5874 const uint32_t indices[2] = { i, interface_index };
5875 AccessChainMeta meta;
5876 expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
5877 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT,
5878 &meta);
5879 }
5880 else
5881 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
5882 if (expr_mbr_type.vecsize > mbr_type.vecsize)
5883 expr += vector_swizzle(mbr_type.vecsize, 0);
5884 }
5885
5886 if (j + 1 < struct_type.member_types.size())
5887 expr += ", ";
5888 }
5889 expr += " }";
5890 if (i + 1 < num_control_points)
5891 expr += ", ";
5892 }
5893 if (is_array_of_struct)
5894 expr += " })";
5895 }
5896 else if (flat_matrix)
5897 {
5898 bool is_array_of_matrix = is_array(result_type);
5899 if (is_array_of_matrix && !ptr_is_io_variable)
5900 SPIRV_CROSS_THROW("Loading array of matrix from IO variable must come directly from IO variable.");
5901 if (interface_index == uint32_t(-1))
5902 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
5903
5904 if (is_array_of_matrix)
5905 {
5906 // Loading a matrix from each control point.
5907 uint32_t base_interface_index = interface_index;
5908 uint32_t num_control_points = to_array_size_literal(result_type, 0);
5909 expr += type_to_glsl(result_type) + "({ ";
5910
5911 auto &matrix_type = get_variable_element_type(get<SPIRVariable>(ptr));
5912
5913 for (uint32_t i = 0; i < num_control_points; i++)
5914 {
5915 interface_index = base_interface_index;
5916 expr += type_to_glsl(matrix_type) + "(";
5917 for (uint32_t j = 0; j < result_type.columns; j++, interface_index++)
5918 {
5919 const uint32_t indices[2] = { i, interface_index };
5920
5921 AccessChainMeta meta;
5922 expr +=
5923 access_chain_internal(stage_in_ptr_var_id, indices, 2,
5924 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
5925 if (expr_type.vecsize > result_type.vecsize)
5926 expr += vector_swizzle(result_type.vecsize, 0);
5927 if (j + 1 < result_type.columns)
5928 expr += ", ";
5929 }
5930 expr += ")";
5931 if (i + 1 < num_control_points)
5932 expr += ", ";
5933 }
5934
5935 expr += " })";
5936 }
5937 else
5938 {
5939 expr += type_to_glsl(result_type) + "(";
5940 for (uint32_t i = 0; i < result_type.columns; i++, interface_index++)
5941 {
5942 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
5943 if (expr_type.vecsize > result_type.vecsize)
5944 expr += vector_swizzle(result_type.vecsize, 0);
5945 if (i + 1 < result_type.columns)
5946 expr += ", ";
5947 }
5948 expr += ")";
5949 }
5950 }
5951 else if (ptr_is_io_variable)
5952 {
5953 assert(is_array(result_type));
5954 assert(result_type.array.size() == 1);
5955 if (interface_index == uint32_t(-1))
5956 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
5957
5958 // We're loading an array directly from a global variable.
5959 // This means we're loading one member from each control point.
5960 expr += type_to_glsl(result_type) + "({ ";
5961 uint32_t num_control_points = to_array_size_literal(result_type, 0);
5962
5963 for (uint32_t i = 0; i < num_control_points; i++)
5964 {
5965 const uint32_t indices[2] = { i, interface_index };
5966
5967 AccessChainMeta meta;
5968 expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
5969 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
5970 if (expr_type.vecsize > result_type.vecsize)
5971 expr += vector_swizzle(result_type.vecsize, 0);
5972
5973 if (i + 1 < num_control_points)
5974 expr += ", ";
5975 }
5976 expr += " })";
5977 }
5978 else
5979 {
5980 // We're loading an array from a concrete control point.
5981 assert(is_array(result_type));
5982 assert(result_type.array.size() == 1);
5983 if (interface_index == uint32_t(-1))
5984 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
5985
5986 expr += type_to_glsl(result_type) + "({ ";
5987 uint32_t array_size = to_array_size_literal(result_type, 0);
5988 for (uint32_t i = 0; i < array_size; i++, interface_index++)
5989 {
5990 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
5991 if (expr_type.vecsize > result_type.vecsize)
5992 expr += vector_swizzle(result_type.vecsize, 0);
5993 if (i + 1 < array_size)
5994 expr += ", ";
5995 }
5996 expr += " })";
5997 }
5998
5999 emit_op(result_type_id, id, expr, false);
6000 register_read(id, ptr, false);
6001 return true;
6002 }
6003
emit_tessellation_access_chain(const uint32_t * ops,uint32_t length)6004 bool CompilerMSL::emit_tessellation_access_chain(const uint32_t *ops, uint32_t length)
6005 {
6006 // If this is a per-vertex output, remap it to the I/O array buffer.
6007
6008 // Any object which did not go through IO flattening shenanigans will go there instead.
6009 // We will unflatten on-demand instead as needed, but not all possible cases can be supported, especially with arrays.
6010
6011 auto *var = maybe_get_backing_variable(ops[2]);
6012 bool patch = false;
6013 bool flat_data = false;
6014 bool ptr_is_chain = false;
6015 bool multi_patch = get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup;
6016
6017 if (var)
6018 {
6019 patch = has_decoration(ops[2], DecorationPatch) || is_patch_block(get_variable_data_type(*var));
6020
6021 // Should match strip_array in add_interface_block.
6022 flat_data = var->storage == StorageClassInput ||
6023 (var->storage == StorageClassOutput && get_execution_model() == ExecutionModelTessellationControl);
6024
6025 // We might have a chained access chain, where
6026 // we first take the access chain to the control point, and then we chain into a member or something similar.
6027 // In this case, we need to skip gl_in/gl_out remapping.
6028 ptr_is_chain = var->self != ID(ops[2]);
6029 }
6030
6031 BuiltIn bi_type = BuiltIn(get_decoration(ops[2], DecorationBuiltIn));
6032 if (var && flat_data && !patch &&
6033 (!is_builtin_variable(*var) || bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
6034 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance ||
6035 get_variable_data_type(*var).basetype == SPIRType::Struct))
6036 {
6037 AccessChainMeta meta;
6038 SmallVector<uint32_t> indices;
6039 uint32_t next_id = ir.increase_bound_by(1);
6040
6041 indices.reserve(length - 3 + 1);
6042
6043 uint32_t first_non_array_index = ptr_is_chain ? 3 : 4;
6044 VariableID stage_var_id = var->storage == StorageClassInput ? stage_in_ptr_var_id : stage_out_ptr_var_id;
6045 VariableID ptr = ptr_is_chain ? VariableID(ops[2]) : stage_var_id;
6046 if (!ptr_is_chain)
6047 {
6048 // Index into gl_in/gl_out with first array index.
6049 indices.push_back(ops[3]);
6050 }
6051
6052 auto &result_ptr_type = get<SPIRType>(ops[0]);
6053
6054 uint32_t const_mbr_id = next_id++;
6055 uint32_t index = get_extended_decoration(var->self, SPIRVCrossDecorationInterfaceMemberIndex);
6056 if (var->storage == StorageClassInput || has_decoration(get_variable_element_type(*var).self, DecorationBlock))
6057 {
6058 uint32_t i = first_non_array_index;
6059 auto *type = &get_variable_element_type(*var);
6060 if (index == uint32_t(-1) && length >= (first_non_array_index + 1))
6061 {
6062 // Maybe this is a struct type in the input class, in which case
6063 // we put it as a decoration on the corresponding member.
6064 index = get_extended_member_decoration(var->self, get_constant(ops[first_non_array_index]).scalar(),
6065 SPIRVCrossDecorationInterfaceMemberIndex);
6066 assert(index != uint32_t(-1));
6067 i++;
6068 type = &get<SPIRType>(type->member_types[get_constant(ops[first_non_array_index]).scalar()]);
6069 }
6070
6071 // In this case, we're poking into flattened structures and arrays, so now we have to
6072 // combine the following indices. If we encounter a non-constant index,
6073 // we're hosed.
6074 for (; i < length; ++i)
6075 {
6076 if ((multi_patch || (!is_array(*type) && !is_matrix(*type))) && type->basetype != SPIRType::Struct)
6077 break;
6078
6079 auto *c = maybe_get<SPIRConstant>(ops[i]);
6080 if (!c || c->specialization)
6081 SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable in tessellation. "
6082 "This is currently unsupported.");
6083
6084 // We're in flattened space, so just increment the member index into IO block.
6085 // We can only do this once in the current implementation, so either:
6086 // Struct, Matrix or 1-dimensional array for a control point.
6087 index += c->scalar();
6088
6089 if (type->parent_type)
6090 type = &get<SPIRType>(type->parent_type);
6091 else if (type->basetype == SPIRType::Struct)
6092 type = &get<SPIRType>(type->member_types[c->scalar()]);
6093 }
6094
6095 if ((!multi_patch && (is_matrix(result_ptr_type) || is_array(result_ptr_type))) ||
6096 result_ptr_type.basetype == SPIRType::Struct)
6097 {
6098 // We're not going to emit the actual member name, we let any further OpLoad take care of that.
6099 // Tag the access chain with the member index we're referencing.
6100 set_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex, index);
6101 }
6102 else
6103 {
6104 // Access the appropriate member of gl_in/gl_out.
6105 set<SPIRConstant>(const_mbr_id, get_uint_type_id(), index, false);
6106 indices.push_back(const_mbr_id);
6107
6108 // Append any straggling access chain indices.
6109 if (i < length)
6110 indices.insert(indices.end(), ops + i, ops + length);
6111 }
6112 }
6113 else
6114 {
6115 assert(index != uint32_t(-1));
6116 set<SPIRConstant>(const_mbr_id, get_uint_type_id(), index, false);
6117 indices.push_back(const_mbr_id);
6118
6119 indices.insert(indices.end(), ops + 4, ops + length);
6120 }
6121
6122 // We use the pointer to the base of the input/output array here,
6123 // so this is always a pointer chain.
6124 string e;
6125
6126 if (!ptr_is_chain)
6127 {
6128 // This is the start of an access chain, use ptr_chain to index into control point array.
6129 e = access_chain(ptr, indices.data(), uint32_t(indices.size()), result_ptr_type, &meta, true);
6130 }
6131 else
6132 {
6133 // If we're accessing a struct, we need to use member indices which are based on the IO block,
6134 // not actual struct type, so we have to use a split access chain here where
6135 // first path resolves the control point index, i.e. gl_in[index], and second half deals with
6136 // looking up flattened member name.
6137
6138 // However, it is possible that we partially accessed a struct,
6139 // by taking pointer to member inside the control-point array.
6140 // For this case, we fall back to a natural access chain since we have already dealt with remapping struct members.
6141 // One way to check this here is if we have 2 implied read expressions.
6142 // First one is the gl_in/gl_out struct itself, then an index into that array.
6143 // If we have traversed further, we use a normal access chain formulation.
6144 auto *ptr_expr = maybe_get<SPIRExpression>(ptr);
6145 if (ptr_expr && ptr_expr->implied_read_expressions.size() == 2)
6146 {
6147 e = join(to_expression(ptr),
6148 access_chain_internal(stage_var_id, indices.data(), uint32_t(indices.size()),
6149 ACCESS_CHAIN_CHAIN_ONLY_BIT, &meta));
6150 }
6151 else
6152 {
6153 e = access_chain_internal(ptr, indices.data(), uint32_t(indices.size()), 0, &meta);
6154 }
6155 }
6156
6157 // Get the actual type of the object that was accessed. If it's a vector type and we changed it,
6158 // then we'll need to add a swizzle.
6159 // For this, we can't necessarily rely on the type of the base expression, because it might be
6160 // another access chain, and it will therefore already have the "correct" type.
6161 auto *expr_type = &get_variable_data_type(*var);
6162 if (has_extended_decoration(ops[2], SPIRVCrossDecorationTessIOOriginalInputTypeID))
6163 expr_type = &get<SPIRType>(get_extended_decoration(ops[2], SPIRVCrossDecorationTessIOOriginalInputTypeID));
6164 for (uint32_t i = 3; i < length; i++)
6165 {
6166 if (!is_array(*expr_type) && expr_type->basetype == SPIRType::Struct)
6167 expr_type = &get<SPIRType>(expr_type->member_types[get<SPIRConstant>(ops[i]).scalar()]);
6168 else
6169 expr_type = &get<SPIRType>(expr_type->parent_type);
6170 }
6171 if (!is_array(*expr_type) && !is_matrix(*expr_type) && expr_type->basetype != SPIRType::Struct &&
6172 expr_type->vecsize > result_ptr_type.vecsize)
6173 e += vector_swizzle(result_ptr_type.vecsize, 0);
6174
6175 auto &expr = set<SPIRExpression>(ops[1], move(e), ops[0], should_forward(ops[2]));
6176 expr.loaded_from = var->self;
6177 expr.need_transpose = meta.need_transpose;
6178 expr.access_chain = true;
6179
6180 // Mark the result as being packed if necessary.
6181 if (meta.storage_is_packed)
6182 set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypePacked);
6183 if (meta.storage_physical_type != 0)
6184 set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypeID, meta.storage_physical_type);
6185 if (meta.storage_is_invariant)
6186 set_decoration(ops[1], DecorationInvariant);
6187 // Save the type we found in case the result is used in another access chain.
6188 set_extended_decoration(ops[1], SPIRVCrossDecorationTessIOOriginalInputTypeID, expr_type->self);
6189
6190 // If we have some expression dependencies in our access chain, this access chain is technically a forwarded
6191 // temporary which could be subject to invalidation.
6192 // Need to assume we're forwarded while calling inherit_expression_depdendencies.
6193 forwarded_temporaries.insert(ops[1]);
6194 // The access chain itself is never forced to a temporary, but its dependencies might.
6195 suppressed_usage_tracking.insert(ops[1]);
6196
6197 for (uint32_t i = 2; i < length; i++)
6198 {
6199 inherit_expression_dependencies(ops[1], ops[i]);
6200 add_implied_read_expression(expr, ops[i]);
6201 }
6202
6203 // If we have no dependencies after all, i.e., all indices in the access chain are immutable temporaries,
6204 // we're not forwarded after all.
6205 if (expr.expression_dependencies.empty())
6206 forwarded_temporaries.erase(ops[1]);
6207
6208 return true;
6209 }
6210
6211 // If this is the inner tessellation level, and we're tessellating triangles,
6212 // drop the last index. It isn't an array in this case, so we can't have an
6213 // array reference here. We need to make this ID a variable instead of an
6214 // expression so we don't try to dereference it as a variable pointer.
6215 // Don't do this if the index is a constant 1, though. We need to drop stores
6216 // to that one.
6217 auto *m = ir.find_meta(var ? var->self : ID(0));
6218 if (get_execution_model() == ExecutionModelTessellationControl && var && m &&
6219 m->decoration.builtin_type == BuiltInTessLevelInner && get_entry_point().flags.get(ExecutionModeTriangles))
6220 {
6221 auto *c = maybe_get<SPIRConstant>(ops[3]);
6222 if (c && c->scalar() == 1)
6223 return false;
6224 auto &dest_var = set<SPIRVariable>(ops[1], *var);
6225 dest_var.basetype = ops[0];
6226 ir.meta[ops[1]] = ir.meta[ops[2]];
6227 inherit_expression_dependencies(ops[1], ops[2]);
6228 return true;
6229 }
6230
6231 return false;
6232 }
6233
is_out_of_bounds_tessellation_level(uint32_t id_lhs)6234 bool CompilerMSL::is_out_of_bounds_tessellation_level(uint32_t id_lhs)
6235 {
6236 if (!get_entry_point().flags.get(ExecutionModeTriangles))
6237 return false;
6238
6239 // In SPIR-V, TessLevelInner always has two elements and TessLevelOuter always has
6240 // four. This is true even if we are tessellating triangles. This allows clients
6241 // to use a single tessellation control shader with multiple tessellation evaluation
6242 // shaders.
6243 // In Metal, however, only the first element of TessLevelInner and the first three
6244 // of TessLevelOuter are accessible. This stems from how in Metal, the tessellation
6245 // levels must be stored to a dedicated buffer in a particular format that depends
6246 // on the patch type. Therefore, in Triangles mode, any access to the second
6247 // inner level or the fourth outer level must be dropped.
6248 const auto *e = maybe_get<SPIRExpression>(id_lhs);
6249 if (!e || !e->access_chain)
6250 return false;
6251 BuiltIn builtin = BuiltIn(get_decoration(e->loaded_from, DecorationBuiltIn));
6252 if (builtin != BuiltInTessLevelInner && builtin != BuiltInTessLevelOuter)
6253 return false;
6254 auto *c = maybe_get<SPIRConstant>(e->implied_read_expressions[1]);
6255 if (!c)
6256 return false;
6257 return (builtin == BuiltInTessLevelInner && c->scalar() == 1) ||
6258 (builtin == BuiltInTessLevelOuter && c->scalar() == 3);
6259 }
6260
prepare_access_chain_for_scalar_access(std::string & expr,const SPIRType & type,spv::StorageClass storage,bool & is_packed)6261 void CompilerMSL::prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
6262 spv::StorageClass storage, bool &is_packed)
6263 {
6264 // If there is any risk of writes happening with the access chain in question,
6265 // and there is a risk of concurrent write access to other components,
6266 // we must cast the access chain to a plain pointer to ensure we only access the exact scalars we expect.
6267 // The MSL compiler refuses to allow component-level access for any non-packed vector types.
6268 if (!is_packed && (storage == StorageClassStorageBuffer || storage == StorageClassWorkgroup))
6269 {
6270 const char *addr_space = storage == StorageClassWorkgroup ? "threadgroup" : "device";
6271 expr = join("((", addr_space, " ", type_to_glsl(type), "*)&", enclose_expression(expr), ")");
6272
6273 // Further indexing should happen with packed rules (array index, not swizzle).
6274 is_packed = true;
6275 }
6276 }
6277
6278 // Override for MSL-specific syntax instructions
emit_instruction(const Instruction & instruction)6279 void CompilerMSL::emit_instruction(const Instruction &instruction)
6280 {
6281 #define MSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
6282 #define MSL_BOP_CAST(op, type) \
6283 emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
6284 #define MSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
6285 #define MSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
6286 #define MSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
6287 #define MSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
6288 #define MSL_BFOP_CAST(op, type) \
6289 emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
6290 #define MSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
6291 #define MSL_UNORD_BOP(op) emit_binary_unord_op(ops[0], ops[1], ops[2], ops[3], #op)
6292
6293 auto ops = stream(instruction);
6294 auto opcode = static_cast<Op>(instruction.op);
6295
6296 // If we need to do implicit bitcasts, make sure we do it with the correct type.
6297 uint32_t integer_width = get_integer_width_for_instruction(instruction);
6298 auto int_type = to_signed_basetype(integer_width);
6299 auto uint_type = to_unsigned_basetype(integer_width);
6300
6301 switch (opcode)
6302 {
6303 case OpLoad:
6304 {
6305 uint32_t id = ops[1];
6306 uint32_t ptr = ops[2];
6307 if (is_tessellation_shader())
6308 {
6309 if (!emit_tessellation_io_load(ops[0], id, ptr))
6310 CompilerGLSL::emit_instruction(instruction);
6311 }
6312 else
6313 {
6314 // Sample mask input for Metal is not an array
6315 if (BuiltIn(get_decoration(ptr, DecorationBuiltIn)) == BuiltInSampleMask)
6316 set_decoration(id, DecorationBuiltIn, BuiltInSampleMask);
6317 CompilerGLSL::emit_instruction(instruction);
6318 }
6319 break;
6320 }
6321
6322 // Comparisons
6323 case OpIEqual:
6324 MSL_BOP_CAST(==, int_type);
6325 break;
6326
6327 case OpLogicalEqual:
6328 case OpFOrdEqual:
6329 MSL_BOP(==);
6330 break;
6331
6332 case OpINotEqual:
6333 MSL_BOP_CAST(!=, int_type);
6334 break;
6335
6336 case OpLogicalNotEqual:
6337 case OpFOrdNotEqual:
6338 MSL_BOP(!=);
6339 break;
6340
6341 case OpUGreaterThan:
6342 MSL_BOP_CAST(>, uint_type);
6343 break;
6344
6345 case OpSGreaterThan:
6346 MSL_BOP_CAST(>, int_type);
6347 break;
6348
6349 case OpFOrdGreaterThan:
6350 MSL_BOP(>);
6351 break;
6352
6353 case OpUGreaterThanEqual:
6354 MSL_BOP_CAST(>=, uint_type);
6355 break;
6356
6357 case OpSGreaterThanEqual:
6358 MSL_BOP_CAST(>=, int_type);
6359 break;
6360
6361 case OpFOrdGreaterThanEqual:
6362 MSL_BOP(>=);
6363 break;
6364
6365 case OpULessThan:
6366 MSL_BOP_CAST(<, uint_type);
6367 break;
6368
6369 case OpSLessThan:
6370 MSL_BOP_CAST(<, int_type);
6371 break;
6372
6373 case OpFOrdLessThan:
6374 MSL_BOP(<);
6375 break;
6376
6377 case OpULessThanEqual:
6378 MSL_BOP_CAST(<=, uint_type);
6379 break;
6380
6381 case OpSLessThanEqual:
6382 MSL_BOP_CAST(<=, int_type);
6383 break;
6384
6385 case OpFOrdLessThanEqual:
6386 MSL_BOP(<=);
6387 break;
6388
6389 case OpFUnordEqual:
6390 MSL_UNORD_BOP(==);
6391 break;
6392
6393 case OpFUnordNotEqual:
6394 MSL_UNORD_BOP(!=);
6395 break;
6396
6397 case OpFUnordGreaterThan:
6398 MSL_UNORD_BOP(>);
6399 break;
6400
6401 case OpFUnordGreaterThanEqual:
6402 MSL_UNORD_BOP(>=);
6403 break;
6404
6405 case OpFUnordLessThan:
6406 MSL_UNORD_BOP(<);
6407 break;
6408
6409 case OpFUnordLessThanEqual:
6410 MSL_UNORD_BOP(<=);
6411 break;
6412
6413 // Derivatives
6414 case OpDPdx:
6415 case OpDPdxFine:
6416 case OpDPdxCoarse:
6417 MSL_UFOP(dfdx);
6418 register_control_dependent_expression(ops[1]);
6419 break;
6420
6421 case OpDPdy:
6422 case OpDPdyFine:
6423 case OpDPdyCoarse:
6424 MSL_UFOP(dfdy);
6425 register_control_dependent_expression(ops[1]);
6426 break;
6427
6428 case OpFwidth:
6429 case OpFwidthCoarse:
6430 case OpFwidthFine:
6431 MSL_UFOP(fwidth);
6432 register_control_dependent_expression(ops[1]);
6433 break;
6434
6435 // Bitfield
6436 case OpBitFieldInsert:
6437 {
6438 emit_bitfield_insert_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], "insert_bits", SPIRType::UInt);
6439 break;
6440 }
6441
6442 case OpBitFieldSExtract:
6443 {
6444 emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "extract_bits", int_type, int_type,
6445 SPIRType::UInt, SPIRType::UInt);
6446 break;
6447 }
6448
6449 case OpBitFieldUExtract:
6450 {
6451 emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "extract_bits", uint_type, uint_type,
6452 SPIRType::UInt, SPIRType::UInt);
6453 break;
6454 }
6455
6456 case OpBitReverse:
6457 // BitReverse does not have issues with sign since result type must match input type.
6458 MSL_UFOP(reverse_bits);
6459 break;
6460
6461 case OpBitCount:
6462 {
6463 auto basetype = expression_type(ops[2]).basetype;
6464 emit_unary_func_op_cast(ops[0], ops[1], ops[2], "popcount", basetype, basetype);
6465 break;
6466 }
6467
6468 case OpFRem:
6469 MSL_BFOP(fmod);
6470 break;
6471
6472 case OpFMul:
6473 if (msl_options.invariant_float_math)
6474 MSL_BFOP(spvFMul);
6475 else
6476 MSL_BOP(*);
6477 break;
6478
6479 case OpFAdd:
6480 if (msl_options.invariant_float_math)
6481 MSL_BFOP(spvFAdd);
6482 else
6483 MSL_BOP(+);
6484 break;
6485
6486 // Atomics
6487 case OpAtomicExchange:
6488 {
6489 uint32_t result_type = ops[0];
6490 uint32_t id = ops[1];
6491 uint32_t ptr = ops[2];
6492 uint32_t mem_sem = ops[4];
6493 uint32_t val = ops[5];
6494 emit_atomic_func_op(result_type, id, "atomic_exchange_explicit", mem_sem, mem_sem, false, ptr, val);
6495 break;
6496 }
6497
6498 case OpAtomicCompareExchange:
6499 {
6500 uint32_t result_type = ops[0];
6501 uint32_t id = ops[1];
6502 uint32_t ptr = ops[2];
6503 uint32_t mem_sem_pass = ops[4];
6504 uint32_t mem_sem_fail = ops[5];
6505 uint32_t val = ops[6];
6506 uint32_t comp = ops[7];
6507 emit_atomic_func_op(result_type, id, "atomic_compare_exchange_weak_explicit", mem_sem_pass, mem_sem_fail, true,
6508 ptr, comp, true, false, val);
6509 break;
6510 }
6511
6512 case OpAtomicCompareExchangeWeak:
6513 SPIRV_CROSS_THROW("OpAtomicCompareExchangeWeak is only supported in kernel profile.");
6514
6515 case OpAtomicLoad:
6516 {
6517 uint32_t result_type = ops[0];
6518 uint32_t id = ops[1];
6519 uint32_t ptr = ops[2];
6520 uint32_t mem_sem = ops[4];
6521 emit_atomic_func_op(result_type, id, "atomic_load_explicit", mem_sem, mem_sem, false, ptr, 0);
6522 break;
6523 }
6524
6525 case OpAtomicStore:
6526 {
6527 uint32_t result_type = expression_type(ops[0]).self;
6528 uint32_t id = ops[0];
6529 uint32_t ptr = ops[0];
6530 uint32_t mem_sem = ops[2];
6531 uint32_t val = ops[3];
6532 emit_atomic_func_op(result_type, id, "atomic_store_explicit", mem_sem, mem_sem, false, ptr, val);
6533 break;
6534 }
6535
6536 #define MSL_AFMO_IMPL(op, valsrc, valconst) \
6537 do \
6538 { \
6539 uint32_t result_type = ops[0]; \
6540 uint32_t id = ops[1]; \
6541 uint32_t ptr = ops[2]; \
6542 uint32_t mem_sem = ops[4]; \
6543 uint32_t val = valsrc; \
6544 emit_atomic_func_op(result_type, id, "atomic_fetch_" #op "_explicit", mem_sem, mem_sem, false, ptr, val, \
6545 false, valconst); \
6546 } while (false)
6547
6548 #define MSL_AFMO(op) MSL_AFMO_IMPL(op, ops[5], false)
6549 #define MSL_AFMIO(op) MSL_AFMO_IMPL(op, 1, true)
6550
6551 case OpAtomicIIncrement:
6552 MSL_AFMIO(add);
6553 break;
6554
6555 case OpAtomicIDecrement:
6556 MSL_AFMIO(sub);
6557 break;
6558
6559 case OpAtomicIAdd:
6560 MSL_AFMO(add);
6561 break;
6562
6563 case OpAtomicISub:
6564 MSL_AFMO(sub);
6565 break;
6566
6567 case OpAtomicSMin:
6568 case OpAtomicUMin:
6569 MSL_AFMO(min);
6570 break;
6571
6572 case OpAtomicSMax:
6573 case OpAtomicUMax:
6574 MSL_AFMO(max);
6575 break;
6576
6577 case OpAtomicAnd:
6578 MSL_AFMO(and);
6579 break;
6580
6581 case OpAtomicOr:
6582 MSL_AFMO(or);
6583 break;
6584
6585 case OpAtomicXor:
6586 MSL_AFMO(xor);
6587 break;
6588
6589 // Images
6590
6591 // Reads == Fetches in Metal
6592 case OpImageRead:
6593 {
6594 // Mark that this shader reads from this image
6595 uint32_t img_id = ops[2];
6596 auto &type = expression_type(img_id);
6597 if (type.image.dim != DimSubpassData)
6598 {
6599 auto *p_var = maybe_get_backing_variable(img_id);
6600 if (p_var && has_decoration(p_var->self, DecorationNonReadable))
6601 {
6602 unset_decoration(p_var->self, DecorationNonReadable);
6603 force_recompile();
6604 }
6605 }
6606
6607 emit_texture_op(instruction, false);
6608 break;
6609 }
6610
6611 // Emulate texture2D atomic operations
6612 case OpImageTexelPointer:
6613 {
6614 // When using the pointer, we need to know which variable it is actually loaded from.
6615 auto *var = maybe_get_backing_variable(ops[2]);
6616 if (var && atomic_image_vars.count(var->self))
6617 {
6618 uint32_t result_type = ops[0];
6619 uint32_t id = ops[1];
6620
6621 std::string coord = to_expression(ops[3]);
6622 auto &type = expression_type(ops[2]);
6623 if (type.image.dim == Dim2D)
6624 {
6625 coord = join("spvImage2DAtomicCoord(", coord, ", ", to_expression(ops[2]), ")");
6626 }
6627
6628 auto &e = set<SPIRExpression>(id, join(to_expression(ops[2]), "_atomic[", coord, "]"), result_type, true);
6629 e.loaded_from = var ? var->self : ID(0);
6630 inherit_expression_dependencies(id, ops[3]);
6631 }
6632 else
6633 {
6634 uint32_t result_type = ops[0];
6635 uint32_t id = ops[1];
6636 auto &e =
6637 set<SPIRExpression>(id, join(to_expression(ops[2]), ", ", to_expression(ops[3])), result_type, true);
6638
6639 // When using the pointer, we need to know which variable it is actually loaded from.
6640 e.loaded_from = var ? var->self : ID(0);
6641 inherit_expression_dependencies(id, ops[3]);
6642 }
6643 break;
6644 }
6645
6646 case OpImageWrite:
6647 {
6648 uint32_t img_id = ops[0];
6649 uint32_t coord_id = ops[1];
6650 uint32_t texel_id = ops[2];
6651 const uint32_t *opt = &ops[3];
6652 uint32_t length = instruction.length - 3;
6653
6654 // Bypass pointers because we need the real image struct
6655 auto &type = expression_type(img_id);
6656 auto &img_type = get<SPIRType>(type.self);
6657
6658 // Ensure this image has been marked as being written to and force a
6659 // recommpile so that the image type output will include write access
6660 auto *p_var = maybe_get_backing_variable(img_id);
6661 if (p_var && has_decoration(p_var->self, DecorationNonWritable))
6662 {
6663 unset_decoration(p_var->self, DecorationNonWritable);
6664 force_recompile();
6665 }
6666
6667 bool forward = false;
6668 uint32_t bias = 0;
6669 uint32_t lod = 0;
6670 uint32_t flags = 0;
6671
6672 if (length)
6673 {
6674 flags = *opt++;
6675 length--;
6676 }
6677
6678 auto test = [&](uint32_t &v, uint32_t flag) {
6679 if (length && (flags & flag))
6680 {
6681 v = *opt++;
6682 length--;
6683 }
6684 };
6685
6686 test(bias, ImageOperandsBiasMask);
6687 test(lod, ImageOperandsLodMask);
6688
6689 auto &texel_type = expression_type(texel_id);
6690 auto store_type = texel_type;
6691 store_type.vecsize = 4;
6692
6693 TextureFunctionArguments args = {};
6694 args.base.img = img_id;
6695 args.base.imgtype = &img_type;
6696 args.base.is_fetch = true;
6697 args.coord = coord_id;
6698 args.lod = lod;
6699 statement(join(to_expression(img_id), ".write(",
6700 remap_swizzle(store_type, texel_type.vecsize, to_expression(texel_id)), ", ",
6701 CompilerMSL::to_function_args(args, &forward), ");"));
6702
6703 if (p_var && variable_storage_is_aliased(*p_var))
6704 flush_all_aliased_variables();
6705
6706 break;
6707 }
6708
6709 case OpImageQuerySize:
6710 case OpImageQuerySizeLod:
6711 {
6712 uint32_t rslt_type_id = ops[0];
6713 auto &rslt_type = get<SPIRType>(rslt_type_id);
6714
6715 uint32_t id = ops[1];
6716
6717 uint32_t img_id = ops[2];
6718 string img_exp = to_expression(img_id);
6719 auto &img_type = expression_type(img_id);
6720 Dim img_dim = img_type.image.dim;
6721 bool img_is_array = img_type.image.arrayed;
6722
6723 if (img_type.basetype != SPIRType::Image)
6724 SPIRV_CROSS_THROW("Invalid type for OpImageQuerySize.");
6725
6726 string lod;
6727 if (opcode == OpImageQuerySizeLod)
6728 {
6729 // LOD index defaults to zero, so don't bother outputing level zero index
6730 string decl_lod = to_expression(ops[3]);
6731 if (decl_lod != "0")
6732 lod = decl_lod;
6733 }
6734
6735 string expr = type_to_glsl(rslt_type) + "(";
6736 expr += img_exp + ".get_width(" + lod + ")";
6737
6738 if (img_dim == Dim2D || img_dim == DimCube || img_dim == Dim3D)
6739 expr += ", " + img_exp + ".get_height(" + lod + ")";
6740
6741 if (img_dim == Dim3D)
6742 expr += ", " + img_exp + ".get_depth(" + lod + ")";
6743
6744 if (img_is_array)
6745 {
6746 expr += ", " + img_exp + ".get_array_size()";
6747 if (img_dim == DimCube && msl_options.emulate_cube_array)
6748 expr += " / 6";
6749 }
6750
6751 expr += ")";
6752
6753 emit_op(rslt_type_id, id, expr, should_forward(img_id));
6754
6755 break;
6756 }
6757
6758 case OpImageQueryLod:
6759 {
6760 if (!msl_options.supports_msl_version(2, 2))
6761 SPIRV_CROSS_THROW("ImageQueryLod is only supported on MSL 2.2 and up.");
6762 uint32_t result_type = ops[0];
6763 uint32_t id = ops[1];
6764 uint32_t image_id = ops[2];
6765 uint32_t coord_id = ops[3];
6766 emit_uninitialized_temporary_expression(result_type, id);
6767
6768 auto sampler_expr = to_sampler_expression(image_id);
6769 auto *combined = maybe_get<SPIRCombinedImageSampler>(image_id);
6770 auto image_expr = combined ? to_expression(combined->image) : to_expression(image_id);
6771
6772 // TODO: It is unclear if calculcate_clamped_lod also conditionally rounds
6773 // the reported LOD based on the sampler. NEAREST miplevel should
6774 // round the LOD, but LINEAR miplevel should not round.
6775 // Let's hope this does not become an issue ...
6776 statement(to_expression(id), ".x = ", image_expr, ".calculate_clamped_lod(", sampler_expr, ", ",
6777 to_expression(coord_id), ");");
6778 statement(to_expression(id), ".y = ", image_expr, ".calculate_unclamped_lod(", sampler_expr, ", ",
6779 to_expression(coord_id), ");");
6780 register_control_dependent_expression(id);
6781 break;
6782 }
6783
6784 #define MSL_ImgQry(qrytype) \
6785 do \
6786 { \
6787 uint32_t rslt_type_id = ops[0]; \
6788 auto &rslt_type = get<SPIRType>(rslt_type_id); \
6789 uint32_t id = ops[1]; \
6790 uint32_t img_id = ops[2]; \
6791 string img_exp = to_expression(img_id); \
6792 string expr = type_to_glsl(rslt_type) + "(" + img_exp + ".get_num_" #qrytype "())"; \
6793 emit_op(rslt_type_id, id, expr, should_forward(img_id)); \
6794 } while (false)
6795
6796 case OpImageQueryLevels:
6797 MSL_ImgQry(mip_levels);
6798 break;
6799
6800 case OpImageQuerySamples:
6801 MSL_ImgQry(samples);
6802 break;
6803
6804 case OpImage:
6805 {
6806 uint32_t result_type = ops[0];
6807 uint32_t id = ops[1];
6808 auto *combined = maybe_get<SPIRCombinedImageSampler>(ops[2]);
6809
6810 if (combined)
6811 {
6812 auto &e = emit_op(result_type, id, to_expression(combined->image), true, true);
6813 auto *var = maybe_get_backing_variable(combined->image);
6814 if (var)
6815 e.loaded_from = var->self;
6816 }
6817 else
6818 {
6819 auto *var = maybe_get_backing_variable(ops[2]);
6820 SPIRExpression *e;
6821 if (var && has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler))
6822 e = &emit_op(result_type, id, join(to_expression(ops[2]), ".plane0"), true, true);
6823 else
6824 e = &emit_op(result_type, id, to_expression(ops[2]), true, true);
6825 if (var)
6826 e->loaded_from = var->self;
6827 }
6828 break;
6829 }
6830
6831 // Casting
6832 case OpQuantizeToF16:
6833 {
6834 uint32_t result_type = ops[0];
6835 uint32_t id = ops[1];
6836 uint32_t arg = ops[2];
6837
6838 string exp;
6839 auto &type = get<SPIRType>(result_type);
6840
6841 switch (type.vecsize)
6842 {
6843 case 1:
6844 exp = join("float(half(", to_expression(arg), "))");
6845 break;
6846 case 2:
6847 exp = join("float2(half2(", to_expression(arg), "))");
6848 break;
6849 case 3:
6850 exp = join("float3(half3(", to_expression(arg), "))");
6851 break;
6852 case 4:
6853 exp = join("float4(half4(", to_expression(arg), "))");
6854 break;
6855 default:
6856 SPIRV_CROSS_THROW("Illegal argument to OpQuantizeToF16.");
6857 }
6858
6859 emit_op(result_type, id, exp, should_forward(arg));
6860 break;
6861 }
6862
6863 case OpInBoundsAccessChain:
6864 case OpAccessChain:
6865 case OpPtrAccessChain:
6866 if (is_tessellation_shader())
6867 {
6868 if (!emit_tessellation_access_chain(ops, instruction.length))
6869 CompilerGLSL::emit_instruction(instruction);
6870 }
6871 else
6872 CompilerGLSL::emit_instruction(instruction);
6873 break;
6874
6875 case OpStore:
6876 if (is_out_of_bounds_tessellation_level(ops[0]))
6877 break;
6878
6879 if (maybe_emit_array_assignment(ops[0], ops[1]))
6880 break;
6881
6882 CompilerGLSL::emit_instruction(instruction);
6883 break;
6884
6885 // Compute barriers
6886 case OpMemoryBarrier:
6887 emit_barrier(0, ops[0], ops[1]);
6888 break;
6889
6890 case OpControlBarrier:
6891 // In GLSL a memory barrier is often followed by a control barrier.
6892 // But in MSL, memory barriers are also control barriers, so don't
6893 // emit a simple control barrier if a memory barrier has just been emitted.
6894 if (previous_instruction_opcode != OpMemoryBarrier)
6895 emit_barrier(ops[0], ops[1], ops[2]);
6896 break;
6897
6898 case OpOuterProduct:
6899 {
6900 uint32_t result_type = ops[0];
6901 uint32_t id = ops[1];
6902 uint32_t a = ops[2];
6903 uint32_t b = ops[3];
6904
6905 auto &type = get<SPIRType>(result_type);
6906 string expr = type_to_glsl_constructor(type);
6907 expr += "(";
6908 for (uint32_t col = 0; col < type.columns; col++)
6909 {
6910 expr += to_enclosed_expression(a);
6911 expr += " * ";
6912 expr += to_extract_component_expression(b, col);
6913 if (col + 1 < type.columns)
6914 expr += ", ";
6915 }
6916 expr += ")";
6917 emit_op(result_type, id, expr, should_forward(a) && should_forward(b));
6918 inherit_expression_dependencies(id, a);
6919 inherit_expression_dependencies(id, b);
6920 break;
6921 }
6922
6923 case OpVectorTimesMatrix:
6924 case OpMatrixTimesVector:
6925 {
6926 if (!msl_options.invariant_float_math)
6927 {
6928 CompilerGLSL::emit_instruction(instruction);
6929 break;
6930 }
6931
6932 // If the matrix needs transpose, just flip the multiply order.
6933 auto *e = maybe_get<SPIRExpression>(ops[opcode == OpMatrixTimesVector ? 2 : 3]);
6934 if (e && e->need_transpose)
6935 {
6936 e->need_transpose = false;
6937 string expr;
6938
6939 if (opcode == OpMatrixTimesVector)
6940 {
6941 expr = join("spvFMulVectorMatrix(", to_enclosed_unpacked_expression(ops[3]), ", ",
6942 to_unpacked_row_major_matrix_expression(ops[2]), ")");
6943 }
6944 else
6945 {
6946 expr = join("spvFMulMatrixVector(", to_unpacked_row_major_matrix_expression(ops[3]), ", ",
6947 to_enclosed_unpacked_expression(ops[2]), ")");
6948 }
6949
6950 bool forward = should_forward(ops[2]) && should_forward(ops[3]);
6951 emit_op(ops[0], ops[1], expr, forward);
6952 e->need_transpose = true;
6953 inherit_expression_dependencies(ops[1], ops[2]);
6954 inherit_expression_dependencies(ops[1], ops[3]);
6955 }
6956 else
6957 {
6958 if (opcode == OpMatrixTimesVector)
6959 MSL_BFOP(spvFMulMatrixVector);
6960 else
6961 MSL_BFOP(spvFMulVectorMatrix);
6962 }
6963 break;
6964 }
6965
6966 case OpMatrixTimesMatrix:
6967 {
6968 if (!msl_options.invariant_float_math)
6969 {
6970 CompilerGLSL::emit_instruction(instruction);
6971 break;
6972 }
6973
6974 auto *a = maybe_get<SPIRExpression>(ops[2]);
6975 auto *b = maybe_get<SPIRExpression>(ops[3]);
6976
6977 // If both matrices need transpose, we can multiply in flipped order and tag the expression as transposed.
6978 // a^T * b^T = (b * a)^T.
6979 if (a && b && a->need_transpose && b->need_transpose)
6980 {
6981 a->need_transpose = false;
6982 b->need_transpose = false;
6983
6984 auto expr =
6985 join("spvFMulMatrixMatrix(", enclose_expression(to_unpacked_row_major_matrix_expression(ops[3])), ", ",
6986 enclose_expression(to_unpacked_row_major_matrix_expression(ops[2])), ")");
6987
6988 bool forward = should_forward(ops[2]) && should_forward(ops[3]);
6989 auto &e = emit_op(ops[0], ops[1], expr, forward);
6990 e.need_transpose = true;
6991 a->need_transpose = true;
6992 b->need_transpose = true;
6993 inherit_expression_dependencies(ops[1], ops[2]);
6994 inherit_expression_dependencies(ops[1], ops[3]);
6995 }
6996 else
6997 MSL_BFOP(spvFMulMatrixMatrix);
6998
6999 break;
7000 }
7001
7002 case OpIAddCarry:
7003 case OpISubBorrow:
7004 {
7005 uint32_t result_type = ops[0];
7006 uint32_t result_id = ops[1];
7007 uint32_t op0 = ops[2];
7008 uint32_t op1 = ops[3];
7009 auto &type = get<SPIRType>(result_type);
7010 emit_uninitialized_temporary_expression(result_type, result_id);
7011
7012 auto &res_type = get<SPIRType>(type.member_types[1]);
7013 if (opcode == OpIAddCarry)
7014 {
7015 statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " + ",
7016 to_enclosed_expression(op1), ";");
7017 statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
7018 "(1), ", type_to_glsl(res_type), "(0), ", to_expression(result_id), ".", to_member_name(type, 0),
7019 " >= max(", to_expression(op0), ", ", to_expression(op1), "));");
7020 }
7021 else
7022 {
7023 statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " - ",
7024 to_enclosed_expression(op1), ";");
7025 statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
7026 "(1), ", type_to_glsl(res_type), "(0), ", to_enclosed_expression(op0),
7027 " >= ", to_enclosed_expression(op1), ");");
7028 }
7029 break;
7030 }
7031
7032 case OpUMulExtended:
7033 case OpSMulExtended:
7034 {
7035 uint32_t result_type = ops[0];
7036 uint32_t result_id = ops[1];
7037 uint32_t op0 = ops[2];
7038 uint32_t op1 = ops[3];
7039 auto &type = get<SPIRType>(result_type);
7040 emit_uninitialized_temporary_expression(result_type, result_id);
7041
7042 statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " * ",
7043 to_enclosed_expression(op1), ";");
7044 statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(", to_expression(op0), ", ",
7045 to_expression(op1), ");");
7046 break;
7047 }
7048
7049 case OpArrayLength:
7050 {
7051 auto &type = expression_type(ops[2]);
7052 uint32_t offset = type_struct_member_offset(type, ops[3]);
7053 uint32_t stride = type_struct_member_array_stride(type, ops[3]);
7054
7055 auto expr = join("(", to_buffer_size_expression(ops[2]), " - ", offset, ") / ", stride);
7056 emit_op(ops[0], ops[1], expr, true);
7057 break;
7058 }
7059
7060 // SPV_INTEL_shader_integer_functions2
7061 case OpUCountLeadingZerosINTEL:
7062 MSL_UFOP(clz);
7063 break;
7064
7065 case OpUCountTrailingZerosINTEL:
7066 MSL_UFOP(ctz);
7067 break;
7068
7069 case OpAbsISubINTEL:
7070 case OpAbsUSubINTEL:
7071 MSL_BFOP(absdiff);
7072 break;
7073
7074 case OpIAddSatINTEL:
7075 case OpUAddSatINTEL:
7076 MSL_BFOP(addsat);
7077 break;
7078
7079 case OpIAverageINTEL:
7080 case OpUAverageINTEL:
7081 MSL_BFOP(hadd);
7082 break;
7083
7084 case OpIAverageRoundedINTEL:
7085 case OpUAverageRoundedINTEL:
7086 MSL_BFOP(rhadd);
7087 break;
7088
7089 case OpISubSatINTEL:
7090 case OpUSubSatINTEL:
7091 MSL_BFOP(subsat);
7092 break;
7093
7094 case OpIMul32x16INTEL:
7095 {
7096 uint32_t result_type = ops[0];
7097 uint32_t id = ops[1];
7098 uint32_t a = ops[2], b = ops[3];
7099 bool forward = should_forward(a) && should_forward(b);
7100 emit_op(result_type, id, join("int(short(", to_expression(a), ")) * int(short(", to_expression(b), "))"),
7101 forward);
7102 inherit_expression_dependencies(id, a);
7103 inherit_expression_dependencies(id, b);
7104 break;
7105 }
7106
7107 case OpUMul32x16INTEL:
7108 {
7109 uint32_t result_type = ops[0];
7110 uint32_t id = ops[1];
7111 uint32_t a = ops[2], b = ops[3];
7112 bool forward = should_forward(a) && should_forward(b);
7113 emit_op(result_type, id, join("uint(ushort(", to_expression(a), ")) * uint(ushort(", to_expression(b), "))"),
7114 forward);
7115 inherit_expression_dependencies(id, a);
7116 inherit_expression_dependencies(id, b);
7117 break;
7118 }
7119
7120 case OpIsHelperInvocationEXT:
7121 if (msl_options.is_ios())
7122 SPIRV_CROSS_THROW("simd_is_helper_thread() is only supported on macOS.");
7123 else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
7124 SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
7125 emit_op(ops[0], ops[1], "simd_is_helper_thread()", false);
7126 break;
7127
7128 case OpBeginInvocationInterlockEXT:
7129 case OpEndInvocationInterlockEXT:
7130 if (!msl_options.supports_msl_version(2, 0))
7131 SPIRV_CROSS_THROW("Raster order groups require MSL 2.0.");
7132 break; // Nothing to do in the body
7133
7134 default:
7135 CompilerGLSL::emit_instruction(instruction);
7136 break;
7137 }
7138
7139 previous_instruction_opcode = opcode;
7140 }
7141
emit_texture_op(const Instruction & i,bool sparse)7142 void CompilerMSL::emit_texture_op(const Instruction &i, bool sparse)
7143 {
7144 if (sparse)
7145 SPIRV_CROSS_THROW("Sparse feedback not yet supported in MSL.");
7146
7147 if (msl_options.is_ios() && msl_options.ios_use_framebuffer_fetch_subpasses)
7148 {
7149 auto *ops = stream(i);
7150
7151 uint32_t result_type_id = ops[0];
7152 uint32_t id = ops[1];
7153 uint32_t img = ops[2];
7154
7155 auto &type = expression_type(img);
7156 auto &imgtype = get<SPIRType>(type.self);
7157
7158 // Use Metal's native frame-buffer fetch API for subpass inputs.
7159 if (imgtype.image.dim == DimSubpassData)
7160 {
7161 // Subpass inputs cannot be invalidated,
7162 // so just forward the expression directly.
7163 string expr = to_expression(img);
7164 emit_op(result_type_id, id, expr, true);
7165 return;
7166 }
7167 }
7168
7169 // Fallback to default implementation
7170 CompilerGLSL::emit_texture_op(i, sparse);
7171 }
7172
emit_barrier(uint32_t id_exe_scope,uint32_t id_mem_scope,uint32_t id_mem_sem)7173 void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uint32_t id_mem_sem)
7174 {
7175 if (get_execution_model() != ExecutionModelGLCompute && get_execution_model() != ExecutionModelTessellationControl)
7176 return;
7177
7178 uint32_t exe_scope = id_exe_scope ? evaluate_constant_u32(id_exe_scope) : uint32_t(ScopeInvocation);
7179 uint32_t mem_scope = id_mem_scope ? evaluate_constant_u32(id_mem_scope) : uint32_t(ScopeInvocation);
7180 // Use the wider of the two scopes (smaller value)
7181 exe_scope = min(exe_scope, mem_scope);
7182
7183 string bar_stmt;
7184 if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2))
7185 bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier";
7186 else
7187 bar_stmt = "threadgroup_barrier";
7188 bar_stmt += "(";
7189
7190 uint32_t mem_sem = id_mem_sem ? evaluate_constant_u32(id_mem_sem) : uint32_t(MemorySemanticsMaskNone);
7191
7192 // Use the | operator to combine flags if we can.
7193 if (msl_options.supports_msl_version(1, 2))
7194 {
7195 string mem_flags = "";
7196 // For tesc shaders, this also affects objects in the Output storage class.
7197 // Since in Metal, these are placed in a device buffer, we have to sync device memory here.
7198 if (get_execution_model() == ExecutionModelTessellationControl ||
7199 (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)))
7200 mem_flags += "mem_flags::mem_device";
7201
7202 // Fix tessellation patch function processing
7203 if (get_execution_model() == ExecutionModelTessellationControl ||
7204 (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
7205 {
7206 if (!mem_flags.empty())
7207 mem_flags += " | ";
7208 mem_flags += "mem_flags::mem_threadgroup";
7209 }
7210 if (mem_sem & MemorySemanticsImageMemoryMask)
7211 {
7212 if (!mem_flags.empty())
7213 mem_flags += " | ";
7214 mem_flags += "mem_flags::mem_texture";
7215 }
7216
7217 if (mem_flags.empty())
7218 mem_flags = "mem_flags::mem_none";
7219
7220 bar_stmt += mem_flags;
7221 }
7222 else
7223 {
7224 if ((mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)) &&
7225 (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
7226 bar_stmt += "mem_flags::mem_device_and_threadgroup";
7227 else if (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask))
7228 bar_stmt += "mem_flags::mem_device";
7229 else if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask))
7230 bar_stmt += "mem_flags::mem_threadgroup";
7231 else if (mem_sem & MemorySemanticsImageMemoryMask)
7232 bar_stmt += "mem_flags::mem_texture";
7233 else
7234 bar_stmt += "mem_flags::mem_none";
7235 }
7236
7237 bar_stmt += ");";
7238
7239 statement(bar_stmt);
7240
7241 assert(current_emitting_block);
7242 flush_control_dependent_expressions(current_emitting_block->self);
7243 flush_all_active_variables();
7244 }
7245
emit_array_copy(const string & lhs,uint32_t rhs_id,StorageClass lhs_storage,StorageClass rhs_storage)7246 void CompilerMSL::emit_array_copy(const string &lhs, uint32_t rhs_id, StorageClass lhs_storage,
7247 StorageClass rhs_storage)
7248 {
7249 // Allow Metal to use the array<T> template to make arrays a value type.
7250 // This, however, cannot be used for threadgroup address specifiers, so consider the custom array copy as fallback.
7251 bool lhs_thread = (lhs_storage == StorageClassOutput || lhs_storage == StorageClassFunction ||
7252 lhs_storage == StorageClassGeneric || lhs_storage == StorageClassPrivate);
7253 bool rhs_thread = (rhs_storage == StorageClassInput || rhs_storage == StorageClassFunction ||
7254 rhs_storage == StorageClassGeneric || rhs_storage == StorageClassPrivate);
7255
7256 // If threadgroup storage qualifiers are *not* used:
7257 // Avoid spvCopy* wrapper functions; Otherwise, spvUnsafeArray<> template cannot be used with that storage qualifier.
7258 if (lhs_thread && rhs_thread && !using_builtin_array())
7259 {
7260 statement(lhs, " = ", to_expression(rhs_id), ";");
7261 }
7262 else
7263 {
7264 // Assignment from an array initializer is fine.
7265 auto &type = expression_type(rhs_id);
7266 auto *var = maybe_get_backing_variable(rhs_id);
7267
7268 // Unfortunately, we cannot template on address space in MSL,
7269 // so explicit address space redirection it is ...
7270 bool is_constant = false;
7271 if (ir.ids[rhs_id].get_type() == TypeConstant)
7272 {
7273 is_constant = true;
7274 }
7275 else if (var && var->remapped_variable && var->statically_assigned &&
7276 ir.ids[var->static_expression].get_type() == TypeConstant)
7277 {
7278 is_constant = true;
7279 }
7280 else if (rhs_storage == StorageClassUniform)
7281 {
7282 is_constant = true;
7283 }
7284
7285 // For the case where we have OpLoad triggering an array copy,
7286 // we cannot easily detect this case ahead of time since it's
7287 // context dependent. We might have to force a recompile here
7288 // if this is the only use of array copies in our shader.
7289 if (type.array.size() > 1)
7290 {
7291 if (type.array.size() > kArrayCopyMultidimMax)
7292 SPIRV_CROSS_THROW("Cannot support this many dimensions for arrays of arrays.");
7293 auto func = static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + type.array.size());
7294 add_spv_func_and_recompile(func);
7295 }
7296 else
7297 add_spv_func_and_recompile(SPVFuncImplArrayCopy);
7298
7299 const char *tag = nullptr;
7300 if (lhs_thread && is_constant)
7301 tag = "FromConstantToStack";
7302 else if (lhs_storage == StorageClassWorkgroup && is_constant)
7303 tag = "FromConstantToThreadGroup";
7304 else if (lhs_thread && rhs_thread)
7305 tag = "FromStackToStack";
7306 else if (lhs_storage == StorageClassWorkgroup && rhs_thread)
7307 tag = "FromStackToThreadGroup";
7308 else if (lhs_thread && rhs_storage == StorageClassWorkgroup)
7309 tag = "FromThreadGroupToStack";
7310 else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassWorkgroup)
7311 tag = "FromThreadGroupToThreadGroup";
7312 else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassStorageBuffer)
7313 tag = "FromDeviceToDevice";
7314 else if (lhs_storage == StorageClassStorageBuffer && is_constant)
7315 tag = "FromConstantToDevice";
7316 else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassWorkgroup)
7317 tag = "FromThreadGroupToDevice";
7318 else if (lhs_storage == StorageClassStorageBuffer && rhs_thread)
7319 tag = "FromStackToDevice";
7320 else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassStorageBuffer)
7321 tag = "FromDeviceToThreadGroup";
7322 else if (lhs_thread && rhs_storage == StorageClassStorageBuffer)
7323 tag = "FromDeviceToStack";
7324 else
7325 SPIRV_CROSS_THROW("Unknown storage class used for copying arrays.");
7326
7327 // Pass internal array of spvUnsafeArray<> into wrapper functions
7328 if (lhs_thread && !msl_options.force_native_arrays)
7329 statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ".elements, ", to_expression(rhs_id), ");");
7330 else if (rhs_thread && !msl_options.force_native_arrays)
7331 statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ", ", to_expression(rhs_id), ".elements);");
7332 else
7333 statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ", ", to_expression(rhs_id), ");");
7334 }
7335 }
7336
7337 // Since MSL does not allow arrays to be copied via simple variable assignment,
7338 // if the LHS and RHS represent an assignment of an entire array, it must be
7339 // implemented by calling an array copy function.
7340 // Returns whether the struct assignment was emitted.
maybe_emit_array_assignment(uint32_t id_lhs,uint32_t id_rhs)7341 bool CompilerMSL::maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs)
7342 {
7343 // We only care about assignments of an entire array
7344 auto &type = expression_type(id_rhs);
7345 if (type.array.size() == 0)
7346 return false;
7347
7348 auto *var = maybe_get<SPIRVariable>(id_lhs);
7349
7350 // Is this a remapped, static constant? Don't do anything.
7351 if (var && var->remapped_variable && var->statically_assigned)
7352 return true;
7353
7354 if (ir.ids[id_rhs].get_type() == TypeConstant && var && var->deferred_declaration)
7355 {
7356 // Special case, if we end up declaring a variable when assigning the constant array,
7357 // we can avoid the copy by directly assigning the constant expression.
7358 // This is likely necessary to be able to use a variable as a true look-up table, as it is unlikely
7359 // the compiler will be able to optimize the spvArrayCopy() into a constant LUT.
7360 // After a variable has been declared, we can no longer assign constant arrays in MSL unfortunately.
7361 statement(to_expression(id_lhs), " = ", constant_expression(get<SPIRConstant>(id_rhs)), ";");
7362 return true;
7363 }
7364
7365 // Ensure the LHS variable has been declared
7366 auto *p_v_lhs = maybe_get_backing_variable(id_lhs);
7367 if (p_v_lhs)
7368 flush_variable_declaration(p_v_lhs->self);
7369
7370 emit_array_copy(to_expression(id_lhs), id_rhs, get_expression_effective_storage_class(id_lhs),
7371 get_expression_effective_storage_class(id_rhs));
7372 register_write(id_lhs);
7373
7374 return true;
7375 }
7376
7377 // Emits one of the atomic functions. In MSL, the atomic functions operate on pointers
emit_atomic_func_op(uint32_t result_type,uint32_t result_id,const char * op,uint32_t mem_order_1,uint32_t mem_order_2,bool has_mem_order_2,uint32_t obj,uint32_t op1,bool op1_is_pointer,bool op1_is_literal,uint32_t op2)7378 void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, uint32_t mem_order_1,
7379 uint32_t mem_order_2, bool has_mem_order_2, uint32_t obj, uint32_t op1,
7380 bool op1_is_pointer, bool op1_is_literal, uint32_t op2)
7381 {
7382 string exp = string(op) + "(";
7383
7384 auto &type = get_pointee_type(expression_type(obj));
7385 exp += "(";
7386 auto *var = maybe_get_backing_variable(obj);
7387 if (!var)
7388 SPIRV_CROSS_THROW("No backing variable for atomic operation.");
7389
7390 // Emulate texture2D atomic operations
7391 const auto &res_type = get<SPIRType>(var->basetype);
7392 if (res_type.storage == StorageClassUniformConstant && res_type.basetype == SPIRType::Image)
7393 {
7394 exp += "device";
7395 }
7396 else
7397 {
7398 exp += get_argument_address_space(*var);
7399 }
7400
7401 exp += " atomic_";
7402 exp += type_to_glsl(type);
7403 exp += "*)";
7404
7405 exp += "&";
7406 exp += to_enclosed_expression(obj);
7407
7408 bool is_atomic_compare_exchange_strong = op1_is_pointer && op1;
7409
7410 if (is_atomic_compare_exchange_strong)
7411 {
7412 assert(strcmp(op, "atomic_compare_exchange_weak_explicit") == 0);
7413 assert(op2);
7414 assert(has_mem_order_2);
7415 exp += ", &";
7416 exp += to_name(result_id);
7417 exp += ", ";
7418 exp += to_expression(op2);
7419 exp += ", ";
7420 exp += get_memory_order(mem_order_1);
7421 exp += ", ";
7422 exp += get_memory_order(mem_order_2);
7423 exp += ")";
7424
7425 // MSL only supports the weak atomic compare exchange, so emit a CAS loop here.
7426 // The MSL function returns false if the atomic write fails OR the comparison test fails,
7427 // so we must validate that it wasn't the comparison test that failed before continuing
7428 // the CAS loop, otherwise it will loop infinitely, with the comparison test always failing.
7429 // The function updates the comparitor value from the memory value, so the additional
7430 // comparison test evaluates the memory value against the expected value.
7431 emit_uninitialized_temporary_expression(result_type, result_id);
7432 statement("do");
7433 begin_scope();
7434 statement(to_name(result_id), " = ", to_expression(op1), ";");
7435 end_scope_decl(join("while (!", exp, " && ", to_name(result_id), " == ", to_enclosed_expression(op1), ")"));
7436 }
7437 else
7438 {
7439 assert(strcmp(op, "atomic_compare_exchange_weak_explicit") != 0);
7440 if (op1)
7441 {
7442 if (op1_is_literal)
7443 exp += join(", ", op1);
7444 else
7445 exp += ", " + to_expression(op1);
7446 }
7447 if (op2)
7448 exp += ", " + to_expression(op2);
7449
7450 exp += string(", ") + get_memory_order(mem_order_1);
7451 if (has_mem_order_2)
7452 exp += string(", ") + get_memory_order(mem_order_2);
7453
7454 exp += ")";
7455
7456 if (strcmp(op, "atomic_store_explicit") != 0)
7457 emit_op(result_type, result_id, exp, false);
7458 else
7459 statement(exp, ";");
7460 }
7461
7462 flush_all_atomic_capable_variables();
7463 }
7464
7465 // Metal only supports relaxed memory order for now
get_memory_order(uint32_t)7466 const char *CompilerMSL::get_memory_order(uint32_t)
7467 {
7468 return "memory_order_relaxed";
7469 }
7470
7471 // Override for MSL-specific extension syntax instructions
emit_glsl_op(uint32_t result_type,uint32_t id,uint32_t eop,const uint32_t * args,uint32_t count)7472 void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
7473 {
7474 auto op = static_cast<GLSLstd450>(eop);
7475
7476 // If we need to do implicit bitcasts, make sure we do it with the correct type.
7477 uint32_t integer_width = get_integer_width_for_glsl_instruction(op, args, count);
7478 auto int_type = to_signed_basetype(integer_width);
7479 auto uint_type = to_unsigned_basetype(integer_width);
7480
7481 switch (op)
7482 {
7483 case GLSLstd450Atan2:
7484 emit_binary_func_op(result_type, id, args[0], args[1], "atan2");
7485 break;
7486 case GLSLstd450InverseSqrt:
7487 emit_unary_func_op(result_type, id, args[0], "rsqrt");
7488 break;
7489 case GLSLstd450RoundEven:
7490 emit_unary_func_op(result_type, id, args[0], "rint");
7491 break;
7492
7493 case GLSLstd450FindILsb:
7494 {
7495 // In this template version of findLSB, we return T.
7496 auto basetype = expression_type(args[0]).basetype;
7497 emit_unary_func_op_cast(result_type, id, args[0], "spvFindLSB", basetype, basetype);
7498 break;
7499 }
7500
7501 case GLSLstd450FindSMsb:
7502 emit_unary_func_op_cast(result_type, id, args[0], "spvFindSMSB", int_type, int_type);
7503 break;
7504
7505 case GLSLstd450FindUMsb:
7506 emit_unary_func_op_cast(result_type, id, args[0], "spvFindUMSB", uint_type, uint_type);
7507 break;
7508
7509 case GLSLstd450PackSnorm4x8:
7510 emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm4x8");
7511 break;
7512 case GLSLstd450PackUnorm4x8:
7513 emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm4x8");
7514 break;
7515 case GLSLstd450PackSnorm2x16:
7516 emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm2x16");
7517 break;
7518 case GLSLstd450PackUnorm2x16:
7519 emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm2x16");
7520 break;
7521
7522 case GLSLstd450PackHalf2x16:
7523 {
7524 auto expr = join("as_type<uint>(half2(", to_expression(args[0]), "))");
7525 emit_op(result_type, id, expr, should_forward(args[0]));
7526 inherit_expression_dependencies(id, args[0]);
7527 break;
7528 }
7529
7530 case GLSLstd450UnpackSnorm4x8:
7531 emit_unary_func_op(result_type, id, args[0], "unpack_snorm4x8_to_float");
7532 break;
7533 case GLSLstd450UnpackUnorm4x8:
7534 emit_unary_func_op(result_type, id, args[0], "unpack_unorm4x8_to_float");
7535 break;
7536 case GLSLstd450UnpackSnorm2x16:
7537 emit_unary_func_op(result_type, id, args[0], "unpack_snorm2x16_to_float");
7538 break;
7539 case GLSLstd450UnpackUnorm2x16:
7540 emit_unary_func_op(result_type, id, args[0], "unpack_unorm2x16_to_float");
7541 break;
7542
7543 case GLSLstd450UnpackHalf2x16:
7544 {
7545 auto expr = join("float2(as_type<half2>(", to_expression(args[0]), "))");
7546 emit_op(result_type, id, expr, should_forward(args[0]));
7547 inherit_expression_dependencies(id, args[0]);
7548 break;
7549 }
7550
7551 case GLSLstd450PackDouble2x32:
7552 emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450PackDouble2x32"); // Currently unsupported
7553 break;
7554 case GLSLstd450UnpackDouble2x32:
7555 emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450UnpackDouble2x32"); // Currently unsupported
7556 break;
7557
7558 case GLSLstd450MatrixInverse:
7559 {
7560 auto &mat_type = get<SPIRType>(result_type);
7561 switch (mat_type.columns)
7562 {
7563 case 2:
7564 emit_unary_func_op(result_type, id, args[0], "spvInverse2x2");
7565 break;
7566 case 3:
7567 emit_unary_func_op(result_type, id, args[0], "spvInverse3x3");
7568 break;
7569 case 4:
7570 emit_unary_func_op(result_type, id, args[0], "spvInverse4x4");
7571 break;
7572 default:
7573 break;
7574 }
7575 break;
7576 }
7577
7578 case GLSLstd450FMin:
7579 // If the result type isn't float, don't bother calling the specific
7580 // precise::/fast:: version. Metal doesn't have those for half and
7581 // double types.
7582 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
7583 emit_binary_func_op(result_type, id, args[0], args[1], "min");
7584 else
7585 emit_binary_func_op(result_type, id, args[0], args[1], "fast::min");
7586 break;
7587
7588 case GLSLstd450FMax:
7589 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
7590 emit_binary_func_op(result_type, id, args[0], args[1], "max");
7591 else
7592 emit_binary_func_op(result_type, id, args[0], args[1], "fast::max");
7593 break;
7594
7595 case GLSLstd450FClamp:
7596 // TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
7597 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
7598 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
7599 else
7600 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "fast::clamp");
7601 break;
7602
7603 case GLSLstd450NMin:
7604 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
7605 emit_binary_func_op(result_type, id, args[0], args[1], "min");
7606 else
7607 emit_binary_func_op(result_type, id, args[0], args[1], "precise::min");
7608 break;
7609
7610 case GLSLstd450NMax:
7611 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
7612 emit_binary_func_op(result_type, id, args[0], args[1], "max");
7613 else
7614 emit_binary_func_op(result_type, id, args[0], args[1], "precise::max");
7615 break;
7616
7617 case GLSLstd450NClamp:
7618 // TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
7619 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
7620 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
7621 else
7622 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "precise::clamp");
7623 break;
7624
7625 // TODO:
7626 // GLSLstd450InterpolateAtCentroid (centroid_no_perspective qualifier)
7627 // GLSLstd450InterpolateAtSample (sample_no_perspective qualifier)
7628 // GLSLstd450InterpolateAtOffset
7629
7630 case GLSLstd450Distance:
7631 // MSL does not support scalar versions here.
7632 if (expression_type(args[0]).vecsize == 1)
7633 {
7634 // Equivalent to length(a - b) -> abs(a - b).
7635 emit_op(result_type, id,
7636 join("abs(", to_enclosed_unpacked_expression(args[0]), " - ",
7637 to_enclosed_unpacked_expression(args[1]), ")"),
7638 should_forward(args[0]) && should_forward(args[1]));
7639 inherit_expression_dependencies(id, args[0]);
7640 inherit_expression_dependencies(id, args[1]);
7641 }
7642 else
7643 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
7644 break;
7645
7646 case GLSLstd450Length:
7647 // MSL does not support scalar versions here.
7648 if (expression_type(args[0]).vecsize == 1)
7649 {
7650 // Equivalent to abs().
7651 emit_unary_func_op(result_type, id, args[0], "abs");
7652 }
7653 else
7654 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
7655 break;
7656
7657 case GLSLstd450Normalize:
7658 // MSL does not support scalar versions here.
7659 if (expression_type(args[0]).vecsize == 1)
7660 {
7661 // Returns -1 or 1 for valid input, sign() does the job.
7662 emit_unary_func_op(result_type, id, args[0], "sign");
7663 }
7664 else
7665 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
7666 break;
7667
7668 case GLSLstd450Reflect:
7669 if (get<SPIRType>(result_type).vecsize == 1)
7670 emit_binary_func_op(result_type, id, args[0], args[1], "spvReflect");
7671 else
7672 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
7673 break;
7674
7675 case GLSLstd450Refract:
7676 if (get<SPIRType>(result_type).vecsize == 1)
7677 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvRefract");
7678 else
7679 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
7680 break;
7681
7682 case GLSLstd450FaceForward:
7683 if (get<SPIRType>(result_type).vecsize == 1)
7684 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvFaceForward");
7685 else
7686 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
7687 break;
7688
7689 case GLSLstd450Modf:
7690 case GLSLstd450Frexp:
7691 {
7692 // Special case. If the variable is a scalar access chain, we cannot use it directly. We have to emit a temporary.
7693 auto *ptr = maybe_get<SPIRExpression>(args[1]);
7694 if (ptr && ptr->access_chain && is_scalar(expression_type(args[1])))
7695 {
7696 register_call_out_argument(args[1]);
7697 forced_temporaries.insert(id);
7698
7699 // Need to create temporaries and copy over to access chain after.
7700 // We cannot directly take the reference of a vector swizzle in MSL, even if it's scalar ...
7701 uint32_t &tmp_id = extra_sub_expressions[id];
7702 if (!tmp_id)
7703 tmp_id = ir.increase_bound_by(1);
7704
7705 uint32_t tmp_type_id = get_pointee_type_id(ptr->expression_type);
7706 emit_uninitialized_temporary_expression(tmp_type_id, tmp_id);
7707 emit_binary_func_op(result_type, id, args[0], tmp_id, eop == GLSLstd450Modf ? "modf" : "frexp");
7708 statement(to_expression(args[1]), " = ", to_expression(tmp_id), ";");
7709 }
7710 else
7711 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
7712 break;
7713 }
7714
7715 default:
7716 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
7717 break;
7718 }
7719 }
7720
emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type,uint32_t id,uint32_t eop,const uint32_t * args,uint32_t count)7721 void CompilerMSL::emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type, uint32_t id, uint32_t eop,
7722 const uint32_t *args, uint32_t count)
7723 {
7724 enum AMDShaderTrinaryMinMax
7725 {
7726 FMin3AMD = 1,
7727 UMin3AMD = 2,
7728 SMin3AMD = 3,
7729 FMax3AMD = 4,
7730 UMax3AMD = 5,
7731 SMax3AMD = 6,
7732 FMid3AMD = 7,
7733 UMid3AMD = 8,
7734 SMid3AMD = 9
7735 };
7736
7737 if (!msl_options.supports_msl_version(2, 1))
7738 SPIRV_CROSS_THROW("Trinary min/max functions require MSL 2.1.");
7739
7740 auto op = static_cast<AMDShaderTrinaryMinMax>(eop);
7741
7742 switch (op)
7743 {
7744 case FMid3AMD:
7745 case UMid3AMD:
7746 case SMid3AMD:
7747 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "median3");
7748 break;
7749 default:
7750 CompilerGLSL::emit_spv_amd_shader_trinary_minmax_op(result_type, id, eop, args, count);
7751 break;
7752 }
7753 }
7754
7755 // Emit a structure declaration for the specified interface variable.
emit_interface_block(uint32_t ib_var_id)7756 void CompilerMSL::emit_interface_block(uint32_t ib_var_id)
7757 {
7758 if (ib_var_id)
7759 {
7760 auto &ib_var = get<SPIRVariable>(ib_var_id);
7761 auto &ib_type = get_variable_data_type(ib_var);
7762 assert(ib_type.basetype == SPIRType::Struct && !ib_type.member_types.empty());
7763 emit_struct(ib_type);
7764 }
7765 }
7766
7767 // Emits the declaration signature of the specified function.
7768 // If this is the entry point function, Metal-specific return value and function arguments are added.
emit_function_prototype(SPIRFunction & func,const Bitset &)7769 void CompilerMSL::emit_function_prototype(SPIRFunction &func, const Bitset &)
7770 {
7771 if (func.self != ir.default_entry_point)
7772 add_function_overload(func);
7773
7774 local_variable_names = resource_names;
7775 string decl;
7776
7777 processing_entry_point = func.self == ir.default_entry_point;
7778
7779 // Metal helper functions must be static force-inline otherwise they will cause problems when linked together in a single Metallib.
7780 if (!processing_entry_point)
7781 statement(force_inline);
7782
7783 auto &type = get<SPIRType>(func.return_type);
7784
7785 if (!type.array.empty() && msl_options.force_native_arrays)
7786 {
7787 // We cannot return native arrays in MSL, so "return" through an out variable.
7788 decl += "void";
7789 }
7790 else
7791 {
7792 decl += func_type_decl(type);
7793 }
7794
7795 decl += " ";
7796 decl += to_name(func.self);
7797 decl += "(";
7798
7799 if (!type.array.empty() && msl_options.force_native_arrays)
7800 {
7801 // Fake arrays returns by writing to an out array instead.
7802 decl += "thread ";
7803 decl += type_to_glsl(type);
7804 decl += " (&SPIRV_Cross_return_value)";
7805 decl += type_to_array_glsl(type);
7806 if (!func.arguments.empty())
7807 decl += ", ";
7808 }
7809
7810 if (processing_entry_point)
7811 {
7812 if (msl_options.argument_buffers)
7813 decl += entry_point_args_argument_buffer(!func.arguments.empty());
7814 else
7815 decl += entry_point_args_classic(!func.arguments.empty());
7816
7817 // If entry point function has variables that require early declaration,
7818 // ensure they each have an empty initializer, creating one if needed.
7819 // This is done at this late stage because the initialization expression
7820 // is cleared after each compilation pass.
7821 for (auto var_id : vars_needing_early_declaration)
7822 {
7823 auto &ed_var = get<SPIRVariable>(var_id);
7824 ID &initializer = ed_var.initializer;
7825 if (!initializer)
7826 initializer = ir.increase_bound_by(1);
7827
7828 // Do not override proper initializers.
7829 if (ir.ids[initializer].get_type() == TypeNone || ir.ids[initializer].get_type() == TypeExpression)
7830 set<SPIRExpression>(ed_var.initializer, "{}", ed_var.basetype, true);
7831 }
7832 }
7833
7834 for (auto &arg : func.arguments)
7835 {
7836 uint32_t name_id = arg.id;
7837
7838 auto *var = maybe_get<SPIRVariable>(arg.id);
7839 if (var)
7840 {
7841 // If we need to modify the name of the variable, make sure we modify the original variable.
7842 // Our alias is just a shadow variable.
7843 if (arg.alias_global_variable && var->basevariable)
7844 name_id = var->basevariable;
7845
7846 var->parameter = &arg; // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
7847 }
7848
7849 add_local_variable_name(name_id);
7850
7851 decl += argument_decl(arg);
7852
7853 bool is_dynamic_img_sampler = has_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
7854
7855 auto &arg_type = get<SPIRType>(arg.type);
7856 if (arg_type.basetype == SPIRType::SampledImage && !is_dynamic_img_sampler)
7857 {
7858 // Manufacture automatic plane args for multiplanar texture
7859 uint32_t planes = 1;
7860 if (auto *constexpr_sampler = find_constexpr_sampler(name_id))
7861 if (constexpr_sampler->ycbcr_conversion_enable)
7862 planes = constexpr_sampler->planes;
7863 for (uint32_t i = 1; i < planes; i++)
7864 decl += join(", ", argument_decl(arg), plane_name_suffix, i);
7865
7866 // Manufacture automatic sampler arg for SampledImage texture
7867 if (arg_type.image.dim != DimBuffer)
7868 decl += join(", thread const ", sampler_type(arg_type), " ", to_sampler_expression(arg.id));
7869 }
7870
7871 // Manufacture automatic swizzle arg.
7872 if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(arg_type) &&
7873 !is_dynamic_img_sampler)
7874 {
7875 bool arg_is_array = !arg_type.array.empty();
7876 decl += join(", constant uint", arg_is_array ? "* " : "& ", to_swizzle_expression(arg.id));
7877 }
7878
7879 if (buffers_requiring_array_length.count(name_id))
7880 {
7881 bool arg_is_array = !arg_type.array.empty();
7882 decl += join(", constant uint", arg_is_array ? "* " : "& ", to_buffer_size_expression(name_id));
7883 }
7884
7885 if (&arg != &func.arguments.back())
7886 decl += ", ";
7887 }
7888
7889 decl += ")";
7890 statement(decl);
7891 }
7892
needs_chroma_reconstruction(const MSLConstexprSampler * constexpr_sampler)7893 static bool needs_chroma_reconstruction(const MSLConstexprSampler *constexpr_sampler)
7894 {
7895 // For now, only multiplanar images need explicit reconstruction. GBGR and BGRG images
7896 // use implicit reconstruction.
7897 return constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && constexpr_sampler->planes > 1;
7898 }
7899
7900 // Returns the texture sampling function string for the specified image and sampling characteristics.
to_function_name(const TextureFunctionNameArguments & args)7901 string CompilerMSL::to_function_name(const TextureFunctionNameArguments &args)
7902 {
7903 VariableID img = args.base.img;
7904 auto &imgtype = *args.base.imgtype;
7905
7906 const MSLConstexprSampler *constexpr_sampler = nullptr;
7907 bool is_dynamic_img_sampler = false;
7908 if (auto *var = maybe_get_backing_variable(img))
7909 {
7910 constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
7911 is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
7912 }
7913
7914 // Special-case gather. We have to alter the component being looked up
7915 // in the swizzle case.
7916 if (msl_options.swizzle_texture_samples && args.base.is_gather && !is_dynamic_img_sampler &&
7917 (!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable))
7918 {
7919 add_spv_func_and_recompile(imgtype.image.depth ? SPVFuncImplGatherCompareSwizzle : SPVFuncImplGatherSwizzle);
7920 return imgtype.image.depth ? "spvGatherCompareSwizzle" : "spvGatherSwizzle";
7921 }
7922
7923 auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
7924
7925 // Texture reference
7926 string fname;
7927 if (needs_chroma_reconstruction(constexpr_sampler) && !is_dynamic_img_sampler)
7928 {
7929 if (constexpr_sampler->planes != 2 && constexpr_sampler->planes != 3)
7930 SPIRV_CROSS_THROW("Unhandled number of color image planes!");
7931 // 444 images aren't downsampled, so we don't need to do linear filtering.
7932 if (constexpr_sampler->resolution == MSL_FORMAT_RESOLUTION_444 ||
7933 constexpr_sampler->chroma_filter == MSL_SAMPLER_FILTER_NEAREST)
7934 {
7935 if (constexpr_sampler->planes == 2)
7936 add_spv_func_and_recompile(SPVFuncImplChromaReconstructNearest2Plane);
7937 else
7938 add_spv_func_and_recompile(SPVFuncImplChromaReconstructNearest3Plane);
7939 fname = "spvChromaReconstructNearest";
7940 }
7941 else // Linear with a downsampled format
7942 {
7943 fname = "spvChromaReconstructLinear";
7944 switch (constexpr_sampler->resolution)
7945 {
7946 case MSL_FORMAT_RESOLUTION_444:
7947 assert(false);
7948 break; // not reached
7949 case MSL_FORMAT_RESOLUTION_422:
7950 switch (constexpr_sampler->x_chroma_offset)
7951 {
7952 case MSL_CHROMA_LOCATION_COSITED_EVEN:
7953 if (constexpr_sampler->planes == 2)
7954 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422CositedEven2Plane);
7955 else
7956 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422CositedEven3Plane);
7957 fname += "422CositedEven";
7958 break;
7959 case MSL_CHROMA_LOCATION_MIDPOINT:
7960 if (constexpr_sampler->planes == 2)
7961 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422Midpoint2Plane);
7962 else
7963 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422Midpoint3Plane);
7964 fname += "422Midpoint";
7965 break;
7966 default:
7967 SPIRV_CROSS_THROW("Invalid chroma location.");
7968 }
7969 break;
7970 case MSL_FORMAT_RESOLUTION_420:
7971 fname += "420";
7972 switch (constexpr_sampler->x_chroma_offset)
7973 {
7974 case MSL_CHROMA_LOCATION_COSITED_EVEN:
7975 switch (constexpr_sampler->y_chroma_offset)
7976 {
7977 case MSL_CHROMA_LOCATION_COSITED_EVEN:
7978 if (constexpr_sampler->planes == 2)
7979 add_spv_func_and_recompile(
7980 SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane);
7981 else
7982 add_spv_func_and_recompile(
7983 SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane);
7984 fname += "XCositedEvenYCositedEven";
7985 break;
7986 case MSL_CHROMA_LOCATION_MIDPOINT:
7987 if (constexpr_sampler->planes == 2)
7988 add_spv_func_and_recompile(
7989 SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane);
7990 else
7991 add_spv_func_and_recompile(
7992 SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane);
7993 fname += "XCositedEvenYMidpoint";
7994 break;
7995 default:
7996 SPIRV_CROSS_THROW("Invalid Y chroma location.");
7997 }
7998 break;
7999 case MSL_CHROMA_LOCATION_MIDPOINT:
8000 switch (constexpr_sampler->y_chroma_offset)
8001 {
8002 case MSL_CHROMA_LOCATION_COSITED_EVEN:
8003 if (constexpr_sampler->planes == 2)
8004 add_spv_func_and_recompile(
8005 SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane);
8006 else
8007 add_spv_func_and_recompile(
8008 SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane);
8009 fname += "XMidpointYCositedEven";
8010 break;
8011 case MSL_CHROMA_LOCATION_MIDPOINT:
8012 if (constexpr_sampler->planes == 2)
8013 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane);
8014 else
8015 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane);
8016 fname += "XMidpointYMidpoint";
8017 break;
8018 default:
8019 SPIRV_CROSS_THROW("Invalid Y chroma location.");
8020 }
8021 break;
8022 default:
8023 SPIRV_CROSS_THROW("Invalid X chroma location.");
8024 }
8025 break;
8026 default:
8027 SPIRV_CROSS_THROW("Invalid format resolution.");
8028 }
8029 }
8030 }
8031 else
8032 {
8033 fname = to_expression(combined ? combined->image : img) + ".";
8034
8035 // Texture function and sampler
8036 if (args.base.is_fetch)
8037 fname += "read";
8038 else if (args.base.is_gather)
8039 fname += "gather";
8040 else
8041 fname += "sample";
8042
8043 if (args.has_dref)
8044 fname += "_compare";
8045 }
8046
8047 return fname;
8048 }
8049
convert_to_f32(const string & expr,uint32_t components)8050 string CompilerMSL::convert_to_f32(const string &expr, uint32_t components)
8051 {
8052 SPIRType t;
8053 t.basetype = SPIRType::Float;
8054 t.vecsize = components;
8055 t.columns = 1;
8056 return join(type_to_glsl_constructor(t), "(", expr, ")");
8057 }
8058
sampling_type_needs_f32_conversion(const SPIRType & type)8059 static inline bool sampling_type_needs_f32_conversion(const SPIRType &type)
8060 {
8061 // Double is not supported to begin with, but doesn't hurt to check for completion.
8062 return type.basetype == SPIRType::Half || type.basetype == SPIRType::Double;
8063 }
8064
8065 // Returns the function args for a texture sampling function for the specified image and sampling characteristics.
to_function_args(const TextureFunctionArguments & args,bool * p_forward)8066 string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool *p_forward)
8067 {
8068 VariableID img = args.base.img;
8069 auto &imgtype = *args.base.imgtype;
8070 uint32_t lod = args.lod;
8071 uint32_t grad_x = args.grad_x;
8072 uint32_t grad_y = args.grad_y;
8073 uint32_t bias = args.bias;
8074
8075 const MSLConstexprSampler *constexpr_sampler = nullptr;
8076 bool is_dynamic_img_sampler = false;
8077 if (auto *var = maybe_get_backing_variable(img))
8078 {
8079 constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
8080 is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
8081 }
8082
8083 string farg_str;
8084 bool forward = true;
8085
8086 if (!is_dynamic_img_sampler)
8087 {
8088 // Texture reference (for some cases)
8089 if (needs_chroma_reconstruction(constexpr_sampler))
8090 {
8091 // Multiplanar images need two or three textures.
8092 farg_str += to_expression(img);
8093 for (uint32_t i = 1; i < constexpr_sampler->planes; i++)
8094 farg_str += join(", ", to_expression(img), plane_name_suffix, i);
8095 }
8096 else if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
8097 msl_options.swizzle_texture_samples && args.base.is_gather)
8098 {
8099 auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
8100 farg_str += to_expression(combined ? combined->image : img);
8101 }
8102
8103 // Sampler reference
8104 if (!args.base.is_fetch)
8105 {
8106 if (!farg_str.empty())
8107 farg_str += ", ";
8108 farg_str += to_sampler_expression(img);
8109 }
8110
8111 if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
8112 msl_options.swizzle_texture_samples && args.base.is_gather)
8113 {
8114 // Add the swizzle constant from the swizzle buffer.
8115 farg_str += ", " + to_swizzle_expression(img);
8116 used_swizzle_buffer = true;
8117 }
8118
8119 // Swizzled gather puts the component before the other args, to allow template
8120 // deduction to work.
8121 if (args.component && msl_options.swizzle_texture_samples)
8122 {
8123 forward = should_forward(args.component);
8124 farg_str += ", " + to_component_argument(args.component);
8125 }
8126 }
8127
8128 // Texture coordinates
8129 forward = forward && should_forward(args.coord);
8130 auto coord_expr = to_enclosed_expression(args.coord);
8131 auto &coord_type = expression_type(args.coord);
8132 bool coord_is_fp = type_is_floating_point(coord_type);
8133 bool is_cube_fetch = false;
8134
8135 string tex_coords = coord_expr;
8136 uint32_t alt_coord_component = 0;
8137
8138 switch (imgtype.image.dim)
8139 {
8140
8141 case Dim1D:
8142 if (coord_type.vecsize > 1)
8143 tex_coords = enclose_expression(tex_coords) + ".x";
8144
8145 if (args.base.is_fetch)
8146 tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8147 else if (sampling_type_needs_f32_conversion(coord_type))
8148 tex_coords = convert_to_f32(tex_coords, 1);
8149
8150 if (msl_options.texture_1D_as_2D)
8151 {
8152 if (args.base.is_fetch)
8153 tex_coords = "uint2(" + tex_coords + ", 0)";
8154 else
8155 tex_coords = "float2(" + tex_coords + ", 0.5)";
8156 }
8157
8158 alt_coord_component = 1;
8159 break;
8160
8161 case DimBuffer:
8162 if (coord_type.vecsize > 1)
8163 tex_coords = enclose_expression(tex_coords) + ".x";
8164
8165 if (msl_options.texture_buffer_native)
8166 {
8167 tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8168 }
8169 else
8170 {
8171 // Metal texel buffer textures are 2D, so convert 1D coord to 2D.
8172 // Support for Metal 2.1's new texture_buffer type.
8173 if (args.base.is_fetch)
8174 {
8175 if (msl_options.texel_buffer_texture_width > 0)
8176 {
8177 tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8178 }
8179 else
8180 {
8181 tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ", " +
8182 to_expression(img) + ")";
8183 }
8184 }
8185 }
8186
8187 alt_coord_component = 1;
8188 break;
8189
8190 case DimSubpassData:
8191 // If we're using Metal's native frame-buffer fetch API for subpass inputs,
8192 // this path will not be hit.
8193 tex_coords = "uint2(gl_FragCoord.xy)";
8194 alt_coord_component = 2;
8195 break;
8196
8197 case Dim2D:
8198 if (coord_type.vecsize > 2)
8199 tex_coords = enclose_expression(tex_coords) + ".xy";
8200
8201 if (args.base.is_fetch)
8202 tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8203 else if (sampling_type_needs_f32_conversion(coord_type))
8204 tex_coords = convert_to_f32(tex_coords, 2);
8205
8206 alt_coord_component = 2;
8207 break;
8208
8209 case Dim3D:
8210 if (coord_type.vecsize > 3)
8211 tex_coords = enclose_expression(tex_coords) + ".xyz";
8212
8213 if (args.base.is_fetch)
8214 tex_coords = "uint3(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8215 else if (sampling_type_needs_f32_conversion(coord_type))
8216 tex_coords = convert_to_f32(tex_coords, 3);
8217
8218 alt_coord_component = 3;
8219 break;
8220
8221 case DimCube:
8222 if (args.base.is_fetch)
8223 {
8224 is_cube_fetch = true;
8225 tex_coords += ".xy";
8226 tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8227 }
8228 else
8229 {
8230 if (coord_type.vecsize > 3)
8231 tex_coords = enclose_expression(tex_coords) + ".xyz";
8232 }
8233
8234 if (sampling_type_needs_f32_conversion(coord_type))
8235 tex_coords = convert_to_f32(tex_coords, 3);
8236
8237 alt_coord_component = 3;
8238 break;
8239
8240 default:
8241 break;
8242 }
8243
8244 if (args.base.is_fetch && args.offset)
8245 {
8246 // Fetch offsets must be applied directly to the coordinate.
8247 forward = forward && should_forward(args.offset);
8248 auto &type = expression_type(args.offset);
8249 if (type.basetype != SPIRType::UInt)
8250 tex_coords += " + " + bitcast_expression(SPIRType::UInt, args.offset);
8251 else
8252 tex_coords += " + " + to_enclosed_expression(args.offset);
8253 }
8254 else if (args.base.is_fetch && args.coffset)
8255 {
8256 // Fetch offsets must be applied directly to the coordinate.
8257 forward = forward && should_forward(args.coffset);
8258 auto &type = expression_type(args.coffset);
8259 if (type.basetype != SPIRType::UInt)
8260 tex_coords += " + " + bitcast_expression(SPIRType::UInt, args.coffset);
8261 else
8262 tex_coords += " + " + to_enclosed_expression(args.coffset);
8263 }
8264
8265 // If projection, use alt coord as divisor
8266 if (args.base.is_proj)
8267 {
8268 if (sampling_type_needs_f32_conversion(coord_type))
8269 tex_coords += " / " + convert_to_f32(to_extract_component_expression(args.coord, alt_coord_component), 1);
8270 else
8271 tex_coords += " / " + to_extract_component_expression(args.coord, alt_coord_component);
8272 }
8273
8274 if (!farg_str.empty())
8275 farg_str += ", ";
8276
8277 if (imgtype.image.dim == DimCube && imgtype.image.arrayed && msl_options.emulate_cube_array)
8278 {
8279 farg_str += "spvCubemapTo2DArrayFace(" + tex_coords + ").xy";
8280
8281 if (is_cube_fetch)
8282 farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ")";
8283 else
8284 farg_str +=
8285 ", uint(spvCubemapTo2DArrayFace(" + tex_coords + ").z) + (uint(" +
8286 round_fp_tex_coords(to_extract_component_expression(args.coord, alt_coord_component), coord_is_fp) +
8287 ") * 6u)";
8288
8289 add_spv_func_and_recompile(SPVFuncImplCubemapTo2DArrayFace);
8290 }
8291 else
8292 {
8293 farg_str += tex_coords;
8294
8295 // If fetch from cube, add face explicitly
8296 if (is_cube_fetch)
8297 {
8298 // Special case for cube arrays, face and layer are packed in one dimension.
8299 if (imgtype.image.arrayed)
8300 farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ") % 6u";
8301 else
8302 farg_str +=
8303 ", uint(" + round_fp_tex_coords(to_extract_component_expression(args.coord, 2), coord_is_fp) + ")";
8304 }
8305
8306 // If array, use alt coord
8307 if (imgtype.image.arrayed)
8308 {
8309 // Special case for cube arrays, face and layer are packed in one dimension.
8310 if (imgtype.image.dim == DimCube && args.base.is_fetch)
8311 {
8312 farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ") / 6u";
8313 }
8314 else
8315 {
8316 farg_str +=
8317 ", uint(" +
8318 round_fp_tex_coords(to_extract_component_expression(args.coord, alt_coord_component), coord_is_fp) +
8319 ")";
8320 if (imgtype.image.dim == DimSubpassData)
8321 {
8322 if (msl_options.multiview)
8323 farg_str += " + gl_ViewIndex";
8324 else if (msl_options.arrayed_subpass_input)
8325 farg_str += " + gl_Layer";
8326 }
8327 }
8328 }
8329 else if (imgtype.image.dim == DimSubpassData)
8330 {
8331 if (msl_options.multiview)
8332 farg_str += ", gl_ViewIndex";
8333 else if (msl_options.arrayed_subpass_input)
8334 farg_str += ", gl_Layer";
8335 }
8336 }
8337
8338 // Depth compare reference value
8339 if (args.dref)
8340 {
8341 forward = forward && should_forward(args.dref);
8342 farg_str += ", ";
8343
8344 auto &dref_type = expression_type(args.dref);
8345
8346 string dref_expr;
8347 if (args.base.is_proj)
8348 dref_expr = join(to_enclosed_expression(args.dref), " / ",
8349 to_extract_component_expression(args.coord, alt_coord_component));
8350 else
8351 dref_expr = to_expression(args.dref);
8352
8353 if (sampling_type_needs_f32_conversion(dref_type))
8354 dref_expr = convert_to_f32(dref_expr, 1);
8355
8356 farg_str += dref_expr;
8357
8358 if (msl_options.is_macos() && (grad_x || grad_y))
8359 {
8360 // For sample compare, MSL does not support gradient2d for all targets (only iOS apparently according to docs).
8361 // However, the most common case here is to have a constant gradient of 0, as that is the only way to express
8362 // LOD == 0 in GLSL with sampler2DArrayShadow (cascaded shadow mapping).
8363 // We will detect a compile-time constant 0 value for gradient and promote that to level(0) on MSL.
8364 bool constant_zero_x = !grad_x || expression_is_constant_null(grad_x);
8365 bool constant_zero_y = !grad_y || expression_is_constant_null(grad_y);
8366 if (constant_zero_x && constant_zero_y)
8367 {
8368 lod = 0;
8369 grad_x = 0;
8370 grad_y = 0;
8371 farg_str += ", level(0)";
8372 }
8373 else
8374 {
8375 SPIRV_CROSS_THROW("Using non-constant 0.0 gradient() qualifier for sample_compare. This is not "
8376 "supported in MSL macOS.");
8377 }
8378 }
8379
8380 if (msl_options.is_macos() && bias)
8381 {
8382 // Bias is not supported either on macOS with sample_compare.
8383 // Verify it is compile-time zero, and drop the argument.
8384 if (expression_is_constant_null(bias))
8385 {
8386 bias = 0;
8387 }
8388 else
8389 {
8390 SPIRV_CROSS_THROW(
8391 "Using non-constant 0.0 bias() qualifier for sample_compare. This is not supported in MSL macOS.");
8392 }
8393 }
8394 }
8395
8396 // LOD Options
8397 // Metal does not support LOD for 1D textures.
8398 if (bias && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
8399 {
8400 forward = forward && should_forward(bias);
8401 farg_str += ", bias(" + to_expression(bias) + ")";
8402 }
8403
8404 // Metal does not support LOD for 1D textures.
8405 if (lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
8406 {
8407 forward = forward && should_forward(lod);
8408 if (args.base.is_fetch)
8409 {
8410 farg_str += ", " + to_expression(lod);
8411 }
8412 else
8413 {
8414 farg_str += ", level(" + to_expression(lod) + ")";
8415 }
8416 }
8417 else if (args.base.is_fetch && !lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D) &&
8418 imgtype.image.dim != DimBuffer && !imgtype.image.ms && imgtype.image.sampled != 2)
8419 {
8420 // Lod argument is optional in OpImageFetch, but we require a LOD value, pick 0 as the default.
8421 // Check for sampled type as well, because is_fetch is also used for OpImageRead in MSL.
8422 farg_str += ", 0";
8423 }
8424
8425 // Metal does not support LOD for 1D textures.
8426 if ((grad_x || grad_y) && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
8427 {
8428 forward = forward && should_forward(grad_x);
8429 forward = forward && should_forward(grad_y);
8430 string grad_opt;
8431 switch (imgtype.image.dim)
8432 {
8433 case Dim2D:
8434 grad_opt = "2d";
8435 break;
8436 case Dim3D:
8437 grad_opt = "3d";
8438 break;
8439 case DimCube:
8440 if (imgtype.image.arrayed && msl_options.emulate_cube_array)
8441 grad_opt = "2d";
8442 else
8443 grad_opt = "cube";
8444 break;
8445 default:
8446 grad_opt = "unsupported_gradient_dimension";
8447 break;
8448 }
8449 farg_str += ", gradient" + grad_opt + "(" + to_expression(grad_x) + ", " + to_expression(grad_y) + ")";
8450 }
8451
8452 if (args.min_lod)
8453 {
8454 if (msl_options.is_macos())
8455 {
8456 if (!msl_options.supports_msl_version(2, 2))
8457 SPIRV_CROSS_THROW("min_lod_clamp() is only supported in MSL 2.2+ and up on macOS.");
8458 }
8459 else if (msl_options.is_ios())
8460 SPIRV_CROSS_THROW("min_lod_clamp() is not supported on iOS.");
8461
8462 forward = forward && should_forward(args.min_lod);
8463 farg_str += ", min_lod_clamp(" + to_expression(args.min_lod) + ")";
8464 }
8465
8466 // Add offsets
8467 string offset_expr;
8468 if (args.coffset && !args.base.is_fetch)
8469 {
8470 forward = forward && should_forward(args.coffset);
8471 offset_expr = to_expression(args.coffset);
8472 }
8473 else if (args.offset && !args.base.is_fetch)
8474 {
8475 forward = forward && should_forward(args.offset);
8476 offset_expr = to_expression(args.offset);
8477 }
8478
8479 if (!offset_expr.empty())
8480 {
8481 switch (imgtype.image.dim)
8482 {
8483 case Dim2D:
8484 if (coord_type.vecsize > 2)
8485 offset_expr = enclose_expression(offset_expr) + ".xy";
8486
8487 farg_str += ", " + offset_expr;
8488 break;
8489
8490 case Dim3D:
8491 if (coord_type.vecsize > 3)
8492 offset_expr = enclose_expression(offset_expr) + ".xyz";
8493
8494 farg_str += ", " + offset_expr;
8495 break;
8496
8497 default:
8498 break;
8499 }
8500 }
8501
8502 if (args.component)
8503 {
8504 // If 2D has gather component, ensure it also has an offset arg
8505 if (imgtype.image.dim == Dim2D && offset_expr.empty())
8506 farg_str += ", int2(0)";
8507
8508 if (!msl_options.swizzle_texture_samples || is_dynamic_img_sampler)
8509 {
8510 forward = forward && should_forward(args.component);
8511 farg_str += ", " + to_component_argument(args.component);
8512 }
8513 }
8514
8515 if (args.sample)
8516 {
8517 forward = forward && should_forward(args.sample);
8518 farg_str += ", ";
8519 farg_str += to_expression(args.sample);
8520 }
8521
8522 *p_forward = forward;
8523
8524 return farg_str;
8525 }
8526
8527 // If the texture coordinates are floating point, invokes MSL round() function to round them.
round_fp_tex_coords(string tex_coords,bool coord_is_fp)8528 string CompilerMSL::round_fp_tex_coords(string tex_coords, bool coord_is_fp)
8529 {
8530 return coord_is_fp ? ("round(" + tex_coords + ")") : tex_coords;
8531 }
8532
8533 // Returns a string to use in an image sampling function argument.
8534 // The ID must be a scalar constant.
to_component_argument(uint32_t id)8535 string CompilerMSL::to_component_argument(uint32_t id)
8536 {
8537 uint32_t component_index = evaluate_constant_u32(id);
8538 switch (component_index)
8539 {
8540 case 0:
8541 return "component::x";
8542 case 1:
8543 return "component::y";
8544 case 2:
8545 return "component::z";
8546 case 3:
8547 return "component::w";
8548
8549 default:
8550 SPIRV_CROSS_THROW("The value (" + to_string(component_index) + ") of OpConstant ID " + to_string(id) +
8551 " is not a valid Component index, which must be one of 0, 1, 2, or 3.");
8552 return "component::x";
8553 }
8554 }
8555
8556 // Establish sampled image as expression object and assign the sampler to it.
emit_sampled_image_op(uint32_t result_type,uint32_t result_id,uint32_t image_id,uint32_t samp_id)8557 void CompilerMSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
8558 {
8559 set<SPIRCombinedImageSampler>(result_id, result_type, image_id, samp_id);
8560 }
8561
to_texture_op(const Instruction & i,bool sparse,bool * forward,SmallVector<uint32_t> & inherited_expressions)8562 string CompilerMSL::to_texture_op(const Instruction &i, bool sparse, bool *forward,
8563 SmallVector<uint32_t> &inherited_expressions)
8564 {
8565 auto *ops = stream(i);
8566 uint32_t result_type_id = ops[0];
8567 uint32_t img = ops[2];
8568 auto &result_type = get<SPIRType>(result_type_id);
8569 auto op = static_cast<Op>(i.op);
8570 bool is_gather = (op == OpImageGather || op == OpImageDrefGather);
8571
8572 // Bypass pointers because we need the real image struct
8573 auto &type = expression_type(img);
8574 auto &imgtype = get<SPIRType>(type.self);
8575
8576 const MSLConstexprSampler *constexpr_sampler = nullptr;
8577 bool is_dynamic_img_sampler = false;
8578 if (auto *var = maybe_get_backing_variable(img))
8579 {
8580 constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
8581 is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
8582 }
8583
8584 string expr;
8585 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
8586 {
8587 // If this needs sampler Y'CbCr conversion, we need to do some additional
8588 // processing.
8589 switch (constexpr_sampler->ycbcr_model)
8590 {
8591 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
8592 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
8593 // Default
8594 break;
8595 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
8596 add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT709);
8597 expr += "spvConvertYCbCrBT709(";
8598 break;
8599 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
8600 add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT601);
8601 expr += "spvConvertYCbCrBT601(";
8602 break;
8603 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
8604 add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT2020);
8605 expr += "spvConvertYCbCrBT2020(";
8606 break;
8607 default:
8608 SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
8609 }
8610
8611 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
8612 {
8613 switch (constexpr_sampler->ycbcr_range)
8614 {
8615 case MSL_SAMPLER_YCBCR_RANGE_ITU_FULL:
8616 add_spv_func_and_recompile(SPVFuncImplExpandITUFullRange);
8617 expr += "spvExpandITUFullRange(";
8618 break;
8619 case MSL_SAMPLER_YCBCR_RANGE_ITU_NARROW:
8620 add_spv_func_and_recompile(SPVFuncImplExpandITUNarrowRange);
8621 expr += "spvExpandITUNarrowRange(";
8622 break;
8623 default:
8624 SPIRV_CROSS_THROW("Invalid Y'CbCr range.");
8625 }
8626 }
8627 }
8628 else if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype) &&
8629 !is_dynamic_img_sampler)
8630 {
8631 add_spv_func_and_recompile(SPVFuncImplTextureSwizzle);
8632 expr += "spvTextureSwizzle(";
8633 }
8634
8635 string inner_expr = CompilerGLSL::to_texture_op(i, sparse, forward, inherited_expressions);
8636
8637 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
8638 {
8639 if (!constexpr_sampler->swizzle_is_identity())
8640 {
8641 static const char swizzle_names[] = "rgba";
8642 if (!constexpr_sampler->swizzle_has_one_or_zero())
8643 {
8644 // If we can, do it inline.
8645 expr += inner_expr + ".";
8646 for (uint32_t c = 0; c < 4; c++)
8647 {
8648 switch (constexpr_sampler->swizzle[c])
8649 {
8650 case MSL_COMPONENT_SWIZZLE_IDENTITY:
8651 expr += swizzle_names[c];
8652 break;
8653 case MSL_COMPONENT_SWIZZLE_R:
8654 case MSL_COMPONENT_SWIZZLE_G:
8655 case MSL_COMPONENT_SWIZZLE_B:
8656 case MSL_COMPONENT_SWIZZLE_A:
8657 expr += swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
8658 break;
8659 default:
8660 SPIRV_CROSS_THROW("Invalid component swizzle.");
8661 }
8662 }
8663 }
8664 else
8665 {
8666 // Otherwise, we need to emit a temporary and swizzle that.
8667 uint32_t temp_id = ir.increase_bound_by(1);
8668 emit_op(result_type_id, temp_id, inner_expr, false);
8669 for (auto &inherit : inherited_expressions)
8670 inherit_expression_dependencies(temp_id, inherit);
8671 inherited_expressions.clear();
8672 inherited_expressions.push_back(temp_id);
8673
8674 switch (op)
8675 {
8676 case OpImageSampleDrefImplicitLod:
8677 case OpImageSampleImplicitLod:
8678 case OpImageSampleProjImplicitLod:
8679 case OpImageSampleProjDrefImplicitLod:
8680 register_control_dependent_expression(temp_id);
8681 break;
8682
8683 default:
8684 break;
8685 }
8686 expr += type_to_glsl(result_type) + "(";
8687 for (uint32_t c = 0; c < 4; c++)
8688 {
8689 switch (constexpr_sampler->swizzle[c])
8690 {
8691 case MSL_COMPONENT_SWIZZLE_IDENTITY:
8692 expr += to_expression(temp_id) + "." + swizzle_names[c];
8693 break;
8694 case MSL_COMPONENT_SWIZZLE_ZERO:
8695 expr += "0";
8696 break;
8697 case MSL_COMPONENT_SWIZZLE_ONE:
8698 expr += "1";
8699 break;
8700 case MSL_COMPONENT_SWIZZLE_R:
8701 case MSL_COMPONENT_SWIZZLE_G:
8702 case MSL_COMPONENT_SWIZZLE_B:
8703 case MSL_COMPONENT_SWIZZLE_A:
8704 expr += to_expression(temp_id) + "." +
8705 swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
8706 break;
8707 default:
8708 SPIRV_CROSS_THROW("Invalid component swizzle.");
8709 }
8710 if (c < 3)
8711 expr += ", ";
8712 }
8713 expr += ")";
8714 }
8715 }
8716 else
8717 expr += inner_expr;
8718 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
8719 {
8720 expr += join(", ", constexpr_sampler->bpc, ")");
8721 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY)
8722 expr += ")";
8723 }
8724 }
8725 else
8726 {
8727 expr += inner_expr;
8728 if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype) &&
8729 !is_dynamic_img_sampler)
8730 {
8731 // Add the swizzle constant from the swizzle buffer.
8732 expr += ", " + to_swizzle_expression(img) + ")";
8733 used_swizzle_buffer = true;
8734 }
8735 }
8736
8737 return expr;
8738 }
8739
create_swizzle(MSLComponentSwizzle swizzle)8740 static string create_swizzle(MSLComponentSwizzle swizzle)
8741 {
8742 switch (swizzle)
8743 {
8744 case MSL_COMPONENT_SWIZZLE_IDENTITY:
8745 return "spvSwizzle::none";
8746 case MSL_COMPONENT_SWIZZLE_ZERO:
8747 return "spvSwizzle::zero";
8748 case MSL_COMPONENT_SWIZZLE_ONE:
8749 return "spvSwizzle::one";
8750 case MSL_COMPONENT_SWIZZLE_R:
8751 return "spvSwizzle::red";
8752 case MSL_COMPONENT_SWIZZLE_G:
8753 return "spvSwizzle::green";
8754 case MSL_COMPONENT_SWIZZLE_B:
8755 return "spvSwizzle::blue";
8756 case MSL_COMPONENT_SWIZZLE_A:
8757 return "spvSwizzle::alpha";
8758 default:
8759 SPIRV_CROSS_THROW("Invalid component swizzle.");
8760 return "";
8761 }
8762 }
8763
8764 // Returns a string representation of the ID, usable as a function arg.
8765 // Manufacture automatic sampler arg for SampledImage texture.
to_func_call_arg(const SPIRFunction::Parameter & arg,uint32_t id)8766 string CompilerMSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id)
8767 {
8768 string arg_str;
8769
8770 auto &type = expression_type(id);
8771 bool is_dynamic_img_sampler = has_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
8772 // If the argument *itself* is a "dynamic" combined-image sampler, then we can just pass that around.
8773 bool arg_is_dynamic_img_sampler = has_extended_decoration(id, SPIRVCrossDecorationDynamicImageSampler);
8774 if (is_dynamic_img_sampler && !arg_is_dynamic_img_sampler)
8775 arg_str = join("spvDynamicImageSampler<", type_to_glsl(get<SPIRType>(type.image.type)), ">(");
8776
8777 auto *c = maybe_get<SPIRConstant>(id);
8778 if (msl_options.force_native_arrays && c && !get<SPIRType>(c->constant_type).array.empty())
8779 {
8780 // If we are passing a constant array directly to a function for some reason,
8781 // the callee will expect an argument in thread const address space
8782 // (since we can only bind to arrays with references in MSL).
8783 // To resolve this, we must emit a copy in this address space.
8784 // This kind of code gen should be rare enough that performance is not a real concern.
8785 // Inline the SPIR-V to avoid this kind of suboptimal codegen.
8786 //
8787 // We risk calling this inside a continue block (invalid code),
8788 // so just create a thread local copy in the current function.
8789 arg_str = join("_", id, "_array_copy");
8790 auto &constants = current_function->constant_arrays_needed_on_stack;
8791 auto itr = find(begin(constants), end(constants), ID(id));
8792 if (itr == end(constants))
8793 {
8794 force_recompile();
8795 constants.push_back(id);
8796 }
8797 }
8798 else
8799 arg_str += CompilerGLSL::to_func_call_arg(arg, id);
8800
8801 // Need to check the base variable in case we need to apply a qualified alias.
8802 uint32_t var_id = 0;
8803 auto *var = maybe_get<SPIRVariable>(id);
8804 if (var)
8805 var_id = var->basevariable;
8806
8807 if (!arg_is_dynamic_img_sampler)
8808 {
8809 auto *constexpr_sampler = find_constexpr_sampler(var_id ? var_id : id);
8810 if (type.basetype == SPIRType::SampledImage)
8811 {
8812 // Manufacture automatic plane args for multiplanar texture
8813 uint32_t planes = 1;
8814 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
8815 {
8816 planes = constexpr_sampler->planes;
8817 // If this parameter isn't aliasing a global, then we need to use
8818 // the special "dynamic image-sampler" class to pass it--and we need
8819 // to use it for *every* non-alias parameter, in case a combined
8820 // image-sampler with a Y'CbCr conversion is passed. Hopefully, this
8821 // pathological case is so rare that it should never be hit in practice.
8822 if (!arg.alias_global_variable)
8823 add_spv_func_and_recompile(SPVFuncImplDynamicImageSampler);
8824 }
8825 for (uint32_t i = 1; i < planes; i++)
8826 arg_str += join(", ", CompilerGLSL::to_func_call_arg(arg, id), plane_name_suffix, i);
8827 // Manufacture automatic sampler arg if the arg is a SampledImage texture.
8828 if (type.image.dim != DimBuffer)
8829 arg_str += ", " + to_sampler_expression(var_id ? var_id : id);
8830
8831 // Add sampler Y'CbCr conversion info if we have it
8832 if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
8833 {
8834 SmallVector<string> samp_args;
8835
8836 switch (constexpr_sampler->resolution)
8837 {
8838 case MSL_FORMAT_RESOLUTION_444:
8839 // Default
8840 break;
8841 case MSL_FORMAT_RESOLUTION_422:
8842 samp_args.push_back("spvFormatResolution::_422");
8843 break;
8844 case MSL_FORMAT_RESOLUTION_420:
8845 samp_args.push_back("spvFormatResolution::_420");
8846 break;
8847 default:
8848 SPIRV_CROSS_THROW("Invalid format resolution.");
8849 }
8850
8851 if (constexpr_sampler->chroma_filter != MSL_SAMPLER_FILTER_NEAREST)
8852 samp_args.push_back("spvChromaFilter::linear");
8853
8854 if (constexpr_sampler->x_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
8855 samp_args.push_back("spvXChromaLocation::midpoint");
8856 if (constexpr_sampler->y_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
8857 samp_args.push_back("spvYChromaLocation::midpoint");
8858 switch (constexpr_sampler->ycbcr_model)
8859 {
8860 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
8861 // Default
8862 break;
8863 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
8864 samp_args.push_back("spvYCbCrModelConversion::ycbcr_identity");
8865 break;
8866 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
8867 samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_709");
8868 break;
8869 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
8870 samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_601");
8871 break;
8872 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
8873 samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_2020");
8874 break;
8875 default:
8876 SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
8877 }
8878 if (constexpr_sampler->ycbcr_range != MSL_SAMPLER_YCBCR_RANGE_ITU_FULL)
8879 samp_args.push_back("spvYCbCrRange::itu_narrow");
8880 samp_args.push_back(join("spvComponentBits(", constexpr_sampler->bpc, ")"));
8881 arg_str += join(", spvYCbCrSampler(", merge(samp_args), ")");
8882 }
8883 }
8884
8885 if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
8886 arg_str += join(", (uint(", create_swizzle(constexpr_sampler->swizzle[3]), ") << 24) | (uint(",
8887 create_swizzle(constexpr_sampler->swizzle[2]), ") << 16) | (uint(",
8888 create_swizzle(constexpr_sampler->swizzle[1]), ") << 8) | uint(",
8889 create_swizzle(constexpr_sampler->swizzle[0]), ")");
8890 else if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
8891 arg_str += ", " + to_swizzle_expression(var_id ? var_id : id);
8892
8893 if (buffers_requiring_array_length.count(var_id))
8894 arg_str += ", " + to_buffer_size_expression(var_id ? var_id : id);
8895
8896 if (is_dynamic_img_sampler)
8897 arg_str += ")";
8898 }
8899
8900 // Emulate texture2D atomic operations
8901 auto *backing_var = maybe_get_backing_variable(var_id);
8902 if (backing_var && atomic_image_vars.count(backing_var->self))
8903 {
8904 arg_str += ", " + to_expression(var_id) + "_atomic";
8905 }
8906
8907 return arg_str;
8908 }
8909
8910 // If the ID represents a sampled image that has been assigned a sampler already,
8911 // generate an expression for the sampler, otherwise generate a fake sampler name
8912 // by appending a suffix to the expression constructed from the ID.
to_sampler_expression(uint32_t id)8913 string CompilerMSL::to_sampler_expression(uint32_t id)
8914 {
8915 auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
8916 auto expr = to_expression(combined ? combined->image : VariableID(id));
8917 auto index = expr.find_first_of('[');
8918
8919 uint32_t samp_id = 0;
8920 if (combined)
8921 samp_id = combined->sampler;
8922
8923 if (index == string::npos)
8924 return samp_id ? to_expression(samp_id) : expr + sampler_name_suffix;
8925 else
8926 {
8927 auto image_expr = expr.substr(0, index);
8928 auto array_expr = expr.substr(index);
8929 return samp_id ? to_expression(samp_id) : (image_expr + sampler_name_suffix + array_expr);
8930 }
8931 }
8932
to_swizzle_expression(uint32_t id)8933 string CompilerMSL::to_swizzle_expression(uint32_t id)
8934 {
8935 auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
8936
8937 auto expr = to_expression(combined ? combined->image : VariableID(id));
8938 auto index = expr.find_first_of('[');
8939
8940 // If an image is part of an argument buffer translate this to a legal identifier.
8941 for (auto &c : expr)
8942 if (c == '.')
8943 c = '_';
8944
8945 if (index == string::npos)
8946 return expr + swizzle_name_suffix;
8947 else
8948 {
8949 auto image_expr = expr.substr(0, index);
8950 auto array_expr = expr.substr(index);
8951 return image_expr + swizzle_name_suffix + array_expr;
8952 }
8953 }
8954
to_buffer_size_expression(uint32_t id)8955 string CompilerMSL::to_buffer_size_expression(uint32_t id)
8956 {
8957 auto expr = to_expression(id);
8958 auto index = expr.find_first_of('[');
8959
8960 // This is quite crude, but we need to translate the reference name (*spvDescriptorSetN.name) to
8961 // the pointer expression spvDescriptorSetN.name to make a reasonable expression here.
8962 // This only happens if we have argument buffers and we are using OpArrayLength on a lone SSBO in that set.
8963 if (expr.size() >= 3 && expr[0] == '(' && expr[1] == '*')
8964 expr = address_of_expression(expr);
8965
8966 // If a buffer is part of an argument buffer translate this to a legal identifier.
8967 for (auto &c : expr)
8968 if (c == '.')
8969 c = '_';
8970
8971 if (index == string::npos)
8972 return expr + buffer_size_name_suffix;
8973 else
8974 {
8975 auto buffer_expr = expr.substr(0, index);
8976 auto array_expr = expr.substr(index);
8977 return buffer_expr + buffer_size_name_suffix + array_expr;
8978 }
8979 }
8980
8981 // Checks whether the type is a Block all of whose members have DecorationPatch.
is_patch_block(const SPIRType & type)8982 bool CompilerMSL::is_patch_block(const SPIRType &type)
8983 {
8984 if (!has_decoration(type.self, DecorationBlock))
8985 return false;
8986
8987 for (uint32_t i = 0; i < type.member_types.size(); i++)
8988 {
8989 if (!has_member_decoration(type.self, i, DecorationPatch))
8990 return false;
8991 }
8992
8993 return true;
8994 }
8995
8996 // Checks whether the ID is a row_major matrix that requires conversion before use
is_non_native_row_major_matrix(uint32_t id)8997 bool CompilerMSL::is_non_native_row_major_matrix(uint32_t id)
8998 {
8999 auto *e = maybe_get<SPIRExpression>(id);
9000 if (e)
9001 return e->need_transpose;
9002 else
9003 return has_decoration(id, DecorationRowMajor);
9004 }
9005
9006 // Checks whether the member is a row_major matrix that requires conversion before use
member_is_non_native_row_major_matrix(const SPIRType & type,uint32_t index)9007 bool CompilerMSL::member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index)
9008 {
9009 return has_member_decoration(type.self, index, DecorationRowMajor);
9010 }
9011
convert_row_major_matrix(string exp_str,const SPIRType & exp_type,uint32_t physical_type_id,bool is_packed)9012 string CompilerMSL::convert_row_major_matrix(string exp_str, const SPIRType &exp_type, uint32_t physical_type_id,
9013 bool is_packed)
9014 {
9015 if (!is_matrix(exp_type))
9016 {
9017 return CompilerGLSL::convert_row_major_matrix(move(exp_str), exp_type, physical_type_id, is_packed);
9018 }
9019 else
9020 {
9021 strip_enclosed_expression(exp_str);
9022 if (physical_type_id != 0 || is_packed)
9023 exp_str = unpack_expression_type(exp_str, exp_type, physical_type_id, is_packed, true);
9024 return join("transpose(", exp_str, ")");
9025 }
9026 }
9027
9028 // Called automatically at the end of the entry point function
emit_fixup()9029 void CompilerMSL::emit_fixup()
9030 {
9031 if ((get_execution_model() == ExecutionModelVertex ||
9032 get_execution_model() == ExecutionModelTessellationEvaluation) &&
9033 stage_out_var_id && !qual_pos_var_name.empty() && !capture_output_to_buffer)
9034 {
9035 if (options.vertex.fixup_clipspace)
9036 statement(qual_pos_var_name, ".z = (", qual_pos_var_name, ".z + ", qual_pos_var_name,
9037 ".w) * 0.5; // Adjust clip-space for Metal");
9038
9039 if (options.vertex.flip_vert_y)
9040 statement(qual_pos_var_name, ".y = -(", qual_pos_var_name, ".y);", " // Invert Y-axis for Metal");
9041 }
9042 }
9043
9044 // Return a string defining a structure member, with padding and packing.
to_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier)9045 string CompilerMSL::to_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
9046 const string &qualifier)
9047 {
9048 if (member_is_remapped_physical_type(type, index))
9049 member_type_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID);
9050 auto &physical_type = get<SPIRType>(member_type_id);
9051
9052 // If this member is packed, mark it as so.
9053 string pack_pfx;
9054
9055 // Allow Metal to use the array<T> template to make arrays a value type
9056 uint32_t orig_id = 0;
9057 if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID))
9058 orig_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID);
9059
9060 bool row_major = false;
9061 if (is_matrix(physical_type))
9062 row_major = has_member_decoration(type.self, index, DecorationRowMajor);
9063
9064 SPIRType row_major_physical_type;
9065 const SPIRType *declared_type = &physical_type;
9066
9067 // If a struct is being declared with physical layout,
9068 // do not use array<T> wrappers.
9069 // This avoids a lot of complicated cases with packed vectors and matrices,
9070 // and generally we cannot copy full arrays in and out of buffers into Function
9071 // address space.
9072 // Array of resources should also be declared as builtin arrays.
9073 if (has_member_decoration(type.self, index, DecorationOffset))
9074 is_using_builtin_array = true;
9075 else if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary))
9076 is_using_builtin_array = true;
9077
9078 if (member_is_packed_physical_type(type, index))
9079 {
9080 // If we're packing a matrix, output an appropriate typedef
9081 if (physical_type.basetype == SPIRType::Struct)
9082 {
9083 SPIRV_CROSS_THROW("Cannot emit a packed struct currently.");
9084 }
9085 else if (is_matrix(physical_type))
9086 {
9087 uint32_t rows = physical_type.vecsize;
9088 uint32_t cols = physical_type.columns;
9089 pack_pfx = "packed_";
9090 if (row_major)
9091 {
9092 // These are stored transposed.
9093 rows = physical_type.columns;
9094 cols = physical_type.vecsize;
9095 pack_pfx = "packed_rm_";
9096 }
9097 string base_type = physical_type.width == 16 ? "half" : "float";
9098 string td_line = "typedef ";
9099 td_line += "packed_" + base_type + to_string(rows);
9100 td_line += " " + pack_pfx;
9101 // Use the actual matrix size here.
9102 td_line += base_type + to_string(physical_type.columns) + "x" + to_string(physical_type.vecsize);
9103 td_line += "[" + to_string(cols) + "]";
9104 td_line += ";";
9105 add_typedef_line(td_line);
9106 }
9107 else if (!is_scalar(physical_type)) // scalar type is already packed.
9108 pack_pfx = "packed_";
9109 }
9110 else if (row_major)
9111 {
9112 // Need to declare type with flipped vecsize/columns.
9113 row_major_physical_type = physical_type;
9114 swap(row_major_physical_type.vecsize, row_major_physical_type.columns);
9115 declared_type = &row_major_physical_type;
9116 }
9117
9118 // Very specifically, image load-store in argument buffers are disallowed on MSL on iOS.
9119 if (msl_options.is_ios() && physical_type.basetype == SPIRType::Image && physical_type.image.sampled == 2)
9120 {
9121 if (!has_decoration(orig_id, DecorationNonWritable))
9122 SPIRV_CROSS_THROW("Writable images are not allowed in argument buffers on iOS.");
9123 }
9124
9125 // Array information is baked into these types.
9126 string array_type;
9127 if (physical_type.basetype != SPIRType::Image && physical_type.basetype != SPIRType::Sampler &&
9128 physical_type.basetype != SPIRType::SampledImage)
9129 {
9130 BuiltIn builtin = BuiltInMax;
9131 if (is_member_builtin(type, index, &builtin))
9132 is_using_builtin_array = true;
9133 array_type = type_to_array_glsl(physical_type);
9134 }
9135
9136 auto result = join(pack_pfx, type_to_glsl(*declared_type, orig_id), " ", qualifier, to_member_name(type, index),
9137 member_attribute_qualifier(type, index), array_type, ";");
9138
9139 is_using_builtin_array = false;
9140 return result;
9141 }
9142
9143 // Emit a structure member, padding and packing to maintain the correct memeber alignments.
emit_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier,uint32_t)9144 void CompilerMSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
9145 const string &qualifier, uint32_t)
9146 {
9147 // If this member requires padding to maintain its declared offset, emit a dummy padding member before it.
9148 if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationPaddingTarget))
9149 {
9150 uint32_t pad_len = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPaddingTarget);
9151 statement("char _m", index, "_pad", "[", pad_len, "];");
9152 }
9153
9154 // Handle HLSL-style 0-based vertex/instance index.
9155 builtin_declaration = true;
9156 statement(to_struct_member(type, member_type_id, index, qualifier));
9157 builtin_declaration = false;
9158 }
9159
emit_struct_padding_target(const SPIRType & type)9160 void CompilerMSL::emit_struct_padding_target(const SPIRType &type)
9161 {
9162 uint32_t struct_size = get_declared_struct_size_msl(type, true, true);
9163 uint32_t target_size = get_extended_decoration(type.self, SPIRVCrossDecorationPaddingTarget);
9164 if (target_size < struct_size)
9165 SPIRV_CROSS_THROW("Cannot pad with negative bytes.");
9166 else if (target_size > struct_size)
9167 statement("char _m0_final_padding[", target_size - struct_size, "];");
9168 }
9169
9170 // Return a MSL qualifier for the specified function attribute member
member_attribute_qualifier(const SPIRType & type,uint32_t index)9171 string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t index)
9172 {
9173 auto &execution = get_entry_point();
9174
9175 uint32_t mbr_type_id = type.member_types[index];
9176 auto &mbr_type = get<SPIRType>(mbr_type_id);
9177
9178 BuiltIn builtin = BuiltInMax;
9179 bool is_builtin = is_member_builtin(type, index, &builtin);
9180
9181 if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary))
9182 {
9183 string quals = join(
9184 " [[id(", get_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary), ")");
9185 if (interlocked_resources.count(
9186 get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID)))
9187 quals += ", raster_order_group(0)";
9188 quals += "]]";
9189 return quals;
9190 }
9191
9192 // Vertex function inputs
9193 if (execution.model == ExecutionModelVertex && type.storage == StorageClassInput)
9194 {
9195 if (is_builtin)
9196 {
9197 switch (builtin)
9198 {
9199 case BuiltInVertexId:
9200 case BuiltInVertexIndex:
9201 case BuiltInBaseVertex:
9202 case BuiltInInstanceId:
9203 case BuiltInInstanceIndex:
9204 case BuiltInBaseInstance:
9205 if (msl_options.vertex_for_tessellation)
9206 return "";
9207 return string(" [[") + builtin_qualifier(builtin) + "]]";
9208
9209 case BuiltInDrawIndex:
9210 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
9211
9212 default:
9213 return "";
9214 }
9215 }
9216 uint32_t locn = get_ordered_member_location(type.self, index);
9217 if (locn != k_unknown_location)
9218 return string(" [[attribute(") + convert_to_string(locn) + ")]]";
9219 }
9220
9221 // Vertex and tessellation evaluation function outputs
9222 if (((execution.model == ExecutionModelVertex && !msl_options.vertex_for_tessellation) ||
9223 execution.model == ExecutionModelTessellationEvaluation) &&
9224 type.storage == StorageClassOutput)
9225 {
9226 if (is_builtin)
9227 {
9228 switch (builtin)
9229 {
9230 case BuiltInPointSize:
9231 // Only mark the PointSize builtin if really rendering points.
9232 // Some shaders may include a PointSize builtin even when used to render
9233 // non-point topologies, and Metal will reject this builtin when compiling
9234 // the shader into a render pipeline that uses a non-point topology.
9235 return msl_options.enable_point_size_builtin ? (string(" [[") + builtin_qualifier(builtin) + "]]") : "";
9236
9237 case BuiltInViewportIndex:
9238 if (!msl_options.supports_msl_version(2, 0))
9239 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
9240 /* fallthrough */
9241 case BuiltInPosition:
9242 case BuiltInLayer:
9243 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
9244
9245 case BuiltInClipDistance:
9246 if (has_member_decoration(type.self, index, DecorationLocation))
9247 return join(" [[user(clip", get_member_decoration(type.self, index, DecorationLocation), ")]]");
9248 else
9249 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
9250
9251 default:
9252 return "";
9253 }
9254 }
9255 uint32_t comp;
9256 uint32_t locn = get_ordered_member_location(type.self, index, &comp);
9257 if (locn != k_unknown_location)
9258 {
9259 if (comp != k_unknown_component)
9260 return string(" [[user(locn") + convert_to_string(locn) + "_" + convert_to_string(comp) + ")]]";
9261 else
9262 return string(" [[user(locn") + convert_to_string(locn) + ")]]";
9263 }
9264 }
9265
9266 // Tessellation control function inputs
9267 if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassInput)
9268 {
9269 if (is_builtin)
9270 {
9271 switch (builtin)
9272 {
9273 case BuiltInInvocationId:
9274 case BuiltInPrimitiveId:
9275 if (msl_options.multi_patch_workgroup)
9276 return "";
9277 /* fallthrough */
9278 case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
9279 case BuiltInSubgroupSize: // FIXME: Should work in any stage
9280 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
9281 case BuiltInPatchVertices:
9282 return "";
9283 // Others come from stage input.
9284 default:
9285 break;
9286 }
9287 }
9288 if (msl_options.multi_patch_workgroup)
9289 return "";
9290 uint32_t locn = get_ordered_member_location(type.self, index);
9291 if (locn != k_unknown_location)
9292 return string(" [[attribute(") + convert_to_string(locn) + ")]]";
9293 }
9294
9295 // Tessellation control function outputs
9296 if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassOutput)
9297 {
9298 // For this type of shader, we always arrange for it to capture its
9299 // output to a buffer. For this reason, qualifiers are irrelevant here.
9300 return "";
9301 }
9302
9303 // Tessellation evaluation function inputs
9304 if (execution.model == ExecutionModelTessellationEvaluation && type.storage == StorageClassInput)
9305 {
9306 if (is_builtin)
9307 {
9308 switch (builtin)
9309 {
9310 case BuiltInPrimitiveId:
9311 case BuiltInTessCoord:
9312 return string(" [[") + builtin_qualifier(builtin) + "]]";
9313 case BuiltInPatchVertices:
9314 return "";
9315 // Others come from stage input.
9316 default:
9317 break;
9318 }
9319 }
9320 // The special control point array must not be marked with an attribute.
9321 if (get_type(type.member_types[index]).basetype == SPIRType::ControlPointArray)
9322 return "";
9323 uint32_t locn = get_ordered_member_location(type.self, index);
9324 if (locn != k_unknown_location)
9325 return string(" [[attribute(") + convert_to_string(locn) + ")]]";
9326 }
9327
9328 // Tessellation evaluation function outputs were handled above.
9329
9330 // Fragment function inputs
9331 if (execution.model == ExecutionModelFragment && type.storage == StorageClassInput)
9332 {
9333 string quals;
9334 if (is_builtin)
9335 {
9336 switch (builtin)
9337 {
9338 case BuiltInViewIndex:
9339 if (!msl_options.multiview || !msl_options.multiview_layered_rendering)
9340 break;
9341 /* fallthrough */
9342 case BuiltInFrontFacing:
9343 case BuiltInPointCoord:
9344 case BuiltInFragCoord:
9345 case BuiltInSampleId:
9346 case BuiltInSampleMask:
9347 case BuiltInLayer:
9348 case BuiltInBaryCoordNV:
9349 case BuiltInBaryCoordNoPerspNV:
9350 quals = builtin_qualifier(builtin);
9351 break;
9352
9353 case BuiltInClipDistance:
9354 return join(" [[user(clip", get_member_decoration(type.self, index, DecorationLocation), ")]]");
9355
9356 default:
9357 break;
9358 }
9359 }
9360 else
9361 {
9362 uint32_t comp;
9363 uint32_t locn = get_ordered_member_location(type.self, index, &comp);
9364 if (locn != k_unknown_location)
9365 {
9366 // For user-defined attributes, this is fine. From Vulkan spec:
9367 // A user-defined output variable is considered to match an input variable in the subsequent stage if
9368 // the two variables are declared with the same Location and Component decoration and match in type
9369 // and decoration, except that interpolation decorations are not required to match. For the purposes
9370 // of interface matching, variables declared without a Component decoration are considered to have a
9371 // Component decoration of zero.
9372
9373 if (comp != k_unknown_component && comp != 0)
9374 quals = string("user(locn") + convert_to_string(locn) + "_" + convert_to_string(comp) + ")";
9375 else
9376 quals = string("user(locn") + convert_to_string(locn) + ")";
9377 }
9378 }
9379
9380 if (builtin == BuiltInBaryCoordNV || builtin == BuiltInBaryCoordNoPerspNV)
9381 {
9382 if (has_member_decoration(type.self, index, DecorationFlat) ||
9383 has_member_decoration(type.self, index, DecorationCentroid) ||
9384 has_member_decoration(type.self, index, DecorationSample) ||
9385 has_member_decoration(type.self, index, DecorationNoPerspective))
9386 {
9387 // NoPerspective is baked into the builtin type.
9388 SPIRV_CROSS_THROW(
9389 "Flat, Centroid, Sample, NoPerspective decorations are not supported for BaryCoord inputs.");
9390 }
9391 }
9392
9393 // Don't bother decorating integers with the 'flat' attribute; it's
9394 // the default (in fact, the only option). Also don't bother with the
9395 // FragCoord builtin; it's always noperspective on Metal.
9396 if (!type_is_integral(mbr_type) && (!is_builtin || builtin != BuiltInFragCoord))
9397 {
9398 if (has_member_decoration(type.self, index, DecorationFlat))
9399 {
9400 if (!quals.empty())
9401 quals += ", ";
9402 quals += "flat";
9403 }
9404 else if (has_member_decoration(type.self, index, DecorationCentroid))
9405 {
9406 if (!quals.empty())
9407 quals += ", ";
9408 if (has_member_decoration(type.self, index, DecorationNoPerspective))
9409 quals += "centroid_no_perspective";
9410 else
9411 quals += "centroid_perspective";
9412 }
9413 else if (has_member_decoration(type.self, index, DecorationSample))
9414 {
9415 if (!quals.empty())
9416 quals += ", ";
9417 if (has_member_decoration(type.self, index, DecorationNoPerspective))
9418 quals += "sample_no_perspective";
9419 else
9420 quals += "sample_perspective";
9421 }
9422 else if (has_member_decoration(type.self, index, DecorationNoPerspective))
9423 {
9424 if (!quals.empty())
9425 quals += ", ";
9426 quals += "center_no_perspective";
9427 }
9428 }
9429
9430 if (!quals.empty())
9431 return " [[" + quals + "]]";
9432 }
9433
9434 // Fragment function outputs
9435 if (execution.model == ExecutionModelFragment && type.storage == StorageClassOutput)
9436 {
9437 if (is_builtin)
9438 {
9439 switch (builtin)
9440 {
9441 case BuiltInFragStencilRefEXT:
9442 // Similar to PointSize, only mark FragStencilRef if there's a stencil buffer.
9443 // Some shaders may include a FragStencilRef builtin even when used to render
9444 // without a stencil attachment, and Metal will reject this builtin
9445 // when compiling the shader into a render pipeline that does not set
9446 // stencilAttachmentPixelFormat.
9447 if (!msl_options.enable_frag_stencil_ref_builtin)
9448 return "";
9449 if (!msl_options.supports_msl_version(2, 1))
9450 SPIRV_CROSS_THROW("Stencil export only supported in MSL 2.1 and up.");
9451 return string(" [[") + builtin_qualifier(builtin) + "]]";
9452
9453 case BuiltInFragDepth:
9454 // Ditto FragDepth.
9455 if (!msl_options.enable_frag_depth_builtin)
9456 return "";
9457 /* fallthrough */
9458 case BuiltInSampleMask:
9459 return string(" [[") + builtin_qualifier(builtin) + "]]";
9460
9461 default:
9462 return "";
9463 }
9464 }
9465 uint32_t locn = get_ordered_member_location(type.self, index);
9466 // Metal will likely complain about missing color attachments, too.
9467 if (locn != k_unknown_location && !(msl_options.enable_frag_output_mask & (1 << locn)))
9468 return "";
9469 if (locn != k_unknown_location && has_member_decoration(type.self, index, DecorationIndex))
9470 return join(" [[color(", locn, "), index(", get_member_decoration(type.self, index, DecorationIndex),
9471 ")]]");
9472 else if (locn != k_unknown_location)
9473 return join(" [[color(", locn, ")]]");
9474 else if (has_member_decoration(type.self, index, DecorationIndex))
9475 return join(" [[index(", get_member_decoration(type.self, index, DecorationIndex), ")]]");
9476 else
9477 return "";
9478 }
9479
9480 // Compute function inputs
9481 if (execution.model == ExecutionModelGLCompute && type.storage == StorageClassInput)
9482 {
9483 if (is_builtin)
9484 {
9485 switch (builtin)
9486 {
9487 case BuiltInGlobalInvocationId:
9488 case BuiltInWorkgroupId:
9489 case BuiltInNumWorkgroups:
9490 case BuiltInLocalInvocationId:
9491 case BuiltInLocalInvocationIndex:
9492 case BuiltInNumSubgroups:
9493 case BuiltInSubgroupId:
9494 case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
9495 case BuiltInSubgroupSize: // FIXME: Should work in any stage
9496 return string(" [[") + builtin_qualifier(builtin) + "]]";
9497
9498 default:
9499 return "";
9500 }
9501 }
9502 }
9503
9504 return "";
9505 }
9506
9507 // Returns the location decoration of the member with the specified index in the specified type.
9508 // If the location of the member has been explicitly set, that location is used. If not, this
9509 // function assumes the members are ordered in their location order, and simply returns the
9510 // index as the location.
get_ordered_member_location(uint32_t type_id,uint32_t index,uint32_t * comp)9511 uint32_t CompilerMSL::get_ordered_member_location(uint32_t type_id, uint32_t index, uint32_t *comp)
9512 {
9513 auto &m = ir.meta[type_id];
9514 if (index < m.members.size())
9515 {
9516 auto &dec = m.members[index];
9517 if (comp)
9518 {
9519 if (dec.decoration_flags.get(DecorationComponent))
9520 *comp = dec.component;
9521 else
9522 *comp = k_unknown_component;
9523 }
9524 if (dec.decoration_flags.get(DecorationLocation))
9525 return dec.location;
9526 }
9527
9528 return index;
9529 }
9530
9531 // Returns the type declaration for a function, including the
9532 // entry type if the current function is the entry point function
func_type_decl(SPIRType & type)9533 string CompilerMSL::func_type_decl(SPIRType &type)
9534 {
9535 // The regular function return type. If not processing the entry point function, that's all we need
9536 string return_type = type_to_glsl(type) + type_to_array_glsl(type);
9537 if (!processing_entry_point)
9538 return return_type;
9539
9540 // If an outgoing interface block has been defined, and it should be returned, override the entry point return type
9541 bool ep_should_return_output = !get_is_rasterization_disabled();
9542 if (stage_out_var_id && ep_should_return_output)
9543 return_type = type_to_glsl(get_stage_out_struct_type()) + type_to_array_glsl(type);
9544
9545 // Prepend a entry type, based on the execution model
9546 string entry_type;
9547 auto &execution = get_entry_point();
9548 switch (execution.model)
9549 {
9550 case ExecutionModelVertex:
9551 if (msl_options.vertex_for_tessellation && !msl_options.supports_msl_version(1, 2))
9552 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
9553 entry_type = msl_options.vertex_for_tessellation ? "kernel" : "vertex";
9554 break;
9555 case ExecutionModelTessellationEvaluation:
9556 if (!msl_options.supports_msl_version(1, 2))
9557 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
9558 if (execution.flags.get(ExecutionModeIsolines))
9559 SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
9560 if (msl_options.is_ios())
9561 entry_type =
9562 join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ") ]] vertex");
9563 else
9564 entry_type = join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ", ",
9565 execution.output_vertices, ") ]] vertex");
9566 break;
9567 case ExecutionModelFragment:
9568 entry_type = execution.flags.get(ExecutionModeEarlyFragmentTests) ||
9569 execution.flags.get(ExecutionModePostDepthCoverage) ?
9570 "[[ early_fragment_tests ]] fragment" :
9571 "fragment";
9572 break;
9573 case ExecutionModelTessellationControl:
9574 if (!msl_options.supports_msl_version(1, 2))
9575 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
9576 if (execution.flags.get(ExecutionModeIsolines))
9577 SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
9578 /* fallthrough */
9579 case ExecutionModelGLCompute:
9580 case ExecutionModelKernel:
9581 entry_type = "kernel";
9582 break;
9583 default:
9584 entry_type = "unknown";
9585 break;
9586 }
9587
9588 return entry_type + " " + return_type;
9589 }
9590
9591 // In MSL, address space qualifiers are required for all pointer or reference variables
get_argument_address_space(const SPIRVariable & argument)9592 string CompilerMSL::get_argument_address_space(const SPIRVariable &argument)
9593 {
9594 const auto &type = get<SPIRType>(argument.basetype);
9595 return get_type_address_space(type, argument.self, true);
9596 }
9597
get_type_address_space(const SPIRType & type,uint32_t id,bool argument)9598 string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bool argument)
9599 {
9600 // This can be called for variable pointer contexts as well, so be very careful about which method we choose.
9601 Bitset flags;
9602 auto *var = maybe_get<SPIRVariable>(id);
9603 if (var && type.basetype == SPIRType::Struct &&
9604 (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
9605 flags = get_buffer_block_flags(id);
9606 else
9607 flags = get_decoration_bitset(id);
9608
9609 const char *addr_space = nullptr;
9610 switch (type.storage)
9611 {
9612 case StorageClassWorkgroup:
9613 addr_space = "threadgroup";
9614 break;
9615
9616 case StorageClassStorageBuffer:
9617 {
9618 // For arguments from variable pointers, we use the write count deduction, so
9619 // we should not assume any constness here. Only for global SSBOs.
9620 bool readonly = false;
9621 if (!var || has_decoration(type.self, DecorationBlock))
9622 readonly = flags.get(DecorationNonWritable);
9623
9624 addr_space = readonly ? "const device" : "device";
9625 break;
9626 }
9627
9628 case StorageClassUniform:
9629 case StorageClassUniformConstant:
9630 case StorageClassPushConstant:
9631 if (type.basetype == SPIRType::Struct)
9632 {
9633 bool ssbo = has_decoration(type.self, DecorationBufferBlock);
9634 if (ssbo)
9635 addr_space = flags.get(DecorationNonWritable) ? "const device" : "device";
9636 else
9637 addr_space = "constant";
9638 }
9639 else if (!argument)
9640 {
9641 addr_space = "constant";
9642 }
9643 else if (type_is_msl_framebuffer_fetch(type))
9644 {
9645 // Subpass inputs are passed around by value.
9646 addr_space = "";
9647 }
9648 break;
9649
9650 case StorageClassFunction:
9651 case StorageClassGeneric:
9652 break;
9653
9654 case StorageClassInput:
9655 if (get_execution_model() == ExecutionModelTessellationControl && var &&
9656 var->basevariable == stage_in_ptr_var_id)
9657 addr_space = msl_options.multi_patch_workgroup ? "constant" : "threadgroup";
9658 break;
9659
9660 case StorageClassOutput:
9661 if (capture_output_to_buffer)
9662 addr_space = "device";
9663 break;
9664
9665 default:
9666 break;
9667 }
9668
9669 if (!addr_space)
9670 // No address space for plain values.
9671 addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : "";
9672
9673 return join(flags.get(DecorationVolatile) || flags.get(DecorationCoherent) ? "volatile " : "", addr_space);
9674 }
9675
to_restrict(uint32_t id,bool space)9676 const char *CompilerMSL::to_restrict(uint32_t id, bool space)
9677 {
9678 // This can be called for variable pointer contexts as well, so be very careful about which method we choose.
9679 Bitset flags;
9680 if (ir.ids[id].get_type() == TypeVariable)
9681 {
9682 uint32_t type_id = expression_type_id(id);
9683 auto &type = expression_type(id);
9684 if (type.basetype == SPIRType::Struct &&
9685 (has_decoration(type_id, DecorationBlock) || has_decoration(type_id, DecorationBufferBlock)))
9686 flags = get_buffer_block_flags(id);
9687 else
9688 flags = get_decoration_bitset(id);
9689 }
9690 else
9691 flags = get_decoration_bitset(id);
9692
9693 return flags.get(DecorationRestrict) ? (space ? "restrict " : "restrict") : "";
9694 }
9695
entry_point_arg_stage_in()9696 string CompilerMSL::entry_point_arg_stage_in()
9697 {
9698 string decl;
9699
9700 if (get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup)
9701 return decl;
9702
9703 // Stage-in structure
9704 uint32_t stage_in_id;
9705 if (get_execution_model() == ExecutionModelTessellationEvaluation)
9706 stage_in_id = patch_stage_in_var_id;
9707 else
9708 stage_in_id = stage_in_var_id;
9709
9710 if (stage_in_id)
9711 {
9712 auto &var = get<SPIRVariable>(stage_in_id);
9713 auto &type = get_variable_data_type(var);
9714
9715 add_resource_name(var.self);
9716 decl = join(type_to_glsl(type), " ", to_name(var.self), " [[stage_in]]");
9717 }
9718
9719 return decl;
9720 }
9721
9722 // Returns true if this input builtin should be a direct parameter on a shader function parameter list,
9723 // and false for builtins that should be passed or calculated some other way.
is_direct_input_builtin(BuiltIn bi_type)9724 bool CompilerMSL::is_direct_input_builtin(BuiltIn bi_type)
9725 {
9726 switch (bi_type)
9727 {
9728 // Vertex function in
9729 case BuiltInVertexId:
9730 case BuiltInVertexIndex:
9731 case BuiltInBaseVertex:
9732 case BuiltInInstanceId:
9733 case BuiltInInstanceIndex:
9734 case BuiltInBaseInstance:
9735 return get_execution_model() != ExecutionModelVertex || !msl_options.vertex_for_tessellation;
9736 // Tess. control function in
9737 case BuiltInPosition:
9738 case BuiltInPointSize:
9739 case BuiltInClipDistance:
9740 case BuiltInCullDistance:
9741 case BuiltInPatchVertices:
9742 return false;
9743 case BuiltInInvocationId:
9744 case BuiltInPrimitiveId:
9745 return get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup;
9746 // Tess. evaluation function in
9747 case BuiltInTessLevelInner:
9748 case BuiltInTessLevelOuter:
9749 return false;
9750 // Fragment function in
9751 case BuiltInSamplePosition:
9752 case BuiltInHelperInvocation:
9753 case BuiltInBaryCoordNV:
9754 case BuiltInBaryCoordNoPerspNV:
9755 return false;
9756 case BuiltInViewIndex:
9757 return get_execution_model() == ExecutionModelFragment && msl_options.multiview &&
9758 msl_options.multiview_layered_rendering;
9759 // Any stage function in
9760 case BuiltInDeviceIndex:
9761 case BuiltInSubgroupEqMask:
9762 case BuiltInSubgroupGeMask:
9763 case BuiltInSubgroupGtMask:
9764 case BuiltInSubgroupLeMask:
9765 case BuiltInSubgroupLtMask:
9766 return false;
9767 case BuiltInSubgroupLocalInvocationId:
9768 case BuiltInSubgroupSize:
9769 return get_execution_model() == ExecutionModelGLCompute ||
9770 (get_execution_model() == ExecutionModelFragment && msl_options.supports_msl_version(2, 2));
9771 default:
9772 return true;
9773 }
9774 }
9775
entry_point_args_builtin(string & ep_args)9776 void CompilerMSL::entry_point_args_builtin(string &ep_args)
9777 {
9778 // Builtin variables
9779 SmallVector<pair<SPIRVariable *, BuiltIn>, 8> active_builtins;
9780 ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
9781 auto bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
9782
9783 // Don't emit SamplePosition as a separate parameter. In the entry
9784 // point, we get that by calling get_sample_position() on the sample ID.
9785 if (var.storage == StorageClassInput && is_builtin_variable(var) &&
9786 get_variable_data_type(var).basetype != SPIRType::Struct &&
9787 get_variable_data_type(var).basetype != SPIRType::ControlPointArray)
9788 {
9789 // If the builtin is not part of the active input builtin set, don't emit it.
9790 // Relevant for multiple entry-point modules which might declare unused builtins.
9791 if (!active_input_builtins.get(bi_type) || !interface_variable_exists_in_entry_point(var_id))
9792 return;
9793
9794 // Remember this variable. We may need to correct its type.
9795 active_builtins.push_back(make_pair(&var, bi_type));
9796
9797 if (is_direct_input_builtin(bi_type))
9798 {
9799 if (!ep_args.empty())
9800 ep_args += ", ";
9801
9802 // Handle HLSL-style 0-based vertex/instance index.
9803 builtin_declaration = true;
9804 ep_args += builtin_type_decl(bi_type, var_id) + " " + to_expression(var_id);
9805 ep_args += " [[" + builtin_qualifier(bi_type);
9806 if (bi_type == BuiltInSampleMask && get_entry_point().flags.get(ExecutionModePostDepthCoverage))
9807 {
9808 if (!msl_options.supports_msl_version(2))
9809 SPIRV_CROSS_THROW("Post-depth coverage requires Metal 2.0.");
9810 if (!msl_options.is_ios())
9811 SPIRV_CROSS_THROW("Post-depth coverage is only supported on iOS.");
9812 ep_args += ", post_depth_coverage";
9813 }
9814 ep_args += "]]";
9815 builtin_declaration = false;
9816 }
9817 }
9818
9819 if (var.storage == StorageClassInput &&
9820 has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase))
9821 {
9822 // This is a special implicit builtin, not corresponding to any SPIR-V builtin,
9823 // which holds the base that was passed to vkCmdDispatchBase() or vkCmdDrawIndexed(). If it's present,
9824 // assume we emitted it for a good reason.
9825 assert(msl_options.supports_msl_version(1, 2));
9826 if (!ep_args.empty())
9827 ep_args += ", ";
9828
9829 ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_origin]]";
9830 }
9831
9832 if (var.storage == StorageClassInput &&
9833 has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInStageInputSize))
9834 {
9835 // This is another special implicit builtin, not corresponding to any SPIR-V builtin,
9836 // which holds the number of vertices and instances to draw. If it's present,
9837 // assume we emitted it for a good reason.
9838 assert(msl_options.supports_msl_version(1, 2));
9839 if (!ep_args.empty())
9840 ep_args += ", ";
9841
9842 ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_size]]";
9843 }
9844 });
9845
9846 // Correct the types of all encountered active builtins. We couldn't do this before
9847 // because ensure_correct_builtin_type() may increase the bound, which isn't allowed
9848 // while iterating over IDs.
9849 for (auto &var : active_builtins)
9850 var.first->basetype = ensure_correct_builtin_type(var.first->basetype, var.second);
9851
9852 // Handle HLSL-style 0-based vertex/instance index.
9853 if (needs_base_vertex_arg == TriState::Yes)
9854 ep_args += built_in_func_arg(BuiltInBaseVertex, !ep_args.empty());
9855
9856 if (needs_base_instance_arg == TriState::Yes)
9857 ep_args += built_in_func_arg(BuiltInBaseInstance, !ep_args.empty());
9858
9859 if (capture_output_to_buffer)
9860 {
9861 // Add parameters to hold the indirect draw parameters and the shader output. This has to be handled
9862 // specially because it needs to be a pointer, not a reference.
9863 if (stage_out_var_id)
9864 {
9865 if (!ep_args.empty())
9866 ep_args += ", ";
9867 ep_args += join("device ", type_to_glsl(get_stage_out_struct_type()), "* ", output_buffer_var_name,
9868 " [[buffer(", msl_options.shader_output_buffer_index, ")]]");
9869 }
9870
9871 if (get_execution_model() == ExecutionModelTessellationControl)
9872 {
9873 if (!ep_args.empty())
9874 ep_args += ", ";
9875 ep_args +=
9876 join("constant uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
9877 }
9878 else if (stage_out_var_id &&
9879 !(get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
9880 {
9881 if (!ep_args.empty())
9882 ep_args += ", ";
9883 ep_args +=
9884 join("device uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
9885 }
9886
9887 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation &&
9888 (active_input_builtins.get(BuiltInVertexIndex) || active_input_builtins.get(BuiltInVertexId)) &&
9889 msl_options.vertex_index_type != Options::IndexType::None)
9890 {
9891 // Add the index buffer so we can set gl_VertexIndex correctly.
9892 if (!ep_args.empty())
9893 ep_args += ", ";
9894 switch (msl_options.vertex_index_type)
9895 {
9896 case Options::IndexType::None:
9897 break;
9898 case Options::IndexType::UInt16:
9899 ep_args += join("const device ushort* ", index_buffer_var_name, " [[buffer(",
9900 msl_options.shader_index_buffer_index, ")]]");
9901 break;
9902 case Options::IndexType::UInt32:
9903 ep_args += join("const device uint* ", index_buffer_var_name, " [[buffer(",
9904 msl_options.shader_index_buffer_index, ")]]");
9905 break;
9906 }
9907 }
9908
9909 // Tessellation control shaders get three additional parameters:
9910 // a buffer to hold the per-patch data, a buffer to hold the per-patch
9911 // tessellation levels, and a block of workgroup memory to hold the
9912 // input control point data.
9913 if (get_execution_model() == ExecutionModelTessellationControl)
9914 {
9915 if (patch_stage_out_var_id)
9916 {
9917 if (!ep_args.empty())
9918 ep_args += ", ";
9919 ep_args +=
9920 join("device ", type_to_glsl(get_patch_stage_out_struct_type()), "* ", patch_output_buffer_var_name,
9921 " [[buffer(", convert_to_string(msl_options.shader_patch_output_buffer_index), ")]]");
9922 }
9923 if (!ep_args.empty())
9924 ep_args += ", ";
9925 ep_args += join("device ", get_tess_factor_struct_name(), "* ", tess_factor_buffer_var_name, " [[buffer(",
9926 convert_to_string(msl_options.shader_tess_factor_buffer_index), ")]]");
9927 if (stage_in_var_id)
9928 {
9929 if (!ep_args.empty())
9930 ep_args += ", ";
9931 if (msl_options.multi_patch_workgroup)
9932 {
9933 ep_args += join("device ", type_to_glsl(get_stage_in_struct_type()), "* ", input_buffer_var_name,
9934 " [[buffer(", convert_to_string(msl_options.shader_input_buffer_index), ")]]");
9935 }
9936 else
9937 {
9938 ep_args += join("threadgroup ", type_to_glsl(get_stage_in_struct_type()), "* ", input_wg_var_name,
9939 " [[threadgroup(", convert_to_string(msl_options.shader_input_wg_index), ")]]");
9940 }
9941 }
9942 }
9943 }
9944 }
9945
entry_point_args_argument_buffer(bool append_comma)9946 string CompilerMSL::entry_point_args_argument_buffer(bool append_comma)
9947 {
9948 string ep_args = entry_point_arg_stage_in();
9949 Bitset claimed_bindings;
9950
9951 for (uint32_t i = 0; i < kMaxArgumentBuffers; i++)
9952 {
9953 uint32_t id = argument_buffer_ids[i];
9954 if (id == 0)
9955 continue;
9956
9957 add_resource_name(id);
9958 auto &var = get<SPIRVariable>(id);
9959 auto &type = get_variable_data_type(var);
9960
9961 if (!ep_args.empty())
9962 ep_args += ", ";
9963
9964 // Check if the argument buffer binding itself has been remapped.
9965 uint32_t buffer_binding;
9966 auto itr = resource_bindings.find({ get_entry_point().model, i, kArgumentBufferBinding });
9967 if (itr != end(resource_bindings))
9968 {
9969 buffer_binding = itr->second.first.msl_buffer;
9970 itr->second.second = true;
9971 }
9972 else
9973 {
9974 // As a fallback, directly map desc set <-> binding.
9975 // If that was taken, take the next buffer binding.
9976 if (claimed_bindings.get(i))
9977 buffer_binding = next_metal_resource_index_buffer;
9978 else
9979 buffer_binding = i;
9980 }
9981
9982 claimed_bindings.set(buffer_binding);
9983
9984 ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(id) + to_name(id);
9985 ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]";
9986
9987 next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1);
9988 }
9989
9990 entry_point_args_discrete_descriptors(ep_args);
9991 entry_point_args_builtin(ep_args);
9992
9993 if (!ep_args.empty() && append_comma)
9994 ep_args += ", ";
9995
9996 return ep_args;
9997 }
9998
find_constexpr_sampler(uint32_t id) const9999 const MSLConstexprSampler *CompilerMSL::find_constexpr_sampler(uint32_t id) const
10000 {
10001 // Try by ID.
10002 {
10003 auto itr = constexpr_samplers_by_id.find(id);
10004 if (itr != end(constexpr_samplers_by_id))
10005 return &itr->second;
10006 }
10007
10008 // Try by binding.
10009 {
10010 uint32_t desc_set = get_decoration(id, DecorationDescriptorSet);
10011 uint32_t binding = get_decoration(id, DecorationBinding);
10012
10013 auto itr = constexpr_samplers_by_binding.find({ desc_set, binding });
10014 if (itr != end(constexpr_samplers_by_binding))
10015 return &itr->second;
10016 }
10017
10018 return nullptr;
10019 }
10020
entry_point_args_discrete_descriptors(string & ep_args)10021 void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
10022 {
10023 // Output resources, sorted by resource index & type
10024 // We need to sort to work around a bug on macOS 10.13 with NVidia drivers where switching between shaders
10025 // with different order of buffers can result in issues with buffer assignments inside the driver.
10026 struct Resource
10027 {
10028 SPIRVariable *var;
10029 string name;
10030 SPIRType::BaseType basetype;
10031 uint32_t index;
10032 uint32_t plane;
10033 uint32_t secondary_index;
10034 };
10035
10036 SmallVector<Resource> resources;
10037
10038 ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
10039 if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
10040 var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer) &&
10041 !is_hidden_variable(var))
10042 {
10043 auto &type = get_variable_data_type(var);
10044
10045 // Very specifically, image load-store in argument buffers are disallowed on MSL on iOS.
10046 // But we won't know when the argument buffer is encoded whether this image will have
10047 // a NonWritable decoration. So just use discrete arguments for all storage images
10048 // on iOS.
10049 if (!(msl_options.is_ios() && type.basetype == SPIRType::Image && type.image.sampled == 2) &&
10050 var.storage != StorageClassPushConstant)
10051 {
10052 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
10053 if (descriptor_set_is_argument_buffer(desc_set))
10054 return;
10055 }
10056
10057 const MSLConstexprSampler *constexpr_sampler = nullptr;
10058 if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
10059 {
10060 constexpr_sampler = find_constexpr_sampler(var_id);
10061 if (constexpr_sampler)
10062 {
10063 // Mark this ID as a constexpr sampler for later in case it came from set/bindings.
10064 constexpr_samplers_by_id[var_id] = *constexpr_sampler;
10065 }
10066 }
10067
10068 // Emulate texture2D atomic operations
10069 uint32_t secondary_index = 0;
10070 if (atomic_image_vars.count(var.self))
10071 {
10072 secondary_index = get_metal_resource_index(var, SPIRType::AtomicCounter, 0);
10073 }
10074
10075 if (type.basetype == SPIRType::SampledImage)
10076 {
10077 add_resource_name(var_id);
10078
10079 uint32_t plane_count = 1;
10080 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
10081 plane_count = constexpr_sampler->planes;
10082
10083 for (uint32_t i = 0; i < plane_count; i++)
10084 resources.push_back({ &var, to_name(var_id), SPIRType::Image,
10085 get_metal_resource_index(var, SPIRType::Image, i), i, secondary_index });
10086
10087 if (type.image.dim != DimBuffer && !constexpr_sampler)
10088 {
10089 resources.push_back({ &var, to_sampler_expression(var_id), SPIRType::Sampler,
10090 get_metal_resource_index(var, SPIRType::Sampler), 0, 0 });
10091 }
10092 }
10093 else if (!constexpr_sampler)
10094 {
10095 // constexpr samplers are not declared as resources.
10096 add_resource_name(var_id);
10097 resources.push_back({ &var, to_name(var_id), type.basetype,
10098 get_metal_resource_index(var, type.basetype), 0, secondary_index });
10099 }
10100 }
10101 });
10102
10103 sort(resources.begin(), resources.end(), [](const Resource &lhs, const Resource &rhs) {
10104 return tie(lhs.basetype, lhs.index) < tie(rhs.basetype, rhs.index);
10105 });
10106
10107 for (auto &r : resources)
10108 {
10109 auto &var = *r.var;
10110 auto &type = get_variable_data_type(var);
10111
10112 uint32_t var_id = var.self;
10113
10114 switch (r.basetype)
10115 {
10116 case SPIRType::Struct:
10117 {
10118 auto &m = ir.meta[type.self];
10119 if (m.members.size() == 0)
10120 break;
10121 if (!type.array.empty())
10122 {
10123 if (type.array.size() > 1)
10124 SPIRV_CROSS_THROW("Arrays of arrays of buffers are not supported.");
10125
10126 // Metal doesn't directly support this, so we must expand the
10127 // array. We'll declare a local array to hold these elements
10128 // later.
10129 uint32_t array_size = to_array_size_literal(type);
10130
10131 if (array_size == 0)
10132 SPIRV_CROSS_THROW("Unsized arrays of buffers are not supported in MSL.");
10133
10134 // Allow Metal to use the array<T> template to make arrays a value type
10135 is_using_builtin_array = true;
10136 buffer_arrays.push_back(var_id);
10137 for (uint32_t i = 0; i < array_size; ++i)
10138 {
10139 if (!ep_args.empty())
10140 ep_args += ", ";
10141 ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "* " + to_restrict(var_id) +
10142 r.name + "_" + convert_to_string(i);
10143 ep_args += " [[buffer(" + convert_to_string(r.index + i) + ")";
10144 if (interlocked_resources.count(var_id))
10145 ep_args += ", raster_order_group(0)";
10146 ep_args += "]]";
10147 }
10148 is_using_builtin_array = false;
10149 }
10150 else
10151 {
10152 if (!ep_args.empty())
10153 ep_args += ", ";
10154 ep_args +=
10155 get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(var_id) + r.name;
10156 ep_args += " [[buffer(" + convert_to_string(r.index) + ")";
10157 if (interlocked_resources.count(var_id))
10158 ep_args += ", raster_order_group(0)";
10159 ep_args += "]]";
10160 }
10161 break;
10162 }
10163 case SPIRType::Sampler:
10164 if (!ep_args.empty())
10165 ep_args += ", ";
10166 ep_args += sampler_type(type) + " " + r.name;
10167 ep_args += " [[sampler(" + convert_to_string(r.index) + ")]]";
10168 break;
10169 case SPIRType::Image:
10170 {
10171 if (!ep_args.empty())
10172 ep_args += ", ";
10173
10174 // Use Metal's native frame-buffer fetch API for subpass inputs.
10175 const auto &basetype = get<SPIRType>(var.basetype);
10176 if (!type_is_msl_framebuffer_fetch(basetype))
10177 {
10178 ep_args += image_type_glsl(type, var_id) + " " + r.name;
10179 if (r.plane > 0)
10180 ep_args += join(plane_name_suffix, r.plane);
10181 ep_args += " [[texture(" + convert_to_string(r.index) + ")";
10182 if (interlocked_resources.count(var_id))
10183 ep_args += ", raster_order_group(0)";
10184 ep_args += "]]";
10185 }
10186 else
10187 {
10188 ep_args += image_type_glsl(type, var_id) + " " + r.name;
10189 ep_args += " [[color(" + convert_to_string(r.index) + ")]]";
10190 }
10191
10192 // Emulate texture2D atomic operations
10193 if (atomic_image_vars.count(var.self))
10194 {
10195 ep_args += ", device atomic_" + type_to_glsl(get<SPIRType>(basetype.image.type), 0);
10196 ep_args += "* " + r.name + "_atomic";
10197 ep_args += " [[buffer(" + convert_to_string(r.secondary_index) + ")]]";
10198 }
10199 break;
10200 }
10201 default:
10202 if (!ep_args.empty())
10203 ep_args += ", ";
10204 if (!type.pointer)
10205 ep_args += get_type_address_space(get<SPIRType>(var.basetype), var_id) + " " +
10206 type_to_glsl(type, var_id) + "& " + r.name;
10207 else
10208 ep_args += type_to_glsl(type, var_id) + " " + r.name;
10209 ep_args += " [[buffer(" + convert_to_string(r.index) + ")";
10210 if (interlocked_resources.count(var_id))
10211 ep_args += ", raster_order_group(0)";
10212 ep_args += "]]";
10213 break;
10214 }
10215 }
10216 }
10217
10218 // Returns a string containing a comma-delimited list of args for the entry point function
10219 // This is the "classic" method of MSL 1 when we don't have argument buffer support.
entry_point_args_classic(bool append_comma)10220 string CompilerMSL::entry_point_args_classic(bool append_comma)
10221 {
10222 string ep_args = entry_point_arg_stage_in();
10223 entry_point_args_discrete_descriptors(ep_args);
10224 entry_point_args_builtin(ep_args);
10225
10226 if (!ep_args.empty() && append_comma)
10227 ep_args += ", ";
10228
10229 return ep_args;
10230 }
10231
fix_up_shader_inputs_outputs()10232 void CompilerMSL::fix_up_shader_inputs_outputs()
10233 {
10234 auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
10235
10236 // Emit a guard to ensure we don't execute beyond the last vertex.
10237 // Vertex shaders shouldn't have the problems with barriers in non-uniform control flow that
10238 // tessellation control shaders do, so early returns should be OK. We may need to revisit this
10239 // if it ever becomes possible to use barriers from a vertex shader.
10240 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
10241 {
10242 entry_func.fixup_hooks_in.push_back([this]() {
10243 statement("if (any(", to_expression(builtin_invocation_id_id),
10244 " >= ", to_expression(builtin_stage_input_size_id), "))");
10245 statement(" return;");
10246 });
10247 }
10248
10249 // Look for sampled images and buffer. Add hooks to set up the swizzle constants or array lengths.
10250 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
10251 auto &type = get_variable_data_type(var);
10252 uint32_t var_id = var.self;
10253 bool ssbo = has_decoration(type.self, DecorationBufferBlock);
10254
10255 if (var.storage == StorageClassUniformConstant && !is_hidden_variable(var))
10256 {
10257 if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
10258 {
10259 entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
10260 bool is_array_type = !type.array.empty();
10261
10262 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
10263 if (descriptor_set_is_argument_buffer(desc_set))
10264 {
10265 statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
10266 is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
10267 ".spvSwizzleConstants", "[",
10268 convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
10269 }
10270 else
10271 {
10272 // If we have an array of images, we need to be able to index into it, so take a pointer instead.
10273 statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
10274 is_array_type ? " = &" : " = ", to_name(swizzle_buffer_id), "[",
10275 convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
10276 }
10277 });
10278 }
10279 }
10280 else if ((var.storage == StorageClassStorageBuffer || (var.storage == StorageClassUniform && ssbo)) &&
10281 !is_hidden_variable(var))
10282 {
10283 if (buffers_requiring_array_length.count(var.self))
10284 {
10285 entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
10286 bool is_array_type = !type.array.empty();
10287
10288 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
10289 if (descriptor_set_is_argument_buffer(desc_set))
10290 {
10291 statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
10292 is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
10293 ".spvBufferSizeConstants", "[",
10294 convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
10295 }
10296 else
10297 {
10298 // If we have an array of images, we need to be able to index into it, so take a pointer instead.
10299 statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
10300 is_array_type ? " = &" : " = ", to_name(buffer_size_buffer_id), "[",
10301 convert_to_string(get_metal_resource_index(var, type.basetype)), "];");
10302 }
10303 });
10304 }
10305 }
10306 });
10307
10308 // Builtin variables
10309 ir.for_each_typed_id<SPIRVariable>([this, &entry_func](uint32_t, SPIRVariable &var) {
10310 uint32_t var_id = var.self;
10311 BuiltIn bi_type = ir.meta[var_id].decoration.builtin_type;
10312
10313 if (var.storage == StorageClassInput && is_builtin_variable(var))
10314 {
10315 switch (bi_type)
10316 {
10317 case BuiltInSamplePosition:
10318 entry_func.fixup_hooks_in.push_back([=]() {
10319 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = get_sample_position(",
10320 to_expression(builtin_sample_id_id), ");");
10321 });
10322 break;
10323 case BuiltInHelperInvocation:
10324 if (msl_options.is_ios())
10325 SPIRV_CROSS_THROW("simd_is_helper_thread() is only supported on macOS.");
10326 else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
10327 SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
10328
10329 entry_func.fixup_hooks_in.push_back([=]() {
10330 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = simd_is_helper_thread();");
10331 });
10332 break;
10333 case BuiltInInvocationId:
10334 // This is direct-mapped without multi-patch workgroups.
10335 if (get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup)
10336 break;
10337
10338 entry_func.fixup_hooks_in.push_back([=]() {
10339 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10340 to_expression(builtin_invocation_id_id), ".x % ", this->get_entry_point().output_vertices,
10341 ";");
10342 });
10343 break;
10344 case BuiltInPrimitiveId:
10345 // This is natively supported by fragment and tessellation evaluation shaders.
10346 // In tessellation control shaders, this is direct-mapped without multi-patch workgroups.
10347 if (get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup)
10348 break;
10349
10350 entry_func.fixup_hooks_in.push_back([=]() {
10351 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = min(",
10352 to_expression(builtin_invocation_id_id), ".x / ", this->get_entry_point().output_vertices,
10353 ", spvIndirectParams[1]);");
10354 });
10355 break;
10356 case BuiltInPatchVertices:
10357 if (get_execution_model() == ExecutionModelTessellationEvaluation)
10358 entry_func.fixup_hooks_in.push_back([=]() {
10359 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10360 to_expression(patch_stage_in_var_id), ".gl_in.size();");
10361 });
10362 else
10363 entry_func.fixup_hooks_in.push_back([=]() {
10364 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = spvIndirectParams[0];");
10365 });
10366 break;
10367 case BuiltInTessCoord:
10368 // Emit a fixup to account for the shifted domain. Don't do this for triangles;
10369 // MoltenVK will just reverse the winding order instead.
10370 if (msl_options.tess_domain_origin_lower_left && !get_entry_point().flags.get(ExecutionModeTriangles))
10371 {
10372 string tc = to_expression(var_id);
10373 entry_func.fixup_hooks_in.push_back([=]() { statement(tc, ".y = 1.0 - ", tc, ".y;"); });
10374 }
10375 break;
10376 case BuiltInSubgroupLocalInvocationId:
10377 // This is natively supported in compute shaders.
10378 if (get_execution_model() == ExecutionModelGLCompute)
10379 break;
10380
10381 // This is natively supported in fragment shaders in MSL 2.2.
10382 if (get_execution_model() == ExecutionModelFragment && msl_options.supports_msl_version(2, 2))
10383 break;
10384
10385 if (msl_options.is_ios())
10386 SPIRV_CROSS_THROW(
10387 "SubgroupLocalInvocationId cannot be used outside of compute shaders before MSL 2.2 on iOS.");
10388
10389 if (!msl_options.supports_msl_version(2, 1))
10390 SPIRV_CROSS_THROW(
10391 "SubgroupLocalInvocationId cannot be used outside of compute shaders before MSL 2.1.");
10392
10393 // Shaders other than compute shaders don't support the SIMD-group
10394 // builtins directly, but we can emulate them using the SIMD-group
10395 // functions. This might break if some of the subgroup terminated
10396 // before reaching the entry point.
10397 entry_func.fixup_hooks_in.push_back([=]() {
10398 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
10399 " = simd_prefix_exclusive_sum(1);");
10400 });
10401 break;
10402 case BuiltInSubgroupSize:
10403 // This is natively supported in compute shaders.
10404 if (get_execution_model() == ExecutionModelGLCompute)
10405 break;
10406
10407 // This is natively supported in fragment shaders in MSL 2.2.
10408 if (get_execution_model() == ExecutionModelFragment && msl_options.supports_msl_version(2, 2))
10409 break;
10410
10411 if (msl_options.is_ios())
10412 SPIRV_CROSS_THROW("SubgroupSize cannot be used outside of compute shaders on iOS.");
10413
10414 if (!msl_options.supports_msl_version(2, 1))
10415 SPIRV_CROSS_THROW("SubgroupSize cannot be used outside of compute shaders before Metal 2.1.");
10416
10417 entry_func.fixup_hooks_in.push_back(
10418 [=]() { statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = simd_sum(1);"); });
10419 break;
10420 case BuiltInSubgroupEqMask:
10421 if (msl_options.is_ios())
10422 SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
10423 if (!msl_options.supports_msl_version(2, 1))
10424 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
10425 entry_func.fixup_hooks_in.push_back([=]() {
10426 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10427 to_expression(builtin_subgroup_invocation_id_id), " > 32 ? uint4(0, (1 << (",
10428 to_expression(builtin_subgroup_invocation_id_id), " - 32)), uint2(0)) : uint4(1 << ",
10429 to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
10430 });
10431 break;
10432 case BuiltInSubgroupGeMask:
10433 if (msl_options.is_ios())
10434 SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
10435 if (!msl_options.supports_msl_version(2, 1))
10436 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
10437 entry_func.fixup_hooks_in.push_back([=]() {
10438 // Case where index < 32, size < 32:
10439 // mask0 = bfe(0xFFFFFFFF, index, size - index);
10440 // mask1 = bfe(0xFFFFFFFF, 0, 0); // Gives 0
10441 // Case where index < 32 but size >= 32:
10442 // mask0 = bfe(0xFFFFFFFF, index, 32 - index);
10443 // mask1 = bfe(0xFFFFFFFF, 0, size - 32);
10444 // Case where index >= 32:
10445 // mask0 = bfe(0xFFFFFFFF, 32, 0); // Gives 0
10446 // mask1 = bfe(0xFFFFFFFF, index - 32, size - index);
10447 // This is expressed without branches to avoid divergent
10448 // control flow--hence the complicated min/max expressions.
10449 // This is further complicated by the fact that if you attempt
10450 // to bfe out-of-bounds on Metal, undefined behavior is the
10451 // result.
10452 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
10453 " = uint4(extract_bits(0xFFFFFFFF, min(",
10454 to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(min((int)",
10455 to_expression(builtin_subgroup_size_id), ", 32) - (int)",
10456 to_expression(builtin_subgroup_invocation_id_id),
10457 ", 0)), extract_bits(0xFFFFFFFF, (uint)max((int)",
10458 to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), (uint)max((int)",
10459 to_expression(builtin_subgroup_size_id), " - (int)max(",
10460 to_expression(builtin_subgroup_invocation_id_id), ", 32u), 0)), uint2(0));");
10461 });
10462 break;
10463 case BuiltInSubgroupGtMask:
10464 if (msl_options.is_ios())
10465 SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
10466 if (!msl_options.supports_msl_version(2, 1))
10467 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
10468 entry_func.fixup_hooks_in.push_back([=]() {
10469 // The same logic applies here, except now the index is one
10470 // more than the subgroup invocation ID.
10471 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
10472 " = uint4(extract_bits(0xFFFFFFFF, min(",
10473 to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(min((int)",
10474 to_expression(builtin_subgroup_size_id), ", 32) - (int)",
10475 to_expression(builtin_subgroup_invocation_id_id),
10476 " - 1, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)",
10477 to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), (uint)max((int)",
10478 to_expression(builtin_subgroup_size_id), " - (int)max(",
10479 to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), 0)), uint2(0));");
10480 });
10481 break;
10482 case BuiltInSubgroupLeMask:
10483 if (msl_options.is_ios())
10484 SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
10485 if (!msl_options.supports_msl_version(2, 1))
10486 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
10487 entry_func.fixup_hooks_in.push_back([=]() {
10488 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
10489 " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
10490 to_expression(builtin_subgroup_invocation_id_id),
10491 " + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
10492 to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0)), uint2(0));");
10493 });
10494 break;
10495 case BuiltInSubgroupLtMask:
10496 if (msl_options.is_ios())
10497 SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
10498 if (!msl_options.supports_msl_version(2, 1))
10499 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
10500 entry_func.fixup_hooks_in.push_back([=]() {
10501 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
10502 " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
10503 to_expression(builtin_subgroup_invocation_id_id),
10504 ", 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
10505 to_expression(builtin_subgroup_invocation_id_id), " - 32, 0)), uint2(0));");
10506 });
10507 break;
10508 case BuiltInViewIndex:
10509 if (!msl_options.multiview)
10510 {
10511 // According to the Vulkan spec, when not running under a multiview
10512 // render pass, ViewIndex is 0.
10513 entry_func.fixup_hooks_in.push_back([=]() {
10514 statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = 0;");
10515 });
10516 }
10517 else if (msl_options.view_index_from_device_index)
10518 {
10519 // In this case, we take the view index from that of the device we're running on.
10520 entry_func.fixup_hooks_in.push_back([=]() {
10521 statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10522 msl_options.device_index, ";");
10523 });
10524 // We actually don't want to set the render_target_array_index here.
10525 // Since every physical device is rendering a different view,
10526 // there's no need for layered rendering here.
10527 }
10528 else if (!msl_options.multiview_layered_rendering)
10529 {
10530 // In this case, the views are rendered one at a time. The view index, then,
10531 // is just the first part of the "view mask".
10532 entry_func.fixup_hooks_in.push_back([=]() {
10533 statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10534 to_expression(view_mask_buffer_id), "[0];");
10535 });
10536 }
10537 else if (get_execution_model() == ExecutionModelFragment)
10538 {
10539 // Because we adjusted the view index in the vertex shader, we have to
10540 // adjust it back here.
10541 entry_func.fixup_hooks_in.push_back([=]() {
10542 statement(to_expression(var_id), " += ", to_expression(view_mask_buffer_id), "[0];");
10543 });
10544 }
10545 else if (get_execution_model() == ExecutionModelVertex)
10546 {
10547 // Metal provides no special support for multiview, so we smuggle
10548 // the view index in the instance index.
10549 entry_func.fixup_hooks_in.push_back([=]() {
10550 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10551 to_expression(view_mask_buffer_id), "[0] + (", to_expression(builtin_instance_idx_id),
10552 " - ", to_expression(builtin_base_instance_id), ") % ",
10553 to_expression(view_mask_buffer_id), "[1];");
10554 statement(to_expression(builtin_instance_idx_id), " = (",
10555 to_expression(builtin_instance_idx_id), " - ",
10556 to_expression(builtin_base_instance_id), ") / ", to_expression(view_mask_buffer_id),
10557 "[1] + ", to_expression(builtin_base_instance_id), ";");
10558 });
10559 // In addition to setting the variable itself, we also need to
10560 // set the render_target_array_index with it on output. We have to
10561 // offset this by the base view index, because Metal isn't in on
10562 // our little game here.
10563 entry_func.fixup_hooks_out.push_back([=]() {
10564 statement(to_expression(builtin_layer_id), " = ", to_expression(var_id), " - ",
10565 to_expression(view_mask_buffer_id), "[0];");
10566 });
10567 }
10568 break;
10569 case BuiltInDeviceIndex:
10570 // Metal pipelines belong to the devices which create them, so we'll
10571 // need to create a MTLPipelineState for every MTLDevice in a grouped
10572 // VkDevice. We can assume, then, that the device index is constant.
10573 entry_func.fixup_hooks_in.push_back([=]() {
10574 statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10575 msl_options.device_index, ";");
10576 });
10577 break;
10578 case BuiltInWorkgroupId:
10579 if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInWorkgroupId))
10580 break;
10581
10582 // The vkCmdDispatchBase() command lets the client set the base value
10583 // of WorkgroupId. Metal has no direct equivalent; we must make this
10584 // adjustment ourselves.
10585 entry_func.fixup_hooks_in.push_back([=]() {
10586 statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id), ";");
10587 });
10588 break;
10589 case BuiltInGlobalInvocationId:
10590 if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInGlobalInvocationId))
10591 break;
10592
10593 // GlobalInvocationId is defined as LocalInvocationId + WorkgroupId * WorkgroupSize.
10594 // This needs to be adjusted too.
10595 entry_func.fixup_hooks_in.push_back([=]() {
10596 auto &execution = this->get_entry_point();
10597 uint32_t workgroup_size_id = execution.workgroup_size.constant;
10598 if (workgroup_size_id)
10599 statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
10600 " * ", to_expression(workgroup_size_id), ";");
10601 else
10602 statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
10603 " * uint3(", execution.workgroup_size.x, ", ", execution.workgroup_size.y, ", ",
10604 execution.workgroup_size.z, ");");
10605 });
10606 break;
10607 case BuiltInVertexId:
10608 case BuiltInVertexIndex:
10609 // This is direct-mapped normally.
10610 if (!msl_options.vertex_for_tessellation)
10611 break;
10612
10613 entry_func.fixup_hooks_in.push_back([=]() {
10614 builtin_declaration = true;
10615 switch (msl_options.vertex_index_type)
10616 {
10617 case Options::IndexType::None:
10618 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10619 to_expression(builtin_invocation_id_id), ".x + ",
10620 to_expression(builtin_dispatch_base_id), ".x;");
10621 break;
10622 case Options::IndexType::UInt16:
10623 case Options::IndexType::UInt32:
10624 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", index_buffer_var_name,
10625 "[", to_expression(builtin_invocation_id_id), ".x] + ",
10626 to_expression(builtin_dispatch_base_id), ".x;");
10627 break;
10628 }
10629 builtin_declaration = false;
10630 });
10631 break;
10632 case BuiltInBaseVertex:
10633 // This is direct-mapped normally.
10634 if (!msl_options.vertex_for_tessellation)
10635 break;
10636
10637 entry_func.fixup_hooks_in.push_back([=]() {
10638 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10639 to_expression(builtin_dispatch_base_id), ".x;");
10640 });
10641 break;
10642 case BuiltInInstanceId:
10643 case BuiltInInstanceIndex:
10644 // This is direct-mapped normally.
10645 if (!msl_options.vertex_for_tessellation)
10646 break;
10647
10648 entry_func.fixup_hooks_in.push_back([=]() {
10649 builtin_declaration = true;
10650 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10651 to_expression(builtin_invocation_id_id), ".y + ", to_expression(builtin_dispatch_base_id),
10652 ".y;");
10653 builtin_declaration = false;
10654 });
10655 break;
10656 case BuiltInBaseInstance:
10657 // This is direct-mapped normally.
10658 if (!msl_options.vertex_for_tessellation)
10659 break;
10660
10661 entry_func.fixup_hooks_in.push_back([=]() {
10662 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10663 to_expression(builtin_dispatch_base_id), ".y;");
10664 });
10665 break;
10666 default:
10667 break;
10668 }
10669 }
10670 else if (var.storage == StorageClassOutput && is_builtin_variable(var))
10671 {
10672 if (bi_type == BuiltInSampleMask && get_execution_model() == ExecutionModelFragment &&
10673 msl_options.additional_fixed_sample_mask != 0xffffffff)
10674 {
10675 // If the additional fixed sample mask was set, we need to adjust the sample_mask
10676 // output to reflect that. If the shader outputs the sample_mask itself too, we need
10677 // to AND the two masks to get the final one.
10678 if (does_shader_write_sample_mask)
10679 {
10680 entry_func.fixup_hooks_out.push_back([=]() {
10681 statement(to_expression(builtin_sample_mask_id),
10682 " &= ", msl_options.additional_fixed_sample_mask, ";");
10683 });
10684 }
10685 else
10686 {
10687 entry_func.fixup_hooks_out.push_back([=]() {
10688 statement(to_expression(builtin_sample_mask_id), " = ",
10689 msl_options.additional_fixed_sample_mask, ";");
10690 });
10691 }
10692 }
10693 }
10694 });
10695 }
10696
10697 // Returns the Metal index of the resource of the specified type as used by the specified variable.
get_metal_resource_index(SPIRVariable & var,SPIRType::BaseType basetype,uint32_t plane)10698 uint32_t CompilerMSL::get_metal_resource_index(SPIRVariable &var, SPIRType::BaseType basetype, uint32_t plane)
10699 {
10700 auto &execution = get_entry_point();
10701 auto &var_dec = ir.meta[var.self].decoration;
10702 auto &var_type = get<SPIRType>(var.basetype);
10703 uint32_t var_desc_set = (var.storage == StorageClassPushConstant) ? kPushConstDescSet : var_dec.set;
10704 uint32_t var_binding = (var.storage == StorageClassPushConstant) ? kPushConstBinding : var_dec.binding;
10705
10706 // If a matching binding has been specified, find and use it.
10707 auto itr = resource_bindings.find({ execution.model, var_desc_set, var_binding });
10708
10709 // Atomic helper buffers for image atomics need to use secondary bindings as well.
10710 bool use_secondary_binding = (var_type.basetype == SPIRType::SampledImage && basetype == SPIRType::Sampler) ||
10711 basetype == SPIRType::AtomicCounter;
10712
10713 auto resource_decoration =
10714 use_secondary_binding ? SPIRVCrossDecorationResourceIndexSecondary : SPIRVCrossDecorationResourceIndexPrimary;
10715
10716 if (plane == 1)
10717 resource_decoration = SPIRVCrossDecorationResourceIndexTertiary;
10718 if (plane == 2)
10719 resource_decoration = SPIRVCrossDecorationResourceIndexQuaternary;
10720
10721 if (itr != end(resource_bindings))
10722 {
10723 auto &remap = itr->second;
10724 remap.second = true;
10725 switch (basetype)
10726 {
10727 case SPIRType::Image:
10728 set_extended_decoration(var.self, resource_decoration, remap.first.msl_texture + plane);
10729 return remap.first.msl_texture + plane;
10730 case SPIRType::Sampler:
10731 set_extended_decoration(var.self, resource_decoration, remap.first.msl_sampler);
10732 return remap.first.msl_sampler;
10733 default:
10734 set_extended_decoration(var.self, resource_decoration, remap.first.msl_buffer);
10735 return remap.first.msl_buffer;
10736 }
10737 }
10738
10739 // If we have already allocated an index, keep using it.
10740 if (has_extended_decoration(var.self, resource_decoration))
10741 return get_extended_decoration(var.self, resource_decoration);
10742
10743 // Allow user to enable decoration binding
10744 if (msl_options.enable_decoration_binding)
10745 {
10746 // If there is no explicit mapping of bindings to MSL, use the declared binding.
10747 if (has_decoration(var.self, DecorationBinding))
10748 {
10749 var_binding = get_decoration(var.self, DecorationBinding);
10750 // Avoid emitting sentinel bindings.
10751 if (var_binding < 0x80000000u)
10752 return var_binding;
10753 }
10754 }
10755
10756 // If we did not explicitly remap, allocate bindings on demand.
10757 // We cannot reliably use Binding decorations since SPIR-V and MSL's binding models are very different.
10758
10759 bool allocate_argument_buffer_ids = false;
10760
10761 if (var.storage != StorageClassPushConstant)
10762 allocate_argument_buffer_ids = descriptor_set_is_argument_buffer(var_desc_set);
10763
10764 uint32_t binding_stride = 1;
10765 auto &type = get<SPIRType>(var.basetype);
10766 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
10767 binding_stride *= to_array_size_literal(type, i);
10768
10769 assert(binding_stride != 0);
10770
10771 // If a binding has not been specified, revert to incrementing resource indices.
10772 uint32_t resource_index;
10773
10774 if (type_is_msl_framebuffer_fetch(type))
10775 {
10776 // Frame-buffer fetch gets its fallback resource index from the input attachment index,
10777 // which is then treated as color index.
10778 resource_index = get_decoration(var.self, DecorationInputAttachmentIndex);
10779 }
10780 else if (allocate_argument_buffer_ids)
10781 {
10782 // Allocate from a flat ID binding space.
10783 resource_index = next_metal_resource_ids[var_desc_set];
10784 next_metal_resource_ids[var_desc_set] += binding_stride;
10785 }
10786 else
10787 {
10788 // Allocate from plain bindings which are allocated per resource type.
10789 switch (basetype)
10790 {
10791 case SPIRType::Image:
10792 resource_index = next_metal_resource_index_texture;
10793 next_metal_resource_index_texture += binding_stride;
10794 break;
10795 case SPIRType::Sampler:
10796 resource_index = next_metal_resource_index_sampler;
10797 next_metal_resource_index_sampler += binding_stride;
10798 break;
10799 default:
10800 resource_index = next_metal_resource_index_buffer;
10801 next_metal_resource_index_buffer += binding_stride;
10802 break;
10803 }
10804 }
10805
10806 set_extended_decoration(var.self, resource_decoration, resource_index);
10807 return resource_index;
10808 }
10809
type_is_msl_framebuffer_fetch(const SPIRType & type) const10810 bool CompilerMSL::type_is_msl_framebuffer_fetch(const SPIRType &type) const
10811 {
10812 return type.basetype == SPIRType::Image && type.image.dim == DimSubpassData && msl_options.is_ios() &&
10813 msl_options.ios_use_framebuffer_fetch_subpasses;
10814 }
10815
argument_decl(const SPIRFunction::Parameter & arg)10816 string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
10817 {
10818 auto &var = get<SPIRVariable>(arg.id);
10819 auto &type = get_variable_data_type(var);
10820 auto &var_type = get<SPIRType>(arg.type);
10821 StorageClass storage = var_type.storage;
10822 bool is_pointer = var_type.pointer;
10823
10824 // If we need to modify the name of the variable, make sure we use the original variable.
10825 // Our alias is just a shadow variable.
10826 uint32_t name_id = var.self;
10827 if (arg.alias_global_variable && var.basevariable)
10828 name_id = var.basevariable;
10829
10830 bool constref = !arg.alias_global_variable && is_pointer && arg.write_count == 0;
10831 // Framebuffer fetch is plain value, const looks out of place, but it is not wrong.
10832 if (type_is_msl_framebuffer_fetch(type))
10833 constref = false;
10834
10835 bool type_is_image = type.basetype == SPIRType::Image || type.basetype == SPIRType::SampledImage ||
10836 type.basetype == SPIRType::Sampler;
10837
10838 // Arrays of images/samplers in MSL are always const.
10839 if (!type.array.empty() && type_is_image)
10840 constref = true;
10841
10842 string decl;
10843 if (constref)
10844 decl += "const ";
10845
10846 // If this is a combined image-sampler for a 2D image with floating-point type,
10847 // we emitted the 'spvDynamicImageSampler' type, and this is *not* an alias parameter
10848 // for a global, then we need to emit a "dynamic" combined image-sampler.
10849 // Unfortunately, this is necessary to properly support passing around
10850 // combined image-samplers with Y'CbCr conversions on them.
10851 bool is_dynamic_img_sampler = !arg.alias_global_variable && type.basetype == SPIRType::SampledImage &&
10852 type.image.dim == Dim2D && type_is_floating_point(get<SPIRType>(type.image.type)) &&
10853 spv_function_implementations.count(SPVFuncImplDynamicImageSampler);
10854
10855 // Allow Metal to use the array<T> template to make arrays a value type
10856 string address_space = get_argument_address_space(var);
10857 bool builtin = is_builtin_variable(var);
10858 is_using_builtin_array = builtin;
10859 if (address_space == "threadgroup")
10860 is_using_builtin_array = true;
10861
10862 if (var.basevariable && (var.basevariable == stage_in_ptr_var_id || var.basevariable == stage_out_ptr_var_id))
10863 decl += type_to_glsl(type, arg.id);
10864 else if (builtin)
10865 decl += builtin_type_decl(static_cast<BuiltIn>(get_decoration(arg.id, DecorationBuiltIn)), arg.id);
10866 else if ((storage == StorageClassUniform || storage == StorageClassStorageBuffer) && is_array(type))
10867 {
10868 is_using_builtin_array = true;
10869 decl += join(type_to_glsl(type, arg.id), "*");
10870 }
10871 else if (is_dynamic_img_sampler)
10872 {
10873 decl += join("spvDynamicImageSampler<", type_to_glsl(get<SPIRType>(type.image.type)), ">");
10874 // Mark the variable so that we can handle passing it to another function.
10875 set_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
10876 }
10877 else
10878 decl += type_to_glsl(type, arg.id);
10879
10880 bool opaque_handle = storage == StorageClassUniformConstant;
10881
10882 if (!builtin && !opaque_handle && !is_pointer &&
10883 (storage == StorageClassFunction || storage == StorageClassGeneric))
10884 {
10885 // If the argument is a pure value and not an opaque type, we will pass by value.
10886 if (msl_options.force_native_arrays && is_array(type))
10887 {
10888 // We are receiving an array by value. This is problematic.
10889 // We cannot be sure of the target address space since we are supposed to receive a copy,
10890 // but this is not possible with MSL without some extra work.
10891 // We will have to assume we're getting a reference in thread address space.
10892 // If we happen to get a reference in constant address space, the caller must emit a copy and pass that.
10893 // Thread const therefore becomes the only logical choice, since we cannot "create" a constant array from
10894 // non-constant arrays, but we can create thread const from constant.
10895 decl = string("thread const ") + decl;
10896 decl += " (&";
10897 const char *restrict_kw = to_restrict(name_id);
10898 if (*restrict_kw)
10899 {
10900 decl += " ";
10901 decl += restrict_kw;
10902 }
10903 decl += to_expression(name_id);
10904 decl += ")";
10905 decl += type_to_array_glsl(type);
10906 }
10907 else
10908 {
10909 if (!address_space.empty())
10910 decl = join(address_space, " ", decl);
10911 decl += " ";
10912 decl += to_expression(name_id);
10913 }
10914 }
10915 else if (is_array(type) && !type_is_image)
10916 {
10917 // Arrays of images and samplers are special cased.
10918 if (!address_space.empty())
10919 decl = join(address_space, " ", decl);
10920
10921 if (msl_options.argument_buffers)
10922 {
10923 uint32_t desc_set = get_decoration(name_id, DecorationDescriptorSet);
10924 if ((storage == StorageClassUniform || storage == StorageClassStorageBuffer) &&
10925 descriptor_set_is_argument_buffer(desc_set))
10926 {
10927 // An awkward case where we need to emit *more* address space declarations (yay!).
10928 // An example is where we pass down an array of buffer pointers to leaf functions.
10929 // It's a constant array containing pointers to constants.
10930 // The pointer array is always constant however. E.g.
10931 // device SSBO * constant (&array)[N].
10932 // const device SSBO * constant (&array)[N].
10933 // constant SSBO * constant (&array)[N].
10934 // However, this only matters for argument buffers, since for MSL 1.0 style codegen,
10935 // we emit the buffer array on stack instead, and that seems to work just fine apparently.
10936
10937 // If the argument was marked as being in device address space, any pointer to member would
10938 // be const device, not constant.
10939 if (argument_buffer_device_storage_mask & (1u << desc_set))
10940 decl += " const device";
10941 else
10942 decl += " constant";
10943 }
10944 }
10945
10946 decl += " (&";
10947 const char *restrict_kw = to_restrict(name_id);
10948 if (*restrict_kw)
10949 {
10950 decl += " ";
10951 decl += restrict_kw;
10952 }
10953 decl += to_expression(name_id);
10954 decl += ")";
10955 decl += type_to_array_glsl(type);
10956 }
10957 else if (!opaque_handle)
10958 {
10959 // If this is going to be a reference to a variable pointer, the address space
10960 // for the reference has to go before the '&', but after the '*'.
10961 if (!address_space.empty())
10962 {
10963 if (decl.back() == '*')
10964 decl += join(" ", address_space, " ");
10965 else
10966 decl = join(address_space, " ", decl);
10967 }
10968 decl += "&";
10969 decl += " ";
10970 decl += to_restrict(name_id);
10971 decl += to_expression(name_id);
10972 }
10973 else
10974 {
10975 if (!address_space.empty())
10976 decl = join(address_space, " ", decl);
10977 decl += " ";
10978 decl += to_expression(name_id);
10979 }
10980
10981 // Emulate texture2D atomic operations
10982 auto *backing_var = maybe_get_backing_variable(name_id);
10983 if (backing_var && atomic_image_vars.count(backing_var->self))
10984 {
10985 decl += ", device atomic_" + type_to_glsl(get<SPIRType>(var_type.image.type), 0);
10986 decl += "* " + to_expression(name_id) + "_atomic";
10987 }
10988
10989 is_using_builtin_array = false;
10990
10991 return decl;
10992 }
10993
10994 // If we're currently in the entry point function, and the object
10995 // has a qualified name, use it, otherwise use the standard name.
to_name(uint32_t id,bool allow_alias) const10996 string CompilerMSL::to_name(uint32_t id, bool allow_alias) const
10997 {
10998 if (current_function && (current_function->self == ir.default_entry_point))
10999 {
11000 auto *m = ir.find_meta(id);
11001 if (m && !m->decoration.qualified_alias.empty())
11002 return m->decoration.qualified_alias;
11003 }
11004 return Compiler::to_name(id, allow_alias);
11005 }
11006
11007 // Returns a name that combines the name of the struct with the name of the member, except for Builtins
to_qualified_member_name(const SPIRType & type,uint32_t index)11008 string CompilerMSL::to_qualified_member_name(const SPIRType &type, uint32_t index)
11009 {
11010 // Don't qualify Builtin names because they are unique and are treated as such when building expressions
11011 BuiltIn builtin = BuiltInMax;
11012 if (is_member_builtin(type, index, &builtin))
11013 return builtin_to_glsl(builtin, type.storage);
11014
11015 // Strip any underscore prefix from member name
11016 string mbr_name = to_member_name(type, index);
11017 size_t startPos = mbr_name.find_first_not_of("_");
11018 mbr_name = (startPos != string::npos) ? mbr_name.substr(startPos) : "";
11019 return join(to_name(type.self), "_", mbr_name);
11020 }
11021
11022 // Ensures that the specified name is permanently usable by prepending a prefix
11023 // if the first chars are _ and a digit, which indicate a transient name.
ensure_valid_name(string name,string pfx)11024 string CompilerMSL::ensure_valid_name(string name, string pfx)
11025 {
11026 return (name.size() >= 2 && name[0] == '_' && isdigit(name[1])) ? (pfx + name) : name;
11027 }
11028
11029 // Replace all names that match MSL keywords or Metal Standard Library functions.
replace_illegal_names()11030 void CompilerMSL::replace_illegal_names()
11031 {
11032 // FIXME: MSL and GLSL are doing two different things here.
11033 // Agree on convention and remove this override.
11034 static const unordered_set<string> keywords = {
11035 "kernel",
11036 "vertex",
11037 "fragment",
11038 "compute",
11039 "bias",
11040 "assert",
11041 "VARIABLE_TRACEPOINT",
11042 "STATIC_DATA_TRACEPOINT",
11043 "STATIC_DATA_TRACEPOINT_V",
11044 "METAL_ALIGN",
11045 "METAL_ASM",
11046 "METAL_CONST",
11047 "METAL_DEPRECATED",
11048 "METAL_ENABLE_IF",
11049 "METAL_FUNC",
11050 "METAL_INTERNAL",
11051 "METAL_NON_NULL_RETURN",
11052 "METAL_NORETURN",
11053 "METAL_NOTHROW",
11054 "METAL_PURE",
11055 "METAL_UNAVAILABLE",
11056 "METAL_IMPLICIT",
11057 "METAL_EXPLICIT",
11058 "METAL_CONST_ARG",
11059 "METAL_ARG_UNIFORM",
11060 "METAL_ZERO_ARG",
11061 "METAL_VALID_LOD_ARG",
11062 "METAL_VALID_LEVEL_ARG",
11063 "METAL_VALID_STORE_ORDER",
11064 "METAL_VALID_LOAD_ORDER",
11065 "METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
11066 "METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
11067 "METAL_VALID_RENDER_TARGET",
11068 "is_function_constant_defined",
11069 "CHAR_BIT",
11070 "SCHAR_MAX",
11071 "SCHAR_MIN",
11072 "UCHAR_MAX",
11073 "CHAR_MAX",
11074 "CHAR_MIN",
11075 "USHRT_MAX",
11076 "SHRT_MAX",
11077 "SHRT_MIN",
11078 "UINT_MAX",
11079 "INT_MAX",
11080 "INT_MIN",
11081 "FLT_DIG",
11082 "FLT_MANT_DIG",
11083 "FLT_MAX_10_EXP",
11084 "FLT_MAX_EXP",
11085 "FLT_MIN_10_EXP",
11086 "FLT_MIN_EXP",
11087 "FLT_RADIX",
11088 "FLT_MAX",
11089 "FLT_MIN",
11090 "FLT_EPSILON",
11091 "FP_ILOGB0",
11092 "FP_ILOGBNAN",
11093 "MAXFLOAT",
11094 "HUGE_VALF",
11095 "INFINITY",
11096 "NAN",
11097 "M_E_F",
11098 "M_LOG2E_F",
11099 "M_LOG10E_F",
11100 "M_LN2_F",
11101 "M_LN10_F",
11102 "M_PI_F",
11103 "M_PI_2_F",
11104 "M_PI_4_F",
11105 "M_1_PI_F",
11106 "M_2_PI_F",
11107 "M_2_SQRTPI_F",
11108 "M_SQRT2_F",
11109 "M_SQRT1_2_F",
11110 "HALF_DIG",
11111 "HALF_MANT_DIG",
11112 "HALF_MAX_10_EXP",
11113 "HALF_MAX_EXP",
11114 "HALF_MIN_10_EXP",
11115 "HALF_MIN_EXP",
11116 "HALF_RADIX",
11117 "HALF_MAX",
11118 "HALF_MIN",
11119 "HALF_EPSILON",
11120 "MAXHALF",
11121 "HUGE_VALH",
11122 "M_E_H",
11123 "M_LOG2E_H",
11124 "M_LOG10E_H",
11125 "M_LN2_H",
11126 "M_LN10_H",
11127 "M_PI_H",
11128 "M_PI_2_H",
11129 "M_PI_4_H",
11130 "M_1_PI_H",
11131 "M_2_PI_H",
11132 "M_2_SQRTPI_H",
11133 "M_SQRT2_H",
11134 "M_SQRT1_2_H",
11135 "DBL_DIG",
11136 "DBL_MANT_DIG",
11137 "DBL_MAX_10_EXP",
11138 "DBL_MAX_EXP",
11139 "DBL_MIN_10_EXP",
11140 "DBL_MIN_EXP",
11141 "DBL_RADIX",
11142 "DBL_MAX",
11143 "DBL_MIN",
11144 "DBL_EPSILON",
11145 "HUGE_VAL",
11146 "M_E",
11147 "M_LOG2E",
11148 "M_LOG10E",
11149 "M_LN2",
11150 "M_LN10",
11151 "M_PI",
11152 "M_PI_2",
11153 "M_PI_4",
11154 "M_1_PI",
11155 "M_2_PI",
11156 "M_2_SQRTPI",
11157 "M_SQRT2",
11158 "M_SQRT1_2",
11159 "quad_broadcast",
11160 };
11161
11162 static const unordered_set<string> illegal_func_names = {
11163 "main",
11164 "saturate",
11165 "assert",
11166 "VARIABLE_TRACEPOINT",
11167 "STATIC_DATA_TRACEPOINT",
11168 "STATIC_DATA_TRACEPOINT_V",
11169 "METAL_ALIGN",
11170 "METAL_ASM",
11171 "METAL_CONST",
11172 "METAL_DEPRECATED",
11173 "METAL_ENABLE_IF",
11174 "METAL_FUNC",
11175 "METAL_INTERNAL",
11176 "METAL_NON_NULL_RETURN",
11177 "METAL_NORETURN",
11178 "METAL_NOTHROW",
11179 "METAL_PURE",
11180 "METAL_UNAVAILABLE",
11181 "METAL_IMPLICIT",
11182 "METAL_EXPLICIT",
11183 "METAL_CONST_ARG",
11184 "METAL_ARG_UNIFORM",
11185 "METAL_ZERO_ARG",
11186 "METAL_VALID_LOD_ARG",
11187 "METAL_VALID_LEVEL_ARG",
11188 "METAL_VALID_STORE_ORDER",
11189 "METAL_VALID_LOAD_ORDER",
11190 "METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
11191 "METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
11192 "METAL_VALID_RENDER_TARGET",
11193 "is_function_constant_defined",
11194 "CHAR_BIT",
11195 "SCHAR_MAX",
11196 "SCHAR_MIN",
11197 "UCHAR_MAX",
11198 "CHAR_MAX",
11199 "CHAR_MIN",
11200 "USHRT_MAX",
11201 "SHRT_MAX",
11202 "SHRT_MIN",
11203 "UINT_MAX",
11204 "INT_MAX",
11205 "INT_MIN",
11206 "FLT_DIG",
11207 "FLT_MANT_DIG",
11208 "FLT_MAX_10_EXP",
11209 "FLT_MAX_EXP",
11210 "FLT_MIN_10_EXP",
11211 "FLT_MIN_EXP",
11212 "FLT_RADIX",
11213 "FLT_MAX",
11214 "FLT_MIN",
11215 "FLT_EPSILON",
11216 "FP_ILOGB0",
11217 "FP_ILOGBNAN",
11218 "MAXFLOAT",
11219 "HUGE_VALF",
11220 "INFINITY",
11221 "NAN",
11222 "M_E_F",
11223 "M_LOG2E_F",
11224 "M_LOG10E_F",
11225 "M_LN2_F",
11226 "M_LN10_F",
11227 "M_PI_F",
11228 "M_PI_2_F",
11229 "M_PI_4_F",
11230 "M_1_PI_F",
11231 "M_2_PI_F",
11232 "M_2_SQRTPI_F",
11233 "M_SQRT2_F",
11234 "M_SQRT1_2_F",
11235 "HALF_DIG",
11236 "HALF_MANT_DIG",
11237 "HALF_MAX_10_EXP",
11238 "HALF_MAX_EXP",
11239 "HALF_MIN_10_EXP",
11240 "HALF_MIN_EXP",
11241 "HALF_RADIX",
11242 "HALF_MAX",
11243 "HALF_MIN",
11244 "HALF_EPSILON",
11245 "MAXHALF",
11246 "HUGE_VALH",
11247 "M_E_H",
11248 "M_LOG2E_H",
11249 "M_LOG10E_H",
11250 "M_LN2_H",
11251 "M_LN10_H",
11252 "M_PI_H",
11253 "M_PI_2_H",
11254 "M_PI_4_H",
11255 "M_1_PI_H",
11256 "M_2_PI_H",
11257 "M_2_SQRTPI_H",
11258 "M_SQRT2_H",
11259 "M_SQRT1_2_H",
11260 "DBL_DIG",
11261 "DBL_MANT_DIG",
11262 "DBL_MAX_10_EXP",
11263 "DBL_MAX_EXP",
11264 "DBL_MIN_10_EXP",
11265 "DBL_MIN_EXP",
11266 "DBL_RADIX",
11267 "DBL_MAX",
11268 "DBL_MIN",
11269 "DBL_EPSILON",
11270 "HUGE_VAL",
11271 "M_E",
11272 "M_LOG2E",
11273 "M_LOG10E",
11274 "M_LN2",
11275 "M_LN10",
11276 "M_PI",
11277 "M_PI_2",
11278 "M_PI_4",
11279 "M_1_PI",
11280 "M_2_PI",
11281 "M_2_SQRTPI",
11282 "M_SQRT2",
11283 "M_SQRT1_2",
11284 };
11285
11286 ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &) {
11287 auto *meta = ir.find_meta(self);
11288 if (!meta)
11289 return;
11290
11291 auto &dec = meta->decoration;
11292 if (keywords.find(dec.alias) != end(keywords))
11293 dec.alias += "0";
11294 });
11295
11296 ir.for_each_typed_id<SPIRFunction>([&](uint32_t self, SPIRFunction &) {
11297 auto *meta = ir.find_meta(self);
11298 if (!meta)
11299 return;
11300
11301 auto &dec = meta->decoration;
11302 if (illegal_func_names.find(dec.alias) != end(illegal_func_names))
11303 dec.alias += "0";
11304 });
11305
11306 ir.for_each_typed_id<SPIRType>([&](uint32_t self, SPIRType &) {
11307 auto *meta = ir.find_meta(self);
11308 if (!meta)
11309 return;
11310
11311 for (auto &mbr_dec : meta->members)
11312 if (keywords.find(mbr_dec.alias) != end(keywords))
11313 mbr_dec.alias += "0";
11314 });
11315
11316 for (auto &entry : ir.entry_points)
11317 {
11318 // Change both the entry point name and the alias, to keep them synced.
11319 string &ep_name = entry.second.name;
11320 if (illegal_func_names.find(ep_name) != end(illegal_func_names))
11321 ep_name += "0";
11322
11323 // Always write this because entry point might have been renamed earlier.
11324 ir.meta[entry.first].decoration.alias = ep_name;
11325 }
11326
11327 CompilerGLSL::replace_illegal_names();
11328 }
11329
to_member_reference(uint32_t base,const SPIRType & type,uint32_t index,bool ptr_chain)11330 string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain)
11331 {
11332 if (index < uint32_t(type.member_type_index_redirection.size()))
11333 index = type.member_type_index_redirection[index];
11334
11335 auto *var = maybe_get<SPIRVariable>(base);
11336 // If this is a buffer array, we have to dereference the buffer pointers.
11337 // Otherwise, if this is a pointer expression, dereference it.
11338
11339 bool declared_as_pointer = false;
11340
11341 if (var)
11342 {
11343 // Only allow -> dereference for block types. This is so we get expressions like
11344 // buffer[i]->first_member.second_member, rather than buffer[i]->first->second.
11345 bool is_block = has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock);
11346
11347 bool is_buffer_variable =
11348 is_block && (var->storage == StorageClassUniform || var->storage == StorageClassStorageBuffer);
11349 declared_as_pointer = is_buffer_variable && is_array(get<SPIRType>(var->basetype));
11350 }
11351
11352 if (declared_as_pointer || (!ptr_chain && should_dereference(base)))
11353 return join("->", to_member_name(type, index));
11354 else
11355 return join(".", to_member_name(type, index));
11356 }
11357
to_qualifiers_glsl(uint32_t id)11358 string CompilerMSL::to_qualifiers_glsl(uint32_t id)
11359 {
11360 string quals;
11361
11362 auto &type = expression_type(id);
11363 if (type.storage == StorageClassWorkgroup)
11364 quals += "threadgroup ";
11365
11366 return quals;
11367 }
11368
11369 // The optional id parameter indicates the object whose type we are trying
11370 // to find the description for. It is optional. Most type descriptions do not
11371 // depend on a specific object's use of that type.
type_to_glsl(const SPIRType & type,uint32_t id)11372 string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id)
11373 {
11374 string type_name;
11375
11376 // Pointer?
11377 if (type.pointer)
11378 {
11379 const char *restrict_kw;
11380 type_name = join(get_type_address_space(type, id), " ", type_to_glsl(get<SPIRType>(type.parent_type), id));
11381
11382 switch (type.basetype)
11383 {
11384 case SPIRType::Image:
11385 case SPIRType::SampledImage:
11386 case SPIRType::Sampler:
11387 // These are handles.
11388 break;
11389 default:
11390 // Anything else can be a raw pointer.
11391 type_name += "*";
11392 restrict_kw = to_restrict(id);
11393 if (*restrict_kw)
11394 {
11395 type_name += " ";
11396 type_name += restrict_kw;
11397 }
11398 break;
11399 }
11400 return type_name;
11401 }
11402
11403 switch (type.basetype)
11404 {
11405 case SPIRType::Struct:
11406 // Need OpName lookup here to get a "sensible" name for a struct.
11407 // Allow Metal to use the array<T> template to make arrays a value type
11408 type_name = to_name(type.self);
11409 break;
11410
11411 case SPIRType::Image:
11412 case SPIRType::SampledImage:
11413 return image_type_glsl(type, id);
11414
11415 case SPIRType::Sampler:
11416 return sampler_type(type);
11417
11418 case SPIRType::Void:
11419 return "void";
11420
11421 case SPIRType::AtomicCounter:
11422 return "atomic_uint";
11423
11424 case SPIRType::ControlPointArray:
11425 return join("patch_control_point<", type_to_glsl(get<SPIRType>(type.parent_type), id), ">");
11426
11427 // Scalars
11428 case SPIRType::Boolean:
11429 type_name = "bool";
11430 break;
11431 case SPIRType::Char:
11432 case SPIRType::SByte:
11433 type_name = "char";
11434 break;
11435 case SPIRType::UByte:
11436 type_name = "uchar";
11437 break;
11438 case SPIRType::Short:
11439 type_name = "short";
11440 break;
11441 case SPIRType::UShort:
11442 type_name = "ushort";
11443 break;
11444 case SPIRType::Int:
11445 type_name = "int";
11446 break;
11447 case SPIRType::UInt:
11448 type_name = "uint";
11449 break;
11450 case SPIRType::Int64:
11451 if (!msl_options.supports_msl_version(2, 2))
11452 SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
11453 type_name = "long";
11454 break;
11455 case SPIRType::UInt64:
11456 if (!msl_options.supports_msl_version(2, 2))
11457 SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
11458 type_name = "ulong";
11459 break;
11460 case SPIRType::Half:
11461 type_name = "half";
11462 break;
11463 case SPIRType::Float:
11464 type_name = "float";
11465 break;
11466 case SPIRType::Double:
11467 type_name = "double"; // Currently unsupported
11468 break;
11469
11470 default:
11471 return "unknown_type";
11472 }
11473
11474 // Matrix?
11475 if (type.columns > 1)
11476 type_name += to_string(type.columns) + "x";
11477
11478 // Vector or Matrix?
11479 if (type.vecsize > 1)
11480 type_name += to_string(type.vecsize);
11481
11482 if (type.array.empty() || using_builtin_array())
11483 {
11484 return type_name;
11485 }
11486 else
11487 {
11488 // Allow Metal to use the array<T> template to make arrays a value type
11489 add_spv_func_and_recompile(SPVFuncImplUnsafeArray);
11490 string res;
11491 string sizes;
11492
11493 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
11494 {
11495 res += "spvUnsafeArray<";
11496 sizes += ", ";
11497 sizes += to_array_size(type, i);
11498 sizes += ">";
11499 }
11500
11501 res += type_name + sizes;
11502 return res;
11503 }
11504 }
11505
type_to_array_glsl(const SPIRType & type)11506 string CompilerMSL::type_to_array_glsl(const SPIRType &type)
11507 {
11508 // Allow Metal to use the array<T> template to make arrays a value type
11509 switch (type.basetype)
11510 {
11511 case SPIRType::AtomicCounter:
11512 case SPIRType::ControlPointArray:
11513 {
11514 return CompilerGLSL::type_to_array_glsl(type);
11515 }
11516 default:
11517 {
11518 if (using_builtin_array())
11519 return CompilerGLSL::type_to_array_glsl(type);
11520 else
11521 return "";
11522 }
11523 }
11524 }
11525
11526 // Threadgroup arrays can't have a wrapper type
variable_decl(const SPIRVariable & variable)11527 std::string CompilerMSL::variable_decl(const SPIRVariable &variable)
11528 {
11529 if (variable.storage == StorageClassWorkgroup)
11530 {
11531 is_using_builtin_array = true;
11532 }
11533 std::string expr = CompilerGLSL::variable_decl(variable);
11534 if (variable.storage == StorageClassWorkgroup)
11535 {
11536 is_using_builtin_array = false;
11537 }
11538 return expr;
11539 }
11540
11541 // GCC workaround of lambdas calling protected funcs
variable_decl(const SPIRType & type,const std::string & name,uint32_t id)11542 std::string CompilerMSL::variable_decl(const SPIRType &type, const std::string &name, uint32_t id)
11543 {
11544 return CompilerGLSL::variable_decl(type, name, id);
11545 }
11546
sampler_type(const SPIRType & type)11547 std::string CompilerMSL::sampler_type(const SPIRType &type)
11548 {
11549 if (!type.array.empty())
11550 {
11551 if (!msl_options.supports_msl_version(2))
11552 SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of samplers.");
11553
11554 if (type.array.size() > 1)
11555 SPIRV_CROSS_THROW("Arrays of arrays of samplers are not supported in MSL.");
11556
11557 // Arrays of samplers in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
11558 uint32_t array_size = to_array_size_literal(type);
11559 if (array_size == 0)
11560 SPIRV_CROSS_THROW("Unsized array of samplers is not supported in MSL.");
11561
11562 auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
11563 return join("array<", sampler_type(parent), ", ", array_size, ">");
11564 }
11565 else
11566 return "sampler";
11567 }
11568
11569 // Returns an MSL string describing the SPIR-V image type
image_type_glsl(const SPIRType & type,uint32_t id)11570 string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id)
11571 {
11572 auto *var = maybe_get<SPIRVariable>(id);
11573 if (var && var->basevariable)
11574 {
11575 // For comparison images, check against the base variable,
11576 // and not the fake ID which might have been generated for this variable.
11577 id = var->basevariable;
11578 }
11579
11580 if (!type.array.empty())
11581 {
11582 uint32_t major = 2, minor = 0;
11583 if (msl_options.is_ios())
11584 {
11585 major = 1;
11586 minor = 2;
11587 }
11588 if (!msl_options.supports_msl_version(major, minor))
11589 {
11590 if (msl_options.is_ios())
11591 SPIRV_CROSS_THROW("MSL 1.2 or greater is required for arrays of textures.");
11592 else
11593 SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of textures.");
11594 }
11595
11596 if (type.array.size() > 1)
11597 SPIRV_CROSS_THROW("Arrays of arrays of textures are not supported in MSL.");
11598
11599 // Arrays of images in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
11600 uint32_t array_size = to_array_size_literal(type);
11601 if (array_size == 0)
11602 SPIRV_CROSS_THROW("Unsized array of images is not supported in MSL.");
11603
11604 auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
11605 return join("array<", image_type_glsl(parent, id), ", ", array_size, ">");
11606 }
11607
11608 string img_type_name;
11609
11610 // Bypass pointers because we need the real image struct
11611 auto &img_type = get<SPIRType>(type.self).image;
11612 if (image_is_comparison(type, id))
11613 {
11614 switch (img_type.dim)
11615 {
11616 case Dim1D:
11617 case Dim2D:
11618 if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
11619 {
11620 // Use a native Metal 1D texture
11621 img_type_name += "depth1d_unsupported_by_metal";
11622 break;
11623 }
11624
11625 if (img_type.ms && img_type.arrayed)
11626 {
11627 if (!msl_options.supports_msl_version(2, 1))
11628 SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
11629 img_type_name += "depth2d_ms_array";
11630 }
11631 else if (img_type.ms)
11632 img_type_name += "depth2d_ms";
11633 else if (img_type.arrayed)
11634 img_type_name += "depth2d_array";
11635 else
11636 img_type_name += "depth2d";
11637 break;
11638 case Dim3D:
11639 img_type_name += "depth3d_unsupported_by_metal";
11640 break;
11641 case DimCube:
11642 if (!msl_options.emulate_cube_array)
11643 img_type_name += (img_type.arrayed ? "depthcube_array" : "depthcube");
11644 else
11645 img_type_name += (img_type.arrayed ? "depth2d_array" : "depthcube");
11646 break;
11647 default:
11648 img_type_name += "unknown_depth_texture_type";
11649 break;
11650 }
11651 }
11652 else
11653 {
11654 switch (img_type.dim)
11655 {
11656 case DimBuffer:
11657 if (img_type.ms || img_type.arrayed)
11658 SPIRV_CROSS_THROW("Cannot use texel buffers with multisampling or array layers.");
11659
11660 if (msl_options.texture_buffer_native)
11661 {
11662 if (!msl_options.supports_msl_version(2, 1))
11663 SPIRV_CROSS_THROW("Native texture_buffer type is only supported in MSL 2.1.");
11664 img_type_name = "texture_buffer";
11665 }
11666 else
11667 img_type_name += "texture2d";
11668 break;
11669 case Dim1D:
11670 case Dim2D:
11671 case DimSubpassData:
11672 {
11673 bool subpass_array =
11674 img_type.dim == DimSubpassData && (msl_options.multiview || msl_options.arrayed_subpass_input);
11675 if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
11676 {
11677 // Use a native Metal 1D texture
11678 img_type_name += (img_type.arrayed ? "texture1d_array" : "texture1d");
11679 break;
11680 }
11681
11682 // Use Metal's native frame-buffer fetch API for subpass inputs.
11683 if (type_is_msl_framebuffer_fetch(type))
11684 {
11685 auto img_type_4 = get<SPIRType>(img_type.type);
11686 img_type_4.vecsize = 4;
11687 return type_to_glsl(img_type_4);
11688 }
11689 if (img_type.ms && (img_type.arrayed || subpass_array))
11690 {
11691 if (!msl_options.supports_msl_version(2, 1))
11692 SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
11693 img_type_name += "texture2d_ms_array";
11694 }
11695 else if (img_type.ms)
11696 img_type_name += "texture2d_ms";
11697 else if (img_type.arrayed || subpass_array)
11698 img_type_name += "texture2d_array";
11699 else
11700 img_type_name += "texture2d";
11701 break;
11702 }
11703 case Dim3D:
11704 img_type_name += "texture3d";
11705 break;
11706 case DimCube:
11707 if (!msl_options.emulate_cube_array)
11708 img_type_name += (img_type.arrayed ? "texturecube_array" : "texturecube");
11709 else
11710 img_type_name += (img_type.arrayed ? "texture2d_array" : "texturecube");
11711 break;
11712 default:
11713 img_type_name += "unknown_texture_type";
11714 break;
11715 }
11716 }
11717
11718 // Append the pixel type
11719 img_type_name += "<";
11720 img_type_name += type_to_glsl(get<SPIRType>(img_type.type));
11721
11722 // For unsampled images, append the sample/read/write access qualifier.
11723 // For kernel images, the access qualifier my be supplied directly by SPIR-V.
11724 // Otherwise it may be set based on whether the image is read from or written to within the shader.
11725 if (type.basetype == SPIRType::Image && type.image.sampled == 2 && type.image.dim != DimSubpassData)
11726 {
11727 switch (img_type.access)
11728 {
11729 case AccessQualifierReadOnly:
11730 img_type_name += ", access::read";
11731 break;
11732
11733 case AccessQualifierWriteOnly:
11734 img_type_name += ", access::write";
11735 break;
11736
11737 case AccessQualifierReadWrite:
11738 img_type_name += ", access::read_write";
11739 break;
11740
11741 default:
11742 {
11743 auto *p_var = maybe_get_backing_variable(id);
11744 if (p_var && p_var->basevariable)
11745 p_var = maybe_get<SPIRVariable>(p_var->basevariable);
11746 if (p_var && !has_decoration(p_var->self, DecorationNonWritable))
11747 {
11748 img_type_name += ", access::";
11749
11750 if (!has_decoration(p_var->self, DecorationNonReadable))
11751 img_type_name += "read_";
11752
11753 img_type_name += "write";
11754 }
11755 break;
11756 }
11757 }
11758 }
11759
11760 img_type_name += ">";
11761
11762 return img_type_name;
11763 }
11764
emit_subgroup_op(const Instruction & i)11765 void CompilerMSL::emit_subgroup_op(const Instruction &i)
11766 {
11767 const uint32_t *ops = stream(i);
11768 auto op = static_cast<Op>(i.op);
11769
11770 // Metal 2.0 is required. iOS only supports quad ops. macOS only supports
11771 // broadcast and shuffle on 10.13 (2.0), with full support in 10.14 (2.1).
11772 // Note that iOS makes no distinction between a quad-group and a subgroup;
11773 // all subgroups are quad-groups there.
11774 if (!msl_options.supports_msl_version(2))
11775 SPIRV_CROSS_THROW("Subgroups are only supported in Metal 2.0 and up.");
11776
11777 // If we need to do implicit bitcasts, make sure we do it with the correct type.
11778 uint32_t integer_width = get_integer_width_for_instruction(i);
11779 auto int_type = to_signed_basetype(integer_width);
11780 auto uint_type = to_unsigned_basetype(integer_width);
11781
11782 if (msl_options.is_ios())
11783 {
11784 switch (op)
11785 {
11786 default:
11787 SPIRV_CROSS_THROW("iOS only supports quad-group operations.");
11788 case OpGroupNonUniformBroadcast:
11789 case OpGroupNonUniformShuffle:
11790 case OpGroupNonUniformShuffleXor:
11791 case OpGroupNonUniformShuffleUp:
11792 case OpGroupNonUniformShuffleDown:
11793 case OpGroupNonUniformQuadSwap:
11794 case OpGroupNonUniformQuadBroadcast:
11795 break;
11796 }
11797 }
11798
11799 if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
11800 {
11801 switch (op)
11802 {
11803 default:
11804 SPIRV_CROSS_THROW("Subgroup ops beyond broadcast and shuffle on macOS require Metal 2.1 and up.");
11805 case OpGroupNonUniformBroadcast:
11806 case OpGroupNonUniformShuffle:
11807 case OpGroupNonUniformShuffleXor:
11808 case OpGroupNonUniformShuffleUp:
11809 case OpGroupNonUniformShuffleDown:
11810 break;
11811 }
11812 }
11813
11814 uint32_t result_type = ops[0];
11815 uint32_t id = ops[1];
11816
11817 auto scope = static_cast<Scope>(evaluate_constant_u32(ops[2]));
11818 if (scope != ScopeSubgroup)
11819 SPIRV_CROSS_THROW("Only subgroup scope is supported.");
11820
11821 switch (op)
11822 {
11823 case OpGroupNonUniformElect:
11824 emit_op(result_type, id, "simd_is_first()", true);
11825 break;
11826
11827 case OpGroupNonUniformBroadcast:
11828 emit_binary_func_op(result_type, id, ops[3], ops[4],
11829 msl_options.is_ios() ? "quad_broadcast" : "simd_broadcast");
11830 break;
11831
11832 case OpGroupNonUniformBroadcastFirst:
11833 emit_unary_func_op(result_type, id, ops[3], "simd_broadcast_first");
11834 break;
11835
11836 case OpGroupNonUniformBallot:
11837 emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallot");
11838 break;
11839
11840 case OpGroupNonUniformInverseBallot:
11841 emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_invocation_id_id, "spvSubgroupBallotBitExtract");
11842 break;
11843
11844 case OpGroupNonUniformBallotBitExtract:
11845 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBallotBitExtract");
11846 break;
11847
11848 case OpGroupNonUniformBallotFindLSB:
11849 emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallotFindLSB");
11850 break;
11851
11852 case OpGroupNonUniformBallotFindMSB:
11853 emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallotFindMSB");
11854 break;
11855
11856 case OpGroupNonUniformBallotBitCount:
11857 {
11858 auto operation = static_cast<GroupOperation>(ops[3]);
11859 if (operation == GroupOperationReduce)
11860 emit_unary_func_op(result_type, id, ops[4], "spvSubgroupBallotBitCount");
11861 else if (operation == GroupOperationInclusiveScan)
11862 emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
11863 "spvSubgroupBallotInclusiveBitCount");
11864 else if (operation == GroupOperationExclusiveScan)
11865 emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
11866 "spvSubgroupBallotExclusiveBitCount");
11867 else
11868 SPIRV_CROSS_THROW("Invalid BitCount operation.");
11869 break;
11870 }
11871
11872 case OpGroupNonUniformShuffle:
11873 emit_binary_func_op(result_type, id, ops[3], ops[4], msl_options.is_ios() ? "quad_shuffle" : "simd_shuffle");
11874 break;
11875
11876 case OpGroupNonUniformShuffleXor:
11877 emit_binary_func_op(result_type, id, ops[3], ops[4],
11878 msl_options.is_ios() ? "quad_shuffle_xor" : "simd_shuffle_xor");
11879 break;
11880
11881 case OpGroupNonUniformShuffleUp:
11882 emit_binary_func_op(result_type, id, ops[3], ops[4],
11883 msl_options.is_ios() ? "quad_shuffle_up" : "simd_shuffle_up");
11884 break;
11885
11886 case OpGroupNonUniformShuffleDown:
11887 emit_binary_func_op(result_type, id, ops[3], ops[4],
11888 msl_options.is_ios() ? "quad_shuffle_down" : "simd_shuffle_down");
11889 break;
11890
11891 case OpGroupNonUniformAll:
11892 emit_unary_func_op(result_type, id, ops[3], "simd_all");
11893 break;
11894
11895 case OpGroupNonUniformAny:
11896 emit_unary_func_op(result_type, id, ops[3], "simd_any");
11897 break;
11898
11899 case OpGroupNonUniformAllEqual:
11900 emit_unary_func_op(result_type, id, ops[3], "spvSubgroupAllEqual");
11901 break;
11902
11903 // clang-format off
11904 #define MSL_GROUP_OP(op, msl_op) \
11905 case OpGroupNonUniform##op: \
11906 { \
11907 auto operation = static_cast<GroupOperation>(ops[3]); \
11908 if (operation == GroupOperationReduce) \
11909 emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
11910 else if (operation == GroupOperationInclusiveScan) \
11911 emit_unary_func_op(result_type, id, ops[4], "simd_prefix_inclusive_" #msl_op); \
11912 else if (operation == GroupOperationExclusiveScan) \
11913 emit_unary_func_op(result_type, id, ops[4], "simd_prefix_exclusive_" #msl_op); \
11914 else if (operation == GroupOperationClusteredReduce) \
11915 { \
11916 /* Only cluster sizes of 4 are supported. */ \
11917 uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
11918 if (cluster_size != 4) \
11919 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
11920 emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
11921 } \
11922 else \
11923 SPIRV_CROSS_THROW("Invalid group operation."); \
11924 break; \
11925 }
11926 MSL_GROUP_OP(FAdd, sum)
11927 MSL_GROUP_OP(FMul, product)
11928 MSL_GROUP_OP(IAdd, sum)
11929 MSL_GROUP_OP(IMul, product)
11930 #undef MSL_GROUP_OP
11931 // The others, unfortunately, don't support InclusiveScan or ExclusiveScan.
11932
11933 #define MSL_GROUP_OP(op, msl_op) \
11934 case OpGroupNonUniform##op: \
11935 { \
11936 auto operation = static_cast<GroupOperation>(ops[3]); \
11937 if (operation == GroupOperationReduce) \
11938 emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
11939 else if (operation == GroupOperationInclusiveScan) \
11940 SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
11941 else if (operation == GroupOperationExclusiveScan) \
11942 SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
11943 else if (operation == GroupOperationClusteredReduce) \
11944 { \
11945 /* Only cluster sizes of 4 are supported. */ \
11946 uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
11947 if (cluster_size != 4) \
11948 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
11949 emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
11950 } \
11951 else \
11952 SPIRV_CROSS_THROW("Invalid group operation."); \
11953 break; \
11954 }
11955
11956 #define MSL_GROUP_OP_CAST(op, msl_op, type) \
11957 case OpGroupNonUniform##op: \
11958 { \
11959 auto operation = static_cast<GroupOperation>(ops[3]); \
11960 if (operation == GroupOperationReduce) \
11961 emit_unary_func_op_cast(result_type, id, ops[4], "simd_" #msl_op, type, type); \
11962 else if (operation == GroupOperationInclusiveScan) \
11963 SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
11964 else if (operation == GroupOperationExclusiveScan) \
11965 SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
11966 else if (operation == GroupOperationClusteredReduce) \
11967 { \
11968 /* Only cluster sizes of 4 are supported. */ \
11969 uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
11970 if (cluster_size != 4) \
11971 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
11972 emit_unary_func_op_cast(result_type, id, ops[4], "quad_" #msl_op, type, type); \
11973 } \
11974 else \
11975 SPIRV_CROSS_THROW("Invalid group operation."); \
11976 break; \
11977 }
11978
11979 MSL_GROUP_OP(FMin, min)
11980 MSL_GROUP_OP(FMax, max)
11981 MSL_GROUP_OP_CAST(SMin, min, int_type)
11982 MSL_GROUP_OP_CAST(SMax, max, int_type)
11983 MSL_GROUP_OP_CAST(UMin, min, uint_type)
11984 MSL_GROUP_OP_CAST(UMax, max, uint_type)
11985 MSL_GROUP_OP(BitwiseAnd, and)
11986 MSL_GROUP_OP(BitwiseOr, or)
11987 MSL_GROUP_OP(BitwiseXor, xor)
11988 MSL_GROUP_OP(LogicalAnd, and)
11989 MSL_GROUP_OP(LogicalOr, or)
11990 MSL_GROUP_OP(LogicalXor, xor)
11991 // clang-format on
11992 #undef MSL_GROUP_OP
11993 #undef MSL_GROUP_OP_CAST
11994
11995 case OpGroupNonUniformQuadSwap:
11996 {
11997 // We can implement this easily based on the following table giving
11998 // the target lane ID from the direction and current lane ID:
11999 // Direction
12000 // | 0 | 1 | 2 |
12001 // ---+---+---+---+
12002 // L 0 | 1 2 3
12003 // a 1 | 0 3 2
12004 // n 2 | 3 0 1
12005 // e 3 | 2 1 0
12006 // Notice that target = source ^ (direction + 1).
12007 uint32_t mask = evaluate_constant_u32(ops[4]) + 1;
12008 uint32_t mask_id = ir.increase_bound_by(1);
12009 set<SPIRConstant>(mask_id, expression_type_id(ops[4]), mask, false);
12010 emit_binary_func_op(result_type, id, ops[3], mask_id, "quad_shuffle_xor");
12011 break;
12012 }
12013
12014 case OpGroupNonUniformQuadBroadcast:
12015 emit_binary_func_op(result_type, id, ops[3], ops[4], "quad_broadcast");
12016 break;
12017
12018 default:
12019 SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
12020 }
12021
12022 register_control_dependent_expression(id);
12023 }
12024
bitcast_glsl_op(const SPIRType & out_type,const SPIRType & in_type)12025 string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
12026 {
12027 if (out_type.basetype == in_type.basetype)
12028 return "";
12029
12030 assert(out_type.basetype != SPIRType::Boolean);
12031 assert(in_type.basetype != SPIRType::Boolean);
12032
12033 bool integral_cast = type_is_integral(out_type) && type_is_integral(in_type);
12034 bool same_size_cast = out_type.width == in_type.width;
12035
12036 if (integral_cast && same_size_cast)
12037 {
12038 // Trivial bitcast case, casts between integers.
12039 return type_to_glsl(out_type);
12040 }
12041 else
12042 {
12043 // Fall back to the catch-all bitcast in MSL.
12044 return "as_type<" + type_to_glsl(out_type) + ">";
12045 }
12046 }
12047
emit_complex_bitcast(uint32_t,uint32_t,uint32_t)12048 bool CompilerMSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
12049 {
12050 return false;
12051 }
12052
12053 // Returns an MSL string identifying the name of a SPIR-V builtin.
12054 // Output builtins are qualified with the name of the stage out structure.
builtin_to_glsl(BuiltIn builtin,StorageClass storage)12055 string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
12056 {
12057 switch (builtin)
12058 {
12059
12060 // Handle HLSL-style 0-based vertex/instance index.
12061 // Override GLSL compiler strictness
12062 case BuiltInVertexId:
12063 ensure_builtin(StorageClassInput, BuiltInVertexId);
12064 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
12065 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
12066 {
12067 if (builtin_declaration)
12068 {
12069 if (needs_base_vertex_arg != TriState::No)
12070 needs_base_vertex_arg = TriState::Yes;
12071 return "gl_VertexID";
12072 }
12073 else
12074 {
12075 ensure_builtin(StorageClassInput, BuiltInBaseVertex);
12076 return "(gl_VertexID - gl_BaseVertex)";
12077 }
12078 }
12079 else
12080 {
12081 return "gl_VertexID";
12082 }
12083 case BuiltInInstanceId:
12084 ensure_builtin(StorageClassInput, BuiltInInstanceId);
12085 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
12086 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
12087 {
12088 if (builtin_declaration)
12089 {
12090 if (needs_base_instance_arg != TriState::No)
12091 needs_base_instance_arg = TriState::Yes;
12092 return "gl_InstanceID";
12093 }
12094 else
12095 {
12096 ensure_builtin(StorageClassInput, BuiltInBaseInstance);
12097 return "(gl_InstanceID - gl_BaseInstance)";
12098 }
12099 }
12100 else
12101 {
12102 return "gl_InstanceID";
12103 }
12104 case BuiltInVertexIndex:
12105 ensure_builtin(StorageClassInput, BuiltInVertexIndex);
12106 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
12107 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
12108 {
12109 if (builtin_declaration)
12110 {
12111 if (needs_base_vertex_arg != TriState::No)
12112 needs_base_vertex_arg = TriState::Yes;
12113 return "gl_VertexIndex";
12114 }
12115 else
12116 {
12117 ensure_builtin(StorageClassInput, BuiltInBaseVertex);
12118 return "(gl_VertexIndex - gl_BaseVertex)";
12119 }
12120 }
12121 else
12122 {
12123 return "gl_VertexIndex";
12124 }
12125 case BuiltInInstanceIndex:
12126 ensure_builtin(StorageClassInput, BuiltInInstanceIndex);
12127 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
12128 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
12129 {
12130 if (builtin_declaration)
12131 {
12132 if (needs_base_instance_arg != TriState::No)
12133 needs_base_instance_arg = TriState::Yes;
12134 return "gl_InstanceIndex";
12135 }
12136 else
12137 {
12138 ensure_builtin(StorageClassInput, BuiltInBaseInstance);
12139 return "(gl_InstanceIndex - gl_BaseInstance)";
12140 }
12141 }
12142 else
12143 {
12144 return "gl_InstanceIndex";
12145 }
12146 case BuiltInBaseVertex:
12147 if (msl_options.supports_msl_version(1, 1) &&
12148 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
12149 {
12150 needs_base_vertex_arg = TriState::No;
12151 return "gl_BaseVertex";
12152 }
12153 else
12154 {
12155 SPIRV_CROSS_THROW("BaseVertex requires Metal 1.1 and Mac or Apple A9+ hardware.");
12156 }
12157 case BuiltInBaseInstance:
12158 if (msl_options.supports_msl_version(1, 1) &&
12159 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
12160 {
12161 needs_base_instance_arg = TriState::No;
12162 return "gl_BaseInstance";
12163 }
12164 else
12165 {
12166 SPIRV_CROSS_THROW("BaseInstance requires Metal 1.1 and Mac or Apple A9+ hardware.");
12167 }
12168 case BuiltInDrawIndex:
12169 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
12170
12171 // When used in the entry function, output builtins are qualified with output struct name.
12172 // Test storage class as NOT Input, as output builtins might be part of generic type.
12173 // Also don't do this for tessellation control shaders.
12174 case BuiltInViewportIndex:
12175 if (!msl_options.supports_msl_version(2, 0))
12176 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
12177 /* fallthrough */
12178 case BuiltInFragDepth:
12179 case BuiltInFragStencilRefEXT:
12180 if ((builtin == BuiltInFragDepth && !msl_options.enable_frag_depth_builtin) ||
12181 (builtin == BuiltInFragStencilRefEXT && !msl_options.enable_frag_stencil_ref_builtin))
12182 break;
12183 /* fallthrough */
12184 case BuiltInPosition:
12185 case BuiltInPointSize:
12186 case BuiltInClipDistance:
12187 case BuiltInCullDistance:
12188 case BuiltInLayer:
12189 case BuiltInSampleMask:
12190 if (get_execution_model() == ExecutionModelTessellationControl)
12191 break;
12192 if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
12193 return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
12194
12195 break;
12196
12197 case BuiltInBaryCoordNV:
12198 case BuiltInBaryCoordNoPerspNV:
12199 if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
12200 return stage_in_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
12201 break;
12202
12203 case BuiltInTessLevelOuter:
12204 if (get_execution_model() == ExecutionModelTessellationEvaluation)
12205 {
12206 if (storage != StorageClassOutput && !get_entry_point().flags.get(ExecutionModeTriangles) &&
12207 current_function && (current_function->self == ir.default_entry_point))
12208 return join(patch_stage_in_var_name, ".", CompilerGLSL::builtin_to_glsl(builtin, storage));
12209 else
12210 break;
12211 }
12212 if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
12213 return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
12214 "].edgeTessellationFactor");
12215 break;
12216
12217 case BuiltInTessLevelInner:
12218 if (get_execution_model() == ExecutionModelTessellationEvaluation)
12219 {
12220 if (storage != StorageClassOutput && !get_entry_point().flags.get(ExecutionModeTriangles) &&
12221 current_function && (current_function->self == ir.default_entry_point))
12222 return join(patch_stage_in_var_name, ".", CompilerGLSL::builtin_to_glsl(builtin, storage));
12223 else
12224 break;
12225 }
12226 if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
12227 return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
12228 "].insideTessellationFactor");
12229 break;
12230
12231 default:
12232 break;
12233 }
12234
12235 return CompilerGLSL::builtin_to_glsl(builtin, storage);
12236 }
12237
12238 // Returns an MSL string attribute qualifer for a SPIR-V builtin
builtin_qualifier(BuiltIn builtin)12239 string CompilerMSL::builtin_qualifier(BuiltIn builtin)
12240 {
12241 auto &execution = get_entry_point();
12242
12243 switch (builtin)
12244 {
12245 // Vertex function in
12246 case BuiltInVertexId:
12247 return "vertex_id";
12248 case BuiltInVertexIndex:
12249 return "vertex_id";
12250 case BuiltInBaseVertex:
12251 return "base_vertex";
12252 case BuiltInInstanceId:
12253 return "instance_id";
12254 case BuiltInInstanceIndex:
12255 return "instance_id";
12256 case BuiltInBaseInstance:
12257 return "base_instance";
12258 case BuiltInDrawIndex:
12259 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
12260
12261 // Vertex function out
12262 case BuiltInClipDistance:
12263 return "clip_distance";
12264 case BuiltInPointSize:
12265 return "point_size";
12266 case BuiltInPosition:
12267 if (position_invariant)
12268 {
12269 if (!msl_options.supports_msl_version(2, 1))
12270 SPIRV_CROSS_THROW("Invariant position is only supported on MSL 2.1 and up.");
12271 return "position, invariant";
12272 }
12273 else
12274 return "position";
12275 case BuiltInLayer:
12276 return "render_target_array_index";
12277 case BuiltInViewportIndex:
12278 if (!msl_options.supports_msl_version(2, 0))
12279 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
12280 return "viewport_array_index";
12281
12282 // Tess. control function in
12283 case BuiltInInvocationId:
12284 if (msl_options.multi_patch_workgroup)
12285 {
12286 // Shouldn't be reached.
12287 SPIRV_CROSS_THROW("InvocationId is computed manually with multi-patch workgroups in MSL.");
12288 }
12289 return "thread_index_in_threadgroup";
12290 case BuiltInPatchVertices:
12291 // Shouldn't be reached.
12292 SPIRV_CROSS_THROW("PatchVertices is derived from the auxiliary buffer in MSL.");
12293 case BuiltInPrimitiveId:
12294 switch (execution.model)
12295 {
12296 case ExecutionModelTessellationControl:
12297 if (msl_options.multi_patch_workgroup)
12298 {
12299 // Shouldn't be reached.
12300 SPIRV_CROSS_THROW("PrimitiveId is computed manually with multi-patch workgroups in MSL.");
12301 }
12302 return "threadgroup_position_in_grid";
12303 case ExecutionModelTessellationEvaluation:
12304 return "patch_id";
12305 case ExecutionModelFragment:
12306 if (msl_options.is_ios())
12307 SPIRV_CROSS_THROW("PrimitiveId is not supported in fragment on iOS.");
12308 else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 2))
12309 SPIRV_CROSS_THROW("PrimitiveId on macOS requires MSL 2.2.");
12310 return "primitive_id";
12311 default:
12312 SPIRV_CROSS_THROW("PrimitiveId is not supported in this execution model.");
12313 }
12314
12315 // Tess. control function out
12316 case BuiltInTessLevelOuter:
12317 case BuiltInTessLevelInner:
12318 // Shouldn't be reached.
12319 SPIRV_CROSS_THROW("Tessellation levels are handled specially in MSL.");
12320
12321 // Tess. evaluation function in
12322 case BuiltInTessCoord:
12323 return "position_in_patch";
12324
12325 // Fragment function in
12326 case BuiltInFrontFacing:
12327 return "front_facing";
12328 case BuiltInPointCoord:
12329 return "point_coord";
12330 case BuiltInFragCoord:
12331 return "position";
12332 case BuiltInSampleId:
12333 return "sample_id";
12334 case BuiltInSampleMask:
12335 return "sample_mask";
12336 case BuiltInSamplePosition:
12337 // Shouldn't be reached.
12338 SPIRV_CROSS_THROW("Sample position is retrieved by a function in MSL.");
12339 case BuiltInViewIndex:
12340 if (execution.model != ExecutionModelFragment)
12341 SPIRV_CROSS_THROW("ViewIndex is handled specially outside fragment shaders.");
12342 // The ViewIndex was implicitly used in the prior stages to set the render_target_array_index,
12343 // so we can get it from there.
12344 return "render_target_array_index";
12345
12346 // Fragment function out
12347 case BuiltInFragDepth:
12348 if (execution.flags.get(ExecutionModeDepthGreater))
12349 return "depth(greater)";
12350 else if (execution.flags.get(ExecutionModeDepthLess))
12351 return "depth(less)";
12352 else
12353 return "depth(any)";
12354
12355 case BuiltInFragStencilRefEXT:
12356 return "stencil";
12357
12358 // Compute function in
12359 case BuiltInGlobalInvocationId:
12360 return "thread_position_in_grid";
12361
12362 case BuiltInWorkgroupId:
12363 return "threadgroup_position_in_grid";
12364
12365 case BuiltInNumWorkgroups:
12366 return "threadgroups_per_grid";
12367
12368 case BuiltInLocalInvocationId:
12369 return "thread_position_in_threadgroup";
12370
12371 case BuiltInLocalInvocationIndex:
12372 return "thread_index_in_threadgroup";
12373
12374 case BuiltInSubgroupSize:
12375 if (execution.model == ExecutionModelFragment)
12376 {
12377 if (!msl_options.supports_msl_version(2, 2))
12378 SPIRV_CROSS_THROW("threads_per_simdgroup requires Metal 2.2 in fragment shaders.");
12379 return "threads_per_simdgroup";
12380 }
12381 else
12382 {
12383 // thread_execution_width is an alias for threads_per_simdgroup, and it's only available since 1.0,
12384 // but not in fragment.
12385 return "thread_execution_width";
12386 }
12387
12388 case BuiltInNumSubgroups:
12389 if (!msl_options.supports_msl_version(2))
12390 SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
12391 return msl_options.is_ios() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
12392
12393 case BuiltInSubgroupId:
12394 if (!msl_options.supports_msl_version(2))
12395 SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
12396 return msl_options.is_ios() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
12397
12398 case BuiltInSubgroupLocalInvocationId:
12399 if (execution.model == ExecutionModelFragment)
12400 {
12401 if (!msl_options.supports_msl_version(2, 2))
12402 SPIRV_CROSS_THROW("thread_index_in_simdgroup requires Metal 2.2 in fragment shaders.");
12403 return "thread_index_in_simdgroup";
12404 }
12405 else
12406 {
12407 if (!msl_options.supports_msl_version(2))
12408 SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
12409 return msl_options.is_ios() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
12410 }
12411
12412 case BuiltInSubgroupEqMask:
12413 case BuiltInSubgroupGeMask:
12414 case BuiltInSubgroupGtMask:
12415 case BuiltInSubgroupLeMask:
12416 case BuiltInSubgroupLtMask:
12417 // Shouldn't be reached.
12418 SPIRV_CROSS_THROW("Subgroup ballot masks are handled specially in MSL.");
12419
12420 case BuiltInBaryCoordNV:
12421 // TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
12422 if (msl_options.is_ios())
12423 SPIRV_CROSS_THROW("Barycentrics not supported on iOS.");
12424 else if (!msl_options.supports_msl_version(2, 2))
12425 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
12426 return "barycentric_coord, center_perspective";
12427
12428 case BuiltInBaryCoordNoPerspNV:
12429 // TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
12430 if (msl_options.is_ios())
12431 SPIRV_CROSS_THROW("Barycentrics not supported on iOS.");
12432 else if (!msl_options.supports_msl_version(2, 2))
12433 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
12434 return "barycentric_coord, center_no_perspective";
12435
12436 default:
12437 return "unsupported-built-in";
12438 }
12439 }
12440
12441 // Returns an MSL string type declaration for a SPIR-V builtin
builtin_type_decl(BuiltIn builtin,uint32_t id)12442 string CompilerMSL::builtin_type_decl(BuiltIn builtin, uint32_t id)
12443 {
12444 const SPIREntryPoint &execution = get_entry_point();
12445 switch (builtin)
12446 {
12447 // Vertex function in
12448 case BuiltInVertexId:
12449 return "uint";
12450 case BuiltInVertexIndex:
12451 return "uint";
12452 case BuiltInBaseVertex:
12453 return "uint";
12454 case BuiltInInstanceId:
12455 return "uint";
12456 case BuiltInInstanceIndex:
12457 return "uint";
12458 case BuiltInBaseInstance:
12459 return "uint";
12460 case BuiltInDrawIndex:
12461 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
12462
12463 // Vertex function out
12464 case BuiltInClipDistance:
12465 return "float";
12466 case BuiltInPointSize:
12467 return "float";
12468 case BuiltInPosition:
12469 return "float4";
12470 case BuiltInLayer:
12471 return "uint";
12472 case BuiltInViewportIndex:
12473 if (!msl_options.supports_msl_version(2, 0))
12474 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
12475 return "uint";
12476
12477 // Tess. control function in
12478 case BuiltInInvocationId:
12479 return "uint";
12480 case BuiltInPatchVertices:
12481 return "uint";
12482 case BuiltInPrimitiveId:
12483 return "uint";
12484
12485 // Tess. control function out
12486 case BuiltInTessLevelInner:
12487 if (execution.model == ExecutionModelTessellationEvaluation)
12488 return !execution.flags.get(ExecutionModeTriangles) ? "float2" : "float";
12489 return "half";
12490 case BuiltInTessLevelOuter:
12491 if (execution.model == ExecutionModelTessellationEvaluation)
12492 return !execution.flags.get(ExecutionModeTriangles) ? "float4" : "float";
12493 return "half";
12494
12495 // Tess. evaluation function in
12496 case BuiltInTessCoord:
12497 return execution.flags.get(ExecutionModeTriangles) ? "float3" : "float2";
12498
12499 // Fragment function in
12500 case BuiltInFrontFacing:
12501 return "bool";
12502 case BuiltInPointCoord:
12503 return "float2";
12504 case BuiltInFragCoord:
12505 return "float4";
12506 case BuiltInSampleId:
12507 return "uint";
12508 case BuiltInSampleMask:
12509 return "uint";
12510 case BuiltInSamplePosition:
12511 return "float2";
12512 case BuiltInViewIndex:
12513 return "uint";
12514
12515 case BuiltInHelperInvocation:
12516 return "bool";
12517
12518 case BuiltInBaryCoordNV:
12519 case BuiltInBaryCoordNoPerspNV:
12520 // Use the type as declared, can be 1, 2 or 3 components.
12521 return type_to_glsl(get_variable_data_type(get<SPIRVariable>(id)));
12522
12523 // Fragment function out
12524 case BuiltInFragDepth:
12525 return "float";
12526
12527 case BuiltInFragStencilRefEXT:
12528 return "uint";
12529
12530 // Compute function in
12531 case BuiltInGlobalInvocationId:
12532 case BuiltInLocalInvocationId:
12533 case BuiltInNumWorkgroups:
12534 case BuiltInWorkgroupId:
12535 return "uint3";
12536 case BuiltInLocalInvocationIndex:
12537 case BuiltInNumSubgroups:
12538 case BuiltInSubgroupId:
12539 case BuiltInSubgroupSize:
12540 case BuiltInSubgroupLocalInvocationId:
12541 return "uint";
12542 case BuiltInSubgroupEqMask:
12543 case BuiltInSubgroupGeMask:
12544 case BuiltInSubgroupGtMask:
12545 case BuiltInSubgroupLeMask:
12546 case BuiltInSubgroupLtMask:
12547 return "uint4";
12548
12549 case BuiltInDeviceIndex:
12550 return "int";
12551
12552 default:
12553 return "unsupported-built-in-type";
12554 }
12555 }
12556
12557 // Returns the declaration of a built-in argument to a function
built_in_func_arg(BuiltIn builtin,bool prefix_comma)12558 string CompilerMSL::built_in_func_arg(BuiltIn builtin, bool prefix_comma)
12559 {
12560 string bi_arg;
12561 if (prefix_comma)
12562 bi_arg += ", ";
12563
12564 // Handle HLSL-style 0-based vertex/instance index.
12565 builtin_declaration = true;
12566 bi_arg += builtin_type_decl(builtin);
12567 bi_arg += " " + builtin_to_glsl(builtin, StorageClassInput);
12568 bi_arg += " [[" + builtin_qualifier(builtin) + "]]";
12569 builtin_declaration = false;
12570
12571 return bi_arg;
12572 }
12573
get_physical_member_type(const SPIRType & type,uint32_t index) const12574 const SPIRType &CompilerMSL::get_physical_member_type(const SPIRType &type, uint32_t index) const
12575 {
12576 if (member_is_remapped_physical_type(type, index))
12577 return get<SPIRType>(get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID));
12578 else
12579 return get<SPIRType>(type.member_types[index]);
12580 }
12581
get_presumed_input_type(const SPIRType & ib_type,uint32_t index) const12582 SPIRType CompilerMSL::get_presumed_input_type(const SPIRType &ib_type, uint32_t index) const
12583 {
12584 SPIRType type = get_physical_member_type(ib_type, index);
12585 uint32_t loc = get_member_decoration(ib_type.self, index, DecorationLocation);
12586 if (inputs_by_location.count(loc))
12587 {
12588 if (inputs_by_location.at(loc).vecsize > type.vecsize)
12589 type.vecsize = inputs_by_location.at(loc).vecsize;
12590 }
12591 return type;
12592 }
12593
get_declared_type_array_stride_msl(const SPIRType & type,bool is_packed,bool row_major) const12594 uint32_t CompilerMSL::get_declared_type_array_stride_msl(const SPIRType &type, bool is_packed, bool row_major) const
12595 {
12596 // Array stride in MSL is always size * array_size. sizeof(float3) == 16,
12597 // unlike GLSL and HLSL where array stride would be 16 and size 12.
12598
12599 // We could use parent type here and recurse, but that makes creating physical type remappings
12600 // far more complicated. We'd rather just create the final type, and ignore having to create the entire type
12601 // hierarchy in order to compute this value, so make a temporary type on the stack.
12602
12603 auto basic_type = type;
12604 basic_type.array.clear();
12605 basic_type.array_size_literal.clear();
12606 uint32_t value_size = get_declared_type_size_msl(basic_type, is_packed, row_major);
12607
12608 uint32_t dimensions = uint32_t(type.array.size());
12609 assert(dimensions > 0);
12610 dimensions--;
12611
12612 // Multiply together every dimension, except the last one.
12613 for (uint32_t dim = 0; dim < dimensions; dim++)
12614 {
12615 uint32_t array_size = to_array_size_literal(type, dim);
12616 value_size *= max(array_size, 1u);
12617 }
12618
12619 return value_size;
12620 }
12621
get_declared_struct_member_array_stride_msl(const SPIRType & type,uint32_t index) const12622 uint32_t CompilerMSL::get_declared_struct_member_array_stride_msl(const SPIRType &type, uint32_t index) const
12623 {
12624 return get_declared_type_array_stride_msl(get_physical_member_type(type, index),
12625 member_is_packed_physical_type(type, index),
12626 has_member_decoration(type.self, index, DecorationRowMajor));
12627 }
12628
get_declared_input_array_stride_msl(const SPIRType & type,uint32_t index) const12629 uint32_t CompilerMSL::get_declared_input_array_stride_msl(const SPIRType &type, uint32_t index) const
12630 {
12631 return get_declared_type_array_stride_msl(get_presumed_input_type(type, index), false,
12632 has_member_decoration(type.self, index, DecorationRowMajor));
12633 }
12634
get_declared_type_matrix_stride_msl(const SPIRType & type,bool packed,bool row_major) const12635 uint32_t CompilerMSL::get_declared_type_matrix_stride_msl(const SPIRType &type, bool packed, bool row_major) const
12636 {
12637 // For packed matrices, we just use the size of the vector type.
12638 // Otherwise, MatrixStride == alignment, which is the size of the underlying vector type.
12639 if (packed)
12640 return (type.width / 8) * ((row_major && type.columns > 1) ? type.columns : type.vecsize);
12641 else
12642 return get_declared_type_alignment_msl(type, false, row_major);
12643 }
12644
get_declared_struct_member_matrix_stride_msl(const SPIRType & type,uint32_t index) const12645 uint32_t CompilerMSL::get_declared_struct_member_matrix_stride_msl(const SPIRType &type, uint32_t index) const
12646 {
12647 return get_declared_type_matrix_stride_msl(get_physical_member_type(type, index),
12648 member_is_packed_physical_type(type, index),
12649 has_member_decoration(type.self, index, DecorationRowMajor));
12650 }
12651
get_declared_input_matrix_stride_msl(const SPIRType & type,uint32_t index) const12652 uint32_t CompilerMSL::get_declared_input_matrix_stride_msl(const SPIRType &type, uint32_t index) const
12653 {
12654 return get_declared_type_matrix_stride_msl(get_presumed_input_type(type, index), false,
12655 has_member_decoration(type.self, index, DecorationRowMajor));
12656 }
12657
get_declared_struct_size_msl(const SPIRType & struct_type,bool ignore_alignment,bool ignore_padding) const12658 uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type, bool ignore_alignment,
12659 bool ignore_padding) const
12660 {
12661 // If we have a target size, that is the declared size as well.
12662 if (!ignore_padding && has_extended_decoration(struct_type.self, SPIRVCrossDecorationPaddingTarget))
12663 return get_extended_decoration(struct_type.self, SPIRVCrossDecorationPaddingTarget);
12664
12665 if (struct_type.member_types.empty())
12666 return 0;
12667
12668 uint32_t mbr_cnt = uint32_t(struct_type.member_types.size());
12669
12670 // In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
12671 uint32_t alignment = 1;
12672
12673 if (!ignore_alignment)
12674 {
12675 for (uint32_t i = 0; i < mbr_cnt; i++)
12676 {
12677 uint32_t mbr_alignment = get_declared_struct_member_alignment_msl(struct_type, i);
12678 alignment = max(alignment, mbr_alignment);
12679 }
12680 }
12681
12682 // Last member will always be matched to the final Offset decoration, but size of struct in MSL now depends
12683 // on physical size in MSL, and the size of the struct itself is then aligned to struct alignment.
12684 uint32_t spirv_offset = type_struct_member_offset(struct_type, mbr_cnt - 1);
12685 uint32_t msl_size = spirv_offset + get_declared_struct_member_size_msl(struct_type, mbr_cnt - 1);
12686 msl_size = (msl_size + alignment - 1) & ~(alignment - 1);
12687 return msl_size;
12688 }
12689
12690 // Returns the byte size of a struct member.
get_declared_type_size_msl(const SPIRType & type,bool is_packed,bool row_major) const12691 uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const
12692 {
12693 switch (type.basetype)
12694 {
12695 case SPIRType::Unknown:
12696 case SPIRType::Void:
12697 case SPIRType::AtomicCounter:
12698 case SPIRType::Image:
12699 case SPIRType::SampledImage:
12700 case SPIRType::Sampler:
12701 SPIRV_CROSS_THROW("Querying size of opaque object.");
12702
12703 default:
12704 {
12705 if (!type.array.empty())
12706 {
12707 uint32_t array_size = to_array_size_literal(type);
12708 return get_declared_type_array_stride_msl(type, is_packed, row_major) * max(array_size, 1u);
12709 }
12710
12711 if (type.basetype == SPIRType::Struct)
12712 return get_declared_struct_size_msl(type);
12713
12714 if (is_packed)
12715 {
12716 return type.vecsize * type.columns * (type.width / 8);
12717 }
12718 else
12719 {
12720 // An unpacked 3-element vector or matrix column is the same memory size as a 4-element.
12721 uint32_t vecsize = type.vecsize;
12722 uint32_t columns = type.columns;
12723
12724 if (row_major && columns > 1)
12725 swap(vecsize, columns);
12726
12727 if (vecsize == 3)
12728 vecsize = 4;
12729
12730 return vecsize * columns * (type.width / 8);
12731 }
12732 }
12733 }
12734 }
12735
get_declared_struct_member_size_msl(const SPIRType & type,uint32_t index) const12736 uint32_t CompilerMSL::get_declared_struct_member_size_msl(const SPIRType &type, uint32_t index) const
12737 {
12738 return get_declared_type_size_msl(get_physical_member_type(type, index),
12739 member_is_packed_physical_type(type, index),
12740 has_member_decoration(type.self, index, DecorationRowMajor));
12741 }
12742
get_declared_input_size_msl(const SPIRType & type,uint32_t index) const12743 uint32_t CompilerMSL::get_declared_input_size_msl(const SPIRType &type, uint32_t index) const
12744 {
12745 return get_declared_type_size_msl(get_presumed_input_type(type, index), false,
12746 has_member_decoration(type.self, index, DecorationRowMajor));
12747 }
12748
12749 // Returns the byte alignment of a type.
get_declared_type_alignment_msl(const SPIRType & type,bool is_packed,bool row_major) const12750 uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool is_packed, bool row_major) const
12751 {
12752 switch (type.basetype)
12753 {
12754 case SPIRType::Unknown:
12755 case SPIRType::Void:
12756 case SPIRType::AtomicCounter:
12757 case SPIRType::Image:
12758 case SPIRType::SampledImage:
12759 case SPIRType::Sampler:
12760 SPIRV_CROSS_THROW("Querying alignment of opaque object.");
12761
12762 case SPIRType::Int64:
12763 SPIRV_CROSS_THROW("long types are not supported in buffers in MSL.");
12764 case SPIRType::UInt64:
12765 SPIRV_CROSS_THROW("ulong types are not supported in buffers in MSL.");
12766 case SPIRType::Double:
12767 SPIRV_CROSS_THROW("double types are not supported in buffers in MSL.");
12768
12769 case SPIRType::Struct:
12770 {
12771 // In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
12772 uint32_t alignment = 1;
12773 for (uint32_t i = 0; i < type.member_types.size(); i++)
12774 alignment = max(alignment, uint32_t(get_declared_struct_member_alignment_msl(type, i)));
12775 return alignment;
12776 }
12777
12778 default:
12779 {
12780 // Alignment of packed type is the same as the underlying component or column size.
12781 // Alignment of unpacked type is the same as the vector size.
12782 // Alignment of 3-elements vector is the same as 4-elements (including packed using column).
12783 if (is_packed)
12784 {
12785 // If we have packed_T and friends, the alignment is always scalar.
12786 return type.width / 8;
12787 }
12788 else
12789 {
12790 // This is the general rule for MSL. Size == alignment.
12791 uint32_t vecsize = (row_major && type.columns > 1) ? type.columns : type.vecsize;
12792 return (type.width / 8) * (vecsize == 3 ? 4 : vecsize);
12793 }
12794 }
12795 }
12796 }
12797
get_declared_struct_member_alignment_msl(const SPIRType & type,uint32_t index) const12798 uint32_t CompilerMSL::get_declared_struct_member_alignment_msl(const SPIRType &type, uint32_t index) const
12799 {
12800 return get_declared_type_alignment_msl(get_physical_member_type(type, index),
12801 member_is_packed_physical_type(type, index),
12802 has_member_decoration(type.self, index, DecorationRowMajor));
12803 }
12804
get_declared_input_alignment_msl(const SPIRType & type,uint32_t index) const12805 uint32_t CompilerMSL::get_declared_input_alignment_msl(const SPIRType &type, uint32_t index) const
12806 {
12807 return get_declared_type_alignment_msl(get_presumed_input_type(type, index), false,
12808 has_member_decoration(type.self, index, DecorationRowMajor));
12809 }
12810
skip_argument(uint32_t) const12811 bool CompilerMSL::skip_argument(uint32_t) const
12812 {
12813 return false;
12814 }
12815
analyze_sampled_image_usage()12816 void CompilerMSL::analyze_sampled_image_usage()
12817 {
12818 if (msl_options.swizzle_texture_samples)
12819 {
12820 SampledImageScanner scanner(*this);
12821 traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), scanner);
12822 }
12823 }
12824
handle(spv::Op opcode,const uint32_t * args,uint32_t length)12825 bool CompilerMSL::SampledImageScanner::handle(spv::Op opcode, const uint32_t *args, uint32_t length)
12826 {
12827 switch (opcode)
12828 {
12829 case OpLoad:
12830 case OpImage:
12831 case OpSampledImage:
12832 {
12833 if (length < 3)
12834 return false;
12835
12836 uint32_t result_type = args[0];
12837 auto &type = compiler.get<SPIRType>(result_type);
12838 if ((type.basetype != SPIRType::Image && type.basetype != SPIRType::SampledImage) || type.image.sampled != 1)
12839 return true;
12840
12841 uint32_t id = args[1];
12842 compiler.set<SPIRExpression>(id, "", result_type, true);
12843 break;
12844 }
12845 case OpImageSampleExplicitLod:
12846 case OpImageSampleProjExplicitLod:
12847 case OpImageSampleDrefExplicitLod:
12848 case OpImageSampleProjDrefExplicitLod:
12849 case OpImageSampleImplicitLod:
12850 case OpImageSampleProjImplicitLod:
12851 case OpImageSampleDrefImplicitLod:
12852 case OpImageSampleProjDrefImplicitLod:
12853 case OpImageFetch:
12854 case OpImageGather:
12855 case OpImageDrefGather:
12856 compiler.has_sampled_images =
12857 compiler.has_sampled_images || compiler.is_sampled_image_type(compiler.expression_type(args[2]));
12858 compiler.needs_swizzle_buffer_def = compiler.needs_swizzle_buffer_def || compiler.has_sampled_images;
12859 break;
12860 default:
12861 break;
12862 }
12863 return true;
12864 }
12865
12866 // If a needed custom function wasn't added before, add it and force a recompile.
add_spv_func_and_recompile(SPVFuncImpl spv_func)12867 void CompilerMSL::add_spv_func_and_recompile(SPVFuncImpl spv_func)
12868 {
12869 if (spv_function_implementations.count(spv_func) == 0)
12870 {
12871 spv_function_implementations.insert(spv_func);
12872 suppress_missing_prototypes = true;
12873 force_recompile();
12874 }
12875 }
12876
handle(Op opcode,const uint32_t * args,uint32_t length)12877 bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, uint32_t length)
12878 {
12879 // Since MSL exists in a single execution scope, function prototype declarations are not
12880 // needed, and clutter the output. If secondary functions are output (either as a SPIR-V
12881 // function implementation or as indicated by the presence of OpFunctionCall), then set
12882 // suppress_missing_prototypes to suppress compiler warnings of missing function prototypes.
12883
12884 // Mark if the input requires the implementation of an SPIR-V function that does not exist in Metal.
12885 SPVFuncImpl spv_func = get_spv_func_impl(opcode, args);
12886 if (spv_func != SPVFuncImplNone)
12887 {
12888 compiler.spv_function_implementations.insert(spv_func);
12889 suppress_missing_prototypes = true;
12890 }
12891
12892 switch (opcode)
12893 {
12894
12895 case OpFunctionCall:
12896 suppress_missing_prototypes = true;
12897 break;
12898
12899 // Emulate texture2D atomic operations
12900 case OpImageTexelPointer:
12901 {
12902 auto *var = compiler.maybe_get_backing_variable(args[2]);
12903 image_pointers[args[1]] = var ? var->self : ID(0);
12904 break;
12905 }
12906
12907 case OpImageWrite:
12908 uses_resource_write = true;
12909 break;
12910
12911 case OpStore:
12912 check_resource_write(args[0]);
12913 break;
12914
12915 // Emulate texture2D atomic operations
12916 case OpAtomicExchange:
12917 case OpAtomicCompareExchange:
12918 case OpAtomicCompareExchangeWeak:
12919 case OpAtomicIIncrement:
12920 case OpAtomicIDecrement:
12921 case OpAtomicIAdd:
12922 case OpAtomicISub:
12923 case OpAtomicSMin:
12924 case OpAtomicUMin:
12925 case OpAtomicSMax:
12926 case OpAtomicUMax:
12927 case OpAtomicAnd:
12928 case OpAtomicOr:
12929 case OpAtomicXor:
12930 {
12931 uses_atomics = true;
12932 auto it = image_pointers.find(args[2]);
12933 if (it != image_pointers.end())
12934 {
12935 compiler.atomic_image_vars.insert(it->second);
12936 }
12937 check_resource_write(args[2]);
12938 break;
12939 }
12940
12941 case OpAtomicStore:
12942 {
12943 uses_atomics = true;
12944 auto it = image_pointers.find(args[0]);
12945 if (it != image_pointers.end())
12946 {
12947 compiler.atomic_image_vars.insert(it->second);
12948 }
12949 check_resource_write(args[0]);
12950 break;
12951 }
12952
12953 case OpAtomicLoad:
12954 {
12955 uses_atomics = true;
12956 auto it = image_pointers.find(args[2]);
12957 if (it != image_pointers.end())
12958 {
12959 compiler.atomic_image_vars.insert(it->second);
12960 }
12961 break;
12962 }
12963
12964 case OpGroupNonUniformInverseBallot:
12965 needs_subgroup_invocation_id = true;
12966 break;
12967
12968 case OpGroupNonUniformBallotBitCount:
12969 if (args[3] != GroupOperationReduce)
12970 needs_subgroup_invocation_id = true;
12971 break;
12972
12973 case OpArrayLength:
12974 {
12975 auto *var = compiler.maybe_get_backing_variable(args[2]);
12976 if (var)
12977 compiler.buffers_requiring_array_length.insert(var->self);
12978 break;
12979 }
12980
12981 case OpInBoundsAccessChain:
12982 case OpAccessChain:
12983 case OpPtrAccessChain:
12984 {
12985 // OpArrayLength might want to know if taking ArrayLength of an array of SSBOs.
12986 uint32_t result_type = args[0];
12987 uint32_t id = args[1];
12988 uint32_t ptr = args[2];
12989
12990 compiler.set<SPIRExpression>(id, "", result_type, true);
12991 compiler.register_read(id, ptr, true);
12992 compiler.ir.ids[id].set_allow_type_rewrite();
12993 break;
12994 }
12995
12996 default:
12997 break;
12998 }
12999
13000 // If it has one, keep track of the instruction's result type, mapped by ID
13001 uint32_t result_type, result_id;
13002 if (compiler.instruction_to_result_type(result_type, result_id, opcode, args, length))
13003 result_types[result_id] = result_type;
13004
13005 return true;
13006 }
13007
13008 // If the variable is a Uniform or StorageBuffer, mark that a resource has been written to.
check_resource_write(uint32_t var_id)13009 void CompilerMSL::OpCodePreprocessor::check_resource_write(uint32_t var_id)
13010 {
13011 auto *p_var = compiler.maybe_get_backing_variable(var_id);
13012 StorageClass sc = p_var ? p_var->storage : StorageClassMax;
13013 if (sc == StorageClassUniform || sc == StorageClassStorageBuffer)
13014 uses_resource_write = true;
13015 }
13016
13017 // Returns an enumeration of a SPIR-V function that needs to be output for certain Op codes.
get_spv_func_impl(Op opcode,const uint32_t * args)13018 CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op opcode, const uint32_t *args)
13019 {
13020 switch (opcode)
13021 {
13022 case OpFMod:
13023 return SPVFuncImplMod;
13024
13025 case OpFAdd:
13026 if (compiler.msl_options.invariant_float_math)
13027 {
13028 return SPVFuncImplFAdd;
13029 }
13030 break;
13031
13032 case OpFMul:
13033 case OpOuterProduct:
13034 case OpMatrixTimesVector:
13035 case OpVectorTimesMatrix:
13036 case OpMatrixTimesMatrix:
13037 if (compiler.msl_options.invariant_float_math)
13038 {
13039 return SPVFuncImplFMul;
13040 }
13041 break;
13042
13043 case OpTypeArray:
13044 {
13045 // Allow Metal to use the array<T> template to make arrays a value type
13046 return SPVFuncImplUnsafeArray;
13047 }
13048
13049 // Emulate texture2D atomic operations
13050 case OpAtomicExchange:
13051 case OpAtomicCompareExchange:
13052 case OpAtomicCompareExchangeWeak:
13053 case OpAtomicIIncrement:
13054 case OpAtomicIDecrement:
13055 case OpAtomicIAdd:
13056 case OpAtomicISub:
13057 case OpAtomicSMin:
13058 case OpAtomicUMin:
13059 case OpAtomicSMax:
13060 case OpAtomicUMax:
13061 case OpAtomicAnd:
13062 case OpAtomicOr:
13063 case OpAtomicXor:
13064 case OpAtomicLoad:
13065 case OpAtomicStore:
13066 {
13067 auto it = image_pointers.find(args[opcode == OpAtomicStore ? 0 : 2]);
13068 if (it != image_pointers.end())
13069 {
13070 uint32_t tid = compiler.get<SPIRVariable>(it->second).basetype;
13071 if (tid && compiler.get<SPIRType>(tid).image.dim == Dim2D)
13072 return SPVFuncImplImage2DAtomicCoords;
13073 }
13074 break;
13075 }
13076
13077 case OpImageFetch:
13078 case OpImageRead:
13079 case OpImageWrite:
13080 {
13081 // Retrieve the image type, and if it's a Buffer, emit a texel coordinate function
13082 uint32_t tid = result_types[args[opcode == OpImageWrite ? 0 : 2]];
13083 if (tid && compiler.get<SPIRType>(tid).image.dim == DimBuffer && !compiler.msl_options.texture_buffer_native)
13084 return SPVFuncImplTexelBufferCoords;
13085 break;
13086 }
13087
13088 case OpExtInst:
13089 {
13090 uint32_t extension_set = args[2];
13091 if (compiler.get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
13092 {
13093 auto op_450 = static_cast<GLSLstd450>(args[3]);
13094 switch (op_450)
13095 {
13096 case GLSLstd450Radians:
13097 return SPVFuncImplRadians;
13098 case GLSLstd450Degrees:
13099 return SPVFuncImplDegrees;
13100 case GLSLstd450FindILsb:
13101 return SPVFuncImplFindILsb;
13102 case GLSLstd450FindSMsb:
13103 return SPVFuncImplFindSMsb;
13104 case GLSLstd450FindUMsb:
13105 return SPVFuncImplFindUMsb;
13106 case GLSLstd450SSign:
13107 return SPVFuncImplSSign;
13108 case GLSLstd450Reflect:
13109 {
13110 auto &type = compiler.get<SPIRType>(args[0]);
13111 if (type.vecsize == 1)
13112 return SPVFuncImplReflectScalar;
13113 break;
13114 }
13115 case GLSLstd450Refract:
13116 {
13117 auto &type = compiler.get<SPIRType>(args[0]);
13118 if (type.vecsize == 1)
13119 return SPVFuncImplRefractScalar;
13120 break;
13121 }
13122 case GLSLstd450FaceForward:
13123 {
13124 auto &type = compiler.get<SPIRType>(args[0]);
13125 if (type.vecsize == 1)
13126 return SPVFuncImplFaceForwardScalar;
13127 break;
13128 }
13129 case GLSLstd450MatrixInverse:
13130 {
13131 auto &mat_type = compiler.get<SPIRType>(args[0]);
13132 switch (mat_type.columns)
13133 {
13134 case 2:
13135 return SPVFuncImplInverse2x2;
13136 case 3:
13137 return SPVFuncImplInverse3x3;
13138 case 4:
13139 return SPVFuncImplInverse4x4;
13140 default:
13141 break;
13142 }
13143 break;
13144 }
13145 default:
13146 break;
13147 }
13148 }
13149 break;
13150 }
13151
13152 case OpGroupNonUniformBallot:
13153 return SPVFuncImplSubgroupBallot;
13154
13155 case OpGroupNonUniformInverseBallot:
13156 case OpGroupNonUniformBallotBitExtract:
13157 return SPVFuncImplSubgroupBallotBitExtract;
13158
13159 case OpGroupNonUniformBallotFindLSB:
13160 return SPVFuncImplSubgroupBallotFindLSB;
13161
13162 case OpGroupNonUniformBallotFindMSB:
13163 return SPVFuncImplSubgroupBallotFindMSB;
13164
13165 case OpGroupNonUniformBallotBitCount:
13166 return SPVFuncImplSubgroupBallotBitCount;
13167
13168 case OpGroupNonUniformAllEqual:
13169 return SPVFuncImplSubgroupAllEqual;
13170
13171 default:
13172 break;
13173 }
13174 return SPVFuncImplNone;
13175 }
13176
13177 // Sort both type and meta member content based on builtin status (put builtins at end),
13178 // then by the required sorting aspect.
sort()13179 void CompilerMSL::MemberSorter::sort()
13180 {
13181 // Create a temporary array of consecutive member indices and sort it based on how
13182 // the members should be reordered, based on builtin and sorting aspect meta info.
13183 size_t mbr_cnt = type.member_types.size();
13184 SmallVector<uint32_t> mbr_idxs(mbr_cnt);
13185 std::iota(mbr_idxs.begin(), mbr_idxs.end(), 0); // Fill with consecutive indices
13186 std::stable_sort(mbr_idxs.begin(), mbr_idxs.end(), *this); // Sort member indices based on sorting aspect
13187
13188 bool sort_is_identity = true;
13189 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
13190 {
13191 if (mbr_idx != mbr_idxs[mbr_idx])
13192 {
13193 sort_is_identity = false;
13194 break;
13195 }
13196 }
13197
13198 if (sort_is_identity)
13199 return;
13200
13201 if (meta.members.size() < type.member_types.size())
13202 {
13203 // This should never trigger in normal circumstances, but to be safe.
13204 meta.members.resize(type.member_types.size());
13205 }
13206
13207 // Move type and meta member info to the order defined by the sorted member indices.
13208 // This is done by creating temporary copies of both member types and meta, and then
13209 // copying back to the original content at the sorted indices.
13210 auto mbr_types_cpy = type.member_types;
13211 auto mbr_meta_cpy = meta.members;
13212 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
13213 {
13214 type.member_types[mbr_idx] = mbr_types_cpy[mbr_idxs[mbr_idx]];
13215 meta.members[mbr_idx] = mbr_meta_cpy[mbr_idxs[mbr_idx]];
13216 }
13217
13218 if (sort_aspect == SortAspect::Offset)
13219 {
13220 // If we're sorting by Offset, this might affect user code which accesses a buffer block.
13221 // We will need to redirect member indices from one index to sorted index.
13222 type.member_type_index_redirection = std::move(mbr_idxs);
13223 }
13224 }
13225
13226 // Sort first by builtin status (put builtins at end), then by the sorting aspect.
operator ()(uint32_t mbr_idx1,uint32_t mbr_idx2)13227 bool CompilerMSL::MemberSorter::operator()(uint32_t mbr_idx1, uint32_t mbr_idx2)
13228 {
13229 auto &mbr_meta1 = meta.members[mbr_idx1];
13230 auto &mbr_meta2 = meta.members[mbr_idx2];
13231 if (mbr_meta1.builtin != mbr_meta2.builtin)
13232 return mbr_meta2.builtin;
13233 else
13234 switch (sort_aspect)
13235 {
13236 case Location:
13237 return mbr_meta1.location < mbr_meta2.location;
13238 case LocationReverse:
13239 return mbr_meta1.location > mbr_meta2.location;
13240 case Offset:
13241 return mbr_meta1.offset < mbr_meta2.offset;
13242 case OffsetThenLocationReverse:
13243 return (mbr_meta1.offset < mbr_meta2.offset) ||
13244 ((mbr_meta1.offset == mbr_meta2.offset) && (mbr_meta1.location > mbr_meta2.location));
13245 case Alphabetical:
13246 return mbr_meta1.alias < mbr_meta2.alias;
13247 default:
13248 return false;
13249 }
13250 }
13251
MemberSorter(SPIRType & t,Meta & m,SortAspect sa)13252 CompilerMSL::MemberSorter::MemberSorter(SPIRType &t, Meta &m, SortAspect sa)
13253 : type(t)
13254 , meta(m)
13255 , sort_aspect(sa)
13256 {
13257 // Ensure enough meta info is available
13258 meta.members.resize(max(type.member_types.size(), meta.members.size()));
13259 }
13260
remap_constexpr_sampler(VariableID id,const MSLConstexprSampler & sampler)13261 void CompilerMSL::remap_constexpr_sampler(VariableID id, const MSLConstexprSampler &sampler)
13262 {
13263 auto &type = get<SPIRType>(get<SPIRVariable>(id).basetype);
13264 if (type.basetype != SPIRType::SampledImage && type.basetype != SPIRType::Sampler)
13265 SPIRV_CROSS_THROW("Can only remap SampledImage and Sampler type.");
13266 if (!type.array.empty())
13267 SPIRV_CROSS_THROW("Can not remap array of samplers.");
13268 constexpr_samplers_by_id[id] = sampler;
13269 }
13270
remap_constexpr_sampler_by_binding(uint32_t desc_set,uint32_t binding,const MSLConstexprSampler & sampler)13271 void CompilerMSL::remap_constexpr_sampler_by_binding(uint32_t desc_set, uint32_t binding,
13272 const MSLConstexprSampler &sampler)
13273 {
13274 constexpr_samplers_by_binding[{ desc_set, binding }] = sampler;
13275 }
13276
bitcast_from_builtin_load(uint32_t source_id,std::string & expr,const SPIRType & expr_type)13277 void CompilerMSL::bitcast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
13278 {
13279 auto *var = maybe_get_backing_variable(source_id);
13280 if (var)
13281 source_id = var->self;
13282
13283 // Only interested in standalone builtin variables.
13284 if (!has_decoration(source_id, DecorationBuiltIn))
13285 return;
13286
13287 auto builtin = static_cast<BuiltIn>(get_decoration(source_id, DecorationBuiltIn));
13288 auto expected_type = expr_type.basetype;
13289 switch (builtin)
13290 {
13291 case BuiltInGlobalInvocationId:
13292 case BuiltInLocalInvocationId:
13293 case BuiltInWorkgroupId:
13294 case BuiltInLocalInvocationIndex:
13295 case BuiltInWorkgroupSize:
13296 case BuiltInNumWorkgroups:
13297 case BuiltInLayer:
13298 case BuiltInViewportIndex:
13299 case BuiltInFragStencilRefEXT:
13300 case BuiltInPrimitiveId:
13301 case BuiltInSubgroupSize:
13302 case BuiltInSubgroupLocalInvocationId:
13303 case BuiltInViewIndex:
13304 case BuiltInVertexIndex:
13305 case BuiltInInstanceIndex:
13306 case BuiltInBaseInstance:
13307 case BuiltInBaseVertex:
13308 expected_type = SPIRType::UInt;
13309 break;
13310
13311 case BuiltInTessLevelInner:
13312 case BuiltInTessLevelOuter:
13313 if (get_execution_model() == ExecutionModelTessellationControl)
13314 expected_type = SPIRType::Half;
13315 break;
13316
13317 default:
13318 break;
13319 }
13320
13321 if (expected_type != expr_type.basetype)
13322 expr = bitcast_expression(expr_type, expected_type, expr);
13323
13324 if (builtin == BuiltInTessCoord && get_entry_point().flags.get(ExecutionModeQuads) && expr_type.vecsize == 3)
13325 {
13326 // In SPIR-V, this is always a vec3, even for quads. In Metal, though, it's a float2 for quads.
13327 // The code is expecting a float3, so we need to widen this.
13328 expr = join("float3(", expr, ", 0)");
13329 }
13330 }
13331
bitcast_to_builtin_store(uint32_t target_id,std::string & expr,const SPIRType & expr_type)13332 void CompilerMSL::bitcast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
13333 {
13334 auto *var = maybe_get_backing_variable(target_id);
13335 if (var)
13336 target_id = var->self;
13337
13338 // Only interested in standalone builtin variables.
13339 if (!has_decoration(target_id, DecorationBuiltIn))
13340 return;
13341
13342 auto builtin = static_cast<BuiltIn>(get_decoration(target_id, DecorationBuiltIn));
13343 auto expected_type = expr_type.basetype;
13344 switch (builtin)
13345 {
13346 case BuiltInLayer:
13347 case BuiltInViewportIndex:
13348 case BuiltInFragStencilRefEXT:
13349 case BuiltInPrimitiveId:
13350 case BuiltInViewIndex:
13351 expected_type = SPIRType::UInt;
13352 break;
13353
13354 case BuiltInTessLevelInner:
13355 case BuiltInTessLevelOuter:
13356 expected_type = SPIRType::Half;
13357 break;
13358
13359 default:
13360 break;
13361 }
13362
13363 if (expected_type != expr_type.basetype)
13364 {
13365 if (expected_type == SPIRType::Half && expr_type.basetype == SPIRType::Float)
13366 {
13367 // These are of different widths, so we cannot do a straight bitcast.
13368 expr = join("half(", expr, ")");
13369 }
13370 else
13371 {
13372 auto type = expr_type;
13373 type.basetype = expected_type;
13374 expr = bitcast_expression(type, expr_type.basetype, expr);
13375 }
13376 }
13377 }
13378
to_initializer_expression(const SPIRVariable & var)13379 string CompilerMSL::to_initializer_expression(const SPIRVariable &var)
13380 {
13381 // We risk getting an array initializer here with MSL. If we have an array.
13382 // FIXME: We cannot handle non-constant arrays being initialized.
13383 // We will need to inject spvArrayCopy here somehow ...
13384 auto &type = get<SPIRType>(var.basetype);
13385 string expr;
13386 if (ir.ids[var.initializer].get_type() == TypeConstant &&
13387 (!type.array.empty() || type.basetype == SPIRType::Struct))
13388 expr = constant_expression(get<SPIRConstant>(var.initializer));
13389 else
13390 expr = CompilerGLSL::to_initializer_expression(var);
13391 // If the initializer has more vector components than the variable, add a swizzle.
13392 // FIXME: This can't handle arrays or structs.
13393 auto &init_type = expression_type(var.initializer);
13394 if (type.array.empty() && type.basetype != SPIRType::Struct && init_type.vecsize > type.vecsize)
13395 expr = enclose_expression(expr + vector_swizzle(type.vecsize, 0));
13396 return expr;
13397 }
13398
to_zero_initialized_expression(uint32_t)13399 string CompilerMSL::to_zero_initialized_expression(uint32_t)
13400 {
13401 return "{}";
13402 }
13403
descriptor_set_is_argument_buffer(uint32_t desc_set) const13404 bool CompilerMSL::descriptor_set_is_argument_buffer(uint32_t desc_set) const
13405 {
13406 if (!msl_options.argument_buffers)
13407 return false;
13408 if (desc_set >= kMaxArgumentBuffers)
13409 return false;
13410
13411 return (argument_buffer_discrete_mask & (1u << desc_set)) == 0;
13412 }
13413
analyze_argument_buffers()13414 void CompilerMSL::analyze_argument_buffers()
13415 {
13416 // Gather all used resources and sort them out into argument buffers.
13417 // Each argument buffer corresponds to a descriptor set in SPIR-V.
13418 // The [[id(N)]] values used correspond to the resource mapping we have for MSL.
13419 // Otherwise, the binding number is used, but this is generally not safe some types like
13420 // combined image samplers and arrays of resources. Metal needs different indices here,
13421 // while SPIR-V can have one descriptor set binding. To use argument buffers in practice,
13422 // you will need to use the remapping from the API.
13423 for (auto &id : argument_buffer_ids)
13424 id = 0;
13425
13426 // Output resources, sorted by resource index & type.
13427 struct Resource
13428 {
13429 SPIRVariable *var;
13430 string name;
13431 SPIRType::BaseType basetype;
13432 uint32_t index;
13433 uint32_t plane;
13434 };
13435 SmallVector<Resource> resources_in_set[kMaxArgumentBuffers];
13436 SmallVector<uint32_t> inline_block_vars;
13437
13438 bool set_needs_swizzle_buffer[kMaxArgumentBuffers] = {};
13439 bool set_needs_buffer_sizes[kMaxArgumentBuffers] = {};
13440 bool needs_buffer_sizes = false;
13441
13442 ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &var) {
13443 if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
13444 var.storage == StorageClassStorageBuffer) &&
13445 !is_hidden_variable(var))
13446 {
13447 uint32_t desc_set = get_decoration(self, DecorationDescriptorSet);
13448 // Ignore if it's part of a push descriptor set.
13449 if (!descriptor_set_is_argument_buffer(desc_set))
13450 return;
13451
13452 uint32_t var_id = var.self;
13453 auto &type = get_variable_data_type(var);
13454
13455 if (desc_set >= kMaxArgumentBuffers)
13456 SPIRV_CROSS_THROW("Descriptor set index is out of range.");
13457
13458 const MSLConstexprSampler *constexpr_sampler = nullptr;
13459 if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
13460 {
13461 constexpr_sampler = find_constexpr_sampler(var_id);
13462 if (constexpr_sampler)
13463 {
13464 // Mark this ID as a constexpr sampler for later in case it came from set/bindings.
13465 constexpr_samplers_by_id[var_id] = *constexpr_sampler;
13466 }
13467 }
13468
13469 uint32_t binding = get_decoration(var_id, DecorationBinding);
13470 if (type.basetype == SPIRType::SampledImage)
13471 {
13472 add_resource_name(var_id);
13473
13474 uint32_t plane_count = 1;
13475 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
13476 plane_count = constexpr_sampler->planes;
13477
13478 for (uint32_t i = 0; i < plane_count; i++)
13479 {
13480 uint32_t image_resource_index = get_metal_resource_index(var, SPIRType::Image, i);
13481 resources_in_set[desc_set].push_back(
13482 { &var, to_name(var_id), SPIRType::Image, image_resource_index, i });
13483 }
13484
13485 if (type.image.dim != DimBuffer && !constexpr_sampler)
13486 {
13487 uint32_t sampler_resource_index = get_metal_resource_index(var, SPIRType::Sampler);
13488 resources_in_set[desc_set].push_back(
13489 { &var, to_sampler_expression(var_id), SPIRType::Sampler, sampler_resource_index, 0 });
13490 }
13491 }
13492 else if (inline_uniform_blocks.count(SetBindingPair{ desc_set, binding }))
13493 {
13494 inline_block_vars.push_back(var_id);
13495 }
13496 else if (!constexpr_sampler)
13497 {
13498 // constexpr samplers are not declared as resources.
13499 // Inline uniform blocks are always emitted at the end.
13500 if (!msl_options.is_ios() || type.basetype != SPIRType::Image || type.image.sampled != 2)
13501 {
13502 add_resource_name(var_id);
13503 resources_in_set[desc_set].push_back(
13504 { &var, to_name(var_id), type.basetype, get_metal_resource_index(var, type.basetype), 0 });
13505 }
13506 }
13507
13508 // Check if this descriptor set needs a swizzle buffer.
13509 if (needs_swizzle_buffer_def && is_sampled_image_type(type))
13510 set_needs_swizzle_buffer[desc_set] = true;
13511 else if (buffers_requiring_array_length.count(var_id) != 0)
13512 {
13513 set_needs_buffer_sizes[desc_set] = true;
13514 needs_buffer_sizes = true;
13515 }
13516 }
13517 });
13518
13519 if (needs_swizzle_buffer_def || needs_buffer_sizes)
13520 {
13521 uint32_t uint_ptr_type_id = 0;
13522
13523 // We might have to add a swizzle buffer resource to the set.
13524 for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
13525 {
13526 if (!set_needs_swizzle_buffer[desc_set] && !set_needs_buffer_sizes[desc_set])
13527 continue;
13528
13529 if (uint_ptr_type_id == 0)
13530 {
13531 uint_ptr_type_id = ir.increase_bound_by(1);
13532
13533 // Create a buffer to hold extra data, including the swizzle constants.
13534 SPIRType uint_type_pointer = get_uint_type();
13535 uint_type_pointer.pointer = true;
13536 uint_type_pointer.pointer_depth = 1;
13537 uint_type_pointer.parent_type = get_uint_type_id();
13538 uint_type_pointer.storage = StorageClassUniform;
13539 set<SPIRType>(uint_ptr_type_id, uint_type_pointer);
13540 set_decoration(uint_ptr_type_id, DecorationArrayStride, 4);
13541 }
13542
13543 if (set_needs_swizzle_buffer[desc_set])
13544 {
13545 uint32_t var_id = ir.increase_bound_by(1);
13546 auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
13547 set_name(var_id, "spvSwizzleConstants");
13548 set_decoration(var_id, DecorationDescriptorSet, desc_set);
13549 set_decoration(var_id, DecorationBinding, kSwizzleBufferBinding);
13550 resources_in_set[desc_set].push_back(
13551 { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0 });
13552 }
13553
13554 if (set_needs_buffer_sizes[desc_set])
13555 {
13556 uint32_t var_id = ir.increase_bound_by(1);
13557 auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
13558 set_name(var_id, "spvBufferSizeConstants");
13559 set_decoration(var_id, DecorationDescriptorSet, desc_set);
13560 set_decoration(var_id, DecorationBinding, kBufferSizeBufferBinding);
13561 resources_in_set[desc_set].push_back(
13562 { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0 });
13563 }
13564 }
13565 }
13566
13567 // Now add inline uniform blocks.
13568 for (uint32_t var_id : inline_block_vars)
13569 {
13570 auto &var = get<SPIRVariable>(var_id);
13571 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
13572 add_resource_name(var_id);
13573 resources_in_set[desc_set].push_back(
13574 { &var, to_name(var_id), SPIRType::Struct, get_metal_resource_index(var, SPIRType::Struct), 0 });
13575 }
13576
13577 for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
13578 {
13579 auto &resources = resources_in_set[desc_set];
13580 if (resources.empty())
13581 continue;
13582
13583 assert(descriptor_set_is_argument_buffer(desc_set));
13584
13585 uint32_t next_id = ir.increase_bound_by(3);
13586 uint32_t type_id = next_id + 1;
13587 uint32_t ptr_type_id = next_id + 2;
13588 argument_buffer_ids[desc_set] = next_id;
13589
13590 auto &buffer_type = set<SPIRType>(type_id);
13591
13592 buffer_type.basetype = SPIRType::Struct;
13593
13594 if ((argument_buffer_device_storage_mask & (1u << desc_set)) != 0)
13595 {
13596 buffer_type.storage = StorageClassStorageBuffer;
13597 // Make sure the argument buffer gets marked as const device.
13598 set_decoration(next_id, DecorationNonWritable);
13599 // Need to mark the type as a Block to enable this.
13600 set_decoration(type_id, DecorationBlock);
13601 }
13602 else
13603 buffer_type.storage = StorageClassUniform;
13604
13605 set_name(type_id, join("spvDescriptorSetBuffer", desc_set));
13606
13607 auto &ptr_type = set<SPIRType>(ptr_type_id);
13608 ptr_type = buffer_type;
13609 ptr_type.pointer = true;
13610 ptr_type.pointer_depth = 1;
13611 ptr_type.parent_type = type_id;
13612
13613 uint32_t buffer_variable_id = next_id;
13614 set<SPIRVariable>(buffer_variable_id, ptr_type_id, StorageClassUniform);
13615 set_name(buffer_variable_id, join("spvDescriptorSet", desc_set));
13616
13617 // Ids must be emitted in ID order.
13618 sort(begin(resources), end(resources), [&](const Resource &lhs, const Resource &rhs) -> bool {
13619 return tie(lhs.index, lhs.basetype) < tie(rhs.index, rhs.basetype);
13620 });
13621
13622 uint32_t member_index = 0;
13623 for (auto &resource : resources)
13624 {
13625 auto &var = *resource.var;
13626 auto &type = get_variable_data_type(var);
13627 string mbr_name = ensure_valid_name(resource.name, "m");
13628 if (resource.plane > 0)
13629 mbr_name += join(plane_name_suffix, resource.plane);
13630 set_member_name(buffer_type.self, member_index, mbr_name);
13631
13632 if (resource.basetype == SPIRType::Sampler && type.basetype != SPIRType::Sampler)
13633 {
13634 // Have to synthesize a sampler type here.
13635
13636 bool type_is_array = !type.array.empty();
13637 uint32_t sampler_type_id = ir.increase_bound_by(type_is_array ? 2 : 1);
13638 auto &new_sampler_type = set<SPIRType>(sampler_type_id);
13639 new_sampler_type.basetype = SPIRType::Sampler;
13640 new_sampler_type.storage = StorageClassUniformConstant;
13641
13642 if (type_is_array)
13643 {
13644 uint32_t sampler_type_array_id = sampler_type_id + 1;
13645 auto &sampler_type_array = set<SPIRType>(sampler_type_array_id);
13646 sampler_type_array = new_sampler_type;
13647 sampler_type_array.array = type.array;
13648 sampler_type_array.array_size_literal = type.array_size_literal;
13649 sampler_type_array.parent_type = sampler_type_id;
13650 buffer_type.member_types.push_back(sampler_type_array_id);
13651 }
13652 else
13653 buffer_type.member_types.push_back(sampler_type_id);
13654 }
13655 else
13656 {
13657 uint32_t binding = get_decoration(var.self, DecorationBinding);
13658 SetBindingPair pair = { desc_set, binding };
13659
13660 if (resource.basetype == SPIRType::Image || resource.basetype == SPIRType::Sampler ||
13661 resource.basetype == SPIRType::SampledImage)
13662 {
13663 // Drop pointer information when we emit the resources into a struct.
13664 buffer_type.member_types.push_back(get_variable_data_type_id(var));
13665 if (resource.plane == 0)
13666 set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
13667 }
13668 else if (buffers_requiring_dynamic_offset.count(pair))
13669 {
13670 // Don't set the qualified name here; we'll define a variable holding the corrected buffer address later.
13671 buffer_type.member_types.push_back(var.basetype);
13672 buffers_requiring_dynamic_offset[pair].second = var.self;
13673 }
13674 else if (inline_uniform_blocks.count(pair))
13675 {
13676 // Put the buffer block itself into the argument buffer.
13677 buffer_type.member_types.push_back(get_variable_data_type_id(var));
13678 set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
13679 }
13680 else
13681 {
13682 // Resources will be declared as pointers not references, so automatically dereference as appropriate.
13683 buffer_type.member_types.push_back(var.basetype);
13684 if (type.array.empty())
13685 set_qualified_name(var.self, join("(*", to_name(buffer_variable_id), ".", mbr_name, ")"));
13686 else
13687 set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
13688 }
13689 }
13690
13691 set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationResourceIndexPrimary,
13692 resource.index);
13693 set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationInterfaceOrigID,
13694 var.self);
13695 member_index++;
13696 }
13697 }
13698 }
13699
activate_argument_buffer_resources()13700 void CompilerMSL::activate_argument_buffer_resources()
13701 {
13702 // For ABI compatibility, force-enable all resources which are part of argument buffers.
13703 ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, const SPIRVariable &) {
13704 if (!has_decoration(self, DecorationDescriptorSet))
13705 return;
13706
13707 uint32_t desc_set = get_decoration(self, DecorationDescriptorSet);
13708 if (descriptor_set_is_argument_buffer(desc_set))
13709 active_interface_variables.insert(self);
13710 });
13711 }
13712
using_builtin_array() const13713 bool CompilerMSL::using_builtin_array() const
13714 {
13715 return msl_options.force_native_arrays || is_using_builtin_array;
13716 }
13717