1 /*
2  * Copyright 2016-2021 Robert Konrad
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  *
16  */
17 
18 /*
19  * At your option, you may choose to accept this material under either:
20  *  1. The Apache License, Version 2.0, found at <http://www.apache.org/licenses/LICENSE-2.0>, or
21  *  2. The MIT License, found at <http://opensource.org/licenses/MIT>.
22  * SPDX-License-Identifier: Apache-2.0 OR MIT.
23  */
24 
25 #include "spirv_hlsl.hpp"
26 #include "GLSL.std.450.h"
27 #include <algorithm>
28 #include <assert.h>
29 
30 using namespace spv;
31 using namespace SPIRV_CROSS_NAMESPACE;
32 using namespace std;
33 
34 enum class ImageFormatNormalizedState
35 {
36 	None = 0,
37 	Unorm = 1,
38 	Snorm = 2
39 };
40 
image_format_to_normalized_state(ImageFormat fmt)41 static ImageFormatNormalizedState image_format_to_normalized_state(ImageFormat fmt)
42 {
43 	switch (fmt)
44 	{
45 	case ImageFormatR8:
46 	case ImageFormatR16:
47 	case ImageFormatRg8:
48 	case ImageFormatRg16:
49 	case ImageFormatRgba8:
50 	case ImageFormatRgba16:
51 	case ImageFormatRgb10A2:
52 		return ImageFormatNormalizedState::Unorm;
53 
54 	case ImageFormatR8Snorm:
55 	case ImageFormatR16Snorm:
56 	case ImageFormatRg8Snorm:
57 	case ImageFormatRg16Snorm:
58 	case ImageFormatRgba8Snorm:
59 	case ImageFormatRgba16Snorm:
60 		return ImageFormatNormalizedState::Snorm;
61 
62 	default:
63 		break;
64 	}
65 
66 	return ImageFormatNormalizedState::None;
67 }
68 
image_format_to_components(ImageFormat fmt)69 static unsigned image_format_to_components(ImageFormat fmt)
70 {
71 	switch (fmt)
72 	{
73 	case ImageFormatR8:
74 	case ImageFormatR16:
75 	case ImageFormatR8Snorm:
76 	case ImageFormatR16Snorm:
77 	case ImageFormatR16f:
78 	case ImageFormatR32f:
79 	case ImageFormatR8i:
80 	case ImageFormatR16i:
81 	case ImageFormatR32i:
82 	case ImageFormatR8ui:
83 	case ImageFormatR16ui:
84 	case ImageFormatR32ui:
85 		return 1;
86 
87 	case ImageFormatRg8:
88 	case ImageFormatRg16:
89 	case ImageFormatRg8Snorm:
90 	case ImageFormatRg16Snorm:
91 	case ImageFormatRg16f:
92 	case ImageFormatRg32f:
93 	case ImageFormatRg8i:
94 	case ImageFormatRg16i:
95 	case ImageFormatRg32i:
96 	case ImageFormatRg8ui:
97 	case ImageFormatRg16ui:
98 	case ImageFormatRg32ui:
99 		return 2;
100 
101 	case ImageFormatR11fG11fB10f:
102 		return 3;
103 
104 	case ImageFormatRgba8:
105 	case ImageFormatRgba16:
106 	case ImageFormatRgb10A2:
107 	case ImageFormatRgba8Snorm:
108 	case ImageFormatRgba16Snorm:
109 	case ImageFormatRgba16f:
110 	case ImageFormatRgba32f:
111 	case ImageFormatRgba8i:
112 	case ImageFormatRgba16i:
113 	case ImageFormatRgba32i:
114 	case ImageFormatRgba8ui:
115 	case ImageFormatRgba16ui:
116 	case ImageFormatRgba32ui:
117 	case ImageFormatRgb10a2ui:
118 		return 4;
119 
120 	case ImageFormatUnknown:
121 		return 4; // Assume 4.
122 
123 	default:
124 		SPIRV_CROSS_THROW("Unrecognized typed image format.");
125 	}
126 }
127 
image_format_to_type(ImageFormat fmt,SPIRType::BaseType basetype)128 static string image_format_to_type(ImageFormat fmt, SPIRType::BaseType basetype)
129 {
130 	switch (fmt)
131 	{
132 	case ImageFormatR8:
133 	case ImageFormatR16:
134 		if (basetype != SPIRType::Float)
135 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
136 		return "unorm float";
137 	case ImageFormatRg8:
138 	case ImageFormatRg16:
139 		if (basetype != SPIRType::Float)
140 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
141 		return "unorm float2";
142 	case ImageFormatRgba8:
143 	case ImageFormatRgba16:
144 		if (basetype != SPIRType::Float)
145 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
146 		return "unorm float4";
147 	case ImageFormatRgb10A2:
148 		if (basetype != SPIRType::Float)
149 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
150 		return "unorm float4";
151 
152 	case ImageFormatR8Snorm:
153 	case ImageFormatR16Snorm:
154 		if (basetype != SPIRType::Float)
155 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
156 		return "snorm float";
157 	case ImageFormatRg8Snorm:
158 	case ImageFormatRg16Snorm:
159 		if (basetype != SPIRType::Float)
160 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
161 		return "snorm float2";
162 	case ImageFormatRgba8Snorm:
163 	case ImageFormatRgba16Snorm:
164 		if (basetype != SPIRType::Float)
165 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
166 		return "snorm float4";
167 
168 	case ImageFormatR16f:
169 	case ImageFormatR32f:
170 		if (basetype != SPIRType::Float)
171 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
172 		return "float";
173 	case ImageFormatRg16f:
174 	case ImageFormatRg32f:
175 		if (basetype != SPIRType::Float)
176 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
177 		return "float2";
178 	case ImageFormatRgba16f:
179 	case ImageFormatRgba32f:
180 		if (basetype != SPIRType::Float)
181 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
182 		return "float4";
183 
184 	case ImageFormatR11fG11fB10f:
185 		if (basetype != SPIRType::Float)
186 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
187 		return "float3";
188 
189 	case ImageFormatR8i:
190 	case ImageFormatR16i:
191 	case ImageFormatR32i:
192 		if (basetype != SPIRType::Int)
193 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
194 		return "int";
195 	case ImageFormatRg8i:
196 	case ImageFormatRg16i:
197 	case ImageFormatRg32i:
198 		if (basetype != SPIRType::Int)
199 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
200 		return "int2";
201 	case ImageFormatRgba8i:
202 	case ImageFormatRgba16i:
203 	case ImageFormatRgba32i:
204 		if (basetype != SPIRType::Int)
205 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
206 		return "int4";
207 
208 	case ImageFormatR8ui:
209 	case ImageFormatR16ui:
210 	case ImageFormatR32ui:
211 		if (basetype != SPIRType::UInt)
212 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
213 		return "uint";
214 	case ImageFormatRg8ui:
215 	case ImageFormatRg16ui:
216 	case ImageFormatRg32ui:
217 		if (basetype != SPIRType::UInt)
218 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
219 		return "uint2";
220 	case ImageFormatRgba8ui:
221 	case ImageFormatRgba16ui:
222 	case ImageFormatRgba32ui:
223 		if (basetype != SPIRType::UInt)
224 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
225 		return "uint4";
226 	case ImageFormatRgb10a2ui:
227 		if (basetype != SPIRType::UInt)
228 			SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
229 		return "uint4";
230 
231 	case ImageFormatUnknown:
232 		switch (basetype)
233 		{
234 		case SPIRType::Float:
235 			return "float4";
236 		case SPIRType::Int:
237 			return "int4";
238 		case SPIRType::UInt:
239 			return "uint4";
240 		default:
241 			SPIRV_CROSS_THROW("Unsupported base type for image.");
242 		}
243 
244 	default:
245 		SPIRV_CROSS_THROW("Unrecognized typed image format.");
246 	}
247 }
248 
image_type_hlsl_modern(const SPIRType & type,uint32_t id)249 string CompilerHLSL::image_type_hlsl_modern(const SPIRType &type, uint32_t id)
250 {
251 	auto &imagetype = get<SPIRType>(type.image.type);
252 	const char *dim = nullptr;
253 	bool typed_load = false;
254 	uint32_t components = 4;
255 
256 	bool force_image_srv = hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(id, DecorationNonWritable);
257 
258 	switch (type.image.dim)
259 	{
260 	case Dim1D:
261 		typed_load = type.image.sampled == 2;
262 		dim = "1D";
263 		break;
264 	case Dim2D:
265 		typed_load = type.image.sampled == 2;
266 		dim = "2D";
267 		break;
268 	case Dim3D:
269 		typed_load = type.image.sampled == 2;
270 		dim = "3D";
271 		break;
272 	case DimCube:
273 		if (type.image.sampled == 2)
274 			SPIRV_CROSS_THROW("RWTextureCube does not exist in HLSL.");
275 		dim = "Cube";
276 		break;
277 	case DimRect:
278 		SPIRV_CROSS_THROW("Rectangle texture support is not yet implemented for HLSL."); // TODO
279 	case DimBuffer:
280 		if (type.image.sampled == 1)
281 			return join("Buffer<", type_to_glsl(imagetype), components, ">");
282 		else if (type.image.sampled == 2)
283 		{
284 			if (interlocked_resources.count(id))
285 				return join("RasterizerOrderedBuffer<", image_format_to_type(type.image.format, imagetype.basetype),
286 				            ">");
287 
288 			typed_load = !force_image_srv && type.image.sampled == 2;
289 
290 			const char *rw = force_image_srv ? "" : "RW";
291 			return join(rw, "Buffer<",
292 			            typed_load ? image_format_to_type(type.image.format, imagetype.basetype) :
293 			                         join(type_to_glsl(imagetype), components),
294 			            ">");
295 		}
296 		else
297 			SPIRV_CROSS_THROW("Sampler buffers must be either sampled or unsampled. Cannot deduce in runtime.");
298 	case DimSubpassData:
299 		dim = "2D";
300 		typed_load = false;
301 		break;
302 	default:
303 		SPIRV_CROSS_THROW("Invalid dimension.");
304 	}
305 	const char *arrayed = type.image.arrayed ? "Array" : "";
306 	const char *ms = type.image.ms ? "MS" : "";
307 	const char *rw = typed_load && !force_image_srv ? "RW" : "";
308 
309 	if (force_image_srv)
310 		typed_load = false;
311 
312 	if (typed_load && interlocked_resources.count(id))
313 		rw = "RasterizerOrdered";
314 
315 	return join(rw, "Texture", dim, ms, arrayed, "<",
316 	            typed_load ? image_format_to_type(type.image.format, imagetype.basetype) :
317 	                         join(type_to_glsl(imagetype), components),
318 	            ">");
319 }
320 
image_type_hlsl_legacy(const SPIRType & type,uint32_t)321 string CompilerHLSL::image_type_hlsl_legacy(const SPIRType &type, uint32_t /*id*/)
322 {
323 	auto &imagetype = get<SPIRType>(type.image.type);
324 	string res;
325 
326 	switch (imagetype.basetype)
327 	{
328 	case SPIRType::Int:
329 		res = "i";
330 		break;
331 	case SPIRType::UInt:
332 		res = "u";
333 		break;
334 	default:
335 		break;
336 	}
337 
338 	if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData)
339 		return res + "subpassInput" + (type.image.ms ? "MS" : "");
340 
341 	// If we're emulating subpassInput with samplers, force sampler2D
342 	// so we don't have to specify format.
343 	if (type.basetype == SPIRType::Image && type.image.dim != DimSubpassData)
344 	{
345 		// Sampler buffers are always declared as samplerBuffer even though they might be separate images in the SPIR-V.
346 		if (type.image.dim == DimBuffer && type.image.sampled == 1)
347 			res += "sampler";
348 		else
349 			res += type.image.sampled == 2 ? "image" : "texture";
350 	}
351 	else
352 		res += "sampler";
353 
354 	switch (type.image.dim)
355 	{
356 	case Dim1D:
357 		res += "1D";
358 		break;
359 	case Dim2D:
360 		res += "2D";
361 		break;
362 	case Dim3D:
363 		res += "3D";
364 		break;
365 	case DimCube:
366 		res += "CUBE";
367 		break;
368 
369 	case DimBuffer:
370 		res += "Buffer";
371 		break;
372 
373 	case DimSubpassData:
374 		res += "2D";
375 		break;
376 	default:
377 		SPIRV_CROSS_THROW("Only 1D, 2D, 3D, Buffer, InputTarget and Cube textures supported.");
378 	}
379 
380 	if (type.image.ms)
381 		res += "MS";
382 	if (type.image.arrayed)
383 		res += "Array";
384 
385 	return res;
386 }
387 
image_type_hlsl(const SPIRType & type,uint32_t id)388 string CompilerHLSL::image_type_hlsl(const SPIRType &type, uint32_t id)
389 {
390 	if (hlsl_options.shader_model <= 30)
391 		return image_type_hlsl_legacy(type, id);
392 	else
393 		return image_type_hlsl_modern(type, id);
394 }
395 
396 // The optional id parameter indicates the object whose type we are trying
397 // to find the description for. It is optional. Most type descriptions do not
398 // depend on a specific object's use of that type.
type_to_glsl(const SPIRType & type,uint32_t id)399 string CompilerHLSL::type_to_glsl(const SPIRType &type, uint32_t id)
400 {
401 	// Ignore the pointer type since GLSL doesn't have pointers.
402 
403 	switch (type.basetype)
404 	{
405 	case SPIRType::Struct:
406 		// Need OpName lookup here to get a "sensible" name for a struct.
407 		if (backend.explicit_struct_type)
408 			return join("struct ", to_name(type.self));
409 		else
410 			return to_name(type.self);
411 
412 	case SPIRType::Image:
413 	case SPIRType::SampledImage:
414 		return image_type_hlsl(type, id);
415 
416 	case SPIRType::Sampler:
417 		return comparison_ids.count(id) ? "SamplerComparisonState" : "SamplerState";
418 
419 	case SPIRType::Void:
420 		return "void";
421 
422 	default:
423 		break;
424 	}
425 
426 	if (type.vecsize == 1 && type.columns == 1) // Scalar builtin
427 	{
428 		switch (type.basetype)
429 		{
430 		case SPIRType::Boolean:
431 			return "bool";
432 		case SPIRType::Int:
433 			return backend.basic_int_type;
434 		case SPIRType::UInt:
435 			return backend.basic_uint_type;
436 		case SPIRType::AtomicCounter:
437 			return "atomic_uint";
438 		case SPIRType::Half:
439 			if (hlsl_options.enable_16bit_types)
440 				return "half";
441 			else
442 				return "min16float";
443 		case SPIRType::Short:
444 			if (hlsl_options.enable_16bit_types)
445 				return "int16_t";
446 			else
447 				return "min16int";
448 		case SPIRType::UShort:
449 			if (hlsl_options.enable_16bit_types)
450 				return "uint16_t";
451 			else
452 				return "min16uint";
453 		case SPIRType::Float:
454 			return "float";
455 		case SPIRType::Double:
456 			return "double";
457 		case SPIRType::Int64:
458 			if (hlsl_options.shader_model < 60)
459 				SPIRV_CROSS_THROW("64-bit integers only supported in SM 6.0.");
460 			return "int64_t";
461 		case SPIRType::UInt64:
462 			if (hlsl_options.shader_model < 60)
463 				SPIRV_CROSS_THROW("64-bit integers only supported in SM 6.0.");
464 			return "uint64_t";
465 		default:
466 			return "???";
467 		}
468 	}
469 	else if (type.vecsize > 1 && type.columns == 1) // Vector builtin
470 	{
471 		switch (type.basetype)
472 		{
473 		case SPIRType::Boolean:
474 			return join("bool", type.vecsize);
475 		case SPIRType::Int:
476 			return join("int", type.vecsize);
477 		case SPIRType::UInt:
478 			return join("uint", type.vecsize);
479 		case SPIRType::Half:
480 			return join(hlsl_options.enable_16bit_types ? "half" : "min16float", type.vecsize);
481 		case SPIRType::Short:
482 			return join(hlsl_options.enable_16bit_types ? "int16_t" : "min16int", type.vecsize);
483 		case SPIRType::UShort:
484 			return join(hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", type.vecsize);
485 		case SPIRType::Float:
486 			return join("float", type.vecsize);
487 		case SPIRType::Double:
488 			return join("double", type.vecsize);
489 		case SPIRType::Int64:
490 			return join("i64vec", type.vecsize);
491 		case SPIRType::UInt64:
492 			return join("u64vec", type.vecsize);
493 		default:
494 			return "???";
495 		}
496 	}
497 	else
498 	{
499 		switch (type.basetype)
500 		{
501 		case SPIRType::Boolean:
502 			return join("bool", type.columns, "x", type.vecsize);
503 		case SPIRType::Int:
504 			return join("int", type.columns, "x", type.vecsize);
505 		case SPIRType::UInt:
506 			return join("uint", type.columns, "x", type.vecsize);
507 		case SPIRType::Half:
508 			return join(hlsl_options.enable_16bit_types ? "half" : "min16float", type.columns, "x", type.vecsize);
509 		case SPIRType::Short:
510 			return join(hlsl_options.enable_16bit_types ? "int16_t" : "min16int", type.columns, "x", type.vecsize);
511 		case SPIRType::UShort:
512 			return join(hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", type.columns, "x", type.vecsize);
513 		case SPIRType::Float:
514 			return join("float", type.columns, "x", type.vecsize);
515 		case SPIRType::Double:
516 			return join("double", type.columns, "x", type.vecsize);
517 		// Matrix types not supported for int64/uint64.
518 		default:
519 			return "???";
520 		}
521 	}
522 }
523 
emit_header()524 void CompilerHLSL::emit_header()
525 {
526 	for (auto &header : header_lines)
527 		statement(header);
528 
529 	if (header_lines.size() > 0)
530 	{
531 		statement("");
532 	}
533 }
534 
emit_interface_block_globally(const SPIRVariable & var)535 void CompilerHLSL::emit_interface_block_globally(const SPIRVariable &var)
536 {
537 	add_resource_name(var.self);
538 
539 	// The global copies of I/O variables should not contain interpolation qualifiers.
540 	// These are emitted inside the interface structs.
541 	auto &flags = ir.meta[var.self].decoration.decoration_flags;
542 	auto old_flags = flags;
543 	flags.reset();
544 	statement("static ", variable_decl(var), ";");
545 	flags = old_flags;
546 }
547 
to_storage_qualifiers_glsl(const SPIRVariable & var)548 const char *CompilerHLSL::to_storage_qualifiers_glsl(const SPIRVariable &var)
549 {
550 	// Input and output variables are handled specially in HLSL backend.
551 	// The variables are declared as global, private variables, and do not need any qualifiers.
552 	if (var.storage == StorageClassUniformConstant || var.storage == StorageClassUniform ||
553 	    var.storage == StorageClassPushConstant)
554 	{
555 		return "uniform ";
556 	}
557 
558 	return "";
559 }
560 
emit_builtin_outputs_in_struct()561 void CompilerHLSL::emit_builtin_outputs_in_struct()
562 {
563 	auto &execution = get_entry_point();
564 
565 	bool legacy = hlsl_options.shader_model <= 30;
566 	active_output_builtins.for_each_bit([&](uint32_t i) {
567 		const char *type = nullptr;
568 		const char *semantic = nullptr;
569 		auto builtin = static_cast<BuiltIn>(i);
570 		switch (builtin)
571 		{
572 		case BuiltInPosition:
573 			type = "float4";
574 			semantic = legacy ? "POSITION" : "SV_Position";
575 			break;
576 
577 		case BuiltInSampleMask:
578 			if (hlsl_options.shader_model < 41 || execution.model != ExecutionModelFragment)
579 				SPIRV_CROSS_THROW("Sample Mask output is only supported in PS 4.1 or higher.");
580 			type = "uint";
581 			semantic = "SV_Coverage";
582 			break;
583 
584 		case BuiltInFragDepth:
585 			type = "float";
586 			if (legacy)
587 			{
588 				semantic = "DEPTH";
589 			}
590 			else
591 			{
592 				if (hlsl_options.shader_model >= 50 && execution.flags.get(ExecutionModeDepthGreater))
593 					semantic = "SV_DepthGreaterEqual";
594 				else if (hlsl_options.shader_model >= 50 && execution.flags.get(ExecutionModeDepthLess))
595 					semantic = "SV_DepthLessEqual";
596 				else
597 					semantic = "SV_Depth";
598 			}
599 			break;
600 
601 		case BuiltInClipDistance:
602 			// HLSL is a bit weird here, use SV_ClipDistance0, SV_ClipDistance1 and so on with vectors.
603 			for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
604 			{
605 				uint32_t to_declare = clip_distance_count - clip;
606 				if (to_declare > 4)
607 					to_declare = 4;
608 
609 				uint32_t semantic_index = clip / 4;
610 
611 				static const char *types[] = { "float", "float2", "float3", "float4" };
612 				statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassOutput), semantic_index,
613 				          " : SV_ClipDistance", semantic_index, ";");
614 			}
615 			break;
616 
617 		case BuiltInCullDistance:
618 			// HLSL is a bit weird here, use SV_CullDistance0, SV_CullDistance1 and so on with vectors.
619 			for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
620 			{
621 				uint32_t to_declare = cull_distance_count - cull;
622 				if (to_declare > 4)
623 					to_declare = 4;
624 
625 				uint32_t semantic_index = cull / 4;
626 
627 				static const char *types[] = { "float", "float2", "float3", "float4" };
628 				statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassOutput), semantic_index,
629 				          " : SV_CullDistance", semantic_index, ";");
630 			}
631 			break;
632 
633 		case BuiltInPointSize:
634 			// If point_size_compat is enabled, just ignore PointSize.
635 			// PointSize does not exist in HLSL, but some code bases might want to be able to use these shaders,
636 			// even if it means working around the missing feature.
637 			if (hlsl_options.point_size_compat)
638 				break;
639 			else
640 				SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
641 
642 		default:
643 			SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
644 			break;
645 		}
646 
647 		if (type && semantic)
648 			statement(type, " ", builtin_to_glsl(builtin, StorageClassOutput), " : ", semantic, ";");
649 	});
650 }
651 
emit_builtin_inputs_in_struct()652 void CompilerHLSL::emit_builtin_inputs_in_struct()
653 {
654 	bool legacy = hlsl_options.shader_model <= 30;
655 	active_input_builtins.for_each_bit([&](uint32_t i) {
656 		const char *type = nullptr;
657 		const char *semantic = nullptr;
658 		auto builtin = static_cast<BuiltIn>(i);
659 		switch (builtin)
660 		{
661 		case BuiltInFragCoord:
662 			type = "float4";
663 			semantic = legacy ? "VPOS" : "SV_Position";
664 			break;
665 
666 		case BuiltInVertexId:
667 		case BuiltInVertexIndex:
668 			if (legacy)
669 				SPIRV_CROSS_THROW("Vertex index not supported in SM 3.0 or lower.");
670 			type = "uint";
671 			semantic = "SV_VertexID";
672 			break;
673 
674 		case BuiltInInstanceId:
675 		case BuiltInInstanceIndex:
676 			if (legacy)
677 				SPIRV_CROSS_THROW("Instance index not supported in SM 3.0 or lower.");
678 			type = "uint";
679 			semantic = "SV_InstanceID";
680 			break;
681 
682 		case BuiltInSampleId:
683 			if (legacy)
684 				SPIRV_CROSS_THROW("Sample ID not supported in SM 3.0 or lower.");
685 			type = "uint";
686 			semantic = "SV_SampleIndex";
687 			break;
688 
689 		case BuiltInSampleMask:
690 			if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
691 				SPIRV_CROSS_THROW("Sample Mask input is only supported in PS 5.0 or higher.");
692 			type = "uint";
693 			semantic = "SV_Coverage";
694 			break;
695 
696 		case BuiltInGlobalInvocationId:
697 			type = "uint3";
698 			semantic = "SV_DispatchThreadID";
699 			break;
700 
701 		case BuiltInLocalInvocationId:
702 			type = "uint3";
703 			semantic = "SV_GroupThreadID";
704 			break;
705 
706 		case BuiltInLocalInvocationIndex:
707 			type = "uint";
708 			semantic = "SV_GroupIndex";
709 			break;
710 
711 		case BuiltInWorkgroupId:
712 			type = "uint3";
713 			semantic = "SV_GroupID";
714 			break;
715 
716 		case BuiltInFrontFacing:
717 			type = "bool";
718 			semantic = "SV_IsFrontFace";
719 			break;
720 
721 		case BuiltInNumWorkgroups:
722 		case BuiltInSubgroupSize:
723 		case BuiltInSubgroupLocalInvocationId:
724 		case BuiltInSubgroupEqMask:
725 		case BuiltInSubgroupLtMask:
726 		case BuiltInSubgroupLeMask:
727 		case BuiltInSubgroupGtMask:
728 		case BuiltInSubgroupGeMask:
729 			// Handled specially.
730 			break;
731 
732 		case BuiltInClipDistance:
733 			// HLSL is a bit weird here, use SV_ClipDistance0, SV_ClipDistance1 and so on with vectors.
734 			for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
735 			{
736 				uint32_t to_declare = clip_distance_count - clip;
737 				if (to_declare > 4)
738 					to_declare = 4;
739 
740 				uint32_t semantic_index = clip / 4;
741 
742 				static const char *types[] = { "float", "float2", "float3", "float4" };
743 				statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassInput), semantic_index,
744 				          " : SV_ClipDistance", semantic_index, ";");
745 			}
746 			break;
747 
748 		case BuiltInCullDistance:
749 			// HLSL is a bit weird here, use SV_CullDistance0, SV_CullDistance1 and so on with vectors.
750 			for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
751 			{
752 				uint32_t to_declare = cull_distance_count - cull;
753 				if (to_declare > 4)
754 					to_declare = 4;
755 
756 				uint32_t semantic_index = cull / 4;
757 
758 				static const char *types[] = { "float", "float2", "float3", "float4" };
759 				statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassInput), semantic_index,
760 				          " : SV_CullDistance", semantic_index, ";");
761 			}
762 			break;
763 
764 		case BuiltInPointCoord:
765 			// PointCoord is not supported, but provide a way to just ignore that, similar to PointSize.
766 			if (hlsl_options.point_coord_compat)
767 				break;
768 			else
769 				SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
770 
771 		default:
772 			SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
773 			break;
774 		}
775 
776 		if (type && semantic)
777 			statement(type, " ", builtin_to_glsl(builtin, StorageClassInput), " : ", semantic, ";");
778 	});
779 }
780 
type_to_consumed_locations(const SPIRType & type) const781 uint32_t CompilerHLSL::type_to_consumed_locations(const SPIRType &type) const
782 {
783 	// TODO: Need to verify correctness.
784 	uint32_t elements = 0;
785 
786 	if (type.basetype == SPIRType::Struct)
787 	{
788 		for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
789 			elements += type_to_consumed_locations(get<SPIRType>(type.member_types[i]));
790 	}
791 	else
792 	{
793 		uint32_t array_multiplier = 1;
794 		for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
795 		{
796 			if (type.array_size_literal[i])
797 				array_multiplier *= type.array[i];
798 			else
799 				array_multiplier *= evaluate_constant_u32(type.array[i]);
800 		}
801 		elements += array_multiplier * type.columns;
802 	}
803 	return elements;
804 }
805 
to_interpolation_qualifiers(const Bitset & flags)806 string CompilerHLSL::to_interpolation_qualifiers(const Bitset &flags)
807 {
808 	string res;
809 	//if (flags & (1ull << DecorationSmooth))
810 	//    res += "linear ";
811 	if (flags.get(DecorationFlat))
812 		res += "nointerpolation ";
813 	if (flags.get(DecorationNoPerspective))
814 		res += "noperspective ";
815 	if (flags.get(DecorationCentroid))
816 		res += "centroid ";
817 	if (flags.get(DecorationPatch))
818 		res += "patch "; // Seems to be different in actual HLSL.
819 	if (flags.get(DecorationSample))
820 		res += "sample ";
821 	if (flags.get(DecorationInvariant))
822 		res += "invariant "; // Not supported?
823 
824 	return res;
825 }
826 
to_semantic(uint32_t location,ExecutionModel em,StorageClass sc)827 std::string CompilerHLSL::to_semantic(uint32_t location, ExecutionModel em, StorageClass sc)
828 {
829 	if (em == ExecutionModelVertex && sc == StorageClassInput)
830 	{
831 		// We have a vertex attribute - we should look at remapping it if the user provided
832 		// vertex attribute hints.
833 		for (auto &attribute : remap_vertex_attributes)
834 			if (attribute.location == location)
835 				return attribute.semantic;
836 	}
837 
838 	// Not a vertex attribute, or no remap_vertex_attributes entry.
839 	return join("TEXCOORD", location);
840 }
841 
to_initializer_expression(const SPIRVariable & var)842 std::string CompilerHLSL::to_initializer_expression(const SPIRVariable &var)
843 {
844 	// We cannot emit static const initializer for block constants for practical reasons,
845 	// so just inline the initializer.
846 	// FIXME: There is a theoretical problem here if someone tries to composite extract
847 	// into this initializer since we don't declare it properly, but that is somewhat non-sensical.
848 	auto &type = get<SPIRType>(var.basetype);
849 	bool is_block = has_decoration(type.self, DecorationBlock);
850 	auto *c = maybe_get<SPIRConstant>(var.initializer);
851 	if (is_block && c)
852 		return constant_expression(*c);
853 	else
854 		return CompilerGLSL::to_initializer_expression(var);
855 }
856 
emit_io_block(const SPIRVariable & var)857 void CompilerHLSL::emit_io_block(const SPIRVariable &var)
858 {
859 	auto &execution = get_entry_point();
860 
861 	auto &type = get<SPIRType>(var.basetype);
862 	add_resource_name(type.self);
863 
864 	statement("struct ", to_name(type.self));
865 	begin_scope();
866 	type.member_name_cache.clear();
867 
868 	uint32_t base_location = get_decoration(var.self, DecorationLocation);
869 
870 	for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
871 	{
872 		string semantic;
873 		if (has_member_decoration(type.self, i, DecorationLocation))
874 		{
875 			uint32_t location = get_member_decoration(type.self, i, DecorationLocation);
876 			semantic = join(" : ", to_semantic(location, execution.model, var.storage));
877 		}
878 		else
879 		{
880 			// If the block itself has a location, but not its members, use the implicit location.
881 			// There could be a conflict if the block members partially specialize the locations.
882 			// It is unclear how SPIR-V deals with this. Assume this does not happen for now.
883 			uint32_t location = base_location + i;
884 			semantic = join(" : ", to_semantic(location, execution.model, var.storage));
885 		}
886 
887 		add_member_name(type, i);
888 
889 		auto &membertype = get<SPIRType>(type.member_types[i]);
890 		statement(to_interpolation_qualifiers(get_member_decoration_bitset(type.self, i)),
891 		          variable_decl(membertype, to_member_name(type, i)), semantic, ";");
892 	}
893 
894 	end_scope_decl();
895 	statement("");
896 
897 	statement("static ", variable_decl(var), ";");
898 	statement("");
899 }
900 
emit_interface_block_in_struct(const SPIRVariable & var,unordered_set<uint32_t> & active_locations)901 void CompilerHLSL::emit_interface_block_in_struct(const SPIRVariable &var, unordered_set<uint32_t> &active_locations)
902 {
903 	auto &execution = get_entry_point();
904 	auto type = get<SPIRType>(var.basetype);
905 
906 	string binding;
907 	bool use_location_number = true;
908 	bool legacy = hlsl_options.shader_model <= 30;
909 	if (execution.model == ExecutionModelFragment && var.storage == StorageClassOutput)
910 	{
911 		// Dual-source blending is achieved in HLSL by emitting to SV_Target0 and 1.
912 		uint32_t index = get_decoration(var.self, DecorationIndex);
913 		uint32_t location = get_decoration(var.self, DecorationLocation);
914 
915 		if (index != 0 && location != 0)
916 			SPIRV_CROSS_THROW("Dual-source blending is only supported on MRT #0 in HLSL.");
917 
918 		binding = join(legacy ? "COLOR" : "SV_Target", location + index);
919 		use_location_number = false;
920 		if (legacy) // COLOR must be a four-component vector on legacy shader model targets (HLSL ERR_COLOR_4COMP)
921 			type.vecsize = 4;
922 	}
923 
924 	const auto get_vacant_location = [&]() -> uint32_t {
925 		for (uint32_t i = 0; i < 64; i++)
926 			if (!active_locations.count(i))
927 				return i;
928 		SPIRV_CROSS_THROW("All locations from 0 to 63 are exhausted.");
929 	};
930 
931 	bool need_matrix_unroll = var.storage == StorageClassInput && execution.model == ExecutionModelVertex;
932 
933 	auto &m = ir.meta[var.self].decoration;
934 	auto name = to_name(var.self);
935 	if (use_location_number)
936 	{
937 		uint32_t location_number;
938 
939 		// If an explicit location exists, use it with TEXCOORD[N] semantic.
940 		// Otherwise, pick a vacant location.
941 		if (m.decoration_flags.get(DecorationLocation))
942 			location_number = m.location;
943 		else
944 			location_number = get_vacant_location();
945 
946 		// Allow semantic remap if specified.
947 		auto semantic = to_semantic(location_number, execution.model, var.storage);
948 
949 		if (need_matrix_unroll && type.columns > 1)
950 		{
951 			if (!type.array.empty())
952 				SPIRV_CROSS_THROW("Arrays of matrices used as input/output. This is not supported.");
953 
954 			// Unroll matrices.
955 			for (uint32_t i = 0; i < type.columns; i++)
956 			{
957 				SPIRType newtype = type;
958 				newtype.columns = 1;
959 
960 				string effective_semantic;
961 				if (hlsl_options.flatten_matrix_vertex_input_semantics)
962 					effective_semantic = to_semantic(location_number, execution.model, var.storage);
963 				else
964 					effective_semantic = join(semantic, "_", i);
965 
966 				statement(to_interpolation_qualifiers(get_decoration_bitset(var.self)),
967 				          variable_decl(newtype, join(name, "_", i)), " : ", effective_semantic, ";");
968 				active_locations.insert(location_number++);
969 			}
970 		}
971 		else
972 		{
973 			statement(to_interpolation_qualifiers(get_decoration_bitset(var.self)), variable_decl(type, name), " : ",
974 			          semantic, ";");
975 
976 			// Structs and arrays should consume more locations.
977 			uint32_t consumed_locations = type_to_consumed_locations(type);
978 			for (uint32_t i = 0; i < consumed_locations; i++)
979 				active_locations.insert(location_number + i);
980 		}
981 	}
982 	else
983 		statement(variable_decl(type, name), " : ", binding, ";");
984 }
985 
builtin_to_glsl(spv::BuiltIn builtin,spv::StorageClass storage)986 std::string CompilerHLSL::builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClass storage)
987 {
988 	switch (builtin)
989 	{
990 	case BuiltInVertexId:
991 		return "gl_VertexID";
992 	case BuiltInInstanceId:
993 		return "gl_InstanceID";
994 	case BuiltInNumWorkgroups:
995 	{
996 		if (!num_workgroups_builtin)
997 			SPIRV_CROSS_THROW("NumWorkgroups builtin is used, but remap_num_workgroups_builtin() was not called. "
998 			                  "Cannot emit code for this builtin.");
999 
1000 		auto &var = get<SPIRVariable>(num_workgroups_builtin);
1001 		auto &type = get<SPIRType>(var.basetype);
1002 		auto ret = join(to_name(num_workgroups_builtin), "_", get_member_name(type.self, 0));
1003 		ParsedIR::sanitize_underscores(ret);
1004 		return ret;
1005 	}
1006 	case BuiltInPointCoord:
1007 		// Crude hack, but there is no real alternative. This path is only enabled if point_coord_compat is set.
1008 		return "float2(0.5f, 0.5f)";
1009 	case BuiltInSubgroupLocalInvocationId:
1010 		return "WaveGetLaneIndex()";
1011 	case BuiltInSubgroupSize:
1012 		return "WaveGetLaneCount()";
1013 
1014 	default:
1015 		return CompilerGLSL::builtin_to_glsl(builtin, storage);
1016 	}
1017 }
1018 
emit_builtin_variables()1019 void CompilerHLSL::emit_builtin_variables()
1020 {
1021 	Bitset builtins = active_input_builtins;
1022 	builtins.merge_or(active_output_builtins);
1023 
1024 	bool need_base_vertex_info = false;
1025 
1026 	std::unordered_map<uint32_t, ID> builtin_to_initializer;
1027 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1028 		if (!is_builtin_variable(var) || var.storage != StorageClassOutput || !var.initializer)
1029 			return;
1030 
1031 		auto *c = this->maybe_get<SPIRConstant>(var.initializer);
1032 		if (!c)
1033 			return;
1034 
1035 		auto &type = this->get<SPIRType>(var.basetype);
1036 		if (type.basetype == SPIRType::Struct)
1037 		{
1038 			uint32_t member_count = uint32_t(type.member_types.size());
1039 			for (uint32_t i = 0; i < member_count; i++)
1040 			{
1041 				if (has_member_decoration(type.self, i, DecorationBuiltIn))
1042 				{
1043 					builtin_to_initializer[get_member_decoration(type.self, i, DecorationBuiltIn)] =
1044 						c->subconstants[i];
1045 				}
1046 			}
1047 		}
1048 		else if (has_decoration(var.self, DecorationBuiltIn))
1049 			builtin_to_initializer[get_decoration(var.self, DecorationBuiltIn)] = var.initializer;
1050 	});
1051 
1052 	// Emit global variables for the interface variables which are statically used by the shader.
1053 	builtins.for_each_bit([&](uint32_t i) {
1054 		const char *type = nullptr;
1055 		auto builtin = static_cast<BuiltIn>(i);
1056 		uint32_t array_size = 0;
1057 
1058 		string init_expr;
1059 		auto init_itr = builtin_to_initializer.find(builtin);
1060 		if (init_itr != builtin_to_initializer.end())
1061 			init_expr = join(" = ", to_expression(init_itr->second));
1062 
1063 		switch (builtin)
1064 		{
1065 		case BuiltInFragCoord:
1066 		case BuiltInPosition:
1067 			type = "float4";
1068 			break;
1069 
1070 		case BuiltInFragDepth:
1071 			type = "float";
1072 			break;
1073 
1074 		case BuiltInVertexId:
1075 		case BuiltInVertexIndex:
1076 		case BuiltInInstanceIndex:
1077 			type = "int";
1078 			if (hlsl_options.support_nonzero_base_vertex_base_instance)
1079 				need_base_vertex_info = true;
1080 			break;
1081 
1082 		case BuiltInInstanceId:
1083 		case BuiltInSampleId:
1084 			type = "int";
1085 			break;
1086 
1087 		case BuiltInPointSize:
1088 			if (hlsl_options.point_size_compat)
1089 			{
1090 				// Just emit the global variable, it will be ignored.
1091 				type = "float";
1092 				break;
1093 			}
1094 			else
1095 				SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
1096 
1097 		case BuiltInGlobalInvocationId:
1098 		case BuiltInLocalInvocationId:
1099 		case BuiltInWorkgroupId:
1100 			type = "uint3";
1101 			break;
1102 
1103 		case BuiltInLocalInvocationIndex:
1104 			type = "uint";
1105 			break;
1106 
1107 		case BuiltInFrontFacing:
1108 			type = "bool";
1109 			break;
1110 
1111 		case BuiltInNumWorkgroups:
1112 		case BuiltInPointCoord:
1113 			// Handled specially.
1114 			break;
1115 
1116 		case BuiltInSubgroupLocalInvocationId:
1117 		case BuiltInSubgroupSize:
1118 			if (hlsl_options.shader_model < 60)
1119 				SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
1120 			break;
1121 
1122 		case BuiltInSubgroupEqMask:
1123 		case BuiltInSubgroupLtMask:
1124 		case BuiltInSubgroupLeMask:
1125 		case BuiltInSubgroupGtMask:
1126 		case BuiltInSubgroupGeMask:
1127 			if (hlsl_options.shader_model < 60)
1128 				SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
1129 			type = "uint4";
1130 			break;
1131 
1132 		case BuiltInClipDistance:
1133 			array_size = clip_distance_count;
1134 			type = "float";
1135 			break;
1136 
1137 		case BuiltInCullDistance:
1138 			array_size = cull_distance_count;
1139 			type = "float";
1140 			break;
1141 
1142 		case BuiltInSampleMask:
1143 			type = "int";
1144 			break;
1145 
1146 		default:
1147 			SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
1148 		}
1149 
1150 		StorageClass storage = active_input_builtins.get(i) ? StorageClassInput : StorageClassOutput;
1151 
1152 		if (type)
1153 		{
1154 			if (array_size)
1155 				statement("static ", type, " ", builtin_to_glsl(builtin, storage), "[", array_size, "]", init_expr, ";");
1156 			else
1157 				statement("static ", type, " ", builtin_to_glsl(builtin, storage), init_expr, ";");
1158 		}
1159 
1160 		// SampleMask can be both in and out with sample builtin, in this case we have already
1161 		// declared the input variable and we need to add the output one now.
1162 		if (builtin == BuiltInSampleMask && storage == StorageClassInput && this->active_output_builtins.get(i))
1163 		{
1164 			statement("static ", type, " ", this->builtin_to_glsl(builtin, StorageClassOutput), init_expr, ";");
1165 		}
1166 	});
1167 
1168 	if (need_base_vertex_info)
1169 	{
1170 		statement("cbuffer SPIRV_Cross_VertexInfo");
1171 		begin_scope();
1172 		statement("int SPIRV_Cross_BaseVertex;");
1173 		statement("int SPIRV_Cross_BaseInstance;");
1174 		end_scope_decl();
1175 		statement("");
1176 	}
1177 }
1178 
emit_composite_constants()1179 void CompilerHLSL::emit_composite_constants()
1180 {
1181 	// HLSL cannot declare structs or arrays inline, so we must move them out to
1182 	// global constants directly.
1183 	bool emitted = false;
1184 
1185 	ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
1186 		if (c.specialization)
1187 			return;
1188 
1189 		auto &type = this->get<SPIRType>(c.constant_type);
1190 
1191 		// Cannot declare block type constants here.
1192 		// We do not have the struct type yet.
1193 		bool is_block = has_decoration(type.self, DecorationBlock);
1194 		if (!is_block && (type.basetype == SPIRType::Struct || !type.array.empty()))
1195 		{
1196 			auto name = to_name(c.self);
1197 			statement("static const ", variable_decl(type, name), " = ", constant_expression(c), ";");
1198 			emitted = true;
1199 		}
1200 	});
1201 
1202 	if (emitted)
1203 		statement("");
1204 }
1205 
emit_specialization_constants_and_structs()1206 void CompilerHLSL::emit_specialization_constants_and_structs()
1207 {
1208 	bool emitted = false;
1209 	SpecializationConstant wg_x, wg_y, wg_z;
1210 	ID workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
1211 
1212 	auto loop_lock = ir.create_loop_hard_lock();
1213 	for (auto &id_ : ir.ids_for_constant_or_type)
1214 	{
1215 		auto &id = ir.ids[id_];
1216 
1217 		if (id.get_type() == TypeConstant)
1218 		{
1219 			auto &c = id.get<SPIRConstant>();
1220 
1221 			if (c.self == workgroup_size_id)
1222 			{
1223 				statement("static const uint3 gl_WorkGroupSize = ",
1224 				          constant_expression(get<SPIRConstant>(workgroup_size_id)), ";");
1225 				emitted = true;
1226 			}
1227 			else if (c.specialization)
1228 			{
1229 				auto &type = get<SPIRType>(c.constant_type);
1230 				auto name = to_name(c.self);
1231 
1232 				// HLSL does not support specialization constants, so fallback to macros.
1233 				c.specialization_constant_macro_name =
1234 				    constant_value_macro_name(get_decoration(c.self, DecorationSpecId));
1235 
1236 				statement("#ifndef ", c.specialization_constant_macro_name);
1237 				statement("#define ", c.specialization_constant_macro_name, " ", constant_expression(c));
1238 				statement("#endif");
1239 				statement("static const ", variable_decl(type, name), " = ", c.specialization_constant_macro_name, ";");
1240 				emitted = true;
1241 			}
1242 		}
1243 		else if (id.get_type() == TypeConstantOp)
1244 		{
1245 			auto &c = id.get<SPIRConstantOp>();
1246 			auto &type = get<SPIRType>(c.basetype);
1247 			auto name = to_name(c.self);
1248 			statement("static const ", variable_decl(type, name), " = ", constant_op_expression(c), ";");
1249 			emitted = true;
1250 		}
1251 		else if (id.get_type() == TypeType)
1252 		{
1253 			auto &type = id.get<SPIRType>();
1254 			if (type.basetype == SPIRType::Struct && type.array.empty() && !type.pointer &&
1255 			    (!ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock) &&
1256 			     !ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock)))
1257 			{
1258 				if (emitted)
1259 					statement("");
1260 				emitted = false;
1261 
1262 				emit_struct(type);
1263 			}
1264 		}
1265 	}
1266 
1267 	if (emitted)
1268 		statement("");
1269 }
1270 
replace_illegal_names()1271 void CompilerHLSL::replace_illegal_names()
1272 {
1273 	static const unordered_set<string> keywords = {
1274 		// Additional HLSL specific keywords.
1275 		"line", "linear", "matrix", "point", "row_major", "sampler",
1276 	};
1277 
1278 	CompilerGLSL::replace_illegal_names(keywords);
1279 	CompilerGLSL::replace_illegal_names();
1280 }
1281 
declare_undefined_values()1282 void CompilerHLSL::declare_undefined_values()
1283 {
1284 	bool emitted = false;
1285 	ir.for_each_typed_id<SPIRUndef>([&](uint32_t, const SPIRUndef &undef) {
1286 		auto &type = this->get<SPIRType>(undef.basetype);
1287 		// OpUndef can be void for some reason ...
1288 		if (type.basetype == SPIRType::Void)
1289 			return;
1290 
1291 		string initializer;
1292 		if (options.force_zero_initialized_variables && type_can_zero_initialize(type))
1293 			initializer = join(" = ", to_zero_initialized_expression(undef.basetype));
1294 
1295 		statement("static ", variable_decl(type, to_name(undef.self), undef.self), initializer, ";");
1296 		emitted = true;
1297 	});
1298 
1299 	if (emitted)
1300 		statement("");
1301 }
1302 
emit_resources()1303 void CompilerHLSL::emit_resources()
1304 {
1305 	auto &execution = get_entry_point();
1306 
1307 	replace_illegal_names();
1308 
1309 	emit_specialization_constants_and_structs();
1310 	emit_composite_constants();
1311 
1312 	bool emitted = false;
1313 
1314 	// Output UBOs and SSBOs
1315 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1316 		auto &type = this->get<SPIRType>(var.basetype);
1317 
1318 		bool is_block_storage = type.storage == StorageClassStorageBuffer || type.storage == StorageClassUniform;
1319 		bool has_block_flags = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock) ||
1320 		                       ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock);
1321 
1322 		if (var.storage != StorageClassFunction && type.pointer && is_block_storage && !is_hidden_variable(var) &&
1323 		    has_block_flags)
1324 		{
1325 			emit_buffer_block(var);
1326 			emitted = true;
1327 		}
1328 	});
1329 
1330 	// Output push constant blocks
1331 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1332 		auto &type = this->get<SPIRType>(var.basetype);
1333 		if (var.storage != StorageClassFunction && type.pointer && type.storage == StorageClassPushConstant &&
1334 		    !is_hidden_variable(var))
1335 		{
1336 			emit_push_constant_block(var);
1337 			emitted = true;
1338 		}
1339 	});
1340 
1341 	if (execution.model == ExecutionModelVertex && hlsl_options.shader_model <= 30)
1342 	{
1343 		statement("uniform float4 gl_HalfPixel;");
1344 		emitted = true;
1345 	}
1346 
1347 	bool skip_separate_image_sampler = !combined_image_samplers.empty() || hlsl_options.shader_model <= 30;
1348 
1349 	// Output Uniform Constants (values, samplers, images, etc).
1350 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1351 		auto &type = this->get<SPIRType>(var.basetype);
1352 
1353 		// If we're remapping separate samplers and images, only emit the combined samplers.
1354 		if (skip_separate_image_sampler)
1355 		{
1356 			// Sampler buffers are always used without a sampler, and they will also work in regular D3D.
1357 			bool sampler_buffer = type.basetype == SPIRType::Image && type.image.dim == DimBuffer;
1358 			bool separate_image = type.basetype == SPIRType::Image && type.image.sampled == 1;
1359 			bool separate_sampler = type.basetype == SPIRType::Sampler;
1360 			if (!sampler_buffer && (separate_image || separate_sampler))
1361 				return;
1362 		}
1363 
1364 		if (var.storage != StorageClassFunction && !is_builtin_variable(var) && !var.remapped_variable &&
1365 		    type.pointer && (type.storage == StorageClassUniformConstant || type.storage == StorageClassAtomicCounter))
1366 		{
1367 			emit_uniform(var);
1368 			emitted = true;
1369 		}
1370 	});
1371 
1372 	if (emitted)
1373 		statement("");
1374 	emitted = false;
1375 
1376 	// Emit builtin input and output variables here.
1377 	emit_builtin_variables();
1378 
1379 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1380 		auto &type = this->get<SPIRType>(var.basetype);
1381 		bool block = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock);
1382 
1383 		// Do not emit I/O blocks here.
1384 		// I/O blocks can be arrayed, so we must deal with them separately to support geometry shaders
1385 		// and tessellation down the line.
1386 		if (!block && var.storage != StorageClassFunction && !var.remapped_variable && type.pointer &&
1387 		    (var.storage == StorageClassInput || var.storage == StorageClassOutput) && !is_builtin_variable(var) &&
1388 		    interface_variable_exists_in_entry_point(var.self))
1389 		{
1390 			// Only emit non-builtins which are not blocks here. Builtin variables are handled separately.
1391 			emit_interface_block_globally(var);
1392 			emitted = true;
1393 		}
1394 	});
1395 
1396 	if (emitted)
1397 		statement("");
1398 	emitted = false;
1399 
1400 	require_input = false;
1401 	require_output = false;
1402 	unordered_set<uint32_t> active_inputs;
1403 	unordered_set<uint32_t> active_outputs;
1404 	SmallVector<SPIRVariable *> input_variables;
1405 	SmallVector<SPIRVariable *> output_variables;
1406 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1407 		auto &type = this->get<SPIRType>(var.basetype);
1408 		bool block = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock);
1409 
1410 		if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
1411 			return;
1412 
1413 		// Do not emit I/O blocks here.
1414 		// I/O blocks can be arrayed, so we must deal with them separately to support geometry shaders
1415 		// and tessellation down the line.
1416 		if (!block && !var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
1417 		    interface_variable_exists_in_entry_point(var.self))
1418 		{
1419 			if (var.storage == StorageClassInput)
1420 				input_variables.push_back(&var);
1421 			else
1422 				output_variables.push_back(&var);
1423 		}
1424 
1425 		// Reserve input and output locations for block variables as necessary.
1426 		if (block && !is_builtin_variable(var) && interface_variable_exists_in_entry_point(var.self))
1427 		{
1428 			auto &active = var.storage == StorageClassInput ? active_inputs : active_outputs;
1429 			for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
1430 			{
1431 				if (has_member_decoration(type.self, i, DecorationLocation))
1432 				{
1433 					uint32_t location = get_member_decoration(type.self, i, DecorationLocation);
1434 					active.insert(location);
1435 				}
1436 			}
1437 
1438 			// Emit the block struct and a global variable here.
1439 			emit_io_block(var);
1440 		}
1441 	});
1442 
1443 	const auto variable_compare = [&](const SPIRVariable *a, const SPIRVariable *b) -> bool {
1444 		// Sort input and output variables based on, from more robust to less robust:
1445 		// - Location
1446 		// - Variable has a location
1447 		// - Name comparison
1448 		// - Variable has a name
1449 		// - Fallback: ID
1450 		bool has_location_a = has_decoration(a->self, DecorationLocation);
1451 		bool has_location_b = has_decoration(b->self, DecorationLocation);
1452 
1453 		if (has_location_a && has_location_b)
1454 		{
1455 			return get_decoration(a->self, DecorationLocation) < get_decoration(b->self, DecorationLocation);
1456 		}
1457 		else if (has_location_a && !has_location_b)
1458 			return true;
1459 		else if (!has_location_a && has_location_b)
1460 			return false;
1461 
1462 		const auto &name1 = to_name(a->self);
1463 		const auto &name2 = to_name(b->self);
1464 
1465 		if (name1.empty() && name2.empty())
1466 			return a->self < b->self;
1467 		else if (name1.empty())
1468 			return true;
1469 		else if (name2.empty())
1470 			return false;
1471 
1472 		return name1.compare(name2) < 0;
1473 	};
1474 
1475 	auto input_builtins = active_input_builtins;
1476 	input_builtins.clear(BuiltInNumWorkgroups);
1477 	input_builtins.clear(BuiltInPointCoord);
1478 	input_builtins.clear(BuiltInSubgroupSize);
1479 	input_builtins.clear(BuiltInSubgroupLocalInvocationId);
1480 	input_builtins.clear(BuiltInSubgroupEqMask);
1481 	input_builtins.clear(BuiltInSubgroupLtMask);
1482 	input_builtins.clear(BuiltInSubgroupLeMask);
1483 	input_builtins.clear(BuiltInSubgroupGtMask);
1484 	input_builtins.clear(BuiltInSubgroupGeMask);
1485 
1486 	if (!input_variables.empty() || !input_builtins.empty())
1487 	{
1488 		require_input = true;
1489 		statement("struct SPIRV_Cross_Input");
1490 
1491 		begin_scope();
1492 		sort(input_variables.begin(), input_variables.end(), variable_compare);
1493 		for (auto var : input_variables)
1494 			emit_interface_block_in_struct(*var, active_inputs);
1495 		emit_builtin_inputs_in_struct();
1496 		end_scope_decl();
1497 		statement("");
1498 	}
1499 
1500 	if (!output_variables.empty() || !active_output_builtins.empty())
1501 	{
1502 		require_output = true;
1503 		statement("struct SPIRV_Cross_Output");
1504 
1505 		begin_scope();
1506 		// FIXME: Use locations properly if they exist.
1507 		sort(output_variables.begin(), output_variables.end(), variable_compare);
1508 		for (auto var : output_variables)
1509 			emit_interface_block_in_struct(*var, active_outputs);
1510 		emit_builtin_outputs_in_struct();
1511 		end_scope_decl();
1512 		statement("");
1513 	}
1514 
1515 	// Global variables.
1516 	for (auto global : global_variables)
1517 	{
1518 		auto &var = get<SPIRVariable>(global);
1519 		if (var.storage != StorageClassOutput)
1520 		{
1521 			if (!variable_is_lut(var))
1522 			{
1523 				add_resource_name(var.self);
1524 
1525 				const char *storage = nullptr;
1526 				switch (var.storage)
1527 				{
1528 				case StorageClassWorkgroup:
1529 					storage = "groupshared";
1530 					break;
1531 
1532 				default:
1533 					storage = "static";
1534 					break;
1535 				}
1536 
1537 				string initializer;
1538 				if (options.force_zero_initialized_variables && var.storage == StorageClassPrivate &&
1539 				    !var.initializer && !var.static_expression && type_can_zero_initialize(get_variable_data_type(var)))
1540 				{
1541 					initializer = join(" = ", to_zero_initialized_expression(get_variable_data_type_id(var)));
1542 				}
1543 				statement(storage, " ", variable_decl(var), initializer, ";");
1544 
1545 				emitted = true;
1546 			}
1547 		}
1548 	}
1549 
1550 	if (emitted)
1551 		statement("");
1552 
1553 	declare_undefined_values();
1554 
1555 	if (requires_op_fmod)
1556 	{
1557 		static const char *types[] = {
1558 			"float",
1559 			"float2",
1560 			"float3",
1561 			"float4",
1562 		};
1563 
1564 		for (auto &type : types)
1565 		{
1566 			statement(type, " mod(", type, " x, ", type, " y)");
1567 			begin_scope();
1568 			statement("return x - y * floor(x / y);");
1569 			end_scope();
1570 			statement("");
1571 		}
1572 	}
1573 
1574 	emit_texture_size_variants(required_texture_size_variants.srv, "4", false, "");
1575 	for (uint32_t norm = 0; norm < 3; norm++)
1576 	{
1577 		for (uint32_t comp = 0; comp < 4; comp++)
1578 		{
1579 			static const char *qualifiers[] = { "", "unorm ", "snorm " };
1580 			static const char *vecsizes[] = { "", "2", "3", "4" };
1581 			emit_texture_size_variants(required_texture_size_variants.uav[norm][comp], vecsizes[comp], true,
1582 			                           qualifiers[norm]);
1583 		}
1584 	}
1585 
1586 	if (requires_fp16_packing)
1587 	{
1588 		// HLSL does not pack into a single word sadly :(
1589 		statement("uint spvPackHalf2x16(float2 value)");
1590 		begin_scope();
1591 		statement("uint2 Packed = f32tof16(value);");
1592 		statement("return Packed.x | (Packed.y << 16);");
1593 		end_scope();
1594 		statement("");
1595 
1596 		statement("float2 spvUnpackHalf2x16(uint value)");
1597 		begin_scope();
1598 		statement("return f16tof32(uint2(value & 0xffff, value >> 16));");
1599 		end_scope();
1600 		statement("");
1601 	}
1602 
1603 	if (requires_uint2_packing)
1604 	{
1605 		statement("uint64_t spvPackUint2x32(uint2 value)");
1606 		begin_scope();
1607 		statement("return (uint64_t(value.y) << 32) | uint64_t(value.x);");
1608 		end_scope();
1609 		statement("");
1610 
1611 		statement("uint2 spvUnpackUint2x32(uint64_t value)");
1612 		begin_scope();
1613 		statement("uint2 Unpacked;");
1614 		statement("Unpacked.x = uint(value & 0xffffffff);");
1615 		statement("Unpacked.y = uint(value >> 32);");
1616 		statement("return Unpacked;");
1617 		end_scope();
1618 		statement("");
1619 	}
1620 
1621 	if (requires_explicit_fp16_packing)
1622 	{
1623 		// HLSL does not pack into a single word sadly :(
1624 		statement("uint spvPackFloat2x16(min16float2 value)");
1625 		begin_scope();
1626 		statement("uint2 Packed = f32tof16(value);");
1627 		statement("return Packed.x | (Packed.y << 16);");
1628 		end_scope();
1629 		statement("");
1630 
1631 		statement("min16float2 spvUnpackFloat2x16(uint value)");
1632 		begin_scope();
1633 		statement("return min16float2(f16tof32(uint2(value & 0xffff, value >> 16)));");
1634 		end_scope();
1635 		statement("");
1636 	}
1637 
1638 	// HLSL does not seem to have builtins for these operation, so roll them by hand ...
1639 	if (requires_unorm8_packing)
1640 	{
1641 		statement("uint spvPackUnorm4x8(float4 value)");
1642 		begin_scope();
1643 		statement("uint4 Packed = uint4(round(saturate(value) * 255.0));");
1644 		statement("return Packed.x | (Packed.y << 8) | (Packed.z << 16) | (Packed.w << 24);");
1645 		end_scope();
1646 		statement("");
1647 
1648 		statement("float4 spvUnpackUnorm4x8(uint value)");
1649 		begin_scope();
1650 		statement("uint4 Packed = uint4(value & 0xff, (value >> 8) & 0xff, (value >> 16) & 0xff, value >> 24);");
1651 		statement("return float4(Packed) / 255.0;");
1652 		end_scope();
1653 		statement("");
1654 	}
1655 
1656 	if (requires_snorm8_packing)
1657 	{
1658 		statement("uint spvPackSnorm4x8(float4 value)");
1659 		begin_scope();
1660 		statement("int4 Packed = int4(round(clamp(value, -1.0, 1.0) * 127.0)) & 0xff;");
1661 		statement("return uint(Packed.x | (Packed.y << 8) | (Packed.z << 16) | (Packed.w << 24));");
1662 		end_scope();
1663 		statement("");
1664 
1665 		statement("float4 spvUnpackSnorm4x8(uint value)");
1666 		begin_scope();
1667 		statement("int SignedValue = int(value);");
1668 		statement("int4 Packed = int4(SignedValue << 24, SignedValue << 16, SignedValue << 8, SignedValue) >> 24;");
1669 		statement("return clamp(float4(Packed) / 127.0, -1.0, 1.0);");
1670 		end_scope();
1671 		statement("");
1672 	}
1673 
1674 	if (requires_unorm16_packing)
1675 	{
1676 		statement("uint spvPackUnorm2x16(float2 value)");
1677 		begin_scope();
1678 		statement("uint2 Packed = uint2(round(saturate(value) * 65535.0));");
1679 		statement("return Packed.x | (Packed.y << 16);");
1680 		end_scope();
1681 		statement("");
1682 
1683 		statement("float2 spvUnpackUnorm2x16(uint value)");
1684 		begin_scope();
1685 		statement("uint2 Packed = uint2(value & 0xffff, value >> 16);");
1686 		statement("return float2(Packed) / 65535.0;");
1687 		end_scope();
1688 		statement("");
1689 	}
1690 
1691 	if (requires_snorm16_packing)
1692 	{
1693 		statement("uint spvPackSnorm2x16(float2 value)");
1694 		begin_scope();
1695 		statement("int2 Packed = int2(round(clamp(value, -1.0, 1.0) * 32767.0)) & 0xffff;");
1696 		statement("return uint(Packed.x | (Packed.y << 16));");
1697 		end_scope();
1698 		statement("");
1699 
1700 		statement("float2 spvUnpackSnorm2x16(uint value)");
1701 		begin_scope();
1702 		statement("int SignedValue = int(value);");
1703 		statement("int2 Packed = int2(SignedValue << 16, SignedValue) >> 16;");
1704 		statement("return clamp(float2(Packed) / 32767.0, -1.0, 1.0);");
1705 		end_scope();
1706 		statement("");
1707 	}
1708 
1709 	if (requires_bitfield_insert)
1710 	{
1711 		static const char *types[] = { "uint", "uint2", "uint3", "uint4" };
1712 		for (auto &type : types)
1713 		{
1714 			statement(type, " spvBitfieldInsert(", type, " Base, ", type, " Insert, uint Offset, uint Count)");
1715 			begin_scope();
1716 			statement("uint Mask = Count == 32 ? 0xffffffff : (((1u << Count) - 1) << (Offset & 31));");
1717 			statement("return (Base & ~Mask) | ((Insert << Offset) & Mask);");
1718 			end_scope();
1719 			statement("");
1720 		}
1721 	}
1722 
1723 	if (requires_bitfield_extract)
1724 	{
1725 		static const char *unsigned_types[] = { "uint", "uint2", "uint3", "uint4" };
1726 		for (auto &type : unsigned_types)
1727 		{
1728 			statement(type, " spvBitfieldUExtract(", type, " Base, uint Offset, uint Count)");
1729 			begin_scope();
1730 			statement("uint Mask = Count == 32 ? 0xffffffff : ((1 << Count) - 1);");
1731 			statement("return (Base >> Offset) & Mask;");
1732 			end_scope();
1733 			statement("");
1734 		}
1735 
1736 		// In this overload, we will have to do sign-extension, which we will emulate by shifting up and down.
1737 		static const char *signed_types[] = { "int", "int2", "int3", "int4" };
1738 		for (auto &type : signed_types)
1739 		{
1740 			statement(type, " spvBitfieldSExtract(", type, " Base, int Offset, int Count)");
1741 			begin_scope();
1742 			statement("int Mask = Count == 32 ? -1 : ((1 << Count) - 1);");
1743 			statement(type, " Masked = (Base >> Offset) & Mask;");
1744 			statement("int ExtendShift = (32 - Count) & 31;");
1745 			statement("return (Masked << ExtendShift) >> ExtendShift;");
1746 			end_scope();
1747 			statement("");
1748 		}
1749 	}
1750 
1751 	if (requires_inverse_2x2)
1752 	{
1753 		statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1754 		statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1755 		statement("float2x2 spvInverse(float2x2 m)");
1756 		begin_scope();
1757 		statement("float2x2 adj;	// The adjoint matrix (inverse after dividing by determinant)");
1758 		statement_no_indent("");
1759 		statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1760 		statement("adj[0][0] =  m[1][1];");
1761 		statement("adj[0][1] = -m[0][1];");
1762 		statement_no_indent("");
1763 		statement("adj[1][0] = -m[1][0];");
1764 		statement("adj[1][1] =  m[0][0];");
1765 		statement_no_indent("");
1766 		statement("// Calculate the determinant as a combination of the cofactors of the first row.");
1767 		statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
1768 		statement_no_indent("");
1769 		statement("// Divide the classical adjoint matrix by the determinant.");
1770 		statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1771 		statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1772 		end_scope();
1773 		statement("");
1774 	}
1775 
1776 	if (requires_inverse_3x3)
1777 	{
1778 		statement("// Returns the determinant of a 2x2 matrix.");
1779 		statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
1780 		begin_scope();
1781 		statement("return a1 * b2 - b1 * a2;");
1782 		end_scope();
1783 		statement_no_indent("");
1784 		statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1785 		statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1786 		statement("float3x3 spvInverse(float3x3 m)");
1787 		begin_scope();
1788 		statement("float3x3 adj;	// The adjoint matrix (inverse after dividing by determinant)");
1789 		statement_no_indent("");
1790 		statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1791 		statement("adj[0][0] =  spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
1792 		statement("adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
1793 		statement("adj[0][2] =  spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
1794 		statement_no_indent("");
1795 		statement("adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
1796 		statement("adj[1][1] =  spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
1797 		statement("adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
1798 		statement_no_indent("");
1799 		statement("adj[2][0] =  spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
1800 		statement("adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
1801 		statement("adj[2][2] =  spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
1802 		statement_no_indent("");
1803 		statement("// Calculate the determinant as a combination of the cofactors of the first row.");
1804 		statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
1805 		statement_no_indent("");
1806 		statement("// Divide the classical adjoint matrix by the determinant.");
1807 		statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1808 		statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1809 		end_scope();
1810 		statement("");
1811 	}
1812 
1813 	if (requires_inverse_4x4)
1814 	{
1815 		if (!requires_inverse_3x3)
1816 		{
1817 			statement("// Returns the determinant of a 2x2 matrix.");
1818 			statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
1819 			begin_scope();
1820 			statement("return a1 * b2 - b1 * a2;");
1821 			end_scope();
1822 			statement("");
1823 		}
1824 
1825 		statement("// Returns the determinant of a 3x3 matrix.");
1826 		statement("float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
1827 		          "float c2, float c3)");
1828 		begin_scope();
1829 		statement("return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * "
1830 		          "spvDet2x2(a2, a3, "
1831 		          "b2, b3);");
1832 		end_scope();
1833 		statement_no_indent("");
1834 		statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1835 		statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1836 		statement("float4x4 spvInverse(float4x4 m)");
1837 		begin_scope();
1838 		statement("float4x4 adj;	// The adjoint matrix (inverse after dividing by determinant)");
1839 		statement_no_indent("");
1840 		statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1841 		statement(
1842 		    "adj[0][0] =  spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
1843 		    "m[3][3]);");
1844 		statement(
1845 		    "adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
1846 		    "m[3][3]);");
1847 		statement(
1848 		    "adj[0][2] =  spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
1849 		    "m[3][3]);");
1850 		statement(
1851 		    "adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
1852 		    "m[2][3]);");
1853 		statement_no_indent("");
1854 		statement(
1855 		    "adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
1856 		    "m[3][3]);");
1857 		statement(
1858 		    "adj[1][1] =  spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
1859 		    "m[3][3]);");
1860 		statement(
1861 		    "adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
1862 		    "m[3][3]);");
1863 		statement(
1864 		    "adj[1][3] =  spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
1865 		    "m[2][3]);");
1866 		statement_no_indent("");
1867 		statement(
1868 		    "adj[2][0] =  spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
1869 		    "m[3][3]);");
1870 		statement(
1871 		    "adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
1872 		    "m[3][3]);");
1873 		statement(
1874 		    "adj[2][2] =  spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
1875 		    "m[3][3]);");
1876 		statement(
1877 		    "adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
1878 		    "m[2][3]);");
1879 		statement_no_indent("");
1880 		statement(
1881 		    "adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
1882 		    "m[3][2]);");
1883 		statement(
1884 		    "adj[3][1] =  spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
1885 		    "m[3][2]);");
1886 		statement(
1887 		    "adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
1888 		    "m[3][2]);");
1889 		statement(
1890 		    "adj[3][3] =  spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
1891 		    "m[2][2]);");
1892 		statement_no_indent("");
1893 		statement("// Calculate the determinant as a combination of the cofactors of the first row.");
1894 		statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
1895 		          "* m[3][0]);");
1896 		statement_no_indent("");
1897 		statement("// Divide the classical adjoint matrix by the determinant.");
1898 		statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1899 		statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1900 		end_scope();
1901 		statement("");
1902 	}
1903 
1904 	if (requires_scalar_reflect)
1905 	{
1906 		// FP16/FP64? No templates in HLSL.
1907 		statement("float spvReflect(float i, float n)");
1908 		begin_scope();
1909 		statement("return i - 2.0 * dot(n, i) * n;");
1910 		end_scope();
1911 		statement("");
1912 	}
1913 
1914 	if (requires_scalar_refract)
1915 	{
1916 		// FP16/FP64? No templates in HLSL.
1917 		statement("float spvRefract(float i, float n, float eta)");
1918 		begin_scope();
1919 		statement("float NoI = n * i;");
1920 		statement("float NoI2 = NoI * NoI;");
1921 		statement("float k = 1.0 - eta * eta * (1.0 - NoI2);");
1922 		statement("if (k < 0.0)");
1923 		begin_scope();
1924 		statement("return 0.0;");
1925 		end_scope();
1926 		statement("else");
1927 		begin_scope();
1928 		statement("return eta * i - (eta * NoI + sqrt(k)) * n;");
1929 		end_scope();
1930 		end_scope();
1931 		statement("");
1932 	}
1933 
1934 	if (requires_scalar_faceforward)
1935 	{
1936 		// FP16/FP64? No templates in HLSL.
1937 		statement("float spvFaceForward(float n, float i, float nref)");
1938 		begin_scope();
1939 		statement("return i * nref < 0.0 ? n : -n;");
1940 		end_scope();
1941 		statement("");
1942 	}
1943 }
1944 
emit_texture_size_variants(uint64_t variant_mask,const char * vecsize_qualifier,bool uav,const char * type_qualifier)1945 void CompilerHLSL::emit_texture_size_variants(uint64_t variant_mask, const char *vecsize_qualifier, bool uav,
1946                                               const char *type_qualifier)
1947 {
1948 	if (variant_mask == 0)
1949 		return;
1950 
1951 	static const char *types[QueryTypeCount] = { "float", "int", "uint" };
1952 	static const char *dims[QueryDimCount] = { "Texture1D",   "Texture1DArray",  "Texture2D",   "Texture2DArray",
1953 		                                       "Texture3D",   "Buffer",          "TextureCube", "TextureCubeArray",
1954 		                                       "Texture2DMS", "Texture2DMSArray" };
1955 
1956 	static const bool has_lod[QueryDimCount] = { true, true, true, true, true, false, true, true, false, false };
1957 
1958 	static const char *ret_types[QueryDimCount] = {
1959 		"uint", "uint2", "uint2", "uint3", "uint3", "uint", "uint2", "uint3", "uint2", "uint3",
1960 	};
1961 
1962 	static const uint32_t return_arguments[QueryDimCount] = {
1963 		1, 2, 2, 3, 3, 1, 2, 3, 2, 3,
1964 	};
1965 
1966 	for (uint32_t index = 0; index < QueryDimCount; index++)
1967 	{
1968 		for (uint32_t type_index = 0; type_index < QueryTypeCount; type_index++)
1969 		{
1970 			uint32_t bit = 16 * type_index + index;
1971 			uint64_t mask = 1ull << bit;
1972 
1973 			if ((variant_mask & mask) == 0)
1974 				continue;
1975 
1976 			statement(ret_types[index], " spv", (uav ? "Image" : "Texture"), "Size(", (uav ? "RW" : ""),
1977 			          dims[index], "<", type_qualifier, types[type_index], vecsize_qualifier, "> Tex, ",
1978 			          (uav ? "" : "uint Level, "), "out uint Param)");
1979 			begin_scope();
1980 			statement(ret_types[index], " ret;");
1981 			switch (return_arguments[index])
1982 			{
1983 			case 1:
1984 				if (has_lod[index] && !uav)
1985 					statement("Tex.GetDimensions(Level, ret.x, Param);");
1986 				else
1987 				{
1988 					statement("Tex.GetDimensions(ret.x);");
1989 					statement("Param = 0u;");
1990 				}
1991 				break;
1992 			case 2:
1993 				if (has_lod[index] && !uav)
1994 					statement("Tex.GetDimensions(Level, ret.x, ret.y, Param);");
1995 				else if (!uav)
1996 					statement("Tex.GetDimensions(ret.x, ret.y, Param);");
1997 				else
1998 				{
1999 					statement("Tex.GetDimensions(ret.x, ret.y);");
2000 					statement("Param = 0u;");
2001 				}
2002 				break;
2003 			case 3:
2004 				if (has_lod[index] && !uav)
2005 					statement("Tex.GetDimensions(Level, ret.x, ret.y, ret.z, Param);");
2006 				else if (!uav)
2007 					statement("Tex.GetDimensions(ret.x, ret.y, ret.z, Param);");
2008 				else
2009 				{
2010 					statement("Tex.GetDimensions(ret.x, ret.y, ret.z);");
2011 					statement("Param = 0u;");
2012 				}
2013 				break;
2014 			}
2015 
2016 			statement("return ret;");
2017 			end_scope();
2018 			statement("");
2019 		}
2020 	}
2021 }
2022 
layout_for_member(const SPIRType & type,uint32_t index)2023 string CompilerHLSL::layout_for_member(const SPIRType &type, uint32_t index)
2024 {
2025 	auto &flags = get_member_decoration_bitset(type.self, index);
2026 
2027 	// HLSL can emit row_major or column_major decoration in any struct.
2028 	// Do not try to merge combined decorations for children like in GLSL.
2029 
2030 	// Flip the convention. HLSL is a bit odd in that the memory layout is column major ... but the language API is "row-major".
2031 	// The way to deal with this is to multiply everything in inverse order, and reverse the memory layout.
2032 	if (flags.get(DecorationColMajor))
2033 		return "row_major ";
2034 	else if (flags.get(DecorationRowMajor))
2035 		return "column_major ";
2036 
2037 	return "";
2038 }
2039 
emit_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier,uint32_t base_offset)2040 void CompilerHLSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
2041                                       const string &qualifier, uint32_t base_offset)
2042 {
2043 	auto &membertype = get<SPIRType>(member_type_id);
2044 
2045 	Bitset memberflags;
2046 	auto &memb = ir.meta[type.self].members;
2047 	if (index < memb.size())
2048 		memberflags = memb[index].decoration_flags;
2049 
2050 	string qualifiers;
2051 	bool is_block = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock) ||
2052 	                ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock);
2053 
2054 	if (is_block)
2055 		qualifiers = to_interpolation_qualifiers(memberflags);
2056 
2057 	string packing_offset;
2058 	bool is_push_constant = type.storage == StorageClassPushConstant;
2059 
2060 	if ((has_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset) || is_push_constant) &&
2061 	    has_member_decoration(type.self, index, DecorationOffset))
2062 	{
2063 		uint32_t offset = memb[index].offset - base_offset;
2064 		if (offset & 3)
2065 			SPIRV_CROSS_THROW("Cannot pack on tighter bounds than 4 bytes in HLSL.");
2066 
2067 		static const char *packing_swizzle[] = { "", ".y", ".z", ".w" };
2068 		packing_offset = join(" : packoffset(c", offset / 16, packing_swizzle[(offset & 15) >> 2], ")");
2069 	}
2070 
2071 	statement(layout_for_member(type, index), qualifiers, qualifier,
2072 	          variable_decl(membertype, to_member_name(type, index)), packing_offset, ";");
2073 }
2074 
emit_buffer_block(const SPIRVariable & var)2075 void CompilerHLSL::emit_buffer_block(const SPIRVariable &var)
2076 {
2077 	auto &type = get<SPIRType>(var.basetype);
2078 
2079 	bool is_uav = var.storage == StorageClassStorageBuffer || has_decoration(type.self, DecorationBufferBlock);
2080 
2081 	if (is_uav)
2082 	{
2083 		Bitset flags = ir.get_buffer_block_flags(var);
2084 		bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
2085 		bool is_coherent = flags.get(DecorationCoherent) && !is_readonly;
2086 		bool is_interlocked = interlocked_resources.count(var.self) > 0;
2087 		const char *type_name = "ByteAddressBuffer ";
2088 		if (!is_readonly)
2089 			type_name = is_interlocked ? "RasterizerOrderedByteAddressBuffer " : "RWByteAddressBuffer ";
2090 		add_resource_name(var.self);
2091 		statement(is_coherent ? "globallycoherent " : "", type_name, to_name(var.self), type_to_array_glsl(type),
2092 		          to_resource_binding(var), ";");
2093 	}
2094 	else
2095 	{
2096 		if (type.array.empty())
2097 		{
2098 			// Flatten the top-level struct so we can use packoffset,
2099 			// this restriction is similar to GLSL where layout(offset) is not possible on sub-structs.
2100 			flattened_structs[var.self] = false;
2101 
2102 			// Prefer the block name if possible.
2103 			auto buffer_name = to_name(type.self, false);
2104 			if (ir.meta[type.self].decoration.alias.empty() ||
2105 			    resource_names.find(buffer_name) != end(resource_names) ||
2106 			    block_names.find(buffer_name) != end(block_names))
2107 			{
2108 				buffer_name = get_block_fallback_name(var.self);
2109 			}
2110 
2111 			add_variable(block_names, resource_names, buffer_name);
2112 
2113 			// If for some reason buffer_name is an illegal name, make a final fallback to a workaround name.
2114 			// This cannot conflict with anything else, so we're safe now.
2115 			if (buffer_name.empty())
2116 				buffer_name = join("_", get<SPIRType>(var.basetype).self, "_", var.self);
2117 
2118 			uint32_t failed_index = 0;
2119 			if (buffer_is_packing_standard(type, BufferPackingHLSLCbufferPackOffset, &failed_index))
2120 				set_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset);
2121 			else
2122 			{
2123 				SPIRV_CROSS_THROW(join("cbuffer ID ", var.self, " (name: ", buffer_name, "), member index ",
2124 				                       failed_index, " (name: ", to_member_name(type, failed_index),
2125 				                       ") cannot be expressed with either HLSL packing layout or packoffset."));
2126 			}
2127 
2128 			block_names.insert(buffer_name);
2129 
2130 			// Save for post-reflection later.
2131 			declared_block_names[var.self] = buffer_name;
2132 
2133 			type.member_name_cache.clear();
2134 			// var.self can be used as a backup name for the block name,
2135 			// so we need to make sure we don't disturb the name here on a recompile.
2136 			// It will need to be reset if we have to recompile.
2137 			preserve_alias_on_reset(var.self);
2138 			add_resource_name(var.self);
2139 			statement("cbuffer ", buffer_name, to_resource_binding(var));
2140 			begin_scope();
2141 
2142 			uint32_t i = 0;
2143 			for (auto &member : type.member_types)
2144 			{
2145 				add_member_name(type, i);
2146 				auto backup_name = get_member_name(type.self, i);
2147 				auto member_name = to_member_name(type, i);
2148 				member_name = join(to_name(var.self), "_", member_name);
2149 				ParsedIR::sanitize_underscores(member_name);
2150 				set_member_name(type.self, i, member_name);
2151 				emit_struct_member(type, member, i, "");
2152 				set_member_name(type.self, i, backup_name);
2153 				i++;
2154 			}
2155 
2156 			end_scope_decl();
2157 			statement("");
2158 		}
2159 		else
2160 		{
2161 			if (hlsl_options.shader_model < 51)
2162 				SPIRV_CROSS_THROW(
2163 				    "Need ConstantBuffer<T> to use arrays of UBOs, but this is only supported in SM 5.1.");
2164 
2165 			add_resource_name(type.self);
2166 			add_resource_name(var.self);
2167 
2168 			// ConstantBuffer<T> does not support packoffset, so it is unuseable unless everything aligns as we expect.
2169 			uint32_t failed_index = 0;
2170 			if (!buffer_is_packing_standard(type, BufferPackingHLSLCbuffer, &failed_index))
2171 			{
2172 				SPIRV_CROSS_THROW(join("HLSL ConstantBuffer<T> ID ", var.self, " (name: ", to_name(type.self),
2173 				                       "), member index ", failed_index, " (name: ", to_member_name(type, failed_index),
2174 				                       ") cannot be expressed with normal HLSL packing rules."));
2175 			}
2176 
2177 			emit_struct(get<SPIRType>(type.self));
2178 			statement("ConstantBuffer<", to_name(type.self), "> ", to_name(var.self), type_to_array_glsl(type),
2179 			          to_resource_binding(var), ";");
2180 		}
2181 	}
2182 }
2183 
emit_push_constant_block(const SPIRVariable & var)2184 void CompilerHLSL::emit_push_constant_block(const SPIRVariable &var)
2185 {
2186 	if (root_constants_layout.empty())
2187 	{
2188 		emit_buffer_block(var);
2189 	}
2190 	else
2191 	{
2192 		for (const auto &layout : root_constants_layout)
2193 		{
2194 			auto &type = get<SPIRType>(var.basetype);
2195 
2196 			uint32_t failed_index = 0;
2197 			if (buffer_is_packing_standard(type, BufferPackingHLSLCbufferPackOffset, &failed_index, layout.start,
2198 			                               layout.end))
2199 				set_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset);
2200 			else
2201 			{
2202 				SPIRV_CROSS_THROW(join("Root constant cbuffer ID ", var.self, " (name: ", to_name(type.self), ")",
2203 				                       ", member index ", failed_index, " (name: ", to_member_name(type, failed_index),
2204 				                       ") cannot be expressed with either HLSL packing layout or packoffset."));
2205 			}
2206 
2207 			flattened_structs[var.self] = false;
2208 			type.member_name_cache.clear();
2209 			add_resource_name(var.self);
2210 			auto &memb = ir.meta[type.self].members;
2211 
2212 			statement("cbuffer SPIRV_CROSS_RootConstant_", to_name(var.self),
2213 			          to_resource_register(HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT, 'b', layout.binding, layout.space));
2214 			begin_scope();
2215 
2216 			// Index of the next field in the generated root constant constant buffer
2217 			auto constant_index = 0u;
2218 
2219 			// Iterate over all member of the push constant and check which of the fields
2220 			// fit into the given root constant layout.
2221 			for (auto i = 0u; i < memb.size(); i++)
2222 			{
2223 				const auto offset = memb[i].offset;
2224 				if (layout.start <= offset && offset < layout.end)
2225 				{
2226 					const auto &member = type.member_types[i];
2227 
2228 					add_member_name(type, constant_index);
2229 					auto backup_name = get_member_name(type.self, i);
2230 					auto member_name = to_member_name(type, i);
2231 					member_name = join(to_name(var.self), "_", member_name);
2232 					ParsedIR::sanitize_underscores(member_name);
2233 					set_member_name(type.self, constant_index, member_name);
2234 					emit_struct_member(type, member, i, "", layout.start);
2235 					set_member_name(type.self, constant_index, backup_name);
2236 
2237 					constant_index++;
2238 				}
2239 			}
2240 
2241 			end_scope_decl();
2242 		}
2243 	}
2244 }
2245 
to_sampler_expression(uint32_t id)2246 string CompilerHLSL::to_sampler_expression(uint32_t id)
2247 {
2248 	auto expr = join("_", to_expression(id));
2249 	auto index = expr.find_first_of('[');
2250 	if (index == string::npos)
2251 	{
2252 		return expr + "_sampler";
2253 	}
2254 	else
2255 	{
2256 		// We have an expression like _ident[array], so we cannot tack on _sampler, insert it inside the string instead.
2257 		return expr.insert(index, "_sampler");
2258 	}
2259 }
2260 
emit_sampled_image_op(uint32_t result_type,uint32_t result_id,uint32_t image_id,uint32_t samp_id)2261 void CompilerHLSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
2262 {
2263 	if (hlsl_options.shader_model >= 40 && combined_image_samplers.empty())
2264 	{
2265 		set<SPIRCombinedImageSampler>(result_id, result_type, image_id, samp_id);
2266 	}
2267 	else
2268 	{
2269 		// Make sure to suppress usage tracking. It is illegal to create temporaries of opaque types.
2270 		emit_op(result_type, result_id, to_combined_image_sampler(image_id, samp_id), true, true);
2271 	}
2272 }
2273 
to_func_call_arg(const SPIRFunction::Parameter & arg,uint32_t id)2274 string CompilerHLSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id)
2275 {
2276 	string arg_str = CompilerGLSL::to_func_call_arg(arg, id);
2277 
2278 	if (hlsl_options.shader_model <= 30)
2279 		return arg_str;
2280 
2281 	// Manufacture automatic sampler arg if the arg is a SampledImage texture and we're in modern HLSL.
2282 	auto &type = expression_type(id);
2283 
2284 	// We don't have to consider combined image samplers here via OpSampledImage because
2285 	// those variables cannot be passed as arguments to functions.
2286 	// Only global SampledImage variables may be used as arguments.
2287 	if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
2288 		arg_str += ", " + to_sampler_expression(id);
2289 
2290 	return arg_str;
2291 }
2292 
emit_function_prototype(SPIRFunction & func,const Bitset & return_flags)2293 void CompilerHLSL::emit_function_prototype(SPIRFunction &func, const Bitset &return_flags)
2294 {
2295 	if (func.self != ir.default_entry_point)
2296 		add_function_overload(func);
2297 
2298 	auto &execution = get_entry_point();
2299 	// Avoid shadow declarations.
2300 	local_variable_names = resource_names;
2301 
2302 	string decl;
2303 
2304 	auto &type = get<SPIRType>(func.return_type);
2305 	if (type.array.empty())
2306 	{
2307 		decl += flags_to_qualifiers_glsl(type, return_flags);
2308 		decl += type_to_glsl(type);
2309 		decl += " ";
2310 	}
2311 	else
2312 	{
2313 		// We cannot return arrays in HLSL, so "return" through an out variable.
2314 		decl = "void ";
2315 	}
2316 
2317 	if (func.self == ir.default_entry_point)
2318 	{
2319 		if (execution.model == ExecutionModelVertex)
2320 			decl += "vert_main";
2321 		else if (execution.model == ExecutionModelFragment)
2322 			decl += "frag_main";
2323 		else if (execution.model == ExecutionModelGLCompute)
2324 			decl += "comp_main";
2325 		else
2326 			SPIRV_CROSS_THROW("Unsupported execution model.");
2327 		processing_entry_point = true;
2328 	}
2329 	else
2330 		decl += to_name(func.self);
2331 
2332 	decl += "(";
2333 	SmallVector<string> arglist;
2334 
2335 	if (!type.array.empty())
2336 	{
2337 		// Fake array returns by writing to an out array instead.
2338 		string out_argument;
2339 		out_argument += "out ";
2340 		out_argument += type_to_glsl(type);
2341 		out_argument += " ";
2342 		out_argument += "spvReturnValue";
2343 		out_argument += type_to_array_glsl(type);
2344 		arglist.push_back(move(out_argument));
2345 	}
2346 
2347 	for (auto &arg : func.arguments)
2348 	{
2349 		// Do not pass in separate images or samplers if we're remapping
2350 		// to combined image samplers.
2351 		if (skip_argument(arg.id))
2352 			continue;
2353 
2354 		// Might change the variable name if it already exists in this function.
2355 		// SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation
2356 		// to use same name for variables.
2357 		// Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates.
2358 		add_local_variable_name(arg.id);
2359 
2360 		arglist.push_back(argument_decl(arg));
2361 
2362 		// Flatten a combined sampler to two separate arguments in modern HLSL.
2363 		auto &arg_type = get<SPIRType>(arg.type);
2364 		if (hlsl_options.shader_model > 30 && arg_type.basetype == SPIRType::SampledImage &&
2365 		    arg_type.image.dim != DimBuffer)
2366 		{
2367 			// Manufacture automatic sampler arg for SampledImage texture
2368 			arglist.push_back(join(image_is_comparison(arg_type, arg.id) ? "SamplerComparisonState " : "SamplerState ",
2369 			                       to_sampler_expression(arg.id), type_to_array_glsl(arg_type)));
2370 		}
2371 
2372 		// Hold a pointer to the parameter so we can invalidate the readonly field if needed.
2373 		auto *var = maybe_get<SPIRVariable>(arg.id);
2374 		if (var)
2375 			var->parameter = &arg;
2376 	}
2377 
2378 	for (auto &arg : func.shadow_arguments)
2379 	{
2380 		// Might change the variable name if it already exists in this function.
2381 		// SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation
2382 		// to use same name for variables.
2383 		// Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates.
2384 		add_local_variable_name(arg.id);
2385 
2386 		arglist.push_back(argument_decl(arg));
2387 
2388 		// Hold a pointer to the parameter so we can invalidate the readonly field if needed.
2389 		auto *var = maybe_get<SPIRVariable>(arg.id);
2390 		if (var)
2391 			var->parameter = &arg;
2392 	}
2393 
2394 	decl += merge(arglist);
2395 	decl += ")";
2396 	statement(decl);
2397 }
2398 
emit_hlsl_entry_point()2399 void CompilerHLSL::emit_hlsl_entry_point()
2400 {
2401 	SmallVector<string> arguments;
2402 
2403 	if (require_input)
2404 		arguments.push_back("SPIRV_Cross_Input stage_input");
2405 
2406 	// Add I/O blocks as separate arguments with appropriate storage qualifier.
2407 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
2408 		auto &type = this->get<SPIRType>(var.basetype);
2409 		bool block = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock);
2410 
2411 		if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
2412 			return;
2413 
2414 		if (block && !is_builtin_variable(var) && interface_variable_exists_in_entry_point(var.self))
2415 		{
2416 			if (var.storage == StorageClassInput)
2417 			{
2418 				arguments.push_back(join("in ", variable_decl(type, join("stage_input", to_name(var.self)))));
2419 			}
2420 			else if (var.storage == StorageClassOutput)
2421 			{
2422 				arguments.push_back(join("out ", variable_decl(type, join("stage_output", to_name(var.self)))));
2423 			}
2424 		}
2425 	});
2426 
2427 	auto &execution = get_entry_point();
2428 
2429 	switch (execution.model)
2430 	{
2431 	case ExecutionModelGLCompute:
2432 	{
2433 		SpecializationConstant wg_x, wg_y, wg_z;
2434 		get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
2435 
2436 		uint32_t x = execution.workgroup_size.x;
2437 		uint32_t y = execution.workgroup_size.y;
2438 		uint32_t z = execution.workgroup_size.z;
2439 
2440 		auto x_expr = wg_x.id ? get<SPIRConstant>(wg_x.id).specialization_constant_macro_name : to_string(x);
2441 		auto y_expr = wg_y.id ? get<SPIRConstant>(wg_y.id).specialization_constant_macro_name : to_string(y);
2442 		auto z_expr = wg_z.id ? get<SPIRConstant>(wg_z.id).specialization_constant_macro_name : to_string(z);
2443 
2444 		statement("[numthreads(", x_expr, ", ", y_expr, ", ", z_expr, ")]");
2445 		break;
2446 	}
2447 	case ExecutionModelFragment:
2448 		if (execution.flags.get(ExecutionModeEarlyFragmentTests))
2449 			statement("[earlydepthstencil]");
2450 		break;
2451 	default:
2452 		break;
2453 	}
2454 
2455 	statement(require_output ? "SPIRV_Cross_Output " : "void ", "main(", merge(arguments), ")");
2456 	begin_scope();
2457 	bool legacy = hlsl_options.shader_model <= 30;
2458 
2459 	// Copy builtins from entry point arguments to globals.
2460 	active_input_builtins.for_each_bit([&](uint32_t i) {
2461 		auto builtin = builtin_to_glsl(static_cast<BuiltIn>(i), StorageClassInput);
2462 		switch (static_cast<BuiltIn>(i))
2463 		{
2464 		case BuiltInFragCoord:
2465 			// VPOS in D3D9 is sampled at integer locations, apply half-pixel offset to be consistent.
2466 			// TODO: Do we need an option here? Any reason why a D3D9 shader would be used
2467 			// on a D3D10+ system with a different rasterization config?
2468 			if (legacy)
2469 				statement(builtin, " = stage_input.", builtin, " + float4(0.5f, 0.5f, 0.0f, 0.0f);");
2470 			else
2471 			{
2472 				statement(builtin, " = stage_input.", builtin, ";");
2473 				// ZW are undefined in D3D9, only do this fixup here.
2474 				statement(builtin, ".w = 1.0 / ", builtin, ".w;");
2475 			}
2476 			break;
2477 
2478 		case BuiltInVertexId:
2479 		case BuiltInVertexIndex:
2480 		case BuiltInInstanceIndex:
2481 			// D3D semantics are uint, but shader wants int.
2482 			if (hlsl_options.support_nonzero_base_vertex_base_instance)
2483 			{
2484 				if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
2485 					statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseInstance;");
2486 				else
2487 					statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseVertex;");
2488 			}
2489 			else
2490 				statement(builtin, " = int(stage_input.", builtin, ");");
2491 			break;
2492 
2493 		case BuiltInInstanceId:
2494 			// D3D semantics are uint, but shader wants int.
2495 			statement(builtin, " = int(stage_input.", builtin, ");");
2496 			break;
2497 
2498 		case BuiltInNumWorkgroups:
2499 		case BuiltInPointCoord:
2500 		case BuiltInSubgroupSize:
2501 		case BuiltInSubgroupLocalInvocationId:
2502 			break;
2503 
2504 		case BuiltInSubgroupEqMask:
2505 			// Emulate these ...
2506 			// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2507 			statement("gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));");
2508 			statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;");
2509 			statement("if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;");
2510 			statement("if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;");
2511 			statement("if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;");
2512 			break;
2513 
2514 		case BuiltInSubgroupGeMask:
2515 			// Emulate these ...
2516 			// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2517 			statement("gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);");
2518 			statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;");
2519 			statement("if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;");
2520 			statement("if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;");
2521 			statement("if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;");
2522 			statement("if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;");
2523 			statement("if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;");
2524 			break;
2525 
2526 		case BuiltInSubgroupGtMask:
2527 			// Emulate these ...
2528 			// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2529 			statement("uint gt_lane_index = WaveGetLaneIndex() + 1;");
2530 			statement("gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);");
2531 			statement("if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;");
2532 			statement("if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;");
2533 			statement("if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;");
2534 			statement("if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;");
2535 			statement("if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;");
2536 			statement("if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;");
2537 			statement("if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;");
2538 			break;
2539 
2540 		case BuiltInSubgroupLeMask:
2541 			// Emulate these ...
2542 			// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2543 			statement("uint le_lane_index = WaveGetLaneIndex() + 1;");
2544 			statement("gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;");
2545 			statement("if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;");
2546 			statement("if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;");
2547 			statement("if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;");
2548 			statement("if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;");
2549 			statement("if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;");
2550 			statement("if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;");
2551 			statement("if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;");
2552 			break;
2553 
2554 		case BuiltInSubgroupLtMask:
2555 			// Emulate these ...
2556 			// No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2557 			statement("gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;");
2558 			statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;");
2559 			statement("if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;");
2560 			statement("if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;");
2561 			statement("if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;");
2562 			statement("if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;");
2563 			statement("if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;");
2564 			break;
2565 
2566 		case BuiltInClipDistance:
2567 			for (uint32_t clip = 0; clip < clip_distance_count; clip++)
2568 				statement("gl_ClipDistance[", clip, "] = stage_input.gl_ClipDistance", clip / 4, ".", "xyzw"[clip & 3],
2569 				          ";");
2570 			break;
2571 
2572 		case BuiltInCullDistance:
2573 			for (uint32_t cull = 0; cull < cull_distance_count; cull++)
2574 				statement("gl_CullDistance[", cull, "] = stage_input.gl_CullDistance", cull / 4, ".", "xyzw"[cull & 3],
2575 				          ";");
2576 			break;
2577 
2578 		default:
2579 			statement(builtin, " = stage_input.", builtin, ";");
2580 			break;
2581 		}
2582 	});
2583 
2584 	// Copy from stage input struct to globals.
2585 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
2586 		auto &type = this->get<SPIRType>(var.basetype);
2587 		bool block = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock);
2588 
2589 		if (var.storage != StorageClassInput)
2590 			return;
2591 
2592 		bool need_matrix_unroll = var.storage == StorageClassInput && execution.model == ExecutionModelVertex;
2593 
2594 		if (!block && !var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
2595 		    interface_variable_exists_in_entry_point(var.self))
2596 		{
2597 			auto name = to_name(var.self);
2598 			auto &mtype = this->get<SPIRType>(var.basetype);
2599 			if (need_matrix_unroll && mtype.columns > 1)
2600 			{
2601 				// Unroll matrices.
2602 				for (uint32_t col = 0; col < mtype.columns; col++)
2603 					statement(name, "[", col, "] = stage_input.", name, "_", col, ";");
2604 			}
2605 			else
2606 			{
2607 				statement(name, " = stage_input.", name, ";");
2608 			}
2609 		}
2610 
2611 		// I/O blocks don't use the common stage input/output struct, but separate outputs.
2612 		if (block && !is_builtin_variable(var) && interface_variable_exists_in_entry_point(var.self))
2613 		{
2614 			auto name = to_name(var.self);
2615 			statement(name, " = stage_input", name, ";");
2616 		}
2617 	});
2618 
2619 	// Run the shader.
2620 	if (execution.model == ExecutionModelVertex)
2621 		statement("vert_main();");
2622 	else if (execution.model == ExecutionModelFragment)
2623 		statement("frag_main();");
2624 	else if (execution.model == ExecutionModelGLCompute)
2625 		statement("comp_main();");
2626 	else
2627 		SPIRV_CROSS_THROW("Unsupported shader stage.");
2628 
2629 	// Copy block outputs.
2630 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
2631 		auto &type = this->get<SPIRType>(var.basetype);
2632 		bool block = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock);
2633 
2634 		if (var.storage != StorageClassOutput)
2635 			return;
2636 
2637 		// I/O blocks don't use the common stage input/output struct, but separate outputs.
2638 		if (block && !is_builtin_variable(var) && interface_variable_exists_in_entry_point(var.self))
2639 		{
2640 			auto name = to_name(var.self);
2641 			statement("stage_output", name, " = ", name, ";");
2642 		}
2643 	});
2644 
2645 	// Copy stage outputs.
2646 	if (require_output)
2647 	{
2648 		statement("SPIRV_Cross_Output stage_output;");
2649 
2650 		// Copy builtins from globals to return struct.
2651 		active_output_builtins.for_each_bit([&](uint32_t i) {
2652 			// PointSize doesn't exist in HLSL.
2653 			if (i == BuiltInPointSize)
2654 				return;
2655 
2656 			switch (static_cast<BuiltIn>(i))
2657 			{
2658 			case BuiltInClipDistance:
2659 				for (uint32_t clip = 0; clip < clip_distance_count; clip++)
2660 					statement("stage_output.gl_ClipDistance", clip / 4, ".", "xyzw"[clip & 3], " = gl_ClipDistance[",
2661 					          clip, "];");
2662 				break;
2663 
2664 			case BuiltInCullDistance:
2665 				for (uint32_t cull = 0; cull < cull_distance_count; cull++)
2666 					statement("stage_output.gl_CullDistance", cull / 4, ".", "xyzw"[cull & 3], " = gl_CullDistance[",
2667 					          cull, "];");
2668 				break;
2669 
2670 			default:
2671 			{
2672 				auto builtin_expr = builtin_to_glsl(static_cast<BuiltIn>(i), StorageClassOutput);
2673 				statement("stage_output.", builtin_expr, " = ", builtin_expr, ";");
2674 				break;
2675 			}
2676 			}
2677 		});
2678 
2679 		ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
2680 			auto &type = this->get<SPIRType>(var.basetype);
2681 			bool block = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock);
2682 
2683 			if (var.storage != StorageClassOutput)
2684 				return;
2685 
2686 			if (!block && var.storage != StorageClassFunction && !var.remapped_variable && type.pointer &&
2687 			    !is_builtin_variable(var) && interface_variable_exists_in_entry_point(var.self))
2688 			{
2689 				auto name = to_name(var.self);
2690 
2691 				if (legacy && execution.model == ExecutionModelFragment)
2692 				{
2693 					string output_filler;
2694 					for (uint32_t size = type.vecsize; size < 4; ++size)
2695 						output_filler += ", 0.0";
2696 
2697 					statement("stage_output.", name, " = float4(", name, output_filler, ");");
2698 				}
2699 				else
2700 				{
2701 					statement("stage_output.", name, " = ", name, ";");
2702 				}
2703 			}
2704 		});
2705 
2706 		statement("return stage_output;");
2707 	}
2708 
2709 	end_scope();
2710 }
2711 
emit_fixup()2712 void CompilerHLSL::emit_fixup()
2713 {
2714 	if (is_vertex_like_shader())
2715 	{
2716 		// Do various mangling on the gl_Position.
2717 		if (hlsl_options.shader_model <= 30)
2718 		{
2719 			statement("gl_Position.x = gl_Position.x - gl_HalfPixel.x * "
2720 			          "gl_Position.w;");
2721 			statement("gl_Position.y = gl_Position.y + gl_HalfPixel.y * "
2722 			          "gl_Position.w;");
2723 		}
2724 
2725 		if (options.vertex.flip_vert_y)
2726 			statement("gl_Position.y = -gl_Position.y;");
2727 		if (options.vertex.fixup_clipspace)
2728 			statement("gl_Position.z = (gl_Position.z + gl_Position.w) * 0.5;");
2729 	}
2730 }
2731 
emit_texture_op(const Instruction & i,bool sparse)2732 void CompilerHLSL::emit_texture_op(const Instruction &i, bool sparse)
2733 {
2734 	if (sparse)
2735 		SPIRV_CROSS_THROW("Sparse feedback not yet supported in HLSL.");
2736 
2737 	auto *ops = stream(i);
2738 	auto op = static_cast<Op>(i.op);
2739 	uint32_t length = i.length;
2740 
2741 	SmallVector<uint32_t> inherited_expressions;
2742 
2743 	uint32_t result_type = ops[0];
2744 	uint32_t id = ops[1];
2745 	VariableID img = ops[2];
2746 	uint32_t coord = ops[3];
2747 	uint32_t dref = 0;
2748 	uint32_t comp = 0;
2749 	bool gather = false;
2750 	bool proj = false;
2751 	const uint32_t *opt = nullptr;
2752 	auto *combined_image = maybe_get<SPIRCombinedImageSampler>(img);
2753 	auto img_expr = to_expression(combined_image ? combined_image->image : img);
2754 
2755 	inherited_expressions.push_back(coord);
2756 
2757 	// Make sure non-uniform decoration is back-propagated to where it needs to be.
2758 	if (has_decoration(img, DecorationNonUniformEXT))
2759 		propagate_nonuniform_qualifier(img);
2760 
2761 	switch (op)
2762 	{
2763 	case OpImageSampleDrefImplicitLod:
2764 	case OpImageSampleDrefExplicitLod:
2765 		dref = ops[4];
2766 		opt = &ops[5];
2767 		length -= 5;
2768 		break;
2769 
2770 	case OpImageSampleProjDrefImplicitLod:
2771 	case OpImageSampleProjDrefExplicitLod:
2772 		dref = ops[4];
2773 		proj = true;
2774 		opt = &ops[5];
2775 		length -= 5;
2776 		break;
2777 
2778 	case OpImageDrefGather:
2779 		dref = ops[4];
2780 		opt = &ops[5];
2781 		gather = true;
2782 		length -= 5;
2783 		break;
2784 
2785 	case OpImageGather:
2786 		comp = ops[4];
2787 		opt = &ops[5];
2788 		gather = true;
2789 		length -= 5;
2790 		break;
2791 
2792 	case OpImageSampleProjImplicitLod:
2793 	case OpImageSampleProjExplicitLod:
2794 		opt = &ops[4];
2795 		length -= 4;
2796 		proj = true;
2797 		break;
2798 
2799 	case OpImageQueryLod:
2800 		opt = &ops[4];
2801 		length -= 4;
2802 		break;
2803 
2804 	default:
2805 		opt = &ops[4];
2806 		length -= 4;
2807 		break;
2808 	}
2809 
2810 	auto &imgtype = expression_type(img);
2811 	uint32_t coord_components = 0;
2812 	switch (imgtype.image.dim)
2813 	{
2814 	case spv::Dim1D:
2815 		coord_components = 1;
2816 		break;
2817 	case spv::Dim2D:
2818 		coord_components = 2;
2819 		break;
2820 	case spv::Dim3D:
2821 		coord_components = 3;
2822 		break;
2823 	case spv::DimCube:
2824 		coord_components = 3;
2825 		break;
2826 	case spv::DimBuffer:
2827 		coord_components = 1;
2828 		break;
2829 	default:
2830 		coord_components = 2;
2831 		break;
2832 	}
2833 
2834 	if (dref)
2835 		inherited_expressions.push_back(dref);
2836 
2837 	if (imgtype.image.arrayed)
2838 		coord_components++;
2839 
2840 	uint32_t bias = 0;
2841 	uint32_t lod = 0;
2842 	uint32_t grad_x = 0;
2843 	uint32_t grad_y = 0;
2844 	uint32_t coffset = 0;
2845 	uint32_t offset = 0;
2846 	uint32_t coffsets = 0;
2847 	uint32_t sample = 0;
2848 	uint32_t minlod = 0;
2849 	uint32_t flags = 0;
2850 
2851 	if (length)
2852 	{
2853 		flags = opt[0];
2854 		opt++;
2855 		length--;
2856 	}
2857 
2858 	auto test = [&](uint32_t &v, uint32_t flag) {
2859 		if (length && (flags & flag))
2860 		{
2861 			v = *opt++;
2862 			inherited_expressions.push_back(v);
2863 			length--;
2864 		}
2865 	};
2866 
2867 	test(bias, ImageOperandsBiasMask);
2868 	test(lod, ImageOperandsLodMask);
2869 	test(grad_x, ImageOperandsGradMask);
2870 	test(grad_y, ImageOperandsGradMask);
2871 	test(coffset, ImageOperandsConstOffsetMask);
2872 	test(offset, ImageOperandsOffsetMask);
2873 	test(coffsets, ImageOperandsConstOffsetsMask);
2874 	test(sample, ImageOperandsSampleMask);
2875 	test(minlod, ImageOperandsMinLodMask);
2876 
2877 	string expr;
2878 	string texop;
2879 
2880 	if (minlod != 0)
2881 		SPIRV_CROSS_THROW("MinLod texture operand not supported in HLSL.");
2882 
2883 	if (op == OpImageFetch)
2884 	{
2885 		if (hlsl_options.shader_model < 40)
2886 		{
2887 			SPIRV_CROSS_THROW("texelFetch is not supported in HLSL shader model 2/3.");
2888 		}
2889 		texop += img_expr;
2890 		texop += ".Load";
2891 	}
2892 	else if (op == OpImageQueryLod)
2893 	{
2894 		texop += img_expr;
2895 		texop += ".CalculateLevelOfDetail";
2896 	}
2897 	else
2898 	{
2899 		auto &imgformat = get<SPIRType>(imgtype.image.type);
2900 		if (imgformat.basetype != SPIRType::Float)
2901 		{
2902 			SPIRV_CROSS_THROW("Sampling non-float textures is not supported in HLSL.");
2903 		}
2904 
2905 		if (hlsl_options.shader_model >= 40)
2906 		{
2907 			texop += img_expr;
2908 
2909 			if (image_is_comparison(imgtype, img))
2910 			{
2911 				if (gather)
2912 				{
2913 					SPIRV_CROSS_THROW("GatherCmp does not exist in HLSL.");
2914 				}
2915 				else if (lod || grad_x || grad_y)
2916 				{
2917 					// Assume we want a fixed level, and the only thing we can get in HLSL is SampleCmpLevelZero.
2918 					texop += ".SampleCmpLevelZero";
2919 				}
2920 				else
2921 					texop += ".SampleCmp";
2922 			}
2923 			else if (gather)
2924 			{
2925 				uint32_t comp_num = evaluate_constant_u32(comp);
2926 				if (hlsl_options.shader_model >= 50)
2927 				{
2928 					switch (comp_num)
2929 					{
2930 					case 0:
2931 						texop += ".GatherRed";
2932 						break;
2933 					case 1:
2934 						texop += ".GatherGreen";
2935 						break;
2936 					case 2:
2937 						texop += ".GatherBlue";
2938 						break;
2939 					case 3:
2940 						texop += ".GatherAlpha";
2941 						break;
2942 					default:
2943 						SPIRV_CROSS_THROW("Invalid component.");
2944 					}
2945 				}
2946 				else
2947 				{
2948 					if (comp_num == 0)
2949 						texop += ".Gather";
2950 					else
2951 						SPIRV_CROSS_THROW("HLSL shader model 4 can only gather from the red component.");
2952 				}
2953 			}
2954 			else if (bias)
2955 				texop += ".SampleBias";
2956 			else if (grad_x || grad_y)
2957 				texop += ".SampleGrad";
2958 			else if (lod)
2959 				texop += ".SampleLevel";
2960 			else
2961 				texop += ".Sample";
2962 		}
2963 		else
2964 		{
2965 			switch (imgtype.image.dim)
2966 			{
2967 			case Dim1D:
2968 				texop += "tex1D";
2969 				break;
2970 			case Dim2D:
2971 				texop += "tex2D";
2972 				break;
2973 			case Dim3D:
2974 				texop += "tex3D";
2975 				break;
2976 			case DimCube:
2977 				texop += "texCUBE";
2978 				break;
2979 			case DimRect:
2980 			case DimBuffer:
2981 			case DimSubpassData:
2982 				SPIRV_CROSS_THROW("Buffer texture support is not yet implemented for HLSL"); // TODO
2983 			default:
2984 				SPIRV_CROSS_THROW("Invalid dimension.");
2985 			}
2986 
2987 			if (gather)
2988 				SPIRV_CROSS_THROW("textureGather is not supported in HLSL shader model 2/3.");
2989 			if (offset || coffset)
2990 				SPIRV_CROSS_THROW("textureOffset is not supported in HLSL shader model 2/3.");
2991 
2992 			if (grad_x || grad_y)
2993 				texop += "grad";
2994 			else if (lod)
2995 				texop += "lod";
2996 			else if (bias)
2997 				texop += "bias";
2998 			else if (proj || dref)
2999 				texop += "proj";
3000 		}
3001 	}
3002 
3003 	expr += texop;
3004 	expr += "(";
3005 	if (hlsl_options.shader_model < 40)
3006 	{
3007 		if (combined_image)
3008 			SPIRV_CROSS_THROW("Separate images/samplers are not supported in HLSL shader model 2/3.");
3009 		expr += to_expression(img);
3010 	}
3011 	else if (op != OpImageFetch)
3012 	{
3013 		string sampler_expr;
3014 		if (combined_image)
3015 			sampler_expr = to_expression(combined_image->sampler);
3016 		else
3017 			sampler_expr = to_sampler_expression(img);
3018 		expr += sampler_expr;
3019 	}
3020 
3021 	auto swizzle = [](uint32_t comps, uint32_t in_comps) -> const char * {
3022 		if (comps == in_comps)
3023 			return "";
3024 
3025 		switch (comps)
3026 		{
3027 		case 1:
3028 			return ".x";
3029 		case 2:
3030 			return ".xy";
3031 		case 3:
3032 			return ".xyz";
3033 		default:
3034 			return "";
3035 		}
3036 	};
3037 
3038 	bool forward = should_forward(coord);
3039 
3040 	// The IR can give us more components than we need, so chop them off as needed.
3041 	string coord_expr;
3042 	auto &coord_type = expression_type(coord);
3043 	if (coord_components != coord_type.vecsize)
3044 		coord_expr = to_enclosed_expression(coord) + swizzle(coord_components, expression_type(coord).vecsize);
3045 	else
3046 		coord_expr = to_expression(coord);
3047 
3048 	if (proj && hlsl_options.shader_model >= 40) // Legacy HLSL has "proj" operations which do this for us.
3049 		coord_expr = coord_expr + " / " + to_extract_component_expression(coord, coord_components);
3050 
3051 	if (hlsl_options.shader_model < 40)
3052 	{
3053 		if (dref)
3054 		{
3055 			if (imgtype.image.dim != spv::Dim1D && imgtype.image.dim != spv::Dim2D)
3056 			{
3057 				SPIRV_CROSS_THROW(
3058 				    "Depth comparison is only supported for 1D and 2D textures in HLSL shader model 2/3.");
3059 			}
3060 
3061 			if (grad_x || grad_y)
3062 				SPIRV_CROSS_THROW("Depth comparison is not supported for grad sampling in HLSL shader model 2/3.");
3063 
3064 			for (uint32_t size = coord_components; size < 2; ++size)
3065 				coord_expr += ", 0.0";
3066 
3067 			forward = forward && should_forward(dref);
3068 			coord_expr += ", " + to_expression(dref);
3069 		}
3070 		else if (lod || bias || proj)
3071 		{
3072 			for (uint32_t size = coord_components; size < 3; ++size)
3073 				coord_expr += ", 0.0";
3074 		}
3075 
3076 		if (lod)
3077 		{
3078 			coord_expr = "float4(" + coord_expr + ", " + to_expression(lod) + ")";
3079 		}
3080 		else if (bias)
3081 		{
3082 			coord_expr = "float4(" + coord_expr + ", " + to_expression(bias) + ")";
3083 		}
3084 		else if (proj)
3085 		{
3086 			coord_expr = "float4(" + coord_expr + ", " + to_extract_component_expression(coord, coord_components) + ")";
3087 		}
3088 		else if (dref)
3089 		{
3090 			// A "normal" sample gets fed into tex2Dproj as well, because the
3091 			// regular tex2D accepts only two coordinates.
3092 			coord_expr = "float4(" + coord_expr + ", 1.0)";
3093 		}
3094 
3095 		if (!!lod + !!bias + !!proj > 1)
3096 			SPIRV_CROSS_THROW("Legacy HLSL can only use one of lod/bias/proj modifiers.");
3097 	}
3098 
3099 	if (op == OpImageFetch)
3100 	{
3101 		if (imgtype.image.dim != DimBuffer && !imgtype.image.ms)
3102 			coord_expr =
3103 			    join("int", coord_components + 1, "(", coord_expr, ", ", lod ? to_expression(lod) : string("0"), ")");
3104 	}
3105 	else
3106 		expr += ", ";
3107 	expr += coord_expr;
3108 
3109 	if (dref && hlsl_options.shader_model >= 40)
3110 	{
3111 		forward = forward && should_forward(dref);
3112 		expr += ", ";
3113 
3114 		if (proj)
3115 			expr += to_enclosed_expression(dref) + " / " + to_extract_component_expression(coord, coord_components);
3116 		else
3117 			expr += to_expression(dref);
3118 	}
3119 
3120 	if (!dref && (grad_x || grad_y))
3121 	{
3122 		forward = forward && should_forward(grad_x);
3123 		forward = forward && should_forward(grad_y);
3124 		expr += ", ";
3125 		expr += to_expression(grad_x);
3126 		expr += ", ";
3127 		expr += to_expression(grad_y);
3128 	}
3129 
3130 	if (!dref && lod && hlsl_options.shader_model >= 40 && op != OpImageFetch)
3131 	{
3132 		forward = forward && should_forward(lod);
3133 		expr += ", ";
3134 		expr += to_expression(lod);
3135 	}
3136 
3137 	if (!dref && bias && hlsl_options.shader_model >= 40)
3138 	{
3139 		forward = forward && should_forward(bias);
3140 		expr += ", ";
3141 		expr += to_expression(bias);
3142 	}
3143 
3144 	if (coffset)
3145 	{
3146 		forward = forward && should_forward(coffset);
3147 		expr += ", ";
3148 		expr += to_expression(coffset);
3149 	}
3150 	else if (offset)
3151 	{
3152 		forward = forward && should_forward(offset);
3153 		expr += ", ";
3154 		expr += to_expression(offset);
3155 	}
3156 
3157 	if (sample)
3158 	{
3159 		expr += ", ";
3160 		expr += to_expression(sample);
3161 	}
3162 
3163 	expr += ")";
3164 
3165 	if (dref && hlsl_options.shader_model < 40)
3166 		expr += ".x";
3167 
3168 	if (op == OpImageQueryLod)
3169 	{
3170 		// This is rather awkward.
3171 		// textureQueryLod returns two values, the "accessed level",
3172 		// as well as the actual LOD lambda.
3173 		// As far as I can tell, there is no way to get the .x component
3174 		// according to GLSL spec, and it depends on the sampler itself.
3175 		// Just assume X == Y, so we will need to splat the result to a float2.
3176 		statement("float _", id, "_tmp = ", expr, ";");
3177 		statement("float2 _", id, " = _", id, "_tmp.xx;");
3178 		set<SPIRExpression>(id, join("_", id), result_type, true);
3179 	}
3180 	else
3181 	{
3182 		emit_op(result_type, id, expr, forward, false);
3183 	}
3184 
3185 	for (auto &inherit : inherited_expressions)
3186 		inherit_expression_dependencies(id, inherit);
3187 
3188 	switch (op)
3189 	{
3190 	case OpImageSampleDrefImplicitLod:
3191 	case OpImageSampleImplicitLod:
3192 	case OpImageSampleProjImplicitLod:
3193 	case OpImageSampleProjDrefImplicitLod:
3194 		register_control_dependent_expression(id);
3195 		break;
3196 
3197 	default:
3198 		break;
3199 	}
3200 }
3201 
to_resource_binding(const SPIRVariable & var)3202 string CompilerHLSL::to_resource_binding(const SPIRVariable &var)
3203 {
3204 	const auto &type = get<SPIRType>(var.basetype);
3205 
3206 	// We can remap push constant blocks, even if they don't have any binding decoration.
3207 	if (type.storage != StorageClassPushConstant && !has_decoration(var.self, DecorationBinding))
3208 		return "";
3209 
3210 	char space = '\0';
3211 
3212 	HLSLBindingFlagBits resource_flags = HLSL_BINDING_AUTO_NONE_BIT;
3213 
3214 	switch (type.basetype)
3215 	{
3216 	case SPIRType::SampledImage:
3217 		space = 't'; // SRV
3218 		resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3219 		break;
3220 
3221 	case SPIRType::Image:
3222 		if (type.image.sampled == 2 && type.image.dim != DimSubpassData)
3223 		{
3224 			if (has_decoration(var.self, DecorationNonWritable) && hlsl_options.nonwritable_uav_texture_as_srv)
3225 			{
3226 				space = 't'; // SRV
3227 				resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3228 			}
3229 			else
3230 			{
3231 				space = 'u'; // UAV
3232 				resource_flags = HLSL_BINDING_AUTO_UAV_BIT;
3233 			}
3234 		}
3235 		else
3236 		{
3237 			space = 't'; // SRV
3238 			resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3239 		}
3240 		break;
3241 
3242 	case SPIRType::Sampler:
3243 		space = 's';
3244 		resource_flags = HLSL_BINDING_AUTO_SAMPLER_BIT;
3245 		break;
3246 
3247 	case SPIRType::Struct:
3248 	{
3249 		auto storage = type.storage;
3250 		if (storage == StorageClassUniform)
3251 		{
3252 			if (has_decoration(type.self, DecorationBufferBlock))
3253 			{
3254 				Bitset flags = ir.get_buffer_block_flags(var);
3255 				bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
3256 				space = is_readonly ? 't' : 'u'; // UAV
3257 				resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
3258 			}
3259 			else if (has_decoration(type.self, DecorationBlock))
3260 			{
3261 				space = 'b'; // Constant buffers
3262 				resource_flags = HLSL_BINDING_AUTO_CBV_BIT;
3263 			}
3264 		}
3265 		else if (storage == StorageClassPushConstant)
3266 		{
3267 			space = 'b'; // Constant buffers
3268 			resource_flags = HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT;
3269 		}
3270 		else if (storage == StorageClassStorageBuffer)
3271 		{
3272 			// UAV or SRV depending on readonly flag.
3273 			Bitset flags = ir.get_buffer_block_flags(var);
3274 			bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
3275 			space = is_readonly ? 't' : 'u';
3276 			resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
3277 		}
3278 
3279 		break;
3280 	}
3281 	default:
3282 		break;
3283 	}
3284 
3285 	if (!space)
3286 		return "";
3287 
3288 	uint32_t desc_set =
3289 	    resource_flags == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT ? ResourceBindingPushConstantDescriptorSet : 0u;
3290 	uint32_t binding = resource_flags == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT ? ResourceBindingPushConstantBinding : 0u;
3291 
3292 	if (has_decoration(var.self, DecorationBinding))
3293 		binding = get_decoration(var.self, DecorationBinding);
3294 	if (has_decoration(var.self, DecorationDescriptorSet))
3295 		desc_set = get_decoration(var.self, DecorationDescriptorSet);
3296 
3297 	return to_resource_register(resource_flags, space, binding, desc_set);
3298 }
3299 
to_resource_binding_sampler(const SPIRVariable & var)3300 string CompilerHLSL::to_resource_binding_sampler(const SPIRVariable &var)
3301 {
3302 	// For combined image samplers.
3303 	if (!has_decoration(var.self, DecorationBinding))
3304 		return "";
3305 
3306 	return to_resource_register(HLSL_BINDING_AUTO_SAMPLER_BIT, 's', get_decoration(var.self, DecorationBinding),
3307 	                            get_decoration(var.self, DecorationDescriptorSet));
3308 }
3309 
remap_hlsl_resource_binding(HLSLBindingFlagBits type,uint32_t & desc_set,uint32_t & binding)3310 void CompilerHLSL::remap_hlsl_resource_binding(HLSLBindingFlagBits type, uint32_t &desc_set, uint32_t &binding)
3311 {
3312 	auto itr = resource_bindings.find({ get_execution_model(), desc_set, binding });
3313 	if (itr != end(resource_bindings))
3314 	{
3315 		auto &remap = itr->second;
3316 		remap.second = true;
3317 
3318 		switch (type)
3319 		{
3320 		case HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT:
3321 		case HLSL_BINDING_AUTO_CBV_BIT:
3322 			desc_set = remap.first.cbv.register_space;
3323 			binding = remap.first.cbv.register_binding;
3324 			break;
3325 
3326 		case HLSL_BINDING_AUTO_SRV_BIT:
3327 			desc_set = remap.first.srv.register_space;
3328 			binding = remap.first.srv.register_binding;
3329 			break;
3330 
3331 		case HLSL_BINDING_AUTO_SAMPLER_BIT:
3332 			desc_set = remap.first.sampler.register_space;
3333 			binding = remap.first.sampler.register_binding;
3334 			break;
3335 
3336 		case HLSL_BINDING_AUTO_UAV_BIT:
3337 			desc_set = remap.first.uav.register_space;
3338 			binding = remap.first.uav.register_binding;
3339 			break;
3340 
3341 		default:
3342 			break;
3343 		}
3344 	}
3345 }
3346 
to_resource_register(HLSLBindingFlagBits flag,char space,uint32_t binding,uint32_t space_set)3347 string CompilerHLSL::to_resource_register(HLSLBindingFlagBits flag, char space, uint32_t binding, uint32_t space_set)
3348 {
3349 	if ((flag & resource_binding_flags) == 0)
3350 	{
3351 		remap_hlsl_resource_binding(flag, space_set, binding);
3352 
3353 		// The push constant block did not have a binding, and there were no remap for it,
3354 		// so, declare without register binding.
3355 		if (flag == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT && space_set == ResourceBindingPushConstantDescriptorSet)
3356 			return "";
3357 
3358 		if (hlsl_options.shader_model >= 51)
3359 			return join(" : register(", space, binding, ", space", space_set, ")");
3360 		else
3361 			return join(" : register(", space, binding, ")");
3362 	}
3363 	else
3364 		return "";
3365 }
3366 
emit_modern_uniform(const SPIRVariable & var)3367 void CompilerHLSL::emit_modern_uniform(const SPIRVariable &var)
3368 {
3369 	auto &type = get<SPIRType>(var.basetype);
3370 	switch (type.basetype)
3371 	{
3372 	case SPIRType::SampledImage:
3373 	case SPIRType::Image:
3374 	{
3375 		bool is_coherent = false;
3376 		if (type.basetype == SPIRType::Image && type.image.sampled == 2)
3377 			is_coherent = has_decoration(var.self, DecorationCoherent);
3378 
3379 		statement(is_coherent ? "globallycoherent " : "", image_type_hlsl_modern(type, var.self), " ",
3380 		          to_name(var.self), type_to_array_glsl(type), to_resource_binding(var), ";");
3381 
3382 		if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
3383 		{
3384 			// For combined image samplers, also emit a combined image sampler.
3385 			if (image_is_comparison(type, var.self))
3386 				statement("SamplerComparisonState ", to_sampler_expression(var.self), type_to_array_glsl(type),
3387 				          to_resource_binding_sampler(var), ";");
3388 			else
3389 				statement("SamplerState ", to_sampler_expression(var.self), type_to_array_glsl(type),
3390 				          to_resource_binding_sampler(var), ";");
3391 		}
3392 		break;
3393 	}
3394 
3395 	case SPIRType::Sampler:
3396 		if (comparison_ids.count(var.self))
3397 			statement("SamplerComparisonState ", to_name(var.self), type_to_array_glsl(type), to_resource_binding(var),
3398 			          ";");
3399 		else
3400 			statement("SamplerState ", to_name(var.self), type_to_array_glsl(type), to_resource_binding(var), ";");
3401 		break;
3402 
3403 	default:
3404 		statement(variable_decl(var), to_resource_binding(var), ";");
3405 		break;
3406 	}
3407 }
3408 
emit_legacy_uniform(const SPIRVariable & var)3409 void CompilerHLSL::emit_legacy_uniform(const SPIRVariable &var)
3410 {
3411 	auto &type = get<SPIRType>(var.basetype);
3412 	switch (type.basetype)
3413 	{
3414 	case SPIRType::Sampler:
3415 	case SPIRType::Image:
3416 		SPIRV_CROSS_THROW("Separate image and samplers not supported in legacy HLSL.");
3417 
3418 	default:
3419 		statement(variable_decl(var), ";");
3420 		break;
3421 	}
3422 }
3423 
emit_uniform(const SPIRVariable & var)3424 void CompilerHLSL::emit_uniform(const SPIRVariable &var)
3425 {
3426 	add_resource_name(var.self);
3427 	if (hlsl_options.shader_model >= 40)
3428 		emit_modern_uniform(var);
3429 	else
3430 		emit_legacy_uniform(var);
3431 }
3432 
emit_complex_bitcast(uint32_t,uint32_t,uint32_t)3433 bool CompilerHLSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
3434 {
3435 	return false;
3436 }
3437 
bitcast_glsl_op(const SPIRType & out_type,const SPIRType & in_type)3438 string CompilerHLSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
3439 {
3440 	if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Int)
3441 		return type_to_glsl(out_type);
3442 	else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Int64)
3443 		return type_to_glsl(out_type);
3444 	else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Float)
3445 		return "asuint";
3446 	else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::UInt)
3447 		return type_to_glsl(out_type);
3448 	else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::UInt64)
3449 		return type_to_glsl(out_type);
3450 	else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::Float)
3451 		return "asint";
3452 	else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::UInt)
3453 		return "asfloat";
3454 	else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::Int)
3455 		return "asfloat";
3456 	else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::Double)
3457 		SPIRV_CROSS_THROW("Double to Int64 is not supported in HLSL.");
3458 	else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Double)
3459 		SPIRV_CROSS_THROW("Double to UInt64 is not supported in HLSL.");
3460 	else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::Int64)
3461 		return "asdouble";
3462 	else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::UInt64)
3463 		return "asdouble";
3464 	else if (out_type.basetype == SPIRType::Half && in_type.basetype == SPIRType::UInt && in_type.vecsize == 1)
3465 	{
3466 		if (!requires_explicit_fp16_packing)
3467 		{
3468 			requires_explicit_fp16_packing = true;
3469 			force_recompile();
3470 		}
3471 		return "spvUnpackFloat2x16";
3472 	}
3473 	else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Half && in_type.vecsize == 2)
3474 	{
3475 		if (!requires_explicit_fp16_packing)
3476 		{
3477 			requires_explicit_fp16_packing = true;
3478 			force_recompile();
3479 		}
3480 		return "spvPackFloat2x16";
3481 	}
3482 	else
3483 		return "";
3484 }
3485 
emit_glsl_op(uint32_t result_type,uint32_t id,uint32_t eop,const uint32_t * args,uint32_t count)3486 void CompilerHLSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
3487 {
3488 	auto op = static_cast<GLSLstd450>(eop);
3489 
3490 	// If we need to do implicit bitcasts, make sure we do it with the correct type.
3491 	uint32_t integer_width = get_integer_width_for_glsl_instruction(op, args, count);
3492 	auto int_type = to_signed_basetype(integer_width);
3493 	auto uint_type = to_unsigned_basetype(integer_width);
3494 
3495 	switch (op)
3496 	{
3497 	case GLSLstd450InverseSqrt:
3498 		emit_unary_func_op(result_type, id, args[0], "rsqrt");
3499 		break;
3500 
3501 	case GLSLstd450Fract:
3502 		emit_unary_func_op(result_type, id, args[0], "frac");
3503 		break;
3504 
3505 	case GLSLstd450RoundEven:
3506 		if (hlsl_options.shader_model < 40)
3507 			SPIRV_CROSS_THROW("roundEven is not supported in HLSL shader model 2/3.");
3508 		emit_unary_func_op(result_type, id, args[0], "round");
3509 		break;
3510 
3511 	case GLSLstd450Acosh:
3512 	case GLSLstd450Asinh:
3513 	case GLSLstd450Atanh:
3514 		SPIRV_CROSS_THROW("Inverse hyperbolics are not supported on HLSL.");
3515 
3516 	case GLSLstd450FMix:
3517 	case GLSLstd450IMix:
3518 		emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "lerp");
3519 		break;
3520 
3521 	case GLSLstd450Atan2:
3522 		emit_binary_func_op(result_type, id, args[0], args[1], "atan2");
3523 		break;
3524 
3525 	case GLSLstd450Fma:
3526 		emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "mad");
3527 		break;
3528 
3529 	case GLSLstd450InterpolateAtCentroid:
3530 		emit_unary_func_op(result_type, id, args[0], "EvaluateAttributeAtCentroid");
3531 		break;
3532 	case GLSLstd450InterpolateAtSample:
3533 		emit_binary_func_op(result_type, id, args[0], args[1], "EvaluateAttributeAtSample");
3534 		break;
3535 	case GLSLstd450InterpolateAtOffset:
3536 		emit_binary_func_op(result_type, id, args[0], args[1], "EvaluateAttributeSnapped");
3537 		break;
3538 
3539 	case GLSLstd450PackHalf2x16:
3540 		if (!requires_fp16_packing)
3541 		{
3542 			requires_fp16_packing = true;
3543 			force_recompile();
3544 		}
3545 		emit_unary_func_op(result_type, id, args[0], "spvPackHalf2x16");
3546 		break;
3547 
3548 	case GLSLstd450UnpackHalf2x16:
3549 		if (!requires_fp16_packing)
3550 		{
3551 			requires_fp16_packing = true;
3552 			force_recompile();
3553 		}
3554 		emit_unary_func_op(result_type, id, args[0], "spvUnpackHalf2x16");
3555 		break;
3556 
3557 	case GLSLstd450PackSnorm4x8:
3558 		if (!requires_snorm8_packing)
3559 		{
3560 			requires_snorm8_packing = true;
3561 			force_recompile();
3562 		}
3563 		emit_unary_func_op(result_type, id, args[0], "spvPackSnorm4x8");
3564 		break;
3565 
3566 	case GLSLstd450UnpackSnorm4x8:
3567 		if (!requires_snorm8_packing)
3568 		{
3569 			requires_snorm8_packing = true;
3570 			force_recompile();
3571 		}
3572 		emit_unary_func_op(result_type, id, args[0], "spvUnpackSnorm4x8");
3573 		break;
3574 
3575 	case GLSLstd450PackUnorm4x8:
3576 		if (!requires_unorm8_packing)
3577 		{
3578 			requires_unorm8_packing = true;
3579 			force_recompile();
3580 		}
3581 		emit_unary_func_op(result_type, id, args[0], "spvPackUnorm4x8");
3582 		break;
3583 
3584 	case GLSLstd450UnpackUnorm4x8:
3585 		if (!requires_unorm8_packing)
3586 		{
3587 			requires_unorm8_packing = true;
3588 			force_recompile();
3589 		}
3590 		emit_unary_func_op(result_type, id, args[0], "spvUnpackUnorm4x8");
3591 		break;
3592 
3593 	case GLSLstd450PackSnorm2x16:
3594 		if (!requires_snorm16_packing)
3595 		{
3596 			requires_snorm16_packing = true;
3597 			force_recompile();
3598 		}
3599 		emit_unary_func_op(result_type, id, args[0], "spvPackSnorm2x16");
3600 		break;
3601 
3602 	case GLSLstd450UnpackSnorm2x16:
3603 		if (!requires_snorm16_packing)
3604 		{
3605 			requires_snorm16_packing = true;
3606 			force_recompile();
3607 		}
3608 		emit_unary_func_op(result_type, id, args[0], "spvUnpackSnorm2x16");
3609 		break;
3610 
3611 	case GLSLstd450PackUnorm2x16:
3612 		if (!requires_unorm16_packing)
3613 		{
3614 			requires_unorm16_packing = true;
3615 			force_recompile();
3616 		}
3617 		emit_unary_func_op(result_type, id, args[0], "spvPackUnorm2x16");
3618 		break;
3619 
3620 	case GLSLstd450UnpackUnorm2x16:
3621 		if (!requires_unorm16_packing)
3622 		{
3623 			requires_unorm16_packing = true;
3624 			force_recompile();
3625 		}
3626 		emit_unary_func_op(result_type, id, args[0], "spvUnpackUnorm2x16");
3627 		break;
3628 
3629 	case GLSLstd450PackDouble2x32:
3630 	case GLSLstd450UnpackDouble2x32:
3631 		SPIRV_CROSS_THROW("packDouble2x32/unpackDouble2x32 not supported in HLSL.");
3632 
3633 	case GLSLstd450FindILsb:
3634 	{
3635 		auto basetype = expression_type(args[0]).basetype;
3636 		emit_unary_func_op_cast(result_type, id, args[0], "firstbitlow", basetype, basetype);
3637 		break;
3638 	}
3639 
3640 	case GLSLstd450FindSMsb:
3641 		emit_unary_func_op_cast(result_type, id, args[0], "firstbithigh", int_type, int_type);
3642 		break;
3643 
3644 	case GLSLstd450FindUMsb:
3645 		emit_unary_func_op_cast(result_type, id, args[0], "firstbithigh", uint_type, uint_type);
3646 		break;
3647 
3648 	case GLSLstd450MatrixInverse:
3649 	{
3650 		auto &type = get<SPIRType>(result_type);
3651 		if (type.vecsize == 2 && type.columns == 2)
3652 		{
3653 			if (!requires_inverse_2x2)
3654 			{
3655 				requires_inverse_2x2 = true;
3656 				force_recompile();
3657 			}
3658 		}
3659 		else if (type.vecsize == 3 && type.columns == 3)
3660 		{
3661 			if (!requires_inverse_3x3)
3662 			{
3663 				requires_inverse_3x3 = true;
3664 				force_recompile();
3665 			}
3666 		}
3667 		else if (type.vecsize == 4 && type.columns == 4)
3668 		{
3669 			if (!requires_inverse_4x4)
3670 			{
3671 				requires_inverse_4x4 = true;
3672 				force_recompile();
3673 			}
3674 		}
3675 		emit_unary_func_op(result_type, id, args[0], "spvInverse");
3676 		break;
3677 	}
3678 
3679 	case GLSLstd450Normalize:
3680 		// HLSL does not support scalar versions here.
3681 		if (expression_type(args[0]).vecsize == 1)
3682 		{
3683 			// Returns -1 or 1 for valid input, sign() does the job.
3684 			emit_unary_func_op(result_type, id, args[0], "sign");
3685 		}
3686 		else
3687 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3688 		break;
3689 
3690 	case GLSLstd450Reflect:
3691 		if (get<SPIRType>(result_type).vecsize == 1)
3692 		{
3693 			if (!requires_scalar_reflect)
3694 			{
3695 				requires_scalar_reflect = true;
3696 				force_recompile();
3697 			}
3698 			emit_binary_func_op(result_type, id, args[0], args[1], "spvReflect");
3699 		}
3700 		else
3701 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3702 		break;
3703 
3704 	case GLSLstd450Refract:
3705 		if (get<SPIRType>(result_type).vecsize == 1)
3706 		{
3707 			if (!requires_scalar_refract)
3708 			{
3709 				requires_scalar_refract = true;
3710 				force_recompile();
3711 			}
3712 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvRefract");
3713 		}
3714 		else
3715 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3716 		break;
3717 
3718 	case GLSLstd450FaceForward:
3719 		if (get<SPIRType>(result_type).vecsize == 1)
3720 		{
3721 			if (!requires_scalar_faceforward)
3722 			{
3723 				requires_scalar_faceforward = true;
3724 				force_recompile();
3725 			}
3726 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvFaceForward");
3727 		}
3728 		else
3729 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3730 		break;
3731 
3732 	default:
3733 		CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3734 		break;
3735 	}
3736 }
3737 
read_access_chain_array(const string & lhs,const SPIRAccessChain & chain)3738 void CompilerHLSL::read_access_chain_array(const string &lhs, const SPIRAccessChain &chain)
3739 {
3740 	auto &type = get<SPIRType>(chain.basetype);
3741 
3742 	// Need to use a reserved identifier here since it might shadow an identifier in the access chain input or other loops.
3743 	auto ident = get_unique_identifier();
3744 
3745 	statement("[unroll]");
3746 	statement("for (int ", ident, " = 0; ", ident, " < ", to_array_size(type, uint32_t(type.array.size() - 1)), "; ",
3747 	          ident, "++)");
3748 	begin_scope();
3749 	auto subchain = chain;
3750 	subchain.dynamic_index = join(ident, " * ", chain.array_stride, " + ", chain.dynamic_index);
3751 	subchain.basetype = type.parent_type;
3752 	if (!get<SPIRType>(subchain.basetype).array.empty())
3753 		subchain.array_stride = get_decoration(subchain.basetype, DecorationArrayStride);
3754 	read_access_chain(nullptr, join(lhs, "[", ident, "]"), subchain);
3755 	end_scope();
3756 }
3757 
read_access_chain_struct(const string & lhs,const SPIRAccessChain & chain)3758 void CompilerHLSL::read_access_chain_struct(const string &lhs, const SPIRAccessChain &chain)
3759 {
3760 	auto &type = get<SPIRType>(chain.basetype);
3761 	auto subchain = chain;
3762 	uint32_t member_count = uint32_t(type.member_types.size());
3763 
3764 	for (uint32_t i = 0; i < member_count; i++)
3765 	{
3766 		uint32_t offset = type_struct_member_offset(type, i);
3767 		subchain.static_index = chain.static_index + offset;
3768 		subchain.basetype = type.member_types[i];
3769 
3770 		subchain.matrix_stride = 0;
3771 		subchain.array_stride = 0;
3772 		subchain.row_major_matrix = false;
3773 
3774 		auto &member_type = get<SPIRType>(subchain.basetype);
3775 		if (member_type.columns > 1)
3776 		{
3777 			subchain.matrix_stride = type_struct_member_matrix_stride(type, i);
3778 			subchain.row_major_matrix = has_member_decoration(type.self, i, DecorationRowMajor);
3779 		}
3780 
3781 		if (!member_type.array.empty())
3782 			subchain.array_stride = type_struct_member_array_stride(type, i);
3783 
3784 		read_access_chain(nullptr, join(lhs, ".", to_member_name(type, i)), subchain);
3785 	}
3786 }
3787 
read_access_chain(string * expr,const string & lhs,const SPIRAccessChain & chain)3788 void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIRAccessChain &chain)
3789 {
3790 	auto &type = get<SPIRType>(chain.basetype);
3791 
3792 	SPIRType target_type;
3793 	target_type.basetype = SPIRType::UInt;
3794 	target_type.vecsize = type.vecsize;
3795 	target_type.columns = type.columns;
3796 
3797 	if (!type.array.empty())
3798 	{
3799 		read_access_chain_array(lhs, chain);
3800 		return;
3801 	}
3802 	else if (type.basetype == SPIRType::Struct)
3803 	{
3804 		read_access_chain_struct(lhs, chain);
3805 		return;
3806 	}
3807 	else if (type.width != 32 && !hlsl_options.enable_16bit_types)
3808 		SPIRV_CROSS_THROW("Reading types other than 32-bit from ByteAddressBuffer not yet supported, unless SM 6.2 and "
3809 		                  "native 16-bit types are enabled.");
3810 
3811 	bool templated_load = hlsl_options.shader_model >= 62;
3812 	string load_expr;
3813 
3814 	string template_expr;
3815 	if (templated_load)
3816 		template_expr = join("<", type_to_glsl(type), ">");
3817 
3818 	// Load a vector or scalar.
3819 	if (type.columns == 1 && !chain.row_major_matrix)
3820 	{
3821 		const char *load_op = nullptr;
3822 		switch (type.vecsize)
3823 		{
3824 		case 1:
3825 			load_op = "Load";
3826 			break;
3827 		case 2:
3828 			load_op = "Load2";
3829 			break;
3830 		case 3:
3831 			load_op = "Load3";
3832 			break;
3833 		case 4:
3834 			load_op = "Load4";
3835 			break;
3836 		default:
3837 			SPIRV_CROSS_THROW("Unknown vector size.");
3838 		}
3839 
3840 		if (templated_load)
3841 			load_op = "Load";
3842 
3843 		load_expr = join(chain.base, ".", load_op, template_expr, "(", chain.dynamic_index, chain.static_index, ")");
3844 	}
3845 	else if (type.columns == 1)
3846 	{
3847 		// Strided load since we are loading a column from a row-major matrix.
3848 		if (templated_load)
3849 		{
3850 			auto scalar_type = type;
3851 			scalar_type.vecsize = 1;
3852 			scalar_type.columns = 1;
3853 			template_expr = join("<", type_to_glsl(scalar_type), ">");
3854 			if (type.vecsize > 1)
3855 				load_expr += type_to_glsl(type) + "(";
3856 		}
3857 		else if (type.vecsize > 1)
3858 		{
3859 			load_expr = type_to_glsl(target_type);
3860 			load_expr += "(";
3861 		}
3862 
3863 		for (uint32_t r = 0; r < type.vecsize; r++)
3864 		{
3865 			load_expr += join(chain.base, ".Load", template_expr, "(", chain.dynamic_index,
3866 			                  chain.static_index + r * chain.matrix_stride, ")");
3867 			if (r + 1 < type.vecsize)
3868 				load_expr += ", ";
3869 		}
3870 
3871 		if (type.vecsize > 1)
3872 			load_expr += ")";
3873 	}
3874 	else if (!chain.row_major_matrix)
3875 	{
3876 		// Load a matrix, column-major, the easy case.
3877 		const char *load_op = nullptr;
3878 		switch (type.vecsize)
3879 		{
3880 		case 1:
3881 			load_op = "Load";
3882 			break;
3883 		case 2:
3884 			load_op = "Load2";
3885 			break;
3886 		case 3:
3887 			load_op = "Load3";
3888 			break;
3889 		case 4:
3890 			load_op = "Load4";
3891 			break;
3892 		default:
3893 			SPIRV_CROSS_THROW("Unknown vector size.");
3894 		}
3895 
3896 		if (templated_load)
3897 		{
3898 			auto vector_type = type;
3899 			vector_type.columns = 1;
3900 			template_expr = join("<", type_to_glsl(vector_type), ">");
3901 			load_expr = type_to_glsl(type);
3902 			load_op = "Load";
3903 		}
3904 		else
3905 		{
3906 			// Note, this loading style in HLSL is *actually* row-major, but we always treat matrices as transposed in this backend,
3907 			// so row-major is technically column-major ...
3908 			load_expr = type_to_glsl(target_type);
3909 		}
3910 		load_expr += "(";
3911 
3912 		for (uint32_t c = 0; c < type.columns; c++)
3913 		{
3914 			load_expr += join(chain.base, ".", load_op, template_expr, "(", chain.dynamic_index,
3915 			                  chain.static_index + c * chain.matrix_stride, ")");
3916 			if (c + 1 < type.columns)
3917 				load_expr += ", ";
3918 		}
3919 		load_expr += ")";
3920 	}
3921 	else
3922 	{
3923 		// Pick out elements one by one ... Hopefully compilers are smart enough to recognize this pattern
3924 		// considering HLSL is "row-major decl", but "column-major" memory layout (basically implicit transpose model, ugh) ...
3925 
3926 		if (templated_load)
3927 		{
3928 			load_expr = type_to_glsl(type);
3929 			auto scalar_type = type;
3930 			scalar_type.vecsize = 1;
3931 			scalar_type.columns = 1;
3932 			template_expr = join("<", type_to_glsl(scalar_type), ">");
3933 		}
3934 		else
3935 			load_expr = type_to_glsl(target_type);
3936 
3937 		load_expr += "(";
3938 
3939 		for (uint32_t c = 0; c < type.columns; c++)
3940 		{
3941 			for (uint32_t r = 0; r < type.vecsize; r++)
3942 			{
3943 				load_expr += join(chain.base, ".Load", template_expr, "(", chain.dynamic_index,
3944 				                  chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ")");
3945 
3946 				if ((r + 1 < type.vecsize) || (c + 1 < type.columns))
3947 					load_expr += ", ";
3948 			}
3949 		}
3950 		load_expr += ")";
3951 	}
3952 
3953 	if (!templated_load)
3954 	{
3955 		auto bitcast_op = bitcast_glsl_op(type, target_type);
3956 		if (!bitcast_op.empty())
3957 			load_expr = join(bitcast_op, "(", load_expr, ")");
3958 	}
3959 
3960 	if (lhs.empty())
3961 	{
3962 		assert(expr);
3963 		*expr = move(load_expr);
3964 	}
3965 	else
3966 		statement(lhs, " = ", load_expr, ";");
3967 }
3968 
emit_load(const Instruction & instruction)3969 void CompilerHLSL::emit_load(const Instruction &instruction)
3970 {
3971 	auto ops = stream(instruction);
3972 
3973 	auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
3974 	if (chain)
3975 	{
3976 		uint32_t result_type = ops[0];
3977 		uint32_t id = ops[1];
3978 		uint32_t ptr = ops[2];
3979 
3980 		if (has_decoration(ptr, DecorationNonUniformEXT))
3981 			propagate_nonuniform_qualifier(ptr);
3982 
3983 		auto &type = get<SPIRType>(result_type);
3984 		bool composite_load = !type.array.empty() || type.basetype == SPIRType::Struct;
3985 
3986 		if (composite_load)
3987 		{
3988 			// We cannot make this work in one single expression as we might have nested structures and arrays,
3989 			// so unroll the load to an uninitialized temporary.
3990 			emit_uninitialized_temporary_expression(result_type, id);
3991 			read_access_chain(nullptr, to_expression(id), *chain);
3992 			track_expression_read(chain->self);
3993 		}
3994 		else
3995 		{
3996 			string load_expr;
3997 			read_access_chain(&load_expr, "", *chain);
3998 
3999 			bool forward = should_forward(ptr) && forced_temporaries.find(id) == end(forced_temporaries);
4000 
4001 			// If we are forwarding this load,
4002 			// don't register the read to access chain here, defer that to when we actually use the expression,
4003 			// using the add_implied_read_expression mechanism.
4004 			if (!forward)
4005 				track_expression_read(chain->self);
4006 
4007 			// Do not forward complex load sequences like matrices, structs and arrays.
4008 			if (type.columns > 1)
4009 				forward = false;
4010 
4011 			auto &e = emit_op(result_type, id, load_expr, forward, true);
4012 			e.need_transpose = false;
4013 			register_read(id, ptr, forward);
4014 			inherit_expression_dependencies(id, ptr);
4015 			if (forward)
4016 				add_implied_read_expression(e, chain->self);
4017 		}
4018 	}
4019 	else
4020 		CompilerGLSL::emit_instruction(instruction);
4021 }
4022 
write_access_chain_array(const SPIRAccessChain & chain,uint32_t value,const SmallVector<uint32_t> & composite_chain)4023 void CompilerHLSL::write_access_chain_array(const SPIRAccessChain &chain, uint32_t value,
4024                                             const SmallVector<uint32_t> &composite_chain)
4025 {
4026 	auto &type = get<SPIRType>(chain.basetype);
4027 
4028 	// Need to use a reserved identifier here since it might shadow an identifier in the access chain input or other loops.
4029 	auto ident = get_unique_identifier();
4030 
4031 	uint32_t id = ir.increase_bound_by(2);
4032 	uint32_t int_type_id = id + 1;
4033 	SPIRType int_type;
4034 	int_type.basetype = SPIRType::Int;
4035 	int_type.width = 32;
4036 	set<SPIRType>(int_type_id, int_type);
4037 	set<SPIRExpression>(id, ident, int_type_id, true);
4038 	set_name(id, ident);
4039 	suppressed_usage_tracking.insert(id);
4040 
4041 	statement("[unroll]");
4042 	statement("for (int ", ident, " = 0; ", ident, " < ", to_array_size(type, uint32_t(type.array.size() - 1)), "; ",
4043 	          ident, "++)");
4044 	begin_scope();
4045 	auto subchain = chain;
4046 	subchain.dynamic_index = join(ident, " * ", chain.array_stride, " + ", chain.dynamic_index);
4047 	subchain.basetype = type.parent_type;
4048 
4049 	// Forcefully allow us to use an ID here by setting MSB.
4050 	auto subcomposite_chain = composite_chain;
4051 	subcomposite_chain.push_back(0x80000000u | id);
4052 
4053 	if (!get<SPIRType>(subchain.basetype).array.empty())
4054 		subchain.array_stride = get_decoration(subchain.basetype, DecorationArrayStride);
4055 
4056 	write_access_chain(subchain, value, subcomposite_chain);
4057 	end_scope();
4058 }
4059 
write_access_chain_struct(const SPIRAccessChain & chain,uint32_t value,const SmallVector<uint32_t> & composite_chain)4060 void CompilerHLSL::write_access_chain_struct(const SPIRAccessChain &chain, uint32_t value,
4061                                              const SmallVector<uint32_t> &composite_chain)
4062 {
4063 	auto &type = get<SPIRType>(chain.basetype);
4064 	uint32_t member_count = uint32_t(type.member_types.size());
4065 	auto subchain = chain;
4066 
4067 	auto subcomposite_chain = composite_chain;
4068 	subcomposite_chain.push_back(0);
4069 
4070 	for (uint32_t i = 0; i < member_count; i++)
4071 	{
4072 		uint32_t offset = type_struct_member_offset(type, i);
4073 		subchain.static_index = chain.static_index + offset;
4074 		subchain.basetype = type.member_types[i];
4075 
4076 		subchain.matrix_stride = 0;
4077 		subchain.array_stride = 0;
4078 		subchain.row_major_matrix = false;
4079 
4080 		auto &member_type = get<SPIRType>(subchain.basetype);
4081 		if (member_type.columns > 1)
4082 		{
4083 			subchain.matrix_stride = type_struct_member_matrix_stride(type, i);
4084 			subchain.row_major_matrix = has_member_decoration(type.self, i, DecorationRowMajor);
4085 		}
4086 
4087 		if (!member_type.array.empty())
4088 			subchain.array_stride = type_struct_member_array_stride(type, i);
4089 
4090 		subcomposite_chain.back() = i;
4091 		write_access_chain(subchain, value, subcomposite_chain);
4092 	}
4093 }
4094 
write_access_chain_value(uint32_t value,const SmallVector<uint32_t> & composite_chain,bool enclose)4095 string CompilerHLSL::write_access_chain_value(uint32_t value, const SmallVector<uint32_t> &composite_chain,
4096                                               bool enclose)
4097 {
4098 	string ret;
4099 	if (composite_chain.empty())
4100 		ret = to_expression(value);
4101 	else
4102 	{
4103 		AccessChainMeta meta;
4104 		ret = access_chain_internal(value, composite_chain.data(), uint32_t(composite_chain.size()),
4105 		                            ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_LITERAL_MSB_FORCE_ID, &meta);
4106 	}
4107 
4108 	if (enclose)
4109 		ret = enclose_expression(ret);
4110 	return ret;
4111 }
4112 
write_access_chain(const SPIRAccessChain & chain,uint32_t value,const SmallVector<uint32_t> & composite_chain)4113 void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t value,
4114                                       const SmallVector<uint32_t> &composite_chain)
4115 {
4116 	auto &type = get<SPIRType>(chain.basetype);
4117 
4118 	// Make sure we trigger a read of the constituents in the access chain.
4119 	track_expression_read(chain.self);
4120 
4121 	if (has_decoration(chain.self, DecorationNonUniformEXT))
4122 		propagate_nonuniform_qualifier(chain.self);
4123 
4124 	SPIRType target_type;
4125 	target_type.basetype = SPIRType::UInt;
4126 	target_type.vecsize = type.vecsize;
4127 	target_type.columns = type.columns;
4128 
4129 	if (!type.array.empty())
4130 	{
4131 		write_access_chain_array(chain, value, composite_chain);
4132 		register_write(chain.self);
4133 		return;
4134 	}
4135 	else if (type.basetype == SPIRType::Struct)
4136 	{
4137 		write_access_chain_struct(chain, value, composite_chain);
4138 		register_write(chain.self);
4139 		return;
4140 	}
4141 	else if (type.width != 32 && !hlsl_options.enable_16bit_types)
4142 		SPIRV_CROSS_THROW("Writing types other than 32-bit to RWByteAddressBuffer not yet supported, unless SM 6.2 and "
4143 		                  "native 16-bit types are enabled.");
4144 
4145 	bool templated_store = hlsl_options.shader_model >= 62;
4146 
4147 	string template_expr;
4148 	if (templated_store)
4149 		template_expr = join("<", type_to_glsl(type), ">");
4150 
4151 	if (type.columns == 1 && !chain.row_major_matrix)
4152 	{
4153 		const char *store_op = nullptr;
4154 		switch (type.vecsize)
4155 		{
4156 		case 1:
4157 			store_op = "Store";
4158 			break;
4159 		case 2:
4160 			store_op = "Store2";
4161 			break;
4162 		case 3:
4163 			store_op = "Store3";
4164 			break;
4165 		case 4:
4166 			store_op = "Store4";
4167 			break;
4168 		default:
4169 			SPIRV_CROSS_THROW("Unknown vector size.");
4170 		}
4171 
4172 		auto store_expr = write_access_chain_value(value, composite_chain, false);
4173 
4174 		if (!templated_store)
4175 		{
4176 			auto bitcast_op = bitcast_glsl_op(target_type, type);
4177 			if (!bitcast_op.empty())
4178 				store_expr = join(bitcast_op, "(", store_expr, ")");
4179 		}
4180 		else
4181 			store_op = "Store";
4182 		statement(chain.base, ".", store_op, template_expr, "(", chain.dynamic_index, chain.static_index, ", ",
4183 		          store_expr, ");");
4184 	}
4185 	else if (type.columns == 1)
4186 	{
4187 		if (templated_store)
4188 		{
4189 			auto scalar_type = type;
4190 			scalar_type.vecsize = 1;
4191 			scalar_type.columns = 1;
4192 			template_expr = join("<", type_to_glsl(scalar_type), ">");
4193 		}
4194 
4195 		// Strided store.
4196 		for (uint32_t r = 0; r < type.vecsize; r++)
4197 		{
4198 			auto store_expr = write_access_chain_value(value, composite_chain, true);
4199 			if (type.vecsize > 1)
4200 			{
4201 				store_expr += ".";
4202 				store_expr += index_to_swizzle(r);
4203 			}
4204 			remove_duplicate_swizzle(store_expr);
4205 
4206 			if (!templated_store)
4207 			{
4208 				auto bitcast_op = bitcast_glsl_op(target_type, type);
4209 				if (!bitcast_op.empty())
4210 					store_expr = join(bitcast_op, "(", store_expr, ")");
4211 			}
4212 
4213 			statement(chain.base, ".Store", template_expr, "(", chain.dynamic_index,
4214 			          chain.static_index + chain.matrix_stride * r, ", ", store_expr, ");");
4215 		}
4216 	}
4217 	else if (!chain.row_major_matrix)
4218 	{
4219 		const char *store_op = nullptr;
4220 		switch (type.vecsize)
4221 		{
4222 		case 1:
4223 			store_op = "Store";
4224 			break;
4225 		case 2:
4226 			store_op = "Store2";
4227 			break;
4228 		case 3:
4229 			store_op = "Store3";
4230 			break;
4231 		case 4:
4232 			store_op = "Store4";
4233 			break;
4234 		default:
4235 			SPIRV_CROSS_THROW("Unknown vector size.");
4236 		}
4237 
4238 		if (templated_store)
4239 		{
4240 			store_op = "Store";
4241 			auto vector_type = type;
4242 			vector_type.columns = 1;
4243 			template_expr = join("<", type_to_glsl(vector_type), ">");
4244 		}
4245 
4246 		for (uint32_t c = 0; c < type.columns; c++)
4247 		{
4248 			auto store_expr = join(write_access_chain_value(value, composite_chain, true), "[", c, "]");
4249 
4250 			if (!templated_store)
4251 			{
4252 				auto bitcast_op = bitcast_glsl_op(target_type, type);
4253 				if (!bitcast_op.empty())
4254 					store_expr = join(bitcast_op, "(", store_expr, ")");
4255 			}
4256 
4257 			statement(chain.base, ".", store_op, template_expr, "(", chain.dynamic_index,
4258 			          chain.static_index + c * chain.matrix_stride, ", ", store_expr, ");");
4259 		}
4260 	}
4261 	else
4262 	{
4263 		if (templated_store)
4264 		{
4265 			auto scalar_type = type;
4266 			scalar_type.vecsize = 1;
4267 			scalar_type.columns = 1;
4268 			template_expr = join("<", type_to_glsl(scalar_type), ">");
4269 		}
4270 
4271 		for (uint32_t r = 0; r < type.vecsize; r++)
4272 		{
4273 			for (uint32_t c = 0; c < type.columns; c++)
4274 			{
4275 				auto store_expr =
4276 				    join(write_access_chain_value(value, composite_chain, true), "[", c, "].", index_to_swizzle(r));
4277 				remove_duplicate_swizzle(store_expr);
4278 				auto bitcast_op = bitcast_glsl_op(target_type, type);
4279 				if (!bitcast_op.empty())
4280 					store_expr = join(bitcast_op, "(", store_expr, ")");
4281 				statement(chain.base, ".Store", template_expr, "(", chain.dynamic_index,
4282 				          chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ", ", store_expr, ");");
4283 			}
4284 		}
4285 	}
4286 
4287 	register_write(chain.self);
4288 }
4289 
emit_store(const Instruction & instruction)4290 void CompilerHLSL::emit_store(const Instruction &instruction)
4291 {
4292 	auto ops = stream(instruction);
4293 	auto *chain = maybe_get<SPIRAccessChain>(ops[0]);
4294 	if (chain)
4295 		write_access_chain(*chain, ops[1], {});
4296 	else
4297 		CompilerGLSL::emit_instruction(instruction);
4298 }
4299 
emit_access_chain(const Instruction & instruction)4300 void CompilerHLSL::emit_access_chain(const Instruction &instruction)
4301 {
4302 	auto ops = stream(instruction);
4303 	uint32_t length = instruction.length;
4304 
4305 	bool need_byte_access_chain = false;
4306 	auto &type = expression_type(ops[2]);
4307 	const auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
4308 
4309 	if (chain)
4310 	{
4311 		// Keep tacking on an existing access chain.
4312 		need_byte_access_chain = true;
4313 	}
4314 	else if (type.storage == StorageClassStorageBuffer || has_decoration(type.self, DecorationBufferBlock))
4315 	{
4316 		// If we are starting to poke into an SSBO, we are dealing with ByteAddressBuffers, and we need
4317 		// to emit SPIRAccessChain rather than a plain SPIRExpression.
4318 		uint32_t chain_arguments = length - 3;
4319 		if (chain_arguments > type.array.size())
4320 			need_byte_access_chain = true;
4321 	}
4322 
4323 	if (need_byte_access_chain)
4324 	{
4325 		// If we have a chain variable, we are already inside the SSBO, and any array type will refer to arrays within a block,
4326 		// and not array of SSBO.
4327 		uint32_t to_plain_buffer_length = chain ? 0u : static_cast<uint32_t>(type.array.size());
4328 
4329 		auto *backing_variable = maybe_get_backing_variable(ops[2]);
4330 
4331 		string base;
4332 		if (to_plain_buffer_length != 0)
4333 			base = access_chain(ops[2], &ops[3], to_plain_buffer_length, get<SPIRType>(ops[0]));
4334 		else if (chain)
4335 			base = chain->base;
4336 		else
4337 			base = to_expression(ops[2]);
4338 
4339 		// Start traversing type hierarchy at the proper non-pointer types.
4340 		auto *basetype = &get_pointee_type(type);
4341 
4342 		// Traverse the type hierarchy down to the actual buffer types.
4343 		for (uint32_t i = 0; i < to_plain_buffer_length; i++)
4344 		{
4345 			assert(basetype->parent_type);
4346 			basetype = &get<SPIRType>(basetype->parent_type);
4347 		}
4348 
4349 		uint32_t matrix_stride = 0;
4350 		uint32_t array_stride = 0;
4351 		bool row_major_matrix = false;
4352 
4353 		// Inherit matrix information.
4354 		if (chain)
4355 		{
4356 			matrix_stride = chain->matrix_stride;
4357 			row_major_matrix = chain->row_major_matrix;
4358 			array_stride = chain->array_stride;
4359 		}
4360 
4361 		auto offsets = flattened_access_chain_offset(*basetype, &ops[3 + to_plain_buffer_length],
4362 		                                             length - 3 - to_plain_buffer_length, 0, 1, &row_major_matrix,
4363 		                                             &matrix_stride, &array_stride);
4364 
4365 		auto &e = set<SPIRAccessChain>(ops[1], ops[0], type.storage, base, offsets.first, offsets.second);
4366 		e.row_major_matrix = row_major_matrix;
4367 		e.matrix_stride = matrix_stride;
4368 		e.array_stride = array_stride;
4369 		e.immutable = should_forward(ops[2]);
4370 		e.loaded_from = backing_variable ? backing_variable->self : ID(0);
4371 
4372 		if (chain)
4373 		{
4374 			e.dynamic_index += chain->dynamic_index;
4375 			e.static_index += chain->static_index;
4376 		}
4377 
4378 		for (uint32_t i = 2; i < length; i++)
4379 		{
4380 			inherit_expression_dependencies(ops[1], ops[i]);
4381 			add_implied_read_expression(e, ops[i]);
4382 		}
4383 
4384 		if (has_decoration(ops[1], DecorationNonUniformEXT))
4385 			propagate_nonuniform_qualifier(ops[1]);
4386 	}
4387 	else
4388 	{
4389 		CompilerGLSL::emit_instruction(instruction);
4390 	}
4391 }
4392 
emit_atomic(const uint32_t * ops,uint32_t length,spv::Op op)4393 void CompilerHLSL::emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op)
4394 {
4395 	const char *atomic_op = nullptr;
4396 
4397 	string value_expr;
4398 	if (op != OpAtomicIDecrement && op != OpAtomicIIncrement && op != OpAtomicLoad && op != OpAtomicStore)
4399 		value_expr = to_expression(ops[op == OpAtomicCompareExchange ? 6 : 5]);
4400 
4401 	bool is_atomic_store = false;
4402 
4403 	switch (op)
4404 	{
4405 	case OpAtomicIIncrement:
4406 		atomic_op = "InterlockedAdd";
4407 		value_expr = "1";
4408 		break;
4409 
4410 	case OpAtomicIDecrement:
4411 		atomic_op = "InterlockedAdd";
4412 		value_expr = "-1";
4413 		break;
4414 
4415 	case OpAtomicLoad:
4416 		atomic_op = "InterlockedAdd";
4417 		value_expr = "0";
4418 		break;
4419 
4420 	case OpAtomicISub:
4421 		atomic_op = "InterlockedAdd";
4422 		value_expr = join("-", enclose_expression(value_expr));
4423 		break;
4424 
4425 	case OpAtomicSMin:
4426 	case OpAtomicUMin:
4427 		atomic_op = "InterlockedMin";
4428 		break;
4429 
4430 	case OpAtomicSMax:
4431 	case OpAtomicUMax:
4432 		atomic_op = "InterlockedMax";
4433 		break;
4434 
4435 	case OpAtomicAnd:
4436 		atomic_op = "InterlockedAnd";
4437 		break;
4438 
4439 	case OpAtomicOr:
4440 		atomic_op = "InterlockedOr";
4441 		break;
4442 
4443 	case OpAtomicXor:
4444 		atomic_op = "InterlockedXor";
4445 		break;
4446 
4447 	case OpAtomicIAdd:
4448 		atomic_op = "InterlockedAdd";
4449 		break;
4450 
4451 	case OpAtomicExchange:
4452 		atomic_op = "InterlockedExchange";
4453 		break;
4454 
4455 	case OpAtomicStore:
4456 		atomic_op = "InterlockedExchange";
4457 		is_atomic_store = true;
4458 		break;
4459 
4460 	case OpAtomicCompareExchange:
4461 		if (length < 8)
4462 			SPIRV_CROSS_THROW("Not enough data for opcode.");
4463 		atomic_op = "InterlockedCompareExchange";
4464 		value_expr = join(to_expression(ops[7]), ", ", value_expr);
4465 		break;
4466 
4467 	default:
4468 		SPIRV_CROSS_THROW("Unknown atomic opcode.");
4469 	}
4470 
4471 	if (is_atomic_store)
4472 	{
4473 		auto &data_type = expression_type(ops[0]);
4474 		auto *chain = maybe_get<SPIRAccessChain>(ops[0]);
4475 
4476 		auto &tmp_id = extra_sub_expressions[ops[0]];
4477 		if (!tmp_id)
4478 		{
4479 			tmp_id = ir.increase_bound_by(1);
4480 			emit_uninitialized_temporary_expression(get_pointee_type(data_type).self, tmp_id);
4481 		}
4482 
4483 		if (data_type.storage == StorageClassImage || !chain)
4484 		{
4485 			statement(atomic_op, "(", to_expression(ops[0]), ", ", to_expression(ops[3]), ", ", to_expression(tmp_id),
4486 			          ");");
4487 		}
4488 		else
4489 		{
4490 			// RWByteAddress buffer is always uint in its underlying type.
4491 			statement(chain->base, ".", atomic_op, "(", chain->dynamic_index, chain->static_index, ", ",
4492 			          to_expression(ops[3]), ", ", to_expression(tmp_id), ");");
4493 		}
4494 	}
4495 	else
4496 	{
4497 		uint32_t result_type = ops[0];
4498 		uint32_t id = ops[1];
4499 		forced_temporaries.insert(ops[1]);
4500 
4501 		auto &type = get<SPIRType>(result_type);
4502 		statement(variable_decl(type, to_name(id)), ";");
4503 
4504 		auto &data_type = expression_type(ops[2]);
4505 		auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
4506 		SPIRType::BaseType expr_type;
4507 		if (data_type.storage == StorageClassImage || !chain)
4508 		{
4509 			statement(atomic_op, "(", to_expression(ops[2]), ", ", value_expr, ", ", to_name(id), ");");
4510 			expr_type = data_type.basetype;
4511 		}
4512 		else
4513 		{
4514 			// RWByteAddress buffer is always uint in its underlying type.
4515 			expr_type = SPIRType::UInt;
4516 			statement(chain->base, ".", atomic_op, "(", chain->dynamic_index, chain->static_index, ", ", value_expr,
4517 			          ", ", to_name(id), ");");
4518 		}
4519 
4520 		auto expr = bitcast_expression(type, expr_type, to_name(id));
4521 		set<SPIRExpression>(id, expr, result_type, true);
4522 	}
4523 	flush_all_atomic_capable_variables();
4524 }
4525 
emit_subgroup_op(const Instruction & i)4526 void CompilerHLSL::emit_subgroup_op(const Instruction &i)
4527 {
4528 	if (hlsl_options.shader_model < 60)
4529 		SPIRV_CROSS_THROW("Wave ops requires SM 6.0 or higher.");
4530 
4531 	const uint32_t *ops = stream(i);
4532 	auto op = static_cast<Op>(i.op);
4533 
4534 	uint32_t result_type = ops[0];
4535 	uint32_t id = ops[1];
4536 
4537 	auto scope = static_cast<Scope>(evaluate_constant_u32(ops[2]));
4538 	if (scope != ScopeSubgroup)
4539 		SPIRV_CROSS_THROW("Only subgroup scope is supported.");
4540 
4541 	const auto make_inclusive_Sum = [&](const string &expr) -> string {
4542 		return join(expr, " + ", to_expression(ops[4]));
4543 	};
4544 
4545 	const auto make_inclusive_Product = [&](const string &expr) -> string {
4546 		return join(expr, " * ", to_expression(ops[4]));
4547 	};
4548 
4549 	// If we need to do implicit bitcasts, make sure we do it with the correct type.
4550 	uint32_t integer_width = get_integer_width_for_instruction(i);
4551 	auto int_type = to_signed_basetype(integer_width);
4552 	auto uint_type = to_unsigned_basetype(integer_width);
4553 
4554 #define make_inclusive_BitAnd(expr) ""
4555 #define make_inclusive_BitOr(expr) ""
4556 #define make_inclusive_BitXor(expr) ""
4557 #define make_inclusive_Min(expr) ""
4558 #define make_inclusive_Max(expr) ""
4559 
4560 	switch (op)
4561 	{
4562 	case OpGroupNonUniformElect:
4563 		emit_op(result_type, id, "WaveIsFirstLane()", true);
4564 		break;
4565 
4566 	case OpGroupNonUniformBroadcast:
4567 		emit_binary_func_op(result_type, id, ops[3], ops[4], "WaveReadLaneAt");
4568 		break;
4569 
4570 	case OpGroupNonUniformBroadcastFirst:
4571 		emit_unary_func_op(result_type, id, ops[3], "WaveReadLaneFirst");
4572 		break;
4573 
4574 	case OpGroupNonUniformBallot:
4575 		emit_unary_func_op(result_type, id, ops[3], "WaveActiveBallot");
4576 		break;
4577 
4578 	case OpGroupNonUniformInverseBallot:
4579 		SPIRV_CROSS_THROW("Cannot trivially implement InverseBallot in HLSL.");
4580 		break;
4581 
4582 	case OpGroupNonUniformBallotBitExtract:
4583 		SPIRV_CROSS_THROW("Cannot trivially implement BallotBitExtract in HLSL.");
4584 		break;
4585 
4586 	case OpGroupNonUniformBallotFindLSB:
4587 		SPIRV_CROSS_THROW("Cannot trivially implement BallotFindLSB in HLSL.");
4588 		break;
4589 
4590 	case OpGroupNonUniformBallotFindMSB:
4591 		SPIRV_CROSS_THROW("Cannot trivially implement BallotFindMSB in HLSL.");
4592 		break;
4593 
4594 	case OpGroupNonUniformBallotBitCount:
4595 	{
4596 		auto operation = static_cast<GroupOperation>(ops[3]);
4597 		if (operation == GroupOperationReduce)
4598 		{
4599 			bool forward = should_forward(ops[4]);
4600 			auto left = join("countbits(", to_enclosed_expression(ops[4]), ".x) + countbits(",
4601 			                 to_enclosed_expression(ops[4]), ".y)");
4602 			auto right = join("countbits(", to_enclosed_expression(ops[4]), ".z) + countbits(",
4603 			                  to_enclosed_expression(ops[4]), ".w)");
4604 			emit_op(result_type, id, join(left, " + ", right), forward);
4605 			inherit_expression_dependencies(id, ops[4]);
4606 		}
4607 		else if (operation == GroupOperationInclusiveScan)
4608 			SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Inclusive Scan in HLSL.");
4609 		else if (operation == GroupOperationExclusiveScan)
4610 			SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Exclusive Scan in HLSL.");
4611 		else
4612 			SPIRV_CROSS_THROW("Invalid BitCount operation.");
4613 		break;
4614 	}
4615 
4616 	case OpGroupNonUniformShuffle:
4617 		SPIRV_CROSS_THROW("Cannot trivially implement Shuffle in HLSL.");
4618 	case OpGroupNonUniformShuffleXor:
4619 		SPIRV_CROSS_THROW("Cannot trivially implement ShuffleXor in HLSL.");
4620 	case OpGroupNonUniformShuffleUp:
4621 		SPIRV_CROSS_THROW("Cannot trivially implement ShuffleUp in HLSL.");
4622 	case OpGroupNonUniformShuffleDown:
4623 		SPIRV_CROSS_THROW("Cannot trivially implement ShuffleDown in HLSL.");
4624 
4625 	case OpGroupNonUniformAll:
4626 		emit_unary_func_op(result_type, id, ops[3], "WaveActiveAllTrue");
4627 		break;
4628 
4629 	case OpGroupNonUniformAny:
4630 		emit_unary_func_op(result_type, id, ops[3], "WaveActiveAnyTrue");
4631 		break;
4632 
4633 	case OpGroupNonUniformAllEqual:
4634 	{
4635 		auto &type = get<SPIRType>(result_type);
4636 		emit_unary_func_op(result_type, id, ops[3],
4637 		                   type.basetype == SPIRType::Boolean ? "WaveActiveAllEqualBool" : "WaveActiveAllEqual");
4638 		break;
4639 	}
4640 
4641 	// clang-format off
4642 #define HLSL_GROUP_OP(op, hlsl_op, supports_scan) \
4643 case OpGroupNonUniform##op: \
4644 	{ \
4645 		auto operation = static_cast<GroupOperation>(ops[3]); \
4646 		if (operation == GroupOperationReduce) \
4647 			emit_unary_func_op(result_type, id, ops[4], "WaveActive" #hlsl_op); \
4648 		else if (operation == GroupOperationInclusiveScan && supports_scan) \
4649         { \
4650 			bool forward = should_forward(ops[4]); \
4651 			emit_op(result_type, id, make_inclusive_##hlsl_op (join("WavePrefix" #hlsl_op, "(", to_expression(ops[4]), ")")), forward); \
4652 			inherit_expression_dependencies(id, ops[4]); \
4653         } \
4654 		else if (operation == GroupOperationExclusiveScan && supports_scan) \
4655 			emit_unary_func_op(result_type, id, ops[4], "WavePrefix" #hlsl_op); \
4656 		else if (operation == GroupOperationClusteredReduce) \
4657 			SPIRV_CROSS_THROW("Cannot trivially implement ClusteredReduce in HLSL."); \
4658 		else \
4659 			SPIRV_CROSS_THROW("Invalid group operation."); \
4660 		break; \
4661 	}
4662 
4663 #define HLSL_GROUP_OP_CAST(op, hlsl_op, type) \
4664 case OpGroupNonUniform##op: \
4665 	{ \
4666 		auto operation = static_cast<GroupOperation>(ops[3]); \
4667 		if (operation == GroupOperationReduce) \
4668 			emit_unary_func_op_cast(result_type, id, ops[4], "WaveActive" #hlsl_op, type, type); \
4669 		else \
4670 			SPIRV_CROSS_THROW("Invalid group operation."); \
4671 		break; \
4672 	}
4673 
4674 	HLSL_GROUP_OP(FAdd, Sum, true)
4675 	HLSL_GROUP_OP(FMul, Product, true)
4676 	HLSL_GROUP_OP(FMin, Min, false)
4677 	HLSL_GROUP_OP(FMax, Max, false)
4678 	HLSL_GROUP_OP(IAdd, Sum, true)
4679 	HLSL_GROUP_OP(IMul, Product, true)
4680 	HLSL_GROUP_OP_CAST(SMin, Min, int_type)
4681 	HLSL_GROUP_OP_CAST(SMax, Max, int_type)
4682 	HLSL_GROUP_OP_CAST(UMin, Min, uint_type)
4683 	HLSL_GROUP_OP_CAST(UMax, Max, uint_type)
4684 	HLSL_GROUP_OP(BitwiseAnd, BitAnd, false)
4685 	HLSL_GROUP_OP(BitwiseOr, BitOr, false)
4686 	HLSL_GROUP_OP(BitwiseXor, BitXor, false)
4687 
4688 #undef HLSL_GROUP_OP
4689 #undef HLSL_GROUP_OP_CAST
4690 		// clang-format on
4691 
4692 	case OpGroupNonUniformQuadSwap:
4693 	{
4694 		uint32_t direction = evaluate_constant_u32(ops[4]);
4695 		if (direction == 0)
4696 			emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossX");
4697 		else if (direction == 1)
4698 			emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossY");
4699 		else if (direction == 2)
4700 			emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossDiagonal");
4701 		else
4702 			SPIRV_CROSS_THROW("Invalid quad swap direction.");
4703 		break;
4704 	}
4705 
4706 	case OpGroupNonUniformQuadBroadcast:
4707 	{
4708 		emit_binary_func_op(result_type, id, ops[3], ops[4], "QuadReadLaneAt");
4709 		break;
4710 	}
4711 
4712 	default:
4713 		SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
4714 	}
4715 
4716 	register_control_dependent_expression(id);
4717 }
4718 
emit_instruction(const Instruction & instruction)4719 void CompilerHLSL::emit_instruction(const Instruction &instruction)
4720 {
4721 	auto ops = stream(instruction);
4722 	auto opcode = static_cast<Op>(instruction.op);
4723 
4724 #define HLSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
4725 #define HLSL_BOP_CAST(op, type) \
4726 	emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
4727 #define HLSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
4728 #define HLSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
4729 #define HLSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
4730 #define HLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
4731 #define HLSL_BFOP_CAST(op, type) \
4732 	emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
4733 #define HLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
4734 #define HLSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
4735 
4736 	// If we need to do implicit bitcasts, make sure we do it with the correct type.
4737 	uint32_t integer_width = get_integer_width_for_instruction(instruction);
4738 	auto int_type = to_signed_basetype(integer_width);
4739 	auto uint_type = to_unsigned_basetype(integer_width);
4740 
4741 	switch (opcode)
4742 	{
4743 	case OpAccessChain:
4744 	case OpInBoundsAccessChain:
4745 	{
4746 		emit_access_chain(instruction);
4747 		break;
4748 	}
4749 	case OpBitcast:
4750 	{
4751 		auto bitcast_type = get_bitcast_type(ops[0], ops[2]);
4752 		if (bitcast_type == CompilerHLSL::TypeNormal)
4753 			CompilerGLSL::emit_instruction(instruction);
4754 		else
4755 		{
4756 			if (!requires_uint2_packing)
4757 			{
4758 				requires_uint2_packing = true;
4759 				force_recompile();
4760 			}
4761 
4762 			if (bitcast_type == CompilerHLSL::TypePackUint2x32)
4763 				emit_unary_func_op(ops[0], ops[1], ops[2], "spvPackUint2x32");
4764 			else
4765 				emit_unary_func_op(ops[0], ops[1], ops[2], "spvUnpackUint2x32");
4766 		}
4767 
4768 		break;
4769 	}
4770 
4771 	case OpStore:
4772 	{
4773 		emit_store(instruction);
4774 		break;
4775 	}
4776 
4777 	case OpLoad:
4778 	{
4779 		emit_load(instruction);
4780 		break;
4781 	}
4782 
4783 	case OpMatrixTimesVector:
4784 	{
4785 		// Matrices are kept in a transposed state all the time, flip multiplication order always.
4786 		emit_binary_func_op(ops[0], ops[1], ops[3], ops[2], "mul");
4787 		break;
4788 	}
4789 
4790 	case OpVectorTimesMatrix:
4791 	{
4792 		// Matrices are kept in a transposed state all the time, flip multiplication order always.
4793 		emit_binary_func_op(ops[0], ops[1], ops[3], ops[2], "mul");
4794 		break;
4795 	}
4796 
4797 	case OpMatrixTimesMatrix:
4798 	{
4799 		// Matrices are kept in a transposed state all the time, flip multiplication order always.
4800 		emit_binary_func_op(ops[0], ops[1], ops[3], ops[2], "mul");
4801 		break;
4802 	}
4803 
4804 	case OpOuterProduct:
4805 	{
4806 		uint32_t result_type = ops[0];
4807 		uint32_t id = ops[1];
4808 		uint32_t a = ops[2];
4809 		uint32_t b = ops[3];
4810 
4811 		auto &type = get<SPIRType>(result_type);
4812 		string expr = type_to_glsl_constructor(type);
4813 		expr += "(";
4814 		for (uint32_t col = 0; col < type.columns; col++)
4815 		{
4816 			expr += to_enclosed_expression(a);
4817 			expr += " * ";
4818 			expr += to_extract_component_expression(b, col);
4819 			if (col + 1 < type.columns)
4820 				expr += ", ";
4821 		}
4822 		expr += ")";
4823 		emit_op(result_type, id, expr, should_forward(a) && should_forward(b));
4824 		inherit_expression_dependencies(id, a);
4825 		inherit_expression_dependencies(id, b);
4826 		break;
4827 	}
4828 
4829 	case OpFMod:
4830 	{
4831 		if (!requires_op_fmod)
4832 		{
4833 			requires_op_fmod = true;
4834 			force_recompile();
4835 		}
4836 		CompilerGLSL::emit_instruction(instruction);
4837 		break;
4838 	}
4839 
4840 	case OpFRem:
4841 		emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], "fmod");
4842 		break;
4843 
4844 	case OpImage:
4845 	{
4846 		uint32_t result_type = ops[0];
4847 		uint32_t id = ops[1];
4848 		auto *combined = maybe_get<SPIRCombinedImageSampler>(ops[2]);
4849 
4850 		if (combined)
4851 		{
4852 			auto &e = emit_op(result_type, id, to_expression(combined->image), true, true);
4853 			auto *var = maybe_get_backing_variable(combined->image);
4854 			if (var)
4855 				e.loaded_from = var->self;
4856 		}
4857 		else
4858 		{
4859 			auto &e = emit_op(result_type, id, to_expression(ops[2]), true, true);
4860 			auto *var = maybe_get_backing_variable(ops[2]);
4861 			if (var)
4862 				e.loaded_from = var->self;
4863 		}
4864 		break;
4865 	}
4866 
4867 	case OpDPdx:
4868 		HLSL_UFOP(ddx);
4869 		register_control_dependent_expression(ops[1]);
4870 		break;
4871 
4872 	case OpDPdy:
4873 		HLSL_UFOP(ddy);
4874 		register_control_dependent_expression(ops[1]);
4875 		break;
4876 
4877 	case OpDPdxFine:
4878 		HLSL_UFOP(ddx_fine);
4879 		register_control_dependent_expression(ops[1]);
4880 		break;
4881 
4882 	case OpDPdyFine:
4883 		HLSL_UFOP(ddy_fine);
4884 		register_control_dependent_expression(ops[1]);
4885 		break;
4886 
4887 	case OpDPdxCoarse:
4888 		HLSL_UFOP(ddx_coarse);
4889 		register_control_dependent_expression(ops[1]);
4890 		break;
4891 
4892 	case OpDPdyCoarse:
4893 		HLSL_UFOP(ddy_coarse);
4894 		register_control_dependent_expression(ops[1]);
4895 		break;
4896 
4897 	case OpFwidth:
4898 	case OpFwidthCoarse:
4899 	case OpFwidthFine:
4900 		HLSL_UFOP(fwidth);
4901 		register_control_dependent_expression(ops[1]);
4902 		break;
4903 
4904 	case OpLogicalNot:
4905 	{
4906 		auto result_type = ops[0];
4907 		auto id = ops[1];
4908 		auto &type = get<SPIRType>(result_type);
4909 
4910 		if (type.vecsize > 1)
4911 			emit_unrolled_unary_op(result_type, id, ops[2], "!");
4912 		else
4913 			HLSL_UOP(!);
4914 		break;
4915 	}
4916 
4917 	case OpIEqual:
4918 	{
4919 		auto result_type = ops[0];
4920 		auto id = ops[1];
4921 
4922 		if (expression_type(ops[2]).vecsize > 1)
4923 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "==", false, SPIRType::Unknown);
4924 		else
4925 			HLSL_BOP_CAST(==, int_type);
4926 		break;
4927 	}
4928 
4929 	case OpLogicalEqual:
4930 	case OpFOrdEqual:
4931 	case OpFUnordEqual:
4932 	{
4933 		// HLSL != operator is unordered.
4934 		// https://docs.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-float-rules.
4935 		// isnan() is apparently implemented as x != x as well.
4936 		// We cannot implement UnordEqual as !(OrdNotEqual), as HLSL cannot express OrdNotEqual.
4937 		// HACK: FUnordEqual will be implemented as FOrdEqual.
4938 
4939 		auto result_type = ops[0];
4940 		auto id = ops[1];
4941 
4942 		if (expression_type(ops[2]).vecsize > 1)
4943 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "==", false, SPIRType::Unknown);
4944 		else
4945 			HLSL_BOP(==);
4946 		break;
4947 	}
4948 
4949 	case OpINotEqual:
4950 	{
4951 		auto result_type = ops[0];
4952 		auto id = ops[1];
4953 
4954 		if (expression_type(ops[2]).vecsize > 1)
4955 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "!=", false, SPIRType::Unknown);
4956 		else
4957 			HLSL_BOP_CAST(!=, int_type);
4958 		break;
4959 	}
4960 
4961 	case OpLogicalNotEqual:
4962 	case OpFOrdNotEqual:
4963 	case OpFUnordNotEqual:
4964 	{
4965 		// HLSL != operator is unordered.
4966 		// https://docs.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-float-rules.
4967 		// isnan() is apparently implemented as x != x as well.
4968 
4969 		// FIXME: FOrdNotEqual cannot be implemented in a crisp and simple way here.
4970 		// We would need to do something like not(UnordEqual), but that cannot be expressed either.
4971 		// Adding a lot of NaN checks would be a breaking change from perspective of performance.
4972 		// SPIR-V will generally use isnan() checks when this even matters.
4973 		// HACK: FOrdNotEqual will be implemented as FUnordEqual.
4974 
4975 		auto result_type = ops[0];
4976 		auto id = ops[1];
4977 
4978 		if (expression_type(ops[2]).vecsize > 1)
4979 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "!=", false, SPIRType::Unknown);
4980 		else
4981 			HLSL_BOP(!=);
4982 		break;
4983 	}
4984 
4985 	case OpUGreaterThan:
4986 	case OpSGreaterThan:
4987 	{
4988 		auto result_type = ops[0];
4989 		auto id = ops[1];
4990 		auto type = opcode == OpUGreaterThan ? uint_type : int_type;
4991 
4992 		if (expression_type(ops[2]).vecsize > 1)
4993 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">", false, type);
4994 		else
4995 			HLSL_BOP_CAST(>, type);
4996 		break;
4997 	}
4998 
4999 	case OpFOrdGreaterThan:
5000 	{
5001 		auto result_type = ops[0];
5002 		auto id = ops[1];
5003 
5004 		if (expression_type(ops[2]).vecsize > 1)
5005 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">", false, SPIRType::Unknown);
5006 		else
5007 			HLSL_BOP(>);
5008 		break;
5009 	}
5010 
5011 	case OpFUnordGreaterThan:
5012 	{
5013 		auto result_type = ops[0];
5014 		auto id = ops[1];
5015 
5016 		if (expression_type(ops[2]).vecsize > 1)
5017 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<=", true, SPIRType::Unknown);
5018 		else
5019 			CompilerGLSL::emit_instruction(instruction);
5020 		break;
5021 	}
5022 
5023 	case OpUGreaterThanEqual:
5024 	case OpSGreaterThanEqual:
5025 	{
5026 		auto result_type = ops[0];
5027 		auto id = ops[1];
5028 
5029 		auto type = opcode == OpUGreaterThanEqual ? uint_type : int_type;
5030 		if (expression_type(ops[2]).vecsize > 1)
5031 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">=", false, type);
5032 		else
5033 			HLSL_BOP_CAST(>=, type);
5034 		break;
5035 	}
5036 
5037 	case OpFOrdGreaterThanEqual:
5038 	{
5039 		auto result_type = ops[0];
5040 		auto id = ops[1];
5041 
5042 		if (expression_type(ops[2]).vecsize > 1)
5043 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">=", false, SPIRType::Unknown);
5044 		else
5045 			HLSL_BOP(>=);
5046 		break;
5047 	}
5048 
5049 	case OpFUnordGreaterThanEqual:
5050 	{
5051 		auto result_type = ops[0];
5052 		auto id = ops[1];
5053 
5054 		if (expression_type(ops[2]).vecsize > 1)
5055 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<", true, SPIRType::Unknown);
5056 		else
5057 			CompilerGLSL::emit_instruction(instruction);
5058 		break;
5059 	}
5060 
5061 	case OpULessThan:
5062 	case OpSLessThan:
5063 	{
5064 		auto result_type = ops[0];
5065 		auto id = ops[1];
5066 
5067 		auto type = opcode == OpULessThan ? uint_type : int_type;
5068 		if (expression_type(ops[2]).vecsize > 1)
5069 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<", false, type);
5070 		else
5071 			HLSL_BOP_CAST(<, type);
5072 		break;
5073 	}
5074 
5075 	case OpFOrdLessThan:
5076 	{
5077 		auto result_type = ops[0];
5078 		auto id = ops[1];
5079 
5080 		if (expression_type(ops[2]).vecsize > 1)
5081 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<", false, SPIRType::Unknown);
5082 		else
5083 			HLSL_BOP(<);
5084 		break;
5085 	}
5086 
5087 	case OpFUnordLessThan:
5088 	{
5089 		auto result_type = ops[0];
5090 		auto id = ops[1];
5091 
5092 		if (expression_type(ops[2]).vecsize > 1)
5093 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">=", true, SPIRType::Unknown);
5094 		else
5095 			CompilerGLSL::emit_instruction(instruction);
5096 		break;
5097 	}
5098 
5099 	case OpULessThanEqual:
5100 	case OpSLessThanEqual:
5101 	{
5102 		auto result_type = ops[0];
5103 		auto id = ops[1];
5104 
5105 		auto type = opcode == OpULessThanEqual ? uint_type : int_type;
5106 		if (expression_type(ops[2]).vecsize > 1)
5107 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<=", false, type);
5108 		else
5109 			HLSL_BOP_CAST(<=, type);
5110 		break;
5111 	}
5112 
5113 	case OpFOrdLessThanEqual:
5114 	{
5115 		auto result_type = ops[0];
5116 		auto id = ops[1];
5117 
5118 		if (expression_type(ops[2]).vecsize > 1)
5119 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<=", false, SPIRType::Unknown);
5120 		else
5121 			HLSL_BOP(<=);
5122 		break;
5123 	}
5124 
5125 	case OpFUnordLessThanEqual:
5126 	{
5127 		auto result_type = ops[0];
5128 		auto id = ops[1];
5129 
5130 		if (expression_type(ops[2]).vecsize > 1)
5131 			emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">", true, SPIRType::Unknown);
5132 		else
5133 			CompilerGLSL::emit_instruction(instruction);
5134 		break;
5135 	}
5136 
5137 	case OpImageQueryLod:
5138 		emit_texture_op(instruction, false);
5139 		break;
5140 
5141 	case OpImageQuerySizeLod:
5142 	{
5143 		auto result_type = ops[0];
5144 		auto id = ops[1];
5145 
5146 		require_texture_query_variant(ops[2]);
5147 		auto dummy_samples_levels = join(get_fallback_name(id), "_dummy_parameter");
5148 		statement("uint ", dummy_samples_levels, ";");
5149 
5150 		auto expr = join("spvTextureSize(", to_expression(ops[2]), ", ",
5151 		                 bitcast_expression(SPIRType::UInt, ops[3]), ", ", dummy_samples_levels, ")");
5152 
5153 		auto &restype = get<SPIRType>(ops[0]);
5154 		expr = bitcast_expression(restype, SPIRType::UInt, expr);
5155 		emit_op(result_type, id, expr, true);
5156 		break;
5157 	}
5158 
5159 	case OpImageQuerySize:
5160 	{
5161 		auto result_type = ops[0];
5162 		auto id = ops[1];
5163 
5164 		require_texture_query_variant(ops[2]);
5165 		bool uav = expression_type(ops[2]).image.sampled == 2;
5166 
5167 		if (const auto *var = maybe_get_backing_variable(ops[2]))
5168 			if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var->self, DecorationNonWritable))
5169 				uav = false;
5170 
5171 		auto dummy_samples_levels = join(get_fallback_name(id), "_dummy_parameter");
5172 		statement("uint ", dummy_samples_levels, ";");
5173 
5174 		string expr;
5175 		if (uav)
5176 			expr = join("spvImageSize(", to_expression(ops[2]), ", ", dummy_samples_levels, ")");
5177 		else
5178 			expr = join("spvTextureSize(", to_expression(ops[2]), ", 0u, ", dummy_samples_levels, ")");
5179 
5180 		auto &restype = get<SPIRType>(ops[0]);
5181 		expr = bitcast_expression(restype, SPIRType::UInt, expr);
5182 		emit_op(result_type, id, expr, true);
5183 		break;
5184 	}
5185 
5186 	case OpImageQuerySamples:
5187 	case OpImageQueryLevels:
5188 	{
5189 		auto result_type = ops[0];
5190 		auto id = ops[1];
5191 
5192 		require_texture_query_variant(ops[2]);
5193 		bool uav = expression_type(ops[2]).image.sampled == 2;
5194 		if (opcode == OpImageQueryLevels && uav)
5195 			SPIRV_CROSS_THROW("Cannot query levels for UAV images.");
5196 
5197 		if (const auto *var = maybe_get_backing_variable(ops[2]))
5198 			if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var->self, DecorationNonWritable))
5199 				uav = false;
5200 
5201 		// Keep it simple and do not emit special variants to make this look nicer ...
5202 		// This stuff is barely, if ever, used.
5203 		forced_temporaries.insert(id);
5204 		auto &type = get<SPIRType>(result_type);
5205 		statement(variable_decl(type, to_name(id)), ";");
5206 
5207 		if (uav)
5208 			statement("spvImageSize(", to_expression(ops[2]), ", ", to_name(id), ");");
5209 		else
5210 			statement("spvTextureSize(", to_expression(ops[2]), ", 0u, ", to_name(id), ");");
5211 
5212 		auto &restype = get<SPIRType>(ops[0]);
5213 		auto expr = bitcast_expression(restype, SPIRType::UInt, to_name(id));
5214 		set<SPIRExpression>(id, expr, result_type, true);
5215 		break;
5216 	}
5217 
5218 	case OpImageRead:
5219 	{
5220 		uint32_t result_type = ops[0];
5221 		uint32_t id = ops[1];
5222 		auto *var = maybe_get_backing_variable(ops[2]);
5223 		auto &type = expression_type(ops[2]);
5224 		bool subpass_data = type.image.dim == DimSubpassData;
5225 		bool pure = false;
5226 
5227 		string imgexpr;
5228 
5229 		if (subpass_data)
5230 		{
5231 			if (hlsl_options.shader_model < 40)
5232 				SPIRV_CROSS_THROW("Subpass loads are not supported in HLSL shader model 2/3.");
5233 
5234 			// Similar to GLSL, implement subpass loads using texelFetch.
5235 			if (type.image.ms)
5236 			{
5237 				uint32_t operands = ops[4];
5238 				if (operands != ImageOperandsSampleMask || instruction.length != 6)
5239 					SPIRV_CROSS_THROW("Multisampled image used in OpImageRead, but unexpected operand mask was used.");
5240 				uint32_t sample = ops[5];
5241 				imgexpr = join(to_expression(ops[2]), ".Load(int2(gl_FragCoord.xy), ", to_expression(sample), ")");
5242 			}
5243 			else
5244 				imgexpr = join(to_expression(ops[2]), ".Load(int3(int2(gl_FragCoord.xy), 0))");
5245 
5246 			pure = true;
5247 		}
5248 		else
5249 		{
5250 			imgexpr = join(to_expression(ops[2]), "[", to_expression(ops[3]), "]");
5251 			// The underlying image type in HLSL depends on the image format, unlike GLSL, where all images are "vec4",
5252 			// except that the underlying type changes how the data is interpreted.
5253 
5254 			bool force_srv =
5255 			    hlsl_options.nonwritable_uav_texture_as_srv && var && has_decoration(var->self, DecorationNonWritable);
5256 			pure = force_srv;
5257 
5258 			if (var && !subpass_data && !force_srv)
5259 				imgexpr = remap_swizzle(get<SPIRType>(result_type),
5260 				                        image_format_to_components(get<SPIRType>(var->basetype).image.format), imgexpr);
5261 		}
5262 
5263 		if (var && var->forwardable)
5264 		{
5265 			bool forward = forced_temporaries.find(id) == end(forced_temporaries);
5266 			auto &e = emit_op(result_type, id, imgexpr, forward);
5267 
5268 			if (!pure)
5269 			{
5270 				e.loaded_from = var->self;
5271 				if (forward)
5272 					var->dependees.push_back(id);
5273 			}
5274 		}
5275 		else
5276 			emit_op(result_type, id, imgexpr, false);
5277 
5278 		inherit_expression_dependencies(id, ops[2]);
5279 		if (type.image.ms)
5280 			inherit_expression_dependencies(id, ops[5]);
5281 		break;
5282 	}
5283 
5284 	case OpImageWrite:
5285 	{
5286 		auto *var = maybe_get_backing_variable(ops[0]);
5287 
5288 		// The underlying image type in HLSL depends on the image format, unlike GLSL, where all images are "vec4",
5289 		// except that the underlying type changes how the data is interpreted.
5290 		auto value_expr = to_expression(ops[2]);
5291 		if (var)
5292 		{
5293 			auto &type = get<SPIRType>(var->basetype);
5294 			auto narrowed_type = get<SPIRType>(type.image.type);
5295 			narrowed_type.vecsize = image_format_to_components(type.image.format);
5296 			value_expr = remap_swizzle(narrowed_type, expression_type(ops[2]).vecsize, value_expr);
5297 		}
5298 
5299 		statement(to_expression(ops[0]), "[", to_expression(ops[1]), "] = ", value_expr, ";");
5300 		if (var && variable_storage_is_aliased(*var))
5301 			flush_all_aliased_variables();
5302 		break;
5303 	}
5304 
5305 	case OpImageTexelPointer:
5306 	{
5307 		uint32_t result_type = ops[0];
5308 		uint32_t id = ops[1];
5309 
5310 		auto expr = to_expression(ops[2]);
5311 		if (has_decoration(id, DecorationNonUniformEXT) || has_decoration(ops[2], DecorationNonUniformEXT))
5312 			convert_non_uniform_expression(expression_type(ops[2]), expr);
5313 		expr += join("[", to_expression(ops[3]), "]");
5314 
5315 		auto &e = set<SPIRExpression>(id, expr, result_type, true);
5316 
5317 		// When using the pointer, we need to know which variable it is actually loaded from.
5318 		auto *var = maybe_get_backing_variable(ops[2]);
5319 		e.loaded_from = var ? var->self : ID(0);
5320 		inherit_expression_dependencies(id, ops[3]);
5321 		break;
5322 	}
5323 
5324 	case OpAtomicCompareExchange:
5325 	case OpAtomicExchange:
5326 	case OpAtomicISub:
5327 	case OpAtomicSMin:
5328 	case OpAtomicUMin:
5329 	case OpAtomicSMax:
5330 	case OpAtomicUMax:
5331 	case OpAtomicAnd:
5332 	case OpAtomicOr:
5333 	case OpAtomicXor:
5334 	case OpAtomicIAdd:
5335 	case OpAtomicIIncrement:
5336 	case OpAtomicIDecrement:
5337 	case OpAtomicLoad:
5338 	case OpAtomicStore:
5339 	{
5340 		emit_atomic(ops, instruction.length, opcode);
5341 		break;
5342 	}
5343 
5344 	case OpControlBarrier:
5345 	case OpMemoryBarrier:
5346 	{
5347 		uint32_t memory;
5348 		uint32_t semantics;
5349 
5350 		if (opcode == OpMemoryBarrier)
5351 		{
5352 			memory = evaluate_constant_u32(ops[0]);
5353 			semantics = evaluate_constant_u32(ops[1]);
5354 		}
5355 		else
5356 		{
5357 			memory = evaluate_constant_u32(ops[1]);
5358 			semantics = evaluate_constant_u32(ops[2]);
5359 		}
5360 
5361 		if (memory == ScopeSubgroup)
5362 		{
5363 			// No Wave-barriers in HLSL.
5364 			break;
5365 		}
5366 
5367 		// We only care about these flags, acquire/release and friends are not relevant to GLSL.
5368 		semantics = mask_relevant_memory_semantics(semantics);
5369 
5370 		if (opcode == OpMemoryBarrier)
5371 		{
5372 			// If we are a memory barrier, and the next instruction is a control barrier, check if that memory barrier
5373 			// does what we need, so we avoid redundant barriers.
5374 			const Instruction *next = get_next_instruction_in_block(instruction);
5375 			if (next && next->op == OpControlBarrier)
5376 			{
5377 				auto *next_ops = stream(*next);
5378 				uint32_t next_memory = evaluate_constant_u32(next_ops[1]);
5379 				uint32_t next_semantics = evaluate_constant_u32(next_ops[2]);
5380 				next_semantics = mask_relevant_memory_semantics(next_semantics);
5381 
5382 				// There is no "just execution barrier" in HLSL.
5383 				// If there are no memory semantics for next instruction, we will imply group shared memory is synced.
5384 				if (next_semantics == 0)
5385 					next_semantics = MemorySemanticsWorkgroupMemoryMask;
5386 
5387 				bool memory_scope_covered = false;
5388 				if (next_memory == memory)
5389 					memory_scope_covered = true;
5390 				else if (next_semantics == MemorySemanticsWorkgroupMemoryMask)
5391 				{
5392 					// If we only care about workgroup memory, either Device or Workgroup scope is fine,
5393 					// scope does not have to match.
5394 					if ((next_memory == ScopeDevice || next_memory == ScopeWorkgroup) &&
5395 					    (memory == ScopeDevice || memory == ScopeWorkgroup))
5396 					{
5397 						memory_scope_covered = true;
5398 					}
5399 				}
5400 				else if (memory == ScopeWorkgroup && next_memory == ScopeDevice)
5401 				{
5402 					// The control barrier has device scope, but the memory barrier just has workgroup scope.
5403 					memory_scope_covered = true;
5404 				}
5405 
5406 				// If we have the same memory scope, and all memory types are covered, we're good.
5407 				if (memory_scope_covered && (semantics & next_semantics) == semantics)
5408 					break;
5409 			}
5410 		}
5411 
5412 		// We are synchronizing some memory or syncing execution,
5413 		// so we cannot forward any loads beyond the memory barrier.
5414 		if (semantics || opcode == OpControlBarrier)
5415 		{
5416 			assert(current_emitting_block);
5417 			flush_control_dependent_expressions(current_emitting_block->self);
5418 			flush_all_active_variables();
5419 		}
5420 
5421 		if (opcode == OpControlBarrier)
5422 		{
5423 			// We cannot emit just execution barrier, for no memory semantics pick the cheapest option.
5424 			if (semantics == MemorySemanticsWorkgroupMemoryMask || semantics == 0)
5425 				statement("GroupMemoryBarrierWithGroupSync();");
5426 			else if (semantics != 0 && (semantics & MemorySemanticsWorkgroupMemoryMask) == 0)
5427 				statement("DeviceMemoryBarrierWithGroupSync();");
5428 			else
5429 				statement("AllMemoryBarrierWithGroupSync();");
5430 		}
5431 		else
5432 		{
5433 			if (semantics == MemorySemanticsWorkgroupMemoryMask)
5434 				statement("GroupMemoryBarrier();");
5435 			else if (semantics != 0 && (semantics & MemorySemanticsWorkgroupMemoryMask) == 0)
5436 				statement("DeviceMemoryBarrier();");
5437 			else
5438 				statement("AllMemoryBarrier();");
5439 		}
5440 		break;
5441 	}
5442 
5443 	case OpBitFieldInsert:
5444 	{
5445 		if (!requires_bitfield_insert)
5446 		{
5447 			requires_bitfield_insert = true;
5448 			force_recompile();
5449 		}
5450 
5451 		auto expr = join("spvBitfieldInsert(", to_expression(ops[2]), ", ", to_expression(ops[3]), ", ",
5452 		                 to_expression(ops[4]), ", ", to_expression(ops[5]), ")");
5453 
5454 		bool forward =
5455 		    should_forward(ops[2]) && should_forward(ops[3]) && should_forward(ops[4]) && should_forward(ops[5]);
5456 
5457 		auto &restype = get<SPIRType>(ops[0]);
5458 		expr = bitcast_expression(restype, SPIRType::UInt, expr);
5459 		emit_op(ops[0], ops[1], expr, forward);
5460 		break;
5461 	}
5462 
5463 	case OpBitFieldSExtract:
5464 	case OpBitFieldUExtract:
5465 	{
5466 		if (!requires_bitfield_extract)
5467 		{
5468 			requires_bitfield_extract = true;
5469 			force_recompile();
5470 		}
5471 
5472 		if (opcode == OpBitFieldSExtract)
5473 			HLSL_TFOP(spvBitfieldSExtract);
5474 		else
5475 			HLSL_TFOP(spvBitfieldUExtract);
5476 		break;
5477 	}
5478 
5479 	case OpBitCount:
5480 	{
5481 		auto basetype = expression_type(ops[2]).basetype;
5482 		emit_unary_func_op_cast(ops[0], ops[1], ops[2], "countbits", basetype, basetype);
5483 		break;
5484 	}
5485 
5486 	case OpBitReverse:
5487 		HLSL_UFOP(reversebits);
5488 		break;
5489 
5490 	case OpArrayLength:
5491 	{
5492 		auto *var = maybe_get<SPIRVariable>(ops[2]);
5493 		if (!var)
5494 			SPIRV_CROSS_THROW("Array length must point directly to an SSBO block.");
5495 
5496 		auto &type = get<SPIRType>(var->basetype);
5497 		if (!has_decoration(type.self, DecorationBlock) && !has_decoration(type.self, DecorationBufferBlock))
5498 			SPIRV_CROSS_THROW("Array length expression must point to a block type.");
5499 
5500 		// This must be 32-bit uint, so we're good to go.
5501 		emit_uninitialized_temporary_expression(ops[0], ops[1]);
5502 		statement(to_expression(ops[2]), ".GetDimensions(", to_expression(ops[1]), ");");
5503 		uint32_t offset = type_struct_member_offset(type, ops[3]);
5504 		uint32_t stride = type_struct_member_array_stride(type, ops[3]);
5505 		statement(to_expression(ops[1]), " = (", to_expression(ops[1]), " - ", offset, ") / ", stride, ";");
5506 		break;
5507 	}
5508 
5509 	case OpIsHelperInvocationEXT:
5510 		SPIRV_CROSS_THROW("helperInvocationEXT() is not supported in HLSL.");
5511 
5512 	case OpBeginInvocationInterlockEXT:
5513 	case OpEndInvocationInterlockEXT:
5514 		if (hlsl_options.shader_model < 51)
5515 			SPIRV_CROSS_THROW("Rasterizer order views require Shader Model 5.1.");
5516 		break; // Nothing to do in the body
5517 
5518 	default:
5519 		CompilerGLSL::emit_instruction(instruction);
5520 		break;
5521 	}
5522 }
5523 
require_texture_query_variant(uint32_t var_id)5524 void CompilerHLSL::require_texture_query_variant(uint32_t var_id)
5525 {
5526 	if (const auto *var = maybe_get_backing_variable(var_id))
5527 		var_id = var->self;
5528 
5529 	auto &type = expression_type(var_id);
5530 	bool uav = type.image.sampled == 2;
5531 	if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var_id, DecorationNonWritable))
5532 		uav = false;
5533 
5534 	uint32_t bit = 0;
5535 	switch (type.image.dim)
5536 	{
5537 	case Dim1D:
5538 		bit = type.image.arrayed ? Query1DArray : Query1D;
5539 		break;
5540 
5541 	case Dim2D:
5542 		if (type.image.ms)
5543 			bit = type.image.arrayed ? Query2DMSArray : Query2DMS;
5544 		else
5545 			bit = type.image.arrayed ? Query2DArray : Query2D;
5546 		break;
5547 
5548 	case Dim3D:
5549 		bit = Query3D;
5550 		break;
5551 
5552 	case DimCube:
5553 		bit = type.image.arrayed ? QueryCubeArray : QueryCube;
5554 		break;
5555 
5556 	case DimBuffer:
5557 		bit = QueryBuffer;
5558 		break;
5559 
5560 	default:
5561 		SPIRV_CROSS_THROW("Unsupported query type.");
5562 	}
5563 
5564 	switch (get<SPIRType>(type.image.type).basetype)
5565 	{
5566 	case SPIRType::Float:
5567 		bit += QueryTypeFloat;
5568 		break;
5569 
5570 	case SPIRType::Int:
5571 		bit += QueryTypeInt;
5572 		break;
5573 
5574 	case SPIRType::UInt:
5575 		bit += QueryTypeUInt;
5576 		break;
5577 
5578 	default:
5579 		SPIRV_CROSS_THROW("Unsupported query type.");
5580 	}
5581 
5582 	auto norm_state = image_format_to_normalized_state(type.image.format);
5583 	auto &variant = uav ? required_texture_size_variants
5584 	                          .uav[uint32_t(norm_state)][image_format_to_components(type.image.format) - 1] :
5585 	                      required_texture_size_variants.srv;
5586 
5587 	uint64_t mask = 1ull << bit;
5588 	if ((variant & mask) == 0)
5589 	{
5590 		force_recompile();
5591 		variant |= mask;
5592 	}
5593 }
5594 
set_root_constant_layouts(std::vector<RootConstants> layout)5595 void CompilerHLSL::set_root_constant_layouts(std::vector<RootConstants> layout)
5596 {
5597 	root_constants_layout = move(layout);
5598 }
5599 
add_vertex_attribute_remap(const HLSLVertexAttributeRemap & vertex_attributes)5600 void CompilerHLSL::add_vertex_attribute_remap(const HLSLVertexAttributeRemap &vertex_attributes)
5601 {
5602 	remap_vertex_attributes.push_back(vertex_attributes);
5603 }
5604 
remap_num_workgroups_builtin()5605 VariableID CompilerHLSL::remap_num_workgroups_builtin()
5606 {
5607 	update_active_builtins();
5608 
5609 	if (!active_input_builtins.get(BuiltInNumWorkgroups))
5610 		return 0;
5611 
5612 	// Create a new, fake UBO.
5613 	uint32_t offset = ir.increase_bound_by(4);
5614 
5615 	uint32_t uint_type_id = offset;
5616 	uint32_t block_type_id = offset + 1;
5617 	uint32_t block_pointer_type_id = offset + 2;
5618 	uint32_t variable_id = offset + 3;
5619 
5620 	SPIRType uint_type;
5621 	uint_type.basetype = SPIRType::UInt;
5622 	uint_type.width = 32;
5623 	uint_type.vecsize = 3;
5624 	uint_type.columns = 1;
5625 	set<SPIRType>(uint_type_id, uint_type);
5626 
5627 	SPIRType block_type;
5628 	block_type.basetype = SPIRType::Struct;
5629 	block_type.member_types.push_back(uint_type_id);
5630 	set<SPIRType>(block_type_id, block_type);
5631 	set_decoration(block_type_id, DecorationBlock);
5632 	set_member_name(block_type_id, 0, "count");
5633 	set_member_decoration(block_type_id, 0, DecorationOffset, 0);
5634 
5635 	SPIRType block_pointer_type = block_type;
5636 	block_pointer_type.pointer = true;
5637 	block_pointer_type.storage = StorageClassUniform;
5638 	block_pointer_type.parent_type = block_type_id;
5639 	auto &ptr_type = set<SPIRType>(block_pointer_type_id, block_pointer_type);
5640 
5641 	// Preserve self.
5642 	ptr_type.self = block_type_id;
5643 
5644 	set<SPIRVariable>(variable_id, block_pointer_type_id, StorageClassUniform);
5645 	ir.meta[variable_id].decoration.alias = "SPIRV_Cross_NumWorkgroups";
5646 
5647 	num_workgroups_builtin = variable_id;
5648 	return variable_id;
5649 }
5650 
set_resource_binding_flags(HLSLBindingFlags flags)5651 void CompilerHLSL::set_resource_binding_flags(HLSLBindingFlags flags)
5652 {
5653 	resource_binding_flags = flags;
5654 }
5655 
validate_shader_model()5656 void CompilerHLSL::validate_shader_model()
5657 {
5658 	// Check for nonuniform qualifier.
5659 	// Instead of looping over all decorations to find this, just look at capabilities.
5660 	for (auto &cap : ir.declared_capabilities)
5661 	{
5662 		switch (cap)
5663 		{
5664 		case CapabilityShaderNonUniformEXT:
5665 		case CapabilityRuntimeDescriptorArrayEXT:
5666 			if (hlsl_options.shader_model < 51)
5667 				SPIRV_CROSS_THROW(
5668 				    "Shader model 5.1 or higher is required to use bindless resources or NonUniformResourceIndex.");
5669 			break;
5670 
5671 		case CapabilityVariablePointers:
5672 		case CapabilityVariablePointersStorageBuffer:
5673 			SPIRV_CROSS_THROW("VariablePointers capability is not supported in HLSL.");
5674 
5675 		default:
5676 			break;
5677 		}
5678 	}
5679 
5680 	if (ir.addressing_model != AddressingModelLogical)
5681 		SPIRV_CROSS_THROW("Only Logical addressing model can be used with HLSL.");
5682 
5683 	if (hlsl_options.enable_16bit_types && hlsl_options.shader_model < 62)
5684 		SPIRV_CROSS_THROW("Need at least shader model 6.2 when enabling native 16-bit type support.");
5685 }
5686 
compile()5687 string CompilerHLSL::compile()
5688 {
5689 	ir.fixup_reserved_names();
5690 
5691 	// Do not deal with ES-isms like precision, older extensions and such.
5692 	options.es = false;
5693 	options.version = 450;
5694 	options.vulkan_semantics = true;
5695 	backend.float_literal_suffix = true;
5696 	backend.double_literal_suffix = false;
5697 	backend.long_long_literal_suffix = true;
5698 	backend.uint32_t_literal_suffix = true;
5699 	backend.int16_t_literal_suffix = "";
5700 	backend.uint16_t_literal_suffix = "u";
5701 	backend.basic_int_type = "int";
5702 	backend.basic_uint_type = "uint";
5703 	backend.demote_literal = "discard";
5704 	backend.boolean_mix_function = "";
5705 	backend.swizzle_is_function = false;
5706 	backend.shared_is_implied = true;
5707 	backend.unsized_array_supported = true;
5708 	backend.explicit_struct_type = false;
5709 	backend.use_initializer_list = true;
5710 	backend.use_constructor_splatting = false;
5711 	backend.can_swizzle_scalar = true;
5712 	backend.can_declare_struct_inline = false;
5713 	backend.can_declare_arrays_inline = false;
5714 	backend.can_return_array = false;
5715 	backend.nonuniform_qualifier = "NonUniformResourceIndex";
5716 	backend.support_case_fallthrough = false;
5717 
5718 	fixup_type_alias();
5719 	reorder_type_alias();
5720 	build_function_control_flow_graphs_and_analyze();
5721 	validate_shader_model();
5722 	update_active_builtins();
5723 	analyze_image_and_sampler_usage();
5724 	analyze_interlocked_resource_usage();
5725 
5726 	// Subpass input needs SV_Position.
5727 	if (need_subpass_input)
5728 		active_input_builtins.set(BuiltInFragCoord);
5729 
5730 	uint32_t pass_count = 0;
5731 	do
5732 	{
5733 		if (pass_count >= 3)
5734 			SPIRV_CROSS_THROW("Over 3 compilation loops detected. Must be a bug!");
5735 
5736 		reset();
5737 
5738 		// Move constructor for this type is broken on GCC 4.9 ...
5739 		buffer.reset();
5740 
5741 		emit_header();
5742 		emit_resources();
5743 
5744 		emit_function(get<SPIRFunction>(ir.default_entry_point), Bitset());
5745 		emit_hlsl_entry_point();
5746 
5747 		pass_count++;
5748 	} while (is_forcing_recompilation());
5749 
5750 	// Entry point in HLSL is always main() for the time being.
5751 	get_entry_point().name = "main";
5752 
5753 	return buffer.str();
5754 }
5755 
emit_block_hints(const SPIRBlock & block)5756 void CompilerHLSL::emit_block_hints(const SPIRBlock &block)
5757 {
5758 	switch (block.hint)
5759 	{
5760 	case SPIRBlock::HintFlatten:
5761 		statement("[flatten]");
5762 		break;
5763 	case SPIRBlock::HintDontFlatten:
5764 		statement("[branch]");
5765 		break;
5766 	case SPIRBlock::HintUnroll:
5767 		statement("[unroll]");
5768 		break;
5769 	case SPIRBlock::HintDontUnroll:
5770 		statement("[loop]");
5771 		break;
5772 	default:
5773 		break;
5774 	}
5775 }
5776 
get_unique_identifier()5777 string CompilerHLSL::get_unique_identifier()
5778 {
5779 	return join("_", unique_identifier_count++, "ident");
5780 }
5781 
add_hlsl_resource_binding(const HLSLResourceBinding & binding)5782 void CompilerHLSL::add_hlsl_resource_binding(const HLSLResourceBinding &binding)
5783 {
5784 	StageSetBinding tuple = { binding.stage, binding.desc_set, binding.binding };
5785 	resource_bindings[tuple] = { binding, false };
5786 }
5787 
is_hlsl_resource_binding_used(ExecutionModel model,uint32_t desc_set,uint32_t binding) const5788 bool CompilerHLSL::is_hlsl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const
5789 {
5790 	StageSetBinding tuple = { model, desc_set, binding };
5791 	auto itr = resource_bindings.find(tuple);
5792 	return itr != end(resource_bindings) && itr->second.second;
5793 }
5794 
get_bitcast_type(uint32_t result_type,uint32_t op0)5795 CompilerHLSL::BitcastType CompilerHLSL::get_bitcast_type(uint32_t result_type, uint32_t op0)
5796 {
5797 	auto &rslt_type = get<SPIRType>(result_type);
5798 	auto &expr_type = expression_type(op0);
5799 
5800 	if (rslt_type.basetype == SPIRType::BaseType::UInt64 && expr_type.basetype == SPIRType::BaseType::UInt &&
5801 	    expr_type.vecsize == 2)
5802 		return BitcastType::TypePackUint2x32;
5803 	else if (rslt_type.basetype == SPIRType::BaseType::UInt && rslt_type.vecsize == 2 &&
5804 	         expr_type.basetype == SPIRType::BaseType::UInt64)
5805 		return BitcastType::TypeUnpackUint64;
5806 
5807 	return BitcastType::TypeNormal;
5808 }
5809 
is_hlsl_force_storage_buffer_as_uav(ID id) const5810 bool CompilerHLSL::is_hlsl_force_storage_buffer_as_uav(ID id) const
5811 {
5812 	if (hlsl_options.force_storage_buffer_as_uav)
5813 	{
5814 		return true;
5815 	}
5816 
5817 	const uint32_t desc_set = get_decoration(id, spv::DecorationDescriptorSet);
5818 	const uint32_t binding = get_decoration(id, spv::DecorationBinding);
5819 
5820 	return (force_uav_buffer_bindings.find({ desc_set, binding }) != force_uav_buffer_bindings.end());
5821 }
5822 
set_hlsl_force_storage_buffer_as_uav(uint32_t desc_set,uint32_t binding)5823 void CompilerHLSL::set_hlsl_force_storage_buffer_as_uav(uint32_t desc_set, uint32_t binding)
5824 {
5825 	SetBindingPair pair = { desc_set, binding };
5826 	force_uav_buffer_bindings.insert(pair);
5827 }
5828 
builtin_translates_to_nonarray(spv::BuiltIn builtin) const5829 bool CompilerHLSL::builtin_translates_to_nonarray(spv::BuiltIn builtin) const
5830 {
5831 	return (builtin == BuiltInSampleMask);
5832 }
5833