1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <algorithm>
16 
17 #include "source/opcode.h"
18 #include "source/val/instruction.h"
19 #include "source/val/validate.h"
20 #include "source/val/validation_state.h"
21 
22 namespace spvtools {
23 namespace val {
24 namespace {
25 
26 // Returns true if |a| and |b| are instructions defining pointers that point to
27 // types logically match and the decorations that apply to |b| are a subset
28 // of the decorations that apply to |a|.
DoPointeesLogicallyMatch(val::Instruction * a,val::Instruction * b,ValidationState_t & _)29 bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b,
30                               ValidationState_t& _) {
31   if (a->opcode() != SpvOpTypePointer || b->opcode() != SpvOpTypePointer) {
32     return false;
33   }
34 
35   const auto& dec_a = _.id_decorations(a->id());
36   const auto& dec_b = _.id_decorations(b->id());
37   for (const auto& dec : dec_b) {
38     if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
39       return false;
40     }
41   }
42 
43   uint32_t a_type = a->GetOperandAs<uint32_t>(2);
44   uint32_t b_type = b->GetOperandAs<uint32_t>(2);
45 
46   if (a_type == b_type) {
47     return true;
48   }
49 
50   Instruction* a_type_inst = _.FindDef(a_type);
51   Instruction* b_type_inst = _.FindDef(b_type);
52 
53   return _.LogicallyMatch(a_type_inst, b_type_inst, true);
54 }
55 
ValidateFunction(ValidationState_t & _,const Instruction * inst)56 spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
57   const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
58   const auto function_type = _.FindDef(function_type_id);
59   if (!function_type || SpvOpTypeFunction != function_type->opcode()) {
60     return _.diag(SPV_ERROR_INVALID_ID, inst)
61            << "OpFunction Function Type <id> '" << _.getIdName(function_type_id)
62            << "' is not a function type.";
63   }
64 
65   const auto return_id = function_type->GetOperandAs<uint32_t>(1);
66   if (return_id != inst->type_id()) {
67     return _.diag(SPV_ERROR_INVALID_ID, inst)
68            << "OpFunction Result Type <id> '" << _.getIdName(inst->type_id())
69            << "' does not match the Function Type's return type <id> '"
70            << _.getIdName(return_id) << "'.";
71   }
72 
73   const std::vector<SpvOp> acceptable = {
74       SpvOpGroupDecorate,
75       SpvOpDecorate,
76       SpvOpEnqueueKernel,
77       SpvOpEntryPoint,
78       SpvOpExecutionMode,
79       SpvOpExecutionModeId,
80       SpvOpFunctionCall,
81       SpvOpGetKernelNDrangeSubGroupCount,
82       SpvOpGetKernelNDrangeMaxSubGroupSize,
83       SpvOpGetKernelWorkGroupSize,
84       SpvOpGetKernelPreferredWorkGroupSizeMultiple,
85       SpvOpGetKernelLocalSizeForSubgroupCount,
86       SpvOpGetKernelMaxNumSubgroups,
87       SpvOpName};
88   for (auto& pair : inst->uses()) {
89     const auto* use = pair.first;
90     if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
91             acceptable.end() &&
92         !use->IsNonSemantic() && !use->IsDebugInfo()) {
93       return _.diag(SPV_ERROR_INVALID_ID, use)
94              << "Invalid use of function result id " << _.getIdName(inst->id())
95              << ".";
96     }
97   }
98 
99   return SPV_SUCCESS;
100 }
101 
ValidateFunctionParameter(ValidationState_t & _,const Instruction * inst)102 spv_result_t ValidateFunctionParameter(ValidationState_t& _,
103                                        const Instruction* inst) {
104   // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
105   size_t param_index = 0;
106   size_t inst_num = inst->LineNum() - 1;
107   if (inst_num == 0) {
108     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
109            << "Function parameter cannot be the first instruction.";
110   }
111 
112   auto func_inst = &_.ordered_instructions()[inst_num];
113   while (--inst_num) {
114     func_inst = &_.ordered_instructions()[inst_num];
115     if (func_inst->opcode() == SpvOpFunction) {
116       break;
117     } else if (func_inst->opcode() == SpvOpFunctionParameter) {
118       ++param_index;
119     }
120   }
121 
122   if (func_inst->opcode() != SpvOpFunction) {
123     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
124            << "Function parameter must be preceded by a function.";
125   }
126 
127   const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
128   const auto function_type = _.FindDef(function_type_id);
129   if (!function_type) {
130     return _.diag(SPV_ERROR_INVALID_ID, func_inst)
131            << "Missing function type definition.";
132   }
133   if (param_index >= function_type->words().size() - 3) {
134     return _.diag(SPV_ERROR_INVALID_ID, inst)
135            << "Too many OpFunctionParameters for " << func_inst->id()
136            << ": expected " << function_type->words().size() - 3
137            << " based on the function's type";
138   }
139 
140   const auto param_type =
141       _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
142   if (!param_type || inst->type_id() != param_type->id()) {
143     return _.diag(SPV_ERROR_INVALID_ID, inst)
144            << "OpFunctionParameter Result Type <id> '"
145            << _.getIdName(inst->type_id())
146            << "' does not match the OpTypeFunction parameter "
147               "type of the same index.";
148   }
149 
150   // Validate that PhysicalStorageBufferEXT have one of Restrict, Aliased,
151   // RestrictPointerEXT, or AliasedPointerEXT.
152   auto param_nonarray_type_id = param_type->id();
153   while (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypeArray) {
154     param_nonarray_type_id =
155         _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u);
156   }
157   if (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypePointer) {
158     auto param_nonarray_type = _.FindDef(param_nonarray_type_id);
159     if (param_nonarray_type->GetOperandAs<uint32_t>(1u) ==
160         SpvStorageClassPhysicalStorageBufferEXT) {
161       // check for Aliased or Restrict
162       const auto& decorations = _.id_decorations(inst->id());
163 
164       bool foundAliased = std::any_of(
165           decorations.begin(), decorations.end(), [](const Decoration& d) {
166             return SpvDecorationAliased == d.dec_type();
167           });
168 
169       bool foundRestrict = std::any_of(
170           decorations.begin(), decorations.end(), [](const Decoration& d) {
171             return SpvDecorationRestrict == d.dec_type();
172           });
173 
174       if (!foundAliased && !foundRestrict) {
175         return _.diag(SPV_ERROR_INVALID_ID, inst)
176                << "OpFunctionParameter " << inst->id()
177                << ": expected Aliased or Restrict for PhysicalStorageBufferEXT "
178                   "pointer.";
179       }
180       if (foundAliased && foundRestrict) {
181         return _.diag(SPV_ERROR_INVALID_ID, inst)
182                << "OpFunctionParameter " << inst->id()
183                << ": can't specify both Aliased and Restrict for "
184                   "PhysicalStorageBufferEXT pointer.";
185       }
186     } else {
187       const auto pointee_type_id =
188           param_nonarray_type->GetOperandAs<uint32_t>(2);
189       const auto pointee_type = _.FindDef(pointee_type_id);
190       if (SpvOpTypePointer == pointee_type->opcode() &&
191           pointee_type->GetOperandAs<uint32_t>(1u) ==
192               SpvStorageClassPhysicalStorageBufferEXT) {
193         // check for AliasedPointerEXT/RestrictPointerEXT
194         const auto& decorations = _.id_decorations(inst->id());
195 
196         bool foundAliased = std::any_of(
197             decorations.begin(), decorations.end(), [](const Decoration& d) {
198               return SpvDecorationAliasedPointerEXT == d.dec_type();
199             });
200 
201         bool foundRestrict = std::any_of(
202             decorations.begin(), decorations.end(), [](const Decoration& d) {
203               return SpvDecorationRestrictPointerEXT == d.dec_type();
204             });
205 
206         if (!foundAliased && !foundRestrict) {
207           return _.diag(SPV_ERROR_INVALID_ID, inst)
208                  << "OpFunctionParameter " << inst->id()
209                  << ": expected AliasedPointerEXT or RestrictPointerEXT for "
210                     "PhysicalStorageBufferEXT pointer.";
211         }
212         if (foundAliased && foundRestrict) {
213           return _.diag(SPV_ERROR_INVALID_ID, inst)
214                  << "OpFunctionParameter " << inst->id()
215                  << ": can't specify both AliasedPointerEXT and "
216                     "RestrictPointerEXT for PhysicalStorageBufferEXT pointer.";
217         }
218       }
219     }
220   }
221 
222   return SPV_SUCCESS;
223 }
224 
ValidateFunctionCall(ValidationState_t & _,const Instruction * inst)225 spv_result_t ValidateFunctionCall(ValidationState_t& _,
226                                   const Instruction* inst) {
227   const auto function_id = inst->GetOperandAs<uint32_t>(2);
228   const auto function = _.FindDef(function_id);
229   if (!function || SpvOpFunction != function->opcode()) {
230     return _.diag(SPV_ERROR_INVALID_ID, inst)
231            << "OpFunctionCall Function <id> '" << _.getIdName(function_id)
232            << "' is not a function.";
233   }
234 
235   auto return_type = _.FindDef(function->type_id());
236   if (!return_type || return_type->id() != inst->type_id()) {
237     return _.diag(SPV_ERROR_INVALID_ID, inst)
238            << "OpFunctionCall Result Type <id> '"
239            << _.getIdName(inst->type_id())
240            << "'s type does not match Function <id> '"
241            << _.getIdName(return_type->id()) << "'s return type.";
242   }
243 
244   const auto function_type_id = function->GetOperandAs<uint32_t>(3);
245   const auto function_type = _.FindDef(function_type_id);
246   if (!function_type || function_type->opcode() != SpvOpTypeFunction) {
247     return _.diag(SPV_ERROR_INVALID_ID, inst)
248            << "Missing function type definition.";
249   }
250 
251   const auto function_call_arg_count = inst->words().size() - 4;
252   const auto function_param_count = function_type->words().size() - 3;
253   if (function_param_count != function_call_arg_count) {
254     return _.diag(SPV_ERROR_INVALID_ID, inst)
255            << "OpFunctionCall Function <id>'s parameter count does not match "
256               "the argument count.";
257   }
258 
259   for (size_t argument_index = 3, param_index = 2;
260        argument_index < inst->operands().size();
261        argument_index++, param_index++) {
262     const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
263     const auto argument = _.FindDef(argument_id);
264     if (!argument) {
265       return _.diag(SPV_ERROR_INVALID_ID, inst)
266              << "Missing argument " << argument_index - 3 << " definition.";
267     }
268 
269     const auto argument_type = _.FindDef(argument->type_id());
270     if (!argument_type) {
271       return _.diag(SPV_ERROR_INVALID_ID, inst)
272              << "Missing argument " << argument_index - 3
273              << " type definition.";
274     }
275 
276     const auto parameter_type_id =
277         function_type->GetOperandAs<uint32_t>(param_index);
278     const auto parameter_type = _.FindDef(parameter_type_id);
279     if (!parameter_type || argument_type->id() != parameter_type->id()) {
280       if (!_.options()->before_hlsl_legalization ||
281           !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) {
282         return _.diag(SPV_ERROR_INVALID_ID, inst)
283                << "OpFunctionCall Argument <id> '" << _.getIdName(argument_id)
284                << "'s type does not match Function <id> '"
285                << _.getIdName(parameter_type_id) << "'s parameter type.";
286       }
287     }
288 
289     if (_.addressing_model() == SpvAddressingModelLogical) {
290       if (parameter_type->opcode() == SpvOpTypePointer &&
291           !_.options()->relax_logical_pointer) {
292         SpvStorageClass sc = parameter_type->GetOperandAs<SpvStorageClass>(1u);
293         // Validate which storage classes can be pointer operands.
294         switch (sc) {
295           case SpvStorageClassUniformConstant:
296           case SpvStorageClassFunction:
297           case SpvStorageClassPrivate:
298           case SpvStorageClassWorkgroup:
299           case SpvStorageClassAtomicCounter:
300             // These are always allowed.
301             break;
302           case SpvStorageClassStorageBuffer:
303             if (!_.features().variable_pointers_storage_buffer) {
304               return _.diag(SPV_ERROR_INVALID_ID, inst)
305                      << "StorageBuffer pointer operand "
306                      << _.getIdName(argument_id)
307                      << " requires a variable pointers capability";
308             }
309             break;
310           default:
311             return _.diag(SPV_ERROR_INVALID_ID, inst)
312                    << "Invalid storage class for pointer operand "
313                    << _.getIdName(argument_id);
314         }
315 
316         // Validate memory object declaration requirements.
317         if (argument->opcode() != SpvOpVariable &&
318             argument->opcode() != SpvOpFunctionParameter) {
319           const bool ssbo_vptr =
320               _.features().variable_pointers_storage_buffer &&
321               sc == SpvStorageClassStorageBuffer;
322           const bool wg_vptr =
323               _.features().variable_pointers && sc == SpvStorageClassWorkgroup;
324           const bool uc_ptr = sc == SpvStorageClassUniformConstant;
325           if (!ssbo_vptr && !wg_vptr && !uc_ptr) {
326             return _.diag(SPV_ERROR_INVALID_ID, inst)
327                    << "Pointer operand " << _.getIdName(argument_id)
328                    << " must be a memory object declaration";
329           }
330         }
331       }
332     }
333   }
334   return SPV_SUCCESS;
335 }
336 
337 }  // namespace
338 
FunctionPass(ValidationState_t & _,const Instruction * inst)339 spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
340   switch (inst->opcode()) {
341     case SpvOpFunction:
342       if (auto error = ValidateFunction(_, inst)) return error;
343       break;
344     case SpvOpFunctionParameter:
345       if (auto error = ValidateFunctionParameter(_, inst)) return error;
346       break;
347     case SpvOpFunctionCall:
348       if (auto error = ValidateFunctionCall(_, inst)) return error;
349       break;
350     default:
351       break;
352   }
353 
354   return SPV_SUCCESS;
355 }
356 
357 }  // namespace val
358 }  // namespace spvtools
359