1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20
21 #include <atomic>
22 #include <chrono>
23 #include <cstdint>
24 #include <cstdio>
25 #include <cstring>
26 #include <iostream>
27 #include <memory>
28 #include <sstream>
29 #include <string>
30 #include <thread>
31 #include <vector>
32
33 #include "arrow/flight/api.h"
34 #include "arrow/ipc/test_common.h"
35 #include "arrow/status.h"
36 #include "arrow/testing/generator.h"
37 #include "arrow/testing/gtest_util.h"
38 #include "arrow/testing/util.h"
39 #include "arrow/util/base64.h"
40 #include "arrow/util/logging.h"
41 #include "arrow/util/make_unique.h"
42 #include "arrow/util/string.h"
43
44 #ifdef GRPCPP_GRPCPP_H
45 #error "gRPC headers should not be in public API"
46 #endif
47
48 #include "arrow/flight/client_cookie_middleware.h"
49 #include "arrow/flight/client_header_internal.h"
50 #include "arrow/flight/internal.h"
51 #include "arrow/flight/middleware_internal.h"
52 #include "arrow/flight/test_util.h"
53
54 namespace arrow {
55 namespace flight {
56
57 namespace pb = arrow::flight::protocol;
58
59 const char kValidUsername[] = "flight_username";
60 const char kValidPassword[] = "flight_password";
61 const char kInvalidUsername[] = "invalid_flight_username";
62 const char kInvalidPassword[] = "invalid_flight_password";
63 const char kBearerToken[] = "bearertoken";
64 const char kBasicPrefix[] = "Basic ";
65 const char kBearerPrefix[] = "Bearer ";
66 const char kAuthHeader[] = "authorization";
67
AssertEqual(const ActionType & expected,const ActionType & actual)68 void AssertEqual(const ActionType& expected, const ActionType& actual) {
69 ASSERT_EQ(expected.type, actual.type);
70 ASSERT_EQ(expected.description, actual.description);
71 }
72
AssertEqual(const FlightDescriptor & expected,const FlightDescriptor & actual)73 void AssertEqual(const FlightDescriptor& expected, const FlightDescriptor& actual) {
74 ASSERT_TRUE(expected.Equals(actual));
75 }
76
AssertEqual(const Ticket & expected,const Ticket & actual)77 void AssertEqual(const Ticket& expected, const Ticket& actual) {
78 ASSERT_EQ(expected.ticket, actual.ticket);
79 }
80
AssertEqual(const Location & expected,const Location & actual)81 void AssertEqual(const Location& expected, const Location& actual) {
82 ASSERT_EQ(expected, actual);
83 }
84
AssertEqual(const std::vector<FlightEndpoint> & expected,const std::vector<FlightEndpoint> & actual)85 void AssertEqual(const std::vector<FlightEndpoint>& expected,
86 const std::vector<FlightEndpoint>& actual) {
87 ASSERT_EQ(expected.size(), actual.size());
88 for (size_t i = 0; i < expected.size(); ++i) {
89 AssertEqual(expected[i].ticket, actual[i].ticket);
90
91 ASSERT_EQ(expected[i].locations.size(), actual[i].locations.size());
92 for (size_t j = 0; j < expected[i].locations.size(); ++j) {
93 AssertEqual(expected[i].locations[j], actual[i].locations[j]);
94 }
95 }
96 }
97
98 template <typename T>
AssertEqual(const std::vector<T> & expected,const std::vector<T> & actual)99 void AssertEqual(const std::vector<T>& expected, const std::vector<T>& actual) {
100 ASSERT_EQ(expected.size(), actual.size());
101 for (size_t i = 0; i < expected.size(); ++i) {
102 AssertEqual(expected[i], actual[i]);
103 }
104 }
105
AssertEqual(const FlightInfo & expected,const FlightInfo & actual)106 void AssertEqual(const FlightInfo& expected, const FlightInfo& actual) {
107 std::shared_ptr<Schema> ex_schema, actual_schema;
108 ipc::DictionaryMemo expected_memo;
109 ipc::DictionaryMemo actual_memo;
110 ASSERT_OK(expected.GetSchema(&expected_memo, &ex_schema));
111 ASSERT_OK(actual.GetSchema(&actual_memo, &actual_schema));
112
113 AssertSchemaEqual(*ex_schema, *actual_schema);
114 ASSERT_EQ(expected.total_records(), actual.total_records());
115 ASSERT_EQ(expected.total_bytes(), actual.total_bytes());
116
117 AssertEqual(expected.descriptor(), actual.descriptor());
118 AssertEqual(expected.endpoints(), actual.endpoints());
119 }
120
TEST(TestFlightDescriptor,Basics)121 TEST(TestFlightDescriptor, Basics) {
122 auto a = FlightDescriptor::Command("select * from table");
123 auto b = FlightDescriptor::Command("select * from table");
124 auto c = FlightDescriptor::Command("select foo from table");
125 auto d = FlightDescriptor::Path({"foo", "bar"});
126 auto e = FlightDescriptor::Path({"foo", "baz"});
127 auto f = FlightDescriptor::Path({"foo", "baz"});
128
129 ASSERT_EQ(a.ToString(), "FlightDescriptor<cmd = 'select * from table'>");
130 ASSERT_EQ(d.ToString(), "FlightDescriptor<path = 'foo/bar'>");
131 ASSERT_TRUE(a.Equals(b));
132 ASSERT_FALSE(a.Equals(c));
133 ASSERT_FALSE(a.Equals(d));
134 ASSERT_FALSE(d.Equals(e));
135 ASSERT_TRUE(e.Equals(f));
136 }
137
138 // This tests the internal protobuf types which don't get exported in the Flight DLL.
139 #ifndef _WIN32
TEST(TestFlightDescriptor,ToFromProto)140 TEST(TestFlightDescriptor, ToFromProto) {
141 FlightDescriptor descr_test;
142 pb::FlightDescriptor pb_descr;
143
144 FlightDescriptor descr1{FlightDescriptor::PATH, "", {"foo", "bar"}};
145 ASSERT_OK(internal::ToProto(descr1, &pb_descr));
146 ASSERT_OK(internal::FromProto(pb_descr, &descr_test));
147 AssertEqual(descr1, descr_test);
148
149 FlightDescriptor descr2{FlightDescriptor::CMD, "command", {}};
150 ASSERT_OK(internal::ToProto(descr2, &pb_descr));
151 ASSERT_OK(internal::FromProto(pb_descr, &descr_test));
152 AssertEqual(descr2, descr_test);
153 }
154 #endif
155
TEST(TestFlight,DISABLED_StartStopTestServer)156 TEST(TestFlight, DISABLED_StartStopTestServer) {
157 TestServer server("flight-test-server");
158 server.Start();
159 ASSERT_TRUE(server.IsRunning());
160
161 std::this_thread::sleep_for(std::chrono::duration<double>(0.2));
162
163 ASSERT_TRUE(server.IsRunning());
164 int exit_code = server.Stop();
165 #ifdef _WIN32
166 // We do a hard kill on Windows
167 ASSERT_EQ(259, exit_code);
168 #else
169 ASSERT_EQ(0, exit_code);
170 #endif
171 ASSERT_FALSE(server.IsRunning());
172 }
173
174 // ARROW-6017: we should be able to construct locations for unknown
175 // schemes
TEST(TestFlight,UnknownLocationScheme)176 TEST(TestFlight, UnknownLocationScheme) {
177 Location location;
178 ASSERT_OK(Location::Parse("s3://test", &location));
179 ASSERT_OK(Location::Parse("https://example.com/foo", &location));
180 }
181
TEST(TestFlight,ConnectUri)182 TEST(TestFlight, ConnectUri) {
183 TestServer server("flight-test-server");
184 server.Start();
185 ASSERT_TRUE(server.IsRunning());
186
187 std::stringstream ss;
188 ss << "grpc://localhost:" << server.port();
189 std::string uri = ss.str();
190
191 std::unique_ptr<FlightClient> client;
192 Location location1;
193 Location location2;
194 ASSERT_OK(Location::Parse(uri, &location1));
195 ASSERT_OK(Location::Parse(uri, &location2));
196 ASSERT_OK(FlightClient::Connect(location1, &client));
197 ASSERT_OK(FlightClient::Connect(location2, &client));
198 }
199
200 #ifndef _WIN32
TEST(TestFlight,ConnectUriUnix)201 TEST(TestFlight, ConnectUriUnix) {
202 TestServer server("flight-test-server", "/tmp/flight-test.sock");
203 server.Start();
204 ASSERT_TRUE(server.IsRunning());
205
206 std::stringstream ss;
207 ss << "grpc+unix://" << server.unix_sock();
208 std::string uri = ss.str();
209
210 std::unique_ptr<FlightClient> client;
211 Location location1;
212 Location location2;
213 ASSERT_OK(Location::Parse(uri, &location1));
214 ASSERT_OK(Location::Parse(uri, &location2));
215 ASSERT_OK(FlightClient::Connect(location1, &client));
216 ASSERT_OK(FlightClient::Connect(location2, &client));
217 }
218 #endif
219
TEST(TestFlight,RoundTripTypes)220 TEST(TestFlight, RoundTripTypes) {
221 Ticket ticket{"foo"};
222 std::string ticket_serialized;
223 Ticket ticket_deserialized;
224 ASSERT_OK(ticket.SerializeToString(&ticket_serialized));
225 ASSERT_OK(Ticket::Deserialize(ticket_serialized, &ticket_deserialized));
226 ASSERT_EQ(ticket.ticket, ticket_deserialized.ticket);
227
228 FlightDescriptor desc = FlightDescriptor::Command("select * from foo;");
229 std::string desc_serialized;
230 FlightDescriptor desc_deserialized;
231 ASSERT_OK(desc.SerializeToString(&desc_serialized));
232 ASSERT_OK(FlightDescriptor::Deserialize(desc_serialized, &desc_deserialized));
233 ASSERT_TRUE(desc.Equals(desc_deserialized));
234
235 desc = FlightDescriptor::Path({"a", "b", "test.arrow"});
236 ASSERT_OK(desc.SerializeToString(&desc_serialized));
237 ASSERT_OK(FlightDescriptor::Deserialize(desc_serialized, &desc_deserialized));
238 ASSERT_TRUE(desc.Equals(desc_deserialized));
239
240 FlightInfo::Data data;
241 std::shared_ptr<Schema> schema =
242 arrow::schema({field("a", int64()), field("b", int64()), field("c", int64()),
243 field("d", int64())});
244 Location location1, location2, location3;
245 ASSERT_OK(Location::ForGrpcTcp("localhost", 10010, &location1));
246 ASSERT_OK(Location::ForGrpcTls("localhost", 10010, &location2));
247 ASSERT_OK(Location::ForGrpcUnix("/tmp/test.sock", &location3));
248 std::vector<FlightEndpoint> endpoints{FlightEndpoint{ticket, {location1, location2}},
249 FlightEndpoint{ticket, {location3}}};
250 ASSERT_OK(MakeFlightInfo(*schema, desc, endpoints, -1, -1, &data));
251 std::unique_ptr<FlightInfo> info = std::unique_ptr<FlightInfo>(new FlightInfo(data));
252 std::string info_serialized;
253 std::unique_ptr<FlightInfo> info_deserialized;
254 ASSERT_OK(info->SerializeToString(&info_serialized));
255 ASSERT_OK(FlightInfo::Deserialize(info_serialized, &info_deserialized));
256 ASSERT_TRUE(info->descriptor().Equals(info_deserialized->descriptor()));
257 ASSERT_EQ(info->endpoints(), info_deserialized->endpoints());
258 ASSERT_EQ(info->total_records(), info_deserialized->total_records());
259 ASSERT_EQ(info->total_bytes(), info_deserialized->total_bytes());
260 }
261
TEST(TestFlight,RoundtripStatus)262 TEST(TestFlight, RoundtripStatus) {
263 // Make sure status codes round trip through our conversions
264
265 std::shared_ptr<FlightStatusDetail> detail;
266 detail = FlightStatusDetail::UnwrapStatus(
267 MakeFlightError(FlightStatusCode::Internal, "Test message"));
268 ASSERT_NE(nullptr, detail);
269 ASSERT_EQ(FlightStatusCode::Internal, detail->code());
270
271 detail = FlightStatusDetail::UnwrapStatus(
272 MakeFlightError(FlightStatusCode::TimedOut, "Test message"));
273 ASSERT_NE(nullptr, detail);
274 ASSERT_EQ(FlightStatusCode::TimedOut, detail->code());
275
276 detail = FlightStatusDetail::UnwrapStatus(
277 MakeFlightError(FlightStatusCode::Cancelled, "Test message"));
278 ASSERT_NE(nullptr, detail);
279 ASSERT_EQ(FlightStatusCode::Cancelled, detail->code());
280
281 detail = FlightStatusDetail::UnwrapStatus(
282 MakeFlightError(FlightStatusCode::Unauthenticated, "Test message"));
283 ASSERT_NE(nullptr, detail);
284 ASSERT_EQ(FlightStatusCode::Unauthenticated, detail->code());
285
286 detail = FlightStatusDetail::UnwrapStatus(
287 MakeFlightError(FlightStatusCode::Unauthorized, "Test message"));
288 ASSERT_NE(nullptr, detail);
289 ASSERT_EQ(FlightStatusCode::Unauthorized, detail->code());
290
291 detail = FlightStatusDetail::UnwrapStatus(
292 MakeFlightError(FlightStatusCode::Unavailable, "Test message"));
293 ASSERT_NE(nullptr, detail);
294 ASSERT_EQ(FlightStatusCode::Unavailable, detail->code());
295
296 Status status = internal::FromGrpcStatus(
297 internal::ToGrpcStatus(Status::NotImplemented("Sentinel")));
298 ASSERT_TRUE(status.IsNotImplemented());
299 ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
300
301 status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::Invalid("Sentinel")));
302 ASSERT_TRUE(status.IsInvalid());
303 ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
304
305 status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::KeyError("Sentinel")));
306 ASSERT_TRUE(status.IsKeyError());
307 ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
308
309 status =
310 internal::FromGrpcStatus(internal::ToGrpcStatus(Status::AlreadyExists("Sentinel")));
311 ASSERT_TRUE(status.IsAlreadyExists());
312 ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
313 }
314
TEST(TestFlight,GetPort)315 TEST(TestFlight, GetPort) {
316 Location location;
317 std::unique_ptr<FlightServerBase> server = ExampleTestServer();
318
319 ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location));
320 FlightServerOptions options(location);
321 ASSERT_OK(server->Init(options));
322 ASSERT_GT(server->port(), 0);
323 }
324
325 // CI environments don't have an IPv6 interface configured
TEST(TestFlight,DISABLED_IpV6Port)326 TEST(TestFlight, DISABLED_IpV6Port) {
327 Location location, location2;
328 std::unique_ptr<FlightServerBase> server = ExampleTestServer();
329
330 ASSERT_OK(Location::ForGrpcTcp("[::1]", 0, &location));
331 FlightServerOptions options(location);
332 ASSERT_OK(server->Init(options));
333 ASSERT_GT(server->port(), 0);
334
335 ASSERT_OK(Location::ForGrpcTcp("[::1]", server->port(), &location2));
336 std::unique_ptr<FlightClient> client;
337 ASSERT_OK(FlightClient::Connect(location2, &client));
338 std::unique_ptr<FlightListing> listing;
339 ASSERT_OK(client->ListFlights(&listing));
340 }
341
TEST(TestFlight,BuilderHook)342 TEST(TestFlight, BuilderHook) {
343 Location location;
344 std::unique_ptr<FlightServerBase> server = ExampleTestServer();
345
346 ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location));
347 FlightServerOptions options(location);
348 bool builder_hook_run = false;
349 options.builder_hook = [&builder_hook_run](void* builder) {
350 ASSERT_NE(nullptr, builder);
351 builder_hook_run = true;
352 };
353 ASSERT_OK(server->Init(options));
354 ASSERT_TRUE(builder_hook_run);
355 ASSERT_GT(server->port(), 0);
356 ASSERT_OK(server->Shutdown());
357 }
358
359 // ----------------------------------------------------------------------
360 // Client tests
361
362 // Helper to initialize a server and matching client with callbacks to
363 // populate options.
364 template <typename T, typename... Args>
MakeServer(std::unique_ptr<FlightServerBase> * server,std::unique_ptr<FlightClient> * client,std::function<Status (FlightServerOptions *)> make_server_options,std::function<Status (FlightClientOptions *)> make_client_options,Args &&...server_args)365 Status MakeServer(std::unique_ptr<FlightServerBase>* server,
366 std::unique_ptr<FlightClient>* client,
367 std::function<Status(FlightServerOptions*)> make_server_options,
368 std::function<Status(FlightClientOptions*)> make_client_options,
369 Args&&... server_args) {
370 Location location;
371 RETURN_NOT_OK(Location::ForGrpcTcp("localhost", 0, &location));
372 *server = arrow::internal::make_unique<T>(std::forward<Args>(server_args)...);
373 FlightServerOptions server_options(location);
374 RETURN_NOT_OK(make_server_options(&server_options));
375 RETURN_NOT_OK((*server)->Init(server_options));
376 Location real_location;
377 RETURN_NOT_OK(Location::ForGrpcTcp("localhost", (*server)->port(), &real_location));
378 FlightClientOptions client_options = FlightClientOptions::Defaults();
379 RETURN_NOT_OK(make_client_options(&client_options));
380 return FlightClient::Connect(real_location, client_options, client);
381 }
382
383 class TestFlightClient : public ::testing::Test {
384 public:
SetUp()385 void SetUp() {
386 server_ = ExampleTestServer();
387
388 Location location;
389 ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location));
390 FlightServerOptions options(location);
391 ASSERT_OK(server_->Init(options));
392
393 ASSERT_OK(ConnectClient());
394 }
395
TearDown()396 void TearDown() { ASSERT_OK(server_->Shutdown()); }
397
ConnectClient()398 Status ConnectClient() {
399 Location location;
400 RETURN_NOT_OK(Location::ForGrpcTcp("localhost", server_->port(), &location));
401 return FlightClient::Connect(location, &client_);
402 }
403
404 template <typename EndpointCheckFunc>
CheckDoGet(const FlightDescriptor & descr,const BatchVector & expected_batches,EndpointCheckFunc && check_endpoints)405 void CheckDoGet(const FlightDescriptor& descr, const BatchVector& expected_batches,
406 EndpointCheckFunc&& check_endpoints) {
407 auto expected_schema = expected_batches[0]->schema();
408
409 std::unique_ptr<FlightInfo> info;
410 ASSERT_OK(client_->GetFlightInfo(descr, &info));
411 check_endpoints(info->endpoints());
412
413 std::shared_ptr<Schema> schema;
414 ipc::DictionaryMemo dict_memo;
415 ASSERT_OK(info->GetSchema(&dict_memo, &schema));
416 AssertSchemaEqual(*expected_schema, *schema);
417
418 // By convention, fetch the first endpoint
419 Ticket ticket = info->endpoints()[0].ticket;
420 CheckDoGet(ticket, expected_batches);
421 }
422
CheckDoGet(const Ticket & ticket,const BatchVector & expected_batches)423 void CheckDoGet(const Ticket& ticket, const BatchVector& expected_batches) {
424 auto num_batches = static_cast<int>(expected_batches.size());
425 ASSERT_GE(num_batches, 2);
426
427 std::unique_ptr<FlightStreamReader> stream;
428 ASSERT_OK(client_->DoGet(ticket, &stream));
429
430 std::unique_ptr<FlightStreamReader> stream2;
431 ASSERT_OK(client_->DoGet(ticket, &stream2));
432 ASSERT_OK_AND_ASSIGN(auto reader, MakeRecordBatchReader(std::move(stream2)));
433
434 FlightStreamChunk chunk;
435 std::shared_ptr<RecordBatch> batch;
436 for (int i = 0; i < num_batches; ++i) {
437 ASSERT_OK(stream->Next(&chunk));
438 ASSERT_OK(reader->ReadNext(&batch));
439 ASSERT_NE(nullptr, chunk.data);
440 ASSERT_NE(nullptr, batch);
441 #if !defined(__MINGW32__)
442 ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data);
443 ASSERT_BATCHES_EQUAL(*expected_batches[i], *batch);
444 #else
445 // In MINGW32, the following code does not have the reproducibility at the LSB
446 // even when this is called twice with the same seed.
447 // As a workaround, use approxEqual
448 // /* from GenerateTypedData in random.cc */
449 // std::default_random_engine rng(seed); // seed = 282475250
450 // std::uniform_real_distribution<double> dist;
451 // std::generate(data, data + n, // n = 10
452 // [&dist, &rng] { return static_cast<ValueType>(dist(rng)); });
453 // /* data[1] = 0x40852cdfe23d3976 or 0x40852cdfe23d3975 */
454 ASSERT_BATCHES_APPROX_EQUAL(*expected_batches[i], *chunk.data);
455 ASSERT_BATCHES_APPROX_EQUAL(*expected_batches[i], *batch);
456 #endif
457 }
458
459 // Stream exhausted
460 ASSERT_OK(stream->Next(&chunk));
461 ASSERT_OK(reader->ReadNext(&batch));
462 ASSERT_EQ(nullptr, chunk.data);
463 ASSERT_EQ(nullptr, batch);
464 }
465
466 protected:
467 std::unique_ptr<FlightClient> client_;
468 std::unique_ptr<FlightServerBase> server_;
469 };
470
471 class AuthTestServer : public FlightServerBase {
DoAction(const ServerCallContext & context,const Action & action,std::unique_ptr<ResultStream> * result)472 Status DoAction(const ServerCallContext& context, const Action& action,
473 std::unique_ptr<ResultStream>* result) override {
474 auto buf = Buffer::FromString(context.peer_identity());
475 auto peer = Buffer::FromString(context.peer());
476 *result = std::unique_ptr<ResultStream>(
477 new SimpleResultStream({Result{buf}, Result{peer}}));
478 return Status::OK();
479 }
480 };
481
482 class TlsTestServer : public FlightServerBase {
DoAction(const ServerCallContext & context,const Action & action,std::unique_ptr<ResultStream> * result)483 Status DoAction(const ServerCallContext& context, const Action& action,
484 std::unique_ptr<ResultStream>* result) override {
485 auto buf = Buffer::FromString("Hello, world!");
486 *result = std::unique_ptr<ResultStream>(new SimpleResultStream({Result{buf}}));
487 return Status::OK();
488 }
489 };
490
491 class DoPutTestServer : public FlightServerBase {
492 public:
DoPut(const ServerCallContext & context,std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMetadataWriter> writer)493 Status DoPut(const ServerCallContext& context,
494 std::unique_ptr<FlightMessageReader> reader,
495 std::unique_ptr<FlightMetadataWriter> writer) override {
496 descriptor_ = reader->descriptor();
497 return reader->ReadAll(&batches_);
498 }
499
500 protected:
501 FlightDescriptor descriptor_;
502 BatchVector batches_;
503
504 friend class TestDoPut;
505 };
506
507 class MetadataTestServer : public FlightServerBase {
DoGet(const ServerCallContext & context,const Ticket & request,std::unique_ptr<FlightDataStream> * data_stream)508 Status DoGet(const ServerCallContext& context, const Ticket& request,
509 std::unique_ptr<FlightDataStream>* data_stream) override {
510 BatchVector batches;
511 if (request.ticket == "dicts") {
512 RETURN_NOT_OK(ExampleDictBatches(&batches));
513 } else if (request.ticket == "floats") {
514 RETURN_NOT_OK(ExampleFloatBatches(&batches));
515 } else {
516 RETURN_NOT_OK(ExampleIntBatches(&batches));
517 }
518 std::shared_ptr<RecordBatchReader> batch_reader =
519 std::make_shared<BatchIterator>(batches[0]->schema(), batches);
520
521 *data_stream = std::unique_ptr<FlightDataStream>(new NumberingStream(
522 std::unique_ptr<FlightDataStream>(new RecordBatchStream(batch_reader))));
523 return Status::OK();
524 }
525
DoPut(const ServerCallContext & context,std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMetadataWriter> writer)526 Status DoPut(const ServerCallContext& context,
527 std::unique_ptr<FlightMessageReader> reader,
528 std::unique_ptr<FlightMetadataWriter> writer) override {
529 FlightStreamChunk chunk;
530 int counter = 0;
531 while (true) {
532 RETURN_NOT_OK(reader->Next(&chunk));
533 if (chunk.data == nullptr) break;
534 if (chunk.app_metadata == nullptr) {
535 return Status::Invalid("Expected application metadata to be provided");
536 }
537 if (std::to_string(counter) != chunk.app_metadata->ToString()) {
538 return Status::Invalid("Expected metadata value: " + std::to_string(counter) +
539 " but got: " + chunk.app_metadata->ToString());
540 }
541 auto metadata = Buffer::FromString(std::to_string(counter));
542 RETURN_NOT_OK(writer->WriteMetadata(*metadata));
543 counter++;
544 }
545 return Status::OK();
546 }
547 };
548
549 // Server for testing custom IPC options support
550 class OptionsTestServer : public FlightServerBase {
DoGet(const ServerCallContext & context,const Ticket & request,std::unique_ptr<FlightDataStream> * data_stream)551 Status DoGet(const ServerCallContext& context, const Ticket& request,
552 std::unique_ptr<FlightDataStream>* data_stream) override {
553 BatchVector batches;
554 RETURN_NOT_OK(ExampleNestedBatches(&batches));
555 auto reader = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
556 *data_stream = std::unique_ptr<FlightDataStream>(new RecordBatchStream(reader));
557 return Status::OK();
558 }
559
560 // Just echo the number of batches written. The client will try to
561 // call this method with different write options set.
DoPut(const ServerCallContext & context,std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMetadataWriter> writer)562 Status DoPut(const ServerCallContext& context,
563 std::unique_ptr<FlightMessageReader> reader,
564 std::unique_ptr<FlightMetadataWriter> writer) override {
565 FlightStreamChunk chunk;
566 int counter = 0;
567 while (true) {
568 RETURN_NOT_OK(reader->Next(&chunk));
569 if (chunk.data == nullptr) break;
570 counter++;
571 }
572 auto metadata = Buffer::FromString(std::to_string(counter));
573 return writer->WriteMetadata(*metadata);
574 }
575
576 // Echo client data, but with write options set to limit the nesting
577 // level.
DoExchange(const ServerCallContext & context,std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMessageWriter> writer)578 Status DoExchange(const ServerCallContext& context,
579 std::unique_ptr<FlightMessageReader> reader,
580 std::unique_ptr<FlightMessageWriter> writer) override {
581 FlightStreamChunk chunk;
582 auto options = ipc::IpcWriteOptions::Defaults();
583 options.max_recursion_depth = 1;
584 bool begun = false;
585 while (true) {
586 RETURN_NOT_OK(reader->Next(&chunk));
587 if (!chunk.data && !chunk.app_metadata) {
588 break;
589 }
590 if (!begun && chunk.data) {
591 begun = true;
592 RETURN_NOT_OK(writer->Begin(chunk.data->schema(), options));
593 }
594 if (chunk.data && chunk.app_metadata) {
595 RETURN_NOT_OK(writer->WriteWithMetadata(*chunk.data, chunk.app_metadata));
596 } else if (chunk.data) {
597 RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data));
598 } else if (chunk.app_metadata) {
599 RETURN_NOT_OK(writer->WriteMetadata(chunk.app_metadata));
600 }
601 }
602 return Status::OK();
603 }
604 };
605
606 class HeaderAuthTestServer : public FlightServerBase {
607 public:
ListFlights(const ServerCallContext & context,const Criteria * criteria,std::unique_ptr<FlightListing> * listings)608 Status ListFlights(const ServerCallContext& context, const Criteria* criteria,
609 std::unique_ptr<FlightListing>* listings) override {
610 return Status::OK();
611 }
612 };
613
614 class TestMetadata : public ::testing::Test {
615 public:
SetUp()616 void SetUp() {
617 ASSERT_OK(MakeServer<MetadataTestServer>(
618 &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
619 [](FlightClientOptions* options) { return Status::OK(); }));
620 }
621
TearDown()622 void TearDown() { ASSERT_OK(server_->Shutdown()); }
623
624 protected:
625 std::unique_ptr<FlightClient> client_;
626 std::unique_ptr<FlightServerBase> server_;
627 };
628
629 class TestOptions : public ::testing::Test {
630 public:
SetUp()631 void SetUp() {
632 ASSERT_OK(MakeServer<OptionsTestServer>(
633 &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
634 [](FlightClientOptions* options) { return Status::OK(); }));
635 }
636
TearDown()637 void TearDown() { ASSERT_OK(server_->Shutdown()); }
638
639 protected:
640 std::unique_ptr<FlightClient> client_;
641 std::unique_ptr<FlightServerBase> server_;
642 };
643
644 class TestAuthHandler : public ::testing::Test {
645 public:
SetUp()646 void SetUp() {
647 ASSERT_OK(MakeServer<AuthTestServer>(
648 &server_, &client_,
649 [](FlightServerOptions* options) {
650 options->auth_handler = std::unique_ptr<ServerAuthHandler>(
651 new TestServerAuthHandler("user", "p4ssw0rd"));
652 return Status::OK();
653 },
654 [](FlightClientOptions* options) { return Status::OK(); }));
655 }
656
TearDown()657 void TearDown() { ASSERT_OK(server_->Shutdown()); }
658
659 protected:
660 std::unique_ptr<FlightClient> client_;
661 std::unique_ptr<FlightServerBase> server_;
662 };
663
664 class TestBasicAuthHandler : public ::testing::Test {
665 public:
SetUp()666 void SetUp() {
667 ASSERT_OK(MakeServer<AuthTestServer>(
668 &server_, &client_,
669 [](FlightServerOptions* options) {
670 options->auth_handler = std::unique_ptr<ServerAuthHandler>(
671 new TestServerBasicAuthHandler("user", "p4ssw0rd"));
672 return Status::OK();
673 },
674 [](FlightClientOptions* options) { return Status::OK(); }));
675 }
676
TearDown()677 void TearDown() { ASSERT_OK(server_->Shutdown()); }
678
679 protected:
680 std::unique_ptr<FlightClient> client_;
681 std::unique_ptr<FlightServerBase> server_;
682 };
683
684 class TestDoPut : public ::testing::Test {
685 public:
SetUp()686 void SetUp() {
687 ASSERT_OK(MakeServer<DoPutTestServer>(
688 &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
689 [](FlightClientOptions* options) { return Status::OK(); }));
690 do_put_server_ = (DoPutTestServer*)server_.get();
691 }
692
TearDown()693 void TearDown() { ASSERT_OK(server_->Shutdown()); }
694
CheckBatches(FlightDescriptor expected_descriptor,const BatchVector & expected_batches)695 void CheckBatches(FlightDescriptor expected_descriptor,
696 const BatchVector& expected_batches) {
697 ASSERT_TRUE(do_put_server_->descriptor_.Equals(expected_descriptor));
698 ASSERT_EQ(do_put_server_->batches_.size(), expected_batches.size());
699 for (size_t i = 0; i < expected_batches.size(); ++i) {
700 ASSERT_BATCHES_EQUAL(*do_put_server_->batches_[i], *expected_batches[i]);
701 }
702 }
703
CheckDoPut(FlightDescriptor descr,const std::shared_ptr<Schema> & schema,const BatchVector & batches)704 void CheckDoPut(FlightDescriptor descr, const std::shared_ptr<Schema>& schema,
705 const BatchVector& batches) {
706 std::unique_ptr<FlightStreamWriter> stream;
707 std::unique_ptr<FlightMetadataReader> reader;
708 ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader));
709 for (const auto& batch : batches) {
710 ASSERT_OK(stream->WriteRecordBatch(*batch));
711 }
712 ASSERT_OK(stream->DoneWriting());
713 ASSERT_OK(stream->Close());
714
715 CheckBatches(descr, batches);
716 }
717
718 protected:
719 std::unique_ptr<FlightClient> client_;
720 std::unique_ptr<FlightServerBase> server_;
721 DoPutTestServer* do_put_server_;
722 };
723
724 class TestTls : public ::testing::Test {
725 public:
SetUp()726 void SetUp() {
727 // Manually initialize gRPC to try to ensure some thread-locals
728 // get initialized.
729 // https://github.com/grpc/grpc/issues/13856
730 // https://github.com/grpc/grpc/issues/20311
731 // In general, gRPC on MacOS struggles with TLS (both in the sense
732 // of thread-locals and encryption)
733 grpc_init();
734
735 server_.reset(new TlsTestServer);
736
737 Location location;
738 ASSERT_OK(Location::ForGrpcTls("localhost", 0, &location));
739 FlightServerOptions options(location);
740 ASSERT_RAISES(UnknownError, server_->Init(options));
741 ASSERT_OK(ExampleTlsCertificates(&options.tls_certificates));
742 ASSERT_OK(server_->Init(options));
743
744 ASSERT_OK(Location::ForGrpcTls("localhost", server_->port(), &location_));
745 ASSERT_OK(ConnectClient());
746 }
747
TearDown()748 void TearDown() {
749 ASSERT_OK(server_->Shutdown());
750 grpc_shutdown();
751 }
752
ConnectClient()753 Status ConnectClient() {
754 auto options = FlightClientOptions::Defaults();
755 CertKeyPair root_cert;
756 RETURN_NOT_OK(ExampleTlsCertificateRoot(&root_cert));
757 options.tls_root_certs = root_cert.pem_cert;
758 return FlightClient::Connect(location_, options, &client_);
759 }
760
761 protected:
762 Location location_;
763 std::unique_ptr<FlightClient> client_;
764 std::unique_ptr<FlightServerBase> server_;
765 };
766
767 // A server middleware that rejects all calls.
768 class RejectServerMiddlewareFactory : public ServerMiddlewareFactory {
StartCall(const CallInfo & info,const CallHeaders & incoming_headers,std::shared_ptr<ServerMiddleware> * middleware)769 Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
770 std::shared_ptr<ServerMiddleware>* middleware) override {
771 return MakeFlightError(FlightStatusCode::Unauthenticated, "All calls are rejected");
772 }
773 };
774
775 // A server middleware that counts the number of successful and failed
776 // calls.
777 class CountingServerMiddleware : public ServerMiddleware {
778 public:
CountingServerMiddleware(std::atomic<int> * successful,std::atomic<int> * failed)779 CountingServerMiddleware(std::atomic<int>* successful, std::atomic<int>* failed)
780 : successful_(successful), failed_(failed) {}
SendingHeaders(AddCallHeaders * outgoing_headers)781 void SendingHeaders(AddCallHeaders* outgoing_headers) override {}
CallCompleted(const Status & status)782 void CallCompleted(const Status& status) override {
783 if (status.ok()) {
784 ARROW_IGNORE_EXPR((*successful_)++);
785 } else {
786 ARROW_IGNORE_EXPR((*failed_)++);
787 }
788 }
789
name() const790 std::string name() const override { return "CountingServerMiddleware"; }
791
792 private:
793 std::atomic<int>* successful_;
794 std::atomic<int>* failed_;
795 };
796
797 class CountingServerMiddlewareFactory : public ServerMiddlewareFactory {
798 public:
CountingServerMiddlewareFactory()799 CountingServerMiddlewareFactory() : successful_(0), failed_(0) {}
800
StartCall(const CallInfo & info,const CallHeaders & incoming_headers,std::shared_ptr<ServerMiddleware> * middleware)801 Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
802 std::shared_ptr<ServerMiddleware>* middleware) override {
803 *middleware = std::make_shared<CountingServerMiddleware>(&successful_, &failed_);
804 return Status::OK();
805 }
806
807 std::atomic<int> successful_;
808 std::atomic<int> failed_;
809 };
810
811 // The current span ID, used to emulate OpenTracing style distributed
812 // tracing. Only used for communication between application code and
813 // client middleware.
814 static thread_local std::string current_span_id = "";
815
816 // A server middleware that stores the current span ID, in an
817 // emulation of OpenTracing style distributed tracing.
818 class TracingServerMiddleware : public ServerMiddleware {
819 public:
TracingServerMiddleware(const std::string & current_span_id)820 explicit TracingServerMiddleware(const std::string& current_span_id)
821 : span_id(current_span_id) {}
SendingHeaders(AddCallHeaders * outgoing_headers)822 void SendingHeaders(AddCallHeaders* outgoing_headers) override {}
CallCompleted(const Status & status)823 void CallCompleted(const Status& status) override {}
824
name() const825 std::string name() const override { return "TracingServerMiddleware"; }
826
827 std::string span_id;
828 };
829
830 class TracingServerMiddlewareFactory : public ServerMiddlewareFactory {
831 public:
TracingServerMiddlewareFactory()832 TracingServerMiddlewareFactory() {}
833
StartCall(const CallInfo & info,const CallHeaders & incoming_headers,std::shared_ptr<ServerMiddleware> * middleware)834 Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
835 std::shared_ptr<ServerMiddleware>* middleware) override {
836 const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>& iter_pair =
837 incoming_headers.equal_range("x-tracing-span-id");
838 if (iter_pair.first != iter_pair.second) {
839 const util::string_view& value = (*iter_pair.first).second;
840 *middleware = std::make_shared<TracingServerMiddleware>(std::string(value));
841 }
842 return Status::OK();
843 }
844 };
845
846 // Function to look in CallHeaders for a key that has a value starting with prefix and
847 // return the rest of the value after the prefix.
FindKeyValPrefixInCallHeaders(const CallHeaders & incoming_headers,const std::string & key,const std::string & prefix)848 std::string FindKeyValPrefixInCallHeaders(const CallHeaders& incoming_headers,
849 const std::string& key,
850 const std::string& prefix) {
851 // Lambda function to compare characters without case sensitivity.
852 auto char_compare = [](const char& char1, const char& char2) {
853 return (::toupper(char1) == ::toupper(char2));
854 };
855
856 auto iter = incoming_headers.find(key);
857 if (iter == incoming_headers.end()) {
858 return "";
859 }
860 const std::string val = iter->second.to_string();
861 if (val.size() > prefix.length()) {
862 if (std::equal(val.begin(), val.begin() + prefix.length(), prefix.begin(),
863 char_compare)) {
864 return val.substr(prefix.length());
865 }
866 }
867 return "";
868 }
869
870 class HeaderAuthServerMiddleware : public ServerMiddleware {
871 public:
SendingHeaders(AddCallHeaders * outgoing_headers)872 void SendingHeaders(AddCallHeaders* outgoing_headers) override {
873 outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + kBearerToken);
874 }
875
CallCompleted(const Status & status)876 void CallCompleted(const Status& status) override {}
877
name() const878 std::string name() const override { return "HeaderAuthServerMiddleware"; }
879 };
880
ParseBasicHeader(const CallHeaders & incoming_headers,std::string & username,std::string & password)881 void ParseBasicHeader(const CallHeaders& incoming_headers, std::string& username,
882 std::string& password) {
883 std::string encoded_credentials =
884 FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix);
885 std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials));
886 std::getline(decoded_stream, username, ':');
887 std::getline(decoded_stream, password, ':');
888 }
889
890 // Factory for base64 header authentication testing.
891 class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory {
892 public:
HeaderAuthServerMiddlewareFactory()893 HeaderAuthServerMiddlewareFactory() {}
894
StartCall(const CallInfo & info,const CallHeaders & incoming_headers,std::shared_ptr<ServerMiddleware> * middleware)895 Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
896 std::shared_ptr<ServerMiddleware>* middleware) override {
897 std::string username, password;
898 ParseBasicHeader(incoming_headers, username, password);
899 if ((username == kValidUsername) && (password == kValidPassword)) {
900 *middleware = std::make_shared<HeaderAuthServerMiddleware>();
901 } else if ((username == kInvalidUsername) && (password == kInvalidPassword)) {
902 return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid credentials");
903 }
904 return Status::OK();
905 }
906 };
907
908 // A server middleware for validating incoming bearer header authentication.
909 class BearerAuthServerMiddleware : public ServerMiddleware {
910 public:
BearerAuthServerMiddleware(const CallHeaders & incoming_headers,bool * isValid)911 explicit BearerAuthServerMiddleware(const CallHeaders& incoming_headers, bool* isValid)
912 : isValid_(isValid) {
913 incoming_headers_ = incoming_headers;
914 }
915
SendingHeaders(AddCallHeaders * outgoing_headers)916 void SendingHeaders(AddCallHeaders* outgoing_headers) override {
917 std::string bearer_token =
918 FindKeyValPrefixInCallHeaders(incoming_headers_, kAuthHeader, kBearerPrefix);
919 *isValid_ = (bearer_token == std::string(kBearerToken));
920 }
921
CallCompleted(const Status & status)922 void CallCompleted(const Status& status) override {}
923
name() const924 std::string name() const override { return "BearerAuthServerMiddleware"; }
925
926 private:
927 CallHeaders incoming_headers_;
928 bool* isValid_;
929 };
930
931 // Factory for base64 header authentication testing.
932 class BearerAuthServerMiddlewareFactory : public ServerMiddlewareFactory {
933 public:
BearerAuthServerMiddlewareFactory()934 BearerAuthServerMiddlewareFactory() : isValid_(false) {}
935
StartCall(const CallInfo & info,const CallHeaders & incoming_headers,std::shared_ptr<ServerMiddleware> * middleware)936 Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
937 std::shared_ptr<ServerMiddleware>* middleware) override {
938 const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>& iter_pair =
939 incoming_headers.equal_range(kAuthHeader);
940 if (iter_pair.first != iter_pair.second) {
941 *middleware =
942 std::make_shared<BearerAuthServerMiddleware>(incoming_headers, &isValid_);
943 }
944 return Status::OK();
945 }
946
GetIsValid()947 bool GetIsValid() { return isValid_; }
948
949 private:
950 bool isValid_;
951 };
952
953 // A client middleware that adds a thread-local "request ID" to
954 // outgoing calls as a header, and keeps track of the status of
955 // completed calls. NOT thread-safe.
956 class PropagatingClientMiddleware : public ClientMiddleware {
957 public:
PropagatingClientMiddleware(std::atomic<int> * received_headers,std::vector<Status> * recorded_status)958 explicit PropagatingClientMiddleware(std::atomic<int>* received_headers,
959 std::vector<Status>* recorded_status)
960 : received_headers_(received_headers), recorded_status_(recorded_status) {}
961
SendingHeaders(AddCallHeaders * outgoing_headers)962 void SendingHeaders(AddCallHeaders* outgoing_headers) {
963 // Pick up the span ID from thread locals. We have to use a
964 // thread-local for communication, since we aren't even
965 // instantiated until after the application code has already
966 // started the call (and so there's no chance for application code
967 // to pass us parameters directly).
968 outgoing_headers->AddHeader("x-tracing-span-id", current_span_id);
969 }
970
ReceivedHeaders(const CallHeaders & incoming_headers)971 void ReceivedHeaders(const CallHeaders& incoming_headers) { (*received_headers_)++; }
972
CallCompleted(const Status & status)973 void CallCompleted(const Status& status) { recorded_status_->push_back(status); }
974
975 private:
976 std::atomic<int>* received_headers_;
977 std::vector<Status>* recorded_status_;
978 };
979
980 class PropagatingClientMiddlewareFactory : public ClientMiddlewareFactory {
981 public:
StartCall(const CallInfo & info,std::unique_ptr<ClientMiddleware> * middleware)982 void StartCall(const CallInfo& info, std::unique_ptr<ClientMiddleware>* middleware) {
983 recorded_calls_.push_back(info.method);
984 *middleware = arrow::internal::make_unique<PropagatingClientMiddleware>(
985 &received_headers_, &recorded_status_);
986 }
987
Reset()988 void Reset() {
989 recorded_calls_.clear();
990 recorded_status_.clear();
991 received_headers_.fetch_and(0);
992 }
993
994 std::vector<FlightMethod> recorded_calls_;
995 std::vector<Status> recorded_status_;
996 std::atomic<int> received_headers_;
997 };
998
999 class ReportContextTestServer : public FlightServerBase {
DoAction(const ServerCallContext & context,const Action & action,std::unique_ptr<ResultStream> * result)1000 Status DoAction(const ServerCallContext& context, const Action& action,
1001 std::unique_ptr<ResultStream>* result) override {
1002 std::shared_ptr<Buffer> buf;
1003 const ServerMiddleware* middleware = context.GetMiddleware("tracing");
1004 if (middleware == nullptr || middleware->name() != "TracingServerMiddleware") {
1005 buf = Buffer::FromString("");
1006 } else {
1007 buf = Buffer::FromString(((const TracingServerMiddleware*)middleware)->span_id);
1008 }
1009 *result = std::unique_ptr<ResultStream>(new SimpleResultStream({Result{buf}}));
1010 return Status::OK();
1011 }
1012 };
1013
1014 class ErrorMiddlewareServer : public FlightServerBase {
DoAction(const ServerCallContext & context,const Action & action,std::unique_ptr<ResultStream> * result)1015 Status DoAction(const ServerCallContext& context, const Action& action,
1016 std::unique_ptr<ResultStream>* result) override {
1017 std::string msg = "error_message";
1018 auto buf = Buffer::FromString("");
1019
1020 std::shared_ptr<FlightStatusDetail> flightStatusDetail(
1021 new FlightStatusDetail(FlightStatusCode::Failed, msg));
1022 *result = std::unique_ptr<ResultStream>(new SimpleResultStream({Result{buf}}));
1023 return Status(StatusCode::ExecutionError, "test failed", flightStatusDetail);
1024 }
1025 };
1026
1027 class PropagatingTestServer : public FlightServerBase {
1028 public:
PropagatingTestServer(std::unique_ptr<FlightClient> client)1029 explicit PropagatingTestServer(std::unique_ptr<FlightClient> client)
1030 : client_(std::move(client)) {}
1031
DoAction(const ServerCallContext & context,const Action & action,std::unique_ptr<ResultStream> * result)1032 Status DoAction(const ServerCallContext& context, const Action& action,
1033 std::unique_ptr<ResultStream>* result) override {
1034 const ServerMiddleware* middleware = context.GetMiddleware("tracing");
1035 if (middleware == nullptr || middleware->name() != "TracingServerMiddleware") {
1036 current_span_id = "";
1037 } else {
1038 current_span_id = ((const TracingServerMiddleware*)middleware)->span_id;
1039 }
1040
1041 return client_->DoAction(action, result);
1042 }
1043
1044 private:
1045 std::unique_ptr<FlightClient> client_;
1046 };
1047
1048 class TestRejectServerMiddleware : public ::testing::Test {
1049 public:
SetUp()1050 void SetUp() {
1051 ASSERT_OK(MakeServer<MetadataTestServer>(
1052 &server_, &client_,
1053 [](FlightServerOptions* options) {
1054 options->middleware.push_back(
1055 {"reject", std::make_shared<RejectServerMiddlewareFactory>()});
1056 return Status::OK();
1057 },
1058 [](FlightClientOptions* options) { return Status::OK(); }));
1059 }
1060
TearDown()1061 void TearDown() { ASSERT_OK(server_->Shutdown()); }
1062
1063 protected:
1064 std::unique_ptr<FlightClient> client_;
1065 std::unique_ptr<FlightServerBase> server_;
1066 };
1067
1068 class TestCountingServerMiddleware : public ::testing::Test {
1069 public:
SetUp()1070 void SetUp() {
1071 request_counter_ = std::make_shared<CountingServerMiddlewareFactory>();
1072 ASSERT_OK(MakeServer<MetadataTestServer>(
1073 &server_, &client_,
1074 [&](FlightServerOptions* options) {
1075 options->middleware.push_back({"request_counter", request_counter_});
1076 return Status::OK();
1077 },
1078 [](FlightClientOptions* options) { return Status::OK(); }));
1079 }
1080
TearDown()1081 void TearDown() { ASSERT_OK(server_->Shutdown()); }
1082
1083 protected:
1084 std::shared_ptr<CountingServerMiddlewareFactory> request_counter_;
1085 std::unique_ptr<FlightClient> client_;
1086 std::unique_ptr<FlightServerBase> server_;
1087 };
1088
1089 // Setup for this test is 2 servers
1090 // 1. Client makes request to server A with a request ID set
1091 // 2. server A extracts the request ID and makes a request to server B
1092 // with the same request ID set
1093 // 3. server B extracts the request ID and sends it back
1094 // 4. server A returns the response of server B
1095 // 5. Client validates the response
1096 class TestPropagatingMiddleware : public ::testing::Test {
1097 public:
SetUp()1098 void SetUp() {
1099 server_middleware_ = std::make_shared<TracingServerMiddlewareFactory>();
1100 second_client_middleware_ = std::make_shared<PropagatingClientMiddlewareFactory>();
1101 client_middleware_ = std::make_shared<PropagatingClientMiddlewareFactory>();
1102
1103 std::unique_ptr<FlightClient> server_client;
1104 ASSERT_OK(MakeServer<ReportContextTestServer>(
1105 &second_server_, &server_client,
1106 [&](FlightServerOptions* options) {
1107 options->middleware.push_back({"tracing", server_middleware_});
1108 return Status::OK();
1109 },
1110 [&](FlightClientOptions* options) {
1111 options->middleware.push_back(second_client_middleware_);
1112 return Status::OK();
1113 }));
1114
1115 ASSERT_OK(MakeServer<PropagatingTestServer>(
1116 &first_server_, &client_,
1117 [&](FlightServerOptions* options) {
1118 options->middleware.push_back({"tracing", server_middleware_});
1119 return Status::OK();
1120 },
1121 [&](FlightClientOptions* options) {
1122 options->middleware.push_back(client_middleware_);
1123 return Status::OK();
1124 },
1125 std::move(server_client)));
1126 }
1127
ValidateStatus(const Status & status,const FlightMethod & method)1128 void ValidateStatus(const Status& status, const FlightMethod& method) {
1129 ASSERT_EQ(1, client_middleware_->received_headers_);
1130 ASSERT_EQ(method, client_middleware_->recorded_calls_.at(0));
1131 ASSERT_EQ(status.code(), client_middleware_->recorded_status_.at(0).code());
1132 }
1133
TearDown()1134 void TearDown() {
1135 ASSERT_OK(first_server_->Shutdown());
1136 ASSERT_OK(second_server_->Shutdown());
1137 }
1138
CheckHeader(const std::string & header,const std::string & value,const CallHeaders::const_iterator & it)1139 void CheckHeader(const std::string& header, const std::string& value,
1140 const CallHeaders::const_iterator& it) {
1141 // Construct a string_view before comparison to satisfy MSVC
1142 util::string_view header_view(header.data(), header.length());
1143 util::string_view value_view(value.data(), value.length());
1144 ASSERT_EQ(header_view, (*it).first);
1145 ASSERT_EQ(value_view, (*it).second);
1146 }
1147
1148 protected:
1149 std::unique_ptr<FlightClient> client_;
1150 std::unique_ptr<FlightServerBase> first_server_;
1151 std::unique_ptr<FlightServerBase> second_server_;
1152 std::shared_ptr<TracingServerMiddlewareFactory> server_middleware_;
1153 std::shared_ptr<PropagatingClientMiddlewareFactory> second_client_middleware_;
1154 std::shared_ptr<PropagatingClientMiddlewareFactory> client_middleware_;
1155 };
1156
1157 class TestErrorMiddleware : public ::testing::Test {
1158 public:
SetUp()1159 void SetUp() {
1160 ASSERT_OK(MakeServer<ErrorMiddlewareServer>(
1161 &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
1162 [](FlightClientOptions* options) { return Status::OK(); }));
1163 }
1164
TearDown()1165 void TearDown() { ASSERT_OK(server_->Shutdown()); }
1166
1167 protected:
1168 std::unique_ptr<FlightClient> client_;
1169 std::unique_ptr<FlightServerBase> server_;
1170 };
1171
1172 class TestBasicHeaderAuthMiddleware : public ::testing::Test {
1173 public:
SetUp()1174 void SetUp() {
1175 header_middleware_ = std::make_shared<HeaderAuthServerMiddlewareFactory>();
1176 bearer_middleware_ = std::make_shared<BearerAuthServerMiddlewareFactory>();
1177 std::pair<std::string, std::string> bearer = make_pair(
1178 kAuthHeader, std::string(kBearerPrefix) + " " + std::string(kBearerToken));
1179 ASSERT_OK(MakeServer<HeaderAuthTestServer>(
1180 &server_, &client_,
1181 [&](FlightServerOptions* options) {
1182 options->auth_handler =
1183 std::unique_ptr<ServerAuthHandler>(new NoOpAuthHandler());
1184 options->middleware.push_back({"header-auth-server", header_middleware_});
1185 options->middleware.push_back({"bearer-auth-server", bearer_middleware_});
1186 return Status::OK();
1187 },
1188 [&](FlightClientOptions* options) { return Status::OK(); }));
1189 }
1190
RunValidClientAuth()1191 void RunValidClientAuth() {
1192 arrow::Result<std::pair<std::string, std::string>> bearer_result =
1193 client_->AuthenticateBasicToken({}, kValidUsername, kValidPassword);
1194 ASSERT_OK(bearer_result.status());
1195 ASSERT_EQ(bearer_result.ValueOrDie().first, kAuthHeader);
1196 ASSERT_EQ(bearer_result.ValueOrDie().second,
1197 (std::string(kBearerPrefix) + kBearerToken));
1198 std::unique_ptr<FlightListing> listing;
1199 FlightCallOptions call_options;
1200 call_options.headers.push_back(bearer_result.ValueOrDie());
1201 ASSERT_OK(client_->ListFlights(call_options, {}, &listing));
1202 ASSERT_TRUE(bearer_middleware_->GetIsValid());
1203 }
1204
RunInvalidClientAuth()1205 void RunInvalidClientAuth() {
1206 arrow::Result<std::pair<std::string, std::string>> bearer_result =
1207 client_->AuthenticateBasicToken({}, kInvalidUsername, kInvalidPassword);
1208 ASSERT_RAISES(IOError, bearer_result.status());
1209 ASSERT_THAT(bearer_result.status().message(),
1210 ::testing::HasSubstr("Invalid credentials"));
1211 }
1212
TearDown()1213 void TearDown() { ASSERT_OK(server_->Shutdown()); }
1214
1215 protected:
1216 std::unique_ptr<FlightClient> client_;
1217 std::unique_ptr<FlightServerBase> server_;
1218 std::shared_ptr<HeaderAuthServerMiddlewareFactory> header_middleware_;
1219 std::shared_ptr<BearerAuthServerMiddlewareFactory> bearer_middleware_;
1220 };
1221
1222 // This test keeps an internal cookie cache and compares that with the middleware.
1223 class TestCookieMiddleware : public ::testing::Test {
1224 public:
1225 // Setup function creates middleware factory and starts it up.
SetUp()1226 void SetUp() {
1227 factory_ = GetCookieFactory();
1228 CallInfo callInfo;
1229 factory_->StartCall(callInfo, &middleware_);
1230 }
1231
1232 // Function to add incoming cookies to middleware and validate them.
AddAndValidate(const std::string & incoming_cookie)1233 void AddAndValidate(const std::string& incoming_cookie) {
1234 // Add cookie
1235 CallHeaders call_headers;
1236 call_headers.insert(std::make_pair(arrow::util::string_view("set-cookie"),
1237 arrow::util::string_view(incoming_cookie)));
1238 middleware_->ReceivedHeaders(call_headers);
1239 expected_cookie_cache_.UpdateCachedCookies(call_headers);
1240
1241 // Get cookie from middleware.
1242 TestCallHeaders add_call_headers;
1243 middleware_->SendingHeaders(&add_call_headers);
1244 const std::string actual_cookies = add_call_headers.GetCookies();
1245
1246 // Validate cookie
1247 const std::string expected_cookies = expected_cookie_cache_.GetValidCookiesAsString();
1248 const std::vector<std::string> split_expected_cookies =
1249 SplitCookies(expected_cookies);
1250 const std::vector<std::string> split_actual_cookies = SplitCookies(actual_cookies);
1251 EXPECT_EQ(split_expected_cookies, split_actual_cookies);
1252 }
1253
1254 // Function to take a list of cookies and split them into a vector of individual
1255 // cookies. This is done because the cookie cache is a map so ordering is not
1256 // necessarily consistent.
SplitCookies(const std::string & cookies)1257 static std::vector<std::string> SplitCookies(const std::string& cookies) {
1258 std::vector<std::string> split_cookies;
1259 std::string::size_type pos1 = 0;
1260 std::string::size_type pos2 = 0;
1261 while ((pos2 = cookies.find(';', pos1)) != std::string::npos) {
1262 split_cookies.push_back(
1263 arrow::internal::TrimString(cookies.substr(pos1, pos2 - pos1)));
1264 pos1 = pos2 + 1;
1265 }
1266 if (pos1 < cookies.size()) {
1267 split_cookies.push_back(arrow::internal::TrimString(cookies.substr(pos1)));
1268 }
1269 std::sort(split_cookies.begin(), split_cookies.end());
1270 return split_cookies;
1271 }
1272
1273 protected:
1274 // Class to allow testing of the call headers.
1275 class TestCallHeaders : public AddCallHeaders {
1276 public:
TestCallHeaders()1277 TestCallHeaders() {}
~TestCallHeaders()1278 ~TestCallHeaders() {}
1279
1280 // Function to add cookie header.
AddHeader(const std::string & key,const std::string & value)1281 void AddHeader(const std::string& key, const std::string& value) {
1282 ASSERT_EQ(key, "cookie");
1283 outbound_cookie_ = value;
1284 }
1285
1286 // Function to get outgoing cookie.
GetCookies()1287 std::string GetCookies() { return outbound_cookie_; }
1288
1289 private:
1290 std::string outbound_cookie_;
1291 };
1292
1293 internal::CookieCache expected_cookie_cache_;
1294 std::unique_ptr<ClientMiddleware> middleware_;
1295 std::shared_ptr<ClientMiddlewareFactory> factory_;
1296 };
1297
1298 // This test is used to test the parsing capabilities of the cookie framework.
1299 class TestCookieParsing : public ::testing::Test {
1300 public:
VerifyParseCookie(const std::string & cookie_str,bool expired)1301 void VerifyParseCookie(const std::string& cookie_str, bool expired) {
1302 internal::Cookie cookie = internal::Cookie::parse(cookie_str);
1303 EXPECT_EQ(expired, cookie.IsExpired());
1304 }
1305
VerifyCookieName(const std::string & cookie_str,const std::string & name)1306 void VerifyCookieName(const std::string& cookie_str, const std::string& name) {
1307 internal::Cookie cookie = internal::Cookie::parse(cookie_str);
1308 EXPECT_EQ(name, cookie.GetName());
1309 }
1310
VerifyCookieString(const std::string & cookie_str,const std::string & cookie_as_string)1311 void VerifyCookieString(const std::string& cookie_str,
1312 const std::string& cookie_as_string) {
1313 internal::Cookie cookie = internal::Cookie::parse(cookie_str);
1314 EXPECT_EQ(cookie_as_string, cookie.AsCookieString());
1315 }
1316
VerifyCookieDateConverson(std::string date,const std::string & converted_date)1317 void VerifyCookieDateConverson(std::string date, const std::string& converted_date) {
1318 internal::Cookie::ConvertCookieDate(&date);
1319 EXPECT_EQ(converted_date, date);
1320 }
1321
VerifyCookieAttributeParsing(const std::string cookie_str,std::string::size_type start_pos,const util::optional<std::pair<std::string,std::string>> cookie_attribute,const std::string::size_type start_pos_after)1322 void VerifyCookieAttributeParsing(
1323 const std::string cookie_str, std::string::size_type start_pos,
1324 const util::optional<std::pair<std::string, std::string>> cookie_attribute,
1325 const std::string::size_type start_pos_after) {
1326 util::optional<std::pair<std::string, std::string>> attr =
1327 internal::Cookie::ParseCookieAttribute(cookie_str, &start_pos);
1328
1329 if (cookie_attribute == util::nullopt) {
1330 EXPECT_EQ(cookie_attribute, attr);
1331 } else {
1332 EXPECT_EQ(cookie_attribute.value(), attr.value());
1333 }
1334 EXPECT_EQ(start_pos_after, start_pos);
1335 }
1336
AddCookieVerifyCache(const std::vector<std::string> & cookies,const std::string & expected_cookies)1337 void AddCookieVerifyCache(const std::vector<std::string>& cookies,
1338 const std::string& expected_cookies) {
1339 internal::CookieCache cookie_cache;
1340 for (auto& cookie : cookies) {
1341 // Add cookie
1342 CallHeaders call_headers;
1343 call_headers.insert(std::make_pair(arrow::util::string_view("set-cookie"),
1344 arrow::util::string_view(cookie)));
1345 cookie_cache.UpdateCachedCookies(call_headers);
1346 }
1347 const std::string actual_cookies = cookie_cache.GetValidCookiesAsString();
1348 const std::vector<std::string> actual_split_cookies =
1349 TestCookieMiddleware::SplitCookies(actual_cookies);
1350 const std::vector<std::string> expected_split_cookies =
1351 TestCookieMiddleware::SplitCookies(expected_cookies);
1352 }
1353 };
1354
TEST_F(TestErrorMiddleware,TestMetadata)1355 TEST_F(TestErrorMiddleware, TestMetadata) {
1356 Action action;
1357 std::unique_ptr<ResultStream> stream;
1358
1359 // Run action1
1360 action.type = "action1";
1361
1362 action.body = Buffer::FromString("action1-content");
1363 Status s = client_->DoAction(action, &stream);
1364 ASSERT_FALSE(s.ok());
1365 std::shared_ptr<FlightStatusDetail> flightStatusDetail =
1366 FlightStatusDetail::UnwrapStatus(s);
1367 ASSERT_TRUE(flightStatusDetail);
1368 ASSERT_EQ(flightStatusDetail->extra_info(), "error_message");
1369 }
1370
TEST_F(TestFlightClient,ListFlights)1371 TEST_F(TestFlightClient, ListFlights) {
1372 std::unique_ptr<FlightListing> listing;
1373 ASSERT_OK(client_->ListFlights(&listing));
1374 ASSERT_TRUE(listing != nullptr);
1375
1376 std::vector<FlightInfo> flights = ExampleFlightInfo();
1377
1378 std::unique_ptr<FlightInfo> info;
1379 for (const FlightInfo& flight : flights) {
1380 ASSERT_OK(listing->Next(&info));
1381 AssertEqual(flight, *info);
1382 }
1383 ASSERT_OK(listing->Next(&info));
1384 ASSERT_TRUE(info == nullptr);
1385
1386 ASSERT_OK(listing->Next(&info));
1387 ASSERT_TRUE(info == nullptr);
1388 }
1389
TEST_F(TestFlightClient,ListFlightsWithCriteria)1390 TEST_F(TestFlightClient, ListFlightsWithCriteria) {
1391 std::unique_ptr<FlightListing> listing;
1392 ASSERT_OK(client_->ListFlights(FlightCallOptions(), {"foo"}, &listing));
1393 std::unique_ptr<FlightInfo> info;
1394 ASSERT_OK(listing->Next(&info));
1395 ASSERT_TRUE(info == nullptr);
1396 }
1397
TEST_F(TestFlightClient,GetFlightInfo)1398 TEST_F(TestFlightClient, GetFlightInfo) {
1399 auto descr = FlightDescriptor::Path({"examples", "ints"});
1400 std::unique_ptr<FlightInfo> info;
1401
1402 ASSERT_OK(client_->GetFlightInfo(descr, &info));
1403 ASSERT_NE(info, nullptr);
1404
1405 std::vector<FlightInfo> flights = ExampleFlightInfo();
1406 AssertEqual(flights[0], *info);
1407 }
1408
TEST_F(TestFlightClient,GetSchema)1409 TEST_F(TestFlightClient, GetSchema) {
1410 auto descr = FlightDescriptor::Path({"examples", "ints"});
1411 std::unique_ptr<SchemaResult> schema_result;
1412 std::shared_ptr<Schema> schema;
1413 ipc::DictionaryMemo dict_memo;
1414
1415 ASSERT_OK(client_->GetSchema(descr, &schema_result));
1416 ASSERT_NE(schema_result, nullptr);
1417 ASSERT_OK(schema_result->GetSchema(&dict_memo, &schema));
1418 }
1419
TEST_F(TestFlightClient,GetFlightInfoNotFound)1420 TEST_F(TestFlightClient, GetFlightInfoNotFound) {
1421 auto descr = FlightDescriptor::Path({"examples", "things"});
1422 std::unique_ptr<FlightInfo> info;
1423 // XXX Ideally should be Invalid (or KeyError), but gRPC doesn't support
1424 // multiple error codes.
1425 auto st = client_->GetFlightInfo(descr, &info);
1426 ASSERT_RAISES(Invalid, st);
1427 ASSERT_NE(st.message().find("Flight not found"), std::string::npos);
1428 }
1429
TEST_F(TestFlightClient,DoGetInts)1430 TEST_F(TestFlightClient, DoGetInts) {
1431 auto descr = FlightDescriptor::Path({"examples", "ints"});
1432 BatchVector expected_batches;
1433 ASSERT_OK(ExampleIntBatches(&expected_batches));
1434
1435 auto check_endpoints = [](const std::vector<FlightEndpoint>& endpoints) {
1436 // Two endpoints in the example FlightInfo
1437 ASSERT_EQ(2, endpoints.size());
1438 AssertEqual(Ticket{"ticket-ints-1"}, endpoints[0].ticket);
1439 };
1440
1441 CheckDoGet(descr, expected_batches, check_endpoints);
1442 }
1443
TEST_F(TestFlightClient,DoGetFloats)1444 TEST_F(TestFlightClient, DoGetFloats) {
1445 auto descr = FlightDescriptor::Path({"examples", "floats"});
1446 BatchVector expected_batches;
1447 ASSERT_OK(ExampleFloatBatches(&expected_batches));
1448
1449 auto check_endpoints = [](const std::vector<FlightEndpoint>& endpoints) {
1450 // One endpoint in the example FlightInfo
1451 ASSERT_EQ(1, endpoints.size());
1452 AssertEqual(Ticket{"ticket-floats-1"}, endpoints[0].ticket);
1453 };
1454
1455 CheckDoGet(descr, expected_batches, check_endpoints);
1456 }
1457
TEST_F(TestFlightClient,DoGetDicts)1458 TEST_F(TestFlightClient, DoGetDicts) {
1459 auto descr = FlightDescriptor::Path({"examples", "dicts"});
1460 BatchVector expected_batches;
1461 ASSERT_OK(ExampleDictBatches(&expected_batches));
1462
1463 auto check_endpoints = [](const std::vector<FlightEndpoint>& endpoints) {
1464 // One endpoint in the example FlightInfo
1465 ASSERT_EQ(1, endpoints.size());
1466 AssertEqual(Ticket{"ticket-dicts-1"}, endpoints[0].ticket);
1467 };
1468
1469 CheckDoGet(descr, expected_batches, check_endpoints);
1470 }
1471
1472 // Ensure the gRPC client is configured to allow large messages
1473 // Tests a 32 MiB batch
TEST_F(TestFlightClient,DoGetLargeBatch)1474 TEST_F(TestFlightClient, DoGetLargeBatch) {
1475 BatchVector expected_batches;
1476 ASSERT_OK(ExampleLargeBatches(&expected_batches));
1477 Ticket ticket{"ticket-large-batch-1"};
1478 CheckDoGet(ticket, expected_batches);
1479 }
1480
TEST_F(TestFlightClient,FlightDataOverflowServerBatch)1481 TEST_F(TestFlightClient, FlightDataOverflowServerBatch) {
1482 // Regression test for ARROW-13253
1483 // N.B. this is rather a slow and memory-hungry test
1484 {
1485 // DoGet: check for overflow on large batch
1486 Ticket ticket{"ARROW-13253-DoGet-Batch"};
1487 std::unique_ptr<FlightStreamReader> stream;
1488 ASSERT_OK(client_->DoGet(ticket, &stream));
1489 FlightStreamChunk chunk;
1490 EXPECT_RAISES_WITH_MESSAGE_THAT(
1491 Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
1492 stream->Next(&chunk));
1493 }
1494 {
1495 // DoExchange: check for overflow on large batch from server
1496 auto descr = FlightDescriptor::Command("large_batch");
1497 std::unique_ptr<FlightStreamReader> reader;
1498 std::unique_ptr<FlightStreamWriter> writer;
1499 ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1500 BatchVector batches;
1501 EXPECT_RAISES_WITH_MESSAGE_THAT(
1502 Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
1503 reader->ReadAll(&batches));
1504 }
1505 }
1506
TEST_F(TestFlightClient,FlightDataOverflowClientBatch)1507 TEST_F(TestFlightClient, FlightDataOverflowClientBatch) {
1508 ASSERT_OK_AND_ASSIGN(auto batch, VeryLargeBatch());
1509 {
1510 // DoPut: check for overflow on large batch
1511 std::unique_ptr<FlightStreamWriter> stream;
1512 std::unique_ptr<FlightMetadataReader> reader;
1513 auto descr = FlightDescriptor::Path({""});
1514 ASSERT_OK(client_->DoPut(descr, batch->schema(), &stream, &reader));
1515 EXPECT_RAISES_WITH_MESSAGE_THAT(
1516 Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
1517 stream->WriteRecordBatch(*batch));
1518 ASSERT_OK(stream->Close());
1519 }
1520 {
1521 // DoExchange: check for overflow on large batch from client
1522 auto descr = FlightDescriptor::Command("counter");
1523 std::unique_ptr<FlightStreamReader> reader;
1524 std::unique_ptr<FlightStreamWriter> writer;
1525 ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1526 ASSERT_OK(writer->Begin(batch->schema()));
1527 EXPECT_RAISES_WITH_MESSAGE_THAT(
1528 Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
1529 writer->WriteRecordBatch(*batch));
1530 ASSERT_OK(writer->Close());
1531 }
1532 }
1533
TEST_F(TestFlightClient,DoExchange)1534 TEST_F(TestFlightClient, DoExchange) {
1535 auto descr = FlightDescriptor::Command("counter");
1536 BatchVector batches;
1537 auto a1 = ArrayFromJSON(int32(), "[4, 5, 6, null]");
1538 auto schema = arrow::schema({field("f1", a1->type())});
1539 batches.push_back(RecordBatch::Make(schema, a1->length(), {a1}));
1540 std::unique_ptr<FlightStreamReader> reader;
1541 std::unique_ptr<FlightStreamWriter> writer;
1542 ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1543 ASSERT_OK(writer->Begin(schema));
1544 for (const auto& batch : batches) {
1545 ASSERT_OK(writer->WriteRecordBatch(*batch));
1546 }
1547 ASSERT_OK(writer->DoneWriting());
1548 FlightStreamChunk chunk;
1549 ASSERT_OK(reader->Next(&chunk));
1550 ASSERT_NE(nullptr, chunk.app_metadata);
1551 ASSERT_EQ(nullptr, chunk.data);
1552 ASSERT_EQ("1", chunk.app_metadata->ToString());
1553 ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema());
1554 AssertSchemaEqual(schema, server_schema);
1555 for (const auto& batch : batches) {
1556 ASSERT_OK(reader->Next(&chunk));
1557 ASSERT_BATCHES_EQUAL(*batch, *chunk.data);
1558 }
1559 ASSERT_OK(writer->Close());
1560 }
1561
1562 // Test pure-metadata DoExchange to ensure nothing blocks waiting for
1563 // schema messages
TEST_F(TestFlightClient,DoExchangeNoData)1564 TEST_F(TestFlightClient, DoExchangeNoData) {
1565 auto descr = FlightDescriptor::Command("counter");
1566 std::unique_ptr<FlightStreamReader> reader;
1567 std::unique_ptr<FlightStreamWriter> writer;
1568 ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1569 ASSERT_OK(writer->DoneWriting());
1570 FlightStreamChunk chunk;
1571 ASSERT_OK(reader->Next(&chunk));
1572 ASSERT_EQ(nullptr, chunk.data);
1573 ASSERT_NE(nullptr, chunk.app_metadata);
1574 ASSERT_EQ("0", chunk.app_metadata->ToString());
1575 ASSERT_OK(writer->Close());
1576 }
1577
1578 // Test sending a schema without any data, as this hits an edge case
1579 // in the client-side writer.
TEST_F(TestFlightClient,DoExchangeWriteOnlySchema)1580 TEST_F(TestFlightClient, DoExchangeWriteOnlySchema) {
1581 auto descr = FlightDescriptor::Command("counter");
1582 std::unique_ptr<FlightStreamReader> reader;
1583 std::unique_ptr<FlightStreamWriter> writer;
1584 ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1585 auto schema = arrow::schema({field("f1", arrow::int32())});
1586 ASSERT_OK(writer->Begin(schema));
1587 ASSERT_OK(writer->WriteMetadata(Buffer::FromString("foo")));
1588 ASSERT_OK(writer->DoneWriting());
1589 FlightStreamChunk chunk;
1590 ASSERT_OK(reader->Next(&chunk));
1591 ASSERT_EQ(nullptr, chunk.data);
1592 ASSERT_NE(nullptr, chunk.app_metadata);
1593 ASSERT_EQ("0", chunk.app_metadata->ToString());
1594 ASSERT_OK(writer->Close());
1595 }
1596
1597 // Emulate DoGet
TEST_F(TestFlightClient,DoExchangeGet)1598 TEST_F(TestFlightClient, DoExchangeGet) {
1599 auto descr = FlightDescriptor::Command("get");
1600 std::unique_ptr<FlightStreamReader> reader;
1601 std::unique_ptr<FlightStreamWriter> writer;
1602 ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1603 ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema());
1604 AssertSchemaEqual(*ExampleIntSchema(), *server_schema);
1605 BatchVector batches;
1606 ASSERT_OK(ExampleIntBatches(&batches));
1607 FlightStreamChunk chunk;
1608 for (const auto& batch : batches) {
1609 ASSERT_OK(reader->Next(&chunk));
1610 ASSERT_NE(nullptr, chunk.data);
1611 AssertBatchesEqual(*batch, *chunk.data);
1612 }
1613 ASSERT_OK(reader->Next(&chunk));
1614 ASSERT_EQ(nullptr, chunk.data);
1615 ASSERT_EQ(nullptr, chunk.app_metadata);
1616 ASSERT_OK(writer->Close());
1617 }
1618
1619 // Emulate DoPut
TEST_F(TestFlightClient,DoExchangePut)1620 TEST_F(TestFlightClient, DoExchangePut) {
1621 auto descr = FlightDescriptor::Command("put");
1622 std::unique_ptr<FlightStreamReader> reader;
1623 std::unique_ptr<FlightStreamWriter> writer;
1624 ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1625 ASSERT_OK(writer->Begin(ExampleIntSchema()));
1626 BatchVector batches;
1627 ASSERT_OK(ExampleIntBatches(&batches));
1628 for (const auto& batch : batches) {
1629 ASSERT_OK(writer->WriteRecordBatch(*batch));
1630 }
1631 ASSERT_OK(writer->DoneWriting());
1632 FlightStreamChunk chunk;
1633 ASSERT_OK(reader->Next(&chunk));
1634 ASSERT_NE(nullptr, chunk.app_metadata);
1635 AssertBufferEqual(*chunk.app_metadata, "done");
1636 ASSERT_OK(reader->Next(&chunk));
1637 ASSERT_EQ(nullptr, chunk.data);
1638 ASSERT_EQ(nullptr, chunk.app_metadata);
1639 ASSERT_OK(writer->Close());
1640 }
1641
1642 // Test the echo server
TEST_F(TestFlightClient,DoExchangeEcho)1643 TEST_F(TestFlightClient, DoExchangeEcho) {
1644 auto descr = FlightDescriptor::Command("echo");
1645 std::unique_ptr<FlightStreamReader> reader;
1646 std::unique_ptr<FlightStreamWriter> writer;
1647 ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1648 ASSERT_OK(writer->Begin(ExampleIntSchema()));
1649 BatchVector batches;
1650 FlightStreamChunk chunk;
1651 ASSERT_OK(ExampleIntBatches(&batches));
1652 for (const auto& batch : batches) {
1653 ASSERT_OK(writer->WriteRecordBatch(*batch));
1654 ASSERT_OK(reader->Next(&chunk));
1655 ASSERT_NE(nullptr, chunk.data);
1656 ASSERT_EQ(nullptr, chunk.app_metadata);
1657 AssertBatchesEqual(*batch, *chunk.data);
1658 }
1659 for (int i = 0; i < 10; i++) {
1660 const auto buf = Buffer::FromString(std::to_string(i));
1661 ASSERT_OK(writer->WriteMetadata(buf));
1662 ASSERT_OK(reader->Next(&chunk));
1663 ASSERT_EQ(nullptr, chunk.data);
1664 ASSERT_NE(nullptr, chunk.app_metadata);
1665 AssertBufferEqual(*buf, *chunk.app_metadata);
1666 }
1667 int index = 0;
1668 for (const auto& batch : batches) {
1669 const auto buf = Buffer::FromString(std::to_string(index));
1670 ASSERT_OK(writer->WriteWithMetadata(*batch, buf));
1671 ASSERT_OK(reader->Next(&chunk));
1672 ASSERT_NE(nullptr, chunk.data);
1673 ASSERT_NE(nullptr, chunk.app_metadata);
1674 AssertBatchesEqual(*batch, *chunk.data);
1675 AssertBufferEqual(*buf, *chunk.app_metadata);
1676 index++;
1677 }
1678 ASSERT_OK(writer->DoneWriting());
1679 ASSERT_OK(reader->Next(&chunk));
1680 ASSERT_EQ(nullptr, chunk.data);
1681 ASSERT_EQ(nullptr, chunk.app_metadata);
1682 ASSERT_OK(writer->Close());
1683 }
1684
1685 // Test interleaved reading/writing
TEST_F(TestFlightClient,DoExchangeTotal)1686 TEST_F(TestFlightClient, DoExchangeTotal) {
1687 auto descr = FlightDescriptor::Command("total");
1688 std::unique_ptr<FlightStreamReader> reader;
1689 std::unique_ptr<FlightStreamWriter> writer;
1690 {
1691 auto a1 = ArrayFromJSON(arrow::int32(), "[4, 5, 6, null]");
1692 auto schema = arrow::schema({field("f1", a1->type())});
1693 // XXX: as noted in flight/client.cc, Begin() is lazy and the
1694 // schema message won't be written until some data is also
1695 // written. There's also timing issues; hence we check each status
1696 // here.
1697 EXPECT_RAISES_WITH_MESSAGE_THAT(
1698 Invalid, ::testing::HasSubstr("Field is not INT64: f1"), ([&]() {
1699 RETURN_NOT_OK(client_->DoExchange(descr, &writer, &reader));
1700 RETURN_NOT_OK(writer->Begin(schema));
1701 auto batch = RecordBatch::Make(schema, /* num_rows */ 4, {a1});
1702 RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
1703 return writer->Close();
1704 })());
1705 }
1706 {
1707 auto a1 = ArrayFromJSON(arrow::int64(), "[1, 2, null, 3]");
1708 auto a2 = ArrayFromJSON(arrow::int64(), "[null, 4, 5, 6]");
1709 auto schema = arrow::schema({field("f1", a1->type()), field("f2", a2->type())});
1710 ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1711 ASSERT_OK(writer->Begin(schema));
1712 auto batch = RecordBatch::Make(schema, /* num_rows */ 4, {a1, a2});
1713 FlightStreamChunk chunk;
1714 ASSERT_OK(writer->WriteRecordBatch(*batch));
1715 ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema());
1716 AssertSchemaEqual(*schema, *server_schema);
1717
1718 ASSERT_OK(reader->Next(&chunk));
1719 ASSERT_NE(nullptr, chunk.data);
1720 auto expected1 = RecordBatch::Make(
1721 schema, /* num_rows */ 1,
1722 {ArrayFromJSON(arrow::int64(), "[6]"), ArrayFromJSON(arrow::int64(), "[15]")});
1723 AssertBatchesEqual(*expected1, *chunk.data);
1724
1725 ASSERT_OK(writer->WriteRecordBatch(*batch));
1726 ASSERT_OK(reader->Next(&chunk));
1727 ASSERT_NE(nullptr, chunk.data);
1728 auto expected2 = RecordBatch::Make(
1729 schema, /* num_rows */ 1,
1730 {ArrayFromJSON(arrow::int64(), "[12]"), ArrayFromJSON(arrow::int64(), "[30]")});
1731 AssertBatchesEqual(*expected2, *chunk.data);
1732
1733 ASSERT_OK(writer->Close());
1734 }
1735 }
1736
1737 // Ensure server errors get propagated no matter what we try
TEST_F(TestFlightClient,DoExchangeError)1738 TEST_F(TestFlightClient, DoExchangeError) {
1739 auto descr = FlightDescriptor::Command("error");
1740 std::unique_ptr<FlightStreamReader> reader;
1741 std::unique_ptr<FlightStreamWriter> writer;
1742 {
1743 ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1744 auto status = writer->Close();
1745 EXPECT_RAISES_WITH_MESSAGE_THAT(
1746 NotImplemented, ::testing::HasSubstr("Expected error"), writer->Close());
1747 }
1748 {
1749 ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1750 FlightStreamChunk chunk;
1751 EXPECT_RAISES_WITH_MESSAGE_THAT(
1752 NotImplemented, ::testing::HasSubstr("Expected error"), reader->Next(&chunk));
1753 }
1754 {
1755 ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1756 EXPECT_RAISES_WITH_MESSAGE_THAT(
1757 NotImplemented, ::testing::HasSubstr("Expected error"), reader->GetSchema());
1758 }
1759 // writer->Begin isn't tested here because, as noted in client.cc,
1760 // OpenRecordBatchWriter lazily writes the initial message - hence
1761 // Begin() won't fail. Additionally, it appears gRPC may buffer
1762 // writes - a write won't immediately fail even when the server
1763 // immediately fails.
1764 }
1765
TEST_F(TestFlightClient,ListActions)1766 TEST_F(TestFlightClient, ListActions) {
1767 std::vector<ActionType> actions;
1768 ASSERT_OK(client_->ListActions(&actions));
1769
1770 std::vector<ActionType> expected = ExampleActionTypes();
1771 AssertEqual(expected, actions);
1772 }
1773
TEST_F(TestFlightClient,DoAction)1774 TEST_F(TestFlightClient, DoAction) {
1775 Action action;
1776 std::unique_ptr<ResultStream> stream;
1777 std::unique_ptr<Result> result;
1778
1779 // Run action1
1780 action.type = "action1";
1781
1782 const std::string action1_value = "action1-content";
1783 action.body = Buffer::FromString(action1_value);
1784 ASSERT_OK(client_->DoAction(action, &stream));
1785
1786 for (int i = 0; i < 3; ++i) {
1787 ASSERT_OK(stream->Next(&result));
1788 std::string expected = action1_value + "-part" + std::to_string(i);
1789 ASSERT_EQ(expected, result->body->ToString());
1790 }
1791
1792 // stream consumed
1793 ASSERT_OK(stream->Next(&result));
1794 ASSERT_EQ(nullptr, result);
1795
1796 // Run action2, no results
1797 action.type = "action2";
1798 ASSERT_OK(client_->DoAction(action, &stream));
1799
1800 ASSERT_OK(stream->Next(&result));
1801 ASSERT_EQ(nullptr, result);
1802 }
1803
TEST_F(TestFlightClient,RoundTripStatus)1804 TEST_F(TestFlightClient, RoundTripStatus) {
1805 const auto descr = FlightDescriptor::Command("status-outofmemory");
1806 std::unique_ptr<FlightInfo> info;
1807 const auto status = client_->GetFlightInfo(descr, &info);
1808 ASSERT_RAISES(OutOfMemory, status);
1809 }
1810
TEST_F(TestFlightClient,Issue5095)1811 TEST_F(TestFlightClient, Issue5095) {
1812 // Make sure the server-side error message is reflected to the
1813 // client
1814 Ticket ticket1{"ARROW-5095-fail"};
1815 std::unique_ptr<FlightStreamReader> stream;
1816 Status status = client_->DoGet(ticket1, &stream);
1817 ASSERT_RAISES(UnknownError, status);
1818 ASSERT_THAT(status.message(), ::testing::HasSubstr("Server-side error"));
1819
1820 Ticket ticket2{"ARROW-5095-success"};
1821 status = client_->DoGet(ticket2, &stream);
1822 ASSERT_RAISES(KeyError, status);
1823 ASSERT_THAT(status.message(), ::testing::HasSubstr("No data"));
1824 }
1825
1826 // Test setting generic transport options by configuring gRPC to fail
1827 // all calls.
TEST_F(TestFlightClient,GenericOptions)1828 TEST_F(TestFlightClient, GenericOptions) {
1829 std::unique_ptr<FlightClient> client;
1830 auto options = FlightClientOptions::Defaults();
1831 // Set a very low limit at the gRPC layer to fail all calls
1832 options.generic_options.emplace_back(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, 4);
1833 Location location;
1834 ASSERT_OK(Location::ForGrpcTcp("localhost", server_->port(), &location));
1835 ASSERT_OK(FlightClient::Connect(location, options, &client));
1836 auto descr = FlightDescriptor::Path({"examples", "ints"});
1837 std::unique_ptr<SchemaResult> schema_result;
1838 std::shared_ptr<Schema> schema;
1839 ipc::DictionaryMemo dict_memo;
1840 auto status = client->GetSchema(descr, &schema_result);
1841 ASSERT_RAISES(Invalid, status);
1842 ASSERT_THAT(status.message(), ::testing::HasSubstr("resource exhausted"));
1843 }
1844
TEST_F(TestFlightClient,TimeoutFires)1845 TEST_F(TestFlightClient, TimeoutFires) {
1846 // Server does not exist on this port, so call should fail
1847 std::unique_ptr<FlightClient> client;
1848 Location location;
1849 ASSERT_OK(Location::ForGrpcTcp("localhost", 30001, &location));
1850 ASSERT_OK(FlightClient::Connect(location, &client));
1851 FlightCallOptions options;
1852 options.timeout = TimeoutDuration{0.2};
1853 std::unique_ptr<FlightInfo> info;
1854 auto start = std::chrono::system_clock::now();
1855 Status status = client->GetFlightInfo(options, FlightDescriptor{}, &info);
1856 auto end = std::chrono::system_clock::now();
1857 #ifdef ARROW_WITH_TIMING_TESTS
1858 EXPECT_LE(end - start, std::chrono::milliseconds{400});
1859 #else
1860 ARROW_UNUSED(end - start);
1861 #endif
1862 ASSERT_RAISES(IOError, status);
1863 }
1864
TEST_F(TestFlightClient,NoTimeout)1865 TEST_F(TestFlightClient, NoTimeout) {
1866 // Call should complete quickly, so timeout should not fire
1867 FlightCallOptions options;
1868 options.timeout = TimeoutDuration{5.0}; // account for slow server process startup
1869 std::unique_ptr<FlightInfo> info;
1870 auto start = std::chrono::system_clock::now();
1871 auto descriptor = FlightDescriptor::Path({"examples", "ints"});
1872 Status status = client_->GetFlightInfo(options, descriptor, &info);
1873 auto end = std::chrono::system_clock::now();
1874 #ifdef ARROW_WITH_TIMING_TESTS
1875 EXPECT_LE(end - start, std::chrono::milliseconds{600});
1876 #else
1877 ARROW_UNUSED(end - start);
1878 #endif
1879 ASSERT_OK(status);
1880 ASSERT_NE(nullptr, info);
1881 }
1882
TEST_F(TestDoPut,DoPutInts)1883 TEST_F(TestDoPut, DoPutInts) {
1884 auto descr = FlightDescriptor::Path({"ints"});
1885 BatchVector batches;
1886 auto a0 = ArrayFromJSON(int8(), "[0, 1, 127, -128, null]");
1887 auto a1 = ArrayFromJSON(uint8(), "[0, 1, 127, 255, null]");
1888 auto a2 = ArrayFromJSON(int16(), "[0, 258, 32767, -32768, null]");
1889 auto a3 = ArrayFromJSON(uint16(), "[0, 258, 32767, 65535, null]");
1890 auto a4 = ArrayFromJSON(int32(), "[0, 65538, 2147483647, -2147483648, null]");
1891 auto a5 = ArrayFromJSON(uint32(), "[0, 65538, 2147483647, 4294967295, null]");
1892 auto a6 = ArrayFromJSON(
1893 int64(), "[0, 4294967298, 9223372036854775807, -9223372036854775808, null]");
1894 auto a7 = ArrayFromJSON(
1895 uint64(), "[0, 4294967298, 9223372036854775807, 18446744073709551615, null]");
1896 auto schema = arrow::schema({field("f0", a0->type()), field("f1", a1->type()),
1897 field("f2", a2->type()), field("f3", a3->type()),
1898 field("f4", a4->type()), field("f5", a5->type()),
1899 field("f6", a6->type()), field("f7", a7->type())});
1900 batches.push_back(
1901 RecordBatch::Make(schema, a0->length(), {a0, a1, a2, a3, a4, a5, a6, a7}));
1902
1903 CheckDoPut(descr, schema, batches);
1904 }
1905
TEST_F(TestDoPut,DoPutFloats)1906 TEST_F(TestDoPut, DoPutFloats) {
1907 auto descr = FlightDescriptor::Path({"floats"});
1908 BatchVector batches;
1909 auto a0 = ArrayFromJSON(float32(), "[0, 1.2, -3.4, 5.6, null]");
1910 auto a1 = ArrayFromJSON(float64(), "[0, 1.2, -3.4, 5.6, null]");
1911 auto schema = arrow::schema({field("f0", a0->type()), field("f1", a1->type())});
1912 batches.push_back(RecordBatch::Make(schema, a0->length(), {a0, a1}));
1913
1914 CheckDoPut(descr, schema, batches);
1915 }
1916
TEST_F(TestDoPut,DoPutEmptyBatch)1917 TEST_F(TestDoPut, DoPutEmptyBatch) {
1918 // Sending and receiving a 0-sized batch shouldn't fail
1919 auto descr = FlightDescriptor::Path({"ints"});
1920 BatchVector batches;
1921 auto a1 = ArrayFromJSON(int32(), "[]");
1922 auto schema = arrow::schema({field("f1", a1->type())});
1923 batches.push_back(RecordBatch::Make(schema, a1->length(), {a1}));
1924
1925 CheckDoPut(descr, schema, batches);
1926 }
1927
TEST_F(TestDoPut,DoPutDicts)1928 TEST_F(TestDoPut, DoPutDicts) {
1929 auto descr = FlightDescriptor::Path({"dicts"});
1930 BatchVector batches;
1931 auto dict_values = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"quux\"]");
1932 auto ty = dictionary(int8(), dict_values->type());
1933 auto schema = arrow::schema({field("f1", ty)});
1934 // Make several batches
1935 for (const char* json : {"[1, 0, 1]", "[null]", "[null, 1]"}) {
1936 auto indices = ArrayFromJSON(int8(), json);
1937 auto dict_array = std::make_shared<DictionaryArray>(ty, indices, dict_values);
1938 batches.push_back(RecordBatch::Make(schema, dict_array->length(), {dict_array}));
1939 }
1940
1941 CheckDoPut(descr, schema, batches);
1942 }
1943
1944 // Ensure the gRPC server is configured to allow large messages
1945 // Tests a 32 MiB batch
TEST_F(TestDoPut,DoPutLargeBatch)1946 TEST_F(TestDoPut, DoPutLargeBatch) {
1947 auto descr = FlightDescriptor::Path({"large-batches"});
1948 auto schema = ExampleLargeSchema();
1949 BatchVector batches;
1950 ASSERT_OK(ExampleLargeBatches(&batches));
1951 CheckDoPut(descr, schema, batches);
1952 }
1953
TEST_F(TestDoPut,DoPutSizeLimit)1954 TEST_F(TestDoPut, DoPutSizeLimit) {
1955 const int64_t size_limit = 4096;
1956 Location location;
1957 ASSERT_OK(Location::ForGrpcTcp("localhost", server_->port(), &location));
1958 auto client_options = FlightClientOptions::Defaults();
1959 client_options.write_size_limit_bytes = size_limit;
1960 std::unique_ptr<FlightClient> client;
1961 ASSERT_OK(FlightClient::Connect(location, client_options, &client));
1962
1963 auto descr = FlightDescriptor::Path({"ints"});
1964 // Batch is too large to fit in one message
1965 auto schema = arrow::schema({field("f1", arrow::int64())});
1966 auto batch = arrow::ConstantArrayGenerator::Zeroes(768, schema);
1967 BatchVector batches;
1968 batches.push_back(batch->Slice(0, 384));
1969 batches.push_back(batch->Slice(384));
1970
1971 std::unique_ptr<FlightStreamWriter> stream;
1972 std::unique_ptr<FlightMetadataReader> reader;
1973 ASSERT_OK(client->DoPut(descr, schema, &stream, &reader));
1974
1975 // Large batch will exceed the limit
1976 const auto status = stream->WriteRecordBatch(*batch);
1977 EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("exceeded soft limit"),
1978 status);
1979 auto detail = FlightWriteSizeStatusDetail::UnwrapStatus(status);
1980 ASSERT_NE(nullptr, detail);
1981 ASSERT_EQ(size_limit, detail->limit());
1982 ASSERT_GT(detail->actual(), size_limit);
1983
1984 // But we can retry with a smaller batch
1985 for (const auto& batch : batches) {
1986 ASSERT_OK(stream->WriteRecordBatch(*batch));
1987 }
1988
1989 ASSERT_OK(stream->DoneWriting());
1990 ASSERT_OK(stream->Close());
1991 CheckBatches(descr, batches);
1992 }
1993
TEST_F(TestAuthHandler,PassAuthenticatedCalls)1994 TEST_F(TestAuthHandler, PassAuthenticatedCalls) {
1995 ASSERT_OK(client_->Authenticate(
1996 {},
1997 std::unique_ptr<ClientAuthHandler>(new TestClientAuthHandler("user", "p4ssw0rd"))));
1998
1999 Status status;
2000 std::unique_ptr<FlightListing> listing;
2001 status = client_->ListFlights(&listing);
2002 ASSERT_RAISES(NotImplemented, status);
2003
2004 std::unique_ptr<ResultStream> results;
2005 Action action;
2006 action.type = "";
2007 action.body = Buffer::FromString("");
2008 status = client_->DoAction(action, &results);
2009 ASSERT_OK(status);
2010
2011 std::vector<ActionType> actions;
2012 status = client_->ListActions(&actions);
2013 ASSERT_RAISES(NotImplemented, status);
2014
2015 std::unique_ptr<FlightInfo> info;
2016 status = client_->GetFlightInfo(FlightDescriptor{}, &info);
2017 ASSERT_RAISES(NotImplemented, status);
2018
2019 std::unique_ptr<FlightStreamReader> stream;
2020 status = client_->DoGet(Ticket{}, &stream);
2021 ASSERT_RAISES(NotImplemented, status);
2022
2023 std::unique_ptr<FlightStreamWriter> writer;
2024 std::unique_ptr<FlightMetadataReader> reader;
2025 std::shared_ptr<Schema> schema = arrow::schema({});
2026 status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
2027 ASSERT_OK(status);
2028 status = writer->Close();
2029 ASSERT_RAISES(NotImplemented, status);
2030 }
2031
TEST_F(TestAuthHandler,FailUnauthenticatedCalls)2032 TEST_F(TestAuthHandler, FailUnauthenticatedCalls) {
2033 Status status;
2034 std::unique_ptr<FlightListing> listing;
2035 status = client_->ListFlights(&listing);
2036 ASSERT_RAISES(IOError, status);
2037 ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2038
2039 std::unique_ptr<ResultStream> results;
2040 Action action;
2041 action.type = "";
2042 action.body = Buffer::FromString("");
2043 status = client_->DoAction(action, &results);
2044 ASSERT_RAISES(IOError, status);
2045 ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2046
2047 std::vector<ActionType> actions;
2048 status = client_->ListActions(&actions);
2049 ASSERT_RAISES(IOError, status);
2050 ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2051
2052 std::unique_ptr<FlightInfo> info;
2053 status = client_->GetFlightInfo(FlightDescriptor{}, &info);
2054 ASSERT_RAISES(IOError, status);
2055 ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2056
2057 std::unique_ptr<FlightStreamReader> stream;
2058 status = client_->DoGet(Ticket{}, &stream);
2059 ASSERT_RAISES(IOError, status);
2060 ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2061
2062 std::unique_ptr<FlightStreamWriter> writer;
2063 std::unique_ptr<FlightMetadataReader> reader;
2064 std::shared_ptr<Schema> schema(
2065 (new arrow::Schema(std::vector<std::shared_ptr<Field>>())));
2066 status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
2067 ASSERT_OK(status);
2068 status = writer->Close();
2069 ASSERT_RAISES(IOError, status);
2070 // ARROW-7583: don't check the error message here.
2071 // Because gRPC reports errors in some paths with booleans, instead
2072 // of statuses, we can fail the call without knowing why it fails,
2073 // instead reporting a generic error message. This is
2074 // nondeterministic, so don't assert any particular message here.
2075 }
2076
TEST_F(TestAuthHandler,CheckPeerIdentity)2077 TEST_F(TestAuthHandler, CheckPeerIdentity) {
2078 ASSERT_OK(client_->Authenticate(
2079 {},
2080 std::unique_ptr<ClientAuthHandler>(new TestClientAuthHandler("user", "p4ssw0rd"))));
2081
2082 Action action;
2083 action.type = "who-am-i";
2084 action.body = Buffer::FromString("");
2085 std::unique_ptr<ResultStream> results;
2086 ASSERT_OK(client_->DoAction(action, &results));
2087 ASSERT_NE(results, nullptr);
2088
2089 std::unique_ptr<Result> result;
2090 ASSERT_OK(results->Next(&result));
2091 ASSERT_NE(result, nullptr);
2092 // Action returns the peer identity as the result.
2093 ASSERT_EQ(result->body->ToString(), "user");
2094
2095 ASSERT_OK(results->Next(&result));
2096 ASSERT_NE(result, nullptr);
2097 // Action returns the peer address as the result.
2098 #ifndef _WIN32
2099 // On Windows gRPC sometimes returns a blank peer address, so don't
2100 // bother checking for it.
2101 ASSERT_NE(result->body->ToString(), "");
2102 #endif
2103 }
2104
TEST_F(TestBasicAuthHandler,PassAuthenticatedCalls)2105 TEST_F(TestBasicAuthHandler, PassAuthenticatedCalls) {
2106 ASSERT_OK(
2107 client_->Authenticate({}, std::unique_ptr<ClientAuthHandler>(
2108 new TestClientBasicAuthHandler("user", "p4ssw0rd"))));
2109
2110 Status status;
2111 std::unique_ptr<FlightListing> listing;
2112 status = client_->ListFlights(&listing);
2113 ASSERT_RAISES(NotImplemented, status);
2114
2115 std::unique_ptr<ResultStream> results;
2116 Action action;
2117 action.type = "";
2118 action.body = Buffer::FromString("");
2119 status = client_->DoAction(action, &results);
2120 ASSERT_OK(status);
2121
2122 std::vector<ActionType> actions;
2123 status = client_->ListActions(&actions);
2124 ASSERT_RAISES(NotImplemented, status);
2125
2126 std::unique_ptr<FlightInfo> info;
2127 status = client_->GetFlightInfo(FlightDescriptor{}, &info);
2128 ASSERT_RAISES(NotImplemented, status);
2129
2130 std::unique_ptr<FlightStreamReader> stream;
2131 status = client_->DoGet(Ticket{}, &stream);
2132 ASSERT_RAISES(NotImplemented, status);
2133
2134 std::unique_ptr<FlightStreamWriter> writer;
2135 std::unique_ptr<FlightMetadataReader> reader;
2136 std::shared_ptr<Schema> schema = arrow::schema({});
2137 status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
2138 ASSERT_OK(status);
2139 status = writer->Close();
2140 ASSERT_RAISES(NotImplemented, status);
2141 }
2142
TEST_F(TestBasicAuthHandler,FailUnauthenticatedCalls)2143 TEST_F(TestBasicAuthHandler, FailUnauthenticatedCalls) {
2144 Status status;
2145 std::unique_ptr<FlightListing> listing;
2146 status = client_->ListFlights(&listing);
2147 ASSERT_RAISES(IOError, status);
2148 ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2149
2150 std::unique_ptr<ResultStream> results;
2151 Action action;
2152 action.type = "";
2153 action.body = Buffer::FromString("");
2154 status = client_->DoAction(action, &results);
2155 ASSERT_RAISES(IOError, status);
2156 ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2157
2158 std::vector<ActionType> actions;
2159 status = client_->ListActions(&actions);
2160 ASSERT_RAISES(IOError, status);
2161 ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2162
2163 std::unique_ptr<FlightInfo> info;
2164 status = client_->GetFlightInfo(FlightDescriptor{}, &info);
2165 ASSERT_RAISES(IOError, status);
2166 ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2167
2168 std::unique_ptr<FlightStreamReader> stream;
2169 status = client_->DoGet(Ticket{}, &stream);
2170 ASSERT_RAISES(IOError, status);
2171 ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2172
2173 std::unique_ptr<FlightStreamWriter> writer;
2174 std::unique_ptr<FlightMetadataReader> reader;
2175 std::shared_ptr<Schema> schema(
2176 (new arrow::Schema(std::vector<std::shared_ptr<Field>>())));
2177 status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
2178 ASSERT_OK(status);
2179 status = writer->Close();
2180 ASSERT_RAISES(IOError, status);
2181 ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2182 }
2183
TEST_F(TestBasicAuthHandler,CheckPeerIdentity)2184 TEST_F(TestBasicAuthHandler, CheckPeerIdentity) {
2185 ASSERT_OK(
2186 client_->Authenticate({}, std::unique_ptr<ClientAuthHandler>(
2187 new TestClientBasicAuthHandler("user", "p4ssw0rd"))));
2188
2189 Action action;
2190 action.type = "who-am-i";
2191 action.body = Buffer::FromString("");
2192 std::unique_ptr<ResultStream> results;
2193 ASSERT_OK(client_->DoAction(action, &results));
2194 ASSERT_NE(results, nullptr);
2195
2196 std::unique_ptr<Result> result;
2197 ASSERT_OK(results->Next(&result));
2198 ASSERT_NE(result, nullptr);
2199 // Action returns the peer identity as the result.
2200 ASSERT_EQ(result->body->ToString(), "user");
2201 }
2202
TEST_F(TestTls,DoAction)2203 TEST_F(TestTls, DoAction) {
2204 FlightCallOptions options;
2205 options.timeout = TimeoutDuration{5.0};
2206 Action action;
2207 action.type = "test";
2208 action.body = Buffer::FromString("");
2209 std::unique_ptr<ResultStream> results;
2210 ASSERT_OK(client_->DoAction(options, action, &results));
2211 ASSERT_NE(results, nullptr);
2212
2213 std::unique_ptr<Result> result;
2214 ASSERT_OK(results->Next(&result));
2215 ASSERT_NE(result, nullptr);
2216 ASSERT_EQ(result->body->ToString(), "Hello, world!");
2217 }
2218
2219 #if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS)
TEST_F(TestTls,DisableServerVerification)2220 TEST_F(TestTls, DisableServerVerification) {
2221 std::unique_ptr<FlightClient> client;
2222 auto client_options = FlightClientOptions::Defaults();
2223 // For security reasons, if encryption is being used,
2224 // the client should be configured to verify the server by default.
2225 ASSERT_EQ(client_options.disable_server_verification, false);
2226 client_options.disable_server_verification = true;
2227 ASSERT_OK(FlightClient::Connect(location_, client_options, &client));
2228
2229 FlightCallOptions options;
2230 options.timeout = TimeoutDuration{5.0};
2231 Action action;
2232 action.type = "test";
2233 action.body = Buffer::FromString("");
2234 std::unique_ptr<ResultStream> results;
2235 ASSERT_OK(client->DoAction(options, action, &results));
2236 ASSERT_NE(results, nullptr);
2237
2238 std::unique_ptr<Result> result;
2239 ASSERT_OK(results->Next(&result));
2240 ASSERT_NE(result, nullptr);
2241 ASSERT_EQ(result->body->ToString(), "Hello, world!");
2242 }
2243 #endif
2244
TEST_F(TestTls,OverrideHostname)2245 TEST_F(TestTls, OverrideHostname) {
2246 std::unique_ptr<FlightClient> client;
2247 auto client_options = FlightClientOptions::Defaults();
2248 client_options.override_hostname = "fakehostname";
2249 CertKeyPair root_cert;
2250 ASSERT_OK(ExampleTlsCertificateRoot(&root_cert));
2251 client_options.tls_root_certs = root_cert.pem_cert;
2252 ASSERT_OK(FlightClient::Connect(location_, client_options, &client));
2253
2254 FlightCallOptions options;
2255 options.timeout = TimeoutDuration{5.0};
2256 Action action;
2257 action.type = "test";
2258 action.body = Buffer::FromString("");
2259 std::unique_ptr<ResultStream> results;
2260 ASSERT_RAISES(IOError, client->DoAction(options, action, &results));
2261 }
2262
2263 // Test the facility for setting generic transport options.
TEST_F(TestTls,OverrideHostnameGeneric)2264 TEST_F(TestTls, OverrideHostnameGeneric) {
2265 std::unique_ptr<FlightClient> client;
2266 auto client_options = FlightClientOptions::Defaults();
2267 client_options.generic_options.emplace_back(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG,
2268 "fakehostname");
2269 CertKeyPair root_cert;
2270 ASSERT_OK(ExampleTlsCertificateRoot(&root_cert));
2271 client_options.tls_root_certs = root_cert.pem_cert;
2272 ASSERT_OK(FlightClient::Connect(location_, client_options, &client));
2273
2274 FlightCallOptions options;
2275 options.timeout = TimeoutDuration{5.0};
2276 Action action;
2277 action.type = "test";
2278 action.body = Buffer::FromString("");
2279 std::unique_ptr<ResultStream> results;
2280 ASSERT_RAISES(IOError, client->DoAction(options, action, &results));
2281 // Could check error message for the gRPC error message but it isn't
2282 // necessarily stable
2283 }
2284
TEST_F(TestMetadata,DoGet)2285 TEST_F(TestMetadata, DoGet) {
2286 Ticket ticket{""};
2287 std::unique_ptr<FlightStreamReader> stream;
2288 ASSERT_OK(client_->DoGet(ticket, &stream));
2289
2290 BatchVector expected_batches;
2291 ASSERT_OK(ExampleIntBatches(&expected_batches));
2292
2293 FlightStreamChunk chunk;
2294 auto num_batches = static_cast<int>(expected_batches.size());
2295 for (int i = 0; i < num_batches; ++i) {
2296 ASSERT_OK(stream->Next(&chunk));
2297 ASSERT_NE(nullptr, chunk.data);
2298 ASSERT_NE(nullptr, chunk.app_metadata);
2299 ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data);
2300 ASSERT_EQ(std::to_string(i), chunk.app_metadata->ToString());
2301 }
2302 ASSERT_OK(stream->Next(&chunk));
2303 ASSERT_EQ(nullptr, chunk.data);
2304 }
2305
2306 // Test dictionaries. This tests a corner case in the reader:
2307 // dictionary batches come in between the schema and the first record
2308 // batch, so the server must take care to read application metadata
2309 // from the record batch, and not one of the dictionary batches.
TEST_F(TestMetadata,DoGetDictionaries)2310 TEST_F(TestMetadata, DoGetDictionaries) {
2311 Ticket ticket{"dicts"};
2312 std::unique_ptr<FlightStreamReader> stream;
2313 ASSERT_OK(client_->DoGet(ticket, &stream));
2314
2315 BatchVector expected_batches;
2316 ASSERT_OK(ExampleDictBatches(&expected_batches));
2317
2318 FlightStreamChunk chunk;
2319 auto num_batches = static_cast<int>(expected_batches.size());
2320 for (int i = 0; i < num_batches; ++i) {
2321 ASSERT_OK(stream->Next(&chunk));
2322 ASSERT_NE(nullptr, chunk.data);
2323 ASSERT_NE(nullptr, chunk.app_metadata);
2324 ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data);
2325 ASSERT_EQ(std::to_string(i), chunk.app_metadata->ToString());
2326 }
2327 ASSERT_OK(stream->Next(&chunk));
2328 ASSERT_EQ(nullptr, chunk.data);
2329 }
2330
TEST_F(TestMetadata,DoPut)2331 TEST_F(TestMetadata, DoPut) {
2332 std::unique_ptr<FlightStreamWriter> writer;
2333 std::unique_ptr<FlightMetadataReader> reader;
2334 std::shared_ptr<Schema> schema = ExampleIntSchema();
2335 ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader));
2336
2337 BatchVector expected_batches;
2338 ASSERT_OK(ExampleIntBatches(&expected_batches));
2339
2340 std::shared_ptr<RecordBatch> chunk;
2341 std::shared_ptr<Buffer> metadata;
2342 auto num_batches = static_cast<int>(expected_batches.size());
2343 for (int i = 0; i < num_batches; ++i) {
2344 ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i],
2345 Buffer::FromString(std::to_string(i))));
2346 }
2347 // This eventually calls grpc::ClientReaderWriter::Finish which can
2348 // hang if there are unread messages. So make sure our wrapper
2349 // around this doesn't hang (because it drains any unread messages)
2350 ASSERT_OK(writer->Close());
2351 }
2352
2353 // Test DoPut() with dictionaries. This tests a corner case in the
2354 // server-side reader; see DoGetDictionaries above.
TEST_F(TestMetadata,DoPutDictionaries)2355 TEST_F(TestMetadata, DoPutDictionaries) {
2356 std::unique_ptr<FlightStreamWriter> writer;
2357 std::unique_ptr<FlightMetadataReader> reader;
2358 BatchVector expected_batches;
2359 ASSERT_OK(ExampleDictBatches(&expected_batches));
2360 // ARROW-8749: don't get the schema via ExampleDictSchema because
2361 // DictionaryMemo uses field addresses to determine whether it's
2362 // seen a field before. Hence, if we use a schema that is different
2363 // (identity-wise) than the schema of the first batch we write,
2364 // we'll end up generating a duplicate set of dictionaries that
2365 // confuses the reader.
2366 ASSERT_OK(client_->DoPut(FlightDescriptor{}, expected_batches[0]->schema(), &writer,
2367 &reader));
2368 std::shared_ptr<RecordBatch> chunk;
2369 std::shared_ptr<Buffer> metadata;
2370 auto num_batches = static_cast<int>(expected_batches.size());
2371 for (int i = 0; i < num_batches; ++i) {
2372 ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i],
2373 Buffer::FromString(std::to_string(i))));
2374 }
2375 ASSERT_OK(writer->Close());
2376 }
2377
TEST_F(TestMetadata,DoPutReadMetadata)2378 TEST_F(TestMetadata, DoPutReadMetadata) {
2379 std::unique_ptr<FlightStreamWriter> writer;
2380 std::unique_ptr<FlightMetadataReader> reader;
2381 std::shared_ptr<Schema> schema = ExampleIntSchema();
2382 ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader));
2383
2384 BatchVector expected_batches;
2385 ASSERT_OK(ExampleIntBatches(&expected_batches));
2386
2387 std::shared_ptr<RecordBatch> chunk;
2388 std::shared_ptr<Buffer> metadata;
2389 auto num_batches = static_cast<int>(expected_batches.size());
2390 for (int i = 0; i < num_batches; ++i) {
2391 ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i],
2392 Buffer::FromString(std::to_string(i))));
2393 ASSERT_OK(reader->ReadMetadata(&metadata));
2394 ASSERT_NE(nullptr, metadata);
2395 ASSERT_EQ(std::to_string(i), metadata->ToString());
2396 }
2397 // As opposed to DoPutDrainMetadata, now we've read the messages, so
2398 // make sure this still closes as expected.
2399 ASSERT_OK(writer->Close());
2400 }
2401
TEST_F(TestOptions,DoGetReadOptions)2402 TEST_F(TestOptions, DoGetReadOptions) {
2403 // Call DoGet, but with a very low read nesting depth set to fail the call.
2404 Ticket ticket{""};
2405 auto options = FlightCallOptions();
2406 options.read_options.max_recursion_depth = 1;
2407 std::unique_ptr<FlightStreamReader> stream;
2408 ASSERT_OK(client_->DoGet(options, ticket, &stream));
2409 FlightStreamChunk chunk;
2410 ASSERT_RAISES(Invalid, stream->Next(&chunk));
2411 }
2412
TEST_F(TestOptions,DoPutWriteOptions)2413 TEST_F(TestOptions, DoPutWriteOptions) {
2414 // Call DoPut, but with a very low write nesting depth set to fail the call.
2415 std::unique_ptr<FlightStreamWriter> writer;
2416 std::unique_ptr<FlightMetadataReader> reader;
2417 BatchVector expected_batches;
2418 ASSERT_OK(ExampleNestedBatches(&expected_batches));
2419
2420 auto options = FlightCallOptions();
2421 options.write_options.max_recursion_depth = 1;
2422 ASSERT_OK(client_->DoPut(options, FlightDescriptor{}, expected_batches[0]->schema(),
2423 &writer, &reader));
2424 for (const auto& batch : expected_batches) {
2425 ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch));
2426 }
2427 }
2428
TEST_F(TestOptions,DoExchangeClientWriteOptions)2429 TEST_F(TestOptions, DoExchangeClientWriteOptions) {
2430 // Call DoExchange and write nested data, but with a very low nesting depth set to
2431 // fail the call.
2432 auto options = FlightCallOptions();
2433 options.write_options.max_recursion_depth = 1;
2434 auto descr = FlightDescriptor::Command("");
2435 std::unique_ptr<FlightStreamReader> reader;
2436 std::unique_ptr<FlightStreamWriter> writer;
2437 ASSERT_OK(client_->DoExchange(options, descr, &writer, &reader));
2438 BatchVector batches;
2439 ASSERT_OK(ExampleNestedBatches(&batches));
2440 ASSERT_OK(writer->Begin(batches[0]->schema()));
2441 for (const auto& batch : batches) {
2442 ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch));
2443 }
2444 ASSERT_OK(writer->DoneWriting());
2445 ASSERT_OK(writer->Close());
2446 }
2447
TEST_F(TestOptions,DoExchangeClientWriteOptionsBegin)2448 TEST_F(TestOptions, DoExchangeClientWriteOptionsBegin) {
2449 // Call DoExchange and write nested data, but with a very low nesting depth set to
2450 // fail the call. Here the options are set explicitly when we write data and not in the
2451 // call options.
2452 auto descr = FlightDescriptor::Command("");
2453 std::unique_ptr<FlightStreamReader> reader;
2454 std::unique_ptr<FlightStreamWriter> writer;
2455 ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
2456 BatchVector batches;
2457 ASSERT_OK(ExampleNestedBatches(&batches));
2458 auto options = ipc::IpcWriteOptions::Defaults();
2459 options.max_recursion_depth = 1;
2460 ASSERT_OK(writer->Begin(batches[0]->schema(), options));
2461 for (const auto& batch : batches) {
2462 ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch));
2463 }
2464 ASSERT_OK(writer->DoneWriting());
2465 ASSERT_OK(writer->Close());
2466 }
2467
TEST_F(TestOptions,DoExchangeServerWriteOptions)2468 TEST_F(TestOptions, DoExchangeServerWriteOptions) {
2469 // Call DoExchange and write nested data, but with a very low nesting depth set to fail
2470 // the call. (The low nesting depth is set on the server side.)
2471 auto descr = FlightDescriptor::Command("");
2472 std::unique_ptr<FlightStreamReader> reader;
2473 std::unique_ptr<FlightStreamWriter> writer;
2474 ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
2475 BatchVector batches;
2476 ASSERT_OK(ExampleNestedBatches(&batches));
2477 ASSERT_OK(writer->Begin(batches[0]->schema()));
2478 FlightStreamChunk chunk;
2479 ASSERT_OK(writer->WriteRecordBatch(*batches[0]));
2480 ASSERT_OK(writer->DoneWriting());
2481 ASSERT_RAISES(Invalid, writer->Close());
2482 }
2483
TEST_F(TestRejectServerMiddleware,Rejected)2484 TEST_F(TestRejectServerMiddleware, Rejected) {
2485 std::unique_ptr<FlightInfo> info;
2486 const auto& status = client_->GetFlightInfo(FlightDescriptor{}, &info);
2487 ASSERT_RAISES(IOError, status);
2488 ASSERT_THAT(status.message(), ::testing::HasSubstr("All calls are rejected"));
2489 }
2490
TEST_F(TestCountingServerMiddleware,Count)2491 TEST_F(TestCountingServerMiddleware, Count) {
2492 std::unique_ptr<FlightInfo> info;
2493 const auto& status = client_->GetFlightInfo(FlightDescriptor{}, &info);
2494 ASSERT_RAISES(NotImplemented, status);
2495
2496 Ticket ticket{""};
2497 std::unique_ptr<FlightStreamReader> stream;
2498 ASSERT_OK(client_->DoGet(ticket, &stream));
2499
2500 ASSERT_EQ(1, request_counter_->failed_);
2501
2502 while (true) {
2503 FlightStreamChunk chunk;
2504 ASSERT_OK(stream->Next(&chunk));
2505 if (chunk.data == nullptr) {
2506 break;
2507 }
2508 }
2509
2510 ASSERT_EQ(1, request_counter_->successful_);
2511 ASSERT_EQ(1, request_counter_->failed_);
2512 }
2513
TEST_F(TestPropagatingMiddleware,Propagate)2514 TEST_F(TestPropagatingMiddleware, Propagate) {
2515 Action action;
2516 std::unique_ptr<ResultStream> stream;
2517 std::unique_ptr<Result> result;
2518
2519 current_span_id = "trace-id";
2520 client_middleware_->Reset();
2521
2522 action.type = "action1";
2523 action.body = Buffer::FromString("action1-content");
2524 ASSERT_OK(client_->DoAction(action, &stream));
2525
2526 ASSERT_OK(stream->Next(&result));
2527 ASSERT_EQ("trace-id", result->body->ToString());
2528 ValidateStatus(Status::OK(), FlightMethod::DoAction);
2529 }
2530
2531 // For each method, make sure that the client middleware received
2532 // headers from the server and that the proper method enum value was
2533 // passed to the interceptor
TEST_F(TestPropagatingMiddleware,ListFlights)2534 TEST_F(TestPropagatingMiddleware, ListFlights) {
2535 client_middleware_->Reset();
2536 std::unique_ptr<FlightListing> listing;
2537 const Status status = client_->ListFlights(&listing);
2538 ASSERT_RAISES(NotImplemented, status);
2539 ValidateStatus(status, FlightMethod::ListFlights);
2540 }
2541
TEST_F(TestPropagatingMiddleware,GetFlightInfo)2542 TEST_F(TestPropagatingMiddleware, GetFlightInfo) {
2543 client_middleware_->Reset();
2544 auto descr = FlightDescriptor::Path({"examples", "ints"});
2545 std::unique_ptr<FlightInfo> info;
2546 const Status status = client_->GetFlightInfo(descr, &info);
2547 ASSERT_RAISES(NotImplemented, status);
2548 ValidateStatus(status, FlightMethod::GetFlightInfo);
2549 }
2550
TEST_F(TestPropagatingMiddleware,GetSchema)2551 TEST_F(TestPropagatingMiddleware, GetSchema) {
2552 client_middleware_->Reset();
2553 auto descr = FlightDescriptor::Path({"examples", "ints"});
2554 std::unique_ptr<SchemaResult> result;
2555 const Status status = client_->GetSchema(descr, &result);
2556 ASSERT_RAISES(NotImplemented, status);
2557 ValidateStatus(status, FlightMethod::GetSchema);
2558 }
2559
TEST_F(TestPropagatingMiddleware,ListActions)2560 TEST_F(TestPropagatingMiddleware, ListActions) {
2561 client_middleware_->Reset();
2562 std::vector<ActionType> actions;
2563 const Status status = client_->ListActions(&actions);
2564 ASSERT_RAISES(NotImplemented, status);
2565 ValidateStatus(status, FlightMethod::ListActions);
2566 }
2567
TEST_F(TestPropagatingMiddleware,DoGet)2568 TEST_F(TestPropagatingMiddleware, DoGet) {
2569 client_middleware_->Reset();
2570 Ticket ticket1{"ARROW-5095-fail"};
2571 std::unique_ptr<FlightStreamReader> stream;
2572 Status status = client_->DoGet(ticket1, &stream);
2573 ASSERT_RAISES(NotImplemented, status);
2574 ValidateStatus(status, FlightMethod::DoGet);
2575 }
2576
TEST_F(TestPropagatingMiddleware,DoPut)2577 TEST_F(TestPropagatingMiddleware, DoPut) {
2578 client_middleware_->Reset();
2579 auto descr = FlightDescriptor::Path({"ints"});
2580 auto a1 = ArrayFromJSON(int32(), "[4, 5, 6, null]");
2581 auto schema = arrow::schema({field("f1", a1->type())});
2582
2583 std::unique_ptr<FlightStreamWriter> stream;
2584 std::unique_ptr<FlightMetadataReader> reader;
2585 ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader));
2586 const Status status = stream->Close();
2587 ASSERT_RAISES(NotImplemented, status);
2588 ValidateStatus(status, FlightMethod::DoPut);
2589 }
2590
TEST_F(TestBasicHeaderAuthMiddleware,ValidCredentials)2591 TEST_F(TestBasicHeaderAuthMiddleware, ValidCredentials) { RunValidClientAuth(); }
2592
TEST_F(TestBasicHeaderAuthMiddleware,InvalidCredentials)2593 TEST_F(TestBasicHeaderAuthMiddleware, InvalidCredentials) { RunInvalidClientAuth(); }
2594
TEST_F(TestCookieMiddleware,BasicParsing)2595 TEST_F(TestCookieMiddleware, BasicParsing) {
2596 AddAndValidate("id1=1; foo=bar;");
2597 AddAndValidate("id1=1; foo=bar");
2598 AddAndValidate("id2=2;");
2599 AddAndValidate("id4=\"4\"");
2600 AddAndValidate("id5=5; foo=bar; baz=buz;");
2601 }
2602
TEST_F(TestCookieMiddleware,Overwrite)2603 TEST_F(TestCookieMiddleware, Overwrite) {
2604 AddAndValidate("id0=0");
2605 AddAndValidate("id0=1");
2606 AddAndValidate("id1=0");
2607 AddAndValidate("id1=1");
2608 AddAndValidate("id1=1");
2609 AddAndValidate("id1=10");
2610 AddAndValidate("id=3");
2611 AddAndValidate("id=0");
2612 AddAndValidate("id=0");
2613 }
2614
TEST_F(TestCookieMiddleware,MaxAge)2615 TEST_F(TestCookieMiddleware, MaxAge) {
2616 AddAndValidate("id0=0; max-age=0;");
2617 AddAndValidate("id1=0; max-age=-1;");
2618 AddAndValidate("id2=0; max-age=0");
2619 AddAndValidate("id3=0; max-age=-1");
2620 AddAndValidate("id4=0; max-age=1");
2621 AddAndValidate("id5=0; max-age=1");
2622 AddAndValidate("id4=0; max-age=0");
2623 AddAndValidate("id5=0; max-age=0");
2624 }
2625
TEST_F(TestCookieMiddleware,Expires)2626 TEST_F(TestCookieMiddleware, Expires) {
2627 AddAndValidate("id0=0; expires=0, 0 0 0 0:0:0 GMT;");
2628 AddAndValidate("id0=0; expires=0, 0 0 0 0:0:0 GMT");
2629 AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;");
2630 AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT");
2631 AddAndValidate("id0=0; expires=Fri, 01 Jan 2038 22:15:36 GMT;");
2632 AddAndValidate("id1=0; expires=Fri, 01 Jan 2038 22:15:36 GMT");
2633 AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;");
2634 AddAndValidate("id1=0; expires=Fri, 22 Dec 2017 22:15:36 GMT");
2635 }
2636
TEST_F(TestCookieParsing,Expired)2637 TEST_F(TestCookieParsing, Expired) {
2638 VerifyParseCookie("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;", true);
2639 VerifyParseCookie("id1=0; max-age=-1;", true);
2640 VerifyParseCookie("id0=0; max-age=0;", true);
2641 }
2642
TEST_F(TestCookieParsing,Invalid)2643 TEST_F(TestCookieParsing, Invalid) {
2644 VerifyParseCookie("id1=0; expires=0, 0 0 0 0:0:0 GMT;", true);
2645 VerifyParseCookie("id1=0; expires=Fri, 01 FOO 2038 22:15:36 GMT", true);
2646 VerifyParseCookie("id1=0; expires=foo", true);
2647 VerifyParseCookie("id1=0; expires=", true);
2648 VerifyParseCookie("id1=0; max-age=FOO", true);
2649 VerifyParseCookie("id1=0; max-age=", true);
2650 }
2651
TEST_F(TestCookieParsing,NoExpiry)2652 TEST_F(TestCookieParsing, NoExpiry) {
2653 VerifyParseCookie("id1=0;", false);
2654 VerifyParseCookie("id1=0; noexpiry=Fri, 01 Jan 2038 22:15:36 GMT", false);
2655 VerifyParseCookie("id1=0; noexpiry=\"Fri, 01 Jan 2038 22:15:36 GMT\"", false);
2656 VerifyParseCookie("id1=0; nomax-age=-1", false);
2657 VerifyParseCookie("id1=0; nomax-age=\"-1\"", false);
2658 VerifyParseCookie("id1=0; randomattr=foo", false);
2659 }
2660
TEST_F(TestCookieParsing,NotExpired)2661 TEST_F(TestCookieParsing, NotExpired) {
2662 VerifyParseCookie("id5=0; max-age=1", false);
2663 VerifyParseCookie("id0=0; expires=Fri, 01 Jan 2038 22:15:36 GMT;", false);
2664 }
2665
TEST_F(TestCookieParsing,GetName)2666 TEST_F(TestCookieParsing, GetName) {
2667 VerifyCookieName("id1=1; foo=bar;", "id1");
2668 VerifyCookieName("id1=1; foo=bar", "id1");
2669 VerifyCookieName("id2=2;", "id2");
2670 VerifyCookieName("id4=\"4\"", "id4");
2671 VerifyCookieName("id5=5; foo=bar; baz=buz;", "id5");
2672 }
2673
TEST_F(TestCookieParsing,ToString)2674 TEST_F(TestCookieParsing, ToString) {
2675 VerifyCookieString("id1=1; foo=bar;", "id1=\"1\"");
2676 VerifyCookieString("id1=1; foo=bar", "id1=\"1\"");
2677 VerifyCookieString("id2=2;", "id2=\"2\"");
2678 VerifyCookieString("id4=\"4\"", "id4=\"4\"");
2679 VerifyCookieString("id5=5; foo=bar; baz=buz;", "id5=\"5\"");
2680 }
2681
TEST_F(TestCookieParsing,DateConversion)2682 TEST_F(TestCookieParsing, DateConversion) {
2683 VerifyCookieDateConverson("Mon, 01 jan 2038 22:15:36 GMT;", "01 01 2038 22:15:36");
2684 VerifyCookieDateConverson("TUE, 10 Feb 2038 22:15:36 GMT", "10 02 2038 22:15:36");
2685 VerifyCookieDateConverson("WED, 20 MAr 2038 22:15:36 GMT;", "20 03 2038 22:15:36");
2686 VerifyCookieDateConverson("thu, 15 APR 2038 22:15:36 GMT", "15 04 2038 22:15:36");
2687 VerifyCookieDateConverson("Fri, 30 mAY 2038 22:15:36 GMT;", "30 05 2038 22:15:36");
2688 VerifyCookieDateConverson("Sat, 03 juN 2038 22:15:36 GMT", "03 06 2038 22:15:36");
2689 VerifyCookieDateConverson("Sun, 01 JuL 2038 22:15:36 GMT;", "01 07 2038 22:15:36");
2690 VerifyCookieDateConverson("Fri, 06 aUg 2038 22:15:36 GMT", "06 08 2038 22:15:36");
2691 VerifyCookieDateConverson("Fri, 01 SEP 2038 22:15:36 GMT;", "01 09 2038 22:15:36");
2692 VerifyCookieDateConverson("Fri, 01 OCT 2038 22:15:36 GMT", "01 10 2038 22:15:36");
2693 VerifyCookieDateConverson("Fri, 01 Nov 2038 22:15:36 GMT;", "01 11 2038 22:15:36");
2694 VerifyCookieDateConverson("Fri, 01 deC 2038 22:15:36 GMT", "01 12 2038 22:15:36");
2695 VerifyCookieDateConverson("", "");
2696 VerifyCookieDateConverson("Fri, 01 INVALID 2038 22:15:36 GMT;",
2697 "01 INVALID 2038 22:15:36");
2698 }
2699
TEST_F(TestCookieParsing,ParseCookieAttribute)2700 TEST_F(TestCookieParsing, ParseCookieAttribute) {
2701 VerifyCookieAttributeParsing("", 0, util::nullopt, std::string::npos);
2702
2703 std::string cookie_string = "attr0=0; attr1=1; attr2=2; attr3=3";
2704 auto attr_length = std::string("attr0=0;").length();
2705 std::string::size_type start_pos = 0;
2706 VerifyCookieAttributeParsing(cookie_string, start_pos, std::make_pair("attr0", "0"),
2707 cookie_string.find("attr0=0;") + attr_length);
2708 VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)),
2709 std::make_pair("attr1", "1"),
2710 cookie_string.find("attr1=1;") + attr_length);
2711 VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)),
2712 std::make_pair("attr2", "2"),
2713 cookie_string.find("attr2=2;") + attr_length);
2714 VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)),
2715 std::make_pair("attr3", "3"), std::string::npos);
2716 VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length - 1)),
2717 util::nullopt, std::string::npos);
2718 VerifyCookieAttributeParsing(cookie_string, std::string::npos, util::nullopt,
2719 std::string::npos);
2720 }
2721
TEST_F(TestCookieParsing,CookieCache)2722 TEST_F(TestCookieParsing, CookieCache) {
2723 AddCookieVerifyCache({"id0=0;"}, "");
2724 AddCookieVerifyCache({"id0=0;", "id0=1;"}, "id0=\"1\"");
2725 AddCookieVerifyCache({"id0=0;", "id1=1;"}, "id0=\"0\"; id1=\"1\"");
2726 AddCookieVerifyCache({"id0=0;", "id1=1;", "id2=2"}, "id0=\"0\"; id1=\"1\"; id2=\"2\"");
2727 }
2728
2729 class ForeverFlightListing : public FlightListing {
Next(std::unique_ptr<FlightInfo> * info)2730 Status Next(std::unique_ptr<FlightInfo>* info) override {
2731 std::this_thread::sleep_for(std::chrono::milliseconds(100));
2732 *info = arrow::internal::make_unique<FlightInfo>(ExampleFlightInfo()[0]);
2733 return Status::OK();
2734 }
2735 };
2736
2737 class ForeverResultStream : public ResultStream {
Next(std::unique_ptr<Result> * result)2738 Status Next(std::unique_ptr<Result>* result) override {
2739 std::this_thread::sleep_for(std::chrono::milliseconds(100));
2740 *result = arrow::internal::make_unique<Result>();
2741 (*result)->body = Buffer::FromString("foo");
2742 return Status::OK();
2743 }
2744 };
2745
2746 class ForeverDataStream : public FlightDataStream {
2747 public:
ForeverDataStream()2748 ForeverDataStream() : schema_(arrow::schema({})), mapper_(*schema_) {}
schema()2749 std::shared_ptr<Schema> schema() override { return schema_; }
2750
GetSchemaPayload(FlightPayload * payload)2751 Status GetSchemaPayload(FlightPayload* payload) override {
2752 return ipc::GetSchemaPayload(*schema_, ipc::IpcWriteOptions::Defaults(), mapper_,
2753 &payload->ipc_message);
2754 }
2755
Next(FlightPayload * payload)2756 Status Next(FlightPayload* payload) override {
2757 auto batch = RecordBatch::Make(schema_, 0, ArrayVector{});
2758 return ipc::GetRecordBatchPayload(*batch, ipc::IpcWriteOptions::Defaults(),
2759 &payload->ipc_message);
2760 }
2761
2762 private:
2763 std::shared_ptr<Schema> schema_;
2764 ipc::DictionaryFieldMapper mapper_;
2765 };
2766
2767 class CancelTestServer : public FlightServerBase {
2768 public:
ListFlights(const ServerCallContext &,const Criteria *,std::unique_ptr<FlightListing> * listings)2769 Status ListFlights(const ServerCallContext&, const Criteria*,
2770 std::unique_ptr<FlightListing>* listings) override {
2771 *listings = arrow::internal::make_unique<ForeverFlightListing>();
2772 return Status::OK();
2773 }
DoAction(const ServerCallContext &,const Action &,std::unique_ptr<ResultStream> * result)2774 Status DoAction(const ServerCallContext&, const Action&,
2775 std::unique_ptr<ResultStream>* result) override {
2776 *result = arrow::internal::make_unique<ForeverResultStream>();
2777 return Status::OK();
2778 }
ListActions(const ServerCallContext &,std::vector<ActionType> * actions)2779 Status ListActions(const ServerCallContext&,
2780 std::vector<ActionType>* actions) override {
2781 *actions = {};
2782 return Status::OK();
2783 }
DoGet(const ServerCallContext &,const Ticket &,std::unique_ptr<FlightDataStream> * data_stream)2784 Status DoGet(const ServerCallContext&, const Ticket&,
2785 std::unique_ptr<FlightDataStream>* data_stream) override {
2786 *data_stream = arrow::internal::make_unique<ForeverDataStream>();
2787 return Status::OK();
2788 }
2789 };
2790
2791 class TestCancel : public ::testing::Test {
2792 public:
SetUp()2793 void SetUp() {
2794 ASSERT_OK(MakeServer<CancelTestServer>(
2795 &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
2796 [](FlightClientOptions* options) { return Status::OK(); }));
2797 }
TearDown()2798 void TearDown() { ASSERT_OK(server_->Shutdown()); }
2799
2800 protected:
2801 std::unique_ptr<FlightClient> client_;
2802 std::unique_ptr<FlightServerBase> server_;
2803 };
2804
TEST_F(TestCancel,ListFlights)2805 TEST_F(TestCancel, ListFlights) {
2806 StopSource stop_source;
2807 FlightCallOptions options;
2808 options.stop_token = stop_source.token();
2809 std::unique_ptr<FlightListing> listing;
2810 stop_source.RequestStop(Status::Cancelled("StopSource"));
2811 EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
2812 client_->ListFlights(options, {}, &listing));
2813 }
2814
TEST_F(TestCancel,DoAction)2815 TEST_F(TestCancel, DoAction) {
2816 StopSource stop_source;
2817 FlightCallOptions options;
2818 options.stop_token = stop_source.token();
2819 std::unique_ptr<ResultStream> results;
2820 stop_source.RequestStop(Status::Cancelled("StopSource"));
2821 EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
2822 client_->DoAction(options, {}, &results));
2823 }
2824
TEST_F(TestCancel,ListActions)2825 TEST_F(TestCancel, ListActions) {
2826 StopSource stop_source;
2827 FlightCallOptions options;
2828 options.stop_token = stop_source.token();
2829 std::vector<ActionType> results;
2830 stop_source.RequestStop(Status::Cancelled("StopSource"));
2831 EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
2832 client_->ListActions(options, &results));
2833 }
2834
TEST_F(TestCancel,DoGet)2835 TEST_F(TestCancel, DoGet) {
2836 StopSource stop_source;
2837 FlightCallOptions options;
2838 options.stop_token = stop_source.token();
2839 std::unique_ptr<ResultStream> results;
2840 stop_source.RequestStop(Status::Cancelled("StopSource"));
2841 std::unique_ptr<FlightStreamReader> stream;
2842 ASSERT_OK(client_->DoGet(options, {}, &stream));
2843 std::shared_ptr<Table> table;
2844 EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
2845 stream->ReadAll(&table));
2846
2847 ASSERT_OK(client_->DoGet({}, &stream));
2848 EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
2849 stream->ReadAll(&table, options.stop_token));
2850 }
2851
TEST_F(TestCancel,DoExchange)2852 TEST_F(TestCancel, DoExchange) {
2853 StopSource stop_source;
2854 FlightCallOptions options;
2855 options.stop_token = stop_source.token();
2856 std::unique_ptr<ResultStream> results;
2857 stop_source.RequestStop(Status::Cancelled("StopSource"));
2858 std::unique_ptr<FlightStreamWriter> writer;
2859 std::unique_ptr<FlightStreamReader> stream;
2860 ASSERT_OK(
2861 client_->DoExchange(options, FlightDescriptor::Command(""), &writer, &stream));
2862 std::shared_ptr<Table> table;
2863 EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
2864 stream->ReadAll(&table));
2865
2866 ASSERT_OK(client_->DoExchange(FlightDescriptor::Command(""), &writer, &stream));
2867 EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
2868 stream->ReadAll(&table, options.stop_token));
2869 }
2870
2871 } // namespace flight
2872 } // namespace arrow
2873