1 // Copyright 2016 The Shaderc Authors. All rights reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "libshaderc_util/spirv_tools_wrapper.h"
16 
17 #include <algorithm>
18 #include <sstream>
19 
20 #include "spirv-tools/optimizer.hpp"
21 
22 namespace shaderc_util {
23 
24 namespace {
25 
26 // Gets the corresponding target environment used in SPIRV-Tools.
GetSpirvToolsTargetEnv(Compiler::TargetEnv env,Compiler::TargetEnvVersion version)27 spv_target_env GetSpirvToolsTargetEnv(Compiler::TargetEnv env,
28                                       Compiler::TargetEnvVersion version) {
29   switch (env) {
30     case Compiler::TargetEnv::Vulkan:
31       switch (version) {
32         case Compiler::TargetEnvVersion::Default:
33           return SPV_ENV_VULKAN_1_0;
34         case Compiler::TargetEnvVersion::Vulkan_1_0:
35           return SPV_ENV_VULKAN_1_0;
36         case Compiler::TargetEnvVersion::Vulkan_1_1:
37           return SPV_ENV_VULKAN_1_1;
38         case Compiler::TargetEnvVersion::Vulkan_1_2:
39           return SPV_ENV_VULKAN_1_2;
40         default:
41           break;
42       }
43       break;
44     case Compiler::TargetEnv::OpenGL:
45       return SPV_ENV_OPENGL_4_5;
46     case Compiler::TargetEnv::OpenGLCompat:  // Deprecated
47       return SPV_ENV_OPENGL_4_5;
48     case Compiler::TargetEnv::WebGPU:
49       return SPV_ENV_WEBGPU_0;
50   }
51   assert(false && "unexpected target environment or version");
52   return SPV_ENV_VULKAN_1_0;
53 }
54 
55 }  // anonymous namespace
56 
SpirvToolsDisassemble(Compiler::TargetEnv env,Compiler::TargetEnvVersion version,const std::vector<uint32_t> & binary,std::string * text_or_error)57 bool SpirvToolsDisassemble(Compiler::TargetEnv env,
58                            Compiler::TargetEnvVersion version,
59                            const std::vector<uint32_t>& binary,
60                            std::string* text_or_error) {
61   spvtools::SpirvTools tools(GetSpirvToolsTargetEnv(env, version));
62   std::ostringstream oss;
63   tools.SetMessageConsumer([&oss](spv_message_level_t, const char*,
64                                   const spv_position_t& position,
65                                   const char* message) {
66     oss << position.index << ": " << message;
67   });
68   const bool success =
69       tools.Disassemble(binary, text_or_error,
70                         SPV_BINARY_TO_TEXT_OPTION_INDENT |
71                             SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES);
72   if (!success) {
73     *text_or_error = oss.str();
74   }
75   return success;
76 }
77 
SpirvToolsAssemble(Compiler::TargetEnv env,Compiler::TargetEnvVersion version,const string_piece assembly,spv_binary * binary,std::string * errors)78 bool SpirvToolsAssemble(Compiler::TargetEnv env,
79                         Compiler::TargetEnvVersion version,
80                         const string_piece assembly, spv_binary* binary,
81                         std::string* errors) {
82   auto spvtools_context =
83       spvContextCreate(GetSpirvToolsTargetEnv(env, version));
84   spv_diagnostic spvtools_diagnostic = nullptr;
85 
86   *binary = nullptr;
87   errors->clear();
88 
89   const bool success =
90       spvTextToBinary(spvtools_context, assembly.data(), assembly.size(),
91                       binary, &spvtools_diagnostic) == SPV_SUCCESS;
92   if (!success) {
93     std::ostringstream oss;
94     oss << spvtools_diagnostic->position.line + 1 << ":"
95         << spvtools_diagnostic->position.column + 1 << ": "
96         << spvtools_diagnostic->error;
97     *errors = oss.str();
98   }
99 
100   spvDiagnosticDestroy(spvtools_diagnostic);
101   spvContextDestroy(spvtools_context);
102 
103   return success;
104 }
105 
SpirvToolsOptimize(Compiler::TargetEnv env,Compiler::TargetEnvVersion version,const std::vector<PassId> & enabled_passes,std::vector<uint32_t> * binary,std::string * errors)106 bool SpirvToolsOptimize(Compiler::TargetEnv env,
107                         Compiler::TargetEnvVersion version,
108                         const std::vector<PassId>& enabled_passes,
109                         std::vector<uint32_t>* binary, std::string* errors) {
110   errors->clear();
111   if (enabled_passes.empty()) return true;
112   if (std::all_of(
113           enabled_passes.cbegin(), enabled_passes.cend(),
114           [](const PassId& pass) { return pass == PassId::kNullPass; })) {
115     return true;
116   }
117 
118   spvtools::ValidatorOptions val_opts;
119   // This allows flexible memory layout for HLSL.
120   val_opts.SetSkipBlockLayout(true);
121   // This allows HLSL legalization regarding resources.
122   val_opts.SetRelaxLogicalPointer(true);
123   // This uses relaxed rules for pre-legalized HLSL.
124   val_opts.SetBeforeHlslLegalization(true);
125 
126   spvtools::OptimizerOptions opt_opts;
127   opt_opts.set_validator_options(val_opts);
128   opt_opts.set_run_validator(true);
129 
130   spvtools::Optimizer optimizer(GetSpirvToolsTargetEnv(env, version));
131 
132   std::ostringstream oss;
133   optimizer.SetMessageConsumer(
134       [&oss](spv_message_level_t, const char*, const spv_position_t&,
135              const char* message) { oss << message << "\n"; });
136 
137   for (const auto& pass : enabled_passes) {
138     switch (pass) {
139       case PassId::kLegalizationPasses:
140         optimizer.RegisterLegalizationPasses();
141         break;
142       case PassId::kPerformancePasses:
143         optimizer.RegisterPerformancePasses();
144         break;
145       case PassId::kSizePasses:
146         optimizer.RegisterSizePasses();
147         break;
148       case PassId::kVulkanToWebGPUPasses:
149         optimizer.RegisterVulkanToWebGPUPasses();
150         break;
151       case PassId::kNullPass:
152         // We actually don't need to do anything for null pass.
153         break;
154       case PassId::kStripDebugInfo:
155         optimizer.RegisterPass(spvtools::CreateStripDebugInfoPass());
156         break;
157       case PassId::kCompactIds:
158         optimizer.RegisterPass(spvtools::CreateCompactIdsPass());
159         break;
160     }
161   }
162 
163   if (!optimizer.Run(binary->data(), binary->size(), binary, opt_opts)) {
164     *errors = oss.str();
165     return false;
166   }
167   return true;
168 }
169 
170 }  // namespace shaderc_util
171