1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <thrift/conformance/GTestHarness.h>
18
19 #include <memory>
20 #include <stdexcept>
21
22 #include <fmt/core.h>
23 #include <folly/lang/Exception.h>
24 #include <thrift/conformance/cpp2/AnyRegistry.h>
25 #include <thrift/conformance/cpp2/Object.h>
26 #include <thrift/lib/cpp2/op/Compare.h>
27 #include <thrift/lib/cpp2/protocol/Serializer.h>
28
29 namespace apache::thrift::conformance {
30 namespace {
31
32 // From a newer version of gtest.
33 //
34 // TODO(afuller): Delete once gtest is updated.
35 template <typename Factory>
RegisterTest(const char * test_suite_name,const char * test_name,const char * type_param,const char * value_param,const char * file,int line,Factory factory)36 testing::TestInfo* RegisterTest(
37 const char* test_suite_name,
38 const char* test_name,
39 const char* type_param,
40 const char* value_param,
41 const char* file,
42 int line,
43 Factory factory) {
44 using TestT = typename std::remove_pointer<decltype(factory())>::type;
45
46 class FactoryImpl : public testing::internal::TestFactoryBase {
47 public:
48 explicit FactoryImpl(Factory f) : factory_(std::move(f)) {}
49 testing::Test* CreateTest() override { return factory_(); }
50
51 private:
52 Factory factory_;
53 };
54
55 return testing::internal::MakeAndRegisterTestInfo(
56 test_suite_name,
57 test_name,
58 type_param,
59 value_param,
60 testing::internal::CodeLocation(file, line),
61 testing::internal::GetTypeId<TestT>(),
62 TestT::SetUpTestCase,
63 TestT::TearDownTestCase,
64 new FactoryImpl{std::move(factory)});
65 }
66
RunRoundTripTest(ConformanceServiceAsyncClient & client,RoundTripTestCase roundTrip)67 testing::AssertionResult RunRoundTripTest(
68 ConformanceServiceAsyncClient& client, RoundTripTestCase roundTrip) {
69 RoundTripResponse res;
70 client.sync_roundTrip(res, *roundTrip.request_ref());
71
72 const Any& expectedAny = roundTrip.expectedResponse_ref()
73 ? *roundTrip.expectedResponse_ref().value_unchecked().value_ref()
74 : *roundTrip.request_ref()->value_ref();
75
76 // TODO(afuller): Make add asValueStruct support to AnyRegistry and use that
77 // instead of hard coding type.
78 auto actual = AnyRegistry::generated().load<Value>(*res.value_ref());
79 auto expected = AnyRegistry::generated().load<Value>(expectedAny);
80 if (!op::identical<type::struct_t<Value>>(actual, expected)) {
81 // TODO(afuller): Report out the delta
82 return testing::AssertionFailure();
83 }
84 return testing::AssertionSuccess();
85 }
86
87 } // namespace
88
parseNameAndCmd(std::string_view entry)89 std::pair<std::string_view, std::string_view> parseNameAndCmd(
90 std::string_view entry) {
91 // Look for a custom name.
92 auto pos = entry.find_last_of("#/");
93 if (pos != std::string_view::npos && entry[pos] == '#') {
94 if (pos == entry.size() - 1) {
95 // Just a trailing delim, remove it.
96 entry = entry.substr(0, pos);
97 } else {
98 // Use the custom name.
99 return {entry.substr(pos + 1), entry.substr(0, pos)};
100 }
101 }
102
103 // No custom name, so use parent directory as name.
104 size_t stop = entry.find_last_of("\\/") - 1;
105 size_t start = entry.find_last_of("\\/", stop);
106 return {entry.substr(start + 1, stop - start), entry};
107 }
108
parseCmds(std::string_view cmdsStr)109 std::map<std::string_view, std::string_view> parseCmds(
110 std::string_view cmdsStr) {
111 std::map<std::string_view, std::string_view> result;
112 std::vector<folly::StringPiece> cmds;
113 folly::split(',', cmdsStr, cmds);
114 for (auto cmd : cmds) {
115 auto entry = parseNameAndCmd(folly::trimWhitespace(cmd));
116 auto res = result.emplace(entry);
117 if (!res.second) {
118 folly::throw_exception<std::invalid_argument>(fmt::format(
119 "Multiple servers have the name {}: {} vs {}",
120 entry.first,
121 res.first->second,
122 entry.second));
123 }
124 }
125 return result;
126 }
127
parseNonconforming(std::string_view data)128 std::set<std::string> parseNonconforming(std::string_view data) {
129 std::vector<folly::StringPiece> lines;
130 folly::split("\n", data, lines);
131 std::set<std::string> result;
132 for (auto& line : lines) {
133 // Strip any comments.
134 if (auto pos = line.find_first_of('#'); pos != folly::StringPiece::npos) {
135 line = line.subpiece(0, pos);
136 }
137 // Add trimmed, non-empty lines.
138 line = folly::trimWhitespace(line);
139 if (!line.empty()) {
140 result.emplace(line);
141 }
142 }
143 return result;
144 }
145
RunTestCase(ConformanceServiceAsyncClient & client,const TestCase & testCase)146 testing::AssertionResult RunTestCase(
147 ConformanceServiceAsyncClient& client, const TestCase& testCase) {
148 switch (testCase.test_ref()->getType()) {
149 case TestCaseUnion::roundTrip:
150 return RunRoundTripTest(client, *testCase.roundTrip_ref());
151 default:
152 return testing::AssertionFailure()
153 << "Unsupported test case type: " << testCase.test_ref()->getType();
154 }
155 }
156
157 class ConformanceTest : public testing::Test {
158 public:
ConformanceTest(ConformanceServiceAsyncClient * client,const TestCase * testCase,bool conforming)159 ConformanceTest(
160 ConformanceServiceAsyncClient* client,
161 const TestCase* testCase,
162 bool conforming)
163 : client_(client), testCase_(*testCase), conforming_(conforming) {}
164
165 protected:
TestBody()166 void TestBody() override {
167 EXPECT_EQ(RunTestCase(*client_, testCase_), conforming_);
168 }
169
170 private:
171 ConformanceServiceAsyncClient* const client_;
172 const TestCase& testCase_;
173 const bool conforming_;
174 };
175
RegisterTests(std::string_view category,const TestSuite * suite,const std::set<std::string> & nonconforming,std::function<ConformanceServiceAsyncClient & ()> clientFn,const char * file,int line)176 void RegisterTests(
177 std::string_view category,
178 const TestSuite* suite,
179 const std::set<std::string>& nonconforming,
180 std::function<ConformanceServiceAsyncClient&()> clientFn,
181 const char* file,
182 int line) {
183 for (const auto& test : *suite->tests_ref()) {
184 for (const auto& testCase : *test.testCases_ref()) {
185 std::string suiteName = fmt::format(
186 "{}/{}/{}", category, *suite->name_ref(), *testCase.name_ref());
187 std::string fullName = fmt::format("{}.{}", suiteName, *test.name_ref());
188 bool conforming = nonconforming.find(fullName) == nonconforming.end();
189 RegisterTest(
190 suiteName.c_str(),
191 test.name_ref()->c_str(),
192 nullptr,
193 conforming ? nullptr : "nonconforming",
194 file,
195 line,
196 [&testCase, clientFn, conforming]() {
197 return new ConformanceTest(&clientFn(), &testCase, conforming);
198 });
199 }
200 }
201 }
202
203 } // namespace apache::thrift::conformance
204