1 // Copyright (c) 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Validates correctness of derivative SPIR-V instructions.
16 
17 #include "source/val/validate.h"
18 
19 #include <string>
20 
21 #include "source/diagnostic.h"
22 #include "source/opcode.h"
23 #include "source/val/instruction.h"
24 #include "source/val/validation_state.h"
25 
26 namespace spvtools {
27 namespace val {
28 
29 // Validates correctness of derivative instructions.
DerivativesPass(ValidationState_t & _,const Instruction * inst)30 spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
31   const SpvOp opcode = inst->opcode();
32   const uint32_t result_type = inst->type_id();
33 
34   switch (opcode) {
35     case SpvOpDPdx:
36     case SpvOpDPdy:
37     case SpvOpFwidth:
38     case SpvOpDPdxFine:
39     case SpvOpDPdyFine:
40     case SpvOpFwidthFine:
41     case SpvOpDPdxCoarse:
42     case SpvOpDPdyCoarse:
43     case SpvOpFwidthCoarse: {
44       if (!_.IsFloatScalarOrVectorType(result_type)) {
45         return _.diag(SPV_ERROR_INVALID_DATA, inst)
46                << "Expected Result Type to be float scalar or vector type: "
47                << spvOpcodeString(opcode);
48       }
49       if (!_.ContainsSizedIntOrFloatType(result_type, SpvOpTypeFloat, 32)) {
50         return _.diag(SPV_ERROR_INVALID_DATA, inst)
51                << "Result type component width must be 32 bits";
52       }
53 
54       const uint32_t p_type = _.GetOperandTypeId(inst, 2);
55       if (p_type != result_type) {
56         return _.diag(SPV_ERROR_INVALID_DATA, inst)
57                << "Expected P type and Result Type to be the same: "
58                << spvOpcodeString(opcode);
59       }
60       _.function(inst->function()->id())
61           ->RegisterExecutionModelLimitation([opcode](SpvExecutionModel model,
62                                                       std::string* message) {
63             if (model != SpvExecutionModelFragment &&
64                 model != SpvExecutionModelGLCompute) {
65               if (message) {
66                 *message =
67                     std::string(
68                         "Derivative instructions require Fragment or GLCompute "
69                         "execution model: ") +
70                     spvOpcodeString(opcode);
71               }
72               return false;
73             }
74             return true;
75           });
76       _.function(inst->function()->id())
77           ->RegisterLimitation([opcode](const ValidationState_t& state,
78                                         const Function* entry_point,
79                                         std::string* message) {
80             const auto* models = state.GetExecutionModels(entry_point->id());
81             const auto* modes = state.GetExecutionModes(entry_point->id());
82             if (models &&
83                 models->find(SpvExecutionModelGLCompute) != models->end() &&
84                 (!modes ||
85                  (modes->find(SpvExecutionModeDerivativeGroupLinearNV) ==
86                       modes->end() &&
87                   modes->find(SpvExecutionModeDerivativeGroupQuadsNV) ==
88                       modes->end()))) {
89               if (message) {
90                 *message = std::string(
91                                "Derivative instructions require "
92                                "DerivativeGroupQuadsNV "
93                                "or DerivativeGroupLinearNV execution mode for "
94                                "GLCompute execution model: ") +
95                            spvOpcodeString(opcode);
96               }
97               return false;
98             }
99             return true;
100           });
101       break;
102     }
103 
104     default:
105       break;
106   }
107 
108   return SPV_SUCCESS;
109 }
110 
111 }  // namespace val
112 }  // namespace spvtools
113