1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "Deserializer.h"
14
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/Support/LogicalResult.h"
23 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/bit.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/raw_ostream.h"
32
33 using namespace mlir;
34
35 #define DEBUG_TYPE "spirv-deserialization"
36
37 //===----------------------------------------------------------------------===//
38 // Utility Functions
39 //===----------------------------------------------------------------------===//
40
41 /// Returns true if the given `block` is a function entry block.
isFnEntryBlock(Block * block)42 static inline bool isFnEntryBlock(Block *block) {
43 return block->isEntryBlock() &&
44 isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45 }
46
47 //===----------------------------------------------------------------------===//
48 // Deserializer Method Definitions
49 //===----------------------------------------------------------------------===//
50
Deserializer(ArrayRef<uint32_t> binary,MLIRContext * context)51 spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
52 MLIRContext *context)
53 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
54 module(createModuleOp()), opBuilder(module->getRegion()) {}
55
deserialize()56 LogicalResult spirv::Deserializer::deserialize() {
57 LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n");
58
59 if (failed(processHeader()))
60 return failure();
61
62 spirv::Opcode opcode = spirv::Opcode::OpNop;
63 ArrayRef<uint32_t> operands;
64 auto binarySize = binary.size();
65 while (curOffset < binarySize) {
66 // Slice the next instruction out and populate `opcode` and `operands`.
67 // Internally this also updates `curOffset`.
68 if (failed(sliceInstruction(opcode, operands)))
69 return failure();
70
71 if (failed(processInstruction(opcode, operands)))
72 return failure();
73 }
74
75 assert(curOffset == binarySize &&
76 "deserializer should never index beyond the binary end");
77
78 for (auto &deferred : deferredInstructions) {
79 if (failed(processInstruction(deferred.first, deferred.second, false))) {
80 return failure();
81 }
82 }
83
84 attachVCETriple();
85
86 LLVM_DEBUG(llvm::dbgs() << "+++ completed deserialization +++\n");
87 return success();
88 }
89
collect()90 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
91 return std::move(module);
92 }
93
94 //===----------------------------------------------------------------------===//
95 // Module structure
96 //===----------------------------------------------------------------------===//
97
createModuleOp()98 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
99 OpBuilder builder(context);
100 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
101 spirv::ModuleOp::build(builder, state);
102 return cast<spirv::ModuleOp>(Operation::create(state));
103 }
104
processHeader()105 LogicalResult spirv::Deserializer::processHeader() {
106 if (binary.size() < spirv::kHeaderWordCount)
107 return emitError(unknownLoc,
108 "SPIR-V binary module must have a 5-word header");
109
110 if (binary[0] != spirv::kMagicNumber)
111 return emitError(unknownLoc, "incorrect magic number");
112
113 // Version number bytes: 0 | major number | minor number | 0
114 uint32_t majorVersion = (binary[1] << 8) >> 24;
115 uint32_t minorVersion = (binary[1] << 16) >> 24;
116 if (majorVersion == 1) {
117 switch (minorVersion) {
118 #define MIN_VERSION_CASE(v) \
119 case v: \
120 version = spirv::Version::V_1_##v; \
121 break
122
123 MIN_VERSION_CASE(0);
124 MIN_VERSION_CASE(1);
125 MIN_VERSION_CASE(2);
126 MIN_VERSION_CASE(3);
127 MIN_VERSION_CASE(4);
128 MIN_VERSION_CASE(5);
129 #undef MIN_VERSION_CASE
130 default:
131 return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
132 << minorVersion;
133 }
134 } else {
135 return emitError(unknownLoc, "unsupported SPIR-V major version: ")
136 << majorVersion;
137 }
138
139 // TODO: generator number, bound, schema
140 curOffset = spirv::kHeaderWordCount;
141 return success();
142 }
143
144 LogicalResult
processCapability(ArrayRef<uint32_t> operands)145 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
146 if (operands.size() != 1)
147 return emitError(unknownLoc, "OpMemoryModel must have one parameter");
148
149 auto cap = spirv::symbolizeCapability(operands[0]);
150 if (!cap)
151 return emitError(unknownLoc, "unknown capability: ") << operands[0];
152
153 capabilities.insert(*cap);
154 return success();
155 }
156
processExtension(ArrayRef<uint32_t> words)157 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
158 if (words.empty()) {
159 return emitError(
160 unknownLoc,
161 "OpExtension must have a literal string for the extension name");
162 }
163
164 unsigned wordIndex = 0;
165 StringRef extName = decodeStringLiteral(words, wordIndex);
166 if (wordIndex != words.size())
167 return emitError(unknownLoc,
168 "unexpected trailing words in OpExtension instruction");
169 auto ext = spirv::symbolizeExtension(extName);
170 if (!ext)
171 return emitError(unknownLoc, "unknown extension: ") << extName;
172
173 extensions.insert(*ext);
174 return success();
175 }
176
177 LogicalResult
processExtInstImport(ArrayRef<uint32_t> words)178 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
179 if (words.size() < 2) {
180 return emitError(unknownLoc,
181 "OpExtInstImport must have a result <id> and a literal "
182 "string for the extended instruction set name");
183 }
184
185 unsigned wordIndex = 1;
186 extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
187 if (wordIndex != words.size()) {
188 return emitError(unknownLoc,
189 "unexpected trailing words in OpExtInstImport");
190 }
191 return success();
192 }
193
attachVCETriple()194 void spirv::Deserializer::attachVCETriple() {
195 (*module)->setAttr(
196 spirv::ModuleOp::getVCETripleAttrName(),
197 spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
198 extensions.getArrayRef(), context));
199 }
200
201 LogicalResult
processMemoryModel(ArrayRef<uint32_t> operands)202 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
203 if (operands.size() != 2)
204 return emitError(unknownLoc, "OpMemoryModel must have two operands");
205
206 (*module)->setAttr(
207 "addressing_model",
208 opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.front())));
209 (*module)->setAttr(
210 "memory_model",
211 opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.back())));
212
213 return success();
214 }
215
processDecoration(ArrayRef<uint32_t> words)216 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
217 // TODO: This function should also be auto-generated. For now, since only a
218 // few decorations are processed/handled in a meaningful manner, going with a
219 // manual implementation.
220 if (words.size() < 2) {
221 return emitError(
222 unknownLoc, "OpDecorate must have at least result <id> and Decoration");
223 }
224 auto decorationName =
225 stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
226 if (decorationName.empty()) {
227 return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
228 }
229 auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
230 auto symbol = opBuilder.getIdentifier(attrName);
231 switch (static_cast<spirv::Decoration>(words[1])) {
232 case spirv::Decoration::DescriptorSet:
233 case spirv::Decoration::Binding:
234 if (words.size() != 3) {
235 return emitError(unknownLoc, "OpDecorate with ")
236 << decorationName << " needs a single integer literal";
237 }
238 decorations[words[0]].set(
239 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
240 break;
241 case spirv::Decoration::BuiltIn:
242 if (words.size() != 3) {
243 return emitError(unknownLoc, "OpDecorate with ")
244 << decorationName << " needs a single integer literal";
245 }
246 decorations[words[0]].set(
247 symbol, opBuilder.getStringAttr(
248 stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
249 break;
250 case spirv::Decoration::ArrayStride:
251 if (words.size() != 3) {
252 return emitError(unknownLoc, "OpDecorate with ")
253 << decorationName << " needs a single integer literal";
254 }
255 typeDecorations[words[0]] = words[2];
256 break;
257 case spirv::Decoration::Aliased:
258 case spirv::Decoration::Block:
259 case spirv::Decoration::BufferBlock:
260 case spirv::Decoration::Flat:
261 case spirv::Decoration::NonReadable:
262 case spirv::Decoration::NonWritable:
263 case spirv::Decoration::NoPerspective:
264 case spirv::Decoration::Restrict:
265 case spirv::Decoration::RelaxedPrecision:
266 if (words.size() != 2) {
267 return emitError(unknownLoc, "OpDecoration with ")
268 << decorationName << "needs a single target <id>";
269 }
270 // Block decoration does not affect spv.struct type, but is still stored for
271 // verification.
272 // TODO: Update StructType to contain this information since
273 // it is needed for many validation rules.
274 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
275 break;
276 case spirv::Decoration::Location:
277 case spirv::Decoration::SpecId:
278 if (words.size() != 3) {
279 return emitError(unknownLoc, "OpDecoration with ")
280 << decorationName << "needs a single integer literal";
281 }
282 decorations[words[0]].set(
283 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
284 break;
285 default:
286 return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
287 }
288 return success();
289 }
290
291 LogicalResult
processMemberDecoration(ArrayRef<uint32_t> words)292 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
293 // The binary layout of OpMemberDecorate is different comparing to OpDecorate
294 if (words.size() < 3) {
295 return emitError(unknownLoc,
296 "OpMemberDecorate must have at least 3 operands");
297 }
298
299 auto decoration = static_cast<spirv::Decoration>(words[2]);
300 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
301 return emitError(unknownLoc,
302 " missing offset specification in OpMemberDecorate with "
303 "Offset decoration");
304 }
305 ArrayRef<uint32_t> decorationOperands;
306 if (words.size() > 3) {
307 decorationOperands = words.slice(3);
308 }
309 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
310 return success();
311 }
312
processMemberName(ArrayRef<uint32_t> words)313 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
314 if (words.size() < 3) {
315 return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
316 }
317 unsigned wordIndex = 2;
318 auto name = decodeStringLiteral(words, wordIndex);
319 if (wordIndex != words.size()) {
320 return emitError(unknownLoc,
321 "unexpected trailing words in OpMemberName instruction");
322 }
323 memberNameMap[words[0]][words[1]] = name;
324 return success();
325 }
326
327 LogicalResult
processFunction(ArrayRef<uint32_t> operands)328 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
329 if (curFunction) {
330 return emitError(unknownLoc, "found function inside function");
331 }
332
333 // Get the result type
334 if (operands.size() != 4) {
335 return emitError(unknownLoc, "OpFunction must have 4 parameters");
336 }
337 Type resultType = getType(operands[0]);
338 if (!resultType) {
339 return emitError(unknownLoc, "undefined result type from <id> ")
340 << operands[0];
341 }
342
343 if (funcMap.count(operands[1])) {
344 return emitError(unknownLoc, "duplicate function definition/declaration");
345 }
346
347 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
348 if (!fnControl) {
349 return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
350 }
351
352 Type fnType = getType(operands[3]);
353 if (!fnType || !fnType.isa<FunctionType>()) {
354 return emitError(unknownLoc, "unknown function type from <id> ")
355 << operands[3];
356 }
357 auto functionType = fnType.cast<FunctionType>();
358
359 if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
360 (functionType.getNumResults() == 1 &&
361 functionType.getResult(0) != resultType)) {
362 return emitError(unknownLoc, "mismatch in function type ")
363 << functionType << " and return type " << resultType << " specified";
364 }
365
366 std::string fnName = getFunctionSymbol(operands[1]);
367 auto funcOp = opBuilder.create<spirv::FuncOp>(
368 unknownLoc, fnName, functionType, fnControl.getValue());
369 curFunction = funcMap[operands[1]] = funcOp;
370 LLVM_DEBUG(llvm::dbgs() << "-- start function " << fnName << " (type = "
371 << fnType << ", id = " << operands[1] << ") --\n");
372 auto *entryBlock = funcOp.addEntryBlock();
373 LLVM_DEBUG(llvm::dbgs() << "[block] created entry block " << entryBlock
374 << "\n");
375
376 // Parse the op argument instructions
377 if (functionType.getNumInputs()) {
378 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
379 auto argType = functionType.getInput(i);
380 spirv::Opcode opcode = spirv::Opcode::OpNop;
381 ArrayRef<uint32_t> operands;
382 if (failed(sliceInstruction(opcode, operands,
383 spirv::Opcode::OpFunctionParameter))) {
384 return failure();
385 }
386 if (opcode != spirv::Opcode::OpFunctionParameter) {
387 return emitError(
388 unknownLoc,
389 "missing OpFunctionParameter instruction for argument ")
390 << i;
391 }
392 if (operands.size() != 2) {
393 return emitError(
394 unknownLoc,
395 "expected result type and result <id> for OpFunctionParameter");
396 }
397 auto argDefinedType = getType(operands[0]);
398 if (!argDefinedType || argDefinedType != argType) {
399 return emitError(unknownLoc,
400 "mismatch in argument type between function type "
401 "definition ")
402 << functionType << " and argument type definition "
403 << argDefinedType << " at argument " << i;
404 }
405 if (getValue(operands[1])) {
406 return emitError(unknownLoc, "duplicate definition of result <id> '")
407 << operands[1];
408 }
409 auto argValue = funcOp.getArgument(i);
410 valueMap[operands[1]] = argValue;
411 }
412 }
413
414 // RAII guard to reset the insertion point to the module's region after
415 // deserializing the body of this function.
416 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
417
418 spirv::Opcode opcode = spirv::Opcode::OpNop;
419 ArrayRef<uint32_t> instOperands;
420
421 // Special handling for the entry block. We need to make sure it starts with
422 // an OpLabel instruction. The entry block takes the same parameters as the
423 // function. All other blocks do not take any parameter. We have already
424 // created the entry block, here we need to register it to the correct label
425 // <id>.
426 if (failed(sliceInstruction(opcode, instOperands,
427 spirv::Opcode::OpFunctionEnd))) {
428 return failure();
429 }
430 if (opcode == spirv::Opcode::OpFunctionEnd) {
431 LLVM_DEBUG(llvm::dbgs()
432 << "-- completed function '" << fnName << "' (type = " << fnType
433 << ", id = " << operands[1] << ") --\n");
434 return processFunctionEnd(instOperands);
435 }
436 if (opcode != spirv::Opcode::OpLabel) {
437 return emitError(unknownLoc, "a basic block must start with OpLabel");
438 }
439 if (instOperands.size() != 1) {
440 return emitError(unknownLoc, "OpLabel should only have result <id>");
441 }
442 blockMap[instOperands[0]] = entryBlock;
443 if (failed(processLabel(instOperands))) {
444 return failure();
445 }
446
447 // Then process all the other instructions in the function until we hit
448 // OpFunctionEnd.
449 while (succeeded(sliceInstruction(opcode, instOperands,
450 spirv::Opcode::OpFunctionEnd)) &&
451 opcode != spirv::Opcode::OpFunctionEnd) {
452 if (failed(processInstruction(opcode, instOperands))) {
453 return failure();
454 }
455 }
456 if (opcode != spirv::Opcode::OpFunctionEnd) {
457 return failure();
458 }
459
460 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << fnName << "' (type = "
461 << fnType << ", id = " << operands[1] << ") --\n");
462 return processFunctionEnd(instOperands);
463 }
464
465 LogicalResult
processFunctionEnd(ArrayRef<uint32_t> operands)466 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
467 // Process OpFunctionEnd.
468 if (!operands.empty()) {
469 return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
470 }
471
472 // Wire up block arguments from OpPhi instructions.
473 // Put all structured control flow in spv.mlir.selection/spv.mlir.loop ops.
474 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
475 return failure();
476 }
477
478 curBlock = nullptr;
479 curFunction = llvm::None;
480
481 return success();
482 }
483
484 Optional<std::pair<Attribute, Type>>
getConstant(uint32_t id)485 spirv::Deserializer::getConstant(uint32_t id) {
486 auto constIt = constantMap.find(id);
487 if (constIt == constantMap.end())
488 return llvm::None;
489 return constIt->getSecond();
490 }
491
492 Optional<spirv::SpecConstOperationMaterializationInfo>
getSpecConstantOperation(uint32_t id)493 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
494 auto constIt = specConstOperationMap.find(id);
495 if (constIt == specConstOperationMap.end())
496 return llvm::None;
497 return constIt->getSecond();
498 }
499
getFunctionSymbol(uint32_t id)500 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
501 auto funcName = nameMap.lookup(id).str();
502 if (funcName.empty()) {
503 funcName = "spirv_fn_" + std::to_string(id);
504 }
505 return funcName;
506 }
507
getSpecConstantSymbol(uint32_t id)508 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
509 auto constName = nameMap.lookup(id).str();
510 if (constName.empty()) {
511 constName = "spirv_spec_const_" + std::to_string(id);
512 }
513 return constName;
514 }
515
516 spirv::SpecConstantOp
createSpecConstant(Location loc,uint32_t resultID,Attribute defaultValue)517 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
518 Attribute defaultValue) {
519 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
520 auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
521 defaultValue);
522 if (decorations.count(resultID)) {
523 for (auto attr : decorations[resultID].getAttrs())
524 op->setAttr(attr.first, attr.second);
525 }
526 specConstMap[resultID] = op;
527 return op;
528 }
529
530 LogicalResult
processGlobalVariable(ArrayRef<uint32_t> operands)531 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
532 unsigned wordIndex = 0;
533 if (operands.size() < 3) {
534 return emitError(
535 unknownLoc,
536 "OpVariable needs at least 3 operands, type, <id> and storage class");
537 }
538
539 // Result Type.
540 auto type = getType(operands[wordIndex]);
541 if (!type) {
542 return emitError(unknownLoc, "unknown result type <id> : ")
543 << operands[wordIndex];
544 }
545 auto ptrType = type.dyn_cast<spirv::PointerType>();
546 if (!ptrType) {
547 return emitError(unknownLoc,
548 "expected a result type <id> to be a spv.ptr, found : ")
549 << type;
550 }
551 wordIndex++;
552
553 // Result <id>.
554 auto variableID = operands[wordIndex];
555 auto variableName = nameMap.lookup(variableID).str();
556 if (variableName.empty()) {
557 variableName = "spirv_var_" + std::to_string(variableID);
558 }
559 wordIndex++;
560
561 // Storage class.
562 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
563 if (ptrType.getStorageClass() != storageClass) {
564 return emitError(unknownLoc, "mismatch in storage class of pointer type ")
565 << type << " and that specified in OpVariable instruction : "
566 << stringifyStorageClass(storageClass);
567 }
568 wordIndex++;
569
570 // Initializer.
571 FlatSymbolRefAttr initializer = nullptr;
572 if (wordIndex < operands.size()) {
573 auto initializerOp = getGlobalVariable(operands[wordIndex]);
574 if (!initializerOp) {
575 return emitError(unknownLoc, "unknown <id> ")
576 << operands[wordIndex] << "used as initializer";
577 }
578 wordIndex++;
579 initializer = SymbolRefAttr::get(initializerOp.getOperation());
580 }
581 if (wordIndex != operands.size()) {
582 return emitError(unknownLoc,
583 "found more operands than expected when deserializing "
584 "OpVariable instruction, only ")
585 << wordIndex << " of " << operands.size() << " processed";
586 }
587 auto loc = createFileLineColLoc(opBuilder);
588 auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
589 loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
590 initializer);
591
592 // Decorations.
593 if (decorations.count(variableID)) {
594 for (auto attr : decorations[variableID].getAttrs()) {
595 varOp->setAttr(attr.first, attr.second);
596 }
597 }
598 globalVariableMap[variableID] = varOp;
599 return success();
600 }
601
getConstantInt(uint32_t id)602 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
603 auto constInfo = getConstant(id);
604 if (!constInfo) {
605 return nullptr;
606 }
607 return constInfo->first.dyn_cast<IntegerAttr>();
608 }
609
processName(ArrayRef<uint32_t> operands)610 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
611 if (operands.size() < 2) {
612 return emitError(unknownLoc, "OpName needs at least 2 operands");
613 }
614 if (!nameMap.lookup(operands[0]).empty()) {
615 return emitError(unknownLoc, "duplicate name found for result <id> ")
616 << operands[0];
617 }
618 unsigned wordIndex = 1;
619 StringRef name = decodeStringLiteral(operands, wordIndex);
620 if (wordIndex != operands.size()) {
621 return emitError(unknownLoc,
622 "unexpected trailing words in OpName instruction");
623 }
624 nameMap[operands[0]] = name;
625 return success();
626 }
627
628 //===----------------------------------------------------------------------===//
629 // Type
630 //===----------------------------------------------------------------------===//
631
processType(spirv::Opcode opcode,ArrayRef<uint32_t> operands)632 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
633 ArrayRef<uint32_t> operands) {
634 if (operands.empty()) {
635 return emitError(unknownLoc, "type instruction with opcode ")
636 << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
637 }
638
639 /// TODO: Types might be forward declared in some instructions and need to be
640 /// handled appropriately.
641 if (typeMap.count(operands[0])) {
642 return emitError(unknownLoc, "duplicate definition for result <id> ")
643 << operands[0];
644 }
645
646 switch (opcode) {
647 case spirv::Opcode::OpTypeVoid:
648 if (operands.size() != 1)
649 return emitError(unknownLoc, "OpTypeVoid must have no parameters");
650 typeMap[operands[0]] = opBuilder.getNoneType();
651 break;
652 case spirv::Opcode::OpTypeBool:
653 if (operands.size() != 1)
654 return emitError(unknownLoc, "OpTypeBool must have no parameters");
655 typeMap[operands[0]] = opBuilder.getI1Type();
656 break;
657 case spirv::Opcode::OpTypeInt: {
658 if (operands.size() != 3)
659 return emitError(
660 unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
661
662 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
663 // to preserve or validate.
664 // 0 indicates unsigned, or no signedness semantics
665 // 1 indicates signed semantics."
666 //
667 // So we cannot differentiate signless and unsigned integers; always use
668 // signless semantics for such cases.
669 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
670 : IntegerType::SignednessSemantics::Signless;
671 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
672 } break;
673 case spirv::Opcode::OpTypeFloat: {
674 if (operands.size() != 2)
675 return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
676
677 Type floatTy;
678 switch (operands[1]) {
679 case 16:
680 floatTy = opBuilder.getF16Type();
681 break;
682 case 32:
683 floatTy = opBuilder.getF32Type();
684 break;
685 case 64:
686 floatTy = opBuilder.getF64Type();
687 break;
688 default:
689 return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
690 << operands[1];
691 }
692 typeMap[operands[0]] = floatTy;
693 } break;
694 case spirv::Opcode::OpTypeVector: {
695 if (operands.size() != 3) {
696 return emitError(
697 unknownLoc,
698 "OpTypeVector must have element type and count parameters");
699 }
700 Type elementTy = getType(operands[1]);
701 if (!elementTy) {
702 return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
703 << operands[1];
704 }
705 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
706 } break;
707 case spirv::Opcode::OpTypePointer: {
708 return processOpTypePointer(operands);
709 } break;
710 case spirv::Opcode::OpTypeArray:
711 return processArrayType(operands);
712 case spirv::Opcode::OpTypeCooperativeMatrixNV:
713 return processCooperativeMatrixType(operands);
714 case spirv::Opcode::OpTypeFunction:
715 return processFunctionType(operands);
716 case spirv::Opcode::OpTypeImage:
717 return processImageType(operands);
718 case spirv::Opcode::OpTypeSampledImage:
719 return processSampledImageType(operands);
720 case spirv::Opcode::OpTypeRuntimeArray:
721 return processRuntimeArrayType(operands);
722 case spirv::Opcode::OpTypeStruct:
723 return processStructType(operands);
724 case spirv::Opcode::OpTypeMatrix:
725 return processMatrixType(operands);
726 default:
727 return emitError(unknownLoc, "unhandled type instruction");
728 }
729 return success();
730 }
731
732 LogicalResult
processOpTypePointer(ArrayRef<uint32_t> operands)733 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
734 if (operands.size() != 3)
735 return emitError(unknownLoc, "OpTypePointer must have two parameters");
736
737 auto pointeeType = getType(operands[2]);
738 if (!pointeeType)
739 return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
740 << operands[2];
741
742 uint32_t typePointerID = operands[0];
743 auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
744 typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
745
746 for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
747 deferredStructIt != std::end(deferredStructTypesInfos);) {
748 for (auto *unresolvedMemberIt =
749 std::begin(deferredStructIt->unresolvedMemberTypes);
750 unresolvedMemberIt !=
751 std::end(deferredStructIt->unresolvedMemberTypes);) {
752 if (unresolvedMemberIt->first == typePointerID) {
753 // The newly constructed pointer type can resolve one of the
754 // deferred struct type members; update the memberTypes list and
755 // clean the unresolvedMemberTypes list accordingly.
756 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
757 typeMap[typePointerID];
758 unresolvedMemberIt =
759 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
760 } else {
761 ++unresolvedMemberIt;
762 }
763 }
764
765 if (deferredStructIt->unresolvedMemberTypes.empty()) {
766 // All deferred struct type members are now resolved, set the struct body.
767 auto structType = deferredStructIt->deferredStructType;
768
769 assert(structType && "expected a spirv::StructType");
770 assert(structType.isIdentified() && "expected an indentified struct");
771
772 if (failed(structType.trySetBody(
773 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
774 deferredStructIt->memberDecorationsInfo)))
775 return failure();
776
777 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
778 } else {
779 ++deferredStructIt;
780 }
781 }
782
783 return success();
784 }
785
786 LogicalResult
processArrayType(ArrayRef<uint32_t> operands)787 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
788 if (operands.size() != 3) {
789 return emitError(unknownLoc,
790 "OpTypeArray must have element type and count parameters");
791 }
792
793 Type elementTy = getType(operands[1]);
794 if (!elementTy) {
795 return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
796 << operands[1];
797 }
798
799 unsigned count = 0;
800 // TODO: The count can also come frome a specialization constant.
801 auto countInfo = getConstant(operands[2]);
802 if (!countInfo) {
803 return emitError(unknownLoc, "OpTypeArray count <id> ")
804 << operands[2] << "can only come from normal constant right now";
805 }
806
807 if (auto intVal = countInfo->first.dyn_cast<IntegerAttr>()) {
808 count = intVal.getValue().getZExtValue();
809 } else {
810 return emitError(unknownLoc, "OpTypeArray count must come from a "
811 "scalar integer constant instruction");
812 }
813
814 typeMap[operands[0]] = spirv::ArrayType::get(
815 elementTy, count, typeDecorations.lookup(operands[0]));
816 return success();
817 }
818
819 LogicalResult
processFunctionType(ArrayRef<uint32_t> operands)820 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
821 assert(!operands.empty() && "No operands for processing function type");
822 if (operands.size() == 1) {
823 return emitError(unknownLoc, "missing return type for OpTypeFunction");
824 }
825 auto returnType = getType(operands[1]);
826 if (!returnType) {
827 return emitError(unknownLoc, "unknown return type in OpTypeFunction");
828 }
829 SmallVector<Type, 1> argTypes;
830 for (size_t i = 2, e = operands.size(); i < e; ++i) {
831 auto ty = getType(operands[i]);
832 if (!ty) {
833 return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
834 }
835 argTypes.push_back(ty);
836 }
837 ArrayRef<Type> returnTypes;
838 if (!isVoidType(returnType)) {
839 returnTypes = llvm::makeArrayRef(returnType);
840 }
841 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
842 return success();
843 }
844
845 LogicalResult
processCooperativeMatrixType(ArrayRef<uint32_t> operands)846 spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
847 if (operands.size() != 5) {
848 return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element "
849 "type and row x column parameters");
850 }
851
852 Type elementTy = getType(operands[1]);
853 if (!elementTy) {
854 return emitError(unknownLoc,
855 "OpTypeCooperativeMatrix references undefined <id> ")
856 << operands[1];
857 }
858
859 auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
860 if (!scope) {
861 return emitError(unknownLoc,
862 "OpTypeCooperativeMatrix references undefined scope <id> ")
863 << operands[2];
864 }
865
866 unsigned rows = getConstantInt(operands[3]).getInt();
867 unsigned columns = getConstantInt(operands[4]).getInt();
868
869 typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
870 elementTy, scope.getValue(), rows, columns);
871 return success();
872 }
873
874 LogicalResult
processRuntimeArrayType(ArrayRef<uint32_t> operands)875 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
876 if (operands.size() != 2) {
877 return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
878 }
879 Type memberType = getType(operands[1]);
880 if (!memberType) {
881 return emitError(unknownLoc,
882 "OpTypeRuntimeArray references undefined <id> ")
883 << operands[1];
884 }
885 typeMap[operands[0]] = spirv::RuntimeArrayType::get(
886 memberType, typeDecorations.lookup(operands[0]));
887 return success();
888 }
889
890 LogicalResult
processStructType(ArrayRef<uint32_t> operands)891 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
892 // TODO: Find a way to handle identified structs when debug info is stripped.
893
894 if (operands.empty()) {
895 return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
896 }
897
898 if (operands.size() == 1) {
899 // Handle empty struct.
900 typeMap[operands[0]] =
901 spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
902 return success();
903 }
904
905 // First element is operand ID, second element is member index in the struct.
906 SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
907 SmallVector<Type, 4> memberTypes;
908
909 for (auto op : llvm::drop_begin(operands, 1)) {
910 Type memberType = getType(op);
911 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
912
913 if (!memberType && !typeForwardPtr)
914 return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
915 << op;
916
917 if (!memberType)
918 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
919
920 memberTypes.push_back(memberType);
921 }
922
923 SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
924 SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
925 if (memberDecorationMap.count(operands[0])) {
926 auto &allMemberDecorations = memberDecorationMap[operands[0]];
927 for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
928 if (allMemberDecorations.count(memberIndex)) {
929 for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
930 // Check for offset.
931 if (memberDecoration.first == spirv::Decoration::Offset) {
932 // If offset info is empty, resize to the number of members;
933 if (offsetInfo.empty()) {
934 offsetInfo.resize(memberTypes.size());
935 }
936 offsetInfo[memberIndex] = memberDecoration.second[0];
937 } else {
938 if (!memberDecoration.second.empty()) {
939 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
940 memberDecoration.first,
941 memberDecoration.second[0]);
942 } else {
943 memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
944 memberDecoration.first, 0);
945 }
946 }
947 }
948 }
949 }
950 }
951
952 uint32_t structID = operands[0];
953 std::string structIdentifier = nameMap.lookup(structID).str();
954
955 if (structIdentifier.empty()) {
956 assert(unresolvedMemberTypes.empty() &&
957 "didn't expect unresolved member types");
958 typeMap[structID] =
959 spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
960 } else {
961 auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
962 typeMap[structID] = structTy;
963
964 if (!unresolvedMemberTypes.empty())
965 deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
966 memberTypes, offsetInfo,
967 memberDecorationsInfo});
968 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
969 memberDecorationsInfo)))
970 return failure();
971 }
972
973 // TODO: Update StructType to have member name as attribute as
974 // well.
975 return success();
976 }
977
978 LogicalResult
processMatrixType(ArrayRef<uint32_t> operands)979 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
980 if (operands.size() != 3) {
981 // Three operands are needed: result_id, column_type, and column_count
982 return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
983 " (result_id, column_type, and column_count)");
984 }
985 // Matrix columns must be of vector type
986 Type elementTy = getType(operands[1]);
987 if (!elementTy) {
988 return emitError(unknownLoc,
989 "OpTypeMatrix references undefined column type.")
990 << operands[1];
991 }
992
993 uint32_t colsCount = operands[2];
994 typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
995 return success();
996 }
997
998 LogicalResult
processTypeForwardPointer(ArrayRef<uint32_t> operands)999 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1000 if (operands.size() != 2)
1001 return emitError(unknownLoc,
1002 "OpTypeForwardPointer instruction must have two operands");
1003
1004 typeForwardPointerIDs.insert(operands[0]);
1005 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1006 // instruction that defines the actual type.
1007
1008 return success();
1009 }
1010
1011 LogicalResult
processImageType(ArrayRef<uint32_t> operands)1012 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1013 // TODO: Add support for Access Qualifier.
1014 if (operands.size() != 8)
1015 return emitError(
1016 unknownLoc,
1017 "OpTypeImage with non-eight operands are not supported yet");
1018
1019 Type elementTy = getType(operands[1]);
1020 if (!elementTy)
1021 return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1022 << operands[1];
1023
1024 auto dim = spirv::symbolizeDim(operands[2]);
1025 if (!dim)
1026 return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1027 << operands[2];
1028
1029 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1030 if (!depthInfo)
1031 return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1032 << operands[3];
1033
1034 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1035 if (!arrayedInfo)
1036 return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1037 << operands[4];
1038
1039 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1040 if (!samplingInfo)
1041 return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1042
1043 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1044 if (!samplerUseInfo)
1045 return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1046 << operands[6];
1047
1048 auto format = spirv::symbolizeImageFormat(operands[7]);
1049 if (!format)
1050 return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1051 << operands[7];
1052
1053 typeMap[operands[0]] = spirv::ImageType::get(
1054 elementTy, dim.getValue(), depthInfo.getValue(), arrayedInfo.getValue(),
1055 samplingInfo.getValue(), samplerUseInfo.getValue(), format.getValue());
1056 return success();
1057 }
1058
1059 LogicalResult
processSampledImageType(ArrayRef<uint32_t> operands)1060 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1061 if (operands.size() != 2)
1062 return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1063
1064 Type elementTy = getType(operands[1]);
1065 if (!elementTy)
1066 return emitError(unknownLoc,
1067 "OpTypeSampledImage references undefined <id>: ")
1068 << operands[1];
1069
1070 typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1071 return success();
1072 }
1073
1074 //===----------------------------------------------------------------------===//
1075 // Constant
1076 //===----------------------------------------------------------------------===//
1077
processConstant(ArrayRef<uint32_t> operands,bool isSpec)1078 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1079 bool isSpec) {
1080 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1081
1082 if (operands.size() < 2) {
1083 return emitError(unknownLoc)
1084 << opname << " must have type <id> and result <id>";
1085 }
1086 if (operands.size() < 3) {
1087 return emitError(unknownLoc)
1088 << opname << " must have at least 1 more parameter";
1089 }
1090
1091 Type resultType = getType(operands[0]);
1092 if (!resultType) {
1093 return emitError(unknownLoc, "undefined result type from <id> ")
1094 << operands[0];
1095 }
1096
1097 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1098 if (bitwidth == 64) {
1099 if (operands.size() == 4) {
1100 return success();
1101 }
1102 return emitError(unknownLoc)
1103 << opname << " should have 2 parameters for 64-bit values";
1104 }
1105 if (bitwidth <= 32) {
1106 if (operands.size() == 3) {
1107 return success();
1108 }
1109
1110 return emitError(unknownLoc)
1111 << opname
1112 << " should have 1 parameter for values with no more than 32 bits";
1113 }
1114 return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1115 << bitwidth;
1116 };
1117
1118 auto resultID = operands[1];
1119
1120 if (auto intType = resultType.dyn_cast<IntegerType>()) {
1121 auto bitwidth = intType.getWidth();
1122 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1123 return failure();
1124 }
1125
1126 APInt value;
1127 if (bitwidth == 64) {
1128 // 64-bit integers are represented with two SPIR-V words. According to
1129 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1130 // literal’s low-order words appear first."
1131 struct DoubleWord {
1132 uint32_t word1;
1133 uint32_t word2;
1134 } words = {operands[2], operands[3]};
1135 value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1136 } else if (bitwidth <= 32) {
1137 value = APInt(bitwidth, operands[2], /*isSigned=*/true);
1138 }
1139
1140 auto attr = opBuilder.getIntegerAttr(intType, value);
1141
1142 if (isSpec) {
1143 createSpecConstant(unknownLoc, resultID, attr);
1144 } else {
1145 // For normal constants, we just record the attribute (and its type) for
1146 // later materialization at use sites.
1147 constantMap.try_emplace(resultID, attr, intType);
1148 }
1149
1150 return success();
1151 }
1152
1153 if (auto floatType = resultType.dyn_cast<FloatType>()) {
1154 auto bitwidth = floatType.getWidth();
1155 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1156 return failure();
1157 }
1158
1159 APFloat value(0.f);
1160 if (floatType.isF64()) {
1161 // Double values are represented with two SPIR-V words. According to
1162 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1163 // literal’s low-order words appear first."
1164 struct DoubleWord {
1165 uint32_t word1;
1166 uint32_t word2;
1167 } words = {operands[2], operands[3]};
1168 value = APFloat(llvm::bit_cast<double>(words));
1169 } else if (floatType.isF32()) {
1170 value = APFloat(llvm::bit_cast<float>(operands[2]));
1171 } else if (floatType.isF16()) {
1172 APInt data(16, operands[2]);
1173 value = APFloat(APFloat::IEEEhalf(), data);
1174 }
1175
1176 auto attr = opBuilder.getFloatAttr(floatType, value);
1177 if (isSpec) {
1178 createSpecConstant(unknownLoc, resultID, attr);
1179 } else {
1180 // For normal constants, we just record the attribute (and its type) for
1181 // later materialization at use sites.
1182 constantMap.try_emplace(resultID, attr, floatType);
1183 }
1184
1185 return success();
1186 }
1187
1188 return emitError(unknownLoc, "OpConstant can only generate values of "
1189 "scalar integer or floating-point type");
1190 }
1191
processConstantBool(bool isTrue,ArrayRef<uint32_t> operands,bool isSpec)1192 LogicalResult spirv::Deserializer::processConstantBool(
1193 bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1194 if (operands.size() != 2) {
1195 return emitError(unknownLoc, "Op")
1196 << (isSpec ? "Spec" : "") << "Constant"
1197 << (isTrue ? "True" : "False")
1198 << " must have type <id> and result <id>";
1199 }
1200
1201 auto attr = opBuilder.getBoolAttr(isTrue);
1202 auto resultID = operands[1];
1203 if (isSpec) {
1204 createSpecConstant(unknownLoc, resultID, attr);
1205 } else {
1206 // For normal constants, we just record the attribute (and its type) for
1207 // later materialization at use sites.
1208 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1209 }
1210
1211 return success();
1212 }
1213
1214 LogicalResult
processConstantComposite(ArrayRef<uint32_t> operands)1215 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1216 if (operands.size() < 2) {
1217 return emitError(unknownLoc,
1218 "OpConstantComposite must have type <id> and result <id>");
1219 }
1220 if (operands.size() < 3) {
1221 return emitError(unknownLoc,
1222 "OpConstantComposite must have at least 1 parameter");
1223 }
1224
1225 Type resultType = getType(operands[0]);
1226 if (!resultType) {
1227 return emitError(unknownLoc, "undefined result type from <id> ")
1228 << operands[0];
1229 }
1230
1231 SmallVector<Attribute, 4> elements;
1232 elements.reserve(operands.size() - 2);
1233 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1234 auto elementInfo = getConstant(operands[i]);
1235 if (!elementInfo) {
1236 return emitError(unknownLoc, "OpConstantComposite component <id> ")
1237 << operands[i] << " must come from a normal constant";
1238 }
1239 elements.push_back(elementInfo->first);
1240 }
1241
1242 auto resultID = operands[1];
1243 if (auto vectorType = resultType.dyn_cast<VectorType>()) {
1244 auto attr = DenseElementsAttr::get(vectorType, elements);
1245 // For normal constants, we just record the attribute (and its type) for
1246 // later materialization at use sites.
1247 constantMap.try_emplace(resultID, attr, resultType);
1248 } else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) {
1249 auto attr = opBuilder.getArrayAttr(elements);
1250 constantMap.try_emplace(resultID, attr, resultType);
1251 } else {
1252 return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1253 << resultType;
1254 }
1255
1256 return success();
1257 }
1258
1259 LogicalResult
processSpecConstantComposite(ArrayRef<uint32_t> operands)1260 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1261 if (operands.size() < 2) {
1262 return emitError(unknownLoc,
1263 "OpConstantComposite must have type <id> and result <id>");
1264 }
1265 if (operands.size() < 3) {
1266 return emitError(unknownLoc,
1267 "OpConstantComposite must have at least 1 parameter");
1268 }
1269
1270 Type resultType = getType(operands[0]);
1271 if (!resultType) {
1272 return emitError(unknownLoc, "undefined result type from <id> ")
1273 << operands[0];
1274 }
1275
1276 auto resultID = operands[1];
1277 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1278
1279 SmallVector<Attribute, 4> elements;
1280 elements.reserve(operands.size() - 2);
1281 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1282 auto elementInfo = getSpecConstant(operands[i]);
1283 elements.push_back(SymbolRefAttr::get(elementInfo));
1284 }
1285
1286 auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1287 unknownLoc, TypeAttr::get(resultType), symName,
1288 opBuilder.getArrayAttr(elements));
1289 specConstCompositeMap[resultID] = op;
1290
1291 return success();
1292 }
1293
1294 LogicalResult
processSpecConstantOperation(ArrayRef<uint32_t> operands)1295 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1296 if (operands.size() < 3)
1297 return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1298 "result <id>, and operand opcode");
1299
1300 uint32_t resultTypeID = operands[0];
1301
1302 if (!getType(resultTypeID))
1303 return emitError(unknownLoc, "undefined result type from <id> ")
1304 << resultTypeID;
1305
1306 uint32_t resultID = operands[1];
1307 spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1308 auto emplaceResult = specConstOperationMap.try_emplace(
1309 resultID,
1310 SpecConstOperationMaterializationInfo{
1311 enclosedOpcode, resultTypeID,
1312 SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1313
1314 if (!emplaceResult.second)
1315 return emitError(unknownLoc, "value with <id>: ")
1316 << resultID << " is probably defined before.";
1317
1318 return success();
1319 }
1320
materializeSpecConstantOperation(uint32_t resultID,spirv::Opcode enclosedOpcode,uint32_t resultTypeID,ArrayRef<uint32_t> enclosedOpOperands)1321 Value spirv::Deserializer::materializeSpecConstantOperation(
1322 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1323 ArrayRef<uint32_t> enclosedOpOperands) {
1324
1325 Type resultType = getType(resultTypeID);
1326
1327 // Instructions wrapped by OpSpecConstantOp need an ID for their
1328 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1329 // dialect wrapped op. For that purpose, a new value map is created and "fake"
1330 // ID in that map is assigned to the result of the enclosed instruction. Note
1331 // that there is no need to update this fake ID since we only need to
1332 // reference the created Value for the enclosed op from the spv::YieldOp
1333 // created later in this method (both of which are the only values in their
1334 // region: the SpecConstantOperation's region). If we encounter another
1335 // SpecConstantOperation in the module, we simply re-use the fake ID since the
1336 // previous Value assigned to it isn't visible in the current scope anyway.
1337 DenseMap<uint32_t, Value> newValueMap;
1338 llvm::SaveAndRestore<DenseMap<uint32_t, Value>> valueMapGuard(valueMap,
1339 newValueMap);
1340 constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1341
1342 SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1343 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1344 enclosedOpResultTypeAndOperands.push_back(fakeID);
1345 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1346 enclosedOpOperands.end());
1347
1348 // Process enclosed instruction before creating the enclosing
1349 // specConstantOperation (and its region). This way, references to constants,
1350 // global variables, and spec constants will be materialized outside the new
1351 // op's region. For more info, see Deserializer::getValue's implementation.
1352 if (failed(
1353 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1354 return Value();
1355
1356 // Since the enclosed op is emitted in the current block, split it in a
1357 // separate new block.
1358 Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1359
1360 auto loc = createFileLineColLoc(opBuilder);
1361 auto specConstOperationOp =
1362 opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1363
1364 Region &body = specConstOperationOp.body();
1365 // Move the new block into SpecConstantOperation's body.
1366 body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1367 Region::iterator(enclosedBlock));
1368 Block &block = body.back();
1369
1370 // RAII guard to reset the insertion point to the module's region after
1371 // deserializing the body of the specConstantOperation.
1372 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1373 opBuilder.setInsertionPointToEnd(&block);
1374
1375 opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1376 return specConstOperationOp.getResult();
1377 }
1378
1379 LogicalResult
processConstantNull(ArrayRef<uint32_t> operands)1380 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1381 if (operands.size() != 2) {
1382 return emitError(unknownLoc,
1383 "OpConstantNull must have type <id> and result <id>");
1384 }
1385
1386 Type resultType = getType(operands[0]);
1387 if (!resultType) {
1388 return emitError(unknownLoc, "undefined result type from <id> ")
1389 << operands[0];
1390 }
1391
1392 auto resultID = operands[1];
1393 if (resultType.isIntOrFloat() || resultType.isa<VectorType>()) {
1394 auto attr = opBuilder.getZeroAttr(resultType);
1395 // For normal constants, we just record the attribute (and its type) for
1396 // later materialization at use sites.
1397 constantMap.try_emplace(resultID, attr, resultType);
1398 return success();
1399 }
1400
1401 return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1402 << resultType;
1403 }
1404
1405 //===----------------------------------------------------------------------===//
1406 // Control flow
1407 //===----------------------------------------------------------------------===//
1408
getOrCreateBlock(uint32_t id)1409 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1410 if (auto *block = getBlock(id)) {
1411 LLVM_DEBUG(llvm::dbgs() << "[block] got exiting block for id = " << id
1412 << " @ " << block << "\n");
1413 return block;
1414 }
1415
1416 // We don't know where this block will be placed finally (in a
1417 // spv.mlir.selection or spv.mlir.loop or function). Create it into the
1418 // function for now and sort out the proper place later.
1419 auto *block = curFunction->addBlock();
1420 LLVM_DEBUG(llvm::dbgs() << "[block] created block for id = " << id << " @ "
1421 << block << "\n");
1422 return blockMap[id] = block;
1423 }
1424
processBranch(ArrayRef<uint32_t> operands)1425 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1426 if (!curBlock) {
1427 return emitError(unknownLoc, "OpBranch must appear inside a block");
1428 }
1429
1430 if (operands.size() != 1) {
1431 return emitError(unknownLoc, "OpBranch must take exactly one target label");
1432 }
1433
1434 auto *target = getOrCreateBlock(operands[0]);
1435 auto loc = createFileLineColLoc(opBuilder);
1436 // The preceding instruction for the OpBranch instruction could be an
1437 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1438 // the same OpLine information.
1439 opBuilder.create<spirv::BranchOp>(loc, target);
1440
1441 (void)clearDebugLine();
1442 return success();
1443 }
1444
1445 LogicalResult
processBranchConditional(ArrayRef<uint32_t> operands)1446 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1447 if (!curBlock) {
1448 return emitError(unknownLoc,
1449 "OpBranchConditional must appear inside a block");
1450 }
1451
1452 if (operands.size() != 3 && operands.size() != 5) {
1453 return emitError(unknownLoc,
1454 "OpBranchConditional must have condition, true label, "
1455 "false label, and optionally two branch weights");
1456 }
1457
1458 auto condition = getValue(operands[0]);
1459 auto *trueBlock = getOrCreateBlock(operands[1]);
1460 auto *falseBlock = getOrCreateBlock(operands[2]);
1461
1462 Optional<std::pair<uint32_t, uint32_t>> weights;
1463 if (operands.size() == 5) {
1464 weights = std::make_pair(operands[3], operands[4]);
1465 }
1466 // The preceding instruction for the OpBranchConditional instruction could be
1467 // an OpSelectionMerge instruction, in this case they will have the same
1468 // OpLine information.
1469 auto loc = createFileLineColLoc(opBuilder);
1470 opBuilder.create<spirv::BranchConditionalOp>(
1471 loc, condition, trueBlock,
1472 /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1473 /*falseArguments=*/ArrayRef<Value>(), weights);
1474
1475 (void)clearDebugLine();
1476 return success();
1477 }
1478
processLabel(ArrayRef<uint32_t> operands)1479 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1480 if (!curFunction) {
1481 return emitError(unknownLoc, "OpLabel must appear inside a function");
1482 }
1483
1484 if (operands.size() != 1) {
1485 return emitError(unknownLoc, "OpLabel should only have result <id>");
1486 }
1487
1488 auto labelID = operands[0];
1489 // We may have forward declared this block.
1490 auto *block = getOrCreateBlock(labelID);
1491 LLVM_DEBUG(llvm::dbgs() << "[block] populating block " << block << "\n");
1492 // If we have seen this block, make sure it was just a forward declaration.
1493 assert(block->empty() && "re-deserialize the same block!");
1494
1495 opBuilder.setInsertionPointToStart(block);
1496 blockMap[labelID] = curBlock = block;
1497
1498 return success();
1499 }
1500
1501 LogicalResult
processSelectionMerge(ArrayRef<uint32_t> operands)1502 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1503 if (!curBlock) {
1504 return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1505 }
1506
1507 if (operands.size() < 2) {
1508 return emitError(
1509 unknownLoc,
1510 "OpSelectionMerge must specify merge target and selection control");
1511 }
1512
1513 auto *mergeBlock = getOrCreateBlock(operands[0]);
1514 auto loc = createFileLineColLoc(opBuilder);
1515 auto selectionControl = operands[1];
1516
1517 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1518 .second) {
1519 return emitError(
1520 unknownLoc,
1521 "a block cannot have more than one OpSelectionMerge instruction");
1522 }
1523
1524 return success();
1525 }
1526
1527 LogicalResult
processLoopMerge(ArrayRef<uint32_t> operands)1528 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1529 if (!curBlock) {
1530 return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1531 }
1532
1533 if (operands.size() < 3) {
1534 return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1535 "continue target and loop control");
1536 }
1537
1538 auto *mergeBlock = getOrCreateBlock(operands[0]);
1539 auto *continueBlock = getOrCreateBlock(operands[1]);
1540 auto loc = createFileLineColLoc(opBuilder);
1541 uint32_t loopControl = operands[2];
1542
1543 if (!blockMergeInfo
1544 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1545 .second) {
1546 return emitError(
1547 unknownLoc,
1548 "a block cannot have more than one OpLoopMerge instruction");
1549 }
1550
1551 return success();
1552 }
1553
processPhi(ArrayRef<uint32_t> operands)1554 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1555 if (!curBlock) {
1556 return emitError(unknownLoc, "OpPhi must appear in a block");
1557 }
1558
1559 if (operands.size() < 4) {
1560 return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1561 "and variable-parent pairs");
1562 }
1563
1564 // Create a block argument for this OpPhi instruction.
1565 Type blockArgType = getType(operands[0]);
1566 BlockArgument blockArg = curBlock->addArgument(blockArgType);
1567 valueMap[operands[1]] = blockArg;
1568 LLVM_DEBUG(llvm::dbgs() << "[phi] created block argument " << blockArg
1569 << " id = " << operands[1] << " of type "
1570 << blockArgType << '\n');
1571
1572 // For each (value, predecessor) pair, insert the value to the predecessor's
1573 // blockPhiInfo entry so later we can fix the block argument there.
1574 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1575 uint32_t value = operands[i];
1576 Block *predecessor = getOrCreateBlock(operands[i + 1]);
1577 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1578 blockPhiInfo[predecessorTargetPair].push_back(value);
1579 LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor
1580 << " with arg id = " << value << '\n');
1581 }
1582
1583 return success();
1584 }
1585
1586 namespace {
1587 /// A class for putting all blocks in a structured selection/loop in a
1588 /// spv.mlir.selection/spv.mlir.loop op.
1589 class ControlFlowStructurizer {
1590 public:
1591 /// Structurizes the loop at the given `headerBlock`.
1592 ///
1593 /// This method will create an spv.mlir.loop op in the `mergeBlock` and move
1594 /// all blocks in the structured loop into the spv.mlir.loop's region. All
1595 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1596 /// method will also update `mergeInfo` by remapping all blocks inside to the
1597 /// newly cloned ones inside structured control flow op's regions.
structurize(Location loc,uint32_t control,spirv::BlockMergeInfoMap & mergeInfo,Block * headerBlock,Block * mergeBlock,Block * continueBlock)1598 static LogicalResult structurize(Location loc, uint32_t control,
1599 spirv::BlockMergeInfoMap &mergeInfo,
1600 Block *headerBlock, Block *mergeBlock,
1601 Block *continueBlock) {
1602 return ControlFlowStructurizer(loc, control, mergeInfo, headerBlock,
1603 mergeBlock, continueBlock)
1604 .structurizeImpl();
1605 }
1606
1607 private:
ControlFlowStructurizer(Location loc,uint32_t control,spirv::BlockMergeInfoMap & mergeInfo,Block * header,Block * merge,Block * cont)1608 ControlFlowStructurizer(Location loc, uint32_t control,
1609 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1610 Block *merge, Block *cont)
1611 : location(loc), control(control), blockMergeInfo(mergeInfo),
1612 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1613
1614 /// Creates a new spv.mlir.selection op at the beginning of the `mergeBlock`.
1615 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1616
1617 /// Creates a new spv.mlir.loop op at the beginning of the `mergeBlock`.
1618 spirv::LoopOp createLoopOp(uint32_t loopControl);
1619
1620 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1621 void collectBlocksInConstruct();
1622
1623 LogicalResult structurizeImpl();
1624
1625 Location location;
1626 uint32_t control;
1627
1628 spirv::BlockMergeInfoMap &blockMergeInfo;
1629
1630 Block *headerBlock;
1631 Block *mergeBlock;
1632 Block *continueBlock; // nullptr for spv.mlir.selection
1633
1634 SetVector<Block *> constructBlocks;
1635 };
1636 } // namespace
1637
1638 spirv::SelectionOp
createSelectionOp(uint32_t selectionControl)1639 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1640 // Create a builder and set the insertion point to the beginning of the
1641 // merge block so that the newly created SelectionOp will be inserted there.
1642 OpBuilder builder(&mergeBlock->front());
1643
1644 auto control = static_cast<spirv::SelectionControl>(selectionControl);
1645 auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1646 selectionOp.addMergeBlock();
1647
1648 return selectionOp;
1649 }
1650
createLoopOp(uint32_t loopControl)1651 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1652 // Create a builder and set the insertion point to the beginning of the
1653 // merge block so that the newly created LoopOp will be inserted there.
1654 OpBuilder builder(&mergeBlock->front());
1655
1656 auto control = static_cast<spirv::LoopControl>(loopControl);
1657 auto loopOp = builder.create<spirv::LoopOp>(location, control);
1658 loopOp.addEntryAndMergeBlock();
1659
1660 return loopOp;
1661 }
1662
collectBlocksInConstruct()1663 void ControlFlowStructurizer::collectBlocksInConstruct() {
1664 assert(constructBlocks.empty() && "expected empty constructBlocks");
1665
1666 // Put the header block in the work list first.
1667 constructBlocks.insert(headerBlock);
1668
1669 // For each item in the work list, add its successors excluding the merge
1670 // block.
1671 for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1672 for (auto *successor : constructBlocks[i]->getSuccessors())
1673 if (successor != mergeBlock)
1674 constructBlocks.insert(successor);
1675 }
1676 }
1677
structurizeImpl()1678 LogicalResult ControlFlowStructurizer::structurizeImpl() {
1679 Operation *op = nullptr;
1680 bool isLoop = continueBlock != nullptr;
1681 if (isLoop) {
1682 if (auto loopOp = createLoopOp(control))
1683 op = loopOp.getOperation();
1684 } else {
1685 if (auto selectionOp = createSelectionOp(control))
1686 op = selectionOp.getOperation();
1687 }
1688 if (!op)
1689 return failure();
1690 Region &body = op->getRegion(0);
1691
1692 BlockAndValueMapping mapper;
1693 // All references to the old merge block should be directed to the
1694 // selection/loop merge block in the SelectionOp/LoopOp's region.
1695 mapper.map(mergeBlock, &body.back());
1696
1697 collectBlocksInConstruct();
1698
1699 // We've identified all blocks belonging to the selection/loop's region. Now
1700 // need to "move" them into the selection/loop. Instead of really moving the
1701 // blocks, in the following we copy them and remap all values and branches.
1702 // This is because:
1703 // * Inserting a block into a region requires the block not in any region
1704 // before. But selections/loops can nest so we can create selection/loop ops
1705 // in a nested manner, which means some blocks may already be in a
1706 // selection/loop region when to be moved again.
1707 // * It's much trickier to fix up the branches into and out of the loop's
1708 // region: we need to treat not-moved blocks and moved blocks differently:
1709 // Not-moved blocks jumping to the loop header block need to jump to the
1710 // merge point containing the new loop op but not the loop continue block's
1711 // back edge. Moved blocks jumping out of the loop need to jump to the
1712 // merge block inside the loop region but not other not-moved blocks.
1713 // We cannot use replaceAllUsesWith clearly and it's harder to follow the
1714 // logic.
1715
1716 // Create a corresponding block in the SelectionOp/LoopOp's region for each
1717 // block in this loop construct.
1718 OpBuilder builder(body);
1719 for (auto *block : constructBlocks) {
1720 // Create a block and insert it before the selection/loop merge block in the
1721 // SelectionOp/LoopOp's region.
1722 auto *newBlock = builder.createBlock(&body.back());
1723 mapper.map(block, newBlock);
1724 LLVM_DEBUG(llvm::dbgs() << "[cf] cloned block " << newBlock
1725 << " from block " << block << "\n");
1726 if (!isFnEntryBlock(block)) {
1727 for (BlockArgument blockArg : block->getArguments()) {
1728 auto newArg = newBlock->addArgument(blockArg.getType());
1729 mapper.map(blockArg, newArg);
1730 LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg
1731 << " to " << newArg << '\n');
1732 }
1733 } else {
1734 LLVM_DEBUG(llvm::dbgs()
1735 << "[cf] block " << block << " is a function entry block\n");
1736 }
1737 for (auto &op : *block)
1738 newBlock->push_back(op.clone(mapper));
1739 }
1740
1741 // Go through all ops and remap the operands.
1742 auto remapOperands = [&](Operation *op) {
1743 for (auto &operand : op->getOpOperands())
1744 if (Value mappedOp = mapper.lookupOrNull(operand.get()))
1745 operand.set(mappedOp);
1746 for (auto &succOp : op->getBlockOperands())
1747 if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
1748 succOp.set(mappedOp);
1749 };
1750 for (auto &block : body) {
1751 block.walk(remapOperands);
1752 }
1753
1754 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1755 // the selection/loop construct into its region. Next we need to fix the
1756 // connections between this new SelectionOp/LoopOp with existing blocks.
1757
1758 // All existing incoming branches should go to the merge block, where the
1759 // SelectionOp/LoopOp resides right now.
1760 headerBlock->replaceAllUsesWith(mergeBlock);
1761
1762 if (isLoop) {
1763 // The loop selection/loop header block may have block arguments. Since now
1764 // we place the selection/loop op inside the old merge block, we need to
1765 // make sure the old merge block has the same block argument list.
1766 assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported");
1767 for (BlockArgument blockArg : headerBlock->getArguments()) {
1768 mergeBlock->addArgument(blockArg.getType());
1769 }
1770
1771 // If the loop header block has block arguments, make sure the spv.branch op
1772 // matches.
1773 SmallVector<Value, 4> blockArgs;
1774 if (!headerBlock->args_empty())
1775 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1776
1777 // The loop entry block should have a unconditional branch jumping to the
1778 // loop header block.
1779 builder.setInsertionPointToEnd(&body.front());
1780 builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
1781 ArrayRef<Value>(blockArgs));
1782 }
1783
1784 // All the blocks cloned into the SelectionOp/LoopOp's region can now be
1785 // cleaned up.
1786 LLVM_DEBUG(llvm::dbgs() << "[cf] cleaning up blocks after clone\n");
1787 // First we need to drop all operands' references inside all blocks. This is
1788 // needed because we can have blocks referencing SSA values from one another.
1789 for (auto *block : constructBlocks)
1790 block->dropAllReferences();
1791
1792 // Then erase all old blocks.
1793 for (auto *block : constructBlocks) {
1794 // We've cloned all blocks belonging to this construct into the structured
1795 // control flow op's region. Among these blocks, some may compose another
1796 // selection/loop. If so, they will be recorded within blockMergeInfo.
1797 // We need to update the pointers there to the newly remapped ones so we can
1798 // continue structurizing them later.
1799 // TODO: The asserts in the following assumes input SPIR-V blob
1800 // forms correctly nested selection/loop constructs. We should relax this
1801 // and support error cases better.
1802 auto it = blockMergeInfo.find(block);
1803 if (it != blockMergeInfo.end()) {
1804 Block *newHeader = mapper.lookupOrNull(block);
1805 assert(newHeader && "nested loop header block should be remapped!");
1806
1807 Block *newContinue = it->second.continueBlock;
1808 if (newContinue) {
1809 newContinue = mapper.lookupOrNull(newContinue);
1810 assert(newContinue && "nested loop continue block should be remapped!");
1811 }
1812
1813 Block *newMerge = it->second.mergeBlock;
1814 if (Block *mappedTo = mapper.lookupOrNull(newMerge))
1815 newMerge = mappedTo;
1816
1817 // Keep original location for nested selection/loop ops.
1818 Location loc = it->second.loc;
1819 // The iterator should be erased before adding a new entry into
1820 // blockMergeInfo to avoid iterator invalidation.
1821 blockMergeInfo.erase(it);
1822 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
1823 newContinue);
1824 }
1825
1826 // The structured selection/loop's entry block does not have arguments.
1827 // If the function's header block is also part of the structured control
1828 // flow, we cannot just simply erase it because it may contain arguments
1829 // matching the function signature and used by the cloned blocks.
1830 if (isFnEntryBlock(block)) {
1831 LLVM_DEBUG(llvm::dbgs() << "[cf] changing entry block " << block
1832 << " to only contain a spv.Branch op\n");
1833 // Still keep the function entry block for the potential block arguments,
1834 // but replace all ops inside with a branch to the merge block.
1835 block->clear();
1836 builder.setInsertionPointToEnd(block);
1837 builder.create<spirv::BranchOp>(location, mergeBlock);
1838 } else {
1839 LLVM_DEBUG(llvm::dbgs() << "[cf] erasing block " << block << "\n");
1840 block->erase();
1841 }
1842 }
1843
1844 LLVM_DEBUG(
1845 llvm::dbgs() << "[cf] after structurizing construct with header block "
1846 << headerBlock << ":\n"
1847 << *op << '\n');
1848
1849 return success();
1850 }
1851
wireUpBlockArgument()1852 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
1853 LLVM_DEBUG(llvm::dbgs() << "[phi] start wiring up block arguments\n");
1854
1855 OpBuilder::InsertionGuard guard(opBuilder);
1856
1857 for (const auto &info : blockPhiInfo) {
1858 Block *block = info.first.first;
1859 Block *target = info.first.second;
1860 const BlockPhiInfo &phiInfo = info.second;
1861 LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n");
1862 LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n");
1863 LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs()));
1864 LLVM_DEBUG(llvm::dbgs() << '\n');
1865
1866 // Set insertion point to before this block's terminator early because we
1867 // may materialize ops via getValue() call.
1868 auto *op = block->getTerminator();
1869 opBuilder.setInsertionPoint(op);
1870
1871 SmallVector<Value, 4> blockArgs;
1872 blockArgs.reserve(phiInfo.size());
1873 for (uint32_t valueId : phiInfo) {
1874 if (Value value = getValue(valueId)) {
1875 blockArgs.push_back(value);
1876 LLVM_DEBUG(llvm::dbgs() << "[phi] block argument " << value
1877 << " id = " << valueId << '\n');
1878 } else {
1879 return emitError(unknownLoc, "OpPhi references undefined value!");
1880 }
1881 }
1882
1883 if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
1884 // Replace the previous branch op with a new one with block arguments.
1885 opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
1886 blockArgs);
1887 branchOp.erase();
1888 } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
1889 assert((branchCondOp.getTrueBlock() == target ||
1890 branchCondOp.getFalseBlock() == target) &&
1891 "expected target to be either the true or false target");
1892 if (target == branchCondOp.trueTarget())
1893 opBuilder.create<spirv::BranchConditionalOp>(
1894 branchCondOp.getLoc(), branchCondOp.condition(), blockArgs,
1895 branchCondOp.getFalseBlockArguments(),
1896 branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(),
1897 branchCondOp.falseTarget());
1898 else
1899 opBuilder.create<spirv::BranchConditionalOp>(
1900 branchCondOp.getLoc(), branchCondOp.condition(),
1901 branchCondOp.getTrueBlockArguments(), blockArgs,
1902 branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(),
1903 branchCondOp.getFalseBlock());
1904
1905 branchCondOp.erase();
1906 } else {
1907 return emitError(unknownLoc, "unimplemented terminator for Phi creation");
1908 }
1909
1910 LLVM_DEBUG(llvm::dbgs() << "[phi] after creating block argument:\n");
1911 LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs()));
1912 LLVM_DEBUG(llvm::dbgs() << '\n');
1913 }
1914 blockPhiInfo.clear();
1915
1916 LLVM_DEBUG(llvm::dbgs() << "[phi] completed wiring up block arguments\n");
1917 return success();
1918 }
1919
structurizeControlFlow()1920 LogicalResult spirv::Deserializer::structurizeControlFlow() {
1921 LLVM_DEBUG(llvm::dbgs() << "[cf] start structurizing control flow\n");
1922
1923 while (!blockMergeInfo.empty()) {
1924 Block *headerBlock = blockMergeInfo.begin()->first;
1925 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
1926
1927 LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << ":\n");
1928 LLVM_DEBUG(headerBlock->print(llvm::dbgs()));
1929
1930 auto *mergeBlock = mergeInfo.mergeBlock;
1931 assert(mergeBlock && "merge block cannot be nullptr");
1932 if (!mergeBlock->args_empty())
1933 return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
1934 LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << ":\n");
1935 LLVM_DEBUG(mergeBlock->print(llvm::dbgs()));
1936
1937 auto *continueBlock = mergeInfo.continueBlock;
1938 if (continueBlock) {
1939 LLVM_DEBUG(llvm::dbgs()
1940 << "[cf] continue block " << continueBlock << ":\n");
1941 LLVM_DEBUG(continueBlock->print(llvm::dbgs()));
1942 }
1943 // Erase this case before calling into structurizer, who will update
1944 // blockMergeInfo.
1945 blockMergeInfo.erase(blockMergeInfo.begin());
1946 if (failed(ControlFlowStructurizer::structurize(
1947 mergeInfo.loc, mergeInfo.control, blockMergeInfo, headerBlock,
1948 mergeBlock, continueBlock)))
1949 return failure();
1950 }
1951
1952 LLVM_DEBUG(llvm::dbgs() << "[cf] completed structurizing control flow\n");
1953 return success();
1954 }
1955
1956 //===----------------------------------------------------------------------===//
1957 // Debug
1958 //===----------------------------------------------------------------------===//
1959
createFileLineColLoc(OpBuilder opBuilder)1960 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
1961 if (!debugLine)
1962 return unknownLoc;
1963
1964 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
1965 if (fileName.empty())
1966 fileName = "<unknown>";
1967 return FileLineColLoc::get(opBuilder.getIdentifier(fileName), debugLine->line,
1968 debugLine->col);
1969 }
1970
1971 LogicalResult
processDebugLine(ArrayRef<uint32_t> operands)1972 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
1973 // According to SPIR-V spec:
1974 // "This location information applies to the instructions physically
1975 // following this instruction, up to the first occurrence of any of the
1976 // following: the next end of block, the next OpLine instruction, or the next
1977 // OpNoLine instruction."
1978 if (operands.size() != 3)
1979 return emitError(unknownLoc, "OpLine must have 3 operands");
1980 debugLine = DebugLine(operands[0], operands[1], operands[2]);
1981 return success();
1982 }
1983
clearDebugLine()1984 LogicalResult spirv::Deserializer::clearDebugLine() {
1985 debugLine = llvm::None;
1986 return success();
1987 }
1988
1989 LogicalResult
processDebugString(ArrayRef<uint32_t> operands)1990 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
1991 if (operands.size() < 2)
1992 return emitError(unknownLoc, "OpString needs at least 2 operands");
1993
1994 if (!debugInfoMap.lookup(operands[0]).empty())
1995 return emitError(unknownLoc,
1996 "duplicate debug string found for result <id> ")
1997 << operands[0];
1998
1999 unsigned wordIndex = 1;
2000 StringRef debugString = decodeStringLiteral(operands, wordIndex);
2001 if (wordIndex != operands.size())
2002 return emitError(unknownLoc,
2003 "unexpected trailing words in OpString instruction");
2004
2005 debugInfoMap[operands[0]] = debugString;
2006 return success();
2007 }
2008