1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <thrift/conformance/GTestHarness.h>
18 
19 #include <memory>
20 #include <stdexcept>
21 
22 #include <fmt/core.h>
23 #include <folly/lang/Exception.h>
24 #include <thrift/conformance/cpp2/AnyRegistry.h>
25 #include <thrift/conformance/cpp2/Object.h>
26 #include <thrift/lib/cpp2/op/Compare.h>
27 #include <thrift/lib/cpp2/protocol/Serializer.h>
28 
29 namespace apache::thrift::conformance {
30 namespace {
31 
32 // From a newer version of gtest.
33 //
34 // TODO(afuller): Delete once gtest is updated.
35 template <typename Factory>
RegisterTest(const char * test_suite_name,const char * test_name,const char * type_param,const char * value_param,const char * file,int line,Factory factory)36 testing::TestInfo* RegisterTest(
37     const char* test_suite_name,
38     const char* test_name,
39     const char* type_param,
40     const char* value_param,
41     const char* file,
42     int line,
43     Factory factory) {
44   using TestT = typename std::remove_pointer<decltype(factory())>::type;
45 
46   class FactoryImpl : public testing::internal::TestFactoryBase {
47    public:
48     explicit FactoryImpl(Factory f) : factory_(std::move(f)) {}
49     testing::Test* CreateTest() override { return factory_(); }
50 
51    private:
52     Factory factory_;
53   };
54 
55   return testing::internal::MakeAndRegisterTestInfo(
56       test_suite_name,
57       test_name,
58       type_param,
59       value_param,
60       testing::internal::CodeLocation(file, line),
61       testing::internal::GetTypeId<TestT>(),
62       TestT::SetUpTestCase,
63       TestT::TearDownTestCase,
64       new FactoryImpl{std::move(factory)});
65 }
66 
RunRoundTripTest(ConformanceServiceAsyncClient & client,RoundTripTestCase roundTrip)67 testing::AssertionResult RunRoundTripTest(
68     ConformanceServiceAsyncClient& client, RoundTripTestCase roundTrip) {
69   RoundTripResponse res;
70   client.sync_roundTrip(res, *roundTrip.request_ref());
71 
72   const Any& expectedAny = roundTrip.expectedResponse_ref()
73       ? *roundTrip.expectedResponse_ref().value_unchecked().value_ref()
74       : *roundTrip.request_ref()->value_ref();
75 
76   // TODO(afuller): Make add asValueStruct support to AnyRegistry and use that
77   // instead of hard coding type.
78   auto actual = AnyRegistry::generated().load<Value>(*res.value_ref());
79   auto expected = AnyRegistry::generated().load<Value>(expectedAny);
80   if (!op::identical<type::struct_t<Value>>(actual, expected)) {
81     // TODO(afuller): Report out the delta
82     return testing::AssertionFailure();
83   }
84   return testing::AssertionSuccess();
85 }
86 
87 } // namespace
88 
parseNameAndCmd(std::string_view entry)89 std::pair<std::string_view, std::string_view> parseNameAndCmd(
90     std::string_view entry) {
91   // Look for a custom name.
92   auto pos = entry.find_last_of("#/");
93   if (pos != std::string_view::npos && entry[pos] == '#') {
94     if (pos == entry.size() - 1) {
95       // Just a trailing delim, remove it.
96       entry = entry.substr(0, pos);
97     } else {
98       // Use the custom name.
99       return {entry.substr(pos + 1), entry.substr(0, pos)};
100     }
101   }
102 
103   // No custom name, so use parent directory as name.
104   size_t stop = entry.find_last_of("\\/") - 1;
105   size_t start = entry.find_last_of("\\/", stop);
106   return {entry.substr(start + 1, stop - start), entry};
107 }
108 
parseCmds(std::string_view cmdsStr)109 std::map<std::string_view, std::string_view> parseCmds(
110     std::string_view cmdsStr) {
111   std::map<std::string_view, std::string_view> result;
112   std::vector<folly::StringPiece> cmds;
113   folly::split(',', cmdsStr, cmds);
114   for (auto cmd : cmds) {
115     auto entry = parseNameAndCmd(folly::trimWhitespace(cmd));
116     auto res = result.emplace(entry);
117     if (!res.second) {
118       folly::throw_exception<std::invalid_argument>(fmt::format(
119           "Multiple servers have the name {}: {} vs {}",
120           entry.first,
121           res.first->second,
122           entry.second));
123     }
124   }
125   return result;
126 }
127 
parseNonconforming(std::string_view data)128 std::set<std::string> parseNonconforming(std::string_view data) {
129   std::vector<folly::StringPiece> lines;
130   folly::split("\n", data, lines);
131   std::set<std::string> result;
132   for (auto& line : lines) {
133     // Strip any comments.
134     if (auto pos = line.find_first_of('#'); pos != folly::StringPiece::npos) {
135       line = line.subpiece(0, pos);
136     }
137     // Add trimmed, non-empty lines.
138     line = folly::trimWhitespace(line);
139     if (!line.empty()) {
140       result.emplace(line);
141     }
142   }
143   return result;
144 }
145 
RunTestCase(ConformanceServiceAsyncClient & client,const TestCase & testCase)146 testing::AssertionResult RunTestCase(
147     ConformanceServiceAsyncClient& client, const TestCase& testCase) {
148   switch (testCase.test_ref()->getType()) {
149     case TestCaseUnion::roundTrip:
150       return RunRoundTripTest(client, *testCase.roundTrip_ref());
151     default:
152       return testing::AssertionFailure()
153           << "Unsupported test case type: " << testCase.test_ref()->getType();
154   }
155 }
156 
157 class ConformanceTest : public testing::Test {
158  public:
ConformanceTest(ConformanceServiceAsyncClient * client,const TestCase * testCase,bool conforming)159   ConformanceTest(
160       ConformanceServiceAsyncClient* client,
161       const TestCase* testCase,
162       bool conforming)
163       : client_(client), testCase_(*testCase), conforming_(conforming) {}
164 
165  protected:
TestBody()166   void TestBody() override {
167     EXPECT_EQ(RunTestCase(*client_, testCase_), conforming_);
168   }
169 
170  private:
171   ConformanceServiceAsyncClient* const client_;
172   const TestCase& testCase_;
173   const bool conforming_;
174 };
175 
RegisterTests(std::string_view category,const TestSuite * suite,const std::set<std::string> & nonconforming,std::function<ConformanceServiceAsyncClient & ()> clientFn,const char * file,int line)176 void RegisterTests(
177     std::string_view category,
178     const TestSuite* suite,
179     const std::set<std::string>& nonconforming,
180     std::function<ConformanceServiceAsyncClient&()> clientFn,
181     const char* file,
182     int line) {
183   for (const auto& test : *suite->tests_ref()) {
184     for (const auto& testCase : *test.testCases_ref()) {
185       std::string suiteName = fmt::format(
186           "{}/{}/{}", category, *suite->name_ref(), *testCase.name_ref());
187       std::string fullName = fmt::format("{}.{}", suiteName, *test.name_ref());
188       bool conforming = nonconforming.find(fullName) == nonconforming.end();
189       RegisterTest(
190           suiteName.c_str(),
191           test.name_ref()->c_str(),
192           nullptr,
193           conforming ? nullptr : "nonconforming",
194           file,
195           line,
196           [&testCase, clientFn, conforming]() {
197             return new ConformanceTest(&clientFn(), &testCase, conforming);
198           });
199     }
200   }
201 }
202 
203 } // namespace apache::thrift::conformance
204