1 /*
2  * Copyright 2016-2019 The Brenwill Workshop Ltd.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "spirv_msl.hpp"
18 #include "GLSL.std.450.h"
19 
20 #include <algorithm>
21 #include <assert.h>
22 #include <numeric>
23 
24 #ifdef RARCH_INTERNAL
25 #include <retro_miscellaneous.h>
26 #endif
27 
28 using namespace spv;
29 using namespace SPIRV_CROSS_NAMESPACE;
30 using namespace std;
31 
32 static const uint32_t k_unknown_location = ~0u;
33 static const uint32_t k_unknown_component = ~0u;
34 
CompilerMSL(std::vector<uint32_t> spirv_)35 CompilerMSL::CompilerMSL(std::vector<uint32_t> spirv_)
36     : CompilerGLSL(move(spirv_))
37 {
38 }
39 
CompilerMSL(const uint32_t * ir_,size_t word_count)40 CompilerMSL::CompilerMSL(const uint32_t *ir_, size_t word_count)
41     : CompilerGLSL(ir_, word_count)
42 {
43 }
44 
CompilerMSL(const ParsedIR & ir_)45 CompilerMSL::CompilerMSL(const ParsedIR &ir_)
46     : CompilerGLSL(ir_)
47 {
48 }
49 
CompilerMSL(ParsedIR && ir_)50 CompilerMSL::CompilerMSL(ParsedIR &&ir_)
51     : CompilerGLSL(std::move(ir_))
52 {
53 }
54 
add_msl_vertex_attribute(const MSLVertexAttr & va)55 void CompilerMSL::add_msl_vertex_attribute(const MSLVertexAttr &va)
56 {
57 	vtx_attrs_by_location[va.location] = va;
58 	if (va.builtin != BuiltInMax && !vtx_attrs_by_builtin.count(va.builtin))
59 		vtx_attrs_by_builtin[va.builtin] = va;
60 }
61 
add_msl_resource_binding(const MSLResourceBinding & binding)62 void CompilerMSL::add_msl_resource_binding(const MSLResourceBinding &binding)
63 {
64 	StageSetBinding tuple = { binding.stage, binding.desc_set, binding.binding };
65 	resource_bindings[tuple] = { binding, false };
66 }
67 
add_discrete_descriptor_set(uint32_t desc_set)68 void CompilerMSL::add_discrete_descriptor_set(uint32_t desc_set)
69 {
70 	if (desc_set < kMaxArgumentBuffers)
71 		argument_buffer_discrete_mask |= 1u << desc_set;
72 }
73 
is_msl_vertex_attribute_used(uint32_t location)74 bool CompilerMSL::is_msl_vertex_attribute_used(uint32_t location)
75 {
76 	return vtx_attrs_in_use.count(location) != 0;
77 }
78 
is_msl_resource_binding_used(ExecutionModel model,uint32_t desc_set,uint32_t binding)79 bool CompilerMSL::is_msl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding)
80 {
81 	StageSetBinding tuple = { model, desc_set, binding };
82 	auto itr = resource_bindings.find(tuple);
83 	return itr != end(resource_bindings) && itr->second.second;
84 }
85 
get_automatic_msl_resource_binding(uint32_t id) const86 uint32_t CompilerMSL::get_automatic_msl_resource_binding(uint32_t id) const
87 {
88 	return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexPrimary);
89 }
90 
get_automatic_msl_resource_binding_secondary(uint32_t id) const91 uint32_t CompilerMSL::get_automatic_msl_resource_binding_secondary(uint32_t id) const
92 {
93 	return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexSecondary);
94 }
95 
set_fragment_output_components(uint32_t location,uint32_t components)96 void CompilerMSL::set_fragment_output_components(uint32_t location, uint32_t components)
97 {
98 	fragment_output_components[location] = components;
99 }
100 
build_implicit_builtins()101 void CompilerMSL::build_implicit_builtins()
102 {
103 	bool need_sample_pos = active_input_builtins.get(BuiltInSamplePosition);
104 	bool need_vertex_params = capture_output_to_buffer && get_execution_model() == ExecutionModelVertex;
105 	bool need_tesc_params = get_execution_model() == ExecutionModelTessellationControl;
106 	bool need_subgroup_mask =
107 	    active_input_builtins.get(BuiltInSubgroupEqMask) || active_input_builtins.get(BuiltInSubgroupGeMask) ||
108 	    active_input_builtins.get(BuiltInSubgroupGtMask) || active_input_builtins.get(BuiltInSubgroupLeMask) ||
109 	    active_input_builtins.get(BuiltInSubgroupLtMask);
110 	bool need_subgroup_ge_mask = !msl_options.is_ios() && (active_input_builtins.get(BuiltInSubgroupGeMask) ||
111 	                                                       active_input_builtins.get(BuiltInSubgroupGtMask));
112 	bool need_multiview = get_execution_model() == ExecutionModelVertex &&
113 	                      (msl_options.multiview || active_input_builtins.get(BuiltInViewIndex));
114 	if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
115 	    need_multiview || needs_subgroup_invocation_id)
116 	{
117 		bool has_frag_coord = false;
118 		bool has_sample_id = false;
119 		bool has_vertex_idx = false;
120 		bool has_base_vertex = false;
121 		bool has_instance_idx = false;
122 		bool has_base_instance = false;
123 		bool has_invocation_id = false;
124 		bool has_primitive_id = false;
125 		bool has_subgroup_invocation_id = false;
126 		bool has_subgroup_size = false;
127 		bool has_view_idx = false;
128 
129 		ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
130 			if (var.storage != StorageClassInput || !ir.meta[var.self].decoration.builtin)
131 				return;
132 
133 			BuiltIn builtin = ir.meta[var.self].decoration.builtin_type;
134 			if (need_subpass_input && builtin == BuiltInFragCoord)
135 			{
136 				builtin_frag_coord_id = var.self;
137 				has_frag_coord = true;
138 			}
139 
140 			if (need_sample_pos && builtin == BuiltInSampleId)
141 			{
142 				builtin_sample_id_id = var.self;
143 				has_sample_id = true;
144 			}
145 
146 			if (need_vertex_params)
147 			{
148 				switch (builtin)
149 				{
150 				case BuiltInVertexIndex:
151 					builtin_vertex_idx_id = var.self;
152 					has_vertex_idx = true;
153 					break;
154 				case BuiltInBaseVertex:
155 					builtin_base_vertex_id = var.self;
156 					has_base_vertex = true;
157 					break;
158 				case BuiltInInstanceIndex:
159 					builtin_instance_idx_id = var.self;
160 					has_instance_idx = true;
161 					break;
162 				case BuiltInBaseInstance:
163 					builtin_base_instance_id = var.self;
164 					has_base_instance = true;
165 					break;
166 				default:
167 					break;
168 				}
169 			}
170 
171 			if (need_tesc_params)
172 			{
173 				switch (builtin)
174 				{
175 				case BuiltInInvocationId:
176 					builtin_invocation_id_id = var.self;
177 					has_invocation_id = true;
178 					break;
179 				case BuiltInPrimitiveId:
180 					builtin_primitive_id_id = var.self;
181 					has_primitive_id = true;
182 					break;
183 				default:
184 					break;
185 				}
186 			}
187 
188 			if ((need_subgroup_mask || needs_subgroup_invocation_id) && builtin == BuiltInSubgroupLocalInvocationId)
189 			{
190 				builtin_subgroup_invocation_id_id = var.self;
191 				has_subgroup_invocation_id = true;
192 			}
193 
194 			if (need_subgroup_ge_mask && builtin == BuiltInSubgroupSize)
195 			{
196 				builtin_subgroup_size_id = var.self;
197 				has_subgroup_size = true;
198 			}
199 
200 			if (need_multiview)
201 			{
202 				if (builtin == BuiltInInstanceIndex)
203 				{
204 					// The view index here is derived from the instance index.
205 					builtin_instance_idx_id = var.self;
206 					has_instance_idx = true;
207 				}
208 
209 				if (builtin == BuiltInViewIndex)
210 				{
211 					builtin_view_idx_id = var.self;
212 					has_view_idx = true;
213 				}
214 			}
215 		});
216 
217 		if (!has_frag_coord && need_subpass_input)
218 		{
219 			uint32_t offset = ir.increase_bound_by(3);
220 			uint32_t type_id = offset;
221 			uint32_t type_ptr_id = offset + 1;
222 			uint32_t var_id = offset + 2;
223 
224 			// Create gl_FragCoord.
225 			SPIRType vec4_type;
226 			vec4_type.basetype = SPIRType::Float;
227 			vec4_type.width = 32;
228 			vec4_type.vecsize = 4;
229 			set<SPIRType>(type_id, vec4_type);
230 
231 			SPIRType vec4_type_ptr;
232 			vec4_type_ptr = vec4_type;
233 			vec4_type_ptr.pointer = true;
234 			vec4_type_ptr.parent_type = type_id;
235 			vec4_type_ptr.storage = StorageClassInput;
236 			auto &ptr_type = set<SPIRType>(type_ptr_id, vec4_type_ptr);
237 			ptr_type.self = type_id;
238 
239 			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
240 			set_decoration(var_id, DecorationBuiltIn, BuiltInFragCoord);
241 			builtin_frag_coord_id = var_id;
242 			mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var_id);
243 		}
244 
245 		if (!has_sample_id && need_sample_pos)
246 		{
247 			uint32_t offset = ir.increase_bound_by(3);
248 			uint32_t type_id = offset;
249 			uint32_t type_ptr_id = offset + 1;
250 			uint32_t var_id = offset + 2;
251 
252 			// Create gl_SampleID.
253 			SPIRType uint_type;
254 			uint_type.basetype = SPIRType::UInt;
255 			uint_type.width = 32;
256 			set<SPIRType>(type_id, uint_type);
257 
258 			SPIRType uint_type_ptr;
259 			uint_type_ptr = uint_type;
260 			uint_type_ptr.pointer = true;
261 			uint_type_ptr.parent_type = type_id;
262 			uint_type_ptr.storage = StorageClassInput;
263 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
264 			ptr_type.self = type_id;
265 
266 			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
267 			set_decoration(var_id, DecorationBuiltIn, BuiltInSampleId);
268 			builtin_sample_id_id = var_id;
269 			mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var_id);
270 		}
271 
272 		if ((need_vertex_params && (!has_vertex_idx || !has_base_vertex || !has_instance_idx || !has_base_instance)) ||
273 		    (need_multiview && (!has_instance_idx || !has_view_idx)))
274 		{
275 			uint32_t offset = ir.increase_bound_by(2);
276 			uint32_t type_id = offset;
277 			uint32_t type_ptr_id = offset + 1;
278 
279 			SPIRType uint_type;
280 			uint_type.basetype = SPIRType::UInt;
281 			uint_type.width = 32;
282 			set<SPIRType>(type_id, uint_type);
283 
284 			SPIRType uint_type_ptr;
285 			uint_type_ptr = uint_type;
286 			uint_type_ptr.pointer = true;
287 			uint_type_ptr.parent_type = type_id;
288 			uint_type_ptr.storage = StorageClassInput;
289 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
290 			ptr_type.self = type_id;
291 
292 			if (need_vertex_params && !has_vertex_idx)
293 			{
294 				uint32_t var_id = ir.increase_bound_by(1);
295 
296 				// Create gl_VertexIndex.
297 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
298 				set_decoration(var_id, DecorationBuiltIn, BuiltInVertexIndex);
299 				builtin_vertex_idx_id = var_id;
300 				mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var_id);
301 			}
302 
303 			if (need_vertex_params && !has_base_vertex)
304 			{
305 				uint32_t var_id = ir.increase_bound_by(1);
306 
307 				// Create gl_BaseVertex.
308 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
309 				set_decoration(var_id, DecorationBuiltIn, BuiltInBaseVertex);
310 				builtin_base_vertex_id = var_id;
311 				mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var_id);
312 			}
313 
314 			if (!has_instance_idx) // Needed by both multiview and tessellation
315 			{
316 				uint32_t var_id = ir.increase_bound_by(1);
317 
318 				// Create gl_InstanceIndex.
319 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
320 				set_decoration(var_id, DecorationBuiltIn, BuiltInInstanceIndex);
321 				builtin_instance_idx_id = var_id;
322 				mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var_id);
323 
324 				if (need_multiview)
325 				{
326 					// Multiview shaders are not allowed to write to gl_Layer, ostensibly because
327 					// it is implicitly written from gl_ViewIndex, but we have to do that explicitly.
328 					// Note that we can't just abuse gl_ViewIndex for this purpose: it's an input, but
329 					// gl_Layer is an output in vertex-pipeline shaders.
330 					uint32_t type_ptr_out_id = ir.increase_bound_by(2);
331 					SPIRType uint_type_ptr_out;
332 					uint_type_ptr_out = uint_type;
333 					uint_type_ptr_out.pointer = true;
334 					uint_type_ptr_out.parent_type = type_id;
335 					uint_type_ptr_out.storage = StorageClassOutput;
336 					auto &ptr_out_type = set<SPIRType>(type_ptr_out_id, uint_type_ptr_out);
337 					ptr_out_type.self = type_id;
338 					var_id = type_ptr_out_id + 1;
339 					set<SPIRVariable>(var_id, type_ptr_out_id, StorageClassOutput);
340 					set_decoration(var_id, DecorationBuiltIn, BuiltInLayer);
341 					builtin_layer_id = var_id;
342 					mark_implicit_builtin(StorageClassOutput, BuiltInLayer, var_id);
343 				}
344 			}
345 
346 			if (need_vertex_params && !has_base_instance)
347 			{
348 				uint32_t var_id = ir.increase_bound_by(1);
349 
350 				// Create gl_BaseInstance.
351 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
352 				set_decoration(var_id, DecorationBuiltIn, BuiltInBaseInstance);
353 				builtin_base_instance_id = var_id;
354 				mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var_id);
355 			}
356 
357 			if (need_multiview && !has_view_idx)
358 			{
359 				uint32_t var_id = ir.increase_bound_by(1);
360 
361 				// Create gl_ViewIndex.
362 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
363 				set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex);
364 				builtin_view_idx_id = var_id;
365 				mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id);
366 			}
367 		}
368 
369 		if (need_tesc_params && (!has_invocation_id || !has_primitive_id))
370 		{
371 			uint32_t offset = ir.increase_bound_by(2);
372 			uint32_t type_id = offset;
373 			uint32_t type_ptr_id = offset + 1;
374 
375 			SPIRType uint_type;
376 			uint_type.basetype = SPIRType::UInt;
377 			uint_type.width = 32;
378 			set<SPIRType>(type_id, uint_type);
379 
380 			SPIRType uint_type_ptr;
381 			uint_type_ptr = uint_type;
382 			uint_type_ptr.pointer = true;
383 			uint_type_ptr.parent_type = type_id;
384 			uint_type_ptr.storage = StorageClassInput;
385 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
386 			ptr_type.self = type_id;
387 
388 			if (!has_invocation_id)
389 			{
390 				uint32_t var_id = ir.increase_bound_by(1);
391 
392 				// Create gl_InvocationID.
393 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
394 				set_decoration(var_id, DecorationBuiltIn, BuiltInInvocationId);
395 				builtin_invocation_id_id = var_id;
396 				mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var_id);
397 			}
398 
399 			if (!has_primitive_id)
400 			{
401 				uint32_t var_id = ir.increase_bound_by(1);
402 
403 				// Create gl_PrimitiveID.
404 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
405 				set_decoration(var_id, DecorationBuiltIn, BuiltInPrimitiveId);
406 				builtin_primitive_id_id = var_id;
407 				mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var_id);
408 			}
409 		}
410 
411 		if (!has_subgroup_invocation_id && (need_subgroup_mask || needs_subgroup_invocation_id))
412 		{
413 			uint32_t offset = ir.increase_bound_by(3);
414 			uint32_t type_id = offset;
415 			uint32_t type_ptr_id = offset + 1;
416 			uint32_t var_id = offset + 2;
417 
418 			// Create gl_SubgroupInvocationID.
419 			SPIRType uint_type;
420 			uint_type.basetype = SPIRType::UInt;
421 			uint_type.width = 32;
422 			set<SPIRType>(type_id, uint_type);
423 
424 			SPIRType uint_type_ptr;
425 			uint_type_ptr = uint_type;
426 			uint_type_ptr.pointer = true;
427 			uint_type_ptr.parent_type = type_id;
428 			uint_type_ptr.storage = StorageClassInput;
429 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
430 			ptr_type.self = type_id;
431 
432 			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
433 			set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupLocalInvocationId);
434 			builtin_subgroup_invocation_id_id = var_id;
435 			mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var_id);
436 		}
437 
438 		if (!has_subgroup_size && need_subgroup_ge_mask)
439 		{
440 			uint32_t offset = ir.increase_bound_by(3);
441 			uint32_t type_id = offset;
442 			uint32_t type_ptr_id = offset + 1;
443 			uint32_t var_id = offset + 2;
444 
445 			// Create gl_SubgroupSize.
446 			SPIRType uint_type;
447 			uint_type.basetype = SPIRType::UInt;
448 			uint_type.width = 32;
449 			set<SPIRType>(type_id, uint_type);
450 
451 			SPIRType uint_type_ptr;
452 			uint_type_ptr = uint_type;
453 			uint_type_ptr.pointer = true;
454 			uint_type_ptr.parent_type = type_id;
455 			uint_type_ptr.storage = StorageClassInput;
456 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
457 			ptr_type.self = type_id;
458 
459 			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
460 			set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupSize);
461 			builtin_subgroup_size_id = var_id;
462 			mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var_id);
463 		}
464 	}
465 
466 	if (needs_swizzle_buffer_def)
467 	{
468 		uint32_t var_id = build_constant_uint_array_pointer();
469 		set_name(var_id, "spvSwizzleConstants");
470 		// This should never match anything.
471 		set_decoration(var_id, DecorationDescriptorSet, kSwizzleBufferBinding);
472 		set_decoration(var_id, DecorationBinding, msl_options.swizzle_buffer_index);
473 		set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.swizzle_buffer_index);
474 		swizzle_buffer_id = var_id;
475 	}
476 
477 	if (!buffers_requiring_array_length.empty())
478 	{
479 		uint32_t var_id = build_constant_uint_array_pointer();
480 		set_name(var_id, "spvBufferSizeConstants");
481 		// This should never match anything.
482 		set_decoration(var_id, DecorationDescriptorSet, kBufferSizeBufferBinding);
483 		set_decoration(var_id, DecorationBinding, msl_options.buffer_size_buffer_index);
484 		set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.buffer_size_buffer_index);
485 		buffer_size_buffer_id = var_id;
486 	}
487 
488 	if (needs_view_mask_buffer())
489 	{
490 		uint32_t var_id = build_constant_uint_array_pointer();
491 		set_name(var_id, "spvViewMask");
492 		// This should never match anything.
493 		set_decoration(var_id, DecorationDescriptorSet, ~(4u));
494 		set_decoration(var_id, DecorationBinding, msl_options.view_mask_buffer_index);
495 		set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.view_mask_buffer_index);
496 		view_mask_buffer_id = var_id;
497 	}
498 }
499 
mark_implicit_builtin(StorageClass storage,BuiltIn builtin,uint32_t id)500 void CompilerMSL::mark_implicit_builtin(StorageClass storage, BuiltIn builtin, uint32_t id)
501 {
502 	Bitset *active_builtins = nullptr;
503 	switch (storage)
504 	{
505 	case StorageClassInput:
506 		active_builtins = &active_input_builtins;
507 		break;
508 
509 	case StorageClassOutput:
510 		active_builtins = &active_output_builtins;
511 		break;
512 
513 	default:
514 		break;
515 	}
516 
517 	assert(active_builtins != nullptr);
518 	active_builtins->set(builtin);
519 	get_entry_point().interface_variables.push_back(id);
520 }
521 
build_constant_uint_array_pointer()522 uint32_t CompilerMSL::build_constant_uint_array_pointer()
523 {
524 	uint32_t offset = ir.increase_bound_by(4);
525 	uint32_t type_id = offset;
526 	uint32_t type_ptr_id = offset + 1;
527 	uint32_t type_ptr_ptr_id = offset + 2;
528 	uint32_t var_id = offset + 3;
529 
530 	// Create a buffer to hold extra data, including the swizzle constants.
531 	SPIRType uint_type;
532 	uint_type.basetype = SPIRType::UInt;
533 	uint_type.width = 32;
534 	set<SPIRType>(type_id, uint_type);
535 
536 	SPIRType uint_type_pointer = uint_type;
537 	uint_type_pointer.pointer = true;
538 	uint_type_pointer.pointer_depth = 1;
539 	uint_type_pointer.parent_type = type_id;
540 	uint_type_pointer.storage = StorageClassUniform;
541 	set<SPIRType>(type_ptr_id, uint_type_pointer);
542 	set_decoration(type_ptr_id, DecorationArrayStride, 4);
543 
544 	SPIRType uint_type_pointer2 = uint_type_pointer;
545 	uint_type_pointer2.pointer_depth++;
546 	uint_type_pointer2.parent_type = type_ptr_id;
547 	set<SPIRType>(type_ptr_ptr_id, uint_type_pointer2);
548 
549 	set<SPIRVariable>(var_id, type_ptr_ptr_id, StorageClassUniformConstant);
550 	return var_id;
551 }
552 
create_sampler_address(const char * prefix,MSLSamplerAddress addr)553 static string create_sampler_address(const char *prefix, MSLSamplerAddress addr)
554 {
555 	switch (addr)
556 	{
557 	case MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE:
558 		return join(prefix, "address::clamp_to_edge");
559 	case MSL_SAMPLER_ADDRESS_CLAMP_TO_ZERO:
560 		return join(prefix, "address::clamp_to_zero");
561 	case MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER:
562 		return join(prefix, "address::clamp_to_border");
563 	case MSL_SAMPLER_ADDRESS_REPEAT:
564 		return join(prefix, "address::repeat");
565 	case MSL_SAMPLER_ADDRESS_MIRRORED_REPEAT:
566 		return join(prefix, "address::mirrored_repeat");
567 	default:
568 		SPIRV_CROSS_THROW("Invalid sampler addressing mode.");
569 	}
570 }
571 
get_stage_in_struct_type()572 SPIRType &CompilerMSL::get_stage_in_struct_type()
573 {
574 	auto &si_var = get<SPIRVariable>(stage_in_var_id);
575 	return get_variable_data_type(si_var);
576 }
577 
get_stage_out_struct_type()578 SPIRType &CompilerMSL::get_stage_out_struct_type()
579 {
580 	auto &so_var = get<SPIRVariable>(stage_out_var_id);
581 	return get_variable_data_type(so_var);
582 }
583 
get_patch_stage_in_struct_type()584 SPIRType &CompilerMSL::get_patch_stage_in_struct_type()
585 {
586 	auto &si_var = get<SPIRVariable>(patch_stage_in_var_id);
587 	return get_variable_data_type(si_var);
588 }
589 
get_patch_stage_out_struct_type()590 SPIRType &CompilerMSL::get_patch_stage_out_struct_type()
591 {
592 	auto &so_var = get<SPIRVariable>(patch_stage_out_var_id);
593 	return get_variable_data_type(so_var);
594 }
595 
get_tess_factor_struct_name()596 std::string CompilerMSL::get_tess_factor_struct_name()
597 {
598 	if (get_entry_point().flags.get(ExecutionModeTriangles))
599 		return "MTLTriangleTessellationFactorsHalf";
600 	return "MTLQuadTessellationFactorsHalf";
601 }
602 
emit_entry_point_declarations()603 void CompilerMSL::emit_entry_point_declarations()
604 {
605 	// FIXME: Get test coverage here ...
606 
607 	// Emit constexpr samplers here.
608 	for (auto &samp : constexpr_samplers_by_id)
609 	{
610 		auto &var = get<SPIRVariable>(samp.first);
611 		auto &type = get<SPIRType>(var.basetype);
612 		if (type.basetype == SPIRType::Sampler)
613 			add_resource_name(samp.first);
614 
615 		SmallVector<string> args;
616 		auto &s = samp.second;
617 
618 		if (s.coord != MSL_SAMPLER_COORD_NORMALIZED)
619 			args.push_back("coord::pixel");
620 
621 		if (s.min_filter == s.mag_filter)
622 		{
623 			if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
624 				args.push_back("filter::linear");
625 		}
626 		else
627 		{
628 			if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
629 				args.push_back("min_filter::linear");
630 			if (s.mag_filter != MSL_SAMPLER_FILTER_NEAREST)
631 				args.push_back("mag_filter::linear");
632 		}
633 
634 		switch (s.mip_filter)
635 		{
636 		case MSL_SAMPLER_MIP_FILTER_NONE:
637 			// Default
638 			break;
639 		case MSL_SAMPLER_MIP_FILTER_NEAREST:
640 			args.push_back("mip_filter::nearest");
641 			break;
642 		case MSL_SAMPLER_MIP_FILTER_LINEAR:
643 			args.push_back("mip_filter::linear");
644 			break;
645 		default:
646 			SPIRV_CROSS_THROW("Invalid mip filter.");
647 		}
648 
649 		if (s.s_address == s.t_address && s.s_address == s.r_address)
650 		{
651 			if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
652 				args.push_back(create_sampler_address("", s.s_address));
653 		}
654 		else
655 		{
656 			if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
657 				args.push_back(create_sampler_address("s_", s.s_address));
658 			if (s.t_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
659 				args.push_back(create_sampler_address("t_", s.t_address));
660 			if (s.r_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
661 				args.push_back(create_sampler_address("r_", s.r_address));
662 		}
663 
664 		if (s.compare_enable)
665 		{
666 			switch (s.compare_func)
667 			{
668 			case MSL_SAMPLER_COMPARE_FUNC_ALWAYS:
669 				args.push_back("compare_func::always");
670 				break;
671 			case MSL_SAMPLER_COMPARE_FUNC_NEVER:
672 				args.push_back("compare_func::never");
673 				break;
674 			case MSL_SAMPLER_COMPARE_FUNC_EQUAL:
675 				args.push_back("compare_func::equal");
676 				break;
677 			case MSL_SAMPLER_COMPARE_FUNC_NOT_EQUAL:
678 				args.push_back("compare_func::not_equal");
679 				break;
680 			case MSL_SAMPLER_COMPARE_FUNC_LESS:
681 				args.push_back("compare_func::less");
682 				break;
683 			case MSL_SAMPLER_COMPARE_FUNC_LESS_EQUAL:
684 				args.push_back("compare_func::less_equal");
685 				break;
686 			case MSL_SAMPLER_COMPARE_FUNC_GREATER:
687 				args.push_back("compare_func::greater");
688 				break;
689 			case MSL_SAMPLER_COMPARE_FUNC_GREATER_EQUAL:
690 				args.push_back("compare_func::greater_equal");
691 				break;
692 			default:
693 				SPIRV_CROSS_THROW("Invalid sampler compare function.");
694 			}
695 		}
696 
697 		if (s.s_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER || s.t_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER ||
698 		    s.r_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER)
699 		{
700 			switch (s.border_color)
701 			{
702 			case MSL_SAMPLER_BORDER_COLOR_OPAQUE_BLACK:
703 				args.push_back("border_color::opaque_black");
704 				break;
705 			case MSL_SAMPLER_BORDER_COLOR_OPAQUE_WHITE:
706 				args.push_back("border_color::opaque_white");
707 				break;
708 			case MSL_SAMPLER_BORDER_COLOR_TRANSPARENT_BLACK:
709 				args.push_back("border_color::transparent_black");
710 				break;
711 			default:
712 				SPIRV_CROSS_THROW("Invalid sampler border color.");
713 			}
714 		}
715 
716 		if (s.anisotropy_enable)
717 			args.push_back(join("max_anisotropy(", s.max_anisotropy, ")"));
718 		if (s.lod_clamp_enable)
719 		{
720 			args.push_back(join("lod_clamp(", convert_to_string(s.lod_clamp_min, current_locale_radix_character), ", ",
721 			                    convert_to_string(s.lod_clamp_max, current_locale_radix_character), ")"));
722 		}
723 
724 		statement("constexpr sampler ",
725 		          type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first),
726 		          "(", merge(args), ");");
727 	}
728 
729 	// Emit buffer arrays here.
730 	for (uint32_t array_id : buffer_arrays)
731 	{
732 		const auto &var = get<SPIRVariable>(array_id);
733 		const auto &type = get_variable_data_type(var);
734 		string name = to_name(array_id);
735 		statement(get_argument_address_space(var) + " " + type_to_glsl(type) + "* " + name + "[] =");
736 		begin_scope();
737 		for (uint32_t i = 0; i < type.array[0]; ++i)
738 			statement(name + "_" + convert_to_string(i) + ",");
739 		end_scope_decl();
740 		statement_no_indent("");
741 	}
742 	// For some reason, without this, we end up emitting the arrays twice.
743 	buffer_arrays.clear();
744 }
745 
compile()746 string CompilerMSL::compile()
747 {
748 	// Do not deal with GLES-isms like precision, older extensions and such.
749 	options.vulkan_semantics = true;
750 	options.es = false;
751 	options.version = 450;
752 	backend.null_pointer_literal = "nullptr";
753 	backend.float_literal_suffix = false;
754 	backend.uint32_t_literal_suffix = true;
755 	backend.int16_t_literal_suffix = "";
756 	backend.uint16_t_literal_suffix = "u";
757 	backend.basic_int_type = "int";
758 	backend.basic_uint_type = "uint";
759 	backend.basic_int8_type = "char";
760 	backend.basic_uint8_type = "uchar";
761 	backend.basic_int16_type = "short";
762 	backend.basic_uint16_type = "ushort";
763 	backend.discard_literal = "discard_fragment()";
764 	backend.swizzle_is_function = false;
765 	backend.shared_is_implied = false;
766 	backend.use_initializer_list = true;
767 	backend.use_typed_initializer_list = true;
768 	backend.native_row_major_matrix = false;
769 	backend.unsized_array_supported = false;
770 	backend.can_declare_arrays_inline = false;
771 	backend.can_return_array = false;
772 	backend.boolean_mix_support = false;
773 	backend.allow_truncated_access_chain = true;
774 	backend.array_is_value_type = false;
775 	backend.comparison_image_samples_scalar = true;
776 	backend.native_pointers = true;
777 	backend.nonuniform_qualifier = "";
778 	backend.support_small_type_sampling_result = true;
779 
780 	capture_output_to_buffer = msl_options.capture_output_to_buffer;
781 	is_rasterization_disabled = msl_options.disable_rasterization || capture_output_to_buffer;
782 
783 	// Initialize array here rather than constructor, MSVC 2013 workaround.
784 	for (auto &id : next_metal_resource_ids)
785 		id = 0;
786 
787 	fixup_type_alias();
788 	replace_illegal_names();
789 
790 	struct_member_padding.clear();
791 
792 	build_function_control_flow_graphs_and_analyze();
793 	update_active_builtins();
794 	analyze_image_and_sampler_usage();
795 	analyze_sampled_image_usage();
796 	preprocess_op_codes();
797 	build_implicit_builtins();
798 
799 	fixup_image_load_store_access();
800 
801 	set_enabled_interface_variables(get_active_interface_variables());
802 	if (swizzle_buffer_id)
803 		active_interface_variables.insert(swizzle_buffer_id);
804 	if (buffer_size_buffer_id)
805 		active_interface_variables.insert(buffer_size_buffer_id);
806 	if (view_mask_buffer_id)
807 		active_interface_variables.insert(view_mask_buffer_id);
808 	if (builtin_layer_id)
809 		active_interface_variables.insert(builtin_layer_id);
810 
811 	// Create structs to hold input, output and uniform variables.
812 	// Do output first to ensure out. is declared at top of entry function.
813 	qual_pos_var_name = "";
814 	stage_out_var_id = add_interface_block(StorageClassOutput);
815 	patch_stage_out_var_id = add_interface_block(StorageClassOutput, true);
816 	stage_in_var_id = add_interface_block(StorageClassInput);
817 	if (get_execution_model() == ExecutionModelTessellationEvaluation)
818 		patch_stage_in_var_id = add_interface_block(StorageClassInput, true);
819 
820 	if (get_execution_model() == ExecutionModelTessellationControl)
821 		stage_out_ptr_var_id = add_interface_block_pointer(stage_out_var_id, StorageClassOutput);
822 	if (is_tessellation_shader())
823 		stage_in_ptr_var_id = add_interface_block_pointer(stage_in_var_id, StorageClassInput);
824 
825 	// Metal vertex functions that define no output must disable rasterization and return void.
826 	if (!stage_out_var_id)
827 		is_rasterization_disabled = true;
828 
829 	// Convert the use of global variables to recursively-passed function parameters
830 	localize_global_variables();
831 	extract_global_variables_from_functions();
832 
833 	// Mark any non-stage-in structs to be tightly packed.
834 	mark_packable_structs();
835 	reorder_type_alias();
836 
837 	// Add fixup hooks required by shader inputs and outputs. This needs to happen before
838 	// the loop, so the hooks aren't added multiple times.
839 	fix_up_shader_inputs_outputs();
840 
841 	// If we are using argument buffers, we create argument buffer structures for them here.
842 	// These buffers will be used in the entry point, not the individual resources.
843 	if (msl_options.argument_buffers)
844 	{
845 		if (!msl_options.supports_msl_version(2, 0))
846 			SPIRV_CROSS_THROW("Argument buffers can only be used with MSL 2.0 and up.");
847 		analyze_argument_buffers();
848 	}
849 
850 	uint32_t pass_count = 0;
851 	do
852 	{
853 		if (pass_count >= 3)
854 			SPIRV_CROSS_THROW("Over 3 compilation loops detected. Must be a bug!");
855 
856 		reset();
857 
858 		// Start bindings at zero.
859 		next_metal_resource_index_buffer = 0;
860 		next_metal_resource_index_texture = 0;
861 		next_metal_resource_index_sampler = 0;
862 		for (auto &id : next_metal_resource_ids)
863 			id = 0;
864 
865 		// Move constructor for this type is broken on GCC 4.9 ...
866 		buffer.reset();
867 
868 		emit_header();
869 		emit_specialization_constants_and_structs();
870 		emit_resources();
871 		emit_custom_functions();
872 		emit_function(get<SPIRFunction>(ir.default_entry_point), Bitset());
873 
874 		pass_count++;
875 	} while (is_forcing_recompilation());
876 
877 	return buffer.str();
878 }
879 
880 // Register the need to output any custom functions.
preprocess_op_codes()881 void CompilerMSL::preprocess_op_codes()
882 {
883 	OpCodePreprocessor preproc(*this);
884 	traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), preproc);
885 
886 	suppress_missing_prototypes = preproc.suppress_missing_prototypes;
887 
888 	if (preproc.uses_atomics)
889 	{
890 		add_header_line("#include <metal_atomic>");
891 		add_pragma_line("#pragma clang diagnostic ignored \"-Wunused-variable\"");
892 	}
893 
894 	// Metal vertex functions that write to resources must disable rasterization and return void.
895 	if (preproc.uses_resource_write)
896 		is_rasterization_disabled = true;
897 
898 	// Tessellation control shaders are run as compute functions in Metal, and so
899 	// must capture their output to a buffer.
900 	if (get_execution_model() == ExecutionModelTessellationControl)
901 	{
902 		is_rasterization_disabled = true;
903 		capture_output_to_buffer = true;
904 	}
905 
906 	if (preproc.needs_subgroup_invocation_id)
907 		needs_subgroup_invocation_id = true;
908 }
909 
910 // Move the Private and Workgroup global variables to the entry function.
911 // Non-constant variables cannot have global scope in Metal.
localize_global_variables()912 void CompilerMSL::localize_global_variables()
913 {
914 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
915 	auto iter = global_variables.begin();
916 	while (iter != global_variables.end())
917 	{
918 		uint32_t v_id = *iter;
919 		auto &var = get<SPIRVariable>(v_id);
920 		if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup)
921 		{
922 			if (!variable_is_lut(var))
923 				entry_func.add_local_variable(v_id);
924 			iter = global_variables.erase(iter);
925 		}
926 		else
927 			iter++;
928 	}
929 }
930 
931 // For any global variable accessed directly by a function,
932 // extract that variable and add it as an argument to that function.
extract_global_variables_from_functions()933 void CompilerMSL::extract_global_variables_from_functions()
934 {
935 	// Uniforms
936 	unordered_set<uint32_t> global_var_ids;
937 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
938 		if (var.storage == StorageClassInput || var.storage == StorageClassOutput ||
939 		    var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
940 		    var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer)
941 		{
942 			global_var_ids.insert(var.self);
943 		}
944 	});
945 
946 	// Local vars that are declared in the main function and accessed directly by a function
947 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
948 	for (auto &var : entry_func.local_variables)
949 		if (get<SPIRVariable>(var).storage != StorageClassFunction)
950 			global_var_ids.insert(var);
951 
952 	std::set<uint32_t> added_arg_ids;
953 	unordered_set<uint32_t> processed_func_ids;
954 	extract_global_variables_from_function(ir.default_entry_point, added_arg_ids, global_var_ids, processed_func_ids);
955 }
956 
957 // MSL does not support the use of global variables for shader input content.
958 // For any global variable accessed directly by the specified function, extract that variable,
959 // add it as an argument to that function, and the arg to the added_arg_ids collection.
extract_global_variables_from_function(uint32_t func_id,std::set<uint32_t> & added_arg_ids,unordered_set<uint32_t> & global_var_ids,unordered_set<uint32_t> & processed_func_ids)960 void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::set<uint32_t> &added_arg_ids,
961                                                          unordered_set<uint32_t> &global_var_ids,
962                                                          unordered_set<uint32_t> &processed_func_ids)
963 {
964 	// Avoid processing a function more than once
965 	if (processed_func_ids.find(func_id) != processed_func_ids.end())
966 	{
967 		// Return function global variables
968 		added_arg_ids = function_global_vars[func_id];
969 		return;
970 	}
971 
972 	processed_func_ids.insert(func_id);
973 
974 	auto &func = get<SPIRFunction>(func_id);
975 
976 	// Recursively establish global args added to functions on which we depend.
977 	for (auto block : func.blocks)
978 	{
979 		auto &b = get<SPIRBlock>(block);
980 		for (auto &i : b.ops)
981 		{
982 			auto ops = stream(i);
983 			auto op = static_cast<Op>(i.op);
984 
985 			switch (op)
986 			{
987 			case OpLoad:
988 			case OpInBoundsAccessChain:
989 			case OpAccessChain:
990 			case OpPtrAccessChain:
991 			case OpArrayLength:
992 			{
993 				uint32_t base_id = ops[2];
994 				if (global_var_ids.find(base_id) != global_var_ids.end())
995 					added_arg_ids.insert(base_id);
996 
997 				auto &type = get<SPIRType>(ops[0]);
998 				if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData)
999 				{
1000 					// Implicitly reads gl_FragCoord.
1001 					assert(builtin_frag_coord_id != 0);
1002 					added_arg_ids.insert(builtin_frag_coord_id);
1003 				}
1004 
1005 				break;
1006 			}
1007 
1008 			case OpFunctionCall:
1009 			{
1010 				// First see if any of the function call args are globals
1011 				for (uint32_t arg_idx = 3; arg_idx < i.length; arg_idx++)
1012 				{
1013 					uint32_t arg_id = ops[arg_idx];
1014 					if (global_var_ids.find(arg_id) != global_var_ids.end())
1015 						added_arg_ids.insert(arg_id);
1016 				}
1017 
1018 				// Then recurse into the function itself to extract globals used internally in the function
1019 				uint32_t inner_func_id = ops[2];
1020 				std::set<uint32_t> inner_func_args;
1021 				extract_global_variables_from_function(inner_func_id, inner_func_args, global_var_ids,
1022 				                                       processed_func_ids);
1023 				added_arg_ids.insert(inner_func_args.begin(), inner_func_args.end());
1024 				break;
1025 			}
1026 
1027 			case OpStore:
1028 			{
1029 				uint32_t base_id = ops[0];
1030 				if (global_var_ids.find(base_id) != global_var_ids.end())
1031 					added_arg_ids.insert(base_id);
1032 				break;
1033 			}
1034 
1035 			case OpSelect:
1036 			{
1037 				uint32_t base_id = ops[3];
1038 				if (global_var_ids.find(base_id) != global_var_ids.end())
1039 					added_arg_ids.insert(base_id);
1040 				base_id = ops[4];
1041 				if (global_var_ids.find(base_id) != global_var_ids.end())
1042 					added_arg_ids.insert(base_id);
1043 				break;
1044 			}
1045 
1046 			default:
1047 				break;
1048 			}
1049 
1050 			// TODO: Add all other operations which can affect memory.
1051 			// We should consider a more unified system here to reduce boiler-plate.
1052 			// This kind of analysis is done in several places ...
1053 		}
1054 	}
1055 
1056 	function_global_vars[func_id] = added_arg_ids;
1057 
1058 	// Add the global variables as arguments to the function
1059 	if (func_id != ir.default_entry_point)
1060 	{
1061 		bool added_in = false;
1062 		bool added_out = false;
1063 		for (uint32_t arg_id : added_arg_ids)
1064 		{
1065 			auto &var = get<SPIRVariable>(arg_id);
1066 			uint32_t type_id = var.basetype;
1067 			auto *p_type = &get<SPIRType>(type_id);
1068 			BuiltIn bi_type = BuiltIn(get_decoration(arg_id, DecorationBuiltIn));
1069 
1070 			if (((is_tessellation_shader() && var.storage == StorageClassInput) ||
1071 			     (get_execution_model() == ExecutionModelTessellationControl && var.storage == StorageClassOutput)) &&
1072 			    !(has_decoration(arg_id, DecorationPatch) || is_patch_block(*p_type)) &&
1073 			    (!is_builtin_variable(var) || bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
1074 			     bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance ||
1075 			     p_type->basetype == SPIRType::Struct))
1076 			{
1077 				// Tessellation control shaders see inputs and per-vertex outputs as arrays.
1078 				// Similarly, tessellation evaluation shaders see per-vertex inputs as arrays.
1079 				// We collected them into a structure; we must pass the array of this
1080 				// structure to the function.
1081 				std::string name;
1082 				if (var.storage == StorageClassInput)
1083 				{
1084 					if (added_in)
1085 						continue;
1086 					name = input_wg_var_name;
1087 					arg_id = stage_in_ptr_var_id;
1088 					added_in = true;
1089 				}
1090 				else if (var.storage == StorageClassOutput)
1091 				{
1092 					if (added_out)
1093 						continue;
1094 					name = "gl_out";
1095 					arg_id = stage_out_ptr_var_id;
1096 					added_out = true;
1097 				}
1098 				type_id = get<SPIRVariable>(arg_id).basetype;
1099 				uint32_t next_id = ir.increase_bound_by(1);
1100 				func.add_parameter(type_id, next_id, true);
1101 				set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1102 
1103 				set_name(next_id, name);
1104 			}
1105 			else if (is_builtin_variable(var) && p_type->basetype == SPIRType::Struct)
1106 			{
1107 				// Get the pointee type
1108 				type_id = get_pointee_type_id(type_id);
1109 				p_type = &get<SPIRType>(type_id);
1110 
1111 				uint32_t mbr_idx = 0;
1112 				for (auto &mbr_type_id : p_type->member_types)
1113 				{
1114 					BuiltIn builtin = BuiltInMax;
1115 					bool is_builtin = is_member_builtin(*p_type, mbr_idx, &builtin);
1116 					if (is_builtin && has_active_builtin(builtin, var.storage))
1117 					{
1118 						// Add a arg variable with the same type and decorations as the member
1119 						uint32_t next_ids = ir.increase_bound_by(2);
1120 						uint32_t ptr_type_id = next_ids + 0;
1121 						uint32_t var_id = next_ids + 1;
1122 
1123 						// Make sure we have an actual pointer type,
1124 						// so that we will get the appropriate address space when declaring these builtins.
1125 						auto &ptr = set<SPIRType>(ptr_type_id, get<SPIRType>(mbr_type_id));
1126 						ptr.self = mbr_type_id;
1127 						ptr.storage = var.storage;
1128 						ptr.pointer = true;
1129 						ptr.parent_type = mbr_type_id;
1130 
1131 						func.add_parameter(mbr_type_id, var_id, true);
1132 						set<SPIRVariable>(var_id, ptr_type_id, StorageClassFunction);
1133 						ir.meta[var_id].decoration = ir.meta[type_id].members[mbr_idx];
1134 					}
1135 					mbr_idx++;
1136 				}
1137 			}
1138 			else
1139 			{
1140 				uint32_t next_id = ir.increase_bound_by(1);
1141 				func.add_parameter(type_id, next_id, true);
1142 				set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1143 
1144 				// Ensure the existing variable has a valid name and the new variable has all the same meta info
1145 				set_name(arg_id, ensure_valid_name(to_name(arg_id), "v"));
1146 				ir.meta[next_id] = ir.meta[arg_id];
1147 			}
1148 		}
1149 	}
1150 }
1151 
1152 // For all variables that are some form of non-input-output interface block, mark that all the structs
1153 // that are recursively contained within the type referenced by that variable should be packed tightly.
mark_packable_structs()1154 void CompilerMSL::mark_packable_structs()
1155 {
1156 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1157 		if (var.storage != StorageClassFunction && !is_hidden_variable(var))
1158 		{
1159 			auto &type = this->get<SPIRType>(var.basetype);
1160 			if (type.pointer &&
1161 			    (type.storage == StorageClassUniform || type.storage == StorageClassUniformConstant ||
1162 			     type.storage == StorageClassPushConstant || type.storage == StorageClassStorageBuffer) &&
1163 			    (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
1164 				mark_as_packable(type);
1165 		}
1166 	});
1167 }
1168 
1169 // If the specified type is a struct, it and any nested structs
1170 // are marked as packable with the SPIRVCrossDecorationPacked decoration,
mark_as_packable(SPIRType & type)1171 void CompilerMSL::mark_as_packable(SPIRType &type)
1172 {
1173 	// If this is not the base type (eg. it's a pointer or array), tunnel down
1174 	if (type.parent_type)
1175 	{
1176 		mark_as_packable(get<SPIRType>(type.parent_type));
1177 		return;
1178 	}
1179 
1180 	if (type.basetype == SPIRType::Struct)
1181 	{
1182 		set_extended_decoration(type.self, SPIRVCrossDecorationPacked);
1183 
1184 		// Recurse
1185 		size_t mbr_cnt = type.member_types.size();
1186 		for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
1187 		{
1188 			uint32_t mbr_type_id = type.member_types[mbr_idx];
1189 			auto &mbr_type = get<SPIRType>(mbr_type_id);
1190 			mark_as_packable(mbr_type);
1191 			if (mbr_type.type_alias)
1192 			{
1193 				auto &mbr_type_alias = get<SPIRType>(mbr_type.type_alias);
1194 				mark_as_packable(mbr_type_alias);
1195 			}
1196 		}
1197 	}
1198 }
1199 
1200 // If a vertex attribute exists at the location, it is marked as being used by this shader
mark_location_as_used_by_shader(uint32_t location,StorageClass storage)1201 void CompilerMSL::mark_location_as_used_by_shader(uint32_t location, StorageClass storage)
1202 {
1203 	if ((get_execution_model() == ExecutionModelVertex || is_tessellation_shader()) && (storage == StorageClassInput))
1204 		vtx_attrs_in_use.insert(location);
1205 }
1206 
get_target_components_for_fragment_location(uint32_t location) const1207 uint32_t CompilerMSL::get_target_components_for_fragment_location(uint32_t location) const
1208 {
1209 	auto itr = fragment_output_components.find(location);
1210 	if (itr == end(fragment_output_components))
1211 		return 4;
1212 	else
1213 		return itr->second;
1214 }
1215 
build_extended_vector_type(uint32_t type_id,uint32_t components)1216 uint32_t CompilerMSL::build_extended_vector_type(uint32_t type_id, uint32_t components)
1217 {
1218 	uint32_t new_type_id = ir.increase_bound_by(1);
1219 	auto &type = set<SPIRType>(new_type_id, get<SPIRType>(type_id));
1220 	type.vecsize = components;
1221 	type.self = new_type_id;
1222 	type.parent_type = type_id;
1223 	type.pointer = false;
1224 
1225 	return new_type_id;
1226 }
1227 
add_plain_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,bool strip_array)1228 void CompilerMSL::add_plain_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
1229                                                         SPIRType &ib_type, SPIRVariable &var, bool strip_array)
1230 {
1231 	bool is_builtin = is_builtin_variable(var);
1232 	BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
1233 	bool is_flat = has_decoration(var.self, DecorationFlat);
1234 	bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
1235 	bool is_centroid = has_decoration(var.self, DecorationCentroid);
1236 	bool is_sample = has_decoration(var.self, DecorationSample);
1237 
1238 	// Add a reference to the variable type to the interface struct.
1239 	uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1240 	uint32_t type_id = ensure_correct_builtin_type(var.basetype, builtin);
1241 	var.basetype = type_id;
1242 
1243 	type_id = get_pointee_type_id(var.basetype);
1244 	if (strip_array && is_array(get<SPIRType>(type_id)))
1245 		type_id = get<SPIRType>(type_id).parent_type;
1246 	auto &type = get<SPIRType>(type_id);
1247 	uint32_t target_components = 0;
1248 	uint32_t type_components = type.vecsize;
1249 	bool padded_output = false;
1250 
1251 	// Check if we need to pad fragment output to match a certain number of components.
1252 	if (get_decoration_bitset(var.self).get(DecorationLocation) && msl_options.pad_fragment_output_components &&
1253 	    get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput)
1254 	{
1255 		uint32_t locn = get_decoration(var.self, DecorationLocation);
1256 		target_components = get_target_components_for_fragment_location(locn);
1257 		if (type_components < target_components)
1258 		{
1259 			// Make a new type here.
1260 			type_id = build_extended_vector_type(type_id, target_components);
1261 			padded_output = true;
1262 		}
1263 	}
1264 
1265 	ib_type.member_types.push_back(type_id);
1266 
1267 	// Give the member a name
1268 	string mbr_name = ensure_valid_name(to_expression(var.self), "m");
1269 	set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1270 
1271 	// Update the original variable reference to include the structure reference
1272 	string qual_var_name = ib_var_ref + "." + mbr_name;
1273 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1274 
1275 	if (padded_output)
1276 	{
1277 		entry_func.add_local_variable(var.self);
1278 		vars_needing_early_declaration.push_back(var.self);
1279 
1280 		entry_func.fixup_hooks_out.push_back([=, &var]() {
1281 			SPIRType &padded_type = this->get<SPIRType>(type_id);
1282 			statement(qual_var_name, " = ", remap_swizzle(padded_type, type_components, to_name(var.self)), ";");
1283 		});
1284 	}
1285 	else if (!strip_array)
1286 		ir.meta[var.self].decoration.qualified_alias = qual_var_name;
1287 
1288 	if (var.storage == StorageClassOutput && var.initializer != 0)
1289 	{
1290 		entry_func.fixup_hooks_in.push_back(
1291 		    [=, &var]() { statement(qual_var_name, " = ", to_expression(var.initializer), ";"); });
1292 	}
1293 
1294 	// Copy the variable location from the original variable to the member
1295 	if (get_decoration_bitset(var.self).get(DecorationLocation))
1296 	{
1297 		uint32_t locn = get_decoration(var.self, DecorationLocation);
1298 		if (storage == StorageClassInput && (get_execution_model() == ExecutionModelVertex || is_tessellation_shader()))
1299 		{
1300 			type_id = ensure_correct_attribute_type(var.basetype, locn);
1301 			var.basetype = type_id;
1302 			type_id = get_pointee_type_id(type_id);
1303 			if (strip_array && is_array(get<SPIRType>(type_id)))
1304 				type_id = get<SPIRType>(type_id).parent_type;
1305 			ib_type.member_types[ib_mbr_idx] = type_id;
1306 		}
1307 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1308 		mark_location_as_used_by_shader(locn, storage);
1309 	}
1310 	else if (is_builtin && is_tessellation_shader() && vtx_attrs_by_builtin.count(builtin))
1311 	{
1312 		uint32_t locn = vtx_attrs_by_builtin[builtin].location;
1313 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1314 		mark_location_as_used_by_shader(locn, storage);
1315 	}
1316 
1317 	if (get_decoration_bitset(var.self).get(DecorationComponent))
1318 	{
1319 		uint32_t comp = get_decoration(var.self, DecorationComponent);
1320 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp);
1321 	}
1322 
1323 	if (get_decoration_bitset(var.self).get(DecorationIndex))
1324 	{
1325 		uint32_t index = get_decoration(var.self, DecorationIndex);
1326 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
1327 	}
1328 
1329 	// Mark the member as builtin if needed
1330 	if (is_builtin)
1331 	{
1332 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
1333 		if (builtin == BuiltInPosition && storage == StorageClassOutput)
1334 			qual_pos_var_name = qual_var_name;
1335 	}
1336 
1337 	// Copy interpolation decorations if needed
1338 	if (is_flat)
1339 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
1340 	if (is_noperspective)
1341 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
1342 	if (is_centroid)
1343 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
1344 	if (is_sample)
1345 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
1346 
1347 	set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
1348 }
1349 
add_composite_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,bool strip_array)1350 void CompilerMSL::add_composite_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
1351                                                             SPIRType &ib_type, SPIRVariable &var, bool strip_array)
1352 {
1353 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1354 	auto &var_type = strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
1355 	uint32_t elem_cnt = 0;
1356 
1357 	if (is_matrix(var_type))
1358 	{
1359 		if (is_array(var_type))
1360 			SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
1361 
1362 		elem_cnt = var_type.columns;
1363 	}
1364 	else if (is_array(var_type))
1365 	{
1366 		if (var_type.array.size() != 1)
1367 			SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
1368 
1369 		elem_cnt = to_array_size_literal(var_type);
1370 	}
1371 
1372 	bool is_builtin = is_builtin_variable(var);
1373 	BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
1374 	bool is_flat = has_decoration(var.self, DecorationFlat);
1375 	bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
1376 	bool is_centroid = has_decoration(var.self, DecorationCentroid);
1377 	bool is_sample = has_decoration(var.self, DecorationSample);
1378 
1379 	auto *usable_type = &var_type;
1380 	if (usable_type->pointer)
1381 		usable_type = &get<SPIRType>(usable_type->parent_type);
1382 	while (is_array(*usable_type) || is_matrix(*usable_type))
1383 		usable_type = &get<SPIRType>(usable_type->parent_type);
1384 
1385 	// If a builtin, force it to have the proper name.
1386 	if (is_builtin)
1387 		set_name(var.self, builtin_to_glsl(builtin, StorageClassFunction));
1388 
1389 	entry_func.add_local_variable(var.self);
1390 
1391 	// We need to declare the variable early and at entry-point scope.
1392 	vars_needing_early_declaration.push_back(var.self);
1393 
1394 	for (uint32_t i = 0; i < elem_cnt; i++)
1395 	{
1396 		// Add a reference to the variable type to the interface struct.
1397 		uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1398 
1399 		uint32_t target_components = 0;
1400 		bool padded_output = false;
1401 		uint32_t type_id = usable_type->self;
1402 
1403 		// Check if we need to pad fragment output to match a certain number of components.
1404 		if (get_decoration_bitset(var.self).get(DecorationLocation) && msl_options.pad_fragment_output_components &&
1405 		    get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput)
1406 		{
1407 			uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
1408 			target_components = get_target_components_for_fragment_location(locn);
1409 			if (usable_type->vecsize < target_components)
1410 			{
1411 				// Make a new type here.
1412 				type_id = build_extended_vector_type(usable_type->self, target_components);
1413 				padded_output = true;
1414 			}
1415 		}
1416 
1417 		ib_type.member_types.push_back(get_pointee_type_id(type_id));
1418 
1419 		// Give the member a name
1420 		string mbr_name = ensure_valid_name(join(to_expression(var.self), "_", i), "m");
1421 		set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1422 
1423 		// There is no qualified alias since we need to flatten the internal array on return.
1424 		if (get_decoration_bitset(var.self).get(DecorationLocation))
1425 		{
1426 			uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
1427 			if (storage == StorageClassInput &&
1428 			    (get_execution_model() == ExecutionModelVertex || is_tessellation_shader()))
1429 			{
1430 				var.basetype = ensure_correct_attribute_type(var.basetype, locn);
1431 				uint32_t mbr_type_id = ensure_correct_attribute_type(usable_type->self, locn);
1432 				ib_type.member_types[ib_mbr_idx] = mbr_type_id;
1433 			}
1434 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1435 			mark_location_as_used_by_shader(locn, storage);
1436 		}
1437 		else if (is_builtin && is_tessellation_shader() && vtx_attrs_by_builtin.count(builtin))
1438 		{
1439 			uint32_t locn = vtx_attrs_by_builtin[builtin].location + i;
1440 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1441 			mark_location_as_used_by_shader(locn, storage);
1442 		}
1443 
1444 		if (get_decoration_bitset(var.self).get(DecorationIndex))
1445 		{
1446 			uint32_t index = get_decoration(var.self, DecorationIndex);
1447 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
1448 		}
1449 
1450 		// Copy interpolation decorations if needed
1451 		if (is_flat)
1452 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
1453 		if (is_noperspective)
1454 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
1455 		if (is_centroid)
1456 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
1457 		if (is_sample)
1458 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
1459 
1460 		set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
1461 
1462 		if (!strip_array)
1463 		{
1464 			switch (storage)
1465 			{
1466 			case StorageClassInput:
1467 				entry_func.fixup_hooks_in.push_back(
1468 				    [=, &var]() { statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, ";"); });
1469 				break;
1470 
1471 			case StorageClassOutput:
1472 				entry_func.fixup_hooks_out.push_back([=, &var]() {
1473 					if (padded_output)
1474 					{
1475 						auto &padded_type = this->get<SPIRType>(type_id);
1476 						statement(
1477 						    ib_var_ref, ".", mbr_name, " = ",
1478 						    remap_swizzle(padded_type, usable_type->vecsize, join(to_name(var.self), "[", i, "]")),
1479 						    ";");
1480 					}
1481 					else
1482 						statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), "[", i, "];");
1483 				});
1484 				break;
1485 
1486 			default:
1487 				break;
1488 			}
1489 		}
1490 	}
1491 }
1492 
get_accumulated_member_location(const SPIRVariable & var,uint32_t mbr_idx,bool strip_array)1493 uint32_t CompilerMSL::get_accumulated_member_location(const SPIRVariable &var, uint32_t mbr_idx, bool strip_array)
1494 {
1495 	auto &type = strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
1496 	uint32_t location = get_decoration(var.self, DecorationLocation);
1497 
1498 	for (uint32_t i = 0; i < mbr_idx; i++)
1499 	{
1500 		auto &mbr_type = get<SPIRType>(type.member_types[i]);
1501 
1502 		// Start counting from any place we have a new location decoration.
1503 		if (has_member_decoration(type.self, mbr_idx, DecorationLocation))
1504 			location = get_member_decoration(type.self, mbr_idx, DecorationLocation);
1505 
1506 		uint32_t location_count = 1;
1507 
1508 		if (mbr_type.columns > 1)
1509 			location_count = mbr_type.columns;
1510 
1511 		if (!mbr_type.array.empty())
1512 			for (uint32_t j = 0; j < uint32_t(mbr_type.array.size()); j++)
1513 				location_count *= to_array_size_literal(mbr_type, j);
1514 
1515 		location += location_count;
1516 	}
1517 
1518 	return location;
1519 }
1520 
add_composite_member_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,uint32_t mbr_idx,bool strip_array)1521 void CompilerMSL::add_composite_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
1522                                                                    SPIRType &ib_type, SPIRVariable &var,
1523                                                                    uint32_t mbr_idx, bool strip_array)
1524 {
1525 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1526 	auto &var_type = strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
1527 
1528 	BuiltIn builtin;
1529 	bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
1530 	bool is_flat =
1531 	    has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
1532 	bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
1533 	                        has_decoration(var.self, DecorationNoPerspective);
1534 	bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
1535 	                   has_decoration(var.self, DecorationCentroid);
1536 	bool is_sample =
1537 	    has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
1538 
1539 	uint32_t mbr_type_id = var_type.member_types[mbr_idx];
1540 	auto &mbr_type = get<SPIRType>(mbr_type_id);
1541 	uint32_t elem_cnt = 0;
1542 
1543 	if (is_matrix(mbr_type))
1544 	{
1545 		if (is_array(mbr_type))
1546 			SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
1547 
1548 		elem_cnt = mbr_type.columns;
1549 	}
1550 	else if (is_array(mbr_type))
1551 	{
1552 		if (mbr_type.array.size() != 1)
1553 			SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
1554 
1555 		elem_cnt = to_array_size_literal(mbr_type);
1556 	}
1557 
1558 	auto *usable_type = &mbr_type;
1559 	if (usable_type->pointer)
1560 		usable_type = &get<SPIRType>(usable_type->parent_type);
1561 	while (is_array(*usable_type) || is_matrix(*usable_type))
1562 		usable_type = &get<SPIRType>(usable_type->parent_type);
1563 
1564 	for (uint32_t i = 0; i < elem_cnt; i++)
1565 	{
1566 		// Add a reference to the variable type to the interface struct.
1567 		uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1568 		ib_type.member_types.push_back(usable_type->self);
1569 
1570 		// Give the member a name
1571 		string mbr_name = ensure_valid_name(join(to_qualified_member_name(var_type, mbr_idx), "_", i), "m");
1572 		set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1573 
1574 		if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
1575 		{
1576 			uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation) + i;
1577 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1578 			mark_location_as_used_by_shader(locn, storage);
1579 		}
1580 		else if (has_decoration(var.self, DecorationLocation))
1581 		{
1582 			uint32_t locn = get_accumulated_member_location(var, mbr_idx, strip_array) + i;
1583 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1584 			mark_location_as_used_by_shader(locn, storage);
1585 		}
1586 		else if (is_builtin && is_tessellation_shader() && vtx_attrs_by_builtin.count(builtin))
1587 		{
1588 			uint32_t locn = vtx_attrs_by_builtin[builtin].location + i;
1589 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1590 			mark_location_as_used_by_shader(locn, storage);
1591 		}
1592 
1593 		if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
1594 			SPIRV_CROSS_THROW("DecorationComponent on matrices and arrays make little sense.");
1595 
1596 		// Copy interpolation decorations if needed
1597 		if (is_flat)
1598 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
1599 		if (is_noperspective)
1600 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
1601 		if (is_centroid)
1602 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
1603 		if (is_sample)
1604 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
1605 
1606 		set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
1607 		set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
1608 
1609 		// Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
1610 		if (!strip_array)
1611 		{
1612 			switch (storage)
1613 			{
1614 			case StorageClassInput:
1615 				entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
1616 					statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), "[", i, "] = ", ib_var_ref,
1617 					          ".", mbr_name, ";");
1618 				});
1619 				break;
1620 
1621 			case StorageClassOutput:
1622 				entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
1623 					statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), ".",
1624 					          to_member_name(var_type, mbr_idx), "[", i, "];");
1625 				});
1626 				break;
1627 
1628 			default:
1629 				break;
1630 			}
1631 		}
1632 	}
1633 }
1634 
add_plain_member_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,uint32_t mbr_idx,bool strip_array)1635 void CompilerMSL::add_plain_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
1636                                                                SPIRType &ib_type, SPIRVariable &var, uint32_t mbr_idx,
1637                                                                bool strip_array)
1638 {
1639 	auto &var_type = strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
1640 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1641 
1642 	BuiltIn builtin = BuiltInMax;
1643 	bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
1644 	bool is_flat =
1645 	    has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
1646 	bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
1647 	                        has_decoration(var.self, DecorationNoPerspective);
1648 	bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
1649 	                   has_decoration(var.self, DecorationCentroid);
1650 	bool is_sample =
1651 	    has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
1652 
1653 	// Add a reference to the member to the interface struct.
1654 	uint32_t mbr_type_id = var_type.member_types[mbr_idx];
1655 	uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1656 	mbr_type_id = ensure_correct_builtin_type(mbr_type_id, builtin);
1657 	var_type.member_types[mbr_idx] = mbr_type_id;
1658 	ib_type.member_types.push_back(mbr_type_id);
1659 
1660 	// Give the member a name
1661 	string mbr_name = ensure_valid_name(to_qualified_member_name(var_type, mbr_idx), "m");
1662 	set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1663 
1664 	// Update the original variable reference to include the structure reference
1665 	string qual_var_name = ib_var_ref + "." + mbr_name;
1666 
1667 	if (is_builtin && !strip_array)
1668 	{
1669 		// For the builtin gl_PerVertex, we cannot treat it as a block anyways,
1670 		// so redirect to qualified name.
1671 		set_member_qualified_name(var_type.self, mbr_idx, qual_var_name);
1672 	}
1673 	else if (!strip_array)
1674 	{
1675 		// Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
1676 		switch (storage)
1677 		{
1678 		case StorageClassInput:
1679 			entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
1680 				statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), " = ", qual_var_name, ";");
1681 			});
1682 			break;
1683 
1684 		case StorageClassOutput:
1685 			entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
1686 				statement(qual_var_name, " = ", to_name(var.self), ".", to_member_name(var_type, mbr_idx), ";");
1687 			});
1688 			break;
1689 
1690 		default:
1691 			break;
1692 		}
1693 	}
1694 
1695 	// Copy the variable location from the original variable to the member
1696 	if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
1697 	{
1698 		uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation);
1699 		if (storage == StorageClassInput && (get_execution_model() == ExecutionModelVertex || is_tessellation_shader()))
1700 		{
1701 			mbr_type_id = ensure_correct_attribute_type(mbr_type_id, locn);
1702 			var_type.member_types[mbr_idx] = mbr_type_id;
1703 			ib_type.member_types[ib_mbr_idx] = mbr_type_id;
1704 		}
1705 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1706 		mark_location_as_used_by_shader(locn, storage);
1707 	}
1708 	else if (has_decoration(var.self, DecorationLocation))
1709 	{
1710 		// The block itself might have a location and in this case, all members of the block
1711 		// receive incrementing locations.
1712 		uint32_t locn = get_accumulated_member_location(var, mbr_idx, strip_array);
1713 		if (storage == StorageClassInput && (get_execution_model() == ExecutionModelVertex || is_tessellation_shader()))
1714 		{
1715 			mbr_type_id = ensure_correct_attribute_type(mbr_type_id, locn);
1716 			var_type.member_types[mbr_idx] = mbr_type_id;
1717 			ib_type.member_types[ib_mbr_idx] = mbr_type_id;
1718 		}
1719 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1720 		mark_location_as_used_by_shader(locn, storage);
1721 	}
1722 	else if (is_builtin && is_tessellation_shader() && vtx_attrs_by_builtin.count(builtin))
1723 	{
1724 		uint32_t locn = 0;
1725 		auto builtin_itr = vtx_attrs_by_builtin.find(builtin);
1726 		if (builtin_itr != end(vtx_attrs_by_builtin))
1727 			locn = builtin_itr->second.location;
1728 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1729 		mark_location_as_used_by_shader(locn, storage);
1730 	}
1731 
1732 	// Copy the component location, if present.
1733 	if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
1734 	{
1735 		uint32_t comp = get_member_decoration(var_type.self, mbr_idx, DecorationComponent);
1736 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp);
1737 	}
1738 
1739 	// Mark the member as builtin if needed
1740 	if (is_builtin)
1741 	{
1742 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
1743 		if (builtin == BuiltInPosition && storage == StorageClassOutput)
1744 			qual_pos_var_name = qual_var_name;
1745 	}
1746 
1747 	// Copy interpolation decorations if needed
1748 	if (is_flat)
1749 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
1750 	if (is_noperspective)
1751 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
1752 	if (is_centroid)
1753 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
1754 	if (is_sample)
1755 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
1756 
1757 	set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
1758 	set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
1759 }
1760 
1761 // In Metal, the tessellation levels are stored as tightly packed half-precision floating point values.
1762 // But, stage-in attribute offsets and strides must be multiples of four, so we can't pass the levels
1763 // individually. Therefore, we must pass them as vectors. Triangles get a single float4, with the outer
1764 // levels in 'xyz' and the inner level in 'w'. Quads get a float4 containing the outer levels and a
1765 // float2 containing the inner levels.
add_tess_level_input_to_interface_block(const std::string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var)1766 void CompilerMSL::add_tess_level_input_to_interface_block(const std::string &ib_var_ref, SPIRType &ib_type,
1767                                                           SPIRVariable &var)
1768 {
1769 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1770 	auto &var_type = get_variable_element_type(var);
1771 
1772 	BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
1773 
1774 	// Force the variable to have the proper name.
1775 	set_name(var.self, builtin_to_glsl(builtin, StorageClassFunction));
1776 
1777 	if (get_entry_point().flags.get(ExecutionModeTriangles))
1778 	{
1779 		// Triangles are tricky, because we want only one member in the struct.
1780 
1781 		// We need to declare the variable early and at entry-point scope.
1782 		entry_func.add_local_variable(var.self);
1783 		vars_needing_early_declaration.push_back(var.self);
1784 
1785 		string mbr_name = "gl_TessLevel";
1786 
1787 		// If we already added the other one, we can skip this step.
1788 		if (!added_builtin_tess_level)
1789 		{
1790 			// Add a reference to the variable type to the interface struct.
1791 			uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1792 
1793 			uint32_t type_id = build_extended_vector_type(var_type.self, 4);
1794 
1795 			ib_type.member_types.push_back(type_id);
1796 
1797 			// Give the member a name
1798 			set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1799 
1800 			// There is no qualified alias since we need to flatten the internal array on return.
1801 			if (get_decoration_bitset(var.self).get(DecorationLocation))
1802 			{
1803 				uint32_t locn = get_decoration(var.self, DecorationLocation);
1804 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1805 				mark_location_as_used_by_shader(locn, StorageClassInput);
1806 			}
1807 			else if (vtx_attrs_by_builtin.count(builtin))
1808 			{
1809 				uint32_t locn = vtx_attrs_by_builtin[builtin].location;
1810 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1811 				mark_location_as_used_by_shader(locn, StorageClassInput);
1812 			}
1813 
1814 			added_builtin_tess_level = true;
1815 		}
1816 
1817 		switch (builtin)
1818 		{
1819 		case BuiltInTessLevelOuter:
1820 			entry_func.fixup_hooks_in.push_back([=, &var]() {
1821 				statement(to_name(var.self), "[0] = ", ib_var_ref, ".", mbr_name, ".x;");
1822 				statement(to_name(var.self), "[1] = ", ib_var_ref, ".", mbr_name, ".y;");
1823 				statement(to_name(var.self), "[2] = ", ib_var_ref, ".", mbr_name, ".z;");
1824 			});
1825 			break;
1826 
1827 		case BuiltInTessLevelInner:
1828 			entry_func.fixup_hooks_in.push_back(
1829 			    [=, &var]() { statement(to_name(var.self), "[0] = ", ib_var_ref, ".", mbr_name, ".w;"); });
1830 			break;
1831 
1832 		default:
1833 			assert(false);
1834 			break;
1835 		}
1836 	}
1837 	else
1838 	{
1839 		// Add a reference to the variable type to the interface struct.
1840 		uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1841 
1842 		uint32_t type_id = build_extended_vector_type(var_type.self, builtin == BuiltInTessLevelOuter ? 4 : 2);
1843 		// Change the type of the variable, too.
1844 		uint32_t ptr_type_id = ir.increase_bound_by(1);
1845 		auto &new_var_type = set<SPIRType>(ptr_type_id, get<SPIRType>(type_id));
1846 		new_var_type.pointer = true;
1847 		new_var_type.storage = StorageClassInput;
1848 		new_var_type.parent_type = type_id;
1849 		var.basetype = ptr_type_id;
1850 
1851 		ib_type.member_types.push_back(type_id);
1852 
1853 		// Give the member a name
1854 		string mbr_name = to_expression(var.self);
1855 		set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1856 
1857 		// Since vectors can be indexed like arrays, there is no need to unpack this. We can
1858 		// just refer to the vector directly. So give it a qualified alias.
1859 		string qual_var_name = ib_var_ref + "." + mbr_name;
1860 		ir.meta[var.self].decoration.qualified_alias = qual_var_name;
1861 
1862 		if (get_decoration_bitset(var.self).get(DecorationLocation))
1863 		{
1864 			uint32_t locn = get_decoration(var.self, DecorationLocation);
1865 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1866 			mark_location_as_used_by_shader(locn, StorageClassInput);
1867 		}
1868 		else if (vtx_attrs_by_builtin.count(builtin))
1869 		{
1870 			uint32_t locn = vtx_attrs_by_builtin[builtin].location;
1871 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1872 			mark_location_as_used_by_shader(locn, StorageClassInput);
1873 		}
1874 	}
1875 }
1876 
add_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,bool strip_array)1877 void CompilerMSL::add_variable_to_interface_block(StorageClass storage, const string &ib_var_ref, SPIRType &ib_type,
1878                                                   SPIRVariable &var, bool strip_array)
1879 {
1880 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1881 	// Tessellation control I/O variables and tessellation evaluation per-point inputs are
1882 	// usually declared as arrays. In these cases, we want to add the element type to the
1883 	// interface block, since in Metal it's the interface block itself which is arrayed.
1884 	auto &var_type = strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
1885 	bool is_builtin = is_builtin_variable(var);
1886 	auto builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
1887 
1888 	if (var_type.basetype == SPIRType::Struct)
1889 	{
1890 		if (!is_builtin_type(var_type) && (!capture_output_to_buffer || storage == StorageClassInput) && !strip_array)
1891 		{
1892 			// For I/O blocks or structs, we will need to pass the block itself around
1893 			// to functions if they are used globally in leaf functions.
1894 			// Rather than passing down member by member,
1895 			// we unflatten I/O blocks while running the shader,
1896 			// and pass the actual struct type down to leaf functions.
1897 			// We then unflatten inputs, and flatten outputs in the "fixup" stages.
1898 			entry_func.add_local_variable(var.self);
1899 			vars_needing_early_declaration.push_back(var.self);
1900 		}
1901 
1902 		if (capture_output_to_buffer && storage != StorageClassInput && !has_decoration(var_type.self, DecorationBlock))
1903 		{
1904 			// In Metal tessellation shaders, the interface block itself is arrayed. This makes things
1905 			// very complicated, since stage-in structures in MSL don't support nested structures.
1906 			// Luckily, for stage-out when capturing output, we can avoid this and just add
1907 			// composite members directly, because the stage-out structure is stored to a buffer,
1908 			// not returned.
1909 			add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, strip_array);
1910 		}
1911 		else
1912 		{
1913 			// Flatten the struct members into the interface struct
1914 			for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++)
1915 			{
1916 				builtin = BuiltInMax;
1917 				is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
1918 				auto &mbr_type = get<SPIRType>(var_type.member_types[mbr_idx]);
1919 
1920 				if (!is_builtin || has_active_builtin(builtin, storage))
1921 				{
1922 					if ((!is_builtin ||
1923 					     (storage == StorageClassInput && get_execution_model() != ExecutionModelFragment)) &&
1924 					    (storage == StorageClassInput || storage == StorageClassOutput) &&
1925 					    (is_matrix(mbr_type) || is_array(mbr_type)))
1926 					{
1927 						add_composite_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx,
1928 						                                                 strip_array);
1929 					}
1930 					else
1931 					{
1932 						add_plain_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx,
1933 						                                             strip_array);
1934 					}
1935 				}
1936 			}
1937 		}
1938 	}
1939 	else if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput &&
1940 	         !strip_array && is_builtin && (builtin == BuiltInTessLevelOuter || builtin == BuiltInTessLevelInner))
1941 	{
1942 		add_tess_level_input_to_interface_block(ib_var_ref, ib_type, var);
1943 	}
1944 	else if (var_type.basetype == SPIRType::Boolean || var_type.basetype == SPIRType::Char ||
1945 	         type_is_integral(var_type) || type_is_floating_point(var_type) || var_type.basetype == SPIRType::Boolean)
1946 	{
1947 		if (!is_builtin || has_active_builtin(builtin, storage))
1948 		{
1949 			// MSL does not allow matrices or arrays in input or output variables, so need to handle it specially.
1950 			if ((!is_builtin || (storage == StorageClassInput && get_execution_model() != ExecutionModelFragment)) &&
1951 			    (storage == StorageClassInput || (storage == StorageClassOutput && !capture_output_to_buffer)) &&
1952 			    (is_matrix(var_type) || is_array(var_type)))
1953 			{
1954 				add_composite_variable_to_interface_block(storage, ib_var_ref, ib_type, var, strip_array);
1955 			}
1956 			else
1957 			{
1958 				add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, strip_array);
1959 			}
1960 		}
1961 	}
1962 }
1963 
1964 // Fix up the mapping of variables to interface member indices, which is used to compile access chains
1965 // for per-vertex variables in a tessellation control shader.
fix_up_interface_member_indices(StorageClass storage,uint32_t ib_type_id)1966 void CompilerMSL::fix_up_interface_member_indices(StorageClass storage, uint32_t ib_type_id)
1967 {
1968 	// Only needed for tessellation shaders.
1969 	if (get_execution_model() != ExecutionModelTessellationControl &&
1970 	    !(get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput))
1971 		return;
1972 
1973 	bool in_array = false;
1974 	for (uint32_t i = 0; i < ir.meta[ib_type_id].members.size(); i++)
1975 	{
1976 		auto &mbr_dec = ir.meta[ib_type_id].members[i];
1977 		uint32_t var_id = mbr_dec.extended.ib_orig_id;
1978 		if (!var_id)
1979 			continue;
1980 		auto &var = get<SPIRVariable>(var_id);
1981 
1982 		// Unfortunately, all this complexity is needed to handle flattened structs and/or
1983 		// arrays.
1984 		if (storage == StorageClassInput)
1985 		{
1986 			auto &type = get_variable_element_type(var);
1987 			if (is_array(type) || is_matrix(type))
1988 			{
1989 				if (in_array)
1990 					continue;
1991 				in_array = true;
1992 				set_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex, i);
1993 			}
1994 			else
1995 			{
1996 				if (type.basetype == SPIRType::Struct)
1997 				{
1998 					uint32_t mbr_idx =
1999 					    get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceMemberIndex);
2000 					auto &mbr_type = get<SPIRType>(type.member_types[mbr_idx]);
2001 
2002 					if (is_array(mbr_type) || is_matrix(mbr_type))
2003 					{
2004 						if (in_array)
2005 							continue;
2006 						in_array = true;
2007 						set_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, i);
2008 					}
2009 					else
2010 					{
2011 						in_array = false;
2012 						set_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, i);
2013 					}
2014 				}
2015 				else
2016 				{
2017 					in_array = false;
2018 					set_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex, i);
2019 				}
2020 			}
2021 		}
2022 		else
2023 			set_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex, i);
2024 	}
2025 }
2026 
2027 // Add an interface structure for the type of storage, which is either StorageClassInput or StorageClassOutput.
2028 // Returns the ID of the newly added variable, or zero if no variable was added.
add_interface_block(StorageClass storage,bool patch)2029 uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch)
2030 {
2031 	// Accumulate the variables that should appear in the interface struct.
2032 	SmallVector<SPIRVariable *> vars;
2033 	bool incl_builtins = storage == StorageClassOutput || is_tessellation_shader();
2034 	bool has_seen_barycentric = false;
2035 
2036 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
2037 		if (var.storage != storage)
2038 			return;
2039 
2040 		auto &type = this->get<SPIRType>(var.basetype);
2041 
2042 		bool is_builtin = is_builtin_variable(var);
2043 		auto bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
2044 
2045 		// These builtins are part of the stage in/out structs.
2046 		bool is_interface_block_builtin =
2047 		    (bi_type == BuiltInPosition || bi_type == BuiltInPointSize || bi_type == BuiltInClipDistance ||
2048 		     bi_type == BuiltInCullDistance || bi_type == BuiltInLayer || bi_type == BuiltInViewportIndex ||
2049 		     bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV || bi_type == BuiltInFragDepth ||
2050 		     bi_type == BuiltInFragStencilRefEXT || bi_type == BuiltInSampleMask) ||
2051 		    (get_execution_model() == ExecutionModelTessellationEvaluation &&
2052 		     (bi_type == BuiltInTessLevelOuter || bi_type == BuiltInTessLevelInner));
2053 
2054 		bool is_active = interface_variable_exists_in_entry_point(var.self);
2055 		if (is_builtin && is_active)
2056 		{
2057 			// Only emit the builtin if it's active in this entry point. Interface variable list might lie.
2058 			is_active = has_active_builtin(bi_type, storage);
2059 		}
2060 
2061 		bool filter_patch_decoration = (has_decoration(var_id, DecorationPatch) || is_patch_block(type)) == patch;
2062 
2063 		bool hidden = is_hidden_variable(var, incl_builtins);
2064 		// Barycentric inputs must be emitted in stage-in, because they can have interpolation arguments.
2065 		if (is_active && (bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV))
2066 		{
2067 			if (has_seen_barycentric)
2068 				SPIRV_CROSS_THROW("Cannot declare both BaryCoordNV and BaryCoordNoPerspNV in same shader in MSL.");
2069 			has_seen_barycentric = true;
2070 			hidden = false;
2071 		}
2072 
2073 		if (is_active && !hidden && type.pointer && filter_patch_decoration &&
2074 		    (!is_builtin || is_interface_block_builtin))
2075 		{
2076 			vars.push_back(&var);
2077 		}
2078 	});
2079 
2080 	// If no variables qualify, leave.
2081 	// For patch input in a tessellation evaluation shader, the per-vertex stage inputs
2082 	// are included in a special patch control point array.
2083 	if (vars.empty() && !(storage == StorageClassInput && patch && stage_in_var_id))
2084 		return 0;
2085 
2086 	// Add a new typed variable for this interface structure.
2087 	// The initializer expression is allocated here, but populated when the function
2088 	// declaraion is emitted, because it is cleared after each compilation pass.
2089 	uint32_t next_id = ir.increase_bound_by(3);
2090 	uint32_t ib_type_id = next_id++;
2091 	auto &ib_type = set<SPIRType>(ib_type_id);
2092 	ib_type.basetype = SPIRType::Struct;
2093 	ib_type.storage = storage;
2094 	set_decoration(ib_type_id, DecorationBlock);
2095 
2096 	uint32_t ib_var_id = next_id++;
2097 	auto &var = set<SPIRVariable>(ib_var_id, ib_type_id, storage, 0);
2098 	var.initializer = next_id++;
2099 
2100 	string ib_var_ref;
2101 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2102 	switch (storage)
2103 	{
2104 	case StorageClassInput:
2105 		ib_var_ref = patch ? patch_stage_in_var_name : stage_in_var_name;
2106 		if (get_execution_model() == ExecutionModelTessellationControl)
2107 		{
2108 			// Add a hook to populate the shared workgroup memory containing
2109 			// the gl_in array.
2110 			entry_func.fixup_hooks_in.push_back([=]() {
2111 				// Can't use PatchVertices yet; the hook for that may not have run yet.
2112 				statement("if (", to_expression(builtin_invocation_id_id), " < ", "spvIndirectParams[0])");
2113 				statement("    ", input_wg_var_name, "[", to_expression(builtin_invocation_id_id), "] = ", ib_var_ref,
2114 				          ";");
2115 				statement("threadgroup_barrier(mem_flags::mem_threadgroup);");
2116 				statement("if (", to_expression(builtin_invocation_id_id), " >= ", get_entry_point().output_vertices,
2117 				          ")");
2118 				statement("    return;");
2119 			});
2120 		}
2121 		break;
2122 
2123 	case StorageClassOutput:
2124 	{
2125 		ib_var_ref = patch ? patch_stage_out_var_name : stage_out_var_name;
2126 
2127 		// Add the output interface struct as a local variable to the entry function.
2128 		// If the entry point should return the output struct, set the entry function
2129 		// to return the output interface struct, otherwise to return nothing.
2130 		// Indicate the output var requires early initialization.
2131 		bool ep_should_return_output = !get_is_rasterization_disabled();
2132 		uint32_t rtn_id = ep_should_return_output ? ib_var_id : 0;
2133 		if (!capture_output_to_buffer)
2134 		{
2135 			entry_func.add_local_variable(ib_var_id);
2136 			for (auto &blk_id : entry_func.blocks)
2137 			{
2138 				auto &blk = get<SPIRBlock>(blk_id);
2139 				if (blk.terminator == SPIRBlock::Return)
2140 					blk.return_value = rtn_id;
2141 			}
2142 			vars_needing_early_declaration.push_back(ib_var_id);
2143 		}
2144 		else
2145 		{
2146 			switch (get_execution_model())
2147 			{
2148 			case ExecutionModelVertex:
2149 			case ExecutionModelTessellationEvaluation:
2150 				// Instead of declaring a struct variable to hold the output and then
2151 				// copying that to the output buffer, we'll declare the output variable
2152 				// as a reference to the final output element in the buffer. Then we can
2153 				// avoid the extra copy.
2154 				entry_func.fixup_hooks_in.push_back([=]() {
2155 					if (stage_out_var_id)
2156 					{
2157 						// The first member of the indirect buffer is always the number of vertices
2158 						// to draw.
2159 						statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref, " = ",
2160 						          output_buffer_var_name, "[(", to_expression(builtin_instance_idx_id), " - ",
2161 						          to_expression(builtin_base_instance_id), ") * spvIndirectParams[0] + ",
2162 						          to_expression(builtin_vertex_idx_id), " - ", to_expression(builtin_base_vertex_id),
2163 						          "];");
2164 					}
2165 				});
2166 				break;
2167 			case ExecutionModelTessellationControl:
2168 				if (patch)
2169 					entry_func.fixup_hooks_in.push_back([=]() {
2170 						statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref, " = ",
2171 						          patch_output_buffer_var_name, "[", to_expression(builtin_primitive_id_id), "];");
2172 					});
2173 				else
2174 					entry_func.fixup_hooks_in.push_back([=]() {
2175 						statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &",
2176 						          output_buffer_var_name, "[", to_expression(builtin_primitive_id_id), " * ",
2177 						          get_entry_point().output_vertices, "];");
2178 					});
2179 				break;
2180 			default:
2181 				break;
2182 			}
2183 		}
2184 		break;
2185 	}
2186 
2187 	default:
2188 		break;
2189 	}
2190 
2191 	set_name(ib_type_id, to_name(ir.default_entry_point) + "_" + ib_var_ref);
2192 	set_name(ib_var_id, ib_var_ref);
2193 
2194 	for (auto *p_var : vars)
2195 	{
2196 		bool strip_array =
2197 		    (get_execution_model() == ExecutionModelTessellationControl ||
2198 		     (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput)) &&
2199 		    !patch;
2200 		add_variable_to_interface_block(storage, ib_var_ref, ib_type, *p_var, strip_array);
2201 	}
2202 
2203 	// Sort the members of the structure by their locations.
2204 	MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Location);
2205 	member_sorter.sort();
2206 
2207 	// The member indices were saved to the original variables, but after the members
2208 	// were sorted, those indices are now likely incorrect. Fix those up now.
2209 	if (!patch)
2210 		fix_up_interface_member_indices(storage, ib_type_id);
2211 
2212 	// For patch inputs, add one more member, holding the array of control point data.
2213 	if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput && patch &&
2214 	    stage_in_var_id)
2215 	{
2216 		uint32_t pcp_type_id = ir.increase_bound_by(1);
2217 		auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
2218 		pcp_type.basetype = SPIRType::ControlPointArray;
2219 		pcp_type.parent_type = pcp_type.type_alias = get_stage_in_struct_type().self;
2220 		pcp_type.storage = storage;
2221 		ir.meta[pcp_type_id] = ir.meta[ib_type.self];
2222 		uint32_t mbr_idx = uint32_t(ib_type.member_types.size());
2223 		ib_type.member_types.push_back(pcp_type_id);
2224 		set_member_name(ib_type.self, mbr_idx, "gl_in");
2225 	}
2226 
2227 	return ib_var_id;
2228 }
2229 
add_interface_block_pointer(uint32_t ib_var_id,StorageClass storage)2230 uint32_t CompilerMSL::add_interface_block_pointer(uint32_t ib_var_id, StorageClass storage)
2231 {
2232 	if (!ib_var_id)
2233 		return 0;
2234 
2235 	uint32_t ib_ptr_var_id;
2236 	uint32_t next_id = ir.increase_bound_by(3);
2237 	auto &ib_type = expression_type(ib_var_id);
2238 	if (get_execution_model() == ExecutionModelTessellationControl)
2239 	{
2240 		// Tessellation control per-vertex I/O is presented as an array, so we must
2241 		// do the same with our struct here.
2242 		uint32_t ib_ptr_type_id = next_id++;
2243 		auto &ib_ptr_type = set<SPIRType>(ib_ptr_type_id, ib_type);
2244 		ib_ptr_type.parent_type = ib_ptr_type.type_alias = ib_type.self;
2245 		ib_ptr_type.pointer = true;
2246 		ib_ptr_type.storage = storage == StorageClassInput ? StorageClassWorkgroup : StorageClassStorageBuffer;
2247 		ir.meta[ib_ptr_type_id] = ir.meta[ib_type.self];
2248 		// To ensure that get_variable_data_type() doesn't strip off the pointer,
2249 		// which we need, use another pointer.
2250 		uint32_t ib_ptr_ptr_type_id = next_id++;
2251 		auto &ib_ptr_ptr_type = set<SPIRType>(ib_ptr_ptr_type_id, ib_ptr_type);
2252 		ib_ptr_ptr_type.parent_type = ib_ptr_type_id;
2253 		ib_ptr_ptr_type.type_alias = ib_type.self;
2254 		ib_ptr_ptr_type.storage = StorageClassFunction;
2255 		ir.meta[ib_ptr_ptr_type_id] = ir.meta[ib_type.self];
2256 
2257 		ib_ptr_var_id = next_id;
2258 		set<SPIRVariable>(ib_ptr_var_id, ib_ptr_ptr_type_id, StorageClassFunction, 0);
2259 		set_name(ib_ptr_var_id, storage == StorageClassInput ? input_wg_var_name : "gl_out");
2260 	}
2261 	else
2262 	{
2263 		// Tessellation evaluation per-vertex inputs are also presented as arrays.
2264 		// But, in Metal, this array uses a very special type, 'patch_control_point<T>',
2265 		// which is a container that can be used to access the control point data.
2266 		// To represent this, a special 'ControlPointArray' type has been added to the
2267 		// SPIRV-Cross type system. It should only be generated by and seen in the MSL
2268 		// backend (i.e. this one).
2269 		uint32_t pcp_type_id = next_id++;
2270 		auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
2271 		pcp_type.basetype = SPIRType::ControlPointArray;
2272 		pcp_type.parent_type = pcp_type.type_alias = ib_type.self;
2273 		pcp_type.storage = storage;
2274 		ir.meta[pcp_type_id] = ir.meta[ib_type.self];
2275 
2276 		ib_ptr_var_id = next_id;
2277 		set<SPIRVariable>(ib_ptr_var_id, pcp_type_id, storage, 0);
2278 		set_name(ib_ptr_var_id, "gl_in");
2279 		ir.meta[ib_ptr_var_id].decoration.qualified_alias = join(patch_stage_in_var_name, ".gl_in");
2280 	}
2281 	return ib_ptr_var_id;
2282 }
2283 
2284 // Ensure that the type is compatible with the builtin.
2285 // If it is, simply return the given type ID.
2286 // Otherwise, create a new type, and return it's ID.
ensure_correct_builtin_type(uint32_t type_id,BuiltIn builtin)2287 uint32_t CompilerMSL::ensure_correct_builtin_type(uint32_t type_id, BuiltIn builtin)
2288 {
2289 	auto &type = get<SPIRType>(type_id);
2290 
2291 	if ((builtin == BuiltInSampleMask && is_array(type)) ||
2292 	    ((builtin == BuiltInLayer || builtin == BuiltInViewportIndex || builtin == BuiltInFragStencilRefEXT) &&
2293 	     type.basetype != SPIRType::UInt))
2294 	{
2295 		uint32_t next_id = ir.increase_bound_by(type.pointer ? 2 : 1);
2296 		uint32_t base_type_id = next_id++;
2297 		auto &base_type = set<SPIRType>(base_type_id);
2298 		base_type.basetype = SPIRType::UInt;
2299 		base_type.width = 32;
2300 
2301 		if (!type.pointer)
2302 			return base_type_id;
2303 
2304 		uint32_t ptr_type_id = next_id++;
2305 		auto &ptr_type = set<SPIRType>(ptr_type_id);
2306 		ptr_type = base_type;
2307 		ptr_type.pointer = true;
2308 		ptr_type.storage = type.storage;
2309 		ptr_type.parent_type = base_type_id;
2310 		return ptr_type_id;
2311 	}
2312 
2313 	return type_id;
2314 }
2315 
2316 // Ensure that the type is compatible with the vertex attribute.
2317 // If it is, simply return the given type ID.
2318 // Otherwise, create a new type, and return its ID.
ensure_correct_attribute_type(uint32_t type_id,uint32_t location)2319 uint32_t CompilerMSL::ensure_correct_attribute_type(uint32_t type_id, uint32_t location)
2320 {
2321 	auto &type = get<SPIRType>(type_id);
2322 
2323 	auto p_va = vtx_attrs_by_location.find(location);
2324 	if (p_va == end(vtx_attrs_by_location))
2325 		return type_id;
2326 
2327 	switch (p_va->second.format)
2328 	{
2329 	case MSL_VERTEX_FORMAT_UINT8:
2330 	{
2331 		switch (type.basetype)
2332 		{
2333 		case SPIRType::UByte:
2334 		case SPIRType::UShort:
2335 		case SPIRType::UInt:
2336 			return type_id;
2337 		case SPIRType::Short:
2338 		case SPIRType::Int:
2339 			break;
2340 		default:
2341 			SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
2342 		}
2343 		uint32_t next_id = ir.increase_bound_by(type.pointer ? 2 : 1);
2344 		uint32_t base_type_id = next_id++;
2345 		auto &base_type = set<SPIRType>(base_type_id);
2346 		base_type = type;
2347 		base_type.basetype = type.basetype == SPIRType::Short ? SPIRType::UShort : SPIRType::UInt;
2348 		base_type.pointer = false;
2349 
2350 		if (!type.pointer)
2351 			return base_type_id;
2352 
2353 		uint32_t ptr_type_id = next_id++;
2354 		auto &ptr_type = set<SPIRType>(ptr_type_id);
2355 		ptr_type = base_type;
2356 		ptr_type.pointer = true;
2357 		ptr_type.storage = type.storage;
2358 		ptr_type.parent_type = base_type_id;
2359 		return ptr_type_id;
2360 	}
2361 
2362 	case MSL_VERTEX_FORMAT_UINT16:
2363 	{
2364 		switch (type.basetype)
2365 		{
2366 		case SPIRType::UShort:
2367 		case SPIRType::UInt:
2368 			return type_id;
2369 		case SPIRType::Int:
2370 			break;
2371 		default:
2372 			SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
2373 		}
2374 		uint32_t next_id = ir.increase_bound_by(type.pointer ? 2 : 1);
2375 		uint32_t base_type_id = next_id++;
2376 		auto &base_type = set<SPIRType>(base_type_id);
2377 		base_type = type;
2378 		base_type.basetype = SPIRType::UInt;
2379 		base_type.pointer = false;
2380 
2381 		if (!type.pointer)
2382 			return base_type_id;
2383 
2384 		uint32_t ptr_type_id = next_id++;
2385 		auto &ptr_type = set<SPIRType>(ptr_type_id);
2386 		ptr_type = base_type;
2387 		ptr_type.pointer = true;
2388 		ptr_type.storage = type.storage;
2389 		ptr_type.parent_type = base_type_id;
2390 		return ptr_type_id;
2391 	}
2392 
2393 	default:
2394 	case MSL_VERTEX_FORMAT_OTHER:
2395 		break;
2396 	}
2397 
2398 	return type_id;
2399 }
2400 
2401 // Sort the members of the struct type by offset, and pack and then pad members where needed
2402 // to align MSL members with SPIR-V offsets. The struct members are iterated twice. Packing
2403 // occurs first, followed by padding, because packing a member reduces both its size and its
2404 // natural alignment, possibly requiring a padding member to be added ahead of it.
align_struct(SPIRType & ib_type)2405 void CompilerMSL::align_struct(SPIRType &ib_type)
2406 {
2407 	uint32_t &ib_type_id = ib_type.self;
2408 
2409 	// Sort the members of the interface structure by their offset.
2410 	// They should already be sorted per SPIR-V spec anyway.
2411 	MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Offset);
2412 	member_sorter.sort();
2413 
2414 	uint32_t mbr_cnt = uint32_t(ib_type.member_types.size());
2415 
2416 	// Test the alignment of each member, and if a member should be closer to the previous
2417 	// member than the default spacing expects, it is likely that the previous member is in
2418 	// a packed format. If so, and the previous member is packable, pack it.
2419 	// For example...this applies to any 3-element vector that is followed by a scalar.
2420 	uint32_t curr_offset = 0;
2421 	for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
2422 	{
2423 		if (is_member_packable(ib_type, mbr_idx))
2424 		{
2425 			set_extended_member_decoration(ib_type_id, mbr_idx, SPIRVCrossDecorationPacked);
2426 			set_extended_member_decoration(ib_type_id, mbr_idx, SPIRVCrossDecorationPackedType,
2427 			                               get_member_packed_type(ib_type, mbr_idx));
2428 		}
2429 
2430 		// Align current offset to the current member's default alignment.
2431 		size_t align_mask = get_declared_struct_member_alignment(ib_type, mbr_idx) - 1;
2432 		uint32_t aligned_curr_offset = uint32_t((curr_offset + align_mask) & ~align_mask);
2433 
2434 		// Fetch the member offset as declared in the SPIRV.
2435 		uint32_t mbr_offset = get_member_decoration(ib_type_id, mbr_idx, DecorationOffset);
2436 		if (mbr_offset > aligned_curr_offset)
2437 		{
2438 			// Since MSL and SPIR-V have slightly different struct member alignment and
2439 			// size rules, we'll pad to standard C-packing rules. If the member is farther
2440 			// away than C-packing, expects, add an inert padding member before the the member.
2441 			MSLStructMemberKey key = get_struct_member_key(ib_type_id, mbr_idx);
2442 			struct_member_padding[key] = mbr_offset - curr_offset;
2443 		}
2444 
2445 		// Increment the current offset to be positioned immediately after the current member.
2446 		// Don't do this for the last member since it can be unsized, and it is not relevant for padding purposes here.
2447 		if (mbr_idx + 1 < mbr_cnt)
2448 			curr_offset = mbr_offset + uint32_t(get_declared_struct_member_size_msl(ib_type, mbr_idx));
2449 	}
2450 }
2451 
2452 // Returns whether the specified struct member supports a packable type
2453 // variation that is smaller than the unpacked variation of that type.
is_member_packable(SPIRType & ib_type,uint32_t index,uint32_t base_offset)2454 bool CompilerMSL::is_member_packable(SPIRType &ib_type, uint32_t index, uint32_t base_offset)
2455 {
2456 	// We've already marked it as packable
2457 	if (has_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPacked))
2458 		return true;
2459 
2460 	auto &mbr_type = get<SPIRType>(ib_type.member_types[index]);
2461 
2462 	uint32_t component_size = mbr_type.width / 8;
2463 	uint32_t unpacked_mbr_size;
2464 	if (mbr_type.vecsize == 3)
2465 		unpacked_mbr_size = component_size * (mbr_type.vecsize + 1) * mbr_type.columns;
2466 	else
2467 		unpacked_mbr_size = component_size * mbr_type.vecsize * mbr_type.columns;
2468 
2469 	// Special case for packing. Check for float[] or vec2[] in std140 layout. Here we actually need to pad out instead,
2470 	// but we will use the same mechanism.
2471 	if (is_array(mbr_type) && (is_scalar(mbr_type) || is_vector(mbr_type)) && mbr_type.vecsize <= 2 &&
2472 	    type_struct_member_array_stride(ib_type, index) == 4 * component_size)
2473 	{
2474 		return true;
2475 	}
2476 
2477 	uint32_t mbr_offset_curr = base_offset + get_member_decoration(ib_type.self, index, DecorationOffset);
2478 	if (mbr_type.basetype == SPIRType::Struct)
2479 	{
2480 		// If this is a struct type, check if any of its members need packing.
2481 		for (uint32_t i = 0; i < mbr_type.member_types.size(); i++)
2482 		{
2483 			if (is_member_packable(mbr_type, i, mbr_offset_curr))
2484 			{
2485 				set_extended_member_decoration(mbr_type.self, i, SPIRVCrossDecorationPacked);
2486 				set_extended_member_decoration(mbr_type.self, i, SPIRVCrossDecorationPackedType,
2487 				                               get_member_packed_type(mbr_type, i));
2488 			}
2489 		}
2490 		size_t declared_struct_size = get_declared_struct_size(mbr_type);
2491 		size_t alignment = get_declared_struct_member_alignment(ib_type, index);
2492 		declared_struct_size = (declared_struct_size + alignment - 1) & ~(alignment - 1);
2493 		// Check for array of struct, where the SPIR-V declares an array stride which is larger than the struct itself.
2494 		// This can happen for struct A { float a }; A a[]; in std140 layout.
2495 		// TODO: Emit a padded struct which can be used for this purpose.
2496 		if (is_array(mbr_type))
2497 		{
2498 			size_t array_stride = type_struct_member_array_stride(ib_type, index);
2499 			if (array_stride > declared_struct_size)
2500 				return true;
2501 			if (array_stride < declared_struct_size)
2502 			{
2503 				// If the stride is *less* (i.e. more tightly packed), then
2504 				// we need to pack the members of the struct itself.
2505 				for (uint32_t i = 0; i < mbr_type.member_types.size(); i++)
2506 				{
2507 					if (is_member_packable(mbr_type, i, mbr_offset_curr + array_stride))
2508 					{
2509 						set_extended_member_decoration(mbr_type.self, i, SPIRVCrossDecorationPacked);
2510 						set_extended_member_decoration(mbr_type.self, i, SPIRVCrossDecorationPackedType,
2511 						                               get_member_packed_type(mbr_type, i));
2512 					}
2513 				}
2514 			}
2515 		}
2516 		else
2517 		{
2518 			// Pack if there is not enough space between this member and next.
2519 			if (index < ib_type.member_types.size() - 1)
2520 			{
2521 				uint32_t mbr_offset_next =
2522 				    base_offset + get_member_decoration(ib_type.self, index + 1, DecorationOffset);
2523 				if (declared_struct_size > mbr_offset_next - mbr_offset_curr)
2524 				{
2525 					for (uint32_t i = 0; i < mbr_type.member_types.size(); i++)
2526 					{
2527 						if (is_member_packable(mbr_type, i, mbr_offset_next))
2528 						{
2529 							set_extended_member_decoration(mbr_type.self, i, SPIRVCrossDecorationPacked);
2530 							set_extended_member_decoration(mbr_type.self, i, SPIRVCrossDecorationPackedType,
2531 							                               get_member_packed_type(mbr_type, i));
2532 						}
2533 					}
2534 				}
2535 			}
2536 		}
2537 	}
2538 
2539 	// TODO: Another sanity check for matrices. We currently do not support std140 matrices which need to be padded out per column.
2540 	//if (is_matrix(mbr_type) && mbr_type.vecsize <= 2 && type_struct_member_matrix_stride(ib_type, index) == 16)
2541 	//	SPIRV_CROSS_THROW("Currently cannot support matrices with small vector size in std140 layout.");
2542 
2543 	// Pack if the member's offset doesn't conform to the type's usual
2544 	// alignment. For example, a float3 at offset 4.
2545 	if (mbr_offset_curr % get_declared_struct_member_alignment(ib_type, index))
2546 		return true;
2547 
2548 	// Only vectors or 3-row matrices need to be packed.
2549 	if (mbr_type.vecsize == 1 || (is_matrix(mbr_type) && mbr_type.vecsize != 3))
2550 		return false;
2551 
2552 	if (is_array(mbr_type))
2553 	{
2554 		// If member is an array, and the array stride is larger than the type needs, don't pack it.
2555 		// Take into consideration multi-dimentional arrays.
2556 		uint32_t md_elem_cnt = 1;
2557 		size_t last_elem_idx = mbr_type.array.size() - 1;
2558 		for (uint32_t i = 0; i < last_elem_idx; i++)
2559 #ifdef RARCH_INTERNAL
2560 			md_elem_cnt *= MAX(to_array_size_literal(mbr_type, i), 1u);
2561 #else
2562 			md_elem_cnt *= max(to_array_size_literal(mbr_type, i), 1u);
2563 #endif
2564 
2565 		uint32_t unpacked_array_stride = unpacked_mbr_size * md_elem_cnt;
2566 		uint32_t array_stride = type_struct_member_array_stride(ib_type, index);
2567 		return unpacked_array_stride > array_stride;
2568 	}
2569 	else
2570 	{
2571 		// Pack if there is not enough space between this member and next.
2572 		// If last member, only pack if it's a row-major matrix.
2573 		if (index < ib_type.member_types.size() - 1)
2574 		{
2575 			uint32_t mbr_offset_next = base_offset + get_member_decoration(ib_type.self, index + 1, DecorationOffset);
2576 			return unpacked_mbr_size > mbr_offset_next - mbr_offset_curr;
2577 		}
2578 		else
2579 			return is_matrix(mbr_type);
2580 	}
2581 }
2582 
get_member_packed_type(SPIRType & type,uint32_t index)2583 uint32_t CompilerMSL::get_member_packed_type(SPIRType &type, uint32_t index)
2584 {
2585 	auto &mbr_type = get<SPIRType>(type.member_types[index]);
2586 	if (is_matrix(mbr_type) && has_member_decoration(type.self, index, DecorationRowMajor))
2587 	{
2588 		// Packed row-major matrices are stored transposed. But, we don't know if
2589 		// we're dealing with a row-major matrix at the time we need to load it.
2590 		// So, we'll set a packed type with the columns and rows transposed, so we'll
2591 		// know to use the correct constructor.
2592 		uint32_t new_type_id = ir.increase_bound_by(1);
2593 		auto &transpose_type = set<SPIRType>(new_type_id);
2594 		transpose_type = mbr_type;
2595 		transpose_type.vecsize = mbr_type.columns;
2596 		transpose_type.columns = mbr_type.vecsize;
2597 		return new_type_id;
2598 	}
2599 	return type.member_types[index];
2600 }
2601 
2602 // Returns a combination of type ID and member index for use as hash key
get_struct_member_key(uint32_t type_id,uint32_t index)2603 MSLStructMemberKey CompilerMSL::get_struct_member_key(uint32_t type_id, uint32_t index)
2604 {
2605 	MSLStructMemberKey k = type_id;
2606 	k <<= 32;
2607 	k += index;
2608 	return k;
2609 }
2610 
emit_store_statement(uint32_t lhs_expression,uint32_t rhs_expression)2611 void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression)
2612 {
2613 	if (!has_extended_decoration(lhs_expression, SPIRVCrossDecorationPacked) ||
2614 	    get_extended_decoration(lhs_expression, SPIRVCrossDecorationPackedType) == 0)
2615 	{
2616 		CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
2617 	}
2618 	else
2619 	{
2620 		// Special handling when storing to a float[] or float2[] in std140 layout.
2621 
2622 		uint32_t type_id = get_extended_decoration(lhs_expression, SPIRVCrossDecorationPackedType);
2623 		auto &type = get<SPIRType>(type_id);
2624 		string lhs = to_dereferenced_expression(lhs_expression);
2625 		string rhs = to_pointer_expression(rhs_expression);
2626 		uint32_t stride = get_decoration(type_id, DecorationArrayStride);
2627 
2628 		if (is_matrix(type))
2629 		{
2630 			// Packed matrices are stored as arrays of packed vectors, so we need
2631 			// to assign the vectors one at a time.
2632 			// For row-major matrices, we need to transpose the *right-hand* side,
2633 			// not the left-hand side. Otherwise, the changes will be lost.
2634 			auto *lhs_e = maybe_get<SPIRExpression>(lhs_expression);
2635 			auto *rhs_e = maybe_get<SPIRExpression>(rhs_expression);
2636 			bool transpose = lhs_e && lhs_e->need_transpose;
2637 			if (transpose)
2638 			{
2639 				lhs_e->need_transpose = false;
2640 				if (rhs_e) rhs_e->need_transpose = !rhs_e->need_transpose;
2641 				lhs = to_dereferenced_expression(lhs_expression);
2642 				rhs = to_pointer_expression(rhs_expression);
2643 			}
2644 			for (uint32_t i = 0; i < type.columns; i++)
2645 				statement(enclose_expression(lhs), "[", i, "] = ", enclose_expression(rhs), "[", i, "];");
2646 			if (transpose)
2647 			{
2648 				lhs_e->need_transpose = true;
2649 				if (rhs_e) rhs_e->need_transpose = !rhs_e->need_transpose;
2650 			}
2651 		}
2652 		else if (is_array(type) && stride == 4 * type.width / 8)
2653 		{
2654 			// Unpack the expression so we can store to it with a float or float2.
2655 			// It's still an l-value, so it's fine. Most other unpacking of expressions turn them into r-values instead.
2656 			if (is_scalar(type))
2657 				lhs = enclose_expression(lhs) + ".x";
2658 			else if (is_vector(type) && type.vecsize == 2)
2659 				lhs = enclose_expression(lhs) + ".xy";
2660 		}
2661 
2662 		if (!is_matrix(type))
2663 		{
2664 			if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs))
2665 				statement(lhs, " = ", rhs, ";");
2666 		}
2667 		register_write(lhs_expression);
2668 	}
2669 }
2670 
2671 // Converts the format of the current expression from packed to unpacked,
2672 // by wrapping the expression in a constructor of the appropriate type.
unpack_expression_type(string expr_str,const SPIRType & type,uint32_t packed_type_id)2673 string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type, uint32_t packed_type_id)
2674 {
2675 	const SPIRType *packed_type = nullptr;
2676 	uint32_t stride = 0;
2677 	if (packed_type_id)
2678 	{
2679 		packed_type = &get<SPIRType>(packed_type_id);
2680 		stride = get_decoration(packed_type_id, DecorationArrayStride);
2681 	}
2682 
2683 	// float[] and float2[] cases are really just padding, so directly swizzle from the backing float4 instead.
2684 	if (packed_type && is_array(*packed_type) && is_scalar(*packed_type) && stride == 4 * packed_type->width / 8)
2685 		return enclose_expression(expr_str) + ".x";
2686 	else if (packed_type && is_array(*packed_type) && is_vector(*packed_type) && packed_type->vecsize == 2 &&
2687 	         stride == 4 * packed_type->width / 8)
2688 		return enclose_expression(expr_str) + ".xy";
2689 	else if (is_matrix(type))
2690 	{
2691 		// Packed matrices are stored as arrays of packed vectors. Unfortunately,
2692 		// we can't just pass the array straight to the matrix constructor. We have to
2693 		// pass each vector individually, so that they can be unpacked to normal vectors.
2694 		if (!packed_type)
2695 			packed_type = &type;
2696 		const char *base_type = packed_type->width == 16 ? "half" : "float";
2697 		string unpack_expr = join(type_to_glsl(*packed_type), "(");
2698 		for (uint32_t i = 0; i < packed_type->columns; i++)
2699 		{
2700 			if (i > 0)
2701 				unpack_expr += ", ";
2702 			unpack_expr += join(base_type, packed_type->vecsize, "(", expr_str, "[", i, "])");
2703 		}
2704 		unpack_expr += ")";
2705 		return unpack_expr;
2706 	}
2707 	else
2708 		return join(type_to_glsl(type), "(", expr_str, ")");
2709 }
2710 
2711 // Emits the file header info
emit_header()2712 void CompilerMSL::emit_header()
2713 {
2714 	// This particular line can be overridden during compilation, so make it a flag and not a pragma line.
2715 	if (suppress_missing_prototypes)
2716 		statement("#pragma clang diagnostic ignored \"-Wmissing-prototypes\"");
2717 	for (auto &pragma : pragma_lines)
2718 		statement(pragma);
2719 
2720 	if (!pragma_lines.empty() || suppress_missing_prototypes)
2721 		statement("");
2722 
2723 	statement("#include <metal_stdlib>");
2724 	statement("#include <simd/simd.h>");
2725 
2726 	for (auto &header : header_lines)
2727 		statement(header);
2728 
2729 	statement("");
2730 	statement("using namespace metal;");
2731 	statement("");
2732 
2733 	for (auto &td : typedef_lines)
2734 		statement(td);
2735 
2736 	if (!typedef_lines.empty())
2737 		statement("");
2738 }
2739 
add_pragma_line(const string & line)2740 void CompilerMSL::add_pragma_line(const string &line)
2741 {
2742 	auto rslt = pragma_lines.insert(line);
2743 	if (rslt.second)
2744 		force_recompile();
2745 }
2746 
add_typedef_line(const string & line)2747 void CompilerMSL::add_typedef_line(const string &line)
2748 {
2749 	auto rslt = typedef_lines.insert(line);
2750 	if (rslt.second)
2751 		force_recompile();
2752 }
2753 
2754 // Emits any needed custom function bodies.
emit_custom_functions()2755 void CompilerMSL::emit_custom_functions()
2756 {
2757 	for (uint32_t i = SPVFuncImplArrayCopyMultidimMax; i >= 2; i--)
2758 		if (spv_function_implementations.count(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i)))
2759 			spv_function_implementations.insert(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i - 1));
2760 
2761 	for (auto &spv_func : spv_function_implementations)
2762 	{
2763 		switch (spv_func)
2764 		{
2765 		case SPVFuncImplMod:
2766 			statement("// Implementation of the GLSL mod() function, which is slightly different than Metal fmod()");
2767 			statement("template<typename Tx, typename Ty>");
2768 			statement("Tx mod(Tx x, Ty y)");
2769 			begin_scope();
2770 			statement("return x - y * floor(x / y);");
2771 			end_scope();
2772 			statement("");
2773 			break;
2774 
2775 		case SPVFuncImplRadians:
2776 			statement("// Implementation of the GLSL radians() function");
2777 			statement("template<typename T>");
2778 			statement("T radians(T d)");
2779 			begin_scope();
2780 			statement("return d * T(0.01745329251);");
2781 			end_scope();
2782 			statement("");
2783 			break;
2784 
2785 		case SPVFuncImplDegrees:
2786 			statement("// Implementation of the GLSL degrees() function");
2787 			statement("template<typename T>");
2788 			statement("T degrees(T r)");
2789 			begin_scope();
2790 			statement("return r * T(57.2957795131);");
2791 			end_scope();
2792 			statement("");
2793 			break;
2794 
2795 		case SPVFuncImplFindILsb:
2796 			statement("// Implementation of the GLSL findLSB() function");
2797 			statement("template<typename T>");
2798 			statement("T findLSB(T x)");
2799 			begin_scope();
2800 			statement("return select(ctz(x), T(-1), x == T(0));");
2801 			end_scope();
2802 			statement("");
2803 			break;
2804 
2805 		case SPVFuncImplFindUMsb:
2806 			statement("// Implementation of the unsigned GLSL findMSB() function");
2807 			statement("template<typename T>");
2808 			statement("T findUMSB(T x)");
2809 			begin_scope();
2810 			statement("return select(clz(T(0)) - (clz(x) + T(1)), T(-1), x == T(0));");
2811 			end_scope();
2812 			statement("");
2813 			break;
2814 
2815 		case SPVFuncImplFindSMsb:
2816 			statement("// Implementation of the signed GLSL findMSB() function");
2817 			statement("template<typename T>");
2818 			statement("T findSMSB(T x)");
2819 			begin_scope();
2820 			statement("T v = select(x, T(-1) - x, x < T(0));");
2821 			statement("return select(clz(T(0)) - (clz(v) + T(1)), T(-1), v == T(0));");
2822 			end_scope();
2823 			statement("");
2824 			break;
2825 
2826 		case SPVFuncImplSSign:
2827 			statement("// Implementation of the GLSL sign() function for integer types");
2828 			statement("template<typename T, typename E = typename enable_if<is_integral<T>::value>::type>");
2829 			statement("T sign(T x)");
2830 			begin_scope();
2831 			statement("return select(select(select(x, T(0), x == T(0)), T(1), x > T(0)), T(-1), x < T(0));");
2832 			end_scope();
2833 			statement("");
2834 			break;
2835 
2836 		case SPVFuncImplArrayCopy:
2837 			statement("// Implementation of an array copy function to cover GLSL's ability to copy an array via "
2838 			          "assignment.");
2839 			statement("template<typename T, uint N>");
2840 			statement("void spvArrayCopyFromStack1(thread T (&dst)[N], thread const T (&src)[N])");
2841 			begin_scope();
2842 			statement("for (uint i = 0; i < N; dst[i] = src[i], i++);");
2843 			end_scope();
2844 			statement("");
2845 
2846 			statement("template<typename T, uint N>");
2847 			statement("void spvArrayCopyFromConstant1(thread T (&dst)[N], constant T (&src)[N])");
2848 			begin_scope();
2849 			statement("for (uint i = 0; i < N; dst[i] = src[i], i++);");
2850 			end_scope();
2851 			statement("");
2852 			break;
2853 
2854 		case SPVFuncImplArrayOfArrayCopy2Dim:
2855 		case SPVFuncImplArrayOfArrayCopy3Dim:
2856 		case SPVFuncImplArrayOfArrayCopy4Dim:
2857 		case SPVFuncImplArrayOfArrayCopy5Dim:
2858 		case SPVFuncImplArrayOfArrayCopy6Dim:
2859 		{
2860 			static const char *function_name_tags[] = {
2861 				"FromStack",
2862 				"FromConstant",
2863 			};
2864 
2865 			static const char *src_address_space[] = {
2866 				"thread const",
2867 				"constant",
2868 			};
2869 
2870 			for (uint32_t variant = 0; variant < 2; variant++)
2871 			{
2872 				uint32_t dimensions = spv_func - SPVFuncImplArrayCopyMultidimBase;
2873 				string tmp = "template<typename T";
2874 				for (uint8_t i = 0; i < dimensions; i++)
2875 				{
2876 					tmp += ", uint ";
2877 					tmp += 'A' + i;
2878 				}
2879 				tmp += ">";
2880 				statement(tmp);
2881 
2882 				string array_arg;
2883 				for (uint8_t i = 0; i < dimensions; i++)
2884 				{
2885 					array_arg += "[";
2886 					array_arg += 'A' + i;
2887 					array_arg += "]";
2888 				}
2889 
2890 				statement("void spvArrayCopy", function_name_tags[variant], dimensions, "(thread T (&dst)", array_arg,
2891 				          ", ", src_address_space[variant], " T (&src)", array_arg, ")");
2892 
2893 				begin_scope();
2894 				statement("for (uint i = 0; i < A; i++)");
2895 				begin_scope();
2896 				statement("spvArrayCopy", function_name_tags[variant], dimensions - 1, "(dst[i], src[i]);");
2897 				end_scope();
2898 				end_scope();
2899 				statement("");
2900 			}
2901 			break;
2902 		}
2903 
2904 		case SPVFuncImplTexelBufferCoords:
2905 		{
2906 			string tex_width_str = convert_to_string(msl_options.texel_buffer_texture_width);
2907 			statement("// Returns 2D texture coords corresponding to 1D texel buffer coords");
2908 			statement("uint2 spvTexelBufferCoord(uint tc)");
2909 			begin_scope();
2910 			statement(join("return uint2(tc % ", tex_width_str, ", tc / ", tex_width_str, ");"));
2911 			end_scope();
2912 			statement("");
2913 			break;
2914 		}
2915 
2916 		case SPVFuncImplInverse4x4:
2917 			statement("// Returns the determinant of a 2x2 matrix.");
2918 			statement("inline float spvDet2x2(float a1, float a2, float b1, float b2)");
2919 			begin_scope();
2920 			statement("return a1 * b2 - b1 * a2;");
2921 			end_scope();
2922 			statement("");
2923 
2924 			statement("// Returns the determinant of a 3x3 matrix.");
2925 			statement("inline float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
2926 			          "float c2, float c3)");
2927 			begin_scope();
2928 			statement("return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * spvDet2x2(a2, a3, "
2929 			          "b2, b3);");
2930 			end_scope();
2931 			statement("");
2932 			statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
2933 			statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
2934 			statement("float4x4 spvInverse4x4(float4x4 m)");
2935 			begin_scope();
2936 			statement("float4x4 adj;	// The adjoint matrix (inverse after dividing by determinant)");
2937 			statement_no_indent("");
2938 			statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
2939 			statement("adj[0][0] =  spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
2940 			          "m[3][3]);");
2941 			statement("adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
2942 			          "m[3][3]);");
2943 			statement("adj[0][2] =  spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
2944 			          "m[3][3]);");
2945 			statement("adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
2946 			          "m[2][3]);");
2947 			statement_no_indent("");
2948 			statement("adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
2949 			          "m[3][3]);");
2950 			statement("adj[1][1] =  spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
2951 			          "m[3][3]);");
2952 			statement("adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
2953 			          "m[3][3]);");
2954 			statement("adj[1][3] =  spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
2955 			          "m[2][3]);");
2956 			statement_no_indent("");
2957 			statement("adj[2][0] =  spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
2958 			          "m[3][3]);");
2959 			statement("adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
2960 			          "m[3][3]);");
2961 			statement("adj[2][2] =  spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
2962 			          "m[3][3]);");
2963 			statement("adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
2964 			          "m[2][3]);");
2965 			statement_no_indent("");
2966 			statement("adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
2967 			          "m[3][2]);");
2968 			statement("adj[3][1] =  spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
2969 			          "m[3][2]);");
2970 			statement("adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
2971 			          "m[3][2]);");
2972 			statement("adj[3][3] =  spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
2973 			          "m[2][2]);");
2974 			statement_no_indent("");
2975 			statement("// Calculate the determinant as a combination of the cofactors of the first row.");
2976 			statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
2977 			          "* m[3][0]);");
2978 			statement_no_indent("");
2979 			statement("// Divide the classical adjoint matrix by the determinant.");
2980 			statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
2981 			statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
2982 			end_scope();
2983 			statement("");
2984 			break;
2985 
2986 		case SPVFuncImplInverse3x3:
2987 			if (spv_function_implementations.count(SPVFuncImplInverse4x4) == 0)
2988 			{
2989 				statement("// Returns the determinant of a 2x2 matrix.");
2990 				statement("inline float spvDet2x2(float a1, float a2, float b1, float b2)");
2991 				begin_scope();
2992 				statement("return a1 * b2 - b1 * a2;");
2993 				end_scope();
2994 				statement("");
2995 			}
2996 
2997 			statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
2998 			statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
2999 			statement("float3x3 spvInverse3x3(float3x3 m)");
3000 			begin_scope();
3001 			statement("float3x3 adj;	// The adjoint matrix (inverse after dividing by determinant)");
3002 			statement_no_indent("");
3003 			statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
3004 			statement("adj[0][0] =  spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
3005 			statement("adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
3006 			statement("adj[0][2] =  spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
3007 			statement_no_indent("");
3008 			statement("adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
3009 			statement("adj[1][1] =  spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
3010 			statement("adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
3011 			statement_no_indent("");
3012 			statement("adj[2][0] =  spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
3013 			statement("adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
3014 			statement("adj[2][2] =  spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
3015 			statement_no_indent("");
3016 			statement("// Calculate the determinant as a combination of the cofactors of the first row.");
3017 			statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
3018 			statement_no_indent("");
3019 			statement("// Divide the classical adjoint matrix by the determinant.");
3020 			statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
3021 			statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
3022 			end_scope();
3023 			statement("");
3024 			break;
3025 
3026 		case SPVFuncImplInverse2x2:
3027 			statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
3028 			statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
3029 			statement("float2x2 spvInverse2x2(float2x2 m)");
3030 			begin_scope();
3031 			statement("float2x2 adj;	// The adjoint matrix (inverse after dividing by determinant)");
3032 			statement_no_indent("");
3033 			statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
3034 			statement("adj[0][0] =  m[1][1];");
3035 			statement("adj[0][1] = -m[0][1];");
3036 			statement_no_indent("");
3037 			statement("adj[1][0] = -m[1][0];");
3038 			statement("adj[1][1] =  m[0][0];");
3039 			statement_no_indent("");
3040 			statement("// Calculate the determinant as a combination of the cofactors of the first row.");
3041 			statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
3042 			statement_no_indent("");
3043 			statement("// Divide the classical adjoint matrix by the determinant.");
3044 			statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
3045 			statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
3046 			end_scope();
3047 			statement("");
3048 			break;
3049 
3050 		case SPVFuncImplRowMajor2x3:
3051 			statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
3052 			statement("float2x3 spvConvertFromRowMajor2x3(float2x3 m)");
3053 			begin_scope();
3054 			statement("return float2x3(float3(m[0][0], m[0][2], m[1][1]), float3(m[0][1], m[1][0], m[1][2]));");
3055 			end_scope();
3056 			statement("");
3057 			break;
3058 
3059 		case SPVFuncImplRowMajor2x4:
3060 			statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
3061 			statement("float2x4 spvConvertFromRowMajor2x4(float2x4 m)");
3062 			begin_scope();
3063 			statement("return float2x4(float4(m[0][0], m[0][2], m[1][0], m[1][2]), float4(m[0][1], m[0][3], m[1][1], "
3064 			          "m[1][3]));");
3065 			end_scope();
3066 			statement("");
3067 			break;
3068 
3069 		case SPVFuncImplRowMajor3x2:
3070 			statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
3071 			statement("float3x2 spvConvertFromRowMajor3x2(float3x2 m)");
3072 			begin_scope();
3073 			statement("return float3x2(float2(m[0][0], m[1][1]), float2(m[0][1], m[2][0]), float2(m[1][0], m[2][1]));");
3074 			end_scope();
3075 			statement("");
3076 			break;
3077 
3078 		case SPVFuncImplRowMajor3x4:
3079 			statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
3080 			statement("float3x4 spvConvertFromRowMajor3x4(float3x4 m)");
3081 			begin_scope();
3082 			statement("return float3x4(float4(m[0][0], m[0][3], m[1][2], m[2][1]), float4(m[0][1], m[1][0], m[1][3], "
3083 			          "m[2][2]), float4(m[0][2], m[1][1], m[2][0], m[2][3]));");
3084 			end_scope();
3085 			statement("");
3086 			break;
3087 
3088 		case SPVFuncImplRowMajor4x2:
3089 			statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
3090 			statement("float4x2 spvConvertFromRowMajor4x2(float4x2 m)");
3091 			begin_scope();
3092 			statement("return float4x2(float2(m[0][0], m[2][0]), float2(m[0][1], m[2][1]), float2(m[1][0], m[3][0]), "
3093 			          "float2(m[1][1], m[3][1]));");
3094 			end_scope();
3095 			statement("");
3096 			break;
3097 
3098 		case SPVFuncImplRowMajor4x3:
3099 			statement("// Implementation of a conversion of matrix content from RowMajor to ColumnMajor organization.");
3100 			statement("float4x3 spvConvertFromRowMajor4x3(float4x3 m)");
3101 			begin_scope();
3102 			statement("return float4x3(float3(m[0][0], m[1][1], m[2][2]), float3(m[0][1], m[1][2], m[3][0]), "
3103 			          "float3(m[0][2], m[2][0], m[3][1]), float3(m[1][0], m[2][1], m[3][2]));");
3104 			end_scope();
3105 			statement("");
3106 			break;
3107 
3108 		case SPVFuncImplTextureSwizzle:
3109 			statement("enum class spvSwizzle : uint");
3110 			begin_scope();
3111 			statement("none = 0,");
3112 			statement("zero,");
3113 			statement("one,");
3114 			statement("red,");
3115 			statement("green,");
3116 			statement("blue,");
3117 			statement("alpha");
3118 			end_scope_decl();
3119 			statement("");
3120 			statement("template<typename T> struct spvRemoveReference { typedef T type; };");
3121 			statement("template<typename T> struct spvRemoveReference<thread T&> { typedef T type; };");
3122 			statement("template<typename T> struct spvRemoveReference<thread T&&> { typedef T type; };");
3123 			statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
3124 			          "spvRemoveReference<T>::type& x)");
3125 			begin_scope();
3126 			statement("return static_cast<thread T&&>(x);");
3127 			end_scope();
3128 			statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
3129 			          "spvRemoveReference<T>::type&& x)");
3130 			begin_scope();
3131 			statement("return static_cast<thread T&&>(x);");
3132 			end_scope();
3133 			statement("");
3134 			statement("template<typename T>");
3135 			statement("inline T spvGetSwizzle(vec<T, 4> x, T c, spvSwizzle s)");
3136 			begin_scope();
3137 			statement("switch (s)");
3138 			begin_scope();
3139 			statement("case spvSwizzle::none:");
3140 			statement("    return c;");
3141 			statement("case spvSwizzle::zero:");
3142 			statement("    return 0;");
3143 			statement("case spvSwizzle::one:");
3144 			statement("    return 1;");
3145 			statement("case spvSwizzle::red:");
3146 			statement("    return x.r;");
3147 			statement("case spvSwizzle::green:");
3148 			statement("    return x.g;");
3149 			statement("case spvSwizzle::blue:");
3150 			statement("    return x.b;");
3151 			statement("case spvSwizzle::alpha:");
3152 			statement("    return x.a;");
3153 			end_scope();
3154 			end_scope();
3155 			statement("");
3156 			statement("// Wrapper function that swizzles texture samples and fetches.");
3157 			statement("template<typename T>");
3158 			statement("inline vec<T, 4> spvTextureSwizzle(vec<T, 4> x, uint s)");
3159 			begin_scope();
3160 			statement("if (!s)");
3161 			statement("    return x;");
3162 			statement("return vec<T, 4>(spvGetSwizzle(x, x.r, spvSwizzle((s >> 0) & 0xFF)), "
3163 			          "spvGetSwizzle(x, x.g, spvSwizzle((s >> 8) & 0xFF)), spvGetSwizzle(x, x.b, spvSwizzle((s >> 16) "
3164 			          "& 0xFF)), "
3165 			          "spvGetSwizzle(x, x.a, spvSwizzle((s >> 24) & 0xFF)));");
3166 			end_scope();
3167 			statement("");
3168 			statement("template<typename T>");
3169 			statement("inline T spvTextureSwizzle(T x, uint s)");
3170 			begin_scope();
3171 			statement("return spvTextureSwizzle(vec<T, 4>(x, 0, 0, 1), s).x;");
3172 			end_scope();
3173 			statement("");
3174 			statement("// Wrapper function that swizzles texture gathers.");
3175 			statement("template<typename T, typename Tex, typename... Ts>");
3176 			statement(
3177 			    "inline vec<T, 4> spvGatherSwizzle(sampler s, const thread Tex& t, Ts... params, component c, uint sw) "
3178 			    "METAL_CONST_ARG(c)");
3179 			begin_scope();
3180 			statement("if (sw)");
3181 			begin_scope();
3182 			statement("switch (spvSwizzle((sw >> (uint(c) * 8)) & 0xFF))");
3183 			begin_scope();
3184 			statement("case spvSwizzle::none:");
3185 			statement("    break;");
3186 			statement("case spvSwizzle::zero:");
3187 			statement("    return vec<T, 4>(0, 0, 0, 0);");
3188 			statement("case spvSwizzle::one:");
3189 			statement("    return vec<T, 4>(1, 1, 1, 1);");
3190 			statement("case spvSwizzle::red:");
3191 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::x);");
3192 			statement("case spvSwizzle::green:");
3193 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::y);");
3194 			statement("case spvSwizzle::blue:");
3195 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::z);");
3196 			statement("case spvSwizzle::alpha:");
3197 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::w);");
3198 			end_scope();
3199 			end_scope();
3200 			// texture::gather insists on its component parameter being a constant
3201 			// expression, so we need this silly workaround just to compile the shader.
3202 			statement("switch (c)");
3203 			begin_scope();
3204 			statement("case component::x:");
3205 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::x);");
3206 			statement("case component::y:");
3207 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::y);");
3208 			statement("case component::z:");
3209 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::z);");
3210 			statement("case component::w:");
3211 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::w);");
3212 			end_scope();
3213 			end_scope();
3214 			statement("");
3215 			statement("// Wrapper function that swizzles depth texture gathers.");
3216 			statement("template<typename T, typename Tex, typename... Ts>");
3217 			statement(
3218 			    "inline vec<T, 4> spvGatherCompareSwizzle(sampler s, const thread Tex& t, Ts... params, uint sw) ");
3219 			begin_scope();
3220 			statement("if (sw)");
3221 			begin_scope();
3222 			statement("switch (spvSwizzle(sw & 0xFF))");
3223 			begin_scope();
3224 			statement("case spvSwizzle::none:");
3225 			statement("case spvSwizzle::red:");
3226 			statement("    break;");
3227 			statement("case spvSwizzle::zero:");
3228 			statement("case spvSwizzle::green:");
3229 			statement("case spvSwizzle::blue:");
3230 			statement("case spvSwizzle::alpha:");
3231 			statement("    return vec<T, 4>(0, 0, 0, 0);");
3232 			statement("case spvSwizzle::one:");
3233 			statement("    return vec<T, 4>(1, 1, 1, 1);");
3234 			end_scope();
3235 			end_scope();
3236 			statement("return t.gather_compare(s, spvForward<Ts>(params)...);");
3237 			end_scope();
3238 			statement("");
3239 			break;
3240 
3241 		case SPVFuncImplSubgroupBallot:
3242 			statement("inline uint4 spvSubgroupBallot(bool value)");
3243 			begin_scope();
3244 			statement("simd_vote vote = simd_ballot(value);");
3245 			statement("// simd_ballot() returns a 64-bit integer-like object, but");
3246 			statement("// SPIR-V callers expect a uint4. We must convert.");
3247 			statement("// FIXME: This won't include higher bits if Apple ever supports");
3248 			statement("// 128 lanes in an SIMD-group.");
3249 			statement("return uint4((uint)((simd_vote::vote_t)vote & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)vote >> "
3250 			          "32) & 0xFFFFFFFF), 0, 0);");
3251 			end_scope();
3252 			statement("");
3253 			break;
3254 
3255 		case SPVFuncImplSubgroupBallotBitExtract:
3256 			statement("inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)");
3257 			begin_scope();
3258 			statement("return !!extract_bits(ballot[bit / 32], bit % 32, 1);");
3259 			end_scope();
3260 			statement("");
3261 			break;
3262 
3263 		case SPVFuncImplSubgroupBallotFindLSB:
3264 			statement("inline uint spvSubgroupBallotFindLSB(uint4 ballot)");
3265 			begin_scope();
3266 			statement("return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + "
3267 			          "ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);");
3268 			end_scope();
3269 			statement("");
3270 			break;
3271 
3272 		case SPVFuncImplSubgroupBallotFindMSB:
3273 			statement("inline uint spvSubgroupBallotFindMSB(uint4 ballot)");
3274 			begin_scope();
3275 			statement("return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - "
3276 			          "(clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), "
3277 			          "ballot.z == 0), ballot.w == 0);");
3278 			end_scope();
3279 			statement("");
3280 			break;
3281 
3282 		case SPVFuncImplSubgroupBallotBitCount:
3283 			statement("inline uint spvSubgroupBallotBitCount(uint4 ballot)");
3284 			begin_scope();
3285 			statement("return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);");
3286 			end_scope();
3287 			statement("");
3288 			statement("inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
3289 			begin_scope();
3290 			statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), "
3291 			          "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), "
3292 			          "uint2(0));");
3293 			statement("return spvSubgroupBallotBitCount(ballot & mask);");
3294 			end_scope();
3295 			statement("");
3296 			statement("inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
3297 			begin_scope();
3298 			statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), "
3299 			          "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));");
3300 			statement("return spvSubgroupBallotBitCount(ballot & mask);");
3301 			end_scope();
3302 			statement("");
3303 			break;
3304 
3305 		case SPVFuncImplSubgroupAllEqual:
3306 			// Metal doesn't provide a function to evaluate this directly. But, we can
3307 			// implement this by comparing every thread's value to one thread's value
3308 			// (in this case, the value of the first active thread). Then, by the transitive
3309 			// property of equality, if all comparisons return true, then they are all equal.
3310 			statement("template<typename T>");
3311 			statement("inline bool spvSubgroupAllEqual(T value)");
3312 			begin_scope();
3313 			statement("return simd_all(value == simd_broadcast_first(value));");
3314 			end_scope();
3315 			statement("");
3316 			statement("template<>");
3317 			statement("inline bool spvSubgroupAllEqual(bool value)");
3318 			begin_scope();
3319 			statement("return simd_all(value) || !simd_any(value);");
3320 			end_scope();
3321 			statement("");
3322 			break;
3323 
3324 		case SPVFuncImplReflectScalar:
3325 			// Metal does not support scalar versions of these functions.
3326 			statement("template<typename T>");
3327 			statement("inline T spvReflect(T i, T n)");
3328 			begin_scope();
3329 			statement("return i - T(2) * i * n * n;");
3330 			end_scope();
3331 			statement("");
3332 			break;
3333 
3334 		case SPVFuncImplRefractScalar:
3335 			// Metal does not support scalar versions of these functions.
3336 			statement("template<typename T>");
3337 			statement("inline T spvRefract(T i, T n, T eta)");
3338 			begin_scope();
3339 			statement("T NoI = n * i;");
3340 			statement("T NoI2 = NoI * NoI;");
3341 			statement("T k = T(1) - eta * eta * (T(1) - NoI2);");
3342 			statement("if (k < T(0))");
3343 			begin_scope();
3344 			statement("return T(0);");
3345 			end_scope();
3346 			statement("else");
3347 			begin_scope();
3348 			statement("return eta * i - (eta * NoI + sqrt(k)) * n;");
3349 			end_scope();
3350 			end_scope();
3351 			statement("");
3352 			break;
3353 
3354 		default:
3355 			break;
3356 		}
3357 	}
3358 }
3359 
3360 // Undefined global memory is not allowed in MSL.
3361 // Declare constant and init to zeros. Use {}, as global constructors can break Metal.
declare_undefined_values()3362 void CompilerMSL::declare_undefined_values()
3363 {
3364 	bool emitted = false;
3365 	ir.for_each_typed_id<SPIRUndef>([&](uint32_t, SPIRUndef &undef) {
3366 		auto &type = this->get<SPIRType>(undef.basetype);
3367 		statement("constant ", variable_decl(type, to_name(undef.self), undef.self), " = {};");
3368 		emitted = true;
3369 	});
3370 
3371 	if (emitted)
3372 		statement("");
3373 }
3374 
declare_constant_arrays()3375 void CompilerMSL::declare_constant_arrays()
3376 {
3377 	// MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
3378 	// global constants directly, so we are able to use constants as variable expressions.
3379 	bool emitted = false;
3380 
3381 	ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
3382 		if (c.specialization)
3383 			return;
3384 
3385 		auto &type = this->get<SPIRType>(c.constant_type);
3386 		if (!type.array.empty())
3387 		{
3388 			auto name = to_name(c.self);
3389 			statement("constant ", variable_decl(type, name), " = ", constant_expression(c), ";");
3390 			emitted = true;
3391 		}
3392 	});
3393 
3394 	if (emitted)
3395 		statement("");
3396 }
3397 
emit_resources()3398 void CompilerMSL::emit_resources()
3399 {
3400 	declare_constant_arrays();
3401 	declare_undefined_values();
3402 
3403 	// Emit the special [[stage_in]] and [[stage_out]] interface blocks which we created.
3404 	emit_interface_block(stage_out_var_id);
3405 	emit_interface_block(patch_stage_out_var_id);
3406 	emit_interface_block(stage_in_var_id);
3407 	emit_interface_block(patch_stage_in_var_id);
3408 }
3409 
3410 // Emit declarations for the specialization Metal function constants
emit_specialization_constants_and_structs()3411 void CompilerMSL::emit_specialization_constants_and_structs()
3412 {
3413 	SpecializationConstant wg_x, wg_y, wg_z;
3414 	uint32_t workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
3415 	bool emitted = false;
3416 
3417 	unordered_set<uint32_t> declared_structs;
3418 
3419 	for (auto &id_ : ir.ids_for_constant_or_type)
3420 	{
3421 		auto &id = ir.ids[id_];
3422 
3423 		if (id.get_type() == TypeConstant)
3424 		{
3425 			auto &c = id.get<SPIRConstant>();
3426 
3427 			if (c.self == workgroup_size_id)
3428 			{
3429 				// TODO: This can be expressed as a [[threads_per_threadgroup]] input semantic, but we need to know
3430 				// the work group size at compile time in SPIR-V, and [[threads_per_threadgroup]] would need to be passed around as a global.
3431 				// The work group size may be a specialization constant.
3432 				statement("constant uint3 ", builtin_to_glsl(BuiltInWorkgroupSize, StorageClassWorkgroup),
3433 				          " [[maybe_unused]] = ", constant_expression(get<SPIRConstant>(workgroup_size_id)), ";");
3434 				emitted = true;
3435 			}
3436 			else if (c.specialization)
3437 			{
3438 				auto &type = get<SPIRType>(c.constant_type);
3439 				string sc_type_name = type_to_glsl(type);
3440 				string sc_name = to_name(c.self);
3441 				string sc_tmp_name = sc_name + "_tmp";
3442 
3443 				// Function constants are only supported in MSL 1.2 and later.
3444 				// If we don't support it just declare the "default" directly.
3445 				// This "default" value can be overridden to the true specialization constant by the API user.
3446 				// Specialization constants which are used as array length expressions cannot be function constants in MSL,
3447 				// so just fall back to macros.
3448 				if (msl_options.supports_msl_version(1, 2) && has_decoration(c.self, DecorationSpecId) &&
3449 				    !c.is_used_as_array_length)
3450 				{
3451 					uint32_t constant_id = get_decoration(c.self, DecorationSpecId);
3452 					// Only scalar, non-composite values can be function constants.
3453 					statement("constant ", sc_type_name, " ", sc_tmp_name, " [[function_constant(", constant_id,
3454 					          ")]];");
3455 					statement("constant ", sc_type_name, " ", sc_name, " = is_function_constant_defined(", sc_tmp_name,
3456 					          ") ? ", sc_tmp_name, " : ", constant_expression(c), ";");
3457 				}
3458 				else if (has_decoration(c.self, DecorationSpecId))
3459 				{
3460 					// Fallback to macro overrides.
3461 					c.specialization_constant_macro_name =
3462 					    constant_value_macro_name(get_decoration(c.self, DecorationSpecId));
3463 
3464 					statement("#ifndef ", c.specialization_constant_macro_name);
3465 					statement("#define ", c.specialization_constant_macro_name, " ", constant_expression(c));
3466 					statement("#endif");
3467 					statement("constant ", sc_type_name, " ", sc_name, " = ", c.specialization_constant_macro_name,
3468 					          ";");
3469 				}
3470 				else
3471 				{
3472 					// Composite specialization constants must be built from other specialization constants.
3473 					statement("constant ", sc_type_name, " ", sc_name, " = ", constant_expression(c), ";");
3474 				}
3475 				emitted = true;
3476 			}
3477 		}
3478 		else if (id.get_type() == TypeConstantOp)
3479 		{
3480 			auto &c = id.get<SPIRConstantOp>();
3481 			auto &type = get<SPIRType>(c.basetype);
3482 			auto name = to_name(c.self);
3483 			statement("constant ", variable_decl(type, name), " = ", constant_op_expression(c), ";");
3484 			emitted = true;
3485 		}
3486 		else if (id.get_type() == TypeType)
3487 		{
3488 			// Output non-builtin interface structs. These include local function structs
3489 			// and structs nested within uniform and read-write buffers.
3490 			auto &type = id.get<SPIRType>();
3491 			uint32_t type_id = type.self;
3492 
3493 			bool is_struct = (type.basetype == SPIRType::Struct) && type.array.empty();
3494 			bool is_block =
3495 			    has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock);
3496 
3497 			bool is_builtin_block = is_block && is_builtin_type(type);
3498 			bool is_declarable_struct = is_struct && !is_builtin_block;
3499 
3500 			// We'll declare this later.
3501 			if (stage_out_var_id && get_stage_out_struct_type().self == type_id)
3502 				is_declarable_struct = false;
3503 			if (patch_stage_out_var_id && get_patch_stage_out_struct_type().self == type_id)
3504 				is_declarable_struct = false;
3505 			if (stage_in_var_id && get_stage_in_struct_type().self == type_id)
3506 				is_declarable_struct = false;
3507 			if (patch_stage_in_var_id && get_patch_stage_in_struct_type().self == type_id)
3508 				is_declarable_struct = false;
3509 
3510 			// Align and emit declarable structs...but avoid declaring each more than once.
3511 			if (is_declarable_struct && declared_structs.count(type_id) == 0)
3512 			{
3513 				if (emitted)
3514 					statement("");
3515 				emitted = false;
3516 
3517 				declared_structs.insert(type_id);
3518 
3519 				if (has_extended_decoration(type_id, SPIRVCrossDecorationPacked))
3520 					align_struct(type);
3521 
3522 				// Make sure we declare the underlying struct type, and not the "decorated" type with pointers, etc.
3523 				emit_struct(get<SPIRType>(type_id));
3524 			}
3525 		}
3526 	}
3527 
3528 	if (emitted)
3529 		statement("");
3530 }
3531 
emit_binary_unord_op(uint32_t result_type,uint32_t result_id,uint32_t op0,uint32_t op1,const char * op)3532 void CompilerMSL::emit_binary_unord_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
3533                                        const char *op)
3534 {
3535 	bool forward = should_forward(op0) && should_forward(op1);
3536 	emit_op(result_type, result_id,
3537 	        join("(isunordered(", to_enclosed_unpacked_expression(op0), ", ", to_enclosed_unpacked_expression(op1),
3538 	             ") || ", to_enclosed_unpacked_expression(op0), " ", op, " ", to_enclosed_unpacked_expression(op1),
3539 	             ")"),
3540 	        forward);
3541 
3542 	inherit_expression_dependencies(result_id, op0);
3543 	inherit_expression_dependencies(result_id, op1);
3544 }
3545 
emit_tessellation_access_chain(const uint32_t * ops,uint32_t length)3546 bool CompilerMSL::emit_tessellation_access_chain(const uint32_t *ops, uint32_t length)
3547 {
3548 	// If this is a per-vertex output, remap it to the I/O array buffer.
3549 	auto *var = maybe_get<SPIRVariable>(ops[2]);
3550 	BuiltIn bi_type = BuiltIn(get_decoration(ops[2], DecorationBuiltIn));
3551 	if (var &&
3552 	    (var->storage == StorageClassInput ||
3553 	     (get_execution_model() == ExecutionModelTessellationControl && var->storage == StorageClassOutput)) &&
3554 	    !(has_decoration(ops[2], DecorationPatch) || is_patch_block(get_variable_data_type(*var))) &&
3555 	    (!is_builtin_variable(*var) || bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
3556 	     bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance ||
3557 	     get_variable_data_type(*var).basetype == SPIRType::Struct))
3558 	{
3559 		AccessChainMeta meta;
3560 		SmallVector<uint32_t> indices;
3561 		uint32_t next_id = ir.increase_bound_by(2);
3562 
3563 		indices.reserve(length - 3 + 1);
3564 		uint32_t type_id = next_id++;
3565 		SPIRType new_uint_type;
3566 		new_uint_type.basetype = SPIRType::UInt;
3567 		new_uint_type.width = 32;
3568 		set<SPIRType>(type_id, new_uint_type);
3569 
3570 		indices.push_back(ops[3]);
3571 
3572 		uint32_t const_mbr_id = next_id++;
3573 		uint32_t index = get_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex);
3574 		uint32_t ptr = var->storage == StorageClassInput ? stage_in_ptr_var_id : stage_out_ptr_var_id;
3575 		if (var->storage == StorageClassInput || has_decoration(get_variable_element_type(*var).self, DecorationBlock))
3576 		{
3577 			uint32_t i = 4;
3578 			auto *type = &get_variable_element_type(*var);
3579 			if (index == uint32_t(-1) && length >= 5)
3580 			{
3581 				// Maybe this is a struct type in the input class, in which case
3582 				// we put it as a decoration on the corresponding member.
3583 				index = get_extended_member_decoration(ops[2], get_constant(ops[4]).scalar(),
3584 				                                       SPIRVCrossDecorationInterfaceMemberIndex);
3585 				assert(index != uint32_t(-1));
3586 				i++;
3587 				type = &get<SPIRType>(type->member_types[get_constant(ops[4]).scalar()]);
3588 			}
3589 			// In this case, we flattened structures and arrays, so now we have to
3590 			// combine the following indices. If we encounter a non-constant index,
3591 			// we're hosed.
3592 			for (; i < length; ++i)
3593 			{
3594 				if (!is_array(*type) && !is_matrix(*type) && type->basetype != SPIRType::Struct)
3595 					break;
3596 
3597 				auto &c = get_constant(ops[i]);
3598 				index += c.scalar();
3599 				if (type->parent_type)
3600 					type = &get<SPIRType>(type->parent_type);
3601 				else if (type->basetype == SPIRType::Struct)
3602 					type = &get<SPIRType>(type->member_types[c.scalar()]);
3603 			}
3604 			// If the access chain terminates at a composite type, the composite
3605 			// itself might be copied. In that case, we must unflatten it.
3606 			if (is_matrix(*type) || is_array(*type) || type->basetype == SPIRType::Struct)
3607 			{
3608 				std::string temp_name = join(to_name(var->self), "_", ops[1]);
3609 				statement(variable_decl(*type, temp_name, var->self), ";");
3610 				// Set up the initializer for this temporary variable.
3611 				indices.push_back(const_mbr_id);
3612 				if (type->basetype == SPIRType::Struct)
3613 				{
3614 					for (uint32_t j = 0; j < type->member_types.size(); j++)
3615 					{
3616 						index = get_extended_member_decoration(ops[2], j, SPIRVCrossDecorationInterfaceMemberIndex);
3617 						const auto &mbr_type = get<SPIRType>(type->member_types[j]);
3618 						if (is_matrix(mbr_type))
3619 						{
3620 							for (uint32_t k = 0; k < mbr_type.columns; k++, index++)
3621 							{
3622 								set<SPIRConstant>(const_mbr_id, type_id, index, false);
3623 								auto e = access_chain(ptr, indices.data(), uint32_t(indices.size()), mbr_type, nullptr,
3624 								                      true);
3625 								statement(temp_name, ".", to_member_name(*type, j), "[", k, "] = ", e, ";");
3626 							}
3627 						}
3628 						else if (is_array(mbr_type))
3629 						{
3630 							for (uint32_t k = 0; k < mbr_type.array[0]; k++, index++)
3631 							{
3632 								set<SPIRConstant>(const_mbr_id, type_id, index, false);
3633 								auto e = access_chain(ptr, indices.data(), uint32_t(indices.size()), mbr_type, nullptr,
3634 								                      true);
3635 								statement(temp_name, ".", to_member_name(*type, j), "[", k, "] = ", e, ";");
3636 							}
3637 						}
3638 						else
3639 						{
3640 							set<SPIRConstant>(const_mbr_id, type_id, index, false);
3641 							auto e =
3642 							    access_chain(ptr, indices.data(), uint32_t(indices.size()), mbr_type, nullptr, true);
3643 							statement(temp_name, ".", to_member_name(*type, j), " = ", e, ";");
3644 						}
3645 					}
3646 				}
3647 				else if (is_matrix(*type))
3648 				{
3649 					for (uint32_t j = 0; j < type->columns; j++, index++)
3650 					{
3651 						set<SPIRConstant>(const_mbr_id, type_id, index, false);
3652 						auto e = access_chain(ptr, indices.data(), uint32_t(indices.size()), *type, nullptr, true);
3653 						statement(temp_name, "[", j, "] = ", e, ";");
3654 					}
3655 				}
3656 				else // Must be an array
3657 				{
3658 					assert(is_array(*type));
3659 					for (uint32_t j = 0; j < type->array[0]; j++, index++)
3660 					{
3661 						set<SPIRConstant>(const_mbr_id, type_id, index, false);
3662 						auto e = access_chain(ptr, indices.data(), uint32_t(indices.size()), *type, nullptr, true);
3663 						statement(temp_name, "[", j, "] = ", e, ";");
3664 					}
3665 				}
3666 
3667 				// This needs to be a variable instead of an expression so we don't
3668 				// try to dereference this as a variable pointer.
3669 				set<SPIRVariable>(ops[1], ops[0], var->storage);
3670 				ir.meta[ops[1]] = ir.meta[ops[2]];
3671 				set_name(ops[1], temp_name);
3672 				if (has_decoration(var->self, DecorationInvariant))
3673 					set_decoration(ops[1], DecorationInvariant);
3674 				for (uint32_t j = 2; j < length; j++)
3675 					inherit_expression_dependencies(ops[1], ops[j]);
3676 				return true;
3677 			}
3678 			else
3679 			{
3680 				set<SPIRConstant>(const_mbr_id, type_id, index, false);
3681 				indices.push_back(const_mbr_id);
3682 
3683 				if (i < length)
3684 					indices.insert(indices.end(), ops + i, ops + length);
3685 			}
3686 		}
3687 		else
3688 		{
3689 			assert(index != uint32_t(-1));
3690 			set<SPIRConstant>(const_mbr_id, type_id, index, false);
3691 			indices.push_back(const_mbr_id);
3692 
3693 			indices.insert(indices.end(), ops + 4, ops + length);
3694 		}
3695 
3696 		// We use the pointer to the base of the input/output array here,
3697 		// so this is always a pointer chain.
3698 		auto e = access_chain(ptr, indices.data(), uint32_t(indices.size()), get<SPIRType>(ops[0]), &meta, true);
3699 		auto &expr = set<SPIRExpression>(ops[1], move(e), ops[0], should_forward(ops[2]));
3700 		expr.loaded_from = var->self;
3701 		expr.need_transpose = meta.need_transpose;
3702 		expr.access_chain = true;
3703 
3704 		// Mark the result as being packed if necessary.
3705 		if (meta.storage_is_packed)
3706 			set_extended_decoration(ops[1], SPIRVCrossDecorationPacked);
3707 		if (meta.storage_packed_type != 0)
3708 			set_extended_decoration(ops[1], SPIRVCrossDecorationPackedType, meta.storage_packed_type);
3709 		if (meta.storage_is_invariant)
3710 			set_decoration(ops[1], DecorationInvariant);
3711 
3712 		for (uint32_t i = 2; i < length; i++)
3713 		{
3714 			inherit_expression_dependencies(ops[1], ops[i]);
3715 			add_implied_read_expression(expr, ops[i]);
3716 		}
3717 
3718 		return true;
3719 	}
3720 
3721 	// If this is the inner tessellation level, and we're tessellating triangles,
3722 	// drop the last index. It isn't an array in this case, so we can't have an
3723 	// array reference here. We need to make this ID a variable instead of an
3724 	// expression so we don't try to dereference it as a variable pointer.
3725 	// Don't do this if the index is a constant 1, though. We need to drop stores
3726 	// to that one.
3727 	auto *m = ir.find_meta(var ? var->self : 0);
3728 	if (get_execution_model() == ExecutionModelTessellationControl && var && m &&
3729 	    m->decoration.builtin_type == BuiltInTessLevelInner && get_entry_point().flags.get(ExecutionModeTriangles))
3730 	{
3731 		auto *c = maybe_get<SPIRConstant>(ops[3]);
3732 		if (c && c->scalar() == 1)
3733 			return false;
3734 		auto &dest_var = set<SPIRVariable>(ops[1], *var);
3735 		dest_var.basetype = ops[0];
3736 		ir.meta[ops[1]] = ir.meta[ops[2]];
3737 		inherit_expression_dependencies(ops[1], ops[2]);
3738 		return true;
3739 	}
3740 
3741 	return false;
3742 }
3743 
is_out_of_bounds_tessellation_level(uint32_t id_lhs)3744 bool CompilerMSL::is_out_of_bounds_tessellation_level(uint32_t id_lhs)
3745 {
3746 	if (!get_entry_point().flags.get(ExecutionModeTriangles))
3747 		return false;
3748 
3749 	// In SPIR-V, TessLevelInner always has two elements and TessLevelOuter always has
3750 	// four. This is true even if we are tessellating triangles. This allows clients
3751 	// to use a single tessellation control shader with multiple tessellation evaluation
3752 	// shaders.
3753 	// In Metal, however, only the first element of TessLevelInner and the first three
3754 	// of TessLevelOuter are accessible. This stems from how in Metal, the tessellation
3755 	// levels must be stored to a dedicated buffer in a particular format that depends
3756 	// on the patch type. Therefore, in Triangles mode, any access to the second
3757 	// inner level or the fourth outer level must be dropped.
3758 	const auto *e = maybe_get<SPIRExpression>(id_lhs);
3759 	if (!e || !e->access_chain)
3760 		return false;
3761 	BuiltIn builtin = BuiltIn(get_decoration(e->loaded_from, DecorationBuiltIn));
3762 	if (builtin != BuiltInTessLevelInner && builtin != BuiltInTessLevelOuter)
3763 		return false;
3764 	auto *c = maybe_get<SPIRConstant>(e->implied_read_expressions[1]);
3765 	if (!c)
3766 		return false;
3767 	return (builtin == BuiltInTessLevelInner && c->scalar() == 1) ||
3768 	       (builtin == BuiltInTessLevelOuter && c->scalar() == 3);
3769 }
3770 
3771 // Override for MSL-specific syntax instructions
emit_instruction(const Instruction & instruction)3772 void CompilerMSL::emit_instruction(const Instruction &instruction)
3773 {
3774 #define MSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
3775 #define MSL_BOP_CAST(op, type) \
3776 	emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
3777 #define MSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
3778 #define MSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
3779 #define MSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
3780 #define MSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
3781 #define MSL_BFOP_CAST(op, type) \
3782 	emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
3783 #define MSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
3784 #define MSL_UNORD_BOP(op) emit_binary_unord_op(ops[0], ops[1], ops[2], ops[3], #op)
3785 
3786 	auto ops = stream(instruction);
3787 	auto opcode = static_cast<Op>(instruction.op);
3788 
3789 	// If we need to do implicit bitcasts, make sure we do it with the correct type.
3790 	uint32_t integer_width = get_integer_width_for_instruction(instruction);
3791 	auto int_type = to_signed_basetype(integer_width);
3792 	auto uint_type = to_unsigned_basetype(integer_width);
3793 
3794 	switch (opcode)
3795 	{
3796 
3797 	// Comparisons
3798 	case OpIEqual:
3799 		MSL_BOP_CAST(==, int_type);
3800 		break;
3801 
3802 	case OpLogicalEqual:
3803 	case OpFOrdEqual:
3804 		MSL_BOP(==);
3805 		break;
3806 
3807 	case OpINotEqual:
3808 		MSL_BOP_CAST(!=, int_type);
3809 		break;
3810 
3811 	case OpLogicalNotEqual:
3812 	case OpFOrdNotEqual:
3813 		MSL_BOP(!=);
3814 		break;
3815 
3816 	case OpUGreaterThan:
3817 		MSL_BOP_CAST(>, uint_type);
3818 		break;
3819 
3820 	case OpSGreaterThan:
3821 		MSL_BOP_CAST(>, int_type);
3822 		break;
3823 
3824 	case OpFOrdGreaterThan:
3825 		MSL_BOP(>);
3826 		break;
3827 
3828 	case OpUGreaterThanEqual:
3829 		MSL_BOP_CAST(>=, uint_type);
3830 		break;
3831 
3832 	case OpSGreaterThanEqual:
3833 		MSL_BOP_CAST(>=, int_type);
3834 		break;
3835 
3836 	case OpFOrdGreaterThanEqual:
3837 		MSL_BOP(>=);
3838 		break;
3839 
3840 	case OpULessThan:
3841 		MSL_BOP_CAST(<, uint_type);
3842 		break;
3843 
3844 	case OpSLessThan:
3845 		MSL_BOP_CAST(<, int_type);
3846 		break;
3847 
3848 	case OpFOrdLessThan:
3849 		MSL_BOP(<);
3850 		break;
3851 
3852 	case OpULessThanEqual:
3853 		MSL_BOP_CAST(<=, uint_type);
3854 		break;
3855 
3856 	case OpSLessThanEqual:
3857 		MSL_BOP_CAST(<=, int_type);
3858 		break;
3859 
3860 	case OpFOrdLessThanEqual:
3861 		MSL_BOP(<=);
3862 		break;
3863 
3864 	case OpFUnordEqual:
3865 		MSL_UNORD_BOP(==);
3866 		break;
3867 
3868 	case OpFUnordNotEqual:
3869 		MSL_UNORD_BOP(!=);
3870 		break;
3871 
3872 	case OpFUnordGreaterThan:
3873 		MSL_UNORD_BOP(>);
3874 		break;
3875 
3876 	case OpFUnordGreaterThanEqual:
3877 		MSL_UNORD_BOP(>=);
3878 		break;
3879 
3880 	case OpFUnordLessThan:
3881 		MSL_UNORD_BOP(<);
3882 		break;
3883 
3884 	case OpFUnordLessThanEqual:
3885 		MSL_UNORD_BOP(<=);
3886 		break;
3887 
3888 	// Derivatives
3889 	case OpDPdx:
3890 	case OpDPdxFine:
3891 	case OpDPdxCoarse:
3892 		MSL_UFOP(dfdx);
3893 		register_control_dependent_expression(ops[1]);
3894 		break;
3895 
3896 	case OpDPdy:
3897 	case OpDPdyFine:
3898 	case OpDPdyCoarse:
3899 		MSL_UFOP(dfdy);
3900 		register_control_dependent_expression(ops[1]);
3901 		break;
3902 
3903 	case OpFwidth:
3904 	case OpFwidthCoarse:
3905 	case OpFwidthFine:
3906 		MSL_UFOP(fwidth);
3907 		register_control_dependent_expression(ops[1]);
3908 		break;
3909 
3910 	// Bitfield
3911 	case OpBitFieldInsert:
3912 		MSL_QFOP(insert_bits);
3913 		break;
3914 
3915 	case OpBitFieldSExtract:
3916 	case OpBitFieldUExtract:
3917 		MSL_TFOP(extract_bits);
3918 		break;
3919 
3920 	case OpBitReverse:
3921 		MSL_UFOP(reverse_bits);
3922 		break;
3923 
3924 	case OpBitCount:
3925 		MSL_UFOP(popcount);
3926 		break;
3927 
3928 	case OpFRem:
3929 		MSL_BFOP(fmod);
3930 		break;
3931 
3932 	// Atomics
3933 	case OpAtomicExchange:
3934 	{
3935 		uint32_t result_type = ops[0];
3936 		uint32_t id = ops[1];
3937 		uint32_t ptr = ops[2];
3938 		uint32_t mem_sem = ops[4];
3939 		uint32_t val = ops[5];
3940 		emit_atomic_func_op(result_type, id, "atomic_exchange_explicit", mem_sem, mem_sem, false, ptr, val);
3941 		break;
3942 	}
3943 
3944 	case OpAtomicCompareExchange:
3945 	{
3946 		uint32_t result_type = ops[0];
3947 		uint32_t id = ops[1];
3948 		uint32_t ptr = ops[2];
3949 		uint32_t mem_sem_pass = ops[4];
3950 		uint32_t mem_sem_fail = ops[5];
3951 		uint32_t val = ops[6];
3952 		uint32_t comp = ops[7];
3953 		emit_atomic_func_op(result_type, id, "atomic_compare_exchange_weak_explicit", mem_sem_pass, mem_sem_fail, true,
3954 		                    ptr, comp, true, false, val);
3955 		break;
3956 	}
3957 
3958 	case OpAtomicCompareExchangeWeak:
3959 		SPIRV_CROSS_THROW("OpAtomicCompareExchangeWeak is only supported in kernel profile.");
3960 
3961 	case OpAtomicLoad:
3962 	{
3963 		uint32_t result_type = ops[0];
3964 		uint32_t id = ops[1];
3965 		uint32_t ptr = ops[2];
3966 		uint32_t mem_sem = ops[4];
3967 		emit_atomic_func_op(result_type, id, "atomic_load_explicit", mem_sem, mem_sem, false, ptr, 0);
3968 		break;
3969 	}
3970 
3971 	case OpAtomicStore:
3972 	{
3973 		uint32_t result_type = expression_type(ops[0]).self;
3974 		uint32_t id = ops[0];
3975 		uint32_t ptr = ops[0];
3976 		uint32_t mem_sem = ops[2];
3977 		uint32_t val = ops[3];
3978 		emit_atomic_func_op(result_type, id, "atomic_store_explicit", mem_sem, mem_sem, false, ptr, val);
3979 		break;
3980 	}
3981 
3982 #define MSL_AFMO_IMPL(op, valsrc, valconst)                                                                      \
3983 	do                                                                                                           \
3984 	{                                                                                                            \
3985 		uint32_t result_type = ops[0];                                                                           \
3986 		uint32_t id = ops[1];                                                                                    \
3987 		uint32_t ptr = ops[2];                                                                                   \
3988 		uint32_t mem_sem = ops[4];                                                                               \
3989 		uint32_t val = valsrc;                                                                                   \
3990 		emit_atomic_func_op(result_type, id, "atomic_fetch_" #op "_explicit", mem_sem, mem_sem, false, ptr, val, \
3991 		                    false, valconst);                                                                    \
3992 	} while (false)
3993 
3994 #define MSL_AFMO(op) MSL_AFMO_IMPL(op, ops[5], false)
3995 #define MSL_AFMIO(op) MSL_AFMO_IMPL(op, 1, true)
3996 
3997 	case OpAtomicIIncrement:
3998 		MSL_AFMIO(add);
3999 		break;
4000 
4001 	case OpAtomicIDecrement:
4002 		MSL_AFMIO(sub);
4003 		break;
4004 
4005 	case OpAtomicIAdd:
4006 		MSL_AFMO(add);
4007 		break;
4008 
4009 	case OpAtomicISub:
4010 		MSL_AFMO(sub);
4011 		break;
4012 
4013 	case OpAtomicSMin:
4014 	case OpAtomicUMin:
4015 		MSL_AFMO(min);
4016 		break;
4017 
4018 	case OpAtomicSMax:
4019 	case OpAtomicUMax:
4020 		MSL_AFMO(max);
4021 		break;
4022 
4023 	case OpAtomicAnd:
4024 		MSL_AFMO(and);
4025 		break;
4026 
4027 	case OpAtomicOr:
4028 		MSL_AFMO(or);
4029 		break;
4030 
4031 	case OpAtomicXor:
4032 		MSL_AFMO(xor);
4033 		break;
4034 
4035 	// Images
4036 
4037 	// Reads == Fetches in Metal
4038 	case OpImageRead:
4039 	{
4040 		// Mark that this shader reads from this image
4041 		uint32_t img_id = ops[2];
4042 		auto &type = expression_type(img_id);
4043 		if (type.image.dim != DimSubpassData)
4044 		{
4045 			auto *p_var = maybe_get_backing_variable(img_id);
4046 			if (p_var && has_decoration(p_var->self, DecorationNonReadable))
4047 			{
4048 				unset_decoration(p_var->self, DecorationNonReadable);
4049 				force_recompile();
4050 			}
4051 		}
4052 
4053 		emit_texture_op(instruction);
4054 		break;
4055 	}
4056 
4057 	case OpImageWrite:
4058 	{
4059 		uint32_t img_id = ops[0];
4060 		uint32_t coord_id = ops[1];
4061 		uint32_t texel_id = ops[2];
4062 		const uint32_t *opt = &ops[3];
4063 		uint32_t length = instruction.length - 3;
4064 
4065 		// Bypass pointers because we need the real image struct
4066 		auto &type = expression_type(img_id);
4067 		auto &img_type = get<SPIRType>(type.self);
4068 
4069 		// Ensure this image has been marked as being written to and force a
4070 		// recommpile so that the image type output will include write access
4071 		auto *p_var = maybe_get_backing_variable(img_id);
4072 		if (p_var && has_decoration(p_var->self, DecorationNonWritable))
4073 		{
4074 			unset_decoration(p_var->self, DecorationNonWritable);
4075 			force_recompile();
4076 		}
4077 
4078 		bool forward = false;
4079 		uint32_t bias = 0;
4080 		uint32_t lod = 0;
4081 		uint32_t flags = 0;
4082 
4083 		if (length)
4084 		{
4085 			flags = *opt++;
4086 			length--;
4087 		}
4088 
4089 		auto test = [&](uint32_t &v, uint32_t flag) {
4090 			if (length && (flags & flag))
4091 			{
4092 				v = *opt++;
4093 				length--;
4094 			}
4095 		};
4096 
4097 		test(bias, ImageOperandsBiasMask);
4098 		test(lod, ImageOperandsLodMask);
4099 
4100 		auto &texel_type = expression_type(texel_id);
4101 		auto store_type = texel_type;
4102 		store_type.vecsize = 4;
4103 
4104 		statement(join(to_expression(img_id), ".write(",
4105 		               remap_swizzle(store_type, texel_type.vecsize, to_expression(texel_id)), ", ",
4106 		               to_function_args(img_id, img_type, true, false, false, coord_id, 0, 0, 0, 0, lod, 0, 0, 0, 0, 0,
4107 		                                0, &forward),
4108 		               ");"));
4109 
4110 		if (p_var && variable_storage_is_aliased(*p_var))
4111 			flush_all_aliased_variables();
4112 
4113 		break;
4114 	}
4115 
4116 	case OpImageQuerySize:
4117 	case OpImageQuerySizeLod:
4118 	{
4119 		uint32_t rslt_type_id = ops[0];
4120 		auto &rslt_type = get<SPIRType>(rslt_type_id);
4121 
4122 		uint32_t id = ops[1];
4123 
4124 		uint32_t img_id = ops[2];
4125 		string img_exp = to_expression(img_id);
4126 		auto &img_type = expression_type(img_id);
4127 		Dim img_dim = img_type.image.dim;
4128 		bool img_is_array = img_type.image.arrayed;
4129 
4130 		if (img_type.basetype != SPIRType::Image)
4131 			SPIRV_CROSS_THROW("Invalid type for OpImageQuerySize.");
4132 
4133 		string lod;
4134 		if (opcode == OpImageQuerySizeLod)
4135 		{
4136 			// LOD index defaults to zero, so don't bother outputing level zero index
4137 			string decl_lod = to_expression(ops[3]);
4138 			if (decl_lod != "0")
4139 				lod = decl_lod;
4140 		}
4141 
4142 		string expr = type_to_glsl(rslt_type) + "(";
4143 		expr += img_exp + ".get_width(" + lod + ")";
4144 
4145 		if (img_dim == Dim2D || img_dim == DimCube || img_dim == Dim3D)
4146 			expr += ", " + img_exp + ".get_height(" + lod + ")";
4147 
4148 		if (img_dim == Dim3D)
4149 			expr += ", " + img_exp + ".get_depth(" + lod + ")";
4150 
4151 		if (img_is_array)
4152 			expr += ", " + img_exp + ".get_array_size()";
4153 
4154 		expr += ")";
4155 
4156 		emit_op(rslt_type_id, id, expr, should_forward(img_id));
4157 
4158 		break;
4159 	}
4160 
4161 	case OpImageQueryLod:
4162 	{
4163 		if (!msl_options.supports_msl_version(2, 2))
4164 			SPIRV_CROSS_THROW("ImageQueryLod is only supported on MSL 2.2 and up.");
4165 		uint32_t result_type = ops[0];
4166 		uint32_t id = ops[1];
4167 		uint32_t image_id = ops[2];
4168 		uint32_t coord_id = ops[3];
4169 		emit_uninitialized_temporary_expression(result_type, id);
4170 
4171 		auto sampler_expr = to_sampler_expression(image_id);
4172 		auto *combined = maybe_get<SPIRCombinedImageSampler>(image_id);
4173 		auto image_expr = combined ? to_expression(combined->image) : to_expression(image_id);
4174 
4175 		// TODO: It is unclear if calculcate_clamped_lod also conditionally rounds
4176 		// the reported LOD based on the sampler. NEAREST miplevel should
4177 		// round the LOD, but LINEAR miplevel should not round.
4178 		// Let's hope this does not become an issue ...
4179 		statement(to_expression(id), ".x = ", image_expr, ".calculate_clamped_lod(", sampler_expr, ", ",
4180 		          to_expression(coord_id), ");");
4181 		statement(to_expression(id), ".y = ", image_expr, ".calculate_unclamped_lod(", sampler_expr, ", ",
4182 		          to_expression(coord_id), ");");
4183 		register_control_dependent_expression(id);
4184 		break;
4185 	}
4186 
4187 #define MSL_ImgQry(qrytype)                                                                 \
4188 	do                                                                                      \
4189 	{                                                                                       \
4190 		uint32_t rslt_type_id = ops[0];                                                     \
4191 		auto &rslt_type = get<SPIRType>(rslt_type_id);                                      \
4192 		uint32_t id = ops[1];                                                               \
4193 		uint32_t img_id = ops[2];                                                           \
4194 		string img_exp = to_expression(img_id);                                             \
4195 		string expr = type_to_glsl(rslt_type) + "(" + img_exp + ".get_num_" #qrytype "())"; \
4196 		emit_op(rslt_type_id, id, expr, should_forward(img_id));                            \
4197 	} while (false)
4198 
4199 	case OpImageQueryLevels:
4200 		MSL_ImgQry(mip_levels);
4201 		break;
4202 
4203 	case OpImageQuerySamples:
4204 		MSL_ImgQry(samples);
4205 		break;
4206 
4207 	case OpImage:
4208 	{
4209 		uint32_t result_type = ops[0];
4210 		uint32_t id = ops[1];
4211 		auto *combined = maybe_get<SPIRCombinedImageSampler>(ops[2]);
4212 
4213 		if (combined)
4214 		{
4215 			auto &e = emit_op(result_type, id, to_expression(combined->image), true, true);
4216 			auto *var = maybe_get_backing_variable(combined->image);
4217 			if (var)
4218 				e.loaded_from = var->self;
4219 		}
4220 		else
4221 		{
4222 			auto &e = emit_op(result_type, id, to_expression(ops[2]), true, true);
4223 			auto *var = maybe_get_backing_variable(ops[2]);
4224 			if (var)
4225 				e.loaded_from = var->self;
4226 		}
4227 		break;
4228 	}
4229 
4230 	case OpImageTexelPointer:
4231 		SPIRV_CROSS_THROW("MSL does not support atomic operations on images or texel buffers.");
4232 
4233 	// Casting
4234 	case OpQuantizeToF16:
4235 	{
4236 		uint32_t result_type = ops[0];
4237 		uint32_t id = ops[1];
4238 		uint32_t arg = ops[2];
4239 
4240 		string exp;
4241 		auto &type = get<SPIRType>(result_type);
4242 
4243 		switch (type.vecsize)
4244 		{
4245 		case 1:
4246 			exp = join("float(half(", to_expression(arg), "))");
4247 			break;
4248 		case 2:
4249 			exp = join("float2(half2(", to_expression(arg), "))");
4250 			break;
4251 		case 3:
4252 			exp = join("float3(half3(", to_expression(arg), "))");
4253 			break;
4254 		case 4:
4255 			exp = join("float4(half4(", to_expression(arg), "))");
4256 			break;
4257 		default:
4258 			SPIRV_CROSS_THROW("Illegal argument to OpQuantizeToF16.");
4259 		}
4260 
4261 		emit_op(result_type, id, exp, should_forward(arg));
4262 		break;
4263 	}
4264 
4265 	case OpInBoundsAccessChain:
4266 	case OpAccessChain:
4267 	case OpPtrAccessChain:
4268 		if (is_tessellation_shader())
4269 		{
4270 			if (!emit_tessellation_access_chain(ops, instruction.length))
4271 				CompilerGLSL::emit_instruction(instruction);
4272 		}
4273 		else
4274 			CompilerGLSL::emit_instruction(instruction);
4275 		break;
4276 
4277 	case OpStore:
4278 		if (is_out_of_bounds_tessellation_level(ops[0]))
4279 			break;
4280 
4281 		if (maybe_emit_array_assignment(ops[0], ops[1]))
4282 			break;
4283 
4284 		CompilerGLSL::emit_instruction(instruction);
4285 		break;
4286 
4287 	// Compute barriers
4288 	case OpMemoryBarrier:
4289 		emit_barrier(0, ops[0], ops[1]);
4290 		break;
4291 
4292 	case OpControlBarrier:
4293 		// In GLSL a memory barrier is often followed by a control barrier.
4294 		// But in MSL, memory barriers are also control barriers, so don't
4295 		// emit a simple control barrier if a memory barrier has just been emitted.
4296 		if (previous_instruction_opcode != OpMemoryBarrier)
4297 			emit_barrier(ops[0], ops[1], ops[2]);
4298 		break;
4299 
4300 	case OpVectorTimesMatrix:
4301 	case OpMatrixTimesVector:
4302 	{
4303 		// If the matrix needs transpose and it is square or packed, just flip the multiply order.
4304 		uint32_t mtx_id = ops[opcode == OpMatrixTimesVector ? 2 : 3];
4305 		auto *e = maybe_get<SPIRExpression>(mtx_id);
4306 		auto &t = expression_type(mtx_id);
4307 		bool is_packed = has_extended_decoration(mtx_id, SPIRVCrossDecorationPacked);
4308 		if (e && e->need_transpose && (t.columns == t.vecsize || is_packed))
4309 		{
4310 			e->need_transpose = false;
4311 			emit_binary_op(ops[0], ops[1], ops[3], ops[2], "*");
4312 			e->need_transpose = true;
4313 		}
4314 		else
4315 			MSL_BOP(*);
4316 		break;
4317 	}
4318 
4319 	case OpOuterProduct:
4320 	{
4321 		uint32_t result_type = ops[0];
4322 		uint32_t id = ops[1];
4323 		uint32_t a = ops[2];
4324 		uint32_t b = ops[3];
4325 
4326 		auto &type = get<SPIRType>(result_type);
4327 		string expr = type_to_glsl_constructor(type);
4328 		expr += "(";
4329 		for (uint32_t col = 0; col < type.columns; col++)
4330 		{
4331 			expr += to_enclosed_expression(a);
4332 			expr += " * ";
4333 			expr += to_extract_component_expression(b, col);
4334 			if (col + 1 < type.columns)
4335 				expr += ", ";
4336 		}
4337 		expr += ")";
4338 		emit_op(result_type, id, expr, should_forward(a) && should_forward(b));
4339 		inherit_expression_dependencies(id, a);
4340 		inherit_expression_dependencies(id, b);
4341 		break;
4342 	}
4343 
4344 	case OpIAddCarry:
4345 	case OpISubBorrow:
4346 	{
4347 		uint32_t result_type = ops[0];
4348 		uint32_t result_id = ops[1];
4349 		uint32_t op0 = ops[2];
4350 		uint32_t op1 = ops[3];
4351 		forced_temporaries.insert(result_id);
4352 		auto &type = get<SPIRType>(result_type);
4353 		statement(variable_decl(type, to_name(result_id)), ";");
4354 		set<SPIRExpression>(result_id, to_name(result_id), result_type, true);
4355 
4356 		auto &res_type = get<SPIRType>(type.member_types[1]);
4357 		if (opcode == OpIAddCarry)
4358 		{
4359 			statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " + ",
4360 			          to_enclosed_expression(op1), ";");
4361 			statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
4362 			          "(1), ", type_to_glsl(res_type), "(0), ", to_expression(result_id), ".", to_member_name(type, 0),
4363 			          " >= max(", to_expression(op0), ", ", to_expression(op1), "));");
4364 		}
4365 		else
4366 		{
4367 			statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " - ",
4368 			          to_enclosed_expression(op1), ";");
4369 			statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
4370 			          "(1), ", type_to_glsl(res_type), "(0), ", to_enclosed_expression(op0),
4371 			          " >= ", to_enclosed_expression(op1), ");");
4372 		}
4373 		break;
4374 	}
4375 
4376 	case OpUMulExtended:
4377 	case OpSMulExtended:
4378 	{
4379 		uint32_t result_type = ops[0];
4380 		uint32_t result_id = ops[1];
4381 		uint32_t op0 = ops[2];
4382 		uint32_t op1 = ops[3];
4383 		forced_temporaries.insert(result_id);
4384 		auto &type = get<SPIRType>(result_type);
4385 		statement(variable_decl(type, to_name(result_id)), ";");
4386 		set<SPIRExpression>(result_id, to_name(result_id), result_type, true);
4387 
4388 		statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " * ",
4389 		          to_enclosed_expression(op1), ";");
4390 		statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(", to_expression(op0), ", ",
4391 		          to_expression(op1), ");");
4392 		break;
4393 	}
4394 
4395 	case OpArrayLength:
4396 	{
4397 		auto &type = expression_type(ops[2]);
4398 		uint32_t offset = type_struct_member_offset(type, ops[3]);
4399 		uint32_t stride = type_struct_member_array_stride(type, ops[3]);
4400 
4401 		auto expr = join("(", to_buffer_size_expression(ops[2]), " - ", offset, ") / ", stride);
4402 		emit_op(ops[0], ops[1], expr, true);
4403 		break;
4404 	}
4405 
4406 	default:
4407 		CompilerGLSL::emit_instruction(instruction);
4408 		break;
4409 	}
4410 
4411 	previous_instruction_opcode = opcode;
4412 }
4413 
emit_barrier(uint32_t id_exe_scope,uint32_t id_mem_scope,uint32_t id_mem_sem)4414 void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uint32_t id_mem_sem)
4415 {
4416 	if (get_execution_model() != ExecutionModelGLCompute && get_execution_model() != ExecutionModelTessellationControl)
4417 		return;
4418 
4419 	uint32_t exe_scope = id_exe_scope ? get<SPIRConstant>(id_exe_scope).scalar() : uint32_t(ScopeInvocation);
4420 	uint32_t mem_scope = id_mem_scope ? get<SPIRConstant>(id_mem_scope).scalar() : uint32_t(ScopeInvocation);
4421 	// Use the wider of the two scopes (smaller value)
4422 	exe_scope = min(exe_scope, mem_scope);
4423 
4424 	string bar_stmt;
4425 	if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2))
4426 		bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier";
4427 	else
4428 		bar_stmt = "threadgroup_barrier";
4429 	bar_stmt += "(";
4430 
4431 	uint32_t mem_sem = id_mem_sem ? get<SPIRConstant>(id_mem_sem).scalar() : uint32_t(MemorySemanticsMaskNone);
4432 
4433 	// Use the | operator to combine flags if we can.
4434 	if (msl_options.supports_msl_version(1, 2))
4435 	{
4436 		string mem_flags = "";
4437 		// For tesc shaders, this also affects objects in the Output storage class.
4438 		// Since in Metal, these are placed in a device buffer, we have to sync device memory here.
4439 		if (get_execution_model() == ExecutionModelTessellationControl ||
4440 		    (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)))
4441 			mem_flags += "mem_flags::mem_device";
4442 		if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask |
4443 		               MemorySemanticsAtomicCounterMemoryMask))
4444 		{
4445 			if (!mem_flags.empty())
4446 				mem_flags += " | ";
4447 			mem_flags += "mem_flags::mem_threadgroup";
4448 		}
4449 		if (mem_sem & MemorySemanticsImageMemoryMask)
4450 		{
4451 			if (!mem_flags.empty())
4452 				mem_flags += " | ";
4453 			mem_flags += "mem_flags::mem_texture";
4454 		}
4455 
4456 		if (mem_flags.empty())
4457 			mem_flags = "mem_flags::mem_none";
4458 
4459 		bar_stmt += mem_flags;
4460 	}
4461 	else
4462 	{
4463 		if ((mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)) &&
4464 		    (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask |
4465 		                MemorySemanticsAtomicCounterMemoryMask)))
4466 			bar_stmt += "mem_flags::mem_device_and_threadgroup";
4467 		else if (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask))
4468 			bar_stmt += "mem_flags::mem_device";
4469 		else if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask |
4470 		                    MemorySemanticsAtomicCounterMemoryMask))
4471 			bar_stmt += "mem_flags::mem_threadgroup";
4472 		else if (mem_sem & MemorySemanticsImageMemoryMask)
4473 			bar_stmt += "mem_flags::mem_texture";
4474 		else
4475 			bar_stmt += "mem_flags::mem_none";
4476 	}
4477 
4478 	if (msl_options.is_ios() && (msl_options.supports_msl_version(2) && !msl_options.supports_msl_version(2, 1)))
4479 	{
4480 		bar_stmt += ", ";
4481 
4482 		switch (mem_scope)
4483 		{
4484 		case ScopeCrossDevice:
4485 		case ScopeDevice:
4486 			bar_stmt += "memory_scope_device";
4487 			break;
4488 
4489 		case ScopeSubgroup:
4490 		case ScopeInvocation:
4491 			bar_stmt += "memory_scope_simdgroup";
4492 			break;
4493 
4494 		case ScopeWorkgroup:
4495 		default:
4496 			bar_stmt += "memory_scope_threadgroup";
4497 			break;
4498 		}
4499 	}
4500 
4501 	bar_stmt += ");";
4502 
4503 	statement(bar_stmt);
4504 
4505 	assert(current_emitting_block);
4506 	flush_control_dependent_expressions(current_emitting_block->self);
4507 	flush_all_active_variables();
4508 }
4509 
emit_array_copy(const string & lhs,uint32_t rhs_id)4510 void CompilerMSL::emit_array_copy(const string &lhs, uint32_t rhs_id)
4511 {
4512 	// Assignment from an array initializer is fine.
4513 	auto &type = expression_type(rhs_id);
4514 	auto *var = maybe_get_backing_variable(rhs_id);
4515 
4516 	// Unfortunately, we cannot template on address space in MSL,
4517 	// so explicit address space redirection it is ...
4518 	bool is_constant = false;
4519 	if (ir.ids[rhs_id].get_type() == TypeConstant)
4520 	{
4521 		is_constant = true;
4522 	}
4523 	else if (var && var->remapped_variable && var->statically_assigned &&
4524 	         ir.ids[var->static_expression].get_type() == TypeConstant)
4525 	{
4526 		is_constant = true;
4527 	}
4528 
4529 	// For the case where we have OpLoad triggering an array copy,
4530 	// we cannot easily detect this case ahead of time since it's
4531 	// context dependent. We might have to force a recompile here
4532 	// if this is the only use of array copies in our shader.
4533 	if (type.array.size() > 1)
4534 	{
4535 		if (type.array.size() > SPVFuncImplArrayCopyMultidimMax)
4536 			SPIRV_CROSS_THROW("Cannot support this many dimensions for arrays of arrays.");
4537 		auto func = static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + type.array.size());
4538 		if (spv_function_implementations.count(func) == 0)
4539 		{
4540 			spv_function_implementations.insert(func);
4541 			suppress_missing_prototypes = true;
4542 			force_recompile();
4543 		}
4544 	}
4545 	else if (spv_function_implementations.count(SPVFuncImplArrayCopy) == 0)
4546 	{
4547 		spv_function_implementations.insert(SPVFuncImplArrayCopy);
4548 		suppress_missing_prototypes = true;
4549 		force_recompile();
4550 	}
4551 
4552 	const char *tag = is_constant ? "FromConstant" : "FromStack";
4553 	statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ", ", to_expression(rhs_id), ");");
4554 }
4555 
4556 // Since MSL does not allow arrays to be copied via simple variable assignment,
4557 // if the LHS and RHS represent an assignment of an entire array, it must be
4558 // implemented by calling an array copy function.
4559 // Returns whether the struct assignment was emitted.
maybe_emit_array_assignment(uint32_t id_lhs,uint32_t id_rhs)4560 bool CompilerMSL::maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs)
4561 {
4562 	// We only care about assignments of an entire array
4563 	auto &type = expression_type(id_rhs);
4564 	if (type.array.size() == 0)
4565 		return false;
4566 
4567 	auto *var = maybe_get<SPIRVariable>(id_lhs);
4568 
4569 	// Is this a remapped, static constant? Don't do anything.
4570 	if (var && var->remapped_variable && var->statically_assigned)
4571 		return true;
4572 
4573 	if (ir.ids[id_rhs].get_type() == TypeConstant && var && var->deferred_declaration)
4574 	{
4575 		// Special case, if we end up declaring a variable when assigning the constant array,
4576 		// we can avoid the copy by directly assigning the constant expression.
4577 		// This is likely necessary to be able to use a variable as a true look-up table, as it is unlikely
4578 		// the compiler will be able to optimize the spvArrayCopy() into a constant LUT.
4579 		// After a variable has been declared, we can no longer assign constant arrays in MSL unfortunately.
4580 		statement(to_expression(id_lhs), " = ", constant_expression(get<SPIRConstant>(id_rhs)), ";");
4581 		return true;
4582 	}
4583 
4584 	// Ensure the LHS variable has been declared
4585 	auto *p_v_lhs = maybe_get_backing_variable(id_lhs);
4586 	if (p_v_lhs)
4587 		flush_variable_declaration(p_v_lhs->self);
4588 
4589 	emit_array_copy(to_expression(id_lhs), id_rhs);
4590 	register_write(id_lhs);
4591 
4592 	return true;
4593 }
4594 
4595 // Emits one of the atomic functions. In MSL, the atomic functions operate on pointers
emit_atomic_func_op(uint32_t result_type,uint32_t result_id,const char * op,uint32_t mem_order_1,uint32_t mem_order_2,bool has_mem_order_2,uint32_t obj,uint32_t op1,bool op1_is_pointer,bool op1_is_literal,uint32_t op2)4596 void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, uint32_t mem_order_1,
4597                                       uint32_t mem_order_2, bool has_mem_order_2, uint32_t obj, uint32_t op1,
4598                                       bool op1_is_pointer, bool op1_is_literal, uint32_t op2)
4599 {
4600 	forced_temporaries.insert(result_id);
4601 
4602 	string exp = string(op) + "(";
4603 
4604 	auto &type = get_pointee_type(expression_type(obj));
4605 	exp += "(volatile ";
4606 	auto *var = maybe_get_backing_variable(obj);
4607 	if (!var)
4608 		SPIRV_CROSS_THROW("No backing variable for atomic operation.");
4609 	exp += get_argument_address_space(*var);
4610 	exp += " atomic_";
4611 	exp += type_to_glsl(type);
4612 	exp += "*)";
4613 
4614 	exp += "&";
4615 	exp += to_enclosed_expression(obj);
4616 
4617 	bool is_atomic_compare_exchange_strong = op1_is_pointer && op1;
4618 
4619 	if (is_atomic_compare_exchange_strong)
4620 	{
4621 		assert(strcmp(op, "atomic_compare_exchange_weak_explicit") == 0);
4622 		assert(op2);
4623 		assert(has_mem_order_2);
4624 		exp += ", &";
4625 		exp += to_name(result_id);
4626 		exp += ", ";
4627 		exp += to_expression(op2);
4628 		exp += ", ";
4629 		exp += get_memory_order(mem_order_1);
4630 		exp += ", ";
4631 		exp += get_memory_order(mem_order_2);
4632 		exp += ")";
4633 
4634 		// MSL only supports the weak atomic compare exchange, so emit a CAS loop here.
4635 		// The MSL function returns false if the atomic write fails OR the comparison test fails,
4636 		// so we must validate that it wasn't the comparison test that failed before continuing
4637 		// the CAS loop, otherwise it will loop infinitely, with the comparison test always failing.
4638 		// The function updates the comparitor value from the memory value, so the additional
4639 		// comparison test evaluates the memory value against the expected value.
4640 		statement(variable_decl(type, to_name(result_id)), ";");
4641 		statement("do");
4642 		begin_scope();
4643 		statement(to_name(result_id), " = ", to_expression(op1), ";");
4644 		end_scope_decl(join("while (!", exp, " && ", to_name(result_id), " == ", to_enclosed_expression(op1), ")"));
4645 		set<SPIRExpression>(result_id, to_name(result_id), result_type, true);
4646 	}
4647 	else
4648 	{
4649 		assert(strcmp(op, "atomic_compare_exchange_weak_explicit") != 0);
4650 		if (op1)
4651 		{
4652 			if (op1_is_literal)
4653 				exp += join(", ", op1);
4654 			else
4655 				exp += ", " + to_expression(op1);
4656 		}
4657 		if (op2)
4658 			exp += ", " + to_expression(op2);
4659 
4660 		exp += string(", ") + get_memory_order(mem_order_1);
4661 		if (has_mem_order_2)
4662 			exp += string(", ") + get_memory_order(mem_order_2);
4663 
4664 		exp += ")";
4665 		emit_op(result_type, result_id, exp, false);
4666 	}
4667 
4668 	flush_all_atomic_capable_variables();
4669 }
4670 
4671 // Metal only supports relaxed memory order for now
get_memory_order(uint32_t)4672 const char *CompilerMSL::get_memory_order(uint32_t)
4673 {
4674 	return "memory_order_relaxed";
4675 }
4676 
4677 // Override for MSL-specific extension syntax instructions
emit_glsl_op(uint32_t result_type,uint32_t id,uint32_t eop,const uint32_t * args,uint32_t count)4678 void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
4679 {
4680 	auto op = static_cast<GLSLstd450>(eop);
4681 
4682 	// If we need to do implicit bitcasts, make sure we do it with the correct type.
4683 	uint32_t integer_width = get_integer_width_for_glsl_instruction(op, args, count);
4684 	auto int_type = to_signed_basetype(integer_width);
4685 	auto uint_type = to_unsigned_basetype(integer_width);
4686 
4687 	switch (op)
4688 	{
4689 	case GLSLstd450Atan2:
4690 		emit_binary_func_op(result_type, id, args[0], args[1], "atan2");
4691 		break;
4692 	case GLSLstd450InverseSqrt:
4693 		emit_unary_func_op(result_type, id, args[0], "rsqrt");
4694 		break;
4695 	case GLSLstd450RoundEven:
4696 		emit_unary_func_op(result_type, id, args[0], "rint");
4697 		break;
4698 
4699 	case GLSLstd450FindSMsb:
4700 		emit_unary_func_op_cast(result_type, id, args[0], "findSMSB", int_type, int_type);
4701 		break;
4702 
4703 	case GLSLstd450FindUMsb:
4704 		emit_unary_func_op_cast(result_type, id, args[0], "findUMSB", uint_type, uint_type);
4705 		break;
4706 
4707 	case GLSLstd450PackSnorm4x8:
4708 		emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm4x8");
4709 		break;
4710 	case GLSLstd450PackUnorm4x8:
4711 		emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm4x8");
4712 		break;
4713 	case GLSLstd450PackSnorm2x16:
4714 		emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm2x16");
4715 		break;
4716 	case GLSLstd450PackUnorm2x16:
4717 		emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm2x16");
4718 		break;
4719 
4720 	case GLSLstd450PackHalf2x16:
4721 	{
4722 		auto expr = join("as_type<uint>(half2(", to_expression(args[0]), "))");
4723 		emit_op(result_type, id, expr, should_forward(args[0]));
4724 		inherit_expression_dependencies(id, args[0]);
4725 		break;
4726 	}
4727 
4728 	case GLSLstd450UnpackSnorm4x8:
4729 		emit_unary_func_op(result_type, id, args[0], "unpack_snorm4x8_to_float");
4730 		break;
4731 	case GLSLstd450UnpackUnorm4x8:
4732 		emit_unary_func_op(result_type, id, args[0], "unpack_unorm4x8_to_float");
4733 		break;
4734 	case GLSLstd450UnpackSnorm2x16:
4735 		emit_unary_func_op(result_type, id, args[0], "unpack_snorm2x16_to_float");
4736 		break;
4737 	case GLSLstd450UnpackUnorm2x16:
4738 		emit_unary_func_op(result_type, id, args[0], "unpack_unorm2x16_to_float");
4739 		break;
4740 
4741 	case GLSLstd450UnpackHalf2x16:
4742 	{
4743 		auto expr = join("float2(as_type<half2>(", to_expression(args[0]), "))");
4744 		emit_op(result_type, id, expr, should_forward(args[0]));
4745 		inherit_expression_dependencies(id, args[0]);
4746 		break;
4747 	}
4748 
4749 	case GLSLstd450PackDouble2x32:
4750 		emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450PackDouble2x32"); // Currently unsupported
4751 		break;
4752 	case GLSLstd450UnpackDouble2x32:
4753 		emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450UnpackDouble2x32"); // Currently unsupported
4754 		break;
4755 
4756 	case GLSLstd450MatrixInverse:
4757 	{
4758 		auto &mat_type = get<SPIRType>(result_type);
4759 		switch (mat_type.columns)
4760 		{
4761 		case 2:
4762 			emit_unary_func_op(result_type, id, args[0], "spvInverse2x2");
4763 			break;
4764 		case 3:
4765 			emit_unary_func_op(result_type, id, args[0], "spvInverse3x3");
4766 			break;
4767 		case 4:
4768 			emit_unary_func_op(result_type, id, args[0], "spvInverse4x4");
4769 			break;
4770 		default:
4771 			break;
4772 		}
4773 		break;
4774 	}
4775 
4776 	case GLSLstd450FMin:
4777 		// If the result type isn't float, don't bother calling the specific
4778 		// precise::/fast:: version. Metal doesn't have those for half and
4779 		// double types.
4780 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
4781 			emit_binary_func_op(result_type, id, args[0], args[1], "min");
4782 		else
4783 			emit_binary_func_op(result_type, id, args[0], args[1], "fast::min");
4784 		break;
4785 
4786 	case GLSLstd450FMax:
4787 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
4788 			emit_binary_func_op(result_type, id, args[0], args[1], "max");
4789 		else
4790 			emit_binary_func_op(result_type, id, args[0], args[1], "fast::max");
4791 		break;
4792 
4793 	case GLSLstd450FClamp:
4794 		// TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
4795 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
4796 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
4797 		else
4798 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "fast::clamp");
4799 		break;
4800 
4801 	case GLSLstd450NMin:
4802 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
4803 			emit_binary_func_op(result_type, id, args[0], args[1], "min");
4804 		else
4805 			emit_binary_func_op(result_type, id, args[0], args[1], "precise::min");
4806 		break;
4807 
4808 	case GLSLstd450NMax:
4809 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
4810 			emit_binary_func_op(result_type, id, args[0], args[1], "max");
4811 		else
4812 			emit_binary_func_op(result_type, id, args[0], args[1], "precise::max");
4813 		break;
4814 
4815 	case GLSLstd450NClamp:
4816 		// TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
4817 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
4818 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
4819 		else
4820 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "precise::clamp");
4821 		break;
4822 
4823 		// TODO:
4824 		//        GLSLstd450InterpolateAtCentroid (centroid_no_perspective qualifier)
4825 		//        GLSLstd450InterpolateAtSample (sample_no_perspective qualifier)
4826 		//        GLSLstd450InterpolateAtOffset
4827 
4828 	case GLSLstd450Distance:
4829 		// MSL does not support scalar versions here.
4830 		if (expression_type(args[0]).vecsize == 1)
4831 		{
4832 			// Equivalent to length(a - b) -> abs(a - b).
4833 			emit_op(result_type, id,
4834 			        join("abs(", to_unpacked_expression(args[0]), " - ", to_unpacked_expression(args[1]), ")"),
4835 			        should_forward(args[0]) && should_forward(args[1]));
4836 			inherit_expression_dependencies(id, args[0]);
4837 			inherit_expression_dependencies(id, args[1]);
4838 		}
4839 		else
4840 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
4841 		break;
4842 
4843 	case GLSLstd450Length:
4844 		// MSL does not support scalar versions here.
4845 		if (expression_type(args[0]).vecsize == 1)
4846 		{
4847 			// Equivalent to abs().
4848 			emit_unary_func_op(result_type, id, args[0], "abs");
4849 		}
4850 		else
4851 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
4852 		break;
4853 
4854 	case GLSLstd450Normalize:
4855 		// MSL does not support scalar versions here.
4856 		if (expression_type(args[0]).vecsize == 1)
4857 		{
4858 			// Returns -1 or 1 for valid input, sign() does the job.
4859 			emit_unary_func_op(result_type, id, args[0], "sign");
4860 		}
4861 		else
4862 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
4863 		break;
4864 
4865 	case GLSLstd450Reflect:
4866 		if (get<SPIRType>(result_type).vecsize == 1)
4867 			emit_binary_func_op(result_type, id, args[0], args[1], "spvReflect");
4868 		else
4869 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
4870 		break;
4871 
4872 	case GLSLstd450Refract:
4873 		if (get<SPIRType>(result_type).vecsize == 1)
4874 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvRefract");
4875 		else
4876 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
4877 		break;
4878 
4879 	default:
4880 		CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
4881 		break;
4882 	}
4883 }
4884 
4885 // Emit a structure declaration for the specified interface variable.
emit_interface_block(uint32_t ib_var_id)4886 void CompilerMSL::emit_interface_block(uint32_t ib_var_id)
4887 {
4888 	if (ib_var_id)
4889 	{
4890 		auto &ib_var = get<SPIRVariable>(ib_var_id);
4891 		auto &ib_type = get_variable_data_type(ib_var);
4892 		assert(ib_type.basetype == SPIRType::Struct && !ib_type.member_types.empty());
4893 		emit_struct(ib_type);
4894 	}
4895 }
4896 
4897 // Emits the declaration signature of the specified function.
4898 // If this is the entry point function, Metal-specific return value and function arguments are added.
emit_function_prototype(SPIRFunction & func,const Bitset &)4899 void CompilerMSL::emit_function_prototype(SPIRFunction &func, const Bitset &)
4900 {
4901 	if (func.self != ir.default_entry_point)
4902 		add_function_overload(func);
4903 
4904 	local_variable_names = resource_names;
4905 	string decl;
4906 
4907 	processing_entry_point = (func.self == ir.default_entry_point);
4908 
4909 	auto &type = get<SPIRType>(func.return_type);
4910 
4911 	if (type.array.empty())
4912 	{
4913 		decl += func_type_decl(type);
4914 	}
4915 	else
4916 	{
4917 		// We cannot return arrays in MSL, so "return" through an out variable.
4918 		decl = "void";
4919 	}
4920 
4921 	decl += " ";
4922 	decl += to_name(func.self);
4923 	decl += "(";
4924 
4925 	if (!type.array.empty())
4926 	{
4927 		// Fake arrays returns by writing to an out array instead.
4928 		decl += "thread ";
4929 		decl += type_to_glsl(type);
4930 		decl += " (&SPIRV_Cross_return_value)";
4931 		decl += type_to_array_glsl(type);
4932 		if (!func.arguments.empty())
4933 			decl += ", ";
4934 	}
4935 
4936 	if (processing_entry_point)
4937 	{
4938 		if (msl_options.argument_buffers)
4939 			decl += entry_point_args_argument_buffer(!func.arguments.empty());
4940 		else
4941 			decl += entry_point_args_classic(!func.arguments.empty());
4942 
4943 		// If entry point function has variables that require early declaration,
4944 		// ensure they each have an empty initializer, creating one if needed.
4945 		// This is done at this late stage because the initialization expression
4946 		// is cleared after each compilation pass.
4947 		for (auto var_id : vars_needing_early_declaration)
4948 		{
4949 			auto &ed_var = get<SPIRVariable>(var_id);
4950 			uint32_t &initializer = ed_var.initializer;
4951 			if (!initializer)
4952 				initializer = ir.increase_bound_by(1);
4953 
4954 			// Do not override proper initializers.
4955 			if (ir.ids[initializer].get_type() == TypeNone || ir.ids[initializer].get_type() == TypeExpression)
4956 				set<SPIRExpression>(ed_var.initializer, "{}", ed_var.basetype, true);
4957 		}
4958 	}
4959 
4960 	for (auto &arg : func.arguments)
4961 	{
4962 		uint32_t name_id = arg.id;
4963 
4964 		auto *var = maybe_get<SPIRVariable>(arg.id);
4965 		if (var)
4966 		{
4967 			// If we need to modify the name of the variable, make sure we modify the original variable.
4968 			// Our alias is just a shadow variable.
4969 			if (arg.alias_global_variable && var->basevariable)
4970 				name_id = var->basevariable;
4971 
4972 			var->parameter = &arg; // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
4973 		}
4974 
4975 		add_local_variable_name(name_id);
4976 
4977 		decl += argument_decl(arg);
4978 
4979 		// Manufacture automatic sampler arg for SampledImage texture
4980 		auto &arg_type = get<SPIRType>(arg.type);
4981 		if (arg_type.basetype == SPIRType::SampledImage && arg_type.image.dim != DimBuffer)
4982 			decl += join(", thread const ", sampler_type(arg_type), " ", to_sampler_expression(arg.id));
4983 
4984 		// Manufacture automatic swizzle arg.
4985 		if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(arg_type))
4986 		{
4987 			bool arg_is_array = !arg_type.array.empty();
4988 			decl += join(", constant uint", arg_is_array ? "* " : "& ", to_swizzle_expression(arg.id));
4989 		}
4990 
4991 		if (buffers_requiring_array_length.count(name_id))
4992 		{
4993 			bool arg_is_array = !arg_type.array.empty();
4994 			decl += join(", constant uint", arg_is_array ? "* " : "& ", to_buffer_size_expression(name_id));
4995 		}
4996 
4997 		if (&arg != &func.arguments.back())
4998 			decl += ", ";
4999 	}
5000 
5001 	decl += ")";
5002 	statement(decl);
5003 }
5004 
5005 // Returns the texture sampling function string for the specified image and sampling characteristics.
to_function_name(uint32_t img,const SPIRType & imgtype,bool is_fetch,bool is_gather,bool,bool,bool has_offset,bool,bool has_dref,uint32_t,uint32_t)5006 string CompilerMSL::to_function_name(uint32_t img, const SPIRType &imgtype, bool is_fetch, bool is_gather, bool, bool,
5007                                      bool has_offset, bool, bool has_dref, uint32_t, uint32_t)
5008 {
5009 	// Special-case gather. We have to alter the component being looked up
5010 	// in the swizzle case.
5011 	if (msl_options.swizzle_texture_samples && is_gather)
5012 	{
5013 		string fname = imgtype.image.depth ? "spvGatherCompareSwizzle" : "spvGatherSwizzle";
5014 		fname += "<" + type_to_glsl(get<SPIRType>(imgtype.image.type)) + ", metal::" + type_to_glsl(imgtype);
5015 		// Add the arg types ourselves. Yes, this sucks, but Clang can't
5016 		// deduce template pack parameters in the middle of an argument list.
5017 		switch (imgtype.image.dim)
5018 		{
5019 		case Dim2D:
5020 			fname += ", float2";
5021 			if (imgtype.image.arrayed)
5022 				fname += ", uint";
5023 			if (imgtype.image.depth)
5024 				fname += ", float";
5025 			if (!imgtype.image.depth || has_offset)
5026 				fname += ", int2";
5027 			break;
5028 		case DimCube:
5029 			fname += ", float3";
5030 			if (imgtype.image.arrayed)
5031 				fname += ", uint";
5032 			if (imgtype.image.depth)
5033 				fname += ", float";
5034 			break;
5035 		default:
5036 			SPIRV_CROSS_THROW("Invalid texture dimension for gather op.");
5037 		}
5038 		fname += ">";
5039 		return fname;
5040 	}
5041 
5042 	auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
5043 
5044 	// Texture reference
5045 	string fname = to_expression(combined ? combined->image : img) + ".";
5046 	if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype))
5047 		fname = "spvTextureSwizzle(" + fname;
5048 
5049 	// Texture function and sampler
5050 	if (is_fetch)
5051 		fname += "read";
5052 	else if (is_gather)
5053 		fname += "gather";
5054 	else
5055 		fname += "sample";
5056 
5057 	if (has_dref)
5058 		fname += "_compare";
5059 
5060 	return fname;
5061 }
5062 
convert_to_f32(const string & expr,uint32_t components)5063 string CompilerMSL::convert_to_f32(const string &expr, uint32_t components)
5064 {
5065 	SPIRType t;
5066 	t.basetype = SPIRType::Float;
5067 	t.vecsize = components;
5068 	t.columns = 1;
5069 	return join(type_to_glsl_constructor(t), "(", expr, ")");
5070 }
5071 
sampling_type_needs_f32_conversion(const SPIRType & type)5072 static inline bool sampling_type_needs_f32_conversion(const SPIRType &type)
5073 {
5074 	// Double is not supported to begin with, but doesn't hurt to check for completion.
5075 	return type.basetype == SPIRType::Half || type.basetype == SPIRType::Double;
5076 }
5077 
5078 // Returns the function args for a texture sampling function for the specified image and sampling characteristics.
to_function_args(uint32_t img,const SPIRType & imgtype,bool is_fetch,bool is_gather,bool is_proj,uint32_t coord,uint32_t,uint32_t dref,uint32_t grad_x,uint32_t grad_y,uint32_t lod,uint32_t coffset,uint32_t offset,uint32_t bias,uint32_t comp,uint32_t sample,uint32_t minlod,bool * p_forward)5079 string CompilerMSL::to_function_args(uint32_t img, const SPIRType &imgtype, bool is_fetch, bool is_gather, bool is_proj,
5080                                      uint32_t coord, uint32_t, uint32_t dref, uint32_t grad_x, uint32_t grad_y,
5081                                      uint32_t lod, uint32_t coffset, uint32_t offset, uint32_t bias, uint32_t comp,
5082                                      uint32_t sample, uint32_t minlod, bool *p_forward)
5083 {
5084 	string farg_str;
5085 	if (!is_fetch)
5086 		farg_str += to_sampler_expression(img);
5087 
5088 	if (msl_options.swizzle_texture_samples && is_gather)
5089 	{
5090 		if (!farg_str.empty())
5091 			farg_str += ", ";
5092 
5093 		auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
5094 		farg_str += to_expression(combined ? combined->image : img);
5095 	}
5096 
5097 	// Texture coordinates
5098 	bool forward = should_forward(coord);
5099 	auto coord_expr = to_enclosed_expression(coord);
5100 	auto &coord_type = expression_type(coord);
5101 	bool coord_is_fp = type_is_floating_point(coord_type);
5102 	bool is_cube_fetch = false;
5103 
5104 	string tex_coords = coord_expr;
5105 	uint32_t alt_coord_component = 0;
5106 
5107 	switch (imgtype.image.dim)
5108 	{
5109 
5110 	case Dim1D:
5111 		if (coord_type.vecsize > 1)
5112 			tex_coords = enclose_expression(tex_coords) + ".x";
5113 
5114 		if (is_fetch)
5115 			tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
5116 		else if (sampling_type_needs_f32_conversion(coord_type))
5117 			tex_coords = convert_to_f32(tex_coords, 1);
5118 
5119 		alt_coord_component = 1;
5120 		break;
5121 
5122 	case DimBuffer:
5123 		if (coord_type.vecsize > 1)
5124 			tex_coords = enclose_expression(tex_coords) + ".x";
5125 
5126 		if (msl_options.texture_buffer_native)
5127 		{
5128 			tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
5129 		}
5130 		else
5131 		{
5132 			// Metal texel buffer textures are 2D, so convert 1D coord to 2D.
5133 			if (is_fetch)
5134 				tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
5135 		}
5136 
5137 		alt_coord_component = 1;
5138 		break;
5139 
5140 	case DimSubpassData:
5141 		if (imgtype.image.ms)
5142 			tex_coords = "uint2(gl_FragCoord.xy)";
5143 		else
5144 			tex_coords = join("uint2(gl_FragCoord.xy), 0");
5145 		break;
5146 
5147 	case Dim2D:
5148 		if (coord_type.vecsize > 2)
5149 			tex_coords = enclose_expression(tex_coords) + ".xy";
5150 
5151 		if (is_fetch)
5152 			tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
5153 		else if (sampling_type_needs_f32_conversion(coord_type))
5154 			tex_coords = convert_to_f32(tex_coords, 2);
5155 
5156 		alt_coord_component = 2;
5157 		break;
5158 
5159 	case Dim3D:
5160 		if (coord_type.vecsize > 3)
5161 			tex_coords = enclose_expression(tex_coords) + ".xyz";
5162 
5163 		if (is_fetch)
5164 			tex_coords = "uint3(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
5165 		else if (sampling_type_needs_f32_conversion(coord_type))
5166 			tex_coords = convert_to_f32(tex_coords, 3);
5167 
5168 		alt_coord_component = 3;
5169 		break;
5170 
5171 	case DimCube:
5172 		if (is_fetch)
5173 		{
5174 			is_cube_fetch = true;
5175 			tex_coords += ".xy";
5176 			tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
5177 		}
5178 		else
5179 		{
5180 			if (coord_type.vecsize > 3)
5181 				tex_coords = enclose_expression(tex_coords) + ".xyz";
5182 		}
5183 
5184 		if (sampling_type_needs_f32_conversion(coord_type))
5185 			tex_coords = convert_to_f32(tex_coords, 3);
5186 
5187 		alt_coord_component = 3;
5188 		break;
5189 
5190 	default:
5191 		break;
5192 	}
5193 
5194 	if (is_fetch && offset)
5195 	{
5196 		// Fetch offsets must be applied directly to the coordinate.
5197 		forward = forward && should_forward(offset);
5198 		auto &type = expression_type(offset);
5199 		if (type.basetype != SPIRType::UInt)
5200 			tex_coords += " + " + bitcast_expression(SPIRType::UInt, offset);
5201 		else
5202 			tex_coords += " + " + to_enclosed_expression(offset);
5203 	}
5204 	else if (is_fetch && coffset)
5205 	{
5206 		// Fetch offsets must be applied directly to the coordinate.
5207 		forward = forward && should_forward(coffset);
5208 		auto &type = expression_type(coffset);
5209 		if (type.basetype != SPIRType::UInt)
5210 			tex_coords += " + " + bitcast_expression(SPIRType::UInt, coffset);
5211 		else
5212 			tex_coords += " + " + to_enclosed_expression(coffset);
5213 	}
5214 
5215 	// If projection, use alt coord as divisor
5216 	if (is_proj)
5217 	{
5218 		if (sampling_type_needs_f32_conversion(coord_type))
5219 			tex_coords += " / " + convert_to_f32(to_extract_component_expression(coord, alt_coord_component), 1);
5220 		else
5221 			tex_coords += " / " + to_extract_component_expression(coord, alt_coord_component);
5222 	}
5223 
5224 	if (!farg_str.empty())
5225 		farg_str += ", ";
5226 	farg_str += tex_coords;
5227 
5228 	// If fetch from cube, add face explicitly
5229 	if (is_cube_fetch)
5230 	{
5231 		// Special case for cube arrays, face and layer are packed in one dimension.
5232 		if (imgtype.image.arrayed)
5233 			farg_str += ", uint(" + to_extract_component_expression(coord, 2) + ") % 6u";
5234 		else
5235 			farg_str += ", uint(" + round_fp_tex_coords(to_extract_component_expression(coord, 2), coord_is_fp) + ")";
5236 	}
5237 
5238 	// If array, use alt coord
5239 	if (imgtype.image.arrayed)
5240 	{
5241 		// Special case for cube arrays, face and layer are packed in one dimension.
5242 		if (imgtype.image.dim == DimCube && is_fetch)
5243 			farg_str += ", uint(" + to_extract_component_expression(coord, 2) + ") / 6u";
5244 		else
5245 			farg_str += ", uint(" +
5246 			            round_fp_tex_coords(to_extract_component_expression(coord, alt_coord_component), coord_is_fp) +
5247 			            ")";
5248 	}
5249 
5250 	// Depth compare reference value
5251 	if (dref)
5252 	{
5253 		forward = forward && should_forward(dref);
5254 		farg_str += ", ";
5255 
5256 		auto &dref_type = expression_type(dref);
5257 
5258 		string dref_expr;
5259 		if (is_proj)
5260 			dref_expr =
5261 			    join(to_enclosed_expression(dref), " / ", to_extract_component_expression(coord, alt_coord_component));
5262 		else
5263 			dref_expr = to_expression(dref);
5264 
5265 		if (sampling_type_needs_f32_conversion(dref_type))
5266 			dref_expr = convert_to_f32(dref_expr, 1);
5267 
5268 		farg_str += dref_expr;
5269 
5270 		if (msl_options.is_macos() && (grad_x || grad_y))
5271 		{
5272 			// For sample compare, MSL does not support gradient2d for all targets (only iOS apparently according to docs).
5273 			// However, the most common case here is to have a constant gradient of 0, as that is the only way to express
5274 			// LOD == 0 in GLSL with sampler2DArrayShadow (cascaded shadow mapping).
5275 			// We will detect a compile-time constant 0 value for gradient and promote that to level(0) on MSL.
5276 			bool constant_zero_x = !grad_x || expression_is_constant_null(grad_x);
5277 			bool constant_zero_y = !grad_y || expression_is_constant_null(grad_y);
5278 			if (constant_zero_x && constant_zero_y)
5279 			{
5280 				lod = 0;
5281 				grad_x = 0;
5282 				grad_y = 0;
5283 				farg_str += ", level(0)";
5284 			}
5285 			else
5286 			{
5287 				SPIRV_CROSS_THROW("Using non-constant 0.0 gradient() qualifier for sample_compare. This is not "
5288 				                  "supported in MSL macOS.");
5289 			}
5290 		}
5291 
5292 		if (msl_options.is_macos() && bias)
5293 		{
5294 			// Bias is not supported either on macOS with sample_compare.
5295 			// Verify it is compile-time zero, and drop the argument.
5296 			if (expression_is_constant_null(bias))
5297 			{
5298 				bias = 0;
5299 			}
5300 			else
5301 			{
5302 				SPIRV_CROSS_THROW(
5303 				    "Using non-constant 0.0 bias() qualifier for sample_compare. This is not supported in MSL macOS.");
5304 			}
5305 		}
5306 	}
5307 
5308 	// LOD Options
5309 	// Metal does not support LOD for 1D textures.
5310 	if (bias && imgtype.image.dim != Dim1D)
5311 	{
5312 		forward = forward && should_forward(bias);
5313 		farg_str += ", bias(" + to_expression(bias) + ")";
5314 	}
5315 
5316 	// Metal does not support LOD for 1D textures.
5317 	if (lod && imgtype.image.dim != Dim1D)
5318 	{
5319 		forward = forward && should_forward(lod);
5320 		if (is_fetch)
5321 		{
5322 			farg_str += ", " + to_expression(lod);
5323 		}
5324 		else
5325 		{
5326 			farg_str += ", level(" + to_expression(lod) + ")";
5327 		}
5328 	}
5329 	else if (is_fetch && !lod && imgtype.image.dim != Dim1D && imgtype.image.dim != DimBuffer && !imgtype.image.ms &&
5330 	         imgtype.image.sampled != 2)
5331 	{
5332 		// Lod argument is optional in OpImageFetch, but we require a LOD value, pick 0 as the default.
5333 		// Check for sampled type as well, because is_fetch is also used for OpImageRead in MSL.
5334 		farg_str += ", 0";
5335 	}
5336 
5337 	// Metal does not support LOD for 1D textures.
5338 	if ((grad_x || grad_y) && imgtype.image.dim != Dim1D)
5339 	{
5340 		forward = forward && should_forward(grad_x);
5341 		forward = forward && should_forward(grad_y);
5342 		string grad_opt;
5343 		switch (imgtype.image.dim)
5344 		{
5345 		case Dim2D:
5346 			grad_opt = "2d";
5347 			break;
5348 		case Dim3D:
5349 			grad_opt = "3d";
5350 			break;
5351 		case DimCube:
5352 			grad_opt = "cube";
5353 			break;
5354 		default:
5355 			grad_opt = "unsupported_gradient_dimension";
5356 			break;
5357 		}
5358 		farg_str += ", gradient" + grad_opt + "(" + to_expression(grad_x) + ", " + to_expression(grad_y) + ")";
5359 	}
5360 
5361 	if (minlod)
5362 	{
5363 		if (msl_options.is_macos())
5364 		{
5365 			if (!msl_options.supports_msl_version(2, 2))
5366 				SPIRV_CROSS_THROW("min_lod_clamp() is only supported in MSL 2.2+ and up on macOS.");
5367 		}
5368 		else if (msl_options.is_ios())
5369 			SPIRV_CROSS_THROW("min_lod_clamp() is not supported on iOS.");
5370 
5371 		forward = forward && should_forward(minlod);
5372 		farg_str += ", min_lod_clamp(" + to_expression(minlod) + ")";
5373 	}
5374 
5375 	// Add offsets
5376 	string offset_expr;
5377 	if (coffset && !is_fetch)
5378 	{
5379 		forward = forward && should_forward(coffset);
5380 		offset_expr = to_expression(coffset);
5381 	}
5382 	else if (offset && !is_fetch)
5383 	{
5384 		forward = forward && should_forward(offset);
5385 		offset_expr = to_expression(offset);
5386 	}
5387 
5388 	if (!offset_expr.empty())
5389 	{
5390 		switch (imgtype.image.dim)
5391 		{
5392 		case Dim2D:
5393 			if (coord_type.vecsize > 2)
5394 				offset_expr = enclose_expression(offset_expr) + ".xy";
5395 
5396 			farg_str += ", " + offset_expr;
5397 			break;
5398 
5399 		case Dim3D:
5400 			if (coord_type.vecsize > 3)
5401 				offset_expr = enclose_expression(offset_expr) + ".xyz";
5402 
5403 			farg_str += ", " + offset_expr;
5404 			break;
5405 
5406 		default:
5407 			break;
5408 		}
5409 	}
5410 
5411 	if (comp)
5412 	{
5413 		// If 2D has gather component, ensure it also has an offset arg
5414 		if (imgtype.image.dim == Dim2D && offset_expr.empty())
5415 			farg_str += ", int2(0)";
5416 
5417 		forward = forward && should_forward(comp);
5418 		farg_str += ", " + to_component_argument(comp);
5419 	}
5420 
5421 	if (sample)
5422 	{
5423 		forward = forward && should_forward(sample);
5424 		farg_str += ", ";
5425 		farg_str += to_expression(sample);
5426 	}
5427 
5428 	if (msl_options.swizzle_texture_samples && is_sampled_image_type(imgtype))
5429 	{
5430 		// Add the swizzle constant from the swizzle buffer.
5431 		if (!is_gather)
5432 			farg_str += ")";
5433 		farg_str += ", " + to_swizzle_expression(img);
5434 		used_swizzle_buffer = true;
5435 	}
5436 
5437 	*p_forward = forward;
5438 
5439 	return farg_str;
5440 }
5441 
5442 // If the texture coordinates are floating point, invokes MSL round() function to round them.
round_fp_tex_coords(string tex_coords,bool coord_is_fp)5443 string CompilerMSL::round_fp_tex_coords(string tex_coords, bool coord_is_fp)
5444 {
5445 	return coord_is_fp ? ("round(" + tex_coords + ")") : tex_coords;
5446 }
5447 
5448 // Returns a string to use in an image sampling function argument.
5449 // The ID must be a scalar constant.
to_component_argument(uint32_t id)5450 string CompilerMSL::to_component_argument(uint32_t id)
5451 {
5452 	if (ir.ids[id].get_type() != TypeConstant)
5453 	{
5454 		SPIRV_CROSS_THROW("ID " + to_string(id) + " is not an OpConstant.");
5455 		return "component::x";
5456 	}
5457 
5458 	uint32_t component_index = get<SPIRConstant>(id).scalar();
5459 	switch (component_index)
5460 	{
5461 	case 0:
5462 		return "component::x";
5463 	case 1:
5464 		return "component::y";
5465 	case 2:
5466 		return "component::z";
5467 	case 3:
5468 		return "component::w";
5469 
5470 	default:
5471 		SPIRV_CROSS_THROW("The value (" + to_string(component_index) + ") of OpConstant ID " + to_string(id) +
5472 		                  " is not a valid Component index, which must be one of 0, 1, 2, or 3.");
5473 		return "component::x";
5474 	}
5475 }
5476 
5477 // Establish sampled image as expression object and assign the sampler to it.
emit_sampled_image_op(uint32_t result_type,uint32_t result_id,uint32_t image_id,uint32_t samp_id)5478 void CompilerMSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
5479 {
5480 	set<SPIRCombinedImageSampler>(result_id, result_type, image_id, samp_id);
5481 }
5482 
5483 // Returns a string representation of the ID, usable as a function arg.
5484 // Manufacture automatic sampler arg for SampledImage texture.
to_func_call_arg(uint32_t id)5485 string CompilerMSL::to_func_call_arg(uint32_t id)
5486 {
5487 	string arg_str;
5488 
5489 	auto *c = maybe_get<SPIRConstant>(id);
5490 	if (c && !get<SPIRType>(c->constant_type).array.empty())
5491 	{
5492 		// If we are passing a constant array directly to a function for some reason,
5493 		// the callee will expect an argument in thread const address space
5494 		// (since we can only bind to arrays with references in MSL).
5495 		// To resolve this, we must emit a copy in this address space.
5496 		// This kind of code gen should be rare enough that performance is not a real concern.
5497 		// Inline the SPIR-V to avoid this kind of suboptimal codegen.
5498 		//
5499 		// We risk calling this inside a continue block (invalid code),
5500 		// so just create a thread local copy in the current function.
5501 		arg_str = join("_", id, "_array_copy");
5502 		auto &constants = current_function->constant_arrays_needed_on_stack;
5503 		auto itr = find(begin(constants), end(constants), id);
5504 		if (itr == end(constants))
5505 		{
5506 			force_recompile();
5507 			constants.push_back(id);
5508 		}
5509 	}
5510 	else
5511 		arg_str = CompilerGLSL::to_func_call_arg(id);
5512 
5513 	// Manufacture automatic sampler arg if the arg is a SampledImage texture.
5514 	auto &type = expression_type(id);
5515 	if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
5516 	{
5517 		// Need to check the base variable in case we need to apply a qualified alias.
5518 		uint32_t var_id = 0;
5519 		auto *sampler_var = maybe_get<SPIRVariable>(id);
5520 		if (sampler_var)
5521 			var_id = sampler_var->basevariable;
5522 
5523 		arg_str += ", " + to_sampler_expression(var_id ? var_id : id);
5524 	}
5525 
5526 	uint32_t var_id = 0;
5527 	auto *var = maybe_get<SPIRVariable>(id);
5528 	if (var)
5529 		var_id = var->basevariable;
5530 
5531 	if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
5532 	{
5533 		// Need to check the base variable in case we need to apply a qualified alias.
5534 		arg_str += ", " + to_swizzle_expression(var_id ? var_id : id);
5535 	}
5536 
5537 	if (buffers_requiring_array_length.count(var_id))
5538 		arg_str += ", " + to_buffer_size_expression(var_id ? var_id : id);
5539 
5540 	return arg_str;
5541 }
5542 
5543 // If the ID represents a sampled image that has been assigned a sampler already,
5544 // generate an expression for the sampler, otherwise generate a fake sampler name
5545 // by appending a suffix to the expression constructed from the ID.
to_sampler_expression(uint32_t id)5546 string CompilerMSL::to_sampler_expression(uint32_t id)
5547 {
5548 	auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
5549 	auto expr = to_expression(combined ? combined->image : id);
5550 	auto index = expr.find_first_of('[');
5551 
5552 	uint32_t samp_id = 0;
5553 	if (combined)
5554 		samp_id = combined->sampler;
5555 
5556 	if (index == string::npos)
5557 		return samp_id ? to_expression(samp_id) : expr + sampler_name_suffix;
5558 	else
5559 	{
5560 		auto image_expr = expr.substr(0, index);
5561 		auto array_expr = expr.substr(index);
5562 		return samp_id ? to_expression(samp_id) : (image_expr + sampler_name_suffix + array_expr);
5563 	}
5564 }
5565 
to_swizzle_expression(uint32_t id)5566 string CompilerMSL::to_swizzle_expression(uint32_t id)
5567 {
5568 	auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
5569 
5570 	auto expr = to_expression(combined ? combined->image : id);
5571 	auto index = expr.find_first_of('[');
5572 
5573 	// If an image is part of an argument buffer translate this to a legal identifier.
5574 	for (auto &c : expr)
5575 		if (c == '.')
5576 			c = '_';
5577 
5578 	if (index == string::npos)
5579 		return expr + swizzle_name_suffix;
5580 	else
5581 	{
5582 		auto image_expr = expr.substr(0, index);
5583 		auto array_expr = expr.substr(index);
5584 		return image_expr + swizzle_name_suffix + array_expr;
5585 	}
5586 }
5587 
to_buffer_size_expression(uint32_t id)5588 string CompilerMSL::to_buffer_size_expression(uint32_t id)
5589 {
5590 	auto expr = to_expression(id);
5591 	auto index = expr.find_first_of('[');
5592 
5593 	// This is quite crude, but we need to translate the reference name (*spvDescriptorSetN.name) to
5594 	// the pointer expression spvDescriptorSetN.name to make a reasonable expression here.
5595 	// This only happens if we have argument buffers and we are using OpArrayLength on a lone SSBO in that set.
5596 	if (expr.size() >= 3 && expr[0] == '(' && expr[1] == '*')
5597 		expr = address_of_expression(expr);
5598 
5599 	// If a buffer is part of an argument buffer translate this to a legal identifier.
5600 	for (auto &c : expr)
5601 		if (c == '.')
5602 			c = '_';
5603 
5604 	if (index == string::npos)
5605 		return expr + buffer_size_name_suffix;
5606 	else
5607 	{
5608 		auto buffer_expr = expr.substr(0, index);
5609 		auto array_expr = expr.substr(index);
5610 		return buffer_expr + buffer_size_name_suffix + array_expr;
5611 	}
5612 }
5613 
5614 // Checks whether the type is a Block all of whose members have DecorationPatch.
is_patch_block(const SPIRType & type)5615 bool CompilerMSL::is_patch_block(const SPIRType &type)
5616 {
5617 	if (!has_decoration(type.self, DecorationBlock))
5618 		return false;
5619 
5620 	for (uint32_t i = 0; i < type.member_types.size(); i++)
5621 	{
5622 		if (!has_member_decoration(type.self, i, DecorationPatch))
5623 			return false;
5624 	}
5625 
5626 	return true;
5627 }
5628 
5629 // Checks whether the ID is a row_major matrix that requires conversion before use
is_non_native_row_major_matrix(uint32_t id)5630 bool CompilerMSL::is_non_native_row_major_matrix(uint32_t id)
5631 {
5632 	// Natively supported row-major matrices do not need to be converted.
5633 	if (backend.native_row_major_matrix)
5634 		return false;
5635 
5636 	// Non-matrix or column-major matrix types do not need to be converted.
5637 	if (!has_decoration(id, DecorationRowMajor))
5638 		return false;
5639 
5640 	// Generate a function that will swap matrix elements from row-major to column-major.
5641 	// Packed row-matrix should just use transpose() function.
5642 	if (!has_extended_decoration(id, SPIRVCrossDecorationPacked))
5643 	{
5644 		const auto type = expression_type(id);
5645 		add_convert_row_major_matrix_function(type.columns, type.vecsize);
5646 	}
5647 
5648 	return true;
5649 }
5650 
5651 // Checks whether the member is a row_major matrix that requires conversion before use
member_is_non_native_row_major_matrix(const SPIRType & type,uint32_t index)5652 bool CompilerMSL::member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index)
5653 {
5654 	// Natively supported row-major matrices do not need to be converted.
5655 	if (backend.native_row_major_matrix)
5656 		return false;
5657 
5658 	// Non-matrix or column-major matrix types do not need to be converted.
5659 	if (!has_member_decoration(type.self, index, DecorationRowMajor))
5660 		return false;
5661 
5662 	// Generate a function that will swap matrix elements from row-major to column-major.
5663 	// Packed row-matrix should just use transpose() function.
5664 	if (!has_extended_member_decoration(type.self, index, SPIRVCrossDecorationPacked))
5665 	{
5666 		const auto mbr_type = get<SPIRType>(type.member_types[index]);
5667 		add_convert_row_major_matrix_function(mbr_type.columns, mbr_type.vecsize);
5668 	}
5669 
5670 	return true;
5671 }
5672 
5673 // Adds a function suitable for converting a non-square row-major matrix to a column-major matrix.
add_convert_row_major_matrix_function(uint32_t cols,uint32_t rows)5674 void CompilerMSL::add_convert_row_major_matrix_function(uint32_t cols, uint32_t rows)
5675 {
5676 	SPVFuncImpl spv_func;
5677 	if (cols == rows) // Square matrix...just use transpose() function
5678 		return;
5679 	else if (cols == 2 && rows == 3)
5680 		spv_func = SPVFuncImplRowMajor2x3;
5681 	else if (cols == 2 && rows == 4)
5682 		spv_func = SPVFuncImplRowMajor2x4;
5683 	else if (cols == 3 && rows == 2)
5684 		spv_func = SPVFuncImplRowMajor3x2;
5685 	else if (cols == 3 && rows == 4)
5686 		spv_func = SPVFuncImplRowMajor3x4;
5687 	else if (cols == 4 && rows == 2)
5688 		spv_func = SPVFuncImplRowMajor4x2;
5689 	else if (cols == 4 && rows == 3)
5690 		spv_func = SPVFuncImplRowMajor4x3;
5691 	else
5692 		SPIRV_CROSS_THROW("Could not convert row-major matrix.");
5693 
5694 	auto rslt = spv_function_implementations.insert(spv_func);
5695 	if (rslt.second)
5696 	{
5697 		suppress_missing_prototypes = true;
5698 		force_recompile();
5699 	}
5700 }
5701 
5702 // Wraps the expression string in a function call that converts the
5703 // row_major matrix result of the expression to a column_major matrix.
convert_row_major_matrix(string exp_str,const SPIRType & exp_type,bool is_packed)5704 string CompilerMSL::convert_row_major_matrix(string exp_str, const SPIRType &exp_type, bool is_packed)
5705 {
5706 	strip_enclosed_expression(exp_str);
5707 
5708 	string func_name;
5709 
5710 	// Square and packed matrices can just use transpose
5711 	if (exp_type.columns == exp_type.vecsize || is_packed)
5712 		func_name = "transpose";
5713 	else
5714 		func_name = string("spvConvertFromRowMajor") + to_string(exp_type.columns) + "x" + to_string(exp_type.vecsize);
5715 
5716 	return join(func_name, "(", exp_str, ")");
5717 }
5718 
5719 // Called automatically at the end of the entry point function
emit_fixup()5720 void CompilerMSL::emit_fixup()
5721 {
5722 	if ((get_execution_model() == ExecutionModelVertex ||
5723 	     get_execution_model() == ExecutionModelTessellationEvaluation) &&
5724 	    stage_out_var_id && !qual_pos_var_name.empty() && !capture_output_to_buffer)
5725 	{
5726 		if (options.vertex.fixup_clipspace)
5727 			statement(qual_pos_var_name, ".z = (", qual_pos_var_name, ".z + ", qual_pos_var_name,
5728 			          ".w) * 0.5;       // Adjust clip-space for Metal");
5729 
5730 		if (options.vertex.flip_vert_y)
5731 			statement(qual_pos_var_name, ".y = -(", qual_pos_var_name, ".y);", "    // Invert Y-axis for Metal");
5732 	}
5733 }
5734 
5735 // Return a string defining a structure member, with padding and packing.
to_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier)5736 string CompilerMSL::to_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
5737                                      const string &qualifier)
5738 {
5739 	auto &membertype = get<SPIRType>(member_type_id);
5740 
5741 	// If this member requires padding to maintain alignment, emit a dummy padding member.
5742 	MSLStructMemberKey key = get_struct_member_key(type.self, index);
5743 	uint32_t pad_len = struct_member_padding[key];
5744 	if (pad_len > 0)
5745 		statement("char _m", index, "_pad", "[", to_string(pad_len), "];");
5746 
5747 	// If this member is packed, mark it as so.
5748 	string pack_pfx = "";
5749 
5750 	const SPIRType *effective_membertype = &membertype;
5751 	SPIRType override_type;
5752 
5753 	uint32_t orig_id = 0;
5754 	if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID))
5755 		orig_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID);
5756 
5757 	if (member_is_packed_type(type, index))
5758 	{
5759 		// If we're packing a matrix, output an appropriate typedef
5760 		if (membertype.basetype == SPIRType::Struct)
5761 		{
5762 			pack_pfx = "/* FIXME: A padded struct is needed here. If you see this message, file a bug! */ ";
5763 		}
5764 		else if (membertype.vecsize > 1 && membertype.columns > 1)
5765 		{
5766 			uint32_t rows = membertype.vecsize;
5767 			uint32_t cols = membertype.columns;
5768 			pack_pfx = "packed_";
5769 			if (has_member_decoration(type.self, index, DecorationRowMajor))
5770 			{
5771 				// These are stored transposed.
5772 				rows = membertype.columns;
5773 				cols = membertype.vecsize;
5774 				pack_pfx = "packed_rm_";
5775 			}
5776 			string base_type = membertype.width == 16 ? "half" : "float";
5777 			string td_line = "typedef ";
5778 			td_line += "packed_" + base_type + to_string(rows);
5779 			td_line += " " + pack_pfx;
5780 			// Use the actual matrix size here.
5781 			td_line += base_type + to_string(membertype.columns) + "x" + to_string(membertype.vecsize);
5782 			td_line += "[" + to_string(cols) + "]";
5783 			td_line += ";";
5784 			add_typedef_line(td_line);
5785 		}
5786 		else if (is_array(membertype) && membertype.vecsize <= 2 && membertype.basetype != SPIRType::Struct &&
5787 		         type_struct_member_array_stride(type, index) == 4 * membertype.width / 8)
5788 		{
5789 			// A "packed" float array, but we pad here instead to 4-vector.
5790 			override_type = membertype;
5791 			override_type.vecsize = 4;
5792 			effective_membertype = &override_type;
5793 		}
5794 		else
5795 			pack_pfx = "packed_";
5796 	}
5797 
5798 	// Very specifically, image load-store in argument buffers are disallowed on MSL on iOS.
5799 	if (msl_options.is_ios() && membertype.basetype == SPIRType::Image && membertype.image.sampled == 2)
5800 	{
5801 		if (!has_decoration(orig_id, DecorationNonWritable))
5802 			SPIRV_CROSS_THROW("Writable images are not allowed in argument buffers on iOS.");
5803 	}
5804 
5805 	// Array information is baked into these types.
5806 	string array_type;
5807 	if (membertype.basetype != SPIRType::Image && membertype.basetype != SPIRType::Sampler &&
5808 	    membertype.basetype != SPIRType::SampledImage)
5809 	{
5810 		array_type = type_to_array_glsl(membertype);
5811 	}
5812 
5813 	return join(pack_pfx, type_to_glsl(*effective_membertype, orig_id), " ", qualifier, to_member_name(type, index),
5814 	            member_attribute_qualifier(type, index), array_type, ";");
5815 }
5816 
5817 // Emit a structure member, padding and packing to maintain the correct memeber alignments.
emit_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier,uint32_t)5818 void CompilerMSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
5819                                      const string &qualifier, uint32_t)
5820 {
5821 	statement(to_struct_member(type, member_type_id, index, qualifier));
5822 }
5823 
5824 // Return a MSL qualifier for the specified function attribute member
member_attribute_qualifier(const SPIRType & type,uint32_t index)5825 string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t index)
5826 {
5827 	auto &execution = get_entry_point();
5828 
5829 	uint32_t mbr_type_id = type.member_types[index];
5830 	auto &mbr_type = get<SPIRType>(mbr_type_id);
5831 
5832 	BuiltIn builtin = BuiltInMax;
5833 	bool is_builtin = is_member_builtin(type, index, &builtin);
5834 
5835 	if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary))
5836 		return join(" [[id(",
5837 		            get_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary), ")]]");
5838 
5839 	// Vertex function inputs
5840 	if (execution.model == ExecutionModelVertex && type.storage == StorageClassInput)
5841 	{
5842 		if (is_builtin)
5843 		{
5844 			switch (builtin)
5845 			{
5846 			case BuiltInVertexId:
5847 			case BuiltInVertexIndex:
5848 			case BuiltInBaseVertex:
5849 			case BuiltInInstanceId:
5850 			case BuiltInInstanceIndex:
5851 			case BuiltInBaseInstance:
5852 				return string(" [[") + builtin_qualifier(builtin) + "]]";
5853 
5854 			case BuiltInDrawIndex:
5855 				SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
5856 
5857 			default:
5858 				return "";
5859 			}
5860 		}
5861 		uint32_t locn = get_ordered_member_location(type.self, index);
5862 		if (locn != k_unknown_location)
5863 			return string(" [[attribute(") + convert_to_string(locn) + ")]]";
5864 	}
5865 
5866 	// Vertex and tessellation evaluation function outputs
5867 	if ((execution.model == ExecutionModelVertex || execution.model == ExecutionModelTessellationEvaluation) &&
5868 	    type.storage == StorageClassOutput)
5869 	{
5870 		if (is_builtin)
5871 		{
5872 			switch (builtin)
5873 			{
5874 			case BuiltInPointSize:
5875 				// Only mark the PointSize builtin if really rendering points.
5876 				// Some shaders may include a PointSize builtin even when used to render
5877 				// non-point topologies, and Metal will reject this builtin when compiling
5878 				// the shader into a render pipeline that uses a non-point topology.
5879 				return msl_options.enable_point_size_builtin ? (string(" [[") + builtin_qualifier(builtin) + "]]") : "";
5880 
5881 			case BuiltInViewportIndex:
5882 				if (!msl_options.supports_msl_version(2, 0))
5883 					SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
5884 				/* fallthrough */
5885 			case BuiltInPosition:
5886 			case BuiltInLayer:
5887 			case BuiltInClipDistance:
5888 				return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
5889 
5890 			default:
5891 				return "";
5892 			}
5893 		}
5894 		uint32_t comp;
5895 		uint32_t locn = get_ordered_member_location(type.self, index, &comp);
5896 		if (locn != k_unknown_location)
5897 		{
5898 			if (comp != k_unknown_component)
5899 				return string(" [[user(locn") + convert_to_string(locn) + "_" + convert_to_string(comp) + ")]]";
5900 			else
5901 				return string(" [[user(locn") + convert_to_string(locn) + ")]]";
5902 		}
5903 	}
5904 
5905 	// Tessellation control function inputs
5906 	if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassInput)
5907 	{
5908 		if (is_builtin)
5909 		{
5910 			switch (builtin)
5911 			{
5912 			case BuiltInInvocationId:
5913 			case BuiltInPrimitiveId:
5914 			case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
5915 			case BuiltInSubgroupSize: // FIXME: Should work in any stage
5916 				return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
5917 			case BuiltInPatchVertices:
5918 				return "";
5919 			// Others come from stage input.
5920 			default:
5921 				break;
5922 			}
5923 		}
5924 		uint32_t locn = get_ordered_member_location(type.self, index);
5925 		if (locn != k_unknown_location)
5926 			return string(" [[attribute(") + convert_to_string(locn) + ")]]";
5927 	}
5928 
5929 	// Tessellation control function outputs
5930 	if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassOutput)
5931 	{
5932 		// For this type of shader, we always arrange for it to capture its
5933 		// output to a buffer. For this reason, qualifiers are irrelevant here.
5934 		return "";
5935 	}
5936 
5937 	// Tessellation evaluation function inputs
5938 	if (execution.model == ExecutionModelTessellationEvaluation && type.storage == StorageClassInput)
5939 	{
5940 		if (is_builtin)
5941 		{
5942 			switch (builtin)
5943 			{
5944 			case BuiltInPrimitiveId:
5945 			case BuiltInTessCoord:
5946 				return string(" [[") + builtin_qualifier(builtin) + "]]";
5947 			case BuiltInPatchVertices:
5948 				return "";
5949 			// Others come from stage input.
5950 			default:
5951 				break;
5952 			}
5953 		}
5954 		// The special control point array must not be marked with an attribute.
5955 		if (get_type(type.member_types[index]).basetype == SPIRType::ControlPointArray)
5956 			return "";
5957 		uint32_t locn = get_ordered_member_location(type.self, index);
5958 		if (locn != k_unknown_location)
5959 			return string(" [[attribute(") + convert_to_string(locn) + ")]]";
5960 	}
5961 
5962 	// Tessellation evaluation function outputs were handled above.
5963 
5964 	// Fragment function inputs
5965 	if (execution.model == ExecutionModelFragment && type.storage == StorageClassInput)
5966 	{
5967 		string quals;
5968 		if (is_builtin)
5969 		{
5970 			switch (builtin)
5971 			{
5972 			case BuiltInViewIndex:
5973 				if (!msl_options.multiview)
5974 					break;
5975 				/* fallthrough */
5976 			case BuiltInFrontFacing:
5977 			case BuiltInPointCoord:
5978 			case BuiltInFragCoord:
5979 			case BuiltInSampleId:
5980 			case BuiltInSampleMask:
5981 			case BuiltInLayer:
5982 			case BuiltInBaryCoordNV:
5983 			case BuiltInBaryCoordNoPerspNV:
5984 				quals = builtin_qualifier(builtin);
5985 				break;
5986 
5987 			default:
5988 				break;
5989 			}
5990 		}
5991 		else
5992 		{
5993 			uint32_t comp;
5994 			uint32_t locn = get_ordered_member_location(type.self, index, &comp);
5995 			if (locn != k_unknown_location)
5996 			{
5997 				if (comp != k_unknown_component)
5998 					quals = string("user(locn") + convert_to_string(locn) + "_" + convert_to_string(comp) + ")";
5999 				else
6000 					quals = string("user(locn") + convert_to_string(locn) + ")";
6001 			}
6002 		}
6003 
6004 		if (builtin == BuiltInBaryCoordNV || builtin == BuiltInBaryCoordNoPerspNV)
6005 		{
6006 			if (has_member_decoration(type.self, index, DecorationFlat) ||
6007 			    has_member_decoration(type.self, index, DecorationCentroid) ||
6008 			    has_member_decoration(type.self, index, DecorationSample) ||
6009 			    has_member_decoration(type.self, index, DecorationNoPerspective))
6010 			{
6011 				// NoPerspective is baked into the builtin type.
6012 				SPIRV_CROSS_THROW(
6013 				    "Flat, Centroid, Sample, NoPerspective decorations are not supported for BaryCoord inputs.");
6014 			}
6015 		}
6016 
6017 		// Don't bother decorating integers with the 'flat' attribute; it's
6018 		// the default (in fact, the only option). Also don't bother with the
6019 		// FragCoord builtin; it's always noperspective on Metal.
6020 		if (!type_is_integral(mbr_type) && (!is_builtin || builtin != BuiltInFragCoord))
6021 		{
6022 			if (has_member_decoration(type.self, index, DecorationFlat))
6023 			{
6024 				if (!quals.empty())
6025 					quals += ", ";
6026 				quals += "flat";
6027 			}
6028 			else if (has_member_decoration(type.self, index, DecorationCentroid))
6029 			{
6030 				if (!quals.empty())
6031 					quals += ", ";
6032 				if (has_member_decoration(type.self, index, DecorationNoPerspective))
6033 					quals += "centroid_no_perspective";
6034 				else
6035 					quals += "centroid_perspective";
6036 			}
6037 			else if (has_member_decoration(type.self, index, DecorationSample))
6038 			{
6039 				if (!quals.empty())
6040 					quals += ", ";
6041 				if (has_member_decoration(type.self, index, DecorationNoPerspective))
6042 					quals += "sample_no_perspective";
6043 				else
6044 					quals += "sample_perspective";
6045 			}
6046 			else if (has_member_decoration(type.self, index, DecorationNoPerspective))
6047 			{
6048 				if (!quals.empty())
6049 					quals += ", ";
6050 				quals += "center_no_perspective";
6051 			}
6052 		}
6053 
6054 		if (!quals.empty())
6055 			return " [[" + quals + "]]";
6056 	}
6057 
6058 	// Fragment function outputs
6059 	if (execution.model == ExecutionModelFragment && type.storage == StorageClassOutput)
6060 	{
6061 		if (is_builtin)
6062 		{
6063 			switch (builtin)
6064 			{
6065 			case BuiltInFragStencilRefEXT:
6066 				if (!msl_options.supports_msl_version(2, 1))
6067 					SPIRV_CROSS_THROW("Stencil export only supported in MSL 2.1 and up.");
6068 				return string(" [[") + builtin_qualifier(builtin) + "]]";
6069 
6070 			case BuiltInSampleMask:
6071 			case BuiltInFragDepth:
6072 				return string(" [[") + builtin_qualifier(builtin) + "]]";
6073 
6074 			default:
6075 				return "";
6076 			}
6077 		}
6078 		uint32_t locn = get_ordered_member_location(type.self, index);
6079 		if (locn != k_unknown_location && has_member_decoration(type.self, index, DecorationIndex))
6080 			return join(" [[color(", locn, "), index(", get_member_decoration(type.self, index, DecorationIndex),
6081 			            ")]]");
6082 		else if (locn != k_unknown_location)
6083 			return join(" [[color(", locn, ")]]");
6084 		else if (has_member_decoration(type.self, index, DecorationIndex))
6085 			return join(" [[index(", get_member_decoration(type.self, index, DecorationIndex), ")]]");
6086 		else
6087 			return "";
6088 	}
6089 
6090 	// Compute function inputs
6091 	if (execution.model == ExecutionModelGLCompute && type.storage == StorageClassInput)
6092 	{
6093 		if (is_builtin)
6094 		{
6095 			switch (builtin)
6096 			{
6097 			case BuiltInGlobalInvocationId:
6098 			case BuiltInWorkgroupId:
6099 			case BuiltInNumWorkgroups:
6100 			case BuiltInLocalInvocationId:
6101 			case BuiltInLocalInvocationIndex:
6102 			case BuiltInNumSubgroups:
6103 			case BuiltInSubgroupId:
6104 			case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
6105 			case BuiltInSubgroupSize: // FIXME: Should work in any stage
6106 				return string(" [[") + builtin_qualifier(builtin) + "]]";
6107 
6108 			default:
6109 				return "";
6110 			}
6111 		}
6112 	}
6113 
6114 	return "";
6115 }
6116 
6117 // Returns the location decoration of the member with the specified index in the specified type.
6118 // If the location of the member has been explicitly set, that location is used. If not, this
6119 // function assumes the members are ordered in their location order, and simply returns the
6120 // index as the location.
get_ordered_member_location(uint32_t type_id,uint32_t index,uint32_t * comp)6121 uint32_t CompilerMSL::get_ordered_member_location(uint32_t type_id, uint32_t index, uint32_t *comp)
6122 {
6123 	auto &m = ir.meta[type_id];
6124 	if (index < m.members.size())
6125 	{
6126 		auto &dec = m.members[index];
6127 		if (comp)
6128 		{
6129 			if (dec.decoration_flags.get(DecorationComponent))
6130 				*comp = dec.component;
6131 			else
6132 				*comp = k_unknown_component;
6133 		}
6134 		if (dec.decoration_flags.get(DecorationLocation))
6135 			return dec.location;
6136 	}
6137 
6138 	return index;
6139 }
6140 
6141 // Returns the type declaration for a function, including the
6142 // entry type if the current function is the entry point function
func_type_decl(SPIRType & type)6143 string CompilerMSL::func_type_decl(SPIRType &type)
6144 {
6145 	// The regular function return type. If not processing the entry point function, that's all we need
6146 	string return_type = type_to_glsl(type) + type_to_array_glsl(type);
6147 	if (!processing_entry_point)
6148 		return return_type;
6149 
6150 	// If an outgoing interface block has been defined, and it should be returned, override the entry point return type
6151 	bool ep_should_return_output = !get_is_rasterization_disabled();
6152 	if (stage_out_var_id && ep_should_return_output)
6153 		return_type = type_to_glsl(get_stage_out_struct_type()) + type_to_array_glsl(type);
6154 
6155 	// Prepend a entry type, based on the execution model
6156 	string entry_type;
6157 	auto &execution = get_entry_point();
6158 	switch (execution.model)
6159 	{
6160 	case ExecutionModelVertex:
6161 		entry_type = "vertex";
6162 		break;
6163 	case ExecutionModelTessellationEvaluation:
6164 		if (!msl_options.supports_msl_version(1, 2))
6165 			SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
6166 		if (execution.flags.get(ExecutionModeIsolines))
6167 			SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
6168 		if (msl_options.is_ios())
6169 			entry_type =
6170 			    join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ") ]] vertex");
6171 		else
6172 			entry_type = join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ", ",
6173 			                  execution.output_vertices, ") ]] vertex");
6174 		break;
6175 	case ExecutionModelFragment:
6176 		entry_type =
6177 		    execution.flags.get(ExecutionModeEarlyFragmentTests) ? "[[ early_fragment_tests ]] fragment" : "fragment";
6178 		break;
6179 	case ExecutionModelTessellationControl:
6180 		if (!msl_options.supports_msl_version(1, 2))
6181 			SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
6182 		if (execution.flags.get(ExecutionModeIsolines))
6183 			SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
6184 		/* fallthrough */
6185 	case ExecutionModelGLCompute:
6186 	case ExecutionModelKernel:
6187 		entry_type = "kernel";
6188 		break;
6189 	default:
6190 		entry_type = "unknown";
6191 		break;
6192 	}
6193 
6194 	return entry_type + " " + return_type;
6195 }
6196 
6197 // In MSL, address space qualifiers are required for all pointer or reference variables
get_argument_address_space(const SPIRVariable & argument)6198 string CompilerMSL::get_argument_address_space(const SPIRVariable &argument)
6199 {
6200 	const auto &type = get<SPIRType>(argument.basetype);
6201 
6202 	switch (type.storage)
6203 	{
6204 	case StorageClassWorkgroup:
6205 		return "threadgroup";
6206 
6207 	case StorageClassStorageBuffer:
6208 	{
6209 		// For arguments from variable pointers, we use the write count deduction, so
6210 		// we should not assume any constness here. Only for global SSBOs.
6211 		bool readonly = false;
6212 		if (has_decoration(type.self, DecorationBlock))
6213 			readonly = ir.get_buffer_block_flags(argument).get(DecorationNonWritable);
6214 
6215 		return readonly ? "const device" : "device";
6216 	}
6217 
6218 	case StorageClassUniform:
6219 	case StorageClassUniformConstant:
6220 	case StorageClassPushConstant:
6221 		if (type.basetype == SPIRType::Struct)
6222 		{
6223 			bool ssbo = has_decoration(type.self, DecorationBufferBlock);
6224 			if (ssbo)
6225 			{
6226 				bool readonly = ir.get_buffer_block_flags(argument).get(DecorationNonWritable);
6227 				return readonly ? "const device" : "device";
6228 			}
6229 			else
6230 				return "constant";
6231 		}
6232 		break;
6233 
6234 	case StorageClassFunction:
6235 	case StorageClassGeneric:
6236 		// No address space for plain values.
6237 		return type.pointer ? "thread" : "";
6238 
6239 	case StorageClassInput:
6240 		if (get_execution_model() == ExecutionModelTessellationControl && argument.basevariable == stage_in_ptr_var_id)
6241 			return "threadgroup";
6242 		break;
6243 
6244 	case StorageClassOutput:
6245 		if (capture_output_to_buffer)
6246 			return "device";
6247 		break;
6248 
6249 	default:
6250 		break;
6251 	}
6252 
6253 	return "thread";
6254 }
6255 
get_type_address_space(const SPIRType & type,uint32_t id)6256 string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id)
6257 {
6258 	switch (type.storage)
6259 	{
6260 	case StorageClassWorkgroup:
6261 		return "threadgroup";
6262 
6263 	case StorageClassStorageBuffer:
6264 	{
6265 		// This can be called for variable pointer contexts as well, so be very careful about which method we choose.
6266 		Bitset flags;
6267 		if (ir.ids[id].get_type() == TypeVariable && has_decoration(type.self, DecorationBlock))
6268 			flags = get_buffer_block_flags(id);
6269 		else
6270 			flags = get_decoration_bitset(id);
6271 
6272 		return flags.get(DecorationNonWritable) ? "const device" : "device";
6273 	}
6274 
6275 	case StorageClassUniform:
6276 	case StorageClassUniformConstant:
6277 	case StorageClassPushConstant:
6278 		if (type.basetype == SPIRType::Struct)
6279 		{
6280 			bool ssbo = has_decoration(type.self, DecorationBufferBlock);
6281 			if (ssbo)
6282 			{
6283 				// This can be called for variable pointer contexts as well, so be very careful about which method we choose.
6284 				Bitset flags;
6285 				if (ir.ids[id].get_type() == TypeVariable && has_decoration(type.self, DecorationBlock))
6286 					flags = get_buffer_block_flags(id);
6287 				else
6288 					flags = get_decoration_bitset(id);
6289 
6290 				return flags.get(DecorationNonWritable) ? "const device" : "device";
6291 			}
6292 			else
6293 				return "constant";
6294 		}
6295 		else
6296 			return "constant";
6297 
6298 	case StorageClassFunction:
6299 	case StorageClassGeneric:
6300 		// No address space for plain values.
6301 		return type.pointer ? "thread" : "";
6302 
6303 	case StorageClassOutput:
6304 		if (capture_output_to_buffer)
6305 			return "device";
6306 		break;
6307 
6308 	default:
6309 		break;
6310 	}
6311 
6312 	return "thread";
6313 }
6314 
entry_point_arg_stage_in()6315 string CompilerMSL::entry_point_arg_stage_in()
6316 {
6317 	string decl;
6318 
6319 	// Stage-in structure
6320 	uint32_t stage_in_id;
6321 	if (get_execution_model() == ExecutionModelTessellationEvaluation)
6322 		stage_in_id = patch_stage_in_var_id;
6323 	else
6324 		stage_in_id = stage_in_var_id;
6325 
6326 	if (stage_in_id)
6327 	{
6328 		auto &var = get<SPIRVariable>(stage_in_id);
6329 		auto &type = get_variable_data_type(var);
6330 
6331 		add_resource_name(var.self);
6332 		decl = join(type_to_glsl(type), " ", to_name(var.self), " [[stage_in]]");
6333 	}
6334 
6335 	return decl;
6336 }
6337 
entry_point_args_builtin(string & ep_args)6338 void CompilerMSL::entry_point_args_builtin(string &ep_args)
6339 {
6340 	// Builtin variables
6341 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
6342 		auto bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
6343 
6344 		// Don't emit SamplePosition as a separate parameter. In the entry
6345 		// point, we get that by calling get_sample_position() on the sample ID.
6346 		if (var.storage == StorageClassInput && is_builtin_variable(var) &&
6347 		    get_variable_data_type(var).basetype != SPIRType::Struct &&
6348 		    get_variable_data_type(var).basetype != SPIRType::ControlPointArray)
6349 		{
6350 			// If the builtin is not part of the active input builtin set, don't emit it.
6351 			// Relevant for multiple entry-point modules which might declare unused builtins.
6352 			if (!active_input_builtins.get(bi_type) || !interface_variable_exists_in_entry_point(var_id))
6353 				return;
6354 
6355 			// These builtins are emitted specially. If we pass this branch, the builtin directly matches
6356 			// a MSL builtin.
6357 			if (bi_type != BuiltInSamplePosition && bi_type != BuiltInHelperInvocation &&
6358 			    bi_type != BuiltInPatchVertices && bi_type != BuiltInTessLevelInner &&
6359 			    bi_type != BuiltInTessLevelOuter && bi_type != BuiltInPosition && bi_type != BuiltInPointSize &&
6360 			    bi_type != BuiltInClipDistance && bi_type != BuiltInCullDistance && bi_type != BuiltInSubgroupEqMask &&
6361 			    bi_type != BuiltInBaryCoordNV && bi_type != BuiltInBaryCoordNoPerspNV &&
6362 			    bi_type != BuiltInSubgroupGeMask && bi_type != BuiltInSubgroupGtMask &&
6363 			    bi_type != BuiltInSubgroupLeMask && bi_type != BuiltInSubgroupLtMask &&
6364 			    ((get_execution_model() == ExecutionModelFragment && msl_options.multiview) ||
6365 			     bi_type != BuiltInViewIndex) &&
6366 			    (get_execution_model() == ExecutionModelGLCompute ||
6367 			     (get_execution_model() == ExecutionModelFragment && msl_options.supports_msl_version(2, 2)) ||
6368 			     (bi_type != BuiltInSubgroupLocalInvocationId && bi_type != BuiltInSubgroupSize)))
6369 			{
6370 				if (!ep_args.empty())
6371 					ep_args += ", ";
6372 
6373 				ep_args += builtin_type_decl(bi_type, var_id) + " " + to_expression(var_id);
6374 				ep_args += " [[" + builtin_qualifier(bi_type) + "]]";
6375 			}
6376 		}
6377 	});
6378 
6379 	// Vertex and instance index built-ins
6380 	if (needs_vertex_idx_arg)
6381 		ep_args += built_in_func_arg(BuiltInVertexIndex, !ep_args.empty());
6382 
6383 	if (needs_instance_idx_arg)
6384 		ep_args += built_in_func_arg(BuiltInInstanceIndex, !ep_args.empty());
6385 
6386 	if (capture_output_to_buffer)
6387 	{
6388 		// Add parameters to hold the indirect draw parameters and the shader output. This has to be handled
6389 		// specially because it needs to be a pointer, not a reference.
6390 		if (stage_out_var_id)
6391 		{
6392 			if (!ep_args.empty())
6393 				ep_args += ", ";
6394 			ep_args += join("device ", type_to_glsl(get_stage_out_struct_type()), "* ", output_buffer_var_name,
6395 			                " [[buffer(", msl_options.shader_output_buffer_index, ")]]");
6396 		}
6397 
6398 		if (get_execution_model() == ExecutionModelTessellationControl)
6399 		{
6400 			if (!ep_args.empty())
6401 				ep_args += ", ";
6402 			ep_args +=
6403 			    join("constant uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
6404 		}
6405 		else if (stage_out_var_id)
6406 		{
6407 			if (!ep_args.empty())
6408 				ep_args += ", ";
6409 			ep_args +=
6410 			    join("device uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
6411 		}
6412 
6413 		// Tessellation control shaders get three additional parameters:
6414 		// a buffer to hold the per-patch data, a buffer to hold the per-patch
6415 		// tessellation levels, and a block of workgroup memory to hold the
6416 		// input control point data.
6417 		if (get_execution_model() == ExecutionModelTessellationControl)
6418 		{
6419 			if (patch_stage_out_var_id)
6420 			{
6421 				if (!ep_args.empty())
6422 					ep_args += ", ";
6423 				ep_args +=
6424 				    join("device ", type_to_glsl(get_patch_stage_out_struct_type()), "* ", patch_output_buffer_var_name,
6425 				         " [[buffer(", convert_to_string(msl_options.shader_patch_output_buffer_index), ")]]");
6426 			}
6427 			if (!ep_args.empty())
6428 				ep_args += ", ";
6429 			ep_args += join("device ", get_tess_factor_struct_name(), "* ", tess_factor_buffer_var_name, " [[buffer(",
6430 			                convert_to_string(msl_options.shader_tess_factor_buffer_index), ")]]");
6431 			if (stage_in_var_id)
6432 			{
6433 				if (!ep_args.empty())
6434 					ep_args += ", ";
6435 				ep_args += join("threadgroup ", type_to_glsl(get_stage_in_struct_type()), "* ", input_wg_var_name,
6436 				                " [[threadgroup(", convert_to_string(msl_options.shader_input_wg_index), ")]]");
6437 			}
6438 		}
6439 	}
6440 }
6441 
entry_point_args_argument_buffer(bool append_comma)6442 string CompilerMSL::entry_point_args_argument_buffer(bool append_comma)
6443 {
6444 	string ep_args = entry_point_arg_stage_in();
6445 	Bitset claimed_bindings;
6446 
6447 	for (uint32_t i = 0; i < kMaxArgumentBuffers; i++)
6448 	{
6449 		uint32_t id = argument_buffer_ids[i];
6450 		if (id == 0)
6451 			continue;
6452 
6453 		add_resource_name(id);
6454 		auto &var = get<SPIRVariable>(id);
6455 		auto &type = get_variable_data_type(var);
6456 
6457 		if (!ep_args.empty())
6458 			ep_args += ", ";
6459 
6460 		// Check if the argument buffer binding itself has been remapped.
6461 		uint32_t buffer_binding;
6462 		auto itr = resource_bindings.find({ get_entry_point().model, i, kArgumentBufferBinding });
6463 		if (itr != end(resource_bindings))
6464 		{
6465 			buffer_binding = itr->second.first.msl_buffer;
6466 			itr->second.second = true;
6467 		}
6468 		else
6469 		{
6470 			// As a fallback, directly map desc set <-> binding.
6471 			// If that was taken, take the next buffer binding.
6472 			if (claimed_bindings.get(i))
6473 				buffer_binding = next_metal_resource_index_buffer;
6474 			else
6475 				buffer_binding = i;
6476 		}
6477 
6478 		claimed_bindings.set(buffer_binding);
6479 
6480 		ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_name(id);
6481 		ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]";
6482 
6483 		next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1);
6484 	}
6485 
6486 	entry_point_args_discrete_descriptors(ep_args);
6487 	entry_point_args_builtin(ep_args);
6488 
6489 	if (!ep_args.empty() && append_comma)
6490 		ep_args += ", ";
6491 
6492 	return ep_args;
6493 }
6494 
find_constexpr_sampler(uint32_t id) const6495 const MSLConstexprSampler *CompilerMSL::find_constexpr_sampler(uint32_t id) const
6496 {
6497 	// Try by ID.
6498 	{
6499 		auto itr = constexpr_samplers_by_id.find(id);
6500 		if (itr != end(constexpr_samplers_by_id))
6501 			return &itr->second;
6502 	}
6503 
6504 	// Try by binding.
6505 	{
6506 		uint32_t desc_set = get_decoration(id, DecorationDescriptorSet);
6507 		uint32_t binding = get_decoration(id, DecorationBinding);
6508 
6509 		auto itr = constexpr_samplers_by_binding.find({ desc_set, binding });
6510 		if (itr != end(constexpr_samplers_by_binding))
6511 			return &itr->second;
6512 	}
6513 
6514 	return nullptr;
6515 }
6516 
entry_point_args_discrete_descriptors(string & ep_args)6517 void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
6518 {
6519 	// Output resources, sorted by resource index & type
6520 	// We need to sort to work around a bug on macOS 10.13 with NVidia drivers where switching between shaders
6521 	// with different order of buffers can result in issues with buffer assignments inside the driver.
6522 	struct Resource
6523 	{
6524 		SPIRVariable *var;
6525 		string name;
6526 		SPIRType::BaseType basetype;
6527 		uint32_t index;
6528 	};
6529 
6530 	SmallVector<Resource> resources;
6531 
6532 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
6533 		if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
6534 		     var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer) &&
6535 		    !is_hidden_variable(var))
6536 		{
6537 			auto &type = get_variable_data_type(var);
6538 			uint32_t var_id = var.self;
6539 
6540 			if (var.storage != StorageClassPushConstant)
6541 			{
6542 				uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
6543 				if (descriptor_set_is_argument_buffer(desc_set))
6544 					return;
6545 			}
6546 
6547 			const MSLConstexprSampler *constexpr_sampler = nullptr;
6548 			if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
6549 			{
6550 				constexpr_sampler = find_constexpr_sampler(var_id);
6551 				if (constexpr_sampler)
6552 				{
6553 					// Mark this ID as a constexpr sampler for later in case it came from set/bindings.
6554 					constexpr_samplers_by_id[var_id] = *constexpr_sampler;
6555 				}
6556 			}
6557 
6558 			if (type.basetype == SPIRType::SampledImage)
6559 			{
6560 				add_resource_name(var_id);
6561 				resources.push_back(
6562 				    { &var, to_name(var_id), SPIRType::Image, get_metal_resource_index(var, SPIRType::Image) });
6563 
6564 				if (type.image.dim != DimBuffer && !constexpr_sampler)
6565 				{
6566 					resources.push_back({ &var, to_sampler_expression(var_id), SPIRType::Sampler,
6567 					                      get_metal_resource_index(var, SPIRType::Sampler) });
6568 				}
6569 			}
6570 			else if (!constexpr_sampler)
6571 			{
6572 				// constexpr samplers are not declared as resources.
6573 				add_resource_name(var_id);
6574 				resources.push_back(
6575 				    { &var, to_name(var_id), type.basetype, get_metal_resource_index(var, type.basetype) });
6576 			}
6577 		}
6578 	});
6579 
6580 	sort(resources.begin(), resources.end(), [](const Resource &lhs, const Resource &rhs) {
6581 		return tie(lhs.basetype, lhs.index) < tie(rhs.basetype, rhs.index);
6582 	});
6583 
6584 	for (auto &r : resources)
6585 	{
6586 		auto &var = *r.var;
6587 		auto &type = get_variable_data_type(var);
6588 
6589 		uint32_t var_id = var.self;
6590 
6591 		switch (r.basetype)
6592 		{
6593 		case SPIRType::Struct:
6594 		{
6595 			auto &m = ir.meta[type.self];
6596 			if (m.members.size() == 0)
6597 				break;
6598 			if (!type.array.empty())
6599 			{
6600 				if (type.array.size() > 1)
6601 					SPIRV_CROSS_THROW("Arrays of arrays of buffers are not supported.");
6602 
6603 				// Metal doesn't directly support this, so we must expand the
6604 				// array. We'll declare a local array to hold these elements
6605 				// later.
6606 				uint32_t array_size = to_array_size_literal(type);
6607 
6608 				if (array_size == 0)
6609 					SPIRV_CROSS_THROW("Unsized arrays of buffers are not supported in MSL.");
6610 
6611 				buffer_arrays.push_back(var_id);
6612 				for (uint32_t i = 0; i < array_size; ++i)
6613 				{
6614 					if (!ep_args.empty())
6615 						ep_args += ", ";
6616 					ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "* " + r.name + "_" +
6617 					           convert_to_string(i);
6618 					ep_args += " [[buffer(" + convert_to_string(r.index + i) + ")]]";
6619 				}
6620 			}
6621 			else
6622 			{
6623 				if (!ep_args.empty())
6624 					ep_args += ", ";
6625 				ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + r.name;
6626 				ep_args += " [[buffer(" + convert_to_string(r.index) + ")]]";
6627 			}
6628 			break;
6629 		}
6630 		case SPIRType::Sampler:
6631 			if (!ep_args.empty())
6632 				ep_args += ", ";
6633 			ep_args += sampler_type(type) + " " + r.name;
6634 			ep_args += " [[sampler(" + convert_to_string(r.index) + ")]]";
6635 			break;
6636 		case SPIRType::Image:
6637 			if (!ep_args.empty())
6638 				ep_args += ", ";
6639 			ep_args += image_type_glsl(type, var_id) + " " + r.name;
6640 			ep_args += " [[texture(" + convert_to_string(r.index) + ")]]";
6641 			break;
6642 		default:
6643 			if (!ep_args.empty())
6644 				ep_args += ", ";
6645 			ep_args += type_to_glsl(type, var_id) + " " + r.name;
6646 			ep_args += " [[buffer(" + convert_to_string(r.index) + ")]]";
6647 			break;
6648 		}
6649 	}
6650 }
6651 
6652 // Returns a string containing a comma-delimited list of args for the entry point function
6653 // This is the "classic" method of MSL 1 when we don't have argument buffer support.
entry_point_args_classic(bool append_comma)6654 string CompilerMSL::entry_point_args_classic(bool append_comma)
6655 {
6656 	string ep_args = entry_point_arg_stage_in();
6657 	entry_point_args_discrete_descriptors(ep_args);
6658 	entry_point_args_builtin(ep_args);
6659 
6660 	if (!ep_args.empty() && append_comma)
6661 		ep_args += ", ";
6662 
6663 	return ep_args;
6664 }
6665 
fix_up_shader_inputs_outputs()6666 void CompilerMSL::fix_up_shader_inputs_outputs()
6667 {
6668 	// Look for sampled images and buffer. Add hooks to set up the swizzle constants or array lengths.
6669 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
6670 		auto &type = get_variable_data_type(var);
6671 		uint32_t var_id = var.self;
6672 		bool ssbo = has_decoration(type.self, DecorationBufferBlock);
6673 
6674 		if (var.storage == StorageClassUniformConstant && !is_hidden_variable(var))
6675 		{
6676 			if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
6677 			{
6678 				auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
6679 				entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
6680 					bool is_array_type = !type.array.empty();
6681 
6682 					uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
6683 					if (descriptor_set_is_argument_buffer(desc_set))
6684 					{
6685 						statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
6686 						          is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
6687 						          ".spvSwizzleConstants", "[",
6688 						          convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
6689 					}
6690 					else
6691 					{
6692 						// If we have an array of images, we need to be able to index into it, so take a pointer instead.
6693 						statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
6694 						          is_array_type ? " = &" : " = ", to_name(swizzle_buffer_id), "[",
6695 						          convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
6696 					}
6697 				});
6698 			}
6699 		}
6700 		else if ((var.storage == StorageClassStorageBuffer || (var.storage == StorageClassUniform && ssbo)) &&
6701 		         !is_hidden_variable(var))
6702 		{
6703 			if (buffers_requiring_array_length.count(var.self))
6704 			{
6705 				auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
6706 				entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
6707 					bool is_array_type = !type.array.empty();
6708 
6709 					uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
6710 					if (descriptor_set_is_argument_buffer(desc_set))
6711 					{
6712 						statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
6713 						          is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
6714 						          ".spvBufferSizeConstants", "[",
6715 						          convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
6716 					}
6717 					else
6718 					{
6719 						// If we have an array of images, we need to be able to index into it, so take a pointer instead.
6720 						statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
6721 						          is_array_type ? " = &" : " = ", to_name(buffer_size_buffer_id), "[",
6722 						          convert_to_string(get_metal_resource_index(var, type.basetype)), "];");
6723 					}
6724 				});
6725 			}
6726 		}
6727 	});
6728 
6729 	// Builtin variables
6730 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
6731 		uint32_t var_id = var.self;
6732 		BuiltIn bi_type = ir.meta[var_id].decoration.builtin_type;
6733 
6734 		if (var.storage == StorageClassInput && is_builtin_variable(var))
6735 		{
6736 			auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
6737 			switch (bi_type)
6738 			{
6739 			case BuiltInSamplePosition:
6740 				entry_func.fixup_hooks_in.push_back([=]() {
6741 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = get_sample_position(",
6742 					          to_expression(builtin_sample_id_id), ");");
6743 				});
6744 				break;
6745 			case BuiltInHelperInvocation:
6746 				if (msl_options.is_ios())
6747 					SPIRV_CROSS_THROW("simd_is_helper_thread() is only supported on macOS.");
6748 				else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
6749 					SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
6750 
6751 				entry_func.fixup_hooks_in.push_back([=]() {
6752 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = simd_is_helper_thread();");
6753 				});
6754 				break;
6755 			case BuiltInPatchVertices:
6756 				if (get_execution_model() == ExecutionModelTessellationEvaluation)
6757 					entry_func.fixup_hooks_in.push_back([=]() {
6758 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
6759 						          to_expression(patch_stage_in_var_id), ".gl_in.size();");
6760 					});
6761 				else
6762 					entry_func.fixup_hooks_in.push_back([=]() {
6763 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = spvIndirectParams[0];");
6764 					});
6765 				break;
6766 			case BuiltInTessCoord:
6767 				// Emit a fixup to account for the shifted domain. Don't do this for triangles;
6768 				// MoltenVK will just reverse the winding order instead.
6769 				if (msl_options.tess_domain_origin_lower_left && !get_entry_point().flags.get(ExecutionModeTriangles))
6770 				{
6771 					string tc = to_expression(var_id);
6772 					entry_func.fixup_hooks_in.push_back([=]() { statement(tc, ".y = 1.0 - ", tc, ".y;"); });
6773 				}
6774 				break;
6775 			case BuiltInSubgroupLocalInvocationId:
6776 				// This is natively supported in compute shaders.
6777 				if (get_execution_model() == ExecutionModelGLCompute)
6778 					break;
6779 
6780 				// This is natively supported in fragment shaders in MSL 2.2.
6781 				if (get_execution_model() == ExecutionModelFragment && msl_options.supports_msl_version(2, 2))
6782 					break;
6783 
6784 				if (msl_options.is_ios())
6785 					SPIRV_CROSS_THROW(
6786 					    "SubgroupLocalInvocationId cannot be used outside of compute shaders before MSL 2.2 on iOS.");
6787 
6788 				if (!msl_options.supports_msl_version(2, 1))
6789 					SPIRV_CROSS_THROW(
6790 					    "SubgroupLocalInvocationId cannot be used outside of compute shaders before MSL 2.1.");
6791 
6792 				// Shaders other than compute shaders don't support the SIMD-group
6793 				// builtins directly, but we can emulate them using the SIMD-group
6794 				// functions. This might break if some of the subgroup terminated
6795 				// before reaching the entry point.
6796 				entry_func.fixup_hooks_in.push_back([=]() {
6797 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
6798 					          " = simd_prefix_exclusive_sum(1);");
6799 				});
6800 				break;
6801 			case BuiltInSubgroupSize:
6802 				// This is natively supported in compute shaders.
6803 				if (get_execution_model() == ExecutionModelGLCompute)
6804 					break;
6805 
6806 				// This is natively supported in fragment shaders in MSL 2.2.
6807 				if (get_execution_model() == ExecutionModelFragment && msl_options.supports_msl_version(2, 2))
6808 					break;
6809 
6810 				if (msl_options.is_ios())
6811 					SPIRV_CROSS_THROW("SubgroupSize cannot be used outside of compute shaders on iOS.");
6812 
6813 				if (!msl_options.supports_msl_version(2, 1))
6814 					SPIRV_CROSS_THROW("SubgroupSize cannot be used outside of compute shaders before Metal 2.1.");
6815 
6816 				entry_func.fixup_hooks_in.push_back(
6817 				    [=]() { statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = simd_sum(1);"); });
6818 				break;
6819 			case BuiltInSubgroupEqMask:
6820 				if (msl_options.is_ios())
6821 					SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
6822 				if (!msl_options.supports_msl_version(2, 1))
6823 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
6824 				entry_func.fixup_hooks_in.push_back([=]() {
6825 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
6826 					          to_expression(builtin_subgroup_invocation_id_id), " > 32 ? uint4(0, (1 << (",
6827 					          to_expression(builtin_subgroup_invocation_id_id), " - 32)), uint2(0)) : uint4(1 << ",
6828 					          to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
6829 				});
6830 				break;
6831 			case BuiltInSubgroupGeMask:
6832 				if (msl_options.is_ios())
6833 					SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
6834 				if (!msl_options.supports_msl_version(2, 1))
6835 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
6836 				entry_func.fixup_hooks_in.push_back([=]() {
6837 					// Case where index < 32, size < 32:
6838 					// mask0 = bfe(0xFFFFFFFF, index, size - index);
6839 					// mask1 = bfe(0xFFFFFFFF, 0, 0); // Gives 0
6840 					// Case where index < 32 but size >= 32:
6841 					// mask0 = bfe(0xFFFFFFFF, index, 32 - index);
6842 					// mask1 = bfe(0xFFFFFFFF, 0, size - 32);
6843 					// Case where index >= 32:
6844 					// mask0 = bfe(0xFFFFFFFF, 32, 0); // Gives 0
6845 					// mask1 = bfe(0xFFFFFFFF, index - 32, size - index);
6846 					// This is expressed without branches to avoid divergent
6847 					// control flow--hence the complicated min/max expressions.
6848 					// This is further complicated by the fact that if you attempt
6849 					// to bfe out-of-bounds on Metal, undefined behavior is the
6850 					// result.
6851 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
6852 					          " = uint4(extract_bits(0xFFFFFFFF, min(",
6853 					          to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(min((int)",
6854 					          to_expression(builtin_subgroup_size_id), ", 32) - (int)",
6855 					          to_expression(builtin_subgroup_invocation_id_id),
6856 					          ", 0)), extract_bits(0xFFFFFFFF, (uint)max((int)",
6857 					          to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), (uint)max((int)",
6858 					          to_expression(builtin_subgroup_size_id), " - (int)max(",
6859 					          to_expression(builtin_subgroup_invocation_id_id), ", 32u), 0)), uint2(0));");
6860 				});
6861 				break;
6862 			case BuiltInSubgroupGtMask:
6863 				if (msl_options.is_ios())
6864 					SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
6865 				if (!msl_options.supports_msl_version(2, 1))
6866 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
6867 				entry_func.fixup_hooks_in.push_back([=]() {
6868 					// The same logic applies here, except now the index is one
6869 					// more than the subgroup invocation ID.
6870 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
6871 					          " = uint4(extract_bits(0xFFFFFFFF, min(",
6872 					          to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(min((int)",
6873 					          to_expression(builtin_subgroup_size_id), ", 32) - (int)",
6874 					          to_expression(builtin_subgroup_invocation_id_id),
6875 					          " - 1, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)",
6876 					          to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), (uint)max((int)",
6877 					          to_expression(builtin_subgroup_size_id), " - (int)max(",
6878 					          to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), 0)), uint2(0));");
6879 				});
6880 				break;
6881 			case BuiltInSubgroupLeMask:
6882 				if (msl_options.is_ios())
6883 					SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
6884 				if (!msl_options.supports_msl_version(2, 1))
6885 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
6886 				entry_func.fixup_hooks_in.push_back([=]() {
6887 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
6888 					          " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
6889 					          to_expression(builtin_subgroup_invocation_id_id),
6890 					          " + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
6891 					          to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0)), uint2(0));");
6892 				});
6893 				break;
6894 			case BuiltInSubgroupLtMask:
6895 				if (msl_options.is_ios())
6896 					SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
6897 				if (!msl_options.supports_msl_version(2, 1))
6898 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
6899 				entry_func.fixup_hooks_in.push_back([=]() {
6900 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
6901 					          " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
6902 					          to_expression(builtin_subgroup_invocation_id_id),
6903 					          ", 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
6904 					          to_expression(builtin_subgroup_invocation_id_id), " - 32, 0)), uint2(0));");
6905 				});
6906 				break;
6907 			case BuiltInViewIndex:
6908 				if (!msl_options.multiview)
6909 				{
6910 					// According to the Vulkan spec, when not running under a multiview
6911 					// render pass, ViewIndex is 0.
6912 					entry_func.fixup_hooks_in.push_back([=]() {
6913 						statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = 0;");
6914 					});
6915 				}
6916 				else if (get_execution_model() == ExecutionModelFragment)
6917 				{
6918 					// Because we adjusted the view index in the vertex shader, we have to
6919 					// adjust it back here.
6920 					entry_func.fixup_hooks_in.push_back([=]() {
6921 						statement(to_expression(var_id), " += ", to_expression(view_mask_buffer_id), "[0];");
6922 					});
6923 				}
6924 				else if (get_execution_model() == ExecutionModelVertex)
6925 				{
6926 					// Metal provides no special support for multiview, so we smuggle
6927 					// the view index in the instance index.
6928 					entry_func.fixup_hooks_in.push_back([=]() {
6929 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
6930 						          to_expression(view_mask_buffer_id), "[0] + ", to_expression(builtin_instance_idx_id),
6931 						          " % ", to_expression(view_mask_buffer_id), "[1];");
6932 						statement(to_expression(builtin_instance_idx_id), " /= ", to_expression(view_mask_buffer_id),
6933 						          "[1];");
6934 					});
6935 					// In addition to setting the variable itself, we also need to
6936 					// set the render_target_array_index with it on output. We have to
6937 					// offset this by the base view index, because Metal isn't in on
6938 					// our little game here.
6939 					entry_func.fixup_hooks_out.push_back([=]() {
6940 						statement(to_expression(builtin_layer_id), " = ", to_expression(var_id), " - ",
6941 						          to_expression(view_mask_buffer_id), "[0];");
6942 					});
6943 				}
6944 				break;
6945 			default:
6946 				break;
6947 			}
6948 		}
6949 	});
6950 }
6951 
6952 // Returns the Metal index of the resource of the specified type as used by the specified variable.
get_metal_resource_index(SPIRVariable & var,SPIRType::BaseType basetype)6953 uint32_t CompilerMSL::get_metal_resource_index(SPIRVariable &var, SPIRType::BaseType basetype)
6954 {
6955 	auto &execution = get_entry_point();
6956 	auto &var_dec = ir.meta[var.self].decoration;
6957 	auto &var_type = get<SPIRType>(var.basetype);
6958 	uint32_t var_desc_set = (var.storage == StorageClassPushConstant) ? kPushConstDescSet : var_dec.set;
6959 	uint32_t var_binding = (var.storage == StorageClassPushConstant) ? kPushConstBinding : var_dec.binding;
6960 
6961 	// If a matching binding has been specified, find and use it.
6962 	auto itr = resource_bindings.find({ execution.model, var_desc_set, var_binding });
6963 
6964 	auto resource_decoration = var_type.basetype == SPIRType::SampledImage && basetype == SPIRType::Sampler ?
6965 	                               SPIRVCrossDecorationResourceIndexSecondary :
6966 	                               SPIRVCrossDecorationResourceIndexPrimary;
6967 
6968 	if (itr != end(resource_bindings))
6969 	{
6970 		auto &remap = itr->second;
6971 		remap.second = true;
6972 		switch (basetype)
6973 		{
6974 		case SPIRType::Image:
6975 			set_extended_decoration(var.self, resource_decoration, remap.first.msl_texture);
6976 			return remap.first.msl_texture;
6977 		case SPIRType::Sampler:
6978 			set_extended_decoration(var.self, resource_decoration, remap.first.msl_sampler);
6979 			return remap.first.msl_sampler;
6980 		default:
6981 			set_extended_decoration(var.self, resource_decoration, remap.first.msl_buffer);
6982 			return remap.first.msl_buffer;
6983 		}
6984 	}
6985 
6986 	// If we have already allocated an index, keep using it.
6987 	if (has_extended_decoration(var.self, resource_decoration))
6988 		return get_extended_decoration(var.self, resource_decoration);
6989 
6990 	// If we did not explicitly remap, allocate bindings on demand.
6991 	// We cannot reliably use Binding decorations since SPIR-V and MSL's binding models are very different.
6992 
6993 	uint32_t binding_stride = 1;
6994 	auto &type = get<SPIRType>(var.basetype);
6995 	for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
6996 		binding_stride *= type.array_size_literal[i] ? type.array[i] : get<SPIRConstant>(type.array[i]).scalar();
6997 
6998 	assert(binding_stride != 0);
6999 
7000 	// If a binding has not been specified, revert to incrementing resource indices.
7001 	uint32_t resource_index;
7002 
7003 	bool allocate_argument_buffer_ids = false;
7004 	uint32_t desc_set = 0;
7005 
7006 	if (var.storage != StorageClassPushConstant)
7007 	{
7008 		desc_set = get_decoration(var.self, DecorationDescriptorSet);
7009 		allocate_argument_buffer_ids = descriptor_set_is_argument_buffer(desc_set);
7010 	}
7011 
7012 	if (allocate_argument_buffer_ids)
7013 	{
7014 		// Allocate from a flat ID binding space.
7015 		resource_index = next_metal_resource_ids[desc_set];
7016 		next_metal_resource_ids[desc_set] += binding_stride;
7017 	}
7018 	else
7019 	{
7020 		// Allocate from plain bindings which are allocated per resource type.
7021 		switch (basetype)
7022 		{
7023 		case SPIRType::Image:
7024 			resource_index = next_metal_resource_index_texture;
7025 			next_metal_resource_index_texture += binding_stride;
7026 			break;
7027 		case SPIRType::Sampler:
7028 			resource_index = next_metal_resource_index_sampler;
7029 			next_metal_resource_index_sampler += binding_stride;
7030 			break;
7031 		default:
7032 			resource_index = next_metal_resource_index_buffer;
7033 			next_metal_resource_index_buffer += binding_stride;
7034 			break;
7035 		}
7036 	}
7037 
7038 	set_extended_decoration(var.self, resource_decoration, resource_index);
7039 	return resource_index;
7040 }
7041 
argument_decl(const SPIRFunction::Parameter & arg)7042 string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
7043 {
7044 	auto &var = get<SPIRVariable>(arg.id);
7045 	auto &type = get_variable_data_type(var);
7046 	auto &var_type = get<SPIRType>(arg.type);
7047 	StorageClass storage = var_type.storage;
7048 	bool is_pointer = var_type.pointer;
7049 
7050 	// If we need to modify the name of the variable, make sure we use the original variable.
7051 	// Our alias is just a shadow variable.
7052 	uint32_t name_id = var.self;
7053 	if (arg.alias_global_variable && var.basevariable)
7054 		name_id = var.basevariable;
7055 
7056 	bool constref = !arg.alias_global_variable && is_pointer && arg.write_count == 0;
7057 
7058 	bool type_is_image = type.basetype == SPIRType::Image || type.basetype == SPIRType::SampledImage ||
7059 	                     type.basetype == SPIRType::Sampler;
7060 
7061 	// Arrays of images/samplers in MSL are always const.
7062 	if (!type.array.empty() && type_is_image)
7063 		constref = true;
7064 
7065 	string decl;
7066 	if (constref)
7067 		decl += "const ";
7068 
7069 	bool builtin = is_builtin_variable(var);
7070 	if (var.basevariable == stage_in_ptr_var_id || var.basevariable == stage_out_ptr_var_id)
7071 		decl += type_to_glsl(type, arg.id);
7072 	else if (builtin)
7073 		decl += builtin_type_decl(static_cast<BuiltIn>(get_decoration(arg.id, DecorationBuiltIn)), arg.id);
7074 	else if ((storage == StorageClassUniform || storage == StorageClassStorageBuffer) && is_array(type))
7075 		decl += join(type_to_glsl(type, arg.id), "*");
7076 	else
7077 		decl += type_to_glsl(type, arg.id);
7078 
7079 	bool opaque_handle = storage == StorageClassUniformConstant;
7080 
7081 	string address_space = get_argument_address_space(var);
7082 
7083 	if (!builtin && !opaque_handle && !is_pointer &&
7084 	    (storage == StorageClassFunction || storage == StorageClassGeneric))
7085 	{
7086 		// If the argument is a pure value and not an opaque type, we will pass by value.
7087 		if (is_array(type))
7088 		{
7089 			// We are receiving an array by value. This is problematic.
7090 			// We cannot be sure of the target address space since we are supposed to receive a copy,
7091 			// but this is not possible with MSL without some extra work.
7092 			// We will have to assume we're getting a reference in thread address space.
7093 			// If we happen to get a reference in constant address space, the caller must emit a copy and pass that.
7094 			// Thread const therefore becomes the only logical choice, since we cannot "create" a constant array from
7095 			// non-constant arrays, but we can create thread const from constant.
7096 			decl = string("thread const ") + decl;
7097 			decl += " (&";
7098 			decl += to_expression(name_id);
7099 			decl += ")";
7100 			decl += type_to_array_glsl(type);
7101 		}
7102 		else
7103 		{
7104 			if (!address_space.empty())
7105 				decl = join(address_space, " ", decl);
7106 			decl += " ";
7107 			decl += to_expression(name_id);
7108 		}
7109 	}
7110 	else if (is_array(type) && !type_is_image)
7111 	{
7112 		// Arrays of images and samplers are special cased.
7113 		if (!address_space.empty())
7114 			decl = join(address_space, " ", decl);
7115 
7116 		if (msl_options.argument_buffers)
7117 		{
7118 			uint32_t desc_set = get_decoration(name_id, DecorationDescriptorSet);
7119 			if ((storage == StorageClassUniform || storage == StorageClassStorageBuffer) &&
7120 			    descriptor_set_is_argument_buffer(desc_set))
7121 			{
7122 				// An awkward case where we need to emit *more* address space declarations (yay!).
7123 				// An example is where we pass down an array of buffer pointers to leaf functions.
7124 				// It's a constant array containing pointers to constants.
7125 				// The pointer array is always constant however. E.g.
7126 				// device SSBO * constant (&array)[N].
7127 				// const device SSBO * constant (&array)[N].
7128 				// constant SSBO * constant (&array)[N].
7129 				// However, this only matters for argument buffers, since for MSL 1.0 style codegen,
7130 				// we emit the buffer array on stack instead, and that seems to work just fine apparently.
7131 				decl += " constant";
7132 			}
7133 		}
7134 
7135 		decl += " (&";
7136 		decl += to_expression(name_id);
7137 		decl += ")";
7138 		decl += type_to_array_glsl(type);
7139 	}
7140 	else if (!opaque_handle)
7141 	{
7142 		// If this is going to be a reference to a variable pointer, the address space
7143 		// for the reference has to go before the '&', but after the '*'.
7144 		if (!address_space.empty())
7145 		{
7146 			if (decl.back() == '*')
7147 				decl += join(" ", address_space, " ");
7148 			else
7149 				decl = join(address_space, " ", decl);
7150 		}
7151 		decl += "&";
7152 		decl += " ";
7153 		decl += to_expression(name_id);
7154 	}
7155 	else
7156 	{
7157 		if (!address_space.empty())
7158 			decl = join(address_space, " ", decl);
7159 		decl += " ";
7160 		decl += to_expression(name_id);
7161 	}
7162 
7163 	return decl;
7164 }
7165 
7166 // If we're currently in the entry point function, and the object
7167 // has a qualified name, use it, otherwise use the standard name.
to_name(uint32_t id,bool allow_alias) const7168 string CompilerMSL::to_name(uint32_t id, bool allow_alias) const
7169 {
7170 	if (current_function && (current_function->self == ir.default_entry_point))
7171 	{
7172 		auto *m = ir.find_meta(id);
7173 		if (m && !m->decoration.qualified_alias.empty())
7174 			return m->decoration.qualified_alias;
7175 	}
7176 	return Compiler::to_name(id, allow_alias);
7177 }
7178 
7179 // Returns a name that combines the name of the struct with the name of the member, except for Builtins
to_qualified_member_name(const SPIRType & type,uint32_t index)7180 string CompilerMSL::to_qualified_member_name(const SPIRType &type, uint32_t index)
7181 {
7182 	// Don't qualify Builtin names because they are unique and are treated as such when building expressions
7183 	BuiltIn builtin = BuiltInMax;
7184 	if (is_member_builtin(type, index, &builtin))
7185 		return builtin_to_glsl(builtin, type.storage);
7186 
7187 	// Strip any underscore prefix from member name
7188 	string mbr_name = to_member_name(type, index);
7189 	size_t startPos = mbr_name.find_first_not_of("_");
7190 	mbr_name = (startPos != string::npos) ? mbr_name.substr(startPos) : "";
7191 	return join(to_name(type.self), "_", mbr_name);
7192 }
7193 
7194 // Ensures that the specified name is permanently usable by prepending a prefix
7195 // if the first chars are _ and a digit, which indicate a transient name.
ensure_valid_name(string name,string pfx)7196 string CompilerMSL::ensure_valid_name(string name, string pfx)
7197 {
7198 	return (name.size() >= 2 && name[0] == '_' && isdigit(name[1])) ? (pfx + name) : name;
7199 }
7200 
7201 // Replace all names that match MSL keywords or Metal Standard Library functions.
replace_illegal_names()7202 void CompilerMSL::replace_illegal_names()
7203 {
7204 	// FIXME: MSL and GLSL are doing two different things here.
7205 	// Agree on convention and remove this override.
7206 	static const unordered_set<string> keywords = {
7207 		"kernel",
7208 		"vertex",
7209 		"fragment",
7210 		"compute",
7211 		"bias",
7212 		"assert",
7213 		"VARIABLE_TRACEPOINT",
7214 		"STATIC_DATA_TRACEPOINT",
7215 		"STATIC_DATA_TRACEPOINT_V",
7216 		"METAL_ALIGN",
7217 		"METAL_ASM",
7218 		"METAL_CONST",
7219 		"METAL_DEPRECATED",
7220 		"METAL_ENABLE_IF",
7221 		"METAL_FUNC",
7222 		"METAL_INTERNAL",
7223 		"METAL_NON_NULL_RETURN",
7224 		"METAL_NORETURN",
7225 		"METAL_NOTHROW",
7226 		"METAL_PURE",
7227 		"METAL_UNAVAILABLE",
7228 		"METAL_IMPLICIT",
7229 		"METAL_EXPLICIT",
7230 		"METAL_CONST_ARG",
7231 		"METAL_ARG_UNIFORM",
7232 		"METAL_ZERO_ARG",
7233 		"METAL_VALID_LOD_ARG",
7234 		"METAL_VALID_LEVEL_ARG",
7235 		"METAL_VALID_STORE_ORDER",
7236 		"METAL_VALID_LOAD_ORDER",
7237 		"METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
7238 		"METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
7239 		"METAL_VALID_RENDER_TARGET",
7240 		"is_function_constant_defined",
7241 		"CHAR_BIT",
7242 		"SCHAR_MAX",
7243 		"SCHAR_MIN",
7244 		"UCHAR_MAX",
7245 		"CHAR_MAX",
7246 		"CHAR_MIN",
7247 		"USHRT_MAX",
7248 		"SHRT_MAX",
7249 		"SHRT_MIN",
7250 		"UINT_MAX",
7251 		"INT_MAX",
7252 		"INT_MIN",
7253 		"FLT_DIG",
7254 		"FLT_MANT_DIG",
7255 		"FLT_MAX_10_EXP",
7256 		"FLT_MAX_EXP",
7257 		"FLT_MIN_10_EXP",
7258 		"FLT_MIN_EXP",
7259 		"FLT_RADIX",
7260 		"FLT_MAX",
7261 		"FLT_MIN",
7262 		"FLT_EPSILON",
7263 		"FP_ILOGB0",
7264 		"FP_ILOGBNAN",
7265 		"MAXFLOAT",
7266 		"HUGE_VALF",
7267 		"INFINITY",
7268 		"NAN",
7269 		"M_E_F",
7270 		"M_LOG2E_F",
7271 		"M_LOG10E_F",
7272 		"M_LN2_F",
7273 		"M_LN10_F",
7274 		"M_PI_F",
7275 		"M_PI_2_F",
7276 		"M_PI_4_F",
7277 		"M_1_PI_F",
7278 		"M_2_PI_F",
7279 		"M_2_SQRTPI_F",
7280 		"M_SQRT2_F",
7281 		"M_SQRT1_2_F",
7282 		"HALF_DIG",
7283 		"HALF_MANT_DIG",
7284 		"HALF_MAX_10_EXP",
7285 		"HALF_MAX_EXP",
7286 		"HALF_MIN_10_EXP",
7287 		"HALF_MIN_EXP",
7288 		"HALF_RADIX",
7289 		"HALF_MAX",
7290 		"HALF_MIN",
7291 		"HALF_EPSILON",
7292 		"MAXHALF",
7293 		"HUGE_VALH",
7294 		"M_E_H",
7295 		"M_LOG2E_H",
7296 		"M_LOG10E_H",
7297 		"M_LN2_H",
7298 		"M_LN10_H",
7299 		"M_PI_H",
7300 		"M_PI_2_H",
7301 		"M_PI_4_H",
7302 		"M_1_PI_H",
7303 		"M_2_PI_H",
7304 		"M_2_SQRTPI_H",
7305 		"M_SQRT2_H",
7306 		"M_SQRT1_2_H",
7307 		"DBL_DIG",
7308 		"DBL_MANT_DIG",
7309 		"DBL_MAX_10_EXP",
7310 		"DBL_MAX_EXP",
7311 		"DBL_MIN_10_EXP",
7312 		"DBL_MIN_EXP",
7313 		"DBL_RADIX",
7314 		"DBL_MAX",
7315 		"DBL_MIN",
7316 		"DBL_EPSILON",
7317 		"HUGE_VAL",
7318 		"M_E",
7319 		"M_LOG2E",
7320 		"M_LOG10E",
7321 		"M_LN2",
7322 		"M_LN10",
7323 		"M_PI",
7324 		"M_PI_2",
7325 		"M_PI_4",
7326 		"M_1_PI",
7327 		"M_2_PI",
7328 		"M_2_SQRTPI",
7329 		"M_SQRT2",
7330 		"M_SQRT1_2",
7331 		"quad_broadcast",
7332 	};
7333 
7334 	static const unordered_set<string> illegal_func_names = {
7335 		"main",
7336 		"saturate",
7337 		"assert",
7338 		"VARIABLE_TRACEPOINT",
7339 		"STATIC_DATA_TRACEPOINT",
7340 		"STATIC_DATA_TRACEPOINT_V",
7341 		"METAL_ALIGN",
7342 		"METAL_ASM",
7343 		"METAL_CONST",
7344 		"METAL_DEPRECATED",
7345 		"METAL_ENABLE_IF",
7346 		"METAL_FUNC",
7347 		"METAL_INTERNAL",
7348 		"METAL_NON_NULL_RETURN",
7349 		"METAL_NORETURN",
7350 		"METAL_NOTHROW",
7351 		"METAL_PURE",
7352 		"METAL_UNAVAILABLE",
7353 		"METAL_IMPLICIT",
7354 		"METAL_EXPLICIT",
7355 		"METAL_CONST_ARG",
7356 		"METAL_ARG_UNIFORM",
7357 		"METAL_ZERO_ARG",
7358 		"METAL_VALID_LOD_ARG",
7359 		"METAL_VALID_LEVEL_ARG",
7360 		"METAL_VALID_STORE_ORDER",
7361 		"METAL_VALID_LOAD_ORDER",
7362 		"METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
7363 		"METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
7364 		"METAL_VALID_RENDER_TARGET",
7365 		"is_function_constant_defined",
7366 		"CHAR_BIT",
7367 		"SCHAR_MAX",
7368 		"SCHAR_MIN",
7369 		"UCHAR_MAX",
7370 		"CHAR_MAX",
7371 		"CHAR_MIN",
7372 		"USHRT_MAX",
7373 		"SHRT_MAX",
7374 		"SHRT_MIN",
7375 		"UINT_MAX",
7376 		"INT_MAX",
7377 		"INT_MIN",
7378 		"FLT_DIG",
7379 		"FLT_MANT_DIG",
7380 		"FLT_MAX_10_EXP",
7381 		"FLT_MAX_EXP",
7382 		"FLT_MIN_10_EXP",
7383 		"FLT_MIN_EXP",
7384 		"FLT_RADIX",
7385 		"FLT_MAX",
7386 		"FLT_MIN",
7387 		"FLT_EPSILON",
7388 		"FP_ILOGB0",
7389 		"FP_ILOGBNAN",
7390 		"MAXFLOAT",
7391 		"HUGE_VALF",
7392 		"INFINITY",
7393 		"NAN",
7394 		"M_E_F",
7395 		"M_LOG2E_F",
7396 		"M_LOG10E_F",
7397 		"M_LN2_F",
7398 		"M_LN10_F",
7399 		"M_PI_F",
7400 		"M_PI_2_F",
7401 		"M_PI_4_F",
7402 		"M_1_PI_F",
7403 		"M_2_PI_F",
7404 		"M_2_SQRTPI_F",
7405 		"M_SQRT2_F",
7406 		"M_SQRT1_2_F",
7407 		"HALF_DIG",
7408 		"HALF_MANT_DIG",
7409 		"HALF_MAX_10_EXP",
7410 		"HALF_MAX_EXP",
7411 		"HALF_MIN_10_EXP",
7412 		"HALF_MIN_EXP",
7413 		"HALF_RADIX",
7414 		"HALF_MAX",
7415 		"HALF_MIN",
7416 		"HALF_EPSILON",
7417 		"MAXHALF",
7418 		"HUGE_VALH",
7419 		"M_E_H",
7420 		"M_LOG2E_H",
7421 		"M_LOG10E_H",
7422 		"M_LN2_H",
7423 		"M_LN10_H",
7424 		"M_PI_H",
7425 		"M_PI_2_H",
7426 		"M_PI_4_H",
7427 		"M_1_PI_H",
7428 		"M_2_PI_H",
7429 		"M_2_SQRTPI_H",
7430 		"M_SQRT2_H",
7431 		"M_SQRT1_2_H",
7432 		"DBL_DIG",
7433 		"DBL_MANT_DIG",
7434 		"DBL_MAX_10_EXP",
7435 		"DBL_MAX_EXP",
7436 		"DBL_MIN_10_EXP",
7437 		"DBL_MIN_EXP",
7438 		"DBL_RADIX",
7439 		"DBL_MAX",
7440 		"DBL_MIN",
7441 		"DBL_EPSILON",
7442 		"HUGE_VAL",
7443 		"M_E",
7444 		"M_LOG2E",
7445 		"M_LOG10E",
7446 		"M_LN2",
7447 		"M_LN10",
7448 		"M_PI",
7449 		"M_PI_2",
7450 		"M_PI_4",
7451 		"M_1_PI",
7452 		"M_2_PI",
7453 		"M_2_SQRTPI",
7454 		"M_SQRT2",
7455 		"M_SQRT1_2",
7456 	};
7457 
7458 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &) {
7459 		auto &dec = ir.meta[self].decoration;
7460 		if (keywords.find(dec.alias) != end(keywords))
7461 			dec.alias += "0";
7462 	});
7463 
7464 	ir.for_each_typed_id<SPIRFunction>([&](uint32_t self, SPIRFunction &) {
7465 		auto &dec = ir.meta[self].decoration;
7466 		if (illegal_func_names.find(dec.alias) != end(illegal_func_names))
7467 			dec.alias += "0";
7468 	});
7469 
7470 	ir.for_each_typed_id<SPIRType>([&](uint32_t self, SPIRType &) {
7471 		for (auto &mbr_dec : ir.meta[self].members)
7472 			if (keywords.find(mbr_dec.alias) != end(keywords))
7473 				mbr_dec.alias += "0";
7474 	});
7475 
7476 	for (auto &entry : ir.entry_points)
7477 	{
7478 		// Change both the entry point name and the alias, to keep them synced.
7479 		string &ep_name = entry.second.name;
7480 		if (illegal_func_names.find(ep_name) != end(illegal_func_names))
7481 			ep_name += "0";
7482 
7483 		// Always write this because entry point might have been renamed earlier.
7484 		ir.meta[entry.first].decoration.alias = ep_name;
7485 	}
7486 
7487 	CompilerGLSL::replace_illegal_names();
7488 }
7489 
to_member_reference(uint32_t base,const SPIRType & type,uint32_t index,bool ptr_chain)7490 string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain)
7491 {
7492 	auto *var = maybe_get<SPIRVariable>(base);
7493 	// If this is a buffer array, we have to dereference the buffer pointers.
7494 	// Otherwise, if this is a pointer expression, dereference it.
7495 
7496 	bool declared_as_pointer = false;
7497 
7498 	if (var)
7499 	{
7500 		bool is_buffer_variable = var->storage == StorageClassUniform || var->storage == StorageClassStorageBuffer;
7501 		declared_as_pointer = is_buffer_variable && is_array(get<SPIRType>(var->basetype));
7502 	}
7503 
7504 	if (declared_as_pointer || (!ptr_chain && should_dereference(base)))
7505 		return join("->", to_member_name(type, index));
7506 	else
7507 		return join(".", to_member_name(type, index));
7508 }
7509 
to_qualifiers_glsl(uint32_t id)7510 string CompilerMSL::to_qualifiers_glsl(uint32_t id)
7511 {
7512 	string quals;
7513 
7514 	auto &type = expression_type(id);
7515 	if (type.storage == StorageClassWorkgroup)
7516 		quals += "threadgroup ";
7517 
7518 	return quals;
7519 }
7520 
7521 // The optional id parameter indicates the object whose type we are trying
7522 // to find the description for. It is optional. Most type descriptions do not
7523 // depend on a specific object's use of that type.
type_to_glsl(const SPIRType & type,uint32_t id)7524 string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id)
7525 {
7526 	string type_name;
7527 
7528 	// Pointer?
7529 	if (type.pointer)
7530 	{
7531 		type_name = join(get_type_address_space(type, id), " ", type_to_glsl(get<SPIRType>(type.parent_type), id));
7532 		switch (type.basetype)
7533 		{
7534 		case SPIRType::Image:
7535 		case SPIRType::SampledImage:
7536 		case SPIRType::Sampler:
7537 			// These are handles.
7538 			break;
7539 		default:
7540 			// Anything else can be a raw pointer.
7541 			type_name += "*";
7542 			break;
7543 		}
7544 		return type_name;
7545 	}
7546 
7547 	switch (type.basetype)
7548 	{
7549 	case SPIRType::Struct:
7550 		// Need OpName lookup here to get a "sensible" name for a struct.
7551 		return to_name(type.self);
7552 
7553 	case SPIRType::Image:
7554 	case SPIRType::SampledImage:
7555 		return image_type_glsl(type, id);
7556 
7557 	case SPIRType::Sampler:
7558 		return sampler_type(type);
7559 
7560 	case SPIRType::Void:
7561 		return "void";
7562 
7563 	case SPIRType::AtomicCounter:
7564 		return "atomic_uint";
7565 
7566 	case SPIRType::ControlPointArray:
7567 		return join("patch_control_point<", type_to_glsl(get<SPIRType>(type.parent_type), id), ">");
7568 
7569 	// Scalars
7570 	case SPIRType::Boolean:
7571 		type_name = "bool";
7572 		break;
7573 	case SPIRType::Char:
7574 	case SPIRType::SByte:
7575 		type_name = "char";
7576 		break;
7577 	case SPIRType::UByte:
7578 		type_name = "uchar";
7579 		break;
7580 	case SPIRType::Short:
7581 		type_name = "short";
7582 		break;
7583 	case SPIRType::UShort:
7584 		type_name = "ushort";
7585 		break;
7586 	case SPIRType::Int:
7587 		type_name = "int";
7588 		break;
7589 	case SPIRType::UInt:
7590 		type_name = "uint";
7591 		break;
7592 	case SPIRType::Int64:
7593 		if (!msl_options.supports_msl_version(2, 2))
7594 			SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
7595 		type_name = "long";
7596 		break;
7597 	case SPIRType::UInt64:
7598 		if (!msl_options.supports_msl_version(2, 2))
7599 			SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
7600 		type_name = "ulong";
7601 		break;
7602 	case SPIRType::Half:
7603 		type_name = "half";
7604 		break;
7605 	case SPIRType::Float:
7606 		type_name = "float";
7607 		break;
7608 	case SPIRType::Double:
7609 		type_name = "double"; // Currently unsupported
7610 		break;
7611 
7612 	default:
7613 		return "unknown_type";
7614 	}
7615 
7616 	// Matrix?
7617 	if (type.columns > 1)
7618 		type_name += to_string(type.columns) + "x";
7619 
7620 	// Vector or Matrix?
7621 	if (type.vecsize > 1)
7622 		type_name += to_string(type.vecsize);
7623 
7624 	return type_name;
7625 }
7626 
sampler_type(const SPIRType & type)7627 std::string CompilerMSL::sampler_type(const SPIRType &type)
7628 {
7629 	if (!type.array.empty())
7630 	{
7631 		if (!msl_options.supports_msl_version(2))
7632 			SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of samplers.");
7633 
7634 		if (type.array.size() > 1)
7635 			SPIRV_CROSS_THROW("Arrays of arrays of samplers are not supported in MSL.");
7636 
7637 		// Arrays of samplers in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
7638 		uint32_t array_size = to_array_size_literal(type);
7639 		if (array_size == 0)
7640 			SPIRV_CROSS_THROW("Unsized array of samplers is not supported in MSL.");
7641 
7642 		auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
7643 		return join("array<", sampler_type(parent), ", ", array_size, ">");
7644 	}
7645 	else
7646 		return "sampler";
7647 }
7648 
7649 // Returns an MSL string describing the SPIR-V image type
image_type_glsl(const SPIRType & type,uint32_t id)7650 string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id)
7651 {
7652 	auto *var = maybe_get<SPIRVariable>(id);
7653 	if (var && var->basevariable)
7654 	{
7655 		// For comparison images, check against the base variable,
7656 		// and not the fake ID which might have been generated for this variable.
7657 		id = var->basevariable;
7658 	}
7659 
7660 	if (!type.array.empty())
7661 	{
7662 		uint32_t major = 2, minor = 0;
7663 		if (msl_options.is_ios())
7664 		{
7665 			major = 1;
7666 			minor = 2;
7667 		}
7668 		if (!msl_options.supports_msl_version(major, minor))
7669 		{
7670 			if (msl_options.is_ios())
7671 				SPIRV_CROSS_THROW("MSL 1.2 or greater is required for arrays of textures.");
7672 			else
7673 				SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of textures.");
7674 		}
7675 
7676 		if (type.array.size() > 1)
7677 			SPIRV_CROSS_THROW("Arrays of arrays of textures are not supported in MSL.");
7678 
7679 		// Arrays of images in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
7680 		uint32_t array_size = to_array_size_literal(type);
7681 		if (array_size == 0)
7682 			SPIRV_CROSS_THROW("Unsized array of images is not supported in MSL.");
7683 
7684 		auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
7685 		return join("array<", image_type_glsl(parent, id), ", ", array_size, ">");
7686 	}
7687 
7688 	string img_type_name;
7689 
7690 	// Bypass pointers because we need the real image struct
7691 	auto &img_type = get<SPIRType>(type.self).image;
7692 	if (image_is_comparison(type, id))
7693 	{
7694 		switch (img_type.dim)
7695 		{
7696 		case Dim1D:
7697 			img_type_name += "depth1d_unsupported_by_metal";
7698 			break;
7699 		case Dim2D:
7700 			if (img_type.ms && img_type.arrayed)
7701 			{
7702 				if (!msl_options.supports_msl_version(2, 1))
7703 					SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
7704 				img_type_name += "depth2d_ms_array";
7705 			}
7706 			else if (img_type.ms)
7707 				img_type_name += "depth2d_ms";
7708 			else if (img_type.arrayed)
7709 				img_type_name += "depth2d_array";
7710 			else
7711 				img_type_name += "depth2d";
7712 			break;
7713 		case Dim3D:
7714 			img_type_name += "depth3d_unsupported_by_metal";
7715 			break;
7716 		case DimCube:
7717 			img_type_name += (img_type.arrayed ? "depthcube_array" : "depthcube");
7718 			break;
7719 		default:
7720 			img_type_name += "unknown_depth_texture_type";
7721 			break;
7722 		}
7723 	}
7724 	else
7725 	{
7726 		switch (img_type.dim)
7727 		{
7728 		case Dim1D:
7729 			img_type_name += (img_type.arrayed ? "texture1d_array" : "texture1d");
7730 			break;
7731 		case DimBuffer:
7732 			if (img_type.ms || img_type.arrayed)
7733 				SPIRV_CROSS_THROW("Cannot use texel buffers with multisampling or array layers.");
7734 
7735 			if (msl_options.texture_buffer_native)
7736 			{
7737 				if (!msl_options.supports_msl_version(2, 1))
7738 					SPIRV_CROSS_THROW("Native texture_buffer type is only supported in MSL 2.1.");
7739 				img_type_name = "texture_buffer";
7740 			}
7741 			else
7742 				img_type_name += "texture2d";
7743 			break;
7744 		case Dim2D:
7745 		case DimSubpassData:
7746 			if (img_type.ms && img_type.arrayed)
7747 			{
7748 				if (!msl_options.supports_msl_version(2, 1))
7749 					SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
7750 				img_type_name += "texture2d_ms_array";
7751 			}
7752 			else if (img_type.ms)
7753 				img_type_name += "texture2d_ms";
7754 			else if (img_type.arrayed)
7755 				img_type_name += "texture2d_array";
7756 			else
7757 				img_type_name += "texture2d";
7758 			break;
7759 		case Dim3D:
7760 			img_type_name += "texture3d";
7761 			break;
7762 		case DimCube:
7763 			img_type_name += (img_type.arrayed ? "texturecube_array" : "texturecube");
7764 			break;
7765 		default:
7766 			img_type_name += "unknown_texture_type";
7767 			break;
7768 		}
7769 	}
7770 
7771 	// Append the pixel type
7772 	img_type_name += "<";
7773 	img_type_name += type_to_glsl(get<SPIRType>(img_type.type));
7774 
7775 	// For unsampled images, append the sample/read/write access qualifier.
7776 	// For kernel images, the access qualifier my be supplied directly by SPIR-V.
7777 	// Otherwise it may be set based on whether the image is read from or written to within the shader.
7778 	if (type.basetype == SPIRType::Image && type.image.sampled == 2 && type.image.dim != DimSubpassData)
7779 	{
7780 		switch (img_type.access)
7781 		{
7782 		case AccessQualifierReadOnly:
7783 			img_type_name += ", access::read";
7784 			break;
7785 
7786 		case AccessQualifierWriteOnly:
7787 			img_type_name += ", access::write";
7788 			break;
7789 
7790 		case AccessQualifierReadWrite:
7791 			img_type_name += ", access::read_write";
7792 			break;
7793 
7794 		default:
7795 		{
7796 			auto *p_var = maybe_get_backing_variable(id);
7797 			if (p_var && p_var->basevariable)
7798 				p_var = maybe_get<SPIRVariable>(p_var->basevariable);
7799 			if (p_var && !has_decoration(p_var->self, DecorationNonWritable))
7800 			{
7801 				img_type_name += ", access::";
7802 
7803 				if (!has_decoration(p_var->self, DecorationNonReadable))
7804 					img_type_name += "read_";
7805 
7806 				img_type_name += "write";
7807 			}
7808 			break;
7809 		}
7810 		}
7811 	}
7812 
7813 	img_type_name += ">";
7814 
7815 	return img_type_name;
7816 }
7817 
emit_subgroup_op(const Instruction & i)7818 void CompilerMSL::emit_subgroup_op(const Instruction &i)
7819 {
7820 	const uint32_t *ops = stream(i);
7821 	auto op = static_cast<Op>(i.op);
7822 
7823 	// Metal 2.0 is required. iOS only supports quad ops. macOS only supports
7824 	// broadcast and shuffle on 10.13 (2.0), with full support in 10.14 (2.1).
7825 	// Note that iOS makes no distinction between a quad-group and a subgroup;
7826 	// all subgroups are quad-groups there.
7827 	if (!msl_options.supports_msl_version(2))
7828 		SPIRV_CROSS_THROW("Subgroups are only supported in Metal 2.0 and up.");
7829 
7830 	if (msl_options.is_ios())
7831 	{
7832 		switch (op)
7833 		{
7834 		default:
7835 			SPIRV_CROSS_THROW("iOS only supports quad-group operations.");
7836 		case OpGroupNonUniformBroadcast:
7837 		case OpGroupNonUniformShuffle:
7838 		case OpGroupNonUniformShuffleXor:
7839 		case OpGroupNonUniformShuffleUp:
7840 		case OpGroupNonUniformShuffleDown:
7841 		case OpGroupNonUniformQuadSwap:
7842 		case OpGroupNonUniformQuadBroadcast:
7843 			break;
7844 		}
7845 	}
7846 
7847 	if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
7848 	{
7849 		switch (op)
7850 		{
7851 		default:
7852 			SPIRV_CROSS_THROW("Subgroup ops beyond broadcast and shuffle on macOS require Metal 2.0 and up.");
7853 		case OpGroupNonUniformBroadcast:
7854 		case OpGroupNonUniformShuffle:
7855 		case OpGroupNonUniformShuffleXor:
7856 		case OpGroupNonUniformShuffleUp:
7857 		case OpGroupNonUniformShuffleDown:
7858 			break;
7859 		}
7860 	}
7861 
7862 	uint32_t result_type = ops[0];
7863 	uint32_t id = ops[1];
7864 
7865 	auto scope = static_cast<Scope>(get<SPIRConstant>(ops[2]).scalar());
7866 	if (scope != ScopeSubgroup)
7867 		SPIRV_CROSS_THROW("Only subgroup scope is supported.");
7868 
7869 	switch (op)
7870 	{
7871 	case OpGroupNonUniformElect:
7872 		emit_op(result_type, id, "simd_is_first()", true);
7873 		break;
7874 
7875 	case OpGroupNonUniformBroadcast:
7876 		emit_binary_func_op(result_type, id, ops[3], ops[4],
7877 		                    msl_options.is_ios() ? "quad_broadcast" : "simd_broadcast");
7878 		break;
7879 
7880 	case OpGroupNonUniformBroadcastFirst:
7881 		emit_unary_func_op(result_type, id, ops[3], "simd_broadcast_first");
7882 		break;
7883 
7884 	case OpGroupNonUniformBallot:
7885 		emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallot");
7886 		break;
7887 
7888 	case OpGroupNonUniformInverseBallot:
7889 		emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_invocation_id_id, "spvSubgroupBallotBitExtract");
7890 		break;
7891 
7892 	case OpGroupNonUniformBallotBitExtract:
7893 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBallotBitExtract");
7894 		break;
7895 
7896 	case OpGroupNonUniformBallotFindLSB:
7897 		emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallotFindLSB");
7898 		break;
7899 
7900 	case OpGroupNonUniformBallotFindMSB:
7901 		emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallotFindMSB");
7902 		break;
7903 
7904 	case OpGroupNonUniformBallotBitCount:
7905 	{
7906 		auto operation = static_cast<GroupOperation>(ops[3]);
7907 		if (operation == GroupOperationReduce)
7908 			emit_unary_func_op(result_type, id, ops[4], "spvSubgroupBallotBitCount");
7909 		else if (operation == GroupOperationInclusiveScan)
7910 			emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
7911 			                    "spvSubgroupBallotInclusiveBitCount");
7912 		else if (operation == GroupOperationExclusiveScan)
7913 			emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
7914 			                    "spvSubgroupBallotExclusiveBitCount");
7915 		else
7916 			SPIRV_CROSS_THROW("Invalid BitCount operation.");
7917 		break;
7918 	}
7919 
7920 	case OpGroupNonUniformShuffle:
7921 		emit_binary_func_op(result_type, id, ops[3], ops[4], msl_options.is_ios() ? "quad_shuffle" : "simd_shuffle");
7922 		break;
7923 
7924 	case OpGroupNonUniformShuffleXor:
7925 		emit_binary_func_op(result_type, id, ops[3], ops[4],
7926 		                    msl_options.is_ios() ? "quad_shuffle_xor" : "simd_shuffle_xor");
7927 		break;
7928 
7929 	case OpGroupNonUniformShuffleUp:
7930 		emit_binary_func_op(result_type, id, ops[3], ops[4],
7931 		                    msl_options.is_ios() ? "quad_shuffle_up" : "simd_shuffle_up");
7932 		break;
7933 
7934 	case OpGroupNonUniformShuffleDown:
7935 		emit_binary_func_op(result_type, id, ops[3], ops[4],
7936 		                    msl_options.is_ios() ? "quad_shuffle_down" : "simd_shuffle_down");
7937 		break;
7938 
7939 	case OpGroupNonUniformAll:
7940 		emit_unary_func_op(result_type, id, ops[3], "simd_all");
7941 		break;
7942 
7943 	case OpGroupNonUniformAny:
7944 		emit_unary_func_op(result_type, id, ops[3], "simd_any");
7945 		break;
7946 
7947 	case OpGroupNonUniformAllEqual:
7948 		emit_unary_func_op(result_type, id, ops[3], "spvSubgroupAllEqual");
7949 		break;
7950 
7951 		// clang-format off
7952 #define MSL_GROUP_OP(op, msl_op) \
7953 case OpGroupNonUniform##op: \
7954 	{ \
7955 		auto operation = static_cast<GroupOperation>(ops[3]); \
7956 		if (operation == GroupOperationReduce) \
7957 			emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
7958 		else if (operation == GroupOperationInclusiveScan) \
7959 			emit_unary_func_op(result_type, id, ops[4], "simd_prefix_inclusive_" #msl_op); \
7960 		else if (operation == GroupOperationExclusiveScan) \
7961 			emit_unary_func_op(result_type, id, ops[4], "simd_prefix_exclusive_" #msl_op); \
7962 		else if (operation == GroupOperationClusteredReduce) \
7963 		{ \
7964 			/* Only cluster sizes of 4 are supported. */ \
7965 			uint32_t cluster_size = get<SPIRConstant>(ops[5]).scalar(); \
7966 			if (cluster_size != 4) \
7967 				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
7968 			emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
7969 		} \
7970 		else \
7971 			SPIRV_CROSS_THROW("Invalid group operation."); \
7972 		break; \
7973 	}
7974 	MSL_GROUP_OP(FAdd, sum)
7975 	MSL_GROUP_OP(FMul, product)
7976 	MSL_GROUP_OP(IAdd, sum)
7977 	MSL_GROUP_OP(IMul, product)
7978 #undef MSL_GROUP_OP
7979 	// The others, unfortunately, don't support InclusiveScan or ExclusiveScan.
7980 #define MSL_GROUP_OP(op, msl_op) \
7981 case OpGroupNonUniform##op: \
7982 	{ \
7983 		auto operation = static_cast<GroupOperation>(ops[3]); \
7984 		if (operation == GroupOperationReduce) \
7985 			emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
7986 		else if (operation == GroupOperationInclusiveScan) \
7987 			SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
7988 		else if (operation == GroupOperationExclusiveScan) \
7989 			SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
7990 		else if (operation == GroupOperationClusteredReduce) \
7991 		{ \
7992 			/* Only cluster sizes of 4 are supported. */ \
7993 			uint32_t cluster_size = get<SPIRConstant>(ops[5]).scalar(); \
7994 			if (cluster_size != 4) \
7995 				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
7996 			emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
7997 		} \
7998 		else \
7999 			SPIRV_CROSS_THROW("Invalid group operation."); \
8000 		break; \
8001 	}
8002 	MSL_GROUP_OP(FMin, min)
8003 	MSL_GROUP_OP(FMax, max)
8004 	MSL_GROUP_OP(SMin, min)
8005 	MSL_GROUP_OP(SMax, max)
8006 	MSL_GROUP_OP(UMin, min)
8007 	MSL_GROUP_OP(UMax, max)
8008 	MSL_GROUP_OP(BitwiseAnd, and)
8009 	MSL_GROUP_OP(BitwiseOr, or)
8010 	MSL_GROUP_OP(BitwiseXor, xor)
8011 	MSL_GROUP_OP(LogicalAnd, and)
8012 	MSL_GROUP_OP(LogicalOr, or)
8013 	MSL_GROUP_OP(LogicalXor, xor)
8014 		// clang-format on
8015 
8016 	case OpGroupNonUniformQuadSwap:
8017 	{
8018 		// We can implement this easily based on the following table giving
8019 		// the target lane ID from the direction and current lane ID:
8020 		//        Direction
8021 		//      | 0 | 1 | 2 |
8022 		//   ---+---+---+---+
8023 		// L 0  | 1   2   3
8024 		// a 1  | 0   3   2
8025 		// n 2  | 3   0   1
8026 		// e 3  | 2   1   0
8027 		// Notice that target = source ^ (direction + 1).
8028 		uint32_t mask = get<SPIRConstant>(ops[4]).scalar() + 1;
8029 		uint32_t mask_id = ir.increase_bound_by(1);
8030 		set<SPIRConstant>(mask_id, expression_type_id(ops[4]), mask, false);
8031 		emit_binary_func_op(result_type, id, ops[3], mask_id, "quad_shuffle_xor");
8032 		break;
8033 	}
8034 
8035 	case OpGroupNonUniformQuadBroadcast:
8036 		emit_binary_func_op(result_type, id, ops[3], ops[4], "quad_broadcast");
8037 		break;
8038 
8039 	default:
8040 		SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
8041 	}
8042 
8043 	register_control_dependent_expression(id);
8044 }
8045 
bitcast_glsl_op(const SPIRType & out_type,const SPIRType & in_type)8046 string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
8047 {
8048 	if (out_type.basetype == in_type.basetype)
8049 		return "";
8050 
8051 	assert(out_type.basetype != SPIRType::Boolean);
8052 	assert(in_type.basetype != SPIRType::Boolean);
8053 
8054 	bool integral_cast = type_is_integral(out_type) && type_is_integral(in_type);
8055 	bool same_size_cast = out_type.width == in_type.width;
8056 
8057 	if (integral_cast && same_size_cast)
8058 	{
8059 		// Trivial bitcast case, casts between integers.
8060 		return type_to_glsl(out_type);
8061 	}
8062 	else
8063 	{
8064 		// Fall back to the catch-all bitcast in MSL.
8065 		return "as_type<" + type_to_glsl(out_type) + ">";
8066 	}
8067 }
8068 
8069 // Returns an MSL string identifying the name of a SPIR-V builtin.
8070 // Output builtins are qualified with the name of the stage out structure.
builtin_to_glsl(BuiltIn builtin,StorageClass storage)8071 string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
8072 {
8073 	switch (builtin)
8074 	{
8075 
8076 	// Override GLSL compiler strictness
8077 	case BuiltInVertexId:
8078 		return "gl_VertexID";
8079 	case BuiltInInstanceId:
8080 		return "gl_InstanceID";
8081 	case BuiltInVertexIndex:
8082 		return "gl_VertexIndex";
8083 	case BuiltInInstanceIndex:
8084 		return "gl_InstanceIndex";
8085 	case BuiltInBaseVertex:
8086 		return "gl_BaseVertex";
8087 	case BuiltInBaseInstance:
8088 		return "gl_BaseInstance";
8089 	case BuiltInDrawIndex:
8090 		SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
8091 
8092 	// When used in the entry function, output builtins are qualified with output struct name.
8093 	// Test storage class as NOT Input, as output builtins might be part of generic type.
8094 	// Also don't do this for tessellation control shaders.
8095 	case BuiltInViewportIndex:
8096 		if (!msl_options.supports_msl_version(2, 0))
8097 			SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
8098 		/* fallthrough */
8099 	case BuiltInPosition:
8100 	case BuiltInPointSize:
8101 	case BuiltInClipDistance:
8102 	case BuiltInCullDistance:
8103 	case BuiltInLayer:
8104 	case BuiltInFragDepth:
8105 	case BuiltInFragStencilRefEXT:
8106 	case BuiltInSampleMask:
8107 		if (get_execution_model() == ExecutionModelTessellationControl)
8108 			break;
8109 		if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
8110 			return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
8111 
8112 		break;
8113 
8114 	case BuiltInBaryCoordNV:
8115 	case BuiltInBaryCoordNoPerspNV:
8116 		if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
8117 			return stage_in_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
8118 		break;
8119 
8120 	case BuiltInTessLevelOuter:
8121 		if (get_execution_model() == ExecutionModelTessellationEvaluation)
8122 		{
8123 			if (storage != StorageClassOutput && !get_entry_point().flags.get(ExecutionModeTriangles) &&
8124 			    current_function && (current_function->self == ir.default_entry_point))
8125 				return join(patch_stage_in_var_name, ".", CompilerGLSL::builtin_to_glsl(builtin, storage));
8126 			else
8127 				break;
8128 		}
8129 		if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
8130 			return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
8131 			            "].edgeTessellationFactor");
8132 		break;
8133 
8134 	case BuiltInTessLevelInner:
8135 		if (get_execution_model() == ExecutionModelTessellationEvaluation)
8136 		{
8137 			if (storage != StorageClassOutput && !get_entry_point().flags.get(ExecutionModeTriangles) &&
8138 			    current_function && (current_function->self == ir.default_entry_point))
8139 				return join(patch_stage_in_var_name, ".", CompilerGLSL::builtin_to_glsl(builtin, storage));
8140 			else
8141 				break;
8142 		}
8143 		if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
8144 			return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
8145 			            "].insideTessellationFactor");
8146 		break;
8147 
8148 	default:
8149 		break;
8150 	}
8151 
8152 	return CompilerGLSL::builtin_to_glsl(builtin, storage);
8153 }
8154 
8155 // Returns an MSL string attribute qualifer for a SPIR-V builtin
builtin_qualifier(BuiltIn builtin)8156 string CompilerMSL::builtin_qualifier(BuiltIn builtin)
8157 {
8158 	auto &execution = get_entry_point();
8159 
8160 	switch (builtin)
8161 	{
8162 	// Vertex function in
8163 	case BuiltInVertexId:
8164 		return "vertex_id";
8165 	case BuiltInVertexIndex:
8166 		return "vertex_id";
8167 	case BuiltInBaseVertex:
8168 		return "base_vertex";
8169 	case BuiltInInstanceId:
8170 		return "instance_id";
8171 	case BuiltInInstanceIndex:
8172 		return "instance_id";
8173 	case BuiltInBaseInstance:
8174 		return "base_instance";
8175 	case BuiltInDrawIndex:
8176 		SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
8177 
8178 	// Vertex function out
8179 	case BuiltInClipDistance:
8180 		return "clip_distance";
8181 	case BuiltInPointSize:
8182 		return "point_size";
8183 	case BuiltInPosition:
8184 		if (position_invariant)
8185 		{
8186 			if (!msl_options.supports_msl_version(2, 1))
8187 				SPIRV_CROSS_THROW("Invariant position is only supported on MSL 2.1 and up.");
8188 			return "position, invariant";
8189 		}
8190 		else
8191 			return "position";
8192 	case BuiltInLayer:
8193 		return "render_target_array_index";
8194 	case BuiltInViewportIndex:
8195 		if (!msl_options.supports_msl_version(2, 0))
8196 			SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
8197 		return "viewport_array_index";
8198 
8199 	// Tess. control function in
8200 	case BuiltInInvocationId:
8201 		return "thread_index_in_threadgroup";
8202 	case BuiltInPatchVertices:
8203 		// Shouldn't be reached.
8204 		SPIRV_CROSS_THROW("PatchVertices is derived from the auxiliary buffer in MSL.");
8205 	case BuiltInPrimitiveId:
8206 		switch (execution.model)
8207 		{
8208 		case ExecutionModelTessellationControl:
8209 			return "threadgroup_position_in_grid";
8210 		case ExecutionModelTessellationEvaluation:
8211 			return "patch_id";
8212 		case ExecutionModelFragment:
8213 			if (msl_options.is_ios())
8214 				SPIRV_CROSS_THROW("PrimitiveId is not supported in fragment on iOS.");
8215 			else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 2))
8216 				SPIRV_CROSS_THROW("PrimitiveId on macOS requires MSL 2.2.");
8217 			return "primitive_id";
8218 		default:
8219 			SPIRV_CROSS_THROW("PrimitiveId is not supported in this execution model.");
8220 		}
8221 
8222 	// Tess. control function out
8223 	case BuiltInTessLevelOuter:
8224 	case BuiltInTessLevelInner:
8225 		// Shouldn't be reached.
8226 		SPIRV_CROSS_THROW("Tessellation levels are handled specially in MSL.");
8227 
8228 	// Tess. evaluation function in
8229 	case BuiltInTessCoord:
8230 		return "position_in_patch";
8231 
8232 	// Fragment function in
8233 	case BuiltInFrontFacing:
8234 		return "front_facing";
8235 	case BuiltInPointCoord:
8236 		return "point_coord";
8237 	case BuiltInFragCoord:
8238 		return "position";
8239 	case BuiltInSampleId:
8240 		return "sample_id";
8241 	case BuiltInSampleMask:
8242 		return "sample_mask";
8243 	case BuiltInSamplePosition:
8244 		// Shouldn't be reached.
8245 		SPIRV_CROSS_THROW("Sample position is retrieved by a function in MSL.");
8246 	case BuiltInViewIndex:
8247 		if (execution.model != ExecutionModelFragment)
8248 			SPIRV_CROSS_THROW("ViewIndex is handled specially outside fragment shaders.");
8249 		// The ViewIndex was implicitly used in the prior stages to set the render_target_array_index,
8250 		// so we can get it from there.
8251 		return "render_target_array_index";
8252 
8253 	// Fragment function out
8254 	case BuiltInFragDepth:
8255 		if (execution.flags.get(ExecutionModeDepthGreater))
8256 			return "depth(greater)";
8257 		else if (execution.flags.get(ExecutionModeDepthLess))
8258 			return "depth(less)";
8259 		else
8260 			return "depth(any)";
8261 
8262 	case BuiltInFragStencilRefEXT:
8263 		return "stencil";
8264 
8265 	// Compute function in
8266 	case BuiltInGlobalInvocationId:
8267 		return "thread_position_in_grid";
8268 
8269 	case BuiltInWorkgroupId:
8270 		return "threadgroup_position_in_grid";
8271 
8272 	case BuiltInNumWorkgroups:
8273 		return "threadgroups_per_grid";
8274 
8275 	case BuiltInLocalInvocationId:
8276 		return "thread_position_in_threadgroup";
8277 
8278 	case BuiltInLocalInvocationIndex:
8279 		return "thread_index_in_threadgroup";
8280 
8281 	case BuiltInSubgroupSize:
8282 		if (execution.model == ExecutionModelFragment)
8283 		{
8284 			if (!msl_options.supports_msl_version(2, 2))
8285 				SPIRV_CROSS_THROW("threads_per_simdgroup requires Metal 2.2 in fragment shaders.");
8286 			return "threads_per_simdgroup";
8287 		}
8288 		else
8289 		{
8290 			// thread_execution_width is an alias for threads_per_simdgroup, and it's only available since 1.0,
8291 			// but not in fragment.
8292 			return "thread_execution_width";
8293 		}
8294 
8295 	case BuiltInNumSubgroups:
8296 		if (!msl_options.supports_msl_version(2))
8297 			SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
8298 		return msl_options.is_ios() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
8299 
8300 	case BuiltInSubgroupId:
8301 		if (!msl_options.supports_msl_version(2))
8302 			SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
8303 		return msl_options.is_ios() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
8304 
8305 	case BuiltInSubgroupLocalInvocationId:
8306 		if (execution.model == ExecutionModelFragment)
8307 		{
8308 			if (!msl_options.supports_msl_version(2, 2))
8309 				SPIRV_CROSS_THROW("thread_index_in_simdgroup requires Metal 2.2 in fragment shaders.");
8310 			return "thread_index_in_simdgroup";
8311 		}
8312 		else
8313 		{
8314 			if (!msl_options.supports_msl_version(2))
8315 				SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
8316 			return msl_options.is_ios() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
8317 		}
8318 
8319 	case BuiltInSubgroupEqMask:
8320 	case BuiltInSubgroupGeMask:
8321 	case BuiltInSubgroupGtMask:
8322 	case BuiltInSubgroupLeMask:
8323 	case BuiltInSubgroupLtMask:
8324 		// Shouldn't be reached.
8325 		SPIRV_CROSS_THROW("Subgroup ballot masks are handled specially in MSL.");
8326 
8327 	case BuiltInBaryCoordNV:
8328 		// TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
8329 		if (msl_options.is_ios())
8330 			SPIRV_CROSS_THROW("Barycentrics not supported on iOS.");
8331 		else if (!msl_options.supports_msl_version(2, 2))
8332 			SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
8333 		return "barycentric_coord, center_perspective";
8334 
8335 	case BuiltInBaryCoordNoPerspNV:
8336 		// TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
8337 		if (msl_options.is_ios())
8338 			SPIRV_CROSS_THROW("Barycentrics not supported on iOS.");
8339 		else if (!msl_options.supports_msl_version(2, 2))
8340 			SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
8341 		return "barycentric_coord, center_no_perspective";
8342 
8343 	default:
8344 		return "unsupported-built-in";
8345 	}
8346 }
8347 
8348 // Returns an MSL string type declaration for a SPIR-V builtin
builtin_type_decl(BuiltIn builtin,uint32_t id)8349 string CompilerMSL::builtin_type_decl(BuiltIn builtin, uint32_t id)
8350 {
8351 	const SPIREntryPoint &execution = get_entry_point();
8352 	switch (builtin)
8353 	{
8354 	// Vertex function in
8355 	case BuiltInVertexId:
8356 		return "uint";
8357 	case BuiltInVertexIndex:
8358 		return "uint";
8359 	case BuiltInBaseVertex:
8360 		return "uint";
8361 	case BuiltInInstanceId:
8362 		return "uint";
8363 	case BuiltInInstanceIndex:
8364 		return "uint";
8365 	case BuiltInBaseInstance:
8366 		return "uint";
8367 	case BuiltInDrawIndex:
8368 		SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
8369 
8370 	// Vertex function out
8371 	case BuiltInClipDistance:
8372 		return "float";
8373 	case BuiltInPointSize:
8374 		return "float";
8375 	case BuiltInPosition:
8376 		return "float4";
8377 	case BuiltInLayer:
8378 		return "uint";
8379 	case BuiltInViewportIndex:
8380 		if (!msl_options.supports_msl_version(2, 0))
8381 			SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
8382 		return "uint";
8383 
8384 	// Tess. control function in
8385 	case BuiltInInvocationId:
8386 		return "uint";
8387 	case BuiltInPatchVertices:
8388 		return "uint";
8389 	case BuiltInPrimitiveId:
8390 		return "uint";
8391 
8392 	// Tess. control function out
8393 	case BuiltInTessLevelInner:
8394 		if (execution.model == ExecutionModelTessellationEvaluation)
8395 			return !execution.flags.get(ExecutionModeTriangles) ? "float2" : "float";
8396 		return "half";
8397 	case BuiltInTessLevelOuter:
8398 		if (execution.model == ExecutionModelTessellationEvaluation)
8399 			return !execution.flags.get(ExecutionModeTriangles) ? "float4" : "float";
8400 		return "half";
8401 
8402 	// Tess. evaluation function in
8403 	case BuiltInTessCoord:
8404 		return execution.flags.get(ExecutionModeTriangles) ? "float3" : "float2";
8405 
8406 	// Fragment function in
8407 	case BuiltInFrontFacing:
8408 		return "bool";
8409 	case BuiltInPointCoord:
8410 		return "float2";
8411 	case BuiltInFragCoord:
8412 		return "float4";
8413 	case BuiltInSampleId:
8414 		return "uint";
8415 	case BuiltInSampleMask:
8416 		return "uint";
8417 	case BuiltInSamplePosition:
8418 		return "float2";
8419 	case BuiltInViewIndex:
8420 		return "uint";
8421 
8422 	// Fragment function out
8423 	case BuiltInFragDepth:
8424 		return "float";
8425 
8426 	case BuiltInFragStencilRefEXT:
8427 		return "uint";
8428 
8429 	// Compute function in
8430 	case BuiltInGlobalInvocationId:
8431 	case BuiltInLocalInvocationId:
8432 	case BuiltInNumWorkgroups:
8433 	case BuiltInWorkgroupId:
8434 		return "uint3";
8435 	case BuiltInLocalInvocationIndex:
8436 	case BuiltInNumSubgroups:
8437 	case BuiltInSubgroupId:
8438 	case BuiltInSubgroupSize:
8439 	case BuiltInSubgroupLocalInvocationId:
8440 		return "uint";
8441 	case BuiltInSubgroupEqMask:
8442 	case BuiltInSubgroupGeMask:
8443 	case BuiltInSubgroupGtMask:
8444 	case BuiltInSubgroupLeMask:
8445 	case BuiltInSubgroupLtMask:
8446 		return "uint4";
8447 
8448 	case BuiltInHelperInvocation:
8449 		return "bool";
8450 
8451 	case BuiltInBaryCoordNV:
8452 	case BuiltInBaryCoordNoPerspNV:
8453 		// Use the type as declared, can be 1, 2 or 3 components.
8454 		return type_to_glsl(get_variable_data_type(get<SPIRVariable>(id)));
8455 
8456 	default:
8457 		return "unsupported-built-in-type";
8458 	}
8459 }
8460 
8461 // Returns the declaration of a built-in argument to a function
built_in_func_arg(BuiltIn builtin,bool prefix_comma)8462 string CompilerMSL::built_in_func_arg(BuiltIn builtin, bool prefix_comma)
8463 {
8464 	string bi_arg;
8465 	if (prefix_comma)
8466 		bi_arg += ", ";
8467 
8468 	bi_arg += builtin_type_decl(builtin);
8469 	bi_arg += " " + builtin_to_glsl(builtin, StorageClassInput);
8470 	bi_arg += " [[" + builtin_qualifier(builtin) + "]]";
8471 
8472 	return bi_arg;
8473 }
8474 
8475 // Returns the byte size of a struct member.
get_declared_struct_member_size_msl(const SPIRType & struct_type,uint32_t index) const8476 size_t CompilerMSL::get_declared_struct_member_size_msl(const SPIRType &struct_type, uint32_t index) const
8477 {
8478 	auto &type = get<SPIRType>(struct_type.member_types[index]);
8479 
8480 	switch (type.basetype)
8481 	{
8482 	case SPIRType::Unknown:
8483 	case SPIRType::Void:
8484 	case SPIRType::AtomicCounter:
8485 	case SPIRType::Image:
8486 	case SPIRType::SampledImage:
8487 	case SPIRType::Sampler:
8488 		SPIRV_CROSS_THROW("Querying size of opaque object.");
8489 
8490 	default:
8491 	{
8492 		// For arrays, we can use ArrayStride to get an easy check.
8493 		// Runtime arrays will have zero size so force to min of one.
8494 		if (!type.array.empty())
8495 		{
8496 			uint32_t array_size = to_array_size_literal(type);
8497 #ifdef RARCH_INTERNAL
8498 			return type_struct_member_array_stride(struct_type, index) * MAX(array_size, 1u);
8499 #else
8500 			return type_struct_member_array_stride(struct_type, index) * max(array_size, 1u);
8501 #endif
8502 		}
8503 
8504 		if (type.basetype == SPIRType::Struct)
8505 		{
8506 			// The size of a struct in Metal is aligned up to its natural alignment.
8507 			auto size = get_declared_struct_size(type);
8508 			auto alignment = get_declared_struct_member_alignment(struct_type, index);
8509 			return (size + alignment - 1) & ~(alignment - 1);
8510 		}
8511 
8512 		uint32_t component_size = type.width / 8;
8513 		uint32_t vecsize = type.vecsize;
8514 		uint32_t columns = type.columns;
8515 
8516 		// An unpacked 3-element vector or matrix column is the same memory size as a 4-element.
8517 		if (vecsize == 3 && !has_extended_member_decoration(struct_type.self, index, SPIRVCrossDecorationPacked))
8518 			vecsize = 4;
8519 
8520 		return component_size * vecsize * columns;
8521 	}
8522 	}
8523 }
8524 
8525 // Returns the byte alignment of a struct member.
get_declared_struct_member_alignment(const SPIRType & struct_type,uint32_t index) const8526 size_t CompilerMSL::get_declared_struct_member_alignment(const SPIRType &struct_type, uint32_t index) const
8527 {
8528 	auto &type = get<SPIRType>(struct_type.member_types[index]);
8529 
8530 	switch (type.basetype)
8531 	{
8532 	case SPIRType::Unknown:
8533 	case SPIRType::Void:
8534 	case SPIRType::AtomicCounter:
8535 	case SPIRType::Image:
8536 	case SPIRType::SampledImage:
8537 	case SPIRType::Sampler:
8538 		SPIRV_CROSS_THROW("Querying alignment of opaque object.");
8539 
8540 	case SPIRType::Int64:
8541 		SPIRV_CROSS_THROW("long types are not supported in buffers in MSL.");
8542 	case SPIRType::UInt64:
8543 		SPIRV_CROSS_THROW("ulong types are not supported in buffers in MSL.");
8544 	case SPIRType::Double:
8545 		SPIRV_CROSS_THROW("double types are not supported in buffers in MSL.");
8546 
8547 	case SPIRType::Struct:
8548 	{
8549 		// In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
8550 		uint32_t alignment = 1;
8551 		for (uint32_t i = 0; i < type.member_types.size(); i++)
8552 			alignment = max(alignment, uint32_t(get_declared_struct_member_alignment(type, i)));
8553 		return alignment;
8554 	}
8555 
8556 	default:
8557 	{
8558 		// Alignment of packed type is the same as the underlying component or column size.
8559 		// Alignment of unpacked type is the same as the vector size.
8560 		// Alignment of 3-elements vector is the same as 4-elements (including packed using column).
8561 		if (member_is_packed_type(struct_type, index))
8562 		{
8563 			// This is getting pretty complicated.
8564 			// The special case of array of float/float2 needs to be handled here.
8565 			uint32_t packed_type_id =
8566 			    get_extended_member_decoration(struct_type.self, index, SPIRVCrossDecorationPackedType);
8567 			const SPIRType *packed_type = packed_type_id != 0 ? &get<SPIRType>(packed_type_id) : nullptr;
8568 			if (packed_type && is_array(*packed_type) && !is_matrix(*packed_type) &&
8569 			    packed_type->basetype != SPIRType::Struct)
8570 			{
8571 				uint32_t stride = type_struct_member_array_stride(struct_type, index);
8572 				if (stride == (packed_type->width / 8) * 4)
8573 					return stride;
8574 				else
8575 					return packed_type->width / 8;
8576 			}
8577 			else
8578 				return type.width / 8;
8579 		}
8580 		else
8581 			return (type.width / 8) * (type.vecsize == 3 ? 4 : type.vecsize);
8582 	}
8583 	}
8584 }
8585 
skip_argument(uint32_t) const8586 bool CompilerMSL::skip_argument(uint32_t) const
8587 {
8588 	return false;
8589 }
8590 
analyze_sampled_image_usage()8591 void CompilerMSL::analyze_sampled_image_usage()
8592 {
8593 	if (msl_options.swizzle_texture_samples)
8594 	{
8595 		SampledImageScanner scanner(*this);
8596 		traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), scanner);
8597 	}
8598 }
8599 
handle(spv::Op opcode,const uint32_t * args,uint32_t length)8600 bool CompilerMSL::SampledImageScanner::handle(spv::Op opcode, const uint32_t *args, uint32_t length)
8601 {
8602 	switch (opcode)
8603 	{
8604 	case OpLoad:
8605 	case OpImage:
8606 	case OpSampledImage:
8607 	{
8608 		if (length < 3)
8609 			return false;
8610 
8611 		uint32_t result_type = args[0];
8612 		auto &type = compiler.get<SPIRType>(result_type);
8613 		if ((type.basetype != SPIRType::Image && type.basetype != SPIRType::SampledImage) || type.image.sampled != 1)
8614 			return true;
8615 
8616 		uint32_t id = args[1];
8617 		compiler.set<SPIRExpression>(id, "", result_type, true);
8618 		break;
8619 	}
8620 	case OpImageSampleExplicitLod:
8621 	case OpImageSampleProjExplicitLod:
8622 	case OpImageSampleDrefExplicitLod:
8623 	case OpImageSampleProjDrefExplicitLod:
8624 	case OpImageSampleImplicitLod:
8625 	case OpImageSampleProjImplicitLod:
8626 	case OpImageSampleDrefImplicitLod:
8627 	case OpImageSampleProjDrefImplicitLod:
8628 	case OpImageFetch:
8629 	case OpImageGather:
8630 	case OpImageDrefGather:
8631 		compiler.has_sampled_images =
8632 		    compiler.has_sampled_images || compiler.is_sampled_image_type(compiler.expression_type(args[2]));
8633 		compiler.needs_swizzle_buffer_def = compiler.needs_swizzle_buffer_def || compiler.has_sampled_images;
8634 		break;
8635 	default:
8636 		break;
8637 	}
8638 	return true;
8639 }
8640 
handle(Op opcode,const uint32_t * args,uint32_t length)8641 bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, uint32_t length)
8642 {
8643 	// Since MSL exists in a single execution scope, function prototype declarations are not
8644 	// needed, and clutter the output. If secondary functions are output (either as a SPIR-V
8645 	// function implementation or as indicated by the presence of OpFunctionCall), then set
8646 	// suppress_missing_prototypes to suppress compiler warnings of missing function prototypes.
8647 
8648 	// Mark if the input requires the implementation of an SPIR-V function that does not exist in Metal.
8649 	SPVFuncImpl spv_func = get_spv_func_impl(opcode, args);
8650 	if (spv_func != SPVFuncImplNone)
8651 	{
8652 		compiler.spv_function_implementations.insert(spv_func);
8653 		suppress_missing_prototypes = true;
8654 	}
8655 
8656 	switch (opcode)
8657 	{
8658 
8659 	case OpFunctionCall:
8660 		suppress_missing_prototypes = true;
8661 		break;
8662 
8663 	case OpImageWrite:
8664 		uses_resource_write = true;
8665 		break;
8666 
8667 	case OpStore:
8668 		check_resource_write(args[0]);
8669 		break;
8670 
8671 	case OpAtomicExchange:
8672 	case OpAtomicCompareExchange:
8673 	case OpAtomicCompareExchangeWeak:
8674 	case OpAtomicIIncrement:
8675 	case OpAtomicIDecrement:
8676 	case OpAtomicIAdd:
8677 	case OpAtomicISub:
8678 	case OpAtomicSMin:
8679 	case OpAtomicUMin:
8680 	case OpAtomicSMax:
8681 	case OpAtomicUMax:
8682 	case OpAtomicAnd:
8683 	case OpAtomicOr:
8684 	case OpAtomicXor:
8685 		uses_atomics = true;
8686 		check_resource_write(args[2]);
8687 		break;
8688 
8689 	case OpAtomicLoad:
8690 		uses_atomics = true;
8691 		break;
8692 
8693 	case OpGroupNonUniformInverseBallot:
8694 		needs_subgroup_invocation_id = true;
8695 		break;
8696 
8697 	case OpGroupNonUniformBallotBitCount:
8698 		if (args[3] != GroupOperationReduce)
8699 			needs_subgroup_invocation_id = true;
8700 		break;
8701 
8702 	case OpArrayLength:
8703 	{
8704 		auto *var = compiler.maybe_get_backing_variable(args[2]);
8705 		if (var)
8706 			compiler.buffers_requiring_array_length.insert(var->self);
8707 		break;
8708 	}
8709 
8710 	case OpInBoundsAccessChain:
8711 	case OpAccessChain:
8712 	case OpPtrAccessChain:
8713 	{
8714 		// OpArrayLength might want to know if taking ArrayLength of an array of SSBOs.
8715 		uint32_t result_type = args[0];
8716 		uint32_t id = args[1];
8717 		uint32_t ptr = args[2];
8718 		compiler.set<SPIRExpression>(id, "", result_type, true);
8719 		compiler.register_read(id, ptr, true);
8720 		compiler.ir.ids[id].set_allow_type_rewrite();
8721 		break;
8722 	}
8723 
8724 	default:
8725 		break;
8726 	}
8727 
8728 	// If it has one, keep track of the instruction's result type, mapped by ID
8729 	uint32_t result_type, result_id;
8730 	if (compiler.instruction_to_result_type(result_type, result_id, opcode, args, length))
8731 		result_types[result_id] = result_type;
8732 
8733 	return true;
8734 }
8735 
8736 // If the variable is a Uniform or StorageBuffer, mark that a resource has been written to.
check_resource_write(uint32_t var_id)8737 void CompilerMSL::OpCodePreprocessor::check_resource_write(uint32_t var_id)
8738 {
8739 	auto *p_var = compiler.maybe_get_backing_variable(var_id);
8740 	StorageClass sc = p_var ? p_var->storage : StorageClassMax;
8741 	if (sc == StorageClassUniform || sc == StorageClassStorageBuffer)
8742 		uses_resource_write = true;
8743 }
8744 
8745 // Returns an enumeration of a SPIR-V function that needs to be output for certain Op codes.
get_spv_func_impl(Op opcode,const uint32_t * args)8746 CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op opcode, const uint32_t *args)
8747 {
8748 	switch (opcode)
8749 	{
8750 	case OpFMod:
8751 		return SPVFuncImplMod;
8752 
8753 	case OpFunctionCall:
8754 	{
8755 		auto &return_type = compiler.get<SPIRType>(args[0]);
8756 		if (return_type.array.size() > 1)
8757 		{
8758 			if (return_type.array.size() > SPVFuncImplArrayCopyMultidimMax)
8759 				SPIRV_CROSS_THROW("Cannot support this many dimensions for arrays of arrays.");
8760 			return static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + return_type.array.size());
8761 		}
8762 		else if (return_type.array.size() > 0)
8763 			return SPVFuncImplArrayCopy;
8764 
8765 		break;
8766 	}
8767 
8768 	case OpStore:
8769 	{
8770 		// Get the result type of the RHS. Since this is run as a pre-processing stage,
8771 		// we must extract the result type directly from the Instruction, rather than the ID.
8772 		uint32_t id_lhs = args[0];
8773 		uint32_t id_rhs = args[1];
8774 
8775 		const SPIRType *type = nullptr;
8776 		if (compiler.ir.ids[id_rhs].get_type() != TypeNone)
8777 		{
8778 			// Could be a constant, or similar.
8779 			type = &compiler.expression_type(id_rhs);
8780 		}
8781 		else
8782 		{
8783 			// Or ... an expression.
8784 			uint32_t tid = result_types[id_rhs];
8785 			if (tid)
8786 				type = &compiler.get<SPIRType>(tid);
8787 		}
8788 
8789 		auto *var = compiler.maybe_get<SPIRVariable>(id_lhs);
8790 
8791 		// Are we simply assigning to a statically assigned variable which takes a constant?
8792 		// Don't bother emitting this function.
8793 		bool static_expression_lhs =
8794 		    var && var->storage == StorageClassFunction && var->statically_assigned && var->remapped_variable;
8795 		if (type && compiler.is_array(*type) && !static_expression_lhs)
8796 		{
8797 			if (type->array.size() > 1)
8798 			{
8799 				if (type->array.size() > SPVFuncImplArrayCopyMultidimMax)
8800 					SPIRV_CROSS_THROW("Cannot support this many dimensions for arrays of arrays.");
8801 				return static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + type->array.size());
8802 			}
8803 			else
8804 				return SPVFuncImplArrayCopy;
8805 		}
8806 
8807 		break;
8808 	}
8809 
8810 	case OpImageFetch:
8811 	case OpImageRead:
8812 	case OpImageWrite:
8813 	{
8814 		// Retrieve the image type, and if it's a Buffer, emit a texel coordinate function
8815 		uint32_t tid = result_types[args[opcode == OpImageWrite ? 0 : 2]];
8816 		if (tid && compiler.get<SPIRType>(tid).image.dim == DimBuffer && !compiler.msl_options.texture_buffer_native)
8817 			return SPVFuncImplTexelBufferCoords;
8818 
8819 		if (opcode == OpImageFetch && compiler.msl_options.swizzle_texture_samples)
8820 			return SPVFuncImplTextureSwizzle;
8821 
8822 		break;
8823 	}
8824 
8825 	case OpImageSampleExplicitLod:
8826 	case OpImageSampleProjExplicitLod:
8827 	case OpImageSampleDrefExplicitLod:
8828 	case OpImageSampleProjDrefExplicitLod:
8829 	case OpImageSampleImplicitLod:
8830 	case OpImageSampleProjImplicitLod:
8831 	case OpImageSampleDrefImplicitLod:
8832 	case OpImageSampleProjDrefImplicitLod:
8833 	case OpImageGather:
8834 	case OpImageDrefGather:
8835 		if (compiler.msl_options.swizzle_texture_samples)
8836 			return SPVFuncImplTextureSwizzle;
8837 		break;
8838 
8839 	case OpExtInst:
8840 	{
8841 		uint32_t extension_set = args[2];
8842 		if (compiler.get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
8843 		{
8844 			auto op_450 = static_cast<GLSLstd450>(args[3]);
8845 			switch (op_450)
8846 			{
8847 			case GLSLstd450Radians:
8848 				return SPVFuncImplRadians;
8849 			case GLSLstd450Degrees:
8850 				return SPVFuncImplDegrees;
8851 			case GLSLstd450FindILsb:
8852 				return SPVFuncImplFindILsb;
8853 			case GLSLstd450FindSMsb:
8854 				return SPVFuncImplFindSMsb;
8855 			case GLSLstd450FindUMsb:
8856 				return SPVFuncImplFindUMsb;
8857 			case GLSLstd450SSign:
8858 				return SPVFuncImplSSign;
8859 			case GLSLstd450Reflect:
8860 			{
8861 				auto &type = compiler.get<SPIRType>(args[0]);
8862 				if (type.vecsize == 1)
8863 					return SPVFuncImplReflectScalar;
8864 				else
8865 					return SPVFuncImplNone;
8866 			}
8867 			case GLSLstd450Refract:
8868 			{
8869 				auto &type = compiler.get<SPIRType>(args[0]);
8870 				if (type.vecsize == 1)
8871 					return SPVFuncImplRefractScalar;
8872 				else
8873 					return SPVFuncImplNone;
8874 			}
8875 			case GLSLstd450MatrixInverse:
8876 			{
8877 				auto &mat_type = compiler.get<SPIRType>(args[0]);
8878 				switch (mat_type.columns)
8879 				{
8880 				case 2:
8881 					return SPVFuncImplInverse2x2;
8882 				case 3:
8883 					return SPVFuncImplInverse3x3;
8884 				case 4:
8885 					return SPVFuncImplInverse4x4;
8886 				default:
8887 					break;
8888 				}
8889 				break;
8890 			}
8891 			default:
8892 				break;
8893 			}
8894 		}
8895 		break;
8896 	}
8897 
8898 	case OpGroupNonUniformBallot:
8899 		return SPVFuncImplSubgroupBallot;
8900 
8901 	case OpGroupNonUniformInverseBallot:
8902 	case OpGroupNonUniformBallotBitExtract:
8903 		return SPVFuncImplSubgroupBallotBitExtract;
8904 
8905 	case OpGroupNonUniformBallotFindLSB:
8906 		return SPVFuncImplSubgroupBallotFindLSB;
8907 
8908 	case OpGroupNonUniformBallotFindMSB:
8909 		return SPVFuncImplSubgroupBallotFindMSB;
8910 
8911 	case OpGroupNonUniformBallotBitCount:
8912 		return SPVFuncImplSubgroupBallotBitCount;
8913 
8914 	case OpGroupNonUniformAllEqual:
8915 		return SPVFuncImplSubgroupAllEqual;
8916 
8917 	default:
8918 		break;
8919 	}
8920 	return SPVFuncImplNone;
8921 }
8922 
8923 // Sort both type and meta member content based on builtin status (put builtins at end),
8924 // then by the required sorting aspect.
sort()8925 void CompilerMSL::MemberSorter::sort()
8926 {
8927 	// Create a temporary array of consecutive member indices and sort it based on how
8928 	// the members should be reordered, based on builtin and sorting aspect meta info.
8929 	size_t mbr_cnt = type.member_types.size();
8930 	SmallVector<uint32_t> mbr_idxs(mbr_cnt);
8931 	iota(mbr_idxs.begin(), mbr_idxs.end(), 0); // Fill with consecutive indices
8932 	std::sort(mbr_idxs.begin(), mbr_idxs.end(), *this); // Sort member indices based on sorting aspect
8933 
8934 	// Move type and meta member info to the order defined by the sorted member indices.
8935 	// This is done by creating temporary copies of both member types and meta, and then
8936 	// copying back to the original content at the sorted indices.
8937 	auto mbr_types_cpy = type.member_types;
8938 	auto mbr_meta_cpy = meta.members;
8939 	for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
8940 	{
8941 		type.member_types[mbr_idx] = mbr_types_cpy[mbr_idxs[mbr_idx]];
8942 		meta.members[mbr_idx] = mbr_meta_cpy[mbr_idxs[mbr_idx]];
8943 	}
8944 }
8945 
8946 // Sort first by builtin status (put builtins at end), then by the sorting aspect.
operator ()(uint32_t mbr_idx1,uint32_t mbr_idx2)8947 bool CompilerMSL::MemberSorter::operator()(uint32_t mbr_idx1, uint32_t mbr_idx2)
8948 {
8949 	auto &mbr_meta1 = meta.members[mbr_idx1];
8950 	auto &mbr_meta2 = meta.members[mbr_idx2];
8951 	if (mbr_meta1.builtin != mbr_meta2.builtin)
8952 		return mbr_meta2.builtin;
8953 	else
8954 		switch (sort_aspect)
8955 		{
8956 		case Location:
8957 			return mbr_meta1.location < mbr_meta2.location;
8958 		case LocationReverse:
8959 			return mbr_meta1.location > mbr_meta2.location;
8960 		case Offset:
8961 			return mbr_meta1.offset < mbr_meta2.offset;
8962 		case OffsetThenLocationReverse:
8963 			return (mbr_meta1.offset < mbr_meta2.offset) ||
8964 			       ((mbr_meta1.offset == mbr_meta2.offset) && (mbr_meta1.location > mbr_meta2.location));
8965 		case Alphabetical:
8966 			return mbr_meta1.alias < mbr_meta2.alias;
8967 		default:
8968 			return false;
8969 		}
8970 }
8971 
MemberSorter(SPIRType & t,Meta & m,SortAspect sa)8972 CompilerMSL::MemberSorter::MemberSorter(SPIRType &t, Meta &m, SortAspect sa)
8973     : type(t)
8974     , meta(m)
8975     , sort_aspect(sa)
8976 {
8977 	// Ensure enough meta info is available
8978 	meta.members.resize(max(type.member_types.size(), meta.members.size()));
8979 }
8980 
remap_constexpr_sampler(uint32_t id,const MSLConstexprSampler & sampler)8981 void CompilerMSL::remap_constexpr_sampler(uint32_t id, const MSLConstexprSampler &sampler)
8982 {
8983 	auto &type = get<SPIRType>(get<SPIRVariable>(id).basetype);
8984 	if (type.basetype != SPIRType::SampledImage && type.basetype != SPIRType::Sampler)
8985 		SPIRV_CROSS_THROW("Can only remap SampledImage and Sampler type.");
8986 	if (!type.array.empty())
8987 		SPIRV_CROSS_THROW("Can not remap array of samplers.");
8988 	constexpr_samplers_by_id[id] = sampler;
8989 }
8990 
remap_constexpr_sampler_by_binding(uint32_t desc_set,uint32_t binding,const MSLConstexprSampler & sampler)8991 void CompilerMSL::remap_constexpr_sampler_by_binding(uint32_t desc_set, uint32_t binding,
8992                                                      const MSLConstexprSampler &sampler)
8993 {
8994 	constexpr_samplers_by_binding[{ desc_set, binding }] = sampler;
8995 }
8996 
bitcast_from_builtin_load(uint32_t source_id,std::string & expr,const SPIRType & expr_type)8997 void CompilerMSL::bitcast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
8998 {
8999 	auto *var = maybe_get_backing_variable(source_id);
9000 	if (var)
9001 		source_id = var->self;
9002 
9003 	// Only interested in standalone builtin variables.
9004 	if (!has_decoration(source_id, DecorationBuiltIn))
9005 		return;
9006 
9007 	auto builtin = static_cast<BuiltIn>(get_decoration(source_id, DecorationBuiltIn));
9008 	auto expected_type = expr_type.basetype;
9009 	switch (builtin)
9010 	{
9011 	case BuiltInGlobalInvocationId:
9012 	case BuiltInLocalInvocationId:
9013 	case BuiltInWorkgroupId:
9014 	case BuiltInLocalInvocationIndex:
9015 	case BuiltInWorkgroupSize:
9016 	case BuiltInNumWorkgroups:
9017 	case BuiltInLayer:
9018 	case BuiltInViewportIndex:
9019 	case BuiltInFragStencilRefEXT:
9020 	case BuiltInPrimitiveId:
9021 	case BuiltInSubgroupSize:
9022 	case BuiltInSubgroupLocalInvocationId:
9023 	case BuiltInViewIndex:
9024 		expected_type = SPIRType::UInt;
9025 		break;
9026 
9027 	case BuiltInTessLevelInner:
9028 	case BuiltInTessLevelOuter:
9029 		if (get_execution_model() == ExecutionModelTessellationControl)
9030 			expected_type = SPIRType::Half;
9031 		break;
9032 
9033 	default:
9034 		break;
9035 	}
9036 
9037 	if (expected_type != expr_type.basetype)
9038 		expr = bitcast_expression(expr_type, expected_type, expr);
9039 
9040 	if (builtin == BuiltInTessCoord && get_entry_point().flags.get(ExecutionModeQuads) && expr_type.vecsize == 3)
9041 	{
9042 		// In SPIR-V, this is always a vec3, even for quads. In Metal, though, it's a float2 for quads.
9043 		// The code is expecting a float3, so we need to widen this.
9044 		expr = join("float3(", expr, ", 0)");
9045 	}
9046 }
9047 
bitcast_to_builtin_store(uint32_t target_id,std::string & expr,const SPIRType & expr_type)9048 void CompilerMSL::bitcast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
9049 {
9050 	auto *var = maybe_get_backing_variable(target_id);
9051 	if (var)
9052 		target_id = var->self;
9053 
9054 	// Only interested in standalone builtin variables.
9055 	if (!has_decoration(target_id, DecorationBuiltIn))
9056 		return;
9057 
9058 	auto builtin = static_cast<BuiltIn>(get_decoration(target_id, DecorationBuiltIn));
9059 	auto expected_type = expr_type.basetype;
9060 	switch (builtin)
9061 	{
9062 	case BuiltInLayer:
9063 	case BuiltInViewportIndex:
9064 	case BuiltInFragStencilRefEXT:
9065 	case BuiltInPrimitiveId:
9066 	case BuiltInViewIndex:
9067 		expected_type = SPIRType::UInt;
9068 		break;
9069 
9070 	case BuiltInTessLevelInner:
9071 	case BuiltInTessLevelOuter:
9072 		expected_type = SPIRType::Half;
9073 		break;
9074 
9075 	default:
9076 		break;
9077 	}
9078 
9079 	if (expected_type != expr_type.basetype)
9080 	{
9081 		if (expected_type == SPIRType::Half && expr_type.basetype == SPIRType::Float)
9082 		{
9083 			// These are of different widths, so we cannot do a straight bitcast.
9084 			expr = join("half(", expr, ")");
9085 		}
9086 		else
9087 		{
9088 			auto type = expr_type;
9089 			type.basetype = expected_type;
9090 			expr = bitcast_expression(type, expr_type.basetype, expr);
9091 		}
9092 	}
9093 }
9094 
to_initializer_expression(const SPIRVariable & var)9095 std::string CompilerMSL::to_initializer_expression(const SPIRVariable &var)
9096 {
9097 	// We risk getting an array initializer here with MSL. If we have an array.
9098 	// FIXME: We cannot handle non-constant arrays being initialized.
9099 	// We will need to inject spvArrayCopy here somehow ...
9100 	auto &type = get<SPIRType>(var.basetype);
9101 	if (ir.ids[var.initializer].get_type() == TypeConstant &&
9102 	    (!type.array.empty() || type.basetype == SPIRType::Struct))
9103 		return constant_expression(get<SPIRConstant>(var.initializer));
9104 	else
9105 		return CompilerGLSL::to_initializer_expression(var);
9106 }
9107 
descriptor_set_is_argument_buffer(uint32_t desc_set) const9108 bool CompilerMSL::descriptor_set_is_argument_buffer(uint32_t desc_set) const
9109 {
9110 	if (!msl_options.argument_buffers)
9111 		return false;
9112 	if (desc_set >= kMaxArgumentBuffers)
9113 		return false;
9114 
9115 	return (argument_buffer_discrete_mask & (1u << desc_set)) == 0;
9116 }
9117 
analyze_argument_buffers()9118 void CompilerMSL::analyze_argument_buffers()
9119 {
9120 	// Gather all used resources and sort them out into argument buffers.
9121 	// Each argument buffer corresponds to a descriptor set in SPIR-V.
9122 	// The [[id(N)]] values used correspond to the resource mapping we have for MSL.
9123 	// Otherwise, the binding number is used, but this is generally not safe some types like
9124 	// combined image samplers and arrays of resources. Metal needs different indices here,
9125 	// while SPIR-V can have one descriptor set binding. To use argument buffers in practice,
9126 	// you will need to use the remapping from the API.
9127 	for (auto &id : argument_buffer_ids)
9128 		id = 0;
9129 
9130 	// Output resources, sorted by resource index & type.
9131 	struct Resource
9132 	{
9133 		SPIRVariable *var;
9134 		string name;
9135 		SPIRType::BaseType basetype;
9136 		uint32_t index;
9137 	};
9138 	SmallVector<Resource> resources_in_set[kMaxArgumentBuffers];
9139 
9140 	bool set_needs_swizzle_buffer[kMaxArgumentBuffers] = {};
9141 	bool set_needs_buffer_sizes[kMaxArgumentBuffers] = {};
9142 	bool needs_buffer_sizes = false;
9143 
9144 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &var) {
9145 		if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
9146 		     var.storage == StorageClassStorageBuffer) &&
9147 		    !is_hidden_variable(var))
9148 		{
9149 			uint32_t desc_set = get_decoration(self, DecorationDescriptorSet);
9150 			// Ignore if it's part of a push descriptor set.
9151 			if (!descriptor_set_is_argument_buffer(desc_set))
9152 				return;
9153 
9154 			uint32_t var_id = var.self;
9155 			auto &type = get_variable_data_type(var);
9156 
9157 			if (desc_set >= kMaxArgumentBuffers)
9158 				SPIRV_CROSS_THROW("Descriptor set index is out of range.");
9159 
9160 			const MSLConstexprSampler *constexpr_sampler = nullptr;
9161 			if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
9162 			{
9163 				constexpr_sampler = find_constexpr_sampler(var_id);
9164 				if (constexpr_sampler)
9165 				{
9166 					// Mark this ID as a constexpr sampler for later in case it came from set/bindings.
9167 					constexpr_samplers_by_id[var_id] = *constexpr_sampler;
9168 				}
9169 			}
9170 
9171 			if (type.basetype == SPIRType::SampledImage)
9172 			{
9173 				add_resource_name(var_id);
9174 
9175 				uint32_t image_resource_index = get_metal_resource_index(var, SPIRType::Image);
9176 				uint32_t sampler_resource_index = get_metal_resource_index(var, SPIRType::Sampler);
9177 
9178 				resources_in_set[desc_set].push_back({ &var, to_name(var_id), SPIRType::Image, image_resource_index });
9179 
9180 				if (type.image.dim != DimBuffer && !constexpr_sampler)
9181 				{
9182 					resources_in_set[desc_set].push_back(
9183 					    { &var, to_sampler_expression(var_id), SPIRType::Sampler, sampler_resource_index });
9184 				}
9185 			}
9186 			else if (!constexpr_sampler)
9187 			{
9188 				// constexpr samplers are not declared as resources.
9189 				add_resource_name(var_id);
9190 				resources_in_set[desc_set].push_back(
9191 				    { &var, to_name(var_id), type.basetype, get_metal_resource_index(var, type.basetype) });
9192 			}
9193 
9194 			// Check if this descriptor set needs a swizzle buffer.
9195 			if (needs_swizzle_buffer_def && is_sampled_image_type(type))
9196 				set_needs_swizzle_buffer[desc_set] = true;
9197 			else if (buffers_requiring_array_length.count(var_id) != 0)
9198 			{
9199 				set_needs_buffer_sizes[desc_set] = true;
9200 				needs_buffer_sizes = true;
9201 			}
9202 		}
9203 	});
9204 
9205 	if (needs_swizzle_buffer_def || needs_buffer_sizes)
9206 	{
9207 		uint32_t uint_ptr_type_id = 0;
9208 
9209 		// We might have to add a swizzle buffer resource to the set.
9210 		for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
9211 		{
9212 			if (!set_needs_swizzle_buffer[desc_set] && !set_needs_buffer_sizes[desc_set])
9213 				continue;
9214 
9215 			if (uint_ptr_type_id == 0)
9216 			{
9217 				uint32_t offset = ir.increase_bound_by(2);
9218 				uint32_t type_id = offset;
9219 				uint_ptr_type_id = offset + 1;
9220 
9221 				// Create a buffer to hold extra data, including the swizzle constants.
9222 				SPIRType uint_type;
9223 				uint_type.basetype = SPIRType::UInt;
9224 				uint_type.width = 32;
9225 				set<SPIRType>(type_id, uint_type);
9226 
9227 				SPIRType uint_type_pointer = uint_type;
9228 				uint_type_pointer.pointer = true;
9229 				uint_type_pointer.pointer_depth = 1;
9230 				uint_type_pointer.parent_type = type_id;
9231 				uint_type_pointer.storage = StorageClassUniform;
9232 				set<SPIRType>(uint_ptr_type_id, uint_type_pointer);
9233 				set_decoration(uint_ptr_type_id, DecorationArrayStride, 4);
9234 			}
9235 
9236 			if (set_needs_swizzle_buffer[desc_set])
9237 			{
9238 				uint32_t var_id = ir.increase_bound_by(1);
9239 				auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
9240 				set_name(var_id, "spvSwizzleConstants");
9241 				set_decoration(var_id, DecorationDescriptorSet, desc_set);
9242 				set_decoration(var_id, DecorationBinding, kSwizzleBufferBinding);
9243 				resources_in_set[desc_set].push_back(
9244 				    { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt) });
9245 			}
9246 
9247 			if (set_needs_buffer_sizes[desc_set])
9248 			{
9249 				uint32_t var_id = ir.increase_bound_by(1);
9250 				auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
9251 				set_name(var_id, "spvBufferSizeConstants");
9252 				set_decoration(var_id, DecorationDescriptorSet, desc_set);
9253 				set_decoration(var_id, DecorationBinding, kBufferSizeBufferBinding);
9254 				resources_in_set[desc_set].push_back(
9255 				    { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt) });
9256 			}
9257 		}
9258 	}
9259 
9260 	for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
9261 	{
9262 		auto &resources = resources_in_set[desc_set];
9263 		if (resources.empty())
9264 			continue;
9265 
9266 		assert(descriptor_set_is_argument_buffer(desc_set));
9267 
9268 		uint32_t next_id = ir.increase_bound_by(3);
9269 		uint32_t type_id = next_id + 1;
9270 		uint32_t ptr_type_id = next_id + 2;
9271 		argument_buffer_ids[desc_set] = next_id;
9272 
9273 		auto &buffer_type = set<SPIRType>(type_id);
9274 		buffer_type.storage = StorageClassUniform;
9275 		buffer_type.basetype = SPIRType::Struct;
9276 		set_name(type_id, join("spvDescriptorSetBuffer", desc_set));
9277 
9278 		auto &ptr_type = set<SPIRType>(ptr_type_id);
9279 		ptr_type = buffer_type;
9280 		ptr_type.pointer = true;
9281 		ptr_type.pointer_depth = 1;
9282 		ptr_type.parent_type = type_id;
9283 
9284 		uint32_t buffer_variable_id = next_id;
9285 		set<SPIRVariable>(buffer_variable_id, ptr_type_id, StorageClassUniform);
9286 		set_name(buffer_variable_id, join("spvDescriptorSet", desc_set));
9287 
9288 		// Ids must be emitted in ID order.
9289 		sort(begin(resources), end(resources), [&](const Resource &lhs, const Resource &rhs) -> bool {
9290 			return tie(lhs.index, lhs.basetype) < tie(rhs.index, rhs.basetype);
9291 		});
9292 
9293 		uint32_t member_index = 0;
9294 		for (auto &resource : resources)
9295 		{
9296 			auto &var = *resource.var;
9297 			auto &type = get_variable_data_type(var);
9298 			string mbr_name = ensure_valid_name(resource.name, "m");
9299 			set_member_name(buffer_type.self, member_index, mbr_name);
9300 
9301 			if (resource.basetype == SPIRType::Sampler && type.basetype != SPIRType::Sampler)
9302 			{
9303 				// Have to synthesize a sampler type here.
9304 
9305 				bool type_is_array = !type.array.empty();
9306 				uint32_t sampler_type_id = ir.increase_bound_by(type_is_array ? 2 : 1);
9307 				auto &new_sampler_type = set<SPIRType>(sampler_type_id);
9308 				new_sampler_type.basetype = SPIRType::Sampler;
9309 				new_sampler_type.storage = StorageClassUniformConstant;
9310 
9311 				if (type_is_array)
9312 				{
9313 					uint32_t sampler_type_array_id = sampler_type_id + 1;
9314 					auto &sampler_type_array = set<SPIRType>(sampler_type_array_id);
9315 					sampler_type_array = new_sampler_type;
9316 					sampler_type_array.array = type.array;
9317 					sampler_type_array.array_size_literal = type.array_size_literal;
9318 					sampler_type_array.parent_type = sampler_type_id;
9319 					buffer_type.member_types.push_back(sampler_type_array_id);
9320 				}
9321 				else
9322 					buffer_type.member_types.push_back(sampler_type_id);
9323 			}
9324 			else
9325 			{
9326 				if (resource.basetype == SPIRType::Image || resource.basetype == SPIRType::Sampler ||
9327 				    resource.basetype == SPIRType::SampledImage)
9328 				{
9329 					// Drop pointer information when we emit the resources into a struct.
9330 					buffer_type.member_types.push_back(get_variable_data_type_id(var));
9331 					set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
9332 				}
9333 				else
9334 				{
9335 					// Resources will be declared as pointers not references, so automatically dereference as appropriate.
9336 					buffer_type.member_types.push_back(var.basetype);
9337 					if (type.array.empty())
9338 						set_qualified_name(var.self, join("(*", to_name(buffer_variable_id), ".", mbr_name, ")"));
9339 					else
9340 						set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
9341 				}
9342 			}
9343 
9344 			set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationResourceIndexPrimary,
9345 			                               resource.index);
9346 			set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationInterfaceOrigID,
9347 			                               var.self);
9348 			member_index++;
9349 		}
9350 	}
9351 }
9352 
operator ==(const SetBindingPair & other) const9353 bool CompilerMSL::SetBindingPair::operator==(const SetBindingPair &other) const
9354 {
9355 	return desc_set == other.desc_set && binding == other.binding;
9356 }
9357 
operator ==(const StageSetBinding & other) const9358 bool CompilerMSL::StageSetBinding::operator==(const StageSetBinding &other) const
9359 {
9360 	return model == other.model && desc_set == other.desc_set && binding == other.binding;
9361 }
9362 
operator ()(const SetBindingPair & value) const9363 size_t CompilerMSL::InternalHasher::operator()(const SetBindingPair &value) const
9364 {
9365 	// Quality of hash doesn't really matter here.
9366 	auto hash_set = std::hash<uint32_t>()(value.desc_set);
9367 	auto hash_binding = std::hash<uint32_t>()(value.binding);
9368 	return (hash_set * 0x10001b31) ^ hash_binding;
9369 }
9370 
operator ()(const StageSetBinding & value) const9371 size_t CompilerMSL::InternalHasher::operator()(const StageSetBinding &value) const
9372 {
9373 	// Quality of hash doesn't really matter here.
9374 	auto hash_model = std::hash<uint32_t>()(value.model);
9375 	auto hash_set = std::hash<uint32_t>()(value.desc_set);
9376 	auto tmp_hash = (hash_model * 0x10001b31) ^ hash_set;
9377 	return (tmp_hash * 0x10001b31) ^ value.binding;
9378 }
9379