1 // Copyright (c) 2017 Pierre Moreau
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "spirv-tools/linker.hpp"
16 
17 #include <algorithm>
18 #include <cstdio>
19 #include <cstring>
20 #include <iostream>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <utility>
26 #include <vector>
27 
28 #include "source/assembly_grammar.h"
29 #include "source/diagnostic.h"
30 #include "source/opt/build_module.h"
31 #include "source/opt/compact_ids_pass.h"
32 #include "source/opt/decoration_manager.h"
33 #include "source/opt/ir_loader.h"
34 #include "source/opt/pass_manager.h"
35 #include "source/opt/remove_duplicates_pass.h"
36 #include "source/opt/type_manager.h"
37 #include "source/spirv_constant.h"
38 #include "source/spirv_target_env.h"
39 #include "source/util/make_unique.h"
40 #include "spirv-tools/libspirv.hpp"
41 
42 namespace spvtools {
43 namespace {
44 
45 using opt::Instruction;
46 using opt::IRContext;
47 using opt::Module;
48 using opt::PassManager;
49 using opt::RemoveDuplicatesPass;
50 using opt::analysis::DecorationManager;
51 using opt::analysis::DefUseManager;
52 using opt::analysis::Type;
53 using opt::analysis::TypeManager;
54 
55 // Stores various information about an imported or exported symbol.
56 struct LinkageSymbolInfo {
57   SpvId id;          // ID of the symbol
58   SpvId type_id;     // ID of the type of the symbol
59   std::string name;  // unique name defining the symbol and used for matching
60                      // imports and exports together
61   std::vector<SpvId> parameter_ids;  // ID of the parameters of the symbol, if
62                                      // it is a function
63 };
64 struct LinkageEntry {
65   LinkageSymbolInfo imported_symbol;
66   LinkageSymbolInfo exported_symbol;
67 
LinkageEntryspvtools::__anonc9c15a0e0111::LinkageEntry68   LinkageEntry(const LinkageSymbolInfo& import_info,
69                const LinkageSymbolInfo& export_info)
70       : imported_symbol(import_info), exported_symbol(export_info) {}
71 };
72 using LinkageTable = std::vector<LinkageEntry>;
73 
74 // Shifts the IDs used in each binary of |modules| so that they occupy a
75 // disjoint range from the other binaries, and compute the new ID bound which
76 // is returned in |max_id_bound|.
77 //
78 // Both |modules| and |max_id_bound| should not be null, and |modules| should
79 // not be empty either. Furthermore |modules| should not contain any null
80 // pointers.
81 spv_result_t ShiftIdsInModules(const MessageConsumer& consumer,
82                                std::vector<opt::Module*>* modules,
83                                uint32_t* max_id_bound);
84 
85 // Generates the header for the linked module and returns it in |header|.
86 //
87 // |header| should not be null, |modules| should not be empty and pointers
88 // should be non-null. |max_id_bound| should be strictly greater than 0.
89 //
90 // TODO(pierremoreau): What to do when binaries use different versions of
91 //                     SPIR-V? For now, use the max of all versions found in
92 //                     the input modules.
93 spv_result_t GenerateHeader(const MessageConsumer& consumer,
94                             const std::vector<opt::Module*>& modules,
95                             uint32_t max_id_bound, opt::ModuleHeader* header);
96 
97 // Merge all the modules from |in_modules| into a single module owned by
98 // |linked_context|.
99 //
100 // |linked_context| should not be null.
101 spv_result_t MergeModules(const MessageConsumer& consumer,
102                           const std::vector<Module*>& in_modules,
103                           const AssemblyGrammar& grammar,
104                           IRContext* linked_context);
105 
106 // Compute all pairs of import and export and return it in |linkings_to_do|.
107 //
108 // |linkings_to_do should not be null. Built-in symbols will be ignored.
109 //
110 // TODO(pierremoreau): Linkage attributes applied by a group decoration are
111 //                     currently not handled. (You could have a group being
112 //                     applied to a single ID.)
113 // TODO(pierremoreau): What should be the proper behaviour with built-in
114 //                     symbols?
115 spv_result_t GetImportExportPairs(const MessageConsumer& consumer,
116                                   const opt::IRContext& linked_context,
117                                   const DefUseManager& def_use_manager,
118                                   const DecorationManager& decoration_manager,
119                                   bool allow_partial_linkage,
120                                   LinkageTable* linkings_to_do);
121 
122 // Checks that for each pair of import and export, the import and export have
123 // the same type as well as the same decorations.
124 //
125 // TODO(pierremoreau): Decorations on functions parameters are currently not
126 // checked.
127 spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
128                                             const LinkageTable& linkings_to_do,
129                                             opt::IRContext* context);
130 
131 // Remove linkage specific instructions, such as prototypes of imported
132 // functions, declarations of imported variables, import (and export if
133 // necessary) linkage attribtes.
134 //
135 // |linked_context| and |decoration_manager| should not be null, and the
136 // 'RemoveDuplicatePass' should be run first.
137 //
138 // TODO(pierremoreau): Linkage attributes applied by a group decoration are
139 //                     currently not handled. (You could have a group being
140 //                     applied to a single ID.)
141 spv_result_t RemoveLinkageSpecificInstructions(
142     const MessageConsumer& consumer, const LinkerOptions& options,
143     const LinkageTable& linkings_to_do, DecorationManager* decoration_manager,
144     opt::IRContext* linked_context);
145 
146 // Verify that the unique ids of each instruction in |linked_context| (i.e. the
147 // merged module) are truly unique. Does not check the validity of other ids
148 spv_result_t VerifyIds(const MessageConsumer& consumer,
149                        opt::IRContext* linked_context);
150 
ShiftIdsInModules(const MessageConsumer & consumer,std::vector<opt::Module * > * modules,uint32_t * max_id_bound)151 spv_result_t ShiftIdsInModules(const MessageConsumer& consumer,
152                                std::vector<opt::Module*>* modules,
153                                uint32_t* max_id_bound) {
154   spv_position_t position = {};
155 
156   if (modules == nullptr)
157     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
158            << "|modules| of ShiftIdsInModules should not be null.";
159   if (modules->empty())
160     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
161            << "|modules| of ShiftIdsInModules should not be empty.";
162   if (max_id_bound == nullptr)
163     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
164            << "|max_id_bound| of ShiftIdsInModules should not be null.";
165 
166   uint32_t id_bound = modules->front()->IdBound() - 1u;
167   for (auto module_iter = modules->begin() + 1; module_iter != modules->end();
168        ++module_iter) {
169     Module* module = *module_iter;
170     module->ForEachInst([&id_bound](Instruction* insn) {
171       insn->ForEachId([&id_bound](uint32_t* id) { *id += id_bound; });
172     });
173     id_bound += module->IdBound() - 1u;
174     if (id_bound > 0x3FFFFF)
175       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_ID)
176              << "The limit of IDs, 4194303, was exceeded:"
177              << " " << id_bound << " is the current ID bound.";
178 
179     // Invalidate the DefUseManager
180     module->context()->InvalidateAnalyses(opt::IRContext::kAnalysisDefUse);
181   }
182   ++id_bound;
183   if (id_bound > 0x3FFFFF)
184     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_ID)
185            << "The limit of IDs, 4194303, was exceeded:"
186            << " " << id_bound << " is the current ID bound.";
187 
188   *max_id_bound = id_bound;
189 
190   return SPV_SUCCESS;
191 }
192 
GenerateHeader(const MessageConsumer & consumer,const std::vector<opt::Module * > & modules,uint32_t max_id_bound,opt::ModuleHeader * header)193 spv_result_t GenerateHeader(const MessageConsumer& consumer,
194                             const std::vector<opt::Module*>& modules,
195                             uint32_t max_id_bound, opt::ModuleHeader* header) {
196   spv_position_t position = {};
197 
198   if (modules.empty())
199     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
200            << "|modules| of GenerateHeader should not be empty.";
201   if (max_id_bound == 0u)
202     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
203            << "|max_id_bound| of GenerateHeader should not be null.";
204 
205   uint32_t version = 0u;
206   for (const auto& module : modules)
207     version = std::max(version, module->version());
208 
209   header->magic_number = SpvMagicNumber;
210   header->version = version;
211   header->generator = SPV_GENERATOR_WORD(SPV_GENERATOR_KHRONOS_LINKER, 0);
212   header->bound = max_id_bound;
213   header->reserved = 0u;
214 
215   return SPV_SUCCESS;
216 }
217 
MergeModules(const MessageConsumer & consumer,const std::vector<Module * > & input_modules,const AssemblyGrammar & grammar,IRContext * linked_context)218 spv_result_t MergeModules(const MessageConsumer& consumer,
219                           const std::vector<Module*>& input_modules,
220                           const AssemblyGrammar& grammar,
221                           IRContext* linked_context) {
222   spv_position_t position = {};
223 
224   if (linked_context == nullptr)
225     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
226            << "|linked_module| of MergeModules should not be null.";
227   Module* linked_module = linked_context->module();
228 
229   if (input_modules.empty()) return SPV_SUCCESS;
230 
231   for (const auto& module : input_modules)
232     for (const auto& inst : module->capabilities())
233       linked_module->AddCapability(
234           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
235 
236   for (const auto& module : input_modules)
237     for (const auto& inst : module->extensions())
238       linked_module->AddExtension(
239           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
240 
241   for (const auto& module : input_modules)
242     for (const auto& inst : module->ext_inst_imports())
243       linked_module->AddExtInstImport(
244           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
245 
246   do {
247     const Instruction* memory_model_inst = input_modules[0]->GetMemoryModel();
248     if (memory_model_inst == nullptr) break;
249 
250     uint32_t addressing_model = memory_model_inst->GetSingleWordOperand(0u);
251     uint32_t memory_model = memory_model_inst->GetSingleWordOperand(1u);
252     for (const auto& module : input_modules) {
253       memory_model_inst = module->GetMemoryModel();
254       if (memory_model_inst == nullptr) continue;
255 
256       if (addressing_model != memory_model_inst->GetSingleWordOperand(0u)) {
257         spv_operand_desc initial_desc = nullptr, current_desc = nullptr;
258         grammar.lookupOperand(SPV_OPERAND_TYPE_ADDRESSING_MODEL,
259                               addressing_model, &initial_desc);
260         grammar.lookupOperand(SPV_OPERAND_TYPE_ADDRESSING_MODEL,
261                               memory_model_inst->GetSingleWordOperand(0u),
262                               &current_desc);
263         return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL)
264                << "Conflicting addressing models: " << initial_desc->name
265                << " vs " << current_desc->name << ".";
266       }
267       if (memory_model != memory_model_inst->GetSingleWordOperand(1u)) {
268         spv_operand_desc initial_desc = nullptr, current_desc = nullptr;
269         grammar.lookupOperand(SPV_OPERAND_TYPE_MEMORY_MODEL, memory_model,
270                               &initial_desc);
271         grammar.lookupOperand(SPV_OPERAND_TYPE_MEMORY_MODEL,
272                               memory_model_inst->GetSingleWordOperand(1u),
273                               &current_desc);
274         return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL)
275                << "Conflicting memory models: " << initial_desc->name << " vs "
276                << current_desc->name << ".";
277       }
278     }
279 
280     if (memory_model_inst != nullptr)
281       linked_module->SetMemoryModel(std::unique_ptr<Instruction>(
282           memory_model_inst->Clone(linked_context)));
283   } while (false);
284 
285   std::vector<std::pair<uint32_t, const char*>> entry_points;
286   for (const auto& module : input_modules)
287     for (const auto& inst : module->entry_points()) {
288       const uint32_t model = inst.GetSingleWordInOperand(0);
289       const char* const name =
290           reinterpret_cast<const char*>(inst.GetInOperand(2).words.data());
291       const auto i = std::find_if(
292           entry_points.begin(), entry_points.end(),
293           [model, name](const std::pair<uint32_t, const char*>& v) {
294             return v.first == model && strcmp(name, v.second) == 0;
295           });
296       if (i != entry_points.end()) {
297         spv_operand_desc desc = nullptr;
298         grammar.lookupOperand(SPV_OPERAND_TYPE_EXECUTION_MODEL, model, &desc);
299         return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL)
300                << "The entry point \"" << name << "\", with execution model "
301                << desc->name << ", was already defined.";
302       }
303       linked_module->AddEntryPoint(
304           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
305       entry_points.emplace_back(model, name);
306     }
307 
308   for (const auto& module : input_modules)
309     for (const auto& inst : module->execution_modes())
310       linked_module->AddExecutionMode(
311           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
312 
313   for (const auto& module : input_modules)
314     for (const auto& inst : module->debugs1())
315       linked_module->AddDebug1Inst(
316           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
317 
318   for (const auto& module : input_modules)
319     for (const auto& inst : module->debugs2())
320       linked_module->AddDebug2Inst(
321           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
322 
323   for (const auto& module : input_modules)
324     for (const auto& inst : module->debugs3())
325       linked_module->AddDebug3Inst(
326           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
327 
328   for (const auto& module : input_modules)
329     for (const auto& inst : module->ext_inst_debuginfo())
330       linked_module->AddExtInstDebugInfo(
331           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
332 
333   // If the generated module uses SPIR-V 1.1 or higher, add an
334   // OpModuleProcessed instruction about the linking step.
335   if (linked_module->version() >= 0x10100) {
336     const std::string processed_string("Linked by SPIR-V Tools Linker");
337     const auto num_chars = processed_string.size();
338     // Compute num words, accommodate the terminating null character.
339     const auto num_words = (num_chars + 1 + 3) / 4;
340     std::vector<uint32_t> processed_words(num_words, 0u);
341     std::memcpy(processed_words.data(), processed_string.data(), num_chars);
342     linked_module->AddDebug3Inst(std::unique_ptr<Instruction>(
343         new Instruction(linked_context, SpvOpModuleProcessed, 0u, 0u,
344                         {{SPV_OPERAND_TYPE_LITERAL_STRING, processed_words}})));
345   }
346 
347   for (const auto& module : input_modules)
348     for (const auto& inst : module->annotations())
349       linked_module->AddAnnotationInst(
350           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
351 
352   // TODO(pierremoreau): Since the modules have not been validate, should we
353   //                     expect SpvStorageClassFunction variables outside
354   //                     functions?
355   uint32_t num_global_values = 0u;
356   for (const auto& module : input_modules) {
357     for (const auto& inst : module->types_values()) {
358       linked_module->AddType(
359           std::unique_ptr<Instruction>(inst.Clone(linked_context)));
360       num_global_values += inst.opcode() == SpvOpVariable;
361     }
362   }
363   if (num_global_values > 0xFFFF)
364     return DiagnosticStream(position, consumer, "", SPV_ERROR_INTERNAL)
365            << "The limit of global values, 65535, was exceeded;"
366            << " " << num_global_values << " global values were found.";
367 
368   // Process functions and their basic blocks
369   for (const auto& module : input_modules) {
370     for (const auto& func : *module) {
371       std::unique_ptr<opt::Function> cloned_func(func.Clone(linked_context));
372       linked_module->AddFunction(std::move(cloned_func));
373     }
374   }
375 
376   return SPV_SUCCESS;
377 }
378 
GetImportExportPairs(const MessageConsumer & consumer,const opt::IRContext & linked_context,const DefUseManager & def_use_manager,const DecorationManager & decoration_manager,bool allow_partial_linkage,LinkageTable * linkings_to_do)379 spv_result_t GetImportExportPairs(const MessageConsumer& consumer,
380                                   const opt::IRContext& linked_context,
381                                   const DefUseManager& def_use_manager,
382                                   const DecorationManager& decoration_manager,
383                                   bool allow_partial_linkage,
384                                   LinkageTable* linkings_to_do) {
385   spv_position_t position = {};
386 
387   if (linkings_to_do == nullptr)
388     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
389            << "|linkings_to_do| of GetImportExportPairs should not be empty.";
390 
391   std::vector<LinkageSymbolInfo> imports;
392   std::unordered_map<std::string, std::vector<LinkageSymbolInfo>> exports;
393 
394   // Figure out the imports and exports
395   for (const auto& decoration : linked_context.annotations()) {
396     if (decoration.opcode() != SpvOpDecorate ||
397         decoration.GetSingleWordInOperand(1u) != SpvDecorationLinkageAttributes)
398       continue;
399 
400     const SpvId id = decoration.GetSingleWordInOperand(0u);
401     // Ignore if the targeted symbol is a built-in
402     bool is_built_in = false;
403     for (const auto& id_decoration :
404          decoration_manager.GetDecorationsFor(id, false)) {
405       if (id_decoration->GetSingleWordInOperand(1u) == SpvDecorationBuiltIn) {
406         is_built_in = true;
407         break;
408       }
409     }
410     if (is_built_in) {
411       continue;
412     }
413 
414     const uint32_t type = decoration.GetSingleWordInOperand(3u);
415 
416     LinkageSymbolInfo symbol_info;
417     symbol_info.name =
418         reinterpret_cast<const char*>(decoration.GetInOperand(2u).words.data());
419     symbol_info.id = id;
420     symbol_info.type_id = 0u;
421 
422     // Retrieve the type of the current symbol. This information will be used
423     // when checking that the imported and exported symbols have the same
424     // types.
425     const Instruction* def_inst = def_use_manager.GetDef(id);
426     if (def_inst == nullptr)
427       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
428              << "ID " << id << " is never defined:\n";
429 
430     if (def_inst->opcode() == SpvOpVariable) {
431       symbol_info.type_id = def_inst->type_id();
432     } else if (def_inst->opcode() == SpvOpFunction) {
433       symbol_info.type_id = def_inst->GetSingleWordInOperand(1u);
434 
435       // range-based for loop calls begin()/end(), but never cbegin()/cend(),
436       // which will not work here.
437       for (auto func_iter = linked_context.module()->cbegin();
438            func_iter != linked_context.module()->cend(); ++func_iter) {
439         if (func_iter->result_id() != id) continue;
440         func_iter->ForEachParam([&symbol_info](const Instruction* inst) {
441           symbol_info.parameter_ids.push_back(inst->result_id());
442         });
443       }
444     } else {
445       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
446              << "Only global variables and functions can be decorated using"
447              << " LinkageAttributes; " << id << " is neither of them.\n";
448     }
449 
450     if (type == SpvLinkageTypeImport)
451       imports.push_back(symbol_info);
452     else if (type == SpvLinkageTypeExport)
453       exports[symbol_info.name].push_back(symbol_info);
454   }
455 
456   // Find the import/export pairs
457   for (const auto& import : imports) {
458     std::vector<LinkageSymbolInfo> possible_exports;
459     const auto& exp = exports.find(import.name);
460     if (exp != exports.end()) possible_exports = exp->second;
461     if (possible_exports.empty() && !allow_partial_linkage)
462       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
463              << "Unresolved external reference to \"" << import.name << "\".";
464     else if (possible_exports.size() > 1u)
465       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
466              << "Too many external references, " << possible_exports.size()
467              << ", were found for \"" << import.name << "\".";
468 
469     if (!possible_exports.empty())
470       linkings_to_do->emplace_back(import, possible_exports.front());
471   }
472 
473   return SPV_SUCCESS;
474 }
475 
CheckImportExportCompatibility(const MessageConsumer & consumer,const LinkageTable & linkings_to_do,opt::IRContext * context)476 spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
477                                             const LinkageTable& linkings_to_do,
478                                             opt::IRContext* context) {
479   spv_position_t position = {};
480 
481   // Ensure the import and export types are the same.
482   const DecorationManager& decoration_manager = *context->get_decoration_mgr();
483   const TypeManager& type_manager = *context->get_type_mgr();
484   for (const auto& linking_entry : linkings_to_do) {
485     Type* imported_symbol_type =
486         type_manager.GetType(linking_entry.imported_symbol.type_id);
487     Type* exported_symbol_type =
488         type_manager.GetType(linking_entry.exported_symbol.type_id);
489     if (!(*imported_symbol_type == *exported_symbol_type))
490       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
491              << "Type mismatch on symbol \""
492              << linking_entry.imported_symbol.name
493              << "\" between imported variable/function %"
494              << linking_entry.imported_symbol.id
495              << " and exported variable/function %"
496              << linking_entry.exported_symbol.id << ".";
497   }
498 
499   // Ensure the import and export decorations are similar
500   for (const auto& linking_entry : linkings_to_do) {
501     if (!decoration_manager.HaveTheSameDecorations(
502             linking_entry.imported_symbol.id, linking_entry.exported_symbol.id))
503       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
504              << "Decorations mismatch on symbol \""
505              << linking_entry.imported_symbol.name
506              << "\" between imported variable/function %"
507              << linking_entry.imported_symbol.id
508              << " and exported variable/function %"
509              << linking_entry.exported_symbol.id << ".";
510     // TODO(pierremoreau): Decorations on function parameters should probably
511     //                     match, except for FuncParamAttr if I understand the
512     //                     spec correctly.
513     // TODO(pierremoreau): Decorations on the function return type should
514     //                     match, except for FuncParamAttr.
515   }
516 
517   return SPV_SUCCESS;
518 }
519 
RemoveLinkageSpecificInstructions(const MessageConsumer & consumer,const LinkerOptions & options,const LinkageTable & linkings_to_do,DecorationManager * decoration_manager,opt::IRContext * linked_context)520 spv_result_t RemoveLinkageSpecificInstructions(
521     const MessageConsumer& consumer, const LinkerOptions& options,
522     const LinkageTable& linkings_to_do, DecorationManager* decoration_manager,
523     opt::IRContext* linked_context) {
524   spv_position_t position = {};
525 
526   if (decoration_manager == nullptr)
527     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
528            << "|decoration_manager| of RemoveLinkageSpecificInstructions "
529               "should not be empty.";
530   if (linked_context == nullptr)
531     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_DATA)
532            << "|linked_module| of RemoveLinkageSpecificInstructions should not "
533               "be empty.";
534 
535   // TODO(pierremoreau): Remove FuncParamAttr decorations of imported
536   // functions' return type.
537 
538   // Remove prototypes of imported functions
539   for (const auto& linking_entry : linkings_to_do) {
540     for (auto func_iter = linked_context->module()->begin();
541          func_iter != linked_context->module()->end();) {
542       if (func_iter->result_id() == linking_entry.imported_symbol.id)
543         func_iter = func_iter.Erase();
544       else
545         ++func_iter;
546     }
547   }
548 
549   // Remove declarations of imported variables
550   for (const auto& linking_entry : linkings_to_do) {
551     auto next = linked_context->types_values_begin();
552     for (auto inst = next; inst != linked_context->types_values_end();
553          inst = next) {
554       ++next;
555       if (inst->result_id() == linking_entry.imported_symbol.id) {
556         linked_context->KillInst(&*inst);
557       }
558     }
559   }
560 
561   // If partial linkage is allowed, we need an efficient way to check whether
562   // an imported ID had a corresponding export symbol. As uses of the imported
563   // symbol have already been replaced by the exported symbol, use the exported
564   // symbol ID.
565   // TODO(pierremoreau): This will not work if the decoration is applied
566   //                     through a group, but the linker does not support that
567   //                     either.
568   std::unordered_set<SpvId> imports;
569   if (options.GetAllowPartialLinkage()) {
570     imports.reserve(linkings_to_do.size());
571     for (const auto& linking_entry : linkings_to_do)
572       imports.emplace(linking_entry.exported_symbol.id);
573   }
574 
575   // Remove import linkage attributes
576   auto next = linked_context->annotation_begin();
577   for (auto inst = next; inst != linked_context->annotation_end();
578        inst = next) {
579     ++next;
580     // If this is an import annotation:
581     // * if we do not allow partial linkage, remove all import annotations;
582     // * otherwise, remove the annotation only if there was a corresponding
583     //   export.
584     if (inst->opcode() == SpvOpDecorate &&
585         inst->GetSingleWordOperand(1u) == SpvDecorationLinkageAttributes &&
586         inst->GetSingleWordOperand(3u) == SpvLinkageTypeImport &&
587         (!options.GetAllowPartialLinkage() ||
588          imports.find(inst->GetSingleWordOperand(0u)) != imports.end())) {
589       linked_context->KillInst(&*inst);
590     }
591   }
592 
593   // Remove export linkage attributes if making an executable
594   if (!options.GetCreateLibrary()) {
595     next = linked_context->annotation_begin();
596     for (auto inst = next; inst != linked_context->annotation_end();
597          inst = next) {
598       ++next;
599       if (inst->opcode() == SpvOpDecorate &&
600           inst->GetSingleWordOperand(1u) == SpvDecorationLinkageAttributes &&
601           inst->GetSingleWordOperand(3u) == SpvLinkageTypeExport) {
602         linked_context->KillInst(&*inst);
603       }
604     }
605   }
606 
607   // Remove Linkage capability if making an executable and partial linkage is
608   // not allowed
609   if (!options.GetCreateLibrary() && !options.GetAllowPartialLinkage()) {
610     for (auto& inst : linked_context->capabilities())
611       if (inst.GetSingleWordInOperand(0u) == SpvCapabilityLinkage) {
612         linked_context->KillInst(&inst);
613         // The RemoveDuplicatesPass did remove duplicated capabilities, so we
614         // now there aren’t more SpvCapabilityLinkage further down.
615         break;
616       }
617   }
618 
619   return SPV_SUCCESS;
620 }
621 
VerifyIds(const MessageConsumer & consumer,opt::IRContext * linked_context)622 spv_result_t VerifyIds(const MessageConsumer& consumer,
623                        opt::IRContext* linked_context) {
624   std::unordered_set<uint32_t> ids;
625   bool ok = true;
626   linked_context->module()->ForEachInst(
627       [&ids, &ok](const opt::Instruction* inst) {
628         ok &= ids.insert(inst->unique_id()).second;
629       });
630 
631   if (!ok) {
632     consumer(SPV_MSG_INTERNAL_ERROR, "", {}, "Non-unique id in merged module");
633     return SPV_ERROR_INVALID_ID;
634   }
635 
636   return SPV_SUCCESS;
637 }
638 
639 }  // namespace
640 
Link(const Context & context,const std::vector<std::vector<uint32_t>> & binaries,std::vector<uint32_t> * linked_binary,const LinkerOptions & options)641 spv_result_t Link(const Context& context,
642                   const std::vector<std::vector<uint32_t>>& binaries,
643                   std::vector<uint32_t>* linked_binary,
644                   const LinkerOptions& options) {
645   std::vector<const uint32_t*> binary_ptrs;
646   binary_ptrs.reserve(binaries.size());
647   std::vector<size_t> binary_sizes;
648   binary_sizes.reserve(binaries.size());
649 
650   for (const auto& binary : binaries) {
651     binary_ptrs.push_back(binary.data());
652     binary_sizes.push_back(binary.size());
653   }
654 
655   return Link(context, binary_ptrs.data(), binary_sizes.data(), binaries.size(),
656               linked_binary, options);
657 }
658 
Link(const Context & context,const uint32_t * const * binaries,const size_t * binary_sizes,size_t num_binaries,std::vector<uint32_t> * linked_binary,const LinkerOptions & options)659 spv_result_t Link(const Context& context, const uint32_t* const* binaries,
660                   const size_t* binary_sizes, size_t num_binaries,
661                   std::vector<uint32_t>* linked_binary,
662                   const LinkerOptions& options) {
663   spv_position_t position = {};
664   const spv_context& c_context = context.CContext();
665   const MessageConsumer& consumer = c_context->consumer;
666 
667   linked_binary->clear();
668   if (num_binaries == 0u)
669     return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
670            << "No modules were given.";
671 
672   std::vector<std::unique_ptr<IRContext>> ir_contexts;
673   std::vector<Module*> modules;
674   modules.reserve(num_binaries);
675   for (size_t i = 0u; i < num_binaries; ++i) {
676     const uint32_t schema = binaries[i][4u];
677     if (schema != 0u) {
678       position.index = 4u;
679       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
680              << "Schema is non-zero for module " << i + 1 << ".";
681     }
682 
683     std::unique_ptr<IRContext> ir_context = BuildModule(
684         c_context->target_env, consumer, binaries[i], binary_sizes[i]);
685     if (ir_context == nullptr)
686       return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
687              << "Failed to build module " << i + 1 << " out of " << num_binaries
688              << ".";
689     modules.push_back(ir_context->module());
690     ir_contexts.push_back(std::move(ir_context));
691   }
692 
693   // Phase 1: Shift the IDs used in each binary so that they occupy a disjoint
694   //          range from the other binaries, and compute the new ID bound.
695   uint32_t max_id_bound = 0u;
696   spv_result_t res = ShiftIdsInModules(consumer, &modules, &max_id_bound);
697   if (res != SPV_SUCCESS) return res;
698 
699   // Phase 2: Generate the header
700   opt::ModuleHeader header;
701   res = GenerateHeader(consumer, modules, max_id_bound, &header);
702   if (res != SPV_SUCCESS) return res;
703   IRContext linked_context(c_context->target_env, consumer);
704   linked_context.module()->SetHeader(header);
705 
706   // Phase 3: Merge all the binaries into a single one.
707   AssemblyGrammar grammar(c_context);
708   res = MergeModules(consumer, modules, grammar, &linked_context);
709   if (res != SPV_SUCCESS) return res;
710 
711   if (options.GetVerifyIds()) {
712     res = VerifyIds(consumer, &linked_context);
713     if (res != SPV_SUCCESS) return res;
714   }
715 
716   // Phase 4: Find the import/export pairs
717   LinkageTable linkings_to_do;
718   res = GetImportExportPairs(consumer, linked_context,
719                              *linked_context.get_def_use_mgr(),
720                              *linked_context.get_decoration_mgr(),
721                              options.GetAllowPartialLinkage(), &linkings_to_do);
722   if (res != SPV_SUCCESS) return res;
723 
724   // Phase 5: Ensure the import and export have the same types and decorations.
725   res =
726       CheckImportExportCompatibility(consumer, linkings_to_do, &linked_context);
727   if (res != SPV_SUCCESS) return res;
728 
729   // Phase 6: Remove duplicates
730   PassManager manager;
731   manager.SetMessageConsumer(consumer);
732   manager.AddPass<RemoveDuplicatesPass>();
733   opt::Pass::Status pass_res = manager.Run(&linked_context);
734   if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
735 
736   // Phase 7: Remove all names and decorations of import variables/functions
737   for (const auto& linking_entry : linkings_to_do) {
738     linked_context.KillNamesAndDecorates(linking_entry.imported_symbol.id);
739     for (const auto parameter_id :
740          linking_entry.imported_symbol.parameter_ids) {
741       linked_context.KillNamesAndDecorates(parameter_id);
742     }
743   }
744 
745   // Phase 8: Rematch import variables/functions to export variables/functions
746   for (const auto& linking_entry : linkings_to_do) {
747     linked_context.ReplaceAllUsesWith(linking_entry.imported_symbol.id,
748                                       linking_entry.exported_symbol.id);
749   }
750 
751   // Phase 9: Remove linkage specific instructions, such as import/export
752   // attributes, linkage capability, etc. if applicable
753   res = RemoveLinkageSpecificInstructions(consumer, options, linkings_to_do,
754                                           linked_context.get_decoration_mgr(),
755                                           &linked_context);
756   if (res != SPV_SUCCESS) return res;
757 
758   // Phase 10: Compact the IDs used in the module
759   manager.AddPass<opt::CompactIdsPass>();
760   pass_res = manager.Run(&linked_context);
761   if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;
762 
763   // Phase 11: Output the module
764   linked_context.module()->ToBinary(linked_binary, true);
765 
766   return SPV_SUCCESS;
767 }
768 
769 }  // namespace spvtools
770