1 // Copyright (c) 2017 Google Inc.
2 // Modifications Copyright (C) 2020 Advanced Micro Devices, Inc. All rights
3 // reserved.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 
17 // Validates correctness of atomic SPIR-V instructions.
18 
19 #include "source/val/validate.h"
20 
21 #include "source/diagnostic.h"
22 #include "source/opcode.h"
23 #include "source/spirv_target_env.h"
24 #include "source/util/bitutils.h"
25 #include "source/val/instruction.h"
26 #include "source/val/validate_memory_semantics.h"
27 #include "source/val/validate_scopes.h"
28 #include "source/val/validation_state.h"
29 
30 namespace {
31 
IsStorageClassAllowedByUniversalRules(uint32_t storage_class)32 bool IsStorageClassAllowedByUniversalRules(uint32_t storage_class) {
33   switch (storage_class) {
34     case SpvStorageClassUniform:
35     case SpvStorageClassStorageBuffer:
36     case SpvStorageClassWorkgroup:
37     case SpvStorageClassCrossWorkgroup:
38     case SpvStorageClassGeneric:
39     case SpvStorageClassAtomicCounter:
40     case SpvStorageClassImage:
41     case SpvStorageClassFunction:
42     case SpvStorageClassPhysicalStorageBufferEXT:
43       return true;
44       break;
45     default:
46       return false;
47   }
48 }
49 
50 }  // namespace
51 
52 namespace spvtools {
53 namespace val {
54 
55 // Validates correctness of atomic instructions.
AtomicsPass(ValidationState_t & _,const Instruction * inst)56 spv_result_t AtomicsPass(ValidationState_t& _, const Instruction* inst) {
57   const SpvOp opcode = inst->opcode();
58   const uint32_t result_type = inst->type_id();
59   bool is_atomic_float_opcode = false;
60   if (opcode == SpvOpAtomicLoad || opcode == SpvOpAtomicStore ||
61       opcode == SpvOpAtomicFAddEXT || opcode == SpvOpAtomicExchange) {
62     is_atomic_float_opcode = true;
63   }
64   switch (opcode) {
65     case SpvOpAtomicLoad:
66     case SpvOpAtomicStore:
67     case SpvOpAtomicExchange:
68     case SpvOpAtomicFAddEXT:
69     case SpvOpAtomicCompareExchange:
70     case SpvOpAtomicCompareExchangeWeak:
71     case SpvOpAtomicIIncrement:
72     case SpvOpAtomicIDecrement:
73     case SpvOpAtomicIAdd:
74     case SpvOpAtomicISub:
75     case SpvOpAtomicSMin:
76     case SpvOpAtomicUMin:
77     case SpvOpAtomicSMax:
78     case SpvOpAtomicUMax:
79     case SpvOpAtomicAnd:
80     case SpvOpAtomicOr:
81     case SpvOpAtomicXor:
82     case SpvOpAtomicFlagTestAndSet:
83     case SpvOpAtomicFlagClear: {
84       if (_.HasCapability(SpvCapabilityKernel) &&
85           (opcode == SpvOpAtomicLoad || opcode == SpvOpAtomicExchange ||
86            opcode == SpvOpAtomicCompareExchange)) {
87         if (!_.IsFloatScalarType(result_type) &&
88             !_.IsIntScalarType(result_type)) {
89           return _.diag(SPV_ERROR_INVALID_DATA, inst)
90                  << spvOpcodeString(opcode)
91                  << ": expected Result Type to be int or float scalar type";
92         }
93       } else if (opcode == SpvOpAtomicFlagTestAndSet) {
94         if (!_.IsBoolScalarType(result_type)) {
95           return _.diag(SPV_ERROR_INVALID_DATA, inst)
96                  << spvOpcodeString(opcode)
97                  << ": expected Result Type to be bool scalar type";
98         }
99       } else if (opcode == SpvOpAtomicFlagClear || opcode == SpvOpAtomicStore) {
100         assert(result_type == 0);
101       } else {
102         if (_.IsFloatScalarType(result_type)) {
103           if (is_atomic_float_opcode) {
104             if (opcode == SpvOpAtomicFAddEXT) {
105               if ((_.GetBitWidth(result_type) == 32) &&
106                   (!_.HasCapability(SpvCapabilityAtomicFloat32AddEXT))) {
107                 return _.diag(SPV_ERROR_INVALID_DATA, inst)
108                        << spvOpcodeString(opcode)
109                        << ": float add atomics require the AtomicFloat32AddEXT "
110                           "capability";
111               }
112               if ((_.GetBitWidth(result_type) == 64) &&
113                   (!_.HasCapability(SpvCapabilityAtomicFloat64AddEXT))) {
114                 return _.diag(SPV_ERROR_INVALID_DATA, inst)
115                        << spvOpcodeString(opcode)
116                        << ": float add atomics require the AtomicFloat64AddEXT "
117                           "capability";
118               }
119             }
120           } else {
121             return _.diag(SPV_ERROR_INVALID_DATA, inst)
122                    << spvOpcodeString(opcode)
123                    << ": expected Result Type to be int scalar type";
124           }
125         } else if (_.IsIntScalarType(result_type) &&
126                    opcode == SpvOpAtomicFAddEXT) {
127           return _.diag(SPV_ERROR_INVALID_DATA, inst)
128                  << spvOpcodeString(opcode)
129                  << ": expected Result Type to be float scalar type";
130         } else if (!_.IsFloatScalarType(result_type) &&
131                    !_.IsIntScalarType(result_type)) {
132           switch (opcode) {
133             case SpvOpAtomicFAddEXT:
134               return _.diag(SPV_ERROR_INVALID_DATA, inst)
135                      << spvOpcodeString(opcode)
136                      << ": expected Result Type to be float scalar type";
137             case SpvOpAtomicIIncrement:
138             case SpvOpAtomicIDecrement:
139             case SpvOpAtomicIAdd:
140             case SpvOpAtomicISub:
141             case SpvOpAtomicSMin:
142             case SpvOpAtomicSMax:
143             case SpvOpAtomicUMin:
144             case SpvOpAtomicUMax:
145               return _.diag(SPV_ERROR_INVALID_DATA, inst)
146                      << spvOpcodeString(opcode)
147                      << ": expected Result Type to be integer scalar type";
148             default:
149               return _.diag(SPV_ERROR_INVALID_DATA, inst)
150                      << spvOpcodeString(opcode)
151                      << ": expected Result Type to be int or float scalar type";
152           }
153         }
154 
155         if (spvIsVulkanEnv(_.context()->target_env) &&
156             (_.GetBitWidth(result_type) != 32 &&
157              (_.GetBitWidth(result_type) != 64 ||
158               !_.HasCapability(SpvCapabilityInt64ImageEXT)))) {
159           switch (opcode) {
160             case SpvOpAtomicSMin:
161             case SpvOpAtomicUMin:
162             case SpvOpAtomicSMax:
163             case SpvOpAtomicUMax:
164             case SpvOpAtomicAnd:
165             case SpvOpAtomicOr:
166             case SpvOpAtomicXor:
167             case SpvOpAtomicIAdd:
168             case SpvOpAtomicISub:
169             case SpvOpAtomicFAddEXT:
170             case SpvOpAtomicLoad:
171             case SpvOpAtomicStore:
172             case SpvOpAtomicExchange:
173             case SpvOpAtomicIIncrement:
174             case SpvOpAtomicIDecrement:
175             case SpvOpAtomicCompareExchangeWeak:
176             case SpvOpAtomicCompareExchange: {
177               if (_.GetBitWidth(result_type) == 64 &&
178                   _.IsIntScalarType(result_type) &&
179                   !_.HasCapability(SpvCapabilityInt64Atomics))
180                 return _.diag(SPV_ERROR_INVALID_DATA, inst)
181                        << spvOpcodeString(opcode)
182                        << ": 64-bit atomics require the Int64Atomics "
183                           "capability";
184             } break;
185             default:
186               return _.diag(SPV_ERROR_INVALID_DATA, inst)
187                      << spvOpcodeString(opcode)
188                      << ": according to the Vulkan spec atomic Result Type "
189                         "needs "
190                         "to be a 32-bit int scalar type";
191           }
192         }
193       }
194 
195       uint32_t operand_index =
196           opcode == SpvOpAtomicFlagClear || opcode == SpvOpAtomicStore ? 0 : 2;
197       const uint32_t pointer_type = _.GetOperandTypeId(inst, operand_index++);
198 
199       uint32_t data_type = 0;
200       uint32_t storage_class = 0;
201       if (!_.GetPointerTypeInfo(pointer_type, &data_type, &storage_class)) {
202         return _.diag(SPV_ERROR_INVALID_DATA, inst)
203                << spvOpcodeString(opcode)
204                << ": expected Pointer to be of type OpTypePointer";
205       }
206 
207       // Validate storage class against universal rules
208       if (!IsStorageClassAllowedByUniversalRules(storage_class)) {
209         return _.diag(SPV_ERROR_INVALID_DATA, inst)
210                << spvOpcodeString(opcode)
211                << ": storage class forbidden by universal validation rules.";
212       }
213 
214       // Then Shader rules
215       if (_.HasCapability(SpvCapabilityShader)) {
216         if (storage_class == SpvStorageClassFunction) {
217           return _.diag(SPV_ERROR_INVALID_DATA, inst)
218                  << spvOpcodeString(opcode)
219                  << ": Function storage class forbidden when the Shader "
220                     "capability is declared.";
221         }
222       }
223 
224       // And finally OpenCL environment rules
225       if (spvIsOpenCLEnv(_.context()->target_env)) {
226         if ((storage_class != SpvStorageClassFunction) &&
227             (storage_class != SpvStorageClassWorkgroup) &&
228             (storage_class != SpvStorageClassCrossWorkgroup) &&
229             (storage_class != SpvStorageClassGeneric)) {
230           return _.diag(SPV_ERROR_INVALID_DATA, inst)
231                  << spvOpcodeString(opcode)
232                  << ": storage class must be Function, Workgroup, "
233                     "CrossWorkGroup or Generic in the OpenCL environment.";
234         }
235 
236         if (_.context()->target_env == SPV_ENV_OPENCL_1_2) {
237           if (storage_class == SpvStorageClassGeneric) {
238             return _.diag(SPV_ERROR_INVALID_DATA, inst)
239                    << "Storage class cannot be Generic in OpenCL 1.2 "
240                       "environment";
241           }
242         }
243       }
244 
245       if (opcode == SpvOpAtomicFlagTestAndSet ||
246           opcode == SpvOpAtomicFlagClear) {
247         if (!_.IsIntScalarType(data_type) || _.GetBitWidth(data_type) != 32) {
248           return _.diag(SPV_ERROR_INVALID_DATA, inst)
249                  << spvOpcodeString(opcode)
250                  << ": expected Pointer to point to a value of 32-bit int type";
251         }
252       } else if (opcode == SpvOpAtomicStore) {
253         if (!_.IsFloatScalarType(data_type) && !_.IsIntScalarType(data_type)) {
254           return _.diag(SPV_ERROR_INVALID_DATA, inst)
255                  << spvOpcodeString(opcode)
256                  << ": expected Pointer to be a pointer to int or float "
257                  << "scalar type";
258         }
259       } else {
260         if (data_type != result_type) {
261           return _.diag(SPV_ERROR_INVALID_DATA, inst)
262                  << spvOpcodeString(opcode)
263                  << ": expected Pointer to point to a value of type Result "
264                     "Type";
265         }
266       }
267 
268       auto memory_scope = inst->GetOperandAs<const uint32_t>(operand_index++);
269       if (auto error = ValidateMemoryScope(_, inst, memory_scope)) {
270         return error;
271       }
272 
273       const auto equal_semantics_index = operand_index++;
274       if (auto error = ValidateMemorySemantics(_, inst, equal_semantics_index))
275         return error;
276 
277       if (opcode == SpvOpAtomicCompareExchange ||
278           opcode == SpvOpAtomicCompareExchangeWeak) {
279         const auto unequal_semantics_index = operand_index++;
280         if (auto error =
281                 ValidateMemorySemantics(_, inst, unequal_semantics_index))
282           return error;
283 
284         // Volatile bits must match for equal and unequal semantics. Previous
285         // checks guarantee they are 32-bit constants, but we need to recheck
286         // whether they are evaluatable constants.
287         bool is_int32 = false;
288         bool is_equal_const = false;
289         bool is_unequal_const = false;
290         uint32_t equal_value = 0;
291         uint32_t unequal_value = 0;
292         std::tie(is_int32, is_equal_const, equal_value) = _.EvalInt32IfConst(
293             inst->GetOperandAs<uint32_t>(equal_semantics_index));
294         std::tie(is_int32, is_unequal_const, unequal_value) =
295             _.EvalInt32IfConst(
296                 inst->GetOperandAs<uint32_t>(unequal_semantics_index));
297         if (is_equal_const && is_unequal_const &&
298             ((equal_value & SpvMemorySemanticsVolatileMask) ^
299              (unequal_value & SpvMemorySemanticsVolatileMask))) {
300           return _.diag(SPV_ERROR_INVALID_ID, inst)
301                  << "Volatile mask setting must match for Equal and Unequal "
302                     "memory semantics";
303         }
304       }
305 
306       if (opcode == SpvOpAtomicStore) {
307         const uint32_t value_type = _.GetOperandTypeId(inst, 3);
308         if (value_type != data_type) {
309           return _.diag(SPV_ERROR_INVALID_DATA, inst)
310                  << spvOpcodeString(opcode)
311                  << ": expected Value type and the type pointed to by "
312                     "Pointer to be the same";
313         }
314       } else if (opcode != SpvOpAtomicLoad && opcode != SpvOpAtomicIIncrement &&
315                  opcode != SpvOpAtomicIDecrement &&
316                  opcode != SpvOpAtomicFlagTestAndSet &&
317                  opcode != SpvOpAtomicFlagClear) {
318         const uint32_t value_type = _.GetOperandTypeId(inst, operand_index++);
319         if (value_type != result_type) {
320           return _.diag(SPV_ERROR_INVALID_DATA, inst)
321                  << spvOpcodeString(opcode)
322                  << ": expected Value to be of type Result Type";
323         }
324       }
325 
326       if (opcode == SpvOpAtomicCompareExchange ||
327           opcode == SpvOpAtomicCompareExchangeWeak) {
328         const uint32_t comparator_type =
329             _.GetOperandTypeId(inst, operand_index++);
330         if (comparator_type != result_type) {
331           return _.diag(SPV_ERROR_INVALID_DATA, inst)
332                  << spvOpcodeString(opcode)
333                  << ": expected Comparator to be of type Result Type";
334         }
335       }
336 
337       break;
338     }
339 
340     default:
341       break;
342   }
343 
344   return SPV_SUCCESS;
345 }
346 
347 }  // namespace val
348 }  // namespace spvtools
349