1 /*
2  * Copyright 2016-2021 The Brenwill Workshop Ltd.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 /*
18  * At your option, you may choose to accept this material under either:
19  *  1. The Apache License, Version 2.0, found at <http://www.apache.org/licenses/LICENSE-2.0>, or
20  *  2. The MIT License, found at <http://opensource.org/licenses/MIT>.
21  * SPDX-License-Identifier: Apache-2.0 OR MIT.
22  */
23 
24 #include "spirv_msl.hpp"
25 #include "GLSL.std.450.h"
26 
27 #include <algorithm>
28 #include <assert.h>
29 #include <numeric>
30 
31 using namespace spv;
32 using namespace SPIRV_CROSS_NAMESPACE;
33 using namespace std;
34 
35 static const uint32_t k_unknown_location = ~0u;
36 static const uint32_t k_unknown_component = ~0u;
37 static const char *force_inline = "static inline __attribute__((always_inline))";
38 
CompilerMSL(std::vector<uint32_t> spirv_)39 CompilerMSL::CompilerMSL(std::vector<uint32_t> spirv_)
40     : CompilerGLSL(move(spirv_))
41 {
42 }
43 
CompilerMSL(const uint32_t * ir_,size_t word_count)44 CompilerMSL::CompilerMSL(const uint32_t *ir_, size_t word_count)
45     : CompilerGLSL(ir_, word_count)
46 {
47 }
48 
CompilerMSL(const ParsedIR & ir_)49 CompilerMSL::CompilerMSL(const ParsedIR &ir_)
50     : CompilerGLSL(ir_)
51 {
52 }
53 
CompilerMSL(ParsedIR && ir_)54 CompilerMSL::CompilerMSL(ParsedIR &&ir_)
55     : CompilerGLSL(std::move(ir_))
56 {
57 }
58 
add_msl_shader_input(const MSLShaderInput & si)59 void CompilerMSL::add_msl_shader_input(const MSLShaderInput &si)
60 {
61 	inputs_by_location[si.location] = si;
62 	if (si.builtin != BuiltInMax && !inputs_by_builtin.count(si.builtin))
63 		inputs_by_builtin[si.builtin] = si;
64 }
65 
add_msl_resource_binding(const MSLResourceBinding & binding)66 void CompilerMSL::add_msl_resource_binding(const MSLResourceBinding &binding)
67 {
68 	StageSetBinding tuple = { binding.stage, binding.desc_set, binding.binding };
69 	resource_bindings[tuple] = { binding, false };
70 }
71 
add_dynamic_buffer(uint32_t desc_set,uint32_t binding,uint32_t index)72 void CompilerMSL::add_dynamic_buffer(uint32_t desc_set, uint32_t binding, uint32_t index)
73 {
74 	SetBindingPair pair = { desc_set, binding };
75 	buffers_requiring_dynamic_offset[pair] = { index, 0 };
76 }
77 
add_inline_uniform_block(uint32_t desc_set,uint32_t binding)78 void CompilerMSL::add_inline_uniform_block(uint32_t desc_set, uint32_t binding)
79 {
80 	SetBindingPair pair = { desc_set, binding };
81 	inline_uniform_blocks.insert(pair);
82 }
83 
add_discrete_descriptor_set(uint32_t desc_set)84 void CompilerMSL::add_discrete_descriptor_set(uint32_t desc_set)
85 {
86 	if (desc_set < kMaxArgumentBuffers)
87 		argument_buffer_discrete_mask |= 1u << desc_set;
88 }
89 
set_argument_buffer_device_address_space(uint32_t desc_set,bool device_storage)90 void CompilerMSL::set_argument_buffer_device_address_space(uint32_t desc_set, bool device_storage)
91 {
92 	if (desc_set < kMaxArgumentBuffers)
93 	{
94 		if (device_storage)
95 			argument_buffer_device_storage_mask |= 1u << desc_set;
96 		else
97 			argument_buffer_device_storage_mask &= ~(1u << desc_set);
98 	}
99 }
100 
is_msl_shader_input_used(uint32_t location)101 bool CompilerMSL::is_msl_shader_input_used(uint32_t location)
102 {
103 	return inputs_in_use.count(location) != 0;
104 }
105 
is_msl_resource_binding_used(ExecutionModel model,uint32_t desc_set,uint32_t binding) const106 bool CompilerMSL::is_msl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const
107 {
108 	StageSetBinding tuple = { model, desc_set, binding };
109 	auto itr = resource_bindings.find(tuple);
110 	return itr != end(resource_bindings) && itr->second.second;
111 }
112 
113 // Returns the size of the array of resources used by the variable with the specified id.
114 // The returned value is retrieved from the resource binding added using add_msl_resource_binding().
get_resource_array_size(uint32_t id) const115 uint32_t CompilerMSL::get_resource_array_size(uint32_t id) const
116 {
117 	StageSetBinding tuple = { get_entry_point().model, get_decoration(id, DecorationDescriptorSet),
118 		                      get_decoration(id, DecorationBinding) };
119 	auto itr = resource_bindings.find(tuple);
120 	return itr != end(resource_bindings) ? itr->second.first.count : 0;
121 }
122 
get_automatic_msl_resource_binding(uint32_t id) const123 uint32_t CompilerMSL::get_automatic_msl_resource_binding(uint32_t id) const
124 {
125 	return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexPrimary);
126 }
127 
get_automatic_msl_resource_binding_secondary(uint32_t id) const128 uint32_t CompilerMSL::get_automatic_msl_resource_binding_secondary(uint32_t id) const
129 {
130 	return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexSecondary);
131 }
132 
get_automatic_msl_resource_binding_tertiary(uint32_t id) const133 uint32_t CompilerMSL::get_automatic_msl_resource_binding_tertiary(uint32_t id) const
134 {
135 	return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexTertiary);
136 }
137 
get_automatic_msl_resource_binding_quaternary(uint32_t id) const138 uint32_t CompilerMSL::get_automatic_msl_resource_binding_quaternary(uint32_t id) const
139 {
140 	return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexQuaternary);
141 }
142 
set_fragment_output_components(uint32_t location,uint32_t components)143 void CompilerMSL::set_fragment_output_components(uint32_t location, uint32_t components)
144 {
145 	fragment_output_components[location] = components;
146 }
147 
builtin_translates_to_nonarray(spv::BuiltIn builtin) const148 bool CompilerMSL::builtin_translates_to_nonarray(spv::BuiltIn builtin) const
149 {
150 	return (builtin == BuiltInSampleMask);
151 }
152 
build_implicit_builtins()153 void CompilerMSL::build_implicit_builtins()
154 {
155 	bool need_sample_pos = active_input_builtins.get(BuiltInSamplePosition);
156 	bool need_vertex_params = capture_output_to_buffer && get_execution_model() == ExecutionModelVertex &&
157 	                          !msl_options.vertex_for_tessellation;
158 	bool need_tesc_params = get_execution_model() == ExecutionModelTessellationControl;
159 	bool need_subgroup_mask =
160 	    active_input_builtins.get(BuiltInSubgroupEqMask) || active_input_builtins.get(BuiltInSubgroupGeMask) ||
161 	    active_input_builtins.get(BuiltInSubgroupGtMask) || active_input_builtins.get(BuiltInSubgroupLeMask) ||
162 	    active_input_builtins.get(BuiltInSubgroupLtMask);
163 	bool need_subgroup_ge_mask = !msl_options.is_ios() && (active_input_builtins.get(BuiltInSubgroupGeMask) ||
164 	                                                       active_input_builtins.get(BuiltInSubgroupGtMask));
165 	bool need_multiview = get_execution_model() == ExecutionModelVertex && !msl_options.view_index_from_device_index &&
166 	                      msl_options.multiview_layered_rendering &&
167 	                      (msl_options.multiview || active_input_builtins.get(BuiltInViewIndex));
168 	bool need_dispatch_base =
169 	    msl_options.dispatch_base && get_execution_model() == ExecutionModelGLCompute &&
170 	    (active_input_builtins.get(BuiltInWorkgroupId) || active_input_builtins.get(BuiltInGlobalInvocationId));
171 	bool need_grid_params = get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation;
172 	bool need_vertex_base_params =
173 	    need_grid_params &&
174 	    (active_input_builtins.get(BuiltInVertexId) || active_input_builtins.get(BuiltInVertexIndex) ||
175 	     active_input_builtins.get(BuiltInBaseVertex) || active_input_builtins.get(BuiltInInstanceId) ||
176 	     active_input_builtins.get(BuiltInInstanceIndex) || active_input_builtins.get(BuiltInBaseInstance));
177 	bool need_sample_mask = msl_options.additional_fixed_sample_mask != 0xffffffff;
178 	bool need_local_invocation_index = msl_options.emulate_subgroups && active_input_builtins.get(BuiltInSubgroupId);
179 	bool need_workgroup_size = msl_options.emulate_subgroups && active_input_builtins.get(BuiltInNumSubgroups);
180 	if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
181 	    need_multiview || need_dispatch_base || need_vertex_base_params || need_grid_params || needs_sample_id ||
182 	    needs_subgroup_invocation_id || needs_subgroup_size || need_sample_mask || need_local_invocation_index ||
183 	    need_workgroup_size)
184 	{
185 		bool has_frag_coord = false;
186 		bool has_sample_id = false;
187 		bool has_vertex_idx = false;
188 		bool has_base_vertex = false;
189 		bool has_instance_idx = false;
190 		bool has_base_instance = false;
191 		bool has_invocation_id = false;
192 		bool has_primitive_id = false;
193 		bool has_subgroup_invocation_id = false;
194 		bool has_subgroup_size = false;
195 		bool has_view_idx = false;
196 		bool has_layer = false;
197 		bool has_local_invocation_index = false;
198 		bool has_workgroup_size = false;
199 		uint32_t workgroup_id_type = 0;
200 
201 		ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
202 			if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
203 				return;
204 			if (!interface_variable_exists_in_entry_point(var.self))
205 				return;
206 			if (!has_decoration(var.self, DecorationBuiltIn))
207 				return;
208 
209 			BuiltIn builtin = ir.meta[var.self].decoration.builtin_type;
210 
211 			if (var.storage == StorageClassOutput)
212 			{
213 				if (need_sample_mask && builtin == BuiltInSampleMask)
214 				{
215 					builtin_sample_mask_id = var.self;
216 					mark_implicit_builtin(StorageClassOutput, BuiltInSampleMask, var.self);
217 					does_shader_write_sample_mask = true;
218 				}
219 			}
220 
221 			if (var.storage != StorageClassInput)
222 				return;
223 
224 			// Use Metal's native frame-buffer fetch API for subpass inputs.
225 			if (need_subpass_input && (!msl_options.use_framebuffer_fetch_subpasses))
226 			{
227 				switch (builtin)
228 				{
229 				case BuiltInFragCoord:
230 					mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var.self);
231 					builtin_frag_coord_id = var.self;
232 					has_frag_coord = true;
233 					break;
234 				case BuiltInLayer:
235 					if (!msl_options.arrayed_subpass_input || msl_options.multiview)
236 						break;
237 					mark_implicit_builtin(StorageClassInput, BuiltInLayer, var.self);
238 					builtin_layer_id = var.self;
239 					has_layer = true;
240 					break;
241 				case BuiltInViewIndex:
242 					if (!msl_options.multiview)
243 						break;
244 					mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var.self);
245 					builtin_view_idx_id = var.self;
246 					has_view_idx = true;
247 					break;
248 				default:
249 					break;
250 				}
251 			}
252 
253 			if ((need_sample_pos || needs_sample_id) && builtin == BuiltInSampleId)
254 			{
255 				builtin_sample_id_id = var.self;
256 				mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var.self);
257 				has_sample_id = true;
258 			}
259 
260 			if (need_vertex_params)
261 			{
262 				switch (builtin)
263 				{
264 				case BuiltInVertexIndex:
265 					builtin_vertex_idx_id = var.self;
266 					mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var.self);
267 					has_vertex_idx = true;
268 					break;
269 				case BuiltInBaseVertex:
270 					builtin_base_vertex_id = var.self;
271 					mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var.self);
272 					has_base_vertex = true;
273 					break;
274 				case BuiltInInstanceIndex:
275 					builtin_instance_idx_id = var.self;
276 					mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var.self);
277 					has_instance_idx = true;
278 					break;
279 				case BuiltInBaseInstance:
280 					builtin_base_instance_id = var.self;
281 					mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var.self);
282 					has_base_instance = true;
283 					break;
284 				default:
285 					break;
286 				}
287 			}
288 
289 			if (need_tesc_params)
290 			{
291 				switch (builtin)
292 				{
293 				case BuiltInInvocationId:
294 					builtin_invocation_id_id = var.self;
295 					mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var.self);
296 					has_invocation_id = true;
297 					break;
298 				case BuiltInPrimitiveId:
299 					builtin_primitive_id_id = var.self;
300 					mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var.self);
301 					has_primitive_id = true;
302 					break;
303 				default:
304 					break;
305 				}
306 			}
307 
308 			if ((need_subgroup_mask || needs_subgroup_invocation_id) && builtin == BuiltInSubgroupLocalInvocationId)
309 			{
310 				builtin_subgroup_invocation_id_id = var.self;
311 				mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var.self);
312 				has_subgroup_invocation_id = true;
313 			}
314 
315 			if ((need_subgroup_ge_mask || needs_subgroup_size) && builtin == BuiltInSubgroupSize)
316 			{
317 				builtin_subgroup_size_id = var.self;
318 				mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var.self);
319 				has_subgroup_size = true;
320 			}
321 
322 			if (need_multiview)
323 			{
324 				switch (builtin)
325 				{
326 				case BuiltInInstanceIndex:
327 					// The view index here is derived from the instance index.
328 					builtin_instance_idx_id = var.self;
329 					mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var.self);
330 					has_instance_idx = true;
331 					break;
332 				case BuiltInBaseInstance:
333 					// If a non-zero base instance is used, we need to adjust for it when calculating the view index.
334 					builtin_base_instance_id = var.self;
335 					mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var.self);
336 					has_base_instance = true;
337 					break;
338 				case BuiltInViewIndex:
339 					builtin_view_idx_id = var.self;
340 					mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var.self);
341 					has_view_idx = true;
342 					break;
343 				default:
344 					break;
345 				}
346 			}
347 
348 			if (need_local_invocation_index && builtin == BuiltInLocalInvocationIndex)
349 			{
350 				builtin_local_invocation_index_id = var.self;
351 				mark_implicit_builtin(StorageClassInput, BuiltInLocalInvocationIndex, var.self);
352 				has_local_invocation_index = true;
353 			}
354 
355 			if (need_workgroup_size && builtin == BuiltInLocalInvocationId)
356 			{
357 				builtin_workgroup_size_id = var.self;
358 				mark_implicit_builtin(StorageClassInput, BuiltInWorkgroupSize, var.self);
359 				has_workgroup_size = true;
360 			}
361 
362 			// The base workgroup needs to have the same type and vector size
363 			// as the workgroup or invocation ID, so keep track of the type that
364 			// was used.
365 			if (need_dispatch_base && workgroup_id_type == 0 &&
366 			    (builtin == BuiltInWorkgroupId || builtin == BuiltInGlobalInvocationId))
367 				workgroup_id_type = var.basetype;
368 		});
369 
370 		// Use Metal's native frame-buffer fetch API for subpass inputs.
371 		if ((!has_frag_coord || (msl_options.multiview && !has_view_idx) ||
372 		     (msl_options.arrayed_subpass_input && !msl_options.multiview && !has_layer)) &&
373 		    (!msl_options.use_framebuffer_fetch_subpasses) && need_subpass_input)
374 		{
375 			if (!has_frag_coord)
376 			{
377 				uint32_t offset = ir.increase_bound_by(3);
378 				uint32_t type_id = offset;
379 				uint32_t type_ptr_id = offset + 1;
380 				uint32_t var_id = offset + 2;
381 
382 				// Create gl_FragCoord.
383 				SPIRType vec4_type;
384 				vec4_type.basetype = SPIRType::Float;
385 				vec4_type.width = 32;
386 				vec4_type.vecsize = 4;
387 				set<SPIRType>(type_id, vec4_type);
388 
389 				SPIRType vec4_type_ptr;
390 				vec4_type_ptr = vec4_type;
391 				vec4_type_ptr.pointer = true;
392 				vec4_type_ptr.parent_type = type_id;
393 				vec4_type_ptr.storage = StorageClassInput;
394 				auto &ptr_type = set<SPIRType>(type_ptr_id, vec4_type_ptr);
395 				ptr_type.self = type_id;
396 
397 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
398 				set_decoration(var_id, DecorationBuiltIn, BuiltInFragCoord);
399 				builtin_frag_coord_id = var_id;
400 				mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var_id);
401 			}
402 
403 			if (!has_layer && msl_options.arrayed_subpass_input && !msl_options.multiview)
404 			{
405 				uint32_t offset = ir.increase_bound_by(2);
406 				uint32_t type_ptr_id = offset;
407 				uint32_t var_id = offset + 1;
408 
409 				// Create gl_Layer.
410 				SPIRType uint_type_ptr;
411 				uint_type_ptr = get_uint_type();
412 				uint_type_ptr.pointer = true;
413 				uint_type_ptr.parent_type = get_uint_type_id();
414 				uint_type_ptr.storage = StorageClassInput;
415 				auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
416 				ptr_type.self = get_uint_type_id();
417 
418 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
419 				set_decoration(var_id, DecorationBuiltIn, BuiltInLayer);
420 				builtin_layer_id = var_id;
421 				mark_implicit_builtin(StorageClassInput, BuiltInLayer, var_id);
422 			}
423 
424 			if (!has_view_idx && msl_options.multiview)
425 			{
426 				uint32_t offset = ir.increase_bound_by(2);
427 				uint32_t type_ptr_id = offset;
428 				uint32_t var_id = offset + 1;
429 
430 				// Create gl_ViewIndex.
431 				SPIRType uint_type_ptr;
432 				uint_type_ptr = get_uint_type();
433 				uint_type_ptr.pointer = true;
434 				uint_type_ptr.parent_type = get_uint_type_id();
435 				uint_type_ptr.storage = StorageClassInput;
436 				auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
437 				ptr_type.self = get_uint_type_id();
438 
439 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
440 				set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex);
441 				builtin_view_idx_id = var_id;
442 				mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id);
443 			}
444 		}
445 
446 		if (!has_sample_id && (need_sample_pos || needs_sample_id))
447 		{
448 			uint32_t offset = ir.increase_bound_by(2);
449 			uint32_t type_ptr_id = offset;
450 			uint32_t var_id = offset + 1;
451 
452 			// Create gl_SampleID.
453 			SPIRType uint_type_ptr;
454 			uint_type_ptr = get_uint_type();
455 			uint_type_ptr.pointer = true;
456 			uint_type_ptr.parent_type = get_uint_type_id();
457 			uint_type_ptr.storage = StorageClassInput;
458 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
459 			ptr_type.self = get_uint_type_id();
460 
461 			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
462 			set_decoration(var_id, DecorationBuiltIn, BuiltInSampleId);
463 			builtin_sample_id_id = var_id;
464 			mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var_id);
465 		}
466 
467 		if ((need_vertex_params && (!has_vertex_idx || !has_base_vertex || !has_instance_idx || !has_base_instance)) ||
468 		    (need_multiview && (!has_instance_idx || !has_base_instance || !has_view_idx)))
469 		{
470 			uint32_t type_ptr_id = ir.increase_bound_by(1);
471 
472 			SPIRType uint_type_ptr;
473 			uint_type_ptr = get_uint_type();
474 			uint_type_ptr.pointer = true;
475 			uint_type_ptr.parent_type = get_uint_type_id();
476 			uint_type_ptr.storage = StorageClassInput;
477 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
478 			ptr_type.self = get_uint_type_id();
479 
480 			if (need_vertex_params && !has_vertex_idx)
481 			{
482 				uint32_t var_id = ir.increase_bound_by(1);
483 
484 				// Create gl_VertexIndex.
485 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
486 				set_decoration(var_id, DecorationBuiltIn, BuiltInVertexIndex);
487 				builtin_vertex_idx_id = var_id;
488 				mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var_id);
489 			}
490 
491 			if (need_vertex_params && !has_base_vertex)
492 			{
493 				uint32_t var_id = ir.increase_bound_by(1);
494 
495 				// Create gl_BaseVertex.
496 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
497 				set_decoration(var_id, DecorationBuiltIn, BuiltInBaseVertex);
498 				builtin_base_vertex_id = var_id;
499 				mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var_id);
500 			}
501 
502 			if (!has_instance_idx) // Needed by both multiview and tessellation
503 			{
504 				uint32_t var_id = ir.increase_bound_by(1);
505 
506 				// Create gl_InstanceIndex.
507 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
508 				set_decoration(var_id, DecorationBuiltIn, BuiltInInstanceIndex);
509 				builtin_instance_idx_id = var_id;
510 				mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var_id);
511 			}
512 
513 			if (!has_base_instance) // Needed by both multiview and tessellation
514 			{
515 				uint32_t var_id = ir.increase_bound_by(1);
516 
517 				// Create gl_BaseInstance.
518 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
519 				set_decoration(var_id, DecorationBuiltIn, BuiltInBaseInstance);
520 				builtin_base_instance_id = var_id;
521 				mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var_id);
522 			}
523 
524 			if (need_multiview)
525 			{
526 				// Multiview shaders are not allowed to write to gl_Layer, ostensibly because
527 				// it is implicitly written from gl_ViewIndex, but we have to do that explicitly.
528 				// Note that we can't just abuse gl_ViewIndex for this purpose: it's an input, but
529 				// gl_Layer is an output in vertex-pipeline shaders.
530 				uint32_t type_ptr_out_id = ir.increase_bound_by(2);
531 				SPIRType uint_type_ptr_out;
532 				uint_type_ptr_out = get_uint_type();
533 				uint_type_ptr_out.pointer = true;
534 				uint_type_ptr_out.parent_type = get_uint_type_id();
535 				uint_type_ptr_out.storage = StorageClassOutput;
536 				auto &ptr_out_type = set<SPIRType>(type_ptr_out_id, uint_type_ptr_out);
537 				ptr_out_type.self = get_uint_type_id();
538 				uint32_t var_id = type_ptr_out_id + 1;
539 				set<SPIRVariable>(var_id, type_ptr_out_id, StorageClassOutput);
540 				set_decoration(var_id, DecorationBuiltIn, BuiltInLayer);
541 				builtin_layer_id = var_id;
542 				mark_implicit_builtin(StorageClassOutput, BuiltInLayer, var_id);
543 			}
544 
545 			if (need_multiview && !has_view_idx)
546 			{
547 				uint32_t var_id = ir.increase_bound_by(1);
548 
549 				// Create gl_ViewIndex.
550 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
551 				set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex);
552 				builtin_view_idx_id = var_id;
553 				mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id);
554 			}
555 		}
556 
557 		if ((need_tesc_params && (msl_options.multi_patch_workgroup || !has_invocation_id || !has_primitive_id)) ||
558 		    need_grid_params)
559 		{
560 			uint32_t type_ptr_id = ir.increase_bound_by(1);
561 
562 			SPIRType uint_type_ptr;
563 			uint_type_ptr = get_uint_type();
564 			uint_type_ptr.pointer = true;
565 			uint_type_ptr.parent_type = get_uint_type_id();
566 			uint_type_ptr.storage = StorageClassInput;
567 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
568 			ptr_type.self = get_uint_type_id();
569 
570 			if (msl_options.multi_patch_workgroup || need_grid_params)
571 			{
572 				uint32_t var_id = ir.increase_bound_by(1);
573 
574 				// Create gl_GlobalInvocationID.
575 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
576 				set_decoration(var_id, DecorationBuiltIn, BuiltInGlobalInvocationId);
577 				builtin_invocation_id_id = var_id;
578 				mark_implicit_builtin(StorageClassInput, BuiltInGlobalInvocationId, var_id);
579 			}
580 			else if (need_tesc_params && !has_invocation_id)
581 			{
582 				uint32_t var_id = ir.increase_bound_by(1);
583 
584 				// Create gl_InvocationID.
585 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
586 				set_decoration(var_id, DecorationBuiltIn, BuiltInInvocationId);
587 				builtin_invocation_id_id = var_id;
588 				mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var_id);
589 			}
590 
591 			if (need_tesc_params && !has_primitive_id)
592 			{
593 				uint32_t var_id = ir.increase_bound_by(1);
594 
595 				// Create gl_PrimitiveID.
596 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
597 				set_decoration(var_id, DecorationBuiltIn, BuiltInPrimitiveId);
598 				builtin_primitive_id_id = var_id;
599 				mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var_id);
600 			}
601 
602 			if (need_grid_params)
603 			{
604 				uint32_t var_id = ir.increase_bound_by(1);
605 
606 				set<SPIRVariable>(var_id, build_extended_vector_type(get_uint_type_id(), 3), StorageClassInput);
607 				set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInStageInputSize);
608 				get_entry_point().interface_variables.push_back(var_id);
609 				set_name(var_id, "spvStageInputSize");
610 				builtin_stage_input_size_id = var_id;
611 			}
612 		}
613 
614 		if (!has_subgroup_invocation_id && (need_subgroup_mask || needs_subgroup_invocation_id))
615 		{
616 			uint32_t offset = ir.increase_bound_by(2);
617 			uint32_t type_ptr_id = offset;
618 			uint32_t var_id = offset + 1;
619 
620 			// Create gl_SubgroupInvocationID.
621 			SPIRType uint_type_ptr;
622 			uint_type_ptr = get_uint_type();
623 			uint_type_ptr.pointer = true;
624 			uint_type_ptr.parent_type = get_uint_type_id();
625 			uint_type_ptr.storage = StorageClassInput;
626 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
627 			ptr_type.self = get_uint_type_id();
628 
629 			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
630 			set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupLocalInvocationId);
631 			builtin_subgroup_invocation_id_id = var_id;
632 			mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var_id);
633 		}
634 
635 		if (!has_subgroup_size && (need_subgroup_ge_mask || needs_subgroup_size))
636 		{
637 			uint32_t offset = ir.increase_bound_by(2);
638 			uint32_t type_ptr_id = offset;
639 			uint32_t var_id = offset + 1;
640 
641 			// Create gl_SubgroupSize.
642 			SPIRType uint_type_ptr;
643 			uint_type_ptr = get_uint_type();
644 			uint_type_ptr.pointer = true;
645 			uint_type_ptr.parent_type = get_uint_type_id();
646 			uint_type_ptr.storage = StorageClassInput;
647 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
648 			ptr_type.self = get_uint_type_id();
649 
650 			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
651 			set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupSize);
652 			builtin_subgroup_size_id = var_id;
653 			mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var_id);
654 		}
655 
656 		if (need_dispatch_base || need_vertex_base_params)
657 		{
658 			if (workgroup_id_type == 0)
659 				workgroup_id_type = build_extended_vector_type(get_uint_type_id(), 3);
660 			uint32_t var_id;
661 			if (msl_options.supports_msl_version(1, 2))
662 			{
663 				// If we have MSL 1.2, we can (ab)use the [[grid_origin]] builtin
664 				// to convey this information and save a buffer slot.
665 				uint32_t offset = ir.increase_bound_by(1);
666 				var_id = offset;
667 
668 				set<SPIRVariable>(var_id, workgroup_id_type, StorageClassInput);
669 				set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase);
670 				get_entry_point().interface_variables.push_back(var_id);
671 			}
672 			else
673 			{
674 				// Otherwise, we need to fall back to a good ol' fashioned buffer.
675 				uint32_t offset = ir.increase_bound_by(2);
676 				var_id = offset;
677 				uint32_t type_id = offset + 1;
678 
679 				SPIRType var_type = get<SPIRType>(workgroup_id_type);
680 				var_type.storage = StorageClassUniform;
681 				set<SPIRType>(type_id, var_type);
682 
683 				set<SPIRVariable>(var_id, type_id, StorageClassUniform);
684 				// This should never match anything.
685 				set_decoration(var_id, DecorationDescriptorSet, ~(5u));
686 				set_decoration(var_id, DecorationBinding, msl_options.indirect_params_buffer_index);
687 				set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary,
688 				                        msl_options.indirect_params_buffer_index);
689 			}
690 			set_name(var_id, "spvDispatchBase");
691 			builtin_dispatch_base_id = var_id;
692 		}
693 
694 		if (need_sample_mask && !does_shader_write_sample_mask)
695 		{
696 			uint32_t offset = ir.increase_bound_by(2);
697 			uint32_t var_id = offset + 1;
698 
699 			// Create gl_SampleMask.
700 			SPIRType uint_type_ptr_out;
701 			uint_type_ptr_out = get_uint_type();
702 			uint_type_ptr_out.pointer = true;
703 			uint_type_ptr_out.parent_type = get_uint_type_id();
704 			uint_type_ptr_out.storage = StorageClassOutput;
705 
706 			auto &ptr_out_type = set<SPIRType>(offset, uint_type_ptr_out);
707 			ptr_out_type.self = get_uint_type_id();
708 			set<SPIRVariable>(var_id, offset, StorageClassOutput);
709 			set_decoration(var_id, DecorationBuiltIn, BuiltInSampleMask);
710 			builtin_sample_mask_id = var_id;
711 			mark_implicit_builtin(StorageClassOutput, BuiltInSampleMask, var_id);
712 		}
713 
714 		if (need_local_invocation_index && !has_local_invocation_index)
715 		{
716 			uint32_t offset = ir.increase_bound_by(2);
717 			uint32_t type_ptr_id = offset;
718 			uint32_t var_id = offset + 1;
719 
720 			// Create gl_LocalInvocationIndex.
721 			SPIRType uint_type_ptr;
722 			uint_type_ptr = get_uint_type();
723 			uint_type_ptr.pointer = true;
724 			uint_type_ptr.parent_type = get_uint_type_id();
725 			uint_type_ptr.storage = StorageClassInput;
726 
727 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
728 			ptr_type.self = get_uint_type_id();
729 			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
730 			set_decoration(var_id, DecorationBuiltIn, BuiltInLocalInvocationIndex);
731 			builtin_local_invocation_index_id = var_id;
732 			mark_implicit_builtin(StorageClassInput, BuiltInLocalInvocationIndex, var_id);
733 		}
734 
735 		if (need_workgroup_size && !has_workgroup_size)
736 		{
737 			uint32_t offset = ir.increase_bound_by(2);
738 			uint32_t type_ptr_id = offset;
739 			uint32_t var_id = offset + 1;
740 
741 			// Create gl_WorkgroupSize.
742 			uint32_t type_id = build_extended_vector_type(get_uint_type_id(), 3);
743 			SPIRType uint_type_ptr = get<SPIRType>(type_id);
744 			uint_type_ptr.pointer = true;
745 			uint_type_ptr.parent_type = type_id;
746 			uint_type_ptr.storage = StorageClassInput;
747 
748 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
749 			ptr_type.self = type_id;
750 			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
751 			set_decoration(var_id, DecorationBuiltIn, BuiltInWorkgroupSize);
752 			builtin_workgroup_size_id = var_id;
753 			mark_implicit_builtin(StorageClassInput, BuiltInWorkgroupSize, var_id);
754 		}
755 	}
756 
757 	if (needs_swizzle_buffer_def)
758 	{
759 		uint32_t var_id = build_constant_uint_array_pointer();
760 		set_name(var_id, "spvSwizzleConstants");
761 		// This should never match anything.
762 		set_decoration(var_id, DecorationDescriptorSet, kSwizzleBufferBinding);
763 		set_decoration(var_id, DecorationBinding, msl_options.swizzle_buffer_index);
764 		set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.swizzle_buffer_index);
765 		swizzle_buffer_id = var_id;
766 	}
767 
768 	if (!buffers_requiring_array_length.empty())
769 	{
770 		uint32_t var_id = build_constant_uint_array_pointer();
771 		set_name(var_id, "spvBufferSizeConstants");
772 		// This should never match anything.
773 		set_decoration(var_id, DecorationDescriptorSet, kBufferSizeBufferBinding);
774 		set_decoration(var_id, DecorationBinding, msl_options.buffer_size_buffer_index);
775 		set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.buffer_size_buffer_index);
776 		buffer_size_buffer_id = var_id;
777 	}
778 
779 	if (needs_view_mask_buffer())
780 	{
781 		uint32_t var_id = build_constant_uint_array_pointer();
782 		set_name(var_id, "spvViewMask");
783 		// This should never match anything.
784 		set_decoration(var_id, DecorationDescriptorSet, ~(4u));
785 		set_decoration(var_id, DecorationBinding, msl_options.view_mask_buffer_index);
786 		set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.view_mask_buffer_index);
787 		view_mask_buffer_id = var_id;
788 	}
789 
790 	if (!buffers_requiring_dynamic_offset.empty())
791 	{
792 		uint32_t var_id = build_constant_uint_array_pointer();
793 		set_name(var_id, "spvDynamicOffsets");
794 		// This should never match anything.
795 		set_decoration(var_id, DecorationDescriptorSet, ~(5u));
796 		set_decoration(var_id, DecorationBinding, msl_options.dynamic_offsets_buffer_index);
797 		set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary,
798 		                        msl_options.dynamic_offsets_buffer_index);
799 		dynamic_offsets_buffer_id = var_id;
800 	}
801 }
802 
803 // Checks if the specified builtin variable (e.g. gl_InstanceIndex) is marked as active.
804 // If not, it marks it as active and forces a recompilation.
805 // This might be used when the optimization of inactive builtins was too optimistic (e.g. when "spvOut" is emitted).
ensure_builtin(spv::StorageClass storage,spv::BuiltIn builtin)806 void CompilerMSL::ensure_builtin(spv::StorageClass storage, spv::BuiltIn builtin)
807 {
808 	Bitset *active_builtins = nullptr;
809 	switch (storage)
810 	{
811 	case StorageClassInput:
812 		active_builtins = &active_input_builtins;
813 		break;
814 
815 	case StorageClassOutput:
816 		active_builtins = &active_output_builtins;
817 		break;
818 
819 	default:
820 		break;
821 	}
822 
823 	// At this point, the specified builtin variable must have already been declared in the entry point.
824 	// If not, mark as active and force recompile.
825 	if (active_builtins != nullptr && !active_builtins->get(builtin))
826 	{
827 		active_builtins->set(builtin);
828 		force_recompile();
829 	}
830 }
831 
mark_implicit_builtin(StorageClass storage,BuiltIn builtin,uint32_t id)832 void CompilerMSL::mark_implicit_builtin(StorageClass storage, BuiltIn builtin, uint32_t id)
833 {
834 	Bitset *active_builtins = nullptr;
835 	switch (storage)
836 	{
837 	case StorageClassInput:
838 		active_builtins = &active_input_builtins;
839 		break;
840 
841 	case StorageClassOutput:
842 		active_builtins = &active_output_builtins;
843 		break;
844 
845 	default:
846 		break;
847 	}
848 
849 	assert(active_builtins != nullptr);
850 	active_builtins->set(builtin);
851 
852 	auto &var = get_entry_point().interface_variables;
853 	if (find(begin(var), end(var), VariableID(id)) == end(var))
854 		var.push_back(id);
855 }
856 
build_constant_uint_array_pointer()857 uint32_t CompilerMSL::build_constant_uint_array_pointer()
858 {
859 	uint32_t offset = ir.increase_bound_by(3);
860 	uint32_t type_ptr_id = offset;
861 	uint32_t type_ptr_ptr_id = offset + 1;
862 	uint32_t var_id = offset + 2;
863 
864 	// Create a buffer to hold extra data, including the swizzle constants.
865 	SPIRType uint_type_pointer = get_uint_type();
866 	uint_type_pointer.pointer = true;
867 	uint_type_pointer.pointer_depth = 1;
868 	uint_type_pointer.parent_type = get_uint_type_id();
869 	uint_type_pointer.storage = StorageClassUniform;
870 	set<SPIRType>(type_ptr_id, uint_type_pointer);
871 	set_decoration(type_ptr_id, DecorationArrayStride, 4);
872 
873 	SPIRType uint_type_pointer2 = uint_type_pointer;
874 	uint_type_pointer2.pointer_depth++;
875 	uint_type_pointer2.parent_type = type_ptr_id;
876 	set<SPIRType>(type_ptr_ptr_id, uint_type_pointer2);
877 
878 	set<SPIRVariable>(var_id, type_ptr_ptr_id, StorageClassUniformConstant);
879 	return var_id;
880 }
881 
create_sampler_address(const char * prefix,MSLSamplerAddress addr)882 static string create_sampler_address(const char *prefix, MSLSamplerAddress addr)
883 {
884 	switch (addr)
885 	{
886 	case MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE:
887 		return join(prefix, "address::clamp_to_edge");
888 	case MSL_SAMPLER_ADDRESS_CLAMP_TO_ZERO:
889 		return join(prefix, "address::clamp_to_zero");
890 	case MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER:
891 		return join(prefix, "address::clamp_to_border");
892 	case MSL_SAMPLER_ADDRESS_REPEAT:
893 		return join(prefix, "address::repeat");
894 	case MSL_SAMPLER_ADDRESS_MIRRORED_REPEAT:
895 		return join(prefix, "address::mirrored_repeat");
896 	default:
897 		SPIRV_CROSS_THROW("Invalid sampler addressing mode.");
898 	}
899 }
900 
get_stage_in_struct_type()901 SPIRType &CompilerMSL::get_stage_in_struct_type()
902 {
903 	auto &si_var = get<SPIRVariable>(stage_in_var_id);
904 	return get_variable_data_type(si_var);
905 }
906 
get_stage_out_struct_type()907 SPIRType &CompilerMSL::get_stage_out_struct_type()
908 {
909 	auto &so_var = get<SPIRVariable>(stage_out_var_id);
910 	return get_variable_data_type(so_var);
911 }
912 
get_patch_stage_in_struct_type()913 SPIRType &CompilerMSL::get_patch_stage_in_struct_type()
914 {
915 	auto &si_var = get<SPIRVariable>(patch_stage_in_var_id);
916 	return get_variable_data_type(si_var);
917 }
918 
get_patch_stage_out_struct_type()919 SPIRType &CompilerMSL::get_patch_stage_out_struct_type()
920 {
921 	auto &so_var = get<SPIRVariable>(patch_stage_out_var_id);
922 	return get_variable_data_type(so_var);
923 }
924 
get_tess_factor_struct_name()925 std::string CompilerMSL::get_tess_factor_struct_name()
926 {
927 	if (get_entry_point().flags.get(ExecutionModeTriangles))
928 		return "MTLTriangleTessellationFactorsHalf";
929 	return "MTLQuadTessellationFactorsHalf";
930 }
931 
get_uint_type()932 SPIRType &CompilerMSL::get_uint_type()
933 {
934 	return get<SPIRType>(get_uint_type_id());
935 }
936 
get_uint_type_id()937 uint32_t CompilerMSL::get_uint_type_id()
938 {
939 	if (uint_type_id != 0)
940 		return uint_type_id;
941 
942 	uint_type_id = ir.increase_bound_by(1);
943 
944 	SPIRType type;
945 	type.basetype = SPIRType::UInt;
946 	type.width = 32;
947 	set<SPIRType>(uint_type_id, type);
948 	return uint_type_id;
949 }
950 
emit_entry_point_declarations()951 void CompilerMSL::emit_entry_point_declarations()
952 {
953 	// FIXME: Get test coverage here ...
954 	// Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
955 	declare_complex_constant_arrays();
956 
957 	// Emit constexpr samplers here.
958 	for (auto &samp : constexpr_samplers_by_id)
959 	{
960 		auto &var = get<SPIRVariable>(samp.first);
961 		auto &type = get<SPIRType>(var.basetype);
962 		if (type.basetype == SPIRType::Sampler)
963 			add_resource_name(samp.first);
964 
965 		SmallVector<string> args;
966 		auto &s = samp.second;
967 
968 		if (s.coord != MSL_SAMPLER_COORD_NORMALIZED)
969 			args.push_back("coord::pixel");
970 
971 		if (s.min_filter == s.mag_filter)
972 		{
973 			if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
974 				args.push_back("filter::linear");
975 		}
976 		else
977 		{
978 			if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
979 				args.push_back("min_filter::linear");
980 			if (s.mag_filter != MSL_SAMPLER_FILTER_NEAREST)
981 				args.push_back("mag_filter::linear");
982 		}
983 
984 		switch (s.mip_filter)
985 		{
986 		case MSL_SAMPLER_MIP_FILTER_NONE:
987 			// Default
988 			break;
989 		case MSL_SAMPLER_MIP_FILTER_NEAREST:
990 			args.push_back("mip_filter::nearest");
991 			break;
992 		case MSL_SAMPLER_MIP_FILTER_LINEAR:
993 			args.push_back("mip_filter::linear");
994 			break;
995 		default:
996 			SPIRV_CROSS_THROW("Invalid mip filter.");
997 		}
998 
999 		if (s.s_address == s.t_address && s.s_address == s.r_address)
1000 		{
1001 			if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1002 				args.push_back(create_sampler_address("", s.s_address));
1003 		}
1004 		else
1005 		{
1006 			if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1007 				args.push_back(create_sampler_address("s_", s.s_address));
1008 			if (s.t_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1009 				args.push_back(create_sampler_address("t_", s.t_address));
1010 			if (s.r_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1011 				args.push_back(create_sampler_address("r_", s.r_address));
1012 		}
1013 
1014 		if (s.compare_enable)
1015 		{
1016 			switch (s.compare_func)
1017 			{
1018 			case MSL_SAMPLER_COMPARE_FUNC_ALWAYS:
1019 				args.push_back("compare_func::always");
1020 				break;
1021 			case MSL_SAMPLER_COMPARE_FUNC_NEVER:
1022 				args.push_back("compare_func::never");
1023 				break;
1024 			case MSL_SAMPLER_COMPARE_FUNC_EQUAL:
1025 				args.push_back("compare_func::equal");
1026 				break;
1027 			case MSL_SAMPLER_COMPARE_FUNC_NOT_EQUAL:
1028 				args.push_back("compare_func::not_equal");
1029 				break;
1030 			case MSL_SAMPLER_COMPARE_FUNC_LESS:
1031 				args.push_back("compare_func::less");
1032 				break;
1033 			case MSL_SAMPLER_COMPARE_FUNC_LESS_EQUAL:
1034 				args.push_back("compare_func::less_equal");
1035 				break;
1036 			case MSL_SAMPLER_COMPARE_FUNC_GREATER:
1037 				args.push_back("compare_func::greater");
1038 				break;
1039 			case MSL_SAMPLER_COMPARE_FUNC_GREATER_EQUAL:
1040 				args.push_back("compare_func::greater_equal");
1041 				break;
1042 			default:
1043 				SPIRV_CROSS_THROW("Invalid sampler compare function.");
1044 			}
1045 		}
1046 
1047 		if (s.s_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER || s.t_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER ||
1048 		    s.r_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER)
1049 		{
1050 			switch (s.border_color)
1051 			{
1052 			case MSL_SAMPLER_BORDER_COLOR_OPAQUE_BLACK:
1053 				args.push_back("border_color::opaque_black");
1054 				break;
1055 			case MSL_SAMPLER_BORDER_COLOR_OPAQUE_WHITE:
1056 				args.push_back("border_color::opaque_white");
1057 				break;
1058 			case MSL_SAMPLER_BORDER_COLOR_TRANSPARENT_BLACK:
1059 				args.push_back("border_color::transparent_black");
1060 				break;
1061 			default:
1062 				SPIRV_CROSS_THROW("Invalid sampler border color.");
1063 			}
1064 		}
1065 
1066 		if (s.anisotropy_enable)
1067 			args.push_back(join("max_anisotropy(", s.max_anisotropy, ")"));
1068 		if (s.lod_clamp_enable)
1069 		{
1070 			args.push_back(join("lod_clamp(", convert_to_string(s.lod_clamp_min, current_locale_radix_character), ", ",
1071 			                    convert_to_string(s.lod_clamp_max, current_locale_radix_character), ")"));
1072 		}
1073 
1074 		// If we would emit no arguments, then omit the parentheses entirely. Otherwise,
1075 		// we'll wind up with a "most vexing parse" situation.
1076 		if (args.empty())
1077 			statement("constexpr sampler ",
1078 			          type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first),
1079 			          ";");
1080 		else
1081 			statement("constexpr sampler ",
1082 			          type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first),
1083 			          "(", merge(args), ");");
1084 	}
1085 
1086 	// Emit dynamic buffers here.
1087 	for (auto &dynamic_buffer : buffers_requiring_dynamic_offset)
1088 	{
1089 		if (!dynamic_buffer.second.second)
1090 		{
1091 			// Could happen if no buffer was used at requested binding point.
1092 			continue;
1093 		}
1094 
1095 		const auto &var = get<SPIRVariable>(dynamic_buffer.second.second);
1096 		uint32_t var_id = var.self;
1097 		const auto &type = get_variable_data_type(var);
1098 		string name = to_name(var.self);
1099 		uint32_t desc_set = get_decoration(var.self, DecorationDescriptorSet);
1100 		uint32_t arg_id = argument_buffer_ids[desc_set];
1101 		uint32_t base_index = dynamic_buffer.second.first;
1102 
1103 		if (!type.array.empty())
1104 		{
1105 			// This is complicated, because we need to support arrays of arrays.
1106 			// And it's even worse if the outermost dimension is a runtime array, because now
1107 			// all this complicated goop has to go into the shader itself. (FIXME)
1108 			if (!type.array[type.array.size() - 1])
1109 				SPIRV_CROSS_THROW("Runtime arrays with dynamic offsets are not supported yet.");
1110 			else
1111 			{
1112 				is_using_builtin_array = true;
1113 				statement(get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id), name,
1114 				          type_to_array_glsl(type), " =");
1115 
1116 				uint32_t dim = uint32_t(type.array.size());
1117 				uint32_t j = 0;
1118 				for (SmallVector<uint32_t> indices(type.array.size());
1119 				     indices[type.array.size() - 1] < to_array_size_literal(type); j++)
1120 				{
1121 					while (dim > 0)
1122 					{
1123 						begin_scope();
1124 						--dim;
1125 					}
1126 
1127 					string arrays;
1128 					for (uint32_t i = uint32_t(type.array.size()); i; --i)
1129 						arrays += join("[", indices[i - 1], "]");
1130 					statement("(", get_argument_address_space(var), " ", type_to_glsl(type), "* ",
1131 					          to_restrict(var_id, false), ")((", get_argument_address_space(var), " char* ",
1132 					          to_restrict(var_id, false), ")", to_name(arg_id), ".", ensure_valid_name(name, "m"),
1133 					          arrays, " + ", to_name(dynamic_offsets_buffer_id), "[", base_index + j, "]),");
1134 
1135 					while (++indices[dim] >= to_array_size_literal(type, dim) && dim < type.array.size() - 1)
1136 					{
1137 						end_scope(",");
1138 						indices[dim++] = 0;
1139 					}
1140 				}
1141 				end_scope_decl();
1142 				statement_no_indent("");
1143 				is_using_builtin_array = false;
1144 			}
1145 		}
1146 		else
1147 		{
1148 			statement(get_argument_address_space(var), " auto& ", to_restrict(var_id), name, " = *(",
1149 			          get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id, false), ")((",
1150 			          get_argument_address_space(var), " char* ", to_restrict(var_id, false), ")", to_name(arg_id), ".",
1151 			          ensure_valid_name(name, "m"), " + ", to_name(dynamic_offsets_buffer_id), "[", base_index, "]);");
1152 		}
1153 	}
1154 
1155 	// Emit buffer arrays here.
1156 	for (uint32_t array_id : buffer_arrays)
1157 	{
1158 		const auto &var = get<SPIRVariable>(array_id);
1159 		const auto &type = get_variable_data_type(var);
1160 		const auto &buffer_type = get_variable_element_type(var);
1161 		string name = to_name(array_id);
1162 		statement(get_argument_address_space(var), " ", type_to_glsl(buffer_type), "* ", to_restrict(array_id), name,
1163 		          "[] =");
1164 		begin_scope();
1165 		for (uint32_t i = 0; i < to_array_size_literal(type); ++i)
1166 			statement(name, "_", i, ",");
1167 		end_scope_decl();
1168 		statement_no_indent("");
1169 	}
1170 	// For some reason, without this, we end up emitting the arrays twice.
1171 	buffer_arrays.clear();
1172 
1173 	// Emit disabled fragment outputs.
1174 	std::sort(disabled_frag_outputs.begin(), disabled_frag_outputs.end());
1175 	for (uint32_t var_id : disabled_frag_outputs)
1176 	{
1177 		auto &var = get<SPIRVariable>(var_id);
1178 		add_local_variable_name(var_id);
1179 		statement(variable_decl(var), ";");
1180 		var.deferred_declaration = false;
1181 	}
1182 }
1183 
compile()1184 string CompilerMSL::compile()
1185 {
1186 	replace_illegal_entry_point_names();
1187 	ir.fixup_reserved_names();
1188 
1189 	// Do not deal with GLES-isms like precision, older extensions and such.
1190 	options.vulkan_semantics = true;
1191 	options.es = false;
1192 	options.version = 450;
1193 	backend.null_pointer_literal = "nullptr";
1194 	backend.float_literal_suffix = false;
1195 	backend.uint32_t_literal_suffix = true;
1196 	backend.int16_t_literal_suffix = "";
1197 	backend.uint16_t_literal_suffix = "";
1198 	backend.basic_int_type = "int";
1199 	backend.basic_uint_type = "uint";
1200 	backend.basic_int8_type = "char";
1201 	backend.basic_uint8_type = "uchar";
1202 	backend.basic_int16_type = "short";
1203 	backend.basic_uint16_type = "ushort";
1204 	backend.discard_literal = "discard_fragment()";
1205 	backend.demote_literal = "discard_fragment()";
1206 	backend.boolean_mix_function = "select";
1207 	backend.swizzle_is_function = false;
1208 	backend.shared_is_implied = false;
1209 	backend.use_initializer_list = true;
1210 	backend.use_typed_initializer_list = true;
1211 	backend.native_row_major_matrix = false;
1212 	backend.unsized_array_supported = false;
1213 	backend.can_declare_arrays_inline = false;
1214 	backend.allow_truncated_access_chain = true;
1215 	backend.comparison_image_samples_scalar = true;
1216 	backend.native_pointers = true;
1217 	backend.nonuniform_qualifier = "";
1218 	backend.support_small_type_sampling_result = true;
1219 	backend.supports_empty_struct = true;
1220 
1221 	// Allow Metal to use the array<T> template unless we force it off.
1222 	backend.can_return_array = !msl_options.force_native_arrays;
1223 	backend.array_is_value_type = !msl_options.force_native_arrays;
1224 	// Arrays which are part of buffer objects are never considered to be native arrays.
1225 	backend.buffer_offset_array_is_value_type = false;
1226 
1227 	capture_output_to_buffer = msl_options.capture_output_to_buffer;
1228 	is_rasterization_disabled = msl_options.disable_rasterization || capture_output_to_buffer;
1229 
1230 	// Initialize array here rather than constructor, MSVC 2013 workaround.
1231 	for (auto &id : next_metal_resource_ids)
1232 		id = 0;
1233 
1234 	fixup_type_alias();
1235 	replace_illegal_names();
1236 	sync_entry_point_aliases_and_names();
1237 
1238 	build_function_control_flow_graphs_and_analyze();
1239 	update_active_builtins();
1240 	analyze_image_and_sampler_usage();
1241 	analyze_sampled_image_usage();
1242 	analyze_interlocked_resource_usage();
1243 	preprocess_op_codes();
1244 	build_implicit_builtins();
1245 
1246 	fixup_image_load_store_access();
1247 
1248 	set_enabled_interface_variables(get_active_interface_variables());
1249 	if (msl_options.force_active_argument_buffer_resources)
1250 		activate_argument_buffer_resources();
1251 
1252 	if (swizzle_buffer_id)
1253 		active_interface_variables.insert(swizzle_buffer_id);
1254 	if (buffer_size_buffer_id)
1255 		active_interface_variables.insert(buffer_size_buffer_id);
1256 	if (view_mask_buffer_id)
1257 		active_interface_variables.insert(view_mask_buffer_id);
1258 	if (dynamic_offsets_buffer_id)
1259 		active_interface_variables.insert(dynamic_offsets_buffer_id);
1260 	if (builtin_layer_id)
1261 		active_interface_variables.insert(builtin_layer_id);
1262 	if (builtin_dispatch_base_id && !msl_options.supports_msl_version(1, 2))
1263 		active_interface_variables.insert(builtin_dispatch_base_id);
1264 	if (builtin_sample_mask_id)
1265 		active_interface_variables.insert(builtin_sample_mask_id);
1266 
1267 	// Create structs to hold input, output and uniform variables.
1268 	// Do output first to ensure out. is declared at top of entry function.
1269 	qual_pos_var_name = "";
1270 	stage_out_var_id = add_interface_block(StorageClassOutput);
1271 	patch_stage_out_var_id = add_interface_block(StorageClassOutput, true);
1272 	stage_in_var_id = add_interface_block(StorageClassInput);
1273 	if (get_execution_model() == ExecutionModelTessellationEvaluation)
1274 		patch_stage_in_var_id = add_interface_block(StorageClassInput, true);
1275 
1276 	if (get_execution_model() == ExecutionModelTessellationControl)
1277 		stage_out_ptr_var_id = add_interface_block_pointer(stage_out_var_id, StorageClassOutput);
1278 	if (is_tessellation_shader())
1279 		stage_in_ptr_var_id = add_interface_block_pointer(stage_in_var_id, StorageClassInput);
1280 
1281 	// Metal vertex functions that define no output must disable rasterization and return void.
1282 	if (!stage_out_var_id)
1283 		is_rasterization_disabled = true;
1284 
1285 	// Convert the use of global variables to recursively-passed function parameters
1286 	localize_global_variables();
1287 	extract_global_variables_from_functions();
1288 
1289 	// Mark any non-stage-in structs to be tightly packed.
1290 	mark_packable_structs();
1291 	reorder_type_alias();
1292 
1293 	// Add fixup hooks required by shader inputs and outputs. This needs to happen before
1294 	// the loop, so the hooks aren't added multiple times.
1295 	fix_up_shader_inputs_outputs();
1296 
1297 	// If we are using argument buffers, we create argument buffer structures for them here.
1298 	// These buffers will be used in the entry point, not the individual resources.
1299 	if (msl_options.argument_buffers)
1300 	{
1301 		if (!msl_options.supports_msl_version(2, 0))
1302 			SPIRV_CROSS_THROW("Argument buffers can only be used with MSL 2.0 and up.");
1303 		analyze_argument_buffers();
1304 	}
1305 
1306 	uint32_t pass_count = 0;
1307 	do
1308 	{
1309 		if (pass_count >= 3)
1310 			SPIRV_CROSS_THROW("Over 3 compilation loops detected. Must be a bug!");
1311 
1312 		reset();
1313 
1314 		// Start bindings at zero.
1315 		next_metal_resource_index_buffer = 0;
1316 		next_metal_resource_index_texture = 0;
1317 		next_metal_resource_index_sampler = 0;
1318 		for (auto &id : next_metal_resource_ids)
1319 			id = 0;
1320 
1321 		// Move constructor for this type is broken on GCC 4.9 ...
1322 		buffer.reset();
1323 
1324 		emit_header();
1325 		emit_custom_templates();
1326 		emit_specialization_constants_and_structs();
1327 		emit_resources();
1328 		emit_custom_functions();
1329 		emit_function(get<SPIRFunction>(ir.default_entry_point), Bitset());
1330 
1331 		pass_count++;
1332 	} while (is_forcing_recompilation());
1333 
1334 	return buffer.str();
1335 }
1336 
1337 // Register the need to output any custom functions.
preprocess_op_codes()1338 void CompilerMSL::preprocess_op_codes()
1339 {
1340 	OpCodePreprocessor preproc(*this);
1341 	traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), preproc);
1342 
1343 	suppress_missing_prototypes = preproc.suppress_missing_prototypes;
1344 
1345 	if (preproc.uses_atomics)
1346 	{
1347 		add_header_line("#include <metal_atomic>");
1348 		add_pragma_line("#pragma clang diagnostic ignored \"-Wunused-variable\"");
1349 	}
1350 
1351 	// Before MSL 2.1 (2.2 for textures), Metal vertex functions that write to
1352 	// resources must disable rasterization and return void.
1353 	if (preproc.uses_resource_write)
1354 		is_rasterization_disabled = true;
1355 
1356 	// Tessellation control shaders are run as compute functions in Metal, and so
1357 	// must capture their output to a buffer.
1358 	if (get_execution_model() == ExecutionModelTessellationControl ||
1359 	    (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
1360 	{
1361 		is_rasterization_disabled = true;
1362 		capture_output_to_buffer = true;
1363 	}
1364 
1365 	if (preproc.needs_subgroup_invocation_id)
1366 		needs_subgroup_invocation_id = true;
1367 	if (preproc.needs_subgroup_size)
1368 		needs_subgroup_size = true;
1369 	// build_implicit_builtins() hasn't run yet, and in fact, this needs to execute
1370 	// before then so that gl_SampleID will get added; so we also need to check if
1371 	// that function would add gl_FragCoord.
1372 	if (preproc.needs_sample_id || msl_options.force_sample_rate_shading ||
1373 	    (is_sample_rate() && (active_input_builtins.get(BuiltInFragCoord) ||
1374 	                          (need_subpass_input && !msl_options.use_framebuffer_fetch_subpasses))))
1375 		needs_sample_id = true;
1376 }
1377 
1378 // Move the Private and Workgroup global variables to the entry function.
1379 // Non-constant variables cannot have global scope in Metal.
localize_global_variables()1380 void CompilerMSL::localize_global_variables()
1381 {
1382 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1383 	auto iter = global_variables.begin();
1384 	while (iter != global_variables.end())
1385 	{
1386 		uint32_t v_id = *iter;
1387 		auto &var = get<SPIRVariable>(v_id);
1388 		if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup)
1389 		{
1390 			if (!variable_is_lut(var))
1391 				entry_func.add_local_variable(v_id);
1392 			iter = global_variables.erase(iter);
1393 		}
1394 		else
1395 			iter++;
1396 	}
1397 }
1398 
1399 // For any global variable accessed directly by a function,
1400 // extract that variable and add it as an argument to that function.
extract_global_variables_from_functions()1401 void CompilerMSL::extract_global_variables_from_functions()
1402 {
1403 	// Uniforms
1404 	unordered_set<uint32_t> global_var_ids;
1405 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1406 		if (var.storage == StorageClassInput || var.storage == StorageClassOutput ||
1407 		    var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
1408 		    var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer)
1409 		{
1410 			global_var_ids.insert(var.self);
1411 		}
1412 	});
1413 
1414 	// Local vars that are declared in the main function and accessed directly by a function
1415 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1416 	for (auto &var : entry_func.local_variables)
1417 		if (get<SPIRVariable>(var).storage != StorageClassFunction)
1418 			global_var_ids.insert(var);
1419 
1420 	std::set<uint32_t> added_arg_ids;
1421 	unordered_set<uint32_t> processed_func_ids;
1422 	extract_global_variables_from_function(ir.default_entry_point, added_arg_ids, global_var_ids, processed_func_ids);
1423 }
1424 
1425 // MSL does not support the use of global variables for shader input content.
1426 // For any global variable accessed directly by the specified function, extract that variable,
1427 // add it as an argument to that function, and the arg to the added_arg_ids collection.
extract_global_variables_from_function(uint32_t func_id,std::set<uint32_t> & added_arg_ids,unordered_set<uint32_t> & global_var_ids,unordered_set<uint32_t> & processed_func_ids)1428 void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::set<uint32_t> &added_arg_ids,
1429                                                          unordered_set<uint32_t> &global_var_ids,
1430                                                          unordered_set<uint32_t> &processed_func_ids)
1431 {
1432 	// Avoid processing a function more than once
1433 	if (processed_func_ids.find(func_id) != processed_func_ids.end())
1434 	{
1435 		// Return function global variables
1436 		added_arg_ids = function_global_vars[func_id];
1437 		return;
1438 	}
1439 
1440 	processed_func_ids.insert(func_id);
1441 
1442 	auto &func = get<SPIRFunction>(func_id);
1443 
1444 	// Recursively establish global args added to functions on which we depend.
1445 	for (auto block : func.blocks)
1446 	{
1447 		auto &b = get<SPIRBlock>(block);
1448 		for (auto &i : b.ops)
1449 		{
1450 			auto ops = stream(i);
1451 			auto op = static_cast<Op>(i.op);
1452 
1453 			switch (op)
1454 			{
1455 			case OpLoad:
1456 			case OpInBoundsAccessChain:
1457 			case OpAccessChain:
1458 			case OpPtrAccessChain:
1459 			case OpArrayLength:
1460 			{
1461 				uint32_t base_id = ops[2];
1462 				if (global_var_ids.find(base_id) != global_var_ids.end())
1463 					added_arg_ids.insert(base_id);
1464 
1465 				// Use Metal's native frame-buffer fetch API for subpass inputs.
1466 				auto &type = get<SPIRType>(ops[0]);
1467 				if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData &&
1468 				    (!msl_options.use_framebuffer_fetch_subpasses))
1469 				{
1470 					// Implicitly reads gl_FragCoord.
1471 					assert(builtin_frag_coord_id != 0);
1472 					added_arg_ids.insert(builtin_frag_coord_id);
1473 					if (msl_options.multiview)
1474 					{
1475 						// Implicitly reads gl_ViewIndex.
1476 						assert(builtin_view_idx_id != 0);
1477 						added_arg_ids.insert(builtin_view_idx_id);
1478 					}
1479 					else if (msl_options.arrayed_subpass_input)
1480 					{
1481 						// Implicitly reads gl_Layer.
1482 						assert(builtin_layer_id != 0);
1483 						added_arg_ids.insert(builtin_layer_id);
1484 					}
1485 				}
1486 
1487 				break;
1488 			}
1489 
1490 			case OpFunctionCall:
1491 			{
1492 				// First see if any of the function call args are globals
1493 				for (uint32_t arg_idx = 3; arg_idx < i.length; arg_idx++)
1494 				{
1495 					uint32_t arg_id = ops[arg_idx];
1496 					if (global_var_ids.find(arg_id) != global_var_ids.end())
1497 						added_arg_ids.insert(arg_id);
1498 				}
1499 
1500 				// Then recurse into the function itself to extract globals used internally in the function
1501 				uint32_t inner_func_id = ops[2];
1502 				std::set<uint32_t> inner_func_args;
1503 				extract_global_variables_from_function(inner_func_id, inner_func_args, global_var_ids,
1504 				                                       processed_func_ids);
1505 				added_arg_ids.insert(inner_func_args.begin(), inner_func_args.end());
1506 				break;
1507 			}
1508 
1509 			case OpStore:
1510 			{
1511 				uint32_t base_id = ops[0];
1512 				if (global_var_ids.find(base_id) != global_var_ids.end())
1513 					added_arg_ids.insert(base_id);
1514 
1515 				uint32_t rvalue_id = ops[1];
1516 				if (global_var_ids.find(rvalue_id) != global_var_ids.end())
1517 					added_arg_ids.insert(rvalue_id);
1518 
1519 				break;
1520 			}
1521 
1522 			case OpSelect:
1523 			{
1524 				uint32_t base_id = ops[3];
1525 				if (global_var_ids.find(base_id) != global_var_ids.end())
1526 					added_arg_ids.insert(base_id);
1527 				base_id = ops[4];
1528 				if (global_var_ids.find(base_id) != global_var_ids.end())
1529 					added_arg_ids.insert(base_id);
1530 				break;
1531 			}
1532 
1533 			// Emulate texture2D atomic operations
1534 			case OpImageTexelPointer:
1535 			{
1536 				// When using the pointer, we need to know which variable it is actually loaded from.
1537 				uint32_t base_id = ops[2];
1538 				auto *var = maybe_get_backing_variable(base_id);
1539 				if (var && atomic_image_vars.count(var->self))
1540 				{
1541 					if (global_var_ids.find(base_id) != global_var_ids.end())
1542 						added_arg_ids.insert(base_id);
1543 				}
1544 				break;
1545 			}
1546 
1547 			case OpExtInst:
1548 			{
1549 				uint32_t extension_set = ops[2];
1550 				if (get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
1551 				{
1552 					auto op_450 = static_cast<GLSLstd450>(ops[3]);
1553 					switch (op_450)
1554 					{
1555 					case GLSLstd450InterpolateAtCentroid:
1556 					case GLSLstd450InterpolateAtSample:
1557 					case GLSLstd450InterpolateAtOffset:
1558 					{
1559 						// For these, we really need the stage-in block. It is theoretically possible to pass the
1560 						// interpolant object, but a) doing so would require us to create an entirely new variable
1561 						// with Interpolant type, and b) if we have a struct or array, handling all the members and
1562 						// elements could get unwieldy fast.
1563 						added_arg_ids.insert(stage_in_var_id);
1564 						break;
1565 					}
1566 					default:
1567 						break;
1568 					}
1569 				}
1570 				break;
1571 			}
1572 
1573 			case OpGroupNonUniformInverseBallot:
1574 			{
1575 				added_arg_ids.insert(builtin_subgroup_invocation_id_id);
1576 				break;
1577 			}
1578 
1579 			case OpGroupNonUniformBallotFindLSB:
1580 			case OpGroupNonUniformBallotFindMSB:
1581 			{
1582 				added_arg_ids.insert(builtin_subgroup_size_id);
1583 				break;
1584 			}
1585 
1586 			case OpGroupNonUniformBallotBitCount:
1587 			{
1588 				auto operation = static_cast<GroupOperation>(ops[3]);
1589 				switch (operation)
1590 				{
1591 				case GroupOperationReduce:
1592 					added_arg_ids.insert(builtin_subgroup_size_id);
1593 					break;
1594 				case GroupOperationInclusiveScan:
1595 				case GroupOperationExclusiveScan:
1596 					added_arg_ids.insert(builtin_subgroup_invocation_id_id);
1597 					break;
1598 				default:
1599 					break;
1600 				}
1601 				break;
1602 			}
1603 
1604 			default:
1605 				break;
1606 			}
1607 
1608 			// TODO: Add all other operations which can affect memory.
1609 			// We should consider a more unified system here to reduce boiler-plate.
1610 			// This kind of analysis is done in several places ...
1611 		}
1612 	}
1613 
1614 	function_global_vars[func_id] = added_arg_ids;
1615 
1616 	// Add the global variables as arguments to the function
1617 	if (func_id != ir.default_entry_point)
1618 	{
1619 		bool added_in = false;
1620 		bool added_out = false;
1621 		for (uint32_t arg_id : added_arg_ids)
1622 		{
1623 			auto &var = get<SPIRVariable>(arg_id);
1624 			uint32_t type_id = var.basetype;
1625 			auto *p_type = &get<SPIRType>(type_id);
1626 			BuiltIn bi_type = BuiltIn(get_decoration(arg_id, DecorationBuiltIn));
1627 
1628 			if (((is_tessellation_shader() && var.storage == StorageClassInput) ||
1629 			     (get_execution_model() == ExecutionModelTessellationControl && var.storage == StorageClassOutput)) &&
1630 			    !(has_decoration(arg_id, DecorationPatch) || is_patch_block(*p_type)) &&
1631 			    (!is_builtin_variable(var) || bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
1632 			     bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance ||
1633 			     p_type->basetype == SPIRType::Struct))
1634 			{
1635 				// Tessellation control shaders see inputs and per-vertex outputs as arrays.
1636 				// Similarly, tessellation evaluation shaders see per-vertex inputs as arrays.
1637 				// We collected them into a structure; we must pass the array of this
1638 				// structure to the function.
1639 				std::string name;
1640 				if (var.storage == StorageClassInput)
1641 				{
1642 					if (added_in)
1643 						continue;
1644 					name = "gl_in";
1645 					arg_id = stage_in_ptr_var_id;
1646 					added_in = true;
1647 				}
1648 				else if (var.storage == StorageClassOutput)
1649 				{
1650 					if (added_out)
1651 						continue;
1652 					name = "gl_out";
1653 					arg_id = stage_out_ptr_var_id;
1654 					added_out = true;
1655 				}
1656 				type_id = get<SPIRVariable>(arg_id).basetype;
1657 				uint32_t next_id = ir.increase_bound_by(1);
1658 				func.add_parameter(type_id, next_id, true);
1659 				set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1660 
1661 				set_name(next_id, name);
1662 			}
1663 			else if (is_builtin_variable(var) && p_type->basetype == SPIRType::Struct)
1664 			{
1665 				// Get the pointee type
1666 				type_id = get_pointee_type_id(type_id);
1667 				p_type = &get<SPIRType>(type_id);
1668 
1669 				uint32_t mbr_idx = 0;
1670 				for (auto &mbr_type_id : p_type->member_types)
1671 				{
1672 					BuiltIn builtin = BuiltInMax;
1673 					bool is_builtin = is_member_builtin(*p_type, mbr_idx, &builtin);
1674 					if (is_builtin && has_active_builtin(builtin, var.storage))
1675 					{
1676 						// Add a arg variable with the same type and decorations as the member
1677 						uint32_t next_ids = ir.increase_bound_by(2);
1678 						uint32_t ptr_type_id = next_ids + 0;
1679 						uint32_t var_id = next_ids + 1;
1680 
1681 						// Make sure we have an actual pointer type,
1682 						// so that we will get the appropriate address space when declaring these builtins.
1683 						auto &ptr = set<SPIRType>(ptr_type_id, get<SPIRType>(mbr_type_id));
1684 						ptr.self = mbr_type_id;
1685 						ptr.storage = var.storage;
1686 						ptr.pointer = true;
1687 						ptr.parent_type = mbr_type_id;
1688 
1689 						func.add_parameter(mbr_type_id, var_id, true);
1690 						set<SPIRVariable>(var_id, ptr_type_id, StorageClassFunction);
1691 						ir.meta[var_id].decoration = ir.meta[type_id].members[mbr_idx];
1692 					}
1693 					mbr_idx++;
1694 				}
1695 			}
1696 			else
1697 			{
1698 				uint32_t next_id = ir.increase_bound_by(1);
1699 				func.add_parameter(type_id, next_id, true);
1700 				set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1701 
1702 				// Ensure the existing variable has a valid name and the new variable has all the same meta info
1703 				set_name(arg_id, ensure_valid_name(to_name(arg_id), "v"));
1704 				ir.meta[next_id] = ir.meta[arg_id];
1705 			}
1706 		}
1707 	}
1708 }
1709 
1710 // For all variables that are some form of non-input-output interface block, mark that all the structs
1711 // that are recursively contained within the type referenced by that variable should be packed tightly.
mark_packable_structs()1712 void CompilerMSL::mark_packable_structs()
1713 {
1714 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1715 		if (var.storage != StorageClassFunction && !is_hidden_variable(var))
1716 		{
1717 			auto &type = this->get<SPIRType>(var.basetype);
1718 			if (type.pointer &&
1719 			    (type.storage == StorageClassUniform || type.storage == StorageClassUniformConstant ||
1720 			     type.storage == StorageClassPushConstant || type.storage == StorageClassStorageBuffer) &&
1721 			    (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
1722 				mark_as_packable(type);
1723 		}
1724 	});
1725 }
1726 
1727 // If the specified type is a struct, it and any nested structs
1728 // are marked as packable with the SPIRVCrossDecorationBufferBlockRepacked decoration,
mark_as_packable(SPIRType & type)1729 void CompilerMSL::mark_as_packable(SPIRType &type)
1730 {
1731 	// If this is not the base type (eg. it's a pointer or array), tunnel down
1732 	if (type.parent_type)
1733 	{
1734 		mark_as_packable(get<SPIRType>(type.parent_type));
1735 		return;
1736 	}
1737 
1738 	if (type.basetype == SPIRType::Struct)
1739 	{
1740 		set_extended_decoration(type.self, SPIRVCrossDecorationBufferBlockRepacked);
1741 
1742 		// Recurse
1743 		uint32_t mbr_cnt = uint32_t(type.member_types.size());
1744 		for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
1745 		{
1746 			uint32_t mbr_type_id = type.member_types[mbr_idx];
1747 			auto &mbr_type = get<SPIRType>(mbr_type_id);
1748 			mark_as_packable(mbr_type);
1749 			if (mbr_type.type_alias)
1750 			{
1751 				auto &mbr_type_alias = get<SPIRType>(mbr_type.type_alias);
1752 				mark_as_packable(mbr_type_alias);
1753 			}
1754 		}
1755 	}
1756 }
1757 
1758 // If a shader input exists at the location, it is marked as being used by this shader
mark_location_as_used_by_shader(uint32_t location,const SPIRType & type,StorageClass storage)1759 void CompilerMSL::mark_location_as_used_by_shader(uint32_t location, const SPIRType &type, StorageClass storage)
1760 {
1761 	if (storage != StorageClassInput)
1762 		return;
1763 	if (is_array(type))
1764 	{
1765 		uint32_t dim = 1;
1766 		for (uint32_t i = 0; i < type.array.size(); i++)
1767 			dim *= to_array_size_literal(type, i);
1768 		for (uint32_t i = 0; i < dim; i++)
1769 		{
1770 			if (is_matrix(type))
1771 			{
1772 				for (uint32_t j = 0; j < type.columns; j++)
1773 					inputs_in_use.insert(location++);
1774 			}
1775 			else
1776 				inputs_in_use.insert(location++);
1777 		}
1778 	}
1779 	else if (is_matrix(type))
1780 	{
1781 		for (uint32_t i = 0; i < type.columns; i++)
1782 			inputs_in_use.insert(location + i);
1783 	}
1784 	else
1785 		inputs_in_use.insert(location);
1786 }
1787 
get_target_components_for_fragment_location(uint32_t location) const1788 uint32_t CompilerMSL::get_target_components_for_fragment_location(uint32_t location) const
1789 {
1790 	auto itr = fragment_output_components.find(location);
1791 	if (itr == end(fragment_output_components))
1792 		return 4;
1793 	else
1794 		return itr->second;
1795 }
1796 
build_extended_vector_type(uint32_t type_id,uint32_t components,SPIRType::BaseType basetype)1797 uint32_t CompilerMSL::build_extended_vector_type(uint32_t type_id, uint32_t components, SPIRType::BaseType basetype)
1798 {
1799 	uint32_t new_type_id = ir.increase_bound_by(1);
1800 	auto &old_type = get<SPIRType>(type_id);
1801 	auto *type = &set<SPIRType>(new_type_id, old_type);
1802 	type->vecsize = components;
1803 	if (basetype != SPIRType::Unknown)
1804 		type->basetype = basetype;
1805 	type->self = new_type_id;
1806 	type->parent_type = type_id;
1807 	type->array.clear();
1808 	type->array_size_literal.clear();
1809 	type->pointer = false;
1810 
1811 	if (is_array(old_type))
1812 	{
1813 		uint32_t array_type_id = ir.increase_bound_by(1);
1814 		type = &set<SPIRType>(array_type_id, *type);
1815 		type->parent_type = new_type_id;
1816 		type->array = old_type.array;
1817 		type->array_size_literal = old_type.array_size_literal;
1818 		new_type_id = array_type_id;
1819 	}
1820 
1821 	if (old_type.pointer)
1822 	{
1823 		uint32_t ptr_type_id = ir.increase_bound_by(1);
1824 		type = &set<SPIRType>(ptr_type_id, *type);
1825 		type->self = new_type_id;
1826 		type->parent_type = new_type_id;
1827 		type->storage = old_type.storage;
1828 		type->pointer = true;
1829 		new_type_id = ptr_type_id;
1830 	}
1831 
1832 	return new_type_id;
1833 }
1834 
build_msl_interpolant_type(uint32_t type_id,bool is_noperspective)1835 uint32_t CompilerMSL::build_msl_interpolant_type(uint32_t type_id, bool is_noperspective)
1836 {
1837 	uint32_t new_type_id = ir.increase_bound_by(1);
1838 	SPIRType &type = set<SPIRType>(new_type_id, get<SPIRType>(type_id));
1839 	type.basetype = SPIRType::Interpolant;
1840 	type.parent_type = type_id;
1841 	// In Metal, the pull-model interpolant type encodes perspective-vs-no-perspective in the type itself.
1842 	// Add this decoration so we know which argument to pass to the template.
1843 	if (is_noperspective)
1844 		set_decoration(new_type_id, DecorationNoPerspective);
1845 	return new_type_id;
1846 }
1847 
add_plain_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)1848 void CompilerMSL::add_plain_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
1849                                                         SPIRType &ib_type, SPIRVariable &var, InterfaceBlockMeta &meta)
1850 {
1851 	bool is_builtin = is_builtin_variable(var);
1852 	BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
1853 	bool is_flat = has_decoration(var.self, DecorationFlat);
1854 	bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
1855 	bool is_centroid = has_decoration(var.self, DecorationCentroid);
1856 	bool is_sample = has_decoration(var.self, DecorationSample);
1857 
1858 	// Add a reference to the variable type to the interface struct.
1859 	uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1860 	uint32_t type_id = ensure_correct_builtin_type(var.basetype, builtin);
1861 	var.basetype = type_id;
1862 
1863 	type_id = get_pointee_type_id(var.basetype);
1864 	if (meta.strip_array && is_array(get<SPIRType>(type_id)))
1865 		type_id = get<SPIRType>(type_id).parent_type;
1866 	auto &type = get<SPIRType>(type_id);
1867 	uint32_t target_components = 0;
1868 	uint32_t type_components = type.vecsize;
1869 
1870 	bool padded_output = false;
1871 	bool padded_input = false;
1872 	uint32_t start_component = 0;
1873 
1874 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1875 
1876 	// Deal with Component decorations.
1877 	InterfaceBlockMeta::LocationMeta *location_meta = nullptr;
1878 	if (has_decoration(var.self, DecorationLocation))
1879 	{
1880 		auto location_meta_itr = meta.location_meta.find(get_decoration(var.self, DecorationLocation));
1881 		if (location_meta_itr != end(meta.location_meta))
1882 			location_meta = &location_meta_itr->second;
1883 	}
1884 
1885 	bool pad_fragment_output = has_decoration(var.self, DecorationLocation) &&
1886 	                           msl_options.pad_fragment_output_components &&
1887 	                           get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput;
1888 
1889 	// Check if we need to pad fragment output to match a certain number of components.
1890 	if (location_meta)
1891 	{
1892 		start_component = get_decoration(var.self, DecorationComponent);
1893 		uint32_t num_components = location_meta->num_components;
1894 		if (pad_fragment_output)
1895 		{
1896 			uint32_t locn = get_decoration(var.self, DecorationLocation);
1897 			num_components = std::max(num_components, get_target_components_for_fragment_location(locn));
1898 		}
1899 
1900 		if (location_meta->ib_index != ~0u)
1901 		{
1902 			// We have already declared the variable. Just emit an early-declared variable and fixup as needed.
1903 			entry_func.add_local_variable(var.self);
1904 			vars_needing_early_declaration.push_back(var.self);
1905 
1906 			if (var.storage == StorageClassInput)
1907 			{
1908 				uint32_t ib_index = location_meta->ib_index;
1909 				entry_func.fixup_hooks_in.push_back([=, &var]() {
1910 					statement(to_name(var.self), " = ", ib_var_ref, ".", to_member_name(ib_type, ib_index),
1911 					          vector_swizzle(type_components, start_component), ";");
1912 				});
1913 			}
1914 			else
1915 			{
1916 				uint32_t ib_index = location_meta->ib_index;
1917 				entry_func.fixup_hooks_out.push_back([=, &var]() {
1918 					statement(ib_var_ref, ".", to_member_name(ib_type, ib_index),
1919 					          vector_swizzle(type_components, start_component), " = ", to_name(var.self), ";");
1920 				});
1921 			}
1922 			return;
1923 		}
1924 		else
1925 		{
1926 			location_meta->ib_index = uint32_t(ib_type.member_types.size());
1927 			type_id = build_extended_vector_type(type_id, num_components);
1928 			if (var.storage == StorageClassInput)
1929 				padded_input = true;
1930 			else
1931 				padded_output = true;
1932 		}
1933 	}
1934 	else if (pad_fragment_output)
1935 	{
1936 		uint32_t locn = get_decoration(var.self, DecorationLocation);
1937 		target_components = get_target_components_for_fragment_location(locn);
1938 		if (type_components < target_components)
1939 		{
1940 			// Make a new type here.
1941 			type_id = build_extended_vector_type(type_id, target_components);
1942 			padded_output = true;
1943 		}
1944 	}
1945 
1946 	if (storage == StorageClassInput && pull_model_inputs.count(var.self))
1947 		ib_type.member_types.push_back(build_msl_interpolant_type(type_id, is_noperspective));
1948 	else
1949 		ib_type.member_types.push_back(type_id);
1950 
1951 	// Give the member a name
1952 	string mbr_name = ensure_valid_name(to_expression(var.self), "m");
1953 	set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1954 
1955 	// Update the original variable reference to include the structure reference
1956 	string qual_var_name = ib_var_ref + "." + mbr_name;
1957 	// If using pull-model interpolation, need to add a call to the correct interpolation method.
1958 	if (storage == StorageClassInput && pull_model_inputs.count(var.self))
1959 	{
1960 		if (is_centroid)
1961 			qual_var_name += ".interpolate_at_centroid()";
1962 		else if (is_sample)
1963 			qual_var_name += join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
1964 		else
1965 			qual_var_name += ".interpolate_at_center()";
1966 	}
1967 
1968 	if (padded_output || padded_input)
1969 	{
1970 		entry_func.add_local_variable(var.self);
1971 		vars_needing_early_declaration.push_back(var.self);
1972 
1973 		if (padded_output)
1974 		{
1975 			entry_func.fixup_hooks_out.push_back([=, &var]() {
1976 				statement(qual_var_name, vector_swizzle(type_components, start_component), " = ", to_name(var.self),
1977 				          ";");
1978 			});
1979 		}
1980 		else
1981 		{
1982 			entry_func.fixup_hooks_in.push_back([=, &var]() {
1983 				statement(to_name(var.self), " = ", qual_var_name, vector_swizzle(type_components, start_component),
1984 				          ";");
1985 			});
1986 		}
1987 	}
1988 	else if (!meta.strip_array)
1989 		ir.meta[var.self].decoration.qualified_alias = qual_var_name;
1990 
1991 	if (var.storage == StorageClassOutput && var.initializer != ID(0))
1992 	{
1993 		if (padded_output || padded_input)
1994 		{
1995 			entry_func.fixup_hooks_in.push_back(
1996 			    [=, &var]() { statement(to_name(var.self), " = ", to_expression(var.initializer), ";"); });
1997 		}
1998 		else
1999 		{
2000 			if (meta.strip_array)
2001 			{
2002 				entry_func.fixup_hooks_in.push_back([=, &var]() {
2003 					uint32_t index = get_extended_decoration(var.self, SPIRVCrossDecorationInterfaceMemberIndex);
2004 					statement(to_expression(stage_out_ptr_var_id), "[",
2005 					          builtin_to_glsl(BuiltInInvocationId, StorageClassInput), "].",
2006 					          to_member_name(ib_type, index), " = ", to_expression(var.initializer), "[",
2007 					          builtin_to_glsl(BuiltInInvocationId, StorageClassInput), "];");
2008 				});
2009 			}
2010 			else
2011 			{
2012 				entry_func.fixup_hooks_in.push_back([=, &var]() {
2013 					statement(qual_var_name, " = ", to_expression(var.initializer), ";");
2014 				});
2015 			}
2016 		}
2017 	}
2018 
2019 	// Copy the variable location from the original variable to the member
2020 	if (get_decoration_bitset(var.self).get(DecorationLocation))
2021 	{
2022 		uint32_t locn = get_decoration(var.self, DecorationLocation);
2023 		if (storage == StorageClassInput)
2024 		{
2025 			type_id = ensure_correct_input_type(var.basetype, locn, location_meta ? location_meta->num_components : 0);
2026 			if (!location_meta)
2027 				var.basetype = type_id;
2028 
2029 			type_id = get_pointee_type_id(type_id);
2030 			if (meta.strip_array && is_array(get<SPIRType>(type_id)))
2031 				type_id = get<SPIRType>(type_id).parent_type;
2032 			if (pull_model_inputs.count(var.self))
2033 				ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(type_id, is_noperspective);
2034 			else
2035 				ib_type.member_types[ib_mbr_idx] = type_id;
2036 		}
2037 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2038 		mark_location_as_used_by_shader(locn, get<SPIRType>(type_id), storage);
2039 	}
2040 	else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2041 	{
2042 		uint32_t locn = inputs_by_builtin[builtin].location;
2043 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2044 		mark_location_as_used_by_shader(locn, type, storage);
2045 	}
2046 
2047 	if (!location_meta)
2048 	{
2049 		if (get_decoration_bitset(var.self).get(DecorationComponent))
2050 		{
2051 			uint32_t component = get_decoration(var.self, DecorationComponent);
2052 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, component);
2053 		}
2054 	}
2055 
2056 	if (get_decoration_bitset(var.self).get(DecorationIndex))
2057 	{
2058 		uint32_t index = get_decoration(var.self, DecorationIndex);
2059 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
2060 	}
2061 
2062 	// Mark the member as builtin if needed
2063 	if (is_builtin)
2064 	{
2065 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2066 		if (builtin == BuiltInPosition && storage == StorageClassOutput)
2067 			qual_pos_var_name = qual_var_name;
2068 	}
2069 
2070 	// Copy interpolation decorations if needed
2071 	if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2072 	{
2073 		if (is_flat)
2074 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2075 		if (is_noperspective)
2076 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2077 		if (is_centroid)
2078 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2079 		if (is_sample)
2080 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2081 	}
2082 
2083 	// If we have location meta, there is no unique OrigID. We won't need it, since we flatten/unflatten
2084 	// the variable to stack anyways here.
2085 	if (!location_meta)
2086 		set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2087 }
2088 
add_composite_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)2089 void CompilerMSL::add_composite_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2090                                                             SPIRType &ib_type, SPIRVariable &var,
2091                                                             InterfaceBlockMeta &meta)
2092 {
2093 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2094 	auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2095 	uint32_t elem_cnt = 0;
2096 
2097 	if (is_matrix(var_type))
2098 	{
2099 		if (is_array(var_type))
2100 			SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
2101 
2102 		elem_cnt = var_type.columns;
2103 	}
2104 	else if (is_array(var_type))
2105 	{
2106 		if (var_type.array.size() != 1)
2107 			SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
2108 
2109 		elem_cnt = to_array_size_literal(var_type);
2110 	}
2111 
2112 	bool is_builtin = is_builtin_variable(var);
2113 	BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2114 	bool is_flat = has_decoration(var.self, DecorationFlat);
2115 	bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
2116 	bool is_centroid = has_decoration(var.self, DecorationCentroid);
2117 	bool is_sample = has_decoration(var.self, DecorationSample);
2118 
2119 	auto *usable_type = &var_type;
2120 	if (usable_type->pointer)
2121 		usable_type = &get<SPIRType>(usable_type->parent_type);
2122 	while (is_array(*usable_type) || is_matrix(*usable_type))
2123 		usable_type = &get<SPIRType>(usable_type->parent_type);
2124 
2125 	// If a builtin, force it to have the proper name.
2126 	if (is_builtin)
2127 		set_name(var.self, builtin_to_glsl(builtin, StorageClassFunction));
2128 
2129 	bool flatten_from_ib_var = false;
2130 	string flatten_from_ib_mbr_name;
2131 
2132 	if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
2133 	{
2134 		// Also declare [[clip_distance]] attribute here.
2135 		uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
2136 		ib_type.member_types.push_back(get_variable_data_type_id(var));
2137 		set_member_decoration(ib_type.self, clip_array_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2138 
2139 		flatten_from_ib_mbr_name = builtin_to_glsl(BuiltInClipDistance, StorageClassOutput);
2140 		set_member_name(ib_type.self, clip_array_mbr_idx, flatten_from_ib_mbr_name);
2141 
2142 		// When we flatten, we flatten directly from the "out" struct,
2143 		// not from a function variable.
2144 		flatten_from_ib_var = true;
2145 
2146 		if (!msl_options.enable_clip_distance_user_varying)
2147 			return;
2148 	}
2149 	else if (!meta.strip_array)
2150 	{
2151 		// Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
2152 		entry_func.add_local_variable(var.self);
2153 		// We need to declare the variable early and at entry-point scope.
2154 		vars_needing_early_declaration.push_back(var.self);
2155 	}
2156 
2157 	for (uint32_t i = 0; i < elem_cnt; i++)
2158 	{
2159 		// Add a reference to the variable type to the interface struct.
2160 		uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2161 
2162 		uint32_t target_components = 0;
2163 		bool padded_output = false;
2164 		uint32_t type_id = usable_type->self;
2165 
2166 		// Check if we need to pad fragment output to match a certain number of components.
2167 		if (get_decoration_bitset(var.self).get(DecorationLocation) && msl_options.pad_fragment_output_components &&
2168 		    get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput)
2169 		{
2170 			uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
2171 			target_components = get_target_components_for_fragment_location(locn);
2172 			if (usable_type->vecsize < target_components)
2173 			{
2174 				// Make a new type here.
2175 				type_id = build_extended_vector_type(usable_type->self, target_components);
2176 				padded_output = true;
2177 			}
2178 		}
2179 
2180 		if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2181 			ib_type.member_types.push_back(build_msl_interpolant_type(get_pointee_type_id(type_id), is_noperspective));
2182 		else
2183 			ib_type.member_types.push_back(get_pointee_type_id(type_id));
2184 
2185 		// Give the member a name
2186 		string mbr_name = ensure_valid_name(join(to_expression(var.self), "_", i), "m");
2187 		set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2188 
2189 		// There is no qualified alias since we need to flatten the internal array on return.
2190 		if (get_decoration_bitset(var.self).get(DecorationLocation))
2191 		{
2192 			uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
2193 			if (storage == StorageClassInput)
2194 			{
2195 				var.basetype = ensure_correct_input_type(var.basetype, locn);
2196 				uint32_t mbr_type_id = ensure_correct_input_type(usable_type->self, locn);
2197 				if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2198 					ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective);
2199 				else
2200 					ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2201 			}
2202 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2203 			mark_location_as_used_by_shader(locn, *usable_type, storage);
2204 		}
2205 		else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2206 		{
2207 			uint32_t locn = inputs_by_builtin[builtin].location + i;
2208 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2209 			mark_location_as_used_by_shader(locn, *usable_type, storage);
2210 		}
2211 		else if (is_builtin && builtin == BuiltInClipDistance)
2212 		{
2213 			// Declare the ClipDistance as [[user(clipN)]].
2214 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2215 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, i);
2216 		}
2217 
2218 		if (get_decoration_bitset(var.self).get(DecorationIndex))
2219 		{
2220 			uint32_t index = get_decoration(var.self, DecorationIndex);
2221 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
2222 		}
2223 
2224 		if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2225 		{
2226 			// Copy interpolation decorations if needed
2227 			if (is_flat)
2228 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2229 			if (is_noperspective)
2230 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2231 			if (is_centroid)
2232 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2233 			if (is_sample)
2234 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2235 		}
2236 
2237 		set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2238 
2239 		// Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
2240 		if (!meta.strip_array)
2241 		{
2242 			switch (storage)
2243 			{
2244 			case StorageClassInput:
2245 				entry_func.fixup_hooks_in.push_back([=, &var]() {
2246 					if (pull_model_inputs.count(var.self))
2247 					{
2248 						string lerp_call;
2249 						if (is_centroid)
2250 							lerp_call = ".interpolate_at_centroid()";
2251 						else if (is_sample)
2252 							lerp_call = join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2253 						else
2254 							lerp_call = ".interpolate_at_center()";
2255 						statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, lerp_call, ";");
2256 					}
2257 					else
2258 					{
2259 						statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, ";");
2260 					}
2261 				});
2262 				break;
2263 
2264 			case StorageClassOutput:
2265 				entry_func.fixup_hooks_out.push_back([=, &var]() {
2266 					if (padded_output)
2267 					{
2268 						auto &padded_type = this->get<SPIRType>(type_id);
2269 						statement(
2270 						    ib_var_ref, ".", mbr_name, " = ",
2271 						    remap_swizzle(padded_type, usable_type->vecsize, join(to_name(var.self), "[", i, "]")),
2272 						    ";");
2273 					}
2274 					else if (flatten_from_ib_var)
2275 						statement(ib_var_ref, ".", mbr_name, " = ", ib_var_ref, ".", flatten_from_ib_mbr_name, "[", i,
2276 						          "];");
2277 					else
2278 						statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), "[", i, "];");
2279 				});
2280 				break;
2281 
2282 			default:
2283 				break;
2284 			}
2285 		}
2286 	}
2287 }
2288 
get_accumulated_member_location(const SPIRVariable & var,uint32_t mbr_idx,bool strip_array)2289 uint32_t CompilerMSL::get_accumulated_member_location(const SPIRVariable &var, uint32_t mbr_idx, bool strip_array)
2290 {
2291 	auto &type = strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2292 	uint32_t location = get_decoration(var.self, DecorationLocation);
2293 
2294 	for (uint32_t i = 0; i < mbr_idx; i++)
2295 	{
2296 		auto &mbr_type = get<SPIRType>(type.member_types[i]);
2297 
2298 		// Start counting from any place we have a new location decoration.
2299 		if (has_member_decoration(type.self, mbr_idx, DecorationLocation))
2300 			location = get_member_decoration(type.self, mbr_idx, DecorationLocation);
2301 
2302 		uint32_t location_count = 1;
2303 
2304 		if (mbr_type.columns > 1)
2305 			location_count = mbr_type.columns;
2306 
2307 		if (!mbr_type.array.empty())
2308 			for (uint32_t j = 0; j < uint32_t(mbr_type.array.size()); j++)
2309 				location_count *= to_array_size_literal(mbr_type, j);
2310 
2311 		location += location_count;
2312 	}
2313 
2314 	return location;
2315 }
2316 
add_composite_member_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,uint32_t mbr_idx,InterfaceBlockMeta & meta)2317 void CompilerMSL::add_composite_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2318                                                                    SPIRType &ib_type, SPIRVariable &var,
2319                                                                    uint32_t mbr_idx, InterfaceBlockMeta &meta)
2320 {
2321 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2322 	auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2323 
2324 	BuiltIn builtin = BuiltInMax;
2325 	bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2326 	bool is_flat =
2327 	    has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
2328 	bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
2329 	                        has_decoration(var.self, DecorationNoPerspective);
2330 	bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
2331 	                   has_decoration(var.self, DecorationCentroid);
2332 	bool is_sample =
2333 	    has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
2334 
2335 	uint32_t mbr_type_id = var_type.member_types[mbr_idx];
2336 	auto &mbr_type = get<SPIRType>(mbr_type_id);
2337 	uint32_t elem_cnt = 0;
2338 
2339 	if (is_matrix(mbr_type))
2340 	{
2341 		if (is_array(mbr_type))
2342 			SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
2343 
2344 		elem_cnt = mbr_type.columns;
2345 	}
2346 	else if (is_array(mbr_type))
2347 	{
2348 		if (mbr_type.array.size() != 1)
2349 			SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
2350 
2351 		elem_cnt = to_array_size_literal(mbr_type);
2352 	}
2353 
2354 	auto *usable_type = &mbr_type;
2355 	if (usable_type->pointer)
2356 		usable_type = &get<SPIRType>(usable_type->parent_type);
2357 	while (is_array(*usable_type) || is_matrix(*usable_type))
2358 		usable_type = &get<SPIRType>(usable_type->parent_type);
2359 
2360 	bool flatten_from_ib_var = false;
2361 	string flatten_from_ib_mbr_name;
2362 
2363 	if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
2364 	{
2365 		// Also declare [[clip_distance]] attribute here.
2366 		uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
2367 		ib_type.member_types.push_back(mbr_type_id);
2368 		set_member_decoration(ib_type.self, clip_array_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2369 
2370 		flatten_from_ib_mbr_name = builtin_to_glsl(BuiltInClipDistance, StorageClassOutput);
2371 		set_member_name(ib_type.self, clip_array_mbr_idx, flatten_from_ib_mbr_name);
2372 
2373 		// When we flatten, we flatten directly from the "out" struct,
2374 		// not from a function variable.
2375 		flatten_from_ib_var = true;
2376 
2377 		if (!msl_options.enable_clip_distance_user_varying)
2378 			return;
2379 	}
2380 
2381 	for (uint32_t i = 0; i < elem_cnt; i++)
2382 	{
2383 		// Add a reference to the variable type to the interface struct.
2384 		uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2385 		if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2386 			ib_type.member_types.push_back(build_msl_interpolant_type(usable_type->self, is_noperspective));
2387 		else
2388 			ib_type.member_types.push_back(usable_type->self);
2389 
2390 		// Give the member a name
2391 		string mbr_name = ensure_valid_name(join(to_qualified_member_name(var_type, mbr_idx), "_", i), "m");
2392 		set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2393 
2394 		if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
2395 		{
2396 			uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation) + i;
2397 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2398 			mark_location_as_used_by_shader(locn, *usable_type, storage);
2399 		}
2400 		else if (has_decoration(var.self, DecorationLocation))
2401 		{
2402 			uint32_t locn = get_accumulated_member_location(var, mbr_idx, meta.strip_array) + i;
2403 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2404 			mark_location_as_used_by_shader(locn, *usable_type, storage);
2405 		}
2406 		else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2407 		{
2408 			uint32_t locn = inputs_by_builtin[builtin].location + i;
2409 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2410 			mark_location_as_used_by_shader(locn, *usable_type, storage);
2411 		}
2412 		else if (is_builtin && builtin == BuiltInClipDistance)
2413 		{
2414 			// Declare the ClipDistance as [[user(clipN)]].
2415 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2416 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, i);
2417 		}
2418 
2419 		if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
2420 			SPIRV_CROSS_THROW("DecorationComponent on matrices and arrays make little sense.");
2421 
2422 		if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2423 		{
2424 			// Copy interpolation decorations if needed
2425 			if (is_flat)
2426 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2427 			if (is_noperspective)
2428 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2429 			if (is_centroid)
2430 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2431 			if (is_sample)
2432 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2433 		}
2434 
2435 		set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2436 		set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
2437 
2438 		// Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
2439 		if (!meta.strip_array)
2440 		{
2441 			switch (storage)
2442 			{
2443 			case StorageClassInput:
2444 				entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
2445 					if (pull_model_inputs.count(var.self))
2446 					{
2447 						string lerp_call;
2448 						if (is_centroid)
2449 							lerp_call = ".interpolate_at_centroid()";
2450 						else if (is_sample)
2451 							lerp_call = join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2452 						else
2453 							lerp_call = ".interpolate_at_center()";
2454 						statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), "[", i, "] = ", ib_var_ref,
2455 						          ".", mbr_name, lerp_call, ";");
2456 					}
2457 					else
2458 					{
2459 						statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), "[", i, "] = ", ib_var_ref,
2460 						          ".", mbr_name, ";");
2461 					}
2462 				});
2463 				break;
2464 
2465 			case StorageClassOutput:
2466 				entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
2467 					if (flatten_from_ib_var)
2468 					{
2469 						statement(ib_var_ref, ".", mbr_name, " = ", ib_var_ref, ".", flatten_from_ib_mbr_name, "[", i,
2470 						          "];");
2471 					}
2472 					else
2473 					{
2474 						statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), ".",
2475 						          to_member_name(var_type, mbr_idx), "[", i, "];");
2476 					}
2477 				});
2478 				break;
2479 
2480 			default:
2481 				break;
2482 			}
2483 		}
2484 	}
2485 }
2486 
add_plain_member_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,uint32_t mbr_idx,InterfaceBlockMeta & meta)2487 void CompilerMSL::add_plain_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2488                                                                SPIRType &ib_type, SPIRVariable &var, uint32_t mbr_idx,
2489                                                                InterfaceBlockMeta &meta)
2490 {
2491 	auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2492 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2493 
2494 	BuiltIn builtin = BuiltInMax;
2495 	bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2496 	bool is_flat =
2497 	    has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
2498 	bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
2499 	                        has_decoration(var.self, DecorationNoPerspective);
2500 	bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
2501 	                   has_decoration(var.self, DecorationCentroid);
2502 	bool is_sample =
2503 	    has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
2504 
2505 	// Add a reference to the member to the interface struct.
2506 	uint32_t mbr_type_id = var_type.member_types[mbr_idx];
2507 	uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2508 	mbr_type_id = ensure_correct_builtin_type(mbr_type_id, builtin);
2509 	var_type.member_types[mbr_idx] = mbr_type_id;
2510 	if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2511 		ib_type.member_types.push_back(build_msl_interpolant_type(mbr_type_id, is_noperspective));
2512 	else
2513 		ib_type.member_types.push_back(mbr_type_id);
2514 
2515 	// Give the member a name
2516 	string mbr_name = ensure_valid_name(to_qualified_member_name(var_type, mbr_idx), "m");
2517 	set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2518 
2519 	// Update the original variable reference to include the structure reference
2520 	string qual_var_name = ib_var_ref + "." + mbr_name;
2521 	// If using pull-model interpolation, need to add a call to the correct interpolation method.
2522 	if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2523 	{
2524 		if (is_centroid)
2525 			qual_var_name += ".interpolate_at_centroid()";
2526 		else if (is_sample)
2527 			qual_var_name += join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2528 		else
2529 			qual_var_name += ".interpolate_at_center()";
2530 	}
2531 
2532 	bool flatten_stage_out = false;
2533 
2534 	if (is_builtin && !meta.strip_array)
2535 	{
2536 		// For the builtin gl_PerVertex, we cannot treat it as a block anyways,
2537 		// so redirect to qualified name.
2538 		set_member_qualified_name(var_type.self, mbr_idx, qual_var_name);
2539 	}
2540 	else if (!meta.strip_array)
2541 	{
2542 		// Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
2543 		switch (storage)
2544 		{
2545 		case StorageClassInput:
2546 			entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
2547 				statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), " = ", qual_var_name, ";");
2548 			});
2549 			break;
2550 
2551 		case StorageClassOutput:
2552 			flatten_stage_out = true;
2553 			entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
2554 				statement(qual_var_name, " = ", to_name(var.self), ".", to_member_name(var_type, mbr_idx), ";");
2555 			});
2556 			break;
2557 
2558 		default:
2559 			break;
2560 		}
2561 	}
2562 
2563 	// Copy the variable location from the original variable to the member
2564 	if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
2565 	{
2566 		uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation);
2567 		if (storage == StorageClassInput)
2568 		{
2569 			mbr_type_id = ensure_correct_input_type(mbr_type_id, locn);
2570 			var_type.member_types[mbr_idx] = mbr_type_id;
2571 			if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2572 				ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective);
2573 			else
2574 				ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2575 		}
2576 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2577 		mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2578 	}
2579 	else if (has_decoration(var.self, DecorationLocation))
2580 	{
2581 		// The block itself might have a location and in this case, all members of the block
2582 		// receive incrementing locations.
2583 		uint32_t locn = get_accumulated_member_location(var, mbr_idx, meta.strip_array);
2584 		if (storage == StorageClassInput)
2585 		{
2586 			mbr_type_id = ensure_correct_input_type(mbr_type_id, locn);
2587 			var_type.member_types[mbr_idx] = mbr_type_id;
2588 			if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2589 				ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective);
2590 			else
2591 				ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2592 		}
2593 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2594 		mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2595 	}
2596 	else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2597 	{
2598 		uint32_t locn = 0;
2599 		auto builtin_itr = inputs_by_builtin.find(builtin);
2600 		if (builtin_itr != end(inputs_by_builtin))
2601 			locn = builtin_itr->second.location;
2602 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2603 		mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2604 	}
2605 
2606 	// Copy the component location, if present.
2607 	if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
2608 	{
2609 		uint32_t comp = get_member_decoration(var_type.self, mbr_idx, DecorationComponent);
2610 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp);
2611 	}
2612 
2613 	// Mark the member as builtin if needed
2614 	if (is_builtin)
2615 	{
2616 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2617 		if (builtin == BuiltInPosition && storage == StorageClassOutput)
2618 			qual_pos_var_name = qual_var_name;
2619 	}
2620 
2621 	const SPIRConstant *c = nullptr;
2622 	if (!flatten_stage_out && var.storage == StorageClassOutput &&
2623 	    var.initializer != ID(0) && (c = maybe_get<SPIRConstant>(var.initializer)))
2624 	{
2625 		if (meta.strip_array)
2626 		{
2627 			entry_func.fixup_hooks_in.push_back([=, &var]() {
2628 				auto &type = this->get<SPIRType>(var.basetype);
2629 				uint32_t index = get_extended_decoration(var.self, SPIRVCrossDecorationInterfaceMemberIndex);
2630 				index += mbr_idx;
2631 
2632 				AccessChainMeta chain_meta;
2633 				auto constant_chain = access_chain_internal(var.initializer, &builtin_invocation_id_id, 1, 0, &chain_meta);
2634 
2635 				statement(to_expression(stage_out_ptr_var_id), "[",
2636 				          builtin_to_glsl(BuiltInInvocationId, StorageClassInput), "].",
2637 				          to_member_name(ib_type, index), " = ",
2638 				          constant_chain, ".", to_member_name(type, mbr_idx), ";");
2639 			});
2640 		}
2641 		else
2642 		{
2643 			entry_func.fixup_hooks_in.push_back([=]() {
2644 				statement(qual_var_name, " = ", constant_expression(
2645 						this->get<SPIRConstant>(c->subconstants[mbr_idx])), ";");
2646 			});
2647 		}
2648 	}
2649 
2650 	if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2651 	{
2652 		// Copy interpolation decorations if needed
2653 		if (is_flat)
2654 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2655 		if (is_noperspective)
2656 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2657 		if (is_centroid)
2658 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2659 		if (is_sample)
2660 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2661 	}
2662 
2663 	set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2664 	set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
2665 }
2666 
2667 // In Metal, the tessellation levels are stored as tightly packed half-precision floating point values.
2668 // But, stage-in attribute offsets and strides must be multiples of four, so we can't pass the levels
2669 // individually. Therefore, we must pass them as vectors. Triangles get a single float4, with the outer
2670 // levels in 'xyz' and the inner level in 'w'. Quads get a float4 containing the outer levels and a
2671 // float2 containing the inner levels.
add_tess_level_input_to_interface_block(const std::string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var)2672 void CompilerMSL::add_tess_level_input_to_interface_block(const std::string &ib_var_ref, SPIRType &ib_type,
2673                                                           SPIRVariable &var)
2674 {
2675 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2676 	auto &var_type = get_variable_element_type(var);
2677 
2678 	BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2679 
2680 	// Force the variable to have the proper name.
2681 	set_name(var.self, builtin_to_glsl(builtin, StorageClassFunction));
2682 
2683 	if (get_entry_point().flags.get(ExecutionModeTriangles))
2684 	{
2685 		// Triangles are tricky, because we want only one member in the struct.
2686 
2687 		// We need to declare the variable early and at entry-point scope.
2688 		entry_func.add_local_variable(var.self);
2689 		vars_needing_early_declaration.push_back(var.self);
2690 
2691 		string mbr_name = "gl_TessLevel";
2692 
2693 		// If we already added the other one, we can skip this step.
2694 		if (!added_builtin_tess_level)
2695 		{
2696 			// Add a reference to the variable type to the interface struct.
2697 			uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2698 
2699 			uint32_t type_id = build_extended_vector_type(var_type.self, 4);
2700 
2701 			ib_type.member_types.push_back(type_id);
2702 
2703 			// Give the member a name
2704 			set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2705 
2706 			// There is no qualified alias since we need to flatten the internal array on return.
2707 			if (get_decoration_bitset(var.self).get(DecorationLocation))
2708 			{
2709 				uint32_t locn = get_decoration(var.self, DecorationLocation);
2710 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2711 				mark_location_as_used_by_shader(locn, var_type, StorageClassInput);
2712 			}
2713 			else if (inputs_by_builtin.count(builtin))
2714 			{
2715 				uint32_t locn = inputs_by_builtin[builtin].location;
2716 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2717 				mark_location_as_used_by_shader(locn, var_type, StorageClassInput);
2718 			}
2719 
2720 			added_builtin_tess_level = true;
2721 		}
2722 
2723 		switch (builtin)
2724 		{
2725 		case BuiltInTessLevelOuter:
2726 			entry_func.fixup_hooks_in.push_back([=, &var]() {
2727 				statement(to_name(var.self), "[0] = ", ib_var_ref, ".", mbr_name, ".x;");
2728 				statement(to_name(var.self), "[1] = ", ib_var_ref, ".", mbr_name, ".y;");
2729 				statement(to_name(var.self), "[2] = ", ib_var_ref, ".", mbr_name, ".z;");
2730 			});
2731 			break;
2732 
2733 		case BuiltInTessLevelInner:
2734 			entry_func.fixup_hooks_in.push_back(
2735 			    [=, &var]() { statement(to_name(var.self), "[0] = ", ib_var_ref, ".", mbr_name, ".w;"); });
2736 			break;
2737 
2738 		default:
2739 			assert(false);
2740 			break;
2741 		}
2742 	}
2743 	else
2744 	{
2745 		// Add a reference to the variable type to the interface struct.
2746 		uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2747 
2748 		uint32_t type_id = build_extended_vector_type(var_type.self, builtin == BuiltInTessLevelOuter ? 4 : 2);
2749 		// Change the type of the variable, too.
2750 		uint32_t ptr_type_id = ir.increase_bound_by(1);
2751 		auto &new_var_type = set<SPIRType>(ptr_type_id, get<SPIRType>(type_id));
2752 		new_var_type.pointer = true;
2753 		new_var_type.storage = StorageClassInput;
2754 		new_var_type.parent_type = type_id;
2755 		var.basetype = ptr_type_id;
2756 
2757 		ib_type.member_types.push_back(type_id);
2758 
2759 		// Give the member a name
2760 		string mbr_name = to_expression(var.self);
2761 		set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2762 
2763 		// Since vectors can be indexed like arrays, there is no need to unpack this. We can
2764 		// just refer to the vector directly. So give it a qualified alias.
2765 		string qual_var_name = ib_var_ref + "." + mbr_name;
2766 		ir.meta[var.self].decoration.qualified_alias = qual_var_name;
2767 
2768 		if (get_decoration_bitset(var.self).get(DecorationLocation))
2769 		{
2770 			uint32_t locn = get_decoration(var.self, DecorationLocation);
2771 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2772 			mark_location_as_used_by_shader(locn, new_var_type, StorageClassInput);
2773 		}
2774 		else if (inputs_by_builtin.count(builtin))
2775 		{
2776 			uint32_t locn = inputs_by_builtin[builtin].location;
2777 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2778 			mark_location_as_used_by_shader(locn, new_var_type, StorageClassInput);
2779 		}
2780 	}
2781 }
2782 
add_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)2783 void CompilerMSL::add_variable_to_interface_block(StorageClass storage, const string &ib_var_ref, SPIRType &ib_type,
2784                                                   SPIRVariable &var, InterfaceBlockMeta &meta)
2785 {
2786 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2787 	// Tessellation control I/O variables and tessellation evaluation per-point inputs are
2788 	// usually declared as arrays. In these cases, we want to add the element type to the
2789 	// interface block, since in Metal it's the interface block itself which is arrayed.
2790 	auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2791 	bool is_builtin = is_builtin_variable(var);
2792 	auto builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2793 
2794 	if (var_type.basetype == SPIRType::Struct)
2795 	{
2796 		if (!is_builtin_type(var_type) && (!capture_output_to_buffer || storage == StorageClassInput) &&
2797 		    !meta.strip_array)
2798 		{
2799 			// For I/O blocks or structs, we will need to pass the block itself around
2800 			// to functions if they are used globally in leaf functions.
2801 			// Rather than passing down member by member,
2802 			// we unflatten I/O blocks while running the shader,
2803 			// and pass the actual struct type down to leaf functions.
2804 			// We then unflatten inputs, and flatten outputs in the "fixup" stages.
2805 			entry_func.add_local_variable(var.self);
2806 			vars_needing_early_declaration.push_back(var.self);
2807 		}
2808 
2809 		if (capture_output_to_buffer && storage != StorageClassInput && !has_decoration(var_type.self, DecorationBlock))
2810 		{
2811 			// In Metal tessellation shaders, the interface block itself is arrayed. This makes things
2812 			// very complicated, since stage-in structures in MSL don't support nested structures.
2813 			// Luckily, for stage-out when capturing output, we can avoid this and just add
2814 			// composite members directly, because the stage-out structure is stored to a buffer,
2815 			// not returned.
2816 			add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
2817 		}
2818 		else
2819 		{
2820 			// Flatten the struct members into the interface struct
2821 			for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++)
2822 			{
2823 				builtin = BuiltInMax;
2824 				is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2825 				auto &mbr_type = get<SPIRType>(var_type.member_types[mbr_idx]);
2826 
2827 				if (!is_builtin || has_active_builtin(builtin, storage))
2828 				{
2829 					bool is_composite_type = is_matrix(mbr_type) || is_array(mbr_type);
2830 					bool attribute_load_store =
2831 					    storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
2832 					bool storage_is_stage_io =
2833 					    (storage == StorageClassInput && !(get_execution_model() == ExecutionModelTessellationControl &&
2834 					                                       msl_options.multi_patch_workgroup)) ||
2835 					    storage == StorageClassOutput;
2836 
2837 					// ClipDistance always needs to be declared as user attributes.
2838 					if (builtin == BuiltInClipDistance)
2839 						is_builtin = false;
2840 
2841 					if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
2842 					{
2843 						add_composite_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx,
2844 						                                                 meta);
2845 					}
2846 					else
2847 					{
2848 						add_plain_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx, meta);
2849 					}
2850 				}
2851 			}
2852 		}
2853 	}
2854 	else if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput &&
2855 	         !meta.strip_array && is_builtin && (builtin == BuiltInTessLevelOuter || builtin == BuiltInTessLevelInner))
2856 	{
2857 		add_tess_level_input_to_interface_block(ib_var_ref, ib_type, var);
2858 	}
2859 	else if (var_type.basetype == SPIRType::Boolean || var_type.basetype == SPIRType::Char ||
2860 	         type_is_integral(var_type) || type_is_floating_point(var_type))
2861 	{
2862 		if (!is_builtin || has_active_builtin(builtin, storage))
2863 		{
2864 			bool is_composite_type = is_matrix(var_type) || is_array(var_type);
2865 			bool storage_is_stage_io =
2866 			    (storage == StorageClassInput &&
2867 			     !(get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup)) ||
2868 			    (storage == StorageClassOutput && !capture_output_to_buffer);
2869 			bool attribute_load_store = storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
2870 
2871 			// ClipDistance always needs to be declared as user attributes.
2872 			if (builtin == BuiltInClipDistance)
2873 				is_builtin = false;
2874 
2875 			// MSL does not allow matrices or arrays in input or output variables, so need to handle it specially.
2876 			if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
2877 			{
2878 				add_composite_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
2879 			}
2880 			else
2881 			{
2882 				add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
2883 			}
2884 		}
2885 	}
2886 }
2887 
2888 // Fix up the mapping of variables to interface member indices, which is used to compile access chains
2889 // for per-vertex variables in a tessellation control shader.
fix_up_interface_member_indices(StorageClass storage,uint32_t ib_type_id)2890 void CompilerMSL::fix_up_interface_member_indices(StorageClass storage, uint32_t ib_type_id)
2891 {
2892 	// Only needed for tessellation shaders and pull-model interpolants.
2893 	// Need to redirect interface indices back to variables themselves.
2894 	// For structs, each member of the struct need a separate instance.
2895 	if (get_execution_model() != ExecutionModelTessellationControl &&
2896 	    !(get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput) &&
2897 	    !(get_execution_model() == ExecutionModelFragment && storage == StorageClassInput &&
2898 	      !pull_model_inputs.empty()))
2899 		return;
2900 
2901 	auto mbr_cnt = uint32_t(ir.meta[ib_type_id].members.size());
2902 	for (uint32_t i = 0; i < mbr_cnt; i++)
2903 	{
2904 		uint32_t var_id = get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceOrigID);
2905 		if (!var_id)
2906 			continue;
2907 		auto &var = get<SPIRVariable>(var_id);
2908 
2909 		auto &type = get_variable_element_type(var);
2910 		if (storage == StorageClassInput && type.basetype == SPIRType::Struct)
2911 		{
2912 			uint32_t mbr_idx = get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceMemberIndex);
2913 
2914 			// Only set the lowest InterfaceMemberIndex for each variable member.
2915 			// IB struct members will be emitted in-order w.r.t. interface member index.
2916 			if (!has_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex))
2917 				set_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, i);
2918 		}
2919 		else
2920 		{
2921 			// Only set the lowest InterfaceMemberIndex for each variable.
2922 			// IB struct members will be emitted in-order w.r.t. interface member index.
2923 			if (!has_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex))
2924 				set_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex, i);
2925 		}
2926 	}
2927 }
2928 
2929 // Add an interface structure for the type of storage, which is either StorageClassInput or StorageClassOutput.
2930 // Returns the ID of the newly added variable, or zero if no variable was added.
add_interface_block(StorageClass storage,bool patch)2931 uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch)
2932 {
2933 	// Accumulate the variables that should appear in the interface struct.
2934 	SmallVector<SPIRVariable *> vars;
2935 	bool incl_builtins = storage == StorageClassOutput || is_tessellation_shader();
2936 	bool has_seen_barycentric = false;
2937 
2938 	InterfaceBlockMeta meta;
2939 
2940 	// Varying interfaces between stages which use "user()" attribute can be dealt with
2941 	// without explicit packing and unpacking of components. For any variables which link against the runtime
2942 	// in some way (vertex attributes, fragment output, etc), we'll need to deal with it somehow.
2943 	bool pack_components =
2944 	    (storage == StorageClassInput && get_execution_model() == ExecutionModelVertex) ||
2945 	    (storage == StorageClassOutput && get_execution_model() == ExecutionModelFragment) ||
2946 	    (storage == StorageClassOutput && get_execution_model() == ExecutionModelVertex && capture_output_to_buffer);
2947 
2948 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
2949 		if (var.storage != storage)
2950 			return;
2951 
2952 		auto &type = this->get<SPIRType>(var.basetype);
2953 
2954 		bool is_builtin = is_builtin_variable(var);
2955 		auto bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
2956 		uint32_t location = get_decoration(var_id, DecorationLocation);
2957 
2958 		// These builtins are part of the stage in/out structs.
2959 		bool is_interface_block_builtin =
2960 		    (bi_type == BuiltInPosition || bi_type == BuiltInPointSize || bi_type == BuiltInClipDistance ||
2961 		     bi_type == BuiltInCullDistance || bi_type == BuiltInLayer || bi_type == BuiltInViewportIndex ||
2962 		     bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV || bi_type == BuiltInFragDepth ||
2963 		     bi_type == BuiltInFragStencilRefEXT || bi_type == BuiltInSampleMask) ||
2964 		    (get_execution_model() == ExecutionModelTessellationEvaluation &&
2965 		     (bi_type == BuiltInTessLevelOuter || bi_type == BuiltInTessLevelInner));
2966 
2967 		bool is_active = interface_variable_exists_in_entry_point(var.self);
2968 		if (is_builtin && is_active)
2969 		{
2970 			// Only emit the builtin if it's active in this entry point. Interface variable list might lie.
2971 			is_active = has_active_builtin(bi_type, storage);
2972 		}
2973 
2974 		bool filter_patch_decoration = (has_decoration(var_id, DecorationPatch) || is_patch_block(type)) == patch;
2975 
2976 		bool hidden = is_hidden_variable(var, incl_builtins);
2977 
2978 		// ClipDistance is never hidden, we need to emulate it when used as an input.
2979 		if (bi_type == BuiltInClipDistance)
2980 			hidden = false;
2981 
2982 		// It's not enough to simply avoid marking fragment outputs if the pipeline won't
2983 		// accept them. We can't put them in the struct at all, or otherwise the compiler
2984 		// complains that the outputs weren't explicitly marked.
2985 		if (get_execution_model() == ExecutionModelFragment && storage == StorageClassOutput && !patch &&
2986 		    ((is_builtin && ((bi_type == BuiltInFragDepth && !msl_options.enable_frag_depth_builtin) ||
2987 		                     (bi_type == BuiltInFragStencilRefEXT && !msl_options.enable_frag_stencil_ref_builtin))) ||
2988 		     (!is_builtin && !(msl_options.enable_frag_output_mask & (1 << location)))))
2989 		{
2990 			hidden = true;
2991 			disabled_frag_outputs.push_back(var_id);
2992 			// If a builtin, force it to have the proper name.
2993 			if (is_builtin)
2994 				set_name(var_id, builtin_to_glsl(bi_type, StorageClassFunction));
2995 		}
2996 
2997 		// Barycentric inputs must be emitted in stage-in, because they can have interpolation arguments.
2998 		if (is_active && (bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV))
2999 		{
3000 			if (has_seen_barycentric)
3001 				SPIRV_CROSS_THROW("Cannot declare both BaryCoordNV and BaryCoordNoPerspNV in same shader in MSL.");
3002 			has_seen_barycentric = true;
3003 			hidden = false;
3004 		}
3005 
3006 		if (is_active && !hidden && type.pointer && filter_patch_decoration &&
3007 		    (!is_builtin || is_interface_block_builtin))
3008 		{
3009 			vars.push_back(&var);
3010 
3011 			if (!is_builtin)
3012 			{
3013 				// Need to deal specially with DecorationComponent.
3014 				// Multiple variables can alias the same Location, and try to make sure each location is declared only once.
3015 				// We will swizzle data in and out to make this work.
3016 				// We only need to consider plain variables here, not composites.
3017 				// This is only relevant for vertex inputs and fragment outputs.
3018 				// Technically tessellation as well, but it is too complicated to support.
3019 				uint32_t component = get_decoration(var_id, DecorationComponent);
3020 				if (component != 0)
3021 				{
3022 					if (is_tessellation_shader())
3023 						SPIRV_CROSS_THROW("Component decoration is not supported in tessellation shaders.");
3024 					else if (pack_components)
3025 					{
3026 						auto &location_meta = meta.location_meta[location];
3027 						location_meta.num_components = std::max(location_meta.num_components, component + type.vecsize);
3028 					}
3029 				}
3030 			}
3031 		}
3032 	});
3033 
3034 	// If no variables qualify, leave.
3035 	// For patch input in a tessellation evaluation shader, the per-vertex stage inputs
3036 	// are included in a special patch control point array.
3037 	if (vars.empty() && !(storage == StorageClassInput && patch && stage_in_var_id))
3038 		return 0;
3039 
3040 	// Add a new typed variable for this interface structure.
3041 	// The initializer expression is allocated here, but populated when the function
3042 	// declaraion is emitted, because it is cleared after each compilation pass.
3043 	uint32_t next_id = ir.increase_bound_by(3);
3044 	uint32_t ib_type_id = next_id++;
3045 	auto &ib_type = set<SPIRType>(ib_type_id);
3046 	ib_type.basetype = SPIRType::Struct;
3047 	ib_type.storage = storage;
3048 	set_decoration(ib_type_id, DecorationBlock);
3049 
3050 	uint32_t ib_var_id = next_id++;
3051 	auto &var = set<SPIRVariable>(ib_var_id, ib_type_id, storage, 0);
3052 	var.initializer = next_id++;
3053 
3054 	string ib_var_ref;
3055 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
3056 	switch (storage)
3057 	{
3058 	case StorageClassInput:
3059 		ib_var_ref = patch ? patch_stage_in_var_name : stage_in_var_name;
3060 		if (get_execution_model() == ExecutionModelTessellationControl)
3061 		{
3062 			// Add a hook to populate the shared workgroup memory containing the gl_in array.
3063 			entry_func.fixup_hooks_in.push_back([=]() {
3064 				// Can't use PatchVertices, PrimitiveId, or InvocationId yet; the hooks for those may not have run yet.
3065 				if (msl_options.multi_patch_workgroup)
3066 				{
3067 					// n.b. builtin_invocation_id_id here is the dispatch global invocation ID,
3068 					// not the TC invocation ID.
3069 					statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_in = &",
3070 					          input_buffer_var_name, "[min(", to_expression(builtin_invocation_id_id), ".x / ",
3071 					          get_entry_point().output_vertices,
3072 					          ", spvIndirectParams[1] - 1) * spvIndirectParams[0]];");
3073 				}
3074 				else
3075 				{
3076 					// It's safe to use InvocationId here because it's directly mapped to a
3077 					// Metal builtin, and therefore doesn't need a hook.
3078 					statement("if (", to_expression(builtin_invocation_id_id), " < spvIndirectParams[0])");
3079 					statement("    ", input_wg_var_name, "[", to_expression(builtin_invocation_id_id),
3080 					          "] = ", ib_var_ref, ";");
3081 					statement("threadgroup_barrier(mem_flags::mem_threadgroup);");
3082 					statement("if (", to_expression(builtin_invocation_id_id),
3083 					          " >= ", get_entry_point().output_vertices, ")");
3084 					statement("    return;");
3085 				}
3086 			});
3087 		}
3088 		break;
3089 
3090 	case StorageClassOutput:
3091 	{
3092 		ib_var_ref = patch ? patch_stage_out_var_name : stage_out_var_name;
3093 
3094 		// Add the output interface struct as a local variable to the entry function.
3095 		// If the entry point should return the output struct, set the entry function
3096 		// to return the output interface struct, otherwise to return nothing.
3097 		// Indicate the output var requires early initialization.
3098 		bool ep_should_return_output = !get_is_rasterization_disabled();
3099 		uint32_t rtn_id = ep_should_return_output ? ib_var_id : 0;
3100 		if (!capture_output_to_buffer)
3101 		{
3102 			entry_func.add_local_variable(ib_var_id);
3103 			for (auto &blk_id : entry_func.blocks)
3104 			{
3105 				auto &blk = get<SPIRBlock>(blk_id);
3106 				if (blk.terminator == SPIRBlock::Return)
3107 					blk.return_value = rtn_id;
3108 			}
3109 			vars_needing_early_declaration.push_back(ib_var_id);
3110 		}
3111 		else
3112 		{
3113 			switch (get_execution_model())
3114 			{
3115 			case ExecutionModelVertex:
3116 			case ExecutionModelTessellationEvaluation:
3117 				// Instead of declaring a struct variable to hold the output and then
3118 				// copying that to the output buffer, we'll declare the output variable
3119 				// as a reference to the final output element in the buffer. Then we can
3120 				// avoid the extra copy.
3121 				entry_func.fixup_hooks_in.push_back([=]() {
3122 					if (stage_out_var_id)
3123 					{
3124 						// The first member of the indirect buffer is always the number of vertices
3125 						// to draw.
3126 						// We zero-base the InstanceID & VertexID variables for HLSL emulation elsewhere, so don't do it twice
3127 						if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
3128 						{
3129 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3130 							          " = ", output_buffer_var_name, "[", to_expression(builtin_invocation_id_id),
3131 							          ".y * ", to_expression(builtin_stage_input_size_id), ".x + ",
3132 							          to_expression(builtin_invocation_id_id), ".x];");
3133 						}
3134 						else if (msl_options.enable_base_index_zero)
3135 						{
3136 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3137 							          " = ", output_buffer_var_name, "[", to_expression(builtin_instance_idx_id),
3138 							          " * spvIndirectParams[0] + ", to_expression(builtin_vertex_idx_id), "];");
3139 						}
3140 						else
3141 						{
3142 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3143 							          " = ", output_buffer_var_name, "[(", to_expression(builtin_instance_idx_id),
3144 							          " - ", to_expression(builtin_base_instance_id), ") * spvIndirectParams[0] + ",
3145 							          to_expression(builtin_vertex_idx_id), " - ",
3146 							          to_expression(builtin_base_vertex_id), "];");
3147 						}
3148 					}
3149 				});
3150 				break;
3151 			case ExecutionModelTessellationControl:
3152 				if (msl_options.multi_patch_workgroup)
3153 				{
3154 					// We cannot use PrimitiveId here, because the hook may not have run yet.
3155 					if (patch)
3156 					{
3157 						entry_func.fixup_hooks_in.push_back([=]() {
3158 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3159 							          " = ", patch_output_buffer_var_name, "[", to_expression(builtin_invocation_id_id),
3160 							          ".x / ", get_entry_point().output_vertices, "];");
3161 						});
3162 					}
3163 					else
3164 					{
3165 						entry_func.fixup_hooks_in.push_back([=]() {
3166 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &",
3167 							          output_buffer_var_name, "[", to_expression(builtin_invocation_id_id), ".x - ",
3168 							          to_expression(builtin_invocation_id_id), ".x % ",
3169 							          get_entry_point().output_vertices, "];");
3170 						});
3171 					}
3172 				}
3173 				else
3174 				{
3175 					if (patch)
3176 					{
3177 						entry_func.fixup_hooks_in.push_back([=]() {
3178 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3179 							          " = ", patch_output_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
3180 							          "];");
3181 						});
3182 					}
3183 					else
3184 					{
3185 						entry_func.fixup_hooks_in.push_back([=]() {
3186 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &",
3187 							          output_buffer_var_name, "[", to_expression(builtin_primitive_id_id), " * ",
3188 							          get_entry_point().output_vertices, "];");
3189 						});
3190 					}
3191 				}
3192 				break;
3193 			default:
3194 				break;
3195 			}
3196 		}
3197 		break;
3198 	}
3199 
3200 	default:
3201 		break;
3202 	}
3203 
3204 	set_name(ib_type_id, to_name(ir.default_entry_point) + "_" + ib_var_ref);
3205 	set_name(ib_var_id, ib_var_ref);
3206 
3207 	for (auto *p_var : vars)
3208 	{
3209 		bool strip_array =
3210 		    (get_execution_model() == ExecutionModelTessellationControl ||
3211 		     (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput)) &&
3212 		    !patch;
3213 
3214 		meta.strip_array = strip_array;
3215 		add_variable_to_interface_block(storage, ib_var_ref, ib_type, *p_var, meta);
3216 	}
3217 
3218 	if (get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup &&
3219 	    storage == StorageClassInput)
3220 	{
3221 		// For tessellation control inputs, add all outputs from the vertex shader to ensure
3222 		// the struct containing them is the correct size and layout.
3223 		for (auto &input : inputs_by_location)
3224 		{
3225 			if (is_msl_shader_input_used(input.first))
3226 				continue;
3227 
3228 			// Create a fake variable to put at the location.
3229 			uint32_t offset = ir.increase_bound_by(4);
3230 			uint32_t type_id = offset;
3231 			uint32_t array_type_id = offset + 1;
3232 			uint32_t ptr_type_id = offset + 2;
3233 			uint32_t var_id = offset + 3;
3234 
3235 			SPIRType type;
3236 			switch (input.second.format)
3237 			{
3238 			case MSL_SHADER_INPUT_FORMAT_UINT16:
3239 			case MSL_SHADER_INPUT_FORMAT_ANY16:
3240 				type.basetype = SPIRType::UShort;
3241 				type.width = 16;
3242 				break;
3243 			case MSL_SHADER_INPUT_FORMAT_ANY32:
3244 			default:
3245 				type.basetype = SPIRType::UInt;
3246 				type.width = 32;
3247 				break;
3248 			}
3249 			type.vecsize = input.second.vecsize;
3250 			set<SPIRType>(type_id, type);
3251 
3252 			type.array.push_back(0);
3253 			type.array_size_literal.push_back(true);
3254 			type.parent_type = type_id;
3255 			set<SPIRType>(array_type_id, type);
3256 
3257 			type.pointer = true;
3258 			type.parent_type = array_type_id;
3259 			type.storage = storage;
3260 			auto &ptr_type = set<SPIRType>(ptr_type_id, type);
3261 			ptr_type.self = array_type_id;
3262 
3263 			auto &fake_var = set<SPIRVariable>(var_id, ptr_type_id, storage);
3264 			set_decoration(var_id, DecorationLocation, input.first);
3265 			meta.strip_array = true;
3266 			add_variable_to_interface_block(storage, ib_var_ref, ib_type, fake_var, meta);
3267 		}
3268 	}
3269 
3270 	// Sort the members of the structure by their locations.
3271 	MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Location);
3272 	member_sorter.sort();
3273 
3274 	// The member indices were saved to the original variables, but after the members
3275 	// were sorted, those indices are now likely incorrect. Fix those up now.
3276 	if (!patch)
3277 		fix_up_interface_member_indices(storage, ib_type_id);
3278 
3279 	// For patch inputs, add one more member, holding the array of control point data.
3280 	if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput && patch &&
3281 	    stage_in_var_id)
3282 	{
3283 		uint32_t pcp_type_id = ir.increase_bound_by(1);
3284 		auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
3285 		pcp_type.basetype = SPIRType::ControlPointArray;
3286 		pcp_type.parent_type = pcp_type.type_alias = get_stage_in_struct_type().self;
3287 		pcp_type.storage = storage;
3288 		ir.meta[pcp_type_id] = ir.meta[ib_type.self];
3289 		uint32_t mbr_idx = uint32_t(ib_type.member_types.size());
3290 		ib_type.member_types.push_back(pcp_type_id);
3291 		set_member_name(ib_type.self, mbr_idx, "gl_in");
3292 	}
3293 
3294 	return ib_var_id;
3295 }
3296 
add_interface_block_pointer(uint32_t ib_var_id,StorageClass storage)3297 uint32_t CompilerMSL::add_interface_block_pointer(uint32_t ib_var_id, StorageClass storage)
3298 {
3299 	if (!ib_var_id)
3300 		return 0;
3301 
3302 	uint32_t ib_ptr_var_id;
3303 	uint32_t next_id = ir.increase_bound_by(3);
3304 	auto &ib_type = expression_type(ib_var_id);
3305 	if (get_execution_model() == ExecutionModelTessellationControl)
3306 	{
3307 		// Tessellation control per-vertex I/O is presented as an array, so we must
3308 		// do the same with our struct here.
3309 		uint32_t ib_ptr_type_id = next_id++;
3310 		auto &ib_ptr_type = set<SPIRType>(ib_ptr_type_id, ib_type);
3311 		ib_ptr_type.parent_type = ib_ptr_type.type_alias = ib_type.self;
3312 		ib_ptr_type.pointer = true;
3313 		ib_ptr_type.storage =
3314 		    storage == StorageClassInput ?
3315 		        (msl_options.multi_patch_workgroup ? StorageClassStorageBuffer : StorageClassWorkgroup) :
3316 		        StorageClassStorageBuffer;
3317 		ir.meta[ib_ptr_type_id] = ir.meta[ib_type.self];
3318 		// To ensure that get_variable_data_type() doesn't strip off the pointer,
3319 		// which we need, use another pointer.
3320 		uint32_t ib_ptr_ptr_type_id = next_id++;
3321 		auto &ib_ptr_ptr_type = set<SPIRType>(ib_ptr_ptr_type_id, ib_ptr_type);
3322 		ib_ptr_ptr_type.parent_type = ib_ptr_type_id;
3323 		ib_ptr_ptr_type.type_alias = ib_type.self;
3324 		ib_ptr_ptr_type.storage = StorageClassFunction;
3325 		ir.meta[ib_ptr_ptr_type_id] = ir.meta[ib_type.self];
3326 
3327 		ib_ptr_var_id = next_id;
3328 		set<SPIRVariable>(ib_ptr_var_id, ib_ptr_ptr_type_id, StorageClassFunction, 0);
3329 		set_name(ib_ptr_var_id, storage == StorageClassInput ? "gl_in" : "gl_out");
3330 	}
3331 	else
3332 	{
3333 		// Tessellation evaluation per-vertex inputs are also presented as arrays.
3334 		// But, in Metal, this array uses a very special type, 'patch_control_point<T>',
3335 		// which is a container that can be used to access the control point data.
3336 		// To represent this, a special 'ControlPointArray' type has been added to the
3337 		// SPIRV-Cross type system. It should only be generated by and seen in the MSL
3338 		// backend (i.e. this one).
3339 		uint32_t pcp_type_id = next_id++;
3340 		auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
3341 		pcp_type.basetype = SPIRType::ControlPointArray;
3342 		pcp_type.parent_type = pcp_type.type_alias = ib_type.self;
3343 		pcp_type.storage = storage;
3344 		ir.meta[pcp_type_id] = ir.meta[ib_type.self];
3345 
3346 		ib_ptr_var_id = next_id;
3347 		set<SPIRVariable>(ib_ptr_var_id, pcp_type_id, storage, 0);
3348 		set_name(ib_ptr_var_id, "gl_in");
3349 		ir.meta[ib_ptr_var_id].decoration.qualified_alias = join(patch_stage_in_var_name, ".gl_in");
3350 	}
3351 	return ib_ptr_var_id;
3352 }
3353 
3354 // Ensure that the type is compatible with the builtin.
3355 // If it is, simply return the given type ID.
3356 // Otherwise, create a new type, and return it's ID.
ensure_correct_builtin_type(uint32_t type_id,BuiltIn builtin)3357 uint32_t CompilerMSL::ensure_correct_builtin_type(uint32_t type_id, BuiltIn builtin)
3358 {
3359 	auto &type = get<SPIRType>(type_id);
3360 
3361 	if ((builtin == BuiltInSampleMask && is_array(type)) ||
3362 	    ((builtin == BuiltInLayer || builtin == BuiltInViewportIndex || builtin == BuiltInFragStencilRefEXT) &&
3363 	     type.basetype != SPIRType::UInt))
3364 	{
3365 		uint32_t next_id = ir.increase_bound_by(type.pointer ? 2 : 1);
3366 		uint32_t base_type_id = next_id++;
3367 		auto &base_type = set<SPIRType>(base_type_id);
3368 		base_type.basetype = SPIRType::UInt;
3369 		base_type.width = 32;
3370 
3371 		if (!type.pointer)
3372 			return base_type_id;
3373 
3374 		uint32_t ptr_type_id = next_id++;
3375 		auto &ptr_type = set<SPIRType>(ptr_type_id);
3376 		ptr_type = base_type;
3377 		ptr_type.pointer = true;
3378 		ptr_type.storage = type.storage;
3379 		ptr_type.parent_type = base_type_id;
3380 		return ptr_type_id;
3381 	}
3382 
3383 	return type_id;
3384 }
3385 
3386 // Ensure that the type is compatible with the shader input.
3387 // If it is, simply return the given type ID.
3388 // Otherwise, create a new type, and return its ID.
ensure_correct_input_type(uint32_t type_id,uint32_t location,uint32_t num_components)3389 uint32_t CompilerMSL::ensure_correct_input_type(uint32_t type_id, uint32_t location, uint32_t num_components)
3390 {
3391 	auto &type = get<SPIRType>(type_id);
3392 
3393 	auto p_va = inputs_by_location.find(location);
3394 	if (p_va == end(inputs_by_location))
3395 	{
3396 		if (num_components > type.vecsize)
3397 			return build_extended_vector_type(type_id, num_components);
3398 		else
3399 			return type_id;
3400 	}
3401 
3402 	if (num_components == 0)
3403 		num_components = p_va->second.vecsize;
3404 
3405 	switch (p_va->second.format)
3406 	{
3407 	case MSL_SHADER_INPUT_FORMAT_UINT8:
3408 	{
3409 		switch (type.basetype)
3410 		{
3411 		case SPIRType::UByte:
3412 		case SPIRType::UShort:
3413 		case SPIRType::UInt:
3414 			if (num_components > type.vecsize)
3415 				return build_extended_vector_type(type_id, num_components);
3416 			else
3417 				return type_id;
3418 
3419 		case SPIRType::Short:
3420 			return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3421 			                                  SPIRType::UShort);
3422 		case SPIRType::Int:
3423 			return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3424 			                                  SPIRType::UInt);
3425 
3426 		default:
3427 			SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
3428 		}
3429 	}
3430 
3431 	case MSL_SHADER_INPUT_FORMAT_UINT16:
3432 	{
3433 		switch (type.basetype)
3434 		{
3435 		case SPIRType::UShort:
3436 		case SPIRType::UInt:
3437 			if (num_components > type.vecsize)
3438 				return build_extended_vector_type(type_id, num_components);
3439 			else
3440 				return type_id;
3441 
3442 		case SPIRType::Int:
3443 			return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3444 			                                  SPIRType::UInt);
3445 
3446 		default:
3447 			SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
3448 		}
3449 	}
3450 
3451 	default:
3452 		if (num_components > type.vecsize)
3453 			type_id = build_extended_vector_type(type_id, num_components);
3454 		break;
3455 	}
3456 
3457 	return type_id;
3458 }
3459 
mark_struct_members_packed(const SPIRType & type)3460 void CompilerMSL::mark_struct_members_packed(const SPIRType &type)
3461 {
3462 	set_extended_decoration(type.self, SPIRVCrossDecorationPhysicalTypePacked);
3463 
3464 	// Problem case! Struct needs to be placed at an awkward alignment.
3465 	// Mark every member of the child struct as packed.
3466 	uint32_t mbr_cnt = uint32_t(type.member_types.size());
3467 	for (uint32_t i = 0; i < mbr_cnt; i++)
3468 	{
3469 		auto &mbr_type = get<SPIRType>(type.member_types[i]);
3470 		if (mbr_type.basetype == SPIRType::Struct)
3471 		{
3472 			// Recursively mark structs as packed.
3473 			auto *struct_type = &mbr_type;
3474 			while (!struct_type->array.empty())
3475 				struct_type = &get<SPIRType>(struct_type->parent_type);
3476 			mark_struct_members_packed(*struct_type);
3477 		}
3478 		else if (!is_scalar(mbr_type))
3479 			set_extended_member_decoration(type.self, i, SPIRVCrossDecorationPhysicalTypePacked);
3480 	}
3481 }
3482 
mark_scalar_layout_structs(const SPIRType & type)3483 void CompilerMSL::mark_scalar_layout_structs(const SPIRType &type)
3484 {
3485 	uint32_t mbr_cnt = uint32_t(type.member_types.size());
3486 	for (uint32_t i = 0; i < mbr_cnt; i++)
3487 	{
3488 		auto &mbr_type = get<SPIRType>(type.member_types[i]);
3489 		if (mbr_type.basetype == SPIRType::Struct)
3490 		{
3491 			auto *struct_type = &mbr_type;
3492 			while (!struct_type->array.empty())
3493 				struct_type = &get<SPIRType>(struct_type->parent_type);
3494 
3495 			if (has_extended_decoration(struct_type->self, SPIRVCrossDecorationPhysicalTypePacked))
3496 				continue;
3497 
3498 			uint32_t msl_alignment = get_declared_struct_member_alignment_msl(type, i);
3499 			uint32_t msl_size = get_declared_struct_member_size_msl(type, i);
3500 			uint32_t spirv_offset = type_struct_member_offset(type, i);
3501 			uint32_t spirv_offset_next;
3502 			if (i + 1 < mbr_cnt)
3503 				spirv_offset_next = type_struct_member_offset(type, i + 1);
3504 			else
3505 				spirv_offset_next = spirv_offset + msl_size;
3506 
3507 			// Both are complicated cases. In scalar layout, a struct of float3 might just consume 12 bytes,
3508 			// and the next member will be placed at offset 12.
3509 			bool struct_is_misaligned = (spirv_offset % msl_alignment) != 0;
3510 			bool struct_is_too_large = spirv_offset + msl_size > spirv_offset_next;
3511 			uint32_t array_stride = 0;
3512 			bool struct_needs_explicit_padding = false;
3513 
3514 			// Verify that if a struct is used as an array that ArrayStride matches the effective size of the struct.
3515 			if (!mbr_type.array.empty())
3516 			{
3517 				array_stride = type_struct_member_array_stride(type, i);
3518 				uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
3519 				for (uint32_t dim = 0; dim < dimensions; dim++)
3520 				{
3521 					uint32_t array_size = to_array_size_literal(mbr_type, dim);
3522 					array_stride /= max(array_size, 1u);
3523 				}
3524 
3525 				// Set expected struct size based on ArrayStride.
3526 				struct_needs_explicit_padding = true;
3527 
3528 				// If struct size is larger than array stride, we might be able to fit, if we tightly pack.
3529 				if (get_declared_struct_size_msl(*struct_type) > array_stride)
3530 					struct_is_too_large = true;
3531 			}
3532 
3533 			if (struct_is_misaligned || struct_is_too_large)
3534 				mark_struct_members_packed(*struct_type);
3535 			mark_scalar_layout_structs(*struct_type);
3536 
3537 			if (struct_needs_explicit_padding)
3538 			{
3539 				msl_size = get_declared_struct_size_msl(*struct_type, true, true);
3540 				if (array_stride < msl_size)
3541 				{
3542 					SPIRV_CROSS_THROW("Cannot express an array stride smaller than size of struct type.");
3543 				}
3544 				else
3545 				{
3546 					if (has_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget))
3547 					{
3548 						if (array_stride !=
3549 						    get_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget))
3550 							SPIRV_CROSS_THROW(
3551 							    "A struct is used with different array strides. Cannot express this in MSL.");
3552 					}
3553 					else
3554 						set_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget, array_stride);
3555 				}
3556 			}
3557 		}
3558 	}
3559 }
3560 
3561 // Sort the members of the struct type by offset, and pack and then pad members where needed
3562 // to align MSL members with SPIR-V offsets. The struct members are iterated twice. Packing
3563 // occurs first, followed by padding, because packing a member reduces both its size and its
3564 // natural alignment, possibly requiring a padding member to be added ahead of it.
align_struct(SPIRType & ib_type,unordered_set<uint32_t> & aligned_structs)3565 void CompilerMSL::align_struct(SPIRType &ib_type, unordered_set<uint32_t> &aligned_structs)
3566 {
3567 	// We align structs recursively, so stop any redundant work.
3568 	ID &ib_type_id = ib_type.self;
3569 	if (aligned_structs.count(ib_type_id))
3570 		return;
3571 	aligned_structs.insert(ib_type_id);
3572 
3573 	// Sort the members of the interface structure by their offset.
3574 	// They should already be sorted per SPIR-V spec anyway.
3575 	MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Offset);
3576 	member_sorter.sort();
3577 
3578 	auto mbr_cnt = uint32_t(ib_type.member_types.size());
3579 
3580 	for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
3581 	{
3582 		// Pack any dependent struct types before we pack a parent struct.
3583 		auto &mbr_type = get<SPIRType>(ib_type.member_types[mbr_idx]);
3584 		if (mbr_type.basetype == SPIRType::Struct)
3585 			align_struct(mbr_type, aligned_structs);
3586 	}
3587 
3588 	// Test the alignment of each member, and if a member should be closer to the previous
3589 	// member than the default spacing expects, it is likely that the previous member is in
3590 	// a packed format. If so, and the previous member is packable, pack it.
3591 	// For example ... this applies to any 3-element vector that is followed by a scalar.
3592 	uint32_t msl_offset = 0;
3593 	for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
3594 	{
3595 		// This checks the member in isolation, if the member needs some kind of type remapping to conform to SPIR-V
3596 		// offsets, array strides and matrix strides.
3597 		ensure_member_packing_rules_msl(ib_type, mbr_idx);
3598 
3599 		// Align current offset to the current member's default alignment. If the member was packed, it will observe
3600 		// the updated alignment here.
3601 		uint32_t msl_align_mask = get_declared_struct_member_alignment_msl(ib_type, mbr_idx) - 1;
3602 		uint32_t aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
3603 
3604 		// Fetch the member offset as declared in the SPIRV.
3605 		uint32_t spirv_mbr_offset = get_member_decoration(ib_type_id, mbr_idx, DecorationOffset);
3606 		if (spirv_mbr_offset > aligned_msl_offset)
3607 		{
3608 			// Since MSL and SPIR-V have slightly different struct member alignment and
3609 			// size rules, we'll pad to standard C-packing rules with a char[] array. If the member is farther
3610 			// away than C-packing, expects, add an inert padding member before the the member.
3611 			uint32_t padding_bytes = spirv_mbr_offset - aligned_msl_offset;
3612 			set_extended_member_decoration(ib_type_id, mbr_idx, SPIRVCrossDecorationPaddingTarget, padding_bytes);
3613 
3614 			// Re-align as a sanity check that aligning post-padding matches up.
3615 			msl_offset += padding_bytes;
3616 			aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
3617 		}
3618 		else if (spirv_mbr_offset < aligned_msl_offset)
3619 		{
3620 			// This should not happen, but deal with unexpected scenarios.
3621 			// It *might* happen if a sub-struct has a larger alignment requirement in MSL than SPIR-V.
3622 			SPIRV_CROSS_THROW("Cannot represent buffer block correctly in MSL.");
3623 		}
3624 
3625 		assert(aligned_msl_offset == spirv_mbr_offset);
3626 
3627 		// Increment the current offset to be positioned immediately after the current member.
3628 		// Don't do this for the last member since it can be unsized, and it is not relevant for padding purposes here.
3629 		if (mbr_idx + 1 < mbr_cnt)
3630 			msl_offset = aligned_msl_offset + get_declared_struct_member_size_msl(ib_type, mbr_idx);
3631 	}
3632 }
3633 
validate_member_packing_rules_msl(const SPIRType & type,uint32_t index) const3634 bool CompilerMSL::validate_member_packing_rules_msl(const SPIRType &type, uint32_t index) const
3635 {
3636 	auto &mbr_type = get<SPIRType>(type.member_types[index]);
3637 	uint32_t spirv_offset = get_member_decoration(type.self, index, DecorationOffset);
3638 
3639 	if (index + 1 < type.member_types.size())
3640 	{
3641 		// First, we will check offsets. If SPIR-V offset + MSL size > SPIR-V offset of next member,
3642 		// we *must* perform some kind of remapping, no way getting around it.
3643 		// We can always pad after this member if necessary, so that case is fine.
3644 		uint32_t spirv_offset_next = get_member_decoration(type.self, index + 1, DecorationOffset);
3645 		assert(spirv_offset_next >= spirv_offset);
3646 		uint32_t maximum_size = spirv_offset_next - spirv_offset;
3647 		uint32_t msl_mbr_size = get_declared_struct_member_size_msl(type, index);
3648 		if (msl_mbr_size > maximum_size)
3649 			return false;
3650 	}
3651 
3652 	if (!mbr_type.array.empty())
3653 	{
3654 		// If we have an array type, array stride must match exactly with SPIR-V.
3655 
3656 		// An exception to this requirement is if we have one array element.
3657 		// This comes from DX scalar layout workaround.
3658 		// If app tries to be cheeky and access the member out of bounds, this will not work, but this is the best we can do.
3659 		// In OpAccessChain with logical memory models, access chains must be in-bounds in SPIR-V specification.
3660 		bool relax_array_stride = mbr_type.array.back() == 1 && mbr_type.array_size_literal.back();
3661 
3662 		if (!relax_array_stride)
3663 		{
3664 			uint32_t spirv_array_stride = type_struct_member_array_stride(type, index);
3665 			uint32_t msl_array_stride = get_declared_struct_member_array_stride_msl(type, index);
3666 			if (spirv_array_stride != msl_array_stride)
3667 				return false;
3668 		}
3669 	}
3670 
3671 	if (is_matrix(mbr_type))
3672 	{
3673 		// Need to check MatrixStride as well.
3674 		uint32_t spirv_matrix_stride = type_struct_member_matrix_stride(type, index);
3675 		uint32_t msl_matrix_stride = get_declared_struct_member_matrix_stride_msl(type, index);
3676 		if (spirv_matrix_stride != msl_matrix_stride)
3677 			return false;
3678 	}
3679 
3680 	// Now, we check alignment.
3681 	uint32_t msl_alignment = get_declared_struct_member_alignment_msl(type, index);
3682 	if ((spirv_offset % msl_alignment) != 0)
3683 		return false;
3684 
3685 	// We're in the clear.
3686 	return true;
3687 }
3688 
3689 // Here we need to verify that the member type we declare conforms to Offset, ArrayStride or MatrixStride restrictions.
3690 // If there is a mismatch, we need to emit remapped types, either normal types, or "packed_X" types.
3691 // In odd cases we need to emit packed and remapped types, for e.g. weird matrices or arrays with weird array strides.
ensure_member_packing_rules_msl(SPIRType & ib_type,uint32_t index)3692 void CompilerMSL::ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t index)
3693 {
3694 	if (validate_member_packing_rules_msl(ib_type, index))
3695 		return;
3696 
3697 	// We failed validation.
3698 	// This case will be nightmare-ish to deal with. This could possibly happen if struct alignment does not quite
3699 	// match up with what we want. Scalar block layout comes to mind here where we might have to work around the rule
3700 	// that struct alignment == max alignment of all members and struct size depends on this alignment.
3701 	auto &mbr_type = get<SPIRType>(ib_type.member_types[index]);
3702 	if (mbr_type.basetype == SPIRType::Struct)
3703 		SPIRV_CROSS_THROW("Cannot perform any repacking for structs when it is used as a member of another struct.");
3704 
3705 	// Perform remapping here.
3706 	// There is nothing to be gained by using packed scalars, so don't attempt it.
3707 	if (!is_scalar(ib_type))
3708 		set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
3709 
3710 	// Try validating again, now with packed.
3711 	if (validate_member_packing_rules_msl(ib_type, index))
3712 		return;
3713 
3714 	// We're in deep trouble, and we need to create a new PhysicalType which matches up with what we expect.
3715 	// A lot of work goes here ...
3716 	// We will need remapping on Load and Store to translate the types between Logical and Physical.
3717 
3718 	// First, we check if we have small vector std140 array.
3719 	// We detect this if we have an array of vectors, and array stride is greater than number of elements.
3720 	if (!mbr_type.array.empty() && !is_matrix(mbr_type))
3721 	{
3722 		uint32_t array_stride = type_struct_member_array_stride(ib_type, index);
3723 
3724 		// Hack off array-of-arrays until we find the array stride per element we must have to make it work.
3725 		uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
3726 		for (uint32_t dim = 0; dim < dimensions; dim++)
3727 			array_stride /= max(to_array_size_literal(mbr_type, dim), 1u);
3728 
3729 		uint32_t elems_per_stride = array_stride / (mbr_type.width / 8);
3730 
3731 		if (elems_per_stride == 3)
3732 			SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
3733 		else if (elems_per_stride > 4)
3734 			SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
3735 
3736 		auto physical_type = mbr_type;
3737 		physical_type.vecsize = elems_per_stride;
3738 		physical_type.parent_type = 0;
3739 		uint32_t type_id = ir.increase_bound_by(1);
3740 		set<SPIRType>(type_id, physical_type);
3741 		set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID, type_id);
3742 		set_decoration(type_id, DecorationArrayStride, array_stride);
3743 
3744 		// Remove packed_ for vectors of size 1, 2 and 4.
3745 		if (has_extended_decoration(ib_type.self, SPIRVCrossDecorationPhysicalTypePacked))
3746 			SPIRV_CROSS_THROW("Unable to remove packed decoration as entire struct must be fully packed. Do not mix "
3747 			                  "scalar and std140 layout rules.");
3748 		else
3749 			unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
3750 	}
3751 	else if (is_matrix(mbr_type))
3752 	{
3753 		// MatrixStride might be std140-esque.
3754 		uint32_t matrix_stride = type_struct_member_matrix_stride(ib_type, index);
3755 
3756 		uint32_t elems_per_stride = matrix_stride / (mbr_type.width / 8);
3757 
3758 		if (elems_per_stride == 3)
3759 			SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
3760 		else if (elems_per_stride > 4)
3761 			SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
3762 
3763 		bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
3764 
3765 		auto physical_type = mbr_type;
3766 		physical_type.parent_type = 0;
3767 		if (row_major)
3768 			physical_type.columns = elems_per_stride;
3769 		else
3770 			physical_type.vecsize = elems_per_stride;
3771 		uint32_t type_id = ir.increase_bound_by(1);
3772 		set<SPIRType>(type_id, physical_type);
3773 		set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID, type_id);
3774 
3775 		// Remove packed_ for vectors of size 1, 2 and 4.
3776 		if (has_extended_decoration(ib_type.self, SPIRVCrossDecorationPhysicalTypePacked))
3777 			SPIRV_CROSS_THROW("Unable to remove packed decoration as entire struct must be fully packed. Do not mix "
3778 			                  "scalar and std140 layout rules.");
3779 		else
3780 			unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
3781 	}
3782 	else
3783 		SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
3784 
3785 	// Try validating again, now with physical type remapping.
3786 	if (validate_member_packing_rules_msl(ib_type, index))
3787 		return;
3788 
3789 	// We might have a particular odd scalar layout case where the last element of an array
3790 	// does not take up as much space as the ArrayStride or MatrixStride. This can happen with DX cbuffers.
3791 	// The "proper" workaround for this is extremely painful and essentially impossible in the edge case of float3[],
3792 	// so we hack around it by declaring the offending array or matrix with one less array size/col/row,
3793 	// and rely on padding to get the correct value. We will technically access arrays out of bounds into the padding region,
3794 	// but it should spill over gracefully without too much trouble. We rely on behavior like this for unsized arrays anyways.
3795 
3796 	// E.g. we might observe a physical layout of:
3797 	// { float2 a[2]; float b; } in cbuffer layout where ArrayStride of a is 16, but offset of b is 24, packed right after a[1] ...
3798 	uint32_t type_id = get_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID);
3799 	auto &type = get<SPIRType>(type_id);
3800 
3801 	// Modify the physical type in-place. This is safe since each physical type workaround is a copy.
3802 	if (is_array(type))
3803 	{
3804 		if (type.array.back() > 1)
3805 		{
3806 			if (!type.array_size_literal.back())
3807 				SPIRV_CROSS_THROW("Cannot apply scalar layout workaround with spec constant array size.");
3808 			type.array.back() -= 1;
3809 		}
3810 		else
3811 		{
3812 			// We have an array of size 1, so we cannot decrement that. Our only option now is to
3813 			// force a packed layout instead, and drop the physical type remap since ArrayStride is meaningless now.
3814 			unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID);
3815 			set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
3816 		}
3817 	}
3818 	else if (is_matrix(type))
3819 	{
3820 		bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
3821 		if (!row_major)
3822 		{
3823 			// Slice off one column. If we only have 2 columns, this might turn the matrix into a vector with one array element instead.
3824 			if (type.columns > 2)
3825 			{
3826 				type.columns--;
3827 			}
3828 			else if (type.columns == 2)
3829 			{
3830 				type.columns = 1;
3831 				assert(type.array.empty());
3832 				type.array.push_back(1);
3833 				type.array_size_literal.push_back(true);
3834 			}
3835 		}
3836 		else
3837 		{
3838 			// Slice off one row. If we only have 2 rows, this might turn the matrix into a vector with one array element instead.
3839 			if (type.vecsize > 2)
3840 			{
3841 				type.vecsize--;
3842 			}
3843 			else if (type.vecsize == 2)
3844 			{
3845 				type.vecsize = type.columns;
3846 				type.columns = 1;
3847 				assert(type.array.empty());
3848 				type.array.push_back(1);
3849 				type.array_size_literal.push_back(true);
3850 			}
3851 		}
3852 	}
3853 
3854 	// This better validate now, or we must fail gracefully.
3855 	if (!validate_member_packing_rules_msl(ib_type, index))
3856 		SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
3857 }
3858 
emit_store_statement(uint32_t lhs_expression,uint32_t rhs_expression)3859 void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression)
3860 {
3861 	auto &type = expression_type(rhs_expression);
3862 
3863 	bool lhs_remapped_type = has_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypeID);
3864 	bool lhs_packed_type = has_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypePacked);
3865 	auto *lhs_e = maybe_get<SPIRExpression>(lhs_expression);
3866 	auto *rhs_e = maybe_get<SPIRExpression>(rhs_expression);
3867 
3868 	bool transpose = lhs_e && lhs_e->need_transpose;
3869 
3870 	// No physical type remapping, and no packed type, so can just emit a store directly.
3871 	if (!lhs_remapped_type && !lhs_packed_type)
3872 	{
3873 		// We might not be dealing with remapped physical types or packed types,
3874 		// but we might be doing a clean store to a row-major matrix.
3875 		// In this case, we just flip transpose states, and emit the store, a transpose must be in the RHS expression, if any.
3876 		if (is_matrix(type) && lhs_e && lhs_e->need_transpose)
3877 		{
3878 			if (!rhs_e)
3879 				SPIRV_CROSS_THROW("Need to transpose right-side expression of a store to row-major matrix, but it is "
3880 				                  "not a SPIRExpression.");
3881 			lhs_e->need_transpose = false;
3882 
3883 			if (rhs_e && rhs_e->need_transpose)
3884 			{
3885 				// Direct copy, but might need to unpack RHS.
3886 				// Skip the transpose, as we will transpose when writing to LHS and transpose(transpose(T)) == T.
3887 				rhs_e->need_transpose = false;
3888 				statement(to_expression(lhs_expression), " = ", to_unpacked_row_major_matrix_expression(rhs_expression),
3889 				          ";");
3890 				rhs_e->need_transpose = true;
3891 			}
3892 			else
3893 				statement(to_expression(lhs_expression), " = transpose(", to_unpacked_expression(rhs_expression), ");");
3894 
3895 			lhs_e->need_transpose = true;
3896 			register_write(lhs_expression);
3897 		}
3898 		else if (lhs_e && lhs_e->need_transpose)
3899 		{
3900 			lhs_e->need_transpose = false;
3901 
3902 			// Storing a column to a row-major matrix. Unroll the write.
3903 			for (uint32_t c = 0; c < type.vecsize; c++)
3904 			{
3905 				auto lhs_expr = to_dereferenced_expression(lhs_expression);
3906 				auto column_index = lhs_expr.find_last_of('[');
3907 				if (column_index != string::npos)
3908 				{
3909 					statement(lhs_expr.insert(column_index, join('[', c, ']')), " = ",
3910 					          to_extract_component_expression(rhs_expression, c), ";");
3911 				}
3912 			}
3913 			lhs_e->need_transpose = true;
3914 			register_write(lhs_expression);
3915 		}
3916 		else
3917 			CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
3918 	}
3919 	else if (!lhs_remapped_type && !is_matrix(type) && !transpose)
3920 	{
3921 		// Even if the target type is packed, we can directly store to it. We cannot store to packed matrices directly,
3922 		// since they are declared as array of vectors instead, and we need the fallback path below.
3923 		CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
3924 	}
3925 	else
3926 	{
3927 		// Special handling when storing to a remapped physical type.
3928 		// This is mostly to deal with std140 padded matrices or vectors.
3929 
3930 		TypeID physical_type_id = lhs_remapped_type ?
3931 		                              ID(get_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypeID)) :
3932 		                              type.self;
3933 
3934 		auto &physical_type = get<SPIRType>(physical_type_id);
3935 
3936 		if (is_matrix(type))
3937 		{
3938 			const char *packed_pfx = lhs_packed_type ? "packed_" : "";
3939 
3940 			// Packed matrices are stored as arrays of packed vectors, so we need
3941 			// to assign the vectors one at a time.
3942 			// For row-major matrices, we need to transpose the *right-hand* side,
3943 			// not the left-hand side.
3944 
3945 			// Lots of cases to cover here ...
3946 
3947 			bool rhs_transpose = rhs_e && rhs_e->need_transpose;
3948 			SPIRType write_type = type;
3949 			string cast_expr;
3950 
3951 			// We're dealing with transpose manually.
3952 			if (rhs_transpose)
3953 				rhs_e->need_transpose = false;
3954 
3955 			if (transpose)
3956 			{
3957 				// We're dealing with transpose manually.
3958 				lhs_e->need_transpose = false;
3959 				write_type.vecsize = type.columns;
3960 				write_type.columns = 1;
3961 
3962 				if (physical_type.columns != type.columns)
3963 					cast_expr = join("(device ", packed_pfx, type_to_glsl(write_type), "&)");
3964 
3965 				if (rhs_transpose)
3966 				{
3967 					// If RHS is also transposed, we can just copy row by row.
3968 					for (uint32_t i = 0; i < type.vecsize; i++)
3969 					{
3970 						statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ",
3971 						          to_unpacked_row_major_matrix_expression(rhs_expression), "[", i, "];");
3972 					}
3973 				}
3974 				else
3975 				{
3976 					auto vector_type = expression_type(rhs_expression);
3977 					vector_type.vecsize = vector_type.columns;
3978 					vector_type.columns = 1;
3979 
3980 					// Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
3981 					// so pick out individual components instead.
3982 					for (uint32_t i = 0; i < type.vecsize; i++)
3983 					{
3984 						string rhs_row = type_to_glsl_constructor(vector_type) + "(";
3985 						for (uint32_t j = 0; j < vector_type.vecsize; j++)
3986 						{
3987 							rhs_row += join(to_enclosed_unpacked_expression(rhs_expression), "[", j, "][", i, "]");
3988 							if (j + 1 < vector_type.vecsize)
3989 								rhs_row += ", ";
3990 						}
3991 						rhs_row += ")";
3992 
3993 						statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";");
3994 					}
3995 				}
3996 
3997 				// We're dealing with transpose manually.
3998 				lhs_e->need_transpose = true;
3999 			}
4000 			else
4001 			{
4002 				write_type.columns = 1;
4003 
4004 				if (physical_type.vecsize != type.vecsize)
4005 					cast_expr = join("(device ", packed_pfx, type_to_glsl(write_type), "&)");
4006 
4007 				if (rhs_transpose)
4008 				{
4009 					auto vector_type = expression_type(rhs_expression);
4010 					vector_type.columns = 1;
4011 
4012 					// Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
4013 					// so pick out individual components instead.
4014 					for (uint32_t i = 0; i < type.columns; i++)
4015 					{
4016 						string rhs_row = type_to_glsl_constructor(vector_type) + "(";
4017 						for (uint32_t j = 0; j < vector_type.vecsize; j++)
4018 						{
4019 							// Need to explicitly unpack expression since we've mucked with transpose state.
4020 							auto unpacked_expr = to_unpacked_row_major_matrix_expression(rhs_expression);
4021 							rhs_row += join(unpacked_expr, "[", j, "][", i, "]");
4022 							if (j + 1 < vector_type.vecsize)
4023 								rhs_row += ", ";
4024 						}
4025 						rhs_row += ")";
4026 
4027 						statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";");
4028 					}
4029 				}
4030 				else
4031 				{
4032 					// Copy column-by-column.
4033 					for (uint32_t i = 0; i < type.columns; i++)
4034 					{
4035 						statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ",
4036 						          to_enclosed_unpacked_expression(rhs_expression), "[", i, "];");
4037 					}
4038 				}
4039 			}
4040 
4041 			// We're dealing with transpose manually.
4042 			if (rhs_transpose)
4043 				rhs_e->need_transpose = true;
4044 		}
4045 		else if (transpose)
4046 		{
4047 			lhs_e->need_transpose = false;
4048 
4049 			SPIRType write_type = type;
4050 			write_type.vecsize = 1;
4051 			write_type.columns = 1;
4052 
4053 			// Storing a column to a row-major matrix. Unroll the write.
4054 			for (uint32_t c = 0; c < type.vecsize; c++)
4055 			{
4056 				auto lhs_expr = to_enclosed_expression(lhs_expression);
4057 				auto column_index = lhs_expr.find_last_of('[');
4058 				if (column_index != string::npos)
4059 				{
4060 					statement("((device ", type_to_glsl(write_type), "*)&",
4061 					          lhs_expr.insert(column_index, join('[', c, ']', ")")), " = ",
4062 					          to_extract_component_expression(rhs_expression, c), ";");
4063 				}
4064 			}
4065 
4066 			lhs_e->need_transpose = true;
4067 		}
4068 		else if ((is_matrix(physical_type) || is_array(physical_type)) && physical_type.vecsize > type.vecsize)
4069 		{
4070 			assert(type.vecsize >= 1 && type.vecsize <= 3);
4071 
4072 			// If we have packed types, we cannot use swizzled stores.
4073 			// We could technically unroll the store for each element if needed.
4074 			// When remapping to a std140 physical type, we always get float4,
4075 			// and the packed decoration should always be removed.
4076 			assert(!lhs_packed_type);
4077 
4078 			string lhs = to_dereferenced_expression(lhs_expression);
4079 			string rhs = to_pointer_expression(rhs_expression);
4080 
4081 			// Unpack the expression so we can store to it with a float or float2.
4082 			// It's still an l-value, so it's fine. Most other unpacking of expressions turn them into r-values instead.
4083 			lhs = join("(device ", type_to_glsl(type), "&)", enclose_expression(lhs));
4084 			if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs))
4085 				statement(lhs, " = ", rhs, ";");
4086 		}
4087 		else if (!is_matrix(type))
4088 		{
4089 			string lhs = to_dereferenced_expression(lhs_expression);
4090 			string rhs = to_pointer_expression(rhs_expression);
4091 			if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs))
4092 				statement(lhs, " = ", rhs, ";");
4093 		}
4094 
4095 		register_write(lhs_expression);
4096 	}
4097 }
4098 
expression_ends_with(const string & expr_str,const std::string & ending)4099 static bool expression_ends_with(const string &expr_str, const std::string &ending)
4100 {
4101 	if (expr_str.length() >= ending.length())
4102 		return (expr_str.compare(expr_str.length() - ending.length(), ending.length(), ending) == 0);
4103 	else
4104 		return false;
4105 }
4106 
4107 // Converts the format of the current expression from packed to unpacked,
4108 // by wrapping the expression in a constructor of the appropriate type.
4109 // Also, handle special physical ID remapping scenarios, similar to emit_store_statement().
unpack_expression_type(string expr_str,const SPIRType & type,uint32_t physical_type_id,bool packed,bool row_major)4110 string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type, uint32_t physical_type_id,
4111                                            bool packed, bool row_major)
4112 {
4113 	// Trivial case, nothing to do.
4114 	if (physical_type_id == 0 && !packed)
4115 		return expr_str;
4116 
4117 	const SPIRType *physical_type = nullptr;
4118 	if (physical_type_id)
4119 		physical_type = &get<SPIRType>(physical_type_id);
4120 
4121 	static const char *swizzle_lut[] = {
4122 		".x",
4123 		".xy",
4124 		".xyz",
4125 	};
4126 
4127 	if (physical_type && is_vector(*physical_type) && is_array(*physical_type) &&
4128 	    physical_type->vecsize > type.vecsize && !expression_ends_with(expr_str, swizzle_lut[type.vecsize - 1]))
4129 	{
4130 		// std140 array cases for vectors.
4131 		assert(type.vecsize >= 1 && type.vecsize <= 3);
4132 		return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1];
4133 	}
4134 	else if (physical_type && is_matrix(*physical_type) && is_vector(type) && physical_type->vecsize > type.vecsize)
4135 	{
4136 		// Extract column from padded matrix.
4137 		assert(type.vecsize >= 1 && type.vecsize <= 3);
4138 		return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1];
4139 	}
4140 	else if (is_matrix(type))
4141 	{
4142 		// Packed matrices are stored as arrays of packed vectors. Unfortunately,
4143 		// we can't just pass the array straight to the matrix constructor. We have to
4144 		// pass each vector individually, so that they can be unpacked to normal vectors.
4145 		if (!physical_type)
4146 			physical_type = &type;
4147 
4148 		uint32_t vecsize = type.vecsize;
4149 		uint32_t columns = type.columns;
4150 		if (row_major)
4151 			swap(vecsize, columns);
4152 
4153 		uint32_t physical_vecsize = row_major ? physical_type->columns : physical_type->vecsize;
4154 
4155 		const char *base_type = type.width == 16 ? "half" : "float";
4156 		string unpack_expr = join(base_type, columns, "x", vecsize, "(");
4157 
4158 		const char *load_swiz = "";
4159 
4160 		if (physical_vecsize != vecsize)
4161 			load_swiz = swizzle_lut[vecsize - 1];
4162 
4163 		for (uint32_t i = 0; i < columns; i++)
4164 		{
4165 			if (i > 0)
4166 				unpack_expr += ", ";
4167 
4168 			if (packed)
4169 				unpack_expr += join(base_type, physical_vecsize, "(", expr_str, "[", i, "]", ")", load_swiz);
4170 			else
4171 				unpack_expr += join(expr_str, "[", i, "]", load_swiz);
4172 		}
4173 
4174 		unpack_expr += ")";
4175 		return unpack_expr;
4176 	}
4177 	else
4178 	{
4179 		return join(type_to_glsl(type), "(", expr_str, ")");
4180 	}
4181 }
4182 
4183 // Emits the file header info
emit_header()4184 void CompilerMSL::emit_header()
4185 {
4186 	// This particular line can be overridden during compilation, so make it a flag and not a pragma line.
4187 	if (suppress_missing_prototypes)
4188 		statement("#pragma clang diagnostic ignored \"-Wmissing-prototypes\"");
4189 
4190 	// Disable warning about missing braces for array<T> template to make arrays a value type
4191 	if (spv_function_implementations.count(SPVFuncImplUnsafeArray) != 0)
4192 		statement("#pragma clang diagnostic ignored \"-Wmissing-braces\"");
4193 
4194 	for (auto &pragma : pragma_lines)
4195 		statement(pragma);
4196 
4197 	if (!pragma_lines.empty() || suppress_missing_prototypes)
4198 		statement("");
4199 
4200 	statement("#include <metal_stdlib>");
4201 	statement("#include <simd/simd.h>");
4202 
4203 	for (auto &header : header_lines)
4204 		statement(header);
4205 
4206 	statement("");
4207 	statement("using namespace metal;");
4208 	statement("");
4209 
4210 	for (auto &td : typedef_lines)
4211 		statement(td);
4212 
4213 	if (!typedef_lines.empty())
4214 		statement("");
4215 }
4216 
add_pragma_line(const string & line)4217 void CompilerMSL::add_pragma_line(const string &line)
4218 {
4219 	auto rslt = pragma_lines.insert(line);
4220 	if (rslt.second)
4221 		force_recompile();
4222 }
4223 
add_typedef_line(const string & line)4224 void CompilerMSL::add_typedef_line(const string &line)
4225 {
4226 	auto rslt = typedef_lines.insert(line);
4227 	if (rslt.second)
4228 		force_recompile();
4229 }
4230 
4231 // Template struct like spvUnsafeArray<> need to be declared *before* any resources are declared
emit_custom_templates()4232 void CompilerMSL::emit_custom_templates()
4233 {
4234 	for (const auto &spv_func : spv_function_implementations)
4235 	{
4236 		switch (spv_func)
4237 		{
4238 		case SPVFuncImplUnsafeArray:
4239 			statement("template<typename T, size_t Num>");
4240 			statement("struct spvUnsafeArray");
4241 			begin_scope();
4242 			statement("T elements[Num ? Num : 1];");
4243 			statement("");
4244 			statement("thread T& operator [] (size_t pos) thread");
4245 			begin_scope();
4246 			statement("return elements[pos];");
4247 			end_scope();
4248 			statement("constexpr const thread T& operator [] (size_t pos) const thread");
4249 			begin_scope();
4250 			statement("return elements[pos];");
4251 			end_scope();
4252 			statement("");
4253 			statement("device T& operator [] (size_t pos) device");
4254 			begin_scope();
4255 			statement("return elements[pos];");
4256 			end_scope();
4257 			statement("constexpr const device T& operator [] (size_t pos) const device");
4258 			begin_scope();
4259 			statement("return elements[pos];");
4260 			end_scope();
4261 			statement("");
4262 			statement("constexpr const constant T& operator [] (size_t pos) const constant");
4263 			begin_scope();
4264 			statement("return elements[pos];");
4265 			end_scope();
4266 			statement("");
4267 			statement("threadgroup T& operator [] (size_t pos) threadgroup");
4268 			begin_scope();
4269 			statement("return elements[pos];");
4270 			end_scope();
4271 			statement("constexpr const threadgroup T& operator [] (size_t pos) const threadgroup");
4272 			begin_scope();
4273 			statement("return elements[pos];");
4274 			end_scope();
4275 			end_scope_decl();
4276 			statement("");
4277 			break;
4278 
4279 		default:
4280 			break;
4281 		}
4282 	}
4283 }
4284 
4285 // Emits any needed custom function bodies.
4286 // Metal helper functions must be static force-inline, i.e. static inline __attribute__((always_inline))
4287 // otherwise they will cause problems when linked together in a single Metallib.
emit_custom_functions()4288 void CompilerMSL::emit_custom_functions()
4289 {
4290 	for (uint32_t i = kArrayCopyMultidimMax; i >= 2; i--)
4291 		if (spv_function_implementations.count(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i)))
4292 			spv_function_implementations.insert(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i - 1));
4293 
4294 	if (spv_function_implementations.count(SPVFuncImplDynamicImageSampler))
4295 	{
4296 		// Unfortunately, this one needs a lot of the other functions to compile OK.
4297 		if (!msl_options.supports_msl_version(2))
4298 			SPIRV_CROSS_THROW(
4299 			    "spvDynamicImageSampler requires default-constructible texture objects, which require MSL 2.0.");
4300 		spv_function_implementations.insert(SPVFuncImplForwardArgs);
4301 		spv_function_implementations.insert(SPVFuncImplTextureSwizzle);
4302 		if (msl_options.swizzle_texture_samples)
4303 			spv_function_implementations.insert(SPVFuncImplGatherSwizzle);
4304 		for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
4305 		     i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
4306 			spv_function_implementations.insert(static_cast<SPVFuncImpl>(i));
4307 		spv_function_implementations.insert(SPVFuncImplExpandITUFullRange);
4308 		spv_function_implementations.insert(SPVFuncImplExpandITUNarrowRange);
4309 		spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT709);
4310 		spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT601);
4311 		spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT2020);
4312 	}
4313 
4314 	for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
4315 	     i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
4316 		if (spv_function_implementations.count(static_cast<SPVFuncImpl>(i)))
4317 			spv_function_implementations.insert(SPVFuncImplForwardArgs);
4318 
4319 	if (spv_function_implementations.count(SPVFuncImplTextureSwizzle) ||
4320 	    spv_function_implementations.count(SPVFuncImplGatherSwizzle) ||
4321 	    spv_function_implementations.count(SPVFuncImplGatherCompareSwizzle))
4322 	{
4323 		spv_function_implementations.insert(SPVFuncImplForwardArgs);
4324 		spv_function_implementations.insert(SPVFuncImplGetSwizzle);
4325 	}
4326 
4327 	for (const auto &spv_func : spv_function_implementations)
4328 	{
4329 		switch (spv_func)
4330 		{
4331 		case SPVFuncImplMod:
4332 			statement("// Implementation of the GLSL mod() function, which is slightly different than Metal fmod()");
4333 			statement("template<typename Tx, typename Ty>");
4334 			statement("inline Tx mod(Tx x, Ty y)");
4335 			begin_scope();
4336 			statement("return x - y * floor(x / y);");
4337 			end_scope();
4338 			statement("");
4339 			break;
4340 
4341 		case SPVFuncImplRadians:
4342 			statement("// Implementation of the GLSL radians() function");
4343 			statement("template<typename T>");
4344 			statement("inline T radians(T d)");
4345 			begin_scope();
4346 			statement("return d * T(0.01745329251);");
4347 			end_scope();
4348 			statement("");
4349 			break;
4350 
4351 		case SPVFuncImplDegrees:
4352 			statement("// Implementation of the GLSL degrees() function");
4353 			statement("template<typename T>");
4354 			statement("inline T degrees(T r)");
4355 			begin_scope();
4356 			statement("return r * T(57.2957795131);");
4357 			end_scope();
4358 			statement("");
4359 			break;
4360 
4361 		case SPVFuncImplFindILsb:
4362 			statement("// Implementation of the GLSL findLSB() function");
4363 			statement("template<typename T>");
4364 			statement("inline T spvFindLSB(T x)");
4365 			begin_scope();
4366 			statement("return select(ctz(x), T(-1), x == T(0));");
4367 			end_scope();
4368 			statement("");
4369 			break;
4370 
4371 		case SPVFuncImplFindUMsb:
4372 			statement("// Implementation of the unsigned GLSL findMSB() function");
4373 			statement("template<typename T>");
4374 			statement("inline T spvFindUMSB(T x)");
4375 			begin_scope();
4376 			statement("return select(clz(T(0)) - (clz(x) + T(1)), T(-1), x == T(0));");
4377 			end_scope();
4378 			statement("");
4379 			break;
4380 
4381 		case SPVFuncImplFindSMsb:
4382 			statement("// Implementation of the signed GLSL findMSB() function");
4383 			statement("template<typename T>");
4384 			statement("inline T spvFindSMSB(T x)");
4385 			begin_scope();
4386 			statement("T v = select(x, T(-1) - x, x < T(0));");
4387 			statement("return select(clz(T(0)) - (clz(v) + T(1)), T(-1), v == T(0));");
4388 			end_scope();
4389 			statement("");
4390 			break;
4391 
4392 		case SPVFuncImplSSign:
4393 			statement("// Implementation of the GLSL sign() function for integer types");
4394 			statement("template<typename T, typename E = typename enable_if<is_integral<T>::value>::type>");
4395 			statement("inline T sign(T x)");
4396 			begin_scope();
4397 			statement("return select(select(select(x, T(0), x == T(0)), T(1), x > T(0)), T(-1), x < T(0));");
4398 			end_scope();
4399 			statement("");
4400 			break;
4401 
4402 		case SPVFuncImplArrayCopy:
4403 		case SPVFuncImplArrayOfArrayCopy2Dim:
4404 		case SPVFuncImplArrayOfArrayCopy3Dim:
4405 		case SPVFuncImplArrayOfArrayCopy4Dim:
4406 		case SPVFuncImplArrayOfArrayCopy5Dim:
4407 		case SPVFuncImplArrayOfArrayCopy6Dim:
4408 		{
4409 			// Unfortunately we cannot template on the address space, so combinatorial explosion it is.
4410 			static const char *function_name_tags[] = {
4411 				"FromConstantToStack",     "FromConstantToThreadGroup", "FromStackToStack",
4412 				"FromStackToThreadGroup",  "FromThreadGroupToStack",    "FromThreadGroupToThreadGroup",
4413 				"FromDeviceToDevice",      "FromConstantToDevice",      "FromStackToDevice",
4414 				"FromThreadGroupToDevice", "FromDeviceToStack",         "FromDeviceToThreadGroup",
4415 			};
4416 
4417 			static const char *src_address_space[] = {
4418 				"constant",          "constant",          "thread const", "thread const",
4419 				"threadgroup const", "threadgroup const", "device const", "constant",
4420 				"thread const",      "threadgroup const", "device const", "device const",
4421 			};
4422 
4423 			static const char *dst_address_space[] = {
4424 				"thread", "threadgroup", "thread", "threadgroup", "thread", "threadgroup",
4425 				"device", "device",      "device", "device",      "thread", "threadgroup",
4426 			};
4427 
4428 			for (uint32_t variant = 0; variant < 12; variant++)
4429 			{
4430 				uint32_t dimensions = spv_func - SPVFuncImplArrayCopyMultidimBase;
4431 				string tmp = "template<typename T";
4432 				for (uint8_t i = 0; i < dimensions; i++)
4433 				{
4434 					tmp += ", uint ";
4435 					tmp += 'A' + i;
4436 				}
4437 				tmp += ">";
4438 				statement(tmp);
4439 
4440 				string array_arg;
4441 				for (uint8_t i = 0; i < dimensions; i++)
4442 				{
4443 					array_arg += "[";
4444 					array_arg += 'A' + i;
4445 					array_arg += "]";
4446 				}
4447 
4448 				statement("inline void spvArrayCopy", function_name_tags[variant], dimensions, "(",
4449 				          dst_address_space[variant], " T (&dst)", array_arg, ", ", src_address_space[variant],
4450 				          " T (&src)", array_arg, ")");
4451 
4452 				begin_scope();
4453 				statement("for (uint i = 0; i < A; i++)");
4454 				begin_scope();
4455 
4456 				if (dimensions == 1)
4457 					statement("dst[i] = src[i];");
4458 				else
4459 					statement("spvArrayCopy", function_name_tags[variant], dimensions - 1, "(dst[i], src[i]);");
4460 				end_scope();
4461 				end_scope();
4462 				statement("");
4463 			}
4464 			break;
4465 		}
4466 
4467 		// Support for Metal 2.1's new texture_buffer type.
4468 		case SPVFuncImplTexelBufferCoords:
4469 		{
4470 			if (msl_options.texel_buffer_texture_width > 0)
4471 			{
4472 				string tex_width_str = convert_to_string(msl_options.texel_buffer_texture_width);
4473 				statement("// Returns 2D texture coords corresponding to 1D texel buffer coords");
4474 				statement(force_inline);
4475 				statement("uint2 spvTexelBufferCoord(uint tc)");
4476 				begin_scope();
4477 				statement(join("return uint2(tc % ", tex_width_str, ", tc / ", tex_width_str, ");"));
4478 				end_scope();
4479 				statement("");
4480 			}
4481 			else
4482 			{
4483 				statement("// Returns 2D texture coords corresponding to 1D texel buffer coords");
4484 				statement(
4485 				    "#define spvTexelBufferCoord(tc, tex) uint2((tc) % (tex).get_width(), (tc) / (tex).get_width())");
4486 				statement("");
4487 			}
4488 			break;
4489 		}
4490 
4491 		// Emulate texture2D atomic operations
4492 		case SPVFuncImplImage2DAtomicCoords:
4493 		{
4494 			if (msl_options.supports_msl_version(1, 2))
4495 			{
4496 				statement("// The required alignment of a linear texture of R32Uint format.");
4497 				statement("constant uint spvLinearTextureAlignmentOverride [[function_constant(",
4498 				          msl_options.r32ui_alignment_constant_id, ")]];");
4499 				statement("constant uint spvLinearTextureAlignment = ",
4500 				          "is_function_constant_defined(spvLinearTextureAlignmentOverride) ? ",
4501 				          "spvLinearTextureAlignmentOverride : ", msl_options.r32ui_linear_texture_alignment, ";");
4502 			}
4503 			else
4504 			{
4505 				statement("// The required alignment of a linear texture of R32Uint format.");
4506 				statement("constant uint spvLinearTextureAlignment = ", msl_options.r32ui_linear_texture_alignment,
4507 				          ";");
4508 			}
4509 			statement("// Returns buffer coords corresponding to 2D texture coords for emulating 2D texture atomics");
4510 			statement("#define spvImage2DAtomicCoord(tc, tex) (((((tex).get_width() + ",
4511 			          " spvLinearTextureAlignment / 4 - 1) & ~(",
4512 			          " spvLinearTextureAlignment / 4 - 1)) * (tc).y) + (tc).x)");
4513 			statement("");
4514 			break;
4515 		}
4516 
4517 		// "fadd" intrinsic support
4518 		case SPVFuncImplFAdd:
4519 			statement("template<typename T>");
4520 			statement("T spvFAdd(T l, T r)");
4521 			begin_scope();
4522 			statement("return fma(T(1), l, r);");
4523 			end_scope();
4524 			statement("");
4525 			break;
4526 
4527 		// "fmul' intrinsic support
4528 		case SPVFuncImplFMul:
4529 			statement("template<typename T>");
4530 			statement("T spvFMul(T l, T r)");
4531 			begin_scope();
4532 			statement("return fma(l, r, T(0));");
4533 			end_scope();
4534 			statement("");
4535 
4536 			statement("template<typename T, int Cols, int Rows>");
4537 			statement("vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)");
4538 			begin_scope();
4539 			statement("vec<T, Cols> res = vec<T, Cols>(0);");
4540 			statement("for (uint i = Rows; i > 0; --i)");
4541 			begin_scope();
4542 			statement("vec<T, Cols> tmp(0);");
4543 			statement("for (uint j = 0; j < Cols; ++j)");
4544 			begin_scope();
4545 			statement("tmp[j] = m[j][i - 1];");
4546 			end_scope();
4547 			statement("res = fma(tmp, vec<T, Cols>(v[i - 1]), res);");
4548 			end_scope();
4549 			statement("return res;");
4550 			end_scope();
4551 			statement("");
4552 
4553 			statement("template<typename T, int Cols, int Rows>");
4554 			statement("vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)");
4555 			begin_scope();
4556 			statement("vec<T, Rows> res = vec<T, Rows>(0);");
4557 			statement("for (uint i = Cols; i > 0; --i)");
4558 			begin_scope();
4559 			statement("res = fma(m[i - 1], vec<T, Rows>(v[i - 1]), res);");
4560 			end_scope();
4561 			statement("return res;");
4562 			end_scope();
4563 			statement("");
4564 
4565 			statement("template<typename T, int LCols, int LRows, int RCols, int RRows>");
4566 			statement(
4567 			    "matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)");
4568 			begin_scope();
4569 			statement("matrix<T, RCols, LRows> res;");
4570 			statement("for (uint i = 0; i < RCols; i++)");
4571 			begin_scope();
4572 			statement("vec<T, RCols> tmp(0);");
4573 			statement("for (uint j = 0; j < LCols; j++)");
4574 			begin_scope();
4575 			statement("tmp = fma(vec<T, RCols>(r[i][j]), l[j], tmp);");
4576 			end_scope();
4577 			statement("res[i] = tmp;");
4578 			end_scope();
4579 			statement("return res;");
4580 			end_scope();
4581 			statement("");
4582 			break;
4583 
4584 		// Emulate texturecube_array with texture2d_array for iOS where this type is not available
4585 		case SPVFuncImplCubemapTo2DArrayFace:
4586 			statement(force_inline);
4587 			statement("float3 spvCubemapTo2DArrayFace(float3 P)");
4588 			begin_scope();
4589 			statement("float3 Coords = abs(P.xyz);");
4590 			statement("float CubeFace = 0;");
4591 			statement("float ProjectionAxis = 0;");
4592 			statement("float u = 0;");
4593 			statement("float v = 0;");
4594 			statement("if (Coords.x >= Coords.y && Coords.x >= Coords.z)");
4595 			begin_scope();
4596 			statement("CubeFace = P.x >= 0 ? 0 : 1;");
4597 			statement("ProjectionAxis = Coords.x;");
4598 			statement("u = P.x >= 0 ? -P.z : P.z;");
4599 			statement("v = -P.y;");
4600 			end_scope();
4601 			statement("else if (Coords.y >= Coords.x && Coords.y >= Coords.z)");
4602 			begin_scope();
4603 			statement("CubeFace = P.y >= 0 ? 2 : 3;");
4604 			statement("ProjectionAxis = Coords.y;");
4605 			statement("u = P.x;");
4606 			statement("v = P.y >= 0 ? P.z : -P.z;");
4607 			end_scope();
4608 			statement("else");
4609 			begin_scope();
4610 			statement("CubeFace = P.z >= 0 ? 4 : 5;");
4611 			statement("ProjectionAxis = Coords.z;");
4612 			statement("u = P.z >= 0 ? P.x : -P.x;");
4613 			statement("v = -P.y;");
4614 			end_scope();
4615 			statement("u = 0.5 * (u/ProjectionAxis + 1);");
4616 			statement("v = 0.5 * (v/ProjectionAxis + 1);");
4617 			statement("return float3(u, v, CubeFace);");
4618 			end_scope();
4619 			statement("");
4620 			break;
4621 
4622 		case SPVFuncImplInverse4x4:
4623 			statement("// Returns the determinant of a 2x2 matrix.");
4624 			statement(force_inline);
4625 			statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
4626 			begin_scope();
4627 			statement("return a1 * b2 - b1 * a2;");
4628 			end_scope();
4629 			statement("");
4630 
4631 			statement("// Returns the determinant of a 3x3 matrix.");
4632 			statement(force_inline);
4633 			statement("float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
4634 			          "float c2, float c3)");
4635 			begin_scope();
4636 			statement("return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * spvDet2x2(a2, a3, "
4637 			          "b2, b3);");
4638 			end_scope();
4639 			statement("");
4640 			statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
4641 			statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
4642 			statement(force_inline);
4643 			statement("float4x4 spvInverse4x4(float4x4 m)");
4644 			begin_scope();
4645 			statement("float4x4 adj;	// The adjoint matrix (inverse after dividing by determinant)");
4646 			statement_no_indent("");
4647 			statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
4648 			statement("adj[0][0] =  spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
4649 			          "m[3][3]);");
4650 			statement("adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
4651 			          "m[3][3]);");
4652 			statement("adj[0][2] =  spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
4653 			          "m[3][3]);");
4654 			statement("adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
4655 			          "m[2][3]);");
4656 			statement_no_indent("");
4657 			statement("adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
4658 			          "m[3][3]);");
4659 			statement("adj[1][1] =  spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
4660 			          "m[3][3]);");
4661 			statement("adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
4662 			          "m[3][3]);");
4663 			statement("adj[1][3] =  spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
4664 			          "m[2][3]);");
4665 			statement_no_indent("");
4666 			statement("adj[2][0] =  spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
4667 			          "m[3][3]);");
4668 			statement("adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
4669 			          "m[3][3]);");
4670 			statement("adj[2][2] =  spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
4671 			          "m[3][3]);");
4672 			statement("adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
4673 			          "m[2][3]);");
4674 			statement_no_indent("");
4675 			statement("adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
4676 			          "m[3][2]);");
4677 			statement("adj[3][1] =  spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
4678 			          "m[3][2]);");
4679 			statement("adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
4680 			          "m[3][2]);");
4681 			statement("adj[3][3] =  spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
4682 			          "m[2][2]);");
4683 			statement_no_indent("");
4684 			statement("// Calculate the determinant as a combination of the cofactors of the first row.");
4685 			statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
4686 			          "* m[3][0]);");
4687 			statement_no_indent("");
4688 			statement("// Divide the classical adjoint matrix by the determinant.");
4689 			statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
4690 			statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
4691 			end_scope();
4692 			statement("");
4693 			break;
4694 
4695 		case SPVFuncImplInverse3x3:
4696 			if (spv_function_implementations.count(SPVFuncImplInverse4x4) == 0)
4697 			{
4698 				statement("// Returns the determinant of a 2x2 matrix.");
4699 				statement(force_inline);
4700 				statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
4701 				begin_scope();
4702 				statement("return a1 * b2 - b1 * a2;");
4703 				end_scope();
4704 				statement("");
4705 			}
4706 
4707 			statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
4708 			statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
4709 			statement(force_inline);
4710 			statement("float3x3 spvInverse3x3(float3x3 m)");
4711 			begin_scope();
4712 			statement("float3x3 adj;	// The adjoint matrix (inverse after dividing by determinant)");
4713 			statement_no_indent("");
4714 			statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
4715 			statement("adj[0][0] =  spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
4716 			statement("adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
4717 			statement("adj[0][2] =  spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
4718 			statement_no_indent("");
4719 			statement("adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
4720 			statement("adj[1][1] =  spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
4721 			statement("adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
4722 			statement_no_indent("");
4723 			statement("adj[2][0] =  spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
4724 			statement("adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
4725 			statement("adj[2][2] =  spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
4726 			statement_no_indent("");
4727 			statement("// Calculate the determinant as a combination of the cofactors of the first row.");
4728 			statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
4729 			statement_no_indent("");
4730 			statement("// Divide the classical adjoint matrix by the determinant.");
4731 			statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
4732 			statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
4733 			end_scope();
4734 			statement("");
4735 			break;
4736 
4737 		case SPVFuncImplInverse2x2:
4738 			statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
4739 			statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
4740 			statement(force_inline);
4741 			statement("float2x2 spvInverse2x2(float2x2 m)");
4742 			begin_scope();
4743 			statement("float2x2 adj;	// The adjoint matrix (inverse after dividing by determinant)");
4744 			statement_no_indent("");
4745 			statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
4746 			statement("adj[0][0] =  m[1][1];");
4747 			statement("adj[0][1] = -m[0][1];");
4748 			statement_no_indent("");
4749 			statement("adj[1][0] = -m[1][0];");
4750 			statement("adj[1][1] =  m[0][0];");
4751 			statement_no_indent("");
4752 			statement("// Calculate the determinant as a combination of the cofactors of the first row.");
4753 			statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
4754 			statement_no_indent("");
4755 			statement("// Divide the classical adjoint matrix by the determinant.");
4756 			statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
4757 			statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
4758 			end_scope();
4759 			statement("");
4760 			break;
4761 
4762 		case SPVFuncImplForwardArgs:
4763 			statement("template<typename T> struct spvRemoveReference { typedef T type; };");
4764 			statement("template<typename T> struct spvRemoveReference<thread T&> { typedef T type; };");
4765 			statement("template<typename T> struct spvRemoveReference<thread T&&> { typedef T type; };");
4766 			statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
4767 			          "spvRemoveReference<T>::type& x)");
4768 			begin_scope();
4769 			statement("return static_cast<thread T&&>(x);");
4770 			end_scope();
4771 			statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
4772 			          "spvRemoveReference<T>::type&& x)");
4773 			begin_scope();
4774 			statement("return static_cast<thread T&&>(x);");
4775 			end_scope();
4776 			statement("");
4777 			break;
4778 
4779 		case SPVFuncImplGetSwizzle:
4780 			statement("enum class spvSwizzle : uint");
4781 			begin_scope();
4782 			statement("none = 0,");
4783 			statement("zero,");
4784 			statement("one,");
4785 			statement("red,");
4786 			statement("green,");
4787 			statement("blue,");
4788 			statement("alpha");
4789 			end_scope_decl();
4790 			statement("");
4791 			statement("template<typename T>");
4792 			statement("inline T spvGetSwizzle(vec<T, 4> x, T c, spvSwizzle s)");
4793 			begin_scope();
4794 			statement("switch (s)");
4795 			begin_scope();
4796 			statement("case spvSwizzle::none:");
4797 			statement("    return c;");
4798 			statement("case spvSwizzle::zero:");
4799 			statement("    return 0;");
4800 			statement("case spvSwizzle::one:");
4801 			statement("    return 1;");
4802 			statement("case spvSwizzle::red:");
4803 			statement("    return x.r;");
4804 			statement("case spvSwizzle::green:");
4805 			statement("    return x.g;");
4806 			statement("case spvSwizzle::blue:");
4807 			statement("    return x.b;");
4808 			statement("case spvSwizzle::alpha:");
4809 			statement("    return x.a;");
4810 			end_scope();
4811 			end_scope();
4812 			statement("");
4813 			break;
4814 
4815 		case SPVFuncImplTextureSwizzle:
4816 			statement("// Wrapper function that swizzles texture samples and fetches.");
4817 			statement("template<typename T>");
4818 			statement("inline vec<T, 4> spvTextureSwizzle(vec<T, 4> x, uint s)");
4819 			begin_scope();
4820 			statement("if (!s)");
4821 			statement("    return x;");
4822 			statement("return vec<T, 4>(spvGetSwizzle(x, x.r, spvSwizzle((s >> 0) & 0xFF)), "
4823 			          "spvGetSwizzle(x, x.g, spvSwizzle((s >> 8) & 0xFF)), spvGetSwizzle(x, x.b, spvSwizzle((s >> 16) "
4824 			          "& 0xFF)), "
4825 			          "spvGetSwizzle(x, x.a, spvSwizzle((s >> 24) & 0xFF)));");
4826 			end_scope();
4827 			statement("");
4828 			statement("template<typename T>");
4829 			statement("inline T spvTextureSwizzle(T x, uint s)");
4830 			begin_scope();
4831 			statement("return spvTextureSwizzle(vec<T, 4>(x, 0, 0, 1), s).x;");
4832 			end_scope();
4833 			statement("");
4834 			break;
4835 
4836 		case SPVFuncImplGatherSwizzle:
4837 			statement("// Wrapper function that swizzles texture gathers.");
4838 			statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
4839 			          "typename... Ts>");
4840 			statement("inline vec<T, 4> spvGatherSwizzle(const thread Tex<T>& t, sampler s, "
4841 			          "uint sw, component c, Ts... params) METAL_CONST_ARG(c)");
4842 			begin_scope();
4843 			statement("if (sw)");
4844 			begin_scope();
4845 			statement("switch (spvSwizzle((sw >> (uint(c) * 8)) & 0xFF))");
4846 			begin_scope();
4847 			statement("case spvSwizzle::none:");
4848 			statement("    break;");
4849 			statement("case spvSwizzle::zero:");
4850 			statement("    return vec<T, 4>(0, 0, 0, 0);");
4851 			statement("case spvSwizzle::one:");
4852 			statement("    return vec<T, 4>(1, 1, 1, 1);");
4853 			statement("case spvSwizzle::red:");
4854 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::x);");
4855 			statement("case spvSwizzle::green:");
4856 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::y);");
4857 			statement("case spvSwizzle::blue:");
4858 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::z);");
4859 			statement("case spvSwizzle::alpha:");
4860 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::w);");
4861 			end_scope();
4862 			end_scope();
4863 			// texture::gather insists on its component parameter being a constant
4864 			// expression, so we need this silly workaround just to compile the shader.
4865 			statement("switch (c)");
4866 			begin_scope();
4867 			statement("case component::x:");
4868 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::x);");
4869 			statement("case component::y:");
4870 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::y);");
4871 			statement("case component::z:");
4872 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::z);");
4873 			statement("case component::w:");
4874 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::w);");
4875 			end_scope();
4876 			end_scope();
4877 			statement("");
4878 			break;
4879 
4880 		case SPVFuncImplGatherCompareSwizzle:
4881 			statement("// Wrapper function that swizzles depth texture gathers.");
4882 			statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
4883 			          "typename... Ts>");
4884 			statement("inline vec<T, 4> spvGatherCompareSwizzle(const thread Tex<T>& t, sampler "
4885 			          "s, uint sw, Ts... params) ");
4886 			begin_scope();
4887 			statement("if (sw)");
4888 			begin_scope();
4889 			statement("switch (spvSwizzle(sw & 0xFF))");
4890 			begin_scope();
4891 			statement("case spvSwizzle::none:");
4892 			statement("case spvSwizzle::red:");
4893 			statement("    break;");
4894 			statement("case spvSwizzle::zero:");
4895 			statement("case spvSwizzle::green:");
4896 			statement("case spvSwizzle::blue:");
4897 			statement("case spvSwizzle::alpha:");
4898 			statement("    return vec<T, 4>(0, 0, 0, 0);");
4899 			statement("case spvSwizzle::one:");
4900 			statement("    return vec<T, 4>(1, 1, 1, 1);");
4901 			end_scope();
4902 			end_scope();
4903 			statement("return t.gather_compare(s, spvForward<Ts>(params)...);");
4904 			end_scope();
4905 			statement("");
4906 			break;
4907 
4908 		case SPVFuncImplSubgroupBroadcast:
4909 			// Metal doesn't allow broadcasting boolean values directly, but we can work around that by broadcasting
4910 			// them as integers.
4911 			statement("template<typename T>");
4912 			statement("inline T spvSubgroupBroadcast(T value, ushort lane)");
4913 			begin_scope();
4914 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
4915 				statement("return quad_broadcast(value, lane);");
4916 			else
4917 				statement("return simd_broadcast(value, lane);");
4918 			end_scope();
4919 			statement("");
4920 			statement("template<>");
4921 			statement("inline bool spvSubgroupBroadcast(bool value, ushort lane)");
4922 			begin_scope();
4923 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
4924 				statement("return !!quad_broadcast((ushort)value, lane);");
4925 			else
4926 				statement("return !!simd_broadcast((ushort)value, lane);");
4927 			end_scope();
4928 			statement("");
4929 			statement("template<uint N>");
4930 			statement("inline vec<bool, N> spvSubgroupBroadcast(vec<bool, N> value, ushort lane)");
4931 			begin_scope();
4932 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
4933 				statement("return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
4934 			else
4935 				statement("return (vec<bool, N>)simd_broadcast((vec<ushort, N>)value, lane);");
4936 			end_scope();
4937 			statement("");
4938 			break;
4939 
4940 		case SPVFuncImplSubgroupBroadcastFirst:
4941 			statement("template<typename T>");
4942 			statement("inline T spvSubgroupBroadcastFirst(T value)");
4943 			begin_scope();
4944 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
4945 				statement("return quad_broadcast_first(value);");
4946 			else
4947 				statement("return simd_broadcast_first(value);");
4948 			end_scope();
4949 			statement("");
4950 			statement("template<>");
4951 			statement("inline bool spvSubgroupBroadcastFirst(bool value)");
4952 			begin_scope();
4953 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
4954 				statement("return !!quad_broadcast_first((ushort)value);");
4955 			else
4956 				statement("return !!simd_broadcast_first((ushort)value);");
4957 			end_scope();
4958 			statement("");
4959 			statement("template<uint N>");
4960 			statement("inline vec<bool, N> spvSubgroupBroadcastFirst(vec<bool, N> value)");
4961 			begin_scope();
4962 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
4963 				statement("return (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value);");
4964 			else
4965 				statement("return (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value);");
4966 			end_scope();
4967 			statement("");
4968 			break;
4969 
4970 		case SPVFuncImplSubgroupBallot:
4971 			statement("inline uint4 spvSubgroupBallot(bool value)");
4972 			begin_scope();
4973 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
4974 			{
4975 				statement("return uint4((quad_vote::vote_t)quad_ballot(value), 0, 0, 0);");
4976 			}
4977 			else if (msl_options.is_ios())
4978 			{
4979 				// The current simd_vote on iOS uses a 32-bit integer-like object.
4980 				statement("return uint4((simd_vote::vote_t)simd_ballot(value), 0, 0, 0);");
4981 			}
4982 			else
4983 			{
4984 				statement("simd_vote vote = simd_ballot(value);");
4985 				statement("// simd_ballot() returns a 64-bit integer-like object, but");
4986 				statement("// SPIR-V callers expect a uint4. We must convert.");
4987 				statement("// FIXME: This won't include higher bits if Apple ever supports");
4988 				statement("// 128 lanes in an SIMD-group.");
4989 				statement(
4990 				    "return uint4((uint)((simd_vote::vote_t)vote & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)vote >> "
4991 				    "32) & 0xFFFFFFFF), 0, 0);");
4992 			}
4993 			end_scope();
4994 			statement("");
4995 			break;
4996 
4997 		case SPVFuncImplSubgroupBallotBitExtract:
4998 			statement("inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)");
4999 			begin_scope();
5000 			statement("return !!extract_bits(ballot[bit / 32], bit % 32, 1);");
5001 			end_scope();
5002 			statement("");
5003 			break;
5004 
5005 		case SPVFuncImplSubgroupBallotFindLSB:
5006 			statement("inline uint spvSubgroupBallotFindLSB(uint4 ballot, uint gl_SubgroupSize)");
5007 			begin_scope();
5008 			if (msl_options.is_ios())
5009 			{
5010 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
5011 			}
5012 			else
5013 			{
5014 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
5015 				          "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
5016 			}
5017 			statement("ballot &= mask;");
5018 			statement("return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + "
5019 			          "ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);");
5020 			end_scope();
5021 			statement("");
5022 			break;
5023 
5024 		case SPVFuncImplSubgroupBallotFindMSB:
5025 			statement("inline uint spvSubgroupBallotFindMSB(uint4 ballot, uint gl_SubgroupSize)");
5026 			begin_scope();
5027 			if (msl_options.is_ios())
5028 			{
5029 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
5030 			}
5031 			else
5032 			{
5033 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
5034 				          "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
5035 			}
5036 			statement("ballot &= mask;");
5037 			statement("return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - "
5038 			          "(clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), "
5039 			          "ballot.z == 0), ballot.w == 0);");
5040 			end_scope();
5041 			statement("");
5042 			break;
5043 
5044 		case SPVFuncImplSubgroupBallotBitCount:
5045 			statement("inline uint spvPopCount4(uint4 ballot)");
5046 			begin_scope();
5047 			statement("return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);");
5048 			end_scope();
5049 			statement("");
5050 			statement("inline uint spvSubgroupBallotBitCount(uint4 ballot, uint gl_SubgroupSize)");
5051 			begin_scope();
5052 			if (msl_options.is_ios())
5053 			{
5054 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
5055 			}
5056 			else
5057 			{
5058 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
5059 				          "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
5060 			}
5061 			statement("return spvPopCount4(ballot & mask);");
5062 			end_scope();
5063 			statement("");
5064 			statement("inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
5065 			begin_scope();
5066 			if (msl_options.is_ios())
5067 			{
5068 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID + 1), uint3(0));");
5069 			}
5070 			else
5071 			{
5072 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), "
5073 				          "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), "
5074 				          "uint2(0));");
5075 			}
5076 			statement("return spvPopCount4(ballot & mask);");
5077 			end_scope();
5078 			statement("");
5079 			statement("inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
5080 			begin_scope();
5081 			if (msl_options.is_ios())
5082 			{
5083 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID), uint2(0));");
5084 			}
5085 			else
5086 			{
5087 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), "
5088 				          "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));");
5089 			}
5090 			statement("return spvPopCount4(ballot & mask);");
5091 			end_scope();
5092 			statement("");
5093 			break;
5094 
5095 		case SPVFuncImplSubgroupAllEqual:
5096 			// Metal doesn't provide a function to evaluate this directly. But, we can
5097 			// implement this by comparing every thread's value to one thread's value
5098 			// (in this case, the value of the first active thread). Then, by the transitive
5099 			// property of equality, if all comparisons return true, then they are all equal.
5100 			statement("template<typename T>");
5101 			statement("inline bool spvSubgroupAllEqual(T value)");
5102 			begin_scope();
5103 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5104 				statement("return quad_all(all(value == quad_broadcast_first(value)));");
5105 			else
5106 				statement("return simd_all(all(value == simd_broadcast_first(value)));");
5107 			end_scope();
5108 			statement("");
5109 			statement("template<>");
5110 			statement("inline bool spvSubgroupAllEqual(bool value)");
5111 			begin_scope();
5112 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5113 				statement("return quad_all(value) || !quad_any(value);");
5114 			else
5115 				statement("return simd_all(value) || !simd_any(value);");
5116 			end_scope();
5117 			statement("");
5118 			statement("template<uint N>");
5119 			statement("inline bool spvSubgroupAllEqual(vec<bool, N> value)");
5120 			begin_scope();
5121 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5122 				statement("return quad_all(all(value == (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value)));");
5123 			else
5124 				statement("return simd_all(all(value == (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value)));");
5125 			end_scope();
5126 			statement("");
5127 			break;
5128 
5129 		case SPVFuncImplSubgroupShuffle:
5130 			statement("template<typename T>");
5131 			statement("inline T spvSubgroupShuffle(T value, ushort lane)");
5132 			begin_scope();
5133 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5134 				statement("return quad_shuffle(value, lane);");
5135 			else
5136 				statement("return simd_shuffle(value, lane);");
5137 			end_scope();
5138 			statement("");
5139 			statement("template<>");
5140 			statement("inline bool spvSubgroupShuffle(bool value, ushort lane)");
5141 			begin_scope();
5142 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5143 				statement("return !!quad_shuffle((ushort)value, lane);");
5144 			else
5145 				statement("return !!simd_shuffle((ushort)value, lane);");
5146 			end_scope();
5147 			statement("");
5148 			statement("template<uint N>");
5149 			statement("inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)");
5150 			begin_scope();
5151 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5152 				statement("return (vec<bool, N>)quad_shuffle((vec<ushort, N>)value, lane);");
5153 			else
5154 				statement("return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);");
5155 			end_scope();
5156 			statement("");
5157 			break;
5158 
5159 		case SPVFuncImplSubgroupShuffleXor:
5160 			statement("template<typename T>");
5161 			statement("inline T spvSubgroupShuffleXor(T value, ushort mask)");
5162 			begin_scope();
5163 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5164 				statement("return quad_shuffle_xor(value, mask);");
5165 			else
5166 				statement("return simd_shuffle_xor(value, mask);");
5167 			end_scope();
5168 			statement("");
5169 			statement("template<>");
5170 			statement("inline bool spvSubgroupShuffleXor(bool value, ushort mask)");
5171 			begin_scope();
5172 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5173 				statement("return !!quad_shuffle_xor((ushort)value, mask);");
5174 			else
5175 				statement("return !!simd_shuffle_xor((ushort)value, mask);");
5176 			end_scope();
5177 			statement("");
5178 			statement("template<uint N>");
5179 			statement("inline vec<bool, N> spvSubgroupShuffleXor(vec<bool, N> value, ushort mask)");
5180 			begin_scope();
5181 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5182 				statement("return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, mask);");
5183 			else
5184 				statement("return (vec<bool, N>)simd_shuffle_xor((vec<ushort, N>)value, mask);");
5185 			end_scope();
5186 			statement("");
5187 			break;
5188 
5189 		case SPVFuncImplSubgroupShuffleUp:
5190 			statement("template<typename T>");
5191 			statement("inline T spvSubgroupShuffleUp(T value, ushort delta)");
5192 			begin_scope();
5193 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5194 				statement("return quad_shuffle_up(value, delta);");
5195 			else
5196 				statement("return simd_shuffle_up(value, delta);");
5197 			end_scope();
5198 			statement("");
5199 			statement("template<>");
5200 			statement("inline bool spvSubgroupShuffleUp(bool value, ushort delta)");
5201 			begin_scope();
5202 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5203 				statement("return !!quad_shuffle_up((ushort)value, delta);");
5204 			else
5205 				statement("return !!simd_shuffle_up((ushort)value, delta);");
5206 			end_scope();
5207 			statement("");
5208 			statement("template<uint N>");
5209 			statement("inline vec<bool, N> spvSubgroupShuffleUp(vec<bool, N> value, ushort delta)");
5210 			begin_scope();
5211 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5212 				statement("return (vec<bool, N>)quad_shuffle_up((vec<ushort, N>)value, delta);");
5213 			else
5214 				statement("return (vec<bool, N>)simd_shuffle_up((vec<ushort, N>)value, delta);");
5215 			end_scope();
5216 			statement("");
5217 			break;
5218 
5219 		case SPVFuncImplSubgroupShuffleDown:
5220 			statement("template<typename T>");
5221 			statement("inline T spvSubgroupShuffleDown(T value, ushort delta)");
5222 			begin_scope();
5223 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5224 				statement("return quad_shuffle_down(value, delta);");
5225 			else
5226 				statement("return simd_shuffle_down(value, delta);");
5227 			end_scope();
5228 			statement("");
5229 			statement("template<>");
5230 			statement("inline bool spvSubgroupShuffleDown(bool value, ushort delta)");
5231 			begin_scope();
5232 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5233 				statement("return !!quad_shuffle_down((ushort)value, delta);");
5234 			else
5235 				statement("return !!simd_shuffle_down((ushort)value, delta);");
5236 			end_scope();
5237 			statement("");
5238 			statement("template<uint N>");
5239 			statement("inline vec<bool, N> spvSubgroupShuffleDown(vec<bool, N> value, ushort delta)");
5240 			begin_scope();
5241 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5242 				statement("return (vec<bool, N>)quad_shuffle_down((vec<ushort, N>)value, delta);");
5243 			else
5244 				statement("return (vec<bool, N>)simd_shuffle_down((vec<ushort, N>)value, delta);");
5245 			end_scope();
5246 			statement("");
5247 			break;
5248 
5249 		case SPVFuncImplQuadBroadcast:
5250 			statement("template<typename T>");
5251 			statement("inline T spvQuadBroadcast(T value, uint lane)");
5252 			begin_scope();
5253 			statement("return quad_broadcast(value, lane);");
5254 			end_scope();
5255 			statement("");
5256 			statement("template<>");
5257 			statement("inline bool spvQuadBroadcast(bool value, uint lane)");
5258 			begin_scope();
5259 			statement("return !!quad_broadcast((ushort)value, lane);");
5260 			end_scope();
5261 			statement("");
5262 			statement("template<uint N>");
5263 			statement("inline vec<bool, N> spvQuadBroadcast(vec<bool, N> value, uint lane)");
5264 			begin_scope();
5265 			statement("return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
5266 			end_scope();
5267 			statement("");
5268 			break;
5269 
5270 		case SPVFuncImplQuadSwap:
5271 			// We can implement this easily based on the following table giving
5272 			// the target lane ID from the direction and current lane ID:
5273 			//        Direction
5274 			//      | 0 | 1 | 2 |
5275 			//   ---+---+---+---+
5276 			// L 0  | 1   2   3
5277 			// a 1  | 0   3   2
5278 			// n 2  | 3   0   1
5279 			// e 3  | 2   1   0
5280 			// Notice that target = source ^ (direction + 1).
5281 			statement("template<typename T>");
5282 			statement("inline T spvQuadSwap(T value, uint dir)");
5283 			begin_scope();
5284 			statement("return quad_shuffle_xor(value, dir + 1);");
5285 			end_scope();
5286 			statement("");
5287 			statement("template<>");
5288 			statement("inline bool spvQuadSwap(bool value, uint dir)");
5289 			begin_scope();
5290 			statement("return !!quad_shuffle_xor((ushort)value, dir + 1);");
5291 			end_scope();
5292 			statement("");
5293 			statement("template<uint N>");
5294 			statement("inline vec<bool, N> spvQuadSwap(vec<bool, N> value, uint dir)");
5295 			begin_scope();
5296 			statement("return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, dir + 1);");
5297 			end_scope();
5298 			statement("");
5299 			break;
5300 
5301 		case SPVFuncImplReflectScalar:
5302 			// Metal does not support scalar versions of these functions.
5303 			statement("template<typename T>");
5304 			statement("inline T spvReflect(T i, T n)");
5305 			begin_scope();
5306 			statement("return i - T(2) * i * n * n;");
5307 			end_scope();
5308 			statement("");
5309 			break;
5310 
5311 		case SPVFuncImplRefractScalar:
5312 			// Metal does not support scalar versions of these functions.
5313 			statement("template<typename T>");
5314 			statement("inline T spvRefract(T i, T n, T eta)");
5315 			begin_scope();
5316 			statement("T NoI = n * i;");
5317 			statement("T NoI2 = NoI * NoI;");
5318 			statement("T k = T(1) - eta * eta * (T(1) - NoI2);");
5319 			statement("if (k < T(0))");
5320 			begin_scope();
5321 			statement("return T(0);");
5322 			end_scope();
5323 			statement("else");
5324 			begin_scope();
5325 			statement("return eta * i - (eta * NoI + sqrt(k)) * n;");
5326 			end_scope();
5327 			end_scope();
5328 			statement("");
5329 			break;
5330 
5331 		case SPVFuncImplFaceForwardScalar:
5332 			// Metal does not support scalar versions of these functions.
5333 			statement("template<typename T>");
5334 			statement("inline T spvFaceForward(T n, T i, T nref)");
5335 			begin_scope();
5336 			statement("return i * nref < T(0) ? n : -n;");
5337 			end_scope();
5338 			statement("");
5339 			break;
5340 
5341 		case SPVFuncImplChromaReconstructNearest2Plane:
5342 			statement("template<typename T, typename... LodOptions>");
5343 			statement("inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, sampler "
5344 			          "samp, float2 coord, LodOptions... options)");
5345 			begin_scope();
5346 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5347 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5348 			statement("ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
5349 			statement("return ycbcr;");
5350 			end_scope();
5351 			statement("");
5352 			break;
5353 
5354 		case SPVFuncImplChromaReconstructNearest3Plane:
5355 			statement("template<typename T, typename... LodOptions>");
5356 			statement("inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, "
5357 			          "texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5358 			begin_scope();
5359 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5360 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5361 			statement("ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5362 			statement("ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5363 			statement("return ycbcr;");
5364 			end_scope();
5365 			statement("");
5366 			break;
5367 
5368 		case SPVFuncImplChromaReconstructLinear422CositedEven2Plane:
5369 			statement("template<typename T, typename... LodOptions>");
5370 			statement("inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
5371 			          "plane1, sampler samp, float2 coord, LodOptions... options)");
5372 			begin_scope();
5373 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5374 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5375 			statement("if (fract(coord.x * plane1.get_width()) != 0.0)");
5376 			begin_scope();
5377 			statement("ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5378 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).rg);");
5379 			end_scope();
5380 			statement("else");
5381 			begin_scope();
5382 			statement("ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
5383 			end_scope();
5384 			statement("return ycbcr;");
5385 			end_scope();
5386 			statement("");
5387 			break;
5388 
5389 		case SPVFuncImplChromaReconstructLinear422CositedEven3Plane:
5390 			statement("template<typename T, typename... LodOptions>");
5391 			statement("inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
5392 			          "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5393 			begin_scope();
5394 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5395 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5396 			statement("if (fract(coord.x * plane1.get_width()) != 0.0)");
5397 			begin_scope();
5398 			statement("ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5399 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
5400 			statement("ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5401 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
5402 			end_scope();
5403 			statement("else");
5404 			begin_scope();
5405 			statement("ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5406 			statement("ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5407 			end_scope();
5408 			statement("return ycbcr;");
5409 			end_scope();
5410 			statement("");
5411 			break;
5412 
5413 		case SPVFuncImplChromaReconstructLinear422Midpoint2Plane:
5414 			statement("template<typename T, typename... LodOptions>");
5415 			statement("inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
5416 			          "plane1, sampler samp, float2 coord, LodOptions... options)");
5417 			begin_scope();
5418 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5419 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5420 			statement("int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
5421 			statement("ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5422 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).rg);");
5423 			statement("return ycbcr;");
5424 			end_scope();
5425 			statement("");
5426 			break;
5427 
5428 		case SPVFuncImplChromaReconstructLinear422Midpoint3Plane:
5429 			statement("template<typename T, typename... LodOptions>");
5430 			statement("inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
5431 			          "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5432 			begin_scope();
5433 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5434 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5435 			statement("int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
5436 			statement("ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5437 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
5438 			statement("ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5439 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
5440 			statement("return ycbcr;");
5441 			end_scope();
5442 			statement("");
5443 			break;
5444 
5445 		case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane:
5446 			statement("template<typename T, typename... LodOptions>");
5447 			statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
5448 			          "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5449 			begin_scope();
5450 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5451 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5452 			statement("float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
5453 			statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5454 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5455 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5456 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5457 			statement("return ycbcr;");
5458 			end_scope();
5459 			statement("");
5460 			break;
5461 
5462 		case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane:
5463 			statement("template<typename T, typename... LodOptions>");
5464 			statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
5465 			          "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5466 			begin_scope();
5467 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5468 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5469 			statement("float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
5470 			statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5471 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5472 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5473 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5474 			statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5475 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5476 			          "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5477 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5478 			statement("return ycbcr;");
5479 			end_scope();
5480 			statement("");
5481 			break;
5482 
5483 		case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane:
5484 			statement("template<typename T, typename... LodOptions>");
5485 			statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
5486 			          "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5487 			begin_scope();
5488 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5489 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5490 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
5491 			          "0)) * 0.5);");
5492 			statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5493 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5494 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5495 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5496 			statement("return ycbcr;");
5497 			end_scope();
5498 			statement("");
5499 			break;
5500 
5501 		case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane:
5502 			statement("template<typename T, typename... LodOptions>");
5503 			statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
5504 			          "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5505 			begin_scope();
5506 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5507 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5508 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
5509 			          "0)) * 0.5);");
5510 			statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5511 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5512 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5513 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5514 			statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5515 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5516 			          "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5517 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5518 			statement("return ycbcr;");
5519 			end_scope();
5520 			statement("");
5521 			break;
5522 
5523 		case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane:
5524 			statement("template<typename T, typename... LodOptions>");
5525 			statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
5526 			          "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5527 			begin_scope();
5528 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5529 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5530 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
5531 			          "0.5)) * 0.5);");
5532 			statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5533 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5534 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5535 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5536 			statement("return ycbcr;");
5537 			end_scope();
5538 			statement("");
5539 			break;
5540 
5541 		case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane:
5542 			statement("template<typename T, typename... LodOptions>");
5543 			statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
5544 			          "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5545 			begin_scope();
5546 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5547 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5548 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
5549 			          "0.5)) * 0.5);");
5550 			statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5551 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5552 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5553 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5554 			statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5555 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5556 			          "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5557 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5558 			statement("return ycbcr;");
5559 			end_scope();
5560 			statement("");
5561 			break;
5562 
5563 		case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane:
5564 			statement("template<typename T, typename... LodOptions>");
5565 			statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
5566 			          "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5567 			begin_scope();
5568 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5569 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5570 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
5571 			          "0.5)) * 0.5);");
5572 			statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5573 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5574 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5575 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5576 			statement("return ycbcr;");
5577 			end_scope();
5578 			statement("");
5579 			break;
5580 
5581 		case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane:
5582 			statement("template<typename T, typename... LodOptions>");
5583 			statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
5584 			          "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5585 			begin_scope();
5586 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5587 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5588 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
5589 			          "0.5)) * 0.5);");
5590 			statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5591 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5592 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5593 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5594 			statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5595 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5596 			          "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5597 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5598 			statement("return ycbcr;");
5599 			end_scope();
5600 			statement("");
5601 			break;
5602 
5603 		case SPVFuncImplExpandITUFullRange:
5604 			statement("template<typename T>");
5605 			statement("inline vec<T, 4> spvExpandITUFullRange(vec<T, 4> ycbcr, int n)");
5606 			begin_scope();
5607 			statement("ycbcr.br -= exp2(T(n-1))/(exp2(T(n))-1);");
5608 			statement("return ycbcr;");
5609 			end_scope();
5610 			statement("");
5611 			break;
5612 
5613 		case SPVFuncImplExpandITUNarrowRange:
5614 			statement("template<typename T>");
5615 			statement("inline vec<T, 4> spvExpandITUNarrowRange(vec<T, 4> ycbcr, int n)");
5616 			begin_scope();
5617 			statement("ycbcr.g = (ycbcr.g * (exp2(T(n)) - 1) - ldexp(T(16), n - 8))/ldexp(T(219), n - 8);");
5618 			statement("ycbcr.br = (ycbcr.br * (exp2(T(n)) - 1) - ldexp(T(128), n - 8))/ldexp(T(224), n - 8);");
5619 			statement("return ycbcr;");
5620 			end_scope();
5621 			statement("");
5622 			break;
5623 
5624 		case SPVFuncImplConvertYCbCrBT709:
5625 			statement("// cf. Khronos Data Format Specification, section 15.1.1");
5626 			statement("constant float3x3 spvBT709Factors = {{1, 1, 1}, {0, -0.13397432/0.7152, 1.8556}, {1.5748, "
5627 			          "-0.33480248/0.7152, 0}};");
5628 			statement("");
5629 			statement("template<typename T>");
5630 			statement("inline vec<T, 4> spvConvertYCbCrBT709(vec<T, 4> ycbcr)");
5631 			begin_scope();
5632 			statement("vec<T, 4> rgba;");
5633 			statement("rgba.rgb = vec<T, 3>(spvBT709Factors * ycbcr.gbr);");
5634 			statement("rgba.a = ycbcr.a;");
5635 			statement("return rgba;");
5636 			end_scope();
5637 			statement("");
5638 			break;
5639 
5640 		case SPVFuncImplConvertYCbCrBT601:
5641 			statement("// cf. Khronos Data Format Specification, section 15.1.2");
5642 			statement("constant float3x3 spvBT601Factors = {{1, 1, 1}, {0, -0.202008/0.587, 1.772}, {1.402, "
5643 			          "-0.419198/0.587, 0}};");
5644 			statement("");
5645 			statement("template<typename T>");
5646 			statement("inline vec<T, 4> spvConvertYCbCrBT601(vec<T, 4> ycbcr)");
5647 			begin_scope();
5648 			statement("vec<T, 4> rgba;");
5649 			statement("rgba.rgb = vec<T, 3>(spvBT601Factors * ycbcr.gbr);");
5650 			statement("rgba.a = ycbcr.a;");
5651 			statement("return rgba;");
5652 			end_scope();
5653 			statement("");
5654 			break;
5655 
5656 		case SPVFuncImplConvertYCbCrBT2020:
5657 			statement("// cf. Khronos Data Format Specification, section 15.1.3");
5658 			statement("constant float3x3 spvBT2020Factors = {{1, 1, 1}, {0, -0.11156702/0.6780, 1.8814}, {1.4746, "
5659 			          "-0.38737742/0.6780, 0}};");
5660 			statement("");
5661 			statement("template<typename T>");
5662 			statement("inline vec<T, 4> spvConvertYCbCrBT2020(vec<T, 4> ycbcr)");
5663 			begin_scope();
5664 			statement("vec<T, 4> rgba;");
5665 			statement("rgba.rgb = vec<T, 3>(spvBT2020Factors * ycbcr.gbr);");
5666 			statement("rgba.a = ycbcr.a;");
5667 			statement("return rgba;");
5668 			end_scope();
5669 			statement("");
5670 			break;
5671 
5672 		case SPVFuncImplDynamicImageSampler:
5673 			statement("enum class spvFormatResolution");
5674 			begin_scope();
5675 			statement("_444 = 0,");
5676 			statement("_422,");
5677 			statement("_420");
5678 			end_scope_decl();
5679 			statement("");
5680 			statement("enum class spvChromaFilter");
5681 			begin_scope();
5682 			statement("nearest = 0,");
5683 			statement("linear");
5684 			end_scope_decl();
5685 			statement("");
5686 			statement("enum class spvXChromaLocation");
5687 			begin_scope();
5688 			statement("cosited_even = 0,");
5689 			statement("midpoint");
5690 			end_scope_decl();
5691 			statement("");
5692 			statement("enum class spvYChromaLocation");
5693 			begin_scope();
5694 			statement("cosited_even = 0,");
5695 			statement("midpoint");
5696 			end_scope_decl();
5697 			statement("");
5698 			statement("enum class spvYCbCrModelConversion");
5699 			begin_scope();
5700 			statement("rgb_identity = 0,");
5701 			statement("ycbcr_identity,");
5702 			statement("ycbcr_bt_709,");
5703 			statement("ycbcr_bt_601,");
5704 			statement("ycbcr_bt_2020");
5705 			end_scope_decl();
5706 			statement("");
5707 			statement("enum class spvYCbCrRange");
5708 			begin_scope();
5709 			statement("itu_full = 0,");
5710 			statement("itu_narrow");
5711 			end_scope_decl();
5712 			statement("");
5713 			statement("struct spvComponentBits");
5714 			begin_scope();
5715 			statement("constexpr explicit spvComponentBits(int v) thread : value(v) {}");
5716 			statement("uchar value : 6;");
5717 			end_scope_decl();
5718 			statement("// A class corresponding to metal::sampler which holds sampler");
5719 			statement("// Y'CbCr conversion info.");
5720 			statement("struct spvYCbCrSampler");
5721 			begin_scope();
5722 			statement("constexpr spvYCbCrSampler() thread : val(build()) {}");
5723 			statement("template<typename... Ts>");
5724 			statement("constexpr spvYCbCrSampler(Ts... t) thread : val(build(t...)) {}");
5725 			statement("constexpr spvYCbCrSampler(const thread spvYCbCrSampler& s) thread = default;");
5726 			statement("");
5727 			statement("spvFormatResolution get_resolution() const thread");
5728 			begin_scope();
5729 			statement("return spvFormatResolution((val & resolution_mask) >> resolution_base);");
5730 			end_scope();
5731 			statement("spvChromaFilter get_chroma_filter() const thread");
5732 			begin_scope();
5733 			statement("return spvChromaFilter((val & chroma_filter_mask) >> chroma_filter_base);");
5734 			end_scope();
5735 			statement("spvXChromaLocation get_x_chroma_offset() const thread");
5736 			begin_scope();
5737 			statement("return spvXChromaLocation((val & x_chroma_off_mask) >> x_chroma_off_base);");
5738 			end_scope();
5739 			statement("spvYChromaLocation get_y_chroma_offset() const thread");
5740 			begin_scope();
5741 			statement("return spvYChromaLocation((val & y_chroma_off_mask) >> y_chroma_off_base);");
5742 			end_scope();
5743 			statement("spvYCbCrModelConversion get_ycbcr_model() const thread");
5744 			begin_scope();
5745 			statement("return spvYCbCrModelConversion((val & ycbcr_model_mask) >> ycbcr_model_base);");
5746 			end_scope();
5747 			statement("spvYCbCrRange get_ycbcr_range() const thread");
5748 			begin_scope();
5749 			statement("return spvYCbCrRange((val & ycbcr_range_mask) >> ycbcr_range_base);");
5750 			end_scope();
5751 			statement("int get_bpc() const thread { return (val & bpc_mask) >> bpc_base; }");
5752 			statement("");
5753 			statement("private:");
5754 			statement("ushort val;");
5755 			statement("");
5756 			statement("constexpr static constant ushort resolution_bits = 2;");
5757 			statement("constexpr static constant ushort chroma_filter_bits = 2;");
5758 			statement("constexpr static constant ushort x_chroma_off_bit = 1;");
5759 			statement("constexpr static constant ushort y_chroma_off_bit = 1;");
5760 			statement("constexpr static constant ushort ycbcr_model_bits = 3;");
5761 			statement("constexpr static constant ushort ycbcr_range_bit = 1;");
5762 			statement("constexpr static constant ushort bpc_bits = 6;");
5763 			statement("");
5764 			statement("constexpr static constant ushort resolution_base = 0;");
5765 			statement("constexpr static constant ushort chroma_filter_base = 2;");
5766 			statement("constexpr static constant ushort x_chroma_off_base = 4;");
5767 			statement("constexpr static constant ushort y_chroma_off_base = 5;");
5768 			statement("constexpr static constant ushort ycbcr_model_base = 6;");
5769 			statement("constexpr static constant ushort ycbcr_range_base = 9;");
5770 			statement("constexpr static constant ushort bpc_base = 10;");
5771 			statement("");
5772 			statement(
5773 			    "constexpr static constant ushort resolution_mask = ((1 << resolution_bits) - 1) << resolution_base;");
5774 			statement("constexpr static constant ushort chroma_filter_mask = ((1 << chroma_filter_bits) - 1) << "
5775 			          "chroma_filter_base;");
5776 			statement("constexpr static constant ushort x_chroma_off_mask = ((1 << x_chroma_off_bit) - 1) << "
5777 			          "x_chroma_off_base;");
5778 			statement("constexpr static constant ushort y_chroma_off_mask = ((1 << y_chroma_off_bit) - 1) << "
5779 			          "y_chroma_off_base;");
5780 			statement("constexpr static constant ushort ycbcr_model_mask = ((1 << ycbcr_model_bits) - 1) << "
5781 			          "ycbcr_model_base;");
5782 			statement("constexpr static constant ushort ycbcr_range_mask = ((1 << ycbcr_range_bit) - 1) << "
5783 			          "ycbcr_range_base;");
5784 			statement("constexpr static constant ushort bpc_mask = ((1 << bpc_bits) - 1) << bpc_base;");
5785 			statement("");
5786 			statement("static constexpr ushort build()");
5787 			begin_scope();
5788 			statement("return 0;");
5789 			end_scope();
5790 			statement("");
5791 			statement("template<typename... Ts>");
5792 			statement("static constexpr ushort build(spvFormatResolution res, Ts... t)");
5793 			begin_scope();
5794 			statement("return (ushort(res) << resolution_base) | (build(t...) & ~resolution_mask);");
5795 			end_scope();
5796 			statement("");
5797 			statement("template<typename... Ts>");
5798 			statement("static constexpr ushort build(spvChromaFilter filt, Ts... t)");
5799 			begin_scope();
5800 			statement("return (ushort(filt) << chroma_filter_base) | (build(t...) & ~chroma_filter_mask);");
5801 			end_scope();
5802 			statement("");
5803 			statement("template<typename... Ts>");
5804 			statement("static constexpr ushort build(spvXChromaLocation loc, Ts... t)");
5805 			begin_scope();
5806 			statement("return (ushort(loc) << x_chroma_off_base) | (build(t...) & ~x_chroma_off_mask);");
5807 			end_scope();
5808 			statement("");
5809 			statement("template<typename... Ts>");
5810 			statement("static constexpr ushort build(spvYChromaLocation loc, Ts... t)");
5811 			begin_scope();
5812 			statement("return (ushort(loc) << y_chroma_off_base) | (build(t...) & ~y_chroma_off_mask);");
5813 			end_scope();
5814 			statement("");
5815 			statement("template<typename... Ts>");
5816 			statement("static constexpr ushort build(spvYCbCrModelConversion model, Ts... t)");
5817 			begin_scope();
5818 			statement("return (ushort(model) << ycbcr_model_base) | (build(t...) & ~ycbcr_model_mask);");
5819 			end_scope();
5820 			statement("");
5821 			statement("template<typename... Ts>");
5822 			statement("static constexpr ushort build(spvYCbCrRange range, Ts... t)");
5823 			begin_scope();
5824 			statement("return (ushort(range) << ycbcr_range_base) | (build(t...) & ~ycbcr_range_mask);");
5825 			end_scope();
5826 			statement("");
5827 			statement("template<typename... Ts>");
5828 			statement("static constexpr ushort build(spvComponentBits bpc, Ts... t)");
5829 			begin_scope();
5830 			statement("return (ushort(bpc.value) << bpc_base) | (build(t...) & ~bpc_mask);");
5831 			end_scope();
5832 			end_scope_decl();
5833 			statement("");
5834 			statement("// A class which can hold up to three textures and a sampler, including");
5835 			statement("// Y'CbCr conversion info, used to pass combined image-samplers");
5836 			statement("// dynamically to functions.");
5837 			statement("template<typename T>");
5838 			statement("struct spvDynamicImageSampler");
5839 			begin_scope();
5840 			statement("texture2d<T> plane0;");
5841 			statement("texture2d<T> plane1;");
5842 			statement("texture2d<T> plane2;");
5843 			statement("sampler samp;");
5844 			statement("spvYCbCrSampler ycbcr_samp;");
5845 			statement("uint swizzle = 0;");
5846 			statement("");
5847 			if (msl_options.swizzle_texture_samples)
5848 			{
5849 				statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, uint sw) thread :");
5850 				statement("    plane0(tex), samp(samp), swizzle(sw) {}");
5851 			}
5852 			else
5853 			{
5854 				statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp) thread :");
5855 				statement("    plane0(tex), samp(samp) {}");
5856 			}
5857 			statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, spvYCbCrSampler ycbcr_samp, "
5858 			          "uint sw) thread :");
5859 			statement("    plane0(tex), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
5860 			statement("constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1,");
5861 			statement("                                 sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
5862 			statement("    plane0(plane0), plane1(plane1), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
5863 			statement(
5864 			    "constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1, texture2d<T> plane2,");
5865 			statement("                                 sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
5866 			statement("    plane0(plane0), plane1(plane1), plane2(plane2), samp(samp), ycbcr_samp(ycbcr_samp), "
5867 			          "swizzle(sw) {}");
5868 			statement("");
5869 			// XXX This is really hard to follow... I've left comments to make it a bit easier.
5870 			statement("template<typename... LodOptions>");
5871 			statement("vec<T, 4> do_sample(float2 coord, LodOptions... options) const thread");
5872 			begin_scope();
5873 			statement("if (!is_null_texture(plane1))");
5874 			begin_scope();
5875 			statement("if (ycbcr_samp.get_resolution() == spvFormatResolution::_444 ||");
5876 			statement("    ycbcr_samp.get_chroma_filter() == spvChromaFilter::nearest)");
5877 			begin_scope();
5878 			statement("if (!is_null_texture(plane2))");
5879 			statement("    return spvChromaReconstructNearest(plane0, plane1, plane2, samp, coord,");
5880 			statement("                                       spvForward<LodOptions>(options)...);");
5881 			statement(
5882 			    "return spvChromaReconstructNearest(plane0, plane1, samp, coord, spvForward<LodOptions>(options)...);");
5883 			end_scope(); // if (resolution == 422 || chroma_filter == nearest)
5884 			statement("switch (ycbcr_samp.get_resolution())");
5885 			begin_scope();
5886 			statement("case spvFormatResolution::_444: break;");
5887 			statement("case spvFormatResolution::_422:");
5888 			begin_scope();
5889 			statement("switch (ycbcr_samp.get_x_chroma_offset())");
5890 			begin_scope();
5891 			statement("case spvXChromaLocation::cosited_even:");
5892 			statement("    if (!is_null_texture(plane2))");
5893 			statement("        return spvChromaReconstructLinear422CositedEven(");
5894 			statement("            plane0, plane1, plane2, samp,");
5895 			statement("            coord, spvForward<LodOptions>(options)...);");
5896 			statement("    return spvChromaReconstructLinear422CositedEven(");
5897 			statement("        plane0, plane1, samp, coord,");
5898 			statement("        spvForward<LodOptions>(options)...);");
5899 			statement("case spvXChromaLocation::midpoint:");
5900 			statement("    if (!is_null_texture(plane2))");
5901 			statement("        return spvChromaReconstructLinear422Midpoint(");
5902 			statement("            plane0, plane1, plane2, samp,");
5903 			statement("            coord, spvForward<LodOptions>(options)...);");
5904 			statement("    return spvChromaReconstructLinear422Midpoint(");
5905 			statement("        plane0, plane1, samp, coord,");
5906 			statement("        spvForward<LodOptions>(options)...);");
5907 			end_scope(); // switch (x_chroma_offset)
5908 			end_scope(); // case 422:
5909 			statement("case spvFormatResolution::_420:");
5910 			begin_scope();
5911 			statement("switch (ycbcr_samp.get_x_chroma_offset())");
5912 			begin_scope();
5913 			statement("case spvXChromaLocation::cosited_even:");
5914 			begin_scope();
5915 			statement("switch (ycbcr_samp.get_y_chroma_offset())");
5916 			begin_scope();
5917 			statement("case spvYChromaLocation::cosited_even:");
5918 			statement("    if (!is_null_texture(plane2))");
5919 			statement("        return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
5920 			statement("            plane0, plane1, plane2, samp,");
5921 			statement("            coord, spvForward<LodOptions>(options)...);");
5922 			statement("    return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
5923 			statement("        plane0, plane1, samp, coord,");
5924 			statement("        spvForward<LodOptions>(options)...);");
5925 			statement("case spvYChromaLocation::midpoint:");
5926 			statement("    if (!is_null_texture(plane2))");
5927 			statement("        return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
5928 			statement("            plane0, plane1, plane2, samp,");
5929 			statement("            coord, spvForward<LodOptions>(options)...);");
5930 			statement("    return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
5931 			statement("        plane0, plane1, samp, coord,");
5932 			statement("        spvForward<LodOptions>(options)...);");
5933 			end_scope(); // switch (y_chroma_offset)
5934 			end_scope(); // case x::cosited_even:
5935 			statement("case spvXChromaLocation::midpoint:");
5936 			begin_scope();
5937 			statement("switch (ycbcr_samp.get_y_chroma_offset())");
5938 			begin_scope();
5939 			statement("case spvYChromaLocation::cosited_even:");
5940 			statement("    if (!is_null_texture(plane2))");
5941 			statement("        return spvChromaReconstructLinear420XMidpointYCositedEven(");
5942 			statement("            plane0, plane1, plane2, samp,");
5943 			statement("            coord, spvForward<LodOptions>(options)...);");
5944 			statement("    return spvChromaReconstructLinear420XMidpointYCositedEven(");
5945 			statement("        plane0, plane1, samp, coord,");
5946 			statement("        spvForward<LodOptions>(options)...);");
5947 			statement("case spvYChromaLocation::midpoint:");
5948 			statement("    if (!is_null_texture(plane2))");
5949 			statement("        return spvChromaReconstructLinear420XMidpointYMidpoint(");
5950 			statement("            plane0, plane1, plane2, samp,");
5951 			statement("            coord, spvForward<LodOptions>(options)...);");
5952 			statement("    return spvChromaReconstructLinear420XMidpointYMidpoint(");
5953 			statement("        plane0, plane1, samp, coord,");
5954 			statement("        spvForward<LodOptions>(options)...);");
5955 			end_scope(); // switch (y_chroma_offset)
5956 			end_scope(); // case x::midpoint
5957 			end_scope(); // switch (x_chroma_offset)
5958 			end_scope(); // case 420:
5959 			end_scope(); // switch (resolution)
5960 			end_scope(); // if (multiplanar)
5961 			statement("return plane0.sample(samp, coord, spvForward<LodOptions>(options)...);");
5962 			end_scope(); // do_sample()
5963 			statement("template <typename... LodOptions>");
5964 			statement("vec<T, 4> sample(float2 coord, LodOptions... options) const thread");
5965 			begin_scope();
5966 			statement(
5967 			    "vec<T, 4> s = spvTextureSwizzle(do_sample(coord, spvForward<LodOptions>(options)...), swizzle);");
5968 			statement("if (ycbcr_samp.get_ycbcr_model() == spvYCbCrModelConversion::rgb_identity)");
5969 			statement("    return s;");
5970 			statement("");
5971 			statement("switch (ycbcr_samp.get_ycbcr_range())");
5972 			begin_scope();
5973 			statement("case spvYCbCrRange::itu_full:");
5974 			statement("    s = spvExpandITUFullRange(s, ycbcr_samp.get_bpc());");
5975 			statement("    break;");
5976 			statement("case spvYCbCrRange::itu_narrow:");
5977 			statement("    s = spvExpandITUNarrowRange(s, ycbcr_samp.get_bpc());");
5978 			statement("    break;");
5979 			end_scope();
5980 			statement("");
5981 			statement("switch (ycbcr_samp.get_ycbcr_model())");
5982 			begin_scope();
5983 			statement("case spvYCbCrModelConversion::rgb_identity:"); // Silence Clang warning
5984 			statement("case spvYCbCrModelConversion::ycbcr_identity:");
5985 			statement("    return s;");
5986 			statement("case spvYCbCrModelConversion::ycbcr_bt_709:");
5987 			statement("    return spvConvertYCbCrBT709(s);");
5988 			statement("case spvYCbCrModelConversion::ycbcr_bt_601:");
5989 			statement("    return spvConvertYCbCrBT601(s);");
5990 			statement("case spvYCbCrModelConversion::ycbcr_bt_2020:");
5991 			statement("    return spvConvertYCbCrBT2020(s);");
5992 			end_scope();
5993 			end_scope();
5994 			statement("");
5995 			// Sampler Y'CbCr conversion forbids offsets.
5996 			statement("vec<T, 4> sample(float2 coord, int2 offset) const thread");
5997 			begin_scope();
5998 			if (msl_options.swizzle_texture_samples)
5999 				statement("return spvTextureSwizzle(plane0.sample(samp, coord, offset), swizzle);");
6000 			else
6001 				statement("return plane0.sample(samp, coord, offset);");
6002 			end_scope();
6003 			statement("template<typename lod_options>");
6004 			statement("vec<T, 4> sample(float2 coord, lod_options options, int2 offset) const thread");
6005 			begin_scope();
6006 			if (msl_options.swizzle_texture_samples)
6007 				statement("return spvTextureSwizzle(plane0.sample(samp, coord, options, offset), swizzle);");
6008 			else
6009 				statement("return plane0.sample(samp, coord, options, offset);");
6010 			end_scope();
6011 			statement("#if __HAVE_MIN_LOD_CLAMP__");
6012 			statement("vec<T, 4> sample(float2 coord, bias b, min_lod_clamp min_lod, int2 offset) const thread");
6013 			begin_scope();
6014 			statement("return plane0.sample(samp, coord, b, min_lod, offset);");
6015 			end_scope();
6016 			statement(
6017 			    "vec<T, 4> sample(float2 coord, gradient2d grad, min_lod_clamp min_lod, int2 offset) const thread");
6018 			begin_scope();
6019 			statement("return plane0.sample(samp, coord, grad, min_lod, offset);");
6020 			end_scope();
6021 			statement("#endif");
6022 			statement("");
6023 			// Y'CbCr conversion forbids all operations but sampling.
6024 			statement("vec<T, 4> read(uint2 coord, uint lod = 0) const thread");
6025 			begin_scope();
6026 			statement("return plane0.read(coord, lod);");
6027 			end_scope();
6028 			statement("");
6029 			statement("vec<T, 4> gather(float2 coord, int2 offset = int2(0), component c = component::x) const thread");
6030 			begin_scope();
6031 			if (msl_options.swizzle_texture_samples)
6032 				statement("return spvGatherSwizzle(plane0, samp, swizzle, c, coord, offset);");
6033 			else
6034 				statement("return plane0.gather(samp, coord, offset, c);");
6035 			end_scope();
6036 			end_scope_decl();
6037 			statement("");
6038 
6039 		default:
6040 			break;
6041 		}
6042 	}
6043 }
6044 
6045 // Undefined global memory is not allowed in MSL.
6046 // Declare constant and init to zeros. Use {}, as global constructors can break Metal.
declare_undefined_values()6047 void CompilerMSL::declare_undefined_values()
6048 {
6049 	bool emitted = false;
6050 	ir.for_each_typed_id<SPIRUndef>([&](uint32_t, SPIRUndef &undef) {
6051 		auto &type = this->get<SPIRType>(undef.basetype);
6052 		// OpUndef can be void for some reason ...
6053 		if (type.basetype == SPIRType::Void)
6054 			return;
6055 
6056 		statement("constant ", variable_decl(type, to_name(undef.self), undef.self), " = {};");
6057 		emitted = true;
6058 	});
6059 
6060 	if (emitted)
6061 		statement("");
6062 }
6063 
declare_constant_arrays()6064 void CompilerMSL::declare_constant_arrays()
6065 {
6066 	bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
6067 
6068 	// MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
6069 	// global constants directly, so we are able to use constants as variable expressions.
6070 	bool emitted = false;
6071 
6072 	ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
6073 		if (c.specialization)
6074 			return;
6075 
6076 		auto &type = this->get<SPIRType>(c.constant_type);
6077 		// Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries.
6078 		// FIXME: However, hoisting constants to main() means we need to pass down constant arrays to leaf functions if they are used there.
6079 		// If there are multiple functions in the module, drop this case to avoid breaking use cases which do not need to
6080 		// link into Metal libraries. This is hacky.
6081 		if (!type.array.empty() && (!fully_inlined || is_scalar(type) || is_vector(type)))
6082 		{
6083 			auto name = to_name(c.self);
6084 			statement("constant ", variable_decl(type, name), " = ", constant_expression(c), ";");
6085 			emitted = true;
6086 		}
6087 	});
6088 
6089 	if (emitted)
6090 		statement("");
6091 }
6092 
6093 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
declare_complex_constant_arrays()6094 void CompilerMSL::declare_complex_constant_arrays()
6095 {
6096 	// If we do not have a fully inlined module, we did not opt in to
6097 	// declaring constant arrays of complex types. See CompilerMSL::declare_constant_arrays().
6098 	bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
6099 	if (!fully_inlined)
6100 		return;
6101 
6102 	// MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
6103 	// global constants directly, so we are able to use constants as variable expressions.
6104 	bool emitted = false;
6105 
6106 	ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
6107 		if (c.specialization)
6108 			return;
6109 
6110 		auto &type = this->get<SPIRType>(c.constant_type);
6111 		if (!type.array.empty() && !(is_scalar(type) || is_vector(type)))
6112 		{
6113 			auto name = to_name(c.self);
6114 			statement("", variable_decl(type, name), " = ", constant_expression(c), ";");
6115 			emitted = true;
6116 		}
6117 	});
6118 
6119 	if (emitted)
6120 		statement("");
6121 }
6122 
emit_resources()6123 void CompilerMSL::emit_resources()
6124 {
6125 	declare_constant_arrays();
6126 	declare_undefined_values();
6127 
6128 	// Emit the special [[stage_in]] and [[stage_out]] interface blocks which we created.
6129 	emit_interface_block(stage_out_var_id);
6130 	emit_interface_block(patch_stage_out_var_id);
6131 	emit_interface_block(stage_in_var_id);
6132 	emit_interface_block(patch_stage_in_var_id);
6133 }
6134 
6135 // Emit declarations for the specialization Metal function constants
emit_specialization_constants_and_structs()6136 void CompilerMSL::emit_specialization_constants_and_structs()
6137 {
6138 	SpecializationConstant wg_x, wg_y, wg_z;
6139 	ID workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
6140 	bool emitted = false;
6141 
6142 	unordered_set<uint32_t> declared_structs;
6143 	unordered_set<uint32_t> aligned_structs;
6144 
6145 	// First, we need to deal with scalar block layout.
6146 	// It is possible that a struct may have to be placed at an alignment which does not match the innate alignment of the struct itself.
6147 	// In that case, if such a case exists for a struct, we must force that all elements of the struct become packed_ types.
6148 	// This makes the struct alignment as small as physically possible.
6149 	// When we actually align the struct later, we can insert padding as necessary to make the packed members behave like normally aligned types.
6150 	ir.for_each_typed_id<SPIRType>([&](uint32_t type_id, const SPIRType &type) {
6151 		if (type.basetype == SPIRType::Struct &&
6152 		    has_extended_decoration(type_id, SPIRVCrossDecorationBufferBlockRepacked))
6153 			mark_scalar_layout_structs(type);
6154 	});
6155 
6156 	bool builtin_block_type_is_required = false;
6157 	// Very special case. If gl_PerVertex is initialized as an array (tessellation)
6158 	// we have to potentially emit the gl_PerVertex struct type so that we can emit a constant LUT.
6159 	ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
6160 		auto &type = this->get<SPIRType>(c.constant_type);
6161 		if (is_array(type) && has_decoration(type.self, DecorationBlock) && is_builtin_type(type))
6162 			builtin_block_type_is_required = true;
6163 	});
6164 
6165 	// Very particular use of the soft loop lock.
6166 	// align_struct may need to create custom types on the fly, but we don't care about
6167 	// these types for purpose of iterating over them in ir.ids_for_type and friends.
6168 	auto loop_lock = ir.create_loop_soft_lock();
6169 
6170 	for (auto &id_ : ir.ids_for_constant_or_type)
6171 	{
6172 		auto &id = ir.ids[id_];
6173 
6174 		if (id.get_type() == TypeConstant)
6175 		{
6176 			auto &c = id.get<SPIRConstant>();
6177 
6178 			if (c.self == workgroup_size_id)
6179 			{
6180 				// TODO: This can be expressed as a [[threads_per_threadgroup]] input semantic, but we need to know
6181 				// the work group size at compile time in SPIR-V, and [[threads_per_threadgroup]] would need to be passed around as a global.
6182 				// The work group size may be a specialization constant.
6183 				statement("constant uint3 ", builtin_to_glsl(BuiltInWorkgroupSize, StorageClassWorkgroup),
6184 				          " [[maybe_unused]] = ", constant_expression(get<SPIRConstant>(workgroup_size_id)), ";");
6185 				emitted = true;
6186 			}
6187 			else if (c.specialization)
6188 			{
6189 				auto &type = get<SPIRType>(c.constant_type);
6190 				string sc_type_name = type_to_glsl(type);
6191 				string sc_name = to_name(c.self);
6192 				string sc_tmp_name = sc_name + "_tmp";
6193 
6194 				// Function constants are only supported in MSL 1.2 and later.
6195 				// If we don't support it just declare the "default" directly.
6196 				// This "default" value can be overridden to the true specialization constant by the API user.
6197 				// Specialization constants which are used as array length expressions cannot be function constants in MSL,
6198 				// so just fall back to macros.
6199 				if (msl_options.supports_msl_version(1, 2) && has_decoration(c.self, DecorationSpecId) &&
6200 				    !c.is_used_as_array_length)
6201 				{
6202 					uint32_t constant_id = get_decoration(c.self, DecorationSpecId);
6203 					// Only scalar, non-composite values can be function constants.
6204 					statement("constant ", sc_type_name, " ", sc_tmp_name, " [[function_constant(", constant_id,
6205 					          ")]];");
6206 					statement("constant ", sc_type_name, " ", sc_name, " = is_function_constant_defined(", sc_tmp_name,
6207 					          ") ? ", sc_tmp_name, " : ", constant_expression(c), ";");
6208 				}
6209 				else if (has_decoration(c.self, DecorationSpecId))
6210 				{
6211 					// Fallback to macro overrides.
6212 					c.specialization_constant_macro_name =
6213 					    constant_value_macro_name(get_decoration(c.self, DecorationSpecId));
6214 
6215 					statement("#ifndef ", c.specialization_constant_macro_name);
6216 					statement("#define ", c.specialization_constant_macro_name, " ", constant_expression(c));
6217 					statement("#endif");
6218 					statement("constant ", sc_type_name, " ", sc_name, " = ", c.specialization_constant_macro_name,
6219 					          ";");
6220 				}
6221 				else
6222 				{
6223 					// Composite specialization constants must be built from other specialization constants.
6224 					statement("constant ", sc_type_name, " ", sc_name, " = ", constant_expression(c), ";");
6225 				}
6226 				emitted = true;
6227 			}
6228 		}
6229 		else if (id.get_type() == TypeConstantOp)
6230 		{
6231 			auto &c = id.get<SPIRConstantOp>();
6232 			auto &type = get<SPIRType>(c.basetype);
6233 			auto name = to_name(c.self);
6234 			statement("constant ", variable_decl(type, name), " = ", constant_op_expression(c), ";");
6235 			emitted = true;
6236 		}
6237 		else if (id.get_type() == TypeType)
6238 		{
6239 			// Output non-builtin interface structs. These include local function structs
6240 			// and structs nested within uniform and read-write buffers.
6241 			auto &type = id.get<SPIRType>();
6242 			TypeID type_id = type.self;
6243 
6244 			bool is_struct = (type.basetype == SPIRType::Struct) && type.array.empty() && !type.pointer;
6245 			bool is_block =
6246 			    has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock);
6247 
6248 			bool is_builtin_block = is_block && is_builtin_type(type);
6249 			bool is_declarable_struct = is_struct && (!is_builtin_block || builtin_block_type_is_required);
6250 
6251 			// We'll declare this later.
6252 			if (stage_out_var_id && get_stage_out_struct_type().self == type_id)
6253 				is_declarable_struct = false;
6254 			if (patch_stage_out_var_id && get_patch_stage_out_struct_type().self == type_id)
6255 				is_declarable_struct = false;
6256 			if (stage_in_var_id && get_stage_in_struct_type().self == type_id)
6257 				is_declarable_struct = false;
6258 			if (patch_stage_in_var_id && get_patch_stage_in_struct_type().self == type_id)
6259 				is_declarable_struct = false;
6260 
6261 			// Align and emit declarable structs...but avoid declaring each more than once.
6262 			if (is_declarable_struct && declared_structs.count(type_id) == 0)
6263 			{
6264 				if (emitted)
6265 					statement("");
6266 				emitted = false;
6267 
6268 				declared_structs.insert(type_id);
6269 
6270 				if (has_extended_decoration(type_id, SPIRVCrossDecorationBufferBlockRepacked))
6271 					align_struct(type, aligned_structs);
6272 
6273 				// Make sure we declare the underlying struct type, and not the "decorated" type with pointers, etc.
6274 				emit_struct(get<SPIRType>(type_id));
6275 			}
6276 		}
6277 	}
6278 
6279 	if (emitted)
6280 		statement("");
6281 }
6282 
emit_binary_unord_op(uint32_t result_type,uint32_t result_id,uint32_t op0,uint32_t op1,const char * op)6283 void CompilerMSL::emit_binary_unord_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
6284                                        const char *op)
6285 {
6286 	bool forward = should_forward(op0) && should_forward(op1);
6287 	emit_op(result_type, result_id,
6288 	        join("(isunordered(", to_enclosed_unpacked_expression(op0), ", ", to_enclosed_unpacked_expression(op1),
6289 	             ") || ", to_enclosed_unpacked_expression(op0), " ", op, " ", to_enclosed_unpacked_expression(op1),
6290 	             ")"),
6291 	        forward);
6292 
6293 	inherit_expression_dependencies(result_id, op0);
6294 	inherit_expression_dependencies(result_id, op1);
6295 }
6296 
emit_tessellation_io_load(uint32_t result_type_id,uint32_t id,uint32_t ptr)6297 bool CompilerMSL::emit_tessellation_io_load(uint32_t result_type_id, uint32_t id, uint32_t ptr)
6298 {
6299 	auto &ptr_type = expression_type(ptr);
6300 	auto &result_type = get<SPIRType>(result_type_id);
6301 	if (ptr_type.storage != StorageClassInput && ptr_type.storage != StorageClassOutput)
6302 		return false;
6303 	if (ptr_type.storage == StorageClassOutput && get_execution_model() == ExecutionModelTessellationEvaluation)
6304 		return false;
6305 
6306 	bool multi_patch_tess_ctl = get_execution_model() == ExecutionModelTessellationControl &&
6307 	                            msl_options.multi_patch_workgroup && ptr_type.storage == StorageClassInput;
6308 	bool flat_matrix = is_matrix(result_type) && ptr_type.storage == StorageClassInput && !multi_patch_tess_ctl;
6309 	bool flat_struct = result_type.basetype == SPIRType::Struct && ptr_type.storage == StorageClassInput;
6310 	bool flat_data_type = flat_matrix || is_array(result_type) || flat_struct;
6311 	if (!flat_data_type)
6312 		return false;
6313 
6314 	if (has_decoration(ptr, DecorationPatch))
6315 		return false;
6316 
6317 	// Now, we must unflatten a composite type and take care of interleaving array access with gl_in/gl_out.
6318 	// Lots of painful code duplication since we *really* should not unroll these kinds of loads in entry point fixup
6319 	// unless we're forced to do this when the code is emitting inoptimal OpLoads.
6320 	string expr;
6321 
6322 	uint32_t interface_index = get_extended_decoration(ptr, SPIRVCrossDecorationInterfaceMemberIndex);
6323 	auto *var = maybe_get_backing_variable(ptr);
6324 	bool ptr_is_io_variable = ir.ids[ptr].get_type() == TypeVariable;
6325 	auto &expr_type = get_pointee_type(ptr_type.self);
6326 
6327 	const auto &iface_type = expression_type(stage_in_ptr_var_id);
6328 
6329 	if (result_type.array.size() > 2)
6330 	{
6331 		SPIRV_CROSS_THROW("Cannot load tessellation IO variables with more than 2 dimensions.");
6332 	}
6333 	else if (result_type.array.size() == 2)
6334 	{
6335 		if (!ptr_is_io_variable)
6336 			SPIRV_CROSS_THROW("Loading an array-of-array must be loaded directly from an IO variable.");
6337 		if (interface_index == uint32_t(-1))
6338 			SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6339 		if (result_type.basetype == SPIRType::Struct || flat_matrix)
6340 			SPIRV_CROSS_THROW("Cannot load array-of-array of composite type in tessellation IO.");
6341 
6342 		expr += type_to_glsl(result_type) + "({ ";
6343 		uint32_t num_control_points = to_array_size_literal(result_type, 1);
6344 		uint32_t base_interface_index = interface_index;
6345 
6346 		auto &sub_type = get<SPIRType>(result_type.parent_type);
6347 
6348 		for (uint32_t i = 0; i < num_control_points; i++)
6349 		{
6350 			expr += type_to_glsl(sub_type) + "({ ";
6351 			interface_index = base_interface_index;
6352 			uint32_t array_size = to_array_size_literal(result_type, 0);
6353 			if (multi_patch_tess_ctl)
6354 			{
6355 				for (uint32_t j = 0; j < array_size; j++)
6356 				{
6357 					const uint32_t indices[3] = { i, interface_index, j };
6358 
6359 					AccessChainMeta meta;
6360 					expr +=
6361 					    access_chain_internal(stage_in_ptr_var_id, indices, 3,
6362 					                          ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6363 					// If the expression has more vector components than the result type, insert
6364 					// a swizzle. This shouldn't happen normally on valid SPIR-V, but it might
6365 					// happen if we replace the type of an input variable.
6366 					if (!is_matrix(sub_type) && sub_type.basetype != SPIRType::Struct &&
6367 					    expr_type.vecsize > sub_type.vecsize)
6368 						expr += vector_swizzle(sub_type.vecsize, 0);
6369 
6370 					if (j + 1 < array_size)
6371 						expr += ", ";
6372 				}
6373 			}
6374 			else
6375 			{
6376 				for (uint32_t j = 0; j < array_size; j++, interface_index++)
6377 				{
6378 					const uint32_t indices[2] = { i, interface_index };
6379 
6380 					AccessChainMeta meta;
6381 					expr +=
6382 					    access_chain_internal(stage_in_ptr_var_id, indices, 2,
6383 					                          ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6384 					if (!is_matrix(sub_type) && sub_type.basetype != SPIRType::Struct &&
6385 					    expr_type.vecsize > sub_type.vecsize)
6386 						expr += vector_swizzle(sub_type.vecsize, 0);
6387 
6388 					if (j + 1 < array_size)
6389 						expr += ", ";
6390 				}
6391 			}
6392 			expr += " })";
6393 			if (i + 1 < num_control_points)
6394 				expr += ", ";
6395 		}
6396 		expr += " })";
6397 	}
6398 	else if (flat_struct)
6399 	{
6400 		bool is_array_of_struct = is_array(result_type);
6401 		if (is_array_of_struct && !ptr_is_io_variable)
6402 			SPIRV_CROSS_THROW("Loading array of struct from IO variable must come directly from IO variable.");
6403 
6404 		uint32_t num_control_points = 1;
6405 		if (is_array_of_struct)
6406 		{
6407 			num_control_points = to_array_size_literal(result_type, 0);
6408 			expr += type_to_glsl(result_type) + "({ ";
6409 		}
6410 
6411 		auto &struct_type = is_array_of_struct ? get<SPIRType>(result_type.parent_type) : result_type;
6412 		assert(struct_type.array.empty());
6413 
6414 		for (uint32_t i = 0; i < num_control_points; i++)
6415 		{
6416 			expr += type_to_glsl(struct_type) + "{ ";
6417 			for (uint32_t j = 0; j < uint32_t(struct_type.member_types.size()); j++)
6418 			{
6419 				// The base interface index is stored per variable for structs.
6420 				if (var)
6421 				{
6422 					interface_index =
6423 					    get_extended_member_decoration(var->self, j, SPIRVCrossDecorationInterfaceMemberIndex);
6424 				}
6425 
6426 				if (interface_index == uint32_t(-1))
6427 					SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6428 
6429 				const auto &mbr_type = get<SPIRType>(struct_type.member_types[j]);
6430 				const auto &expr_mbr_type = get<SPIRType>(expr_type.member_types[j]);
6431 				if (is_matrix(mbr_type) && ptr_type.storage == StorageClassInput && !multi_patch_tess_ctl)
6432 				{
6433 					expr += type_to_glsl(mbr_type) + "(";
6434 					for (uint32_t k = 0; k < mbr_type.columns; k++, interface_index++)
6435 					{
6436 						if (is_array_of_struct)
6437 						{
6438 							const uint32_t indices[2] = { i, interface_index };
6439 							AccessChainMeta meta;
6440 							expr += access_chain_internal(
6441 							    stage_in_ptr_var_id, indices, 2,
6442 							    ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6443 						}
6444 						else
6445 							expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6446 						if (expr_mbr_type.vecsize > mbr_type.vecsize)
6447 							expr += vector_swizzle(mbr_type.vecsize, 0);
6448 
6449 						if (k + 1 < mbr_type.columns)
6450 							expr += ", ";
6451 					}
6452 					expr += ")";
6453 				}
6454 				else if (is_array(mbr_type))
6455 				{
6456 					expr += type_to_glsl(mbr_type) + "({ ";
6457 					uint32_t array_size = to_array_size_literal(mbr_type, 0);
6458 					if (multi_patch_tess_ctl)
6459 					{
6460 						for (uint32_t k = 0; k < array_size; k++)
6461 						{
6462 							if (is_array_of_struct)
6463 							{
6464 								const uint32_t indices[3] = { i, interface_index, k };
6465 								AccessChainMeta meta;
6466 								expr += access_chain_internal(
6467 								    stage_in_ptr_var_id, indices, 3,
6468 								    ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6469 							}
6470 							else
6471 								expr += join(to_expression(ptr), ".", to_member_name(iface_type, interface_index), "[",
6472 								             k, "]");
6473 							if (expr_mbr_type.vecsize > mbr_type.vecsize)
6474 								expr += vector_swizzle(mbr_type.vecsize, 0);
6475 
6476 							if (k + 1 < array_size)
6477 								expr += ", ";
6478 						}
6479 					}
6480 					else
6481 					{
6482 						for (uint32_t k = 0; k < array_size; k++, interface_index++)
6483 						{
6484 							if (is_array_of_struct)
6485 							{
6486 								const uint32_t indices[2] = { i, interface_index };
6487 								AccessChainMeta meta;
6488 								expr += access_chain_internal(
6489 								    stage_in_ptr_var_id, indices, 2,
6490 								    ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6491 							}
6492 							else
6493 								expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6494 							if (expr_mbr_type.vecsize > mbr_type.vecsize)
6495 								expr += vector_swizzle(mbr_type.vecsize, 0);
6496 
6497 							if (k + 1 < array_size)
6498 								expr += ", ";
6499 						}
6500 					}
6501 					expr += " })";
6502 				}
6503 				else
6504 				{
6505 					if (is_array_of_struct)
6506 					{
6507 						const uint32_t indices[2] = { i, interface_index };
6508 						AccessChainMeta meta;
6509 						expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
6510 						                              ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT,
6511 						                              &meta);
6512 					}
6513 					else
6514 						expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6515 					if (expr_mbr_type.vecsize > mbr_type.vecsize)
6516 						expr += vector_swizzle(mbr_type.vecsize, 0);
6517 				}
6518 
6519 				if (j + 1 < struct_type.member_types.size())
6520 					expr += ", ";
6521 			}
6522 			expr += " }";
6523 			if (i + 1 < num_control_points)
6524 				expr += ", ";
6525 		}
6526 		if (is_array_of_struct)
6527 			expr += " })";
6528 	}
6529 	else if (flat_matrix)
6530 	{
6531 		bool is_array_of_matrix = is_array(result_type);
6532 		if (is_array_of_matrix && !ptr_is_io_variable)
6533 			SPIRV_CROSS_THROW("Loading array of matrix from IO variable must come directly from IO variable.");
6534 		if (interface_index == uint32_t(-1))
6535 			SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6536 
6537 		if (is_array_of_matrix)
6538 		{
6539 			// Loading a matrix from each control point.
6540 			uint32_t base_interface_index = interface_index;
6541 			uint32_t num_control_points = to_array_size_literal(result_type, 0);
6542 			expr += type_to_glsl(result_type) + "({ ";
6543 
6544 			auto &matrix_type = get_variable_element_type(get<SPIRVariable>(ptr));
6545 
6546 			for (uint32_t i = 0; i < num_control_points; i++)
6547 			{
6548 				interface_index = base_interface_index;
6549 				expr += type_to_glsl(matrix_type) + "(";
6550 				for (uint32_t j = 0; j < result_type.columns; j++, interface_index++)
6551 				{
6552 					const uint32_t indices[2] = { i, interface_index };
6553 
6554 					AccessChainMeta meta;
6555 					expr +=
6556 					    access_chain_internal(stage_in_ptr_var_id, indices, 2,
6557 					                          ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6558 					if (expr_type.vecsize > result_type.vecsize)
6559 						expr += vector_swizzle(result_type.vecsize, 0);
6560 					if (j + 1 < result_type.columns)
6561 						expr += ", ";
6562 				}
6563 				expr += ")";
6564 				if (i + 1 < num_control_points)
6565 					expr += ", ";
6566 			}
6567 
6568 			expr += " })";
6569 		}
6570 		else
6571 		{
6572 			expr += type_to_glsl(result_type) + "(";
6573 			for (uint32_t i = 0; i < result_type.columns; i++, interface_index++)
6574 			{
6575 				expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6576 				if (expr_type.vecsize > result_type.vecsize)
6577 					expr += vector_swizzle(result_type.vecsize, 0);
6578 				if (i + 1 < result_type.columns)
6579 					expr += ", ";
6580 			}
6581 			expr += ")";
6582 		}
6583 	}
6584 	else if (ptr_is_io_variable)
6585 	{
6586 		assert(is_array(result_type));
6587 		assert(result_type.array.size() == 1);
6588 		if (interface_index == uint32_t(-1))
6589 			SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6590 
6591 		// We're loading an array directly from a global variable.
6592 		// This means we're loading one member from each control point.
6593 		expr += type_to_glsl(result_type) + "({ ";
6594 		uint32_t num_control_points = to_array_size_literal(result_type, 0);
6595 
6596 		for (uint32_t i = 0; i < num_control_points; i++)
6597 		{
6598 			const uint32_t indices[2] = { i, interface_index };
6599 
6600 			AccessChainMeta meta;
6601 			expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
6602 			                              ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6603 			if (expr_type.vecsize > result_type.vecsize)
6604 				expr += vector_swizzle(result_type.vecsize, 0);
6605 
6606 			if (i + 1 < num_control_points)
6607 				expr += ", ";
6608 		}
6609 		expr += " })";
6610 	}
6611 	else
6612 	{
6613 		// We're loading an array from a concrete control point.
6614 		assert(is_array(result_type));
6615 		assert(result_type.array.size() == 1);
6616 		if (interface_index == uint32_t(-1))
6617 			SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6618 
6619 		expr += type_to_glsl(result_type) + "({ ";
6620 		uint32_t array_size = to_array_size_literal(result_type, 0);
6621 		for (uint32_t i = 0; i < array_size; i++, interface_index++)
6622 		{
6623 			expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6624 			if (expr_type.vecsize > result_type.vecsize)
6625 				expr += vector_swizzle(result_type.vecsize, 0);
6626 			if (i + 1 < array_size)
6627 				expr += ", ";
6628 		}
6629 		expr += " })";
6630 	}
6631 
6632 	emit_op(result_type_id, id, expr, false);
6633 	register_read(id, ptr, false);
6634 	return true;
6635 }
6636 
emit_tessellation_access_chain(const uint32_t * ops,uint32_t length)6637 bool CompilerMSL::emit_tessellation_access_chain(const uint32_t *ops, uint32_t length)
6638 {
6639 	// If this is a per-vertex output, remap it to the I/O array buffer.
6640 
6641 	// Any object which did not go through IO flattening shenanigans will go there instead.
6642 	// We will unflatten on-demand instead as needed, but not all possible cases can be supported, especially with arrays.
6643 
6644 	auto *var = maybe_get_backing_variable(ops[2]);
6645 	bool patch = false;
6646 	bool flat_data = false;
6647 	bool ptr_is_chain = false;
6648 	bool multi_patch = get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup;
6649 
6650 	if (var)
6651 	{
6652 		patch = has_decoration(ops[2], DecorationPatch) || is_patch_block(get_variable_data_type(*var));
6653 
6654 		// Should match strip_array in add_interface_block.
6655 		flat_data = var->storage == StorageClassInput ||
6656 		            (var->storage == StorageClassOutput && get_execution_model() == ExecutionModelTessellationControl);
6657 
6658 		// We might have a chained access chain, where
6659 		// we first take the access chain to the control point, and then we chain into a member or something similar.
6660 		// In this case, we need to skip gl_in/gl_out remapping.
6661 		ptr_is_chain = var->self != ID(ops[2]);
6662 	}
6663 
6664 	BuiltIn bi_type = BuiltIn(get_decoration(ops[2], DecorationBuiltIn));
6665 	if (var && flat_data && !patch &&
6666 	    (!is_builtin_variable(*var) || bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
6667 	     bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance ||
6668 	     get_variable_data_type(*var).basetype == SPIRType::Struct))
6669 	{
6670 		AccessChainMeta meta;
6671 		SmallVector<uint32_t> indices;
6672 		uint32_t next_id = ir.increase_bound_by(1);
6673 
6674 		indices.reserve(length - 3 + 1);
6675 
6676 		uint32_t first_non_array_index = ptr_is_chain ? 3 : 4;
6677 		VariableID stage_var_id = var->storage == StorageClassInput ? stage_in_ptr_var_id : stage_out_ptr_var_id;
6678 		VariableID ptr = ptr_is_chain ? VariableID(ops[2]) : stage_var_id;
6679 		if (!ptr_is_chain)
6680 		{
6681 			// Index into gl_in/gl_out with first array index.
6682 			indices.push_back(ops[3]);
6683 		}
6684 
6685 		auto &result_ptr_type = get<SPIRType>(ops[0]);
6686 
6687 		uint32_t const_mbr_id = next_id++;
6688 		uint32_t index = get_extended_decoration(var->self, SPIRVCrossDecorationInterfaceMemberIndex);
6689 		if (var->storage == StorageClassInput || has_decoration(get_variable_element_type(*var).self, DecorationBlock))
6690 		{
6691 			uint32_t i = first_non_array_index;
6692 			auto *type = &get_variable_element_type(*var);
6693 			if (index == uint32_t(-1) && length >= (first_non_array_index + 1))
6694 			{
6695 				// Maybe this is a struct type in the input class, in which case
6696 				// we put it as a decoration on the corresponding member.
6697 				index = get_extended_member_decoration(var->self, get_constant(ops[first_non_array_index]).scalar(),
6698 				                                       SPIRVCrossDecorationInterfaceMemberIndex);
6699 				assert(index != uint32_t(-1));
6700 				i++;
6701 				type = &get<SPIRType>(type->member_types[get_constant(ops[first_non_array_index]).scalar()]);
6702 			}
6703 
6704 			// In this case, we're poking into flattened structures and arrays, so now we have to
6705 			// combine the following indices. If we encounter a non-constant index,
6706 			// we're hosed.
6707 			for (; i < length; ++i)
6708 			{
6709 				if ((multi_patch || (!is_array(*type) && !is_matrix(*type))) && type->basetype != SPIRType::Struct)
6710 					break;
6711 
6712 				auto *c = maybe_get<SPIRConstant>(ops[i]);
6713 				if (!c || c->specialization)
6714 					SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable in tessellation. "
6715 					                  "This is currently unsupported.");
6716 
6717 				// We're in flattened space, so just increment the member index into IO block.
6718 				// We can only do this once in the current implementation, so either:
6719 				// Struct, Matrix or 1-dimensional array for a control point.
6720 				index += c->scalar();
6721 
6722 				if (type->parent_type)
6723 					type = &get<SPIRType>(type->parent_type);
6724 				else if (type->basetype == SPIRType::Struct)
6725 					type = &get<SPIRType>(type->member_types[c->scalar()]);
6726 			}
6727 
6728 			if ((!multi_patch && (is_matrix(result_ptr_type) || is_array(result_ptr_type))) ||
6729 			    result_ptr_type.basetype == SPIRType::Struct)
6730 			{
6731 				// We're not going to emit the actual member name, we let any further OpLoad take care of that.
6732 				// Tag the access chain with the member index we're referencing.
6733 				set_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex, index);
6734 			}
6735 			else
6736 			{
6737 				// Access the appropriate member of gl_in/gl_out.
6738 				set<SPIRConstant>(const_mbr_id, get_uint_type_id(), index, false);
6739 				indices.push_back(const_mbr_id);
6740 
6741 				// Append any straggling access chain indices.
6742 				if (i < length)
6743 					indices.insert(indices.end(), ops + i, ops + length);
6744 			}
6745 		}
6746 		else
6747 		{
6748 			assert(index != uint32_t(-1));
6749 			set<SPIRConstant>(const_mbr_id, get_uint_type_id(), index, false);
6750 			indices.push_back(const_mbr_id);
6751 
6752 			indices.insert(indices.end(), ops + 4, ops + length);
6753 		}
6754 
6755 		// We use the pointer to the base of the input/output array here,
6756 		// so this is always a pointer chain.
6757 		string e;
6758 
6759 		if (!ptr_is_chain)
6760 		{
6761 			// This is the start of an access chain, use ptr_chain to index into control point array.
6762 			e = access_chain(ptr, indices.data(), uint32_t(indices.size()), result_ptr_type, &meta, true);
6763 		}
6764 		else
6765 		{
6766 			// If we're accessing a struct, we need to use member indices which are based on the IO block,
6767 			// not actual struct type, so we have to use a split access chain here where
6768 			// first path resolves the control point index, i.e. gl_in[index], and second half deals with
6769 			// looking up flattened member name.
6770 
6771 			// However, it is possible that we partially accessed a struct,
6772 			// by taking pointer to member inside the control-point array.
6773 			// For this case, we fall back to a natural access chain since we have already dealt with remapping struct members.
6774 			// One way to check this here is if we have 2 implied read expressions.
6775 			// First one is the gl_in/gl_out struct itself, then an index into that array.
6776 			// If we have traversed further, we use a normal access chain formulation.
6777 			auto *ptr_expr = maybe_get<SPIRExpression>(ptr);
6778 			if (ptr_expr && ptr_expr->implied_read_expressions.size() == 2)
6779 			{
6780 				e = join(to_expression(ptr),
6781 				         access_chain_internal(stage_var_id, indices.data(), uint32_t(indices.size()),
6782 				                               ACCESS_CHAIN_CHAIN_ONLY_BIT, &meta));
6783 			}
6784 			else
6785 			{
6786 				e = access_chain_internal(ptr, indices.data(), uint32_t(indices.size()), 0, &meta);
6787 			}
6788 		}
6789 
6790 		// Get the actual type of the object that was accessed. If it's a vector type and we changed it,
6791 		// then we'll need to add a swizzle.
6792 		// For this, we can't necessarily rely on the type of the base expression, because it might be
6793 		// another access chain, and it will therefore already have the "correct" type.
6794 		auto *expr_type = &get_variable_data_type(*var);
6795 		if (has_extended_decoration(ops[2], SPIRVCrossDecorationTessIOOriginalInputTypeID))
6796 			expr_type = &get<SPIRType>(get_extended_decoration(ops[2], SPIRVCrossDecorationTessIOOriginalInputTypeID));
6797 		for (uint32_t i = 3; i < length; i++)
6798 		{
6799 			if (!is_array(*expr_type) && expr_type->basetype == SPIRType::Struct)
6800 				expr_type = &get<SPIRType>(expr_type->member_types[get<SPIRConstant>(ops[i]).scalar()]);
6801 			else
6802 				expr_type = &get<SPIRType>(expr_type->parent_type);
6803 		}
6804 		if (!is_array(*expr_type) && !is_matrix(*expr_type) && expr_type->basetype != SPIRType::Struct &&
6805 		    expr_type->vecsize > result_ptr_type.vecsize)
6806 			e += vector_swizzle(result_ptr_type.vecsize, 0);
6807 
6808 		auto &expr = set<SPIRExpression>(ops[1], move(e), ops[0], should_forward(ops[2]));
6809 		expr.loaded_from = var->self;
6810 		expr.need_transpose = meta.need_transpose;
6811 		expr.access_chain = true;
6812 
6813 		// Mark the result as being packed if necessary.
6814 		if (meta.storage_is_packed)
6815 			set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypePacked);
6816 		if (meta.storage_physical_type != 0)
6817 			set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypeID, meta.storage_physical_type);
6818 		if (meta.storage_is_invariant)
6819 			set_decoration(ops[1], DecorationInvariant);
6820 		// Save the type we found in case the result is used in another access chain.
6821 		set_extended_decoration(ops[1], SPIRVCrossDecorationTessIOOriginalInputTypeID, expr_type->self);
6822 
6823 		// If we have some expression dependencies in our access chain, this access chain is technically a forwarded
6824 		// temporary which could be subject to invalidation.
6825 		// Need to assume we're forwarded while calling inherit_expression_depdendencies.
6826 		forwarded_temporaries.insert(ops[1]);
6827 		// The access chain itself is never forced to a temporary, but its dependencies might.
6828 		suppressed_usage_tracking.insert(ops[1]);
6829 
6830 		for (uint32_t i = 2; i < length; i++)
6831 		{
6832 			inherit_expression_dependencies(ops[1], ops[i]);
6833 			add_implied_read_expression(expr, ops[i]);
6834 		}
6835 
6836 		// If we have no dependencies after all, i.e., all indices in the access chain are immutable temporaries,
6837 		// we're not forwarded after all.
6838 		if (expr.expression_dependencies.empty())
6839 			forwarded_temporaries.erase(ops[1]);
6840 
6841 		return true;
6842 	}
6843 
6844 	// If this is the inner tessellation level, and we're tessellating triangles,
6845 	// drop the last index. It isn't an array in this case, so we can't have an
6846 	// array reference here. We need to make this ID a variable instead of an
6847 	// expression so we don't try to dereference it as a variable pointer.
6848 	// Don't do this if the index is a constant 1, though. We need to drop stores
6849 	// to that one.
6850 	auto *m = ir.find_meta(var ? var->self : ID(0));
6851 	if (get_execution_model() == ExecutionModelTessellationControl && var && m &&
6852 	    m->decoration.builtin_type == BuiltInTessLevelInner && get_entry_point().flags.get(ExecutionModeTriangles))
6853 	{
6854 		auto *c = maybe_get<SPIRConstant>(ops[3]);
6855 		if (c && c->scalar() == 1)
6856 			return false;
6857 		auto &dest_var = set<SPIRVariable>(ops[1], *var);
6858 		dest_var.basetype = ops[0];
6859 		ir.meta[ops[1]] = ir.meta[ops[2]];
6860 		inherit_expression_dependencies(ops[1], ops[2]);
6861 		return true;
6862 	}
6863 
6864 	return false;
6865 }
6866 
is_out_of_bounds_tessellation_level(uint32_t id_lhs)6867 bool CompilerMSL::is_out_of_bounds_tessellation_level(uint32_t id_lhs)
6868 {
6869 	if (!get_entry_point().flags.get(ExecutionModeTriangles))
6870 		return false;
6871 
6872 	// In SPIR-V, TessLevelInner always has two elements and TessLevelOuter always has
6873 	// four. This is true even if we are tessellating triangles. This allows clients
6874 	// to use a single tessellation control shader with multiple tessellation evaluation
6875 	// shaders.
6876 	// In Metal, however, only the first element of TessLevelInner and the first three
6877 	// of TessLevelOuter are accessible. This stems from how in Metal, the tessellation
6878 	// levels must be stored to a dedicated buffer in a particular format that depends
6879 	// on the patch type. Therefore, in Triangles mode, any access to the second
6880 	// inner level or the fourth outer level must be dropped.
6881 	const auto *e = maybe_get<SPIRExpression>(id_lhs);
6882 	if (!e || !e->access_chain)
6883 		return false;
6884 	BuiltIn builtin = BuiltIn(get_decoration(e->loaded_from, DecorationBuiltIn));
6885 	if (builtin != BuiltInTessLevelInner && builtin != BuiltInTessLevelOuter)
6886 		return false;
6887 	auto *c = maybe_get<SPIRConstant>(e->implied_read_expressions[1]);
6888 	if (!c)
6889 		return false;
6890 	return (builtin == BuiltInTessLevelInner && c->scalar() == 1) ||
6891 	       (builtin == BuiltInTessLevelOuter && c->scalar() == 3);
6892 }
6893 
prepare_access_chain_for_scalar_access(std::string & expr,const SPIRType & type,spv::StorageClass storage,bool & is_packed)6894 void CompilerMSL::prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
6895                                                          spv::StorageClass storage, bool &is_packed)
6896 {
6897 	// If there is any risk of writes happening with the access chain in question,
6898 	// and there is a risk of concurrent write access to other components,
6899 	// we must cast the access chain to a plain pointer to ensure we only access the exact scalars we expect.
6900 	// The MSL compiler refuses to allow component-level access for any non-packed vector types.
6901 	if (!is_packed && (storage == StorageClassStorageBuffer || storage == StorageClassWorkgroup))
6902 	{
6903 		const char *addr_space = storage == StorageClassWorkgroup ? "threadgroup" : "device";
6904 		expr = join("((", addr_space, " ", type_to_glsl(type), "*)&", enclose_expression(expr), ")");
6905 
6906 		// Further indexing should happen with packed rules (array index, not swizzle).
6907 		is_packed = true;
6908 	}
6909 }
6910 
6911 // Sets the interface member index for an access chain to a pull-model interpolant.
fix_up_interpolant_access_chain(const uint32_t * ops,uint32_t length)6912 void CompilerMSL::fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t length)
6913 {
6914 	auto *var = maybe_get_backing_variable(ops[2]);
6915 	if (!var || !pull_model_inputs.count(var->self))
6916 		return;
6917 	// Get the base index.
6918 	uint32_t interface_index;
6919 	auto &var_type = get_variable_data_type(*var);
6920 	auto &result_type = get<SPIRType>(ops[0]);
6921 	auto *type = &var_type;
6922 	if (has_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex))
6923 	{
6924 		interface_index = get_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex);
6925 	}
6926 	else
6927 	{
6928 		// Assume an access chain into a struct variable.
6929 		assert(var_type.basetype == SPIRType::Struct);
6930 		auto &c = get<SPIRConstant>(ops[3 + var_type.array.size()]);
6931 		interface_index =
6932 		    get_extended_member_decoration(var->self, c.scalar(), SPIRVCrossDecorationInterfaceMemberIndex);
6933 	}
6934 	// Accumulate indices. We'll have to skip over the one for the struct, if present, because we already accounted
6935 	// for that getting the base index.
6936 	for (uint32_t i = 3; i < length; ++i)
6937 	{
6938 		if (is_vector(*type) && is_scalar(result_type))
6939 		{
6940 			// We don't want to combine the next index. Actually, we need to save it
6941 			// so we know to apply a swizzle to the result of the interpolation.
6942 			set_extended_decoration(ops[1], SPIRVCrossDecorationInterpolantComponentExpr, ops[i]);
6943 			break;
6944 		}
6945 
6946 		auto *c = maybe_get<SPIRConstant>(ops[i]);
6947 		if (!c || c->specialization)
6948 			SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable using pull-model "
6949 			                  "interpolation. This is currently unsupported.");
6950 
6951 		if (type->parent_type)
6952 			type = &get<SPIRType>(type->parent_type);
6953 		else if (type->basetype == SPIRType::Struct)
6954 			type = &get<SPIRType>(type->member_types[c->scalar()]);
6955 
6956 		if (!has_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex) &&
6957 		    i - 3 == var_type.array.size())
6958 			continue;
6959 
6960 		interface_index += c->scalar();
6961 	}
6962 	// Save this to the access chain itself so we can recover it later when calling an interpolation function.
6963 	set_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex, interface_index);
6964 }
6965 
6966 // Override for MSL-specific syntax instructions
emit_instruction(const Instruction & instruction)6967 void CompilerMSL::emit_instruction(const Instruction &instruction)
6968 {
6969 #define MSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
6970 #define MSL_BOP_CAST(op, type) \
6971 	emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
6972 #define MSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
6973 #define MSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
6974 #define MSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
6975 #define MSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
6976 #define MSL_BFOP_CAST(op, type) \
6977 	emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
6978 #define MSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
6979 #define MSL_UNORD_BOP(op) emit_binary_unord_op(ops[0], ops[1], ops[2], ops[3], #op)
6980 
6981 	auto ops = stream(instruction);
6982 	auto opcode = static_cast<Op>(instruction.op);
6983 
6984 	// If we need to do implicit bitcasts, make sure we do it with the correct type.
6985 	uint32_t integer_width = get_integer_width_for_instruction(instruction);
6986 	auto int_type = to_signed_basetype(integer_width);
6987 	auto uint_type = to_unsigned_basetype(integer_width);
6988 
6989 	switch (opcode)
6990 	{
6991 	case OpLoad:
6992 	{
6993 		uint32_t id = ops[1];
6994 		uint32_t ptr = ops[2];
6995 		if (is_tessellation_shader())
6996 		{
6997 			if (!emit_tessellation_io_load(ops[0], id, ptr))
6998 				CompilerGLSL::emit_instruction(instruction);
6999 		}
7000 		else
7001 		{
7002 			// Sample mask input for Metal is not an array
7003 			if (BuiltIn(get_decoration(ptr, DecorationBuiltIn)) == BuiltInSampleMask)
7004 				set_decoration(id, DecorationBuiltIn, BuiltInSampleMask);
7005 			CompilerGLSL::emit_instruction(instruction);
7006 		}
7007 		break;
7008 	}
7009 
7010 	// Comparisons
7011 	case OpIEqual:
7012 		MSL_BOP_CAST(==, int_type);
7013 		break;
7014 
7015 	case OpLogicalEqual:
7016 	case OpFOrdEqual:
7017 		MSL_BOP(==);
7018 		break;
7019 
7020 	case OpINotEqual:
7021 		MSL_BOP_CAST(!=, int_type);
7022 		break;
7023 
7024 	case OpLogicalNotEqual:
7025 	case OpFOrdNotEqual:
7026 		MSL_BOP(!=);
7027 		break;
7028 
7029 	case OpUGreaterThan:
7030 		MSL_BOP_CAST(>, uint_type);
7031 		break;
7032 
7033 	case OpSGreaterThan:
7034 		MSL_BOP_CAST(>, int_type);
7035 		break;
7036 
7037 	case OpFOrdGreaterThan:
7038 		MSL_BOP(>);
7039 		break;
7040 
7041 	case OpUGreaterThanEqual:
7042 		MSL_BOP_CAST(>=, uint_type);
7043 		break;
7044 
7045 	case OpSGreaterThanEqual:
7046 		MSL_BOP_CAST(>=, int_type);
7047 		break;
7048 
7049 	case OpFOrdGreaterThanEqual:
7050 		MSL_BOP(>=);
7051 		break;
7052 
7053 	case OpULessThan:
7054 		MSL_BOP_CAST(<, uint_type);
7055 		break;
7056 
7057 	case OpSLessThan:
7058 		MSL_BOP_CAST(<, int_type);
7059 		break;
7060 
7061 	case OpFOrdLessThan:
7062 		MSL_BOP(<);
7063 		break;
7064 
7065 	case OpULessThanEqual:
7066 		MSL_BOP_CAST(<=, uint_type);
7067 		break;
7068 
7069 	case OpSLessThanEqual:
7070 		MSL_BOP_CAST(<=, int_type);
7071 		break;
7072 
7073 	case OpFOrdLessThanEqual:
7074 		MSL_BOP(<=);
7075 		break;
7076 
7077 	case OpFUnordEqual:
7078 		MSL_UNORD_BOP(==);
7079 		break;
7080 
7081 	case OpFUnordNotEqual:
7082 		MSL_UNORD_BOP(!=);
7083 		break;
7084 
7085 	case OpFUnordGreaterThan:
7086 		MSL_UNORD_BOP(>);
7087 		break;
7088 
7089 	case OpFUnordGreaterThanEqual:
7090 		MSL_UNORD_BOP(>=);
7091 		break;
7092 
7093 	case OpFUnordLessThan:
7094 		MSL_UNORD_BOP(<);
7095 		break;
7096 
7097 	case OpFUnordLessThanEqual:
7098 		MSL_UNORD_BOP(<=);
7099 		break;
7100 
7101 	// Derivatives
7102 	case OpDPdx:
7103 	case OpDPdxFine:
7104 	case OpDPdxCoarse:
7105 		MSL_UFOP(dfdx);
7106 		register_control_dependent_expression(ops[1]);
7107 		break;
7108 
7109 	case OpDPdy:
7110 	case OpDPdyFine:
7111 	case OpDPdyCoarse:
7112 		MSL_UFOP(dfdy);
7113 		register_control_dependent_expression(ops[1]);
7114 		break;
7115 
7116 	case OpFwidth:
7117 	case OpFwidthCoarse:
7118 	case OpFwidthFine:
7119 		MSL_UFOP(fwidth);
7120 		register_control_dependent_expression(ops[1]);
7121 		break;
7122 
7123 	// Bitfield
7124 	case OpBitFieldInsert:
7125 	{
7126 		emit_bitfield_insert_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], "insert_bits", SPIRType::UInt);
7127 		break;
7128 	}
7129 
7130 	case OpBitFieldSExtract:
7131 	{
7132 		emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "extract_bits", int_type, int_type,
7133 		                                SPIRType::UInt, SPIRType::UInt);
7134 		break;
7135 	}
7136 
7137 	case OpBitFieldUExtract:
7138 	{
7139 		emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "extract_bits", uint_type, uint_type,
7140 		                                SPIRType::UInt, SPIRType::UInt);
7141 		break;
7142 	}
7143 
7144 	case OpBitReverse:
7145 		// BitReverse does not have issues with sign since result type must match input type.
7146 		MSL_UFOP(reverse_bits);
7147 		break;
7148 
7149 	case OpBitCount:
7150 	{
7151 		auto basetype = expression_type(ops[2]).basetype;
7152 		emit_unary_func_op_cast(ops[0], ops[1], ops[2], "popcount", basetype, basetype);
7153 		break;
7154 	}
7155 
7156 	case OpFRem:
7157 		MSL_BFOP(fmod);
7158 		break;
7159 
7160 	case OpFMul:
7161 		if (msl_options.invariant_float_math)
7162 			MSL_BFOP(spvFMul);
7163 		else
7164 			MSL_BOP(*);
7165 		break;
7166 
7167 	case OpFAdd:
7168 		if (msl_options.invariant_float_math)
7169 			MSL_BFOP(spvFAdd);
7170 		else
7171 			MSL_BOP(+);
7172 		break;
7173 
7174 	// Atomics
7175 	case OpAtomicExchange:
7176 	{
7177 		uint32_t result_type = ops[0];
7178 		uint32_t id = ops[1];
7179 		uint32_t ptr = ops[2];
7180 		uint32_t mem_sem = ops[4];
7181 		uint32_t val = ops[5];
7182 		emit_atomic_func_op(result_type, id, "atomic_exchange_explicit", mem_sem, mem_sem, false, ptr, val);
7183 		break;
7184 	}
7185 
7186 	case OpAtomicCompareExchange:
7187 	{
7188 		uint32_t result_type = ops[0];
7189 		uint32_t id = ops[1];
7190 		uint32_t ptr = ops[2];
7191 		uint32_t mem_sem_pass = ops[4];
7192 		uint32_t mem_sem_fail = ops[5];
7193 		uint32_t val = ops[6];
7194 		uint32_t comp = ops[7];
7195 		emit_atomic_func_op(result_type, id, "atomic_compare_exchange_weak_explicit", mem_sem_pass, mem_sem_fail, true,
7196 		                    ptr, comp, true, false, val);
7197 		break;
7198 	}
7199 
7200 	case OpAtomicCompareExchangeWeak:
7201 		SPIRV_CROSS_THROW("OpAtomicCompareExchangeWeak is only supported in kernel profile.");
7202 
7203 	case OpAtomicLoad:
7204 	{
7205 		uint32_t result_type = ops[0];
7206 		uint32_t id = ops[1];
7207 		uint32_t ptr = ops[2];
7208 		uint32_t mem_sem = ops[4];
7209 		emit_atomic_func_op(result_type, id, "atomic_load_explicit", mem_sem, mem_sem, false, ptr, 0);
7210 		break;
7211 	}
7212 
7213 	case OpAtomicStore:
7214 	{
7215 		uint32_t result_type = expression_type(ops[0]).self;
7216 		uint32_t id = ops[0];
7217 		uint32_t ptr = ops[0];
7218 		uint32_t mem_sem = ops[2];
7219 		uint32_t val = ops[3];
7220 		emit_atomic_func_op(result_type, id, "atomic_store_explicit", mem_sem, mem_sem, false, ptr, val);
7221 		break;
7222 	}
7223 
7224 #define MSL_AFMO_IMPL(op, valsrc, valconst)                                                                      \
7225 	do                                                                                                           \
7226 	{                                                                                                            \
7227 		uint32_t result_type = ops[0];                                                                           \
7228 		uint32_t id = ops[1];                                                                                    \
7229 		uint32_t ptr = ops[2];                                                                                   \
7230 		uint32_t mem_sem = ops[4];                                                                               \
7231 		uint32_t val = valsrc;                                                                                   \
7232 		emit_atomic_func_op(result_type, id, "atomic_fetch_" #op "_explicit", mem_sem, mem_sem, false, ptr, val, \
7233 		                    false, valconst);                                                                    \
7234 	} while (false)
7235 
7236 #define MSL_AFMO(op) MSL_AFMO_IMPL(op, ops[5], false)
7237 #define MSL_AFMIO(op) MSL_AFMO_IMPL(op, 1, true)
7238 
7239 	case OpAtomicIIncrement:
7240 		MSL_AFMIO(add);
7241 		break;
7242 
7243 	case OpAtomicIDecrement:
7244 		MSL_AFMIO(sub);
7245 		break;
7246 
7247 	case OpAtomicIAdd:
7248 		MSL_AFMO(add);
7249 		break;
7250 
7251 	case OpAtomicISub:
7252 		MSL_AFMO(sub);
7253 		break;
7254 
7255 	case OpAtomicSMin:
7256 	case OpAtomicUMin:
7257 		MSL_AFMO(min);
7258 		break;
7259 
7260 	case OpAtomicSMax:
7261 	case OpAtomicUMax:
7262 		MSL_AFMO(max);
7263 		break;
7264 
7265 	case OpAtomicAnd:
7266 		MSL_AFMO(and);
7267 		break;
7268 
7269 	case OpAtomicOr:
7270 		MSL_AFMO(or);
7271 		break;
7272 
7273 	case OpAtomicXor:
7274 		MSL_AFMO(xor);
7275 		break;
7276 
7277 	// Images
7278 
7279 	// Reads == Fetches in Metal
7280 	case OpImageRead:
7281 	{
7282 		// Mark that this shader reads from this image
7283 		uint32_t img_id = ops[2];
7284 		auto &type = expression_type(img_id);
7285 		if (type.image.dim != DimSubpassData)
7286 		{
7287 			auto *p_var = maybe_get_backing_variable(img_id);
7288 			if (p_var && has_decoration(p_var->self, DecorationNonReadable))
7289 			{
7290 				unset_decoration(p_var->self, DecorationNonReadable);
7291 				force_recompile();
7292 			}
7293 		}
7294 
7295 		emit_texture_op(instruction, false);
7296 		break;
7297 	}
7298 
7299 	// Emulate texture2D atomic operations
7300 	case OpImageTexelPointer:
7301 	{
7302 		// When using the pointer, we need to know which variable it is actually loaded from.
7303 		auto *var = maybe_get_backing_variable(ops[2]);
7304 		if (var && atomic_image_vars.count(var->self))
7305 		{
7306 			uint32_t result_type = ops[0];
7307 			uint32_t id = ops[1];
7308 
7309 			std::string coord = to_expression(ops[3]);
7310 			auto &type = expression_type(ops[2]);
7311 			if (type.image.dim == Dim2D)
7312 			{
7313 				coord = join("spvImage2DAtomicCoord(", coord, ", ", to_expression(ops[2]), ")");
7314 			}
7315 
7316 			auto &e = set<SPIRExpression>(id, join(to_expression(ops[2]), "_atomic[", coord, "]"), result_type, true);
7317 			e.loaded_from = var ? var->self : ID(0);
7318 			inherit_expression_dependencies(id, ops[3]);
7319 		}
7320 		else
7321 		{
7322 			uint32_t result_type = ops[0];
7323 			uint32_t id = ops[1];
7324 			auto &e =
7325 			    set<SPIRExpression>(id, join(to_expression(ops[2]), ", ", to_expression(ops[3])), result_type, true);
7326 
7327 			// When using the pointer, we need to know which variable it is actually loaded from.
7328 			e.loaded_from = var ? var->self : ID(0);
7329 			inherit_expression_dependencies(id, ops[3]);
7330 		}
7331 		break;
7332 	}
7333 
7334 	case OpImageWrite:
7335 	{
7336 		uint32_t img_id = ops[0];
7337 		uint32_t coord_id = ops[1];
7338 		uint32_t texel_id = ops[2];
7339 		const uint32_t *opt = &ops[3];
7340 		uint32_t length = instruction.length - 3;
7341 
7342 		// Bypass pointers because we need the real image struct
7343 		auto &type = expression_type(img_id);
7344 		auto &img_type = get<SPIRType>(type.self);
7345 
7346 		// Ensure this image has been marked as being written to and force a
7347 		// recommpile so that the image type output will include write access
7348 		auto *p_var = maybe_get_backing_variable(img_id);
7349 		if (p_var && has_decoration(p_var->self, DecorationNonWritable))
7350 		{
7351 			unset_decoration(p_var->self, DecorationNonWritable);
7352 			force_recompile();
7353 		}
7354 
7355 		bool forward = false;
7356 		uint32_t bias = 0;
7357 		uint32_t lod = 0;
7358 		uint32_t flags = 0;
7359 
7360 		if (length)
7361 		{
7362 			flags = *opt++;
7363 			length--;
7364 		}
7365 
7366 		auto test = [&](uint32_t &v, uint32_t flag) {
7367 			if (length && (flags & flag))
7368 			{
7369 				v = *opt++;
7370 				length--;
7371 			}
7372 		};
7373 
7374 		test(bias, ImageOperandsBiasMask);
7375 		test(lod, ImageOperandsLodMask);
7376 
7377 		auto &texel_type = expression_type(texel_id);
7378 		auto store_type = texel_type;
7379 		store_type.vecsize = 4;
7380 
7381 		TextureFunctionArguments args = {};
7382 		args.base.img = img_id;
7383 		args.base.imgtype = &img_type;
7384 		args.base.is_fetch = true;
7385 		args.coord = coord_id;
7386 		args.lod = lod;
7387 		statement(join(to_expression(img_id), ".write(",
7388 		               remap_swizzle(store_type, texel_type.vecsize, to_expression(texel_id)), ", ",
7389 		               CompilerMSL::to_function_args(args, &forward), ");"));
7390 
7391 		if (p_var && variable_storage_is_aliased(*p_var))
7392 			flush_all_aliased_variables();
7393 
7394 		break;
7395 	}
7396 
7397 	case OpImageQuerySize:
7398 	case OpImageQuerySizeLod:
7399 	{
7400 		uint32_t rslt_type_id = ops[0];
7401 		auto &rslt_type = get<SPIRType>(rslt_type_id);
7402 
7403 		uint32_t id = ops[1];
7404 
7405 		uint32_t img_id = ops[2];
7406 		string img_exp = to_expression(img_id);
7407 		auto &img_type = expression_type(img_id);
7408 		Dim img_dim = img_type.image.dim;
7409 		bool img_is_array = img_type.image.arrayed;
7410 
7411 		if (img_type.basetype != SPIRType::Image)
7412 			SPIRV_CROSS_THROW("Invalid type for OpImageQuerySize.");
7413 
7414 		string lod;
7415 		if (opcode == OpImageQuerySizeLod)
7416 		{
7417 			// LOD index defaults to zero, so don't bother outputing level zero index
7418 			string decl_lod = to_expression(ops[3]);
7419 			if (decl_lod != "0")
7420 				lod = decl_lod;
7421 		}
7422 
7423 		string expr = type_to_glsl(rslt_type) + "(";
7424 		expr += img_exp + ".get_width(" + lod + ")";
7425 
7426 		if (img_dim == Dim2D || img_dim == DimCube || img_dim == Dim3D)
7427 			expr += ", " + img_exp + ".get_height(" + lod + ")";
7428 
7429 		if (img_dim == Dim3D)
7430 			expr += ", " + img_exp + ".get_depth(" + lod + ")";
7431 
7432 		if (img_is_array)
7433 		{
7434 			expr += ", " + img_exp + ".get_array_size()";
7435 			if (img_dim == DimCube && msl_options.emulate_cube_array)
7436 				expr += " / 6";
7437 		}
7438 
7439 		expr += ")";
7440 
7441 		emit_op(rslt_type_id, id, expr, should_forward(img_id));
7442 
7443 		break;
7444 	}
7445 
7446 	case OpImageQueryLod:
7447 	{
7448 		if (!msl_options.supports_msl_version(2, 2))
7449 			SPIRV_CROSS_THROW("ImageQueryLod is only supported on MSL 2.2 and up.");
7450 		uint32_t result_type = ops[0];
7451 		uint32_t id = ops[1];
7452 		uint32_t image_id = ops[2];
7453 		uint32_t coord_id = ops[3];
7454 		emit_uninitialized_temporary_expression(result_type, id);
7455 
7456 		auto sampler_expr = to_sampler_expression(image_id);
7457 		auto *combined = maybe_get<SPIRCombinedImageSampler>(image_id);
7458 		auto image_expr = combined ? to_expression(combined->image) : to_expression(image_id);
7459 
7460 		// TODO: It is unclear if calculcate_clamped_lod also conditionally rounds
7461 		// the reported LOD based on the sampler. NEAREST miplevel should
7462 		// round the LOD, but LINEAR miplevel should not round.
7463 		// Let's hope this does not become an issue ...
7464 		statement(to_expression(id), ".x = ", image_expr, ".calculate_clamped_lod(", sampler_expr, ", ",
7465 		          to_expression(coord_id), ");");
7466 		statement(to_expression(id), ".y = ", image_expr, ".calculate_unclamped_lod(", sampler_expr, ", ",
7467 		          to_expression(coord_id), ");");
7468 		register_control_dependent_expression(id);
7469 		break;
7470 	}
7471 
7472 #define MSL_ImgQry(qrytype)                                                                 \
7473 	do                                                                                      \
7474 	{                                                                                       \
7475 		uint32_t rslt_type_id = ops[0];                                                     \
7476 		auto &rslt_type = get<SPIRType>(rslt_type_id);                                      \
7477 		uint32_t id = ops[1];                                                               \
7478 		uint32_t img_id = ops[2];                                                           \
7479 		string img_exp = to_expression(img_id);                                             \
7480 		string expr = type_to_glsl(rslt_type) + "(" + img_exp + ".get_num_" #qrytype "())"; \
7481 		emit_op(rslt_type_id, id, expr, should_forward(img_id));                            \
7482 	} while (false)
7483 
7484 	case OpImageQueryLevels:
7485 		MSL_ImgQry(mip_levels);
7486 		break;
7487 
7488 	case OpImageQuerySamples:
7489 		MSL_ImgQry(samples);
7490 		break;
7491 
7492 	case OpImage:
7493 	{
7494 		uint32_t result_type = ops[0];
7495 		uint32_t id = ops[1];
7496 		auto *combined = maybe_get<SPIRCombinedImageSampler>(ops[2]);
7497 
7498 		if (combined)
7499 		{
7500 			auto &e = emit_op(result_type, id, to_expression(combined->image), true, true);
7501 			auto *var = maybe_get_backing_variable(combined->image);
7502 			if (var)
7503 				e.loaded_from = var->self;
7504 		}
7505 		else
7506 		{
7507 			auto *var = maybe_get_backing_variable(ops[2]);
7508 			SPIRExpression *e;
7509 			if (var && has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler))
7510 				e = &emit_op(result_type, id, join(to_expression(ops[2]), ".plane0"), true, true);
7511 			else
7512 				e = &emit_op(result_type, id, to_expression(ops[2]), true, true);
7513 			if (var)
7514 				e->loaded_from = var->self;
7515 		}
7516 		break;
7517 	}
7518 
7519 	// Casting
7520 	case OpQuantizeToF16:
7521 	{
7522 		uint32_t result_type = ops[0];
7523 		uint32_t id = ops[1];
7524 		uint32_t arg = ops[2];
7525 
7526 		string exp;
7527 		auto &type = get<SPIRType>(result_type);
7528 
7529 		switch (type.vecsize)
7530 		{
7531 		case 1:
7532 			exp = join("float(half(", to_expression(arg), "))");
7533 			break;
7534 		case 2:
7535 			exp = join("float2(half2(", to_expression(arg), "))");
7536 			break;
7537 		case 3:
7538 			exp = join("float3(half3(", to_expression(arg), "))");
7539 			break;
7540 		case 4:
7541 			exp = join("float4(half4(", to_expression(arg), "))");
7542 			break;
7543 		default:
7544 			SPIRV_CROSS_THROW("Illegal argument to OpQuantizeToF16.");
7545 		}
7546 
7547 		emit_op(result_type, id, exp, should_forward(arg));
7548 		break;
7549 	}
7550 
7551 	case OpInBoundsAccessChain:
7552 	case OpAccessChain:
7553 	case OpPtrAccessChain:
7554 		if (is_tessellation_shader())
7555 		{
7556 			if (!emit_tessellation_access_chain(ops, instruction.length))
7557 				CompilerGLSL::emit_instruction(instruction);
7558 		}
7559 		else
7560 			CompilerGLSL::emit_instruction(instruction);
7561 		fix_up_interpolant_access_chain(ops, instruction.length);
7562 		break;
7563 
7564 	case OpStore:
7565 		if (is_out_of_bounds_tessellation_level(ops[0]))
7566 			break;
7567 
7568 		if (maybe_emit_array_assignment(ops[0], ops[1]))
7569 			break;
7570 
7571 		CompilerGLSL::emit_instruction(instruction);
7572 		break;
7573 
7574 	// Compute barriers
7575 	case OpMemoryBarrier:
7576 		emit_barrier(0, ops[0], ops[1]);
7577 		break;
7578 
7579 	case OpControlBarrier:
7580 		// In GLSL a memory barrier is often followed by a control barrier.
7581 		// But in MSL, memory barriers are also control barriers, so don't
7582 		// emit a simple control barrier if a memory barrier has just been emitted.
7583 		if (previous_instruction_opcode != OpMemoryBarrier)
7584 			emit_barrier(ops[0], ops[1], ops[2]);
7585 		break;
7586 
7587 	case OpOuterProduct:
7588 	{
7589 		uint32_t result_type = ops[0];
7590 		uint32_t id = ops[1];
7591 		uint32_t a = ops[2];
7592 		uint32_t b = ops[3];
7593 
7594 		auto &type = get<SPIRType>(result_type);
7595 		string expr = type_to_glsl_constructor(type);
7596 		expr += "(";
7597 		for (uint32_t col = 0; col < type.columns; col++)
7598 		{
7599 			expr += to_enclosed_expression(a);
7600 			expr += " * ";
7601 			expr += to_extract_component_expression(b, col);
7602 			if (col + 1 < type.columns)
7603 				expr += ", ";
7604 		}
7605 		expr += ")";
7606 		emit_op(result_type, id, expr, should_forward(a) && should_forward(b));
7607 		inherit_expression_dependencies(id, a);
7608 		inherit_expression_dependencies(id, b);
7609 		break;
7610 	}
7611 
7612 	case OpVectorTimesMatrix:
7613 	case OpMatrixTimesVector:
7614 	{
7615 		if (!msl_options.invariant_float_math)
7616 		{
7617 			CompilerGLSL::emit_instruction(instruction);
7618 			break;
7619 		}
7620 
7621 		// If the matrix needs transpose, just flip the multiply order.
7622 		auto *e = maybe_get<SPIRExpression>(ops[opcode == OpMatrixTimesVector ? 2 : 3]);
7623 		if (e && e->need_transpose)
7624 		{
7625 			e->need_transpose = false;
7626 			string expr;
7627 
7628 			if (opcode == OpMatrixTimesVector)
7629 			{
7630 				expr = join("spvFMulVectorMatrix(", to_enclosed_unpacked_expression(ops[3]), ", ",
7631 				            to_unpacked_row_major_matrix_expression(ops[2]), ")");
7632 			}
7633 			else
7634 			{
7635 				expr = join("spvFMulMatrixVector(", to_unpacked_row_major_matrix_expression(ops[3]), ", ",
7636 				            to_enclosed_unpacked_expression(ops[2]), ")");
7637 			}
7638 
7639 			bool forward = should_forward(ops[2]) && should_forward(ops[3]);
7640 			emit_op(ops[0], ops[1], expr, forward);
7641 			e->need_transpose = true;
7642 			inherit_expression_dependencies(ops[1], ops[2]);
7643 			inherit_expression_dependencies(ops[1], ops[3]);
7644 		}
7645 		else
7646 		{
7647 			if (opcode == OpMatrixTimesVector)
7648 				MSL_BFOP(spvFMulMatrixVector);
7649 			else
7650 				MSL_BFOP(spvFMulVectorMatrix);
7651 		}
7652 		break;
7653 	}
7654 
7655 	case OpMatrixTimesMatrix:
7656 	{
7657 		if (!msl_options.invariant_float_math)
7658 		{
7659 			CompilerGLSL::emit_instruction(instruction);
7660 			break;
7661 		}
7662 
7663 		auto *a = maybe_get<SPIRExpression>(ops[2]);
7664 		auto *b = maybe_get<SPIRExpression>(ops[3]);
7665 
7666 		// If both matrices need transpose, we can multiply in flipped order and tag the expression as transposed.
7667 		// a^T * b^T = (b * a)^T.
7668 		if (a && b && a->need_transpose && b->need_transpose)
7669 		{
7670 			a->need_transpose = false;
7671 			b->need_transpose = false;
7672 
7673 			auto expr =
7674 			    join("spvFMulMatrixMatrix(", enclose_expression(to_unpacked_row_major_matrix_expression(ops[3])), ", ",
7675 			         enclose_expression(to_unpacked_row_major_matrix_expression(ops[2])), ")");
7676 
7677 			bool forward = should_forward(ops[2]) && should_forward(ops[3]);
7678 			auto &e = emit_op(ops[0], ops[1], expr, forward);
7679 			e.need_transpose = true;
7680 			a->need_transpose = true;
7681 			b->need_transpose = true;
7682 			inherit_expression_dependencies(ops[1], ops[2]);
7683 			inherit_expression_dependencies(ops[1], ops[3]);
7684 		}
7685 		else
7686 			MSL_BFOP(spvFMulMatrixMatrix);
7687 
7688 		break;
7689 	}
7690 
7691 	case OpIAddCarry:
7692 	case OpISubBorrow:
7693 	{
7694 		uint32_t result_type = ops[0];
7695 		uint32_t result_id = ops[1];
7696 		uint32_t op0 = ops[2];
7697 		uint32_t op1 = ops[3];
7698 		auto &type = get<SPIRType>(result_type);
7699 		emit_uninitialized_temporary_expression(result_type, result_id);
7700 
7701 		auto &res_type = get<SPIRType>(type.member_types[1]);
7702 		if (opcode == OpIAddCarry)
7703 		{
7704 			statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " + ",
7705 			          to_enclosed_expression(op1), ";");
7706 			statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
7707 			          "(1), ", type_to_glsl(res_type), "(0), ", to_expression(result_id), ".", to_member_name(type, 0),
7708 			          " >= max(", to_expression(op0), ", ", to_expression(op1), "));");
7709 		}
7710 		else
7711 		{
7712 			statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " - ",
7713 			          to_enclosed_expression(op1), ";");
7714 			statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
7715 			          "(1), ", type_to_glsl(res_type), "(0), ", to_enclosed_expression(op0),
7716 			          " >= ", to_enclosed_expression(op1), ");");
7717 		}
7718 		break;
7719 	}
7720 
7721 	case OpUMulExtended:
7722 	case OpSMulExtended:
7723 	{
7724 		uint32_t result_type = ops[0];
7725 		uint32_t result_id = ops[1];
7726 		uint32_t op0 = ops[2];
7727 		uint32_t op1 = ops[3];
7728 		auto &type = get<SPIRType>(result_type);
7729 		emit_uninitialized_temporary_expression(result_type, result_id);
7730 
7731 		statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " * ",
7732 		          to_enclosed_expression(op1), ";");
7733 		statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(", to_expression(op0), ", ",
7734 		          to_expression(op1), ");");
7735 		break;
7736 	}
7737 
7738 	case OpArrayLength:
7739 	{
7740 		auto &type = expression_type(ops[2]);
7741 		uint32_t offset = type_struct_member_offset(type, ops[3]);
7742 		uint32_t stride = type_struct_member_array_stride(type, ops[3]);
7743 
7744 		auto expr = join("(", to_buffer_size_expression(ops[2]), " - ", offset, ") / ", stride);
7745 		emit_op(ops[0], ops[1], expr, true);
7746 		break;
7747 	}
7748 
7749 	// SPV_INTEL_shader_integer_functions2
7750 	case OpUCountLeadingZerosINTEL:
7751 		MSL_UFOP(clz);
7752 		break;
7753 
7754 	case OpUCountTrailingZerosINTEL:
7755 		MSL_UFOP(ctz);
7756 		break;
7757 
7758 	case OpAbsISubINTEL:
7759 	case OpAbsUSubINTEL:
7760 		MSL_BFOP(absdiff);
7761 		break;
7762 
7763 	case OpIAddSatINTEL:
7764 	case OpUAddSatINTEL:
7765 		MSL_BFOP(addsat);
7766 		break;
7767 
7768 	case OpIAverageINTEL:
7769 	case OpUAverageINTEL:
7770 		MSL_BFOP(hadd);
7771 		break;
7772 
7773 	case OpIAverageRoundedINTEL:
7774 	case OpUAverageRoundedINTEL:
7775 		MSL_BFOP(rhadd);
7776 		break;
7777 
7778 	case OpISubSatINTEL:
7779 	case OpUSubSatINTEL:
7780 		MSL_BFOP(subsat);
7781 		break;
7782 
7783 	case OpIMul32x16INTEL:
7784 	{
7785 		uint32_t result_type = ops[0];
7786 		uint32_t id = ops[1];
7787 		uint32_t a = ops[2], b = ops[3];
7788 		bool forward = should_forward(a) && should_forward(b);
7789 		emit_op(result_type, id, join("int(short(", to_expression(a), ")) * int(short(", to_expression(b), "))"),
7790 		        forward);
7791 		inherit_expression_dependencies(id, a);
7792 		inherit_expression_dependencies(id, b);
7793 		break;
7794 	}
7795 
7796 	case OpUMul32x16INTEL:
7797 	{
7798 		uint32_t result_type = ops[0];
7799 		uint32_t id = ops[1];
7800 		uint32_t a = ops[2], b = ops[3];
7801 		bool forward = should_forward(a) && should_forward(b);
7802 		emit_op(result_type, id, join("uint(ushort(", to_expression(a), ")) * uint(ushort(", to_expression(b), "))"),
7803 		        forward);
7804 		inherit_expression_dependencies(id, a);
7805 		inherit_expression_dependencies(id, b);
7806 		break;
7807 	}
7808 
7809 	// SPV_EXT_demote_to_helper_invocation
7810 	case OpDemoteToHelperInvocationEXT:
7811 		if (!msl_options.supports_msl_version(2, 3))
7812 			SPIRV_CROSS_THROW("discard_fragment() does not formally have demote semantics until MSL 2.3.");
7813 		CompilerGLSL::emit_instruction(instruction);
7814 		break;
7815 
7816 	case OpIsHelperInvocationEXT:
7817 		if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
7818 			SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.3 on iOS.");
7819 		else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
7820 			SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.1 on macOS.");
7821 		emit_op(ops[0], ops[1], "simd_is_helper_thread()", false);
7822 		break;
7823 
7824 	case OpBeginInvocationInterlockEXT:
7825 	case OpEndInvocationInterlockEXT:
7826 		if (!msl_options.supports_msl_version(2, 0))
7827 			SPIRV_CROSS_THROW("Raster order groups require MSL 2.0.");
7828 		break; // Nothing to do in the body
7829 
7830 	default:
7831 		CompilerGLSL::emit_instruction(instruction);
7832 		break;
7833 	}
7834 
7835 	previous_instruction_opcode = opcode;
7836 }
7837 
emit_texture_op(const Instruction & i,bool sparse)7838 void CompilerMSL::emit_texture_op(const Instruction &i, bool sparse)
7839 {
7840 	if (sparse)
7841 		SPIRV_CROSS_THROW("Sparse feedback not yet supported in MSL.");
7842 
7843 	if (msl_options.use_framebuffer_fetch_subpasses)
7844 	{
7845 		auto *ops = stream(i);
7846 
7847 		uint32_t result_type_id = ops[0];
7848 		uint32_t id = ops[1];
7849 		uint32_t img = ops[2];
7850 
7851 		auto &type = expression_type(img);
7852 		auto &imgtype = get<SPIRType>(type.self);
7853 
7854 		// Use Metal's native frame-buffer fetch API for subpass inputs.
7855 		if (imgtype.image.dim == DimSubpassData)
7856 		{
7857 			// Subpass inputs cannot be invalidated,
7858 			// so just forward the expression directly.
7859 			string expr = to_expression(img);
7860 			emit_op(result_type_id, id, expr, true);
7861 			return;
7862 		}
7863 	}
7864 
7865 	// Fallback to default implementation
7866 	CompilerGLSL::emit_texture_op(i, sparse);
7867 }
7868 
emit_barrier(uint32_t id_exe_scope,uint32_t id_mem_scope,uint32_t id_mem_sem)7869 void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uint32_t id_mem_sem)
7870 {
7871 	if (get_execution_model() != ExecutionModelGLCompute && get_execution_model() != ExecutionModelTessellationControl)
7872 		return;
7873 
7874 	uint32_t exe_scope = id_exe_scope ? evaluate_constant_u32(id_exe_scope) : uint32_t(ScopeInvocation);
7875 	uint32_t mem_scope = id_mem_scope ? evaluate_constant_u32(id_mem_scope) : uint32_t(ScopeInvocation);
7876 	// Use the wider of the two scopes (smaller value)
7877 	exe_scope = min(exe_scope, mem_scope);
7878 
7879 	if (msl_options.emulate_subgroups && exe_scope >= ScopeSubgroup && !id_mem_sem)
7880 		// In this case, we assume a "subgroup" size of 1. The barrier, then, is a noop.
7881 		return;
7882 
7883 	string bar_stmt;
7884 	if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2))
7885 		bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier";
7886 	else
7887 		bar_stmt = "threadgroup_barrier";
7888 	bar_stmt += "(";
7889 
7890 	uint32_t mem_sem = id_mem_sem ? evaluate_constant_u32(id_mem_sem) : uint32_t(MemorySemanticsMaskNone);
7891 
7892 	// Use the | operator to combine flags if we can.
7893 	if (msl_options.supports_msl_version(1, 2))
7894 	{
7895 		string mem_flags = "";
7896 		// For tesc shaders, this also affects objects in the Output storage class.
7897 		// Since in Metal, these are placed in a device buffer, we have to sync device memory here.
7898 		if (get_execution_model() == ExecutionModelTessellationControl ||
7899 		    (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)))
7900 			mem_flags += "mem_flags::mem_device";
7901 
7902 		// Fix tessellation patch function processing
7903 		if (get_execution_model() == ExecutionModelTessellationControl ||
7904 		    (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
7905 		{
7906 			if (!mem_flags.empty())
7907 				mem_flags += " | ";
7908 			mem_flags += "mem_flags::mem_threadgroup";
7909 		}
7910 		if (mem_sem & MemorySemanticsImageMemoryMask)
7911 		{
7912 			if (!mem_flags.empty())
7913 				mem_flags += " | ";
7914 			mem_flags += "mem_flags::mem_texture";
7915 		}
7916 
7917 		if (mem_flags.empty())
7918 			mem_flags = "mem_flags::mem_none";
7919 
7920 		bar_stmt += mem_flags;
7921 	}
7922 	else
7923 	{
7924 		if ((mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)) &&
7925 		    (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
7926 			bar_stmt += "mem_flags::mem_device_and_threadgroup";
7927 		else if (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask))
7928 			bar_stmt += "mem_flags::mem_device";
7929 		else if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask))
7930 			bar_stmt += "mem_flags::mem_threadgroup";
7931 		else if (mem_sem & MemorySemanticsImageMemoryMask)
7932 			bar_stmt += "mem_flags::mem_texture";
7933 		else
7934 			bar_stmt += "mem_flags::mem_none";
7935 	}
7936 
7937 	bar_stmt += ");";
7938 
7939 	statement(bar_stmt);
7940 
7941 	assert(current_emitting_block);
7942 	flush_control_dependent_expressions(current_emitting_block->self);
7943 	flush_all_active_variables();
7944 }
7945 
emit_array_copy(const string & lhs,uint32_t rhs_id,StorageClass lhs_storage,StorageClass rhs_storage)7946 void CompilerMSL::emit_array_copy(const string &lhs, uint32_t rhs_id, StorageClass lhs_storage,
7947                                   StorageClass rhs_storage)
7948 {
7949 	// Allow Metal to use the array<T> template to make arrays a value type.
7950 	// This, however, cannot be used for threadgroup address specifiers, so consider the custom array copy as fallback.
7951 	bool lhs_thread = (lhs_storage == StorageClassOutput || lhs_storage == StorageClassFunction ||
7952 	                   lhs_storage == StorageClassGeneric || lhs_storage == StorageClassPrivate);
7953 	bool rhs_thread = (rhs_storage == StorageClassInput || rhs_storage == StorageClassFunction ||
7954 	                   rhs_storage == StorageClassGeneric || rhs_storage == StorageClassPrivate);
7955 
7956 	// If threadgroup storage qualifiers are *not* used:
7957 	// Avoid spvCopy* wrapper functions; Otherwise, spvUnsafeArray<> template cannot be used with that storage qualifier.
7958 	if (lhs_thread && rhs_thread && !using_builtin_array())
7959 	{
7960 		statement(lhs, " = ", to_expression(rhs_id), ";");
7961 	}
7962 	else
7963 	{
7964 		// Assignment from an array initializer is fine.
7965 		auto &type = expression_type(rhs_id);
7966 		auto *var = maybe_get_backing_variable(rhs_id);
7967 
7968 		// Unfortunately, we cannot template on address space in MSL,
7969 		// so explicit address space redirection it is ...
7970 		bool is_constant = false;
7971 		if (ir.ids[rhs_id].get_type() == TypeConstant)
7972 		{
7973 			is_constant = true;
7974 		}
7975 		else if (var && var->remapped_variable && var->statically_assigned &&
7976 		         ir.ids[var->static_expression].get_type() == TypeConstant)
7977 		{
7978 			is_constant = true;
7979 		}
7980 		else if (rhs_storage == StorageClassUniform)
7981 		{
7982 			is_constant = true;
7983 		}
7984 
7985 		// For the case where we have OpLoad triggering an array copy,
7986 		// we cannot easily detect this case ahead of time since it's
7987 		// context dependent. We might have to force a recompile here
7988 		// if this is the only use of array copies in our shader.
7989 		if (type.array.size() > 1)
7990 		{
7991 			if (type.array.size() > kArrayCopyMultidimMax)
7992 				SPIRV_CROSS_THROW("Cannot support this many dimensions for arrays of arrays.");
7993 			auto func = static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + type.array.size());
7994 			add_spv_func_and_recompile(func);
7995 		}
7996 		else
7997 			add_spv_func_and_recompile(SPVFuncImplArrayCopy);
7998 
7999 		const char *tag = nullptr;
8000 		if (lhs_thread && is_constant)
8001 			tag = "FromConstantToStack";
8002 		else if (lhs_storage == StorageClassWorkgroup && is_constant)
8003 			tag = "FromConstantToThreadGroup";
8004 		else if (lhs_thread && rhs_thread)
8005 			tag = "FromStackToStack";
8006 		else if (lhs_storage == StorageClassWorkgroup && rhs_thread)
8007 			tag = "FromStackToThreadGroup";
8008 		else if (lhs_thread && rhs_storage == StorageClassWorkgroup)
8009 			tag = "FromThreadGroupToStack";
8010 		else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassWorkgroup)
8011 			tag = "FromThreadGroupToThreadGroup";
8012 		else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassStorageBuffer)
8013 			tag = "FromDeviceToDevice";
8014 		else if (lhs_storage == StorageClassStorageBuffer && is_constant)
8015 			tag = "FromConstantToDevice";
8016 		else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassWorkgroup)
8017 			tag = "FromThreadGroupToDevice";
8018 		else if (lhs_storage == StorageClassStorageBuffer && rhs_thread)
8019 			tag = "FromStackToDevice";
8020 		else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassStorageBuffer)
8021 			tag = "FromDeviceToThreadGroup";
8022 		else if (lhs_thread && rhs_storage == StorageClassStorageBuffer)
8023 			tag = "FromDeviceToStack";
8024 		else
8025 			SPIRV_CROSS_THROW("Unknown storage class used for copying arrays.");
8026 
8027 		// Pass internal array of spvUnsafeArray<> into wrapper functions
8028 		if (lhs_thread && !msl_options.force_native_arrays)
8029 			statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ".elements, ", to_expression(rhs_id), ");");
8030 		else if (rhs_thread && !msl_options.force_native_arrays)
8031 			statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ", ", to_expression(rhs_id), ".elements);");
8032 		else
8033 			statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ", ", to_expression(rhs_id), ");");
8034 	}
8035 }
8036 
8037 // Since MSL does not allow arrays to be copied via simple variable assignment,
8038 // if the LHS and RHS represent an assignment of an entire array, it must be
8039 // implemented by calling an array copy function.
8040 // Returns whether the struct assignment was emitted.
maybe_emit_array_assignment(uint32_t id_lhs,uint32_t id_rhs)8041 bool CompilerMSL::maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs)
8042 {
8043 	// We only care about assignments of an entire array
8044 	auto &type = expression_type(id_rhs);
8045 	if (type.array.size() == 0)
8046 		return false;
8047 
8048 	auto *var = maybe_get<SPIRVariable>(id_lhs);
8049 
8050 	// Is this a remapped, static constant? Don't do anything.
8051 	if (var && var->remapped_variable && var->statically_assigned)
8052 		return true;
8053 
8054 	if (ir.ids[id_rhs].get_type() == TypeConstant && var && var->deferred_declaration)
8055 	{
8056 		// Special case, if we end up declaring a variable when assigning the constant array,
8057 		// we can avoid the copy by directly assigning the constant expression.
8058 		// This is likely necessary to be able to use a variable as a true look-up table, as it is unlikely
8059 		// the compiler will be able to optimize the spvArrayCopy() into a constant LUT.
8060 		// After a variable has been declared, we can no longer assign constant arrays in MSL unfortunately.
8061 		statement(to_expression(id_lhs), " = ", constant_expression(get<SPIRConstant>(id_rhs)), ";");
8062 		return true;
8063 	}
8064 
8065 	// Ensure the LHS variable has been declared
8066 	auto *p_v_lhs = maybe_get_backing_variable(id_lhs);
8067 	if (p_v_lhs)
8068 		flush_variable_declaration(p_v_lhs->self);
8069 
8070 	emit_array_copy(to_expression(id_lhs), id_rhs, get_expression_effective_storage_class(id_lhs),
8071 	                get_expression_effective_storage_class(id_rhs));
8072 	register_write(id_lhs);
8073 
8074 	return true;
8075 }
8076 
8077 // Emits one of the atomic functions. In MSL, the atomic functions operate on pointers
emit_atomic_func_op(uint32_t result_type,uint32_t result_id,const char * op,uint32_t mem_order_1,uint32_t mem_order_2,bool has_mem_order_2,uint32_t obj,uint32_t op1,bool op1_is_pointer,bool op1_is_literal,uint32_t op2)8078 void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, uint32_t mem_order_1,
8079                                       uint32_t mem_order_2, bool has_mem_order_2, uint32_t obj, uint32_t op1,
8080                                       bool op1_is_pointer, bool op1_is_literal, uint32_t op2)
8081 {
8082 	string exp = string(op) + "(";
8083 
8084 	auto &type = get_pointee_type(expression_type(obj));
8085 	exp += "(";
8086 	auto *var = maybe_get_backing_variable(obj);
8087 	if (!var)
8088 		SPIRV_CROSS_THROW("No backing variable for atomic operation.");
8089 
8090 	// Emulate texture2D atomic operations
8091 	const auto &res_type = get<SPIRType>(var->basetype);
8092 	if (res_type.storage == StorageClassUniformConstant && res_type.basetype == SPIRType::Image)
8093 	{
8094 		exp += "device";
8095 	}
8096 	else
8097 	{
8098 		exp += get_argument_address_space(*var);
8099 	}
8100 
8101 	exp += " atomic_";
8102 	exp += type_to_glsl(type);
8103 	exp += "*)";
8104 
8105 	exp += "&";
8106 	exp += to_enclosed_expression(obj);
8107 
8108 	bool is_atomic_compare_exchange_strong = op1_is_pointer && op1;
8109 
8110 	if (is_atomic_compare_exchange_strong)
8111 	{
8112 		assert(strcmp(op, "atomic_compare_exchange_weak_explicit") == 0);
8113 		assert(op2);
8114 		assert(has_mem_order_2);
8115 		exp += ", &";
8116 		exp += to_name(result_id);
8117 		exp += ", ";
8118 		exp += to_expression(op2);
8119 		exp += ", ";
8120 		exp += get_memory_order(mem_order_1);
8121 		exp += ", ";
8122 		exp += get_memory_order(mem_order_2);
8123 		exp += ")";
8124 
8125 		// MSL only supports the weak atomic compare exchange, so emit a CAS loop here.
8126 		// The MSL function returns false if the atomic write fails OR the comparison test fails,
8127 		// so we must validate that it wasn't the comparison test that failed before continuing
8128 		// the CAS loop, otherwise it will loop infinitely, with the comparison test always failing.
8129 		// The function updates the comparitor value from the memory value, so the additional
8130 		// comparison test evaluates the memory value against the expected value.
8131 		emit_uninitialized_temporary_expression(result_type, result_id);
8132 		statement("do");
8133 		begin_scope();
8134 		statement(to_name(result_id), " = ", to_expression(op1), ";");
8135 		end_scope_decl(join("while (!", exp, " && ", to_name(result_id), " == ", to_enclosed_expression(op1), ")"));
8136 	}
8137 	else
8138 	{
8139 		assert(strcmp(op, "atomic_compare_exchange_weak_explicit") != 0);
8140 		if (op1)
8141 		{
8142 			if (op1_is_literal)
8143 				exp += join(", ", op1);
8144 			else
8145 				exp += ", " + to_expression(op1);
8146 		}
8147 		if (op2)
8148 			exp += ", " + to_expression(op2);
8149 
8150 		exp += string(", ") + get_memory_order(mem_order_1);
8151 		if (has_mem_order_2)
8152 			exp += string(", ") + get_memory_order(mem_order_2);
8153 
8154 		exp += ")";
8155 
8156 		if (strcmp(op, "atomic_store_explicit") != 0)
8157 			emit_op(result_type, result_id, exp, false);
8158 		else
8159 			statement(exp, ";");
8160 	}
8161 
8162 	flush_all_atomic_capable_variables();
8163 }
8164 
8165 // Metal only supports relaxed memory order for now
get_memory_order(uint32_t)8166 const char *CompilerMSL::get_memory_order(uint32_t)
8167 {
8168 	return "memory_order_relaxed";
8169 }
8170 
8171 // Override for MSL-specific extension syntax instructions
emit_glsl_op(uint32_t result_type,uint32_t id,uint32_t eop,const uint32_t * args,uint32_t count)8172 void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
8173 {
8174 	auto op = static_cast<GLSLstd450>(eop);
8175 
8176 	// If we need to do implicit bitcasts, make sure we do it with the correct type.
8177 	uint32_t integer_width = get_integer_width_for_glsl_instruction(op, args, count);
8178 	auto int_type = to_signed_basetype(integer_width);
8179 	auto uint_type = to_unsigned_basetype(integer_width);
8180 
8181 	switch (op)
8182 	{
8183 	case GLSLstd450Atan2:
8184 		emit_binary_func_op(result_type, id, args[0], args[1], "atan2");
8185 		break;
8186 	case GLSLstd450InverseSqrt:
8187 		emit_unary_func_op(result_type, id, args[0], "rsqrt");
8188 		break;
8189 	case GLSLstd450RoundEven:
8190 		emit_unary_func_op(result_type, id, args[0], "rint");
8191 		break;
8192 
8193 	case GLSLstd450FindILsb:
8194 	{
8195 		// In this template version of findLSB, we return T.
8196 		auto basetype = expression_type(args[0]).basetype;
8197 		emit_unary_func_op_cast(result_type, id, args[0], "spvFindLSB", basetype, basetype);
8198 		break;
8199 	}
8200 
8201 	case GLSLstd450FindSMsb:
8202 		emit_unary_func_op_cast(result_type, id, args[0], "spvFindSMSB", int_type, int_type);
8203 		break;
8204 
8205 	case GLSLstd450FindUMsb:
8206 		emit_unary_func_op_cast(result_type, id, args[0], "spvFindUMSB", uint_type, uint_type);
8207 		break;
8208 
8209 	case GLSLstd450PackSnorm4x8:
8210 		emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm4x8");
8211 		break;
8212 	case GLSLstd450PackUnorm4x8:
8213 		emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm4x8");
8214 		break;
8215 	case GLSLstd450PackSnorm2x16:
8216 		emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm2x16");
8217 		break;
8218 	case GLSLstd450PackUnorm2x16:
8219 		emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm2x16");
8220 		break;
8221 
8222 	case GLSLstd450PackHalf2x16:
8223 	{
8224 		auto expr = join("as_type<uint>(half2(", to_expression(args[0]), "))");
8225 		emit_op(result_type, id, expr, should_forward(args[0]));
8226 		inherit_expression_dependencies(id, args[0]);
8227 		break;
8228 	}
8229 
8230 	case GLSLstd450UnpackSnorm4x8:
8231 		emit_unary_func_op(result_type, id, args[0], "unpack_snorm4x8_to_float");
8232 		break;
8233 	case GLSLstd450UnpackUnorm4x8:
8234 		emit_unary_func_op(result_type, id, args[0], "unpack_unorm4x8_to_float");
8235 		break;
8236 	case GLSLstd450UnpackSnorm2x16:
8237 		emit_unary_func_op(result_type, id, args[0], "unpack_snorm2x16_to_float");
8238 		break;
8239 	case GLSLstd450UnpackUnorm2x16:
8240 		emit_unary_func_op(result_type, id, args[0], "unpack_unorm2x16_to_float");
8241 		break;
8242 
8243 	case GLSLstd450UnpackHalf2x16:
8244 	{
8245 		auto expr = join("float2(as_type<half2>(", to_expression(args[0]), "))");
8246 		emit_op(result_type, id, expr, should_forward(args[0]));
8247 		inherit_expression_dependencies(id, args[0]);
8248 		break;
8249 	}
8250 
8251 	case GLSLstd450PackDouble2x32:
8252 		emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450PackDouble2x32"); // Currently unsupported
8253 		break;
8254 	case GLSLstd450UnpackDouble2x32:
8255 		emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450UnpackDouble2x32"); // Currently unsupported
8256 		break;
8257 
8258 	case GLSLstd450MatrixInverse:
8259 	{
8260 		auto &mat_type = get<SPIRType>(result_type);
8261 		switch (mat_type.columns)
8262 		{
8263 		case 2:
8264 			emit_unary_func_op(result_type, id, args[0], "spvInverse2x2");
8265 			break;
8266 		case 3:
8267 			emit_unary_func_op(result_type, id, args[0], "spvInverse3x3");
8268 			break;
8269 		case 4:
8270 			emit_unary_func_op(result_type, id, args[0], "spvInverse4x4");
8271 			break;
8272 		default:
8273 			break;
8274 		}
8275 		break;
8276 	}
8277 
8278 	case GLSLstd450FMin:
8279 		// If the result type isn't float, don't bother calling the specific
8280 		// precise::/fast:: version. Metal doesn't have those for half and
8281 		// double types.
8282 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8283 			emit_binary_func_op(result_type, id, args[0], args[1], "min");
8284 		else
8285 			emit_binary_func_op(result_type, id, args[0], args[1], "fast::min");
8286 		break;
8287 
8288 	case GLSLstd450FMax:
8289 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8290 			emit_binary_func_op(result_type, id, args[0], args[1], "max");
8291 		else
8292 			emit_binary_func_op(result_type, id, args[0], args[1], "fast::max");
8293 		break;
8294 
8295 	case GLSLstd450FClamp:
8296 		// TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
8297 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8298 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
8299 		else
8300 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "fast::clamp");
8301 		break;
8302 
8303 	case GLSLstd450NMin:
8304 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8305 			emit_binary_func_op(result_type, id, args[0], args[1], "min");
8306 		else
8307 			emit_binary_func_op(result_type, id, args[0], args[1], "precise::min");
8308 		break;
8309 
8310 	case GLSLstd450NMax:
8311 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8312 			emit_binary_func_op(result_type, id, args[0], args[1], "max");
8313 		else
8314 			emit_binary_func_op(result_type, id, args[0], args[1], "precise::max");
8315 		break;
8316 
8317 	case GLSLstd450NClamp:
8318 		// TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
8319 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8320 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
8321 		else
8322 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "precise::clamp");
8323 		break;
8324 
8325 	case GLSLstd450InterpolateAtCentroid:
8326 	{
8327 		// We can't just emit the expression normally, because the qualified name contains a call to the default
8328 		// interpolate method, or refers to a local variable. We saved the interface index we need; use it to construct
8329 		// the base for the method call.
8330 		uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
8331 		string component;
8332 		if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
8333 		{
8334 			uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
8335 			auto *c = maybe_get<SPIRConstant>(index_expr);
8336 			if (!c || c->specialization)
8337 				component = join("[", to_expression(index_expr), "]");
8338 			else
8339 				component = join(".", index_to_swizzle(c->scalar()));
8340 		}
8341 		emit_op(result_type, id,
8342 		        join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
8343 		             ".interpolate_at_centroid()", component),
8344 		        should_forward(args[0]));
8345 		break;
8346 	}
8347 
8348 	case GLSLstd450InterpolateAtSample:
8349 	{
8350 		uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
8351 		string component;
8352 		if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
8353 		{
8354 			uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
8355 			auto *c = maybe_get<SPIRConstant>(index_expr);
8356 			if (!c || c->specialization)
8357 				component = join("[", to_expression(index_expr), "]");
8358 			else
8359 				component = join(".", index_to_swizzle(c->scalar()));
8360 		}
8361 		emit_op(result_type, id,
8362 		        join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
8363 		             ".interpolate_at_sample(", to_expression(args[1]), ")", component),
8364 		        should_forward(args[0]) && should_forward(args[1]));
8365 		break;
8366 	}
8367 
8368 	case GLSLstd450InterpolateAtOffset:
8369 	{
8370 		uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
8371 		string component;
8372 		if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
8373 		{
8374 			uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
8375 			auto *c = maybe_get<SPIRConstant>(index_expr);
8376 			if (!c || c->specialization)
8377 				component = join("[", to_expression(index_expr), "]");
8378 			else
8379 				component = join(".", index_to_swizzle(c->scalar()));
8380 		}
8381 		// Like Direct3D, Metal puts the (0, 0) at the upper-left corner, not the center as SPIR-V and GLSL do.
8382 		// Offset the offset by (1/2 - 1/16), or 0.4375, to compensate for this.
8383 		// It has to be (1/2 - 1/16) and not 1/2, or several CTS tests subtly break on Intel.
8384 		emit_op(result_type, id,
8385 		        join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
8386 		             ".interpolate_at_offset(", to_expression(args[1]), " + 0.4375)", component),
8387 		        should_forward(args[0]) && should_forward(args[1]));
8388 		break;
8389 	}
8390 
8391 	case GLSLstd450Distance:
8392 		// MSL does not support scalar versions here.
8393 		if (expression_type(args[0]).vecsize == 1)
8394 		{
8395 			// Equivalent to length(a - b) -> abs(a - b).
8396 			emit_op(result_type, id,
8397 			        join("abs(", to_enclosed_unpacked_expression(args[0]), " - ",
8398 			             to_enclosed_unpacked_expression(args[1]), ")"),
8399 			        should_forward(args[0]) && should_forward(args[1]));
8400 			inherit_expression_dependencies(id, args[0]);
8401 			inherit_expression_dependencies(id, args[1]);
8402 		}
8403 		else
8404 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
8405 		break;
8406 
8407 	case GLSLstd450Length:
8408 		// MSL does not support scalar versions here.
8409 		if (expression_type(args[0]).vecsize == 1)
8410 		{
8411 			// Equivalent to abs().
8412 			emit_unary_func_op(result_type, id, args[0], "abs");
8413 		}
8414 		else
8415 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
8416 		break;
8417 
8418 	case GLSLstd450Normalize:
8419 		// MSL does not support scalar versions here.
8420 		if (expression_type(args[0]).vecsize == 1)
8421 		{
8422 			// Returns -1 or 1 for valid input, sign() does the job.
8423 			emit_unary_func_op(result_type, id, args[0], "sign");
8424 		}
8425 		else
8426 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
8427 		break;
8428 
8429 	case GLSLstd450Reflect:
8430 		if (get<SPIRType>(result_type).vecsize == 1)
8431 			emit_binary_func_op(result_type, id, args[0], args[1], "spvReflect");
8432 		else
8433 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
8434 		break;
8435 
8436 	case GLSLstd450Refract:
8437 		if (get<SPIRType>(result_type).vecsize == 1)
8438 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvRefract");
8439 		else
8440 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
8441 		break;
8442 
8443 	case GLSLstd450FaceForward:
8444 		if (get<SPIRType>(result_type).vecsize == 1)
8445 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvFaceForward");
8446 		else
8447 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
8448 		break;
8449 
8450 	case GLSLstd450Modf:
8451 	case GLSLstd450Frexp:
8452 	{
8453 		// Special case. If the variable is a scalar access chain, we cannot use it directly. We have to emit a temporary.
8454 		auto *ptr = maybe_get<SPIRExpression>(args[1]);
8455 		if (ptr && ptr->access_chain && is_scalar(expression_type(args[1])))
8456 		{
8457 			register_call_out_argument(args[1]);
8458 			forced_temporaries.insert(id);
8459 
8460 			// Need to create temporaries and copy over to access chain after.
8461 			// We cannot directly take the reference of a vector swizzle in MSL, even if it's scalar ...
8462 			uint32_t &tmp_id = extra_sub_expressions[id];
8463 			if (!tmp_id)
8464 				tmp_id = ir.increase_bound_by(1);
8465 
8466 			uint32_t tmp_type_id = get_pointee_type_id(ptr->expression_type);
8467 			emit_uninitialized_temporary_expression(tmp_type_id, tmp_id);
8468 			emit_binary_func_op(result_type, id, args[0], tmp_id, eop == GLSLstd450Modf ? "modf" : "frexp");
8469 			statement(to_expression(args[1]), " = ", to_expression(tmp_id), ";");
8470 		}
8471 		else
8472 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
8473 		break;
8474 	}
8475 
8476 	default:
8477 		CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
8478 		break;
8479 	}
8480 }
8481 
emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type,uint32_t id,uint32_t eop,const uint32_t * args,uint32_t count)8482 void CompilerMSL::emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type, uint32_t id, uint32_t eop,
8483                                                         const uint32_t *args, uint32_t count)
8484 {
8485 	enum AMDShaderTrinaryMinMax
8486 	{
8487 		FMin3AMD = 1,
8488 		UMin3AMD = 2,
8489 		SMin3AMD = 3,
8490 		FMax3AMD = 4,
8491 		UMax3AMD = 5,
8492 		SMax3AMD = 6,
8493 		FMid3AMD = 7,
8494 		UMid3AMD = 8,
8495 		SMid3AMD = 9
8496 	};
8497 
8498 	if (!msl_options.supports_msl_version(2, 1))
8499 		SPIRV_CROSS_THROW("Trinary min/max functions require MSL 2.1.");
8500 
8501 	auto op = static_cast<AMDShaderTrinaryMinMax>(eop);
8502 
8503 	switch (op)
8504 	{
8505 	case FMid3AMD:
8506 	case UMid3AMD:
8507 	case SMid3AMD:
8508 		emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "median3");
8509 		break;
8510 	default:
8511 		CompilerGLSL::emit_spv_amd_shader_trinary_minmax_op(result_type, id, eop, args, count);
8512 		break;
8513 	}
8514 }
8515 
8516 // Emit a structure declaration for the specified interface variable.
emit_interface_block(uint32_t ib_var_id)8517 void CompilerMSL::emit_interface_block(uint32_t ib_var_id)
8518 {
8519 	if (ib_var_id)
8520 	{
8521 		auto &ib_var = get<SPIRVariable>(ib_var_id);
8522 		auto &ib_type = get_variable_data_type(ib_var);
8523 		assert(ib_type.basetype == SPIRType::Struct && !ib_type.member_types.empty());
8524 		emit_struct(ib_type);
8525 	}
8526 }
8527 
8528 // Emits the declaration signature of the specified function.
8529 // If this is the entry point function, Metal-specific return value and function arguments are added.
emit_function_prototype(SPIRFunction & func,const Bitset &)8530 void CompilerMSL::emit_function_prototype(SPIRFunction &func, const Bitset &)
8531 {
8532 	if (func.self != ir.default_entry_point)
8533 		add_function_overload(func);
8534 
8535 	local_variable_names = resource_names;
8536 	string decl;
8537 
8538 	processing_entry_point = func.self == ir.default_entry_point;
8539 
8540 	// Metal helper functions must be static force-inline otherwise they will cause problems when linked together in a single Metallib.
8541 	if (!processing_entry_point)
8542 		statement(force_inline);
8543 
8544 	auto &type = get<SPIRType>(func.return_type);
8545 
8546 	if (!type.array.empty() && msl_options.force_native_arrays)
8547 	{
8548 		// We cannot return native arrays in MSL, so "return" through an out variable.
8549 		decl += "void";
8550 	}
8551 	else
8552 	{
8553 		decl += func_type_decl(type);
8554 	}
8555 
8556 	decl += " ";
8557 	decl += to_name(func.self);
8558 	decl += "(";
8559 
8560 	if (!type.array.empty() && msl_options.force_native_arrays)
8561 	{
8562 		// Fake arrays returns by writing to an out array instead.
8563 		decl += "thread ";
8564 		decl += type_to_glsl(type);
8565 		decl += " (&spvReturnValue)";
8566 		decl += type_to_array_glsl(type);
8567 		if (!func.arguments.empty())
8568 			decl += ", ";
8569 	}
8570 
8571 	if (processing_entry_point)
8572 	{
8573 		if (msl_options.argument_buffers)
8574 			decl += entry_point_args_argument_buffer(!func.arguments.empty());
8575 		else
8576 			decl += entry_point_args_classic(!func.arguments.empty());
8577 
8578 		// If entry point function has variables that require early declaration,
8579 		// ensure they each have an empty initializer, creating one if needed.
8580 		// This is done at this late stage because the initialization expression
8581 		// is cleared after each compilation pass.
8582 		for (auto var_id : vars_needing_early_declaration)
8583 		{
8584 			auto &ed_var = get<SPIRVariable>(var_id);
8585 			ID &initializer = ed_var.initializer;
8586 			if (!initializer)
8587 				initializer = ir.increase_bound_by(1);
8588 
8589 			// Do not override proper initializers.
8590 			if (ir.ids[initializer].get_type() == TypeNone || ir.ids[initializer].get_type() == TypeExpression)
8591 				set<SPIRExpression>(ed_var.initializer, "{}", ed_var.basetype, true);
8592 		}
8593 	}
8594 
8595 	for (auto &arg : func.arguments)
8596 	{
8597 		uint32_t name_id = arg.id;
8598 
8599 		auto *var = maybe_get<SPIRVariable>(arg.id);
8600 		if (var)
8601 		{
8602 			// If we need to modify the name of the variable, make sure we modify the original variable.
8603 			// Our alias is just a shadow variable.
8604 			if (arg.alias_global_variable && var->basevariable)
8605 				name_id = var->basevariable;
8606 
8607 			var->parameter = &arg; // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
8608 		}
8609 
8610 		add_local_variable_name(name_id);
8611 
8612 		decl += argument_decl(arg);
8613 
8614 		bool is_dynamic_img_sampler = has_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
8615 
8616 		auto &arg_type = get<SPIRType>(arg.type);
8617 		if (arg_type.basetype == SPIRType::SampledImage && !is_dynamic_img_sampler)
8618 		{
8619 			// Manufacture automatic plane args for multiplanar texture
8620 			uint32_t planes = 1;
8621 			if (auto *constexpr_sampler = find_constexpr_sampler(name_id))
8622 				if (constexpr_sampler->ycbcr_conversion_enable)
8623 					planes = constexpr_sampler->planes;
8624 			for (uint32_t i = 1; i < planes; i++)
8625 				decl += join(", ", argument_decl(arg), plane_name_suffix, i);
8626 
8627 			// Manufacture automatic sampler arg for SampledImage texture
8628 			if (arg_type.image.dim != DimBuffer)
8629 				decl += join(", thread const ", sampler_type(arg_type, arg.id), " ", to_sampler_expression(arg.id));
8630 		}
8631 
8632 		// Manufacture automatic swizzle arg.
8633 		if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(arg_type) &&
8634 		    !is_dynamic_img_sampler)
8635 		{
8636 			bool arg_is_array = !arg_type.array.empty();
8637 			decl += join(", constant uint", arg_is_array ? "* " : "& ", to_swizzle_expression(arg.id));
8638 		}
8639 
8640 		if (buffers_requiring_array_length.count(name_id))
8641 		{
8642 			bool arg_is_array = !arg_type.array.empty();
8643 			decl += join(", constant uint", arg_is_array ? "* " : "& ", to_buffer_size_expression(name_id));
8644 		}
8645 
8646 		if (&arg != &func.arguments.back())
8647 			decl += ", ";
8648 	}
8649 
8650 	decl += ")";
8651 	statement(decl);
8652 }
8653 
needs_chroma_reconstruction(const MSLConstexprSampler * constexpr_sampler)8654 static bool needs_chroma_reconstruction(const MSLConstexprSampler *constexpr_sampler)
8655 {
8656 	// For now, only multiplanar images need explicit reconstruction. GBGR and BGRG images
8657 	// use implicit reconstruction.
8658 	return constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && constexpr_sampler->planes > 1;
8659 }
8660 
8661 // Returns the texture sampling function string for the specified image and sampling characteristics.
to_function_name(const TextureFunctionNameArguments & args)8662 string CompilerMSL::to_function_name(const TextureFunctionNameArguments &args)
8663 {
8664 	VariableID img = args.base.img;
8665 	auto &imgtype = *args.base.imgtype;
8666 
8667 	const MSLConstexprSampler *constexpr_sampler = nullptr;
8668 	bool is_dynamic_img_sampler = false;
8669 	if (auto *var = maybe_get_backing_variable(img))
8670 	{
8671 		constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
8672 		is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
8673 	}
8674 
8675 	// Special-case gather. We have to alter the component being looked up
8676 	// in the swizzle case.
8677 	if (msl_options.swizzle_texture_samples && args.base.is_gather && !is_dynamic_img_sampler &&
8678 	    (!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable))
8679 	{
8680 		add_spv_func_and_recompile(imgtype.image.depth ? SPVFuncImplGatherCompareSwizzle : SPVFuncImplGatherSwizzle);
8681 		return imgtype.image.depth ? "spvGatherCompareSwizzle" : "spvGatherSwizzle";
8682 	}
8683 
8684 	auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
8685 
8686 	// Texture reference
8687 	string fname;
8688 	if (needs_chroma_reconstruction(constexpr_sampler) && !is_dynamic_img_sampler)
8689 	{
8690 		if (constexpr_sampler->planes != 2 && constexpr_sampler->planes != 3)
8691 			SPIRV_CROSS_THROW("Unhandled number of color image planes!");
8692 		// 444 images aren't downsampled, so we don't need to do linear filtering.
8693 		if (constexpr_sampler->resolution == MSL_FORMAT_RESOLUTION_444 ||
8694 		    constexpr_sampler->chroma_filter == MSL_SAMPLER_FILTER_NEAREST)
8695 		{
8696 			if (constexpr_sampler->planes == 2)
8697 				add_spv_func_and_recompile(SPVFuncImplChromaReconstructNearest2Plane);
8698 			else
8699 				add_spv_func_and_recompile(SPVFuncImplChromaReconstructNearest3Plane);
8700 			fname = "spvChromaReconstructNearest";
8701 		}
8702 		else // Linear with a downsampled format
8703 		{
8704 			fname = "spvChromaReconstructLinear";
8705 			switch (constexpr_sampler->resolution)
8706 			{
8707 			case MSL_FORMAT_RESOLUTION_444:
8708 				assert(false);
8709 				break; // not reached
8710 			case MSL_FORMAT_RESOLUTION_422:
8711 				switch (constexpr_sampler->x_chroma_offset)
8712 				{
8713 				case MSL_CHROMA_LOCATION_COSITED_EVEN:
8714 					if (constexpr_sampler->planes == 2)
8715 						add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422CositedEven2Plane);
8716 					else
8717 						add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422CositedEven3Plane);
8718 					fname += "422CositedEven";
8719 					break;
8720 				case MSL_CHROMA_LOCATION_MIDPOINT:
8721 					if (constexpr_sampler->planes == 2)
8722 						add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422Midpoint2Plane);
8723 					else
8724 						add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422Midpoint3Plane);
8725 					fname += "422Midpoint";
8726 					break;
8727 				default:
8728 					SPIRV_CROSS_THROW("Invalid chroma location.");
8729 				}
8730 				break;
8731 			case MSL_FORMAT_RESOLUTION_420:
8732 				fname += "420";
8733 				switch (constexpr_sampler->x_chroma_offset)
8734 				{
8735 				case MSL_CHROMA_LOCATION_COSITED_EVEN:
8736 					switch (constexpr_sampler->y_chroma_offset)
8737 					{
8738 					case MSL_CHROMA_LOCATION_COSITED_EVEN:
8739 						if (constexpr_sampler->planes == 2)
8740 							add_spv_func_and_recompile(
8741 							    SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane);
8742 						else
8743 							add_spv_func_and_recompile(
8744 							    SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane);
8745 						fname += "XCositedEvenYCositedEven";
8746 						break;
8747 					case MSL_CHROMA_LOCATION_MIDPOINT:
8748 						if (constexpr_sampler->planes == 2)
8749 							add_spv_func_and_recompile(
8750 							    SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane);
8751 						else
8752 							add_spv_func_and_recompile(
8753 							    SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane);
8754 						fname += "XCositedEvenYMidpoint";
8755 						break;
8756 					default:
8757 						SPIRV_CROSS_THROW("Invalid Y chroma location.");
8758 					}
8759 					break;
8760 				case MSL_CHROMA_LOCATION_MIDPOINT:
8761 					switch (constexpr_sampler->y_chroma_offset)
8762 					{
8763 					case MSL_CHROMA_LOCATION_COSITED_EVEN:
8764 						if (constexpr_sampler->planes == 2)
8765 							add_spv_func_and_recompile(
8766 							    SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane);
8767 						else
8768 							add_spv_func_and_recompile(
8769 							    SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane);
8770 						fname += "XMidpointYCositedEven";
8771 						break;
8772 					case MSL_CHROMA_LOCATION_MIDPOINT:
8773 						if (constexpr_sampler->planes == 2)
8774 							add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane);
8775 						else
8776 							add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane);
8777 						fname += "XMidpointYMidpoint";
8778 						break;
8779 					default:
8780 						SPIRV_CROSS_THROW("Invalid Y chroma location.");
8781 					}
8782 					break;
8783 				default:
8784 					SPIRV_CROSS_THROW("Invalid X chroma location.");
8785 				}
8786 				break;
8787 			default:
8788 				SPIRV_CROSS_THROW("Invalid format resolution.");
8789 			}
8790 		}
8791 	}
8792 	else
8793 	{
8794 		fname = to_expression(combined ? combined->image : img) + ".";
8795 
8796 		// Texture function and sampler
8797 		if (args.base.is_fetch)
8798 			fname += "read";
8799 		else if (args.base.is_gather)
8800 			fname += "gather";
8801 		else
8802 			fname += "sample";
8803 
8804 		if (args.has_dref)
8805 			fname += "_compare";
8806 	}
8807 
8808 	return fname;
8809 }
8810 
convert_to_f32(const string & expr,uint32_t components)8811 string CompilerMSL::convert_to_f32(const string &expr, uint32_t components)
8812 {
8813 	SPIRType t;
8814 	t.basetype = SPIRType::Float;
8815 	t.vecsize = components;
8816 	t.columns = 1;
8817 	return join(type_to_glsl_constructor(t), "(", expr, ")");
8818 }
8819 
sampling_type_needs_f32_conversion(const SPIRType & type)8820 static inline bool sampling_type_needs_f32_conversion(const SPIRType &type)
8821 {
8822 	// Double is not supported to begin with, but doesn't hurt to check for completion.
8823 	return type.basetype == SPIRType::Half || type.basetype == SPIRType::Double;
8824 }
8825 
8826 // Returns the function args for a texture sampling function for the specified image and sampling characteristics.
to_function_args(const TextureFunctionArguments & args,bool * p_forward)8827 string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool *p_forward)
8828 {
8829 	VariableID img = args.base.img;
8830 	auto &imgtype = *args.base.imgtype;
8831 	uint32_t lod = args.lod;
8832 	uint32_t grad_x = args.grad_x;
8833 	uint32_t grad_y = args.grad_y;
8834 	uint32_t bias = args.bias;
8835 
8836 	const MSLConstexprSampler *constexpr_sampler = nullptr;
8837 	bool is_dynamic_img_sampler = false;
8838 	if (auto *var = maybe_get_backing_variable(img))
8839 	{
8840 		constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
8841 		is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
8842 	}
8843 
8844 	string farg_str;
8845 	bool forward = true;
8846 
8847 	if (!is_dynamic_img_sampler)
8848 	{
8849 		// Texture reference (for some cases)
8850 		if (needs_chroma_reconstruction(constexpr_sampler))
8851 		{
8852 			// Multiplanar images need two or three textures.
8853 			farg_str += to_expression(img);
8854 			for (uint32_t i = 1; i < constexpr_sampler->planes; i++)
8855 				farg_str += join(", ", to_expression(img), plane_name_suffix, i);
8856 		}
8857 		else if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
8858 		         msl_options.swizzle_texture_samples && args.base.is_gather)
8859 		{
8860 			auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
8861 			farg_str += to_expression(combined ? combined->image : img);
8862 		}
8863 
8864 		// Sampler reference
8865 		if (!args.base.is_fetch)
8866 		{
8867 			if (!farg_str.empty())
8868 				farg_str += ", ";
8869 			farg_str += to_sampler_expression(img);
8870 		}
8871 
8872 		if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
8873 		    msl_options.swizzle_texture_samples && args.base.is_gather)
8874 		{
8875 			// Add the swizzle constant from the swizzle buffer.
8876 			farg_str += ", " + to_swizzle_expression(img);
8877 			used_swizzle_buffer = true;
8878 		}
8879 
8880 		// Swizzled gather puts the component before the other args, to allow template
8881 		// deduction to work.
8882 		if (args.component && msl_options.swizzle_texture_samples)
8883 		{
8884 			forward = should_forward(args.component);
8885 			farg_str += ", " + to_component_argument(args.component);
8886 		}
8887 	}
8888 
8889 	// Texture coordinates
8890 	forward = forward && should_forward(args.coord);
8891 	auto coord_expr = to_enclosed_expression(args.coord);
8892 	auto &coord_type = expression_type(args.coord);
8893 	bool coord_is_fp = type_is_floating_point(coord_type);
8894 	bool is_cube_fetch = false;
8895 
8896 	string tex_coords = coord_expr;
8897 	uint32_t alt_coord_component = 0;
8898 
8899 	switch (imgtype.image.dim)
8900 	{
8901 
8902 	case Dim1D:
8903 		if (coord_type.vecsize > 1)
8904 			tex_coords = enclose_expression(tex_coords) + ".x";
8905 
8906 		if (args.base.is_fetch)
8907 			tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8908 		else if (sampling_type_needs_f32_conversion(coord_type))
8909 			tex_coords = convert_to_f32(tex_coords, 1);
8910 
8911 		if (msl_options.texture_1D_as_2D)
8912 		{
8913 			if (args.base.is_fetch)
8914 				tex_coords = "uint2(" + tex_coords + ", 0)";
8915 			else
8916 				tex_coords = "float2(" + tex_coords + ", 0.5)";
8917 		}
8918 
8919 		alt_coord_component = 1;
8920 		break;
8921 
8922 	case DimBuffer:
8923 		if (coord_type.vecsize > 1)
8924 			tex_coords = enclose_expression(tex_coords) + ".x";
8925 
8926 		if (msl_options.texture_buffer_native)
8927 		{
8928 			tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8929 		}
8930 		else
8931 		{
8932 			// Metal texel buffer textures are 2D, so convert 1D coord to 2D.
8933 			// Support for Metal 2.1's new texture_buffer type.
8934 			if (args.base.is_fetch)
8935 			{
8936 				if (msl_options.texel_buffer_texture_width > 0)
8937 				{
8938 					tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8939 				}
8940 				else
8941 				{
8942 					tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ", " +
8943 					             to_expression(img) + ")";
8944 				}
8945 			}
8946 		}
8947 
8948 		alt_coord_component = 1;
8949 		break;
8950 
8951 	case DimSubpassData:
8952 		// If we're using Metal's native frame-buffer fetch API for subpass inputs,
8953 		// this path will not be hit.
8954 		tex_coords = "uint2(gl_FragCoord.xy)";
8955 		alt_coord_component = 2;
8956 		break;
8957 
8958 	case Dim2D:
8959 		if (coord_type.vecsize > 2)
8960 			tex_coords = enclose_expression(tex_coords) + ".xy";
8961 
8962 		if (args.base.is_fetch)
8963 			tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8964 		else if (sampling_type_needs_f32_conversion(coord_type))
8965 			tex_coords = convert_to_f32(tex_coords, 2);
8966 
8967 		alt_coord_component = 2;
8968 		break;
8969 
8970 	case Dim3D:
8971 		if (coord_type.vecsize > 3)
8972 			tex_coords = enclose_expression(tex_coords) + ".xyz";
8973 
8974 		if (args.base.is_fetch)
8975 			tex_coords = "uint3(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8976 		else if (sampling_type_needs_f32_conversion(coord_type))
8977 			tex_coords = convert_to_f32(tex_coords, 3);
8978 
8979 		alt_coord_component = 3;
8980 		break;
8981 
8982 	case DimCube:
8983 		if (args.base.is_fetch)
8984 		{
8985 			is_cube_fetch = true;
8986 			tex_coords += ".xy";
8987 			tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8988 		}
8989 		else
8990 		{
8991 			if (coord_type.vecsize > 3)
8992 				tex_coords = enclose_expression(tex_coords) + ".xyz";
8993 		}
8994 
8995 		if (sampling_type_needs_f32_conversion(coord_type))
8996 			tex_coords = convert_to_f32(tex_coords, 3);
8997 
8998 		alt_coord_component = 3;
8999 		break;
9000 
9001 	default:
9002 		break;
9003 	}
9004 
9005 	if (args.base.is_fetch && (args.offset || args.coffset))
9006 	{
9007 		uint32_t offset_expr = args.offset ? args.offset : args.coffset;
9008 		// Fetch offsets must be applied directly to the coordinate.
9009 		forward = forward && should_forward(offset_expr);
9010 		auto &type = expression_type(offset_expr);
9011 		if (imgtype.image.dim == Dim1D && msl_options.texture_1D_as_2D)
9012 		{
9013 			if (type.basetype != SPIRType::UInt)
9014 				tex_coords += join(" + uint2(", bitcast_expression(SPIRType::UInt, offset_expr), ", 0)");
9015 			else
9016 				tex_coords += join(" + uint2(", to_enclosed_expression(offset_expr), ", 0)");
9017 		}
9018 		else
9019 		{
9020 			if (type.basetype != SPIRType::UInt)
9021 				tex_coords += " + " + bitcast_expression(SPIRType::UInt, offset_expr);
9022 			else
9023 				tex_coords += " + " + to_enclosed_expression(offset_expr);
9024 		}
9025 	}
9026 
9027 	// If projection, use alt coord as divisor
9028 	if (args.base.is_proj)
9029 	{
9030 		if (sampling_type_needs_f32_conversion(coord_type))
9031 			tex_coords += " / " + convert_to_f32(to_extract_component_expression(args.coord, alt_coord_component), 1);
9032 		else
9033 			tex_coords += " / " + to_extract_component_expression(args.coord, alt_coord_component);
9034 	}
9035 
9036 	if (!farg_str.empty())
9037 		farg_str += ", ";
9038 
9039 	if (imgtype.image.dim == DimCube && imgtype.image.arrayed && msl_options.emulate_cube_array)
9040 	{
9041 		farg_str += "spvCubemapTo2DArrayFace(" + tex_coords + ").xy";
9042 
9043 		if (is_cube_fetch)
9044 			farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ")";
9045 		else
9046 			farg_str +=
9047 			    ", uint(spvCubemapTo2DArrayFace(" + tex_coords + ").z) + (uint(" +
9048 			    round_fp_tex_coords(to_extract_component_expression(args.coord, alt_coord_component), coord_is_fp) +
9049 			    ") * 6u)";
9050 
9051 		add_spv_func_and_recompile(SPVFuncImplCubemapTo2DArrayFace);
9052 	}
9053 	else
9054 	{
9055 		farg_str += tex_coords;
9056 
9057 		// If fetch from cube, add face explicitly
9058 		if (is_cube_fetch)
9059 		{
9060 			// Special case for cube arrays, face and layer are packed in one dimension.
9061 			if (imgtype.image.arrayed)
9062 				farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ") % 6u";
9063 			else
9064 				farg_str +=
9065 				    ", uint(" + round_fp_tex_coords(to_extract_component_expression(args.coord, 2), coord_is_fp) + ")";
9066 		}
9067 
9068 		// If array, use alt coord
9069 		if (imgtype.image.arrayed)
9070 		{
9071 			// Special case for cube arrays, face and layer are packed in one dimension.
9072 			if (imgtype.image.dim == DimCube && args.base.is_fetch)
9073 			{
9074 				farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ") / 6u";
9075 			}
9076 			else
9077 			{
9078 				farg_str +=
9079 				    ", uint(" +
9080 				    round_fp_tex_coords(to_extract_component_expression(args.coord, alt_coord_component), coord_is_fp) +
9081 				    ")";
9082 				if (imgtype.image.dim == DimSubpassData)
9083 				{
9084 					if (msl_options.multiview)
9085 						farg_str += " + gl_ViewIndex";
9086 					else if (msl_options.arrayed_subpass_input)
9087 						farg_str += " + gl_Layer";
9088 				}
9089 			}
9090 		}
9091 		else if (imgtype.image.dim == DimSubpassData)
9092 		{
9093 			if (msl_options.multiview)
9094 				farg_str += ", gl_ViewIndex";
9095 			else if (msl_options.arrayed_subpass_input)
9096 				farg_str += ", gl_Layer";
9097 		}
9098 	}
9099 
9100 	// Depth compare reference value
9101 	if (args.dref)
9102 	{
9103 		forward = forward && should_forward(args.dref);
9104 		farg_str += ", ";
9105 
9106 		auto &dref_type = expression_type(args.dref);
9107 
9108 		string dref_expr;
9109 		if (args.base.is_proj)
9110 			dref_expr = join(to_enclosed_expression(args.dref), " / ",
9111 			                 to_extract_component_expression(args.coord, alt_coord_component));
9112 		else
9113 			dref_expr = to_expression(args.dref);
9114 
9115 		if (sampling_type_needs_f32_conversion(dref_type))
9116 			dref_expr = convert_to_f32(dref_expr, 1);
9117 
9118 		farg_str += dref_expr;
9119 
9120 		if (msl_options.is_macos() && (grad_x || grad_y))
9121 		{
9122 			// For sample compare, MSL does not support gradient2d for all targets (only iOS apparently according to docs).
9123 			// However, the most common case here is to have a constant gradient of 0, as that is the only way to express
9124 			// LOD == 0 in GLSL with sampler2DArrayShadow (cascaded shadow mapping).
9125 			// We will detect a compile-time constant 0 value for gradient and promote that to level(0) on MSL.
9126 			bool constant_zero_x = !grad_x || expression_is_constant_null(grad_x);
9127 			bool constant_zero_y = !grad_y || expression_is_constant_null(grad_y);
9128 			if (constant_zero_x && constant_zero_y)
9129 			{
9130 				lod = 0;
9131 				grad_x = 0;
9132 				grad_y = 0;
9133 				farg_str += ", level(0)";
9134 			}
9135 			else if (!msl_options.supports_msl_version(2, 3))
9136 			{
9137 				SPIRV_CROSS_THROW("Using non-constant 0.0 gradient() qualifier for sample_compare. This is not "
9138 				                  "supported on macOS prior to MSL 2.3.");
9139 			}
9140 		}
9141 
9142 		if (msl_options.is_macos() && bias)
9143 		{
9144 			// Bias is not supported either on macOS with sample_compare.
9145 			// Verify it is compile-time zero, and drop the argument.
9146 			if (expression_is_constant_null(bias))
9147 			{
9148 				bias = 0;
9149 			}
9150 			else if (!msl_options.supports_msl_version(2, 3))
9151 			{
9152 				SPIRV_CROSS_THROW("Using non-constant 0.0 bias() qualifier for sample_compare. This is not supported "
9153 				                  "on macOS prior to MSL 2.3.");
9154 			}
9155 		}
9156 	}
9157 
9158 	// LOD Options
9159 	// Metal does not support LOD for 1D textures.
9160 	if (bias && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
9161 	{
9162 		forward = forward && should_forward(bias);
9163 		farg_str += ", bias(" + to_expression(bias) + ")";
9164 	}
9165 
9166 	// Metal does not support LOD for 1D textures.
9167 	if (lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
9168 	{
9169 		forward = forward && should_forward(lod);
9170 		if (args.base.is_fetch)
9171 		{
9172 			farg_str += ", " + to_expression(lod);
9173 		}
9174 		else
9175 		{
9176 			farg_str += ", level(" + to_expression(lod) + ")";
9177 		}
9178 	}
9179 	else if (args.base.is_fetch && !lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D) &&
9180 	         imgtype.image.dim != DimBuffer && !imgtype.image.ms && imgtype.image.sampled != 2)
9181 	{
9182 		// Lod argument is optional in OpImageFetch, but we require a LOD value, pick 0 as the default.
9183 		// Check for sampled type as well, because is_fetch is also used for OpImageRead in MSL.
9184 		farg_str += ", 0";
9185 	}
9186 
9187 	// Metal does not support LOD for 1D textures.
9188 	if ((grad_x || grad_y) && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
9189 	{
9190 		forward = forward && should_forward(grad_x);
9191 		forward = forward && should_forward(grad_y);
9192 		string grad_opt;
9193 		switch (imgtype.image.dim)
9194 		{
9195 		case Dim1D:
9196 		case Dim2D:
9197 			grad_opt = "2d";
9198 			break;
9199 		case Dim3D:
9200 			grad_opt = "3d";
9201 			break;
9202 		case DimCube:
9203 			if (imgtype.image.arrayed && msl_options.emulate_cube_array)
9204 				grad_opt = "2d";
9205 			else
9206 				grad_opt = "cube";
9207 			break;
9208 		default:
9209 			grad_opt = "unsupported_gradient_dimension";
9210 			break;
9211 		}
9212 		farg_str += ", gradient" + grad_opt + "(" + to_expression(grad_x) + ", " + to_expression(grad_y) + ")";
9213 	}
9214 
9215 	if (args.min_lod)
9216 	{
9217 		if (!msl_options.supports_msl_version(2, 2))
9218 			SPIRV_CROSS_THROW("min_lod_clamp() is only supported in MSL 2.2+ and up.");
9219 
9220 		forward = forward && should_forward(args.min_lod);
9221 		farg_str += ", min_lod_clamp(" + to_expression(args.min_lod) + ")";
9222 	}
9223 
9224 	// Add offsets
9225 	string offset_expr;
9226 	const SPIRType *offset_type = nullptr;
9227 	if (args.coffset && !args.base.is_fetch)
9228 	{
9229 		forward = forward && should_forward(args.coffset);
9230 		offset_expr = to_expression(args.coffset);
9231 		offset_type = &expression_type(args.coffset);
9232 	}
9233 	else if (args.offset && !args.base.is_fetch)
9234 	{
9235 		forward = forward && should_forward(args.offset);
9236 		offset_expr = to_expression(args.offset);
9237 		offset_type = &expression_type(args.offset);
9238 	}
9239 
9240 	if (!offset_expr.empty())
9241 	{
9242 		switch (imgtype.image.dim)
9243 		{
9244 		case Dim1D:
9245 			if (!msl_options.texture_1D_as_2D)
9246 				break;
9247 			if (offset_type->vecsize > 1)
9248 				offset_expr = enclose_expression(offset_expr) + ".x";
9249 
9250 			farg_str += join(", int2(", offset_expr, ", 0)");
9251 			break;
9252 
9253 		case Dim2D:
9254 			if (offset_type->vecsize > 2)
9255 				offset_expr = enclose_expression(offset_expr) + ".xy";
9256 
9257 			farg_str += ", " + offset_expr;
9258 			break;
9259 
9260 		case Dim3D:
9261 			if (offset_type->vecsize > 3)
9262 				offset_expr = enclose_expression(offset_expr) + ".xyz";
9263 
9264 			farg_str += ", " + offset_expr;
9265 			break;
9266 
9267 		default:
9268 			break;
9269 		}
9270 	}
9271 
9272 	if (args.component)
9273 	{
9274 		// If 2D has gather component, ensure it also has an offset arg
9275 		if (imgtype.image.dim == Dim2D && offset_expr.empty())
9276 			farg_str += ", int2(0)";
9277 
9278 		if (!msl_options.swizzle_texture_samples || is_dynamic_img_sampler)
9279 		{
9280 			forward = forward && should_forward(args.component);
9281 
9282 			uint32_t image_var = 0;
9283 			if (const auto *combined = maybe_get<SPIRCombinedImageSampler>(img))
9284 			{
9285 				if (const auto *img_var = maybe_get_backing_variable(combined->image))
9286 					image_var = img_var->self;
9287 			}
9288 			else if (const auto *var = maybe_get_backing_variable(img))
9289 			{
9290 				image_var = var->self;
9291 			}
9292 
9293 			if (image_var == 0 || !image_is_comparison(expression_type(image_var), image_var))
9294 				farg_str += ", " + to_component_argument(args.component);
9295 		}
9296 	}
9297 
9298 	if (args.sample)
9299 	{
9300 		forward = forward && should_forward(args.sample);
9301 		farg_str += ", ";
9302 		farg_str += to_expression(args.sample);
9303 	}
9304 
9305 	*p_forward = forward;
9306 
9307 	return farg_str;
9308 }
9309 
9310 // If the texture coordinates are floating point, invokes MSL round() function to round them.
round_fp_tex_coords(string tex_coords,bool coord_is_fp)9311 string CompilerMSL::round_fp_tex_coords(string tex_coords, bool coord_is_fp)
9312 {
9313 	return coord_is_fp ? ("round(" + tex_coords + ")") : tex_coords;
9314 }
9315 
9316 // Returns a string to use in an image sampling function argument.
9317 // The ID must be a scalar constant.
to_component_argument(uint32_t id)9318 string CompilerMSL::to_component_argument(uint32_t id)
9319 {
9320 	uint32_t component_index = evaluate_constant_u32(id);
9321 	switch (component_index)
9322 	{
9323 	case 0:
9324 		return "component::x";
9325 	case 1:
9326 		return "component::y";
9327 	case 2:
9328 		return "component::z";
9329 	case 3:
9330 		return "component::w";
9331 
9332 	default:
9333 		SPIRV_CROSS_THROW("The value (" + to_string(component_index) + ") of OpConstant ID " + to_string(id) +
9334 		                  " is not a valid Component index, which must be one of 0, 1, 2, or 3.");
9335 		return "component::x";
9336 	}
9337 }
9338 
9339 // Establish sampled image as expression object and assign the sampler to it.
emit_sampled_image_op(uint32_t result_type,uint32_t result_id,uint32_t image_id,uint32_t samp_id)9340 void CompilerMSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
9341 {
9342 	set<SPIRCombinedImageSampler>(result_id, result_type, image_id, samp_id);
9343 }
9344 
to_texture_op(const Instruction & i,bool sparse,bool * forward,SmallVector<uint32_t> & inherited_expressions)9345 string CompilerMSL::to_texture_op(const Instruction &i, bool sparse, bool *forward,
9346                                   SmallVector<uint32_t> &inherited_expressions)
9347 {
9348 	auto *ops = stream(i);
9349 	uint32_t result_type_id = ops[0];
9350 	uint32_t img = ops[2];
9351 	auto &result_type = get<SPIRType>(result_type_id);
9352 	auto op = static_cast<Op>(i.op);
9353 	bool is_gather = (op == OpImageGather || op == OpImageDrefGather);
9354 
9355 	// Bypass pointers because we need the real image struct
9356 	auto &type = expression_type(img);
9357 	auto &imgtype = get<SPIRType>(type.self);
9358 
9359 	const MSLConstexprSampler *constexpr_sampler = nullptr;
9360 	bool is_dynamic_img_sampler = false;
9361 	if (auto *var = maybe_get_backing_variable(img))
9362 	{
9363 		constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
9364 		is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
9365 	}
9366 
9367 	string expr;
9368 	if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
9369 	{
9370 		// If this needs sampler Y'CbCr conversion, we need to do some additional
9371 		// processing.
9372 		switch (constexpr_sampler->ycbcr_model)
9373 		{
9374 		case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
9375 		case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
9376 			// Default
9377 			break;
9378 		case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
9379 			add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT709);
9380 			expr += "spvConvertYCbCrBT709(";
9381 			break;
9382 		case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
9383 			add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT601);
9384 			expr += "spvConvertYCbCrBT601(";
9385 			break;
9386 		case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
9387 			add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT2020);
9388 			expr += "spvConvertYCbCrBT2020(";
9389 			break;
9390 		default:
9391 			SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
9392 		}
9393 
9394 		if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
9395 		{
9396 			switch (constexpr_sampler->ycbcr_range)
9397 			{
9398 			case MSL_SAMPLER_YCBCR_RANGE_ITU_FULL:
9399 				add_spv_func_and_recompile(SPVFuncImplExpandITUFullRange);
9400 				expr += "spvExpandITUFullRange(";
9401 				break;
9402 			case MSL_SAMPLER_YCBCR_RANGE_ITU_NARROW:
9403 				add_spv_func_and_recompile(SPVFuncImplExpandITUNarrowRange);
9404 				expr += "spvExpandITUNarrowRange(";
9405 				break;
9406 			default:
9407 				SPIRV_CROSS_THROW("Invalid Y'CbCr range.");
9408 			}
9409 		}
9410 	}
9411 	else if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype) &&
9412 	         !is_dynamic_img_sampler)
9413 	{
9414 		add_spv_func_and_recompile(SPVFuncImplTextureSwizzle);
9415 		expr += "spvTextureSwizzle(";
9416 	}
9417 
9418 	string inner_expr = CompilerGLSL::to_texture_op(i, sparse, forward, inherited_expressions);
9419 
9420 	if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
9421 	{
9422 		if (!constexpr_sampler->swizzle_is_identity())
9423 		{
9424 			static const char swizzle_names[] = "rgba";
9425 			if (!constexpr_sampler->swizzle_has_one_or_zero())
9426 			{
9427 				// If we can, do it inline.
9428 				expr += inner_expr + ".";
9429 				for (uint32_t c = 0; c < 4; c++)
9430 				{
9431 					switch (constexpr_sampler->swizzle[c])
9432 					{
9433 					case MSL_COMPONENT_SWIZZLE_IDENTITY:
9434 						expr += swizzle_names[c];
9435 						break;
9436 					case MSL_COMPONENT_SWIZZLE_R:
9437 					case MSL_COMPONENT_SWIZZLE_G:
9438 					case MSL_COMPONENT_SWIZZLE_B:
9439 					case MSL_COMPONENT_SWIZZLE_A:
9440 						expr += swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
9441 						break;
9442 					default:
9443 						SPIRV_CROSS_THROW("Invalid component swizzle.");
9444 					}
9445 				}
9446 			}
9447 			else
9448 			{
9449 				// Otherwise, we need to emit a temporary and swizzle that.
9450 				uint32_t temp_id = ir.increase_bound_by(1);
9451 				emit_op(result_type_id, temp_id, inner_expr, false);
9452 				for (auto &inherit : inherited_expressions)
9453 					inherit_expression_dependencies(temp_id, inherit);
9454 				inherited_expressions.clear();
9455 				inherited_expressions.push_back(temp_id);
9456 
9457 				switch (op)
9458 				{
9459 				case OpImageSampleDrefImplicitLod:
9460 				case OpImageSampleImplicitLod:
9461 				case OpImageSampleProjImplicitLod:
9462 				case OpImageSampleProjDrefImplicitLod:
9463 					register_control_dependent_expression(temp_id);
9464 					break;
9465 
9466 				default:
9467 					break;
9468 				}
9469 				expr += type_to_glsl(result_type) + "(";
9470 				for (uint32_t c = 0; c < 4; c++)
9471 				{
9472 					switch (constexpr_sampler->swizzle[c])
9473 					{
9474 					case MSL_COMPONENT_SWIZZLE_IDENTITY:
9475 						expr += to_expression(temp_id) + "." + swizzle_names[c];
9476 						break;
9477 					case MSL_COMPONENT_SWIZZLE_ZERO:
9478 						expr += "0";
9479 						break;
9480 					case MSL_COMPONENT_SWIZZLE_ONE:
9481 						expr += "1";
9482 						break;
9483 					case MSL_COMPONENT_SWIZZLE_R:
9484 					case MSL_COMPONENT_SWIZZLE_G:
9485 					case MSL_COMPONENT_SWIZZLE_B:
9486 					case MSL_COMPONENT_SWIZZLE_A:
9487 						expr += to_expression(temp_id) + "." +
9488 						        swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
9489 						break;
9490 					default:
9491 						SPIRV_CROSS_THROW("Invalid component swizzle.");
9492 					}
9493 					if (c < 3)
9494 						expr += ", ";
9495 				}
9496 				expr += ")";
9497 			}
9498 		}
9499 		else
9500 			expr += inner_expr;
9501 		if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
9502 		{
9503 			expr += join(", ", constexpr_sampler->bpc, ")");
9504 			if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY)
9505 				expr += ")";
9506 		}
9507 	}
9508 	else
9509 	{
9510 		expr += inner_expr;
9511 		if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype) &&
9512 		    !is_dynamic_img_sampler)
9513 		{
9514 			// Add the swizzle constant from the swizzle buffer.
9515 			expr += ", " + to_swizzle_expression(img) + ")";
9516 			used_swizzle_buffer = true;
9517 		}
9518 	}
9519 
9520 	return expr;
9521 }
9522 
create_swizzle(MSLComponentSwizzle swizzle)9523 static string create_swizzle(MSLComponentSwizzle swizzle)
9524 {
9525 	switch (swizzle)
9526 	{
9527 	case MSL_COMPONENT_SWIZZLE_IDENTITY:
9528 		return "spvSwizzle::none";
9529 	case MSL_COMPONENT_SWIZZLE_ZERO:
9530 		return "spvSwizzle::zero";
9531 	case MSL_COMPONENT_SWIZZLE_ONE:
9532 		return "spvSwizzle::one";
9533 	case MSL_COMPONENT_SWIZZLE_R:
9534 		return "spvSwizzle::red";
9535 	case MSL_COMPONENT_SWIZZLE_G:
9536 		return "spvSwizzle::green";
9537 	case MSL_COMPONENT_SWIZZLE_B:
9538 		return "spvSwizzle::blue";
9539 	case MSL_COMPONENT_SWIZZLE_A:
9540 		return "spvSwizzle::alpha";
9541 	default:
9542 		SPIRV_CROSS_THROW("Invalid component swizzle.");
9543 		return "";
9544 	}
9545 }
9546 
9547 // Returns a string representation of the ID, usable as a function arg.
9548 // Manufacture automatic sampler arg for SampledImage texture.
to_func_call_arg(const SPIRFunction::Parameter & arg,uint32_t id)9549 string CompilerMSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id)
9550 {
9551 	string arg_str;
9552 
9553 	auto &type = expression_type(id);
9554 	bool is_dynamic_img_sampler = has_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
9555 	// If the argument *itself* is a "dynamic" combined-image sampler, then we can just pass that around.
9556 	bool arg_is_dynamic_img_sampler = has_extended_decoration(id, SPIRVCrossDecorationDynamicImageSampler);
9557 	if (is_dynamic_img_sampler && !arg_is_dynamic_img_sampler)
9558 		arg_str = join("spvDynamicImageSampler<", type_to_glsl(get<SPIRType>(type.image.type)), ">(");
9559 
9560 	auto *c = maybe_get<SPIRConstant>(id);
9561 	if (msl_options.force_native_arrays && c && !get<SPIRType>(c->constant_type).array.empty())
9562 	{
9563 		// If we are passing a constant array directly to a function for some reason,
9564 		// the callee will expect an argument in thread const address space
9565 		// (since we can only bind to arrays with references in MSL).
9566 		// To resolve this, we must emit a copy in this address space.
9567 		// This kind of code gen should be rare enough that performance is not a real concern.
9568 		// Inline the SPIR-V to avoid this kind of suboptimal codegen.
9569 		//
9570 		// We risk calling this inside a continue block (invalid code),
9571 		// so just create a thread local copy in the current function.
9572 		arg_str = join("_", id, "_array_copy");
9573 		auto &constants = current_function->constant_arrays_needed_on_stack;
9574 		auto itr = find(begin(constants), end(constants), ID(id));
9575 		if (itr == end(constants))
9576 		{
9577 			force_recompile();
9578 			constants.push_back(id);
9579 		}
9580 	}
9581 	else
9582 		arg_str += CompilerGLSL::to_func_call_arg(arg, id);
9583 
9584 	// Need to check the base variable in case we need to apply a qualified alias.
9585 	uint32_t var_id = 0;
9586 	auto *var = maybe_get<SPIRVariable>(id);
9587 	if (var)
9588 		var_id = var->basevariable;
9589 
9590 	if (!arg_is_dynamic_img_sampler)
9591 	{
9592 		auto *constexpr_sampler = find_constexpr_sampler(var_id ? var_id : id);
9593 		if (type.basetype == SPIRType::SampledImage)
9594 		{
9595 			// Manufacture automatic plane args for multiplanar texture
9596 			uint32_t planes = 1;
9597 			if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
9598 			{
9599 				planes = constexpr_sampler->planes;
9600 				// If this parameter isn't aliasing a global, then we need to use
9601 				// the special "dynamic image-sampler" class to pass it--and we need
9602 				// to use it for *every* non-alias parameter, in case a combined
9603 				// image-sampler with a Y'CbCr conversion is passed. Hopefully, this
9604 				// pathological case is so rare that it should never be hit in practice.
9605 				if (!arg.alias_global_variable)
9606 					add_spv_func_and_recompile(SPVFuncImplDynamicImageSampler);
9607 			}
9608 			for (uint32_t i = 1; i < planes; i++)
9609 				arg_str += join(", ", CompilerGLSL::to_func_call_arg(arg, id), plane_name_suffix, i);
9610 			// Manufacture automatic sampler arg if the arg is a SampledImage texture.
9611 			if (type.image.dim != DimBuffer)
9612 				arg_str += ", " + to_sampler_expression(var_id ? var_id : id);
9613 
9614 			// Add sampler Y'CbCr conversion info if we have it
9615 			if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
9616 			{
9617 				SmallVector<string> samp_args;
9618 
9619 				switch (constexpr_sampler->resolution)
9620 				{
9621 				case MSL_FORMAT_RESOLUTION_444:
9622 					// Default
9623 					break;
9624 				case MSL_FORMAT_RESOLUTION_422:
9625 					samp_args.push_back("spvFormatResolution::_422");
9626 					break;
9627 				case MSL_FORMAT_RESOLUTION_420:
9628 					samp_args.push_back("spvFormatResolution::_420");
9629 					break;
9630 				default:
9631 					SPIRV_CROSS_THROW("Invalid format resolution.");
9632 				}
9633 
9634 				if (constexpr_sampler->chroma_filter != MSL_SAMPLER_FILTER_NEAREST)
9635 					samp_args.push_back("spvChromaFilter::linear");
9636 
9637 				if (constexpr_sampler->x_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
9638 					samp_args.push_back("spvXChromaLocation::midpoint");
9639 				if (constexpr_sampler->y_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
9640 					samp_args.push_back("spvYChromaLocation::midpoint");
9641 				switch (constexpr_sampler->ycbcr_model)
9642 				{
9643 				case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
9644 					// Default
9645 					break;
9646 				case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
9647 					samp_args.push_back("spvYCbCrModelConversion::ycbcr_identity");
9648 					break;
9649 				case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
9650 					samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_709");
9651 					break;
9652 				case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
9653 					samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_601");
9654 					break;
9655 				case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
9656 					samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_2020");
9657 					break;
9658 				default:
9659 					SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
9660 				}
9661 				if (constexpr_sampler->ycbcr_range != MSL_SAMPLER_YCBCR_RANGE_ITU_FULL)
9662 					samp_args.push_back("spvYCbCrRange::itu_narrow");
9663 				samp_args.push_back(join("spvComponentBits(", constexpr_sampler->bpc, ")"));
9664 				arg_str += join(", spvYCbCrSampler(", merge(samp_args), ")");
9665 			}
9666 		}
9667 
9668 		if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
9669 			arg_str += join(", (uint(", create_swizzle(constexpr_sampler->swizzle[3]), ") << 24) | (uint(",
9670 			                create_swizzle(constexpr_sampler->swizzle[2]), ") << 16) | (uint(",
9671 			                create_swizzle(constexpr_sampler->swizzle[1]), ") << 8) | uint(",
9672 			                create_swizzle(constexpr_sampler->swizzle[0]), ")");
9673 		else if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
9674 			arg_str += ", " + to_swizzle_expression(var_id ? var_id : id);
9675 
9676 		if (buffers_requiring_array_length.count(var_id))
9677 			arg_str += ", " + to_buffer_size_expression(var_id ? var_id : id);
9678 
9679 		if (is_dynamic_img_sampler)
9680 			arg_str += ")";
9681 	}
9682 
9683 	// Emulate texture2D atomic operations
9684 	auto *backing_var = maybe_get_backing_variable(var_id);
9685 	if (backing_var && atomic_image_vars.count(backing_var->self))
9686 	{
9687 		arg_str += ", " + to_expression(var_id) + "_atomic";
9688 	}
9689 
9690 	return arg_str;
9691 }
9692 
9693 // If the ID represents a sampled image that has been assigned a sampler already,
9694 // generate an expression for the sampler, otherwise generate a fake sampler name
9695 // by appending a suffix to the expression constructed from the ID.
to_sampler_expression(uint32_t id)9696 string CompilerMSL::to_sampler_expression(uint32_t id)
9697 {
9698 	auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
9699 	auto expr = to_expression(combined ? combined->image : VariableID(id));
9700 	auto index = expr.find_first_of('[');
9701 
9702 	uint32_t samp_id = 0;
9703 	if (combined)
9704 		samp_id = combined->sampler;
9705 
9706 	if (index == string::npos)
9707 		return samp_id ? to_expression(samp_id) : expr + sampler_name_suffix;
9708 	else
9709 	{
9710 		auto image_expr = expr.substr(0, index);
9711 		auto array_expr = expr.substr(index);
9712 		return samp_id ? to_expression(samp_id) : (image_expr + sampler_name_suffix + array_expr);
9713 	}
9714 }
9715 
to_swizzle_expression(uint32_t id)9716 string CompilerMSL::to_swizzle_expression(uint32_t id)
9717 {
9718 	auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
9719 
9720 	auto expr = to_expression(combined ? combined->image : VariableID(id));
9721 	auto index = expr.find_first_of('[');
9722 
9723 	// If an image is part of an argument buffer translate this to a legal identifier.
9724 	string::size_type period = 0;
9725 	while ((period = expr.find_first_of('.', period)) != string::npos && period < index)
9726 		expr[period] = '_';
9727 
9728 	if (index == string::npos)
9729 		return expr + swizzle_name_suffix;
9730 	else
9731 	{
9732 		auto image_expr = expr.substr(0, index);
9733 		auto array_expr = expr.substr(index);
9734 		return image_expr + swizzle_name_suffix + array_expr;
9735 	}
9736 }
9737 
to_buffer_size_expression(uint32_t id)9738 string CompilerMSL::to_buffer_size_expression(uint32_t id)
9739 {
9740 	auto expr = to_expression(id);
9741 	auto index = expr.find_first_of('[');
9742 
9743 	// This is quite crude, but we need to translate the reference name (*spvDescriptorSetN.name) to
9744 	// the pointer expression spvDescriptorSetN.name to make a reasonable expression here.
9745 	// This only happens if we have argument buffers and we are using OpArrayLength on a lone SSBO in that set.
9746 	if (expr.size() >= 3 && expr[0] == '(' && expr[1] == '*')
9747 		expr = address_of_expression(expr);
9748 
9749 	// If a buffer is part of an argument buffer translate this to a legal identifier.
9750 	for (auto &c : expr)
9751 		if (c == '.')
9752 			c = '_';
9753 
9754 	if (index == string::npos)
9755 		return expr + buffer_size_name_suffix;
9756 	else
9757 	{
9758 		auto buffer_expr = expr.substr(0, index);
9759 		auto array_expr = expr.substr(index);
9760 		return buffer_expr + buffer_size_name_suffix + array_expr;
9761 	}
9762 }
9763 
9764 // Checks whether the type is a Block all of whose members have DecorationPatch.
is_patch_block(const SPIRType & type)9765 bool CompilerMSL::is_patch_block(const SPIRType &type)
9766 {
9767 	if (!has_decoration(type.self, DecorationBlock))
9768 		return false;
9769 
9770 	for (uint32_t i = 0; i < type.member_types.size(); i++)
9771 	{
9772 		if (!has_member_decoration(type.self, i, DecorationPatch))
9773 			return false;
9774 	}
9775 
9776 	return true;
9777 }
9778 
9779 // Checks whether the ID is a row_major matrix that requires conversion before use
is_non_native_row_major_matrix(uint32_t id)9780 bool CompilerMSL::is_non_native_row_major_matrix(uint32_t id)
9781 {
9782 	auto *e = maybe_get<SPIRExpression>(id);
9783 	if (e)
9784 		return e->need_transpose;
9785 	else
9786 		return has_decoration(id, DecorationRowMajor);
9787 }
9788 
9789 // Checks whether the member is a row_major matrix that requires conversion before use
member_is_non_native_row_major_matrix(const SPIRType & type,uint32_t index)9790 bool CompilerMSL::member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index)
9791 {
9792 	return has_member_decoration(type.self, index, DecorationRowMajor);
9793 }
9794 
convert_row_major_matrix(string exp_str,const SPIRType & exp_type,uint32_t physical_type_id,bool is_packed)9795 string CompilerMSL::convert_row_major_matrix(string exp_str, const SPIRType &exp_type, uint32_t physical_type_id,
9796                                              bool is_packed)
9797 {
9798 	if (!is_matrix(exp_type))
9799 	{
9800 		return CompilerGLSL::convert_row_major_matrix(move(exp_str), exp_type, physical_type_id, is_packed);
9801 	}
9802 	else
9803 	{
9804 		strip_enclosed_expression(exp_str);
9805 		if (physical_type_id != 0 || is_packed)
9806 			exp_str = unpack_expression_type(exp_str, exp_type, physical_type_id, is_packed, true);
9807 		return join("transpose(", exp_str, ")");
9808 	}
9809 }
9810 
9811 // Called automatically at the end of the entry point function
emit_fixup()9812 void CompilerMSL::emit_fixup()
9813 {
9814 	if (is_vertex_like_shader() && stage_out_var_id && !qual_pos_var_name.empty() && !capture_output_to_buffer)
9815 	{
9816 		if (options.vertex.fixup_clipspace)
9817 			statement(qual_pos_var_name, ".z = (", qual_pos_var_name, ".z + ", qual_pos_var_name,
9818 			          ".w) * 0.5;       // Adjust clip-space for Metal");
9819 
9820 		if (options.vertex.flip_vert_y)
9821 			statement(qual_pos_var_name, ".y = -(", qual_pos_var_name, ".y);", "    // Invert Y-axis for Metal");
9822 	}
9823 }
9824 
9825 // Return a string defining a structure member, with padding and packing.
to_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier)9826 string CompilerMSL::to_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
9827                                      const string &qualifier)
9828 {
9829 	if (member_is_remapped_physical_type(type, index))
9830 		member_type_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID);
9831 	auto &physical_type = get<SPIRType>(member_type_id);
9832 
9833 	// If this member is packed, mark it as so.
9834 	string pack_pfx;
9835 
9836 	// Allow Metal to use the array<T> template to make arrays a value type
9837 	uint32_t orig_id = 0;
9838 	if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID))
9839 		orig_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID);
9840 
9841 	bool row_major = false;
9842 	if (is_matrix(physical_type))
9843 		row_major = has_member_decoration(type.self, index, DecorationRowMajor);
9844 
9845 	SPIRType row_major_physical_type;
9846 	const SPIRType *declared_type = &physical_type;
9847 
9848 	// If a struct is being declared with physical layout,
9849 	// do not use array<T> wrappers.
9850 	// This avoids a lot of complicated cases with packed vectors and matrices,
9851 	// and generally we cannot copy full arrays in and out of buffers into Function
9852 	// address space.
9853 	// Array of resources should also be declared as builtin arrays.
9854 	if (has_member_decoration(type.self, index, DecorationOffset))
9855 		is_using_builtin_array = true;
9856 	else if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary))
9857 		is_using_builtin_array = true;
9858 
9859 	if (member_is_packed_physical_type(type, index))
9860 	{
9861 		// If we're packing a matrix, output an appropriate typedef
9862 		if (physical_type.basetype == SPIRType::Struct)
9863 		{
9864 			SPIRV_CROSS_THROW("Cannot emit a packed struct currently.");
9865 		}
9866 		else if (is_matrix(physical_type))
9867 		{
9868 			uint32_t rows = physical_type.vecsize;
9869 			uint32_t cols = physical_type.columns;
9870 			pack_pfx = "packed_";
9871 			if (row_major)
9872 			{
9873 				// These are stored transposed.
9874 				rows = physical_type.columns;
9875 				cols = physical_type.vecsize;
9876 				pack_pfx = "packed_rm_";
9877 			}
9878 			string base_type = physical_type.width == 16 ? "half" : "float";
9879 			string td_line = "typedef ";
9880 			td_line += "packed_" + base_type + to_string(rows);
9881 			td_line += " " + pack_pfx;
9882 			// Use the actual matrix size here.
9883 			td_line += base_type + to_string(physical_type.columns) + "x" + to_string(physical_type.vecsize);
9884 			td_line += "[" + to_string(cols) + "]";
9885 			td_line += ";";
9886 			add_typedef_line(td_line);
9887 		}
9888 		else if (!is_scalar(physical_type)) // scalar type is already packed.
9889 			pack_pfx = "packed_";
9890 	}
9891 	else if (row_major)
9892 	{
9893 		// Need to declare type with flipped vecsize/columns.
9894 		row_major_physical_type = physical_type;
9895 		swap(row_major_physical_type.vecsize, row_major_physical_type.columns);
9896 		declared_type = &row_major_physical_type;
9897 	}
9898 
9899 	// Very specifically, image load-store in argument buffers are disallowed on MSL on iOS.
9900 	if (msl_options.is_ios() && physical_type.basetype == SPIRType::Image && physical_type.image.sampled == 2)
9901 	{
9902 		if (!has_decoration(orig_id, DecorationNonWritable))
9903 			SPIRV_CROSS_THROW("Writable images are not allowed in argument buffers on iOS.");
9904 	}
9905 
9906 	// Array information is baked into these types.
9907 	string array_type;
9908 	if (physical_type.basetype != SPIRType::Image && physical_type.basetype != SPIRType::Sampler &&
9909 	    physical_type.basetype != SPIRType::SampledImage)
9910 	{
9911 		BuiltIn builtin = BuiltInMax;
9912 
9913 		// Special handling. In [[stage_out]] or [[stage_in]] blocks,
9914 		// we need flat arrays, but if we're somehow declaring gl_PerVertex for constant array reasons, we want
9915 		// template array types to be declared.
9916 		bool is_ib_in_out =
9917 				((stage_out_var_id && get_stage_out_struct_type().self == type.self) ||
9918 				 (stage_in_var_id && get_stage_in_struct_type().self == type.self));
9919 		if (is_ib_in_out && is_member_builtin(type, index, &builtin))
9920 			is_using_builtin_array = true;
9921 		array_type = type_to_array_glsl(physical_type);
9922 	}
9923 
9924 	auto result = join(pack_pfx, type_to_glsl(*declared_type, orig_id), " ", qualifier, to_member_name(type, index),
9925 	                   member_attribute_qualifier(type, index), array_type, ";");
9926 
9927 	is_using_builtin_array = false;
9928 	return result;
9929 }
9930 
9931 // Emit a structure member, padding and packing to maintain the correct memeber alignments.
emit_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier,uint32_t)9932 void CompilerMSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
9933                                      const string &qualifier, uint32_t)
9934 {
9935 	// If this member requires padding to maintain its declared offset, emit a dummy padding member before it.
9936 	if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationPaddingTarget))
9937 	{
9938 		uint32_t pad_len = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPaddingTarget);
9939 		statement("char _m", index, "_pad", "[", pad_len, "];");
9940 	}
9941 
9942 	// Handle HLSL-style 0-based vertex/instance index.
9943 	builtin_declaration = true;
9944 	statement(to_struct_member(type, member_type_id, index, qualifier));
9945 	builtin_declaration = false;
9946 }
9947 
emit_struct_padding_target(const SPIRType & type)9948 void CompilerMSL::emit_struct_padding_target(const SPIRType &type)
9949 {
9950 	uint32_t struct_size = get_declared_struct_size_msl(type, true, true);
9951 	uint32_t target_size = get_extended_decoration(type.self, SPIRVCrossDecorationPaddingTarget);
9952 	if (target_size < struct_size)
9953 		SPIRV_CROSS_THROW("Cannot pad with negative bytes.");
9954 	else if (target_size > struct_size)
9955 		statement("char _m0_final_padding[", target_size - struct_size, "];");
9956 }
9957 
9958 // Return a MSL qualifier for the specified function attribute member
member_attribute_qualifier(const SPIRType & type,uint32_t index)9959 string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t index)
9960 {
9961 	auto &execution = get_entry_point();
9962 
9963 	uint32_t mbr_type_id = type.member_types[index];
9964 	auto &mbr_type = get<SPIRType>(mbr_type_id);
9965 
9966 	BuiltIn builtin = BuiltInMax;
9967 	bool is_builtin = is_member_builtin(type, index, &builtin);
9968 
9969 	if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary))
9970 	{
9971 		string quals = join(
9972 		    " [[id(", get_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary), ")");
9973 		if (interlocked_resources.count(
9974 		        get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID)))
9975 			quals += ", raster_order_group(0)";
9976 		quals += "]]";
9977 		return quals;
9978 	}
9979 
9980 	// Vertex function inputs
9981 	if (execution.model == ExecutionModelVertex && type.storage == StorageClassInput)
9982 	{
9983 		if (is_builtin)
9984 		{
9985 			switch (builtin)
9986 			{
9987 			case BuiltInVertexId:
9988 			case BuiltInVertexIndex:
9989 			case BuiltInBaseVertex:
9990 			case BuiltInInstanceId:
9991 			case BuiltInInstanceIndex:
9992 			case BuiltInBaseInstance:
9993 				if (msl_options.vertex_for_tessellation)
9994 					return "";
9995 				return string(" [[") + builtin_qualifier(builtin) + "]]";
9996 
9997 			case BuiltInDrawIndex:
9998 				SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
9999 
10000 			default:
10001 				return "";
10002 			}
10003 		}
10004 		uint32_t locn = get_ordered_member_location(type.self, index);
10005 		if (locn != k_unknown_location)
10006 			return string(" [[attribute(") + convert_to_string(locn) + ")]]";
10007 	}
10008 
10009 	// Vertex and tessellation evaluation function outputs
10010 	if (((execution.model == ExecutionModelVertex && !msl_options.vertex_for_tessellation) ||
10011 	     execution.model == ExecutionModelTessellationEvaluation) &&
10012 	    type.storage == StorageClassOutput)
10013 	{
10014 		if (is_builtin)
10015 		{
10016 			switch (builtin)
10017 			{
10018 			case BuiltInPointSize:
10019 				// Only mark the PointSize builtin if really rendering points.
10020 				// Some shaders may include a PointSize builtin even when used to render
10021 				// non-point topologies, and Metal will reject this builtin when compiling
10022 				// the shader into a render pipeline that uses a non-point topology.
10023 				return msl_options.enable_point_size_builtin ? (string(" [[") + builtin_qualifier(builtin) + "]]") : "";
10024 
10025 			case BuiltInViewportIndex:
10026 				if (!msl_options.supports_msl_version(2, 0))
10027 					SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
10028 				/* fallthrough */
10029 			case BuiltInPosition:
10030 			case BuiltInLayer:
10031 				return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10032 
10033 			case BuiltInClipDistance:
10034 				if (has_member_decoration(type.self, index, DecorationLocation))
10035 					return join(" [[user(clip", get_member_decoration(type.self, index, DecorationLocation), ")]]");
10036 				else
10037 					return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10038 
10039 			default:
10040 				return "";
10041 			}
10042 		}
10043 		uint32_t comp;
10044 		uint32_t locn = get_ordered_member_location(type.self, index, &comp);
10045 		if (locn != k_unknown_location)
10046 		{
10047 			if (comp != k_unknown_component)
10048 				return string(" [[user(locn") + convert_to_string(locn) + "_" + convert_to_string(comp) + ")]]";
10049 			else
10050 				return string(" [[user(locn") + convert_to_string(locn) + ")]]";
10051 		}
10052 	}
10053 
10054 	// Tessellation control function inputs
10055 	if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassInput)
10056 	{
10057 		if (is_builtin)
10058 		{
10059 			switch (builtin)
10060 			{
10061 			case BuiltInInvocationId:
10062 			case BuiltInPrimitiveId:
10063 				if (msl_options.multi_patch_workgroup)
10064 					return "";
10065 				return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10066 			case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
10067 			case BuiltInSubgroupSize: // FIXME: Should work in any stage
10068 				if (msl_options.emulate_subgroups)
10069 					return "";
10070 				return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10071 			case BuiltInPatchVertices:
10072 				return "";
10073 			// Others come from stage input.
10074 			default:
10075 				break;
10076 			}
10077 		}
10078 		if (msl_options.multi_patch_workgroup)
10079 			return "";
10080 		uint32_t locn = get_ordered_member_location(type.self, index);
10081 		if (locn != k_unknown_location)
10082 			return string(" [[attribute(") + convert_to_string(locn) + ")]]";
10083 	}
10084 
10085 	// Tessellation control function outputs
10086 	if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassOutput)
10087 	{
10088 		// For this type of shader, we always arrange for it to capture its
10089 		// output to a buffer. For this reason, qualifiers are irrelevant here.
10090 		return "";
10091 	}
10092 
10093 	// Tessellation evaluation function inputs
10094 	if (execution.model == ExecutionModelTessellationEvaluation && type.storage == StorageClassInput)
10095 	{
10096 		if (is_builtin)
10097 		{
10098 			switch (builtin)
10099 			{
10100 			case BuiltInPrimitiveId:
10101 			case BuiltInTessCoord:
10102 				return string(" [[") + builtin_qualifier(builtin) + "]]";
10103 			case BuiltInPatchVertices:
10104 				return "";
10105 			// Others come from stage input.
10106 			default:
10107 				break;
10108 			}
10109 		}
10110 		// The special control point array must not be marked with an attribute.
10111 		if (get_type(type.member_types[index]).basetype == SPIRType::ControlPointArray)
10112 			return "";
10113 		uint32_t locn = get_ordered_member_location(type.self, index);
10114 		if (locn != k_unknown_location)
10115 			return string(" [[attribute(") + convert_to_string(locn) + ")]]";
10116 	}
10117 
10118 	// Tessellation evaluation function outputs were handled above.
10119 
10120 	// Fragment function inputs
10121 	if (execution.model == ExecutionModelFragment && type.storage == StorageClassInput)
10122 	{
10123 		string quals;
10124 		if (is_builtin)
10125 		{
10126 			switch (builtin)
10127 			{
10128 			case BuiltInViewIndex:
10129 				if (!msl_options.multiview || !msl_options.multiview_layered_rendering)
10130 					break;
10131 				/* fallthrough */
10132 			case BuiltInFrontFacing:
10133 			case BuiltInPointCoord:
10134 			case BuiltInFragCoord:
10135 			case BuiltInSampleId:
10136 			case BuiltInSampleMask:
10137 			case BuiltInLayer:
10138 			case BuiltInBaryCoordNV:
10139 			case BuiltInBaryCoordNoPerspNV:
10140 				quals = builtin_qualifier(builtin);
10141 				break;
10142 
10143 			case BuiltInClipDistance:
10144 				return join(" [[user(clip", get_member_decoration(type.self, index, DecorationLocation), ")]]");
10145 
10146 			default:
10147 				break;
10148 			}
10149 		}
10150 		else
10151 		{
10152 			uint32_t comp;
10153 			uint32_t locn = get_ordered_member_location(type.self, index, &comp);
10154 			if (locn != k_unknown_location)
10155 			{
10156 				// For user-defined attributes, this is fine. From Vulkan spec:
10157 				// A user-defined output variable is considered to match an input variable in the subsequent stage if
10158 				// the two variables are declared with the same Location and Component decoration and match in type
10159 				// and decoration, except that interpolation decorations are not required to match. For the purposes
10160 				// of interface matching, variables declared without a Component decoration are considered to have a
10161 				// Component decoration of zero.
10162 
10163 				if (comp != k_unknown_component && comp != 0)
10164 					quals = string("user(locn") + convert_to_string(locn) + "_" + convert_to_string(comp) + ")";
10165 				else
10166 					quals = string("user(locn") + convert_to_string(locn) + ")";
10167 			}
10168 		}
10169 
10170 		if (builtin == BuiltInBaryCoordNV || builtin == BuiltInBaryCoordNoPerspNV)
10171 		{
10172 			if (has_member_decoration(type.self, index, DecorationFlat) ||
10173 			    has_member_decoration(type.self, index, DecorationCentroid) ||
10174 			    has_member_decoration(type.self, index, DecorationSample) ||
10175 			    has_member_decoration(type.self, index, DecorationNoPerspective))
10176 			{
10177 				// NoPerspective is baked into the builtin type.
10178 				SPIRV_CROSS_THROW(
10179 				    "Flat, Centroid, Sample, NoPerspective decorations are not supported for BaryCoord inputs.");
10180 			}
10181 		}
10182 
10183 		// Don't bother decorating integers with the 'flat' attribute; it's
10184 		// the default (in fact, the only option). Also don't bother with the
10185 		// FragCoord builtin; it's always noperspective on Metal.
10186 		if (!type_is_integral(mbr_type) && (!is_builtin || builtin != BuiltInFragCoord))
10187 		{
10188 			if (has_member_decoration(type.self, index, DecorationFlat))
10189 			{
10190 				if (!quals.empty())
10191 					quals += ", ";
10192 				quals += "flat";
10193 			}
10194 			else if (has_member_decoration(type.self, index, DecorationCentroid))
10195 			{
10196 				if (!quals.empty())
10197 					quals += ", ";
10198 				if (has_member_decoration(type.self, index, DecorationNoPerspective))
10199 					quals += "centroid_no_perspective";
10200 				else
10201 					quals += "centroid_perspective";
10202 			}
10203 			else if (has_member_decoration(type.self, index, DecorationSample))
10204 			{
10205 				if (!quals.empty())
10206 					quals += ", ";
10207 				if (has_member_decoration(type.self, index, DecorationNoPerspective))
10208 					quals += "sample_no_perspective";
10209 				else
10210 					quals += "sample_perspective";
10211 			}
10212 			else if (has_member_decoration(type.self, index, DecorationNoPerspective))
10213 			{
10214 				if (!quals.empty())
10215 					quals += ", ";
10216 				quals += "center_no_perspective";
10217 			}
10218 		}
10219 
10220 		if (!quals.empty())
10221 			return " [[" + quals + "]]";
10222 	}
10223 
10224 	// Fragment function outputs
10225 	if (execution.model == ExecutionModelFragment && type.storage == StorageClassOutput)
10226 	{
10227 		if (is_builtin)
10228 		{
10229 			switch (builtin)
10230 			{
10231 			case BuiltInFragStencilRefEXT:
10232 				// Similar to PointSize, only mark FragStencilRef if there's a stencil buffer.
10233 				// Some shaders may include a FragStencilRef builtin even when used to render
10234 				// without a stencil attachment, and Metal will reject this builtin
10235 				// when compiling the shader into a render pipeline that does not set
10236 				// stencilAttachmentPixelFormat.
10237 				if (!msl_options.enable_frag_stencil_ref_builtin)
10238 					return "";
10239 				if (!msl_options.supports_msl_version(2, 1))
10240 					SPIRV_CROSS_THROW("Stencil export only supported in MSL 2.1 and up.");
10241 				return string(" [[") + builtin_qualifier(builtin) + "]]";
10242 
10243 			case BuiltInFragDepth:
10244 				// Ditto FragDepth.
10245 				if (!msl_options.enable_frag_depth_builtin)
10246 					return "";
10247 				/* fallthrough */
10248 			case BuiltInSampleMask:
10249 				return string(" [[") + builtin_qualifier(builtin) + "]]";
10250 
10251 			default:
10252 				return "";
10253 			}
10254 		}
10255 		uint32_t locn = get_ordered_member_location(type.self, index);
10256 		// Metal will likely complain about missing color attachments, too.
10257 		if (locn != k_unknown_location && !(msl_options.enable_frag_output_mask & (1 << locn)))
10258 			return "";
10259 		if (locn != k_unknown_location && has_member_decoration(type.self, index, DecorationIndex))
10260 			return join(" [[color(", locn, "), index(", get_member_decoration(type.self, index, DecorationIndex),
10261 			            ")]]");
10262 		else if (locn != k_unknown_location)
10263 			return join(" [[color(", locn, ")]]");
10264 		else if (has_member_decoration(type.self, index, DecorationIndex))
10265 			return join(" [[index(", get_member_decoration(type.self, index, DecorationIndex), ")]]");
10266 		else
10267 			return "";
10268 	}
10269 
10270 	// Compute function inputs
10271 	if (execution.model == ExecutionModelGLCompute && type.storage == StorageClassInput)
10272 	{
10273 		if (is_builtin)
10274 		{
10275 			switch (builtin)
10276 			{
10277 			case BuiltInNumSubgroups:
10278 			case BuiltInSubgroupId:
10279 			case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
10280 			case BuiltInSubgroupSize: // FIXME: Should work in any stage
10281 				if (msl_options.emulate_subgroups)
10282 					break;
10283 				/* fallthrough */
10284 			case BuiltInGlobalInvocationId:
10285 			case BuiltInWorkgroupId:
10286 			case BuiltInNumWorkgroups:
10287 			case BuiltInLocalInvocationId:
10288 			case BuiltInLocalInvocationIndex:
10289 				return string(" [[") + builtin_qualifier(builtin) + "]]";
10290 
10291 			default:
10292 				return "";
10293 			}
10294 		}
10295 	}
10296 
10297 	return "";
10298 }
10299 
10300 // Returns the location decoration of the member with the specified index in the specified type.
10301 // If the location of the member has been explicitly set, that location is used. If not, this
10302 // function assumes the members are ordered in their location order, and simply returns the
10303 // index as the location.
get_ordered_member_location(uint32_t type_id,uint32_t index,uint32_t * comp)10304 uint32_t CompilerMSL::get_ordered_member_location(uint32_t type_id, uint32_t index, uint32_t *comp)
10305 {
10306 	auto &m = ir.meta[type_id];
10307 	if (index < m.members.size())
10308 	{
10309 		auto &dec = m.members[index];
10310 		if (comp)
10311 		{
10312 			if (dec.decoration_flags.get(DecorationComponent))
10313 				*comp = dec.component;
10314 			else
10315 				*comp = k_unknown_component;
10316 		}
10317 		if (dec.decoration_flags.get(DecorationLocation))
10318 			return dec.location;
10319 	}
10320 
10321 	return index;
10322 }
10323 
10324 // Returns the type declaration for a function, including the
10325 // entry type if the current function is the entry point function
func_type_decl(SPIRType & type)10326 string CompilerMSL::func_type_decl(SPIRType &type)
10327 {
10328 	// The regular function return type. If not processing the entry point function, that's all we need
10329 	string return_type = type_to_glsl(type) + type_to_array_glsl(type);
10330 	if (!processing_entry_point)
10331 		return return_type;
10332 
10333 	// If an outgoing interface block has been defined, and it should be returned, override the entry point return type
10334 	bool ep_should_return_output = !get_is_rasterization_disabled();
10335 	if (stage_out_var_id && ep_should_return_output)
10336 		return_type = type_to_glsl(get_stage_out_struct_type()) + type_to_array_glsl(type);
10337 
10338 	// Prepend a entry type, based on the execution model
10339 	string entry_type;
10340 	auto &execution = get_entry_point();
10341 	switch (execution.model)
10342 	{
10343 	case ExecutionModelVertex:
10344 		if (msl_options.vertex_for_tessellation && !msl_options.supports_msl_version(1, 2))
10345 			SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
10346 		entry_type = msl_options.vertex_for_tessellation ? "kernel" : "vertex";
10347 		break;
10348 	case ExecutionModelTessellationEvaluation:
10349 		if (!msl_options.supports_msl_version(1, 2))
10350 			SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
10351 		if (execution.flags.get(ExecutionModeIsolines))
10352 			SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
10353 		if (msl_options.is_ios())
10354 			entry_type =
10355 			    join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ") ]] vertex");
10356 		else
10357 			entry_type = join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ", ",
10358 			                  execution.output_vertices, ") ]] vertex");
10359 		break;
10360 	case ExecutionModelFragment:
10361 		entry_type = execution.flags.get(ExecutionModeEarlyFragmentTests) ||
10362 		                     execution.flags.get(ExecutionModePostDepthCoverage) ?
10363 		                 "[[ early_fragment_tests ]] fragment" :
10364 		                 "fragment";
10365 		break;
10366 	case ExecutionModelTessellationControl:
10367 		if (!msl_options.supports_msl_version(1, 2))
10368 			SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
10369 		if (execution.flags.get(ExecutionModeIsolines))
10370 			SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
10371 		/* fallthrough */
10372 	case ExecutionModelGLCompute:
10373 	case ExecutionModelKernel:
10374 		entry_type = "kernel";
10375 		break;
10376 	default:
10377 		entry_type = "unknown";
10378 		break;
10379 	}
10380 
10381 	return entry_type + " " + return_type;
10382 }
10383 
10384 // In MSL, address space qualifiers are required for all pointer or reference variables
get_argument_address_space(const SPIRVariable & argument)10385 string CompilerMSL::get_argument_address_space(const SPIRVariable &argument)
10386 {
10387 	const auto &type = get<SPIRType>(argument.basetype);
10388 	return get_type_address_space(type, argument.self, true);
10389 }
10390 
get_type_address_space(const SPIRType & type,uint32_t id,bool argument)10391 string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bool argument)
10392 {
10393 	// This can be called for variable pointer contexts as well, so be very careful about which method we choose.
10394 	Bitset flags;
10395 	auto *var = maybe_get<SPIRVariable>(id);
10396 	if (var && type.basetype == SPIRType::Struct &&
10397 	    (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
10398 		flags = get_buffer_block_flags(id);
10399 	else
10400 		flags = get_decoration_bitset(id);
10401 
10402 	const char *addr_space = nullptr;
10403 	switch (type.storage)
10404 	{
10405 	case StorageClassWorkgroup:
10406 		addr_space = "threadgroup";
10407 		break;
10408 
10409 	case StorageClassStorageBuffer:
10410 	{
10411 		// For arguments from variable pointers, we use the write count deduction, so
10412 		// we should not assume any constness here. Only for global SSBOs.
10413 		bool readonly = false;
10414 		if (!var || has_decoration(type.self, DecorationBlock))
10415 			readonly = flags.get(DecorationNonWritable);
10416 
10417 		addr_space = readonly ? "const device" : "device";
10418 		break;
10419 	}
10420 
10421 	case StorageClassUniform:
10422 	case StorageClassUniformConstant:
10423 	case StorageClassPushConstant:
10424 		if (type.basetype == SPIRType::Struct)
10425 		{
10426 			bool ssbo = has_decoration(type.self, DecorationBufferBlock);
10427 			if (ssbo)
10428 				addr_space = flags.get(DecorationNonWritable) ? "const device" : "device";
10429 			else
10430 				addr_space = "constant";
10431 		}
10432 		else if (!argument)
10433 		{
10434 			addr_space = "constant";
10435 		}
10436 		else if (type_is_msl_framebuffer_fetch(type))
10437 		{
10438 			// Subpass inputs are passed around by value.
10439 			addr_space = "";
10440 		}
10441 		break;
10442 
10443 	case StorageClassFunction:
10444 	case StorageClassGeneric:
10445 		break;
10446 
10447 	case StorageClassInput:
10448 		if (get_execution_model() == ExecutionModelTessellationControl && var &&
10449 		    var->basevariable == stage_in_ptr_var_id)
10450 			addr_space = msl_options.multi_patch_workgroup ? "constant" : "threadgroup";
10451 		if (get_execution_model() == ExecutionModelFragment && var && var->basevariable == stage_in_var_id)
10452 			addr_space = "thread";
10453 		break;
10454 
10455 	case StorageClassOutput:
10456 		if (capture_output_to_buffer)
10457 			addr_space = "device";
10458 		break;
10459 
10460 	default:
10461 		break;
10462 	}
10463 
10464 	if (!addr_space)
10465 		// No address space for plain values.
10466 		addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : "";
10467 
10468 	return join(flags.get(DecorationVolatile) || flags.get(DecorationCoherent) ? "volatile " : "", addr_space);
10469 }
10470 
to_restrict(uint32_t id,bool space)10471 const char *CompilerMSL::to_restrict(uint32_t id, bool space)
10472 {
10473 	// This can be called for variable pointer contexts as well, so be very careful about which method we choose.
10474 	Bitset flags;
10475 	if (ir.ids[id].get_type() == TypeVariable)
10476 	{
10477 		uint32_t type_id = expression_type_id(id);
10478 		auto &type = expression_type(id);
10479 		if (type.basetype == SPIRType::Struct &&
10480 		    (has_decoration(type_id, DecorationBlock) || has_decoration(type_id, DecorationBufferBlock)))
10481 			flags = get_buffer_block_flags(id);
10482 		else
10483 			flags = get_decoration_bitset(id);
10484 	}
10485 	else
10486 		flags = get_decoration_bitset(id);
10487 
10488 	return flags.get(DecorationRestrict) ? (space ? "restrict " : "restrict") : "";
10489 }
10490 
entry_point_arg_stage_in()10491 string CompilerMSL::entry_point_arg_stage_in()
10492 {
10493 	string decl;
10494 
10495 	if (get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup)
10496 		return decl;
10497 
10498 	// Stage-in structure
10499 	uint32_t stage_in_id;
10500 	if (get_execution_model() == ExecutionModelTessellationEvaluation)
10501 		stage_in_id = patch_stage_in_var_id;
10502 	else
10503 		stage_in_id = stage_in_var_id;
10504 
10505 	if (stage_in_id)
10506 	{
10507 		auto &var = get<SPIRVariable>(stage_in_id);
10508 		auto &type = get_variable_data_type(var);
10509 
10510 		add_resource_name(var.self);
10511 		decl = join(type_to_glsl(type), " ", to_name(var.self), " [[stage_in]]");
10512 	}
10513 
10514 	return decl;
10515 }
10516 
10517 // Returns true if this input builtin should be a direct parameter on a shader function parameter list,
10518 // and false for builtins that should be passed or calculated some other way.
is_direct_input_builtin(BuiltIn bi_type)10519 bool CompilerMSL::is_direct_input_builtin(BuiltIn bi_type)
10520 {
10521 	switch (bi_type)
10522 	{
10523 	// Vertex function in
10524 	case BuiltInVertexId:
10525 	case BuiltInVertexIndex:
10526 	case BuiltInBaseVertex:
10527 	case BuiltInInstanceId:
10528 	case BuiltInInstanceIndex:
10529 	case BuiltInBaseInstance:
10530 		return get_execution_model() != ExecutionModelVertex || !msl_options.vertex_for_tessellation;
10531 	// Tess. control function in
10532 	case BuiltInPosition:
10533 	case BuiltInPointSize:
10534 	case BuiltInClipDistance:
10535 	case BuiltInCullDistance:
10536 	case BuiltInPatchVertices:
10537 		return false;
10538 	case BuiltInInvocationId:
10539 	case BuiltInPrimitiveId:
10540 		return get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup;
10541 	// Tess. evaluation function in
10542 	case BuiltInTessLevelInner:
10543 	case BuiltInTessLevelOuter:
10544 		return false;
10545 	// Fragment function in
10546 	case BuiltInSamplePosition:
10547 	case BuiltInHelperInvocation:
10548 	case BuiltInBaryCoordNV:
10549 	case BuiltInBaryCoordNoPerspNV:
10550 		return false;
10551 	case BuiltInViewIndex:
10552 		return get_execution_model() == ExecutionModelFragment && msl_options.multiview &&
10553 		       msl_options.multiview_layered_rendering;
10554 	// Compute function in
10555 	case BuiltInSubgroupId:
10556 	case BuiltInNumSubgroups:
10557 		return !msl_options.emulate_subgroups;
10558 	// Any stage function in
10559 	case BuiltInDeviceIndex:
10560 	case BuiltInSubgroupEqMask:
10561 	case BuiltInSubgroupGeMask:
10562 	case BuiltInSubgroupGtMask:
10563 	case BuiltInSubgroupLeMask:
10564 	case BuiltInSubgroupLtMask:
10565 		return false;
10566 	case BuiltInSubgroupSize:
10567 		if (msl_options.fixed_subgroup_size != 0)
10568 			return false;
10569 		/* fallthrough */
10570 	case BuiltInSubgroupLocalInvocationId:
10571 		return !msl_options.emulate_subgroups;
10572 	default:
10573 		return true;
10574 	}
10575 }
10576 
10577 // Returns true if this is a fragment shader that runs per sample, and false otherwise.
is_sample_rate() const10578 bool CompilerMSL::is_sample_rate() const
10579 {
10580 	auto &caps = get_declared_capabilities();
10581 	return get_execution_model() == ExecutionModelFragment &&
10582 	       (msl_options.force_sample_rate_shading ||
10583 	        std::find(caps.begin(), caps.end(), CapabilitySampleRateShading) != caps.end() ||
10584 	        (msl_options.use_framebuffer_fetch_subpasses && need_subpass_input));
10585 }
10586 
entry_point_args_builtin(string & ep_args)10587 void CompilerMSL::entry_point_args_builtin(string &ep_args)
10588 {
10589 	// Builtin variables
10590 	SmallVector<pair<SPIRVariable *, BuiltIn>, 8> active_builtins;
10591 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
10592 		if (var.storage != StorageClassInput)
10593 			return;
10594 
10595 		auto bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
10596 
10597 		// Don't emit SamplePosition as a separate parameter. In the entry
10598 		// point, we get that by calling get_sample_position() on the sample ID.
10599 		if (is_builtin_variable(var) &&
10600 		    get_variable_data_type(var).basetype != SPIRType::Struct &&
10601 		    get_variable_data_type(var).basetype != SPIRType::ControlPointArray)
10602 		{
10603 			// If the builtin is not part of the active input builtin set, don't emit it.
10604 			// Relevant for multiple entry-point modules which might declare unused builtins.
10605 			if (!active_input_builtins.get(bi_type) || !interface_variable_exists_in_entry_point(var_id))
10606 				return;
10607 
10608 			// Remember this variable. We may need to correct its type.
10609 			active_builtins.push_back(make_pair(&var, bi_type));
10610 
10611 			if (is_direct_input_builtin(bi_type))
10612 			{
10613 				if (!ep_args.empty())
10614 					ep_args += ", ";
10615 
10616 				// Handle HLSL-style 0-based vertex/instance index.
10617 				builtin_declaration = true;
10618 				ep_args += builtin_type_decl(bi_type, var_id) + " " + to_expression(var_id);
10619 				ep_args += " [[" + builtin_qualifier(bi_type);
10620 				if (bi_type == BuiltInSampleMask && get_entry_point().flags.get(ExecutionModePostDepthCoverage))
10621 				{
10622 					if (!msl_options.supports_msl_version(2))
10623 						SPIRV_CROSS_THROW("Post-depth coverage requires MSL 2.0.");
10624 					if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 3))
10625 						SPIRV_CROSS_THROW("Post-depth coverage on Mac requires MSL 2.3.");
10626 					ep_args += ", post_depth_coverage";
10627 				}
10628 				ep_args += "]]";
10629 				builtin_declaration = false;
10630 			}
10631 		}
10632 
10633 		if (has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase))
10634 		{
10635 			// This is a special implicit builtin, not corresponding to any SPIR-V builtin,
10636 			// which holds the base that was passed to vkCmdDispatchBase() or vkCmdDrawIndexed(). If it's present,
10637 			// assume we emitted it for a good reason.
10638 			assert(msl_options.supports_msl_version(1, 2));
10639 			if (!ep_args.empty())
10640 				ep_args += ", ";
10641 
10642 			ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_origin]]";
10643 		}
10644 
10645 		if (has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInStageInputSize))
10646 		{
10647 			// This is another special implicit builtin, not corresponding to any SPIR-V builtin,
10648 			// which holds the number of vertices and instances to draw. If it's present,
10649 			// assume we emitted it for a good reason.
10650 			assert(msl_options.supports_msl_version(1, 2));
10651 			if (!ep_args.empty())
10652 				ep_args += ", ";
10653 
10654 			ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_size]]";
10655 		}
10656 	});
10657 
10658 	// Correct the types of all encountered active builtins. We couldn't do this before
10659 	// because ensure_correct_builtin_type() may increase the bound, which isn't allowed
10660 	// while iterating over IDs.
10661 	for (auto &var : active_builtins)
10662 		var.first->basetype = ensure_correct_builtin_type(var.first->basetype, var.second);
10663 
10664 	// Handle HLSL-style 0-based vertex/instance index.
10665 	if (needs_base_vertex_arg == TriState::Yes)
10666 		ep_args += built_in_func_arg(BuiltInBaseVertex, !ep_args.empty());
10667 
10668 	if (needs_base_instance_arg == TriState::Yes)
10669 		ep_args += built_in_func_arg(BuiltInBaseInstance, !ep_args.empty());
10670 
10671 	if (capture_output_to_buffer)
10672 	{
10673 		// Add parameters to hold the indirect draw parameters and the shader output. This has to be handled
10674 		// specially because it needs to be a pointer, not a reference.
10675 		if (stage_out_var_id)
10676 		{
10677 			if (!ep_args.empty())
10678 				ep_args += ", ";
10679 			ep_args += join("device ", type_to_glsl(get_stage_out_struct_type()), "* ", output_buffer_var_name,
10680 			                " [[buffer(", msl_options.shader_output_buffer_index, ")]]");
10681 		}
10682 
10683 		if (get_execution_model() == ExecutionModelTessellationControl)
10684 		{
10685 			if (!ep_args.empty())
10686 				ep_args += ", ";
10687 			ep_args +=
10688 			    join("constant uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
10689 		}
10690 		else if (stage_out_var_id &&
10691 		         !(get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
10692 		{
10693 			if (!ep_args.empty())
10694 				ep_args += ", ";
10695 			ep_args +=
10696 			    join("device uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
10697 		}
10698 
10699 		if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation &&
10700 		    (active_input_builtins.get(BuiltInVertexIndex) || active_input_builtins.get(BuiltInVertexId)) &&
10701 		    msl_options.vertex_index_type != Options::IndexType::None)
10702 		{
10703 			// Add the index buffer so we can set gl_VertexIndex correctly.
10704 			if (!ep_args.empty())
10705 				ep_args += ", ";
10706 			switch (msl_options.vertex_index_type)
10707 			{
10708 			case Options::IndexType::None:
10709 				break;
10710 			case Options::IndexType::UInt16:
10711 				ep_args += join("const device ushort* ", index_buffer_var_name, " [[buffer(",
10712 				                msl_options.shader_index_buffer_index, ")]]");
10713 				break;
10714 			case Options::IndexType::UInt32:
10715 				ep_args += join("const device uint* ", index_buffer_var_name, " [[buffer(",
10716 				                msl_options.shader_index_buffer_index, ")]]");
10717 				break;
10718 			}
10719 		}
10720 
10721 		// Tessellation control shaders get three additional parameters:
10722 		// a buffer to hold the per-patch data, a buffer to hold the per-patch
10723 		// tessellation levels, and a block of workgroup memory to hold the
10724 		// input control point data.
10725 		if (get_execution_model() == ExecutionModelTessellationControl)
10726 		{
10727 			if (patch_stage_out_var_id)
10728 			{
10729 				if (!ep_args.empty())
10730 					ep_args += ", ";
10731 				ep_args +=
10732 				    join("device ", type_to_glsl(get_patch_stage_out_struct_type()), "* ", patch_output_buffer_var_name,
10733 				         " [[buffer(", convert_to_string(msl_options.shader_patch_output_buffer_index), ")]]");
10734 			}
10735 			if (!ep_args.empty())
10736 				ep_args += ", ";
10737 			ep_args += join("device ", get_tess_factor_struct_name(), "* ", tess_factor_buffer_var_name, " [[buffer(",
10738 			                convert_to_string(msl_options.shader_tess_factor_buffer_index), ")]]");
10739 
10740 			// Initializer for tess factors must be handled specially since it's never declared as a normal variable.
10741 			uint32_t outer_factor_initializer_id = 0;
10742 			uint32_t inner_factor_initializer_id = 0;
10743 			ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
10744 				if (!has_decoration(var.self, DecorationBuiltIn) || var.storage != StorageClassOutput || !var.initializer)
10745 					return;
10746 
10747 				BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
10748 				if (builtin == BuiltInTessLevelInner)
10749 					inner_factor_initializer_id = var.initializer;
10750 				else if (builtin == BuiltInTessLevelOuter)
10751 					outer_factor_initializer_id = var.initializer;
10752 			});
10753 
10754 			const SPIRConstant *c = nullptr;
10755 
10756 			if (outer_factor_initializer_id && (c = maybe_get<SPIRConstant>(outer_factor_initializer_id)))
10757 			{
10758 				auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
10759 				entry_func.fixup_hooks_in.push_back([=]() {
10760 					uint32_t components = get_execution_mode_bitset().get(ExecutionModeTriangles) ? 3 : 4;
10761 					for (uint32_t i = 0; i < components; i++)
10762 					{
10763 						statement(builtin_to_glsl(BuiltInTessLevelOuter, StorageClassOutput), "[", i, "] = ",
10764 						          "half(", to_expression(c->subconstants[i]), ");");
10765 					}
10766 				});
10767 			}
10768 
10769 			if (inner_factor_initializer_id && (c = maybe_get<SPIRConstant>(inner_factor_initializer_id)))
10770 			{
10771 				auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
10772 				if (get_execution_mode_bitset().get(ExecutionModeTriangles))
10773 				{
10774 					entry_func.fixup_hooks_in.push_back([=]() {
10775 						statement(builtin_to_glsl(BuiltInTessLevelInner, StorageClassOutput), " = ", "half(",
10776 						          to_expression(c->subconstants[0]), ");");
10777 					});
10778 				}
10779 				else
10780 				{
10781 					entry_func.fixup_hooks_in.push_back([=]() {
10782 						for (uint32_t i = 0; i < 2; i++)
10783 						{
10784 							statement(builtin_to_glsl(BuiltInTessLevelInner, StorageClassOutput), "[", i, "] = ",
10785 							          "half(", to_expression(c->subconstants[i]), ");");
10786 						}
10787 					});
10788 				}
10789 			}
10790 
10791 			if (stage_in_var_id)
10792 			{
10793 				if (!ep_args.empty())
10794 					ep_args += ", ";
10795 				if (msl_options.multi_patch_workgroup)
10796 				{
10797 					ep_args += join("device ", type_to_glsl(get_stage_in_struct_type()), "* ", input_buffer_var_name,
10798 					                " [[buffer(", convert_to_string(msl_options.shader_input_buffer_index), ")]]");
10799 				}
10800 				else
10801 				{
10802 					ep_args += join("threadgroup ", type_to_glsl(get_stage_in_struct_type()), "* ", input_wg_var_name,
10803 					                " [[threadgroup(", convert_to_string(msl_options.shader_input_wg_index), ")]]");
10804 				}
10805 			}
10806 		}
10807 	}
10808 }
10809 
entry_point_args_argument_buffer(bool append_comma)10810 string CompilerMSL::entry_point_args_argument_buffer(bool append_comma)
10811 {
10812 	string ep_args = entry_point_arg_stage_in();
10813 	Bitset claimed_bindings;
10814 
10815 	for (uint32_t i = 0; i < kMaxArgumentBuffers; i++)
10816 	{
10817 		uint32_t id = argument_buffer_ids[i];
10818 		if (id == 0)
10819 			continue;
10820 
10821 		add_resource_name(id);
10822 		auto &var = get<SPIRVariable>(id);
10823 		auto &type = get_variable_data_type(var);
10824 
10825 		if (!ep_args.empty())
10826 			ep_args += ", ";
10827 
10828 		// Check if the argument buffer binding itself has been remapped.
10829 		uint32_t buffer_binding;
10830 		auto itr = resource_bindings.find({ get_entry_point().model, i, kArgumentBufferBinding });
10831 		if (itr != end(resource_bindings))
10832 		{
10833 			buffer_binding = itr->second.first.msl_buffer;
10834 			itr->second.second = true;
10835 		}
10836 		else
10837 		{
10838 			// As a fallback, directly map desc set <-> binding.
10839 			// If that was taken, take the next buffer binding.
10840 			if (claimed_bindings.get(i))
10841 				buffer_binding = next_metal_resource_index_buffer;
10842 			else
10843 				buffer_binding = i;
10844 		}
10845 
10846 		claimed_bindings.set(buffer_binding);
10847 
10848 		ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(id) + to_name(id);
10849 		ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]";
10850 
10851 		next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1);
10852 	}
10853 
10854 	entry_point_args_discrete_descriptors(ep_args);
10855 	entry_point_args_builtin(ep_args);
10856 
10857 	if (!ep_args.empty() && append_comma)
10858 		ep_args += ", ";
10859 
10860 	return ep_args;
10861 }
10862 
find_constexpr_sampler(uint32_t id) const10863 const MSLConstexprSampler *CompilerMSL::find_constexpr_sampler(uint32_t id) const
10864 {
10865 	// Try by ID.
10866 	{
10867 		auto itr = constexpr_samplers_by_id.find(id);
10868 		if (itr != end(constexpr_samplers_by_id))
10869 			return &itr->second;
10870 	}
10871 
10872 	// Try by binding.
10873 	{
10874 		uint32_t desc_set = get_decoration(id, DecorationDescriptorSet);
10875 		uint32_t binding = get_decoration(id, DecorationBinding);
10876 
10877 		auto itr = constexpr_samplers_by_binding.find({ desc_set, binding });
10878 		if (itr != end(constexpr_samplers_by_binding))
10879 			return &itr->second;
10880 	}
10881 
10882 	return nullptr;
10883 }
10884 
entry_point_args_discrete_descriptors(string & ep_args)10885 void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
10886 {
10887 	// Output resources, sorted by resource index & type
10888 	// We need to sort to work around a bug on macOS 10.13 with NVidia drivers where switching between shaders
10889 	// with different order of buffers can result in issues with buffer assignments inside the driver.
10890 	struct Resource
10891 	{
10892 		SPIRVariable *var;
10893 		string name;
10894 		SPIRType::BaseType basetype;
10895 		uint32_t index;
10896 		uint32_t plane;
10897 		uint32_t secondary_index;
10898 	};
10899 
10900 	SmallVector<Resource> resources;
10901 
10902 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
10903 		if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
10904 		     var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer) &&
10905 		    !is_hidden_variable(var))
10906 		{
10907 			auto &type = get_variable_data_type(var);
10908 
10909 			if (is_supported_argument_buffer_type(type) && var.storage != StorageClassPushConstant)
10910 			{
10911 				uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
10912 				if (descriptor_set_is_argument_buffer(desc_set))
10913 					return;
10914 			}
10915 
10916 			const MSLConstexprSampler *constexpr_sampler = nullptr;
10917 			if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
10918 			{
10919 				constexpr_sampler = find_constexpr_sampler(var_id);
10920 				if (constexpr_sampler)
10921 				{
10922 					// Mark this ID as a constexpr sampler for later in case it came from set/bindings.
10923 					constexpr_samplers_by_id[var_id] = *constexpr_sampler;
10924 				}
10925 			}
10926 
10927 			// Emulate texture2D atomic operations
10928 			uint32_t secondary_index = 0;
10929 			if (atomic_image_vars.count(var.self))
10930 			{
10931 				secondary_index = get_metal_resource_index(var, SPIRType::AtomicCounter, 0);
10932 			}
10933 
10934 			if (type.basetype == SPIRType::SampledImage)
10935 			{
10936 				add_resource_name(var_id);
10937 
10938 				uint32_t plane_count = 1;
10939 				if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
10940 					plane_count = constexpr_sampler->planes;
10941 
10942 				for (uint32_t i = 0; i < plane_count; i++)
10943 					resources.push_back({ &var, to_name(var_id), SPIRType::Image,
10944 					                      get_metal_resource_index(var, SPIRType::Image, i), i, secondary_index });
10945 
10946 				if (type.image.dim != DimBuffer && !constexpr_sampler)
10947 				{
10948 					resources.push_back({ &var, to_sampler_expression(var_id), SPIRType::Sampler,
10949 					                      get_metal_resource_index(var, SPIRType::Sampler), 0, 0 });
10950 				}
10951 			}
10952 			else if (!constexpr_sampler)
10953 			{
10954 				// constexpr samplers are not declared as resources.
10955 				add_resource_name(var_id);
10956 				resources.push_back({ &var, to_name(var_id), type.basetype,
10957 				                      get_metal_resource_index(var, type.basetype), 0, secondary_index });
10958 			}
10959 		}
10960 	});
10961 
10962 	sort(resources.begin(), resources.end(), [](const Resource &lhs, const Resource &rhs) {
10963 		return tie(lhs.basetype, lhs.index) < tie(rhs.basetype, rhs.index);
10964 	});
10965 
10966 	for (auto &r : resources)
10967 	{
10968 		auto &var = *r.var;
10969 		auto &type = get_variable_data_type(var);
10970 
10971 		uint32_t var_id = var.self;
10972 
10973 		switch (r.basetype)
10974 		{
10975 		case SPIRType::Struct:
10976 		{
10977 			auto &m = ir.meta[type.self];
10978 			if (m.members.size() == 0)
10979 				break;
10980 			if (!type.array.empty())
10981 			{
10982 				if (type.array.size() > 1)
10983 					SPIRV_CROSS_THROW("Arrays of arrays of buffers are not supported.");
10984 
10985 				// Metal doesn't directly support this, so we must expand the
10986 				// array. We'll declare a local array to hold these elements
10987 				// later.
10988 				uint32_t array_size = to_array_size_literal(type);
10989 
10990 				if (array_size == 0)
10991 					SPIRV_CROSS_THROW("Unsized arrays of buffers are not supported in MSL.");
10992 
10993 				// Allow Metal to use the array<T> template to make arrays a value type
10994 				is_using_builtin_array = true;
10995 				buffer_arrays.push_back(var_id);
10996 				for (uint32_t i = 0; i < array_size; ++i)
10997 				{
10998 					if (!ep_args.empty())
10999 						ep_args += ", ";
11000 					ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "* " + to_restrict(var_id) +
11001 					           r.name + "_" + convert_to_string(i);
11002 					ep_args += " [[buffer(" + convert_to_string(r.index + i) + ")";
11003 					if (interlocked_resources.count(var_id))
11004 						ep_args += ", raster_order_group(0)";
11005 					ep_args += "]]";
11006 				}
11007 				is_using_builtin_array = false;
11008 			}
11009 			else
11010 			{
11011 				if (!ep_args.empty())
11012 					ep_args += ", ";
11013 				ep_args +=
11014 				    get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(var_id) + r.name;
11015 				ep_args += " [[buffer(" + convert_to_string(r.index) + ")";
11016 				if (interlocked_resources.count(var_id))
11017 					ep_args += ", raster_order_group(0)";
11018 				ep_args += "]]";
11019 			}
11020 			break;
11021 		}
11022 		case SPIRType::Sampler:
11023 			if (!ep_args.empty())
11024 				ep_args += ", ";
11025 			ep_args += sampler_type(type, var_id) + " " + r.name;
11026 			ep_args += " [[sampler(" + convert_to_string(r.index) + ")]]";
11027 			break;
11028 		case SPIRType::Image:
11029 		{
11030 			if (!ep_args.empty())
11031 				ep_args += ", ";
11032 
11033 			// Use Metal's native frame-buffer fetch API for subpass inputs.
11034 			const auto &basetype = get<SPIRType>(var.basetype);
11035 			if (!type_is_msl_framebuffer_fetch(basetype))
11036 			{
11037 				ep_args += image_type_glsl(type, var_id) + " " + r.name;
11038 				if (r.plane > 0)
11039 					ep_args += join(plane_name_suffix, r.plane);
11040 				ep_args += " [[texture(" + convert_to_string(r.index) + ")";
11041 				if (interlocked_resources.count(var_id))
11042 					ep_args += ", raster_order_group(0)";
11043 				ep_args += "]]";
11044 			}
11045 			else
11046 			{
11047 				if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 3))
11048 					SPIRV_CROSS_THROW("Framebuffer fetch on Mac is not supported before MSL 2.3.");
11049 				ep_args += image_type_glsl(type, var_id) + " " + r.name;
11050 				ep_args += " [[color(" + convert_to_string(r.index) + ")]]";
11051 			}
11052 
11053 			// Emulate texture2D atomic operations
11054 			if (atomic_image_vars.count(var.self))
11055 			{
11056 				ep_args += ", device atomic_" + type_to_glsl(get<SPIRType>(basetype.image.type), 0);
11057 				ep_args += "* " + r.name + "_atomic";
11058 				ep_args += " [[buffer(" + convert_to_string(r.secondary_index) + ")";
11059 				if (interlocked_resources.count(var_id))
11060 					ep_args += ", raster_order_group(0)";
11061 				ep_args += "]]";
11062 			}
11063 			break;
11064 		}
11065 		default:
11066 			if (!ep_args.empty())
11067 				ep_args += ", ";
11068 			if (!type.pointer)
11069 				ep_args += get_type_address_space(get<SPIRType>(var.basetype), var_id) + " " +
11070 				           type_to_glsl(type, var_id) + "& " + r.name;
11071 			else
11072 				ep_args += type_to_glsl(type, var_id) + " " + r.name;
11073 			ep_args += " [[buffer(" + convert_to_string(r.index) + ")";
11074 			if (interlocked_resources.count(var_id))
11075 				ep_args += ", raster_order_group(0)";
11076 			ep_args += "]]";
11077 			break;
11078 		}
11079 	}
11080 }
11081 
11082 // Returns a string containing a comma-delimited list of args for the entry point function
11083 // This is the "classic" method of MSL 1 when we don't have argument buffer support.
entry_point_args_classic(bool append_comma)11084 string CompilerMSL::entry_point_args_classic(bool append_comma)
11085 {
11086 	string ep_args = entry_point_arg_stage_in();
11087 	entry_point_args_discrete_descriptors(ep_args);
11088 	entry_point_args_builtin(ep_args);
11089 
11090 	if (!ep_args.empty() && append_comma)
11091 		ep_args += ", ";
11092 
11093 	return ep_args;
11094 }
11095 
fix_up_shader_inputs_outputs()11096 void CompilerMSL::fix_up_shader_inputs_outputs()
11097 {
11098 	auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
11099 
11100 	// Emit a guard to ensure we don't execute beyond the last vertex.
11101 	// Vertex shaders shouldn't have the problems with barriers in non-uniform control flow that
11102 	// tessellation control shaders do, so early returns should be OK. We may need to revisit this
11103 	// if it ever becomes possible to use barriers from a vertex shader.
11104 	if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
11105 	{
11106 		entry_func.fixup_hooks_in.push_back([this]() {
11107 			statement("if (any(", to_expression(builtin_invocation_id_id),
11108 			          " >= ", to_expression(builtin_stage_input_size_id), "))");
11109 			statement("    return;");
11110 		});
11111 	}
11112 
11113 	// Look for sampled images and buffer. Add hooks to set up the swizzle constants or array lengths.
11114 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
11115 		auto &type = get_variable_data_type(var);
11116 		uint32_t var_id = var.self;
11117 		bool ssbo = has_decoration(type.self, DecorationBufferBlock);
11118 
11119 		if (var.storage == StorageClassUniformConstant && !is_hidden_variable(var))
11120 		{
11121 			if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
11122 			{
11123 				entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
11124 					bool is_array_type = !type.array.empty();
11125 
11126 					uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
11127 					if (descriptor_set_is_argument_buffer(desc_set))
11128 					{
11129 						statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
11130 						          is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
11131 						          ".spvSwizzleConstants", "[",
11132 						          convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
11133 					}
11134 					else
11135 					{
11136 						// If we have an array of images, we need to be able to index into it, so take a pointer instead.
11137 						statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
11138 						          is_array_type ? " = &" : " = ", to_name(swizzle_buffer_id), "[",
11139 						          convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
11140 					}
11141 				});
11142 			}
11143 		}
11144 		else if ((var.storage == StorageClassStorageBuffer || (var.storage == StorageClassUniform && ssbo)) &&
11145 		         !is_hidden_variable(var))
11146 		{
11147 			if (buffers_requiring_array_length.count(var.self))
11148 			{
11149 				entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
11150 					bool is_array_type = !type.array.empty();
11151 
11152 					uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
11153 					if (descriptor_set_is_argument_buffer(desc_set))
11154 					{
11155 						statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
11156 						          is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
11157 						          ".spvBufferSizeConstants", "[",
11158 						          convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
11159 					}
11160 					else
11161 					{
11162 						// If we have an array of images, we need to be able to index into it, so take a pointer instead.
11163 						statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
11164 						          is_array_type ? " = &" : " = ", to_name(buffer_size_buffer_id), "[",
11165 						          convert_to_string(get_metal_resource_index(var, type.basetype)), "];");
11166 					}
11167 				});
11168 			}
11169 		}
11170 	});
11171 
11172 	// Builtin variables
11173 	ir.for_each_typed_id<SPIRVariable>([this, &entry_func](uint32_t, SPIRVariable &var) {
11174 		uint32_t var_id = var.self;
11175 		BuiltIn bi_type = ir.meta[var_id].decoration.builtin_type;
11176 
11177 		if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
11178 			return;
11179 		if (!interface_variable_exists_in_entry_point(var.self))
11180 			return;
11181 
11182 		if (var.storage == StorageClassInput && is_builtin_variable(var) && active_input_builtins.get(bi_type))
11183 		{
11184 			switch (bi_type)
11185 			{
11186 			case BuiltInSamplePosition:
11187 				entry_func.fixup_hooks_in.push_back([=]() {
11188 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = get_sample_position(",
11189 					          to_expression(builtin_sample_id_id), ");");
11190 				});
11191 				break;
11192 			case BuiltInFragCoord:
11193 				if (is_sample_rate())
11194 				{
11195 					entry_func.fixup_hooks_in.push_back([=]() {
11196 						statement(to_expression(var_id), ".xy += get_sample_position(",
11197 						          to_expression(builtin_sample_id_id), ") - 0.5;");
11198 					});
11199 				}
11200 				break;
11201 			case BuiltInHelperInvocation:
11202 				if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
11203 					SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.3 on iOS.");
11204 				else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
11205 					SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
11206 
11207 				entry_func.fixup_hooks_in.push_back([=]() {
11208 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = simd_is_helper_thread();");
11209 				});
11210 				break;
11211 			case BuiltInInvocationId:
11212 				// This is direct-mapped without multi-patch workgroups.
11213 				if (get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup)
11214 					break;
11215 
11216 				entry_func.fixup_hooks_in.push_back([=]() {
11217 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11218 					          to_expression(builtin_invocation_id_id), ".x % ", this->get_entry_point().output_vertices,
11219 					          ";");
11220 				});
11221 				break;
11222 			case BuiltInPrimitiveId:
11223 				// This is natively supported by fragment and tessellation evaluation shaders.
11224 				// In tessellation control shaders, this is direct-mapped without multi-patch workgroups.
11225 				if (get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup)
11226 					break;
11227 
11228 				entry_func.fixup_hooks_in.push_back([=]() {
11229 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = min(",
11230 					          to_expression(builtin_invocation_id_id), ".x / ", this->get_entry_point().output_vertices,
11231 					          ", spvIndirectParams[1]);");
11232 				});
11233 				break;
11234 			case BuiltInPatchVertices:
11235 				if (get_execution_model() == ExecutionModelTessellationEvaluation)
11236 					entry_func.fixup_hooks_in.push_back([=]() {
11237 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11238 						          to_expression(patch_stage_in_var_id), ".gl_in.size();");
11239 					});
11240 				else
11241 					entry_func.fixup_hooks_in.push_back([=]() {
11242 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = spvIndirectParams[0];");
11243 					});
11244 				break;
11245 			case BuiltInTessCoord:
11246 				// Emit a fixup to account for the shifted domain. Don't do this for triangles;
11247 				// MoltenVK will just reverse the winding order instead.
11248 				if (msl_options.tess_domain_origin_lower_left && !get_entry_point().flags.get(ExecutionModeTriangles))
11249 				{
11250 					string tc = to_expression(var_id);
11251 					entry_func.fixup_hooks_in.push_back([=]() { statement(tc, ".y = 1.0 - ", tc, ".y;"); });
11252 				}
11253 				break;
11254 			case BuiltInSubgroupId:
11255 				if (!msl_options.emulate_subgroups)
11256 					break;
11257 				// For subgroup emulation, this is the same as the local invocation index.
11258 				entry_func.fixup_hooks_in.push_back([=]() {
11259 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11260 					          to_expression(builtin_local_invocation_index_id), ";");
11261 				});
11262 				break;
11263 			case BuiltInNumSubgroups:
11264 				if (!msl_options.emulate_subgroups)
11265 					break;
11266 				// For subgroup emulation, this is the same as the workgroup size.
11267 				entry_func.fixup_hooks_in.push_back([=]() {
11268 					auto &type = expression_type(builtin_workgroup_size_id);
11269 					string size_expr = to_expression(builtin_workgroup_size_id);
11270 					if (type.vecsize >= 3)
11271 						size_expr = join(size_expr, ".x * ", size_expr, ".y * ", size_expr, ".z");
11272 					else if (type.vecsize == 2)
11273 						size_expr = join(size_expr, ".x * ", size_expr, ".y");
11274 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", size_expr, ";");
11275 				});
11276 				break;
11277 			case BuiltInSubgroupLocalInvocationId:
11278 				if (!msl_options.emulate_subgroups)
11279 					break;
11280 				// For subgroup emulation, assume subgroups of size 1.
11281 				entry_func.fixup_hooks_in.push_back(
11282 				    [=]() { statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = 0;"); });
11283 				break;
11284 			case BuiltInSubgroupSize:
11285 				if (msl_options.emulate_subgroups)
11286 				{
11287 					// For subgroup emulation, assume subgroups of size 1.
11288 					entry_func.fixup_hooks_in.push_back(
11289 					    [=]() { statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = 1;"); });
11290 				}
11291 				else if (msl_options.fixed_subgroup_size != 0)
11292 				{
11293 					entry_func.fixup_hooks_in.push_back([=]() {
11294 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11295 						          msl_options.fixed_subgroup_size, ";");
11296 					});
11297 				}
11298 				break;
11299 			case BuiltInSubgroupEqMask:
11300 				if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
11301 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
11302 				if (!msl_options.supports_msl_version(2, 1))
11303 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
11304 				entry_func.fixup_hooks_in.push_back([=]() {
11305 					if (msl_options.is_ios())
11306 					{
11307 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", "uint4(1 << ",
11308 						          to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
11309 					}
11310 					else
11311 					{
11312 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11313 						          to_expression(builtin_subgroup_invocation_id_id), " >= 32 ? uint4(0, (1 << (",
11314 						          to_expression(builtin_subgroup_invocation_id_id), " - 32)), uint2(0)) : uint4(1 << ",
11315 						          to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
11316 					}
11317 				});
11318 				break;
11319 			case BuiltInSubgroupGeMask:
11320 				if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
11321 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
11322 				if (!msl_options.supports_msl_version(2, 1))
11323 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
11324 				if (msl_options.fixed_subgroup_size != 0)
11325 					add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
11326 				entry_func.fixup_hooks_in.push_back([=]() {
11327 					// Case where index < 32, size < 32:
11328 					// mask0 = bfi(0, 0xFFFFFFFF, index, size - index);
11329 					// mask1 = bfi(0, 0xFFFFFFFF, 0, 0); // Gives 0
11330 					// Case where index < 32 but size >= 32:
11331 					// mask0 = bfi(0, 0xFFFFFFFF, index, 32 - index);
11332 					// mask1 = bfi(0, 0xFFFFFFFF, 0, size - 32);
11333 					// Case where index >= 32:
11334 					// mask0 = bfi(0, 0xFFFFFFFF, 32, 0); // Gives 0
11335 					// mask1 = bfi(0, 0xFFFFFFFF, index - 32, size - index);
11336 					// This is expressed without branches to avoid divergent
11337 					// control flow--hence the complicated min/max expressions.
11338 					// This is further complicated by the fact that if you attempt
11339 					// to bfi/bfe out-of-bounds on Metal, undefined behavior is the
11340 					// result.
11341 					if (msl_options.fixed_subgroup_size > 32)
11342 					{
11343 						// Don't use the subgroup size variable with fixed subgroup sizes,
11344 						// since the variables could be defined in the wrong order.
11345 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11346 						          " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
11347 						          to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(32 - (int)",
11348 						          to_expression(builtin_subgroup_invocation_id_id),
11349 						          ", 0)), insert_bits(0u, 0xFFFFFFFF,"
11350 						          " (uint)max((int)",
11351 						          to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), ",
11352 						          msl_options.fixed_subgroup_size, " - max(",
11353 						          to_expression(builtin_subgroup_invocation_id_id),
11354 						          ", 32u)), uint2(0));");
11355 					}
11356 					else if (msl_options.fixed_subgroup_size != 0)
11357 					{
11358 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11359 						          " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
11360 						          to_expression(builtin_subgroup_invocation_id_id), ", ",
11361 						          msl_options.fixed_subgroup_size, " - ",
11362 						          to_expression(builtin_subgroup_invocation_id_id),
11363 						          "), uint3(0));");
11364 					}
11365 					else if (msl_options.is_ios())
11366 					{
11367 						// On iOS, the SIMD-group size will currently never exceed 32.
11368 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11369 						          " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
11370 						          to_expression(builtin_subgroup_invocation_id_id), ", ",
11371 						          to_expression(builtin_subgroup_size_id), " - ",
11372 						          to_expression(builtin_subgroup_invocation_id_id), "), uint3(0));");
11373 					}
11374 					else
11375 					{
11376 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11377 						          " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
11378 						          to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(min((int)",
11379 						          to_expression(builtin_subgroup_size_id), ", 32) - (int)",
11380 						          to_expression(builtin_subgroup_invocation_id_id),
11381 						          ", 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
11382 						          to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), (uint)max((int)",
11383 						          to_expression(builtin_subgroup_size_id), " - (int)max(",
11384 						          to_expression(builtin_subgroup_invocation_id_id), ", 32u), 0)), uint2(0));");
11385 					}
11386 				});
11387 				break;
11388 			case BuiltInSubgroupGtMask:
11389 				if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
11390 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
11391 				if (!msl_options.supports_msl_version(2, 1))
11392 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
11393 				add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
11394 				entry_func.fixup_hooks_in.push_back([=]() {
11395 					// The same logic applies here, except now the index is one
11396 					// more than the subgroup invocation ID.
11397 					if (msl_options.fixed_subgroup_size > 32)
11398 					{
11399 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11400 						          " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
11401 						          to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(32 - (int)",
11402 						          to_expression(builtin_subgroup_invocation_id_id),
11403 						          " - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
11404 						          to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), ",
11405 						          msl_options.fixed_subgroup_size, " - max(",
11406 						          to_expression(builtin_subgroup_invocation_id_id),
11407 						          " + 1, 32u)), uint2(0));");
11408 					}
11409 					else if (msl_options.fixed_subgroup_size != 0)
11410 					{
11411 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11412 						          " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
11413 						          to_expression(builtin_subgroup_invocation_id_id), " + 1, ",
11414 						          msl_options.fixed_subgroup_size, " - ",
11415 						          to_expression(builtin_subgroup_invocation_id_id),
11416 						          " - 1), uint3(0));");
11417 					}
11418 					else if (msl_options.is_ios())
11419 					{
11420 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11421 						          " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
11422 						          to_expression(builtin_subgroup_invocation_id_id), " + 1, ",
11423 						          to_expression(builtin_subgroup_size_id), " - ",
11424 						          to_expression(builtin_subgroup_invocation_id_id), " - 1), uint3(0));");
11425 					}
11426 					else
11427 					{
11428 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11429 						          " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
11430 						          to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(min((int)",
11431 						          to_expression(builtin_subgroup_size_id), ", 32) - (int)",
11432 						          to_expression(builtin_subgroup_invocation_id_id),
11433 						          " - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
11434 						          to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), (uint)max((int)",
11435 						          to_expression(builtin_subgroup_size_id), " - (int)max(",
11436 						          to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), 0)), uint2(0));");
11437 					}
11438 				});
11439 				break;
11440 			case BuiltInSubgroupLeMask:
11441 				if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
11442 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
11443 				if (!msl_options.supports_msl_version(2, 1))
11444 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
11445 				add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
11446 				entry_func.fixup_hooks_in.push_back([=]() {
11447 					if (msl_options.is_ios())
11448 					{
11449 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11450 						          " = uint4(extract_bits(0xFFFFFFFF, 0, ",
11451 						          to_expression(builtin_subgroup_invocation_id_id), " + 1), uint3(0));");
11452 					}
11453 					else
11454 					{
11455 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11456 						          " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
11457 						          to_expression(builtin_subgroup_invocation_id_id),
11458 						          " + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
11459 						          to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0)), uint2(0));");
11460 					}
11461 				});
11462 				break;
11463 			case BuiltInSubgroupLtMask:
11464 				if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
11465 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
11466 				if (!msl_options.supports_msl_version(2, 1))
11467 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
11468 				add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
11469 				entry_func.fixup_hooks_in.push_back([=]() {
11470 					if (msl_options.is_ios())
11471 					{
11472 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11473 						          " = uint4(extract_bits(0xFFFFFFFF, 0, ",
11474 						          to_expression(builtin_subgroup_invocation_id_id), "), uint3(0));");
11475 					}
11476 					else
11477 					{
11478 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11479 						          " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
11480 						          to_expression(builtin_subgroup_invocation_id_id),
11481 						          ", 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
11482 						          to_expression(builtin_subgroup_invocation_id_id), " - 32, 0)), uint2(0));");
11483 					}
11484 				});
11485 				break;
11486 			case BuiltInViewIndex:
11487 				if (!msl_options.multiview)
11488 				{
11489 					// According to the Vulkan spec, when not running under a multiview
11490 					// render pass, ViewIndex is 0.
11491 					entry_func.fixup_hooks_in.push_back([=]() {
11492 						statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = 0;");
11493 					});
11494 				}
11495 				else if (msl_options.view_index_from_device_index)
11496 				{
11497 					// In this case, we take the view index from that of the device we're running on.
11498 					entry_func.fixup_hooks_in.push_back([=]() {
11499 						statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11500 						          msl_options.device_index, ";");
11501 					});
11502 					// We actually don't want to set the render_target_array_index here.
11503 					// Since every physical device is rendering a different view,
11504 					// there's no need for layered rendering here.
11505 				}
11506 				else if (!msl_options.multiview_layered_rendering)
11507 				{
11508 					// In this case, the views are rendered one at a time. The view index, then,
11509 					// is just the first part of the "view mask".
11510 					entry_func.fixup_hooks_in.push_back([=]() {
11511 						statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11512 						          to_expression(view_mask_buffer_id), "[0];");
11513 					});
11514 				}
11515 				else if (get_execution_model() == ExecutionModelFragment)
11516 				{
11517 					// Because we adjusted the view index in the vertex shader, we have to
11518 					// adjust it back here.
11519 					entry_func.fixup_hooks_in.push_back([=]() {
11520 						statement(to_expression(var_id), " += ", to_expression(view_mask_buffer_id), "[0];");
11521 					});
11522 				}
11523 				else if (get_execution_model() == ExecutionModelVertex)
11524 				{
11525 					// Metal provides no special support for multiview, so we smuggle
11526 					// the view index in the instance index.
11527 					entry_func.fixup_hooks_in.push_back([=]() {
11528 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11529 						          to_expression(view_mask_buffer_id), "[0] + (", to_expression(builtin_instance_idx_id),
11530 						          " - ", to_expression(builtin_base_instance_id), ") % ",
11531 						          to_expression(view_mask_buffer_id), "[1];");
11532 						statement(to_expression(builtin_instance_idx_id), " = (",
11533 						          to_expression(builtin_instance_idx_id), " - ",
11534 						          to_expression(builtin_base_instance_id), ") / ", to_expression(view_mask_buffer_id),
11535 						          "[1] + ", to_expression(builtin_base_instance_id), ";");
11536 					});
11537 					// In addition to setting the variable itself, we also need to
11538 					// set the render_target_array_index with it on output. We have to
11539 					// offset this by the base view index, because Metal isn't in on
11540 					// our little game here.
11541 					entry_func.fixup_hooks_out.push_back([=]() {
11542 						statement(to_expression(builtin_layer_id), " = ", to_expression(var_id), " - ",
11543 						          to_expression(view_mask_buffer_id), "[0];");
11544 					});
11545 				}
11546 				break;
11547 			case BuiltInDeviceIndex:
11548 				// Metal pipelines belong to the devices which create them, so we'll
11549 				// need to create a MTLPipelineState for every MTLDevice in a grouped
11550 				// VkDevice. We can assume, then, that the device index is constant.
11551 				entry_func.fixup_hooks_in.push_back([=]() {
11552 					statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11553 					          msl_options.device_index, ";");
11554 				});
11555 				break;
11556 			case BuiltInWorkgroupId:
11557 				if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInWorkgroupId))
11558 					break;
11559 
11560 				// The vkCmdDispatchBase() command lets the client set the base value
11561 				// of WorkgroupId. Metal has no direct equivalent; we must make this
11562 				// adjustment ourselves.
11563 				entry_func.fixup_hooks_in.push_back([=]() {
11564 					statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id), ";");
11565 				});
11566 				break;
11567 			case BuiltInGlobalInvocationId:
11568 				if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInGlobalInvocationId))
11569 					break;
11570 
11571 				// GlobalInvocationId is defined as LocalInvocationId + WorkgroupId * WorkgroupSize.
11572 				// This needs to be adjusted too.
11573 				entry_func.fixup_hooks_in.push_back([=]() {
11574 					auto &execution = this->get_entry_point();
11575 					uint32_t workgroup_size_id = execution.workgroup_size.constant;
11576 					if (workgroup_size_id)
11577 						statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
11578 						          " * ", to_expression(workgroup_size_id), ";");
11579 					else
11580 						statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
11581 						          " * uint3(", execution.workgroup_size.x, ", ", execution.workgroup_size.y, ", ",
11582 						          execution.workgroup_size.z, ");");
11583 				});
11584 				break;
11585 			case BuiltInVertexId:
11586 			case BuiltInVertexIndex:
11587 				// This is direct-mapped normally.
11588 				if (!msl_options.vertex_for_tessellation)
11589 					break;
11590 
11591 				entry_func.fixup_hooks_in.push_back([=]() {
11592 					builtin_declaration = true;
11593 					switch (msl_options.vertex_index_type)
11594 					{
11595 					case Options::IndexType::None:
11596 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11597 						          to_expression(builtin_invocation_id_id), ".x + ",
11598 						          to_expression(builtin_dispatch_base_id), ".x;");
11599 						break;
11600 					case Options::IndexType::UInt16:
11601 					case Options::IndexType::UInt32:
11602 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", index_buffer_var_name,
11603 						          "[", to_expression(builtin_invocation_id_id), ".x] + ",
11604 						          to_expression(builtin_dispatch_base_id), ".x;");
11605 						break;
11606 					}
11607 					builtin_declaration = false;
11608 				});
11609 				break;
11610 			case BuiltInBaseVertex:
11611 				// This is direct-mapped normally.
11612 				if (!msl_options.vertex_for_tessellation)
11613 					break;
11614 
11615 				entry_func.fixup_hooks_in.push_back([=]() {
11616 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11617 					          to_expression(builtin_dispatch_base_id), ".x;");
11618 				});
11619 				break;
11620 			case BuiltInInstanceId:
11621 			case BuiltInInstanceIndex:
11622 				// This is direct-mapped normally.
11623 				if (!msl_options.vertex_for_tessellation)
11624 					break;
11625 
11626 				entry_func.fixup_hooks_in.push_back([=]() {
11627 					builtin_declaration = true;
11628 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11629 					          to_expression(builtin_invocation_id_id), ".y + ", to_expression(builtin_dispatch_base_id),
11630 					          ".y;");
11631 					builtin_declaration = false;
11632 				});
11633 				break;
11634 			case BuiltInBaseInstance:
11635 				// This is direct-mapped normally.
11636 				if (!msl_options.vertex_for_tessellation)
11637 					break;
11638 
11639 				entry_func.fixup_hooks_in.push_back([=]() {
11640 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11641 					          to_expression(builtin_dispatch_base_id), ".y;");
11642 				});
11643 				break;
11644 			default:
11645 				break;
11646 			}
11647 		}
11648 		else if (var.storage == StorageClassOutput && is_builtin_variable(var) && active_output_builtins.get(bi_type))
11649 		{
11650 			if (bi_type == BuiltInSampleMask && get_execution_model() == ExecutionModelFragment &&
11651 			    msl_options.additional_fixed_sample_mask != 0xffffffff)
11652 			{
11653 				// If the additional fixed sample mask was set, we need to adjust the sample_mask
11654 				// output to reflect that. If the shader outputs the sample_mask itself too, we need
11655 				// to AND the two masks to get the final one.
11656 				if (does_shader_write_sample_mask)
11657 				{
11658 					entry_func.fixup_hooks_out.push_back([=]() {
11659 						statement(to_expression(builtin_sample_mask_id),
11660 						          " &= ", msl_options.additional_fixed_sample_mask, ";");
11661 					});
11662 				}
11663 				else
11664 				{
11665 					entry_func.fixup_hooks_out.push_back([=]() {
11666 						statement(to_expression(builtin_sample_mask_id), " = ",
11667 						          msl_options.additional_fixed_sample_mask, ";");
11668 					});
11669 				}
11670 			}
11671 		}
11672 	});
11673 }
11674 
11675 // Returns the Metal index of the resource of the specified type as used by the specified variable.
get_metal_resource_index(SPIRVariable & var,SPIRType::BaseType basetype,uint32_t plane)11676 uint32_t CompilerMSL::get_metal_resource_index(SPIRVariable &var, SPIRType::BaseType basetype, uint32_t plane)
11677 {
11678 	auto &execution = get_entry_point();
11679 	auto &var_dec = ir.meta[var.self].decoration;
11680 	auto &var_type = get<SPIRType>(var.basetype);
11681 	uint32_t var_desc_set = (var.storage == StorageClassPushConstant) ? kPushConstDescSet : var_dec.set;
11682 	uint32_t var_binding = (var.storage == StorageClassPushConstant) ? kPushConstBinding : var_dec.binding;
11683 
11684 	// If a matching binding has been specified, find and use it.
11685 	auto itr = resource_bindings.find({ execution.model, var_desc_set, var_binding });
11686 
11687 	// Atomic helper buffers for image atomics need to use secondary bindings as well.
11688 	bool use_secondary_binding = (var_type.basetype == SPIRType::SampledImage && basetype == SPIRType::Sampler) ||
11689 	                             basetype == SPIRType::AtomicCounter;
11690 
11691 	auto resource_decoration =
11692 	    use_secondary_binding ? SPIRVCrossDecorationResourceIndexSecondary : SPIRVCrossDecorationResourceIndexPrimary;
11693 
11694 	if (plane == 1)
11695 		resource_decoration = SPIRVCrossDecorationResourceIndexTertiary;
11696 	if (plane == 2)
11697 		resource_decoration = SPIRVCrossDecorationResourceIndexQuaternary;
11698 
11699 	if (itr != end(resource_bindings))
11700 	{
11701 		auto &remap = itr->second;
11702 		remap.second = true;
11703 		switch (basetype)
11704 		{
11705 		case SPIRType::Image:
11706 			set_extended_decoration(var.self, resource_decoration, remap.first.msl_texture + plane);
11707 			return remap.first.msl_texture + plane;
11708 		case SPIRType::Sampler:
11709 			set_extended_decoration(var.self, resource_decoration, remap.first.msl_sampler);
11710 			return remap.first.msl_sampler;
11711 		default:
11712 			set_extended_decoration(var.self, resource_decoration, remap.first.msl_buffer);
11713 			return remap.first.msl_buffer;
11714 		}
11715 	}
11716 
11717 	// If we have already allocated an index, keep using it.
11718 	if (has_extended_decoration(var.self, resource_decoration))
11719 		return get_extended_decoration(var.self, resource_decoration);
11720 
11721 	auto &type = get<SPIRType>(var.basetype);
11722 
11723 	if (type_is_msl_framebuffer_fetch(type))
11724 	{
11725 		// Frame-buffer fetch gets its fallback resource index from the input attachment index,
11726 		// which is then treated as color index.
11727 		return get_decoration(var.self, DecorationInputAttachmentIndex);
11728 	}
11729 	else if (msl_options.enable_decoration_binding)
11730 	{
11731 		// Allow user to enable decoration binding.
11732 		// If there is no explicit mapping of bindings to MSL, use the declared binding as a fallback.
11733 		if (has_decoration(var.self, DecorationBinding))
11734 		{
11735 			var_binding = get_decoration(var.self, DecorationBinding);
11736 			// Avoid emitting sentinel bindings.
11737 			if (var_binding < 0x80000000u)
11738 				return var_binding;
11739 		}
11740 	}
11741 
11742 	// If we did not explicitly remap, allocate bindings on demand.
11743 	// We cannot reliably use Binding decorations since SPIR-V and MSL's binding models are very different.
11744 
11745 	bool allocate_argument_buffer_ids = false;
11746 
11747 	if (var.storage != StorageClassPushConstant)
11748 		allocate_argument_buffer_ids = descriptor_set_is_argument_buffer(var_desc_set);
11749 
11750 	uint32_t binding_stride = 1;
11751 	for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
11752 		binding_stride *= to_array_size_literal(type, i);
11753 
11754 	assert(binding_stride != 0);
11755 
11756 	// If a binding has not been specified, revert to incrementing resource indices.
11757 	uint32_t resource_index;
11758 
11759 	if (allocate_argument_buffer_ids)
11760 	{
11761 		// Allocate from a flat ID binding space.
11762 		resource_index = next_metal_resource_ids[var_desc_set];
11763 		next_metal_resource_ids[var_desc_set] += binding_stride;
11764 	}
11765 	else
11766 	{
11767 		// Allocate from plain bindings which are allocated per resource type.
11768 		switch (basetype)
11769 		{
11770 		case SPIRType::Image:
11771 			resource_index = next_metal_resource_index_texture;
11772 			next_metal_resource_index_texture += binding_stride;
11773 			break;
11774 		case SPIRType::Sampler:
11775 			resource_index = next_metal_resource_index_sampler;
11776 			next_metal_resource_index_sampler += binding_stride;
11777 			break;
11778 		default:
11779 			resource_index = next_metal_resource_index_buffer;
11780 			next_metal_resource_index_buffer += binding_stride;
11781 			break;
11782 		}
11783 	}
11784 
11785 	set_extended_decoration(var.self, resource_decoration, resource_index);
11786 	return resource_index;
11787 }
11788 
type_is_msl_framebuffer_fetch(const SPIRType & type) const11789 bool CompilerMSL::type_is_msl_framebuffer_fetch(const SPIRType &type) const
11790 {
11791 	return type.basetype == SPIRType::Image && type.image.dim == DimSubpassData &&
11792 	       msl_options.use_framebuffer_fetch_subpasses;
11793 }
11794 
argument_decl(const SPIRFunction::Parameter & arg)11795 string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
11796 {
11797 	auto &var = get<SPIRVariable>(arg.id);
11798 	auto &type = get_variable_data_type(var);
11799 	auto &var_type = get<SPIRType>(arg.type);
11800 	StorageClass storage = var_type.storage;
11801 	bool is_pointer = var_type.pointer;
11802 
11803 	// If we need to modify the name of the variable, make sure we use the original variable.
11804 	// Our alias is just a shadow variable.
11805 	uint32_t name_id = var.self;
11806 	if (arg.alias_global_variable && var.basevariable)
11807 		name_id = var.basevariable;
11808 
11809 	bool constref = !arg.alias_global_variable && is_pointer && arg.write_count == 0;
11810 	// Framebuffer fetch is plain value, const looks out of place, but it is not wrong.
11811 	if (type_is_msl_framebuffer_fetch(type))
11812 		constref = false;
11813 
11814 	bool type_is_image = type.basetype == SPIRType::Image || type.basetype == SPIRType::SampledImage ||
11815 	                     type.basetype == SPIRType::Sampler;
11816 
11817 	// Arrays of images/samplers in MSL are always const.
11818 	if (!type.array.empty() && type_is_image)
11819 		constref = true;
11820 
11821 	string decl;
11822 	if (constref)
11823 		decl += "const ";
11824 
11825 	// If this is a combined image-sampler for a 2D image with floating-point type,
11826 	// we emitted the 'spvDynamicImageSampler' type, and this is *not* an alias parameter
11827 	// for a global, then we need to emit a "dynamic" combined image-sampler.
11828 	// Unfortunately, this is necessary to properly support passing around
11829 	// combined image-samplers with Y'CbCr conversions on them.
11830 	bool is_dynamic_img_sampler = !arg.alias_global_variable && type.basetype == SPIRType::SampledImage &&
11831 	                              type.image.dim == Dim2D && type_is_floating_point(get<SPIRType>(type.image.type)) &&
11832 	                              spv_function_implementations.count(SPVFuncImplDynamicImageSampler);
11833 
11834 	// Allow Metal to use the array<T> template to make arrays a value type
11835 	string address_space = get_argument_address_space(var);
11836 	bool builtin = is_builtin_variable(var);
11837 	is_using_builtin_array = builtin;
11838 	if (address_space == "threadgroup")
11839 		is_using_builtin_array = true;
11840 
11841 	if (var.basevariable && (var.basevariable == stage_in_ptr_var_id || var.basevariable == stage_out_ptr_var_id))
11842 		decl += type_to_glsl(type, arg.id);
11843 	else if (builtin)
11844 		decl += builtin_type_decl(static_cast<BuiltIn>(get_decoration(arg.id, DecorationBuiltIn)), arg.id);
11845 	else if ((storage == StorageClassUniform || storage == StorageClassStorageBuffer) && is_array(type))
11846 	{
11847 		is_using_builtin_array = true;
11848 		decl += join(type_to_glsl(type, arg.id), "*");
11849 	}
11850 	else if (is_dynamic_img_sampler)
11851 	{
11852 		decl += join("spvDynamicImageSampler<", type_to_glsl(get<SPIRType>(type.image.type)), ">");
11853 		// Mark the variable so that we can handle passing it to another function.
11854 		set_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
11855 	}
11856 	else
11857 		decl += type_to_glsl(type, arg.id);
11858 
11859 	bool opaque_handle = storage == StorageClassUniformConstant;
11860 
11861 	if (!builtin && !opaque_handle && !is_pointer &&
11862 	    (storage == StorageClassFunction || storage == StorageClassGeneric))
11863 	{
11864 		// If the argument is a pure value and not an opaque type, we will pass by value.
11865 		if (msl_options.force_native_arrays && is_array(type))
11866 		{
11867 			// We are receiving an array by value. This is problematic.
11868 			// We cannot be sure of the target address space since we are supposed to receive a copy,
11869 			// but this is not possible with MSL without some extra work.
11870 			// We will have to assume we're getting a reference in thread address space.
11871 			// If we happen to get a reference in constant address space, the caller must emit a copy and pass that.
11872 			// Thread const therefore becomes the only logical choice, since we cannot "create" a constant array from
11873 			// non-constant arrays, but we can create thread const from constant.
11874 			decl = string("thread const ") + decl;
11875 			decl += " (&";
11876 			const char *restrict_kw = to_restrict(name_id);
11877 			if (*restrict_kw)
11878 			{
11879 				decl += " ";
11880 				decl += restrict_kw;
11881 			}
11882 			decl += to_expression(name_id);
11883 			decl += ")";
11884 			decl += type_to_array_glsl(type);
11885 		}
11886 		else
11887 		{
11888 			if (!address_space.empty())
11889 				decl = join(address_space, " ", decl);
11890 			decl += " ";
11891 			decl += to_expression(name_id);
11892 		}
11893 	}
11894 	else if (is_array(type) && !type_is_image)
11895 	{
11896 		// Arrays of images and samplers are special cased.
11897 		if (!address_space.empty())
11898 			decl = join(address_space, " ", decl);
11899 
11900 		if (msl_options.argument_buffers)
11901 		{
11902 			uint32_t desc_set = get_decoration(name_id, DecorationDescriptorSet);
11903 			if ((storage == StorageClassUniform || storage == StorageClassStorageBuffer) &&
11904 			    descriptor_set_is_argument_buffer(desc_set))
11905 			{
11906 				// An awkward case where we need to emit *more* address space declarations (yay!).
11907 				// An example is where we pass down an array of buffer pointers to leaf functions.
11908 				// It's a constant array containing pointers to constants.
11909 				// The pointer array is always constant however. E.g.
11910 				// device SSBO * constant (&array)[N].
11911 				// const device SSBO * constant (&array)[N].
11912 				// constant SSBO * constant (&array)[N].
11913 				// However, this only matters for argument buffers, since for MSL 1.0 style codegen,
11914 				// we emit the buffer array on stack instead, and that seems to work just fine apparently.
11915 
11916 				// If the argument was marked as being in device address space, any pointer to member would
11917 				// be const device, not constant.
11918 				if (argument_buffer_device_storage_mask & (1u << desc_set))
11919 					decl += " const device";
11920 				else
11921 					decl += " constant";
11922 			}
11923 		}
11924 
11925 		decl += " (&";
11926 		const char *restrict_kw = to_restrict(name_id);
11927 		if (*restrict_kw)
11928 		{
11929 			decl += " ";
11930 			decl += restrict_kw;
11931 		}
11932 		decl += to_expression(name_id);
11933 		decl += ")";
11934 		decl += type_to_array_glsl(type);
11935 	}
11936 	else if (!opaque_handle && (!pull_model_inputs.count(var.basevariable) || type.basetype == SPIRType::Struct))
11937 	{
11938 		// If this is going to be a reference to a variable pointer, the address space
11939 		// for the reference has to go before the '&', but after the '*'.
11940 		if (!address_space.empty())
11941 		{
11942 			if (decl.back() == '*')
11943 				decl += join(" ", address_space, " ");
11944 			else
11945 				decl = join(address_space, " ", decl);
11946 		}
11947 		decl += "&";
11948 		decl += " ";
11949 		decl += to_restrict(name_id);
11950 		decl += to_expression(name_id);
11951 	}
11952 	else
11953 	{
11954 		if (!address_space.empty())
11955 			decl = join(address_space, " ", decl);
11956 		decl += " ";
11957 		decl += to_expression(name_id);
11958 	}
11959 
11960 	// Emulate texture2D atomic operations
11961 	auto *backing_var = maybe_get_backing_variable(name_id);
11962 	if (backing_var && atomic_image_vars.count(backing_var->self))
11963 	{
11964 		decl += ", device atomic_" + type_to_glsl(get<SPIRType>(var_type.image.type), 0);
11965 		decl += "* " + to_expression(name_id) + "_atomic";
11966 	}
11967 
11968 	is_using_builtin_array = false;
11969 
11970 	return decl;
11971 }
11972 
11973 // If we're currently in the entry point function, and the object
11974 // has a qualified name, use it, otherwise use the standard name.
to_name(uint32_t id,bool allow_alias) const11975 string CompilerMSL::to_name(uint32_t id, bool allow_alias) const
11976 {
11977 	if (current_function && (current_function->self == ir.default_entry_point))
11978 	{
11979 		auto *m = ir.find_meta(id);
11980 		if (m && !m->decoration.qualified_alias.empty())
11981 			return m->decoration.qualified_alias;
11982 	}
11983 	return Compiler::to_name(id, allow_alias);
11984 }
11985 
11986 // Returns a name that combines the name of the struct with the name of the member, except for Builtins
to_qualified_member_name(const SPIRType & type,uint32_t index)11987 string CompilerMSL::to_qualified_member_name(const SPIRType &type, uint32_t index)
11988 {
11989 	// Don't qualify Builtin names because they are unique and are treated as such when building expressions
11990 	BuiltIn builtin = BuiltInMax;
11991 	if (is_member_builtin(type, index, &builtin))
11992 		return builtin_to_glsl(builtin, type.storage);
11993 
11994 	// Strip any underscore prefix from member name
11995 	string mbr_name = to_member_name(type, index);
11996 	size_t startPos = mbr_name.find_first_not_of("_");
11997 	mbr_name = (startPos != string::npos) ? mbr_name.substr(startPos) : "";
11998 	return join(to_name(type.self), "_", mbr_name);
11999 }
12000 
12001 // Ensures that the specified name is permanently usable by prepending a prefix
12002 // if the first chars are _ and a digit, which indicate a transient name.
ensure_valid_name(string name,string pfx)12003 string CompilerMSL::ensure_valid_name(string name, string pfx)
12004 {
12005 	return (name.size() >= 2 && name[0] == '_' && isdigit(name[1])) ? (pfx + name) : name;
12006 }
12007 
get_reserved_keyword_set()12008 const std::unordered_set<std::string> &CompilerMSL::get_reserved_keyword_set()
12009 {
12010 	static const unordered_set<string> keywords = {
12011 		"kernel",
12012 		"vertex",
12013 		"fragment",
12014 		"compute",
12015 		"bias",
12016 		"level",
12017 		"gradient2d",
12018 		"gradientcube",
12019 		"gradient3d",
12020 		"min_lod_clamp",
12021 		"assert",
12022 		"VARIABLE_TRACEPOINT",
12023 		"STATIC_DATA_TRACEPOINT",
12024 		"STATIC_DATA_TRACEPOINT_V",
12025 		"METAL_ALIGN",
12026 		"METAL_ASM",
12027 		"METAL_CONST",
12028 		"METAL_DEPRECATED",
12029 		"METAL_ENABLE_IF",
12030 		"METAL_FUNC",
12031 		"METAL_INTERNAL",
12032 		"METAL_NON_NULL_RETURN",
12033 		"METAL_NORETURN",
12034 		"METAL_NOTHROW",
12035 		"METAL_PURE",
12036 		"METAL_UNAVAILABLE",
12037 		"METAL_IMPLICIT",
12038 		"METAL_EXPLICIT",
12039 		"METAL_CONST_ARG",
12040 		"METAL_ARG_UNIFORM",
12041 		"METAL_ZERO_ARG",
12042 		"METAL_VALID_LOD_ARG",
12043 		"METAL_VALID_LEVEL_ARG",
12044 		"METAL_VALID_STORE_ORDER",
12045 		"METAL_VALID_LOAD_ORDER",
12046 		"METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
12047 		"METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
12048 		"METAL_VALID_RENDER_TARGET",
12049 		"is_function_constant_defined",
12050 		"CHAR_BIT",
12051 		"SCHAR_MAX",
12052 		"SCHAR_MIN",
12053 		"UCHAR_MAX",
12054 		"CHAR_MAX",
12055 		"CHAR_MIN",
12056 		"USHRT_MAX",
12057 		"SHRT_MAX",
12058 		"SHRT_MIN",
12059 		"UINT_MAX",
12060 		"INT_MAX",
12061 		"INT_MIN",
12062 		"FLT_DIG",
12063 		"FLT_MANT_DIG",
12064 		"FLT_MAX_10_EXP",
12065 		"FLT_MAX_EXP",
12066 		"FLT_MIN_10_EXP",
12067 		"FLT_MIN_EXP",
12068 		"FLT_RADIX",
12069 		"FLT_MAX",
12070 		"FLT_MIN",
12071 		"FLT_EPSILON",
12072 		"FP_ILOGB0",
12073 		"FP_ILOGBNAN",
12074 		"MAXFLOAT",
12075 		"HUGE_VALF",
12076 		"INFINITY",
12077 		"NAN",
12078 		"M_E_F",
12079 		"M_LOG2E_F",
12080 		"M_LOG10E_F",
12081 		"M_LN2_F",
12082 		"M_LN10_F",
12083 		"M_PI_F",
12084 		"M_PI_2_F",
12085 		"M_PI_4_F",
12086 		"M_1_PI_F",
12087 		"M_2_PI_F",
12088 		"M_2_SQRTPI_F",
12089 		"M_SQRT2_F",
12090 		"M_SQRT1_2_F",
12091 		"HALF_DIG",
12092 		"HALF_MANT_DIG",
12093 		"HALF_MAX_10_EXP",
12094 		"HALF_MAX_EXP",
12095 		"HALF_MIN_10_EXP",
12096 		"HALF_MIN_EXP",
12097 		"HALF_RADIX",
12098 		"HALF_MAX",
12099 		"HALF_MIN",
12100 		"HALF_EPSILON",
12101 		"MAXHALF",
12102 		"HUGE_VALH",
12103 		"M_E_H",
12104 		"M_LOG2E_H",
12105 		"M_LOG10E_H",
12106 		"M_LN2_H",
12107 		"M_LN10_H",
12108 		"M_PI_H",
12109 		"M_PI_2_H",
12110 		"M_PI_4_H",
12111 		"M_1_PI_H",
12112 		"M_2_PI_H",
12113 		"M_2_SQRTPI_H",
12114 		"M_SQRT2_H",
12115 		"M_SQRT1_2_H",
12116 		"DBL_DIG",
12117 		"DBL_MANT_DIG",
12118 		"DBL_MAX_10_EXP",
12119 		"DBL_MAX_EXP",
12120 		"DBL_MIN_10_EXP",
12121 		"DBL_MIN_EXP",
12122 		"DBL_RADIX",
12123 		"DBL_MAX",
12124 		"DBL_MIN",
12125 		"DBL_EPSILON",
12126 		"HUGE_VAL",
12127 		"M_E",
12128 		"M_LOG2E",
12129 		"M_LOG10E",
12130 		"M_LN2",
12131 		"M_LN10",
12132 		"M_PI",
12133 		"M_PI_2",
12134 		"M_PI_4",
12135 		"M_1_PI",
12136 		"M_2_PI",
12137 		"M_2_SQRTPI",
12138 		"M_SQRT2",
12139 		"M_SQRT1_2",
12140 		"quad_broadcast",
12141 	};
12142 
12143 	return keywords;
12144 }
12145 
get_illegal_func_names()12146 const std::unordered_set<std::string> &CompilerMSL::get_illegal_func_names()
12147 {
12148 	static const unordered_set<string> illegal_func_names = {
12149 		"main",
12150 		"saturate",
12151 		"assert",
12152 		"fmin3",
12153 		"fmax3",
12154 		"VARIABLE_TRACEPOINT",
12155 		"STATIC_DATA_TRACEPOINT",
12156 		"STATIC_DATA_TRACEPOINT_V",
12157 		"METAL_ALIGN",
12158 		"METAL_ASM",
12159 		"METAL_CONST",
12160 		"METAL_DEPRECATED",
12161 		"METAL_ENABLE_IF",
12162 		"METAL_FUNC",
12163 		"METAL_INTERNAL",
12164 		"METAL_NON_NULL_RETURN",
12165 		"METAL_NORETURN",
12166 		"METAL_NOTHROW",
12167 		"METAL_PURE",
12168 		"METAL_UNAVAILABLE",
12169 		"METAL_IMPLICIT",
12170 		"METAL_EXPLICIT",
12171 		"METAL_CONST_ARG",
12172 		"METAL_ARG_UNIFORM",
12173 		"METAL_ZERO_ARG",
12174 		"METAL_VALID_LOD_ARG",
12175 		"METAL_VALID_LEVEL_ARG",
12176 		"METAL_VALID_STORE_ORDER",
12177 		"METAL_VALID_LOAD_ORDER",
12178 		"METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
12179 		"METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
12180 		"METAL_VALID_RENDER_TARGET",
12181 		"is_function_constant_defined",
12182 		"CHAR_BIT",
12183 		"SCHAR_MAX",
12184 		"SCHAR_MIN",
12185 		"UCHAR_MAX",
12186 		"CHAR_MAX",
12187 		"CHAR_MIN",
12188 		"USHRT_MAX",
12189 		"SHRT_MAX",
12190 		"SHRT_MIN",
12191 		"UINT_MAX",
12192 		"INT_MAX",
12193 		"INT_MIN",
12194 		"FLT_DIG",
12195 		"FLT_MANT_DIG",
12196 		"FLT_MAX_10_EXP",
12197 		"FLT_MAX_EXP",
12198 		"FLT_MIN_10_EXP",
12199 		"FLT_MIN_EXP",
12200 		"FLT_RADIX",
12201 		"FLT_MAX",
12202 		"FLT_MIN",
12203 		"FLT_EPSILON",
12204 		"FP_ILOGB0",
12205 		"FP_ILOGBNAN",
12206 		"MAXFLOAT",
12207 		"HUGE_VALF",
12208 		"INFINITY",
12209 		"NAN",
12210 		"M_E_F",
12211 		"M_LOG2E_F",
12212 		"M_LOG10E_F",
12213 		"M_LN2_F",
12214 		"M_LN10_F",
12215 		"M_PI_F",
12216 		"M_PI_2_F",
12217 		"M_PI_4_F",
12218 		"M_1_PI_F",
12219 		"M_2_PI_F",
12220 		"M_2_SQRTPI_F",
12221 		"M_SQRT2_F",
12222 		"M_SQRT1_2_F",
12223 		"HALF_DIG",
12224 		"HALF_MANT_DIG",
12225 		"HALF_MAX_10_EXP",
12226 		"HALF_MAX_EXP",
12227 		"HALF_MIN_10_EXP",
12228 		"HALF_MIN_EXP",
12229 		"HALF_RADIX",
12230 		"HALF_MAX",
12231 		"HALF_MIN",
12232 		"HALF_EPSILON",
12233 		"MAXHALF",
12234 		"HUGE_VALH",
12235 		"M_E_H",
12236 		"M_LOG2E_H",
12237 		"M_LOG10E_H",
12238 		"M_LN2_H",
12239 		"M_LN10_H",
12240 		"M_PI_H",
12241 		"M_PI_2_H",
12242 		"M_PI_4_H",
12243 		"M_1_PI_H",
12244 		"M_2_PI_H",
12245 		"M_2_SQRTPI_H",
12246 		"M_SQRT2_H",
12247 		"M_SQRT1_2_H",
12248 		"DBL_DIG",
12249 		"DBL_MANT_DIG",
12250 		"DBL_MAX_10_EXP",
12251 		"DBL_MAX_EXP",
12252 		"DBL_MIN_10_EXP",
12253 		"DBL_MIN_EXP",
12254 		"DBL_RADIX",
12255 		"DBL_MAX",
12256 		"DBL_MIN",
12257 		"DBL_EPSILON",
12258 		"HUGE_VAL",
12259 		"M_E",
12260 		"M_LOG2E",
12261 		"M_LOG10E",
12262 		"M_LN2",
12263 		"M_LN10",
12264 		"M_PI",
12265 		"M_PI_2",
12266 		"M_PI_4",
12267 		"M_1_PI",
12268 		"M_2_PI",
12269 		"M_2_SQRTPI",
12270 		"M_SQRT2",
12271 		"M_SQRT1_2",
12272 	};
12273 
12274 	return illegal_func_names;
12275 }
12276 
12277 // Replace all names that match MSL keywords or Metal Standard Library functions.
replace_illegal_names()12278 void CompilerMSL::replace_illegal_names()
12279 {
12280 	// FIXME: MSL and GLSL are doing two different things here.
12281 	// Agree on convention and remove this override.
12282 	auto &keywords = get_reserved_keyword_set();
12283 	auto &illegal_func_names = get_illegal_func_names();
12284 
12285 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &) {
12286 		auto *meta = ir.find_meta(self);
12287 		if (!meta)
12288 			return;
12289 
12290 		auto &dec = meta->decoration;
12291 		if (keywords.find(dec.alias) != end(keywords))
12292 			dec.alias += "0";
12293 	});
12294 
12295 	ir.for_each_typed_id<SPIRFunction>([&](uint32_t self, SPIRFunction &) {
12296 		auto *meta = ir.find_meta(self);
12297 		if (!meta)
12298 			return;
12299 
12300 		auto &dec = meta->decoration;
12301 		if (illegal_func_names.find(dec.alias) != end(illegal_func_names))
12302 			dec.alias += "0";
12303 	});
12304 
12305 	ir.for_each_typed_id<SPIRType>([&](uint32_t self, SPIRType &) {
12306 		auto *meta = ir.find_meta(self);
12307 		if (!meta)
12308 			return;
12309 
12310 		for (auto &mbr_dec : meta->members)
12311 			if (keywords.find(mbr_dec.alias) != end(keywords))
12312 				mbr_dec.alias += "0";
12313 	});
12314 
12315 	CompilerGLSL::replace_illegal_names();
12316 }
12317 
replace_illegal_entry_point_names()12318 void CompilerMSL::replace_illegal_entry_point_names()
12319 {
12320 	auto &illegal_func_names = get_illegal_func_names();
12321 
12322 	// It is important to this before we fixup identifiers,
12323 	// since if ep_name is reserved, we will need to fix that up,
12324 	// and then copy alias back into entry.name after the fixup.
12325 	for (auto &entry : ir.entry_points)
12326 	{
12327 		// Change both the entry point name and the alias, to keep them synced.
12328 		string &ep_name = entry.second.name;
12329 		if (illegal_func_names.find(ep_name) != end(illegal_func_names))
12330 			ep_name += "0";
12331 
12332 		ir.meta[entry.first].decoration.alias = ep_name;
12333 	}
12334 }
12335 
sync_entry_point_aliases_and_names()12336 void CompilerMSL::sync_entry_point_aliases_and_names()
12337 {
12338 	for (auto &entry : ir.entry_points)
12339 		entry.second.name = ir.meta[entry.first].decoration.alias;
12340 }
12341 
to_member_reference(uint32_t base,const SPIRType & type,uint32_t index,bool ptr_chain)12342 string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain)
12343 {
12344 	if (index < uint32_t(type.member_type_index_redirection.size()))
12345 		index = type.member_type_index_redirection[index];
12346 
12347 	auto *var = maybe_get<SPIRVariable>(base);
12348 	// If this is a buffer array, we have to dereference the buffer pointers.
12349 	// Otherwise, if this is a pointer expression, dereference it.
12350 
12351 	bool declared_as_pointer = false;
12352 
12353 	if (var)
12354 	{
12355 		// Only allow -> dereference for block types. This is so we get expressions like
12356 		// buffer[i]->first_member.second_member, rather than buffer[i]->first->second.
12357 		bool is_block = has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock);
12358 
12359 		bool is_buffer_variable =
12360 		    is_block && (var->storage == StorageClassUniform || var->storage == StorageClassStorageBuffer);
12361 		declared_as_pointer = is_buffer_variable && is_array(get<SPIRType>(var->basetype));
12362 	}
12363 
12364 	if (declared_as_pointer || (!ptr_chain && should_dereference(base)))
12365 		return join("->", to_member_name(type, index));
12366 	else
12367 		return join(".", to_member_name(type, index));
12368 }
12369 
to_qualifiers_glsl(uint32_t id)12370 string CompilerMSL::to_qualifiers_glsl(uint32_t id)
12371 {
12372 	string quals;
12373 
12374 	auto &type = expression_type(id);
12375 	if (type.storage == StorageClassWorkgroup)
12376 		quals += "threadgroup ";
12377 
12378 	return quals;
12379 }
12380 
12381 // The optional id parameter indicates the object whose type we are trying
12382 // to find the description for. It is optional. Most type descriptions do not
12383 // depend on a specific object's use of that type.
type_to_glsl(const SPIRType & type,uint32_t id)12384 string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id)
12385 {
12386 	string type_name;
12387 
12388 	// Pointer?
12389 	if (type.pointer)
12390 	{
12391 		const char *restrict_kw;
12392 		type_name = join(get_type_address_space(type, id), " ", type_to_glsl(get<SPIRType>(type.parent_type), id));
12393 
12394 		switch (type.basetype)
12395 		{
12396 		case SPIRType::Image:
12397 		case SPIRType::SampledImage:
12398 		case SPIRType::Sampler:
12399 			// These are handles.
12400 			break;
12401 		default:
12402 			// Anything else can be a raw pointer.
12403 			type_name += "*";
12404 			restrict_kw = to_restrict(id);
12405 			if (*restrict_kw)
12406 			{
12407 				type_name += " ";
12408 				type_name += restrict_kw;
12409 			}
12410 			break;
12411 		}
12412 		return type_name;
12413 	}
12414 
12415 	switch (type.basetype)
12416 	{
12417 	case SPIRType::Struct:
12418 		// Need OpName lookup here to get a "sensible" name for a struct.
12419 		// Allow Metal to use the array<T> template to make arrays a value type
12420 		type_name = to_name(type.self);
12421 		break;
12422 
12423 	case SPIRType::Image:
12424 	case SPIRType::SampledImage:
12425 		return image_type_glsl(type, id);
12426 
12427 	case SPIRType::Sampler:
12428 		return sampler_type(type, id);
12429 
12430 	case SPIRType::Void:
12431 		return "void";
12432 
12433 	case SPIRType::AtomicCounter:
12434 		return "atomic_uint";
12435 
12436 	case SPIRType::ControlPointArray:
12437 		return join("patch_control_point<", type_to_glsl(get<SPIRType>(type.parent_type), id), ">");
12438 
12439 	case SPIRType::Interpolant:
12440 		return join("interpolant<", type_to_glsl(get<SPIRType>(type.parent_type), id), ", interpolation::",
12441 		            has_decoration(type.self, DecorationNoPerspective) ? "no_perspective" : "perspective", ">");
12442 
12443 	// Scalars
12444 	case SPIRType::Boolean:
12445 		type_name = "bool";
12446 		break;
12447 	case SPIRType::Char:
12448 	case SPIRType::SByte:
12449 		type_name = "char";
12450 		break;
12451 	case SPIRType::UByte:
12452 		type_name = "uchar";
12453 		break;
12454 	case SPIRType::Short:
12455 		type_name = "short";
12456 		break;
12457 	case SPIRType::UShort:
12458 		type_name = "ushort";
12459 		break;
12460 	case SPIRType::Int:
12461 		type_name = "int";
12462 		break;
12463 	case SPIRType::UInt:
12464 		type_name = "uint";
12465 		break;
12466 	case SPIRType::Int64:
12467 		if (!msl_options.supports_msl_version(2, 2))
12468 			SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
12469 		type_name = "long";
12470 		break;
12471 	case SPIRType::UInt64:
12472 		if (!msl_options.supports_msl_version(2, 2))
12473 			SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
12474 		type_name = "ulong";
12475 		break;
12476 	case SPIRType::Half:
12477 		type_name = "half";
12478 		break;
12479 	case SPIRType::Float:
12480 		type_name = "float";
12481 		break;
12482 	case SPIRType::Double:
12483 		type_name = "double"; // Currently unsupported
12484 		break;
12485 
12486 	default:
12487 		return "unknown_type";
12488 	}
12489 
12490 	// Matrix?
12491 	if (type.columns > 1)
12492 		type_name += to_string(type.columns) + "x";
12493 
12494 	// Vector or Matrix?
12495 	if (type.vecsize > 1)
12496 		type_name += to_string(type.vecsize);
12497 
12498 	if (type.array.empty() || using_builtin_array())
12499 	{
12500 		return type_name;
12501 	}
12502 	else
12503 	{
12504 		// Allow Metal to use the array<T> template to make arrays a value type
12505 		add_spv_func_and_recompile(SPVFuncImplUnsafeArray);
12506 		string res;
12507 		string sizes;
12508 
12509 		for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
12510 		{
12511 			res += "spvUnsafeArray<";
12512 			sizes += ", ";
12513 			sizes += to_array_size(type, i);
12514 			sizes += ">";
12515 		}
12516 
12517 		res += type_name + sizes;
12518 		return res;
12519 	}
12520 }
12521 
type_to_array_glsl(const SPIRType & type)12522 string CompilerMSL::type_to_array_glsl(const SPIRType &type)
12523 {
12524 	// Allow Metal to use the array<T> template to make arrays a value type
12525 	switch (type.basetype)
12526 	{
12527 	case SPIRType::AtomicCounter:
12528 	case SPIRType::ControlPointArray:
12529 	{
12530 		return CompilerGLSL::type_to_array_glsl(type);
12531 	}
12532 	default:
12533 	{
12534 		if (using_builtin_array())
12535 			return CompilerGLSL::type_to_array_glsl(type);
12536 		else
12537 			return "";
12538 	}
12539 	}
12540 }
12541 
12542 // Threadgroup arrays can't have a wrapper type
variable_decl(const SPIRVariable & variable)12543 std::string CompilerMSL::variable_decl(const SPIRVariable &variable)
12544 {
12545 	if (variable.storage == StorageClassWorkgroup)
12546 	{
12547 		is_using_builtin_array = true;
12548 	}
12549 	std::string expr = CompilerGLSL::variable_decl(variable);
12550 	if (variable.storage == StorageClassWorkgroup)
12551 	{
12552 		is_using_builtin_array = false;
12553 	}
12554 	return expr;
12555 }
12556 
12557 // GCC workaround of lambdas calling protected funcs
variable_decl(const SPIRType & type,const std::string & name,uint32_t id)12558 std::string CompilerMSL::variable_decl(const SPIRType &type, const std::string &name, uint32_t id)
12559 {
12560 	return CompilerGLSL::variable_decl(type, name, id);
12561 }
12562 
sampler_type(const SPIRType & type,uint32_t id)12563 std::string CompilerMSL::sampler_type(const SPIRType &type, uint32_t id)
12564 {
12565 	auto *var = maybe_get<SPIRVariable>(id);
12566 	if (var && var->basevariable)
12567 	{
12568 		// Check against the base variable, and not a fake ID which might have been generated for this variable.
12569 		id = var->basevariable;
12570 	}
12571 
12572 	if (!type.array.empty())
12573 	{
12574 		if (!msl_options.supports_msl_version(2))
12575 			SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of samplers.");
12576 
12577 		if (type.array.size() > 1)
12578 			SPIRV_CROSS_THROW("Arrays of arrays of samplers are not supported in MSL.");
12579 
12580 		// Arrays of samplers in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
12581 		// If we have a runtime array, it could be a variable-count descriptor set binding.
12582 		uint32_t array_size = to_array_size_literal(type);
12583 		if (array_size == 0)
12584 			array_size = get_resource_array_size(id);
12585 
12586 		if (array_size == 0)
12587 			SPIRV_CROSS_THROW("Unsized array of samplers is not supported in MSL.");
12588 
12589 		auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
12590 		return join("array<", sampler_type(parent, id), ", ", array_size, ">");
12591 	}
12592 	else
12593 		return "sampler";
12594 }
12595 
12596 // Returns an MSL string describing the SPIR-V image type
image_type_glsl(const SPIRType & type,uint32_t id)12597 string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id)
12598 {
12599 	auto *var = maybe_get<SPIRVariable>(id);
12600 	if (var && var->basevariable)
12601 	{
12602 		// For comparison images, check against the base variable,
12603 		// and not the fake ID which might have been generated for this variable.
12604 		id = var->basevariable;
12605 	}
12606 
12607 	if (!type.array.empty())
12608 	{
12609 		uint32_t major = 2, minor = 0;
12610 		if (msl_options.is_ios())
12611 		{
12612 			major = 1;
12613 			minor = 2;
12614 		}
12615 		if (!msl_options.supports_msl_version(major, minor))
12616 		{
12617 			if (msl_options.is_ios())
12618 				SPIRV_CROSS_THROW("MSL 1.2 or greater is required for arrays of textures.");
12619 			else
12620 				SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of textures.");
12621 		}
12622 
12623 		if (type.array.size() > 1)
12624 			SPIRV_CROSS_THROW("Arrays of arrays of textures are not supported in MSL.");
12625 
12626 		// Arrays of images in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
12627 		// If we have a runtime array, it could be a variable-count descriptor set binding.
12628 		uint32_t array_size = to_array_size_literal(type);
12629 		if (array_size == 0)
12630 			array_size = get_resource_array_size(id);
12631 
12632 		if (array_size == 0)
12633 			SPIRV_CROSS_THROW("Unsized array of images is not supported in MSL.");
12634 
12635 		auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
12636 		return join("array<", image_type_glsl(parent, id), ", ", array_size, ">");
12637 	}
12638 
12639 	string img_type_name;
12640 
12641 	// Bypass pointers because we need the real image struct
12642 	auto &img_type = get<SPIRType>(type.self).image;
12643 	if (image_is_comparison(type, id))
12644 	{
12645 		switch (img_type.dim)
12646 		{
12647 		case Dim1D:
12648 		case Dim2D:
12649 			if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
12650 			{
12651 				// Use a native Metal 1D texture
12652 				img_type_name += "depth1d_unsupported_by_metal";
12653 				break;
12654 			}
12655 
12656 			if (img_type.ms && img_type.arrayed)
12657 			{
12658 				if (!msl_options.supports_msl_version(2, 1))
12659 					SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
12660 				img_type_name += "depth2d_ms_array";
12661 			}
12662 			else if (img_type.ms)
12663 				img_type_name += "depth2d_ms";
12664 			else if (img_type.arrayed)
12665 				img_type_name += "depth2d_array";
12666 			else
12667 				img_type_name += "depth2d";
12668 			break;
12669 		case Dim3D:
12670 			img_type_name += "depth3d_unsupported_by_metal";
12671 			break;
12672 		case DimCube:
12673 			if (!msl_options.emulate_cube_array)
12674 				img_type_name += (img_type.arrayed ? "depthcube_array" : "depthcube");
12675 			else
12676 				img_type_name += (img_type.arrayed ? "depth2d_array" : "depthcube");
12677 			break;
12678 		default:
12679 			img_type_name += "unknown_depth_texture_type";
12680 			break;
12681 		}
12682 	}
12683 	else
12684 	{
12685 		switch (img_type.dim)
12686 		{
12687 		case DimBuffer:
12688 			if (img_type.ms || img_type.arrayed)
12689 				SPIRV_CROSS_THROW("Cannot use texel buffers with multisampling or array layers.");
12690 
12691 			if (msl_options.texture_buffer_native)
12692 			{
12693 				if (!msl_options.supports_msl_version(2, 1))
12694 					SPIRV_CROSS_THROW("Native texture_buffer type is only supported in MSL 2.1.");
12695 				img_type_name = "texture_buffer";
12696 			}
12697 			else
12698 				img_type_name += "texture2d";
12699 			break;
12700 		case Dim1D:
12701 		case Dim2D:
12702 		case DimSubpassData:
12703 		{
12704 			bool subpass_array =
12705 			    img_type.dim == DimSubpassData && (msl_options.multiview || msl_options.arrayed_subpass_input);
12706 			if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
12707 			{
12708 				// Use a native Metal 1D texture
12709 				img_type_name += (img_type.arrayed ? "texture1d_array" : "texture1d");
12710 				break;
12711 			}
12712 
12713 			// Use Metal's native frame-buffer fetch API for subpass inputs.
12714 			if (type_is_msl_framebuffer_fetch(type))
12715 			{
12716 				auto img_type_4 = get<SPIRType>(img_type.type);
12717 				img_type_4.vecsize = 4;
12718 				return type_to_glsl(img_type_4);
12719 			}
12720 			if (img_type.ms && (img_type.arrayed || subpass_array))
12721 			{
12722 				if (!msl_options.supports_msl_version(2, 1))
12723 					SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
12724 				img_type_name += "texture2d_ms_array";
12725 			}
12726 			else if (img_type.ms)
12727 				img_type_name += "texture2d_ms";
12728 			else if (img_type.arrayed || subpass_array)
12729 				img_type_name += "texture2d_array";
12730 			else
12731 				img_type_name += "texture2d";
12732 			break;
12733 		}
12734 		case Dim3D:
12735 			img_type_name += "texture3d";
12736 			break;
12737 		case DimCube:
12738 			if (!msl_options.emulate_cube_array)
12739 				img_type_name += (img_type.arrayed ? "texturecube_array" : "texturecube");
12740 			else
12741 				img_type_name += (img_type.arrayed ? "texture2d_array" : "texturecube");
12742 			break;
12743 		default:
12744 			img_type_name += "unknown_texture_type";
12745 			break;
12746 		}
12747 	}
12748 
12749 	// Append the pixel type
12750 	img_type_name += "<";
12751 	img_type_name += type_to_glsl(get<SPIRType>(img_type.type));
12752 
12753 	// For unsampled images, append the sample/read/write access qualifier.
12754 	// For kernel images, the access qualifier my be supplied directly by SPIR-V.
12755 	// Otherwise it may be set based on whether the image is read from or written to within the shader.
12756 	if (type.basetype == SPIRType::Image && type.image.sampled == 2 && type.image.dim != DimSubpassData)
12757 	{
12758 		switch (img_type.access)
12759 		{
12760 		case AccessQualifierReadOnly:
12761 			img_type_name += ", access::read";
12762 			break;
12763 
12764 		case AccessQualifierWriteOnly:
12765 			img_type_name += ", access::write";
12766 			break;
12767 
12768 		case AccessQualifierReadWrite:
12769 			img_type_name += ", access::read_write";
12770 			break;
12771 
12772 		default:
12773 		{
12774 			auto *p_var = maybe_get_backing_variable(id);
12775 			if (p_var && p_var->basevariable)
12776 				p_var = maybe_get<SPIRVariable>(p_var->basevariable);
12777 			if (p_var && !has_decoration(p_var->self, DecorationNonWritable))
12778 			{
12779 				img_type_name += ", access::";
12780 
12781 				if (!has_decoration(p_var->self, DecorationNonReadable))
12782 					img_type_name += "read_";
12783 
12784 				img_type_name += "write";
12785 			}
12786 			break;
12787 		}
12788 		}
12789 	}
12790 
12791 	img_type_name += ">";
12792 
12793 	return img_type_name;
12794 }
12795 
emit_subgroup_op(const Instruction & i)12796 void CompilerMSL::emit_subgroup_op(const Instruction &i)
12797 {
12798 	const uint32_t *ops = stream(i);
12799 	auto op = static_cast<Op>(i.op);
12800 
12801 	if (msl_options.emulate_subgroups)
12802 	{
12803 		// In this mode, only the GroupNonUniform cap is supported. The only op
12804 		// we need to handle, then, is OpGroupNonUniformElect.
12805 		if (op != OpGroupNonUniformElect)
12806 			SPIRV_CROSS_THROW("Subgroup emulation does not support operations other than Elect.");
12807 		// In this mode, the subgroup size is assumed to be one, so every invocation
12808 		// is elected.
12809 		emit_op(ops[0], ops[1], "true", true);
12810 		return;
12811 	}
12812 
12813 	// Metal 2.0 is required. iOS only supports quad ops on 11.0 (2.0), with
12814 	// full support in 13.0 (2.2). macOS only supports broadcast and shuffle on
12815 	// 10.13 (2.0), with full support in 10.14 (2.1).
12816 	// Note that Apple GPUs before A13 make no distinction between a quad-group
12817 	// and a SIMD-group; all SIMD-groups are quad-groups on those.
12818 	if (!msl_options.supports_msl_version(2))
12819 		SPIRV_CROSS_THROW("Subgroups are only supported in Metal 2.0 and up.");
12820 
12821 	// If we need to do implicit bitcasts, make sure we do it with the correct type.
12822 	uint32_t integer_width = get_integer_width_for_instruction(i);
12823 	auto int_type = to_signed_basetype(integer_width);
12824 	auto uint_type = to_unsigned_basetype(integer_width);
12825 
12826 	if (msl_options.is_ios() && (!msl_options.supports_msl_version(2, 3) || !msl_options.ios_use_simdgroup_functions))
12827 	{
12828 		switch (op)
12829 		{
12830 		default:
12831 			SPIRV_CROSS_THROW("Subgroup ops beyond broadcast, ballot, and shuffle on iOS require Metal 2.3 and up.");
12832 		case OpGroupNonUniformBroadcastFirst:
12833 			if (!msl_options.supports_msl_version(2, 2))
12834 				SPIRV_CROSS_THROW("BroadcastFirst on iOS requires Metal 2.2 and up.");
12835 			break;
12836 		case OpGroupNonUniformElect:
12837 			if (!msl_options.supports_msl_version(2, 2))
12838 				SPIRV_CROSS_THROW("Elect on iOS requires Metal 2.2 and up.");
12839 			break;
12840 		case OpGroupNonUniformAny:
12841 		case OpGroupNonUniformAll:
12842 		case OpGroupNonUniformAllEqual:
12843 		case OpGroupNonUniformBallot:
12844 		case OpGroupNonUniformInverseBallot:
12845 		case OpGroupNonUniformBallotBitExtract:
12846 		case OpGroupNonUniformBallotFindLSB:
12847 		case OpGroupNonUniformBallotFindMSB:
12848 		case OpGroupNonUniformBallotBitCount:
12849 			if (!msl_options.supports_msl_version(2, 2))
12850 				SPIRV_CROSS_THROW("Ballot ops on iOS requires Metal 2.2 and up.");
12851 			break;
12852 		case OpGroupNonUniformBroadcast:
12853 		case OpGroupNonUniformShuffle:
12854 		case OpGroupNonUniformShuffleXor:
12855 		case OpGroupNonUniformShuffleUp:
12856 		case OpGroupNonUniformShuffleDown:
12857 		case OpGroupNonUniformQuadSwap:
12858 		case OpGroupNonUniformQuadBroadcast:
12859 			break;
12860 		}
12861 	}
12862 
12863 	if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
12864 	{
12865 		switch (op)
12866 		{
12867 		default:
12868 			SPIRV_CROSS_THROW("Subgroup ops beyond broadcast and shuffle on macOS require Metal 2.1 and up.");
12869 		case OpGroupNonUniformBroadcast:
12870 		case OpGroupNonUniformShuffle:
12871 		case OpGroupNonUniformShuffleXor:
12872 		case OpGroupNonUniformShuffleUp:
12873 		case OpGroupNonUniformShuffleDown:
12874 			break;
12875 		}
12876 	}
12877 
12878 	uint32_t result_type = ops[0];
12879 	uint32_t id = ops[1];
12880 
12881 	auto scope = static_cast<Scope>(evaluate_constant_u32(ops[2]));
12882 	if (scope != ScopeSubgroup)
12883 		SPIRV_CROSS_THROW("Only subgroup scope is supported.");
12884 
12885 	switch (op)
12886 	{
12887 	case OpGroupNonUniformElect:
12888 		if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
12889 			emit_op(result_type, id, "quad_is_first()", false);
12890 		else
12891 			emit_op(result_type, id, "simd_is_first()", false);
12892 		break;
12893 
12894 	case OpGroupNonUniformBroadcast:
12895 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBroadcast");
12896 		break;
12897 
12898 	case OpGroupNonUniformBroadcastFirst:
12899 		emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBroadcastFirst");
12900 		break;
12901 
12902 	case OpGroupNonUniformBallot:
12903 		emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallot");
12904 		break;
12905 
12906 	case OpGroupNonUniformInverseBallot:
12907 		emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_invocation_id_id, "spvSubgroupBallotBitExtract");
12908 		break;
12909 
12910 	case OpGroupNonUniformBallotBitExtract:
12911 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBallotBitExtract");
12912 		break;
12913 
12914 	case OpGroupNonUniformBallotFindLSB:
12915 		emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_size_id, "spvSubgroupBallotFindLSB");
12916 		break;
12917 
12918 	case OpGroupNonUniformBallotFindMSB:
12919 		emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_size_id, "spvSubgroupBallotFindMSB");
12920 		break;
12921 
12922 	case OpGroupNonUniformBallotBitCount:
12923 	{
12924 		auto operation = static_cast<GroupOperation>(ops[3]);
12925 		switch (operation)
12926 		{
12927 		case GroupOperationReduce:
12928 			emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_size_id, "spvSubgroupBallotBitCount");
12929 			break;
12930 		case GroupOperationInclusiveScan:
12931 			emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
12932 			                    "spvSubgroupBallotInclusiveBitCount");
12933 			break;
12934 		case GroupOperationExclusiveScan:
12935 			emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
12936 			                    "spvSubgroupBallotExclusiveBitCount");
12937 			break;
12938 		default:
12939 			SPIRV_CROSS_THROW("Invalid BitCount operation.");
12940 			break;
12941 		}
12942 		break;
12943 	}
12944 
12945 	case OpGroupNonUniformShuffle:
12946 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffle");
12947 		break;
12948 
12949 	case OpGroupNonUniformShuffleXor:
12950 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleXor");
12951 		break;
12952 
12953 	case OpGroupNonUniformShuffleUp:
12954 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleUp");
12955 		break;
12956 
12957 	case OpGroupNonUniformShuffleDown:
12958 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleDown");
12959 		break;
12960 
12961 	case OpGroupNonUniformAll:
12962 		if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
12963 			emit_unary_func_op(result_type, id, ops[3], "quad_all");
12964 		else
12965 			emit_unary_func_op(result_type, id, ops[3], "simd_all");
12966 		break;
12967 
12968 	case OpGroupNonUniformAny:
12969 		if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
12970 			emit_unary_func_op(result_type, id, ops[3], "quad_any");
12971 		else
12972 			emit_unary_func_op(result_type, id, ops[3], "simd_any");
12973 		break;
12974 
12975 	case OpGroupNonUniformAllEqual:
12976 		emit_unary_func_op(result_type, id, ops[3], "spvSubgroupAllEqual");
12977 		break;
12978 
12979 		// clang-format off
12980 #define MSL_GROUP_OP(op, msl_op) \
12981 case OpGroupNonUniform##op: \
12982 	{ \
12983 		auto operation = static_cast<GroupOperation>(ops[3]); \
12984 		if (operation == GroupOperationReduce) \
12985 			emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
12986 		else if (operation == GroupOperationInclusiveScan) \
12987 			emit_unary_func_op(result_type, id, ops[4], "simd_prefix_inclusive_" #msl_op); \
12988 		else if (operation == GroupOperationExclusiveScan) \
12989 			emit_unary_func_op(result_type, id, ops[4], "simd_prefix_exclusive_" #msl_op); \
12990 		else if (operation == GroupOperationClusteredReduce) \
12991 		{ \
12992 			/* Only cluster sizes of 4 are supported. */ \
12993 			uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
12994 			if (cluster_size != 4) \
12995 				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
12996 			emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
12997 		} \
12998 		else \
12999 			SPIRV_CROSS_THROW("Invalid group operation."); \
13000 		break; \
13001 	}
13002 	MSL_GROUP_OP(FAdd, sum)
13003 	MSL_GROUP_OP(FMul, product)
13004 	MSL_GROUP_OP(IAdd, sum)
13005 	MSL_GROUP_OP(IMul, product)
13006 #undef MSL_GROUP_OP
13007 	// The others, unfortunately, don't support InclusiveScan or ExclusiveScan.
13008 
13009 #define MSL_GROUP_OP(op, msl_op) \
13010 case OpGroupNonUniform##op: \
13011 	{ \
13012 		auto operation = static_cast<GroupOperation>(ops[3]); \
13013 		if (operation == GroupOperationReduce) \
13014 			emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
13015 		else if (operation == GroupOperationInclusiveScan) \
13016 			SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
13017 		else if (operation == GroupOperationExclusiveScan) \
13018 			SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
13019 		else if (operation == GroupOperationClusteredReduce) \
13020 		{ \
13021 			/* Only cluster sizes of 4 are supported. */ \
13022 			uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
13023 			if (cluster_size != 4) \
13024 				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
13025 			emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
13026 		} \
13027 		else \
13028 			SPIRV_CROSS_THROW("Invalid group operation."); \
13029 		break; \
13030 	}
13031 
13032 #define MSL_GROUP_OP_CAST(op, msl_op, type) \
13033 case OpGroupNonUniform##op: \
13034 	{ \
13035 		auto operation = static_cast<GroupOperation>(ops[3]); \
13036 		if (operation == GroupOperationReduce) \
13037 			emit_unary_func_op_cast(result_type, id, ops[4], "simd_" #msl_op, type, type); \
13038 		else if (operation == GroupOperationInclusiveScan) \
13039 			SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
13040 		else if (operation == GroupOperationExclusiveScan) \
13041 			SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
13042 		else if (operation == GroupOperationClusteredReduce) \
13043 		{ \
13044 			/* Only cluster sizes of 4 are supported. */ \
13045 			uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
13046 			if (cluster_size != 4) \
13047 				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
13048 			emit_unary_func_op_cast(result_type, id, ops[4], "quad_" #msl_op, type, type); \
13049 		} \
13050 		else \
13051 			SPIRV_CROSS_THROW("Invalid group operation."); \
13052 		break; \
13053 	}
13054 
13055 	MSL_GROUP_OP(FMin, min)
13056 	MSL_GROUP_OP(FMax, max)
13057 	MSL_GROUP_OP_CAST(SMin, min, int_type)
13058 	MSL_GROUP_OP_CAST(SMax, max, int_type)
13059 	MSL_GROUP_OP_CAST(UMin, min, uint_type)
13060 	MSL_GROUP_OP_CAST(UMax, max, uint_type)
13061 	MSL_GROUP_OP(BitwiseAnd, and)
13062 	MSL_GROUP_OP(BitwiseOr, or)
13063 	MSL_GROUP_OP(BitwiseXor, xor)
13064 	MSL_GROUP_OP(LogicalAnd, and)
13065 	MSL_GROUP_OP(LogicalOr, or)
13066 	MSL_GROUP_OP(LogicalXor, xor)
13067 		// clang-format on
13068 #undef MSL_GROUP_OP
13069 #undef MSL_GROUP_OP_CAST
13070 
13071 	case OpGroupNonUniformQuadSwap:
13072 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvQuadSwap");
13073 		break;
13074 
13075 	case OpGroupNonUniformQuadBroadcast:
13076 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvQuadBroadcast");
13077 		break;
13078 
13079 	default:
13080 		SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
13081 	}
13082 
13083 	register_control_dependent_expression(id);
13084 }
13085 
bitcast_glsl_op(const SPIRType & out_type,const SPIRType & in_type)13086 string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
13087 {
13088 	if (out_type.basetype == in_type.basetype)
13089 		return "";
13090 
13091 	assert(out_type.basetype != SPIRType::Boolean);
13092 	assert(in_type.basetype != SPIRType::Boolean);
13093 
13094 	bool integral_cast = type_is_integral(out_type) && type_is_integral(in_type);
13095 	bool same_size_cast = out_type.width == in_type.width;
13096 
13097 	if (integral_cast && same_size_cast)
13098 	{
13099 		// Trivial bitcast case, casts between integers.
13100 		return type_to_glsl(out_type);
13101 	}
13102 	else
13103 	{
13104 		// Fall back to the catch-all bitcast in MSL.
13105 		return "as_type<" + type_to_glsl(out_type) + ">";
13106 	}
13107 }
13108 
emit_complex_bitcast(uint32_t,uint32_t,uint32_t)13109 bool CompilerMSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
13110 {
13111 	return false;
13112 }
13113 
13114 // Returns an MSL string identifying the name of a SPIR-V builtin.
13115 // Output builtins are qualified with the name of the stage out structure.
builtin_to_glsl(BuiltIn builtin,StorageClass storage)13116 string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
13117 {
13118 	switch (builtin)
13119 	{
13120 
13121 	// Handle HLSL-style 0-based vertex/instance index.
13122 	// Override GLSL compiler strictness
13123 	case BuiltInVertexId:
13124 		ensure_builtin(StorageClassInput, BuiltInVertexId);
13125 		if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
13126 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13127 		{
13128 			if (builtin_declaration)
13129 			{
13130 				if (needs_base_vertex_arg != TriState::No)
13131 					needs_base_vertex_arg = TriState::Yes;
13132 				return "gl_VertexID";
13133 			}
13134 			else
13135 			{
13136 				ensure_builtin(StorageClassInput, BuiltInBaseVertex);
13137 				return "(gl_VertexID - gl_BaseVertex)";
13138 			}
13139 		}
13140 		else
13141 		{
13142 			return "gl_VertexID";
13143 		}
13144 	case BuiltInInstanceId:
13145 		ensure_builtin(StorageClassInput, BuiltInInstanceId);
13146 		if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
13147 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13148 		{
13149 			if (builtin_declaration)
13150 			{
13151 				if (needs_base_instance_arg != TriState::No)
13152 					needs_base_instance_arg = TriState::Yes;
13153 				return "gl_InstanceID";
13154 			}
13155 			else
13156 			{
13157 				ensure_builtin(StorageClassInput, BuiltInBaseInstance);
13158 				return "(gl_InstanceID - gl_BaseInstance)";
13159 			}
13160 		}
13161 		else
13162 		{
13163 			return "gl_InstanceID";
13164 		}
13165 	case BuiltInVertexIndex:
13166 		ensure_builtin(StorageClassInput, BuiltInVertexIndex);
13167 		if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
13168 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13169 		{
13170 			if (builtin_declaration)
13171 			{
13172 				if (needs_base_vertex_arg != TriState::No)
13173 					needs_base_vertex_arg = TriState::Yes;
13174 				return "gl_VertexIndex";
13175 			}
13176 			else
13177 			{
13178 				ensure_builtin(StorageClassInput, BuiltInBaseVertex);
13179 				return "(gl_VertexIndex - gl_BaseVertex)";
13180 			}
13181 		}
13182 		else
13183 		{
13184 			return "gl_VertexIndex";
13185 		}
13186 	case BuiltInInstanceIndex:
13187 		ensure_builtin(StorageClassInput, BuiltInInstanceIndex);
13188 		if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
13189 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13190 		{
13191 			if (builtin_declaration)
13192 			{
13193 				if (needs_base_instance_arg != TriState::No)
13194 					needs_base_instance_arg = TriState::Yes;
13195 				return "gl_InstanceIndex";
13196 			}
13197 			else
13198 			{
13199 				ensure_builtin(StorageClassInput, BuiltInBaseInstance);
13200 				return "(gl_InstanceIndex - gl_BaseInstance)";
13201 			}
13202 		}
13203 		else
13204 		{
13205 			return "gl_InstanceIndex";
13206 		}
13207 	case BuiltInBaseVertex:
13208 		if (msl_options.supports_msl_version(1, 1) &&
13209 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13210 		{
13211 			needs_base_vertex_arg = TriState::No;
13212 			return "gl_BaseVertex";
13213 		}
13214 		else
13215 		{
13216 			SPIRV_CROSS_THROW("BaseVertex requires Metal 1.1 and Mac or Apple A9+ hardware.");
13217 		}
13218 	case BuiltInBaseInstance:
13219 		if (msl_options.supports_msl_version(1, 1) &&
13220 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13221 		{
13222 			needs_base_instance_arg = TriState::No;
13223 			return "gl_BaseInstance";
13224 		}
13225 		else
13226 		{
13227 			SPIRV_CROSS_THROW("BaseInstance requires Metal 1.1 and Mac or Apple A9+ hardware.");
13228 		}
13229 	case BuiltInDrawIndex:
13230 		SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
13231 
13232 	// When used in the entry function, output builtins are qualified with output struct name.
13233 	// Test storage class as NOT Input, as output builtins might be part of generic type.
13234 	// Also don't do this for tessellation control shaders.
13235 	case BuiltInViewportIndex:
13236 		if (!msl_options.supports_msl_version(2, 0))
13237 			SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
13238 		/* fallthrough */
13239 	case BuiltInFragDepth:
13240 	case BuiltInFragStencilRefEXT:
13241 		if ((builtin == BuiltInFragDepth && !msl_options.enable_frag_depth_builtin) ||
13242 		    (builtin == BuiltInFragStencilRefEXT && !msl_options.enable_frag_stencil_ref_builtin))
13243 			break;
13244 		/* fallthrough */
13245 	case BuiltInPosition:
13246 	case BuiltInPointSize:
13247 	case BuiltInClipDistance:
13248 	case BuiltInCullDistance:
13249 	case BuiltInLayer:
13250 	case BuiltInSampleMask:
13251 		if (get_execution_model() == ExecutionModelTessellationControl)
13252 			break;
13253 		if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
13254 			return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
13255 
13256 		break;
13257 
13258 	case BuiltInBaryCoordNV:
13259 	case BuiltInBaryCoordNoPerspNV:
13260 		if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
13261 			return stage_in_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
13262 		break;
13263 
13264 	case BuiltInTessLevelOuter:
13265 		if (get_execution_model() == ExecutionModelTessellationEvaluation)
13266 		{
13267 			if (storage != StorageClassOutput && !get_entry_point().flags.get(ExecutionModeTriangles) &&
13268 			    current_function && (current_function->self == ir.default_entry_point))
13269 				return join(patch_stage_in_var_name, ".", CompilerGLSL::builtin_to_glsl(builtin, storage));
13270 			else
13271 				break;
13272 		}
13273 		if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
13274 			return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
13275 			            "].edgeTessellationFactor");
13276 		break;
13277 
13278 	case BuiltInTessLevelInner:
13279 		if (get_execution_model() == ExecutionModelTessellationEvaluation)
13280 		{
13281 			if (storage != StorageClassOutput && !get_entry_point().flags.get(ExecutionModeTriangles) &&
13282 			    current_function && (current_function->self == ir.default_entry_point))
13283 				return join(patch_stage_in_var_name, ".", CompilerGLSL::builtin_to_glsl(builtin, storage));
13284 			else
13285 				break;
13286 		}
13287 		if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
13288 			return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
13289 			            "].insideTessellationFactor");
13290 		break;
13291 
13292 	default:
13293 		break;
13294 	}
13295 
13296 	return CompilerGLSL::builtin_to_glsl(builtin, storage);
13297 }
13298 
13299 // Returns an MSL string attribute qualifer for a SPIR-V builtin
builtin_qualifier(BuiltIn builtin)13300 string CompilerMSL::builtin_qualifier(BuiltIn builtin)
13301 {
13302 	auto &execution = get_entry_point();
13303 
13304 	switch (builtin)
13305 	{
13306 	// Vertex function in
13307 	case BuiltInVertexId:
13308 		return "vertex_id";
13309 	case BuiltInVertexIndex:
13310 		return "vertex_id";
13311 	case BuiltInBaseVertex:
13312 		return "base_vertex";
13313 	case BuiltInInstanceId:
13314 		return "instance_id";
13315 	case BuiltInInstanceIndex:
13316 		return "instance_id";
13317 	case BuiltInBaseInstance:
13318 		return "base_instance";
13319 	case BuiltInDrawIndex:
13320 		SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
13321 
13322 	// Vertex function out
13323 	case BuiltInClipDistance:
13324 		return "clip_distance";
13325 	case BuiltInPointSize:
13326 		return "point_size";
13327 	case BuiltInPosition:
13328 		if (position_invariant)
13329 		{
13330 			if (!msl_options.supports_msl_version(2, 1))
13331 				SPIRV_CROSS_THROW("Invariant position is only supported on MSL 2.1 and up.");
13332 			return "position, invariant";
13333 		}
13334 		else
13335 			return "position";
13336 	case BuiltInLayer:
13337 		return "render_target_array_index";
13338 	case BuiltInViewportIndex:
13339 		if (!msl_options.supports_msl_version(2, 0))
13340 			SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
13341 		return "viewport_array_index";
13342 
13343 	// Tess. control function in
13344 	case BuiltInInvocationId:
13345 		if (msl_options.multi_patch_workgroup)
13346 		{
13347 			// Shouldn't be reached.
13348 			SPIRV_CROSS_THROW("InvocationId is computed manually with multi-patch workgroups in MSL.");
13349 		}
13350 		return "thread_index_in_threadgroup";
13351 	case BuiltInPatchVertices:
13352 		// Shouldn't be reached.
13353 		SPIRV_CROSS_THROW("PatchVertices is derived from the auxiliary buffer in MSL.");
13354 	case BuiltInPrimitiveId:
13355 		switch (execution.model)
13356 		{
13357 		case ExecutionModelTessellationControl:
13358 			if (msl_options.multi_patch_workgroup)
13359 			{
13360 				// Shouldn't be reached.
13361 				SPIRV_CROSS_THROW("PrimitiveId is computed manually with multi-patch workgroups in MSL.");
13362 			}
13363 			return "threadgroup_position_in_grid";
13364 		case ExecutionModelTessellationEvaluation:
13365 			return "patch_id";
13366 		case ExecutionModelFragment:
13367 			if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
13368 				SPIRV_CROSS_THROW("PrimitiveId on iOS requires MSL 2.3.");
13369 			else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 2))
13370 				SPIRV_CROSS_THROW("PrimitiveId on macOS requires MSL 2.2.");
13371 			return "primitive_id";
13372 		default:
13373 			SPIRV_CROSS_THROW("PrimitiveId is not supported in this execution model.");
13374 		}
13375 
13376 	// Tess. control function out
13377 	case BuiltInTessLevelOuter:
13378 	case BuiltInTessLevelInner:
13379 		// Shouldn't be reached.
13380 		SPIRV_CROSS_THROW("Tessellation levels are handled specially in MSL.");
13381 
13382 	// Tess. evaluation function in
13383 	case BuiltInTessCoord:
13384 		return "position_in_patch";
13385 
13386 	// Fragment function in
13387 	case BuiltInFrontFacing:
13388 		return "front_facing";
13389 	case BuiltInPointCoord:
13390 		return "point_coord";
13391 	case BuiltInFragCoord:
13392 		return "position";
13393 	case BuiltInSampleId:
13394 		return "sample_id";
13395 	case BuiltInSampleMask:
13396 		return "sample_mask";
13397 	case BuiltInSamplePosition:
13398 		// Shouldn't be reached.
13399 		SPIRV_CROSS_THROW("Sample position is retrieved by a function in MSL.");
13400 	case BuiltInViewIndex:
13401 		if (execution.model != ExecutionModelFragment)
13402 			SPIRV_CROSS_THROW("ViewIndex is handled specially outside fragment shaders.");
13403 		// The ViewIndex was implicitly used in the prior stages to set the render_target_array_index,
13404 		// so we can get it from there.
13405 		return "render_target_array_index";
13406 
13407 	// Fragment function out
13408 	case BuiltInFragDepth:
13409 		if (execution.flags.get(ExecutionModeDepthGreater))
13410 			return "depth(greater)";
13411 		else if (execution.flags.get(ExecutionModeDepthLess))
13412 			return "depth(less)";
13413 		else
13414 			return "depth(any)";
13415 
13416 	case BuiltInFragStencilRefEXT:
13417 		return "stencil";
13418 
13419 	// Compute function in
13420 	case BuiltInGlobalInvocationId:
13421 		return "thread_position_in_grid";
13422 
13423 	case BuiltInWorkgroupId:
13424 		return "threadgroup_position_in_grid";
13425 
13426 	case BuiltInNumWorkgroups:
13427 		return "threadgroups_per_grid";
13428 
13429 	case BuiltInLocalInvocationId:
13430 		return "thread_position_in_threadgroup";
13431 
13432 	case BuiltInLocalInvocationIndex:
13433 		return "thread_index_in_threadgroup";
13434 
13435 	case BuiltInSubgroupSize:
13436 		if (msl_options.emulate_subgroups || msl_options.fixed_subgroup_size != 0)
13437 			// Shouldn't be reached.
13438 			SPIRV_CROSS_THROW("Emitting threads_per_simdgroup attribute with fixed subgroup size??");
13439 		if (execution.model == ExecutionModelFragment)
13440 		{
13441 			if (!msl_options.supports_msl_version(2, 2))
13442 				SPIRV_CROSS_THROW("threads_per_simdgroup requires Metal 2.2 in fragment shaders.");
13443 			return "threads_per_simdgroup";
13444 		}
13445 		else
13446 		{
13447 			// thread_execution_width is an alias for threads_per_simdgroup, and it's only available since 1.0,
13448 			// but not in fragment.
13449 			return "thread_execution_width";
13450 		}
13451 
13452 	case BuiltInNumSubgroups:
13453 		if (msl_options.emulate_subgroups)
13454 			// Shouldn't be reached.
13455 			SPIRV_CROSS_THROW("NumSubgroups is handled specially with emulation.");
13456 		if (!msl_options.supports_msl_version(2))
13457 			SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
13458 		return msl_options.is_ios() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
13459 
13460 	case BuiltInSubgroupId:
13461 		if (msl_options.emulate_subgroups)
13462 			// Shouldn't be reached.
13463 			SPIRV_CROSS_THROW("SubgroupId is handled specially with emulation.");
13464 		if (!msl_options.supports_msl_version(2))
13465 			SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
13466 		return msl_options.is_ios() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
13467 
13468 	case BuiltInSubgroupLocalInvocationId:
13469 		if (msl_options.emulate_subgroups)
13470 			// Shouldn't be reached.
13471 			SPIRV_CROSS_THROW("SubgroupLocalInvocationId is handled specially with emulation.");
13472 		if (execution.model == ExecutionModelFragment)
13473 		{
13474 			if (!msl_options.supports_msl_version(2, 2))
13475 				SPIRV_CROSS_THROW("thread_index_in_simdgroup requires Metal 2.2 in fragment shaders.");
13476 			return "thread_index_in_simdgroup";
13477 		}
13478 		else if (execution.model == ExecutionModelKernel || execution.model == ExecutionModelGLCompute ||
13479 		         execution.model == ExecutionModelTessellationControl ||
13480 		         (execution.model == ExecutionModelVertex && msl_options.vertex_for_tessellation))
13481 		{
13482 			// We are generating a Metal kernel function.
13483 			if (!msl_options.supports_msl_version(2))
13484 				SPIRV_CROSS_THROW("Subgroup builtins in kernel functions require Metal 2.0.");
13485 			return msl_options.is_ios() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
13486 		}
13487 		else
13488 			SPIRV_CROSS_THROW("Subgroup builtins are not available in this type of function.");
13489 
13490 	case BuiltInSubgroupEqMask:
13491 	case BuiltInSubgroupGeMask:
13492 	case BuiltInSubgroupGtMask:
13493 	case BuiltInSubgroupLeMask:
13494 	case BuiltInSubgroupLtMask:
13495 		// Shouldn't be reached.
13496 		SPIRV_CROSS_THROW("Subgroup ballot masks are handled specially in MSL.");
13497 
13498 	case BuiltInBaryCoordNV:
13499 		// TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
13500 		if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
13501 			SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.3 and above on iOS.");
13502 		else if (!msl_options.supports_msl_version(2, 2))
13503 			SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
13504 		return "barycentric_coord, center_perspective";
13505 
13506 	case BuiltInBaryCoordNoPerspNV:
13507 		// TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
13508 		if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
13509 			SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.3 and above on iOS.");
13510 		else if (!msl_options.supports_msl_version(2, 2))
13511 			SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
13512 		return "barycentric_coord, center_no_perspective";
13513 
13514 	default:
13515 		return "unsupported-built-in";
13516 	}
13517 }
13518 
13519 // Returns an MSL string type declaration for a SPIR-V builtin
builtin_type_decl(BuiltIn builtin,uint32_t id)13520 string CompilerMSL::builtin_type_decl(BuiltIn builtin, uint32_t id)
13521 {
13522 	const SPIREntryPoint &execution = get_entry_point();
13523 	switch (builtin)
13524 	{
13525 	// Vertex function in
13526 	case BuiltInVertexId:
13527 		return "uint";
13528 	case BuiltInVertexIndex:
13529 		return "uint";
13530 	case BuiltInBaseVertex:
13531 		return "uint";
13532 	case BuiltInInstanceId:
13533 		return "uint";
13534 	case BuiltInInstanceIndex:
13535 		return "uint";
13536 	case BuiltInBaseInstance:
13537 		return "uint";
13538 	case BuiltInDrawIndex:
13539 		SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
13540 
13541 	// Vertex function out
13542 	case BuiltInClipDistance:
13543 		return "float";
13544 	case BuiltInPointSize:
13545 		return "float";
13546 	case BuiltInPosition:
13547 		return "float4";
13548 	case BuiltInLayer:
13549 		return "uint";
13550 	case BuiltInViewportIndex:
13551 		if (!msl_options.supports_msl_version(2, 0))
13552 			SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
13553 		return "uint";
13554 
13555 	// Tess. control function in
13556 	case BuiltInInvocationId:
13557 		return "uint";
13558 	case BuiltInPatchVertices:
13559 		return "uint";
13560 	case BuiltInPrimitiveId:
13561 		return "uint";
13562 
13563 	// Tess. control function out
13564 	case BuiltInTessLevelInner:
13565 		if (execution.model == ExecutionModelTessellationEvaluation)
13566 			return !execution.flags.get(ExecutionModeTriangles) ? "float2" : "float";
13567 		return "half";
13568 	case BuiltInTessLevelOuter:
13569 		if (execution.model == ExecutionModelTessellationEvaluation)
13570 			return !execution.flags.get(ExecutionModeTriangles) ? "float4" : "float";
13571 		return "half";
13572 
13573 	// Tess. evaluation function in
13574 	case BuiltInTessCoord:
13575 		return execution.flags.get(ExecutionModeTriangles) ? "float3" : "float2";
13576 
13577 	// Fragment function in
13578 	case BuiltInFrontFacing:
13579 		return "bool";
13580 	case BuiltInPointCoord:
13581 		return "float2";
13582 	case BuiltInFragCoord:
13583 		return "float4";
13584 	case BuiltInSampleId:
13585 		return "uint";
13586 	case BuiltInSampleMask:
13587 		return "uint";
13588 	case BuiltInSamplePosition:
13589 		return "float2";
13590 	case BuiltInViewIndex:
13591 		return "uint";
13592 
13593 	case BuiltInHelperInvocation:
13594 		return "bool";
13595 
13596 	case BuiltInBaryCoordNV:
13597 	case BuiltInBaryCoordNoPerspNV:
13598 		// Use the type as declared, can be 1, 2 or 3 components.
13599 		return type_to_glsl(get_variable_data_type(get<SPIRVariable>(id)));
13600 
13601 	// Fragment function out
13602 	case BuiltInFragDepth:
13603 		return "float";
13604 
13605 	case BuiltInFragStencilRefEXT:
13606 		return "uint";
13607 
13608 	// Compute function in
13609 	case BuiltInGlobalInvocationId:
13610 	case BuiltInLocalInvocationId:
13611 	case BuiltInNumWorkgroups:
13612 	case BuiltInWorkgroupId:
13613 		return "uint3";
13614 	case BuiltInLocalInvocationIndex:
13615 	case BuiltInNumSubgroups:
13616 	case BuiltInSubgroupId:
13617 	case BuiltInSubgroupSize:
13618 	case BuiltInSubgroupLocalInvocationId:
13619 		return "uint";
13620 	case BuiltInSubgroupEqMask:
13621 	case BuiltInSubgroupGeMask:
13622 	case BuiltInSubgroupGtMask:
13623 	case BuiltInSubgroupLeMask:
13624 	case BuiltInSubgroupLtMask:
13625 		return "uint4";
13626 
13627 	case BuiltInDeviceIndex:
13628 		return "int";
13629 
13630 	default:
13631 		return "unsupported-built-in-type";
13632 	}
13633 }
13634 
13635 // Returns the declaration of a built-in argument to a function
built_in_func_arg(BuiltIn builtin,bool prefix_comma)13636 string CompilerMSL::built_in_func_arg(BuiltIn builtin, bool prefix_comma)
13637 {
13638 	string bi_arg;
13639 	if (prefix_comma)
13640 		bi_arg += ", ";
13641 
13642 	// Handle HLSL-style 0-based vertex/instance index.
13643 	builtin_declaration = true;
13644 	bi_arg += builtin_type_decl(builtin);
13645 	bi_arg += " " + builtin_to_glsl(builtin, StorageClassInput);
13646 	bi_arg += " [[" + builtin_qualifier(builtin) + "]]";
13647 	builtin_declaration = false;
13648 
13649 	return bi_arg;
13650 }
13651 
get_physical_member_type(const SPIRType & type,uint32_t index) const13652 const SPIRType &CompilerMSL::get_physical_member_type(const SPIRType &type, uint32_t index) const
13653 {
13654 	if (member_is_remapped_physical_type(type, index))
13655 		return get<SPIRType>(get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID));
13656 	else
13657 		return get<SPIRType>(type.member_types[index]);
13658 }
13659 
get_presumed_input_type(const SPIRType & ib_type,uint32_t index) const13660 SPIRType CompilerMSL::get_presumed_input_type(const SPIRType &ib_type, uint32_t index) const
13661 {
13662 	SPIRType type = get_physical_member_type(ib_type, index);
13663 	uint32_t loc = get_member_decoration(ib_type.self, index, DecorationLocation);
13664 	if (inputs_by_location.count(loc))
13665 	{
13666 		if (inputs_by_location.at(loc).vecsize > type.vecsize)
13667 			type.vecsize = inputs_by_location.at(loc).vecsize;
13668 	}
13669 	return type;
13670 }
13671 
get_declared_type_array_stride_msl(const SPIRType & type,bool is_packed,bool row_major) const13672 uint32_t CompilerMSL::get_declared_type_array_stride_msl(const SPIRType &type, bool is_packed, bool row_major) const
13673 {
13674 	// Array stride in MSL is always size * array_size. sizeof(float3) == 16,
13675 	// unlike GLSL and HLSL where array stride would be 16 and size 12.
13676 
13677 	// We could use parent type here and recurse, but that makes creating physical type remappings
13678 	// far more complicated. We'd rather just create the final type, and ignore having to create the entire type
13679 	// hierarchy in order to compute this value, so make a temporary type on the stack.
13680 
13681 	auto basic_type = type;
13682 	basic_type.array.clear();
13683 	basic_type.array_size_literal.clear();
13684 	uint32_t value_size = get_declared_type_size_msl(basic_type, is_packed, row_major);
13685 
13686 	uint32_t dimensions = uint32_t(type.array.size());
13687 	assert(dimensions > 0);
13688 	dimensions--;
13689 
13690 	// Multiply together every dimension, except the last one.
13691 	for (uint32_t dim = 0; dim < dimensions; dim++)
13692 	{
13693 		uint32_t array_size = to_array_size_literal(type, dim);
13694 		value_size *= max(array_size, 1u);
13695 	}
13696 
13697 	return value_size;
13698 }
13699 
get_declared_struct_member_array_stride_msl(const SPIRType & type,uint32_t index) const13700 uint32_t CompilerMSL::get_declared_struct_member_array_stride_msl(const SPIRType &type, uint32_t index) const
13701 {
13702 	return get_declared_type_array_stride_msl(get_physical_member_type(type, index),
13703 	                                          member_is_packed_physical_type(type, index),
13704 	                                          has_member_decoration(type.self, index, DecorationRowMajor));
13705 }
13706 
get_declared_input_array_stride_msl(const SPIRType & type,uint32_t index) const13707 uint32_t CompilerMSL::get_declared_input_array_stride_msl(const SPIRType &type, uint32_t index) const
13708 {
13709 	return get_declared_type_array_stride_msl(get_presumed_input_type(type, index), false,
13710 	                                          has_member_decoration(type.self, index, DecorationRowMajor));
13711 }
13712 
get_declared_type_matrix_stride_msl(const SPIRType & type,bool packed,bool row_major) const13713 uint32_t CompilerMSL::get_declared_type_matrix_stride_msl(const SPIRType &type, bool packed, bool row_major) const
13714 {
13715 	// For packed matrices, we just use the size of the vector type.
13716 	// Otherwise, MatrixStride == alignment, which is the size of the underlying vector type.
13717 	if (packed)
13718 		return (type.width / 8) * ((row_major && type.columns > 1) ? type.columns : type.vecsize);
13719 	else
13720 		return get_declared_type_alignment_msl(type, false, row_major);
13721 }
13722 
get_declared_struct_member_matrix_stride_msl(const SPIRType & type,uint32_t index) const13723 uint32_t CompilerMSL::get_declared_struct_member_matrix_stride_msl(const SPIRType &type, uint32_t index) const
13724 {
13725 	return get_declared_type_matrix_stride_msl(get_physical_member_type(type, index),
13726 	                                           member_is_packed_physical_type(type, index),
13727 	                                           has_member_decoration(type.self, index, DecorationRowMajor));
13728 }
13729 
get_declared_input_matrix_stride_msl(const SPIRType & type,uint32_t index) const13730 uint32_t CompilerMSL::get_declared_input_matrix_stride_msl(const SPIRType &type, uint32_t index) const
13731 {
13732 	return get_declared_type_matrix_stride_msl(get_presumed_input_type(type, index), false,
13733 	                                           has_member_decoration(type.self, index, DecorationRowMajor));
13734 }
13735 
get_declared_struct_size_msl(const SPIRType & struct_type,bool ignore_alignment,bool ignore_padding) const13736 uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type, bool ignore_alignment,
13737                                                    bool ignore_padding) const
13738 {
13739 	// If we have a target size, that is the declared size as well.
13740 	if (!ignore_padding && has_extended_decoration(struct_type.self, SPIRVCrossDecorationPaddingTarget))
13741 		return get_extended_decoration(struct_type.self, SPIRVCrossDecorationPaddingTarget);
13742 
13743 	if (struct_type.member_types.empty())
13744 		return 0;
13745 
13746 	uint32_t mbr_cnt = uint32_t(struct_type.member_types.size());
13747 
13748 	// In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
13749 	uint32_t alignment = 1;
13750 
13751 	if (!ignore_alignment)
13752 	{
13753 		for (uint32_t i = 0; i < mbr_cnt; i++)
13754 		{
13755 			uint32_t mbr_alignment = get_declared_struct_member_alignment_msl(struct_type, i);
13756 			alignment = max(alignment, mbr_alignment);
13757 		}
13758 	}
13759 
13760 	// Last member will always be matched to the final Offset decoration, but size of struct in MSL now depends
13761 	// on physical size in MSL, and the size of the struct itself is then aligned to struct alignment.
13762 	uint32_t spirv_offset = type_struct_member_offset(struct_type, mbr_cnt - 1);
13763 	uint32_t msl_size = spirv_offset + get_declared_struct_member_size_msl(struct_type, mbr_cnt - 1);
13764 	msl_size = (msl_size + alignment - 1) & ~(alignment - 1);
13765 	return msl_size;
13766 }
13767 
13768 // Returns the byte size of a struct member.
get_declared_type_size_msl(const SPIRType & type,bool is_packed,bool row_major) const13769 uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const
13770 {
13771 	switch (type.basetype)
13772 	{
13773 	case SPIRType::Unknown:
13774 	case SPIRType::Void:
13775 	case SPIRType::AtomicCounter:
13776 	case SPIRType::Image:
13777 	case SPIRType::SampledImage:
13778 	case SPIRType::Sampler:
13779 		SPIRV_CROSS_THROW("Querying size of opaque object.");
13780 
13781 	default:
13782 	{
13783 		if (!type.array.empty())
13784 		{
13785 			uint32_t array_size = to_array_size_literal(type);
13786 			return get_declared_type_array_stride_msl(type, is_packed, row_major) * max(array_size, 1u);
13787 		}
13788 
13789 		if (type.basetype == SPIRType::Struct)
13790 			return get_declared_struct_size_msl(type);
13791 
13792 		if (is_packed)
13793 		{
13794 			return type.vecsize * type.columns * (type.width / 8);
13795 		}
13796 		else
13797 		{
13798 			// An unpacked 3-element vector or matrix column is the same memory size as a 4-element.
13799 			uint32_t vecsize = type.vecsize;
13800 			uint32_t columns = type.columns;
13801 
13802 			if (row_major && columns > 1)
13803 				swap(vecsize, columns);
13804 
13805 			if (vecsize == 3)
13806 				vecsize = 4;
13807 
13808 			return vecsize * columns * (type.width / 8);
13809 		}
13810 	}
13811 	}
13812 }
13813 
get_declared_struct_member_size_msl(const SPIRType & type,uint32_t index) const13814 uint32_t CompilerMSL::get_declared_struct_member_size_msl(const SPIRType &type, uint32_t index) const
13815 {
13816 	return get_declared_type_size_msl(get_physical_member_type(type, index),
13817 	                                  member_is_packed_physical_type(type, index),
13818 	                                  has_member_decoration(type.self, index, DecorationRowMajor));
13819 }
13820 
get_declared_input_size_msl(const SPIRType & type,uint32_t index) const13821 uint32_t CompilerMSL::get_declared_input_size_msl(const SPIRType &type, uint32_t index) const
13822 {
13823 	return get_declared_type_size_msl(get_presumed_input_type(type, index), false,
13824 	                                  has_member_decoration(type.self, index, DecorationRowMajor));
13825 }
13826 
13827 // Returns the byte alignment of a type.
get_declared_type_alignment_msl(const SPIRType & type,bool is_packed,bool row_major) const13828 uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool is_packed, bool row_major) const
13829 {
13830 	switch (type.basetype)
13831 	{
13832 	case SPIRType::Unknown:
13833 	case SPIRType::Void:
13834 	case SPIRType::AtomicCounter:
13835 	case SPIRType::Image:
13836 	case SPIRType::SampledImage:
13837 	case SPIRType::Sampler:
13838 		SPIRV_CROSS_THROW("Querying alignment of opaque object.");
13839 
13840 	case SPIRType::Int64:
13841 		SPIRV_CROSS_THROW("long types are not supported in buffers in MSL.");
13842 	case SPIRType::UInt64:
13843 		SPIRV_CROSS_THROW("ulong types are not supported in buffers in MSL.");
13844 	case SPIRType::Double:
13845 		SPIRV_CROSS_THROW("double types are not supported in buffers in MSL.");
13846 
13847 	case SPIRType::Struct:
13848 	{
13849 		// In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
13850 		uint32_t alignment = 1;
13851 		for (uint32_t i = 0; i < type.member_types.size(); i++)
13852 			alignment = max(alignment, uint32_t(get_declared_struct_member_alignment_msl(type, i)));
13853 		return alignment;
13854 	}
13855 
13856 	default:
13857 	{
13858 		// Alignment of packed type is the same as the underlying component or column size.
13859 		// Alignment of unpacked type is the same as the vector size.
13860 		// Alignment of 3-elements vector is the same as 4-elements (including packed using column).
13861 		if (is_packed)
13862 		{
13863 			// If we have packed_T and friends, the alignment is always scalar.
13864 			return type.width / 8;
13865 		}
13866 		else
13867 		{
13868 			// This is the general rule for MSL. Size == alignment.
13869 			uint32_t vecsize = (row_major && type.columns > 1) ? type.columns : type.vecsize;
13870 			return (type.width / 8) * (vecsize == 3 ? 4 : vecsize);
13871 		}
13872 	}
13873 	}
13874 }
13875 
get_declared_struct_member_alignment_msl(const SPIRType & type,uint32_t index) const13876 uint32_t CompilerMSL::get_declared_struct_member_alignment_msl(const SPIRType &type, uint32_t index) const
13877 {
13878 	return get_declared_type_alignment_msl(get_physical_member_type(type, index),
13879 	                                       member_is_packed_physical_type(type, index),
13880 	                                       has_member_decoration(type.self, index, DecorationRowMajor));
13881 }
13882 
get_declared_input_alignment_msl(const SPIRType & type,uint32_t index) const13883 uint32_t CompilerMSL::get_declared_input_alignment_msl(const SPIRType &type, uint32_t index) const
13884 {
13885 	return get_declared_type_alignment_msl(get_presumed_input_type(type, index), false,
13886 	                                       has_member_decoration(type.self, index, DecorationRowMajor));
13887 }
13888 
skip_argument(uint32_t) const13889 bool CompilerMSL::skip_argument(uint32_t) const
13890 {
13891 	return false;
13892 }
13893 
analyze_sampled_image_usage()13894 void CompilerMSL::analyze_sampled_image_usage()
13895 {
13896 	if (msl_options.swizzle_texture_samples)
13897 	{
13898 		SampledImageScanner scanner(*this);
13899 		traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), scanner);
13900 	}
13901 }
13902 
handle(spv::Op opcode,const uint32_t * args,uint32_t length)13903 bool CompilerMSL::SampledImageScanner::handle(spv::Op opcode, const uint32_t *args, uint32_t length)
13904 {
13905 	switch (opcode)
13906 	{
13907 	case OpLoad:
13908 	case OpImage:
13909 	case OpSampledImage:
13910 	{
13911 		if (length < 3)
13912 			return false;
13913 
13914 		uint32_t result_type = args[0];
13915 		auto &type = compiler.get<SPIRType>(result_type);
13916 		if ((type.basetype != SPIRType::Image && type.basetype != SPIRType::SampledImage) || type.image.sampled != 1)
13917 			return true;
13918 
13919 		uint32_t id = args[1];
13920 		compiler.set<SPIRExpression>(id, "", result_type, true);
13921 		break;
13922 	}
13923 	case OpImageSampleExplicitLod:
13924 	case OpImageSampleProjExplicitLod:
13925 	case OpImageSampleDrefExplicitLod:
13926 	case OpImageSampleProjDrefExplicitLod:
13927 	case OpImageSampleImplicitLod:
13928 	case OpImageSampleProjImplicitLod:
13929 	case OpImageSampleDrefImplicitLod:
13930 	case OpImageSampleProjDrefImplicitLod:
13931 	case OpImageFetch:
13932 	case OpImageGather:
13933 	case OpImageDrefGather:
13934 		compiler.has_sampled_images =
13935 		    compiler.has_sampled_images || compiler.is_sampled_image_type(compiler.expression_type(args[2]));
13936 		compiler.needs_swizzle_buffer_def = compiler.needs_swizzle_buffer_def || compiler.has_sampled_images;
13937 		break;
13938 	default:
13939 		break;
13940 	}
13941 	return true;
13942 }
13943 
13944 // If a needed custom function wasn't added before, add it and force a recompile.
add_spv_func_and_recompile(SPVFuncImpl spv_func)13945 void CompilerMSL::add_spv_func_and_recompile(SPVFuncImpl spv_func)
13946 {
13947 	if (spv_function_implementations.count(spv_func) == 0)
13948 	{
13949 		spv_function_implementations.insert(spv_func);
13950 		suppress_missing_prototypes = true;
13951 		force_recompile();
13952 	}
13953 }
13954 
handle(Op opcode,const uint32_t * args,uint32_t length)13955 bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, uint32_t length)
13956 {
13957 	// Since MSL exists in a single execution scope, function prototype declarations are not
13958 	// needed, and clutter the output. If secondary functions are output (either as a SPIR-V
13959 	// function implementation or as indicated by the presence of OpFunctionCall), then set
13960 	// suppress_missing_prototypes to suppress compiler warnings of missing function prototypes.
13961 
13962 	// Mark if the input requires the implementation of an SPIR-V function that does not exist in Metal.
13963 	SPVFuncImpl spv_func = get_spv_func_impl(opcode, args);
13964 	if (spv_func != SPVFuncImplNone)
13965 	{
13966 		compiler.spv_function_implementations.insert(spv_func);
13967 		suppress_missing_prototypes = true;
13968 	}
13969 
13970 	switch (opcode)
13971 	{
13972 
13973 	case OpFunctionCall:
13974 		suppress_missing_prototypes = true;
13975 		break;
13976 
13977 	// Emulate texture2D atomic operations
13978 	case OpImageTexelPointer:
13979 	{
13980 		auto *var = compiler.maybe_get_backing_variable(args[2]);
13981 		image_pointers[args[1]] = var ? var->self : ID(0);
13982 		break;
13983 	}
13984 
13985 	case OpImageWrite:
13986 		if (!compiler.msl_options.supports_msl_version(2, 2))
13987 			uses_resource_write = true;
13988 		break;
13989 
13990 	case OpStore:
13991 		check_resource_write(args[0]);
13992 		break;
13993 
13994 	// Emulate texture2D atomic operations
13995 	case OpAtomicExchange:
13996 	case OpAtomicCompareExchange:
13997 	case OpAtomicCompareExchangeWeak:
13998 	case OpAtomicIIncrement:
13999 	case OpAtomicIDecrement:
14000 	case OpAtomicIAdd:
14001 	case OpAtomicISub:
14002 	case OpAtomicSMin:
14003 	case OpAtomicUMin:
14004 	case OpAtomicSMax:
14005 	case OpAtomicUMax:
14006 	case OpAtomicAnd:
14007 	case OpAtomicOr:
14008 	case OpAtomicXor:
14009 	{
14010 		uses_atomics = true;
14011 		auto it = image_pointers.find(args[2]);
14012 		if (it != image_pointers.end())
14013 		{
14014 			compiler.atomic_image_vars.insert(it->second);
14015 		}
14016 		check_resource_write(args[2]);
14017 		break;
14018 	}
14019 
14020 	case OpAtomicStore:
14021 	{
14022 		uses_atomics = true;
14023 		auto it = image_pointers.find(args[0]);
14024 		if (it != image_pointers.end())
14025 		{
14026 			compiler.atomic_image_vars.insert(it->second);
14027 		}
14028 		check_resource_write(args[0]);
14029 		break;
14030 	}
14031 
14032 	case OpAtomicLoad:
14033 	{
14034 		uses_atomics = true;
14035 		auto it = image_pointers.find(args[2]);
14036 		if (it != image_pointers.end())
14037 		{
14038 			compiler.atomic_image_vars.insert(it->second);
14039 		}
14040 		break;
14041 	}
14042 
14043 	case OpGroupNonUniformInverseBallot:
14044 		needs_subgroup_invocation_id = true;
14045 		break;
14046 
14047 	case OpGroupNonUniformBallotFindLSB:
14048 	case OpGroupNonUniformBallotFindMSB:
14049 		needs_subgroup_size = true;
14050 		break;
14051 
14052 	case OpGroupNonUniformBallotBitCount:
14053 		if (args[3] == GroupOperationReduce)
14054 			needs_subgroup_size = true;
14055 		else
14056 			needs_subgroup_invocation_id = true;
14057 		break;
14058 
14059 	case OpArrayLength:
14060 	{
14061 		auto *var = compiler.maybe_get_backing_variable(args[2]);
14062 		if (var)
14063 			compiler.buffers_requiring_array_length.insert(var->self);
14064 		break;
14065 	}
14066 
14067 	case OpInBoundsAccessChain:
14068 	case OpAccessChain:
14069 	case OpPtrAccessChain:
14070 	{
14071 		// OpArrayLength might want to know if taking ArrayLength of an array of SSBOs.
14072 		uint32_t result_type = args[0];
14073 		uint32_t id = args[1];
14074 		uint32_t ptr = args[2];
14075 
14076 		compiler.set<SPIRExpression>(id, "", result_type, true);
14077 		compiler.register_read(id, ptr, true);
14078 		compiler.ir.ids[id].set_allow_type_rewrite();
14079 		break;
14080 	}
14081 
14082 	case OpExtInst:
14083 	{
14084 		uint32_t extension_set = args[2];
14085 		if (compiler.get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
14086 		{
14087 			auto op_450 = static_cast<GLSLstd450>(args[3]);
14088 			switch (op_450)
14089 			{
14090 			case GLSLstd450InterpolateAtCentroid:
14091 			case GLSLstd450InterpolateAtSample:
14092 			case GLSLstd450InterpolateAtOffset:
14093 			{
14094 				if (!compiler.msl_options.supports_msl_version(2, 3))
14095 					SPIRV_CROSS_THROW("Pull-model interpolation requires MSL 2.3.");
14096 				// Fragment varyings used with pull-model interpolation need special handling,
14097 				// due to the way pull-model interpolation works in Metal.
14098 				auto *var = compiler.maybe_get_backing_variable(args[4]);
14099 				if (var)
14100 				{
14101 					compiler.pull_model_inputs.insert(var->self);
14102 					auto &var_type = compiler.get_variable_element_type(*var);
14103 					// In addition, if this variable has a 'Sample' decoration, we need the sample ID
14104 					// in order to do default interpolation.
14105 					if (compiler.has_decoration(var->self, DecorationSample))
14106 					{
14107 						needs_sample_id = true;
14108 					}
14109 					else if (var_type.basetype == SPIRType::Struct)
14110 					{
14111 						// Now we need to check each member and see if it has this decoration.
14112 						for (uint32_t i = 0; i < var_type.member_types.size(); ++i)
14113 						{
14114 							if (compiler.has_member_decoration(var_type.self, i, DecorationSample))
14115 							{
14116 								needs_sample_id = true;
14117 								break;
14118 							}
14119 						}
14120 					}
14121 				}
14122 				break;
14123 			}
14124 			default:
14125 				break;
14126 			}
14127 		}
14128 		break;
14129 	}
14130 
14131 	default:
14132 		break;
14133 	}
14134 
14135 	// If it has one, keep track of the instruction's result type, mapped by ID
14136 	uint32_t result_type, result_id;
14137 	if (compiler.instruction_to_result_type(result_type, result_id, opcode, args, length))
14138 		result_types[result_id] = result_type;
14139 
14140 	return true;
14141 }
14142 
14143 // If the variable is a Uniform or StorageBuffer, mark that a resource has been written to.
check_resource_write(uint32_t var_id)14144 void CompilerMSL::OpCodePreprocessor::check_resource_write(uint32_t var_id)
14145 {
14146 	auto *p_var = compiler.maybe_get_backing_variable(var_id);
14147 	StorageClass sc = p_var ? p_var->storage : StorageClassMax;
14148 	if (!compiler.msl_options.supports_msl_version(2, 1) &&
14149 	    (sc == StorageClassUniform || sc == StorageClassStorageBuffer))
14150 		uses_resource_write = true;
14151 }
14152 
14153 // Returns an enumeration of a SPIR-V function that needs to be output for certain Op codes.
get_spv_func_impl(Op opcode,const uint32_t * args)14154 CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op opcode, const uint32_t *args)
14155 {
14156 	switch (opcode)
14157 	{
14158 	case OpFMod:
14159 		return SPVFuncImplMod;
14160 
14161 	case OpFAdd:
14162 		if (compiler.msl_options.invariant_float_math)
14163 		{
14164 			return SPVFuncImplFAdd;
14165 		}
14166 		break;
14167 
14168 	case OpFMul:
14169 	case OpOuterProduct:
14170 	case OpMatrixTimesVector:
14171 	case OpVectorTimesMatrix:
14172 	case OpMatrixTimesMatrix:
14173 		if (compiler.msl_options.invariant_float_math)
14174 		{
14175 			return SPVFuncImplFMul;
14176 		}
14177 		break;
14178 
14179 	case OpTypeArray:
14180 	{
14181 		// Allow Metal to use the array<T> template to make arrays a value type
14182 		return SPVFuncImplUnsafeArray;
14183 	}
14184 
14185 	// Emulate texture2D atomic operations
14186 	case OpAtomicExchange:
14187 	case OpAtomicCompareExchange:
14188 	case OpAtomicCompareExchangeWeak:
14189 	case OpAtomicIIncrement:
14190 	case OpAtomicIDecrement:
14191 	case OpAtomicIAdd:
14192 	case OpAtomicISub:
14193 	case OpAtomicSMin:
14194 	case OpAtomicUMin:
14195 	case OpAtomicSMax:
14196 	case OpAtomicUMax:
14197 	case OpAtomicAnd:
14198 	case OpAtomicOr:
14199 	case OpAtomicXor:
14200 	case OpAtomicLoad:
14201 	case OpAtomicStore:
14202 	{
14203 		auto it = image_pointers.find(args[opcode == OpAtomicStore ? 0 : 2]);
14204 		if (it != image_pointers.end())
14205 		{
14206 			uint32_t tid = compiler.get<SPIRVariable>(it->second).basetype;
14207 			if (tid && compiler.get<SPIRType>(tid).image.dim == Dim2D)
14208 				return SPVFuncImplImage2DAtomicCoords;
14209 		}
14210 		break;
14211 	}
14212 
14213 	case OpImageFetch:
14214 	case OpImageRead:
14215 	case OpImageWrite:
14216 	{
14217 		// Retrieve the image type, and if it's a Buffer, emit a texel coordinate function
14218 		uint32_t tid = result_types[args[opcode == OpImageWrite ? 0 : 2]];
14219 		if (tid && compiler.get<SPIRType>(tid).image.dim == DimBuffer && !compiler.msl_options.texture_buffer_native)
14220 			return SPVFuncImplTexelBufferCoords;
14221 		break;
14222 	}
14223 
14224 	case OpExtInst:
14225 	{
14226 		uint32_t extension_set = args[2];
14227 		if (compiler.get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
14228 		{
14229 			auto op_450 = static_cast<GLSLstd450>(args[3]);
14230 			switch (op_450)
14231 			{
14232 			case GLSLstd450Radians:
14233 				return SPVFuncImplRadians;
14234 			case GLSLstd450Degrees:
14235 				return SPVFuncImplDegrees;
14236 			case GLSLstd450FindILsb:
14237 				return SPVFuncImplFindILsb;
14238 			case GLSLstd450FindSMsb:
14239 				return SPVFuncImplFindSMsb;
14240 			case GLSLstd450FindUMsb:
14241 				return SPVFuncImplFindUMsb;
14242 			case GLSLstd450SSign:
14243 				return SPVFuncImplSSign;
14244 			case GLSLstd450Reflect:
14245 			{
14246 				auto &type = compiler.get<SPIRType>(args[0]);
14247 				if (type.vecsize == 1)
14248 					return SPVFuncImplReflectScalar;
14249 				break;
14250 			}
14251 			case GLSLstd450Refract:
14252 			{
14253 				auto &type = compiler.get<SPIRType>(args[0]);
14254 				if (type.vecsize == 1)
14255 					return SPVFuncImplRefractScalar;
14256 				break;
14257 			}
14258 			case GLSLstd450FaceForward:
14259 			{
14260 				auto &type = compiler.get<SPIRType>(args[0]);
14261 				if (type.vecsize == 1)
14262 					return SPVFuncImplFaceForwardScalar;
14263 				break;
14264 			}
14265 			case GLSLstd450MatrixInverse:
14266 			{
14267 				auto &mat_type = compiler.get<SPIRType>(args[0]);
14268 				switch (mat_type.columns)
14269 				{
14270 				case 2:
14271 					return SPVFuncImplInverse2x2;
14272 				case 3:
14273 					return SPVFuncImplInverse3x3;
14274 				case 4:
14275 					return SPVFuncImplInverse4x4;
14276 				default:
14277 					break;
14278 				}
14279 				break;
14280 			}
14281 			default:
14282 				break;
14283 			}
14284 		}
14285 		break;
14286 	}
14287 
14288 	case OpGroupNonUniformBroadcast:
14289 		return SPVFuncImplSubgroupBroadcast;
14290 
14291 	case OpGroupNonUniformBroadcastFirst:
14292 		return SPVFuncImplSubgroupBroadcastFirst;
14293 
14294 	case OpGroupNonUniformBallot:
14295 		return SPVFuncImplSubgroupBallot;
14296 
14297 	case OpGroupNonUniformInverseBallot:
14298 	case OpGroupNonUniformBallotBitExtract:
14299 		return SPVFuncImplSubgroupBallotBitExtract;
14300 
14301 	case OpGroupNonUniformBallotFindLSB:
14302 		return SPVFuncImplSubgroupBallotFindLSB;
14303 
14304 	case OpGroupNonUniformBallotFindMSB:
14305 		return SPVFuncImplSubgroupBallotFindMSB;
14306 
14307 	case OpGroupNonUniformBallotBitCount:
14308 		return SPVFuncImplSubgroupBallotBitCount;
14309 
14310 	case OpGroupNonUniformAllEqual:
14311 		return SPVFuncImplSubgroupAllEqual;
14312 
14313 	case OpGroupNonUniformShuffle:
14314 		return SPVFuncImplSubgroupShuffle;
14315 
14316 	case OpGroupNonUniformShuffleXor:
14317 		return SPVFuncImplSubgroupShuffleXor;
14318 
14319 	case OpGroupNonUniformShuffleUp:
14320 		return SPVFuncImplSubgroupShuffleUp;
14321 
14322 	case OpGroupNonUniformShuffleDown:
14323 		return SPVFuncImplSubgroupShuffleDown;
14324 
14325 	case OpGroupNonUniformQuadBroadcast:
14326 		return SPVFuncImplQuadBroadcast;
14327 
14328 	case OpGroupNonUniformQuadSwap:
14329 		return SPVFuncImplQuadSwap;
14330 
14331 	default:
14332 		break;
14333 	}
14334 	return SPVFuncImplNone;
14335 }
14336 
14337 // Sort both type and meta member content based on builtin status (put builtins at end),
14338 // then by the required sorting aspect.
sort()14339 void CompilerMSL::MemberSorter::sort()
14340 {
14341 	// Create a temporary array of consecutive member indices and sort it based on how
14342 	// the members should be reordered, based on builtin and sorting aspect meta info.
14343 	size_t mbr_cnt = type.member_types.size();
14344 	SmallVector<uint32_t> mbr_idxs(mbr_cnt);
14345 	std::iota(mbr_idxs.begin(), mbr_idxs.end(), 0); // Fill with consecutive indices
14346 	std::stable_sort(mbr_idxs.begin(), mbr_idxs.end(), *this); // Sort member indices based on sorting aspect
14347 
14348 	bool sort_is_identity = true;
14349 	for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
14350 	{
14351 		if (mbr_idx != mbr_idxs[mbr_idx])
14352 		{
14353 			sort_is_identity = false;
14354 			break;
14355 		}
14356 	}
14357 
14358 	if (sort_is_identity)
14359 		return;
14360 
14361 	if (meta.members.size() < type.member_types.size())
14362 	{
14363 		// This should never trigger in normal circumstances, but to be safe.
14364 		meta.members.resize(type.member_types.size());
14365 	}
14366 
14367 	// Move type and meta member info to the order defined by the sorted member indices.
14368 	// This is done by creating temporary copies of both member types and meta, and then
14369 	// copying back to the original content at the sorted indices.
14370 	auto mbr_types_cpy = type.member_types;
14371 	auto mbr_meta_cpy = meta.members;
14372 	for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
14373 	{
14374 		type.member_types[mbr_idx] = mbr_types_cpy[mbr_idxs[mbr_idx]];
14375 		meta.members[mbr_idx] = mbr_meta_cpy[mbr_idxs[mbr_idx]];
14376 	}
14377 
14378 	if (sort_aspect == SortAspect::Offset)
14379 	{
14380 		// If we're sorting by Offset, this might affect user code which accesses a buffer block.
14381 		// We will need to redirect member indices from one index to sorted index.
14382 		type.member_type_index_redirection = std::move(mbr_idxs);
14383 	}
14384 }
14385 
14386 // Sort first by builtin status (put builtins at end), then by the sorting aspect.
operator ()(uint32_t mbr_idx1,uint32_t mbr_idx2)14387 bool CompilerMSL::MemberSorter::operator()(uint32_t mbr_idx1, uint32_t mbr_idx2)
14388 {
14389 	auto &mbr_meta1 = meta.members[mbr_idx1];
14390 	auto &mbr_meta2 = meta.members[mbr_idx2];
14391 	if (mbr_meta1.builtin != mbr_meta2.builtin)
14392 		return mbr_meta2.builtin;
14393 	else
14394 		switch (sort_aspect)
14395 		{
14396 		case Location:
14397 			return mbr_meta1.location < mbr_meta2.location;
14398 		case LocationReverse:
14399 			return mbr_meta1.location > mbr_meta2.location;
14400 		case Offset:
14401 			return mbr_meta1.offset < mbr_meta2.offset;
14402 		case OffsetThenLocationReverse:
14403 			return (mbr_meta1.offset < mbr_meta2.offset) ||
14404 			       ((mbr_meta1.offset == mbr_meta2.offset) && (mbr_meta1.location > mbr_meta2.location));
14405 		case Alphabetical:
14406 			return mbr_meta1.alias < mbr_meta2.alias;
14407 		default:
14408 			return false;
14409 		}
14410 }
14411 
MemberSorter(SPIRType & t,Meta & m,SortAspect sa)14412 CompilerMSL::MemberSorter::MemberSorter(SPIRType &t, Meta &m, SortAspect sa)
14413     : type(t)
14414     , meta(m)
14415     , sort_aspect(sa)
14416 {
14417 	// Ensure enough meta info is available
14418 	meta.members.resize(max(type.member_types.size(), meta.members.size()));
14419 }
14420 
remap_constexpr_sampler(VariableID id,const MSLConstexprSampler & sampler)14421 void CompilerMSL::remap_constexpr_sampler(VariableID id, const MSLConstexprSampler &sampler)
14422 {
14423 	auto &type = get<SPIRType>(get<SPIRVariable>(id).basetype);
14424 	if (type.basetype != SPIRType::SampledImage && type.basetype != SPIRType::Sampler)
14425 		SPIRV_CROSS_THROW("Can only remap SampledImage and Sampler type.");
14426 	if (!type.array.empty())
14427 		SPIRV_CROSS_THROW("Can not remap array of samplers.");
14428 	constexpr_samplers_by_id[id] = sampler;
14429 }
14430 
remap_constexpr_sampler_by_binding(uint32_t desc_set,uint32_t binding,const MSLConstexprSampler & sampler)14431 void CompilerMSL::remap_constexpr_sampler_by_binding(uint32_t desc_set, uint32_t binding,
14432                                                      const MSLConstexprSampler &sampler)
14433 {
14434 	constexpr_samplers_by_binding[{ desc_set, binding }] = sampler;
14435 }
14436 
cast_from_builtin_load(uint32_t source_id,std::string & expr,const SPIRType & expr_type)14437 void CompilerMSL::cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
14438 {
14439 	auto *var = maybe_get_backing_variable(source_id);
14440 	if (var)
14441 		source_id = var->self;
14442 
14443 	// Only interested in standalone builtin variables.
14444 	if (!has_decoration(source_id, DecorationBuiltIn))
14445 		return;
14446 
14447 	auto builtin = static_cast<BuiltIn>(get_decoration(source_id, DecorationBuiltIn));
14448 	auto expected_type = expr_type.basetype;
14449 	auto expected_width = expr_type.width;
14450 	switch (builtin)
14451 	{
14452 	case BuiltInGlobalInvocationId:
14453 	case BuiltInLocalInvocationId:
14454 	case BuiltInWorkgroupId:
14455 	case BuiltInLocalInvocationIndex:
14456 	case BuiltInWorkgroupSize:
14457 	case BuiltInNumWorkgroups:
14458 	case BuiltInLayer:
14459 	case BuiltInViewportIndex:
14460 	case BuiltInFragStencilRefEXT:
14461 	case BuiltInPrimitiveId:
14462 	case BuiltInSubgroupSize:
14463 	case BuiltInSubgroupLocalInvocationId:
14464 	case BuiltInViewIndex:
14465 	case BuiltInVertexIndex:
14466 	case BuiltInInstanceIndex:
14467 	case BuiltInBaseInstance:
14468 	case BuiltInBaseVertex:
14469 		expected_type = SPIRType::UInt;
14470 		expected_width = 32;
14471 		break;
14472 
14473 	case BuiltInTessLevelInner:
14474 	case BuiltInTessLevelOuter:
14475 		if (get_execution_model() == ExecutionModelTessellationControl)
14476 		{
14477 			expected_type = SPIRType::Half;
14478 			expected_width = 16;
14479 		}
14480 		break;
14481 
14482 	default:
14483 		break;
14484 	}
14485 
14486 	if (expected_type != expr_type.basetype)
14487 	{
14488 		if (expected_width != expr_type.width)
14489 		{
14490 			// These are of different widths, so we cannot do a straight bitcast.
14491 			expr = join(type_to_glsl(expr_type), "(", expr, ")");
14492 		}
14493 		else
14494 		{
14495 			expr = bitcast_expression(expr_type, expected_type, expr);
14496 		}
14497 	}
14498 
14499 	if (builtin == BuiltInTessCoord && get_entry_point().flags.get(ExecutionModeQuads) && expr_type.vecsize == 3)
14500 	{
14501 		// In SPIR-V, this is always a vec3, even for quads. In Metal, though, it's a float2 for quads.
14502 		// The code is expecting a float3, so we need to widen this.
14503 		expr = join("float3(", expr, ", 0)");
14504 	}
14505 }
14506 
cast_to_builtin_store(uint32_t target_id,std::string & expr,const SPIRType & expr_type)14507 void CompilerMSL::cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
14508 {
14509 	auto *var = maybe_get_backing_variable(target_id);
14510 	if (var)
14511 		target_id = var->self;
14512 
14513 	// Only interested in standalone builtin variables.
14514 	if (!has_decoration(target_id, DecorationBuiltIn))
14515 		return;
14516 
14517 	auto builtin = static_cast<BuiltIn>(get_decoration(target_id, DecorationBuiltIn));
14518 	auto expected_type = expr_type.basetype;
14519 	auto expected_width = expr_type.width;
14520 	switch (builtin)
14521 	{
14522 	case BuiltInLayer:
14523 	case BuiltInViewportIndex:
14524 	case BuiltInFragStencilRefEXT:
14525 	case BuiltInPrimitiveId:
14526 	case BuiltInViewIndex:
14527 		expected_type = SPIRType::UInt;
14528 		expected_width = 32;
14529 		break;
14530 
14531 	case BuiltInTessLevelInner:
14532 	case BuiltInTessLevelOuter:
14533 		expected_type = SPIRType::Half;
14534 		expected_width = 16;
14535 		break;
14536 
14537 	default:
14538 		break;
14539 	}
14540 
14541 	if (expected_type != expr_type.basetype)
14542 	{
14543 		if (expected_width != expr_type.width)
14544 		{
14545 			// These are of different widths, so we cannot do a straight bitcast.
14546 			auto type = expr_type;
14547 			type.basetype = expected_type;
14548 			type.width = expected_width;
14549 			expr = join(type_to_glsl(type), "(", expr, ")");
14550 		}
14551 		else
14552 		{
14553 			auto type = expr_type;
14554 			type.basetype = expected_type;
14555 			expr = bitcast_expression(type, expr_type.basetype, expr);
14556 		}
14557 	}
14558 }
14559 
to_initializer_expression(const SPIRVariable & var)14560 string CompilerMSL::to_initializer_expression(const SPIRVariable &var)
14561 {
14562 	// We risk getting an array initializer here with MSL. If we have an array.
14563 	// FIXME: We cannot handle non-constant arrays being initialized.
14564 	// We will need to inject spvArrayCopy here somehow ...
14565 	auto &type = get<SPIRType>(var.basetype);
14566 	string expr;
14567 	if (ir.ids[var.initializer].get_type() == TypeConstant &&
14568 	    (!type.array.empty() || type.basetype == SPIRType::Struct))
14569 		expr = constant_expression(get<SPIRConstant>(var.initializer));
14570 	else
14571 		expr = CompilerGLSL::to_initializer_expression(var);
14572 	// If the initializer has more vector components than the variable, add a swizzle.
14573 	// FIXME: This can't handle arrays or structs.
14574 	auto &init_type = expression_type(var.initializer);
14575 	if (type.array.empty() && type.basetype != SPIRType::Struct && init_type.vecsize > type.vecsize)
14576 		expr = enclose_expression(expr + vector_swizzle(type.vecsize, 0));
14577 	return expr;
14578 }
14579 
to_zero_initialized_expression(uint32_t)14580 string CompilerMSL::to_zero_initialized_expression(uint32_t)
14581 {
14582 	return "{}";
14583 }
14584 
descriptor_set_is_argument_buffer(uint32_t desc_set) const14585 bool CompilerMSL::descriptor_set_is_argument_buffer(uint32_t desc_set) const
14586 {
14587 	if (!msl_options.argument_buffers)
14588 		return false;
14589 	if (desc_set >= kMaxArgumentBuffers)
14590 		return false;
14591 
14592 	return (argument_buffer_discrete_mask & (1u << desc_set)) == 0;
14593 }
14594 
is_supported_argument_buffer_type(const SPIRType & type) const14595 bool CompilerMSL::is_supported_argument_buffer_type(const SPIRType &type) const
14596 {
14597 	// Very specifically, image load-store in argument buffers are disallowed on MSL on iOS.
14598 	// But we won't know when the argument buffer is encoded whether this image will have
14599 	// a NonWritable decoration. So just use discrete arguments for all storage images
14600 	// on iOS.
14601 	bool is_storage_image = type.basetype == SPIRType::Image && type.image.sampled == 2;
14602 	bool is_supported_type = !msl_options.is_ios() || !is_storage_image;
14603 	return !type_is_msl_framebuffer_fetch(type) && is_supported_type;
14604 }
14605 
analyze_argument_buffers()14606 void CompilerMSL::analyze_argument_buffers()
14607 {
14608 	// Gather all used resources and sort them out into argument buffers.
14609 	// Each argument buffer corresponds to a descriptor set in SPIR-V.
14610 	// The [[id(N)]] values used correspond to the resource mapping we have for MSL.
14611 	// Otherwise, the binding number is used, but this is generally not safe some types like
14612 	// combined image samplers and arrays of resources. Metal needs different indices here,
14613 	// while SPIR-V can have one descriptor set binding. To use argument buffers in practice,
14614 	// you will need to use the remapping from the API.
14615 	for (auto &id : argument_buffer_ids)
14616 		id = 0;
14617 
14618 	// Output resources, sorted by resource index & type.
14619 	struct Resource
14620 	{
14621 		SPIRVariable *var;
14622 		string name;
14623 		SPIRType::BaseType basetype;
14624 		uint32_t index;
14625 		uint32_t plane;
14626 	};
14627 	SmallVector<Resource> resources_in_set[kMaxArgumentBuffers];
14628 	SmallVector<uint32_t> inline_block_vars;
14629 
14630 	bool set_needs_swizzle_buffer[kMaxArgumentBuffers] = {};
14631 	bool set_needs_buffer_sizes[kMaxArgumentBuffers] = {};
14632 	bool needs_buffer_sizes = false;
14633 
14634 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &var) {
14635 		if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
14636 		     var.storage == StorageClassStorageBuffer) &&
14637 		    !is_hidden_variable(var))
14638 		{
14639 			uint32_t desc_set = get_decoration(self, DecorationDescriptorSet);
14640 			// Ignore if it's part of a push descriptor set.
14641 			if (!descriptor_set_is_argument_buffer(desc_set))
14642 				return;
14643 
14644 			uint32_t var_id = var.self;
14645 			auto &type = get_variable_data_type(var);
14646 
14647 			if (desc_set >= kMaxArgumentBuffers)
14648 				SPIRV_CROSS_THROW("Descriptor set index is out of range.");
14649 
14650 			const MSLConstexprSampler *constexpr_sampler = nullptr;
14651 			if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
14652 			{
14653 				constexpr_sampler = find_constexpr_sampler(var_id);
14654 				if (constexpr_sampler)
14655 				{
14656 					// Mark this ID as a constexpr sampler for later in case it came from set/bindings.
14657 					constexpr_samplers_by_id[var_id] = *constexpr_sampler;
14658 				}
14659 			}
14660 
14661 			uint32_t binding = get_decoration(var_id, DecorationBinding);
14662 			if (type.basetype == SPIRType::SampledImage)
14663 			{
14664 				add_resource_name(var_id);
14665 
14666 				uint32_t plane_count = 1;
14667 				if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
14668 					plane_count = constexpr_sampler->planes;
14669 
14670 				for (uint32_t i = 0; i < plane_count; i++)
14671 				{
14672 					uint32_t image_resource_index = get_metal_resource_index(var, SPIRType::Image, i);
14673 					resources_in_set[desc_set].push_back(
14674 					    { &var, to_name(var_id), SPIRType::Image, image_resource_index, i });
14675 				}
14676 
14677 				if (type.image.dim != DimBuffer && !constexpr_sampler)
14678 				{
14679 					uint32_t sampler_resource_index = get_metal_resource_index(var, SPIRType::Sampler);
14680 					resources_in_set[desc_set].push_back(
14681 					    { &var, to_sampler_expression(var_id), SPIRType::Sampler, sampler_resource_index, 0 });
14682 				}
14683 			}
14684 			else if (inline_uniform_blocks.count(SetBindingPair{ desc_set, binding }))
14685 			{
14686 				inline_block_vars.push_back(var_id);
14687 			}
14688 			else if (!constexpr_sampler && is_supported_argument_buffer_type(type))
14689 			{
14690 				// constexpr samplers are not declared as resources.
14691 				// Inline uniform blocks are always emitted at the end.
14692 				add_resource_name(var_id);
14693 				resources_in_set[desc_set].push_back(
14694 					{ &var, to_name(var_id), type.basetype, get_metal_resource_index(var, type.basetype), 0 });
14695 
14696 				// Emulate texture2D atomic operations
14697 				if (atomic_image_vars.count(var.self))
14698 				{
14699 					uint32_t buffer_resource_index = get_metal_resource_index(var, SPIRType::AtomicCounter, 0);
14700 					resources_in_set[desc_set].push_back(
14701 						{ &var, to_name(var_id) + "_atomic", SPIRType::Struct, buffer_resource_index, 0 });
14702 				}
14703 			}
14704 
14705 			// Check if this descriptor set needs a swizzle buffer.
14706 			if (needs_swizzle_buffer_def && is_sampled_image_type(type))
14707 				set_needs_swizzle_buffer[desc_set] = true;
14708 			else if (buffers_requiring_array_length.count(var_id) != 0)
14709 			{
14710 				set_needs_buffer_sizes[desc_set] = true;
14711 				needs_buffer_sizes = true;
14712 			}
14713 		}
14714 	});
14715 
14716 	if (needs_swizzle_buffer_def || needs_buffer_sizes)
14717 	{
14718 		uint32_t uint_ptr_type_id = 0;
14719 
14720 		// We might have to add a swizzle buffer resource to the set.
14721 		for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
14722 		{
14723 			if (!set_needs_swizzle_buffer[desc_set] && !set_needs_buffer_sizes[desc_set])
14724 				continue;
14725 
14726 			if (uint_ptr_type_id == 0)
14727 			{
14728 				uint_ptr_type_id = ir.increase_bound_by(1);
14729 
14730 				// Create a buffer to hold extra data, including the swizzle constants.
14731 				SPIRType uint_type_pointer = get_uint_type();
14732 				uint_type_pointer.pointer = true;
14733 				uint_type_pointer.pointer_depth = 1;
14734 				uint_type_pointer.parent_type = get_uint_type_id();
14735 				uint_type_pointer.storage = StorageClassUniform;
14736 				set<SPIRType>(uint_ptr_type_id, uint_type_pointer);
14737 				set_decoration(uint_ptr_type_id, DecorationArrayStride, 4);
14738 			}
14739 
14740 			if (set_needs_swizzle_buffer[desc_set])
14741 			{
14742 				uint32_t var_id = ir.increase_bound_by(1);
14743 				auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
14744 				set_name(var_id, "spvSwizzleConstants");
14745 				set_decoration(var_id, DecorationDescriptorSet, desc_set);
14746 				set_decoration(var_id, DecorationBinding, kSwizzleBufferBinding);
14747 				resources_in_set[desc_set].push_back(
14748 				    { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0 });
14749 			}
14750 
14751 			if (set_needs_buffer_sizes[desc_set])
14752 			{
14753 				uint32_t var_id = ir.increase_bound_by(1);
14754 				auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
14755 				set_name(var_id, "spvBufferSizeConstants");
14756 				set_decoration(var_id, DecorationDescriptorSet, desc_set);
14757 				set_decoration(var_id, DecorationBinding, kBufferSizeBufferBinding);
14758 				resources_in_set[desc_set].push_back(
14759 				    { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0 });
14760 			}
14761 		}
14762 	}
14763 
14764 	// Now add inline uniform blocks.
14765 	for (uint32_t var_id : inline_block_vars)
14766 	{
14767 		auto &var = get<SPIRVariable>(var_id);
14768 		uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
14769 		add_resource_name(var_id);
14770 		resources_in_set[desc_set].push_back(
14771 		    { &var, to_name(var_id), SPIRType::Struct, get_metal_resource_index(var, SPIRType::Struct), 0 });
14772 	}
14773 
14774 	for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
14775 	{
14776 		auto &resources = resources_in_set[desc_set];
14777 		if (resources.empty())
14778 			continue;
14779 
14780 		assert(descriptor_set_is_argument_buffer(desc_set));
14781 
14782 		uint32_t next_id = ir.increase_bound_by(3);
14783 		uint32_t type_id = next_id + 1;
14784 		uint32_t ptr_type_id = next_id + 2;
14785 		argument_buffer_ids[desc_set] = next_id;
14786 
14787 		auto &buffer_type = set<SPIRType>(type_id);
14788 
14789 		buffer_type.basetype = SPIRType::Struct;
14790 
14791 		if ((argument_buffer_device_storage_mask & (1u << desc_set)) != 0)
14792 		{
14793 			buffer_type.storage = StorageClassStorageBuffer;
14794 			// Make sure the argument buffer gets marked as const device.
14795 			set_decoration(next_id, DecorationNonWritable);
14796 			// Need to mark the type as a Block to enable this.
14797 			set_decoration(type_id, DecorationBlock);
14798 		}
14799 		else
14800 			buffer_type.storage = StorageClassUniform;
14801 
14802 		set_name(type_id, join("spvDescriptorSetBuffer", desc_set));
14803 
14804 		auto &ptr_type = set<SPIRType>(ptr_type_id);
14805 		ptr_type = buffer_type;
14806 		ptr_type.pointer = true;
14807 		ptr_type.pointer_depth = 1;
14808 		ptr_type.parent_type = type_id;
14809 
14810 		uint32_t buffer_variable_id = next_id;
14811 		set<SPIRVariable>(buffer_variable_id, ptr_type_id, StorageClassUniform);
14812 		set_name(buffer_variable_id, join("spvDescriptorSet", desc_set));
14813 
14814 		// Ids must be emitted in ID order.
14815 		sort(begin(resources), end(resources), [&](const Resource &lhs, const Resource &rhs) -> bool {
14816 			return tie(lhs.index, lhs.basetype) < tie(rhs.index, rhs.basetype);
14817 		});
14818 
14819 		uint32_t member_index = 0;
14820 		for (auto &resource : resources)
14821 		{
14822 			auto &var = *resource.var;
14823 			auto &type = get_variable_data_type(var);
14824 			string mbr_name = ensure_valid_name(resource.name, "m");
14825 			if (resource.plane > 0)
14826 				mbr_name += join(plane_name_suffix, resource.plane);
14827 			set_member_name(buffer_type.self, member_index, mbr_name);
14828 
14829 			if (resource.basetype == SPIRType::Sampler && type.basetype != SPIRType::Sampler)
14830 			{
14831 				// Have to synthesize a sampler type here.
14832 
14833 				bool type_is_array = !type.array.empty();
14834 				uint32_t sampler_type_id = ir.increase_bound_by(type_is_array ? 2 : 1);
14835 				auto &new_sampler_type = set<SPIRType>(sampler_type_id);
14836 				new_sampler_type.basetype = SPIRType::Sampler;
14837 				new_sampler_type.storage = StorageClassUniformConstant;
14838 
14839 				if (type_is_array)
14840 				{
14841 					uint32_t sampler_type_array_id = sampler_type_id + 1;
14842 					auto &sampler_type_array = set<SPIRType>(sampler_type_array_id);
14843 					sampler_type_array = new_sampler_type;
14844 					sampler_type_array.array = type.array;
14845 					sampler_type_array.array_size_literal = type.array_size_literal;
14846 					sampler_type_array.parent_type = sampler_type_id;
14847 					buffer_type.member_types.push_back(sampler_type_array_id);
14848 				}
14849 				else
14850 					buffer_type.member_types.push_back(sampler_type_id);
14851 			}
14852 			else
14853 			{
14854 				uint32_t binding = get_decoration(var.self, DecorationBinding);
14855 				SetBindingPair pair = { desc_set, binding };
14856 
14857 				if (resource.basetype == SPIRType::Image || resource.basetype == SPIRType::Sampler ||
14858 				    resource.basetype == SPIRType::SampledImage)
14859 				{
14860 					// Drop pointer information when we emit the resources into a struct.
14861 					buffer_type.member_types.push_back(get_variable_data_type_id(var));
14862 					if (resource.plane == 0)
14863 						set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
14864 				}
14865 				else if (buffers_requiring_dynamic_offset.count(pair))
14866 				{
14867 					// Don't set the qualified name here; we'll define a variable holding the corrected buffer address later.
14868 					buffer_type.member_types.push_back(var.basetype);
14869 					buffers_requiring_dynamic_offset[pair].second = var.self;
14870 				}
14871 				else if (inline_uniform_blocks.count(pair))
14872 				{
14873 					// Put the buffer block itself into the argument buffer.
14874 					buffer_type.member_types.push_back(get_variable_data_type_id(var));
14875 					set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
14876 				}
14877 				else if (atomic_image_vars.count(var.self))
14878 				{
14879 					// Emulate texture2D atomic operations.
14880 					// Don't set the qualified name: it's already set for this variable,
14881 					// and the code that references the buffer manually appends "_atomic"
14882 					// to the name.
14883 					uint32_t offset = ir.increase_bound_by(2);
14884 					uint32_t atomic_type_id = offset;
14885 					uint32_t type_ptr_id = offset + 1;
14886 
14887 					SPIRType atomic_type;
14888 					atomic_type.basetype = SPIRType::AtomicCounter;
14889 					atomic_type.width = 32;
14890 					atomic_type.vecsize = 1;
14891 					set<SPIRType>(atomic_type_id, atomic_type);
14892 
14893 					atomic_type.pointer = true;
14894 					atomic_type.parent_type = atomic_type_id;
14895 					atomic_type.storage = StorageClassStorageBuffer;
14896 					auto &atomic_ptr_type = set<SPIRType>(type_ptr_id, atomic_type);
14897 					atomic_ptr_type.self = atomic_type_id;
14898 
14899 					buffer_type.member_types.push_back(type_ptr_id);
14900 				}
14901 				else
14902 				{
14903 					// Resources will be declared as pointers not references, so automatically dereference as appropriate.
14904 					buffer_type.member_types.push_back(var.basetype);
14905 					if (type.array.empty())
14906 						set_qualified_name(var.self, join("(*", to_name(buffer_variable_id), ".", mbr_name, ")"));
14907 					else
14908 						set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
14909 				}
14910 			}
14911 
14912 			set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationResourceIndexPrimary,
14913 			                               resource.index);
14914 			set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationInterfaceOrigID,
14915 			                               var.self);
14916 			member_index++;
14917 		}
14918 	}
14919 }
14920 
activate_argument_buffer_resources()14921 void CompilerMSL::activate_argument_buffer_resources()
14922 {
14923 	// For ABI compatibility, force-enable all resources which are part of argument buffers.
14924 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, const SPIRVariable &) {
14925 		if (!has_decoration(self, DecorationDescriptorSet))
14926 			return;
14927 
14928 		uint32_t desc_set = get_decoration(self, DecorationDescriptorSet);
14929 		if (descriptor_set_is_argument_buffer(desc_set))
14930 			active_interface_variables.insert(self);
14931 	});
14932 }
14933 
using_builtin_array() const14934 bool CompilerMSL::using_builtin_array() const
14935 {
14936 	return msl_options.force_native_arrays || is_using_builtin_array;
14937 }
14938 
set_combined_sampler_suffix(const char * suffix)14939 void CompilerMSL::set_combined_sampler_suffix(const char *suffix)
14940 {
14941 	sampler_name_suffix = suffix;
14942 }
14943 
get_combined_sampler_suffix() const14944 const char *CompilerMSL::get_combined_sampler_suffix() const
14945 {
14946 	return sampler_name_suffix.c_str();
14947 }
14948