1 /*
2 * Copyright 2016-2019 The Brenwill Workshop Ltd.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "spirv_msl.hpp"
18 #include "GLSL.std.450.h"
19
20 #include <algorithm>
21 #include <assert.h>
22 #include <numeric>
23
24 #ifdef RARCH_INTERNAL
25 #include <retro_miscellaneous.h>
26 #endif
27
28 using namespace spv;
29 using namespace SPIRV_CROSS_NAMESPACE;
30 using namespace std;
31
32 static const uint32_t k_unknown_location = ~0u;
33 static const uint32_t k_unknown_component = ~0u;
34
CompilerMSL(std::vector<uint32_t> spirv_)35 CompilerMSL::CompilerMSL(std::vector<uint32_t> spirv_)
36 : CompilerGLSL(move(spirv_))
37 {
38 }
39
CompilerMSL(const uint32_t * ir_,size_t word_count)40 CompilerMSL::CompilerMSL(const uint32_t *ir_, size_t word_count)
41 : CompilerGLSL(ir_, word_count)
42 {
43 }
44
CompilerMSL(const ParsedIR & ir_)45 CompilerMSL::CompilerMSL(const ParsedIR &ir_)
46 : CompilerGLSL(ir_)
47 {
48 }
49
CompilerMSL(ParsedIR && ir_)50 CompilerMSL::CompilerMSL(ParsedIR &&ir_)
51 : CompilerGLSL(std::move(ir_))
52 {
53 }
54
add_msl_vertex_attribute(const MSLVertexAttr & va)55 void CompilerMSL::add_msl_vertex_attribute(const MSLVertexAttr &va)
56 {
57 vtx_attrs_by_location[va.location] = va;
58 if (va.builtin != BuiltInMax && !vtx_attrs_by_builtin.count(va.builtin))
59 vtx_attrs_by_builtin[va.builtin] = va;
60 }
61
add_msl_resource_binding(const MSLResourceBinding & binding)62 void CompilerMSL::add_msl_resource_binding(const MSLResourceBinding &binding)
63 {
64 StageSetBinding tuple = { binding.stage, binding.desc_set, binding.binding };
65 resource_bindings[tuple] = { binding, false };
66 }
67
add_discrete_descriptor_set(uint32_t desc_set)68 void CompilerMSL::add_discrete_descriptor_set(uint32_t desc_set)
69 {
70 if (desc_set < kMaxArgumentBuffers)
71 argument_buffer_discrete_mask |= 1u << desc_set;
72 }
73
is_msl_vertex_attribute_used(uint32_t location)74 bool CompilerMSL::is_msl_vertex_attribute_used(uint32_t location)
75 {
76 return vtx_attrs_in_use.count(location) != 0;
77 }
78
is_msl_resource_binding_used(ExecutionModel model,uint32_t desc_set,uint32_t binding)79 bool CompilerMSL::is_msl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding)
80 {
81 StageSetBinding tuple = { model, desc_set, binding };
82 auto itr = resource_bindings.find(tuple);
83 return itr != end(resource_bindings) && itr->second.second;
84 }
85
get_automatic_msl_resource_binding(uint32_t id) const86 uint32_t CompilerMSL::get_automatic_msl_resource_binding(uint32_t id) const
87 {
88 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexPrimary);
89 }
90
get_automatic_msl_resource_binding_secondary(uint32_t id) const91 uint32_t CompilerMSL::get_automatic_msl_resource_binding_secondary(uint32_t id) const
92 {
93 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexSecondary);
94 }
95
set_fragment_output_components(uint32_t location,uint32_t components)96 void CompilerMSL::set_fragment_output_components(uint32_t location, uint32_t components)
97 {
98 fragment_output_components[location] = components;
99 }
100
build_implicit_builtins()101 void CompilerMSL::build_implicit_builtins()
102 {
103 bool need_sample_pos = active_input_builtins.get(BuiltInSamplePosition);
104 bool need_vertex_params = capture_output_to_buffer && get_execution_model() == ExecutionModelVertex;
105 bool need_tesc_params = get_execution_model() == ExecutionModelTessellationControl;
106 bool need_subgroup_mask =
107 active_input_builtins.get(BuiltInSubgroupEqMask) || active_input_builtins.get(BuiltInSubgroupGeMask) ||
108 active_input_builtins.get(BuiltInSubgroupGtMask) || active_input_builtins.get(BuiltInSubgroupLeMask) ||
109 active_input_builtins.get(BuiltInSubgroupLtMask);
110 bool need_subgroup_ge_mask = !msl_options.is_ios() && (active_input_builtins.get(BuiltInSubgroupGeMask) ||
111 active_input_builtins.get(BuiltInSubgroupGtMask));
112 bool need_multiview = get_execution_model() == ExecutionModelVertex &&
113 (msl_options.multiview || active_input_builtins.get(BuiltInViewIndex));
114 if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
115 need_multiview || needs_subgroup_invocation_id)
116 {
117 bool has_frag_coord = false;
118 bool has_sample_id = false;
119 bool has_vertex_idx = false;
120 bool has_base_vertex = false;
121 bool has_instance_idx = false;
122 bool has_base_instance = false;
123 bool has_invocation_id = false;
124 bool has_primitive_id = false;
125 bool has_subgroup_invocation_id = false;
126 bool has_subgroup_size = false;
127 bool has_view_idx = false;
128
129 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
130 if (var.storage != StorageClassInput || !ir.meta[var.self].decoration.builtin)
131 return;
132
133 BuiltIn builtin = ir.meta[var.self].decoration.builtin_type;
134 if (need_subpass_input && builtin == BuiltInFragCoord)
135 {
136 builtin_frag_coord_id = var.self;
137 has_frag_coord = true;
138 }
139
140 if (need_sample_pos && builtin == BuiltInSampleId)
141 {
142 builtin_sample_id_id = var.self;
143 has_sample_id = true;
144 }
145
146 if (need_vertex_params)
147 {
148 switch (builtin)
149 {
150 case BuiltInVertexIndex:
151 builtin_vertex_idx_id = var.self;
152 has_vertex_idx = true;
153 break;
154 case BuiltInBaseVertex:
155 builtin_base_vertex_id = var.self;
156 has_base_vertex = true;
157 break;
158 case BuiltInInstanceIndex:
159 builtin_instance_idx_id = var.self;
160 has_instance_idx = true;
161 break;
162 case BuiltInBaseInstance:
163 builtin_base_instance_id = var.self;
164 has_base_instance = true;
165 break;
166 default:
167 break;
168 }
169 }
170
171 if (need_tesc_params)
172 {
173 switch (builtin)
174 {
175 case BuiltInInvocationId:
176 builtin_invocation_id_id = var.self;
177 has_invocation_id = true;
178 break;
179 case BuiltInPrimitiveId:
180 builtin_primitive_id_id = var.self;
181 has_primitive_id = true;
182 break;
183 default:
184 break;
185 }
186 }
187
188 if ((need_subgroup_mask || needs_subgroup_invocation_id) && builtin == BuiltInSubgroupLocalInvocationId)
189 {
190 builtin_subgroup_invocation_id_id = var.self;
191 has_subgroup_invocation_id = true;
192 }
193
194 if (need_subgroup_ge_mask && builtin == BuiltInSubgroupSize)
195 {
196 builtin_subgroup_size_id = var.self;
197 has_subgroup_size = true;
198 }
199
200 if (need_multiview)
201 {
202 if (builtin == BuiltInInstanceIndex)
203 {
204 // The view index here is derived from the instance index.
205 builtin_instance_idx_id = var.self;
206 has_instance_idx = true;
207 }
208
209 if (builtin == BuiltInViewIndex)
210 {
211 builtin_view_idx_id = var.self;
212 has_view_idx = true;
213 }
214 }
215 });
216
217 if (!has_frag_coord && need_subpass_input)
218 {
219 uint32_t offset = ir.increase_bound_by(3);
220 uint32_t type_id = offset;
221 uint32_t type_ptr_id = offset + 1;
222 uint32_t var_id = offset + 2;
223
224 // Create gl_FragCoord.
225 SPIRType vec4_type;
226 vec4_type.basetype = SPIRType::Float;
227 vec4_type.width = 32;
228 vec4_type.vecsize = 4;
229 set<SPIRType>(type_id, vec4_type);
230
231 SPIRType vec4_type_ptr;
232 vec4_type_ptr = vec4_type;
233 vec4_type_ptr.pointer = true;
234 vec4_type_ptr.parent_type = type_id;
235 vec4_type_ptr.storage = StorageClassInput;
236 auto &ptr_type = set<SPIRType>(type_ptr_id, vec4_type_ptr);
237 ptr_type.self = type_id;
238
239 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
240 set_decoration(var_id, DecorationBuiltIn, BuiltInFragCoord);
241 builtin_frag_coord_id = var_id;
242 mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var_id);
243 }
244
245 if (!has_sample_id && need_sample_pos)
246 {
247 uint32_t offset = ir.increase_bound_by(3);
248 uint32_t type_id = offset;
249 uint32_t type_ptr_id = offset + 1;
250 uint32_t var_id = offset + 2;
251
252 // Create gl_SampleID.
253 SPIRType uint_type;
254 uint_type.basetype = SPIRType::UInt;
255 uint_type.width = 32;
256 set<SPIRType>(type_id, uint_type);
257
258 SPIRType uint_type_ptr;
259 uint_type_ptr = uint_type;
260 uint_type_ptr.pointer = true;
261 uint_type_ptr.parent_type = type_id;
262 uint_type_ptr.storage = StorageClassInput;
263 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
264 ptr_type.self = type_id;
265
266 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
267 set_decoration(var_id, DecorationBuiltIn, BuiltInSampleId);
268 builtin_sample_id_id = var_id;
269 mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var_id);
270 }
271
272 if ((need_vertex_params && (!has_vertex_idx || !has_base_vertex || !has_instance_idx || !has_base_instance)) ||
273 (need_multiview && (!has_instance_idx || !has_view_idx)))
274 {
275 uint32_t offset = ir.increase_bound_by(2);
276 uint32_t type_id = offset;
277 uint32_t type_ptr_id = offset + 1;
278
279 SPIRType uint_type;
280 uint_type.basetype = SPIRType::UInt;
281 uint_type.width = 32;
282 set<SPIRType>(type_id, uint_type);
283
284 SPIRType uint_type_ptr;
285 uint_type_ptr = uint_type;
286 uint_type_ptr.pointer = true;
287 uint_type_ptr.parent_type = type_id;
288 uint_type_ptr.storage = StorageClassInput;
289 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
290 ptr_type.self = type_id;
291
292 if (need_vertex_params && !has_vertex_idx)
293 {
294 uint32_t var_id = ir.increase_bound_by(1);
295
296 // Create gl_VertexIndex.
297 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
298 set_decoration(var_id, DecorationBuiltIn, BuiltInVertexIndex);
299 builtin_vertex_idx_id = var_id;
300 mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var_id);
301 }
302
303 if (need_vertex_params && !has_base_vertex)
304 {
305 uint32_t var_id = ir.increase_bound_by(1);
306
307 // Create gl_BaseVertex.
308 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
309 set_decoration(var_id, DecorationBuiltIn, BuiltInBaseVertex);
310 builtin_base_vertex_id = var_id;
311 mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var_id);
312 }
313
314 if (!has_instance_idx) // Needed by both multiview and tessellation
315 {
316 uint32_t var_id = ir.increase_bound_by(1);
317
318 // Create gl_InstanceIndex.
319 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
320 set_decoration(var_id, DecorationBuiltIn, BuiltInInstanceIndex);
321 builtin_instance_idx_id = var_id;
322 mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var_id);
323
324 if (need_multiview)
325 {
326 // Multiview shaders are not allowed to write to gl_Layer, ostensibly because
327 // it is implicitly written from gl_ViewIndex, but we have to do that explicitly.
328 // Note that we can't just abuse gl_ViewIndex for this purpose: it's an input, but
329 // gl_Layer is an output in vertex-pipeline shaders.
330 uint32_t type_ptr_out_id = ir.increase_bound_by(2);
331 SPIRType uint_type_ptr_out;
332 uint_type_ptr_out = uint_type;
333 uint_type_ptr_out.pointer = true;
334 uint_type_ptr_out.parent_type = type_id;
335 uint_type_ptr_out.storage = StorageClassOutput;
336 auto &ptr_out_type = set<SPIRType>(type_ptr_out_id, uint_type_ptr_out);
337 ptr_out_type.self = type_id;
338 var_id = type_ptr_out_id + 1;
339 set<SPIRVariable>(var_id, type_ptr_out_id, StorageClassOutput);
340 set_decoration(var_id, DecorationBuiltIn, BuiltInLayer);
341 builtin_layer_id = var_id;
342 mark_implicit_builtin(StorageClassOutput, BuiltInLayer, var_id);
343 }
344 }
345
346 if (need_vertex_params && !has_base_instance)
347 {
348 uint32_t var_id = ir.increase_bound_by(1);
349
350 // Create gl_BaseInstance.
351 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
352 set_decoration(var_id, DecorationBuiltIn, BuiltInBaseInstance);
353 builtin_base_instance_id = var_id;
354 mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var_id);
355 }
356
357 if (need_multiview && !has_view_idx)
358 {
359 uint32_t var_id = ir.increase_bound_by(1);
360
361 // Create gl_ViewIndex.
362 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
363 set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex);
364 builtin_view_idx_id = var_id;
365 mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id);
366 }
367 }
368
369 if (need_tesc_params && (!has_invocation_id || !has_primitive_id))
370 {
371 uint32_t offset = ir.increase_bound_by(2);
372 uint32_t type_id = offset;
373 uint32_t type_ptr_id = offset + 1;
374
375 SPIRType uint_type;
376 uint_type.basetype = SPIRType::UInt;
377 uint_type.width = 32;
378 set<SPIRType>(type_id, uint_type);
379
380 SPIRType uint_type_ptr;
381 uint_type_ptr = uint_type;
382 uint_type_ptr.pointer = true;
383 uint_type_ptr.parent_type = type_id;
384 uint_type_ptr.storage = StorageClassInput;
385 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
386 ptr_type.self = type_id;
387
388 if (!has_invocation_id)
389 {
390 uint32_t var_id = ir.increase_bound_by(1);
391
392 // Create gl_InvocationID.
393 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
394 set_decoration(var_id, DecorationBuiltIn, BuiltInInvocationId);
395 builtin_invocation_id_id = var_id;
396 mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var_id);
397 }
398
399 if (!has_primitive_id)
400 {
401 uint32_t var_id = ir.increase_bound_by(1);
402
403 // Create gl_PrimitiveID.
404 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
405 set_decoration(var_id, DecorationBuiltIn, BuiltInPrimitiveId);
406 builtin_primitive_id_id = var_id;
407 mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var_id);
408 }
409 }
410
411 if (!has_subgroup_invocation_id && (need_subgroup_mask || needs_subgroup_invocation_id))
412 {
413 uint32_t offset = ir.increase_bound_by(3);
414 uint32_t type_id = offset;
415 uint32_t type_ptr_id = offset + 1;
416 uint32_t var_id = offset + 2;
417
418 // Create gl_SubgroupInvocationID.
419 SPIRType uint_type;
420 uint_type.basetype = SPIRType::UInt;
421 uint_type.width = 32;
422 set<SPIRType>(type_id, uint_type);
423
424 SPIRType uint_type_ptr;
425 uint_type_ptr = uint_type;
426 uint_type_ptr.pointer = true;
427 uint_type_ptr.parent_type = type_id;
428 uint_type_ptr.storage = StorageClassInput;
429 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
430 ptr_type.self = type_id;
431
432 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
433 set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupLocalInvocationId);
434 builtin_subgroup_invocation_id_id = var_id;
435 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var_id);
436 }
437
438 if (!has_subgroup_size && need_subgroup_ge_mask)
439 {
440 uint32_t offset = ir.increase_bound_by(3);
441 uint32_t type_id = offset;
442 uint32_t type_ptr_id = offset + 1;
443 uint32_t var_id = offset + 2;
444
445 // Create gl_SubgroupSize.
446 SPIRType uint_type;
447 uint_type.basetype = SPIRType::UInt;
448 uint_type.width = 32;
449 set<SPIRType>(type_id, uint_type);
450
451 SPIRType uint_type_ptr;
452 uint_type_ptr = uint_type;
453 uint_type_ptr.pointer = true;
454 uint_type_ptr.parent_type = type_id;
455 uint_type_ptr.storage = StorageClassInput;
456 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
457 ptr_type.self = type_id;
458
459 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
460 set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupSize);
461 builtin_subgroup_size_id = var_id;
462 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var_id);
463 }
464 }
465
466 if (needs_swizzle_buffer_def)
467 {
468 uint32_t var_id = build_constant_uint_array_pointer();
469 set_name(var_id, "spvSwizzleConstants");
470 // This should never match anything.
471 set_decoration(var_id, DecorationDescriptorSet, kSwizzleBufferBinding);
472 set_decoration(var_id, DecorationBinding, msl_options.swizzle_buffer_index);
473 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.swizzle_buffer_index);
474 swizzle_buffer_id = var_id;
475 }
476
477 if (!buffers_requiring_array_length.empty())
478 {
479 uint32_t var_id = build_constant_uint_array_pointer();
480 set_name(var_id, "spvBufferSizeConstants");
481 // This should never match anything.
482 set_decoration(var_id, DecorationDescriptorSet, kBufferSizeBufferBinding);
483 set_decoration(var_id, DecorationBinding, msl_options.buffer_size_buffer_index);
484 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.buffer_size_buffer_index);
485 buffer_size_buffer_id = var_id;
486 }
487
488 if (needs_view_mask_buffer())
489 {
490 uint32_t var_id = build_constant_uint_array_pointer();
491 set_name(var_id, "spvViewMask");
492 // This should never match anything.
493 set_decoration(var_id, DecorationDescriptorSet, ~(4u));
494 set_decoration(var_id, DecorationBinding, msl_options.view_mask_buffer_index);
495 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.view_mask_buffer_index);
496 view_mask_buffer_id = var_id;
497 }
498 }
499
mark_implicit_builtin(StorageClass storage,BuiltIn builtin,uint32_t id)500 void CompilerMSL::mark_implicit_builtin(StorageClass storage, BuiltIn builtin, uint32_t id)
501 {
502 Bitset *active_builtins = nullptr;
503 switch (storage)
504 {
505 case StorageClassInput:
506 active_builtins = &active_input_builtins;
507 break;
508
509 case StorageClassOutput:
510 active_builtins = &active_output_builtins;
511 break;
512
513 default:
514 break;
515 }
516
517 assert(active_builtins != nullptr);
518 active_builtins->set(builtin);
519 get_entry_point().interface_variables.push_back(id);
520 }
521
build_constant_uint_array_pointer()522 uint32_t CompilerMSL::build_constant_uint_array_pointer()
523 {
524 uint32_t offset = ir.increase_bound_by(4);
525 uint32_t type_id = offset;
526 uint32_t type_ptr_id = offset + 1;
527 uint32_t type_ptr_ptr_id = offset + 2;
528 uint32_t var_id = offset + 3;
529
530 // Create a buffer to hold extra data, including the swizzle constants.
531 SPIRType uint_type;
532 uint_type.basetype = SPIRType::UInt;
533 uint_type.width = 32;
534 set<SPIRType>(type_id, uint_type);
535
536 SPIRType uint_type_pointer = uint_type;
537 uint_type_pointer.pointer = true;
538 uint_type_pointer.pointer_depth = 1;
539 uint_type_pointer.parent_type = type_id;
540 uint_type_pointer.storage = StorageClassUniform;
541 set<SPIRType>(type_ptr_id, uint_type_pointer);
542 set_decoration(type_ptr_id, DecorationArrayStride, 4);
543
544 SPIRType uint_type_pointer2 = uint_type_pointer;
545 uint_type_pointer2.pointer_depth++;
546 uint_type_pointer2.parent_type = type_ptr_id;
547 set<SPIRType>(type_ptr_ptr_id, uint_type_pointer2);
548
549 set<SPIRVariable>(var_id, type_ptr_ptr_id, StorageClassUniformConstant);
550 return var_id;
551 }
552
create_sampler_address(const char * prefix,MSLSamplerAddress addr)553 static string create_sampler_address(const char *prefix, MSLSamplerAddress addr)
554 {
555 switch (addr)
556 {
557 case MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE:
558 return join(prefix, "address::clamp_to_edge");
559 case MSL_SAMPLER_ADDRESS_CLAMP_TO_ZERO:
560 return join(prefix, "address::clamp_to_zero");
561 case MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER:
562 return join(prefix, "address::clamp_to_border");
563 case MSL_SAMPLER_ADDRESS_REPEAT:
564 return join(prefix, "address::repeat");
565 case MSL_SAMPLER_ADDRESS_MIRRORED_REPEAT:
566 return join(prefix, "address::mirrored_repeat");
567 default:
568 SPIRV_CROSS_THROW("Invalid sampler addressing mode.");
569 }
570 }
571
get_stage_in_struct_type()572 SPIRType &CompilerMSL::get_stage_in_struct_type()
573 {
574 auto &si_var = get<SPIRVariable>(stage_in_var_id);
575 return get_variable_data_type(si_var);
576 }
577
get_stage_out_struct_type()578 SPIRType &CompilerMSL::get_stage_out_struct_type()
579 {
580 auto &so_var = get<SPIRVariable>(stage_out_var_id);
581 return get_variable_data_type(so_var);
582 }
583
get_patch_stage_in_struct_type()584 SPIRType &CompilerMSL::get_patch_stage_in_struct_type()
585 {
586 auto &si_var = get<SPIRVariable>(patch_stage_in_var_id);
587 return get_variable_data_type(si_var);
588 }
589
get_patch_stage_out_struct_type()590 SPIRType &CompilerMSL::get_patch_stage_out_struct_type()
591 {
592 auto &so_var = get<SPIRVariable>(patch_stage_out_var_id);
593 return get_variable_data_type(so_var);
594 }
595
get_tess_factor_struct_name()596 std::string CompilerMSL::get_tess_factor_struct_name()
597 {
598 if (get_entry_point().flags.get(ExecutionModeTriangles))
599 return "MTLTriangleTessellationFactorsHalf";
600 return "MTLQuadTessellationFactorsHalf";
601 }
602
emit_entry_point_declarations()603 void CompilerMSL::emit_entry_point_declarations()
604 {
605 // FIXME: Get test coverage here ...
606
607 // Emit constexpr samplers here.
608 for (auto &samp : constexpr_samplers_by_id)
609 {
610 auto &var = get<SPIRVariable>(samp.first);
611 auto &type = get<SPIRType>(var.basetype);
612 if (type.basetype == SPIRType::Sampler)
613 add_resource_name(samp.first);
614
615 SmallVector<string> args;
616 auto &s = samp.second;
617
618 if (s.coord != MSL_SAMPLER_COORD_NORMALIZED)
619 args.push_back("coord::pixel");
620
621 if (s.min_filter == s.mag_filter)
622 {
623 if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
624 args.push_back("filter::linear");
625 }
626 else
627 {
628 if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
629 args.push_back("min_filter::linear");
630 if (s.mag_filter != MSL_SAMPLER_FILTER_NEAREST)
631 args.push_back("mag_filter::linear");
632 }
633
634 switch (s.mip_filter)
635 {
636 case MSL_SAMPLER_MIP_FILTER_NONE:
637 // Default
638 break;
639 case MSL_SAMPLER_MIP_FILTER_NEAREST:
640 args.push_back("mip_filter::nearest");
641 break;
642 case MSL_SAMPLER_MIP_FILTER_LINEAR:
643 args.push_back("mip_filter::linear");
644 break;
645 default:
646 SPIRV_CROSS_THROW("Invalid mip filter.");
647 }
648
649 if (s.s_address == s.t_address && s.s_address == s.r_address)
650 {
651 if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
652 args.push_back(create_sampler_address("", s.s_address));
653 }
654 else
655 {
656 if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
657 args.push_back(create_sampler_address("s_", s.s_address));
658 if (s.t_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
659 args.push_back(create_sampler_address("t_", s.t_address));
660 if (s.r_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
661 args.push_back(create_sampler_address("r_", s.r_address));
662 }
663
664 if (s.compare_enable)
665 {
666 switch (s.compare_func)
667 {
668 case MSL_SAMPLER_COMPARE_FUNC_ALWAYS:
669 args.push_back("compare_func::always");
670 break;
671 case MSL_SAMPLER_COMPARE_FUNC_NEVER:
672 args.push_back("compare_func::never");
673 break;
674 case MSL_SAMPLER_COMPARE_FUNC_EQUAL:
675 args.push_back("compare_func::equal");
676 break;
677 case MSL_SAMPLER_COMPARE_FUNC_NOT_EQUAL:
678 args.push_back("compare_func::not_equal");
679 break;
680 case MSL_SAMPLER_COMPARE_FUNC_LESS:
681 args.push_back("compare_func::less");
682 break;
683 case MSL_SAMPLER_COMPARE_FUNC_LESS_EQUAL:
684 args.push_back("compare_func::less_equal");
685 break;
686 case MSL_SAMPLER_COMPARE_FUNC_GREATER:
687 args.push_back("compare_func::greater");
688 break;
689 case MSL_SAMPLER_COMPARE_FUNC_GREATER_EQUAL:
690 args.push_back("compare_func::greater_equal");
691 break;
692 default:
693 SPIRV_CROSS_THROW("Invalid sampler compare function.");
694 }
695 }
696
697 if (s.s_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER || s.t_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER ||
698 s.r_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER)
699 {
700 switch (s.border_color)
701 {
702 case MSL_SAMPLER_BORDER_COLOR_OPAQUE_BLACK:
703 args.push_back("border_color::opaque_black");
704 break;
705 case MSL_SAMPLER_BORDER_COLOR_OPAQUE_WHITE:
706 args.push_back("border_color::opaque_white");
707 break;
708 case MSL_SAMPLER_BORDER_COLOR_TRANSPARENT_BLACK:
709 args.push_back("border_color::transparent_black");
710 break;
711 default:
712 SPIRV_CROSS_THROW("Invalid sampler border color.");
713 }
714 }
715
716 if (s.anisotropy_enable)
717 args.push_back(join("max_anisotropy(", s.max_anisotropy, ")"));
718 if (s.lod_clamp_enable)
719 {
720 args.push_back(join("lod_clamp(", convert_to_string(s.lod_clamp_min, current_locale_radix_character), ", ",
721 convert_to_string(s.lod_clamp_max, current_locale_radix_character), ")"));
722 }
723
724 statement("constexpr sampler ",
725 type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first),
726 "(", merge(args), ");");
727 }
728
729 // Emit buffer arrays here.
730 for (uint32_t array_id : buffer_arrays)
731 {
732 const auto &var = get<SPIRVariable>(array_id);
733 const auto &type = get_variable_data_type(var);
734 string name = to_name(array_id);
735 statement(get_argument_address_space(var) + " " + type_to_glsl(type) + "* " + name + "[] =");
736 begin_scope();
737 for (uint32_t i = 0; i < type.array[0]; ++i)
738 statement(name + "_" + convert_to_string(i) + ",");
739 end_scope_decl();
740 statement_no_indent("");
741 }
742 // For some reason, without this, we end up emitting the arrays twice.
743 buffer_arrays.clear();
744 }
745
compile()746 string CompilerMSL::compile()
747 {
748 // Do not deal with GLES-isms like precision, older extensions and such.
749 options.vulkan_semantics = true;
750 options.es = false;
751 options.version = 450;
752 backend.null_pointer_literal = "nullptr";
753 backend.float_literal_suffix = false;
754 backend.uint32_t_literal_suffix = true;
755 backend.int16_t_literal_suffix = "";
756 backend.uint16_t_literal_suffix = "u";
757 backend.basic_int_type = "int";
758 backend.basic_uint_type = "uint";
759 backend.basic_int8_type = "char";
760 backend.basic_uint8_type = "uchar";
761 backend.basic_int16_type = "short";
762 backend.basic_uint16_type = "ushort";
763 backend.discard_literal = "discard_fragment()";
764 backend.swizzle_is_function = false;
765 backend.shared_is_implied = false;
766 backend.use_initializer_list = true;
767 backend.use_typed_initializer_list = true;
768 backend.native_row_major_matrix = false;
769 backend.unsized_array_supported = false;
770 backend.can_declare_arrays_inline = false;
771 backend.can_return_array = false;
772 backend.boolean_mix_support = false;
773 backend.allow_truncated_access_chain = true;
774 backend.array_is_value_type = false;
775 backend.comparison_image_samples_scalar = true;
776 backend.native_pointers = true;
777 backend.nonuniform_qualifier = "";
778 backend.support_small_type_sampling_result = true;
779
780 capture_output_to_buffer = msl_options.capture_output_to_buffer;
781 is_rasterization_disabled = msl_options.disable_rasterization || capture_output_to_buffer;
782
783 // Initialize array here rather than constructor, MSVC 2013 workaround.
784 for (auto &id : next_metal_resource_ids)
785 id = 0;
786
787 fixup_type_alias();
788 replace_illegal_names();
789
790 struct_member_padding.clear();
791
792 build_function_control_flow_graphs_and_analyze();
793 update_active_builtins();
794 analyze_image_and_sampler_usage();
795 analyze_sampled_image_usage();
796 preprocess_op_codes();
797 build_implicit_builtins();
798
799 fixup_image_load_store_access();
800
801 set_enabled_interface_variables(get_active_interface_variables());
802 if (swizzle_buffer_id)
803 active_interface_variables.insert(swizzle_buffer_id);
804 if (buffer_size_buffer_id)
805 active_interface_variables.insert(buffer_size_buffer_id);
806 if (view_mask_buffer_id)
807 active_interface_variables.insert(view_mask_buffer_id);
808 if (builtin_layer_id)
809 active_interface_variables.insert(builtin_layer_id);
810
811 // Create structs to hold input, output and uniform variables.
812 // Do output first to ensure out. is declared at top of entry function.
813 qual_pos_var_name = "";
814 stage_out_var_id = add_interface_block(StorageClassOutput);
815 patch_stage_out_var_id = add_interface_block(StorageClassOutput, true);
816 stage_in_var_id = add_interface_block(StorageClassInput);
817 if (get_execution_model() == ExecutionModelTessellationEvaluation)
818 patch_stage_in_var_id = add_interface_block(StorageClassInput, true);
819
820 if (get_execution_model() == ExecutionModelTessellationControl)
821 stage_out_ptr_var_id = add_interface_block_pointer(stage_out_var_id, StorageClassOutput);
822 if (is_tessellation_shader())
823 stage_in_ptr_var_id = add_interface_block_pointer(stage_in_var_id, StorageClassInput);
824
825 // Metal vertex functions that define no output must disable rasterization and return void.
826 if (!stage_out_var_id)
827 is_rasterization_disabled = true;
828
829 // Convert the use of global variables to recursively-passed function parameters
830 localize_global_variables();
831 extract_global_variables_from_functions();
832
833 // Mark any non-stage-in structs to be tightly packed.
834 mark_packable_structs();
835 reorder_type_alias();
836
837 // Add fixup hooks required by shader inputs and outputs. This needs to happen before
838 // the loop, so the hooks aren't added multiple times.
839 fix_up_shader_inputs_outputs();
840
841 // If we are using argument buffers, we create argument buffer structures for them here.
842 // These buffers will be used in the entry point, not the individual resources.
843 if (msl_options.argument_buffers)
844 {
845 if (!msl_options.supports_msl_version(2, 0))
846 SPIRV_CROSS_THROW("Argument buffers can only be used with MSL 2.0 and up.");
847 analyze_argument_buffers();
848 }
849
850 uint32_t pass_count = 0;
851 do
852 {
853 if (pass_count >= 3)
854 SPIRV_CROSS_THROW("Over 3 compilation loops detected. Must be a bug!");
855
856 reset();
857
858 // Start bindings at zero.
859 next_metal_resource_index_buffer = 0;
860 next_metal_resource_index_texture = 0;
861 next_metal_resource_index_sampler = 0;
862 for (auto &id : next_metal_resource_ids)
863 id = 0;
864
865 // Move constructor for this type is broken on GCC 4.9 ...
866 buffer.reset();
867
868 emit_header();
869 emit_specialization_constants_and_structs();
870 emit_resources();
871 emit_custom_functions();
872 emit_function(get<SPIRFunction>(ir.default_entry_point), Bitset());
873
874 pass_count++;
875 } while (is_forcing_recompilation());
876
877 return buffer.str();
878 }
879
880 // Register the need to output any custom functions.
preprocess_op_codes()881 void CompilerMSL::preprocess_op_codes()
882 {
883 OpCodePreprocessor preproc(*this);
884 traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), preproc);
885
886 suppress_missing_prototypes = preproc.suppress_missing_prototypes;
887
888 if (preproc.uses_atomics)
889 {
890 add_header_line("#include <metal_atomic>");
891 add_pragma_line("#pragma clang diagnostic ignored \"-Wunused-variable\"");
892 }
893
894 // Metal vertex functions that write to resources must disable rasterization and return void.
895 if (preproc.uses_resource_write)
896 is_rasterization_disabled = true;
897
898 // Tessellation control shaders are run as compute functions in Metal, and so
899 // must capture their output to a buffer.
900 if (get_execution_model() == ExecutionModelTessellationControl)
901 {
902 is_rasterization_disabled = true;
903 capture_output_to_buffer = true;
904 }
905
906 if (preproc.needs_subgroup_invocation_id)
907 needs_subgroup_invocation_id = true;
908 }
909
910 // Move the Private and Workgroup global variables to the entry function.
911 // Non-constant variables cannot have global scope in Metal.
localize_global_variables()912 void CompilerMSL::localize_global_variables()
913 {
914 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
915 auto iter = global_variables.begin();
916 while (iter != global_variables.end())
917 {
918 uint32_t v_id = *iter;
919 auto &var = get<SPIRVariable>(v_id);
920 if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup)
921 {
922 if (!variable_is_lut(var))
923 entry_func.add_local_variable(v_id);
924 iter = global_variables.erase(iter);
925 }
926 else
927 iter++;
928 }
929 }
930
931 // For any global variable accessed directly by a function,
932 // extract that variable and add it as an argument to that function.
extract_global_variables_from_functions()933 void CompilerMSL::extract_global_variables_from_functions()
934 {
935 // Uniforms
936 unordered_set<uint32_t> global_var_ids;
937 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
938 if (var.storage == StorageClassInput || var.storage == StorageClassOutput ||
939 var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
940 var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer)
941 {
942 global_var_ids.insert(var.self);
943 }
944 });
945
946 // Local vars that are declared in the main function and accessed directly by a function
947 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
948 for (auto &var : entry_func.local_variables)
949 if (get<SPIRVariable>(var).storage != StorageClassFunction)
950 global_var_ids.insert(var);
951
952 std::set<uint32_t> added_arg_ids;
953 unordered_set<uint32_t> processed_func_ids;
954 extract_global_variables_from_function(ir.default_entry_point, added_arg_ids, global_var_ids, processed_func_ids);
955 }
956
957 // MSL does not support the use of global variables for shader input content.
958 // For any global variable accessed directly by the specified function, extract that variable,
959 // add it as an argument to that function, and the arg to the added_arg_ids collection.
extract_global_variables_from_function(uint32_t func_id,std::set<uint32_t> & added_arg_ids,unordered_set<uint32_t> & global_var_ids,unordered_set<uint32_t> & processed_func_ids)960 void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::set<uint32_t> &added_arg_ids,
961 unordered_set<uint32_t> &global_var_ids,
962 unordered_set<uint32_t> &processed_func_ids)
963 {
964 // Avoid processing a function more than once
965 if (processed_func_ids.find(func_id) != processed_func_ids.end())
966 {
967 // Return function global variables
968 added_arg_ids = function_global_vars[func_id];
969 return;
970 }
971
972 processed_func_ids.insert(func_id);
973
974 auto &func = get<SPIRFunction>(func_id);
975
976 // Recursively establish global args added to functions on which we depend.
977 for (auto block : func.blocks)
978 {
979 auto &b = get<SPIRBlock>(block);
980 for (auto &i : b.ops)
981 {
982 auto ops = stream(i);
983 auto op = static_cast<Op>(i.op);
984
985 switch (op)
986 {
987 case OpLoad:
988 case OpInBoundsAccessChain:
989 case OpAccessChain:
990 case OpPtrAccessChain:
991 case OpArrayLength:
992 {
993 uint32_t base_id = ops[2];
994 if (global_var_ids.find(base_id) != global_var_ids.end())
995 added_arg_ids.insert(base_id);
996
997 auto &type = get<SPIRType>(ops[0]);
998 if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData)
999 {
1000 // Implicitly reads gl_FragCoord.
1001 assert(builtin_frag_coord_id != 0);
1002 added_arg_ids.insert(builtin_frag_coord_id);
1003 }
1004
1005 break;
1006 }
1007
1008 case OpFunctionCall:
1009 {
1010 // First see if any of the function call args are globals
1011 for (uint32_t arg_idx = 3; arg_idx < i.length; arg_idx++)
1012 {
1013 uint32_t arg_id = ops[arg_idx];
1014 if (global_var_ids.find(arg_id) != global_var_ids.end())
1015 added_arg_ids.insert(arg_id);
1016 }
1017
1018 // Then recurse into the function itself to extract globals used internally in the function
1019 uint32_t inner_func_id = ops[2];
1020 std::set<uint32_t> inner_func_args;
1021 extract_global_variables_from_function(inner_func_id, inner_func_args, global_var_ids,
1022 processed_func_ids);
1023 added_arg_ids.insert(inner_func_args.begin(), inner_func_args.end());
1024 break;
1025 }
1026
1027 case OpStore:
1028 {
1029 uint32_t base_id = ops[0];
1030 if (global_var_ids.find(base_id) != global_var_ids.end())
1031 added_arg_ids.insert(base_id);
1032 break;
1033 }
1034
1035 case OpSelect:
1036 {
1037 uint32_t base_id = ops[3];
1038 if (global_var_ids.find(base_id) != global_var_ids.end())
1039 added_arg_ids.insert(base_id);
1040 base_id = ops[4];
1041 if (global_var_ids.find(base_id) != global_var_ids.end())
1042 added_arg_ids.insert(base_id);
1043 break;
1044 }
1045
1046 default:
1047 break;
1048 }
1049
1050 // TODO: Add all other operations which can affect memory.
1051 // We should consider a more unified system here to reduce boiler-plate.
1052 // This kind of analysis is done in several places ...
1053 }
1054 }
1055
1056 function_global_vars[func_id] = added_arg_ids;
1057
1058 // Add the global variables as arguments to the function
1059 if (func_id != ir.default_entry_point)
1060 {
1061 bool added_in = false;
1062 bool added_out = false;
1063 for (uint32_t arg_id : added_arg_ids)
1064 {
1065 auto &var = get<SPIRVariable>(arg_id);
1066 uint32_t type_id = var.basetype;
1067 auto *p_type = &get<SPIRType>(type_id);
1068 BuiltIn bi_type = BuiltIn(get_decoration(arg_id, DecorationBuiltIn));
1069
1070 if (((is_tessellation_shader() && var.storage == StorageClassInput) ||
1071 (get_execution_model() == ExecutionModelTessellationControl && var.storage == StorageClassOutput)) &&
1072 !(has_decoration(arg_id, DecorationPatch) || is_patch_block(*p_type)) &&
1073 (!is_builtin_variable(var) || bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
1074 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance ||
1075 p_type->basetype == SPIRType::Struct))
1076 {
1077 // Tessellation control shaders see inputs and per-vertex outputs as arrays.
1078 // Similarly, tessellation evaluation shaders see per-vertex inputs as arrays.
1079 // We collected them into a structure; we must pass the array of this
1080 // structure to the function.
1081 std::string name;
1082 if (var.storage == StorageClassInput)
1083 {
1084 if (added_in)
1085 continue;
1086 name = input_wg_var_name;
1087 arg_id = stage_in_ptr_var_id;
1088 added_in = true;
1089 }
1090 else if (var.storage == StorageClassOutput)
1091 {
1092 if (added_out)
1093 continue;
1094 name = "gl_out";
1095 arg_id = stage_out_ptr_var_id;
1096 added_out = true;
1097 }
1098 type_id = get<SPIRVariable>(arg_id).basetype;
1099 uint32_t next_id = ir.increase_bound_by(1);
1100 func.add_parameter(type_id, next_id, true);
1101 set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1102
1103 set_name(next_id, name);
1104 }
1105 else if (is_builtin_variable(var) && p_type->basetype == SPIRType::Struct)
1106 {
1107 // Get the pointee type
1108 type_id = get_pointee_type_id(type_id);
1109 p_type = &get<SPIRType>(type_id);
1110
1111 uint32_t mbr_idx = 0;
1112 for (auto &mbr_type_id : p_type->member_types)
1113 {
1114 BuiltIn builtin = BuiltInMax;
1115 bool is_builtin = is_member_builtin(*p_type, mbr_idx, &builtin);
1116 if (is_builtin && has_active_builtin(builtin, var.storage))
1117 {
1118 // Add a arg variable with the same type and decorations as the member
1119 uint32_t next_ids = ir.increase_bound_by(2);
1120 uint32_t ptr_type_id = next_ids + 0;
1121 uint32_t var_id = next_ids + 1;
1122
1123 // Make sure we have an actual pointer type,
1124 // so that we will get the appropriate address space when declaring these builtins.
1125 auto &ptr = set<SPIRType>(ptr_type_id, get<SPIRType>(mbr_type_id));
1126 ptr.self = mbr_type_id;
1127 ptr.storage = var.storage;
1128 ptr.pointer = true;
1129 ptr.parent_type = mbr_type_id;
1130
1131 func.add_parameter(mbr_type_id, var_id, true);
1132 set<SPIRVariable>(var_id, ptr_type_id, StorageClassFunction);
1133 ir.meta[var_id].decoration = ir.meta[type_id].members[mbr_idx];
1134 }
1135 mbr_idx++;
1136 }
1137 }
1138 else
1139 {
1140 uint32_t next_id = ir.increase_bound_by(1);
1141 func.add_parameter(type_id, next_id, true);
1142 set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1143
1144 // Ensure the existing variable has a valid name and the new variable has all the same meta info
1145 set_name(arg_id, ensure_valid_name(to_name(arg_id), "v"));
1146 ir.meta[next_id] = ir.meta[arg_id];
1147 }
1148 }
1149 }
1150 }
1151
1152 // For all variables that are some form of non-input-output interface block, mark that all the structs
1153 // that are recursively contained within the type referenced by that variable should be packed tightly.
mark_packable_structs()1154 void CompilerMSL::mark_packable_structs()
1155 {
1156 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1157 if (var.storage != StorageClassFunction && !is_hidden_variable(var))
1158 {
1159 auto &type = this->get<SPIRType>(var.basetype);
1160 if (type.pointer &&
1161 (type.storage == StorageClassUniform || type.storage == StorageClassUniformConstant ||
1162 type.storage == StorageClassPushConstant || type.storage == StorageClassStorageBuffer) &&
1163 (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
1164 mark_as_packable(type);
1165 }
1166 });
1167 }
1168
1169 // If the specified type is a struct, it and any nested structs
1170 // are marked as packable with the SPIRVCrossDecorationPacked decoration,
mark_as_packable(SPIRType & type)1171 void CompilerMSL::mark_as_packable(SPIRType &type)
1172 {
1173 // If this is not the base type (eg. it's a pointer or array), tunnel down
1174 if (type.parent_type)
1175 {
1176 mark_as_packable(get<SPIRType>(type.parent_type));
1177 return;
1178 }
1179
1180 if (type.basetype == SPIRType::Struct)
1181 {
1182 set_extended_decoration(type.self, SPIRVCrossDecorationPacked);
1183
1184 // Recurse
1185 size_t mbr_cnt = type.member_types.size();
1186 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
1187 {
1188 uint32_t mbr_type_id = type.member_types[mbr_idx];
1189 auto &mbr_type = get<SPIRType>(mbr_type_id);
1190 mark_as_packable(mbr_type);
1191 if (mbr_type.type_alias)
1192 {
1193 auto &mbr_type_alias = get<SPIRType>(mbr_type.type_alias);
1194 mark_as_packable(mbr_type_alias);
1195 }
1196 }
1197 }
1198 }
1199
1200 // If a vertex attribute exists at the location, it is marked as being used by this shader
mark_location_as_used_by_shader(uint32_t location,StorageClass storage)1201 void CompilerMSL::mark_location_as_used_by_shader(uint32_t location, StorageClass storage)
1202 {
1203 if ((get_execution_model() == ExecutionModelVertex || is_tessellation_shader()) && (storage == StorageClassInput))
1204 vtx_attrs_in_use.insert(location);
1205 }
1206
get_target_components_for_fragment_location(uint32_t location) const1207 uint32_t CompilerMSL::get_target_components_for_fragment_location(uint32_t location) const
1208 {
1209 auto itr = fragment_output_components.find(location);
1210 if (itr == end(fragment_output_components))
1211 return 4;
1212 else
1213 return itr->second;
1214 }
1215
build_extended_vector_type(uint32_t type_id,uint32_t components)1216 uint32_t CompilerMSL::build_extended_vector_type(uint32_t type_id, uint32_t components)
1217 {
1218 uint32_t new_type_id = ir.increase_bound_by(1);
1219 auto &type = set<SPIRType>(new_type_id, get<SPIRType>(type_id));
1220 type.vecsize = components;
1221 type.self = new_type_id;
1222 type.parent_type = type_id;
1223 type.pointer = false;
1224
1225 return new_type_id;
1226 }
1227
add_plain_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,bool strip_array)1228 void CompilerMSL::add_plain_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
1229 SPIRType &ib_type, SPIRVariable &var, bool strip_array)
1230 {
1231 bool is_builtin = is_builtin_variable(var);
1232 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
1233 bool is_flat = has_decoration(var.self, DecorationFlat);
1234 bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
1235 bool is_centroid = has_decoration(var.self, DecorationCentroid);
1236 bool is_sample = has_decoration(var.self, DecorationSample);
1237
1238 // Add a reference to the variable type to the interface struct.
1239 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1240 uint32_t type_id = ensure_correct_builtin_type(var.basetype, builtin);
1241 var.basetype = type_id;
1242
1243 type_id = get_pointee_type_id(var.basetype);
1244 if (strip_array && is_array(get<SPIRType>(type_id)))
1245 type_id = get<SPIRType>(type_id).parent_type;
1246 auto &type = get<SPIRType>(type_id);
1247 uint32_t target_components = 0;
1248 uint32_t type_components = type.vecsize;
1249 bool padded_output = false;
1250
1251 // Check if we need to pad fragment output to match a certain number of components.
1252 if (get_decoration_bitset(var.self).get(DecorationLocation) && msl_options.pad_fragment_output_components &&
1253 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput)
1254 {
1255 uint32_t locn = get_decoration(var.self, DecorationLocation);
1256 target_components = get_target_components_for_fragment_location(locn);
1257 if (type_components < target_components)
1258 {
1259 // Make a new type here.
1260 type_id = build_extended_vector_type(type_id, target_components);
1261 padded_output = true;
1262 }
1263 }
1264
1265 ib_type.member_types.push_back(type_id);
1266
1267 // Give the member a name
1268 string mbr_name = ensure_valid_name(to_expression(var.self), "m");
1269 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1270
1271 // Update the original variable reference to include the structure reference
1272 string qual_var_name = ib_var_ref + "." + mbr_name;
1273 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1274
1275 if (padded_output)
1276 {
1277 entry_func.add_local_variable(var.self);
1278 vars_needing_early_declaration.push_back(var.self);
1279
1280 entry_func.fixup_hooks_out.push_back([=, &var]() {
1281 SPIRType &padded_type = this->get<SPIRType>(type_id);
1282 statement(qual_var_name, " = ", remap_swizzle(padded_type, type_components, to_name(var.self)), ";");
1283 });
1284 }
1285 else if (!strip_array)
1286 ir.meta[var.self].decoration.qualified_alias = qual_var_name;
1287
1288 if (var.storage == StorageClassOutput && var.initializer != 0)
1289 {
1290 entry_func.fixup_hooks_in.push_back(
1291 [=, &var]() { statement(qual_var_name, " = ", to_expression(var.initializer), ";"); });
1292 }
1293
1294 // Copy the variable location from the original variable to the member
1295 if (get_decoration_bitset(var.self).get(DecorationLocation))
1296 {
1297 uint32_t locn = get_decoration(var.self, DecorationLocation);
1298 if (storage == StorageClassInput && (get_execution_model() == ExecutionModelVertex || is_tessellation_shader()))
1299 {
1300 type_id = ensure_correct_attribute_type(var.basetype, locn);
1301 var.basetype = type_id;
1302 type_id = get_pointee_type_id(type_id);
1303 if (strip_array && is_array(get<SPIRType>(type_id)))
1304 type_id = get<SPIRType>(type_id).parent_type;
1305 ib_type.member_types[ib_mbr_idx] = type_id;
1306 }
1307 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1308 mark_location_as_used_by_shader(locn, storage);
1309 }
1310 else if (is_builtin && is_tessellation_shader() && vtx_attrs_by_builtin.count(builtin))
1311 {
1312 uint32_t locn = vtx_attrs_by_builtin[builtin].location;
1313 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1314 mark_location_as_used_by_shader(locn, storage);
1315 }
1316
1317 if (get_decoration_bitset(var.self).get(DecorationComponent))
1318 {
1319 uint32_t comp = get_decoration(var.self, DecorationComponent);
1320 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp);
1321 }
1322
1323 if (get_decoration_bitset(var.self).get(DecorationIndex))
1324 {
1325 uint32_t index = get_decoration(var.self, DecorationIndex);
1326 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
1327 }
1328
1329 // Mark the member as builtin if needed
1330 if (is_builtin)
1331 {
1332 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
1333 if (builtin == BuiltInPosition && storage == StorageClassOutput)
1334 qual_pos_var_name = qual_var_name;
1335 }
1336
1337 // Copy interpolation decorations if needed
1338 if (is_flat)
1339 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
1340 if (is_noperspective)
1341 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
1342 if (is_centroid)
1343 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
1344 if (is_sample)
1345 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
1346
1347 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
1348 }
1349
add_composite_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,bool strip_array)1350 void CompilerMSL::add_composite_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
1351 SPIRType &ib_type, SPIRVariable &var, bool strip_array)
1352 {
1353 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1354 auto &var_type = strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
1355 uint32_t elem_cnt = 0;
1356
1357 if (is_matrix(var_type))
1358 {
1359 if (is_array(var_type))
1360 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
1361
1362 elem_cnt = var_type.columns;
1363 }
1364 else if (is_array(var_type))
1365 {
1366 if (var_type.array.size() != 1)
1367 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
1368
1369 elem_cnt = to_array_size_literal(var_type);
1370 }
1371
1372 bool is_builtin = is_builtin_variable(var);
1373 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
1374 bool is_flat = has_decoration(var.self, DecorationFlat);
1375 bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
1376 bool is_centroid = has_decoration(var.self, DecorationCentroid);
1377 bool is_sample = has_decoration(var.self, DecorationSample);
1378
1379 auto *usable_type = &var_type;
1380 if (usable_type->pointer)
1381 usable_type = &get<SPIRType>(usable_type->parent_type);
1382 while (is_array(*usable_type) || is_matrix(*usable_type))
1383 usable_type = &get<SPIRType>(usable_type->parent_type);
1384
1385 // If a builtin, force it to have the proper name.
1386 if (is_builtin)
1387 set_name(var.self, builtin_to_glsl(builtin, StorageClassFunction));
1388
1389 entry_func.add_local_variable(var.self);
1390
1391 // We need to declare the variable early and at entry-point scope.
1392 vars_needing_early_declaration.push_back(var.self);
1393
1394 for (uint32_t i = 0; i < elem_cnt; i++)
1395 {
1396 // Add a reference to the variable type to the interface struct.
1397 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1398
1399 uint32_t target_components = 0;
1400 bool padded_output = false;
1401 uint32_t type_id = usable_type->self;
1402
1403 // Check if we need to pad fragment output to match a certain number of components.
1404 if (get_decoration_bitset(var.self).get(DecorationLocation) && msl_options.pad_fragment_output_components &&
1405 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput)
1406 {
1407 uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
1408 target_components = get_target_components_for_fragment_location(locn);
1409 if (usable_type->vecsize < target_components)
1410 {
1411 // Make a new type here.
1412 type_id = build_extended_vector_type(usable_type->self, target_components);
1413 padded_output = true;
1414 }
1415 }
1416
1417 ib_type.member_types.push_back(get_pointee_type_id(type_id));
1418
1419 // Give the member a name
1420 string mbr_name = ensure_valid_name(join(to_expression(var.self), "_", i), "m");
1421 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1422
1423 // There is no qualified alias since we need to flatten the internal array on return.
1424 if (get_decoration_bitset(var.self).get(DecorationLocation))
1425 {
1426 uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
1427 if (storage == StorageClassInput &&
1428 (get_execution_model() == ExecutionModelVertex || is_tessellation_shader()))
1429 {
1430 var.basetype = ensure_correct_attribute_type(var.basetype, locn);
1431 uint32_t mbr_type_id = ensure_correct_attribute_type(usable_type->self, locn);
1432 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
1433 }
1434 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1435 mark_location_as_used_by_shader(locn, storage);
1436 }
1437 else if (is_builtin && is_tessellation_shader() && vtx_attrs_by_builtin.count(builtin))
1438 {
1439 uint32_t locn = vtx_attrs_by_builtin[builtin].location + i;
1440 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1441 mark_location_as_used_by_shader(locn, storage);
1442 }
1443
1444 if (get_decoration_bitset(var.self).get(DecorationIndex))
1445 {
1446 uint32_t index = get_decoration(var.self, DecorationIndex);
1447 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
1448 }
1449
1450 // Copy interpolation decorations if needed
1451 if (is_flat)
1452 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
1453 if (is_noperspective)
1454 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
1455 if (is_centroid)
1456 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
1457 if (is_sample)
1458 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
1459
1460 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
1461
1462 if (!strip_array)
1463 {
1464 switch (storage)
1465 {
1466 case StorageClassInput:
1467 entry_func.fixup_hooks_in.push_back(
1468 [=, &var]() { statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, ";"); });
1469 break;
1470
1471 case StorageClassOutput:
1472 entry_func.fixup_hooks_out.push_back([=, &var]() {
1473 if (padded_output)
1474 {
1475 auto &padded_type = this->get<SPIRType>(type_id);
1476 statement(
1477 ib_var_ref, ".", mbr_name, " = ",
1478 remap_swizzle(padded_type, usable_type->vecsize, join(to_name(var.self), "[", i, "]")),
1479 ";");
1480 }
1481 else
1482 statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), "[", i, "];");
1483 });
1484 break;
1485
1486 default:
1487 break;
1488 }
1489 }
1490 }
1491 }
1492
get_accumulated_member_location(const SPIRVariable & var,uint32_t mbr_idx,bool strip_array)1493 uint32_t CompilerMSL::get_accumulated_member_location(const SPIRVariable &var, uint32_t mbr_idx, bool strip_array)
1494 {
1495 auto &type = strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
1496 uint32_t location = get_decoration(var.self, DecorationLocation);
1497
1498 for (uint32_t i = 0; i < mbr_idx; i++)
1499 {
1500 auto &mbr_type = get<SPIRType>(type.member_types[i]);
1501
1502 // Start counting from any place we have a new location decoration.
1503 if (has_member_decoration(type.self, mbr_idx, DecorationLocation))
1504 location = get_member_decoration(type.self, mbr_idx, DecorationLocation);
1505
1506 uint32_t location_count = 1;
1507
1508 if (mbr_type.columns > 1)
1509 location_count = mbr_type.columns;
1510
1511 if (!mbr_type.array.empty())
1512 for (uint32_t j = 0; j < uint32_t(mbr_type.array.size()); j++)
1513 location_count *= to_array_size_literal(mbr_type, j);
1514
1515 location += location_count;
1516 }
1517
1518 return location;
1519 }
1520
add_composite_member_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,uint32_t mbr_idx,bool strip_array)1521 void CompilerMSL::add_composite_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
1522 SPIRType &ib_type, SPIRVariable &var,
1523 uint32_t mbr_idx, bool strip_array)
1524 {
1525 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1526 auto &var_type = strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
1527
1528 BuiltIn builtin;
1529 bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
1530 bool is_flat =
1531 has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
1532 bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
1533 has_decoration(var.self, DecorationNoPerspective);
1534 bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
1535 has_decoration(var.self, DecorationCentroid);
1536 bool is_sample =
1537 has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
1538
1539 uint32_t mbr_type_id = var_type.member_types[mbr_idx];
1540 auto &mbr_type = get<SPIRType>(mbr_type_id);
1541 uint32_t elem_cnt = 0;
1542
1543 if (is_matrix(mbr_type))
1544 {
1545 if (is_array(mbr_type))
1546 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
1547
1548 elem_cnt = mbr_type.columns;
1549 }
1550 else if (is_array(mbr_type))
1551 {
1552 if (mbr_type.array.size() != 1)
1553 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
1554
1555 elem_cnt = to_array_size_literal(mbr_type);
1556 }
1557
1558 auto *usable_type = &mbr_type;
1559 if (usable_type->pointer)
1560 usable_type = &get<SPIRType>(usable_type->parent_type);
1561 while (is_array(*usable_type) || is_matrix(*usable_type))
1562 usable_type = &get<SPIRType>(usable_type->parent_type);
1563
1564 for (uint32_t i = 0; i < elem_cnt; i++)
1565 {
1566 // Add a reference to the variable type to the interface struct.
1567 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1568 ib_type.member_types.push_back(usable_type->self);
1569
1570 // Give the member a name
1571 string mbr_name = ensure_valid_name(join(to_qualified_member_name(var_type, mbr_idx), "_", i), "m");
1572 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1573
1574 if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
1575 {
1576 uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation) + i;
1577 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1578 mark_location_as_used_by_shader(locn, storage);
1579 }
1580 else if (has_decoration(var.self, DecorationLocation))
1581 {
1582 uint32_t locn = get_accumulated_member_location(var, mbr_idx, strip_array) + i;
1583 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1584 mark_location_as_used_by_shader(locn, storage);
1585 }
1586 else if (is_builtin && is_tessellation_shader() && vtx_attrs_by_builtin.count(builtin))
1587 {
1588 uint32_t locn = vtx_attrs_by_builtin[builtin].location + i;
1589 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1590 mark_location_as_used_by_shader(locn, storage);
1591 }
1592
1593 if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
1594 SPIRV_CROSS_THROW("DecorationComponent on matrices and arrays make little sense.");
1595
1596 // Copy interpolation decorations if needed
1597 if (is_flat)
1598 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
1599 if (is_noperspective)
1600 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
1601 if (is_centroid)
1602 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
1603 if (is_sample)
1604 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
1605
1606 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
1607 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
1608
1609 // Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
1610 if (!strip_array)
1611 {
1612 switch (storage)
1613 {
1614 case StorageClassInput:
1615 entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
1616 statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), "[", i, "] = ", ib_var_ref,
1617 ".", mbr_name, ";");
1618 });
1619 break;
1620
1621 case StorageClassOutput:
1622 entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
1623 statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), ".",
1624 to_member_name(var_type, mbr_idx), "[", i, "];");
1625 });
1626 break;
1627
1628 default:
1629 break;
1630 }
1631 }
1632 }
1633 }
1634
add_plain_member_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,uint32_t mbr_idx,bool strip_array)1635 void CompilerMSL::add_plain_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
1636 SPIRType &ib_type, SPIRVariable &var, uint32_t mbr_idx,
1637 bool strip_array)
1638 {
1639 auto &var_type = strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
1640 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1641
1642 BuiltIn builtin = BuiltInMax;
1643 bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
1644 bool is_flat =
1645 has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
1646 bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
1647 has_decoration(var.self, DecorationNoPerspective);
1648 bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
1649 has_decoration(var.self, DecorationCentroid);
1650 bool is_sample =
1651 has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
1652
1653 // Add a reference to the member to the interface struct.
1654 uint32_t mbr_type_id = var_type.member_types[mbr_idx];
1655 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1656 mbr_type_id = ensure_correct_builtin_type(mbr_type_id, builtin);
1657 var_type.member_types[mbr_idx] = mbr_type_id;
1658 ib_type.member_types.push_back(mbr_type_id);
1659
1660 // Give the member a name
1661 string mbr_name = ensure_valid_name(to_qualified_member_name(var_type, mbr_idx), "m");
1662 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1663
1664 // Update the original variable reference to include the structure reference
1665 string qual_var_name = ib_var_ref + "." + mbr_name;
1666
1667 if (is_builtin && !strip_array)
1668 {
1669 // For the builtin gl_PerVertex, we cannot treat it as a block anyways,
1670 // so redirect to qualified name.
1671 set_member_qualified_name(var_type.self, mbr_idx, qual_var_name);
1672 }
1673 else if (!strip_array)
1674 {
1675 // Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
1676 switch (storage)
1677 {
1678 case StorageClassInput:
1679 entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
1680 statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), " = ", qual_var_name, ";");
1681 });
1682 break;
1683
1684 case StorageClassOutput:
1685 entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
1686 statement(qual_var_name, " = ", to_name(var.self), ".", to_member_name(var_type, mbr_idx), ";");
1687 });
1688 break;
1689
1690 default:
1691 break;
1692 }
1693 }
1694
1695 // Copy the variable location from the original variable to the member
1696 if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
1697 {
1698 uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation);
1699 if (storage == StorageClassInput && (get_execution_model() == ExecutionModelVertex || is_tessellation_shader()))
1700 {
1701 mbr_type_id = ensure_correct_attribute_type(mbr_type_id, locn);
1702 var_type.member_types[mbr_idx] = mbr_type_id;
1703 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
1704 }
1705 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1706 mark_location_as_used_by_shader(locn, storage);
1707 }
1708 else if (has_decoration(var.self, DecorationLocation))
1709 {
1710 // The block itself might have a location and in this case, all members of the block
1711 // receive incrementing locations.
1712 uint32_t locn = get_accumulated_member_location(var, mbr_idx, strip_array);
1713 if (storage == StorageClassInput && (get_execution_model() == ExecutionModelVertex || is_tessellation_shader()))
1714 {
1715 mbr_type_id = ensure_correct_attribute_type(mbr_type_id, locn);
1716 var_type.member_types[mbr_idx] = mbr_type_id;
1717 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
1718 }
1719 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1720 mark_location_as_used_by_shader(locn, storage);
1721 }
1722 else if (is_builtin && is_tessellation_shader() && vtx_attrs_by_builtin.count(builtin))
1723 {
1724 uint32_t locn = 0;
1725 auto builtin_itr = vtx_attrs_by_builtin.find(builtin);
1726 if (builtin_itr != end(vtx_attrs_by_builtin))
1727 locn = builtin_itr->second.location;
1728 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1729 mark_location_as_used_by_shader(locn, storage);
1730 }
1731
1732 // Copy the component location, if present.
1733 if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
1734 {
1735 uint32_t comp = get_member_decoration(var_type.self, mbr_idx, DecorationComponent);
1736 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp);
1737 }
1738
1739 // Mark the member as builtin if needed
1740 if (is_builtin)
1741 {
1742 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
1743 if (builtin == BuiltInPosition && storage == StorageClassOutput)
1744 qual_pos_var_name = qual_var_name;
1745 }
1746
1747 // Copy interpolation decorations if needed
1748 if (is_flat)
1749 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
1750 if (is_noperspective)
1751 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
1752 if (is_centroid)
1753 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
1754 if (is_sample)
1755 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
1756
1757 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
1758 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
1759 }
1760
1761 // In Metal, the tessellation levels are stored as tightly packed half-precision floating point values.
1762 // But, stage-in attribute offsets and strides must be multiples of four, so we can't pass the levels
1763 // individually. Therefore, we must pass them as vectors. Triangles get a single float4, with the outer
1764 // levels in 'xyz' and the inner level in 'w'. Quads get a float4 containing the outer levels and a
1765 // float2 containing the inner levels.
add_tess_level_input_to_interface_block(const std::string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var)1766 void CompilerMSL::add_tess_level_input_to_interface_block(const std::string &ib_var_ref, SPIRType &ib_type,
1767 SPIRVariable &var)
1768 {
1769 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1770 auto &var_type = get_variable_element_type(var);
1771
1772 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
1773
1774 // Force the variable to have the proper name.
1775 set_name(var.self, builtin_to_glsl(builtin, StorageClassFunction));
1776
1777 if (get_entry_point().flags.get(ExecutionModeTriangles))
1778 {
1779 // Triangles are tricky, because we want only one member in the struct.
1780
1781 // We need to declare the variable early and at entry-point scope.
1782 entry_func.add_local_variable(var.self);
1783 vars_needing_early_declaration.push_back(var.self);
1784
1785 string mbr_name = "gl_TessLevel";
1786
1787 // If we already added the other one, we can skip this step.
1788 if (!added_builtin_tess_level)
1789 {
1790 // Add a reference to the variable type to the interface struct.
1791 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1792
1793 uint32_t type_id = build_extended_vector_type(var_type.self, 4);
1794
1795 ib_type.member_types.push_back(type_id);
1796
1797 // Give the member a name
1798 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1799
1800 // There is no qualified alias since we need to flatten the internal array on return.
1801 if (get_decoration_bitset(var.self).get(DecorationLocation))
1802 {
1803 uint32_t locn = get_decoration(var.self, DecorationLocation);
1804 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1805 mark_location_as_used_by_shader(locn, StorageClassInput);
1806 }
1807 else if (vtx_attrs_by_builtin.count(builtin))
1808 {
1809 uint32_t locn = vtx_attrs_by_builtin[builtin].location;
1810 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1811 mark_location_as_used_by_shader(locn, StorageClassInput);
1812 }
1813
1814 added_builtin_tess_level = true;
1815 }
1816
1817 switch (builtin)
1818 {
1819 case BuiltInTessLevelOuter:
1820 entry_func.fixup_hooks_in.push_back([=, &var]() {
1821 statement(to_name(var.self), "[0] = ", ib_var_ref, ".", mbr_name, ".x;");
1822 statement(to_name(var.self), "[1] = ", ib_var_ref, ".", mbr_name, ".y;");
1823 statement(to_name(var.self), "[2] = ", ib_var_ref, ".", mbr_name, ".z;");
1824 });
1825 break;
1826
1827 case BuiltInTessLevelInner:
1828 entry_func.fixup_hooks_in.push_back(
1829 [=, &var]() { statement(to_name(var.self), "[0] = ", ib_var_ref, ".", mbr_name, ".w;"); });
1830 break;
1831
1832 default:
1833 assert(false);
1834 break;
1835 }
1836 }
1837 else
1838 {
1839 // Add a reference to the variable type to the interface struct.
1840 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1841
1842 uint32_t type_id = build_extended_vector_type(var_type.self, builtin == BuiltInTessLevelOuter ? 4 : 2);
1843 // Change the type of the variable, too.
1844 uint32_t ptr_type_id = ir.increase_bound_by(1);
1845 auto &new_var_type = set<SPIRType>(ptr_type_id, get<SPIRType>(type_id));
1846 new_var_type.pointer = true;
1847 new_var_type.storage = StorageClassInput;
1848 new_var_type.parent_type = type_id;
1849 var.basetype = ptr_type_id;
1850
1851 ib_type.member_types.push_back(type_id);
1852
1853 // Give the member a name
1854 string mbr_name = to_expression(var.self);
1855 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1856
1857 // Since vectors can be indexed like arrays, there is no need to unpack this. We can
1858 // just refer to the vector directly. So give it a qualified alias.
1859 string qual_var_name = ib_var_ref + "." + mbr_name;
1860 ir.meta[var.self].decoration.qualified_alias = qual_var_name;
1861
1862 if (get_decoration_bitset(var.self).get(DecorationLocation))
1863 {
1864 uint32_t locn = get_decoration(var.self, DecorationLocation);
1865 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1866 mark_location_as_used_by_shader(locn, StorageClassInput);
1867 }
1868 else if (vtx_attrs_by_builtin.count(builtin))
1869 {
1870 uint32_t locn = vtx_attrs_by_builtin[builtin].location;
1871 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1872 mark_location_as_used_by_shader(locn, StorageClassInput);
1873 }
1874 }
1875 }
1876
add_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,bool strip_array)1877 void CompilerMSL::add_variable_to_interface_block(StorageClass storage, const string &ib_var_ref, SPIRType &ib_type,
1878 SPIRVariable &var, bool strip_array)
1879 {
1880 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1881 // Tessellation control I/O variables and tessellation evaluation per-point inputs are
1882 // usually declared as arrays. In these cases, we want to add the element type to the
1883 // interface block, since in Metal it's the interface block itself which is arrayed.
1884 auto &var_type = strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
1885 bool is_builtin = is_builtin_variable(var);
1886 auto builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
1887
1888 if (var_type.basetype == SPIRType::Struct)
1889 {
1890 if (!is_builtin_type(var_type) && (!capture_output_to_buffer || storage == StorageClassInput) && !strip_array)
1891 {
1892 // For I/O blocks or structs, we will need to pass the block itself around
1893 // to functions if they are used globally in leaf functions.
1894 // Rather than passing down member by member,
1895 // we unflatten I/O blocks while running the shader,
1896 // and pass the actual struct type down to leaf functions.
1897 // We then unflatten inputs, and flatten outputs in the "fixup" stages.
1898 entry_func.add_local_variable(var.self);
1899 vars_needing_early_declaration.push_back(var.self);
1900 }
1901
1902 if (capture_output_to_buffer && storage != StorageClassInput && !has_decoration(var_type.self, DecorationBlock))
1903 {
1904 // In Metal tessellation shaders, the interface block itself is arrayed. This makes things
1905 // very complicated, since stage-in structures in MSL don't support nested structures.
1906 // Luckily, for stage-out when capturing output, we can avoid this and just add
1907 // composite members directly, because the stage-out structure is stored to a buffer,
1908 // not returned.
1909 add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, strip_array);
1910 }
1911 else
1912 {
1913 // Flatten the struct members into the interface struct
1914 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++)
1915 {
1916 builtin = BuiltInMax;
1917 is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
1918 auto &mbr_type = get<SPIRType>(var_type.member_types[mbr_idx]);
1919
1920 if (!is_builtin || has_active_builtin(builtin, storage))
1921 {
1922 if ((!is_builtin ||
1923 (storage == StorageClassInput && get_execution_model() != ExecutionModelFragment)) &&
1924 (storage == StorageClassInput || storage == StorageClassOutput) &&
1925 (is_matrix(mbr_type) || is_array(mbr_type)))
1926 {
1927 add_composite_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx,
1928 strip_array);
1929 }
1930 else
1931 {
1932 add_plain_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx,
1933 strip_array);
1934 }
1935 }
1936 }
1937 }
1938 }
1939 else if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput &&
1940 !strip_array && is_builtin && (builtin == BuiltInTessLevelOuter || builtin == BuiltInTessLevelInner))
1941 {
1942 add_tess_level_input_to_interface_block(ib_var_ref, ib_type, var);
1943 }
1944 else if (var_type.basetype == SPIRType::Boolean || var_type.basetype == SPIRType::Char ||
1945 type_is_integral(var_type) || type_is_floating_point(var_type) || var_type.basetype == SPIRType::Boolean)
1946 {
1947 if (!is_builtin || has_active_builtin(builtin, storage))
1948 {
1949 // MSL does not allow matrices or arrays in input or output variables, so need to handle it specially.
1950 if ((!is_builtin || (storage == StorageClassInput && get_execution_model() != ExecutionModelFragment)) &&
1951 (storage == StorageClassInput || (storage == StorageClassOutput && !capture_output_to_buffer)) &&
1952 (is_matrix(var_type) || is_array(var_type)))
1953 {
1954 add_composite_variable_to_interface_block(storage, ib_var_ref, ib_type, var, strip_array);
1955 }
1956 else
1957 {
1958 add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, strip_array);
1959 }
1960 }
1961 }
1962 }
1963
1964 // Fix up the mapping of variables to interface member indices, which is used to compile access chains
1965 // for per-vertex variables in a tessellation control shader.
fix_up_interface_member_indices(StorageClass storage,uint32_t ib_type_id)1966 void CompilerMSL::fix_up_interface_member_indices(StorageClass storage, uint32_t ib_type_id)
1967 {
1968 // Only needed for tessellation shaders.
1969 if (get_execution_model() != ExecutionModelTessellationControl &&
1970 !(get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput))
1971 return;
1972
1973 bool in_array = false;
1974 for (uint32_t i = 0; i < ir.meta[ib_type_id].members.size(); i++)
1975 {
1976 auto &mbr_dec = ir.meta[ib_type_id].members[i];
1977 uint32_t var_id = mbr_dec.extended.ib_orig_id;
1978 if (!var_id)
1979 continue;
1980 auto &var = get<SPIRVariable>(var_id);
1981
1982 // Unfortunately, all this complexity is needed to handle flattened structs and/or
1983 // arrays.
1984 if (storage == StorageClassInput)
1985 {
1986 auto &type = get_variable_element_type(var);
1987 if (is_array(type) || is_matrix(type))
1988 {
1989 if (in_array)
1990 continue;
1991 in_array = true;
1992 set_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex, i);
1993 }
1994 else
1995 {
1996 if (type.basetype == SPIRType::Struct)
1997 {
1998 uint32_t mbr_idx =
1999 get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceMemberIndex);
2000 auto &mbr_type = get<SPIRType>(type.member_types[mbr_idx]);
2001
2002 if (is_array(mbr_type) || is_matrix(mbr_type))
2003 {
2004 if (in_array)
2005 continue;
2006 in_array = true;
2007 set_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, i);
2008 }
2009 else
2010 {
2011 in_array = false;
2012 set_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, i);
2013 }
2014 }
2015 else
2016 {
2017 in_array = false;
2018 set_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex, i);
2019 }
2020 }
2021 }
2022 else
2023 set_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex, i);
2024 }
2025 }
2026
2027 // Add an interface structure for the type of storage, which is either StorageClassInput or StorageClassOutput.
2028 // Returns the ID of the newly added variable, or zero if no variable was added.
add_interface_block(StorageClass storage,bool patch)2029 uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch)
2030 {
2031 // Accumulate the variables that should appear in the interface struct.
2032 SmallVector<SPIRVariable *> vars;
2033 bool incl_builtins = storage == StorageClassOutput || is_tessellation_shader();
2034 bool has_seen_barycentric = false;
2035
2036 ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
2037 if (var.storage != storage)
2038 return;
2039
2040 auto &type = this->get<SPIRType>(var.basetype);
2041
2042 bool is_builtin = is_builtin_variable(var);
2043 auto bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
2044
2045 // These builtins are part of the stage in/out structs.
2046 bool is_interface_block_builtin =
2047 (bi_type == BuiltInPosition || bi_type == BuiltInPointSize || bi_type == BuiltInClipDistance ||
2048 bi_type == BuiltInCullDistance || bi_type == BuiltInLayer || bi_type == BuiltInViewportIndex ||
2049 bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV || bi_type == BuiltInFragDepth ||
2050 bi_type == BuiltInFragStencilRefEXT || bi_type == BuiltInSampleMask) ||
2051 (get_execution_model() == ExecutionModelTessellationEvaluation &&
2052 (bi_type == BuiltInTessLevelOuter || bi_type == BuiltInTessLevelInner));
2053
2054 bool is_active = interface_variable_exists_in_entry_point(var.self);
2055 if (is_builtin && is_active)
2056 {
2057 // Only emit the builtin if it's active in this entry point. Interface variable list might lie.
2058 is_active = has_active_builtin(bi_type, storage);
2059 }
2060
2061 bool filter_patch_decoration = (has_decoration(var_id, DecorationPatch) || is_patch_block(type)) == patch;
2062
2063 bool hidden = is_hidden_variable(var, incl_builtins);
2064 // Barycentric inputs must be emitted in stage-in, because they can have interpolation arguments.
2065 if (is_active && (bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV))
2066 {
2067 if (has_seen_barycentric)
2068 SPIRV_CROSS_THROW("Cannot declare both BaryCoordNV and BaryCoordNoPerspNV in same shader in MSL.");
2069 has_seen_barycentric = true;
2070 hidden = false;
2071 }
2072
2073 if (is_active && !hidden && type.pointer && filter_patch_decoration &&
2074 (!is_builtin || is_interface_block_builtin))
2075 {
2076 vars.push_back(&var);
2077 }
2078 });
2079
2080 // If no variables qualify, leave.
2081 // For patch input in a tessellation evaluation shader, the per-vertex stage inputs
2082 // are included in a special patch control point array.
2083 if (vars.empty() && !(storage == StorageClassInput && patch && stage_in_var_id))
2084 return 0;
2085
2086 // Add a new typed variable for this interface structure.
2087 // The initializer expression is allocated here, but populated when the function
2088 // declaraion is emitted, because it is cleared after each compilation pass.
2089 uint32_t next_id = ir.increase_bound_by(3);
2090 uint32_t ib_type_id = next_id++;
2091 auto &ib_type = set<SPIRType>(ib_type_id);
2092 ib_type.basetype = SPIRType::Struct;
2093 ib_type.storage = storage;
2094 set_decoration(ib_type_id, DecorationBlock);
2095
2096 uint32_t ib_var_id = next_id++;
2097 auto &var = set<SPIRVariable>(ib_var_id, ib_type_id, storage, 0);
2098 var.initializer = next_id++;
2099
2100 string ib_var_ref;
2101 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2102 switch (storage)
2103 {
2104 case StorageClassInput:
2105 ib_var_ref = patch ? patch_stage_in_var_name : stage_in_var_name;
2106 if (get_execution_model() == ExecutionModelTessellationControl)
2107 {
2108 // Add a hook to populate the shared workgroup memory containing
2109 // the gl_in array.
2110 entry_func.fixup_hooks_in.push_back([=]() {
2111 // Can't use PatchVertices yet; the hook for that may not have run yet.
2112 statement("if (", to_expression(builtin_invocation_id_id), " < ", "spvIndirectParams[0])");
2113 statement(" ", input_wg_var_name, "[", to_expression(builtin_invocation_id_id), "] = ", ib_var_ref,
2114 ";");
2115 statement("threadgroup_barrier(mem_flags::mem_threadgroup);");
2116 statement("if (", to_expression(builtin_invocation_id_id), " >= ", get_entry_point().output_vertices,
2117 ")");
2118 statement(" return;");
2119 });
2120 }
2121 break;
2122
2123 case StorageClassOutput:
2124 {
2125 ib_var_ref = patch ? patch_stage_out_var_name : stage_out_var_name;
2126
2127 // Add the output interface struct as a local variable to the entry function.
2128 // If the entry point should return the output struct, set the entry function
2129 // to return the output interface struct, otherwise to return nothing.
2130 // Indicate the output var requires early initialization.
2131 bool ep_should_return_output = !get_is_rasterization_disabled();
2132 uint32_t rtn_id = ep_should_return_output ? ib_var_id : 0;
2133 if (!capture_output_to_buffer)
2134 {
2135 entry_func.add_local_variable(ib_var_id);
2136 for (auto &blk_id : entry_func.blocks)
2137 {
2138 auto &blk = get<SPIRBlock>(blk_id);
2139 if (blk.terminator == SPIRBlock::Return)
2140 blk.return_value = rtn_id;
2141 }
2142 vars_needing_early_declaration.push_back(ib_var_id);
2143 }
2144 else
2145 {
2146 switch (get_execution_model())
2147 {
2148 case ExecutionModelVertex:
2149 case ExecutionModelTessellationEvaluation:
2150 // Instead of declaring a struct variable to hold the output and then
2151 // copying that to the output buffer, we'll declare the output variable
2152 // as a reference to the final output element in the buffer. Then we can
2153 // avoid the extra copy.
2154 entry_func.fixup_hooks_in.push_back([=]() {
2155 if (stage_out_var_id)
2156 {
2157 // The first member of the indirect buffer is always the number of vertices
2158 // to draw.
2159 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref, " = ",
2160 output_buffer_var_name, "[(", to_expression(builtin_instance_idx_id), " - ",
2161 to_expression(builtin_base_instance_id), ") * spvIndirectParams[0] + ",
2162 to_expression(builtin_vertex_idx_id), " - ", to_expression(builtin_base_vertex_id),
2163 "];");
2164 }
2165 });
2166 break;
2167 case ExecutionModelTessellationControl:
2168 if (patch)
2169 entry_func.fixup_hooks_in.push_back([=]() {
2170 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref, " = ",
2171 patch_output_buffer_var_name, "[", to_expression(builtin_primitive_id_id), "];");
2172 });
2173 else
2174 entry_func.fixup_hooks_in.push_back([=]() {
2175 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &",
2176 output_buffer_var_name, "[", to_expression(builtin_primitive_id_id), " * ",
2177 get_entry_point().output_vertices, "];");
2178 });
2179 break;
2180 default:
2181 break;
2182 }
2183 }
2184 break;
2185 }
2186
2187 default:
2188 break;
2189 }
2190
2191 set_name(ib_type_id, to_name(ir.default_entry_point) + "_" + ib_var_ref);
2192 set_name(ib_var_id, ib_var_ref);
2193
2194 for (auto *p_var : vars)
2195 {
2196 bool strip_array =
2197 (get_execution_model() == ExecutionModelTessellationControl ||
2198 (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput)) &&
2199 !patch;
2200 add_variable_to_interface_block(storage, ib_var_ref, ib_type, *p_var, strip_array);
2201 }
2202
2203 // Sort the members of the structure by their locations.
2204 MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Location);
2205 member_sorter.sort();
2206
2207 // The member indices were saved to the original variables, but after the members
2208 // were sorted, those indices are now likely incorrect. Fix those up now.
2209 if (!patch)
2210 fix_up_interface_member_indices(storage, ib_type_id);
2211
2212 // For patch inputs, add one more member, holding the array of control point data.
2213 if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput && patch &&
2214 stage_in_var_id)
2215 {
2216 uint32_t pcp_type_id = ir.increase_bound_by(1);
2217 auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
2218 pcp_type.basetype = SPIRType::ControlPointArray;
2219 pcp_type.parent_type = pcp_type.type_alias = get_stage_in_struct_type().self;
2220 pcp_type.storage = storage;
2221 ir.meta[pcp_type_id] = ir.meta[ib_type.self];
2222 uint32_t mbr_idx = uint32_t(ib_type.member_types.size());
2223 ib_type.member_types.push_back(pcp_type_id);
2224 set_member_name(ib_type.self, mbr_idx, "gl_in");
2225 }
2226
2227 return ib_var_id;
2228 }
2229
add_interface_block_pointer(uint32_t ib_var_id,StorageClass storage)2230 uint32_t CompilerMSL::add_interface_block_pointer(uint32_t ib_var_id, StorageClass storage)
2231 {
2232 if (!ib_var_id)
2233 return 0;
2234
2235 uint32_t ib_ptr_var_id;
2236 uint32_t next_id = ir.increase_bound_by(3);
2237 auto &ib_type = expression_type(ib_var_id);
2238 if (get_execution_model() == ExecutionModelTessellationControl)
2239 {
2240 // Tessellation control per-vertex I/O is presented as an array, so we must
2241 // do the same with our struct here.
2242 uint32_t ib_ptr_type_id = next_id++;
2243 auto &ib_ptr_type = set<SPIRType>(ib_ptr_type_id, ib_type);
2244 ib_ptr_type.parent_type = ib_ptr_type.type_alias = ib_type.self;
2245 ib_ptr_type.pointer = true;
2246 ib_ptr_type.storage = storage == StorageClassInput ? StorageClassWorkgroup : StorageClassStorageBuffer;
2247 ir.meta[ib_ptr_type_id] = ir.meta[ib_type.self];
2248 // To ensure that get_variable_data_type() doesn't strip off the pointer,
2249 // which we need, use another pointer.
2250 uint32_t ib_ptr_ptr_type_id = next_id++;
2251 auto &ib_ptr_ptr_type = set<SPIRType>(ib_ptr_ptr_type_id, ib_ptr_type);
2252 ib_ptr_ptr_type.parent_type = ib_ptr_type_id;
2253 ib_ptr_ptr_type.type_alias = ib_type.self;
2254 ib_ptr_ptr_type.storage = StorageClassFunction;
2255 ir.meta[ib_ptr_ptr_type_id] = ir.meta[ib_type.self];
2256
2257 ib_ptr_var_id = next_id;
2258 set<SPIRVariable>(ib_ptr_var_id, ib_ptr_ptr_type_id, StorageClassFunction, 0);
2259 set_name(ib_ptr_var_id, storage == StorageClassInput ? input_wg_var_name : "gl_out");
2260 }
2261 else
2262 {
2263 // Tessellation evaluation per-vertex inputs are also presented as arrays.
2264 // But, in Metal, this array uses a very special type, 'patch_control_point<T>',
2265 // which is a container that can be used to access the control point data.
2266 // To represent this, a special 'ControlPointArray' type has been added to the
2267 // SPIRV-Cross type system. It should only be generated by and seen in the MSL
2268 // backend (i.e. this one).
2269 uint32_t pcp_type_id = next_id++;
2270 auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
2271 pcp_type.basetype = SPIRType::ControlPointArray;
2272 pcp_type.parent_type = pcp_type.type_alias = ib_type.self;
2273 pcp_type.storage = storage;
2274 ir.meta[pcp_type_id] = ir.meta[ib_type.self];
2275
2276 ib_ptr_var_id = next_id;
2277 set<SPIRVariable>(ib_ptr_var_id, pcp_type_id, storage, 0);
2278 set_name(ib_ptr_var_id, "gl_in");
2279 ir.meta[ib_ptr_var_id].decoration.qualified_alias = join(patch_stage_in_var_name, ".gl_in");
2280 }
2281 return ib_ptr_var_id;
2282 }
2283
2284 // Ensure that the type is compatible with the builtin.
2285 // If it is, simply return the given type ID.
2286 // Otherwise, create a new type, and return it's ID.
ensure_correct_builtin_type(uint32_t type_id,BuiltIn builtin)2287 uint32_t CompilerMSL::ensure_correct_builtin_type(uint32_t type_id, BuiltIn builtin)
2288 {
2289 auto &type = get<SPIRType>(type_id);
2290
2291 if ((builtin == BuiltInSampleMask && is_array(type)) ||
2292 ((builtin == BuiltInLayer || builtin == BuiltInViewportIndex || builtin == BuiltInFragStencilRefEXT) &&
2293 type.basetype != SPIRType::UInt))
2294 {
2295 uint32_t next_id = ir.increase_bound_by(type.pointer ? 2 : 1);
2296 uint32_t base_type_id = next_id++;
2297 auto &base_type = set<SPIRType>(base_type_id);
2298 base_type.basetype = SPIRType::UInt;
2299 base_type.width = 32;
2300
2301 if (!type.pointer)
2302 return base_type_id;
2303
2304 uint32_t ptr_type_id = next_id++;
2305 auto &ptr_type = set<SPIRType>(ptr_type_id);
2306 ptr_type = base_type;
2307 ptr_type.pointer = true;
2308 ptr_type.storage = type.storage;
2309 ptr_type.parent_type = base_type_id;
2310 return ptr_type_id;
2311 }
2312
2313 return type_id;
2314 }
2315
2316 // Ensure that the type is compatible with the vertex attribute.
2317 // If it is, simply return the given type ID.
2318 // Otherwise, create a new type, and return its ID.
ensure_correct_attribute_type(uint32_t type_id,uint32_t location)2319 uint32_t CompilerMSL::ensure_correct_attribute_type(uint32_t type_id, uint32_t location)
2320 {
2321 auto &type = get<SPIRType>(type_id);
2322
2323 auto p_va = vtx_attrs_by_location.find(location);
2324 if (p_va == end(vtx_attrs_by_location))
2325 return type_id;
2326
2327 switch (p_va->second.format)
2328 {
2329 case MSL_VERTEX_FORMAT_UINT8:
2330 {
2331 switch (type.basetype)
2332 {
2333 case SPIRType::UByte:
2334 case SPIRType::UShort:
2335 case SPIRType::UInt:
2336 return type_id;
2337 case SPIRType::Short:
2338 case SPIRType::Int:
2339 break;
2340 default:
2341 SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
2342 }
2343 uint32_t next_id = ir.increase_bound_by(type.pointer ? 2 : 1);
2344 uint32_t base_type_id = next_id++;
2345 auto &base_type = set<SPIRType>(base_type_id);
2346 base_type = type;
2347 base_type.basetype = type.basetype == SPIRType::Short ? SPIRType::UShort : SPIRType::UInt;
2348 base_type.pointer = false;
2349
2350 if (!type.pointer)
2351 return base_type_id;
2352
2353 uint32_t ptr_type_id = next_id++;
2354 auto &ptr_type = set<SPIRType>(ptr_type_id);
2355 ptr_type = base_type;
2356 ptr_type.pointer = true;
2357 ptr_type.storage = type.storage;
2358 ptr_type.parent_type = base_type_id;
2359 return ptr_type_id;
2360 }
2361
2362 case MSL_VERTEX_FORMAT_UINT16:
2363 {
2364 switch (type.basetype)
2365 {
2366 case SPIRType::UShort:
2367 case SPIRType::UInt:
2368 return type_id;
2369 case SPIRType::Int:
2370 break;
2371 default:
2372 SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
2373 }
2374 uint32_t next_id = ir.increase_bound_by(type.pointer ? 2 : 1);
2375 uint32_t base_type_id = next_id++;
2376 auto &base_type = set<SPIRType>(base_type_id);
2377 base_type = type;
2378 base_type.basetype = SPIRType::UInt;
2379 base_type.pointer = false;
2380
2381 if (!type.pointer)
2382 return base_type_id;
2383
2384 uint32_t ptr_type_id = next_id++;
2385 auto &ptr_type = set<SPIRType>(ptr_type_id);
2386 ptr_type = base_type;
2387 ptr_type.pointer = true;
2388 ptr_type.storage = type.storage;
2389 ptr_type.parent_type = base_type_id;
2390 return ptr_type_id;
2391 }
2392
2393 default:
2394 case MSL_VERTEX_FORMAT_OTHER:
2395 break;
2396 }
2397
2398 return type_id;
2399 }
2400
2401 // Sort the members of the struct type by offset, and pack and then pad members where needed
2402 // to align MSL members with SPIR-V offsets. The struct members are iterated twice. Packing
2403 // occurs first, followed by padding, because packing a member reduces both its size and its
2404 // natural alignment, possibly requiring a padding member to be added ahead of it.
align_struct(SPIRType & ib_type)2405 void CompilerMSL::align_struct(SPIRType &ib_type)
2406 {
2407 uint32_t &ib_type_id = ib_type.self;
2408
2409 // Sort the members of the interface structure by their offset.
2410 // They should already be sorted per SPIR-V spec anyway.
2411 MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Offset);
2412 member_sorter.sort();
2413
2414 uint32_t mbr_cnt = uint32_t(ib_type.member_types.size());
2415
2416 // Test the alignment of each member, and if a member should be closer to the previous
2417 // member than the default spacing expects, it is likely that the previous member is in
2418 // a packed format. If so, and the previous member is packable, pack it.
2419 // For example...this applies to any 3-element vector that is followed by a scalar.
2420 uint32_t curr_offset = 0;
2421 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
2422 {
2423 if (is_member_packable(ib_type, mbr_idx))
2424 {
2425 set_extended_member_decoration(ib_type_id, mbr_idx, SPIRVCrossDecorationPacked);
2426 set_extended_member_decoration(ib_type_id, mbr_idx, SPIRVCrossDecorationPackedType,
2427 get_member_packed_type(ib_type, mbr_idx));
2428 }
2429
2430 // Align current offset to the current member's default alignment.
2431 size_t align_mask = get_declared_struct_member_alignment(ib_type, mbr_idx) - 1;
2432 uint32_t aligned_curr_offset = uint32_t((curr_offset + align_mask) & ~align_mask);
2433
2434 // Fetch the member offset as declared in the SPIRV.
2435 uint32_t mbr_offset = get_member_decoration(ib_type_id, mbr_idx, DecorationOffset);
2436 if (mbr_offset > aligned_curr_offset)
2437 {
2438 // Since MSL and SPIR-V have slightly different struct member alignment and
2439 // size rules, we'll pad to standard C-packing rules. If the member is farther
2440 // away than C-packing, expects, add an inert padding member before the the member.
2441 MSLStructMemberKey key = get_struct_member_key(ib_type_id, mbr_idx);
2442 struct_member_padding[key] = mbr_offset - curr_offset;
2443 }
2444
2445 // Increment the current offset to be positioned immediately after the current member.
2446 // Don't do this for the last member since it can be unsized, and it is not relevant for padding purposes here.
2447 if (mbr_idx + 1 < mbr_cnt)
2448 curr_offset = mbr_offset + uint32_t(get_declared_struct_member_size_msl(ib_type, mbr_idx));
2449 }
2450 }
2451
2452 // Returns whether the specified struct member supports a packable type
2453 // variation that is smaller than the unpacked variation of that type.
is_member_packable(SPIRType & ib_type,uint32_t index,uint32_t base_offset)2454 bool CompilerMSL::is_member_packable(SPIRType &ib_type, uint32_t index, uint32_t base_offset)
2455 {
2456 // We've already marked it as packable
2457 if (has_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPacked))
2458 return true;
2459
2460 auto &mbr_type = get<SPIRType>(ib_type.member_types[index]);
2461
2462 uint32_t component_size = mbr_type.width / 8;
2463 uint32_t unpacked_mbr_size;
2464 if (mbr_type.vecsize == 3)
2465 unpacked_mbr_size = component_size * (mbr_type.vecsize + 1) * mbr_type.columns;
2466 else
2467 unpacked_mbr_size = component_size * mbr_type.vecsize * mbr_type.columns;
2468
2469 // Special case for packing. Check for float[] or vec2[] in std140 layout. Here we actually need to pad out instead,
2470 // but we will use the same mechanism.
2471 if (is_array(mbr_type) && (is_scalar(mbr_type) || is_vector(mbr_type)) && mbr_type.vecsize <= 2 &&
2472 type_struct_member_array_stride(ib_type, index) == 4 * component_size)
2473 {
2474 return true;
2475 }
2476
2477 uint32_t mbr_offset_curr = base_offset + get_member_decoration(ib_type.self, index, DecorationOffset);
2478 if (mbr_type.basetype == SPIRType::Struct)
2479 {
2480 // If this is a struct type, check if any of its members need packing.
2481 for (uint32_t i = 0; i < mbr_type.member_types.size(); i++)
2482 {
2483 if (is_member_packable(mbr_type, i, mbr_offset_curr))
2484 {
2485 set_extended_member_decoration(mbr_type.self, i, SPIRVCrossDecorationPacked);
2486 set_extended_member_decoration(mbr_type.self, i, SPIRVCrossDecorationPackedType,
2487 get_member_packed_type(mbr_type, i));
2488 }
2489 }
2490 size_t declared_struct_size = get_declared_struct_size(mbr_type);
2491 size_t alignment = get_declared_struct_member_alignment(ib_type, index);
2492 declared_struct_size = (declared_struct_size + alignment - 1) & ~(alignment - 1);
2493 // Check for array of struct, where the SPIR-V declares an array stride which is larger than the struct itself.
2494 // This can happen for struct A { float a }; A a[]; in std140 layout.
2495 // TODO: Emit a padded struct which can be used for this purpose.
2496 if (is_array(mbr_type))
2497 {
2498 size_t array_stride = type_struct_member_array_stride(ib_type, index);
2499 if (array_stride > declared_struct_size)
2500 return true;
2501 if (array_stride < declared_struct_size)
2502 {
2503 // If the stride is *less* (i.e. more tightly packed), then
2504 // we need to pack the members of the struct itself.
2505 for (uint32_t i = 0; i < mbr_type.member_types.size(); i++)
2506 {
2507 if (is_member_packable(mbr_type, i, mbr_offset_curr + array_stride))
2508 {
2509 set_extended_member_decoration(mbr_type.self, i, SPIRVCrossDecorationPacked);
2510 set_extended_member_decoration(mbr_type.self, i, SPIRVCrossDecorationPackedType,
2511 get_member_packed_type(mbr_type, i));
2512 }
2513 }
2514 }
2515 }
2516 else
2517 {
2518 // Pack if there is not enough space between this member and next.
2519 if (index < ib_type.member_types.size() - 1)
2520 {
2521 uint32_t mbr_offset_next =
2522 base_offset + get_member_decoration(ib_type.self, index + 1, DecorationOffset);
2523 if (declared_struct_size > mbr_offset_next - mbr_offset_curr)
2524 {
2525 for (uint32_t i = 0; i < mbr_type.member_types.size(); i++)
2526 {
2527 if (is_member_packable(mbr_type, i, mbr_offset_next))
2528 {
2529 set_extended_member_decoration(mbr_type.self, i, SPIRVCrossDecorationPacked);
2530 set_extended_member_decoration(mbr_type.self, i, SPIRVCrossDecorationPackedType,
2531 get_member_packed_type(mbr_type, i));
2532 }
2533 }
2534 }
2535 }
2536 }
2537 }
2538
2539 // TODO: Another sanity check for matrices. We currently do not support std140 matrices which need to be padded out per column.
2540 //if (is_matrix(mbr_type) && mbr_type.vecsize <= 2 && type_struct_member_matrix_stride(ib_type, index) == 16)
2541 // SPIRV_CROSS_THROW("Currently cannot support matrices with small vector size in std140 layout.");
2542
2543 // Pack if the member's offset doesn't conform to the type's usual
2544 // alignment. For example, a float3 at offset 4.
2545 if (mbr_offset_curr % get_declared_struct_member_alignment(ib_type, index))
2546 return true;
2547
2548 // Only vectors or 3-row matrices need to be packed.
2549 if (mbr_type.vecsize == 1 || (is_matrix(mbr_type) && mbr_type.vecsize != 3))
2550 return false;
2551
2552 if (is_array(mbr_type))
2553 {
2554 // If member is an array, and the array stride is larger than the type needs, don't pack it.
2555 // Take into consideration multi-dimentional arrays.
2556 uint32_t md_elem_cnt = 1;
2557 size_t last_elem_idx = mbr_type.array.size() - 1;
2558 for (uint32_t i = 0; i < last_elem_idx; i++)
2559 #ifdef RARCH_INTERNAL
2560 md_elem_cnt *= MAX(to_array_size_literal(mbr_type, i), 1u);
2561 #else
2562 md_elem_cnt *= max(to_array_size_literal(mbr_type, i), 1u);
2563 #endif
2564
2565 uint32_t unpacked_array_stride = unpacked_mbr_size * md_elem_cnt;
2566 uint32_t array_stride = type_struct_member_array_stride(ib_type, index);
2567 return unpacked_array_stride > array_stride;
2568 }
2569 else
2570 {
2571 // Pack if there is not enough space between this member and next.
2572 // If last member, only pack if it's a row-major matrix.
2573 if (index < ib_type.member_types.size() - 1)
2574 {
2575 uint32_t mbr_offset_next = base_offset + get_member_decoration(ib_type.self, index + 1, DecorationOffset);
2576 return unpacked_mbr_size > mbr_offset_next - mbr_offset_curr;
2577 }
2578 else
2579 return is_matrix(mbr_type);
2580 }
2581 }
2582
get_member_packed_type(SPIRType & type,uint32_t index)2583 uint32_t CompilerMSL::get_member_packed_type(SPIRType &type, uint32_t index)
2584 {
2585 auto &mbr_type = get<SPIRType>(type.member_types[index]);
2586 if (is_matrix(mbr_type) && has_member_decoration(type.self, index, DecorationRowMajor))
2587 {
2588 // Packed row-major matrices are stored transposed. But, we don't know if
2589 // we're dealing with a row-major matrix at the time we need to load it.
2590 // So, we'll set a packed type with the columns and rows transposed, so we'll
2591 // know to use the correct constructor.
2592 uint32_t new_type_id = ir.increase_bound_by(1);
2593 auto &transpose_type = set<SPIRType>(new_type_id);
2594 transpose_type = mbr_type;
2595 transpose_type.vecsize = mbr_type.columns;
2596 transpose_type.columns = mbr_type.vecsize;
2597 return new_type_id;
2598 }
2599 return type.member_types[index];
2600 }
2601
2602 // Returns a combination of type ID and member index for use as hash key
get_struct_member_key(uint32_t type_id,uint32_t index)2603 MSLStructMemberKey CompilerMSL::get_struct_member_key(uint32_t type_id, uint32_t index)
2604 {
2605 MSLStructMemberKey k = type_id;
2606 k <<= 32;
2607 k += index;
2608 return k;
2609 }
2610
emit_store_statement(uint32_t lhs_expression,uint32_t rhs_expression)2611 void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression)
2612 {
2613 if (!has_extended_decoration(lhs_expression, SPIRVCrossDecorationPacked) ||
2614 get_extended_decoration(lhs_expression, SPIRVCrossDecorationPackedType) == 0)
2615 {
2616 CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
2617 }
2618 else
2619 {
2620 // Special handling when storing to a float[] or float2[] in std140 layout.
2621
2622 uint32_t type_id = get_extended_decoration(lhs_expression, SPIRVCrossDecorationPackedType);
2623 auto &type = get<SPIRType>(type_id);
2624 string lhs = to_dereferenced_expression(lhs_expression);
2625 string rhs = to_pointer_expression(rhs_expression);
2626 uint32_t stride = get_decoration(type_id, DecorationArrayStride);
2627
2628 if (is_matrix(type))
2629 {
2630 // Packed matrices are stored as arrays of packed vectors, so we need
2631 // to assign the vectors one at a time.
2632 // For row-major matrices, we need to transpose the *right-hand* side,
2633 // not the left-hand side. Otherwise, the changes will be lost.
2634 auto *lhs_e = maybe_get<SPIRExpression>(lhs_expression);
2635 auto *rhs_e = maybe_get<SPIRExpression>(rhs_expression);
2636 bool transpose = lhs_e && lhs_e->need_transpose;
2637 if (transpose)
2638 {
2639 lhs_e->need_transpose = false;
2640 if (rhs_e) rhs_e->need_transpose = !rhs_e->need_transpose;
2641 lhs = to_dereferenced_expression(lhs_expression);
2642 rhs = to_pointer_expression(rhs_expression);
2643 }
2644 for (uint32_t i = 0; i < type.columns; i++)
2645 statement(enclose_expression(lhs), "[", i, "] = ", enclose_expression(rhs), "[", i, "];");
2646 if (transpose)
2647 {
2648 lhs_e->need_transpose = true;
2649 if (rhs_e) rhs_e->need_transpose = !rhs_e->need_transpose;
2650 }
2651 }
2652 else if (is_array(type) && stride == 4 * type.width / 8)
2653 {
2654 // Unpack the expression so we can store to it with a float or float2.
2655 // It's still an l-value, so it's fine. Most other unpacking of expressions turn them into r-values instead.
2656 if (is_scalar(type))
2657 lhs = enclose_expression(lhs) + ".x";
2658 else if (is_vector(type) && type.vecsize == 2)
2659 lhs = enclose_expression(lhs) + ".xy";
2660 }
2661
2662 if (!is_matrix(type))
2663 {
2664 if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs))
2665 statement(lhs, " = ", rhs, ";");
2666 }
2667 register_write(lhs_expression);
2668 }
2669 }
2670
2671 // Converts the format of the current expression from packed to unpacked,
2672 // by wrapping the expression in a constructor of the appropriate type.
unpack_expression_type(string expr_str,const SPIRType & type,uint32_t packed_type_id)2673 string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type, uint32_t packed_type_id)
2674 {
2675 const SPIRType *packed_type = nullptr;
2676 uint32_t stride = 0;
2677 if (packed_type_id)
2678 {
2679 packed_type = &get<SPIRType>(packed_type_id);
2680 stride = get_decoration(packed_type_id, DecorationArrayStride);
2681 }
2682
2683 // float[] and float2[] cases are really just padding, so directly swizzle from the backing float4 instead.
2684 if (packed_type && is_array(*packed_type) && is_scalar(*packed_type) && stride == 4 * packed_type->width / 8)
2685 return enclose_expression(expr_str) + ".x";
2686 else if (packed_type && is_array(*packed_type) && is_vector(*packed_type) && packed_type->vecsize == 2 &&
2687 stride == 4 * packed_type->width / 8)
2688 return enclose_expression(expr_str) + ".xy";
2689 else if (is_matrix(type))
2690 {
2691 // Packed matrices are stored as arrays of packed vectors. Unfortunately,
2692 // we can't just pass the array straight to the matrix constructor. We have to
2693 // pass each vector individually, so that they can be unpacked to normal vectors.
2694 if (!packed_type)
2695 packed_type = &type;
2696 const char *base_type = packed_type->width == 16 ? "half" : "float";
2697 string unpack_expr = join(type_to_glsl(*packed_type), "(");
2698 for (uint32_t i = 0; i < packed_type->columns; i++)
2699 {
2700 if (i > 0)
2701 unpack_expr += ", ";
2702 unpack_expr += join(base_type, packed_type->vecsize, "(", expr_str, "[", i, "])");
2703 }
2704 unpack_expr += ")";
2705 return unpack_expr;
2706 }
2707 else
2708 return join(type_to_glsl(type), "(", expr_str, ")");
2709 }
2710
2711 // Emits the file header info
emit_header()2712 void CompilerMSL::emit_header()
2713 {
2714 // This particular line can be overridden during compilation, so make it a flag and not a pragma line.
2715 if (suppress_missing_prototypes)
2716 statement("#pragma clang diagnostic ignored \"-Wmissing-prototypes\"");
2717 for (auto &pragma : pragma_lines)
2718 statement(pragma);
2719
2720 if (!pragma_lines.empty() || suppress_missing_prototypes)
2721 statement("");
2722
2723 statement("#include <metal_stdlib>");
2724 statement("#include <simd/simd.h>");
2725
2726 for (auto &header : header_lines)
2727 statement(header);
2728
2729 statement("");
2730 statement("using namespace metal;");
2731 statement("");
2732
2733 for (auto &td : typedef_lines)
2734 statement(td);
2735
2736 if (!typedef_lines.empty())
2737 statement("");
2738 }
2739
add_pragma_line(const string & line)2740 void CompilerMSL::add_pragma_line(const string &line)
2741 {
2742 auto rslt = pragma_lines.insert(line);
2743 if (rslt.second)
2744 force_recompile();
2745 }
2746
add_typedef_line(const string & line)2747 void CompilerMSL::add_typedef_line(const string &line)
2748 {
2749 auto rslt = typedef_lines.insert(line);
2750 if (rslt.second)
2751 force_recompile();
2752 }
2753
2754 // Emits any needed custom function bodies.
emit_custom_functions()2755 void CompilerMSL::emit_custom_functions()
2756 {
2757 for (uint32_t i = SPVFuncImplArrayCopyMultidimMax; i >= 2; i--)
2758 if (spv_function_implementations.count(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i)))
2759 spv_function_implementations.insert(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i - 1));
2760
2761 for (auto &spv_func : spv_function_implementations)
2762 {
2763 switch (spv_func)
2764 {
2765 case SPVFuncImplMod:
2766 statement("// Implementation of the GLSL mod() function, which is slightly different than Metal fmod()");
2767 statement("template<typename Tx, typename Ty>");
2768 statement("Tx mod(Tx x, Ty y)");
2769 begin_scope();
2770 statement("return x - y * floor(x / y);");
2771 end_scope();
2772 statement("");
2773 break;
2774
2775 case SPVFuncImplRadians:
2776 statement("// Implementation of the GLSL radians() function");
2777 statement("template<typename T>");
2778 statement("T radians(T d)");
2779 begin_scope();
2780 statement("return d * T(0.01745329251);");
2781 end_scope();
2782 statement("");
2783 break;
2784
2785 case SPVFuncImplDegrees:
2786 statement("// Implementation of the GLSL degrees() function");
2787 statement("template<typename T>");
2788 statement("T degrees(T r)");
2789 begin_scope();
2790 statement("return r * T(57.2957795131);");
2791 end_scope();
2792 statement("");
2793 break;
2794
2795 case SPVFuncImplFindILsb:
2796 statement("// Implementation of the GLSL findLSB() function");
2797 statement("template<typename T>");
2798 statement("T findLSB(T x)");
2799 begin_scope();
2800 statement("return select(ctz(x), T(-1), x == T(0));");
2801 end_scope();
2802 statement("");
2803 break;
2804
2805 case SPVFuncImplFindUMsb:
2806 statement("// Implementation of the unsigned GLSL findMSB() function");
2807 statement("template<typename T>");
2808 statement("T findUMSB(T x)");
2809 begin_scope();
2810 statement("return select(clz(T(0)) - (clz(x) + T(1)), T(-1), x == T(0));");
2811 end_scope();
2812 statement("");
2813 break;
2814
2815 case SPVFuncImplFindSMsb:
2816 statement("// Implementation of the signed GLSL findMSB() function");
2817 statement("template<typename T>");
2818 statement("T findSMSB(T x)");
2819 begin_scope();
2820 statement("T v = select(x, T(-1) - x, x < T(0));");
2821 statement("return select(clz(T(0)) - (clz(v) + T(1)), T(-1), v == T(0));");
2822 end_scope();
2823 statement("");
2824 break;
2825
2826 case SPVFuncImplSSign:
2827 statement("// Implementation of the GLSL sign() function for integer types");
2828 statement("template<typename T, typename E = typename enable_if<is_integral<T>::value>::type>");
2829 statement("T sign(T x)");
2830 begin_scope();
2831 statement("return select(select(select(x, T(0), x == T(0)), T(1), x > T(0)), T(-1), x < T(0));");
2832 end_scope();
2833 statement("");
2834 break;
2835
2836 case SPVFuncImplArrayCopy:
2837 statement("// Implementation of an array copy function to cover GLSL's ability to copy an array via "
2838 "assignment.");
2839 statement("template<typename T, uint N>");
2840 statement("void spvArrayCopyFromStack1(thread T (&dst)[N], thread const T (&src)[N])");
2841 begin_scope();
2842 statement("for (uint i = 0; i < N; dst[i] = src[i], i++);");
2843 end_scope();
2844 statement("");
2845
2846 statement("template<typename T, uint N>");
2847 statement("void spvArrayCopyFromConstant1(thread T (&dst)[N], constant T (&src)[N])");
2848 begin_scope();
2849 statement("for (uint i = 0; i < N; dst[i] = src[i], i++);");
2850 end_scope();
2851 statement("");
2852 break;
2853
2854 case SPVFuncImplArrayOfArrayCopy2Dim:
2855 case SPVFuncImplArrayOfArrayCopy3Dim:
2856 case SPVFuncImplArrayOfArrayCopy4Dim:
2857 case SPVFuncImplArrayOfArrayCopy5Dim:
2858 case SPVFuncImplArrayOfArrayCopy6Dim:
2859 {
2860 static const char *function_name_tags[] = {
2861 "FromStack",
2862 "FromConstant",
2863 };
2864
2865 static const char *src_address_space[] = {
2866 "thread const",
2867 "constant",
2868 };
2869
2870 for (uint32_t variant = 0; variant < 2; variant++)
2871 {
2872 uint32_t dimensions = spv_func - SPVFuncImplArrayCopyMultidimBase;
2873 string tmp = "template<typename T";
2874 for (uint8_t i = 0; i < dimensions; i++)
2875 {
2876 tmp += ", uint ";
2877 tmp += 'A' + i;
2878 }
2879 tmp += ">";
2880 statement(tmp);
2881
2882 string array_arg;
2883 for (uint8_t i = 0; i < dimensions; i++)
2884 {
2885 array_arg += "[";
2886 array_arg += 'A' + i;
2887 array_arg += "]";
2888 }
2889
2890 statement("void spvArrayCopy", function_name_tags[variant], dimensions, "(thread T (&dst)", array_arg,
2891 ", ", src_address_space[variant], " T (&src)", array_arg, ")");
2892
2893 begin_scope();
2894 statement("for (uint i = 0; i < A; i++)");
2895 begin_scope();
2896 statement("spvArrayCopy", function_name_tags[variant], dimensions - 1, "(dst[i], src[i]);");
2897 end_scope();
2898 end_scope();
2899 statement("");
2900 }
2901 break;
2902 }
2903
2904 case SPVFuncImplTexelBufferCoords:
2905 {
2906 string tex_width_str = convert_to_string(msl_options.texel_buffer_texture_width);
2907 statement("// Returns 2D texture coords corresponding to 1D texel buffer coords");
2908 statement("uint2 spvTexelBufferCoord(uint tc)");
2909 begin_scope();
2910 statement(join("return uint2(tc % ", tex_width_str, ", tc / ", tex_width_str, ");"));
2911 end_scope();
2912 statement("");
2913 break;
2914 }
2915
2916 case SPVFuncImplInverse4x4:
2917 statement("// Returns the determinant of a 2x2 matrix.");
2918 statement("inline float spvDet2x2(float a1, float a2, float b1, float b2)");
2919 begin_scope();
2920 statement("return a1 * b2 - b1 * a2;");
2921 end_scope();
2922 statement("");
2923
2924 statement("// Returns the determinant of a 3x3 matrix.");
2925 statement("inline float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
2926 "float c2, float c3)");
2927 begin_scope();
2928 statement("return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * spvDet2x2(a2, a3, "
2929 "b2, b3);");
2930 end_scope();
2931 statement("");
2932 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
2933 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
2934 statement("float4x4 spvInverse4x4(float4x4 m)");
2935 begin_scope();
2936 statement("float4x4 adj; // The adjoint matrix (inverse after dividing by determinant)");
2937 statement_no_indent("");
2938 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
2939 statement("adj[0][0] = spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
2940 "m[3][3]);");
2941 statement("adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
2942 "m[3][3]);");
2943 statement("adj[0][2] = spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
2944 "m[3][3]);");
2945 statement("adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
2946 "m[2][3]);");
2947 statement_no_indent("");
2948 statement("adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
2949 "m[3][3]);");
2950 statement("adj[1][1] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
2951 "m[3][3]);");
2952 statement("adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
2953 "m[3][3]);");
2954 statement("adj[1][3] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
2955 "m[2][3]);");
2956 statement_no_indent("");
2957 statement("adj[2][0] = spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
2958 "m[3][3]);");
2959 statement("adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
2960 "m[3][3]);");
2961 statement("adj[2][2] = spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
2962 "m[3][3]);");
2963 statement("adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
2964 "m[2][3]);");
2965 statement_no_indent("");
2966 statement("adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
2967 "m[3][2]);");
2968 statement("adj[3][1] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
2969 "m[3][2]);");
2970 statement("adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
2971 "m[3][2]);");
2972 statement("adj[3][3] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
2973 "m[2][2]);");
2974 statement_no_indent("");
2975 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
2976 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
2977 "* m[3][0]);");
2978 statement_no_indent("");
2979 statement("// Divide the classical adjoint matrix by the determinant.");
2980 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
2981 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
2982 end_scope();
2983 statement("");
2984 break;
2985
2986 case SPVFuncImplInverse3x3:
2987 if (spv_function_implementations.count(SPVFuncImplInverse4x4) == 0)
2988 {
2989 statement("// Returns the determinant of a 2x2 matrix.");
2990 statement("inline float spvDet2x2(float a1, float a2, float b1, float b2)");
2991 begin_scope();
2992 statement("return a1 * b2 - b1 * a2;");
2993 end_scope();
2994 statement("");
2995 }
2996
2997 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
2998 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
2999 statement("float3x3 spvInverse3x3(float3x3 m)");
3000 begin_scope();
3001 statement("float3x3 adj; // The adjoint matrix (inverse after dividing by determinant)");
3002 statement_no_indent("");
3003 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
3004 statement("adj[0][0] = spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
3005 statement("adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
3006 statement("adj[0][2] = spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
3007 statement_no_indent("");
3008 statement("adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
3009 statement("adj[1][1] = spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
3010 statement("adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
3011 statement_no_indent("");
3012 statement("adj[2][0] = spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
3013 statement("adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
3014 statement("adj[2][2] = spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
3015 statement_no_indent("");
3016 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
3017 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
3018 statement_no_indent("");
3019 statement("// Divide the classical adjoint matrix by the determinant.");
3020 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
3021 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
3022 end_scope();
3023 statement("");
3024 break;
3025
3026 case SPVFuncImplInverse2x2:
3027 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
3028 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
3029 statement("float2x2 spvInverse2x2(float2x2 m)");
3030 begin_scope();
3031 statement("float2x2 adj; // The adjoint matrix (inverse after dividing by determinant)");
3032 statement_no_indent("");
3033 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
3034 statement("adj[0][0] = m[1][1];");
3035 statement("adj[0][1] = -m[0][1];");
3036 statement_no_indent("");
3037 statement("adj[1][0] = -m[1][0];");
3038 statement("adj[1][1] = m[0][0];");
3039 statement_no_indent("");
3040 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
3041 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
3042 statement_no_indent("");
3043 statement("// Divide the classical adjoint matrix by the determinant.");
3044 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
3045 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
3046 end_scope();
3047 statement("");
3048 break;
3049
3050 case SPVFuncImplRowMajor2x3:
3051 statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
3052 statement("float2x3 spvConvertFromRowMajor2x3(float2x3 m)");
3053 begin_scope();
3054 statement("return float2x3(float3(m[0][0], m[0][2], m[1][1]), float3(m[0][1], m[1][0], m[1][2]));");
3055 end_scope();
3056 statement("");
3057 break;
3058
3059 case SPVFuncImplRowMajor2x4:
3060 statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
3061 statement("float2x4 spvConvertFromRowMajor2x4(float2x4 m)");
3062 begin_scope();
3063 statement("return float2x4(float4(m[0][0], m[0][2], m[1][0], m[1][2]), float4(m[0][1], m[0][3], m[1][1], "
3064 "m[1][3]));");
3065 end_scope();
3066 statement("");
3067 break;
3068
3069 case SPVFuncImplRowMajor3x2:
3070 statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
3071 statement("float3x2 spvConvertFromRowMajor3x2(float3x2 m)");
3072 begin_scope();
3073 statement("return float3x2(float2(m[0][0], m[1][1]), float2(m[0][1], m[2][0]), float2(m[1][0], m[2][1]));");
3074 end_scope();
3075 statement("");
3076 break;
3077
3078 case SPVFuncImplRowMajor3x4:
3079 statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
3080 statement("float3x4 spvConvertFromRowMajor3x4(float3x4 m)");
3081 begin_scope();
3082 statement("return float3x4(float4(m[0][0], m[0][3], m[1][2], m[2][1]), float4(m[0][1], m[1][0], m[1][3], "
3083 "m[2][2]), float4(m[0][2], m[1][1], m[2][0], m[2][3]));");
3084 end_scope();
3085 statement("");
3086 break;
3087
3088 case SPVFuncImplRowMajor4x2:
3089 statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
3090 statement("float4x2 spvConvertFromRowMajor4x2(float4x2 m)");
3091 begin_scope();
3092 statement("return float4x2(float2(m[0][0], m[2][0]), float2(m[0][1], m[2][1]), float2(m[1][0], m[3][0]), "
3093 "float2(m[1][1], m[3][1]));");
3094 end_scope();
3095 statement("");
3096 break;
3097
3098 case SPVFuncImplRowMajor4x3:
3099 statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
3100 statement("float4x3 spvConvertFromRowMajor4x3(float4x3 m)");
3101 begin_scope();
3102 statement("return float4x3(float3(m[0][0], m[1][1], m[2][2]), float3(m[0][1], m[1][2], m[3][0]), "
3103 "float3(m[0][2], m[2][0], m[3][1]), float3(m[1][0], m[2][1], m[3][2]));");
3104 end_scope();
3105 statement("");
3106 break;
3107
3108 case SPVFuncImplTextureSwizzle:
3109 statement("enum class spvSwizzle : uint");
3110 begin_scope();
3111 statement("none = 0,");
3112 statement("zero,");
3113 statement("one,");
3114 statement("red,");
3115 statement("green,");
3116 statement("blue,");
3117 statement("alpha");
3118 end_scope_decl();
3119 statement("");
3120 statement("template<typename T> struct spvRemoveReference { typedef T type; };");
3121 statement("template<typename T> struct spvRemoveReference<thread T&> { typedef T type; };");
3122 statement("template<typename T> struct spvRemoveReference<thread T&&> { typedef T type; };");
3123 statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
3124 "spvRemoveReference<T>::type& x)");
3125 begin_scope();
3126 statement("return static_cast<thread T&&>(x);");
3127 end_scope();
3128 statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
3129 "spvRemoveReference<T>::type&& x)");
3130 begin_scope();
3131 statement("return static_cast<thread T&&>(x);");
3132 end_scope();
3133 statement("");
3134 statement("template<typename T>");
3135 statement("inline T spvGetSwizzle(vec<T, 4> x, T c, spvSwizzle s)");
3136 begin_scope();
3137 statement("switch (s)");
3138 begin_scope();
3139 statement("case spvSwizzle::none:");
3140 statement(" return c;");
3141 statement("case spvSwizzle::zero:");
3142 statement(" return 0;");
3143 statement("case spvSwizzle::one:");
3144 statement(" return 1;");
3145 statement("case spvSwizzle::red:");
3146 statement(" return x.r;");
3147 statement("case spvSwizzle::green:");
3148 statement(" return x.g;");
3149 statement("case spvSwizzle::blue:");
3150 statement(" return x.b;");
3151 statement("case spvSwizzle::alpha:");
3152 statement(" return x.a;");
3153 end_scope();
3154 end_scope();
3155 statement("");
3156 statement("// Wrapper function that swizzles texture samples and fetches.");
3157 statement("template<typename T>");
3158 statement("inline vec<T, 4> spvTextureSwizzle(vec<T, 4> x, uint s)");
3159 begin_scope();
3160 statement("if (!s)");
3161 statement(" return x;");
3162 statement("return vec<T, 4>(spvGetSwizzle(x, x.r, spvSwizzle((s >> 0) & 0xFF)), "
3163 "spvGetSwizzle(x, x.g, spvSwizzle((s >> 8) & 0xFF)), spvGetSwizzle(x, x.b, spvSwizzle((s >> 16) "
3164 "& 0xFF)), "
3165 "spvGetSwizzle(x, x.a, spvSwizzle((s >> 24) & 0xFF)));");
3166 end_scope();
3167 statement("");
3168 statement("template<typename T>");
3169 statement("inline T spvTextureSwizzle(T x, uint s)");
3170 begin_scope();
3171 statement("return spvTextureSwizzle(vec<T, 4>(x, 0, 0, 1), s).x;");
3172 end_scope();
3173 statement("");
3174 statement("// Wrapper function that swizzles texture gathers.");
3175 statement("template<typename T, typename Tex, typename... Ts>");
3176 statement(
3177 "inline vec<T, 4> spvGatherSwizzle(sampler s, const thread Tex& t, Ts... params, component c, uint sw) "
3178 "METAL_CONST_ARG(c)");
3179 begin_scope();
3180 statement("if (sw)");
3181 begin_scope();
3182 statement("switch (spvSwizzle((sw >> (uint(c) * 8)) & 0xFF))");
3183 begin_scope();
3184 statement("case spvSwizzle::none:");
3185 statement(" break;");
3186 statement("case spvSwizzle::zero:");
3187 statement(" return vec<T, 4>(0, 0, 0, 0);");
3188 statement("case spvSwizzle::one:");
3189 statement(" return vec<T, 4>(1, 1, 1, 1);");
3190 statement("case spvSwizzle::red:");
3191 statement(" return t.gather(s, spvForward<Ts>(params)..., component::x);");
3192 statement("case spvSwizzle::green:");
3193 statement(" return t.gather(s, spvForward<Ts>(params)..., component::y);");
3194 statement("case spvSwizzle::blue:");
3195 statement(" return t.gather(s, spvForward<Ts>(params)..., component::z);");
3196 statement("case spvSwizzle::alpha:");
3197 statement(" return t.gather(s, spvForward<Ts>(params)..., component::w);");
3198 end_scope();
3199 end_scope();
3200 // texture::gather insists on its component parameter being a constant
3201 // expression, so we need this silly workaround just to compile the shader.
3202 statement("switch (c)");
3203 begin_scope();
3204 statement("case component::x:");
3205 statement(" return t.gather(s, spvForward<Ts>(params)..., component::x);");
3206 statement("case component::y:");
3207 statement(" return t.gather(s, spvForward<Ts>(params)..., component::y);");
3208 statement("case component::z:");
3209 statement(" return t.gather(s, spvForward<Ts>(params)..., component::z);");
3210 statement("case component::w:");
3211 statement(" return t.gather(s, spvForward<Ts>(params)..., component::w);");
3212 end_scope();
3213 end_scope();
3214 statement("");
3215 statement("// Wrapper function that swizzles depth texture gathers.");
3216 statement("template<typename T, typename Tex, typename... Ts>");
3217 statement(
3218 "inline vec<T, 4> spvGatherCompareSwizzle(sampler s, const thread Tex& t, Ts... params, uint sw) ");
3219 begin_scope();
3220 statement("if (sw)");
3221 begin_scope();
3222 statement("switch (spvSwizzle(sw & 0xFF))");
3223 begin_scope();
3224 statement("case spvSwizzle::none:");
3225 statement("case spvSwizzle::red:");
3226 statement(" break;");
3227 statement("case spvSwizzle::zero:");
3228 statement("case spvSwizzle::green:");
3229 statement("case spvSwizzle::blue:");
3230 statement("case spvSwizzle::alpha:");
3231 statement(" return vec<T, 4>(0, 0, 0, 0);");
3232 statement("case spvSwizzle::one:");
3233 statement(" return vec<T, 4>(1, 1, 1, 1);");
3234 end_scope();
3235 end_scope();
3236 statement("return t.gather_compare(s, spvForward<Ts>(params)...);");
3237 end_scope();
3238 statement("");
3239 break;
3240
3241 case SPVFuncImplSubgroupBallot:
3242 statement("inline uint4 spvSubgroupBallot(bool value)");
3243 begin_scope();
3244 statement("simd_vote vote = simd_ballot(value);");
3245 statement("// simd_ballot() returns a 64-bit integer-like object, but");
3246 statement("// SPIR-V callers expect a uint4. We must convert.");
3247 statement("// FIXME: This won't include higher bits if Apple ever supports");
3248 statement("// 128 lanes in an SIMD-group.");
3249 statement("return uint4((uint)((simd_vote::vote_t)vote & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)vote >> "
3250 "32) & 0xFFFFFFFF), 0, 0);");
3251 end_scope();
3252 statement("");
3253 break;
3254
3255 case SPVFuncImplSubgroupBallotBitExtract:
3256 statement("inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)");
3257 begin_scope();
3258 statement("return !!extract_bits(ballot[bit / 32], bit % 32, 1);");
3259 end_scope();
3260 statement("");
3261 break;
3262
3263 case SPVFuncImplSubgroupBallotFindLSB:
3264 statement("inline uint spvSubgroupBallotFindLSB(uint4 ballot)");
3265 begin_scope();
3266 statement("return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + "
3267 "ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);");
3268 end_scope();
3269 statement("");
3270 break;
3271
3272 case SPVFuncImplSubgroupBallotFindMSB:
3273 statement("inline uint spvSubgroupBallotFindMSB(uint4 ballot)");
3274 begin_scope();
3275 statement("return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - "
3276 "(clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), "
3277 "ballot.z == 0), ballot.w == 0);");
3278 end_scope();
3279 statement("");
3280 break;
3281
3282 case SPVFuncImplSubgroupBallotBitCount:
3283 statement("inline uint spvSubgroupBallotBitCount(uint4 ballot)");
3284 begin_scope();
3285 statement("return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);");
3286 end_scope();
3287 statement("");
3288 statement("inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
3289 begin_scope();
3290 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), "
3291 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), "
3292 "uint2(0));");
3293 statement("return spvSubgroupBallotBitCount(ballot & mask);");
3294 end_scope();
3295 statement("");
3296 statement("inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
3297 begin_scope();
3298 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), "
3299 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));");
3300 statement("return spvSubgroupBallotBitCount(ballot & mask);");
3301 end_scope();
3302 statement("");
3303 break;
3304
3305 case SPVFuncImplSubgroupAllEqual:
3306 // Metal doesn't provide a function to evaluate this directly. But, we can
3307 // implement this by comparing every thread's value to one thread's value
3308 // (in this case, the value of the first active thread). Then, by the transitive
3309 // property of equality, if all comparisons return true, then they are all equal.
3310 statement("template<typename T>");
3311 statement("inline bool spvSubgroupAllEqual(T value)");
3312 begin_scope();
3313 statement("return simd_all(value == simd_broadcast_first(value));");
3314 end_scope();
3315 statement("");
3316 statement("template<>");
3317 statement("inline bool spvSubgroupAllEqual(bool value)");
3318 begin_scope();
3319 statement("return simd_all(value) || !simd_any(value);");
3320 end_scope();
3321 statement("");
3322 break;
3323
3324 case SPVFuncImplReflectScalar:
3325 // Metal does not support scalar versions of these functions.
3326 statement("template<typename T>");
3327 statement("inline T spvReflect(T i, T n)");
3328 begin_scope();
3329 statement("return i - T(2) * i * n * n;");
3330 end_scope();
3331 statement("");
3332 break;
3333
3334 case SPVFuncImplRefractScalar:
3335 // Metal does not support scalar versions of these functions.
3336 statement("template<typename T>");
3337 statement("inline T spvRefract(T i, T n, T eta)");
3338 begin_scope();
3339 statement("T NoI = n * i;");
3340 statement("T NoI2 = NoI * NoI;");
3341 statement("T k = T(1) - eta * eta * (T(1) - NoI2);");
3342 statement("if (k < T(0))");
3343 begin_scope();
3344 statement("return T(0);");
3345 end_scope();
3346 statement("else");
3347 begin_scope();
3348 statement("return eta * i - (eta * NoI + sqrt(k)) * n;");
3349 end_scope();
3350 end_scope();
3351 statement("");
3352 break;
3353
3354 default:
3355 break;
3356 }
3357 }
3358 }
3359
3360 // Undefined global memory is not allowed in MSL.
3361 // Declare constant and init to zeros. Use {}, as global constructors can break Metal.
declare_undefined_values()3362 void CompilerMSL::declare_undefined_values()
3363 {
3364 bool emitted = false;
3365 ir.for_each_typed_id<SPIRUndef>([&](uint32_t, SPIRUndef &undef) {
3366 auto &type = this->get<SPIRType>(undef.basetype);
3367 statement("constant ", variable_decl(type, to_name(undef.self), undef.self), " = {};");
3368 emitted = true;
3369 });
3370
3371 if (emitted)
3372 statement("");
3373 }
3374
declare_constant_arrays()3375 void CompilerMSL::declare_constant_arrays()
3376 {
3377 // MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
3378 // global constants directly, so we are able to use constants as variable expressions.
3379 bool emitted = false;
3380
3381 ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
3382 if (c.specialization)
3383 return;
3384
3385 auto &type = this->get<SPIRType>(c.constant_type);
3386 if (!type.array.empty())
3387 {
3388 auto name = to_name(c.self);
3389 statement("constant ", variable_decl(type, name), " = ", constant_expression(c), ";");
3390 emitted = true;
3391 }
3392 });
3393
3394 if (emitted)
3395 statement("");
3396 }
3397
emit_resources()3398 void CompilerMSL::emit_resources()
3399 {
3400 declare_constant_arrays();
3401 declare_undefined_values();
3402
3403 // Emit the special [[stage_in]] and [[stage_out]] interface blocks which we created.
3404 emit_interface_block(stage_out_var_id);
3405 emit_interface_block(patch_stage_out_var_id);
3406 emit_interface_block(stage_in_var_id);
3407 emit_interface_block(patch_stage_in_var_id);
3408 }
3409
3410 // Emit declarations for the specialization Metal function constants
emit_specialization_constants_and_structs()3411 void CompilerMSL::emit_specialization_constants_and_structs()
3412 {
3413 SpecializationConstant wg_x, wg_y, wg_z;
3414 uint32_t workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
3415 bool emitted = false;
3416
3417 unordered_set<uint32_t> declared_structs;
3418
3419 for (auto &id_ : ir.ids_for_constant_or_type)
3420 {
3421 auto &id = ir.ids[id_];
3422
3423 if (id.get_type() == TypeConstant)
3424 {
3425 auto &c = id.get<SPIRConstant>();
3426
3427 if (c.self == workgroup_size_id)
3428 {
3429 // TODO: This can be expressed as a [[threads_per_threadgroup]] input semantic, but we need to know
3430 // the work group size at compile time in SPIR-V, and [[threads_per_threadgroup]] would need to be passed around as a global.
3431 // The work group size may be a specialization constant.
3432 statement("constant uint3 ", builtin_to_glsl(BuiltInWorkgroupSize, StorageClassWorkgroup),
3433 " [[maybe_unused]] = ", constant_expression(get<SPIRConstant>(workgroup_size_id)), ";");
3434 emitted = true;
3435 }
3436 else if (c.specialization)
3437 {
3438 auto &type = get<SPIRType>(c.constant_type);
3439 string sc_type_name = type_to_glsl(type);
3440 string sc_name = to_name(c.self);
3441 string sc_tmp_name = sc_name + "_tmp";
3442
3443 // Function constants are only supported in MSL 1.2 and later.
3444 // If we don't support it just declare the "default" directly.
3445 // This "default" value can be overridden to the true specialization constant by the API user.
3446 // Specialization constants which are used as array length expressions cannot be function constants in MSL,
3447 // so just fall back to macros.
3448 if (msl_options.supports_msl_version(1, 2) && has_decoration(c.self, DecorationSpecId) &&
3449 !c.is_used_as_array_length)
3450 {
3451 uint32_t constant_id = get_decoration(c.self, DecorationSpecId);
3452 // Only scalar, non-composite values can be function constants.
3453 statement("constant ", sc_type_name, " ", sc_tmp_name, " [[function_constant(", constant_id,
3454 ")]];");
3455 statement("constant ", sc_type_name, " ", sc_name, " = is_function_constant_defined(", sc_tmp_name,
3456 ") ? ", sc_tmp_name, " : ", constant_expression(c), ";");
3457 }
3458 else if (has_decoration(c.self, DecorationSpecId))
3459 {
3460 // Fallback to macro overrides.
3461 c.specialization_constant_macro_name =
3462 constant_value_macro_name(get_decoration(c.self, DecorationSpecId));
3463
3464 statement("#ifndef ", c.specialization_constant_macro_name);
3465 statement("#define ", c.specialization_constant_macro_name, " ", constant_expression(c));
3466 statement("#endif");
3467 statement("constant ", sc_type_name, " ", sc_name, " = ", c.specialization_constant_macro_name,
3468 ";");
3469 }
3470 else
3471 {
3472 // Composite specialization constants must be built from other specialization constants.
3473 statement("constant ", sc_type_name, " ", sc_name, " = ", constant_expression(c), ";");
3474 }
3475 emitted = true;
3476 }
3477 }
3478 else if (id.get_type() == TypeConstantOp)
3479 {
3480 auto &c = id.get<SPIRConstantOp>();
3481 auto &type = get<SPIRType>(c.basetype);
3482 auto name = to_name(c.self);
3483 statement("constant ", variable_decl(type, name), " = ", constant_op_expression(c), ";");
3484 emitted = true;
3485 }
3486 else if (id.get_type() == TypeType)
3487 {
3488 // Output non-builtin interface structs. These include local function structs
3489 // and structs nested within uniform and read-write buffers.
3490 auto &type = id.get<SPIRType>();
3491 uint32_t type_id = type.self;
3492
3493 bool is_struct = (type.basetype == SPIRType::Struct) && type.array.empty();
3494 bool is_block =
3495 has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock);
3496
3497 bool is_builtin_block = is_block && is_builtin_type(type);
3498 bool is_declarable_struct = is_struct && !is_builtin_block;
3499
3500 // We'll declare this later.
3501 if (stage_out_var_id && get_stage_out_struct_type().self == type_id)
3502 is_declarable_struct = false;
3503 if (patch_stage_out_var_id && get_patch_stage_out_struct_type().self == type_id)
3504 is_declarable_struct = false;
3505 if (stage_in_var_id && get_stage_in_struct_type().self == type_id)
3506 is_declarable_struct = false;
3507 if (patch_stage_in_var_id && get_patch_stage_in_struct_type().self == type_id)
3508 is_declarable_struct = false;
3509
3510 // Align and emit declarable structs...but avoid declaring each more than once.
3511 if (is_declarable_struct && declared_structs.count(type_id) == 0)
3512 {
3513 if (emitted)
3514 statement("");
3515 emitted = false;
3516
3517 declared_structs.insert(type_id);
3518
3519 if (has_extended_decoration(type_id, SPIRVCrossDecorationPacked))
3520 align_struct(type);
3521
3522 // Make sure we declare the underlying struct type, and not the "decorated" type with pointers, etc.
3523 emit_struct(get<SPIRType>(type_id));
3524 }
3525 }
3526 }
3527
3528 if (emitted)
3529 statement("");
3530 }
3531
emit_binary_unord_op(uint32_t result_type,uint32_t result_id,uint32_t op0,uint32_t op1,const char * op)3532 void CompilerMSL::emit_binary_unord_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
3533 const char *op)
3534 {
3535 bool forward = should_forward(op0) && should_forward(op1);
3536 emit_op(result_type, result_id,
3537 join("(isunordered(", to_enclosed_unpacked_expression(op0), ", ", to_enclosed_unpacked_expression(op1),
3538 ") || ", to_enclosed_unpacked_expression(op0), " ", op, " ", to_enclosed_unpacked_expression(op1),
3539 ")"),
3540 forward);
3541
3542 inherit_expression_dependencies(result_id, op0);
3543 inherit_expression_dependencies(result_id, op1);
3544 }
3545
emit_tessellation_access_chain(const uint32_t * ops,uint32_t length)3546 bool CompilerMSL::emit_tessellation_access_chain(const uint32_t *ops, uint32_t length)
3547 {
3548 // If this is a per-vertex output, remap it to the I/O array buffer.
3549 auto *var = maybe_get<SPIRVariable>(ops[2]);
3550 BuiltIn bi_type = BuiltIn(get_decoration(ops[2], DecorationBuiltIn));
3551 if (var &&
3552 (var->storage == StorageClassInput ||
3553 (get_execution_model() == ExecutionModelTessellationControl && var->storage == StorageClassOutput)) &&
3554 !(has_decoration(ops[2], DecorationPatch) || is_patch_block(get_variable_data_type(*var))) &&
3555 (!is_builtin_variable(*var) || bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
3556 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance ||
3557 get_variable_data_type(*var).basetype == SPIRType::Struct))
3558 {
3559 AccessChainMeta meta;
3560 SmallVector<uint32_t> indices;
3561 uint32_t next_id = ir.increase_bound_by(2);
3562
3563 indices.reserve(length - 3 + 1);
3564 uint32_t type_id = next_id++;
3565 SPIRType new_uint_type;
3566 new_uint_type.basetype = SPIRType::UInt;
3567 new_uint_type.width = 32;
3568 set<SPIRType>(type_id, new_uint_type);
3569
3570 indices.push_back(ops[3]);
3571
3572 uint32_t const_mbr_id = next_id++;
3573 uint32_t index = get_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex);
3574 uint32_t ptr = var->storage == StorageClassInput ? stage_in_ptr_var_id : stage_out_ptr_var_id;
3575 if (var->storage == StorageClassInput || has_decoration(get_variable_element_type(*var).self, DecorationBlock))
3576 {
3577 uint32_t i = 4;
3578 auto *type = &get_variable_element_type(*var);
3579 if (index == uint32_t(-1) && length >= 5)
3580 {
3581 // Maybe this is a struct type in the input class, in which case
3582 // we put it as a decoration on the corresponding member.
3583 index = get_extended_member_decoration(ops[2], get_constant(ops[4]).scalar(),
3584 SPIRVCrossDecorationInterfaceMemberIndex);
3585 assert(index != uint32_t(-1));
3586 i++;
3587 type = &get<SPIRType>(type->member_types[get_constant(ops[4]).scalar()]);
3588 }
3589 // In this case, we flattened structures and arrays, so now we have to
3590 // combine the following indices. If we encounter a non-constant index,
3591 // we're hosed.
3592 for (; i < length; ++i)
3593 {
3594 if (!is_array(*type) && !is_matrix(*type) && type->basetype != SPIRType::Struct)
3595 break;
3596
3597 auto &c = get_constant(ops[i]);
3598 index += c.scalar();
3599 if (type->parent_type)
3600 type = &get<SPIRType>(type->parent_type);
3601 else if (type->basetype == SPIRType::Struct)
3602 type = &get<SPIRType>(type->member_types[c.scalar()]);
3603 }
3604 // If the access chain terminates at a composite type, the composite
3605 // itself might be copied. In that case, we must unflatten it.
3606 if (is_matrix(*type) || is_array(*type) || type->basetype == SPIRType::Struct)
3607 {
3608 std::string temp_name = join(to_name(var->self), "_", ops[1]);
3609 statement(variable_decl(*type, temp_name, var->self), ";");
3610 // Set up the initializer for this temporary variable.
3611 indices.push_back(const_mbr_id);
3612 if (type->basetype == SPIRType::Struct)
3613 {
3614 for (uint32_t j = 0; j < type->member_types.size(); j++)
3615 {
3616 index = get_extended_member_decoration(ops[2], j, SPIRVCrossDecorationInterfaceMemberIndex);
3617 const auto &mbr_type = get<SPIRType>(type->member_types[j]);
3618 if (is_matrix(mbr_type))
3619 {
3620 for (uint32_t k = 0; k < mbr_type.columns; k++, index++)
3621 {
3622 set<SPIRConstant>(const_mbr_id, type_id, index, false);
3623 auto e = access_chain(ptr, indices.data(), uint32_t(indices.size()), mbr_type, nullptr,
3624 true);
3625 statement(temp_name, ".", to_member_name(*type, j), "[", k, "] = ", e, ";");
3626 }
3627 }
3628 else if (is_array(mbr_type))
3629 {
3630 for (uint32_t k = 0; k < mbr_type.array[0]; k++, index++)
3631 {
3632 set<SPIRConstant>(const_mbr_id, type_id, index, false);
3633 auto e = access_chain(ptr, indices.data(), uint32_t(indices.size()), mbr_type, nullptr,
3634 true);
3635 statement(temp_name, ".", to_member_name(*type, j), "[", k, "] = ", e, ";");
3636 }
3637 }
3638 else
3639 {
3640 set<SPIRConstant>(const_mbr_id, type_id, index, false);
3641 auto e =
3642 access_chain(ptr, indices.data(), uint32_t(indices.size()), mbr_type, nullptr, true);
3643 statement(temp_name, ".", to_member_name(*type, j), " = ", e, ";");
3644 }
3645 }
3646 }
3647 else if (is_matrix(*type))
3648 {
3649 for (uint32_t j = 0; j < type->columns; j++, index++)
3650 {
3651 set<SPIRConstant>(const_mbr_id, type_id, index, false);
3652 auto e = access_chain(ptr, indices.data(), uint32_t(indices.size()), *type, nullptr, true);
3653 statement(temp_name, "[", j, "] = ", e, ";");
3654 }
3655 }
3656 else // Must be an array
3657 {
3658 assert(is_array(*type));
3659 for (uint32_t j = 0; j < type->array[0]; j++, index++)
3660 {
3661 set<SPIRConstant>(const_mbr_id, type_id, index, false);
3662 auto e = access_chain(ptr, indices.data(), uint32_t(indices.size()), *type, nullptr, true);
3663 statement(temp_name, "[", j, "] = ", e, ";");
3664 }
3665 }
3666
3667 // This needs to be a variable instead of an expression so we don't
3668 // try to dereference this as a variable pointer.
3669 set<SPIRVariable>(ops[1], ops[0], var->storage);
3670 ir.meta[ops[1]] = ir.meta[ops[2]];
3671 set_name(ops[1], temp_name);
3672 if (has_decoration(var->self, DecorationInvariant))
3673 set_decoration(ops[1], DecorationInvariant);
3674 for (uint32_t j = 2; j < length; j++)
3675 inherit_expression_dependencies(ops[1], ops[j]);
3676 return true;
3677 }
3678 else
3679 {
3680 set<SPIRConstant>(const_mbr_id, type_id, index, false);
3681 indices.push_back(const_mbr_id);
3682
3683 if (i < length)
3684 indices.insert(indices.end(), ops + i, ops + length);
3685 }
3686 }
3687 else
3688 {
3689 assert(index != uint32_t(-1));
3690 set<SPIRConstant>(const_mbr_id, type_id, index, false);
3691 indices.push_back(const_mbr_id);
3692
3693 indices.insert(indices.end(), ops + 4, ops + length);
3694 }
3695
3696 // We use the pointer to the base of the input/output array here,
3697 // so this is always a pointer chain.
3698 auto e = access_chain(ptr, indices.data(), uint32_t(indices.size()), get<SPIRType>(ops[0]), &meta, true);
3699 auto &expr = set<SPIRExpression>(ops[1], move(e), ops[0], should_forward(ops[2]));
3700 expr.loaded_from = var->self;
3701 expr.need_transpose = meta.need_transpose;
3702 expr.access_chain = true;
3703
3704 // Mark the result as being packed if necessary.
3705 if (meta.storage_is_packed)
3706 set_extended_decoration(ops[1], SPIRVCrossDecorationPacked);
3707 if (meta.storage_packed_type != 0)
3708 set_extended_decoration(ops[1], SPIRVCrossDecorationPackedType, meta.storage_packed_type);
3709 if (meta.storage_is_invariant)
3710 set_decoration(ops[1], DecorationInvariant);
3711
3712 for (uint32_t i = 2; i < length; i++)
3713 {
3714 inherit_expression_dependencies(ops[1], ops[i]);
3715 add_implied_read_expression(expr, ops[i]);
3716 }
3717
3718 return true;
3719 }
3720
3721 // If this is the inner tessellation level, and we're tessellating triangles,
3722 // drop the last index. It isn't an array in this case, so we can't have an
3723 // array reference here. We need to make this ID a variable instead of an
3724 // expression so we don't try to dereference it as a variable pointer.
3725 // Don't do this if the index is a constant 1, though. We need to drop stores
3726 // to that one.
3727 auto *m = ir.find_meta(var ? var->self : 0);
3728 if (get_execution_model() == ExecutionModelTessellationControl && var && m &&
3729 m->decoration.builtin_type == BuiltInTessLevelInner && get_entry_point().flags.get(ExecutionModeTriangles))
3730 {
3731 auto *c = maybe_get<SPIRConstant>(ops[3]);
3732 if (c && c->scalar() == 1)
3733 return false;
3734 auto &dest_var = set<SPIRVariable>(ops[1], *var);
3735 dest_var.basetype = ops[0];
3736 ir.meta[ops[1]] = ir.meta[ops[2]];
3737 inherit_expression_dependencies(ops[1], ops[2]);
3738 return true;
3739 }
3740
3741 return false;
3742 }
3743
is_out_of_bounds_tessellation_level(uint32_t id_lhs)3744 bool CompilerMSL::is_out_of_bounds_tessellation_level(uint32_t id_lhs)
3745 {
3746 if (!get_entry_point().flags.get(ExecutionModeTriangles))
3747 return false;
3748
3749 // In SPIR-V, TessLevelInner always has two elements and TessLevelOuter always has
3750 // four. This is true even if we are tessellating triangles. This allows clients
3751 // to use a single tessellation control shader with multiple tessellation evaluation
3752 // shaders.
3753 // In Metal, however, only the first element of TessLevelInner and the first three
3754 // of TessLevelOuter are accessible. This stems from how in Metal, the tessellation
3755 // levels must be stored to a dedicated buffer in a particular format that depends
3756 // on the patch type. Therefore, in Triangles mode, any access to the second
3757 // inner level or the fourth outer level must be dropped.
3758 const auto *e = maybe_get<SPIRExpression>(id_lhs);
3759 if (!e || !e->access_chain)
3760 return false;
3761 BuiltIn builtin = BuiltIn(get_decoration(e->loaded_from, DecorationBuiltIn));
3762 if (builtin != BuiltInTessLevelInner && builtin != BuiltInTessLevelOuter)
3763 return false;
3764 auto *c = maybe_get<SPIRConstant>(e->implied_read_expressions[1]);
3765 if (!c)
3766 return false;
3767 return (builtin == BuiltInTessLevelInner && c->scalar() == 1) ||
3768 (builtin == BuiltInTessLevelOuter && c->scalar() == 3);
3769 }
3770
3771 // Override for MSL-specific syntax instructions
emit_instruction(const Instruction & instruction)3772 void CompilerMSL::emit_instruction(const Instruction &instruction)
3773 {
3774 #define MSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
3775 #define MSL_BOP_CAST(op, type) \
3776 emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
3777 #define MSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
3778 #define MSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
3779 #define MSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
3780 #define MSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
3781 #define MSL_BFOP_CAST(op, type) \
3782 emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
3783 #define MSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
3784 #define MSL_UNORD_BOP(op) emit_binary_unord_op(ops[0], ops[1], ops[2], ops[3], #op)
3785
3786 auto ops = stream(instruction);
3787 auto opcode = static_cast<Op>(instruction.op);
3788
3789 // If we need to do implicit bitcasts, make sure we do it with the correct type.
3790 uint32_t integer_width = get_integer_width_for_instruction(instruction);
3791 auto int_type = to_signed_basetype(integer_width);
3792 auto uint_type = to_unsigned_basetype(integer_width);
3793
3794 switch (opcode)
3795 {
3796
3797 // Comparisons
3798 case OpIEqual:
3799 MSL_BOP_CAST(==, int_type);
3800 break;
3801
3802 case OpLogicalEqual:
3803 case OpFOrdEqual:
3804 MSL_BOP(==);
3805 break;
3806
3807 case OpINotEqual:
3808 MSL_BOP_CAST(!=, int_type);
3809 break;
3810
3811 case OpLogicalNotEqual:
3812 case OpFOrdNotEqual:
3813 MSL_BOP(!=);
3814 break;
3815
3816 case OpUGreaterThan:
3817 MSL_BOP_CAST(>, uint_type);
3818 break;
3819
3820 case OpSGreaterThan:
3821 MSL_BOP_CAST(>, int_type);
3822 break;
3823
3824 case OpFOrdGreaterThan:
3825 MSL_BOP(>);
3826 break;
3827
3828 case OpUGreaterThanEqual:
3829 MSL_BOP_CAST(>=, uint_type);
3830 break;
3831
3832 case OpSGreaterThanEqual:
3833 MSL_BOP_CAST(>=, int_type);
3834 break;
3835
3836 case OpFOrdGreaterThanEqual:
3837 MSL_BOP(>=);
3838 break;
3839
3840 case OpULessThan:
3841 MSL_BOP_CAST(<, uint_type);
3842 break;
3843
3844 case OpSLessThan:
3845 MSL_BOP_CAST(<, int_type);
3846 break;
3847
3848 case OpFOrdLessThan:
3849 MSL_BOP(<);
3850 break;
3851
3852 case OpULessThanEqual:
3853 MSL_BOP_CAST(<=, uint_type);
3854 break;
3855
3856 case OpSLessThanEqual:
3857 MSL_BOP_CAST(<=, int_type);
3858 break;
3859
3860 case OpFOrdLessThanEqual:
3861 MSL_BOP(<=);
3862 break;
3863
3864 case OpFUnordEqual:
3865 MSL_UNORD_BOP(==);
3866 break;
3867
3868 case OpFUnordNotEqual:
3869 MSL_UNORD_BOP(!=);
3870 break;
3871
3872 case OpFUnordGreaterThan:
3873 MSL_UNORD_BOP(>);
3874 break;
3875
3876 case OpFUnordGreaterThanEqual:
3877 MSL_UNORD_BOP(>=);
3878 break;
3879
3880 case OpFUnordLessThan:
3881 MSL_UNORD_BOP(<);
3882 break;
3883
3884 case OpFUnordLessThanEqual:
3885 MSL_UNORD_BOP(<=);
3886 break;
3887
3888 // Derivatives
3889 case OpDPdx:
3890 case OpDPdxFine:
3891 case OpDPdxCoarse:
3892 MSL_UFOP(dfdx);
3893 register_control_dependent_expression(ops[1]);
3894 break;
3895
3896 case OpDPdy:
3897 case OpDPdyFine:
3898 case OpDPdyCoarse:
3899 MSL_UFOP(dfdy);
3900 register_control_dependent_expression(ops[1]);
3901 break;
3902
3903 case OpFwidth:
3904 case OpFwidthCoarse:
3905 case OpFwidthFine:
3906 MSL_UFOP(fwidth);
3907 register_control_dependent_expression(ops[1]);
3908 break;
3909
3910 // Bitfield
3911 case OpBitFieldInsert:
3912 MSL_QFOP(insert_bits);
3913 break;
3914
3915 case OpBitFieldSExtract:
3916 case OpBitFieldUExtract:
3917 MSL_TFOP(extract_bits);
3918 break;
3919
3920 case OpBitReverse:
3921 MSL_UFOP(reverse_bits);
3922 break;
3923
3924 case OpBitCount:
3925 MSL_UFOP(popcount);
3926 break;
3927
3928 case OpFRem:
3929 MSL_BFOP(fmod);
3930 break;
3931
3932 // Atomics
3933 case OpAtomicExchange:
3934 {
3935 uint32_t result_type = ops[0];
3936 uint32_t id = ops[1];
3937 uint32_t ptr = ops[2];
3938 uint32_t mem_sem = ops[4];
3939 uint32_t val = ops[5];
3940 emit_atomic_func_op(result_type, id, "atomic_exchange_explicit", mem_sem, mem_sem, false, ptr, val);
3941 break;
3942 }
3943
3944 case OpAtomicCompareExchange:
3945 {
3946 uint32_t result_type = ops[0];
3947 uint32_t id = ops[1];
3948 uint32_t ptr = ops[2];
3949 uint32_t mem_sem_pass = ops[4];
3950 uint32_t mem_sem_fail = ops[5];
3951 uint32_t val = ops[6];
3952 uint32_t comp = ops[7];
3953 emit_atomic_func_op(result_type, id, "atomic_compare_exchange_weak_explicit", mem_sem_pass, mem_sem_fail, true,
3954 ptr, comp, true, false, val);
3955 break;
3956 }
3957
3958 case OpAtomicCompareExchangeWeak:
3959 SPIRV_CROSS_THROW("OpAtomicCompareExchangeWeak is only supported in kernel profile.");
3960
3961 case OpAtomicLoad:
3962 {
3963 uint32_t result_type = ops[0];
3964 uint32_t id = ops[1];
3965 uint32_t ptr = ops[2];
3966 uint32_t mem_sem = ops[4];
3967 emit_atomic_func_op(result_type, id, "atomic_load_explicit", mem_sem, mem_sem, false, ptr, 0);
3968 break;
3969 }
3970
3971 case OpAtomicStore:
3972 {
3973 uint32_t result_type = expression_type(ops[0]).self;
3974 uint32_t id = ops[0];
3975 uint32_t ptr = ops[0];
3976 uint32_t mem_sem = ops[2];
3977 uint32_t val = ops[3];
3978 emit_atomic_func_op(result_type, id, "atomic_store_explicit", mem_sem, mem_sem, false, ptr, val);
3979 break;
3980 }
3981
3982 #define MSL_AFMO_IMPL(op, valsrc, valconst) \
3983 do \
3984 { \
3985 uint32_t result_type = ops[0]; \
3986 uint32_t id = ops[1]; \
3987 uint32_t ptr = ops[2]; \
3988 uint32_t mem_sem = ops[4]; \
3989 uint32_t val = valsrc; \
3990 emit_atomic_func_op(result_type, id, "atomic_fetch_" #op "_explicit", mem_sem, mem_sem, false, ptr, val, \
3991 false, valconst); \
3992 } while (false)
3993
3994 #define MSL_AFMO(op) MSL_AFMO_IMPL(op, ops[5], false)
3995 #define MSL_AFMIO(op) MSL_AFMO_IMPL(op, 1, true)
3996
3997 case OpAtomicIIncrement:
3998 MSL_AFMIO(add);
3999 break;
4000
4001 case OpAtomicIDecrement:
4002 MSL_AFMIO(sub);
4003 break;
4004
4005 case OpAtomicIAdd:
4006 MSL_AFMO(add);
4007 break;
4008
4009 case OpAtomicISub:
4010 MSL_AFMO(sub);
4011 break;
4012
4013 case OpAtomicSMin:
4014 case OpAtomicUMin:
4015 MSL_AFMO(min);
4016 break;
4017
4018 case OpAtomicSMax:
4019 case OpAtomicUMax:
4020 MSL_AFMO(max);
4021 break;
4022
4023 case OpAtomicAnd:
4024 MSL_AFMO(and);
4025 break;
4026
4027 case OpAtomicOr:
4028 MSL_AFMO(or);
4029 break;
4030
4031 case OpAtomicXor:
4032 MSL_AFMO(xor);
4033 break;
4034
4035 // Images
4036
4037 // Reads == Fetches in Metal
4038 case OpImageRead:
4039 {
4040 // Mark that this shader reads from this image
4041 uint32_t img_id = ops[2];
4042 auto &type = expression_type(img_id);
4043 if (type.image.dim != DimSubpassData)
4044 {
4045 auto *p_var = maybe_get_backing_variable(img_id);
4046 if (p_var && has_decoration(p_var->self, DecorationNonReadable))
4047 {
4048 unset_decoration(p_var->self, DecorationNonReadable);
4049 force_recompile();
4050 }
4051 }
4052
4053 emit_texture_op(instruction);
4054 break;
4055 }
4056
4057 case OpImageWrite:
4058 {
4059 uint32_t img_id = ops[0];
4060 uint32_t coord_id = ops[1];
4061 uint32_t texel_id = ops[2];
4062 const uint32_t *opt = &ops[3];
4063 uint32_t length = instruction.length - 3;
4064
4065 // Bypass pointers because we need the real image struct
4066 auto &type = expression_type(img_id);
4067 auto &img_type = get<SPIRType>(type.self);
4068
4069 // Ensure this image has been marked as being written to and force a
4070 // recommpile so that the image type output will include write access
4071 auto *p_var = maybe_get_backing_variable(img_id);
4072 if (p_var && has_decoration(p_var->self, DecorationNonWritable))
4073 {
4074 unset_decoration(p_var->self, DecorationNonWritable);
4075 force_recompile();
4076 }
4077
4078 bool forward = false;
4079 uint32_t bias = 0;
4080 uint32_t lod = 0;
4081 uint32_t flags = 0;
4082
4083 if (length)
4084 {
4085 flags = *opt++;
4086 length--;
4087 }
4088
4089 auto test = [&](uint32_t &v, uint32_t flag) {
4090 if (length && (flags & flag))
4091 {
4092 v = *opt++;
4093 length--;
4094 }
4095 };
4096
4097 test(bias, ImageOperandsBiasMask);
4098 test(lod, ImageOperandsLodMask);
4099
4100 auto &texel_type = expression_type(texel_id);
4101 auto store_type = texel_type;
4102 store_type.vecsize = 4;
4103
4104 statement(join(to_expression(img_id), ".write(",
4105 remap_swizzle(store_type, texel_type.vecsize, to_expression(texel_id)), ", ",
4106 to_function_args(img_id, img_type, true, false, false, coord_id, 0, 0, 0, 0, lod, 0, 0, 0, 0, 0,
4107 0, &forward),
4108 ");"));
4109
4110 if (p_var && variable_storage_is_aliased(*p_var))
4111 flush_all_aliased_variables();
4112
4113 break;
4114 }
4115
4116 case OpImageQuerySize:
4117 case OpImageQuerySizeLod:
4118 {
4119 uint32_t rslt_type_id = ops[0];
4120 auto &rslt_type = get<SPIRType>(rslt_type_id);
4121
4122 uint32_t id = ops[1];
4123
4124 uint32_t img_id = ops[2];
4125 string img_exp = to_expression(img_id);
4126 auto &img_type = expression_type(img_id);
4127 Dim img_dim = img_type.image.dim;
4128 bool img_is_array = img_type.image.arrayed;
4129
4130 if (img_type.basetype != SPIRType::Image)
4131 SPIRV_CROSS_THROW("Invalid type for OpImageQuerySize.");
4132
4133 string lod;
4134 if (opcode == OpImageQuerySizeLod)
4135 {
4136 // LOD index defaults to zero, so don't bother outputing level zero index
4137 string decl_lod = to_expression(ops[3]);
4138 if (decl_lod != "0")
4139 lod = decl_lod;
4140 }
4141
4142 string expr = type_to_glsl(rslt_type) + "(";
4143 expr += img_exp + ".get_width(" + lod + ")";
4144
4145 if (img_dim == Dim2D || img_dim == DimCube || img_dim == Dim3D)
4146 expr += ", " + img_exp + ".get_height(" + lod + ")";
4147
4148 if (img_dim == Dim3D)
4149 expr += ", " + img_exp + ".get_depth(" + lod + ")";
4150
4151 if (img_is_array)
4152 expr += ", " + img_exp + ".get_array_size()";
4153
4154 expr += ")";
4155
4156 emit_op(rslt_type_id, id, expr, should_forward(img_id));
4157
4158 break;
4159 }
4160
4161 case OpImageQueryLod:
4162 {
4163 if (!msl_options.supports_msl_version(2, 2))
4164 SPIRV_CROSS_THROW("ImageQueryLod is only supported on MSL 2.2 and up.");
4165 uint32_t result_type = ops[0];
4166 uint32_t id = ops[1];
4167 uint32_t image_id = ops[2];
4168 uint32_t coord_id = ops[3];
4169 emit_uninitialized_temporary_expression(result_type, id);
4170
4171 auto sampler_expr = to_sampler_expression(image_id);
4172 auto *combined = maybe_get<SPIRCombinedImageSampler>(image_id);
4173 auto image_expr = combined ? to_expression(combined->image) : to_expression(image_id);
4174
4175 // TODO: It is unclear if calculcate_clamped_lod also conditionally rounds
4176 // the reported LOD based on the sampler. NEAREST miplevel should
4177 // round the LOD, but LINEAR miplevel should not round.
4178 // Let's hope this does not become an issue ...
4179 statement(to_expression(id), ".x = ", image_expr, ".calculate_clamped_lod(", sampler_expr, ", ",
4180 to_expression(coord_id), ");");
4181 statement(to_expression(id), ".y = ", image_expr, ".calculate_unclamped_lod(", sampler_expr, ", ",
4182 to_expression(coord_id), ");");
4183 register_control_dependent_expression(id);
4184 break;
4185 }
4186
4187 #define MSL_ImgQry(qrytype) \
4188 do \
4189 { \
4190 uint32_t rslt_type_id = ops[0]; \
4191 auto &rslt_type = get<SPIRType>(rslt_type_id); \
4192 uint32_t id = ops[1]; \
4193 uint32_t img_id = ops[2]; \
4194 string img_exp = to_expression(img_id); \
4195 string expr = type_to_glsl(rslt_type) + "(" + img_exp + ".get_num_" #qrytype "())"; \
4196 emit_op(rslt_type_id, id, expr, should_forward(img_id)); \
4197 } while (false)
4198
4199 case OpImageQueryLevels:
4200 MSL_ImgQry(mip_levels);
4201 break;
4202
4203 case OpImageQuerySamples:
4204 MSL_ImgQry(samples);
4205 break;
4206
4207 case OpImage:
4208 {
4209 uint32_t result_type = ops[0];
4210 uint32_t id = ops[1];
4211 auto *combined = maybe_get<SPIRCombinedImageSampler>(ops[2]);
4212
4213 if (combined)
4214 {
4215 auto &e = emit_op(result_type, id, to_expression(combined->image), true, true);
4216 auto *var = maybe_get_backing_variable(combined->image);
4217 if (var)
4218 e.loaded_from = var->self;
4219 }
4220 else
4221 {
4222 auto &e = emit_op(result_type, id, to_expression(ops[2]), true, true);
4223 auto *var = maybe_get_backing_variable(ops[2]);
4224 if (var)
4225 e.loaded_from = var->self;
4226 }
4227 break;
4228 }
4229
4230 case OpImageTexelPointer:
4231 SPIRV_CROSS_THROW("MSL does not support atomic operations on images or texel buffers.");
4232
4233 // Casting
4234 case OpQuantizeToF16:
4235 {
4236 uint32_t result_type = ops[0];
4237 uint32_t id = ops[1];
4238 uint32_t arg = ops[2];
4239
4240 string exp;
4241 auto &type = get<SPIRType>(result_type);
4242
4243 switch (type.vecsize)
4244 {
4245 case 1:
4246 exp = join("float(half(", to_expression(arg), "))");
4247 break;
4248 case 2:
4249 exp = join("float2(half2(", to_expression(arg), "))");
4250 break;
4251 case 3:
4252 exp = join("float3(half3(", to_expression(arg), "))");
4253 break;
4254 case 4:
4255 exp = join("float4(half4(", to_expression(arg), "))");
4256 break;
4257 default:
4258 SPIRV_CROSS_THROW("Illegal argument to OpQuantizeToF16.");
4259 }
4260
4261 emit_op(result_type, id, exp, should_forward(arg));
4262 break;
4263 }
4264
4265 case OpInBoundsAccessChain:
4266 case OpAccessChain:
4267 case OpPtrAccessChain:
4268 if (is_tessellation_shader())
4269 {
4270 if (!emit_tessellation_access_chain(ops, instruction.length))
4271 CompilerGLSL::emit_instruction(instruction);
4272 }
4273 else
4274 CompilerGLSL::emit_instruction(instruction);
4275 break;
4276
4277 case OpStore:
4278 if (is_out_of_bounds_tessellation_level(ops[0]))
4279 break;
4280
4281 if (maybe_emit_array_assignment(ops[0], ops[1]))
4282 break;
4283
4284 CompilerGLSL::emit_instruction(instruction);
4285 break;
4286
4287 // Compute barriers
4288 case OpMemoryBarrier:
4289 emit_barrier(0, ops[0], ops[1]);
4290 break;
4291
4292 case OpControlBarrier:
4293 // In GLSL a memory barrier is often followed by a control barrier.
4294 // But in MSL, memory barriers are also control barriers, so don't
4295 // emit a simple control barrier if a memory barrier has just been emitted.
4296 if (previous_instruction_opcode != OpMemoryBarrier)
4297 emit_barrier(ops[0], ops[1], ops[2]);
4298 break;
4299
4300 case OpVectorTimesMatrix:
4301 case OpMatrixTimesVector:
4302 {
4303 // If the matrix needs transpose and it is square or packed, just flip the multiply order.
4304 uint32_t mtx_id = ops[opcode == OpMatrixTimesVector ? 2 : 3];
4305 auto *e = maybe_get<SPIRExpression>(mtx_id);
4306 auto &t = expression_type(mtx_id);
4307 bool is_packed = has_extended_decoration(mtx_id, SPIRVCrossDecorationPacked);
4308 if (e && e->need_transpose && (t.columns == t.vecsize || is_packed))
4309 {
4310 e->need_transpose = false;
4311 emit_binary_op(ops[0], ops[1], ops[3], ops[2], "*");
4312 e->need_transpose = true;
4313 }
4314 else
4315 MSL_BOP(*);
4316 break;
4317 }
4318
4319 case OpOuterProduct:
4320 {
4321 uint32_t result_type = ops[0];
4322 uint32_t id = ops[1];
4323 uint32_t a = ops[2];
4324 uint32_t b = ops[3];
4325
4326 auto &type = get<SPIRType>(result_type);
4327 string expr = type_to_glsl_constructor(type);
4328 expr += "(";
4329 for (uint32_t col = 0; col < type.columns; col++)
4330 {
4331 expr += to_enclosed_expression(a);
4332 expr += " * ";
4333 expr += to_extract_component_expression(b, col);
4334 if (col + 1 < type.columns)
4335 expr += ", ";
4336 }
4337 expr += ")";
4338 emit_op(result_type, id, expr, should_forward(a) && should_forward(b));
4339 inherit_expression_dependencies(id, a);
4340 inherit_expression_dependencies(id, b);
4341 break;
4342 }
4343
4344 case OpIAddCarry:
4345 case OpISubBorrow:
4346 {
4347 uint32_t result_type = ops[0];
4348 uint32_t result_id = ops[1];
4349 uint32_t op0 = ops[2];
4350 uint32_t op1 = ops[3];
4351 forced_temporaries.insert(result_id);
4352 auto &type = get<SPIRType>(result_type);
4353 statement(variable_decl(type, to_name(result_id)), ";");
4354 set<SPIRExpression>(result_id, to_name(result_id), result_type, true);
4355
4356 auto &res_type = get<SPIRType>(type.member_types[1]);
4357 if (opcode == OpIAddCarry)
4358 {
4359 statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " + ",
4360 to_enclosed_expression(op1), ";");
4361 statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
4362 "(1), ", type_to_glsl(res_type), "(0), ", to_expression(result_id), ".", to_member_name(type, 0),
4363 " >= max(", to_expression(op0), ", ", to_expression(op1), "));");
4364 }
4365 else
4366 {
4367 statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " - ",
4368 to_enclosed_expression(op1), ";");
4369 statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
4370 "(1), ", type_to_glsl(res_type), "(0), ", to_enclosed_expression(op0),
4371 " >= ", to_enclosed_expression(op1), ");");
4372 }
4373 break;
4374 }
4375
4376 case OpUMulExtended:
4377 case OpSMulExtended:
4378 {
4379 uint32_t result_type = ops[0];
4380 uint32_t result_id = ops[1];
4381 uint32_t op0 = ops[2];
4382 uint32_t op1 = ops[3];
4383 forced_temporaries.insert(result_id);
4384 auto &type = get<SPIRType>(result_type);
4385 statement(variable_decl(type, to_name(result_id)), ";");
4386 set<SPIRExpression>(result_id, to_name(result_id), result_type, true);
4387
4388 statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " * ",
4389 to_enclosed_expression(op1), ";");
4390 statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(", to_expression(op0), ", ",
4391 to_expression(op1), ");");
4392 break;
4393 }
4394
4395 case OpArrayLength:
4396 {
4397 auto &type = expression_type(ops[2]);
4398 uint32_t offset = type_struct_member_offset(type, ops[3]);
4399 uint32_t stride = type_struct_member_array_stride(type, ops[3]);
4400
4401 auto expr = join("(", to_buffer_size_expression(ops[2]), " - ", offset, ") / ", stride);
4402 emit_op(ops[0], ops[1], expr, true);
4403 break;
4404 }
4405
4406 default:
4407 CompilerGLSL::emit_instruction(instruction);
4408 break;
4409 }
4410
4411 previous_instruction_opcode = opcode;
4412 }
4413
emit_barrier(uint32_t id_exe_scope,uint32_t id_mem_scope,uint32_t id_mem_sem)4414 void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uint32_t id_mem_sem)
4415 {
4416 if (get_execution_model() != ExecutionModelGLCompute && get_execution_model() != ExecutionModelTessellationControl)
4417 return;
4418
4419 uint32_t exe_scope = id_exe_scope ? get<SPIRConstant>(id_exe_scope).scalar() : uint32_t(ScopeInvocation);
4420 uint32_t mem_scope = id_mem_scope ? get<SPIRConstant>(id_mem_scope).scalar() : uint32_t(ScopeInvocation);
4421 // Use the wider of the two scopes (smaller value)
4422 exe_scope = min(exe_scope, mem_scope);
4423
4424 string bar_stmt;
4425 if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2))
4426 bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier";
4427 else
4428 bar_stmt = "threadgroup_barrier";
4429 bar_stmt += "(";
4430
4431 uint32_t mem_sem = id_mem_sem ? get<SPIRConstant>(id_mem_sem).scalar() : uint32_t(MemorySemanticsMaskNone);
4432
4433 // Use the | operator to combine flags if we can.
4434 if (msl_options.supports_msl_version(1, 2))
4435 {
4436 string mem_flags = "";
4437 // For tesc shaders, this also affects objects in the Output storage class.
4438 // Since in Metal, these are placed in a device buffer, we have to sync device memory here.
4439 if (get_execution_model() == ExecutionModelTessellationControl ||
4440 (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)))
4441 mem_flags += "mem_flags::mem_device";
4442 if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask |
4443 MemorySemanticsAtomicCounterMemoryMask))
4444 {
4445 if (!mem_flags.empty())
4446 mem_flags += " | ";
4447 mem_flags += "mem_flags::mem_threadgroup";
4448 }
4449 if (mem_sem & MemorySemanticsImageMemoryMask)
4450 {
4451 if (!mem_flags.empty())
4452 mem_flags += " | ";
4453 mem_flags += "mem_flags::mem_texture";
4454 }
4455
4456 if (mem_flags.empty())
4457 mem_flags = "mem_flags::mem_none";
4458
4459 bar_stmt += mem_flags;
4460 }
4461 else
4462 {
4463 if ((mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)) &&
4464 (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask |
4465 MemorySemanticsAtomicCounterMemoryMask)))
4466 bar_stmt += "mem_flags::mem_device_and_threadgroup";
4467 else if (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask))
4468 bar_stmt += "mem_flags::mem_device";
4469 else if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask |
4470 MemorySemanticsAtomicCounterMemoryMask))
4471 bar_stmt += "mem_flags::mem_threadgroup";
4472 else if (mem_sem & MemorySemanticsImageMemoryMask)
4473 bar_stmt += "mem_flags::mem_texture";
4474 else
4475 bar_stmt += "mem_flags::mem_none";
4476 }
4477
4478 if (msl_options.is_ios() && (msl_options.supports_msl_version(2) && !msl_options.supports_msl_version(2, 1)))
4479 {
4480 bar_stmt += ", ";
4481
4482 switch (mem_scope)
4483 {
4484 case ScopeCrossDevice:
4485 case ScopeDevice:
4486 bar_stmt += "memory_scope_device";
4487 break;
4488
4489 case ScopeSubgroup:
4490 case ScopeInvocation:
4491 bar_stmt += "memory_scope_simdgroup";
4492 break;
4493
4494 case ScopeWorkgroup:
4495 default:
4496 bar_stmt += "memory_scope_threadgroup";
4497 break;
4498 }
4499 }
4500
4501 bar_stmt += ");";
4502
4503 statement(bar_stmt);
4504
4505 assert(current_emitting_block);
4506 flush_control_dependent_expressions(current_emitting_block->self);
4507 flush_all_active_variables();
4508 }
4509
emit_array_copy(const string & lhs,uint32_t rhs_id)4510 void CompilerMSL::emit_array_copy(const string &lhs, uint32_t rhs_id)
4511 {
4512 // Assignment from an array initializer is fine.
4513 auto &type = expression_type(rhs_id);
4514 auto *var = maybe_get_backing_variable(rhs_id);
4515
4516 // Unfortunately, we cannot template on address space in MSL,
4517 // so explicit address space redirection it is ...
4518 bool is_constant = false;
4519 if (ir.ids[rhs_id].get_type() == TypeConstant)
4520 {
4521 is_constant = true;
4522 }
4523 else if (var && var->remapped_variable && var->statically_assigned &&
4524 ir.ids[var->static_expression].get_type() == TypeConstant)
4525 {
4526 is_constant = true;
4527 }
4528
4529 // For the case where we have OpLoad triggering an array copy,
4530 // we cannot easily detect this case ahead of time since it's
4531 // context dependent. We might have to force a recompile here
4532 // if this is the only use of array copies in our shader.
4533 if (type.array.size() > 1)
4534 {
4535 if (type.array.size() > SPVFuncImplArrayCopyMultidimMax)
4536 SPIRV_CROSS_THROW("Cannot support this many dimensions for arrays of arrays.");
4537 auto func = static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + type.array.size());
4538 if (spv_function_implementations.count(func) == 0)
4539 {
4540 spv_function_implementations.insert(func);
4541 suppress_missing_prototypes = true;
4542 force_recompile();
4543 }
4544 }
4545 else if (spv_function_implementations.count(SPVFuncImplArrayCopy) == 0)
4546 {
4547 spv_function_implementations.insert(SPVFuncImplArrayCopy);
4548 suppress_missing_prototypes = true;
4549 force_recompile();
4550 }
4551
4552 const char *tag = is_constant ? "FromConstant" : "FromStack";
4553 statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ", ", to_expression(rhs_id), ");");
4554 }
4555
4556 // Since MSL does not allow arrays to be copied via simple variable assignment,
4557 // if the LHS and RHS represent an assignment of an entire array, it must be
4558 // implemented by calling an array copy function.
4559 // Returns whether the struct assignment was emitted.
maybe_emit_array_assignment(uint32_t id_lhs,uint32_t id_rhs)4560 bool CompilerMSL::maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs)
4561 {
4562 // We only care about assignments of an entire array
4563 auto &type = expression_type(id_rhs);
4564 if (type.array.size() == 0)
4565 return false;
4566
4567 auto *var = maybe_get<SPIRVariable>(id_lhs);
4568
4569 // Is this a remapped, static constant? Don't do anything.
4570 if (var && var->remapped_variable && var->statically_assigned)
4571 return true;
4572
4573 if (ir.ids[id_rhs].get_type() == TypeConstant && var && var->deferred_declaration)
4574 {
4575 // Special case, if we end up declaring a variable when assigning the constant array,
4576 // we can avoid the copy by directly assigning the constant expression.
4577 // This is likely necessary to be able to use a variable as a true look-up table, as it is unlikely
4578 // the compiler will be able to optimize the spvArrayCopy() into a constant LUT.
4579 // After a variable has been declared, we can no longer assign constant arrays in MSL unfortunately.
4580 statement(to_expression(id_lhs), " = ", constant_expression(get<SPIRConstant>(id_rhs)), ";");
4581 return true;
4582 }
4583
4584 // Ensure the LHS variable has been declared
4585 auto *p_v_lhs = maybe_get_backing_variable(id_lhs);
4586 if (p_v_lhs)
4587 flush_variable_declaration(p_v_lhs->self);
4588
4589 emit_array_copy(to_expression(id_lhs), id_rhs);
4590 register_write(id_lhs);
4591
4592 return true;
4593 }
4594
4595 // Emits one of the atomic functions. In MSL, the atomic functions operate on pointers
emit_atomic_func_op(uint32_t result_type,uint32_t result_id,const char * op,uint32_t mem_order_1,uint32_t mem_order_2,bool has_mem_order_2,uint32_t obj,uint32_t op1,bool op1_is_pointer,bool op1_is_literal,uint32_t op2)4596 void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, uint32_t mem_order_1,
4597 uint32_t mem_order_2, bool has_mem_order_2, uint32_t obj, uint32_t op1,
4598 bool op1_is_pointer, bool op1_is_literal, uint32_t op2)
4599 {
4600 forced_temporaries.insert(result_id);
4601
4602 string exp = string(op) + "(";
4603
4604 auto &type = get_pointee_type(expression_type(obj));
4605 exp += "(volatile ";
4606 auto *var = maybe_get_backing_variable(obj);
4607 if (!var)
4608 SPIRV_CROSS_THROW("No backing variable for atomic operation.");
4609 exp += get_argument_address_space(*var);
4610 exp += " atomic_";
4611 exp += type_to_glsl(type);
4612 exp += "*)";
4613
4614 exp += "&";
4615 exp += to_enclosed_expression(obj);
4616
4617 bool is_atomic_compare_exchange_strong = op1_is_pointer && op1;
4618
4619 if (is_atomic_compare_exchange_strong)
4620 {
4621 assert(strcmp(op, "atomic_compare_exchange_weak_explicit") == 0);
4622 assert(op2);
4623 assert(has_mem_order_2);
4624 exp += ", &";
4625 exp += to_name(result_id);
4626 exp += ", ";
4627 exp += to_expression(op2);
4628 exp += ", ";
4629 exp += get_memory_order(mem_order_1);
4630 exp += ", ";
4631 exp += get_memory_order(mem_order_2);
4632 exp += ")";
4633
4634 // MSL only supports the weak atomic compare exchange, so emit a CAS loop here.
4635 // The MSL function returns false if the atomic write fails OR the comparison test fails,
4636 // so we must validate that it wasn't the comparison test that failed before continuing
4637 // the CAS loop, otherwise it will loop infinitely, with the comparison test always failing.
4638 // The function updates the comparitor value from the memory value, so the additional
4639 // comparison test evaluates the memory value against the expected value.
4640 statement(variable_decl(type, to_name(result_id)), ";");
4641 statement("do");
4642 begin_scope();
4643 statement(to_name(result_id), " = ", to_expression(op1), ";");
4644 end_scope_decl(join("while (!", exp, " && ", to_name(result_id), " == ", to_enclosed_expression(op1), ")"));
4645 set<SPIRExpression>(result_id, to_name(result_id), result_type, true);
4646 }
4647 else
4648 {
4649 assert(strcmp(op, "atomic_compare_exchange_weak_explicit") != 0);
4650 if (op1)
4651 {
4652 if (op1_is_literal)
4653 exp += join(", ", op1);
4654 else
4655 exp += ", " + to_expression(op1);
4656 }
4657 if (op2)
4658 exp += ", " + to_expression(op2);
4659
4660 exp += string(", ") + get_memory_order(mem_order_1);
4661 if (has_mem_order_2)
4662 exp += string(", ") + get_memory_order(mem_order_2);
4663
4664 exp += ")";
4665 emit_op(result_type, result_id, exp, false);
4666 }
4667
4668 flush_all_atomic_capable_variables();
4669 }
4670
4671 // Metal only supports relaxed memory order for now
get_memory_order(uint32_t)4672 const char *CompilerMSL::get_memory_order(uint32_t)
4673 {
4674 return "memory_order_relaxed";
4675 }
4676
4677 // Override for MSL-specific extension syntax instructions
emit_glsl_op(uint32_t result_type,uint32_t id,uint32_t eop,const uint32_t * args,uint32_t count)4678 void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
4679 {
4680 auto op = static_cast<GLSLstd450>(eop);
4681
4682 // If we need to do implicit bitcasts, make sure we do it with the correct type.
4683 uint32_t integer_width = get_integer_width_for_glsl_instruction(op, args, count);
4684 auto int_type = to_signed_basetype(integer_width);
4685 auto uint_type = to_unsigned_basetype(integer_width);
4686
4687 switch (op)
4688 {
4689 case GLSLstd450Atan2:
4690 emit_binary_func_op(result_type, id, args[0], args[1], "atan2");
4691 break;
4692 case GLSLstd450InverseSqrt:
4693 emit_unary_func_op(result_type, id, args[0], "rsqrt");
4694 break;
4695 case GLSLstd450RoundEven:
4696 emit_unary_func_op(result_type, id, args[0], "rint");
4697 break;
4698
4699 case GLSLstd450FindSMsb:
4700 emit_unary_func_op_cast(result_type, id, args[0], "findSMSB", int_type, int_type);
4701 break;
4702
4703 case GLSLstd450FindUMsb:
4704 emit_unary_func_op_cast(result_type, id, args[0], "findUMSB", uint_type, uint_type);
4705 break;
4706
4707 case GLSLstd450PackSnorm4x8:
4708 emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm4x8");
4709 break;
4710 case GLSLstd450PackUnorm4x8:
4711 emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm4x8");
4712 break;
4713 case GLSLstd450PackSnorm2x16:
4714 emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm2x16");
4715 break;
4716 case GLSLstd450PackUnorm2x16:
4717 emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm2x16");
4718 break;
4719
4720 case GLSLstd450PackHalf2x16:
4721 {
4722 auto expr = join("as_type<uint>(half2(", to_expression(args[0]), "))");
4723 emit_op(result_type, id, expr, should_forward(args[0]));
4724 inherit_expression_dependencies(id, args[0]);
4725 break;
4726 }
4727
4728 case GLSLstd450UnpackSnorm4x8:
4729 emit_unary_func_op(result_type, id, args[0], "unpack_snorm4x8_to_float");
4730 break;
4731 case GLSLstd450UnpackUnorm4x8:
4732 emit_unary_func_op(result_type, id, args[0], "unpack_unorm4x8_to_float");
4733 break;
4734 case GLSLstd450UnpackSnorm2x16:
4735 emit_unary_func_op(result_type, id, args[0], "unpack_snorm2x16_to_float");
4736 break;
4737 case GLSLstd450UnpackUnorm2x16:
4738 emit_unary_func_op(result_type, id, args[0], "unpack_unorm2x16_to_float");
4739 break;
4740
4741 case GLSLstd450UnpackHalf2x16:
4742 {
4743 auto expr = join("float2(as_type<half2>(", to_expression(args[0]), "))");
4744 emit_op(result_type, id, expr, should_forward(args[0]));
4745 inherit_expression_dependencies(id, args[0]);
4746 break;
4747 }
4748
4749 case GLSLstd450PackDouble2x32:
4750 emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450PackDouble2x32"); // Currently unsupported
4751 break;
4752 case GLSLstd450UnpackDouble2x32:
4753 emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450UnpackDouble2x32"); // Currently unsupported
4754 break;
4755
4756 case GLSLstd450MatrixInverse:
4757 {
4758 auto &mat_type = get<SPIRType>(result_type);
4759 switch (mat_type.columns)
4760 {
4761 case 2:
4762 emit_unary_func_op(result_type, id, args[0], "spvInverse2x2");
4763 break;
4764 case 3:
4765 emit_unary_func_op(result_type, id, args[0], "spvInverse3x3");
4766 break;
4767 case 4:
4768 emit_unary_func_op(result_type, id, args[0], "spvInverse4x4");
4769 break;
4770 default:
4771 break;
4772 }
4773 break;
4774 }
4775
4776 case GLSLstd450FMin:
4777 // If the result type isn't float, don't bother calling the specific
4778 // precise::/fast:: version. Metal doesn't have those for half and
4779 // double types.
4780 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
4781 emit_binary_func_op(result_type, id, args[0], args[1], "min");
4782 else
4783 emit_binary_func_op(result_type, id, args[0], args[1], "fast::min");
4784 break;
4785
4786 case GLSLstd450FMax:
4787 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
4788 emit_binary_func_op(result_type, id, args[0], args[1], "max");
4789 else
4790 emit_binary_func_op(result_type, id, args[0], args[1], "fast::max");
4791 break;
4792
4793 case GLSLstd450FClamp:
4794 // TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
4795 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
4796 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
4797 else
4798 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "fast::clamp");
4799 break;
4800
4801 case GLSLstd450NMin:
4802 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
4803 emit_binary_func_op(result_type, id, args[0], args[1], "min");
4804 else
4805 emit_binary_func_op(result_type, id, args[0], args[1], "precise::min");
4806 break;
4807
4808 case GLSLstd450NMax:
4809 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
4810 emit_binary_func_op(result_type, id, args[0], args[1], "max");
4811 else
4812 emit_binary_func_op(result_type, id, args[0], args[1], "precise::max");
4813 break;
4814
4815 case GLSLstd450NClamp:
4816 // TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
4817 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
4818 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
4819 else
4820 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "precise::clamp");
4821 break;
4822
4823 // TODO:
4824 // GLSLstd450InterpolateAtCentroid (centroid_no_perspective qualifier)
4825 // GLSLstd450InterpolateAtSample (sample_no_perspective qualifier)
4826 // GLSLstd450InterpolateAtOffset
4827
4828 case GLSLstd450Distance:
4829 // MSL does not support scalar versions here.
4830 if (expression_type(args[0]).vecsize == 1)
4831 {
4832 // Equivalent to length(a - b) -> abs(a - b).
4833 emit_op(result_type, id,
4834 join("abs(", to_unpacked_expression(args[0]), " - ", to_unpacked_expression(args[1]), ")"),
4835 should_forward(args[0]) && should_forward(args[1]));
4836 inherit_expression_dependencies(id, args[0]);
4837 inherit_expression_dependencies(id, args[1]);
4838 }
4839 else
4840 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
4841 break;
4842
4843 case GLSLstd450Length:
4844 // MSL does not support scalar versions here.
4845 if (expression_type(args[0]).vecsize == 1)
4846 {
4847 // Equivalent to abs().
4848 emit_unary_func_op(result_type, id, args[0], "abs");
4849 }
4850 else
4851 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
4852 break;
4853
4854 case GLSLstd450Normalize:
4855 // MSL does not support scalar versions here.
4856 if (expression_type(args[0]).vecsize == 1)
4857 {
4858 // Returns -1 or 1 for valid input, sign() does the job.
4859 emit_unary_func_op(result_type, id, args[0], "sign");
4860 }
4861 else
4862 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
4863 break;
4864
4865 case GLSLstd450Reflect:
4866 if (get<SPIRType>(result_type).vecsize == 1)
4867 emit_binary_func_op(result_type, id, args[0], args[1], "spvReflect");
4868 else
4869 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
4870 break;
4871
4872 case GLSLstd450Refract:
4873 if (get<SPIRType>(result_type).vecsize == 1)
4874 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvRefract");
4875 else
4876 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
4877 break;
4878
4879 default:
4880 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
4881 break;
4882 }
4883 }
4884
4885 // Emit a structure declaration for the specified interface variable.
emit_interface_block(uint32_t ib_var_id)4886 void CompilerMSL::emit_interface_block(uint32_t ib_var_id)
4887 {
4888 if (ib_var_id)
4889 {
4890 auto &ib_var = get<SPIRVariable>(ib_var_id);
4891 auto &ib_type = get_variable_data_type(ib_var);
4892 assert(ib_type.basetype == SPIRType::Struct && !ib_type.member_types.empty());
4893 emit_struct(ib_type);
4894 }
4895 }
4896
4897 // Emits the declaration signature of the specified function.
4898 // If this is the entry point function, Metal-specific return value and function arguments are added.
emit_function_prototype(SPIRFunction & func,const Bitset &)4899 void CompilerMSL::emit_function_prototype(SPIRFunction &func, const Bitset &)
4900 {
4901 if (func.self != ir.default_entry_point)
4902 add_function_overload(func);
4903
4904 local_variable_names = resource_names;
4905 string decl;
4906
4907 processing_entry_point = (func.self == ir.default_entry_point);
4908
4909 auto &type = get<SPIRType>(func.return_type);
4910
4911 if (type.array.empty())
4912 {
4913 decl += func_type_decl(type);
4914 }
4915 else
4916 {
4917 // We cannot return arrays in MSL, so "return" through an out variable.
4918 decl = "void";
4919 }
4920
4921 decl += " ";
4922 decl += to_name(func.self);
4923 decl += "(";
4924
4925 if (!type.array.empty())
4926 {
4927 // Fake arrays returns by writing to an out array instead.
4928 decl += "thread ";
4929 decl += type_to_glsl(type);
4930 decl += " (&SPIRV_Cross_return_value)";
4931 decl += type_to_array_glsl(type);
4932 if (!func.arguments.empty())
4933 decl += ", ";
4934 }
4935
4936 if (processing_entry_point)
4937 {
4938 if (msl_options.argument_buffers)
4939 decl += entry_point_args_argument_buffer(!func.arguments.empty());
4940 else
4941 decl += entry_point_args_classic(!func.arguments.empty());
4942
4943 // If entry point function has variables that require early declaration,
4944 // ensure they each have an empty initializer, creating one if needed.
4945 // This is done at this late stage because the initialization expression
4946 // is cleared after each compilation pass.
4947 for (auto var_id : vars_needing_early_declaration)
4948 {
4949 auto &ed_var = get<SPIRVariable>(var_id);
4950 uint32_t &initializer = ed_var.initializer;
4951 if (!initializer)
4952 initializer = ir.increase_bound_by(1);
4953
4954 // Do not override proper initializers.
4955 if (ir.ids[initializer].get_type() == TypeNone || ir.ids[initializer].get_type() == TypeExpression)
4956 set<SPIRExpression>(ed_var.initializer, "{}", ed_var.basetype, true);
4957 }
4958 }
4959
4960 for (auto &arg : func.arguments)
4961 {
4962 uint32_t name_id = arg.id;
4963
4964 auto *var = maybe_get<SPIRVariable>(arg.id);
4965 if (var)
4966 {
4967 // If we need to modify the name of the variable, make sure we modify the original variable.
4968 // Our alias is just a shadow variable.
4969 if (arg.alias_global_variable && var->basevariable)
4970 name_id = var->basevariable;
4971
4972 var->parameter = &arg; // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
4973 }
4974
4975 add_local_variable_name(name_id);
4976
4977 decl += argument_decl(arg);
4978
4979 // Manufacture automatic sampler arg for SampledImage texture
4980 auto &arg_type = get<SPIRType>(arg.type);
4981 if (arg_type.basetype == SPIRType::SampledImage && arg_type.image.dim != DimBuffer)
4982 decl += join(", thread const ", sampler_type(arg_type), " ", to_sampler_expression(arg.id));
4983
4984 // Manufacture automatic swizzle arg.
4985 if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(arg_type))
4986 {
4987 bool arg_is_array = !arg_type.array.empty();
4988 decl += join(", constant uint", arg_is_array ? "* " : "& ", to_swizzle_expression(arg.id));
4989 }
4990
4991 if (buffers_requiring_array_length.count(name_id))
4992 {
4993 bool arg_is_array = !arg_type.array.empty();
4994 decl += join(", constant uint", arg_is_array ? "* " : "& ", to_buffer_size_expression(name_id));
4995 }
4996
4997 if (&arg != &func.arguments.back())
4998 decl += ", ";
4999 }
5000
5001 decl += ")";
5002 statement(decl);
5003 }
5004
5005 // Returns the texture sampling function string for the specified image and sampling characteristics.
to_function_name(uint32_t img,const SPIRType & imgtype,bool is_fetch,bool is_gather,bool,bool,bool has_offset,bool,bool has_dref,uint32_t,uint32_t)5006 string CompilerMSL::to_function_name(uint32_t img, const SPIRType &imgtype, bool is_fetch, bool is_gather, bool, bool,
5007 bool has_offset, bool, bool has_dref, uint32_t, uint32_t)
5008 {
5009 // Special-case gather. We have to alter the component being looked up
5010 // in the swizzle case.
5011 if (msl_options.swizzle_texture_samples && is_gather)
5012 {
5013 string fname = imgtype.image.depth ? "spvGatherCompareSwizzle" : "spvGatherSwizzle";
5014 fname += "<" + type_to_glsl(get<SPIRType>(imgtype.image.type)) + ", metal::" + type_to_glsl(imgtype);
5015 // Add the arg types ourselves. Yes, this sucks, but Clang can't
5016 // deduce template pack parameters in the middle of an argument list.
5017 switch (imgtype.image.dim)
5018 {
5019 case Dim2D:
5020 fname += ", float2";
5021 if (imgtype.image.arrayed)
5022 fname += ", uint";
5023 if (imgtype.image.depth)
5024 fname += ", float";
5025 if (!imgtype.image.depth || has_offset)
5026 fname += ", int2";
5027 break;
5028 case DimCube:
5029 fname += ", float3";
5030 if (imgtype.image.arrayed)
5031 fname += ", uint";
5032 if (imgtype.image.depth)
5033 fname += ", float";
5034 break;
5035 default:
5036 SPIRV_CROSS_THROW("Invalid texture dimension for gather op.");
5037 }
5038 fname += ">";
5039 return fname;
5040 }
5041
5042 auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
5043
5044 // Texture reference
5045 string fname = to_expression(combined ? combined->image : img) + ".";
5046 if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype))
5047 fname = "spvTextureSwizzle(" + fname;
5048
5049 // Texture function and sampler
5050 if (is_fetch)
5051 fname += "read";
5052 else if (is_gather)
5053 fname += "gather";
5054 else
5055 fname += "sample";
5056
5057 if (has_dref)
5058 fname += "_compare";
5059
5060 return fname;
5061 }
5062
convert_to_f32(const string & expr,uint32_t components)5063 string CompilerMSL::convert_to_f32(const string &expr, uint32_t components)
5064 {
5065 SPIRType t;
5066 t.basetype = SPIRType::Float;
5067 t.vecsize = components;
5068 t.columns = 1;
5069 return join(type_to_glsl_constructor(t), "(", expr, ")");
5070 }
5071
sampling_type_needs_f32_conversion(const SPIRType & type)5072 static inline bool sampling_type_needs_f32_conversion(const SPIRType &type)
5073 {
5074 // Double is not supported to begin with, but doesn't hurt to check for completion.
5075 return type.basetype == SPIRType::Half || type.basetype == SPIRType::Double;
5076 }
5077
5078 // Returns the function args for a texture sampling function for the specified image and sampling characteristics.
to_function_args(uint32_t img,const SPIRType & imgtype,bool is_fetch,bool is_gather,bool is_proj,uint32_t coord,uint32_t,uint32_t dref,uint32_t grad_x,uint32_t grad_y,uint32_t lod,uint32_t coffset,uint32_t offset,uint32_t bias,uint32_t comp,uint32_t sample,uint32_t minlod,bool * p_forward)5079 string CompilerMSL::to_function_args(uint32_t img, const SPIRType &imgtype, bool is_fetch, bool is_gather, bool is_proj,
5080 uint32_t coord, uint32_t, uint32_t dref, uint32_t grad_x, uint32_t grad_y,
5081 uint32_t lod, uint32_t coffset, uint32_t offset, uint32_t bias, uint32_t comp,
5082 uint32_t sample, uint32_t minlod, bool *p_forward)
5083 {
5084 string farg_str;
5085 if (!is_fetch)
5086 farg_str += to_sampler_expression(img);
5087
5088 if (msl_options.swizzle_texture_samples && is_gather)
5089 {
5090 if (!farg_str.empty())
5091 farg_str += ", ";
5092
5093 auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
5094 farg_str += to_expression(combined ? combined->image : img);
5095 }
5096
5097 // Texture coordinates
5098 bool forward = should_forward(coord);
5099 auto coord_expr = to_enclosed_expression(coord);
5100 auto &coord_type = expression_type(coord);
5101 bool coord_is_fp = type_is_floating_point(coord_type);
5102 bool is_cube_fetch = false;
5103
5104 string tex_coords = coord_expr;
5105 uint32_t alt_coord_component = 0;
5106
5107 switch (imgtype.image.dim)
5108 {
5109
5110 case Dim1D:
5111 if (coord_type.vecsize > 1)
5112 tex_coords = enclose_expression(tex_coords) + ".x";
5113
5114 if (is_fetch)
5115 tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
5116 else if (sampling_type_needs_f32_conversion(coord_type))
5117 tex_coords = convert_to_f32(tex_coords, 1);
5118
5119 alt_coord_component = 1;
5120 break;
5121
5122 case DimBuffer:
5123 if (coord_type.vecsize > 1)
5124 tex_coords = enclose_expression(tex_coords) + ".x";
5125
5126 if (msl_options.texture_buffer_native)
5127 {
5128 tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
5129 }
5130 else
5131 {
5132 // Metal texel buffer textures are 2D, so convert 1D coord to 2D.
5133 if (is_fetch)
5134 tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
5135 }
5136
5137 alt_coord_component = 1;
5138 break;
5139
5140 case DimSubpassData:
5141 if (imgtype.image.ms)
5142 tex_coords = "uint2(gl_FragCoord.xy)";
5143 else
5144 tex_coords = join("uint2(gl_FragCoord.xy), 0");
5145 break;
5146
5147 case Dim2D:
5148 if (coord_type.vecsize > 2)
5149 tex_coords = enclose_expression(tex_coords) + ".xy";
5150
5151 if (is_fetch)
5152 tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
5153 else if (sampling_type_needs_f32_conversion(coord_type))
5154 tex_coords = convert_to_f32(tex_coords, 2);
5155
5156 alt_coord_component = 2;
5157 break;
5158
5159 case Dim3D:
5160 if (coord_type.vecsize > 3)
5161 tex_coords = enclose_expression(tex_coords) + ".xyz";
5162
5163 if (is_fetch)
5164 tex_coords = "uint3(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
5165 else if (sampling_type_needs_f32_conversion(coord_type))
5166 tex_coords = convert_to_f32(tex_coords, 3);
5167
5168 alt_coord_component = 3;
5169 break;
5170
5171 case DimCube:
5172 if (is_fetch)
5173 {
5174 is_cube_fetch = true;
5175 tex_coords += ".xy";
5176 tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
5177 }
5178 else
5179 {
5180 if (coord_type.vecsize > 3)
5181 tex_coords = enclose_expression(tex_coords) + ".xyz";
5182 }
5183
5184 if (sampling_type_needs_f32_conversion(coord_type))
5185 tex_coords = convert_to_f32(tex_coords, 3);
5186
5187 alt_coord_component = 3;
5188 break;
5189
5190 default:
5191 break;
5192 }
5193
5194 if (is_fetch && offset)
5195 {
5196 // Fetch offsets must be applied directly to the coordinate.
5197 forward = forward && should_forward(offset);
5198 auto &type = expression_type(offset);
5199 if (type.basetype != SPIRType::UInt)
5200 tex_coords += " + " + bitcast_expression(SPIRType::UInt, offset);
5201 else
5202 tex_coords += " + " + to_enclosed_expression(offset);
5203 }
5204 else if (is_fetch && coffset)
5205 {
5206 // Fetch offsets must be applied directly to the coordinate.
5207 forward = forward && should_forward(coffset);
5208 auto &type = expression_type(coffset);
5209 if (type.basetype != SPIRType::UInt)
5210 tex_coords += " + " + bitcast_expression(SPIRType::UInt, coffset);
5211 else
5212 tex_coords += " + " + to_enclosed_expression(coffset);
5213 }
5214
5215 // If projection, use alt coord as divisor
5216 if (is_proj)
5217 {
5218 if (sampling_type_needs_f32_conversion(coord_type))
5219 tex_coords += " / " + convert_to_f32(to_extract_component_expression(coord, alt_coord_component), 1);
5220 else
5221 tex_coords += " / " + to_extract_component_expression(coord, alt_coord_component);
5222 }
5223
5224 if (!farg_str.empty())
5225 farg_str += ", ";
5226 farg_str += tex_coords;
5227
5228 // If fetch from cube, add face explicitly
5229 if (is_cube_fetch)
5230 {
5231 // Special case for cube arrays, face and layer are packed in one dimension.
5232 if (imgtype.image.arrayed)
5233 farg_str += ", uint(" + to_extract_component_expression(coord, 2) + ") % 6u";
5234 else
5235 farg_str += ", uint(" + round_fp_tex_coords(to_extract_component_expression(coord, 2), coord_is_fp) + ")";
5236 }
5237
5238 // If array, use alt coord
5239 if (imgtype.image.arrayed)
5240 {
5241 // Special case for cube arrays, face and layer are packed in one dimension.
5242 if (imgtype.image.dim == DimCube && is_fetch)
5243 farg_str += ", uint(" + to_extract_component_expression(coord, 2) + ") / 6u";
5244 else
5245 farg_str += ", uint(" +
5246 round_fp_tex_coords(to_extract_component_expression(coord, alt_coord_component), coord_is_fp) +
5247 ")";
5248 }
5249
5250 // Depth compare reference value
5251 if (dref)
5252 {
5253 forward = forward && should_forward(dref);
5254 farg_str += ", ";
5255
5256 auto &dref_type = expression_type(dref);
5257
5258 string dref_expr;
5259 if (is_proj)
5260 dref_expr =
5261 join(to_enclosed_expression(dref), " / ", to_extract_component_expression(coord, alt_coord_component));
5262 else
5263 dref_expr = to_expression(dref);
5264
5265 if (sampling_type_needs_f32_conversion(dref_type))
5266 dref_expr = convert_to_f32(dref_expr, 1);
5267
5268 farg_str += dref_expr;
5269
5270 if (msl_options.is_macos() && (grad_x || grad_y))
5271 {
5272 // For sample compare, MSL does not support gradient2d for all targets (only iOS apparently according to docs).
5273 // However, the most common case here is to have a constant gradient of 0, as that is the only way to express
5274 // LOD == 0 in GLSL with sampler2DArrayShadow (cascaded shadow mapping).
5275 // We will detect a compile-time constant 0 value for gradient and promote that to level(0) on MSL.
5276 bool constant_zero_x = !grad_x || expression_is_constant_null(grad_x);
5277 bool constant_zero_y = !grad_y || expression_is_constant_null(grad_y);
5278 if (constant_zero_x && constant_zero_y)
5279 {
5280 lod = 0;
5281 grad_x = 0;
5282 grad_y = 0;
5283 farg_str += ", level(0)";
5284 }
5285 else
5286 {
5287 SPIRV_CROSS_THROW("Using non-constant 0.0 gradient() qualifier for sample_compare. This is not "
5288 "supported in MSL macOS.");
5289 }
5290 }
5291
5292 if (msl_options.is_macos() && bias)
5293 {
5294 // Bias is not supported either on macOS with sample_compare.
5295 // Verify it is compile-time zero, and drop the argument.
5296 if (expression_is_constant_null(bias))
5297 {
5298 bias = 0;
5299 }
5300 else
5301 {
5302 SPIRV_CROSS_THROW(
5303 "Using non-constant 0.0 bias() qualifier for sample_compare. This is not supported in MSL macOS.");
5304 }
5305 }
5306 }
5307
5308 // LOD Options
5309 // Metal does not support LOD for 1D textures.
5310 if (bias && imgtype.image.dim != Dim1D)
5311 {
5312 forward = forward && should_forward(bias);
5313 farg_str += ", bias(" + to_expression(bias) + ")";
5314 }
5315
5316 // Metal does not support LOD for 1D textures.
5317 if (lod && imgtype.image.dim != Dim1D)
5318 {
5319 forward = forward && should_forward(lod);
5320 if (is_fetch)
5321 {
5322 farg_str += ", " + to_expression(lod);
5323 }
5324 else
5325 {
5326 farg_str += ", level(" + to_expression(lod) + ")";
5327 }
5328 }
5329 else if (is_fetch && !lod && imgtype.image.dim != Dim1D && imgtype.image.dim != DimBuffer && !imgtype.image.ms &&
5330 imgtype.image.sampled != 2)
5331 {
5332 // Lod argument is optional in OpImageFetch, but we require a LOD value, pick 0 as the default.
5333 // Check for sampled type as well, because is_fetch is also used for OpImageRead in MSL.
5334 farg_str += ", 0";
5335 }
5336
5337 // Metal does not support LOD for 1D textures.
5338 if ((grad_x || grad_y) && imgtype.image.dim != Dim1D)
5339 {
5340 forward = forward && should_forward(grad_x);
5341 forward = forward && should_forward(grad_y);
5342 string grad_opt;
5343 switch (imgtype.image.dim)
5344 {
5345 case Dim2D:
5346 grad_opt = "2d";
5347 break;
5348 case Dim3D:
5349 grad_opt = "3d";
5350 break;
5351 case DimCube:
5352 grad_opt = "cube";
5353 break;
5354 default:
5355 grad_opt = "unsupported_gradient_dimension";
5356 break;
5357 }
5358 farg_str += ", gradient" + grad_opt + "(" + to_expression(grad_x) + ", " + to_expression(grad_y) + ")";
5359 }
5360
5361 if (minlod)
5362 {
5363 if (msl_options.is_macos())
5364 {
5365 if (!msl_options.supports_msl_version(2, 2))
5366 SPIRV_CROSS_THROW("min_lod_clamp() is only supported in MSL 2.2+ and up on macOS.");
5367 }
5368 else if (msl_options.is_ios())
5369 SPIRV_CROSS_THROW("min_lod_clamp() is not supported on iOS.");
5370
5371 forward = forward && should_forward(minlod);
5372 farg_str += ", min_lod_clamp(" + to_expression(minlod) + ")";
5373 }
5374
5375 // Add offsets
5376 string offset_expr;
5377 if (coffset && !is_fetch)
5378 {
5379 forward = forward && should_forward(coffset);
5380 offset_expr = to_expression(coffset);
5381 }
5382 else if (offset && !is_fetch)
5383 {
5384 forward = forward && should_forward(offset);
5385 offset_expr = to_expression(offset);
5386 }
5387
5388 if (!offset_expr.empty())
5389 {
5390 switch (imgtype.image.dim)
5391 {
5392 case Dim2D:
5393 if (coord_type.vecsize > 2)
5394 offset_expr = enclose_expression(offset_expr) + ".xy";
5395
5396 farg_str += ", " + offset_expr;
5397 break;
5398
5399 case Dim3D:
5400 if (coord_type.vecsize > 3)
5401 offset_expr = enclose_expression(offset_expr) + ".xyz";
5402
5403 farg_str += ", " + offset_expr;
5404 break;
5405
5406 default:
5407 break;
5408 }
5409 }
5410
5411 if (comp)
5412 {
5413 // If 2D has gather component, ensure it also has an offset arg
5414 if (imgtype.image.dim == Dim2D && offset_expr.empty())
5415 farg_str += ", int2(0)";
5416
5417 forward = forward && should_forward(comp);
5418 farg_str += ", " + to_component_argument(comp);
5419 }
5420
5421 if (sample)
5422 {
5423 forward = forward && should_forward(sample);
5424 farg_str += ", ";
5425 farg_str += to_expression(sample);
5426 }
5427
5428 if (msl_options.swizzle_texture_samples && is_sampled_image_type(imgtype))
5429 {
5430 // Add the swizzle constant from the swizzle buffer.
5431 if (!is_gather)
5432 farg_str += ")";
5433 farg_str += ", " + to_swizzle_expression(img);
5434 used_swizzle_buffer = true;
5435 }
5436
5437 *p_forward = forward;
5438
5439 return farg_str;
5440 }
5441
5442 // If the texture coordinates are floating point, invokes MSL round() function to round them.
round_fp_tex_coords(string tex_coords,bool coord_is_fp)5443 string CompilerMSL::round_fp_tex_coords(string tex_coords, bool coord_is_fp)
5444 {
5445 return coord_is_fp ? ("round(" + tex_coords + ")") : tex_coords;
5446 }
5447
5448 // Returns a string to use in an image sampling function argument.
5449 // The ID must be a scalar constant.
to_component_argument(uint32_t id)5450 string CompilerMSL::to_component_argument(uint32_t id)
5451 {
5452 if (ir.ids[id].get_type() != TypeConstant)
5453 {
5454 SPIRV_CROSS_THROW("ID " + to_string(id) + " is not an OpConstant.");
5455 return "component::x";
5456 }
5457
5458 uint32_t component_index = get<SPIRConstant>(id).scalar();
5459 switch (component_index)
5460 {
5461 case 0:
5462 return "component::x";
5463 case 1:
5464 return "component::y";
5465 case 2:
5466 return "component::z";
5467 case 3:
5468 return "component::w";
5469
5470 default:
5471 SPIRV_CROSS_THROW("The value (" + to_string(component_index) + ") of OpConstant ID " + to_string(id) +
5472 " is not a valid Component index, which must be one of 0, 1, 2, or 3.");
5473 return "component::x";
5474 }
5475 }
5476
5477 // Establish sampled image as expression object and assign the sampler to it.
emit_sampled_image_op(uint32_t result_type,uint32_t result_id,uint32_t image_id,uint32_t samp_id)5478 void CompilerMSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
5479 {
5480 set<SPIRCombinedImageSampler>(result_id, result_type, image_id, samp_id);
5481 }
5482
5483 // Returns a string representation of the ID, usable as a function arg.
5484 // Manufacture automatic sampler arg for SampledImage texture.
to_func_call_arg(uint32_t id)5485 string CompilerMSL::to_func_call_arg(uint32_t id)
5486 {
5487 string arg_str;
5488
5489 auto *c = maybe_get<SPIRConstant>(id);
5490 if (c && !get<SPIRType>(c->constant_type).array.empty())
5491 {
5492 // If we are passing a constant array directly to a function for some reason,
5493 // the callee will expect an argument in thread const address space
5494 // (since we can only bind to arrays with references in MSL).
5495 // To resolve this, we must emit a copy in this address space.
5496 // This kind of code gen should be rare enough that performance is not a real concern.
5497 // Inline the SPIR-V to avoid this kind of suboptimal codegen.
5498 //
5499 // We risk calling this inside a continue block (invalid code),
5500 // so just create a thread local copy in the current function.
5501 arg_str = join("_", id, "_array_copy");
5502 auto &constants = current_function->constant_arrays_needed_on_stack;
5503 auto itr = find(begin(constants), end(constants), id);
5504 if (itr == end(constants))
5505 {
5506 force_recompile();
5507 constants.push_back(id);
5508 }
5509 }
5510 else
5511 arg_str = CompilerGLSL::to_func_call_arg(id);
5512
5513 // Manufacture automatic sampler arg if the arg is a SampledImage texture.
5514 auto &type = expression_type(id);
5515 if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
5516 {
5517 // Need to check the base variable in case we need to apply a qualified alias.
5518 uint32_t var_id = 0;
5519 auto *sampler_var = maybe_get<SPIRVariable>(id);
5520 if (sampler_var)
5521 var_id = sampler_var->basevariable;
5522
5523 arg_str += ", " + to_sampler_expression(var_id ? var_id : id);
5524 }
5525
5526 uint32_t var_id = 0;
5527 auto *var = maybe_get<SPIRVariable>(id);
5528 if (var)
5529 var_id = var->basevariable;
5530
5531 if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
5532 {
5533 // Need to check the base variable in case we need to apply a qualified alias.
5534 arg_str += ", " + to_swizzle_expression(var_id ? var_id : id);
5535 }
5536
5537 if (buffers_requiring_array_length.count(var_id))
5538 arg_str += ", " + to_buffer_size_expression(var_id ? var_id : id);
5539
5540 return arg_str;
5541 }
5542
5543 // If the ID represents a sampled image that has been assigned a sampler already,
5544 // generate an expression for the sampler, otherwise generate a fake sampler name
5545 // by appending a suffix to the expression constructed from the ID.
to_sampler_expression(uint32_t id)5546 string CompilerMSL::to_sampler_expression(uint32_t id)
5547 {
5548 auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
5549 auto expr = to_expression(combined ? combined->image : id);
5550 auto index = expr.find_first_of('[');
5551
5552 uint32_t samp_id = 0;
5553 if (combined)
5554 samp_id = combined->sampler;
5555
5556 if (index == string::npos)
5557 return samp_id ? to_expression(samp_id) : expr + sampler_name_suffix;
5558 else
5559 {
5560 auto image_expr = expr.substr(0, index);
5561 auto array_expr = expr.substr(index);
5562 return samp_id ? to_expression(samp_id) : (image_expr + sampler_name_suffix + array_expr);
5563 }
5564 }
5565
to_swizzle_expression(uint32_t id)5566 string CompilerMSL::to_swizzle_expression(uint32_t id)
5567 {
5568 auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
5569
5570 auto expr = to_expression(combined ? combined->image : id);
5571 auto index = expr.find_first_of('[');
5572
5573 // If an image is part of an argument buffer translate this to a legal identifier.
5574 for (auto &c : expr)
5575 if (c == '.')
5576 c = '_';
5577
5578 if (index == string::npos)
5579 return expr + swizzle_name_suffix;
5580 else
5581 {
5582 auto image_expr = expr.substr(0, index);
5583 auto array_expr = expr.substr(index);
5584 return image_expr + swizzle_name_suffix + array_expr;
5585 }
5586 }
5587
to_buffer_size_expression(uint32_t id)5588 string CompilerMSL::to_buffer_size_expression(uint32_t id)
5589 {
5590 auto expr = to_expression(id);
5591 auto index = expr.find_first_of('[');
5592
5593 // This is quite crude, but we need to translate the reference name (*spvDescriptorSetN.name) to
5594 // the pointer expression spvDescriptorSetN.name to make a reasonable expression here.
5595 // This only happens if we have argument buffers and we are using OpArrayLength on a lone SSBO in that set.
5596 if (expr.size() >= 3 && expr[0] == '(' && expr[1] == '*')
5597 expr = address_of_expression(expr);
5598
5599 // If a buffer is part of an argument buffer translate this to a legal identifier.
5600 for (auto &c : expr)
5601 if (c == '.')
5602 c = '_';
5603
5604 if (index == string::npos)
5605 return expr + buffer_size_name_suffix;
5606 else
5607 {
5608 auto buffer_expr = expr.substr(0, index);
5609 auto array_expr = expr.substr(index);
5610 return buffer_expr + buffer_size_name_suffix + array_expr;
5611 }
5612 }
5613
5614 // Checks whether the type is a Block all of whose members have DecorationPatch.
is_patch_block(const SPIRType & type)5615 bool CompilerMSL::is_patch_block(const SPIRType &type)
5616 {
5617 if (!has_decoration(type.self, DecorationBlock))
5618 return false;
5619
5620 for (uint32_t i = 0; i < type.member_types.size(); i++)
5621 {
5622 if (!has_member_decoration(type.self, i, DecorationPatch))
5623 return false;
5624 }
5625
5626 return true;
5627 }
5628
5629 // Checks whether the ID is a row_major matrix that requires conversion before use
is_non_native_row_major_matrix(uint32_t id)5630 bool CompilerMSL::is_non_native_row_major_matrix(uint32_t id)
5631 {
5632 // Natively supported row-major matrices do not need to be converted.
5633 if (backend.native_row_major_matrix)
5634 return false;
5635
5636 // Non-matrix or column-major matrix types do not need to be converted.
5637 if (!has_decoration(id, DecorationRowMajor))
5638 return false;
5639
5640 // Generate a function that will swap matrix elements from row-major to column-major.
5641 // Packed row-matrix should just use transpose() function.
5642 if (!has_extended_decoration(id, SPIRVCrossDecorationPacked))
5643 {
5644 const auto type = expression_type(id);
5645 add_convert_row_major_matrix_function(type.columns, type.vecsize);
5646 }
5647
5648 return true;
5649 }
5650
5651 // Checks whether the member is a row_major matrix that requires conversion before use
member_is_non_native_row_major_matrix(const SPIRType & type,uint32_t index)5652 bool CompilerMSL::member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index)
5653 {
5654 // Natively supported row-major matrices do not need to be converted.
5655 if (backend.native_row_major_matrix)
5656 return false;
5657
5658 // Non-matrix or column-major matrix types do not need to be converted.
5659 if (!has_member_decoration(type.self, index, DecorationRowMajor))
5660 return false;
5661
5662 // Generate a function that will swap matrix elements from row-major to column-major.
5663 // Packed row-matrix should just use transpose() function.
5664 if (!has_extended_member_decoration(type.self, index, SPIRVCrossDecorationPacked))
5665 {
5666 const auto mbr_type = get<SPIRType>(type.member_types[index]);
5667 add_convert_row_major_matrix_function(mbr_type.columns, mbr_type.vecsize);
5668 }
5669
5670 return true;
5671 }
5672
5673 // Adds a function suitable for converting a non-square row-major matrix to a column-major matrix.
add_convert_row_major_matrix_function(uint32_t cols,uint32_t rows)5674 void CompilerMSL::add_convert_row_major_matrix_function(uint32_t cols, uint32_t rows)
5675 {
5676 SPVFuncImpl spv_func;
5677 if (cols == rows) // Square matrix...just use transpose() function
5678 return;
5679 else if (cols == 2 && rows == 3)
5680 spv_func = SPVFuncImplRowMajor2x3;
5681 else if (cols == 2 && rows == 4)
5682 spv_func = SPVFuncImplRowMajor2x4;
5683 else if (cols == 3 && rows == 2)
5684 spv_func = SPVFuncImplRowMajor3x2;
5685 else if (cols == 3 && rows == 4)
5686 spv_func = SPVFuncImplRowMajor3x4;
5687 else if (cols == 4 && rows == 2)
5688 spv_func = SPVFuncImplRowMajor4x2;
5689 else if (cols == 4 && rows == 3)
5690 spv_func = SPVFuncImplRowMajor4x3;
5691 else
5692 SPIRV_CROSS_THROW("Could not convert row-major matrix.");
5693
5694 auto rslt = spv_function_implementations.insert(spv_func);
5695 if (rslt.second)
5696 {
5697 suppress_missing_prototypes = true;
5698 force_recompile();
5699 }
5700 }
5701
5702 // Wraps the expression string in a function call that converts the
5703 // row_major matrix result of the expression to a column_major matrix.
convert_row_major_matrix(string exp_str,const SPIRType & exp_type,bool is_packed)5704 string CompilerMSL::convert_row_major_matrix(string exp_str, const SPIRType &exp_type, bool is_packed)
5705 {
5706 strip_enclosed_expression(exp_str);
5707
5708 string func_name;
5709
5710 // Square and packed matrices can just use transpose
5711 if (exp_type.columns == exp_type.vecsize || is_packed)
5712 func_name = "transpose";
5713 else
5714 func_name = string("spvConvertFromRowMajor") + to_string(exp_type.columns) + "x" + to_string(exp_type.vecsize);
5715
5716 return join(func_name, "(", exp_str, ")");
5717 }
5718
5719 // Called automatically at the end of the entry point function
emit_fixup()5720 void CompilerMSL::emit_fixup()
5721 {
5722 if ((get_execution_model() == ExecutionModelVertex ||
5723 get_execution_model() == ExecutionModelTessellationEvaluation) &&
5724 stage_out_var_id && !qual_pos_var_name.empty() && !capture_output_to_buffer)
5725 {
5726 if (options.vertex.fixup_clipspace)
5727 statement(qual_pos_var_name, ".z = (", qual_pos_var_name, ".z + ", qual_pos_var_name,
5728 ".w) * 0.5; // Adjust clip-space for Metal");
5729
5730 if (options.vertex.flip_vert_y)
5731 statement(qual_pos_var_name, ".y = -(", qual_pos_var_name, ".y);", " // Invert Y-axis for Metal");
5732 }
5733 }
5734
5735 // Return a string defining a structure member, with padding and packing.
to_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier)5736 string CompilerMSL::to_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
5737 const string &qualifier)
5738 {
5739 auto &membertype = get<SPIRType>(member_type_id);
5740
5741 // If this member requires padding to maintain alignment, emit a dummy padding member.
5742 MSLStructMemberKey key = get_struct_member_key(type.self, index);
5743 uint32_t pad_len = struct_member_padding[key];
5744 if (pad_len > 0)
5745 statement("char _m", index, "_pad", "[", to_string(pad_len), "];");
5746
5747 // If this member is packed, mark it as so.
5748 string pack_pfx = "";
5749
5750 const SPIRType *effective_membertype = &membertype;
5751 SPIRType override_type;
5752
5753 uint32_t orig_id = 0;
5754 if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID))
5755 orig_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID);
5756
5757 if (member_is_packed_type(type, index))
5758 {
5759 // If we're packing a matrix, output an appropriate typedef
5760 if (membertype.basetype == SPIRType::Struct)
5761 {
5762 pack_pfx = "/* FIXME: A padded struct is needed here. If you see this message, file a bug! */ ";
5763 }
5764 else if (membertype.vecsize > 1 && membertype.columns > 1)
5765 {
5766 uint32_t rows = membertype.vecsize;
5767 uint32_t cols = membertype.columns;
5768 pack_pfx = "packed_";
5769 if (has_member_decoration(type.self, index, DecorationRowMajor))
5770 {
5771 // These are stored transposed.
5772 rows = membertype.columns;
5773 cols = membertype.vecsize;
5774 pack_pfx = "packed_rm_";
5775 }
5776 string base_type = membertype.width == 16 ? "half" : "float";
5777 string td_line = "typedef ";
5778 td_line += "packed_" + base_type + to_string(rows);
5779 td_line += " " + pack_pfx;
5780 // Use the actual matrix size here.
5781 td_line += base_type + to_string(membertype.columns) + "x" + to_string(membertype.vecsize);
5782 td_line += "[" + to_string(cols) + "]";
5783 td_line += ";";
5784 add_typedef_line(td_line);
5785 }
5786 else if (is_array(membertype) && membertype.vecsize <= 2 && membertype.basetype != SPIRType::Struct &&
5787 type_struct_member_array_stride(type, index) == 4 * membertype.width / 8)
5788 {
5789 // A "packed" float array, but we pad here instead to 4-vector.
5790 override_type = membertype;
5791 override_type.vecsize = 4;
5792 effective_membertype = &override_type;
5793 }
5794 else
5795 pack_pfx = "packed_";
5796 }
5797
5798 // Very specifically, image load-store in argument buffers are disallowed on MSL on iOS.
5799 if (msl_options.is_ios() && membertype.basetype == SPIRType::Image && membertype.image.sampled == 2)
5800 {
5801 if (!has_decoration(orig_id, DecorationNonWritable))
5802 SPIRV_CROSS_THROW("Writable images are not allowed in argument buffers on iOS.");
5803 }
5804
5805 // Array information is baked into these types.
5806 string array_type;
5807 if (membertype.basetype != SPIRType::Image && membertype.basetype != SPIRType::Sampler &&
5808 membertype.basetype != SPIRType::SampledImage)
5809 {
5810 array_type = type_to_array_glsl(membertype);
5811 }
5812
5813 return join(pack_pfx, type_to_glsl(*effective_membertype, orig_id), " ", qualifier, to_member_name(type, index),
5814 member_attribute_qualifier(type, index), array_type, ";");
5815 }
5816
5817 // Emit a structure member, padding and packing to maintain the correct memeber alignments.
emit_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier,uint32_t)5818 void CompilerMSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
5819 const string &qualifier, uint32_t)
5820 {
5821 statement(to_struct_member(type, member_type_id, index, qualifier));
5822 }
5823
5824 // Return a MSL qualifier for the specified function attribute member
member_attribute_qualifier(const SPIRType & type,uint32_t index)5825 string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t index)
5826 {
5827 auto &execution = get_entry_point();
5828
5829 uint32_t mbr_type_id = type.member_types[index];
5830 auto &mbr_type = get<SPIRType>(mbr_type_id);
5831
5832 BuiltIn builtin = BuiltInMax;
5833 bool is_builtin = is_member_builtin(type, index, &builtin);
5834
5835 if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary))
5836 return join(" [[id(",
5837 get_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary), ")]]");
5838
5839 // Vertex function inputs
5840 if (execution.model == ExecutionModelVertex && type.storage == StorageClassInput)
5841 {
5842 if (is_builtin)
5843 {
5844 switch (builtin)
5845 {
5846 case BuiltInVertexId:
5847 case BuiltInVertexIndex:
5848 case BuiltInBaseVertex:
5849 case BuiltInInstanceId:
5850 case BuiltInInstanceIndex:
5851 case BuiltInBaseInstance:
5852 return string(" [[") + builtin_qualifier(builtin) + "]]";
5853
5854 case BuiltInDrawIndex:
5855 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
5856
5857 default:
5858 return "";
5859 }
5860 }
5861 uint32_t locn = get_ordered_member_location(type.self, index);
5862 if (locn != k_unknown_location)
5863 return string(" [[attribute(") + convert_to_string(locn) + ")]]";
5864 }
5865
5866 // Vertex and tessellation evaluation function outputs
5867 if ((execution.model == ExecutionModelVertex || execution.model == ExecutionModelTessellationEvaluation) &&
5868 type.storage == StorageClassOutput)
5869 {
5870 if (is_builtin)
5871 {
5872 switch (builtin)
5873 {
5874 case BuiltInPointSize:
5875 // Only mark the PointSize builtin if really rendering points.
5876 // Some shaders may include a PointSize builtin even when used to render
5877 // non-point topologies, and Metal will reject this builtin when compiling
5878 // the shader into a render pipeline that uses a non-point topology.
5879 return msl_options.enable_point_size_builtin ? (string(" [[") + builtin_qualifier(builtin) + "]]") : "";
5880
5881 case BuiltInViewportIndex:
5882 if (!msl_options.supports_msl_version(2, 0))
5883 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
5884 /* fallthrough */
5885 case BuiltInPosition:
5886 case BuiltInLayer:
5887 case BuiltInClipDistance:
5888 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
5889
5890 default:
5891 return "";
5892 }
5893 }
5894 uint32_t comp;
5895 uint32_t locn = get_ordered_member_location(type.self, index, &comp);
5896 if (locn != k_unknown_location)
5897 {
5898 if (comp != k_unknown_component)
5899 return string(" [[user(locn") + convert_to_string(locn) + "_" + convert_to_string(comp) + ")]]";
5900 else
5901 return string(" [[user(locn") + convert_to_string(locn) + ")]]";
5902 }
5903 }
5904
5905 // Tessellation control function inputs
5906 if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassInput)
5907 {
5908 if (is_builtin)
5909 {
5910 switch (builtin)
5911 {
5912 case BuiltInInvocationId:
5913 case BuiltInPrimitiveId:
5914 case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
5915 case BuiltInSubgroupSize: // FIXME: Should work in any stage
5916 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
5917 case BuiltInPatchVertices:
5918 return "";
5919 // Others come from stage input.
5920 default:
5921 break;
5922 }
5923 }
5924 uint32_t locn = get_ordered_member_location(type.self, index);
5925 if (locn != k_unknown_location)
5926 return string(" [[attribute(") + convert_to_string(locn) + ")]]";
5927 }
5928
5929 // Tessellation control function outputs
5930 if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassOutput)
5931 {
5932 // For this type of shader, we always arrange for it to capture its
5933 // output to a buffer. For this reason, qualifiers are irrelevant here.
5934 return "";
5935 }
5936
5937 // Tessellation evaluation function inputs
5938 if (execution.model == ExecutionModelTessellationEvaluation && type.storage == StorageClassInput)
5939 {
5940 if (is_builtin)
5941 {
5942 switch (builtin)
5943 {
5944 case BuiltInPrimitiveId:
5945 case BuiltInTessCoord:
5946 return string(" [[") + builtin_qualifier(builtin) + "]]";
5947 case BuiltInPatchVertices:
5948 return "";
5949 // Others come from stage input.
5950 default:
5951 break;
5952 }
5953 }
5954 // The special control point array must not be marked with an attribute.
5955 if (get_type(type.member_types[index]).basetype == SPIRType::ControlPointArray)
5956 return "";
5957 uint32_t locn = get_ordered_member_location(type.self, index);
5958 if (locn != k_unknown_location)
5959 return string(" [[attribute(") + convert_to_string(locn) + ")]]";
5960 }
5961
5962 // Tessellation evaluation function outputs were handled above.
5963
5964 // Fragment function inputs
5965 if (execution.model == ExecutionModelFragment && type.storage == StorageClassInput)
5966 {
5967 string quals;
5968 if (is_builtin)
5969 {
5970 switch (builtin)
5971 {
5972 case BuiltInViewIndex:
5973 if (!msl_options.multiview)
5974 break;
5975 /* fallthrough */
5976 case BuiltInFrontFacing:
5977 case BuiltInPointCoord:
5978 case BuiltInFragCoord:
5979 case BuiltInSampleId:
5980 case BuiltInSampleMask:
5981 case BuiltInLayer:
5982 case BuiltInBaryCoordNV:
5983 case BuiltInBaryCoordNoPerspNV:
5984 quals = builtin_qualifier(builtin);
5985 break;
5986
5987 default:
5988 break;
5989 }
5990 }
5991 else
5992 {
5993 uint32_t comp;
5994 uint32_t locn = get_ordered_member_location(type.self, index, &comp);
5995 if (locn != k_unknown_location)
5996 {
5997 if (comp != k_unknown_component)
5998 quals = string("user(locn") + convert_to_string(locn) + "_" + convert_to_string(comp) + ")";
5999 else
6000 quals = string("user(locn") + convert_to_string(locn) + ")";
6001 }
6002 }
6003
6004 if (builtin == BuiltInBaryCoordNV || builtin == BuiltInBaryCoordNoPerspNV)
6005 {
6006 if (has_member_decoration(type.self, index, DecorationFlat) ||
6007 has_member_decoration(type.self, index, DecorationCentroid) ||
6008 has_member_decoration(type.self, index, DecorationSample) ||
6009 has_member_decoration(type.self, index, DecorationNoPerspective))
6010 {
6011 // NoPerspective is baked into the builtin type.
6012 SPIRV_CROSS_THROW(
6013 "Flat, Centroid, Sample, NoPerspective decorations are not supported for BaryCoord inputs.");
6014 }
6015 }
6016
6017 // Don't bother decorating integers with the 'flat' attribute; it's
6018 // the default (in fact, the only option). Also don't bother with the
6019 // FragCoord builtin; it's always noperspective on Metal.
6020 if (!type_is_integral(mbr_type) && (!is_builtin || builtin != BuiltInFragCoord))
6021 {
6022 if (has_member_decoration(type.self, index, DecorationFlat))
6023 {
6024 if (!quals.empty())
6025 quals += ", ";
6026 quals += "flat";
6027 }
6028 else if (has_member_decoration(type.self, index, DecorationCentroid))
6029 {
6030 if (!quals.empty())
6031 quals += ", ";
6032 if (has_member_decoration(type.self, index, DecorationNoPerspective))
6033 quals += "centroid_no_perspective";
6034 else
6035 quals += "centroid_perspective";
6036 }
6037 else if (has_member_decoration(type.self, index, DecorationSample))
6038 {
6039 if (!quals.empty())
6040 quals += ", ";
6041 if (has_member_decoration(type.self, index, DecorationNoPerspective))
6042 quals += "sample_no_perspective";
6043 else
6044 quals += "sample_perspective";
6045 }
6046 else if (has_member_decoration(type.self, index, DecorationNoPerspective))
6047 {
6048 if (!quals.empty())
6049 quals += ", ";
6050 quals += "center_no_perspective";
6051 }
6052 }
6053
6054 if (!quals.empty())
6055 return " [[" + quals + "]]";
6056 }
6057
6058 // Fragment function outputs
6059 if (execution.model == ExecutionModelFragment && type.storage == StorageClassOutput)
6060 {
6061 if (is_builtin)
6062 {
6063 switch (builtin)
6064 {
6065 case BuiltInFragStencilRefEXT:
6066 if (!msl_options.supports_msl_version(2, 1))
6067 SPIRV_CROSS_THROW("Stencil export only supported in MSL 2.1 and up.");
6068 return string(" [[") + builtin_qualifier(builtin) + "]]";
6069
6070 case BuiltInSampleMask:
6071 case BuiltInFragDepth:
6072 return string(" [[") + builtin_qualifier(builtin) + "]]";
6073
6074 default:
6075 return "";
6076 }
6077 }
6078 uint32_t locn = get_ordered_member_location(type.self, index);
6079 if (locn != k_unknown_location && has_member_decoration(type.self, index, DecorationIndex))
6080 return join(" [[color(", locn, "), index(", get_member_decoration(type.self, index, DecorationIndex),
6081 ")]]");
6082 else if (locn != k_unknown_location)
6083 return join(" [[color(", locn, ")]]");
6084 else if (has_member_decoration(type.self, index, DecorationIndex))
6085 return join(" [[index(", get_member_decoration(type.self, index, DecorationIndex), ")]]");
6086 else
6087 return "";
6088 }
6089
6090 // Compute function inputs
6091 if (execution.model == ExecutionModelGLCompute && type.storage == StorageClassInput)
6092 {
6093 if (is_builtin)
6094 {
6095 switch (builtin)
6096 {
6097 case BuiltInGlobalInvocationId:
6098 case BuiltInWorkgroupId:
6099 case BuiltInNumWorkgroups:
6100 case BuiltInLocalInvocationId:
6101 case BuiltInLocalInvocationIndex:
6102 case BuiltInNumSubgroups:
6103 case BuiltInSubgroupId:
6104 case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
6105 case BuiltInSubgroupSize: // FIXME: Should work in any stage
6106 return string(" [[") + builtin_qualifier(builtin) + "]]";
6107
6108 default:
6109 return "";
6110 }
6111 }
6112 }
6113
6114 return "";
6115 }
6116
6117 // Returns the location decoration of the member with the specified index in the specified type.
6118 // If the location of the member has been explicitly set, that location is used. If not, this
6119 // function assumes the members are ordered in their location order, and simply returns the
6120 // index as the location.
get_ordered_member_location(uint32_t type_id,uint32_t index,uint32_t * comp)6121 uint32_t CompilerMSL::get_ordered_member_location(uint32_t type_id, uint32_t index, uint32_t *comp)
6122 {
6123 auto &m = ir.meta[type_id];
6124 if (index < m.members.size())
6125 {
6126 auto &dec = m.members[index];
6127 if (comp)
6128 {
6129 if (dec.decoration_flags.get(DecorationComponent))
6130 *comp = dec.component;
6131 else
6132 *comp = k_unknown_component;
6133 }
6134 if (dec.decoration_flags.get(DecorationLocation))
6135 return dec.location;
6136 }
6137
6138 return index;
6139 }
6140
6141 // Returns the type declaration for a function, including the
6142 // entry type if the current function is the entry point function
func_type_decl(SPIRType & type)6143 string CompilerMSL::func_type_decl(SPIRType &type)
6144 {
6145 // The regular function return type. If not processing the entry point function, that's all we need
6146 string return_type = type_to_glsl(type) + type_to_array_glsl(type);
6147 if (!processing_entry_point)
6148 return return_type;
6149
6150 // If an outgoing interface block has been defined, and it should be returned, override the entry point return type
6151 bool ep_should_return_output = !get_is_rasterization_disabled();
6152 if (stage_out_var_id && ep_should_return_output)
6153 return_type = type_to_glsl(get_stage_out_struct_type()) + type_to_array_glsl(type);
6154
6155 // Prepend a entry type, based on the execution model
6156 string entry_type;
6157 auto &execution = get_entry_point();
6158 switch (execution.model)
6159 {
6160 case ExecutionModelVertex:
6161 entry_type = "vertex";
6162 break;
6163 case ExecutionModelTessellationEvaluation:
6164 if (!msl_options.supports_msl_version(1, 2))
6165 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
6166 if (execution.flags.get(ExecutionModeIsolines))
6167 SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
6168 if (msl_options.is_ios())
6169 entry_type =
6170 join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ") ]] vertex");
6171 else
6172 entry_type = join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ", ",
6173 execution.output_vertices, ") ]] vertex");
6174 break;
6175 case ExecutionModelFragment:
6176 entry_type =
6177 execution.flags.get(ExecutionModeEarlyFragmentTests) ? "[[ early_fragment_tests ]] fragment" : "fragment";
6178 break;
6179 case ExecutionModelTessellationControl:
6180 if (!msl_options.supports_msl_version(1, 2))
6181 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
6182 if (execution.flags.get(ExecutionModeIsolines))
6183 SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
6184 /* fallthrough */
6185 case ExecutionModelGLCompute:
6186 case ExecutionModelKernel:
6187 entry_type = "kernel";
6188 break;
6189 default:
6190 entry_type = "unknown";
6191 break;
6192 }
6193
6194 return entry_type + " " + return_type;
6195 }
6196
6197 // In MSL, address space qualifiers are required for all pointer or reference variables
get_argument_address_space(const SPIRVariable & argument)6198 string CompilerMSL::get_argument_address_space(const SPIRVariable &argument)
6199 {
6200 const auto &type = get<SPIRType>(argument.basetype);
6201
6202 switch (type.storage)
6203 {
6204 case StorageClassWorkgroup:
6205 return "threadgroup";
6206
6207 case StorageClassStorageBuffer:
6208 {
6209 // For arguments from variable pointers, we use the write count deduction, so
6210 // we should not assume any constness here. Only for global SSBOs.
6211 bool readonly = false;
6212 if (has_decoration(type.self, DecorationBlock))
6213 readonly = ir.get_buffer_block_flags(argument).get(DecorationNonWritable);
6214
6215 return readonly ? "const device" : "device";
6216 }
6217
6218 case StorageClassUniform:
6219 case StorageClassUniformConstant:
6220 case StorageClassPushConstant:
6221 if (type.basetype == SPIRType::Struct)
6222 {
6223 bool ssbo = has_decoration(type.self, DecorationBufferBlock);
6224 if (ssbo)
6225 {
6226 bool readonly = ir.get_buffer_block_flags(argument).get(DecorationNonWritable);
6227 return readonly ? "const device" : "device";
6228 }
6229 else
6230 return "constant";
6231 }
6232 break;
6233
6234 case StorageClassFunction:
6235 case StorageClassGeneric:
6236 // No address space for plain values.
6237 return type.pointer ? "thread" : "";
6238
6239 case StorageClassInput:
6240 if (get_execution_model() == ExecutionModelTessellationControl && argument.basevariable == stage_in_ptr_var_id)
6241 return "threadgroup";
6242 break;
6243
6244 case StorageClassOutput:
6245 if (capture_output_to_buffer)
6246 return "device";
6247 break;
6248
6249 default:
6250 break;
6251 }
6252
6253 return "thread";
6254 }
6255
get_type_address_space(const SPIRType & type,uint32_t id)6256 string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id)
6257 {
6258 switch (type.storage)
6259 {
6260 case StorageClassWorkgroup:
6261 return "threadgroup";
6262
6263 case StorageClassStorageBuffer:
6264 {
6265 // This can be called for variable pointer contexts as well, so be very careful about which method we choose.
6266 Bitset flags;
6267 if (ir.ids[id].get_type() == TypeVariable && has_decoration(type.self, DecorationBlock))
6268 flags = get_buffer_block_flags(id);
6269 else
6270 flags = get_decoration_bitset(id);
6271
6272 return flags.get(DecorationNonWritable) ? "const device" : "device";
6273 }
6274
6275 case StorageClassUniform:
6276 case StorageClassUniformConstant:
6277 case StorageClassPushConstant:
6278 if (type.basetype == SPIRType::Struct)
6279 {
6280 bool ssbo = has_decoration(type.self, DecorationBufferBlock);
6281 if (ssbo)
6282 {
6283 // This can be called for variable pointer contexts as well, so be very careful about which method we choose.
6284 Bitset flags;
6285 if (ir.ids[id].get_type() == TypeVariable && has_decoration(type.self, DecorationBlock))
6286 flags = get_buffer_block_flags(id);
6287 else
6288 flags = get_decoration_bitset(id);
6289
6290 return flags.get(DecorationNonWritable) ? "const device" : "device";
6291 }
6292 else
6293 return "constant";
6294 }
6295 else
6296 return "constant";
6297
6298 case StorageClassFunction:
6299 case StorageClassGeneric:
6300 // No address space for plain values.
6301 return type.pointer ? "thread" : "";
6302
6303 case StorageClassOutput:
6304 if (capture_output_to_buffer)
6305 return "device";
6306 break;
6307
6308 default:
6309 break;
6310 }
6311
6312 return "thread";
6313 }
6314
entry_point_arg_stage_in()6315 string CompilerMSL::entry_point_arg_stage_in()
6316 {
6317 string decl;
6318
6319 // Stage-in structure
6320 uint32_t stage_in_id;
6321 if (get_execution_model() == ExecutionModelTessellationEvaluation)
6322 stage_in_id = patch_stage_in_var_id;
6323 else
6324 stage_in_id = stage_in_var_id;
6325
6326 if (stage_in_id)
6327 {
6328 auto &var = get<SPIRVariable>(stage_in_id);
6329 auto &type = get_variable_data_type(var);
6330
6331 add_resource_name(var.self);
6332 decl = join(type_to_glsl(type), " ", to_name(var.self), " [[stage_in]]");
6333 }
6334
6335 return decl;
6336 }
6337
entry_point_args_builtin(string & ep_args)6338 void CompilerMSL::entry_point_args_builtin(string &ep_args)
6339 {
6340 // Builtin variables
6341 ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
6342 auto bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
6343
6344 // Don't emit SamplePosition as a separate parameter. In the entry
6345 // point, we get that by calling get_sample_position() on the sample ID.
6346 if (var.storage == StorageClassInput && is_builtin_variable(var) &&
6347 get_variable_data_type(var).basetype != SPIRType::Struct &&
6348 get_variable_data_type(var).basetype != SPIRType::ControlPointArray)
6349 {
6350 // If the builtin is not part of the active input builtin set, don't emit it.
6351 // Relevant for multiple entry-point modules which might declare unused builtins.
6352 if (!active_input_builtins.get(bi_type) || !interface_variable_exists_in_entry_point(var_id))
6353 return;
6354
6355 // These builtins are emitted specially. If we pass this branch, the builtin directly matches
6356 // a MSL builtin.
6357 if (bi_type != BuiltInSamplePosition && bi_type != BuiltInHelperInvocation &&
6358 bi_type != BuiltInPatchVertices && bi_type != BuiltInTessLevelInner &&
6359 bi_type != BuiltInTessLevelOuter && bi_type != BuiltInPosition && bi_type != BuiltInPointSize &&
6360 bi_type != BuiltInClipDistance && bi_type != BuiltInCullDistance && bi_type != BuiltInSubgroupEqMask &&
6361 bi_type != BuiltInBaryCoordNV && bi_type != BuiltInBaryCoordNoPerspNV &&
6362 bi_type != BuiltInSubgroupGeMask && bi_type != BuiltInSubgroupGtMask &&
6363 bi_type != BuiltInSubgroupLeMask && bi_type != BuiltInSubgroupLtMask &&
6364 ((get_execution_model() == ExecutionModelFragment && msl_options.multiview) ||
6365 bi_type != BuiltInViewIndex) &&
6366 (get_execution_model() == ExecutionModelGLCompute ||
6367 (get_execution_model() == ExecutionModelFragment && msl_options.supports_msl_version(2, 2)) ||
6368 (bi_type != BuiltInSubgroupLocalInvocationId && bi_type != BuiltInSubgroupSize)))
6369 {
6370 if (!ep_args.empty())
6371 ep_args += ", ";
6372
6373 ep_args += builtin_type_decl(bi_type, var_id) + " " + to_expression(var_id);
6374 ep_args += " [[" + builtin_qualifier(bi_type) + "]]";
6375 }
6376 }
6377 });
6378
6379 // Vertex and instance index built-ins
6380 if (needs_vertex_idx_arg)
6381 ep_args += built_in_func_arg(BuiltInVertexIndex, !ep_args.empty());
6382
6383 if (needs_instance_idx_arg)
6384 ep_args += built_in_func_arg(BuiltInInstanceIndex, !ep_args.empty());
6385
6386 if (capture_output_to_buffer)
6387 {
6388 // Add parameters to hold the indirect draw parameters and the shader output. This has to be handled
6389 // specially because it needs to be a pointer, not a reference.
6390 if (stage_out_var_id)
6391 {
6392 if (!ep_args.empty())
6393 ep_args += ", ";
6394 ep_args += join("device ", type_to_glsl(get_stage_out_struct_type()), "* ", output_buffer_var_name,
6395 " [[buffer(", msl_options.shader_output_buffer_index, ")]]");
6396 }
6397
6398 if (get_execution_model() == ExecutionModelTessellationControl)
6399 {
6400 if (!ep_args.empty())
6401 ep_args += ", ";
6402 ep_args +=
6403 join("constant uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
6404 }
6405 else if (stage_out_var_id)
6406 {
6407 if (!ep_args.empty())
6408 ep_args += ", ";
6409 ep_args +=
6410 join("device uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
6411 }
6412
6413 // Tessellation control shaders get three additional parameters:
6414 // a buffer to hold the per-patch data, a buffer to hold the per-patch
6415 // tessellation levels, and a block of workgroup memory to hold the
6416 // input control point data.
6417 if (get_execution_model() == ExecutionModelTessellationControl)
6418 {
6419 if (patch_stage_out_var_id)
6420 {
6421 if (!ep_args.empty())
6422 ep_args += ", ";
6423 ep_args +=
6424 join("device ", type_to_glsl(get_patch_stage_out_struct_type()), "* ", patch_output_buffer_var_name,
6425 " [[buffer(", convert_to_string(msl_options.shader_patch_output_buffer_index), ")]]");
6426 }
6427 if (!ep_args.empty())
6428 ep_args += ", ";
6429 ep_args += join("device ", get_tess_factor_struct_name(), "* ", tess_factor_buffer_var_name, " [[buffer(",
6430 convert_to_string(msl_options.shader_tess_factor_buffer_index), ")]]");
6431 if (stage_in_var_id)
6432 {
6433 if (!ep_args.empty())
6434 ep_args += ", ";
6435 ep_args += join("threadgroup ", type_to_glsl(get_stage_in_struct_type()), "* ", input_wg_var_name,
6436 " [[threadgroup(", convert_to_string(msl_options.shader_input_wg_index), ")]]");
6437 }
6438 }
6439 }
6440 }
6441
entry_point_args_argument_buffer(bool append_comma)6442 string CompilerMSL::entry_point_args_argument_buffer(bool append_comma)
6443 {
6444 string ep_args = entry_point_arg_stage_in();
6445 Bitset claimed_bindings;
6446
6447 for (uint32_t i = 0; i < kMaxArgumentBuffers; i++)
6448 {
6449 uint32_t id = argument_buffer_ids[i];
6450 if (id == 0)
6451 continue;
6452
6453 add_resource_name(id);
6454 auto &var = get<SPIRVariable>(id);
6455 auto &type = get_variable_data_type(var);
6456
6457 if (!ep_args.empty())
6458 ep_args += ", ";
6459
6460 // Check if the argument buffer binding itself has been remapped.
6461 uint32_t buffer_binding;
6462 auto itr = resource_bindings.find({ get_entry_point().model, i, kArgumentBufferBinding });
6463 if (itr != end(resource_bindings))
6464 {
6465 buffer_binding = itr->second.first.msl_buffer;
6466 itr->second.second = true;
6467 }
6468 else
6469 {
6470 // As a fallback, directly map desc set <-> binding.
6471 // If that was taken, take the next buffer binding.
6472 if (claimed_bindings.get(i))
6473 buffer_binding = next_metal_resource_index_buffer;
6474 else
6475 buffer_binding = i;
6476 }
6477
6478 claimed_bindings.set(buffer_binding);
6479
6480 ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_name(id);
6481 ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]";
6482
6483 next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1);
6484 }
6485
6486 entry_point_args_discrete_descriptors(ep_args);
6487 entry_point_args_builtin(ep_args);
6488
6489 if (!ep_args.empty() && append_comma)
6490 ep_args += ", ";
6491
6492 return ep_args;
6493 }
6494
find_constexpr_sampler(uint32_t id) const6495 const MSLConstexprSampler *CompilerMSL::find_constexpr_sampler(uint32_t id) const
6496 {
6497 // Try by ID.
6498 {
6499 auto itr = constexpr_samplers_by_id.find(id);
6500 if (itr != end(constexpr_samplers_by_id))
6501 return &itr->second;
6502 }
6503
6504 // Try by binding.
6505 {
6506 uint32_t desc_set = get_decoration(id, DecorationDescriptorSet);
6507 uint32_t binding = get_decoration(id, DecorationBinding);
6508
6509 auto itr = constexpr_samplers_by_binding.find({ desc_set, binding });
6510 if (itr != end(constexpr_samplers_by_binding))
6511 return &itr->second;
6512 }
6513
6514 return nullptr;
6515 }
6516
entry_point_args_discrete_descriptors(string & ep_args)6517 void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
6518 {
6519 // Output resources, sorted by resource index & type
6520 // We need to sort to work around a bug on macOS 10.13 with NVidia drivers where switching between shaders
6521 // with different order of buffers can result in issues with buffer assignments inside the driver.
6522 struct Resource
6523 {
6524 SPIRVariable *var;
6525 string name;
6526 SPIRType::BaseType basetype;
6527 uint32_t index;
6528 };
6529
6530 SmallVector<Resource> resources;
6531
6532 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
6533 if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
6534 var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer) &&
6535 !is_hidden_variable(var))
6536 {
6537 auto &type = get_variable_data_type(var);
6538 uint32_t var_id = var.self;
6539
6540 if (var.storage != StorageClassPushConstant)
6541 {
6542 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
6543 if (descriptor_set_is_argument_buffer(desc_set))
6544 return;
6545 }
6546
6547 const MSLConstexprSampler *constexpr_sampler = nullptr;
6548 if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
6549 {
6550 constexpr_sampler = find_constexpr_sampler(var_id);
6551 if (constexpr_sampler)
6552 {
6553 // Mark this ID as a constexpr sampler for later in case it came from set/bindings.
6554 constexpr_samplers_by_id[var_id] = *constexpr_sampler;
6555 }
6556 }
6557
6558 if (type.basetype == SPIRType::SampledImage)
6559 {
6560 add_resource_name(var_id);
6561 resources.push_back(
6562 { &var, to_name(var_id), SPIRType::Image, get_metal_resource_index(var, SPIRType::Image) });
6563
6564 if (type.image.dim != DimBuffer && !constexpr_sampler)
6565 {
6566 resources.push_back({ &var, to_sampler_expression(var_id), SPIRType::Sampler,
6567 get_metal_resource_index(var, SPIRType::Sampler) });
6568 }
6569 }
6570 else if (!constexpr_sampler)
6571 {
6572 // constexpr samplers are not declared as resources.
6573 add_resource_name(var_id);
6574 resources.push_back(
6575 { &var, to_name(var_id), type.basetype, get_metal_resource_index(var, type.basetype) });
6576 }
6577 }
6578 });
6579
6580 sort(resources.begin(), resources.end(), [](const Resource &lhs, const Resource &rhs) {
6581 return tie(lhs.basetype, lhs.index) < tie(rhs.basetype, rhs.index);
6582 });
6583
6584 for (auto &r : resources)
6585 {
6586 auto &var = *r.var;
6587 auto &type = get_variable_data_type(var);
6588
6589 uint32_t var_id = var.self;
6590
6591 switch (r.basetype)
6592 {
6593 case SPIRType::Struct:
6594 {
6595 auto &m = ir.meta[type.self];
6596 if (m.members.size() == 0)
6597 break;
6598 if (!type.array.empty())
6599 {
6600 if (type.array.size() > 1)
6601 SPIRV_CROSS_THROW("Arrays of arrays of buffers are not supported.");
6602
6603 // Metal doesn't directly support this, so we must expand the
6604 // array. We'll declare a local array to hold these elements
6605 // later.
6606 uint32_t array_size = to_array_size_literal(type);
6607
6608 if (array_size == 0)
6609 SPIRV_CROSS_THROW("Unsized arrays of buffers are not supported in MSL.");
6610
6611 buffer_arrays.push_back(var_id);
6612 for (uint32_t i = 0; i < array_size; ++i)
6613 {
6614 if (!ep_args.empty())
6615 ep_args += ", ";
6616 ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "* " + r.name + "_" +
6617 convert_to_string(i);
6618 ep_args += " [[buffer(" + convert_to_string(r.index + i) + ")]]";
6619 }
6620 }
6621 else
6622 {
6623 if (!ep_args.empty())
6624 ep_args += ", ";
6625 ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + r.name;
6626 ep_args += " [[buffer(" + convert_to_string(r.index) + ")]]";
6627 }
6628 break;
6629 }
6630 case SPIRType::Sampler:
6631 if (!ep_args.empty())
6632 ep_args += ", ";
6633 ep_args += sampler_type(type) + " " + r.name;
6634 ep_args += " [[sampler(" + convert_to_string(r.index) + ")]]";
6635 break;
6636 case SPIRType::Image:
6637 if (!ep_args.empty())
6638 ep_args += ", ";
6639 ep_args += image_type_glsl(type, var_id) + " " + r.name;
6640 ep_args += " [[texture(" + convert_to_string(r.index) + ")]]";
6641 break;
6642 default:
6643 if (!ep_args.empty())
6644 ep_args += ", ";
6645 ep_args += type_to_glsl(type, var_id) + " " + r.name;
6646 ep_args += " [[buffer(" + convert_to_string(r.index) + ")]]";
6647 break;
6648 }
6649 }
6650 }
6651
6652 // Returns a string containing a comma-delimited list of args for the entry point function
6653 // This is the "classic" method of MSL 1 when we don't have argument buffer support.
entry_point_args_classic(bool append_comma)6654 string CompilerMSL::entry_point_args_classic(bool append_comma)
6655 {
6656 string ep_args = entry_point_arg_stage_in();
6657 entry_point_args_discrete_descriptors(ep_args);
6658 entry_point_args_builtin(ep_args);
6659
6660 if (!ep_args.empty() && append_comma)
6661 ep_args += ", ";
6662
6663 return ep_args;
6664 }
6665
fix_up_shader_inputs_outputs()6666 void CompilerMSL::fix_up_shader_inputs_outputs()
6667 {
6668 // Look for sampled images and buffer. Add hooks to set up the swizzle constants or array lengths.
6669 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
6670 auto &type = get_variable_data_type(var);
6671 uint32_t var_id = var.self;
6672 bool ssbo = has_decoration(type.self, DecorationBufferBlock);
6673
6674 if (var.storage == StorageClassUniformConstant && !is_hidden_variable(var))
6675 {
6676 if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
6677 {
6678 auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
6679 entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
6680 bool is_array_type = !type.array.empty();
6681
6682 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
6683 if (descriptor_set_is_argument_buffer(desc_set))
6684 {
6685 statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
6686 is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
6687 ".spvSwizzleConstants", "[",
6688 convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
6689 }
6690 else
6691 {
6692 // If we have an array of images, we need to be able to index into it, so take a pointer instead.
6693 statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
6694 is_array_type ? " = &" : " = ", to_name(swizzle_buffer_id), "[",
6695 convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
6696 }
6697 });
6698 }
6699 }
6700 else if ((var.storage == StorageClassStorageBuffer || (var.storage == StorageClassUniform && ssbo)) &&
6701 !is_hidden_variable(var))
6702 {
6703 if (buffers_requiring_array_length.count(var.self))
6704 {
6705 auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
6706 entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
6707 bool is_array_type = !type.array.empty();
6708
6709 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
6710 if (descriptor_set_is_argument_buffer(desc_set))
6711 {
6712 statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
6713 is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
6714 ".spvBufferSizeConstants", "[",
6715 convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
6716 }
6717 else
6718 {
6719 // If we have an array of images, we need to be able to index into it, so take a pointer instead.
6720 statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
6721 is_array_type ? " = &" : " = ", to_name(buffer_size_buffer_id), "[",
6722 convert_to_string(get_metal_resource_index(var, type.basetype)), "];");
6723 }
6724 });
6725 }
6726 }
6727 });
6728
6729 // Builtin variables
6730 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
6731 uint32_t var_id = var.self;
6732 BuiltIn bi_type = ir.meta[var_id].decoration.builtin_type;
6733
6734 if (var.storage == StorageClassInput && is_builtin_variable(var))
6735 {
6736 auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
6737 switch (bi_type)
6738 {
6739 case BuiltInSamplePosition:
6740 entry_func.fixup_hooks_in.push_back([=]() {
6741 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = get_sample_position(",
6742 to_expression(builtin_sample_id_id), ");");
6743 });
6744 break;
6745 case BuiltInHelperInvocation:
6746 if (msl_options.is_ios())
6747 SPIRV_CROSS_THROW("simd_is_helper_thread() is only supported on macOS.");
6748 else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
6749 SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
6750
6751 entry_func.fixup_hooks_in.push_back([=]() {
6752 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = simd_is_helper_thread();");
6753 });
6754 break;
6755 case BuiltInPatchVertices:
6756 if (get_execution_model() == ExecutionModelTessellationEvaluation)
6757 entry_func.fixup_hooks_in.push_back([=]() {
6758 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
6759 to_expression(patch_stage_in_var_id), ".gl_in.size();");
6760 });
6761 else
6762 entry_func.fixup_hooks_in.push_back([=]() {
6763 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = spvIndirectParams[0];");
6764 });
6765 break;
6766 case BuiltInTessCoord:
6767 // Emit a fixup to account for the shifted domain. Don't do this for triangles;
6768 // MoltenVK will just reverse the winding order instead.
6769 if (msl_options.tess_domain_origin_lower_left && !get_entry_point().flags.get(ExecutionModeTriangles))
6770 {
6771 string tc = to_expression(var_id);
6772 entry_func.fixup_hooks_in.push_back([=]() { statement(tc, ".y = 1.0 - ", tc, ".y;"); });
6773 }
6774 break;
6775 case BuiltInSubgroupLocalInvocationId:
6776 // This is natively supported in compute shaders.
6777 if (get_execution_model() == ExecutionModelGLCompute)
6778 break;
6779
6780 // This is natively supported in fragment shaders in MSL 2.2.
6781 if (get_execution_model() == ExecutionModelFragment && msl_options.supports_msl_version(2, 2))
6782 break;
6783
6784 if (msl_options.is_ios())
6785 SPIRV_CROSS_THROW(
6786 "SubgroupLocalInvocationId cannot be used outside of compute shaders before MSL 2.2 on iOS.");
6787
6788 if (!msl_options.supports_msl_version(2, 1))
6789 SPIRV_CROSS_THROW(
6790 "SubgroupLocalInvocationId cannot be used outside of compute shaders before MSL 2.1.");
6791
6792 // Shaders other than compute shaders don't support the SIMD-group
6793 // builtins directly, but we can emulate them using the SIMD-group
6794 // functions. This might break if some of the subgroup terminated
6795 // before reaching the entry point.
6796 entry_func.fixup_hooks_in.push_back([=]() {
6797 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
6798 " = simd_prefix_exclusive_sum(1);");
6799 });
6800 break;
6801 case BuiltInSubgroupSize:
6802 // This is natively supported in compute shaders.
6803 if (get_execution_model() == ExecutionModelGLCompute)
6804 break;
6805
6806 // This is natively supported in fragment shaders in MSL 2.2.
6807 if (get_execution_model() == ExecutionModelFragment && msl_options.supports_msl_version(2, 2))
6808 break;
6809
6810 if (msl_options.is_ios())
6811 SPIRV_CROSS_THROW("SubgroupSize cannot be used outside of compute shaders on iOS.");
6812
6813 if (!msl_options.supports_msl_version(2, 1))
6814 SPIRV_CROSS_THROW("SubgroupSize cannot be used outside of compute shaders before Metal 2.1.");
6815
6816 entry_func.fixup_hooks_in.push_back(
6817 [=]() { statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = simd_sum(1);"); });
6818 break;
6819 case BuiltInSubgroupEqMask:
6820 if (msl_options.is_ios())
6821 SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
6822 if (!msl_options.supports_msl_version(2, 1))
6823 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
6824 entry_func.fixup_hooks_in.push_back([=]() {
6825 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
6826 to_expression(builtin_subgroup_invocation_id_id), " > 32 ? uint4(0, (1 << (",
6827 to_expression(builtin_subgroup_invocation_id_id), " - 32)), uint2(0)) : uint4(1 << ",
6828 to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
6829 });
6830 break;
6831 case BuiltInSubgroupGeMask:
6832 if (msl_options.is_ios())
6833 SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
6834 if (!msl_options.supports_msl_version(2, 1))
6835 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
6836 entry_func.fixup_hooks_in.push_back([=]() {
6837 // Case where index < 32, size < 32:
6838 // mask0 = bfe(0xFFFFFFFF, index, size - index);
6839 // mask1 = bfe(0xFFFFFFFF, 0, 0); // Gives 0
6840 // Case where index < 32 but size >= 32:
6841 // mask0 = bfe(0xFFFFFFFF, index, 32 - index);
6842 // mask1 = bfe(0xFFFFFFFF, 0, size - 32);
6843 // Case where index >= 32:
6844 // mask0 = bfe(0xFFFFFFFF, 32, 0); // Gives 0
6845 // mask1 = bfe(0xFFFFFFFF, index - 32, size - index);
6846 // This is expressed without branches to avoid divergent
6847 // control flow--hence the complicated min/max expressions.
6848 // This is further complicated by the fact that if you attempt
6849 // to bfe out-of-bounds on Metal, undefined behavior is the
6850 // result.
6851 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
6852 " = uint4(extract_bits(0xFFFFFFFF, min(",
6853 to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(min((int)",
6854 to_expression(builtin_subgroup_size_id), ", 32) - (int)",
6855 to_expression(builtin_subgroup_invocation_id_id),
6856 ", 0)), extract_bits(0xFFFFFFFF, (uint)max((int)",
6857 to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), (uint)max((int)",
6858 to_expression(builtin_subgroup_size_id), " - (int)max(",
6859 to_expression(builtin_subgroup_invocation_id_id), ", 32u), 0)), uint2(0));");
6860 });
6861 break;
6862 case BuiltInSubgroupGtMask:
6863 if (msl_options.is_ios())
6864 SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
6865 if (!msl_options.supports_msl_version(2, 1))
6866 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
6867 entry_func.fixup_hooks_in.push_back([=]() {
6868 // The same logic applies here, except now the index is one
6869 // more than the subgroup invocation ID.
6870 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
6871 " = uint4(extract_bits(0xFFFFFFFF, min(",
6872 to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(min((int)",
6873 to_expression(builtin_subgroup_size_id), ", 32) - (int)",
6874 to_expression(builtin_subgroup_invocation_id_id),
6875 " - 1, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)",
6876 to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), (uint)max((int)",
6877 to_expression(builtin_subgroup_size_id), " - (int)max(",
6878 to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), 0)), uint2(0));");
6879 });
6880 break;
6881 case BuiltInSubgroupLeMask:
6882 if (msl_options.is_ios())
6883 SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
6884 if (!msl_options.supports_msl_version(2, 1))
6885 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
6886 entry_func.fixup_hooks_in.push_back([=]() {
6887 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
6888 " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
6889 to_expression(builtin_subgroup_invocation_id_id),
6890 " + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
6891 to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0)), uint2(0));");
6892 });
6893 break;
6894 case BuiltInSubgroupLtMask:
6895 if (msl_options.is_ios())
6896 SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
6897 if (!msl_options.supports_msl_version(2, 1))
6898 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
6899 entry_func.fixup_hooks_in.push_back([=]() {
6900 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
6901 " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
6902 to_expression(builtin_subgroup_invocation_id_id),
6903 ", 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
6904 to_expression(builtin_subgroup_invocation_id_id), " - 32, 0)), uint2(0));");
6905 });
6906 break;
6907 case BuiltInViewIndex:
6908 if (!msl_options.multiview)
6909 {
6910 // According to the Vulkan spec, when not running under a multiview
6911 // render pass, ViewIndex is 0.
6912 entry_func.fixup_hooks_in.push_back([=]() {
6913 statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = 0;");
6914 });
6915 }
6916 else if (get_execution_model() == ExecutionModelFragment)
6917 {
6918 // Because we adjusted the view index in the vertex shader, we have to
6919 // adjust it back here.
6920 entry_func.fixup_hooks_in.push_back([=]() {
6921 statement(to_expression(var_id), " += ", to_expression(view_mask_buffer_id), "[0];");
6922 });
6923 }
6924 else if (get_execution_model() == ExecutionModelVertex)
6925 {
6926 // Metal provides no special support for multiview, so we smuggle
6927 // the view index in the instance index.
6928 entry_func.fixup_hooks_in.push_back([=]() {
6929 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
6930 to_expression(view_mask_buffer_id), "[0] + ", to_expression(builtin_instance_idx_id),
6931 " % ", to_expression(view_mask_buffer_id), "[1];");
6932 statement(to_expression(builtin_instance_idx_id), " /= ", to_expression(view_mask_buffer_id),
6933 "[1];");
6934 });
6935 // In addition to setting the variable itself, we also need to
6936 // set the render_target_array_index with it on output. We have to
6937 // offset this by the base view index, because Metal isn't in on
6938 // our little game here.
6939 entry_func.fixup_hooks_out.push_back([=]() {
6940 statement(to_expression(builtin_layer_id), " = ", to_expression(var_id), " - ",
6941 to_expression(view_mask_buffer_id), "[0];");
6942 });
6943 }
6944 break;
6945 default:
6946 break;
6947 }
6948 }
6949 });
6950 }
6951
6952 // Returns the Metal index of the resource of the specified type as used by the specified variable.
get_metal_resource_index(SPIRVariable & var,SPIRType::BaseType basetype)6953 uint32_t CompilerMSL::get_metal_resource_index(SPIRVariable &var, SPIRType::BaseType basetype)
6954 {
6955 auto &execution = get_entry_point();
6956 auto &var_dec = ir.meta[var.self].decoration;
6957 auto &var_type = get<SPIRType>(var.basetype);
6958 uint32_t var_desc_set = (var.storage == StorageClassPushConstant) ? kPushConstDescSet : var_dec.set;
6959 uint32_t var_binding = (var.storage == StorageClassPushConstant) ? kPushConstBinding : var_dec.binding;
6960
6961 // If a matching binding has been specified, find and use it.
6962 auto itr = resource_bindings.find({ execution.model, var_desc_set, var_binding });
6963
6964 auto resource_decoration = var_type.basetype == SPIRType::SampledImage && basetype == SPIRType::Sampler ?
6965 SPIRVCrossDecorationResourceIndexSecondary :
6966 SPIRVCrossDecorationResourceIndexPrimary;
6967
6968 if (itr != end(resource_bindings))
6969 {
6970 auto &remap = itr->second;
6971 remap.second = true;
6972 switch (basetype)
6973 {
6974 case SPIRType::Image:
6975 set_extended_decoration(var.self, resource_decoration, remap.first.msl_texture);
6976 return remap.first.msl_texture;
6977 case SPIRType::Sampler:
6978 set_extended_decoration(var.self, resource_decoration, remap.first.msl_sampler);
6979 return remap.first.msl_sampler;
6980 default:
6981 set_extended_decoration(var.self, resource_decoration, remap.first.msl_buffer);
6982 return remap.first.msl_buffer;
6983 }
6984 }
6985
6986 // If we have already allocated an index, keep using it.
6987 if (has_extended_decoration(var.self, resource_decoration))
6988 return get_extended_decoration(var.self, resource_decoration);
6989
6990 // If we did not explicitly remap, allocate bindings on demand.
6991 // We cannot reliably use Binding decorations since SPIR-V and MSL's binding models are very different.
6992
6993 uint32_t binding_stride = 1;
6994 auto &type = get<SPIRType>(var.basetype);
6995 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
6996 binding_stride *= type.array_size_literal[i] ? type.array[i] : get<SPIRConstant>(type.array[i]).scalar();
6997
6998 assert(binding_stride != 0);
6999
7000 // If a binding has not been specified, revert to incrementing resource indices.
7001 uint32_t resource_index;
7002
7003 bool allocate_argument_buffer_ids = false;
7004 uint32_t desc_set = 0;
7005
7006 if (var.storage != StorageClassPushConstant)
7007 {
7008 desc_set = get_decoration(var.self, DecorationDescriptorSet);
7009 allocate_argument_buffer_ids = descriptor_set_is_argument_buffer(desc_set);
7010 }
7011
7012 if (allocate_argument_buffer_ids)
7013 {
7014 // Allocate from a flat ID binding space.
7015 resource_index = next_metal_resource_ids[desc_set];
7016 next_metal_resource_ids[desc_set] += binding_stride;
7017 }
7018 else
7019 {
7020 // Allocate from plain bindings which are allocated per resource type.
7021 switch (basetype)
7022 {
7023 case SPIRType::Image:
7024 resource_index = next_metal_resource_index_texture;
7025 next_metal_resource_index_texture += binding_stride;
7026 break;
7027 case SPIRType::Sampler:
7028 resource_index = next_metal_resource_index_sampler;
7029 next_metal_resource_index_sampler += binding_stride;
7030 break;
7031 default:
7032 resource_index = next_metal_resource_index_buffer;
7033 next_metal_resource_index_buffer += binding_stride;
7034 break;
7035 }
7036 }
7037
7038 set_extended_decoration(var.self, resource_decoration, resource_index);
7039 return resource_index;
7040 }
7041
argument_decl(const SPIRFunction::Parameter & arg)7042 string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
7043 {
7044 auto &var = get<SPIRVariable>(arg.id);
7045 auto &type = get_variable_data_type(var);
7046 auto &var_type = get<SPIRType>(arg.type);
7047 StorageClass storage = var_type.storage;
7048 bool is_pointer = var_type.pointer;
7049
7050 // If we need to modify the name of the variable, make sure we use the original variable.
7051 // Our alias is just a shadow variable.
7052 uint32_t name_id = var.self;
7053 if (arg.alias_global_variable && var.basevariable)
7054 name_id = var.basevariable;
7055
7056 bool constref = !arg.alias_global_variable && is_pointer && arg.write_count == 0;
7057
7058 bool type_is_image = type.basetype == SPIRType::Image || type.basetype == SPIRType::SampledImage ||
7059 type.basetype == SPIRType::Sampler;
7060
7061 // Arrays of images/samplers in MSL are always const.
7062 if (!type.array.empty() && type_is_image)
7063 constref = true;
7064
7065 string decl;
7066 if (constref)
7067 decl += "const ";
7068
7069 bool builtin = is_builtin_variable(var);
7070 if (var.basevariable == stage_in_ptr_var_id || var.basevariable == stage_out_ptr_var_id)
7071 decl += type_to_glsl(type, arg.id);
7072 else if (builtin)
7073 decl += builtin_type_decl(static_cast<BuiltIn>(get_decoration(arg.id, DecorationBuiltIn)), arg.id);
7074 else if ((storage == StorageClassUniform || storage == StorageClassStorageBuffer) && is_array(type))
7075 decl += join(type_to_glsl(type, arg.id), "*");
7076 else
7077 decl += type_to_glsl(type, arg.id);
7078
7079 bool opaque_handle = storage == StorageClassUniformConstant;
7080
7081 string address_space = get_argument_address_space(var);
7082
7083 if (!builtin && !opaque_handle && !is_pointer &&
7084 (storage == StorageClassFunction || storage == StorageClassGeneric))
7085 {
7086 // If the argument is a pure value and not an opaque type, we will pass by value.
7087 if (is_array(type))
7088 {
7089 // We are receiving an array by value. This is problematic.
7090 // We cannot be sure of the target address space since we are supposed to receive a copy,
7091 // but this is not possible with MSL without some extra work.
7092 // We will have to assume we're getting a reference in thread address space.
7093 // If we happen to get a reference in constant address space, the caller must emit a copy and pass that.
7094 // Thread const therefore becomes the only logical choice, since we cannot "create" a constant array from
7095 // non-constant arrays, but we can create thread const from constant.
7096 decl = string("thread const ") + decl;
7097 decl += " (&";
7098 decl += to_expression(name_id);
7099 decl += ")";
7100 decl += type_to_array_glsl(type);
7101 }
7102 else
7103 {
7104 if (!address_space.empty())
7105 decl = join(address_space, " ", decl);
7106 decl += " ";
7107 decl += to_expression(name_id);
7108 }
7109 }
7110 else if (is_array(type) && !type_is_image)
7111 {
7112 // Arrays of images and samplers are special cased.
7113 if (!address_space.empty())
7114 decl = join(address_space, " ", decl);
7115
7116 if (msl_options.argument_buffers)
7117 {
7118 uint32_t desc_set = get_decoration(name_id, DecorationDescriptorSet);
7119 if ((storage == StorageClassUniform || storage == StorageClassStorageBuffer) &&
7120 descriptor_set_is_argument_buffer(desc_set))
7121 {
7122 // An awkward case where we need to emit *more* address space declarations (yay!).
7123 // An example is where we pass down an array of buffer pointers to leaf functions.
7124 // It's a constant array containing pointers to constants.
7125 // The pointer array is always constant however. E.g.
7126 // device SSBO * constant (&array)[N].
7127 // const device SSBO * constant (&array)[N].
7128 // constant SSBO * constant (&array)[N].
7129 // However, this only matters for argument buffers, since for MSL 1.0 style codegen,
7130 // we emit the buffer array on stack instead, and that seems to work just fine apparently.
7131 decl += " constant";
7132 }
7133 }
7134
7135 decl += " (&";
7136 decl += to_expression(name_id);
7137 decl += ")";
7138 decl += type_to_array_glsl(type);
7139 }
7140 else if (!opaque_handle)
7141 {
7142 // If this is going to be a reference to a variable pointer, the address space
7143 // for the reference has to go before the '&', but after the '*'.
7144 if (!address_space.empty())
7145 {
7146 if (decl.back() == '*')
7147 decl += join(" ", address_space, " ");
7148 else
7149 decl = join(address_space, " ", decl);
7150 }
7151 decl += "&";
7152 decl += " ";
7153 decl += to_expression(name_id);
7154 }
7155 else
7156 {
7157 if (!address_space.empty())
7158 decl = join(address_space, " ", decl);
7159 decl += " ";
7160 decl += to_expression(name_id);
7161 }
7162
7163 return decl;
7164 }
7165
7166 // If we're currently in the entry point function, and the object
7167 // has a qualified name, use it, otherwise use the standard name.
to_name(uint32_t id,bool allow_alias) const7168 string CompilerMSL::to_name(uint32_t id, bool allow_alias) const
7169 {
7170 if (current_function && (current_function->self == ir.default_entry_point))
7171 {
7172 auto *m = ir.find_meta(id);
7173 if (m && !m->decoration.qualified_alias.empty())
7174 return m->decoration.qualified_alias;
7175 }
7176 return Compiler::to_name(id, allow_alias);
7177 }
7178
7179 // Returns a name that combines the name of the struct with the name of the member, except for Builtins
to_qualified_member_name(const SPIRType & type,uint32_t index)7180 string CompilerMSL::to_qualified_member_name(const SPIRType &type, uint32_t index)
7181 {
7182 // Don't qualify Builtin names because they are unique and are treated as such when building expressions
7183 BuiltIn builtin = BuiltInMax;
7184 if (is_member_builtin(type, index, &builtin))
7185 return builtin_to_glsl(builtin, type.storage);
7186
7187 // Strip any underscore prefix from member name
7188 string mbr_name = to_member_name(type, index);
7189 size_t startPos = mbr_name.find_first_not_of("_");
7190 mbr_name = (startPos != string::npos) ? mbr_name.substr(startPos) : "";
7191 return join(to_name(type.self), "_", mbr_name);
7192 }
7193
7194 // Ensures that the specified name is permanently usable by prepending a prefix
7195 // if the first chars are _ and a digit, which indicate a transient name.
ensure_valid_name(string name,string pfx)7196 string CompilerMSL::ensure_valid_name(string name, string pfx)
7197 {
7198 return (name.size() >= 2 && name[0] == '_' && isdigit(name[1])) ? (pfx + name) : name;
7199 }
7200
7201 // Replace all names that match MSL keywords or Metal Standard Library functions.
replace_illegal_names()7202 void CompilerMSL::replace_illegal_names()
7203 {
7204 // FIXME: MSL and GLSL are doing two different things here.
7205 // Agree on convention and remove this override.
7206 static const unordered_set<string> keywords = {
7207 "kernel",
7208 "vertex",
7209 "fragment",
7210 "compute",
7211 "bias",
7212 "assert",
7213 "VARIABLE_TRACEPOINT",
7214 "STATIC_DATA_TRACEPOINT",
7215 "STATIC_DATA_TRACEPOINT_V",
7216 "METAL_ALIGN",
7217 "METAL_ASM",
7218 "METAL_CONST",
7219 "METAL_DEPRECATED",
7220 "METAL_ENABLE_IF",
7221 "METAL_FUNC",
7222 "METAL_INTERNAL",
7223 "METAL_NON_NULL_RETURN",
7224 "METAL_NORETURN",
7225 "METAL_NOTHROW",
7226 "METAL_PURE",
7227 "METAL_UNAVAILABLE",
7228 "METAL_IMPLICIT",
7229 "METAL_EXPLICIT",
7230 "METAL_CONST_ARG",
7231 "METAL_ARG_UNIFORM",
7232 "METAL_ZERO_ARG",
7233 "METAL_VALID_LOD_ARG",
7234 "METAL_VALID_LEVEL_ARG",
7235 "METAL_VALID_STORE_ORDER",
7236 "METAL_VALID_LOAD_ORDER",
7237 "METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
7238 "METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
7239 "METAL_VALID_RENDER_TARGET",
7240 "is_function_constant_defined",
7241 "CHAR_BIT",
7242 "SCHAR_MAX",
7243 "SCHAR_MIN",
7244 "UCHAR_MAX",
7245 "CHAR_MAX",
7246 "CHAR_MIN",
7247 "USHRT_MAX",
7248 "SHRT_MAX",
7249 "SHRT_MIN",
7250 "UINT_MAX",
7251 "INT_MAX",
7252 "INT_MIN",
7253 "FLT_DIG",
7254 "FLT_MANT_DIG",
7255 "FLT_MAX_10_EXP",
7256 "FLT_MAX_EXP",
7257 "FLT_MIN_10_EXP",
7258 "FLT_MIN_EXP",
7259 "FLT_RADIX",
7260 "FLT_MAX",
7261 "FLT_MIN",
7262 "FLT_EPSILON",
7263 "FP_ILOGB0",
7264 "FP_ILOGBNAN",
7265 "MAXFLOAT",
7266 "HUGE_VALF",
7267 "INFINITY",
7268 "NAN",
7269 "M_E_F",
7270 "M_LOG2E_F",
7271 "M_LOG10E_F",
7272 "M_LN2_F",
7273 "M_LN10_F",
7274 "M_PI_F",
7275 "M_PI_2_F",
7276 "M_PI_4_F",
7277 "M_1_PI_F",
7278 "M_2_PI_F",
7279 "M_2_SQRTPI_F",
7280 "M_SQRT2_F",
7281 "M_SQRT1_2_F",
7282 "HALF_DIG",
7283 "HALF_MANT_DIG",
7284 "HALF_MAX_10_EXP",
7285 "HALF_MAX_EXP",
7286 "HALF_MIN_10_EXP",
7287 "HALF_MIN_EXP",
7288 "HALF_RADIX",
7289 "HALF_MAX",
7290 "HALF_MIN",
7291 "HALF_EPSILON",
7292 "MAXHALF",
7293 "HUGE_VALH",
7294 "M_E_H",
7295 "M_LOG2E_H",
7296 "M_LOG10E_H",
7297 "M_LN2_H",
7298 "M_LN10_H",
7299 "M_PI_H",
7300 "M_PI_2_H",
7301 "M_PI_4_H",
7302 "M_1_PI_H",
7303 "M_2_PI_H",
7304 "M_2_SQRTPI_H",
7305 "M_SQRT2_H",
7306 "M_SQRT1_2_H",
7307 "DBL_DIG",
7308 "DBL_MANT_DIG",
7309 "DBL_MAX_10_EXP",
7310 "DBL_MAX_EXP",
7311 "DBL_MIN_10_EXP",
7312 "DBL_MIN_EXP",
7313 "DBL_RADIX",
7314 "DBL_MAX",
7315 "DBL_MIN",
7316 "DBL_EPSILON",
7317 "HUGE_VAL",
7318 "M_E",
7319 "M_LOG2E",
7320 "M_LOG10E",
7321 "M_LN2",
7322 "M_LN10",
7323 "M_PI",
7324 "M_PI_2",
7325 "M_PI_4",
7326 "M_1_PI",
7327 "M_2_PI",
7328 "M_2_SQRTPI",
7329 "M_SQRT2",
7330 "M_SQRT1_2",
7331 "quad_broadcast",
7332 };
7333
7334 static const unordered_set<string> illegal_func_names = {
7335 "main",
7336 "saturate",
7337 "assert",
7338 "VARIABLE_TRACEPOINT",
7339 "STATIC_DATA_TRACEPOINT",
7340 "STATIC_DATA_TRACEPOINT_V",
7341 "METAL_ALIGN",
7342 "METAL_ASM",
7343 "METAL_CONST",
7344 "METAL_DEPRECATED",
7345 "METAL_ENABLE_IF",
7346 "METAL_FUNC",
7347 "METAL_INTERNAL",
7348 "METAL_NON_NULL_RETURN",
7349 "METAL_NORETURN",
7350 "METAL_NOTHROW",
7351 "METAL_PURE",
7352 "METAL_UNAVAILABLE",
7353 "METAL_IMPLICIT",
7354 "METAL_EXPLICIT",
7355 "METAL_CONST_ARG",
7356 "METAL_ARG_UNIFORM",
7357 "METAL_ZERO_ARG",
7358 "METAL_VALID_LOD_ARG",
7359 "METAL_VALID_LEVEL_ARG",
7360 "METAL_VALID_STORE_ORDER",
7361 "METAL_VALID_LOAD_ORDER",
7362 "METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
7363 "METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
7364 "METAL_VALID_RENDER_TARGET",
7365 "is_function_constant_defined",
7366 "CHAR_BIT",
7367 "SCHAR_MAX",
7368 "SCHAR_MIN",
7369 "UCHAR_MAX",
7370 "CHAR_MAX",
7371 "CHAR_MIN",
7372 "USHRT_MAX",
7373 "SHRT_MAX",
7374 "SHRT_MIN",
7375 "UINT_MAX",
7376 "INT_MAX",
7377 "INT_MIN",
7378 "FLT_DIG",
7379 "FLT_MANT_DIG",
7380 "FLT_MAX_10_EXP",
7381 "FLT_MAX_EXP",
7382 "FLT_MIN_10_EXP",
7383 "FLT_MIN_EXP",
7384 "FLT_RADIX",
7385 "FLT_MAX",
7386 "FLT_MIN",
7387 "FLT_EPSILON",
7388 "FP_ILOGB0",
7389 "FP_ILOGBNAN",
7390 "MAXFLOAT",
7391 "HUGE_VALF",
7392 "INFINITY",
7393 "NAN",
7394 "M_E_F",
7395 "M_LOG2E_F",
7396 "M_LOG10E_F",
7397 "M_LN2_F",
7398 "M_LN10_F",
7399 "M_PI_F",
7400 "M_PI_2_F",
7401 "M_PI_4_F",
7402 "M_1_PI_F",
7403 "M_2_PI_F",
7404 "M_2_SQRTPI_F",
7405 "M_SQRT2_F",
7406 "M_SQRT1_2_F",
7407 "HALF_DIG",
7408 "HALF_MANT_DIG",
7409 "HALF_MAX_10_EXP",
7410 "HALF_MAX_EXP",
7411 "HALF_MIN_10_EXP",
7412 "HALF_MIN_EXP",
7413 "HALF_RADIX",
7414 "HALF_MAX",
7415 "HALF_MIN",
7416 "HALF_EPSILON",
7417 "MAXHALF",
7418 "HUGE_VALH",
7419 "M_E_H",
7420 "M_LOG2E_H",
7421 "M_LOG10E_H",
7422 "M_LN2_H",
7423 "M_LN10_H",
7424 "M_PI_H",
7425 "M_PI_2_H",
7426 "M_PI_4_H",
7427 "M_1_PI_H",
7428 "M_2_PI_H",
7429 "M_2_SQRTPI_H",
7430 "M_SQRT2_H",
7431 "M_SQRT1_2_H",
7432 "DBL_DIG",
7433 "DBL_MANT_DIG",
7434 "DBL_MAX_10_EXP",
7435 "DBL_MAX_EXP",
7436 "DBL_MIN_10_EXP",
7437 "DBL_MIN_EXP",
7438 "DBL_RADIX",
7439 "DBL_MAX",
7440 "DBL_MIN",
7441 "DBL_EPSILON",
7442 "HUGE_VAL",
7443 "M_E",
7444 "M_LOG2E",
7445 "M_LOG10E",
7446 "M_LN2",
7447 "M_LN10",
7448 "M_PI",
7449 "M_PI_2",
7450 "M_PI_4",
7451 "M_1_PI",
7452 "M_2_PI",
7453 "M_2_SQRTPI",
7454 "M_SQRT2",
7455 "M_SQRT1_2",
7456 };
7457
7458 ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &) {
7459 auto &dec = ir.meta[self].decoration;
7460 if (keywords.find(dec.alias) != end(keywords))
7461 dec.alias += "0";
7462 });
7463
7464 ir.for_each_typed_id<SPIRFunction>([&](uint32_t self, SPIRFunction &) {
7465 auto &dec = ir.meta[self].decoration;
7466 if (illegal_func_names.find(dec.alias) != end(illegal_func_names))
7467 dec.alias += "0";
7468 });
7469
7470 ir.for_each_typed_id<SPIRType>([&](uint32_t self, SPIRType &) {
7471 for (auto &mbr_dec : ir.meta[self].members)
7472 if (keywords.find(mbr_dec.alias) != end(keywords))
7473 mbr_dec.alias += "0";
7474 });
7475
7476 for (auto &entry : ir.entry_points)
7477 {
7478 // Change both the entry point name and the alias, to keep them synced.
7479 string &ep_name = entry.second.name;
7480 if (illegal_func_names.find(ep_name) != end(illegal_func_names))
7481 ep_name += "0";
7482
7483 // Always write this because entry point might have been renamed earlier.
7484 ir.meta[entry.first].decoration.alias = ep_name;
7485 }
7486
7487 CompilerGLSL::replace_illegal_names();
7488 }
7489
to_member_reference(uint32_t base,const SPIRType & type,uint32_t index,bool ptr_chain)7490 string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain)
7491 {
7492 auto *var = maybe_get<SPIRVariable>(base);
7493 // If this is a buffer array, we have to dereference the buffer pointers.
7494 // Otherwise, if this is a pointer expression, dereference it.
7495
7496 bool declared_as_pointer = false;
7497
7498 if (var)
7499 {
7500 bool is_buffer_variable = var->storage == StorageClassUniform || var->storage == StorageClassStorageBuffer;
7501 declared_as_pointer = is_buffer_variable && is_array(get<SPIRType>(var->basetype));
7502 }
7503
7504 if (declared_as_pointer || (!ptr_chain && should_dereference(base)))
7505 return join("->", to_member_name(type, index));
7506 else
7507 return join(".", to_member_name(type, index));
7508 }
7509
to_qualifiers_glsl(uint32_t id)7510 string CompilerMSL::to_qualifiers_glsl(uint32_t id)
7511 {
7512 string quals;
7513
7514 auto &type = expression_type(id);
7515 if (type.storage == StorageClassWorkgroup)
7516 quals += "threadgroup ";
7517
7518 return quals;
7519 }
7520
7521 // The optional id parameter indicates the object whose type we are trying
7522 // to find the description for. It is optional. Most type descriptions do not
7523 // depend on a specific object's use of that type.
type_to_glsl(const SPIRType & type,uint32_t id)7524 string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id)
7525 {
7526 string type_name;
7527
7528 // Pointer?
7529 if (type.pointer)
7530 {
7531 type_name = join(get_type_address_space(type, id), " ", type_to_glsl(get<SPIRType>(type.parent_type), id));
7532 switch (type.basetype)
7533 {
7534 case SPIRType::Image:
7535 case SPIRType::SampledImage:
7536 case SPIRType::Sampler:
7537 // These are handles.
7538 break;
7539 default:
7540 // Anything else can be a raw pointer.
7541 type_name += "*";
7542 break;
7543 }
7544 return type_name;
7545 }
7546
7547 switch (type.basetype)
7548 {
7549 case SPIRType::Struct:
7550 // Need OpName lookup here to get a "sensible" name for a struct.
7551 return to_name(type.self);
7552
7553 case SPIRType::Image:
7554 case SPIRType::SampledImage:
7555 return image_type_glsl(type, id);
7556
7557 case SPIRType::Sampler:
7558 return sampler_type(type);
7559
7560 case SPIRType::Void:
7561 return "void";
7562
7563 case SPIRType::AtomicCounter:
7564 return "atomic_uint";
7565
7566 case SPIRType::ControlPointArray:
7567 return join("patch_control_point<", type_to_glsl(get<SPIRType>(type.parent_type), id), ">");
7568
7569 // Scalars
7570 case SPIRType::Boolean:
7571 type_name = "bool";
7572 break;
7573 case SPIRType::Char:
7574 case SPIRType::SByte:
7575 type_name = "char";
7576 break;
7577 case SPIRType::UByte:
7578 type_name = "uchar";
7579 break;
7580 case SPIRType::Short:
7581 type_name = "short";
7582 break;
7583 case SPIRType::UShort:
7584 type_name = "ushort";
7585 break;
7586 case SPIRType::Int:
7587 type_name = "int";
7588 break;
7589 case SPIRType::UInt:
7590 type_name = "uint";
7591 break;
7592 case SPIRType::Int64:
7593 if (!msl_options.supports_msl_version(2, 2))
7594 SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
7595 type_name = "long";
7596 break;
7597 case SPIRType::UInt64:
7598 if (!msl_options.supports_msl_version(2, 2))
7599 SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
7600 type_name = "ulong";
7601 break;
7602 case SPIRType::Half:
7603 type_name = "half";
7604 break;
7605 case SPIRType::Float:
7606 type_name = "float";
7607 break;
7608 case SPIRType::Double:
7609 type_name = "double"; // Currently unsupported
7610 break;
7611
7612 default:
7613 return "unknown_type";
7614 }
7615
7616 // Matrix?
7617 if (type.columns > 1)
7618 type_name += to_string(type.columns) + "x";
7619
7620 // Vector or Matrix?
7621 if (type.vecsize > 1)
7622 type_name += to_string(type.vecsize);
7623
7624 return type_name;
7625 }
7626
sampler_type(const SPIRType & type)7627 std::string CompilerMSL::sampler_type(const SPIRType &type)
7628 {
7629 if (!type.array.empty())
7630 {
7631 if (!msl_options.supports_msl_version(2))
7632 SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of samplers.");
7633
7634 if (type.array.size() > 1)
7635 SPIRV_CROSS_THROW("Arrays of arrays of samplers are not supported in MSL.");
7636
7637 // Arrays of samplers in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
7638 uint32_t array_size = to_array_size_literal(type);
7639 if (array_size == 0)
7640 SPIRV_CROSS_THROW("Unsized array of samplers is not supported in MSL.");
7641
7642 auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
7643 return join("array<", sampler_type(parent), ", ", array_size, ">");
7644 }
7645 else
7646 return "sampler";
7647 }
7648
7649 // Returns an MSL string describing the SPIR-V image type
image_type_glsl(const SPIRType & type,uint32_t id)7650 string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id)
7651 {
7652 auto *var = maybe_get<SPIRVariable>(id);
7653 if (var && var->basevariable)
7654 {
7655 // For comparison images, check against the base variable,
7656 // and not the fake ID which might have been generated for this variable.
7657 id = var->basevariable;
7658 }
7659
7660 if (!type.array.empty())
7661 {
7662 uint32_t major = 2, minor = 0;
7663 if (msl_options.is_ios())
7664 {
7665 major = 1;
7666 minor = 2;
7667 }
7668 if (!msl_options.supports_msl_version(major, minor))
7669 {
7670 if (msl_options.is_ios())
7671 SPIRV_CROSS_THROW("MSL 1.2 or greater is required for arrays of textures.");
7672 else
7673 SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of textures.");
7674 }
7675
7676 if (type.array.size() > 1)
7677 SPIRV_CROSS_THROW("Arrays of arrays of textures are not supported in MSL.");
7678
7679 // Arrays of images in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
7680 uint32_t array_size = to_array_size_literal(type);
7681 if (array_size == 0)
7682 SPIRV_CROSS_THROW("Unsized array of images is not supported in MSL.");
7683
7684 auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
7685 return join("array<", image_type_glsl(parent, id), ", ", array_size, ">");
7686 }
7687
7688 string img_type_name;
7689
7690 // Bypass pointers because we need the real image struct
7691 auto &img_type = get<SPIRType>(type.self).image;
7692 if (image_is_comparison(type, id))
7693 {
7694 switch (img_type.dim)
7695 {
7696 case Dim1D:
7697 img_type_name += "depth1d_unsupported_by_metal";
7698 break;
7699 case Dim2D:
7700 if (img_type.ms && img_type.arrayed)
7701 {
7702 if (!msl_options.supports_msl_version(2, 1))
7703 SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
7704 img_type_name += "depth2d_ms_array";
7705 }
7706 else if (img_type.ms)
7707 img_type_name += "depth2d_ms";
7708 else if (img_type.arrayed)
7709 img_type_name += "depth2d_array";
7710 else
7711 img_type_name += "depth2d";
7712 break;
7713 case Dim3D:
7714 img_type_name += "depth3d_unsupported_by_metal";
7715 break;
7716 case DimCube:
7717 img_type_name += (img_type.arrayed ? "depthcube_array" : "depthcube");
7718 break;
7719 default:
7720 img_type_name += "unknown_depth_texture_type";
7721 break;
7722 }
7723 }
7724 else
7725 {
7726 switch (img_type.dim)
7727 {
7728 case Dim1D:
7729 img_type_name += (img_type.arrayed ? "texture1d_array" : "texture1d");
7730 break;
7731 case DimBuffer:
7732 if (img_type.ms || img_type.arrayed)
7733 SPIRV_CROSS_THROW("Cannot use texel buffers with multisampling or array layers.");
7734
7735 if (msl_options.texture_buffer_native)
7736 {
7737 if (!msl_options.supports_msl_version(2, 1))
7738 SPIRV_CROSS_THROW("Native texture_buffer type is only supported in MSL 2.1.");
7739 img_type_name = "texture_buffer";
7740 }
7741 else
7742 img_type_name += "texture2d";
7743 break;
7744 case Dim2D:
7745 case DimSubpassData:
7746 if (img_type.ms && img_type.arrayed)
7747 {
7748 if (!msl_options.supports_msl_version(2, 1))
7749 SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
7750 img_type_name += "texture2d_ms_array";
7751 }
7752 else if (img_type.ms)
7753 img_type_name += "texture2d_ms";
7754 else if (img_type.arrayed)
7755 img_type_name += "texture2d_array";
7756 else
7757 img_type_name += "texture2d";
7758 break;
7759 case Dim3D:
7760 img_type_name += "texture3d";
7761 break;
7762 case DimCube:
7763 img_type_name += (img_type.arrayed ? "texturecube_array" : "texturecube");
7764 break;
7765 default:
7766 img_type_name += "unknown_texture_type";
7767 break;
7768 }
7769 }
7770
7771 // Append the pixel type
7772 img_type_name += "<";
7773 img_type_name += type_to_glsl(get<SPIRType>(img_type.type));
7774
7775 // For unsampled images, append the sample/read/write access qualifier.
7776 // For kernel images, the access qualifier my be supplied directly by SPIR-V.
7777 // Otherwise it may be set based on whether the image is read from or written to within the shader.
7778 if (type.basetype == SPIRType::Image && type.image.sampled == 2 && type.image.dim != DimSubpassData)
7779 {
7780 switch (img_type.access)
7781 {
7782 case AccessQualifierReadOnly:
7783 img_type_name += ", access::read";
7784 break;
7785
7786 case AccessQualifierWriteOnly:
7787 img_type_name += ", access::write";
7788 break;
7789
7790 case AccessQualifierReadWrite:
7791 img_type_name += ", access::read_write";
7792 break;
7793
7794 default:
7795 {
7796 auto *p_var = maybe_get_backing_variable(id);
7797 if (p_var && p_var->basevariable)
7798 p_var = maybe_get<SPIRVariable>(p_var->basevariable);
7799 if (p_var && !has_decoration(p_var->self, DecorationNonWritable))
7800 {
7801 img_type_name += ", access::";
7802
7803 if (!has_decoration(p_var->self, DecorationNonReadable))
7804 img_type_name += "read_";
7805
7806 img_type_name += "write";
7807 }
7808 break;
7809 }
7810 }
7811 }
7812
7813 img_type_name += ">";
7814
7815 return img_type_name;
7816 }
7817
emit_subgroup_op(const Instruction & i)7818 void CompilerMSL::emit_subgroup_op(const Instruction &i)
7819 {
7820 const uint32_t *ops = stream(i);
7821 auto op = static_cast<Op>(i.op);
7822
7823 // Metal 2.0 is required. iOS only supports quad ops. macOS only supports
7824 // broadcast and shuffle on 10.13 (2.0), with full support in 10.14 (2.1).
7825 // Note that iOS makes no distinction between a quad-group and a subgroup;
7826 // all subgroups are quad-groups there.
7827 if (!msl_options.supports_msl_version(2))
7828 SPIRV_CROSS_THROW("Subgroups are only supported in Metal 2.0 and up.");
7829
7830 if (msl_options.is_ios())
7831 {
7832 switch (op)
7833 {
7834 default:
7835 SPIRV_CROSS_THROW("iOS only supports quad-group operations.");
7836 case OpGroupNonUniformBroadcast:
7837 case OpGroupNonUniformShuffle:
7838 case OpGroupNonUniformShuffleXor:
7839 case OpGroupNonUniformShuffleUp:
7840 case OpGroupNonUniformShuffleDown:
7841 case OpGroupNonUniformQuadSwap:
7842 case OpGroupNonUniformQuadBroadcast:
7843 break;
7844 }
7845 }
7846
7847 if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
7848 {
7849 switch (op)
7850 {
7851 default:
7852 SPIRV_CROSS_THROW("Subgroup ops beyond broadcast and shuffle on macOS require Metal 2.0 and up.");
7853 case OpGroupNonUniformBroadcast:
7854 case OpGroupNonUniformShuffle:
7855 case OpGroupNonUniformShuffleXor:
7856 case OpGroupNonUniformShuffleUp:
7857 case OpGroupNonUniformShuffleDown:
7858 break;
7859 }
7860 }
7861
7862 uint32_t result_type = ops[0];
7863 uint32_t id = ops[1];
7864
7865 auto scope = static_cast<Scope>(get<SPIRConstant>(ops[2]).scalar());
7866 if (scope != ScopeSubgroup)
7867 SPIRV_CROSS_THROW("Only subgroup scope is supported.");
7868
7869 switch (op)
7870 {
7871 case OpGroupNonUniformElect:
7872 emit_op(result_type, id, "simd_is_first()", true);
7873 break;
7874
7875 case OpGroupNonUniformBroadcast:
7876 emit_binary_func_op(result_type, id, ops[3], ops[4],
7877 msl_options.is_ios() ? "quad_broadcast" : "simd_broadcast");
7878 break;
7879
7880 case OpGroupNonUniformBroadcastFirst:
7881 emit_unary_func_op(result_type, id, ops[3], "simd_broadcast_first");
7882 break;
7883
7884 case OpGroupNonUniformBallot:
7885 emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallot");
7886 break;
7887
7888 case OpGroupNonUniformInverseBallot:
7889 emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_invocation_id_id, "spvSubgroupBallotBitExtract");
7890 break;
7891
7892 case OpGroupNonUniformBallotBitExtract:
7893 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBallotBitExtract");
7894 break;
7895
7896 case OpGroupNonUniformBallotFindLSB:
7897 emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallotFindLSB");
7898 break;
7899
7900 case OpGroupNonUniformBallotFindMSB:
7901 emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallotFindMSB");
7902 break;
7903
7904 case OpGroupNonUniformBallotBitCount:
7905 {
7906 auto operation = static_cast<GroupOperation>(ops[3]);
7907 if (operation == GroupOperationReduce)
7908 emit_unary_func_op(result_type, id, ops[4], "spvSubgroupBallotBitCount");
7909 else if (operation == GroupOperationInclusiveScan)
7910 emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
7911 "spvSubgroupBallotInclusiveBitCount");
7912 else if (operation == GroupOperationExclusiveScan)
7913 emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
7914 "spvSubgroupBallotExclusiveBitCount");
7915 else
7916 SPIRV_CROSS_THROW("Invalid BitCount operation.");
7917 break;
7918 }
7919
7920 case OpGroupNonUniformShuffle:
7921 emit_binary_func_op(result_type, id, ops[3], ops[4], msl_options.is_ios() ? "quad_shuffle" : "simd_shuffle");
7922 break;
7923
7924 case OpGroupNonUniformShuffleXor:
7925 emit_binary_func_op(result_type, id, ops[3], ops[4],
7926 msl_options.is_ios() ? "quad_shuffle_xor" : "simd_shuffle_xor");
7927 break;
7928
7929 case OpGroupNonUniformShuffleUp:
7930 emit_binary_func_op(result_type, id, ops[3], ops[4],
7931 msl_options.is_ios() ? "quad_shuffle_up" : "simd_shuffle_up");
7932 break;
7933
7934 case OpGroupNonUniformShuffleDown:
7935 emit_binary_func_op(result_type, id, ops[3], ops[4],
7936 msl_options.is_ios() ? "quad_shuffle_down" : "simd_shuffle_down");
7937 break;
7938
7939 case OpGroupNonUniformAll:
7940 emit_unary_func_op(result_type, id, ops[3], "simd_all");
7941 break;
7942
7943 case OpGroupNonUniformAny:
7944 emit_unary_func_op(result_type, id, ops[3], "simd_any");
7945 break;
7946
7947 case OpGroupNonUniformAllEqual:
7948 emit_unary_func_op(result_type, id, ops[3], "spvSubgroupAllEqual");
7949 break;
7950
7951 // clang-format off
7952 #define MSL_GROUP_OP(op, msl_op) \
7953 case OpGroupNonUniform##op: \
7954 { \
7955 auto operation = static_cast<GroupOperation>(ops[3]); \
7956 if (operation == GroupOperationReduce) \
7957 emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
7958 else if (operation == GroupOperationInclusiveScan) \
7959 emit_unary_func_op(result_type, id, ops[4], "simd_prefix_inclusive_" #msl_op); \
7960 else if (operation == GroupOperationExclusiveScan) \
7961 emit_unary_func_op(result_type, id, ops[4], "simd_prefix_exclusive_" #msl_op); \
7962 else if (operation == GroupOperationClusteredReduce) \
7963 { \
7964 /* Only cluster sizes of 4 are supported. */ \
7965 uint32_t cluster_size = get<SPIRConstant>(ops[5]).scalar(); \
7966 if (cluster_size != 4) \
7967 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
7968 emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
7969 } \
7970 else \
7971 SPIRV_CROSS_THROW("Invalid group operation."); \
7972 break; \
7973 }
7974 MSL_GROUP_OP(FAdd, sum)
7975 MSL_GROUP_OP(FMul, product)
7976 MSL_GROUP_OP(IAdd, sum)
7977 MSL_GROUP_OP(IMul, product)
7978 #undef MSL_GROUP_OP
7979 // The others, unfortunately, don't support InclusiveScan or ExclusiveScan.
7980 #define MSL_GROUP_OP(op, msl_op) \
7981 case OpGroupNonUniform##op: \
7982 { \
7983 auto operation = static_cast<GroupOperation>(ops[3]); \
7984 if (operation == GroupOperationReduce) \
7985 emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
7986 else if (operation == GroupOperationInclusiveScan) \
7987 SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
7988 else if (operation == GroupOperationExclusiveScan) \
7989 SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
7990 else if (operation == GroupOperationClusteredReduce) \
7991 { \
7992 /* Only cluster sizes of 4 are supported. */ \
7993 uint32_t cluster_size = get<SPIRConstant>(ops[5]).scalar(); \
7994 if (cluster_size != 4) \
7995 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
7996 emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
7997 } \
7998 else \
7999 SPIRV_CROSS_THROW("Invalid group operation."); \
8000 break; \
8001 }
8002 MSL_GROUP_OP(FMin, min)
8003 MSL_GROUP_OP(FMax, max)
8004 MSL_GROUP_OP(SMin, min)
8005 MSL_GROUP_OP(SMax, max)
8006 MSL_GROUP_OP(UMin, min)
8007 MSL_GROUP_OP(UMax, max)
8008 MSL_GROUP_OP(BitwiseAnd, and)
8009 MSL_GROUP_OP(BitwiseOr, or)
8010 MSL_GROUP_OP(BitwiseXor, xor)
8011 MSL_GROUP_OP(LogicalAnd, and)
8012 MSL_GROUP_OP(LogicalOr, or)
8013 MSL_GROUP_OP(LogicalXor, xor)
8014 // clang-format on
8015
8016 case OpGroupNonUniformQuadSwap:
8017 {
8018 // We can implement this easily based on the following table giving
8019 // the target lane ID from the direction and current lane ID:
8020 // Direction
8021 // | 0 | 1 | 2 |
8022 // ---+---+---+---+
8023 // L 0 | 1 2 3
8024 // a 1 | 0 3 2
8025 // n 2 | 3 0 1
8026 // e 3 | 2 1 0
8027 // Notice that target = source ^ (direction + 1).
8028 uint32_t mask = get<SPIRConstant>(ops[4]).scalar() + 1;
8029 uint32_t mask_id = ir.increase_bound_by(1);
8030 set<SPIRConstant>(mask_id, expression_type_id(ops[4]), mask, false);
8031 emit_binary_func_op(result_type, id, ops[3], mask_id, "quad_shuffle_xor");
8032 break;
8033 }
8034
8035 case OpGroupNonUniformQuadBroadcast:
8036 emit_binary_func_op(result_type, id, ops[3], ops[4], "quad_broadcast");
8037 break;
8038
8039 default:
8040 SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
8041 }
8042
8043 register_control_dependent_expression(id);
8044 }
8045
bitcast_glsl_op(const SPIRType & out_type,const SPIRType & in_type)8046 string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
8047 {
8048 if (out_type.basetype == in_type.basetype)
8049 return "";
8050
8051 assert(out_type.basetype != SPIRType::Boolean);
8052 assert(in_type.basetype != SPIRType::Boolean);
8053
8054 bool integral_cast = type_is_integral(out_type) && type_is_integral(in_type);
8055 bool same_size_cast = out_type.width == in_type.width;
8056
8057 if (integral_cast && same_size_cast)
8058 {
8059 // Trivial bitcast case, casts between integers.
8060 return type_to_glsl(out_type);
8061 }
8062 else
8063 {
8064 // Fall back to the catch-all bitcast in MSL.
8065 return "as_type<" + type_to_glsl(out_type) + ">";
8066 }
8067 }
8068
8069 // Returns an MSL string identifying the name of a SPIR-V builtin.
8070 // Output builtins are qualified with the name of the stage out structure.
builtin_to_glsl(BuiltIn builtin,StorageClass storage)8071 string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
8072 {
8073 switch (builtin)
8074 {
8075
8076 // Override GLSL compiler strictness
8077 case BuiltInVertexId:
8078 return "gl_VertexID";
8079 case BuiltInInstanceId:
8080 return "gl_InstanceID";
8081 case BuiltInVertexIndex:
8082 return "gl_VertexIndex";
8083 case BuiltInInstanceIndex:
8084 return "gl_InstanceIndex";
8085 case BuiltInBaseVertex:
8086 return "gl_BaseVertex";
8087 case BuiltInBaseInstance:
8088 return "gl_BaseInstance";
8089 case BuiltInDrawIndex:
8090 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
8091
8092 // When used in the entry function, output builtins are qualified with output struct name.
8093 // Test storage class as NOT Input, as output builtins might be part of generic type.
8094 // Also don't do this for tessellation control shaders.
8095 case BuiltInViewportIndex:
8096 if (!msl_options.supports_msl_version(2, 0))
8097 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
8098 /* fallthrough */
8099 case BuiltInPosition:
8100 case BuiltInPointSize:
8101 case BuiltInClipDistance:
8102 case BuiltInCullDistance:
8103 case BuiltInLayer:
8104 case BuiltInFragDepth:
8105 case BuiltInFragStencilRefEXT:
8106 case BuiltInSampleMask:
8107 if (get_execution_model() == ExecutionModelTessellationControl)
8108 break;
8109 if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
8110 return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
8111
8112 break;
8113
8114 case BuiltInBaryCoordNV:
8115 case BuiltInBaryCoordNoPerspNV:
8116 if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
8117 return stage_in_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
8118 break;
8119
8120 case BuiltInTessLevelOuter:
8121 if (get_execution_model() == ExecutionModelTessellationEvaluation)
8122 {
8123 if (storage != StorageClassOutput && !get_entry_point().flags.get(ExecutionModeTriangles) &&
8124 current_function && (current_function->self == ir.default_entry_point))
8125 return join(patch_stage_in_var_name, ".", CompilerGLSL::builtin_to_glsl(builtin, storage));
8126 else
8127 break;
8128 }
8129 if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
8130 return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
8131 "].edgeTessellationFactor");
8132 break;
8133
8134 case BuiltInTessLevelInner:
8135 if (get_execution_model() == ExecutionModelTessellationEvaluation)
8136 {
8137 if (storage != StorageClassOutput && !get_entry_point().flags.get(ExecutionModeTriangles) &&
8138 current_function && (current_function->self == ir.default_entry_point))
8139 return join(patch_stage_in_var_name, ".", CompilerGLSL::builtin_to_glsl(builtin, storage));
8140 else
8141 break;
8142 }
8143 if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
8144 return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
8145 "].insideTessellationFactor");
8146 break;
8147
8148 default:
8149 break;
8150 }
8151
8152 return CompilerGLSL::builtin_to_glsl(builtin, storage);
8153 }
8154
8155 // Returns an MSL string attribute qualifer for a SPIR-V builtin
builtin_qualifier(BuiltIn builtin)8156 string CompilerMSL::builtin_qualifier(BuiltIn builtin)
8157 {
8158 auto &execution = get_entry_point();
8159
8160 switch (builtin)
8161 {
8162 // Vertex function in
8163 case BuiltInVertexId:
8164 return "vertex_id";
8165 case BuiltInVertexIndex:
8166 return "vertex_id";
8167 case BuiltInBaseVertex:
8168 return "base_vertex";
8169 case BuiltInInstanceId:
8170 return "instance_id";
8171 case BuiltInInstanceIndex:
8172 return "instance_id";
8173 case BuiltInBaseInstance:
8174 return "base_instance";
8175 case BuiltInDrawIndex:
8176 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
8177
8178 // Vertex function out
8179 case BuiltInClipDistance:
8180 return "clip_distance";
8181 case BuiltInPointSize:
8182 return "point_size";
8183 case BuiltInPosition:
8184 if (position_invariant)
8185 {
8186 if (!msl_options.supports_msl_version(2, 1))
8187 SPIRV_CROSS_THROW("Invariant position is only supported on MSL 2.1 and up.");
8188 return "position, invariant";
8189 }
8190 else
8191 return "position";
8192 case BuiltInLayer:
8193 return "render_target_array_index";
8194 case BuiltInViewportIndex:
8195 if (!msl_options.supports_msl_version(2, 0))
8196 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
8197 return "viewport_array_index";
8198
8199 // Tess. control function in
8200 case BuiltInInvocationId:
8201 return "thread_index_in_threadgroup";
8202 case BuiltInPatchVertices:
8203 // Shouldn't be reached.
8204 SPIRV_CROSS_THROW("PatchVertices is derived from the auxiliary buffer in MSL.");
8205 case BuiltInPrimitiveId:
8206 switch (execution.model)
8207 {
8208 case ExecutionModelTessellationControl:
8209 return "threadgroup_position_in_grid";
8210 case ExecutionModelTessellationEvaluation:
8211 return "patch_id";
8212 case ExecutionModelFragment:
8213 if (msl_options.is_ios())
8214 SPIRV_CROSS_THROW("PrimitiveId is not supported in fragment on iOS.");
8215 else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 2))
8216 SPIRV_CROSS_THROW("PrimitiveId on macOS requires MSL 2.2.");
8217 return "primitive_id";
8218 default:
8219 SPIRV_CROSS_THROW("PrimitiveId is not supported in this execution model.");
8220 }
8221
8222 // Tess. control function out
8223 case BuiltInTessLevelOuter:
8224 case BuiltInTessLevelInner:
8225 // Shouldn't be reached.
8226 SPIRV_CROSS_THROW("Tessellation levels are handled specially in MSL.");
8227
8228 // Tess. evaluation function in
8229 case BuiltInTessCoord:
8230 return "position_in_patch";
8231
8232 // Fragment function in
8233 case BuiltInFrontFacing:
8234 return "front_facing";
8235 case BuiltInPointCoord:
8236 return "point_coord";
8237 case BuiltInFragCoord:
8238 return "position";
8239 case BuiltInSampleId:
8240 return "sample_id";
8241 case BuiltInSampleMask:
8242 return "sample_mask";
8243 case BuiltInSamplePosition:
8244 // Shouldn't be reached.
8245 SPIRV_CROSS_THROW("Sample position is retrieved by a function in MSL.");
8246 case BuiltInViewIndex:
8247 if (execution.model != ExecutionModelFragment)
8248 SPIRV_CROSS_THROW("ViewIndex is handled specially outside fragment shaders.");
8249 // The ViewIndex was implicitly used in the prior stages to set the render_target_array_index,
8250 // so we can get it from there.
8251 return "render_target_array_index";
8252
8253 // Fragment function out
8254 case BuiltInFragDepth:
8255 if (execution.flags.get(ExecutionModeDepthGreater))
8256 return "depth(greater)";
8257 else if (execution.flags.get(ExecutionModeDepthLess))
8258 return "depth(less)";
8259 else
8260 return "depth(any)";
8261
8262 case BuiltInFragStencilRefEXT:
8263 return "stencil";
8264
8265 // Compute function in
8266 case BuiltInGlobalInvocationId:
8267 return "thread_position_in_grid";
8268
8269 case BuiltInWorkgroupId:
8270 return "threadgroup_position_in_grid";
8271
8272 case BuiltInNumWorkgroups:
8273 return "threadgroups_per_grid";
8274
8275 case BuiltInLocalInvocationId:
8276 return "thread_position_in_threadgroup";
8277
8278 case BuiltInLocalInvocationIndex:
8279 return "thread_index_in_threadgroup";
8280
8281 case BuiltInSubgroupSize:
8282 if (execution.model == ExecutionModelFragment)
8283 {
8284 if (!msl_options.supports_msl_version(2, 2))
8285 SPIRV_CROSS_THROW("threads_per_simdgroup requires Metal 2.2 in fragment shaders.");
8286 return "threads_per_simdgroup";
8287 }
8288 else
8289 {
8290 // thread_execution_width is an alias for threads_per_simdgroup, and it's only available since 1.0,
8291 // but not in fragment.
8292 return "thread_execution_width";
8293 }
8294
8295 case BuiltInNumSubgroups:
8296 if (!msl_options.supports_msl_version(2))
8297 SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
8298 return msl_options.is_ios() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
8299
8300 case BuiltInSubgroupId:
8301 if (!msl_options.supports_msl_version(2))
8302 SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
8303 return msl_options.is_ios() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
8304
8305 case BuiltInSubgroupLocalInvocationId:
8306 if (execution.model == ExecutionModelFragment)
8307 {
8308 if (!msl_options.supports_msl_version(2, 2))
8309 SPIRV_CROSS_THROW("thread_index_in_simdgroup requires Metal 2.2 in fragment shaders.");
8310 return "thread_index_in_simdgroup";
8311 }
8312 else
8313 {
8314 if (!msl_options.supports_msl_version(2))
8315 SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
8316 return msl_options.is_ios() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
8317 }
8318
8319 case BuiltInSubgroupEqMask:
8320 case BuiltInSubgroupGeMask:
8321 case BuiltInSubgroupGtMask:
8322 case BuiltInSubgroupLeMask:
8323 case BuiltInSubgroupLtMask:
8324 // Shouldn't be reached.
8325 SPIRV_CROSS_THROW("Subgroup ballot masks are handled specially in MSL.");
8326
8327 case BuiltInBaryCoordNV:
8328 // TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
8329 if (msl_options.is_ios())
8330 SPIRV_CROSS_THROW("Barycentrics not supported on iOS.");
8331 else if (!msl_options.supports_msl_version(2, 2))
8332 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
8333 return "barycentric_coord, center_perspective";
8334
8335 case BuiltInBaryCoordNoPerspNV:
8336 // TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
8337 if (msl_options.is_ios())
8338 SPIRV_CROSS_THROW("Barycentrics not supported on iOS.");
8339 else if (!msl_options.supports_msl_version(2, 2))
8340 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
8341 return "barycentric_coord, center_no_perspective";
8342
8343 default:
8344 return "unsupported-built-in";
8345 }
8346 }
8347
8348 // Returns an MSL string type declaration for a SPIR-V builtin
builtin_type_decl(BuiltIn builtin,uint32_t id)8349 string CompilerMSL::builtin_type_decl(BuiltIn builtin, uint32_t id)
8350 {
8351 const SPIREntryPoint &execution = get_entry_point();
8352 switch (builtin)
8353 {
8354 // Vertex function in
8355 case BuiltInVertexId:
8356 return "uint";
8357 case BuiltInVertexIndex:
8358 return "uint";
8359 case BuiltInBaseVertex:
8360 return "uint";
8361 case BuiltInInstanceId:
8362 return "uint";
8363 case BuiltInInstanceIndex:
8364 return "uint";
8365 case BuiltInBaseInstance:
8366 return "uint";
8367 case BuiltInDrawIndex:
8368 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
8369
8370 // Vertex function out
8371 case BuiltInClipDistance:
8372 return "float";
8373 case BuiltInPointSize:
8374 return "float";
8375 case BuiltInPosition:
8376 return "float4";
8377 case BuiltInLayer:
8378 return "uint";
8379 case BuiltInViewportIndex:
8380 if (!msl_options.supports_msl_version(2, 0))
8381 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
8382 return "uint";
8383
8384 // Tess. control function in
8385 case BuiltInInvocationId:
8386 return "uint";
8387 case BuiltInPatchVertices:
8388 return "uint";
8389 case BuiltInPrimitiveId:
8390 return "uint";
8391
8392 // Tess. control function out
8393 case BuiltInTessLevelInner:
8394 if (execution.model == ExecutionModelTessellationEvaluation)
8395 return !execution.flags.get(ExecutionModeTriangles) ? "float2" : "float";
8396 return "half";
8397 case BuiltInTessLevelOuter:
8398 if (execution.model == ExecutionModelTessellationEvaluation)
8399 return !execution.flags.get(ExecutionModeTriangles) ? "float4" : "float";
8400 return "half";
8401
8402 // Tess. evaluation function in
8403 case BuiltInTessCoord:
8404 return execution.flags.get(ExecutionModeTriangles) ? "float3" : "float2";
8405
8406 // Fragment function in
8407 case BuiltInFrontFacing:
8408 return "bool";
8409 case BuiltInPointCoord:
8410 return "float2";
8411 case BuiltInFragCoord:
8412 return "float4";
8413 case BuiltInSampleId:
8414 return "uint";
8415 case BuiltInSampleMask:
8416 return "uint";
8417 case BuiltInSamplePosition:
8418 return "float2";
8419 case BuiltInViewIndex:
8420 return "uint";
8421
8422 // Fragment function out
8423 case BuiltInFragDepth:
8424 return "float";
8425
8426 case BuiltInFragStencilRefEXT:
8427 return "uint";
8428
8429 // Compute function in
8430 case BuiltInGlobalInvocationId:
8431 case BuiltInLocalInvocationId:
8432 case BuiltInNumWorkgroups:
8433 case BuiltInWorkgroupId:
8434 return "uint3";
8435 case BuiltInLocalInvocationIndex:
8436 case BuiltInNumSubgroups:
8437 case BuiltInSubgroupId:
8438 case BuiltInSubgroupSize:
8439 case BuiltInSubgroupLocalInvocationId:
8440 return "uint";
8441 case BuiltInSubgroupEqMask:
8442 case BuiltInSubgroupGeMask:
8443 case BuiltInSubgroupGtMask:
8444 case BuiltInSubgroupLeMask:
8445 case BuiltInSubgroupLtMask:
8446 return "uint4";
8447
8448 case BuiltInHelperInvocation:
8449 return "bool";
8450
8451 case BuiltInBaryCoordNV:
8452 case BuiltInBaryCoordNoPerspNV:
8453 // Use the type as declared, can be 1, 2 or 3 components.
8454 return type_to_glsl(get_variable_data_type(get<SPIRVariable>(id)));
8455
8456 default:
8457 return "unsupported-built-in-type";
8458 }
8459 }
8460
8461 // Returns the declaration of a built-in argument to a function
built_in_func_arg(BuiltIn builtin,bool prefix_comma)8462 string CompilerMSL::built_in_func_arg(BuiltIn builtin, bool prefix_comma)
8463 {
8464 string bi_arg;
8465 if (prefix_comma)
8466 bi_arg += ", ";
8467
8468 bi_arg += builtin_type_decl(builtin);
8469 bi_arg += " " + builtin_to_glsl(builtin, StorageClassInput);
8470 bi_arg += " [[" + builtin_qualifier(builtin) + "]]";
8471
8472 return bi_arg;
8473 }
8474
8475 // Returns the byte size of a struct member.
get_declared_struct_member_size_msl(const SPIRType & struct_type,uint32_t index) const8476 size_t CompilerMSL::get_declared_struct_member_size_msl(const SPIRType &struct_type, uint32_t index) const
8477 {
8478 auto &type = get<SPIRType>(struct_type.member_types[index]);
8479
8480 switch (type.basetype)
8481 {
8482 case SPIRType::Unknown:
8483 case SPIRType::Void:
8484 case SPIRType::AtomicCounter:
8485 case SPIRType::Image:
8486 case SPIRType::SampledImage:
8487 case SPIRType::Sampler:
8488 SPIRV_CROSS_THROW("Querying size of opaque object.");
8489
8490 default:
8491 {
8492 // For arrays, we can use ArrayStride to get an easy check.
8493 // Runtime arrays will have zero size so force to min of one.
8494 if (!type.array.empty())
8495 {
8496 uint32_t array_size = to_array_size_literal(type);
8497 #ifdef RARCH_INTERNAL
8498 return type_struct_member_array_stride(struct_type, index) * MAX(array_size, 1u);
8499 #else
8500 return type_struct_member_array_stride(struct_type, index) * max(array_size, 1u);
8501 #endif
8502 }
8503
8504 if (type.basetype == SPIRType::Struct)
8505 {
8506 // The size of a struct in Metal is aligned up to its natural alignment.
8507 auto size = get_declared_struct_size(type);
8508 auto alignment = get_declared_struct_member_alignment(struct_type, index);
8509 return (size + alignment - 1) & ~(alignment - 1);
8510 }
8511
8512 uint32_t component_size = type.width / 8;
8513 uint32_t vecsize = type.vecsize;
8514 uint32_t columns = type.columns;
8515
8516 // An unpacked 3-element vector or matrix column is the same memory size as a 4-element.
8517 if (vecsize == 3 && !has_extended_member_decoration(struct_type.self, index, SPIRVCrossDecorationPacked))
8518 vecsize = 4;
8519
8520 return component_size * vecsize * columns;
8521 }
8522 }
8523 }
8524
8525 // Returns the byte alignment of a struct member.
get_declared_struct_member_alignment(const SPIRType & struct_type,uint32_t index) const8526 size_t CompilerMSL::get_declared_struct_member_alignment(const SPIRType &struct_type, uint32_t index) const
8527 {
8528 auto &type = get<SPIRType>(struct_type.member_types[index]);
8529
8530 switch (type.basetype)
8531 {
8532 case SPIRType::Unknown:
8533 case SPIRType::Void:
8534 case SPIRType::AtomicCounter:
8535 case SPIRType::Image:
8536 case SPIRType::SampledImage:
8537 case SPIRType::Sampler:
8538 SPIRV_CROSS_THROW("Querying alignment of opaque object.");
8539
8540 case SPIRType::Int64:
8541 SPIRV_CROSS_THROW("long types are not supported in buffers in MSL.");
8542 case SPIRType::UInt64:
8543 SPIRV_CROSS_THROW("ulong types are not supported in buffers in MSL.");
8544 case SPIRType::Double:
8545 SPIRV_CROSS_THROW("double types are not supported in buffers in MSL.");
8546
8547 case SPIRType::Struct:
8548 {
8549 // In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
8550 uint32_t alignment = 1;
8551 for (uint32_t i = 0; i < type.member_types.size(); i++)
8552 alignment = max(alignment, uint32_t(get_declared_struct_member_alignment(type, i)));
8553 return alignment;
8554 }
8555
8556 default:
8557 {
8558 // Alignment of packed type is the same as the underlying component or column size.
8559 // Alignment of unpacked type is the same as the vector size.
8560 // Alignment of 3-elements vector is the same as 4-elements (including packed using column).
8561 if (member_is_packed_type(struct_type, index))
8562 {
8563 // This is getting pretty complicated.
8564 // The special case of array of float/float2 needs to be handled here.
8565 uint32_t packed_type_id =
8566 get_extended_member_decoration(struct_type.self, index, SPIRVCrossDecorationPackedType);
8567 const SPIRType *packed_type = packed_type_id != 0 ? &get<SPIRType>(packed_type_id) : nullptr;
8568 if (packed_type && is_array(*packed_type) && !is_matrix(*packed_type) &&
8569 packed_type->basetype != SPIRType::Struct)
8570 {
8571 uint32_t stride = type_struct_member_array_stride(struct_type, index);
8572 if (stride == (packed_type->width / 8) * 4)
8573 return stride;
8574 else
8575 return packed_type->width / 8;
8576 }
8577 else
8578 return type.width / 8;
8579 }
8580 else
8581 return (type.width / 8) * (type.vecsize == 3 ? 4 : type.vecsize);
8582 }
8583 }
8584 }
8585
skip_argument(uint32_t) const8586 bool CompilerMSL::skip_argument(uint32_t) const
8587 {
8588 return false;
8589 }
8590
analyze_sampled_image_usage()8591 void CompilerMSL::analyze_sampled_image_usage()
8592 {
8593 if (msl_options.swizzle_texture_samples)
8594 {
8595 SampledImageScanner scanner(*this);
8596 traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), scanner);
8597 }
8598 }
8599
handle(spv::Op opcode,const uint32_t * args,uint32_t length)8600 bool CompilerMSL::SampledImageScanner::handle(spv::Op opcode, const uint32_t *args, uint32_t length)
8601 {
8602 switch (opcode)
8603 {
8604 case OpLoad:
8605 case OpImage:
8606 case OpSampledImage:
8607 {
8608 if (length < 3)
8609 return false;
8610
8611 uint32_t result_type = args[0];
8612 auto &type = compiler.get<SPIRType>(result_type);
8613 if ((type.basetype != SPIRType::Image && type.basetype != SPIRType::SampledImage) || type.image.sampled != 1)
8614 return true;
8615
8616 uint32_t id = args[1];
8617 compiler.set<SPIRExpression>(id, "", result_type, true);
8618 break;
8619 }
8620 case OpImageSampleExplicitLod:
8621 case OpImageSampleProjExplicitLod:
8622 case OpImageSampleDrefExplicitLod:
8623 case OpImageSampleProjDrefExplicitLod:
8624 case OpImageSampleImplicitLod:
8625 case OpImageSampleProjImplicitLod:
8626 case OpImageSampleDrefImplicitLod:
8627 case OpImageSampleProjDrefImplicitLod:
8628 case OpImageFetch:
8629 case OpImageGather:
8630 case OpImageDrefGather:
8631 compiler.has_sampled_images =
8632 compiler.has_sampled_images || compiler.is_sampled_image_type(compiler.expression_type(args[2]));
8633 compiler.needs_swizzle_buffer_def = compiler.needs_swizzle_buffer_def || compiler.has_sampled_images;
8634 break;
8635 default:
8636 break;
8637 }
8638 return true;
8639 }
8640
handle(Op opcode,const uint32_t * args,uint32_t length)8641 bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, uint32_t length)
8642 {
8643 // Since MSL exists in a single execution scope, function prototype declarations are not
8644 // needed, and clutter the output. If secondary functions are output (either as a SPIR-V
8645 // function implementation or as indicated by the presence of OpFunctionCall), then set
8646 // suppress_missing_prototypes to suppress compiler warnings of missing function prototypes.
8647
8648 // Mark if the input requires the implementation of an SPIR-V function that does not exist in Metal.
8649 SPVFuncImpl spv_func = get_spv_func_impl(opcode, args);
8650 if (spv_func != SPVFuncImplNone)
8651 {
8652 compiler.spv_function_implementations.insert(spv_func);
8653 suppress_missing_prototypes = true;
8654 }
8655
8656 switch (opcode)
8657 {
8658
8659 case OpFunctionCall:
8660 suppress_missing_prototypes = true;
8661 break;
8662
8663 case OpImageWrite:
8664 uses_resource_write = true;
8665 break;
8666
8667 case OpStore:
8668 check_resource_write(args[0]);
8669 break;
8670
8671 case OpAtomicExchange:
8672 case OpAtomicCompareExchange:
8673 case OpAtomicCompareExchangeWeak:
8674 case OpAtomicIIncrement:
8675 case OpAtomicIDecrement:
8676 case OpAtomicIAdd:
8677 case OpAtomicISub:
8678 case OpAtomicSMin:
8679 case OpAtomicUMin:
8680 case OpAtomicSMax:
8681 case OpAtomicUMax:
8682 case OpAtomicAnd:
8683 case OpAtomicOr:
8684 case OpAtomicXor:
8685 uses_atomics = true;
8686 check_resource_write(args[2]);
8687 break;
8688
8689 case OpAtomicLoad:
8690 uses_atomics = true;
8691 break;
8692
8693 case OpGroupNonUniformInverseBallot:
8694 needs_subgroup_invocation_id = true;
8695 break;
8696
8697 case OpGroupNonUniformBallotBitCount:
8698 if (args[3] != GroupOperationReduce)
8699 needs_subgroup_invocation_id = true;
8700 break;
8701
8702 case OpArrayLength:
8703 {
8704 auto *var = compiler.maybe_get_backing_variable(args[2]);
8705 if (var)
8706 compiler.buffers_requiring_array_length.insert(var->self);
8707 break;
8708 }
8709
8710 case OpInBoundsAccessChain:
8711 case OpAccessChain:
8712 case OpPtrAccessChain:
8713 {
8714 // OpArrayLength might want to know if taking ArrayLength of an array of SSBOs.
8715 uint32_t result_type = args[0];
8716 uint32_t id = args[1];
8717 uint32_t ptr = args[2];
8718 compiler.set<SPIRExpression>(id, "", result_type, true);
8719 compiler.register_read(id, ptr, true);
8720 compiler.ir.ids[id].set_allow_type_rewrite();
8721 break;
8722 }
8723
8724 default:
8725 break;
8726 }
8727
8728 // If it has one, keep track of the instruction's result type, mapped by ID
8729 uint32_t result_type, result_id;
8730 if (compiler.instruction_to_result_type(result_type, result_id, opcode, args, length))
8731 result_types[result_id] = result_type;
8732
8733 return true;
8734 }
8735
8736 // If the variable is a Uniform or StorageBuffer, mark that a resource has been written to.
check_resource_write(uint32_t var_id)8737 void CompilerMSL::OpCodePreprocessor::check_resource_write(uint32_t var_id)
8738 {
8739 auto *p_var = compiler.maybe_get_backing_variable(var_id);
8740 StorageClass sc = p_var ? p_var->storage : StorageClassMax;
8741 if (sc == StorageClassUniform || sc == StorageClassStorageBuffer)
8742 uses_resource_write = true;
8743 }
8744
8745 // Returns an enumeration of a SPIR-V function that needs to be output for certain Op codes.
get_spv_func_impl(Op opcode,const uint32_t * args)8746 CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op opcode, const uint32_t *args)
8747 {
8748 switch (opcode)
8749 {
8750 case OpFMod:
8751 return SPVFuncImplMod;
8752
8753 case OpFunctionCall:
8754 {
8755 auto &return_type = compiler.get<SPIRType>(args[0]);
8756 if (return_type.array.size() > 1)
8757 {
8758 if (return_type.array.size() > SPVFuncImplArrayCopyMultidimMax)
8759 SPIRV_CROSS_THROW("Cannot support this many dimensions for arrays of arrays.");
8760 return static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + return_type.array.size());
8761 }
8762 else if (return_type.array.size() > 0)
8763 return SPVFuncImplArrayCopy;
8764
8765 break;
8766 }
8767
8768 case OpStore:
8769 {
8770 // Get the result type of the RHS. Since this is run as a pre-processing stage,
8771 // we must extract the result type directly from the Instruction, rather than the ID.
8772 uint32_t id_lhs = args[0];
8773 uint32_t id_rhs = args[1];
8774
8775 const SPIRType *type = nullptr;
8776 if (compiler.ir.ids[id_rhs].get_type() != TypeNone)
8777 {
8778 // Could be a constant, or similar.
8779 type = &compiler.expression_type(id_rhs);
8780 }
8781 else
8782 {
8783 // Or ... an expression.
8784 uint32_t tid = result_types[id_rhs];
8785 if (tid)
8786 type = &compiler.get<SPIRType>(tid);
8787 }
8788
8789 auto *var = compiler.maybe_get<SPIRVariable>(id_lhs);
8790
8791 // Are we simply assigning to a statically assigned variable which takes a constant?
8792 // Don't bother emitting this function.
8793 bool static_expression_lhs =
8794 var && var->storage == StorageClassFunction && var->statically_assigned && var->remapped_variable;
8795 if (type && compiler.is_array(*type) && !static_expression_lhs)
8796 {
8797 if (type->array.size() > 1)
8798 {
8799 if (type->array.size() > SPVFuncImplArrayCopyMultidimMax)
8800 SPIRV_CROSS_THROW("Cannot support this many dimensions for arrays of arrays.");
8801 return static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + type->array.size());
8802 }
8803 else
8804 return SPVFuncImplArrayCopy;
8805 }
8806
8807 break;
8808 }
8809
8810 case OpImageFetch:
8811 case OpImageRead:
8812 case OpImageWrite:
8813 {
8814 // Retrieve the image type, and if it's a Buffer, emit a texel coordinate function
8815 uint32_t tid = result_types[args[opcode == OpImageWrite ? 0 : 2]];
8816 if (tid && compiler.get<SPIRType>(tid).image.dim == DimBuffer && !compiler.msl_options.texture_buffer_native)
8817 return SPVFuncImplTexelBufferCoords;
8818
8819 if (opcode == OpImageFetch && compiler.msl_options.swizzle_texture_samples)
8820 return SPVFuncImplTextureSwizzle;
8821
8822 break;
8823 }
8824
8825 case OpImageSampleExplicitLod:
8826 case OpImageSampleProjExplicitLod:
8827 case OpImageSampleDrefExplicitLod:
8828 case OpImageSampleProjDrefExplicitLod:
8829 case OpImageSampleImplicitLod:
8830 case OpImageSampleProjImplicitLod:
8831 case OpImageSampleDrefImplicitLod:
8832 case OpImageSampleProjDrefImplicitLod:
8833 case OpImageGather:
8834 case OpImageDrefGather:
8835 if (compiler.msl_options.swizzle_texture_samples)
8836 return SPVFuncImplTextureSwizzle;
8837 break;
8838
8839 case OpExtInst:
8840 {
8841 uint32_t extension_set = args[2];
8842 if (compiler.get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
8843 {
8844 auto op_450 = static_cast<GLSLstd450>(args[3]);
8845 switch (op_450)
8846 {
8847 case GLSLstd450Radians:
8848 return SPVFuncImplRadians;
8849 case GLSLstd450Degrees:
8850 return SPVFuncImplDegrees;
8851 case GLSLstd450FindILsb:
8852 return SPVFuncImplFindILsb;
8853 case GLSLstd450FindSMsb:
8854 return SPVFuncImplFindSMsb;
8855 case GLSLstd450FindUMsb:
8856 return SPVFuncImplFindUMsb;
8857 case GLSLstd450SSign:
8858 return SPVFuncImplSSign;
8859 case GLSLstd450Reflect:
8860 {
8861 auto &type = compiler.get<SPIRType>(args[0]);
8862 if (type.vecsize == 1)
8863 return SPVFuncImplReflectScalar;
8864 else
8865 return SPVFuncImplNone;
8866 }
8867 case GLSLstd450Refract:
8868 {
8869 auto &type = compiler.get<SPIRType>(args[0]);
8870 if (type.vecsize == 1)
8871 return SPVFuncImplRefractScalar;
8872 else
8873 return SPVFuncImplNone;
8874 }
8875 case GLSLstd450MatrixInverse:
8876 {
8877 auto &mat_type = compiler.get<SPIRType>(args[0]);
8878 switch (mat_type.columns)
8879 {
8880 case 2:
8881 return SPVFuncImplInverse2x2;
8882 case 3:
8883 return SPVFuncImplInverse3x3;
8884 case 4:
8885 return SPVFuncImplInverse4x4;
8886 default:
8887 break;
8888 }
8889 break;
8890 }
8891 default:
8892 break;
8893 }
8894 }
8895 break;
8896 }
8897
8898 case OpGroupNonUniformBallot:
8899 return SPVFuncImplSubgroupBallot;
8900
8901 case OpGroupNonUniformInverseBallot:
8902 case OpGroupNonUniformBallotBitExtract:
8903 return SPVFuncImplSubgroupBallotBitExtract;
8904
8905 case OpGroupNonUniformBallotFindLSB:
8906 return SPVFuncImplSubgroupBallotFindLSB;
8907
8908 case OpGroupNonUniformBallotFindMSB:
8909 return SPVFuncImplSubgroupBallotFindMSB;
8910
8911 case OpGroupNonUniformBallotBitCount:
8912 return SPVFuncImplSubgroupBallotBitCount;
8913
8914 case OpGroupNonUniformAllEqual:
8915 return SPVFuncImplSubgroupAllEqual;
8916
8917 default:
8918 break;
8919 }
8920 return SPVFuncImplNone;
8921 }
8922
8923 // Sort both type and meta member content based on builtin status (put builtins at end),
8924 // then by the required sorting aspect.
sort()8925 void CompilerMSL::MemberSorter::sort()
8926 {
8927 // Create a temporary array of consecutive member indices and sort it based on how
8928 // the members should be reordered, based on builtin and sorting aspect meta info.
8929 size_t mbr_cnt = type.member_types.size();
8930 SmallVector<uint32_t> mbr_idxs(mbr_cnt);
8931 iota(mbr_idxs.begin(), mbr_idxs.end(), 0); // Fill with consecutive indices
8932 std::sort(mbr_idxs.begin(), mbr_idxs.end(), *this); // Sort member indices based on sorting aspect
8933
8934 // Move type and meta member info to the order defined by the sorted member indices.
8935 // This is done by creating temporary copies of both member types and meta, and then
8936 // copying back to the original content at the sorted indices.
8937 auto mbr_types_cpy = type.member_types;
8938 auto mbr_meta_cpy = meta.members;
8939 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
8940 {
8941 type.member_types[mbr_idx] = mbr_types_cpy[mbr_idxs[mbr_idx]];
8942 meta.members[mbr_idx] = mbr_meta_cpy[mbr_idxs[mbr_idx]];
8943 }
8944 }
8945
8946 // Sort first by builtin status (put builtins at end), then by the sorting aspect.
operator ()(uint32_t mbr_idx1,uint32_t mbr_idx2)8947 bool CompilerMSL::MemberSorter::operator()(uint32_t mbr_idx1, uint32_t mbr_idx2)
8948 {
8949 auto &mbr_meta1 = meta.members[mbr_idx1];
8950 auto &mbr_meta2 = meta.members[mbr_idx2];
8951 if (mbr_meta1.builtin != mbr_meta2.builtin)
8952 return mbr_meta2.builtin;
8953 else
8954 switch (sort_aspect)
8955 {
8956 case Location:
8957 return mbr_meta1.location < mbr_meta2.location;
8958 case LocationReverse:
8959 return mbr_meta1.location > mbr_meta2.location;
8960 case Offset:
8961 return mbr_meta1.offset < mbr_meta2.offset;
8962 case OffsetThenLocationReverse:
8963 return (mbr_meta1.offset < mbr_meta2.offset) ||
8964 ((mbr_meta1.offset == mbr_meta2.offset) && (mbr_meta1.location > mbr_meta2.location));
8965 case Alphabetical:
8966 return mbr_meta1.alias < mbr_meta2.alias;
8967 default:
8968 return false;
8969 }
8970 }
8971
MemberSorter(SPIRType & t,Meta & m,SortAspect sa)8972 CompilerMSL::MemberSorter::MemberSorter(SPIRType &t, Meta &m, SortAspect sa)
8973 : type(t)
8974 , meta(m)
8975 , sort_aspect(sa)
8976 {
8977 // Ensure enough meta info is available
8978 meta.members.resize(max(type.member_types.size(), meta.members.size()));
8979 }
8980
remap_constexpr_sampler(uint32_t id,const MSLConstexprSampler & sampler)8981 void CompilerMSL::remap_constexpr_sampler(uint32_t id, const MSLConstexprSampler &sampler)
8982 {
8983 auto &type = get<SPIRType>(get<SPIRVariable>(id).basetype);
8984 if (type.basetype != SPIRType::SampledImage && type.basetype != SPIRType::Sampler)
8985 SPIRV_CROSS_THROW("Can only remap SampledImage and Sampler type.");
8986 if (!type.array.empty())
8987 SPIRV_CROSS_THROW("Can not remap array of samplers.");
8988 constexpr_samplers_by_id[id] = sampler;
8989 }
8990
remap_constexpr_sampler_by_binding(uint32_t desc_set,uint32_t binding,const MSLConstexprSampler & sampler)8991 void CompilerMSL::remap_constexpr_sampler_by_binding(uint32_t desc_set, uint32_t binding,
8992 const MSLConstexprSampler &sampler)
8993 {
8994 constexpr_samplers_by_binding[{ desc_set, binding }] = sampler;
8995 }
8996
bitcast_from_builtin_load(uint32_t source_id,std::string & expr,const SPIRType & expr_type)8997 void CompilerMSL::bitcast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
8998 {
8999 auto *var = maybe_get_backing_variable(source_id);
9000 if (var)
9001 source_id = var->self;
9002
9003 // Only interested in standalone builtin variables.
9004 if (!has_decoration(source_id, DecorationBuiltIn))
9005 return;
9006
9007 auto builtin = static_cast<BuiltIn>(get_decoration(source_id, DecorationBuiltIn));
9008 auto expected_type = expr_type.basetype;
9009 switch (builtin)
9010 {
9011 case BuiltInGlobalInvocationId:
9012 case BuiltInLocalInvocationId:
9013 case BuiltInWorkgroupId:
9014 case BuiltInLocalInvocationIndex:
9015 case BuiltInWorkgroupSize:
9016 case BuiltInNumWorkgroups:
9017 case BuiltInLayer:
9018 case BuiltInViewportIndex:
9019 case BuiltInFragStencilRefEXT:
9020 case BuiltInPrimitiveId:
9021 case BuiltInSubgroupSize:
9022 case BuiltInSubgroupLocalInvocationId:
9023 case BuiltInViewIndex:
9024 expected_type = SPIRType::UInt;
9025 break;
9026
9027 case BuiltInTessLevelInner:
9028 case BuiltInTessLevelOuter:
9029 if (get_execution_model() == ExecutionModelTessellationControl)
9030 expected_type = SPIRType::Half;
9031 break;
9032
9033 default:
9034 break;
9035 }
9036
9037 if (expected_type != expr_type.basetype)
9038 expr = bitcast_expression(expr_type, expected_type, expr);
9039
9040 if (builtin == BuiltInTessCoord && get_entry_point().flags.get(ExecutionModeQuads) && expr_type.vecsize == 3)
9041 {
9042 // In SPIR-V, this is always a vec3, even for quads. In Metal, though, it's a float2 for quads.
9043 // The code is expecting a float3, so we need to widen this.
9044 expr = join("float3(", expr, ", 0)");
9045 }
9046 }
9047
bitcast_to_builtin_store(uint32_t target_id,std::string & expr,const SPIRType & expr_type)9048 void CompilerMSL::bitcast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
9049 {
9050 auto *var = maybe_get_backing_variable(target_id);
9051 if (var)
9052 target_id = var->self;
9053
9054 // Only interested in standalone builtin variables.
9055 if (!has_decoration(target_id, DecorationBuiltIn))
9056 return;
9057
9058 auto builtin = static_cast<BuiltIn>(get_decoration(target_id, DecorationBuiltIn));
9059 auto expected_type = expr_type.basetype;
9060 switch (builtin)
9061 {
9062 case BuiltInLayer:
9063 case BuiltInViewportIndex:
9064 case BuiltInFragStencilRefEXT:
9065 case BuiltInPrimitiveId:
9066 case BuiltInViewIndex:
9067 expected_type = SPIRType::UInt;
9068 break;
9069
9070 case BuiltInTessLevelInner:
9071 case BuiltInTessLevelOuter:
9072 expected_type = SPIRType::Half;
9073 break;
9074
9075 default:
9076 break;
9077 }
9078
9079 if (expected_type != expr_type.basetype)
9080 {
9081 if (expected_type == SPIRType::Half && expr_type.basetype == SPIRType::Float)
9082 {
9083 // These are of different widths, so we cannot do a straight bitcast.
9084 expr = join("half(", expr, ")");
9085 }
9086 else
9087 {
9088 auto type = expr_type;
9089 type.basetype = expected_type;
9090 expr = bitcast_expression(type, expr_type.basetype, expr);
9091 }
9092 }
9093 }
9094
to_initializer_expression(const SPIRVariable & var)9095 std::string CompilerMSL::to_initializer_expression(const SPIRVariable &var)
9096 {
9097 // We risk getting an array initializer here with MSL. If we have an array.
9098 // FIXME: We cannot handle non-constant arrays being initialized.
9099 // We will need to inject spvArrayCopy here somehow ...
9100 auto &type = get<SPIRType>(var.basetype);
9101 if (ir.ids[var.initializer].get_type() == TypeConstant &&
9102 (!type.array.empty() || type.basetype == SPIRType::Struct))
9103 return constant_expression(get<SPIRConstant>(var.initializer));
9104 else
9105 return CompilerGLSL::to_initializer_expression(var);
9106 }
9107
descriptor_set_is_argument_buffer(uint32_t desc_set) const9108 bool CompilerMSL::descriptor_set_is_argument_buffer(uint32_t desc_set) const
9109 {
9110 if (!msl_options.argument_buffers)
9111 return false;
9112 if (desc_set >= kMaxArgumentBuffers)
9113 return false;
9114
9115 return (argument_buffer_discrete_mask & (1u << desc_set)) == 0;
9116 }
9117
analyze_argument_buffers()9118 void CompilerMSL::analyze_argument_buffers()
9119 {
9120 // Gather all used resources and sort them out into argument buffers.
9121 // Each argument buffer corresponds to a descriptor set in SPIR-V.
9122 // The [[id(N)]] values used correspond to the resource mapping we have for MSL.
9123 // Otherwise, the binding number is used, but this is generally not safe some types like
9124 // combined image samplers and arrays of resources. Metal needs different indices here,
9125 // while SPIR-V can have one descriptor set binding. To use argument buffers in practice,
9126 // you will need to use the remapping from the API.
9127 for (auto &id : argument_buffer_ids)
9128 id = 0;
9129
9130 // Output resources, sorted by resource index & type.
9131 struct Resource
9132 {
9133 SPIRVariable *var;
9134 string name;
9135 SPIRType::BaseType basetype;
9136 uint32_t index;
9137 };
9138 SmallVector<Resource> resources_in_set[kMaxArgumentBuffers];
9139
9140 bool set_needs_swizzle_buffer[kMaxArgumentBuffers] = {};
9141 bool set_needs_buffer_sizes[kMaxArgumentBuffers] = {};
9142 bool needs_buffer_sizes = false;
9143
9144 ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &var) {
9145 if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
9146 var.storage == StorageClassStorageBuffer) &&
9147 !is_hidden_variable(var))
9148 {
9149 uint32_t desc_set = get_decoration(self, DecorationDescriptorSet);
9150 // Ignore if it's part of a push descriptor set.
9151 if (!descriptor_set_is_argument_buffer(desc_set))
9152 return;
9153
9154 uint32_t var_id = var.self;
9155 auto &type = get_variable_data_type(var);
9156
9157 if (desc_set >= kMaxArgumentBuffers)
9158 SPIRV_CROSS_THROW("Descriptor set index is out of range.");
9159
9160 const MSLConstexprSampler *constexpr_sampler = nullptr;
9161 if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
9162 {
9163 constexpr_sampler = find_constexpr_sampler(var_id);
9164 if (constexpr_sampler)
9165 {
9166 // Mark this ID as a constexpr sampler for later in case it came from set/bindings.
9167 constexpr_samplers_by_id[var_id] = *constexpr_sampler;
9168 }
9169 }
9170
9171 if (type.basetype == SPIRType::SampledImage)
9172 {
9173 add_resource_name(var_id);
9174
9175 uint32_t image_resource_index = get_metal_resource_index(var, SPIRType::Image);
9176 uint32_t sampler_resource_index = get_metal_resource_index(var, SPIRType::Sampler);
9177
9178 resources_in_set[desc_set].push_back({ &var, to_name(var_id), SPIRType::Image, image_resource_index });
9179
9180 if (type.image.dim != DimBuffer && !constexpr_sampler)
9181 {
9182 resources_in_set[desc_set].push_back(
9183 { &var, to_sampler_expression(var_id), SPIRType::Sampler, sampler_resource_index });
9184 }
9185 }
9186 else if (!constexpr_sampler)
9187 {
9188 // constexpr samplers are not declared as resources.
9189 add_resource_name(var_id);
9190 resources_in_set[desc_set].push_back(
9191 { &var, to_name(var_id), type.basetype, get_metal_resource_index(var, type.basetype) });
9192 }
9193
9194 // Check if this descriptor set needs a swizzle buffer.
9195 if (needs_swizzle_buffer_def && is_sampled_image_type(type))
9196 set_needs_swizzle_buffer[desc_set] = true;
9197 else if (buffers_requiring_array_length.count(var_id) != 0)
9198 {
9199 set_needs_buffer_sizes[desc_set] = true;
9200 needs_buffer_sizes = true;
9201 }
9202 }
9203 });
9204
9205 if (needs_swizzle_buffer_def || needs_buffer_sizes)
9206 {
9207 uint32_t uint_ptr_type_id = 0;
9208
9209 // We might have to add a swizzle buffer resource to the set.
9210 for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
9211 {
9212 if (!set_needs_swizzle_buffer[desc_set] && !set_needs_buffer_sizes[desc_set])
9213 continue;
9214
9215 if (uint_ptr_type_id == 0)
9216 {
9217 uint32_t offset = ir.increase_bound_by(2);
9218 uint32_t type_id = offset;
9219 uint_ptr_type_id = offset + 1;
9220
9221 // Create a buffer to hold extra data, including the swizzle constants.
9222 SPIRType uint_type;
9223 uint_type.basetype = SPIRType::UInt;
9224 uint_type.width = 32;
9225 set<SPIRType>(type_id, uint_type);
9226
9227 SPIRType uint_type_pointer = uint_type;
9228 uint_type_pointer.pointer = true;
9229 uint_type_pointer.pointer_depth = 1;
9230 uint_type_pointer.parent_type = type_id;
9231 uint_type_pointer.storage = StorageClassUniform;
9232 set<SPIRType>(uint_ptr_type_id, uint_type_pointer);
9233 set_decoration(uint_ptr_type_id, DecorationArrayStride, 4);
9234 }
9235
9236 if (set_needs_swizzle_buffer[desc_set])
9237 {
9238 uint32_t var_id = ir.increase_bound_by(1);
9239 auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
9240 set_name(var_id, "spvSwizzleConstants");
9241 set_decoration(var_id, DecorationDescriptorSet, desc_set);
9242 set_decoration(var_id, DecorationBinding, kSwizzleBufferBinding);
9243 resources_in_set[desc_set].push_back(
9244 { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt) });
9245 }
9246
9247 if (set_needs_buffer_sizes[desc_set])
9248 {
9249 uint32_t var_id = ir.increase_bound_by(1);
9250 auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
9251 set_name(var_id, "spvBufferSizeConstants");
9252 set_decoration(var_id, DecorationDescriptorSet, desc_set);
9253 set_decoration(var_id, DecorationBinding, kBufferSizeBufferBinding);
9254 resources_in_set[desc_set].push_back(
9255 { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt) });
9256 }
9257 }
9258 }
9259
9260 for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
9261 {
9262 auto &resources = resources_in_set[desc_set];
9263 if (resources.empty())
9264 continue;
9265
9266 assert(descriptor_set_is_argument_buffer(desc_set));
9267
9268 uint32_t next_id = ir.increase_bound_by(3);
9269 uint32_t type_id = next_id + 1;
9270 uint32_t ptr_type_id = next_id + 2;
9271 argument_buffer_ids[desc_set] = next_id;
9272
9273 auto &buffer_type = set<SPIRType>(type_id);
9274 buffer_type.storage = StorageClassUniform;
9275 buffer_type.basetype = SPIRType::Struct;
9276 set_name(type_id, join("spvDescriptorSetBuffer", desc_set));
9277
9278 auto &ptr_type = set<SPIRType>(ptr_type_id);
9279 ptr_type = buffer_type;
9280 ptr_type.pointer = true;
9281 ptr_type.pointer_depth = 1;
9282 ptr_type.parent_type = type_id;
9283
9284 uint32_t buffer_variable_id = next_id;
9285 set<SPIRVariable>(buffer_variable_id, ptr_type_id, StorageClassUniform);
9286 set_name(buffer_variable_id, join("spvDescriptorSet", desc_set));
9287
9288 // Ids must be emitted in ID order.
9289 sort(begin(resources), end(resources), [&](const Resource &lhs, const Resource &rhs) -> bool {
9290 return tie(lhs.index, lhs.basetype) < tie(rhs.index, rhs.basetype);
9291 });
9292
9293 uint32_t member_index = 0;
9294 for (auto &resource : resources)
9295 {
9296 auto &var = *resource.var;
9297 auto &type = get_variable_data_type(var);
9298 string mbr_name = ensure_valid_name(resource.name, "m");
9299 set_member_name(buffer_type.self, member_index, mbr_name);
9300
9301 if (resource.basetype == SPIRType::Sampler && type.basetype != SPIRType::Sampler)
9302 {
9303 // Have to synthesize a sampler type here.
9304
9305 bool type_is_array = !type.array.empty();
9306 uint32_t sampler_type_id = ir.increase_bound_by(type_is_array ? 2 : 1);
9307 auto &new_sampler_type = set<SPIRType>(sampler_type_id);
9308 new_sampler_type.basetype = SPIRType::Sampler;
9309 new_sampler_type.storage = StorageClassUniformConstant;
9310
9311 if (type_is_array)
9312 {
9313 uint32_t sampler_type_array_id = sampler_type_id + 1;
9314 auto &sampler_type_array = set<SPIRType>(sampler_type_array_id);
9315 sampler_type_array = new_sampler_type;
9316 sampler_type_array.array = type.array;
9317 sampler_type_array.array_size_literal = type.array_size_literal;
9318 sampler_type_array.parent_type = sampler_type_id;
9319 buffer_type.member_types.push_back(sampler_type_array_id);
9320 }
9321 else
9322 buffer_type.member_types.push_back(sampler_type_id);
9323 }
9324 else
9325 {
9326 if (resource.basetype == SPIRType::Image || resource.basetype == SPIRType::Sampler ||
9327 resource.basetype == SPIRType::SampledImage)
9328 {
9329 // Drop pointer information when we emit the resources into a struct.
9330 buffer_type.member_types.push_back(get_variable_data_type_id(var));
9331 set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
9332 }
9333 else
9334 {
9335 // Resources will be declared as pointers not references, so automatically dereference as appropriate.
9336 buffer_type.member_types.push_back(var.basetype);
9337 if (type.array.empty())
9338 set_qualified_name(var.self, join("(*", to_name(buffer_variable_id), ".", mbr_name, ")"));
9339 else
9340 set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
9341 }
9342 }
9343
9344 set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationResourceIndexPrimary,
9345 resource.index);
9346 set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationInterfaceOrigID,
9347 var.self);
9348 member_index++;
9349 }
9350 }
9351 }
9352
operator ==(const SetBindingPair & other) const9353 bool CompilerMSL::SetBindingPair::operator==(const SetBindingPair &other) const
9354 {
9355 return desc_set == other.desc_set && binding == other.binding;
9356 }
9357
operator ==(const StageSetBinding & other) const9358 bool CompilerMSL::StageSetBinding::operator==(const StageSetBinding &other) const
9359 {
9360 return model == other.model && desc_set == other.desc_set && binding == other.binding;
9361 }
9362
operator ()(const SetBindingPair & value) const9363 size_t CompilerMSL::InternalHasher::operator()(const SetBindingPair &value) const
9364 {
9365 // Quality of hash doesn't really matter here.
9366 auto hash_set = std::hash<uint32_t>()(value.desc_set);
9367 auto hash_binding = std::hash<uint32_t>()(value.binding);
9368 return (hash_set * 0x10001b31) ^ hash_binding;
9369 }
9370
operator ()(const StageSetBinding & value) const9371 size_t CompilerMSL::InternalHasher::operator()(const StageSetBinding &value) const
9372 {
9373 // Quality of hash doesn't really matter here.
9374 auto hash_model = std::hash<uint32_t>()(value.model);
9375 auto hash_set = std::hash<uint32_t>()(value.desc_set);
9376 auto tmp_hash = (hash_model * 0x10001b31) ^ hash_set;
9377 return (tmp_hash * 0x10001b31) ^ value.binding;
9378 }
9379