1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/Support/LogicalResult.h"
23 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/bit.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 using namespace mlir;
34 
35 #define DEBUG_TYPE "spirv-deserialization"
36 
37 //===----------------------------------------------------------------------===//
38 // Utility Functions
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true if the given `block` is a function entry block.
isFnEntryBlock(Block * block)42 static inline bool isFnEntryBlock(Block *block) {
43   return block->isEntryBlock() &&
44          isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Deserializer Method Definitions
49 //===----------------------------------------------------------------------===//
50 
Deserializer(ArrayRef<uint32_t> binary,MLIRContext * context)51 spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
52                                   MLIRContext *context)
53     : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
54       module(createModuleOp()), opBuilder(module->getRegion()) {}
55 
deserialize()56 LogicalResult spirv::Deserializer::deserialize() {
57   LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n");
58 
59   if (failed(processHeader()))
60     return failure();
61 
62   spirv::Opcode opcode = spirv::Opcode::OpNop;
63   ArrayRef<uint32_t> operands;
64   auto binarySize = binary.size();
65   while (curOffset < binarySize) {
66     // Slice the next instruction out and populate `opcode` and `operands`.
67     // Internally this also updates `curOffset`.
68     if (failed(sliceInstruction(opcode, operands)))
69       return failure();
70 
71     if (failed(processInstruction(opcode, operands)))
72       return failure();
73   }
74 
75   assert(curOffset == binarySize &&
76          "deserializer should never index beyond the binary end");
77 
78   for (auto &deferred : deferredInstructions) {
79     if (failed(processInstruction(deferred.first, deferred.second, false))) {
80       return failure();
81     }
82   }
83 
84   attachVCETriple();
85 
86   LLVM_DEBUG(llvm::dbgs() << "+++ completed deserialization +++\n");
87   return success();
88 }
89 
collect()90 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
91   return std::move(module);
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // Module structure
96 //===----------------------------------------------------------------------===//
97 
createModuleOp()98 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
99   OpBuilder builder(context);
100   OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
101   spirv::ModuleOp::build(builder, state);
102   return cast<spirv::ModuleOp>(Operation::create(state));
103 }
104 
processHeader()105 LogicalResult spirv::Deserializer::processHeader() {
106   if (binary.size() < spirv::kHeaderWordCount)
107     return emitError(unknownLoc,
108                      "SPIR-V binary module must have a 5-word header");
109 
110   if (binary[0] != spirv::kMagicNumber)
111     return emitError(unknownLoc, "incorrect magic number");
112 
113   // Version number bytes: 0 | major number | minor number | 0
114   uint32_t majorVersion = (binary[1] << 8) >> 24;
115   uint32_t minorVersion = (binary[1] << 16) >> 24;
116   if (majorVersion == 1) {
117     switch (minorVersion) {
118 #define MIN_VERSION_CASE(v)                                                    \
119   case v:                                                                      \
120     version = spirv::Version::V_1_##v;                                         \
121     break
122 
123       MIN_VERSION_CASE(0);
124       MIN_VERSION_CASE(1);
125       MIN_VERSION_CASE(2);
126       MIN_VERSION_CASE(3);
127       MIN_VERSION_CASE(4);
128       MIN_VERSION_CASE(5);
129 #undef MIN_VERSION_CASE
130     default:
131       return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
132              << minorVersion;
133     }
134   } else {
135     return emitError(unknownLoc, "unsupported SPIR-V major version: ")
136            << majorVersion;
137   }
138 
139   // TODO: generator number, bound, schema
140   curOffset = spirv::kHeaderWordCount;
141   return success();
142 }
143 
144 LogicalResult
processCapability(ArrayRef<uint32_t> operands)145 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
146   if (operands.size() != 1)
147     return emitError(unknownLoc, "OpMemoryModel must have one parameter");
148 
149   auto cap = spirv::symbolizeCapability(operands[0]);
150   if (!cap)
151     return emitError(unknownLoc, "unknown capability: ") << operands[0];
152 
153   capabilities.insert(*cap);
154   return success();
155 }
156 
processExtension(ArrayRef<uint32_t> words)157 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
158   if (words.empty()) {
159     return emitError(
160         unknownLoc,
161         "OpExtension must have a literal string for the extension name");
162   }
163 
164   unsigned wordIndex = 0;
165   StringRef extName = decodeStringLiteral(words, wordIndex);
166   if (wordIndex != words.size())
167     return emitError(unknownLoc,
168                      "unexpected trailing words in OpExtension instruction");
169   auto ext = spirv::symbolizeExtension(extName);
170   if (!ext)
171     return emitError(unknownLoc, "unknown extension: ") << extName;
172 
173   extensions.insert(*ext);
174   return success();
175 }
176 
177 LogicalResult
processExtInstImport(ArrayRef<uint32_t> words)178 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
179   if (words.size() < 2) {
180     return emitError(unknownLoc,
181                      "OpExtInstImport must have a result <id> and a literal "
182                      "string for the extended instruction set name");
183   }
184 
185   unsigned wordIndex = 1;
186   extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
187   if (wordIndex != words.size()) {
188     return emitError(unknownLoc,
189                      "unexpected trailing words in OpExtInstImport");
190   }
191   return success();
192 }
193 
attachVCETriple()194 void spirv::Deserializer::attachVCETriple() {
195   (*module)->setAttr(
196       spirv::ModuleOp::getVCETripleAttrName(),
197       spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
198                                 extensions.getArrayRef(), context));
199 }
200 
201 LogicalResult
processMemoryModel(ArrayRef<uint32_t> operands)202 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
203   if (operands.size() != 2)
204     return emitError(unknownLoc, "OpMemoryModel must have two operands");
205 
206   (*module)->setAttr(
207       "addressing_model",
208       opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.front())));
209   (*module)->setAttr(
210       "memory_model",
211       opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.back())));
212 
213   return success();
214 }
215 
processDecoration(ArrayRef<uint32_t> words)216 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
217   // TODO: This function should also be auto-generated. For now, since only a
218   // few decorations are processed/handled in a meaningful manner, going with a
219   // manual implementation.
220   if (words.size() < 2) {
221     return emitError(
222         unknownLoc, "OpDecorate must have at least result <id> and Decoration");
223   }
224   auto decorationName =
225       stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
226   if (decorationName.empty()) {
227     return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
228   }
229   auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
230   auto symbol = opBuilder.getIdentifier(attrName);
231   switch (static_cast<spirv::Decoration>(words[1])) {
232   case spirv::Decoration::DescriptorSet:
233   case spirv::Decoration::Binding:
234     if (words.size() != 3) {
235       return emitError(unknownLoc, "OpDecorate with ")
236              << decorationName << " needs a single integer literal";
237     }
238     decorations[words[0]].set(
239         symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
240     break;
241   case spirv::Decoration::BuiltIn:
242     if (words.size() != 3) {
243       return emitError(unknownLoc, "OpDecorate with ")
244              << decorationName << " needs a single integer literal";
245     }
246     decorations[words[0]].set(
247         symbol, opBuilder.getStringAttr(
248                     stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
249     break;
250   case spirv::Decoration::ArrayStride:
251     if (words.size() != 3) {
252       return emitError(unknownLoc, "OpDecorate with ")
253              << decorationName << " needs a single integer literal";
254     }
255     typeDecorations[words[0]] = words[2];
256     break;
257   case spirv::Decoration::Aliased:
258   case spirv::Decoration::Block:
259   case spirv::Decoration::BufferBlock:
260   case spirv::Decoration::Flat:
261   case spirv::Decoration::NonReadable:
262   case spirv::Decoration::NonWritable:
263   case spirv::Decoration::NoPerspective:
264   case spirv::Decoration::Restrict:
265   case spirv::Decoration::RelaxedPrecision:
266     if (words.size() != 2) {
267       return emitError(unknownLoc, "OpDecoration with ")
268              << decorationName << "needs a single target <id>";
269     }
270     // Block decoration does not affect spv.struct type, but is still stored for
271     // verification.
272     // TODO: Update StructType to contain this information since
273     // it is needed for many validation rules.
274     decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
275     break;
276   case spirv::Decoration::Location:
277   case spirv::Decoration::SpecId:
278     if (words.size() != 3) {
279       return emitError(unknownLoc, "OpDecoration with ")
280              << decorationName << "needs a single integer literal";
281     }
282     decorations[words[0]].set(
283         symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
284     break;
285   default:
286     return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
287   }
288   return success();
289 }
290 
291 LogicalResult
processMemberDecoration(ArrayRef<uint32_t> words)292 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
293   // The binary layout of OpMemberDecorate is different comparing to OpDecorate
294   if (words.size() < 3) {
295     return emitError(unknownLoc,
296                      "OpMemberDecorate must have at least 3 operands");
297   }
298 
299   auto decoration = static_cast<spirv::Decoration>(words[2]);
300   if (decoration == spirv::Decoration::Offset && words.size() != 4) {
301     return emitError(unknownLoc,
302                      " missing offset specification in OpMemberDecorate with "
303                      "Offset decoration");
304   }
305   ArrayRef<uint32_t> decorationOperands;
306   if (words.size() > 3) {
307     decorationOperands = words.slice(3);
308   }
309   memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
310   return success();
311 }
312 
processMemberName(ArrayRef<uint32_t> words)313 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
314   if (words.size() < 3) {
315     return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
316   }
317   unsigned wordIndex = 2;
318   auto name = decodeStringLiteral(words, wordIndex);
319   if (wordIndex != words.size()) {
320     return emitError(unknownLoc,
321                      "unexpected trailing words in OpMemberName instruction");
322   }
323   memberNameMap[words[0]][words[1]] = name;
324   return success();
325 }
326 
327 LogicalResult
processFunction(ArrayRef<uint32_t> operands)328 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
329   if (curFunction) {
330     return emitError(unknownLoc, "found function inside function");
331   }
332 
333   // Get the result type
334   if (operands.size() != 4) {
335     return emitError(unknownLoc, "OpFunction must have 4 parameters");
336   }
337   Type resultType = getType(operands[0]);
338   if (!resultType) {
339     return emitError(unknownLoc, "undefined result type from <id> ")
340            << operands[0];
341   }
342 
343   if (funcMap.count(operands[1])) {
344     return emitError(unknownLoc, "duplicate function definition/declaration");
345   }
346 
347   auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
348   if (!fnControl) {
349     return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
350   }
351 
352   Type fnType = getType(operands[3]);
353   if (!fnType || !fnType.isa<FunctionType>()) {
354     return emitError(unknownLoc, "unknown function type from <id> ")
355            << operands[3];
356   }
357   auto functionType = fnType.cast<FunctionType>();
358 
359   if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
360       (functionType.getNumResults() == 1 &&
361        functionType.getResult(0) != resultType)) {
362     return emitError(unknownLoc, "mismatch in function type ")
363            << functionType << " and return type " << resultType << " specified";
364   }
365 
366   std::string fnName = getFunctionSymbol(operands[1]);
367   auto funcOp = opBuilder.create<spirv::FuncOp>(
368       unknownLoc, fnName, functionType, fnControl.getValue());
369   curFunction = funcMap[operands[1]] = funcOp;
370   LLVM_DEBUG(llvm::dbgs() << "-- start function " << fnName << " (type = "
371                           << fnType << ", id = " << operands[1] << ") --\n");
372   auto *entryBlock = funcOp.addEntryBlock();
373   LLVM_DEBUG(llvm::dbgs() << "[block] created entry block " << entryBlock
374                           << "\n");
375 
376   // Parse the op argument instructions
377   if (functionType.getNumInputs()) {
378     for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
379       auto argType = functionType.getInput(i);
380       spirv::Opcode opcode = spirv::Opcode::OpNop;
381       ArrayRef<uint32_t> operands;
382       if (failed(sliceInstruction(opcode, operands,
383                                   spirv::Opcode::OpFunctionParameter))) {
384         return failure();
385       }
386       if (opcode != spirv::Opcode::OpFunctionParameter) {
387         return emitError(
388                    unknownLoc,
389                    "missing OpFunctionParameter instruction for argument ")
390                << i;
391       }
392       if (operands.size() != 2) {
393         return emitError(
394             unknownLoc,
395             "expected result type and result <id> for OpFunctionParameter");
396       }
397       auto argDefinedType = getType(operands[0]);
398       if (!argDefinedType || argDefinedType != argType) {
399         return emitError(unknownLoc,
400                          "mismatch in argument type between function type "
401                          "definition ")
402                << functionType << " and argument type definition "
403                << argDefinedType << " at argument " << i;
404       }
405       if (getValue(operands[1])) {
406         return emitError(unknownLoc, "duplicate definition of result <id> '")
407                << operands[1];
408       }
409       auto argValue = funcOp.getArgument(i);
410       valueMap[operands[1]] = argValue;
411     }
412   }
413 
414   // RAII guard to reset the insertion point to the module's region after
415   // deserializing the body of this function.
416   OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
417 
418   spirv::Opcode opcode = spirv::Opcode::OpNop;
419   ArrayRef<uint32_t> instOperands;
420 
421   // Special handling for the entry block. We need to make sure it starts with
422   // an OpLabel instruction. The entry block takes the same parameters as the
423   // function. All other blocks do not take any parameter. We have already
424   // created the entry block, here we need to register it to the correct label
425   // <id>.
426   if (failed(sliceInstruction(opcode, instOperands,
427                               spirv::Opcode::OpFunctionEnd))) {
428     return failure();
429   }
430   if (opcode == spirv::Opcode::OpFunctionEnd) {
431     LLVM_DEBUG(llvm::dbgs()
432                << "-- completed function '" << fnName << "' (type = " << fnType
433                << ", id = " << operands[1] << ") --\n");
434     return processFunctionEnd(instOperands);
435   }
436   if (opcode != spirv::Opcode::OpLabel) {
437     return emitError(unknownLoc, "a basic block must start with OpLabel");
438   }
439   if (instOperands.size() != 1) {
440     return emitError(unknownLoc, "OpLabel should only have result <id>");
441   }
442   blockMap[instOperands[0]] = entryBlock;
443   if (failed(processLabel(instOperands))) {
444     return failure();
445   }
446 
447   // Then process all the other instructions in the function until we hit
448   // OpFunctionEnd.
449   while (succeeded(sliceInstruction(opcode, instOperands,
450                                     spirv::Opcode::OpFunctionEnd)) &&
451          opcode != spirv::Opcode::OpFunctionEnd) {
452     if (failed(processInstruction(opcode, instOperands))) {
453       return failure();
454     }
455   }
456   if (opcode != spirv::Opcode::OpFunctionEnd) {
457     return failure();
458   }
459 
460   LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << fnName << "' (type = "
461                           << fnType << ", id = " << operands[1] << ") --\n");
462   return processFunctionEnd(instOperands);
463 }
464 
465 LogicalResult
processFunctionEnd(ArrayRef<uint32_t> operands)466 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
467   // Process OpFunctionEnd.
468   if (!operands.empty()) {
469     return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
470   }
471 
472   // Wire up block arguments from OpPhi instructions.
473   // Put all structured control flow in spv.mlir.selection/spv.mlir.loop ops.
474   if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
475     return failure();
476   }
477 
478   curBlock = nullptr;
479   curFunction = llvm::None;
480 
481   return success();
482 }
483 
484 Optional<std::pair<Attribute, Type>>
getConstant(uint32_t id)485 spirv::Deserializer::getConstant(uint32_t id) {
486   auto constIt = constantMap.find(id);
487   if (constIt == constantMap.end())
488     return llvm::None;
489   return constIt->getSecond();
490 }
491 
492 Optional<spirv::SpecConstOperationMaterializationInfo>
getSpecConstantOperation(uint32_t id)493 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
494   auto constIt = specConstOperationMap.find(id);
495   if (constIt == specConstOperationMap.end())
496     return llvm::None;
497   return constIt->getSecond();
498 }
499 
getFunctionSymbol(uint32_t id)500 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
501   auto funcName = nameMap.lookup(id).str();
502   if (funcName.empty()) {
503     funcName = "spirv_fn_" + std::to_string(id);
504   }
505   return funcName;
506 }
507 
getSpecConstantSymbol(uint32_t id)508 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
509   auto constName = nameMap.lookup(id).str();
510   if (constName.empty()) {
511     constName = "spirv_spec_const_" + std::to_string(id);
512   }
513   return constName;
514 }
515 
516 spirv::SpecConstantOp
createSpecConstant(Location loc,uint32_t resultID,Attribute defaultValue)517 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
518                                         Attribute defaultValue) {
519   auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
520   auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
521                                                     defaultValue);
522   if (decorations.count(resultID)) {
523     for (auto attr : decorations[resultID].getAttrs())
524       op->setAttr(attr.first, attr.second);
525   }
526   specConstMap[resultID] = op;
527   return op;
528 }
529 
530 LogicalResult
processGlobalVariable(ArrayRef<uint32_t> operands)531 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
532   unsigned wordIndex = 0;
533   if (operands.size() < 3) {
534     return emitError(
535         unknownLoc,
536         "OpVariable needs at least 3 operands, type, <id> and storage class");
537   }
538 
539   // Result Type.
540   auto type = getType(operands[wordIndex]);
541   if (!type) {
542     return emitError(unknownLoc, "unknown result type <id> : ")
543            << operands[wordIndex];
544   }
545   auto ptrType = type.dyn_cast<spirv::PointerType>();
546   if (!ptrType) {
547     return emitError(unknownLoc,
548                      "expected a result type <id> to be a spv.ptr, found : ")
549            << type;
550   }
551   wordIndex++;
552 
553   // Result <id>.
554   auto variableID = operands[wordIndex];
555   auto variableName = nameMap.lookup(variableID).str();
556   if (variableName.empty()) {
557     variableName = "spirv_var_" + std::to_string(variableID);
558   }
559   wordIndex++;
560 
561   // Storage class.
562   auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
563   if (ptrType.getStorageClass() != storageClass) {
564     return emitError(unknownLoc, "mismatch in storage class of pointer type ")
565            << type << " and that specified in OpVariable instruction  : "
566            << stringifyStorageClass(storageClass);
567   }
568   wordIndex++;
569 
570   // Initializer.
571   FlatSymbolRefAttr initializer = nullptr;
572   if (wordIndex < operands.size()) {
573     auto initializerOp = getGlobalVariable(operands[wordIndex]);
574     if (!initializerOp) {
575       return emitError(unknownLoc, "unknown <id> ")
576              << operands[wordIndex] << "used as initializer";
577     }
578     wordIndex++;
579     initializer = SymbolRefAttr::get(initializerOp.getOperation());
580   }
581   if (wordIndex != operands.size()) {
582     return emitError(unknownLoc,
583                      "found more operands than expected when deserializing "
584                      "OpVariable instruction, only ")
585            << wordIndex << " of " << operands.size() << " processed";
586   }
587   auto loc = createFileLineColLoc(opBuilder);
588   auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
589       loc, TypeAttr::get(type), opBuilder.getStringAttr(variableName),
590       initializer);
591 
592   // Decorations.
593   if (decorations.count(variableID)) {
594     for (auto attr : decorations[variableID].getAttrs()) {
595       varOp->setAttr(attr.first, attr.second);
596     }
597   }
598   globalVariableMap[variableID] = varOp;
599   return success();
600 }
601 
getConstantInt(uint32_t id)602 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
603   auto constInfo = getConstant(id);
604   if (!constInfo) {
605     return nullptr;
606   }
607   return constInfo->first.dyn_cast<IntegerAttr>();
608 }
609 
processName(ArrayRef<uint32_t> operands)610 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
611   if (operands.size() < 2) {
612     return emitError(unknownLoc, "OpName needs at least 2 operands");
613   }
614   if (!nameMap.lookup(operands[0]).empty()) {
615     return emitError(unknownLoc, "duplicate name found for result <id> ")
616            << operands[0];
617   }
618   unsigned wordIndex = 1;
619   StringRef name = decodeStringLiteral(operands, wordIndex);
620   if (wordIndex != operands.size()) {
621     return emitError(unknownLoc,
622                      "unexpected trailing words in OpName instruction");
623   }
624   nameMap[operands[0]] = name;
625   return success();
626 }
627 
628 //===----------------------------------------------------------------------===//
629 // Type
630 //===----------------------------------------------------------------------===//
631 
processType(spirv::Opcode opcode,ArrayRef<uint32_t> operands)632 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
633                                                ArrayRef<uint32_t> operands) {
634   if (operands.empty()) {
635     return emitError(unknownLoc, "type instruction with opcode ")
636            << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
637   }
638 
639   /// TODO: Types might be forward declared in some instructions and need to be
640   /// handled appropriately.
641   if (typeMap.count(operands[0])) {
642     return emitError(unknownLoc, "duplicate definition for result <id> ")
643            << operands[0];
644   }
645 
646   switch (opcode) {
647   case spirv::Opcode::OpTypeVoid:
648     if (operands.size() != 1)
649       return emitError(unknownLoc, "OpTypeVoid must have no parameters");
650     typeMap[operands[0]] = opBuilder.getNoneType();
651     break;
652   case spirv::Opcode::OpTypeBool:
653     if (operands.size() != 1)
654       return emitError(unknownLoc, "OpTypeBool must have no parameters");
655     typeMap[operands[0]] = opBuilder.getI1Type();
656     break;
657   case spirv::Opcode::OpTypeInt: {
658     if (operands.size() != 3)
659       return emitError(
660           unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
661 
662     // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
663     // to preserve or validate.
664     // 0 indicates unsigned, or no signedness semantics
665     // 1 indicates signed semantics."
666     //
667     // So we cannot differentiate signless and unsigned integers; always use
668     // signless semantics for such cases.
669     auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
670                                  : IntegerType::SignednessSemantics::Signless;
671     typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
672   } break;
673   case spirv::Opcode::OpTypeFloat: {
674     if (operands.size() != 2)
675       return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
676 
677     Type floatTy;
678     switch (operands[1]) {
679     case 16:
680       floatTy = opBuilder.getF16Type();
681       break;
682     case 32:
683       floatTy = opBuilder.getF32Type();
684       break;
685     case 64:
686       floatTy = opBuilder.getF64Type();
687       break;
688     default:
689       return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
690              << operands[1];
691     }
692     typeMap[operands[0]] = floatTy;
693   } break;
694   case spirv::Opcode::OpTypeVector: {
695     if (operands.size() != 3) {
696       return emitError(
697           unknownLoc,
698           "OpTypeVector must have element type and count parameters");
699     }
700     Type elementTy = getType(operands[1]);
701     if (!elementTy) {
702       return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
703              << operands[1];
704     }
705     typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
706   } break;
707   case spirv::Opcode::OpTypePointer: {
708     return processOpTypePointer(operands);
709   } break;
710   case spirv::Opcode::OpTypeArray:
711     return processArrayType(operands);
712   case spirv::Opcode::OpTypeCooperativeMatrixNV:
713     return processCooperativeMatrixType(operands);
714   case spirv::Opcode::OpTypeFunction:
715     return processFunctionType(operands);
716   case spirv::Opcode::OpTypeImage:
717     return processImageType(operands);
718   case spirv::Opcode::OpTypeSampledImage:
719     return processSampledImageType(operands);
720   case spirv::Opcode::OpTypeRuntimeArray:
721     return processRuntimeArrayType(operands);
722   case spirv::Opcode::OpTypeStruct:
723     return processStructType(operands);
724   case spirv::Opcode::OpTypeMatrix:
725     return processMatrixType(operands);
726   default:
727     return emitError(unknownLoc, "unhandled type instruction");
728   }
729   return success();
730 }
731 
732 LogicalResult
processOpTypePointer(ArrayRef<uint32_t> operands)733 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
734   if (operands.size() != 3)
735     return emitError(unknownLoc, "OpTypePointer must have two parameters");
736 
737   auto pointeeType = getType(operands[2]);
738   if (!pointeeType)
739     return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
740            << operands[2];
741 
742   uint32_t typePointerID = operands[0];
743   auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
744   typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
745 
746   for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
747        deferredStructIt != std::end(deferredStructTypesInfos);) {
748     for (auto *unresolvedMemberIt =
749              std::begin(deferredStructIt->unresolvedMemberTypes);
750          unresolvedMemberIt !=
751          std::end(deferredStructIt->unresolvedMemberTypes);) {
752       if (unresolvedMemberIt->first == typePointerID) {
753         // The newly constructed pointer type can resolve one of the
754         // deferred struct type members; update the memberTypes list and
755         // clean the unresolvedMemberTypes list accordingly.
756         deferredStructIt->memberTypes[unresolvedMemberIt->second] =
757             typeMap[typePointerID];
758         unresolvedMemberIt =
759             deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
760       } else {
761         ++unresolvedMemberIt;
762       }
763     }
764 
765     if (deferredStructIt->unresolvedMemberTypes.empty()) {
766       // All deferred struct type members are now resolved, set the struct body.
767       auto structType = deferredStructIt->deferredStructType;
768 
769       assert(structType && "expected a spirv::StructType");
770       assert(structType.isIdentified() && "expected an indentified struct");
771 
772       if (failed(structType.trySetBody(
773               deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
774               deferredStructIt->memberDecorationsInfo)))
775         return failure();
776 
777       deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
778     } else {
779       ++deferredStructIt;
780     }
781   }
782 
783   return success();
784 }
785 
786 LogicalResult
processArrayType(ArrayRef<uint32_t> operands)787 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
788   if (operands.size() != 3) {
789     return emitError(unknownLoc,
790                      "OpTypeArray must have element type and count parameters");
791   }
792 
793   Type elementTy = getType(operands[1]);
794   if (!elementTy) {
795     return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
796            << operands[1];
797   }
798 
799   unsigned count = 0;
800   // TODO: The count can also come frome a specialization constant.
801   auto countInfo = getConstant(operands[2]);
802   if (!countInfo) {
803     return emitError(unknownLoc, "OpTypeArray count <id> ")
804            << operands[2] << "can only come from normal constant right now";
805   }
806 
807   if (auto intVal = countInfo->first.dyn_cast<IntegerAttr>()) {
808     count = intVal.getValue().getZExtValue();
809   } else {
810     return emitError(unknownLoc, "OpTypeArray count must come from a "
811                                  "scalar integer constant instruction");
812   }
813 
814   typeMap[operands[0]] = spirv::ArrayType::get(
815       elementTy, count, typeDecorations.lookup(operands[0]));
816   return success();
817 }
818 
819 LogicalResult
processFunctionType(ArrayRef<uint32_t> operands)820 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
821   assert(!operands.empty() && "No operands for processing function type");
822   if (operands.size() == 1) {
823     return emitError(unknownLoc, "missing return type for OpTypeFunction");
824   }
825   auto returnType = getType(operands[1]);
826   if (!returnType) {
827     return emitError(unknownLoc, "unknown return type in OpTypeFunction");
828   }
829   SmallVector<Type, 1> argTypes;
830   for (size_t i = 2, e = operands.size(); i < e; ++i) {
831     auto ty = getType(operands[i]);
832     if (!ty) {
833       return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
834     }
835     argTypes.push_back(ty);
836   }
837   ArrayRef<Type> returnTypes;
838   if (!isVoidType(returnType)) {
839     returnTypes = llvm::makeArrayRef(returnType);
840   }
841   typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
842   return success();
843 }
844 
845 LogicalResult
processCooperativeMatrixType(ArrayRef<uint32_t> operands)846 spirv::Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
847   if (operands.size() != 5) {
848     return emitError(unknownLoc, "OpTypeCooperativeMatrix must have element "
849                                  "type and row x column parameters");
850   }
851 
852   Type elementTy = getType(operands[1]);
853   if (!elementTy) {
854     return emitError(unknownLoc,
855                      "OpTypeCooperativeMatrix references undefined <id> ")
856            << operands[1];
857   }
858 
859   auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
860   if (!scope) {
861     return emitError(unknownLoc,
862                      "OpTypeCooperativeMatrix references undefined scope <id> ")
863            << operands[2];
864   }
865 
866   unsigned rows = getConstantInt(operands[3]).getInt();
867   unsigned columns = getConstantInt(operands[4]).getInt();
868 
869   typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
870       elementTy, scope.getValue(), rows, columns);
871   return success();
872 }
873 
874 LogicalResult
processRuntimeArrayType(ArrayRef<uint32_t> operands)875 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
876   if (operands.size() != 2) {
877     return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
878   }
879   Type memberType = getType(operands[1]);
880   if (!memberType) {
881     return emitError(unknownLoc,
882                      "OpTypeRuntimeArray references undefined <id> ")
883            << operands[1];
884   }
885   typeMap[operands[0]] = spirv::RuntimeArrayType::get(
886       memberType, typeDecorations.lookup(operands[0]));
887   return success();
888 }
889 
890 LogicalResult
processStructType(ArrayRef<uint32_t> operands)891 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
892   // TODO: Find a way to handle identified structs when debug info is stripped.
893 
894   if (operands.empty()) {
895     return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
896   }
897 
898   if (operands.size() == 1) {
899     // Handle empty struct.
900     typeMap[operands[0]] =
901         spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
902     return success();
903   }
904 
905   // First element is operand ID, second element is member index in the struct.
906   SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
907   SmallVector<Type, 4> memberTypes;
908 
909   for (auto op : llvm::drop_begin(operands, 1)) {
910     Type memberType = getType(op);
911     bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
912 
913     if (!memberType && !typeForwardPtr)
914       return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
915              << op;
916 
917     if (!memberType)
918       unresolvedMemberTypes.emplace_back(op, memberTypes.size());
919 
920     memberTypes.push_back(memberType);
921   }
922 
923   SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
924   SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
925   if (memberDecorationMap.count(operands[0])) {
926     auto &allMemberDecorations = memberDecorationMap[operands[0]];
927     for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
928       if (allMemberDecorations.count(memberIndex)) {
929         for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
930           // Check for offset.
931           if (memberDecoration.first == spirv::Decoration::Offset) {
932             // If offset info is empty, resize to the number of members;
933             if (offsetInfo.empty()) {
934               offsetInfo.resize(memberTypes.size());
935             }
936             offsetInfo[memberIndex] = memberDecoration.second[0];
937           } else {
938             if (!memberDecoration.second.empty()) {
939               memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1,
940                                                  memberDecoration.first,
941                                                  memberDecoration.second[0]);
942             } else {
943               memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0,
944                                                  memberDecoration.first, 0);
945             }
946           }
947         }
948       }
949     }
950   }
951 
952   uint32_t structID = operands[0];
953   std::string structIdentifier = nameMap.lookup(structID).str();
954 
955   if (structIdentifier.empty()) {
956     assert(unresolvedMemberTypes.empty() &&
957            "didn't expect unresolved member types");
958     typeMap[structID] =
959         spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
960   } else {
961     auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
962     typeMap[structID] = structTy;
963 
964     if (!unresolvedMemberTypes.empty())
965       deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
966                                           memberTypes, offsetInfo,
967                                           memberDecorationsInfo});
968     else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
969                                         memberDecorationsInfo)))
970       return failure();
971   }
972 
973   // TODO: Update StructType to have member name as attribute as
974   // well.
975   return success();
976 }
977 
978 LogicalResult
processMatrixType(ArrayRef<uint32_t> operands)979 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
980   if (operands.size() != 3) {
981     // Three operands are needed: result_id, column_type, and column_count
982     return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
983                                  " (result_id, column_type, and column_count)");
984   }
985   // Matrix columns must be of vector type
986   Type elementTy = getType(operands[1]);
987   if (!elementTy) {
988     return emitError(unknownLoc,
989                      "OpTypeMatrix references undefined column type.")
990            << operands[1];
991   }
992 
993   uint32_t colsCount = operands[2];
994   typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
995   return success();
996 }
997 
998 LogicalResult
processTypeForwardPointer(ArrayRef<uint32_t> operands)999 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1000   if (operands.size() != 2)
1001     return emitError(unknownLoc,
1002                      "OpTypeForwardPointer instruction must have two operands");
1003 
1004   typeForwardPointerIDs.insert(operands[0]);
1005   // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1006   // instruction that defines the actual type.
1007 
1008   return success();
1009 }
1010 
1011 LogicalResult
processImageType(ArrayRef<uint32_t> operands)1012 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1013   // TODO: Add support for Access Qualifier.
1014   if (operands.size() != 8)
1015     return emitError(
1016         unknownLoc,
1017         "OpTypeImage with non-eight operands are not supported yet");
1018 
1019   Type elementTy = getType(operands[1]);
1020   if (!elementTy)
1021     return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1022            << operands[1];
1023 
1024   auto dim = spirv::symbolizeDim(operands[2]);
1025   if (!dim)
1026     return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1027            << operands[2];
1028 
1029   auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1030   if (!depthInfo)
1031     return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1032            << operands[3];
1033 
1034   auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1035   if (!arrayedInfo)
1036     return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1037            << operands[4];
1038 
1039   auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1040   if (!samplingInfo)
1041     return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1042 
1043   auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1044   if (!samplerUseInfo)
1045     return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1046            << operands[6];
1047 
1048   auto format = spirv::symbolizeImageFormat(operands[7]);
1049   if (!format)
1050     return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1051            << operands[7];
1052 
1053   typeMap[operands[0]] = spirv::ImageType::get(
1054       elementTy, dim.getValue(), depthInfo.getValue(), arrayedInfo.getValue(),
1055       samplingInfo.getValue(), samplerUseInfo.getValue(), format.getValue());
1056   return success();
1057 }
1058 
1059 LogicalResult
processSampledImageType(ArrayRef<uint32_t> operands)1060 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1061   if (operands.size() != 2)
1062     return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1063 
1064   Type elementTy = getType(operands[1]);
1065   if (!elementTy)
1066     return emitError(unknownLoc,
1067                      "OpTypeSampledImage references undefined <id>: ")
1068            << operands[1];
1069 
1070   typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1071   return success();
1072 }
1073 
1074 //===----------------------------------------------------------------------===//
1075 // Constant
1076 //===----------------------------------------------------------------------===//
1077 
processConstant(ArrayRef<uint32_t> operands,bool isSpec)1078 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1079                                                    bool isSpec) {
1080   StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1081 
1082   if (operands.size() < 2) {
1083     return emitError(unknownLoc)
1084            << opname << " must have type <id> and result <id>";
1085   }
1086   if (operands.size() < 3) {
1087     return emitError(unknownLoc)
1088            << opname << " must have at least 1 more parameter";
1089   }
1090 
1091   Type resultType = getType(operands[0]);
1092   if (!resultType) {
1093     return emitError(unknownLoc, "undefined result type from <id> ")
1094            << operands[0];
1095   }
1096 
1097   auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1098     if (bitwidth == 64) {
1099       if (operands.size() == 4) {
1100         return success();
1101       }
1102       return emitError(unknownLoc)
1103              << opname << " should have 2 parameters for 64-bit values";
1104     }
1105     if (bitwidth <= 32) {
1106       if (operands.size() == 3) {
1107         return success();
1108       }
1109 
1110       return emitError(unknownLoc)
1111              << opname
1112              << " should have 1 parameter for values with no more than 32 bits";
1113     }
1114     return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1115            << bitwidth;
1116   };
1117 
1118   auto resultID = operands[1];
1119 
1120   if (auto intType = resultType.dyn_cast<IntegerType>()) {
1121     auto bitwidth = intType.getWidth();
1122     if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1123       return failure();
1124     }
1125 
1126     APInt value;
1127     if (bitwidth == 64) {
1128       // 64-bit integers are represented with two SPIR-V words. According to
1129       // SPIR-V spec: "When the type’s bit width is larger than one word, the
1130       // literal’s low-order words appear first."
1131       struct DoubleWord {
1132         uint32_t word1;
1133         uint32_t word2;
1134       } words = {operands[2], operands[3]};
1135       value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1136     } else if (bitwidth <= 32) {
1137       value = APInt(bitwidth, operands[2], /*isSigned=*/true);
1138     }
1139 
1140     auto attr = opBuilder.getIntegerAttr(intType, value);
1141 
1142     if (isSpec) {
1143       createSpecConstant(unknownLoc, resultID, attr);
1144     } else {
1145       // For normal constants, we just record the attribute (and its type) for
1146       // later materialization at use sites.
1147       constantMap.try_emplace(resultID, attr, intType);
1148     }
1149 
1150     return success();
1151   }
1152 
1153   if (auto floatType = resultType.dyn_cast<FloatType>()) {
1154     auto bitwidth = floatType.getWidth();
1155     if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1156       return failure();
1157     }
1158 
1159     APFloat value(0.f);
1160     if (floatType.isF64()) {
1161       // Double values are represented with two SPIR-V words. According to
1162       // SPIR-V spec: "When the type’s bit width is larger than one word, the
1163       // literal’s low-order words appear first."
1164       struct DoubleWord {
1165         uint32_t word1;
1166         uint32_t word2;
1167       } words = {operands[2], operands[3]};
1168       value = APFloat(llvm::bit_cast<double>(words));
1169     } else if (floatType.isF32()) {
1170       value = APFloat(llvm::bit_cast<float>(operands[2]));
1171     } else if (floatType.isF16()) {
1172       APInt data(16, operands[2]);
1173       value = APFloat(APFloat::IEEEhalf(), data);
1174     }
1175 
1176     auto attr = opBuilder.getFloatAttr(floatType, value);
1177     if (isSpec) {
1178       createSpecConstant(unknownLoc, resultID, attr);
1179     } else {
1180       // For normal constants, we just record the attribute (and its type) for
1181       // later materialization at use sites.
1182       constantMap.try_emplace(resultID, attr, floatType);
1183     }
1184 
1185     return success();
1186   }
1187 
1188   return emitError(unknownLoc, "OpConstant can only generate values of "
1189                                "scalar integer or floating-point type");
1190 }
1191 
processConstantBool(bool isTrue,ArrayRef<uint32_t> operands,bool isSpec)1192 LogicalResult spirv::Deserializer::processConstantBool(
1193     bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1194   if (operands.size() != 2) {
1195     return emitError(unknownLoc, "Op")
1196            << (isSpec ? "Spec" : "") << "Constant"
1197            << (isTrue ? "True" : "False")
1198            << " must have type <id> and result <id>";
1199   }
1200 
1201   auto attr = opBuilder.getBoolAttr(isTrue);
1202   auto resultID = operands[1];
1203   if (isSpec) {
1204     createSpecConstant(unknownLoc, resultID, attr);
1205   } else {
1206     // For normal constants, we just record the attribute (and its type) for
1207     // later materialization at use sites.
1208     constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1209   }
1210 
1211   return success();
1212 }
1213 
1214 LogicalResult
processConstantComposite(ArrayRef<uint32_t> operands)1215 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1216   if (operands.size() < 2) {
1217     return emitError(unknownLoc,
1218                      "OpConstantComposite must have type <id> and result <id>");
1219   }
1220   if (operands.size() < 3) {
1221     return emitError(unknownLoc,
1222                      "OpConstantComposite must have at least 1 parameter");
1223   }
1224 
1225   Type resultType = getType(operands[0]);
1226   if (!resultType) {
1227     return emitError(unknownLoc, "undefined result type from <id> ")
1228            << operands[0];
1229   }
1230 
1231   SmallVector<Attribute, 4> elements;
1232   elements.reserve(operands.size() - 2);
1233   for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1234     auto elementInfo = getConstant(operands[i]);
1235     if (!elementInfo) {
1236       return emitError(unknownLoc, "OpConstantComposite component <id> ")
1237              << operands[i] << " must come from a normal constant";
1238     }
1239     elements.push_back(elementInfo->first);
1240   }
1241 
1242   auto resultID = operands[1];
1243   if (auto vectorType = resultType.dyn_cast<VectorType>()) {
1244     auto attr = DenseElementsAttr::get(vectorType, elements);
1245     // For normal constants, we just record the attribute (and its type) for
1246     // later materialization at use sites.
1247     constantMap.try_emplace(resultID, attr, resultType);
1248   } else if (auto arrayType = resultType.dyn_cast<spirv::ArrayType>()) {
1249     auto attr = opBuilder.getArrayAttr(elements);
1250     constantMap.try_emplace(resultID, attr, resultType);
1251   } else {
1252     return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1253            << resultType;
1254   }
1255 
1256   return success();
1257 }
1258 
1259 LogicalResult
processSpecConstantComposite(ArrayRef<uint32_t> operands)1260 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1261   if (operands.size() < 2) {
1262     return emitError(unknownLoc,
1263                      "OpConstantComposite must have type <id> and result <id>");
1264   }
1265   if (operands.size() < 3) {
1266     return emitError(unknownLoc,
1267                      "OpConstantComposite must have at least 1 parameter");
1268   }
1269 
1270   Type resultType = getType(operands[0]);
1271   if (!resultType) {
1272     return emitError(unknownLoc, "undefined result type from <id> ")
1273            << operands[0];
1274   }
1275 
1276   auto resultID = operands[1];
1277   auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1278 
1279   SmallVector<Attribute, 4> elements;
1280   elements.reserve(operands.size() - 2);
1281   for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1282     auto elementInfo = getSpecConstant(operands[i]);
1283     elements.push_back(SymbolRefAttr::get(elementInfo));
1284   }
1285 
1286   auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
1287       unknownLoc, TypeAttr::get(resultType), symName,
1288       opBuilder.getArrayAttr(elements));
1289   specConstCompositeMap[resultID] = op;
1290 
1291   return success();
1292 }
1293 
1294 LogicalResult
processSpecConstantOperation(ArrayRef<uint32_t> operands)1295 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1296   if (operands.size() < 3)
1297     return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1298                                  "result <id>, and operand opcode");
1299 
1300   uint32_t resultTypeID = operands[0];
1301 
1302   if (!getType(resultTypeID))
1303     return emitError(unknownLoc, "undefined result type from <id> ")
1304            << resultTypeID;
1305 
1306   uint32_t resultID = operands[1];
1307   spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1308   auto emplaceResult = specConstOperationMap.try_emplace(
1309       resultID,
1310       SpecConstOperationMaterializationInfo{
1311           enclosedOpcode, resultTypeID,
1312           SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1313 
1314   if (!emplaceResult.second)
1315     return emitError(unknownLoc, "value with <id>: ")
1316            << resultID << " is probably defined before.";
1317 
1318   return success();
1319 }
1320 
materializeSpecConstantOperation(uint32_t resultID,spirv::Opcode enclosedOpcode,uint32_t resultTypeID,ArrayRef<uint32_t> enclosedOpOperands)1321 Value spirv::Deserializer::materializeSpecConstantOperation(
1322     uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1323     ArrayRef<uint32_t> enclosedOpOperands) {
1324 
1325   Type resultType = getType(resultTypeID);
1326 
1327   // Instructions wrapped by OpSpecConstantOp need an ID for their
1328   // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1329   // dialect wrapped op. For that purpose, a new value map is created and "fake"
1330   // ID in that map is assigned to the result of the enclosed instruction. Note
1331   // that there is no need to update this fake ID since we only need to
1332   // reference the created Value for the enclosed op from the spv::YieldOp
1333   // created later in this method (both of which are the only values in their
1334   // region: the SpecConstantOperation's region). If we encounter another
1335   // SpecConstantOperation in the module, we simply re-use the fake ID since the
1336   // previous Value assigned to it isn't visible in the current scope anyway.
1337   DenseMap<uint32_t, Value> newValueMap;
1338   llvm::SaveAndRestore<DenseMap<uint32_t, Value>> valueMapGuard(valueMap,
1339                                                                 newValueMap);
1340   constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1341 
1342   SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1343   enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1344   enclosedOpResultTypeAndOperands.push_back(fakeID);
1345   enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1346                                          enclosedOpOperands.end());
1347 
1348   // Process enclosed instruction before creating the enclosing
1349   // specConstantOperation (and its region). This way, references to constants,
1350   // global variables, and spec constants will be materialized outside the new
1351   // op's region. For more info, see Deserializer::getValue's implementation.
1352   if (failed(
1353           processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1354     return Value();
1355 
1356   // Since the enclosed op is emitted in the current block, split it in a
1357   // separate new block.
1358   Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
1359 
1360   auto loc = createFileLineColLoc(opBuilder);
1361   auto specConstOperationOp =
1362       opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1363 
1364   Region &body = specConstOperationOp.body();
1365   // Move the new block into SpecConstantOperation's body.
1366   body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1367                           Region::iterator(enclosedBlock));
1368   Block &block = body.back();
1369 
1370   // RAII guard to reset the insertion point to the module's region after
1371   // deserializing the body of the specConstantOperation.
1372   OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
1373   opBuilder.setInsertionPointToEnd(&block);
1374 
1375   opBuilder.create<spirv::YieldOp>(loc, block.front().getResult(0));
1376   return specConstOperationOp.getResult();
1377 }
1378 
1379 LogicalResult
processConstantNull(ArrayRef<uint32_t> operands)1380 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1381   if (operands.size() != 2) {
1382     return emitError(unknownLoc,
1383                      "OpConstantNull must have type <id> and result <id>");
1384   }
1385 
1386   Type resultType = getType(operands[0]);
1387   if (!resultType) {
1388     return emitError(unknownLoc, "undefined result type from <id> ")
1389            << operands[0];
1390   }
1391 
1392   auto resultID = operands[1];
1393   if (resultType.isIntOrFloat() || resultType.isa<VectorType>()) {
1394     auto attr = opBuilder.getZeroAttr(resultType);
1395     // For normal constants, we just record the attribute (and its type) for
1396     // later materialization at use sites.
1397     constantMap.try_emplace(resultID, attr, resultType);
1398     return success();
1399   }
1400 
1401   return emitError(unknownLoc, "unsupported OpConstantNull type: ")
1402          << resultType;
1403 }
1404 
1405 //===----------------------------------------------------------------------===//
1406 // Control flow
1407 //===----------------------------------------------------------------------===//
1408 
getOrCreateBlock(uint32_t id)1409 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
1410   if (auto *block = getBlock(id)) {
1411     LLVM_DEBUG(llvm::dbgs() << "[block] got exiting block for id = " << id
1412                             << " @ " << block << "\n");
1413     return block;
1414   }
1415 
1416   // We don't know where this block will be placed finally (in a
1417   // spv.mlir.selection or spv.mlir.loop or function). Create it into the
1418   // function for now and sort out the proper place later.
1419   auto *block = curFunction->addBlock();
1420   LLVM_DEBUG(llvm::dbgs() << "[block] created block for id = " << id << " @ "
1421                           << block << "\n");
1422   return blockMap[id] = block;
1423 }
1424 
processBranch(ArrayRef<uint32_t> operands)1425 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
1426   if (!curBlock) {
1427     return emitError(unknownLoc, "OpBranch must appear inside a block");
1428   }
1429 
1430   if (operands.size() != 1) {
1431     return emitError(unknownLoc, "OpBranch must take exactly one target label");
1432   }
1433 
1434   auto *target = getOrCreateBlock(operands[0]);
1435   auto loc = createFileLineColLoc(opBuilder);
1436   // The preceding instruction for the OpBranch instruction could be an
1437   // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
1438   // the same OpLine information.
1439   opBuilder.create<spirv::BranchOp>(loc, target);
1440 
1441   (void)clearDebugLine();
1442   return success();
1443 }
1444 
1445 LogicalResult
processBranchConditional(ArrayRef<uint32_t> operands)1446 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
1447   if (!curBlock) {
1448     return emitError(unknownLoc,
1449                      "OpBranchConditional must appear inside a block");
1450   }
1451 
1452   if (operands.size() != 3 && operands.size() != 5) {
1453     return emitError(unknownLoc,
1454                      "OpBranchConditional must have condition, true label, "
1455                      "false label, and optionally two branch weights");
1456   }
1457 
1458   auto condition = getValue(operands[0]);
1459   auto *trueBlock = getOrCreateBlock(operands[1]);
1460   auto *falseBlock = getOrCreateBlock(operands[2]);
1461 
1462   Optional<std::pair<uint32_t, uint32_t>> weights;
1463   if (operands.size() == 5) {
1464     weights = std::make_pair(operands[3], operands[4]);
1465   }
1466   // The preceding instruction for the OpBranchConditional instruction could be
1467   // an OpSelectionMerge instruction, in this case they will have the same
1468   // OpLine information.
1469   auto loc = createFileLineColLoc(opBuilder);
1470   opBuilder.create<spirv::BranchConditionalOp>(
1471       loc, condition, trueBlock,
1472       /*trueArguments=*/ArrayRef<Value>(), falseBlock,
1473       /*falseArguments=*/ArrayRef<Value>(), weights);
1474 
1475   (void)clearDebugLine();
1476   return success();
1477 }
1478 
processLabel(ArrayRef<uint32_t> operands)1479 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1480   if (!curFunction) {
1481     return emitError(unknownLoc, "OpLabel must appear inside a function");
1482   }
1483 
1484   if (operands.size() != 1) {
1485     return emitError(unknownLoc, "OpLabel should only have result <id>");
1486   }
1487 
1488   auto labelID = operands[0];
1489   // We may have forward declared this block.
1490   auto *block = getOrCreateBlock(labelID);
1491   LLVM_DEBUG(llvm::dbgs() << "[block] populating block " << block << "\n");
1492   // If we have seen this block, make sure it was just a forward declaration.
1493   assert(block->empty() && "re-deserialize the same block!");
1494 
1495   opBuilder.setInsertionPointToStart(block);
1496   blockMap[labelID] = curBlock = block;
1497 
1498   return success();
1499 }
1500 
1501 LogicalResult
processSelectionMerge(ArrayRef<uint32_t> operands)1502 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
1503   if (!curBlock) {
1504     return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
1505   }
1506 
1507   if (operands.size() < 2) {
1508     return emitError(
1509         unknownLoc,
1510         "OpSelectionMerge must specify merge target and selection control");
1511   }
1512 
1513   auto *mergeBlock = getOrCreateBlock(operands[0]);
1514   auto loc = createFileLineColLoc(opBuilder);
1515   auto selectionControl = operands[1];
1516 
1517   if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1518            .second) {
1519     return emitError(
1520         unknownLoc,
1521         "a block cannot have more than one OpSelectionMerge instruction");
1522   }
1523 
1524   return success();
1525 }
1526 
1527 LogicalResult
processLoopMerge(ArrayRef<uint32_t> operands)1528 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
1529   if (!curBlock) {
1530     return emitError(unknownLoc, "OpLoopMerge must appear in a block");
1531   }
1532 
1533   if (operands.size() < 3) {
1534     return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
1535                                  "continue target and loop control");
1536   }
1537 
1538   auto *mergeBlock = getOrCreateBlock(operands[0]);
1539   auto *continueBlock = getOrCreateBlock(operands[1]);
1540   auto loc = createFileLineColLoc(opBuilder);
1541   uint32_t loopControl = operands[2];
1542 
1543   if (!blockMergeInfo
1544            .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1545            .second) {
1546     return emitError(
1547         unknownLoc,
1548         "a block cannot have more than one OpLoopMerge instruction");
1549   }
1550 
1551   return success();
1552 }
1553 
processPhi(ArrayRef<uint32_t> operands)1554 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
1555   if (!curBlock) {
1556     return emitError(unknownLoc, "OpPhi must appear in a block");
1557   }
1558 
1559   if (operands.size() < 4) {
1560     return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
1561                                  "and variable-parent pairs");
1562   }
1563 
1564   // Create a block argument for this OpPhi instruction.
1565   Type blockArgType = getType(operands[0]);
1566   BlockArgument blockArg = curBlock->addArgument(blockArgType);
1567   valueMap[operands[1]] = blockArg;
1568   LLVM_DEBUG(llvm::dbgs() << "[phi] created block argument " << blockArg
1569                           << " id = " << operands[1] << " of type "
1570                           << blockArgType << '\n');
1571 
1572   // For each (value, predecessor) pair, insert the value to the predecessor's
1573   // blockPhiInfo entry so later we can fix the block argument there.
1574   for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
1575     uint32_t value = operands[i];
1576     Block *predecessor = getOrCreateBlock(operands[i + 1]);
1577     std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1578     blockPhiInfo[predecessorTargetPair].push_back(value);
1579     LLVM_DEBUG(llvm::dbgs() << "[phi] predecessor @ " << predecessor
1580                             << " with arg id = " << value << '\n');
1581   }
1582 
1583   return success();
1584 }
1585 
1586 namespace {
1587 /// A class for putting all blocks in a structured selection/loop in a
1588 /// spv.mlir.selection/spv.mlir.loop op.
1589 class ControlFlowStructurizer {
1590 public:
1591   /// Structurizes the loop at the given `headerBlock`.
1592   ///
1593   /// This method will create an spv.mlir.loop op in the `mergeBlock` and move
1594   /// all blocks in the structured loop into the spv.mlir.loop's region. All
1595   /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
1596   /// method will also update `mergeInfo` by remapping all blocks inside to the
1597   /// newly cloned ones inside structured control flow op's regions.
structurize(Location loc,uint32_t control,spirv::BlockMergeInfoMap & mergeInfo,Block * headerBlock,Block * mergeBlock,Block * continueBlock)1598   static LogicalResult structurize(Location loc, uint32_t control,
1599                                    spirv::BlockMergeInfoMap &mergeInfo,
1600                                    Block *headerBlock, Block *mergeBlock,
1601                                    Block *continueBlock) {
1602     return ControlFlowStructurizer(loc, control, mergeInfo, headerBlock,
1603                                    mergeBlock, continueBlock)
1604         .structurizeImpl();
1605   }
1606 
1607 private:
ControlFlowStructurizer(Location loc,uint32_t control,spirv::BlockMergeInfoMap & mergeInfo,Block * header,Block * merge,Block * cont)1608   ControlFlowStructurizer(Location loc, uint32_t control,
1609                           spirv::BlockMergeInfoMap &mergeInfo, Block *header,
1610                           Block *merge, Block *cont)
1611       : location(loc), control(control), blockMergeInfo(mergeInfo),
1612         headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1613 
1614   /// Creates a new spv.mlir.selection op at the beginning of the `mergeBlock`.
1615   spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1616 
1617   /// Creates a new spv.mlir.loop op at the beginning of the `mergeBlock`.
1618   spirv::LoopOp createLoopOp(uint32_t loopControl);
1619 
1620   /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
1621   void collectBlocksInConstruct();
1622 
1623   LogicalResult structurizeImpl();
1624 
1625   Location location;
1626   uint32_t control;
1627 
1628   spirv::BlockMergeInfoMap &blockMergeInfo;
1629 
1630   Block *headerBlock;
1631   Block *mergeBlock;
1632   Block *continueBlock; // nullptr for spv.mlir.selection
1633 
1634   SetVector<Block *> constructBlocks;
1635 };
1636 } // namespace
1637 
1638 spirv::SelectionOp
createSelectionOp(uint32_t selectionControl)1639 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1640   // Create a builder and set the insertion point to the beginning of the
1641   // merge block so that the newly created SelectionOp will be inserted there.
1642   OpBuilder builder(&mergeBlock->front());
1643 
1644   auto control = static_cast<spirv::SelectionControl>(selectionControl);
1645   auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1646   selectionOp.addMergeBlock();
1647 
1648   return selectionOp;
1649 }
1650 
createLoopOp(uint32_t loopControl)1651 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1652   // Create a builder and set the insertion point to the beginning of the
1653   // merge block so that the newly created LoopOp will be inserted there.
1654   OpBuilder builder(&mergeBlock->front());
1655 
1656   auto control = static_cast<spirv::LoopControl>(loopControl);
1657   auto loopOp = builder.create<spirv::LoopOp>(location, control);
1658   loopOp.addEntryAndMergeBlock();
1659 
1660   return loopOp;
1661 }
1662 
collectBlocksInConstruct()1663 void ControlFlowStructurizer::collectBlocksInConstruct() {
1664   assert(constructBlocks.empty() && "expected empty constructBlocks");
1665 
1666   // Put the header block in the work list first.
1667   constructBlocks.insert(headerBlock);
1668 
1669   // For each item in the work list, add its successors excluding the merge
1670   // block.
1671   for (unsigned i = 0; i < constructBlocks.size(); ++i) {
1672     for (auto *successor : constructBlocks[i]->getSuccessors())
1673       if (successor != mergeBlock)
1674         constructBlocks.insert(successor);
1675   }
1676 }
1677 
structurizeImpl()1678 LogicalResult ControlFlowStructurizer::structurizeImpl() {
1679   Operation *op = nullptr;
1680   bool isLoop = continueBlock != nullptr;
1681   if (isLoop) {
1682     if (auto loopOp = createLoopOp(control))
1683       op = loopOp.getOperation();
1684   } else {
1685     if (auto selectionOp = createSelectionOp(control))
1686       op = selectionOp.getOperation();
1687   }
1688   if (!op)
1689     return failure();
1690   Region &body = op->getRegion(0);
1691 
1692   BlockAndValueMapping mapper;
1693   // All references to the old merge block should be directed to the
1694   // selection/loop merge block in the SelectionOp/LoopOp's region.
1695   mapper.map(mergeBlock, &body.back());
1696 
1697   collectBlocksInConstruct();
1698 
1699   // We've identified all blocks belonging to the selection/loop's region. Now
1700   // need to "move" them into the selection/loop. Instead of really moving the
1701   // blocks, in the following we copy them and remap all values and branches.
1702   // This is because:
1703   // * Inserting a block into a region requires the block not in any region
1704   //   before. But selections/loops can nest so we can create selection/loop ops
1705   //   in a nested manner, which means some blocks may already be in a
1706   //   selection/loop region when to be moved again.
1707   // * It's much trickier to fix up the branches into and out of the loop's
1708   //   region: we need to treat not-moved blocks and moved blocks differently:
1709   //   Not-moved blocks jumping to the loop header block need to jump to the
1710   //   merge point containing the new loop op but not the loop continue block's
1711   //   back edge. Moved blocks jumping out of the loop need to jump to the
1712   //   merge block inside the loop region but not other not-moved blocks.
1713   //   We cannot use replaceAllUsesWith clearly and it's harder to follow the
1714   //   logic.
1715 
1716   // Create a corresponding block in the SelectionOp/LoopOp's region for each
1717   // block in this loop construct.
1718   OpBuilder builder(body);
1719   for (auto *block : constructBlocks) {
1720     // Create a block and insert it before the selection/loop merge block in the
1721     // SelectionOp/LoopOp's region.
1722     auto *newBlock = builder.createBlock(&body.back());
1723     mapper.map(block, newBlock);
1724     LLVM_DEBUG(llvm::dbgs() << "[cf] cloned block " << newBlock
1725                             << " from block " << block << "\n");
1726     if (!isFnEntryBlock(block)) {
1727       for (BlockArgument blockArg : block->getArguments()) {
1728         auto newArg = newBlock->addArgument(blockArg.getType());
1729         mapper.map(blockArg, newArg);
1730         LLVM_DEBUG(llvm::dbgs() << "[cf] remapped block argument " << blockArg
1731                                 << " to " << newArg << '\n');
1732       }
1733     } else {
1734       LLVM_DEBUG(llvm::dbgs()
1735                  << "[cf] block " << block << " is a function entry block\n");
1736     }
1737     for (auto &op : *block)
1738       newBlock->push_back(op.clone(mapper));
1739   }
1740 
1741   // Go through all ops and remap the operands.
1742   auto remapOperands = [&](Operation *op) {
1743     for (auto &operand : op->getOpOperands())
1744       if (Value mappedOp = mapper.lookupOrNull(operand.get()))
1745         operand.set(mappedOp);
1746     for (auto &succOp : op->getBlockOperands())
1747       if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
1748         succOp.set(mappedOp);
1749   };
1750   for (auto &block : body) {
1751     block.walk(remapOperands);
1752   }
1753 
1754   // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
1755   // the selection/loop construct into its region. Next we need to fix the
1756   // connections between this new SelectionOp/LoopOp with existing blocks.
1757 
1758   // All existing incoming branches should go to the merge block, where the
1759   // SelectionOp/LoopOp resides right now.
1760   headerBlock->replaceAllUsesWith(mergeBlock);
1761 
1762   if (isLoop) {
1763     // The loop selection/loop header block may have block arguments. Since now
1764     // we place the selection/loop op inside the old merge block, we need to
1765     // make sure the old merge block has the same block argument list.
1766     assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported");
1767     for (BlockArgument blockArg : headerBlock->getArguments()) {
1768       mergeBlock->addArgument(blockArg.getType());
1769     }
1770 
1771     // If the loop header block has block arguments, make sure the spv.branch op
1772     // matches.
1773     SmallVector<Value, 4> blockArgs;
1774     if (!headerBlock->args_empty())
1775       blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1776 
1777     // The loop entry block should have a unconditional branch jumping to the
1778     // loop header block.
1779     builder.setInsertionPointToEnd(&body.front());
1780     builder.create<spirv::BranchOp>(location, mapper.lookupOrNull(headerBlock),
1781                                     ArrayRef<Value>(blockArgs));
1782   }
1783 
1784   // All the blocks cloned into the SelectionOp/LoopOp's region can now be
1785   // cleaned up.
1786   LLVM_DEBUG(llvm::dbgs() << "[cf] cleaning up blocks after clone\n");
1787   // First we need to drop all operands' references inside all blocks. This is
1788   // needed because we can have blocks referencing SSA values from one another.
1789   for (auto *block : constructBlocks)
1790     block->dropAllReferences();
1791 
1792   // Then erase all old blocks.
1793   for (auto *block : constructBlocks) {
1794     // We've cloned all blocks belonging to this construct into the structured
1795     // control flow op's region. Among these blocks, some may compose another
1796     // selection/loop. If so, they will be recorded within blockMergeInfo.
1797     // We need to update the pointers there to the newly remapped ones so we can
1798     // continue structurizing them later.
1799     // TODO: The asserts in the following assumes input SPIR-V blob
1800     // forms correctly nested selection/loop constructs. We should relax this
1801     // and support error cases better.
1802     auto it = blockMergeInfo.find(block);
1803     if (it != blockMergeInfo.end()) {
1804       Block *newHeader = mapper.lookupOrNull(block);
1805       assert(newHeader && "nested loop header block should be remapped!");
1806 
1807       Block *newContinue = it->second.continueBlock;
1808       if (newContinue) {
1809         newContinue = mapper.lookupOrNull(newContinue);
1810         assert(newContinue && "nested loop continue block should be remapped!");
1811       }
1812 
1813       Block *newMerge = it->second.mergeBlock;
1814       if (Block *mappedTo = mapper.lookupOrNull(newMerge))
1815         newMerge = mappedTo;
1816 
1817       // Keep original location for nested selection/loop ops.
1818       Location loc = it->second.loc;
1819       // The iterator should be erased before adding a new entry into
1820       // blockMergeInfo to avoid iterator invalidation.
1821       blockMergeInfo.erase(it);
1822       blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
1823                                  newContinue);
1824     }
1825 
1826     // The structured selection/loop's entry block does not have arguments.
1827     // If the function's header block is also part of the structured control
1828     // flow, we cannot just simply erase it because it may contain arguments
1829     // matching the function signature and used by the cloned blocks.
1830     if (isFnEntryBlock(block)) {
1831       LLVM_DEBUG(llvm::dbgs() << "[cf] changing entry block " << block
1832                               << " to only contain a spv.Branch op\n");
1833       // Still keep the function entry block for the potential block arguments,
1834       // but replace all ops inside with a branch to the merge block.
1835       block->clear();
1836       builder.setInsertionPointToEnd(block);
1837       builder.create<spirv::BranchOp>(location, mergeBlock);
1838     } else {
1839       LLVM_DEBUG(llvm::dbgs() << "[cf] erasing block " << block << "\n");
1840       block->erase();
1841     }
1842   }
1843 
1844   LLVM_DEBUG(
1845       llvm::dbgs() << "[cf] after structurizing construct with header block "
1846                    << headerBlock << ":\n"
1847                    << *op << '\n');
1848 
1849   return success();
1850 }
1851 
wireUpBlockArgument()1852 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
1853   LLVM_DEBUG(llvm::dbgs() << "[phi] start wiring up block arguments\n");
1854 
1855   OpBuilder::InsertionGuard guard(opBuilder);
1856 
1857   for (const auto &info : blockPhiInfo) {
1858     Block *block = info.first.first;
1859     Block *target = info.first.second;
1860     const BlockPhiInfo &phiInfo = info.second;
1861     LLVM_DEBUG(llvm::dbgs() << "[phi] block " << block << "\n");
1862     LLVM_DEBUG(llvm::dbgs() << "[phi] before creating block argument:\n");
1863     LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs()));
1864     LLVM_DEBUG(llvm::dbgs() << '\n');
1865 
1866     // Set insertion point to before this block's terminator early because we
1867     // may materialize ops via getValue() call.
1868     auto *op = block->getTerminator();
1869     opBuilder.setInsertionPoint(op);
1870 
1871     SmallVector<Value, 4> blockArgs;
1872     blockArgs.reserve(phiInfo.size());
1873     for (uint32_t valueId : phiInfo) {
1874       if (Value value = getValue(valueId)) {
1875         blockArgs.push_back(value);
1876         LLVM_DEBUG(llvm::dbgs() << "[phi] block argument " << value
1877                                 << " id = " << valueId << '\n');
1878       } else {
1879         return emitError(unknownLoc, "OpPhi references undefined value!");
1880       }
1881     }
1882 
1883     if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
1884       // Replace the previous branch op with a new one with block arguments.
1885       opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
1886                                         blockArgs);
1887       branchOp.erase();
1888     } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
1889       assert((branchCondOp.getTrueBlock() == target ||
1890               branchCondOp.getFalseBlock() == target) &&
1891              "expected target to be either the true or false target");
1892       if (target == branchCondOp.trueTarget())
1893         opBuilder.create<spirv::BranchConditionalOp>(
1894             branchCondOp.getLoc(), branchCondOp.condition(), blockArgs,
1895             branchCondOp.getFalseBlockArguments(),
1896             branchCondOp.branch_weightsAttr(), branchCondOp.trueTarget(),
1897             branchCondOp.falseTarget());
1898       else
1899         opBuilder.create<spirv::BranchConditionalOp>(
1900             branchCondOp.getLoc(), branchCondOp.condition(),
1901             branchCondOp.getTrueBlockArguments(), blockArgs,
1902             branchCondOp.branch_weightsAttr(), branchCondOp.getTrueBlock(),
1903             branchCondOp.getFalseBlock());
1904 
1905       branchCondOp.erase();
1906     } else {
1907       return emitError(unknownLoc, "unimplemented terminator for Phi creation");
1908     }
1909 
1910     LLVM_DEBUG(llvm::dbgs() << "[phi] after creating block argument:\n");
1911     LLVM_DEBUG(block->getParentOp()->print(llvm::dbgs()));
1912     LLVM_DEBUG(llvm::dbgs() << '\n');
1913   }
1914   blockPhiInfo.clear();
1915 
1916   LLVM_DEBUG(llvm::dbgs() << "[phi] completed wiring up block arguments\n");
1917   return success();
1918 }
1919 
structurizeControlFlow()1920 LogicalResult spirv::Deserializer::structurizeControlFlow() {
1921   LLVM_DEBUG(llvm::dbgs() << "[cf] start structurizing control flow\n");
1922 
1923   while (!blockMergeInfo.empty()) {
1924     Block *headerBlock = blockMergeInfo.begin()->first;
1925     BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
1926 
1927     LLVM_DEBUG(llvm::dbgs() << "[cf] header block " << headerBlock << ":\n");
1928     LLVM_DEBUG(headerBlock->print(llvm::dbgs()));
1929 
1930     auto *mergeBlock = mergeInfo.mergeBlock;
1931     assert(mergeBlock && "merge block cannot be nullptr");
1932     if (!mergeBlock->args_empty())
1933       return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
1934     LLVM_DEBUG(llvm::dbgs() << "[cf] merge block " << mergeBlock << ":\n");
1935     LLVM_DEBUG(mergeBlock->print(llvm::dbgs()));
1936 
1937     auto *continueBlock = mergeInfo.continueBlock;
1938     if (continueBlock) {
1939       LLVM_DEBUG(llvm::dbgs()
1940                  << "[cf] continue block " << continueBlock << ":\n");
1941       LLVM_DEBUG(continueBlock->print(llvm::dbgs()));
1942     }
1943     // Erase this case before calling into structurizer, who will update
1944     // blockMergeInfo.
1945     blockMergeInfo.erase(blockMergeInfo.begin());
1946     if (failed(ControlFlowStructurizer::structurize(
1947             mergeInfo.loc, mergeInfo.control, blockMergeInfo, headerBlock,
1948             mergeBlock, continueBlock)))
1949       return failure();
1950   }
1951 
1952   LLVM_DEBUG(llvm::dbgs() << "[cf] completed structurizing control flow\n");
1953   return success();
1954 }
1955 
1956 //===----------------------------------------------------------------------===//
1957 // Debug
1958 //===----------------------------------------------------------------------===//
1959 
createFileLineColLoc(OpBuilder opBuilder)1960 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
1961   if (!debugLine)
1962     return unknownLoc;
1963 
1964   auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
1965   if (fileName.empty())
1966     fileName = "<unknown>";
1967   return FileLineColLoc::get(opBuilder.getIdentifier(fileName), debugLine->line,
1968                              debugLine->col);
1969 }
1970 
1971 LogicalResult
processDebugLine(ArrayRef<uint32_t> operands)1972 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
1973   // According to SPIR-V spec:
1974   // "This location information applies to the instructions physically
1975   // following this instruction, up to the first occurrence of any of the
1976   // following: the next end of block, the next OpLine instruction, or the next
1977   // OpNoLine instruction."
1978   if (operands.size() != 3)
1979     return emitError(unknownLoc, "OpLine must have 3 operands");
1980   debugLine = DebugLine(operands[0], operands[1], operands[2]);
1981   return success();
1982 }
1983 
clearDebugLine()1984 LogicalResult spirv::Deserializer::clearDebugLine() {
1985   debugLine = llvm::None;
1986   return success();
1987 }
1988 
1989 LogicalResult
processDebugString(ArrayRef<uint32_t> operands)1990 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
1991   if (operands.size() < 2)
1992     return emitError(unknownLoc, "OpString needs at least 2 operands");
1993 
1994   if (!debugInfoMap.lookup(operands[0]).empty())
1995     return emitError(unknownLoc,
1996                      "duplicate debug string found for result <id> ")
1997            << operands[0];
1998 
1999   unsigned wordIndex = 1;
2000   StringRef debugString = decodeStringLiteral(operands, wordIndex);
2001   if (wordIndex != operands.size())
2002     return emitError(unknownLoc,
2003                      "unexpected trailing words in OpString instruction");
2004 
2005   debugInfoMap[operands[0]] = debugString;
2006   return success();
2007 }
2008