1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // This pass injects code in a graphics shader to implement guarantees
16 // satisfying Vulkan's robustBufferAcces rules.  Robust access rules permit
17 // an out-of-bounds access to be redirected to an access of the same type
18 // (load, store, etc.) but within the same root object.
19 //
20 // We assume baseline functionality in Vulkan, i.e. the module uses
21 // logical addressing mode, without VK_KHR_variable_pointers.
22 //
23 //    - Logical addressing mode implies:
24 //      - Each root pointer (a pointer that exists other than by the
25 //        execution of a shader instruction) is the result of an OpVariable.
26 //
27 //      - Instructions that result in pointers are:
28 //          OpVariable
29 //          OpAccessChain
30 //          OpInBoundsAccessChain
31 //          OpFunctionParameter
32 //          OpImageTexelPointer
33 //          OpCopyObject
34 //
35 //      - Instructions that use a pointer are:
36 //          OpLoad
37 //          OpStore
38 //          OpAccessChain
39 //          OpInBoundsAccessChain
40 //          OpFunctionCall
41 //          OpImageTexelPointer
42 //          OpCopyMemory
43 //          OpCopyObject
44 //          all OpAtomic* instructions
45 //
46 // We classify pointer-users into:
47 //  - Accesses:
48 //    - OpLoad
49 //    - OpStore
50 //    - OpAtomic*
51 //    - OpCopyMemory
52 //
53 //  - Address calculations:
54 //    - OpAccessChain
55 //    - OpInBoundsAccessChain
56 //
57 //  - Pass-through:
58 //    - OpFunctionCall
59 //    - OpFunctionParameter
60 //    - OpCopyObject
61 //
62 // The strategy is:
63 //
64 //  - Handle only logical addressing mode. In particular, don't handle a module
65 //    if it uses one of the variable-pointers capabilities.
66 //
67 //  - Don't handle modules using capability RuntimeDescriptorArrayEXT.  So the
68 //    only runtime arrays are those that are the last member in a
69 //    Block-decorated struct.  This allows us to feasibly/easily compute the
70 //    length of the runtime array. See below.
71 //
72 //  - The memory locations accessed by OpLoad, OpStore, OpCopyMemory, and
73 //    OpAtomic* are determined by their pointer parameter or parameters.
74 //    Pointers are always (correctly) typed and so the address and number of
75 //    consecutive locations are fully determined by the pointer.
76 //
77 //  - A pointer value orginates as one of few cases:
78 //
79 //    - OpVariable for an interface object or an array of them: image,
80 //      buffer (UBO or SSBO), sampler, sampled-image, push-constant, input
81 //      variable, output variable. The execution environment is responsible for
82 //      allocating the correct amount of storage for these, and for ensuring
83 //      each resource bound to such a variable is big enough to contain the
84 //      SPIR-V pointee type of the variable.
85 //
86 //    - OpVariable for a non-interface object.  These are variables in
87 //      Workgroup, Private, and Function storage classes.  The compiler ensures
88 //      the underlying allocation is big enough to store the entire SPIR-V
89 //      pointee type of the variable.
90 //
91 //    - An OpFunctionParameter. This always maps to a pointer parameter to an
92 //      OpFunctionCall.
93 //
94 //      - In logical addressing mode, these are severely limited:
95 //        "Any pointer operand to an OpFunctionCall must be:
96 //          - a memory object declaration, or
97 //          - a pointer to an element in an array that is a memory object
98 //          declaration, where the element type is OpTypeSampler or OpTypeImage"
99 //
100 //      - This has an important simplifying consequence:
101 //
102 //        - When looking for a pointer to the structure containing a runtime
103 //          array, you begin with a pointer to the runtime array and trace
104 //          backward in the function.  You never have to trace back beyond
105 //          your function call boundary.  So you can't take a partial access
106 //          chain into an SSBO, then pass that pointer into a function.  So
107 //          we don't resort to using fat pointers to compute array length.
108 //          We can trace back to a pointer to the containing structure,
109 //          and use that in an OpArrayLength instruction. (The structure type
110 //          gives us the member index of the runtime array.)
111 //
112 //        - Otherwise, the pointer type fully encodes the range of valid
113 //          addresses. In particular, the type of a pointer to an aggregate
114 //          value fully encodes the range of indices when indexing into
115 //          that aggregate.
116 //
117 //    - The pointer is the result of an access chain instruction.  We clamp
118 //      indices contributing to address calculations.  As noted above, the
119 //      valid ranges are either bound by the length of a runtime array, or
120 //      by the type of the base pointer.  The length of a runtime array is
121 //      the result of an OpArrayLength instruction acting on the pointer of
122 //      the containing structure as noted above.
123 //
124 //      - Access chain indices are always treated as signed, so:
125 //        - Clamp the upper bound at the signed integer maximum.
126 //        - Use SClamp for all clamping.
127 //
128 //    - TODO(dneto): OpImageTexelPointer:
129 //      - Clamp coordinate to the image size returned by OpImageQuerySize
130 //      - If multi-sampled, clamp the sample index to the count returned by
131 //        OpImageQuerySamples.
132 //      - If not multi-sampled, set the sample index to 0.
133 //
134 //  - Rely on the external validator to check that pointers are only
135 //    used by the instructions as above.
136 //
137 //  - Handles OpTypeRuntimeArray
138 //    Track pointer back to original resource (pointer to struct), so we can
139 //    query the runtime array size.
140 //
141 
142 #include "graphics_robust_access_pass.h"
143 
144 #include <algorithm>
145 #include <cstring>
146 #include <functional>
147 #include <initializer_list>
148 #include <limits>
149 #include <utility>
150 
151 #include "constants.h"
152 #include "def_use_manager.h"
153 #include "function.h"
154 #include "ir_context.h"
155 #include "module.h"
156 #include "pass.h"
157 #include "source/diagnostic.h"
158 #include "source/util/make_unique.h"
159 #include "spirv-tools/libspirv.h"
160 #include "spirv/unified1/GLSL.std.450.h"
161 #include "spirv/unified1/spirv.h"
162 #include "type_manager.h"
163 #include "types.h"
164 
165 namespace spvtools {
166 namespace opt {
167 
168 using opt::Instruction;
169 using opt::Operand;
170 using spvtools::MakeUnique;
171 
GraphicsRobustAccessPass()172 GraphicsRobustAccessPass::GraphicsRobustAccessPass() : module_status_() {}
173 
Process()174 Pass::Status GraphicsRobustAccessPass::Process() {
175   module_status_ = PerModuleState();
176 
177   ProcessCurrentModule();
178 
179   auto result = module_status_.failed
180                     ? Status::Failure
181                     : (module_status_.modified ? Status::SuccessWithChange
182                                                : Status::SuccessWithoutChange);
183 
184   return result;
185 }
186 
Fail()187 spvtools::DiagnosticStream GraphicsRobustAccessPass::Fail() {
188   module_status_.failed = true;
189   // We don't really have a position, and we'll ignore the result.
190   return std::move(
191       spvtools::DiagnosticStream({}, consumer(), "", SPV_ERROR_INVALID_BINARY)
192       << name() << ": ");
193 }
194 
IsCompatibleModule()195 spv_result_t GraphicsRobustAccessPass::IsCompatibleModule() {
196   auto* feature_mgr = context()->get_feature_mgr();
197   if (!feature_mgr->HasCapability(SpvCapabilityShader))
198     return Fail() << "Can only process Shader modules";
199   if (feature_mgr->HasCapability(SpvCapabilityVariablePointers))
200     return Fail() << "Can't process modules with VariablePointers capability";
201   if (feature_mgr->HasCapability(SpvCapabilityVariablePointersStorageBuffer))
202     return Fail() << "Can't process modules with VariablePointersStorageBuffer "
203                      "capability";
204   if (feature_mgr->HasCapability(SpvCapabilityRuntimeDescriptorArrayEXT)) {
205     // These have a RuntimeArray outside of Block-decorated struct.  There
206     // is no way to compute the array length from within SPIR-V.
207     return Fail() << "Can't process modules with RuntimeDescriptorArrayEXT "
208                      "capability";
209   }
210 
211   {
212     auto* inst = context()->module()->GetMemoryModel();
213     const auto addressing_model = inst->GetSingleWordOperand(0);
214     if (addressing_model != SpvAddressingModelLogical)
215       return Fail() << "Addressing model must be Logical.  Found "
216                     << inst->PrettyPrint();
217   }
218   return SPV_SUCCESS;
219 }
220 
ProcessCurrentModule()221 spv_result_t GraphicsRobustAccessPass::ProcessCurrentModule() {
222   auto err = IsCompatibleModule();
223   if (err != SPV_SUCCESS) return err;
224 
225   ProcessFunction fn = [this](opt::Function* f) { return ProcessAFunction(f); };
226   module_status_.modified |= context()->ProcessReachableCallTree(fn);
227 
228   // Need something here.  It's the price we pay for easier failure paths.
229   return SPV_SUCCESS;
230 }
231 
ProcessAFunction(opt::Function * function)232 bool GraphicsRobustAccessPass::ProcessAFunction(opt::Function* function) {
233   // Ensure that all pointers computed inside a function are within bounds.
234   // Find the access chains in this block before trying to modify them.
235   std::vector<Instruction*> access_chains;
236   std::vector<Instruction*> image_texel_pointers;
237   for (auto& block : *function) {
238     for (auto& inst : block) {
239       switch (inst.opcode()) {
240         case SpvOpAccessChain:
241         case SpvOpInBoundsAccessChain:
242           access_chains.push_back(&inst);
243           break;
244         case SpvOpImageTexelPointer:
245           image_texel_pointers.push_back(&inst);
246           break;
247         default:
248           break;
249       }
250     }
251   }
252   for (auto* inst : access_chains) {
253     ClampIndicesForAccessChain(inst);
254     if (module_status_.failed) return module_status_.modified;
255   }
256 
257   for (auto* inst : image_texel_pointers) {
258     if (SPV_SUCCESS != ClampCoordinateForImageTexelPointer(inst)) break;
259   }
260   return module_status_.modified;
261 }
262 
ClampIndicesForAccessChain(Instruction * access_chain)263 void GraphicsRobustAccessPass::ClampIndicesForAccessChain(
264     Instruction* access_chain) {
265   Instruction& inst = *access_chain;
266 
267   auto* constant_mgr = context()->get_constant_mgr();
268   auto* def_use_mgr = context()->get_def_use_mgr();
269   auto* type_mgr = context()->get_type_mgr();
270   const bool have_int64_cap =
271       context()->get_feature_mgr()->HasCapability(SpvCapabilityInt64);
272 
273   // Replaces one of the OpAccessChain index operands with a new value.
274   // Updates def-use analysis.
275   auto replace_index = [&inst, def_use_mgr](uint32_t operand_index,
276                                             Instruction* new_value) {
277     inst.SetOperand(operand_index, {new_value->result_id()});
278     def_use_mgr->AnalyzeInstUse(&inst);
279     return SPV_SUCCESS;
280   };
281 
282   // Replaces one of the OpAccesssChain index operands with a clamped value.
283   // Replace the operand at |operand_index| with the value computed from
284   // signed_clamp(%old_value, %min_value, %max_value).  It also analyzes
285   // the new instruction and records that them module is modified.
286   // Assumes %min_value is signed-less-or-equal than %max_value. (All callees
287   // use 0 for %min_value).
288   auto clamp_index = [&inst, type_mgr, this, &replace_index](
289                          uint32_t operand_index, Instruction* old_value,
290                          Instruction* min_value, Instruction* max_value) {
291     auto* clamp_inst =
292         MakeSClampInst(*type_mgr, old_value, min_value, max_value, &inst);
293     return replace_index(operand_index, clamp_inst);
294   };
295 
296   // Ensures the specified index of access chain |inst| has a value that is
297   // at most |count| - 1.  If the index is already a constant value less than
298   // |count| then no change is made.
299   auto clamp_to_literal_count =
300       [&inst, this, &constant_mgr, &type_mgr, have_int64_cap, &replace_index,
301        &clamp_index](uint32_t operand_index, uint64_t count) -> spv_result_t {
302     Instruction* index_inst =
303         this->GetDef(inst.GetSingleWordOperand(operand_index));
304     const auto* index_type =
305         type_mgr->GetType(index_inst->type_id())->AsInteger();
306     assert(index_type);
307     const auto index_width = index_type->width();
308 
309     if (count <= 1) {
310       // Replace the index with 0.
311       return replace_index(operand_index, GetValueForType(0, index_type));
312     }
313 
314     uint64_t maxval = count - 1;
315 
316     // Compute the bit width of a viable type to hold |maxval|.
317     // Look for a bit width, up to 64 bits wide, to fit maxval.
318     uint32_t maxval_width = index_width;
319     while ((maxval_width < 64) && (0 != (maxval >> maxval_width))) {
320       maxval_width *= 2;
321     }
322     // Determine the type for |maxval|.
323     analysis::Integer signed_type_for_query(maxval_width, true);
324     auto* maxval_type =
325         type_mgr->GetRegisteredType(&signed_type_for_query)->AsInteger();
326     // Access chain indices are treated as signed, so limit the maximum value
327     // of the index so it will always be positive for a signed clamp operation.
328     maxval = std::min(maxval, ((uint64_t(1) << (maxval_width - 1)) - 1));
329 
330     if (index_width > 64) {
331       return this->Fail() << "Can't handle indices wider than 64 bits, found "
332                              "constant index with "
333                           << index_width << " bits as index number "
334                           << operand_index << " of access chain "
335                           << inst.PrettyPrint();
336     }
337 
338     // Split into two cases: the current index is a constant, or not.
339 
340     // If the index is a constant then |index_constant| will not be a null
341     // pointer.  (If index is an |OpConstantNull| then it |index_constant| will
342     // not be a null pointer.)  Since access chain indices must be scalar
343     // integers, this can't be a spec constant.
344     if (auto* index_constant = constant_mgr->GetConstantFromInst(index_inst)) {
345       auto* int_index_constant = index_constant->AsIntConstant();
346       int64_t value = 0;
347       // OpAccessChain indices are treated as signed.  So get the signed
348       // constant value here.
349       if (index_width <= 32) {
350         value = int64_t(int_index_constant->GetS32BitValue());
351       } else if (index_width <= 64) {
352         value = int_index_constant->GetS64BitValue();
353       }
354       if (value < 0) {
355         return replace_index(operand_index, GetValueForType(0, index_type));
356       } else if (uint64_t(value) <= maxval) {
357         // Nothing to do.
358         return SPV_SUCCESS;
359       } else {
360         // Replace with maxval.
361         assert(count > 0);  // Already took care of this case above.
362         return replace_index(operand_index,
363                              GetValueForType(maxval, maxval_type));
364       }
365     } else {
366       // Generate a clamp instruction.
367       assert(maxval >= 1);
368       assert(index_width <= 64);  // Otherwise, already returned above.
369       if (index_width >= 64 && !have_int64_cap) {
370         // An inconsistent module.
371         return Fail() << "Access chain index is wider than 64 bits, but Int64 "
372                          "is not declared: "
373                       << index_inst->PrettyPrint();
374       }
375       // Widen the index value if necessary
376       if (maxval_width > index_width) {
377         // Find the wider type.  We only need this case if a constant array
378         // bound is too big.
379 
380         // From how we calculated maxval_width, widening won't require adding
381         // the Int64 capability.
382         assert(have_int64_cap || maxval_width <= 32);
383         if (!have_int64_cap && maxval_width >= 64) {
384           // Be defensive, but this shouldn't happen.
385           return this->Fail()
386                  << "Clamping index would require adding Int64 capability. "
387                  << "Can't clamp 32-bit index " << operand_index
388                  << " of access chain " << inst.PrettyPrint();
389         }
390         index_inst = WidenInteger(index_type->IsSigned(), maxval_width,
391                                   index_inst, &inst);
392       }
393 
394       // Finally, clamp the index.
395       return clamp_index(operand_index, index_inst,
396                          GetValueForType(0, maxval_type),
397                          GetValueForType(maxval, maxval_type));
398     }
399     return SPV_SUCCESS;
400   };
401 
402   // Ensures the specified index of access chain |inst| has a value that is at
403   // most the value of |count_inst| minus 1, where |count_inst| is treated as an
404   // unsigned integer. This can log a failure.
405   auto clamp_to_count = [&inst, this, &constant_mgr, &clamp_to_literal_count,
406                          &clamp_index,
407                          &type_mgr](uint32_t operand_index,
408                                     Instruction* count_inst) -> spv_result_t {
409     Instruction* index_inst =
410         this->GetDef(inst.GetSingleWordOperand(operand_index));
411     const auto* index_type =
412         type_mgr->GetType(index_inst->type_id())->AsInteger();
413     const auto* count_type =
414         type_mgr->GetType(count_inst->type_id())->AsInteger();
415     assert(index_type);
416     if (const auto* count_constant =
417             constant_mgr->GetConstantFromInst(count_inst)) {
418       uint64_t value = 0;
419       const auto width = count_constant->type()->AsInteger()->width();
420       if (width <= 32) {
421         value = count_constant->AsIntConstant()->GetU32BitValue();
422       } else if (width <= 64) {
423         value = count_constant->AsIntConstant()->GetU64BitValue();
424       } else {
425         return this->Fail() << "Can't handle indices wider than 64 bits, found "
426                                "constant index with "
427                             << index_type->width() << "bits";
428       }
429       return clamp_to_literal_count(operand_index, value);
430     } else {
431       // Widen them to the same width.
432       const auto index_width = index_type->width();
433       const auto count_width = count_type->width();
434       const auto target_width = std::max(index_width, count_width);
435       // UConvert requires the result type to have 0 signedness.  So enforce
436       // that here.
437       auto* wider_type = index_width < count_width ? count_type : index_type;
438       if (index_type->width() < target_width) {
439         // Access chain indices are treated as signed integers.
440         index_inst = WidenInteger(true, target_width, index_inst, &inst);
441       } else if (count_type->width() < target_width) {
442         // Assume type sizes are treated as unsigned.
443         count_inst = WidenInteger(false, target_width, count_inst, &inst);
444       }
445       // Compute count - 1.
446       // It doesn't matter if 1 is signed or unsigned.
447       auto* one = GetValueForType(1, wider_type);
448       auto* count_minus_1 = InsertInst(
449           &inst, SpvOpISub, type_mgr->GetId(wider_type), TakeNextId(),
450           {{SPV_OPERAND_TYPE_ID, {count_inst->result_id()}},
451            {SPV_OPERAND_TYPE_ID, {one->result_id()}}});
452       auto* zero = GetValueForType(0, wider_type);
453       // Make sure we clamp to an upper bound that is at most the signed max
454       // for the target type.
455       const uint64_t max_signed_value =
456           ((uint64_t(1) << (target_width - 1)) - 1);
457       // Use unsigned-min to ensure that the result is always non-negative.
458       // That ensures we satisfy the invariant for SClamp, where the "min"
459       // argument we give it (zero), is no larger than the third argument.
460       auto* upper_bound =
461           MakeUMinInst(*type_mgr, count_minus_1,
462                        GetValueForType(max_signed_value, wider_type), &inst);
463       // Now clamp the index to this upper bound.
464       return clamp_index(operand_index, index_inst, zero, upper_bound);
465     }
466     return SPV_SUCCESS;
467   };
468 
469   const Instruction* base_inst = GetDef(inst.GetSingleWordInOperand(0));
470   const Instruction* base_type = GetDef(base_inst->type_id());
471   Instruction* pointee_type = GetDef(base_type->GetSingleWordInOperand(1));
472 
473   // Walk the indices from earliest to latest, replacing indices with a
474   // clamped value, and updating the pointee_type.  The order matters for
475   // the case when we have to compute the length of a runtime array.  In
476   // that the algorithm relies on the fact that that the earlier indices
477   // have already been clamped.
478   const uint32_t num_operands = inst.NumOperands();
479   for (uint32_t idx = 3; !module_status_.failed && idx < num_operands; ++idx) {
480     const uint32_t index_id = inst.GetSingleWordOperand(idx);
481     Instruction* index_inst = GetDef(index_id);
482 
483     switch (pointee_type->opcode()) {
484       case SpvOpTypeMatrix:  // Use column count
485       case SpvOpTypeVector:  // Use component count
486       {
487         const uint32_t count = pointee_type->GetSingleWordOperand(2);
488         clamp_to_literal_count(idx, count);
489         pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
490       } break;
491 
492       case SpvOpTypeArray: {
493         // The array length can be a spec constant, so go through the general
494         // case.
495         Instruction* array_len = GetDef(pointee_type->GetSingleWordOperand(2));
496         clamp_to_count(idx, array_len);
497         pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
498       } break;
499 
500       case SpvOpTypeStruct: {
501         // SPIR-V requires the index to be an OpConstant.
502         // We need to know the index literal value so we can compute the next
503         // pointee type.
504         if (index_inst->opcode() != SpvOpConstant ||
505             !constant_mgr->GetConstantFromInst(index_inst)
506                  ->type()
507                  ->AsInteger()) {
508           Fail() << "Member index into struct is not a constant integer: "
509                  << index_inst->PrettyPrint(
510                         SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES)
511                  << "\nin access chain: "
512                  << inst.PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
513           return;
514         }
515         const auto num_members = pointee_type->NumInOperands();
516         const auto* index_constant =
517             constant_mgr->GetConstantFromInst(index_inst);
518         // Get the sign-extended value, since access index is always treated as
519         // signed.
520         const auto index_value = index_constant->GetSignExtendedValue();
521         if (index_value < 0 || index_value >= num_members) {
522           Fail() << "Member index " << index_value
523                  << " is out of bounds for struct type: "
524                  << pointee_type->PrettyPrint(
525                         SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES)
526                  << "\nin access chain: "
527                  << inst.PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
528           return;
529         }
530         pointee_type = GetDef(pointee_type->GetSingleWordInOperand(
531             static_cast<uint32_t>(index_value)));
532         // No need to clamp this index.  We just checked that it's valid.
533       } break;
534 
535       case SpvOpTypeRuntimeArray: {
536         auto* array_len = MakeRuntimeArrayLengthInst(&inst, idx);
537         if (!array_len) {  // We've already signaled an error.
538           return;
539         }
540         clamp_to_count(idx, array_len);
541         if (module_status_.failed) return;
542         pointee_type = GetDef(pointee_type->GetSingleWordOperand(1));
543       } break;
544 
545       default:
546         Fail() << " Unhandled pointee type for access chain "
547                << pointee_type->PrettyPrint(
548                       SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
549     }
550   }
551 }
552 
GetGlslInsts()553 uint32_t GraphicsRobustAccessPass::GetGlslInsts() {
554   if (module_status_.glsl_insts_id == 0) {
555     // This string serves double-duty as raw data for a string and for a vector
556     // of 32-bit words
557     const char glsl[] = "GLSL.std.450\0\0\0\0";
558     const size_t glsl_str_byte_len = 16;
559     // Use an existing import if we can.
560     for (auto& inst : context()->module()->ext_inst_imports()) {
561       const auto& name_words = inst.GetInOperand(0).words;
562       if (0 == std::strncmp(reinterpret_cast<const char*>(name_words.data()),
563                             glsl, glsl_str_byte_len)) {
564         module_status_.glsl_insts_id = inst.result_id();
565       }
566     }
567     if (module_status_.glsl_insts_id == 0) {
568       // Make a new import instruction.
569       module_status_.glsl_insts_id = TakeNextId();
570       std::vector<uint32_t> words(glsl_str_byte_len / sizeof(uint32_t));
571       std::memcpy(words.data(), glsl, glsl_str_byte_len);
572       auto import_inst = MakeUnique<Instruction>(
573           context(), SpvOpExtInstImport, 0, module_status_.glsl_insts_id,
574           std::initializer_list<Operand>{
575               Operand{SPV_OPERAND_TYPE_LITERAL_STRING, std::move(words)}});
576       Instruction* inst = import_inst.get();
577       context()->module()->AddExtInstImport(std::move(import_inst));
578       module_status_.modified = true;
579       context()->AnalyzeDefUse(inst);
580       // Reanalyze the feature list, since we added an extended instruction
581       // set improt.
582       context()->get_feature_mgr()->Analyze(context()->module());
583     }
584   }
585   return module_status_.glsl_insts_id;
586 }
587 
GetValueForType(uint64_t value,const analysis::Integer * type)588 opt::Instruction* opt::GraphicsRobustAccessPass::GetValueForType(
589     uint64_t value, const analysis::Integer* type) {
590   auto* mgr = context()->get_constant_mgr();
591   assert(type->width() <= 64);
592   std::vector<uint32_t> words;
593   words.push_back(uint32_t(value));
594   if (type->width() > 32) {
595     words.push_back(uint32_t(value >> 32u));
596   }
597   const auto* constant = mgr->GetConstant(type, words);
598   return mgr->GetDefiningInstruction(
599       constant, context()->get_type_mgr()->GetTypeInstruction(type));
600 }
601 
WidenInteger(bool sign_extend,uint32_t bit_width,Instruction * value,Instruction * before_inst)602 opt::Instruction* opt::GraphicsRobustAccessPass::WidenInteger(
603     bool sign_extend, uint32_t bit_width, Instruction* value,
604     Instruction* before_inst) {
605   analysis::Integer unsigned_type_for_query(bit_width, false);
606   auto* type_mgr = context()->get_type_mgr();
607   auto* unsigned_type = type_mgr->GetRegisteredType(&unsigned_type_for_query);
608   auto type_id = context()->get_type_mgr()->GetId(unsigned_type);
609   auto conversion_id = TakeNextId();
610   auto* conversion = InsertInst(
611       before_inst, (sign_extend ? SpvOpSConvert : SpvOpUConvert), type_id,
612       conversion_id, {{SPV_OPERAND_TYPE_ID, {value->result_id()}}});
613   return conversion;
614 }
615 
MakeUMinInst(const analysis::TypeManager & tm,Instruction * x,Instruction * y,Instruction * where)616 Instruction* GraphicsRobustAccessPass::MakeUMinInst(
617     const analysis::TypeManager& tm, Instruction* x, Instruction* y,
618     Instruction* where) {
619   // Get IDs of instructions we'll be referencing. Evaluate them before calling
620   // the function so we force a deterministic ordering in case both of them need
621   // to take a new ID.
622   const uint32_t glsl_insts_id = GetGlslInsts();
623   uint32_t smin_id = TakeNextId();
624   const auto xwidth = tm.GetType(x->type_id())->AsInteger()->width();
625   const auto ywidth = tm.GetType(y->type_id())->AsInteger()->width();
626   assert(xwidth == ywidth);
627   (void)xwidth;
628   (void)ywidth;
629   auto* smin_inst = InsertInst(
630       where, SpvOpExtInst, x->type_id(), smin_id,
631       {
632           {SPV_OPERAND_TYPE_ID, {glsl_insts_id}},
633           {SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, {GLSLstd450UMin}},
634           {SPV_OPERAND_TYPE_ID, {x->result_id()}},
635           {SPV_OPERAND_TYPE_ID, {y->result_id()}},
636       });
637   return smin_inst;
638 }
639 
MakeSClampInst(const analysis::TypeManager & tm,Instruction * x,Instruction * min,Instruction * max,Instruction * where)640 Instruction* GraphicsRobustAccessPass::MakeSClampInst(
641     const analysis::TypeManager& tm, Instruction* x, Instruction* min,
642     Instruction* max, Instruction* where) {
643   // Get IDs of instructions we'll be referencing. Evaluate them before calling
644   // the function so we force a deterministic ordering in case both of them need
645   // to take a new ID.
646   const uint32_t glsl_insts_id = GetGlslInsts();
647   uint32_t clamp_id = TakeNextId();
648   const auto xwidth = tm.GetType(x->type_id())->AsInteger()->width();
649   const auto minwidth = tm.GetType(min->type_id())->AsInteger()->width();
650   const auto maxwidth = tm.GetType(max->type_id())->AsInteger()->width();
651   assert(xwidth == minwidth);
652   assert(xwidth == maxwidth);
653   (void)xwidth;
654   (void)minwidth;
655   (void)maxwidth;
656   auto* clamp_inst = InsertInst(
657       where, SpvOpExtInst, x->type_id(), clamp_id,
658       {
659           {SPV_OPERAND_TYPE_ID, {glsl_insts_id}},
660           {SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, {GLSLstd450SClamp}},
661           {SPV_OPERAND_TYPE_ID, {x->result_id()}},
662           {SPV_OPERAND_TYPE_ID, {min->result_id()}},
663           {SPV_OPERAND_TYPE_ID, {max->result_id()}},
664       });
665   return clamp_inst;
666 }
667 
MakeRuntimeArrayLengthInst(Instruction * access_chain,uint32_t operand_index)668 Instruction* GraphicsRobustAccessPass::MakeRuntimeArrayLengthInst(
669     Instruction* access_chain, uint32_t operand_index) {
670   // The Index parameter to the access chain at |operand_index| is indexing
671   // *into* the runtime-array.  To get the number of elements in the runtime
672   // array we need a pointer to the Block-decorated struct that contains the
673   // runtime array. So conceptually we have to go 2 steps backward in the
674   // access chain.  The two steps backward might forces us to traverse backward
675   // across multiple dominating instructions.
676   auto* type_mgr = context()->get_type_mgr();
677 
678   // How many access chain indices do we have to unwind to find the pointer
679   // to the struct containing the runtime array?
680   uint32_t steps_remaining = 2;
681   // Find or create an instruction computing the pointer to the structure
682   // containing the runtime array.
683   // Walk backward through pointer address calculations until we either get
684   // to exactly the right base pointer, or to an access chain instruction
685   // that we can replicate but truncate to compute the address of the right
686   // struct.
687   Instruction* current_access_chain = access_chain;
688   Instruction* pointer_to_containing_struct = nullptr;
689   while (steps_remaining > 0) {
690     switch (current_access_chain->opcode()) {
691       case SpvOpCopyObject:
692         // Whoops. Walk right through this one.
693         current_access_chain =
694             GetDef(current_access_chain->GetSingleWordInOperand(0));
695         break;
696       case SpvOpAccessChain:
697       case SpvOpInBoundsAccessChain: {
698         const int first_index_operand = 3;
699         // How many indices in this access chain contribute to getting us
700         // to an element in the runtime array?
701         const auto num_contributing_indices =
702             current_access_chain == access_chain
703                 ? operand_index - (first_index_operand - 1)
704                 : current_access_chain->NumInOperands() - 1 /* skip the base */;
705         Instruction* base =
706             GetDef(current_access_chain->GetSingleWordInOperand(0));
707         if (num_contributing_indices == steps_remaining) {
708           // The base pointer points to the structure.
709           pointer_to_containing_struct = base;
710           steps_remaining = 0;
711           break;
712         } else if (num_contributing_indices < steps_remaining) {
713           // Peel off the index and keep going backward.
714           steps_remaining -= num_contributing_indices;
715           current_access_chain = base;
716         } else {
717           // This access chain has more indices than needed.  Generate a new
718           // access chain instruction, but truncating the list of indices.
719           const int base_operand = 2;
720           // We'll use the base pointer and the indices up to but not including
721           // the one indexing into the runtime array.
722           Instruction::OperandList ops;
723           // Use the base pointer
724           ops.push_back(current_access_chain->GetOperand(base_operand));
725           const uint32_t num_indices_to_keep =
726               num_contributing_indices - steps_remaining - 1;
727           for (uint32_t i = 0; i <= num_indices_to_keep; i++) {
728             ops.push_back(
729                 current_access_chain->GetOperand(first_index_operand + i));
730           }
731           // Compute the type of the result of the new access chain.  Start at
732           // the base and walk the indices in a forward direction.
733           auto* constant_mgr = context()->get_constant_mgr();
734           std::vector<uint32_t> indices_for_type;
735           for (uint32_t i = 0; i < ops.size() - 1; i++) {
736             uint32_t index_for_type_calculation = 0;
737             Instruction* index =
738                 GetDef(current_access_chain->GetSingleWordOperand(
739                     first_index_operand + i));
740             if (auto* index_constant =
741                     constant_mgr->GetConstantFromInst(index)) {
742               // We only need 32 bits. For the type calculation, it's sufficient
743               // to take the zero-extended value. It only matters for the struct
744               // case, and struct member indices are unsigned.
745               index_for_type_calculation =
746                   uint32_t(index_constant->GetZeroExtendedValue());
747             } else {
748               // Indexing into a variably-sized thing like an array.  Use 0.
749               index_for_type_calculation = 0;
750             }
751             indices_for_type.push_back(index_for_type_calculation);
752           }
753           auto* base_ptr_type = type_mgr->GetType(base->type_id())->AsPointer();
754           auto* base_pointee_type = base_ptr_type->pointee_type();
755           auto* new_access_chain_result_pointee_type =
756               type_mgr->GetMemberType(base_pointee_type, indices_for_type);
757           const uint32_t new_access_chain_type_id = type_mgr->FindPointerToType(
758               type_mgr->GetId(new_access_chain_result_pointee_type),
759               base_ptr_type->storage_class());
760 
761           // Create the instruction and insert it.
762           const auto new_access_chain_id = TakeNextId();
763           auto* new_access_chain =
764               InsertInst(current_access_chain, current_access_chain->opcode(),
765                          new_access_chain_type_id, new_access_chain_id, ops);
766           pointer_to_containing_struct = new_access_chain;
767           steps_remaining = 0;
768           break;
769         }
770       } break;
771       default:
772         Fail() << "Unhandled access chain in logical addressing mode passes "
773                   "through "
774                << current_access_chain->PrettyPrint(
775                       SPV_BINARY_TO_TEXT_OPTION_SHOW_BYTE_OFFSET |
776                       SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
777         return nullptr;
778     }
779   }
780   assert(pointer_to_containing_struct);
781   auto* pointee_type =
782       type_mgr->GetType(pointer_to_containing_struct->type_id())
783           ->AsPointer()
784           ->pointee_type();
785 
786   auto* struct_type = pointee_type->AsStruct();
787   const uint32_t member_index_of_runtime_array =
788       uint32_t(struct_type->element_types().size() - 1);
789   // Create the length-of-array instruction before the original access chain,
790   // but after the generation of the pointer to the struct.
791   const auto array_len_id = TakeNextId();
792   analysis::Integer uint_type_for_query(32, false);
793   auto* uint_type = type_mgr->GetRegisteredType(&uint_type_for_query);
794   auto* array_len = InsertInst(
795       access_chain, SpvOpArrayLength, type_mgr->GetId(uint_type), array_len_id,
796       {{SPV_OPERAND_TYPE_ID, {pointer_to_containing_struct->result_id()}},
797        {SPV_OPERAND_TYPE_LITERAL_INTEGER, {member_index_of_runtime_array}}});
798   return array_len;
799 }
800 
ClampCoordinateForImageTexelPointer(opt::Instruction * image_texel_pointer)801 spv_result_t GraphicsRobustAccessPass::ClampCoordinateForImageTexelPointer(
802     opt::Instruction* image_texel_pointer) {
803   // TODO(dneto): Write tests for this code.
804   // TODO(dneto): Use signed-clamp
805   return SPV_SUCCESS;
806 
807   // Example:
808   //   %texel_ptr = OpImageTexelPointer %texel_ptr_type %image_ptr %coord
809   //   %sample
810   //
811   // We want to clamp %coord components between vector-0 and the result
812   // of OpImageQuerySize acting on the underlying image.  So insert:
813   //     %image = OpLoad %image_type %image_ptr
814   //     %query_size = OpImageQuerySize %query_size_type %image
815   //
816   // For a multi-sampled image, %sample is the sample index, and we need
817   // to clamp it between zero and the number of samples in the image.
818   //     %sample_count = OpImageQuerySamples %uint %image
819   //     %max_sample_index = OpISub %uint %sample_count %uint_1
820   // For non-multi-sampled images, the sample index must be constant zero.
821 
822   auto* def_use_mgr = context()->get_def_use_mgr();
823   auto* type_mgr = context()->get_type_mgr();
824   auto* constant_mgr = context()->get_constant_mgr();
825 
826   auto* image_ptr = GetDef(image_texel_pointer->GetSingleWordInOperand(0));
827   auto* image_ptr_type = GetDef(image_ptr->type_id());
828   auto image_type_id = image_ptr_type->GetSingleWordInOperand(1);
829   auto* image_type = GetDef(image_type_id);
830   auto* coord = GetDef(image_texel_pointer->GetSingleWordInOperand(1));
831   auto* samples = GetDef(image_texel_pointer->GetSingleWordInOperand(2));
832 
833   // We will modify the module, at least by adding image query instructions.
834   module_status_.modified = true;
835 
836   // Declare the ImageQuery capability if the module doesn't already have it.
837   auto* feature_mgr = context()->get_feature_mgr();
838   if (!feature_mgr->HasCapability(SpvCapabilityImageQuery)) {
839     auto cap = MakeUnique<Instruction>(
840         context(), SpvOpCapability, 0, 0,
841         std::initializer_list<Operand>{
842             {SPV_OPERAND_TYPE_CAPABILITY, {SpvCapabilityImageQuery}}});
843     def_use_mgr->AnalyzeInstDefUse(cap.get());
844     context()->AddCapability(std::move(cap));
845     feature_mgr->Analyze(context()->module());
846   }
847 
848   // OpImageTexelPointer is used to translate a coordinate and sample index
849   // into an address for use with an atomic operation.  That is, it may only
850   // used with what Vulkan calls a "storage image"
851   // (OpTypeImage parameter Sampled=2).
852   // Note: A storage image never has a level-of-detail associated with it.
853 
854   // Constraints on the sample id:
855   //  - Only 2D images can be multi-sampled: OpTypeImage parameter MS=1
856   //    only if Dim=2D.
857   //  - Non-multi-sampled images (OpTypeImage parameter MS=0) must use
858   //    sample ID to a constant 0.
859 
860   // The coordinate is treated as unsigned, and should be clamped against the
861   // image "size", returned by OpImageQuerySize. (Note: OpImageQuerySizeLod
862   // is only usable with a sampled image, i.e. its image type has Sampled=1).
863 
864   // Determine the result type for the OpImageQuerySize.
865   // For non-arrayed images:
866   //   non-Cube:
867   //     - Always the same as the coordinate type
868   //   Cube:
869   //     - Use all but the last component of the coordinate (which is the face
870   //       index from 0 to 5).
871   // For arrayed images (in Vulkan the Dim is 1D, 2D, or Cube):
872   //   non-Cube:
873   //     - A vector with the components in the coordinate, and one more for
874   //       the layer index.
875   //   Cube:
876   //     - The same as the coordinate type: 3-element integer vector.
877   //     - The third component from the size query is the layer count.
878   //     - The third component in the texel pointer calculation is
879   //       6 * layer + face, where 0 <= face < 6.
880   //   Cube: Use all but the last component of the coordinate (which is the face
881   //   index from 0 to 5).
882   const auto dim = SpvDim(image_type->GetSingleWordInOperand(1));
883   const bool arrayed = image_type->GetSingleWordInOperand(3) == 1;
884   const bool multisampled = image_type->GetSingleWordInOperand(4) != 0;
885   const auto query_num_components = [dim, arrayed, this]() -> int {
886     const int arrayness_bonus = arrayed ? 1 : 0;
887     int num_coords = 0;
888     switch (dim) {
889       case SpvDimBuffer:
890       case SpvDim1D:
891         num_coords = 1;
892         break;
893       case SpvDimCube:
894         // For cube, we need bounds for x, y, but not face.
895       case SpvDimRect:
896       case SpvDim2D:
897         num_coords = 2;
898         break;
899       case SpvDim3D:
900         num_coords = 3;
901         break;
902       case SpvDimSubpassData:
903       case SpvDimMax:
904         return Fail() << "Invalid image dimension for OpImageTexelPointer: "
905                       << int(dim);
906         break;
907     }
908     return num_coords + arrayness_bonus;
909   }();
910   const auto* coord_component_type = [type_mgr, coord]() {
911     const analysis::Type* coord_type = type_mgr->GetType(coord->type_id());
912     if (auto* vector_type = coord_type->AsVector()) {
913       return vector_type->element_type()->AsInteger();
914     }
915     return coord_type->AsInteger();
916   }();
917   // For now, only handle 32-bit case for coordinates.
918   if (!coord_component_type) {
919     return Fail() << " Coordinates for OpImageTexelPointer are not integral: "
920                   << image_texel_pointer->PrettyPrint(
921                          SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
922   }
923   if (coord_component_type->width() != 32) {
924     return Fail() << " Expected OpImageTexelPointer coordinate components to "
925                      "be 32-bits wide. They are "
926                   << coord_component_type->width() << " bits. "
927                   << image_texel_pointer->PrettyPrint(
928                          SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
929   }
930   const auto* query_size_type =
931       [type_mgr, coord_component_type,
932        query_num_components]() -> const analysis::Type* {
933     if (query_num_components == 1) return coord_component_type;
934     analysis::Vector proposed(coord_component_type, query_num_components);
935     return type_mgr->GetRegisteredType(&proposed);
936   }();
937 
938   const uint32_t image_id = TakeNextId();
939   auto* image =
940       InsertInst(image_texel_pointer, SpvOpLoad, image_type_id, image_id,
941                  {{SPV_OPERAND_TYPE_ID, {image_ptr->result_id()}}});
942 
943   const uint32_t query_size_id = TakeNextId();
944   auto* query_size =
945       InsertInst(image_texel_pointer, SpvOpImageQuerySize,
946                  type_mgr->GetTypeInstruction(query_size_type), query_size_id,
947                  {{SPV_OPERAND_TYPE_ID, {image->result_id()}}});
948 
949   auto* component_1 = constant_mgr->GetConstant(coord_component_type, {1});
950   const uint32_t component_1_id =
951       constant_mgr->GetDefiningInstruction(component_1)->result_id();
952   auto* component_0 = constant_mgr->GetConstant(coord_component_type, {0});
953   const uint32_t component_0_id =
954       constant_mgr->GetDefiningInstruction(component_0)->result_id();
955 
956   // If the image is a cube array, then the last component of the queried
957   // size is the layer count.  In the query, we have to accomodate folding
958   // in the face index ranging from 0 through 5. The inclusive upper bound
959   // on the third coordinate therefore is multiplied by 6.
960   auto* query_size_including_faces = query_size;
961   if (arrayed && (dim == SpvDimCube)) {
962     // Multiply the last coordinate by 6.
963     auto* component_6 = constant_mgr->GetConstant(coord_component_type, {6});
964     const uint32_t component_6_id =
965         constant_mgr->GetDefiningInstruction(component_6)->result_id();
966     assert(query_num_components == 3);
967     auto* multiplicand = constant_mgr->GetConstant(
968         query_size_type, {component_1_id, component_1_id, component_6_id});
969     auto* multiplicand_inst =
970         constant_mgr->GetDefiningInstruction(multiplicand);
971     const auto query_size_including_faces_id = TakeNextId();
972     query_size_including_faces = InsertInst(
973         image_texel_pointer, SpvOpIMul,
974         type_mgr->GetTypeInstruction(query_size_type),
975         query_size_including_faces_id,
976         {{SPV_OPERAND_TYPE_ID, {query_size_including_faces->result_id()}},
977          {SPV_OPERAND_TYPE_ID, {multiplicand_inst->result_id()}}});
978   }
979 
980   // Make a coordinate-type with all 1 components.
981   auto* coordinate_1 =
982       query_num_components == 1
983           ? component_1
984           : constant_mgr->GetConstant(
985                 query_size_type,
986                 std::vector<uint32_t>(query_num_components, component_1_id));
987   // Make a coordinate-type with all 1 components.
988   auto* coordinate_0 =
989       query_num_components == 0
990           ? component_0
991           : constant_mgr->GetConstant(
992                 query_size_type,
993                 std::vector<uint32_t>(query_num_components, component_0_id));
994 
995   const uint32_t query_max_including_faces_id = TakeNextId();
996   auto* query_max_including_faces = InsertInst(
997       image_texel_pointer, SpvOpISub,
998       type_mgr->GetTypeInstruction(query_size_type),
999       query_max_including_faces_id,
1000       {{SPV_OPERAND_TYPE_ID, {query_size_including_faces->result_id()}},
1001        {SPV_OPERAND_TYPE_ID,
1002         {constant_mgr->GetDefiningInstruction(coordinate_1)->result_id()}}});
1003 
1004   // Clamp the coordinate
1005   auto* clamp_coord = MakeSClampInst(
1006       *type_mgr, coord, constant_mgr->GetDefiningInstruction(coordinate_0),
1007       query_max_including_faces, image_texel_pointer);
1008   image_texel_pointer->SetInOperand(1, {clamp_coord->result_id()});
1009 
1010   // Clamp the sample index
1011   if (multisampled) {
1012     // Get the sample count via OpImageQuerySamples
1013     const auto query_samples_id = TakeNextId();
1014     auto* query_samples = InsertInst(
1015         image_texel_pointer, SpvOpImageQuerySamples,
1016         constant_mgr->GetDefiningInstruction(component_0)->type_id(),
1017         query_samples_id, {{SPV_OPERAND_TYPE_ID, {image->result_id()}}});
1018 
1019     const auto max_samples_id = TakeNextId();
1020     auto* max_samples = InsertInst(image_texel_pointer, SpvOpImageQuerySamples,
1021                                    query_samples->type_id(), max_samples_id,
1022                                    {{SPV_OPERAND_TYPE_ID, {query_samples_id}},
1023                                     {SPV_OPERAND_TYPE_ID, {component_1_id}}});
1024 
1025     auto* clamp_samples = MakeSClampInst(
1026         *type_mgr, samples, constant_mgr->GetDefiningInstruction(coordinate_0),
1027         max_samples, image_texel_pointer);
1028     image_texel_pointer->SetInOperand(2, {clamp_samples->result_id()});
1029 
1030   } else {
1031     // Just replace it with 0.  Don't even check what was there before.
1032     image_texel_pointer->SetInOperand(2, {component_0_id});
1033   }
1034 
1035   def_use_mgr->AnalyzeInstUse(image_texel_pointer);
1036 
1037   return SPV_SUCCESS;
1038 }
1039 
InsertInst(opt::Instruction * where_inst,SpvOp opcode,uint32_t type_id,uint32_t result_id,const Instruction::OperandList & operands)1040 opt::Instruction* GraphicsRobustAccessPass::InsertInst(
1041     opt::Instruction* where_inst, SpvOp opcode, uint32_t type_id,
1042     uint32_t result_id, const Instruction::OperandList& operands) {
1043   module_status_.modified = true;
1044   auto* result = where_inst->InsertBefore(
1045       MakeUnique<Instruction>(context(), opcode, type_id, result_id, operands));
1046   context()->get_def_use_mgr()->AnalyzeInstDefUse(result);
1047   auto* basic_block = context()->get_instr_block(where_inst);
1048   context()->set_instr_block(result, basic_block);
1049   return result;
1050 }
1051 
1052 }  // namespace opt
1053 }  // namespace spvtools
1054