1 // Copyright (c) 2018 The Khronos Group Inc.
2 // Copyright (c) 2018 Valve Corporation
3 // Copyright (c) 2018 LunarG Inc.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 #include "inst_bindless_check_pass.h"
18 
19 namespace {
20 
21 // Input Operand Indices
22 static const int kSpvImageSampleImageIdInIdx = 0;
23 static const int kSpvSampledImageImageIdInIdx = 0;
24 static const int kSpvSampledImageSamplerIdInIdx = 1;
25 static const int kSpvImageSampledImageIdInIdx = 0;
26 static const int kSpvCopyObjectOperandIdInIdx = 0;
27 static const int kSpvLoadPtrIdInIdx = 0;
28 static const int kSpvAccessChainBaseIdInIdx = 0;
29 static const int kSpvAccessChainIndex0IdInIdx = 1;
30 static const int kSpvTypeArrayTypeIdInIdx = 0;
31 static const int kSpvTypeArrayLengthIdInIdx = 1;
32 static const int kSpvConstantValueInIdx = 0;
33 static const int kSpvVariableStorageClassInIdx = 0;
34 static const int kSpvTypePtrTypeIdInIdx = 1;
35 static const int kSpvTypeImageDim = 1;
36 static const int kSpvTypeImageDepth = 2;
37 static const int kSpvTypeImageArrayed = 3;
38 static const int kSpvTypeImageMS = 4;
39 static const int kSpvTypeImageSampled = 5;
40 }  // anonymous namespace
41 
42 // Avoid unused variable warning/error on Linux
43 #ifndef NDEBUG
44 #define USE_ASSERT(x) assert(x)
45 #else
46 #define USE_ASSERT(x) ((void)(x))
47 #endif
48 
49 namespace spvtools {
50 namespace opt {
51 
GenDebugReadLength(uint32_t var_id,InstructionBuilder * builder)52 uint32_t InstBindlessCheckPass::GenDebugReadLength(
53     uint32_t var_id, InstructionBuilder* builder) {
54   uint32_t desc_set_idx =
55       var2desc_set_[var_id] + kDebugInputBindlessOffsetLengths;
56   uint32_t desc_set_idx_id = builder->GetUintConstantId(desc_set_idx);
57   uint32_t binding_idx_id = builder->GetUintConstantId(var2binding_[var_id]);
58   return GenDebugDirectRead({desc_set_idx_id, binding_idx_id}, builder);
59 }
60 
GenDebugReadInit(uint32_t var_id,uint32_t desc_idx_id,InstructionBuilder * builder)61 uint32_t InstBindlessCheckPass::GenDebugReadInit(uint32_t var_id,
62                                                  uint32_t desc_idx_id,
63                                                  InstructionBuilder* builder) {
64   uint32_t binding_idx_id = builder->GetUintConstantId(var2binding_[var_id]);
65   uint32_t u_desc_idx_id = GenUintCastCode(desc_idx_id, builder);
66   // If desc index checking is not enabled, we know the offset of initialization
67   // entries is 1, so we can avoid loading this value and just add 1 to the
68   // descriptor set.
69   if (!desc_idx_enabled_) {
70     uint32_t desc_set_idx_id =
71         builder->GetUintConstantId(var2desc_set_[var_id] + 1);
72     return GenDebugDirectRead({desc_set_idx_id, binding_idx_id, u_desc_idx_id},
73                               builder);
74   } else {
75     uint32_t desc_set_base_id =
76         builder->GetUintConstantId(kDebugInputBindlessInitOffset);
77     uint32_t desc_set_idx_id =
78         builder->GetUintConstantId(var2desc_set_[var_id]);
79     return GenDebugDirectRead(
80         {desc_set_base_id, desc_set_idx_id, binding_idx_id, u_desc_idx_id},
81         builder);
82   }
83 }
84 
CloneOriginalImage(uint32_t old_image_id,InstructionBuilder * builder)85 uint32_t InstBindlessCheckPass::CloneOriginalImage(
86     uint32_t old_image_id, InstructionBuilder* builder) {
87   Instruction* new_image_inst;
88   Instruction* old_image_inst = get_def_use_mgr()->GetDef(old_image_id);
89   if (old_image_inst->opcode() == SpvOpLoad) {
90     new_image_inst = builder->AddLoad(
91         old_image_inst->type_id(),
92         old_image_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx));
93   } else if (old_image_inst->opcode() == SpvOp::SpvOpSampledImage) {
94     uint32_t clone_id = CloneOriginalImage(
95         old_image_inst->GetSingleWordInOperand(kSpvSampledImageImageIdInIdx),
96         builder);
97     new_image_inst = builder->AddBinaryOp(
98         old_image_inst->type_id(), SpvOpSampledImage, clone_id,
99         old_image_inst->GetSingleWordInOperand(kSpvSampledImageSamplerIdInIdx));
100   } else if (old_image_inst->opcode() == SpvOp::SpvOpImage) {
101     uint32_t clone_id = CloneOriginalImage(
102         old_image_inst->GetSingleWordInOperand(kSpvImageSampledImageIdInIdx),
103         builder);
104     new_image_inst =
105         builder->AddUnaryOp(old_image_inst->type_id(), SpvOpImage, clone_id);
106   } else {
107     assert(old_image_inst->opcode() == SpvOp::SpvOpCopyObject &&
108            "expecting OpCopyObject");
109     uint32_t clone_id = CloneOriginalImage(
110         old_image_inst->GetSingleWordInOperand(kSpvCopyObjectOperandIdInIdx),
111         builder);
112     // Since we are cloning, no need to create new copy
113     new_image_inst = get_def_use_mgr()->GetDef(clone_id);
114   }
115   uid2offset_[new_image_inst->unique_id()] =
116       uid2offset_[old_image_inst->unique_id()];
117   uint32_t new_image_id = new_image_inst->result_id();
118   get_decoration_mgr()->CloneDecorations(old_image_id, new_image_id);
119   return new_image_id;
120 }
121 
CloneOriginalReference(RefAnalysis * ref,InstructionBuilder * builder)122 uint32_t InstBindlessCheckPass::CloneOriginalReference(
123     RefAnalysis* ref, InstructionBuilder* builder) {
124   // If original is image based, start by cloning descriptor load
125   uint32_t new_image_id = 0;
126   if (ref->desc_load_id != 0) {
127     uint32_t old_image_id =
128         ref->ref_inst->GetSingleWordInOperand(kSpvImageSampleImageIdInIdx);
129     new_image_id = CloneOriginalImage(old_image_id, builder);
130   }
131   // Clone original reference
132   std::unique_ptr<Instruction> new_ref_inst(ref->ref_inst->Clone(context()));
133   uint32_t ref_result_id = ref->ref_inst->result_id();
134   uint32_t new_ref_id = 0;
135   if (ref_result_id != 0) {
136     new_ref_id = TakeNextId();
137     new_ref_inst->SetResultId(new_ref_id);
138   }
139   // Update new ref with new image if created
140   if (new_image_id != 0)
141     new_ref_inst->SetInOperand(kSpvImageSampleImageIdInIdx, {new_image_id});
142   // Register new reference and add to new block
143   Instruction* added_inst = builder->AddInstruction(std::move(new_ref_inst));
144   uid2offset_[added_inst->unique_id()] =
145       uid2offset_[ref->ref_inst->unique_id()];
146   if (new_ref_id != 0)
147     get_decoration_mgr()->CloneDecorations(ref_result_id, new_ref_id);
148   return new_ref_id;
149 }
150 
GetImageId(Instruction * inst)151 uint32_t InstBindlessCheckPass::GetImageId(Instruction* inst) {
152   switch (inst->opcode()) {
153     case SpvOp::SpvOpImageSampleImplicitLod:
154     case SpvOp::SpvOpImageSampleExplicitLod:
155     case SpvOp::SpvOpImageSampleDrefImplicitLod:
156     case SpvOp::SpvOpImageSampleDrefExplicitLod:
157     case SpvOp::SpvOpImageSampleProjImplicitLod:
158     case SpvOp::SpvOpImageSampleProjExplicitLod:
159     case SpvOp::SpvOpImageSampleProjDrefImplicitLod:
160     case SpvOp::SpvOpImageSampleProjDrefExplicitLod:
161     case SpvOp::SpvOpImageGather:
162     case SpvOp::SpvOpImageDrefGather:
163     case SpvOp::SpvOpImageQueryLod:
164     case SpvOp::SpvOpImageSparseSampleImplicitLod:
165     case SpvOp::SpvOpImageSparseSampleExplicitLod:
166     case SpvOp::SpvOpImageSparseSampleDrefImplicitLod:
167     case SpvOp::SpvOpImageSparseSampleDrefExplicitLod:
168     case SpvOp::SpvOpImageSparseSampleProjImplicitLod:
169     case SpvOp::SpvOpImageSparseSampleProjExplicitLod:
170     case SpvOp::SpvOpImageSparseSampleProjDrefImplicitLod:
171     case SpvOp::SpvOpImageSparseSampleProjDrefExplicitLod:
172     case SpvOp::SpvOpImageSparseGather:
173     case SpvOp::SpvOpImageSparseDrefGather:
174     case SpvOp::SpvOpImageFetch:
175     case SpvOp::SpvOpImageRead:
176     case SpvOp::SpvOpImageQueryFormat:
177     case SpvOp::SpvOpImageQueryOrder:
178     case SpvOp::SpvOpImageQuerySizeLod:
179     case SpvOp::SpvOpImageQuerySize:
180     case SpvOp::SpvOpImageQueryLevels:
181     case SpvOp::SpvOpImageQuerySamples:
182     case SpvOp::SpvOpImageSparseFetch:
183     case SpvOp::SpvOpImageSparseRead:
184     case SpvOp::SpvOpImageWrite:
185       return inst->GetSingleWordInOperand(kSpvImageSampleImageIdInIdx);
186     default:
187       break;
188   }
189   return 0;
190 }
191 
GetPointeeTypeInst(Instruction * ptr_inst)192 Instruction* InstBindlessCheckPass::GetPointeeTypeInst(Instruction* ptr_inst) {
193   uint32_t pte_ty_id = GetPointeeTypeId(ptr_inst);
194   return get_def_use_mgr()->GetDef(pte_ty_id);
195 }
196 
AnalyzeDescriptorReference(Instruction * ref_inst,RefAnalysis * ref)197 bool InstBindlessCheckPass::AnalyzeDescriptorReference(Instruction* ref_inst,
198                                                        RefAnalysis* ref) {
199   ref->ref_inst = ref_inst;
200   if (ref_inst->opcode() == SpvOpLoad || ref_inst->opcode() == SpvOpStore) {
201     ref->desc_load_id = 0;
202     ref->ptr_id = ref_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx);
203     Instruction* ptr_inst = get_def_use_mgr()->GetDef(ref->ptr_id);
204     if (ptr_inst->opcode() != SpvOp::SpvOpAccessChain) return false;
205     ref->var_id = ptr_inst->GetSingleWordInOperand(kSpvAccessChainBaseIdInIdx);
206     Instruction* var_inst = get_def_use_mgr()->GetDef(ref->var_id);
207     if (var_inst->opcode() != SpvOp::SpvOpVariable) return false;
208     uint32_t storage_class =
209         var_inst->GetSingleWordInOperand(kSpvVariableStorageClassInIdx);
210     switch (storage_class) {
211       case SpvStorageClassUniform:
212       case SpvStorageClassStorageBuffer:
213         break;
214       default:
215         return false;
216         break;
217     }
218     // Check for deprecated storage block form
219     if (storage_class == SpvStorageClassUniform) {
220       uint32_t var_ty_id = var_inst->type_id();
221       Instruction* var_ty_inst = get_def_use_mgr()->GetDef(var_ty_id);
222       uint32_t ptr_ty_id =
223           var_ty_inst->GetSingleWordInOperand(kSpvTypePtrTypeIdInIdx);
224       Instruction* ptr_ty_inst = get_def_use_mgr()->GetDef(ptr_ty_id);
225       SpvOp ptr_ty_op = ptr_ty_inst->opcode();
226       uint32_t block_ty_id =
227           (ptr_ty_op == SpvOpTypeArray || ptr_ty_op == SpvOpTypeRuntimeArray)
228               ? ptr_ty_inst->GetSingleWordInOperand(kSpvTypeArrayTypeIdInIdx)
229               : ptr_ty_id;
230       assert(get_def_use_mgr()->GetDef(block_ty_id)->opcode() ==
231                  SpvOpTypeStruct &&
232              "unexpected block type");
233       bool block_found = get_decoration_mgr()->FindDecoration(
234           block_ty_id, SpvDecorationBlock,
235           [](const Instruction&) { return true; });
236       if (!block_found) {
237         // If block decoration not found, verify deprecated form of SSBO
238         bool buffer_block_found = get_decoration_mgr()->FindDecoration(
239             block_ty_id, SpvDecorationBufferBlock,
240             [](const Instruction&) { return true; });
241         USE_ASSERT(buffer_block_found && "block decoration not found");
242         storage_class = SpvStorageClassStorageBuffer;
243       }
244     }
245     ref->strg_class = storage_class;
246     Instruction* desc_type_inst = GetPointeeTypeInst(var_inst);
247     switch (desc_type_inst->opcode()) {
248       case SpvOpTypeArray:
249       case SpvOpTypeRuntimeArray:
250         // A load through a descriptor array will have at least 3 operands. We
251         // do not want to instrument loads of descriptors here which are part of
252         // an image-based reference.
253         if (ptr_inst->NumInOperands() < 3) return false;
254         ref->desc_idx_id =
255             ptr_inst->GetSingleWordInOperand(kSpvAccessChainIndex0IdInIdx);
256         break;
257       default:
258         ref->desc_idx_id = 0;
259         break;
260     }
261     return true;
262   }
263   // Reference is not load or store. If not an image-based reference, return.
264   ref->image_id = GetImageId(ref_inst);
265   if (ref->image_id == 0) return false;
266   // Search for descriptor load
267   uint32_t desc_load_id = ref->image_id;
268   Instruction* desc_load_inst;
269   for (;;) {
270     desc_load_inst = get_def_use_mgr()->GetDef(desc_load_id);
271     if (desc_load_inst->opcode() == SpvOp::SpvOpSampledImage)
272       desc_load_id =
273           desc_load_inst->GetSingleWordInOperand(kSpvSampledImageImageIdInIdx);
274     else if (desc_load_inst->opcode() == SpvOp::SpvOpImage)
275       desc_load_id =
276           desc_load_inst->GetSingleWordInOperand(kSpvImageSampledImageIdInIdx);
277     else if (desc_load_inst->opcode() == SpvOp::SpvOpCopyObject)
278       desc_load_id =
279           desc_load_inst->GetSingleWordInOperand(kSpvCopyObjectOperandIdInIdx);
280     else
281       break;
282   }
283   if (desc_load_inst->opcode() != SpvOp::SpvOpLoad) {
284     // TODO(greg-lunarg): Handle additional possibilities?
285     return false;
286   }
287   ref->desc_load_id = desc_load_id;
288   ref->ptr_id = desc_load_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx);
289   Instruction* ptr_inst = get_def_use_mgr()->GetDef(ref->ptr_id);
290   if (ptr_inst->opcode() == SpvOp::SpvOpVariable) {
291     ref->desc_idx_id = 0;
292     ref->var_id = ref->ptr_id;
293   } else if (ptr_inst->opcode() == SpvOp::SpvOpAccessChain) {
294     if (ptr_inst->NumInOperands() != 2) {
295       assert(false && "unexpected bindless index number");
296       return false;
297     }
298     ref->desc_idx_id =
299         ptr_inst->GetSingleWordInOperand(kSpvAccessChainIndex0IdInIdx);
300     ref->var_id = ptr_inst->GetSingleWordInOperand(kSpvAccessChainBaseIdInIdx);
301     Instruction* var_inst = get_def_use_mgr()->GetDef(ref->var_id);
302     if (var_inst->opcode() != SpvOpVariable) {
303       assert(false && "unexpected bindless base");
304       return false;
305     }
306   } else {
307     // TODO(greg-lunarg): Handle additional possibilities?
308     return false;
309   }
310   return true;
311 }
312 
FindStride(uint32_t ty_id,uint32_t stride_deco)313 uint32_t InstBindlessCheckPass::FindStride(uint32_t ty_id,
314                                            uint32_t stride_deco) {
315   uint32_t stride = 0xdeadbeef;
316   bool found = get_decoration_mgr()->FindDecoration(
317       ty_id, stride_deco, [&stride](const Instruction& deco_inst) {
318         stride = deco_inst.GetSingleWordInOperand(2u);
319         return true;
320       });
321   USE_ASSERT(found && "stride not found");
322   return stride;
323 }
324 
ByteSize(uint32_t ty_id,uint32_t matrix_stride,bool col_major,bool in_matrix)325 uint32_t InstBindlessCheckPass::ByteSize(uint32_t ty_id, uint32_t matrix_stride,
326                                          bool col_major, bool in_matrix) {
327   analysis::TypeManager* type_mgr = context()->get_type_mgr();
328   const analysis::Type* sz_ty = type_mgr->GetType(ty_id);
329   if (sz_ty->kind() == analysis::Type::kPointer) {
330     // Assuming PhysicalStorageBuffer pointer
331     return 8;
332   }
333   if (sz_ty->kind() == analysis::Type::kMatrix) {
334     assert(matrix_stride != 0 && "missing matrix stride");
335     const analysis::Matrix* m_ty = sz_ty->AsMatrix();
336     if (col_major) {
337       return m_ty->element_count() * matrix_stride;
338     } else {
339       const analysis::Vector* v_ty = m_ty->element_type()->AsVector();
340       return v_ty->element_count() * matrix_stride;
341     }
342   }
343   uint32_t size = 1;
344   if (sz_ty->kind() == analysis::Type::kVector) {
345     const analysis::Vector* v_ty = sz_ty->AsVector();
346     size = v_ty->element_count();
347     const analysis::Type* comp_ty = v_ty->element_type();
348     // if vector in row major matrix, the vector is strided so return the
349     // number of bytes spanned by the vector
350     if (in_matrix && !col_major && matrix_stride > 0) {
351       uint32_t comp_ty_id = type_mgr->GetId(comp_ty);
352       return (size - 1) * matrix_stride + ByteSize(comp_ty_id, 0, false, false);
353     }
354     sz_ty = comp_ty;
355   }
356   switch (sz_ty->kind()) {
357     case analysis::Type::kFloat: {
358       const analysis::Float* f_ty = sz_ty->AsFloat();
359       size *= f_ty->width();
360     } break;
361     case analysis::Type::kInteger: {
362       const analysis::Integer* i_ty = sz_ty->AsInteger();
363       size *= i_ty->width();
364     } break;
365     default: { assert(false && "unexpected type"); } break;
366   }
367   size /= 8;
368   return size;
369 }
370 
GenLastByteIdx(RefAnalysis * ref,InstructionBuilder * builder)371 uint32_t InstBindlessCheckPass::GenLastByteIdx(RefAnalysis* ref,
372                                                InstructionBuilder* builder) {
373   // Find outermost buffer type and its access chain index
374   Instruction* var_inst = get_def_use_mgr()->GetDef(ref->var_id);
375   Instruction* desc_ty_inst = GetPointeeTypeInst(var_inst);
376   uint32_t buff_ty_id;
377   uint32_t ac_in_idx = 1;
378   switch (desc_ty_inst->opcode()) {
379     case SpvOpTypeArray:
380     case SpvOpTypeRuntimeArray:
381       buff_ty_id = desc_ty_inst->GetSingleWordInOperand(0);
382       ++ac_in_idx;
383       break;
384     default:
385       assert(desc_ty_inst->opcode() == SpvOpTypeStruct &&
386              "unexpected descriptor type");
387       buff_ty_id = desc_ty_inst->result_id();
388       break;
389   }
390   // Process remaining access chain indices
391   Instruction* ac_inst = get_def_use_mgr()->GetDef(ref->ptr_id);
392   uint32_t curr_ty_id = buff_ty_id;
393   uint32_t sum_id = 0u;
394   uint32_t matrix_stride = 0u;
395   bool col_major = false;
396   uint32_t matrix_stride_id = 0u;
397   bool in_matrix = false;
398   while (ac_in_idx < ac_inst->NumInOperands()) {
399     uint32_t curr_idx_id = ac_inst->GetSingleWordInOperand(ac_in_idx);
400     Instruction* curr_ty_inst = get_def_use_mgr()->GetDef(curr_ty_id);
401     uint32_t curr_offset_id = 0;
402     switch (curr_ty_inst->opcode()) {
403       case SpvOpTypeArray:
404       case SpvOpTypeRuntimeArray: {
405         // Get array stride and multiply by current index
406         uint32_t arr_stride = FindStride(curr_ty_id, SpvDecorationArrayStride);
407         uint32_t arr_stride_id = builder->GetUintConstantId(arr_stride);
408         uint32_t curr_idx_32b_id = Gen32BitCvtCode(curr_idx_id, builder);
409         Instruction* curr_offset_inst = builder->AddBinaryOp(
410             GetUintId(), SpvOpIMul, arr_stride_id, curr_idx_32b_id);
411         curr_offset_id = curr_offset_inst->result_id();
412         // Get element type for next step
413         curr_ty_id = curr_ty_inst->GetSingleWordInOperand(0);
414       } break;
415       case SpvOpTypeMatrix: {
416         assert(matrix_stride != 0 && "missing matrix stride");
417         matrix_stride_id = builder->GetUintConstantId(matrix_stride);
418         uint32_t vec_ty_id = curr_ty_inst->GetSingleWordInOperand(0);
419         // If column major, multiply column index by matrix stride, otherwise
420         // by vector component size and save matrix stride for vector (row)
421         // index
422         uint32_t col_stride_id;
423         if (col_major) {
424           col_stride_id = matrix_stride_id;
425         } else {
426           Instruction* vec_ty_inst = get_def_use_mgr()->GetDef(vec_ty_id);
427           uint32_t comp_ty_id = vec_ty_inst->GetSingleWordInOperand(0u);
428           uint32_t col_stride = ByteSize(comp_ty_id, 0u, false, false);
429           col_stride_id = builder->GetUintConstantId(col_stride);
430         }
431         uint32_t curr_idx_32b_id = Gen32BitCvtCode(curr_idx_id, builder);
432         Instruction* curr_offset_inst = builder->AddBinaryOp(
433             GetUintId(), SpvOpIMul, col_stride_id, curr_idx_32b_id);
434         curr_offset_id = curr_offset_inst->result_id();
435         // Get element type for next step
436         curr_ty_id = vec_ty_id;
437         in_matrix = true;
438       } break;
439       case SpvOpTypeVector: {
440         // If inside a row major matrix type, multiply index by matrix stride,
441         // else multiply by component size
442         uint32_t comp_ty_id = curr_ty_inst->GetSingleWordInOperand(0u);
443         uint32_t curr_idx_32b_id = Gen32BitCvtCode(curr_idx_id, builder);
444         if (in_matrix && !col_major) {
445           Instruction* curr_offset_inst = builder->AddBinaryOp(
446               GetUintId(), SpvOpIMul, matrix_stride_id, curr_idx_32b_id);
447           curr_offset_id = curr_offset_inst->result_id();
448         } else {
449           uint32_t comp_ty_sz = ByteSize(comp_ty_id, 0u, false, false);
450           uint32_t comp_ty_sz_id = builder->GetUintConstantId(comp_ty_sz);
451           Instruction* curr_offset_inst = builder->AddBinaryOp(
452               GetUintId(), SpvOpIMul, comp_ty_sz_id, curr_idx_32b_id);
453           curr_offset_id = curr_offset_inst->result_id();
454         }
455         // Get element type for next step
456         curr_ty_id = comp_ty_id;
457       } break;
458       case SpvOpTypeStruct: {
459         // Get buffer byte offset for the referenced member
460         Instruction* curr_idx_inst = get_def_use_mgr()->GetDef(curr_idx_id);
461         assert(curr_idx_inst->opcode() == SpvOpConstant &&
462                "unexpected struct index");
463         uint32_t member_idx = curr_idx_inst->GetSingleWordInOperand(0);
464         uint32_t member_offset = 0xdeadbeef;
465         bool found = get_decoration_mgr()->FindDecoration(
466             curr_ty_id, SpvDecorationOffset,
467             [&member_idx, &member_offset](const Instruction& deco_inst) {
468               if (deco_inst.GetSingleWordInOperand(1u) != member_idx)
469                 return false;
470               member_offset = deco_inst.GetSingleWordInOperand(3u);
471               return true;
472             });
473         USE_ASSERT(found && "member offset not found");
474         curr_offset_id = builder->GetUintConstantId(member_offset);
475         // Look for matrix stride for this member if there is one. The matrix
476         // stride is not on the matrix type, but in a OpMemberDecorate on the
477         // enclosing struct type at the member index. If none found, reset
478         // stride to 0.
479         found = get_decoration_mgr()->FindDecoration(
480             curr_ty_id, SpvDecorationMatrixStride,
481             [&member_idx, &matrix_stride](const Instruction& deco_inst) {
482               if (deco_inst.GetSingleWordInOperand(1u) != member_idx)
483                 return false;
484               matrix_stride = deco_inst.GetSingleWordInOperand(3u);
485               return true;
486             });
487         if (!found) matrix_stride = 0;
488         // Look for column major decoration
489         found = get_decoration_mgr()->FindDecoration(
490             curr_ty_id, SpvDecorationColMajor,
491             [&member_idx, &col_major](const Instruction& deco_inst) {
492               if (deco_inst.GetSingleWordInOperand(1u) != member_idx)
493                 return false;
494               col_major = true;
495               return true;
496             });
497         if (!found) col_major = false;
498         // Get element type for next step
499         curr_ty_id = curr_ty_inst->GetSingleWordInOperand(member_idx);
500       } break;
501       default: { assert(false && "unexpected non-composite type"); } break;
502     }
503     if (sum_id == 0)
504       sum_id = curr_offset_id;
505     else {
506       Instruction* sum_inst =
507           builder->AddBinaryOp(GetUintId(), SpvOpIAdd, sum_id, curr_offset_id);
508       sum_id = sum_inst->result_id();
509     }
510     ++ac_in_idx;
511   }
512   // Add in offset of last byte of referenced object
513   uint32_t bsize = ByteSize(curr_ty_id, matrix_stride, col_major, in_matrix);
514   uint32_t last = bsize - 1;
515   uint32_t last_id = builder->GetUintConstantId(last);
516   Instruction* sum_inst =
517       builder->AddBinaryOp(GetUintId(), SpvOpIAdd, sum_id, last_id);
518   return sum_inst->result_id();
519 }
520 
GenCheckCode(uint32_t check_id,uint32_t error_id,uint32_t offset_id,uint32_t length_id,uint32_t stage_idx,RefAnalysis * ref,std::vector<std::unique_ptr<BasicBlock>> * new_blocks)521 void InstBindlessCheckPass::GenCheckCode(
522     uint32_t check_id, uint32_t error_id, uint32_t offset_id,
523     uint32_t length_id, uint32_t stage_idx, RefAnalysis* ref,
524     std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
525   BasicBlock* back_blk_ptr = &*new_blocks->back();
526   InstructionBuilder builder(
527       context(), back_blk_ptr,
528       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
529   // Gen conditional branch on check_id. Valid branch generates original
530   // reference. Invalid generates debug output and zero result (if needed).
531   uint32_t merge_blk_id = TakeNextId();
532   uint32_t valid_blk_id = TakeNextId();
533   uint32_t invalid_blk_id = TakeNextId();
534   std::unique_ptr<Instruction> merge_label(NewLabel(merge_blk_id));
535   std::unique_ptr<Instruction> valid_label(NewLabel(valid_blk_id));
536   std::unique_ptr<Instruction> invalid_label(NewLabel(invalid_blk_id));
537   (void)builder.AddConditionalBranch(check_id, valid_blk_id, invalid_blk_id,
538                                      merge_blk_id, SpvSelectionControlMaskNone);
539   // Gen valid bounds branch
540   std::unique_ptr<BasicBlock> new_blk_ptr(
541       new BasicBlock(std::move(valid_label)));
542   builder.SetInsertPoint(&*new_blk_ptr);
543   uint32_t new_ref_id = CloneOriginalReference(ref, &builder);
544   (void)builder.AddBranch(merge_blk_id);
545   new_blocks->push_back(std::move(new_blk_ptr));
546   // Gen invalid block
547   new_blk_ptr.reset(new BasicBlock(std::move(invalid_label)));
548   builder.SetInsertPoint(&*new_blk_ptr);
549   uint32_t u_index_id = GenUintCastCode(ref->desc_idx_id, &builder);
550   if (offset_id != 0) {
551     // Buffer OOB
552     uint32_t u_offset_id = GenUintCastCode(offset_id, &builder);
553     uint32_t u_length_id = GenUintCastCode(length_id, &builder);
554     GenDebugStreamWrite(uid2offset_[ref->ref_inst->unique_id()], stage_idx,
555                         {error_id, u_index_id, u_offset_id, u_length_id},
556                         &builder);
557   } else if (buffer_bounds_enabled_ || texel_buffer_enabled_) {
558     // Uninitialized Descriptor - Return additional unused zero so all error
559     // modes will use same debug stream write function
560     uint32_t u_length_id = GenUintCastCode(length_id, &builder);
561     GenDebugStreamWrite(
562         uid2offset_[ref->ref_inst->unique_id()], stage_idx,
563         {error_id, u_index_id, u_length_id, builder.GetUintConstantId(0)},
564         &builder);
565   } else {
566     // Uninitialized Descriptor - Normal error return
567     uint32_t u_length_id = GenUintCastCode(length_id, &builder);
568     GenDebugStreamWrite(uid2offset_[ref->ref_inst->unique_id()], stage_idx,
569                         {error_id, u_index_id, u_length_id}, &builder);
570   }
571   // Remember last invalid block id
572   uint32_t last_invalid_blk_id = new_blk_ptr->GetLabelInst()->result_id();
573   // Gen zero for invalid  reference
574   uint32_t ref_type_id = ref->ref_inst->type_id();
575   (void)builder.AddBranch(merge_blk_id);
576   new_blocks->push_back(std::move(new_blk_ptr));
577   // Gen merge block
578   new_blk_ptr.reset(new BasicBlock(std::move(merge_label)));
579   builder.SetInsertPoint(&*new_blk_ptr);
580   // Gen phi of new reference and zero, if necessary, and replace the
581   // result id of the original reference with that of the Phi. Kill original
582   // reference.
583   if (new_ref_id != 0) {
584     Instruction* phi_inst = builder.AddPhi(
585         ref_type_id, {new_ref_id, valid_blk_id, GetNullId(ref_type_id),
586                       last_invalid_blk_id});
587     context()->ReplaceAllUsesWith(ref->ref_inst->result_id(),
588                                   phi_inst->result_id());
589   }
590   new_blocks->push_back(std::move(new_blk_ptr));
591   context()->KillInst(ref->ref_inst);
592 }
593 
GenDescIdxCheckCode(BasicBlock::iterator ref_inst_itr,UptrVectorIterator<BasicBlock> ref_block_itr,uint32_t stage_idx,std::vector<std::unique_ptr<BasicBlock>> * new_blocks)594 void InstBindlessCheckPass::GenDescIdxCheckCode(
595     BasicBlock::iterator ref_inst_itr,
596     UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
597     std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
598   // Look for reference through indexed descriptor. If found, analyze and
599   // save components. If not, return.
600   RefAnalysis ref;
601   if (!AnalyzeDescriptorReference(&*ref_inst_itr, &ref)) return;
602   Instruction* ptr_inst = get_def_use_mgr()->GetDef(ref.ptr_id);
603   if (ptr_inst->opcode() != SpvOp::SpvOpAccessChain) return;
604   // If index and bound both compile-time constants and index < bound,
605   // return without changing
606   Instruction* var_inst = get_def_use_mgr()->GetDef(ref.var_id);
607   Instruction* desc_type_inst = GetPointeeTypeInst(var_inst);
608   uint32_t length_id = 0;
609   if (desc_type_inst->opcode() == SpvOpTypeArray) {
610     length_id =
611         desc_type_inst->GetSingleWordInOperand(kSpvTypeArrayLengthIdInIdx);
612     Instruction* index_inst = get_def_use_mgr()->GetDef(ref.desc_idx_id);
613     Instruction* length_inst = get_def_use_mgr()->GetDef(length_id);
614     if (index_inst->opcode() == SpvOpConstant &&
615         length_inst->opcode() == SpvOpConstant &&
616         index_inst->GetSingleWordInOperand(kSpvConstantValueInIdx) <
617             length_inst->GetSingleWordInOperand(kSpvConstantValueInIdx))
618       return;
619   } else if (!desc_idx_enabled_ ||
620              desc_type_inst->opcode() != SpvOpTypeRuntimeArray) {
621     return;
622   }
623   // Move original block's preceding instructions into first new block
624   std::unique_ptr<BasicBlock> new_blk_ptr;
625   MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
626   InstructionBuilder builder(
627       context(), &*new_blk_ptr,
628       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
629   new_blocks->push_back(std::move(new_blk_ptr));
630   uint32_t error_id = builder.GetUintConstantId(kInstErrorBindlessBounds);
631   // If length id not yet set, descriptor array is runtime size so
632   // generate load of length from stage's debug input buffer.
633   if (length_id == 0) {
634     assert(desc_type_inst->opcode() == SpvOpTypeRuntimeArray &&
635            "unexpected bindless type");
636     length_id = GenDebugReadLength(ref.var_id, &builder);
637   }
638   // Generate full runtime bounds test code with true branch
639   // being full reference and false branch being debug output and zero
640   // for the referenced value.
641   uint32_t desc_idx_32b_id = Gen32BitCvtCode(ref.desc_idx_id, &builder);
642   uint32_t length_32b_id = Gen32BitCvtCode(length_id, &builder);
643   Instruction* ult_inst = builder.AddBinaryOp(GetBoolId(), SpvOpULessThan,
644                                               desc_idx_32b_id, length_32b_id);
645   ref.desc_idx_id = desc_idx_32b_id;
646   GenCheckCode(ult_inst->result_id(), error_id, 0u, length_id, stage_idx, &ref,
647                new_blocks);
648   // Move original block's remaining code into remainder/merge block and add
649   // to new blocks
650   BasicBlock* back_blk_ptr = &*new_blocks->back();
651   MovePostludeCode(ref_block_itr, back_blk_ptr);
652 }
653 
GenDescInitCheckCode(BasicBlock::iterator ref_inst_itr,UptrVectorIterator<BasicBlock> ref_block_itr,uint32_t stage_idx,std::vector<std::unique_ptr<BasicBlock>> * new_blocks)654 void InstBindlessCheckPass::GenDescInitCheckCode(
655     BasicBlock::iterator ref_inst_itr,
656     UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
657     std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
658   // Look for reference through descriptor. If not, return.
659   RefAnalysis ref;
660   if (!AnalyzeDescriptorReference(&*ref_inst_itr, &ref)) return;
661   // Determine if we can only do initialization check
662   bool init_check = false;
663   if (ref.desc_load_id != 0 || !buffer_bounds_enabled_) {
664     init_check = true;
665   } else {
666     // For now, only do bounds check for non-aggregate types. Otherwise
667     // just do descriptor initialization check.
668     // TODO(greg-lunarg): Do bounds check for aggregate loads and stores
669     Instruction* ref_ptr_inst = get_def_use_mgr()->GetDef(ref.ptr_id);
670     Instruction* pte_type_inst = GetPointeeTypeInst(ref_ptr_inst);
671     uint32_t pte_type_op = pte_type_inst->opcode();
672     if (pte_type_op == SpvOpTypeArray || pte_type_op == SpvOpTypeRuntimeArray ||
673         pte_type_op == SpvOpTypeStruct)
674       init_check = true;
675   }
676   // If initialization check and not enabled, return
677   if (init_check && !desc_init_enabled_) return;
678   // Move original block's preceding instructions into first new block
679   std::unique_ptr<BasicBlock> new_blk_ptr;
680   MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
681   InstructionBuilder builder(
682       context(), &*new_blk_ptr,
683       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
684   new_blocks->push_back(std::move(new_blk_ptr));
685   // If initialization check, use reference value of zero.
686   // Else use the index of the last byte referenced.
687   uint32_t ref_id = init_check ? builder.GetUintConstantId(0u)
688                                : GenLastByteIdx(&ref, &builder);
689   // Read initialization/bounds from debug input buffer. If index id not yet
690   // set, binding is single descriptor, so set index to constant 0.
691   if (ref.desc_idx_id == 0) ref.desc_idx_id = builder.GetUintConstantId(0u);
692   uint32_t init_id = GenDebugReadInit(ref.var_id, ref.desc_idx_id, &builder);
693   // Generate runtime initialization/bounds test code with true branch
694   // being full reference and false branch being debug output and zero
695   // for the referenced value.
696   Instruction* ult_inst =
697       builder.AddBinaryOp(GetBoolId(), SpvOpULessThan, ref_id, init_id);
698   uint32_t error = init_check ? kInstErrorBindlessUninit
699                               : (ref.strg_class == SpvStorageClassUniform
700                                      ? kInstErrorBuffOOBUniform
701                                      : kInstErrorBuffOOBStorage);
702   uint32_t error_id = builder.GetUintConstantId(error);
703   GenCheckCode(ult_inst->result_id(), error_id, init_check ? 0 : ref_id,
704                init_check ? builder.GetUintConstantId(0u) : init_id, stage_idx,
705                &ref, new_blocks);
706   // Move original block's remaining code into remainder/merge block and add
707   // to new blocks
708   BasicBlock* back_blk_ptr = &*new_blocks->back();
709   MovePostludeCode(ref_block_itr, back_blk_ptr);
710 }
711 
GenTexBuffCheckCode(BasicBlock::iterator ref_inst_itr,UptrVectorIterator<BasicBlock> ref_block_itr,uint32_t stage_idx,std::vector<std::unique_ptr<BasicBlock>> * new_blocks)712 void InstBindlessCheckPass::GenTexBuffCheckCode(
713     BasicBlock::iterator ref_inst_itr,
714     UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
715     std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
716   // Only process OpImageRead and OpImageWrite with no optional operands
717   Instruction* ref_inst = &*ref_inst_itr;
718   SpvOp op = ref_inst->opcode();
719   uint32_t num_in_oprnds = ref_inst->NumInOperands();
720   if (!((op == SpvOpImageRead && num_in_oprnds == 2) ||
721         (op == SpvOpImageFetch && num_in_oprnds == 2) ||
722         (op == SpvOpImageWrite && num_in_oprnds == 3)))
723     return;
724   // Pull components from descriptor reference
725   RefAnalysis ref;
726   if (!AnalyzeDescriptorReference(ref_inst, &ref)) return;
727   // Only process if image is texel buffer
728   Instruction* image_inst = get_def_use_mgr()->GetDef(ref.image_id);
729   uint32_t image_ty_id = image_inst->type_id();
730   Instruction* image_ty_inst = get_def_use_mgr()->GetDef(image_ty_id);
731   if (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageDim) != SpvDimBuffer)
732     return;
733   if (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageDepth) != 0) return;
734   if (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageArrayed) != 0) return;
735   if (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageMS) != 0) return;
736   // Enable ImageQuery Capability if not yet enabled
737   if (!get_feature_mgr()->HasCapability(SpvCapabilityImageQuery)) {
738     std::unique_ptr<Instruction> cap_image_query_inst(new Instruction(
739         context(), SpvOpCapability, 0, 0,
740         std::initializer_list<Operand>{
741             {SPV_OPERAND_TYPE_CAPABILITY, {SpvCapabilityImageQuery}}}));
742     get_def_use_mgr()->AnalyzeInstDefUse(&*cap_image_query_inst);
743     context()->AddCapability(std::move(cap_image_query_inst));
744   }
745   // Move original block's preceding instructions into first new block
746   std::unique_ptr<BasicBlock> new_blk_ptr;
747   MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
748   InstructionBuilder builder(
749       context(), &*new_blk_ptr,
750       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
751   new_blocks->push_back(std::move(new_blk_ptr));
752   // Get texel coordinate
753   uint32_t coord_id =
754       GenUintCastCode(ref_inst->GetSingleWordInOperand(1), &builder);
755   // If index id not yet set, binding is single descriptor, so set index to
756   // constant 0.
757   if (ref.desc_idx_id == 0) ref.desc_idx_id = builder.GetUintConstantId(0u);
758   // Get texel buffer size.
759   Instruction* size_inst =
760       builder.AddUnaryOp(GetUintId(), SpvOpImageQuerySize, ref.image_id);
761   uint32_t size_id = size_inst->result_id();
762   // Generate runtime initialization/bounds test code with true branch
763   // being full reference and false branch being debug output and zero
764   // for the referenced value.
765   Instruction* ult_inst =
766       builder.AddBinaryOp(GetBoolId(), SpvOpULessThan, coord_id, size_id);
767   uint32_t error =
768       (image_ty_inst->GetSingleWordInOperand(kSpvTypeImageSampled) == 2)
769           ? kInstErrorBuffOOBStorageTexel
770           : kInstErrorBuffOOBUniformTexel;
771   uint32_t error_id = builder.GetUintConstantId(error);
772   GenCheckCode(ult_inst->result_id(), error_id, coord_id, size_id, stage_idx,
773                &ref, new_blocks);
774   // Move original block's remaining code into remainder/merge block and add
775   // to new blocks
776   BasicBlock* back_blk_ptr = &*new_blocks->back();
777   MovePostludeCode(ref_block_itr, back_blk_ptr);
778 }
779 
InitializeInstBindlessCheck()780 void InstBindlessCheckPass::InitializeInstBindlessCheck() {
781   // Initialize base class
782   InitializeInstrument();
783   // If runtime array length support or buffer bounds checking are enabled,
784   // create variable mappings. Length support is always enabled if descriptor
785   // init check is enabled.
786   if (desc_idx_enabled_ || buffer_bounds_enabled_ || texel_buffer_enabled_)
787     for (auto& anno : get_module()->annotations())
788       if (anno.opcode() == SpvOpDecorate) {
789         if (anno.GetSingleWordInOperand(1u) == SpvDecorationDescriptorSet)
790           var2desc_set_[anno.GetSingleWordInOperand(0u)] =
791               anno.GetSingleWordInOperand(2u);
792         else if (anno.GetSingleWordInOperand(1u) == SpvDecorationBinding)
793           var2binding_[anno.GetSingleWordInOperand(0u)] =
794               anno.GetSingleWordInOperand(2u);
795       }
796 }
797 
ProcessImpl()798 Pass::Status InstBindlessCheckPass::ProcessImpl() {
799   // Perform bindless bounds check on each entry point function in module
800   InstProcessFunction pfn =
801       [this](BasicBlock::iterator ref_inst_itr,
802              UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
803              std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
804         return GenDescIdxCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
805                                    new_blocks);
806       };
807   bool modified = InstProcessEntryPointCallTree(pfn);
808   if (desc_init_enabled_ || buffer_bounds_enabled_) {
809     // Perform descriptor initialization and/or buffer bounds check on each
810     // entry point function in module
811     pfn = [this](BasicBlock::iterator ref_inst_itr,
812                  UptrVectorIterator<BasicBlock> ref_block_itr,
813                  uint32_t stage_idx,
814                  std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
815       return GenDescInitCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
816                                   new_blocks);
817     };
818     modified |= InstProcessEntryPointCallTree(pfn);
819   }
820   if (texel_buffer_enabled_) {
821     // Perform texel buffer bounds check on each entry point function in
822     // module. Generate after descriptor bounds and initialization checks.
823     pfn = [this](BasicBlock::iterator ref_inst_itr,
824                  UptrVectorIterator<BasicBlock> ref_block_itr,
825                  uint32_t stage_idx,
826                  std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
827       return GenTexBuffCheckCode(ref_inst_itr, ref_block_itr, stage_idx,
828                                  new_blocks);
829     };
830     modified |= InstProcessEntryPointCallTree(pfn);
831   }
832   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
833 }
834 
Process()835 Pass::Status InstBindlessCheckPass::Process() {
836   InitializeInstBindlessCheck();
837   return ProcessImpl();
838 }
839 
840 }  // namespace opt
841 }  // namespace spvtools
842