1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <algorithm>
16
17 #include "source/opcode.h"
18 #include "source/val/instruction.h"
19 #include "source/val/validate.h"
20 #include "source/val/validation_state.h"
21
22 namespace spvtools {
23 namespace val {
24 namespace {
25
26 // Returns true if |a| and |b| are instructions defining pointers that point to
27 // types logically match and the decorations that apply to |b| are a subset
28 // of the decorations that apply to |a|.
DoPointeesLogicallyMatch(val::Instruction * a,val::Instruction * b,ValidationState_t & _)29 bool DoPointeesLogicallyMatch(val::Instruction* a, val::Instruction* b,
30 ValidationState_t& _) {
31 if (a->opcode() != SpvOpTypePointer || b->opcode() != SpvOpTypePointer) {
32 return false;
33 }
34
35 const auto& dec_a = _.id_decorations(a->id());
36 const auto& dec_b = _.id_decorations(b->id());
37 for (const auto& dec : dec_b) {
38 if (std::find(dec_a.begin(), dec_a.end(), dec) == dec_a.end()) {
39 return false;
40 }
41 }
42
43 uint32_t a_type = a->GetOperandAs<uint32_t>(2);
44 uint32_t b_type = b->GetOperandAs<uint32_t>(2);
45
46 if (a_type == b_type) {
47 return true;
48 }
49
50 Instruction* a_type_inst = _.FindDef(a_type);
51 Instruction* b_type_inst = _.FindDef(b_type);
52
53 return _.LogicallyMatch(a_type_inst, b_type_inst, true);
54 }
55
ValidateFunction(ValidationState_t & _,const Instruction * inst)56 spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
57 const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
58 const auto function_type = _.FindDef(function_type_id);
59 if (!function_type || SpvOpTypeFunction != function_type->opcode()) {
60 return _.diag(SPV_ERROR_INVALID_ID, inst)
61 << "OpFunction Function Type <id> '" << _.getIdName(function_type_id)
62 << "' is not a function type.";
63 }
64
65 const auto return_id = function_type->GetOperandAs<uint32_t>(1);
66 if (return_id != inst->type_id()) {
67 return _.diag(SPV_ERROR_INVALID_ID, inst)
68 << "OpFunction Result Type <id> '" << _.getIdName(inst->type_id())
69 << "' does not match the Function Type's return type <id> '"
70 << _.getIdName(return_id) << "'.";
71 }
72
73 const std::vector<SpvOp> acceptable = {
74 SpvOpGroupDecorate,
75 SpvOpDecorate,
76 SpvOpEnqueueKernel,
77 SpvOpEntryPoint,
78 SpvOpExecutionMode,
79 SpvOpExecutionModeId,
80 SpvOpFunctionCall,
81 SpvOpGetKernelNDrangeSubGroupCount,
82 SpvOpGetKernelNDrangeMaxSubGroupSize,
83 SpvOpGetKernelWorkGroupSize,
84 SpvOpGetKernelPreferredWorkGroupSizeMultiple,
85 SpvOpGetKernelLocalSizeForSubgroupCount,
86 SpvOpGetKernelMaxNumSubgroups,
87 SpvOpName};
88 for (auto& pair : inst->uses()) {
89 const auto* use = pair.first;
90 if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
91 acceptable.end() &&
92 !use->IsNonSemantic() && !use->IsDebugInfo()) {
93 return _.diag(SPV_ERROR_INVALID_ID, use)
94 << "Invalid use of function result id " << _.getIdName(inst->id())
95 << ".";
96 }
97 }
98
99 return SPV_SUCCESS;
100 }
101
ValidateFunctionParameter(ValidationState_t & _,const Instruction * inst)102 spv_result_t ValidateFunctionParameter(ValidationState_t& _,
103 const Instruction* inst) {
104 // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
105 size_t param_index = 0;
106 size_t inst_num = inst->LineNum() - 1;
107 if (inst_num == 0) {
108 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
109 << "Function parameter cannot be the first instruction.";
110 }
111
112 auto func_inst = &_.ordered_instructions()[inst_num];
113 while (--inst_num) {
114 func_inst = &_.ordered_instructions()[inst_num];
115 if (func_inst->opcode() == SpvOpFunction) {
116 break;
117 } else if (func_inst->opcode() == SpvOpFunctionParameter) {
118 ++param_index;
119 }
120 }
121
122 if (func_inst->opcode() != SpvOpFunction) {
123 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
124 << "Function parameter must be preceded by a function.";
125 }
126
127 const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
128 const auto function_type = _.FindDef(function_type_id);
129 if (!function_type) {
130 return _.diag(SPV_ERROR_INVALID_ID, func_inst)
131 << "Missing function type definition.";
132 }
133 if (param_index >= function_type->words().size() - 3) {
134 return _.diag(SPV_ERROR_INVALID_ID, inst)
135 << "Too many OpFunctionParameters for " << func_inst->id()
136 << ": expected " << function_type->words().size() - 3
137 << " based on the function's type";
138 }
139
140 const auto param_type =
141 _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
142 if (!param_type || inst->type_id() != param_type->id()) {
143 return _.diag(SPV_ERROR_INVALID_ID, inst)
144 << "OpFunctionParameter Result Type <id> '"
145 << _.getIdName(inst->type_id())
146 << "' does not match the OpTypeFunction parameter "
147 "type of the same index.";
148 }
149
150 // Validate that PhysicalStorageBufferEXT have one of Restrict, Aliased,
151 // RestrictPointerEXT, or AliasedPointerEXT.
152 auto param_nonarray_type_id = param_type->id();
153 while (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypeArray) {
154 param_nonarray_type_id =
155 _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u);
156 }
157 if (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypePointer) {
158 auto param_nonarray_type = _.FindDef(param_nonarray_type_id);
159 if (param_nonarray_type->GetOperandAs<uint32_t>(1u) ==
160 SpvStorageClassPhysicalStorageBufferEXT) {
161 // check for Aliased or Restrict
162 const auto& decorations = _.id_decorations(inst->id());
163
164 bool foundAliased = std::any_of(
165 decorations.begin(), decorations.end(), [](const Decoration& d) {
166 return SpvDecorationAliased == d.dec_type();
167 });
168
169 bool foundRestrict = std::any_of(
170 decorations.begin(), decorations.end(), [](const Decoration& d) {
171 return SpvDecorationRestrict == d.dec_type();
172 });
173
174 if (!foundAliased && !foundRestrict) {
175 return _.diag(SPV_ERROR_INVALID_ID, inst)
176 << "OpFunctionParameter " << inst->id()
177 << ": expected Aliased or Restrict for PhysicalStorageBufferEXT "
178 "pointer.";
179 }
180 if (foundAliased && foundRestrict) {
181 return _.diag(SPV_ERROR_INVALID_ID, inst)
182 << "OpFunctionParameter " << inst->id()
183 << ": can't specify both Aliased and Restrict for "
184 "PhysicalStorageBufferEXT pointer.";
185 }
186 } else {
187 const auto pointee_type_id =
188 param_nonarray_type->GetOperandAs<uint32_t>(2);
189 const auto pointee_type = _.FindDef(pointee_type_id);
190 if (SpvOpTypePointer == pointee_type->opcode() &&
191 pointee_type->GetOperandAs<uint32_t>(1u) ==
192 SpvStorageClassPhysicalStorageBufferEXT) {
193 // check for AliasedPointerEXT/RestrictPointerEXT
194 const auto& decorations = _.id_decorations(inst->id());
195
196 bool foundAliased = std::any_of(
197 decorations.begin(), decorations.end(), [](const Decoration& d) {
198 return SpvDecorationAliasedPointerEXT == d.dec_type();
199 });
200
201 bool foundRestrict = std::any_of(
202 decorations.begin(), decorations.end(), [](const Decoration& d) {
203 return SpvDecorationRestrictPointerEXT == d.dec_type();
204 });
205
206 if (!foundAliased && !foundRestrict) {
207 return _.diag(SPV_ERROR_INVALID_ID, inst)
208 << "OpFunctionParameter " << inst->id()
209 << ": expected AliasedPointerEXT or RestrictPointerEXT for "
210 "PhysicalStorageBufferEXT pointer.";
211 }
212 if (foundAliased && foundRestrict) {
213 return _.diag(SPV_ERROR_INVALID_ID, inst)
214 << "OpFunctionParameter " << inst->id()
215 << ": can't specify both AliasedPointerEXT and "
216 "RestrictPointerEXT for PhysicalStorageBufferEXT pointer.";
217 }
218 }
219 }
220 }
221
222 return SPV_SUCCESS;
223 }
224
ValidateFunctionCall(ValidationState_t & _,const Instruction * inst)225 spv_result_t ValidateFunctionCall(ValidationState_t& _,
226 const Instruction* inst) {
227 const auto function_id = inst->GetOperandAs<uint32_t>(2);
228 const auto function = _.FindDef(function_id);
229 if (!function || SpvOpFunction != function->opcode()) {
230 return _.diag(SPV_ERROR_INVALID_ID, inst)
231 << "OpFunctionCall Function <id> '" << _.getIdName(function_id)
232 << "' is not a function.";
233 }
234
235 auto return_type = _.FindDef(function->type_id());
236 if (!return_type || return_type->id() != inst->type_id()) {
237 return _.diag(SPV_ERROR_INVALID_ID, inst)
238 << "OpFunctionCall Result Type <id> '"
239 << _.getIdName(inst->type_id())
240 << "'s type does not match Function <id> '"
241 << _.getIdName(return_type->id()) << "'s return type.";
242 }
243
244 const auto function_type_id = function->GetOperandAs<uint32_t>(3);
245 const auto function_type = _.FindDef(function_type_id);
246 if (!function_type || function_type->opcode() != SpvOpTypeFunction) {
247 return _.diag(SPV_ERROR_INVALID_ID, inst)
248 << "Missing function type definition.";
249 }
250
251 const auto function_call_arg_count = inst->words().size() - 4;
252 const auto function_param_count = function_type->words().size() - 3;
253 if (function_param_count != function_call_arg_count) {
254 return _.diag(SPV_ERROR_INVALID_ID, inst)
255 << "OpFunctionCall Function <id>'s parameter count does not match "
256 "the argument count.";
257 }
258
259 for (size_t argument_index = 3, param_index = 2;
260 argument_index < inst->operands().size();
261 argument_index++, param_index++) {
262 const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
263 const auto argument = _.FindDef(argument_id);
264 if (!argument) {
265 return _.diag(SPV_ERROR_INVALID_ID, inst)
266 << "Missing argument " << argument_index - 3 << " definition.";
267 }
268
269 const auto argument_type = _.FindDef(argument->type_id());
270 if (!argument_type) {
271 return _.diag(SPV_ERROR_INVALID_ID, inst)
272 << "Missing argument " << argument_index - 3
273 << " type definition.";
274 }
275
276 const auto parameter_type_id =
277 function_type->GetOperandAs<uint32_t>(param_index);
278 const auto parameter_type = _.FindDef(parameter_type_id);
279 if (!parameter_type || argument_type->id() != parameter_type->id()) {
280 if (!_.options()->before_hlsl_legalization ||
281 !DoPointeesLogicallyMatch(argument_type, parameter_type, _)) {
282 return _.diag(SPV_ERROR_INVALID_ID, inst)
283 << "OpFunctionCall Argument <id> '" << _.getIdName(argument_id)
284 << "'s type does not match Function <id> '"
285 << _.getIdName(parameter_type_id) << "'s parameter type.";
286 }
287 }
288
289 if (_.addressing_model() == SpvAddressingModelLogical) {
290 if (parameter_type->opcode() == SpvOpTypePointer &&
291 !_.options()->relax_logical_pointer) {
292 SpvStorageClass sc = parameter_type->GetOperandAs<SpvStorageClass>(1u);
293 // Validate which storage classes can be pointer operands.
294 switch (sc) {
295 case SpvStorageClassUniformConstant:
296 case SpvStorageClassFunction:
297 case SpvStorageClassPrivate:
298 case SpvStorageClassWorkgroup:
299 case SpvStorageClassAtomicCounter:
300 // These are always allowed.
301 break;
302 case SpvStorageClassStorageBuffer:
303 if (!_.features().variable_pointers_storage_buffer) {
304 return _.diag(SPV_ERROR_INVALID_ID, inst)
305 << "StorageBuffer pointer operand "
306 << _.getIdName(argument_id)
307 << " requires a variable pointers capability";
308 }
309 break;
310 default:
311 return _.diag(SPV_ERROR_INVALID_ID, inst)
312 << "Invalid storage class for pointer operand "
313 << _.getIdName(argument_id);
314 }
315
316 // Validate memory object declaration requirements.
317 if (argument->opcode() != SpvOpVariable &&
318 argument->opcode() != SpvOpFunctionParameter) {
319 const bool ssbo_vptr =
320 _.features().variable_pointers_storage_buffer &&
321 sc == SpvStorageClassStorageBuffer;
322 const bool wg_vptr =
323 _.features().variable_pointers && sc == SpvStorageClassWorkgroup;
324 const bool uc_ptr = sc == SpvStorageClassUniformConstant;
325 if (!ssbo_vptr && !wg_vptr && !uc_ptr) {
326 return _.diag(SPV_ERROR_INVALID_ID, inst)
327 << "Pointer operand " << _.getIdName(argument_id)
328 << " must be a memory object declaration";
329 }
330 }
331 }
332 }
333 }
334 return SPV_SUCCESS;
335 }
336
337 } // namespace
338
FunctionPass(ValidationState_t & _,const Instruction * inst)339 spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
340 switch (inst->opcode()) {
341 case SpvOpFunction:
342 if (auto error = ValidateFunction(_, inst)) return error;
343 break;
344 case SpvOpFunctionParameter:
345 if (auto error = ValidateFunctionParameter(_, inst)) return error;
346 break;
347 case SpvOpFunctionCall:
348 if (auto error = ValidateFunctionCall(_, inst)) return error;
349 break;
350 default:
351 break;
352 }
353
354 return SPV_SUCCESS;
355 }
356
357 } // namespace val
358 } // namespace spvtools
359