1 // Copyright (c) 2017 Pierre Moreau
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef TEST_LINK_LINKER_FIXTURE_H_
16 #define TEST_LINK_LINKER_FIXTURE_H_
17 
18 #include <iostream>
19 #include <string>
20 #include <vector>
21 
22 #include "effcee/effcee.h"
23 #include "re2/re2.h"
24 #include "source/spirv_constant.h"
25 #include "spirv-tools/linker.hpp"
26 #include "test/unit_spirv.h"
27 
28 namespace spvtest {
29 
30 using Binary = std::vector<uint32_t>;
31 using Binaries = std::vector<Binary>;
32 
33 class LinkerTest : public ::testing::Test {
34  public:
LinkerTest()35   LinkerTest()
36       : context_(SPV_ENV_UNIVERSAL_1_2),
37         tools_(SPV_ENV_UNIVERSAL_1_2),
38         assemble_options_(spvtools::SpirvTools::kDefaultAssembleOption),
39         disassemble_options_(spvtools::SpirvTools::kDefaultDisassembleOption) {
40     const auto consumer = [this](spv_message_level_t level, const char*,
41                                  const spv_position_t& position,
42                                  const char* message) {
43       if (!error_message_.empty()) error_message_ += "\n";
44       switch (level) {
45         case SPV_MSG_FATAL:
46         case SPV_MSG_INTERNAL_ERROR:
47         case SPV_MSG_ERROR:
48           error_message_ += "ERROR";
49           break;
50         case SPV_MSG_WARNING:
51           error_message_ += "WARNING";
52           break;
53         case SPV_MSG_INFO:
54           error_message_ += "INFO";
55           break;
56         case SPV_MSG_DEBUG:
57           error_message_ += "DEBUG";
58           break;
59       }
60       error_message_ += ": " + std::to_string(position.index) + ": " + message;
61     };
62     context_.SetMessageConsumer(consumer);
63     tools_.SetMessageConsumer(consumer);
64   }
65 
TearDown()66   void TearDown() override { error_message_.clear(); }
67 
68   // Assembles each of the given strings into SPIR-V binaries before linking
69   // them together. SPV_ERROR_INVALID_TEXT is returned if the assembling failed
70   // for any of the input strings, and SPV_ERROR_INVALID_POINTER if
71   // |linked_binary| is a null pointer.
72   spv_result_t AssembleAndLink(
73       const std::vector<std::string>& bodies, spvtest::Binary* linked_binary,
74       spvtools::LinkerOptions options = spvtools::LinkerOptions()) {
75     if (!linked_binary) return SPV_ERROR_INVALID_POINTER;
76 
77     spvtest::Binaries binaries(bodies.size());
78     for (size_t i = 0u; i < bodies.size(); ++i)
79       if (!tools_.Assemble(bodies[i], binaries.data() + i, assemble_options_))
80         return SPV_ERROR_INVALID_TEXT;
81 
82     return spvtools::Link(context_, binaries, linked_binary, options);
83   }
84 
85   // Assembles and links a vector of SPIR-V bodies based on the |templateBody|.
86   // Template arguments to be replaced are written as {a,b,...}.
87   // SPV_ERROR_INVALID_TEXT is returned if the assembling failed for any of the
88   // resulting bodies (or errors in the template), and SPV_ERROR_INVALID_POINTER
89   // if |linked_binary| is a null pointer.
90   spv_result_t ExpandAndLink(
91       const std::string& templateBody, spvtest::Binary* linked_binary,
92       spvtools::LinkerOptions options = spvtools::LinkerOptions()) {
93     if (!linked_binary) return SPV_ERROR_INVALID_POINTER;
94 
95     // Find out how many template arguments there are, we assume they all have
96     // the same number. We'll error later if they don't.
97     re2::StringPiece temp(templateBody);
98     re2::StringPiece x;
99     int cnt = 0;
100     if (!RE2::FindAndConsume(&temp, "{")) return SPV_ERROR_INVALID_TEXT;
101     while (RE2::FindAndConsume(&temp, "([,}])", &x) && x[0] == ',') cnt++;
102     cnt++;
103     if (cnt <= 1) return SPV_ERROR_INVALID_TEXT;
104 
105     // Construct a regex for a single common strip and template expansion.
106     std::string regex("([^{]*){");
107     for (int i = 0; i < cnt; i++) regex += (i > 0) ? ",([^,]*)" : "([^,]*)";
108     regex += "}";
109     RE2 pattern(regex);
110 
111     // Prepare the RE2::Args for processing.
112     re2::StringPiece common;
113     std::vector<re2::StringPiece> variants(cnt);
114     std::vector<RE2::Arg> args(cnt + 1);
115     args[0] = RE2::Arg(&common);
116     std::vector<RE2::Arg*> pargs(cnt + 1);
117     pargs[0] = &args[0];
118     for (int i = 0; i < cnt; i++) {
119       args[i + 1] = RE2::Arg(&variants[i]);
120       pargs[i + 1] = &args[i + 1];
121     }
122 
123     // Reset and construct the bodies bit by bit.
124     std::vector<std::string> bodies(cnt);
125     re2::StringPiece temp2(templateBody);
126     while (RE2::ConsumeN(&temp2, pattern, pargs.data(), cnt + 1)) {
127       for (int i = 0; i < cnt; i++) {
128         bodies[i].append(common.begin(), common.end());
129         bodies[i].append(variants[i].begin(), variants[i].end());
130       }
131     }
132     RE2::Consume(&temp2, "([^{]*)", &common);
133     for (int i = 0; i < cnt; i++)
134       bodies[i].append(common.begin(), common.end());
135 
136     // Run through the assemble and link stages of the process.
137     return AssembleAndLink(bodies, linked_binary, options);
138   }
139 
140   // Expand the |templateBody| and link the results as with ExpandAndLink,
141   // then disassemble and test that the result matches the |expected|.
142   void ExpandAndCheck(
143       const std::string& templateBody, const std::string& expected,
144       const spvtools::LinkerOptions options = spvtools::LinkerOptions()) {
145     spvtest::Binary linked_binary;
146     spv_result_t res = ExpandAndLink(templateBody, &linked_binary, options);
147     EXPECT_EQ(SPV_SUCCESS, res) << GetErrorMessage() << "\nExpanded from:\n"
148                                 << templateBody;
149     if (res == SPV_SUCCESS) {
150       std::string result;
151       EXPECT_TRUE(
152           tools_.Disassemble(linked_binary, &result, disassemble_options_))
153           << GetErrorMessage();
154       EXPECT_EQ(expected, result);
155     }
156   }
157 
158   // An alternative to ExpandAndCheck, which uses the |templateBody| as the
159   // match pattern for the disassembled linked result.
160   void ExpandAndMatch(
161       const std::string& templateBody,
162       const spvtools::LinkerOptions options = spvtools::LinkerOptions()) {
163     spvtest::Binary linked_binary;
164     spv_result_t res = ExpandAndLink(templateBody, &linked_binary, options);
165     EXPECT_EQ(SPV_SUCCESS, res) << GetErrorMessage() << "\nExpanded from:\n"
166                                 << templateBody;
167     if (res == SPV_SUCCESS) {
168       std::string result;
169       EXPECT_TRUE(
170           tools_.Disassemble(linked_binary, &result, disassemble_options_))
171           << GetErrorMessage();
172       auto match_res = effcee::Match(result, templateBody);
173       EXPECT_EQ(effcee::Result::Status::Ok, match_res.status())
174           << match_res.message() << "\nExpanded from:\n"
175           << templateBody << "\nChecking result:\n"
176           << result;
177     }
178   }
179 
180   // Links the given SPIR-V binaries together; SPV_ERROR_INVALID_POINTER is
181   // returned if |linked_binary| is a null pointer.
182   spv_result_t Link(
183       const spvtest::Binaries& binaries, spvtest::Binary* linked_binary,
184       spvtools::LinkerOptions options = spvtools::LinkerOptions()) {
185     if (!linked_binary) return SPV_ERROR_INVALID_POINTER;
186     return spvtools::Link(context_, binaries, linked_binary, options);
187   }
188 
189   // Disassembles |binary| and outputs the result in |text|. If |text| is a
190   // null pointer, SPV_ERROR_INVALID_POINTER is returned.
Disassemble(const spvtest::Binary & binary,std::string * text)191   spv_result_t Disassemble(const spvtest::Binary& binary, std::string* text) {
192     if (!text) return SPV_ERROR_INVALID_POINTER;
193     return tools_.Disassemble(binary, text, disassemble_options_)
194                ? SPV_SUCCESS
195                : SPV_ERROR_INVALID_BINARY;
196   }
197 
198   // Sets the options for the assembler.
SetAssembleOptions(uint32_t assemble_options)199   void SetAssembleOptions(uint32_t assemble_options) {
200     assemble_options_ = assemble_options;
201   }
202 
203   // Sets the options used by the disassembler.
SetDisassembleOptions(uint32_t disassemble_options)204   void SetDisassembleOptions(uint32_t disassemble_options) {
205     disassemble_options_ = disassemble_options;
206   }
207 
208   // Returns the accumulated error messages for the test.
GetErrorMessage()209   std::string GetErrorMessage() const { return error_message_; }
210 
211  private:
212   spvtools::Context context_;
213   spvtools::SpirvTools
214       tools_;  // An instance for calling SPIRV-Tools functionalities.
215   uint32_t assemble_options_;
216   uint32_t disassemble_options_;
217   std::string error_message_;
218 };
219 
220 }  // namespace spvtest
221 
222 #endif  // TEST_LINK_LINKER_FIXTURE_H_
223