1 /*
2 * Copyright 2016-2021 The Brenwill Workshop Ltd.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 /*
18 * At your option, you may choose to accept this material under either:
19 * 1. The Apache License, Version 2.0, found at <http://www.apache.org/licenses/LICENSE-2.0>, or
20 * 2. The MIT License, found at <http://opensource.org/licenses/MIT>.
21 * SPDX-License-Identifier: Apache-2.0 OR MIT.
22 */
23
24 #include "spirv_msl.hpp"
25 #include "GLSL.std.450.h"
26
27 #include <algorithm>
28 #include <assert.h>
29 #include <numeric>
30
31 using namespace spv;
32 using namespace SPIRV_CROSS_NAMESPACE;
33 using namespace std;
34
35 static const uint32_t k_unknown_location = ~0u;
36 static const uint32_t k_unknown_component = ~0u;
37 static const char *force_inline = "static inline __attribute__((always_inline))";
38
CompilerMSL(std::vector<uint32_t> spirv_)39 CompilerMSL::CompilerMSL(std::vector<uint32_t> spirv_)
40 : CompilerGLSL(move(spirv_))
41 {
42 }
43
CompilerMSL(const uint32_t * ir_,size_t word_count)44 CompilerMSL::CompilerMSL(const uint32_t *ir_, size_t word_count)
45 : CompilerGLSL(ir_, word_count)
46 {
47 }
48
CompilerMSL(const ParsedIR & ir_)49 CompilerMSL::CompilerMSL(const ParsedIR &ir_)
50 : CompilerGLSL(ir_)
51 {
52 }
53
CompilerMSL(ParsedIR && ir_)54 CompilerMSL::CompilerMSL(ParsedIR &&ir_)
55 : CompilerGLSL(std::move(ir_))
56 {
57 }
58
add_msl_shader_input(const MSLShaderInput & si)59 void CompilerMSL::add_msl_shader_input(const MSLShaderInput &si)
60 {
61 inputs_by_location[si.location] = si;
62 if (si.builtin != BuiltInMax && !inputs_by_builtin.count(si.builtin))
63 inputs_by_builtin[si.builtin] = si;
64 }
65
add_msl_resource_binding(const MSLResourceBinding & binding)66 void CompilerMSL::add_msl_resource_binding(const MSLResourceBinding &binding)
67 {
68 StageSetBinding tuple = { binding.stage, binding.desc_set, binding.binding };
69 resource_bindings[tuple] = { binding, false };
70 }
71
add_dynamic_buffer(uint32_t desc_set,uint32_t binding,uint32_t index)72 void CompilerMSL::add_dynamic_buffer(uint32_t desc_set, uint32_t binding, uint32_t index)
73 {
74 SetBindingPair pair = { desc_set, binding };
75 buffers_requiring_dynamic_offset[pair] = { index, 0 };
76 }
77
add_inline_uniform_block(uint32_t desc_set,uint32_t binding)78 void CompilerMSL::add_inline_uniform_block(uint32_t desc_set, uint32_t binding)
79 {
80 SetBindingPair pair = { desc_set, binding };
81 inline_uniform_blocks.insert(pair);
82 }
83
add_discrete_descriptor_set(uint32_t desc_set)84 void CompilerMSL::add_discrete_descriptor_set(uint32_t desc_set)
85 {
86 if (desc_set < kMaxArgumentBuffers)
87 argument_buffer_discrete_mask |= 1u << desc_set;
88 }
89
set_argument_buffer_device_address_space(uint32_t desc_set,bool device_storage)90 void CompilerMSL::set_argument_buffer_device_address_space(uint32_t desc_set, bool device_storage)
91 {
92 if (desc_set < kMaxArgumentBuffers)
93 {
94 if (device_storage)
95 argument_buffer_device_storage_mask |= 1u << desc_set;
96 else
97 argument_buffer_device_storage_mask &= ~(1u << desc_set);
98 }
99 }
100
is_msl_shader_input_used(uint32_t location)101 bool CompilerMSL::is_msl_shader_input_used(uint32_t location)
102 {
103 return inputs_in_use.count(location) != 0;
104 }
105
is_msl_resource_binding_used(ExecutionModel model,uint32_t desc_set,uint32_t binding) const106 bool CompilerMSL::is_msl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const
107 {
108 StageSetBinding tuple = { model, desc_set, binding };
109 auto itr = resource_bindings.find(tuple);
110 return itr != end(resource_bindings) && itr->second.second;
111 }
112
113 // Returns the size of the array of resources used by the variable with the specified id.
114 // The returned value is retrieved from the resource binding added using add_msl_resource_binding().
get_resource_array_size(uint32_t id) const115 uint32_t CompilerMSL::get_resource_array_size(uint32_t id) const
116 {
117 StageSetBinding tuple = { get_entry_point().model, get_decoration(id, DecorationDescriptorSet),
118 get_decoration(id, DecorationBinding) };
119 auto itr = resource_bindings.find(tuple);
120 return itr != end(resource_bindings) ? itr->second.first.count : 0;
121 }
122
get_automatic_msl_resource_binding(uint32_t id) const123 uint32_t CompilerMSL::get_automatic_msl_resource_binding(uint32_t id) const
124 {
125 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexPrimary);
126 }
127
get_automatic_msl_resource_binding_secondary(uint32_t id) const128 uint32_t CompilerMSL::get_automatic_msl_resource_binding_secondary(uint32_t id) const
129 {
130 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexSecondary);
131 }
132
get_automatic_msl_resource_binding_tertiary(uint32_t id) const133 uint32_t CompilerMSL::get_automatic_msl_resource_binding_tertiary(uint32_t id) const
134 {
135 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexTertiary);
136 }
137
get_automatic_msl_resource_binding_quaternary(uint32_t id) const138 uint32_t CompilerMSL::get_automatic_msl_resource_binding_quaternary(uint32_t id) const
139 {
140 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexQuaternary);
141 }
142
set_fragment_output_components(uint32_t location,uint32_t components)143 void CompilerMSL::set_fragment_output_components(uint32_t location, uint32_t components)
144 {
145 fragment_output_components[location] = components;
146 }
147
builtin_translates_to_nonarray(spv::BuiltIn builtin) const148 bool CompilerMSL::builtin_translates_to_nonarray(spv::BuiltIn builtin) const
149 {
150 return (builtin == BuiltInSampleMask);
151 }
152
build_implicit_builtins()153 void CompilerMSL::build_implicit_builtins()
154 {
155 bool need_sample_pos = active_input_builtins.get(BuiltInSamplePosition);
156 bool need_vertex_params = capture_output_to_buffer && get_execution_model() == ExecutionModelVertex &&
157 !msl_options.vertex_for_tessellation;
158 bool need_tesc_params = get_execution_model() == ExecutionModelTessellationControl;
159 bool need_subgroup_mask =
160 active_input_builtins.get(BuiltInSubgroupEqMask) || active_input_builtins.get(BuiltInSubgroupGeMask) ||
161 active_input_builtins.get(BuiltInSubgroupGtMask) || active_input_builtins.get(BuiltInSubgroupLeMask) ||
162 active_input_builtins.get(BuiltInSubgroupLtMask);
163 bool need_subgroup_ge_mask = !msl_options.is_ios() && (active_input_builtins.get(BuiltInSubgroupGeMask) ||
164 active_input_builtins.get(BuiltInSubgroupGtMask));
165 bool need_multiview = get_execution_model() == ExecutionModelVertex && !msl_options.view_index_from_device_index &&
166 msl_options.multiview_layered_rendering &&
167 (msl_options.multiview || active_input_builtins.get(BuiltInViewIndex));
168 bool need_dispatch_base =
169 msl_options.dispatch_base && get_execution_model() == ExecutionModelGLCompute &&
170 (active_input_builtins.get(BuiltInWorkgroupId) || active_input_builtins.get(BuiltInGlobalInvocationId));
171 bool need_grid_params = get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation;
172 bool need_vertex_base_params =
173 need_grid_params &&
174 (active_input_builtins.get(BuiltInVertexId) || active_input_builtins.get(BuiltInVertexIndex) ||
175 active_input_builtins.get(BuiltInBaseVertex) || active_input_builtins.get(BuiltInInstanceId) ||
176 active_input_builtins.get(BuiltInInstanceIndex) || active_input_builtins.get(BuiltInBaseInstance));
177 bool need_sample_mask = msl_options.additional_fixed_sample_mask != 0xffffffff;
178 bool need_local_invocation_index = msl_options.emulate_subgroups && active_input_builtins.get(BuiltInSubgroupId);
179 bool need_workgroup_size = msl_options.emulate_subgroups && active_input_builtins.get(BuiltInNumSubgroups);
180 if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
181 need_multiview || need_dispatch_base || need_vertex_base_params || need_grid_params || needs_sample_id ||
182 needs_subgroup_invocation_id || needs_subgroup_size || need_sample_mask || need_local_invocation_index ||
183 need_workgroup_size)
184 {
185 bool has_frag_coord = false;
186 bool has_sample_id = false;
187 bool has_vertex_idx = false;
188 bool has_base_vertex = false;
189 bool has_instance_idx = false;
190 bool has_base_instance = false;
191 bool has_invocation_id = false;
192 bool has_primitive_id = false;
193 bool has_subgroup_invocation_id = false;
194 bool has_subgroup_size = false;
195 bool has_view_idx = false;
196 bool has_layer = false;
197 bool has_local_invocation_index = false;
198 bool has_workgroup_size = false;
199 uint32_t workgroup_id_type = 0;
200
201 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
202 if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
203 return;
204 if (!interface_variable_exists_in_entry_point(var.self))
205 return;
206 if (!has_decoration(var.self, DecorationBuiltIn))
207 return;
208
209 BuiltIn builtin = ir.meta[var.self].decoration.builtin_type;
210
211 if (var.storage == StorageClassOutput)
212 {
213 if (need_sample_mask && builtin == BuiltInSampleMask)
214 {
215 builtin_sample_mask_id = var.self;
216 mark_implicit_builtin(StorageClassOutput, BuiltInSampleMask, var.self);
217 does_shader_write_sample_mask = true;
218 }
219 }
220
221 if (var.storage != StorageClassInput)
222 return;
223
224 // Use Metal's native frame-buffer fetch API for subpass inputs.
225 if (need_subpass_input && (!msl_options.use_framebuffer_fetch_subpasses))
226 {
227 switch (builtin)
228 {
229 case BuiltInFragCoord:
230 mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var.self);
231 builtin_frag_coord_id = var.self;
232 has_frag_coord = true;
233 break;
234 case BuiltInLayer:
235 if (!msl_options.arrayed_subpass_input || msl_options.multiview)
236 break;
237 mark_implicit_builtin(StorageClassInput, BuiltInLayer, var.self);
238 builtin_layer_id = var.self;
239 has_layer = true;
240 break;
241 case BuiltInViewIndex:
242 if (!msl_options.multiview)
243 break;
244 mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var.self);
245 builtin_view_idx_id = var.self;
246 has_view_idx = true;
247 break;
248 default:
249 break;
250 }
251 }
252
253 if ((need_sample_pos || needs_sample_id) && builtin == BuiltInSampleId)
254 {
255 builtin_sample_id_id = var.self;
256 mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var.self);
257 has_sample_id = true;
258 }
259
260 if (need_vertex_params)
261 {
262 switch (builtin)
263 {
264 case BuiltInVertexIndex:
265 builtin_vertex_idx_id = var.self;
266 mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var.self);
267 has_vertex_idx = true;
268 break;
269 case BuiltInBaseVertex:
270 builtin_base_vertex_id = var.self;
271 mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var.self);
272 has_base_vertex = true;
273 break;
274 case BuiltInInstanceIndex:
275 builtin_instance_idx_id = var.self;
276 mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var.self);
277 has_instance_idx = true;
278 break;
279 case BuiltInBaseInstance:
280 builtin_base_instance_id = var.self;
281 mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var.self);
282 has_base_instance = true;
283 break;
284 default:
285 break;
286 }
287 }
288
289 if (need_tesc_params)
290 {
291 switch (builtin)
292 {
293 case BuiltInInvocationId:
294 builtin_invocation_id_id = var.self;
295 mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var.self);
296 has_invocation_id = true;
297 break;
298 case BuiltInPrimitiveId:
299 builtin_primitive_id_id = var.self;
300 mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var.self);
301 has_primitive_id = true;
302 break;
303 default:
304 break;
305 }
306 }
307
308 if ((need_subgroup_mask || needs_subgroup_invocation_id) && builtin == BuiltInSubgroupLocalInvocationId)
309 {
310 builtin_subgroup_invocation_id_id = var.self;
311 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var.self);
312 has_subgroup_invocation_id = true;
313 }
314
315 if ((need_subgroup_ge_mask || needs_subgroup_size) && builtin == BuiltInSubgroupSize)
316 {
317 builtin_subgroup_size_id = var.self;
318 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var.self);
319 has_subgroup_size = true;
320 }
321
322 if (need_multiview)
323 {
324 switch (builtin)
325 {
326 case BuiltInInstanceIndex:
327 // The view index here is derived from the instance index.
328 builtin_instance_idx_id = var.self;
329 mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var.self);
330 has_instance_idx = true;
331 break;
332 case BuiltInBaseInstance:
333 // If a non-zero base instance is used, we need to adjust for it when calculating the view index.
334 builtin_base_instance_id = var.self;
335 mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var.self);
336 has_base_instance = true;
337 break;
338 case BuiltInViewIndex:
339 builtin_view_idx_id = var.self;
340 mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var.self);
341 has_view_idx = true;
342 break;
343 default:
344 break;
345 }
346 }
347
348 if (need_local_invocation_index && builtin == BuiltInLocalInvocationIndex)
349 {
350 builtin_local_invocation_index_id = var.self;
351 mark_implicit_builtin(StorageClassInput, BuiltInLocalInvocationIndex, var.self);
352 has_local_invocation_index = true;
353 }
354
355 if (need_workgroup_size && builtin == BuiltInLocalInvocationId)
356 {
357 builtin_workgroup_size_id = var.self;
358 mark_implicit_builtin(StorageClassInput, BuiltInWorkgroupSize, var.self);
359 has_workgroup_size = true;
360 }
361
362 // The base workgroup needs to have the same type and vector size
363 // as the workgroup or invocation ID, so keep track of the type that
364 // was used.
365 if (need_dispatch_base && workgroup_id_type == 0 &&
366 (builtin == BuiltInWorkgroupId || builtin == BuiltInGlobalInvocationId))
367 workgroup_id_type = var.basetype;
368 });
369
370 // Use Metal's native frame-buffer fetch API for subpass inputs.
371 if ((!has_frag_coord || (msl_options.multiview && !has_view_idx) ||
372 (msl_options.arrayed_subpass_input && !msl_options.multiview && !has_layer)) &&
373 (!msl_options.use_framebuffer_fetch_subpasses) && need_subpass_input)
374 {
375 if (!has_frag_coord)
376 {
377 uint32_t offset = ir.increase_bound_by(3);
378 uint32_t type_id = offset;
379 uint32_t type_ptr_id = offset + 1;
380 uint32_t var_id = offset + 2;
381
382 // Create gl_FragCoord.
383 SPIRType vec4_type;
384 vec4_type.basetype = SPIRType::Float;
385 vec4_type.width = 32;
386 vec4_type.vecsize = 4;
387 set<SPIRType>(type_id, vec4_type);
388
389 SPIRType vec4_type_ptr;
390 vec4_type_ptr = vec4_type;
391 vec4_type_ptr.pointer = true;
392 vec4_type_ptr.parent_type = type_id;
393 vec4_type_ptr.storage = StorageClassInput;
394 auto &ptr_type = set<SPIRType>(type_ptr_id, vec4_type_ptr);
395 ptr_type.self = type_id;
396
397 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
398 set_decoration(var_id, DecorationBuiltIn, BuiltInFragCoord);
399 builtin_frag_coord_id = var_id;
400 mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var_id);
401 }
402
403 if (!has_layer && msl_options.arrayed_subpass_input && !msl_options.multiview)
404 {
405 uint32_t offset = ir.increase_bound_by(2);
406 uint32_t type_ptr_id = offset;
407 uint32_t var_id = offset + 1;
408
409 // Create gl_Layer.
410 SPIRType uint_type_ptr;
411 uint_type_ptr = get_uint_type();
412 uint_type_ptr.pointer = true;
413 uint_type_ptr.parent_type = get_uint_type_id();
414 uint_type_ptr.storage = StorageClassInput;
415 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
416 ptr_type.self = get_uint_type_id();
417
418 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
419 set_decoration(var_id, DecorationBuiltIn, BuiltInLayer);
420 builtin_layer_id = var_id;
421 mark_implicit_builtin(StorageClassInput, BuiltInLayer, var_id);
422 }
423
424 if (!has_view_idx && msl_options.multiview)
425 {
426 uint32_t offset = ir.increase_bound_by(2);
427 uint32_t type_ptr_id = offset;
428 uint32_t var_id = offset + 1;
429
430 // Create gl_ViewIndex.
431 SPIRType uint_type_ptr;
432 uint_type_ptr = get_uint_type();
433 uint_type_ptr.pointer = true;
434 uint_type_ptr.parent_type = get_uint_type_id();
435 uint_type_ptr.storage = StorageClassInput;
436 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
437 ptr_type.self = get_uint_type_id();
438
439 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
440 set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex);
441 builtin_view_idx_id = var_id;
442 mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id);
443 }
444 }
445
446 if (!has_sample_id && (need_sample_pos || needs_sample_id))
447 {
448 uint32_t offset = ir.increase_bound_by(2);
449 uint32_t type_ptr_id = offset;
450 uint32_t var_id = offset + 1;
451
452 // Create gl_SampleID.
453 SPIRType uint_type_ptr;
454 uint_type_ptr = get_uint_type();
455 uint_type_ptr.pointer = true;
456 uint_type_ptr.parent_type = get_uint_type_id();
457 uint_type_ptr.storage = StorageClassInput;
458 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
459 ptr_type.self = get_uint_type_id();
460
461 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
462 set_decoration(var_id, DecorationBuiltIn, BuiltInSampleId);
463 builtin_sample_id_id = var_id;
464 mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var_id);
465 }
466
467 if ((need_vertex_params && (!has_vertex_idx || !has_base_vertex || !has_instance_idx || !has_base_instance)) ||
468 (need_multiview && (!has_instance_idx || !has_base_instance || !has_view_idx)))
469 {
470 uint32_t type_ptr_id = ir.increase_bound_by(1);
471
472 SPIRType uint_type_ptr;
473 uint_type_ptr = get_uint_type();
474 uint_type_ptr.pointer = true;
475 uint_type_ptr.parent_type = get_uint_type_id();
476 uint_type_ptr.storage = StorageClassInput;
477 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
478 ptr_type.self = get_uint_type_id();
479
480 if (need_vertex_params && !has_vertex_idx)
481 {
482 uint32_t var_id = ir.increase_bound_by(1);
483
484 // Create gl_VertexIndex.
485 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
486 set_decoration(var_id, DecorationBuiltIn, BuiltInVertexIndex);
487 builtin_vertex_idx_id = var_id;
488 mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var_id);
489 }
490
491 if (need_vertex_params && !has_base_vertex)
492 {
493 uint32_t var_id = ir.increase_bound_by(1);
494
495 // Create gl_BaseVertex.
496 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
497 set_decoration(var_id, DecorationBuiltIn, BuiltInBaseVertex);
498 builtin_base_vertex_id = var_id;
499 mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var_id);
500 }
501
502 if (!has_instance_idx) // Needed by both multiview and tessellation
503 {
504 uint32_t var_id = ir.increase_bound_by(1);
505
506 // Create gl_InstanceIndex.
507 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
508 set_decoration(var_id, DecorationBuiltIn, BuiltInInstanceIndex);
509 builtin_instance_idx_id = var_id;
510 mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var_id);
511 }
512
513 if (!has_base_instance) // Needed by both multiview and tessellation
514 {
515 uint32_t var_id = ir.increase_bound_by(1);
516
517 // Create gl_BaseInstance.
518 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
519 set_decoration(var_id, DecorationBuiltIn, BuiltInBaseInstance);
520 builtin_base_instance_id = var_id;
521 mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var_id);
522 }
523
524 if (need_multiview)
525 {
526 // Multiview shaders are not allowed to write to gl_Layer, ostensibly because
527 // it is implicitly written from gl_ViewIndex, but we have to do that explicitly.
528 // Note that we can't just abuse gl_ViewIndex for this purpose: it's an input, but
529 // gl_Layer is an output in vertex-pipeline shaders.
530 uint32_t type_ptr_out_id = ir.increase_bound_by(2);
531 SPIRType uint_type_ptr_out;
532 uint_type_ptr_out = get_uint_type();
533 uint_type_ptr_out.pointer = true;
534 uint_type_ptr_out.parent_type = get_uint_type_id();
535 uint_type_ptr_out.storage = StorageClassOutput;
536 auto &ptr_out_type = set<SPIRType>(type_ptr_out_id, uint_type_ptr_out);
537 ptr_out_type.self = get_uint_type_id();
538 uint32_t var_id = type_ptr_out_id + 1;
539 set<SPIRVariable>(var_id, type_ptr_out_id, StorageClassOutput);
540 set_decoration(var_id, DecorationBuiltIn, BuiltInLayer);
541 builtin_layer_id = var_id;
542 mark_implicit_builtin(StorageClassOutput, BuiltInLayer, var_id);
543 }
544
545 if (need_multiview && !has_view_idx)
546 {
547 uint32_t var_id = ir.increase_bound_by(1);
548
549 // Create gl_ViewIndex.
550 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
551 set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex);
552 builtin_view_idx_id = var_id;
553 mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id);
554 }
555 }
556
557 if ((need_tesc_params && (msl_options.multi_patch_workgroup || !has_invocation_id || !has_primitive_id)) ||
558 need_grid_params)
559 {
560 uint32_t type_ptr_id = ir.increase_bound_by(1);
561
562 SPIRType uint_type_ptr;
563 uint_type_ptr = get_uint_type();
564 uint_type_ptr.pointer = true;
565 uint_type_ptr.parent_type = get_uint_type_id();
566 uint_type_ptr.storage = StorageClassInput;
567 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
568 ptr_type.self = get_uint_type_id();
569
570 if (msl_options.multi_patch_workgroup || need_grid_params)
571 {
572 uint32_t var_id = ir.increase_bound_by(1);
573
574 // Create gl_GlobalInvocationID.
575 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
576 set_decoration(var_id, DecorationBuiltIn, BuiltInGlobalInvocationId);
577 builtin_invocation_id_id = var_id;
578 mark_implicit_builtin(StorageClassInput, BuiltInGlobalInvocationId, var_id);
579 }
580 else if (need_tesc_params && !has_invocation_id)
581 {
582 uint32_t var_id = ir.increase_bound_by(1);
583
584 // Create gl_InvocationID.
585 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
586 set_decoration(var_id, DecorationBuiltIn, BuiltInInvocationId);
587 builtin_invocation_id_id = var_id;
588 mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var_id);
589 }
590
591 if (need_tesc_params && !has_primitive_id)
592 {
593 uint32_t var_id = ir.increase_bound_by(1);
594
595 // Create gl_PrimitiveID.
596 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
597 set_decoration(var_id, DecorationBuiltIn, BuiltInPrimitiveId);
598 builtin_primitive_id_id = var_id;
599 mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var_id);
600 }
601
602 if (need_grid_params)
603 {
604 uint32_t var_id = ir.increase_bound_by(1);
605
606 set<SPIRVariable>(var_id, build_extended_vector_type(get_uint_type_id(), 3), StorageClassInput);
607 set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInStageInputSize);
608 get_entry_point().interface_variables.push_back(var_id);
609 set_name(var_id, "spvStageInputSize");
610 builtin_stage_input_size_id = var_id;
611 }
612 }
613
614 if (!has_subgroup_invocation_id && (need_subgroup_mask || needs_subgroup_invocation_id))
615 {
616 uint32_t offset = ir.increase_bound_by(2);
617 uint32_t type_ptr_id = offset;
618 uint32_t var_id = offset + 1;
619
620 // Create gl_SubgroupInvocationID.
621 SPIRType uint_type_ptr;
622 uint_type_ptr = get_uint_type();
623 uint_type_ptr.pointer = true;
624 uint_type_ptr.parent_type = get_uint_type_id();
625 uint_type_ptr.storage = StorageClassInput;
626 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
627 ptr_type.self = get_uint_type_id();
628
629 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
630 set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupLocalInvocationId);
631 builtin_subgroup_invocation_id_id = var_id;
632 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var_id);
633 }
634
635 if (!has_subgroup_size && (need_subgroup_ge_mask || needs_subgroup_size))
636 {
637 uint32_t offset = ir.increase_bound_by(2);
638 uint32_t type_ptr_id = offset;
639 uint32_t var_id = offset + 1;
640
641 // Create gl_SubgroupSize.
642 SPIRType uint_type_ptr;
643 uint_type_ptr = get_uint_type();
644 uint_type_ptr.pointer = true;
645 uint_type_ptr.parent_type = get_uint_type_id();
646 uint_type_ptr.storage = StorageClassInput;
647 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
648 ptr_type.self = get_uint_type_id();
649
650 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
651 set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupSize);
652 builtin_subgroup_size_id = var_id;
653 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var_id);
654 }
655
656 if (need_dispatch_base || need_vertex_base_params)
657 {
658 if (workgroup_id_type == 0)
659 workgroup_id_type = build_extended_vector_type(get_uint_type_id(), 3);
660 uint32_t var_id;
661 if (msl_options.supports_msl_version(1, 2))
662 {
663 // If we have MSL 1.2, we can (ab)use the [[grid_origin]] builtin
664 // to convey this information and save a buffer slot.
665 uint32_t offset = ir.increase_bound_by(1);
666 var_id = offset;
667
668 set<SPIRVariable>(var_id, workgroup_id_type, StorageClassInput);
669 set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase);
670 get_entry_point().interface_variables.push_back(var_id);
671 }
672 else
673 {
674 // Otherwise, we need to fall back to a good ol' fashioned buffer.
675 uint32_t offset = ir.increase_bound_by(2);
676 var_id = offset;
677 uint32_t type_id = offset + 1;
678
679 SPIRType var_type = get<SPIRType>(workgroup_id_type);
680 var_type.storage = StorageClassUniform;
681 set<SPIRType>(type_id, var_type);
682
683 set<SPIRVariable>(var_id, type_id, StorageClassUniform);
684 // This should never match anything.
685 set_decoration(var_id, DecorationDescriptorSet, ~(5u));
686 set_decoration(var_id, DecorationBinding, msl_options.indirect_params_buffer_index);
687 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary,
688 msl_options.indirect_params_buffer_index);
689 }
690 set_name(var_id, "spvDispatchBase");
691 builtin_dispatch_base_id = var_id;
692 }
693
694 if (need_sample_mask && !does_shader_write_sample_mask)
695 {
696 uint32_t offset = ir.increase_bound_by(2);
697 uint32_t var_id = offset + 1;
698
699 // Create gl_SampleMask.
700 SPIRType uint_type_ptr_out;
701 uint_type_ptr_out = get_uint_type();
702 uint_type_ptr_out.pointer = true;
703 uint_type_ptr_out.parent_type = get_uint_type_id();
704 uint_type_ptr_out.storage = StorageClassOutput;
705
706 auto &ptr_out_type = set<SPIRType>(offset, uint_type_ptr_out);
707 ptr_out_type.self = get_uint_type_id();
708 set<SPIRVariable>(var_id, offset, StorageClassOutput);
709 set_decoration(var_id, DecorationBuiltIn, BuiltInSampleMask);
710 builtin_sample_mask_id = var_id;
711 mark_implicit_builtin(StorageClassOutput, BuiltInSampleMask, var_id);
712 }
713
714 if (need_local_invocation_index && !has_local_invocation_index)
715 {
716 uint32_t offset = ir.increase_bound_by(2);
717 uint32_t type_ptr_id = offset;
718 uint32_t var_id = offset + 1;
719
720 // Create gl_LocalInvocationIndex.
721 SPIRType uint_type_ptr;
722 uint_type_ptr = get_uint_type();
723 uint_type_ptr.pointer = true;
724 uint_type_ptr.parent_type = get_uint_type_id();
725 uint_type_ptr.storage = StorageClassInput;
726
727 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
728 ptr_type.self = get_uint_type_id();
729 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
730 set_decoration(var_id, DecorationBuiltIn, BuiltInLocalInvocationIndex);
731 builtin_local_invocation_index_id = var_id;
732 mark_implicit_builtin(StorageClassInput, BuiltInLocalInvocationIndex, var_id);
733 }
734
735 if (need_workgroup_size && !has_workgroup_size)
736 {
737 uint32_t offset = ir.increase_bound_by(2);
738 uint32_t type_ptr_id = offset;
739 uint32_t var_id = offset + 1;
740
741 // Create gl_WorkgroupSize.
742 uint32_t type_id = build_extended_vector_type(get_uint_type_id(), 3);
743 SPIRType uint_type_ptr = get<SPIRType>(type_id);
744 uint_type_ptr.pointer = true;
745 uint_type_ptr.parent_type = type_id;
746 uint_type_ptr.storage = StorageClassInput;
747
748 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
749 ptr_type.self = type_id;
750 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
751 set_decoration(var_id, DecorationBuiltIn, BuiltInWorkgroupSize);
752 builtin_workgroup_size_id = var_id;
753 mark_implicit_builtin(StorageClassInput, BuiltInWorkgroupSize, var_id);
754 }
755 }
756
757 if (needs_swizzle_buffer_def)
758 {
759 uint32_t var_id = build_constant_uint_array_pointer();
760 set_name(var_id, "spvSwizzleConstants");
761 // This should never match anything.
762 set_decoration(var_id, DecorationDescriptorSet, kSwizzleBufferBinding);
763 set_decoration(var_id, DecorationBinding, msl_options.swizzle_buffer_index);
764 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.swizzle_buffer_index);
765 swizzle_buffer_id = var_id;
766 }
767
768 if (!buffers_requiring_array_length.empty())
769 {
770 uint32_t var_id = build_constant_uint_array_pointer();
771 set_name(var_id, "spvBufferSizeConstants");
772 // This should never match anything.
773 set_decoration(var_id, DecorationDescriptorSet, kBufferSizeBufferBinding);
774 set_decoration(var_id, DecorationBinding, msl_options.buffer_size_buffer_index);
775 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.buffer_size_buffer_index);
776 buffer_size_buffer_id = var_id;
777 }
778
779 if (needs_view_mask_buffer())
780 {
781 uint32_t var_id = build_constant_uint_array_pointer();
782 set_name(var_id, "spvViewMask");
783 // This should never match anything.
784 set_decoration(var_id, DecorationDescriptorSet, ~(4u));
785 set_decoration(var_id, DecorationBinding, msl_options.view_mask_buffer_index);
786 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.view_mask_buffer_index);
787 view_mask_buffer_id = var_id;
788 }
789
790 if (!buffers_requiring_dynamic_offset.empty())
791 {
792 uint32_t var_id = build_constant_uint_array_pointer();
793 set_name(var_id, "spvDynamicOffsets");
794 // This should never match anything.
795 set_decoration(var_id, DecorationDescriptorSet, ~(5u));
796 set_decoration(var_id, DecorationBinding, msl_options.dynamic_offsets_buffer_index);
797 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary,
798 msl_options.dynamic_offsets_buffer_index);
799 dynamic_offsets_buffer_id = var_id;
800 }
801 }
802
803 // Checks if the specified builtin variable (e.g. gl_InstanceIndex) is marked as active.
804 // If not, it marks it as active and forces a recompilation.
805 // This might be used when the optimization of inactive builtins was too optimistic (e.g. when "spvOut" is emitted).
ensure_builtin(spv::StorageClass storage,spv::BuiltIn builtin)806 void CompilerMSL::ensure_builtin(spv::StorageClass storage, spv::BuiltIn builtin)
807 {
808 Bitset *active_builtins = nullptr;
809 switch (storage)
810 {
811 case StorageClassInput:
812 active_builtins = &active_input_builtins;
813 break;
814
815 case StorageClassOutput:
816 active_builtins = &active_output_builtins;
817 break;
818
819 default:
820 break;
821 }
822
823 // At this point, the specified builtin variable must have already been declared in the entry point.
824 // If not, mark as active and force recompile.
825 if (active_builtins != nullptr && !active_builtins->get(builtin))
826 {
827 active_builtins->set(builtin);
828 force_recompile();
829 }
830 }
831
mark_implicit_builtin(StorageClass storage,BuiltIn builtin,uint32_t id)832 void CompilerMSL::mark_implicit_builtin(StorageClass storage, BuiltIn builtin, uint32_t id)
833 {
834 Bitset *active_builtins = nullptr;
835 switch (storage)
836 {
837 case StorageClassInput:
838 active_builtins = &active_input_builtins;
839 break;
840
841 case StorageClassOutput:
842 active_builtins = &active_output_builtins;
843 break;
844
845 default:
846 break;
847 }
848
849 assert(active_builtins != nullptr);
850 active_builtins->set(builtin);
851
852 auto &var = get_entry_point().interface_variables;
853 if (find(begin(var), end(var), VariableID(id)) == end(var))
854 var.push_back(id);
855 }
856
build_constant_uint_array_pointer()857 uint32_t CompilerMSL::build_constant_uint_array_pointer()
858 {
859 uint32_t offset = ir.increase_bound_by(3);
860 uint32_t type_ptr_id = offset;
861 uint32_t type_ptr_ptr_id = offset + 1;
862 uint32_t var_id = offset + 2;
863
864 // Create a buffer to hold extra data, including the swizzle constants.
865 SPIRType uint_type_pointer = get_uint_type();
866 uint_type_pointer.pointer = true;
867 uint_type_pointer.pointer_depth = 1;
868 uint_type_pointer.parent_type = get_uint_type_id();
869 uint_type_pointer.storage = StorageClassUniform;
870 set<SPIRType>(type_ptr_id, uint_type_pointer);
871 set_decoration(type_ptr_id, DecorationArrayStride, 4);
872
873 SPIRType uint_type_pointer2 = uint_type_pointer;
874 uint_type_pointer2.pointer_depth++;
875 uint_type_pointer2.parent_type = type_ptr_id;
876 set<SPIRType>(type_ptr_ptr_id, uint_type_pointer2);
877
878 set<SPIRVariable>(var_id, type_ptr_ptr_id, StorageClassUniformConstant);
879 return var_id;
880 }
881
create_sampler_address(const char * prefix,MSLSamplerAddress addr)882 static string create_sampler_address(const char *prefix, MSLSamplerAddress addr)
883 {
884 switch (addr)
885 {
886 case MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE:
887 return join(prefix, "address::clamp_to_edge");
888 case MSL_SAMPLER_ADDRESS_CLAMP_TO_ZERO:
889 return join(prefix, "address::clamp_to_zero");
890 case MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER:
891 return join(prefix, "address::clamp_to_border");
892 case MSL_SAMPLER_ADDRESS_REPEAT:
893 return join(prefix, "address::repeat");
894 case MSL_SAMPLER_ADDRESS_MIRRORED_REPEAT:
895 return join(prefix, "address::mirrored_repeat");
896 default:
897 SPIRV_CROSS_THROW("Invalid sampler addressing mode.");
898 }
899 }
900
get_stage_in_struct_type()901 SPIRType &CompilerMSL::get_stage_in_struct_type()
902 {
903 auto &si_var = get<SPIRVariable>(stage_in_var_id);
904 return get_variable_data_type(si_var);
905 }
906
get_stage_out_struct_type()907 SPIRType &CompilerMSL::get_stage_out_struct_type()
908 {
909 auto &so_var = get<SPIRVariable>(stage_out_var_id);
910 return get_variable_data_type(so_var);
911 }
912
get_patch_stage_in_struct_type()913 SPIRType &CompilerMSL::get_patch_stage_in_struct_type()
914 {
915 auto &si_var = get<SPIRVariable>(patch_stage_in_var_id);
916 return get_variable_data_type(si_var);
917 }
918
get_patch_stage_out_struct_type()919 SPIRType &CompilerMSL::get_patch_stage_out_struct_type()
920 {
921 auto &so_var = get<SPIRVariable>(patch_stage_out_var_id);
922 return get_variable_data_type(so_var);
923 }
924
get_tess_factor_struct_name()925 std::string CompilerMSL::get_tess_factor_struct_name()
926 {
927 if (get_entry_point().flags.get(ExecutionModeTriangles))
928 return "MTLTriangleTessellationFactorsHalf";
929 return "MTLQuadTessellationFactorsHalf";
930 }
931
get_uint_type()932 SPIRType &CompilerMSL::get_uint_type()
933 {
934 return get<SPIRType>(get_uint_type_id());
935 }
936
get_uint_type_id()937 uint32_t CompilerMSL::get_uint_type_id()
938 {
939 if (uint_type_id != 0)
940 return uint_type_id;
941
942 uint_type_id = ir.increase_bound_by(1);
943
944 SPIRType type;
945 type.basetype = SPIRType::UInt;
946 type.width = 32;
947 set<SPIRType>(uint_type_id, type);
948 return uint_type_id;
949 }
950
emit_entry_point_declarations()951 void CompilerMSL::emit_entry_point_declarations()
952 {
953 // FIXME: Get test coverage here ...
954 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
955 declare_complex_constant_arrays();
956
957 // Emit constexpr samplers here.
958 for (auto &samp : constexpr_samplers_by_id)
959 {
960 auto &var = get<SPIRVariable>(samp.first);
961 auto &type = get<SPIRType>(var.basetype);
962 if (type.basetype == SPIRType::Sampler)
963 add_resource_name(samp.first);
964
965 SmallVector<string> args;
966 auto &s = samp.second;
967
968 if (s.coord != MSL_SAMPLER_COORD_NORMALIZED)
969 args.push_back("coord::pixel");
970
971 if (s.min_filter == s.mag_filter)
972 {
973 if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
974 args.push_back("filter::linear");
975 }
976 else
977 {
978 if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
979 args.push_back("min_filter::linear");
980 if (s.mag_filter != MSL_SAMPLER_FILTER_NEAREST)
981 args.push_back("mag_filter::linear");
982 }
983
984 switch (s.mip_filter)
985 {
986 case MSL_SAMPLER_MIP_FILTER_NONE:
987 // Default
988 break;
989 case MSL_SAMPLER_MIP_FILTER_NEAREST:
990 args.push_back("mip_filter::nearest");
991 break;
992 case MSL_SAMPLER_MIP_FILTER_LINEAR:
993 args.push_back("mip_filter::linear");
994 break;
995 default:
996 SPIRV_CROSS_THROW("Invalid mip filter.");
997 }
998
999 if (s.s_address == s.t_address && s.s_address == s.r_address)
1000 {
1001 if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1002 args.push_back(create_sampler_address("", s.s_address));
1003 }
1004 else
1005 {
1006 if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1007 args.push_back(create_sampler_address("s_", s.s_address));
1008 if (s.t_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1009 args.push_back(create_sampler_address("t_", s.t_address));
1010 if (s.r_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1011 args.push_back(create_sampler_address("r_", s.r_address));
1012 }
1013
1014 if (s.compare_enable)
1015 {
1016 switch (s.compare_func)
1017 {
1018 case MSL_SAMPLER_COMPARE_FUNC_ALWAYS:
1019 args.push_back("compare_func::always");
1020 break;
1021 case MSL_SAMPLER_COMPARE_FUNC_NEVER:
1022 args.push_back("compare_func::never");
1023 break;
1024 case MSL_SAMPLER_COMPARE_FUNC_EQUAL:
1025 args.push_back("compare_func::equal");
1026 break;
1027 case MSL_SAMPLER_COMPARE_FUNC_NOT_EQUAL:
1028 args.push_back("compare_func::not_equal");
1029 break;
1030 case MSL_SAMPLER_COMPARE_FUNC_LESS:
1031 args.push_back("compare_func::less");
1032 break;
1033 case MSL_SAMPLER_COMPARE_FUNC_LESS_EQUAL:
1034 args.push_back("compare_func::less_equal");
1035 break;
1036 case MSL_SAMPLER_COMPARE_FUNC_GREATER:
1037 args.push_back("compare_func::greater");
1038 break;
1039 case MSL_SAMPLER_COMPARE_FUNC_GREATER_EQUAL:
1040 args.push_back("compare_func::greater_equal");
1041 break;
1042 default:
1043 SPIRV_CROSS_THROW("Invalid sampler compare function.");
1044 }
1045 }
1046
1047 if (s.s_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER || s.t_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER ||
1048 s.r_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER)
1049 {
1050 switch (s.border_color)
1051 {
1052 case MSL_SAMPLER_BORDER_COLOR_OPAQUE_BLACK:
1053 args.push_back("border_color::opaque_black");
1054 break;
1055 case MSL_SAMPLER_BORDER_COLOR_OPAQUE_WHITE:
1056 args.push_back("border_color::opaque_white");
1057 break;
1058 case MSL_SAMPLER_BORDER_COLOR_TRANSPARENT_BLACK:
1059 args.push_back("border_color::transparent_black");
1060 break;
1061 default:
1062 SPIRV_CROSS_THROW("Invalid sampler border color.");
1063 }
1064 }
1065
1066 if (s.anisotropy_enable)
1067 args.push_back(join("max_anisotropy(", s.max_anisotropy, ")"));
1068 if (s.lod_clamp_enable)
1069 {
1070 args.push_back(join("lod_clamp(", convert_to_string(s.lod_clamp_min, current_locale_radix_character), ", ",
1071 convert_to_string(s.lod_clamp_max, current_locale_radix_character), ")"));
1072 }
1073
1074 // If we would emit no arguments, then omit the parentheses entirely. Otherwise,
1075 // we'll wind up with a "most vexing parse" situation.
1076 if (args.empty())
1077 statement("constexpr sampler ",
1078 type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first),
1079 ";");
1080 else
1081 statement("constexpr sampler ",
1082 type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first),
1083 "(", merge(args), ");");
1084 }
1085
1086 // Emit dynamic buffers here.
1087 for (auto &dynamic_buffer : buffers_requiring_dynamic_offset)
1088 {
1089 if (!dynamic_buffer.second.second)
1090 {
1091 // Could happen if no buffer was used at requested binding point.
1092 continue;
1093 }
1094
1095 const auto &var = get<SPIRVariable>(dynamic_buffer.second.second);
1096 uint32_t var_id = var.self;
1097 const auto &type = get_variable_data_type(var);
1098 string name = to_name(var.self);
1099 uint32_t desc_set = get_decoration(var.self, DecorationDescriptorSet);
1100 uint32_t arg_id = argument_buffer_ids[desc_set];
1101 uint32_t base_index = dynamic_buffer.second.first;
1102
1103 if (!type.array.empty())
1104 {
1105 // This is complicated, because we need to support arrays of arrays.
1106 // And it's even worse if the outermost dimension is a runtime array, because now
1107 // all this complicated goop has to go into the shader itself. (FIXME)
1108 if (!type.array[type.array.size() - 1])
1109 SPIRV_CROSS_THROW("Runtime arrays with dynamic offsets are not supported yet.");
1110 else
1111 {
1112 is_using_builtin_array = true;
1113 statement(get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id), name,
1114 type_to_array_glsl(type), " =");
1115
1116 uint32_t dim = uint32_t(type.array.size());
1117 uint32_t j = 0;
1118 for (SmallVector<uint32_t> indices(type.array.size());
1119 indices[type.array.size() - 1] < to_array_size_literal(type); j++)
1120 {
1121 while (dim > 0)
1122 {
1123 begin_scope();
1124 --dim;
1125 }
1126
1127 string arrays;
1128 for (uint32_t i = uint32_t(type.array.size()); i; --i)
1129 arrays += join("[", indices[i - 1], "]");
1130 statement("(", get_argument_address_space(var), " ", type_to_glsl(type), "* ",
1131 to_restrict(var_id, false), ")((", get_argument_address_space(var), " char* ",
1132 to_restrict(var_id, false), ")", to_name(arg_id), ".", ensure_valid_name(name, "m"),
1133 arrays, " + ", to_name(dynamic_offsets_buffer_id), "[", base_index + j, "]),");
1134
1135 while (++indices[dim] >= to_array_size_literal(type, dim) && dim < type.array.size() - 1)
1136 {
1137 end_scope(",");
1138 indices[dim++] = 0;
1139 }
1140 }
1141 end_scope_decl();
1142 statement_no_indent("");
1143 is_using_builtin_array = false;
1144 }
1145 }
1146 else
1147 {
1148 statement(get_argument_address_space(var), " auto& ", to_restrict(var_id), name, " = *(",
1149 get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id, false), ")((",
1150 get_argument_address_space(var), " char* ", to_restrict(var_id, false), ")", to_name(arg_id), ".",
1151 ensure_valid_name(name, "m"), " + ", to_name(dynamic_offsets_buffer_id), "[", base_index, "]);");
1152 }
1153 }
1154
1155 // Emit buffer arrays here.
1156 for (uint32_t array_id : buffer_arrays)
1157 {
1158 const auto &var = get<SPIRVariable>(array_id);
1159 const auto &type = get_variable_data_type(var);
1160 const auto &buffer_type = get_variable_element_type(var);
1161 string name = to_name(array_id);
1162 statement(get_argument_address_space(var), " ", type_to_glsl(buffer_type), "* ", to_restrict(array_id), name,
1163 "[] =");
1164 begin_scope();
1165 for (uint32_t i = 0; i < to_array_size_literal(type); ++i)
1166 statement(name, "_", i, ",");
1167 end_scope_decl();
1168 statement_no_indent("");
1169 }
1170 // For some reason, without this, we end up emitting the arrays twice.
1171 buffer_arrays.clear();
1172
1173 // Emit disabled fragment outputs.
1174 std::sort(disabled_frag_outputs.begin(), disabled_frag_outputs.end());
1175 for (uint32_t var_id : disabled_frag_outputs)
1176 {
1177 auto &var = get<SPIRVariable>(var_id);
1178 add_local_variable_name(var_id);
1179 statement(variable_decl(var), ";");
1180 var.deferred_declaration = false;
1181 }
1182 }
1183
compile()1184 string CompilerMSL::compile()
1185 {
1186 replace_illegal_entry_point_names();
1187 ir.fixup_reserved_names();
1188
1189 // Do not deal with GLES-isms like precision, older extensions and such.
1190 options.vulkan_semantics = true;
1191 options.es = false;
1192 options.version = 450;
1193 backend.null_pointer_literal = "nullptr";
1194 backend.float_literal_suffix = false;
1195 backend.uint32_t_literal_suffix = true;
1196 backend.int16_t_literal_suffix = "";
1197 backend.uint16_t_literal_suffix = "";
1198 backend.basic_int_type = "int";
1199 backend.basic_uint_type = "uint";
1200 backend.basic_int8_type = "char";
1201 backend.basic_uint8_type = "uchar";
1202 backend.basic_int16_type = "short";
1203 backend.basic_uint16_type = "ushort";
1204 backend.discard_literal = "discard_fragment()";
1205 backend.demote_literal = "discard_fragment()";
1206 backend.boolean_mix_function = "select";
1207 backend.swizzle_is_function = false;
1208 backend.shared_is_implied = false;
1209 backend.use_initializer_list = true;
1210 backend.use_typed_initializer_list = true;
1211 backend.native_row_major_matrix = false;
1212 backend.unsized_array_supported = false;
1213 backend.can_declare_arrays_inline = false;
1214 backend.allow_truncated_access_chain = true;
1215 backend.comparison_image_samples_scalar = true;
1216 backend.native_pointers = true;
1217 backend.nonuniform_qualifier = "";
1218 backend.support_small_type_sampling_result = true;
1219 backend.supports_empty_struct = true;
1220
1221 // Allow Metal to use the array<T> template unless we force it off.
1222 backend.can_return_array = !msl_options.force_native_arrays;
1223 backend.array_is_value_type = !msl_options.force_native_arrays;
1224 // Arrays which are part of buffer objects are never considered to be native arrays.
1225 backend.buffer_offset_array_is_value_type = false;
1226
1227 capture_output_to_buffer = msl_options.capture_output_to_buffer;
1228 is_rasterization_disabled = msl_options.disable_rasterization || capture_output_to_buffer;
1229
1230 // Initialize array here rather than constructor, MSVC 2013 workaround.
1231 for (auto &id : next_metal_resource_ids)
1232 id = 0;
1233
1234 fixup_type_alias();
1235 replace_illegal_names();
1236 sync_entry_point_aliases_and_names();
1237
1238 build_function_control_flow_graphs_and_analyze();
1239 update_active_builtins();
1240 analyze_image_and_sampler_usage();
1241 analyze_sampled_image_usage();
1242 analyze_interlocked_resource_usage();
1243 preprocess_op_codes();
1244 build_implicit_builtins();
1245
1246 fixup_image_load_store_access();
1247
1248 set_enabled_interface_variables(get_active_interface_variables());
1249 if (msl_options.force_active_argument_buffer_resources)
1250 activate_argument_buffer_resources();
1251
1252 if (swizzle_buffer_id)
1253 active_interface_variables.insert(swizzle_buffer_id);
1254 if (buffer_size_buffer_id)
1255 active_interface_variables.insert(buffer_size_buffer_id);
1256 if (view_mask_buffer_id)
1257 active_interface_variables.insert(view_mask_buffer_id);
1258 if (dynamic_offsets_buffer_id)
1259 active_interface_variables.insert(dynamic_offsets_buffer_id);
1260 if (builtin_layer_id)
1261 active_interface_variables.insert(builtin_layer_id);
1262 if (builtin_dispatch_base_id && !msl_options.supports_msl_version(1, 2))
1263 active_interface_variables.insert(builtin_dispatch_base_id);
1264 if (builtin_sample_mask_id)
1265 active_interface_variables.insert(builtin_sample_mask_id);
1266
1267 // Create structs to hold input, output and uniform variables.
1268 // Do output first to ensure out. is declared at top of entry function.
1269 qual_pos_var_name = "";
1270 stage_out_var_id = add_interface_block(StorageClassOutput);
1271 patch_stage_out_var_id = add_interface_block(StorageClassOutput, true);
1272 stage_in_var_id = add_interface_block(StorageClassInput);
1273 if (get_execution_model() == ExecutionModelTessellationEvaluation)
1274 patch_stage_in_var_id = add_interface_block(StorageClassInput, true);
1275
1276 if (get_execution_model() == ExecutionModelTessellationControl)
1277 stage_out_ptr_var_id = add_interface_block_pointer(stage_out_var_id, StorageClassOutput);
1278 if (is_tessellation_shader())
1279 stage_in_ptr_var_id = add_interface_block_pointer(stage_in_var_id, StorageClassInput);
1280
1281 // Metal vertex functions that define no output must disable rasterization and return void.
1282 if (!stage_out_var_id)
1283 is_rasterization_disabled = true;
1284
1285 // Convert the use of global variables to recursively-passed function parameters
1286 localize_global_variables();
1287 extract_global_variables_from_functions();
1288
1289 // Mark any non-stage-in structs to be tightly packed.
1290 mark_packable_structs();
1291 reorder_type_alias();
1292
1293 // Add fixup hooks required by shader inputs and outputs. This needs to happen before
1294 // the loop, so the hooks aren't added multiple times.
1295 fix_up_shader_inputs_outputs();
1296
1297 // If we are using argument buffers, we create argument buffer structures for them here.
1298 // These buffers will be used in the entry point, not the individual resources.
1299 if (msl_options.argument_buffers)
1300 {
1301 if (!msl_options.supports_msl_version(2, 0))
1302 SPIRV_CROSS_THROW("Argument buffers can only be used with MSL 2.0 and up.");
1303 analyze_argument_buffers();
1304 }
1305
1306 uint32_t pass_count = 0;
1307 do
1308 {
1309 if (pass_count >= 3)
1310 SPIRV_CROSS_THROW("Over 3 compilation loops detected. Must be a bug!");
1311
1312 reset();
1313
1314 // Start bindings at zero.
1315 next_metal_resource_index_buffer = 0;
1316 next_metal_resource_index_texture = 0;
1317 next_metal_resource_index_sampler = 0;
1318 for (auto &id : next_metal_resource_ids)
1319 id = 0;
1320
1321 // Move constructor for this type is broken on GCC 4.9 ...
1322 buffer.reset();
1323
1324 emit_header();
1325 emit_custom_templates();
1326 emit_specialization_constants_and_structs();
1327 emit_resources();
1328 emit_custom_functions();
1329 emit_function(get<SPIRFunction>(ir.default_entry_point), Bitset());
1330
1331 pass_count++;
1332 } while (is_forcing_recompilation());
1333
1334 return buffer.str();
1335 }
1336
1337 // Register the need to output any custom functions.
preprocess_op_codes()1338 void CompilerMSL::preprocess_op_codes()
1339 {
1340 OpCodePreprocessor preproc(*this);
1341 traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), preproc);
1342
1343 suppress_missing_prototypes = preproc.suppress_missing_prototypes;
1344
1345 if (preproc.uses_atomics)
1346 {
1347 add_header_line("#include <metal_atomic>");
1348 add_pragma_line("#pragma clang diagnostic ignored \"-Wunused-variable\"");
1349 }
1350
1351 // Before MSL 2.1 (2.2 for textures), Metal vertex functions that write to
1352 // resources must disable rasterization and return void.
1353 if (preproc.uses_resource_write)
1354 is_rasterization_disabled = true;
1355
1356 // Tessellation control shaders are run as compute functions in Metal, and so
1357 // must capture their output to a buffer.
1358 if (get_execution_model() == ExecutionModelTessellationControl ||
1359 (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
1360 {
1361 is_rasterization_disabled = true;
1362 capture_output_to_buffer = true;
1363 }
1364
1365 if (preproc.needs_subgroup_invocation_id)
1366 needs_subgroup_invocation_id = true;
1367 if (preproc.needs_subgroup_size)
1368 needs_subgroup_size = true;
1369 // build_implicit_builtins() hasn't run yet, and in fact, this needs to execute
1370 // before then so that gl_SampleID will get added; so we also need to check if
1371 // that function would add gl_FragCoord.
1372 if (preproc.needs_sample_id || msl_options.force_sample_rate_shading ||
1373 (is_sample_rate() && (active_input_builtins.get(BuiltInFragCoord) ||
1374 (need_subpass_input && !msl_options.use_framebuffer_fetch_subpasses))))
1375 needs_sample_id = true;
1376 }
1377
1378 // Move the Private and Workgroup global variables to the entry function.
1379 // Non-constant variables cannot have global scope in Metal.
localize_global_variables()1380 void CompilerMSL::localize_global_variables()
1381 {
1382 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1383 auto iter = global_variables.begin();
1384 while (iter != global_variables.end())
1385 {
1386 uint32_t v_id = *iter;
1387 auto &var = get<SPIRVariable>(v_id);
1388 if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup)
1389 {
1390 if (!variable_is_lut(var))
1391 entry_func.add_local_variable(v_id);
1392 iter = global_variables.erase(iter);
1393 }
1394 else
1395 iter++;
1396 }
1397 }
1398
1399 // For any global variable accessed directly by a function,
1400 // extract that variable and add it as an argument to that function.
extract_global_variables_from_functions()1401 void CompilerMSL::extract_global_variables_from_functions()
1402 {
1403 // Uniforms
1404 unordered_set<uint32_t> global_var_ids;
1405 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1406 if (var.storage == StorageClassInput || var.storage == StorageClassOutput ||
1407 var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
1408 var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer)
1409 {
1410 global_var_ids.insert(var.self);
1411 }
1412 });
1413
1414 // Local vars that are declared in the main function and accessed directly by a function
1415 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1416 for (auto &var : entry_func.local_variables)
1417 if (get<SPIRVariable>(var).storage != StorageClassFunction)
1418 global_var_ids.insert(var);
1419
1420 std::set<uint32_t> added_arg_ids;
1421 unordered_set<uint32_t> processed_func_ids;
1422 extract_global_variables_from_function(ir.default_entry_point, added_arg_ids, global_var_ids, processed_func_ids);
1423 }
1424
1425 // MSL does not support the use of global variables for shader input content.
1426 // For any global variable accessed directly by the specified function, extract that variable,
1427 // add it as an argument to that function, and the arg to the added_arg_ids collection.
extract_global_variables_from_function(uint32_t func_id,std::set<uint32_t> & added_arg_ids,unordered_set<uint32_t> & global_var_ids,unordered_set<uint32_t> & processed_func_ids)1428 void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::set<uint32_t> &added_arg_ids,
1429 unordered_set<uint32_t> &global_var_ids,
1430 unordered_set<uint32_t> &processed_func_ids)
1431 {
1432 // Avoid processing a function more than once
1433 if (processed_func_ids.find(func_id) != processed_func_ids.end())
1434 {
1435 // Return function global variables
1436 added_arg_ids = function_global_vars[func_id];
1437 return;
1438 }
1439
1440 processed_func_ids.insert(func_id);
1441
1442 auto &func = get<SPIRFunction>(func_id);
1443
1444 // Recursively establish global args added to functions on which we depend.
1445 for (auto block : func.blocks)
1446 {
1447 auto &b = get<SPIRBlock>(block);
1448 for (auto &i : b.ops)
1449 {
1450 auto ops = stream(i);
1451 auto op = static_cast<Op>(i.op);
1452
1453 switch (op)
1454 {
1455 case OpLoad:
1456 case OpInBoundsAccessChain:
1457 case OpAccessChain:
1458 case OpPtrAccessChain:
1459 case OpArrayLength:
1460 {
1461 uint32_t base_id = ops[2];
1462 if (global_var_ids.find(base_id) != global_var_ids.end())
1463 added_arg_ids.insert(base_id);
1464
1465 // Use Metal's native frame-buffer fetch API for subpass inputs.
1466 auto &type = get<SPIRType>(ops[0]);
1467 if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData &&
1468 (!msl_options.use_framebuffer_fetch_subpasses))
1469 {
1470 // Implicitly reads gl_FragCoord.
1471 assert(builtin_frag_coord_id != 0);
1472 added_arg_ids.insert(builtin_frag_coord_id);
1473 if (msl_options.multiview)
1474 {
1475 // Implicitly reads gl_ViewIndex.
1476 assert(builtin_view_idx_id != 0);
1477 added_arg_ids.insert(builtin_view_idx_id);
1478 }
1479 else if (msl_options.arrayed_subpass_input)
1480 {
1481 // Implicitly reads gl_Layer.
1482 assert(builtin_layer_id != 0);
1483 added_arg_ids.insert(builtin_layer_id);
1484 }
1485 }
1486
1487 break;
1488 }
1489
1490 case OpFunctionCall:
1491 {
1492 // First see if any of the function call args are globals
1493 for (uint32_t arg_idx = 3; arg_idx < i.length; arg_idx++)
1494 {
1495 uint32_t arg_id = ops[arg_idx];
1496 if (global_var_ids.find(arg_id) != global_var_ids.end())
1497 added_arg_ids.insert(arg_id);
1498 }
1499
1500 // Then recurse into the function itself to extract globals used internally in the function
1501 uint32_t inner_func_id = ops[2];
1502 std::set<uint32_t> inner_func_args;
1503 extract_global_variables_from_function(inner_func_id, inner_func_args, global_var_ids,
1504 processed_func_ids);
1505 added_arg_ids.insert(inner_func_args.begin(), inner_func_args.end());
1506 break;
1507 }
1508
1509 case OpStore:
1510 {
1511 uint32_t base_id = ops[0];
1512 if (global_var_ids.find(base_id) != global_var_ids.end())
1513 added_arg_ids.insert(base_id);
1514
1515 uint32_t rvalue_id = ops[1];
1516 if (global_var_ids.find(rvalue_id) != global_var_ids.end())
1517 added_arg_ids.insert(rvalue_id);
1518
1519 break;
1520 }
1521
1522 case OpSelect:
1523 {
1524 uint32_t base_id = ops[3];
1525 if (global_var_ids.find(base_id) != global_var_ids.end())
1526 added_arg_ids.insert(base_id);
1527 base_id = ops[4];
1528 if (global_var_ids.find(base_id) != global_var_ids.end())
1529 added_arg_ids.insert(base_id);
1530 break;
1531 }
1532
1533 // Emulate texture2D atomic operations
1534 case OpImageTexelPointer:
1535 {
1536 // When using the pointer, we need to know which variable it is actually loaded from.
1537 uint32_t base_id = ops[2];
1538 auto *var = maybe_get_backing_variable(base_id);
1539 if (var && atomic_image_vars.count(var->self))
1540 {
1541 if (global_var_ids.find(base_id) != global_var_ids.end())
1542 added_arg_ids.insert(base_id);
1543 }
1544 break;
1545 }
1546
1547 case OpExtInst:
1548 {
1549 uint32_t extension_set = ops[2];
1550 if (get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
1551 {
1552 auto op_450 = static_cast<GLSLstd450>(ops[3]);
1553 switch (op_450)
1554 {
1555 case GLSLstd450InterpolateAtCentroid:
1556 case GLSLstd450InterpolateAtSample:
1557 case GLSLstd450InterpolateAtOffset:
1558 {
1559 // For these, we really need the stage-in block. It is theoretically possible to pass the
1560 // interpolant object, but a) doing so would require us to create an entirely new variable
1561 // with Interpolant type, and b) if we have a struct or array, handling all the members and
1562 // elements could get unwieldy fast.
1563 added_arg_ids.insert(stage_in_var_id);
1564 break;
1565 }
1566 default:
1567 break;
1568 }
1569 }
1570 break;
1571 }
1572
1573 case OpGroupNonUniformInverseBallot:
1574 {
1575 added_arg_ids.insert(builtin_subgroup_invocation_id_id);
1576 break;
1577 }
1578
1579 case OpGroupNonUniformBallotFindLSB:
1580 case OpGroupNonUniformBallotFindMSB:
1581 {
1582 added_arg_ids.insert(builtin_subgroup_size_id);
1583 break;
1584 }
1585
1586 case OpGroupNonUniformBallotBitCount:
1587 {
1588 auto operation = static_cast<GroupOperation>(ops[3]);
1589 switch (operation)
1590 {
1591 case GroupOperationReduce:
1592 added_arg_ids.insert(builtin_subgroup_size_id);
1593 break;
1594 case GroupOperationInclusiveScan:
1595 case GroupOperationExclusiveScan:
1596 added_arg_ids.insert(builtin_subgroup_invocation_id_id);
1597 break;
1598 default:
1599 break;
1600 }
1601 break;
1602 }
1603
1604 default:
1605 break;
1606 }
1607
1608 // TODO: Add all other operations which can affect memory.
1609 // We should consider a more unified system here to reduce boiler-plate.
1610 // This kind of analysis is done in several places ...
1611 }
1612 }
1613
1614 function_global_vars[func_id] = added_arg_ids;
1615
1616 // Add the global variables as arguments to the function
1617 if (func_id != ir.default_entry_point)
1618 {
1619 bool added_in = false;
1620 bool added_out = false;
1621 for (uint32_t arg_id : added_arg_ids)
1622 {
1623 auto &var = get<SPIRVariable>(arg_id);
1624 uint32_t type_id = var.basetype;
1625 auto *p_type = &get<SPIRType>(type_id);
1626 BuiltIn bi_type = BuiltIn(get_decoration(arg_id, DecorationBuiltIn));
1627
1628 if (((is_tessellation_shader() && var.storage == StorageClassInput) ||
1629 (get_execution_model() == ExecutionModelTessellationControl && var.storage == StorageClassOutput)) &&
1630 !(has_decoration(arg_id, DecorationPatch) || is_patch_block(*p_type)) &&
1631 (!is_builtin_variable(var) || bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
1632 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance ||
1633 p_type->basetype == SPIRType::Struct))
1634 {
1635 // Tessellation control shaders see inputs and per-vertex outputs as arrays.
1636 // Similarly, tessellation evaluation shaders see per-vertex inputs as arrays.
1637 // We collected them into a structure; we must pass the array of this
1638 // structure to the function.
1639 std::string name;
1640 if (var.storage == StorageClassInput)
1641 {
1642 if (added_in)
1643 continue;
1644 name = "gl_in";
1645 arg_id = stage_in_ptr_var_id;
1646 added_in = true;
1647 }
1648 else if (var.storage == StorageClassOutput)
1649 {
1650 if (added_out)
1651 continue;
1652 name = "gl_out";
1653 arg_id = stage_out_ptr_var_id;
1654 added_out = true;
1655 }
1656 type_id = get<SPIRVariable>(arg_id).basetype;
1657 uint32_t next_id = ir.increase_bound_by(1);
1658 func.add_parameter(type_id, next_id, true);
1659 set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1660
1661 set_name(next_id, name);
1662 }
1663 else if (is_builtin_variable(var) && p_type->basetype == SPIRType::Struct)
1664 {
1665 // Get the pointee type
1666 type_id = get_pointee_type_id(type_id);
1667 p_type = &get<SPIRType>(type_id);
1668
1669 uint32_t mbr_idx = 0;
1670 for (auto &mbr_type_id : p_type->member_types)
1671 {
1672 BuiltIn builtin = BuiltInMax;
1673 bool is_builtin = is_member_builtin(*p_type, mbr_idx, &builtin);
1674 if (is_builtin && has_active_builtin(builtin, var.storage))
1675 {
1676 // Add a arg variable with the same type and decorations as the member
1677 uint32_t next_ids = ir.increase_bound_by(2);
1678 uint32_t ptr_type_id = next_ids + 0;
1679 uint32_t var_id = next_ids + 1;
1680
1681 // Make sure we have an actual pointer type,
1682 // so that we will get the appropriate address space when declaring these builtins.
1683 auto &ptr = set<SPIRType>(ptr_type_id, get<SPIRType>(mbr_type_id));
1684 ptr.self = mbr_type_id;
1685 ptr.storage = var.storage;
1686 ptr.pointer = true;
1687 ptr.parent_type = mbr_type_id;
1688
1689 func.add_parameter(mbr_type_id, var_id, true);
1690 set<SPIRVariable>(var_id, ptr_type_id, StorageClassFunction);
1691 ir.meta[var_id].decoration = ir.meta[type_id].members[mbr_idx];
1692 }
1693 mbr_idx++;
1694 }
1695 }
1696 else
1697 {
1698 uint32_t next_id = ir.increase_bound_by(1);
1699 func.add_parameter(type_id, next_id, true);
1700 set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1701
1702 // Ensure the existing variable has a valid name and the new variable has all the same meta info
1703 set_name(arg_id, ensure_valid_name(to_name(arg_id), "v"));
1704 ir.meta[next_id] = ir.meta[arg_id];
1705 }
1706 }
1707 }
1708 }
1709
1710 // For all variables that are some form of non-input-output interface block, mark that all the structs
1711 // that are recursively contained within the type referenced by that variable should be packed tightly.
mark_packable_structs()1712 void CompilerMSL::mark_packable_structs()
1713 {
1714 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1715 if (var.storage != StorageClassFunction && !is_hidden_variable(var))
1716 {
1717 auto &type = this->get<SPIRType>(var.basetype);
1718 if (type.pointer &&
1719 (type.storage == StorageClassUniform || type.storage == StorageClassUniformConstant ||
1720 type.storage == StorageClassPushConstant || type.storage == StorageClassStorageBuffer) &&
1721 (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
1722 mark_as_packable(type);
1723 }
1724 });
1725 }
1726
1727 // If the specified type is a struct, it and any nested structs
1728 // are marked as packable with the SPIRVCrossDecorationBufferBlockRepacked decoration,
mark_as_packable(SPIRType & type)1729 void CompilerMSL::mark_as_packable(SPIRType &type)
1730 {
1731 // If this is not the base type (eg. it's a pointer or array), tunnel down
1732 if (type.parent_type)
1733 {
1734 mark_as_packable(get<SPIRType>(type.parent_type));
1735 return;
1736 }
1737
1738 if (type.basetype == SPIRType::Struct)
1739 {
1740 set_extended_decoration(type.self, SPIRVCrossDecorationBufferBlockRepacked);
1741
1742 // Recurse
1743 uint32_t mbr_cnt = uint32_t(type.member_types.size());
1744 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
1745 {
1746 uint32_t mbr_type_id = type.member_types[mbr_idx];
1747 auto &mbr_type = get<SPIRType>(mbr_type_id);
1748 mark_as_packable(mbr_type);
1749 if (mbr_type.type_alias)
1750 {
1751 auto &mbr_type_alias = get<SPIRType>(mbr_type.type_alias);
1752 mark_as_packable(mbr_type_alias);
1753 }
1754 }
1755 }
1756 }
1757
1758 // If a shader input exists at the location, it is marked as being used by this shader
mark_location_as_used_by_shader(uint32_t location,const SPIRType & type,StorageClass storage)1759 void CompilerMSL::mark_location_as_used_by_shader(uint32_t location, const SPIRType &type, StorageClass storage)
1760 {
1761 if (storage != StorageClassInput)
1762 return;
1763 if (is_array(type))
1764 {
1765 uint32_t dim = 1;
1766 for (uint32_t i = 0; i < type.array.size(); i++)
1767 dim *= to_array_size_literal(type, i);
1768 for (uint32_t i = 0; i < dim; i++)
1769 {
1770 if (is_matrix(type))
1771 {
1772 for (uint32_t j = 0; j < type.columns; j++)
1773 inputs_in_use.insert(location++);
1774 }
1775 else
1776 inputs_in_use.insert(location++);
1777 }
1778 }
1779 else if (is_matrix(type))
1780 {
1781 for (uint32_t i = 0; i < type.columns; i++)
1782 inputs_in_use.insert(location + i);
1783 }
1784 else
1785 inputs_in_use.insert(location);
1786 }
1787
get_target_components_for_fragment_location(uint32_t location) const1788 uint32_t CompilerMSL::get_target_components_for_fragment_location(uint32_t location) const
1789 {
1790 auto itr = fragment_output_components.find(location);
1791 if (itr == end(fragment_output_components))
1792 return 4;
1793 else
1794 return itr->second;
1795 }
1796
build_extended_vector_type(uint32_t type_id,uint32_t components,SPIRType::BaseType basetype)1797 uint32_t CompilerMSL::build_extended_vector_type(uint32_t type_id, uint32_t components, SPIRType::BaseType basetype)
1798 {
1799 uint32_t new_type_id = ir.increase_bound_by(1);
1800 auto &old_type = get<SPIRType>(type_id);
1801 auto *type = &set<SPIRType>(new_type_id, old_type);
1802 type->vecsize = components;
1803 if (basetype != SPIRType::Unknown)
1804 type->basetype = basetype;
1805 type->self = new_type_id;
1806 type->parent_type = type_id;
1807 type->array.clear();
1808 type->array_size_literal.clear();
1809 type->pointer = false;
1810
1811 if (is_array(old_type))
1812 {
1813 uint32_t array_type_id = ir.increase_bound_by(1);
1814 type = &set<SPIRType>(array_type_id, *type);
1815 type->parent_type = new_type_id;
1816 type->array = old_type.array;
1817 type->array_size_literal = old_type.array_size_literal;
1818 new_type_id = array_type_id;
1819 }
1820
1821 if (old_type.pointer)
1822 {
1823 uint32_t ptr_type_id = ir.increase_bound_by(1);
1824 type = &set<SPIRType>(ptr_type_id, *type);
1825 type->self = new_type_id;
1826 type->parent_type = new_type_id;
1827 type->storage = old_type.storage;
1828 type->pointer = true;
1829 new_type_id = ptr_type_id;
1830 }
1831
1832 return new_type_id;
1833 }
1834
build_msl_interpolant_type(uint32_t type_id,bool is_noperspective)1835 uint32_t CompilerMSL::build_msl_interpolant_type(uint32_t type_id, bool is_noperspective)
1836 {
1837 uint32_t new_type_id = ir.increase_bound_by(1);
1838 SPIRType &type = set<SPIRType>(new_type_id, get<SPIRType>(type_id));
1839 type.basetype = SPIRType::Interpolant;
1840 type.parent_type = type_id;
1841 // In Metal, the pull-model interpolant type encodes perspective-vs-no-perspective in the type itself.
1842 // Add this decoration so we know which argument to pass to the template.
1843 if (is_noperspective)
1844 set_decoration(new_type_id, DecorationNoPerspective);
1845 return new_type_id;
1846 }
1847
add_plain_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)1848 void CompilerMSL::add_plain_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
1849 SPIRType &ib_type, SPIRVariable &var, InterfaceBlockMeta &meta)
1850 {
1851 bool is_builtin = is_builtin_variable(var);
1852 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
1853 bool is_flat = has_decoration(var.self, DecorationFlat);
1854 bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
1855 bool is_centroid = has_decoration(var.self, DecorationCentroid);
1856 bool is_sample = has_decoration(var.self, DecorationSample);
1857
1858 // Add a reference to the variable type to the interface struct.
1859 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
1860 uint32_t type_id = ensure_correct_builtin_type(var.basetype, builtin);
1861 var.basetype = type_id;
1862
1863 type_id = get_pointee_type_id(var.basetype);
1864 if (meta.strip_array && is_array(get<SPIRType>(type_id)))
1865 type_id = get<SPIRType>(type_id).parent_type;
1866 auto &type = get<SPIRType>(type_id);
1867 uint32_t target_components = 0;
1868 uint32_t type_components = type.vecsize;
1869
1870 bool padded_output = false;
1871 bool padded_input = false;
1872 uint32_t start_component = 0;
1873
1874 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1875
1876 // Deal with Component decorations.
1877 InterfaceBlockMeta::LocationMeta *location_meta = nullptr;
1878 if (has_decoration(var.self, DecorationLocation))
1879 {
1880 auto location_meta_itr = meta.location_meta.find(get_decoration(var.self, DecorationLocation));
1881 if (location_meta_itr != end(meta.location_meta))
1882 location_meta = &location_meta_itr->second;
1883 }
1884
1885 bool pad_fragment_output = has_decoration(var.self, DecorationLocation) &&
1886 msl_options.pad_fragment_output_components &&
1887 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput;
1888
1889 // Check if we need to pad fragment output to match a certain number of components.
1890 if (location_meta)
1891 {
1892 start_component = get_decoration(var.self, DecorationComponent);
1893 uint32_t num_components = location_meta->num_components;
1894 if (pad_fragment_output)
1895 {
1896 uint32_t locn = get_decoration(var.self, DecorationLocation);
1897 num_components = std::max(num_components, get_target_components_for_fragment_location(locn));
1898 }
1899
1900 if (location_meta->ib_index != ~0u)
1901 {
1902 // We have already declared the variable. Just emit an early-declared variable and fixup as needed.
1903 entry_func.add_local_variable(var.self);
1904 vars_needing_early_declaration.push_back(var.self);
1905
1906 if (var.storage == StorageClassInput)
1907 {
1908 uint32_t ib_index = location_meta->ib_index;
1909 entry_func.fixup_hooks_in.push_back([=, &var]() {
1910 statement(to_name(var.self), " = ", ib_var_ref, ".", to_member_name(ib_type, ib_index),
1911 vector_swizzle(type_components, start_component), ";");
1912 });
1913 }
1914 else
1915 {
1916 uint32_t ib_index = location_meta->ib_index;
1917 entry_func.fixup_hooks_out.push_back([=, &var]() {
1918 statement(ib_var_ref, ".", to_member_name(ib_type, ib_index),
1919 vector_swizzle(type_components, start_component), " = ", to_name(var.self), ";");
1920 });
1921 }
1922 return;
1923 }
1924 else
1925 {
1926 location_meta->ib_index = uint32_t(ib_type.member_types.size());
1927 type_id = build_extended_vector_type(type_id, num_components);
1928 if (var.storage == StorageClassInput)
1929 padded_input = true;
1930 else
1931 padded_output = true;
1932 }
1933 }
1934 else if (pad_fragment_output)
1935 {
1936 uint32_t locn = get_decoration(var.self, DecorationLocation);
1937 target_components = get_target_components_for_fragment_location(locn);
1938 if (type_components < target_components)
1939 {
1940 // Make a new type here.
1941 type_id = build_extended_vector_type(type_id, target_components);
1942 padded_output = true;
1943 }
1944 }
1945
1946 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
1947 ib_type.member_types.push_back(build_msl_interpolant_type(type_id, is_noperspective));
1948 else
1949 ib_type.member_types.push_back(type_id);
1950
1951 // Give the member a name
1952 string mbr_name = ensure_valid_name(to_expression(var.self), "m");
1953 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
1954
1955 // Update the original variable reference to include the structure reference
1956 string qual_var_name = ib_var_ref + "." + mbr_name;
1957 // If using pull-model interpolation, need to add a call to the correct interpolation method.
1958 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
1959 {
1960 if (is_centroid)
1961 qual_var_name += ".interpolate_at_centroid()";
1962 else if (is_sample)
1963 qual_var_name += join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
1964 else
1965 qual_var_name += ".interpolate_at_center()";
1966 }
1967
1968 if (padded_output || padded_input)
1969 {
1970 entry_func.add_local_variable(var.self);
1971 vars_needing_early_declaration.push_back(var.self);
1972
1973 if (padded_output)
1974 {
1975 entry_func.fixup_hooks_out.push_back([=, &var]() {
1976 statement(qual_var_name, vector_swizzle(type_components, start_component), " = ", to_name(var.self),
1977 ";");
1978 });
1979 }
1980 else
1981 {
1982 entry_func.fixup_hooks_in.push_back([=, &var]() {
1983 statement(to_name(var.self), " = ", qual_var_name, vector_swizzle(type_components, start_component),
1984 ";");
1985 });
1986 }
1987 }
1988 else if (!meta.strip_array)
1989 ir.meta[var.self].decoration.qualified_alias = qual_var_name;
1990
1991 if (var.storage == StorageClassOutput && var.initializer != ID(0))
1992 {
1993 if (padded_output || padded_input)
1994 {
1995 entry_func.fixup_hooks_in.push_back(
1996 [=, &var]() { statement(to_name(var.self), " = ", to_expression(var.initializer), ";"); });
1997 }
1998 else
1999 {
2000 if (meta.strip_array)
2001 {
2002 entry_func.fixup_hooks_in.push_back([=, &var]() {
2003 uint32_t index = get_extended_decoration(var.self, SPIRVCrossDecorationInterfaceMemberIndex);
2004 statement(to_expression(stage_out_ptr_var_id), "[",
2005 builtin_to_glsl(BuiltInInvocationId, StorageClassInput), "].",
2006 to_member_name(ib_type, index), " = ", to_expression(var.initializer), "[",
2007 builtin_to_glsl(BuiltInInvocationId, StorageClassInput), "];");
2008 });
2009 }
2010 else
2011 {
2012 entry_func.fixup_hooks_in.push_back([=, &var]() {
2013 statement(qual_var_name, " = ", to_expression(var.initializer), ";");
2014 });
2015 }
2016 }
2017 }
2018
2019 // Copy the variable location from the original variable to the member
2020 if (get_decoration_bitset(var.self).get(DecorationLocation))
2021 {
2022 uint32_t locn = get_decoration(var.self, DecorationLocation);
2023 if (storage == StorageClassInput)
2024 {
2025 type_id = ensure_correct_input_type(var.basetype, locn, location_meta ? location_meta->num_components : 0);
2026 if (!location_meta)
2027 var.basetype = type_id;
2028
2029 type_id = get_pointee_type_id(type_id);
2030 if (meta.strip_array && is_array(get<SPIRType>(type_id)))
2031 type_id = get<SPIRType>(type_id).parent_type;
2032 if (pull_model_inputs.count(var.self))
2033 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(type_id, is_noperspective);
2034 else
2035 ib_type.member_types[ib_mbr_idx] = type_id;
2036 }
2037 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2038 mark_location_as_used_by_shader(locn, get<SPIRType>(type_id), storage);
2039 }
2040 else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2041 {
2042 uint32_t locn = inputs_by_builtin[builtin].location;
2043 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2044 mark_location_as_used_by_shader(locn, type, storage);
2045 }
2046
2047 if (!location_meta)
2048 {
2049 if (get_decoration_bitset(var.self).get(DecorationComponent))
2050 {
2051 uint32_t component = get_decoration(var.self, DecorationComponent);
2052 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, component);
2053 }
2054 }
2055
2056 if (get_decoration_bitset(var.self).get(DecorationIndex))
2057 {
2058 uint32_t index = get_decoration(var.self, DecorationIndex);
2059 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
2060 }
2061
2062 // Mark the member as builtin if needed
2063 if (is_builtin)
2064 {
2065 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2066 if (builtin == BuiltInPosition && storage == StorageClassOutput)
2067 qual_pos_var_name = qual_var_name;
2068 }
2069
2070 // Copy interpolation decorations if needed
2071 if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2072 {
2073 if (is_flat)
2074 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2075 if (is_noperspective)
2076 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2077 if (is_centroid)
2078 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2079 if (is_sample)
2080 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2081 }
2082
2083 // If we have location meta, there is no unique OrigID. We won't need it, since we flatten/unflatten
2084 // the variable to stack anyways here.
2085 if (!location_meta)
2086 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2087 }
2088
add_composite_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)2089 void CompilerMSL::add_composite_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2090 SPIRType &ib_type, SPIRVariable &var,
2091 InterfaceBlockMeta &meta)
2092 {
2093 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2094 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2095 uint32_t elem_cnt = 0;
2096
2097 if (is_matrix(var_type))
2098 {
2099 if (is_array(var_type))
2100 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
2101
2102 elem_cnt = var_type.columns;
2103 }
2104 else if (is_array(var_type))
2105 {
2106 if (var_type.array.size() != 1)
2107 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
2108
2109 elem_cnt = to_array_size_literal(var_type);
2110 }
2111
2112 bool is_builtin = is_builtin_variable(var);
2113 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2114 bool is_flat = has_decoration(var.self, DecorationFlat);
2115 bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
2116 bool is_centroid = has_decoration(var.self, DecorationCentroid);
2117 bool is_sample = has_decoration(var.self, DecorationSample);
2118
2119 auto *usable_type = &var_type;
2120 if (usable_type->pointer)
2121 usable_type = &get<SPIRType>(usable_type->parent_type);
2122 while (is_array(*usable_type) || is_matrix(*usable_type))
2123 usable_type = &get<SPIRType>(usable_type->parent_type);
2124
2125 // If a builtin, force it to have the proper name.
2126 if (is_builtin)
2127 set_name(var.self, builtin_to_glsl(builtin, StorageClassFunction));
2128
2129 bool flatten_from_ib_var = false;
2130 string flatten_from_ib_mbr_name;
2131
2132 if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
2133 {
2134 // Also declare [[clip_distance]] attribute here.
2135 uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
2136 ib_type.member_types.push_back(get_variable_data_type_id(var));
2137 set_member_decoration(ib_type.self, clip_array_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2138
2139 flatten_from_ib_mbr_name = builtin_to_glsl(BuiltInClipDistance, StorageClassOutput);
2140 set_member_name(ib_type.self, clip_array_mbr_idx, flatten_from_ib_mbr_name);
2141
2142 // When we flatten, we flatten directly from the "out" struct,
2143 // not from a function variable.
2144 flatten_from_ib_var = true;
2145
2146 if (!msl_options.enable_clip_distance_user_varying)
2147 return;
2148 }
2149 else if (!meta.strip_array)
2150 {
2151 // Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
2152 entry_func.add_local_variable(var.self);
2153 // We need to declare the variable early and at entry-point scope.
2154 vars_needing_early_declaration.push_back(var.self);
2155 }
2156
2157 for (uint32_t i = 0; i < elem_cnt; i++)
2158 {
2159 // Add a reference to the variable type to the interface struct.
2160 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2161
2162 uint32_t target_components = 0;
2163 bool padded_output = false;
2164 uint32_t type_id = usable_type->self;
2165
2166 // Check if we need to pad fragment output to match a certain number of components.
2167 if (get_decoration_bitset(var.self).get(DecorationLocation) && msl_options.pad_fragment_output_components &&
2168 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput)
2169 {
2170 uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
2171 target_components = get_target_components_for_fragment_location(locn);
2172 if (usable_type->vecsize < target_components)
2173 {
2174 // Make a new type here.
2175 type_id = build_extended_vector_type(usable_type->self, target_components);
2176 padded_output = true;
2177 }
2178 }
2179
2180 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2181 ib_type.member_types.push_back(build_msl_interpolant_type(get_pointee_type_id(type_id), is_noperspective));
2182 else
2183 ib_type.member_types.push_back(get_pointee_type_id(type_id));
2184
2185 // Give the member a name
2186 string mbr_name = ensure_valid_name(join(to_expression(var.self), "_", i), "m");
2187 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2188
2189 // There is no qualified alias since we need to flatten the internal array on return.
2190 if (get_decoration_bitset(var.self).get(DecorationLocation))
2191 {
2192 uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
2193 if (storage == StorageClassInput)
2194 {
2195 var.basetype = ensure_correct_input_type(var.basetype, locn);
2196 uint32_t mbr_type_id = ensure_correct_input_type(usable_type->self, locn);
2197 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2198 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective);
2199 else
2200 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2201 }
2202 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2203 mark_location_as_used_by_shader(locn, *usable_type, storage);
2204 }
2205 else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2206 {
2207 uint32_t locn = inputs_by_builtin[builtin].location + i;
2208 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2209 mark_location_as_used_by_shader(locn, *usable_type, storage);
2210 }
2211 else if (is_builtin && builtin == BuiltInClipDistance)
2212 {
2213 // Declare the ClipDistance as [[user(clipN)]].
2214 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2215 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, i);
2216 }
2217
2218 if (get_decoration_bitset(var.self).get(DecorationIndex))
2219 {
2220 uint32_t index = get_decoration(var.self, DecorationIndex);
2221 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
2222 }
2223
2224 if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2225 {
2226 // Copy interpolation decorations if needed
2227 if (is_flat)
2228 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2229 if (is_noperspective)
2230 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2231 if (is_centroid)
2232 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2233 if (is_sample)
2234 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2235 }
2236
2237 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2238
2239 // Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
2240 if (!meta.strip_array)
2241 {
2242 switch (storage)
2243 {
2244 case StorageClassInput:
2245 entry_func.fixup_hooks_in.push_back([=, &var]() {
2246 if (pull_model_inputs.count(var.self))
2247 {
2248 string lerp_call;
2249 if (is_centroid)
2250 lerp_call = ".interpolate_at_centroid()";
2251 else if (is_sample)
2252 lerp_call = join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2253 else
2254 lerp_call = ".interpolate_at_center()";
2255 statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, lerp_call, ";");
2256 }
2257 else
2258 {
2259 statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, ";");
2260 }
2261 });
2262 break;
2263
2264 case StorageClassOutput:
2265 entry_func.fixup_hooks_out.push_back([=, &var]() {
2266 if (padded_output)
2267 {
2268 auto &padded_type = this->get<SPIRType>(type_id);
2269 statement(
2270 ib_var_ref, ".", mbr_name, " = ",
2271 remap_swizzle(padded_type, usable_type->vecsize, join(to_name(var.self), "[", i, "]")),
2272 ";");
2273 }
2274 else if (flatten_from_ib_var)
2275 statement(ib_var_ref, ".", mbr_name, " = ", ib_var_ref, ".", flatten_from_ib_mbr_name, "[", i,
2276 "];");
2277 else
2278 statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), "[", i, "];");
2279 });
2280 break;
2281
2282 default:
2283 break;
2284 }
2285 }
2286 }
2287 }
2288
get_accumulated_member_location(const SPIRVariable & var,uint32_t mbr_idx,bool strip_array)2289 uint32_t CompilerMSL::get_accumulated_member_location(const SPIRVariable &var, uint32_t mbr_idx, bool strip_array)
2290 {
2291 auto &type = strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2292 uint32_t location = get_decoration(var.self, DecorationLocation);
2293
2294 for (uint32_t i = 0; i < mbr_idx; i++)
2295 {
2296 auto &mbr_type = get<SPIRType>(type.member_types[i]);
2297
2298 // Start counting from any place we have a new location decoration.
2299 if (has_member_decoration(type.self, mbr_idx, DecorationLocation))
2300 location = get_member_decoration(type.self, mbr_idx, DecorationLocation);
2301
2302 uint32_t location_count = 1;
2303
2304 if (mbr_type.columns > 1)
2305 location_count = mbr_type.columns;
2306
2307 if (!mbr_type.array.empty())
2308 for (uint32_t j = 0; j < uint32_t(mbr_type.array.size()); j++)
2309 location_count *= to_array_size_literal(mbr_type, j);
2310
2311 location += location_count;
2312 }
2313
2314 return location;
2315 }
2316
add_composite_member_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,uint32_t mbr_idx,InterfaceBlockMeta & meta)2317 void CompilerMSL::add_composite_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2318 SPIRType &ib_type, SPIRVariable &var,
2319 uint32_t mbr_idx, InterfaceBlockMeta &meta)
2320 {
2321 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2322 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2323
2324 BuiltIn builtin = BuiltInMax;
2325 bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2326 bool is_flat =
2327 has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
2328 bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
2329 has_decoration(var.self, DecorationNoPerspective);
2330 bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
2331 has_decoration(var.self, DecorationCentroid);
2332 bool is_sample =
2333 has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
2334
2335 uint32_t mbr_type_id = var_type.member_types[mbr_idx];
2336 auto &mbr_type = get<SPIRType>(mbr_type_id);
2337 uint32_t elem_cnt = 0;
2338
2339 if (is_matrix(mbr_type))
2340 {
2341 if (is_array(mbr_type))
2342 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
2343
2344 elem_cnt = mbr_type.columns;
2345 }
2346 else if (is_array(mbr_type))
2347 {
2348 if (mbr_type.array.size() != 1)
2349 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
2350
2351 elem_cnt = to_array_size_literal(mbr_type);
2352 }
2353
2354 auto *usable_type = &mbr_type;
2355 if (usable_type->pointer)
2356 usable_type = &get<SPIRType>(usable_type->parent_type);
2357 while (is_array(*usable_type) || is_matrix(*usable_type))
2358 usable_type = &get<SPIRType>(usable_type->parent_type);
2359
2360 bool flatten_from_ib_var = false;
2361 string flatten_from_ib_mbr_name;
2362
2363 if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
2364 {
2365 // Also declare [[clip_distance]] attribute here.
2366 uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
2367 ib_type.member_types.push_back(mbr_type_id);
2368 set_member_decoration(ib_type.self, clip_array_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2369
2370 flatten_from_ib_mbr_name = builtin_to_glsl(BuiltInClipDistance, StorageClassOutput);
2371 set_member_name(ib_type.self, clip_array_mbr_idx, flatten_from_ib_mbr_name);
2372
2373 // When we flatten, we flatten directly from the "out" struct,
2374 // not from a function variable.
2375 flatten_from_ib_var = true;
2376
2377 if (!msl_options.enable_clip_distance_user_varying)
2378 return;
2379 }
2380
2381 for (uint32_t i = 0; i < elem_cnt; i++)
2382 {
2383 // Add a reference to the variable type to the interface struct.
2384 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2385 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2386 ib_type.member_types.push_back(build_msl_interpolant_type(usable_type->self, is_noperspective));
2387 else
2388 ib_type.member_types.push_back(usable_type->self);
2389
2390 // Give the member a name
2391 string mbr_name = ensure_valid_name(join(to_qualified_member_name(var_type, mbr_idx), "_", i), "m");
2392 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2393
2394 if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
2395 {
2396 uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation) + i;
2397 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2398 mark_location_as_used_by_shader(locn, *usable_type, storage);
2399 }
2400 else if (has_decoration(var.self, DecorationLocation))
2401 {
2402 uint32_t locn = get_accumulated_member_location(var, mbr_idx, meta.strip_array) + i;
2403 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2404 mark_location_as_used_by_shader(locn, *usable_type, storage);
2405 }
2406 else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2407 {
2408 uint32_t locn = inputs_by_builtin[builtin].location + i;
2409 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2410 mark_location_as_used_by_shader(locn, *usable_type, storage);
2411 }
2412 else if (is_builtin && builtin == BuiltInClipDistance)
2413 {
2414 // Declare the ClipDistance as [[user(clipN)]].
2415 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2416 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, i);
2417 }
2418
2419 if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
2420 SPIRV_CROSS_THROW("DecorationComponent on matrices and arrays make little sense.");
2421
2422 if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2423 {
2424 // Copy interpolation decorations if needed
2425 if (is_flat)
2426 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2427 if (is_noperspective)
2428 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2429 if (is_centroid)
2430 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2431 if (is_sample)
2432 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2433 }
2434
2435 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2436 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
2437
2438 // Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
2439 if (!meta.strip_array)
2440 {
2441 switch (storage)
2442 {
2443 case StorageClassInput:
2444 entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
2445 if (pull_model_inputs.count(var.self))
2446 {
2447 string lerp_call;
2448 if (is_centroid)
2449 lerp_call = ".interpolate_at_centroid()";
2450 else if (is_sample)
2451 lerp_call = join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2452 else
2453 lerp_call = ".interpolate_at_center()";
2454 statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), "[", i, "] = ", ib_var_ref,
2455 ".", mbr_name, lerp_call, ";");
2456 }
2457 else
2458 {
2459 statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), "[", i, "] = ", ib_var_ref,
2460 ".", mbr_name, ";");
2461 }
2462 });
2463 break;
2464
2465 case StorageClassOutput:
2466 entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
2467 if (flatten_from_ib_var)
2468 {
2469 statement(ib_var_ref, ".", mbr_name, " = ", ib_var_ref, ".", flatten_from_ib_mbr_name, "[", i,
2470 "];");
2471 }
2472 else
2473 {
2474 statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), ".",
2475 to_member_name(var_type, mbr_idx), "[", i, "];");
2476 }
2477 });
2478 break;
2479
2480 default:
2481 break;
2482 }
2483 }
2484 }
2485 }
2486
add_plain_member_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,uint32_t mbr_idx,InterfaceBlockMeta & meta)2487 void CompilerMSL::add_plain_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2488 SPIRType &ib_type, SPIRVariable &var, uint32_t mbr_idx,
2489 InterfaceBlockMeta &meta)
2490 {
2491 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2492 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2493
2494 BuiltIn builtin = BuiltInMax;
2495 bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2496 bool is_flat =
2497 has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
2498 bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
2499 has_decoration(var.self, DecorationNoPerspective);
2500 bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
2501 has_decoration(var.self, DecorationCentroid);
2502 bool is_sample =
2503 has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
2504
2505 // Add a reference to the member to the interface struct.
2506 uint32_t mbr_type_id = var_type.member_types[mbr_idx];
2507 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2508 mbr_type_id = ensure_correct_builtin_type(mbr_type_id, builtin);
2509 var_type.member_types[mbr_idx] = mbr_type_id;
2510 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2511 ib_type.member_types.push_back(build_msl_interpolant_type(mbr_type_id, is_noperspective));
2512 else
2513 ib_type.member_types.push_back(mbr_type_id);
2514
2515 // Give the member a name
2516 string mbr_name = ensure_valid_name(to_qualified_member_name(var_type, mbr_idx), "m");
2517 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2518
2519 // Update the original variable reference to include the structure reference
2520 string qual_var_name = ib_var_ref + "." + mbr_name;
2521 // If using pull-model interpolation, need to add a call to the correct interpolation method.
2522 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2523 {
2524 if (is_centroid)
2525 qual_var_name += ".interpolate_at_centroid()";
2526 else if (is_sample)
2527 qual_var_name += join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2528 else
2529 qual_var_name += ".interpolate_at_center()";
2530 }
2531
2532 bool flatten_stage_out = false;
2533
2534 if (is_builtin && !meta.strip_array)
2535 {
2536 // For the builtin gl_PerVertex, we cannot treat it as a block anyways,
2537 // so redirect to qualified name.
2538 set_member_qualified_name(var_type.self, mbr_idx, qual_var_name);
2539 }
2540 else if (!meta.strip_array)
2541 {
2542 // Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
2543 switch (storage)
2544 {
2545 case StorageClassInput:
2546 entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
2547 statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), " = ", qual_var_name, ";");
2548 });
2549 break;
2550
2551 case StorageClassOutput:
2552 flatten_stage_out = true;
2553 entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
2554 statement(qual_var_name, " = ", to_name(var.self), ".", to_member_name(var_type, mbr_idx), ";");
2555 });
2556 break;
2557
2558 default:
2559 break;
2560 }
2561 }
2562
2563 // Copy the variable location from the original variable to the member
2564 if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
2565 {
2566 uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation);
2567 if (storage == StorageClassInput)
2568 {
2569 mbr_type_id = ensure_correct_input_type(mbr_type_id, locn);
2570 var_type.member_types[mbr_idx] = mbr_type_id;
2571 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2572 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective);
2573 else
2574 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2575 }
2576 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2577 mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2578 }
2579 else if (has_decoration(var.self, DecorationLocation))
2580 {
2581 // The block itself might have a location and in this case, all members of the block
2582 // receive incrementing locations.
2583 uint32_t locn = get_accumulated_member_location(var, mbr_idx, meta.strip_array);
2584 if (storage == StorageClassInput)
2585 {
2586 mbr_type_id = ensure_correct_input_type(mbr_type_id, locn);
2587 var_type.member_types[mbr_idx] = mbr_type_id;
2588 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2589 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective);
2590 else
2591 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2592 }
2593 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2594 mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2595 }
2596 else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2597 {
2598 uint32_t locn = 0;
2599 auto builtin_itr = inputs_by_builtin.find(builtin);
2600 if (builtin_itr != end(inputs_by_builtin))
2601 locn = builtin_itr->second.location;
2602 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2603 mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2604 }
2605
2606 // Copy the component location, if present.
2607 if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
2608 {
2609 uint32_t comp = get_member_decoration(var_type.self, mbr_idx, DecorationComponent);
2610 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp);
2611 }
2612
2613 // Mark the member as builtin if needed
2614 if (is_builtin)
2615 {
2616 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2617 if (builtin == BuiltInPosition && storage == StorageClassOutput)
2618 qual_pos_var_name = qual_var_name;
2619 }
2620
2621 const SPIRConstant *c = nullptr;
2622 if (!flatten_stage_out && var.storage == StorageClassOutput &&
2623 var.initializer != ID(0) && (c = maybe_get<SPIRConstant>(var.initializer)))
2624 {
2625 if (meta.strip_array)
2626 {
2627 entry_func.fixup_hooks_in.push_back([=, &var]() {
2628 auto &type = this->get<SPIRType>(var.basetype);
2629 uint32_t index = get_extended_decoration(var.self, SPIRVCrossDecorationInterfaceMemberIndex);
2630 index += mbr_idx;
2631
2632 AccessChainMeta chain_meta;
2633 auto constant_chain = access_chain_internal(var.initializer, &builtin_invocation_id_id, 1, 0, &chain_meta);
2634
2635 statement(to_expression(stage_out_ptr_var_id), "[",
2636 builtin_to_glsl(BuiltInInvocationId, StorageClassInput), "].",
2637 to_member_name(ib_type, index), " = ",
2638 constant_chain, ".", to_member_name(type, mbr_idx), ";");
2639 });
2640 }
2641 else
2642 {
2643 entry_func.fixup_hooks_in.push_back([=]() {
2644 statement(qual_var_name, " = ", constant_expression(
2645 this->get<SPIRConstant>(c->subconstants[mbr_idx])), ";");
2646 });
2647 }
2648 }
2649
2650 if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2651 {
2652 // Copy interpolation decorations if needed
2653 if (is_flat)
2654 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2655 if (is_noperspective)
2656 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2657 if (is_centroid)
2658 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2659 if (is_sample)
2660 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2661 }
2662
2663 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2664 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
2665 }
2666
2667 // In Metal, the tessellation levels are stored as tightly packed half-precision floating point values.
2668 // But, stage-in attribute offsets and strides must be multiples of four, so we can't pass the levels
2669 // individually. Therefore, we must pass them as vectors. Triangles get a single float4, with the outer
2670 // levels in 'xyz' and the inner level in 'w'. Quads get a float4 containing the outer levels and a
2671 // float2 containing the inner levels.
add_tess_level_input_to_interface_block(const std::string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var)2672 void CompilerMSL::add_tess_level_input_to_interface_block(const std::string &ib_var_ref, SPIRType &ib_type,
2673 SPIRVariable &var)
2674 {
2675 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2676 auto &var_type = get_variable_element_type(var);
2677
2678 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2679
2680 // Force the variable to have the proper name.
2681 set_name(var.self, builtin_to_glsl(builtin, StorageClassFunction));
2682
2683 if (get_entry_point().flags.get(ExecutionModeTriangles))
2684 {
2685 // Triangles are tricky, because we want only one member in the struct.
2686
2687 // We need to declare the variable early and at entry-point scope.
2688 entry_func.add_local_variable(var.self);
2689 vars_needing_early_declaration.push_back(var.self);
2690
2691 string mbr_name = "gl_TessLevel";
2692
2693 // If we already added the other one, we can skip this step.
2694 if (!added_builtin_tess_level)
2695 {
2696 // Add a reference to the variable type to the interface struct.
2697 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2698
2699 uint32_t type_id = build_extended_vector_type(var_type.self, 4);
2700
2701 ib_type.member_types.push_back(type_id);
2702
2703 // Give the member a name
2704 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2705
2706 // There is no qualified alias since we need to flatten the internal array on return.
2707 if (get_decoration_bitset(var.self).get(DecorationLocation))
2708 {
2709 uint32_t locn = get_decoration(var.self, DecorationLocation);
2710 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2711 mark_location_as_used_by_shader(locn, var_type, StorageClassInput);
2712 }
2713 else if (inputs_by_builtin.count(builtin))
2714 {
2715 uint32_t locn = inputs_by_builtin[builtin].location;
2716 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2717 mark_location_as_used_by_shader(locn, var_type, StorageClassInput);
2718 }
2719
2720 added_builtin_tess_level = true;
2721 }
2722
2723 switch (builtin)
2724 {
2725 case BuiltInTessLevelOuter:
2726 entry_func.fixup_hooks_in.push_back([=, &var]() {
2727 statement(to_name(var.self), "[0] = ", ib_var_ref, ".", mbr_name, ".x;");
2728 statement(to_name(var.self), "[1] = ", ib_var_ref, ".", mbr_name, ".y;");
2729 statement(to_name(var.self), "[2] = ", ib_var_ref, ".", mbr_name, ".z;");
2730 });
2731 break;
2732
2733 case BuiltInTessLevelInner:
2734 entry_func.fixup_hooks_in.push_back(
2735 [=, &var]() { statement(to_name(var.self), "[0] = ", ib_var_ref, ".", mbr_name, ".w;"); });
2736 break;
2737
2738 default:
2739 assert(false);
2740 break;
2741 }
2742 }
2743 else
2744 {
2745 // Add a reference to the variable type to the interface struct.
2746 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2747
2748 uint32_t type_id = build_extended_vector_type(var_type.self, builtin == BuiltInTessLevelOuter ? 4 : 2);
2749 // Change the type of the variable, too.
2750 uint32_t ptr_type_id = ir.increase_bound_by(1);
2751 auto &new_var_type = set<SPIRType>(ptr_type_id, get<SPIRType>(type_id));
2752 new_var_type.pointer = true;
2753 new_var_type.storage = StorageClassInput;
2754 new_var_type.parent_type = type_id;
2755 var.basetype = ptr_type_id;
2756
2757 ib_type.member_types.push_back(type_id);
2758
2759 // Give the member a name
2760 string mbr_name = to_expression(var.self);
2761 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2762
2763 // Since vectors can be indexed like arrays, there is no need to unpack this. We can
2764 // just refer to the vector directly. So give it a qualified alias.
2765 string qual_var_name = ib_var_ref + "." + mbr_name;
2766 ir.meta[var.self].decoration.qualified_alias = qual_var_name;
2767
2768 if (get_decoration_bitset(var.self).get(DecorationLocation))
2769 {
2770 uint32_t locn = get_decoration(var.self, DecorationLocation);
2771 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2772 mark_location_as_used_by_shader(locn, new_var_type, StorageClassInput);
2773 }
2774 else if (inputs_by_builtin.count(builtin))
2775 {
2776 uint32_t locn = inputs_by_builtin[builtin].location;
2777 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2778 mark_location_as_used_by_shader(locn, new_var_type, StorageClassInput);
2779 }
2780 }
2781 }
2782
add_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)2783 void CompilerMSL::add_variable_to_interface_block(StorageClass storage, const string &ib_var_ref, SPIRType &ib_type,
2784 SPIRVariable &var, InterfaceBlockMeta &meta)
2785 {
2786 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2787 // Tessellation control I/O variables and tessellation evaluation per-point inputs are
2788 // usually declared as arrays. In these cases, we want to add the element type to the
2789 // interface block, since in Metal it's the interface block itself which is arrayed.
2790 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2791 bool is_builtin = is_builtin_variable(var);
2792 auto builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2793
2794 if (var_type.basetype == SPIRType::Struct)
2795 {
2796 if (!is_builtin_type(var_type) && (!capture_output_to_buffer || storage == StorageClassInput) &&
2797 !meta.strip_array)
2798 {
2799 // For I/O blocks or structs, we will need to pass the block itself around
2800 // to functions if they are used globally in leaf functions.
2801 // Rather than passing down member by member,
2802 // we unflatten I/O blocks while running the shader,
2803 // and pass the actual struct type down to leaf functions.
2804 // We then unflatten inputs, and flatten outputs in the "fixup" stages.
2805 entry_func.add_local_variable(var.self);
2806 vars_needing_early_declaration.push_back(var.self);
2807 }
2808
2809 if (capture_output_to_buffer && storage != StorageClassInput && !has_decoration(var_type.self, DecorationBlock))
2810 {
2811 // In Metal tessellation shaders, the interface block itself is arrayed. This makes things
2812 // very complicated, since stage-in structures in MSL don't support nested structures.
2813 // Luckily, for stage-out when capturing output, we can avoid this and just add
2814 // composite members directly, because the stage-out structure is stored to a buffer,
2815 // not returned.
2816 add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
2817 }
2818 else
2819 {
2820 // Flatten the struct members into the interface struct
2821 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++)
2822 {
2823 builtin = BuiltInMax;
2824 is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2825 auto &mbr_type = get<SPIRType>(var_type.member_types[mbr_idx]);
2826
2827 if (!is_builtin || has_active_builtin(builtin, storage))
2828 {
2829 bool is_composite_type = is_matrix(mbr_type) || is_array(mbr_type);
2830 bool attribute_load_store =
2831 storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
2832 bool storage_is_stage_io =
2833 (storage == StorageClassInput && !(get_execution_model() == ExecutionModelTessellationControl &&
2834 msl_options.multi_patch_workgroup)) ||
2835 storage == StorageClassOutput;
2836
2837 // ClipDistance always needs to be declared as user attributes.
2838 if (builtin == BuiltInClipDistance)
2839 is_builtin = false;
2840
2841 if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
2842 {
2843 add_composite_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx,
2844 meta);
2845 }
2846 else
2847 {
2848 add_plain_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx, meta);
2849 }
2850 }
2851 }
2852 }
2853 }
2854 else if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput &&
2855 !meta.strip_array && is_builtin && (builtin == BuiltInTessLevelOuter || builtin == BuiltInTessLevelInner))
2856 {
2857 add_tess_level_input_to_interface_block(ib_var_ref, ib_type, var);
2858 }
2859 else if (var_type.basetype == SPIRType::Boolean || var_type.basetype == SPIRType::Char ||
2860 type_is_integral(var_type) || type_is_floating_point(var_type))
2861 {
2862 if (!is_builtin || has_active_builtin(builtin, storage))
2863 {
2864 bool is_composite_type = is_matrix(var_type) || is_array(var_type);
2865 bool storage_is_stage_io =
2866 (storage == StorageClassInput &&
2867 !(get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup)) ||
2868 (storage == StorageClassOutput && !capture_output_to_buffer);
2869 bool attribute_load_store = storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
2870
2871 // ClipDistance always needs to be declared as user attributes.
2872 if (builtin == BuiltInClipDistance)
2873 is_builtin = false;
2874
2875 // MSL does not allow matrices or arrays in input or output variables, so need to handle it specially.
2876 if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
2877 {
2878 add_composite_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
2879 }
2880 else
2881 {
2882 add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
2883 }
2884 }
2885 }
2886 }
2887
2888 // Fix up the mapping of variables to interface member indices, which is used to compile access chains
2889 // for per-vertex variables in a tessellation control shader.
fix_up_interface_member_indices(StorageClass storage,uint32_t ib_type_id)2890 void CompilerMSL::fix_up_interface_member_indices(StorageClass storage, uint32_t ib_type_id)
2891 {
2892 // Only needed for tessellation shaders and pull-model interpolants.
2893 // Need to redirect interface indices back to variables themselves.
2894 // For structs, each member of the struct need a separate instance.
2895 if (get_execution_model() != ExecutionModelTessellationControl &&
2896 !(get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput) &&
2897 !(get_execution_model() == ExecutionModelFragment && storage == StorageClassInput &&
2898 !pull_model_inputs.empty()))
2899 return;
2900
2901 auto mbr_cnt = uint32_t(ir.meta[ib_type_id].members.size());
2902 for (uint32_t i = 0; i < mbr_cnt; i++)
2903 {
2904 uint32_t var_id = get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceOrigID);
2905 if (!var_id)
2906 continue;
2907 auto &var = get<SPIRVariable>(var_id);
2908
2909 auto &type = get_variable_element_type(var);
2910 if (storage == StorageClassInput && type.basetype == SPIRType::Struct)
2911 {
2912 uint32_t mbr_idx = get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceMemberIndex);
2913
2914 // Only set the lowest InterfaceMemberIndex for each variable member.
2915 // IB struct members will be emitted in-order w.r.t. interface member index.
2916 if (!has_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex))
2917 set_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, i);
2918 }
2919 else
2920 {
2921 // Only set the lowest InterfaceMemberIndex for each variable.
2922 // IB struct members will be emitted in-order w.r.t. interface member index.
2923 if (!has_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex))
2924 set_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex, i);
2925 }
2926 }
2927 }
2928
2929 // Add an interface structure for the type of storage, which is either StorageClassInput or StorageClassOutput.
2930 // Returns the ID of the newly added variable, or zero if no variable was added.
add_interface_block(StorageClass storage,bool patch)2931 uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch)
2932 {
2933 // Accumulate the variables that should appear in the interface struct.
2934 SmallVector<SPIRVariable *> vars;
2935 bool incl_builtins = storage == StorageClassOutput || is_tessellation_shader();
2936 bool has_seen_barycentric = false;
2937
2938 InterfaceBlockMeta meta;
2939
2940 // Varying interfaces between stages which use "user()" attribute can be dealt with
2941 // without explicit packing and unpacking of components. For any variables which link against the runtime
2942 // in some way (vertex attributes, fragment output, etc), we'll need to deal with it somehow.
2943 bool pack_components =
2944 (storage == StorageClassInput && get_execution_model() == ExecutionModelVertex) ||
2945 (storage == StorageClassOutput && get_execution_model() == ExecutionModelFragment) ||
2946 (storage == StorageClassOutput && get_execution_model() == ExecutionModelVertex && capture_output_to_buffer);
2947
2948 ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
2949 if (var.storage != storage)
2950 return;
2951
2952 auto &type = this->get<SPIRType>(var.basetype);
2953
2954 bool is_builtin = is_builtin_variable(var);
2955 auto bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
2956 uint32_t location = get_decoration(var_id, DecorationLocation);
2957
2958 // These builtins are part of the stage in/out structs.
2959 bool is_interface_block_builtin =
2960 (bi_type == BuiltInPosition || bi_type == BuiltInPointSize || bi_type == BuiltInClipDistance ||
2961 bi_type == BuiltInCullDistance || bi_type == BuiltInLayer || bi_type == BuiltInViewportIndex ||
2962 bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV || bi_type == BuiltInFragDepth ||
2963 bi_type == BuiltInFragStencilRefEXT || bi_type == BuiltInSampleMask) ||
2964 (get_execution_model() == ExecutionModelTessellationEvaluation &&
2965 (bi_type == BuiltInTessLevelOuter || bi_type == BuiltInTessLevelInner));
2966
2967 bool is_active = interface_variable_exists_in_entry_point(var.self);
2968 if (is_builtin && is_active)
2969 {
2970 // Only emit the builtin if it's active in this entry point. Interface variable list might lie.
2971 is_active = has_active_builtin(bi_type, storage);
2972 }
2973
2974 bool filter_patch_decoration = (has_decoration(var_id, DecorationPatch) || is_patch_block(type)) == patch;
2975
2976 bool hidden = is_hidden_variable(var, incl_builtins);
2977
2978 // ClipDistance is never hidden, we need to emulate it when used as an input.
2979 if (bi_type == BuiltInClipDistance)
2980 hidden = false;
2981
2982 // It's not enough to simply avoid marking fragment outputs if the pipeline won't
2983 // accept them. We can't put them in the struct at all, or otherwise the compiler
2984 // complains that the outputs weren't explicitly marked.
2985 if (get_execution_model() == ExecutionModelFragment && storage == StorageClassOutput && !patch &&
2986 ((is_builtin && ((bi_type == BuiltInFragDepth && !msl_options.enable_frag_depth_builtin) ||
2987 (bi_type == BuiltInFragStencilRefEXT && !msl_options.enable_frag_stencil_ref_builtin))) ||
2988 (!is_builtin && !(msl_options.enable_frag_output_mask & (1 << location)))))
2989 {
2990 hidden = true;
2991 disabled_frag_outputs.push_back(var_id);
2992 // If a builtin, force it to have the proper name.
2993 if (is_builtin)
2994 set_name(var_id, builtin_to_glsl(bi_type, StorageClassFunction));
2995 }
2996
2997 // Barycentric inputs must be emitted in stage-in, because they can have interpolation arguments.
2998 if (is_active && (bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV))
2999 {
3000 if (has_seen_barycentric)
3001 SPIRV_CROSS_THROW("Cannot declare both BaryCoordNV and BaryCoordNoPerspNV in same shader in MSL.");
3002 has_seen_barycentric = true;
3003 hidden = false;
3004 }
3005
3006 if (is_active && !hidden && type.pointer && filter_patch_decoration &&
3007 (!is_builtin || is_interface_block_builtin))
3008 {
3009 vars.push_back(&var);
3010
3011 if (!is_builtin)
3012 {
3013 // Need to deal specially with DecorationComponent.
3014 // Multiple variables can alias the same Location, and try to make sure each location is declared only once.
3015 // We will swizzle data in and out to make this work.
3016 // We only need to consider plain variables here, not composites.
3017 // This is only relevant for vertex inputs and fragment outputs.
3018 // Technically tessellation as well, but it is too complicated to support.
3019 uint32_t component = get_decoration(var_id, DecorationComponent);
3020 if (component != 0)
3021 {
3022 if (is_tessellation_shader())
3023 SPIRV_CROSS_THROW("Component decoration is not supported in tessellation shaders.");
3024 else if (pack_components)
3025 {
3026 auto &location_meta = meta.location_meta[location];
3027 location_meta.num_components = std::max(location_meta.num_components, component + type.vecsize);
3028 }
3029 }
3030 }
3031 }
3032 });
3033
3034 // If no variables qualify, leave.
3035 // For patch input in a tessellation evaluation shader, the per-vertex stage inputs
3036 // are included in a special patch control point array.
3037 if (vars.empty() && !(storage == StorageClassInput && patch && stage_in_var_id))
3038 return 0;
3039
3040 // Add a new typed variable for this interface structure.
3041 // The initializer expression is allocated here, but populated when the function
3042 // declaraion is emitted, because it is cleared after each compilation pass.
3043 uint32_t next_id = ir.increase_bound_by(3);
3044 uint32_t ib_type_id = next_id++;
3045 auto &ib_type = set<SPIRType>(ib_type_id);
3046 ib_type.basetype = SPIRType::Struct;
3047 ib_type.storage = storage;
3048 set_decoration(ib_type_id, DecorationBlock);
3049
3050 uint32_t ib_var_id = next_id++;
3051 auto &var = set<SPIRVariable>(ib_var_id, ib_type_id, storage, 0);
3052 var.initializer = next_id++;
3053
3054 string ib_var_ref;
3055 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
3056 switch (storage)
3057 {
3058 case StorageClassInput:
3059 ib_var_ref = patch ? patch_stage_in_var_name : stage_in_var_name;
3060 if (get_execution_model() == ExecutionModelTessellationControl)
3061 {
3062 // Add a hook to populate the shared workgroup memory containing the gl_in array.
3063 entry_func.fixup_hooks_in.push_back([=]() {
3064 // Can't use PatchVertices, PrimitiveId, or InvocationId yet; the hooks for those may not have run yet.
3065 if (msl_options.multi_patch_workgroup)
3066 {
3067 // n.b. builtin_invocation_id_id here is the dispatch global invocation ID,
3068 // not the TC invocation ID.
3069 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_in = &",
3070 input_buffer_var_name, "[min(", to_expression(builtin_invocation_id_id), ".x / ",
3071 get_entry_point().output_vertices,
3072 ", spvIndirectParams[1] - 1) * spvIndirectParams[0]];");
3073 }
3074 else
3075 {
3076 // It's safe to use InvocationId here because it's directly mapped to a
3077 // Metal builtin, and therefore doesn't need a hook.
3078 statement("if (", to_expression(builtin_invocation_id_id), " < spvIndirectParams[0])");
3079 statement(" ", input_wg_var_name, "[", to_expression(builtin_invocation_id_id),
3080 "] = ", ib_var_ref, ";");
3081 statement("threadgroup_barrier(mem_flags::mem_threadgroup);");
3082 statement("if (", to_expression(builtin_invocation_id_id),
3083 " >= ", get_entry_point().output_vertices, ")");
3084 statement(" return;");
3085 }
3086 });
3087 }
3088 break;
3089
3090 case StorageClassOutput:
3091 {
3092 ib_var_ref = patch ? patch_stage_out_var_name : stage_out_var_name;
3093
3094 // Add the output interface struct as a local variable to the entry function.
3095 // If the entry point should return the output struct, set the entry function
3096 // to return the output interface struct, otherwise to return nothing.
3097 // Indicate the output var requires early initialization.
3098 bool ep_should_return_output = !get_is_rasterization_disabled();
3099 uint32_t rtn_id = ep_should_return_output ? ib_var_id : 0;
3100 if (!capture_output_to_buffer)
3101 {
3102 entry_func.add_local_variable(ib_var_id);
3103 for (auto &blk_id : entry_func.blocks)
3104 {
3105 auto &blk = get<SPIRBlock>(blk_id);
3106 if (blk.terminator == SPIRBlock::Return)
3107 blk.return_value = rtn_id;
3108 }
3109 vars_needing_early_declaration.push_back(ib_var_id);
3110 }
3111 else
3112 {
3113 switch (get_execution_model())
3114 {
3115 case ExecutionModelVertex:
3116 case ExecutionModelTessellationEvaluation:
3117 // Instead of declaring a struct variable to hold the output and then
3118 // copying that to the output buffer, we'll declare the output variable
3119 // as a reference to the final output element in the buffer. Then we can
3120 // avoid the extra copy.
3121 entry_func.fixup_hooks_in.push_back([=]() {
3122 if (stage_out_var_id)
3123 {
3124 // The first member of the indirect buffer is always the number of vertices
3125 // to draw.
3126 // We zero-base the InstanceID & VertexID variables for HLSL emulation elsewhere, so don't do it twice
3127 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
3128 {
3129 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3130 " = ", output_buffer_var_name, "[", to_expression(builtin_invocation_id_id),
3131 ".y * ", to_expression(builtin_stage_input_size_id), ".x + ",
3132 to_expression(builtin_invocation_id_id), ".x];");
3133 }
3134 else if (msl_options.enable_base_index_zero)
3135 {
3136 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3137 " = ", output_buffer_var_name, "[", to_expression(builtin_instance_idx_id),
3138 " * spvIndirectParams[0] + ", to_expression(builtin_vertex_idx_id), "];");
3139 }
3140 else
3141 {
3142 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3143 " = ", output_buffer_var_name, "[(", to_expression(builtin_instance_idx_id),
3144 " - ", to_expression(builtin_base_instance_id), ") * spvIndirectParams[0] + ",
3145 to_expression(builtin_vertex_idx_id), " - ",
3146 to_expression(builtin_base_vertex_id), "];");
3147 }
3148 }
3149 });
3150 break;
3151 case ExecutionModelTessellationControl:
3152 if (msl_options.multi_patch_workgroup)
3153 {
3154 // We cannot use PrimitiveId here, because the hook may not have run yet.
3155 if (patch)
3156 {
3157 entry_func.fixup_hooks_in.push_back([=]() {
3158 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3159 " = ", patch_output_buffer_var_name, "[", to_expression(builtin_invocation_id_id),
3160 ".x / ", get_entry_point().output_vertices, "];");
3161 });
3162 }
3163 else
3164 {
3165 entry_func.fixup_hooks_in.push_back([=]() {
3166 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &",
3167 output_buffer_var_name, "[", to_expression(builtin_invocation_id_id), ".x - ",
3168 to_expression(builtin_invocation_id_id), ".x % ",
3169 get_entry_point().output_vertices, "];");
3170 });
3171 }
3172 }
3173 else
3174 {
3175 if (patch)
3176 {
3177 entry_func.fixup_hooks_in.push_back([=]() {
3178 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3179 " = ", patch_output_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
3180 "];");
3181 });
3182 }
3183 else
3184 {
3185 entry_func.fixup_hooks_in.push_back([=]() {
3186 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &",
3187 output_buffer_var_name, "[", to_expression(builtin_primitive_id_id), " * ",
3188 get_entry_point().output_vertices, "];");
3189 });
3190 }
3191 }
3192 break;
3193 default:
3194 break;
3195 }
3196 }
3197 break;
3198 }
3199
3200 default:
3201 break;
3202 }
3203
3204 set_name(ib_type_id, to_name(ir.default_entry_point) + "_" + ib_var_ref);
3205 set_name(ib_var_id, ib_var_ref);
3206
3207 for (auto *p_var : vars)
3208 {
3209 bool strip_array =
3210 (get_execution_model() == ExecutionModelTessellationControl ||
3211 (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput)) &&
3212 !patch;
3213
3214 meta.strip_array = strip_array;
3215 add_variable_to_interface_block(storage, ib_var_ref, ib_type, *p_var, meta);
3216 }
3217
3218 if (get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup &&
3219 storage == StorageClassInput)
3220 {
3221 // For tessellation control inputs, add all outputs from the vertex shader to ensure
3222 // the struct containing them is the correct size and layout.
3223 for (auto &input : inputs_by_location)
3224 {
3225 if (is_msl_shader_input_used(input.first))
3226 continue;
3227
3228 // Create a fake variable to put at the location.
3229 uint32_t offset = ir.increase_bound_by(4);
3230 uint32_t type_id = offset;
3231 uint32_t array_type_id = offset + 1;
3232 uint32_t ptr_type_id = offset + 2;
3233 uint32_t var_id = offset + 3;
3234
3235 SPIRType type;
3236 switch (input.second.format)
3237 {
3238 case MSL_SHADER_INPUT_FORMAT_UINT16:
3239 case MSL_SHADER_INPUT_FORMAT_ANY16:
3240 type.basetype = SPIRType::UShort;
3241 type.width = 16;
3242 break;
3243 case MSL_SHADER_INPUT_FORMAT_ANY32:
3244 default:
3245 type.basetype = SPIRType::UInt;
3246 type.width = 32;
3247 break;
3248 }
3249 type.vecsize = input.second.vecsize;
3250 set<SPIRType>(type_id, type);
3251
3252 type.array.push_back(0);
3253 type.array_size_literal.push_back(true);
3254 type.parent_type = type_id;
3255 set<SPIRType>(array_type_id, type);
3256
3257 type.pointer = true;
3258 type.parent_type = array_type_id;
3259 type.storage = storage;
3260 auto &ptr_type = set<SPIRType>(ptr_type_id, type);
3261 ptr_type.self = array_type_id;
3262
3263 auto &fake_var = set<SPIRVariable>(var_id, ptr_type_id, storage);
3264 set_decoration(var_id, DecorationLocation, input.first);
3265 meta.strip_array = true;
3266 add_variable_to_interface_block(storage, ib_var_ref, ib_type, fake_var, meta);
3267 }
3268 }
3269
3270 // Sort the members of the structure by their locations.
3271 MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Location);
3272 member_sorter.sort();
3273
3274 // The member indices were saved to the original variables, but after the members
3275 // were sorted, those indices are now likely incorrect. Fix those up now.
3276 if (!patch)
3277 fix_up_interface_member_indices(storage, ib_type_id);
3278
3279 // For patch inputs, add one more member, holding the array of control point data.
3280 if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput && patch &&
3281 stage_in_var_id)
3282 {
3283 uint32_t pcp_type_id = ir.increase_bound_by(1);
3284 auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
3285 pcp_type.basetype = SPIRType::ControlPointArray;
3286 pcp_type.parent_type = pcp_type.type_alias = get_stage_in_struct_type().self;
3287 pcp_type.storage = storage;
3288 ir.meta[pcp_type_id] = ir.meta[ib_type.self];
3289 uint32_t mbr_idx = uint32_t(ib_type.member_types.size());
3290 ib_type.member_types.push_back(pcp_type_id);
3291 set_member_name(ib_type.self, mbr_idx, "gl_in");
3292 }
3293
3294 return ib_var_id;
3295 }
3296
add_interface_block_pointer(uint32_t ib_var_id,StorageClass storage)3297 uint32_t CompilerMSL::add_interface_block_pointer(uint32_t ib_var_id, StorageClass storage)
3298 {
3299 if (!ib_var_id)
3300 return 0;
3301
3302 uint32_t ib_ptr_var_id;
3303 uint32_t next_id = ir.increase_bound_by(3);
3304 auto &ib_type = expression_type(ib_var_id);
3305 if (get_execution_model() == ExecutionModelTessellationControl)
3306 {
3307 // Tessellation control per-vertex I/O is presented as an array, so we must
3308 // do the same with our struct here.
3309 uint32_t ib_ptr_type_id = next_id++;
3310 auto &ib_ptr_type = set<SPIRType>(ib_ptr_type_id, ib_type);
3311 ib_ptr_type.parent_type = ib_ptr_type.type_alias = ib_type.self;
3312 ib_ptr_type.pointer = true;
3313 ib_ptr_type.storage =
3314 storage == StorageClassInput ?
3315 (msl_options.multi_patch_workgroup ? StorageClassStorageBuffer : StorageClassWorkgroup) :
3316 StorageClassStorageBuffer;
3317 ir.meta[ib_ptr_type_id] = ir.meta[ib_type.self];
3318 // To ensure that get_variable_data_type() doesn't strip off the pointer,
3319 // which we need, use another pointer.
3320 uint32_t ib_ptr_ptr_type_id = next_id++;
3321 auto &ib_ptr_ptr_type = set<SPIRType>(ib_ptr_ptr_type_id, ib_ptr_type);
3322 ib_ptr_ptr_type.parent_type = ib_ptr_type_id;
3323 ib_ptr_ptr_type.type_alias = ib_type.self;
3324 ib_ptr_ptr_type.storage = StorageClassFunction;
3325 ir.meta[ib_ptr_ptr_type_id] = ir.meta[ib_type.self];
3326
3327 ib_ptr_var_id = next_id;
3328 set<SPIRVariable>(ib_ptr_var_id, ib_ptr_ptr_type_id, StorageClassFunction, 0);
3329 set_name(ib_ptr_var_id, storage == StorageClassInput ? "gl_in" : "gl_out");
3330 }
3331 else
3332 {
3333 // Tessellation evaluation per-vertex inputs are also presented as arrays.
3334 // But, in Metal, this array uses a very special type, 'patch_control_point<T>',
3335 // which is a container that can be used to access the control point data.
3336 // To represent this, a special 'ControlPointArray' type has been added to the
3337 // SPIRV-Cross type system. It should only be generated by and seen in the MSL
3338 // backend (i.e. this one).
3339 uint32_t pcp_type_id = next_id++;
3340 auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
3341 pcp_type.basetype = SPIRType::ControlPointArray;
3342 pcp_type.parent_type = pcp_type.type_alias = ib_type.self;
3343 pcp_type.storage = storage;
3344 ir.meta[pcp_type_id] = ir.meta[ib_type.self];
3345
3346 ib_ptr_var_id = next_id;
3347 set<SPIRVariable>(ib_ptr_var_id, pcp_type_id, storage, 0);
3348 set_name(ib_ptr_var_id, "gl_in");
3349 ir.meta[ib_ptr_var_id].decoration.qualified_alias = join(patch_stage_in_var_name, ".gl_in");
3350 }
3351 return ib_ptr_var_id;
3352 }
3353
3354 // Ensure that the type is compatible with the builtin.
3355 // If it is, simply return the given type ID.
3356 // Otherwise, create a new type, and return it's ID.
ensure_correct_builtin_type(uint32_t type_id,BuiltIn builtin)3357 uint32_t CompilerMSL::ensure_correct_builtin_type(uint32_t type_id, BuiltIn builtin)
3358 {
3359 auto &type = get<SPIRType>(type_id);
3360
3361 if ((builtin == BuiltInSampleMask && is_array(type)) ||
3362 ((builtin == BuiltInLayer || builtin == BuiltInViewportIndex || builtin == BuiltInFragStencilRefEXT) &&
3363 type.basetype != SPIRType::UInt))
3364 {
3365 uint32_t next_id = ir.increase_bound_by(type.pointer ? 2 : 1);
3366 uint32_t base_type_id = next_id++;
3367 auto &base_type = set<SPIRType>(base_type_id);
3368 base_type.basetype = SPIRType::UInt;
3369 base_type.width = 32;
3370
3371 if (!type.pointer)
3372 return base_type_id;
3373
3374 uint32_t ptr_type_id = next_id++;
3375 auto &ptr_type = set<SPIRType>(ptr_type_id);
3376 ptr_type = base_type;
3377 ptr_type.pointer = true;
3378 ptr_type.storage = type.storage;
3379 ptr_type.parent_type = base_type_id;
3380 return ptr_type_id;
3381 }
3382
3383 return type_id;
3384 }
3385
3386 // Ensure that the type is compatible with the shader input.
3387 // If it is, simply return the given type ID.
3388 // Otherwise, create a new type, and return its ID.
ensure_correct_input_type(uint32_t type_id,uint32_t location,uint32_t num_components)3389 uint32_t CompilerMSL::ensure_correct_input_type(uint32_t type_id, uint32_t location, uint32_t num_components)
3390 {
3391 auto &type = get<SPIRType>(type_id);
3392
3393 auto p_va = inputs_by_location.find(location);
3394 if (p_va == end(inputs_by_location))
3395 {
3396 if (num_components > type.vecsize)
3397 return build_extended_vector_type(type_id, num_components);
3398 else
3399 return type_id;
3400 }
3401
3402 if (num_components == 0)
3403 num_components = p_va->second.vecsize;
3404
3405 switch (p_va->second.format)
3406 {
3407 case MSL_SHADER_INPUT_FORMAT_UINT8:
3408 {
3409 switch (type.basetype)
3410 {
3411 case SPIRType::UByte:
3412 case SPIRType::UShort:
3413 case SPIRType::UInt:
3414 if (num_components > type.vecsize)
3415 return build_extended_vector_type(type_id, num_components);
3416 else
3417 return type_id;
3418
3419 case SPIRType::Short:
3420 return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3421 SPIRType::UShort);
3422 case SPIRType::Int:
3423 return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3424 SPIRType::UInt);
3425
3426 default:
3427 SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
3428 }
3429 }
3430
3431 case MSL_SHADER_INPUT_FORMAT_UINT16:
3432 {
3433 switch (type.basetype)
3434 {
3435 case SPIRType::UShort:
3436 case SPIRType::UInt:
3437 if (num_components > type.vecsize)
3438 return build_extended_vector_type(type_id, num_components);
3439 else
3440 return type_id;
3441
3442 case SPIRType::Int:
3443 return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3444 SPIRType::UInt);
3445
3446 default:
3447 SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
3448 }
3449 }
3450
3451 default:
3452 if (num_components > type.vecsize)
3453 type_id = build_extended_vector_type(type_id, num_components);
3454 break;
3455 }
3456
3457 return type_id;
3458 }
3459
mark_struct_members_packed(const SPIRType & type)3460 void CompilerMSL::mark_struct_members_packed(const SPIRType &type)
3461 {
3462 set_extended_decoration(type.self, SPIRVCrossDecorationPhysicalTypePacked);
3463
3464 // Problem case! Struct needs to be placed at an awkward alignment.
3465 // Mark every member of the child struct as packed.
3466 uint32_t mbr_cnt = uint32_t(type.member_types.size());
3467 for (uint32_t i = 0; i < mbr_cnt; i++)
3468 {
3469 auto &mbr_type = get<SPIRType>(type.member_types[i]);
3470 if (mbr_type.basetype == SPIRType::Struct)
3471 {
3472 // Recursively mark structs as packed.
3473 auto *struct_type = &mbr_type;
3474 while (!struct_type->array.empty())
3475 struct_type = &get<SPIRType>(struct_type->parent_type);
3476 mark_struct_members_packed(*struct_type);
3477 }
3478 else if (!is_scalar(mbr_type))
3479 set_extended_member_decoration(type.self, i, SPIRVCrossDecorationPhysicalTypePacked);
3480 }
3481 }
3482
mark_scalar_layout_structs(const SPIRType & type)3483 void CompilerMSL::mark_scalar_layout_structs(const SPIRType &type)
3484 {
3485 uint32_t mbr_cnt = uint32_t(type.member_types.size());
3486 for (uint32_t i = 0; i < mbr_cnt; i++)
3487 {
3488 auto &mbr_type = get<SPIRType>(type.member_types[i]);
3489 if (mbr_type.basetype == SPIRType::Struct)
3490 {
3491 auto *struct_type = &mbr_type;
3492 while (!struct_type->array.empty())
3493 struct_type = &get<SPIRType>(struct_type->parent_type);
3494
3495 if (has_extended_decoration(struct_type->self, SPIRVCrossDecorationPhysicalTypePacked))
3496 continue;
3497
3498 uint32_t msl_alignment = get_declared_struct_member_alignment_msl(type, i);
3499 uint32_t msl_size = get_declared_struct_member_size_msl(type, i);
3500 uint32_t spirv_offset = type_struct_member_offset(type, i);
3501 uint32_t spirv_offset_next;
3502 if (i + 1 < mbr_cnt)
3503 spirv_offset_next = type_struct_member_offset(type, i + 1);
3504 else
3505 spirv_offset_next = spirv_offset + msl_size;
3506
3507 // Both are complicated cases. In scalar layout, a struct of float3 might just consume 12 bytes,
3508 // and the next member will be placed at offset 12.
3509 bool struct_is_misaligned = (spirv_offset % msl_alignment) != 0;
3510 bool struct_is_too_large = spirv_offset + msl_size > spirv_offset_next;
3511 uint32_t array_stride = 0;
3512 bool struct_needs_explicit_padding = false;
3513
3514 // Verify that if a struct is used as an array that ArrayStride matches the effective size of the struct.
3515 if (!mbr_type.array.empty())
3516 {
3517 array_stride = type_struct_member_array_stride(type, i);
3518 uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
3519 for (uint32_t dim = 0; dim < dimensions; dim++)
3520 {
3521 uint32_t array_size = to_array_size_literal(mbr_type, dim);
3522 array_stride /= max(array_size, 1u);
3523 }
3524
3525 // Set expected struct size based on ArrayStride.
3526 struct_needs_explicit_padding = true;
3527
3528 // If struct size is larger than array stride, we might be able to fit, if we tightly pack.
3529 if (get_declared_struct_size_msl(*struct_type) > array_stride)
3530 struct_is_too_large = true;
3531 }
3532
3533 if (struct_is_misaligned || struct_is_too_large)
3534 mark_struct_members_packed(*struct_type);
3535 mark_scalar_layout_structs(*struct_type);
3536
3537 if (struct_needs_explicit_padding)
3538 {
3539 msl_size = get_declared_struct_size_msl(*struct_type, true, true);
3540 if (array_stride < msl_size)
3541 {
3542 SPIRV_CROSS_THROW("Cannot express an array stride smaller than size of struct type.");
3543 }
3544 else
3545 {
3546 if (has_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget))
3547 {
3548 if (array_stride !=
3549 get_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget))
3550 SPIRV_CROSS_THROW(
3551 "A struct is used with different array strides. Cannot express this in MSL.");
3552 }
3553 else
3554 set_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget, array_stride);
3555 }
3556 }
3557 }
3558 }
3559 }
3560
3561 // Sort the members of the struct type by offset, and pack and then pad members where needed
3562 // to align MSL members with SPIR-V offsets. The struct members are iterated twice. Packing
3563 // occurs first, followed by padding, because packing a member reduces both its size and its
3564 // natural alignment, possibly requiring a padding member to be added ahead of it.
align_struct(SPIRType & ib_type,unordered_set<uint32_t> & aligned_structs)3565 void CompilerMSL::align_struct(SPIRType &ib_type, unordered_set<uint32_t> &aligned_structs)
3566 {
3567 // We align structs recursively, so stop any redundant work.
3568 ID &ib_type_id = ib_type.self;
3569 if (aligned_structs.count(ib_type_id))
3570 return;
3571 aligned_structs.insert(ib_type_id);
3572
3573 // Sort the members of the interface structure by their offset.
3574 // They should already be sorted per SPIR-V spec anyway.
3575 MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Offset);
3576 member_sorter.sort();
3577
3578 auto mbr_cnt = uint32_t(ib_type.member_types.size());
3579
3580 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
3581 {
3582 // Pack any dependent struct types before we pack a parent struct.
3583 auto &mbr_type = get<SPIRType>(ib_type.member_types[mbr_idx]);
3584 if (mbr_type.basetype == SPIRType::Struct)
3585 align_struct(mbr_type, aligned_structs);
3586 }
3587
3588 // Test the alignment of each member, and if a member should be closer to the previous
3589 // member than the default spacing expects, it is likely that the previous member is in
3590 // a packed format. If so, and the previous member is packable, pack it.
3591 // For example ... this applies to any 3-element vector that is followed by a scalar.
3592 uint32_t msl_offset = 0;
3593 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
3594 {
3595 // This checks the member in isolation, if the member needs some kind of type remapping to conform to SPIR-V
3596 // offsets, array strides and matrix strides.
3597 ensure_member_packing_rules_msl(ib_type, mbr_idx);
3598
3599 // Align current offset to the current member's default alignment. If the member was packed, it will observe
3600 // the updated alignment here.
3601 uint32_t msl_align_mask = get_declared_struct_member_alignment_msl(ib_type, mbr_idx) - 1;
3602 uint32_t aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
3603
3604 // Fetch the member offset as declared in the SPIRV.
3605 uint32_t spirv_mbr_offset = get_member_decoration(ib_type_id, mbr_idx, DecorationOffset);
3606 if (spirv_mbr_offset > aligned_msl_offset)
3607 {
3608 // Since MSL and SPIR-V have slightly different struct member alignment and
3609 // size rules, we'll pad to standard C-packing rules with a char[] array. If the member is farther
3610 // away than C-packing, expects, add an inert padding member before the the member.
3611 uint32_t padding_bytes = spirv_mbr_offset - aligned_msl_offset;
3612 set_extended_member_decoration(ib_type_id, mbr_idx, SPIRVCrossDecorationPaddingTarget, padding_bytes);
3613
3614 // Re-align as a sanity check that aligning post-padding matches up.
3615 msl_offset += padding_bytes;
3616 aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
3617 }
3618 else if (spirv_mbr_offset < aligned_msl_offset)
3619 {
3620 // This should not happen, but deal with unexpected scenarios.
3621 // It *might* happen if a sub-struct has a larger alignment requirement in MSL than SPIR-V.
3622 SPIRV_CROSS_THROW("Cannot represent buffer block correctly in MSL.");
3623 }
3624
3625 assert(aligned_msl_offset == spirv_mbr_offset);
3626
3627 // Increment the current offset to be positioned immediately after the current member.
3628 // Don't do this for the last member since it can be unsized, and it is not relevant for padding purposes here.
3629 if (mbr_idx + 1 < mbr_cnt)
3630 msl_offset = aligned_msl_offset + get_declared_struct_member_size_msl(ib_type, mbr_idx);
3631 }
3632 }
3633
validate_member_packing_rules_msl(const SPIRType & type,uint32_t index) const3634 bool CompilerMSL::validate_member_packing_rules_msl(const SPIRType &type, uint32_t index) const
3635 {
3636 auto &mbr_type = get<SPIRType>(type.member_types[index]);
3637 uint32_t spirv_offset = get_member_decoration(type.self, index, DecorationOffset);
3638
3639 if (index + 1 < type.member_types.size())
3640 {
3641 // First, we will check offsets. If SPIR-V offset + MSL size > SPIR-V offset of next member,
3642 // we *must* perform some kind of remapping, no way getting around it.
3643 // We can always pad after this member if necessary, so that case is fine.
3644 uint32_t spirv_offset_next = get_member_decoration(type.self, index + 1, DecorationOffset);
3645 assert(spirv_offset_next >= spirv_offset);
3646 uint32_t maximum_size = spirv_offset_next - spirv_offset;
3647 uint32_t msl_mbr_size = get_declared_struct_member_size_msl(type, index);
3648 if (msl_mbr_size > maximum_size)
3649 return false;
3650 }
3651
3652 if (!mbr_type.array.empty())
3653 {
3654 // If we have an array type, array stride must match exactly with SPIR-V.
3655
3656 // An exception to this requirement is if we have one array element.
3657 // This comes from DX scalar layout workaround.
3658 // If app tries to be cheeky and access the member out of bounds, this will not work, but this is the best we can do.
3659 // In OpAccessChain with logical memory models, access chains must be in-bounds in SPIR-V specification.
3660 bool relax_array_stride = mbr_type.array.back() == 1 && mbr_type.array_size_literal.back();
3661
3662 if (!relax_array_stride)
3663 {
3664 uint32_t spirv_array_stride = type_struct_member_array_stride(type, index);
3665 uint32_t msl_array_stride = get_declared_struct_member_array_stride_msl(type, index);
3666 if (spirv_array_stride != msl_array_stride)
3667 return false;
3668 }
3669 }
3670
3671 if (is_matrix(mbr_type))
3672 {
3673 // Need to check MatrixStride as well.
3674 uint32_t spirv_matrix_stride = type_struct_member_matrix_stride(type, index);
3675 uint32_t msl_matrix_stride = get_declared_struct_member_matrix_stride_msl(type, index);
3676 if (spirv_matrix_stride != msl_matrix_stride)
3677 return false;
3678 }
3679
3680 // Now, we check alignment.
3681 uint32_t msl_alignment = get_declared_struct_member_alignment_msl(type, index);
3682 if ((spirv_offset % msl_alignment) != 0)
3683 return false;
3684
3685 // We're in the clear.
3686 return true;
3687 }
3688
3689 // Here we need to verify that the member type we declare conforms to Offset, ArrayStride or MatrixStride restrictions.
3690 // If there is a mismatch, we need to emit remapped types, either normal types, or "packed_X" types.
3691 // In odd cases we need to emit packed and remapped types, for e.g. weird matrices or arrays with weird array strides.
ensure_member_packing_rules_msl(SPIRType & ib_type,uint32_t index)3692 void CompilerMSL::ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t index)
3693 {
3694 if (validate_member_packing_rules_msl(ib_type, index))
3695 return;
3696
3697 // We failed validation.
3698 // This case will be nightmare-ish to deal with. This could possibly happen if struct alignment does not quite
3699 // match up with what we want. Scalar block layout comes to mind here where we might have to work around the rule
3700 // that struct alignment == max alignment of all members and struct size depends on this alignment.
3701 auto &mbr_type = get<SPIRType>(ib_type.member_types[index]);
3702 if (mbr_type.basetype == SPIRType::Struct)
3703 SPIRV_CROSS_THROW("Cannot perform any repacking for structs when it is used as a member of another struct.");
3704
3705 // Perform remapping here.
3706 // There is nothing to be gained by using packed scalars, so don't attempt it.
3707 if (!is_scalar(ib_type))
3708 set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
3709
3710 // Try validating again, now with packed.
3711 if (validate_member_packing_rules_msl(ib_type, index))
3712 return;
3713
3714 // We're in deep trouble, and we need to create a new PhysicalType which matches up with what we expect.
3715 // A lot of work goes here ...
3716 // We will need remapping on Load and Store to translate the types between Logical and Physical.
3717
3718 // First, we check if we have small vector std140 array.
3719 // We detect this if we have an array of vectors, and array stride is greater than number of elements.
3720 if (!mbr_type.array.empty() && !is_matrix(mbr_type))
3721 {
3722 uint32_t array_stride = type_struct_member_array_stride(ib_type, index);
3723
3724 // Hack off array-of-arrays until we find the array stride per element we must have to make it work.
3725 uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
3726 for (uint32_t dim = 0; dim < dimensions; dim++)
3727 array_stride /= max(to_array_size_literal(mbr_type, dim), 1u);
3728
3729 uint32_t elems_per_stride = array_stride / (mbr_type.width / 8);
3730
3731 if (elems_per_stride == 3)
3732 SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
3733 else if (elems_per_stride > 4)
3734 SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
3735
3736 auto physical_type = mbr_type;
3737 physical_type.vecsize = elems_per_stride;
3738 physical_type.parent_type = 0;
3739 uint32_t type_id = ir.increase_bound_by(1);
3740 set<SPIRType>(type_id, physical_type);
3741 set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID, type_id);
3742 set_decoration(type_id, DecorationArrayStride, array_stride);
3743
3744 // Remove packed_ for vectors of size 1, 2 and 4.
3745 if (has_extended_decoration(ib_type.self, SPIRVCrossDecorationPhysicalTypePacked))
3746 SPIRV_CROSS_THROW("Unable to remove packed decoration as entire struct must be fully packed. Do not mix "
3747 "scalar and std140 layout rules.");
3748 else
3749 unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
3750 }
3751 else if (is_matrix(mbr_type))
3752 {
3753 // MatrixStride might be std140-esque.
3754 uint32_t matrix_stride = type_struct_member_matrix_stride(ib_type, index);
3755
3756 uint32_t elems_per_stride = matrix_stride / (mbr_type.width / 8);
3757
3758 if (elems_per_stride == 3)
3759 SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
3760 else if (elems_per_stride > 4)
3761 SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
3762
3763 bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
3764
3765 auto physical_type = mbr_type;
3766 physical_type.parent_type = 0;
3767 if (row_major)
3768 physical_type.columns = elems_per_stride;
3769 else
3770 physical_type.vecsize = elems_per_stride;
3771 uint32_t type_id = ir.increase_bound_by(1);
3772 set<SPIRType>(type_id, physical_type);
3773 set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID, type_id);
3774
3775 // Remove packed_ for vectors of size 1, 2 and 4.
3776 if (has_extended_decoration(ib_type.self, SPIRVCrossDecorationPhysicalTypePacked))
3777 SPIRV_CROSS_THROW("Unable to remove packed decoration as entire struct must be fully packed. Do not mix "
3778 "scalar and std140 layout rules.");
3779 else
3780 unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
3781 }
3782 else
3783 SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
3784
3785 // Try validating again, now with physical type remapping.
3786 if (validate_member_packing_rules_msl(ib_type, index))
3787 return;
3788
3789 // We might have a particular odd scalar layout case where the last element of an array
3790 // does not take up as much space as the ArrayStride or MatrixStride. This can happen with DX cbuffers.
3791 // The "proper" workaround for this is extremely painful and essentially impossible in the edge case of float3[],
3792 // so we hack around it by declaring the offending array or matrix with one less array size/col/row,
3793 // and rely on padding to get the correct value. We will technically access arrays out of bounds into the padding region,
3794 // but it should spill over gracefully without too much trouble. We rely on behavior like this for unsized arrays anyways.
3795
3796 // E.g. we might observe a physical layout of:
3797 // { float2 a[2]; float b; } in cbuffer layout where ArrayStride of a is 16, but offset of b is 24, packed right after a[1] ...
3798 uint32_t type_id = get_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID);
3799 auto &type = get<SPIRType>(type_id);
3800
3801 // Modify the physical type in-place. This is safe since each physical type workaround is a copy.
3802 if (is_array(type))
3803 {
3804 if (type.array.back() > 1)
3805 {
3806 if (!type.array_size_literal.back())
3807 SPIRV_CROSS_THROW("Cannot apply scalar layout workaround with spec constant array size.");
3808 type.array.back() -= 1;
3809 }
3810 else
3811 {
3812 // We have an array of size 1, so we cannot decrement that. Our only option now is to
3813 // force a packed layout instead, and drop the physical type remap since ArrayStride is meaningless now.
3814 unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID);
3815 set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
3816 }
3817 }
3818 else if (is_matrix(type))
3819 {
3820 bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
3821 if (!row_major)
3822 {
3823 // Slice off one column. If we only have 2 columns, this might turn the matrix into a vector with one array element instead.
3824 if (type.columns > 2)
3825 {
3826 type.columns--;
3827 }
3828 else if (type.columns == 2)
3829 {
3830 type.columns = 1;
3831 assert(type.array.empty());
3832 type.array.push_back(1);
3833 type.array_size_literal.push_back(true);
3834 }
3835 }
3836 else
3837 {
3838 // Slice off one row. If we only have 2 rows, this might turn the matrix into a vector with one array element instead.
3839 if (type.vecsize > 2)
3840 {
3841 type.vecsize--;
3842 }
3843 else if (type.vecsize == 2)
3844 {
3845 type.vecsize = type.columns;
3846 type.columns = 1;
3847 assert(type.array.empty());
3848 type.array.push_back(1);
3849 type.array_size_literal.push_back(true);
3850 }
3851 }
3852 }
3853
3854 // This better validate now, or we must fail gracefully.
3855 if (!validate_member_packing_rules_msl(ib_type, index))
3856 SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
3857 }
3858
emit_store_statement(uint32_t lhs_expression,uint32_t rhs_expression)3859 void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression)
3860 {
3861 auto &type = expression_type(rhs_expression);
3862
3863 bool lhs_remapped_type = has_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypeID);
3864 bool lhs_packed_type = has_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypePacked);
3865 auto *lhs_e = maybe_get<SPIRExpression>(lhs_expression);
3866 auto *rhs_e = maybe_get<SPIRExpression>(rhs_expression);
3867
3868 bool transpose = lhs_e && lhs_e->need_transpose;
3869
3870 // No physical type remapping, and no packed type, so can just emit a store directly.
3871 if (!lhs_remapped_type && !lhs_packed_type)
3872 {
3873 // We might not be dealing with remapped physical types or packed types,
3874 // but we might be doing a clean store to a row-major matrix.
3875 // In this case, we just flip transpose states, and emit the store, a transpose must be in the RHS expression, if any.
3876 if (is_matrix(type) && lhs_e && lhs_e->need_transpose)
3877 {
3878 if (!rhs_e)
3879 SPIRV_CROSS_THROW("Need to transpose right-side expression of a store to row-major matrix, but it is "
3880 "not a SPIRExpression.");
3881 lhs_e->need_transpose = false;
3882
3883 if (rhs_e && rhs_e->need_transpose)
3884 {
3885 // Direct copy, but might need to unpack RHS.
3886 // Skip the transpose, as we will transpose when writing to LHS and transpose(transpose(T)) == T.
3887 rhs_e->need_transpose = false;
3888 statement(to_expression(lhs_expression), " = ", to_unpacked_row_major_matrix_expression(rhs_expression),
3889 ";");
3890 rhs_e->need_transpose = true;
3891 }
3892 else
3893 statement(to_expression(lhs_expression), " = transpose(", to_unpacked_expression(rhs_expression), ");");
3894
3895 lhs_e->need_transpose = true;
3896 register_write(lhs_expression);
3897 }
3898 else if (lhs_e && lhs_e->need_transpose)
3899 {
3900 lhs_e->need_transpose = false;
3901
3902 // Storing a column to a row-major matrix. Unroll the write.
3903 for (uint32_t c = 0; c < type.vecsize; c++)
3904 {
3905 auto lhs_expr = to_dereferenced_expression(lhs_expression);
3906 auto column_index = lhs_expr.find_last_of('[');
3907 if (column_index != string::npos)
3908 {
3909 statement(lhs_expr.insert(column_index, join('[', c, ']')), " = ",
3910 to_extract_component_expression(rhs_expression, c), ";");
3911 }
3912 }
3913 lhs_e->need_transpose = true;
3914 register_write(lhs_expression);
3915 }
3916 else
3917 CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
3918 }
3919 else if (!lhs_remapped_type && !is_matrix(type) && !transpose)
3920 {
3921 // Even if the target type is packed, we can directly store to it. We cannot store to packed matrices directly,
3922 // since they are declared as array of vectors instead, and we need the fallback path below.
3923 CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
3924 }
3925 else
3926 {
3927 // Special handling when storing to a remapped physical type.
3928 // This is mostly to deal with std140 padded matrices or vectors.
3929
3930 TypeID physical_type_id = lhs_remapped_type ?
3931 ID(get_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypeID)) :
3932 type.self;
3933
3934 auto &physical_type = get<SPIRType>(physical_type_id);
3935
3936 if (is_matrix(type))
3937 {
3938 const char *packed_pfx = lhs_packed_type ? "packed_" : "";
3939
3940 // Packed matrices are stored as arrays of packed vectors, so we need
3941 // to assign the vectors one at a time.
3942 // For row-major matrices, we need to transpose the *right-hand* side,
3943 // not the left-hand side.
3944
3945 // Lots of cases to cover here ...
3946
3947 bool rhs_transpose = rhs_e && rhs_e->need_transpose;
3948 SPIRType write_type = type;
3949 string cast_expr;
3950
3951 // We're dealing with transpose manually.
3952 if (rhs_transpose)
3953 rhs_e->need_transpose = false;
3954
3955 if (transpose)
3956 {
3957 // We're dealing with transpose manually.
3958 lhs_e->need_transpose = false;
3959 write_type.vecsize = type.columns;
3960 write_type.columns = 1;
3961
3962 if (physical_type.columns != type.columns)
3963 cast_expr = join("(device ", packed_pfx, type_to_glsl(write_type), "&)");
3964
3965 if (rhs_transpose)
3966 {
3967 // If RHS is also transposed, we can just copy row by row.
3968 for (uint32_t i = 0; i < type.vecsize; i++)
3969 {
3970 statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ",
3971 to_unpacked_row_major_matrix_expression(rhs_expression), "[", i, "];");
3972 }
3973 }
3974 else
3975 {
3976 auto vector_type = expression_type(rhs_expression);
3977 vector_type.vecsize = vector_type.columns;
3978 vector_type.columns = 1;
3979
3980 // Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
3981 // so pick out individual components instead.
3982 for (uint32_t i = 0; i < type.vecsize; i++)
3983 {
3984 string rhs_row = type_to_glsl_constructor(vector_type) + "(";
3985 for (uint32_t j = 0; j < vector_type.vecsize; j++)
3986 {
3987 rhs_row += join(to_enclosed_unpacked_expression(rhs_expression), "[", j, "][", i, "]");
3988 if (j + 1 < vector_type.vecsize)
3989 rhs_row += ", ";
3990 }
3991 rhs_row += ")";
3992
3993 statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";");
3994 }
3995 }
3996
3997 // We're dealing with transpose manually.
3998 lhs_e->need_transpose = true;
3999 }
4000 else
4001 {
4002 write_type.columns = 1;
4003
4004 if (physical_type.vecsize != type.vecsize)
4005 cast_expr = join("(device ", packed_pfx, type_to_glsl(write_type), "&)");
4006
4007 if (rhs_transpose)
4008 {
4009 auto vector_type = expression_type(rhs_expression);
4010 vector_type.columns = 1;
4011
4012 // Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
4013 // so pick out individual components instead.
4014 for (uint32_t i = 0; i < type.columns; i++)
4015 {
4016 string rhs_row = type_to_glsl_constructor(vector_type) + "(";
4017 for (uint32_t j = 0; j < vector_type.vecsize; j++)
4018 {
4019 // Need to explicitly unpack expression since we've mucked with transpose state.
4020 auto unpacked_expr = to_unpacked_row_major_matrix_expression(rhs_expression);
4021 rhs_row += join(unpacked_expr, "[", j, "][", i, "]");
4022 if (j + 1 < vector_type.vecsize)
4023 rhs_row += ", ";
4024 }
4025 rhs_row += ")";
4026
4027 statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";");
4028 }
4029 }
4030 else
4031 {
4032 // Copy column-by-column.
4033 for (uint32_t i = 0; i < type.columns; i++)
4034 {
4035 statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ",
4036 to_enclosed_unpacked_expression(rhs_expression), "[", i, "];");
4037 }
4038 }
4039 }
4040
4041 // We're dealing with transpose manually.
4042 if (rhs_transpose)
4043 rhs_e->need_transpose = true;
4044 }
4045 else if (transpose)
4046 {
4047 lhs_e->need_transpose = false;
4048
4049 SPIRType write_type = type;
4050 write_type.vecsize = 1;
4051 write_type.columns = 1;
4052
4053 // Storing a column to a row-major matrix. Unroll the write.
4054 for (uint32_t c = 0; c < type.vecsize; c++)
4055 {
4056 auto lhs_expr = to_enclosed_expression(lhs_expression);
4057 auto column_index = lhs_expr.find_last_of('[');
4058 if (column_index != string::npos)
4059 {
4060 statement("((device ", type_to_glsl(write_type), "*)&",
4061 lhs_expr.insert(column_index, join('[', c, ']', ")")), " = ",
4062 to_extract_component_expression(rhs_expression, c), ";");
4063 }
4064 }
4065
4066 lhs_e->need_transpose = true;
4067 }
4068 else if ((is_matrix(physical_type) || is_array(physical_type)) && physical_type.vecsize > type.vecsize)
4069 {
4070 assert(type.vecsize >= 1 && type.vecsize <= 3);
4071
4072 // If we have packed types, we cannot use swizzled stores.
4073 // We could technically unroll the store for each element if needed.
4074 // When remapping to a std140 physical type, we always get float4,
4075 // and the packed decoration should always be removed.
4076 assert(!lhs_packed_type);
4077
4078 string lhs = to_dereferenced_expression(lhs_expression);
4079 string rhs = to_pointer_expression(rhs_expression);
4080
4081 // Unpack the expression so we can store to it with a float or float2.
4082 // It's still an l-value, so it's fine. Most other unpacking of expressions turn them into r-values instead.
4083 lhs = join("(device ", type_to_glsl(type), "&)", enclose_expression(lhs));
4084 if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs))
4085 statement(lhs, " = ", rhs, ";");
4086 }
4087 else if (!is_matrix(type))
4088 {
4089 string lhs = to_dereferenced_expression(lhs_expression);
4090 string rhs = to_pointer_expression(rhs_expression);
4091 if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs))
4092 statement(lhs, " = ", rhs, ";");
4093 }
4094
4095 register_write(lhs_expression);
4096 }
4097 }
4098
expression_ends_with(const string & expr_str,const std::string & ending)4099 static bool expression_ends_with(const string &expr_str, const std::string &ending)
4100 {
4101 if (expr_str.length() >= ending.length())
4102 return (expr_str.compare(expr_str.length() - ending.length(), ending.length(), ending) == 0);
4103 else
4104 return false;
4105 }
4106
4107 // Converts the format of the current expression from packed to unpacked,
4108 // by wrapping the expression in a constructor of the appropriate type.
4109 // Also, handle special physical ID remapping scenarios, similar to emit_store_statement().
unpack_expression_type(string expr_str,const SPIRType & type,uint32_t physical_type_id,bool packed,bool row_major)4110 string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type, uint32_t physical_type_id,
4111 bool packed, bool row_major)
4112 {
4113 // Trivial case, nothing to do.
4114 if (physical_type_id == 0 && !packed)
4115 return expr_str;
4116
4117 const SPIRType *physical_type = nullptr;
4118 if (physical_type_id)
4119 physical_type = &get<SPIRType>(physical_type_id);
4120
4121 static const char *swizzle_lut[] = {
4122 ".x",
4123 ".xy",
4124 ".xyz",
4125 };
4126
4127 if (physical_type && is_vector(*physical_type) && is_array(*physical_type) &&
4128 physical_type->vecsize > type.vecsize && !expression_ends_with(expr_str, swizzle_lut[type.vecsize - 1]))
4129 {
4130 // std140 array cases for vectors.
4131 assert(type.vecsize >= 1 && type.vecsize <= 3);
4132 return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1];
4133 }
4134 else if (physical_type && is_matrix(*physical_type) && is_vector(type) && physical_type->vecsize > type.vecsize)
4135 {
4136 // Extract column from padded matrix.
4137 assert(type.vecsize >= 1 && type.vecsize <= 3);
4138 return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1];
4139 }
4140 else if (is_matrix(type))
4141 {
4142 // Packed matrices are stored as arrays of packed vectors. Unfortunately,
4143 // we can't just pass the array straight to the matrix constructor. We have to
4144 // pass each vector individually, so that they can be unpacked to normal vectors.
4145 if (!physical_type)
4146 physical_type = &type;
4147
4148 uint32_t vecsize = type.vecsize;
4149 uint32_t columns = type.columns;
4150 if (row_major)
4151 swap(vecsize, columns);
4152
4153 uint32_t physical_vecsize = row_major ? physical_type->columns : physical_type->vecsize;
4154
4155 const char *base_type = type.width == 16 ? "half" : "float";
4156 string unpack_expr = join(base_type, columns, "x", vecsize, "(");
4157
4158 const char *load_swiz = "";
4159
4160 if (physical_vecsize != vecsize)
4161 load_swiz = swizzle_lut[vecsize - 1];
4162
4163 for (uint32_t i = 0; i < columns; i++)
4164 {
4165 if (i > 0)
4166 unpack_expr += ", ";
4167
4168 if (packed)
4169 unpack_expr += join(base_type, physical_vecsize, "(", expr_str, "[", i, "]", ")", load_swiz);
4170 else
4171 unpack_expr += join(expr_str, "[", i, "]", load_swiz);
4172 }
4173
4174 unpack_expr += ")";
4175 return unpack_expr;
4176 }
4177 else
4178 {
4179 return join(type_to_glsl(type), "(", expr_str, ")");
4180 }
4181 }
4182
4183 // Emits the file header info
emit_header()4184 void CompilerMSL::emit_header()
4185 {
4186 // This particular line can be overridden during compilation, so make it a flag and not a pragma line.
4187 if (suppress_missing_prototypes)
4188 statement("#pragma clang diagnostic ignored \"-Wmissing-prototypes\"");
4189
4190 // Disable warning about missing braces for array<T> template to make arrays a value type
4191 if (spv_function_implementations.count(SPVFuncImplUnsafeArray) != 0)
4192 statement("#pragma clang diagnostic ignored \"-Wmissing-braces\"");
4193
4194 for (auto &pragma : pragma_lines)
4195 statement(pragma);
4196
4197 if (!pragma_lines.empty() || suppress_missing_prototypes)
4198 statement("");
4199
4200 statement("#include <metal_stdlib>");
4201 statement("#include <simd/simd.h>");
4202
4203 for (auto &header : header_lines)
4204 statement(header);
4205
4206 statement("");
4207 statement("using namespace metal;");
4208 statement("");
4209
4210 for (auto &td : typedef_lines)
4211 statement(td);
4212
4213 if (!typedef_lines.empty())
4214 statement("");
4215 }
4216
add_pragma_line(const string & line)4217 void CompilerMSL::add_pragma_line(const string &line)
4218 {
4219 auto rslt = pragma_lines.insert(line);
4220 if (rslt.second)
4221 force_recompile();
4222 }
4223
add_typedef_line(const string & line)4224 void CompilerMSL::add_typedef_line(const string &line)
4225 {
4226 auto rslt = typedef_lines.insert(line);
4227 if (rslt.second)
4228 force_recompile();
4229 }
4230
4231 // Template struct like spvUnsafeArray<> need to be declared *before* any resources are declared
emit_custom_templates()4232 void CompilerMSL::emit_custom_templates()
4233 {
4234 for (const auto &spv_func : spv_function_implementations)
4235 {
4236 switch (spv_func)
4237 {
4238 case SPVFuncImplUnsafeArray:
4239 statement("template<typename T, size_t Num>");
4240 statement("struct spvUnsafeArray");
4241 begin_scope();
4242 statement("T elements[Num ? Num : 1];");
4243 statement("");
4244 statement("thread T& operator [] (size_t pos) thread");
4245 begin_scope();
4246 statement("return elements[pos];");
4247 end_scope();
4248 statement("constexpr const thread T& operator [] (size_t pos) const thread");
4249 begin_scope();
4250 statement("return elements[pos];");
4251 end_scope();
4252 statement("");
4253 statement("device T& operator [] (size_t pos) device");
4254 begin_scope();
4255 statement("return elements[pos];");
4256 end_scope();
4257 statement("constexpr const device T& operator [] (size_t pos) const device");
4258 begin_scope();
4259 statement("return elements[pos];");
4260 end_scope();
4261 statement("");
4262 statement("constexpr const constant T& operator [] (size_t pos) const constant");
4263 begin_scope();
4264 statement("return elements[pos];");
4265 end_scope();
4266 statement("");
4267 statement("threadgroup T& operator [] (size_t pos) threadgroup");
4268 begin_scope();
4269 statement("return elements[pos];");
4270 end_scope();
4271 statement("constexpr const threadgroup T& operator [] (size_t pos) const threadgroup");
4272 begin_scope();
4273 statement("return elements[pos];");
4274 end_scope();
4275 end_scope_decl();
4276 statement("");
4277 break;
4278
4279 default:
4280 break;
4281 }
4282 }
4283 }
4284
4285 // Emits any needed custom function bodies.
4286 // Metal helper functions must be static force-inline, i.e. static inline __attribute__((always_inline))
4287 // otherwise they will cause problems when linked together in a single Metallib.
emit_custom_functions()4288 void CompilerMSL::emit_custom_functions()
4289 {
4290 for (uint32_t i = kArrayCopyMultidimMax; i >= 2; i--)
4291 if (spv_function_implementations.count(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i)))
4292 spv_function_implementations.insert(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i - 1));
4293
4294 if (spv_function_implementations.count(SPVFuncImplDynamicImageSampler))
4295 {
4296 // Unfortunately, this one needs a lot of the other functions to compile OK.
4297 if (!msl_options.supports_msl_version(2))
4298 SPIRV_CROSS_THROW(
4299 "spvDynamicImageSampler requires default-constructible texture objects, which require MSL 2.0.");
4300 spv_function_implementations.insert(SPVFuncImplForwardArgs);
4301 spv_function_implementations.insert(SPVFuncImplTextureSwizzle);
4302 if (msl_options.swizzle_texture_samples)
4303 spv_function_implementations.insert(SPVFuncImplGatherSwizzle);
4304 for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
4305 i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
4306 spv_function_implementations.insert(static_cast<SPVFuncImpl>(i));
4307 spv_function_implementations.insert(SPVFuncImplExpandITUFullRange);
4308 spv_function_implementations.insert(SPVFuncImplExpandITUNarrowRange);
4309 spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT709);
4310 spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT601);
4311 spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT2020);
4312 }
4313
4314 for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
4315 i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
4316 if (spv_function_implementations.count(static_cast<SPVFuncImpl>(i)))
4317 spv_function_implementations.insert(SPVFuncImplForwardArgs);
4318
4319 if (spv_function_implementations.count(SPVFuncImplTextureSwizzle) ||
4320 spv_function_implementations.count(SPVFuncImplGatherSwizzle) ||
4321 spv_function_implementations.count(SPVFuncImplGatherCompareSwizzle))
4322 {
4323 spv_function_implementations.insert(SPVFuncImplForwardArgs);
4324 spv_function_implementations.insert(SPVFuncImplGetSwizzle);
4325 }
4326
4327 for (const auto &spv_func : spv_function_implementations)
4328 {
4329 switch (spv_func)
4330 {
4331 case SPVFuncImplMod:
4332 statement("// Implementation of the GLSL mod() function, which is slightly different than Metal fmod()");
4333 statement("template<typename Tx, typename Ty>");
4334 statement("inline Tx mod(Tx x, Ty y)");
4335 begin_scope();
4336 statement("return x - y * floor(x / y);");
4337 end_scope();
4338 statement("");
4339 break;
4340
4341 case SPVFuncImplRadians:
4342 statement("// Implementation of the GLSL radians() function");
4343 statement("template<typename T>");
4344 statement("inline T radians(T d)");
4345 begin_scope();
4346 statement("return d * T(0.01745329251);");
4347 end_scope();
4348 statement("");
4349 break;
4350
4351 case SPVFuncImplDegrees:
4352 statement("// Implementation of the GLSL degrees() function");
4353 statement("template<typename T>");
4354 statement("inline T degrees(T r)");
4355 begin_scope();
4356 statement("return r * T(57.2957795131);");
4357 end_scope();
4358 statement("");
4359 break;
4360
4361 case SPVFuncImplFindILsb:
4362 statement("// Implementation of the GLSL findLSB() function");
4363 statement("template<typename T>");
4364 statement("inline T spvFindLSB(T x)");
4365 begin_scope();
4366 statement("return select(ctz(x), T(-1), x == T(0));");
4367 end_scope();
4368 statement("");
4369 break;
4370
4371 case SPVFuncImplFindUMsb:
4372 statement("// Implementation of the unsigned GLSL findMSB() function");
4373 statement("template<typename T>");
4374 statement("inline T spvFindUMSB(T x)");
4375 begin_scope();
4376 statement("return select(clz(T(0)) - (clz(x) + T(1)), T(-1), x == T(0));");
4377 end_scope();
4378 statement("");
4379 break;
4380
4381 case SPVFuncImplFindSMsb:
4382 statement("// Implementation of the signed GLSL findMSB() function");
4383 statement("template<typename T>");
4384 statement("inline T spvFindSMSB(T x)");
4385 begin_scope();
4386 statement("T v = select(x, T(-1) - x, x < T(0));");
4387 statement("return select(clz(T(0)) - (clz(v) + T(1)), T(-1), v == T(0));");
4388 end_scope();
4389 statement("");
4390 break;
4391
4392 case SPVFuncImplSSign:
4393 statement("// Implementation of the GLSL sign() function for integer types");
4394 statement("template<typename T, typename E = typename enable_if<is_integral<T>::value>::type>");
4395 statement("inline T sign(T x)");
4396 begin_scope();
4397 statement("return select(select(select(x, T(0), x == T(0)), T(1), x > T(0)), T(-1), x < T(0));");
4398 end_scope();
4399 statement("");
4400 break;
4401
4402 case SPVFuncImplArrayCopy:
4403 case SPVFuncImplArrayOfArrayCopy2Dim:
4404 case SPVFuncImplArrayOfArrayCopy3Dim:
4405 case SPVFuncImplArrayOfArrayCopy4Dim:
4406 case SPVFuncImplArrayOfArrayCopy5Dim:
4407 case SPVFuncImplArrayOfArrayCopy6Dim:
4408 {
4409 // Unfortunately we cannot template on the address space, so combinatorial explosion it is.
4410 static const char *function_name_tags[] = {
4411 "FromConstantToStack", "FromConstantToThreadGroup", "FromStackToStack",
4412 "FromStackToThreadGroup", "FromThreadGroupToStack", "FromThreadGroupToThreadGroup",
4413 "FromDeviceToDevice", "FromConstantToDevice", "FromStackToDevice",
4414 "FromThreadGroupToDevice", "FromDeviceToStack", "FromDeviceToThreadGroup",
4415 };
4416
4417 static const char *src_address_space[] = {
4418 "constant", "constant", "thread const", "thread const",
4419 "threadgroup const", "threadgroup const", "device const", "constant",
4420 "thread const", "threadgroup const", "device const", "device const",
4421 };
4422
4423 static const char *dst_address_space[] = {
4424 "thread", "threadgroup", "thread", "threadgroup", "thread", "threadgroup",
4425 "device", "device", "device", "device", "thread", "threadgroup",
4426 };
4427
4428 for (uint32_t variant = 0; variant < 12; variant++)
4429 {
4430 uint32_t dimensions = spv_func - SPVFuncImplArrayCopyMultidimBase;
4431 string tmp = "template<typename T";
4432 for (uint8_t i = 0; i < dimensions; i++)
4433 {
4434 tmp += ", uint ";
4435 tmp += 'A' + i;
4436 }
4437 tmp += ">";
4438 statement(tmp);
4439
4440 string array_arg;
4441 for (uint8_t i = 0; i < dimensions; i++)
4442 {
4443 array_arg += "[";
4444 array_arg += 'A' + i;
4445 array_arg += "]";
4446 }
4447
4448 statement("inline void spvArrayCopy", function_name_tags[variant], dimensions, "(",
4449 dst_address_space[variant], " T (&dst)", array_arg, ", ", src_address_space[variant],
4450 " T (&src)", array_arg, ")");
4451
4452 begin_scope();
4453 statement("for (uint i = 0; i < A; i++)");
4454 begin_scope();
4455
4456 if (dimensions == 1)
4457 statement("dst[i] = src[i];");
4458 else
4459 statement("spvArrayCopy", function_name_tags[variant], dimensions - 1, "(dst[i], src[i]);");
4460 end_scope();
4461 end_scope();
4462 statement("");
4463 }
4464 break;
4465 }
4466
4467 // Support for Metal 2.1's new texture_buffer type.
4468 case SPVFuncImplTexelBufferCoords:
4469 {
4470 if (msl_options.texel_buffer_texture_width > 0)
4471 {
4472 string tex_width_str = convert_to_string(msl_options.texel_buffer_texture_width);
4473 statement("// Returns 2D texture coords corresponding to 1D texel buffer coords");
4474 statement(force_inline);
4475 statement("uint2 spvTexelBufferCoord(uint tc)");
4476 begin_scope();
4477 statement(join("return uint2(tc % ", tex_width_str, ", tc / ", tex_width_str, ");"));
4478 end_scope();
4479 statement("");
4480 }
4481 else
4482 {
4483 statement("// Returns 2D texture coords corresponding to 1D texel buffer coords");
4484 statement(
4485 "#define spvTexelBufferCoord(tc, tex) uint2((tc) % (tex).get_width(), (tc) / (tex).get_width())");
4486 statement("");
4487 }
4488 break;
4489 }
4490
4491 // Emulate texture2D atomic operations
4492 case SPVFuncImplImage2DAtomicCoords:
4493 {
4494 if (msl_options.supports_msl_version(1, 2))
4495 {
4496 statement("// The required alignment of a linear texture of R32Uint format.");
4497 statement("constant uint spvLinearTextureAlignmentOverride [[function_constant(",
4498 msl_options.r32ui_alignment_constant_id, ")]];");
4499 statement("constant uint spvLinearTextureAlignment = ",
4500 "is_function_constant_defined(spvLinearTextureAlignmentOverride) ? ",
4501 "spvLinearTextureAlignmentOverride : ", msl_options.r32ui_linear_texture_alignment, ";");
4502 }
4503 else
4504 {
4505 statement("// The required alignment of a linear texture of R32Uint format.");
4506 statement("constant uint spvLinearTextureAlignment = ", msl_options.r32ui_linear_texture_alignment,
4507 ";");
4508 }
4509 statement("// Returns buffer coords corresponding to 2D texture coords for emulating 2D texture atomics");
4510 statement("#define spvImage2DAtomicCoord(tc, tex) (((((tex).get_width() + ",
4511 " spvLinearTextureAlignment / 4 - 1) & ~(",
4512 " spvLinearTextureAlignment / 4 - 1)) * (tc).y) + (tc).x)");
4513 statement("");
4514 break;
4515 }
4516
4517 // "fadd" intrinsic support
4518 case SPVFuncImplFAdd:
4519 statement("template<typename T>");
4520 statement("T spvFAdd(T l, T r)");
4521 begin_scope();
4522 statement("return fma(T(1), l, r);");
4523 end_scope();
4524 statement("");
4525 break;
4526
4527 // "fmul' intrinsic support
4528 case SPVFuncImplFMul:
4529 statement("template<typename T>");
4530 statement("T spvFMul(T l, T r)");
4531 begin_scope();
4532 statement("return fma(l, r, T(0));");
4533 end_scope();
4534 statement("");
4535
4536 statement("template<typename T, int Cols, int Rows>");
4537 statement("vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)");
4538 begin_scope();
4539 statement("vec<T, Cols> res = vec<T, Cols>(0);");
4540 statement("for (uint i = Rows; i > 0; --i)");
4541 begin_scope();
4542 statement("vec<T, Cols> tmp(0);");
4543 statement("for (uint j = 0; j < Cols; ++j)");
4544 begin_scope();
4545 statement("tmp[j] = m[j][i - 1];");
4546 end_scope();
4547 statement("res = fma(tmp, vec<T, Cols>(v[i - 1]), res);");
4548 end_scope();
4549 statement("return res;");
4550 end_scope();
4551 statement("");
4552
4553 statement("template<typename T, int Cols, int Rows>");
4554 statement("vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)");
4555 begin_scope();
4556 statement("vec<T, Rows> res = vec<T, Rows>(0);");
4557 statement("for (uint i = Cols; i > 0; --i)");
4558 begin_scope();
4559 statement("res = fma(m[i - 1], vec<T, Rows>(v[i - 1]), res);");
4560 end_scope();
4561 statement("return res;");
4562 end_scope();
4563 statement("");
4564
4565 statement("template<typename T, int LCols, int LRows, int RCols, int RRows>");
4566 statement(
4567 "matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)");
4568 begin_scope();
4569 statement("matrix<T, RCols, LRows> res;");
4570 statement("for (uint i = 0; i < RCols; i++)");
4571 begin_scope();
4572 statement("vec<T, RCols> tmp(0);");
4573 statement("for (uint j = 0; j < LCols; j++)");
4574 begin_scope();
4575 statement("tmp = fma(vec<T, RCols>(r[i][j]), l[j], tmp);");
4576 end_scope();
4577 statement("res[i] = tmp;");
4578 end_scope();
4579 statement("return res;");
4580 end_scope();
4581 statement("");
4582 break;
4583
4584 // Emulate texturecube_array with texture2d_array for iOS where this type is not available
4585 case SPVFuncImplCubemapTo2DArrayFace:
4586 statement(force_inline);
4587 statement("float3 spvCubemapTo2DArrayFace(float3 P)");
4588 begin_scope();
4589 statement("float3 Coords = abs(P.xyz);");
4590 statement("float CubeFace = 0;");
4591 statement("float ProjectionAxis = 0;");
4592 statement("float u = 0;");
4593 statement("float v = 0;");
4594 statement("if (Coords.x >= Coords.y && Coords.x >= Coords.z)");
4595 begin_scope();
4596 statement("CubeFace = P.x >= 0 ? 0 : 1;");
4597 statement("ProjectionAxis = Coords.x;");
4598 statement("u = P.x >= 0 ? -P.z : P.z;");
4599 statement("v = -P.y;");
4600 end_scope();
4601 statement("else if (Coords.y >= Coords.x && Coords.y >= Coords.z)");
4602 begin_scope();
4603 statement("CubeFace = P.y >= 0 ? 2 : 3;");
4604 statement("ProjectionAxis = Coords.y;");
4605 statement("u = P.x;");
4606 statement("v = P.y >= 0 ? P.z : -P.z;");
4607 end_scope();
4608 statement("else");
4609 begin_scope();
4610 statement("CubeFace = P.z >= 0 ? 4 : 5;");
4611 statement("ProjectionAxis = Coords.z;");
4612 statement("u = P.z >= 0 ? P.x : -P.x;");
4613 statement("v = -P.y;");
4614 end_scope();
4615 statement("u = 0.5 * (u/ProjectionAxis + 1);");
4616 statement("v = 0.5 * (v/ProjectionAxis + 1);");
4617 statement("return float3(u, v, CubeFace);");
4618 end_scope();
4619 statement("");
4620 break;
4621
4622 case SPVFuncImplInverse4x4:
4623 statement("// Returns the determinant of a 2x2 matrix.");
4624 statement(force_inline);
4625 statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
4626 begin_scope();
4627 statement("return a1 * b2 - b1 * a2;");
4628 end_scope();
4629 statement("");
4630
4631 statement("// Returns the determinant of a 3x3 matrix.");
4632 statement(force_inline);
4633 statement("float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
4634 "float c2, float c3)");
4635 begin_scope();
4636 statement("return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * spvDet2x2(a2, a3, "
4637 "b2, b3);");
4638 end_scope();
4639 statement("");
4640 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
4641 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
4642 statement(force_inline);
4643 statement("float4x4 spvInverse4x4(float4x4 m)");
4644 begin_scope();
4645 statement("float4x4 adj; // The adjoint matrix (inverse after dividing by determinant)");
4646 statement_no_indent("");
4647 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
4648 statement("adj[0][0] = spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
4649 "m[3][3]);");
4650 statement("adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
4651 "m[3][3]);");
4652 statement("adj[0][2] = spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
4653 "m[3][3]);");
4654 statement("adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
4655 "m[2][3]);");
4656 statement_no_indent("");
4657 statement("adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
4658 "m[3][3]);");
4659 statement("adj[1][1] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
4660 "m[3][3]);");
4661 statement("adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
4662 "m[3][3]);");
4663 statement("adj[1][3] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
4664 "m[2][3]);");
4665 statement_no_indent("");
4666 statement("adj[2][0] = spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
4667 "m[3][3]);");
4668 statement("adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
4669 "m[3][3]);");
4670 statement("adj[2][2] = spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
4671 "m[3][3]);");
4672 statement("adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
4673 "m[2][3]);");
4674 statement_no_indent("");
4675 statement("adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
4676 "m[3][2]);");
4677 statement("adj[3][1] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
4678 "m[3][2]);");
4679 statement("adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
4680 "m[3][2]);");
4681 statement("adj[3][3] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
4682 "m[2][2]);");
4683 statement_no_indent("");
4684 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
4685 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
4686 "* m[3][0]);");
4687 statement_no_indent("");
4688 statement("// Divide the classical adjoint matrix by the determinant.");
4689 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
4690 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
4691 end_scope();
4692 statement("");
4693 break;
4694
4695 case SPVFuncImplInverse3x3:
4696 if (spv_function_implementations.count(SPVFuncImplInverse4x4) == 0)
4697 {
4698 statement("// Returns the determinant of a 2x2 matrix.");
4699 statement(force_inline);
4700 statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
4701 begin_scope();
4702 statement("return a1 * b2 - b1 * a2;");
4703 end_scope();
4704 statement("");
4705 }
4706
4707 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
4708 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
4709 statement(force_inline);
4710 statement("float3x3 spvInverse3x3(float3x3 m)");
4711 begin_scope();
4712 statement("float3x3 adj; // The adjoint matrix (inverse after dividing by determinant)");
4713 statement_no_indent("");
4714 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
4715 statement("adj[0][0] = spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
4716 statement("adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
4717 statement("adj[0][2] = spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
4718 statement_no_indent("");
4719 statement("adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
4720 statement("adj[1][1] = spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
4721 statement("adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
4722 statement_no_indent("");
4723 statement("adj[2][0] = spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
4724 statement("adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
4725 statement("adj[2][2] = spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
4726 statement_no_indent("");
4727 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
4728 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
4729 statement_no_indent("");
4730 statement("// Divide the classical adjoint matrix by the determinant.");
4731 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
4732 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
4733 end_scope();
4734 statement("");
4735 break;
4736
4737 case SPVFuncImplInverse2x2:
4738 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
4739 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
4740 statement(force_inline);
4741 statement("float2x2 spvInverse2x2(float2x2 m)");
4742 begin_scope();
4743 statement("float2x2 adj; // The adjoint matrix (inverse after dividing by determinant)");
4744 statement_no_indent("");
4745 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
4746 statement("adj[0][0] = m[1][1];");
4747 statement("adj[0][1] = -m[0][1];");
4748 statement_no_indent("");
4749 statement("adj[1][0] = -m[1][0];");
4750 statement("adj[1][1] = m[0][0];");
4751 statement_no_indent("");
4752 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
4753 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
4754 statement_no_indent("");
4755 statement("// Divide the classical adjoint matrix by the determinant.");
4756 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
4757 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
4758 end_scope();
4759 statement("");
4760 break;
4761
4762 case SPVFuncImplForwardArgs:
4763 statement("template<typename T> struct spvRemoveReference { typedef T type; };");
4764 statement("template<typename T> struct spvRemoveReference<thread T&> { typedef T type; };");
4765 statement("template<typename T> struct spvRemoveReference<thread T&&> { typedef T type; };");
4766 statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
4767 "spvRemoveReference<T>::type& x)");
4768 begin_scope();
4769 statement("return static_cast<thread T&&>(x);");
4770 end_scope();
4771 statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
4772 "spvRemoveReference<T>::type&& x)");
4773 begin_scope();
4774 statement("return static_cast<thread T&&>(x);");
4775 end_scope();
4776 statement("");
4777 break;
4778
4779 case SPVFuncImplGetSwizzle:
4780 statement("enum class spvSwizzle : uint");
4781 begin_scope();
4782 statement("none = 0,");
4783 statement("zero,");
4784 statement("one,");
4785 statement("red,");
4786 statement("green,");
4787 statement("blue,");
4788 statement("alpha");
4789 end_scope_decl();
4790 statement("");
4791 statement("template<typename T>");
4792 statement("inline T spvGetSwizzle(vec<T, 4> x, T c, spvSwizzle s)");
4793 begin_scope();
4794 statement("switch (s)");
4795 begin_scope();
4796 statement("case spvSwizzle::none:");
4797 statement(" return c;");
4798 statement("case spvSwizzle::zero:");
4799 statement(" return 0;");
4800 statement("case spvSwizzle::one:");
4801 statement(" return 1;");
4802 statement("case spvSwizzle::red:");
4803 statement(" return x.r;");
4804 statement("case spvSwizzle::green:");
4805 statement(" return x.g;");
4806 statement("case spvSwizzle::blue:");
4807 statement(" return x.b;");
4808 statement("case spvSwizzle::alpha:");
4809 statement(" return x.a;");
4810 end_scope();
4811 end_scope();
4812 statement("");
4813 break;
4814
4815 case SPVFuncImplTextureSwizzle:
4816 statement("// Wrapper function that swizzles texture samples and fetches.");
4817 statement("template<typename T>");
4818 statement("inline vec<T, 4> spvTextureSwizzle(vec<T, 4> x, uint s)");
4819 begin_scope();
4820 statement("if (!s)");
4821 statement(" return x;");
4822 statement("return vec<T, 4>(spvGetSwizzle(x, x.r, spvSwizzle((s >> 0) & 0xFF)), "
4823 "spvGetSwizzle(x, x.g, spvSwizzle((s >> 8) & 0xFF)), spvGetSwizzle(x, x.b, spvSwizzle((s >> 16) "
4824 "& 0xFF)), "
4825 "spvGetSwizzle(x, x.a, spvSwizzle((s >> 24) & 0xFF)));");
4826 end_scope();
4827 statement("");
4828 statement("template<typename T>");
4829 statement("inline T spvTextureSwizzle(T x, uint s)");
4830 begin_scope();
4831 statement("return spvTextureSwizzle(vec<T, 4>(x, 0, 0, 1), s).x;");
4832 end_scope();
4833 statement("");
4834 break;
4835
4836 case SPVFuncImplGatherSwizzle:
4837 statement("// Wrapper function that swizzles texture gathers.");
4838 statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
4839 "typename... Ts>");
4840 statement("inline vec<T, 4> spvGatherSwizzle(const thread Tex<T>& t, sampler s, "
4841 "uint sw, component c, Ts... params) METAL_CONST_ARG(c)");
4842 begin_scope();
4843 statement("if (sw)");
4844 begin_scope();
4845 statement("switch (spvSwizzle((sw >> (uint(c) * 8)) & 0xFF))");
4846 begin_scope();
4847 statement("case spvSwizzle::none:");
4848 statement(" break;");
4849 statement("case spvSwizzle::zero:");
4850 statement(" return vec<T, 4>(0, 0, 0, 0);");
4851 statement("case spvSwizzle::one:");
4852 statement(" return vec<T, 4>(1, 1, 1, 1);");
4853 statement("case spvSwizzle::red:");
4854 statement(" return t.gather(s, spvForward<Ts>(params)..., component::x);");
4855 statement("case spvSwizzle::green:");
4856 statement(" return t.gather(s, spvForward<Ts>(params)..., component::y);");
4857 statement("case spvSwizzle::blue:");
4858 statement(" return t.gather(s, spvForward<Ts>(params)..., component::z);");
4859 statement("case spvSwizzle::alpha:");
4860 statement(" return t.gather(s, spvForward<Ts>(params)..., component::w);");
4861 end_scope();
4862 end_scope();
4863 // texture::gather insists on its component parameter being a constant
4864 // expression, so we need this silly workaround just to compile the shader.
4865 statement("switch (c)");
4866 begin_scope();
4867 statement("case component::x:");
4868 statement(" return t.gather(s, spvForward<Ts>(params)..., component::x);");
4869 statement("case component::y:");
4870 statement(" return t.gather(s, spvForward<Ts>(params)..., component::y);");
4871 statement("case component::z:");
4872 statement(" return t.gather(s, spvForward<Ts>(params)..., component::z);");
4873 statement("case component::w:");
4874 statement(" return t.gather(s, spvForward<Ts>(params)..., component::w);");
4875 end_scope();
4876 end_scope();
4877 statement("");
4878 break;
4879
4880 case SPVFuncImplGatherCompareSwizzle:
4881 statement("// Wrapper function that swizzles depth texture gathers.");
4882 statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
4883 "typename... Ts>");
4884 statement("inline vec<T, 4> spvGatherCompareSwizzle(const thread Tex<T>& t, sampler "
4885 "s, uint sw, Ts... params) ");
4886 begin_scope();
4887 statement("if (sw)");
4888 begin_scope();
4889 statement("switch (spvSwizzle(sw & 0xFF))");
4890 begin_scope();
4891 statement("case spvSwizzle::none:");
4892 statement("case spvSwizzle::red:");
4893 statement(" break;");
4894 statement("case spvSwizzle::zero:");
4895 statement("case spvSwizzle::green:");
4896 statement("case spvSwizzle::blue:");
4897 statement("case spvSwizzle::alpha:");
4898 statement(" return vec<T, 4>(0, 0, 0, 0);");
4899 statement("case spvSwizzle::one:");
4900 statement(" return vec<T, 4>(1, 1, 1, 1);");
4901 end_scope();
4902 end_scope();
4903 statement("return t.gather_compare(s, spvForward<Ts>(params)...);");
4904 end_scope();
4905 statement("");
4906 break;
4907
4908 case SPVFuncImplSubgroupBroadcast:
4909 // Metal doesn't allow broadcasting boolean values directly, but we can work around that by broadcasting
4910 // them as integers.
4911 statement("template<typename T>");
4912 statement("inline T spvSubgroupBroadcast(T value, ushort lane)");
4913 begin_scope();
4914 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
4915 statement("return quad_broadcast(value, lane);");
4916 else
4917 statement("return simd_broadcast(value, lane);");
4918 end_scope();
4919 statement("");
4920 statement("template<>");
4921 statement("inline bool spvSubgroupBroadcast(bool value, ushort lane)");
4922 begin_scope();
4923 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
4924 statement("return !!quad_broadcast((ushort)value, lane);");
4925 else
4926 statement("return !!simd_broadcast((ushort)value, lane);");
4927 end_scope();
4928 statement("");
4929 statement("template<uint N>");
4930 statement("inline vec<bool, N> spvSubgroupBroadcast(vec<bool, N> value, ushort lane)");
4931 begin_scope();
4932 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
4933 statement("return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
4934 else
4935 statement("return (vec<bool, N>)simd_broadcast((vec<ushort, N>)value, lane);");
4936 end_scope();
4937 statement("");
4938 break;
4939
4940 case SPVFuncImplSubgroupBroadcastFirst:
4941 statement("template<typename T>");
4942 statement("inline T spvSubgroupBroadcastFirst(T value)");
4943 begin_scope();
4944 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
4945 statement("return quad_broadcast_first(value);");
4946 else
4947 statement("return simd_broadcast_first(value);");
4948 end_scope();
4949 statement("");
4950 statement("template<>");
4951 statement("inline bool spvSubgroupBroadcastFirst(bool value)");
4952 begin_scope();
4953 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
4954 statement("return !!quad_broadcast_first((ushort)value);");
4955 else
4956 statement("return !!simd_broadcast_first((ushort)value);");
4957 end_scope();
4958 statement("");
4959 statement("template<uint N>");
4960 statement("inline vec<bool, N> spvSubgroupBroadcastFirst(vec<bool, N> value)");
4961 begin_scope();
4962 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
4963 statement("return (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value);");
4964 else
4965 statement("return (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value);");
4966 end_scope();
4967 statement("");
4968 break;
4969
4970 case SPVFuncImplSubgroupBallot:
4971 statement("inline uint4 spvSubgroupBallot(bool value)");
4972 begin_scope();
4973 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
4974 {
4975 statement("return uint4((quad_vote::vote_t)quad_ballot(value), 0, 0, 0);");
4976 }
4977 else if (msl_options.is_ios())
4978 {
4979 // The current simd_vote on iOS uses a 32-bit integer-like object.
4980 statement("return uint4((simd_vote::vote_t)simd_ballot(value), 0, 0, 0);");
4981 }
4982 else
4983 {
4984 statement("simd_vote vote = simd_ballot(value);");
4985 statement("// simd_ballot() returns a 64-bit integer-like object, but");
4986 statement("// SPIR-V callers expect a uint4. We must convert.");
4987 statement("// FIXME: This won't include higher bits if Apple ever supports");
4988 statement("// 128 lanes in an SIMD-group.");
4989 statement(
4990 "return uint4((uint)((simd_vote::vote_t)vote & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)vote >> "
4991 "32) & 0xFFFFFFFF), 0, 0);");
4992 }
4993 end_scope();
4994 statement("");
4995 break;
4996
4997 case SPVFuncImplSubgroupBallotBitExtract:
4998 statement("inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)");
4999 begin_scope();
5000 statement("return !!extract_bits(ballot[bit / 32], bit % 32, 1);");
5001 end_scope();
5002 statement("");
5003 break;
5004
5005 case SPVFuncImplSubgroupBallotFindLSB:
5006 statement("inline uint spvSubgroupBallotFindLSB(uint4 ballot, uint gl_SubgroupSize)");
5007 begin_scope();
5008 if (msl_options.is_ios())
5009 {
5010 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
5011 }
5012 else
5013 {
5014 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
5015 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
5016 }
5017 statement("ballot &= mask;");
5018 statement("return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + "
5019 "ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);");
5020 end_scope();
5021 statement("");
5022 break;
5023
5024 case SPVFuncImplSubgroupBallotFindMSB:
5025 statement("inline uint spvSubgroupBallotFindMSB(uint4 ballot, uint gl_SubgroupSize)");
5026 begin_scope();
5027 if (msl_options.is_ios())
5028 {
5029 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
5030 }
5031 else
5032 {
5033 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
5034 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
5035 }
5036 statement("ballot &= mask;");
5037 statement("return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - "
5038 "(clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), "
5039 "ballot.z == 0), ballot.w == 0);");
5040 end_scope();
5041 statement("");
5042 break;
5043
5044 case SPVFuncImplSubgroupBallotBitCount:
5045 statement("inline uint spvPopCount4(uint4 ballot)");
5046 begin_scope();
5047 statement("return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);");
5048 end_scope();
5049 statement("");
5050 statement("inline uint spvSubgroupBallotBitCount(uint4 ballot, uint gl_SubgroupSize)");
5051 begin_scope();
5052 if (msl_options.is_ios())
5053 {
5054 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
5055 }
5056 else
5057 {
5058 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
5059 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
5060 }
5061 statement("return spvPopCount4(ballot & mask);");
5062 end_scope();
5063 statement("");
5064 statement("inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
5065 begin_scope();
5066 if (msl_options.is_ios())
5067 {
5068 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID + 1), uint3(0));");
5069 }
5070 else
5071 {
5072 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), "
5073 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), "
5074 "uint2(0));");
5075 }
5076 statement("return spvPopCount4(ballot & mask);");
5077 end_scope();
5078 statement("");
5079 statement("inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
5080 begin_scope();
5081 if (msl_options.is_ios())
5082 {
5083 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID), uint2(0));");
5084 }
5085 else
5086 {
5087 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), "
5088 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));");
5089 }
5090 statement("return spvPopCount4(ballot & mask);");
5091 end_scope();
5092 statement("");
5093 break;
5094
5095 case SPVFuncImplSubgroupAllEqual:
5096 // Metal doesn't provide a function to evaluate this directly. But, we can
5097 // implement this by comparing every thread's value to one thread's value
5098 // (in this case, the value of the first active thread). Then, by the transitive
5099 // property of equality, if all comparisons return true, then they are all equal.
5100 statement("template<typename T>");
5101 statement("inline bool spvSubgroupAllEqual(T value)");
5102 begin_scope();
5103 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5104 statement("return quad_all(all(value == quad_broadcast_first(value)));");
5105 else
5106 statement("return simd_all(all(value == simd_broadcast_first(value)));");
5107 end_scope();
5108 statement("");
5109 statement("template<>");
5110 statement("inline bool spvSubgroupAllEqual(bool value)");
5111 begin_scope();
5112 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5113 statement("return quad_all(value) || !quad_any(value);");
5114 else
5115 statement("return simd_all(value) || !simd_any(value);");
5116 end_scope();
5117 statement("");
5118 statement("template<uint N>");
5119 statement("inline bool spvSubgroupAllEqual(vec<bool, N> value)");
5120 begin_scope();
5121 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5122 statement("return quad_all(all(value == (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value)));");
5123 else
5124 statement("return simd_all(all(value == (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value)));");
5125 end_scope();
5126 statement("");
5127 break;
5128
5129 case SPVFuncImplSubgroupShuffle:
5130 statement("template<typename T>");
5131 statement("inline T spvSubgroupShuffle(T value, ushort lane)");
5132 begin_scope();
5133 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5134 statement("return quad_shuffle(value, lane);");
5135 else
5136 statement("return simd_shuffle(value, lane);");
5137 end_scope();
5138 statement("");
5139 statement("template<>");
5140 statement("inline bool spvSubgroupShuffle(bool value, ushort lane)");
5141 begin_scope();
5142 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5143 statement("return !!quad_shuffle((ushort)value, lane);");
5144 else
5145 statement("return !!simd_shuffle((ushort)value, lane);");
5146 end_scope();
5147 statement("");
5148 statement("template<uint N>");
5149 statement("inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)");
5150 begin_scope();
5151 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5152 statement("return (vec<bool, N>)quad_shuffle((vec<ushort, N>)value, lane);");
5153 else
5154 statement("return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);");
5155 end_scope();
5156 statement("");
5157 break;
5158
5159 case SPVFuncImplSubgroupShuffleXor:
5160 statement("template<typename T>");
5161 statement("inline T spvSubgroupShuffleXor(T value, ushort mask)");
5162 begin_scope();
5163 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5164 statement("return quad_shuffle_xor(value, mask);");
5165 else
5166 statement("return simd_shuffle_xor(value, mask);");
5167 end_scope();
5168 statement("");
5169 statement("template<>");
5170 statement("inline bool spvSubgroupShuffleXor(bool value, ushort mask)");
5171 begin_scope();
5172 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5173 statement("return !!quad_shuffle_xor((ushort)value, mask);");
5174 else
5175 statement("return !!simd_shuffle_xor((ushort)value, mask);");
5176 end_scope();
5177 statement("");
5178 statement("template<uint N>");
5179 statement("inline vec<bool, N> spvSubgroupShuffleXor(vec<bool, N> value, ushort mask)");
5180 begin_scope();
5181 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5182 statement("return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, mask);");
5183 else
5184 statement("return (vec<bool, N>)simd_shuffle_xor((vec<ushort, N>)value, mask);");
5185 end_scope();
5186 statement("");
5187 break;
5188
5189 case SPVFuncImplSubgroupShuffleUp:
5190 statement("template<typename T>");
5191 statement("inline T spvSubgroupShuffleUp(T value, ushort delta)");
5192 begin_scope();
5193 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5194 statement("return quad_shuffle_up(value, delta);");
5195 else
5196 statement("return simd_shuffle_up(value, delta);");
5197 end_scope();
5198 statement("");
5199 statement("template<>");
5200 statement("inline bool spvSubgroupShuffleUp(bool value, ushort delta)");
5201 begin_scope();
5202 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5203 statement("return !!quad_shuffle_up((ushort)value, delta);");
5204 else
5205 statement("return !!simd_shuffle_up((ushort)value, delta);");
5206 end_scope();
5207 statement("");
5208 statement("template<uint N>");
5209 statement("inline vec<bool, N> spvSubgroupShuffleUp(vec<bool, N> value, ushort delta)");
5210 begin_scope();
5211 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5212 statement("return (vec<bool, N>)quad_shuffle_up((vec<ushort, N>)value, delta);");
5213 else
5214 statement("return (vec<bool, N>)simd_shuffle_up((vec<ushort, N>)value, delta);");
5215 end_scope();
5216 statement("");
5217 break;
5218
5219 case SPVFuncImplSubgroupShuffleDown:
5220 statement("template<typename T>");
5221 statement("inline T spvSubgroupShuffleDown(T value, ushort delta)");
5222 begin_scope();
5223 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5224 statement("return quad_shuffle_down(value, delta);");
5225 else
5226 statement("return simd_shuffle_down(value, delta);");
5227 end_scope();
5228 statement("");
5229 statement("template<>");
5230 statement("inline bool spvSubgroupShuffleDown(bool value, ushort delta)");
5231 begin_scope();
5232 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5233 statement("return !!quad_shuffle_down((ushort)value, delta);");
5234 else
5235 statement("return !!simd_shuffle_down((ushort)value, delta);");
5236 end_scope();
5237 statement("");
5238 statement("template<uint N>");
5239 statement("inline vec<bool, N> spvSubgroupShuffleDown(vec<bool, N> value, ushort delta)");
5240 begin_scope();
5241 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5242 statement("return (vec<bool, N>)quad_shuffle_down((vec<ushort, N>)value, delta);");
5243 else
5244 statement("return (vec<bool, N>)simd_shuffle_down((vec<ushort, N>)value, delta);");
5245 end_scope();
5246 statement("");
5247 break;
5248
5249 case SPVFuncImplQuadBroadcast:
5250 statement("template<typename T>");
5251 statement("inline T spvQuadBroadcast(T value, uint lane)");
5252 begin_scope();
5253 statement("return quad_broadcast(value, lane);");
5254 end_scope();
5255 statement("");
5256 statement("template<>");
5257 statement("inline bool spvQuadBroadcast(bool value, uint lane)");
5258 begin_scope();
5259 statement("return !!quad_broadcast((ushort)value, lane);");
5260 end_scope();
5261 statement("");
5262 statement("template<uint N>");
5263 statement("inline vec<bool, N> spvQuadBroadcast(vec<bool, N> value, uint lane)");
5264 begin_scope();
5265 statement("return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
5266 end_scope();
5267 statement("");
5268 break;
5269
5270 case SPVFuncImplQuadSwap:
5271 // We can implement this easily based on the following table giving
5272 // the target lane ID from the direction and current lane ID:
5273 // Direction
5274 // | 0 | 1 | 2 |
5275 // ---+---+---+---+
5276 // L 0 | 1 2 3
5277 // a 1 | 0 3 2
5278 // n 2 | 3 0 1
5279 // e 3 | 2 1 0
5280 // Notice that target = source ^ (direction + 1).
5281 statement("template<typename T>");
5282 statement("inline T spvQuadSwap(T value, uint dir)");
5283 begin_scope();
5284 statement("return quad_shuffle_xor(value, dir + 1);");
5285 end_scope();
5286 statement("");
5287 statement("template<>");
5288 statement("inline bool spvQuadSwap(bool value, uint dir)");
5289 begin_scope();
5290 statement("return !!quad_shuffle_xor((ushort)value, dir + 1);");
5291 end_scope();
5292 statement("");
5293 statement("template<uint N>");
5294 statement("inline vec<bool, N> spvQuadSwap(vec<bool, N> value, uint dir)");
5295 begin_scope();
5296 statement("return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, dir + 1);");
5297 end_scope();
5298 statement("");
5299 break;
5300
5301 case SPVFuncImplReflectScalar:
5302 // Metal does not support scalar versions of these functions.
5303 statement("template<typename T>");
5304 statement("inline T spvReflect(T i, T n)");
5305 begin_scope();
5306 statement("return i - T(2) * i * n * n;");
5307 end_scope();
5308 statement("");
5309 break;
5310
5311 case SPVFuncImplRefractScalar:
5312 // Metal does not support scalar versions of these functions.
5313 statement("template<typename T>");
5314 statement("inline T spvRefract(T i, T n, T eta)");
5315 begin_scope();
5316 statement("T NoI = n * i;");
5317 statement("T NoI2 = NoI * NoI;");
5318 statement("T k = T(1) - eta * eta * (T(1) - NoI2);");
5319 statement("if (k < T(0))");
5320 begin_scope();
5321 statement("return T(0);");
5322 end_scope();
5323 statement("else");
5324 begin_scope();
5325 statement("return eta * i - (eta * NoI + sqrt(k)) * n;");
5326 end_scope();
5327 end_scope();
5328 statement("");
5329 break;
5330
5331 case SPVFuncImplFaceForwardScalar:
5332 // Metal does not support scalar versions of these functions.
5333 statement("template<typename T>");
5334 statement("inline T spvFaceForward(T n, T i, T nref)");
5335 begin_scope();
5336 statement("return i * nref < T(0) ? n : -n;");
5337 end_scope();
5338 statement("");
5339 break;
5340
5341 case SPVFuncImplChromaReconstructNearest2Plane:
5342 statement("template<typename T, typename... LodOptions>");
5343 statement("inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, sampler "
5344 "samp, float2 coord, LodOptions... options)");
5345 begin_scope();
5346 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5347 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5348 statement("ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
5349 statement("return ycbcr;");
5350 end_scope();
5351 statement("");
5352 break;
5353
5354 case SPVFuncImplChromaReconstructNearest3Plane:
5355 statement("template<typename T, typename... LodOptions>");
5356 statement("inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, "
5357 "texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5358 begin_scope();
5359 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5360 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5361 statement("ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5362 statement("ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5363 statement("return ycbcr;");
5364 end_scope();
5365 statement("");
5366 break;
5367
5368 case SPVFuncImplChromaReconstructLinear422CositedEven2Plane:
5369 statement("template<typename T, typename... LodOptions>");
5370 statement("inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
5371 "plane1, sampler samp, float2 coord, LodOptions... options)");
5372 begin_scope();
5373 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5374 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5375 statement("if (fract(coord.x * plane1.get_width()) != 0.0)");
5376 begin_scope();
5377 statement("ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5378 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).rg);");
5379 end_scope();
5380 statement("else");
5381 begin_scope();
5382 statement("ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
5383 end_scope();
5384 statement("return ycbcr;");
5385 end_scope();
5386 statement("");
5387 break;
5388
5389 case SPVFuncImplChromaReconstructLinear422CositedEven3Plane:
5390 statement("template<typename T, typename... LodOptions>");
5391 statement("inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
5392 "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5393 begin_scope();
5394 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5395 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5396 statement("if (fract(coord.x * plane1.get_width()) != 0.0)");
5397 begin_scope();
5398 statement("ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5399 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
5400 statement("ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5401 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
5402 end_scope();
5403 statement("else");
5404 begin_scope();
5405 statement("ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5406 statement("ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5407 end_scope();
5408 statement("return ycbcr;");
5409 end_scope();
5410 statement("");
5411 break;
5412
5413 case SPVFuncImplChromaReconstructLinear422Midpoint2Plane:
5414 statement("template<typename T, typename... LodOptions>");
5415 statement("inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
5416 "plane1, sampler samp, float2 coord, LodOptions... options)");
5417 begin_scope();
5418 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5419 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5420 statement("int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
5421 statement("ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5422 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).rg);");
5423 statement("return ycbcr;");
5424 end_scope();
5425 statement("");
5426 break;
5427
5428 case SPVFuncImplChromaReconstructLinear422Midpoint3Plane:
5429 statement("template<typename T, typename... LodOptions>");
5430 statement("inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
5431 "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5432 begin_scope();
5433 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5434 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5435 statement("int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
5436 statement("ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5437 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
5438 statement("ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5439 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
5440 statement("return ycbcr;");
5441 end_scope();
5442 statement("");
5443 break;
5444
5445 case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane:
5446 statement("template<typename T, typename... LodOptions>");
5447 statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
5448 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5449 begin_scope();
5450 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5451 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5452 statement("float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
5453 statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5454 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5455 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5456 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5457 statement("return ycbcr;");
5458 end_scope();
5459 statement("");
5460 break;
5461
5462 case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane:
5463 statement("template<typename T, typename... LodOptions>");
5464 statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
5465 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5466 begin_scope();
5467 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5468 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5469 statement("float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
5470 statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5471 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5472 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5473 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5474 statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5475 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5476 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5477 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5478 statement("return ycbcr;");
5479 end_scope();
5480 statement("");
5481 break;
5482
5483 case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane:
5484 statement("template<typename T, typename... LodOptions>");
5485 statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
5486 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5487 begin_scope();
5488 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5489 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5490 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
5491 "0)) * 0.5);");
5492 statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5493 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5494 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5495 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5496 statement("return ycbcr;");
5497 end_scope();
5498 statement("");
5499 break;
5500
5501 case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane:
5502 statement("template<typename T, typename... LodOptions>");
5503 statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
5504 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5505 begin_scope();
5506 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5507 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5508 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
5509 "0)) * 0.5);");
5510 statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5511 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5512 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5513 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5514 statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5515 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5516 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5517 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5518 statement("return ycbcr;");
5519 end_scope();
5520 statement("");
5521 break;
5522
5523 case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane:
5524 statement("template<typename T, typename... LodOptions>");
5525 statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
5526 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5527 begin_scope();
5528 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5529 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5530 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
5531 "0.5)) * 0.5);");
5532 statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5533 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5534 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5535 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5536 statement("return ycbcr;");
5537 end_scope();
5538 statement("");
5539 break;
5540
5541 case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane:
5542 statement("template<typename T, typename... LodOptions>");
5543 statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
5544 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5545 begin_scope();
5546 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5547 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5548 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
5549 "0.5)) * 0.5);");
5550 statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5551 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5552 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5553 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5554 statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5555 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5556 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5557 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5558 statement("return ycbcr;");
5559 end_scope();
5560 statement("");
5561 break;
5562
5563 case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane:
5564 statement("template<typename T, typename... LodOptions>");
5565 statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
5566 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5567 begin_scope();
5568 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5569 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5570 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
5571 "0.5)) * 0.5);");
5572 statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5573 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5574 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5575 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5576 statement("return ycbcr;");
5577 end_scope();
5578 statement("");
5579 break;
5580
5581 case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane:
5582 statement("template<typename T, typename... LodOptions>");
5583 statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
5584 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5585 begin_scope();
5586 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5587 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5588 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
5589 "0.5)) * 0.5);");
5590 statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5591 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5592 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5593 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5594 statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5595 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5596 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5597 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5598 statement("return ycbcr;");
5599 end_scope();
5600 statement("");
5601 break;
5602
5603 case SPVFuncImplExpandITUFullRange:
5604 statement("template<typename T>");
5605 statement("inline vec<T, 4> spvExpandITUFullRange(vec<T, 4> ycbcr, int n)");
5606 begin_scope();
5607 statement("ycbcr.br -= exp2(T(n-1))/(exp2(T(n))-1);");
5608 statement("return ycbcr;");
5609 end_scope();
5610 statement("");
5611 break;
5612
5613 case SPVFuncImplExpandITUNarrowRange:
5614 statement("template<typename T>");
5615 statement("inline vec<T, 4> spvExpandITUNarrowRange(vec<T, 4> ycbcr, int n)");
5616 begin_scope();
5617 statement("ycbcr.g = (ycbcr.g * (exp2(T(n)) - 1) - ldexp(T(16), n - 8))/ldexp(T(219), n - 8);");
5618 statement("ycbcr.br = (ycbcr.br * (exp2(T(n)) - 1) - ldexp(T(128), n - 8))/ldexp(T(224), n - 8);");
5619 statement("return ycbcr;");
5620 end_scope();
5621 statement("");
5622 break;
5623
5624 case SPVFuncImplConvertYCbCrBT709:
5625 statement("// cf. Khronos Data Format Specification, section 15.1.1");
5626 statement("constant float3x3 spvBT709Factors = {{1, 1, 1}, {0, -0.13397432/0.7152, 1.8556}, {1.5748, "
5627 "-0.33480248/0.7152, 0}};");
5628 statement("");
5629 statement("template<typename T>");
5630 statement("inline vec<T, 4> spvConvertYCbCrBT709(vec<T, 4> ycbcr)");
5631 begin_scope();
5632 statement("vec<T, 4> rgba;");
5633 statement("rgba.rgb = vec<T, 3>(spvBT709Factors * ycbcr.gbr);");
5634 statement("rgba.a = ycbcr.a;");
5635 statement("return rgba;");
5636 end_scope();
5637 statement("");
5638 break;
5639
5640 case SPVFuncImplConvertYCbCrBT601:
5641 statement("// cf. Khronos Data Format Specification, section 15.1.2");
5642 statement("constant float3x3 spvBT601Factors = {{1, 1, 1}, {0, -0.202008/0.587, 1.772}, {1.402, "
5643 "-0.419198/0.587, 0}};");
5644 statement("");
5645 statement("template<typename T>");
5646 statement("inline vec<T, 4> spvConvertYCbCrBT601(vec<T, 4> ycbcr)");
5647 begin_scope();
5648 statement("vec<T, 4> rgba;");
5649 statement("rgba.rgb = vec<T, 3>(spvBT601Factors * ycbcr.gbr);");
5650 statement("rgba.a = ycbcr.a;");
5651 statement("return rgba;");
5652 end_scope();
5653 statement("");
5654 break;
5655
5656 case SPVFuncImplConvertYCbCrBT2020:
5657 statement("// cf. Khronos Data Format Specification, section 15.1.3");
5658 statement("constant float3x3 spvBT2020Factors = {{1, 1, 1}, {0, -0.11156702/0.6780, 1.8814}, {1.4746, "
5659 "-0.38737742/0.6780, 0}};");
5660 statement("");
5661 statement("template<typename T>");
5662 statement("inline vec<T, 4> spvConvertYCbCrBT2020(vec<T, 4> ycbcr)");
5663 begin_scope();
5664 statement("vec<T, 4> rgba;");
5665 statement("rgba.rgb = vec<T, 3>(spvBT2020Factors * ycbcr.gbr);");
5666 statement("rgba.a = ycbcr.a;");
5667 statement("return rgba;");
5668 end_scope();
5669 statement("");
5670 break;
5671
5672 case SPVFuncImplDynamicImageSampler:
5673 statement("enum class spvFormatResolution");
5674 begin_scope();
5675 statement("_444 = 0,");
5676 statement("_422,");
5677 statement("_420");
5678 end_scope_decl();
5679 statement("");
5680 statement("enum class spvChromaFilter");
5681 begin_scope();
5682 statement("nearest = 0,");
5683 statement("linear");
5684 end_scope_decl();
5685 statement("");
5686 statement("enum class spvXChromaLocation");
5687 begin_scope();
5688 statement("cosited_even = 0,");
5689 statement("midpoint");
5690 end_scope_decl();
5691 statement("");
5692 statement("enum class spvYChromaLocation");
5693 begin_scope();
5694 statement("cosited_even = 0,");
5695 statement("midpoint");
5696 end_scope_decl();
5697 statement("");
5698 statement("enum class spvYCbCrModelConversion");
5699 begin_scope();
5700 statement("rgb_identity = 0,");
5701 statement("ycbcr_identity,");
5702 statement("ycbcr_bt_709,");
5703 statement("ycbcr_bt_601,");
5704 statement("ycbcr_bt_2020");
5705 end_scope_decl();
5706 statement("");
5707 statement("enum class spvYCbCrRange");
5708 begin_scope();
5709 statement("itu_full = 0,");
5710 statement("itu_narrow");
5711 end_scope_decl();
5712 statement("");
5713 statement("struct spvComponentBits");
5714 begin_scope();
5715 statement("constexpr explicit spvComponentBits(int v) thread : value(v) {}");
5716 statement("uchar value : 6;");
5717 end_scope_decl();
5718 statement("// A class corresponding to metal::sampler which holds sampler");
5719 statement("// Y'CbCr conversion info.");
5720 statement("struct spvYCbCrSampler");
5721 begin_scope();
5722 statement("constexpr spvYCbCrSampler() thread : val(build()) {}");
5723 statement("template<typename... Ts>");
5724 statement("constexpr spvYCbCrSampler(Ts... t) thread : val(build(t...)) {}");
5725 statement("constexpr spvYCbCrSampler(const thread spvYCbCrSampler& s) thread = default;");
5726 statement("");
5727 statement("spvFormatResolution get_resolution() const thread");
5728 begin_scope();
5729 statement("return spvFormatResolution((val & resolution_mask) >> resolution_base);");
5730 end_scope();
5731 statement("spvChromaFilter get_chroma_filter() const thread");
5732 begin_scope();
5733 statement("return spvChromaFilter((val & chroma_filter_mask) >> chroma_filter_base);");
5734 end_scope();
5735 statement("spvXChromaLocation get_x_chroma_offset() const thread");
5736 begin_scope();
5737 statement("return spvXChromaLocation((val & x_chroma_off_mask) >> x_chroma_off_base);");
5738 end_scope();
5739 statement("spvYChromaLocation get_y_chroma_offset() const thread");
5740 begin_scope();
5741 statement("return spvYChromaLocation((val & y_chroma_off_mask) >> y_chroma_off_base);");
5742 end_scope();
5743 statement("spvYCbCrModelConversion get_ycbcr_model() const thread");
5744 begin_scope();
5745 statement("return spvYCbCrModelConversion((val & ycbcr_model_mask) >> ycbcr_model_base);");
5746 end_scope();
5747 statement("spvYCbCrRange get_ycbcr_range() const thread");
5748 begin_scope();
5749 statement("return spvYCbCrRange((val & ycbcr_range_mask) >> ycbcr_range_base);");
5750 end_scope();
5751 statement("int get_bpc() const thread { return (val & bpc_mask) >> bpc_base; }");
5752 statement("");
5753 statement("private:");
5754 statement("ushort val;");
5755 statement("");
5756 statement("constexpr static constant ushort resolution_bits = 2;");
5757 statement("constexpr static constant ushort chroma_filter_bits = 2;");
5758 statement("constexpr static constant ushort x_chroma_off_bit = 1;");
5759 statement("constexpr static constant ushort y_chroma_off_bit = 1;");
5760 statement("constexpr static constant ushort ycbcr_model_bits = 3;");
5761 statement("constexpr static constant ushort ycbcr_range_bit = 1;");
5762 statement("constexpr static constant ushort bpc_bits = 6;");
5763 statement("");
5764 statement("constexpr static constant ushort resolution_base = 0;");
5765 statement("constexpr static constant ushort chroma_filter_base = 2;");
5766 statement("constexpr static constant ushort x_chroma_off_base = 4;");
5767 statement("constexpr static constant ushort y_chroma_off_base = 5;");
5768 statement("constexpr static constant ushort ycbcr_model_base = 6;");
5769 statement("constexpr static constant ushort ycbcr_range_base = 9;");
5770 statement("constexpr static constant ushort bpc_base = 10;");
5771 statement("");
5772 statement(
5773 "constexpr static constant ushort resolution_mask = ((1 << resolution_bits) - 1) << resolution_base;");
5774 statement("constexpr static constant ushort chroma_filter_mask = ((1 << chroma_filter_bits) - 1) << "
5775 "chroma_filter_base;");
5776 statement("constexpr static constant ushort x_chroma_off_mask = ((1 << x_chroma_off_bit) - 1) << "
5777 "x_chroma_off_base;");
5778 statement("constexpr static constant ushort y_chroma_off_mask = ((1 << y_chroma_off_bit) - 1) << "
5779 "y_chroma_off_base;");
5780 statement("constexpr static constant ushort ycbcr_model_mask = ((1 << ycbcr_model_bits) - 1) << "
5781 "ycbcr_model_base;");
5782 statement("constexpr static constant ushort ycbcr_range_mask = ((1 << ycbcr_range_bit) - 1) << "
5783 "ycbcr_range_base;");
5784 statement("constexpr static constant ushort bpc_mask = ((1 << bpc_bits) - 1) << bpc_base;");
5785 statement("");
5786 statement("static constexpr ushort build()");
5787 begin_scope();
5788 statement("return 0;");
5789 end_scope();
5790 statement("");
5791 statement("template<typename... Ts>");
5792 statement("static constexpr ushort build(spvFormatResolution res, Ts... t)");
5793 begin_scope();
5794 statement("return (ushort(res) << resolution_base) | (build(t...) & ~resolution_mask);");
5795 end_scope();
5796 statement("");
5797 statement("template<typename... Ts>");
5798 statement("static constexpr ushort build(spvChromaFilter filt, Ts... t)");
5799 begin_scope();
5800 statement("return (ushort(filt) << chroma_filter_base) | (build(t...) & ~chroma_filter_mask);");
5801 end_scope();
5802 statement("");
5803 statement("template<typename... Ts>");
5804 statement("static constexpr ushort build(spvXChromaLocation loc, Ts... t)");
5805 begin_scope();
5806 statement("return (ushort(loc) << x_chroma_off_base) | (build(t...) & ~x_chroma_off_mask);");
5807 end_scope();
5808 statement("");
5809 statement("template<typename... Ts>");
5810 statement("static constexpr ushort build(spvYChromaLocation loc, Ts... t)");
5811 begin_scope();
5812 statement("return (ushort(loc) << y_chroma_off_base) | (build(t...) & ~y_chroma_off_mask);");
5813 end_scope();
5814 statement("");
5815 statement("template<typename... Ts>");
5816 statement("static constexpr ushort build(spvYCbCrModelConversion model, Ts... t)");
5817 begin_scope();
5818 statement("return (ushort(model) << ycbcr_model_base) | (build(t...) & ~ycbcr_model_mask);");
5819 end_scope();
5820 statement("");
5821 statement("template<typename... Ts>");
5822 statement("static constexpr ushort build(spvYCbCrRange range, Ts... t)");
5823 begin_scope();
5824 statement("return (ushort(range) << ycbcr_range_base) | (build(t...) & ~ycbcr_range_mask);");
5825 end_scope();
5826 statement("");
5827 statement("template<typename... Ts>");
5828 statement("static constexpr ushort build(spvComponentBits bpc, Ts... t)");
5829 begin_scope();
5830 statement("return (ushort(bpc.value) << bpc_base) | (build(t...) & ~bpc_mask);");
5831 end_scope();
5832 end_scope_decl();
5833 statement("");
5834 statement("// A class which can hold up to three textures and a sampler, including");
5835 statement("// Y'CbCr conversion info, used to pass combined image-samplers");
5836 statement("// dynamically to functions.");
5837 statement("template<typename T>");
5838 statement("struct spvDynamicImageSampler");
5839 begin_scope();
5840 statement("texture2d<T> plane0;");
5841 statement("texture2d<T> plane1;");
5842 statement("texture2d<T> plane2;");
5843 statement("sampler samp;");
5844 statement("spvYCbCrSampler ycbcr_samp;");
5845 statement("uint swizzle = 0;");
5846 statement("");
5847 if (msl_options.swizzle_texture_samples)
5848 {
5849 statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, uint sw) thread :");
5850 statement(" plane0(tex), samp(samp), swizzle(sw) {}");
5851 }
5852 else
5853 {
5854 statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp) thread :");
5855 statement(" plane0(tex), samp(samp) {}");
5856 }
5857 statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, spvYCbCrSampler ycbcr_samp, "
5858 "uint sw) thread :");
5859 statement(" plane0(tex), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
5860 statement("constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1,");
5861 statement(" sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
5862 statement(" plane0(plane0), plane1(plane1), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
5863 statement(
5864 "constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1, texture2d<T> plane2,");
5865 statement(" sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
5866 statement(" plane0(plane0), plane1(plane1), plane2(plane2), samp(samp), ycbcr_samp(ycbcr_samp), "
5867 "swizzle(sw) {}");
5868 statement("");
5869 // XXX This is really hard to follow... I've left comments to make it a bit easier.
5870 statement("template<typename... LodOptions>");
5871 statement("vec<T, 4> do_sample(float2 coord, LodOptions... options) const thread");
5872 begin_scope();
5873 statement("if (!is_null_texture(plane1))");
5874 begin_scope();
5875 statement("if (ycbcr_samp.get_resolution() == spvFormatResolution::_444 ||");
5876 statement(" ycbcr_samp.get_chroma_filter() == spvChromaFilter::nearest)");
5877 begin_scope();
5878 statement("if (!is_null_texture(plane2))");
5879 statement(" return spvChromaReconstructNearest(plane0, plane1, plane2, samp, coord,");
5880 statement(" spvForward<LodOptions>(options)...);");
5881 statement(
5882 "return spvChromaReconstructNearest(plane0, plane1, samp, coord, spvForward<LodOptions>(options)...);");
5883 end_scope(); // if (resolution == 422 || chroma_filter == nearest)
5884 statement("switch (ycbcr_samp.get_resolution())");
5885 begin_scope();
5886 statement("case spvFormatResolution::_444: break;");
5887 statement("case spvFormatResolution::_422:");
5888 begin_scope();
5889 statement("switch (ycbcr_samp.get_x_chroma_offset())");
5890 begin_scope();
5891 statement("case spvXChromaLocation::cosited_even:");
5892 statement(" if (!is_null_texture(plane2))");
5893 statement(" return spvChromaReconstructLinear422CositedEven(");
5894 statement(" plane0, plane1, plane2, samp,");
5895 statement(" coord, spvForward<LodOptions>(options)...);");
5896 statement(" return spvChromaReconstructLinear422CositedEven(");
5897 statement(" plane0, plane1, samp, coord,");
5898 statement(" spvForward<LodOptions>(options)...);");
5899 statement("case spvXChromaLocation::midpoint:");
5900 statement(" if (!is_null_texture(plane2))");
5901 statement(" return spvChromaReconstructLinear422Midpoint(");
5902 statement(" plane0, plane1, plane2, samp,");
5903 statement(" coord, spvForward<LodOptions>(options)...);");
5904 statement(" return spvChromaReconstructLinear422Midpoint(");
5905 statement(" plane0, plane1, samp, coord,");
5906 statement(" spvForward<LodOptions>(options)...);");
5907 end_scope(); // switch (x_chroma_offset)
5908 end_scope(); // case 422:
5909 statement("case spvFormatResolution::_420:");
5910 begin_scope();
5911 statement("switch (ycbcr_samp.get_x_chroma_offset())");
5912 begin_scope();
5913 statement("case spvXChromaLocation::cosited_even:");
5914 begin_scope();
5915 statement("switch (ycbcr_samp.get_y_chroma_offset())");
5916 begin_scope();
5917 statement("case spvYChromaLocation::cosited_even:");
5918 statement(" if (!is_null_texture(plane2))");
5919 statement(" return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
5920 statement(" plane0, plane1, plane2, samp,");
5921 statement(" coord, spvForward<LodOptions>(options)...);");
5922 statement(" return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
5923 statement(" plane0, plane1, samp, coord,");
5924 statement(" spvForward<LodOptions>(options)...);");
5925 statement("case spvYChromaLocation::midpoint:");
5926 statement(" if (!is_null_texture(plane2))");
5927 statement(" return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
5928 statement(" plane0, plane1, plane2, samp,");
5929 statement(" coord, spvForward<LodOptions>(options)...);");
5930 statement(" return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
5931 statement(" plane0, plane1, samp, coord,");
5932 statement(" spvForward<LodOptions>(options)...);");
5933 end_scope(); // switch (y_chroma_offset)
5934 end_scope(); // case x::cosited_even:
5935 statement("case spvXChromaLocation::midpoint:");
5936 begin_scope();
5937 statement("switch (ycbcr_samp.get_y_chroma_offset())");
5938 begin_scope();
5939 statement("case spvYChromaLocation::cosited_even:");
5940 statement(" if (!is_null_texture(plane2))");
5941 statement(" return spvChromaReconstructLinear420XMidpointYCositedEven(");
5942 statement(" plane0, plane1, plane2, samp,");
5943 statement(" coord, spvForward<LodOptions>(options)...);");
5944 statement(" return spvChromaReconstructLinear420XMidpointYCositedEven(");
5945 statement(" plane0, plane1, samp, coord,");
5946 statement(" spvForward<LodOptions>(options)...);");
5947 statement("case spvYChromaLocation::midpoint:");
5948 statement(" if (!is_null_texture(plane2))");
5949 statement(" return spvChromaReconstructLinear420XMidpointYMidpoint(");
5950 statement(" plane0, plane1, plane2, samp,");
5951 statement(" coord, spvForward<LodOptions>(options)...);");
5952 statement(" return spvChromaReconstructLinear420XMidpointYMidpoint(");
5953 statement(" plane0, plane1, samp, coord,");
5954 statement(" spvForward<LodOptions>(options)...);");
5955 end_scope(); // switch (y_chroma_offset)
5956 end_scope(); // case x::midpoint
5957 end_scope(); // switch (x_chroma_offset)
5958 end_scope(); // case 420:
5959 end_scope(); // switch (resolution)
5960 end_scope(); // if (multiplanar)
5961 statement("return plane0.sample(samp, coord, spvForward<LodOptions>(options)...);");
5962 end_scope(); // do_sample()
5963 statement("template <typename... LodOptions>");
5964 statement("vec<T, 4> sample(float2 coord, LodOptions... options) const thread");
5965 begin_scope();
5966 statement(
5967 "vec<T, 4> s = spvTextureSwizzle(do_sample(coord, spvForward<LodOptions>(options)...), swizzle);");
5968 statement("if (ycbcr_samp.get_ycbcr_model() == spvYCbCrModelConversion::rgb_identity)");
5969 statement(" return s;");
5970 statement("");
5971 statement("switch (ycbcr_samp.get_ycbcr_range())");
5972 begin_scope();
5973 statement("case spvYCbCrRange::itu_full:");
5974 statement(" s = spvExpandITUFullRange(s, ycbcr_samp.get_bpc());");
5975 statement(" break;");
5976 statement("case spvYCbCrRange::itu_narrow:");
5977 statement(" s = spvExpandITUNarrowRange(s, ycbcr_samp.get_bpc());");
5978 statement(" break;");
5979 end_scope();
5980 statement("");
5981 statement("switch (ycbcr_samp.get_ycbcr_model())");
5982 begin_scope();
5983 statement("case spvYCbCrModelConversion::rgb_identity:"); // Silence Clang warning
5984 statement("case spvYCbCrModelConversion::ycbcr_identity:");
5985 statement(" return s;");
5986 statement("case spvYCbCrModelConversion::ycbcr_bt_709:");
5987 statement(" return spvConvertYCbCrBT709(s);");
5988 statement("case spvYCbCrModelConversion::ycbcr_bt_601:");
5989 statement(" return spvConvertYCbCrBT601(s);");
5990 statement("case spvYCbCrModelConversion::ycbcr_bt_2020:");
5991 statement(" return spvConvertYCbCrBT2020(s);");
5992 end_scope();
5993 end_scope();
5994 statement("");
5995 // Sampler Y'CbCr conversion forbids offsets.
5996 statement("vec<T, 4> sample(float2 coord, int2 offset) const thread");
5997 begin_scope();
5998 if (msl_options.swizzle_texture_samples)
5999 statement("return spvTextureSwizzle(plane0.sample(samp, coord, offset), swizzle);");
6000 else
6001 statement("return plane0.sample(samp, coord, offset);");
6002 end_scope();
6003 statement("template<typename lod_options>");
6004 statement("vec<T, 4> sample(float2 coord, lod_options options, int2 offset) const thread");
6005 begin_scope();
6006 if (msl_options.swizzle_texture_samples)
6007 statement("return spvTextureSwizzle(plane0.sample(samp, coord, options, offset), swizzle);");
6008 else
6009 statement("return plane0.sample(samp, coord, options, offset);");
6010 end_scope();
6011 statement("#if __HAVE_MIN_LOD_CLAMP__");
6012 statement("vec<T, 4> sample(float2 coord, bias b, min_lod_clamp min_lod, int2 offset) const thread");
6013 begin_scope();
6014 statement("return plane0.sample(samp, coord, b, min_lod, offset);");
6015 end_scope();
6016 statement(
6017 "vec<T, 4> sample(float2 coord, gradient2d grad, min_lod_clamp min_lod, int2 offset) const thread");
6018 begin_scope();
6019 statement("return plane0.sample(samp, coord, grad, min_lod, offset);");
6020 end_scope();
6021 statement("#endif");
6022 statement("");
6023 // Y'CbCr conversion forbids all operations but sampling.
6024 statement("vec<T, 4> read(uint2 coord, uint lod = 0) const thread");
6025 begin_scope();
6026 statement("return plane0.read(coord, lod);");
6027 end_scope();
6028 statement("");
6029 statement("vec<T, 4> gather(float2 coord, int2 offset = int2(0), component c = component::x) const thread");
6030 begin_scope();
6031 if (msl_options.swizzle_texture_samples)
6032 statement("return spvGatherSwizzle(plane0, samp, swizzle, c, coord, offset);");
6033 else
6034 statement("return plane0.gather(samp, coord, offset, c);");
6035 end_scope();
6036 end_scope_decl();
6037 statement("");
6038
6039 default:
6040 break;
6041 }
6042 }
6043 }
6044
6045 // Undefined global memory is not allowed in MSL.
6046 // Declare constant and init to zeros. Use {}, as global constructors can break Metal.
declare_undefined_values()6047 void CompilerMSL::declare_undefined_values()
6048 {
6049 bool emitted = false;
6050 ir.for_each_typed_id<SPIRUndef>([&](uint32_t, SPIRUndef &undef) {
6051 auto &type = this->get<SPIRType>(undef.basetype);
6052 // OpUndef can be void for some reason ...
6053 if (type.basetype == SPIRType::Void)
6054 return;
6055
6056 statement("constant ", variable_decl(type, to_name(undef.self), undef.self), " = {};");
6057 emitted = true;
6058 });
6059
6060 if (emitted)
6061 statement("");
6062 }
6063
declare_constant_arrays()6064 void CompilerMSL::declare_constant_arrays()
6065 {
6066 bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
6067
6068 // MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
6069 // global constants directly, so we are able to use constants as variable expressions.
6070 bool emitted = false;
6071
6072 ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
6073 if (c.specialization)
6074 return;
6075
6076 auto &type = this->get<SPIRType>(c.constant_type);
6077 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries.
6078 // FIXME: However, hoisting constants to main() means we need to pass down constant arrays to leaf functions if they are used there.
6079 // If there are multiple functions in the module, drop this case to avoid breaking use cases which do not need to
6080 // link into Metal libraries. This is hacky.
6081 if (!type.array.empty() && (!fully_inlined || is_scalar(type) || is_vector(type)))
6082 {
6083 auto name = to_name(c.self);
6084 statement("constant ", variable_decl(type, name), " = ", constant_expression(c), ";");
6085 emitted = true;
6086 }
6087 });
6088
6089 if (emitted)
6090 statement("");
6091 }
6092
6093 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
declare_complex_constant_arrays()6094 void CompilerMSL::declare_complex_constant_arrays()
6095 {
6096 // If we do not have a fully inlined module, we did not opt in to
6097 // declaring constant arrays of complex types. See CompilerMSL::declare_constant_arrays().
6098 bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
6099 if (!fully_inlined)
6100 return;
6101
6102 // MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
6103 // global constants directly, so we are able to use constants as variable expressions.
6104 bool emitted = false;
6105
6106 ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
6107 if (c.specialization)
6108 return;
6109
6110 auto &type = this->get<SPIRType>(c.constant_type);
6111 if (!type.array.empty() && !(is_scalar(type) || is_vector(type)))
6112 {
6113 auto name = to_name(c.self);
6114 statement("", variable_decl(type, name), " = ", constant_expression(c), ";");
6115 emitted = true;
6116 }
6117 });
6118
6119 if (emitted)
6120 statement("");
6121 }
6122
emit_resources()6123 void CompilerMSL::emit_resources()
6124 {
6125 declare_constant_arrays();
6126 declare_undefined_values();
6127
6128 // Emit the special [[stage_in]] and [[stage_out]] interface blocks which we created.
6129 emit_interface_block(stage_out_var_id);
6130 emit_interface_block(patch_stage_out_var_id);
6131 emit_interface_block(stage_in_var_id);
6132 emit_interface_block(patch_stage_in_var_id);
6133 }
6134
6135 // Emit declarations for the specialization Metal function constants
emit_specialization_constants_and_structs()6136 void CompilerMSL::emit_specialization_constants_and_structs()
6137 {
6138 SpecializationConstant wg_x, wg_y, wg_z;
6139 ID workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
6140 bool emitted = false;
6141
6142 unordered_set<uint32_t> declared_structs;
6143 unordered_set<uint32_t> aligned_structs;
6144
6145 // First, we need to deal with scalar block layout.
6146 // It is possible that a struct may have to be placed at an alignment which does not match the innate alignment of the struct itself.
6147 // In that case, if such a case exists for a struct, we must force that all elements of the struct become packed_ types.
6148 // This makes the struct alignment as small as physically possible.
6149 // When we actually align the struct later, we can insert padding as necessary to make the packed members behave like normally aligned types.
6150 ir.for_each_typed_id<SPIRType>([&](uint32_t type_id, const SPIRType &type) {
6151 if (type.basetype == SPIRType::Struct &&
6152 has_extended_decoration(type_id, SPIRVCrossDecorationBufferBlockRepacked))
6153 mark_scalar_layout_structs(type);
6154 });
6155
6156 bool builtin_block_type_is_required = false;
6157 // Very special case. If gl_PerVertex is initialized as an array (tessellation)
6158 // we have to potentially emit the gl_PerVertex struct type so that we can emit a constant LUT.
6159 ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
6160 auto &type = this->get<SPIRType>(c.constant_type);
6161 if (is_array(type) && has_decoration(type.self, DecorationBlock) && is_builtin_type(type))
6162 builtin_block_type_is_required = true;
6163 });
6164
6165 // Very particular use of the soft loop lock.
6166 // align_struct may need to create custom types on the fly, but we don't care about
6167 // these types for purpose of iterating over them in ir.ids_for_type and friends.
6168 auto loop_lock = ir.create_loop_soft_lock();
6169
6170 for (auto &id_ : ir.ids_for_constant_or_type)
6171 {
6172 auto &id = ir.ids[id_];
6173
6174 if (id.get_type() == TypeConstant)
6175 {
6176 auto &c = id.get<SPIRConstant>();
6177
6178 if (c.self == workgroup_size_id)
6179 {
6180 // TODO: This can be expressed as a [[threads_per_threadgroup]] input semantic, but we need to know
6181 // the work group size at compile time in SPIR-V, and [[threads_per_threadgroup]] would need to be passed around as a global.
6182 // The work group size may be a specialization constant.
6183 statement("constant uint3 ", builtin_to_glsl(BuiltInWorkgroupSize, StorageClassWorkgroup),
6184 " [[maybe_unused]] = ", constant_expression(get<SPIRConstant>(workgroup_size_id)), ";");
6185 emitted = true;
6186 }
6187 else if (c.specialization)
6188 {
6189 auto &type = get<SPIRType>(c.constant_type);
6190 string sc_type_name = type_to_glsl(type);
6191 string sc_name = to_name(c.self);
6192 string sc_tmp_name = sc_name + "_tmp";
6193
6194 // Function constants are only supported in MSL 1.2 and later.
6195 // If we don't support it just declare the "default" directly.
6196 // This "default" value can be overridden to the true specialization constant by the API user.
6197 // Specialization constants which are used as array length expressions cannot be function constants in MSL,
6198 // so just fall back to macros.
6199 if (msl_options.supports_msl_version(1, 2) && has_decoration(c.self, DecorationSpecId) &&
6200 !c.is_used_as_array_length)
6201 {
6202 uint32_t constant_id = get_decoration(c.self, DecorationSpecId);
6203 // Only scalar, non-composite values can be function constants.
6204 statement("constant ", sc_type_name, " ", sc_tmp_name, " [[function_constant(", constant_id,
6205 ")]];");
6206 statement("constant ", sc_type_name, " ", sc_name, " = is_function_constant_defined(", sc_tmp_name,
6207 ") ? ", sc_tmp_name, " : ", constant_expression(c), ";");
6208 }
6209 else if (has_decoration(c.self, DecorationSpecId))
6210 {
6211 // Fallback to macro overrides.
6212 c.specialization_constant_macro_name =
6213 constant_value_macro_name(get_decoration(c.self, DecorationSpecId));
6214
6215 statement("#ifndef ", c.specialization_constant_macro_name);
6216 statement("#define ", c.specialization_constant_macro_name, " ", constant_expression(c));
6217 statement("#endif");
6218 statement("constant ", sc_type_name, " ", sc_name, " = ", c.specialization_constant_macro_name,
6219 ";");
6220 }
6221 else
6222 {
6223 // Composite specialization constants must be built from other specialization constants.
6224 statement("constant ", sc_type_name, " ", sc_name, " = ", constant_expression(c), ";");
6225 }
6226 emitted = true;
6227 }
6228 }
6229 else if (id.get_type() == TypeConstantOp)
6230 {
6231 auto &c = id.get<SPIRConstantOp>();
6232 auto &type = get<SPIRType>(c.basetype);
6233 auto name = to_name(c.self);
6234 statement("constant ", variable_decl(type, name), " = ", constant_op_expression(c), ";");
6235 emitted = true;
6236 }
6237 else if (id.get_type() == TypeType)
6238 {
6239 // Output non-builtin interface structs. These include local function structs
6240 // and structs nested within uniform and read-write buffers.
6241 auto &type = id.get<SPIRType>();
6242 TypeID type_id = type.self;
6243
6244 bool is_struct = (type.basetype == SPIRType::Struct) && type.array.empty() && !type.pointer;
6245 bool is_block =
6246 has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock);
6247
6248 bool is_builtin_block = is_block && is_builtin_type(type);
6249 bool is_declarable_struct = is_struct && (!is_builtin_block || builtin_block_type_is_required);
6250
6251 // We'll declare this later.
6252 if (stage_out_var_id && get_stage_out_struct_type().self == type_id)
6253 is_declarable_struct = false;
6254 if (patch_stage_out_var_id && get_patch_stage_out_struct_type().self == type_id)
6255 is_declarable_struct = false;
6256 if (stage_in_var_id && get_stage_in_struct_type().self == type_id)
6257 is_declarable_struct = false;
6258 if (patch_stage_in_var_id && get_patch_stage_in_struct_type().self == type_id)
6259 is_declarable_struct = false;
6260
6261 // Align and emit declarable structs...but avoid declaring each more than once.
6262 if (is_declarable_struct && declared_structs.count(type_id) == 0)
6263 {
6264 if (emitted)
6265 statement("");
6266 emitted = false;
6267
6268 declared_structs.insert(type_id);
6269
6270 if (has_extended_decoration(type_id, SPIRVCrossDecorationBufferBlockRepacked))
6271 align_struct(type, aligned_structs);
6272
6273 // Make sure we declare the underlying struct type, and not the "decorated" type with pointers, etc.
6274 emit_struct(get<SPIRType>(type_id));
6275 }
6276 }
6277 }
6278
6279 if (emitted)
6280 statement("");
6281 }
6282
emit_binary_unord_op(uint32_t result_type,uint32_t result_id,uint32_t op0,uint32_t op1,const char * op)6283 void CompilerMSL::emit_binary_unord_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
6284 const char *op)
6285 {
6286 bool forward = should_forward(op0) && should_forward(op1);
6287 emit_op(result_type, result_id,
6288 join("(isunordered(", to_enclosed_unpacked_expression(op0), ", ", to_enclosed_unpacked_expression(op1),
6289 ") || ", to_enclosed_unpacked_expression(op0), " ", op, " ", to_enclosed_unpacked_expression(op1),
6290 ")"),
6291 forward);
6292
6293 inherit_expression_dependencies(result_id, op0);
6294 inherit_expression_dependencies(result_id, op1);
6295 }
6296
emit_tessellation_io_load(uint32_t result_type_id,uint32_t id,uint32_t ptr)6297 bool CompilerMSL::emit_tessellation_io_load(uint32_t result_type_id, uint32_t id, uint32_t ptr)
6298 {
6299 auto &ptr_type = expression_type(ptr);
6300 auto &result_type = get<SPIRType>(result_type_id);
6301 if (ptr_type.storage != StorageClassInput && ptr_type.storage != StorageClassOutput)
6302 return false;
6303 if (ptr_type.storage == StorageClassOutput && get_execution_model() == ExecutionModelTessellationEvaluation)
6304 return false;
6305
6306 bool multi_patch_tess_ctl = get_execution_model() == ExecutionModelTessellationControl &&
6307 msl_options.multi_patch_workgroup && ptr_type.storage == StorageClassInput;
6308 bool flat_matrix = is_matrix(result_type) && ptr_type.storage == StorageClassInput && !multi_patch_tess_ctl;
6309 bool flat_struct = result_type.basetype == SPIRType::Struct && ptr_type.storage == StorageClassInput;
6310 bool flat_data_type = flat_matrix || is_array(result_type) || flat_struct;
6311 if (!flat_data_type)
6312 return false;
6313
6314 if (has_decoration(ptr, DecorationPatch))
6315 return false;
6316
6317 // Now, we must unflatten a composite type and take care of interleaving array access with gl_in/gl_out.
6318 // Lots of painful code duplication since we *really* should not unroll these kinds of loads in entry point fixup
6319 // unless we're forced to do this when the code is emitting inoptimal OpLoads.
6320 string expr;
6321
6322 uint32_t interface_index = get_extended_decoration(ptr, SPIRVCrossDecorationInterfaceMemberIndex);
6323 auto *var = maybe_get_backing_variable(ptr);
6324 bool ptr_is_io_variable = ir.ids[ptr].get_type() == TypeVariable;
6325 auto &expr_type = get_pointee_type(ptr_type.self);
6326
6327 const auto &iface_type = expression_type(stage_in_ptr_var_id);
6328
6329 if (result_type.array.size() > 2)
6330 {
6331 SPIRV_CROSS_THROW("Cannot load tessellation IO variables with more than 2 dimensions.");
6332 }
6333 else if (result_type.array.size() == 2)
6334 {
6335 if (!ptr_is_io_variable)
6336 SPIRV_CROSS_THROW("Loading an array-of-array must be loaded directly from an IO variable.");
6337 if (interface_index == uint32_t(-1))
6338 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6339 if (result_type.basetype == SPIRType::Struct || flat_matrix)
6340 SPIRV_CROSS_THROW("Cannot load array-of-array of composite type in tessellation IO.");
6341
6342 expr += type_to_glsl(result_type) + "({ ";
6343 uint32_t num_control_points = to_array_size_literal(result_type, 1);
6344 uint32_t base_interface_index = interface_index;
6345
6346 auto &sub_type = get<SPIRType>(result_type.parent_type);
6347
6348 for (uint32_t i = 0; i < num_control_points; i++)
6349 {
6350 expr += type_to_glsl(sub_type) + "({ ";
6351 interface_index = base_interface_index;
6352 uint32_t array_size = to_array_size_literal(result_type, 0);
6353 if (multi_patch_tess_ctl)
6354 {
6355 for (uint32_t j = 0; j < array_size; j++)
6356 {
6357 const uint32_t indices[3] = { i, interface_index, j };
6358
6359 AccessChainMeta meta;
6360 expr +=
6361 access_chain_internal(stage_in_ptr_var_id, indices, 3,
6362 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6363 // If the expression has more vector components than the result type, insert
6364 // a swizzle. This shouldn't happen normally on valid SPIR-V, but it might
6365 // happen if we replace the type of an input variable.
6366 if (!is_matrix(sub_type) && sub_type.basetype != SPIRType::Struct &&
6367 expr_type.vecsize > sub_type.vecsize)
6368 expr += vector_swizzle(sub_type.vecsize, 0);
6369
6370 if (j + 1 < array_size)
6371 expr += ", ";
6372 }
6373 }
6374 else
6375 {
6376 for (uint32_t j = 0; j < array_size; j++, interface_index++)
6377 {
6378 const uint32_t indices[2] = { i, interface_index };
6379
6380 AccessChainMeta meta;
6381 expr +=
6382 access_chain_internal(stage_in_ptr_var_id, indices, 2,
6383 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6384 if (!is_matrix(sub_type) && sub_type.basetype != SPIRType::Struct &&
6385 expr_type.vecsize > sub_type.vecsize)
6386 expr += vector_swizzle(sub_type.vecsize, 0);
6387
6388 if (j + 1 < array_size)
6389 expr += ", ";
6390 }
6391 }
6392 expr += " })";
6393 if (i + 1 < num_control_points)
6394 expr += ", ";
6395 }
6396 expr += " })";
6397 }
6398 else if (flat_struct)
6399 {
6400 bool is_array_of_struct = is_array(result_type);
6401 if (is_array_of_struct && !ptr_is_io_variable)
6402 SPIRV_CROSS_THROW("Loading array of struct from IO variable must come directly from IO variable.");
6403
6404 uint32_t num_control_points = 1;
6405 if (is_array_of_struct)
6406 {
6407 num_control_points = to_array_size_literal(result_type, 0);
6408 expr += type_to_glsl(result_type) + "({ ";
6409 }
6410
6411 auto &struct_type = is_array_of_struct ? get<SPIRType>(result_type.parent_type) : result_type;
6412 assert(struct_type.array.empty());
6413
6414 for (uint32_t i = 0; i < num_control_points; i++)
6415 {
6416 expr += type_to_glsl(struct_type) + "{ ";
6417 for (uint32_t j = 0; j < uint32_t(struct_type.member_types.size()); j++)
6418 {
6419 // The base interface index is stored per variable for structs.
6420 if (var)
6421 {
6422 interface_index =
6423 get_extended_member_decoration(var->self, j, SPIRVCrossDecorationInterfaceMemberIndex);
6424 }
6425
6426 if (interface_index == uint32_t(-1))
6427 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6428
6429 const auto &mbr_type = get<SPIRType>(struct_type.member_types[j]);
6430 const auto &expr_mbr_type = get<SPIRType>(expr_type.member_types[j]);
6431 if (is_matrix(mbr_type) && ptr_type.storage == StorageClassInput && !multi_patch_tess_ctl)
6432 {
6433 expr += type_to_glsl(mbr_type) + "(";
6434 for (uint32_t k = 0; k < mbr_type.columns; k++, interface_index++)
6435 {
6436 if (is_array_of_struct)
6437 {
6438 const uint32_t indices[2] = { i, interface_index };
6439 AccessChainMeta meta;
6440 expr += access_chain_internal(
6441 stage_in_ptr_var_id, indices, 2,
6442 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6443 }
6444 else
6445 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6446 if (expr_mbr_type.vecsize > mbr_type.vecsize)
6447 expr += vector_swizzle(mbr_type.vecsize, 0);
6448
6449 if (k + 1 < mbr_type.columns)
6450 expr += ", ";
6451 }
6452 expr += ")";
6453 }
6454 else if (is_array(mbr_type))
6455 {
6456 expr += type_to_glsl(mbr_type) + "({ ";
6457 uint32_t array_size = to_array_size_literal(mbr_type, 0);
6458 if (multi_patch_tess_ctl)
6459 {
6460 for (uint32_t k = 0; k < array_size; k++)
6461 {
6462 if (is_array_of_struct)
6463 {
6464 const uint32_t indices[3] = { i, interface_index, k };
6465 AccessChainMeta meta;
6466 expr += access_chain_internal(
6467 stage_in_ptr_var_id, indices, 3,
6468 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6469 }
6470 else
6471 expr += join(to_expression(ptr), ".", to_member_name(iface_type, interface_index), "[",
6472 k, "]");
6473 if (expr_mbr_type.vecsize > mbr_type.vecsize)
6474 expr += vector_swizzle(mbr_type.vecsize, 0);
6475
6476 if (k + 1 < array_size)
6477 expr += ", ";
6478 }
6479 }
6480 else
6481 {
6482 for (uint32_t k = 0; k < array_size; k++, interface_index++)
6483 {
6484 if (is_array_of_struct)
6485 {
6486 const uint32_t indices[2] = { i, interface_index };
6487 AccessChainMeta meta;
6488 expr += access_chain_internal(
6489 stage_in_ptr_var_id, indices, 2,
6490 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6491 }
6492 else
6493 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6494 if (expr_mbr_type.vecsize > mbr_type.vecsize)
6495 expr += vector_swizzle(mbr_type.vecsize, 0);
6496
6497 if (k + 1 < array_size)
6498 expr += ", ";
6499 }
6500 }
6501 expr += " })";
6502 }
6503 else
6504 {
6505 if (is_array_of_struct)
6506 {
6507 const uint32_t indices[2] = { i, interface_index };
6508 AccessChainMeta meta;
6509 expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
6510 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT,
6511 &meta);
6512 }
6513 else
6514 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6515 if (expr_mbr_type.vecsize > mbr_type.vecsize)
6516 expr += vector_swizzle(mbr_type.vecsize, 0);
6517 }
6518
6519 if (j + 1 < struct_type.member_types.size())
6520 expr += ", ";
6521 }
6522 expr += " }";
6523 if (i + 1 < num_control_points)
6524 expr += ", ";
6525 }
6526 if (is_array_of_struct)
6527 expr += " })";
6528 }
6529 else if (flat_matrix)
6530 {
6531 bool is_array_of_matrix = is_array(result_type);
6532 if (is_array_of_matrix && !ptr_is_io_variable)
6533 SPIRV_CROSS_THROW("Loading array of matrix from IO variable must come directly from IO variable.");
6534 if (interface_index == uint32_t(-1))
6535 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6536
6537 if (is_array_of_matrix)
6538 {
6539 // Loading a matrix from each control point.
6540 uint32_t base_interface_index = interface_index;
6541 uint32_t num_control_points = to_array_size_literal(result_type, 0);
6542 expr += type_to_glsl(result_type) + "({ ";
6543
6544 auto &matrix_type = get_variable_element_type(get<SPIRVariable>(ptr));
6545
6546 for (uint32_t i = 0; i < num_control_points; i++)
6547 {
6548 interface_index = base_interface_index;
6549 expr += type_to_glsl(matrix_type) + "(";
6550 for (uint32_t j = 0; j < result_type.columns; j++, interface_index++)
6551 {
6552 const uint32_t indices[2] = { i, interface_index };
6553
6554 AccessChainMeta meta;
6555 expr +=
6556 access_chain_internal(stage_in_ptr_var_id, indices, 2,
6557 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6558 if (expr_type.vecsize > result_type.vecsize)
6559 expr += vector_swizzle(result_type.vecsize, 0);
6560 if (j + 1 < result_type.columns)
6561 expr += ", ";
6562 }
6563 expr += ")";
6564 if (i + 1 < num_control_points)
6565 expr += ", ";
6566 }
6567
6568 expr += " })";
6569 }
6570 else
6571 {
6572 expr += type_to_glsl(result_type) + "(";
6573 for (uint32_t i = 0; i < result_type.columns; i++, interface_index++)
6574 {
6575 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6576 if (expr_type.vecsize > result_type.vecsize)
6577 expr += vector_swizzle(result_type.vecsize, 0);
6578 if (i + 1 < result_type.columns)
6579 expr += ", ";
6580 }
6581 expr += ")";
6582 }
6583 }
6584 else if (ptr_is_io_variable)
6585 {
6586 assert(is_array(result_type));
6587 assert(result_type.array.size() == 1);
6588 if (interface_index == uint32_t(-1))
6589 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6590
6591 // We're loading an array directly from a global variable.
6592 // This means we're loading one member from each control point.
6593 expr += type_to_glsl(result_type) + "({ ";
6594 uint32_t num_control_points = to_array_size_literal(result_type, 0);
6595
6596 for (uint32_t i = 0; i < num_control_points; i++)
6597 {
6598 const uint32_t indices[2] = { i, interface_index };
6599
6600 AccessChainMeta meta;
6601 expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
6602 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6603 if (expr_type.vecsize > result_type.vecsize)
6604 expr += vector_swizzle(result_type.vecsize, 0);
6605
6606 if (i + 1 < num_control_points)
6607 expr += ", ";
6608 }
6609 expr += " })";
6610 }
6611 else
6612 {
6613 // We're loading an array from a concrete control point.
6614 assert(is_array(result_type));
6615 assert(result_type.array.size() == 1);
6616 if (interface_index == uint32_t(-1))
6617 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6618
6619 expr += type_to_glsl(result_type) + "({ ";
6620 uint32_t array_size = to_array_size_literal(result_type, 0);
6621 for (uint32_t i = 0; i < array_size; i++, interface_index++)
6622 {
6623 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6624 if (expr_type.vecsize > result_type.vecsize)
6625 expr += vector_swizzle(result_type.vecsize, 0);
6626 if (i + 1 < array_size)
6627 expr += ", ";
6628 }
6629 expr += " })";
6630 }
6631
6632 emit_op(result_type_id, id, expr, false);
6633 register_read(id, ptr, false);
6634 return true;
6635 }
6636
emit_tessellation_access_chain(const uint32_t * ops,uint32_t length)6637 bool CompilerMSL::emit_tessellation_access_chain(const uint32_t *ops, uint32_t length)
6638 {
6639 // If this is a per-vertex output, remap it to the I/O array buffer.
6640
6641 // Any object which did not go through IO flattening shenanigans will go there instead.
6642 // We will unflatten on-demand instead as needed, but not all possible cases can be supported, especially with arrays.
6643
6644 auto *var = maybe_get_backing_variable(ops[2]);
6645 bool patch = false;
6646 bool flat_data = false;
6647 bool ptr_is_chain = false;
6648 bool multi_patch = get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup;
6649
6650 if (var)
6651 {
6652 patch = has_decoration(ops[2], DecorationPatch) || is_patch_block(get_variable_data_type(*var));
6653
6654 // Should match strip_array in add_interface_block.
6655 flat_data = var->storage == StorageClassInput ||
6656 (var->storage == StorageClassOutput && get_execution_model() == ExecutionModelTessellationControl);
6657
6658 // We might have a chained access chain, where
6659 // we first take the access chain to the control point, and then we chain into a member or something similar.
6660 // In this case, we need to skip gl_in/gl_out remapping.
6661 ptr_is_chain = var->self != ID(ops[2]);
6662 }
6663
6664 BuiltIn bi_type = BuiltIn(get_decoration(ops[2], DecorationBuiltIn));
6665 if (var && flat_data && !patch &&
6666 (!is_builtin_variable(*var) || bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
6667 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance ||
6668 get_variable_data_type(*var).basetype == SPIRType::Struct))
6669 {
6670 AccessChainMeta meta;
6671 SmallVector<uint32_t> indices;
6672 uint32_t next_id = ir.increase_bound_by(1);
6673
6674 indices.reserve(length - 3 + 1);
6675
6676 uint32_t first_non_array_index = ptr_is_chain ? 3 : 4;
6677 VariableID stage_var_id = var->storage == StorageClassInput ? stage_in_ptr_var_id : stage_out_ptr_var_id;
6678 VariableID ptr = ptr_is_chain ? VariableID(ops[2]) : stage_var_id;
6679 if (!ptr_is_chain)
6680 {
6681 // Index into gl_in/gl_out with first array index.
6682 indices.push_back(ops[3]);
6683 }
6684
6685 auto &result_ptr_type = get<SPIRType>(ops[0]);
6686
6687 uint32_t const_mbr_id = next_id++;
6688 uint32_t index = get_extended_decoration(var->self, SPIRVCrossDecorationInterfaceMemberIndex);
6689 if (var->storage == StorageClassInput || has_decoration(get_variable_element_type(*var).self, DecorationBlock))
6690 {
6691 uint32_t i = first_non_array_index;
6692 auto *type = &get_variable_element_type(*var);
6693 if (index == uint32_t(-1) && length >= (first_non_array_index + 1))
6694 {
6695 // Maybe this is a struct type in the input class, in which case
6696 // we put it as a decoration on the corresponding member.
6697 index = get_extended_member_decoration(var->self, get_constant(ops[first_non_array_index]).scalar(),
6698 SPIRVCrossDecorationInterfaceMemberIndex);
6699 assert(index != uint32_t(-1));
6700 i++;
6701 type = &get<SPIRType>(type->member_types[get_constant(ops[first_non_array_index]).scalar()]);
6702 }
6703
6704 // In this case, we're poking into flattened structures and arrays, so now we have to
6705 // combine the following indices. If we encounter a non-constant index,
6706 // we're hosed.
6707 for (; i < length; ++i)
6708 {
6709 if ((multi_patch || (!is_array(*type) && !is_matrix(*type))) && type->basetype != SPIRType::Struct)
6710 break;
6711
6712 auto *c = maybe_get<SPIRConstant>(ops[i]);
6713 if (!c || c->specialization)
6714 SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable in tessellation. "
6715 "This is currently unsupported.");
6716
6717 // We're in flattened space, so just increment the member index into IO block.
6718 // We can only do this once in the current implementation, so either:
6719 // Struct, Matrix or 1-dimensional array for a control point.
6720 index += c->scalar();
6721
6722 if (type->parent_type)
6723 type = &get<SPIRType>(type->parent_type);
6724 else if (type->basetype == SPIRType::Struct)
6725 type = &get<SPIRType>(type->member_types[c->scalar()]);
6726 }
6727
6728 if ((!multi_patch && (is_matrix(result_ptr_type) || is_array(result_ptr_type))) ||
6729 result_ptr_type.basetype == SPIRType::Struct)
6730 {
6731 // We're not going to emit the actual member name, we let any further OpLoad take care of that.
6732 // Tag the access chain with the member index we're referencing.
6733 set_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex, index);
6734 }
6735 else
6736 {
6737 // Access the appropriate member of gl_in/gl_out.
6738 set<SPIRConstant>(const_mbr_id, get_uint_type_id(), index, false);
6739 indices.push_back(const_mbr_id);
6740
6741 // Append any straggling access chain indices.
6742 if (i < length)
6743 indices.insert(indices.end(), ops + i, ops + length);
6744 }
6745 }
6746 else
6747 {
6748 assert(index != uint32_t(-1));
6749 set<SPIRConstant>(const_mbr_id, get_uint_type_id(), index, false);
6750 indices.push_back(const_mbr_id);
6751
6752 indices.insert(indices.end(), ops + 4, ops + length);
6753 }
6754
6755 // We use the pointer to the base of the input/output array here,
6756 // so this is always a pointer chain.
6757 string e;
6758
6759 if (!ptr_is_chain)
6760 {
6761 // This is the start of an access chain, use ptr_chain to index into control point array.
6762 e = access_chain(ptr, indices.data(), uint32_t(indices.size()), result_ptr_type, &meta, true);
6763 }
6764 else
6765 {
6766 // If we're accessing a struct, we need to use member indices which are based on the IO block,
6767 // not actual struct type, so we have to use a split access chain here where
6768 // first path resolves the control point index, i.e. gl_in[index], and second half deals with
6769 // looking up flattened member name.
6770
6771 // However, it is possible that we partially accessed a struct,
6772 // by taking pointer to member inside the control-point array.
6773 // For this case, we fall back to a natural access chain since we have already dealt with remapping struct members.
6774 // One way to check this here is if we have 2 implied read expressions.
6775 // First one is the gl_in/gl_out struct itself, then an index into that array.
6776 // If we have traversed further, we use a normal access chain formulation.
6777 auto *ptr_expr = maybe_get<SPIRExpression>(ptr);
6778 if (ptr_expr && ptr_expr->implied_read_expressions.size() == 2)
6779 {
6780 e = join(to_expression(ptr),
6781 access_chain_internal(stage_var_id, indices.data(), uint32_t(indices.size()),
6782 ACCESS_CHAIN_CHAIN_ONLY_BIT, &meta));
6783 }
6784 else
6785 {
6786 e = access_chain_internal(ptr, indices.data(), uint32_t(indices.size()), 0, &meta);
6787 }
6788 }
6789
6790 // Get the actual type of the object that was accessed. If it's a vector type and we changed it,
6791 // then we'll need to add a swizzle.
6792 // For this, we can't necessarily rely on the type of the base expression, because it might be
6793 // another access chain, and it will therefore already have the "correct" type.
6794 auto *expr_type = &get_variable_data_type(*var);
6795 if (has_extended_decoration(ops[2], SPIRVCrossDecorationTessIOOriginalInputTypeID))
6796 expr_type = &get<SPIRType>(get_extended_decoration(ops[2], SPIRVCrossDecorationTessIOOriginalInputTypeID));
6797 for (uint32_t i = 3; i < length; i++)
6798 {
6799 if (!is_array(*expr_type) && expr_type->basetype == SPIRType::Struct)
6800 expr_type = &get<SPIRType>(expr_type->member_types[get<SPIRConstant>(ops[i]).scalar()]);
6801 else
6802 expr_type = &get<SPIRType>(expr_type->parent_type);
6803 }
6804 if (!is_array(*expr_type) && !is_matrix(*expr_type) && expr_type->basetype != SPIRType::Struct &&
6805 expr_type->vecsize > result_ptr_type.vecsize)
6806 e += vector_swizzle(result_ptr_type.vecsize, 0);
6807
6808 auto &expr = set<SPIRExpression>(ops[1], move(e), ops[0], should_forward(ops[2]));
6809 expr.loaded_from = var->self;
6810 expr.need_transpose = meta.need_transpose;
6811 expr.access_chain = true;
6812
6813 // Mark the result as being packed if necessary.
6814 if (meta.storage_is_packed)
6815 set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypePacked);
6816 if (meta.storage_physical_type != 0)
6817 set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypeID, meta.storage_physical_type);
6818 if (meta.storage_is_invariant)
6819 set_decoration(ops[1], DecorationInvariant);
6820 // Save the type we found in case the result is used in another access chain.
6821 set_extended_decoration(ops[1], SPIRVCrossDecorationTessIOOriginalInputTypeID, expr_type->self);
6822
6823 // If we have some expression dependencies in our access chain, this access chain is technically a forwarded
6824 // temporary which could be subject to invalidation.
6825 // Need to assume we're forwarded while calling inherit_expression_depdendencies.
6826 forwarded_temporaries.insert(ops[1]);
6827 // The access chain itself is never forced to a temporary, but its dependencies might.
6828 suppressed_usage_tracking.insert(ops[1]);
6829
6830 for (uint32_t i = 2; i < length; i++)
6831 {
6832 inherit_expression_dependencies(ops[1], ops[i]);
6833 add_implied_read_expression(expr, ops[i]);
6834 }
6835
6836 // If we have no dependencies after all, i.e., all indices in the access chain are immutable temporaries,
6837 // we're not forwarded after all.
6838 if (expr.expression_dependencies.empty())
6839 forwarded_temporaries.erase(ops[1]);
6840
6841 return true;
6842 }
6843
6844 // If this is the inner tessellation level, and we're tessellating triangles,
6845 // drop the last index. It isn't an array in this case, so we can't have an
6846 // array reference here. We need to make this ID a variable instead of an
6847 // expression so we don't try to dereference it as a variable pointer.
6848 // Don't do this if the index is a constant 1, though. We need to drop stores
6849 // to that one.
6850 auto *m = ir.find_meta(var ? var->self : ID(0));
6851 if (get_execution_model() == ExecutionModelTessellationControl && var && m &&
6852 m->decoration.builtin_type == BuiltInTessLevelInner && get_entry_point().flags.get(ExecutionModeTriangles))
6853 {
6854 auto *c = maybe_get<SPIRConstant>(ops[3]);
6855 if (c && c->scalar() == 1)
6856 return false;
6857 auto &dest_var = set<SPIRVariable>(ops[1], *var);
6858 dest_var.basetype = ops[0];
6859 ir.meta[ops[1]] = ir.meta[ops[2]];
6860 inherit_expression_dependencies(ops[1], ops[2]);
6861 return true;
6862 }
6863
6864 return false;
6865 }
6866
is_out_of_bounds_tessellation_level(uint32_t id_lhs)6867 bool CompilerMSL::is_out_of_bounds_tessellation_level(uint32_t id_lhs)
6868 {
6869 if (!get_entry_point().flags.get(ExecutionModeTriangles))
6870 return false;
6871
6872 // In SPIR-V, TessLevelInner always has two elements and TessLevelOuter always has
6873 // four. This is true even if we are tessellating triangles. This allows clients
6874 // to use a single tessellation control shader with multiple tessellation evaluation
6875 // shaders.
6876 // In Metal, however, only the first element of TessLevelInner and the first three
6877 // of TessLevelOuter are accessible. This stems from how in Metal, the tessellation
6878 // levels must be stored to a dedicated buffer in a particular format that depends
6879 // on the patch type. Therefore, in Triangles mode, any access to the second
6880 // inner level or the fourth outer level must be dropped.
6881 const auto *e = maybe_get<SPIRExpression>(id_lhs);
6882 if (!e || !e->access_chain)
6883 return false;
6884 BuiltIn builtin = BuiltIn(get_decoration(e->loaded_from, DecorationBuiltIn));
6885 if (builtin != BuiltInTessLevelInner && builtin != BuiltInTessLevelOuter)
6886 return false;
6887 auto *c = maybe_get<SPIRConstant>(e->implied_read_expressions[1]);
6888 if (!c)
6889 return false;
6890 return (builtin == BuiltInTessLevelInner && c->scalar() == 1) ||
6891 (builtin == BuiltInTessLevelOuter && c->scalar() == 3);
6892 }
6893
prepare_access_chain_for_scalar_access(std::string & expr,const SPIRType & type,spv::StorageClass storage,bool & is_packed)6894 void CompilerMSL::prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
6895 spv::StorageClass storage, bool &is_packed)
6896 {
6897 // If there is any risk of writes happening with the access chain in question,
6898 // and there is a risk of concurrent write access to other components,
6899 // we must cast the access chain to a plain pointer to ensure we only access the exact scalars we expect.
6900 // The MSL compiler refuses to allow component-level access for any non-packed vector types.
6901 if (!is_packed && (storage == StorageClassStorageBuffer || storage == StorageClassWorkgroup))
6902 {
6903 const char *addr_space = storage == StorageClassWorkgroup ? "threadgroup" : "device";
6904 expr = join("((", addr_space, " ", type_to_glsl(type), "*)&", enclose_expression(expr), ")");
6905
6906 // Further indexing should happen with packed rules (array index, not swizzle).
6907 is_packed = true;
6908 }
6909 }
6910
6911 // Sets the interface member index for an access chain to a pull-model interpolant.
fix_up_interpolant_access_chain(const uint32_t * ops,uint32_t length)6912 void CompilerMSL::fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t length)
6913 {
6914 auto *var = maybe_get_backing_variable(ops[2]);
6915 if (!var || !pull_model_inputs.count(var->self))
6916 return;
6917 // Get the base index.
6918 uint32_t interface_index;
6919 auto &var_type = get_variable_data_type(*var);
6920 auto &result_type = get<SPIRType>(ops[0]);
6921 auto *type = &var_type;
6922 if (has_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex))
6923 {
6924 interface_index = get_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex);
6925 }
6926 else
6927 {
6928 // Assume an access chain into a struct variable.
6929 assert(var_type.basetype == SPIRType::Struct);
6930 auto &c = get<SPIRConstant>(ops[3 + var_type.array.size()]);
6931 interface_index =
6932 get_extended_member_decoration(var->self, c.scalar(), SPIRVCrossDecorationInterfaceMemberIndex);
6933 }
6934 // Accumulate indices. We'll have to skip over the one for the struct, if present, because we already accounted
6935 // for that getting the base index.
6936 for (uint32_t i = 3; i < length; ++i)
6937 {
6938 if (is_vector(*type) && is_scalar(result_type))
6939 {
6940 // We don't want to combine the next index. Actually, we need to save it
6941 // so we know to apply a swizzle to the result of the interpolation.
6942 set_extended_decoration(ops[1], SPIRVCrossDecorationInterpolantComponentExpr, ops[i]);
6943 break;
6944 }
6945
6946 auto *c = maybe_get<SPIRConstant>(ops[i]);
6947 if (!c || c->specialization)
6948 SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable using pull-model "
6949 "interpolation. This is currently unsupported.");
6950
6951 if (type->parent_type)
6952 type = &get<SPIRType>(type->parent_type);
6953 else if (type->basetype == SPIRType::Struct)
6954 type = &get<SPIRType>(type->member_types[c->scalar()]);
6955
6956 if (!has_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex) &&
6957 i - 3 == var_type.array.size())
6958 continue;
6959
6960 interface_index += c->scalar();
6961 }
6962 // Save this to the access chain itself so we can recover it later when calling an interpolation function.
6963 set_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex, interface_index);
6964 }
6965
6966 // Override for MSL-specific syntax instructions
emit_instruction(const Instruction & instruction)6967 void CompilerMSL::emit_instruction(const Instruction &instruction)
6968 {
6969 #define MSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
6970 #define MSL_BOP_CAST(op, type) \
6971 emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
6972 #define MSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
6973 #define MSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
6974 #define MSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
6975 #define MSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
6976 #define MSL_BFOP_CAST(op, type) \
6977 emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
6978 #define MSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
6979 #define MSL_UNORD_BOP(op) emit_binary_unord_op(ops[0], ops[1], ops[2], ops[3], #op)
6980
6981 auto ops = stream(instruction);
6982 auto opcode = static_cast<Op>(instruction.op);
6983
6984 // If we need to do implicit bitcasts, make sure we do it with the correct type.
6985 uint32_t integer_width = get_integer_width_for_instruction(instruction);
6986 auto int_type = to_signed_basetype(integer_width);
6987 auto uint_type = to_unsigned_basetype(integer_width);
6988
6989 switch (opcode)
6990 {
6991 case OpLoad:
6992 {
6993 uint32_t id = ops[1];
6994 uint32_t ptr = ops[2];
6995 if (is_tessellation_shader())
6996 {
6997 if (!emit_tessellation_io_load(ops[0], id, ptr))
6998 CompilerGLSL::emit_instruction(instruction);
6999 }
7000 else
7001 {
7002 // Sample mask input for Metal is not an array
7003 if (BuiltIn(get_decoration(ptr, DecorationBuiltIn)) == BuiltInSampleMask)
7004 set_decoration(id, DecorationBuiltIn, BuiltInSampleMask);
7005 CompilerGLSL::emit_instruction(instruction);
7006 }
7007 break;
7008 }
7009
7010 // Comparisons
7011 case OpIEqual:
7012 MSL_BOP_CAST(==, int_type);
7013 break;
7014
7015 case OpLogicalEqual:
7016 case OpFOrdEqual:
7017 MSL_BOP(==);
7018 break;
7019
7020 case OpINotEqual:
7021 MSL_BOP_CAST(!=, int_type);
7022 break;
7023
7024 case OpLogicalNotEqual:
7025 case OpFOrdNotEqual:
7026 MSL_BOP(!=);
7027 break;
7028
7029 case OpUGreaterThan:
7030 MSL_BOP_CAST(>, uint_type);
7031 break;
7032
7033 case OpSGreaterThan:
7034 MSL_BOP_CAST(>, int_type);
7035 break;
7036
7037 case OpFOrdGreaterThan:
7038 MSL_BOP(>);
7039 break;
7040
7041 case OpUGreaterThanEqual:
7042 MSL_BOP_CAST(>=, uint_type);
7043 break;
7044
7045 case OpSGreaterThanEqual:
7046 MSL_BOP_CAST(>=, int_type);
7047 break;
7048
7049 case OpFOrdGreaterThanEqual:
7050 MSL_BOP(>=);
7051 break;
7052
7053 case OpULessThan:
7054 MSL_BOP_CAST(<, uint_type);
7055 break;
7056
7057 case OpSLessThan:
7058 MSL_BOP_CAST(<, int_type);
7059 break;
7060
7061 case OpFOrdLessThan:
7062 MSL_BOP(<);
7063 break;
7064
7065 case OpULessThanEqual:
7066 MSL_BOP_CAST(<=, uint_type);
7067 break;
7068
7069 case OpSLessThanEqual:
7070 MSL_BOP_CAST(<=, int_type);
7071 break;
7072
7073 case OpFOrdLessThanEqual:
7074 MSL_BOP(<=);
7075 break;
7076
7077 case OpFUnordEqual:
7078 MSL_UNORD_BOP(==);
7079 break;
7080
7081 case OpFUnordNotEqual:
7082 MSL_UNORD_BOP(!=);
7083 break;
7084
7085 case OpFUnordGreaterThan:
7086 MSL_UNORD_BOP(>);
7087 break;
7088
7089 case OpFUnordGreaterThanEqual:
7090 MSL_UNORD_BOP(>=);
7091 break;
7092
7093 case OpFUnordLessThan:
7094 MSL_UNORD_BOP(<);
7095 break;
7096
7097 case OpFUnordLessThanEqual:
7098 MSL_UNORD_BOP(<=);
7099 break;
7100
7101 // Derivatives
7102 case OpDPdx:
7103 case OpDPdxFine:
7104 case OpDPdxCoarse:
7105 MSL_UFOP(dfdx);
7106 register_control_dependent_expression(ops[1]);
7107 break;
7108
7109 case OpDPdy:
7110 case OpDPdyFine:
7111 case OpDPdyCoarse:
7112 MSL_UFOP(dfdy);
7113 register_control_dependent_expression(ops[1]);
7114 break;
7115
7116 case OpFwidth:
7117 case OpFwidthCoarse:
7118 case OpFwidthFine:
7119 MSL_UFOP(fwidth);
7120 register_control_dependent_expression(ops[1]);
7121 break;
7122
7123 // Bitfield
7124 case OpBitFieldInsert:
7125 {
7126 emit_bitfield_insert_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], "insert_bits", SPIRType::UInt);
7127 break;
7128 }
7129
7130 case OpBitFieldSExtract:
7131 {
7132 emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "extract_bits", int_type, int_type,
7133 SPIRType::UInt, SPIRType::UInt);
7134 break;
7135 }
7136
7137 case OpBitFieldUExtract:
7138 {
7139 emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "extract_bits", uint_type, uint_type,
7140 SPIRType::UInt, SPIRType::UInt);
7141 break;
7142 }
7143
7144 case OpBitReverse:
7145 // BitReverse does not have issues with sign since result type must match input type.
7146 MSL_UFOP(reverse_bits);
7147 break;
7148
7149 case OpBitCount:
7150 {
7151 auto basetype = expression_type(ops[2]).basetype;
7152 emit_unary_func_op_cast(ops[0], ops[1], ops[2], "popcount", basetype, basetype);
7153 break;
7154 }
7155
7156 case OpFRem:
7157 MSL_BFOP(fmod);
7158 break;
7159
7160 case OpFMul:
7161 if (msl_options.invariant_float_math)
7162 MSL_BFOP(spvFMul);
7163 else
7164 MSL_BOP(*);
7165 break;
7166
7167 case OpFAdd:
7168 if (msl_options.invariant_float_math)
7169 MSL_BFOP(spvFAdd);
7170 else
7171 MSL_BOP(+);
7172 break;
7173
7174 // Atomics
7175 case OpAtomicExchange:
7176 {
7177 uint32_t result_type = ops[0];
7178 uint32_t id = ops[1];
7179 uint32_t ptr = ops[2];
7180 uint32_t mem_sem = ops[4];
7181 uint32_t val = ops[5];
7182 emit_atomic_func_op(result_type, id, "atomic_exchange_explicit", mem_sem, mem_sem, false, ptr, val);
7183 break;
7184 }
7185
7186 case OpAtomicCompareExchange:
7187 {
7188 uint32_t result_type = ops[0];
7189 uint32_t id = ops[1];
7190 uint32_t ptr = ops[2];
7191 uint32_t mem_sem_pass = ops[4];
7192 uint32_t mem_sem_fail = ops[5];
7193 uint32_t val = ops[6];
7194 uint32_t comp = ops[7];
7195 emit_atomic_func_op(result_type, id, "atomic_compare_exchange_weak_explicit", mem_sem_pass, mem_sem_fail, true,
7196 ptr, comp, true, false, val);
7197 break;
7198 }
7199
7200 case OpAtomicCompareExchangeWeak:
7201 SPIRV_CROSS_THROW("OpAtomicCompareExchangeWeak is only supported in kernel profile.");
7202
7203 case OpAtomicLoad:
7204 {
7205 uint32_t result_type = ops[0];
7206 uint32_t id = ops[1];
7207 uint32_t ptr = ops[2];
7208 uint32_t mem_sem = ops[4];
7209 emit_atomic_func_op(result_type, id, "atomic_load_explicit", mem_sem, mem_sem, false, ptr, 0);
7210 break;
7211 }
7212
7213 case OpAtomicStore:
7214 {
7215 uint32_t result_type = expression_type(ops[0]).self;
7216 uint32_t id = ops[0];
7217 uint32_t ptr = ops[0];
7218 uint32_t mem_sem = ops[2];
7219 uint32_t val = ops[3];
7220 emit_atomic_func_op(result_type, id, "atomic_store_explicit", mem_sem, mem_sem, false, ptr, val);
7221 break;
7222 }
7223
7224 #define MSL_AFMO_IMPL(op, valsrc, valconst) \
7225 do \
7226 { \
7227 uint32_t result_type = ops[0]; \
7228 uint32_t id = ops[1]; \
7229 uint32_t ptr = ops[2]; \
7230 uint32_t mem_sem = ops[4]; \
7231 uint32_t val = valsrc; \
7232 emit_atomic_func_op(result_type, id, "atomic_fetch_" #op "_explicit", mem_sem, mem_sem, false, ptr, val, \
7233 false, valconst); \
7234 } while (false)
7235
7236 #define MSL_AFMO(op) MSL_AFMO_IMPL(op, ops[5], false)
7237 #define MSL_AFMIO(op) MSL_AFMO_IMPL(op, 1, true)
7238
7239 case OpAtomicIIncrement:
7240 MSL_AFMIO(add);
7241 break;
7242
7243 case OpAtomicIDecrement:
7244 MSL_AFMIO(sub);
7245 break;
7246
7247 case OpAtomicIAdd:
7248 MSL_AFMO(add);
7249 break;
7250
7251 case OpAtomicISub:
7252 MSL_AFMO(sub);
7253 break;
7254
7255 case OpAtomicSMin:
7256 case OpAtomicUMin:
7257 MSL_AFMO(min);
7258 break;
7259
7260 case OpAtomicSMax:
7261 case OpAtomicUMax:
7262 MSL_AFMO(max);
7263 break;
7264
7265 case OpAtomicAnd:
7266 MSL_AFMO(and);
7267 break;
7268
7269 case OpAtomicOr:
7270 MSL_AFMO(or);
7271 break;
7272
7273 case OpAtomicXor:
7274 MSL_AFMO(xor);
7275 break;
7276
7277 // Images
7278
7279 // Reads == Fetches in Metal
7280 case OpImageRead:
7281 {
7282 // Mark that this shader reads from this image
7283 uint32_t img_id = ops[2];
7284 auto &type = expression_type(img_id);
7285 if (type.image.dim != DimSubpassData)
7286 {
7287 auto *p_var = maybe_get_backing_variable(img_id);
7288 if (p_var && has_decoration(p_var->self, DecorationNonReadable))
7289 {
7290 unset_decoration(p_var->self, DecorationNonReadable);
7291 force_recompile();
7292 }
7293 }
7294
7295 emit_texture_op(instruction, false);
7296 break;
7297 }
7298
7299 // Emulate texture2D atomic operations
7300 case OpImageTexelPointer:
7301 {
7302 // When using the pointer, we need to know which variable it is actually loaded from.
7303 auto *var = maybe_get_backing_variable(ops[2]);
7304 if (var && atomic_image_vars.count(var->self))
7305 {
7306 uint32_t result_type = ops[0];
7307 uint32_t id = ops[1];
7308
7309 std::string coord = to_expression(ops[3]);
7310 auto &type = expression_type(ops[2]);
7311 if (type.image.dim == Dim2D)
7312 {
7313 coord = join("spvImage2DAtomicCoord(", coord, ", ", to_expression(ops[2]), ")");
7314 }
7315
7316 auto &e = set<SPIRExpression>(id, join(to_expression(ops[2]), "_atomic[", coord, "]"), result_type, true);
7317 e.loaded_from = var ? var->self : ID(0);
7318 inherit_expression_dependencies(id, ops[3]);
7319 }
7320 else
7321 {
7322 uint32_t result_type = ops[0];
7323 uint32_t id = ops[1];
7324 auto &e =
7325 set<SPIRExpression>(id, join(to_expression(ops[2]), ", ", to_expression(ops[3])), result_type, true);
7326
7327 // When using the pointer, we need to know which variable it is actually loaded from.
7328 e.loaded_from = var ? var->self : ID(0);
7329 inherit_expression_dependencies(id, ops[3]);
7330 }
7331 break;
7332 }
7333
7334 case OpImageWrite:
7335 {
7336 uint32_t img_id = ops[0];
7337 uint32_t coord_id = ops[1];
7338 uint32_t texel_id = ops[2];
7339 const uint32_t *opt = &ops[3];
7340 uint32_t length = instruction.length - 3;
7341
7342 // Bypass pointers because we need the real image struct
7343 auto &type = expression_type(img_id);
7344 auto &img_type = get<SPIRType>(type.self);
7345
7346 // Ensure this image has been marked as being written to and force a
7347 // recommpile so that the image type output will include write access
7348 auto *p_var = maybe_get_backing_variable(img_id);
7349 if (p_var && has_decoration(p_var->self, DecorationNonWritable))
7350 {
7351 unset_decoration(p_var->self, DecorationNonWritable);
7352 force_recompile();
7353 }
7354
7355 bool forward = false;
7356 uint32_t bias = 0;
7357 uint32_t lod = 0;
7358 uint32_t flags = 0;
7359
7360 if (length)
7361 {
7362 flags = *opt++;
7363 length--;
7364 }
7365
7366 auto test = [&](uint32_t &v, uint32_t flag) {
7367 if (length && (flags & flag))
7368 {
7369 v = *opt++;
7370 length--;
7371 }
7372 };
7373
7374 test(bias, ImageOperandsBiasMask);
7375 test(lod, ImageOperandsLodMask);
7376
7377 auto &texel_type = expression_type(texel_id);
7378 auto store_type = texel_type;
7379 store_type.vecsize = 4;
7380
7381 TextureFunctionArguments args = {};
7382 args.base.img = img_id;
7383 args.base.imgtype = &img_type;
7384 args.base.is_fetch = true;
7385 args.coord = coord_id;
7386 args.lod = lod;
7387 statement(join(to_expression(img_id), ".write(",
7388 remap_swizzle(store_type, texel_type.vecsize, to_expression(texel_id)), ", ",
7389 CompilerMSL::to_function_args(args, &forward), ");"));
7390
7391 if (p_var && variable_storage_is_aliased(*p_var))
7392 flush_all_aliased_variables();
7393
7394 break;
7395 }
7396
7397 case OpImageQuerySize:
7398 case OpImageQuerySizeLod:
7399 {
7400 uint32_t rslt_type_id = ops[0];
7401 auto &rslt_type = get<SPIRType>(rslt_type_id);
7402
7403 uint32_t id = ops[1];
7404
7405 uint32_t img_id = ops[2];
7406 string img_exp = to_expression(img_id);
7407 auto &img_type = expression_type(img_id);
7408 Dim img_dim = img_type.image.dim;
7409 bool img_is_array = img_type.image.arrayed;
7410
7411 if (img_type.basetype != SPIRType::Image)
7412 SPIRV_CROSS_THROW("Invalid type for OpImageQuerySize.");
7413
7414 string lod;
7415 if (opcode == OpImageQuerySizeLod)
7416 {
7417 // LOD index defaults to zero, so don't bother outputing level zero index
7418 string decl_lod = to_expression(ops[3]);
7419 if (decl_lod != "0")
7420 lod = decl_lod;
7421 }
7422
7423 string expr = type_to_glsl(rslt_type) + "(";
7424 expr += img_exp + ".get_width(" + lod + ")";
7425
7426 if (img_dim == Dim2D || img_dim == DimCube || img_dim == Dim3D)
7427 expr += ", " + img_exp + ".get_height(" + lod + ")";
7428
7429 if (img_dim == Dim3D)
7430 expr += ", " + img_exp + ".get_depth(" + lod + ")";
7431
7432 if (img_is_array)
7433 {
7434 expr += ", " + img_exp + ".get_array_size()";
7435 if (img_dim == DimCube && msl_options.emulate_cube_array)
7436 expr += " / 6";
7437 }
7438
7439 expr += ")";
7440
7441 emit_op(rslt_type_id, id, expr, should_forward(img_id));
7442
7443 break;
7444 }
7445
7446 case OpImageQueryLod:
7447 {
7448 if (!msl_options.supports_msl_version(2, 2))
7449 SPIRV_CROSS_THROW("ImageQueryLod is only supported on MSL 2.2 and up.");
7450 uint32_t result_type = ops[0];
7451 uint32_t id = ops[1];
7452 uint32_t image_id = ops[2];
7453 uint32_t coord_id = ops[3];
7454 emit_uninitialized_temporary_expression(result_type, id);
7455
7456 auto sampler_expr = to_sampler_expression(image_id);
7457 auto *combined = maybe_get<SPIRCombinedImageSampler>(image_id);
7458 auto image_expr = combined ? to_expression(combined->image) : to_expression(image_id);
7459
7460 // TODO: It is unclear if calculcate_clamped_lod also conditionally rounds
7461 // the reported LOD based on the sampler. NEAREST miplevel should
7462 // round the LOD, but LINEAR miplevel should not round.
7463 // Let's hope this does not become an issue ...
7464 statement(to_expression(id), ".x = ", image_expr, ".calculate_clamped_lod(", sampler_expr, ", ",
7465 to_expression(coord_id), ");");
7466 statement(to_expression(id), ".y = ", image_expr, ".calculate_unclamped_lod(", sampler_expr, ", ",
7467 to_expression(coord_id), ");");
7468 register_control_dependent_expression(id);
7469 break;
7470 }
7471
7472 #define MSL_ImgQry(qrytype) \
7473 do \
7474 { \
7475 uint32_t rslt_type_id = ops[0]; \
7476 auto &rslt_type = get<SPIRType>(rslt_type_id); \
7477 uint32_t id = ops[1]; \
7478 uint32_t img_id = ops[2]; \
7479 string img_exp = to_expression(img_id); \
7480 string expr = type_to_glsl(rslt_type) + "(" + img_exp + ".get_num_" #qrytype "())"; \
7481 emit_op(rslt_type_id, id, expr, should_forward(img_id)); \
7482 } while (false)
7483
7484 case OpImageQueryLevels:
7485 MSL_ImgQry(mip_levels);
7486 break;
7487
7488 case OpImageQuerySamples:
7489 MSL_ImgQry(samples);
7490 break;
7491
7492 case OpImage:
7493 {
7494 uint32_t result_type = ops[0];
7495 uint32_t id = ops[1];
7496 auto *combined = maybe_get<SPIRCombinedImageSampler>(ops[2]);
7497
7498 if (combined)
7499 {
7500 auto &e = emit_op(result_type, id, to_expression(combined->image), true, true);
7501 auto *var = maybe_get_backing_variable(combined->image);
7502 if (var)
7503 e.loaded_from = var->self;
7504 }
7505 else
7506 {
7507 auto *var = maybe_get_backing_variable(ops[2]);
7508 SPIRExpression *e;
7509 if (var && has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler))
7510 e = &emit_op(result_type, id, join(to_expression(ops[2]), ".plane0"), true, true);
7511 else
7512 e = &emit_op(result_type, id, to_expression(ops[2]), true, true);
7513 if (var)
7514 e->loaded_from = var->self;
7515 }
7516 break;
7517 }
7518
7519 // Casting
7520 case OpQuantizeToF16:
7521 {
7522 uint32_t result_type = ops[0];
7523 uint32_t id = ops[1];
7524 uint32_t arg = ops[2];
7525
7526 string exp;
7527 auto &type = get<SPIRType>(result_type);
7528
7529 switch (type.vecsize)
7530 {
7531 case 1:
7532 exp = join("float(half(", to_expression(arg), "))");
7533 break;
7534 case 2:
7535 exp = join("float2(half2(", to_expression(arg), "))");
7536 break;
7537 case 3:
7538 exp = join("float3(half3(", to_expression(arg), "))");
7539 break;
7540 case 4:
7541 exp = join("float4(half4(", to_expression(arg), "))");
7542 break;
7543 default:
7544 SPIRV_CROSS_THROW("Illegal argument to OpQuantizeToF16.");
7545 }
7546
7547 emit_op(result_type, id, exp, should_forward(arg));
7548 break;
7549 }
7550
7551 case OpInBoundsAccessChain:
7552 case OpAccessChain:
7553 case OpPtrAccessChain:
7554 if (is_tessellation_shader())
7555 {
7556 if (!emit_tessellation_access_chain(ops, instruction.length))
7557 CompilerGLSL::emit_instruction(instruction);
7558 }
7559 else
7560 CompilerGLSL::emit_instruction(instruction);
7561 fix_up_interpolant_access_chain(ops, instruction.length);
7562 break;
7563
7564 case OpStore:
7565 if (is_out_of_bounds_tessellation_level(ops[0]))
7566 break;
7567
7568 if (maybe_emit_array_assignment(ops[0], ops[1]))
7569 break;
7570
7571 CompilerGLSL::emit_instruction(instruction);
7572 break;
7573
7574 // Compute barriers
7575 case OpMemoryBarrier:
7576 emit_barrier(0, ops[0], ops[1]);
7577 break;
7578
7579 case OpControlBarrier:
7580 // In GLSL a memory barrier is often followed by a control barrier.
7581 // But in MSL, memory barriers are also control barriers, so don't
7582 // emit a simple control barrier if a memory barrier has just been emitted.
7583 if (previous_instruction_opcode != OpMemoryBarrier)
7584 emit_barrier(ops[0], ops[1], ops[2]);
7585 break;
7586
7587 case OpOuterProduct:
7588 {
7589 uint32_t result_type = ops[0];
7590 uint32_t id = ops[1];
7591 uint32_t a = ops[2];
7592 uint32_t b = ops[3];
7593
7594 auto &type = get<SPIRType>(result_type);
7595 string expr = type_to_glsl_constructor(type);
7596 expr += "(";
7597 for (uint32_t col = 0; col < type.columns; col++)
7598 {
7599 expr += to_enclosed_expression(a);
7600 expr += " * ";
7601 expr += to_extract_component_expression(b, col);
7602 if (col + 1 < type.columns)
7603 expr += ", ";
7604 }
7605 expr += ")";
7606 emit_op(result_type, id, expr, should_forward(a) && should_forward(b));
7607 inherit_expression_dependencies(id, a);
7608 inherit_expression_dependencies(id, b);
7609 break;
7610 }
7611
7612 case OpVectorTimesMatrix:
7613 case OpMatrixTimesVector:
7614 {
7615 if (!msl_options.invariant_float_math)
7616 {
7617 CompilerGLSL::emit_instruction(instruction);
7618 break;
7619 }
7620
7621 // If the matrix needs transpose, just flip the multiply order.
7622 auto *e = maybe_get<SPIRExpression>(ops[opcode == OpMatrixTimesVector ? 2 : 3]);
7623 if (e && e->need_transpose)
7624 {
7625 e->need_transpose = false;
7626 string expr;
7627
7628 if (opcode == OpMatrixTimesVector)
7629 {
7630 expr = join("spvFMulVectorMatrix(", to_enclosed_unpacked_expression(ops[3]), ", ",
7631 to_unpacked_row_major_matrix_expression(ops[2]), ")");
7632 }
7633 else
7634 {
7635 expr = join("spvFMulMatrixVector(", to_unpacked_row_major_matrix_expression(ops[3]), ", ",
7636 to_enclosed_unpacked_expression(ops[2]), ")");
7637 }
7638
7639 bool forward = should_forward(ops[2]) && should_forward(ops[3]);
7640 emit_op(ops[0], ops[1], expr, forward);
7641 e->need_transpose = true;
7642 inherit_expression_dependencies(ops[1], ops[2]);
7643 inherit_expression_dependencies(ops[1], ops[3]);
7644 }
7645 else
7646 {
7647 if (opcode == OpMatrixTimesVector)
7648 MSL_BFOP(spvFMulMatrixVector);
7649 else
7650 MSL_BFOP(spvFMulVectorMatrix);
7651 }
7652 break;
7653 }
7654
7655 case OpMatrixTimesMatrix:
7656 {
7657 if (!msl_options.invariant_float_math)
7658 {
7659 CompilerGLSL::emit_instruction(instruction);
7660 break;
7661 }
7662
7663 auto *a = maybe_get<SPIRExpression>(ops[2]);
7664 auto *b = maybe_get<SPIRExpression>(ops[3]);
7665
7666 // If both matrices need transpose, we can multiply in flipped order and tag the expression as transposed.
7667 // a^T * b^T = (b * a)^T.
7668 if (a && b && a->need_transpose && b->need_transpose)
7669 {
7670 a->need_transpose = false;
7671 b->need_transpose = false;
7672
7673 auto expr =
7674 join("spvFMulMatrixMatrix(", enclose_expression(to_unpacked_row_major_matrix_expression(ops[3])), ", ",
7675 enclose_expression(to_unpacked_row_major_matrix_expression(ops[2])), ")");
7676
7677 bool forward = should_forward(ops[2]) && should_forward(ops[3]);
7678 auto &e = emit_op(ops[0], ops[1], expr, forward);
7679 e.need_transpose = true;
7680 a->need_transpose = true;
7681 b->need_transpose = true;
7682 inherit_expression_dependencies(ops[1], ops[2]);
7683 inherit_expression_dependencies(ops[1], ops[3]);
7684 }
7685 else
7686 MSL_BFOP(spvFMulMatrixMatrix);
7687
7688 break;
7689 }
7690
7691 case OpIAddCarry:
7692 case OpISubBorrow:
7693 {
7694 uint32_t result_type = ops[0];
7695 uint32_t result_id = ops[1];
7696 uint32_t op0 = ops[2];
7697 uint32_t op1 = ops[3];
7698 auto &type = get<SPIRType>(result_type);
7699 emit_uninitialized_temporary_expression(result_type, result_id);
7700
7701 auto &res_type = get<SPIRType>(type.member_types[1]);
7702 if (opcode == OpIAddCarry)
7703 {
7704 statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " + ",
7705 to_enclosed_expression(op1), ";");
7706 statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
7707 "(1), ", type_to_glsl(res_type), "(0), ", to_expression(result_id), ".", to_member_name(type, 0),
7708 " >= max(", to_expression(op0), ", ", to_expression(op1), "));");
7709 }
7710 else
7711 {
7712 statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " - ",
7713 to_enclosed_expression(op1), ";");
7714 statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
7715 "(1), ", type_to_glsl(res_type), "(0), ", to_enclosed_expression(op0),
7716 " >= ", to_enclosed_expression(op1), ");");
7717 }
7718 break;
7719 }
7720
7721 case OpUMulExtended:
7722 case OpSMulExtended:
7723 {
7724 uint32_t result_type = ops[0];
7725 uint32_t result_id = ops[1];
7726 uint32_t op0 = ops[2];
7727 uint32_t op1 = ops[3];
7728 auto &type = get<SPIRType>(result_type);
7729 emit_uninitialized_temporary_expression(result_type, result_id);
7730
7731 statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " * ",
7732 to_enclosed_expression(op1), ";");
7733 statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(", to_expression(op0), ", ",
7734 to_expression(op1), ");");
7735 break;
7736 }
7737
7738 case OpArrayLength:
7739 {
7740 auto &type = expression_type(ops[2]);
7741 uint32_t offset = type_struct_member_offset(type, ops[3]);
7742 uint32_t stride = type_struct_member_array_stride(type, ops[3]);
7743
7744 auto expr = join("(", to_buffer_size_expression(ops[2]), " - ", offset, ") / ", stride);
7745 emit_op(ops[0], ops[1], expr, true);
7746 break;
7747 }
7748
7749 // SPV_INTEL_shader_integer_functions2
7750 case OpUCountLeadingZerosINTEL:
7751 MSL_UFOP(clz);
7752 break;
7753
7754 case OpUCountTrailingZerosINTEL:
7755 MSL_UFOP(ctz);
7756 break;
7757
7758 case OpAbsISubINTEL:
7759 case OpAbsUSubINTEL:
7760 MSL_BFOP(absdiff);
7761 break;
7762
7763 case OpIAddSatINTEL:
7764 case OpUAddSatINTEL:
7765 MSL_BFOP(addsat);
7766 break;
7767
7768 case OpIAverageINTEL:
7769 case OpUAverageINTEL:
7770 MSL_BFOP(hadd);
7771 break;
7772
7773 case OpIAverageRoundedINTEL:
7774 case OpUAverageRoundedINTEL:
7775 MSL_BFOP(rhadd);
7776 break;
7777
7778 case OpISubSatINTEL:
7779 case OpUSubSatINTEL:
7780 MSL_BFOP(subsat);
7781 break;
7782
7783 case OpIMul32x16INTEL:
7784 {
7785 uint32_t result_type = ops[0];
7786 uint32_t id = ops[1];
7787 uint32_t a = ops[2], b = ops[3];
7788 bool forward = should_forward(a) && should_forward(b);
7789 emit_op(result_type, id, join("int(short(", to_expression(a), ")) * int(short(", to_expression(b), "))"),
7790 forward);
7791 inherit_expression_dependencies(id, a);
7792 inherit_expression_dependencies(id, b);
7793 break;
7794 }
7795
7796 case OpUMul32x16INTEL:
7797 {
7798 uint32_t result_type = ops[0];
7799 uint32_t id = ops[1];
7800 uint32_t a = ops[2], b = ops[3];
7801 bool forward = should_forward(a) && should_forward(b);
7802 emit_op(result_type, id, join("uint(ushort(", to_expression(a), ")) * uint(ushort(", to_expression(b), "))"),
7803 forward);
7804 inherit_expression_dependencies(id, a);
7805 inherit_expression_dependencies(id, b);
7806 break;
7807 }
7808
7809 // SPV_EXT_demote_to_helper_invocation
7810 case OpDemoteToHelperInvocationEXT:
7811 if (!msl_options.supports_msl_version(2, 3))
7812 SPIRV_CROSS_THROW("discard_fragment() does not formally have demote semantics until MSL 2.3.");
7813 CompilerGLSL::emit_instruction(instruction);
7814 break;
7815
7816 case OpIsHelperInvocationEXT:
7817 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
7818 SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.3 on iOS.");
7819 else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
7820 SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.1 on macOS.");
7821 emit_op(ops[0], ops[1], "simd_is_helper_thread()", false);
7822 break;
7823
7824 case OpBeginInvocationInterlockEXT:
7825 case OpEndInvocationInterlockEXT:
7826 if (!msl_options.supports_msl_version(2, 0))
7827 SPIRV_CROSS_THROW("Raster order groups require MSL 2.0.");
7828 break; // Nothing to do in the body
7829
7830 default:
7831 CompilerGLSL::emit_instruction(instruction);
7832 break;
7833 }
7834
7835 previous_instruction_opcode = opcode;
7836 }
7837
emit_texture_op(const Instruction & i,bool sparse)7838 void CompilerMSL::emit_texture_op(const Instruction &i, bool sparse)
7839 {
7840 if (sparse)
7841 SPIRV_CROSS_THROW("Sparse feedback not yet supported in MSL.");
7842
7843 if (msl_options.use_framebuffer_fetch_subpasses)
7844 {
7845 auto *ops = stream(i);
7846
7847 uint32_t result_type_id = ops[0];
7848 uint32_t id = ops[1];
7849 uint32_t img = ops[2];
7850
7851 auto &type = expression_type(img);
7852 auto &imgtype = get<SPIRType>(type.self);
7853
7854 // Use Metal's native frame-buffer fetch API for subpass inputs.
7855 if (imgtype.image.dim == DimSubpassData)
7856 {
7857 // Subpass inputs cannot be invalidated,
7858 // so just forward the expression directly.
7859 string expr = to_expression(img);
7860 emit_op(result_type_id, id, expr, true);
7861 return;
7862 }
7863 }
7864
7865 // Fallback to default implementation
7866 CompilerGLSL::emit_texture_op(i, sparse);
7867 }
7868
emit_barrier(uint32_t id_exe_scope,uint32_t id_mem_scope,uint32_t id_mem_sem)7869 void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uint32_t id_mem_sem)
7870 {
7871 if (get_execution_model() != ExecutionModelGLCompute && get_execution_model() != ExecutionModelTessellationControl)
7872 return;
7873
7874 uint32_t exe_scope = id_exe_scope ? evaluate_constant_u32(id_exe_scope) : uint32_t(ScopeInvocation);
7875 uint32_t mem_scope = id_mem_scope ? evaluate_constant_u32(id_mem_scope) : uint32_t(ScopeInvocation);
7876 // Use the wider of the two scopes (smaller value)
7877 exe_scope = min(exe_scope, mem_scope);
7878
7879 if (msl_options.emulate_subgroups && exe_scope >= ScopeSubgroup && !id_mem_sem)
7880 // In this case, we assume a "subgroup" size of 1. The barrier, then, is a noop.
7881 return;
7882
7883 string bar_stmt;
7884 if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2))
7885 bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier";
7886 else
7887 bar_stmt = "threadgroup_barrier";
7888 bar_stmt += "(";
7889
7890 uint32_t mem_sem = id_mem_sem ? evaluate_constant_u32(id_mem_sem) : uint32_t(MemorySemanticsMaskNone);
7891
7892 // Use the | operator to combine flags if we can.
7893 if (msl_options.supports_msl_version(1, 2))
7894 {
7895 string mem_flags = "";
7896 // For tesc shaders, this also affects objects in the Output storage class.
7897 // Since in Metal, these are placed in a device buffer, we have to sync device memory here.
7898 if (get_execution_model() == ExecutionModelTessellationControl ||
7899 (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)))
7900 mem_flags += "mem_flags::mem_device";
7901
7902 // Fix tessellation patch function processing
7903 if (get_execution_model() == ExecutionModelTessellationControl ||
7904 (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
7905 {
7906 if (!mem_flags.empty())
7907 mem_flags += " | ";
7908 mem_flags += "mem_flags::mem_threadgroup";
7909 }
7910 if (mem_sem & MemorySemanticsImageMemoryMask)
7911 {
7912 if (!mem_flags.empty())
7913 mem_flags += " | ";
7914 mem_flags += "mem_flags::mem_texture";
7915 }
7916
7917 if (mem_flags.empty())
7918 mem_flags = "mem_flags::mem_none";
7919
7920 bar_stmt += mem_flags;
7921 }
7922 else
7923 {
7924 if ((mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)) &&
7925 (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
7926 bar_stmt += "mem_flags::mem_device_and_threadgroup";
7927 else if (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask))
7928 bar_stmt += "mem_flags::mem_device";
7929 else if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask))
7930 bar_stmt += "mem_flags::mem_threadgroup";
7931 else if (mem_sem & MemorySemanticsImageMemoryMask)
7932 bar_stmt += "mem_flags::mem_texture";
7933 else
7934 bar_stmt += "mem_flags::mem_none";
7935 }
7936
7937 bar_stmt += ");";
7938
7939 statement(bar_stmt);
7940
7941 assert(current_emitting_block);
7942 flush_control_dependent_expressions(current_emitting_block->self);
7943 flush_all_active_variables();
7944 }
7945
emit_array_copy(const string & lhs,uint32_t rhs_id,StorageClass lhs_storage,StorageClass rhs_storage)7946 void CompilerMSL::emit_array_copy(const string &lhs, uint32_t rhs_id, StorageClass lhs_storage,
7947 StorageClass rhs_storage)
7948 {
7949 // Allow Metal to use the array<T> template to make arrays a value type.
7950 // This, however, cannot be used for threadgroup address specifiers, so consider the custom array copy as fallback.
7951 bool lhs_thread = (lhs_storage == StorageClassOutput || lhs_storage == StorageClassFunction ||
7952 lhs_storage == StorageClassGeneric || lhs_storage == StorageClassPrivate);
7953 bool rhs_thread = (rhs_storage == StorageClassInput || rhs_storage == StorageClassFunction ||
7954 rhs_storage == StorageClassGeneric || rhs_storage == StorageClassPrivate);
7955
7956 // If threadgroup storage qualifiers are *not* used:
7957 // Avoid spvCopy* wrapper functions; Otherwise, spvUnsafeArray<> template cannot be used with that storage qualifier.
7958 if (lhs_thread && rhs_thread && !using_builtin_array())
7959 {
7960 statement(lhs, " = ", to_expression(rhs_id), ";");
7961 }
7962 else
7963 {
7964 // Assignment from an array initializer is fine.
7965 auto &type = expression_type(rhs_id);
7966 auto *var = maybe_get_backing_variable(rhs_id);
7967
7968 // Unfortunately, we cannot template on address space in MSL,
7969 // so explicit address space redirection it is ...
7970 bool is_constant = false;
7971 if (ir.ids[rhs_id].get_type() == TypeConstant)
7972 {
7973 is_constant = true;
7974 }
7975 else if (var && var->remapped_variable && var->statically_assigned &&
7976 ir.ids[var->static_expression].get_type() == TypeConstant)
7977 {
7978 is_constant = true;
7979 }
7980 else if (rhs_storage == StorageClassUniform)
7981 {
7982 is_constant = true;
7983 }
7984
7985 // For the case where we have OpLoad triggering an array copy,
7986 // we cannot easily detect this case ahead of time since it's
7987 // context dependent. We might have to force a recompile here
7988 // if this is the only use of array copies in our shader.
7989 if (type.array.size() > 1)
7990 {
7991 if (type.array.size() > kArrayCopyMultidimMax)
7992 SPIRV_CROSS_THROW("Cannot support this many dimensions for arrays of arrays.");
7993 auto func = static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + type.array.size());
7994 add_spv_func_and_recompile(func);
7995 }
7996 else
7997 add_spv_func_and_recompile(SPVFuncImplArrayCopy);
7998
7999 const char *tag = nullptr;
8000 if (lhs_thread && is_constant)
8001 tag = "FromConstantToStack";
8002 else if (lhs_storage == StorageClassWorkgroup && is_constant)
8003 tag = "FromConstantToThreadGroup";
8004 else if (lhs_thread && rhs_thread)
8005 tag = "FromStackToStack";
8006 else if (lhs_storage == StorageClassWorkgroup && rhs_thread)
8007 tag = "FromStackToThreadGroup";
8008 else if (lhs_thread && rhs_storage == StorageClassWorkgroup)
8009 tag = "FromThreadGroupToStack";
8010 else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassWorkgroup)
8011 tag = "FromThreadGroupToThreadGroup";
8012 else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassStorageBuffer)
8013 tag = "FromDeviceToDevice";
8014 else if (lhs_storage == StorageClassStorageBuffer && is_constant)
8015 tag = "FromConstantToDevice";
8016 else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassWorkgroup)
8017 tag = "FromThreadGroupToDevice";
8018 else if (lhs_storage == StorageClassStorageBuffer && rhs_thread)
8019 tag = "FromStackToDevice";
8020 else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassStorageBuffer)
8021 tag = "FromDeviceToThreadGroup";
8022 else if (lhs_thread && rhs_storage == StorageClassStorageBuffer)
8023 tag = "FromDeviceToStack";
8024 else
8025 SPIRV_CROSS_THROW("Unknown storage class used for copying arrays.");
8026
8027 // Pass internal array of spvUnsafeArray<> into wrapper functions
8028 if (lhs_thread && !msl_options.force_native_arrays)
8029 statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ".elements, ", to_expression(rhs_id), ");");
8030 else if (rhs_thread && !msl_options.force_native_arrays)
8031 statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ", ", to_expression(rhs_id), ".elements);");
8032 else
8033 statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ", ", to_expression(rhs_id), ");");
8034 }
8035 }
8036
8037 // Since MSL does not allow arrays to be copied via simple variable assignment,
8038 // if the LHS and RHS represent an assignment of an entire array, it must be
8039 // implemented by calling an array copy function.
8040 // Returns whether the struct assignment was emitted.
maybe_emit_array_assignment(uint32_t id_lhs,uint32_t id_rhs)8041 bool CompilerMSL::maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs)
8042 {
8043 // We only care about assignments of an entire array
8044 auto &type = expression_type(id_rhs);
8045 if (type.array.size() == 0)
8046 return false;
8047
8048 auto *var = maybe_get<SPIRVariable>(id_lhs);
8049
8050 // Is this a remapped, static constant? Don't do anything.
8051 if (var && var->remapped_variable && var->statically_assigned)
8052 return true;
8053
8054 if (ir.ids[id_rhs].get_type() == TypeConstant && var && var->deferred_declaration)
8055 {
8056 // Special case, if we end up declaring a variable when assigning the constant array,
8057 // we can avoid the copy by directly assigning the constant expression.
8058 // This is likely necessary to be able to use a variable as a true look-up table, as it is unlikely
8059 // the compiler will be able to optimize the spvArrayCopy() into a constant LUT.
8060 // After a variable has been declared, we can no longer assign constant arrays in MSL unfortunately.
8061 statement(to_expression(id_lhs), " = ", constant_expression(get<SPIRConstant>(id_rhs)), ";");
8062 return true;
8063 }
8064
8065 // Ensure the LHS variable has been declared
8066 auto *p_v_lhs = maybe_get_backing_variable(id_lhs);
8067 if (p_v_lhs)
8068 flush_variable_declaration(p_v_lhs->self);
8069
8070 emit_array_copy(to_expression(id_lhs), id_rhs, get_expression_effective_storage_class(id_lhs),
8071 get_expression_effective_storage_class(id_rhs));
8072 register_write(id_lhs);
8073
8074 return true;
8075 }
8076
8077 // Emits one of the atomic functions. In MSL, the atomic functions operate on pointers
emit_atomic_func_op(uint32_t result_type,uint32_t result_id,const char * op,uint32_t mem_order_1,uint32_t mem_order_2,bool has_mem_order_2,uint32_t obj,uint32_t op1,bool op1_is_pointer,bool op1_is_literal,uint32_t op2)8078 void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, uint32_t mem_order_1,
8079 uint32_t mem_order_2, bool has_mem_order_2, uint32_t obj, uint32_t op1,
8080 bool op1_is_pointer, bool op1_is_literal, uint32_t op2)
8081 {
8082 string exp = string(op) + "(";
8083
8084 auto &type = get_pointee_type(expression_type(obj));
8085 exp += "(";
8086 auto *var = maybe_get_backing_variable(obj);
8087 if (!var)
8088 SPIRV_CROSS_THROW("No backing variable for atomic operation.");
8089
8090 // Emulate texture2D atomic operations
8091 const auto &res_type = get<SPIRType>(var->basetype);
8092 if (res_type.storage == StorageClassUniformConstant && res_type.basetype == SPIRType::Image)
8093 {
8094 exp += "device";
8095 }
8096 else
8097 {
8098 exp += get_argument_address_space(*var);
8099 }
8100
8101 exp += " atomic_";
8102 exp += type_to_glsl(type);
8103 exp += "*)";
8104
8105 exp += "&";
8106 exp += to_enclosed_expression(obj);
8107
8108 bool is_atomic_compare_exchange_strong = op1_is_pointer && op1;
8109
8110 if (is_atomic_compare_exchange_strong)
8111 {
8112 assert(strcmp(op, "atomic_compare_exchange_weak_explicit") == 0);
8113 assert(op2);
8114 assert(has_mem_order_2);
8115 exp += ", &";
8116 exp += to_name(result_id);
8117 exp += ", ";
8118 exp += to_expression(op2);
8119 exp += ", ";
8120 exp += get_memory_order(mem_order_1);
8121 exp += ", ";
8122 exp += get_memory_order(mem_order_2);
8123 exp += ")";
8124
8125 // MSL only supports the weak atomic compare exchange, so emit a CAS loop here.
8126 // The MSL function returns false if the atomic write fails OR the comparison test fails,
8127 // so we must validate that it wasn't the comparison test that failed before continuing
8128 // the CAS loop, otherwise it will loop infinitely, with the comparison test always failing.
8129 // The function updates the comparitor value from the memory value, so the additional
8130 // comparison test evaluates the memory value against the expected value.
8131 emit_uninitialized_temporary_expression(result_type, result_id);
8132 statement("do");
8133 begin_scope();
8134 statement(to_name(result_id), " = ", to_expression(op1), ";");
8135 end_scope_decl(join("while (!", exp, " && ", to_name(result_id), " == ", to_enclosed_expression(op1), ")"));
8136 }
8137 else
8138 {
8139 assert(strcmp(op, "atomic_compare_exchange_weak_explicit") != 0);
8140 if (op1)
8141 {
8142 if (op1_is_literal)
8143 exp += join(", ", op1);
8144 else
8145 exp += ", " + to_expression(op1);
8146 }
8147 if (op2)
8148 exp += ", " + to_expression(op2);
8149
8150 exp += string(", ") + get_memory_order(mem_order_1);
8151 if (has_mem_order_2)
8152 exp += string(", ") + get_memory_order(mem_order_2);
8153
8154 exp += ")";
8155
8156 if (strcmp(op, "atomic_store_explicit") != 0)
8157 emit_op(result_type, result_id, exp, false);
8158 else
8159 statement(exp, ";");
8160 }
8161
8162 flush_all_atomic_capable_variables();
8163 }
8164
8165 // Metal only supports relaxed memory order for now
get_memory_order(uint32_t)8166 const char *CompilerMSL::get_memory_order(uint32_t)
8167 {
8168 return "memory_order_relaxed";
8169 }
8170
8171 // Override for MSL-specific extension syntax instructions
emit_glsl_op(uint32_t result_type,uint32_t id,uint32_t eop,const uint32_t * args,uint32_t count)8172 void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
8173 {
8174 auto op = static_cast<GLSLstd450>(eop);
8175
8176 // If we need to do implicit bitcasts, make sure we do it with the correct type.
8177 uint32_t integer_width = get_integer_width_for_glsl_instruction(op, args, count);
8178 auto int_type = to_signed_basetype(integer_width);
8179 auto uint_type = to_unsigned_basetype(integer_width);
8180
8181 switch (op)
8182 {
8183 case GLSLstd450Atan2:
8184 emit_binary_func_op(result_type, id, args[0], args[1], "atan2");
8185 break;
8186 case GLSLstd450InverseSqrt:
8187 emit_unary_func_op(result_type, id, args[0], "rsqrt");
8188 break;
8189 case GLSLstd450RoundEven:
8190 emit_unary_func_op(result_type, id, args[0], "rint");
8191 break;
8192
8193 case GLSLstd450FindILsb:
8194 {
8195 // In this template version of findLSB, we return T.
8196 auto basetype = expression_type(args[0]).basetype;
8197 emit_unary_func_op_cast(result_type, id, args[0], "spvFindLSB", basetype, basetype);
8198 break;
8199 }
8200
8201 case GLSLstd450FindSMsb:
8202 emit_unary_func_op_cast(result_type, id, args[0], "spvFindSMSB", int_type, int_type);
8203 break;
8204
8205 case GLSLstd450FindUMsb:
8206 emit_unary_func_op_cast(result_type, id, args[0], "spvFindUMSB", uint_type, uint_type);
8207 break;
8208
8209 case GLSLstd450PackSnorm4x8:
8210 emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm4x8");
8211 break;
8212 case GLSLstd450PackUnorm4x8:
8213 emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm4x8");
8214 break;
8215 case GLSLstd450PackSnorm2x16:
8216 emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm2x16");
8217 break;
8218 case GLSLstd450PackUnorm2x16:
8219 emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm2x16");
8220 break;
8221
8222 case GLSLstd450PackHalf2x16:
8223 {
8224 auto expr = join("as_type<uint>(half2(", to_expression(args[0]), "))");
8225 emit_op(result_type, id, expr, should_forward(args[0]));
8226 inherit_expression_dependencies(id, args[0]);
8227 break;
8228 }
8229
8230 case GLSLstd450UnpackSnorm4x8:
8231 emit_unary_func_op(result_type, id, args[0], "unpack_snorm4x8_to_float");
8232 break;
8233 case GLSLstd450UnpackUnorm4x8:
8234 emit_unary_func_op(result_type, id, args[0], "unpack_unorm4x8_to_float");
8235 break;
8236 case GLSLstd450UnpackSnorm2x16:
8237 emit_unary_func_op(result_type, id, args[0], "unpack_snorm2x16_to_float");
8238 break;
8239 case GLSLstd450UnpackUnorm2x16:
8240 emit_unary_func_op(result_type, id, args[0], "unpack_unorm2x16_to_float");
8241 break;
8242
8243 case GLSLstd450UnpackHalf2x16:
8244 {
8245 auto expr = join("float2(as_type<half2>(", to_expression(args[0]), "))");
8246 emit_op(result_type, id, expr, should_forward(args[0]));
8247 inherit_expression_dependencies(id, args[0]);
8248 break;
8249 }
8250
8251 case GLSLstd450PackDouble2x32:
8252 emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450PackDouble2x32"); // Currently unsupported
8253 break;
8254 case GLSLstd450UnpackDouble2x32:
8255 emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450UnpackDouble2x32"); // Currently unsupported
8256 break;
8257
8258 case GLSLstd450MatrixInverse:
8259 {
8260 auto &mat_type = get<SPIRType>(result_type);
8261 switch (mat_type.columns)
8262 {
8263 case 2:
8264 emit_unary_func_op(result_type, id, args[0], "spvInverse2x2");
8265 break;
8266 case 3:
8267 emit_unary_func_op(result_type, id, args[0], "spvInverse3x3");
8268 break;
8269 case 4:
8270 emit_unary_func_op(result_type, id, args[0], "spvInverse4x4");
8271 break;
8272 default:
8273 break;
8274 }
8275 break;
8276 }
8277
8278 case GLSLstd450FMin:
8279 // If the result type isn't float, don't bother calling the specific
8280 // precise::/fast:: version. Metal doesn't have those for half and
8281 // double types.
8282 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8283 emit_binary_func_op(result_type, id, args[0], args[1], "min");
8284 else
8285 emit_binary_func_op(result_type, id, args[0], args[1], "fast::min");
8286 break;
8287
8288 case GLSLstd450FMax:
8289 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8290 emit_binary_func_op(result_type, id, args[0], args[1], "max");
8291 else
8292 emit_binary_func_op(result_type, id, args[0], args[1], "fast::max");
8293 break;
8294
8295 case GLSLstd450FClamp:
8296 // TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
8297 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8298 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
8299 else
8300 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "fast::clamp");
8301 break;
8302
8303 case GLSLstd450NMin:
8304 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8305 emit_binary_func_op(result_type, id, args[0], args[1], "min");
8306 else
8307 emit_binary_func_op(result_type, id, args[0], args[1], "precise::min");
8308 break;
8309
8310 case GLSLstd450NMax:
8311 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8312 emit_binary_func_op(result_type, id, args[0], args[1], "max");
8313 else
8314 emit_binary_func_op(result_type, id, args[0], args[1], "precise::max");
8315 break;
8316
8317 case GLSLstd450NClamp:
8318 // TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
8319 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8320 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
8321 else
8322 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "precise::clamp");
8323 break;
8324
8325 case GLSLstd450InterpolateAtCentroid:
8326 {
8327 // We can't just emit the expression normally, because the qualified name contains a call to the default
8328 // interpolate method, or refers to a local variable. We saved the interface index we need; use it to construct
8329 // the base for the method call.
8330 uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
8331 string component;
8332 if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
8333 {
8334 uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
8335 auto *c = maybe_get<SPIRConstant>(index_expr);
8336 if (!c || c->specialization)
8337 component = join("[", to_expression(index_expr), "]");
8338 else
8339 component = join(".", index_to_swizzle(c->scalar()));
8340 }
8341 emit_op(result_type, id,
8342 join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
8343 ".interpolate_at_centroid()", component),
8344 should_forward(args[0]));
8345 break;
8346 }
8347
8348 case GLSLstd450InterpolateAtSample:
8349 {
8350 uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
8351 string component;
8352 if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
8353 {
8354 uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
8355 auto *c = maybe_get<SPIRConstant>(index_expr);
8356 if (!c || c->specialization)
8357 component = join("[", to_expression(index_expr), "]");
8358 else
8359 component = join(".", index_to_swizzle(c->scalar()));
8360 }
8361 emit_op(result_type, id,
8362 join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
8363 ".interpolate_at_sample(", to_expression(args[1]), ")", component),
8364 should_forward(args[0]) && should_forward(args[1]));
8365 break;
8366 }
8367
8368 case GLSLstd450InterpolateAtOffset:
8369 {
8370 uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
8371 string component;
8372 if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
8373 {
8374 uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
8375 auto *c = maybe_get<SPIRConstant>(index_expr);
8376 if (!c || c->specialization)
8377 component = join("[", to_expression(index_expr), "]");
8378 else
8379 component = join(".", index_to_swizzle(c->scalar()));
8380 }
8381 // Like Direct3D, Metal puts the (0, 0) at the upper-left corner, not the center as SPIR-V and GLSL do.
8382 // Offset the offset by (1/2 - 1/16), or 0.4375, to compensate for this.
8383 // It has to be (1/2 - 1/16) and not 1/2, or several CTS tests subtly break on Intel.
8384 emit_op(result_type, id,
8385 join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
8386 ".interpolate_at_offset(", to_expression(args[1]), " + 0.4375)", component),
8387 should_forward(args[0]) && should_forward(args[1]));
8388 break;
8389 }
8390
8391 case GLSLstd450Distance:
8392 // MSL does not support scalar versions here.
8393 if (expression_type(args[0]).vecsize == 1)
8394 {
8395 // Equivalent to length(a - b) -> abs(a - b).
8396 emit_op(result_type, id,
8397 join("abs(", to_enclosed_unpacked_expression(args[0]), " - ",
8398 to_enclosed_unpacked_expression(args[1]), ")"),
8399 should_forward(args[0]) && should_forward(args[1]));
8400 inherit_expression_dependencies(id, args[0]);
8401 inherit_expression_dependencies(id, args[1]);
8402 }
8403 else
8404 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
8405 break;
8406
8407 case GLSLstd450Length:
8408 // MSL does not support scalar versions here.
8409 if (expression_type(args[0]).vecsize == 1)
8410 {
8411 // Equivalent to abs().
8412 emit_unary_func_op(result_type, id, args[0], "abs");
8413 }
8414 else
8415 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
8416 break;
8417
8418 case GLSLstd450Normalize:
8419 // MSL does not support scalar versions here.
8420 if (expression_type(args[0]).vecsize == 1)
8421 {
8422 // Returns -1 or 1 for valid input, sign() does the job.
8423 emit_unary_func_op(result_type, id, args[0], "sign");
8424 }
8425 else
8426 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
8427 break;
8428
8429 case GLSLstd450Reflect:
8430 if (get<SPIRType>(result_type).vecsize == 1)
8431 emit_binary_func_op(result_type, id, args[0], args[1], "spvReflect");
8432 else
8433 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
8434 break;
8435
8436 case GLSLstd450Refract:
8437 if (get<SPIRType>(result_type).vecsize == 1)
8438 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvRefract");
8439 else
8440 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
8441 break;
8442
8443 case GLSLstd450FaceForward:
8444 if (get<SPIRType>(result_type).vecsize == 1)
8445 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvFaceForward");
8446 else
8447 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
8448 break;
8449
8450 case GLSLstd450Modf:
8451 case GLSLstd450Frexp:
8452 {
8453 // Special case. If the variable is a scalar access chain, we cannot use it directly. We have to emit a temporary.
8454 auto *ptr = maybe_get<SPIRExpression>(args[1]);
8455 if (ptr && ptr->access_chain && is_scalar(expression_type(args[1])))
8456 {
8457 register_call_out_argument(args[1]);
8458 forced_temporaries.insert(id);
8459
8460 // Need to create temporaries and copy over to access chain after.
8461 // We cannot directly take the reference of a vector swizzle in MSL, even if it's scalar ...
8462 uint32_t &tmp_id = extra_sub_expressions[id];
8463 if (!tmp_id)
8464 tmp_id = ir.increase_bound_by(1);
8465
8466 uint32_t tmp_type_id = get_pointee_type_id(ptr->expression_type);
8467 emit_uninitialized_temporary_expression(tmp_type_id, tmp_id);
8468 emit_binary_func_op(result_type, id, args[0], tmp_id, eop == GLSLstd450Modf ? "modf" : "frexp");
8469 statement(to_expression(args[1]), " = ", to_expression(tmp_id), ";");
8470 }
8471 else
8472 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
8473 break;
8474 }
8475
8476 default:
8477 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
8478 break;
8479 }
8480 }
8481
emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type,uint32_t id,uint32_t eop,const uint32_t * args,uint32_t count)8482 void CompilerMSL::emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type, uint32_t id, uint32_t eop,
8483 const uint32_t *args, uint32_t count)
8484 {
8485 enum AMDShaderTrinaryMinMax
8486 {
8487 FMin3AMD = 1,
8488 UMin3AMD = 2,
8489 SMin3AMD = 3,
8490 FMax3AMD = 4,
8491 UMax3AMD = 5,
8492 SMax3AMD = 6,
8493 FMid3AMD = 7,
8494 UMid3AMD = 8,
8495 SMid3AMD = 9
8496 };
8497
8498 if (!msl_options.supports_msl_version(2, 1))
8499 SPIRV_CROSS_THROW("Trinary min/max functions require MSL 2.1.");
8500
8501 auto op = static_cast<AMDShaderTrinaryMinMax>(eop);
8502
8503 switch (op)
8504 {
8505 case FMid3AMD:
8506 case UMid3AMD:
8507 case SMid3AMD:
8508 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "median3");
8509 break;
8510 default:
8511 CompilerGLSL::emit_spv_amd_shader_trinary_minmax_op(result_type, id, eop, args, count);
8512 break;
8513 }
8514 }
8515
8516 // Emit a structure declaration for the specified interface variable.
emit_interface_block(uint32_t ib_var_id)8517 void CompilerMSL::emit_interface_block(uint32_t ib_var_id)
8518 {
8519 if (ib_var_id)
8520 {
8521 auto &ib_var = get<SPIRVariable>(ib_var_id);
8522 auto &ib_type = get_variable_data_type(ib_var);
8523 assert(ib_type.basetype == SPIRType::Struct && !ib_type.member_types.empty());
8524 emit_struct(ib_type);
8525 }
8526 }
8527
8528 // Emits the declaration signature of the specified function.
8529 // If this is the entry point function, Metal-specific return value and function arguments are added.
emit_function_prototype(SPIRFunction & func,const Bitset &)8530 void CompilerMSL::emit_function_prototype(SPIRFunction &func, const Bitset &)
8531 {
8532 if (func.self != ir.default_entry_point)
8533 add_function_overload(func);
8534
8535 local_variable_names = resource_names;
8536 string decl;
8537
8538 processing_entry_point = func.self == ir.default_entry_point;
8539
8540 // Metal helper functions must be static force-inline otherwise they will cause problems when linked together in a single Metallib.
8541 if (!processing_entry_point)
8542 statement(force_inline);
8543
8544 auto &type = get<SPIRType>(func.return_type);
8545
8546 if (!type.array.empty() && msl_options.force_native_arrays)
8547 {
8548 // We cannot return native arrays in MSL, so "return" through an out variable.
8549 decl += "void";
8550 }
8551 else
8552 {
8553 decl += func_type_decl(type);
8554 }
8555
8556 decl += " ";
8557 decl += to_name(func.self);
8558 decl += "(";
8559
8560 if (!type.array.empty() && msl_options.force_native_arrays)
8561 {
8562 // Fake arrays returns by writing to an out array instead.
8563 decl += "thread ";
8564 decl += type_to_glsl(type);
8565 decl += " (&spvReturnValue)";
8566 decl += type_to_array_glsl(type);
8567 if (!func.arguments.empty())
8568 decl += ", ";
8569 }
8570
8571 if (processing_entry_point)
8572 {
8573 if (msl_options.argument_buffers)
8574 decl += entry_point_args_argument_buffer(!func.arguments.empty());
8575 else
8576 decl += entry_point_args_classic(!func.arguments.empty());
8577
8578 // If entry point function has variables that require early declaration,
8579 // ensure they each have an empty initializer, creating one if needed.
8580 // This is done at this late stage because the initialization expression
8581 // is cleared after each compilation pass.
8582 for (auto var_id : vars_needing_early_declaration)
8583 {
8584 auto &ed_var = get<SPIRVariable>(var_id);
8585 ID &initializer = ed_var.initializer;
8586 if (!initializer)
8587 initializer = ir.increase_bound_by(1);
8588
8589 // Do not override proper initializers.
8590 if (ir.ids[initializer].get_type() == TypeNone || ir.ids[initializer].get_type() == TypeExpression)
8591 set<SPIRExpression>(ed_var.initializer, "{}", ed_var.basetype, true);
8592 }
8593 }
8594
8595 for (auto &arg : func.arguments)
8596 {
8597 uint32_t name_id = arg.id;
8598
8599 auto *var = maybe_get<SPIRVariable>(arg.id);
8600 if (var)
8601 {
8602 // If we need to modify the name of the variable, make sure we modify the original variable.
8603 // Our alias is just a shadow variable.
8604 if (arg.alias_global_variable && var->basevariable)
8605 name_id = var->basevariable;
8606
8607 var->parameter = &arg; // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
8608 }
8609
8610 add_local_variable_name(name_id);
8611
8612 decl += argument_decl(arg);
8613
8614 bool is_dynamic_img_sampler = has_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
8615
8616 auto &arg_type = get<SPIRType>(arg.type);
8617 if (arg_type.basetype == SPIRType::SampledImage && !is_dynamic_img_sampler)
8618 {
8619 // Manufacture automatic plane args for multiplanar texture
8620 uint32_t planes = 1;
8621 if (auto *constexpr_sampler = find_constexpr_sampler(name_id))
8622 if (constexpr_sampler->ycbcr_conversion_enable)
8623 planes = constexpr_sampler->planes;
8624 for (uint32_t i = 1; i < planes; i++)
8625 decl += join(", ", argument_decl(arg), plane_name_suffix, i);
8626
8627 // Manufacture automatic sampler arg for SampledImage texture
8628 if (arg_type.image.dim != DimBuffer)
8629 decl += join(", thread const ", sampler_type(arg_type, arg.id), " ", to_sampler_expression(arg.id));
8630 }
8631
8632 // Manufacture automatic swizzle arg.
8633 if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(arg_type) &&
8634 !is_dynamic_img_sampler)
8635 {
8636 bool arg_is_array = !arg_type.array.empty();
8637 decl += join(", constant uint", arg_is_array ? "* " : "& ", to_swizzle_expression(arg.id));
8638 }
8639
8640 if (buffers_requiring_array_length.count(name_id))
8641 {
8642 bool arg_is_array = !arg_type.array.empty();
8643 decl += join(", constant uint", arg_is_array ? "* " : "& ", to_buffer_size_expression(name_id));
8644 }
8645
8646 if (&arg != &func.arguments.back())
8647 decl += ", ";
8648 }
8649
8650 decl += ")";
8651 statement(decl);
8652 }
8653
needs_chroma_reconstruction(const MSLConstexprSampler * constexpr_sampler)8654 static bool needs_chroma_reconstruction(const MSLConstexprSampler *constexpr_sampler)
8655 {
8656 // For now, only multiplanar images need explicit reconstruction. GBGR and BGRG images
8657 // use implicit reconstruction.
8658 return constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && constexpr_sampler->planes > 1;
8659 }
8660
8661 // Returns the texture sampling function string for the specified image and sampling characteristics.
to_function_name(const TextureFunctionNameArguments & args)8662 string CompilerMSL::to_function_name(const TextureFunctionNameArguments &args)
8663 {
8664 VariableID img = args.base.img;
8665 auto &imgtype = *args.base.imgtype;
8666
8667 const MSLConstexprSampler *constexpr_sampler = nullptr;
8668 bool is_dynamic_img_sampler = false;
8669 if (auto *var = maybe_get_backing_variable(img))
8670 {
8671 constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
8672 is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
8673 }
8674
8675 // Special-case gather. We have to alter the component being looked up
8676 // in the swizzle case.
8677 if (msl_options.swizzle_texture_samples && args.base.is_gather && !is_dynamic_img_sampler &&
8678 (!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable))
8679 {
8680 add_spv_func_and_recompile(imgtype.image.depth ? SPVFuncImplGatherCompareSwizzle : SPVFuncImplGatherSwizzle);
8681 return imgtype.image.depth ? "spvGatherCompareSwizzle" : "spvGatherSwizzle";
8682 }
8683
8684 auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
8685
8686 // Texture reference
8687 string fname;
8688 if (needs_chroma_reconstruction(constexpr_sampler) && !is_dynamic_img_sampler)
8689 {
8690 if (constexpr_sampler->planes != 2 && constexpr_sampler->planes != 3)
8691 SPIRV_CROSS_THROW("Unhandled number of color image planes!");
8692 // 444 images aren't downsampled, so we don't need to do linear filtering.
8693 if (constexpr_sampler->resolution == MSL_FORMAT_RESOLUTION_444 ||
8694 constexpr_sampler->chroma_filter == MSL_SAMPLER_FILTER_NEAREST)
8695 {
8696 if (constexpr_sampler->planes == 2)
8697 add_spv_func_and_recompile(SPVFuncImplChromaReconstructNearest2Plane);
8698 else
8699 add_spv_func_and_recompile(SPVFuncImplChromaReconstructNearest3Plane);
8700 fname = "spvChromaReconstructNearest";
8701 }
8702 else // Linear with a downsampled format
8703 {
8704 fname = "spvChromaReconstructLinear";
8705 switch (constexpr_sampler->resolution)
8706 {
8707 case MSL_FORMAT_RESOLUTION_444:
8708 assert(false);
8709 break; // not reached
8710 case MSL_FORMAT_RESOLUTION_422:
8711 switch (constexpr_sampler->x_chroma_offset)
8712 {
8713 case MSL_CHROMA_LOCATION_COSITED_EVEN:
8714 if (constexpr_sampler->planes == 2)
8715 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422CositedEven2Plane);
8716 else
8717 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422CositedEven3Plane);
8718 fname += "422CositedEven";
8719 break;
8720 case MSL_CHROMA_LOCATION_MIDPOINT:
8721 if (constexpr_sampler->planes == 2)
8722 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422Midpoint2Plane);
8723 else
8724 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422Midpoint3Plane);
8725 fname += "422Midpoint";
8726 break;
8727 default:
8728 SPIRV_CROSS_THROW("Invalid chroma location.");
8729 }
8730 break;
8731 case MSL_FORMAT_RESOLUTION_420:
8732 fname += "420";
8733 switch (constexpr_sampler->x_chroma_offset)
8734 {
8735 case MSL_CHROMA_LOCATION_COSITED_EVEN:
8736 switch (constexpr_sampler->y_chroma_offset)
8737 {
8738 case MSL_CHROMA_LOCATION_COSITED_EVEN:
8739 if (constexpr_sampler->planes == 2)
8740 add_spv_func_and_recompile(
8741 SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane);
8742 else
8743 add_spv_func_and_recompile(
8744 SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane);
8745 fname += "XCositedEvenYCositedEven";
8746 break;
8747 case MSL_CHROMA_LOCATION_MIDPOINT:
8748 if (constexpr_sampler->planes == 2)
8749 add_spv_func_and_recompile(
8750 SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane);
8751 else
8752 add_spv_func_and_recompile(
8753 SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane);
8754 fname += "XCositedEvenYMidpoint";
8755 break;
8756 default:
8757 SPIRV_CROSS_THROW("Invalid Y chroma location.");
8758 }
8759 break;
8760 case MSL_CHROMA_LOCATION_MIDPOINT:
8761 switch (constexpr_sampler->y_chroma_offset)
8762 {
8763 case MSL_CHROMA_LOCATION_COSITED_EVEN:
8764 if (constexpr_sampler->planes == 2)
8765 add_spv_func_and_recompile(
8766 SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane);
8767 else
8768 add_spv_func_and_recompile(
8769 SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane);
8770 fname += "XMidpointYCositedEven";
8771 break;
8772 case MSL_CHROMA_LOCATION_MIDPOINT:
8773 if (constexpr_sampler->planes == 2)
8774 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane);
8775 else
8776 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane);
8777 fname += "XMidpointYMidpoint";
8778 break;
8779 default:
8780 SPIRV_CROSS_THROW("Invalid Y chroma location.");
8781 }
8782 break;
8783 default:
8784 SPIRV_CROSS_THROW("Invalid X chroma location.");
8785 }
8786 break;
8787 default:
8788 SPIRV_CROSS_THROW("Invalid format resolution.");
8789 }
8790 }
8791 }
8792 else
8793 {
8794 fname = to_expression(combined ? combined->image : img) + ".";
8795
8796 // Texture function and sampler
8797 if (args.base.is_fetch)
8798 fname += "read";
8799 else if (args.base.is_gather)
8800 fname += "gather";
8801 else
8802 fname += "sample";
8803
8804 if (args.has_dref)
8805 fname += "_compare";
8806 }
8807
8808 return fname;
8809 }
8810
convert_to_f32(const string & expr,uint32_t components)8811 string CompilerMSL::convert_to_f32(const string &expr, uint32_t components)
8812 {
8813 SPIRType t;
8814 t.basetype = SPIRType::Float;
8815 t.vecsize = components;
8816 t.columns = 1;
8817 return join(type_to_glsl_constructor(t), "(", expr, ")");
8818 }
8819
sampling_type_needs_f32_conversion(const SPIRType & type)8820 static inline bool sampling_type_needs_f32_conversion(const SPIRType &type)
8821 {
8822 // Double is not supported to begin with, but doesn't hurt to check for completion.
8823 return type.basetype == SPIRType::Half || type.basetype == SPIRType::Double;
8824 }
8825
8826 // Returns the function args for a texture sampling function for the specified image and sampling characteristics.
to_function_args(const TextureFunctionArguments & args,bool * p_forward)8827 string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool *p_forward)
8828 {
8829 VariableID img = args.base.img;
8830 auto &imgtype = *args.base.imgtype;
8831 uint32_t lod = args.lod;
8832 uint32_t grad_x = args.grad_x;
8833 uint32_t grad_y = args.grad_y;
8834 uint32_t bias = args.bias;
8835
8836 const MSLConstexprSampler *constexpr_sampler = nullptr;
8837 bool is_dynamic_img_sampler = false;
8838 if (auto *var = maybe_get_backing_variable(img))
8839 {
8840 constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
8841 is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
8842 }
8843
8844 string farg_str;
8845 bool forward = true;
8846
8847 if (!is_dynamic_img_sampler)
8848 {
8849 // Texture reference (for some cases)
8850 if (needs_chroma_reconstruction(constexpr_sampler))
8851 {
8852 // Multiplanar images need two or three textures.
8853 farg_str += to_expression(img);
8854 for (uint32_t i = 1; i < constexpr_sampler->planes; i++)
8855 farg_str += join(", ", to_expression(img), plane_name_suffix, i);
8856 }
8857 else if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
8858 msl_options.swizzle_texture_samples && args.base.is_gather)
8859 {
8860 auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
8861 farg_str += to_expression(combined ? combined->image : img);
8862 }
8863
8864 // Sampler reference
8865 if (!args.base.is_fetch)
8866 {
8867 if (!farg_str.empty())
8868 farg_str += ", ";
8869 farg_str += to_sampler_expression(img);
8870 }
8871
8872 if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
8873 msl_options.swizzle_texture_samples && args.base.is_gather)
8874 {
8875 // Add the swizzle constant from the swizzle buffer.
8876 farg_str += ", " + to_swizzle_expression(img);
8877 used_swizzle_buffer = true;
8878 }
8879
8880 // Swizzled gather puts the component before the other args, to allow template
8881 // deduction to work.
8882 if (args.component && msl_options.swizzle_texture_samples)
8883 {
8884 forward = should_forward(args.component);
8885 farg_str += ", " + to_component_argument(args.component);
8886 }
8887 }
8888
8889 // Texture coordinates
8890 forward = forward && should_forward(args.coord);
8891 auto coord_expr = to_enclosed_expression(args.coord);
8892 auto &coord_type = expression_type(args.coord);
8893 bool coord_is_fp = type_is_floating_point(coord_type);
8894 bool is_cube_fetch = false;
8895
8896 string tex_coords = coord_expr;
8897 uint32_t alt_coord_component = 0;
8898
8899 switch (imgtype.image.dim)
8900 {
8901
8902 case Dim1D:
8903 if (coord_type.vecsize > 1)
8904 tex_coords = enclose_expression(tex_coords) + ".x";
8905
8906 if (args.base.is_fetch)
8907 tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8908 else if (sampling_type_needs_f32_conversion(coord_type))
8909 tex_coords = convert_to_f32(tex_coords, 1);
8910
8911 if (msl_options.texture_1D_as_2D)
8912 {
8913 if (args.base.is_fetch)
8914 tex_coords = "uint2(" + tex_coords + ", 0)";
8915 else
8916 tex_coords = "float2(" + tex_coords + ", 0.5)";
8917 }
8918
8919 alt_coord_component = 1;
8920 break;
8921
8922 case DimBuffer:
8923 if (coord_type.vecsize > 1)
8924 tex_coords = enclose_expression(tex_coords) + ".x";
8925
8926 if (msl_options.texture_buffer_native)
8927 {
8928 tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8929 }
8930 else
8931 {
8932 // Metal texel buffer textures are 2D, so convert 1D coord to 2D.
8933 // Support for Metal 2.1's new texture_buffer type.
8934 if (args.base.is_fetch)
8935 {
8936 if (msl_options.texel_buffer_texture_width > 0)
8937 {
8938 tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8939 }
8940 else
8941 {
8942 tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ", " +
8943 to_expression(img) + ")";
8944 }
8945 }
8946 }
8947
8948 alt_coord_component = 1;
8949 break;
8950
8951 case DimSubpassData:
8952 // If we're using Metal's native frame-buffer fetch API for subpass inputs,
8953 // this path will not be hit.
8954 tex_coords = "uint2(gl_FragCoord.xy)";
8955 alt_coord_component = 2;
8956 break;
8957
8958 case Dim2D:
8959 if (coord_type.vecsize > 2)
8960 tex_coords = enclose_expression(tex_coords) + ".xy";
8961
8962 if (args.base.is_fetch)
8963 tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8964 else if (sampling_type_needs_f32_conversion(coord_type))
8965 tex_coords = convert_to_f32(tex_coords, 2);
8966
8967 alt_coord_component = 2;
8968 break;
8969
8970 case Dim3D:
8971 if (coord_type.vecsize > 3)
8972 tex_coords = enclose_expression(tex_coords) + ".xyz";
8973
8974 if (args.base.is_fetch)
8975 tex_coords = "uint3(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8976 else if (sampling_type_needs_f32_conversion(coord_type))
8977 tex_coords = convert_to_f32(tex_coords, 3);
8978
8979 alt_coord_component = 3;
8980 break;
8981
8982 case DimCube:
8983 if (args.base.is_fetch)
8984 {
8985 is_cube_fetch = true;
8986 tex_coords += ".xy";
8987 tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
8988 }
8989 else
8990 {
8991 if (coord_type.vecsize > 3)
8992 tex_coords = enclose_expression(tex_coords) + ".xyz";
8993 }
8994
8995 if (sampling_type_needs_f32_conversion(coord_type))
8996 tex_coords = convert_to_f32(tex_coords, 3);
8997
8998 alt_coord_component = 3;
8999 break;
9000
9001 default:
9002 break;
9003 }
9004
9005 if (args.base.is_fetch && (args.offset || args.coffset))
9006 {
9007 uint32_t offset_expr = args.offset ? args.offset : args.coffset;
9008 // Fetch offsets must be applied directly to the coordinate.
9009 forward = forward && should_forward(offset_expr);
9010 auto &type = expression_type(offset_expr);
9011 if (imgtype.image.dim == Dim1D && msl_options.texture_1D_as_2D)
9012 {
9013 if (type.basetype != SPIRType::UInt)
9014 tex_coords += join(" + uint2(", bitcast_expression(SPIRType::UInt, offset_expr), ", 0)");
9015 else
9016 tex_coords += join(" + uint2(", to_enclosed_expression(offset_expr), ", 0)");
9017 }
9018 else
9019 {
9020 if (type.basetype != SPIRType::UInt)
9021 tex_coords += " + " + bitcast_expression(SPIRType::UInt, offset_expr);
9022 else
9023 tex_coords += " + " + to_enclosed_expression(offset_expr);
9024 }
9025 }
9026
9027 // If projection, use alt coord as divisor
9028 if (args.base.is_proj)
9029 {
9030 if (sampling_type_needs_f32_conversion(coord_type))
9031 tex_coords += " / " + convert_to_f32(to_extract_component_expression(args.coord, alt_coord_component), 1);
9032 else
9033 tex_coords += " / " + to_extract_component_expression(args.coord, alt_coord_component);
9034 }
9035
9036 if (!farg_str.empty())
9037 farg_str += ", ";
9038
9039 if (imgtype.image.dim == DimCube && imgtype.image.arrayed && msl_options.emulate_cube_array)
9040 {
9041 farg_str += "spvCubemapTo2DArrayFace(" + tex_coords + ").xy";
9042
9043 if (is_cube_fetch)
9044 farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ")";
9045 else
9046 farg_str +=
9047 ", uint(spvCubemapTo2DArrayFace(" + tex_coords + ").z) + (uint(" +
9048 round_fp_tex_coords(to_extract_component_expression(args.coord, alt_coord_component), coord_is_fp) +
9049 ") * 6u)";
9050
9051 add_spv_func_and_recompile(SPVFuncImplCubemapTo2DArrayFace);
9052 }
9053 else
9054 {
9055 farg_str += tex_coords;
9056
9057 // If fetch from cube, add face explicitly
9058 if (is_cube_fetch)
9059 {
9060 // Special case for cube arrays, face and layer are packed in one dimension.
9061 if (imgtype.image.arrayed)
9062 farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ") % 6u";
9063 else
9064 farg_str +=
9065 ", uint(" + round_fp_tex_coords(to_extract_component_expression(args.coord, 2), coord_is_fp) + ")";
9066 }
9067
9068 // If array, use alt coord
9069 if (imgtype.image.arrayed)
9070 {
9071 // Special case for cube arrays, face and layer are packed in one dimension.
9072 if (imgtype.image.dim == DimCube && args.base.is_fetch)
9073 {
9074 farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ") / 6u";
9075 }
9076 else
9077 {
9078 farg_str +=
9079 ", uint(" +
9080 round_fp_tex_coords(to_extract_component_expression(args.coord, alt_coord_component), coord_is_fp) +
9081 ")";
9082 if (imgtype.image.dim == DimSubpassData)
9083 {
9084 if (msl_options.multiview)
9085 farg_str += " + gl_ViewIndex";
9086 else if (msl_options.arrayed_subpass_input)
9087 farg_str += " + gl_Layer";
9088 }
9089 }
9090 }
9091 else if (imgtype.image.dim == DimSubpassData)
9092 {
9093 if (msl_options.multiview)
9094 farg_str += ", gl_ViewIndex";
9095 else if (msl_options.arrayed_subpass_input)
9096 farg_str += ", gl_Layer";
9097 }
9098 }
9099
9100 // Depth compare reference value
9101 if (args.dref)
9102 {
9103 forward = forward && should_forward(args.dref);
9104 farg_str += ", ";
9105
9106 auto &dref_type = expression_type(args.dref);
9107
9108 string dref_expr;
9109 if (args.base.is_proj)
9110 dref_expr = join(to_enclosed_expression(args.dref), " / ",
9111 to_extract_component_expression(args.coord, alt_coord_component));
9112 else
9113 dref_expr = to_expression(args.dref);
9114
9115 if (sampling_type_needs_f32_conversion(dref_type))
9116 dref_expr = convert_to_f32(dref_expr, 1);
9117
9118 farg_str += dref_expr;
9119
9120 if (msl_options.is_macos() && (grad_x || grad_y))
9121 {
9122 // For sample compare, MSL does not support gradient2d for all targets (only iOS apparently according to docs).
9123 // However, the most common case here is to have a constant gradient of 0, as that is the only way to express
9124 // LOD == 0 in GLSL with sampler2DArrayShadow (cascaded shadow mapping).
9125 // We will detect a compile-time constant 0 value for gradient and promote that to level(0) on MSL.
9126 bool constant_zero_x = !grad_x || expression_is_constant_null(grad_x);
9127 bool constant_zero_y = !grad_y || expression_is_constant_null(grad_y);
9128 if (constant_zero_x && constant_zero_y)
9129 {
9130 lod = 0;
9131 grad_x = 0;
9132 grad_y = 0;
9133 farg_str += ", level(0)";
9134 }
9135 else if (!msl_options.supports_msl_version(2, 3))
9136 {
9137 SPIRV_CROSS_THROW("Using non-constant 0.0 gradient() qualifier for sample_compare. This is not "
9138 "supported on macOS prior to MSL 2.3.");
9139 }
9140 }
9141
9142 if (msl_options.is_macos() && bias)
9143 {
9144 // Bias is not supported either on macOS with sample_compare.
9145 // Verify it is compile-time zero, and drop the argument.
9146 if (expression_is_constant_null(bias))
9147 {
9148 bias = 0;
9149 }
9150 else if (!msl_options.supports_msl_version(2, 3))
9151 {
9152 SPIRV_CROSS_THROW("Using non-constant 0.0 bias() qualifier for sample_compare. This is not supported "
9153 "on macOS prior to MSL 2.3.");
9154 }
9155 }
9156 }
9157
9158 // LOD Options
9159 // Metal does not support LOD for 1D textures.
9160 if (bias && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
9161 {
9162 forward = forward && should_forward(bias);
9163 farg_str += ", bias(" + to_expression(bias) + ")";
9164 }
9165
9166 // Metal does not support LOD for 1D textures.
9167 if (lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
9168 {
9169 forward = forward && should_forward(lod);
9170 if (args.base.is_fetch)
9171 {
9172 farg_str += ", " + to_expression(lod);
9173 }
9174 else
9175 {
9176 farg_str += ", level(" + to_expression(lod) + ")";
9177 }
9178 }
9179 else if (args.base.is_fetch && !lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D) &&
9180 imgtype.image.dim != DimBuffer && !imgtype.image.ms && imgtype.image.sampled != 2)
9181 {
9182 // Lod argument is optional in OpImageFetch, but we require a LOD value, pick 0 as the default.
9183 // Check for sampled type as well, because is_fetch is also used for OpImageRead in MSL.
9184 farg_str += ", 0";
9185 }
9186
9187 // Metal does not support LOD for 1D textures.
9188 if ((grad_x || grad_y) && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
9189 {
9190 forward = forward && should_forward(grad_x);
9191 forward = forward && should_forward(grad_y);
9192 string grad_opt;
9193 switch (imgtype.image.dim)
9194 {
9195 case Dim1D:
9196 case Dim2D:
9197 grad_opt = "2d";
9198 break;
9199 case Dim3D:
9200 grad_opt = "3d";
9201 break;
9202 case DimCube:
9203 if (imgtype.image.arrayed && msl_options.emulate_cube_array)
9204 grad_opt = "2d";
9205 else
9206 grad_opt = "cube";
9207 break;
9208 default:
9209 grad_opt = "unsupported_gradient_dimension";
9210 break;
9211 }
9212 farg_str += ", gradient" + grad_opt + "(" + to_expression(grad_x) + ", " + to_expression(grad_y) + ")";
9213 }
9214
9215 if (args.min_lod)
9216 {
9217 if (!msl_options.supports_msl_version(2, 2))
9218 SPIRV_CROSS_THROW("min_lod_clamp() is only supported in MSL 2.2+ and up.");
9219
9220 forward = forward && should_forward(args.min_lod);
9221 farg_str += ", min_lod_clamp(" + to_expression(args.min_lod) + ")";
9222 }
9223
9224 // Add offsets
9225 string offset_expr;
9226 const SPIRType *offset_type = nullptr;
9227 if (args.coffset && !args.base.is_fetch)
9228 {
9229 forward = forward && should_forward(args.coffset);
9230 offset_expr = to_expression(args.coffset);
9231 offset_type = &expression_type(args.coffset);
9232 }
9233 else if (args.offset && !args.base.is_fetch)
9234 {
9235 forward = forward && should_forward(args.offset);
9236 offset_expr = to_expression(args.offset);
9237 offset_type = &expression_type(args.offset);
9238 }
9239
9240 if (!offset_expr.empty())
9241 {
9242 switch (imgtype.image.dim)
9243 {
9244 case Dim1D:
9245 if (!msl_options.texture_1D_as_2D)
9246 break;
9247 if (offset_type->vecsize > 1)
9248 offset_expr = enclose_expression(offset_expr) + ".x";
9249
9250 farg_str += join(", int2(", offset_expr, ", 0)");
9251 break;
9252
9253 case Dim2D:
9254 if (offset_type->vecsize > 2)
9255 offset_expr = enclose_expression(offset_expr) + ".xy";
9256
9257 farg_str += ", " + offset_expr;
9258 break;
9259
9260 case Dim3D:
9261 if (offset_type->vecsize > 3)
9262 offset_expr = enclose_expression(offset_expr) + ".xyz";
9263
9264 farg_str += ", " + offset_expr;
9265 break;
9266
9267 default:
9268 break;
9269 }
9270 }
9271
9272 if (args.component)
9273 {
9274 // If 2D has gather component, ensure it also has an offset arg
9275 if (imgtype.image.dim == Dim2D && offset_expr.empty())
9276 farg_str += ", int2(0)";
9277
9278 if (!msl_options.swizzle_texture_samples || is_dynamic_img_sampler)
9279 {
9280 forward = forward && should_forward(args.component);
9281
9282 uint32_t image_var = 0;
9283 if (const auto *combined = maybe_get<SPIRCombinedImageSampler>(img))
9284 {
9285 if (const auto *img_var = maybe_get_backing_variable(combined->image))
9286 image_var = img_var->self;
9287 }
9288 else if (const auto *var = maybe_get_backing_variable(img))
9289 {
9290 image_var = var->self;
9291 }
9292
9293 if (image_var == 0 || !image_is_comparison(expression_type(image_var), image_var))
9294 farg_str += ", " + to_component_argument(args.component);
9295 }
9296 }
9297
9298 if (args.sample)
9299 {
9300 forward = forward && should_forward(args.sample);
9301 farg_str += ", ";
9302 farg_str += to_expression(args.sample);
9303 }
9304
9305 *p_forward = forward;
9306
9307 return farg_str;
9308 }
9309
9310 // If the texture coordinates are floating point, invokes MSL round() function to round them.
round_fp_tex_coords(string tex_coords,bool coord_is_fp)9311 string CompilerMSL::round_fp_tex_coords(string tex_coords, bool coord_is_fp)
9312 {
9313 return coord_is_fp ? ("round(" + tex_coords + ")") : tex_coords;
9314 }
9315
9316 // Returns a string to use in an image sampling function argument.
9317 // The ID must be a scalar constant.
to_component_argument(uint32_t id)9318 string CompilerMSL::to_component_argument(uint32_t id)
9319 {
9320 uint32_t component_index = evaluate_constant_u32(id);
9321 switch (component_index)
9322 {
9323 case 0:
9324 return "component::x";
9325 case 1:
9326 return "component::y";
9327 case 2:
9328 return "component::z";
9329 case 3:
9330 return "component::w";
9331
9332 default:
9333 SPIRV_CROSS_THROW("The value (" + to_string(component_index) + ") of OpConstant ID " + to_string(id) +
9334 " is not a valid Component index, which must be one of 0, 1, 2, or 3.");
9335 return "component::x";
9336 }
9337 }
9338
9339 // Establish sampled image as expression object and assign the sampler to it.
emit_sampled_image_op(uint32_t result_type,uint32_t result_id,uint32_t image_id,uint32_t samp_id)9340 void CompilerMSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
9341 {
9342 set<SPIRCombinedImageSampler>(result_id, result_type, image_id, samp_id);
9343 }
9344
to_texture_op(const Instruction & i,bool sparse,bool * forward,SmallVector<uint32_t> & inherited_expressions)9345 string CompilerMSL::to_texture_op(const Instruction &i, bool sparse, bool *forward,
9346 SmallVector<uint32_t> &inherited_expressions)
9347 {
9348 auto *ops = stream(i);
9349 uint32_t result_type_id = ops[0];
9350 uint32_t img = ops[2];
9351 auto &result_type = get<SPIRType>(result_type_id);
9352 auto op = static_cast<Op>(i.op);
9353 bool is_gather = (op == OpImageGather || op == OpImageDrefGather);
9354
9355 // Bypass pointers because we need the real image struct
9356 auto &type = expression_type(img);
9357 auto &imgtype = get<SPIRType>(type.self);
9358
9359 const MSLConstexprSampler *constexpr_sampler = nullptr;
9360 bool is_dynamic_img_sampler = false;
9361 if (auto *var = maybe_get_backing_variable(img))
9362 {
9363 constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
9364 is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
9365 }
9366
9367 string expr;
9368 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
9369 {
9370 // If this needs sampler Y'CbCr conversion, we need to do some additional
9371 // processing.
9372 switch (constexpr_sampler->ycbcr_model)
9373 {
9374 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
9375 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
9376 // Default
9377 break;
9378 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
9379 add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT709);
9380 expr += "spvConvertYCbCrBT709(";
9381 break;
9382 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
9383 add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT601);
9384 expr += "spvConvertYCbCrBT601(";
9385 break;
9386 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
9387 add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT2020);
9388 expr += "spvConvertYCbCrBT2020(";
9389 break;
9390 default:
9391 SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
9392 }
9393
9394 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
9395 {
9396 switch (constexpr_sampler->ycbcr_range)
9397 {
9398 case MSL_SAMPLER_YCBCR_RANGE_ITU_FULL:
9399 add_spv_func_and_recompile(SPVFuncImplExpandITUFullRange);
9400 expr += "spvExpandITUFullRange(";
9401 break;
9402 case MSL_SAMPLER_YCBCR_RANGE_ITU_NARROW:
9403 add_spv_func_and_recompile(SPVFuncImplExpandITUNarrowRange);
9404 expr += "spvExpandITUNarrowRange(";
9405 break;
9406 default:
9407 SPIRV_CROSS_THROW("Invalid Y'CbCr range.");
9408 }
9409 }
9410 }
9411 else if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype) &&
9412 !is_dynamic_img_sampler)
9413 {
9414 add_spv_func_and_recompile(SPVFuncImplTextureSwizzle);
9415 expr += "spvTextureSwizzle(";
9416 }
9417
9418 string inner_expr = CompilerGLSL::to_texture_op(i, sparse, forward, inherited_expressions);
9419
9420 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
9421 {
9422 if (!constexpr_sampler->swizzle_is_identity())
9423 {
9424 static const char swizzle_names[] = "rgba";
9425 if (!constexpr_sampler->swizzle_has_one_or_zero())
9426 {
9427 // If we can, do it inline.
9428 expr += inner_expr + ".";
9429 for (uint32_t c = 0; c < 4; c++)
9430 {
9431 switch (constexpr_sampler->swizzle[c])
9432 {
9433 case MSL_COMPONENT_SWIZZLE_IDENTITY:
9434 expr += swizzle_names[c];
9435 break;
9436 case MSL_COMPONENT_SWIZZLE_R:
9437 case MSL_COMPONENT_SWIZZLE_G:
9438 case MSL_COMPONENT_SWIZZLE_B:
9439 case MSL_COMPONENT_SWIZZLE_A:
9440 expr += swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
9441 break;
9442 default:
9443 SPIRV_CROSS_THROW("Invalid component swizzle.");
9444 }
9445 }
9446 }
9447 else
9448 {
9449 // Otherwise, we need to emit a temporary and swizzle that.
9450 uint32_t temp_id = ir.increase_bound_by(1);
9451 emit_op(result_type_id, temp_id, inner_expr, false);
9452 for (auto &inherit : inherited_expressions)
9453 inherit_expression_dependencies(temp_id, inherit);
9454 inherited_expressions.clear();
9455 inherited_expressions.push_back(temp_id);
9456
9457 switch (op)
9458 {
9459 case OpImageSampleDrefImplicitLod:
9460 case OpImageSampleImplicitLod:
9461 case OpImageSampleProjImplicitLod:
9462 case OpImageSampleProjDrefImplicitLod:
9463 register_control_dependent_expression(temp_id);
9464 break;
9465
9466 default:
9467 break;
9468 }
9469 expr += type_to_glsl(result_type) + "(";
9470 for (uint32_t c = 0; c < 4; c++)
9471 {
9472 switch (constexpr_sampler->swizzle[c])
9473 {
9474 case MSL_COMPONENT_SWIZZLE_IDENTITY:
9475 expr += to_expression(temp_id) + "." + swizzle_names[c];
9476 break;
9477 case MSL_COMPONENT_SWIZZLE_ZERO:
9478 expr += "0";
9479 break;
9480 case MSL_COMPONENT_SWIZZLE_ONE:
9481 expr += "1";
9482 break;
9483 case MSL_COMPONENT_SWIZZLE_R:
9484 case MSL_COMPONENT_SWIZZLE_G:
9485 case MSL_COMPONENT_SWIZZLE_B:
9486 case MSL_COMPONENT_SWIZZLE_A:
9487 expr += to_expression(temp_id) + "." +
9488 swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
9489 break;
9490 default:
9491 SPIRV_CROSS_THROW("Invalid component swizzle.");
9492 }
9493 if (c < 3)
9494 expr += ", ";
9495 }
9496 expr += ")";
9497 }
9498 }
9499 else
9500 expr += inner_expr;
9501 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
9502 {
9503 expr += join(", ", constexpr_sampler->bpc, ")");
9504 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY)
9505 expr += ")";
9506 }
9507 }
9508 else
9509 {
9510 expr += inner_expr;
9511 if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype) &&
9512 !is_dynamic_img_sampler)
9513 {
9514 // Add the swizzle constant from the swizzle buffer.
9515 expr += ", " + to_swizzle_expression(img) + ")";
9516 used_swizzle_buffer = true;
9517 }
9518 }
9519
9520 return expr;
9521 }
9522
create_swizzle(MSLComponentSwizzle swizzle)9523 static string create_swizzle(MSLComponentSwizzle swizzle)
9524 {
9525 switch (swizzle)
9526 {
9527 case MSL_COMPONENT_SWIZZLE_IDENTITY:
9528 return "spvSwizzle::none";
9529 case MSL_COMPONENT_SWIZZLE_ZERO:
9530 return "spvSwizzle::zero";
9531 case MSL_COMPONENT_SWIZZLE_ONE:
9532 return "spvSwizzle::one";
9533 case MSL_COMPONENT_SWIZZLE_R:
9534 return "spvSwizzle::red";
9535 case MSL_COMPONENT_SWIZZLE_G:
9536 return "spvSwizzle::green";
9537 case MSL_COMPONENT_SWIZZLE_B:
9538 return "spvSwizzle::blue";
9539 case MSL_COMPONENT_SWIZZLE_A:
9540 return "spvSwizzle::alpha";
9541 default:
9542 SPIRV_CROSS_THROW("Invalid component swizzle.");
9543 return "";
9544 }
9545 }
9546
9547 // Returns a string representation of the ID, usable as a function arg.
9548 // Manufacture automatic sampler arg for SampledImage texture.
to_func_call_arg(const SPIRFunction::Parameter & arg,uint32_t id)9549 string CompilerMSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id)
9550 {
9551 string arg_str;
9552
9553 auto &type = expression_type(id);
9554 bool is_dynamic_img_sampler = has_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
9555 // If the argument *itself* is a "dynamic" combined-image sampler, then we can just pass that around.
9556 bool arg_is_dynamic_img_sampler = has_extended_decoration(id, SPIRVCrossDecorationDynamicImageSampler);
9557 if (is_dynamic_img_sampler && !arg_is_dynamic_img_sampler)
9558 arg_str = join("spvDynamicImageSampler<", type_to_glsl(get<SPIRType>(type.image.type)), ">(");
9559
9560 auto *c = maybe_get<SPIRConstant>(id);
9561 if (msl_options.force_native_arrays && c && !get<SPIRType>(c->constant_type).array.empty())
9562 {
9563 // If we are passing a constant array directly to a function for some reason,
9564 // the callee will expect an argument in thread const address space
9565 // (since we can only bind to arrays with references in MSL).
9566 // To resolve this, we must emit a copy in this address space.
9567 // This kind of code gen should be rare enough that performance is not a real concern.
9568 // Inline the SPIR-V to avoid this kind of suboptimal codegen.
9569 //
9570 // We risk calling this inside a continue block (invalid code),
9571 // so just create a thread local copy in the current function.
9572 arg_str = join("_", id, "_array_copy");
9573 auto &constants = current_function->constant_arrays_needed_on_stack;
9574 auto itr = find(begin(constants), end(constants), ID(id));
9575 if (itr == end(constants))
9576 {
9577 force_recompile();
9578 constants.push_back(id);
9579 }
9580 }
9581 else
9582 arg_str += CompilerGLSL::to_func_call_arg(arg, id);
9583
9584 // Need to check the base variable in case we need to apply a qualified alias.
9585 uint32_t var_id = 0;
9586 auto *var = maybe_get<SPIRVariable>(id);
9587 if (var)
9588 var_id = var->basevariable;
9589
9590 if (!arg_is_dynamic_img_sampler)
9591 {
9592 auto *constexpr_sampler = find_constexpr_sampler(var_id ? var_id : id);
9593 if (type.basetype == SPIRType::SampledImage)
9594 {
9595 // Manufacture automatic plane args for multiplanar texture
9596 uint32_t planes = 1;
9597 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
9598 {
9599 planes = constexpr_sampler->planes;
9600 // If this parameter isn't aliasing a global, then we need to use
9601 // the special "dynamic image-sampler" class to pass it--and we need
9602 // to use it for *every* non-alias parameter, in case a combined
9603 // image-sampler with a Y'CbCr conversion is passed. Hopefully, this
9604 // pathological case is so rare that it should never be hit in practice.
9605 if (!arg.alias_global_variable)
9606 add_spv_func_and_recompile(SPVFuncImplDynamicImageSampler);
9607 }
9608 for (uint32_t i = 1; i < planes; i++)
9609 arg_str += join(", ", CompilerGLSL::to_func_call_arg(arg, id), plane_name_suffix, i);
9610 // Manufacture automatic sampler arg if the arg is a SampledImage texture.
9611 if (type.image.dim != DimBuffer)
9612 arg_str += ", " + to_sampler_expression(var_id ? var_id : id);
9613
9614 // Add sampler Y'CbCr conversion info if we have it
9615 if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
9616 {
9617 SmallVector<string> samp_args;
9618
9619 switch (constexpr_sampler->resolution)
9620 {
9621 case MSL_FORMAT_RESOLUTION_444:
9622 // Default
9623 break;
9624 case MSL_FORMAT_RESOLUTION_422:
9625 samp_args.push_back("spvFormatResolution::_422");
9626 break;
9627 case MSL_FORMAT_RESOLUTION_420:
9628 samp_args.push_back("spvFormatResolution::_420");
9629 break;
9630 default:
9631 SPIRV_CROSS_THROW("Invalid format resolution.");
9632 }
9633
9634 if (constexpr_sampler->chroma_filter != MSL_SAMPLER_FILTER_NEAREST)
9635 samp_args.push_back("spvChromaFilter::linear");
9636
9637 if (constexpr_sampler->x_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
9638 samp_args.push_back("spvXChromaLocation::midpoint");
9639 if (constexpr_sampler->y_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
9640 samp_args.push_back("spvYChromaLocation::midpoint");
9641 switch (constexpr_sampler->ycbcr_model)
9642 {
9643 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
9644 // Default
9645 break;
9646 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
9647 samp_args.push_back("spvYCbCrModelConversion::ycbcr_identity");
9648 break;
9649 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
9650 samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_709");
9651 break;
9652 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
9653 samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_601");
9654 break;
9655 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
9656 samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_2020");
9657 break;
9658 default:
9659 SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
9660 }
9661 if (constexpr_sampler->ycbcr_range != MSL_SAMPLER_YCBCR_RANGE_ITU_FULL)
9662 samp_args.push_back("spvYCbCrRange::itu_narrow");
9663 samp_args.push_back(join("spvComponentBits(", constexpr_sampler->bpc, ")"));
9664 arg_str += join(", spvYCbCrSampler(", merge(samp_args), ")");
9665 }
9666 }
9667
9668 if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
9669 arg_str += join(", (uint(", create_swizzle(constexpr_sampler->swizzle[3]), ") << 24) | (uint(",
9670 create_swizzle(constexpr_sampler->swizzle[2]), ") << 16) | (uint(",
9671 create_swizzle(constexpr_sampler->swizzle[1]), ") << 8) | uint(",
9672 create_swizzle(constexpr_sampler->swizzle[0]), ")");
9673 else if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
9674 arg_str += ", " + to_swizzle_expression(var_id ? var_id : id);
9675
9676 if (buffers_requiring_array_length.count(var_id))
9677 arg_str += ", " + to_buffer_size_expression(var_id ? var_id : id);
9678
9679 if (is_dynamic_img_sampler)
9680 arg_str += ")";
9681 }
9682
9683 // Emulate texture2D atomic operations
9684 auto *backing_var = maybe_get_backing_variable(var_id);
9685 if (backing_var && atomic_image_vars.count(backing_var->self))
9686 {
9687 arg_str += ", " + to_expression(var_id) + "_atomic";
9688 }
9689
9690 return arg_str;
9691 }
9692
9693 // If the ID represents a sampled image that has been assigned a sampler already,
9694 // generate an expression for the sampler, otherwise generate a fake sampler name
9695 // by appending a suffix to the expression constructed from the ID.
to_sampler_expression(uint32_t id)9696 string CompilerMSL::to_sampler_expression(uint32_t id)
9697 {
9698 auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
9699 auto expr = to_expression(combined ? combined->image : VariableID(id));
9700 auto index = expr.find_first_of('[');
9701
9702 uint32_t samp_id = 0;
9703 if (combined)
9704 samp_id = combined->sampler;
9705
9706 if (index == string::npos)
9707 return samp_id ? to_expression(samp_id) : expr + sampler_name_suffix;
9708 else
9709 {
9710 auto image_expr = expr.substr(0, index);
9711 auto array_expr = expr.substr(index);
9712 return samp_id ? to_expression(samp_id) : (image_expr + sampler_name_suffix + array_expr);
9713 }
9714 }
9715
to_swizzle_expression(uint32_t id)9716 string CompilerMSL::to_swizzle_expression(uint32_t id)
9717 {
9718 auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
9719
9720 auto expr = to_expression(combined ? combined->image : VariableID(id));
9721 auto index = expr.find_first_of('[');
9722
9723 // If an image is part of an argument buffer translate this to a legal identifier.
9724 string::size_type period = 0;
9725 while ((period = expr.find_first_of('.', period)) != string::npos && period < index)
9726 expr[period] = '_';
9727
9728 if (index == string::npos)
9729 return expr + swizzle_name_suffix;
9730 else
9731 {
9732 auto image_expr = expr.substr(0, index);
9733 auto array_expr = expr.substr(index);
9734 return image_expr + swizzle_name_suffix + array_expr;
9735 }
9736 }
9737
to_buffer_size_expression(uint32_t id)9738 string CompilerMSL::to_buffer_size_expression(uint32_t id)
9739 {
9740 auto expr = to_expression(id);
9741 auto index = expr.find_first_of('[');
9742
9743 // This is quite crude, but we need to translate the reference name (*spvDescriptorSetN.name) to
9744 // the pointer expression spvDescriptorSetN.name to make a reasonable expression here.
9745 // This only happens if we have argument buffers and we are using OpArrayLength on a lone SSBO in that set.
9746 if (expr.size() >= 3 && expr[0] == '(' && expr[1] == '*')
9747 expr = address_of_expression(expr);
9748
9749 // If a buffer is part of an argument buffer translate this to a legal identifier.
9750 for (auto &c : expr)
9751 if (c == '.')
9752 c = '_';
9753
9754 if (index == string::npos)
9755 return expr + buffer_size_name_suffix;
9756 else
9757 {
9758 auto buffer_expr = expr.substr(0, index);
9759 auto array_expr = expr.substr(index);
9760 return buffer_expr + buffer_size_name_suffix + array_expr;
9761 }
9762 }
9763
9764 // Checks whether the type is a Block all of whose members have DecorationPatch.
is_patch_block(const SPIRType & type)9765 bool CompilerMSL::is_patch_block(const SPIRType &type)
9766 {
9767 if (!has_decoration(type.self, DecorationBlock))
9768 return false;
9769
9770 for (uint32_t i = 0; i < type.member_types.size(); i++)
9771 {
9772 if (!has_member_decoration(type.self, i, DecorationPatch))
9773 return false;
9774 }
9775
9776 return true;
9777 }
9778
9779 // Checks whether the ID is a row_major matrix that requires conversion before use
is_non_native_row_major_matrix(uint32_t id)9780 bool CompilerMSL::is_non_native_row_major_matrix(uint32_t id)
9781 {
9782 auto *e = maybe_get<SPIRExpression>(id);
9783 if (e)
9784 return e->need_transpose;
9785 else
9786 return has_decoration(id, DecorationRowMajor);
9787 }
9788
9789 // Checks whether the member is a row_major matrix that requires conversion before use
member_is_non_native_row_major_matrix(const SPIRType & type,uint32_t index)9790 bool CompilerMSL::member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index)
9791 {
9792 return has_member_decoration(type.self, index, DecorationRowMajor);
9793 }
9794
convert_row_major_matrix(string exp_str,const SPIRType & exp_type,uint32_t physical_type_id,bool is_packed)9795 string CompilerMSL::convert_row_major_matrix(string exp_str, const SPIRType &exp_type, uint32_t physical_type_id,
9796 bool is_packed)
9797 {
9798 if (!is_matrix(exp_type))
9799 {
9800 return CompilerGLSL::convert_row_major_matrix(move(exp_str), exp_type, physical_type_id, is_packed);
9801 }
9802 else
9803 {
9804 strip_enclosed_expression(exp_str);
9805 if (physical_type_id != 0 || is_packed)
9806 exp_str = unpack_expression_type(exp_str, exp_type, physical_type_id, is_packed, true);
9807 return join("transpose(", exp_str, ")");
9808 }
9809 }
9810
9811 // Called automatically at the end of the entry point function
emit_fixup()9812 void CompilerMSL::emit_fixup()
9813 {
9814 if (is_vertex_like_shader() && stage_out_var_id && !qual_pos_var_name.empty() && !capture_output_to_buffer)
9815 {
9816 if (options.vertex.fixup_clipspace)
9817 statement(qual_pos_var_name, ".z = (", qual_pos_var_name, ".z + ", qual_pos_var_name,
9818 ".w) * 0.5; // Adjust clip-space for Metal");
9819
9820 if (options.vertex.flip_vert_y)
9821 statement(qual_pos_var_name, ".y = -(", qual_pos_var_name, ".y);", " // Invert Y-axis for Metal");
9822 }
9823 }
9824
9825 // Return a string defining a structure member, with padding and packing.
to_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier)9826 string CompilerMSL::to_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
9827 const string &qualifier)
9828 {
9829 if (member_is_remapped_physical_type(type, index))
9830 member_type_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID);
9831 auto &physical_type = get<SPIRType>(member_type_id);
9832
9833 // If this member is packed, mark it as so.
9834 string pack_pfx;
9835
9836 // Allow Metal to use the array<T> template to make arrays a value type
9837 uint32_t orig_id = 0;
9838 if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID))
9839 orig_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID);
9840
9841 bool row_major = false;
9842 if (is_matrix(physical_type))
9843 row_major = has_member_decoration(type.self, index, DecorationRowMajor);
9844
9845 SPIRType row_major_physical_type;
9846 const SPIRType *declared_type = &physical_type;
9847
9848 // If a struct is being declared with physical layout,
9849 // do not use array<T> wrappers.
9850 // This avoids a lot of complicated cases with packed vectors and matrices,
9851 // and generally we cannot copy full arrays in and out of buffers into Function
9852 // address space.
9853 // Array of resources should also be declared as builtin arrays.
9854 if (has_member_decoration(type.self, index, DecorationOffset))
9855 is_using_builtin_array = true;
9856 else if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary))
9857 is_using_builtin_array = true;
9858
9859 if (member_is_packed_physical_type(type, index))
9860 {
9861 // If we're packing a matrix, output an appropriate typedef
9862 if (physical_type.basetype == SPIRType::Struct)
9863 {
9864 SPIRV_CROSS_THROW("Cannot emit a packed struct currently.");
9865 }
9866 else if (is_matrix(physical_type))
9867 {
9868 uint32_t rows = physical_type.vecsize;
9869 uint32_t cols = physical_type.columns;
9870 pack_pfx = "packed_";
9871 if (row_major)
9872 {
9873 // These are stored transposed.
9874 rows = physical_type.columns;
9875 cols = physical_type.vecsize;
9876 pack_pfx = "packed_rm_";
9877 }
9878 string base_type = physical_type.width == 16 ? "half" : "float";
9879 string td_line = "typedef ";
9880 td_line += "packed_" + base_type + to_string(rows);
9881 td_line += " " + pack_pfx;
9882 // Use the actual matrix size here.
9883 td_line += base_type + to_string(physical_type.columns) + "x" + to_string(physical_type.vecsize);
9884 td_line += "[" + to_string(cols) + "]";
9885 td_line += ";";
9886 add_typedef_line(td_line);
9887 }
9888 else if (!is_scalar(physical_type)) // scalar type is already packed.
9889 pack_pfx = "packed_";
9890 }
9891 else if (row_major)
9892 {
9893 // Need to declare type with flipped vecsize/columns.
9894 row_major_physical_type = physical_type;
9895 swap(row_major_physical_type.vecsize, row_major_physical_type.columns);
9896 declared_type = &row_major_physical_type;
9897 }
9898
9899 // Very specifically, image load-store in argument buffers are disallowed on MSL on iOS.
9900 if (msl_options.is_ios() && physical_type.basetype == SPIRType::Image && physical_type.image.sampled == 2)
9901 {
9902 if (!has_decoration(orig_id, DecorationNonWritable))
9903 SPIRV_CROSS_THROW("Writable images are not allowed in argument buffers on iOS.");
9904 }
9905
9906 // Array information is baked into these types.
9907 string array_type;
9908 if (physical_type.basetype != SPIRType::Image && physical_type.basetype != SPIRType::Sampler &&
9909 physical_type.basetype != SPIRType::SampledImage)
9910 {
9911 BuiltIn builtin = BuiltInMax;
9912
9913 // Special handling. In [[stage_out]] or [[stage_in]] blocks,
9914 // we need flat arrays, but if we're somehow declaring gl_PerVertex for constant array reasons, we want
9915 // template array types to be declared.
9916 bool is_ib_in_out =
9917 ((stage_out_var_id && get_stage_out_struct_type().self == type.self) ||
9918 (stage_in_var_id && get_stage_in_struct_type().self == type.self));
9919 if (is_ib_in_out && is_member_builtin(type, index, &builtin))
9920 is_using_builtin_array = true;
9921 array_type = type_to_array_glsl(physical_type);
9922 }
9923
9924 auto result = join(pack_pfx, type_to_glsl(*declared_type, orig_id), " ", qualifier, to_member_name(type, index),
9925 member_attribute_qualifier(type, index), array_type, ";");
9926
9927 is_using_builtin_array = false;
9928 return result;
9929 }
9930
9931 // Emit a structure member, padding and packing to maintain the correct memeber alignments.
emit_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier,uint32_t)9932 void CompilerMSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
9933 const string &qualifier, uint32_t)
9934 {
9935 // If this member requires padding to maintain its declared offset, emit a dummy padding member before it.
9936 if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationPaddingTarget))
9937 {
9938 uint32_t pad_len = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPaddingTarget);
9939 statement("char _m", index, "_pad", "[", pad_len, "];");
9940 }
9941
9942 // Handle HLSL-style 0-based vertex/instance index.
9943 builtin_declaration = true;
9944 statement(to_struct_member(type, member_type_id, index, qualifier));
9945 builtin_declaration = false;
9946 }
9947
emit_struct_padding_target(const SPIRType & type)9948 void CompilerMSL::emit_struct_padding_target(const SPIRType &type)
9949 {
9950 uint32_t struct_size = get_declared_struct_size_msl(type, true, true);
9951 uint32_t target_size = get_extended_decoration(type.self, SPIRVCrossDecorationPaddingTarget);
9952 if (target_size < struct_size)
9953 SPIRV_CROSS_THROW("Cannot pad with negative bytes.");
9954 else if (target_size > struct_size)
9955 statement("char _m0_final_padding[", target_size - struct_size, "];");
9956 }
9957
9958 // Return a MSL qualifier for the specified function attribute member
member_attribute_qualifier(const SPIRType & type,uint32_t index)9959 string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t index)
9960 {
9961 auto &execution = get_entry_point();
9962
9963 uint32_t mbr_type_id = type.member_types[index];
9964 auto &mbr_type = get<SPIRType>(mbr_type_id);
9965
9966 BuiltIn builtin = BuiltInMax;
9967 bool is_builtin = is_member_builtin(type, index, &builtin);
9968
9969 if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary))
9970 {
9971 string quals = join(
9972 " [[id(", get_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary), ")");
9973 if (interlocked_resources.count(
9974 get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID)))
9975 quals += ", raster_order_group(0)";
9976 quals += "]]";
9977 return quals;
9978 }
9979
9980 // Vertex function inputs
9981 if (execution.model == ExecutionModelVertex && type.storage == StorageClassInput)
9982 {
9983 if (is_builtin)
9984 {
9985 switch (builtin)
9986 {
9987 case BuiltInVertexId:
9988 case BuiltInVertexIndex:
9989 case BuiltInBaseVertex:
9990 case BuiltInInstanceId:
9991 case BuiltInInstanceIndex:
9992 case BuiltInBaseInstance:
9993 if (msl_options.vertex_for_tessellation)
9994 return "";
9995 return string(" [[") + builtin_qualifier(builtin) + "]]";
9996
9997 case BuiltInDrawIndex:
9998 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
9999
10000 default:
10001 return "";
10002 }
10003 }
10004 uint32_t locn = get_ordered_member_location(type.self, index);
10005 if (locn != k_unknown_location)
10006 return string(" [[attribute(") + convert_to_string(locn) + ")]]";
10007 }
10008
10009 // Vertex and tessellation evaluation function outputs
10010 if (((execution.model == ExecutionModelVertex && !msl_options.vertex_for_tessellation) ||
10011 execution.model == ExecutionModelTessellationEvaluation) &&
10012 type.storage == StorageClassOutput)
10013 {
10014 if (is_builtin)
10015 {
10016 switch (builtin)
10017 {
10018 case BuiltInPointSize:
10019 // Only mark the PointSize builtin if really rendering points.
10020 // Some shaders may include a PointSize builtin even when used to render
10021 // non-point topologies, and Metal will reject this builtin when compiling
10022 // the shader into a render pipeline that uses a non-point topology.
10023 return msl_options.enable_point_size_builtin ? (string(" [[") + builtin_qualifier(builtin) + "]]") : "";
10024
10025 case BuiltInViewportIndex:
10026 if (!msl_options.supports_msl_version(2, 0))
10027 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
10028 /* fallthrough */
10029 case BuiltInPosition:
10030 case BuiltInLayer:
10031 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10032
10033 case BuiltInClipDistance:
10034 if (has_member_decoration(type.self, index, DecorationLocation))
10035 return join(" [[user(clip", get_member_decoration(type.self, index, DecorationLocation), ")]]");
10036 else
10037 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10038
10039 default:
10040 return "";
10041 }
10042 }
10043 uint32_t comp;
10044 uint32_t locn = get_ordered_member_location(type.self, index, &comp);
10045 if (locn != k_unknown_location)
10046 {
10047 if (comp != k_unknown_component)
10048 return string(" [[user(locn") + convert_to_string(locn) + "_" + convert_to_string(comp) + ")]]";
10049 else
10050 return string(" [[user(locn") + convert_to_string(locn) + ")]]";
10051 }
10052 }
10053
10054 // Tessellation control function inputs
10055 if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassInput)
10056 {
10057 if (is_builtin)
10058 {
10059 switch (builtin)
10060 {
10061 case BuiltInInvocationId:
10062 case BuiltInPrimitiveId:
10063 if (msl_options.multi_patch_workgroup)
10064 return "";
10065 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10066 case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
10067 case BuiltInSubgroupSize: // FIXME: Should work in any stage
10068 if (msl_options.emulate_subgroups)
10069 return "";
10070 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10071 case BuiltInPatchVertices:
10072 return "";
10073 // Others come from stage input.
10074 default:
10075 break;
10076 }
10077 }
10078 if (msl_options.multi_patch_workgroup)
10079 return "";
10080 uint32_t locn = get_ordered_member_location(type.self, index);
10081 if (locn != k_unknown_location)
10082 return string(" [[attribute(") + convert_to_string(locn) + ")]]";
10083 }
10084
10085 // Tessellation control function outputs
10086 if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassOutput)
10087 {
10088 // For this type of shader, we always arrange for it to capture its
10089 // output to a buffer. For this reason, qualifiers are irrelevant here.
10090 return "";
10091 }
10092
10093 // Tessellation evaluation function inputs
10094 if (execution.model == ExecutionModelTessellationEvaluation && type.storage == StorageClassInput)
10095 {
10096 if (is_builtin)
10097 {
10098 switch (builtin)
10099 {
10100 case BuiltInPrimitiveId:
10101 case BuiltInTessCoord:
10102 return string(" [[") + builtin_qualifier(builtin) + "]]";
10103 case BuiltInPatchVertices:
10104 return "";
10105 // Others come from stage input.
10106 default:
10107 break;
10108 }
10109 }
10110 // The special control point array must not be marked with an attribute.
10111 if (get_type(type.member_types[index]).basetype == SPIRType::ControlPointArray)
10112 return "";
10113 uint32_t locn = get_ordered_member_location(type.self, index);
10114 if (locn != k_unknown_location)
10115 return string(" [[attribute(") + convert_to_string(locn) + ")]]";
10116 }
10117
10118 // Tessellation evaluation function outputs were handled above.
10119
10120 // Fragment function inputs
10121 if (execution.model == ExecutionModelFragment && type.storage == StorageClassInput)
10122 {
10123 string quals;
10124 if (is_builtin)
10125 {
10126 switch (builtin)
10127 {
10128 case BuiltInViewIndex:
10129 if (!msl_options.multiview || !msl_options.multiview_layered_rendering)
10130 break;
10131 /* fallthrough */
10132 case BuiltInFrontFacing:
10133 case BuiltInPointCoord:
10134 case BuiltInFragCoord:
10135 case BuiltInSampleId:
10136 case BuiltInSampleMask:
10137 case BuiltInLayer:
10138 case BuiltInBaryCoordNV:
10139 case BuiltInBaryCoordNoPerspNV:
10140 quals = builtin_qualifier(builtin);
10141 break;
10142
10143 case BuiltInClipDistance:
10144 return join(" [[user(clip", get_member_decoration(type.self, index, DecorationLocation), ")]]");
10145
10146 default:
10147 break;
10148 }
10149 }
10150 else
10151 {
10152 uint32_t comp;
10153 uint32_t locn = get_ordered_member_location(type.self, index, &comp);
10154 if (locn != k_unknown_location)
10155 {
10156 // For user-defined attributes, this is fine. From Vulkan spec:
10157 // A user-defined output variable is considered to match an input variable in the subsequent stage if
10158 // the two variables are declared with the same Location and Component decoration and match in type
10159 // and decoration, except that interpolation decorations are not required to match. For the purposes
10160 // of interface matching, variables declared without a Component decoration are considered to have a
10161 // Component decoration of zero.
10162
10163 if (comp != k_unknown_component && comp != 0)
10164 quals = string("user(locn") + convert_to_string(locn) + "_" + convert_to_string(comp) + ")";
10165 else
10166 quals = string("user(locn") + convert_to_string(locn) + ")";
10167 }
10168 }
10169
10170 if (builtin == BuiltInBaryCoordNV || builtin == BuiltInBaryCoordNoPerspNV)
10171 {
10172 if (has_member_decoration(type.self, index, DecorationFlat) ||
10173 has_member_decoration(type.self, index, DecorationCentroid) ||
10174 has_member_decoration(type.self, index, DecorationSample) ||
10175 has_member_decoration(type.self, index, DecorationNoPerspective))
10176 {
10177 // NoPerspective is baked into the builtin type.
10178 SPIRV_CROSS_THROW(
10179 "Flat, Centroid, Sample, NoPerspective decorations are not supported for BaryCoord inputs.");
10180 }
10181 }
10182
10183 // Don't bother decorating integers with the 'flat' attribute; it's
10184 // the default (in fact, the only option). Also don't bother with the
10185 // FragCoord builtin; it's always noperspective on Metal.
10186 if (!type_is_integral(mbr_type) && (!is_builtin || builtin != BuiltInFragCoord))
10187 {
10188 if (has_member_decoration(type.self, index, DecorationFlat))
10189 {
10190 if (!quals.empty())
10191 quals += ", ";
10192 quals += "flat";
10193 }
10194 else if (has_member_decoration(type.self, index, DecorationCentroid))
10195 {
10196 if (!quals.empty())
10197 quals += ", ";
10198 if (has_member_decoration(type.self, index, DecorationNoPerspective))
10199 quals += "centroid_no_perspective";
10200 else
10201 quals += "centroid_perspective";
10202 }
10203 else if (has_member_decoration(type.self, index, DecorationSample))
10204 {
10205 if (!quals.empty())
10206 quals += ", ";
10207 if (has_member_decoration(type.self, index, DecorationNoPerspective))
10208 quals += "sample_no_perspective";
10209 else
10210 quals += "sample_perspective";
10211 }
10212 else if (has_member_decoration(type.self, index, DecorationNoPerspective))
10213 {
10214 if (!quals.empty())
10215 quals += ", ";
10216 quals += "center_no_perspective";
10217 }
10218 }
10219
10220 if (!quals.empty())
10221 return " [[" + quals + "]]";
10222 }
10223
10224 // Fragment function outputs
10225 if (execution.model == ExecutionModelFragment && type.storage == StorageClassOutput)
10226 {
10227 if (is_builtin)
10228 {
10229 switch (builtin)
10230 {
10231 case BuiltInFragStencilRefEXT:
10232 // Similar to PointSize, only mark FragStencilRef if there's a stencil buffer.
10233 // Some shaders may include a FragStencilRef builtin even when used to render
10234 // without a stencil attachment, and Metal will reject this builtin
10235 // when compiling the shader into a render pipeline that does not set
10236 // stencilAttachmentPixelFormat.
10237 if (!msl_options.enable_frag_stencil_ref_builtin)
10238 return "";
10239 if (!msl_options.supports_msl_version(2, 1))
10240 SPIRV_CROSS_THROW("Stencil export only supported in MSL 2.1 and up.");
10241 return string(" [[") + builtin_qualifier(builtin) + "]]";
10242
10243 case BuiltInFragDepth:
10244 // Ditto FragDepth.
10245 if (!msl_options.enable_frag_depth_builtin)
10246 return "";
10247 /* fallthrough */
10248 case BuiltInSampleMask:
10249 return string(" [[") + builtin_qualifier(builtin) + "]]";
10250
10251 default:
10252 return "";
10253 }
10254 }
10255 uint32_t locn = get_ordered_member_location(type.self, index);
10256 // Metal will likely complain about missing color attachments, too.
10257 if (locn != k_unknown_location && !(msl_options.enable_frag_output_mask & (1 << locn)))
10258 return "";
10259 if (locn != k_unknown_location && has_member_decoration(type.self, index, DecorationIndex))
10260 return join(" [[color(", locn, "), index(", get_member_decoration(type.self, index, DecorationIndex),
10261 ")]]");
10262 else if (locn != k_unknown_location)
10263 return join(" [[color(", locn, ")]]");
10264 else if (has_member_decoration(type.self, index, DecorationIndex))
10265 return join(" [[index(", get_member_decoration(type.self, index, DecorationIndex), ")]]");
10266 else
10267 return "";
10268 }
10269
10270 // Compute function inputs
10271 if (execution.model == ExecutionModelGLCompute && type.storage == StorageClassInput)
10272 {
10273 if (is_builtin)
10274 {
10275 switch (builtin)
10276 {
10277 case BuiltInNumSubgroups:
10278 case BuiltInSubgroupId:
10279 case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
10280 case BuiltInSubgroupSize: // FIXME: Should work in any stage
10281 if (msl_options.emulate_subgroups)
10282 break;
10283 /* fallthrough */
10284 case BuiltInGlobalInvocationId:
10285 case BuiltInWorkgroupId:
10286 case BuiltInNumWorkgroups:
10287 case BuiltInLocalInvocationId:
10288 case BuiltInLocalInvocationIndex:
10289 return string(" [[") + builtin_qualifier(builtin) + "]]";
10290
10291 default:
10292 return "";
10293 }
10294 }
10295 }
10296
10297 return "";
10298 }
10299
10300 // Returns the location decoration of the member with the specified index in the specified type.
10301 // If the location of the member has been explicitly set, that location is used. If not, this
10302 // function assumes the members are ordered in their location order, and simply returns the
10303 // index as the location.
get_ordered_member_location(uint32_t type_id,uint32_t index,uint32_t * comp)10304 uint32_t CompilerMSL::get_ordered_member_location(uint32_t type_id, uint32_t index, uint32_t *comp)
10305 {
10306 auto &m = ir.meta[type_id];
10307 if (index < m.members.size())
10308 {
10309 auto &dec = m.members[index];
10310 if (comp)
10311 {
10312 if (dec.decoration_flags.get(DecorationComponent))
10313 *comp = dec.component;
10314 else
10315 *comp = k_unknown_component;
10316 }
10317 if (dec.decoration_flags.get(DecorationLocation))
10318 return dec.location;
10319 }
10320
10321 return index;
10322 }
10323
10324 // Returns the type declaration for a function, including the
10325 // entry type if the current function is the entry point function
func_type_decl(SPIRType & type)10326 string CompilerMSL::func_type_decl(SPIRType &type)
10327 {
10328 // The regular function return type. If not processing the entry point function, that's all we need
10329 string return_type = type_to_glsl(type) + type_to_array_glsl(type);
10330 if (!processing_entry_point)
10331 return return_type;
10332
10333 // If an outgoing interface block has been defined, and it should be returned, override the entry point return type
10334 bool ep_should_return_output = !get_is_rasterization_disabled();
10335 if (stage_out_var_id && ep_should_return_output)
10336 return_type = type_to_glsl(get_stage_out_struct_type()) + type_to_array_glsl(type);
10337
10338 // Prepend a entry type, based on the execution model
10339 string entry_type;
10340 auto &execution = get_entry_point();
10341 switch (execution.model)
10342 {
10343 case ExecutionModelVertex:
10344 if (msl_options.vertex_for_tessellation && !msl_options.supports_msl_version(1, 2))
10345 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
10346 entry_type = msl_options.vertex_for_tessellation ? "kernel" : "vertex";
10347 break;
10348 case ExecutionModelTessellationEvaluation:
10349 if (!msl_options.supports_msl_version(1, 2))
10350 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
10351 if (execution.flags.get(ExecutionModeIsolines))
10352 SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
10353 if (msl_options.is_ios())
10354 entry_type =
10355 join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ") ]] vertex");
10356 else
10357 entry_type = join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ", ",
10358 execution.output_vertices, ") ]] vertex");
10359 break;
10360 case ExecutionModelFragment:
10361 entry_type = execution.flags.get(ExecutionModeEarlyFragmentTests) ||
10362 execution.flags.get(ExecutionModePostDepthCoverage) ?
10363 "[[ early_fragment_tests ]] fragment" :
10364 "fragment";
10365 break;
10366 case ExecutionModelTessellationControl:
10367 if (!msl_options.supports_msl_version(1, 2))
10368 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
10369 if (execution.flags.get(ExecutionModeIsolines))
10370 SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
10371 /* fallthrough */
10372 case ExecutionModelGLCompute:
10373 case ExecutionModelKernel:
10374 entry_type = "kernel";
10375 break;
10376 default:
10377 entry_type = "unknown";
10378 break;
10379 }
10380
10381 return entry_type + " " + return_type;
10382 }
10383
10384 // In MSL, address space qualifiers are required for all pointer or reference variables
get_argument_address_space(const SPIRVariable & argument)10385 string CompilerMSL::get_argument_address_space(const SPIRVariable &argument)
10386 {
10387 const auto &type = get<SPIRType>(argument.basetype);
10388 return get_type_address_space(type, argument.self, true);
10389 }
10390
get_type_address_space(const SPIRType & type,uint32_t id,bool argument)10391 string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bool argument)
10392 {
10393 // This can be called for variable pointer contexts as well, so be very careful about which method we choose.
10394 Bitset flags;
10395 auto *var = maybe_get<SPIRVariable>(id);
10396 if (var && type.basetype == SPIRType::Struct &&
10397 (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
10398 flags = get_buffer_block_flags(id);
10399 else
10400 flags = get_decoration_bitset(id);
10401
10402 const char *addr_space = nullptr;
10403 switch (type.storage)
10404 {
10405 case StorageClassWorkgroup:
10406 addr_space = "threadgroup";
10407 break;
10408
10409 case StorageClassStorageBuffer:
10410 {
10411 // For arguments from variable pointers, we use the write count deduction, so
10412 // we should not assume any constness here. Only for global SSBOs.
10413 bool readonly = false;
10414 if (!var || has_decoration(type.self, DecorationBlock))
10415 readonly = flags.get(DecorationNonWritable);
10416
10417 addr_space = readonly ? "const device" : "device";
10418 break;
10419 }
10420
10421 case StorageClassUniform:
10422 case StorageClassUniformConstant:
10423 case StorageClassPushConstant:
10424 if (type.basetype == SPIRType::Struct)
10425 {
10426 bool ssbo = has_decoration(type.self, DecorationBufferBlock);
10427 if (ssbo)
10428 addr_space = flags.get(DecorationNonWritable) ? "const device" : "device";
10429 else
10430 addr_space = "constant";
10431 }
10432 else if (!argument)
10433 {
10434 addr_space = "constant";
10435 }
10436 else if (type_is_msl_framebuffer_fetch(type))
10437 {
10438 // Subpass inputs are passed around by value.
10439 addr_space = "";
10440 }
10441 break;
10442
10443 case StorageClassFunction:
10444 case StorageClassGeneric:
10445 break;
10446
10447 case StorageClassInput:
10448 if (get_execution_model() == ExecutionModelTessellationControl && var &&
10449 var->basevariable == stage_in_ptr_var_id)
10450 addr_space = msl_options.multi_patch_workgroup ? "constant" : "threadgroup";
10451 if (get_execution_model() == ExecutionModelFragment && var && var->basevariable == stage_in_var_id)
10452 addr_space = "thread";
10453 break;
10454
10455 case StorageClassOutput:
10456 if (capture_output_to_buffer)
10457 addr_space = "device";
10458 break;
10459
10460 default:
10461 break;
10462 }
10463
10464 if (!addr_space)
10465 // No address space for plain values.
10466 addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : "";
10467
10468 return join(flags.get(DecorationVolatile) || flags.get(DecorationCoherent) ? "volatile " : "", addr_space);
10469 }
10470
to_restrict(uint32_t id,bool space)10471 const char *CompilerMSL::to_restrict(uint32_t id, bool space)
10472 {
10473 // This can be called for variable pointer contexts as well, so be very careful about which method we choose.
10474 Bitset flags;
10475 if (ir.ids[id].get_type() == TypeVariable)
10476 {
10477 uint32_t type_id = expression_type_id(id);
10478 auto &type = expression_type(id);
10479 if (type.basetype == SPIRType::Struct &&
10480 (has_decoration(type_id, DecorationBlock) || has_decoration(type_id, DecorationBufferBlock)))
10481 flags = get_buffer_block_flags(id);
10482 else
10483 flags = get_decoration_bitset(id);
10484 }
10485 else
10486 flags = get_decoration_bitset(id);
10487
10488 return flags.get(DecorationRestrict) ? (space ? "restrict " : "restrict") : "";
10489 }
10490
entry_point_arg_stage_in()10491 string CompilerMSL::entry_point_arg_stage_in()
10492 {
10493 string decl;
10494
10495 if (get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup)
10496 return decl;
10497
10498 // Stage-in structure
10499 uint32_t stage_in_id;
10500 if (get_execution_model() == ExecutionModelTessellationEvaluation)
10501 stage_in_id = patch_stage_in_var_id;
10502 else
10503 stage_in_id = stage_in_var_id;
10504
10505 if (stage_in_id)
10506 {
10507 auto &var = get<SPIRVariable>(stage_in_id);
10508 auto &type = get_variable_data_type(var);
10509
10510 add_resource_name(var.self);
10511 decl = join(type_to_glsl(type), " ", to_name(var.self), " [[stage_in]]");
10512 }
10513
10514 return decl;
10515 }
10516
10517 // Returns true if this input builtin should be a direct parameter on a shader function parameter list,
10518 // and false for builtins that should be passed or calculated some other way.
is_direct_input_builtin(BuiltIn bi_type)10519 bool CompilerMSL::is_direct_input_builtin(BuiltIn bi_type)
10520 {
10521 switch (bi_type)
10522 {
10523 // Vertex function in
10524 case BuiltInVertexId:
10525 case BuiltInVertexIndex:
10526 case BuiltInBaseVertex:
10527 case BuiltInInstanceId:
10528 case BuiltInInstanceIndex:
10529 case BuiltInBaseInstance:
10530 return get_execution_model() != ExecutionModelVertex || !msl_options.vertex_for_tessellation;
10531 // Tess. control function in
10532 case BuiltInPosition:
10533 case BuiltInPointSize:
10534 case BuiltInClipDistance:
10535 case BuiltInCullDistance:
10536 case BuiltInPatchVertices:
10537 return false;
10538 case BuiltInInvocationId:
10539 case BuiltInPrimitiveId:
10540 return get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup;
10541 // Tess. evaluation function in
10542 case BuiltInTessLevelInner:
10543 case BuiltInTessLevelOuter:
10544 return false;
10545 // Fragment function in
10546 case BuiltInSamplePosition:
10547 case BuiltInHelperInvocation:
10548 case BuiltInBaryCoordNV:
10549 case BuiltInBaryCoordNoPerspNV:
10550 return false;
10551 case BuiltInViewIndex:
10552 return get_execution_model() == ExecutionModelFragment && msl_options.multiview &&
10553 msl_options.multiview_layered_rendering;
10554 // Compute function in
10555 case BuiltInSubgroupId:
10556 case BuiltInNumSubgroups:
10557 return !msl_options.emulate_subgroups;
10558 // Any stage function in
10559 case BuiltInDeviceIndex:
10560 case BuiltInSubgroupEqMask:
10561 case BuiltInSubgroupGeMask:
10562 case BuiltInSubgroupGtMask:
10563 case BuiltInSubgroupLeMask:
10564 case BuiltInSubgroupLtMask:
10565 return false;
10566 case BuiltInSubgroupSize:
10567 if (msl_options.fixed_subgroup_size != 0)
10568 return false;
10569 /* fallthrough */
10570 case BuiltInSubgroupLocalInvocationId:
10571 return !msl_options.emulate_subgroups;
10572 default:
10573 return true;
10574 }
10575 }
10576
10577 // Returns true if this is a fragment shader that runs per sample, and false otherwise.
is_sample_rate() const10578 bool CompilerMSL::is_sample_rate() const
10579 {
10580 auto &caps = get_declared_capabilities();
10581 return get_execution_model() == ExecutionModelFragment &&
10582 (msl_options.force_sample_rate_shading ||
10583 std::find(caps.begin(), caps.end(), CapabilitySampleRateShading) != caps.end() ||
10584 (msl_options.use_framebuffer_fetch_subpasses && need_subpass_input));
10585 }
10586
entry_point_args_builtin(string & ep_args)10587 void CompilerMSL::entry_point_args_builtin(string &ep_args)
10588 {
10589 // Builtin variables
10590 SmallVector<pair<SPIRVariable *, BuiltIn>, 8> active_builtins;
10591 ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
10592 if (var.storage != StorageClassInput)
10593 return;
10594
10595 auto bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
10596
10597 // Don't emit SamplePosition as a separate parameter. In the entry
10598 // point, we get that by calling get_sample_position() on the sample ID.
10599 if (is_builtin_variable(var) &&
10600 get_variable_data_type(var).basetype != SPIRType::Struct &&
10601 get_variable_data_type(var).basetype != SPIRType::ControlPointArray)
10602 {
10603 // If the builtin is not part of the active input builtin set, don't emit it.
10604 // Relevant for multiple entry-point modules which might declare unused builtins.
10605 if (!active_input_builtins.get(bi_type) || !interface_variable_exists_in_entry_point(var_id))
10606 return;
10607
10608 // Remember this variable. We may need to correct its type.
10609 active_builtins.push_back(make_pair(&var, bi_type));
10610
10611 if (is_direct_input_builtin(bi_type))
10612 {
10613 if (!ep_args.empty())
10614 ep_args += ", ";
10615
10616 // Handle HLSL-style 0-based vertex/instance index.
10617 builtin_declaration = true;
10618 ep_args += builtin_type_decl(bi_type, var_id) + " " + to_expression(var_id);
10619 ep_args += " [[" + builtin_qualifier(bi_type);
10620 if (bi_type == BuiltInSampleMask && get_entry_point().flags.get(ExecutionModePostDepthCoverage))
10621 {
10622 if (!msl_options.supports_msl_version(2))
10623 SPIRV_CROSS_THROW("Post-depth coverage requires MSL 2.0.");
10624 if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 3))
10625 SPIRV_CROSS_THROW("Post-depth coverage on Mac requires MSL 2.3.");
10626 ep_args += ", post_depth_coverage";
10627 }
10628 ep_args += "]]";
10629 builtin_declaration = false;
10630 }
10631 }
10632
10633 if (has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase))
10634 {
10635 // This is a special implicit builtin, not corresponding to any SPIR-V builtin,
10636 // which holds the base that was passed to vkCmdDispatchBase() or vkCmdDrawIndexed(). If it's present,
10637 // assume we emitted it for a good reason.
10638 assert(msl_options.supports_msl_version(1, 2));
10639 if (!ep_args.empty())
10640 ep_args += ", ";
10641
10642 ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_origin]]";
10643 }
10644
10645 if (has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInStageInputSize))
10646 {
10647 // This is another special implicit builtin, not corresponding to any SPIR-V builtin,
10648 // which holds the number of vertices and instances to draw. If it's present,
10649 // assume we emitted it for a good reason.
10650 assert(msl_options.supports_msl_version(1, 2));
10651 if (!ep_args.empty())
10652 ep_args += ", ";
10653
10654 ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_size]]";
10655 }
10656 });
10657
10658 // Correct the types of all encountered active builtins. We couldn't do this before
10659 // because ensure_correct_builtin_type() may increase the bound, which isn't allowed
10660 // while iterating over IDs.
10661 for (auto &var : active_builtins)
10662 var.first->basetype = ensure_correct_builtin_type(var.first->basetype, var.second);
10663
10664 // Handle HLSL-style 0-based vertex/instance index.
10665 if (needs_base_vertex_arg == TriState::Yes)
10666 ep_args += built_in_func_arg(BuiltInBaseVertex, !ep_args.empty());
10667
10668 if (needs_base_instance_arg == TriState::Yes)
10669 ep_args += built_in_func_arg(BuiltInBaseInstance, !ep_args.empty());
10670
10671 if (capture_output_to_buffer)
10672 {
10673 // Add parameters to hold the indirect draw parameters and the shader output. This has to be handled
10674 // specially because it needs to be a pointer, not a reference.
10675 if (stage_out_var_id)
10676 {
10677 if (!ep_args.empty())
10678 ep_args += ", ";
10679 ep_args += join("device ", type_to_glsl(get_stage_out_struct_type()), "* ", output_buffer_var_name,
10680 " [[buffer(", msl_options.shader_output_buffer_index, ")]]");
10681 }
10682
10683 if (get_execution_model() == ExecutionModelTessellationControl)
10684 {
10685 if (!ep_args.empty())
10686 ep_args += ", ";
10687 ep_args +=
10688 join("constant uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
10689 }
10690 else if (stage_out_var_id &&
10691 !(get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
10692 {
10693 if (!ep_args.empty())
10694 ep_args += ", ";
10695 ep_args +=
10696 join("device uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
10697 }
10698
10699 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation &&
10700 (active_input_builtins.get(BuiltInVertexIndex) || active_input_builtins.get(BuiltInVertexId)) &&
10701 msl_options.vertex_index_type != Options::IndexType::None)
10702 {
10703 // Add the index buffer so we can set gl_VertexIndex correctly.
10704 if (!ep_args.empty())
10705 ep_args += ", ";
10706 switch (msl_options.vertex_index_type)
10707 {
10708 case Options::IndexType::None:
10709 break;
10710 case Options::IndexType::UInt16:
10711 ep_args += join("const device ushort* ", index_buffer_var_name, " [[buffer(",
10712 msl_options.shader_index_buffer_index, ")]]");
10713 break;
10714 case Options::IndexType::UInt32:
10715 ep_args += join("const device uint* ", index_buffer_var_name, " [[buffer(",
10716 msl_options.shader_index_buffer_index, ")]]");
10717 break;
10718 }
10719 }
10720
10721 // Tessellation control shaders get three additional parameters:
10722 // a buffer to hold the per-patch data, a buffer to hold the per-patch
10723 // tessellation levels, and a block of workgroup memory to hold the
10724 // input control point data.
10725 if (get_execution_model() == ExecutionModelTessellationControl)
10726 {
10727 if (patch_stage_out_var_id)
10728 {
10729 if (!ep_args.empty())
10730 ep_args += ", ";
10731 ep_args +=
10732 join("device ", type_to_glsl(get_patch_stage_out_struct_type()), "* ", patch_output_buffer_var_name,
10733 " [[buffer(", convert_to_string(msl_options.shader_patch_output_buffer_index), ")]]");
10734 }
10735 if (!ep_args.empty())
10736 ep_args += ", ";
10737 ep_args += join("device ", get_tess_factor_struct_name(), "* ", tess_factor_buffer_var_name, " [[buffer(",
10738 convert_to_string(msl_options.shader_tess_factor_buffer_index), ")]]");
10739
10740 // Initializer for tess factors must be handled specially since it's never declared as a normal variable.
10741 uint32_t outer_factor_initializer_id = 0;
10742 uint32_t inner_factor_initializer_id = 0;
10743 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
10744 if (!has_decoration(var.self, DecorationBuiltIn) || var.storage != StorageClassOutput || !var.initializer)
10745 return;
10746
10747 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
10748 if (builtin == BuiltInTessLevelInner)
10749 inner_factor_initializer_id = var.initializer;
10750 else if (builtin == BuiltInTessLevelOuter)
10751 outer_factor_initializer_id = var.initializer;
10752 });
10753
10754 const SPIRConstant *c = nullptr;
10755
10756 if (outer_factor_initializer_id && (c = maybe_get<SPIRConstant>(outer_factor_initializer_id)))
10757 {
10758 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
10759 entry_func.fixup_hooks_in.push_back([=]() {
10760 uint32_t components = get_execution_mode_bitset().get(ExecutionModeTriangles) ? 3 : 4;
10761 for (uint32_t i = 0; i < components; i++)
10762 {
10763 statement(builtin_to_glsl(BuiltInTessLevelOuter, StorageClassOutput), "[", i, "] = ",
10764 "half(", to_expression(c->subconstants[i]), ");");
10765 }
10766 });
10767 }
10768
10769 if (inner_factor_initializer_id && (c = maybe_get<SPIRConstant>(inner_factor_initializer_id)))
10770 {
10771 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
10772 if (get_execution_mode_bitset().get(ExecutionModeTriangles))
10773 {
10774 entry_func.fixup_hooks_in.push_back([=]() {
10775 statement(builtin_to_glsl(BuiltInTessLevelInner, StorageClassOutput), " = ", "half(",
10776 to_expression(c->subconstants[0]), ");");
10777 });
10778 }
10779 else
10780 {
10781 entry_func.fixup_hooks_in.push_back([=]() {
10782 for (uint32_t i = 0; i < 2; i++)
10783 {
10784 statement(builtin_to_glsl(BuiltInTessLevelInner, StorageClassOutput), "[", i, "] = ",
10785 "half(", to_expression(c->subconstants[i]), ");");
10786 }
10787 });
10788 }
10789 }
10790
10791 if (stage_in_var_id)
10792 {
10793 if (!ep_args.empty())
10794 ep_args += ", ";
10795 if (msl_options.multi_patch_workgroup)
10796 {
10797 ep_args += join("device ", type_to_glsl(get_stage_in_struct_type()), "* ", input_buffer_var_name,
10798 " [[buffer(", convert_to_string(msl_options.shader_input_buffer_index), ")]]");
10799 }
10800 else
10801 {
10802 ep_args += join("threadgroup ", type_to_glsl(get_stage_in_struct_type()), "* ", input_wg_var_name,
10803 " [[threadgroup(", convert_to_string(msl_options.shader_input_wg_index), ")]]");
10804 }
10805 }
10806 }
10807 }
10808 }
10809
entry_point_args_argument_buffer(bool append_comma)10810 string CompilerMSL::entry_point_args_argument_buffer(bool append_comma)
10811 {
10812 string ep_args = entry_point_arg_stage_in();
10813 Bitset claimed_bindings;
10814
10815 for (uint32_t i = 0; i < kMaxArgumentBuffers; i++)
10816 {
10817 uint32_t id = argument_buffer_ids[i];
10818 if (id == 0)
10819 continue;
10820
10821 add_resource_name(id);
10822 auto &var = get<SPIRVariable>(id);
10823 auto &type = get_variable_data_type(var);
10824
10825 if (!ep_args.empty())
10826 ep_args += ", ";
10827
10828 // Check if the argument buffer binding itself has been remapped.
10829 uint32_t buffer_binding;
10830 auto itr = resource_bindings.find({ get_entry_point().model, i, kArgumentBufferBinding });
10831 if (itr != end(resource_bindings))
10832 {
10833 buffer_binding = itr->second.first.msl_buffer;
10834 itr->second.second = true;
10835 }
10836 else
10837 {
10838 // As a fallback, directly map desc set <-> binding.
10839 // If that was taken, take the next buffer binding.
10840 if (claimed_bindings.get(i))
10841 buffer_binding = next_metal_resource_index_buffer;
10842 else
10843 buffer_binding = i;
10844 }
10845
10846 claimed_bindings.set(buffer_binding);
10847
10848 ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(id) + to_name(id);
10849 ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]";
10850
10851 next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1);
10852 }
10853
10854 entry_point_args_discrete_descriptors(ep_args);
10855 entry_point_args_builtin(ep_args);
10856
10857 if (!ep_args.empty() && append_comma)
10858 ep_args += ", ";
10859
10860 return ep_args;
10861 }
10862
find_constexpr_sampler(uint32_t id) const10863 const MSLConstexprSampler *CompilerMSL::find_constexpr_sampler(uint32_t id) const
10864 {
10865 // Try by ID.
10866 {
10867 auto itr = constexpr_samplers_by_id.find(id);
10868 if (itr != end(constexpr_samplers_by_id))
10869 return &itr->second;
10870 }
10871
10872 // Try by binding.
10873 {
10874 uint32_t desc_set = get_decoration(id, DecorationDescriptorSet);
10875 uint32_t binding = get_decoration(id, DecorationBinding);
10876
10877 auto itr = constexpr_samplers_by_binding.find({ desc_set, binding });
10878 if (itr != end(constexpr_samplers_by_binding))
10879 return &itr->second;
10880 }
10881
10882 return nullptr;
10883 }
10884
entry_point_args_discrete_descriptors(string & ep_args)10885 void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
10886 {
10887 // Output resources, sorted by resource index & type
10888 // We need to sort to work around a bug on macOS 10.13 with NVidia drivers where switching between shaders
10889 // with different order of buffers can result in issues with buffer assignments inside the driver.
10890 struct Resource
10891 {
10892 SPIRVariable *var;
10893 string name;
10894 SPIRType::BaseType basetype;
10895 uint32_t index;
10896 uint32_t plane;
10897 uint32_t secondary_index;
10898 };
10899
10900 SmallVector<Resource> resources;
10901
10902 ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
10903 if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
10904 var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer) &&
10905 !is_hidden_variable(var))
10906 {
10907 auto &type = get_variable_data_type(var);
10908
10909 if (is_supported_argument_buffer_type(type) && var.storage != StorageClassPushConstant)
10910 {
10911 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
10912 if (descriptor_set_is_argument_buffer(desc_set))
10913 return;
10914 }
10915
10916 const MSLConstexprSampler *constexpr_sampler = nullptr;
10917 if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
10918 {
10919 constexpr_sampler = find_constexpr_sampler(var_id);
10920 if (constexpr_sampler)
10921 {
10922 // Mark this ID as a constexpr sampler for later in case it came from set/bindings.
10923 constexpr_samplers_by_id[var_id] = *constexpr_sampler;
10924 }
10925 }
10926
10927 // Emulate texture2D atomic operations
10928 uint32_t secondary_index = 0;
10929 if (atomic_image_vars.count(var.self))
10930 {
10931 secondary_index = get_metal_resource_index(var, SPIRType::AtomicCounter, 0);
10932 }
10933
10934 if (type.basetype == SPIRType::SampledImage)
10935 {
10936 add_resource_name(var_id);
10937
10938 uint32_t plane_count = 1;
10939 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
10940 plane_count = constexpr_sampler->planes;
10941
10942 for (uint32_t i = 0; i < plane_count; i++)
10943 resources.push_back({ &var, to_name(var_id), SPIRType::Image,
10944 get_metal_resource_index(var, SPIRType::Image, i), i, secondary_index });
10945
10946 if (type.image.dim != DimBuffer && !constexpr_sampler)
10947 {
10948 resources.push_back({ &var, to_sampler_expression(var_id), SPIRType::Sampler,
10949 get_metal_resource_index(var, SPIRType::Sampler), 0, 0 });
10950 }
10951 }
10952 else if (!constexpr_sampler)
10953 {
10954 // constexpr samplers are not declared as resources.
10955 add_resource_name(var_id);
10956 resources.push_back({ &var, to_name(var_id), type.basetype,
10957 get_metal_resource_index(var, type.basetype), 0, secondary_index });
10958 }
10959 }
10960 });
10961
10962 sort(resources.begin(), resources.end(), [](const Resource &lhs, const Resource &rhs) {
10963 return tie(lhs.basetype, lhs.index) < tie(rhs.basetype, rhs.index);
10964 });
10965
10966 for (auto &r : resources)
10967 {
10968 auto &var = *r.var;
10969 auto &type = get_variable_data_type(var);
10970
10971 uint32_t var_id = var.self;
10972
10973 switch (r.basetype)
10974 {
10975 case SPIRType::Struct:
10976 {
10977 auto &m = ir.meta[type.self];
10978 if (m.members.size() == 0)
10979 break;
10980 if (!type.array.empty())
10981 {
10982 if (type.array.size() > 1)
10983 SPIRV_CROSS_THROW("Arrays of arrays of buffers are not supported.");
10984
10985 // Metal doesn't directly support this, so we must expand the
10986 // array. We'll declare a local array to hold these elements
10987 // later.
10988 uint32_t array_size = to_array_size_literal(type);
10989
10990 if (array_size == 0)
10991 SPIRV_CROSS_THROW("Unsized arrays of buffers are not supported in MSL.");
10992
10993 // Allow Metal to use the array<T> template to make arrays a value type
10994 is_using_builtin_array = true;
10995 buffer_arrays.push_back(var_id);
10996 for (uint32_t i = 0; i < array_size; ++i)
10997 {
10998 if (!ep_args.empty())
10999 ep_args += ", ";
11000 ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "* " + to_restrict(var_id) +
11001 r.name + "_" + convert_to_string(i);
11002 ep_args += " [[buffer(" + convert_to_string(r.index + i) + ")";
11003 if (interlocked_resources.count(var_id))
11004 ep_args += ", raster_order_group(0)";
11005 ep_args += "]]";
11006 }
11007 is_using_builtin_array = false;
11008 }
11009 else
11010 {
11011 if (!ep_args.empty())
11012 ep_args += ", ";
11013 ep_args +=
11014 get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(var_id) + r.name;
11015 ep_args += " [[buffer(" + convert_to_string(r.index) + ")";
11016 if (interlocked_resources.count(var_id))
11017 ep_args += ", raster_order_group(0)";
11018 ep_args += "]]";
11019 }
11020 break;
11021 }
11022 case SPIRType::Sampler:
11023 if (!ep_args.empty())
11024 ep_args += ", ";
11025 ep_args += sampler_type(type, var_id) + " " + r.name;
11026 ep_args += " [[sampler(" + convert_to_string(r.index) + ")]]";
11027 break;
11028 case SPIRType::Image:
11029 {
11030 if (!ep_args.empty())
11031 ep_args += ", ";
11032
11033 // Use Metal's native frame-buffer fetch API for subpass inputs.
11034 const auto &basetype = get<SPIRType>(var.basetype);
11035 if (!type_is_msl_framebuffer_fetch(basetype))
11036 {
11037 ep_args += image_type_glsl(type, var_id) + " " + r.name;
11038 if (r.plane > 0)
11039 ep_args += join(plane_name_suffix, r.plane);
11040 ep_args += " [[texture(" + convert_to_string(r.index) + ")";
11041 if (interlocked_resources.count(var_id))
11042 ep_args += ", raster_order_group(0)";
11043 ep_args += "]]";
11044 }
11045 else
11046 {
11047 if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 3))
11048 SPIRV_CROSS_THROW("Framebuffer fetch on Mac is not supported before MSL 2.3.");
11049 ep_args += image_type_glsl(type, var_id) + " " + r.name;
11050 ep_args += " [[color(" + convert_to_string(r.index) + ")]]";
11051 }
11052
11053 // Emulate texture2D atomic operations
11054 if (atomic_image_vars.count(var.self))
11055 {
11056 ep_args += ", device atomic_" + type_to_glsl(get<SPIRType>(basetype.image.type), 0);
11057 ep_args += "* " + r.name + "_atomic";
11058 ep_args += " [[buffer(" + convert_to_string(r.secondary_index) + ")";
11059 if (interlocked_resources.count(var_id))
11060 ep_args += ", raster_order_group(0)";
11061 ep_args += "]]";
11062 }
11063 break;
11064 }
11065 default:
11066 if (!ep_args.empty())
11067 ep_args += ", ";
11068 if (!type.pointer)
11069 ep_args += get_type_address_space(get<SPIRType>(var.basetype), var_id) + " " +
11070 type_to_glsl(type, var_id) + "& " + r.name;
11071 else
11072 ep_args += type_to_glsl(type, var_id) + " " + r.name;
11073 ep_args += " [[buffer(" + convert_to_string(r.index) + ")";
11074 if (interlocked_resources.count(var_id))
11075 ep_args += ", raster_order_group(0)";
11076 ep_args += "]]";
11077 break;
11078 }
11079 }
11080 }
11081
11082 // Returns a string containing a comma-delimited list of args for the entry point function
11083 // This is the "classic" method of MSL 1 when we don't have argument buffer support.
entry_point_args_classic(bool append_comma)11084 string CompilerMSL::entry_point_args_classic(bool append_comma)
11085 {
11086 string ep_args = entry_point_arg_stage_in();
11087 entry_point_args_discrete_descriptors(ep_args);
11088 entry_point_args_builtin(ep_args);
11089
11090 if (!ep_args.empty() && append_comma)
11091 ep_args += ", ";
11092
11093 return ep_args;
11094 }
11095
fix_up_shader_inputs_outputs()11096 void CompilerMSL::fix_up_shader_inputs_outputs()
11097 {
11098 auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
11099
11100 // Emit a guard to ensure we don't execute beyond the last vertex.
11101 // Vertex shaders shouldn't have the problems with barriers in non-uniform control flow that
11102 // tessellation control shaders do, so early returns should be OK. We may need to revisit this
11103 // if it ever becomes possible to use barriers from a vertex shader.
11104 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
11105 {
11106 entry_func.fixup_hooks_in.push_back([this]() {
11107 statement("if (any(", to_expression(builtin_invocation_id_id),
11108 " >= ", to_expression(builtin_stage_input_size_id), "))");
11109 statement(" return;");
11110 });
11111 }
11112
11113 // Look for sampled images and buffer. Add hooks to set up the swizzle constants or array lengths.
11114 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
11115 auto &type = get_variable_data_type(var);
11116 uint32_t var_id = var.self;
11117 bool ssbo = has_decoration(type.self, DecorationBufferBlock);
11118
11119 if (var.storage == StorageClassUniformConstant && !is_hidden_variable(var))
11120 {
11121 if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
11122 {
11123 entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
11124 bool is_array_type = !type.array.empty();
11125
11126 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
11127 if (descriptor_set_is_argument_buffer(desc_set))
11128 {
11129 statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
11130 is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
11131 ".spvSwizzleConstants", "[",
11132 convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
11133 }
11134 else
11135 {
11136 // If we have an array of images, we need to be able to index into it, so take a pointer instead.
11137 statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
11138 is_array_type ? " = &" : " = ", to_name(swizzle_buffer_id), "[",
11139 convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
11140 }
11141 });
11142 }
11143 }
11144 else if ((var.storage == StorageClassStorageBuffer || (var.storage == StorageClassUniform && ssbo)) &&
11145 !is_hidden_variable(var))
11146 {
11147 if (buffers_requiring_array_length.count(var.self))
11148 {
11149 entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
11150 bool is_array_type = !type.array.empty();
11151
11152 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
11153 if (descriptor_set_is_argument_buffer(desc_set))
11154 {
11155 statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
11156 is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
11157 ".spvBufferSizeConstants", "[",
11158 convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
11159 }
11160 else
11161 {
11162 // If we have an array of images, we need to be able to index into it, so take a pointer instead.
11163 statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
11164 is_array_type ? " = &" : " = ", to_name(buffer_size_buffer_id), "[",
11165 convert_to_string(get_metal_resource_index(var, type.basetype)), "];");
11166 }
11167 });
11168 }
11169 }
11170 });
11171
11172 // Builtin variables
11173 ir.for_each_typed_id<SPIRVariable>([this, &entry_func](uint32_t, SPIRVariable &var) {
11174 uint32_t var_id = var.self;
11175 BuiltIn bi_type = ir.meta[var_id].decoration.builtin_type;
11176
11177 if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
11178 return;
11179 if (!interface_variable_exists_in_entry_point(var.self))
11180 return;
11181
11182 if (var.storage == StorageClassInput && is_builtin_variable(var) && active_input_builtins.get(bi_type))
11183 {
11184 switch (bi_type)
11185 {
11186 case BuiltInSamplePosition:
11187 entry_func.fixup_hooks_in.push_back([=]() {
11188 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = get_sample_position(",
11189 to_expression(builtin_sample_id_id), ");");
11190 });
11191 break;
11192 case BuiltInFragCoord:
11193 if (is_sample_rate())
11194 {
11195 entry_func.fixup_hooks_in.push_back([=]() {
11196 statement(to_expression(var_id), ".xy += get_sample_position(",
11197 to_expression(builtin_sample_id_id), ") - 0.5;");
11198 });
11199 }
11200 break;
11201 case BuiltInHelperInvocation:
11202 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
11203 SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.3 on iOS.");
11204 else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
11205 SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
11206
11207 entry_func.fixup_hooks_in.push_back([=]() {
11208 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = simd_is_helper_thread();");
11209 });
11210 break;
11211 case BuiltInInvocationId:
11212 // This is direct-mapped without multi-patch workgroups.
11213 if (get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup)
11214 break;
11215
11216 entry_func.fixup_hooks_in.push_back([=]() {
11217 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11218 to_expression(builtin_invocation_id_id), ".x % ", this->get_entry_point().output_vertices,
11219 ";");
11220 });
11221 break;
11222 case BuiltInPrimitiveId:
11223 // This is natively supported by fragment and tessellation evaluation shaders.
11224 // In tessellation control shaders, this is direct-mapped without multi-patch workgroups.
11225 if (get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup)
11226 break;
11227
11228 entry_func.fixup_hooks_in.push_back([=]() {
11229 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = min(",
11230 to_expression(builtin_invocation_id_id), ".x / ", this->get_entry_point().output_vertices,
11231 ", spvIndirectParams[1]);");
11232 });
11233 break;
11234 case BuiltInPatchVertices:
11235 if (get_execution_model() == ExecutionModelTessellationEvaluation)
11236 entry_func.fixup_hooks_in.push_back([=]() {
11237 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11238 to_expression(patch_stage_in_var_id), ".gl_in.size();");
11239 });
11240 else
11241 entry_func.fixup_hooks_in.push_back([=]() {
11242 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = spvIndirectParams[0];");
11243 });
11244 break;
11245 case BuiltInTessCoord:
11246 // Emit a fixup to account for the shifted domain. Don't do this for triangles;
11247 // MoltenVK will just reverse the winding order instead.
11248 if (msl_options.tess_domain_origin_lower_left && !get_entry_point().flags.get(ExecutionModeTriangles))
11249 {
11250 string tc = to_expression(var_id);
11251 entry_func.fixup_hooks_in.push_back([=]() { statement(tc, ".y = 1.0 - ", tc, ".y;"); });
11252 }
11253 break;
11254 case BuiltInSubgroupId:
11255 if (!msl_options.emulate_subgroups)
11256 break;
11257 // For subgroup emulation, this is the same as the local invocation index.
11258 entry_func.fixup_hooks_in.push_back([=]() {
11259 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11260 to_expression(builtin_local_invocation_index_id), ";");
11261 });
11262 break;
11263 case BuiltInNumSubgroups:
11264 if (!msl_options.emulate_subgroups)
11265 break;
11266 // For subgroup emulation, this is the same as the workgroup size.
11267 entry_func.fixup_hooks_in.push_back([=]() {
11268 auto &type = expression_type(builtin_workgroup_size_id);
11269 string size_expr = to_expression(builtin_workgroup_size_id);
11270 if (type.vecsize >= 3)
11271 size_expr = join(size_expr, ".x * ", size_expr, ".y * ", size_expr, ".z");
11272 else if (type.vecsize == 2)
11273 size_expr = join(size_expr, ".x * ", size_expr, ".y");
11274 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", size_expr, ";");
11275 });
11276 break;
11277 case BuiltInSubgroupLocalInvocationId:
11278 if (!msl_options.emulate_subgroups)
11279 break;
11280 // For subgroup emulation, assume subgroups of size 1.
11281 entry_func.fixup_hooks_in.push_back(
11282 [=]() { statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = 0;"); });
11283 break;
11284 case BuiltInSubgroupSize:
11285 if (msl_options.emulate_subgroups)
11286 {
11287 // For subgroup emulation, assume subgroups of size 1.
11288 entry_func.fixup_hooks_in.push_back(
11289 [=]() { statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = 1;"); });
11290 }
11291 else if (msl_options.fixed_subgroup_size != 0)
11292 {
11293 entry_func.fixup_hooks_in.push_back([=]() {
11294 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11295 msl_options.fixed_subgroup_size, ";");
11296 });
11297 }
11298 break;
11299 case BuiltInSubgroupEqMask:
11300 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
11301 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
11302 if (!msl_options.supports_msl_version(2, 1))
11303 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
11304 entry_func.fixup_hooks_in.push_back([=]() {
11305 if (msl_options.is_ios())
11306 {
11307 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", "uint4(1 << ",
11308 to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
11309 }
11310 else
11311 {
11312 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11313 to_expression(builtin_subgroup_invocation_id_id), " >= 32 ? uint4(0, (1 << (",
11314 to_expression(builtin_subgroup_invocation_id_id), " - 32)), uint2(0)) : uint4(1 << ",
11315 to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
11316 }
11317 });
11318 break;
11319 case BuiltInSubgroupGeMask:
11320 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
11321 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
11322 if (!msl_options.supports_msl_version(2, 1))
11323 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
11324 if (msl_options.fixed_subgroup_size != 0)
11325 add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
11326 entry_func.fixup_hooks_in.push_back([=]() {
11327 // Case where index < 32, size < 32:
11328 // mask0 = bfi(0, 0xFFFFFFFF, index, size - index);
11329 // mask1 = bfi(0, 0xFFFFFFFF, 0, 0); // Gives 0
11330 // Case where index < 32 but size >= 32:
11331 // mask0 = bfi(0, 0xFFFFFFFF, index, 32 - index);
11332 // mask1 = bfi(0, 0xFFFFFFFF, 0, size - 32);
11333 // Case where index >= 32:
11334 // mask0 = bfi(0, 0xFFFFFFFF, 32, 0); // Gives 0
11335 // mask1 = bfi(0, 0xFFFFFFFF, index - 32, size - index);
11336 // This is expressed without branches to avoid divergent
11337 // control flow--hence the complicated min/max expressions.
11338 // This is further complicated by the fact that if you attempt
11339 // to bfi/bfe out-of-bounds on Metal, undefined behavior is the
11340 // result.
11341 if (msl_options.fixed_subgroup_size > 32)
11342 {
11343 // Don't use the subgroup size variable with fixed subgroup sizes,
11344 // since the variables could be defined in the wrong order.
11345 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11346 " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
11347 to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(32 - (int)",
11348 to_expression(builtin_subgroup_invocation_id_id),
11349 ", 0)), insert_bits(0u, 0xFFFFFFFF,"
11350 " (uint)max((int)",
11351 to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), ",
11352 msl_options.fixed_subgroup_size, " - max(",
11353 to_expression(builtin_subgroup_invocation_id_id),
11354 ", 32u)), uint2(0));");
11355 }
11356 else if (msl_options.fixed_subgroup_size != 0)
11357 {
11358 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11359 " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
11360 to_expression(builtin_subgroup_invocation_id_id), ", ",
11361 msl_options.fixed_subgroup_size, " - ",
11362 to_expression(builtin_subgroup_invocation_id_id),
11363 "), uint3(0));");
11364 }
11365 else if (msl_options.is_ios())
11366 {
11367 // On iOS, the SIMD-group size will currently never exceed 32.
11368 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11369 " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
11370 to_expression(builtin_subgroup_invocation_id_id), ", ",
11371 to_expression(builtin_subgroup_size_id), " - ",
11372 to_expression(builtin_subgroup_invocation_id_id), "), uint3(0));");
11373 }
11374 else
11375 {
11376 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11377 " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
11378 to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(min((int)",
11379 to_expression(builtin_subgroup_size_id), ", 32) - (int)",
11380 to_expression(builtin_subgroup_invocation_id_id),
11381 ", 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
11382 to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), (uint)max((int)",
11383 to_expression(builtin_subgroup_size_id), " - (int)max(",
11384 to_expression(builtin_subgroup_invocation_id_id), ", 32u), 0)), uint2(0));");
11385 }
11386 });
11387 break;
11388 case BuiltInSubgroupGtMask:
11389 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
11390 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
11391 if (!msl_options.supports_msl_version(2, 1))
11392 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
11393 add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
11394 entry_func.fixup_hooks_in.push_back([=]() {
11395 // The same logic applies here, except now the index is one
11396 // more than the subgroup invocation ID.
11397 if (msl_options.fixed_subgroup_size > 32)
11398 {
11399 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11400 " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
11401 to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(32 - (int)",
11402 to_expression(builtin_subgroup_invocation_id_id),
11403 " - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
11404 to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), ",
11405 msl_options.fixed_subgroup_size, " - max(",
11406 to_expression(builtin_subgroup_invocation_id_id),
11407 " + 1, 32u)), uint2(0));");
11408 }
11409 else if (msl_options.fixed_subgroup_size != 0)
11410 {
11411 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11412 " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
11413 to_expression(builtin_subgroup_invocation_id_id), " + 1, ",
11414 msl_options.fixed_subgroup_size, " - ",
11415 to_expression(builtin_subgroup_invocation_id_id),
11416 " - 1), uint3(0));");
11417 }
11418 else if (msl_options.is_ios())
11419 {
11420 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11421 " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
11422 to_expression(builtin_subgroup_invocation_id_id), " + 1, ",
11423 to_expression(builtin_subgroup_size_id), " - ",
11424 to_expression(builtin_subgroup_invocation_id_id), " - 1), uint3(0));");
11425 }
11426 else
11427 {
11428 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11429 " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
11430 to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(min((int)",
11431 to_expression(builtin_subgroup_size_id), ", 32) - (int)",
11432 to_expression(builtin_subgroup_invocation_id_id),
11433 " - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
11434 to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), (uint)max((int)",
11435 to_expression(builtin_subgroup_size_id), " - (int)max(",
11436 to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), 0)), uint2(0));");
11437 }
11438 });
11439 break;
11440 case BuiltInSubgroupLeMask:
11441 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
11442 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
11443 if (!msl_options.supports_msl_version(2, 1))
11444 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
11445 add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
11446 entry_func.fixup_hooks_in.push_back([=]() {
11447 if (msl_options.is_ios())
11448 {
11449 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11450 " = uint4(extract_bits(0xFFFFFFFF, 0, ",
11451 to_expression(builtin_subgroup_invocation_id_id), " + 1), uint3(0));");
11452 }
11453 else
11454 {
11455 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11456 " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
11457 to_expression(builtin_subgroup_invocation_id_id),
11458 " + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
11459 to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0)), uint2(0));");
11460 }
11461 });
11462 break;
11463 case BuiltInSubgroupLtMask:
11464 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
11465 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
11466 if (!msl_options.supports_msl_version(2, 1))
11467 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
11468 add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
11469 entry_func.fixup_hooks_in.push_back([=]() {
11470 if (msl_options.is_ios())
11471 {
11472 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11473 " = uint4(extract_bits(0xFFFFFFFF, 0, ",
11474 to_expression(builtin_subgroup_invocation_id_id), "), uint3(0));");
11475 }
11476 else
11477 {
11478 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
11479 " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
11480 to_expression(builtin_subgroup_invocation_id_id),
11481 ", 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
11482 to_expression(builtin_subgroup_invocation_id_id), " - 32, 0)), uint2(0));");
11483 }
11484 });
11485 break;
11486 case BuiltInViewIndex:
11487 if (!msl_options.multiview)
11488 {
11489 // According to the Vulkan spec, when not running under a multiview
11490 // render pass, ViewIndex is 0.
11491 entry_func.fixup_hooks_in.push_back([=]() {
11492 statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = 0;");
11493 });
11494 }
11495 else if (msl_options.view_index_from_device_index)
11496 {
11497 // In this case, we take the view index from that of the device we're running on.
11498 entry_func.fixup_hooks_in.push_back([=]() {
11499 statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11500 msl_options.device_index, ";");
11501 });
11502 // We actually don't want to set the render_target_array_index here.
11503 // Since every physical device is rendering a different view,
11504 // there's no need for layered rendering here.
11505 }
11506 else if (!msl_options.multiview_layered_rendering)
11507 {
11508 // In this case, the views are rendered one at a time. The view index, then,
11509 // is just the first part of the "view mask".
11510 entry_func.fixup_hooks_in.push_back([=]() {
11511 statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11512 to_expression(view_mask_buffer_id), "[0];");
11513 });
11514 }
11515 else if (get_execution_model() == ExecutionModelFragment)
11516 {
11517 // Because we adjusted the view index in the vertex shader, we have to
11518 // adjust it back here.
11519 entry_func.fixup_hooks_in.push_back([=]() {
11520 statement(to_expression(var_id), " += ", to_expression(view_mask_buffer_id), "[0];");
11521 });
11522 }
11523 else if (get_execution_model() == ExecutionModelVertex)
11524 {
11525 // Metal provides no special support for multiview, so we smuggle
11526 // the view index in the instance index.
11527 entry_func.fixup_hooks_in.push_back([=]() {
11528 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11529 to_expression(view_mask_buffer_id), "[0] + (", to_expression(builtin_instance_idx_id),
11530 " - ", to_expression(builtin_base_instance_id), ") % ",
11531 to_expression(view_mask_buffer_id), "[1];");
11532 statement(to_expression(builtin_instance_idx_id), " = (",
11533 to_expression(builtin_instance_idx_id), " - ",
11534 to_expression(builtin_base_instance_id), ") / ", to_expression(view_mask_buffer_id),
11535 "[1] + ", to_expression(builtin_base_instance_id), ";");
11536 });
11537 // In addition to setting the variable itself, we also need to
11538 // set the render_target_array_index with it on output. We have to
11539 // offset this by the base view index, because Metal isn't in on
11540 // our little game here.
11541 entry_func.fixup_hooks_out.push_back([=]() {
11542 statement(to_expression(builtin_layer_id), " = ", to_expression(var_id), " - ",
11543 to_expression(view_mask_buffer_id), "[0];");
11544 });
11545 }
11546 break;
11547 case BuiltInDeviceIndex:
11548 // Metal pipelines belong to the devices which create them, so we'll
11549 // need to create a MTLPipelineState for every MTLDevice in a grouped
11550 // VkDevice. We can assume, then, that the device index is constant.
11551 entry_func.fixup_hooks_in.push_back([=]() {
11552 statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11553 msl_options.device_index, ";");
11554 });
11555 break;
11556 case BuiltInWorkgroupId:
11557 if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInWorkgroupId))
11558 break;
11559
11560 // The vkCmdDispatchBase() command lets the client set the base value
11561 // of WorkgroupId. Metal has no direct equivalent; we must make this
11562 // adjustment ourselves.
11563 entry_func.fixup_hooks_in.push_back([=]() {
11564 statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id), ";");
11565 });
11566 break;
11567 case BuiltInGlobalInvocationId:
11568 if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInGlobalInvocationId))
11569 break;
11570
11571 // GlobalInvocationId is defined as LocalInvocationId + WorkgroupId * WorkgroupSize.
11572 // This needs to be adjusted too.
11573 entry_func.fixup_hooks_in.push_back([=]() {
11574 auto &execution = this->get_entry_point();
11575 uint32_t workgroup_size_id = execution.workgroup_size.constant;
11576 if (workgroup_size_id)
11577 statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
11578 " * ", to_expression(workgroup_size_id), ";");
11579 else
11580 statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
11581 " * uint3(", execution.workgroup_size.x, ", ", execution.workgroup_size.y, ", ",
11582 execution.workgroup_size.z, ");");
11583 });
11584 break;
11585 case BuiltInVertexId:
11586 case BuiltInVertexIndex:
11587 // This is direct-mapped normally.
11588 if (!msl_options.vertex_for_tessellation)
11589 break;
11590
11591 entry_func.fixup_hooks_in.push_back([=]() {
11592 builtin_declaration = true;
11593 switch (msl_options.vertex_index_type)
11594 {
11595 case Options::IndexType::None:
11596 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11597 to_expression(builtin_invocation_id_id), ".x + ",
11598 to_expression(builtin_dispatch_base_id), ".x;");
11599 break;
11600 case Options::IndexType::UInt16:
11601 case Options::IndexType::UInt32:
11602 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", index_buffer_var_name,
11603 "[", to_expression(builtin_invocation_id_id), ".x] + ",
11604 to_expression(builtin_dispatch_base_id), ".x;");
11605 break;
11606 }
11607 builtin_declaration = false;
11608 });
11609 break;
11610 case BuiltInBaseVertex:
11611 // This is direct-mapped normally.
11612 if (!msl_options.vertex_for_tessellation)
11613 break;
11614
11615 entry_func.fixup_hooks_in.push_back([=]() {
11616 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11617 to_expression(builtin_dispatch_base_id), ".x;");
11618 });
11619 break;
11620 case BuiltInInstanceId:
11621 case BuiltInInstanceIndex:
11622 // This is direct-mapped normally.
11623 if (!msl_options.vertex_for_tessellation)
11624 break;
11625
11626 entry_func.fixup_hooks_in.push_back([=]() {
11627 builtin_declaration = true;
11628 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11629 to_expression(builtin_invocation_id_id), ".y + ", to_expression(builtin_dispatch_base_id),
11630 ".y;");
11631 builtin_declaration = false;
11632 });
11633 break;
11634 case BuiltInBaseInstance:
11635 // This is direct-mapped normally.
11636 if (!msl_options.vertex_for_tessellation)
11637 break;
11638
11639 entry_func.fixup_hooks_in.push_back([=]() {
11640 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11641 to_expression(builtin_dispatch_base_id), ".y;");
11642 });
11643 break;
11644 default:
11645 break;
11646 }
11647 }
11648 else if (var.storage == StorageClassOutput && is_builtin_variable(var) && active_output_builtins.get(bi_type))
11649 {
11650 if (bi_type == BuiltInSampleMask && get_execution_model() == ExecutionModelFragment &&
11651 msl_options.additional_fixed_sample_mask != 0xffffffff)
11652 {
11653 // If the additional fixed sample mask was set, we need to adjust the sample_mask
11654 // output to reflect that. If the shader outputs the sample_mask itself too, we need
11655 // to AND the two masks to get the final one.
11656 if (does_shader_write_sample_mask)
11657 {
11658 entry_func.fixup_hooks_out.push_back([=]() {
11659 statement(to_expression(builtin_sample_mask_id),
11660 " &= ", msl_options.additional_fixed_sample_mask, ";");
11661 });
11662 }
11663 else
11664 {
11665 entry_func.fixup_hooks_out.push_back([=]() {
11666 statement(to_expression(builtin_sample_mask_id), " = ",
11667 msl_options.additional_fixed_sample_mask, ";");
11668 });
11669 }
11670 }
11671 }
11672 });
11673 }
11674
11675 // Returns the Metal index of the resource of the specified type as used by the specified variable.
get_metal_resource_index(SPIRVariable & var,SPIRType::BaseType basetype,uint32_t plane)11676 uint32_t CompilerMSL::get_metal_resource_index(SPIRVariable &var, SPIRType::BaseType basetype, uint32_t plane)
11677 {
11678 auto &execution = get_entry_point();
11679 auto &var_dec = ir.meta[var.self].decoration;
11680 auto &var_type = get<SPIRType>(var.basetype);
11681 uint32_t var_desc_set = (var.storage == StorageClassPushConstant) ? kPushConstDescSet : var_dec.set;
11682 uint32_t var_binding = (var.storage == StorageClassPushConstant) ? kPushConstBinding : var_dec.binding;
11683
11684 // If a matching binding has been specified, find and use it.
11685 auto itr = resource_bindings.find({ execution.model, var_desc_set, var_binding });
11686
11687 // Atomic helper buffers for image atomics need to use secondary bindings as well.
11688 bool use_secondary_binding = (var_type.basetype == SPIRType::SampledImage && basetype == SPIRType::Sampler) ||
11689 basetype == SPIRType::AtomicCounter;
11690
11691 auto resource_decoration =
11692 use_secondary_binding ? SPIRVCrossDecorationResourceIndexSecondary : SPIRVCrossDecorationResourceIndexPrimary;
11693
11694 if (plane == 1)
11695 resource_decoration = SPIRVCrossDecorationResourceIndexTertiary;
11696 if (plane == 2)
11697 resource_decoration = SPIRVCrossDecorationResourceIndexQuaternary;
11698
11699 if (itr != end(resource_bindings))
11700 {
11701 auto &remap = itr->second;
11702 remap.second = true;
11703 switch (basetype)
11704 {
11705 case SPIRType::Image:
11706 set_extended_decoration(var.self, resource_decoration, remap.first.msl_texture + plane);
11707 return remap.first.msl_texture + plane;
11708 case SPIRType::Sampler:
11709 set_extended_decoration(var.self, resource_decoration, remap.first.msl_sampler);
11710 return remap.first.msl_sampler;
11711 default:
11712 set_extended_decoration(var.self, resource_decoration, remap.first.msl_buffer);
11713 return remap.first.msl_buffer;
11714 }
11715 }
11716
11717 // If we have already allocated an index, keep using it.
11718 if (has_extended_decoration(var.self, resource_decoration))
11719 return get_extended_decoration(var.self, resource_decoration);
11720
11721 auto &type = get<SPIRType>(var.basetype);
11722
11723 if (type_is_msl_framebuffer_fetch(type))
11724 {
11725 // Frame-buffer fetch gets its fallback resource index from the input attachment index,
11726 // which is then treated as color index.
11727 return get_decoration(var.self, DecorationInputAttachmentIndex);
11728 }
11729 else if (msl_options.enable_decoration_binding)
11730 {
11731 // Allow user to enable decoration binding.
11732 // If there is no explicit mapping of bindings to MSL, use the declared binding as a fallback.
11733 if (has_decoration(var.self, DecorationBinding))
11734 {
11735 var_binding = get_decoration(var.self, DecorationBinding);
11736 // Avoid emitting sentinel bindings.
11737 if (var_binding < 0x80000000u)
11738 return var_binding;
11739 }
11740 }
11741
11742 // If we did not explicitly remap, allocate bindings on demand.
11743 // We cannot reliably use Binding decorations since SPIR-V and MSL's binding models are very different.
11744
11745 bool allocate_argument_buffer_ids = false;
11746
11747 if (var.storage != StorageClassPushConstant)
11748 allocate_argument_buffer_ids = descriptor_set_is_argument_buffer(var_desc_set);
11749
11750 uint32_t binding_stride = 1;
11751 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
11752 binding_stride *= to_array_size_literal(type, i);
11753
11754 assert(binding_stride != 0);
11755
11756 // If a binding has not been specified, revert to incrementing resource indices.
11757 uint32_t resource_index;
11758
11759 if (allocate_argument_buffer_ids)
11760 {
11761 // Allocate from a flat ID binding space.
11762 resource_index = next_metal_resource_ids[var_desc_set];
11763 next_metal_resource_ids[var_desc_set] += binding_stride;
11764 }
11765 else
11766 {
11767 // Allocate from plain bindings which are allocated per resource type.
11768 switch (basetype)
11769 {
11770 case SPIRType::Image:
11771 resource_index = next_metal_resource_index_texture;
11772 next_metal_resource_index_texture += binding_stride;
11773 break;
11774 case SPIRType::Sampler:
11775 resource_index = next_metal_resource_index_sampler;
11776 next_metal_resource_index_sampler += binding_stride;
11777 break;
11778 default:
11779 resource_index = next_metal_resource_index_buffer;
11780 next_metal_resource_index_buffer += binding_stride;
11781 break;
11782 }
11783 }
11784
11785 set_extended_decoration(var.self, resource_decoration, resource_index);
11786 return resource_index;
11787 }
11788
type_is_msl_framebuffer_fetch(const SPIRType & type) const11789 bool CompilerMSL::type_is_msl_framebuffer_fetch(const SPIRType &type) const
11790 {
11791 return type.basetype == SPIRType::Image && type.image.dim == DimSubpassData &&
11792 msl_options.use_framebuffer_fetch_subpasses;
11793 }
11794
argument_decl(const SPIRFunction::Parameter & arg)11795 string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
11796 {
11797 auto &var = get<SPIRVariable>(arg.id);
11798 auto &type = get_variable_data_type(var);
11799 auto &var_type = get<SPIRType>(arg.type);
11800 StorageClass storage = var_type.storage;
11801 bool is_pointer = var_type.pointer;
11802
11803 // If we need to modify the name of the variable, make sure we use the original variable.
11804 // Our alias is just a shadow variable.
11805 uint32_t name_id = var.self;
11806 if (arg.alias_global_variable && var.basevariable)
11807 name_id = var.basevariable;
11808
11809 bool constref = !arg.alias_global_variable && is_pointer && arg.write_count == 0;
11810 // Framebuffer fetch is plain value, const looks out of place, but it is not wrong.
11811 if (type_is_msl_framebuffer_fetch(type))
11812 constref = false;
11813
11814 bool type_is_image = type.basetype == SPIRType::Image || type.basetype == SPIRType::SampledImage ||
11815 type.basetype == SPIRType::Sampler;
11816
11817 // Arrays of images/samplers in MSL are always const.
11818 if (!type.array.empty() && type_is_image)
11819 constref = true;
11820
11821 string decl;
11822 if (constref)
11823 decl += "const ";
11824
11825 // If this is a combined image-sampler for a 2D image with floating-point type,
11826 // we emitted the 'spvDynamicImageSampler' type, and this is *not* an alias parameter
11827 // for a global, then we need to emit a "dynamic" combined image-sampler.
11828 // Unfortunately, this is necessary to properly support passing around
11829 // combined image-samplers with Y'CbCr conversions on them.
11830 bool is_dynamic_img_sampler = !arg.alias_global_variable && type.basetype == SPIRType::SampledImage &&
11831 type.image.dim == Dim2D && type_is_floating_point(get<SPIRType>(type.image.type)) &&
11832 spv_function_implementations.count(SPVFuncImplDynamicImageSampler);
11833
11834 // Allow Metal to use the array<T> template to make arrays a value type
11835 string address_space = get_argument_address_space(var);
11836 bool builtin = is_builtin_variable(var);
11837 is_using_builtin_array = builtin;
11838 if (address_space == "threadgroup")
11839 is_using_builtin_array = true;
11840
11841 if (var.basevariable && (var.basevariable == stage_in_ptr_var_id || var.basevariable == stage_out_ptr_var_id))
11842 decl += type_to_glsl(type, arg.id);
11843 else if (builtin)
11844 decl += builtin_type_decl(static_cast<BuiltIn>(get_decoration(arg.id, DecorationBuiltIn)), arg.id);
11845 else if ((storage == StorageClassUniform || storage == StorageClassStorageBuffer) && is_array(type))
11846 {
11847 is_using_builtin_array = true;
11848 decl += join(type_to_glsl(type, arg.id), "*");
11849 }
11850 else if (is_dynamic_img_sampler)
11851 {
11852 decl += join("spvDynamicImageSampler<", type_to_glsl(get<SPIRType>(type.image.type)), ">");
11853 // Mark the variable so that we can handle passing it to another function.
11854 set_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
11855 }
11856 else
11857 decl += type_to_glsl(type, arg.id);
11858
11859 bool opaque_handle = storage == StorageClassUniformConstant;
11860
11861 if (!builtin && !opaque_handle && !is_pointer &&
11862 (storage == StorageClassFunction || storage == StorageClassGeneric))
11863 {
11864 // If the argument is a pure value and not an opaque type, we will pass by value.
11865 if (msl_options.force_native_arrays && is_array(type))
11866 {
11867 // We are receiving an array by value. This is problematic.
11868 // We cannot be sure of the target address space since we are supposed to receive a copy,
11869 // but this is not possible with MSL without some extra work.
11870 // We will have to assume we're getting a reference in thread address space.
11871 // If we happen to get a reference in constant address space, the caller must emit a copy and pass that.
11872 // Thread const therefore becomes the only logical choice, since we cannot "create" a constant array from
11873 // non-constant arrays, but we can create thread const from constant.
11874 decl = string("thread const ") + decl;
11875 decl += " (&";
11876 const char *restrict_kw = to_restrict(name_id);
11877 if (*restrict_kw)
11878 {
11879 decl += " ";
11880 decl += restrict_kw;
11881 }
11882 decl += to_expression(name_id);
11883 decl += ")";
11884 decl += type_to_array_glsl(type);
11885 }
11886 else
11887 {
11888 if (!address_space.empty())
11889 decl = join(address_space, " ", decl);
11890 decl += " ";
11891 decl += to_expression(name_id);
11892 }
11893 }
11894 else if (is_array(type) && !type_is_image)
11895 {
11896 // Arrays of images and samplers are special cased.
11897 if (!address_space.empty())
11898 decl = join(address_space, " ", decl);
11899
11900 if (msl_options.argument_buffers)
11901 {
11902 uint32_t desc_set = get_decoration(name_id, DecorationDescriptorSet);
11903 if ((storage == StorageClassUniform || storage == StorageClassStorageBuffer) &&
11904 descriptor_set_is_argument_buffer(desc_set))
11905 {
11906 // An awkward case where we need to emit *more* address space declarations (yay!).
11907 // An example is where we pass down an array of buffer pointers to leaf functions.
11908 // It's a constant array containing pointers to constants.
11909 // The pointer array is always constant however. E.g.
11910 // device SSBO * constant (&array)[N].
11911 // const device SSBO * constant (&array)[N].
11912 // constant SSBO * constant (&array)[N].
11913 // However, this only matters for argument buffers, since for MSL 1.0 style codegen,
11914 // we emit the buffer array on stack instead, and that seems to work just fine apparently.
11915
11916 // If the argument was marked as being in device address space, any pointer to member would
11917 // be const device, not constant.
11918 if (argument_buffer_device_storage_mask & (1u << desc_set))
11919 decl += " const device";
11920 else
11921 decl += " constant";
11922 }
11923 }
11924
11925 decl += " (&";
11926 const char *restrict_kw = to_restrict(name_id);
11927 if (*restrict_kw)
11928 {
11929 decl += " ";
11930 decl += restrict_kw;
11931 }
11932 decl += to_expression(name_id);
11933 decl += ")";
11934 decl += type_to_array_glsl(type);
11935 }
11936 else if (!opaque_handle && (!pull_model_inputs.count(var.basevariable) || type.basetype == SPIRType::Struct))
11937 {
11938 // If this is going to be a reference to a variable pointer, the address space
11939 // for the reference has to go before the '&', but after the '*'.
11940 if (!address_space.empty())
11941 {
11942 if (decl.back() == '*')
11943 decl += join(" ", address_space, " ");
11944 else
11945 decl = join(address_space, " ", decl);
11946 }
11947 decl += "&";
11948 decl += " ";
11949 decl += to_restrict(name_id);
11950 decl += to_expression(name_id);
11951 }
11952 else
11953 {
11954 if (!address_space.empty())
11955 decl = join(address_space, " ", decl);
11956 decl += " ";
11957 decl += to_expression(name_id);
11958 }
11959
11960 // Emulate texture2D atomic operations
11961 auto *backing_var = maybe_get_backing_variable(name_id);
11962 if (backing_var && atomic_image_vars.count(backing_var->self))
11963 {
11964 decl += ", device atomic_" + type_to_glsl(get<SPIRType>(var_type.image.type), 0);
11965 decl += "* " + to_expression(name_id) + "_atomic";
11966 }
11967
11968 is_using_builtin_array = false;
11969
11970 return decl;
11971 }
11972
11973 // If we're currently in the entry point function, and the object
11974 // has a qualified name, use it, otherwise use the standard name.
to_name(uint32_t id,bool allow_alias) const11975 string CompilerMSL::to_name(uint32_t id, bool allow_alias) const
11976 {
11977 if (current_function && (current_function->self == ir.default_entry_point))
11978 {
11979 auto *m = ir.find_meta(id);
11980 if (m && !m->decoration.qualified_alias.empty())
11981 return m->decoration.qualified_alias;
11982 }
11983 return Compiler::to_name(id, allow_alias);
11984 }
11985
11986 // Returns a name that combines the name of the struct with the name of the member, except for Builtins
to_qualified_member_name(const SPIRType & type,uint32_t index)11987 string CompilerMSL::to_qualified_member_name(const SPIRType &type, uint32_t index)
11988 {
11989 // Don't qualify Builtin names because they are unique and are treated as such when building expressions
11990 BuiltIn builtin = BuiltInMax;
11991 if (is_member_builtin(type, index, &builtin))
11992 return builtin_to_glsl(builtin, type.storage);
11993
11994 // Strip any underscore prefix from member name
11995 string mbr_name = to_member_name(type, index);
11996 size_t startPos = mbr_name.find_first_not_of("_");
11997 mbr_name = (startPos != string::npos) ? mbr_name.substr(startPos) : "";
11998 return join(to_name(type.self), "_", mbr_name);
11999 }
12000
12001 // Ensures that the specified name is permanently usable by prepending a prefix
12002 // if the first chars are _ and a digit, which indicate a transient name.
ensure_valid_name(string name,string pfx)12003 string CompilerMSL::ensure_valid_name(string name, string pfx)
12004 {
12005 return (name.size() >= 2 && name[0] == '_' && isdigit(name[1])) ? (pfx + name) : name;
12006 }
12007
get_reserved_keyword_set()12008 const std::unordered_set<std::string> &CompilerMSL::get_reserved_keyword_set()
12009 {
12010 static const unordered_set<string> keywords = {
12011 "kernel",
12012 "vertex",
12013 "fragment",
12014 "compute",
12015 "bias",
12016 "level",
12017 "gradient2d",
12018 "gradientcube",
12019 "gradient3d",
12020 "min_lod_clamp",
12021 "assert",
12022 "VARIABLE_TRACEPOINT",
12023 "STATIC_DATA_TRACEPOINT",
12024 "STATIC_DATA_TRACEPOINT_V",
12025 "METAL_ALIGN",
12026 "METAL_ASM",
12027 "METAL_CONST",
12028 "METAL_DEPRECATED",
12029 "METAL_ENABLE_IF",
12030 "METAL_FUNC",
12031 "METAL_INTERNAL",
12032 "METAL_NON_NULL_RETURN",
12033 "METAL_NORETURN",
12034 "METAL_NOTHROW",
12035 "METAL_PURE",
12036 "METAL_UNAVAILABLE",
12037 "METAL_IMPLICIT",
12038 "METAL_EXPLICIT",
12039 "METAL_CONST_ARG",
12040 "METAL_ARG_UNIFORM",
12041 "METAL_ZERO_ARG",
12042 "METAL_VALID_LOD_ARG",
12043 "METAL_VALID_LEVEL_ARG",
12044 "METAL_VALID_STORE_ORDER",
12045 "METAL_VALID_LOAD_ORDER",
12046 "METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
12047 "METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
12048 "METAL_VALID_RENDER_TARGET",
12049 "is_function_constant_defined",
12050 "CHAR_BIT",
12051 "SCHAR_MAX",
12052 "SCHAR_MIN",
12053 "UCHAR_MAX",
12054 "CHAR_MAX",
12055 "CHAR_MIN",
12056 "USHRT_MAX",
12057 "SHRT_MAX",
12058 "SHRT_MIN",
12059 "UINT_MAX",
12060 "INT_MAX",
12061 "INT_MIN",
12062 "FLT_DIG",
12063 "FLT_MANT_DIG",
12064 "FLT_MAX_10_EXP",
12065 "FLT_MAX_EXP",
12066 "FLT_MIN_10_EXP",
12067 "FLT_MIN_EXP",
12068 "FLT_RADIX",
12069 "FLT_MAX",
12070 "FLT_MIN",
12071 "FLT_EPSILON",
12072 "FP_ILOGB0",
12073 "FP_ILOGBNAN",
12074 "MAXFLOAT",
12075 "HUGE_VALF",
12076 "INFINITY",
12077 "NAN",
12078 "M_E_F",
12079 "M_LOG2E_F",
12080 "M_LOG10E_F",
12081 "M_LN2_F",
12082 "M_LN10_F",
12083 "M_PI_F",
12084 "M_PI_2_F",
12085 "M_PI_4_F",
12086 "M_1_PI_F",
12087 "M_2_PI_F",
12088 "M_2_SQRTPI_F",
12089 "M_SQRT2_F",
12090 "M_SQRT1_2_F",
12091 "HALF_DIG",
12092 "HALF_MANT_DIG",
12093 "HALF_MAX_10_EXP",
12094 "HALF_MAX_EXP",
12095 "HALF_MIN_10_EXP",
12096 "HALF_MIN_EXP",
12097 "HALF_RADIX",
12098 "HALF_MAX",
12099 "HALF_MIN",
12100 "HALF_EPSILON",
12101 "MAXHALF",
12102 "HUGE_VALH",
12103 "M_E_H",
12104 "M_LOG2E_H",
12105 "M_LOG10E_H",
12106 "M_LN2_H",
12107 "M_LN10_H",
12108 "M_PI_H",
12109 "M_PI_2_H",
12110 "M_PI_4_H",
12111 "M_1_PI_H",
12112 "M_2_PI_H",
12113 "M_2_SQRTPI_H",
12114 "M_SQRT2_H",
12115 "M_SQRT1_2_H",
12116 "DBL_DIG",
12117 "DBL_MANT_DIG",
12118 "DBL_MAX_10_EXP",
12119 "DBL_MAX_EXP",
12120 "DBL_MIN_10_EXP",
12121 "DBL_MIN_EXP",
12122 "DBL_RADIX",
12123 "DBL_MAX",
12124 "DBL_MIN",
12125 "DBL_EPSILON",
12126 "HUGE_VAL",
12127 "M_E",
12128 "M_LOG2E",
12129 "M_LOG10E",
12130 "M_LN2",
12131 "M_LN10",
12132 "M_PI",
12133 "M_PI_2",
12134 "M_PI_4",
12135 "M_1_PI",
12136 "M_2_PI",
12137 "M_2_SQRTPI",
12138 "M_SQRT2",
12139 "M_SQRT1_2",
12140 "quad_broadcast",
12141 };
12142
12143 return keywords;
12144 }
12145
get_illegal_func_names()12146 const std::unordered_set<std::string> &CompilerMSL::get_illegal_func_names()
12147 {
12148 static const unordered_set<string> illegal_func_names = {
12149 "main",
12150 "saturate",
12151 "assert",
12152 "fmin3",
12153 "fmax3",
12154 "VARIABLE_TRACEPOINT",
12155 "STATIC_DATA_TRACEPOINT",
12156 "STATIC_DATA_TRACEPOINT_V",
12157 "METAL_ALIGN",
12158 "METAL_ASM",
12159 "METAL_CONST",
12160 "METAL_DEPRECATED",
12161 "METAL_ENABLE_IF",
12162 "METAL_FUNC",
12163 "METAL_INTERNAL",
12164 "METAL_NON_NULL_RETURN",
12165 "METAL_NORETURN",
12166 "METAL_NOTHROW",
12167 "METAL_PURE",
12168 "METAL_UNAVAILABLE",
12169 "METAL_IMPLICIT",
12170 "METAL_EXPLICIT",
12171 "METAL_CONST_ARG",
12172 "METAL_ARG_UNIFORM",
12173 "METAL_ZERO_ARG",
12174 "METAL_VALID_LOD_ARG",
12175 "METAL_VALID_LEVEL_ARG",
12176 "METAL_VALID_STORE_ORDER",
12177 "METAL_VALID_LOAD_ORDER",
12178 "METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
12179 "METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
12180 "METAL_VALID_RENDER_TARGET",
12181 "is_function_constant_defined",
12182 "CHAR_BIT",
12183 "SCHAR_MAX",
12184 "SCHAR_MIN",
12185 "UCHAR_MAX",
12186 "CHAR_MAX",
12187 "CHAR_MIN",
12188 "USHRT_MAX",
12189 "SHRT_MAX",
12190 "SHRT_MIN",
12191 "UINT_MAX",
12192 "INT_MAX",
12193 "INT_MIN",
12194 "FLT_DIG",
12195 "FLT_MANT_DIG",
12196 "FLT_MAX_10_EXP",
12197 "FLT_MAX_EXP",
12198 "FLT_MIN_10_EXP",
12199 "FLT_MIN_EXP",
12200 "FLT_RADIX",
12201 "FLT_MAX",
12202 "FLT_MIN",
12203 "FLT_EPSILON",
12204 "FP_ILOGB0",
12205 "FP_ILOGBNAN",
12206 "MAXFLOAT",
12207 "HUGE_VALF",
12208 "INFINITY",
12209 "NAN",
12210 "M_E_F",
12211 "M_LOG2E_F",
12212 "M_LOG10E_F",
12213 "M_LN2_F",
12214 "M_LN10_F",
12215 "M_PI_F",
12216 "M_PI_2_F",
12217 "M_PI_4_F",
12218 "M_1_PI_F",
12219 "M_2_PI_F",
12220 "M_2_SQRTPI_F",
12221 "M_SQRT2_F",
12222 "M_SQRT1_2_F",
12223 "HALF_DIG",
12224 "HALF_MANT_DIG",
12225 "HALF_MAX_10_EXP",
12226 "HALF_MAX_EXP",
12227 "HALF_MIN_10_EXP",
12228 "HALF_MIN_EXP",
12229 "HALF_RADIX",
12230 "HALF_MAX",
12231 "HALF_MIN",
12232 "HALF_EPSILON",
12233 "MAXHALF",
12234 "HUGE_VALH",
12235 "M_E_H",
12236 "M_LOG2E_H",
12237 "M_LOG10E_H",
12238 "M_LN2_H",
12239 "M_LN10_H",
12240 "M_PI_H",
12241 "M_PI_2_H",
12242 "M_PI_4_H",
12243 "M_1_PI_H",
12244 "M_2_PI_H",
12245 "M_2_SQRTPI_H",
12246 "M_SQRT2_H",
12247 "M_SQRT1_2_H",
12248 "DBL_DIG",
12249 "DBL_MANT_DIG",
12250 "DBL_MAX_10_EXP",
12251 "DBL_MAX_EXP",
12252 "DBL_MIN_10_EXP",
12253 "DBL_MIN_EXP",
12254 "DBL_RADIX",
12255 "DBL_MAX",
12256 "DBL_MIN",
12257 "DBL_EPSILON",
12258 "HUGE_VAL",
12259 "M_E",
12260 "M_LOG2E",
12261 "M_LOG10E",
12262 "M_LN2",
12263 "M_LN10",
12264 "M_PI",
12265 "M_PI_2",
12266 "M_PI_4",
12267 "M_1_PI",
12268 "M_2_PI",
12269 "M_2_SQRTPI",
12270 "M_SQRT2",
12271 "M_SQRT1_2",
12272 };
12273
12274 return illegal_func_names;
12275 }
12276
12277 // Replace all names that match MSL keywords or Metal Standard Library functions.
replace_illegal_names()12278 void CompilerMSL::replace_illegal_names()
12279 {
12280 // FIXME: MSL and GLSL are doing two different things here.
12281 // Agree on convention and remove this override.
12282 auto &keywords = get_reserved_keyword_set();
12283 auto &illegal_func_names = get_illegal_func_names();
12284
12285 ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &) {
12286 auto *meta = ir.find_meta(self);
12287 if (!meta)
12288 return;
12289
12290 auto &dec = meta->decoration;
12291 if (keywords.find(dec.alias) != end(keywords))
12292 dec.alias += "0";
12293 });
12294
12295 ir.for_each_typed_id<SPIRFunction>([&](uint32_t self, SPIRFunction &) {
12296 auto *meta = ir.find_meta(self);
12297 if (!meta)
12298 return;
12299
12300 auto &dec = meta->decoration;
12301 if (illegal_func_names.find(dec.alias) != end(illegal_func_names))
12302 dec.alias += "0";
12303 });
12304
12305 ir.for_each_typed_id<SPIRType>([&](uint32_t self, SPIRType &) {
12306 auto *meta = ir.find_meta(self);
12307 if (!meta)
12308 return;
12309
12310 for (auto &mbr_dec : meta->members)
12311 if (keywords.find(mbr_dec.alias) != end(keywords))
12312 mbr_dec.alias += "0";
12313 });
12314
12315 CompilerGLSL::replace_illegal_names();
12316 }
12317
replace_illegal_entry_point_names()12318 void CompilerMSL::replace_illegal_entry_point_names()
12319 {
12320 auto &illegal_func_names = get_illegal_func_names();
12321
12322 // It is important to this before we fixup identifiers,
12323 // since if ep_name is reserved, we will need to fix that up,
12324 // and then copy alias back into entry.name after the fixup.
12325 for (auto &entry : ir.entry_points)
12326 {
12327 // Change both the entry point name and the alias, to keep them synced.
12328 string &ep_name = entry.second.name;
12329 if (illegal_func_names.find(ep_name) != end(illegal_func_names))
12330 ep_name += "0";
12331
12332 ir.meta[entry.first].decoration.alias = ep_name;
12333 }
12334 }
12335
sync_entry_point_aliases_and_names()12336 void CompilerMSL::sync_entry_point_aliases_and_names()
12337 {
12338 for (auto &entry : ir.entry_points)
12339 entry.second.name = ir.meta[entry.first].decoration.alias;
12340 }
12341
to_member_reference(uint32_t base,const SPIRType & type,uint32_t index,bool ptr_chain)12342 string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain)
12343 {
12344 if (index < uint32_t(type.member_type_index_redirection.size()))
12345 index = type.member_type_index_redirection[index];
12346
12347 auto *var = maybe_get<SPIRVariable>(base);
12348 // If this is a buffer array, we have to dereference the buffer pointers.
12349 // Otherwise, if this is a pointer expression, dereference it.
12350
12351 bool declared_as_pointer = false;
12352
12353 if (var)
12354 {
12355 // Only allow -> dereference for block types. This is so we get expressions like
12356 // buffer[i]->first_member.second_member, rather than buffer[i]->first->second.
12357 bool is_block = has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock);
12358
12359 bool is_buffer_variable =
12360 is_block && (var->storage == StorageClassUniform || var->storage == StorageClassStorageBuffer);
12361 declared_as_pointer = is_buffer_variable && is_array(get<SPIRType>(var->basetype));
12362 }
12363
12364 if (declared_as_pointer || (!ptr_chain && should_dereference(base)))
12365 return join("->", to_member_name(type, index));
12366 else
12367 return join(".", to_member_name(type, index));
12368 }
12369
to_qualifiers_glsl(uint32_t id)12370 string CompilerMSL::to_qualifiers_glsl(uint32_t id)
12371 {
12372 string quals;
12373
12374 auto &type = expression_type(id);
12375 if (type.storage == StorageClassWorkgroup)
12376 quals += "threadgroup ";
12377
12378 return quals;
12379 }
12380
12381 // The optional id parameter indicates the object whose type we are trying
12382 // to find the description for. It is optional. Most type descriptions do not
12383 // depend on a specific object's use of that type.
type_to_glsl(const SPIRType & type,uint32_t id)12384 string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id)
12385 {
12386 string type_name;
12387
12388 // Pointer?
12389 if (type.pointer)
12390 {
12391 const char *restrict_kw;
12392 type_name = join(get_type_address_space(type, id), " ", type_to_glsl(get<SPIRType>(type.parent_type), id));
12393
12394 switch (type.basetype)
12395 {
12396 case SPIRType::Image:
12397 case SPIRType::SampledImage:
12398 case SPIRType::Sampler:
12399 // These are handles.
12400 break;
12401 default:
12402 // Anything else can be a raw pointer.
12403 type_name += "*";
12404 restrict_kw = to_restrict(id);
12405 if (*restrict_kw)
12406 {
12407 type_name += " ";
12408 type_name += restrict_kw;
12409 }
12410 break;
12411 }
12412 return type_name;
12413 }
12414
12415 switch (type.basetype)
12416 {
12417 case SPIRType::Struct:
12418 // Need OpName lookup here to get a "sensible" name for a struct.
12419 // Allow Metal to use the array<T> template to make arrays a value type
12420 type_name = to_name(type.self);
12421 break;
12422
12423 case SPIRType::Image:
12424 case SPIRType::SampledImage:
12425 return image_type_glsl(type, id);
12426
12427 case SPIRType::Sampler:
12428 return sampler_type(type, id);
12429
12430 case SPIRType::Void:
12431 return "void";
12432
12433 case SPIRType::AtomicCounter:
12434 return "atomic_uint";
12435
12436 case SPIRType::ControlPointArray:
12437 return join("patch_control_point<", type_to_glsl(get<SPIRType>(type.parent_type), id), ">");
12438
12439 case SPIRType::Interpolant:
12440 return join("interpolant<", type_to_glsl(get<SPIRType>(type.parent_type), id), ", interpolation::",
12441 has_decoration(type.self, DecorationNoPerspective) ? "no_perspective" : "perspective", ">");
12442
12443 // Scalars
12444 case SPIRType::Boolean:
12445 type_name = "bool";
12446 break;
12447 case SPIRType::Char:
12448 case SPIRType::SByte:
12449 type_name = "char";
12450 break;
12451 case SPIRType::UByte:
12452 type_name = "uchar";
12453 break;
12454 case SPIRType::Short:
12455 type_name = "short";
12456 break;
12457 case SPIRType::UShort:
12458 type_name = "ushort";
12459 break;
12460 case SPIRType::Int:
12461 type_name = "int";
12462 break;
12463 case SPIRType::UInt:
12464 type_name = "uint";
12465 break;
12466 case SPIRType::Int64:
12467 if (!msl_options.supports_msl_version(2, 2))
12468 SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
12469 type_name = "long";
12470 break;
12471 case SPIRType::UInt64:
12472 if (!msl_options.supports_msl_version(2, 2))
12473 SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
12474 type_name = "ulong";
12475 break;
12476 case SPIRType::Half:
12477 type_name = "half";
12478 break;
12479 case SPIRType::Float:
12480 type_name = "float";
12481 break;
12482 case SPIRType::Double:
12483 type_name = "double"; // Currently unsupported
12484 break;
12485
12486 default:
12487 return "unknown_type";
12488 }
12489
12490 // Matrix?
12491 if (type.columns > 1)
12492 type_name += to_string(type.columns) + "x";
12493
12494 // Vector or Matrix?
12495 if (type.vecsize > 1)
12496 type_name += to_string(type.vecsize);
12497
12498 if (type.array.empty() || using_builtin_array())
12499 {
12500 return type_name;
12501 }
12502 else
12503 {
12504 // Allow Metal to use the array<T> template to make arrays a value type
12505 add_spv_func_and_recompile(SPVFuncImplUnsafeArray);
12506 string res;
12507 string sizes;
12508
12509 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
12510 {
12511 res += "spvUnsafeArray<";
12512 sizes += ", ";
12513 sizes += to_array_size(type, i);
12514 sizes += ">";
12515 }
12516
12517 res += type_name + sizes;
12518 return res;
12519 }
12520 }
12521
type_to_array_glsl(const SPIRType & type)12522 string CompilerMSL::type_to_array_glsl(const SPIRType &type)
12523 {
12524 // Allow Metal to use the array<T> template to make arrays a value type
12525 switch (type.basetype)
12526 {
12527 case SPIRType::AtomicCounter:
12528 case SPIRType::ControlPointArray:
12529 {
12530 return CompilerGLSL::type_to_array_glsl(type);
12531 }
12532 default:
12533 {
12534 if (using_builtin_array())
12535 return CompilerGLSL::type_to_array_glsl(type);
12536 else
12537 return "";
12538 }
12539 }
12540 }
12541
12542 // Threadgroup arrays can't have a wrapper type
variable_decl(const SPIRVariable & variable)12543 std::string CompilerMSL::variable_decl(const SPIRVariable &variable)
12544 {
12545 if (variable.storage == StorageClassWorkgroup)
12546 {
12547 is_using_builtin_array = true;
12548 }
12549 std::string expr = CompilerGLSL::variable_decl(variable);
12550 if (variable.storage == StorageClassWorkgroup)
12551 {
12552 is_using_builtin_array = false;
12553 }
12554 return expr;
12555 }
12556
12557 // GCC workaround of lambdas calling protected funcs
variable_decl(const SPIRType & type,const std::string & name,uint32_t id)12558 std::string CompilerMSL::variable_decl(const SPIRType &type, const std::string &name, uint32_t id)
12559 {
12560 return CompilerGLSL::variable_decl(type, name, id);
12561 }
12562
sampler_type(const SPIRType & type,uint32_t id)12563 std::string CompilerMSL::sampler_type(const SPIRType &type, uint32_t id)
12564 {
12565 auto *var = maybe_get<SPIRVariable>(id);
12566 if (var && var->basevariable)
12567 {
12568 // Check against the base variable, and not a fake ID which might have been generated for this variable.
12569 id = var->basevariable;
12570 }
12571
12572 if (!type.array.empty())
12573 {
12574 if (!msl_options.supports_msl_version(2))
12575 SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of samplers.");
12576
12577 if (type.array.size() > 1)
12578 SPIRV_CROSS_THROW("Arrays of arrays of samplers are not supported in MSL.");
12579
12580 // Arrays of samplers in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
12581 // If we have a runtime array, it could be a variable-count descriptor set binding.
12582 uint32_t array_size = to_array_size_literal(type);
12583 if (array_size == 0)
12584 array_size = get_resource_array_size(id);
12585
12586 if (array_size == 0)
12587 SPIRV_CROSS_THROW("Unsized array of samplers is not supported in MSL.");
12588
12589 auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
12590 return join("array<", sampler_type(parent, id), ", ", array_size, ">");
12591 }
12592 else
12593 return "sampler";
12594 }
12595
12596 // Returns an MSL string describing the SPIR-V image type
image_type_glsl(const SPIRType & type,uint32_t id)12597 string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id)
12598 {
12599 auto *var = maybe_get<SPIRVariable>(id);
12600 if (var && var->basevariable)
12601 {
12602 // For comparison images, check against the base variable,
12603 // and not the fake ID which might have been generated for this variable.
12604 id = var->basevariable;
12605 }
12606
12607 if (!type.array.empty())
12608 {
12609 uint32_t major = 2, minor = 0;
12610 if (msl_options.is_ios())
12611 {
12612 major = 1;
12613 minor = 2;
12614 }
12615 if (!msl_options.supports_msl_version(major, minor))
12616 {
12617 if (msl_options.is_ios())
12618 SPIRV_CROSS_THROW("MSL 1.2 or greater is required for arrays of textures.");
12619 else
12620 SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of textures.");
12621 }
12622
12623 if (type.array.size() > 1)
12624 SPIRV_CROSS_THROW("Arrays of arrays of textures are not supported in MSL.");
12625
12626 // Arrays of images in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
12627 // If we have a runtime array, it could be a variable-count descriptor set binding.
12628 uint32_t array_size = to_array_size_literal(type);
12629 if (array_size == 0)
12630 array_size = get_resource_array_size(id);
12631
12632 if (array_size == 0)
12633 SPIRV_CROSS_THROW("Unsized array of images is not supported in MSL.");
12634
12635 auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
12636 return join("array<", image_type_glsl(parent, id), ", ", array_size, ">");
12637 }
12638
12639 string img_type_name;
12640
12641 // Bypass pointers because we need the real image struct
12642 auto &img_type = get<SPIRType>(type.self).image;
12643 if (image_is_comparison(type, id))
12644 {
12645 switch (img_type.dim)
12646 {
12647 case Dim1D:
12648 case Dim2D:
12649 if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
12650 {
12651 // Use a native Metal 1D texture
12652 img_type_name += "depth1d_unsupported_by_metal";
12653 break;
12654 }
12655
12656 if (img_type.ms && img_type.arrayed)
12657 {
12658 if (!msl_options.supports_msl_version(2, 1))
12659 SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
12660 img_type_name += "depth2d_ms_array";
12661 }
12662 else if (img_type.ms)
12663 img_type_name += "depth2d_ms";
12664 else if (img_type.arrayed)
12665 img_type_name += "depth2d_array";
12666 else
12667 img_type_name += "depth2d";
12668 break;
12669 case Dim3D:
12670 img_type_name += "depth3d_unsupported_by_metal";
12671 break;
12672 case DimCube:
12673 if (!msl_options.emulate_cube_array)
12674 img_type_name += (img_type.arrayed ? "depthcube_array" : "depthcube");
12675 else
12676 img_type_name += (img_type.arrayed ? "depth2d_array" : "depthcube");
12677 break;
12678 default:
12679 img_type_name += "unknown_depth_texture_type";
12680 break;
12681 }
12682 }
12683 else
12684 {
12685 switch (img_type.dim)
12686 {
12687 case DimBuffer:
12688 if (img_type.ms || img_type.arrayed)
12689 SPIRV_CROSS_THROW("Cannot use texel buffers with multisampling or array layers.");
12690
12691 if (msl_options.texture_buffer_native)
12692 {
12693 if (!msl_options.supports_msl_version(2, 1))
12694 SPIRV_CROSS_THROW("Native texture_buffer type is only supported in MSL 2.1.");
12695 img_type_name = "texture_buffer";
12696 }
12697 else
12698 img_type_name += "texture2d";
12699 break;
12700 case Dim1D:
12701 case Dim2D:
12702 case DimSubpassData:
12703 {
12704 bool subpass_array =
12705 img_type.dim == DimSubpassData && (msl_options.multiview || msl_options.arrayed_subpass_input);
12706 if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
12707 {
12708 // Use a native Metal 1D texture
12709 img_type_name += (img_type.arrayed ? "texture1d_array" : "texture1d");
12710 break;
12711 }
12712
12713 // Use Metal's native frame-buffer fetch API for subpass inputs.
12714 if (type_is_msl_framebuffer_fetch(type))
12715 {
12716 auto img_type_4 = get<SPIRType>(img_type.type);
12717 img_type_4.vecsize = 4;
12718 return type_to_glsl(img_type_4);
12719 }
12720 if (img_type.ms && (img_type.arrayed || subpass_array))
12721 {
12722 if (!msl_options.supports_msl_version(2, 1))
12723 SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
12724 img_type_name += "texture2d_ms_array";
12725 }
12726 else if (img_type.ms)
12727 img_type_name += "texture2d_ms";
12728 else if (img_type.arrayed || subpass_array)
12729 img_type_name += "texture2d_array";
12730 else
12731 img_type_name += "texture2d";
12732 break;
12733 }
12734 case Dim3D:
12735 img_type_name += "texture3d";
12736 break;
12737 case DimCube:
12738 if (!msl_options.emulate_cube_array)
12739 img_type_name += (img_type.arrayed ? "texturecube_array" : "texturecube");
12740 else
12741 img_type_name += (img_type.arrayed ? "texture2d_array" : "texturecube");
12742 break;
12743 default:
12744 img_type_name += "unknown_texture_type";
12745 break;
12746 }
12747 }
12748
12749 // Append the pixel type
12750 img_type_name += "<";
12751 img_type_name += type_to_glsl(get<SPIRType>(img_type.type));
12752
12753 // For unsampled images, append the sample/read/write access qualifier.
12754 // For kernel images, the access qualifier my be supplied directly by SPIR-V.
12755 // Otherwise it may be set based on whether the image is read from or written to within the shader.
12756 if (type.basetype == SPIRType::Image && type.image.sampled == 2 && type.image.dim != DimSubpassData)
12757 {
12758 switch (img_type.access)
12759 {
12760 case AccessQualifierReadOnly:
12761 img_type_name += ", access::read";
12762 break;
12763
12764 case AccessQualifierWriteOnly:
12765 img_type_name += ", access::write";
12766 break;
12767
12768 case AccessQualifierReadWrite:
12769 img_type_name += ", access::read_write";
12770 break;
12771
12772 default:
12773 {
12774 auto *p_var = maybe_get_backing_variable(id);
12775 if (p_var && p_var->basevariable)
12776 p_var = maybe_get<SPIRVariable>(p_var->basevariable);
12777 if (p_var && !has_decoration(p_var->self, DecorationNonWritable))
12778 {
12779 img_type_name += ", access::";
12780
12781 if (!has_decoration(p_var->self, DecorationNonReadable))
12782 img_type_name += "read_";
12783
12784 img_type_name += "write";
12785 }
12786 break;
12787 }
12788 }
12789 }
12790
12791 img_type_name += ">";
12792
12793 return img_type_name;
12794 }
12795
emit_subgroup_op(const Instruction & i)12796 void CompilerMSL::emit_subgroup_op(const Instruction &i)
12797 {
12798 const uint32_t *ops = stream(i);
12799 auto op = static_cast<Op>(i.op);
12800
12801 if (msl_options.emulate_subgroups)
12802 {
12803 // In this mode, only the GroupNonUniform cap is supported. The only op
12804 // we need to handle, then, is OpGroupNonUniformElect.
12805 if (op != OpGroupNonUniformElect)
12806 SPIRV_CROSS_THROW("Subgroup emulation does not support operations other than Elect.");
12807 // In this mode, the subgroup size is assumed to be one, so every invocation
12808 // is elected.
12809 emit_op(ops[0], ops[1], "true", true);
12810 return;
12811 }
12812
12813 // Metal 2.0 is required. iOS only supports quad ops on 11.0 (2.0), with
12814 // full support in 13.0 (2.2). macOS only supports broadcast and shuffle on
12815 // 10.13 (2.0), with full support in 10.14 (2.1).
12816 // Note that Apple GPUs before A13 make no distinction between a quad-group
12817 // and a SIMD-group; all SIMD-groups are quad-groups on those.
12818 if (!msl_options.supports_msl_version(2))
12819 SPIRV_CROSS_THROW("Subgroups are only supported in Metal 2.0 and up.");
12820
12821 // If we need to do implicit bitcasts, make sure we do it with the correct type.
12822 uint32_t integer_width = get_integer_width_for_instruction(i);
12823 auto int_type = to_signed_basetype(integer_width);
12824 auto uint_type = to_unsigned_basetype(integer_width);
12825
12826 if (msl_options.is_ios() && (!msl_options.supports_msl_version(2, 3) || !msl_options.ios_use_simdgroup_functions))
12827 {
12828 switch (op)
12829 {
12830 default:
12831 SPIRV_CROSS_THROW("Subgroup ops beyond broadcast, ballot, and shuffle on iOS require Metal 2.3 and up.");
12832 case OpGroupNonUniformBroadcastFirst:
12833 if (!msl_options.supports_msl_version(2, 2))
12834 SPIRV_CROSS_THROW("BroadcastFirst on iOS requires Metal 2.2 and up.");
12835 break;
12836 case OpGroupNonUniformElect:
12837 if (!msl_options.supports_msl_version(2, 2))
12838 SPIRV_CROSS_THROW("Elect on iOS requires Metal 2.2 and up.");
12839 break;
12840 case OpGroupNonUniformAny:
12841 case OpGroupNonUniformAll:
12842 case OpGroupNonUniformAllEqual:
12843 case OpGroupNonUniformBallot:
12844 case OpGroupNonUniformInverseBallot:
12845 case OpGroupNonUniformBallotBitExtract:
12846 case OpGroupNonUniformBallotFindLSB:
12847 case OpGroupNonUniformBallotFindMSB:
12848 case OpGroupNonUniformBallotBitCount:
12849 if (!msl_options.supports_msl_version(2, 2))
12850 SPIRV_CROSS_THROW("Ballot ops on iOS requires Metal 2.2 and up.");
12851 break;
12852 case OpGroupNonUniformBroadcast:
12853 case OpGroupNonUniformShuffle:
12854 case OpGroupNonUniformShuffleXor:
12855 case OpGroupNonUniformShuffleUp:
12856 case OpGroupNonUniformShuffleDown:
12857 case OpGroupNonUniformQuadSwap:
12858 case OpGroupNonUniformQuadBroadcast:
12859 break;
12860 }
12861 }
12862
12863 if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
12864 {
12865 switch (op)
12866 {
12867 default:
12868 SPIRV_CROSS_THROW("Subgroup ops beyond broadcast and shuffle on macOS require Metal 2.1 and up.");
12869 case OpGroupNonUniformBroadcast:
12870 case OpGroupNonUniformShuffle:
12871 case OpGroupNonUniformShuffleXor:
12872 case OpGroupNonUniformShuffleUp:
12873 case OpGroupNonUniformShuffleDown:
12874 break;
12875 }
12876 }
12877
12878 uint32_t result_type = ops[0];
12879 uint32_t id = ops[1];
12880
12881 auto scope = static_cast<Scope>(evaluate_constant_u32(ops[2]));
12882 if (scope != ScopeSubgroup)
12883 SPIRV_CROSS_THROW("Only subgroup scope is supported.");
12884
12885 switch (op)
12886 {
12887 case OpGroupNonUniformElect:
12888 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
12889 emit_op(result_type, id, "quad_is_first()", false);
12890 else
12891 emit_op(result_type, id, "simd_is_first()", false);
12892 break;
12893
12894 case OpGroupNonUniformBroadcast:
12895 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBroadcast");
12896 break;
12897
12898 case OpGroupNonUniformBroadcastFirst:
12899 emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBroadcastFirst");
12900 break;
12901
12902 case OpGroupNonUniformBallot:
12903 emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallot");
12904 break;
12905
12906 case OpGroupNonUniformInverseBallot:
12907 emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_invocation_id_id, "spvSubgroupBallotBitExtract");
12908 break;
12909
12910 case OpGroupNonUniformBallotBitExtract:
12911 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBallotBitExtract");
12912 break;
12913
12914 case OpGroupNonUniformBallotFindLSB:
12915 emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_size_id, "spvSubgroupBallotFindLSB");
12916 break;
12917
12918 case OpGroupNonUniformBallotFindMSB:
12919 emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_size_id, "spvSubgroupBallotFindMSB");
12920 break;
12921
12922 case OpGroupNonUniformBallotBitCount:
12923 {
12924 auto operation = static_cast<GroupOperation>(ops[3]);
12925 switch (operation)
12926 {
12927 case GroupOperationReduce:
12928 emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_size_id, "spvSubgroupBallotBitCount");
12929 break;
12930 case GroupOperationInclusiveScan:
12931 emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
12932 "spvSubgroupBallotInclusiveBitCount");
12933 break;
12934 case GroupOperationExclusiveScan:
12935 emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
12936 "spvSubgroupBallotExclusiveBitCount");
12937 break;
12938 default:
12939 SPIRV_CROSS_THROW("Invalid BitCount operation.");
12940 break;
12941 }
12942 break;
12943 }
12944
12945 case OpGroupNonUniformShuffle:
12946 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffle");
12947 break;
12948
12949 case OpGroupNonUniformShuffleXor:
12950 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleXor");
12951 break;
12952
12953 case OpGroupNonUniformShuffleUp:
12954 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleUp");
12955 break;
12956
12957 case OpGroupNonUniformShuffleDown:
12958 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleDown");
12959 break;
12960
12961 case OpGroupNonUniformAll:
12962 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
12963 emit_unary_func_op(result_type, id, ops[3], "quad_all");
12964 else
12965 emit_unary_func_op(result_type, id, ops[3], "simd_all");
12966 break;
12967
12968 case OpGroupNonUniformAny:
12969 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
12970 emit_unary_func_op(result_type, id, ops[3], "quad_any");
12971 else
12972 emit_unary_func_op(result_type, id, ops[3], "simd_any");
12973 break;
12974
12975 case OpGroupNonUniformAllEqual:
12976 emit_unary_func_op(result_type, id, ops[3], "spvSubgroupAllEqual");
12977 break;
12978
12979 // clang-format off
12980 #define MSL_GROUP_OP(op, msl_op) \
12981 case OpGroupNonUniform##op: \
12982 { \
12983 auto operation = static_cast<GroupOperation>(ops[3]); \
12984 if (operation == GroupOperationReduce) \
12985 emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
12986 else if (operation == GroupOperationInclusiveScan) \
12987 emit_unary_func_op(result_type, id, ops[4], "simd_prefix_inclusive_" #msl_op); \
12988 else if (operation == GroupOperationExclusiveScan) \
12989 emit_unary_func_op(result_type, id, ops[4], "simd_prefix_exclusive_" #msl_op); \
12990 else if (operation == GroupOperationClusteredReduce) \
12991 { \
12992 /* Only cluster sizes of 4 are supported. */ \
12993 uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
12994 if (cluster_size != 4) \
12995 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
12996 emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
12997 } \
12998 else \
12999 SPIRV_CROSS_THROW("Invalid group operation."); \
13000 break; \
13001 }
13002 MSL_GROUP_OP(FAdd, sum)
13003 MSL_GROUP_OP(FMul, product)
13004 MSL_GROUP_OP(IAdd, sum)
13005 MSL_GROUP_OP(IMul, product)
13006 #undef MSL_GROUP_OP
13007 // The others, unfortunately, don't support InclusiveScan or ExclusiveScan.
13008
13009 #define MSL_GROUP_OP(op, msl_op) \
13010 case OpGroupNonUniform##op: \
13011 { \
13012 auto operation = static_cast<GroupOperation>(ops[3]); \
13013 if (operation == GroupOperationReduce) \
13014 emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
13015 else if (operation == GroupOperationInclusiveScan) \
13016 SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
13017 else if (operation == GroupOperationExclusiveScan) \
13018 SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
13019 else if (operation == GroupOperationClusteredReduce) \
13020 { \
13021 /* Only cluster sizes of 4 are supported. */ \
13022 uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
13023 if (cluster_size != 4) \
13024 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
13025 emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
13026 } \
13027 else \
13028 SPIRV_CROSS_THROW("Invalid group operation."); \
13029 break; \
13030 }
13031
13032 #define MSL_GROUP_OP_CAST(op, msl_op, type) \
13033 case OpGroupNonUniform##op: \
13034 { \
13035 auto operation = static_cast<GroupOperation>(ops[3]); \
13036 if (operation == GroupOperationReduce) \
13037 emit_unary_func_op_cast(result_type, id, ops[4], "simd_" #msl_op, type, type); \
13038 else if (operation == GroupOperationInclusiveScan) \
13039 SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
13040 else if (operation == GroupOperationExclusiveScan) \
13041 SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
13042 else if (operation == GroupOperationClusteredReduce) \
13043 { \
13044 /* Only cluster sizes of 4 are supported. */ \
13045 uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
13046 if (cluster_size != 4) \
13047 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
13048 emit_unary_func_op_cast(result_type, id, ops[4], "quad_" #msl_op, type, type); \
13049 } \
13050 else \
13051 SPIRV_CROSS_THROW("Invalid group operation."); \
13052 break; \
13053 }
13054
13055 MSL_GROUP_OP(FMin, min)
13056 MSL_GROUP_OP(FMax, max)
13057 MSL_GROUP_OP_CAST(SMin, min, int_type)
13058 MSL_GROUP_OP_CAST(SMax, max, int_type)
13059 MSL_GROUP_OP_CAST(UMin, min, uint_type)
13060 MSL_GROUP_OP_CAST(UMax, max, uint_type)
13061 MSL_GROUP_OP(BitwiseAnd, and)
13062 MSL_GROUP_OP(BitwiseOr, or)
13063 MSL_GROUP_OP(BitwiseXor, xor)
13064 MSL_GROUP_OP(LogicalAnd, and)
13065 MSL_GROUP_OP(LogicalOr, or)
13066 MSL_GROUP_OP(LogicalXor, xor)
13067 // clang-format on
13068 #undef MSL_GROUP_OP
13069 #undef MSL_GROUP_OP_CAST
13070
13071 case OpGroupNonUniformQuadSwap:
13072 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvQuadSwap");
13073 break;
13074
13075 case OpGroupNonUniformQuadBroadcast:
13076 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvQuadBroadcast");
13077 break;
13078
13079 default:
13080 SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
13081 }
13082
13083 register_control_dependent_expression(id);
13084 }
13085
bitcast_glsl_op(const SPIRType & out_type,const SPIRType & in_type)13086 string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
13087 {
13088 if (out_type.basetype == in_type.basetype)
13089 return "";
13090
13091 assert(out_type.basetype != SPIRType::Boolean);
13092 assert(in_type.basetype != SPIRType::Boolean);
13093
13094 bool integral_cast = type_is_integral(out_type) && type_is_integral(in_type);
13095 bool same_size_cast = out_type.width == in_type.width;
13096
13097 if (integral_cast && same_size_cast)
13098 {
13099 // Trivial bitcast case, casts between integers.
13100 return type_to_glsl(out_type);
13101 }
13102 else
13103 {
13104 // Fall back to the catch-all bitcast in MSL.
13105 return "as_type<" + type_to_glsl(out_type) + ">";
13106 }
13107 }
13108
emit_complex_bitcast(uint32_t,uint32_t,uint32_t)13109 bool CompilerMSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
13110 {
13111 return false;
13112 }
13113
13114 // Returns an MSL string identifying the name of a SPIR-V builtin.
13115 // Output builtins are qualified with the name of the stage out structure.
builtin_to_glsl(BuiltIn builtin,StorageClass storage)13116 string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
13117 {
13118 switch (builtin)
13119 {
13120
13121 // Handle HLSL-style 0-based vertex/instance index.
13122 // Override GLSL compiler strictness
13123 case BuiltInVertexId:
13124 ensure_builtin(StorageClassInput, BuiltInVertexId);
13125 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
13126 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13127 {
13128 if (builtin_declaration)
13129 {
13130 if (needs_base_vertex_arg != TriState::No)
13131 needs_base_vertex_arg = TriState::Yes;
13132 return "gl_VertexID";
13133 }
13134 else
13135 {
13136 ensure_builtin(StorageClassInput, BuiltInBaseVertex);
13137 return "(gl_VertexID - gl_BaseVertex)";
13138 }
13139 }
13140 else
13141 {
13142 return "gl_VertexID";
13143 }
13144 case BuiltInInstanceId:
13145 ensure_builtin(StorageClassInput, BuiltInInstanceId);
13146 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
13147 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13148 {
13149 if (builtin_declaration)
13150 {
13151 if (needs_base_instance_arg != TriState::No)
13152 needs_base_instance_arg = TriState::Yes;
13153 return "gl_InstanceID";
13154 }
13155 else
13156 {
13157 ensure_builtin(StorageClassInput, BuiltInBaseInstance);
13158 return "(gl_InstanceID - gl_BaseInstance)";
13159 }
13160 }
13161 else
13162 {
13163 return "gl_InstanceID";
13164 }
13165 case BuiltInVertexIndex:
13166 ensure_builtin(StorageClassInput, BuiltInVertexIndex);
13167 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
13168 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13169 {
13170 if (builtin_declaration)
13171 {
13172 if (needs_base_vertex_arg != TriState::No)
13173 needs_base_vertex_arg = TriState::Yes;
13174 return "gl_VertexIndex";
13175 }
13176 else
13177 {
13178 ensure_builtin(StorageClassInput, BuiltInBaseVertex);
13179 return "(gl_VertexIndex - gl_BaseVertex)";
13180 }
13181 }
13182 else
13183 {
13184 return "gl_VertexIndex";
13185 }
13186 case BuiltInInstanceIndex:
13187 ensure_builtin(StorageClassInput, BuiltInInstanceIndex);
13188 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
13189 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13190 {
13191 if (builtin_declaration)
13192 {
13193 if (needs_base_instance_arg != TriState::No)
13194 needs_base_instance_arg = TriState::Yes;
13195 return "gl_InstanceIndex";
13196 }
13197 else
13198 {
13199 ensure_builtin(StorageClassInput, BuiltInBaseInstance);
13200 return "(gl_InstanceIndex - gl_BaseInstance)";
13201 }
13202 }
13203 else
13204 {
13205 return "gl_InstanceIndex";
13206 }
13207 case BuiltInBaseVertex:
13208 if (msl_options.supports_msl_version(1, 1) &&
13209 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13210 {
13211 needs_base_vertex_arg = TriState::No;
13212 return "gl_BaseVertex";
13213 }
13214 else
13215 {
13216 SPIRV_CROSS_THROW("BaseVertex requires Metal 1.1 and Mac or Apple A9+ hardware.");
13217 }
13218 case BuiltInBaseInstance:
13219 if (msl_options.supports_msl_version(1, 1) &&
13220 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13221 {
13222 needs_base_instance_arg = TriState::No;
13223 return "gl_BaseInstance";
13224 }
13225 else
13226 {
13227 SPIRV_CROSS_THROW("BaseInstance requires Metal 1.1 and Mac or Apple A9+ hardware.");
13228 }
13229 case BuiltInDrawIndex:
13230 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
13231
13232 // When used in the entry function, output builtins are qualified with output struct name.
13233 // Test storage class as NOT Input, as output builtins might be part of generic type.
13234 // Also don't do this for tessellation control shaders.
13235 case BuiltInViewportIndex:
13236 if (!msl_options.supports_msl_version(2, 0))
13237 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
13238 /* fallthrough */
13239 case BuiltInFragDepth:
13240 case BuiltInFragStencilRefEXT:
13241 if ((builtin == BuiltInFragDepth && !msl_options.enable_frag_depth_builtin) ||
13242 (builtin == BuiltInFragStencilRefEXT && !msl_options.enable_frag_stencil_ref_builtin))
13243 break;
13244 /* fallthrough */
13245 case BuiltInPosition:
13246 case BuiltInPointSize:
13247 case BuiltInClipDistance:
13248 case BuiltInCullDistance:
13249 case BuiltInLayer:
13250 case BuiltInSampleMask:
13251 if (get_execution_model() == ExecutionModelTessellationControl)
13252 break;
13253 if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
13254 return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
13255
13256 break;
13257
13258 case BuiltInBaryCoordNV:
13259 case BuiltInBaryCoordNoPerspNV:
13260 if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
13261 return stage_in_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
13262 break;
13263
13264 case BuiltInTessLevelOuter:
13265 if (get_execution_model() == ExecutionModelTessellationEvaluation)
13266 {
13267 if (storage != StorageClassOutput && !get_entry_point().flags.get(ExecutionModeTriangles) &&
13268 current_function && (current_function->self == ir.default_entry_point))
13269 return join(patch_stage_in_var_name, ".", CompilerGLSL::builtin_to_glsl(builtin, storage));
13270 else
13271 break;
13272 }
13273 if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
13274 return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
13275 "].edgeTessellationFactor");
13276 break;
13277
13278 case BuiltInTessLevelInner:
13279 if (get_execution_model() == ExecutionModelTessellationEvaluation)
13280 {
13281 if (storage != StorageClassOutput && !get_entry_point().flags.get(ExecutionModeTriangles) &&
13282 current_function && (current_function->self == ir.default_entry_point))
13283 return join(patch_stage_in_var_name, ".", CompilerGLSL::builtin_to_glsl(builtin, storage));
13284 else
13285 break;
13286 }
13287 if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
13288 return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
13289 "].insideTessellationFactor");
13290 break;
13291
13292 default:
13293 break;
13294 }
13295
13296 return CompilerGLSL::builtin_to_glsl(builtin, storage);
13297 }
13298
13299 // Returns an MSL string attribute qualifer for a SPIR-V builtin
builtin_qualifier(BuiltIn builtin)13300 string CompilerMSL::builtin_qualifier(BuiltIn builtin)
13301 {
13302 auto &execution = get_entry_point();
13303
13304 switch (builtin)
13305 {
13306 // Vertex function in
13307 case BuiltInVertexId:
13308 return "vertex_id";
13309 case BuiltInVertexIndex:
13310 return "vertex_id";
13311 case BuiltInBaseVertex:
13312 return "base_vertex";
13313 case BuiltInInstanceId:
13314 return "instance_id";
13315 case BuiltInInstanceIndex:
13316 return "instance_id";
13317 case BuiltInBaseInstance:
13318 return "base_instance";
13319 case BuiltInDrawIndex:
13320 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
13321
13322 // Vertex function out
13323 case BuiltInClipDistance:
13324 return "clip_distance";
13325 case BuiltInPointSize:
13326 return "point_size";
13327 case BuiltInPosition:
13328 if (position_invariant)
13329 {
13330 if (!msl_options.supports_msl_version(2, 1))
13331 SPIRV_CROSS_THROW("Invariant position is only supported on MSL 2.1 and up.");
13332 return "position, invariant";
13333 }
13334 else
13335 return "position";
13336 case BuiltInLayer:
13337 return "render_target_array_index";
13338 case BuiltInViewportIndex:
13339 if (!msl_options.supports_msl_version(2, 0))
13340 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
13341 return "viewport_array_index";
13342
13343 // Tess. control function in
13344 case BuiltInInvocationId:
13345 if (msl_options.multi_patch_workgroup)
13346 {
13347 // Shouldn't be reached.
13348 SPIRV_CROSS_THROW("InvocationId is computed manually with multi-patch workgroups in MSL.");
13349 }
13350 return "thread_index_in_threadgroup";
13351 case BuiltInPatchVertices:
13352 // Shouldn't be reached.
13353 SPIRV_CROSS_THROW("PatchVertices is derived from the auxiliary buffer in MSL.");
13354 case BuiltInPrimitiveId:
13355 switch (execution.model)
13356 {
13357 case ExecutionModelTessellationControl:
13358 if (msl_options.multi_patch_workgroup)
13359 {
13360 // Shouldn't be reached.
13361 SPIRV_CROSS_THROW("PrimitiveId is computed manually with multi-patch workgroups in MSL.");
13362 }
13363 return "threadgroup_position_in_grid";
13364 case ExecutionModelTessellationEvaluation:
13365 return "patch_id";
13366 case ExecutionModelFragment:
13367 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
13368 SPIRV_CROSS_THROW("PrimitiveId on iOS requires MSL 2.3.");
13369 else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 2))
13370 SPIRV_CROSS_THROW("PrimitiveId on macOS requires MSL 2.2.");
13371 return "primitive_id";
13372 default:
13373 SPIRV_CROSS_THROW("PrimitiveId is not supported in this execution model.");
13374 }
13375
13376 // Tess. control function out
13377 case BuiltInTessLevelOuter:
13378 case BuiltInTessLevelInner:
13379 // Shouldn't be reached.
13380 SPIRV_CROSS_THROW("Tessellation levels are handled specially in MSL.");
13381
13382 // Tess. evaluation function in
13383 case BuiltInTessCoord:
13384 return "position_in_patch";
13385
13386 // Fragment function in
13387 case BuiltInFrontFacing:
13388 return "front_facing";
13389 case BuiltInPointCoord:
13390 return "point_coord";
13391 case BuiltInFragCoord:
13392 return "position";
13393 case BuiltInSampleId:
13394 return "sample_id";
13395 case BuiltInSampleMask:
13396 return "sample_mask";
13397 case BuiltInSamplePosition:
13398 // Shouldn't be reached.
13399 SPIRV_CROSS_THROW("Sample position is retrieved by a function in MSL.");
13400 case BuiltInViewIndex:
13401 if (execution.model != ExecutionModelFragment)
13402 SPIRV_CROSS_THROW("ViewIndex is handled specially outside fragment shaders.");
13403 // The ViewIndex was implicitly used in the prior stages to set the render_target_array_index,
13404 // so we can get it from there.
13405 return "render_target_array_index";
13406
13407 // Fragment function out
13408 case BuiltInFragDepth:
13409 if (execution.flags.get(ExecutionModeDepthGreater))
13410 return "depth(greater)";
13411 else if (execution.flags.get(ExecutionModeDepthLess))
13412 return "depth(less)";
13413 else
13414 return "depth(any)";
13415
13416 case BuiltInFragStencilRefEXT:
13417 return "stencil";
13418
13419 // Compute function in
13420 case BuiltInGlobalInvocationId:
13421 return "thread_position_in_grid";
13422
13423 case BuiltInWorkgroupId:
13424 return "threadgroup_position_in_grid";
13425
13426 case BuiltInNumWorkgroups:
13427 return "threadgroups_per_grid";
13428
13429 case BuiltInLocalInvocationId:
13430 return "thread_position_in_threadgroup";
13431
13432 case BuiltInLocalInvocationIndex:
13433 return "thread_index_in_threadgroup";
13434
13435 case BuiltInSubgroupSize:
13436 if (msl_options.emulate_subgroups || msl_options.fixed_subgroup_size != 0)
13437 // Shouldn't be reached.
13438 SPIRV_CROSS_THROW("Emitting threads_per_simdgroup attribute with fixed subgroup size??");
13439 if (execution.model == ExecutionModelFragment)
13440 {
13441 if (!msl_options.supports_msl_version(2, 2))
13442 SPIRV_CROSS_THROW("threads_per_simdgroup requires Metal 2.2 in fragment shaders.");
13443 return "threads_per_simdgroup";
13444 }
13445 else
13446 {
13447 // thread_execution_width is an alias for threads_per_simdgroup, and it's only available since 1.0,
13448 // but not in fragment.
13449 return "thread_execution_width";
13450 }
13451
13452 case BuiltInNumSubgroups:
13453 if (msl_options.emulate_subgroups)
13454 // Shouldn't be reached.
13455 SPIRV_CROSS_THROW("NumSubgroups is handled specially with emulation.");
13456 if (!msl_options.supports_msl_version(2))
13457 SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
13458 return msl_options.is_ios() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
13459
13460 case BuiltInSubgroupId:
13461 if (msl_options.emulate_subgroups)
13462 // Shouldn't be reached.
13463 SPIRV_CROSS_THROW("SubgroupId is handled specially with emulation.");
13464 if (!msl_options.supports_msl_version(2))
13465 SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
13466 return msl_options.is_ios() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
13467
13468 case BuiltInSubgroupLocalInvocationId:
13469 if (msl_options.emulate_subgroups)
13470 // Shouldn't be reached.
13471 SPIRV_CROSS_THROW("SubgroupLocalInvocationId is handled specially with emulation.");
13472 if (execution.model == ExecutionModelFragment)
13473 {
13474 if (!msl_options.supports_msl_version(2, 2))
13475 SPIRV_CROSS_THROW("thread_index_in_simdgroup requires Metal 2.2 in fragment shaders.");
13476 return "thread_index_in_simdgroup";
13477 }
13478 else if (execution.model == ExecutionModelKernel || execution.model == ExecutionModelGLCompute ||
13479 execution.model == ExecutionModelTessellationControl ||
13480 (execution.model == ExecutionModelVertex && msl_options.vertex_for_tessellation))
13481 {
13482 // We are generating a Metal kernel function.
13483 if (!msl_options.supports_msl_version(2))
13484 SPIRV_CROSS_THROW("Subgroup builtins in kernel functions require Metal 2.0.");
13485 return msl_options.is_ios() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
13486 }
13487 else
13488 SPIRV_CROSS_THROW("Subgroup builtins are not available in this type of function.");
13489
13490 case BuiltInSubgroupEqMask:
13491 case BuiltInSubgroupGeMask:
13492 case BuiltInSubgroupGtMask:
13493 case BuiltInSubgroupLeMask:
13494 case BuiltInSubgroupLtMask:
13495 // Shouldn't be reached.
13496 SPIRV_CROSS_THROW("Subgroup ballot masks are handled specially in MSL.");
13497
13498 case BuiltInBaryCoordNV:
13499 // TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
13500 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
13501 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.3 and above on iOS.");
13502 else if (!msl_options.supports_msl_version(2, 2))
13503 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
13504 return "barycentric_coord, center_perspective";
13505
13506 case BuiltInBaryCoordNoPerspNV:
13507 // TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
13508 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
13509 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.3 and above on iOS.");
13510 else if (!msl_options.supports_msl_version(2, 2))
13511 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
13512 return "barycentric_coord, center_no_perspective";
13513
13514 default:
13515 return "unsupported-built-in";
13516 }
13517 }
13518
13519 // Returns an MSL string type declaration for a SPIR-V builtin
builtin_type_decl(BuiltIn builtin,uint32_t id)13520 string CompilerMSL::builtin_type_decl(BuiltIn builtin, uint32_t id)
13521 {
13522 const SPIREntryPoint &execution = get_entry_point();
13523 switch (builtin)
13524 {
13525 // Vertex function in
13526 case BuiltInVertexId:
13527 return "uint";
13528 case BuiltInVertexIndex:
13529 return "uint";
13530 case BuiltInBaseVertex:
13531 return "uint";
13532 case BuiltInInstanceId:
13533 return "uint";
13534 case BuiltInInstanceIndex:
13535 return "uint";
13536 case BuiltInBaseInstance:
13537 return "uint";
13538 case BuiltInDrawIndex:
13539 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
13540
13541 // Vertex function out
13542 case BuiltInClipDistance:
13543 return "float";
13544 case BuiltInPointSize:
13545 return "float";
13546 case BuiltInPosition:
13547 return "float4";
13548 case BuiltInLayer:
13549 return "uint";
13550 case BuiltInViewportIndex:
13551 if (!msl_options.supports_msl_version(2, 0))
13552 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
13553 return "uint";
13554
13555 // Tess. control function in
13556 case BuiltInInvocationId:
13557 return "uint";
13558 case BuiltInPatchVertices:
13559 return "uint";
13560 case BuiltInPrimitiveId:
13561 return "uint";
13562
13563 // Tess. control function out
13564 case BuiltInTessLevelInner:
13565 if (execution.model == ExecutionModelTessellationEvaluation)
13566 return !execution.flags.get(ExecutionModeTriangles) ? "float2" : "float";
13567 return "half";
13568 case BuiltInTessLevelOuter:
13569 if (execution.model == ExecutionModelTessellationEvaluation)
13570 return !execution.flags.get(ExecutionModeTriangles) ? "float4" : "float";
13571 return "half";
13572
13573 // Tess. evaluation function in
13574 case BuiltInTessCoord:
13575 return execution.flags.get(ExecutionModeTriangles) ? "float3" : "float2";
13576
13577 // Fragment function in
13578 case BuiltInFrontFacing:
13579 return "bool";
13580 case BuiltInPointCoord:
13581 return "float2";
13582 case BuiltInFragCoord:
13583 return "float4";
13584 case BuiltInSampleId:
13585 return "uint";
13586 case BuiltInSampleMask:
13587 return "uint";
13588 case BuiltInSamplePosition:
13589 return "float2";
13590 case BuiltInViewIndex:
13591 return "uint";
13592
13593 case BuiltInHelperInvocation:
13594 return "bool";
13595
13596 case BuiltInBaryCoordNV:
13597 case BuiltInBaryCoordNoPerspNV:
13598 // Use the type as declared, can be 1, 2 or 3 components.
13599 return type_to_glsl(get_variable_data_type(get<SPIRVariable>(id)));
13600
13601 // Fragment function out
13602 case BuiltInFragDepth:
13603 return "float";
13604
13605 case BuiltInFragStencilRefEXT:
13606 return "uint";
13607
13608 // Compute function in
13609 case BuiltInGlobalInvocationId:
13610 case BuiltInLocalInvocationId:
13611 case BuiltInNumWorkgroups:
13612 case BuiltInWorkgroupId:
13613 return "uint3";
13614 case BuiltInLocalInvocationIndex:
13615 case BuiltInNumSubgroups:
13616 case BuiltInSubgroupId:
13617 case BuiltInSubgroupSize:
13618 case BuiltInSubgroupLocalInvocationId:
13619 return "uint";
13620 case BuiltInSubgroupEqMask:
13621 case BuiltInSubgroupGeMask:
13622 case BuiltInSubgroupGtMask:
13623 case BuiltInSubgroupLeMask:
13624 case BuiltInSubgroupLtMask:
13625 return "uint4";
13626
13627 case BuiltInDeviceIndex:
13628 return "int";
13629
13630 default:
13631 return "unsupported-built-in-type";
13632 }
13633 }
13634
13635 // Returns the declaration of a built-in argument to a function
built_in_func_arg(BuiltIn builtin,bool prefix_comma)13636 string CompilerMSL::built_in_func_arg(BuiltIn builtin, bool prefix_comma)
13637 {
13638 string bi_arg;
13639 if (prefix_comma)
13640 bi_arg += ", ";
13641
13642 // Handle HLSL-style 0-based vertex/instance index.
13643 builtin_declaration = true;
13644 bi_arg += builtin_type_decl(builtin);
13645 bi_arg += " " + builtin_to_glsl(builtin, StorageClassInput);
13646 bi_arg += " [[" + builtin_qualifier(builtin) + "]]";
13647 builtin_declaration = false;
13648
13649 return bi_arg;
13650 }
13651
get_physical_member_type(const SPIRType & type,uint32_t index) const13652 const SPIRType &CompilerMSL::get_physical_member_type(const SPIRType &type, uint32_t index) const
13653 {
13654 if (member_is_remapped_physical_type(type, index))
13655 return get<SPIRType>(get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID));
13656 else
13657 return get<SPIRType>(type.member_types[index]);
13658 }
13659
get_presumed_input_type(const SPIRType & ib_type,uint32_t index) const13660 SPIRType CompilerMSL::get_presumed_input_type(const SPIRType &ib_type, uint32_t index) const
13661 {
13662 SPIRType type = get_physical_member_type(ib_type, index);
13663 uint32_t loc = get_member_decoration(ib_type.self, index, DecorationLocation);
13664 if (inputs_by_location.count(loc))
13665 {
13666 if (inputs_by_location.at(loc).vecsize > type.vecsize)
13667 type.vecsize = inputs_by_location.at(loc).vecsize;
13668 }
13669 return type;
13670 }
13671
get_declared_type_array_stride_msl(const SPIRType & type,bool is_packed,bool row_major) const13672 uint32_t CompilerMSL::get_declared_type_array_stride_msl(const SPIRType &type, bool is_packed, bool row_major) const
13673 {
13674 // Array stride in MSL is always size * array_size. sizeof(float3) == 16,
13675 // unlike GLSL and HLSL where array stride would be 16 and size 12.
13676
13677 // We could use parent type here and recurse, but that makes creating physical type remappings
13678 // far more complicated. We'd rather just create the final type, and ignore having to create the entire type
13679 // hierarchy in order to compute this value, so make a temporary type on the stack.
13680
13681 auto basic_type = type;
13682 basic_type.array.clear();
13683 basic_type.array_size_literal.clear();
13684 uint32_t value_size = get_declared_type_size_msl(basic_type, is_packed, row_major);
13685
13686 uint32_t dimensions = uint32_t(type.array.size());
13687 assert(dimensions > 0);
13688 dimensions--;
13689
13690 // Multiply together every dimension, except the last one.
13691 for (uint32_t dim = 0; dim < dimensions; dim++)
13692 {
13693 uint32_t array_size = to_array_size_literal(type, dim);
13694 value_size *= max(array_size, 1u);
13695 }
13696
13697 return value_size;
13698 }
13699
get_declared_struct_member_array_stride_msl(const SPIRType & type,uint32_t index) const13700 uint32_t CompilerMSL::get_declared_struct_member_array_stride_msl(const SPIRType &type, uint32_t index) const
13701 {
13702 return get_declared_type_array_stride_msl(get_physical_member_type(type, index),
13703 member_is_packed_physical_type(type, index),
13704 has_member_decoration(type.self, index, DecorationRowMajor));
13705 }
13706
get_declared_input_array_stride_msl(const SPIRType & type,uint32_t index) const13707 uint32_t CompilerMSL::get_declared_input_array_stride_msl(const SPIRType &type, uint32_t index) const
13708 {
13709 return get_declared_type_array_stride_msl(get_presumed_input_type(type, index), false,
13710 has_member_decoration(type.self, index, DecorationRowMajor));
13711 }
13712
get_declared_type_matrix_stride_msl(const SPIRType & type,bool packed,bool row_major) const13713 uint32_t CompilerMSL::get_declared_type_matrix_stride_msl(const SPIRType &type, bool packed, bool row_major) const
13714 {
13715 // For packed matrices, we just use the size of the vector type.
13716 // Otherwise, MatrixStride == alignment, which is the size of the underlying vector type.
13717 if (packed)
13718 return (type.width / 8) * ((row_major && type.columns > 1) ? type.columns : type.vecsize);
13719 else
13720 return get_declared_type_alignment_msl(type, false, row_major);
13721 }
13722
get_declared_struct_member_matrix_stride_msl(const SPIRType & type,uint32_t index) const13723 uint32_t CompilerMSL::get_declared_struct_member_matrix_stride_msl(const SPIRType &type, uint32_t index) const
13724 {
13725 return get_declared_type_matrix_stride_msl(get_physical_member_type(type, index),
13726 member_is_packed_physical_type(type, index),
13727 has_member_decoration(type.self, index, DecorationRowMajor));
13728 }
13729
get_declared_input_matrix_stride_msl(const SPIRType & type,uint32_t index) const13730 uint32_t CompilerMSL::get_declared_input_matrix_stride_msl(const SPIRType &type, uint32_t index) const
13731 {
13732 return get_declared_type_matrix_stride_msl(get_presumed_input_type(type, index), false,
13733 has_member_decoration(type.self, index, DecorationRowMajor));
13734 }
13735
get_declared_struct_size_msl(const SPIRType & struct_type,bool ignore_alignment,bool ignore_padding) const13736 uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type, bool ignore_alignment,
13737 bool ignore_padding) const
13738 {
13739 // If we have a target size, that is the declared size as well.
13740 if (!ignore_padding && has_extended_decoration(struct_type.self, SPIRVCrossDecorationPaddingTarget))
13741 return get_extended_decoration(struct_type.self, SPIRVCrossDecorationPaddingTarget);
13742
13743 if (struct_type.member_types.empty())
13744 return 0;
13745
13746 uint32_t mbr_cnt = uint32_t(struct_type.member_types.size());
13747
13748 // In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
13749 uint32_t alignment = 1;
13750
13751 if (!ignore_alignment)
13752 {
13753 for (uint32_t i = 0; i < mbr_cnt; i++)
13754 {
13755 uint32_t mbr_alignment = get_declared_struct_member_alignment_msl(struct_type, i);
13756 alignment = max(alignment, mbr_alignment);
13757 }
13758 }
13759
13760 // Last member will always be matched to the final Offset decoration, but size of struct in MSL now depends
13761 // on physical size in MSL, and the size of the struct itself is then aligned to struct alignment.
13762 uint32_t spirv_offset = type_struct_member_offset(struct_type, mbr_cnt - 1);
13763 uint32_t msl_size = spirv_offset + get_declared_struct_member_size_msl(struct_type, mbr_cnt - 1);
13764 msl_size = (msl_size + alignment - 1) & ~(alignment - 1);
13765 return msl_size;
13766 }
13767
13768 // Returns the byte size of a struct member.
get_declared_type_size_msl(const SPIRType & type,bool is_packed,bool row_major) const13769 uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const
13770 {
13771 switch (type.basetype)
13772 {
13773 case SPIRType::Unknown:
13774 case SPIRType::Void:
13775 case SPIRType::AtomicCounter:
13776 case SPIRType::Image:
13777 case SPIRType::SampledImage:
13778 case SPIRType::Sampler:
13779 SPIRV_CROSS_THROW("Querying size of opaque object.");
13780
13781 default:
13782 {
13783 if (!type.array.empty())
13784 {
13785 uint32_t array_size = to_array_size_literal(type);
13786 return get_declared_type_array_stride_msl(type, is_packed, row_major) * max(array_size, 1u);
13787 }
13788
13789 if (type.basetype == SPIRType::Struct)
13790 return get_declared_struct_size_msl(type);
13791
13792 if (is_packed)
13793 {
13794 return type.vecsize * type.columns * (type.width / 8);
13795 }
13796 else
13797 {
13798 // An unpacked 3-element vector or matrix column is the same memory size as a 4-element.
13799 uint32_t vecsize = type.vecsize;
13800 uint32_t columns = type.columns;
13801
13802 if (row_major && columns > 1)
13803 swap(vecsize, columns);
13804
13805 if (vecsize == 3)
13806 vecsize = 4;
13807
13808 return vecsize * columns * (type.width / 8);
13809 }
13810 }
13811 }
13812 }
13813
get_declared_struct_member_size_msl(const SPIRType & type,uint32_t index) const13814 uint32_t CompilerMSL::get_declared_struct_member_size_msl(const SPIRType &type, uint32_t index) const
13815 {
13816 return get_declared_type_size_msl(get_physical_member_type(type, index),
13817 member_is_packed_physical_type(type, index),
13818 has_member_decoration(type.self, index, DecorationRowMajor));
13819 }
13820
get_declared_input_size_msl(const SPIRType & type,uint32_t index) const13821 uint32_t CompilerMSL::get_declared_input_size_msl(const SPIRType &type, uint32_t index) const
13822 {
13823 return get_declared_type_size_msl(get_presumed_input_type(type, index), false,
13824 has_member_decoration(type.self, index, DecorationRowMajor));
13825 }
13826
13827 // Returns the byte alignment of a type.
get_declared_type_alignment_msl(const SPIRType & type,bool is_packed,bool row_major) const13828 uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool is_packed, bool row_major) const
13829 {
13830 switch (type.basetype)
13831 {
13832 case SPIRType::Unknown:
13833 case SPIRType::Void:
13834 case SPIRType::AtomicCounter:
13835 case SPIRType::Image:
13836 case SPIRType::SampledImage:
13837 case SPIRType::Sampler:
13838 SPIRV_CROSS_THROW("Querying alignment of opaque object.");
13839
13840 case SPIRType::Int64:
13841 SPIRV_CROSS_THROW("long types are not supported in buffers in MSL.");
13842 case SPIRType::UInt64:
13843 SPIRV_CROSS_THROW("ulong types are not supported in buffers in MSL.");
13844 case SPIRType::Double:
13845 SPIRV_CROSS_THROW("double types are not supported in buffers in MSL.");
13846
13847 case SPIRType::Struct:
13848 {
13849 // In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
13850 uint32_t alignment = 1;
13851 for (uint32_t i = 0; i < type.member_types.size(); i++)
13852 alignment = max(alignment, uint32_t(get_declared_struct_member_alignment_msl(type, i)));
13853 return alignment;
13854 }
13855
13856 default:
13857 {
13858 // Alignment of packed type is the same as the underlying component or column size.
13859 // Alignment of unpacked type is the same as the vector size.
13860 // Alignment of 3-elements vector is the same as 4-elements (including packed using column).
13861 if (is_packed)
13862 {
13863 // If we have packed_T and friends, the alignment is always scalar.
13864 return type.width / 8;
13865 }
13866 else
13867 {
13868 // This is the general rule for MSL. Size == alignment.
13869 uint32_t vecsize = (row_major && type.columns > 1) ? type.columns : type.vecsize;
13870 return (type.width / 8) * (vecsize == 3 ? 4 : vecsize);
13871 }
13872 }
13873 }
13874 }
13875
get_declared_struct_member_alignment_msl(const SPIRType & type,uint32_t index) const13876 uint32_t CompilerMSL::get_declared_struct_member_alignment_msl(const SPIRType &type, uint32_t index) const
13877 {
13878 return get_declared_type_alignment_msl(get_physical_member_type(type, index),
13879 member_is_packed_physical_type(type, index),
13880 has_member_decoration(type.self, index, DecorationRowMajor));
13881 }
13882
get_declared_input_alignment_msl(const SPIRType & type,uint32_t index) const13883 uint32_t CompilerMSL::get_declared_input_alignment_msl(const SPIRType &type, uint32_t index) const
13884 {
13885 return get_declared_type_alignment_msl(get_presumed_input_type(type, index), false,
13886 has_member_decoration(type.self, index, DecorationRowMajor));
13887 }
13888
skip_argument(uint32_t) const13889 bool CompilerMSL::skip_argument(uint32_t) const
13890 {
13891 return false;
13892 }
13893
analyze_sampled_image_usage()13894 void CompilerMSL::analyze_sampled_image_usage()
13895 {
13896 if (msl_options.swizzle_texture_samples)
13897 {
13898 SampledImageScanner scanner(*this);
13899 traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), scanner);
13900 }
13901 }
13902
handle(spv::Op opcode,const uint32_t * args,uint32_t length)13903 bool CompilerMSL::SampledImageScanner::handle(spv::Op opcode, const uint32_t *args, uint32_t length)
13904 {
13905 switch (opcode)
13906 {
13907 case OpLoad:
13908 case OpImage:
13909 case OpSampledImage:
13910 {
13911 if (length < 3)
13912 return false;
13913
13914 uint32_t result_type = args[0];
13915 auto &type = compiler.get<SPIRType>(result_type);
13916 if ((type.basetype != SPIRType::Image && type.basetype != SPIRType::SampledImage) || type.image.sampled != 1)
13917 return true;
13918
13919 uint32_t id = args[1];
13920 compiler.set<SPIRExpression>(id, "", result_type, true);
13921 break;
13922 }
13923 case OpImageSampleExplicitLod:
13924 case OpImageSampleProjExplicitLod:
13925 case OpImageSampleDrefExplicitLod:
13926 case OpImageSampleProjDrefExplicitLod:
13927 case OpImageSampleImplicitLod:
13928 case OpImageSampleProjImplicitLod:
13929 case OpImageSampleDrefImplicitLod:
13930 case OpImageSampleProjDrefImplicitLod:
13931 case OpImageFetch:
13932 case OpImageGather:
13933 case OpImageDrefGather:
13934 compiler.has_sampled_images =
13935 compiler.has_sampled_images || compiler.is_sampled_image_type(compiler.expression_type(args[2]));
13936 compiler.needs_swizzle_buffer_def = compiler.needs_swizzle_buffer_def || compiler.has_sampled_images;
13937 break;
13938 default:
13939 break;
13940 }
13941 return true;
13942 }
13943
13944 // If a needed custom function wasn't added before, add it and force a recompile.
add_spv_func_and_recompile(SPVFuncImpl spv_func)13945 void CompilerMSL::add_spv_func_and_recompile(SPVFuncImpl spv_func)
13946 {
13947 if (spv_function_implementations.count(spv_func) == 0)
13948 {
13949 spv_function_implementations.insert(spv_func);
13950 suppress_missing_prototypes = true;
13951 force_recompile();
13952 }
13953 }
13954
handle(Op opcode,const uint32_t * args,uint32_t length)13955 bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, uint32_t length)
13956 {
13957 // Since MSL exists in a single execution scope, function prototype declarations are not
13958 // needed, and clutter the output. If secondary functions are output (either as a SPIR-V
13959 // function implementation or as indicated by the presence of OpFunctionCall), then set
13960 // suppress_missing_prototypes to suppress compiler warnings of missing function prototypes.
13961
13962 // Mark if the input requires the implementation of an SPIR-V function that does not exist in Metal.
13963 SPVFuncImpl spv_func = get_spv_func_impl(opcode, args);
13964 if (spv_func != SPVFuncImplNone)
13965 {
13966 compiler.spv_function_implementations.insert(spv_func);
13967 suppress_missing_prototypes = true;
13968 }
13969
13970 switch (opcode)
13971 {
13972
13973 case OpFunctionCall:
13974 suppress_missing_prototypes = true;
13975 break;
13976
13977 // Emulate texture2D atomic operations
13978 case OpImageTexelPointer:
13979 {
13980 auto *var = compiler.maybe_get_backing_variable(args[2]);
13981 image_pointers[args[1]] = var ? var->self : ID(0);
13982 break;
13983 }
13984
13985 case OpImageWrite:
13986 if (!compiler.msl_options.supports_msl_version(2, 2))
13987 uses_resource_write = true;
13988 break;
13989
13990 case OpStore:
13991 check_resource_write(args[0]);
13992 break;
13993
13994 // Emulate texture2D atomic operations
13995 case OpAtomicExchange:
13996 case OpAtomicCompareExchange:
13997 case OpAtomicCompareExchangeWeak:
13998 case OpAtomicIIncrement:
13999 case OpAtomicIDecrement:
14000 case OpAtomicIAdd:
14001 case OpAtomicISub:
14002 case OpAtomicSMin:
14003 case OpAtomicUMin:
14004 case OpAtomicSMax:
14005 case OpAtomicUMax:
14006 case OpAtomicAnd:
14007 case OpAtomicOr:
14008 case OpAtomicXor:
14009 {
14010 uses_atomics = true;
14011 auto it = image_pointers.find(args[2]);
14012 if (it != image_pointers.end())
14013 {
14014 compiler.atomic_image_vars.insert(it->second);
14015 }
14016 check_resource_write(args[2]);
14017 break;
14018 }
14019
14020 case OpAtomicStore:
14021 {
14022 uses_atomics = true;
14023 auto it = image_pointers.find(args[0]);
14024 if (it != image_pointers.end())
14025 {
14026 compiler.atomic_image_vars.insert(it->second);
14027 }
14028 check_resource_write(args[0]);
14029 break;
14030 }
14031
14032 case OpAtomicLoad:
14033 {
14034 uses_atomics = true;
14035 auto it = image_pointers.find(args[2]);
14036 if (it != image_pointers.end())
14037 {
14038 compiler.atomic_image_vars.insert(it->second);
14039 }
14040 break;
14041 }
14042
14043 case OpGroupNonUniformInverseBallot:
14044 needs_subgroup_invocation_id = true;
14045 break;
14046
14047 case OpGroupNonUniformBallotFindLSB:
14048 case OpGroupNonUniformBallotFindMSB:
14049 needs_subgroup_size = true;
14050 break;
14051
14052 case OpGroupNonUniformBallotBitCount:
14053 if (args[3] == GroupOperationReduce)
14054 needs_subgroup_size = true;
14055 else
14056 needs_subgroup_invocation_id = true;
14057 break;
14058
14059 case OpArrayLength:
14060 {
14061 auto *var = compiler.maybe_get_backing_variable(args[2]);
14062 if (var)
14063 compiler.buffers_requiring_array_length.insert(var->self);
14064 break;
14065 }
14066
14067 case OpInBoundsAccessChain:
14068 case OpAccessChain:
14069 case OpPtrAccessChain:
14070 {
14071 // OpArrayLength might want to know if taking ArrayLength of an array of SSBOs.
14072 uint32_t result_type = args[0];
14073 uint32_t id = args[1];
14074 uint32_t ptr = args[2];
14075
14076 compiler.set<SPIRExpression>(id, "", result_type, true);
14077 compiler.register_read(id, ptr, true);
14078 compiler.ir.ids[id].set_allow_type_rewrite();
14079 break;
14080 }
14081
14082 case OpExtInst:
14083 {
14084 uint32_t extension_set = args[2];
14085 if (compiler.get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
14086 {
14087 auto op_450 = static_cast<GLSLstd450>(args[3]);
14088 switch (op_450)
14089 {
14090 case GLSLstd450InterpolateAtCentroid:
14091 case GLSLstd450InterpolateAtSample:
14092 case GLSLstd450InterpolateAtOffset:
14093 {
14094 if (!compiler.msl_options.supports_msl_version(2, 3))
14095 SPIRV_CROSS_THROW("Pull-model interpolation requires MSL 2.3.");
14096 // Fragment varyings used with pull-model interpolation need special handling,
14097 // due to the way pull-model interpolation works in Metal.
14098 auto *var = compiler.maybe_get_backing_variable(args[4]);
14099 if (var)
14100 {
14101 compiler.pull_model_inputs.insert(var->self);
14102 auto &var_type = compiler.get_variable_element_type(*var);
14103 // In addition, if this variable has a 'Sample' decoration, we need the sample ID
14104 // in order to do default interpolation.
14105 if (compiler.has_decoration(var->self, DecorationSample))
14106 {
14107 needs_sample_id = true;
14108 }
14109 else if (var_type.basetype == SPIRType::Struct)
14110 {
14111 // Now we need to check each member and see if it has this decoration.
14112 for (uint32_t i = 0; i < var_type.member_types.size(); ++i)
14113 {
14114 if (compiler.has_member_decoration(var_type.self, i, DecorationSample))
14115 {
14116 needs_sample_id = true;
14117 break;
14118 }
14119 }
14120 }
14121 }
14122 break;
14123 }
14124 default:
14125 break;
14126 }
14127 }
14128 break;
14129 }
14130
14131 default:
14132 break;
14133 }
14134
14135 // If it has one, keep track of the instruction's result type, mapped by ID
14136 uint32_t result_type, result_id;
14137 if (compiler.instruction_to_result_type(result_type, result_id, opcode, args, length))
14138 result_types[result_id] = result_type;
14139
14140 return true;
14141 }
14142
14143 // If the variable is a Uniform or StorageBuffer, mark that a resource has been written to.
check_resource_write(uint32_t var_id)14144 void CompilerMSL::OpCodePreprocessor::check_resource_write(uint32_t var_id)
14145 {
14146 auto *p_var = compiler.maybe_get_backing_variable(var_id);
14147 StorageClass sc = p_var ? p_var->storage : StorageClassMax;
14148 if (!compiler.msl_options.supports_msl_version(2, 1) &&
14149 (sc == StorageClassUniform || sc == StorageClassStorageBuffer))
14150 uses_resource_write = true;
14151 }
14152
14153 // Returns an enumeration of a SPIR-V function that needs to be output for certain Op codes.
get_spv_func_impl(Op opcode,const uint32_t * args)14154 CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op opcode, const uint32_t *args)
14155 {
14156 switch (opcode)
14157 {
14158 case OpFMod:
14159 return SPVFuncImplMod;
14160
14161 case OpFAdd:
14162 if (compiler.msl_options.invariant_float_math)
14163 {
14164 return SPVFuncImplFAdd;
14165 }
14166 break;
14167
14168 case OpFMul:
14169 case OpOuterProduct:
14170 case OpMatrixTimesVector:
14171 case OpVectorTimesMatrix:
14172 case OpMatrixTimesMatrix:
14173 if (compiler.msl_options.invariant_float_math)
14174 {
14175 return SPVFuncImplFMul;
14176 }
14177 break;
14178
14179 case OpTypeArray:
14180 {
14181 // Allow Metal to use the array<T> template to make arrays a value type
14182 return SPVFuncImplUnsafeArray;
14183 }
14184
14185 // Emulate texture2D atomic operations
14186 case OpAtomicExchange:
14187 case OpAtomicCompareExchange:
14188 case OpAtomicCompareExchangeWeak:
14189 case OpAtomicIIncrement:
14190 case OpAtomicIDecrement:
14191 case OpAtomicIAdd:
14192 case OpAtomicISub:
14193 case OpAtomicSMin:
14194 case OpAtomicUMin:
14195 case OpAtomicSMax:
14196 case OpAtomicUMax:
14197 case OpAtomicAnd:
14198 case OpAtomicOr:
14199 case OpAtomicXor:
14200 case OpAtomicLoad:
14201 case OpAtomicStore:
14202 {
14203 auto it = image_pointers.find(args[opcode == OpAtomicStore ? 0 : 2]);
14204 if (it != image_pointers.end())
14205 {
14206 uint32_t tid = compiler.get<SPIRVariable>(it->second).basetype;
14207 if (tid && compiler.get<SPIRType>(tid).image.dim == Dim2D)
14208 return SPVFuncImplImage2DAtomicCoords;
14209 }
14210 break;
14211 }
14212
14213 case OpImageFetch:
14214 case OpImageRead:
14215 case OpImageWrite:
14216 {
14217 // Retrieve the image type, and if it's a Buffer, emit a texel coordinate function
14218 uint32_t tid = result_types[args[opcode == OpImageWrite ? 0 : 2]];
14219 if (tid && compiler.get<SPIRType>(tid).image.dim == DimBuffer && !compiler.msl_options.texture_buffer_native)
14220 return SPVFuncImplTexelBufferCoords;
14221 break;
14222 }
14223
14224 case OpExtInst:
14225 {
14226 uint32_t extension_set = args[2];
14227 if (compiler.get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
14228 {
14229 auto op_450 = static_cast<GLSLstd450>(args[3]);
14230 switch (op_450)
14231 {
14232 case GLSLstd450Radians:
14233 return SPVFuncImplRadians;
14234 case GLSLstd450Degrees:
14235 return SPVFuncImplDegrees;
14236 case GLSLstd450FindILsb:
14237 return SPVFuncImplFindILsb;
14238 case GLSLstd450FindSMsb:
14239 return SPVFuncImplFindSMsb;
14240 case GLSLstd450FindUMsb:
14241 return SPVFuncImplFindUMsb;
14242 case GLSLstd450SSign:
14243 return SPVFuncImplSSign;
14244 case GLSLstd450Reflect:
14245 {
14246 auto &type = compiler.get<SPIRType>(args[0]);
14247 if (type.vecsize == 1)
14248 return SPVFuncImplReflectScalar;
14249 break;
14250 }
14251 case GLSLstd450Refract:
14252 {
14253 auto &type = compiler.get<SPIRType>(args[0]);
14254 if (type.vecsize == 1)
14255 return SPVFuncImplRefractScalar;
14256 break;
14257 }
14258 case GLSLstd450FaceForward:
14259 {
14260 auto &type = compiler.get<SPIRType>(args[0]);
14261 if (type.vecsize == 1)
14262 return SPVFuncImplFaceForwardScalar;
14263 break;
14264 }
14265 case GLSLstd450MatrixInverse:
14266 {
14267 auto &mat_type = compiler.get<SPIRType>(args[0]);
14268 switch (mat_type.columns)
14269 {
14270 case 2:
14271 return SPVFuncImplInverse2x2;
14272 case 3:
14273 return SPVFuncImplInverse3x3;
14274 case 4:
14275 return SPVFuncImplInverse4x4;
14276 default:
14277 break;
14278 }
14279 break;
14280 }
14281 default:
14282 break;
14283 }
14284 }
14285 break;
14286 }
14287
14288 case OpGroupNonUniformBroadcast:
14289 return SPVFuncImplSubgroupBroadcast;
14290
14291 case OpGroupNonUniformBroadcastFirst:
14292 return SPVFuncImplSubgroupBroadcastFirst;
14293
14294 case OpGroupNonUniformBallot:
14295 return SPVFuncImplSubgroupBallot;
14296
14297 case OpGroupNonUniformInverseBallot:
14298 case OpGroupNonUniformBallotBitExtract:
14299 return SPVFuncImplSubgroupBallotBitExtract;
14300
14301 case OpGroupNonUniformBallotFindLSB:
14302 return SPVFuncImplSubgroupBallotFindLSB;
14303
14304 case OpGroupNonUniformBallotFindMSB:
14305 return SPVFuncImplSubgroupBallotFindMSB;
14306
14307 case OpGroupNonUniformBallotBitCount:
14308 return SPVFuncImplSubgroupBallotBitCount;
14309
14310 case OpGroupNonUniformAllEqual:
14311 return SPVFuncImplSubgroupAllEqual;
14312
14313 case OpGroupNonUniformShuffle:
14314 return SPVFuncImplSubgroupShuffle;
14315
14316 case OpGroupNonUniformShuffleXor:
14317 return SPVFuncImplSubgroupShuffleXor;
14318
14319 case OpGroupNonUniformShuffleUp:
14320 return SPVFuncImplSubgroupShuffleUp;
14321
14322 case OpGroupNonUniformShuffleDown:
14323 return SPVFuncImplSubgroupShuffleDown;
14324
14325 case OpGroupNonUniformQuadBroadcast:
14326 return SPVFuncImplQuadBroadcast;
14327
14328 case OpGroupNonUniformQuadSwap:
14329 return SPVFuncImplQuadSwap;
14330
14331 default:
14332 break;
14333 }
14334 return SPVFuncImplNone;
14335 }
14336
14337 // Sort both type and meta member content based on builtin status (put builtins at end),
14338 // then by the required sorting aspect.
sort()14339 void CompilerMSL::MemberSorter::sort()
14340 {
14341 // Create a temporary array of consecutive member indices and sort it based on how
14342 // the members should be reordered, based on builtin and sorting aspect meta info.
14343 size_t mbr_cnt = type.member_types.size();
14344 SmallVector<uint32_t> mbr_idxs(mbr_cnt);
14345 std::iota(mbr_idxs.begin(), mbr_idxs.end(), 0); // Fill with consecutive indices
14346 std::stable_sort(mbr_idxs.begin(), mbr_idxs.end(), *this); // Sort member indices based on sorting aspect
14347
14348 bool sort_is_identity = true;
14349 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
14350 {
14351 if (mbr_idx != mbr_idxs[mbr_idx])
14352 {
14353 sort_is_identity = false;
14354 break;
14355 }
14356 }
14357
14358 if (sort_is_identity)
14359 return;
14360
14361 if (meta.members.size() < type.member_types.size())
14362 {
14363 // This should never trigger in normal circumstances, but to be safe.
14364 meta.members.resize(type.member_types.size());
14365 }
14366
14367 // Move type and meta member info to the order defined by the sorted member indices.
14368 // This is done by creating temporary copies of both member types and meta, and then
14369 // copying back to the original content at the sorted indices.
14370 auto mbr_types_cpy = type.member_types;
14371 auto mbr_meta_cpy = meta.members;
14372 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
14373 {
14374 type.member_types[mbr_idx] = mbr_types_cpy[mbr_idxs[mbr_idx]];
14375 meta.members[mbr_idx] = mbr_meta_cpy[mbr_idxs[mbr_idx]];
14376 }
14377
14378 if (sort_aspect == SortAspect::Offset)
14379 {
14380 // If we're sorting by Offset, this might affect user code which accesses a buffer block.
14381 // We will need to redirect member indices from one index to sorted index.
14382 type.member_type_index_redirection = std::move(mbr_idxs);
14383 }
14384 }
14385
14386 // Sort first by builtin status (put builtins at end), then by the sorting aspect.
operator ()(uint32_t mbr_idx1,uint32_t mbr_idx2)14387 bool CompilerMSL::MemberSorter::operator()(uint32_t mbr_idx1, uint32_t mbr_idx2)
14388 {
14389 auto &mbr_meta1 = meta.members[mbr_idx1];
14390 auto &mbr_meta2 = meta.members[mbr_idx2];
14391 if (mbr_meta1.builtin != mbr_meta2.builtin)
14392 return mbr_meta2.builtin;
14393 else
14394 switch (sort_aspect)
14395 {
14396 case Location:
14397 return mbr_meta1.location < mbr_meta2.location;
14398 case LocationReverse:
14399 return mbr_meta1.location > mbr_meta2.location;
14400 case Offset:
14401 return mbr_meta1.offset < mbr_meta2.offset;
14402 case OffsetThenLocationReverse:
14403 return (mbr_meta1.offset < mbr_meta2.offset) ||
14404 ((mbr_meta1.offset == mbr_meta2.offset) && (mbr_meta1.location > mbr_meta2.location));
14405 case Alphabetical:
14406 return mbr_meta1.alias < mbr_meta2.alias;
14407 default:
14408 return false;
14409 }
14410 }
14411
MemberSorter(SPIRType & t,Meta & m,SortAspect sa)14412 CompilerMSL::MemberSorter::MemberSorter(SPIRType &t, Meta &m, SortAspect sa)
14413 : type(t)
14414 , meta(m)
14415 , sort_aspect(sa)
14416 {
14417 // Ensure enough meta info is available
14418 meta.members.resize(max(type.member_types.size(), meta.members.size()));
14419 }
14420
remap_constexpr_sampler(VariableID id,const MSLConstexprSampler & sampler)14421 void CompilerMSL::remap_constexpr_sampler(VariableID id, const MSLConstexprSampler &sampler)
14422 {
14423 auto &type = get<SPIRType>(get<SPIRVariable>(id).basetype);
14424 if (type.basetype != SPIRType::SampledImage && type.basetype != SPIRType::Sampler)
14425 SPIRV_CROSS_THROW("Can only remap SampledImage and Sampler type.");
14426 if (!type.array.empty())
14427 SPIRV_CROSS_THROW("Can not remap array of samplers.");
14428 constexpr_samplers_by_id[id] = sampler;
14429 }
14430
remap_constexpr_sampler_by_binding(uint32_t desc_set,uint32_t binding,const MSLConstexprSampler & sampler)14431 void CompilerMSL::remap_constexpr_sampler_by_binding(uint32_t desc_set, uint32_t binding,
14432 const MSLConstexprSampler &sampler)
14433 {
14434 constexpr_samplers_by_binding[{ desc_set, binding }] = sampler;
14435 }
14436
cast_from_builtin_load(uint32_t source_id,std::string & expr,const SPIRType & expr_type)14437 void CompilerMSL::cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
14438 {
14439 auto *var = maybe_get_backing_variable(source_id);
14440 if (var)
14441 source_id = var->self;
14442
14443 // Only interested in standalone builtin variables.
14444 if (!has_decoration(source_id, DecorationBuiltIn))
14445 return;
14446
14447 auto builtin = static_cast<BuiltIn>(get_decoration(source_id, DecorationBuiltIn));
14448 auto expected_type = expr_type.basetype;
14449 auto expected_width = expr_type.width;
14450 switch (builtin)
14451 {
14452 case BuiltInGlobalInvocationId:
14453 case BuiltInLocalInvocationId:
14454 case BuiltInWorkgroupId:
14455 case BuiltInLocalInvocationIndex:
14456 case BuiltInWorkgroupSize:
14457 case BuiltInNumWorkgroups:
14458 case BuiltInLayer:
14459 case BuiltInViewportIndex:
14460 case BuiltInFragStencilRefEXT:
14461 case BuiltInPrimitiveId:
14462 case BuiltInSubgroupSize:
14463 case BuiltInSubgroupLocalInvocationId:
14464 case BuiltInViewIndex:
14465 case BuiltInVertexIndex:
14466 case BuiltInInstanceIndex:
14467 case BuiltInBaseInstance:
14468 case BuiltInBaseVertex:
14469 expected_type = SPIRType::UInt;
14470 expected_width = 32;
14471 break;
14472
14473 case BuiltInTessLevelInner:
14474 case BuiltInTessLevelOuter:
14475 if (get_execution_model() == ExecutionModelTessellationControl)
14476 {
14477 expected_type = SPIRType::Half;
14478 expected_width = 16;
14479 }
14480 break;
14481
14482 default:
14483 break;
14484 }
14485
14486 if (expected_type != expr_type.basetype)
14487 {
14488 if (expected_width != expr_type.width)
14489 {
14490 // These are of different widths, so we cannot do a straight bitcast.
14491 expr = join(type_to_glsl(expr_type), "(", expr, ")");
14492 }
14493 else
14494 {
14495 expr = bitcast_expression(expr_type, expected_type, expr);
14496 }
14497 }
14498
14499 if (builtin == BuiltInTessCoord && get_entry_point().flags.get(ExecutionModeQuads) && expr_type.vecsize == 3)
14500 {
14501 // In SPIR-V, this is always a vec3, even for quads. In Metal, though, it's a float2 for quads.
14502 // The code is expecting a float3, so we need to widen this.
14503 expr = join("float3(", expr, ", 0)");
14504 }
14505 }
14506
cast_to_builtin_store(uint32_t target_id,std::string & expr,const SPIRType & expr_type)14507 void CompilerMSL::cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
14508 {
14509 auto *var = maybe_get_backing_variable(target_id);
14510 if (var)
14511 target_id = var->self;
14512
14513 // Only interested in standalone builtin variables.
14514 if (!has_decoration(target_id, DecorationBuiltIn))
14515 return;
14516
14517 auto builtin = static_cast<BuiltIn>(get_decoration(target_id, DecorationBuiltIn));
14518 auto expected_type = expr_type.basetype;
14519 auto expected_width = expr_type.width;
14520 switch (builtin)
14521 {
14522 case BuiltInLayer:
14523 case BuiltInViewportIndex:
14524 case BuiltInFragStencilRefEXT:
14525 case BuiltInPrimitiveId:
14526 case BuiltInViewIndex:
14527 expected_type = SPIRType::UInt;
14528 expected_width = 32;
14529 break;
14530
14531 case BuiltInTessLevelInner:
14532 case BuiltInTessLevelOuter:
14533 expected_type = SPIRType::Half;
14534 expected_width = 16;
14535 break;
14536
14537 default:
14538 break;
14539 }
14540
14541 if (expected_type != expr_type.basetype)
14542 {
14543 if (expected_width != expr_type.width)
14544 {
14545 // These are of different widths, so we cannot do a straight bitcast.
14546 auto type = expr_type;
14547 type.basetype = expected_type;
14548 type.width = expected_width;
14549 expr = join(type_to_glsl(type), "(", expr, ")");
14550 }
14551 else
14552 {
14553 auto type = expr_type;
14554 type.basetype = expected_type;
14555 expr = bitcast_expression(type, expr_type.basetype, expr);
14556 }
14557 }
14558 }
14559
to_initializer_expression(const SPIRVariable & var)14560 string CompilerMSL::to_initializer_expression(const SPIRVariable &var)
14561 {
14562 // We risk getting an array initializer here with MSL. If we have an array.
14563 // FIXME: We cannot handle non-constant arrays being initialized.
14564 // We will need to inject spvArrayCopy here somehow ...
14565 auto &type = get<SPIRType>(var.basetype);
14566 string expr;
14567 if (ir.ids[var.initializer].get_type() == TypeConstant &&
14568 (!type.array.empty() || type.basetype == SPIRType::Struct))
14569 expr = constant_expression(get<SPIRConstant>(var.initializer));
14570 else
14571 expr = CompilerGLSL::to_initializer_expression(var);
14572 // If the initializer has more vector components than the variable, add a swizzle.
14573 // FIXME: This can't handle arrays or structs.
14574 auto &init_type = expression_type(var.initializer);
14575 if (type.array.empty() && type.basetype != SPIRType::Struct && init_type.vecsize > type.vecsize)
14576 expr = enclose_expression(expr + vector_swizzle(type.vecsize, 0));
14577 return expr;
14578 }
14579
to_zero_initialized_expression(uint32_t)14580 string CompilerMSL::to_zero_initialized_expression(uint32_t)
14581 {
14582 return "{}";
14583 }
14584
descriptor_set_is_argument_buffer(uint32_t desc_set) const14585 bool CompilerMSL::descriptor_set_is_argument_buffer(uint32_t desc_set) const
14586 {
14587 if (!msl_options.argument_buffers)
14588 return false;
14589 if (desc_set >= kMaxArgumentBuffers)
14590 return false;
14591
14592 return (argument_buffer_discrete_mask & (1u << desc_set)) == 0;
14593 }
14594
is_supported_argument_buffer_type(const SPIRType & type) const14595 bool CompilerMSL::is_supported_argument_buffer_type(const SPIRType &type) const
14596 {
14597 // Very specifically, image load-store in argument buffers are disallowed on MSL on iOS.
14598 // But we won't know when the argument buffer is encoded whether this image will have
14599 // a NonWritable decoration. So just use discrete arguments for all storage images
14600 // on iOS.
14601 bool is_storage_image = type.basetype == SPIRType::Image && type.image.sampled == 2;
14602 bool is_supported_type = !msl_options.is_ios() || !is_storage_image;
14603 return !type_is_msl_framebuffer_fetch(type) && is_supported_type;
14604 }
14605
analyze_argument_buffers()14606 void CompilerMSL::analyze_argument_buffers()
14607 {
14608 // Gather all used resources and sort them out into argument buffers.
14609 // Each argument buffer corresponds to a descriptor set in SPIR-V.
14610 // The [[id(N)]] values used correspond to the resource mapping we have for MSL.
14611 // Otherwise, the binding number is used, but this is generally not safe some types like
14612 // combined image samplers and arrays of resources. Metal needs different indices here,
14613 // while SPIR-V can have one descriptor set binding. To use argument buffers in practice,
14614 // you will need to use the remapping from the API.
14615 for (auto &id : argument_buffer_ids)
14616 id = 0;
14617
14618 // Output resources, sorted by resource index & type.
14619 struct Resource
14620 {
14621 SPIRVariable *var;
14622 string name;
14623 SPIRType::BaseType basetype;
14624 uint32_t index;
14625 uint32_t plane;
14626 };
14627 SmallVector<Resource> resources_in_set[kMaxArgumentBuffers];
14628 SmallVector<uint32_t> inline_block_vars;
14629
14630 bool set_needs_swizzle_buffer[kMaxArgumentBuffers] = {};
14631 bool set_needs_buffer_sizes[kMaxArgumentBuffers] = {};
14632 bool needs_buffer_sizes = false;
14633
14634 ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &var) {
14635 if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
14636 var.storage == StorageClassStorageBuffer) &&
14637 !is_hidden_variable(var))
14638 {
14639 uint32_t desc_set = get_decoration(self, DecorationDescriptorSet);
14640 // Ignore if it's part of a push descriptor set.
14641 if (!descriptor_set_is_argument_buffer(desc_set))
14642 return;
14643
14644 uint32_t var_id = var.self;
14645 auto &type = get_variable_data_type(var);
14646
14647 if (desc_set >= kMaxArgumentBuffers)
14648 SPIRV_CROSS_THROW("Descriptor set index is out of range.");
14649
14650 const MSLConstexprSampler *constexpr_sampler = nullptr;
14651 if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
14652 {
14653 constexpr_sampler = find_constexpr_sampler(var_id);
14654 if (constexpr_sampler)
14655 {
14656 // Mark this ID as a constexpr sampler for later in case it came from set/bindings.
14657 constexpr_samplers_by_id[var_id] = *constexpr_sampler;
14658 }
14659 }
14660
14661 uint32_t binding = get_decoration(var_id, DecorationBinding);
14662 if (type.basetype == SPIRType::SampledImage)
14663 {
14664 add_resource_name(var_id);
14665
14666 uint32_t plane_count = 1;
14667 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
14668 plane_count = constexpr_sampler->planes;
14669
14670 for (uint32_t i = 0; i < plane_count; i++)
14671 {
14672 uint32_t image_resource_index = get_metal_resource_index(var, SPIRType::Image, i);
14673 resources_in_set[desc_set].push_back(
14674 { &var, to_name(var_id), SPIRType::Image, image_resource_index, i });
14675 }
14676
14677 if (type.image.dim != DimBuffer && !constexpr_sampler)
14678 {
14679 uint32_t sampler_resource_index = get_metal_resource_index(var, SPIRType::Sampler);
14680 resources_in_set[desc_set].push_back(
14681 { &var, to_sampler_expression(var_id), SPIRType::Sampler, sampler_resource_index, 0 });
14682 }
14683 }
14684 else if (inline_uniform_blocks.count(SetBindingPair{ desc_set, binding }))
14685 {
14686 inline_block_vars.push_back(var_id);
14687 }
14688 else if (!constexpr_sampler && is_supported_argument_buffer_type(type))
14689 {
14690 // constexpr samplers are not declared as resources.
14691 // Inline uniform blocks are always emitted at the end.
14692 add_resource_name(var_id);
14693 resources_in_set[desc_set].push_back(
14694 { &var, to_name(var_id), type.basetype, get_metal_resource_index(var, type.basetype), 0 });
14695
14696 // Emulate texture2D atomic operations
14697 if (atomic_image_vars.count(var.self))
14698 {
14699 uint32_t buffer_resource_index = get_metal_resource_index(var, SPIRType::AtomicCounter, 0);
14700 resources_in_set[desc_set].push_back(
14701 { &var, to_name(var_id) + "_atomic", SPIRType::Struct, buffer_resource_index, 0 });
14702 }
14703 }
14704
14705 // Check if this descriptor set needs a swizzle buffer.
14706 if (needs_swizzle_buffer_def && is_sampled_image_type(type))
14707 set_needs_swizzle_buffer[desc_set] = true;
14708 else if (buffers_requiring_array_length.count(var_id) != 0)
14709 {
14710 set_needs_buffer_sizes[desc_set] = true;
14711 needs_buffer_sizes = true;
14712 }
14713 }
14714 });
14715
14716 if (needs_swizzle_buffer_def || needs_buffer_sizes)
14717 {
14718 uint32_t uint_ptr_type_id = 0;
14719
14720 // We might have to add a swizzle buffer resource to the set.
14721 for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
14722 {
14723 if (!set_needs_swizzle_buffer[desc_set] && !set_needs_buffer_sizes[desc_set])
14724 continue;
14725
14726 if (uint_ptr_type_id == 0)
14727 {
14728 uint_ptr_type_id = ir.increase_bound_by(1);
14729
14730 // Create a buffer to hold extra data, including the swizzle constants.
14731 SPIRType uint_type_pointer = get_uint_type();
14732 uint_type_pointer.pointer = true;
14733 uint_type_pointer.pointer_depth = 1;
14734 uint_type_pointer.parent_type = get_uint_type_id();
14735 uint_type_pointer.storage = StorageClassUniform;
14736 set<SPIRType>(uint_ptr_type_id, uint_type_pointer);
14737 set_decoration(uint_ptr_type_id, DecorationArrayStride, 4);
14738 }
14739
14740 if (set_needs_swizzle_buffer[desc_set])
14741 {
14742 uint32_t var_id = ir.increase_bound_by(1);
14743 auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
14744 set_name(var_id, "spvSwizzleConstants");
14745 set_decoration(var_id, DecorationDescriptorSet, desc_set);
14746 set_decoration(var_id, DecorationBinding, kSwizzleBufferBinding);
14747 resources_in_set[desc_set].push_back(
14748 { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0 });
14749 }
14750
14751 if (set_needs_buffer_sizes[desc_set])
14752 {
14753 uint32_t var_id = ir.increase_bound_by(1);
14754 auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
14755 set_name(var_id, "spvBufferSizeConstants");
14756 set_decoration(var_id, DecorationDescriptorSet, desc_set);
14757 set_decoration(var_id, DecorationBinding, kBufferSizeBufferBinding);
14758 resources_in_set[desc_set].push_back(
14759 { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0 });
14760 }
14761 }
14762 }
14763
14764 // Now add inline uniform blocks.
14765 for (uint32_t var_id : inline_block_vars)
14766 {
14767 auto &var = get<SPIRVariable>(var_id);
14768 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
14769 add_resource_name(var_id);
14770 resources_in_set[desc_set].push_back(
14771 { &var, to_name(var_id), SPIRType::Struct, get_metal_resource_index(var, SPIRType::Struct), 0 });
14772 }
14773
14774 for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
14775 {
14776 auto &resources = resources_in_set[desc_set];
14777 if (resources.empty())
14778 continue;
14779
14780 assert(descriptor_set_is_argument_buffer(desc_set));
14781
14782 uint32_t next_id = ir.increase_bound_by(3);
14783 uint32_t type_id = next_id + 1;
14784 uint32_t ptr_type_id = next_id + 2;
14785 argument_buffer_ids[desc_set] = next_id;
14786
14787 auto &buffer_type = set<SPIRType>(type_id);
14788
14789 buffer_type.basetype = SPIRType::Struct;
14790
14791 if ((argument_buffer_device_storage_mask & (1u << desc_set)) != 0)
14792 {
14793 buffer_type.storage = StorageClassStorageBuffer;
14794 // Make sure the argument buffer gets marked as const device.
14795 set_decoration(next_id, DecorationNonWritable);
14796 // Need to mark the type as a Block to enable this.
14797 set_decoration(type_id, DecorationBlock);
14798 }
14799 else
14800 buffer_type.storage = StorageClassUniform;
14801
14802 set_name(type_id, join("spvDescriptorSetBuffer", desc_set));
14803
14804 auto &ptr_type = set<SPIRType>(ptr_type_id);
14805 ptr_type = buffer_type;
14806 ptr_type.pointer = true;
14807 ptr_type.pointer_depth = 1;
14808 ptr_type.parent_type = type_id;
14809
14810 uint32_t buffer_variable_id = next_id;
14811 set<SPIRVariable>(buffer_variable_id, ptr_type_id, StorageClassUniform);
14812 set_name(buffer_variable_id, join("spvDescriptorSet", desc_set));
14813
14814 // Ids must be emitted in ID order.
14815 sort(begin(resources), end(resources), [&](const Resource &lhs, const Resource &rhs) -> bool {
14816 return tie(lhs.index, lhs.basetype) < tie(rhs.index, rhs.basetype);
14817 });
14818
14819 uint32_t member_index = 0;
14820 for (auto &resource : resources)
14821 {
14822 auto &var = *resource.var;
14823 auto &type = get_variable_data_type(var);
14824 string mbr_name = ensure_valid_name(resource.name, "m");
14825 if (resource.plane > 0)
14826 mbr_name += join(plane_name_suffix, resource.plane);
14827 set_member_name(buffer_type.self, member_index, mbr_name);
14828
14829 if (resource.basetype == SPIRType::Sampler && type.basetype != SPIRType::Sampler)
14830 {
14831 // Have to synthesize a sampler type here.
14832
14833 bool type_is_array = !type.array.empty();
14834 uint32_t sampler_type_id = ir.increase_bound_by(type_is_array ? 2 : 1);
14835 auto &new_sampler_type = set<SPIRType>(sampler_type_id);
14836 new_sampler_type.basetype = SPIRType::Sampler;
14837 new_sampler_type.storage = StorageClassUniformConstant;
14838
14839 if (type_is_array)
14840 {
14841 uint32_t sampler_type_array_id = sampler_type_id + 1;
14842 auto &sampler_type_array = set<SPIRType>(sampler_type_array_id);
14843 sampler_type_array = new_sampler_type;
14844 sampler_type_array.array = type.array;
14845 sampler_type_array.array_size_literal = type.array_size_literal;
14846 sampler_type_array.parent_type = sampler_type_id;
14847 buffer_type.member_types.push_back(sampler_type_array_id);
14848 }
14849 else
14850 buffer_type.member_types.push_back(sampler_type_id);
14851 }
14852 else
14853 {
14854 uint32_t binding = get_decoration(var.self, DecorationBinding);
14855 SetBindingPair pair = { desc_set, binding };
14856
14857 if (resource.basetype == SPIRType::Image || resource.basetype == SPIRType::Sampler ||
14858 resource.basetype == SPIRType::SampledImage)
14859 {
14860 // Drop pointer information when we emit the resources into a struct.
14861 buffer_type.member_types.push_back(get_variable_data_type_id(var));
14862 if (resource.plane == 0)
14863 set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
14864 }
14865 else if (buffers_requiring_dynamic_offset.count(pair))
14866 {
14867 // Don't set the qualified name here; we'll define a variable holding the corrected buffer address later.
14868 buffer_type.member_types.push_back(var.basetype);
14869 buffers_requiring_dynamic_offset[pair].second = var.self;
14870 }
14871 else if (inline_uniform_blocks.count(pair))
14872 {
14873 // Put the buffer block itself into the argument buffer.
14874 buffer_type.member_types.push_back(get_variable_data_type_id(var));
14875 set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
14876 }
14877 else if (atomic_image_vars.count(var.self))
14878 {
14879 // Emulate texture2D atomic operations.
14880 // Don't set the qualified name: it's already set for this variable,
14881 // and the code that references the buffer manually appends "_atomic"
14882 // to the name.
14883 uint32_t offset = ir.increase_bound_by(2);
14884 uint32_t atomic_type_id = offset;
14885 uint32_t type_ptr_id = offset + 1;
14886
14887 SPIRType atomic_type;
14888 atomic_type.basetype = SPIRType::AtomicCounter;
14889 atomic_type.width = 32;
14890 atomic_type.vecsize = 1;
14891 set<SPIRType>(atomic_type_id, atomic_type);
14892
14893 atomic_type.pointer = true;
14894 atomic_type.parent_type = atomic_type_id;
14895 atomic_type.storage = StorageClassStorageBuffer;
14896 auto &atomic_ptr_type = set<SPIRType>(type_ptr_id, atomic_type);
14897 atomic_ptr_type.self = atomic_type_id;
14898
14899 buffer_type.member_types.push_back(type_ptr_id);
14900 }
14901 else
14902 {
14903 // Resources will be declared as pointers not references, so automatically dereference as appropriate.
14904 buffer_type.member_types.push_back(var.basetype);
14905 if (type.array.empty())
14906 set_qualified_name(var.self, join("(*", to_name(buffer_variable_id), ".", mbr_name, ")"));
14907 else
14908 set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
14909 }
14910 }
14911
14912 set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationResourceIndexPrimary,
14913 resource.index);
14914 set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationInterfaceOrigID,
14915 var.self);
14916 member_index++;
14917 }
14918 }
14919 }
14920
activate_argument_buffer_resources()14921 void CompilerMSL::activate_argument_buffer_resources()
14922 {
14923 // For ABI compatibility, force-enable all resources which are part of argument buffers.
14924 ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, const SPIRVariable &) {
14925 if (!has_decoration(self, DecorationDescriptorSet))
14926 return;
14927
14928 uint32_t desc_set = get_decoration(self, DecorationDescriptorSet);
14929 if (descriptor_set_is_argument_buffer(desc_set))
14930 active_interface_variables.insert(self);
14931 });
14932 }
14933
using_builtin_array() const14934 bool CompilerMSL::using_builtin_array() const
14935 {
14936 return msl_options.force_native_arrays || is_using_builtin_array;
14937 }
14938
set_combined_sampler_suffix(const char * suffix)14939 void CompilerMSL::set_combined_sampler_suffix(const char *suffix)
14940 {
14941 sampler_name_suffix = suffix;
14942 }
14943
get_combined_sampler_suffix() const14944 const char *CompilerMSL::get_combined_sampler_suffix() const
14945 {
14946 return sampler_name_suffix.c_str();
14947 }
14948