1 /*
2  * Copyright 2016-2020 The Brenwill Workshop Ltd.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "spirv_msl.hpp"
18 #include "GLSL.std.450.h"
19 
20 #include <algorithm>
21 #include <assert.h>
22 #include <numeric>
23 
24 using namespace spv;
25 using namespace SPIRV_CROSS_NAMESPACE;
26 using namespace std;
27 
28 static const uint32_t k_unknown_location = ~0u;
29 static const uint32_t k_unknown_component = ~0u;
30 static const char *force_inline = "static inline __attribute__((always_inline))";
31 
CompilerMSL(std::vector<uint32_t> spirv_)32 CompilerMSL::CompilerMSL(std::vector<uint32_t> spirv_)
33     : CompilerGLSL(move(spirv_))
34 {
35 }
36 
CompilerMSL(const uint32_t * ir_,size_t word_count)37 CompilerMSL::CompilerMSL(const uint32_t *ir_, size_t word_count)
38     : CompilerGLSL(ir_, word_count)
39 {
40 }
41 
CompilerMSL(const ParsedIR & ir_)42 CompilerMSL::CompilerMSL(const ParsedIR &ir_)
43     : CompilerGLSL(ir_)
44 {
45 }
46 
CompilerMSL(ParsedIR && ir_)47 CompilerMSL::CompilerMSL(ParsedIR &&ir_)
48     : CompilerGLSL(std::move(ir_))
49 {
50 }
51 
add_msl_shader_input(const MSLShaderInput & si)52 void CompilerMSL::add_msl_shader_input(const MSLShaderInput &si)
53 {
54 	inputs_by_location[si.location] = si;
55 	if (si.builtin != BuiltInMax && !inputs_by_builtin.count(si.builtin))
56 		inputs_by_builtin[si.builtin] = si;
57 }
58 
add_msl_resource_binding(const MSLResourceBinding & binding)59 void CompilerMSL::add_msl_resource_binding(const MSLResourceBinding &binding)
60 {
61 	StageSetBinding tuple = { binding.stage, binding.desc_set, binding.binding };
62 	resource_bindings[tuple] = { binding, false };
63 }
64 
add_dynamic_buffer(uint32_t desc_set,uint32_t binding,uint32_t index)65 void CompilerMSL::add_dynamic_buffer(uint32_t desc_set, uint32_t binding, uint32_t index)
66 {
67 	SetBindingPair pair = { desc_set, binding };
68 	buffers_requiring_dynamic_offset[pair] = { index, 0 };
69 }
70 
add_inline_uniform_block(uint32_t desc_set,uint32_t binding)71 void CompilerMSL::add_inline_uniform_block(uint32_t desc_set, uint32_t binding)
72 {
73 	SetBindingPair pair = { desc_set, binding };
74 	inline_uniform_blocks.insert(pair);
75 }
76 
add_discrete_descriptor_set(uint32_t desc_set)77 void CompilerMSL::add_discrete_descriptor_set(uint32_t desc_set)
78 {
79 	if (desc_set < kMaxArgumentBuffers)
80 		argument_buffer_discrete_mask |= 1u << desc_set;
81 }
82 
set_argument_buffer_device_address_space(uint32_t desc_set,bool device_storage)83 void CompilerMSL::set_argument_buffer_device_address_space(uint32_t desc_set, bool device_storage)
84 {
85 	if (desc_set < kMaxArgumentBuffers)
86 	{
87 		if (device_storage)
88 			argument_buffer_device_storage_mask |= 1u << desc_set;
89 		else
90 			argument_buffer_device_storage_mask &= ~(1u << desc_set);
91 	}
92 }
93 
is_msl_shader_input_used(uint32_t location)94 bool CompilerMSL::is_msl_shader_input_used(uint32_t location)
95 {
96 	return inputs_in_use.count(location) != 0;
97 }
98 
is_msl_resource_binding_used(ExecutionModel model,uint32_t desc_set,uint32_t binding) const99 bool CompilerMSL::is_msl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const
100 {
101 	StageSetBinding tuple = { model, desc_set, binding };
102 	auto itr = resource_bindings.find(tuple);
103 	return itr != end(resource_bindings) && itr->second.second;
104 }
105 
get_automatic_msl_resource_binding(uint32_t id) const106 uint32_t CompilerMSL::get_automatic_msl_resource_binding(uint32_t id) const
107 {
108 	return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexPrimary);
109 }
110 
get_automatic_msl_resource_binding_secondary(uint32_t id) const111 uint32_t CompilerMSL::get_automatic_msl_resource_binding_secondary(uint32_t id) const
112 {
113 	return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexSecondary);
114 }
115 
get_automatic_msl_resource_binding_tertiary(uint32_t id) const116 uint32_t CompilerMSL::get_automatic_msl_resource_binding_tertiary(uint32_t id) const
117 {
118 	return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexTertiary);
119 }
120 
get_automatic_msl_resource_binding_quaternary(uint32_t id) const121 uint32_t CompilerMSL::get_automatic_msl_resource_binding_quaternary(uint32_t id) const
122 {
123 	return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexQuaternary);
124 }
125 
set_fragment_output_components(uint32_t location,uint32_t components)126 void CompilerMSL::set_fragment_output_components(uint32_t location, uint32_t components)
127 {
128 	fragment_output_components[location] = components;
129 }
130 
builtin_translates_to_nonarray(spv::BuiltIn builtin) const131 bool CompilerMSL::builtin_translates_to_nonarray(spv::BuiltIn builtin) const
132 {
133 	return (builtin == BuiltInSampleMask);
134 }
135 
build_implicit_builtins()136 void CompilerMSL::build_implicit_builtins()
137 {
138 	bool need_sample_pos = active_input_builtins.get(BuiltInSamplePosition);
139 	bool need_vertex_params = capture_output_to_buffer && get_execution_model() == ExecutionModelVertex &&
140 	                          !msl_options.vertex_for_tessellation;
141 	bool need_tesc_params = get_execution_model() == ExecutionModelTessellationControl;
142 	bool need_subgroup_mask =
143 	    active_input_builtins.get(BuiltInSubgroupEqMask) || active_input_builtins.get(BuiltInSubgroupGeMask) ||
144 	    active_input_builtins.get(BuiltInSubgroupGtMask) || active_input_builtins.get(BuiltInSubgroupLeMask) ||
145 	    active_input_builtins.get(BuiltInSubgroupLtMask);
146 	bool need_subgroup_ge_mask = !msl_options.is_ios() && (active_input_builtins.get(BuiltInSubgroupGeMask) ||
147 	                                                       active_input_builtins.get(BuiltInSubgroupGtMask));
148 	bool need_multiview = get_execution_model() == ExecutionModelVertex && !msl_options.view_index_from_device_index &&
149 	                      msl_options.multiview_layered_rendering &&
150 	                      (msl_options.multiview || active_input_builtins.get(BuiltInViewIndex));
151 	bool need_dispatch_base =
152 	    msl_options.dispatch_base && get_execution_model() == ExecutionModelGLCompute &&
153 	    (active_input_builtins.get(BuiltInWorkgroupId) || active_input_builtins.get(BuiltInGlobalInvocationId));
154 	bool need_grid_params = get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation;
155 	bool need_vertex_base_params =
156 	    need_grid_params &&
157 	    (active_input_builtins.get(BuiltInVertexId) || active_input_builtins.get(BuiltInVertexIndex) ||
158 	     active_input_builtins.get(BuiltInBaseVertex) || active_input_builtins.get(BuiltInInstanceId) ||
159 	     active_input_builtins.get(BuiltInInstanceIndex) || active_input_builtins.get(BuiltInBaseInstance));
160 	bool need_sample_mask = msl_options.additional_fixed_sample_mask != 0xffffffff;
161 	if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
162 	    need_multiview || need_dispatch_base || need_vertex_base_params || need_grid_params ||
163 	    needs_subgroup_invocation_id || need_sample_mask)
164 	{
165 		bool has_frag_coord = false;
166 		bool has_sample_id = false;
167 		bool has_vertex_idx = false;
168 		bool has_base_vertex = false;
169 		bool has_instance_idx = false;
170 		bool has_base_instance = false;
171 		bool has_invocation_id = false;
172 		bool has_primitive_id = false;
173 		bool has_subgroup_invocation_id = false;
174 		bool has_subgroup_size = false;
175 		bool has_view_idx = false;
176 		bool has_layer = false;
177 		uint32_t workgroup_id_type = 0;
178 
179 		// FIXME: Investigate the fact that there are no checks for the entry point interface variables.
180 		ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
181 			if (!ir.meta[var.self].decoration.builtin)
182 				return;
183 
184 			// Use Metal's native frame-buffer fetch API for subpass inputs.
185 			BuiltIn builtin = ir.meta[var.self].decoration.builtin_type;
186 
187 			if (var.storage == StorageClassOutput)
188 			{
189 				if (need_sample_mask && builtin == BuiltInSampleMask)
190 				{
191 					builtin_sample_mask_id = var.self;
192 					mark_implicit_builtin(StorageClassOutput, BuiltInSampleMask, var.self);
193 					does_shader_write_sample_mask = true;
194 				}
195 			}
196 
197 			if (var.storage != StorageClassInput)
198 				return;
199 
200 			if (need_subpass_input && (!msl_options.is_ios() || !msl_options.ios_use_framebuffer_fetch_subpasses))
201 			{
202 				switch (builtin)
203 				{
204 				case BuiltInFragCoord:
205 					mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var.self);
206 					builtin_frag_coord_id = var.self;
207 					has_frag_coord = true;
208 					break;
209 				case BuiltInLayer:
210 					if (!msl_options.arrayed_subpass_input || msl_options.multiview)
211 						break;
212 					mark_implicit_builtin(StorageClassInput, BuiltInLayer, var.self);
213 					builtin_layer_id = var.self;
214 					has_layer = true;
215 					break;
216 				case BuiltInViewIndex:
217 					if (!msl_options.multiview)
218 						break;
219 					mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var.self);
220 					builtin_view_idx_id = var.self;
221 					has_view_idx = true;
222 					break;
223 				default:
224 					break;
225 				}
226 			}
227 
228 			if (need_sample_pos && builtin == BuiltInSampleId)
229 			{
230 				builtin_sample_id_id = var.self;
231 				mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var.self);
232 				has_sample_id = true;
233 			}
234 
235 			if (need_vertex_params)
236 			{
237 				switch (builtin)
238 				{
239 				case BuiltInVertexIndex:
240 					builtin_vertex_idx_id = var.self;
241 					mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var.self);
242 					has_vertex_idx = true;
243 					break;
244 				case BuiltInBaseVertex:
245 					builtin_base_vertex_id = var.self;
246 					mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var.self);
247 					has_base_vertex = true;
248 					break;
249 				case BuiltInInstanceIndex:
250 					builtin_instance_idx_id = var.self;
251 					mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var.self);
252 					has_instance_idx = true;
253 					break;
254 				case BuiltInBaseInstance:
255 					builtin_base_instance_id = var.self;
256 					mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var.self);
257 					has_base_instance = true;
258 					break;
259 				default:
260 					break;
261 				}
262 			}
263 
264 			if (need_tesc_params)
265 			{
266 				switch (builtin)
267 				{
268 				case BuiltInInvocationId:
269 					builtin_invocation_id_id = var.self;
270 					mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var.self);
271 					has_invocation_id = true;
272 					break;
273 				case BuiltInPrimitiveId:
274 					builtin_primitive_id_id = var.self;
275 					mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var.self);
276 					has_primitive_id = true;
277 					break;
278 				default:
279 					break;
280 				}
281 			}
282 
283 			if ((need_subgroup_mask || needs_subgroup_invocation_id) && builtin == BuiltInSubgroupLocalInvocationId)
284 			{
285 				builtin_subgroup_invocation_id_id = var.self;
286 				mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var.self);
287 				has_subgroup_invocation_id = true;
288 			}
289 
290 			if (need_subgroup_ge_mask && builtin == BuiltInSubgroupSize)
291 			{
292 				builtin_subgroup_size_id = var.self;
293 				mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var.self);
294 				has_subgroup_size = true;
295 			}
296 
297 			if (need_multiview)
298 			{
299 				switch (builtin)
300 				{
301 				case BuiltInInstanceIndex:
302 					// The view index here is derived from the instance index.
303 					builtin_instance_idx_id = var.self;
304 					mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var.self);
305 					has_instance_idx = true;
306 					break;
307 				case BuiltInBaseInstance:
308 					// If a non-zero base instance is used, we need to adjust for it when calculating the view index.
309 					builtin_base_instance_id = var.self;
310 					mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var.self);
311 					has_base_instance = true;
312 					break;
313 				case BuiltInViewIndex:
314 					builtin_view_idx_id = var.self;
315 					mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var.self);
316 					has_view_idx = true;
317 					break;
318 				default:
319 					break;
320 				}
321 			}
322 
323 			// The base workgroup needs to have the same type and vector size
324 			// as the workgroup or invocation ID, so keep track of the type that
325 			// was used.
326 			if (need_dispatch_base && workgroup_id_type == 0 &&
327 			    (builtin == BuiltInWorkgroupId || builtin == BuiltInGlobalInvocationId))
328 				workgroup_id_type = var.basetype;
329 		});
330 
331 		// Use Metal's native frame-buffer fetch API for subpass inputs.
332 		if ((!has_frag_coord || (msl_options.multiview && !has_view_idx) ||
333 		     (msl_options.arrayed_subpass_input && !msl_options.multiview && !has_layer)) &&
334 		    (!msl_options.is_ios() || !msl_options.ios_use_framebuffer_fetch_subpasses) && need_subpass_input)
335 		{
336 			if (!has_frag_coord)
337 			{
338 				uint32_t offset = ir.increase_bound_by(3);
339 				uint32_t type_id = offset;
340 				uint32_t type_ptr_id = offset + 1;
341 				uint32_t var_id = offset + 2;
342 
343 				// Create gl_FragCoord.
344 				SPIRType vec4_type;
345 				vec4_type.basetype = SPIRType::Float;
346 				vec4_type.width = 32;
347 				vec4_type.vecsize = 4;
348 				set<SPIRType>(type_id, vec4_type);
349 
350 				SPIRType vec4_type_ptr;
351 				vec4_type_ptr = vec4_type;
352 				vec4_type_ptr.pointer = true;
353 				vec4_type_ptr.parent_type = type_id;
354 				vec4_type_ptr.storage = StorageClassInput;
355 				auto &ptr_type = set<SPIRType>(type_ptr_id, vec4_type_ptr);
356 				ptr_type.self = type_id;
357 
358 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
359 				set_decoration(var_id, DecorationBuiltIn, BuiltInFragCoord);
360 				builtin_frag_coord_id = var_id;
361 				mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var_id);
362 			}
363 
364 			if (!has_layer && msl_options.arrayed_subpass_input && !msl_options.multiview)
365 			{
366 				uint32_t offset = ir.increase_bound_by(2);
367 				uint32_t type_ptr_id = offset;
368 				uint32_t var_id = offset + 1;
369 
370 				// Create gl_Layer.
371 				SPIRType uint_type_ptr;
372 				uint_type_ptr = get_uint_type();
373 				uint_type_ptr.pointer = true;
374 				uint_type_ptr.parent_type = get_uint_type_id();
375 				uint_type_ptr.storage = StorageClassInput;
376 				auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
377 				ptr_type.self = get_uint_type_id();
378 
379 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
380 				set_decoration(var_id, DecorationBuiltIn, BuiltInLayer);
381 				builtin_layer_id = var_id;
382 				mark_implicit_builtin(StorageClassInput, BuiltInLayer, var_id);
383 			}
384 
385 			if (!has_view_idx && msl_options.multiview)
386 			{
387 				uint32_t offset = ir.increase_bound_by(2);
388 				uint32_t type_ptr_id = offset;
389 				uint32_t var_id = offset + 1;
390 
391 				// Create gl_ViewIndex.
392 				SPIRType uint_type_ptr;
393 				uint_type_ptr = get_uint_type();
394 				uint_type_ptr.pointer = true;
395 				uint_type_ptr.parent_type = get_uint_type_id();
396 				uint_type_ptr.storage = StorageClassInput;
397 				auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
398 				ptr_type.self = get_uint_type_id();
399 
400 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
401 				set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex);
402 				builtin_view_idx_id = var_id;
403 				mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id);
404 			}
405 		}
406 
407 		if (!has_sample_id && need_sample_pos)
408 		{
409 			uint32_t offset = ir.increase_bound_by(2);
410 			uint32_t type_ptr_id = offset;
411 			uint32_t var_id = offset + 1;
412 
413 			// Create gl_SampleID.
414 			SPIRType uint_type_ptr;
415 			uint_type_ptr = get_uint_type();
416 			uint_type_ptr.pointer = true;
417 			uint_type_ptr.parent_type = get_uint_type_id();
418 			uint_type_ptr.storage = StorageClassInput;
419 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
420 			ptr_type.self = get_uint_type_id();
421 
422 			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
423 			set_decoration(var_id, DecorationBuiltIn, BuiltInSampleId);
424 			builtin_sample_id_id = var_id;
425 			mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var_id);
426 		}
427 
428 		if ((need_vertex_params && (!has_vertex_idx || !has_base_vertex || !has_instance_idx || !has_base_instance)) ||
429 		    (need_multiview && (!has_instance_idx || !has_base_instance || !has_view_idx)))
430 		{
431 			uint32_t type_ptr_id = ir.increase_bound_by(1);
432 
433 			SPIRType uint_type_ptr;
434 			uint_type_ptr = get_uint_type();
435 			uint_type_ptr.pointer = true;
436 			uint_type_ptr.parent_type = get_uint_type_id();
437 			uint_type_ptr.storage = StorageClassInput;
438 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
439 			ptr_type.self = get_uint_type_id();
440 
441 			if (need_vertex_params && !has_vertex_idx)
442 			{
443 				uint32_t var_id = ir.increase_bound_by(1);
444 
445 				// Create gl_VertexIndex.
446 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
447 				set_decoration(var_id, DecorationBuiltIn, BuiltInVertexIndex);
448 				builtin_vertex_idx_id = var_id;
449 				mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var_id);
450 			}
451 
452 			if (need_vertex_params && !has_base_vertex)
453 			{
454 				uint32_t var_id = ir.increase_bound_by(1);
455 
456 				// Create gl_BaseVertex.
457 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
458 				set_decoration(var_id, DecorationBuiltIn, BuiltInBaseVertex);
459 				builtin_base_vertex_id = var_id;
460 				mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var_id);
461 			}
462 
463 			if (!has_instance_idx) // Needed by both multiview and tessellation
464 			{
465 				uint32_t var_id = ir.increase_bound_by(1);
466 
467 				// Create gl_InstanceIndex.
468 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
469 				set_decoration(var_id, DecorationBuiltIn, BuiltInInstanceIndex);
470 				builtin_instance_idx_id = var_id;
471 				mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var_id);
472 			}
473 
474 			if (!has_base_instance) // Needed by both multiview and tessellation
475 			{
476 				uint32_t var_id = ir.increase_bound_by(1);
477 
478 				// Create gl_BaseInstance.
479 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
480 				set_decoration(var_id, DecorationBuiltIn, BuiltInBaseInstance);
481 				builtin_base_instance_id = var_id;
482 				mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var_id);
483 			}
484 
485 			if (need_multiview)
486 			{
487 				// Multiview shaders are not allowed to write to gl_Layer, ostensibly because
488 				// it is implicitly written from gl_ViewIndex, but we have to do that explicitly.
489 				// Note that we can't just abuse gl_ViewIndex for this purpose: it's an input, but
490 				// gl_Layer is an output in vertex-pipeline shaders.
491 				uint32_t type_ptr_out_id = ir.increase_bound_by(2);
492 				SPIRType uint_type_ptr_out;
493 				uint_type_ptr_out = get_uint_type();
494 				uint_type_ptr_out.pointer = true;
495 				uint_type_ptr_out.parent_type = get_uint_type_id();
496 				uint_type_ptr_out.storage = StorageClassOutput;
497 				auto &ptr_out_type = set<SPIRType>(type_ptr_out_id, uint_type_ptr_out);
498 				ptr_out_type.self = get_uint_type_id();
499 				uint32_t var_id = type_ptr_out_id + 1;
500 				set<SPIRVariable>(var_id, type_ptr_out_id, StorageClassOutput);
501 				set_decoration(var_id, DecorationBuiltIn, BuiltInLayer);
502 				builtin_layer_id = var_id;
503 				mark_implicit_builtin(StorageClassOutput, BuiltInLayer, var_id);
504 			}
505 
506 			if (need_multiview && !has_view_idx)
507 			{
508 				uint32_t var_id = ir.increase_bound_by(1);
509 
510 				// Create gl_ViewIndex.
511 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
512 				set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex);
513 				builtin_view_idx_id = var_id;
514 				mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id);
515 			}
516 		}
517 
518 		if ((need_tesc_params && (msl_options.multi_patch_workgroup || !has_invocation_id || !has_primitive_id)) ||
519 		    need_grid_params)
520 		{
521 			uint32_t type_ptr_id = ir.increase_bound_by(1);
522 
523 			SPIRType uint_type_ptr;
524 			uint_type_ptr = get_uint_type();
525 			uint_type_ptr.pointer = true;
526 			uint_type_ptr.parent_type = get_uint_type_id();
527 			uint_type_ptr.storage = StorageClassInput;
528 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
529 			ptr_type.self = get_uint_type_id();
530 
531 			if (msl_options.multi_patch_workgroup || need_grid_params)
532 			{
533 				uint32_t var_id = ir.increase_bound_by(1);
534 
535 				// Create gl_GlobalInvocationID.
536 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
537 				set_decoration(var_id, DecorationBuiltIn, BuiltInGlobalInvocationId);
538 				builtin_invocation_id_id = var_id;
539 				mark_implicit_builtin(StorageClassInput, BuiltInGlobalInvocationId, var_id);
540 			}
541 			else if (need_tesc_params && !has_invocation_id)
542 			{
543 				uint32_t var_id = ir.increase_bound_by(1);
544 
545 				// Create gl_InvocationID.
546 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
547 				set_decoration(var_id, DecorationBuiltIn, BuiltInInvocationId);
548 				builtin_invocation_id_id = var_id;
549 				mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var_id);
550 			}
551 
552 			if (need_tesc_params && !has_primitive_id)
553 			{
554 				uint32_t var_id = ir.increase_bound_by(1);
555 
556 				// Create gl_PrimitiveID.
557 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
558 				set_decoration(var_id, DecorationBuiltIn, BuiltInPrimitiveId);
559 				builtin_primitive_id_id = var_id;
560 				mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var_id);
561 			}
562 
563 			if (need_grid_params)
564 			{
565 				uint32_t var_id = ir.increase_bound_by(1);
566 
567 				set<SPIRVariable>(var_id, build_extended_vector_type(get_uint_type_id(), 3), StorageClassInput);
568 				set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInStageInputSize);
569 				get_entry_point().interface_variables.push_back(var_id);
570 				set_name(var_id, "spvStageInputSize");
571 				builtin_stage_input_size_id = var_id;
572 			}
573 		}
574 
575 		if (!has_subgroup_invocation_id && (need_subgroup_mask || needs_subgroup_invocation_id))
576 		{
577 			uint32_t offset = ir.increase_bound_by(2);
578 			uint32_t type_ptr_id = offset;
579 			uint32_t var_id = offset + 1;
580 
581 			// Create gl_SubgroupInvocationID.
582 			SPIRType uint_type_ptr;
583 			uint_type_ptr = get_uint_type();
584 			uint_type_ptr.pointer = true;
585 			uint_type_ptr.parent_type = get_uint_type_id();
586 			uint_type_ptr.storage = StorageClassInput;
587 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
588 			ptr_type.self = get_uint_type_id();
589 
590 			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
591 			set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupLocalInvocationId);
592 			builtin_subgroup_invocation_id_id = var_id;
593 			mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var_id);
594 		}
595 
596 		if (!has_subgroup_size && need_subgroup_ge_mask)
597 		{
598 			uint32_t offset = ir.increase_bound_by(2);
599 			uint32_t type_ptr_id = offset;
600 			uint32_t var_id = offset + 1;
601 
602 			// Create gl_SubgroupSize.
603 			SPIRType uint_type_ptr;
604 			uint_type_ptr = get_uint_type();
605 			uint_type_ptr.pointer = true;
606 			uint_type_ptr.parent_type = get_uint_type_id();
607 			uint_type_ptr.storage = StorageClassInput;
608 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
609 			ptr_type.self = get_uint_type_id();
610 
611 			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
612 			set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupSize);
613 			builtin_subgroup_size_id = var_id;
614 			mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var_id);
615 		}
616 
617 		if (need_dispatch_base || need_vertex_base_params)
618 		{
619 			if (workgroup_id_type == 0)
620 				workgroup_id_type = build_extended_vector_type(get_uint_type_id(), 3);
621 			uint32_t var_id;
622 			if (msl_options.supports_msl_version(1, 2))
623 			{
624 				// If we have MSL 1.2, we can (ab)use the [[grid_origin]] builtin
625 				// to convey this information and save a buffer slot.
626 				uint32_t offset = ir.increase_bound_by(1);
627 				var_id = offset;
628 
629 				set<SPIRVariable>(var_id, workgroup_id_type, StorageClassInput);
630 				set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase);
631 				get_entry_point().interface_variables.push_back(var_id);
632 			}
633 			else
634 			{
635 				// Otherwise, we need to fall back to a good ol' fashioned buffer.
636 				uint32_t offset = ir.increase_bound_by(2);
637 				var_id = offset;
638 				uint32_t type_id = offset + 1;
639 
640 				SPIRType var_type = get<SPIRType>(workgroup_id_type);
641 				var_type.storage = StorageClassUniform;
642 				set<SPIRType>(type_id, var_type);
643 
644 				set<SPIRVariable>(var_id, type_id, StorageClassUniform);
645 				// This should never match anything.
646 				set_decoration(var_id, DecorationDescriptorSet, ~(5u));
647 				set_decoration(var_id, DecorationBinding, msl_options.indirect_params_buffer_index);
648 				set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary,
649 				                        msl_options.indirect_params_buffer_index);
650 			}
651 			set_name(var_id, "spvDispatchBase");
652 			builtin_dispatch_base_id = var_id;
653 		}
654 
655 		if (need_sample_mask && !does_shader_write_sample_mask)
656 		{
657 			uint32_t offset = ir.increase_bound_by(2);
658 			uint32_t var_id = offset + 1;
659 
660 			// Create gl_SampleMask.
661 			SPIRType uint_type_ptr_out;
662 			uint_type_ptr_out = get_uint_type();
663 			uint_type_ptr_out.pointer = true;
664 			uint_type_ptr_out.parent_type = get_uint_type_id();
665 			uint_type_ptr_out.storage = StorageClassOutput;
666 
667 			auto &ptr_out_type = set<SPIRType>(offset, uint_type_ptr_out);
668 			ptr_out_type.self = get_uint_type_id();
669 			set<SPIRVariable>(var_id, offset, StorageClassOutput);
670 			set_decoration(var_id, DecorationBuiltIn, BuiltInSampleMask);
671 			builtin_sample_mask_id = var_id;
672 			mark_implicit_builtin(StorageClassOutput, BuiltInSampleMask, var_id);
673 		}
674 	}
675 
676 	if (needs_swizzle_buffer_def)
677 	{
678 		uint32_t var_id = build_constant_uint_array_pointer();
679 		set_name(var_id, "spvSwizzleConstants");
680 		// This should never match anything.
681 		set_decoration(var_id, DecorationDescriptorSet, kSwizzleBufferBinding);
682 		set_decoration(var_id, DecorationBinding, msl_options.swizzle_buffer_index);
683 		set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.swizzle_buffer_index);
684 		swizzle_buffer_id = var_id;
685 	}
686 
687 	if (!buffers_requiring_array_length.empty())
688 	{
689 		uint32_t var_id = build_constant_uint_array_pointer();
690 		set_name(var_id, "spvBufferSizeConstants");
691 		// This should never match anything.
692 		set_decoration(var_id, DecorationDescriptorSet, kBufferSizeBufferBinding);
693 		set_decoration(var_id, DecorationBinding, msl_options.buffer_size_buffer_index);
694 		set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.buffer_size_buffer_index);
695 		buffer_size_buffer_id = var_id;
696 	}
697 
698 	if (needs_view_mask_buffer())
699 	{
700 		uint32_t var_id = build_constant_uint_array_pointer();
701 		set_name(var_id, "spvViewMask");
702 		// This should never match anything.
703 		set_decoration(var_id, DecorationDescriptorSet, ~(4u));
704 		set_decoration(var_id, DecorationBinding, msl_options.view_mask_buffer_index);
705 		set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.view_mask_buffer_index);
706 		view_mask_buffer_id = var_id;
707 	}
708 
709 	if (!buffers_requiring_dynamic_offset.empty())
710 	{
711 		uint32_t var_id = build_constant_uint_array_pointer();
712 		set_name(var_id, "spvDynamicOffsets");
713 		// This should never match anything.
714 		set_decoration(var_id, DecorationDescriptorSet, ~(5u));
715 		set_decoration(var_id, DecorationBinding, msl_options.dynamic_offsets_buffer_index);
716 		set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary,
717 		                        msl_options.dynamic_offsets_buffer_index);
718 		dynamic_offsets_buffer_id = var_id;
719 	}
720 }
721 
722 // Checks if the specified builtin variable (e.g. gl_InstanceIndex) is marked as active.
723 // If not, it marks it as active and forces a recompilation.
724 // This might be used when the optimization of inactive builtins was too optimistic (e.g. when "spvOut" is emitted).
ensure_builtin(spv::StorageClass storage,spv::BuiltIn builtin)725 void CompilerMSL::ensure_builtin(spv::StorageClass storage, spv::BuiltIn builtin)
726 {
727 	Bitset *active_builtins = nullptr;
728 	switch (storage)
729 	{
730 	case StorageClassInput:
731 		active_builtins = &active_input_builtins;
732 		break;
733 
734 	case StorageClassOutput:
735 		active_builtins = &active_output_builtins;
736 		break;
737 
738 	default:
739 		break;
740 	}
741 
742 	// At this point, the specified builtin variable must have already been declared in the entry point.
743 	// If not, mark as active and force recompile.
744 	if (active_builtins != nullptr && !active_builtins->get(builtin))
745 	{
746 		active_builtins->set(builtin);
747 		force_recompile();
748 	}
749 }
750 
mark_implicit_builtin(StorageClass storage,BuiltIn builtin,uint32_t id)751 void CompilerMSL::mark_implicit_builtin(StorageClass storage, BuiltIn builtin, uint32_t id)
752 {
753 	Bitset *active_builtins = nullptr;
754 	switch (storage)
755 	{
756 	case StorageClassInput:
757 		active_builtins = &active_input_builtins;
758 		break;
759 
760 	case StorageClassOutput:
761 		active_builtins = &active_output_builtins;
762 		break;
763 
764 	default:
765 		break;
766 	}
767 
768 	assert(active_builtins != nullptr);
769 	active_builtins->set(builtin);
770 
771 	auto &var = get_entry_point().interface_variables;
772 	if (find(begin(var), end(var), VariableID(id)) == end(var))
773 		var.push_back(id);
774 }
775 
build_constant_uint_array_pointer()776 uint32_t CompilerMSL::build_constant_uint_array_pointer()
777 {
778 	uint32_t offset = ir.increase_bound_by(3);
779 	uint32_t type_ptr_id = offset;
780 	uint32_t type_ptr_ptr_id = offset + 1;
781 	uint32_t var_id = offset + 2;
782 
783 	// Create a buffer to hold extra data, including the swizzle constants.
784 	SPIRType uint_type_pointer = get_uint_type();
785 	uint_type_pointer.pointer = true;
786 	uint_type_pointer.pointer_depth = 1;
787 	uint_type_pointer.parent_type = get_uint_type_id();
788 	uint_type_pointer.storage = StorageClassUniform;
789 	set<SPIRType>(type_ptr_id, uint_type_pointer);
790 	set_decoration(type_ptr_id, DecorationArrayStride, 4);
791 
792 	SPIRType uint_type_pointer2 = uint_type_pointer;
793 	uint_type_pointer2.pointer_depth++;
794 	uint_type_pointer2.parent_type = type_ptr_id;
795 	set<SPIRType>(type_ptr_ptr_id, uint_type_pointer2);
796 
797 	set<SPIRVariable>(var_id, type_ptr_ptr_id, StorageClassUniformConstant);
798 	return var_id;
799 }
800 
create_sampler_address(const char * prefix,MSLSamplerAddress addr)801 static string create_sampler_address(const char *prefix, MSLSamplerAddress addr)
802 {
803 	switch (addr)
804 	{
805 	case MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE:
806 		return join(prefix, "address::clamp_to_edge");
807 	case MSL_SAMPLER_ADDRESS_CLAMP_TO_ZERO:
808 		return join(prefix, "address::clamp_to_zero");
809 	case MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER:
810 		return join(prefix, "address::clamp_to_border");
811 	case MSL_SAMPLER_ADDRESS_REPEAT:
812 		return join(prefix, "address::repeat");
813 	case MSL_SAMPLER_ADDRESS_MIRRORED_REPEAT:
814 		return join(prefix, "address::mirrored_repeat");
815 	default:
816 		SPIRV_CROSS_THROW("Invalid sampler addressing mode.");
817 	}
818 }
819 
get_stage_in_struct_type()820 SPIRType &CompilerMSL::get_stage_in_struct_type()
821 {
822 	auto &si_var = get<SPIRVariable>(stage_in_var_id);
823 	return get_variable_data_type(si_var);
824 }
825 
get_stage_out_struct_type()826 SPIRType &CompilerMSL::get_stage_out_struct_type()
827 {
828 	auto &so_var = get<SPIRVariable>(stage_out_var_id);
829 	return get_variable_data_type(so_var);
830 }
831 
get_patch_stage_in_struct_type()832 SPIRType &CompilerMSL::get_patch_stage_in_struct_type()
833 {
834 	auto &si_var = get<SPIRVariable>(patch_stage_in_var_id);
835 	return get_variable_data_type(si_var);
836 }
837 
get_patch_stage_out_struct_type()838 SPIRType &CompilerMSL::get_patch_stage_out_struct_type()
839 {
840 	auto &so_var = get<SPIRVariable>(patch_stage_out_var_id);
841 	return get_variable_data_type(so_var);
842 }
843 
get_tess_factor_struct_name()844 std::string CompilerMSL::get_tess_factor_struct_name()
845 {
846 	if (get_entry_point().flags.get(ExecutionModeTriangles))
847 		return "MTLTriangleTessellationFactorsHalf";
848 	return "MTLQuadTessellationFactorsHalf";
849 }
850 
get_uint_type()851 SPIRType &CompilerMSL::get_uint_type()
852 {
853 	return get<SPIRType>(get_uint_type_id());
854 }
855 
get_uint_type_id()856 uint32_t CompilerMSL::get_uint_type_id()
857 {
858 	if (uint_type_id != 0)
859 		return uint_type_id;
860 
861 	uint_type_id = ir.increase_bound_by(1);
862 
863 	SPIRType type;
864 	type.basetype = SPIRType::UInt;
865 	type.width = 32;
866 	set<SPIRType>(uint_type_id, type);
867 	return uint_type_id;
868 }
869 
emit_entry_point_declarations()870 void CompilerMSL::emit_entry_point_declarations()
871 {
872 	// FIXME: Get test coverage here ...
873 	// Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
874 	declare_complex_constant_arrays();
875 
876 	// Emit constexpr samplers here.
877 	for (auto &samp : constexpr_samplers_by_id)
878 	{
879 		auto &var = get<SPIRVariable>(samp.first);
880 		auto &type = get<SPIRType>(var.basetype);
881 		if (type.basetype == SPIRType::Sampler)
882 			add_resource_name(samp.first);
883 
884 		SmallVector<string> args;
885 		auto &s = samp.second;
886 
887 		if (s.coord != MSL_SAMPLER_COORD_NORMALIZED)
888 			args.push_back("coord::pixel");
889 
890 		if (s.min_filter == s.mag_filter)
891 		{
892 			if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
893 				args.push_back("filter::linear");
894 		}
895 		else
896 		{
897 			if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
898 				args.push_back("min_filter::linear");
899 			if (s.mag_filter != MSL_SAMPLER_FILTER_NEAREST)
900 				args.push_back("mag_filter::linear");
901 		}
902 
903 		switch (s.mip_filter)
904 		{
905 		case MSL_SAMPLER_MIP_FILTER_NONE:
906 			// Default
907 			break;
908 		case MSL_SAMPLER_MIP_FILTER_NEAREST:
909 			args.push_back("mip_filter::nearest");
910 			break;
911 		case MSL_SAMPLER_MIP_FILTER_LINEAR:
912 			args.push_back("mip_filter::linear");
913 			break;
914 		default:
915 			SPIRV_CROSS_THROW("Invalid mip filter.");
916 		}
917 
918 		if (s.s_address == s.t_address && s.s_address == s.r_address)
919 		{
920 			if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
921 				args.push_back(create_sampler_address("", s.s_address));
922 		}
923 		else
924 		{
925 			if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
926 				args.push_back(create_sampler_address("s_", s.s_address));
927 			if (s.t_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
928 				args.push_back(create_sampler_address("t_", s.t_address));
929 			if (s.r_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
930 				args.push_back(create_sampler_address("r_", s.r_address));
931 		}
932 
933 		if (s.compare_enable)
934 		{
935 			switch (s.compare_func)
936 			{
937 			case MSL_SAMPLER_COMPARE_FUNC_ALWAYS:
938 				args.push_back("compare_func::always");
939 				break;
940 			case MSL_SAMPLER_COMPARE_FUNC_NEVER:
941 				args.push_back("compare_func::never");
942 				break;
943 			case MSL_SAMPLER_COMPARE_FUNC_EQUAL:
944 				args.push_back("compare_func::equal");
945 				break;
946 			case MSL_SAMPLER_COMPARE_FUNC_NOT_EQUAL:
947 				args.push_back("compare_func::not_equal");
948 				break;
949 			case MSL_SAMPLER_COMPARE_FUNC_LESS:
950 				args.push_back("compare_func::less");
951 				break;
952 			case MSL_SAMPLER_COMPARE_FUNC_LESS_EQUAL:
953 				args.push_back("compare_func::less_equal");
954 				break;
955 			case MSL_SAMPLER_COMPARE_FUNC_GREATER:
956 				args.push_back("compare_func::greater");
957 				break;
958 			case MSL_SAMPLER_COMPARE_FUNC_GREATER_EQUAL:
959 				args.push_back("compare_func::greater_equal");
960 				break;
961 			default:
962 				SPIRV_CROSS_THROW("Invalid sampler compare function.");
963 			}
964 		}
965 
966 		if (s.s_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER || s.t_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER ||
967 		    s.r_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER)
968 		{
969 			switch (s.border_color)
970 			{
971 			case MSL_SAMPLER_BORDER_COLOR_OPAQUE_BLACK:
972 				args.push_back("border_color::opaque_black");
973 				break;
974 			case MSL_SAMPLER_BORDER_COLOR_OPAQUE_WHITE:
975 				args.push_back("border_color::opaque_white");
976 				break;
977 			case MSL_SAMPLER_BORDER_COLOR_TRANSPARENT_BLACK:
978 				args.push_back("border_color::transparent_black");
979 				break;
980 			default:
981 				SPIRV_CROSS_THROW("Invalid sampler border color.");
982 			}
983 		}
984 
985 		if (s.anisotropy_enable)
986 			args.push_back(join("max_anisotropy(", s.max_anisotropy, ")"));
987 		if (s.lod_clamp_enable)
988 		{
989 			args.push_back(join("lod_clamp(", convert_to_string(s.lod_clamp_min, current_locale_radix_character), ", ",
990 			                    convert_to_string(s.lod_clamp_max, current_locale_radix_character), ")"));
991 		}
992 
993 		// If we would emit no arguments, then omit the parentheses entirely. Otherwise,
994 		// we'll wind up with a "most vexing parse" situation.
995 		if (args.empty())
996 			statement("constexpr sampler ",
997 			          type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first),
998 			          ";");
999 		else
1000 			statement("constexpr sampler ",
1001 			          type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first),
1002 			          "(", merge(args), ");");
1003 	}
1004 
1005 	// Emit dynamic buffers here.
1006 	for (auto &dynamic_buffer : buffers_requiring_dynamic_offset)
1007 	{
1008 		if (!dynamic_buffer.second.second)
1009 		{
1010 			// Could happen if no buffer was used at requested binding point.
1011 			continue;
1012 		}
1013 
1014 		const auto &var = get<SPIRVariable>(dynamic_buffer.second.second);
1015 		uint32_t var_id = var.self;
1016 		const auto &type = get_variable_data_type(var);
1017 		string name = to_name(var.self);
1018 		uint32_t desc_set = get_decoration(var.self, DecorationDescriptorSet);
1019 		uint32_t arg_id = argument_buffer_ids[desc_set];
1020 		uint32_t base_index = dynamic_buffer.second.first;
1021 
1022 		if (!type.array.empty())
1023 		{
1024 			// This is complicated, because we need to support arrays of arrays.
1025 			// And it's even worse if the outermost dimension is a runtime array, because now
1026 			// all this complicated goop has to go into the shader itself. (FIXME)
1027 			if (!type.array[type.array.size() - 1])
1028 				SPIRV_CROSS_THROW("Runtime arrays with dynamic offsets are not supported yet.");
1029 			else
1030 			{
1031 				is_using_builtin_array = true;
1032 				statement(get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id), name,
1033 				          type_to_array_glsl(type), " =");
1034 
1035 				uint32_t dim = uint32_t(type.array.size());
1036 				uint32_t j = 0;
1037 				for (SmallVector<uint32_t> indices(type.array.size());
1038 				     indices[type.array.size() - 1] < to_array_size_literal(type); j++)
1039 				{
1040 					while (dim > 0)
1041 					{
1042 						begin_scope();
1043 						--dim;
1044 					}
1045 
1046 					string arrays;
1047 					for (uint32_t i = uint32_t(type.array.size()); i; --i)
1048 						arrays += join("[", indices[i - 1], "]");
1049 					statement("(", get_argument_address_space(var), " ", type_to_glsl(type), "* ",
1050 					          to_restrict(var_id, false), ")((", get_argument_address_space(var), " char* ",
1051 					          to_restrict(var_id, false), ")", to_name(arg_id), ".", ensure_valid_name(name, "m"),
1052 					          arrays, " + ", to_name(dynamic_offsets_buffer_id), "[", base_index + j, "]),");
1053 
1054 					while (++indices[dim] >= to_array_size_literal(type, dim) && dim < type.array.size() - 1)
1055 					{
1056 						end_scope(",");
1057 						indices[dim++] = 0;
1058 					}
1059 				}
1060 				end_scope_decl();
1061 				statement_no_indent("");
1062 				is_using_builtin_array = false;
1063 			}
1064 		}
1065 		else
1066 		{
1067 			statement(get_argument_address_space(var), " auto& ", to_restrict(var_id), name, " = *(",
1068 			          get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id, false), ")((",
1069 			          get_argument_address_space(var), " char* ", to_restrict(var_id, false), ")", to_name(arg_id), ".",
1070 			          ensure_valid_name(name, "m"), " + ", to_name(dynamic_offsets_buffer_id), "[", base_index, "]);");
1071 		}
1072 	}
1073 
1074 	// Emit buffer arrays here.
1075 	for (uint32_t array_id : buffer_arrays)
1076 	{
1077 		const auto &var = get<SPIRVariable>(array_id);
1078 		const auto &type = get_variable_data_type(var);
1079 		const auto &buffer_type = get_variable_element_type(var);
1080 		string name = to_name(array_id);
1081 		statement(get_argument_address_space(var), " ", type_to_glsl(buffer_type), "* ", to_restrict(array_id), name,
1082 		          "[] =");
1083 		begin_scope();
1084 		for (uint32_t i = 0; i < to_array_size_literal(type); ++i)
1085 			statement(name, "_", i, ",");
1086 		end_scope_decl();
1087 		statement_no_indent("");
1088 	}
1089 	// For some reason, without this, we end up emitting the arrays twice.
1090 	buffer_arrays.clear();
1091 
1092 	// Emit disabled fragment outputs.
1093 	std::sort(disabled_frag_outputs.begin(), disabled_frag_outputs.end());
1094 	for (uint32_t var_id : disabled_frag_outputs)
1095 	{
1096 		auto &var = get<SPIRVariable>(var_id);
1097 		add_local_variable_name(var_id);
1098 		statement(variable_decl(var), ";");
1099 		var.deferred_declaration = false;
1100 	}
1101 }
1102 
compile()1103 string CompilerMSL::compile()
1104 {
1105 	ir.fixup_reserved_names();
1106 
1107 	// Do not deal with GLES-isms like precision, older extensions and such.
1108 	options.vulkan_semantics = true;
1109 	options.es = false;
1110 	options.version = 450;
1111 	backend.null_pointer_literal = "nullptr";
1112 	backend.float_literal_suffix = false;
1113 	backend.uint32_t_literal_suffix = true;
1114 	backend.int16_t_literal_suffix = "";
1115 	backend.uint16_t_literal_suffix = "";
1116 	backend.basic_int_type = "int";
1117 	backend.basic_uint_type = "uint";
1118 	backend.basic_int8_type = "char";
1119 	backend.basic_uint8_type = "uchar";
1120 	backend.basic_int16_type = "short";
1121 	backend.basic_uint16_type = "ushort";
1122 	backend.discard_literal = "discard_fragment()";
1123 	backend.demote_literal = "unsupported-demote";
1124 	backend.boolean_mix_function = "select";
1125 	backend.swizzle_is_function = false;
1126 	backend.shared_is_implied = false;
1127 	backend.use_initializer_list = true;
1128 	backend.use_typed_initializer_list = true;
1129 	backend.native_row_major_matrix = false;
1130 	backend.unsized_array_supported = false;
1131 	backend.can_declare_arrays_inline = false;
1132 	backend.allow_truncated_access_chain = true;
1133 	backend.comparison_image_samples_scalar = true;
1134 	backend.native_pointers = true;
1135 	backend.nonuniform_qualifier = "";
1136 	backend.support_small_type_sampling_result = true;
1137 	backend.supports_empty_struct = true;
1138 
1139 	// Allow Metal to use the array<T> template unless we force it off.
1140 	backend.can_return_array = !msl_options.force_native_arrays;
1141 	backend.array_is_value_type = !msl_options.force_native_arrays;
1142 	// Arrays which are part of buffer objects are never considered to be native arrays.
1143 	backend.buffer_offset_array_is_value_type = false;
1144 
1145 	capture_output_to_buffer = msl_options.capture_output_to_buffer;
1146 	is_rasterization_disabled = msl_options.disable_rasterization || capture_output_to_buffer;
1147 
1148 	// Initialize array here rather than constructor, MSVC 2013 workaround.
1149 	for (auto &id : next_metal_resource_ids)
1150 		id = 0;
1151 
1152 	fixup_type_alias();
1153 	replace_illegal_names();
1154 
1155 	build_function_control_flow_graphs_and_analyze();
1156 	update_active_builtins();
1157 	analyze_image_and_sampler_usage();
1158 	analyze_sampled_image_usage();
1159 	analyze_interlocked_resource_usage();
1160 	preprocess_op_codes();
1161 	build_implicit_builtins();
1162 
1163 	fixup_image_load_store_access();
1164 
1165 	set_enabled_interface_variables(get_active_interface_variables());
1166 	if (msl_options.force_active_argument_buffer_resources)
1167 		activate_argument_buffer_resources();
1168 
1169 	if (swizzle_buffer_id)
1170 		active_interface_variables.insert(swizzle_buffer_id);
1171 	if (buffer_size_buffer_id)
1172 		active_interface_variables.insert(buffer_size_buffer_id);
1173 	if (view_mask_buffer_id)
1174 		active_interface_variables.insert(view_mask_buffer_id);
1175 	if (dynamic_offsets_buffer_id)
1176 		active_interface_variables.insert(dynamic_offsets_buffer_id);
1177 	if (builtin_layer_id)
1178 		active_interface_variables.insert(builtin_layer_id);
1179 	if (builtin_dispatch_base_id && !msl_options.supports_msl_version(1, 2))
1180 		active_interface_variables.insert(builtin_dispatch_base_id);
1181 	if (builtin_sample_mask_id)
1182 		active_interface_variables.insert(builtin_sample_mask_id);
1183 
1184 	// Create structs to hold input, output and uniform variables.
1185 	// Do output first to ensure out. is declared at top of entry function.
1186 	qual_pos_var_name = "";
1187 	stage_out_var_id = add_interface_block(StorageClassOutput);
1188 	patch_stage_out_var_id = add_interface_block(StorageClassOutput, true);
1189 	stage_in_var_id = add_interface_block(StorageClassInput);
1190 	if (get_execution_model() == ExecutionModelTessellationEvaluation)
1191 		patch_stage_in_var_id = add_interface_block(StorageClassInput, true);
1192 
1193 	if (get_execution_model() == ExecutionModelTessellationControl)
1194 		stage_out_ptr_var_id = add_interface_block_pointer(stage_out_var_id, StorageClassOutput);
1195 	if (is_tessellation_shader())
1196 		stage_in_ptr_var_id = add_interface_block_pointer(stage_in_var_id, StorageClassInput);
1197 
1198 	// Metal vertex functions that define no output must disable rasterization and return void.
1199 	if (!stage_out_var_id)
1200 		is_rasterization_disabled = true;
1201 
1202 	// Convert the use of global variables to recursively-passed function parameters
1203 	localize_global_variables();
1204 	extract_global_variables_from_functions();
1205 
1206 	// Mark any non-stage-in structs to be tightly packed.
1207 	mark_packable_structs();
1208 	reorder_type_alias();
1209 
1210 	// Add fixup hooks required by shader inputs and outputs. This needs to happen before
1211 	// the loop, so the hooks aren't added multiple times.
1212 	fix_up_shader_inputs_outputs();
1213 
1214 	// If we are using argument buffers, we create argument buffer structures for them here.
1215 	// These buffers will be used in the entry point, not the individual resources.
1216 	if (msl_options.argument_buffers)
1217 	{
1218 		if (!msl_options.supports_msl_version(2, 0))
1219 			SPIRV_CROSS_THROW("Argument buffers can only be used with MSL 2.0 and up.");
1220 		analyze_argument_buffers();
1221 	}
1222 
1223 	uint32_t pass_count = 0;
1224 	do
1225 	{
1226 		if (pass_count >= 3)
1227 			SPIRV_CROSS_THROW("Over 3 compilation loops detected. Must be a bug!");
1228 
1229 		reset();
1230 
1231 		// Start bindings at zero.
1232 		next_metal_resource_index_buffer = 0;
1233 		next_metal_resource_index_texture = 0;
1234 		next_metal_resource_index_sampler = 0;
1235 		for (auto &id : next_metal_resource_ids)
1236 			id = 0;
1237 
1238 		// Move constructor for this type is broken on GCC 4.9 ...
1239 		buffer.reset();
1240 
1241 		emit_header();
1242 		emit_custom_templates();
1243 		emit_specialization_constants_and_structs();
1244 		emit_resources();
1245 		emit_custom_functions();
1246 		emit_function(get<SPIRFunction>(ir.default_entry_point), Bitset());
1247 
1248 		pass_count++;
1249 	} while (is_forcing_recompilation());
1250 
1251 	return buffer.str();
1252 }
1253 
1254 // Register the need to output any custom functions.
preprocess_op_codes()1255 void CompilerMSL::preprocess_op_codes()
1256 {
1257 	OpCodePreprocessor preproc(*this);
1258 	traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), preproc);
1259 
1260 	suppress_missing_prototypes = preproc.suppress_missing_prototypes;
1261 
1262 	if (preproc.uses_atomics)
1263 	{
1264 		add_header_line("#include <metal_atomic>");
1265 		add_pragma_line("#pragma clang diagnostic ignored \"-Wunused-variable\"");
1266 	}
1267 
1268 	// Metal vertex functions that write to resources must disable rasterization and return void.
1269 	if (preproc.uses_resource_write)
1270 		is_rasterization_disabled = true;
1271 
1272 	// Tessellation control shaders are run as compute functions in Metal, and so
1273 	// must capture their output to a buffer.
1274 	if (get_execution_model() == ExecutionModelTessellationControl ||
1275 	    (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
1276 	{
1277 		is_rasterization_disabled = true;
1278 		capture_output_to_buffer = true;
1279 	}
1280 
1281 	if (preproc.needs_subgroup_invocation_id)
1282 		needs_subgroup_invocation_id = true;
1283 }
1284 
1285 // Move the Private and Workgroup global variables to the entry function.
1286 // Non-constant variables cannot have global scope in Metal.
localize_global_variables()1287 void CompilerMSL::localize_global_variables()
1288 {
1289 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1290 	auto iter = global_variables.begin();
1291 	while (iter != global_variables.end())
1292 	{
1293 		uint32_t v_id = *iter;
1294 		auto &var = get<SPIRVariable>(v_id);
1295 		if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup)
1296 		{
1297 			if (!variable_is_lut(var))
1298 				entry_func.add_local_variable(v_id);
1299 			iter = global_variables.erase(iter);
1300 		}
1301 		else
1302 			iter++;
1303 	}
1304 }
1305 
1306 // For any global variable accessed directly by a function,
1307 // extract that variable and add it as an argument to that function.
extract_global_variables_from_functions()1308 void CompilerMSL::extract_global_variables_from_functions()
1309 {
1310 	// Uniforms
1311 	unordered_set<uint32_t> global_var_ids;
1312 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1313 		if (var.storage == StorageClassInput || var.storage == StorageClassOutput ||
1314 		    var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
1315 		    var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer)
1316 		{
1317 			global_var_ids.insert(var.self);
1318 		}
1319 	});
1320 
1321 	// Local vars that are declared in the main function and accessed directly by a function
1322 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1323 	for (auto &var : entry_func.local_variables)
1324 		if (get<SPIRVariable>(var).storage != StorageClassFunction)
1325 			global_var_ids.insert(var);
1326 
1327 	std::set<uint32_t> added_arg_ids;
1328 	unordered_set<uint32_t> processed_func_ids;
1329 	extract_global_variables_from_function(ir.default_entry_point, added_arg_ids, global_var_ids, processed_func_ids);
1330 }
1331 
1332 // MSL does not support the use of global variables for shader input content.
1333 // For any global variable accessed directly by the specified function, extract that variable,
1334 // add it as an argument to that function, and the arg to the added_arg_ids collection.
extract_global_variables_from_function(uint32_t func_id,std::set<uint32_t> & added_arg_ids,unordered_set<uint32_t> & global_var_ids,unordered_set<uint32_t> & processed_func_ids)1335 void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::set<uint32_t> &added_arg_ids,
1336                                                          unordered_set<uint32_t> &global_var_ids,
1337                                                          unordered_set<uint32_t> &processed_func_ids)
1338 {
1339 	// Avoid processing a function more than once
1340 	if (processed_func_ids.find(func_id) != processed_func_ids.end())
1341 	{
1342 		// Return function global variables
1343 		added_arg_ids = function_global_vars[func_id];
1344 		return;
1345 	}
1346 
1347 	processed_func_ids.insert(func_id);
1348 
1349 	auto &func = get<SPIRFunction>(func_id);
1350 
1351 	// Recursively establish global args added to functions on which we depend.
1352 	for (auto block : func.blocks)
1353 	{
1354 		auto &b = get<SPIRBlock>(block);
1355 		for (auto &i : b.ops)
1356 		{
1357 			auto ops = stream(i);
1358 			auto op = static_cast<Op>(i.op);
1359 
1360 			switch (op)
1361 			{
1362 			case OpLoad:
1363 			case OpInBoundsAccessChain:
1364 			case OpAccessChain:
1365 			case OpPtrAccessChain:
1366 			case OpArrayLength:
1367 			{
1368 				uint32_t base_id = ops[2];
1369 				if (global_var_ids.find(base_id) != global_var_ids.end())
1370 					added_arg_ids.insert(base_id);
1371 
1372 				// Use Metal's native frame-buffer fetch API for subpass inputs.
1373 				auto &type = get<SPIRType>(ops[0]);
1374 				if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData &&
1375 				    (!msl_options.is_ios() || !msl_options.ios_use_framebuffer_fetch_subpasses))
1376 				{
1377 					// Implicitly reads gl_FragCoord.
1378 					assert(builtin_frag_coord_id != 0);
1379 					added_arg_ids.insert(builtin_frag_coord_id);
1380 					if (msl_options.multiview)
1381 					{
1382 						// Implicitly reads gl_ViewIndex.
1383 						assert(builtin_view_idx_id != 0);
1384 						added_arg_ids.insert(builtin_view_idx_id);
1385 					}
1386 					else if (msl_options.arrayed_subpass_input)
1387 					{
1388 						// Implicitly reads gl_Layer.
1389 						assert(builtin_layer_id != 0);
1390 						added_arg_ids.insert(builtin_layer_id);
1391 					}
1392 				}
1393 
1394 				break;
1395 			}
1396 
1397 			case OpFunctionCall:
1398 			{
1399 				// First see if any of the function call args are globals
1400 				for (uint32_t arg_idx = 3; arg_idx < i.length; arg_idx++)
1401 				{
1402 					uint32_t arg_id = ops[arg_idx];
1403 					if (global_var_ids.find(arg_id) != global_var_ids.end())
1404 						added_arg_ids.insert(arg_id);
1405 				}
1406 
1407 				// Then recurse into the function itself to extract globals used internally in the function
1408 				uint32_t inner_func_id = ops[2];
1409 				std::set<uint32_t> inner_func_args;
1410 				extract_global_variables_from_function(inner_func_id, inner_func_args, global_var_ids,
1411 				                                       processed_func_ids);
1412 				added_arg_ids.insert(inner_func_args.begin(), inner_func_args.end());
1413 				break;
1414 			}
1415 
1416 			case OpStore:
1417 			{
1418 				uint32_t base_id = ops[0];
1419 				if (global_var_ids.find(base_id) != global_var_ids.end())
1420 					added_arg_ids.insert(base_id);
1421 
1422 				uint32_t rvalue_id = ops[1];
1423 				if (global_var_ids.find(rvalue_id) != global_var_ids.end())
1424 					added_arg_ids.insert(rvalue_id);
1425 
1426 				break;
1427 			}
1428 
1429 			case OpSelect:
1430 			{
1431 				uint32_t base_id = ops[3];
1432 				if (global_var_ids.find(base_id) != global_var_ids.end())
1433 					added_arg_ids.insert(base_id);
1434 				base_id = ops[4];
1435 				if (global_var_ids.find(base_id) != global_var_ids.end())
1436 					added_arg_ids.insert(base_id);
1437 				break;
1438 			}
1439 
1440 			// Emulate texture2D atomic operations
1441 			case OpImageTexelPointer:
1442 			{
1443 				// When using the pointer, we need to know which variable it is actually loaded from.
1444 				uint32_t base_id = ops[2];
1445 				auto *var = maybe_get_backing_variable(base_id);
1446 				if (var && atomic_image_vars.count(var->self))
1447 				{
1448 					if (global_var_ids.find(base_id) != global_var_ids.end())
1449 						added_arg_ids.insert(base_id);
1450 				}
1451 				break;
1452 			}
1453 
1454 			default:
1455 				break;
1456 			}
1457 
1458 			// TODO: Add all other operations which can affect memory.
1459 			// We should consider a more unified system here to reduce boiler-plate.
1460 			// This kind of analysis is done in several places ...
1461 		}
1462 	}
1463 
1464 	function_global_vars[func_id] = added_arg_ids;
1465 
1466 	// Add the global variables as arguments to the function
1467 	if (func_id != ir.default_entry_point)
1468 	{
1469 		bool added_in = false;
1470 		bool added_out = false;
1471 		for (uint32_t arg_id : added_arg_ids)
1472 		{
1473 			auto &var = get<SPIRVariable>(arg_id);
1474 			uint32_t type_id = var.basetype;
1475 			auto *p_type = &get<SPIRType>(type_id);
1476 			BuiltIn bi_type = BuiltIn(get_decoration(arg_id, DecorationBuiltIn));
1477 
1478 			if (((is_tessellation_shader() && var.storage == StorageClassInput) ||
1479 			     (get_execution_model() == ExecutionModelTessellationControl && var.storage == StorageClassOutput)) &&
1480 			    !(has_decoration(arg_id, DecorationPatch) || is_patch_block(*p_type)) &&
1481 			    (!is_builtin_variable(var) || bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
1482 			     bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance ||
1483 			     p_type->basetype == SPIRType::Struct))
1484 			{
1485 				// Tessellation control shaders see inputs and per-vertex outputs as arrays.
1486 				// Similarly, tessellation evaluation shaders see per-vertex inputs as arrays.
1487 				// We collected them into a structure; we must pass the array of this
1488 				// structure to the function.
1489 				std::string name;
1490 				if (var.storage == StorageClassInput)
1491 				{
1492 					if (added_in)
1493 						continue;
1494 					name = "gl_in";
1495 					arg_id = stage_in_ptr_var_id;
1496 					added_in = true;
1497 				}
1498 				else if (var.storage == StorageClassOutput)
1499 				{
1500 					if (added_out)
1501 						continue;
1502 					name = "gl_out";
1503 					arg_id = stage_out_ptr_var_id;
1504 					added_out = true;
1505 				}
1506 				type_id = get<SPIRVariable>(arg_id).basetype;
1507 				uint32_t next_id = ir.increase_bound_by(1);
1508 				func.add_parameter(type_id, next_id, true);
1509 				set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1510 
1511 				set_name(next_id, name);
1512 			}
1513 			else if (is_builtin_variable(var) && p_type->basetype == SPIRType::Struct)
1514 			{
1515 				// Get the pointee type
1516 				type_id = get_pointee_type_id(type_id);
1517 				p_type = &get<SPIRType>(type_id);
1518 
1519 				uint32_t mbr_idx = 0;
1520 				for (auto &mbr_type_id : p_type->member_types)
1521 				{
1522 					BuiltIn builtin = BuiltInMax;
1523 					bool is_builtin = is_member_builtin(*p_type, mbr_idx, &builtin);
1524 					if (is_builtin && has_active_builtin(builtin, var.storage))
1525 					{
1526 						// Add a arg variable with the same type and decorations as the member
1527 						uint32_t next_ids = ir.increase_bound_by(2);
1528 						uint32_t ptr_type_id = next_ids + 0;
1529 						uint32_t var_id = next_ids + 1;
1530 
1531 						// Make sure we have an actual pointer type,
1532 						// so that we will get the appropriate address space when declaring these builtins.
1533 						auto &ptr = set<SPIRType>(ptr_type_id, get<SPIRType>(mbr_type_id));
1534 						ptr.self = mbr_type_id;
1535 						ptr.storage = var.storage;
1536 						ptr.pointer = true;
1537 						ptr.parent_type = mbr_type_id;
1538 
1539 						func.add_parameter(mbr_type_id, var_id, true);
1540 						set<SPIRVariable>(var_id, ptr_type_id, StorageClassFunction);
1541 						ir.meta[var_id].decoration = ir.meta[type_id].members[mbr_idx];
1542 					}
1543 					mbr_idx++;
1544 				}
1545 			}
1546 			else
1547 			{
1548 				uint32_t next_id = ir.increase_bound_by(1);
1549 				func.add_parameter(type_id, next_id, true);
1550 				set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1551 
1552 				// Ensure the existing variable has a valid name and the new variable has all the same meta info
1553 				set_name(arg_id, ensure_valid_name(to_name(arg_id), "v"));
1554 				ir.meta[next_id] = ir.meta[arg_id];
1555 			}
1556 		}
1557 	}
1558 }
1559 
1560 // For all variables that are some form of non-input-output interface block, mark that all the structs
1561 // that are recursively contained within the type referenced by that variable should be packed tightly.
mark_packable_structs()1562 void CompilerMSL::mark_packable_structs()
1563 {
1564 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1565 		if (var.storage != StorageClassFunction && !is_hidden_variable(var))
1566 		{
1567 			auto &type = this->get<SPIRType>(var.basetype);
1568 			if (type.pointer &&
1569 			    (type.storage == StorageClassUniform || type.storage == StorageClassUniformConstant ||
1570 			     type.storage == StorageClassPushConstant || type.storage == StorageClassStorageBuffer) &&
1571 			    (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
1572 				mark_as_packable(type);
1573 		}
1574 	});
1575 }
1576 
1577 // If the specified type is a struct, it and any nested structs
1578 // are marked as packable with the SPIRVCrossDecorationBufferBlockRepacked decoration,
mark_as_packable(SPIRType & type)1579 void CompilerMSL::mark_as_packable(SPIRType &type)
1580 {
1581 	// If this is not the base type (eg. it's a pointer or array), tunnel down
1582 	if (type.parent_type)
1583 	{
1584 		mark_as_packable(get<SPIRType>(type.parent_type));
1585 		return;
1586 	}
1587 
1588 	if (type.basetype == SPIRType::Struct)
1589 	{
1590 		set_extended_decoration(type.self, SPIRVCrossDecorationBufferBlockRepacked);
1591 
1592 		// Recurse
1593 		uint32_t mbr_cnt = uint32_t(type.member_types.size());
1594 		for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
1595 		{
1596 			uint32_t mbr_type_id = type.member_types[mbr_idx];
1597 			auto &mbr_type = get<SPIRType>(mbr_type_id);
1598 			mark_as_packable(mbr_type);
1599 			if (mbr_type.type_alias)
1600 			{
1601 				auto &mbr_type_alias = get<SPIRType>(mbr_type.type_alias);
1602 				mark_as_packable(mbr_type_alias);
1603 			}
1604 		}
1605 	}
1606 }
1607 
1608 // If a shader input exists at the location, it is marked as being used by this shader
mark_location_as_used_by_shader(uint32_t location,const SPIRType & type,StorageClass storage)1609 void CompilerMSL::mark_location_as_used_by_shader(uint32_t location, const SPIRType &type, StorageClass storage)
1610 {
1611 	if (storage != StorageClassInput)
1612 		return;
1613 	if (is_array(type))
1614 	{
1615 		uint32_t dim = 1;
1616 		for (uint32_t i = 0; i < type.array.size(); i++)
1617 			dim *= to_array_size_literal(type, i);
1618 		for (uint32_t i = 0; i < dim; i++)
1619 		{
1620 			if (is_matrix(type))
1621 			{
1622 				for (uint32_t j = 0; j < type.columns; j++)
1623 					inputs_in_use.insert(location++);
1624 			}
1625 			else
1626 				inputs_in_use.insert(location++);
1627 		}
1628 	}
1629 	else if (is_matrix(type))
1630 	{
1631 		for (uint32_t i = 0; i < type.columns; i++)
1632 			inputs_in_use.insert(location + i);
1633 	}
1634 	else
1635 		inputs_in_use.insert(location);
1636 }
1637 
get_target_components_for_fragment_location(uint32_t location) const1638 uint32_t CompilerMSL::get_target_components_for_fragment_location(uint32_t location) const
1639 {
1640 	auto itr = fragment_output_components.find(location);
1641 	if (itr == end(fragment_output_components))
1642 		return 4;
1643 	else
1644 		return itr->second;
1645 }
1646 
build_extended_vector_type(uint32_t type_id,uint32_t components,SPIRType::BaseType basetype)1647 uint32_t CompilerMSL::build_extended_vector_type(uint32_t type_id, uint32_t components, SPIRType::BaseType basetype)
1648 {
1649 	uint32_t new_type_id = ir.increase_bound_by(1);
1650 	auto &old_type = get<SPIRType>(type_id);
1651 	auto *type = &set<SPIRType>(new_type_id, old_type);
1652 	type->vecsize = components;
1653 	if (basetype != SPIRType::Unknown)
1654 		type->basetype = basetype;
1655 	type->self = new_type_id;
1656 	type->parent_type = type_id;
1657 	type->array.clear();
1658 	type->array_size_literal.clear();
1659 	type->pointer = false;
1660 
1661 	if (is_array(old_type))
1662 	{
1663 		uint32_t array_type_id = ir.increase_bound_by(1);
1664 		type = &set<SPIRType>(array_type_id, *type);
1665 		type->parent_type = new_type_id;
1666 		type->array = old_type.array;
1667 		type->array_size_literal = old_type.array_size_literal;
1668 		new_type_id = array_type_id;
1669 	}
1670 
1671 	if (old_type.pointer)
1672 	{
1673 		uint32_t ptr_type_id = ir.increase_bound_by(1);
1674 		type = &set<SPIRType>(ptr_type_id, *type);
1675 		type->self = new_type_id;
1676 		type->parent_type = new_type_id;
1677 		type->storage = old_type.storage;
1678 		type->pointer = true;
1679 		new_type_id = ptr_type_id;
1680 	}
1681 
1682 	return new_type_id;
1683 }
1684 
add_plain_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)1685 void CompilerMSL::add_plain_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
1686                                                         SPIRType &ib_type, SPIRVariable &var, InterfaceBlockMeta &meta)
1687 {
1688 	bool is_builtin = is_builtin_variable(var);
1689 	BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
1690 	bool is_flat = has_decoration(var.self, DecorationFlat);
1691 	bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
1692 	bool is_centroid = has_decoration(var.self, DecorationCentroid);
1693 	bool is_sample = has_decoration(var.self, DecorationSample);
1694 
1695 	// Add a reference to the variable type to the interface struct.
1696 	uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1697 	uint32_t type_id = ensure_correct_builtin_type(var.basetype, builtin);
1698 	var.basetype = type_id;
1699 
1700 	type_id = get_pointee_type_id(var.basetype);
1701 	if (meta.strip_array && is_array(get<SPIRType>(type_id)))
1702 		type_id = get<SPIRType>(type_id).parent_type;
1703 	auto &type = get<SPIRType>(type_id);
1704 	uint32_t target_components = 0;
1705 	uint32_t type_components = type.vecsize;
1706 
1707 	bool padded_output = false;
1708 	bool padded_input = false;
1709 	uint32_t start_component = 0;
1710 
1711 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1712 
1713 	// Deal with Component decorations.
1714 	InterfaceBlockMeta::LocationMeta *location_meta = nullptr;
1715 	if (has_decoration(var.self, DecorationLocation))
1716 	{
1717 		auto location_meta_itr = meta.location_meta.find(get_decoration(var.self, DecorationLocation));
1718 		if (location_meta_itr != end(meta.location_meta))
1719 			location_meta = &location_meta_itr->second;
1720 	}
1721 
1722 	bool pad_fragment_output = has_decoration(var.self, DecorationLocation) &&
1723 	                           msl_options.pad_fragment_output_components &&
1724 	                           get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput;
1725 
1726 	// Check if we need to pad fragment output to match a certain number of components.
1727 	if (location_meta)
1728 	{
1729 		start_component = get_decoration(var.self, DecorationComponent);
1730 		uint32_t num_components = location_meta->num_components;
1731 		if (pad_fragment_output)
1732 		{
1733 			uint32_t locn = get_decoration(var.self, DecorationLocation);
1734 			num_components = std::max(num_components, get_target_components_for_fragment_location(locn));
1735 		}
1736 
1737 		if (location_meta->ib_index != ~0u)
1738 		{
1739 			// We have already declared the variable. Just emit an early-declared variable and fixup as needed.
1740 			entry_func.add_local_variable(var.self);
1741 			vars_needing_early_declaration.push_back(var.self);
1742 
1743 			if (var.storage == StorageClassInput)
1744 			{
1745 				uint32_t ib_index = location_meta->ib_index;
1746 				entry_func.fixup_hooks_in.push_back([=, &var]() {
1747 					statement(to_name(var.self), " = ", ib_var_ref, ".", to_member_name(ib_type, ib_index),
1748 					          vector_swizzle(type_components, start_component), ";");
1749 				});
1750 			}
1751 			else
1752 			{
1753 				uint32_t ib_index = location_meta->ib_index;
1754 				entry_func.fixup_hooks_out.push_back([=, &var]() {
1755 					statement(ib_var_ref, ".", to_member_name(ib_type, ib_index),
1756 					          vector_swizzle(type_components, start_component), " = ", to_name(var.self), ";");
1757 				});
1758 			}
1759 			return;
1760 		}
1761 		else
1762 		{
1763 			location_meta->ib_index = uint32_t(ib_type.member_types.size());
1764 			type_id = build_extended_vector_type(type_id, num_components);
1765 			if (var.storage == StorageClassInput)
1766 				padded_input = true;
1767 			else
1768 				padded_output = true;
1769 		}
1770 	}
1771 	else if (pad_fragment_output)
1772 	{
1773 		uint32_t locn = get_decoration(var.self, DecorationLocation);
1774 		target_components = get_target_components_for_fragment_location(locn);
1775 		if (type_components < target_components)
1776 		{
1777 			// Make a new type here.
1778 			type_id = build_extended_vector_type(type_id, target_components);
1779 			padded_output = true;
1780 		}
1781 	}
1782 
1783 	ib_type.member_types.push_back(type_id);
1784 
1785 	// Give the member a name
1786 	string mbr_name = ensure_valid_name(to_expression(var.self), "m");
1787 	set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1788 
1789 	// Update the original variable reference to include the structure reference
1790 	string qual_var_name = ib_var_ref + "." + mbr_name;
1791 
1792 	if (padded_output || padded_input)
1793 	{
1794 		entry_func.add_local_variable(var.self);
1795 		vars_needing_early_declaration.push_back(var.self);
1796 
1797 		if (padded_output)
1798 		{
1799 			entry_func.fixup_hooks_out.push_back([=, &var]() {
1800 				statement(qual_var_name, vector_swizzle(type_components, start_component), " = ", to_name(var.self),
1801 				          ";");
1802 			});
1803 		}
1804 		else
1805 		{
1806 			entry_func.fixup_hooks_in.push_back([=, &var]() {
1807 				statement(to_name(var.self), " = ", qual_var_name, vector_swizzle(type_components, start_component),
1808 				          ";");
1809 			});
1810 		}
1811 	}
1812 	else if (!meta.strip_array)
1813 		ir.meta[var.self].decoration.qualified_alias = qual_var_name;
1814 
1815 	if (var.storage == StorageClassOutput && var.initializer != ID(0))
1816 	{
1817 		if (padded_output || padded_input)
1818 		{
1819 			entry_func.fixup_hooks_in.push_back(
1820 			    [=, &var]() { statement(to_name(var.self), " = ", to_expression(var.initializer), ";"); });
1821 		}
1822 		else
1823 		{
1824 			entry_func.fixup_hooks_in.push_back(
1825 			    [=, &var]() { statement(qual_var_name, " = ", to_expression(var.initializer), ";"); });
1826 		}
1827 	}
1828 
1829 	// Copy the variable location from the original variable to the member
1830 	if (get_decoration_bitset(var.self).get(DecorationLocation))
1831 	{
1832 		uint32_t locn = get_decoration(var.self, DecorationLocation);
1833 		if (storage == StorageClassInput)
1834 		{
1835 			type_id = ensure_correct_input_type(var.basetype, locn, location_meta ? location_meta->num_components : 0);
1836 			if (!location_meta)
1837 				var.basetype = type_id;
1838 
1839 			type_id = get_pointee_type_id(type_id);
1840 			if (meta.strip_array && is_array(get<SPIRType>(type_id)))
1841 				type_id = get<SPIRType>(type_id).parent_type;
1842 			ib_type.member_types[ib_mbr_idx] = type_id;
1843 		}
1844 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1845 		mark_location_as_used_by_shader(locn, get<SPIRType>(type_id), storage);
1846 	}
1847 	else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
1848 	{
1849 		uint32_t locn = inputs_by_builtin[builtin].location;
1850 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
1851 		mark_location_as_used_by_shader(locn, type, storage);
1852 	}
1853 
1854 	if (!location_meta)
1855 	{
1856 		if (get_decoration_bitset(var.self).get(DecorationComponent))
1857 		{
1858 			uint32_t component = get_decoration(var.self, DecorationComponent);
1859 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, component);
1860 		}
1861 	}
1862 
1863 	if (get_decoration_bitset(var.self).get(DecorationIndex))
1864 	{
1865 		uint32_t index = get_decoration(var.self, DecorationIndex);
1866 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
1867 	}
1868 
1869 	// Mark the member as builtin if needed
1870 	if (is_builtin)
1871 	{
1872 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
1873 		if (builtin == BuiltInPosition && storage == StorageClassOutput)
1874 			qual_pos_var_name = qual_var_name;
1875 	}
1876 
1877 	// Copy interpolation decorations if needed
1878 	if (is_flat)
1879 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
1880 	if (is_noperspective)
1881 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
1882 	if (is_centroid)
1883 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
1884 	if (is_sample)
1885 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
1886 
1887 	// If we have location meta, there is no unique OrigID. We won't need it, since we flatten/unflatten
1888 	// the variable to stack anyways here.
1889 	if (!location_meta)
1890 		set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
1891 }
1892 
add_composite_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)1893 void CompilerMSL::add_composite_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
1894                                                             SPIRType &ib_type, SPIRVariable &var,
1895                                                             InterfaceBlockMeta &meta)
1896 {
1897 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1898 	auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
1899 	uint32_t elem_cnt = 0;
1900 
1901 	if (is_matrix(var_type))
1902 	{
1903 		if (is_array(var_type))
1904 			SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
1905 
1906 		elem_cnt = var_type.columns;
1907 	}
1908 	else if (is_array(var_type))
1909 	{
1910 		if (var_type.array.size() != 1)
1911 			SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
1912 
1913 		elem_cnt = to_array_size_literal(var_type);
1914 	}
1915 
1916 	bool is_builtin = is_builtin_variable(var);
1917 	BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
1918 	bool is_flat = has_decoration(var.self, DecorationFlat);
1919 	bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
1920 	bool is_centroid = has_decoration(var.self, DecorationCentroid);
1921 	bool is_sample = has_decoration(var.self, DecorationSample);
1922 
1923 	auto *usable_type = &var_type;
1924 	if (usable_type->pointer)
1925 		usable_type = &get<SPIRType>(usable_type->parent_type);
1926 	while (is_array(*usable_type) || is_matrix(*usable_type))
1927 		usable_type = &get<SPIRType>(usable_type->parent_type);
1928 
1929 	// If a builtin, force it to have the proper name.
1930 	if (is_builtin)
1931 		set_name(var.self, builtin_to_glsl(builtin, StorageClassFunction));
1932 
1933 	bool flatten_from_ib_var = false;
1934 	string flatten_from_ib_mbr_name;
1935 
1936 	if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
1937 	{
1938 		// Also declare [[clip_distance]] attribute here.
1939 		uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
1940 		ib_type.member_types.push_back(get_variable_data_type_id(var));
1941 		set_member_decoration(ib_type.self, clip_array_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
1942 
1943 		flatten_from_ib_mbr_name = builtin_to_glsl(BuiltInClipDistance, StorageClassOutput);
1944 		set_member_name(ib_type.self, clip_array_mbr_idx, flatten_from_ib_mbr_name);
1945 
1946 		// When we flatten, we flatten directly from the "out" struct,
1947 		// not from a function variable.
1948 		flatten_from_ib_var = true;
1949 
1950 		if (!msl_options.enable_clip_distance_user_varying)
1951 			return;
1952 	}
1953 	else if (!meta.strip_array)
1954 	{
1955 		// Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
1956 		entry_func.add_local_variable(var.self);
1957 		// We need to declare the variable early and at entry-point scope.
1958 		vars_needing_early_declaration.push_back(var.self);
1959 	}
1960 
1961 	for (uint32_t i = 0; i < elem_cnt; i++)
1962 	{
1963 		// Add a reference to the variable type to the interface struct.
1964 		uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1965 
1966 		uint32_t target_components = 0;
1967 		bool padded_output = false;
1968 		uint32_t type_id = usable_type->self;
1969 
1970 		// Check if we need to pad fragment output to match a certain number of components.
1971 		if (get_decoration_bitset(var.self).get(DecorationLocation) && msl_options.pad_fragment_output_components &&
1972 		    get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput)
1973 		{
1974 			uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
1975 			target_components = get_target_components_for_fragment_location(locn);
1976 			if (usable_type->vecsize < target_components)
1977 			{
1978 				// Make a new type here.
1979 				type_id = build_extended_vector_type(usable_type->self, target_components);
1980 				padded_output = true;
1981 			}
1982 		}
1983 
1984 		ib_type.member_types.push_back(get_pointee_type_id(type_id));
1985 
1986 		// Give the member a name
1987 		string mbr_name = ensure_valid_name(join(to_expression(var.self), "_", i), "m");
1988 		set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1989 
1990 		// There is no qualified alias since we need to flatten the internal array on return.
1991 		if (get_decoration_bitset(var.self).get(DecorationLocation))
1992 		{
1993 			uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
1994 			if (storage == StorageClassInput)
1995 			{
1996 				var.basetype = ensure_correct_input_type(var.basetype, locn);
1997 				uint32_t mbr_type_id = ensure_correct_input_type(usable_type->self, locn);
1998 				ib_type.member_types[ib_mbr_idx] = mbr_type_id;
1999 			}
2000 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2001 			mark_location_as_used_by_shader(locn, *usable_type, storage);
2002 		}
2003 		else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2004 		{
2005 			uint32_t locn = inputs_by_builtin[builtin].location + i;
2006 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2007 			mark_location_as_used_by_shader(locn, *usable_type, storage);
2008 		}
2009 		else if (is_builtin && builtin == BuiltInClipDistance)
2010 		{
2011 			// Declare the ClipDistance as [[user(clipN)]].
2012 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2013 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, i);
2014 		}
2015 
2016 		if (get_decoration_bitset(var.self).get(DecorationIndex))
2017 		{
2018 			uint32_t index = get_decoration(var.self, DecorationIndex);
2019 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
2020 		}
2021 
2022 		// Copy interpolation decorations if needed
2023 		if (is_flat)
2024 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2025 		if (is_noperspective)
2026 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2027 		if (is_centroid)
2028 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2029 		if (is_sample)
2030 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2031 
2032 		set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2033 
2034 		// Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
2035 		if (!meta.strip_array)
2036 		{
2037 			switch (storage)
2038 			{
2039 			case StorageClassInput:
2040 				entry_func.fixup_hooks_in.push_back(
2041 				    [=, &var]() { statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, ";"); });
2042 				break;
2043 
2044 			case StorageClassOutput:
2045 				entry_func.fixup_hooks_out.push_back([=, &var]() {
2046 					if (padded_output)
2047 					{
2048 						auto &padded_type = this->get<SPIRType>(type_id);
2049 						statement(
2050 						    ib_var_ref, ".", mbr_name, " = ",
2051 						    remap_swizzle(padded_type, usable_type->vecsize, join(to_name(var.self), "[", i, "]")),
2052 						    ";");
2053 					}
2054 					else if (flatten_from_ib_var)
2055 						statement(ib_var_ref, ".", mbr_name, " = ", ib_var_ref, ".", flatten_from_ib_mbr_name, "[", i,
2056 						          "];");
2057 					else
2058 						statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), "[", i, "];");
2059 				});
2060 				break;
2061 
2062 			default:
2063 				break;
2064 			}
2065 		}
2066 	}
2067 }
2068 
get_accumulated_member_location(const SPIRVariable & var,uint32_t mbr_idx,bool strip_array)2069 uint32_t CompilerMSL::get_accumulated_member_location(const SPIRVariable &var, uint32_t mbr_idx, bool strip_array)
2070 {
2071 	auto &type = strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2072 	uint32_t location = get_decoration(var.self, DecorationLocation);
2073 
2074 	for (uint32_t i = 0; i < mbr_idx; i++)
2075 	{
2076 		auto &mbr_type = get<SPIRType>(type.member_types[i]);
2077 
2078 		// Start counting from any place we have a new location decoration.
2079 		if (has_member_decoration(type.self, mbr_idx, DecorationLocation))
2080 			location = get_member_decoration(type.self, mbr_idx, DecorationLocation);
2081 
2082 		uint32_t location_count = 1;
2083 
2084 		if (mbr_type.columns > 1)
2085 			location_count = mbr_type.columns;
2086 
2087 		if (!mbr_type.array.empty())
2088 			for (uint32_t j = 0; j < uint32_t(mbr_type.array.size()); j++)
2089 				location_count *= to_array_size_literal(mbr_type, j);
2090 
2091 		location += location_count;
2092 	}
2093 
2094 	return location;
2095 }
2096 
add_composite_member_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,uint32_t mbr_idx,InterfaceBlockMeta & meta)2097 void CompilerMSL::add_composite_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2098                                                                    SPIRType &ib_type, SPIRVariable &var,
2099                                                                    uint32_t mbr_idx, InterfaceBlockMeta &meta)
2100 {
2101 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2102 	auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2103 
2104 	BuiltIn builtin = BuiltInMax;
2105 	bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2106 	bool is_flat =
2107 	    has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
2108 	bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
2109 	                        has_decoration(var.self, DecorationNoPerspective);
2110 	bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
2111 	                   has_decoration(var.self, DecorationCentroid);
2112 	bool is_sample =
2113 	    has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
2114 
2115 	uint32_t mbr_type_id = var_type.member_types[mbr_idx];
2116 	auto &mbr_type = get<SPIRType>(mbr_type_id);
2117 	uint32_t elem_cnt = 0;
2118 
2119 	if (is_matrix(mbr_type))
2120 	{
2121 		if (is_array(mbr_type))
2122 			SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
2123 
2124 		elem_cnt = mbr_type.columns;
2125 	}
2126 	else if (is_array(mbr_type))
2127 	{
2128 		if (mbr_type.array.size() != 1)
2129 			SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
2130 
2131 		elem_cnt = to_array_size_literal(mbr_type);
2132 	}
2133 
2134 	auto *usable_type = &mbr_type;
2135 	if (usable_type->pointer)
2136 		usable_type = &get<SPIRType>(usable_type->parent_type);
2137 	while (is_array(*usable_type) || is_matrix(*usable_type))
2138 		usable_type = &get<SPIRType>(usable_type->parent_type);
2139 
2140 	bool flatten_from_ib_var = false;
2141 	string flatten_from_ib_mbr_name;
2142 
2143 	if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
2144 	{
2145 		// Also declare [[clip_distance]] attribute here.
2146 		uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
2147 		ib_type.member_types.push_back(mbr_type_id);
2148 		set_member_decoration(ib_type.self, clip_array_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2149 
2150 		flatten_from_ib_mbr_name = builtin_to_glsl(BuiltInClipDistance, StorageClassOutput);
2151 		set_member_name(ib_type.self, clip_array_mbr_idx, flatten_from_ib_mbr_name);
2152 
2153 		// When we flatten, we flatten directly from the "out" struct,
2154 		// not from a function variable.
2155 		flatten_from_ib_var = true;
2156 
2157 		if (!msl_options.enable_clip_distance_user_varying)
2158 			return;
2159 	}
2160 
2161 	for (uint32_t i = 0; i < elem_cnt; i++)
2162 	{
2163 		// Add a reference to the variable type to the interface struct.
2164 		uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2165 		ib_type.member_types.push_back(usable_type->self);
2166 
2167 		// Give the member a name
2168 		string mbr_name = ensure_valid_name(join(to_qualified_member_name(var_type, mbr_idx), "_", i), "m");
2169 		set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2170 
2171 		if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
2172 		{
2173 			uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation) + i;
2174 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2175 			mark_location_as_used_by_shader(locn, *usable_type, storage);
2176 		}
2177 		else if (has_decoration(var.self, DecorationLocation))
2178 		{
2179 			uint32_t locn = get_accumulated_member_location(var, mbr_idx, meta.strip_array) + i;
2180 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2181 			mark_location_as_used_by_shader(locn, *usable_type, storage);
2182 		}
2183 		else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2184 		{
2185 			uint32_t locn = inputs_by_builtin[builtin].location + i;
2186 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2187 			mark_location_as_used_by_shader(locn, *usable_type, storage);
2188 		}
2189 		else if (is_builtin && builtin == BuiltInClipDistance)
2190 		{
2191 			// Declare the ClipDistance as [[user(clipN)]].
2192 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2193 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, i);
2194 		}
2195 
2196 		if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
2197 			SPIRV_CROSS_THROW("DecorationComponent on matrices and arrays make little sense.");
2198 
2199 		// Copy interpolation decorations if needed
2200 		if (is_flat)
2201 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2202 		if (is_noperspective)
2203 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2204 		if (is_centroid)
2205 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2206 		if (is_sample)
2207 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2208 
2209 		set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2210 		set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
2211 
2212 		// Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
2213 		if (!meta.strip_array)
2214 		{
2215 			switch (storage)
2216 			{
2217 			case StorageClassInput:
2218 				entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
2219 					statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), "[", i, "] = ", ib_var_ref,
2220 					          ".", mbr_name, ";");
2221 				});
2222 				break;
2223 
2224 			case StorageClassOutput:
2225 				entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
2226 					if (flatten_from_ib_var)
2227 					{
2228 						statement(ib_var_ref, ".", mbr_name, " = ", ib_var_ref, ".", flatten_from_ib_mbr_name, "[", i,
2229 						          "];");
2230 					}
2231 					else
2232 					{
2233 						statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), ".",
2234 						          to_member_name(var_type, mbr_idx), "[", i, "];");
2235 					}
2236 				});
2237 				break;
2238 
2239 			default:
2240 				break;
2241 			}
2242 		}
2243 	}
2244 }
2245 
add_plain_member_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,uint32_t mbr_idx,InterfaceBlockMeta & meta)2246 void CompilerMSL::add_plain_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2247                                                                SPIRType &ib_type, SPIRVariable &var, uint32_t mbr_idx,
2248                                                                InterfaceBlockMeta &meta)
2249 {
2250 	auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2251 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2252 
2253 	BuiltIn builtin = BuiltInMax;
2254 	bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2255 	bool is_flat =
2256 	    has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
2257 	bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
2258 	                        has_decoration(var.self, DecorationNoPerspective);
2259 	bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
2260 	                   has_decoration(var.self, DecorationCentroid);
2261 	bool is_sample =
2262 	    has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
2263 
2264 	// Add a reference to the member to the interface struct.
2265 	uint32_t mbr_type_id = var_type.member_types[mbr_idx];
2266 	uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2267 	mbr_type_id = ensure_correct_builtin_type(mbr_type_id, builtin);
2268 	var_type.member_types[mbr_idx] = mbr_type_id;
2269 	ib_type.member_types.push_back(mbr_type_id);
2270 
2271 	// Give the member a name
2272 	string mbr_name = ensure_valid_name(to_qualified_member_name(var_type, mbr_idx), "m");
2273 	set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2274 
2275 	// Update the original variable reference to include the structure reference
2276 	string qual_var_name = ib_var_ref + "." + mbr_name;
2277 
2278 	if (is_builtin && !meta.strip_array)
2279 	{
2280 		// For the builtin gl_PerVertex, we cannot treat it as a block anyways,
2281 		// so redirect to qualified name.
2282 		set_member_qualified_name(var_type.self, mbr_idx, qual_var_name);
2283 	}
2284 	else if (!meta.strip_array)
2285 	{
2286 		// Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
2287 		switch (storage)
2288 		{
2289 		case StorageClassInput:
2290 			entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
2291 				statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), " = ", qual_var_name, ";");
2292 			});
2293 			break;
2294 
2295 		case StorageClassOutput:
2296 			entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
2297 				statement(qual_var_name, " = ", to_name(var.self), ".", to_member_name(var_type, mbr_idx), ";");
2298 			});
2299 			break;
2300 
2301 		default:
2302 			break;
2303 		}
2304 	}
2305 
2306 	// Copy the variable location from the original variable to the member
2307 	if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
2308 	{
2309 		uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation);
2310 		if (storage == StorageClassInput)
2311 		{
2312 			mbr_type_id = ensure_correct_input_type(mbr_type_id, locn);
2313 			var_type.member_types[mbr_idx] = mbr_type_id;
2314 			ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2315 		}
2316 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2317 		mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2318 	}
2319 	else if (has_decoration(var.self, DecorationLocation))
2320 	{
2321 		// The block itself might have a location and in this case, all members of the block
2322 		// receive incrementing locations.
2323 		uint32_t locn = get_accumulated_member_location(var, mbr_idx, meta.strip_array);
2324 		if (storage == StorageClassInput)
2325 		{
2326 			mbr_type_id = ensure_correct_input_type(mbr_type_id, locn);
2327 			var_type.member_types[mbr_idx] = mbr_type_id;
2328 			ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2329 		}
2330 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2331 		mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2332 	}
2333 	else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2334 	{
2335 		uint32_t locn = 0;
2336 		auto builtin_itr = inputs_by_builtin.find(builtin);
2337 		if (builtin_itr != end(inputs_by_builtin))
2338 			locn = builtin_itr->second.location;
2339 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2340 		mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2341 	}
2342 
2343 	// Copy the component location, if present.
2344 	if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
2345 	{
2346 		uint32_t comp = get_member_decoration(var_type.self, mbr_idx, DecorationComponent);
2347 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp);
2348 	}
2349 
2350 	// Mark the member as builtin if needed
2351 	if (is_builtin)
2352 	{
2353 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2354 		if (builtin == BuiltInPosition && storage == StorageClassOutput)
2355 			qual_pos_var_name = qual_var_name;
2356 	}
2357 
2358 	// Copy interpolation decorations if needed
2359 	if (is_flat)
2360 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2361 	if (is_noperspective)
2362 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2363 	if (is_centroid)
2364 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2365 	if (is_sample)
2366 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2367 
2368 	set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2369 	set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
2370 }
2371 
2372 // In Metal, the tessellation levels are stored as tightly packed half-precision floating point values.
2373 // But, stage-in attribute offsets and strides must be multiples of four, so we can't pass the levels
2374 // individually. Therefore, we must pass them as vectors. Triangles get a single float4, with the outer
2375 // levels in 'xyz' and the inner level in 'w'. Quads get a float4 containing the outer levels and a
2376 // float2 containing the inner levels.
add_tess_level_input_to_interface_block(const std::string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var)2377 void CompilerMSL::add_tess_level_input_to_interface_block(const std::string &ib_var_ref, SPIRType &ib_type,
2378                                                           SPIRVariable &var)
2379 {
2380 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2381 	auto &var_type = get_variable_element_type(var);
2382 
2383 	BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2384 
2385 	// Force the variable to have the proper name.
2386 	set_name(var.self, builtin_to_glsl(builtin, StorageClassFunction));
2387 
2388 	if (get_entry_point().flags.get(ExecutionModeTriangles))
2389 	{
2390 		// Triangles are tricky, because we want only one member in the struct.
2391 
2392 		// We need to declare the variable early and at entry-point scope.
2393 		entry_func.add_local_variable(var.self);
2394 		vars_needing_early_declaration.push_back(var.self);
2395 
2396 		string mbr_name = "gl_TessLevel";
2397 
2398 		// If we already added the other one, we can skip this step.
2399 		if (!added_builtin_tess_level)
2400 		{
2401 			// Add a reference to the variable type to the interface struct.
2402 			uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2403 
2404 			uint32_t type_id = build_extended_vector_type(var_type.self, 4);
2405 
2406 			ib_type.member_types.push_back(type_id);
2407 
2408 			// Give the member a name
2409 			set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2410 
2411 			// There is no qualified alias since we need to flatten the internal array on return.
2412 			if (get_decoration_bitset(var.self).get(DecorationLocation))
2413 			{
2414 				uint32_t locn = get_decoration(var.self, DecorationLocation);
2415 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2416 				mark_location_as_used_by_shader(locn, var_type, StorageClassInput);
2417 			}
2418 			else if (inputs_by_builtin.count(builtin))
2419 			{
2420 				uint32_t locn = inputs_by_builtin[builtin].location;
2421 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2422 				mark_location_as_used_by_shader(locn, var_type, StorageClassInput);
2423 			}
2424 
2425 			added_builtin_tess_level = true;
2426 		}
2427 
2428 		switch (builtin)
2429 		{
2430 		case BuiltInTessLevelOuter:
2431 			entry_func.fixup_hooks_in.push_back([=, &var]() {
2432 				statement(to_name(var.self), "[0] = ", ib_var_ref, ".", mbr_name, ".x;");
2433 				statement(to_name(var.self), "[1] = ", ib_var_ref, ".", mbr_name, ".y;");
2434 				statement(to_name(var.self), "[2] = ", ib_var_ref, ".", mbr_name, ".z;");
2435 			});
2436 			break;
2437 
2438 		case BuiltInTessLevelInner:
2439 			entry_func.fixup_hooks_in.push_back(
2440 			    [=, &var]() { statement(to_name(var.self), "[0] = ", ib_var_ref, ".", mbr_name, ".w;"); });
2441 			break;
2442 
2443 		default:
2444 			assert(false);
2445 			break;
2446 		}
2447 	}
2448 	else
2449 	{
2450 		// Add a reference to the variable type to the interface struct.
2451 		uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2452 
2453 		uint32_t type_id = build_extended_vector_type(var_type.self, builtin == BuiltInTessLevelOuter ? 4 : 2);
2454 		// Change the type of the variable, too.
2455 		uint32_t ptr_type_id = ir.increase_bound_by(1);
2456 		auto &new_var_type = set<SPIRType>(ptr_type_id, get<SPIRType>(type_id));
2457 		new_var_type.pointer = true;
2458 		new_var_type.storage = StorageClassInput;
2459 		new_var_type.parent_type = type_id;
2460 		var.basetype = ptr_type_id;
2461 
2462 		ib_type.member_types.push_back(type_id);
2463 
2464 		// Give the member a name
2465 		string mbr_name = to_expression(var.self);
2466 		set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2467 
2468 		// Since vectors can be indexed like arrays, there is no need to unpack this. We can
2469 		// just refer to the vector directly. So give it a qualified alias.
2470 		string qual_var_name = ib_var_ref + "." + mbr_name;
2471 		ir.meta[var.self].decoration.qualified_alias = qual_var_name;
2472 
2473 		if (get_decoration_bitset(var.self).get(DecorationLocation))
2474 		{
2475 			uint32_t locn = get_decoration(var.self, DecorationLocation);
2476 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2477 			mark_location_as_used_by_shader(locn, new_var_type, StorageClassInput);
2478 		}
2479 		else if (inputs_by_builtin.count(builtin))
2480 		{
2481 			uint32_t locn = inputs_by_builtin[builtin].location;
2482 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2483 			mark_location_as_used_by_shader(locn, new_var_type, StorageClassInput);
2484 		}
2485 	}
2486 }
2487 
add_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)2488 void CompilerMSL::add_variable_to_interface_block(StorageClass storage, const string &ib_var_ref, SPIRType &ib_type,
2489                                                   SPIRVariable &var, InterfaceBlockMeta &meta)
2490 {
2491 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2492 	// Tessellation control I/O variables and tessellation evaluation per-point inputs are
2493 	// usually declared as arrays. In these cases, we want to add the element type to the
2494 	// interface block, since in Metal it's the interface block itself which is arrayed.
2495 	auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2496 	bool is_builtin = is_builtin_variable(var);
2497 	auto builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2498 
2499 	if (var_type.basetype == SPIRType::Struct)
2500 	{
2501 		if (!is_builtin_type(var_type) && (!capture_output_to_buffer || storage == StorageClassInput) &&
2502 		    !meta.strip_array)
2503 		{
2504 			// For I/O blocks or structs, we will need to pass the block itself around
2505 			// to functions if they are used globally in leaf functions.
2506 			// Rather than passing down member by member,
2507 			// we unflatten I/O blocks while running the shader,
2508 			// and pass the actual struct type down to leaf functions.
2509 			// We then unflatten inputs, and flatten outputs in the "fixup" stages.
2510 			entry_func.add_local_variable(var.self);
2511 			vars_needing_early_declaration.push_back(var.self);
2512 		}
2513 
2514 		if (capture_output_to_buffer && storage != StorageClassInput && !has_decoration(var_type.self, DecorationBlock))
2515 		{
2516 			// In Metal tessellation shaders, the interface block itself is arrayed. This makes things
2517 			// very complicated, since stage-in structures in MSL don't support nested structures.
2518 			// Luckily, for stage-out when capturing output, we can avoid this and just add
2519 			// composite members directly, because the stage-out structure is stored to a buffer,
2520 			// not returned.
2521 			add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
2522 		}
2523 		else
2524 		{
2525 			// Flatten the struct members into the interface struct
2526 			for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++)
2527 			{
2528 				builtin = BuiltInMax;
2529 				is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2530 				auto &mbr_type = get<SPIRType>(var_type.member_types[mbr_idx]);
2531 
2532 				if (!is_builtin || has_active_builtin(builtin, storage))
2533 				{
2534 					bool is_composite_type = is_matrix(mbr_type) || is_array(mbr_type);
2535 					bool attribute_load_store =
2536 					    storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
2537 					bool storage_is_stage_io =
2538 					    (storage == StorageClassInput && !(get_execution_model() == ExecutionModelTessellationControl &&
2539 					                                       msl_options.multi_patch_workgroup)) ||
2540 					    storage == StorageClassOutput;
2541 
2542 					// ClipDistance always needs to be declared as user attributes.
2543 					if (builtin == BuiltInClipDistance)
2544 						is_builtin = false;
2545 
2546 					if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
2547 					{
2548 						add_composite_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx,
2549 						                                                 meta);
2550 					}
2551 					else
2552 					{
2553 						add_plain_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx, meta);
2554 					}
2555 				}
2556 			}
2557 		}
2558 	}
2559 	else if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput &&
2560 	         !meta.strip_array && is_builtin && (builtin == BuiltInTessLevelOuter || builtin == BuiltInTessLevelInner))
2561 	{
2562 		add_tess_level_input_to_interface_block(ib_var_ref, ib_type, var);
2563 	}
2564 	else if (var_type.basetype == SPIRType::Boolean || var_type.basetype == SPIRType::Char ||
2565 	         type_is_integral(var_type) || type_is_floating_point(var_type) || var_type.basetype == SPIRType::Boolean)
2566 	{
2567 		if (!is_builtin || has_active_builtin(builtin, storage))
2568 		{
2569 			bool is_composite_type = is_matrix(var_type) || is_array(var_type);
2570 			bool storage_is_stage_io =
2571 			    (storage == StorageClassInput &&
2572 			     !(get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup)) ||
2573 			    (storage == StorageClassOutput && !capture_output_to_buffer);
2574 			bool attribute_load_store = storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
2575 
2576 			// ClipDistance always needs to be declared as user attributes.
2577 			if (builtin == BuiltInClipDistance)
2578 				is_builtin = false;
2579 
2580 			// MSL does not allow matrices or arrays in input or output variables, so need to handle it specially.
2581 			if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
2582 			{
2583 				add_composite_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
2584 			}
2585 			else
2586 			{
2587 				add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
2588 			}
2589 		}
2590 	}
2591 }
2592 
2593 // Fix up the mapping of variables to interface member indices, which is used to compile access chains
2594 // for per-vertex variables in a tessellation control shader.
fix_up_interface_member_indices(StorageClass storage,uint32_t ib_type_id)2595 void CompilerMSL::fix_up_interface_member_indices(StorageClass storage, uint32_t ib_type_id)
2596 {
2597 	// Only needed for tessellation shaders.
2598 	// Need to redirect interface indices back to variables themselves.
2599 	// For structs, each member of the struct need a separate instance.
2600 	if (get_execution_model() != ExecutionModelTessellationControl &&
2601 	    !(get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput))
2602 		return;
2603 
2604 	auto mbr_cnt = uint32_t(ir.meta[ib_type_id].members.size());
2605 	for (uint32_t i = 0; i < mbr_cnt; i++)
2606 	{
2607 		uint32_t var_id = get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceOrigID);
2608 		if (!var_id)
2609 			continue;
2610 		auto &var = get<SPIRVariable>(var_id);
2611 
2612 		auto &type = get_variable_element_type(var);
2613 		if (storage == StorageClassInput && type.basetype == SPIRType::Struct)
2614 		{
2615 			uint32_t mbr_idx = get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceMemberIndex);
2616 
2617 			// Only set the lowest InterfaceMemberIndex for each variable member.
2618 			// IB struct members will be emitted in-order w.r.t. interface member index.
2619 			if (!has_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex))
2620 				set_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, i);
2621 		}
2622 		else
2623 		{
2624 			// Only set the lowest InterfaceMemberIndex for each variable.
2625 			// IB struct members will be emitted in-order w.r.t. interface member index.
2626 			if (!has_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex))
2627 				set_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex, i);
2628 		}
2629 	}
2630 }
2631 
2632 // Add an interface structure for the type of storage, which is either StorageClassInput or StorageClassOutput.
2633 // Returns the ID of the newly added variable, or zero if no variable was added.
add_interface_block(StorageClass storage,bool patch)2634 uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch)
2635 {
2636 	// Accumulate the variables that should appear in the interface struct.
2637 	SmallVector<SPIRVariable *> vars;
2638 	bool incl_builtins = storage == StorageClassOutput || is_tessellation_shader();
2639 	bool has_seen_barycentric = false;
2640 
2641 	InterfaceBlockMeta meta;
2642 
2643 	// Varying interfaces between stages which use "user()" attribute can be dealt with
2644 	// without explicit packing and unpacking of components. For any variables which link against the runtime
2645 	// in some way (vertex attributes, fragment output, etc), we'll need to deal with it somehow.
2646 	bool pack_components =
2647 	    (storage == StorageClassInput && get_execution_model() == ExecutionModelVertex) ||
2648 	    (storage == StorageClassOutput && get_execution_model() == ExecutionModelFragment) ||
2649 	    (storage == StorageClassOutput && get_execution_model() == ExecutionModelVertex && capture_output_to_buffer);
2650 
2651 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
2652 		if (var.storage != storage)
2653 			return;
2654 
2655 		auto &type = this->get<SPIRType>(var.basetype);
2656 
2657 		bool is_builtin = is_builtin_variable(var);
2658 		auto bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
2659 		uint32_t location = get_decoration(var_id, DecorationLocation);
2660 
2661 		// These builtins are part of the stage in/out structs.
2662 		bool is_interface_block_builtin =
2663 		    (bi_type == BuiltInPosition || bi_type == BuiltInPointSize || bi_type == BuiltInClipDistance ||
2664 		     bi_type == BuiltInCullDistance || bi_type == BuiltInLayer || bi_type == BuiltInViewportIndex ||
2665 		     bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV || bi_type == BuiltInFragDepth ||
2666 		     bi_type == BuiltInFragStencilRefEXT || bi_type == BuiltInSampleMask) ||
2667 		    (get_execution_model() == ExecutionModelTessellationEvaluation &&
2668 		     (bi_type == BuiltInTessLevelOuter || bi_type == BuiltInTessLevelInner));
2669 
2670 		bool is_active = interface_variable_exists_in_entry_point(var.self);
2671 		if (is_builtin && is_active)
2672 		{
2673 			// Only emit the builtin if it's active in this entry point. Interface variable list might lie.
2674 			is_active = has_active_builtin(bi_type, storage);
2675 		}
2676 
2677 		bool filter_patch_decoration = (has_decoration(var_id, DecorationPatch) || is_patch_block(type)) == patch;
2678 
2679 		bool hidden = is_hidden_variable(var, incl_builtins);
2680 
2681 		// ClipDistance is never hidden, we need to emulate it when used as an input.
2682 		if (bi_type == BuiltInClipDistance)
2683 			hidden = false;
2684 
2685 		// It's not enough to simply avoid marking fragment outputs if the pipeline won't
2686 		// accept them. We can't put them in the struct at all, or otherwise the compiler
2687 		// complains that the outputs weren't explicitly marked.
2688 		if (get_execution_model() == ExecutionModelFragment && storage == StorageClassOutput && !patch &&
2689 		    ((is_builtin && ((bi_type == BuiltInFragDepth && !msl_options.enable_frag_depth_builtin) ||
2690 		                     (bi_type == BuiltInFragStencilRefEXT && !msl_options.enable_frag_stencil_ref_builtin))) ||
2691 		     (!is_builtin && !(msl_options.enable_frag_output_mask & (1 << location)))))
2692 		{
2693 			hidden = true;
2694 			disabled_frag_outputs.push_back(var_id);
2695 			// If a builtin, force it to have the proper name.
2696 			if (is_builtin)
2697 				set_name(var_id, builtin_to_glsl(bi_type, StorageClassFunction));
2698 		}
2699 
2700 		// Barycentric inputs must be emitted in stage-in, because they can have interpolation arguments.
2701 		if (is_active && (bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV))
2702 		{
2703 			if (has_seen_barycentric)
2704 				SPIRV_CROSS_THROW("Cannot declare both BaryCoordNV and BaryCoordNoPerspNV in same shader in MSL.");
2705 			has_seen_barycentric = true;
2706 			hidden = false;
2707 		}
2708 
2709 		if (is_active && !hidden && type.pointer && filter_patch_decoration &&
2710 		    (!is_builtin || is_interface_block_builtin))
2711 		{
2712 			vars.push_back(&var);
2713 
2714 			if (!is_builtin)
2715 			{
2716 				// Need to deal specially with DecorationComponent.
2717 				// Multiple variables can alias the same Location, and try to make sure each location is declared only once.
2718 				// We will swizzle data in and out to make this work.
2719 				// We only need to consider plain variables here, not composites.
2720 				// This is only relevant for vertex inputs and fragment outputs.
2721 				// Technically tessellation as well, but it is too complicated to support.
2722 				uint32_t component = get_decoration(var_id, DecorationComponent);
2723 				if (component != 0)
2724 				{
2725 					if (is_tessellation_shader())
2726 						SPIRV_CROSS_THROW("Component decoration is not supported in tessellation shaders.");
2727 					else if (pack_components)
2728 					{
2729 						auto &location_meta = meta.location_meta[location];
2730 						location_meta.num_components = std::max(location_meta.num_components, component + type.vecsize);
2731 					}
2732 				}
2733 			}
2734 		}
2735 	});
2736 
2737 	// If no variables qualify, leave.
2738 	// For patch input in a tessellation evaluation shader, the per-vertex stage inputs
2739 	// are included in a special patch control point array.
2740 	if (vars.empty() && !(storage == StorageClassInput && patch && stage_in_var_id))
2741 		return 0;
2742 
2743 	// Add a new typed variable for this interface structure.
2744 	// The initializer expression is allocated here, but populated when the function
2745 	// declaraion is emitted, because it is cleared after each compilation pass.
2746 	uint32_t next_id = ir.increase_bound_by(3);
2747 	uint32_t ib_type_id = next_id++;
2748 	auto &ib_type = set<SPIRType>(ib_type_id);
2749 	ib_type.basetype = SPIRType::Struct;
2750 	ib_type.storage = storage;
2751 	set_decoration(ib_type_id, DecorationBlock);
2752 
2753 	uint32_t ib_var_id = next_id++;
2754 	auto &var = set<SPIRVariable>(ib_var_id, ib_type_id, storage, 0);
2755 	var.initializer = next_id++;
2756 
2757 	string ib_var_ref;
2758 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2759 	switch (storage)
2760 	{
2761 	case StorageClassInput:
2762 		ib_var_ref = patch ? patch_stage_in_var_name : stage_in_var_name;
2763 		if (get_execution_model() == ExecutionModelTessellationControl)
2764 		{
2765 			// Add a hook to populate the shared workgroup memory containing the gl_in array.
2766 			entry_func.fixup_hooks_in.push_back([=]() {
2767 				// Can't use PatchVertices, PrimitiveId, or InvocationId yet; the hooks for those may not have run yet.
2768 				if (msl_options.multi_patch_workgroup)
2769 				{
2770 					// n.b. builtin_invocation_id_id here is the dispatch global invocation ID,
2771 					// not the TC invocation ID.
2772 					statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_in = &",
2773 					          input_buffer_var_name, "[min(", to_expression(builtin_invocation_id_id), ".x / ",
2774 					          get_entry_point().output_vertices,
2775 					          ", spvIndirectParams[1] - 1) * spvIndirectParams[0]];");
2776 				}
2777 				else
2778 				{
2779 					// It's safe to use InvocationId here because it's directly mapped to a
2780 					// Metal builtin, and therefore doesn't need a hook.
2781 					statement("if (", to_expression(builtin_invocation_id_id), " < spvIndirectParams[0])");
2782 					statement("    ", input_wg_var_name, "[", to_expression(builtin_invocation_id_id),
2783 					          "] = ", ib_var_ref, ";");
2784 					statement("threadgroup_barrier(mem_flags::mem_threadgroup);");
2785 					statement("if (", to_expression(builtin_invocation_id_id),
2786 					          " >= ", get_entry_point().output_vertices, ")");
2787 					statement("    return;");
2788 				}
2789 			});
2790 		}
2791 		break;
2792 
2793 	case StorageClassOutput:
2794 	{
2795 		ib_var_ref = patch ? patch_stage_out_var_name : stage_out_var_name;
2796 
2797 		// Add the output interface struct as a local variable to the entry function.
2798 		// If the entry point should return the output struct, set the entry function
2799 		// to return the output interface struct, otherwise to return nothing.
2800 		// Indicate the output var requires early initialization.
2801 		bool ep_should_return_output = !get_is_rasterization_disabled();
2802 		uint32_t rtn_id = ep_should_return_output ? ib_var_id : 0;
2803 		if (!capture_output_to_buffer)
2804 		{
2805 			entry_func.add_local_variable(ib_var_id);
2806 			for (auto &blk_id : entry_func.blocks)
2807 			{
2808 				auto &blk = get<SPIRBlock>(blk_id);
2809 				if (blk.terminator == SPIRBlock::Return)
2810 					blk.return_value = rtn_id;
2811 			}
2812 			vars_needing_early_declaration.push_back(ib_var_id);
2813 		}
2814 		else
2815 		{
2816 			switch (get_execution_model())
2817 			{
2818 			case ExecutionModelVertex:
2819 			case ExecutionModelTessellationEvaluation:
2820 				// Instead of declaring a struct variable to hold the output and then
2821 				// copying that to the output buffer, we'll declare the output variable
2822 				// as a reference to the final output element in the buffer. Then we can
2823 				// avoid the extra copy.
2824 				entry_func.fixup_hooks_in.push_back([=]() {
2825 					if (stage_out_var_id)
2826 					{
2827 						// The first member of the indirect buffer is always the number of vertices
2828 						// to draw.
2829 						// We zero-base the InstanceID & VertexID variables for HLSL emulation elsewhere, so don't do it twice
2830 						if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
2831 						{
2832 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
2833 							          " = ", output_buffer_var_name, "[", to_expression(builtin_invocation_id_id),
2834 							          ".y * ", to_expression(builtin_stage_input_size_id), ".x + ",
2835 							          to_expression(builtin_invocation_id_id), ".x];");
2836 						}
2837 						else if (msl_options.enable_base_index_zero)
2838 						{
2839 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
2840 							          " = ", output_buffer_var_name, "[", to_expression(builtin_instance_idx_id),
2841 							          " * spvIndirectParams[0] + ", to_expression(builtin_vertex_idx_id), "];");
2842 						}
2843 						else
2844 						{
2845 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
2846 							          " = ", output_buffer_var_name, "[(", to_expression(builtin_instance_idx_id),
2847 							          " - ", to_expression(builtin_base_instance_id), ") * spvIndirectParams[0] + ",
2848 							          to_expression(builtin_vertex_idx_id), " - ",
2849 							          to_expression(builtin_base_vertex_id), "];");
2850 						}
2851 					}
2852 				});
2853 				break;
2854 			case ExecutionModelTessellationControl:
2855 				if (msl_options.multi_patch_workgroup)
2856 				{
2857 					// We cannot use PrimitiveId here, because the hook may not have run yet.
2858 					if (patch)
2859 					{
2860 						entry_func.fixup_hooks_in.push_back([=]() {
2861 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
2862 							          " = ", patch_output_buffer_var_name, "[", to_expression(builtin_invocation_id_id),
2863 							          ".x / ", get_entry_point().output_vertices, "];");
2864 						});
2865 					}
2866 					else
2867 					{
2868 						entry_func.fixup_hooks_in.push_back([=]() {
2869 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &",
2870 							          output_buffer_var_name, "[", to_expression(builtin_invocation_id_id), ".x - ",
2871 							          to_expression(builtin_invocation_id_id), ".x % ",
2872 							          get_entry_point().output_vertices, "];");
2873 						});
2874 					}
2875 				}
2876 				else
2877 				{
2878 					if (patch)
2879 					{
2880 						entry_func.fixup_hooks_in.push_back([=]() {
2881 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
2882 							          " = ", patch_output_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
2883 							          "];");
2884 						});
2885 					}
2886 					else
2887 					{
2888 						entry_func.fixup_hooks_in.push_back([=]() {
2889 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &",
2890 							          output_buffer_var_name, "[", to_expression(builtin_primitive_id_id), " * ",
2891 							          get_entry_point().output_vertices, "];");
2892 						});
2893 					}
2894 				}
2895 				break;
2896 			default:
2897 				break;
2898 			}
2899 		}
2900 		break;
2901 	}
2902 
2903 	default:
2904 		break;
2905 	}
2906 
2907 	set_name(ib_type_id, to_name(ir.default_entry_point) + "_" + ib_var_ref);
2908 	set_name(ib_var_id, ib_var_ref);
2909 
2910 	for (auto *p_var : vars)
2911 	{
2912 		bool strip_array =
2913 		    (get_execution_model() == ExecutionModelTessellationControl ||
2914 		     (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput)) &&
2915 		    !patch;
2916 
2917 		meta.strip_array = strip_array;
2918 		add_variable_to_interface_block(storage, ib_var_ref, ib_type, *p_var, meta);
2919 	}
2920 
2921 	if (get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup &&
2922 	    storage == StorageClassInput)
2923 	{
2924 		// For tessellation control inputs, add all outputs from the vertex shader to ensure
2925 		// the struct containing them is the correct size and layout.
2926 		for (auto &input : inputs_by_location)
2927 		{
2928 			if (is_msl_shader_input_used(input.first))
2929 				continue;
2930 
2931 			// Create a fake variable to put at the location.
2932 			uint32_t offset = ir.increase_bound_by(4);
2933 			uint32_t type_id = offset;
2934 			uint32_t array_type_id = offset + 1;
2935 			uint32_t ptr_type_id = offset + 2;
2936 			uint32_t var_id = offset + 3;
2937 
2938 			SPIRType type;
2939 			switch (input.second.format)
2940 			{
2941 			case MSL_SHADER_INPUT_FORMAT_UINT16:
2942 			case MSL_SHADER_INPUT_FORMAT_ANY16:
2943 				type.basetype = SPIRType::UShort;
2944 				type.width = 16;
2945 				break;
2946 			case MSL_SHADER_INPUT_FORMAT_ANY32:
2947 			default:
2948 				type.basetype = SPIRType::UInt;
2949 				type.width = 32;
2950 				break;
2951 			}
2952 			type.vecsize = input.second.vecsize;
2953 			set<SPIRType>(type_id, type);
2954 
2955 			type.array.push_back(0);
2956 			type.array_size_literal.push_back(true);
2957 			type.parent_type = type_id;
2958 			set<SPIRType>(array_type_id, type);
2959 
2960 			type.pointer = true;
2961 			type.parent_type = array_type_id;
2962 			type.storage = storage;
2963 			auto &ptr_type = set<SPIRType>(ptr_type_id, type);
2964 			ptr_type.self = array_type_id;
2965 
2966 			auto &fake_var = set<SPIRVariable>(var_id, ptr_type_id, storage);
2967 			set_decoration(var_id, DecorationLocation, input.first);
2968 			meta.strip_array = true;
2969 			add_variable_to_interface_block(storage, ib_var_ref, ib_type, fake_var, meta);
2970 		}
2971 	}
2972 
2973 	// Sort the members of the structure by their locations.
2974 	MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Location);
2975 	member_sorter.sort();
2976 
2977 	// The member indices were saved to the original variables, but after the members
2978 	// were sorted, those indices are now likely incorrect. Fix those up now.
2979 	if (!patch)
2980 		fix_up_interface_member_indices(storage, ib_type_id);
2981 
2982 	// For patch inputs, add one more member, holding the array of control point data.
2983 	if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput && patch &&
2984 	    stage_in_var_id)
2985 	{
2986 		uint32_t pcp_type_id = ir.increase_bound_by(1);
2987 		auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
2988 		pcp_type.basetype = SPIRType::ControlPointArray;
2989 		pcp_type.parent_type = pcp_type.type_alias = get_stage_in_struct_type().self;
2990 		pcp_type.storage = storage;
2991 		ir.meta[pcp_type_id] = ir.meta[ib_type.self];
2992 		uint32_t mbr_idx = uint32_t(ib_type.member_types.size());
2993 		ib_type.member_types.push_back(pcp_type_id);
2994 		set_member_name(ib_type.self, mbr_idx, "gl_in");
2995 	}
2996 
2997 	return ib_var_id;
2998 }
2999 
add_interface_block_pointer(uint32_t ib_var_id,StorageClass storage)3000 uint32_t CompilerMSL::add_interface_block_pointer(uint32_t ib_var_id, StorageClass storage)
3001 {
3002 	if (!ib_var_id)
3003 		return 0;
3004 
3005 	uint32_t ib_ptr_var_id;
3006 	uint32_t next_id = ir.increase_bound_by(3);
3007 	auto &ib_type = expression_type(ib_var_id);
3008 	if (get_execution_model() == ExecutionModelTessellationControl)
3009 	{
3010 		// Tessellation control per-vertex I/O is presented as an array, so we must
3011 		// do the same with our struct here.
3012 		uint32_t ib_ptr_type_id = next_id++;
3013 		auto &ib_ptr_type = set<SPIRType>(ib_ptr_type_id, ib_type);
3014 		ib_ptr_type.parent_type = ib_ptr_type.type_alias = ib_type.self;
3015 		ib_ptr_type.pointer = true;
3016 		ib_ptr_type.storage =
3017 		    storage == StorageClassInput ?
3018 		        (msl_options.multi_patch_workgroup ? StorageClassStorageBuffer : StorageClassWorkgroup) :
3019 		        StorageClassStorageBuffer;
3020 		ir.meta[ib_ptr_type_id] = ir.meta[ib_type.self];
3021 		// To ensure that get_variable_data_type() doesn't strip off the pointer,
3022 		// which we need, use another pointer.
3023 		uint32_t ib_ptr_ptr_type_id = next_id++;
3024 		auto &ib_ptr_ptr_type = set<SPIRType>(ib_ptr_ptr_type_id, ib_ptr_type);
3025 		ib_ptr_ptr_type.parent_type = ib_ptr_type_id;
3026 		ib_ptr_ptr_type.type_alias = ib_type.self;
3027 		ib_ptr_ptr_type.storage = StorageClassFunction;
3028 		ir.meta[ib_ptr_ptr_type_id] = ir.meta[ib_type.self];
3029 
3030 		ib_ptr_var_id = next_id;
3031 		set<SPIRVariable>(ib_ptr_var_id, ib_ptr_ptr_type_id, StorageClassFunction, 0);
3032 		set_name(ib_ptr_var_id, storage == StorageClassInput ? "gl_in" : "gl_out");
3033 	}
3034 	else
3035 	{
3036 		// Tessellation evaluation per-vertex inputs are also presented as arrays.
3037 		// But, in Metal, this array uses a very special type, 'patch_control_point<T>',
3038 		// which is a container that can be used to access the control point data.
3039 		// To represent this, a special 'ControlPointArray' type has been added to the
3040 		// SPIRV-Cross type system. It should only be generated by and seen in the MSL
3041 		// backend (i.e. this one).
3042 		uint32_t pcp_type_id = next_id++;
3043 		auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
3044 		pcp_type.basetype = SPIRType::ControlPointArray;
3045 		pcp_type.parent_type = pcp_type.type_alias = ib_type.self;
3046 		pcp_type.storage = storage;
3047 		ir.meta[pcp_type_id] = ir.meta[ib_type.self];
3048 
3049 		ib_ptr_var_id = next_id;
3050 		set<SPIRVariable>(ib_ptr_var_id, pcp_type_id, storage, 0);
3051 		set_name(ib_ptr_var_id, "gl_in");
3052 		ir.meta[ib_ptr_var_id].decoration.qualified_alias = join(patch_stage_in_var_name, ".gl_in");
3053 	}
3054 	return ib_ptr_var_id;
3055 }
3056 
3057 // Ensure that the type is compatible with the builtin.
3058 // If it is, simply return the given type ID.
3059 // Otherwise, create a new type, and return it's ID.
ensure_correct_builtin_type(uint32_t type_id,BuiltIn builtin)3060 uint32_t CompilerMSL::ensure_correct_builtin_type(uint32_t type_id, BuiltIn builtin)
3061 {
3062 	auto &type = get<SPIRType>(type_id);
3063 
3064 	if ((builtin == BuiltInSampleMask && is_array(type)) ||
3065 	    ((builtin == BuiltInLayer || builtin == BuiltInViewportIndex || builtin == BuiltInFragStencilRefEXT) &&
3066 	     type.basetype != SPIRType::UInt))
3067 	{
3068 		uint32_t next_id = ir.increase_bound_by(type.pointer ? 2 : 1);
3069 		uint32_t base_type_id = next_id++;
3070 		auto &base_type = set<SPIRType>(base_type_id);
3071 		base_type.basetype = SPIRType::UInt;
3072 		base_type.width = 32;
3073 
3074 		if (!type.pointer)
3075 			return base_type_id;
3076 
3077 		uint32_t ptr_type_id = next_id++;
3078 		auto &ptr_type = set<SPIRType>(ptr_type_id);
3079 		ptr_type = base_type;
3080 		ptr_type.pointer = true;
3081 		ptr_type.storage = type.storage;
3082 		ptr_type.parent_type = base_type_id;
3083 		return ptr_type_id;
3084 	}
3085 
3086 	return type_id;
3087 }
3088 
3089 // Ensure that the type is compatible with the shader input.
3090 // If it is, simply return the given type ID.
3091 // Otherwise, create a new type, and return its ID.
ensure_correct_input_type(uint32_t type_id,uint32_t location,uint32_t num_components)3092 uint32_t CompilerMSL::ensure_correct_input_type(uint32_t type_id, uint32_t location, uint32_t num_components)
3093 {
3094 	auto &type = get<SPIRType>(type_id);
3095 
3096 	auto p_va = inputs_by_location.find(location);
3097 	if (p_va == end(inputs_by_location))
3098 	{
3099 		if (num_components > type.vecsize)
3100 			return build_extended_vector_type(type_id, num_components);
3101 		else
3102 			return type_id;
3103 	}
3104 
3105 	if (num_components == 0)
3106 		num_components = p_va->second.vecsize;
3107 
3108 	switch (p_va->second.format)
3109 	{
3110 	case MSL_SHADER_INPUT_FORMAT_UINT8:
3111 	{
3112 		switch (type.basetype)
3113 		{
3114 		case SPIRType::UByte:
3115 		case SPIRType::UShort:
3116 		case SPIRType::UInt:
3117 			if (num_components > type.vecsize)
3118 				return build_extended_vector_type(type_id, num_components);
3119 			else
3120 				return type_id;
3121 
3122 		case SPIRType::Short:
3123 			return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3124 			                                  SPIRType::UShort);
3125 		case SPIRType::Int:
3126 			return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3127 			                                  SPIRType::UInt);
3128 
3129 		default:
3130 			SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
3131 		}
3132 	}
3133 
3134 	case MSL_SHADER_INPUT_FORMAT_UINT16:
3135 	{
3136 		switch (type.basetype)
3137 		{
3138 		case SPIRType::UShort:
3139 		case SPIRType::UInt:
3140 			if (num_components > type.vecsize)
3141 				return build_extended_vector_type(type_id, num_components);
3142 			else
3143 				return type_id;
3144 
3145 		case SPIRType::Int:
3146 			return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3147 			                                  SPIRType::UInt);
3148 
3149 		default:
3150 			SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
3151 		}
3152 	}
3153 
3154 	default:
3155 		if (num_components > type.vecsize)
3156 			type_id = build_extended_vector_type(type_id, num_components);
3157 		break;
3158 	}
3159 
3160 	return type_id;
3161 }
3162 
mark_struct_members_packed(const SPIRType & type)3163 void CompilerMSL::mark_struct_members_packed(const SPIRType &type)
3164 {
3165 	set_extended_decoration(type.self, SPIRVCrossDecorationPhysicalTypePacked);
3166 
3167 	// Problem case! Struct needs to be placed at an awkward alignment.
3168 	// Mark every member of the child struct as packed.
3169 	uint32_t mbr_cnt = uint32_t(type.member_types.size());
3170 	for (uint32_t i = 0; i < mbr_cnt; i++)
3171 	{
3172 		auto &mbr_type = get<SPIRType>(type.member_types[i]);
3173 		if (mbr_type.basetype == SPIRType::Struct)
3174 		{
3175 			// Recursively mark structs as packed.
3176 			auto *struct_type = &mbr_type;
3177 			while (!struct_type->array.empty())
3178 				struct_type = &get<SPIRType>(struct_type->parent_type);
3179 			mark_struct_members_packed(*struct_type);
3180 		}
3181 		else if (!is_scalar(mbr_type))
3182 			set_extended_member_decoration(type.self, i, SPIRVCrossDecorationPhysicalTypePacked);
3183 	}
3184 }
3185 
mark_scalar_layout_structs(const SPIRType & type)3186 void CompilerMSL::mark_scalar_layout_structs(const SPIRType &type)
3187 {
3188 	uint32_t mbr_cnt = uint32_t(type.member_types.size());
3189 	for (uint32_t i = 0; i < mbr_cnt; i++)
3190 	{
3191 		auto &mbr_type = get<SPIRType>(type.member_types[i]);
3192 		if (mbr_type.basetype == SPIRType::Struct)
3193 		{
3194 			auto *struct_type = &mbr_type;
3195 			while (!struct_type->array.empty())
3196 				struct_type = &get<SPIRType>(struct_type->parent_type);
3197 
3198 			if (has_extended_decoration(struct_type->self, SPIRVCrossDecorationPhysicalTypePacked))
3199 				continue;
3200 
3201 			uint32_t msl_alignment = get_declared_struct_member_alignment_msl(type, i);
3202 			uint32_t msl_size = get_declared_struct_member_size_msl(type, i);
3203 			uint32_t spirv_offset = type_struct_member_offset(type, i);
3204 			uint32_t spirv_offset_next;
3205 			if (i + 1 < mbr_cnt)
3206 				spirv_offset_next = type_struct_member_offset(type, i + 1);
3207 			else
3208 				spirv_offset_next = spirv_offset + msl_size;
3209 
3210 			// Both are complicated cases. In scalar layout, a struct of float3 might just consume 12 bytes,
3211 			// and the next member will be placed at offset 12.
3212 			bool struct_is_misaligned = (spirv_offset % msl_alignment) != 0;
3213 			bool struct_is_too_large = spirv_offset + msl_size > spirv_offset_next;
3214 			uint32_t array_stride = 0;
3215 			bool struct_needs_explicit_padding = false;
3216 
3217 			// Verify that if a struct is used as an array that ArrayStride matches the effective size of the struct.
3218 			if (!mbr_type.array.empty())
3219 			{
3220 				array_stride = type_struct_member_array_stride(type, i);
3221 				uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
3222 				for (uint32_t dim = 0; dim < dimensions; dim++)
3223 				{
3224 					uint32_t array_size = to_array_size_literal(mbr_type, dim);
3225 					array_stride /= max(array_size, 1u);
3226 				}
3227 
3228 				// Set expected struct size based on ArrayStride.
3229 				struct_needs_explicit_padding = true;
3230 
3231 				// If struct size is larger than array stride, we might be able to fit, if we tightly pack.
3232 				if (get_declared_struct_size_msl(*struct_type) > array_stride)
3233 					struct_is_too_large = true;
3234 			}
3235 
3236 			if (struct_is_misaligned || struct_is_too_large)
3237 				mark_struct_members_packed(*struct_type);
3238 			mark_scalar_layout_structs(*struct_type);
3239 
3240 			if (struct_needs_explicit_padding)
3241 			{
3242 				msl_size = get_declared_struct_size_msl(*struct_type, true, true);
3243 				if (array_stride < msl_size)
3244 				{
3245 					SPIRV_CROSS_THROW("Cannot express an array stride smaller than size of struct type.");
3246 				}
3247 				else
3248 				{
3249 					if (has_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget))
3250 					{
3251 						if (array_stride !=
3252 						    get_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget))
3253 							SPIRV_CROSS_THROW(
3254 							    "A struct is used with different array strides. Cannot express this in MSL.");
3255 					}
3256 					else
3257 						set_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget, array_stride);
3258 				}
3259 			}
3260 		}
3261 	}
3262 }
3263 
3264 // Sort the members of the struct type by offset, and pack and then pad members where needed
3265 // to align MSL members with SPIR-V offsets. The struct members are iterated twice. Packing
3266 // occurs first, followed by padding, because packing a member reduces both its size and its
3267 // natural alignment, possibly requiring a padding member to be added ahead of it.
align_struct(SPIRType & ib_type,unordered_set<uint32_t> & aligned_structs)3268 void CompilerMSL::align_struct(SPIRType &ib_type, unordered_set<uint32_t> &aligned_structs)
3269 {
3270 	// We align structs recursively, so stop any redundant work.
3271 	ID &ib_type_id = ib_type.self;
3272 	if (aligned_structs.count(ib_type_id))
3273 		return;
3274 	aligned_structs.insert(ib_type_id);
3275 
3276 	// Sort the members of the interface structure by their offset.
3277 	// They should already be sorted per SPIR-V spec anyway.
3278 	MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Offset);
3279 	member_sorter.sort();
3280 
3281 	auto mbr_cnt = uint32_t(ib_type.member_types.size());
3282 
3283 	for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
3284 	{
3285 		// Pack any dependent struct types before we pack a parent struct.
3286 		auto &mbr_type = get<SPIRType>(ib_type.member_types[mbr_idx]);
3287 		if (mbr_type.basetype == SPIRType::Struct)
3288 			align_struct(mbr_type, aligned_structs);
3289 	}
3290 
3291 	// Test the alignment of each member, and if a member should be closer to the previous
3292 	// member than the default spacing expects, it is likely that the previous member is in
3293 	// a packed format. If so, and the previous member is packable, pack it.
3294 	// For example ... this applies to any 3-element vector that is followed by a scalar.
3295 	uint32_t msl_offset = 0;
3296 	for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
3297 	{
3298 		// This checks the member in isolation, if the member needs some kind of type remapping to conform to SPIR-V
3299 		// offsets, array strides and matrix strides.
3300 		ensure_member_packing_rules_msl(ib_type, mbr_idx);
3301 
3302 		// Align current offset to the current member's default alignment. If the member was packed, it will observe
3303 		// the updated alignment here.
3304 		uint32_t msl_align_mask = get_declared_struct_member_alignment_msl(ib_type, mbr_idx) - 1;
3305 		uint32_t aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
3306 
3307 		// Fetch the member offset as declared in the SPIRV.
3308 		uint32_t spirv_mbr_offset = get_member_decoration(ib_type_id, mbr_idx, DecorationOffset);
3309 		if (spirv_mbr_offset > aligned_msl_offset)
3310 		{
3311 			// Since MSL and SPIR-V have slightly different struct member alignment and
3312 			// size rules, we'll pad to standard C-packing rules with a char[] array. If the member is farther
3313 			// away than C-packing, expects, add an inert padding member before the the member.
3314 			uint32_t padding_bytes = spirv_mbr_offset - aligned_msl_offset;
3315 			set_extended_member_decoration(ib_type_id, mbr_idx, SPIRVCrossDecorationPaddingTarget, padding_bytes);
3316 
3317 			// Re-align as a sanity check that aligning post-padding matches up.
3318 			msl_offset += padding_bytes;
3319 			aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
3320 		}
3321 		else if (spirv_mbr_offset < aligned_msl_offset)
3322 		{
3323 			// This should not happen, but deal with unexpected scenarios.
3324 			// It *might* happen if a sub-struct has a larger alignment requirement in MSL than SPIR-V.
3325 			SPIRV_CROSS_THROW("Cannot represent buffer block correctly in MSL.");
3326 		}
3327 
3328 		assert(aligned_msl_offset == spirv_mbr_offset);
3329 
3330 		// Increment the current offset to be positioned immediately after the current member.
3331 		// Don't do this for the last member since it can be unsized, and it is not relevant for padding purposes here.
3332 		if (mbr_idx + 1 < mbr_cnt)
3333 			msl_offset = aligned_msl_offset + get_declared_struct_member_size_msl(ib_type, mbr_idx);
3334 	}
3335 }
3336 
validate_member_packing_rules_msl(const SPIRType & type,uint32_t index) const3337 bool CompilerMSL::validate_member_packing_rules_msl(const SPIRType &type, uint32_t index) const
3338 {
3339 	auto &mbr_type = get<SPIRType>(type.member_types[index]);
3340 	uint32_t spirv_offset = get_member_decoration(type.self, index, DecorationOffset);
3341 
3342 	if (index + 1 < type.member_types.size())
3343 	{
3344 		// First, we will check offsets. If SPIR-V offset + MSL size > SPIR-V offset of next member,
3345 		// we *must* perform some kind of remapping, no way getting around it.
3346 		// We can always pad after this member if necessary, so that case is fine.
3347 		uint32_t spirv_offset_next = get_member_decoration(type.self, index + 1, DecorationOffset);
3348 		assert(spirv_offset_next >= spirv_offset);
3349 		uint32_t maximum_size = spirv_offset_next - spirv_offset;
3350 		uint32_t msl_mbr_size = get_declared_struct_member_size_msl(type, index);
3351 		if (msl_mbr_size > maximum_size)
3352 			return false;
3353 	}
3354 
3355 	if (!mbr_type.array.empty())
3356 	{
3357 		// If we have an array type, array stride must match exactly with SPIR-V.
3358 
3359 		// An exception to this requirement is if we have one array element.
3360 		// This comes from DX scalar layout workaround.
3361 		// If app tries to be cheeky and access the member out of bounds, this will not work, but this is the best we can do.
3362 		// In OpAccessChain with logical memory models, access chains must be in-bounds in SPIR-V specification.
3363 		bool relax_array_stride = mbr_type.array.back() == 1 && mbr_type.array_size_literal.back();
3364 
3365 		if (!relax_array_stride)
3366 		{
3367 			uint32_t spirv_array_stride = type_struct_member_array_stride(type, index);
3368 			uint32_t msl_array_stride = get_declared_struct_member_array_stride_msl(type, index);
3369 			if (spirv_array_stride != msl_array_stride)
3370 				return false;
3371 		}
3372 	}
3373 
3374 	if (is_matrix(mbr_type))
3375 	{
3376 		// Need to check MatrixStride as well.
3377 		uint32_t spirv_matrix_stride = type_struct_member_matrix_stride(type, index);
3378 		uint32_t msl_matrix_stride = get_declared_struct_member_matrix_stride_msl(type, index);
3379 		if (spirv_matrix_stride != msl_matrix_stride)
3380 			return false;
3381 	}
3382 
3383 	// Now, we check alignment.
3384 	uint32_t msl_alignment = get_declared_struct_member_alignment_msl(type, index);
3385 	if ((spirv_offset % msl_alignment) != 0)
3386 		return false;
3387 
3388 	// We're in the clear.
3389 	return true;
3390 }
3391 
3392 // Here we need to verify that the member type we declare conforms to Offset, ArrayStride or MatrixStride restrictions.
3393 // If there is a mismatch, we need to emit remapped types, either normal types, or "packed_X" types.
3394 // In odd cases we need to emit packed and remapped types, for e.g. weird matrices or arrays with weird array strides.
ensure_member_packing_rules_msl(SPIRType & ib_type,uint32_t index)3395 void CompilerMSL::ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t index)
3396 {
3397 	if (validate_member_packing_rules_msl(ib_type, index))
3398 		return;
3399 
3400 	// We failed validation.
3401 	// This case will be nightmare-ish to deal with. This could possibly happen if struct alignment does not quite
3402 	// match up with what we want. Scalar block layout comes to mind here where we might have to work around the rule
3403 	// that struct alignment == max alignment of all members and struct size depends on this alignment.
3404 	auto &mbr_type = get<SPIRType>(ib_type.member_types[index]);
3405 	if (mbr_type.basetype == SPIRType::Struct)
3406 		SPIRV_CROSS_THROW("Cannot perform any repacking for structs when it is used as a member of another struct.");
3407 
3408 	// Perform remapping here.
3409 	// There is nothing to be gained by using packed scalars, so don't attempt it.
3410 	if (!is_scalar(ib_type))
3411 		set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
3412 
3413 	// Try validating again, now with packed.
3414 	if (validate_member_packing_rules_msl(ib_type, index))
3415 		return;
3416 
3417 	// We're in deep trouble, and we need to create a new PhysicalType which matches up with what we expect.
3418 	// A lot of work goes here ...
3419 	// We will need remapping on Load and Store to translate the types between Logical and Physical.
3420 
3421 	// First, we check if we have small vector std140 array.
3422 	// We detect this if we have an array of vectors, and array stride is greater than number of elements.
3423 	if (!mbr_type.array.empty() && !is_matrix(mbr_type))
3424 	{
3425 		uint32_t array_stride = type_struct_member_array_stride(ib_type, index);
3426 
3427 		// Hack off array-of-arrays until we find the array stride per element we must have to make it work.
3428 		uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
3429 		for (uint32_t dim = 0; dim < dimensions; dim++)
3430 			array_stride /= max(to_array_size_literal(mbr_type, dim), 1u);
3431 
3432 		uint32_t elems_per_stride = array_stride / (mbr_type.width / 8);
3433 
3434 		if (elems_per_stride == 3)
3435 			SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
3436 		else if (elems_per_stride > 4)
3437 			SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
3438 
3439 		auto physical_type = mbr_type;
3440 		physical_type.vecsize = elems_per_stride;
3441 		physical_type.parent_type = 0;
3442 		uint32_t type_id = ir.increase_bound_by(1);
3443 		set<SPIRType>(type_id, physical_type);
3444 		set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID, type_id);
3445 		set_decoration(type_id, DecorationArrayStride, array_stride);
3446 
3447 		// Remove packed_ for vectors of size 1, 2 and 4.
3448 		if (has_extended_decoration(ib_type.self, SPIRVCrossDecorationPhysicalTypePacked))
3449 			SPIRV_CROSS_THROW("Unable to remove packed decoration as entire struct must be fully packed. Do not mix "
3450 			                  "scalar and std140 layout rules.");
3451 		else
3452 			unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
3453 	}
3454 	else if (is_matrix(mbr_type))
3455 	{
3456 		// MatrixStride might be std140-esque.
3457 		uint32_t matrix_stride = type_struct_member_matrix_stride(ib_type, index);
3458 
3459 		uint32_t elems_per_stride = matrix_stride / (mbr_type.width / 8);
3460 
3461 		if (elems_per_stride == 3)
3462 			SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
3463 		else if (elems_per_stride > 4)
3464 			SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
3465 
3466 		bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
3467 
3468 		auto physical_type = mbr_type;
3469 		physical_type.parent_type = 0;
3470 		if (row_major)
3471 			physical_type.columns = elems_per_stride;
3472 		else
3473 			physical_type.vecsize = elems_per_stride;
3474 		uint32_t type_id = ir.increase_bound_by(1);
3475 		set<SPIRType>(type_id, physical_type);
3476 		set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID, type_id);
3477 
3478 		// Remove packed_ for vectors of size 1, 2 and 4.
3479 		if (has_extended_decoration(ib_type.self, SPIRVCrossDecorationPhysicalTypePacked))
3480 			SPIRV_CROSS_THROW("Unable to remove packed decoration as entire struct must be fully packed. Do not mix "
3481 			                  "scalar and std140 layout rules.");
3482 		else
3483 			unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
3484 	}
3485 	else
3486 		SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
3487 
3488 	// Try validating again, now with physical type remapping.
3489 	if (validate_member_packing_rules_msl(ib_type, index))
3490 		return;
3491 
3492 	// We might have a particular odd scalar layout case where the last element of an array
3493 	// does not take up as much space as the ArrayStride or MatrixStride. This can happen with DX cbuffers.
3494 	// The "proper" workaround for this is extremely painful and essentially impossible in the edge case of float3[],
3495 	// so we hack around it by declaring the offending array or matrix with one less array size/col/row,
3496 	// and rely on padding to get the correct value. We will technically access arrays out of bounds into the padding region,
3497 	// but it should spill over gracefully without too much trouble. We rely on behavior like this for unsized arrays anyways.
3498 
3499 	// E.g. we might observe a physical layout of:
3500 	// { float2 a[2]; float b; } in cbuffer layout where ArrayStride of a is 16, but offset of b is 24, packed right after a[1] ...
3501 	uint32_t type_id = get_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID);
3502 	auto &type = get<SPIRType>(type_id);
3503 
3504 	// Modify the physical type in-place. This is safe since each physical type workaround is a copy.
3505 	if (is_array(type))
3506 	{
3507 		if (type.array.back() > 1)
3508 		{
3509 			if (!type.array_size_literal.back())
3510 				SPIRV_CROSS_THROW("Cannot apply scalar layout workaround with spec constant array size.");
3511 			type.array.back() -= 1;
3512 		}
3513 		else
3514 		{
3515 			// We have an array of size 1, so we cannot decrement that. Our only option now is to
3516 			// force a packed layout instead, and drop the physical type remap since ArrayStride is meaningless now.
3517 			unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID);
3518 			set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
3519 		}
3520 	}
3521 	else if (is_matrix(type))
3522 	{
3523 		bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
3524 		if (!row_major)
3525 		{
3526 			// Slice off one column. If we only have 2 columns, this might turn the matrix into a vector with one array element instead.
3527 			if (type.columns > 2)
3528 			{
3529 				type.columns--;
3530 			}
3531 			else if (type.columns == 2)
3532 			{
3533 				type.columns = 1;
3534 				assert(type.array.empty());
3535 				type.array.push_back(1);
3536 				type.array_size_literal.push_back(true);
3537 			}
3538 		}
3539 		else
3540 		{
3541 			// Slice off one row. If we only have 2 rows, this might turn the matrix into a vector with one array element instead.
3542 			if (type.vecsize > 2)
3543 			{
3544 				type.vecsize--;
3545 			}
3546 			else if (type.vecsize == 2)
3547 			{
3548 				type.vecsize = type.columns;
3549 				type.columns = 1;
3550 				assert(type.array.empty());
3551 				type.array.push_back(1);
3552 				type.array_size_literal.push_back(true);
3553 			}
3554 		}
3555 	}
3556 
3557 	// This better validate now, or we must fail gracefully.
3558 	if (!validate_member_packing_rules_msl(ib_type, index))
3559 		SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
3560 }
3561 
emit_store_statement(uint32_t lhs_expression,uint32_t rhs_expression)3562 void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression)
3563 {
3564 	auto &type = expression_type(rhs_expression);
3565 
3566 	bool lhs_remapped_type = has_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypeID);
3567 	bool lhs_packed_type = has_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypePacked);
3568 	auto *lhs_e = maybe_get<SPIRExpression>(lhs_expression);
3569 	auto *rhs_e = maybe_get<SPIRExpression>(rhs_expression);
3570 
3571 	bool transpose = lhs_e && lhs_e->need_transpose;
3572 
3573 	// No physical type remapping, and no packed type, so can just emit a store directly.
3574 	if (!lhs_remapped_type && !lhs_packed_type)
3575 	{
3576 		// We might not be dealing with remapped physical types or packed types,
3577 		// but we might be doing a clean store to a row-major matrix.
3578 		// In this case, we just flip transpose states, and emit the store, a transpose must be in the RHS expression, if any.
3579 		if (is_matrix(type) && lhs_e && lhs_e->need_transpose)
3580 		{
3581 			if (!rhs_e)
3582 				SPIRV_CROSS_THROW("Need to transpose right-side expression of a store to row-major matrix, but it is "
3583 				                  "not a SPIRExpression.");
3584 			lhs_e->need_transpose = false;
3585 
3586 			if (rhs_e && rhs_e->need_transpose)
3587 			{
3588 				// Direct copy, but might need to unpack RHS.
3589 				// Skip the transpose, as we will transpose when writing to LHS and transpose(transpose(T)) == T.
3590 				rhs_e->need_transpose = false;
3591 				statement(to_expression(lhs_expression), " = ", to_unpacked_row_major_matrix_expression(rhs_expression),
3592 				          ";");
3593 				rhs_e->need_transpose = true;
3594 			}
3595 			else
3596 				statement(to_expression(lhs_expression), " = transpose(", to_unpacked_expression(rhs_expression), ");");
3597 
3598 			lhs_e->need_transpose = true;
3599 			register_write(lhs_expression);
3600 		}
3601 		else if (lhs_e && lhs_e->need_transpose)
3602 		{
3603 			lhs_e->need_transpose = false;
3604 
3605 			// Storing a column to a row-major matrix. Unroll the write.
3606 			for (uint32_t c = 0; c < type.vecsize; c++)
3607 			{
3608 				auto lhs_expr = to_dereferenced_expression(lhs_expression);
3609 				auto column_index = lhs_expr.find_last_of('[');
3610 				if (column_index != string::npos)
3611 				{
3612 					statement(lhs_expr.insert(column_index, join('[', c, ']')), " = ",
3613 					          to_extract_component_expression(rhs_expression, c), ";");
3614 				}
3615 			}
3616 			lhs_e->need_transpose = true;
3617 			register_write(lhs_expression);
3618 		}
3619 		else
3620 			CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
3621 	}
3622 	else if (!lhs_remapped_type && !is_matrix(type) && !transpose)
3623 	{
3624 		// Even if the target type is packed, we can directly store to it. We cannot store to packed matrices directly,
3625 		// since they are declared as array of vectors instead, and we need the fallback path below.
3626 		CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
3627 	}
3628 	else
3629 	{
3630 		// Special handling when storing to a remapped physical type.
3631 		// This is mostly to deal with std140 padded matrices or vectors.
3632 
3633 		TypeID physical_type_id = lhs_remapped_type ?
3634 		                              ID(get_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypeID)) :
3635 		                              type.self;
3636 
3637 		auto &physical_type = get<SPIRType>(physical_type_id);
3638 
3639 		if (is_matrix(type))
3640 		{
3641 			const char *packed_pfx = lhs_packed_type ? "packed_" : "";
3642 
3643 			// Packed matrices are stored as arrays of packed vectors, so we need
3644 			// to assign the vectors one at a time.
3645 			// For row-major matrices, we need to transpose the *right-hand* side,
3646 			// not the left-hand side.
3647 
3648 			// Lots of cases to cover here ...
3649 
3650 			bool rhs_transpose = rhs_e && rhs_e->need_transpose;
3651 			SPIRType write_type = type;
3652 			string cast_expr;
3653 
3654 			// We're dealing with transpose manually.
3655 			if (rhs_transpose)
3656 				rhs_e->need_transpose = false;
3657 
3658 			if (transpose)
3659 			{
3660 				// We're dealing with transpose manually.
3661 				lhs_e->need_transpose = false;
3662 				write_type.vecsize = type.columns;
3663 				write_type.columns = 1;
3664 
3665 				if (physical_type.columns != type.columns)
3666 					cast_expr = join("(device ", packed_pfx, type_to_glsl(write_type), "&)");
3667 
3668 				if (rhs_transpose)
3669 				{
3670 					// If RHS is also transposed, we can just copy row by row.
3671 					for (uint32_t i = 0; i < type.vecsize; i++)
3672 					{
3673 						statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ",
3674 						          to_unpacked_row_major_matrix_expression(rhs_expression), "[", i, "];");
3675 					}
3676 				}
3677 				else
3678 				{
3679 					auto vector_type = expression_type(rhs_expression);
3680 					vector_type.vecsize = vector_type.columns;
3681 					vector_type.columns = 1;
3682 
3683 					// Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
3684 					// so pick out individual components instead.
3685 					for (uint32_t i = 0; i < type.vecsize; i++)
3686 					{
3687 						string rhs_row = type_to_glsl_constructor(vector_type) + "(";
3688 						for (uint32_t j = 0; j < vector_type.vecsize; j++)
3689 						{
3690 							rhs_row += join(to_enclosed_unpacked_expression(rhs_expression), "[", j, "][", i, "]");
3691 							if (j + 1 < vector_type.vecsize)
3692 								rhs_row += ", ";
3693 						}
3694 						rhs_row += ")";
3695 
3696 						statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";");
3697 					}
3698 				}
3699 
3700 				// We're dealing with transpose manually.
3701 				lhs_e->need_transpose = true;
3702 			}
3703 			else
3704 			{
3705 				write_type.columns = 1;
3706 
3707 				if (physical_type.vecsize != type.vecsize)
3708 					cast_expr = join("(device ", packed_pfx, type_to_glsl(write_type), "&)");
3709 
3710 				if (rhs_transpose)
3711 				{
3712 					auto vector_type = expression_type(rhs_expression);
3713 					vector_type.columns = 1;
3714 
3715 					// Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
3716 					// so pick out individual components instead.
3717 					for (uint32_t i = 0; i < type.columns; i++)
3718 					{
3719 						string rhs_row = type_to_glsl_constructor(vector_type) + "(";
3720 						for (uint32_t j = 0; j < vector_type.vecsize; j++)
3721 						{
3722 							// Need to explicitly unpack expression since we've mucked with transpose state.
3723 							auto unpacked_expr = to_unpacked_row_major_matrix_expression(rhs_expression);
3724 							rhs_row += join(unpacked_expr, "[", j, "][", i, "]");
3725 							if (j + 1 < vector_type.vecsize)
3726 								rhs_row += ", ";
3727 						}
3728 						rhs_row += ")";
3729 
3730 						statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";");
3731 					}
3732 				}
3733 				else
3734 				{
3735 					// Copy column-by-column.
3736 					for (uint32_t i = 0; i < type.columns; i++)
3737 					{
3738 						statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ",
3739 						          to_enclosed_unpacked_expression(rhs_expression), "[", i, "];");
3740 					}
3741 				}
3742 			}
3743 
3744 			// We're dealing with transpose manually.
3745 			if (rhs_transpose)
3746 				rhs_e->need_transpose = true;
3747 		}
3748 		else if (transpose)
3749 		{
3750 			lhs_e->need_transpose = false;
3751 
3752 			SPIRType write_type = type;
3753 			write_type.vecsize = 1;
3754 			write_type.columns = 1;
3755 
3756 			// Storing a column to a row-major matrix. Unroll the write.
3757 			for (uint32_t c = 0; c < type.vecsize; c++)
3758 			{
3759 				auto lhs_expr = to_enclosed_expression(lhs_expression);
3760 				auto column_index = lhs_expr.find_last_of('[');
3761 				if (column_index != string::npos)
3762 				{
3763 					statement("((device ", type_to_glsl(write_type), "*)&",
3764 					          lhs_expr.insert(column_index, join('[', c, ']', ")")), " = ",
3765 					          to_extract_component_expression(rhs_expression, c), ";");
3766 				}
3767 			}
3768 
3769 			lhs_e->need_transpose = true;
3770 		}
3771 		else if ((is_matrix(physical_type) || is_array(physical_type)) && physical_type.vecsize > type.vecsize)
3772 		{
3773 			assert(type.vecsize >= 1 && type.vecsize <= 3);
3774 
3775 			// If we have packed types, we cannot use swizzled stores.
3776 			// We could technically unroll the store for each element if needed.
3777 			// When remapping to a std140 physical type, we always get float4,
3778 			// and the packed decoration should always be removed.
3779 			assert(!lhs_packed_type);
3780 
3781 			string lhs = to_dereferenced_expression(lhs_expression);
3782 			string rhs = to_pointer_expression(rhs_expression);
3783 
3784 			// Unpack the expression so we can store to it with a float or float2.
3785 			// It's still an l-value, so it's fine. Most other unpacking of expressions turn them into r-values instead.
3786 			lhs = join("(device ", type_to_glsl(type), "&)", enclose_expression(lhs));
3787 			if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs))
3788 				statement(lhs, " = ", rhs, ";");
3789 		}
3790 		else if (!is_matrix(type))
3791 		{
3792 			string lhs = to_dereferenced_expression(lhs_expression);
3793 			string rhs = to_pointer_expression(rhs_expression);
3794 			if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs))
3795 				statement(lhs, " = ", rhs, ";");
3796 		}
3797 
3798 		register_write(lhs_expression);
3799 	}
3800 }
3801 
expression_ends_with(const string & expr_str,const std::string & ending)3802 static bool expression_ends_with(const string &expr_str, const std::string &ending)
3803 {
3804 	if (expr_str.length() >= ending.length())
3805 		return (expr_str.compare(expr_str.length() - ending.length(), ending.length(), ending) == 0);
3806 	else
3807 		return false;
3808 }
3809 
3810 // Converts the format of the current expression from packed to unpacked,
3811 // by wrapping the expression in a constructor of the appropriate type.
3812 // Also, handle special physical ID remapping scenarios, similar to emit_store_statement().
unpack_expression_type(string expr_str,const SPIRType & type,uint32_t physical_type_id,bool packed,bool row_major)3813 string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type, uint32_t physical_type_id,
3814                                            bool packed, bool row_major)
3815 {
3816 	// Trivial case, nothing to do.
3817 	if (physical_type_id == 0 && !packed)
3818 		return expr_str;
3819 
3820 	const SPIRType *physical_type = nullptr;
3821 	if (physical_type_id)
3822 		physical_type = &get<SPIRType>(physical_type_id);
3823 
3824 	static const char *swizzle_lut[] = {
3825 		".x",
3826 		".xy",
3827 		".xyz",
3828 	};
3829 
3830 	if (physical_type && is_vector(*physical_type) && is_array(*physical_type) &&
3831 	    physical_type->vecsize > type.vecsize && !expression_ends_with(expr_str, swizzle_lut[type.vecsize - 1]))
3832 	{
3833 		// std140 array cases for vectors.
3834 		assert(type.vecsize >= 1 && type.vecsize <= 3);
3835 		return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1];
3836 	}
3837 	else if (physical_type && is_matrix(*physical_type) && is_vector(type) && physical_type->vecsize > type.vecsize)
3838 	{
3839 		// Extract column from padded matrix.
3840 		assert(type.vecsize >= 1 && type.vecsize <= 3);
3841 		return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1];
3842 	}
3843 	else if (is_matrix(type))
3844 	{
3845 		// Packed matrices are stored as arrays of packed vectors. Unfortunately,
3846 		// we can't just pass the array straight to the matrix constructor. We have to
3847 		// pass each vector individually, so that they can be unpacked to normal vectors.
3848 		if (!physical_type)
3849 			physical_type = &type;
3850 
3851 		uint32_t vecsize = type.vecsize;
3852 		uint32_t columns = type.columns;
3853 		if (row_major)
3854 			swap(vecsize, columns);
3855 
3856 		uint32_t physical_vecsize = row_major ? physical_type->columns : physical_type->vecsize;
3857 
3858 		const char *base_type = type.width == 16 ? "half" : "float";
3859 		string unpack_expr = join(base_type, columns, "x", vecsize, "(");
3860 
3861 		const char *load_swiz = "";
3862 
3863 		if (physical_vecsize != vecsize)
3864 			load_swiz = swizzle_lut[vecsize - 1];
3865 
3866 		for (uint32_t i = 0; i < columns; i++)
3867 		{
3868 			if (i > 0)
3869 				unpack_expr += ", ";
3870 
3871 			if (packed)
3872 				unpack_expr += join(base_type, physical_vecsize, "(", expr_str, "[", i, "]", ")", load_swiz);
3873 			else
3874 				unpack_expr += join(expr_str, "[", i, "]", load_swiz);
3875 		}
3876 
3877 		unpack_expr += ")";
3878 		return unpack_expr;
3879 	}
3880 	else
3881 	{
3882 		return join(type_to_glsl(type), "(", expr_str, ")");
3883 	}
3884 }
3885 
3886 // Emits the file header info
emit_header()3887 void CompilerMSL::emit_header()
3888 {
3889 	// This particular line can be overridden during compilation, so make it a flag and not a pragma line.
3890 	if (suppress_missing_prototypes)
3891 		statement("#pragma clang diagnostic ignored \"-Wmissing-prototypes\"");
3892 
3893 	// Disable warning about missing braces for array<T> template to make arrays a value type
3894 	if (spv_function_implementations.count(SPVFuncImplUnsafeArray) != 0)
3895 		statement("#pragma clang diagnostic ignored \"-Wmissing-braces\"");
3896 
3897 	for (auto &pragma : pragma_lines)
3898 		statement(pragma);
3899 
3900 	if (!pragma_lines.empty() || suppress_missing_prototypes)
3901 		statement("");
3902 
3903 	statement("#include <metal_stdlib>");
3904 	statement("#include <simd/simd.h>");
3905 
3906 	for (auto &header : header_lines)
3907 		statement(header);
3908 
3909 	statement("");
3910 	statement("using namespace metal;");
3911 	statement("");
3912 
3913 	for (auto &td : typedef_lines)
3914 		statement(td);
3915 
3916 	if (!typedef_lines.empty())
3917 		statement("");
3918 }
3919 
add_pragma_line(const string & line)3920 void CompilerMSL::add_pragma_line(const string &line)
3921 {
3922 	auto rslt = pragma_lines.insert(line);
3923 	if (rslt.second)
3924 		force_recompile();
3925 }
3926 
add_typedef_line(const string & line)3927 void CompilerMSL::add_typedef_line(const string &line)
3928 {
3929 	auto rslt = typedef_lines.insert(line);
3930 	if (rslt.second)
3931 		force_recompile();
3932 }
3933 
3934 // Template struct like spvUnsafeArray<> need to be declared *before* any resources are declared
emit_custom_templates()3935 void CompilerMSL::emit_custom_templates()
3936 {
3937 	for (const auto &spv_func : spv_function_implementations)
3938 	{
3939 		switch (spv_func)
3940 		{
3941 		case SPVFuncImplUnsafeArray:
3942 			statement("template<typename T, size_t Num>");
3943 			statement("struct spvUnsafeArray");
3944 			begin_scope();
3945 			statement("T elements[Num ? Num : 1];");
3946 			statement("");
3947 			statement("thread T& operator [] (size_t pos) thread");
3948 			begin_scope();
3949 			statement("return elements[pos];");
3950 			end_scope();
3951 			statement("constexpr const thread T& operator [] (size_t pos) const thread");
3952 			begin_scope();
3953 			statement("return elements[pos];");
3954 			end_scope();
3955 			statement("");
3956 			statement("device T& operator [] (size_t pos) device");
3957 			begin_scope();
3958 			statement("return elements[pos];");
3959 			end_scope();
3960 			statement("constexpr const device T& operator [] (size_t pos) const device");
3961 			begin_scope();
3962 			statement("return elements[pos];");
3963 			end_scope();
3964 			statement("");
3965 			statement("constexpr const constant T& operator [] (size_t pos) const constant");
3966 			begin_scope();
3967 			statement("return elements[pos];");
3968 			end_scope();
3969 			statement("");
3970 			statement("threadgroup T& operator [] (size_t pos) threadgroup");
3971 			begin_scope();
3972 			statement("return elements[pos];");
3973 			end_scope();
3974 			statement("constexpr const threadgroup T& operator [] (size_t pos) const threadgroup");
3975 			begin_scope();
3976 			statement("return elements[pos];");
3977 			end_scope();
3978 			end_scope_decl();
3979 			statement("");
3980 			break;
3981 
3982 		default:
3983 			break;
3984 		}
3985 	}
3986 }
3987 
3988 // Emits any needed custom function bodies.
3989 // Metal helper functions must be static force-inline, i.e. static inline __attribute__((always_inline))
3990 // otherwise they will cause problems when linked together in a single Metallib.
emit_custom_functions()3991 void CompilerMSL::emit_custom_functions()
3992 {
3993 	for (uint32_t i = kArrayCopyMultidimMax; i >= 2; i--)
3994 		if (spv_function_implementations.count(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i)))
3995 			spv_function_implementations.insert(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i - 1));
3996 
3997 	if (spv_function_implementations.count(SPVFuncImplDynamicImageSampler))
3998 	{
3999 		// Unfortunately, this one needs a lot of the other functions to compile OK.
4000 		if (!msl_options.supports_msl_version(2))
4001 			SPIRV_CROSS_THROW(
4002 			    "spvDynamicImageSampler requires default-constructible texture objects, which require MSL 2.0.");
4003 		spv_function_implementations.insert(SPVFuncImplForwardArgs);
4004 		spv_function_implementations.insert(SPVFuncImplTextureSwizzle);
4005 		if (msl_options.swizzle_texture_samples)
4006 			spv_function_implementations.insert(SPVFuncImplGatherSwizzle);
4007 		for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
4008 		     i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
4009 			spv_function_implementations.insert(static_cast<SPVFuncImpl>(i));
4010 		spv_function_implementations.insert(SPVFuncImplExpandITUFullRange);
4011 		spv_function_implementations.insert(SPVFuncImplExpandITUNarrowRange);
4012 		spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT709);
4013 		spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT601);
4014 		spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT2020);
4015 	}
4016 
4017 	for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
4018 	     i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
4019 		if (spv_function_implementations.count(static_cast<SPVFuncImpl>(i)))
4020 			spv_function_implementations.insert(SPVFuncImplForwardArgs);
4021 
4022 	if (spv_function_implementations.count(SPVFuncImplTextureSwizzle) ||
4023 	    spv_function_implementations.count(SPVFuncImplGatherSwizzle) ||
4024 	    spv_function_implementations.count(SPVFuncImplGatherCompareSwizzle))
4025 	{
4026 		spv_function_implementations.insert(SPVFuncImplForwardArgs);
4027 		spv_function_implementations.insert(SPVFuncImplGetSwizzle);
4028 	}
4029 
4030 	for (const auto &spv_func : spv_function_implementations)
4031 	{
4032 		switch (spv_func)
4033 		{
4034 		case SPVFuncImplMod:
4035 			statement("// Implementation of the GLSL mod() function, which is slightly different than Metal fmod()");
4036 			statement("template<typename Tx, typename Ty>");
4037 			statement("inline Tx mod(Tx x, Ty y)");
4038 			begin_scope();
4039 			statement("return x - y * floor(x / y);");
4040 			end_scope();
4041 			statement("");
4042 			break;
4043 
4044 		case SPVFuncImplRadians:
4045 			statement("// Implementation of the GLSL radians() function");
4046 			statement("template<typename T>");
4047 			statement("inline T radians(T d)");
4048 			begin_scope();
4049 			statement("return d * T(0.01745329251);");
4050 			end_scope();
4051 			statement("");
4052 			break;
4053 
4054 		case SPVFuncImplDegrees:
4055 			statement("// Implementation of the GLSL degrees() function");
4056 			statement("template<typename T>");
4057 			statement("inline T degrees(T r)");
4058 			begin_scope();
4059 			statement("return r * T(57.2957795131);");
4060 			end_scope();
4061 			statement("");
4062 			break;
4063 
4064 		case SPVFuncImplFindILsb:
4065 			statement("// Implementation of the GLSL findLSB() function");
4066 			statement("template<typename T>");
4067 			statement("inline T spvFindLSB(T x)");
4068 			begin_scope();
4069 			statement("return select(ctz(x), T(-1), x == T(0));");
4070 			end_scope();
4071 			statement("");
4072 			break;
4073 
4074 		case SPVFuncImplFindUMsb:
4075 			statement("// Implementation of the unsigned GLSL findMSB() function");
4076 			statement("template<typename T>");
4077 			statement("inline T spvFindUMSB(T x)");
4078 			begin_scope();
4079 			statement("return select(clz(T(0)) - (clz(x) + T(1)), T(-1), x == T(0));");
4080 			end_scope();
4081 			statement("");
4082 			break;
4083 
4084 		case SPVFuncImplFindSMsb:
4085 			statement("// Implementation of the signed GLSL findMSB() function");
4086 			statement("template<typename T>");
4087 			statement("inline T spvFindSMSB(T x)");
4088 			begin_scope();
4089 			statement("T v = select(x, T(-1) - x, x < T(0));");
4090 			statement("return select(clz(T(0)) - (clz(v) + T(1)), T(-1), v == T(0));");
4091 			end_scope();
4092 			statement("");
4093 			break;
4094 
4095 		case SPVFuncImplSSign:
4096 			statement("// Implementation of the GLSL sign() function for integer types");
4097 			statement("template<typename T, typename E = typename enable_if<is_integral<T>::value>::type>");
4098 			statement("inline T sign(T x)");
4099 			begin_scope();
4100 			statement("return select(select(select(x, T(0), x == T(0)), T(1), x > T(0)), T(-1), x < T(0));");
4101 			end_scope();
4102 			statement("");
4103 			break;
4104 
4105 		case SPVFuncImplArrayCopy:
4106 		case SPVFuncImplArrayOfArrayCopy2Dim:
4107 		case SPVFuncImplArrayOfArrayCopy3Dim:
4108 		case SPVFuncImplArrayOfArrayCopy4Dim:
4109 		case SPVFuncImplArrayOfArrayCopy5Dim:
4110 		case SPVFuncImplArrayOfArrayCopy6Dim:
4111 		{
4112 			// Unfortunately we cannot template on the address space, so combinatorial explosion it is.
4113 			static const char *function_name_tags[] = {
4114 				"FromConstantToStack",     "FromConstantToThreadGroup", "FromStackToStack",
4115 				"FromStackToThreadGroup",  "FromThreadGroupToStack",    "FromThreadGroupToThreadGroup",
4116 				"FromDeviceToDevice",      "FromConstantToDevice",      "FromStackToDevice",
4117 				"FromThreadGroupToDevice", "FromDeviceToStack",         "FromDeviceToThreadGroup",
4118 			};
4119 
4120 			static const char *src_address_space[] = {
4121 				"constant",          "constant",          "thread const", "thread const",
4122 				"threadgroup const", "threadgroup const", "device const", "constant",
4123 				"thread const",      "threadgroup const", "device const", "device const",
4124 			};
4125 
4126 			static const char *dst_address_space[] = {
4127 				"thread", "threadgroup", "thread", "threadgroup", "thread", "threadgroup",
4128 				"device", "device",      "device", "device",      "thread", "threadgroup",
4129 			};
4130 
4131 			for (uint32_t variant = 0; variant < 12; variant++)
4132 			{
4133 				uint32_t dimensions = spv_func - SPVFuncImplArrayCopyMultidimBase;
4134 				string tmp = "template<typename T";
4135 				for (uint8_t i = 0; i < dimensions; i++)
4136 				{
4137 					tmp += ", uint ";
4138 					tmp += 'A' + i;
4139 				}
4140 				tmp += ">";
4141 				statement(tmp);
4142 
4143 				string array_arg;
4144 				for (uint8_t i = 0; i < dimensions; i++)
4145 				{
4146 					array_arg += "[";
4147 					array_arg += 'A' + i;
4148 					array_arg += "]";
4149 				}
4150 
4151 				statement("inline void spvArrayCopy", function_name_tags[variant], dimensions, "(",
4152 				          dst_address_space[variant], " T (&dst)", array_arg, ", ", src_address_space[variant],
4153 				          " T (&src)", array_arg, ")");
4154 
4155 				begin_scope();
4156 				statement("for (uint i = 0; i < A; i++)");
4157 				begin_scope();
4158 
4159 				if (dimensions == 1)
4160 					statement("dst[i] = src[i];");
4161 				else
4162 					statement("spvArrayCopy", function_name_tags[variant], dimensions - 1, "(dst[i], src[i]);");
4163 				end_scope();
4164 				end_scope();
4165 				statement("");
4166 			}
4167 			break;
4168 		}
4169 
4170 		// Support for Metal 2.1's new texture_buffer type.
4171 		case SPVFuncImplTexelBufferCoords:
4172 		{
4173 			if (msl_options.texel_buffer_texture_width > 0)
4174 			{
4175 				string tex_width_str = convert_to_string(msl_options.texel_buffer_texture_width);
4176 				statement("// Returns 2D texture coords corresponding to 1D texel buffer coords");
4177 				statement(force_inline);
4178 				statement("uint2 spvTexelBufferCoord(uint tc)");
4179 				begin_scope();
4180 				statement(join("return uint2(tc % ", tex_width_str, ", tc / ", tex_width_str, ");"));
4181 				end_scope();
4182 				statement("");
4183 			}
4184 			else
4185 			{
4186 				statement("// Returns 2D texture coords corresponding to 1D texel buffer coords");
4187 				statement(
4188 				    "#define spvTexelBufferCoord(tc, tex) uint2((tc) % (tex).get_width(), (tc) / (tex).get_width())");
4189 				statement("");
4190 			}
4191 			break;
4192 		}
4193 
4194 		// Emulate texture2D atomic operations
4195 		case SPVFuncImplImage2DAtomicCoords:
4196 		{
4197 			statement("// Returns buffer coords corresponding to 2D texture coords for emulating 2D texture atomics");
4198 			statement("#define spvImage2DAtomicCoord(tc, tex) (((tex).get_width() * (tc).x) + (tc).y)");
4199 			statement("");
4200 			break;
4201 		}
4202 
4203 		// "fadd" intrinsic support
4204 		case SPVFuncImplFAdd:
4205 			statement("template<typename T>");
4206 			statement("T spvFAdd(T l, T r)");
4207 			begin_scope();
4208 			statement("return fma(T(1), l, r);");
4209 			end_scope();
4210 			statement("");
4211 			break;
4212 
4213 		// "fmul' intrinsic support
4214 		case SPVFuncImplFMul:
4215 			statement("template<typename T>");
4216 			statement("T spvFMul(T l, T r)");
4217 			begin_scope();
4218 			statement("return fma(l, r, T(0));");
4219 			end_scope();
4220 			statement("");
4221 
4222 			statement("template<typename T, int Cols, int Rows>");
4223 			statement("vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)");
4224 			begin_scope();
4225 			statement("vec<T, Cols> res = vec<T, Cols>(0);");
4226 			statement("for (uint i = Rows; i > 0; --i)");
4227 			begin_scope();
4228 			statement("vec<T, Cols> tmp(0);");
4229 			statement("for (uint j = 0; j < Cols; ++j)");
4230 			begin_scope();
4231 			statement("tmp[j] = m[j][i - 1];");
4232 			end_scope();
4233 			statement("res = fma(tmp, vec<T, Cols>(v[i - 1]), res);");
4234 			end_scope();
4235 			statement("return res;");
4236 			end_scope();
4237 			statement("");
4238 
4239 			statement("template<typename T, int Cols, int Rows>");
4240 			statement("vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)");
4241 			begin_scope();
4242 			statement("vec<T, Rows> res = vec<T, Rows>(0);");
4243 			statement("for (uint i = Cols; i > 0; --i)");
4244 			begin_scope();
4245 			statement("res = fma(m[i - 1], vec<T, Rows>(v[i - 1]), res);");
4246 			end_scope();
4247 			statement("return res;");
4248 			end_scope();
4249 			statement("");
4250 
4251 			statement("template<typename T, int LCols, int LRows, int RCols, int RRows>");
4252 			statement(
4253 			    "matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)");
4254 			begin_scope();
4255 			statement("matrix<T, RCols, LRows> res;");
4256 			statement("for (uint i = 0; i < RCols; i++)");
4257 			begin_scope();
4258 			statement("vec<T, RCols> tmp(0);");
4259 			statement("for (uint j = 0; j < LCols; j++)");
4260 			begin_scope();
4261 			statement("tmp = fma(vec<T, RCols>(r[i][j]), l[j], tmp);");
4262 			end_scope();
4263 			statement("res[i] = tmp;");
4264 			end_scope();
4265 			statement("return res;");
4266 			end_scope();
4267 			statement("");
4268 			break;
4269 
4270 		// Emulate texturecube_array with texture2d_array for iOS where this type is not available
4271 		case SPVFuncImplCubemapTo2DArrayFace:
4272 			statement(force_inline);
4273 			statement("float3 spvCubemapTo2DArrayFace(float3 P)");
4274 			begin_scope();
4275 			statement("float3 Coords = abs(P.xyz);");
4276 			statement("float CubeFace = 0;");
4277 			statement("float ProjectionAxis = 0;");
4278 			statement("float u = 0;");
4279 			statement("float v = 0;");
4280 			statement("if (Coords.x >= Coords.y && Coords.x >= Coords.z)");
4281 			begin_scope();
4282 			statement("CubeFace = P.x >= 0 ? 0 : 1;");
4283 			statement("ProjectionAxis = Coords.x;");
4284 			statement("u = P.x >= 0 ? -P.z : P.z;");
4285 			statement("v = -P.y;");
4286 			end_scope();
4287 			statement("else if (Coords.y >= Coords.x && Coords.y >= Coords.z)");
4288 			begin_scope();
4289 			statement("CubeFace = P.y >= 0 ? 2 : 3;");
4290 			statement("ProjectionAxis = Coords.y;");
4291 			statement("u = P.x;");
4292 			statement("v = P.y >= 0 ? P.z : -P.z;");
4293 			end_scope();
4294 			statement("else");
4295 			begin_scope();
4296 			statement("CubeFace = P.z >= 0 ? 4 : 5;");
4297 			statement("ProjectionAxis = Coords.z;");
4298 			statement("u = P.z >= 0 ? P.x : -P.x;");
4299 			statement("v = -P.y;");
4300 			end_scope();
4301 			statement("u = 0.5 * (u/ProjectionAxis + 1);");
4302 			statement("v = 0.5 * (v/ProjectionAxis + 1);");
4303 			statement("return float3(u, v, CubeFace);");
4304 			end_scope();
4305 			statement("");
4306 			break;
4307 
4308 		case SPVFuncImplInverse4x4:
4309 			statement("// Returns the determinant of a 2x2 matrix.");
4310 			statement(force_inline);
4311 			statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
4312 			begin_scope();
4313 			statement("return a1 * b2 - b1 * a2;");
4314 			end_scope();
4315 			statement("");
4316 
4317 			statement("// Returns the determinant of a 3x3 matrix.");
4318 			statement(force_inline);
4319 			statement("float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
4320 			          "float c2, float c3)");
4321 			begin_scope();
4322 			statement("return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * spvDet2x2(a2, a3, "
4323 			          "b2, b3);");
4324 			end_scope();
4325 			statement("");
4326 			statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
4327 			statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
4328 			statement(force_inline);
4329 			statement("float4x4 spvInverse4x4(float4x4 m)");
4330 			begin_scope();
4331 			statement("float4x4 adj;	// The adjoint matrix (inverse after dividing by determinant)");
4332 			statement_no_indent("");
4333 			statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
4334 			statement("adj[0][0] =  spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
4335 			          "m[3][3]);");
4336 			statement("adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
4337 			          "m[3][3]);");
4338 			statement("adj[0][2] =  spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
4339 			          "m[3][3]);");
4340 			statement("adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
4341 			          "m[2][3]);");
4342 			statement_no_indent("");
4343 			statement("adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
4344 			          "m[3][3]);");
4345 			statement("adj[1][1] =  spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
4346 			          "m[3][3]);");
4347 			statement("adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
4348 			          "m[3][3]);");
4349 			statement("adj[1][3] =  spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
4350 			          "m[2][3]);");
4351 			statement_no_indent("");
4352 			statement("adj[2][0] =  spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
4353 			          "m[3][3]);");
4354 			statement("adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
4355 			          "m[3][3]);");
4356 			statement("adj[2][2] =  spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
4357 			          "m[3][3]);");
4358 			statement("adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
4359 			          "m[2][3]);");
4360 			statement_no_indent("");
4361 			statement("adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
4362 			          "m[3][2]);");
4363 			statement("adj[3][1] =  spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
4364 			          "m[3][2]);");
4365 			statement("adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
4366 			          "m[3][2]);");
4367 			statement("adj[3][3] =  spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
4368 			          "m[2][2]);");
4369 			statement_no_indent("");
4370 			statement("// Calculate the determinant as a combination of the cofactors of the first row.");
4371 			statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
4372 			          "* m[3][0]);");
4373 			statement_no_indent("");
4374 			statement("// Divide the classical adjoint matrix by the determinant.");
4375 			statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
4376 			statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
4377 			end_scope();
4378 			statement("");
4379 			break;
4380 
4381 		case SPVFuncImplInverse3x3:
4382 			if (spv_function_implementations.count(SPVFuncImplInverse4x4) == 0)
4383 			{
4384 				statement("// Returns the determinant of a 2x2 matrix.");
4385 				statement(force_inline);
4386 				statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
4387 				begin_scope();
4388 				statement("return a1 * b2 - b1 * a2;");
4389 				end_scope();
4390 				statement("");
4391 			}
4392 
4393 			statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
4394 			statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
4395 			statement(force_inline);
4396 			statement("float3x3 spvInverse3x3(float3x3 m)");
4397 			begin_scope();
4398 			statement("float3x3 adj;	// The adjoint matrix (inverse after dividing by determinant)");
4399 			statement_no_indent("");
4400 			statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
4401 			statement("adj[0][0] =  spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
4402 			statement("adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
4403 			statement("adj[0][2] =  spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
4404 			statement_no_indent("");
4405 			statement("adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
4406 			statement("adj[1][1] =  spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
4407 			statement("adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
4408 			statement_no_indent("");
4409 			statement("adj[2][0] =  spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
4410 			statement("adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
4411 			statement("adj[2][2] =  spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
4412 			statement_no_indent("");
4413 			statement("// Calculate the determinant as a combination of the cofactors of the first row.");
4414 			statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
4415 			statement_no_indent("");
4416 			statement("// Divide the classical adjoint matrix by the determinant.");
4417 			statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
4418 			statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
4419 			end_scope();
4420 			statement("");
4421 			break;
4422 
4423 		case SPVFuncImplInverse2x2:
4424 			statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
4425 			statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
4426 			statement(force_inline);
4427 			statement("float2x2 spvInverse2x2(float2x2 m)");
4428 			begin_scope();
4429 			statement("float2x2 adj;	// The adjoint matrix (inverse after dividing by determinant)");
4430 			statement_no_indent("");
4431 			statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
4432 			statement("adj[0][0] =  m[1][1];");
4433 			statement("adj[0][1] = -m[0][1];");
4434 			statement_no_indent("");
4435 			statement("adj[1][0] = -m[1][0];");
4436 			statement("adj[1][1] =  m[0][0];");
4437 			statement_no_indent("");
4438 			statement("// Calculate the determinant as a combination of the cofactors of the first row.");
4439 			statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
4440 			statement_no_indent("");
4441 			statement("// Divide the classical adjoint matrix by the determinant.");
4442 			statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
4443 			statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
4444 			end_scope();
4445 			statement("");
4446 			break;
4447 
4448 		case SPVFuncImplForwardArgs:
4449 			statement("template<typename T> struct spvRemoveReference { typedef T type; };");
4450 			statement("template<typename T> struct spvRemoveReference<thread T&> { typedef T type; };");
4451 			statement("template<typename T> struct spvRemoveReference<thread T&&> { typedef T type; };");
4452 			statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
4453 			          "spvRemoveReference<T>::type& x)");
4454 			begin_scope();
4455 			statement("return static_cast<thread T&&>(x);");
4456 			end_scope();
4457 			statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
4458 			          "spvRemoveReference<T>::type&& x)");
4459 			begin_scope();
4460 			statement("return static_cast<thread T&&>(x);");
4461 			end_scope();
4462 			statement("");
4463 			break;
4464 
4465 		case SPVFuncImplGetSwizzle:
4466 			statement("enum class spvSwizzle : uint");
4467 			begin_scope();
4468 			statement("none = 0,");
4469 			statement("zero,");
4470 			statement("one,");
4471 			statement("red,");
4472 			statement("green,");
4473 			statement("blue,");
4474 			statement("alpha");
4475 			end_scope_decl();
4476 			statement("");
4477 			statement("template<typename T>");
4478 			statement("inline T spvGetSwizzle(vec<T, 4> x, T c, spvSwizzle s)");
4479 			begin_scope();
4480 			statement("switch (s)");
4481 			begin_scope();
4482 			statement("case spvSwizzle::none:");
4483 			statement("    return c;");
4484 			statement("case spvSwizzle::zero:");
4485 			statement("    return 0;");
4486 			statement("case spvSwizzle::one:");
4487 			statement("    return 1;");
4488 			statement("case spvSwizzle::red:");
4489 			statement("    return x.r;");
4490 			statement("case spvSwizzle::green:");
4491 			statement("    return x.g;");
4492 			statement("case spvSwizzle::blue:");
4493 			statement("    return x.b;");
4494 			statement("case spvSwizzle::alpha:");
4495 			statement("    return x.a;");
4496 			end_scope();
4497 			end_scope();
4498 			statement("");
4499 			break;
4500 
4501 		case SPVFuncImplTextureSwizzle:
4502 			statement("// Wrapper function that swizzles texture samples and fetches.");
4503 			statement("template<typename T>");
4504 			statement("inline vec<T, 4> spvTextureSwizzle(vec<T, 4> x, uint s)");
4505 			begin_scope();
4506 			statement("if (!s)");
4507 			statement("    return x;");
4508 			statement("return vec<T, 4>(spvGetSwizzle(x, x.r, spvSwizzle((s >> 0) & 0xFF)), "
4509 			          "spvGetSwizzle(x, x.g, spvSwizzle((s >> 8) & 0xFF)), spvGetSwizzle(x, x.b, spvSwizzle((s >> 16) "
4510 			          "& 0xFF)), "
4511 			          "spvGetSwizzle(x, x.a, spvSwizzle((s >> 24) & 0xFF)));");
4512 			end_scope();
4513 			statement("");
4514 			statement("template<typename T>");
4515 			statement("inline T spvTextureSwizzle(T x, uint s)");
4516 			begin_scope();
4517 			statement("return spvTextureSwizzle(vec<T, 4>(x, 0, 0, 1), s).x;");
4518 			end_scope();
4519 			statement("");
4520 			break;
4521 
4522 		case SPVFuncImplGatherSwizzle:
4523 			statement("// Wrapper function that swizzles texture gathers.");
4524 			statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
4525 			          "typename... Ts>");
4526 			statement("inline vec<T, 4> spvGatherSwizzle(const thread Tex<T>& t, sampler s, "
4527 			          "uint sw, component c, Ts... params) METAL_CONST_ARG(c)");
4528 			begin_scope();
4529 			statement("if (sw)");
4530 			begin_scope();
4531 			statement("switch (spvSwizzle((sw >> (uint(c) * 8)) & 0xFF))");
4532 			begin_scope();
4533 			statement("case spvSwizzle::none:");
4534 			statement("    break;");
4535 			statement("case spvSwizzle::zero:");
4536 			statement("    return vec<T, 4>(0, 0, 0, 0);");
4537 			statement("case spvSwizzle::one:");
4538 			statement("    return vec<T, 4>(1, 1, 1, 1);");
4539 			statement("case spvSwizzle::red:");
4540 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::x);");
4541 			statement("case spvSwizzle::green:");
4542 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::y);");
4543 			statement("case spvSwizzle::blue:");
4544 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::z);");
4545 			statement("case spvSwizzle::alpha:");
4546 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::w);");
4547 			end_scope();
4548 			end_scope();
4549 			// texture::gather insists on its component parameter being a constant
4550 			// expression, so we need this silly workaround just to compile the shader.
4551 			statement("switch (c)");
4552 			begin_scope();
4553 			statement("case component::x:");
4554 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::x);");
4555 			statement("case component::y:");
4556 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::y);");
4557 			statement("case component::z:");
4558 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::z);");
4559 			statement("case component::w:");
4560 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::w);");
4561 			end_scope();
4562 			end_scope();
4563 			statement("");
4564 			break;
4565 
4566 		case SPVFuncImplGatherCompareSwizzle:
4567 			statement("// Wrapper function that swizzles depth texture gathers.");
4568 			statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
4569 			          "typename... Ts>");
4570 			statement("inline vec<T, 4> spvGatherCompareSwizzle(const thread Tex<T>& t, sampler "
4571 			          "s, uint sw, Ts... params) ");
4572 			begin_scope();
4573 			statement("if (sw)");
4574 			begin_scope();
4575 			statement("switch (spvSwizzle(sw & 0xFF))");
4576 			begin_scope();
4577 			statement("case spvSwizzle::none:");
4578 			statement("case spvSwizzle::red:");
4579 			statement("    break;");
4580 			statement("case spvSwizzle::zero:");
4581 			statement("case spvSwizzle::green:");
4582 			statement("case spvSwizzle::blue:");
4583 			statement("case spvSwizzle::alpha:");
4584 			statement("    return vec<T, 4>(0, 0, 0, 0);");
4585 			statement("case spvSwizzle::one:");
4586 			statement("    return vec<T, 4>(1, 1, 1, 1);");
4587 			end_scope();
4588 			end_scope();
4589 			statement("return t.gather_compare(s, spvForward<Ts>(params)...);");
4590 			end_scope();
4591 			statement("");
4592 			break;
4593 
4594 		case SPVFuncImplSubgroupBallot:
4595 			statement("inline uint4 spvSubgroupBallot(bool value)");
4596 			begin_scope();
4597 			statement("simd_vote vote = simd_ballot(value);");
4598 			statement("// simd_ballot() returns a 64-bit integer-like object, but");
4599 			statement("// SPIR-V callers expect a uint4. We must convert.");
4600 			statement("// FIXME: This won't include higher bits if Apple ever supports");
4601 			statement("// 128 lanes in an SIMD-group.");
4602 			statement("return uint4((uint)((simd_vote::vote_t)vote & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)vote >> "
4603 			          "32) & 0xFFFFFFFF), 0, 0);");
4604 			end_scope();
4605 			statement("");
4606 			break;
4607 
4608 		case SPVFuncImplSubgroupBallotBitExtract:
4609 			statement("inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)");
4610 			begin_scope();
4611 			statement("return !!extract_bits(ballot[bit / 32], bit % 32, 1);");
4612 			end_scope();
4613 			statement("");
4614 			break;
4615 
4616 		case SPVFuncImplSubgroupBallotFindLSB:
4617 			statement("inline uint spvSubgroupBallotFindLSB(uint4 ballot)");
4618 			begin_scope();
4619 			statement("return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + "
4620 			          "ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);");
4621 			end_scope();
4622 			statement("");
4623 			break;
4624 
4625 		case SPVFuncImplSubgroupBallotFindMSB:
4626 			statement("inline uint spvSubgroupBallotFindMSB(uint4 ballot)");
4627 			begin_scope();
4628 			statement("return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - "
4629 			          "(clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), "
4630 			          "ballot.z == 0), ballot.w == 0);");
4631 			end_scope();
4632 			statement("");
4633 			break;
4634 
4635 		case SPVFuncImplSubgroupBallotBitCount:
4636 			statement("inline uint spvSubgroupBallotBitCount(uint4 ballot)");
4637 			begin_scope();
4638 			statement("return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);");
4639 			end_scope();
4640 			statement("");
4641 			statement("inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
4642 			begin_scope();
4643 			statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), "
4644 			          "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), "
4645 			          "uint2(0));");
4646 			statement("return spvSubgroupBallotBitCount(ballot & mask);");
4647 			end_scope();
4648 			statement("");
4649 			statement("inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
4650 			begin_scope();
4651 			statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), "
4652 			          "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));");
4653 			statement("return spvSubgroupBallotBitCount(ballot & mask);");
4654 			end_scope();
4655 			statement("");
4656 			break;
4657 
4658 		case SPVFuncImplSubgroupAllEqual:
4659 			// Metal doesn't provide a function to evaluate this directly. But, we can
4660 			// implement this by comparing every thread's value to one thread's value
4661 			// (in this case, the value of the first active thread). Then, by the transitive
4662 			// property of equality, if all comparisons return true, then they are all equal.
4663 			statement("template<typename T>");
4664 			statement("inline bool spvSubgroupAllEqual(T value)");
4665 			begin_scope();
4666 			statement("return simd_all(value == simd_broadcast_first(value));");
4667 			end_scope();
4668 			statement("");
4669 			statement("template<>");
4670 			statement("inline bool spvSubgroupAllEqual(bool value)");
4671 			begin_scope();
4672 			statement("return simd_all(value) || !simd_any(value);");
4673 			end_scope();
4674 			statement("");
4675 			break;
4676 
4677 		case SPVFuncImplReflectScalar:
4678 			// Metal does not support scalar versions of these functions.
4679 			statement("template<typename T>");
4680 			statement("inline T spvReflect(T i, T n)");
4681 			begin_scope();
4682 			statement("return i - T(2) * i * n * n;");
4683 			end_scope();
4684 			statement("");
4685 			break;
4686 
4687 		case SPVFuncImplRefractScalar:
4688 			// Metal does not support scalar versions of these functions.
4689 			statement("template<typename T>");
4690 			statement("inline T spvRefract(T i, T n, T eta)");
4691 			begin_scope();
4692 			statement("T NoI = n * i;");
4693 			statement("T NoI2 = NoI * NoI;");
4694 			statement("T k = T(1) - eta * eta * (T(1) - NoI2);");
4695 			statement("if (k < T(0))");
4696 			begin_scope();
4697 			statement("return T(0);");
4698 			end_scope();
4699 			statement("else");
4700 			begin_scope();
4701 			statement("return eta * i - (eta * NoI + sqrt(k)) * n;");
4702 			end_scope();
4703 			end_scope();
4704 			statement("");
4705 			break;
4706 
4707 		case SPVFuncImplFaceForwardScalar:
4708 			// Metal does not support scalar versions of these functions.
4709 			statement("template<typename T>");
4710 			statement("inline T spvFaceForward(T n, T i, T nref)");
4711 			begin_scope();
4712 			statement("return i * nref < T(0) ? n : -n;");
4713 			end_scope();
4714 			statement("");
4715 			break;
4716 
4717 		case SPVFuncImplChromaReconstructNearest2Plane:
4718 			statement("template<typename T, typename... LodOptions>");
4719 			statement("inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, sampler "
4720 			          "samp, float2 coord, LodOptions... options)");
4721 			begin_scope();
4722 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4723 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4724 			statement("ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
4725 			statement("return ycbcr;");
4726 			end_scope();
4727 			statement("");
4728 			break;
4729 
4730 		case SPVFuncImplChromaReconstructNearest3Plane:
4731 			statement("template<typename T, typename... LodOptions>");
4732 			statement("inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, "
4733 			          "texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
4734 			begin_scope();
4735 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4736 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4737 			statement("ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4738 			statement("ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4739 			statement("return ycbcr;");
4740 			end_scope();
4741 			statement("");
4742 			break;
4743 
4744 		case SPVFuncImplChromaReconstructLinear422CositedEven2Plane:
4745 			statement("template<typename T, typename... LodOptions>");
4746 			statement("inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
4747 			          "plane1, sampler samp, float2 coord, LodOptions... options)");
4748 			begin_scope();
4749 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4750 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4751 			statement("if (fract(coord.x * plane1.get_width()) != 0.0)");
4752 			begin_scope();
4753 			statement("ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4754 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).rg);");
4755 			end_scope();
4756 			statement("else");
4757 			begin_scope();
4758 			statement("ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
4759 			end_scope();
4760 			statement("return ycbcr;");
4761 			end_scope();
4762 			statement("");
4763 			break;
4764 
4765 		case SPVFuncImplChromaReconstructLinear422CositedEven3Plane:
4766 			statement("template<typename T, typename... LodOptions>");
4767 			statement("inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
4768 			          "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
4769 			begin_scope();
4770 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4771 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4772 			statement("if (fract(coord.x * plane1.get_width()) != 0.0)");
4773 			begin_scope();
4774 			statement("ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4775 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
4776 			statement("ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
4777 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
4778 			end_scope();
4779 			statement("else");
4780 			begin_scope();
4781 			statement("ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4782 			statement("ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4783 			end_scope();
4784 			statement("return ycbcr;");
4785 			end_scope();
4786 			statement("");
4787 			break;
4788 
4789 		case SPVFuncImplChromaReconstructLinear422Midpoint2Plane:
4790 			statement("template<typename T, typename... LodOptions>");
4791 			statement("inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
4792 			          "plane1, sampler samp, float2 coord, LodOptions... options)");
4793 			begin_scope();
4794 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4795 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4796 			statement("int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
4797 			statement("ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4798 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).rg);");
4799 			statement("return ycbcr;");
4800 			end_scope();
4801 			statement("");
4802 			break;
4803 
4804 		case SPVFuncImplChromaReconstructLinear422Midpoint3Plane:
4805 			statement("template<typename T, typename... LodOptions>");
4806 			statement("inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
4807 			          "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
4808 			begin_scope();
4809 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4810 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4811 			statement("int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
4812 			statement("ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4813 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
4814 			statement("ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
4815 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
4816 			statement("return ycbcr;");
4817 			end_scope();
4818 			statement("");
4819 			break;
4820 
4821 		case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane:
4822 			statement("template<typename T, typename... LodOptions>");
4823 			statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
4824 			          "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
4825 			begin_scope();
4826 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4827 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4828 			statement("float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
4829 			statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4830 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4831 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4832 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
4833 			statement("return ycbcr;");
4834 			end_scope();
4835 			statement("");
4836 			break;
4837 
4838 		case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane:
4839 			statement("template<typename T, typename... LodOptions>");
4840 			statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
4841 			          "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
4842 			begin_scope();
4843 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4844 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4845 			statement("float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
4846 			statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4847 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4848 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4849 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
4850 			statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
4851 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4852 			          "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4853 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
4854 			statement("return ycbcr;");
4855 			end_scope();
4856 			statement("");
4857 			break;
4858 
4859 		case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane:
4860 			statement("template<typename T, typename... LodOptions>");
4861 			statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
4862 			          "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
4863 			begin_scope();
4864 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4865 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4866 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
4867 			          "0)) * 0.5);");
4868 			statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4869 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4870 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4871 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
4872 			statement("return ycbcr;");
4873 			end_scope();
4874 			statement("");
4875 			break;
4876 
4877 		case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane:
4878 			statement("template<typename T, typename... LodOptions>");
4879 			statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
4880 			          "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
4881 			begin_scope();
4882 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4883 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4884 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
4885 			          "0)) * 0.5);");
4886 			statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4887 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4888 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4889 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
4890 			statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
4891 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4892 			          "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4893 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
4894 			statement("return ycbcr;");
4895 			end_scope();
4896 			statement("");
4897 			break;
4898 
4899 		case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane:
4900 			statement("template<typename T, typename... LodOptions>");
4901 			statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
4902 			          "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
4903 			begin_scope();
4904 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4905 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4906 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
4907 			          "0.5)) * 0.5);");
4908 			statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4909 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4910 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4911 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
4912 			statement("return ycbcr;");
4913 			end_scope();
4914 			statement("");
4915 			break;
4916 
4917 		case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane:
4918 			statement("template<typename T, typename... LodOptions>");
4919 			statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
4920 			          "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
4921 			begin_scope();
4922 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4923 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4924 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
4925 			          "0.5)) * 0.5);");
4926 			statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4927 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4928 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4929 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
4930 			statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
4931 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4932 			          "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4933 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
4934 			statement("return ycbcr;");
4935 			end_scope();
4936 			statement("");
4937 			break;
4938 
4939 		case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane:
4940 			statement("template<typename T, typename... LodOptions>");
4941 			statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
4942 			          "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
4943 			begin_scope();
4944 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4945 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4946 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
4947 			          "0.5)) * 0.5);");
4948 			statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4949 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4950 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4951 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
4952 			statement("return ycbcr;");
4953 			end_scope();
4954 			statement("");
4955 			break;
4956 
4957 		case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane:
4958 			statement("template<typename T, typename... LodOptions>");
4959 			statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
4960 			          "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
4961 			begin_scope();
4962 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
4963 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
4964 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
4965 			          "0.5)) * 0.5);");
4966 			statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
4967 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4968 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4969 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
4970 			statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
4971 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
4972 			          "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
4973 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
4974 			statement("return ycbcr;");
4975 			end_scope();
4976 			statement("");
4977 			break;
4978 
4979 		case SPVFuncImplExpandITUFullRange:
4980 			statement("template<typename T>");
4981 			statement("inline vec<T, 4> spvExpandITUFullRange(vec<T, 4> ycbcr, int n)");
4982 			begin_scope();
4983 			statement("ycbcr.br -= exp2(T(n-1))/(exp2(T(n))-1);");
4984 			statement("return ycbcr;");
4985 			end_scope();
4986 			statement("");
4987 			break;
4988 
4989 		case SPVFuncImplExpandITUNarrowRange:
4990 			statement("template<typename T>");
4991 			statement("inline vec<T, 4> spvExpandITUNarrowRange(vec<T, 4> ycbcr, int n)");
4992 			begin_scope();
4993 			statement("ycbcr.g = (ycbcr.g * (exp2(T(n)) - 1) - ldexp(T(16), n - 8))/ldexp(T(219), n - 8);");
4994 			statement("ycbcr.br = (ycbcr.br * (exp2(T(n)) - 1) - ldexp(T(128), n - 8))/ldexp(T(224), n - 8);");
4995 			statement("return ycbcr;");
4996 			end_scope();
4997 			statement("");
4998 			break;
4999 
5000 		case SPVFuncImplConvertYCbCrBT709:
5001 			statement("// cf. Khronos Data Format Specification, section 15.1.1");
5002 			statement("constant float3x3 spvBT709Factors = {{1, 1, 1}, {0, -0.13397432/0.7152, 1.8556}, {1.5748, "
5003 			          "-0.33480248/0.7152, 0}};");
5004 			statement("");
5005 			statement("template<typename T>");
5006 			statement("inline vec<T, 4> spvConvertYCbCrBT709(vec<T, 4> ycbcr)");
5007 			begin_scope();
5008 			statement("vec<T, 4> rgba;");
5009 			statement("rgba.rgb = vec<T, 3>(spvBT709Factors * ycbcr.gbr);");
5010 			statement("rgba.a = ycbcr.a;");
5011 			statement("return rgba;");
5012 			end_scope();
5013 			statement("");
5014 			break;
5015 
5016 		case SPVFuncImplConvertYCbCrBT601:
5017 			statement("// cf. Khronos Data Format Specification, section 15.1.2");
5018 			statement("constant float3x3 spvBT601Factors = {{1, 1, 1}, {0, -0.202008/0.587, 1.772}, {1.402, "
5019 			          "-0.419198/0.587, 0}};");
5020 			statement("");
5021 			statement("template<typename T>");
5022 			statement("inline vec<T, 4> spvConvertYCbCrBT601(vec<T, 4> ycbcr)");
5023 			begin_scope();
5024 			statement("vec<T, 4> rgba;");
5025 			statement("rgba.rgb = vec<T, 3>(spvBT601Factors * ycbcr.gbr);");
5026 			statement("rgba.a = ycbcr.a;");
5027 			statement("return rgba;");
5028 			end_scope();
5029 			statement("");
5030 			break;
5031 
5032 		case SPVFuncImplConvertYCbCrBT2020:
5033 			statement("// cf. Khronos Data Format Specification, section 15.1.3");
5034 			statement("constant float3x3 spvBT2020Factors = {{1, 1, 1}, {0, -0.11156702/0.6780, 1.8814}, {1.4746, "
5035 			          "-0.38737742/0.6780, 0}};");
5036 			statement("");
5037 			statement("template<typename T>");
5038 			statement("inline vec<T, 4> spvConvertYCbCrBT2020(vec<T, 4> ycbcr)");
5039 			begin_scope();
5040 			statement("vec<T, 4> rgba;");
5041 			statement("rgba.rgb = vec<T, 3>(spvBT2020Factors * ycbcr.gbr);");
5042 			statement("rgba.a = ycbcr.a;");
5043 			statement("return rgba;");
5044 			end_scope();
5045 			statement("");
5046 			break;
5047 
5048 		case SPVFuncImplDynamicImageSampler:
5049 			statement("enum class spvFormatResolution");
5050 			begin_scope();
5051 			statement("_444 = 0,");
5052 			statement("_422,");
5053 			statement("_420");
5054 			end_scope_decl();
5055 			statement("");
5056 			statement("enum class spvChromaFilter");
5057 			begin_scope();
5058 			statement("nearest = 0,");
5059 			statement("linear");
5060 			end_scope_decl();
5061 			statement("");
5062 			statement("enum class spvXChromaLocation");
5063 			begin_scope();
5064 			statement("cosited_even = 0,");
5065 			statement("midpoint");
5066 			end_scope_decl();
5067 			statement("");
5068 			statement("enum class spvYChromaLocation");
5069 			begin_scope();
5070 			statement("cosited_even = 0,");
5071 			statement("midpoint");
5072 			end_scope_decl();
5073 			statement("");
5074 			statement("enum class spvYCbCrModelConversion");
5075 			begin_scope();
5076 			statement("rgb_identity = 0,");
5077 			statement("ycbcr_identity,");
5078 			statement("ycbcr_bt_709,");
5079 			statement("ycbcr_bt_601,");
5080 			statement("ycbcr_bt_2020");
5081 			end_scope_decl();
5082 			statement("");
5083 			statement("enum class spvYCbCrRange");
5084 			begin_scope();
5085 			statement("itu_full = 0,");
5086 			statement("itu_narrow");
5087 			end_scope_decl();
5088 			statement("");
5089 			statement("struct spvComponentBits");
5090 			begin_scope();
5091 			statement("constexpr explicit spvComponentBits(int v) thread : value(v) {}");
5092 			statement("uchar value : 6;");
5093 			end_scope_decl();
5094 			statement("// A class corresponding to metal::sampler which holds sampler");
5095 			statement("// Y'CbCr conversion info.");
5096 			statement("struct spvYCbCrSampler");
5097 			begin_scope();
5098 			statement("constexpr spvYCbCrSampler() thread : val(build()) {}");
5099 			statement("template<typename... Ts>");
5100 			statement("constexpr spvYCbCrSampler(Ts... t) thread : val(build(t...)) {}");
5101 			statement("constexpr spvYCbCrSampler(const thread spvYCbCrSampler& s) thread = default;");
5102 			statement("");
5103 			statement("spvFormatResolution get_resolution() const thread");
5104 			begin_scope();
5105 			statement("return spvFormatResolution((val & resolution_mask) >> resolution_base);");
5106 			end_scope();
5107 			statement("spvChromaFilter get_chroma_filter() const thread");
5108 			begin_scope();
5109 			statement("return spvChromaFilter((val & chroma_filter_mask) >> chroma_filter_base);");
5110 			end_scope();
5111 			statement("spvXChromaLocation get_x_chroma_offset() const thread");
5112 			begin_scope();
5113 			statement("return spvXChromaLocation((val & x_chroma_off_mask) >> x_chroma_off_base);");
5114 			end_scope();
5115 			statement("spvYChromaLocation get_y_chroma_offset() const thread");
5116 			begin_scope();
5117 			statement("return spvYChromaLocation((val & y_chroma_off_mask) >> y_chroma_off_base);");
5118 			end_scope();
5119 			statement("spvYCbCrModelConversion get_ycbcr_model() const thread");
5120 			begin_scope();
5121 			statement("return spvYCbCrModelConversion((val & ycbcr_model_mask) >> ycbcr_model_base);");
5122 			end_scope();
5123 			statement("spvYCbCrRange get_ycbcr_range() const thread");
5124 			begin_scope();
5125 			statement("return spvYCbCrRange((val & ycbcr_range_mask) >> ycbcr_range_base);");
5126 			end_scope();
5127 			statement("int get_bpc() const thread { return (val & bpc_mask) >> bpc_base; }");
5128 			statement("");
5129 			statement("private:");
5130 			statement("ushort val;");
5131 			statement("");
5132 			statement("constexpr static constant ushort resolution_bits = 2;");
5133 			statement("constexpr static constant ushort chroma_filter_bits = 2;");
5134 			statement("constexpr static constant ushort x_chroma_off_bit = 1;");
5135 			statement("constexpr static constant ushort y_chroma_off_bit = 1;");
5136 			statement("constexpr static constant ushort ycbcr_model_bits = 3;");
5137 			statement("constexpr static constant ushort ycbcr_range_bit = 1;");
5138 			statement("constexpr static constant ushort bpc_bits = 6;");
5139 			statement("");
5140 			statement("constexpr static constant ushort resolution_base = 0;");
5141 			statement("constexpr static constant ushort chroma_filter_base = 2;");
5142 			statement("constexpr static constant ushort x_chroma_off_base = 4;");
5143 			statement("constexpr static constant ushort y_chroma_off_base = 5;");
5144 			statement("constexpr static constant ushort ycbcr_model_base = 6;");
5145 			statement("constexpr static constant ushort ycbcr_range_base = 9;");
5146 			statement("constexpr static constant ushort bpc_base = 10;");
5147 			statement("");
5148 			statement(
5149 			    "constexpr static constant ushort resolution_mask = ((1 << resolution_bits) - 1) << resolution_base;");
5150 			statement("constexpr static constant ushort chroma_filter_mask = ((1 << chroma_filter_bits) - 1) << "
5151 			          "chroma_filter_base;");
5152 			statement("constexpr static constant ushort x_chroma_off_mask = ((1 << x_chroma_off_bit) - 1) << "
5153 			          "x_chroma_off_base;");
5154 			statement("constexpr static constant ushort y_chroma_off_mask = ((1 << y_chroma_off_bit) - 1) << "
5155 			          "y_chroma_off_base;");
5156 			statement("constexpr static constant ushort ycbcr_model_mask = ((1 << ycbcr_model_bits) - 1) << "
5157 			          "ycbcr_model_base;");
5158 			statement("constexpr static constant ushort ycbcr_range_mask = ((1 << ycbcr_range_bit) - 1) << "
5159 			          "ycbcr_range_base;");
5160 			statement("constexpr static constant ushort bpc_mask = ((1 << bpc_bits) - 1) << bpc_base;");
5161 			statement("");
5162 			statement("static constexpr ushort build()");
5163 			begin_scope();
5164 			statement("return 0;");
5165 			end_scope();
5166 			statement("");
5167 			statement("template<typename... Ts>");
5168 			statement("static constexpr ushort build(spvFormatResolution res, Ts... t)");
5169 			begin_scope();
5170 			statement("return (ushort(res) << resolution_base) | (build(t...) & ~resolution_mask);");
5171 			end_scope();
5172 			statement("");
5173 			statement("template<typename... Ts>");
5174 			statement("static constexpr ushort build(spvChromaFilter filt, Ts... t)");
5175 			begin_scope();
5176 			statement("return (ushort(filt) << chroma_filter_base) | (build(t...) & ~chroma_filter_mask);");
5177 			end_scope();
5178 			statement("");
5179 			statement("template<typename... Ts>");
5180 			statement("static constexpr ushort build(spvXChromaLocation loc, Ts... t)");
5181 			begin_scope();
5182 			statement("return (ushort(loc) << x_chroma_off_base) | (build(t...) & ~x_chroma_off_mask);");
5183 			end_scope();
5184 			statement("");
5185 			statement("template<typename... Ts>");
5186 			statement("static constexpr ushort build(spvYChromaLocation loc, Ts... t)");
5187 			begin_scope();
5188 			statement("return (ushort(loc) << y_chroma_off_base) | (build(t...) & ~y_chroma_off_mask);");
5189 			end_scope();
5190 			statement("");
5191 			statement("template<typename... Ts>");
5192 			statement("static constexpr ushort build(spvYCbCrModelConversion model, Ts... t)");
5193 			begin_scope();
5194 			statement("return (ushort(model) << ycbcr_model_base) | (build(t...) & ~ycbcr_model_mask);");
5195 			end_scope();
5196 			statement("");
5197 			statement("template<typename... Ts>");
5198 			statement("static constexpr ushort build(spvYCbCrRange range, Ts... t)");
5199 			begin_scope();
5200 			statement("return (ushort(range) << ycbcr_range_base) | (build(t...) & ~ycbcr_range_mask);");
5201 			end_scope();
5202 			statement("");
5203 			statement("template<typename... Ts>");
5204 			statement("static constexpr ushort build(spvComponentBits bpc, Ts... t)");
5205 			begin_scope();
5206 			statement("return (ushort(bpc.value) << bpc_base) | (build(t...) & ~bpc_mask);");
5207 			end_scope();
5208 			end_scope_decl();
5209 			statement("");
5210 			statement("// A class which can hold up to three textures and a sampler, including");
5211 			statement("// Y'CbCr conversion info, used to pass combined image-samplers");
5212 			statement("// dynamically to functions.");
5213 			statement("template<typename T>");
5214 			statement("struct spvDynamicImageSampler");
5215 			begin_scope();
5216 			statement("texture2d<T> plane0;");
5217 			statement("texture2d<T> plane1;");
5218 			statement("texture2d<T> plane2;");
5219 			statement("sampler samp;");
5220 			statement("spvYCbCrSampler ycbcr_samp;");
5221 			statement("uint swizzle = 0;");
5222 			statement("");
5223 			if (msl_options.swizzle_texture_samples)
5224 			{
5225 				statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, uint sw) thread :");
5226 				statement("    plane0(tex), samp(samp), swizzle(sw) {}");
5227 			}
5228 			else
5229 			{
5230 				statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp) thread :");
5231 				statement("    plane0(tex), samp(samp) {}");
5232 			}
5233 			statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, spvYCbCrSampler ycbcr_samp, "
5234 			          "uint sw) thread :");
5235 			statement("    plane0(tex), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
5236 			statement("constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1,");
5237 			statement("                                 sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
5238 			statement("    plane0(plane0), plane1(plane1), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
5239 			statement(
5240 			    "constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1, texture2d<T> plane2,");
5241 			statement("                                 sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
5242 			statement("    plane0(plane0), plane1(plane1), plane2(plane2), samp(samp), ycbcr_samp(ycbcr_samp), "
5243 			          "swizzle(sw) {}");
5244 			statement("");
5245 			// XXX This is really hard to follow... I've left comments to make it a bit easier.
5246 			statement("template<typename... LodOptions>");
5247 			statement("vec<T, 4> do_sample(float2 coord, LodOptions... options) const thread");
5248 			begin_scope();
5249 			statement("if (!is_null_texture(plane1))");
5250 			begin_scope();
5251 			statement("if (ycbcr_samp.get_resolution() == spvFormatResolution::_444 ||");
5252 			statement("    ycbcr_samp.get_chroma_filter() == spvChromaFilter::nearest)");
5253 			begin_scope();
5254 			statement("if (!is_null_texture(plane2))");
5255 			statement("    return spvChromaReconstructNearest(plane0, plane1, plane2, samp, coord,");
5256 			statement("                                       spvForward<LodOptions>(options)...);");
5257 			statement(
5258 			    "return spvChromaReconstructNearest(plane0, plane1, samp, coord, spvForward<LodOptions>(options)...);");
5259 			end_scope(); // if (resolution == 422 || chroma_filter == nearest)
5260 			statement("switch (ycbcr_samp.get_resolution())");
5261 			begin_scope();
5262 			statement("case spvFormatResolution::_444: break;");
5263 			statement("case spvFormatResolution::_422:");
5264 			begin_scope();
5265 			statement("switch (ycbcr_samp.get_x_chroma_offset())");
5266 			begin_scope();
5267 			statement("case spvXChromaLocation::cosited_even:");
5268 			statement("    if (!is_null_texture(plane2))");
5269 			statement("        return spvChromaReconstructLinear422CositedEven(");
5270 			statement("            plane0, plane1, plane2, samp,");
5271 			statement("            coord, spvForward<LodOptions>(options)...);");
5272 			statement("    return spvChromaReconstructLinear422CositedEven(");
5273 			statement("        plane0, plane1, samp, coord,");
5274 			statement("        spvForward<LodOptions>(options)...);");
5275 			statement("case spvXChromaLocation::midpoint:");
5276 			statement("    if (!is_null_texture(plane2))");
5277 			statement("        return spvChromaReconstructLinear422Midpoint(");
5278 			statement("            plane0, plane1, plane2, samp,");
5279 			statement("            coord, spvForward<LodOptions>(options)...);");
5280 			statement("    return spvChromaReconstructLinear422Midpoint(");
5281 			statement("        plane0, plane1, samp, coord,");
5282 			statement("        spvForward<LodOptions>(options)...);");
5283 			end_scope(); // switch (x_chroma_offset)
5284 			end_scope(); // case 422:
5285 			statement("case spvFormatResolution::_420:");
5286 			begin_scope();
5287 			statement("switch (ycbcr_samp.get_x_chroma_offset())");
5288 			begin_scope();
5289 			statement("case spvXChromaLocation::cosited_even:");
5290 			begin_scope();
5291 			statement("switch (ycbcr_samp.get_y_chroma_offset())");
5292 			begin_scope();
5293 			statement("case spvYChromaLocation::cosited_even:");
5294 			statement("    if (!is_null_texture(plane2))");
5295 			statement("        return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
5296 			statement("            plane0, plane1, plane2, samp,");
5297 			statement("            coord, spvForward<LodOptions>(options)...);");
5298 			statement("    return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
5299 			statement("        plane0, plane1, samp, coord,");
5300 			statement("        spvForward<LodOptions>(options)...);");
5301 			statement("case spvYChromaLocation::midpoint:");
5302 			statement("    if (!is_null_texture(plane2))");
5303 			statement("        return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
5304 			statement("            plane0, plane1, plane2, samp,");
5305 			statement("            coord, spvForward<LodOptions>(options)...);");
5306 			statement("    return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
5307 			statement("        plane0, plane1, samp, coord,");
5308 			statement("        spvForward<LodOptions>(options)...);");
5309 			end_scope(); // switch (y_chroma_offset)
5310 			end_scope(); // case x::cosited_even:
5311 			statement("case spvXChromaLocation::midpoint:");
5312 			begin_scope();
5313 			statement("switch (ycbcr_samp.get_y_chroma_offset())");
5314 			begin_scope();
5315 			statement("case spvYChromaLocation::cosited_even:");
5316 			statement("    if (!is_null_texture(plane2))");
5317 			statement("        return spvChromaReconstructLinear420XMidpointYCositedEven(");
5318 			statement("            plane0, plane1, plane2, samp,");
5319 			statement("            coord, spvForward<LodOptions>(options)...);");
5320 			statement("    return spvChromaReconstructLinear420XMidpointYCositedEven(");
5321 			statement("        plane0, plane1, samp, coord,");
5322 			statement("        spvForward<LodOptions>(options)...);");
5323 			statement("case spvYChromaLocation::midpoint:");
5324 			statement("    if (!is_null_texture(plane2))");
5325 			statement("        return spvChromaReconstructLinear420XMidpointYMidpoint(");
5326 			statement("            plane0, plane1, plane2, samp,");
5327 			statement("            coord, spvForward<LodOptions>(options)...);");
5328 			statement("    return spvChromaReconstructLinear420XMidpointYMidpoint(");
5329 			statement("        plane0, plane1, samp, coord,");
5330 			statement("        spvForward<LodOptions>(options)...);");
5331 			end_scope(); // switch (y_chroma_offset)
5332 			end_scope(); // case x::midpoint
5333 			end_scope(); // switch (x_chroma_offset)
5334 			end_scope(); // case 420:
5335 			end_scope(); // switch (resolution)
5336 			end_scope(); // if (multiplanar)
5337 			statement("return plane0.sample(samp, coord, spvForward<LodOptions>(options)...);");
5338 			end_scope(); // do_sample()
5339 			statement("template <typename... LodOptions>");
5340 			statement("vec<T, 4> sample(float2 coord, LodOptions... options) const thread");
5341 			begin_scope();
5342 			statement(
5343 			    "vec<T, 4> s = spvTextureSwizzle(do_sample(coord, spvForward<LodOptions>(options)...), swizzle);");
5344 			statement("if (ycbcr_samp.get_ycbcr_model() == spvYCbCrModelConversion::rgb_identity)");
5345 			statement("    return s;");
5346 			statement("");
5347 			statement("switch (ycbcr_samp.get_ycbcr_range())");
5348 			begin_scope();
5349 			statement("case spvYCbCrRange::itu_full:");
5350 			statement("    s = spvExpandITUFullRange(s, ycbcr_samp.get_bpc());");
5351 			statement("    break;");
5352 			statement("case spvYCbCrRange::itu_narrow:");
5353 			statement("    s = spvExpandITUNarrowRange(s, ycbcr_samp.get_bpc());");
5354 			statement("    break;");
5355 			end_scope();
5356 			statement("");
5357 			statement("switch (ycbcr_samp.get_ycbcr_model())");
5358 			begin_scope();
5359 			statement("case spvYCbCrModelConversion::rgb_identity:"); // Silence Clang warning
5360 			statement("case spvYCbCrModelConversion::ycbcr_identity:");
5361 			statement("    return s;");
5362 			statement("case spvYCbCrModelConversion::ycbcr_bt_709:");
5363 			statement("    return spvConvertYCbCrBT709(s);");
5364 			statement("case spvYCbCrModelConversion::ycbcr_bt_601:");
5365 			statement("    return spvConvertYCbCrBT601(s);");
5366 			statement("case spvYCbCrModelConversion::ycbcr_bt_2020:");
5367 			statement("    return spvConvertYCbCrBT2020(s);");
5368 			end_scope();
5369 			end_scope();
5370 			statement("");
5371 			// Sampler Y'CbCr conversion forbids offsets.
5372 			statement("vec<T, 4> sample(float2 coord, int2 offset) const thread");
5373 			begin_scope();
5374 			if (msl_options.swizzle_texture_samples)
5375 				statement("return spvTextureSwizzle(plane0.sample(samp, coord, offset), swizzle);");
5376 			else
5377 				statement("return plane0.sample(samp, coord, offset);");
5378 			end_scope();
5379 			statement("template<typename lod_options>");
5380 			statement("vec<T, 4> sample(float2 coord, lod_options options, int2 offset) const thread");
5381 			begin_scope();
5382 			if (msl_options.swizzle_texture_samples)
5383 				statement("return spvTextureSwizzle(plane0.sample(samp, coord, options, offset), swizzle);");
5384 			else
5385 				statement("return plane0.sample(samp, coord, options, offset);");
5386 			end_scope();
5387 			statement("#if __HAVE_MIN_LOD_CLAMP__");
5388 			statement("vec<T, 4> sample(float2 coord, bias b, min_lod_clamp min_lod, int2 offset) const thread");
5389 			begin_scope();
5390 			statement("return plane0.sample(samp, coord, b, min_lod, offset);");
5391 			end_scope();
5392 			statement(
5393 			    "vec<T, 4> sample(float2 coord, gradient2d grad, min_lod_clamp min_lod, int2 offset) const thread");
5394 			begin_scope();
5395 			statement("return plane0.sample(samp, coord, grad, min_lod, offset);");
5396 			end_scope();
5397 			statement("#endif");
5398 			statement("");
5399 			// Y'CbCr conversion forbids all operations but sampling.
5400 			statement("vec<T, 4> read(uint2 coord, uint lod = 0) const thread");
5401 			begin_scope();
5402 			statement("return plane0.read(coord, lod);");
5403 			end_scope();
5404 			statement("");
5405 			statement("vec<T, 4> gather(float2 coord, int2 offset = int2(0), component c = component::x) const thread");
5406 			begin_scope();
5407 			if (msl_options.swizzle_texture_samples)
5408 				statement("return spvGatherSwizzle(plane0, samp, swizzle, c, coord, offset);");
5409 			else
5410 				statement("return plane0.gather(samp, coord, offset, c);");
5411 			end_scope();
5412 			end_scope_decl();
5413 			statement("");
5414 
5415 		default:
5416 			break;
5417 		}
5418 	}
5419 }
5420 
5421 // Undefined global memory is not allowed in MSL.
5422 // Declare constant and init to zeros. Use {}, as global constructors can break Metal.
declare_undefined_values()5423 void CompilerMSL::declare_undefined_values()
5424 {
5425 	bool emitted = false;
5426 	ir.for_each_typed_id<SPIRUndef>([&](uint32_t, SPIRUndef &undef) {
5427 		auto &type = this->get<SPIRType>(undef.basetype);
5428 		// OpUndef can be void for some reason ...
5429 		if (type.basetype == SPIRType::Void)
5430 			return;
5431 
5432 		statement("constant ", variable_decl(type, to_name(undef.self), undef.self), " = {};");
5433 		emitted = true;
5434 	});
5435 
5436 	if (emitted)
5437 		statement("");
5438 }
5439 
declare_constant_arrays()5440 void CompilerMSL::declare_constant_arrays()
5441 {
5442 	bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
5443 
5444 	// MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
5445 	// global constants directly, so we are able to use constants as variable expressions.
5446 	bool emitted = false;
5447 
5448 	ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
5449 		if (c.specialization)
5450 			return;
5451 
5452 		auto &type = this->get<SPIRType>(c.constant_type);
5453 		// Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries.
5454 		// FIXME: However, hoisting constants to main() means we need to pass down constant arrays to leaf functions if they are used there.
5455 		// If there are multiple functions in the module, drop this case to avoid breaking use cases which do not need to
5456 		// link into Metal libraries. This is hacky.
5457 		if (!type.array.empty() && (!fully_inlined || is_scalar(type) || is_vector(type)))
5458 		{
5459 			auto name = to_name(c.self);
5460 			statement("constant ", variable_decl(type, name), " = ", constant_expression(c), ";");
5461 			emitted = true;
5462 		}
5463 	});
5464 
5465 	if (emitted)
5466 		statement("");
5467 }
5468 
5469 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
declare_complex_constant_arrays()5470 void CompilerMSL::declare_complex_constant_arrays()
5471 {
5472 	// If we do not have a fully inlined module, we did not opt in to
5473 	// declaring constant arrays of complex types. See CompilerMSL::declare_constant_arrays().
5474 	bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
5475 	if (!fully_inlined)
5476 		return;
5477 
5478 	// MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
5479 	// global constants directly, so we are able to use constants as variable expressions.
5480 	bool emitted = false;
5481 
5482 	ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
5483 		if (c.specialization)
5484 			return;
5485 
5486 		auto &type = this->get<SPIRType>(c.constant_type);
5487 		if (!type.array.empty() && !(is_scalar(type) || is_vector(type)))
5488 		{
5489 			auto name = to_name(c.self);
5490 			statement("", variable_decl(type, name), " = ", constant_expression(c), ";");
5491 			emitted = true;
5492 		}
5493 	});
5494 
5495 	if (emitted)
5496 		statement("");
5497 }
5498 
emit_resources()5499 void CompilerMSL::emit_resources()
5500 {
5501 	declare_constant_arrays();
5502 	declare_undefined_values();
5503 
5504 	// Emit the special [[stage_in]] and [[stage_out]] interface blocks which we created.
5505 	emit_interface_block(stage_out_var_id);
5506 	emit_interface_block(patch_stage_out_var_id);
5507 	emit_interface_block(stage_in_var_id);
5508 	emit_interface_block(patch_stage_in_var_id);
5509 }
5510 
5511 // Emit declarations for the specialization Metal function constants
emit_specialization_constants_and_structs()5512 void CompilerMSL::emit_specialization_constants_and_structs()
5513 {
5514 	SpecializationConstant wg_x, wg_y, wg_z;
5515 	ID workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
5516 	bool emitted = false;
5517 
5518 	unordered_set<uint32_t> declared_structs;
5519 	unordered_set<uint32_t> aligned_structs;
5520 
5521 	// First, we need to deal with scalar block layout.
5522 	// It is possible that a struct may have to be placed at an alignment which does not match the innate alignment of the struct itself.
5523 	// In that case, if such a case exists for a struct, we must force that all elements of the struct become packed_ types.
5524 	// This makes the struct alignment as small as physically possible.
5525 	// When we actually align the struct later, we can insert padding as necessary to make the packed members behave like normally aligned types.
5526 	ir.for_each_typed_id<SPIRType>([&](uint32_t type_id, const SPIRType &type) {
5527 		if (type.basetype == SPIRType::Struct &&
5528 		    has_extended_decoration(type_id, SPIRVCrossDecorationBufferBlockRepacked))
5529 			mark_scalar_layout_structs(type);
5530 	});
5531 
5532 	// Very particular use of the soft loop lock.
5533 	// align_struct may need to create custom types on the fly, but we don't care about
5534 	// these types for purpose of iterating over them in ir.ids_for_type and friends.
5535 	auto loop_lock = ir.create_loop_soft_lock();
5536 
5537 	for (auto &id_ : ir.ids_for_constant_or_type)
5538 	{
5539 		auto &id = ir.ids[id_];
5540 
5541 		if (id.get_type() == TypeConstant)
5542 		{
5543 			auto &c = id.get<SPIRConstant>();
5544 
5545 			if (c.self == workgroup_size_id)
5546 			{
5547 				// TODO: This can be expressed as a [[threads_per_threadgroup]] input semantic, but we need to know
5548 				// the work group size at compile time in SPIR-V, and [[threads_per_threadgroup]] would need to be passed around as a global.
5549 				// The work group size may be a specialization constant.
5550 				statement("constant uint3 ", builtin_to_glsl(BuiltInWorkgroupSize, StorageClassWorkgroup),
5551 				          " [[maybe_unused]] = ", constant_expression(get<SPIRConstant>(workgroup_size_id)), ";");
5552 				emitted = true;
5553 			}
5554 			else if (c.specialization)
5555 			{
5556 				auto &type = get<SPIRType>(c.constant_type);
5557 				string sc_type_name = type_to_glsl(type);
5558 				string sc_name = to_name(c.self);
5559 				string sc_tmp_name = sc_name + "_tmp";
5560 
5561 				// Function constants are only supported in MSL 1.2 and later.
5562 				// If we don't support it just declare the "default" directly.
5563 				// This "default" value can be overridden to the true specialization constant by the API user.
5564 				// Specialization constants which are used as array length expressions cannot be function constants in MSL,
5565 				// so just fall back to macros.
5566 				if (msl_options.supports_msl_version(1, 2) && has_decoration(c.self, DecorationSpecId) &&
5567 				    !c.is_used_as_array_length)
5568 				{
5569 					uint32_t constant_id = get_decoration(c.self, DecorationSpecId);
5570 					// Only scalar, non-composite values can be function constants.
5571 					statement("constant ", sc_type_name, " ", sc_tmp_name, " [[function_constant(", constant_id,
5572 					          ")]];");
5573 					statement("constant ", sc_type_name, " ", sc_name, " = is_function_constant_defined(", sc_tmp_name,
5574 					          ") ? ", sc_tmp_name, " : ", constant_expression(c), ";");
5575 				}
5576 				else if (has_decoration(c.self, DecorationSpecId))
5577 				{
5578 					// Fallback to macro overrides.
5579 					c.specialization_constant_macro_name =
5580 					    constant_value_macro_name(get_decoration(c.self, DecorationSpecId));
5581 
5582 					statement("#ifndef ", c.specialization_constant_macro_name);
5583 					statement("#define ", c.specialization_constant_macro_name, " ", constant_expression(c));
5584 					statement("#endif");
5585 					statement("constant ", sc_type_name, " ", sc_name, " = ", c.specialization_constant_macro_name,
5586 					          ";");
5587 				}
5588 				else
5589 				{
5590 					// Composite specialization constants must be built from other specialization constants.
5591 					statement("constant ", sc_type_name, " ", sc_name, " = ", constant_expression(c), ";");
5592 				}
5593 				emitted = true;
5594 			}
5595 		}
5596 		else if (id.get_type() == TypeConstantOp)
5597 		{
5598 			auto &c = id.get<SPIRConstantOp>();
5599 			auto &type = get<SPIRType>(c.basetype);
5600 			auto name = to_name(c.self);
5601 			statement("constant ", variable_decl(type, name), " = ", constant_op_expression(c), ";");
5602 			emitted = true;
5603 		}
5604 		else if (id.get_type() == TypeType)
5605 		{
5606 			// Output non-builtin interface structs. These include local function structs
5607 			// and structs nested within uniform and read-write buffers.
5608 			auto &type = id.get<SPIRType>();
5609 			TypeID type_id = type.self;
5610 
5611 			bool is_struct = (type.basetype == SPIRType::Struct) && type.array.empty() && !type.pointer;
5612 			bool is_block =
5613 			    has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock);
5614 
5615 			bool is_builtin_block = is_block && is_builtin_type(type);
5616 			bool is_declarable_struct = is_struct && !is_builtin_block;
5617 
5618 			// We'll declare this later.
5619 			if (stage_out_var_id && get_stage_out_struct_type().self == type_id)
5620 				is_declarable_struct = false;
5621 			if (patch_stage_out_var_id && get_patch_stage_out_struct_type().self == type_id)
5622 				is_declarable_struct = false;
5623 			if (stage_in_var_id && get_stage_in_struct_type().self == type_id)
5624 				is_declarable_struct = false;
5625 			if (patch_stage_in_var_id && get_patch_stage_in_struct_type().self == type_id)
5626 				is_declarable_struct = false;
5627 
5628 			// Align and emit declarable structs...but avoid declaring each more than once.
5629 			if (is_declarable_struct && declared_structs.count(type_id) == 0)
5630 			{
5631 				if (emitted)
5632 					statement("");
5633 				emitted = false;
5634 
5635 				declared_structs.insert(type_id);
5636 
5637 				if (has_extended_decoration(type_id, SPIRVCrossDecorationBufferBlockRepacked))
5638 					align_struct(type, aligned_structs);
5639 
5640 				// Make sure we declare the underlying struct type, and not the "decorated" type with pointers, etc.
5641 				emit_struct(get<SPIRType>(type_id));
5642 			}
5643 		}
5644 	}
5645 
5646 	if (emitted)
5647 		statement("");
5648 }
5649 
emit_binary_unord_op(uint32_t result_type,uint32_t result_id,uint32_t op0,uint32_t op1,const char * op)5650 void CompilerMSL::emit_binary_unord_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
5651                                        const char *op)
5652 {
5653 	bool forward = should_forward(op0) && should_forward(op1);
5654 	emit_op(result_type, result_id,
5655 	        join("(isunordered(", to_enclosed_unpacked_expression(op0), ", ", to_enclosed_unpacked_expression(op1),
5656 	             ") || ", to_enclosed_unpacked_expression(op0), " ", op, " ", to_enclosed_unpacked_expression(op1),
5657 	             ")"),
5658 	        forward);
5659 
5660 	inherit_expression_dependencies(result_id, op0);
5661 	inherit_expression_dependencies(result_id, op1);
5662 }
5663 
emit_tessellation_io_load(uint32_t result_type_id,uint32_t id,uint32_t ptr)5664 bool CompilerMSL::emit_tessellation_io_load(uint32_t result_type_id, uint32_t id, uint32_t ptr)
5665 {
5666 	auto &ptr_type = expression_type(ptr);
5667 	auto &result_type = get<SPIRType>(result_type_id);
5668 	if (ptr_type.storage != StorageClassInput && ptr_type.storage != StorageClassOutput)
5669 		return false;
5670 	if (ptr_type.storage == StorageClassOutput && get_execution_model() == ExecutionModelTessellationEvaluation)
5671 		return false;
5672 
5673 	bool multi_patch_tess_ctl = get_execution_model() == ExecutionModelTessellationControl &&
5674 	                            msl_options.multi_patch_workgroup && ptr_type.storage == StorageClassInput;
5675 	bool flat_matrix = is_matrix(result_type) && ptr_type.storage == StorageClassInput && !multi_patch_tess_ctl;
5676 	bool flat_struct = result_type.basetype == SPIRType::Struct && ptr_type.storage == StorageClassInput;
5677 	bool flat_data_type = flat_matrix || is_array(result_type) || flat_struct;
5678 	if (!flat_data_type)
5679 		return false;
5680 
5681 	if (has_decoration(ptr, DecorationPatch))
5682 		return false;
5683 
5684 	// Now, we must unflatten a composite type and take care of interleaving array access with gl_in/gl_out.
5685 	// Lots of painful code duplication since we *really* should not unroll these kinds of loads in entry point fixup
5686 	// unless we're forced to do this when the code is emitting inoptimal OpLoads.
5687 	string expr;
5688 
5689 	uint32_t interface_index = get_extended_decoration(ptr, SPIRVCrossDecorationInterfaceMemberIndex);
5690 	auto *var = maybe_get_backing_variable(ptr);
5691 	bool ptr_is_io_variable = ir.ids[ptr].get_type() == TypeVariable;
5692 	auto &expr_type = get_pointee_type(ptr_type.self);
5693 
5694 	const auto &iface_type = expression_type(stage_in_ptr_var_id);
5695 
5696 	if (result_type.array.size() > 2)
5697 	{
5698 		SPIRV_CROSS_THROW("Cannot load tessellation IO variables with more than 2 dimensions.");
5699 	}
5700 	else if (result_type.array.size() == 2)
5701 	{
5702 		if (!ptr_is_io_variable)
5703 			SPIRV_CROSS_THROW("Loading an array-of-array must be loaded directly from an IO variable.");
5704 		if (interface_index == uint32_t(-1))
5705 			SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
5706 		if (result_type.basetype == SPIRType::Struct || flat_matrix)
5707 			SPIRV_CROSS_THROW("Cannot load array-of-array of composite type in tessellation IO.");
5708 
5709 		expr += type_to_glsl(result_type) + "({ ";
5710 		uint32_t num_control_points = to_array_size_literal(result_type, 1);
5711 		uint32_t base_interface_index = interface_index;
5712 
5713 		auto &sub_type = get<SPIRType>(result_type.parent_type);
5714 
5715 		for (uint32_t i = 0; i < num_control_points; i++)
5716 		{
5717 			expr += type_to_glsl(sub_type) + "({ ";
5718 			interface_index = base_interface_index;
5719 			uint32_t array_size = to_array_size_literal(result_type, 0);
5720 			if (multi_patch_tess_ctl)
5721 			{
5722 				for (uint32_t j = 0; j < array_size; j++)
5723 				{
5724 					const uint32_t indices[3] = { i, interface_index, j };
5725 
5726 					AccessChainMeta meta;
5727 					expr +=
5728 					    access_chain_internal(stage_in_ptr_var_id, indices, 3,
5729 					                          ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
5730 					// If the expression has more vector components than the result type, insert
5731 					// a swizzle. This shouldn't happen normally on valid SPIR-V, but it might
5732 					// happen if we replace the type of an input variable.
5733 					if (!is_matrix(sub_type) && sub_type.basetype != SPIRType::Struct &&
5734 					    expr_type.vecsize > sub_type.vecsize)
5735 						expr += vector_swizzle(sub_type.vecsize, 0);
5736 
5737 					if (j + 1 < array_size)
5738 						expr += ", ";
5739 				}
5740 			}
5741 			else
5742 			{
5743 				for (uint32_t j = 0; j < array_size; j++, interface_index++)
5744 				{
5745 					const uint32_t indices[2] = { i, interface_index };
5746 
5747 					AccessChainMeta meta;
5748 					expr +=
5749 					    access_chain_internal(stage_in_ptr_var_id, indices, 2,
5750 					                          ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
5751 					if (!is_matrix(sub_type) && sub_type.basetype != SPIRType::Struct &&
5752 					    expr_type.vecsize > sub_type.vecsize)
5753 						expr += vector_swizzle(sub_type.vecsize, 0);
5754 
5755 					if (j + 1 < array_size)
5756 						expr += ", ";
5757 				}
5758 			}
5759 			expr += " })";
5760 			if (i + 1 < num_control_points)
5761 				expr += ", ";
5762 		}
5763 		expr += " })";
5764 	}
5765 	else if (flat_struct)
5766 	{
5767 		bool is_array_of_struct = is_array(result_type);
5768 		if (is_array_of_struct && !ptr_is_io_variable)
5769 			SPIRV_CROSS_THROW("Loading array of struct from IO variable must come directly from IO variable.");
5770 
5771 		uint32_t num_control_points = 1;
5772 		if (is_array_of_struct)
5773 		{
5774 			num_control_points = to_array_size_literal(result_type, 0);
5775 			expr += type_to_glsl(result_type) + "({ ";
5776 		}
5777 
5778 		auto &struct_type = is_array_of_struct ? get<SPIRType>(result_type.parent_type) : result_type;
5779 		assert(struct_type.array.empty());
5780 
5781 		for (uint32_t i = 0; i < num_control_points; i++)
5782 		{
5783 			expr += type_to_glsl(struct_type) + "{ ";
5784 			for (uint32_t j = 0; j < uint32_t(struct_type.member_types.size()); j++)
5785 			{
5786 				// The base interface index is stored per variable for structs.
5787 				if (var)
5788 				{
5789 					interface_index =
5790 					    get_extended_member_decoration(var->self, j, SPIRVCrossDecorationInterfaceMemberIndex);
5791 				}
5792 
5793 				if (interface_index == uint32_t(-1))
5794 					SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
5795 
5796 				const auto &mbr_type = get<SPIRType>(struct_type.member_types[j]);
5797 				const auto &expr_mbr_type = get<SPIRType>(expr_type.member_types[j]);
5798 				if (is_matrix(mbr_type) && ptr_type.storage == StorageClassInput && !multi_patch_tess_ctl)
5799 				{
5800 					expr += type_to_glsl(mbr_type) + "(";
5801 					for (uint32_t k = 0; k < mbr_type.columns; k++, interface_index++)
5802 					{
5803 						if (is_array_of_struct)
5804 						{
5805 							const uint32_t indices[2] = { i, interface_index };
5806 							AccessChainMeta meta;
5807 							expr += access_chain_internal(
5808 							    stage_in_ptr_var_id, indices, 2,
5809 							    ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
5810 						}
5811 						else
5812 							expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
5813 						if (expr_mbr_type.vecsize > mbr_type.vecsize)
5814 							expr += vector_swizzle(mbr_type.vecsize, 0);
5815 
5816 						if (k + 1 < mbr_type.columns)
5817 							expr += ", ";
5818 					}
5819 					expr += ")";
5820 				}
5821 				else if (is_array(mbr_type))
5822 				{
5823 					expr += type_to_glsl(mbr_type) + "({ ";
5824 					uint32_t array_size = to_array_size_literal(mbr_type, 0);
5825 					if (multi_patch_tess_ctl)
5826 					{
5827 						for (uint32_t k = 0; k < array_size; k++)
5828 						{
5829 							if (is_array_of_struct)
5830 							{
5831 								const uint32_t indices[3] = { i, interface_index, k };
5832 								AccessChainMeta meta;
5833 								expr += access_chain_internal(
5834 								    stage_in_ptr_var_id, indices, 3,
5835 								    ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
5836 							}
5837 							else
5838 								expr += join(to_expression(ptr), ".", to_member_name(iface_type, interface_index), "[",
5839 								             k, "]");
5840 							if (expr_mbr_type.vecsize > mbr_type.vecsize)
5841 								expr += vector_swizzle(mbr_type.vecsize, 0);
5842 
5843 							if (k + 1 < array_size)
5844 								expr += ", ";
5845 						}
5846 					}
5847 					else
5848 					{
5849 						for (uint32_t k = 0; k < array_size; k++, interface_index++)
5850 						{
5851 							if (is_array_of_struct)
5852 							{
5853 								const uint32_t indices[2] = { i, interface_index };
5854 								AccessChainMeta meta;
5855 								expr += access_chain_internal(
5856 								    stage_in_ptr_var_id, indices, 2,
5857 								    ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
5858 							}
5859 							else
5860 								expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
5861 							if (expr_mbr_type.vecsize > mbr_type.vecsize)
5862 								expr += vector_swizzle(mbr_type.vecsize, 0);
5863 
5864 							if (k + 1 < array_size)
5865 								expr += ", ";
5866 						}
5867 					}
5868 					expr += " })";
5869 				}
5870 				else
5871 				{
5872 					if (is_array_of_struct)
5873 					{
5874 						const uint32_t indices[2] = { i, interface_index };
5875 						AccessChainMeta meta;
5876 						expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
5877 						                              ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT,
5878 						                              &meta);
5879 					}
5880 					else
5881 						expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
5882 					if (expr_mbr_type.vecsize > mbr_type.vecsize)
5883 						expr += vector_swizzle(mbr_type.vecsize, 0);
5884 				}
5885 
5886 				if (j + 1 < struct_type.member_types.size())
5887 					expr += ", ";
5888 			}
5889 			expr += " }";
5890 			if (i + 1 < num_control_points)
5891 				expr += ", ";
5892 		}
5893 		if (is_array_of_struct)
5894 			expr += " })";
5895 	}
5896 	else if (flat_matrix)
5897 	{
5898 		bool is_array_of_matrix = is_array(result_type);
5899 		if (is_array_of_matrix && !ptr_is_io_variable)
5900 			SPIRV_CROSS_THROW("Loading array of matrix from IO variable must come directly from IO variable.");
5901 		if (interface_index == uint32_t(-1))
5902 			SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
5903 
5904 		if (is_array_of_matrix)
5905 		{
5906 			// Loading a matrix from each control point.
5907 			uint32_t base_interface_index = interface_index;
5908 			uint32_t num_control_points = to_array_size_literal(result_type, 0);
5909 			expr += type_to_glsl(result_type) + "({ ";
5910 
5911 			auto &matrix_type = get_variable_element_type(get<SPIRVariable>(ptr));
5912 
5913 			for (uint32_t i = 0; i < num_control_points; i++)
5914 			{
5915 				interface_index = base_interface_index;
5916 				expr += type_to_glsl(matrix_type) + "(";
5917 				for (uint32_t j = 0; j < result_type.columns; j++, interface_index++)
5918 				{
5919 					const uint32_t indices[2] = { i, interface_index };
5920 
5921 					AccessChainMeta meta;
5922 					expr +=
5923 					    access_chain_internal(stage_in_ptr_var_id, indices, 2,
5924 					                          ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
5925 					if (expr_type.vecsize > result_type.vecsize)
5926 						expr += vector_swizzle(result_type.vecsize, 0);
5927 					if (j + 1 < result_type.columns)
5928 						expr += ", ";
5929 				}
5930 				expr += ")";
5931 				if (i + 1 < num_control_points)
5932 					expr += ", ";
5933 			}
5934 
5935 			expr += " })";
5936 		}
5937 		else
5938 		{
5939 			expr += type_to_glsl(result_type) + "(";
5940 			for (uint32_t i = 0; i < result_type.columns; i++, interface_index++)
5941 			{
5942 				expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
5943 				if (expr_type.vecsize > result_type.vecsize)
5944 					expr += vector_swizzle(result_type.vecsize, 0);
5945 				if (i + 1 < result_type.columns)
5946 					expr += ", ";
5947 			}
5948 			expr += ")";
5949 		}
5950 	}
5951 	else if (ptr_is_io_variable)
5952 	{
5953 		assert(is_array(result_type));
5954 		assert(result_type.array.size() == 1);
5955 		if (interface_index == uint32_t(-1))
5956 			SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
5957 
5958 		// We're loading an array directly from a global variable.
5959 		// This means we're loading one member from each control point.
5960 		expr += type_to_glsl(result_type) + "({ ";
5961 		uint32_t num_control_points = to_array_size_literal(result_type, 0);
5962 
5963 		for (uint32_t i = 0; i < num_control_points; i++)
5964 		{
5965 			const uint32_t indices[2] = { i, interface_index };
5966 
5967 			AccessChainMeta meta;
5968 			expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
5969 			                              ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
5970 			if (expr_type.vecsize > result_type.vecsize)
5971 				expr += vector_swizzle(result_type.vecsize, 0);
5972 
5973 			if (i + 1 < num_control_points)
5974 				expr += ", ";
5975 		}
5976 		expr += " })";
5977 	}
5978 	else
5979 	{
5980 		// We're loading an array from a concrete control point.
5981 		assert(is_array(result_type));
5982 		assert(result_type.array.size() == 1);
5983 		if (interface_index == uint32_t(-1))
5984 			SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
5985 
5986 		expr += type_to_glsl(result_type) + "({ ";
5987 		uint32_t array_size = to_array_size_literal(result_type, 0);
5988 		for (uint32_t i = 0; i < array_size; i++, interface_index++)
5989 		{
5990 			expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
5991 			if (expr_type.vecsize > result_type.vecsize)
5992 				expr += vector_swizzle(result_type.vecsize, 0);
5993 			if (i + 1 < array_size)
5994 				expr += ", ";
5995 		}
5996 		expr += " })";
5997 	}
5998 
5999 	emit_op(result_type_id, id, expr, false);
6000 	register_read(id, ptr, false);
6001 	return true;
6002 }
6003 
emit_tessellation_access_chain(const uint32_t * ops,uint32_t length)6004 bool CompilerMSL::emit_tessellation_access_chain(const uint32_t *ops, uint32_t length)
6005 {
6006 	// If this is a per-vertex output, remap it to the I/O array buffer.
6007 
6008 	// Any object which did not go through IO flattening shenanigans will go there instead.
6009 	// We will unflatten on-demand instead as needed, but not all possible cases can be supported, especially with arrays.
6010 
6011 	auto *var = maybe_get_backing_variable(ops[2]);
6012 	bool patch = false;
6013 	bool flat_data = false;
6014 	bool ptr_is_chain = false;
6015 	bool multi_patch = get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup;
6016 
6017 	if (var)
6018 	{
6019 		patch = has_decoration(ops[2], DecorationPatch) || is_patch_block(get_variable_data_type(*var));
6020 
6021 		// Should match strip_array in add_interface_block.
6022 		flat_data = var->storage == StorageClassInput ||
6023 		            (var->storage == StorageClassOutput && get_execution_model() == ExecutionModelTessellationControl);
6024 
6025 		// We might have a chained access chain, where
6026 		// we first take the access chain to the control point, and then we chain into a member or something similar.
6027 		// In this case, we need to skip gl_in/gl_out remapping.
6028 		ptr_is_chain = var->self != ID(ops[2]);
6029 	}
6030 
6031 	BuiltIn bi_type = BuiltIn(get_decoration(ops[2], DecorationBuiltIn));
6032 	if (var && flat_data && !patch &&
6033 	    (!is_builtin_variable(*var) || bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
6034 	     bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance ||
6035 	     get_variable_data_type(*var).basetype == SPIRType::Struct))
6036 	{
6037 		AccessChainMeta meta;
6038 		SmallVector<uint32_t> indices;
6039 		uint32_t next_id = ir.increase_bound_by(1);
6040 
6041 		indices.reserve(length - 3 + 1);
6042 
6043 		uint32_t first_non_array_index = ptr_is_chain ? 3 : 4;
6044 		VariableID stage_var_id = var->storage == StorageClassInput ? stage_in_ptr_var_id : stage_out_ptr_var_id;
6045 		VariableID ptr = ptr_is_chain ? VariableID(ops[2]) : stage_var_id;
6046 		if (!ptr_is_chain)
6047 		{
6048 			// Index into gl_in/gl_out with first array index.
6049 			indices.push_back(ops[3]);
6050 		}
6051 
6052 		auto &result_ptr_type = get<SPIRType>(ops[0]);
6053 
6054 		uint32_t const_mbr_id = next_id++;
6055 		uint32_t index = get_extended_decoration(var->self, SPIRVCrossDecorationInterfaceMemberIndex);
6056 		if (var->storage == StorageClassInput || has_decoration(get_variable_element_type(*var).self, DecorationBlock))
6057 		{
6058 			uint32_t i = first_non_array_index;
6059 			auto *type = &get_variable_element_type(*var);
6060 			if (index == uint32_t(-1) && length >= (first_non_array_index + 1))
6061 			{
6062 				// Maybe this is a struct type in the input class, in which case
6063 				// we put it as a decoration on the corresponding member.
6064 				index = get_extended_member_decoration(var->self, get_constant(ops[first_non_array_index]).scalar(),
6065 				                                       SPIRVCrossDecorationInterfaceMemberIndex);
6066 				assert(index != uint32_t(-1));
6067 				i++;
6068 				type = &get<SPIRType>(type->member_types[get_constant(ops[first_non_array_index]).scalar()]);
6069 			}
6070 
6071 			// In this case, we're poking into flattened structures and arrays, so now we have to
6072 			// combine the following indices. If we encounter a non-constant index,
6073 			// we're hosed.
6074 			for (; i < length; ++i)
6075 			{
6076 				if ((multi_patch || (!is_array(*type) && !is_matrix(*type))) && type->basetype != SPIRType::Struct)
6077 					break;
6078 
6079 				auto *c = maybe_get<SPIRConstant>(ops[i]);
6080 				if (!c || c->specialization)
6081 					SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable in tessellation. "
6082 					                  "This is currently unsupported.");
6083 
6084 				// We're in flattened space, so just increment the member index into IO block.
6085 				// We can only do this once in the current implementation, so either:
6086 				// Struct, Matrix or 1-dimensional array for a control point.
6087 				index += c->scalar();
6088 
6089 				if (type->parent_type)
6090 					type = &get<SPIRType>(type->parent_type);
6091 				else if (type->basetype == SPIRType::Struct)
6092 					type = &get<SPIRType>(type->member_types[c->scalar()]);
6093 			}
6094 
6095 			if ((!multi_patch && (is_matrix(result_ptr_type) || is_array(result_ptr_type))) ||
6096 			    result_ptr_type.basetype == SPIRType::Struct)
6097 			{
6098 				// We're not going to emit the actual member name, we let any further OpLoad take care of that.
6099 				// Tag the access chain with the member index we're referencing.
6100 				set_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex, index);
6101 			}
6102 			else
6103 			{
6104 				// Access the appropriate member of gl_in/gl_out.
6105 				set<SPIRConstant>(const_mbr_id, get_uint_type_id(), index, false);
6106 				indices.push_back(const_mbr_id);
6107 
6108 				// Append any straggling access chain indices.
6109 				if (i < length)
6110 					indices.insert(indices.end(), ops + i, ops + length);
6111 			}
6112 		}
6113 		else
6114 		{
6115 			assert(index != uint32_t(-1));
6116 			set<SPIRConstant>(const_mbr_id, get_uint_type_id(), index, false);
6117 			indices.push_back(const_mbr_id);
6118 
6119 			indices.insert(indices.end(), ops + 4, ops + length);
6120 		}
6121 
6122 		// We use the pointer to the base of the input/output array here,
6123 		// so this is always a pointer chain.
6124 		string e;
6125 
6126 		if (!ptr_is_chain)
6127 		{
6128 			// This is the start of an access chain, use ptr_chain to index into control point array.
6129 			e = access_chain(ptr, indices.data(), uint32_t(indices.size()), result_ptr_type, &meta, true);
6130 		}
6131 		else
6132 		{
6133 			// If we're accessing a struct, we need to use member indices which are based on the IO block,
6134 			// not actual struct type, so we have to use a split access chain here where
6135 			// first path resolves the control point index, i.e. gl_in[index], and second half deals with
6136 			// looking up flattened member name.
6137 
6138 			// However, it is possible that we partially accessed a struct,
6139 			// by taking pointer to member inside the control-point array.
6140 			// For this case, we fall back to a natural access chain since we have already dealt with remapping struct members.
6141 			// One way to check this here is if we have 2 implied read expressions.
6142 			// First one is the gl_in/gl_out struct itself, then an index into that array.
6143 			// If we have traversed further, we use a normal access chain formulation.
6144 			auto *ptr_expr = maybe_get<SPIRExpression>(ptr);
6145 			if (ptr_expr && ptr_expr->implied_read_expressions.size() == 2)
6146 			{
6147 				e = join(to_expression(ptr),
6148 				         access_chain_internal(stage_var_id, indices.data(), uint32_t(indices.size()),
6149 				                               ACCESS_CHAIN_CHAIN_ONLY_BIT, &meta));
6150 			}
6151 			else
6152 			{
6153 				e = access_chain_internal(ptr, indices.data(), uint32_t(indices.size()), 0, &meta);
6154 			}
6155 		}
6156 
6157 		// Get the actual type of the object that was accessed. If it's a vector type and we changed it,
6158 		// then we'll need to add a swizzle.
6159 		// For this, we can't necessarily rely on the type of the base expression, because it might be
6160 		// another access chain, and it will therefore already have the "correct" type.
6161 		auto *expr_type = &get_variable_data_type(*var);
6162 		if (has_extended_decoration(ops[2], SPIRVCrossDecorationTessIOOriginalInputTypeID))
6163 			expr_type = &get<SPIRType>(get_extended_decoration(ops[2], SPIRVCrossDecorationTessIOOriginalInputTypeID));
6164 		for (uint32_t i = 3; i < length; i++)
6165 		{
6166 			if (!is_array(*expr_type) && expr_type->basetype == SPIRType::Struct)
6167 				expr_type = &get<SPIRType>(expr_type->member_types[get<SPIRConstant>(ops[i]).scalar()]);
6168 			else
6169 				expr_type = &get<SPIRType>(expr_type->parent_type);
6170 		}
6171 		if (!is_array(*expr_type) && !is_matrix(*expr_type) && expr_type->basetype != SPIRType::Struct &&
6172 		    expr_type->vecsize > result_ptr_type.vecsize)
6173 			e += vector_swizzle(result_ptr_type.vecsize, 0);
6174 
6175 		auto &expr = set<SPIRExpression>(ops[1], move(e), ops[0], should_forward(ops[2]));
6176 		expr.loaded_from = var->self;
6177 		expr.need_transpose = meta.need_transpose;
6178 		expr.access_chain = true;
6179 
6180 		// Mark the result as being packed if necessary.
6181 		if (meta.storage_is_packed)
6182 			set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypePacked);
6183 		if (meta.storage_physical_type != 0)
6184 			set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypeID, meta.storage_physical_type);
6185 		if (meta.storage_is_invariant)
6186 			set_decoration(ops[1], DecorationInvariant);
6187 		// Save the type we found in case the result is used in another access chain.
6188 		set_extended_decoration(ops[1], SPIRVCrossDecorationTessIOOriginalInputTypeID, expr_type->self);
6189 
6190 		// If we have some expression dependencies in our access chain, this access chain is technically a forwarded
6191 		// temporary which could be subject to invalidation.
6192 		// Need to assume we're forwarded while calling inherit_expression_depdendencies.
6193 		forwarded_temporaries.insert(ops[1]);
6194 		// The access chain itself is never forced to a temporary, but its dependencies might.
6195 		suppressed_usage_tracking.insert(ops[1]);
6196 
6197 		for (uint32_t i = 2; i < length; i++)
6198 		{
6199 			inherit_expression_dependencies(ops[1], ops[i]);
6200 			add_implied_read_expression(expr, ops[i]);
6201 		}
6202 
6203 		// If we have no dependencies after all, i.e., all indices in the access chain are immutable temporaries,
6204 		// we're not forwarded after all.
6205 		if (expr.expression_dependencies.empty())
6206 			forwarded_temporaries.erase(ops[1]);
6207 
6208 		return true;
6209 	}
6210 
6211 	// If this is the inner tessellation level, and we're tessellating triangles,
6212 	// drop the last index. It isn't an array in this case, so we can't have an
6213 	// array reference here. We need to make this ID a variable instead of an
6214 	// expression so we don't try to dereference it as a variable pointer.
6215 	// Don't do this if the index is a constant 1, though. We need to drop stores
6216 	// to that one.
6217 	auto *m = ir.find_meta(var ? var->self : ID(0));
6218 	if (get_execution_model() == ExecutionModelTessellationControl && var && m &&
6219 	    m->decoration.builtin_type == BuiltInTessLevelInner && get_entry_point().flags.get(ExecutionModeTriangles))
6220 	{
6221 		auto *c = maybe_get<SPIRConstant>(ops[3]);
6222 		if (c && c->scalar() == 1)
6223 			return false;
6224 		auto &dest_var = set<SPIRVariable>(ops[1], *var);
6225 		dest_var.basetype = ops[0];
6226 		ir.meta[ops[1]] = ir.meta[ops[2]];
6227 		inherit_expression_dependencies(ops[1], ops[2]);
6228 		return true;
6229 	}
6230 
6231 	return false;
6232 }
6233 
is_out_of_bounds_tessellation_level(uint32_t id_lhs)6234 bool CompilerMSL::is_out_of_bounds_tessellation_level(uint32_t id_lhs)
6235 {
6236 	if (!get_entry_point().flags.get(ExecutionModeTriangles))
6237 		return false;
6238 
6239 	// In SPIR-V, TessLevelInner always has two elements and TessLevelOuter always has
6240 	// four. This is true even if we are tessellating triangles. This allows clients
6241 	// to use a single tessellation control shader with multiple tessellation evaluation
6242 	// shaders.
6243 	// In Metal, however, only the first element of TessLevelInner and the first three
6244 	// of TessLevelOuter are accessible. This stems from how in Metal, the tessellation
6245 	// levels must be stored to a dedicated buffer in a particular format that depends
6246 	// on the patch type. Therefore, in Triangles mode, any access to the second
6247 	// inner level or the fourth outer level must be dropped.
6248 	const auto *e = maybe_get<SPIRExpression>(id_lhs);
6249 	if (!e || !e->access_chain)
6250 		return false;
6251 	BuiltIn builtin = BuiltIn(get_decoration(e->loaded_from, DecorationBuiltIn));
6252 	if (builtin != BuiltInTessLevelInner && builtin != BuiltInTessLevelOuter)
6253 		return false;
6254 	auto *c = maybe_get<SPIRConstant>(e->implied_read_expressions[1]);
6255 	if (!c)
6256 		return false;
6257 	return (builtin == BuiltInTessLevelInner && c->scalar() == 1) ||
6258 	       (builtin == BuiltInTessLevelOuter && c->scalar() == 3);
6259 }
6260 
prepare_access_chain_for_scalar_access(std::string & expr,const SPIRType & type,spv::StorageClass storage,bool & is_packed)6261 void CompilerMSL::prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
6262                                                          spv::StorageClass storage, bool &is_packed)
6263 {
6264 	// If there is any risk of writes happening with the access chain in question,
6265 	// and there is a risk of concurrent write access to other components,
6266 	// we must cast the access chain to a plain pointer to ensure we only access the exact scalars we expect.
6267 	// The MSL compiler refuses to allow component-level access for any non-packed vector types.
6268 	if (!is_packed && (storage == StorageClassStorageBuffer || storage == StorageClassWorkgroup))
6269 	{
6270 		const char *addr_space = storage == StorageClassWorkgroup ? "threadgroup" : "device";
6271 		expr = join("((", addr_space, " ", type_to_glsl(type), "*)&", enclose_expression(expr), ")");
6272 
6273 		// Further indexing should happen with packed rules (array index, not swizzle).
6274 		is_packed = true;
6275 	}
6276 }
6277 
6278 // Override for MSL-specific syntax instructions
emit_instruction(const Instruction & instruction)6279 void CompilerMSL::emit_instruction(const Instruction &instruction)
6280 {
6281 #define MSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
6282 #define MSL_BOP_CAST(op, type) \
6283 	emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
6284 #define MSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
6285 #define MSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
6286 #define MSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
6287 #define MSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
6288 #define MSL_BFOP_CAST(op, type) \
6289 	emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
6290 #define MSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
6291 #define MSL_UNORD_BOP(op) emit_binary_unord_op(ops[0], ops[1], ops[2], ops[3], #op)
6292 
6293 	auto ops = stream(instruction);
6294 	auto opcode = static_cast<Op>(instruction.op);
6295 
6296 	// If we need to do implicit bitcasts, make sure we do it with the correct type.
6297 	uint32_t integer_width = get_integer_width_for_instruction(instruction);
6298 	auto int_type = to_signed_basetype(integer_width);
6299 	auto uint_type = to_unsigned_basetype(integer_width);
6300 
6301 	switch (opcode)
6302 	{
6303 	case OpLoad:
6304 	{
6305 		uint32_t id = ops[1];
6306 		uint32_t ptr = ops[2];
6307 		if (is_tessellation_shader())
6308 		{
6309 			if (!emit_tessellation_io_load(ops[0], id, ptr))
6310 				CompilerGLSL::emit_instruction(instruction);
6311 		}
6312 		else
6313 		{
6314 			// Sample mask input for Metal is not an array
6315 			if (BuiltIn(get_decoration(ptr, DecorationBuiltIn)) == BuiltInSampleMask)
6316 				set_decoration(id, DecorationBuiltIn, BuiltInSampleMask);
6317 			CompilerGLSL::emit_instruction(instruction);
6318 		}
6319 		break;
6320 	}
6321 
6322 	// Comparisons
6323 	case OpIEqual:
6324 		MSL_BOP_CAST(==, int_type);
6325 		break;
6326 
6327 	case OpLogicalEqual:
6328 	case OpFOrdEqual:
6329 		MSL_BOP(==);
6330 		break;
6331 
6332 	case OpINotEqual:
6333 		MSL_BOP_CAST(!=, int_type);
6334 		break;
6335 
6336 	case OpLogicalNotEqual:
6337 	case OpFOrdNotEqual:
6338 		MSL_BOP(!=);
6339 		break;
6340 
6341 	case OpUGreaterThan:
6342 		MSL_BOP_CAST(>, uint_type);
6343 		break;
6344 
6345 	case OpSGreaterThan:
6346 		MSL_BOP_CAST(>, int_type);
6347 		break;
6348 
6349 	case OpFOrdGreaterThan:
6350 		MSL_BOP(>);
6351 		break;
6352 
6353 	case OpUGreaterThanEqual:
6354 		MSL_BOP_CAST(>=, uint_type);
6355 		break;
6356 
6357 	case OpSGreaterThanEqual:
6358 		MSL_BOP_CAST(>=, int_type);
6359 		break;
6360 
6361 	case OpFOrdGreaterThanEqual:
6362 		MSL_BOP(>=);
6363 		break;
6364 
6365 	case OpULessThan:
6366 		MSL_BOP_CAST(<, uint_type);
6367 		break;
6368 
6369 	case OpSLessThan:
6370 		MSL_BOP_CAST(<, int_type);
6371 		break;
6372 
6373 	case OpFOrdLessThan:
6374 		MSL_BOP(<);
6375 		break;
6376 
6377 	case OpULessThanEqual:
6378 		MSL_BOP_CAST(<=, uint_type);
6379 		break;
6380 
6381 	case OpSLessThanEqual:
6382 		MSL_BOP_CAST(<=, int_type);
6383 		break;
6384 
6385 	case OpFOrdLessThanEqual:
6386 		MSL_BOP(<=);
6387 		break;
6388 
6389 	case OpFUnordEqual:
6390 		MSL_UNORD_BOP(==);
6391 		break;
6392 
6393 	case OpFUnordNotEqual:
6394 		MSL_UNORD_BOP(!=);
6395 		break;
6396 
6397 	case OpFUnordGreaterThan:
6398 		MSL_UNORD_BOP(>);
6399 		break;
6400 
6401 	case OpFUnordGreaterThanEqual:
6402 		MSL_UNORD_BOP(>=);
6403 		break;
6404 
6405 	case OpFUnordLessThan:
6406 		MSL_UNORD_BOP(<);
6407 		break;
6408 
6409 	case OpFUnordLessThanEqual:
6410 		MSL_UNORD_BOP(<=);
6411 		break;
6412 
6413 	// Derivatives
6414 	case OpDPdx:
6415 	case OpDPdxFine:
6416 	case OpDPdxCoarse:
6417 		MSL_UFOP(dfdx);
6418 		register_control_dependent_expression(ops[1]);
6419 		break;
6420 
6421 	case OpDPdy:
6422 	case OpDPdyFine:
6423 	case OpDPdyCoarse:
6424 		MSL_UFOP(dfdy);
6425 		register_control_dependent_expression(ops[1]);
6426 		break;
6427 
6428 	case OpFwidth:
6429 	case OpFwidthCoarse:
6430 	case OpFwidthFine:
6431 		MSL_UFOP(fwidth);
6432 		register_control_dependent_expression(ops[1]);
6433 		break;
6434 
6435 	// Bitfield
6436 	case OpBitFieldInsert:
6437 	{
6438 		emit_bitfield_insert_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], "insert_bits", SPIRType::UInt);
6439 		break;
6440 	}
6441 
6442 	case OpBitFieldSExtract:
6443 	{
6444 		emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "extract_bits", int_type, int_type,
6445 		                                SPIRType::UInt, SPIRType::UInt);
6446 		break;
6447 	}
6448 
6449 	case OpBitFieldUExtract:
6450 	{
6451 		emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "extract_bits", uint_type, uint_type,
6452 		                                SPIRType::UInt, SPIRType::UInt);
6453 		break;
6454 	}
6455 
6456 	case OpBitReverse:
6457 		// BitReverse does not have issues with sign since result type must match input type.
6458 		MSL_UFOP(reverse_bits);
6459 		break;
6460 
6461 	case OpBitCount:
6462 	{
6463 		auto basetype = expression_type(ops[2]).basetype;
6464 		emit_unary_func_op_cast(ops[0], ops[1], ops[2], "popcount", basetype, basetype);
6465 		break;
6466 	}
6467 
6468 	case OpFRem:
6469 		MSL_BFOP(fmod);
6470 		break;
6471 
6472 	case OpFMul:
6473 		if (msl_options.invariant_float_math)
6474 			MSL_BFOP(spvFMul);
6475 		else
6476 			MSL_BOP(*);
6477 		break;
6478 
6479 	case OpFAdd:
6480 		if (msl_options.invariant_float_math)
6481 			MSL_BFOP(spvFAdd);
6482 		else
6483 			MSL_BOP(+);
6484 		break;
6485 
6486 	// Atomics
6487 	case OpAtomicExchange:
6488 	{
6489 		uint32_t result_type = ops[0];
6490 		uint32_t id = ops[1];
6491 		uint32_t ptr = ops[2];
6492 		uint32_t mem_sem = ops[4];
6493 		uint32_t val = ops[5];
6494 		emit_atomic_func_op(result_type, id, "atomic_exchange_explicit", mem_sem, mem_sem, false, ptr, val);
6495 		break;
6496 	}
6497 
6498 	case OpAtomicCompareExchange:
6499 	{
6500 		uint32_t result_type = ops[0];
6501 		uint32_t id = ops[1];
6502 		uint32_t ptr = ops[2];
6503 		uint32_t mem_sem_pass = ops[4];
6504 		uint32_t mem_sem_fail = ops[5];
6505 		uint32_t val = ops[6];
6506 		uint32_t comp = ops[7];
6507 		emit_atomic_func_op(result_type, id, "atomic_compare_exchange_weak_explicit", mem_sem_pass, mem_sem_fail, true,
6508 		                    ptr, comp, true, false, val);
6509 		break;
6510 	}
6511 
6512 	case OpAtomicCompareExchangeWeak:
6513 		SPIRV_CROSS_THROW("OpAtomicCompareExchangeWeak is only supported in kernel profile.");
6514 
6515 	case OpAtomicLoad:
6516 	{
6517 		uint32_t result_type = ops[0];
6518 		uint32_t id = ops[1];
6519 		uint32_t ptr = ops[2];
6520 		uint32_t mem_sem = ops[4];
6521 		emit_atomic_func_op(result_type, id, "atomic_load_explicit", mem_sem, mem_sem, false, ptr, 0);
6522 		break;
6523 	}
6524 
6525 	case OpAtomicStore:
6526 	{
6527 		uint32_t result_type = expression_type(ops[0]).self;
6528 		uint32_t id = ops[0];
6529 		uint32_t ptr = ops[0];
6530 		uint32_t mem_sem = ops[2];
6531 		uint32_t val = ops[3];
6532 		emit_atomic_func_op(result_type, id, "atomic_store_explicit", mem_sem, mem_sem, false, ptr, val);
6533 		break;
6534 	}
6535 
6536 #define MSL_AFMO_IMPL(op, valsrc, valconst)                                                                      \
6537 	do                                                                                                           \
6538 	{                                                                                                            \
6539 		uint32_t result_type = ops[0];                                                                           \
6540 		uint32_t id = ops[1];                                                                                    \
6541 		uint32_t ptr = ops[2];                                                                                   \
6542 		uint32_t mem_sem = ops[4];                                                                               \
6543 		uint32_t val = valsrc;                                                                                   \
6544 		emit_atomic_func_op(result_type, id, "atomic_fetch_" #op "_explicit", mem_sem, mem_sem, false, ptr, val, \
6545 		                    false, valconst);                                                                    \
6546 	} while (false)
6547 
6548 #define MSL_AFMO(op) MSL_AFMO_IMPL(op, ops[5], false)
6549 #define MSL_AFMIO(op) MSL_AFMO_IMPL(op, 1, true)
6550 
6551 	case OpAtomicIIncrement:
6552 		MSL_AFMIO(add);
6553 		break;
6554 
6555 	case OpAtomicIDecrement:
6556 		MSL_AFMIO(sub);
6557 		break;
6558 
6559 	case OpAtomicIAdd:
6560 		MSL_AFMO(add);
6561 		break;
6562 
6563 	case OpAtomicISub:
6564 		MSL_AFMO(sub);
6565 		break;
6566 
6567 	case OpAtomicSMin:
6568 	case OpAtomicUMin:
6569 		MSL_AFMO(min);
6570 		break;
6571 
6572 	case OpAtomicSMax:
6573 	case OpAtomicUMax:
6574 		MSL_AFMO(max);
6575 		break;
6576 
6577 	case OpAtomicAnd:
6578 		MSL_AFMO(and);
6579 		break;
6580 
6581 	case OpAtomicOr:
6582 		MSL_AFMO(or);
6583 		break;
6584 
6585 	case OpAtomicXor:
6586 		MSL_AFMO(xor);
6587 		break;
6588 
6589 	// Images
6590 
6591 	// Reads == Fetches in Metal
6592 	case OpImageRead:
6593 	{
6594 		// Mark that this shader reads from this image
6595 		uint32_t img_id = ops[2];
6596 		auto &type = expression_type(img_id);
6597 		if (type.image.dim != DimSubpassData)
6598 		{
6599 			auto *p_var = maybe_get_backing_variable(img_id);
6600 			if (p_var && has_decoration(p_var->self, DecorationNonReadable))
6601 			{
6602 				unset_decoration(p_var->self, DecorationNonReadable);
6603 				force_recompile();
6604 			}
6605 		}
6606 
6607 		emit_texture_op(instruction, false);
6608 		break;
6609 	}
6610 
6611 	// Emulate texture2D atomic operations
6612 	case OpImageTexelPointer:
6613 	{
6614 		// When using the pointer, we need to know which variable it is actually loaded from.
6615 		auto *var = maybe_get_backing_variable(ops[2]);
6616 		if (var && atomic_image_vars.count(var->self))
6617 		{
6618 			uint32_t result_type = ops[0];
6619 			uint32_t id = ops[1];
6620 
6621 			std::string coord = to_expression(ops[3]);
6622 			auto &type = expression_type(ops[2]);
6623 			if (type.image.dim == Dim2D)
6624 			{
6625 				coord = join("spvImage2DAtomicCoord(", coord, ", ", to_expression(ops[2]), ")");
6626 			}
6627 
6628 			auto &e = set<SPIRExpression>(id, join(to_expression(ops[2]), "_atomic[", coord, "]"), result_type, true);
6629 			e.loaded_from = var ? var->self : ID(0);
6630 			inherit_expression_dependencies(id, ops[3]);
6631 		}
6632 		else
6633 		{
6634 			uint32_t result_type = ops[0];
6635 			uint32_t id = ops[1];
6636 			auto &e =
6637 			    set<SPIRExpression>(id, join(to_expression(ops[2]), ", ", to_expression(ops[3])), result_type, true);
6638 
6639 			// When using the pointer, we need to know which variable it is actually loaded from.
6640 			e.loaded_from = var ? var->self : ID(0);
6641 			inherit_expression_dependencies(id, ops[3]);
6642 		}
6643 		break;
6644 	}
6645 
6646 	case OpImageWrite:
6647 	{
6648 		uint32_t img_id = ops[0];
6649 		uint32_t coord_id = ops[1];
6650 		uint32_t texel_id = ops[2];
6651 		const uint32_t *opt = &ops[3];
6652 		uint32_t length = instruction.length - 3;
6653 
6654 		// Bypass pointers because we need the real image struct
6655 		auto &type = expression_type(img_id);
6656 		auto &img_type = get<SPIRType>(type.self);
6657 
6658 		// Ensure this image has been marked as being written to and force a
6659 		// recommpile so that the image type output will include write access
6660 		auto *p_var = maybe_get_backing_variable(img_id);
6661 		if (p_var && has_decoration(p_var->self, DecorationNonWritable))
6662 		{
6663 			unset_decoration(p_var->self, DecorationNonWritable);
6664 			force_recompile();
6665 		}
6666 
6667 		bool forward = false;
6668 		uint32_t bias = 0;
6669 		uint32_t lod = 0;
6670 		uint32_t flags = 0;
6671 
6672 		if (length)
6673 		{
6674 			flags = *opt++;
6675 			length--;
6676 		}
6677 
6678 		auto test = [&](uint32_t &v, uint32_t flag) {
6679 			if (length && (flags & flag))
6680 			{
6681 				v = *opt++;
6682 				length--;
6683 			}
6684 		};
6685 
6686 		test(bias, ImageOperandsBiasMask);
6687 		test(lod, ImageOperandsLodMask);
6688 
6689 		auto &texel_type = expression_type(texel_id);
6690 		auto store_type = texel_type;
6691 		store_type.vecsize = 4;
6692 
6693 		TextureFunctionArguments args = {};
6694 		args.base.img = img_id;
6695 		args.base.imgtype = &img_type;
6696 		args.base.is_fetch = true;
6697 		args.coord = coord_id;
6698 		args.lod = lod;
6699 		statement(join(to_expression(img_id), ".write(",
6700 		               remap_swizzle(store_type, texel_type.vecsize, to_expression(texel_id)), ", ",
6701 		               CompilerMSL::to_function_args(args, &forward), ");"));
6702 
6703 		if (p_var && variable_storage_is_aliased(*p_var))
6704 			flush_all_aliased_variables();
6705 
6706 		break;
6707 	}
6708 
6709 	case OpImageQuerySize:
6710 	case OpImageQuerySizeLod:
6711 	{
6712 		uint32_t rslt_type_id = ops[0];
6713 		auto &rslt_type = get<SPIRType>(rslt_type_id);
6714 
6715 		uint32_t id = ops[1];
6716 
6717 		uint32_t img_id = ops[2];
6718 		string img_exp = to_expression(img_id);
6719 		auto &img_type = expression_type(img_id);
6720 		Dim img_dim = img_type.image.dim;
6721 		bool img_is_array = img_type.image.arrayed;
6722 
6723 		if (img_type.basetype != SPIRType::Image)
6724 			SPIRV_CROSS_THROW("Invalid type for OpImageQuerySize.");
6725 
6726 		string lod;
6727 		if (opcode == OpImageQuerySizeLod)
6728 		{
6729 			// LOD index defaults to zero, so don't bother outputing level zero index
6730 			string decl_lod = to_expression(ops[3]);
6731 			if (decl_lod != "0")
6732 				lod = decl_lod;
6733 		}
6734 
6735 		string expr = type_to_glsl(rslt_type) + "(";
6736 		expr += img_exp + ".get_width(" + lod + ")";
6737 
6738 		if (img_dim == Dim2D || img_dim == DimCube || img_dim == Dim3D)
6739 			expr += ", " + img_exp + ".get_height(" + lod + ")";
6740 
6741 		if (img_dim == Dim3D)
6742 			expr += ", " + img_exp + ".get_depth(" + lod + ")";
6743 
6744 		if (img_is_array)
6745 		{
6746 			expr += ", " + img_exp + ".get_array_size()";
6747 			if (img_dim == DimCube && msl_options.emulate_cube_array)
6748 				expr += " / 6";
6749 		}
6750 
6751 		expr += ")";
6752 
6753 		emit_op(rslt_type_id, id, expr, should_forward(img_id));
6754 
6755 		break;
6756 	}
6757 
6758 	case OpImageQueryLod:
6759 	{
6760 		if (!msl_options.supports_msl_version(2, 2))
6761 			SPIRV_CROSS_THROW("ImageQueryLod is only supported on MSL 2.2 and up.");
6762 		uint32_t result_type = ops[0];
6763 		uint32_t id = ops[1];
6764 		uint32_t image_id = ops[2];
6765 		uint32_t coord_id = ops[3];
6766 		emit_uninitialized_temporary_expression(result_type, id);
6767 
6768 		auto sampler_expr = to_sampler_expression(image_id);
6769 		auto *combined = maybe_get<SPIRCombinedImageSampler>(image_id);
6770 		auto image_expr = combined ? to_expression(combined->image) : to_expression(image_id);
6771 
6772 		// TODO: It is unclear if calculcate_clamped_lod also conditionally rounds
6773 		// the reported LOD based on the sampler. NEAREST miplevel should
6774 		// round the LOD, but LINEAR miplevel should not round.
6775 		// Let's hope this does not become an issue ...
6776 		statement(to_expression(id), ".x = ", image_expr, ".calculate_clamped_lod(", sampler_expr, ", ",
6777 		          to_expression(coord_id), ");");
6778 		statement(to_expression(id), ".y = ", image_expr, ".calculate_unclamped_lod(", sampler_expr, ", ",
6779 		          to_expression(coord_id), ");");
6780 		register_control_dependent_expression(id);
6781 		break;
6782 	}
6783 
6784 #define MSL_ImgQry(qrytype)                                                                 \
6785 	do                                                                                      \
6786 	{                                                                                       \
6787 		uint32_t rslt_type_id = ops[0];                                                     \
6788 		auto &rslt_type = get<SPIRType>(rslt_type_id);                                      \
6789 		uint32_t id = ops[1];                                                               \
6790 		uint32_t img_id = ops[2];                                                           \
6791 		string img_exp = to_expression(img_id);                                             \
6792 		string expr = type_to_glsl(rslt_type) + "(" + img_exp + ".get_num_" #qrytype "())"; \
6793 		emit_op(rslt_type_id, id, expr, should_forward(img_id));                            \
6794 	} while (false)
6795 
6796 	case OpImageQueryLevels:
6797 		MSL_ImgQry(mip_levels);
6798 		break;
6799 
6800 	case OpImageQuerySamples:
6801 		MSL_ImgQry(samples);
6802 		break;
6803 
6804 	case OpImage:
6805 	{
6806 		uint32_t result_type = ops[0];
6807 		uint32_t id = ops[1];
6808 		auto *combined = maybe_get<SPIRCombinedImageSampler>(ops[2]);
6809 
6810 		if (combined)
6811 		{
6812 			auto &e = emit_op(result_type, id, to_expression(combined->image), true, true);
6813 			auto *var = maybe_get_backing_variable(combined->image);
6814 			if (var)
6815 				e.loaded_from = var->self;
6816 		}
6817 		else
6818 		{
6819 			auto *var = maybe_get_backing_variable(ops[2]);
6820 			SPIRExpression *e;
6821 			if (var && has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler))
6822 				e = &emit_op(result_type, id, join(to_expression(ops[2]), ".plane0"), true, true);
6823 			else
6824 				e = &emit_op(result_type, id, to_expression(ops[2]), true, true);
6825 			if (var)
6826 				e->loaded_from = var->self;
6827 		}
6828 		break;
6829 	}
6830 
6831 	// Casting
6832 	case OpQuantizeToF16:
6833 	{
6834 		uint32_t result_type = ops[0];
6835 		uint32_t id = ops[1];
6836 		uint32_t arg = ops[2];
6837 
6838 		string exp;
6839 		auto &type = get<SPIRType>(result_type);
6840 
6841 		switch (type.vecsize)
6842 		{
6843 		case 1:
6844 			exp = join("float(half(", to_expression(arg), "))");
6845 			break;
6846 		case 2:
6847 			exp = join("float2(half2(", to_expression(arg), "))");
6848 			break;
6849 		case 3:
6850 			exp = join("float3(half3(", to_expression(arg), "))");
6851 			break;
6852 		case 4:
6853 			exp = join("float4(half4(", to_expression(arg), "))");
6854 			break;
6855 		default:
6856 			SPIRV_CROSS_THROW("Illegal argument to OpQuantizeToF16.");
6857 		}
6858 
6859 		emit_op(result_type, id, exp, should_forward(arg));
6860 		break;
6861 	}
6862 
6863 	case OpInBoundsAccessChain:
6864 	case OpAccessChain:
6865 	case OpPtrAccessChain:
6866 		if (is_tessellation_shader())
6867 		{
6868 			if (!emit_tessellation_access_chain(ops, instruction.length))
6869 				CompilerGLSL::emit_instruction(instruction);
6870 		}
6871 		else
6872 			CompilerGLSL::emit_instruction(instruction);
6873 		break;
6874 
6875 	case OpStore:
6876 		if (is_out_of_bounds_tessellation_level(ops[0]))
6877 			break;
6878 
6879 		if (maybe_emit_array_assignment(ops[0], ops[1]))
6880 			break;
6881 
6882 		CompilerGLSL::emit_instruction(instruction);
6883 		break;
6884 
6885 	// Compute barriers
6886 	case OpMemoryBarrier:
6887 		emit_barrier(0, ops[0], ops[1]);
6888 		break;
6889 
6890 	case OpControlBarrier:
6891 		// In GLSL a memory barrier is often followed by a control barrier.
6892 		// But in MSL, memory barriers are also control barriers, so don't
6893 		// emit a simple control barrier if a memory barrier has just been emitted.
6894 		if (previous_instruction_opcode != OpMemoryBarrier)
6895 			emit_barrier(ops[0], ops[1], ops[2]);
6896 		break;
6897 
6898 	case OpOuterProduct:
6899 	{
6900 		uint32_t result_type = ops[0];
6901 		uint32_t id = ops[1];
6902 		uint32_t a = ops[2];
6903 		uint32_t b = ops[3];
6904 
6905 		auto &type = get<SPIRType>(result_type);
6906 		string expr = type_to_glsl_constructor(type);
6907 		expr += "(";
6908 		for (uint32_t col = 0; col < type.columns; col++)
6909 		{
6910 			expr += to_enclosed_expression(a);
6911 			expr += " * ";
6912 			expr += to_extract_component_expression(b, col);
6913 			if (col + 1 < type.columns)
6914 				expr += ", ";
6915 		}
6916 		expr += ")";
6917 		emit_op(result_type, id, expr, should_forward(a) && should_forward(b));
6918 		inherit_expression_dependencies(id, a);
6919 		inherit_expression_dependencies(id, b);
6920 		break;
6921 	}
6922 
6923 	case OpVectorTimesMatrix:
6924 	case OpMatrixTimesVector:
6925 	{
6926 		if (!msl_options.invariant_float_math)
6927 		{
6928 			CompilerGLSL::emit_instruction(instruction);
6929 			break;
6930 		}
6931 
6932 		// If the matrix needs transpose, just flip the multiply order.
6933 		auto *e = maybe_get<SPIRExpression>(ops[opcode == OpMatrixTimesVector ? 2 : 3]);
6934 		if (e && e->need_transpose)
6935 		{
6936 			e->need_transpose = false;
6937 			string expr;
6938 
6939 			if (opcode == OpMatrixTimesVector)
6940 			{
6941 				expr = join("spvFMulVectorMatrix(", to_enclosed_unpacked_expression(ops[3]), ", ",
6942 				            to_unpacked_row_major_matrix_expression(ops[2]), ")");
6943 			}
6944 			else
6945 			{
6946 				expr = join("spvFMulMatrixVector(", to_unpacked_row_major_matrix_expression(ops[3]), ", ",
6947 				            to_enclosed_unpacked_expression(ops[2]), ")");
6948 			}
6949 
6950 			bool forward = should_forward(ops[2]) && should_forward(ops[3]);
6951 			emit_op(ops[0], ops[1], expr, forward);
6952 			e->need_transpose = true;
6953 			inherit_expression_dependencies(ops[1], ops[2]);
6954 			inherit_expression_dependencies(ops[1], ops[3]);
6955 		}
6956 		else
6957 		{
6958 			if (opcode == OpMatrixTimesVector)
6959 				MSL_BFOP(spvFMulMatrixVector);
6960 			else
6961 				MSL_BFOP(spvFMulVectorMatrix);
6962 		}
6963 		break;
6964 	}
6965 
6966 	case OpMatrixTimesMatrix:
6967 	{
6968 		if (!msl_options.invariant_float_math)
6969 		{
6970 			CompilerGLSL::emit_instruction(instruction);
6971 			break;
6972 		}
6973 
6974 		auto *a = maybe_get<SPIRExpression>(ops[2]);
6975 		auto *b = maybe_get<SPIRExpression>(ops[3]);
6976 
6977 		// If both matrices need transpose, we can multiply in flipped order and tag the expression as transposed.
6978 		// a^T * b^T = (b * a)^T.
6979 		if (a && b && a->need_transpose && b->need_transpose)
6980 		{
6981 			a->need_transpose = false;
6982 			b->need_transpose = false;
6983 
6984 			auto expr =
6985 			    join("spvFMulMatrixMatrix(", enclose_expression(to_unpacked_row_major_matrix_expression(ops[3])), ", ",
6986 			         enclose_expression(to_unpacked_row_major_matrix_expression(ops[2])), ")");
6987 
6988 			bool forward = should_forward(ops[2]) && should_forward(ops[3]);
6989 			auto &e = emit_op(ops[0], ops[1], expr, forward);
6990 			e.need_transpose = true;
6991 			a->need_transpose = true;
6992 			b->need_transpose = true;
6993 			inherit_expression_dependencies(ops[1], ops[2]);
6994 			inherit_expression_dependencies(ops[1], ops[3]);
6995 		}
6996 		else
6997 			MSL_BFOP(spvFMulMatrixMatrix);
6998 
6999 		break;
7000 	}
7001 
7002 	case OpIAddCarry:
7003 	case OpISubBorrow:
7004 	{
7005 		uint32_t result_type = ops[0];
7006 		uint32_t result_id = ops[1];
7007 		uint32_t op0 = ops[2];
7008 		uint32_t op1 = ops[3];
7009 		auto &type = get<SPIRType>(result_type);
7010 		emit_uninitialized_temporary_expression(result_type, result_id);
7011 
7012 		auto &res_type = get<SPIRType>(type.member_types[1]);
7013 		if (opcode == OpIAddCarry)
7014 		{
7015 			statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " + ",
7016 			          to_enclosed_expression(op1), ";");
7017 			statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
7018 			          "(1), ", type_to_glsl(res_type), "(0), ", to_expression(result_id), ".", to_member_name(type, 0),
7019 			          " >= max(", to_expression(op0), ", ", to_expression(op1), "));");
7020 		}
7021 		else
7022 		{
7023 			statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " - ",
7024 			          to_enclosed_expression(op1), ";");
7025 			statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
7026 			          "(1), ", type_to_glsl(res_type), "(0), ", to_enclosed_expression(op0),
7027 			          " >= ", to_enclosed_expression(op1), ");");
7028 		}
7029 		break;
7030 	}
7031 
7032 	case OpUMulExtended:
7033 	case OpSMulExtended:
7034 	{
7035 		uint32_t result_type = ops[0];
7036 		uint32_t result_id = ops[1];
7037 		uint32_t op0 = ops[2];
7038 		uint32_t op1 = ops[3];
7039 		auto &type = get<SPIRType>(result_type);
7040 		emit_uninitialized_temporary_expression(result_type, result_id);
7041 
7042 		statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " * ",
7043 		          to_enclosed_expression(op1), ";");
7044 		statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(", to_expression(op0), ", ",
7045 		          to_expression(op1), ");");
7046 		break;
7047 	}
7048 
7049 	case OpArrayLength:
7050 	{
7051 		auto &type = expression_type(ops[2]);
7052 		uint32_t offset = type_struct_member_offset(type, ops[3]);
7053 		uint32_t stride = type_struct_member_array_stride(type, ops[3]);
7054 
7055 		auto expr = join("(", to_buffer_size_expression(ops[2]), " - ", offset, ") / ", stride);
7056 		emit_op(ops[0], ops[1], expr, true);
7057 		break;
7058 	}
7059 
7060 	// SPV_INTEL_shader_integer_functions2
7061 	case OpUCountLeadingZerosINTEL:
7062 		MSL_UFOP(clz);
7063 		break;
7064 
7065 	case OpUCountTrailingZerosINTEL:
7066 		MSL_UFOP(ctz);
7067 		break;
7068 
7069 	case OpAbsISubINTEL:
7070 	case OpAbsUSubINTEL:
7071 		MSL_BFOP(absdiff);
7072 		break;
7073 
7074 	case OpIAddSatINTEL:
7075 	case OpUAddSatINTEL:
7076 		MSL_BFOP(addsat);
7077 		break;
7078 
7079 	case OpIAverageINTEL:
7080 	case OpUAverageINTEL:
7081 		MSL_BFOP(hadd);
7082 		break;
7083 
7084 	case OpIAverageRoundedINTEL:
7085 	case OpUAverageRoundedINTEL:
7086 		MSL_BFOP(rhadd);
7087 		break;
7088 
7089 	case OpISubSatINTEL:
7090 	case OpUSubSatINTEL:
7091 		MSL_BFOP(subsat);
7092 		break;
7093 
7094 	case OpIMul32x16INTEL:
7095 	{
7096 		uint32_t result_type = ops[0];
7097 		uint32_t id = ops[1];
7098 		uint32_t a = ops[2], b = ops[3];
7099 		bool forward = should_forward(a) && should_forward(b);
7100 		emit_op(result_type, id, join("int(short(", to_expression(a), ")) * int(short(", to_expression(b), "))"),
7101 		        forward);
7102 		inherit_expression_dependencies(id, a);
7103 		inherit_expression_dependencies(id, b);
7104 		break;
7105 	}
7106 
7107 	case OpUMul32x16INTEL:
7108 	{
7109 		uint32_t result_type = ops[0];
7110 		uint32_t id = ops[1];
7111 		uint32_t a = ops[2], b = ops[3];
7112 		bool forward = should_forward(a) && should_forward(b);
7113 		emit_op(result_type, id, join("uint(ushort(", to_expression(a), ")) * uint(ushort(", to_expression(b), "))"),
7114 		        forward);
7115 		inherit_expression_dependencies(id, a);
7116 		inherit_expression_dependencies(id, b);
7117 		break;
7118 	}
7119 
7120 	case OpIsHelperInvocationEXT:
7121 		if (msl_options.is_ios())
7122 			SPIRV_CROSS_THROW("simd_is_helper_thread() is only supported on macOS.");
7123 		else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
7124 			SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
7125 		emit_op(ops[0], ops[1], "simd_is_helper_thread()", false);
7126 		break;
7127 
7128 	case OpBeginInvocationInterlockEXT:
7129 	case OpEndInvocationInterlockEXT:
7130 		if (!msl_options.supports_msl_version(2, 0))
7131 			SPIRV_CROSS_THROW("Raster order groups require MSL 2.0.");
7132 		break; // Nothing to do in the body
7133 
7134 	default:
7135 		CompilerGLSL::emit_instruction(instruction);
7136 		break;
7137 	}
7138 
7139 	previous_instruction_opcode = opcode;
7140 }
7141 
emit_texture_op(const Instruction & i,bool sparse)7142 void CompilerMSL::emit_texture_op(const Instruction &i, bool sparse)
7143 {
7144 	if (sparse)
7145 		SPIRV_CROSS_THROW("Sparse feedback not yet supported in MSL.");
7146 
7147 	if (msl_options.is_ios() && msl_options.ios_use_framebuffer_fetch_subpasses)
7148 	{
7149 		auto *ops = stream(i);
7150 
7151 		uint32_t result_type_id = ops[0];
7152 		uint32_t id = ops[1];
7153 		uint32_t img = ops[2];
7154 
7155 		auto &type = expression_type(img);
7156 		auto &imgtype = get<SPIRType>(type.self);
7157 
7158 		// Use Metal's native frame-buffer fetch API for subpass inputs.
7159 		if (imgtype.image.dim == DimSubpassData)
7160 		{
7161 			// Subpass inputs cannot be invalidated,
7162 			// so just forward the expression directly.
7163 			string expr = to_expression(img);
7164 			emit_op(result_type_id, id, expr, true);
7165 			return;
7166 		}
7167 	}
7168 
7169 	// Fallback to default implementation
7170 	CompilerGLSL::emit_texture_op(i, sparse);
7171 }
7172 
emit_barrier(uint32_t id_exe_scope,uint32_t id_mem_scope,uint32_t id_mem_sem)7173 void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uint32_t id_mem_sem)
7174 {
7175 	if (get_execution_model() != ExecutionModelGLCompute && get_execution_model() != ExecutionModelTessellationControl)
7176 		return;
7177 
7178 	uint32_t exe_scope = id_exe_scope ? evaluate_constant_u32(id_exe_scope) : uint32_t(ScopeInvocation);
7179 	uint32_t mem_scope = id_mem_scope ? evaluate_constant_u32(id_mem_scope) : uint32_t(ScopeInvocation);
7180 	// Use the wider of the two scopes (smaller value)
7181 	exe_scope = min(exe_scope, mem_scope);
7182 
7183 	string bar_stmt;
7184 	if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2))
7185 		bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier";
7186 	else
7187 		bar_stmt = "threadgroup_barrier";
7188 	bar_stmt += "(";
7189 
7190 	uint32_t mem_sem = id_mem_sem ? evaluate_constant_u32(id_mem_sem) : uint32_t(MemorySemanticsMaskNone);
7191 
7192 	// Use the | operator to combine flags if we can.
7193 	if (msl_options.supports_msl_version(1, 2))
7194 	{
7195 		string mem_flags = "";
7196 		// For tesc shaders, this also affects objects in the Output storage class.
7197 		// Since in Metal, these are placed in a device buffer, we have to sync device memory here.
7198 		if (get_execution_model() == ExecutionModelTessellationControl ||
7199 		    (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)))
7200 			mem_flags += "mem_flags::mem_device";
7201 
7202 		// Fix tessellation patch function processing
7203 		if (get_execution_model() == ExecutionModelTessellationControl ||
7204 		    (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
7205 		{
7206 			if (!mem_flags.empty())
7207 				mem_flags += " | ";
7208 			mem_flags += "mem_flags::mem_threadgroup";
7209 		}
7210 		if (mem_sem & MemorySemanticsImageMemoryMask)
7211 		{
7212 			if (!mem_flags.empty())
7213 				mem_flags += " | ";
7214 			mem_flags += "mem_flags::mem_texture";
7215 		}
7216 
7217 		if (mem_flags.empty())
7218 			mem_flags = "mem_flags::mem_none";
7219 
7220 		bar_stmt += mem_flags;
7221 	}
7222 	else
7223 	{
7224 		if ((mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)) &&
7225 		    (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
7226 			bar_stmt += "mem_flags::mem_device_and_threadgroup";
7227 		else if (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask))
7228 			bar_stmt += "mem_flags::mem_device";
7229 		else if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask))
7230 			bar_stmt += "mem_flags::mem_threadgroup";
7231 		else if (mem_sem & MemorySemanticsImageMemoryMask)
7232 			bar_stmt += "mem_flags::mem_texture";
7233 		else
7234 			bar_stmt += "mem_flags::mem_none";
7235 	}
7236 
7237 	bar_stmt += ");";
7238 
7239 	statement(bar_stmt);
7240 
7241 	assert(current_emitting_block);
7242 	flush_control_dependent_expressions(current_emitting_block->self);
7243 	flush_all_active_variables();
7244 }
7245 
emit_array_copy(const string & lhs,uint32_t rhs_id,StorageClass lhs_storage,StorageClass rhs_storage)7246 void CompilerMSL::emit_array_copy(const string &lhs, uint32_t rhs_id, StorageClass lhs_storage,
7247                                   StorageClass rhs_storage)
7248 {
7249 	// Allow Metal to use the array<T> template to make arrays a value type.
7250 	// This, however, cannot be used for threadgroup address specifiers, so consider the custom array copy as fallback.
7251 	bool lhs_thread = (lhs_storage == StorageClassOutput || lhs_storage == StorageClassFunction ||
7252 	                   lhs_storage == StorageClassGeneric || lhs_storage == StorageClassPrivate);
7253 	bool rhs_thread = (rhs_storage == StorageClassInput || rhs_storage == StorageClassFunction ||
7254 	                   rhs_storage == StorageClassGeneric || rhs_storage == StorageClassPrivate);
7255 
7256 	// If threadgroup storage qualifiers are *not* used:
7257 	// Avoid spvCopy* wrapper functions; Otherwise, spvUnsafeArray<> template cannot be used with that storage qualifier.
7258 	if (lhs_thread && rhs_thread && !using_builtin_array())
7259 	{
7260 		statement(lhs, " = ", to_expression(rhs_id), ";");
7261 	}
7262 	else
7263 	{
7264 		// Assignment from an array initializer is fine.
7265 		auto &type = expression_type(rhs_id);
7266 		auto *var = maybe_get_backing_variable(rhs_id);
7267 
7268 		// Unfortunately, we cannot template on address space in MSL,
7269 		// so explicit address space redirection it is ...
7270 		bool is_constant = false;
7271 		if (ir.ids[rhs_id].get_type() == TypeConstant)
7272 		{
7273 			is_constant = true;
7274 		}
7275 		else if (var && var->remapped_variable && var->statically_assigned &&
7276 		         ir.ids[var->static_expression].get_type() == TypeConstant)
7277 		{
7278 			is_constant = true;
7279 		}
7280 		else if (rhs_storage == StorageClassUniform)
7281 		{
7282 			is_constant = true;
7283 		}
7284 
7285 		// For the case where we have OpLoad triggering an array copy,
7286 		// we cannot easily detect this case ahead of time since it's
7287 		// context dependent. We might have to force a recompile here
7288 		// if this is the only use of array copies in our shader.
7289 		if (type.array.size() > 1)
7290 		{
7291 			if (type.array.size() > kArrayCopyMultidimMax)
7292 				SPIRV_CROSS_THROW("Cannot support this many dimensions for arrays of arrays.");
7293 			auto func = static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + type.array.size());
7294 			add_spv_func_and_recompile(func);
7295 		}
7296 		else
7297 			add_spv_func_and_recompile(SPVFuncImplArrayCopy);
7298 
7299 		const char *tag = nullptr;
7300 		if (lhs_thread && is_constant)
7301 			tag = "FromConstantToStack";
7302 		else if (lhs_storage == StorageClassWorkgroup && is_constant)
7303 			tag = "FromConstantToThreadGroup";
7304 		else if (lhs_thread && rhs_thread)
7305 			tag = "FromStackToStack";
7306 		else if (lhs_storage == StorageClassWorkgroup && rhs_thread)
7307 			tag = "FromStackToThreadGroup";
7308 		else if (lhs_thread && rhs_storage == StorageClassWorkgroup)
7309 			tag = "FromThreadGroupToStack";
7310 		else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassWorkgroup)
7311 			tag = "FromThreadGroupToThreadGroup";
7312 		else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassStorageBuffer)
7313 			tag = "FromDeviceToDevice";
7314 		else if (lhs_storage == StorageClassStorageBuffer && is_constant)
7315 			tag = "FromConstantToDevice";
7316 		else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassWorkgroup)
7317 			tag = "FromThreadGroupToDevice";
7318 		else if (lhs_storage == StorageClassStorageBuffer && rhs_thread)
7319 			tag = "FromStackToDevice";
7320 		else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassStorageBuffer)
7321 			tag = "FromDeviceToThreadGroup";
7322 		else if (lhs_thread && rhs_storage == StorageClassStorageBuffer)
7323 			tag = "FromDeviceToStack";
7324 		else
7325 			SPIRV_CROSS_THROW("Unknown storage class used for copying arrays.");
7326 
7327 		// Pass internal array of spvUnsafeArray<> into wrapper functions
7328 		if (lhs_thread && !msl_options.force_native_arrays)
7329 			statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ".elements, ", to_expression(rhs_id), ");");
7330 		else if (rhs_thread && !msl_options.force_native_arrays)
7331 			statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ", ", to_expression(rhs_id), ".elements);");
7332 		else
7333 			statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ", ", to_expression(rhs_id), ");");
7334 	}
7335 }
7336 
7337 // Since MSL does not allow arrays to be copied via simple variable assignment,
7338 // if the LHS and RHS represent an assignment of an entire array, it must be
7339 // implemented by calling an array copy function.
7340 // Returns whether the struct assignment was emitted.
maybe_emit_array_assignment(uint32_t id_lhs,uint32_t id_rhs)7341 bool CompilerMSL::maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs)
7342 {
7343 	// We only care about assignments of an entire array
7344 	auto &type = expression_type(id_rhs);
7345 	if (type.array.size() == 0)
7346 		return false;
7347 
7348 	auto *var = maybe_get<SPIRVariable>(id_lhs);
7349 
7350 	// Is this a remapped, static constant? Don't do anything.
7351 	if (var && var->remapped_variable && var->statically_assigned)
7352 		return true;
7353 
7354 	if (ir.ids[id_rhs].get_type() == TypeConstant && var && var->deferred_declaration)
7355 	{
7356 		// Special case, if we end up declaring a variable when assigning the constant array,
7357 		// we can avoid the copy by directly assigning the constant expression.
7358 		// This is likely necessary to be able to use a variable as a true look-up table, as it is unlikely
7359 		// the compiler will be able to optimize the spvArrayCopy() into a constant LUT.
7360 		// After a variable has been declared, we can no longer assign constant arrays in MSL unfortunately.
7361 		statement(to_expression(id_lhs), " = ", constant_expression(get<SPIRConstant>(id_rhs)), ";");
7362 		return true;
7363 	}
7364 
7365 	// Ensure the LHS variable has been declared
7366 	auto *p_v_lhs = maybe_get_backing_variable(id_lhs);
7367 	if (p_v_lhs)
7368 		flush_variable_declaration(p_v_lhs->self);
7369 
7370 	emit_array_copy(to_expression(id_lhs), id_rhs, get_expression_effective_storage_class(id_lhs),
7371 	                get_expression_effective_storage_class(id_rhs));
7372 	register_write(id_lhs);
7373 
7374 	return true;
7375 }
7376 
7377 // Emits one of the atomic functions. In MSL, the atomic functions operate on pointers
emit_atomic_func_op(uint32_t result_type,uint32_t result_id,const char * op,uint32_t mem_order_1,uint32_t mem_order_2,bool has_mem_order_2,uint32_t obj,uint32_t op1,bool op1_is_pointer,bool op1_is_literal,uint32_t op2)7378 void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, uint32_t mem_order_1,
7379                                       uint32_t mem_order_2, bool has_mem_order_2, uint32_t obj, uint32_t op1,
7380                                       bool op1_is_pointer, bool op1_is_literal, uint32_t op2)
7381 {
7382 	string exp = string(op) + "(";
7383 
7384 	auto &type = get_pointee_type(expression_type(obj));
7385 	exp += "(";
7386 	auto *var = maybe_get_backing_variable(obj);
7387 	if (!var)
7388 		SPIRV_CROSS_THROW("No backing variable for atomic operation.");
7389 
7390 	// Emulate texture2D atomic operations
7391 	const auto &res_type = get<SPIRType>(var->basetype);
7392 	if (res_type.storage == StorageClassUniformConstant && res_type.basetype == SPIRType::Image)
7393 	{
7394 		exp += "device";
7395 	}
7396 	else
7397 	{
7398 		exp += get_argument_address_space(*var);
7399 	}
7400 
7401 	exp += " atomic_";
7402 	exp += type_to_glsl(type);
7403 	exp += "*)";
7404 
7405 	exp += "&";
7406 	exp += to_enclosed_expression(obj);
7407 
7408 	bool is_atomic_compare_exchange_strong = op1_is_pointer && op1;
7409 
7410 	if (is_atomic_compare_exchange_strong)
7411 	{
7412 		assert(strcmp(op, "atomic_compare_exchange_weak_explicit") == 0);
7413 		assert(op2);
7414 		assert(has_mem_order_2);
7415 		exp += ", &";
7416 		exp += to_name(result_id);
7417 		exp += ", ";
7418 		exp += to_expression(op2);
7419 		exp += ", ";
7420 		exp += get_memory_order(mem_order_1);
7421 		exp += ", ";
7422 		exp += get_memory_order(mem_order_2);
7423 		exp += ")";
7424 
7425 		// MSL only supports the weak atomic compare exchange, so emit a CAS loop here.
7426 		// The MSL function returns false if the atomic write fails OR the comparison test fails,
7427 		// so we must validate that it wasn't the comparison test that failed before continuing
7428 		// the CAS loop, otherwise it will loop infinitely, with the comparison test always failing.
7429 		// The function updates the comparitor value from the memory value, so the additional
7430 		// comparison test evaluates the memory value against the expected value.
7431 		emit_uninitialized_temporary_expression(result_type, result_id);
7432 		statement("do");
7433 		begin_scope();
7434 		statement(to_name(result_id), " = ", to_expression(op1), ";");
7435 		end_scope_decl(join("while (!", exp, " && ", to_name(result_id), " == ", to_enclosed_expression(op1), ")"));
7436 	}
7437 	else
7438 	{
7439 		assert(strcmp(op, "atomic_compare_exchange_weak_explicit") != 0);
7440 		if (op1)
7441 		{
7442 			if (op1_is_literal)
7443 				exp += join(", ", op1);
7444 			else
7445 				exp += ", " + to_expression(op1);
7446 		}
7447 		if (op2)
7448 			exp += ", " + to_expression(op2);
7449 
7450 		exp += string(", ") + get_memory_order(mem_order_1);
7451 		if (has_mem_order_2)
7452 			exp += string(", ") + get_memory_order(mem_order_2);
7453 
7454 		exp += ")";
7455 
7456 		if (strcmp(op, "atomic_store_explicit") != 0)
7457 			emit_op(result_type, result_id, exp, false);
7458 		else
7459 			statement(exp, ";");
7460 	}
7461 
7462 	flush_all_atomic_capable_variables();
7463 }
7464 
7465 // Metal only supports relaxed memory order for now
get_memory_order(uint32_t)7466 const char *CompilerMSL::get_memory_order(uint32_t)
7467 {
7468 	return "memory_order_relaxed";
7469 }
7470 
7471 // Override for MSL-specific extension syntax instructions
emit_glsl_op(uint32_t result_type,uint32_t id,uint32_t eop,const uint32_t * args,uint32_t count)7472 void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
7473 {
7474 	auto op = static_cast<GLSLstd450>(eop);
7475 
7476 	// If we need to do implicit bitcasts, make sure we do it with the correct type.
7477 	uint32_t integer_width = get_integer_width_for_glsl_instruction(op, args, count);
7478 	auto int_type = to_signed_basetype(integer_width);
7479 	auto uint_type = to_unsigned_basetype(integer_width);
7480 
7481 	switch (op)
7482 	{
7483 	case GLSLstd450Atan2:
7484 		emit_binary_func_op(result_type, id, args[0], args[1], "atan2");
7485 		break;
7486 	case GLSLstd450InverseSqrt:
7487 		emit_unary_func_op(result_type, id, args[0], "rsqrt");
7488 		break;
7489 	case GLSLstd450RoundEven:
7490 		emit_unary_func_op(result_type, id, args[0], "rint");
7491 		break;
7492 
7493 	case GLSLstd450FindILsb:
7494 	{
7495 		// In this template version of findLSB, we return T.
7496 		auto basetype = expression_type(args[0]).basetype;
7497 		emit_unary_func_op_cast(result_type, id, args[0], "spvFindLSB", basetype, basetype);
7498 		break;
7499 	}
7500 
7501 	case GLSLstd450FindSMsb:
7502 		emit_unary_func_op_cast(result_type, id, args[0], "spvFindSMSB", int_type, int_type);
7503 		break;
7504 
7505 	case GLSLstd450FindUMsb:
7506 		emit_unary_func_op_cast(result_type, id, args[0], "spvFindUMSB", uint_type, uint_type);
7507 		break;
7508 
7509 	case GLSLstd450PackSnorm4x8:
7510 		emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm4x8");
7511 		break;
7512 	case GLSLstd450PackUnorm4x8:
7513 		emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm4x8");
7514 		break;
7515 	case GLSLstd450PackSnorm2x16:
7516 		emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm2x16");
7517 		break;
7518 	case GLSLstd450PackUnorm2x16:
7519 		emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm2x16");
7520 		break;
7521 
7522 	case GLSLstd450PackHalf2x16:
7523 	{
7524 		auto expr = join("as_type<uint>(half2(", to_expression(args[0]), "))");
7525 		emit_op(result_type, id, expr, should_forward(args[0]));
7526 		inherit_expression_dependencies(id, args[0]);
7527 		break;
7528 	}
7529 
7530 	case GLSLstd450UnpackSnorm4x8:
7531 		emit_unary_func_op(result_type, id, args[0], "unpack_snorm4x8_to_float");
7532 		break;
7533 	case GLSLstd450UnpackUnorm4x8:
7534 		emit_unary_func_op(result_type, id, args[0], "unpack_unorm4x8_to_float");
7535 		break;
7536 	case GLSLstd450UnpackSnorm2x16:
7537 		emit_unary_func_op(result_type, id, args[0], "unpack_snorm2x16_to_float");
7538 		break;
7539 	case GLSLstd450UnpackUnorm2x16:
7540 		emit_unary_func_op(result_type, id, args[0], "unpack_unorm2x16_to_float");
7541 		break;
7542 
7543 	case GLSLstd450UnpackHalf2x16:
7544 	{
7545 		auto expr = join("float2(as_type<half2>(", to_expression(args[0]), "))");
7546 		emit_op(result_type, id, expr, should_forward(args[0]));
7547 		inherit_expression_dependencies(id, args[0]);
7548 		break;
7549 	}
7550 
7551 	case GLSLstd450PackDouble2x32:
7552 		emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450PackDouble2x32"); // Currently unsupported
7553 		break;
7554 	case GLSLstd450UnpackDouble2x32:
7555 		emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450UnpackDouble2x32"); // Currently unsupported
7556 		break;
7557 
7558 	case GLSLstd450MatrixInverse:
7559 	{
7560 		auto &mat_type = get<SPIRType>(result_type);
7561 		switch (mat_type.columns)
7562 		{
7563 		case 2:
7564 			emit_unary_func_op(result_type, id, args[0], "spvInverse2x2");
7565 			break;
7566 		case 3:
7567 			emit_unary_func_op(result_type, id, args[0], "spvInverse3x3");
7568 			break;
7569 		case 4:
7570 			emit_unary_func_op(result_type, id, args[0], "spvInverse4x4");
7571 			break;
7572 		default:
7573 			break;
7574 		}
7575 		break;
7576 	}
7577 
7578 	case GLSLstd450FMin:
7579 		// If the result type isn't float, don't bother calling the specific
7580 		// precise::/fast:: version. Metal doesn't have those for half and
7581 		// double types.
7582 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
7583 			emit_binary_func_op(result_type, id, args[0], args[1], "min");
7584 		else
7585 			emit_binary_func_op(result_type, id, args[0], args[1], "fast::min");
7586 		break;
7587 
7588 	case GLSLstd450FMax:
7589 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
7590 			emit_binary_func_op(result_type, id, args[0], args[1], "max");
7591 		else
7592 			emit_binary_func_op(result_type, id, args[0], args[1], "fast::max");
7593 		break;
7594 
7595 	case GLSLstd450FClamp:
7596 		// TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
7597 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
7598 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
7599 		else
7600 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "fast::clamp");
7601 		break;
7602 
7603 	case GLSLstd450NMin:
7604 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
7605 			emit_binary_func_op(result_type, id, args[0], args[1], "min");
7606 		else
7607 			emit_binary_func_op(result_type, id, args[0], args[1], "precise::min");
7608 		break;
7609 
7610 	case GLSLstd450NMax:
7611 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
7612 			emit_binary_func_op(result_type, id, args[0], args[1], "max");
7613 		else
7614 			emit_binary_func_op(result_type, id, args[0], args[1], "precise::max");
7615 		break;
7616 
7617 	case GLSLstd450NClamp:
7618 		// TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
7619 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
7620 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
7621 		else
7622 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "precise::clamp");
7623 		break;
7624 
7625 		// TODO:
7626 		//        GLSLstd450InterpolateAtCentroid (centroid_no_perspective qualifier)
7627 		//        GLSLstd450InterpolateAtSample (sample_no_perspective qualifier)
7628 		//        GLSLstd450InterpolateAtOffset
7629 
7630 	case GLSLstd450Distance:
7631 		// MSL does not support scalar versions here.
7632 		if (expression_type(args[0]).vecsize == 1)
7633 		{
7634 			// Equivalent to length(a - b) -> abs(a - b).
7635 			emit_op(result_type, id,
7636 			        join("abs(", to_enclosed_unpacked_expression(args[0]), " - ",
7637 			             to_enclosed_unpacked_expression(args[1]), ")"),
7638 			        should_forward(args[0]) && should_forward(args[1]));
7639 			inherit_expression_dependencies(id, args[0]);
7640 			inherit_expression_dependencies(id, args[1]);
7641 		}
7642 		else
7643 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
7644 		break;
7645 
7646 	case GLSLstd450Length:
7647 		// MSL does not support scalar versions here.
7648 		if (expression_type(args[0]).vecsize == 1)
7649 		{
7650 			// Equivalent to abs().
7651 			emit_unary_func_op(result_type, id, args[0], "abs");
7652 		}
7653 		else
7654 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
7655 		break;
7656 
7657 	case GLSLstd450Normalize:
7658 		// MSL does not support scalar versions here.
7659 		if (expression_type(args[0]).vecsize == 1)
7660 		{
7661 			// Returns -1 or 1 for valid input, sign() does the job.
7662 			emit_unary_func_op(result_type, id, args[0], "sign");
7663 		}
7664 		else
7665 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
7666 		break;
7667 
7668 	case GLSLstd450Reflect:
7669 		if (get<SPIRType>(result_type).vecsize == 1)
7670 			emit_binary_func_op(result_type, id, args[0], args[1], "spvReflect");
7671 		else
7672 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
7673 		break;
7674 
7675 	case GLSLstd450Refract:
7676 		if (get<SPIRType>(result_type).vecsize == 1)
7677 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvRefract");
7678 		else
7679 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
7680 		break;
7681 
7682 	case GLSLstd450FaceForward:
7683 		if (get<SPIRType>(result_type).vecsize == 1)
7684 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvFaceForward");
7685 		else
7686 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
7687 		break;
7688 
7689 	case GLSLstd450Modf:
7690 	case GLSLstd450Frexp:
7691 	{
7692 		// Special case. If the variable is a scalar access chain, we cannot use it directly. We have to emit a temporary.
7693 		auto *ptr = maybe_get<SPIRExpression>(args[1]);
7694 		if (ptr && ptr->access_chain && is_scalar(expression_type(args[1])))
7695 		{
7696 			register_call_out_argument(args[1]);
7697 			forced_temporaries.insert(id);
7698 
7699 			// Need to create temporaries and copy over to access chain after.
7700 			// We cannot directly take the reference of a vector swizzle in MSL, even if it's scalar ...
7701 			uint32_t &tmp_id = extra_sub_expressions[id];
7702 			if (!tmp_id)
7703 				tmp_id = ir.increase_bound_by(1);
7704 
7705 			uint32_t tmp_type_id = get_pointee_type_id(ptr->expression_type);
7706 			emit_uninitialized_temporary_expression(tmp_type_id, tmp_id);
7707 			emit_binary_func_op(result_type, id, args[0], tmp_id, eop == GLSLstd450Modf ? "modf" : "frexp");
7708 			statement(to_expression(args[1]), " = ", to_expression(tmp_id), ";");
7709 		}
7710 		else
7711 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
7712 		break;
7713 	}
7714 
7715 	default:
7716 		CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
7717 		break;
7718 	}
7719 }
7720 
emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type,uint32_t id,uint32_t eop,const uint32_t * args,uint32_t count)7721 void CompilerMSL::emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type, uint32_t id, uint32_t eop,
7722                                                         const uint32_t *args, uint32_t count)
7723 {
7724 	enum AMDShaderTrinaryMinMax
7725 	{
7726 		FMin3AMD = 1,
7727 		UMin3AMD = 2,
7728 		SMin3AMD = 3,
7729 		FMax3AMD = 4,
7730 		UMax3AMD = 5,
7731 		SMax3AMD = 6,
7732 		FMid3AMD = 7,
7733 		UMid3AMD = 8,
7734 		SMid3AMD = 9
7735 	};
7736 
7737 	if (!msl_options.supports_msl_version(2, 1))
7738 		SPIRV_CROSS_THROW("Trinary min/max functions require MSL 2.1.");
7739 
7740 	auto op = static_cast<AMDShaderTrinaryMinMax>(eop);
7741 
7742 	switch (op)
7743 	{
7744 	case FMid3AMD:
7745 	case UMid3AMD:
7746 	case SMid3AMD:
7747 		emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "median3");
7748 		break;
7749 	default:
7750 		CompilerGLSL::emit_spv_amd_shader_trinary_minmax_op(result_type, id, eop, args, count);
7751 		break;
7752 	}
7753 }
7754 
7755 // Emit a structure declaration for the specified interface variable.
emit_interface_block(uint32_t ib_var_id)7756 void CompilerMSL::emit_interface_block(uint32_t ib_var_id)
7757 {
7758 	if (ib_var_id)
7759 	{
7760 		auto &ib_var = get<SPIRVariable>(ib_var_id);
7761 		auto &ib_type = get_variable_data_type(ib_var);
7762 		assert(ib_type.basetype == SPIRType::Struct && !ib_type.member_types.empty());
7763 		emit_struct(ib_type);
7764 	}
7765 }
7766 
7767 // Emits the declaration signature of the specified function.
7768 // If this is the entry point function, Metal-specific return value and function arguments are added.
emit_function_prototype(SPIRFunction & func,const Bitset &)7769 void CompilerMSL::emit_function_prototype(SPIRFunction &func, const Bitset &)
7770 {
7771 	if (func.self != ir.default_entry_point)
7772 		add_function_overload(func);
7773 
7774 	local_variable_names = resource_names;
7775 	string decl;
7776 
7777 	processing_entry_point = func.self == ir.default_entry_point;
7778 
7779 	// Metal helper functions must be static force-inline otherwise they will cause problems when linked together in a single Metallib.
7780 	if (!processing_entry_point)
7781 		statement(force_inline);
7782 
7783 	auto &type = get<SPIRType>(func.return_type);
7784 
7785 	if (!type.array.empty() && msl_options.force_native_arrays)
7786 	{
7787 		// We cannot return native arrays in MSL, so "return" through an out variable.
7788 		decl += "void";
7789 	}
7790 	else
7791 	{
7792 		decl += func_type_decl(type);
7793 	}
7794 
7795 	decl += " ";
7796 	decl += to_name(func.self);
7797 	decl += "(";
7798 
7799 	if (!type.array.empty() && msl_options.force_native_arrays)
7800 	{
7801 		// Fake arrays returns by writing to an out array instead.
7802 		decl += "thread ";
7803 		decl += type_to_glsl(type);
7804 		decl += " (&SPIRV_Cross_return_value)";
7805 		decl += type_to_array_glsl(type);
7806 		if (!func.arguments.empty())
7807 			decl += ", ";
7808 	}
7809 
7810 	if (processing_entry_point)
7811 	{
7812 		if (msl_options.argument_buffers)
7813 			decl += entry_point_args_argument_buffer(!func.arguments.empty());
7814 		else
7815 			decl += entry_point_args_classic(!func.arguments.empty());
7816 
7817 		// If entry point function has variables that require early declaration,
7818 		// ensure they each have an empty initializer, creating one if needed.
7819 		// This is done at this late stage because the initialization expression
7820 		// is cleared after each compilation pass.
7821 		for (auto var_id : vars_needing_early_declaration)
7822 		{
7823 			auto &ed_var = get<SPIRVariable>(var_id);
7824 			ID &initializer = ed_var.initializer;
7825 			if (!initializer)
7826 				initializer = ir.increase_bound_by(1);
7827 
7828 			// Do not override proper initializers.
7829 			if (ir.ids[initializer].get_type() == TypeNone || ir.ids[initializer].get_type() == TypeExpression)
7830 				set<SPIRExpression>(ed_var.initializer, "{}", ed_var.basetype, true);
7831 		}
7832 	}
7833 
7834 	for (auto &arg : func.arguments)
7835 	{
7836 		uint32_t name_id = arg.id;
7837 
7838 		auto *var = maybe_get<SPIRVariable>(arg.id);
7839 		if (var)
7840 		{
7841 			// If we need to modify the name of the variable, make sure we modify the original variable.
7842 			// Our alias is just a shadow variable.
7843 			if (arg.alias_global_variable && var->basevariable)
7844 				name_id = var->basevariable;
7845 
7846 			var->parameter = &arg; // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
7847 		}
7848 
7849 		add_local_variable_name(name_id);
7850 
7851 		decl += argument_decl(arg);
7852 
7853 		bool is_dynamic_img_sampler = has_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
7854 
7855 		auto &arg_type = get<SPIRType>(arg.type);
7856 		if (arg_type.basetype == SPIRType::SampledImage && !is_dynamic_img_sampler)
7857 		{
7858 			// Manufacture automatic plane args for multiplanar texture
7859 			uint32_t planes = 1;
7860 			if (auto *constexpr_sampler = find_constexpr_sampler(name_id))
7861 				if (constexpr_sampler->ycbcr_conversion_enable)
7862 					planes = constexpr_sampler->planes;
7863 			for (uint32_t i = 1; i < planes; i++)
7864 				decl += join(", ", argument_decl(arg), plane_name_suffix, i);
7865 
7866 			// Manufacture automatic sampler arg for SampledImage texture
7867 			if (arg_type.image.dim != DimBuffer)
7868 				decl += join(", thread const ", sampler_type(arg_type), " ", to_sampler_expression(arg.id));
7869 		}
7870 
7871 		// Manufacture automatic swizzle arg.
7872 		if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(arg_type) &&
7873 		    !is_dynamic_img_sampler)
7874 		{
7875 			bool arg_is_array = !arg_type.array.empty();
7876 			decl += join(", constant uint", arg_is_array ? "* " : "& ", to_swizzle_expression(arg.id));
7877 		}
7878 
7879 		if (buffers_requiring_array_length.count(name_id))
7880 		{
7881 			bool arg_is_array = !arg_type.array.empty();
7882 			decl += join(", constant uint", arg_is_array ? "* " : "& ", to_buffer_size_expression(name_id));
7883 		}
7884 
7885 		if (&arg != &func.arguments.back())
7886 			decl += ", ";
7887 	}
7888 
7889 	decl += ")";
7890 	statement(decl);
7891 }
7892 
needs_chroma_reconstruction(const MSLConstexprSampler * constexpr_sampler)7893 static bool needs_chroma_reconstruction(const MSLConstexprSampler *constexpr_sampler)
7894 {
7895 	// For now, only multiplanar images need explicit reconstruction. GBGR and BGRG images
7896 	// use implicit reconstruction.
7897 	return constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && constexpr_sampler->planes > 1;
7898 }
7899 
7900 // Returns the texture sampling function string for the specified image and sampling characteristics.
to_function_name(const TextureFunctionNameArguments & args)7901 string CompilerMSL::to_function_name(const TextureFunctionNameArguments &args)
7902 {
7903 	VariableID img = args.base.img;
7904 	auto &imgtype = *args.base.imgtype;
7905 
7906 	const MSLConstexprSampler *constexpr_sampler = nullptr;
7907 	bool is_dynamic_img_sampler = false;
7908 	if (auto *var = maybe_get_backing_variable(img))
7909 	{
7910 		constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
7911 		is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
7912 	}
7913 
7914 	// Special-case gather. We have to alter the component being looked up
7915 	// in the swizzle case.
7916 	if (msl_options.swizzle_texture_samples && args.base.is_gather && !is_dynamic_img_sampler &&
7917 	    (!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable))
7918 	{
7919 		add_spv_func_and_recompile(imgtype.image.depth ? SPVFuncImplGatherCompareSwizzle : SPVFuncImplGatherSwizzle);
7920 		return imgtype.image.depth ? "spvGatherCompareSwizzle" : "spvGatherSwizzle";
7921 	}
7922 
7923 	auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
7924 
7925 	// Texture reference
7926 	string fname;
7927 	if (needs_chroma_reconstruction(constexpr_sampler) && !is_dynamic_img_sampler)
7928 	{
7929 		if (constexpr_sampler->planes != 2 && constexpr_sampler->planes != 3)
7930 			SPIRV_CROSS_THROW("Unhandled number of color image planes!");
7931 		// 444 images aren't downsampled, so we don't need to do linear filtering.
7932 		if (constexpr_sampler->resolution == MSL_FORMAT_RESOLUTION_444 ||
7933 		    constexpr_sampler->chroma_filter == MSL_SAMPLER_FILTER_NEAREST)
7934 		{
7935 			if (constexpr_sampler->planes == 2)
7936 				add_spv_func_and_recompile(SPVFuncImplChromaReconstructNearest2Plane);
7937 			else
7938 				add_spv_func_and_recompile(SPVFuncImplChromaReconstructNearest3Plane);
7939 			fname = "spvChromaReconstructNearest";
7940 		}
7941 		else // Linear with a downsampled format
7942 		{
7943 			fname = "spvChromaReconstructLinear";
7944 			switch (constexpr_sampler->resolution)
7945 			{
7946 			case MSL_FORMAT_RESOLUTION_444:
7947 				assert(false);
7948 				break; // not reached
7949 			case MSL_FORMAT_RESOLUTION_422:
7950 				switch (constexpr_sampler->x_chroma_offset)
7951 				{
7952 				case MSL_CHROMA_LOCATION_COSITED_EVEN:
7953 					if (constexpr_sampler->planes == 2)
7954 						add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422CositedEven2Plane);
7955 					else
7956 						add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422CositedEven3Plane);
7957 					fname += "422CositedEven";
7958 					break;
7959 				case MSL_CHROMA_LOCATION_MIDPOINT:
7960 					if (constexpr_sampler->planes == 2)
7961 						add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422Midpoint2Plane);
7962 					else
7963 						add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422Midpoint3Plane);
7964 					fname += "422Midpoint";
7965 					break;
7966 				default:
7967 					SPIRV_CROSS_THROW("Invalid chroma location.");
7968 				}
7969 				break;
7970 			case MSL_FORMAT_RESOLUTION_420:
7971 				fname += "420";
7972 				switch (constexpr_sampler->x_chroma_offset)
7973 				{
7974 				case MSL_CHROMA_LOCATION_COSITED_EVEN:
7975 					switch (constexpr_sampler->y_chroma_offset)
7976 					{
7977 					case MSL_CHROMA_LOCATION_COSITED_EVEN:
7978 						if (constexpr_sampler->planes == 2)
7979 							add_spv_func_and_recompile(
7980 							    SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane);
7981 						else
7982 							add_spv_func_and_recompile(
7983 							    SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane);
7984 						fname += "XCositedEvenYCositedEven";
7985 						break;
7986 					case MSL_CHROMA_LOCATION_MIDPOINT:
7987 						if (constexpr_sampler->planes == 2)
7988 							add_spv_func_and_recompile(
7989 							    SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane);
7990 						else
7991 							add_spv_func_and_recompile(
7992 							    SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane);
7993 						fname += "XCositedEvenYMidpoint";
7994 						break;
7995 					default:
7996 						SPIRV_CROSS_THROW("Invalid Y chroma location.");
7997 					}
7998 					break;
7999 				case MSL_CHROMA_LOCATION_MIDPOINT:
8000 					switch (constexpr_sampler->y_chroma_offset)
8001 					{
8002 					case MSL_CHROMA_LOCATION_COSITED_EVEN:
8003 						if (constexpr_sampler->planes == 2)
8004 							add_spv_func_and_recompile(
8005 							    SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane);
8006 						else
8007 							add_spv_func_and_recompile(
8008 							    SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane);
8009 						fname += "XMidpointYCositedEven";
8010 						break;
8011 					case MSL_CHROMA_LOCATION_MIDPOINT:
8012 						if (constexpr_sampler->planes == 2)
8013 							add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane);
8014 						else
8015 							add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane);
8016 						fname += "XMidpointYMidpoint";
8017 						break;
8018 					default:
8019 						SPIRV_CROSS_THROW("Invalid Y chroma location.");
8020 					}
8021 					break;
8022 				default:
8023 					SPIRV_CROSS_THROW("Invalid X chroma location.");
8024 				}
8025 				break;
8026 			default:
8027 				SPIRV_CROSS_THROW("Invalid format resolution.");
8028 			}
8029 		}
8030 	}
8031 	else
8032 	{
8033 		fname = to_expression(combined ? combined->image : img) + ".";
8034 
8035 		// Texture function and sampler
8036 		if (args.base.is_fetch)
8037 			fname += "read";
8038 		else if (args.base.is_gather)
8039 			fname += "gather";
8040 		else
8041 			fname += "sample";
8042 
8043 		if (args.has_dref)
8044 			fname += "_compare";
8045 	}
8046 
8047 	return fname;
8048 }
8049 
convert_to_f32(const string & expr,uint32_t components)8050 string CompilerMSL::convert_to_f32(const string &expr, uint32_t components)
8051 {
8052 	SPIRType t;
8053 	t.basetype = SPIRType::Float;
8054 	t.vecsize = components;
8055 	t.columns = 1;
8056 	return join(type_to_glsl_constructor(t), "(", expr, ")");
8057 }
8058 
sampling_type_needs_f32_conversion(const SPIRType & type)8059 static inline bool sampling_type_needs_f32_conversion(const SPIRType &type)
8060 {
8061 	// Double is not supported to begin with, but doesn't hurt to check for completion.
8062 	return type.basetype == SPIRType::Half || type.basetype == SPIRType::Double;
8063 }
8064 
8065 // Returns the function args for a texture sampling function for the specified image and sampling characteristics.
to_function_args(const TextureFunctionArguments & args,bool * p_forward)8066 string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool *p_forward)
8067 {
8068 	VariableID img = args.base.img;
8069 	auto &imgtype = *args.base.imgtype;
8070 	uint32_t lod = args.lod;
8071 	uint32_t grad_x = args.grad_x;
8072 	uint32_t grad_y = args.grad_y;
8073 	uint32_t bias = args.bias;
8074 
8075 	const MSLConstexprSampler *constexpr_sampler = nullptr;
8076 	bool is_dynamic_img_sampler = false;
8077 	if (auto *var = maybe_get_backing_variable(img))
8078 	{
8079 		constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
8080 		is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
8081 	}
8082 
8083 	string farg_str;
8084 	bool forward = true;
8085 
8086 	if (!is_dynamic_img_sampler)
8087 	{
8088 		// Texture reference (for some cases)
8089 		if (needs_chroma_reconstruction(constexpr_sampler))
8090 		{
8091 			// Multiplanar images need two or three textures.
8092 			farg_str += to_expression(img);
8093 			for (uint32_t i = 1; i < constexpr_sampler->planes; i++)
8094 				farg_str += join(", ", to_expression(img), plane_name_suffix, i);
8095 		}
8096 		else if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
8097 		         msl_options.swizzle_texture_samples && args.base.is_gather)
8098 		{
8099 			auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
8100 			farg_str += to_expression(combined ? combined->image : img);
8101 		}
8102 
8103 		// Sampler reference
8104 		if (!args.base.is_fetch)
8105 		{
8106 			if (!farg_str.empty())
8107 				farg_str += ", ";
8108 			farg_str += to_sampler_expression(img);
8109 		}
8110 
8111 		if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
8112 		    msl_options.swizzle_texture_samples && args.base.is_gather)
8113 		{
8114 			// Add the swizzle constant from the swizzle buffer.
8115 			farg_str += ", " + to_swizzle_expression(img);
8116 			used_swizzle_buffer = true;
8117 		}
8118 
8119 		// Swizzled gather puts the component before the other args, to allow template
8120 		// deduction to work.
8121 		if (args.component && msl_options.swizzle_texture_samples)
8122 		{
8123 			forward = should_forward(args.component);
8124 			farg_str += ", " + to_component_argument(args.component);
8125 		}
8126 	}
8127 
8128 	// Texture coordinates
8129 	forward = forward && should_forward(args.coord);
8130 	auto coord_expr = to_enclosed_expression(args.coord);
8131 	auto &coord_type = expression_type(args.coord);
8132 	bool coord_is_fp = type_is_floating_point(coord_type);
8133 	bool is_cube_fetch = false;
8134 
8135 	string tex_coords = coord_expr;
8136 	uint32_t alt_coord_component = 0;
8137 
8138 	switch (imgtype.image.dim)
8139 	{
8140 
8141 	case Dim1D:
8142 		if (coord_type.vecsize > 1)
8143 			tex_coords = enclose_expression(tex_coords) + ".x";
8144 
8145 		if (args.base.is_fetch)
8146 			tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8147 		else if (sampling_type_needs_f32_conversion(coord_type))
8148 			tex_coords = convert_to_f32(tex_coords, 1);
8149 
8150 		if (msl_options.texture_1D_as_2D)
8151 		{
8152 			if (args.base.is_fetch)
8153 				tex_coords = "uint2(" + tex_coords + ", 0)";
8154 			else
8155 				tex_coords = "float2(" + tex_coords + ", 0.5)";
8156 		}
8157 
8158 		alt_coord_component = 1;
8159 		break;
8160 
8161 	case DimBuffer:
8162 		if (coord_type.vecsize > 1)
8163 			tex_coords = enclose_expression(tex_coords) + ".x";
8164 
8165 		if (msl_options.texture_buffer_native)
8166 		{
8167 			tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8168 		}
8169 		else
8170 		{
8171 			// Metal texel buffer textures are 2D, so convert 1D coord to 2D.
8172 			// Support for Metal 2.1's new texture_buffer type.
8173 			if (args.base.is_fetch)
8174 			{
8175 				if (msl_options.texel_buffer_texture_width > 0)
8176 				{
8177 					tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8178 				}
8179 				else
8180 				{
8181 					tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ", " +
8182 					             to_expression(img) + ")";
8183 				}
8184 			}
8185 		}
8186 
8187 		alt_coord_component = 1;
8188 		break;
8189 
8190 	case DimSubpassData:
8191 		// If we're using Metal's native frame-buffer fetch API for subpass inputs,
8192 		// this path will not be hit.
8193 		tex_coords = "uint2(gl_FragCoord.xy)";
8194 		alt_coord_component = 2;
8195 		break;
8196 
8197 	case Dim2D:
8198 		if (coord_type.vecsize > 2)
8199 			tex_coords = enclose_expression(tex_coords) + ".xy";
8200 
8201 		if (args.base.is_fetch)
8202 			tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8203 		else if (sampling_type_needs_f32_conversion(coord_type))
8204 			tex_coords = convert_to_f32(tex_coords, 2);
8205 
8206 		alt_coord_component = 2;
8207 		break;
8208 
8209 	case Dim3D:
8210 		if (coord_type.vecsize > 3)
8211 			tex_coords = enclose_expression(tex_coords) + ".xyz";
8212 
8213 		if (args.base.is_fetch)
8214 			tex_coords = "uint3(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8215 		else if (sampling_type_needs_f32_conversion(coord_type))
8216 			tex_coords = convert_to_f32(tex_coords, 3);
8217 
8218 		alt_coord_component = 3;
8219 		break;
8220 
8221 	case DimCube:
8222 		if (args.base.is_fetch)
8223 		{
8224 			is_cube_fetch = true;
8225 			tex_coords += ".xy";
8226 			tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8227 		}
8228 		else
8229 		{
8230 			if (coord_type.vecsize > 3)
8231 				tex_coords = enclose_expression(tex_coords) + ".xyz";
8232 		}
8233 
8234 		if (sampling_type_needs_f32_conversion(coord_type))
8235 			tex_coords = convert_to_f32(tex_coords, 3);
8236 
8237 		alt_coord_component = 3;
8238 		break;
8239 
8240 	default:
8241 		break;
8242 	}
8243 
8244 	if (args.base.is_fetch && args.offset)
8245 	{
8246 		// Fetch offsets must be applied directly to the coordinate.
8247 		forward = forward && should_forward(args.offset);
8248 		auto &type = expression_type(args.offset);
8249 		if (type.basetype != SPIRType::UInt)
8250 			tex_coords += " + " + bitcast_expression(SPIRType::UInt, args.offset);
8251 		else
8252 			tex_coords += " + " + to_enclosed_expression(args.offset);
8253 	}
8254 	else if (args.base.is_fetch && args.coffset)
8255 	{
8256 		// Fetch offsets must be applied directly to the coordinate.
8257 		forward = forward && should_forward(args.coffset);
8258 		auto &type = expression_type(args.coffset);
8259 		if (type.basetype != SPIRType::UInt)
8260 			tex_coords += " + " + bitcast_expression(SPIRType::UInt, args.coffset);
8261 		else
8262 			tex_coords += " + " + to_enclosed_expression(args.coffset);
8263 	}
8264 
8265 	// If projection, use alt coord as divisor
8266 	if (args.base.is_proj)
8267 	{
8268 		if (sampling_type_needs_f32_conversion(coord_type))
8269 			tex_coords += " / " + convert_to_f32(to_extract_component_expression(args.coord, alt_coord_component), 1);
8270 		else
8271 			tex_coords += " / " + to_extract_component_expression(args.coord, alt_coord_component);
8272 	}
8273 
8274 	if (!farg_str.empty())
8275 		farg_str += ", ";
8276 
8277 	if (imgtype.image.dim == DimCube && imgtype.image.arrayed && msl_options.emulate_cube_array)
8278 	{
8279 		farg_str += "spvCubemapTo2DArrayFace(" + tex_coords + ").xy";
8280 
8281 		if (is_cube_fetch)
8282 			farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ")";
8283 		else
8284 			farg_str +=
8285 			    ", uint(spvCubemapTo2DArrayFace(" + tex_coords + ").z) + (uint(" +
8286 			    round_fp_tex_coords(to_extract_component_expression(args.coord, alt_coord_component), coord_is_fp) +
8287 			    ") * 6u)";
8288 
8289 		add_spv_func_and_recompile(SPVFuncImplCubemapTo2DArrayFace);
8290 	}
8291 	else
8292 	{
8293 		farg_str += tex_coords;
8294 
8295 		// If fetch from cube, add face explicitly
8296 		if (is_cube_fetch)
8297 		{
8298 			// Special case for cube arrays, face and layer are packed in one dimension.
8299 			if (imgtype.image.arrayed)
8300 				farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ") % 6u";
8301 			else
8302 				farg_str +=
8303 				    ", uint(" + round_fp_tex_coords(to_extract_component_expression(args.coord, 2), coord_is_fp) + ")";
8304 		}
8305 
8306 		// If array, use alt coord
8307 		if (imgtype.image.arrayed)
8308 		{
8309 			// Special case for cube arrays, face and layer are packed in one dimension.
8310 			if (imgtype.image.dim == DimCube && args.base.is_fetch)
8311 			{
8312 				farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ") / 6u";
8313 			}
8314 			else
8315 			{
8316 				farg_str +=
8317 				    ", uint(" +
8318 				    round_fp_tex_coords(to_extract_component_expression(args.coord, alt_coord_component), coord_is_fp) +
8319 				    ")";
8320 				if (imgtype.image.dim == DimSubpassData)
8321 				{
8322 					if (msl_options.multiview)
8323 						farg_str += " + gl_ViewIndex";
8324 					else if (msl_options.arrayed_subpass_input)
8325 						farg_str += " + gl_Layer";
8326 				}
8327 			}
8328 		}
8329 		else if (imgtype.image.dim == DimSubpassData)
8330 		{
8331 			if (msl_options.multiview)
8332 				farg_str += ", gl_ViewIndex";
8333 			else if (msl_options.arrayed_subpass_input)
8334 				farg_str += ", gl_Layer";
8335 		}
8336 	}
8337 
8338 	// Depth compare reference value
8339 	if (args.dref)
8340 	{
8341 		forward = forward && should_forward(args.dref);
8342 		farg_str += ", ";
8343 
8344 		auto &dref_type = expression_type(args.dref);
8345 
8346 		string dref_expr;
8347 		if (args.base.is_proj)
8348 			dref_expr = join(to_enclosed_expression(args.dref), " / ",
8349 			                 to_extract_component_expression(args.coord, alt_coord_component));
8350 		else
8351 			dref_expr = to_expression(args.dref);
8352 
8353 		if (sampling_type_needs_f32_conversion(dref_type))
8354 			dref_expr = convert_to_f32(dref_expr, 1);
8355 
8356 		farg_str += dref_expr;
8357 
8358 		if (msl_options.is_macos() && (grad_x || grad_y))
8359 		{
8360 			// For sample compare, MSL does not support gradient2d for all targets (only iOS apparently according to docs).
8361 			// However, the most common case here is to have a constant gradient of 0, as that is the only way to express
8362 			// LOD == 0 in GLSL with sampler2DArrayShadow (cascaded shadow mapping).
8363 			// We will detect a compile-time constant 0 value for gradient and promote that to level(0) on MSL.
8364 			bool constant_zero_x = !grad_x || expression_is_constant_null(grad_x);
8365 			bool constant_zero_y = !grad_y || expression_is_constant_null(grad_y);
8366 			if (constant_zero_x && constant_zero_y)
8367 			{
8368 				lod = 0;
8369 				grad_x = 0;
8370 				grad_y = 0;
8371 				farg_str += ", level(0)";
8372 			}
8373 			else
8374 			{
8375 				SPIRV_CROSS_THROW("Using non-constant 0.0 gradient() qualifier for sample_compare. This is not "
8376 				                  "supported in MSL macOS.");
8377 			}
8378 		}
8379 
8380 		if (msl_options.is_macos() && bias)
8381 		{
8382 			// Bias is not supported either on macOS with sample_compare.
8383 			// Verify it is compile-time zero, and drop the argument.
8384 			if (expression_is_constant_null(bias))
8385 			{
8386 				bias = 0;
8387 			}
8388 			else
8389 			{
8390 				SPIRV_CROSS_THROW(
8391 				    "Using non-constant 0.0 bias() qualifier for sample_compare. This is not supported in MSL macOS.");
8392 			}
8393 		}
8394 	}
8395 
8396 	// LOD Options
8397 	// Metal does not support LOD for 1D textures.
8398 	if (bias && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
8399 	{
8400 		forward = forward && should_forward(bias);
8401 		farg_str += ", bias(" + to_expression(bias) + ")";
8402 	}
8403 
8404 	// Metal does not support LOD for 1D textures.
8405 	if (lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
8406 	{
8407 		forward = forward && should_forward(lod);
8408 		if (args.base.is_fetch)
8409 		{
8410 			farg_str += ", " + to_expression(lod);
8411 		}
8412 		else
8413 		{
8414 			farg_str += ", level(" + to_expression(lod) + ")";
8415 		}
8416 	}
8417 	else if (args.base.is_fetch && !lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D) &&
8418 	         imgtype.image.dim != DimBuffer && !imgtype.image.ms && imgtype.image.sampled != 2)
8419 	{
8420 		// Lod argument is optional in OpImageFetch, but we require a LOD value, pick 0 as the default.
8421 		// Check for sampled type as well, because is_fetch is also used for OpImageRead in MSL.
8422 		farg_str += ", 0";
8423 	}
8424 
8425 	// Metal does not support LOD for 1D textures.
8426 	if ((grad_x || grad_y) && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
8427 	{
8428 		forward = forward && should_forward(grad_x);
8429 		forward = forward && should_forward(grad_y);
8430 		string grad_opt;
8431 		switch (imgtype.image.dim)
8432 		{
8433 		case Dim2D:
8434 			grad_opt = "2d";
8435 			break;
8436 		case Dim3D:
8437 			grad_opt = "3d";
8438 			break;
8439 		case DimCube:
8440 			if (imgtype.image.arrayed && msl_options.emulate_cube_array)
8441 				grad_opt = "2d";
8442 			else
8443 				grad_opt = "cube";
8444 			break;
8445 		default:
8446 			grad_opt = "unsupported_gradient_dimension";
8447 			break;
8448 		}
8449 		farg_str += ", gradient" + grad_opt + "(" + to_expression(grad_x) + ", " + to_expression(grad_y) + ")";
8450 	}
8451 
8452 	if (args.min_lod)
8453 	{
8454 		if (msl_options.is_macos())
8455 		{
8456 			if (!msl_options.supports_msl_version(2, 2))
8457 				SPIRV_CROSS_THROW("min_lod_clamp() is only supported in MSL 2.2+ and up on macOS.");
8458 		}
8459 		else if (msl_options.is_ios())
8460 			SPIRV_CROSS_THROW("min_lod_clamp() is not supported on iOS.");
8461 
8462 		forward = forward && should_forward(args.min_lod);
8463 		farg_str += ", min_lod_clamp(" + to_expression(args.min_lod) + ")";
8464 	}
8465 
8466 	// Add offsets
8467 	string offset_expr;
8468 	if (args.coffset && !args.base.is_fetch)
8469 	{
8470 		forward = forward && should_forward(args.coffset);
8471 		offset_expr = to_expression(args.coffset);
8472 	}
8473 	else if (args.offset && !args.base.is_fetch)
8474 	{
8475 		forward = forward && should_forward(args.offset);
8476 		offset_expr = to_expression(args.offset);
8477 	}
8478 
8479 	if (!offset_expr.empty())
8480 	{
8481 		switch (imgtype.image.dim)
8482 		{
8483 		case Dim2D:
8484 			if (coord_type.vecsize > 2)
8485 				offset_expr = enclose_expression(offset_expr) + ".xy";
8486 
8487 			farg_str += ", " + offset_expr;
8488 			break;
8489 
8490 		case Dim3D:
8491 			if (coord_type.vecsize > 3)
8492 				offset_expr = enclose_expression(offset_expr) + ".xyz";
8493 
8494 			farg_str += ", " + offset_expr;
8495 			break;
8496 
8497 		default:
8498 			break;
8499 		}
8500 	}
8501 
8502 	if (args.component)
8503 	{
8504 		// If 2D has gather component, ensure it also has an offset arg
8505 		if (imgtype.image.dim == Dim2D && offset_expr.empty())
8506 			farg_str += ", int2(0)";
8507 
8508 		if (!msl_options.swizzle_texture_samples || is_dynamic_img_sampler)
8509 		{
8510 			forward = forward && should_forward(args.component);
8511 			farg_str += ", " + to_component_argument(args.component);
8512 		}
8513 	}
8514 
8515 	if (args.sample)
8516 	{
8517 		forward = forward && should_forward(args.sample);
8518 		farg_str += ", ";
8519 		farg_str += to_expression(args.sample);
8520 	}
8521 
8522 	*p_forward = forward;
8523 
8524 	return farg_str;
8525 }
8526 
8527 // If the texture coordinates are floating point, invokes MSL round() function to round them.
round_fp_tex_coords(string tex_coords,bool coord_is_fp)8528 string CompilerMSL::round_fp_tex_coords(string tex_coords, bool coord_is_fp)
8529 {
8530 	return coord_is_fp ? ("round(" + tex_coords + ")") : tex_coords;
8531 }
8532 
8533 // Returns a string to use in an image sampling function argument.
8534 // The ID must be a scalar constant.
to_component_argument(uint32_t id)8535 string CompilerMSL::to_component_argument(uint32_t id)
8536 {
8537 	uint32_t component_index = evaluate_constant_u32(id);
8538 	switch (component_index)
8539 	{
8540 	case 0:
8541 		return "component::x";
8542 	case 1:
8543 		return "component::y";
8544 	case 2:
8545 		return "component::z";
8546 	case 3:
8547 		return "component::w";
8548 
8549 	default:
8550 		SPIRV_CROSS_THROW("The value (" + to_string(component_index) + ") of OpConstant ID " + to_string(id) +
8551 		                  " is not a valid Component index, which must be one of 0, 1, 2, or 3.");
8552 		return "component::x";
8553 	}
8554 }
8555 
8556 // Establish sampled image as expression object and assign the sampler to it.
emit_sampled_image_op(uint32_t result_type,uint32_t result_id,uint32_t image_id,uint32_t samp_id)8557 void CompilerMSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
8558 {
8559 	set<SPIRCombinedImageSampler>(result_id, result_type, image_id, samp_id);
8560 }
8561 
to_texture_op(const Instruction & i,bool sparse,bool * forward,SmallVector<uint32_t> & inherited_expressions)8562 string CompilerMSL::to_texture_op(const Instruction &i, bool sparse, bool *forward,
8563                                   SmallVector<uint32_t> &inherited_expressions)
8564 {
8565 	auto *ops = stream(i);
8566 	uint32_t result_type_id = ops[0];
8567 	uint32_t img = ops[2];
8568 	auto &result_type = get<SPIRType>(result_type_id);
8569 	auto op = static_cast<Op>(i.op);
8570 	bool is_gather = (op == OpImageGather || op == OpImageDrefGather);
8571 
8572 	// Bypass pointers because we need the real image struct
8573 	auto &type = expression_type(img);
8574 	auto &imgtype = get<SPIRType>(type.self);
8575 
8576 	const MSLConstexprSampler *constexpr_sampler = nullptr;
8577 	bool is_dynamic_img_sampler = false;
8578 	if (auto *var = maybe_get_backing_variable(img))
8579 	{
8580 		constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
8581 		is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
8582 	}
8583 
8584 	string expr;
8585 	if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
8586 	{
8587 		// If this needs sampler Y'CbCr conversion, we need to do some additional
8588 		// processing.
8589 		switch (constexpr_sampler->ycbcr_model)
8590 		{
8591 		case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
8592 		case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
8593 			// Default
8594 			break;
8595 		case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
8596 			add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT709);
8597 			expr += "spvConvertYCbCrBT709(";
8598 			break;
8599 		case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
8600 			add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT601);
8601 			expr += "spvConvertYCbCrBT601(";
8602 			break;
8603 		case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
8604 			add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT2020);
8605 			expr += "spvConvertYCbCrBT2020(";
8606 			break;
8607 		default:
8608 			SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
8609 		}
8610 
8611 		if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
8612 		{
8613 			switch (constexpr_sampler->ycbcr_range)
8614 			{
8615 			case MSL_SAMPLER_YCBCR_RANGE_ITU_FULL:
8616 				add_spv_func_and_recompile(SPVFuncImplExpandITUFullRange);
8617 				expr += "spvExpandITUFullRange(";
8618 				break;
8619 			case MSL_SAMPLER_YCBCR_RANGE_ITU_NARROW:
8620 				add_spv_func_and_recompile(SPVFuncImplExpandITUNarrowRange);
8621 				expr += "spvExpandITUNarrowRange(";
8622 				break;
8623 			default:
8624 				SPIRV_CROSS_THROW("Invalid Y'CbCr range.");
8625 			}
8626 		}
8627 	}
8628 	else if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype) &&
8629 	         !is_dynamic_img_sampler)
8630 	{
8631 		add_spv_func_and_recompile(SPVFuncImplTextureSwizzle);
8632 		expr += "spvTextureSwizzle(";
8633 	}
8634 
8635 	string inner_expr = CompilerGLSL::to_texture_op(i, sparse, forward, inherited_expressions);
8636 
8637 	if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
8638 	{
8639 		if (!constexpr_sampler->swizzle_is_identity())
8640 		{
8641 			static const char swizzle_names[] = "rgba";
8642 			if (!constexpr_sampler->swizzle_has_one_or_zero())
8643 			{
8644 				// If we can, do it inline.
8645 				expr += inner_expr + ".";
8646 				for (uint32_t c = 0; c < 4; c++)
8647 				{
8648 					switch (constexpr_sampler->swizzle[c])
8649 					{
8650 					case MSL_COMPONENT_SWIZZLE_IDENTITY:
8651 						expr += swizzle_names[c];
8652 						break;
8653 					case MSL_COMPONENT_SWIZZLE_R:
8654 					case MSL_COMPONENT_SWIZZLE_G:
8655 					case MSL_COMPONENT_SWIZZLE_B:
8656 					case MSL_COMPONENT_SWIZZLE_A:
8657 						expr += swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
8658 						break;
8659 					default:
8660 						SPIRV_CROSS_THROW("Invalid component swizzle.");
8661 					}
8662 				}
8663 			}
8664 			else
8665 			{
8666 				// Otherwise, we need to emit a temporary and swizzle that.
8667 				uint32_t temp_id = ir.increase_bound_by(1);
8668 				emit_op(result_type_id, temp_id, inner_expr, false);
8669 				for (auto &inherit : inherited_expressions)
8670 					inherit_expression_dependencies(temp_id, inherit);
8671 				inherited_expressions.clear();
8672 				inherited_expressions.push_back(temp_id);
8673 
8674 				switch (op)
8675 				{
8676 				case OpImageSampleDrefImplicitLod:
8677 				case OpImageSampleImplicitLod:
8678 				case OpImageSampleProjImplicitLod:
8679 				case OpImageSampleProjDrefImplicitLod:
8680 					register_control_dependent_expression(temp_id);
8681 					break;
8682 
8683 				default:
8684 					break;
8685 				}
8686 				expr += type_to_glsl(result_type) + "(";
8687 				for (uint32_t c = 0; c < 4; c++)
8688 				{
8689 					switch (constexpr_sampler->swizzle[c])
8690 					{
8691 					case MSL_COMPONENT_SWIZZLE_IDENTITY:
8692 						expr += to_expression(temp_id) + "." + swizzle_names[c];
8693 						break;
8694 					case MSL_COMPONENT_SWIZZLE_ZERO:
8695 						expr += "0";
8696 						break;
8697 					case MSL_COMPONENT_SWIZZLE_ONE:
8698 						expr += "1";
8699 						break;
8700 					case MSL_COMPONENT_SWIZZLE_R:
8701 					case MSL_COMPONENT_SWIZZLE_G:
8702 					case MSL_COMPONENT_SWIZZLE_B:
8703 					case MSL_COMPONENT_SWIZZLE_A:
8704 						expr += to_expression(temp_id) + "." +
8705 						        swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
8706 						break;
8707 					default:
8708 						SPIRV_CROSS_THROW("Invalid component swizzle.");
8709 					}
8710 					if (c < 3)
8711 						expr += ", ";
8712 				}
8713 				expr += ")";
8714 			}
8715 		}
8716 		else
8717 			expr += inner_expr;
8718 		if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
8719 		{
8720 			expr += join(", ", constexpr_sampler->bpc, ")");
8721 			if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY)
8722 				expr += ")";
8723 		}
8724 	}
8725 	else
8726 	{
8727 		expr += inner_expr;
8728 		if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype) &&
8729 		    !is_dynamic_img_sampler)
8730 		{
8731 			// Add the swizzle constant from the swizzle buffer.
8732 			expr += ", " + to_swizzle_expression(img) + ")";
8733 			used_swizzle_buffer = true;
8734 		}
8735 	}
8736 
8737 	return expr;
8738 }
8739 
create_swizzle(MSLComponentSwizzle swizzle)8740 static string create_swizzle(MSLComponentSwizzle swizzle)
8741 {
8742 	switch (swizzle)
8743 	{
8744 	case MSL_COMPONENT_SWIZZLE_IDENTITY:
8745 		return "spvSwizzle::none";
8746 	case MSL_COMPONENT_SWIZZLE_ZERO:
8747 		return "spvSwizzle::zero";
8748 	case MSL_COMPONENT_SWIZZLE_ONE:
8749 		return "spvSwizzle::one";
8750 	case MSL_COMPONENT_SWIZZLE_R:
8751 		return "spvSwizzle::red";
8752 	case MSL_COMPONENT_SWIZZLE_G:
8753 		return "spvSwizzle::green";
8754 	case MSL_COMPONENT_SWIZZLE_B:
8755 		return "spvSwizzle::blue";
8756 	case MSL_COMPONENT_SWIZZLE_A:
8757 		return "spvSwizzle::alpha";
8758 	default:
8759 		SPIRV_CROSS_THROW("Invalid component swizzle.");
8760 		return "";
8761 	}
8762 }
8763 
8764 // Returns a string representation of the ID, usable as a function arg.
8765 // Manufacture automatic sampler arg for SampledImage texture.
to_func_call_arg(const SPIRFunction::Parameter & arg,uint32_t id)8766 string CompilerMSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id)
8767 {
8768 	string arg_str;
8769 
8770 	auto &type = expression_type(id);
8771 	bool is_dynamic_img_sampler = has_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
8772 	// If the argument *itself* is a "dynamic" combined-image sampler, then we can just pass that around.
8773 	bool arg_is_dynamic_img_sampler = has_extended_decoration(id, SPIRVCrossDecorationDynamicImageSampler);
8774 	if (is_dynamic_img_sampler && !arg_is_dynamic_img_sampler)
8775 		arg_str = join("spvDynamicImageSampler<", type_to_glsl(get<SPIRType>(type.image.type)), ">(");
8776 
8777 	auto *c = maybe_get<SPIRConstant>(id);
8778 	if (msl_options.force_native_arrays && c && !get<SPIRType>(c->constant_type).array.empty())
8779 	{
8780 		// If we are passing a constant array directly to a function for some reason,
8781 		// the callee will expect an argument in thread const address space
8782 		// (since we can only bind to arrays with references in MSL).
8783 		// To resolve this, we must emit a copy in this address space.
8784 		// This kind of code gen should be rare enough that performance is not a real concern.
8785 		// Inline the SPIR-V to avoid this kind of suboptimal codegen.
8786 		//
8787 		// We risk calling this inside a continue block (invalid code),
8788 		// so just create a thread local copy in the current function.
8789 		arg_str = join("_", id, "_array_copy");
8790 		auto &constants = current_function->constant_arrays_needed_on_stack;
8791 		auto itr = find(begin(constants), end(constants), ID(id));
8792 		if (itr == end(constants))
8793 		{
8794 			force_recompile();
8795 			constants.push_back(id);
8796 		}
8797 	}
8798 	else
8799 		arg_str += CompilerGLSL::to_func_call_arg(arg, id);
8800 
8801 	// Need to check the base variable in case we need to apply a qualified alias.
8802 	uint32_t var_id = 0;
8803 	auto *var = maybe_get<SPIRVariable>(id);
8804 	if (var)
8805 		var_id = var->basevariable;
8806 
8807 	if (!arg_is_dynamic_img_sampler)
8808 	{
8809 		auto *constexpr_sampler = find_constexpr_sampler(var_id ? var_id : id);
8810 		if (type.basetype == SPIRType::SampledImage)
8811 		{
8812 			// Manufacture automatic plane args for multiplanar texture
8813 			uint32_t planes = 1;
8814 			if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
8815 			{
8816 				planes = constexpr_sampler->planes;
8817 				// If this parameter isn't aliasing a global, then we need to use
8818 				// the special "dynamic image-sampler" class to pass it--and we need
8819 				// to use it for *every* non-alias parameter, in case a combined
8820 				// image-sampler with a Y'CbCr conversion is passed. Hopefully, this
8821 				// pathological case is so rare that it should never be hit in practice.
8822 				if (!arg.alias_global_variable)
8823 					add_spv_func_and_recompile(SPVFuncImplDynamicImageSampler);
8824 			}
8825 			for (uint32_t i = 1; i < planes; i++)
8826 				arg_str += join(", ", CompilerGLSL::to_func_call_arg(arg, id), plane_name_suffix, i);
8827 			// Manufacture automatic sampler arg if the arg is a SampledImage texture.
8828 			if (type.image.dim != DimBuffer)
8829 				arg_str += ", " + to_sampler_expression(var_id ? var_id : id);
8830 
8831 			// Add sampler Y'CbCr conversion info if we have it
8832 			if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
8833 			{
8834 				SmallVector<string> samp_args;
8835 
8836 				switch (constexpr_sampler->resolution)
8837 				{
8838 				case MSL_FORMAT_RESOLUTION_444:
8839 					// Default
8840 					break;
8841 				case MSL_FORMAT_RESOLUTION_422:
8842 					samp_args.push_back("spvFormatResolution::_422");
8843 					break;
8844 				case MSL_FORMAT_RESOLUTION_420:
8845 					samp_args.push_back("spvFormatResolution::_420");
8846 					break;
8847 				default:
8848 					SPIRV_CROSS_THROW("Invalid format resolution.");
8849 				}
8850 
8851 				if (constexpr_sampler->chroma_filter != MSL_SAMPLER_FILTER_NEAREST)
8852 					samp_args.push_back("spvChromaFilter::linear");
8853 
8854 				if (constexpr_sampler->x_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
8855 					samp_args.push_back("spvXChromaLocation::midpoint");
8856 				if (constexpr_sampler->y_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
8857 					samp_args.push_back("spvYChromaLocation::midpoint");
8858 				switch (constexpr_sampler->ycbcr_model)
8859 				{
8860 				case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
8861 					// Default
8862 					break;
8863 				case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
8864 					samp_args.push_back("spvYCbCrModelConversion::ycbcr_identity");
8865 					break;
8866 				case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
8867 					samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_709");
8868 					break;
8869 				case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
8870 					samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_601");
8871 					break;
8872 				case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
8873 					samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_2020");
8874 					break;
8875 				default:
8876 					SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
8877 				}
8878 				if (constexpr_sampler->ycbcr_range != MSL_SAMPLER_YCBCR_RANGE_ITU_FULL)
8879 					samp_args.push_back("spvYCbCrRange::itu_narrow");
8880 				samp_args.push_back(join("spvComponentBits(", constexpr_sampler->bpc, ")"));
8881 				arg_str += join(", spvYCbCrSampler(", merge(samp_args), ")");
8882 			}
8883 		}
8884 
8885 		if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
8886 			arg_str += join(", (uint(", create_swizzle(constexpr_sampler->swizzle[3]), ") << 24) | (uint(",
8887 			                create_swizzle(constexpr_sampler->swizzle[2]), ") << 16) | (uint(",
8888 			                create_swizzle(constexpr_sampler->swizzle[1]), ") << 8) | uint(",
8889 			                create_swizzle(constexpr_sampler->swizzle[0]), ")");
8890 		else if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
8891 			arg_str += ", " + to_swizzle_expression(var_id ? var_id : id);
8892 
8893 		if (buffers_requiring_array_length.count(var_id))
8894 			arg_str += ", " + to_buffer_size_expression(var_id ? var_id : id);
8895 
8896 		if (is_dynamic_img_sampler)
8897 			arg_str += ")";
8898 	}
8899 
8900 	// Emulate texture2D atomic operations
8901 	auto *backing_var = maybe_get_backing_variable(var_id);
8902 	if (backing_var && atomic_image_vars.count(backing_var->self))
8903 	{
8904 		arg_str += ", " + to_expression(var_id) + "_atomic";
8905 	}
8906 
8907 	return arg_str;
8908 }
8909 
8910 // If the ID represents a sampled image that has been assigned a sampler already,
8911 // generate an expression for the sampler, otherwise generate a fake sampler name
8912 // by appending a suffix to the expression constructed from the ID.
to_sampler_expression(uint32_t id)8913 string CompilerMSL::to_sampler_expression(uint32_t id)
8914 {
8915 	auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
8916 	auto expr = to_expression(combined ? combined->image : VariableID(id));
8917 	auto index = expr.find_first_of('[');
8918 
8919 	uint32_t samp_id = 0;
8920 	if (combined)
8921 		samp_id = combined->sampler;
8922 
8923 	if (index == string::npos)
8924 		return samp_id ? to_expression(samp_id) : expr + sampler_name_suffix;
8925 	else
8926 	{
8927 		auto image_expr = expr.substr(0, index);
8928 		auto array_expr = expr.substr(index);
8929 		return samp_id ? to_expression(samp_id) : (image_expr + sampler_name_suffix + array_expr);
8930 	}
8931 }
8932 
to_swizzle_expression(uint32_t id)8933 string CompilerMSL::to_swizzle_expression(uint32_t id)
8934 {
8935 	auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
8936 
8937 	auto expr = to_expression(combined ? combined->image : VariableID(id));
8938 	auto index = expr.find_first_of('[');
8939 
8940 	// If an image is part of an argument buffer translate this to a legal identifier.
8941 	for (auto &c : expr)
8942 		if (c == '.')
8943 			c = '_';
8944 
8945 	if (index == string::npos)
8946 		return expr + swizzle_name_suffix;
8947 	else
8948 	{
8949 		auto image_expr = expr.substr(0, index);
8950 		auto array_expr = expr.substr(index);
8951 		return image_expr + swizzle_name_suffix + array_expr;
8952 	}
8953 }
8954 
to_buffer_size_expression(uint32_t id)8955 string CompilerMSL::to_buffer_size_expression(uint32_t id)
8956 {
8957 	auto expr = to_expression(id);
8958 	auto index = expr.find_first_of('[');
8959 
8960 	// This is quite crude, but we need to translate the reference name (*spvDescriptorSetN.name) to
8961 	// the pointer expression spvDescriptorSetN.name to make a reasonable expression here.
8962 	// This only happens if we have argument buffers and we are using OpArrayLength on a lone SSBO in that set.
8963 	if (expr.size() >= 3 && expr[0] == '(' && expr[1] == '*')
8964 		expr = address_of_expression(expr);
8965 
8966 	// If a buffer is part of an argument buffer translate this to a legal identifier.
8967 	for (auto &c : expr)
8968 		if (c == '.')
8969 			c = '_';
8970 
8971 	if (index == string::npos)
8972 		return expr + buffer_size_name_suffix;
8973 	else
8974 	{
8975 		auto buffer_expr = expr.substr(0, index);
8976 		auto array_expr = expr.substr(index);
8977 		return buffer_expr + buffer_size_name_suffix + array_expr;
8978 	}
8979 }
8980 
8981 // Checks whether the type is a Block all of whose members have DecorationPatch.
is_patch_block(const SPIRType & type)8982 bool CompilerMSL::is_patch_block(const SPIRType &type)
8983 {
8984 	if (!has_decoration(type.self, DecorationBlock))
8985 		return false;
8986 
8987 	for (uint32_t i = 0; i < type.member_types.size(); i++)
8988 	{
8989 		if (!has_member_decoration(type.self, i, DecorationPatch))
8990 			return false;
8991 	}
8992 
8993 	return true;
8994 }
8995 
8996 // Checks whether the ID is a row_major matrix that requires conversion before use
is_non_native_row_major_matrix(uint32_t id)8997 bool CompilerMSL::is_non_native_row_major_matrix(uint32_t id)
8998 {
8999 	auto *e = maybe_get<SPIRExpression>(id);
9000 	if (e)
9001 		return e->need_transpose;
9002 	else
9003 		return has_decoration(id, DecorationRowMajor);
9004 }
9005 
9006 // Checks whether the member is a row_major matrix that requires conversion before use
member_is_non_native_row_major_matrix(const SPIRType & type,uint32_t index)9007 bool CompilerMSL::member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index)
9008 {
9009 	return has_member_decoration(type.self, index, DecorationRowMajor);
9010 }
9011 
convert_row_major_matrix(string exp_str,const SPIRType & exp_type,uint32_t physical_type_id,bool is_packed)9012 string CompilerMSL::convert_row_major_matrix(string exp_str, const SPIRType &exp_type, uint32_t physical_type_id,
9013                                              bool is_packed)
9014 {
9015 	if (!is_matrix(exp_type))
9016 	{
9017 		return CompilerGLSL::convert_row_major_matrix(move(exp_str), exp_type, physical_type_id, is_packed);
9018 	}
9019 	else
9020 	{
9021 		strip_enclosed_expression(exp_str);
9022 		if (physical_type_id != 0 || is_packed)
9023 			exp_str = unpack_expression_type(exp_str, exp_type, physical_type_id, is_packed, true);
9024 		return join("transpose(", exp_str, ")");
9025 	}
9026 }
9027 
9028 // Called automatically at the end of the entry point function
emit_fixup()9029 void CompilerMSL::emit_fixup()
9030 {
9031 	if ((get_execution_model() == ExecutionModelVertex ||
9032 	     get_execution_model() == ExecutionModelTessellationEvaluation) &&
9033 	    stage_out_var_id && !qual_pos_var_name.empty() && !capture_output_to_buffer)
9034 	{
9035 		if (options.vertex.fixup_clipspace)
9036 			statement(qual_pos_var_name, ".z = (", qual_pos_var_name, ".z + ", qual_pos_var_name,
9037 			          ".w) * 0.5;       // Adjust clip-space for Metal");
9038 
9039 		if (options.vertex.flip_vert_y)
9040 			statement(qual_pos_var_name, ".y = -(", qual_pos_var_name, ".y);", "    // Invert Y-axis for Metal");
9041 	}
9042 }
9043 
9044 // Return a string defining a structure member, with padding and packing.
to_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier)9045 string CompilerMSL::to_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
9046                                      const string &qualifier)
9047 {
9048 	if (member_is_remapped_physical_type(type, index))
9049 		member_type_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID);
9050 	auto &physical_type = get<SPIRType>(member_type_id);
9051 
9052 	// If this member is packed, mark it as so.
9053 	string pack_pfx;
9054 
9055 	// Allow Metal to use the array<T> template to make arrays a value type
9056 	uint32_t orig_id = 0;
9057 	if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID))
9058 		orig_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID);
9059 
9060 	bool row_major = false;
9061 	if (is_matrix(physical_type))
9062 		row_major = has_member_decoration(type.self, index, DecorationRowMajor);
9063 
9064 	SPIRType row_major_physical_type;
9065 	const SPIRType *declared_type = &physical_type;
9066 
9067 	// If a struct is being declared with physical layout,
9068 	// do not use array<T> wrappers.
9069 	// This avoids a lot of complicated cases with packed vectors and matrices,
9070 	// and generally we cannot copy full arrays in and out of buffers into Function
9071 	// address space.
9072 	// Array of resources should also be declared as builtin arrays.
9073 	if (has_member_decoration(type.self, index, DecorationOffset))
9074 		is_using_builtin_array = true;
9075 	else if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary))
9076 		is_using_builtin_array = true;
9077 
9078 	if (member_is_packed_physical_type(type, index))
9079 	{
9080 		// If we're packing a matrix, output an appropriate typedef
9081 		if (physical_type.basetype == SPIRType::Struct)
9082 		{
9083 			SPIRV_CROSS_THROW("Cannot emit a packed struct currently.");
9084 		}
9085 		else if (is_matrix(physical_type))
9086 		{
9087 			uint32_t rows = physical_type.vecsize;
9088 			uint32_t cols = physical_type.columns;
9089 			pack_pfx = "packed_";
9090 			if (row_major)
9091 			{
9092 				// These are stored transposed.
9093 				rows = physical_type.columns;
9094 				cols = physical_type.vecsize;
9095 				pack_pfx = "packed_rm_";
9096 			}
9097 			string base_type = physical_type.width == 16 ? "half" : "float";
9098 			string td_line = "typedef ";
9099 			td_line += "packed_" + base_type + to_string(rows);
9100 			td_line += " " + pack_pfx;
9101 			// Use the actual matrix size here.
9102 			td_line += base_type + to_string(physical_type.columns) + "x" + to_string(physical_type.vecsize);
9103 			td_line += "[" + to_string(cols) + "]";
9104 			td_line += ";";
9105 			add_typedef_line(td_line);
9106 		}
9107 		else if (!is_scalar(physical_type)) // scalar type is already packed.
9108 			pack_pfx = "packed_";
9109 	}
9110 	else if (row_major)
9111 	{
9112 		// Need to declare type with flipped vecsize/columns.
9113 		row_major_physical_type = physical_type;
9114 		swap(row_major_physical_type.vecsize, row_major_physical_type.columns);
9115 		declared_type = &row_major_physical_type;
9116 	}
9117 
9118 	// Very specifically, image load-store in argument buffers are disallowed on MSL on iOS.
9119 	if (msl_options.is_ios() && physical_type.basetype == SPIRType::Image && physical_type.image.sampled == 2)
9120 	{
9121 		if (!has_decoration(orig_id, DecorationNonWritable))
9122 			SPIRV_CROSS_THROW("Writable images are not allowed in argument buffers on iOS.");
9123 	}
9124 
9125 	// Array information is baked into these types.
9126 	string array_type;
9127 	if (physical_type.basetype != SPIRType::Image && physical_type.basetype != SPIRType::Sampler &&
9128 	    physical_type.basetype != SPIRType::SampledImage)
9129 	{
9130 		BuiltIn builtin = BuiltInMax;
9131 		if (is_member_builtin(type, index, &builtin))
9132 			is_using_builtin_array = true;
9133 		array_type = type_to_array_glsl(physical_type);
9134 	}
9135 
9136 	auto result = join(pack_pfx, type_to_glsl(*declared_type, orig_id), " ", qualifier, to_member_name(type, index),
9137 	                   member_attribute_qualifier(type, index), array_type, ";");
9138 
9139 	is_using_builtin_array = false;
9140 	return result;
9141 }
9142 
9143 // Emit a structure member, padding and packing to maintain the correct memeber alignments.
emit_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier,uint32_t)9144 void CompilerMSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
9145                                      const string &qualifier, uint32_t)
9146 {
9147 	// If this member requires padding to maintain its declared offset, emit a dummy padding member before it.
9148 	if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationPaddingTarget))
9149 	{
9150 		uint32_t pad_len = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPaddingTarget);
9151 		statement("char _m", index, "_pad", "[", pad_len, "];");
9152 	}
9153 
9154 	// Handle HLSL-style 0-based vertex/instance index.
9155 	builtin_declaration = true;
9156 	statement(to_struct_member(type, member_type_id, index, qualifier));
9157 	builtin_declaration = false;
9158 }
9159 
emit_struct_padding_target(const SPIRType & type)9160 void CompilerMSL::emit_struct_padding_target(const SPIRType &type)
9161 {
9162 	uint32_t struct_size = get_declared_struct_size_msl(type, true, true);
9163 	uint32_t target_size = get_extended_decoration(type.self, SPIRVCrossDecorationPaddingTarget);
9164 	if (target_size < struct_size)
9165 		SPIRV_CROSS_THROW("Cannot pad with negative bytes.");
9166 	else if (target_size > struct_size)
9167 		statement("char _m0_final_padding[", target_size - struct_size, "];");
9168 }
9169 
9170 // Return a MSL qualifier for the specified function attribute member
member_attribute_qualifier(const SPIRType & type,uint32_t index)9171 string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t index)
9172 {
9173 	auto &execution = get_entry_point();
9174 
9175 	uint32_t mbr_type_id = type.member_types[index];
9176 	auto &mbr_type = get<SPIRType>(mbr_type_id);
9177 
9178 	BuiltIn builtin = BuiltInMax;
9179 	bool is_builtin = is_member_builtin(type, index, &builtin);
9180 
9181 	if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary))
9182 	{
9183 		string quals = join(
9184 		    " [[id(", get_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary), ")");
9185 		if (interlocked_resources.count(
9186 		        get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID)))
9187 			quals += ", raster_order_group(0)";
9188 		quals += "]]";
9189 		return quals;
9190 	}
9191 
9192 	// Vertex function inputs
9193 	if (execution.model == ExecutionModelVertex && type.storage == StorageClassInput)
9194 	{
9195 		if (is_builtin)
9196 		{
9197 			switch (builtin)
9198 			{
9199 			case BuiltInVertexId:
9200 			case BuiltInVertexIndex:
9201 			case BuiltInBaseVertex:
9202 			case BuiltInInstanceId:
9203 			case BuiltInInstanceIndex:
9204 			case BuiltInBaseInstance:
9205 				if (msl_options.vertex_for_tessellation)
9206 					return "";
9207 				return string(" [[") + builtin_qualifier(builtin) + "]]";
9208 
9209 			case BuiltInDrawIndex:
9210 				SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
9211 
9212 			default:
9213 				return "";
9214 			}
9215 		}
9216 		uint32_t locn = get_ordered_member_location(type.self, index);
9217 		if (locn != k_unknown_location)
9218 			return string(" [[attribute(") + convert_to_string(locn) + ")]]";
9219 	}
9220 
9221 	// Vertex and tessellation evaluation function outputs
9222 	if (((execution.model == ExecutionModelVertex && !msl_options.vertex_for_tessellation) ||
9223 	     execution.model == ExecutionModelTessellationEvaluation) &&
9224 	    type.storage == StorageClassOutput)
9225 	{
9226 		if (is_builtin)
9227 		{
9228 			switch (builtin)
9229 			{
9230 			case BuiltInPointSize:
9231 				// Only mark the PointSize builtin if really rendering points.
9232 				// Some shaders may include a PointSize builtin even when used to render
9233 				// non-point topologies, and Metal will reject this builtin when compiling
9234 				// the shader into a render pipeline that uses a non-point topology.
9235 				return msl_options.enable_point_size_builtin ? (string(" [[") + builtin_qualifier(builtin) + "]]") : "";
9236 
9237 			case BuiltInViewportIndex:
9238 				if (!msl_options.supports_msl_version(2, 0))
9239 					SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
9240 				/* fallthrough */
9241 			case BuiltInPosition:
9242 			case BuiltInLayer:
9243 				return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
9244 
9245 			case BuiltInClipDistance:
9246 				if (has_member_decoration(type.self, index, DecorationLocation))
9247 					return join(" [[user(clip", get_member_decoration(type.self, index, DecorationLocation), ")]]");
9248 				else
9249 					return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
9250 
9251 			default:
9252 				return "";
9253 			}
9254 		}
9255 		uint32_t comp;
9256 		uint32_t locn = get_ordered_member_location(type.self, index, &comp);
9257 		if (locn != k_unknown_location)
9258 		{
9259 			if (comp != k_unknown_component)
9260 				return string(" [[user(locn") + convert_to_string(locn) + "_" + convert_to_string(comp) + ")]]";
9261 			else
9262 				return string(" [[user(locn") + convert_to_string(locn) + ")]]";
9263 		}
9264 	}
9265 
9266 	// Tessellation control function inputs
9267 	if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassInput)
9268 	{
9269 		if (is_builtin)
9270 		{
9271 			switch (builtin)
9272 			{
9273 			case BuiltInInvocationId:
9274 			case BuiltInPrimitiveId:
9275 				if (msl_options.multi_patch_workgroup)
9276 					return "";
9277 				/* fallthrough */
9278 			case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
9279 			case BuiltInSubgroupSize: // FIXME: Should work in any stage
9280 				return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
9281 			case BuiltInPatchVertices:
9282 				return "";
9283 			// Others come from stage input.
9284 			default:
9285 				break;
9286 			}
9287 		}
9288 		if (msl_options.multi_patch_workgroup)
9289 			return "";
9290 		uint32_t locn = get_ordered_member_location(type.self, index);
9291 		if (locn != k_unknown_location)
9292 			return string(" [[attribute(") + convert_to_string(locn) + ")]]";
9293 	}
9294 
9295 	// Tessellation control function outputs
9296 	if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassOutput)
9297 	{
9298 		// For this type of shader, we always arrange for it to capture its
9299 		// output to a buffer. For this reason, qualifiers are irrelevant here.
9300 		return "";
9301 	}
9302 
9303 	// Tessellation evaluation function inputs
9304 	if (execution.model == ExecutionModelTessellationEvaluation && type.storage == StorageClassInput)
9305 	{
9306 		if (is_builtin)
9307 		{
9308 			switch (builtin)
9309 			{
9310 			case BuiltInPrimitiveId:
9311 			case BuiltInTessCoord:
9312 				return string(" [[") + builtin_qualifier(builtin) + "]]";
9313 			case BuiltInPatchVertices:
9314 				return "";
9315 			// Others come from stage input.
9316 			default:
9317 				break;
9318 			}
9319 		}
9320 		// The special control point array must not be marked with an attribute.
9321 		if (get_type(type.member_types[index]).basetype == SPIRType::ControlPointArray)
9322 			return "";
9323 		uint32_t locn = get_ordered_member_location(type.self, index);
9324 		if (locn != k_unknown_location)
9325 			return string(" [[attribute(") + convert_to_string(locn) + ")]]";
9326 	}
9327 
9328 	// Tessellation evaluation function outputs were handled above.
9329 
9330 	// Fragment function inputs
9331 	if (execution.model == ExecutionModelFragment && type.storage == StorageClassInput)
9332 	{
9333 		string quals;
9334 		if (is_builtin)
9335 		{
9336 			switch (builtin)
9337 			{
9338 			case BuiltInViewIndex:
9339 				if (!msl_options.multiview || !msl_options.multiview_layered_rendering)
9340 					break;
9341 				/* fallthrough */
9342 			case BuiltInFrontFacing:
9343 			case BuiltInPointCoord:
9344 			case BuiltInFragCoord:
9345 			case BuiltInSampleId:
9346 			case BuiltInSampleMask:
9347 			case BuiltInLayer:
9348 			case BuiltInBaryCoordNV:
9349 			case BuiltInBaryCoordNoPerspNV:
9350 				quals = builtin_qualifier(builtin);
9351 				break;
9352 
9353 			case BuiltInClipDistance:
9354 				return join(" [[user(clip", get_member_decoration(type.self, index, DecorationLocation), ")]]");
9355 
9356 			default:
9357 				break;
9358 			}
9359 		}
9360 		else
9361 		{
9362 			uint32_t comp;
9363 			uint32_t locn = get_ordered_member_location(type.self, index, &comp);
9364 			if (locn != k_unknown_location)
9365 			{
9366 				// For user-defined attributes, this is fine. From Vulkan spec:
9367 				// A user-defined output variable is considered to match an input variable in the subsequent stage if
9368 				// the two variables are declared with the same Location and Component decoration and match in type
9369 				// and decoration, except that interpolation decorations are not required to match. For the purposes
9370 				// of interface matching, variables declared without a Component decoration are considered to have a
9371 				// Component decoration of zero.
9372 
9373 				if (comp != k_unknown_component && comp != 0)
9374 					quals = string("user(locn") + convert_to_string(locn) + "_" + convert_to_string(comp) + ")";
9375 				else
9376 					quals = string("user(locn") + convert_to_string(locn) + ")";
9377 			}
9378 		}
9379 
9380 		if (builtin == BuiltInBaryCoordNV || builtin == BuiltInBaryCoordNoPerspNV)
9381 		{
9382 			if (has_member_decoration(type.self, index, DecorationFlat) ||
9383 			    has_member_decoration(type.self, index, DecorationCentroid) ||
9384 			    has_member_decoration(type.self, index, DecorationSample) ||
9385 			    has_member_decoration(type.self, index, DecorationNoPerspective))
9386 			{
9387 				// NoPerspective is baked into the builtin type.
9388 				SPIRV_CROSS_THROW(
9389 				    "Flat, Centroid, Sample, NoPerspective decorations are not supported for BaryCoord inputs.");
9390 			}
9391 		}
9392 
9393 		// Don't bother decorating integers with the 'flat' attribute; it's
9394 		// the default (in fact, the only option). Also don't bother with the
9395 		// FragCoord builtin; it's always noperspective on Metal.
9396 		if (!type_is_integral(mbr_type) && (!is_builtin || builtin != BuiltInFragCoord))
9397 		{
9398 			if (has_member_decoration(type.self, index, DecorationFlat))
9399 			{
9400 				if (!quals.empty())
9401 					quals += ", ";
9402 				quals += "flat";
9403 			}
9404 			else if (has_member_decoration(type.self, index, DecorationCentroid))
9405 			{
9406 				if (!quals.empty())
9407 					quals += ", ";
9408 				if (has_member_decoration(type.self, index, DecorationNoPerspective))
9409 					quals += "centroid_no_perspective";
9410 				else
9411 					quals += "centroid_perspective";
9412 			}
9413 			else if (has_member_decoration(type.self, index, DecorationSample))
9414 			{
9415 				if (!quals.empty())
9416 					quals += ", ";
9417 				if (has_member_decoration(type.self, index, DecorationNoPerspective))
9418 					quals += "sample_no_perspective";
9419 				else
9420 					quals += "sample_perspective";
9421 			}
9422 			else if (has_member_decoration(type.self, index, DecorationNoPerspective))
9423 			{
9424 				if (!quals.empty())
9425 					quals += ", ";
9426 				quals += "center_no_perspective";
9427 			}
9428 		}
9429 
9430 		if (!quals.empty())
9431 			return " [[" + quals + "]]";
9432 	}
9433 
9434 	// Fragment function outputs
9435 	if (execution.model == ExecutionModelFragment && type.storage == StorageClassOutput)
9436 	{
9437 		if (is_builtin)
9438 		{
9439 			switch (builtin)
9440 			{
9441 			case BuiltInFragStencilRefEXT:
9442 				// Similar to PointSize, only mark FragStencilRef if there's a stencil buffer.
9443 				// Some shaders may include a FragStencilRef builtin even when used to render
9444 				// without a stencil attachment, and Metal will reject this builtin
9445 				// when compiling the shader into a render pipeline that does not set
9446 				// stencilAttachmentPixelFormat.
9447 				if (!msl_options.enable_frag_stencil_ref_builtin)
9448 					return "";
9449 				if (!msl_options.supports_msl_version(2, 1))
9450 					SPIRV_CROSS_THROW("Stencil export only supported in MSL 2.1 and up.");
9451 				return string(" [[") + builtin_qualifier(builtin) + "]]";
9452 
9453 			case BuiltInFragDepth:
9454 				// Ditto FragDepth.
9455 				if (!msl_options.enable_frag_depth_builtin)
9456 					return "";
9457 				/* fallthrough */
9458 			case BuiltInSampleMask:
9459 				return string(" [[") + builtin_qualifier(builtin) + "]]";
9460 
9461 			default:
9462 				return "";
9463 			}
9464 		}
9465 		uint32_t locn = get_ordered_member_location(type.self, index);
9466 		// Metal will likely complain about missing color attachments, too.
9467 		if (locn != k_unknown_location && !(msl_options.enable_frag_output_mask & (1 << locn)))
9468 			return "";
9469 		if (locn != k_unknown_location && has_member_decoration(type.self, index, DecorationIndex))
9470 			return join(" [[color(", locn, "), index(", get_member_decoration(type.self, index, DecorationIndex),
9471 			            ")]]");
9472 		else if (locn != k_unknown_location)
9473 			return join(" [[color(", locn, ")]]");
9474 		else if (has_member_decoration(type.self, index, DecorationIndex))
9475 			return join(" [[index(", get_member_decoration(type.self, index, DecorationIndex), ")]]");
9476 		else
9477 			return "";
9478 	}
9479 
9480 	// Compute function inputs
9481 	if (execution.model == ExecutionModelGLCompute && type.storage == StorageClassInput)
9482 	{
9483 		if (is_builtin)
9484 		{
9485 			switch (builtin)
9486 			{
9487 			case BuiltInGlobalInvocationId:
9488 			case BuiltInWorkgroupId:
9489 			case BuiltInNumWorkgroups:
9490 			case BuiltInLocalInvocationId:
9491 			case BuiltInLocalInvocationIndex:
9492 			case BuiltInNumSubgroups:
9493 			case BuiltInSubgroupId:
9494 			case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
9495 			case BuiltInSubgroupSize: // FIXME: Should work in any stage
9496 				return string(" [[") + builtin_qualifier(builtin) + "]]";
9497 
9498 			default:
9499 				return "";
9500 			}
9501 		}
9502 	}
9503 
9504 	return "";
9505 }
9506 
9507 // Returns the location decoration of the member with the specified index in the specified type.
9508 // If the location of the member has been explicitly set, that location is used. If not, this
9509 // function assumes the members are ordered in their location order, and simply returns the
9510 // index as the location.
get_ordered_member_location(uint32_t type_id,uint32_t index,uint32_t * comp)9511 uint32_t CompilerMSL::get_ordered_member_location(uint32_t type_id, uint32_t index, uint32_t *comp)
9512 {
9513 	auto &m = ir.meta[type_id];
9514 	if (index < m.members.size())
9515 	{
9516 		auto &dec = m.members[index];
9517 		if (comp)
9518 		{
9519 			if (dec.decoration_flags.get(DecorationComponent))
9520 				*comp = dec.component;
9521 			else
9522 				*comp = k_unknown_component;
9523 		}
9524 		if (dec.decoration_flags.get(DecorationLocation))
9525 			return dec.location;
9526 	}
9527 
9528 	return index;
9529 }
9530 
9531 // Returns the type declaration for a function, including the
9532 // entry type if the current function is the entry point function
func_type_decl(SPIRType & type)9533 string CompilerMSL::func_type_decl(SPIRType &type)
9534 {
9535 	// The regular function return type. If not processing the entry point function, that's all we need
9536 	string return_type = type_to_glsl(type) + type_to_array_glsl(type);
9537 	if (!processing_entry_point)
9538 		return return_type;
9539 
9540 	// If an outgoing interface block has been defined, and it should be returned, override the entry point return type
9541 	bool ep_should_return_output = !get_is_rasterization_disabled();
9542 	if (stage_out_var_id && ep_should_return_output)
9543 		return_type = type_to_glsl(get_stage_out_struct_type()) + type_to_array_glsl(type);
9544 
9545 	// Prepend a entry type, based on the execution model
9546 	string entry_type;
9547 	auto &execution = get_entry_point();
9548 	switch (execution.model)
9549 	{
9550 	case ExecutionModelVertex:
9551 		if (msl_options.vertex_for_tessellation && !msl_options.supports_msl_version(1, 2))
9552 			SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
9553 		entry_type = msl_options.vertex_for_tessellation ? "kernel" : "vertex";
9554 		break;
9555 	case ExecutionModelTessellationEvaluation:
9556 		if (!msl_options.supports_msl_version(1, 2))
9557 			SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
9558 		if (execution.flags.get(ExecutionModeIsolines))
9559 			SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
9560 		if (msl_options.is_ios())
9561 			entry_type =
9562 			    join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ") ]] vertex");
9563 		else
9564 			entry_type = join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ", ",
9565 			                  execution.output_vertices, ") ]] vertex");
9566 		break;
9567 	case ExecutionModelFragment:
9568 		entry_type = execution.flags.get(ExecutionModeEarlyFragmentTests) ||
9569 		                     execution.flags.get(ExecutionModePostDepthCoverage) ?
9570 		                 "[[ early_fragment_tests ]] fragment" :
9571 		                 "fragment";
9572 		break;
9573 	case ExecutionModelTessellationControl:
9574 		if (!msl_options.supports_msl_version(1, 2))
9575 			SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
9576 		if (execution.flags.get(ExecutionModeIsolines))
9577 			SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
9578 		/* fallthrough */
9579 	case ExecutionModelGLCompute:
9580 	case ExecutionModelKernel:
9581 		entry_type = "kernel";
9582 		break;
9583 	default:
9584 		entry_type = "unknown";
9585 		break;
9586 	}
9587 
9588 	return entry_type + " " + return_type;
9589 }
9590 
9591 // In MSL, address space qualifiers are required for all pointer or reference variables
get_argument_address_space(const SPIRVariable & argument)9592 string CompilerMSL::get_argument_address_space(const SPIRVariable &argument)
9593 {
9594 	const auto &type = get<SPIRType>(argument.basetype);
9595 	return get_type_address_space(type, argument.self, true);
9596 }
9597 
get_type_address_space(const SPIRType & type,uint32_t id,bool argument)9598 string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bool argument)
9599 {
9600 	// This can be called for variable pointer contexts as well, so be very careful about which method we choose.
9601 	Bitset flags;
9602 	auto *var = maybe_get<SPIRVariable>(id);
9603 	if (var && type.basetype == SPIRType::Struct &&
9604 	    (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
9605 		flags = get_buffer_block_flags(id);
9606 	else
9607 		flags = get_decoration_bitset(id);
9608 
9609 	const char *addr_space = nullptr;
9610 	switch (type.storage)
9611 	{
9612 	case StorageClassWorkgroup:
9613 		addr_space = "threadgroup";
9614 		break;
9615 
9616 	case StorageClassStorageBuffer:
9617 	{
9618 		// For arguments from variable pointers, we use the write count deduction, so
9619 		// we should not assume any constness here. Only for global SSBOs.
9620 		bool readonly = false;
9621 		if (!var || has_decoration(type.self, DecorationBlock))
9622 			readonly = flags.get(DecorationNonWritable);
9623 
9624 		addr_space = readonly ? "const device" : "device";
9625 		break;
9626 	}
9627 
9628 	case StorageClassUniform:
9629 	case StorageClassUniformConstant:
9630 	case StorageClassPushConstant:
9631 		if (type.basetype == SPIRType::Struct)
9632 		{
9633 			bool ssbo = has_decoration(type.self, DecorationBufferBlock);
9634 			if (ssbo)
9635 				addr_space = flags.get(DecorationNonWritable) ? "const device" : "device";
9636 			else
9637 				addr_space = "constant";
9638 		}
9639 		else if (!argument)
9640 		{
9641 			addr_space = "constant";
9642 		}
9643 		else if (type_is_msl_framebuffer_fetch(type))
9644 		{
9645 			// Subpass inputs are passed around by value.
9646 			addr_space = "";
9647 		}
9648 		break;
9649 
9650 	case StorageClassFunction:
9651 	case StorageClassGeneric:
9652 		break;
9653 
9654 	case StorageClassInput:
9655 		if (get_execution_model() == ExecutionModelTessellationControl && var &&
9656 		    var->basevariable == stage_in_ptr_var_id)
9657 			addr_space = msl_options.multi_patch_workgroup ? "constant" : "threadgroup";
9658 		break;
9659 
9660 	case StorageClassOutput:
9661 		if (capture_output_to_buffer)
9662 			addr_space = "device";
9663 		break;
9664 
9665 	default:
9666 		break;
9667 	}
9668 
9669 	if (!addr_space)
9670 		// No address space for plain values.
9671 		addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : "";
9672 
9673 	return join(flags.get(DecorationVolatile) || flags.get(DecorationCoherent) ? "volatile " : "", addr_space);
9674 }
9675 
to_restrict(uint32_t id,bool space)9676 const char *CompilerMSL::to_restrict(uint32_t id, bool space)
9677 {
9678 	// This can be called for variable pointer contexts as well, so be very careful about which method we choose.
9679 	Bitset flags;
9680 	if (ir.ids[id].get_type() == TypeVariable)
9681 	{
9682 		uint32_t type_id = expression_type_id(id);
9683 		auto &type = expression_type(id);
9684 		if (type.basetype == SPIRType::Struct &&
9685 		    (has_decoration(type_id, DecorationBlock) || has_decoration(type_id, DecorationBufferBlock)))
9686 			flags = get_buffer_block_flags(id);
9687 		else
9688 			flags = get_decoration_bitset(id);
9689 	}
9690 	else
9691 		flags = get_decoration_bitset(id);
9692 
9693 	return flags.get(DecorationRestrict) ? (space ? "restrict " : "restrict") : "";
9694 }
9695 
entry_point_arg_stage_in()9696 string CompilerMSL::entry_point_arg_stage_in()
9697 {
9698 	string decl;
9699 
9700 	if (get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup)
9701 		return decl;
9702 
9703 	// Stage-in structure
9704 	uint32_t stage_in_id;
9705 	if (get_execution_model() == ExecutionModelTessellationEvaluation)
9706 		stage_in_id = patch_stage_in_var_id;
9707 	else
9708 		stage_in_id = stage_in_var_id;
9709 
9710 	if (stage_in_id)
9711 	{
9712 		auto &var = get<SPIRVariable>(stage_in_id);
9713 		auto &type = get_variable_data_type(var);
9714 
9715 		add_resource_name(var.self);
9716 		decl = join(type_to_glsl(type), " ", to_name(var.self), " [[stage_in]]");
9717 	}
9718 
9719 	return decl;
9720 }
9721 
9722 // Returns true if this input builtin should be a direct parameter on a shader function parameter list,
9723 // and false for builtins that should be passed or calculated some other way.
is_direct_input_builtin(BuiltIn bi_type)9724 bool CompilerMSL::is_direct_input_builtin(BuiltIn bi_type)
9725 {
9726 	switch (bi_type)
9727 	{
9728 	// Vertex function in
9729 	case BuiltInVertexId:
9730 	case BuiltInVertexIndex:
9731 	case BuiltInBaseVertex:
9732 	case BuiltInInstanceId:
9733 	case BuiltInInstanceIndex:
9734 	case BuiltInBaseInstance:
9735 		return get_execution_model() != ExecutionModelVertex || !msl_options.vertex_for_tessellation;
9736 	// Tess. control function in
9737 	case BuiltInPosition:
9738 	case BuiltInPointSize:
9739 	case BuiltInClipDistance:
9740 	case BuiltInCullDistance:
9741 	case BuiltInPatchVertices:
9742 		return false;
9743 	case BuiltInInvocationId:
9744 	case BuiltInPrimitiveId:
9745 		return get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup;
9746 	// Tess. evaluation function in
9747 	case BuiltInTessLevelInner:
9748 	case BuiltInTessLevelOuter:
9749 		return false;
9750 	// Fragment function in
9751 	case BuiltInSamplePosition:
9752 	case BuiltInHelperInvocation:
9753 	case BuiltInBaryCoordNV:
9754 	case BuiltInBaryCoordNoPerspNV:
9755 		return false;
9756 	case BuiltInViewIndex:
9757 		return get_execution_model() == ExecutionModelFragment && msl_options.multiview &&
9758 		       msl_options.multiview_layered_rendering;
9759 	// Any stage function in
9760 	case BuiltInDeviceIndex:
9761 	case BuiltInSubgroupEqMask:
9762 	case BuiltInSubgroupGeMask:
9763 	case BuiltInSubgroupGtMask:
9764 	case BuiltInSubgroupLeMask:
9765 	case BuiltInSubgroupLtMask:
9766 		return false;
9767 	case BuiltInSubgroupLocalInvocationId:
9768 	case BuiltInSubgroupSize:
9769 		return get_execution_model() == ExecutionModelGLCompute ||
9770 		       (get_execution_model() == ExecutionModelFragment && msl_options.supports_msl_version(2, 2));
9771 	default:
9772 		return true;
9773 	}
9774 }
9775 
entry_point_args_builtin(string & ep_args)9776 void CompilerMSL::entry_point_args_builtin(string &ep_args)
9777 {
9778 	// Builtin variables
9779 	SmallVector<pair<SPIRVariable *, BuiltIn>, 8> active_builtins;
9780 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
9781 		auto bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
9782 
9783 		// Don't emit SamplePosition as a separate parameter. In the entry
9784 		// point, we get that by calling get_sample_position() on the sample ID.
9785 		if (var.storage == StorageClassInput && is_builtin_variable(var) &&
9786 		    get_variable_data_type(var).basetype != SPIRType::Struct &&
9787 		    get_variable_data_type(var).basetype != SPIRType::ControlPointArray)
9788 		{
9789 			// If the builtin is not part of the active input builtin set, don't emit it.
9790 			// Relevant for multiple entry-point modules which might declare unused builtins.
9791 			if (!active_input_builtins.get(bi_type) || !interface_variable_exists_in_entry_point(var_id))
9792 				return;
9793 
9794 			// Remember this variable. We may need to correct its type.
9795 			active_builtins.push_back(make_pair(&var, bi_type));
9796 
9797 			if (is_direct_input_builtin(bi_type))
9798 			{
9799 				if (!ep_args.empty())
9800 					ep_args += ", ";
9801 
9802 				// Handle HLSL-style 0-based vertex/instance index.
9803 				builtin_declaration = true;
9804 				ep_args += builtin_type_decl(bi_type, var_id) + " " + to_expression(var_id);
9805 				ep_args += " [[" + builtin_qualifier(bi_type);
9806 				if (bi_type == BuiltInSampleMask && get_entry_point().flags.get(ExecutionModePostDepthCoverage))
9807 				{
9808 					if (!msl_options.supports_msl_version(2))
9809 						SPIRV_CROSS_THROW("Post-depth coverage requires Metal 2.0.");
9810 					if (!msl_options.is_ios())
9811 						SPIRV_CROSS_THROW("Post-depth coverage is only supported on iOS.");
9812 					ep_args += ", post_depth_coverage";
9813 				}
9814 				ep_args += "]]";
9815 				builtin_declaration = false;
9816 			}
9817 		}
9818 
9819 		if (var.storage == StorageClassInput &&
9820 		    has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase))
9821 		{
9822 			// This is a special implicit builtin, not corresponding to any SPIR-V builtin,
9823 			// which holds the base that was passed to vkCmdDispatchBase() or vkCmdDrawIndexed(). If it's present,
9824 			// assume we emitted it for a good reason.
9825 			assert(msl_options.supports_msl_version(1, 2));
9826 			if (!ep_args.empty())
9827 				ep_args += ", ";
9828 
9829 			ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_origin]]";
9830 		}
9831 
9832 		if (var.storage == StorageClassInput &&
9833 		    has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInStageInputSize))
9834 		{
9835 			// This is another special implicit builtin, not corresponding to any SPIR-V builtin,
9836 			// which holds the number of vertices and instances to draw. If it's present,
9837 			// assume we emitted it for a good reason.
9838 			assert(msl_options.supports_msl_version(1, 2));
9839 			if (!ep_args.empty())
9840 				ep_args += ", ";
9841 
9842 			ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_size]]";
9843 		}
9844 	});
9845 
9846 	// Correct the types of all encountered active builtins. We couldn't do this before
9847 	// because ensure_correct_builtin_type() may increase the bound, which isn't allowed
9848 	// while iterating over IDs.
9849 	for (auto &var : active_builtins)
9850 		var.first->basetype = ensure_correct_builtin_type(var.first->basetype, var.second);
9851 
9852 	// Handle HLSL-style 0-based vertex/instance index.
9853 	if (needs_base_vertex_arg == TriState::Yes)
9854 		ep_args += built_in_func_arg(BuiltInBaseVertex, !ep_args.empty());
9855 
9856 	if (needs_base_instance_arg == TriState::Yes)
9857 		ep_args += built_in_func_arg(BuiltInBaseInstance, !ep_args.empty());
9858 
9859 	if (capture_output_to_buffer)
9860 	{
9861 		// Add parameters to hold the indirect draw parameters and the shader output. This has to be handled
9862 		// specially because it needs to be a pointer, not a reference.
9863 		if (stage_out_var_id)
9864 		{
9865 			if (!ep_args.empty())
9866 				ep_args += ", ";
9867 			ep_args += join("device ", type_to_glsl(get_stage_out_struct_type()), "* ", output_buffer_var_name,
9868 			                " [[buffer(", msl_options.shader_output_buffer_index, ")]]");
9869 		}
9870 
9871 		if (get_execution_model() == ExecutionModelTessellationControl)
9872 		{
9873 			if (!ep_args.empty())
9874 				ep_args += ", ";
9875 			ep_args +=
9876 			    join("constant uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
9877 		}
9878 		else if (stage_out_var_id &&
9879 		         !(get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
9880 		{
9881 			if (!ep_args.empty())
9882 				ep_args += ", ";
9883 			ep_args +=
9884 			    join("device uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
9885 		}
9886 
9887 		if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation &&
9888 		    (active_input_builtins.get(BuiltInVertexIndex) || active_input_builtins.get(BuiltInVertexId)) &&
9889 		    msl_options.vertex_index_type != Options::IndexType::None)
9890 		{
9891 			// Add the index buffer so we can set gl_VertexIndex correctly.
9892 			if (!ep_args.empty())
9893 				ep_args += ", ";
9894 			switch (msl_options.vertex_index_type)
9895 			{
9896 			case Options::IndexType::None:
9897 				break;
9898 			case Options::IndexType::UInt16:
9899 				ep_args += join("const device ushort* ", index_buffer_var_name, " [[buffer(",
9900 				                msl_options.shader_index_buffer_index, ")]]");
9901 				break;
9902 			case Options::IndexType::UInt32:
9903 				ep_args += join("const device uint* ", index_buffer_var_name, " [[buffer(",
9904 				                msl_options.shader_index_buffer_index, ")]]");
9905 				break;
9906 			}
9907 		}
9908 
9909 		// Tessellation control shaders get three additional parameters:
9910 		// a buffer to hold the per-patch data, a buffer to hold the per-patch
9911 		// tessellation levels, and a block of workgroup memory to hold the
9912 		// input control point data.
9913 		if (get_execution_model() == ExecutionModelTessellationControl)
9914 		{
9915 			if (patch_stage_out_var_id)
9916 			{
9917 				if (!ep_args.empty())
9918 					ep_args += ", ";
9919 				ep_args +=
9920 				    join("device ", type_to_glsl(get_patch_stage_out_struct_type()), "* ", patch_output_buffer_var_name,
9921 				         " [[buffer(", convert_to_string(msl_options.shader_patch_output_buffer_index), ")]]");
9922 			}
9923 			if (!ep_args.empty())
9924 				ep_args += ", ";
9925 			ep_args += join("device ", get_tess_factor_struct_name(), "* ", tess_factor_buffer_var_name, " [[buffer(",
9926 			                convert_to_string(msl_options.shader_tess_factor_buffer_index), ")]]");
9927 			if (stage_in_var_id)
9928 			{
9929 				if (!ep_args.empty())
9930 					ep_args += ", ";
9931 				if (msl_options.multi_patch_workgroup)
9932 				{
9933 					ep_args += join("device ", type_to_glsl(get_stage_in_struct_type()), "* ", input_buffer_var_name,
9934 					                " [[buffer(", convert_to_string(msl_options.shader_input_buffer_index), ")]]");
9935 				}
9936 				else
9937 				{
9938 					ep_args += join("threadgroup ", type_to_glsl(get_stage_in_struct_type()), "* ", input_wg_var_name,
9939 					                " [[threadgroup(", convert_to_string(msl_options.shader_input_wg_index), ")]]");
9940 				}
9941 			}
9942 		}
9943 	}
9944 }
9945 
entry_point_args_argument_buffer(bool append_comma)9946 string CompilerMSL::entry_point_args_argument_buffer(bool append_comma)
9947 {
9948 	string ep_args = entry_point_arg_stage_in();
9949 	Bitset claimed_bindings;
9950 
9951 	for (uint32_t i = 0; i < kMaxArgumentBuffers; i++)
9952 	{
9953 		uint32_t id = argument_buffer_ids[i];
9954 		if (id == 0)
9955 			continue;
9956 
9957 		add_resource_name(id);
9958 		auto &var = get<SPIRVariable>(id);
9959 		auto &type = get_variable_data_type(var);
9960 
9961 		if (!ep_args.empty())
9962 			ep_args += ", ";
9963 
9964 		// Check if the argument buffer binding itself has been remapped.
9965 		uint32_t buffer_binding;
9966 		auto itr = resource_bindings.find({ get_entry_point().model, i, kArgumentBufferBinding });
9967 		if (itr != end(resource_bindings))
9968 		{
9969 			buffer_binding = itr->second.first.msl_buffer;
9970 			itr->second.second = true;
9971 		}
9972 		else
9973 		{
9974 			// As a fallback, directly map desc set <-> binding.
9975 			// If that was taken, take the next buffer binding.
9976 			if (claimed_bindings.get(i))
9977 				buffer_binding = next_metal_resource_index_buffer;
9978 			else
9979 				buffer_binding = i;
9980 		}
9981 
9982 		claimed_bindings.set(buffer_binding);
9983 
9984 		ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(id) + to_name(id);
9985 		ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]";
9986 
9987 		next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1);
9988 	}
9989 
9990 	entry_point_args_discrete_descriptors(ep_args);
9991 	entry_point_args_builtin(ep_args);
9992 
9993 	if (!ep_args.empty() && append_comma)
9994 		ep_args += ", ";
9995 
9996 	return ep_args;
9997 }
9998 
find_constexpr_sampler(uint32_t id) const9999 const MSLConstexprSampler *CompilerMSL::find_constexpr_sampler(uint32_t id) const
10000 {
10001 	// Try by ID.
10002 	{
10003 		auto itr = constexpr_samplers_by_id.find(id);
10004 		if (itr != end(constexpr_samplers_by_id))
10005 			return &itr->second;
10006 	}
10007 
10008 	// Try by binding.
10009 	{
10010 		uint32_t desc_set = get_decoration(id, DecorationDescriptorSet);
10011 		uint32_t binding = get_decoration(id, DecorationBinding);
10012 
10013 		auto itr = constexpr_samplers_by_binding.find({ desc_set, binding });
10014 		if (itr != end(constexpr_samplers_by_binding))
10015 			return &itr->second;
10016 	}
10017 
10018 	return nullptr;
10019 }
10020 
entry_point_args_discrete_descriptors(string & ep_args)10021 void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
10022 {
10023 	// Output resources, sorted by resource index & type
10024 	// We need to sort to work around a bug on macOS 10.13 with NVidia drivers where switching between shaders
10025 	// with different order of buffers can result in issues with buffer assignments inside the driver.
10026 	struct Resource
10027 	{
10028 		SPIRVariable *var;
10029 		string name;
10030 		SPIRType::BaseType basetype;
10031 		uint32_t index;
10032 		uint32_t plane;
10033 		uint32_t secondary_index;
10034 	};
10035 
10036 	SmallVector<Resource> resources;
10037 
10038 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
10039 		if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
10040 		     var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer) &&
10041 		    !is_hidden_variable(var))
10042 		{
10043 			auto &type = get_variable_data_type(var);
10044 
10045 			// Very specifically, image load-store in argument buffers are disallowed on MSL on iOS.
10046 			// But we won't know when the argument buffer is encoded whether this image will have
10047 			// a NonWritable decoration. So just use discrete arguments for all storage images
10048 			// on iOS.
10049 			if (!(msl_options.is_ios() && type.basetype == SPIRType::Image && type.image.sampled == 2) &&
10050 			    var.storage != StorageClassPushConstant)
10051 			{
10052 				uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
10053 				if (descriptor_set_is_argument_buffer(desc_set))
10054 					return;
10055 			}
10056 
10057 			const MSLConstexprSampler *constexpr_sampler = nullptr;
10058 			if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
10059 			{
10060 				constexpr_sampler = find_constexpr_sampler(var_id);
10061 				if (constexpr_sampler)
10062 				{
10063 					// Mark this ID as a constexpr sampler for later in case it came from set/bindings.
10064 					constexpr_samplers_by_id[var_id] = *constexpr_sampler;
10065 				}
10066 			}
10067 
10068 			// Emulate texture2D atomic operations
10069 			uint32_t secondary_index = 0;
10070 			if (atomic_image_vars.count(var.self))
10071 			{
10072 				secondary_index = get_metal_resource_index(var, SPIRType::AtomicCounter, 0);
10073 			}
10074 
10075 			if (type.basetype == SPIRType::SampledImage)
10076 			{
10077 				add_resource_name(var_id);
10078 
10079 				uint32_t plane_count = 1;
10080 				if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
10081 					plane_count = constexpr_sampler->planes;
10082 
10083 				for (uint32_t i = 0; i < plane_count; i++)
10084 					resources.push_back({ &var, to_name(var_id), SPIRType::Image,
10085 					                      get_metal_resource_index(var, SPIRType::Image, i), i, secondary_index });
10086 
10087 				if (type.image.dim != DimBuffer && !constexpr_sampler)
10088 				{
10089 					resources.push_back({ &var, to_sampler_expression(var_id), SPIRType::Sampler,
10090 					                      get_metal_resource_index(var, SPIRType::Sampler), 0, 0 });
10091 				}
10092 			}
10093 			else if (!constexpr_sampler)
10094 			{
10095 				// constexpr samplers are not declared as resources.
10096 				add_resource_name(var_id);
10097 				resources.push_back({ &var, to_name(var_id), type.basetype,
10098 				                      get_metal_resource_index(var, type.basetype), 0, secondary_index });
10099 			}
10100 		}
10101 	});
10102 
10103 	sort(resources.begin(), resources.end(), [](const Resource &lhs, const Resource &rhs) {
10104 		return tie(lhs.basetype, lhs.index) < tie(rhs.basetype, rhs.index);
10105 	});
10106 
10107 	for (auto &r : resources)
10108 	{
10109 		auto &var = *r.var;
10110 		auto &type = get_variable_data_type(var);
10111 
10112 		uint32_t var_id = var.self;
10113 
10114 		switch (r.basetype)
10115 		{
10116 		case SPIRType::Struct:
10117 		{
10118 			auto &m = ir.meta[type.self];
10119 			if (m.members.size() == 0)
10120 				break;
10121 			if (!type.array.empty())
10122 			{
10123 				if (type.array.size() > 1)
10124 					SPIRV_CROSS_THROW("Arrays of arrays of buffers are not supported.");
10125 
10126 				// Metal doesn't directly support this, so we must expand the
10127 				// array. We'll declare a local array to hold these elements
10128 				// later.
10129 				uint32_t array_size = to_array_size_literal(type);
10130 
10131 				if (array_size == 0)
10132 					SPIRV_CROSS_THROW("Unsized arrays of buffers are not supported in MSL.");
10133 
10134 				// Allow Metal to use the array<T> template to make arrays a value type
10135 				is_using_builtin_array = true;
10136 				buffer_arrays.push_back(var_id);
10137 				for (uint32_t i = 0; i < array_size; ++i)
10138 				{
10139 					if (!ep_args.empty())
10140 						ep_args += ", ";
10141 					ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "* " + to_restrict(var_id) +
10142 					           r.name + "_" + convert_to_string(i);
10143 					ep_args += " [[buffer(" + convert_to_string(r.index + i) + ")";
10144 					if (interlocked_resources.count(var_id))
10145 						ep_args += ", raster_order_group(0)";
10146 					ep_args += "]]";
10147 				}
10148 				is_using_builtin_array = false;
10149 			}
10150 			else
10151 			{
10152 				if (!ep_args.empty())
10153 					ep_args += ", ";
10154 				ep_args +=
10155 				    get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(var_id) + r.name;
10156 				ep_args += " [[buffer(" + convert_to_string(r.index) + ")";
10157 				if (interlocked_resources.count(var_id))
10158 					ep_args += ", raster_order_group(0)";
10159 				ep_args += "]]";
10160 			}
10161 			break;
10162 		}
10163 		case SPIRType::Sampler:
10164 			if (!ep_args.empty())
10165 				ep_args += ", ";
10166 			ep_args += sampler_type(type) + " " + r.name;
10167 			ep_args += " [[sampler(" + convert_to_string(r.index) + ")]]";
10168 			break;
10169 		case SPIRType::Image:
10170 		{
10171 			if (!ep_args.empty())
10172 				ep_args += ", ";
10173 
10174 			// Use Metal's native frame-buffer fetch API for subpass inputs.
10175 			const auto &basetype = get<SPIRType>(var.basetype);
10176 			if (!type_is_msl_framebuffer_fetch(basetype))
10177 			{
10178 				ep_args += image_type_glsl(type, var_id) + " " + r.name;
10179 				if (r.plane > 0)
10180 					ep_args += join(plane_name_suffix, r.plane);
10181 				ep_args += " [[texture(" + convert_to_string(r.index) + ")";
10182 				if (interlocked_resources.count(var_id))
10183 					ep_args += ", raster_order_group(0)";
10184 				ep_args += "]]";
10185 			}
10186 			else
10187 			{
10188 				ep_args += image_type_glsl(type, var_id) + " " + r.name;
10189 				ep_args += " [[color(" + convert_to_string(r.index) + ")]]";
10190 			}
10191 
10192 			// Emulate texture2D atomic operations
10193 			if (atomic_image_vars.count(var.self))
10194 			{
10195 				ep_args += ", device atomic_" + type_to_glsl(get<SPIRType>(basetype.image.type), 0);
10196 				ep_args += "* " + r.name + "_atomic";
10197 				ep_args += " [[buffer(" + convert_to_string(r.secondary_index) + ")]]";
10198 			}
10199 			break;
10200 		}
10201 		default:
10202 			if (!ep_args.empty())
10203 				ep_args += ", ";
10204 			if (!type.pointer)
10205 				ep_args += get_type_address_space(get<SPIRType>(var.basetype), var_id) + " " +
10206 				           type_to_glsl(type, var_id) + "& " + r.name;
10207 			else
10208 				ep_args += type_to_glsl(type, var_id) + " " + r.name;
10209 			ep_args += " [[buffer(" + convert_to_string(r.index) + ")";
10210 			if (interlocked_resources.count(var_id))
10211 				ep_args += ", raster_order_group(0)";
10212 			ep_args += "]]";
10213 			break;
10214 		}
10215 	}
10216 }
10217 
10218 // Returns a string containing a comma-delimited list of args for the entry point function
10219 // This is the "classic" method of MSL 1 when we don't have argument buffer support.
entry_point_args_classic(bool append_comma)10220 string CompilerMSL::entry_point_args_classic(bool append_comma)
10221 {
10222 	string ep_args = entry_point_arg_stage_in();
10223 	entry_point_args_discrete_descriptors(ep_args);
10224 	entry_point_args_builtin(ep_args);
10225 
10226 	if (!ep_args.empty() && append_comma)
10227 		ep_args += ", ";
10228 
10229 	return ep_args;
10230 }
10231 
fix_up_shader_inputs_outputs()10232 void CompilerMSL::fix_up_shader_inputs_outputs()
10233 {
10234 	auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
10235 
10236 	// Emit a guard to ensure we don't execute beyond the last vertex.
10237 	// Vertex shaders shouldn't have the problems with barriers in non-uniform control flow that
10238 	// tessellation control shaders do, so early returns should be OK. We may need to revisit this
10239 	// if it ever becomes possible to use barriers from a vertex shader.
10240 	if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
10241 	{
10242 		entry_func.fixup_hooks_in.push_back([this]() {
10243 			statement("if (any(", to_expression(builtin_invocation_id_id),
10244 			          " >= ", to_expression(builtin_stage_input_size_id), "))");
10245 			statement("    return;");
10246 		});
10247 	}
10248 
10249 	// Look for sampled images and buffer. Add hooks to set up the swizzle constants or array lengths.
10250 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
10251 		auto &type = get_variable_data_type(var);
10252 		uint32_t var_id = var.self;
10253 		bool ssbo = has_decoration(type.self, DecorationBufferBlock);
10254 
10255 		if (var.storage == StorageClassUniformConstant && !is_hidden_variable(var))
10256 		{
10257 			if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
10258 			{
10259 				entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
10260 					bool is_array_type = !type.array.empty();
10261 
10262 					uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
10263 					if (descriptor_set_is_argument_buffer(desc_set))
10264 					{
10265 						statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
10266 						          is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
10267 						          ".spvSwizzleConstants", "[",
10268 						          convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
10269 					}
10270 					else
10271 					{
10272 						// If we have an array of images, we need to be able to index into it, so take a pointer instead.
10273 						statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
10274 						          is_array_type ? " = &" : " = ", to_name(swizzle_buffer_id), "[",
10275 						          convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
10276 					}
10277 				});
10278 			}
10279 		}
10280 		else if ((var.storage == StorageClassStorageBuffer || (var.storage == StorageClassUniform && ssbo)) &&
10281 		         !is_hidden_variable(var))
10282 		{
10283 			if (buffers_requiring_array_length.count(var.self))
10284 			{
10285 				entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
10286 					bool is_array_type = !type.array.empty();
10287 
10288 					uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
10289 					if (descriptor_set_is_argument_buffer(desc_set))
10290 					{
10291 						statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
10292 						          is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
10293 						          ".spvBufferSizeConstants", "[",
10294 						          convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
10295 					}
10296 					else
10297 					{
10298 						// If we have an array of images, we need to be able to index into it, so take a pointer instead.
10299 						statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
10300 						          is_array_type ? " = &" : " = ", to_name(buffer_size_buffer_id), "[",
10301 						          convert_to_string(get_metal_resource_index(var, type.basetype)), "];");
10302 					}
10303 				});
10304 			}
10305 		}
10306 	});
10307 
10308 	// Builtin variables
10309 	ir.for_each_typed_id<SPIRVariable>([this, &entry_func](uint32_t, SPIRVariable &var) {
10310 		uint32_t var_id = var.self;
10311 		BuiltIn bi_type = ir.meta[var_id].decoration.builtin_type;
10312 
10313 		if (var.storage == StorageClassInput && is_builtin_variable(var))
10314 		{
10315 			switch (bi_type)
10316 			{
10317 			case BuiltInSamplePosition:
10318 				entry_func.fixup_hooks_in.push_back([=]() {
10319 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = get_sample_position(",
10320 					          to_expression(builtin_sample_id_id), ");");
10321 				});
10322 				break;
10323 			case BuiltInHelperInvocation:
10324 				if (msl_options.is_ios())
10325 					SPIRV_CROSS_THROW("simd_is_helper_thread() is only supported on macOS.");
10326 				else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
10327 					SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
10328 
10329 				entry_func.fixup_hooks_in.push_back([=]() {
10330 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = simd_is_helper_thread();");
10331 				});
10332 				break;
10333 			case BuiltInInvocationId:
10334 				// This is direct-mapped without multi-patch workgroups.
10335 				if (get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup)
10336 					break;
10337 
10338 				entry_func.fixup_hooks_in.push_back([=]() {
10339 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10340 					          to_expression(builtin_invocation_id_id), ".x % ", this->get_entry_point().output_vertices,
10341 					          ";");
10342 				});
10343 				break;
10344 			case BuiltInPrimitiveId:
10345 				// This is natively supported by fragment and tessellation evaluation shaders.
10346 				// In tessellation control shaders, this is direct-mapped without multi-patch workgroups.
10347 				if (get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup)
10348 					break;
10349 
10350 				entry_func.fixup_hooks_in.push_back([=]() {
10351 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = min(",
10352 					          to_expression(builtin_invocation_id_id), ".x / ", this->get_entry_point().output_vertices,
10353 					          ", spvIndirectParams[1]);");
10354 				});
10355 				break;
10356 			case BuiltInPatchVertices:
10357 				if (get_execution_model() == ExecutionModelTessellationEvaluation)
10358 					entry_func.fixup_hooks_in.push_back([=]() {
10359 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10360 						          to_expression(patch_stage_in_var_id), ".gl_in.size();");
10361 					});
10362 				else
10363 					entry_func.fixup_hooks_in.push_back([=]() {
10364 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = spvIndirectParams[0];");
10365 					});
10366 				break;
10367 			case BuiltInTessCoord:
10368 				// Emit a fixup to account for the shifted domain. Don't do this for triangles;
10369 				// MoltenVK will just reverse the winding order instead.
10370 				if (msl_options.tess_domain_origin_lower_left && !get_entry_point().flags.get(ExecutionModeTriangles))
10371 				{
10372 					string tc = to_expression(var_id);
10373 					entry_func.fixup_hooks_in.push_back([=]() { statement(tc, ".y = 1.0 - ", tc, ".y;"); });
10374 				}
10375 				break;
10376 			case BuiltInSubgroupLocalInvocationId:
10377 				// This is natively supported in compute shaders.
10378 				if (get_execution_model() == ExecutionModelGLCompute)
10379 					break;
10380 
10381 				// This is natively supported in fragment shaders in MSL 2.2.
10382 				if (get_execution_model() == ExecutionModelFragment && msl_options.supports_msl_version(2, 2))
10383 					break;
10384 
10385 				if (msl_options.is_ios())
10386 					SPIRV_CROSS_THROW(
10387 					    "SubgroupLocalInvocationId cannot be used outside of compute shaders before MSL 2.2 on iOS.");
10388 
10389 				if (!msl_options.supports_msl_version(2, 1))
10390 					SPIRV_CROSS_THROW(
10391 					    "SubgroupLocalInvocationId cannot be used outside of compute shaders before MSL 2.1.");
10392 
10393 				// Shaders other than compute shaders don't support the SIMD-group
10394 				// builtins directly, but we can emulate them using the SIMD-group
10395 				// functions. This might break if some of the subgroup terminated
10396 				// before reaching the entry point.
10397 				entry_func.fixup_hooks_in.push_back([=]() {
10398 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
10399 					          " = simd_prefix_exclusive_sum(1);");
10400 				});
10401 				break;
10402 			case BuiltInSubgroupSize:
10403 				// This is natively supported in compute shaders.
10404 				if (get_execution_model() == ExecutionModelGLCompute)
10405 					break;
10406 
10407 				// This is natively supported in fragment shaders in MSL 2.2.
10408 				if (get_execution_model() == ExecutionModelFragment && msl_options.supports_msl_version(2, 2))
10409 					break;
10410 
10411 				if (msl_options.is_ios())
10412 					SPIRV_CROSS_THROW("SubgroupSize cannot be used outside of compute shaders on iOS.");
10413 
10414 				if (!msl_options.supports_msl_version(2, 1))
10415 					SPIRV_CROSS_THROW("SubgroupSize cannot be used outside of compute shaders before Metal 2.1.");
10416 
10417 				entry_func.fixup_hooks_in.push_back(
10418 				    [=]() { statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = simd_sum(1);"); });
10419 				break;
10420 			case BuiltInSubgroupEqMask:
10421 				if (msl_options.is_ios())
10422 					SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
10423 				if (!msl_options.supports_msl_version(2, 1))
10424 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
10425 				entry_func.fixup_hooks_in.push_back([=]() {
10426 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10427 					          to_expression(builtin_subgroup_invocation_id_id), " > 32 ? uint4(0, (1 << (",
10428 					          to_expression(builtin_subgroup_invocation_id_id), " - 32)), uint2(0)) : uint4(1 << ",
10429 					          to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
10430 				});
10431 				break;
10432 			case BuiltInSubgroupGeMask:
10433 				if (msl_options.is_ios())
10434 					SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
10435 				if (!msl_options.supports_msl_version(2, 1))
10436 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
10437 				entry_func.fixup_hooks_in.push_back([=]() {
10438 					// Case where index < 32, size < 32:
10439 					// mask0 = bfe(0xFFFFFFFF, index, size - index);
10440 					// mask1 = bfe(0xFFFFFFFF, 0, 0); // Gives 0
10441 					// Case where index < 32 but size >= 32:
10442 					// mask0 = bfe(0xFFFFFFFF, index, 32 - index);
10443 					// mask1 = bfe(0xFFFFFFFF, 0, size - 32);
10444 					// Case where index >= 32:
10445 					// mask0 = bfe(0xFFFFFFFF, 32, 0); // Gives 0
10446 					// mask1 = bfe(0xFFFFFFFF, index - 32, size - index);
10447 					// This is expressed without branches to avoid divergent
10448 					// control flow--hence the complicated min/max expressions.
10449 					// This is further complicated by the fact that if you attempt
10450 					// to bfe out-of-bounds on Metal, undefined behavior is the
10451 					// result.
10452 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
10453 					          " = uint4(extract_bits(0xFFFFFFFF, min(",
10454 					          to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(min((int)",
10455 					          to_expression(builtin_subgroup_size_id), ", 32) - (int)",
10456 					          to_expression(builtin_subgroup_invocation_id_id),
10457 					          ", 0)), extract_bits(0xFFFFFFFF, (uint)max((int)",
10458 					          to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), (uint)max((int)",
10459 					          to_expression(builtin_subgroup_size_id), " - (int)max(",
10460 					          to_expression(builtin_subgroup_invocation_id_id), ", 32u), 0)), uint2(0));");
10461 				});
10462 				break;
10463 			case BuiltInSubgroupGtMask:
10464 				if (msl_options.is_ios())
10465 					SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
10466 				if (!msl_options.supports_msl_version(2, 1))
10467 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
10468 				entry_func.fixup_hooks_in.push_back([=]() {
10469 					// The same logic applies here, except now the index is one
10470 					// more than the subgroup invocation ID.
10471 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
10472 					          " = uint4(extract_bits(0xFFFFFFFF, min(",
10473 					          to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(min((int)",
10474 					          to_expression(builtin_subgroup_size_id), ", 32) - (int)",
10475 					          to_expression(builtin_subgroup_invocation_id_id),
10476 					          " - 1, 0)), extract_bits(0xFFFFFFFF, (uint)max((int)",
10477 					          to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), (uint)max((int)",
10478 					          to_expression(builtin_subgroup_size_id), " - (int)max(",
10479 					          to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), 0)), uint2(0));");
10480 				});
10481 				break;
10482 			case BuiltInSubgroupLeMask:
10483 				if (msl_options.is_ios())
10484 					SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
10485 				if (!msl_options.supports_msl_version(2, 1))
10486 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
10487 				entry_func.fixup_hooks_in.push_back([=]() {
10488 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
10489 					          " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
10490 					          to_expression(builtin_subgroup_invocation_id_id),
10491 					          " + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
10492 					          to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0)), uint2(0));");
10493 				});
10494 				break;
10495 			case BuiltInSubgroupLtMask:
10496 				if (msl_options.is_ios())
10497 					SPIRV_CROSS_THROW("Subgroup ballot functionality is unavailable on iOS.");
10498 				if (!msl_options.supports_msl_version(2, 1))
10499 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
10500 				entry_func.fixup_hooks_in.push_back([=]() {
10501 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
10502 					          " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
10503 					          to_expression(builtin_subgroup_invocation_id_id),
10504 					          ", 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
10505 					          to_expression(builtin_subgroup_invocation_id_id), " - 32, 0)), uint2(0));");
10506 				});
10507 				break;
10508 			case BuiltInViewIndex:
10509 				if (!msl_options.multiview)
10510 				{
10511 					// According to the Vulkan spec, when not running under a multiview
10512 					// render pass, ViewIndex is 0.
10513 					entry_func.fixup_hooks_in.push_back([=]() {
10514 						statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = 0;");
10515 					});
10516 				}
10517 				else if (msl_options.view_index_from_device_index)
10518 				{
10519 					// In this case, we take the view index from that of the device we're running on.
10520 					entry_func.fixup_hooks_in.push_back([=]() {
10521 						statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10522 						          msl_options.device_index, ";");
10523 					});
10524 					// We actually don't want to set the render_target_array_index here.
10525 					// Since every physical device is rendering a different view,
10526 					// there's no need for layered rendering here.
10527 				}
10528 				else if (!msl_options.multiview_layered_rendering)
10529 				{
10530 					// In this case, the views are rendered one at a time. The view index, then,
10531 					// is just the first part of the "view mask".
10532 					entry_func.fixup_hooks_in.push_back([=]() {
10533 						statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10534 						          to_expression(view_mask_buffer_id), "[0];");
10535 					});
10536 				}
10537 				else if (get_execution_model() == ExecutionModelFragment)
10538 				{
10539 					// Because we adjusted the view index in the vertex shader, we have to
10540 					// adjust it back here.
10541 					entry_func.fixup_hooks_in.push_back([=]() {
10542 						statement(to_expression(var_id), " += ", to_expression(view_mask_buffer_id), "[0];");
10543 					});
10544 				}
10545 				else if (get_execution_model() == ExecutionModelVertex)
10546 				{
10547 					// Metal provides no special support for multiview, so we smuggle
10548 					// the view index in the instance index.
10549 					entry_func.fixup_hooks_in.push_back([=]() {
10550 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10551 						          to_expression(view_mask_buffer_id), "[0] + (", to_expression(builtin_instance_idx_id),
10552 						          " - ", to_expression(builtin_base_instance_id), ") % ",
10553 						          to_expression(view_mask_buffer_id), "[1];");
10554 						statement(to_expression(builtin_instance_idx_id), " = (",
10555 						          to_expression(builtin_instance_idx_id), " - ",
10556 						          to_expression(builtin_base_instance_id), ") / ", to_expression(view_mask_buffer_id),
10557 						          "[1] + ", to_expression(builtin_base_instance_id), ";");
10558 					});
10559 					// In addition to setting the variable itself, we also need to
10560 					// set the render_target_array_index with it on output. We have to
10561 					// offset this by the base view index, because Metal isn't in on
10562 					// our little game here.
10563 					entry_func.fixup_hooks_out.push_back([=]() {
10564 						statement(to_expression(builtin_layer_id), " = ", to_expression(var_id), " - ",
10565 						          to_expression(view_mask_buffer_id), "[0];");
10566 					});
10567 				}
10568 				break;
10569 			case BuiltInDeviceIndex:
10570 				// Metal pipelines belong to the devices which create them, so we'll
10571 				// need to create a MTLPipelineState for every MTLDevice in a grouped
10572 				// VkDevice. We can assume, then, that the device index is constant.
10573 				entry_func.fixup_hooks_in.push_back([=]() {
10574 					statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10575 					          msl_options.device_index, ";");
10576 				});
10577 				break;
10578 			case BuiltInWorkgroupId:
10579 				if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInWorkgroupId))
10580 					break;
10581 
10582 				// The vkCmdDispatchBase() command lets the client set the base value
10583 				// of WorkgroupId. Metal has no direct equivalent; we must make this
10584 				// adjustment ourselves.
10585 				entry_func.fixup_hooks_in.push_back([=]() {
10586 					statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id), ";");
10587 				});
10588 				break;
10589 			case BuiltInGlobalInvocationId:
10590 				if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInGlobalInvocationId))
10591 					break;
10592 
10593 				// GlobalInvocationId is defined as LocalInvocationId + WorkgroupId * WorkgroupSize.
10594 				// This needs to be adjusted too.
10595 				entry_func.fixup_hooks_in.push_back([=]() {
10596 					auto &execution = this->get_entry_point();
10597 					uint32_t workgroup_size_id = execution.workgroup_size.constant;
10598 					if (workgroup_size_id)
10599 						statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
10600 						          " * ", to_expression(workgroup_size_id), ";");
10601 					else
10602 						statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
10603 						          " * uint3(", execution.workgroup_size.x, ", ", execution.workgroup_size.y, ", ",
10604 						          execution.workgroup_size.z, ");");
10605 				});
10606 				break;
10607 			case BuiltInVertexId:
10608 			case BuiltInVertexIndex:
10609 				// This is direct-mapped normally.
10610 				if (!msl_options.vertex_for_tessellation)
10611 					break;
10612 
10613 				entry_func.fixup_hooks_in.push_back([=]() {
10614 					builtin_declaration = true;
10615 					switch (msl_options.vertex_index_type)
10616 					{
10617 					case Options::IndexType::None:
10618 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10619 						          to_expression(builtin_invocation_id_id), ".x + ",
10620 						          to_expression(builtin_dispatch_base_id), ".x;");
10621 						break;
10622 					case Options::IndexType::UInt16:
10623 					case Options::IndexType::UInt32:
10624 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", index_buffer_var_name,
10625 						          "[", to_expression(builtin_invocation_id_id), ".x] + ",
10626 						          to_expression(builtin_dispatch_base_id), ".x;");
10627 						break;
10628 					}
10629 					builtin_declaration = false;
10630 				});
10631 				break;
10632 			case BuiltInBaseVertex:
10633 				// This is direct-mapped normally.
10634 				if (!msl_options.vertex_for_tessellation)
10635 					break;
10636 
10637 				entry_func.fixup_hooks_in.push_back([=]() {
10638 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10639 					          to_expression(builtin_dispatch_base_id), ".x;");
10640 				});
10641 				break;
10642 			case BuiltInInstanceId:
10643 			case BuiltInInstanceIndex:
10644 				// This is direct-mapped normally.
10645 				if (!msl_options.vertex_for_tessellation)
10646 					break;
10647 
10648 				entry_func.fixup_hooks_in.push_back([=]() {
10649 					builtin_declaration = true;
10650 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10651 					          to_expression(builtin_invocation_id_id), ".y + ", to_expression(builtin_dispatch_base_id),
10652 					          ".y;");
10653 					builtin_declaration = false;
10654 				});
10655 				break;
10656 			case BuiltInBaseInstance:
10657 				// This is direct-mapped normally.
10658 				if (!msl_options.vertex_for_tessellation)
10659 					break;
10660 
10661 				entry_func.fixup_hooks_in.push_back([=]() {
10662 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
10663 					          to_expression(builtin_dispatch_base_id), ".y;");
10664 				});
10665 				break;
10666 			default:
10667 				break;
10668 			}
10669 		}
10670 		else if (var.storage == StorageClassOutput && is_builtin_variable(var))
10671 		{
10672 			if (bi_type == BuiltInSampleMask && get_execution_model() == ExecutionModelFragment &&
10673 			    msl_options.additional_fixed_sample_mask != 0xffffffff)
10674 			{
10675 				// If the additional fixed sample mask was set, we need to adjust the sample_mask
10676 				// output to reflect that. If the shader outputs the sample_mask itself too, we need
10677 				// to AND the two masks to get the final one.
10678 				if (does_shader_write_sample_mask)
10679 				{
10680 					entry_func.fixup_hooks_out.push_back([=]() {
10681 						statement(to_expression(builtin_sample_mask_id),
10682 						          " &= ", msl_options.additional_fixed_sample_mask, ";");
10683 					});
10684 				}
10685 				else
10686 				{
10687 					entry_func.fixup_hooks_out.push_back([=]() {
10688 						statement(to_expression(builtin_sample_mask_id), " = ",
10689 						          msl_options.additional_fixed_sample_mask, ";");
10690 					});
10691 				}
10692 			}
10693 		}
10694 	});
10695 }
10696 
10697 // Returns the Metal index of the resource of the specified type as used by the specified variable.
get_metal_resource_index(SPIRVariable & var,SPIRType::BaseType basetype,uint32_t plane)10698 uint32_t CompilerMSL::get_metal_resource_index(SPIRVariable &var, SPIRType::BaseType basetype, uint32_t plane)
10699 {
10700 	auto &execution = get_entry_point();
10701 	auto &var_dec = ir.meta[var.self].decoration;
10702 	auto &var_type = get<SPIRType>(var.basetype);
10703 	uint32_t var_desc_set = (var.storage == StorageClassPushConstant) ? kPushConstDescSet : var_dec.set;
10704 	uint32_t var_binding = (var.storage == StorageClassPushConstant) ? kPushConstBinding : var_dec.binding;
10705 
10706 	// If a matching binding has been specified, find and use it.
10707 	auto itr = resource_bindings.find({ execution.model, var_desc_set, var_binding });
10708 
10709 	// Atomic helper buffers for image atomics need to use secondary bindings as well.
10710 	bool use_secondary_binding = (var_type.basetype == SPIRType::SampledImage && basetype == SPIRType::Sampler) ||
10711 	                             basetype == SPIRType::AtomicCounter;
10712 
10713 	auto resource_decoration =
10714 	    use_secondary_binding ? SPIRVCrossDecorationResourceIndexSecondary : SPIRVCrossDecorationResourceIndexPrimary;
10715 
10716 	if (plane == 1)
10717 		resource_decoration = SPIRVCrossDecorationResourceIndexTertiary;
10718 	if (plane == 2)
10719 		resource_decoration = SPIRVCrossDecorationResourceIndexQuaternary;
10720 
10721 	if (itr != end(resource_bindings))
10722 	{
10723 		auto &remap = itr->second;
10724 		remap.second = true;
10725 		switch (basetype)
10726 		{
10727 		case SPIRType::Image:
10728 			set_extended_decoration(var.self, resource_decoration, remap.first.msl_texture + plane);
10729 			return remap.first.msl_texture + plane;
10730 		case SPIRType::Sampler:
10731 			set_extended_decoration(var.self, resource_decoration, remap.first.msl_sampler);
10732 			return remap.first.msl_sampler;
10733 		default:
10734 			set_extended_decoration(var.self, resource_decoration, remap.first.msl_buffer);
10735 			return remap.first.msl_buffer;
10736 		}
10737 	}
10738 
10739 	// If we have already allocated an index, keep using it.
10740 	if (has_extended_decoration(var.self, resource_decoration))
10741 		return get_extended_decoration(var.self, resource_decoration);
10742 
10743 	// Allow user to enable decoration binding
10744 	if (msl_options.enable_decoration_binding)
10745 	{
10746 		// If there is no explicit mapping of bindings to MSL, use the declared binding.
10747 		if (has_decoration(var.self, DecorationBinding))
10748 		{
10749 			var_binding = get_decoration(var.self, DecorationBinding);
10750 			// Avoid emitting sentinel bindings.
10751 			if (var_binding < 0x80000000u)
10752 				return var_binding;
10753 		}
10754 	}
10755 
10756 	// If we did not explicitly remap, allocate bindings on demand.
10757 	// We cannot reliably use Binding decorations since SPIR-V and MSL's binding models are very different.
10758 
10759 	bool allocate_argument_buffer_ids = false;
10760 
10761 	if (var.storage != StorageClassPushConstant)
10762 		allocate_argument_buffer_ids = descriptor_set_is_argument_buffer(var_desc_set);
10763 
10764 	uint32_t binding_stride = 1;
10765 	auto &type = get<SPIRType>(var.basetype);
10766 	for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
10767 		binding_stride *= to_array_size_literal(type, i);
10768 
10769 	assert(binding_stride != 0);
10770 
10771 	// If a binding has not been specified, revert to incrementing resource indices.
10772 	uint32_t resource_index;
10773 
10774 	if (type_is_msl_framebuffer_fetch(type))
10775 	{
10776 		// Frame-buffer fetch gets its fallback resource index from the input attachment index,
10777 		// which is then treated as color index.
10778 		resource_index = get_decoration(var.self, DecorationInputAttachmentIndex);
10779 	}
10780 	else if (allocate_argument_buffer_ids)
10781 	{
10782 		// Allocate from a flat ID binding space.
10783 		resource_index = next_metal_resource_ids[var_desc_set];
10784 		next_metal_resource_ids[var_desc_set] += binding_stride;
10785 	}
10786 	else
10787 	{
10788 		// Allocate from plain bindings which are allocated per resource type.
10789 		switch (basetype)
10790 		{
10791 		case SPIRType::Image:
10792 			resource_index = next_metal_resource_index_texture;
10793 			next_metal_resource_index_texture += binding_stride;
10794 			break;
10795 		case SPIRType::Sampler:
10796 			resource_index = next_metal_resource_index_sampler;
10797 			next_metal_resource_index_sampler += binding_stride;
10798 			break;
10799 		default:
10800 			resource_index = next_metal_resource_index_buffer;
10801 			next_metal_resource_index_buffer += binding_stride;
10802 			break;
10803 		}
10804 	}
10805 
10806 	set_extended_decoration(var.self, resource_decoration, resource_index);
10807 	return resource_index;
10808 }
10809 
type_is_msl_framebuffer_fetch(const SPIRType & type) const10810 bool CompilerMSL::type_is_msl_framebuffer_fetch(const SPIRType &type) const
10811 {
10812 	return type.basetype == SPIRType::Image && type.image.dim == DimSubpassData && msl_options.is_ios() &&
10813 	       msl_options.ios_use_framebuffer_fetch_subpasses;
10814 }
10815 
argument_decl(const SPIRFunction::Parameter & arg)10816 string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
10817 {
10818 	auto &var = get<SPIRVariable>(arg.id);
10819 	auto &type = get_variable_data_type(var);
10820 	auto &var_type = get<SPIRType>(arg.type);
10821 	StorageClass storage = var_type.storage;
10822 	bool is_pointer = var_type.pointer;
10823 
10824 	// If we need to modify the name of the variable, make sure we use the original variable.
10825 	// Our alias is just a shadow variable.
10826 	uint32_t name_id = var.self;
10827 	if (arg.alias_global_variable && var.basevariable)
10828 		name_id = var.basevariable;
10829 
10830 	bool constref = !arg.alias_global_variable && is_pointer && arg.write_count == 0;
10831 	// Framebuffer fetch is plain value, const looks out of place, but it is not wrong.
10832 	if (type_is_msl_framebuffer_fetch(type))
10833 		constref = false;
10834 
10835 	bool type_is_image = type.basetype == SPIRType::Image || type.basetype == SPIRType::SampledImage ||
10836 	                     type.basetype == SPIRType::Sampler;
10837 
10838 	// Arrays of images/samplers in MSL are always const.
10839 	if (!type.array.empty() && type_is_image)
10840 		constref = true;
10841 
10842 	string decl;
10843 	if (constref)
10844 		decl += "const ";
10845 
10846 	// If this is a combined image-sampler for a 2D image with floating-point type,
10847 	// we emitted the 'spvDynamicImageSampler' type, and this is *not* an alias parameter
10848 	// for a global, then we need to emit a "dynamic" combined image-sampler.
10849 	// Unfortunately, this is necessary to properly support passing around
10850 	// combined image-samplers with Y'CbCr conversions on them.
10851 	bool is_dynamic_img_sampler = !arg.alias_global_variable && type.basetype == SPIRType::SampledImage &&
10852 	                              type.image.dim == Dim2D && type_is_floating_point(get<SPIRType>(type.image.type)) &&
10853 	                              spv_function_implementations.count(SPVFuncImplDynamicImageSampler);
10854 
10855 	// Allow Metal to use the array<T> template to make arrays a value type
10856 	string address_space = get_argument_address_space(var);
10857 	bool builtin = is_builtin_variable(var);
10858 	is_using_builtin_array = builtin;
10859 	if (address_space == "threadgroup")
10860 		is_using_builtin_array = true;
10861 
10862 	if (var.basevariable && (var.basevariable == stage_in_ptr_var_id || var.basevariable == stage_out_ptr_var_id))
10863 		decl += type_to_glsl(type, arg.id);
10864 	else if (builtin)
10865 		decl += builtin_type_decl(static_cast<BuiltIn>(get_decoration(arg.id, DecorationBuiltIn)), arg.id);
10866 	else if ((storage == StorageClassUniform || storage == StorageClassStorageBuffer) && is_array(type))
10867 	{
10868 		is_using_builtin_array = true;
10869 		decl += join(type_to_glsl(type, arg.id), "*");
10870 	}
10871 	else if (is_dynamic_img_sampler)
10872 	{
10873 		decl += join("spvDynamicImageSampler<", type_to_glsl(get<SPIRType>(type.image.type)), ">");
10874 		// Mark the variable so that we can handle passing it to another function.
10875 		set_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
10876 	}
10877 	else
10878 		decl += type_to_glsl(type, arg.id);
10879 
10880 	bool opaque_handle = storage == StorageClassUniformConstant;
10881 
10882 	if (!builtin && !opaque_handle && !is_pointer &&
10883 	    (storage == StorageClassFunction || storage == StorageClassGeneric))
10884 	{
10885 		// If the argument is a pure value and not an opaque type, we will pass by value.
10886 		if (msl_options.force_native_arrays && is_array(type))
10887 		{
10888 			// We are receiving an array by value. This is problematic.
10889 			// We cannot be sure of the target address space since we are supposed to receive a copy,
10890 			// but this is not possible with MSL without some extra work.
10891 			// We will have to assume we're getting a reference in thread address space.
10892 			// If we happen to get a reference in constant address space, the caller must emit a copy and pass that.
10893 			// Thread const therefore becomes the only logical choice, since we cannot "create" a constant array from
10894 			// non-constant arrays, but we can create thread const from constant.
10895 			decl = string("thread const ") + decl;
10896 			decl += " (&";
10897 			const char *restrict_kw = to_restrict(name_id);
10898 			if (*restrict_kw)
10899 			{
10900 				decl += " ";
10901 				decl += restrict_kw;
10902 			}
10903 			decl += to_expression(name_id);
10904 			decl += ")";
10905 			decl += type_to_array_glsl(type);
10906 		}
10907 		else
10908 		{
10909 			if (!address_space.empty())
10910 				decl = join(address_space, " ", decl);
10911 			decl += " ";
10912 			decl += to_expression(name_id);
10913 		}
10914 	}
10915 	else if (is_array(type) && !type_is_image)
10916 	{
10917 		// Arrays of images and samplers are special cased.
10918 		if (!address_space.empty())
10919 			decl = join(address_space, " ", decl);
10920 
10921 		if (msl_options.argument_buffers)
10922 		{
10923 			uint32_t desc_set = get_decoration(name_id, DecorationDescriptorSet);
10924 			if ((storage == StorageClassUniform || storage == StorageClassStorageBuffer) &&
10925 			    descriptor_set_is_argument_buffer(desc_set))
10926 			{
10927 				// An awkward case where we need to emit *more* address space declarations (yay!).
10928 				// An example is where we pass down an array of buffer pointers to leaf functions.
10929 				// It's a constant array containing pointers to constants.
10930 				// The pointer array is always constant however. E.g.
10931 				// device SSBO * constant (&array)[N].
10932 				// const device SSBO * constant (&array)[N].
10933 				// constant SSBO * constant (&array)[N].
10934 				// However, this only matters for argument buffers, since for MSL 1.0 style codegen,
10935 				// we emit the buffer array on stack instead, and that seems to work just fine apparently.
10936 
10937 				// If the argument was marked as being in device address space, any pointer to member would
10938 				// be const device, not constant.
10939 				if (argument_buffer_device_storage_mask & (1u << desc_set))
10940 					decl += " const device";
10941 				else
10942 					decl += " constant";
10943 			}
10944 		}
10945 
10946 		decl += " (&";
10947 		const char *restrict_kw = to_restrict(name_id);
10948 		if (*restrict_kw)
10949 		{
10950 			decl += " ";
10951 			decl += restrict_kw;
10952 		}
10953 		decl += to_expression(name_id);
10954 		decl += ")";
10955 		decl += type_to_array_glsl(type);
10956 	}
10957 	else if (!opaque_handle)
10958 	{
10959 		// If this is going to be a reference to a variable pointer, the address space
10960 		// for the reference has to go before the '&', but after the '*'.
10961 		if (!address_space.empty())
10962 		{
10963 			if (decl.back() == '*')
10964 				decl += join(" ", address_space, " ");
10965 			else
10966 				decl = join(address_space, " ", decl);
10967 		}
10968 		decl += "&";
10969 		decl += " ";
10970 		decl += to_restrict(name_id);
10971 		decl += to_expression(name_id);
10972 	}
10973 	else
10974 	{
10975 		if (!address_space.empty())
10976 			decl = join(address_space, " ", decl);
10977 		decl += " ";
10978 		decl += to_expression(name_id);
10979 	}
10980 
10981 	// Emulate texture2D atomic operations
10982 	auto *backing_var = maybe_get_backing_variable(name_id);
10983 	if (backing_var && atomic_image_vars.count(backing_var->self))
10984 	{
10985 		decl += ", device atomic_" + type_to_glsl(get<SPIRType>(var_type.image.type), 0);
10986 		decl += "* " + to_expression(name_id) + "_atomic";
10987 	}
10988 
10989 	is_using_builtin_array = false;
10990 
10991 	return decl;
10992 }
10993 
10994 // If we're currently in the entry point function, and the object
10995 // has a qualified name, use it, otherwise use the standard name.
to_name(uint32_t id,bool allow_alias) const10996 string CompilerMSL::to_name(uint32_t id, bool allow_alias) const
10997 {
10998 	if (current_function && (current_function->self == ir.default_entry_point))
10999 	{
11000 		auto *m = ir.find_meta(id);
11001 		if (m && !m->decoration.qualified_alias.empty())
11002 			return m->decoration.qualified_alias;
11003 	}
11004 	return Compiler::to_name(id, allow_alias);
11005 }
11006 
11007 // Returns a name that combines the name of the struct with the name of the member, except for Builtins
to_qualified_member_name(const SPIRType & type,uint32_t index)11008 string CompilerMSL::to_qualified_member_name(const SPIRType &type, uint32_t index)
11009 {
11010 	// Don't qualify Builtin names because they are unique and are treated as such when building expressions
11011 	BuiltIn builtin = BuiltInMax;
11012 	if (is_member_builtin(type, index, &builtin))
11013 		return builtin_to_glsl(builtin, type.storage);
11014 
11015 	// Strip any underscore prefix from member name
11016 	string mbr_name = to_member_name(type, index);
11017 	size_t startPos = mbr_name.find_first_not_of("_");
11018 	mbr_name = (startPos != string::npos) ? mbr_name.substr(startPos) : "";
11019 	return join(to_name(type.self), "_", mbr_name);
11020 }
11021 
11022 // Ensures that the specified name is permanently usable by prepending a prefix
11023 // if the first chars are _ and a digit, which indicate a transient name.
ensure_valid_name(string name,string pfx)11024 string CompilerMSL::ensure_valid_name(string name, string pfx)
11025 {
11026 	return (name.size() >= 2 && name[0] == '_' && isdigit(name[1])) ? (pfx + name) : name;
11027 }
11028 
11029 // Replace all names that match MSL keywords or Metal Standard Library functions.
replace_illegal_names()11030 void CompilerMSL::replace_illegal_names()
11031 {
11032 	// FIXME: MSL and GLSL are doing two different things here.
11033 	// Agree on convention and remove this override.
11034 	static const unordered_set<string> keywords = {
11035 		"kernel",
11036 		"vertex",
11037 		"fragment",
11038 		"compute",
11039 		"bias",
11040 		"assert",
11041 		"VARIABLE_TRACEPOINT",
11042 		"STATIC_DATA_TRACEPOINT",
11043 		"STATIC_DATA_TRACEPOINT_V",
11044 		"METAL_ALIGN",
11045 		"METAL_ASM",
11046 		"METAL_CONST",
11047 		"METAL_DEPRECATED",
11048 		"METAL_ENABLE_IF",
11049 		"METAL_FUNC",
11050 		"METAL_INTERNAL",
11051 		"METAL_NON_NULL_RETURN",
11052 		"METAL_NORETURN",
11053 		"METAL_NOTHROW",
11054 		"METAL_PURE",
11055 		"METAL_UNAVAILABLE",
11056 		"METAL_IMPLICIT",
11057 		"METAL_EXPLICIT",
11058 		"METAL_CONST_ARG",
11059 		"METAL_ARG_UNIFORM",
11060 		"METAL_ZERO_ARG",
11061 		"METAL_VALID_LOD_ARG",
11062 		"METAL_VALID_LEVEL_ARG",
11063 		"METAL_VALID_STORE_ORDER",
11064 		"METAL_VALID_LOAD_ORDER",
11065 		"METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
11066 		"METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
11067 		"METAL_VALID_RENDER_TARGET",
11068 		"is_function_constant_defined",
11069 		"CHAR_BIT",
11070 		"SCHAR_MAX",
11071 		"SCHAR_MIN",
11072 		"UCHAR_MAX",
11073 		"CHAR_MAX",
11074 		"CHAR_MIN",
11075 		"USHRT_MAX",
11076 		"SHRT_MAX",
11077 		"SHRT_MIN",
11078 		"UINT_MAX",
11079 		"INT_MAX",
11080 		"INT_MIN",
11081 		"FLT_DIG",
11082 		"FLT_MANT_DIG",
11083 		"FLT_MAX_10_EXP",
11084 		"FLT_MAX_EXP",
11085 		"FLT_MIN_10_EXP",
11086 		"FLT_MIN_EXP",
11087 		"FLT_RADIX",
11088 		"FLT_MAX",
11089 		"FLT_MIN",
11090 		"FLT_EPSILON",
11091 		"FP_ILOGB0",
11092 		"FP_ILOGBNAN",
11093 		"MAXFLOAT",
11094 		"HUGE_VALF",
11095 		"INFINITY",
11096 		"NAN",
11097 		"M_E_F",
11098 		"M_LOG2E_F",
11099 		"M_LOG10E_F",
11100 		"M_LN2_F",
11101 		"M_LN10_F",
11102 		"M_PI_F",
11103 		"M_PI_2_F",
11104 		"M_PI_4_F",
11105 		"M_1_PI_F",
11106 		"M_2_PI_F",
11107 		"M_2_SQRTPI_F",
11108 		"M_SQRT2_F",
11109 		"M_SQRT1_2_F",
11110 		"HALF_DIG",
11111 		"HALF_MANT_DIG",
11112 		"HALF_MAX_10_EXP",
11113 		"HALF_MAX_EXP",
11114 		"HALF_MIN_10_EXP",
11115 		"HALF_MIN_EXP",
11116 		"HALF_RADIX",
11117 		"HALF_MAX",
11118 		"HALF_MIN",
11119 		"HALF_EPSILON",
11120 		"MAXHALF",
11121 		"HUGE_VALH",
11122 		"M_E_H",
11123 		"M_LOG2E_H",
11124 		"M_LOG10E_H",
11125 		"M_LN2_H",
11126 		"M_LN10_H",
11127 		"M_PI_H",
11128 		"M_PI_2_H",
11129 		"M_PI_4_H",
11130 		"M_1_PI_H",
11131 		"M_2_PI_H",
11132 		"M_2_SQRTPI_H",
11133 		"M_SQRT2_H",
11134 		"M_SQRT1_2_H",
11135 		"DBL_DIG",
11136 		"DBL_MANT_DIG",
11137 		"DBL_MAX_10_EXP",
11138 		"DBL_MAX_EXP",
11139 		"DBL_MIN_10_EXP",
11140 		"DBL_MIN_EXP",
11141 		"DBL_RADIX",
11142 		"DBL_MAX",
11143 		"DBL_MIN",
11144 		"DBL_EPSILON",
11145 		"HUGE_VAL",
11146 		"M_E",
11147 		"M_LOG2E",
11148 		"M_LOG10E",
11149 		"M_LN2",
11150 		"M_LN10",
11151 		"M_PI",
11152 		"M_PI_2",
11153 		"M_PI_4",
11154 		"M_1_PI",
11155 		"M_2_PI",
11156 		"M_2_SQRTPI",
11157 		"M_SQRT2",
11158 		"M_SQRT1_2",
11159 		"quad_broadcast",
11160 	};
11161 
11162 	static const unordered_set<string> illegal_func_names = {
11163 		"main",
11164 		"saturate",
11165 		"assert",
11166 		"VARIABLE_TRACEPOINT",
11167 		"STATIC_DATA_TRACEPOINT",
11168 		"STATIC_DATA_TRACEPOINT_V",
11169 		"METAL_ALIGN",
11170 		"METAL_ASM",
11171 		"METAL_CONST",
11172 		"METAL_DEPRECATED",
11173 		"METAL_ENABLE_IF",
11174 		"METAL_FUNC",
11175 		"METAL_INTERNAL",
11176 		"METAL_NON_NULL_RETURN",
11177 		"METAL_NORETURN",
11178 		"METAL_NOTHROW",
11179 		"METAL_PURE",
11180 		"METAL_UNAVAILABLE",
11181 		"METAL_IMPLICIT",
11182 		"METAL_EXPLICIT",
11183 		"METAL_CONST_ARG",
11184 		"METAL_ARG_UNIFORM",
11185 		"METAL_ZERO_ARG",
11186 		"METAL_VALID_LOD_ARG",
11187 		"METAL_VALID_LEVEL_ARG",
11188 		"METAL_VALID_STORE_ORDER",
11189 		"METAL_VALID_LOAD_ORDER",
11190 		"METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
11191 		"METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
11192 		"METAL_VALID_RENDER_TARGET",
11193 		"is_function_constant_defined",
11194 		"CHAR_BIT",
11195 		"SCHAR_MAX",
11196 		"SCHAR_MIN",
11197 		"UCHAR_MAX",
11198 		"CHAR_MAX",
11199 		"CHAR_MIN",
11200 		"USHRT_MAX",
11201 		"SHRT_MAX",
11202 		"SHRT_MIN",
11203 		"UINT_MAX",
11204 		"INT_MAX",
11205 		"INT_MIN",
11206 		"FLT_DIG",
11207 		"FLT_MANT_DIG",
11208 		"FLT_MAX_10_EXP",
11209 		"FLT_MAX_EXP",
11210 		"FLT_MIN_10_EXP",
11211 		"FLT_MIN_EXP",
11212 		"FLT_RADIX",
11213 		"FLT_MAX",
11214 		"FLT_MIN",
11215 		"FLT_EPSILON",
11216 		"FP_ILOGB0",
11217 		"FP_ILOGBNAN",
11218 		"MAXFLOAT",
11219 		"HUGE_VALF",
11220 		"INFINITY",
11221 		"NAN",
11222 		"M_E_F",
11223 		"M_LOG2E_F",
11224 		"M_LOG10E_F",
11225 		"M_LN2_F",
11226 		"M_LN10_F",
11227 		"M_PI_F",
11228 		"M_PI_2_F",
11229 		"M_PI_4_F",
11230 		"M_1_PI_F",
11231 		"M_2_PI_F",
11232 		"M_2_SQRTPI_F",
11233 		"M_SQRT2_F",
11234 		"M_SQRT1_2_F",
11235 		"HALF_DIG",
11236 		"HALF_MANT_DIG",
11237 		"HALF_MAX_10_EXP",
11238 		"HALF_MAX_EXP",
11239 		"HALF_MIN_10_EXP",
11240 		"HALF_MIN_EXP",
11241 		"HALF_RADIX",
11242 		"HALF_MAX",
11243 		"HALF_MIN",
11244 		"HALF_EPSILON",
11245 		"MAXHALF",
11246 		"HUGE_VALH",
11247 		"M_E_H",
11248 		"M_LOG2E_H",
11249 		"M_LOG10E_H",
11250 		"M_LN2_H",
11251 		"M_LN10_H",
11252 		"M_PI_H",
11253 		"M_PI_2_H",
11254 		"M_PI_4_H",
11255 		"M_1_PI_H",
11256 		"M_2_PI_H",
11257 		"M_2_SQRTPI_H",
11258 		"M_SQRT2_H",
11259 		"M_SQRT1_2_H",
11260 		"DBL_DIG",
11261 		"DBL_MANT_DIG",
11262 		"DBL_MAX_10_EXP",
11263 		"DBL_MAX_EXP",
11264 		"DBL_MIN_10_EXP",
11265 		"DBL_MIN_EXP",
11266 		"DBL_RADIX",
11267 		"DBL_MAX",
11268 		"DBL_MIN",
11269 		"DBL_EPSILON",
11270 		"HUGE_VAL",
11271 		"M_E",
11272 		"M_LOG2E",
11273 		"M_LOG10E",
11274 		"M_LN2",
11275 		"M_LN10",
11276 		"M_PI",
11277 		"M_PI_2",
11278 		"M_PI_4",
11279 		"M_1_PI",
11280 		"M_2_PI",
11281 		"M_2_SQRTPI",
11282 		"M_SQRT2",
11283 		"M_SQRT1_2",
11284 	};
11285 
11286 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &) {
11287 		auto *meta = ir.find_meta(self);
11288 		if (!meta)
11289 			return;
11290 
11291 		auto &dec = meta->decoration;
11292 		if (keywords.find(dec.alias) != end(keywords))
11293 			dec.alias += "0";
11294 	});
11295 
11296 	ir.for_each_typed_id<SPIRFunction>([&](uint32_t self, SPIRFunction &) {
11297 		auto *meta = ir.find_meta(self);
11298 		if (!meta)
11299 			return;
11300 
11301 		auto &dec = meta->decoration;
11302 		if (illegal_func_names.find(dec.alias) != end(illegal_func_names))
11303 			dec.alias += "0";
11304 	});
11305 
11306 	ir.for_each_typed_id<SPIRType>([&](uint32_t self, SPIRType &) {
11307 		auto *meta = ir.find_meta(self);
11308 		if (!meta)
11309 			return;
11310 
11311 		for (auto &mbr_dec : meta->members)
11312 			if (keywords.find(mbr_dec.alias) != end(keywords))
11313 				mbr_dec.alias += "0";
11314 	});
11315 
11316 	for (auto &entry : ir.entry_points)
11317 	{
11318 		// Change both the entry point name and the alias, to keep them synced.
11319 		string &ep_name = entry.second.name;
11320 		if (illegal_func_names.find(ep_name) != end(illegal_func_names))
11321 			ep_name += "0";
11322 
11323 		// Always write this because entry point might have been renamed earlier.
11324 		ir.meta[entry.first].decoration.alias = ep_name;
11325 	}
11326 
11327 	CompilerGLSL::replace_illegal_names();
11328 }
11329 
to_member_reference(uint32_t base,const SPIRType & type,uint32_t index,bool ptr_chain)11330 string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain)
11331 {
11332 	if (index < uint32_t(type.member_type_index_redirection.size()))
11333 		index = type.member_type_index_redirection[index];
11334 
11335 	auto *var = maybe_get<SPIRVariable>(base);
11336 	// If this is a buffer array, we have to dereference the buffer pointers.
11337 	// Otherwise, if this is a pointer expression, dereference it.
11338 
11339 	bool declared_as_pointer = false;
11340 
11341 	if (var)
11342 	{
11343 		// Only allow -> dereference for block types. This is so we get expressions like
11344 		// buffer[i]->first_member.second_member, rather than buffer[i]->first->second.
11345 		bool is_block = has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock);
11346 
11347 		bool is_buffer_variable =
11348 		    is_block && (var->storage == StorageClassUniform || var->storage == StorageClassStorageBuffer);
11349 		declared_as_pointer = is_buffer_variable && is_array(get<SPIRType>(var->basetype));
11350 	}
11351 
11352 	if (declared_as_pointer || (!ptr_chain && should_dereference(base)))
11353 		return join("->", to_member_name(type, index));
11354 	else
11355 		return join(".", to_member_name(type, index));
11356 }
11357 
to_qualifiers_glsl(uint32_t id)11358 string CompilerMSL::to_qualifiers_glsl(uint32_t id)
11359 {
11360 	string quals;
11361 
11362 	auto &type = expression_type(id);
11363 	if (type.storage == StorageClassWorkgroup)
11364 		quals += "threadgroup ";
11365 
11366 	return quals;
11367 }
11368 
11369 // The optional id parameter indicates the object whose type we are trying
11370 // to find the description for. It is optional. Most type descriptions do not
11371 // depend on a specific object's use of that type.
type_to_glsl(const SPIRType & type,uint32_t id)11372 string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id)
11373 {
11374 	string type_name;
11375 
11376 	// Pointer?
11377 	if (type.pointer)
11378 	{
11379 		const char *restrict_kw;
11380 		type_name = join(get_type_address_space(type, id), " ", type_to_glsl(get<SPIRType>(type.parent_type), id));
11381 
11382 		switch (type.basetype)
11383 		{
11384 		case SPIRType::Image:
11385 		case SPIRType::SampledImage:
11386 		case SPIRType::Sampler:
11387 			// These are handles.
11388 			break;
11389 		default:
11390 			// Anything else can be a raw pointer.
11391 			type_name += "*";
11392 			restrict_kw = to_restrict(id);
11393 			if (*restrict_kw)
11394 			{
11395 				type_name += " ";
11396 				type_name += restrict_kw;
11397 			}
11398 			break;
11399 		}
11400 		return type_name;
11401 	}
11402 
11403 	switch (type.basetype)
11404 	{
11405 	case SPIRType::Struct:
11406 		// Need OpName lookup here to get a "sensible" name for a struct.
11407 		// Allow Metal to use the array<T> template to make arrays a value type
11408 		type_name = to_name(type.self);
11409 		break;
11410 
11411 	case SPIRType::Image:
11412 	case SPIRType::SampledImage:
11413 		return image_type_glsl(type, id);
11414 
11415 	case SPIRType::Sampler:
11416 		return sampler_type(type);
11417 
11418 	case SPIRType::Void:
11419 		return "void";
11420 
11421 	case SPIRType::AtomicCounter:
11422 		return "atomic_uint";
11423 
11424 	case SPIRType::ControlPointArray:
11425 		return join("patch_control_point<", type_to_glsl(get<SPIRType>(type.parent_type), id), ">");
11426 
11427 	// Scalars
11428 	case SPIRType::Boolean:
11429 		type_name = "bool";
11430 		break;
11431 	case SPIRType::Char:
11432 	case SPIRType::SByte:
11433 		type_name = "char";
11434 		break;
11435 	case SPIRType::UByte:
11436 		type_name = "uchar";
11437 		break;
11438 	case SPIRType::Short:
11439 		type_name = "short";
11440 		break;
11441 	case SPIRType::UShort:
11442 		type_name = "ushort";
11443 		break;
11444 	case SPIRType::Int:
11445 		type_name = "int";
11446 		break;
11447 	case SPIRType::UInt:
11448 		type_name = "uint";
11449 		break;
11450 	case SPIRType::Int64:
11451 		if (!msl_options.supports_msl_version(2, 2))
11452 			SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
11453 		type_name = "long";
11454 		break;
11455 	case SPIRType::UInt64:
11456 		if (!msl_options.supports_msl_version(2, 2))
11457 			SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
11458 		type_name = "ulong";
11459 		break;
11460 	case SPIRType::Half:
11461 		type_name = "half";
11462 		break;
11463 	case SPIRType::Float:
11464 		type_name = "float";
11465 		break;
11466 	case SPIRType::Double:
11467 		type_name = "double"; // Currently unsupported
11468 		break;
11469 
11470 	default:
11471 		return "unknown_type";
11472 	}
11473 
11474 	// Matrix?
11475 	if (type.columns > 1)
11476 		type_name += to_string(type.columns) + "x";
11477 
11478 	// Vector or Matrix?
11479 	if (type.vecsize > 1)
11480 		type_name += to_string(type.vecsize);
11481 
11482 	if (type.array.empty() || using_builtin_array())
11483 	{
11484 		return type_name;
11485 	}
11486 	else
11487 	{
11488 		// Allow Metal to use the array<T> template to make arrays a value type
11489 		add_spv_func_and_recompile(SPVFuncImplUnsafeArray);
11490 		string res;
11491 		string sizes;
11492 
11493 		for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
11494 		{
11495 			res += "spvUnsafeArray<";
11496 			sizes += ", ";
11497 			sizes += to_array_size(type, i);
11498 			sizes += ">";
11499 		}
11500 
11501 		res += type_name + sizes;
11502 		return res;
11503 	}
11504 }
11505 
type_to_array_glsl(const SPIRType & type)11506 string CompilerMSL::type_to_array_glsl(const SPIRType &type)
11507 {
11508 	// Allow Metal to use the array<T> template to make arrays a value type
11509 	switch (type.basetype)
11510 	{
11511 	case SPIRType::AtomicCounter:
11512 	case SPIRType::ControlPointArray:
11513 	{
11514 		return CompilerGLSL::type_to_array_glsl(type);
11515 	}
11516 	default:
11517 	{
11518 		if (using_builtin_array())
11519 			return CompilerGLSL::type_to_array_glsl(type);
11520 		else
11521 			return "";
11522 	}
11523 	}
11524 }
11525 
11526 // Threadgroup arrays can't have a wrapper type
variable_decl(const SPIRVariable & variable)11527 std::string CompilerMSL::variable_decl(const SPIRVariable &variable)
11528 {
11529 	if (variable.storage == StorageClassWorkgroup)
11530 	{
11531 		is_using_builtin_array = true;
11532 	}
11533 	std::string expr = CompilerGLSL::variable_decl(variable);
11534 	if (variable.storage == StorageClassWorkgroup)
11535 	{
11536 		is_using_builtin_array = false;
11537 	}
11538 	return expr;
11539 }
11540 
11541 // GCC workaround of lambdas calling protected funcs
variable_decl(const SPIRType & type,const std::string & name,uint32_t id)11542 std::string CompilerMSL::variable_decl(const SPIRType &type, const std::string &name, uint32_t id)
11543 {
11544 	return CompilerGLSL::variable_decl(type, name, id);
11545 }
11546 
sampler_type(const SPIRType & type)11547 std::string CompilerMSL::sampler_type(const SPIRType &type)
11548 {
11549 	if (!type.array.empty())
11550 	{
11551 		if (!msl_options.supports_msl_version(2))
11552 			SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of samplers.");
11553 
11554 		if (type.array.size() > 1)
11555 			SPIRV_CROSS_THROW("Arrays of arrays of samplers are not supported in MSL.");
11556 
11557 		// Arrays of samplers in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
11558 		uint32_t array_size = to_array_size_literal(type);
11559 		if (array_size == 0)
11560 			SPIRV_CROSS_THROW("Unsized array of samplers is not supported in MSL.");
11561 
11562 		auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
11563 		return join("array<", sampler_type(parent), ", ", array_size, ">");
11564 	}
11565 	else
11566 		return "sampler";
11567 }
11568 
11569 // Returns an MSL string describing the SPIR-V image type
image_type_glsl(const SPIRType & type,uint32_t id)11570 string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id)
11571 {
11572 	auto *var = maybe_get<SPIRVariable>(id);
11573 	if (var && var->basevariable)
11574 	{
11575 		// For comparison images, check against the base variable,
11576 		// and not the fake ID which might have been generated for this variable.
11577 		id = var->basevariable;
11578 	}
11579 
11580 	if (!type.array.empty())
11581 	{
11582 		uint32_t major = 2, minor = 0;
11583 		if (msl_options.is_ios())
11584 		{
11585 			major = 1;
11586 			minor = 2;
11587 		}
11588 		if (!msl_options.supports_msl_version(major, minor))
11589 		{
11590 			if (msl_options.is_ios())
11591 				SPIRV_CROSS_THROW("MSL 1.2 or greater is required for arrays of textures.");
11592 			else
11593 				SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of textures.");
11594 		}
11595 
11596 		if (type.array.size() > 1)
11597 			SPIRV_CROSS_THROW("Arrays of arrays of textures are not supported in MSL.");
11598 
11599 		// Arrays of images in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
11600 		uint32_t array_size = to_array_size_literal(type);
11601 		if (array_size == 0)
11602 			SPIRV_CROSS_THROW("Unsized array of images is not supported in MSL.");
11603 
11604 		auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
11605 		return join("array<", image_type_glsl(parent, id), ", ", array_size, ">");
11606 	}
11607 
11608 	string img_type_name;
11609 
11610 	// Bypass pointers because we need the real image struct
11611 	auto &img_type = get<SPIRType>(type.self).image;
11612 	if (image_is_comparison(type, id))
11613 	{
11614 		switch (img_type.dim)
11615 		{
11616 		case Dim1D:
11617 		case Dim2D:
11618 			if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
11619 			{
11620 				// Use a native Metal 1D texture
11621 				img_type_name += "depth1d_unsupported_by_metal";
11622 				break;
11623 			}
11624 
11625 			if (img_type.ms && img_type.arrayed)
11626 			{
11627 				if (!msl_options.supports_msl_version(2, 1))
11628 					SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
11629 				img_type_name += "depth2d_ms_array";
11630 			}
11631 			else if (img_type.ms)
11632 				img_type_name += "depth2d_ms";
11633 			else if (img_type.arrayed)
11634 				img_type_name += "depth2d_array";
11635 			else
11636 				img_type_name += "depth2d";
11637 			break;
11638 		case Dim3D:
11639 			img_type_name += "depth3d_unsupported_by_metal";
11640 			break;
11641 		case DimCube:
11642 			if (!msl_options.emulate_cube_array)
11643 				img_type_name += (img_type.arrayed ? "depthcube_array" : "depthcube");
11644 			else
11645 				img_type_name += (img_type.arrayed ? "depth2d_array" : "depthcube");
11646 			break;
11647 		default:
11648 			img_type_name += "unknown_depth_texture_type";
11649 			break;
11650 		}
11651 	}
11652 	else
11653 	{
11654 		switch (img_type.dim)
11655 		{
11656 		case DimBuffer:
11657 			if (img_type.ms || img_type.arrayed)
11658 				SPIRV_CROSS_THROW("Cannot use texel buffers with multisampling or array layers.");
11659 
11660 			if (msl_options.texture_buffer_native)
11661 			{
11662 				if (!msl_options.supports_msl_version(2, 1))
11663 					SPIRV_CROSS_THROW("Native texture_buffer type is only supported in MSL 2.1.");
11664 				img_type_name = "texture_buffer";
11665 			}
11666 			else
11667 				img_type_name += "texture2d";
11668 			break;
11669 		case Dim1D:
11670 		case Dim2D:
11671 		case DimSubpassData:
11672 		{
11673 			bool subpass_array =
11674 			    img_type.dim == DimSubpassData && (msl_options.multiview || msl_options.arrayed_subpass_input);
11675 			if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
11676 			{
11677 				// Use a native Metal 1D texture
11678 				img_type_name += (img_type.arrayed ? "texture1d_array" : "texture1d");
11679 				break;
11680 			}
11681 
11682 			// Use Metal's native frame-buffer fetch API for subpass inputs.
11683 			if (type_is_msl_framebuffer_fetch(type))
11684 			{
11685 				auto img_type_4 = get<SPIRType>(img_type.type);
11686 				img_type_4.vecsize = 4;
11687 				return type_to_glsl(img_type_4);
11688 			}
11689 			if (img_type.ms && (img_type.arrayed || subpass_array))
11690 			{
11691 				if (!msl_options.supports_msl_version(2, 1))
11692 					SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
11693 				img_type_name += "texture2d_ms_array";
11694 			}
11695 			else if (img_type.ms)
11696 				img_type_name += "texture2d_ms";
11697 			else if (img_type.arrayed || subpass_array)
11698 				img_type_name += "texture2d_array";
11699 			else
11700 				img_type_name += "texture2d";
11701 			break;
11702 		}
11703 		case Dim3D:
11704 			img_type_name += "texture3d";
11705 			break;
11706 		case DimCube:
11707 			if (!msl_options.emulate_cube_array)
11708 				img_type_name += (img_type.arrayed ? "texturecube_array" : "texturecube");
11709 			else
11710 				img_type_name += (img_type.arrayed ? "texture2d_array" : "texturecube");
11711 			break;
11712 		default:
11713 			img_type_name += "unknown_texture_type";
11714 			break;
11715 		}
11716 	}
11717 
11718 	// Append the pixel type
11719 	img_type_name += "<";
11720 	img_type_name += type_to_glsl(get<SPIRType>(img_type.type));
11721 
11722 	// For unsampled images, append the sample/read/write access qualifier.
11723 	// For kernel images, the access qualifier my be supplied directly by SPIR-V.
11724 	// Otherwise it may be set based on whether the image is read from or written to within the shader.
11725 	if (type.basetype == SPIRType::Image && type.image.sampled == 2 && type.image.dim != DimSubpassData)
11726 	{
11727 		switch (img_type.access)
11728 		{
11729 		case AccessQualifierReadOnly:
11730 			img_type_name += ", access::read";
11731 			break;
11732 
11733 		case AccessQualifierWriteOnly:
11734 			img_type_name += ", access::write";
11735 			break;
11736 
11737 		case AccessQualifierReadWrite:
11738 			img_type_name += ", access::read_write";
11739 			break;
11740 
11741 		default:
11742 		{
11743 			auto *p_var = maybe_get_backing_variable(id);
11744 			if (p_var && p_var->basevariable)
11745 				p_var = maybe_get<SPIRVariable>(p_var->basevariable);
11746 			if (p_var && !has_decoration(p_var->self, DecorationNonWritable))
11747 			{
11748 				img_type_name += ", access::";
11749 
11750 				if (!has_decoration(p_var->self, DecorationNonReadable))
11751 					img_type_name += "read_";
11752 
11753 				img_type_name += "write";
11754 			}
11755 			break;
11756 		}
11757 		}
11758 	}
11759 
11760 	img_type_name += ">";
11761 
11762 	return img_type_name;
11763 }
11764 
emit_subgroup_op(const Instruction & i)11765 void CompilerMSL::emit_subgroup_op(const Instruction &i)
11766 {
11767 	const uint32_t *ops = stream(i);
11768 	auto op = static_cast<Op>(i.op);
11769 
11770 	// Metal 2.0 is required. iOS only supports quad ops. macOS only supports
11771 	// broadcast and shuffle on 10.13 (2.0), with full support in 10.14 (2.1).
11772 	// Note that iOS makes no distinction between a quad-group and a subgroup;
11773 	// all subgroups are quad-groups there.
11774 	if (!msl_options.supports_msl_version(2))
11775 		SPIRV_CROSS_THROW("Subgroups are only supported in Metal 2.0 and up.");
11776 
11777 	// If we need to do implicit bitcasts, make sure we do it with the correct type.
11778 	uint32_t integer_width = get_integer_width_for_instruction(i);
11779 	auto int_type = to_signed_basetype(integer_width);
11780 	auto uint_type = to_unsigned_basetype(integer_width);
11781 
11782 	if (msl_options.is_ios())
11783 	{
11784 		switch (op)
11785 		{
11786 		default:
11787 			SPIRV_CROSS_THROW("iOS only supports quad-group operations.");
11788 		case OpGroupNonUniformBroadcast:
11789 		case OpGroupNonUniformShuffle:
11790 		case OpGroupNonUniformShuffleXor:
11791 		case OpGroupNonUniformShuffleUp:
11792 		case OpGroupNonUniformShuffleDown:
11793 		case OpGroupNonUniformQuadSwap:
11794 		case OpGroupNonUniformQuadBroadcast:
11795 			break;
11796 		}
11797 	}
11798 
11799 	if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
11800 	{
11801 		switch (op)
11802 		{
11803 		default:
11804 			SPIRV_CROSS_THROW("Subgroup ops beyond broadcast and shuffle on macOS require Metal 2.1 and up.");
11805 		case OpGroupNonUniformBroadcast:
11806 		case OpGroupNonUniformShuffle:
11807 		case OpGroupNonUniformShuffleXor:
11808 		case OpGroupNonUniformShuffleUp:
11809 		case OpGroupNonUniformShuffleDown:
11810 			break;
11811 		}
11812 	}
11813 
11814 	uint32_t result_type = ops[0];
11815 	uint32_t id = ops[1];
11816 
11817 	auto scope = static_cast<Scope>(evaluate_constant_u32(ops[2]));
11818 	if (scope != ScopeSubgroup)
11819 		SPIRV_CROSS_THROW("Only subgroup scope is supported.");
11820 
11821 	switch (op)
11822 	{
11823 	case OpGroupNonUniformElect:
11824 		emit_op(result_type, id, "simd_is_first()", true);
11825 		break;
11826 
11827 	case OpGroupNonUniformBroadcast:
11828 		emit_binary_func_op(result_type, id, ops[3], ops[4],
11829 		                    msl_options.is_ios() ? "quad_broadcast" : "simd_broadcast");
11830 		break;
11831 
11832 	case OpGroupNonUniformBroadcastFirst:
11833 		emit_unary_func_op(result_type, id, ops[3], "simd_broadcast_first");
11834 		break;
11835 
11836 	case OpGroupNonUniformBallot:
11837 		emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallot");
11838 		break;
11839 
11840 	case OpGroupNonUniformInverseBallot:
11841 		emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_invocation_id_id, "spvSubgroupBallotBitExtract");
11842 		break;
11843 
11844 	case OpGroupNonUniformBallotBitExtract:
11845 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBallotBitExtract");
11846 		break;
11847 
11848 	case OpGroupNonUniformBallotFindLSB:
11849 		emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallotFindLSB");
11850 		break;
11851 
11852 	case OpGroupNonUniformBallotFindMSB:
11853 		emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallotFindMSB");
11854 		break;
11855 
11856 	case OpGroupNonUniformBallotBitCount:
11857 	{
11858 		auto operation = static_cast<GroupOperation>(ops[3]);
11859 		if (operation == GroupOperationReduce)
11860 			emit_unary_func_op(result_type, id, ops[4], "spvSubgroupBallotBitCount");
11861 		else if (operation == GroupOperationInclusiveScan)
11862 			emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
11863 			                    "spvSubgroupBallotInclusiveBitCount");
11864 		else if (operation == GroupOperationExclusiveScan)
11865 			emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
11866 			                    "spvSubgroupBallotExclusiveBitCount");
11867 		else
11868 			SPIRV_CROSS_THROW("Invalid BitCount operation.");
11869 		break;
11870 	}
11871 
11872 	case OpGroupNonUniformShuffle:
11873 		emit_binary_func_op(result_type, id, ops[3], ops[4], msl_options.is_ios() ? "quad_shuffle" : "simd_shuffle");
11874 		break;
11875 
11876 	case OpGroupNonUniformShuffleXor:
11877 		emit_binary_func_op(result_type, id, ops[3], ops[4],
11878 		                    msl_options.is_ios() ? "quad_shuffle_xor" : "simd_shuffle_xor");
11879 		break;
11880 
11881 	case OpGroupNonUniformShuffleUp:
11882 		emit_binary_func_op(result_type, id, ops[3], ops[4],
11883 		                    msl_options.is_ios() ? "quad_shuffle_up" : "simd_shuffle_up");
11884 		break;
11885 
11886 	case OpGroupNonUniformShuffleDown:
11887 		emit_binary_func_op(result_type, id, ops[3], ops[4],
11888 		                    msl_options.is_ios() ? "quad_shuffle_down" : "simd_shuffle_down");
11889 		break;
11890 
11891 	case OpGroupNonUniformAll:
11892 		emit_unary_func_op(result_type, id, ops[3], "simd_all");
11893 		break;
11894 
11895 	case OpGroupNonUniformAny:
11896 		emit_unary_func_op(result_type, id, ops[3], "simd_any");
11897 		break;
11898 
11899 	case OpGroupNonUniformAllEqual:
11900 		emit_unary_func_op(result_type, id, ops[3], "spvSubgroupAllEqual");
11901 		break;
11902 
11903 		// clang-format off
11904 #define MSL_GROUP_OP(op, msl_op) \
11905 case OpGroupNonUniform##op: \
11906 	{ \
11907 		auto operation = static_cast<GroupOperation>(ops[3]); \
11908 		if (operation == GroupOperationReduce) \
11909 			emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
11910 		else if (operation == GroupOperationInclusiveScan) \
11911 			emit_unary_func_op(result_type, id, ops[4], "simd_prefix_inclusive_" #msl_op); \
11912 		else if (operation == GroupOperationExclusiveScan) \
11913 			emit_unary_func_op(result_type, id, ops[4], "simd_prefix_exclusive_" #msl_op); \
11914 		else if (operation == GroupOperationClusteredReduce) \
11915 		{ \
11916 			/* Only cluster sizes of 4 are supported. */ \
11917 			uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
11918 			if (cluster_size != 4) \
11919 				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
11920 			emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
11921 		} \
11922 		else \
11923 			SPIRV_CROSS_THROW("Invalid group operation."); \
11924 		break; \
11925 	}
11926 	MSL_GROUP_OP(FAdd, sum)
11927 	MSL_GROUP_OP(FMul, product)
11928 	MSL_GROUP_OP(IAdd, sum)
11929 	MSL_GROUP_OP(IMul, product)
11930 #undef MSL_GROUP_OP
11931 	// The others, unfortunately, don't support InclusiveScan or ExclusiveScan.
11932 
11933 #define MSL_GROUP_OP(op, msl_op) \
11934 case OpGroupNonUniform##op: \
11935 	{ \
11936 		auto operation = static_cast<GroupOperation>(ops[3]); \
11937 		if (operation == GroupOperationReduce) \
11938 			emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
11939 		else if (operation == GroupOperationInclusiveScan) \
11940 			SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
11941 		else if (operation == GroupOperationExclusiveScan) \
11942 			SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
11943 		else if (operation == GroupOperationClusteredReduce) \
11944 		{ \
11945 			/* Only cluster sizes of 4 are supported. */ \
11946 			uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
11947 			if (cluster_size != 4) \
11948 				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
11949 			emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
11950 		} \
11951 		else \
11952 			SPIRV_CROSS_THROW("Invalid group operation."); \
11953 		break; \
11954 	}
11955 
11956 #define MSL_GROUP_OP_CAST(op, msl_op, type) \
11957 case OpGroupNonUniform##op: \
11958 	{ \
11959 		auto operation = static_cast<GroupOperation>(ops[3]); \
11960 		if (operation == GroupOperationReduce) \
11961 			emit_unary_func_op_cast(result_type, id, ops[4], "simd_" #msl_op, type, type); \
11962 		else if (operation == GroupOperationInclusiveScan) \
11963 			SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
11964 		else if (operation == GroupOperationExclusiveScan) \
11965 			SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
11966 		else if (operation == GroupOperationClusteredReduce) \
11967 		{ \
11968 			/* Only cluster sizes of 4 are supported. */ \
11969 			uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
11970 			if (cluster_size != 4) \
11971 				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
11972 			emit_unary_func_op_cast(result_type, id, ops[4], "quad_" #msl_op, type, type); \
11973 		} \
11974 		else \
11975 			SPIRV_CROSS_THROW("Invalid group operation."); \
11976 		break; \
11977 	}
11978 
11979 	MSL_GROUP_OP(FMin, min)
11980 	MSL_GROUP_OP(FMax, max)
11981 	MSL_GROUP_OP_CAST(SMin, min, int_type)
11982 	MSL_GROUP_OP_CAST(SMax, max, int_type)
11983 	MSL_GROUP_OP_CAST(UMin, min, uint_type)
11984 	MSL_GROUP_OP_CAST(UMax, max, uint_type)
11985 	MSL_GROUP_OP(BitwiseAnd, and)
11986 	MSL_GROUP_OP(BitwiseOr, or)
11987 	MSL_GROUP_OP(BitwiseXor, xor)
11988 	MSL_GROUP_OP(LogicalAnd, and)
11989 	MSL_GROUP_OP(LogicalOr, or)
11990 	MSL_GROUP_OP(LogicalXor, xor)
11991 		// clang-format on
11992 #undef MSL_GROUP_OP
11993 #undef MSL_GROUP_OP_CAST
11994 
11995 	case OpGroupNonUniformQuadSwap:
11996 	{
11997 		// We can implement this easily based on the following table giving
11998 		// the target lane ID from the direction and current lane ID:
11999 		//        Direction
12000 		//      | 0 | 1 | 2 |
12001 		//   ---+---+---+---+
12002 		// L 0  | 1   2   3
12003 		// a 1  | 0   3   2
12004 		// n 2  | 3   0   1
12005 		// e 3  | 2   1   0
12006 		// Notice that target = source ^ (direction + 1).
12007 		uint32_t mask = evaluate_constant_u32(ops[4]) + 1;
12008 		uint32_t mask_id = ir.increase_bound_by(1);
12009 		set<SPIRConstant>(mask_id, expression_type_id(ops[4]), mask, false);
12010 		emit_binary_func_op(result_type, id, ops[3], mask_id, "quad_shuffle_xor");
12011 		break;
12012 	}
12013 
12014 	case OpGroupNonUniformQuadBroadcast:
12015 		emit_binary_func_op(result_type, id, ops[3], ops[4], "quad_broadcast");
12016 		break;
12017 
12018 	default:
12019 		SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
12020 	}
12021 
12022 	register_control_dependent_expression(id);
12023 }
12024 
bitcast_glsl_op(const SPIRType & out_type,const SPIRType & in_type)12025 string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
12026 {
12027 	if (out_type.basetype == in_type.basetype)
12028 		return "";
12029 
12030 	assert(out_type.basetype != SPIRType::Boolean);
12031 	assert(in_type.basetype != SPIRType::Boolean);
12032 
12033 	bool integral_cast = type_is_integral(out_type) && type_is_integral(in_type);
12034 	bool same_size_cast = out_type.width == in_type.width;
12035 
12036 	if (integral_cast && same_size_cast)
12037 	{
12038 		// Trivial bitcast case, casts between integers.
12039 		return type_to_glsl(out_type);
12040 	}
12041 	else
12042 	{
12043 		// Fall back to the catch-all bitcast in MSL.
12044 		return "as_type<" + type_to_glsl(out_type) + ">";
12045 	}
12046 }
12047 
emit_complex_bitcast(uint32_t,uint32_t,uint32_t)12048 bool CompilerMSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
12049 {
12050 	return false;
12051 }
12052 
12053 // Returns an MSL string identifying the name of a SPIR-V builtin.
12054 // Output builtins are qualified with the name of the stage out structure.
builtin_to_glsl(BuiltIn builtin,StorageClass storage)12055 string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
12056 {
12057 	switch (builtin)
12058 	{
12059 
12060 	// Handle HLSL-style 0-based vertex/instance index.
12061 	// Override GLSL compiler strictness
12062 	case BuiltInVertexId:
12063 		ensure_builtin(StorageClassInput, BuiltInVertexId);
12064 		if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
12065 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
12066 		{
12067 			if (builtin_declaration)
12068 			{
12069 				if (needs_base_vertex_arg != TriState::No)
12070 					needs_base_vertex_arg = TriState::Yes;
12071 				return "gl_VertexID";
12072 			}
12073 			else
12074 			{
12075 				ensure_builtin(StorageClassInput, BuiltInBaseVertex);
12076 				return "(gl_VertexID - gl_BaseVertex)";
12077 			}
12078 		}
12079 		else
12080 		{
12081 			return "gl_VertexID";
12082 		}
12083 	case BuiltInInstanceId:
12084 		ensure_builtin(StorageClassInput, BuiltInInstanceId);
12085 		if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
12086 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
12087 		{
12088 			if (builtin_declaration)
12089 			{
12090 				if (needs_base_instance_arg != TriState::No)
12091 					needs_base_instance_arg = TriState::Yes;
12092 				return "gl_InstanceID";
12093 			}
12094 			else
12095 			{
12096 				ensure_builtin(StorageClassInput, BuiltInBaseInstance);
12097 				return "(gl_InstanceID - gl_BaseInstance)";
12098 			}
12099 		}
12100 		else
12101 		{
12102 			return "gl_InstanceID";
12103 		}
12104 	case BuiltInVertexIndex:
12105 		ensure_builtin(StorageClassInput, BuiltInVertexIndex);
12106 		if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
12107 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
12108 		{
12109 			if (builtin_declaration)
12110 			{
12111 				if (needs_base_vertex_arg != TriState::No)
12112 					needs_base_vertex_arg = TriState::Yes;
12113 				return "gl_VertexIndex";
12114 			}
12115 			else
12116 			{
12117 				ensure_builtin(StorageClassInput, BuiltInBaseVertex);
12118 				return "(gl_VertexIndex - gl_BaseVertex)";
12119 			}
12120 		}
12121 		else
12122 		{
12123 			return "gl_VertexIndex";
12124 		}
12125 	case BuiltInInstanceIndex:
12126 		ensure_builtin(StorageClassInput, BuiltInInstanceIndex);
12127 		if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
12128 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
12129 		{
12130 			if (builtin_declaration)
12131 			{
12132 				if (needs_base_instance_arg != TriState::No)
12133 					needs_base_instance_arg = TriState::Yes;
12134 				return "gl_InstanceIndex";
12135 			}
12136 			else
12137 			{
12138 				ensure_builtin(StorageClassInput, BuiltInBaseInstance);
12139 				return "(gl_InstanceIndex - gl_BaseInstance)";
12140 			}
12141 		}
12142 		else
12143 		{
12144 			return "gl_InstanceIndex";
12145 		}
12146 	case BuiltInBaseVertex:
12147 		if (msl_options.supports_msl_version(1, 1) &&
12148 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
12149 		{
12150 			needs_base_vertex_arg = TriState::No;
12151 			return "gl_BaseVertex";
12152 		}
12153 		else
12154 		{
12155 			SPIRV_CROSS_THROW("BaseVertex requires Metal 1.1 and Mac or Apple A9+ hardware.");
12156 		}
12157 	case BuiltInBaseInstance:
12158 		if (msl_options.supports_msl_version(1, 1) &&
12159 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
12160 		{
12161 			needs_base_instance_arg = TriState::No;
12162 			return "gl_BaseInstance";
12163 		}
12164 		else
12165 		{
12166 			SPIRV_CROSS_THROW("BaseInstance requires Metal 1.1 and Mac or Apple A9+ hardware.");
12167 		}
12168 	case BuiltInDrawIndex:
12169 		SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
12170 
12171 	// When used in the entry function, output builtins are qualified with output struct name.
12172 	// Test storage class as NOT Input, as output builtins might be part of generic type.
12173 	// Also don't do this for tessellation control shaders.
12174 	case BuiltInViewportIndex:
12175 		if (!msl_options.supports_msl_version(2, 0))
12176 			SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
12177 		/* fallthrough */
12178 	case BuiltInFragDepth:
12179 	case BuiltInFragStencilRefEXT:
12180 		if ((builtin == BuiltInFragDepth && !msl_options.enable_frag_depth_builtin) ||
12181 		    (builtin == BuiltInFragStencilRefEXT && !msl_options.enable_frag_stencil_ref_builtin))
12182 			break;
12183 		/* fallthrough */
12184 	case BuiltInPosition:
12185 	case BuiltInPointSize:
12186 	case BuiltInClipDistance:
12187 	case BuiltInCullDistance:
12188 	case BuiltInLayer:
12189 	case BuiltInSampleMask:
12190 		if (get_execution_model() == ExecutionModelTessellationControl)
12191 			break;
12192 		if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
12193 			return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
12194 
12195 		break;
12196 
12197 	case BuiltInBaryCoordNV:
12198 	case BuiltInBaryCoordNoPerspNV:
12199 		if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
12200 			return stage_in_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
12201 		break;
12202 
12203 	case BuiltInTessLevelOuter:
12204 		if (get_execution_model() == ExecutionModelTessellationEvaluation)
12205 		{
12206 			if (storage != StorageClassOutput && !get_entry_point().flags.get(ExecutionModeTriangles) &&
12207 			    current_function && (current_function->self == ir.default_entry_point))
12208 				return join(patch_stage_in_var_name, ".", CompilerGLSL::builtin_to_glsl(builtin, storage));
12209 			else
12210 				break;
12211 		}
12212 		if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
12213 			return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
12214 			            "].edgeTessellationFactor");
12215 		break;
12216 
12217 	case BuiltInTessLevelInner:
12218 		if (get_execution_model() == ExecutionModelTessellationEvaluation)
12219 		{
12220 			if (storage != StorageClassOutput && !get_entry_point().flags.get(ExecutionModeTriangles) &&
12221 			    current_function && (current_function->self == ir.default_entry_point))
12222 				return join(patch_stage_in_var_name, ".", CompilerGLSL::builtin_to_glsl(builtin, storage));
12223 			else
12224 				break;
12225 		}
12226 		if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
12227 			return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
12228 			            "].insideTessellationFactor");
12229 		break;
12230 
12231 	default:
12232 		break;
12233 	}
12234 
12235 	return CompilerGLSL::builtin_to_glsl(builtin, storage);
12236 }
12237 
12238 // Returns an MSL string attribute qualifer for a SPIR-V builtin
builtin_qualifier(BuiltIn builtin)12239 string CompilerMSL::builtin_qualifier(BuiltIn builtin)
12240 {
12241 	auto &execution = get_entry_point();
12242 
12243 	switch (builtin)
12244 	{
12245 	// Vertex function in
12246 	case BuiltInVertexId:
12247 		return "vertex_id";
12248 	case BuiltInVertexIndex:
12249 		return "vertex_id";
12250 	case BuiltInBaseVertex:
12251 		return "base_vertex";
12252 	case BuiltInInstanceId:
12253 		return "instance_id";
12254 	case BuiltInInstanceIndex:
12255 		return "instance_id";
12256 	case BuiltInBaseInstance:
12257 		return "base_instance";
12258 	case BuiltInDrawIndex:
12259 		SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
12260 
12261 	// Vertex function out
12262 	case BuiltInClipDistance:
12263 		return "clip_distance";
12264 	case BuiltInPointSize:
12265 		return "point_size";
12266 	case BuiltInPosition:
12267 		if (position_invariant)
12268 		{
12269 			if (!msl_options.supports_msl_version(2, 1))
12270 				SPIRV_CROSS_THROW("Invariant position is only supported on MSL 2.1 and up.");
12271 			return "position, invariant";
12272 		}
12273 		else
12274 			return "position";
12275 	case BuiltInLayer:
12276 		return "render_target_array_index";
12277 	case BuiltInViewportIndex:
12278 		if (!msl_options.supports_msl_version(2, 0))
12279 			SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
12280 		return "viewport_array_index";
12281 
12282 	// Tess. control function in
12283 	case BuiltInInvocationId:
12284 		if (msl_options.multi_patch_workgroup)
12285 		{
12286 			// Shouldn't be reached.
12287 			SPIRV_CROSS_THROW("InvocationId is computed manually with multi-patch workgroups in MSL.");
12288 		}
12289 		return "thread_index_in_threadgroup";
12290 	case BuiltInPatchVertices:
12291 		// Shouldn't be reached.
12292 		SPIRV_CROSS_THROW("PatchVertices is derived from the auxiliary buffer in MSL.");
12293 	case BuiltInPrimitiveId:
12294 		switch (execution.model)
12295 		{
12296 		case ExecutionModelTessellationControl:
12297 			if (msl_options.multi_patch_workgroup)
12298 			{
12299 				// Shouldn't be reached.
12300 				SPIRV_CROSS_THROW("PrimitiveId is computed manually with multi-patch workgroups in MSL.");
12301 			}
12302 			return "threadgroup_position_in_grid";
12303 		case ExecutionModelTessellationEvaluation:
12304 			return "patch_id";
12305 		case ExecutionModelFragment:
12306 			if (msl_options.is_ios())
12307 				SPIRV_CROSS_THROW("PrimitiveId is not supported in fragment on iOS.");
12308 			else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 2))
12309 				SPIRV_CROSS_THROW("PrimitiveId on macOS requires MSL 2.2.");
12310 			return "primitive_id";
12311 		default:
12312 			SPIRV_CROSS_THROW("PrimitiveId is not supported in this execution model.");
12313 		}
12314 
12315 	// Tess. control function out
12316 	case BuiltInTessLevelOuter:
12317 	case BuiltInTessLevelInner:
12318 		// Shouldn't be reached.
12319 		SPIRV_CROSS_THROW("Tessellation levels are handled specially in MSL.");
12320 
12321 	// Tess. evaluation function in
12322 	case BuiltInTessCoord:
12323 		return "position_in_patch";
12324 
12325 	// Fragment function in
12326 	case BuiltInFrontFacing:
12327 		return "front_facing";
12328 	case BuiltInPointCoord:
12329 		return "point_coord";
12330 	case BuiltInFragCoord:
12331 		return "position";
12332 	case BuiltInSampleId:
12333 		return "sample_id";
12334 	case BuiltInSampleMask:
12335 		return "sample_mask";
12336 	case BuiltInSamplePosition:
12337 		// Shouldn't be reached.
12338 		SPIRV_CROSS_THROW("Sample position is retrieved by a function in MSL.");
12339 	case BuiltInViewIndex:
12340 		if (execution.model != ExecutionModelFragment)
12341 			SPIRV_CROSS_THROW("ViewIndex is handled specially outside fragment shaders.");
12342 		// The ViewIndex was implicitly used in the prior stages to set the render_target_array_index,
12343 		// so we can get it from there.
12344 		return "render_target_array_index";
12345 
12346 	// Fragment function out
12347 	case BuiltInFragDepth:
12348 		if (execution.flags.get(ExecutionModeDepthGreater))
12349 			return "depth(greater)";
12350 		else if (execution.flags.get(ExecutionModeDepthLess))
12351 			return "depth(less)";
12352 		else
12353 			return "depth(any)";
12354 
12355 	case BuiltInFragStencilRefEXT:
12356 		return "stencil";
12357 
12358 	// Compute function in
12359 	case BuiltInGlobalInvocationId:
12360 		return "thread_position_in_grid";
12361 
12362 	case BuiltInWorkgroupId:
12363 		return "threadgroup_position_in_grid";
12364 
12365 	case BuiltInNumWorkgroups:
12366 		return "threadgroups_per_grid";
12367 
12368 	case BuiltInLocalInvocationId:
12369 		return "thread_position_in_threadgroup";
12370 
12371 	case BuiltInLocalInvocationIndex:
12372 		return "thread_index_in_threadgroup";
12373 
12374 	case BuiltInSubgroupSize:
12375 		if (execution.model == ExecutionModelFragment)
12376 		{
12377 			if (!msl_options.supports_msl_version(2, 2))
12378 				SPIRV_CROSS_THROW("threads_per_simdgroup requires Metal 2.2 in fragment shaders.");
12379 			return "threads_per_simdgroup";
12380 		}
12381 		else
12382 		{
12383 			// thread_execution_width is an alias for threads_per_simdgroup, and it's only available since 1.0,
12384 			// but not in fragment.
12385 			return "thread_execution_width";
12386 		}
12387 
12388 	case BuiltInNumSubgroups:
12389 		if (!msl_options.supports_msl_version(2))
12390 			SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
12391 		return msl_options.is_ios() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
12392 
12393 	case BuiltInSubgroupId:
12394 		if (!msl_options.supports_msl_version(2))
12395 			SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
12396 		return msl_options.is_ios() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
12397 
12398 	case BuiltInSubgroupLocalInvocationId:
12399 		if (execution.model == ExecutionModelFragment)
12400 		{
12401 			if (!msl_options.supports_msl_version(2, 2))
12402 				SPIRV_CROSS_THROW("thread_index_in_simdgroup requires Metal 2.2 in fragment shaders.");
12403 			return "thread_index_in_simdgroup";
12404 		}
12405 		else
12406 		{
12407 			if (!msl_options.supports_msl_version(2))
12408 				SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
12409 			return msl_options.is_ios() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
12410 		}
12411 
12412 	case BuiltInSubgroupEqMask:
12413 	case BuiltInSubgroupGeMask:
12414 	case BuiltInSubgroupGtMask:
12415 	case BuiltInSubgroupLeMask:
12416 	case BuiltInSubgroupLtMask:
12417 		// Shouldn't be reached.
12418 		SPIRV_CROSS_THROW("Subgroup ballot masks are handled specially in MSL.");
12419 
12420 	case BuiltInBaryCoordNV:
12421 		// TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
12422 		if (msl_options.is_ios())
12423 			SPIRV_CROSS_THROW("Barycentrics not supported on iOS.");
12424 		else if (!msl_options.supports_msl_version(2, 2))
12425 			SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
12426 		return "barycentric_coord, center_perspective";
12427 
12428 	case BuiltInBaryCoordNoPerspNV:
12429 		// TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
12430 		if (msl_options.is_ios())
12431 			SPIRV_CROSS_THROW("Barycentrics not supported on iOS.");
12432 		else if (!msl_options.supports_msl_version(2, 2))
12433 			SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
12434 		return "barycentric_coord, center_no_perspective";
12435 
12436 	default:
12437 		return "unsupported-built-in";
12438 	}
12439 }
12440 
12441 // Returns an MSL string type declaration for a SPIR-V builtin
builtin_type_decl(BuiltIn builtin,uint32_t id)12442 string CompilerMSL::builtin_type_decl(BuiltIn builtin, uint32_t id)
12443 {
12444 	const SPIREntryPoint &execution = get_entry_point();
12445 	switch (builtin)
12446 	{
12447 	// Vertex function in
12448 	case BuiltInVertexId:
12449 		return "uint";
12450 	case BuiltInVertexIndex:
12451 		return "uint";
12452 	case BuiltInBaseVertex:
12453 		return "uint";
12454 	case BuiltInInstanceId:
12455 		return "uint";
12456 	case BuiltInInstanceIndex:
12457 		return "uint";
12458 	case BuiltInBaseInstance:
12459 		return "uint";
12460 	case BuiltInDrawIndex:
12461 		SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
12462 
12463 	// Vertex function out
12464 	case BuiltInClipDistance:
12465 		return "float";
12466 	case BuiltInPointSize:
12467 		return "float";
12468 	case BuiltInPosition:
12469 		return "float4";
12470 	case BuiltInLayer:
12471 		return "uint";
12472 	case BuiltInViewportIndex:
12473 		if (!msl_options.supports_msl_version(2, 0))
12474 			SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
12475 		return "uint";
12476 
12477 	// Tess. control function in
12478 	case BuiltInInvocationId:
12479 		return "uint";
12480 	case BuiltInPatchVertices:
12481 		return "uint";
12482 	case BuiltInPrimitiveId:
12483 		return "uint";
12484 
12485 	// Tess. control function out
12486 	case BuiltInTessLevelInner:
12487 		if (execution.model == ExecutionModelTessellationEvaluation)
12488 			return !execution.flags.get(ExecutionModeTriangles) ? "float2" : "float";
12489 		return "half";
12490 	case BuiltInTessLevelOuter:
12491 		if (execution.model == ExecutionModelTessellationEvaluation)
12492 			return !execution.flags.get(ExecutionModeTriangles) ? "float4" : "float";
12493 		return "half";
12494 
12495 	// Tess. evaluation function in
12496 	case BuiltInTessCoord:
12497 		return execution.flags.get(ExecutionModeTriangles) ? "float3" : "float2";
12498 
12499 	// Fragment function in
12500 	case BuiltInFrontFacing:
12501 		return "bool";
12502 	case BuiltInPointCoord:
12503 		return "float2";
12504 	case BuiltInFragCoord:
12505 		return "float4";
12506 	case BuiltInSampleId:
12507 		return "uint";
12508 	case BuiltInSampleMask:
12509 		return "uint";
12510 	case BuiltInSamplePosition:
12511 		return "float2";
12512 	case BuiltInViewIndex:
12513 		return "uint";
12514 
12515 	case BuiltInHelperInvocation:
12516 		return "bool";
12517 
12518 	case BuiltInBaryCoordNV:
12519 	case BuiltInBaryCoordNoPerspNV:
12520 		// Use the type as declared, can be 1, 2 or 3 components.
12521 		return type_to_glsl(get_variable_data_type(get<SPIRVariable>(id)));
12522 
12523 	// Fragment function out
12524 	case BuiltInFragDepth:
12525 		return "float";
12526 
12527 	case BuiltInFragStencilRefEXT:
12528 		return "uint";
12529 
12530 	// Compute function in
12531 	case BuiltInGlobalInvocationId:
12532 	case BuiltInLocalInvocationId:
12533 	case BuiltInNumWorkgroups:
12534 	case BuiltInWorkgroupId:
12535 		return "uint3";
12536 	case BuiltInLocalInvocationIndex:
12537 	case BuiltInNumSubgroups:
12538 	case BuiltInSubgroupId:
12539 	case BuiltInSubgroupSize:
12540 	case BuiltInSubgroupLocalInvocationId:
12541 		return "uint";
12542 	case BuiltInSubgroupEqMask:
12543 	case BuiltInSubgroupGeMask:
12544 	case BuiltInSubgroupGtMask:
12545 	case BuiltInSubgroupLeMask:
12546 	case BuiltInSubgroupLtMask:
12547 		return "uint4";
12548 
12549 	case BuiltInDeviceIndex:
12550 		return "int";
12551 
12552 	default:
12553 		return "unsupported-built-in-type";
12554 	}
12555 }
12556 
12557 // Returns the declaration of a built-in argument to a function
built_in_func_arg(BuiltIn builtin,bool prefix_comma)12558 string CompilerMSL::built_in_func_arg(BuiltIn builtin, bool prefix_comma)
12559 {
12560 	string bi_arg;
12561 	if (prefix_comma)
12562 		bi_arg += ", ";
12563 
12564 	// Handle HLSL-style 0-based vertex/instance index.
12565 	builtin_declaration = true;
12566 	bi_arg += builtin_type_decl(builtin);
12567 	bi_arg += " " + builtin_to_glsl(builtin, StorageClassInput);
12568 	bi_arg += " [[" + builtin_qualifier(builtin) + "]]";
12569 	builtin_declaration = false;
12570 
12571 	return bi_arg;
12572 }
12573 
get_physical_member_type(const SPIRType & type,uint32_t index) const12574 const SPIRType &CompilerMSL::get_physical_member_type(const SPIRType &type, uint32_t index) const
12575 {
12576 	if (member_is_remapped_physical_type(type, index))
12577 		return get<SPIRType>(get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID));
12578 	else
12579 		return get<SPIRType>(type.member_types[index]);
12580 }
12581 
get_presumed_input_type(const SPIRType & ib_type,uint32_t index) const12582 SPIRType CompilerMSL::get_presumed_input_type(const SPIRType &ib_type, uint32_t index) const
12583 {
12584 	SPIRType type = get_physical_member_type(ib_type, index);
12585 	uint32_t loc = get_member_decoration(ib_type.self, index, DecorationLocation);
12586 	if (inputs_by_location.count(loc))
12587 	{
12588 		if (inputs_by_location.at(loc).vecsize > type.vecsize)
12589 			type.vecsize = inputs_by_location.at(loc).vecsize;
12590 	}
12591 	return type;
12592 }
12593 
get_declared_type_array_stride_msl(const SPIRType & type,bool is_packed,bool row_major) const12594 uint32_t CompilerMSL::get_declared_type_array_stride_msl(const SPIRType &type, bool is_packed, bool row_major) const
12595 {
12596 	// Array stride in MSL is always size * array_size. sizeof(float3) == 16,
12597 	// unlike GLSL and HLSL where array stride would be 16 and size 12.
12598 
12599 	// We could use parent type here and recurse, but that makes creating physical type remappings
12600 	// far more complicated. We'd rather just create the final type, and ignore having to create the entire type
12601 	// hierarchy in order to compute this value, so make a temporary type on the stack.
12602 
12603 	auto basic_type = type;
12604 	basic_type.array.clear();
12605 	basic_type.array_size_literal.clear();
12606 	uint32_t value_size = get_declared_type_size_msl(basic_type, is_packed, row_major);
12607 
12608 	uint32_t dimensions = uint32_t(type.array.size());
12609 	assert(dimensions > 0);
12610 	dimensions--;
12611 
12612 	// Multiply together every dimension, except the last one.
12613 	for (uint32_t dim = 0; dim < dimensions; dim++)
12614 	{
12615 		uint32_t array_size = to_array_size_literal(type, dim);
12616 		value_size *= max(array_size, 1u);
12617 	}
12618 
12619 	return value_size;
12620 }
12621 
get_declared_struct_member_array_stride_msl(const SPIRType & type,uint32_t index) const12622 uint32_t CompilerMSL::get_declared_struct_member_array_stride_msl(const SPIRType &type, uint32_t index) const
12623 {
12624 	return get_declared_type_array_stride_msl(get_physical_member_type(type, index),
12625 	                                          member_is_packed_physical_type(type, index),
12626 	                                          has_member_decoration(type.self, index, DecorationRowMajor));
12627 }
12628 
get_declared_input_array_stride_msl(const SPIRType & type,uint32_t index) const12629 uint32_t CompilerMSL::get_declared_input_array_stride_msl(const SPIRType &type, uint32_t index) const
12630 {
12631 	return get_declared_type_array_stride_msl(get_presumed_input_type(type, index), false,
12632 	                                          has_member_decoration(type.self, index, DecorationRowMajor));
12633 }
12634 
get_declared_type_matrix_stride_msl(const SPIRType & type,bool packed,bool row_major) const12635 uint32_t CompilerMSL::get_declared_type_matrix_stride_msl(const SPIRType &type, bool packed, bool row_major) const
12636 {
12637 	// For packed matrices, we just use the size of the vector type.
12638 	// Otherwise, MatrixStride == alignment, which is the size of the underlying vector type.
12639 	if (packed)
12640 		return (type.width / 8) * ((row_major && type.columns > 1) ? type.columns : type.vecsize);
12641 	else
12642 		return get_declared_type_alignment_msl(type, false, row_major);
12643 }
12644 
get_declared_struct_member_matrix_stride_msl(const SPIRType & type,uint32_t index) const12645 uint32_t CompilerMSL::get_declared_struct_member_matrix_stride_msl(const SPIRType &type, uint32_t index) const
12646 {
12647 	return get_declared_type_matrix_stride_msl(get_physical_member_type(type, index),
12648 	                                           member_is_packed_physical_type(type, index),
12649 	                                           has_member_decoration(type.self, index, DecorationRowMajor));
12650 }
12651 
get_declared_input_matrix_stride_msl(const SPIRType & type,uint32_t index) const12652 uint32_t CompilerMSL::get_declared_input_matrix_stride_msl(const SPIRType &type, uint32_t index) const
12653 {
12654 	return get_declared_type_matrix_stride_msl(get_presumed_input_type(type, index), false,
12655 	                                           has_member_decoration(type.self, index, DecorationRowMajor));
12656 }
12657 
get_declared_struct_size_msl(const SPIRType & struct_type,bool ignore_alignment,bool ignore_padding) const12658 uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type, bool ignore_alignment,
12659                                                    bool ignore_padding) const
12660 {
12661 	// If we have a target size, that is the declared size as well.
12662 	if (!ignore_padding && has_extended_decoration(struct_type.self, SPIRVCrossDecorationPaddingTarget))
12663 		return get_extended_decoration(struct_type.self, SPIRVCrossDecorationPaddingTarget);
12664 
12665 	if (struct_type.member_types.empty())
12666 		return 0;
12667 
12668 	uint32_t mbr_cnt = uint32_t(struct_type.member_types.size());
12669 
12670 	// In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
12671 	uint32_t alignment = 1;
12672 
12673 	if (!ignore_alignment)
12674 	{
12675 		for (uint32_t i = 0; i < mbr_cnt; i++)
12676 		{
12677 			uint32_t mbr_alignment = get_declared_struct_member_alignment_msl(struct_type, i);
12678 			alignment = max(alignment, mbr_alignment);
12679 		}
12680 	}
12681 
12682 	// Last member will always be matched to the final Offset decoration, but size of struct in MSL now depends
12683 	// on physical size in MSL, and the size of the struct itself is then aligned to struct alignment.
12684 	uint32_t spirv_offset = type_struct_member_offset(struct_type, mbr_cnt - 1);
12685 	uint32_t msl_size = spirv_offset + get_declared_struct_member_size_msl(struct_type, mbr_cnt - 1);
12686 	msl_size = (msl_size + alignment - 1) & ~(alignment - 1);
12687 	return msl_size;
12688 }
12689 
12690 // Returns the byte size of a struct member.
get_declared_type_size_msl(const SPIRType & type,bool is_packed,bool row_major) const12691 uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const
12692 {
12693 	switch (type.basetype)
12694 	{
12695 	case SPIRType::Unknown:
12696 	case SPIRType::Void:
12697 	case SPIRType::AtomicCounter:
12698 	case SPIRType::Image:
12699 	case SPIRType::SampledImage:
12700 	case SPIRType::Sampler:
12701 		SPIRV_CROSS_THROW("Querying size of opaque object.");
12702 
12703 	default:
12704 	{
12705 		if (!type.array.empty())
12706 		{
12707 			uint32_t array_size = to_array_size_literal(type);
12708 			return get_declared_type_array_stride_msl(type, is_packed, row_major) * max(array_size, 1u);
12709 		}
12710 
12711 		if (type.basetype == SPIRType::Struct)
12712 			return get_declared_struct_size_msl(type);
12713 
12714 		if (is_packed)
12715 		{
12716 			return type.vecsize * type.columns * (type.width / 8);
12717 		}
12718 		else
12719 		{
12720 			// An unpacked 3-element vector or matrix column is the same memory size as a 4-element.
12721 			uint32_t vecsize = type.vecsize;
12722 			uint32_t columns = type.columns;
12723 
12724 			if (row_major && columns > 1)
12725 				swap(vecsize, columns);
12726 
12727 			if (vecsize == 3)
12728 				vecsize = 4;
12729 
12730 			return vecsize * columns * (type.width / 8);
12731 		}
12732 	}
12733 	}
12734 }
12735 
get_declared_struct_member_size_msl(const SPIRType & type,uint32_t index) const12736 uint32_t CompilerMSL::get_declared_struct_member_size_msl(const SPIRType &type, uint32_t index) const
12737 {
12738 	return get_declared_type_size_msl(get_physical_member_type(type, index),
12739 	                                  member_is_packed_physical_type(type, index),
12740 	                                  has_member_decoration(type.self, index, DecorationRowMajor));
12741 }
12742 
get_declared_input_size_msl(const SPIRType & type,uint32_t index) const12743 uint32_t CompilerMSL::get_declared_input_size_msl(const SPIRType &type, uint32_t index) const
12744 {
12745 	return get_declared_type_size_msl(get_presumed_input_type(type, index), false,
12746 	                                  has_member_decoration(type.self, index, DecorationRowMajor));
12747 }
12748 
12749 // Returns the byte alignment of a type.
get_declared_type_alignment_msl(const SPIRType & type,bool is_packed,bool row_major) const12750 uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool is_packed, bool row_major) const
12751 {
12752 	switch (type.basetype)
12753 	{
12754 	case SPIRType::Unknown:
12755 	case SPIRType::Void:
12756 	case SPIRType::AtomicCounter:
12757 	case SPIRType::Image:
12758 	case SPIRType::SampledImage:
12759 	case SPIRType::Sampler:
12760 		SPIRV_CROSS_THROW("Querying alignment of opaque object.");
12761 
12762 	case SPIRType::Int64:
12763 		SPIRV_CROSS_THROW("long types are not supported in buffers in MSL.");
12764 	case SPIRType::UInt64:
12765 		SPIRV_CROSS_THROW("ulong types are not supported in buffers in MSL.");
12766 	case SPIRType::Double:
12767 		SPIRV_CROSS_THROW("double types are not supported in buffers in MSL.");
12768 
12769 	case SPIRType::Struct:
12770 	{
12771 		// In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
12772 		uint32_t alignment = 1;
12773 		for (uint32_t i = 0; i < type.member_types.size(); i++)
12774 			alignment = max(alignment, uint32_t(get_declared_struct_member_alignment_msl(type, i)));
12775 		return alignment;
12776 	}
12777 
12778 	default:
12779 	{
12780 		// Alignment of packed type is the same as the underlying component or column size.
12781 		// Alignment of unpacked type is the same as the vector size.
12782 		// Alignment of 3-elements vector is the same as 4-elements (including packed using column).
12783 		if (is_packed)
12784 		{
12785 			// If we have packed_T and friends, the alignment is always scalar.
12786 			return type.width / 8;
12787 		}
12788 		else
12789 		{
12790 			// This is the general rule for MSL. Size == alignment.
12791 			uint32_t vecsize = (row_major && type.columns > 1) ? type.columns : type.vecsize;
12792 			return (type.width / 8) * (vecsize == 3 ? 4 : vecsize);
12793 		}
12794 	}
12795 	}
12796 }
12797 
get_declared_struct_member_alignment_msl(const SPIRType & type,uint32_t index) const12798 uint32_t CompilerMSL::get_declared_struct_member_alignment_msl(const SPIRType &type, uint32_t index) const
12799 {
12800 	return get_declared_type_alignment_msl(get_physical_member_type(type, index),
12801 	                                       member_is_packed_physical_type(type, index),
12802 	                                       has_member_decoration(type.self, index, DecorationRowMajor));
12803 }
12804 
get_declared_input_alignment_msl(const SPIRType & type,uint32_t index) const12805 uint32_t CompilerMSL::get_declared_input_alignment_msl(const SPIRType &type, uint32_t index) const
12806 {
12807 	return get_declared_type_alignment_msl(get_presumed_input_type(type, index), false,
12808 	                                       has_member_decoration(type.self, index, DecorationRowMajor));
12809 }
12810 
skip_argument(uint32_t) const12811 bool CompilerMSL::skip_argument(uint32_t) const
12812 {
12813 	return false;
12814 }
12815 
analyze_sampled_image_usage()12816 void CompilerMSL::analyze_sampled_image_usage()
12817 {
12818 	if (msl_options.swizzle_texture_samples)
12819 	{
12820 		SampledImageScanner scanner(*this);
12821 		traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), scanner);
12822 	}
12823 }
12824 
handle(spv::Op opcode,const uint32_t * args,uint32_t length)12825 bool CompilerMSL::SampledImageScanner::handle(spv::Op opcode, const uint32_t *args, uint32_t length)
12826 {
12827 	switch (opcode)
12828 	{
12829 	case OpLoad:
12830 	case OpImage:
12831 	case OpSampledImage:
12832 	{
12833 		if (length < 3)
12834 			return false;
12835 
12836 		uint32_t result_type = args[0];
12837 		auto &type = compiler.get<SPIRType>(result_type);
12838 		if ((type.basetype != SPIRType::Image && type.basetype != SPIRType::SampledImage) || type.image.sampled != 1)
12839 			return true;
12840 
12841 		uint32_t id = args[1];
12842 		compiler.set<SPIRExpression>(id, "", result_type, true);
12843 		break;
12844 	}
12845 	case OpImageSampleExplicitLod:
12846 	case OpImageSampleProjExplicitLod:
12847 	case OpImageSampleDrefExplicitLod:
12848 	case OpImageSampleProjDrefExplicitLod:
12849 	case OpImageSampleImplicitLod:
12850 	case OpImageSampleProjImplicitLod:
12851 	case OpImageSampleDrefImplicitLod:
12852 	case OpImageSampleProjDrefImplicitLod:
12853 	case OpImageFetch:
12854 	case OpImageGather:
12855 	case OpImageDrefGather:
12856 		compiler.has_sampled_images =
12857 		    compiler.has_sampled_images || compiler.is_sampled_image_type(compiler.expression_type(args[2]));
12858 		compiler.needs_swizzle_buffer_def = compiler.needs_swizzle_buffer_def || compiler.has_sampled_images;
12859 		break;
12860 	default:
12861 		break;
12862 	}
12863 	return true;
12864 }
12865 
12866 // If a needed custom function wasn't added before, add it and force a recompile.
add_spv_func_and_recompile(SPVFuncImpl spv_func)12867 void CompilerMSL::add_spv_func_and_recompile(SPVFuncImpl spv_func)
12868 {
12869 	if (spv_function_implementations.count(spv_func) == 0)
12870 	{
12871 		spv_function_implementations.insert(spv_func);
12872 		suppress_missing_prototypes = true;
12873 		force_recompile();
12874 	}
12875 }
12876 
handle(Op opcode,const uint32_t * args,uint32_t length)12877 bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, uint32_t length)
12878 {
12879 	// Since MSL exists in a single execution scope, function prototype declarations are not
12880 	// needed, and clutter the output. If secondary functions are output (either as a SPIR-V
12881 	// function implementation or as indicated by the presence of OpFunctionCall), then set
12882 	// suppress_missing_prototypes to suppress compiler warnings of missing function prototypes.
12883 
12884 	// Mark if the input requires the implementation of an SPIR-V function that does not exist in Metal.
12885 	SPVFuncImpl spv_func = get_spv_func_impl(opcode, args);
12886 	if (spv_func != SPVFuncImplNone)
12887 	{
12888 		compiler.spv_function_implementations.insert(spv_func);
12889 		suppress_missing_prototypes = true;
12890 	}
12891 
12892 	switch (opcode)
12893 	{
12894 
12895 	case OpFunctionCall:
12896 		suppress_missing_prototypes = true;
12897 		break;
12898 
12899 	// Emulate texture2D atomic operations
12900 	case OpImageTexelPointer:
12901 	{
12902 		auto *var = compiler.maybe_get_backing_variable(args[2]);
12903 		image_pointers[args[1]] = var ? var->self : ID(0);
12904 		break;
12905 	}
12906 
12907 	case OpImageWrite:
12908 		uses_resource_write = true;
12909 		break;
12910 
12911 	case OpStore:
12912 		check_resource_write(args[0]);
12913 		break;
12914 
12915 	// Emulate texture2D atomic operations
12916 	case OpAtomicExchange:
12917 	case OpAtomicCompareExchange:
12918 	case OpAtomicCompareExchangeWeak:
12919 	case OpAtomicIIncrement:
12920 	case OpAtomicIDecrement:
12921 	case OpAtomicIAdd:
12922 	case OpAtomicISub:
12923 	case OpAtomicSMin:
12924 	case OpAtomicUMin:
12925 	case OpAtomicSMax:
12926 	case OpAtomicUMax:
12927 	case OpAtomicAnd:
12928 	case OpAtomicOr:
12929 	case OpAtomicXor:
12930 	{
12931 		uses_atomics = true;
12932 		auto it = image_pointers.find(args[2]);
12933 		if (it != image_pointers.end())
12934 		{
12935 			compiler.atomic_image_vars.insert(it->second);
12936 		}
12937 		check_resource_write(args[2]);
12938 		break;
12939 	}
12940 
12941 	case OpAtomicStore:
12942 	{
12943 		uses_atomics = true;
12944 		auto it = image_pointers.find(args[0]);
12945 		if (it != image_pointers.end())
12946 		{
12947 			compiler.atomic_image_vars.insert(it->second);
12948 		}
12949 		check_resource_write(args[0]);
12950 		break;
12951 	}
12952 
12953 	case OpAtomicLoad:
12954 	{
12955 		uses_atomics = true;
12956 		auto it = image_pointers.find(args[2]);
12957 		if (it != image_pointers.end())
12958 		{
12959 			compiler.atomic_image_vars.insert(it->second);
12960 		}
12961 		break;
12962 	}
12963 
12964 	case OpGroupNonUniformInverseBallot:
12965 		needs_subgroup_invocation_id = true;
12966 		break;
12967 
12968 	case OpGroupNonUniformBallotBitCount:
12969 		if (args[3] != GroupOperationReduce)
12970 			needs_subgroup_invocation_id = true;
12971 		break;
12972 
12973 	case OpArrayLength:
12974 	{
12975 		auto *var = compiler.maybe_get_backing_variable(args[2]);
12976 		if (var)
12977 			compiler.buffers_requiring_array_length.insert(var->self);
12978 		break;
12979 	}
12980 
12981 	case OpInBoundsAccessChain:
12982 	case OpAccessChain:
12983 	case OpPtrAccessChain:
12984 	{
12985 		// OpArrayLength might want to know if taking ArrayLength of an array of SSBOs.
12986 		uint32_t result_type = args[0];
12987 		uint32_t id = args[1];
12988 		uint32_t ptr = args[2];
12989 
12990 		compiler.set<SPIRExpression>(id, "", result_type, true);
12991 		compiler.register_read(id, ptr, true);
12992 		compiler.ir.ids[id].set_allow_type_rewrite();
12993 		break;
12994 	}
12995 
12996 	default:
12997 		break;
12998 	}
12999 
13000 	// If it has one, keep track of the instruction's result type, mapped by ID
13001 	uint32_t result_type, result_id;
13002 	if (compiler.instruction_to_result_type(result_type, result_id, opcode, args, length))
13003 		result_types[result_id] = result_type;
13004 
13005 	return true;
13006 }
13007 
13008 // If the variable is a Uniform or StorageBuffer, mark that a resource has been written to.
check_resource_write(uint32_t var_id)13009 void CompilerMSL::OpCodePreprocessor::check_resource_write(uint32_t var_id)
13010 {
13011 	auto *p_var = compiler.maybe_get_backing_variable(var_id);
13012 	StorageClass sc = p_var ? p_var->storage : StorageClassMax;
13013 	if (sc == StorageClassUniform || sc == StorageClassStorageBuffer)
13014 		uses_resource_write = true;
13015 }
13016 
13017 // Returns an enumeration of a SPIR-V function that needs to be output for certain Op codes.
get_spv_func_impl(Op opcode,const uint32_t * args)13018 CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op opcode, const uint32_t *args)
13019 {
13020 	switch (opcode)
13021 	{
13022 	case OpFMod:
13023 		return SPVFuncImplMod;
13024 
13025 	case OpFAdd:
13026 		if (compiler.msl_options.invariant_float_math)
13027 		{
13028 			return SPVFuncImplFAdd;
13029 		}
13030 		break;
13031 
13032 	case OpFMul:
13033 	case OpOuterProduct:
13034 	case OpMatrixTimesVector:
13035 	case OpVectorTimesMatrix:
13036 	case OpMatrixTimesMatrix:
13037 		if (compiler.msl_options.invariant_float_math)
13038 		{
13039 			return SPVFuncImplFMul;
13040 		}
13041 		break;
13042 
13043 	case OpTypeArray:
13044 	{
13045 		// Allow Metal to use the array<T> template to make arrays a value type
13046 		return SPVFuncImplUnsafeArray;
13047 	}
13048 
13049 	// Emulate texture2D atomic operations
13050 	case OpAtomicExchange:
13051 	case OpAtomicCompareExchange:
13052 	case OpAtomicCompareExchangeWeak:
13053 	case OpAtomicIIncrement:
13054 	case OpAtomicIDecrement:
13055 	case OpAtomicIAdd:
13056 	case OpAtomicISub:
13057 	case OpAtomicSMin:
13058 	case OpAtomicUMin:
13059 	case OpAtomicSMax:
13060 	case OpAtomicUMax:
13061 	case OpAtomicAnd:
13062 	case OpAtomicOr:
13063 	case OpAtomicXor:
13064 	case OpAtomicLoad:
13065 	case OpAtomicStore:
13066 	{
13067 		auto it = image_pointers.find(args[opcode == OpAtomicStore ? 0 : 2]);
13068 		if (it != image_pointers.end())
13069 		{
13070 			uint32_t tid = compiler.get<SPIRVariable>(it->second).basetype;
13071 			if (tid && compiler.get<SPIRType>(tid).image.dim == Dim2D)
13072 				return SPVFuncImplImage2DAtomicCoords;
13073 		}
13074 		break;
13075 	}
13076 
13077 	case OpImageFetch:
13078 	case OpImageRead:
13079 	case OpImageWrite:
13080 	{
13081 		// Retrieve the image type, and if it's a Buffer, emit a texel coordinate function
13082 		uint32_t tid = result_types[args[opcode == OpImageWrite ? 0 : 2]];
13083 		if (tid && compiler.get<SPIRType>(tid).image.dim == DimBuffer && !compiler.msl_options.texture_buffer_native)
13084 			return SPVFuncImplTexelBufferCoords;
13085 		break;
13086 	}
13087 
13088 	case OpExtInst:
13089 	{
13090 		uint32_t extension_set = args[2];
13091 		if (compiler.get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
13092 		{
13093 			auto op_450 = static_cast<GLSLstd450>(args[3]);
13094 			switch (op_450)
13095 			{
13096 			case GLSLstd450Radians:
13097 				return SPVFuncImplRadians;
13098 			case GLSLstd450Degrees:
13099 				return SPVFuncImplDegrees;
13100 			case GLSLstd450FindILsb:
13101 				return SPVFuncImplFindILsb;
13102 			case GLSLstd450FindSMsb:
13103 				return SPVFuncImplFindSMsb;
13104 			case GLSLstd450FindUMsb:
13105 				return SPVFuncImplFindUMsb;
13106 			case GLSLstd450SSign:
13107 				return SPVFuncImplSSign;
13108 			case GLSLstd450Reflect:
13109 			{
13110 				auto &type = compiler.get<SPIRType>(args[0]);
13111 				if (type.vecsize == 1)
13112 					return SPVFuncImplReflectScalar;
13113 				break;
13114 			}
13115 			case GLSLstd450Refract:
13116 			{
13117 				auto &type = compiler.get<SPIRType>(args[0]);
13118 				if (type.vecsize == 1)
13119 					return SPVFuncImplRefractScalar;
13120 				break;
13121 			}
13122 			case GLSLstd450FaceForward:
13123 			{
13124 				auto &type = compiler.get<SPIRType>(args[0]);
13125 				if (type.vecsize == 1)
13126 					return SPVFuncImplFaceForwardScalar;
13127 				break;
13128 			}
13129 			case GLSLstd450MatrixInverse:
13130 			{
13131 				auto &mat_type = compiler.get<SPIRType>(args[0]);
13132 				switch (mat_type.columns)
13133 				{
13134 				case 2:
13135 					return SPVFuncImplInverse2x2;
13136 				case 3:
13137 					return SPVFuncImplInverse3x3;
13138 				case 4:
13139 					return SPVFuncImplInverse4x4;
13140 				default:
13141 					break;
13142 				}
13143 				break;
13144 			}
13145 			default:
13146 				break;
13147 			}
13148 		}
13149 		break;
13150 	}
13151 
13152 	case OpGroupNonUniformBallot:
13153 		return SPVFuncImplSubgroupBallot;
13154 
13155 	case OpGroupNonUniformInverseBallot:
13156 	case OpGroupNonUniformBallotBitExtract:
13157 		return SPVFuncImplSubgroupBallotBitExtract;
13158 
13159 	case OpGroupNonUniformBallotFindLSB:
13160 		return SPVFuncImplSubgroupBallotFindLSB;
13161 
13162 	case OpGroupNonUniformBallotFindMSB:
13163 		return SPVFuncImplSubgroupBallotFindMSB;
13164 
13165 	case OpGroupNonUniformBallotBitCount:
13166 		return SPVFuncImplSubgroupBallotBitCount;
13167 
13168 	case OpGroupNonUniformAllEqual:
13169 		return SPVFuncImplSubgroupAllEqual;
13170 
13171 	default:
13172 		break;
13173 	}
13174 	return SPVFuncImplNone;
13175 }
13176 
13177 // Sort both type and meta member content based on builtin status (put builtins at end),
13178 // then by the required sorting aspect.
sort()13179 void CompilerMSL::MemberSorter::sort()
13180 {
13181 	// Create a temporary array of consecutive member indices and sort it based on how
13182 	// the members should be reordered, based on builtin and sorting aspect meta info.
13183 	size_t mbr_cnt = type.member_types.size();
13184 	SmallVector<uint32_t> mbr_idxs(mbr_cnt);
13185 	std::iota(mbr_idxs.begin(), mbr_idxs.end(), 0); // Fill with consecutive indices
13186 	std::stable_sort(mbr_idxs.begin(), mbr_idxs.end(), *this); // Sort member indices based on sorting aspect
13187 
13188 	bool sort_is_identity = true;
13189 	for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
13190 	{
13191 		if (mbr_idx != mbr_idxs[mbr_idx])
13192 		{
13193 			sort_is_identity = false;
13194 			break;
13195 		}
13196 	}
13197 
13198 	if (sort_is_identity)
13199 		return;
13200 
13201 	if (meta.members.size() < type.member_types.size())
13202 	{
13203 		// This should never trigger in normal circumstances, but to be safe.
13204 		meta.members.resize(type.member_types.size());
13205 	}
13206 
13207 	// Move type and meta member info to the order defined by the sorted member indices.
13208 	// This is done by creating temporary copies of both member types and meta, and then
13209 	// copying back to the original content at the sorted indices.
13210 	auto mbr_types_cpy = type.member_types;
13211 	auto mbr_meta_cpy = meta.members;
13212 	for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
13213 	{
13214 		type.member_types[mbr_idx] = mbr_types_cpy[mbr_idxs[mbr_idx]];
13215 		meta.members[mbr_idx] = mbr_meta_cpy[mbr_idxs[mbr_idx]];
13216 	}
13217 
13218 	if (sort_aspect == SortAspect::Offset)
13219 	{
13220 		// If we're sorting by Offset, this might affect user code which accesses a buffer block.
13221 		// We will need to redirect member indices from one index to sorted index.
13222 		type.member_type_index_redirection = std::move(mbr_idxs);
13223 	}
13224 }
13225 
13226 // Sort first by builtin status (put builtins at end), then by the sorting aspect.
operator ()(uint32_t mbr_idx1,uint32_t mbr_idx2)13227 bool CompilerMSL::MemberSorter::operator()(uint32_t mbr_idx1, uint32_t mbr_idx2)
13228 {
13229 	auto &mbr_meta1 = meta.members[mbr_idx1];
13230 	auto &mbr_meta2 = meta.members[mbr_idx2];
13231 	if (mbr_meta1.builtin != mbr_meta2.builtin)
13232 		return mbr_meta2.builtin;
13233 	else
13234 		switch (sort_aspect)
13235 		{
13236 		case Location:
13237 			return mbr_meta1.location < mbr_meta2.location;
13238 		case LocationReverse:
13239 			return mbr_meta1.location > mbr_meta2.location;
13240 		case Offset:
13241 			return mbr_meta1.offset < mbr_meta2.offset;
13242 		case OffsetThenLocationReverse:
13243 			return (mbr_meta1.offset < mbr_meta2.offset) ||
13244 			       ((mbr_meta1.offset == mbr_meta2.offset) && (mbr_meta1.location > mbr_meta2.location));
13245 		case Alphabetical:
13246 			return mbr_meta1.alias < mbr_meta2.alias;
13247 		default:
13248 			return false;
13249 		}
13250 }
13251 
MemberSorter(SPIRType & t,Meta & m,SortAspect sa)13252 CompilerMSL::MemberSorter::MemberSorter(SPIRType &t, Meta &m, SortAspect sa)
13253     : type(t)
13254     , meta(m)
13255     , sort_aspect(sa)
13256 {
13257 	// Ensure enough meta info is available
13258 	meta.members.resize(max(type.member_types.size(), meta.members.size()));
13259 }
13260 
remap_constexpr_sampler(VariableID id,const MSLConstexprSampler & sampler)13261 void CompilerMSL::remap_constexpr_sampler(VariableID id, const MSLConstexprSampler &sampler)
13262 {
13263 	auto &type = get<SPIRType>(get<SPIRVariable>(id).basetype);
13264 	if (type.basetype != SPIRType::SampledImage && type.basetype != SPIRType::Sampler)
13265 		SPIRV_CROSS_THROW("Can only remap SampledImage and Sampler type.");
13266 	if (!type.array.empty())
13267 		SPIRV_CROSS_THROW("Can not remap array of samplers.");
13268 	constexpr_samplers_by_id[id] = sampler;
13269 }
13270 
remap_constexpr_sampler_by_binding(uint32_t desc_set,uint32_t binding,const MSLConstexprSampler & sampler)13271 void CompilerMSL::remap_constexpr_sampler_by_binding(uint32_t desc_set, uint32_t binding,
13272                                                      const MSLConstexprSampler &sampler)
13273 {
13274 	constexpr_samplers_by_binding[{ desc_set, binding }] = sampler;
13275 }
13276 
bitcast_from_builtin_load(uint32_t source_id,std::string & expr,const SPIRType & expr_type)13277 void CompilerMSL::bitcast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
13278 {
13279 	auto *var = maybe_get_backing_variable(source_id);
13280 	if (var)
13281 		source_id = var->self;
13282 
13283 	// Only interested in standalone builtin variables.
13284 	if (!has_decoration(source_id, DecorationBuiltIn))
13285 		return;
13286 
13287 	auto builtin = static_cast<BuiltIn>(get_decoration(source_id, DecorationBuiltIn));
13288 	auto expected_type = expr_type.basetype;
13289 	switch (builtin)
13290 	{
13291 	case BuiltInGlobalInvocationId:
13292 	case BuiltInLocalInvocationId:
13293 	case BuiltInWorkgroupId:
13294 	case BuiltInLocalInvocationIndex:
13295 	case BuiltInWorkgroupSize:
13296 	case BuiltInNumWorkgroups:
13297 	case BuiltInLayer:
13298 	case BuiltInViewportIndex:
13299 	case BuiltInFragStencilRefEXT:
13300 	case BuiltInPrimitiveId:
13301 	case BuiltInSubgroupSize:
13302 	case BuiltInSubgroupLocalInvocationId:
13303 	case BuiltInViewIndex:
13304 	case BuiltInVertexIndex:
13305 	case BuiltInInstanceIndex:
13306 	case BuiltInBaseInstance:
13307 	case BuiltInBaseVertex:
13308 		expected_type = SPIRType::UInt;
13309 		break;
13310 
13311 	case BuiltInTessLevelInner:
13312 	case BuiltInTessLevelOuter:
13313 		if (get_execution_model() == ExecutionModelTessellationControl)
13314 			expected_type = SPIRType::Half;
13315 		break;
13316 
13317 	default:
13318 		break;
13319 	}
13320 
13321 	if (expected_type != expr_type.basetype)
13322 		expr = bitcast_expression(expr_type, expected_type, expr);
13323 
13324 	if (builtin == BuiltInTessCoord && get_entry_point().flags.get(ExecutionModeQuads) && expr_type.vecsize == 3)
13325 	{
13326 		// In SPIR-V, this is always a vec3, even for quads. In Metal, though, it's a float2 for quads.
13327 		// The code is expecting a float3, so we need to widen this.
13328 		expr = join("float3(", expr, ", 0)");
13329 	}
13330 }
13331 
bitcast_to_builtin_store(uint32_t target_id,std::string & expr,const SPIRType & expr_type)13332 void CompilerMSL::bitcast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
13333 {
13334 	auto *var = maybe_get_backing_variable(target_id);
13335 	if (var)
13336 		target_id = var->self;
13337 
13338 	// Only interested in standalone builtin variables.
13339 	if (!has_decoration(target_id, DecorationBuiltIn))
13340 		return;
13341 
13342 	auto builtin = static_cast<BuiltIn>(get_decoration(target_id, DecorationBuiltIn));
13343 	auto expected_type = expr_type.basetype;
13344 	switch (builtin)
13345 	{
13346 	case BuiltInLayer:
13347 	case BuiltInViewportIndex:
13348 	case BuiltInFragStencilRefEXT:
13349 	case BuiltInPrimitiveId:
13350 	case BuiltInViewIndex:
13351 		expected_type = SPIRType::UInt;
13352 		break;
13353 
13354 	case BuiltInTessLevelInner:
13355 	case BuiltInTessLevelOuter:
13356 		expected_type = SPIRType::Half;
13357 		break;
13358 
13359 	default:
13360 		break;
13361 	}
13362 
13363 	if (expected_type != expr_type.basetype)
13364 	{
13365 		if (expected_type == SPIRType::Half && expr_type.basetype == SPIRType::Float)
13366 		{
13367 			// These are of different widths, so we cannot do a straight bitcast.
13368 			expr = join("half(", expr, ")");
13369 		}
13370 		else
13371 		{
13372 			auto type = expr_type;
13373 			type.basetype = expected_type;
13374 			expr = bitcast_expression(type, expr_type.basetype, expr);
13375 		}
13376 	}
13377 }
13378 
to_initializer_expression(const SPIRVariable & var)13379 string CompilerMSL::to_initializer_expression(const SPIRVariable &var)
13380 {
13381 	// We risk getting an array initializer here with MSL. If we have an array.
13382 	// FIXME: We cannot handle non-constant arrays being initialized.
13383 	// We will need to inject spvArrayCopy here somehow ...
13384 	auto &type = get<SPIRType>(var.basetype);
13385 	string expr;
13386 	if (ir.ids[var.initializer].get_type() == TypeConstant &&
13387 	    (!type.array.empty() || type.basetype == SPIRType::Struct))
13388 		expr = constant_expression(get<SPIRConstant>(var.initializer));
13389 	else
13390 		expr = CompilerGLSL::to_initializer_expression(var);
13391 	// If the initializer has more vector components than the variable, add a swizzle.
13392 	// FIXME: This can't handle arrays or structs.
13393 	auto &init_type = expression_type(var.initializer);
13394 	if (type.array.empty() && type.basetype != SPIRType::Struct && init_type.vecsize > type.vecsize)
13395 		expr = enclose_expression(expr + vector_swizzle(type.vecsize, 0));
13396 	return expr;
13397 }
13398 
to_zero_initialized_expression(uint32_t)13399 string CompilerMSL::to_zero_initialized_expression(uint32_t)
13400 {
13401 	return "{}";
13402 }
13403 
descriptor_set_is_argument_buffer(uint32_t desc_set) const13404 bool CompilerMSL::descriptor_set_is_argument_buffer(uint32_t desc_set) const
13405 {
13406 	if (!msl_options.argument_buffers)
13407 		return false;
13408 	if (desc_set >= kMaxArgumentBuffers)
13409 		return false;
13410 
13411 	return (argument_buffer_discrete_mask & (1u << desc_set)) == 0;
13412 }
13413 
analyze_argument_buffers()13414 void CompilerMSL::analyze_argument_buffers()
13415 {
13416 	// Gather all used resources and sort them out into argument buffers.
13417 	// Each argument buffer corresponds to a descriptor set in SPIR-V.
13418 	// The [[id(N)]] values used correspond to the resource mapping we have for MSL.
13419 	// Otherwise, the binding number is used, but this is generally not safe some types like
13420 	// combined image samplers and arrays of resources. Metal needs different indices here,
13421 	// while SPIR-V can have one descriptor set binding. To use argument buffers in practice,
13422 	// you will need to use the remapping from the API.
13423 	for (auto &id : argument_buffer_ids)
13424 		id = 0;
13425 
13426 	// Output resources, sorted by resource index & type.
13427 	struct Resource
13428 	{
13429 		SPIRVariable *var;
13430 		string name;
13431 		SPIRType::BaseType basetype;
13432 		uint32_t index;
13433 		uint32_t plane;
13434 	};
13435 	SmallVector<Resource> resources_in_set[kMaxArgumentBuffers];
13436 	SmallVector<uint32_t> inline_block_vars;
13437 
13438 	bool set_needs_swizzle_buffer[kMaxArgumentBuffers] = {};
13439 	bool set_needs_buffer_sizes[kMaxArgumentBuffers] = {};
13440 	bool needs_buffer_sizes = false;
13441 
13442 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &var) {
13443 		if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
13444 		     var.storage == StorageClassStorageBuffer) &&
13445 		    !is_hidden_variable(var))
13446 		{
13447 			uint32_t desc_set = get_decoration(self, DecorationDescriptorSet);
13448 			// Ignore if it's part of a push descriptor set.
13449 			if (!descriptor_set_is_argument_buffer(desc_set))
13450 				return;
13451 
13452 			uint32_t var_id = var.self;
13453 			auto &type = get_variable_data_type(var);
13454 
13455 			if (desc_set >= kMaxArgumentBuffers)
13456 				SPIRV_CROSS_THROW("Descriptor set index is out of range.");
13457 
13458 			const MSLConstexprSampler *constexpr_sampler = nullptr;
13459 			if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
13460 			{
13461 				constexpr_sampler = find_constexpr_sampler(var_id);
13462 				if (constexpr_sampler)
13463 				{
13464 					// Mark this ID as a constexpr sampler for later in case it came from set/bindings.
13465 					constexpr_samplers_by_id[var_id] = *constexpr_sampler;
13466 				}
13467 			}
13468 
13469 			uint32_t binding = get_decoration(var_id, DecorationBinding);
13470 			if (type.basetype == SPIRType::SampledImage)
13471 			{
13472 				add_resource_name(var_id);
13473 
13474 				uint32_t plane_count = 1;
13475 				if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
13476 					plane_count = constexpr_sampler->planes;
13477 
13478 				for (uint32_t i = 0; i < plane_count; i++)
13479 				{
13480 					uint32_t image_resource_index = get_metal_resource_index(var, SPIRType::Image, i);
13481 					resources_in_set[desc_set].push_back(
13482 					    { &var, to_name(var_id), SPIRType::Image, image_resource_index, i });
13483 				}
13484 
13485 				if (type.image.dim != DimBuffer && !constexpr_sampler)
13486 				{
13487 					uint32_t sampler_resource_index = get_metal_resource_index(var, SPIRType::Sampler);
13488 					resources_in_set[desc_set].push_back(
13489 					    { &var, to_sampler_expression(var_id), SPIRType::Sampler, sampler_resource_index, 0 });
13490 				}
13491 			}
13492 			else if (inline_uniform_blocks.count(SetBindingPair{ desc_set, binding }))
13493 			{
13494 				inline_block_vars.push_back(var_id);
13495 			}
13496 			else if (!constexpr_sampler)
13497 			{
13498 				// constexpr samplers are not declared as resources.
13499 				// Inline uniform blocks are always emitted at the end.
13500 				if (!msl_options.is_ios() || type.basetype != SPIRType::Image || type.image.sampled != 2)
13501 				{
13502 					add_resource_name(var_id);
13503 					resources_in_set[desc_set].push_back(
13504 					    { &var, to_name(var_id), type.basetype, get_metal_resource_index(var, type.basetype), 0 });
13505 				}
13506 			}
13507 
13508 			// Check if this descriptor set needs a swizzle buffer.
13509 			if (needs_swizzle_buffer_def && is_sampled_image_type(type))
13510 				set_needs_swizzle_buffer[desc_set] = true;
13511 			else if (buffers_requiring_array_length.count(var_id) != 0)
13512 			{
13513 				set_needs_buffer_sizes[desc_set] = true;
13514 				needs_buffer_sizes = true;
13515 			}
13516 		}
13517 	});
13518 
13519 	if (needs_swizzle_buffer_def || needs_buffer_sizes)
13520 	{
13521 		uint32_t uint_ptr_type_id = 0;
13522 
13523 		// We might have to add a swizzle buffer resource to the set.
13524 		for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
13525 		{
13526 			if (!set_needs_swizzle_buffer[desc_set] && !set_needs_buffer_sizes[desc_set])
13527 				continue;
13528 
13529 			if (uint_ptr_type_id == 0)
13530 			{
13531 				uint_ptr_type_id = ir.increase_bound_by(1);
13532 
13533 				// Create a buffer to hold extra data, including the swizzle constants.
13534 				SPIRType uint_type_pointer = get_uint_type();
13535 				uint_type_pointer.pointer = true;
13536 				uint_type_pointer.pointer_depth = 1;
13537 				uint_type_pointer.parent_type = get_uint_type_id();
13538 				uint_type_pointer.storage = StorageClassUniform;
13539 				set<SPIRType>(uint_ptr_type_id, uint_type_pointer);
13540 				set_decoration(uint_ptr_type_id, DecorationArrayStride, 4);
13541 			}
13542 
13543 			if (set_needs_swizzle_buffer[desc_set])
13544 			{
13545 				uint32_t var_id = ir.increase_bound_by(1);
13546 				auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
13547 				set_name(var_id, "spvSwizzleConstants");
13548 				set_decoration(var_id, DecorationDescriptorSet, desc_set);
13549 				set_decoration(var_id, DecorationBinding, kSwizzleBufferBinding);
13550 				resources_in_set[desc_set].push_back(
13551 				    { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0 });
13552 			}
13553 
13554 			if (set_needs_buffer_sizes[desc_set])
13555 			{
13556 				uint32_t var_id = ir.increase_bound_by(1);
13557 				auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
13558 				set_name(var_id, "spvBufferSizeConstants");
13559 				set_decoration(var_id, DecorationDescriptorSet, desc_set);
13560 				set_decoration(var_id, DecorationBinding, kBufferSizeBufferBinding);
13561 				resources_in_set[desc_set].push_back(
13562 				    { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0 });
13563 			}
13564 		}
13565 	}
13566 
13567 	// Now add inline uniform blocks.
13568 	for (uint32_t var_id : inline_block_vars)
13569 	{
13570 		auto &var = get<SPIRVariable>(var_id);
13571 		uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
13572 		add_resource_name(var_id);
13573 		resources_in_set[desc_set].push_back(
13574 		    { &var, to_name(var_id), SPIRType::Struct, get_metal_resource_index(var, SPIRType::Struct), 0 });
13575 	}
13576 
13577 	for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
13578 	{
13579 		auto &resources = resources_in_set[desc_set];
13580 		if (resources.empty())
13581 			continue;
13582 
13583 		assert(descriptor_set_is_argument_buffer(desc_set));
13584 
13585 		uint32_t next_id = ir.increase_bound_by(3);
13586 		uint32_t type_id = next_id + 1;
13587 		uint32_t ptr_type_id = next_id + 2;
13588 		argument_buffer_ids[desc_set] = next_id;
13589 
13590 		auto &buffer_type = set<SPIRType>(type_id);
13591 
13592 		buffer_type.basetype = SPIRType::Struct;
13593 
13594 		if ((argument_buffer_device_storage_mask & (1u << desc_set)) != 0)
13595 		{
13596 			buffer_type.storage = StorageClassStorageBuffer;
13597 			// Make sure the argument buffer gets marked as const device.
13598 			set_decoration(next_id, DecorationNonWritable);
13599 			// Need to mark the type as a Block to enable this.
13600 			set_decoration(type_id, DecorationBlock);
13601 		}
13602 		else
13603 			buffer_type.storage = StorageClassUniform;
13604 
13605 		set_name(type_id, join("spvDescriptorSetBuffer", desc_set));
13606 
13607 		auto &ptr_type = set<SPIRType>(ptr_type_id);
13608 		ptr_type = buffer_type;
13609 		ptr_type.pointer = true;
13610 		ptr_type.pointer_depth = 1;
13611 		ptr_type.parent_type = type_id;
13612 
13613 		uint32_t buffer_variable_id = next_id;
13614 		set<SPIRVariable>(buffer_variable_id, ptr_type_id, StorageClassUniform);
13615 		set_name(buffer_variable_id, join("spvDescriptorSet", desc_set));
13616 
13617 		// Ids must be emitted in ID order.
13618 		sort(begin(resources), end(resources), [&](const Resource &lhs, const Resource &rhs) -> bool {
13619 			return tie(lhs.index, lhs.basetype) < tie(rhs.index, rhs.basetype);
13620 		});
13621 
13622 		uint32_t member_index = 0;
13623 		for (auto &resource : resources)
13624 		{
13625 			auto &var = *resource.var;
13626 			auto &type = get_variable_data_type(var);
13627 			string mbr_name = ensure_valid_name(resource.name, "m");
13628 			if (resource.plane > 0)
13629 				mbr_name += join(plane_name_suffix, resource.plane);
13630 			set_member_name(buffer_type.self, member_index, mbr_name);
13631 
13632 			if (resource.basetype == SPIRType::Sampler && type.basetype != SPIRType::Sampler)
13633 			{
13634 				// Have to synthesize a sampler type here.
13635 
13636 				bool type_is_array = !type.array.empty();
13637 				uint32_t sampler_type_id = ir.increase_bound_by(type_is_array ? 2 : 1);
13638 				auto &new_sampler_type = set<SPIRType>(sampler_type_id);
13639 				new_sampler_type.basetype = SPIRType::Sampler;
13640 				new_sampler_type.storage = StorageClassUniformConstant;
13641 
13642 				if (type_is_array)
13643 				{
13644 					uint32_t sampler_type_array_id = sampler_type_id + 1;
13645 					auto &sampler_type_array = set<SPIRType>(sampler_type_array_id);
13646 					sampler_type_array = new_sampler_type;
13647 					sampler_type_array.array = type.array;
13648 					sampler_type_array.array_size_literal = type.array_size_literal;
13649 					sampler_type_array.parent_type = sampler_type_id;
13650 					buffer_type.member_types.push_back(sampler_type_array_id);
13651 				}
13652 				else
13653 					buffer_type.member_types.push_back(sampler_type_id);
13654 			}
13655 			else
13656 			{
13657 				uint32_t binding = get_decoration(var.self, DecorationBinding);
13658 				SetBindingPair pair = { desc_set, binding };
13659 
13660 				if (resource.basetype == SPIRType::Image || resource.basetype == SPIRType::Sampler ||
13661 				    resource.basetype == SPIRType::SampledImage)
13662 				{
13663 					// Drop pointer information when we emit the resources into a struct.
13664 					buffer_type.member_types.push_back(get_variable_data_type_id(var));
13665 					if (resource.plane == 0)
13666 						set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
13667 				}
13668 				else if (buffers_requiring_dynamic_offset.count(pair))
13669 				{
13670 					// Don't set the qualified name here; we'll define a variable holding the corrected buffer address later.
13671 					buffer_type.member_types.push_back(var.basetype);
13672 					buffers_requiring_dynamic_offset[pair].second = var.self;
13673 				}
13674 				else if (inline_uniform_blocks.count(pair))
13675 				{
13676 					// Put the buffer block itself into the argument buffer.
13677 					buffer_type.member_types.push_back(get_variable_data_type_id(var));
13678 					set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
13679 				}
13680 				else
13681 				{
13682 					// Resources will be declared as pointers not references, so automatically dereference as appropriate.
13683 					buffer_type.member_types.push_back(var.basetype);
13684 					if (type.array.empty())
13685 						set_qualified_name(var.self, join("(*", to_name(buffer_variable_id), ".", mbr_name, ")"));
13686 					else
13687 						set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
13688 				}
13689 			}
13690 
13691 			set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationResourceIndexPrimary,
13692 			                               resource.index);
13693 			set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationInterfaceOrigID,
13694 			                               var.self);
13695 			member_index++;
13696 		}
13697 	}
13698 }
13699 
activate_argument_buffer_resources()13700 void CompilerMSL::activate_argument_buffer_resources()
13701 {
13702 	// For ABI compatibility, force-enable all resources which are part of argument buffers.
13703 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, const SPIRVariable &) {
13704 		if (!has_decoration(self, DecorationDescriptorSet))
13705 			return;
13706 
13707 		uint32_t desc_set = get_decoration(self, DecorationDescriptorSet);
13708 		if (descriptor_set_is_argument_buffer(desc_set))
13709 			active_interface_variables.insert(self);
13710 	});
13711 }
13712 
using_builtin_array() const13713 bool CompilerMSL::using_builtin_array() const
13714 {
13715 	return msl_options.force_native_arrays || is_using_builtin_array;
13716 }
13717