1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include <gmock/gmock.h>
19 #include <gtest/gtest.h>
20 
21 #include <atomic>
22 #include <chrono>
23 #include <cstdint>
24 #include <cstdio>
25 #include <cstring>
26 #include <iostream>
27 #include <memory>
28 #include <sstream>
29 #include <string>
30 #include <thread>
31 #include <vector>
32 
33 #include "arrow/flight/api.h"
34 #include "arrow/ipc/test_common.h"
35 #include "arrow/status.h"
36 #include "arrow/testing/generator.h"
37 #include "arrow/testing/gtest_util.h"
38 #include "arrow/testing/util.h"
39 #include "arrow/util/base64.h"
40 #include "arrow/util/logging.h"
41 #include "arrow/util/make_unique.h"
42 #include "arrow/util/string.h"
43 
44 #ifdef GRPCPP_GRPCPP_H
45 #error "gRPC headers should not be in public API"
46 #endif
47 
48 #include "arrow/flight/client_cookie_middleware.h"
49 #include "arrow/flight/client_header_internal.h"
50 #include "arrow/flight/internal.h"
51 #include "arrow/flight/middleware_internal.h"
52 #include "arrow/flight/test_util.h"
53 
54 namespace arrow {
55 namespace flight {
56 
57 namespace pb = arrow::flight::protocol;
58 
59 const char kValidUsername[] = "flight_username";
60 const char kValidPassword[] = "flight_password";
61 const char kInvalidUsername[] = "invalid_flight_username";
62 const char kInvalidPassword[] = "invalid_flight_password";
63 const char kBearerToken[] = "bearertoken";
64 const char kBasicPrefix[] = "Basic ";
65 const char kBearerPrefix[] = "Bearer ";
66 const char kAuthHeader[] = "authorization";
67 
AssertEqual(const ActionType & expected,const ActionType & actual)68 void AssertEqual(const ActionType& expected, const ActionType& actual) {
69   ASSERT_EQ(expected.type, actual.type);
70   ASSERT_EQ(expected.description, actual.description);
71 }
72 
AssertEqual(const FlightDescriptor & expected,const FlightDescriptor & actual)73 void AssertEqual(const FlightDescriptor& expected, const FlightDescriptor& actual) {
74   ASSERT_TRUE(expected.Equals(actual));
75 }
76 
AssertEqual(const Ticket & expected,const Ticket & actual)77 void AssertEqual(const Ticket& expected, const Ticket& actual) {
78   ASSERT_EQ(expected.ticket, actual.ticket);
79 }
80 
AssertEqual(const Location & expected,const Location & actual)81 void AssertEqual(const Location& expected, const Location& actual) {
82   ASSERT_EQ(expected, actual);
83 }
84 
AssertEqual(const std::vector<FlightEndpoint> & expected,const std::vector<FlightEndpoint> & actual)85 void AssertEqual(const std::vector<FlightEndpoint>& expected,
86                  const std::vector<FlightEndpoint>& actual) {
87   ASSERT_EQ(expected.size(), actual.size());
88   for (size_t i = 0; i < expected.size(); ++i) {
89     AssertEqual(expected[i].ticket, actual[i].ticket);
90 
91     ASSERT_EQ(expected[i].locations.size(), actual[i].locations.size());
92     for (size_t j = 0; j < expected[i].locations.size(); ++j) {
93       AssertEqual(expected[i].locations[j], actual[i].locations[j]);
94     }
95   }
96 }
97 
98 template <typename T>
AssertEqual(const std::vector<T> & expected,const std::vector<T> & actual)99 void AssertEqual(const std::vector<T>& expected, const std::vector<T>& actual) {
100   ASSERT_EQ(expected.size(), actual.size());
101   for (size_t i = 0; i < expected.size(); ++i) {
102     AssertEqual(expected[i], actual[i]);
103   }
104 }
105 
AssertEqual(const FlightInfo & expected,const FlightInfo & actual)106 void AssertEqual(const FlightInfo& expected, const FlightInfo& actual) {
107   std::shared_ptr<Schema> ex_schema, actual_schema;
108   ipc::DictionaryMemo expected_memo;
109   ipc::DictionaryMemo actual_memo;
110   ASSERT_OK(expected.GetSchema(&expected_memo, &ex_schema));
111   ASSERT_OK(actual.GetSchema(&actual_memo, &actual_schema));
112 
113   AssertSchemaEqual(*ex_schema, *actual_schema);
114   ASSERT_EQ(expected.total_records(), actual.total_records());
115   ASSERT_EQ(expected.total_bytes(), actual.total_bytes());
116 
117   AssertEqual(expected.descriptor(), actual.descriptor());
118   AssertEqual(expected.endpoints(), actual.endpoints());
119 }
120 
TEST(TestFlightDescriptor,Basics)121 TEST(TestFlightDescriptor, Basics) {
122   auto a = FlightDescriptor::Command("select * from table");
123   auto b = FlightDescriptor::Command("select * from table");
124   auto c = FlightDescriptor::Command("select foo from table");
125   auto d = FlightDescriptor::Path({"foo", "bar"});
126   auto e = FlightDescriptor::Path({"foo", "baz"});
127   auto f = FlightDescriptor::Path({"foo", "baz"});
128 
129   ASSERT_EQ(a.ToString(), "FlightDescriptor<cmd = 'select * from table'>");
130   ASSERT_EQ(d.ToString(), "FlightDescriptor<path = 'foo/bar'>");
131   ASSERT_TRUE(a.Equals(b));
132   ASSERT_FALSE(a.Equals(c));
133   ASSERT_FALSE(a.Equals(d));
134   ASSERT_FALSE(d.Equals(e));
135   ASSERT_TRUE(e.Equals(f));
136 }
137 
138 // This tests the internal protobuf types which don't get exported in the Flight DLL.
139 #ifndef _WIN32
TEST(TestFlightDescriptor,ToFromProto)140 TEST(TestFlightDescriptor, ToFromProto) {
141   FlightDescriptor descr_test;
142   pb::FlightDescriptor pb_descr;
143 
144   FlightDescriptor descr1{FlightDescriptor::PATH, "", {"foo", "bar"}};
145   ASSERT_OK(internal::ToProto(descr1, &pb_descr));
146   ASSERT_OK(internal::FromProto(pb_descr, &descr_test));
147   AssertEqual(descr1, descr_test);
148 
149   FlightDescriptor descr2{FlightDescriptor::CMD, "command", {}};
150   ASSERT_OK(internal::ToProto(descr2, &pb_descr));
151   ASSERT_OK(internal::FromProto(pb_descr, &descr_test));
152   AssertEqual(descr2, descr_test);
153 }
154 #endif
155 
TEST(TestFlight,DISABLED_StartStopTestServer)156 TEST(TestFlight, DISABLED_StartStopTestServer) {
157   TestServer server("flight-test-server");
158   server.Start();
159   ASSERT_TRUE(server.IsRunning());
160 
161   std::this_thread::sleep_for(std::chrono::duration<double>(0.2));
162 
163   ASSERT_TRUE(server.IsRunning());
164   int exit_code = server.Stop();
165 #ifdef _WIN32
166   // We do a hard kill on Windows
167   ASSERT_EQ(259, exit_code);
168 #else
169   ASSERT_EQ(0, exit_code);
170 #endif
171   ASSERT_FALSE(server.IsRunning());
172 }
173 
174 // ARROW-6017: we should be able to construct locations for unknown
175 // schemes
TEST(TestFlight,UnknownLocationScheme)176 TEST(TestFlight, UnknownLocationScheme) {
177   Location location;
178   ASSERT_OK(Location::Parse("s3://test", &location));
179   ASSERT_OK(Location::Parse("https://example.com/foo", &location));
180 }
181 
TEST(TestFlight,ConnectUri)182 TEST(TestFlight, ConnectUri) {
183   TestServer server("flight-test-server");
184   server.Start();
185   ASSERT_TRUE(server.IsRunning());
186 
187   std::stringstream ss;
188   ss << "grpc://localhost:" << server.port();
189   std::string uri = ss.str();
190 
191   std::unique_ptr<FlightClient> client;
192   Location location1;
193   Location location2;
194   ASSERT_OK(Location::Parse(uri, &location1));
195   ASSERT_OK(Location::Parse(uri, &location2));
196   ASSERT_OK(FlightClient::Connect(location1, &client));
197   ASSERT_OK(FlightClient::Connect(location2, &client));
198 }
199 
200 #ifndef _WIN32
TEST(TestFlight,ConnectUriUnix)201 TEST(TestFlight, ConnectUriUnix) {
202   TestServer server("flight-test-server", "/tmp/flight-test.sock");
203   server.Start();
204   ASSERT_TRUE(server.IsRunning());
205 
206   std::stringstream ss;
207   ss << "grpc+unix://" << server.unix_sock();
208   std::string uri = ss.str();
209 
210   std::unique_ptr<FlightClient> client;
211   Location location1;
212   Location location2;
213   ASSERT_OK(Location::Parse(uri, &location1));
214   ASSERT_OK(Location::Parse(uri, &location2));
215   ASSERT_OK(FlightClient::Connect(location1, &client));
216   ASSERT_OK(FlightClient::Connect(location2, &client));
217 }
218 #endif
219 
TEST(TestFlight,RoundTripTypes)220 TEST(TestFlight, RoundTripTypes) {
221   Ticket ticket{"foo"};
222   std::string ticket_serialized;
223   Ticket ticket_deserialized;
224   ASSERT_OK(ticket.SerializeToString(&ticket_serialized));
225   ASSERT_OK(Ticket::Deserialize(ticket_serialized, &ticket_deserialized));
226   ASSERT_EQ(ticket.ticket, ticket_deserialized.ticket);
227 
228   FlightDescriptor desc = FlightDescriptor::Command("select * from foo;");
229   std::string desc_serialized;
230   FlightDescriptor desc_deserialized;
231   ASSERT_OK(desc.SerializeToString(&desc_serialized));
232   ASSERT_OK(FlightDescriptor::Deserialize(desc_serialized, &desc_deserialized));
233   ASSERT_TRUE(desc.Equals(desc_deserialized));
234 
235   desc = FlightDescriptor::Path({"a", "b", "test.arrow"});
236   ASSERT_OK(desc.SerializeToString(&desc_serialized));
237   ASSERT_OK(FlightDescriptor::Deserialize(desc_serialized, &desc_deserialized));
238   ASSERT_TRUE(desc.Equals(desc_deserialized));
239 
240   FlightInfo::Data data;
241   std::shared_ptr<Schema> schema =
242       arrow::schema({field("a", int64()), field("b", int64()), field("c", int64()),
243                      field("d", int64())});
244   Location location1, location2, location3;
245   ASSERT_OK(Location::ForGrpcTcp("localhost", 10010, &location1));
246   ASSERT_OK(Location::ForGrpcTls("localhost", 10010, &location2));
247   ASSERT_OK(Location::ForGrpcUnix("/tmp/test.sock", &location3));
248   std::vector<FlightEndpoint> endpoints{FlightEndpoint{ticket, {location1, location2}},
249                                         FlightEndpoint{ticket, {location3}}};
250   ASSERT_OK(MakeFlightInfo(*schema, desc, endpoints, -1, -1, &data));
251   std::unique_ptr<FlightInfo> info = std::unique_ptr<FlightInfo>(new FlightInfo(data));
252   std::string info_serialized;
253   std::unique_ptr<FlightInfo> info_deserialized;
254   ASSERT_OK(info->SerializeToString(&info_serialized));
255   ASSERT_OK(FlightInfo::Deserialize(info_serialized, &info_deserialized));
256   ASSERT_TRUE(info->descriptor().Equals(info_deserialized->descriptor()));
257   ASSERT_EQ(info->endpoints(), info_deserialized->endpoints());
258   ASSERT_EQ(info->total_records(), info_deserialized->total_records());
259   ASSERT_EQ(info->total_bytes(), info_deserialized->total_bytes());
260 }
261 
TEST(TestFlight,RoundtripStatus)262 TEST(TestFlight, RoundtripStatus) {
263   // Make sure status codes round trip through our conversions
264 
265   std::shared_ptr<FlightStatusDetail> detail;
266   detail = FlightStatusDetail::UnwrapStatus(
267       MakeFlightError(FlightStatusCode::Internal, "Test message"));
268   ASSERT_NE(nullptr, detail);
269   ASSERT_EQ(FlightStatusCode::Internal, detail->code());
270 
271   detail = FlightStatusDetail::UnwrapStatus(
272       MakeFlightError(FlightStatusCode::TimedOut, "Test message"));
273   ASSERT_NE(nullptr, detail);
274   ASSERT_EQ(FlightStatusCode::TimedOut, detail->code());
275 
276   detail = FlightStatusDetail::UnwrapStatus(
277       MakeFlightError(FlightStatusCode::Cancelled, "Test message"));
278   ASSERT_NE(nullptr, detail);
279   ASSERT_EQ(FlightStatusCode::Cancelled, detail->code());
280 
281   detail = FlightStatusDetail::UnwrapStatus(
282       MakeFlightError(FlightStatusCode::Unauthenticated, "Test message"));
283   ASSERT_NE(nullptr, detail);
284   ASSERT_EQ(FlightStatusCode::Unauthenticated, detail->code());
285 
286   detail = FlightStatusDetail::UnwrapStatus(
287       MakeFlightError(FlightStatusCode::Unauthorized, "Test message"));
288   ASSERT_NE(nullptr, detail);
289   ASSERT_EQ(FlightStatusCode::Unauthorized, detail->code());
290 
291   detail = FlightStatusDetail::UnwrapStatus(
292       MakeFlightError(FlightStatusCode::Unavailable, "Test message"));
293   ASSERT_NE(nullptr, detail);
294   ASSERT_EQ(FlightStatusCode::Unavailable, detail->code());
295 
296   Status status = internal::FromGrpcStatus(
297       internal::ToGrpcStatus(Status::NotImplemented("Sentinel")));
298   ASSERT_TRUE(status.IsNotImplemented());
299   ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
300 
301   status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::Invalid("Sentinel")));
302   ASSERT_TRUE(status.IsInvalid());
303   ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
304 
305   status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::KeyError("Sentinel")));
306   ASSERT_TRUE(status.IsKeyError());
307   ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
308 
309   status =
310       internal::FromGrpcStatus(internal::ToGrpcStatus(Status::AlreadyExists("Sentinel")));
311   ASSERT_TRUE(status.IsAlreadyExists());
312   ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
313 }
314 
TEST(TestFlight,GetPort)315 TEST(TestFlight, GetPort) {
316   Location location;
317   std::unique_ptr<FlightServerBase> server = ExampleTestServer();
318 
319   ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location));
320   FlightServerOptions options(location);
321   ASSERT_OK(server->Init(options));
322   ASSERT_GT(server->port(), 0);
323 }
324 
325 // CI environments don't have an IPv6 interface configured
TEST(TestFlight,DISABLED_IpV6Port)326 TEST(TestFlight, DISABLED_IpV6Port) {
327   Location location, location2;
328   std::unique_ptr<FlightServerBase> server = ExampleTestServer();
329 
330   ASSERT_OK(Location::ForGrpcTcp("[::1]", 0, &location));
331   FlightServerOptions options(location);
332   ASSERT_OK(server->Init(options));
333   ASSERT_GT(server->port(), 0);
334 
335   ASSERT_OK(Location::ForGrpcTcp("[::1]", server->port(), &location2));
336   std::unique_ptr<FlightClient> client;
337   ASSERT_OK(FlightClient::Connect(location2, &client));
338   std::unique_ptr<FlightListing> listing;
339   ASSERT_OK(client->ListFlights(&listing));
340 }
341 
TEST(TestFlight,BuilderHook)342 TEST(TestFlight, BuilderHook) {
343   Location location;
344   std::unique_ptr<FlightServerBase> server = ExampleTestServer();
345 
346   ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location));
347   FlightServerOptions options(location);
348   bool builder_hook_run = false;
349   options.builder_hook = [&builder_hook_run](void* builder) {
350     ASSERT_NE(nullptr, builder);
351     builder_hook_run = true;
352   };
353   ASSERT_OK(server->Init(options));
354   ASSERT_TRUE(builder_hook_run);
355   ASSERT_GT(server->port(), 0);
356   ASSERT_OK(server->Shutdown());
357 }
358 
359 // ----------------------------------------------------------------------
360 // Client tests
361 
362 // Helper to initialize a server and matching client with callbacks to
363 // populate options.
364 template <typename T, typename... Args>
MakeServer(std::unique_ptr<FlightServerBase> * server,std::unique_ptr<FlightClient> * client,std::function<Status (FlightServerOptions *)> make_server_options,std::function<Status (FlightClientOptions *)> make_client_options,Args &&...server_args)365 Status MakeServer(std::unique_ptr<FlightServerBase>* server,
366                   std::unique_ptr<FlightClient>* client,
367                   std::function<Status(FlightServerOptions*)> make_server_options,
368                   std::function<Status(FlightClientOptions*)> make_client_options,
369                   Args&&... server_args) {
370   Location location;
371   RETURN_NOT_OK(Location::ForGrpcTcp("localhost", 0, &location));
372   *server = arrow::internal::make_unique<T>(std::forward<Args>(server_args)...);
373   FlightServerOptions server_options(location);
374   RETURN_NOT_OK(make_server_options(&server_options));
375   RETURN_NOT_OK((*server)->Init(server_options));
376   Location real_location;
377   RETURN_NOT_OK(Location::ForGrpcTcp("localhost", (*server)->port(), &real_location));
378   FlightClientOptions client_options = FlightClientOptions::Defaults();
379   RETURN_NOT_OK(make_client_options(&client_options));
380   return FlightClient::Connect(real_location, client_options, client);
381 }
382 
383 class TestFlightClient : public ::testing::Test {
384  public:
SetUp()385   void SetUp() {
386     server_ = ExampleTestServer();
387 
388     Location location;
389     ASSERT_OK(Location::ForGrpcTcp("localhost", 0, &location));
390     FlightServerOptions options(location);
391     ASSERT_OK(server_->Init(options));
392 
393     ASSERT_OK(ConnectClient());
394   }
395 
TearDown()396   void TearDown() { ASSERT_OK(server_->Shutdown()); }
397 
ConnectClient()398   Status ConnectClient() {
399     Location location;
400     RETURN_NOT_OK(Location::ForGrpcTcp("localhost", server_->port(), &location));
401     return FlightClient::Connect(location, &client_);
402   }
403 
404   template <typename EndpointCheckFunc>
CheckDoGet(const FlightDescriptor & descr,const BatchVector & expected_batches,EndpointCheckFunc && check_endpoints)405   void CheckDoGet(const FlightDescriptor& descr, const BatchVector& expected_batches,
406                   EndpointCheckFunc&& check_endpoints) {
407     auto expected_schema = expected_batches[0]->schema();
408 
409     std::unique_ptr<FlightInfo> info;
410     ASSERT_OK(client_->GetFlightInfo(descr, &info));
411     check_endpoints(info->endpoints());
412 
413     std::shared_ptr<Schema> schema;
414     ipc::DictionaryMemo dict_memo;
415     ASSERT_OK(info->GetSchema(&dict_memo, &schema));
416     AssertSchemaEqual(*expected_schema, *schema);
417 
418     // By convention, fetch the first endpoint
419     Ticket ticket = info->endpoints()[0].ticket;
420     CheckDoGet(ticket, expected_batches);
421   }
422 
CheckDoGet(const Ticket & ticket,const BatchVector & expected_batches)423   void CheckDoGet(const Ticket& ticket, const BatchVector& expected_batches) {
424     auto num_batches = static_cast<int>(expected_batches.size());
425     ASSERT_GE(num_batches, 2);
426 
427     std::unique_ptr<FlightStreamReader> stream;
428     ASSERT_OK(client_->DoGet(ticket, &stream));
429 
430     std::unique_ptr<FlightStreamReader> stream2;
431     ASSERT_OK(client_->DoGet(ticket, &stream2));
432     ASSERT_OK_AND_ASSIGN(auto reader, MakeRecordBatchReader(std::move(stream2)));
433 
434     FlightStreamChunk chunk;
435     std::shared_ptr<RecordBatch> batch;
436     for (int i = 0; i < num_batches; ++i) {
437       ASSERT_OK(stream->Next(&chunk));
438       ASSERT_OK(reader->ReadNext(&batch));
439       ASSERT_NE(nullptr, chunk.data);
440       ASSERT_NE(nullptr, batch);
441 #if !defined(__MINGW32__)
442       ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data);
443       ASSERT_BATCHES_EQUAL(*expected_batches[i], *batch);
444 #else
445       // In MINGW32, the following code does not have the reproducibility at the LSB
446       // even when this is called twice with the same seed.
447       // As a workaround, use approxEqual
448       //   /* from GenerateTypedData in random.cc */
449       //   std::default_random_engine rng(seed);  // seed = 282475250
450       //   std::uniform_real_distribution<double> dist;
451       //   std::generate(data, data + n,          // n = 10
452       //                 [&dist, &rng] { return static_cast<ValueType>(dist(rng)); });
453       //   /* data[1] = 0x40852cdfe23d3976 or 0x40852cdfe23d3975 */
454       ASSERT_BATCHES_APPROX_EQUAL(*expected_batches[i], *chunk.data);
455       ASSERT_BATCHES_APPROX_EQUAL(*expected_batches[i], *batch);
456 #endif
457     }
458 
459     // Stream exhausted
460     ASSERT_OK(stream->Next(&chunk));
461     ASSERT_OK(reader->ReadNext(&batch));
462     ASSERT_EQ(nullptr, chunk.data);
463     ASSERT_EQ(nullptr, batch);
464   }
465 
466  protected:
467   std::unique_ptr<FlightClient> client_;
468   std::unique_ptr<FlightServerBase> server_;
469 };
470 
471 class AuthTestServer : public FlightServerBase {
DoAction(const ServerCallContext & context,const Action & action,std::unique_ptr<ResultStream> * result)472   Status DoAction(const ServerCallContext& context, const Action& action,
473                   std::unique_ptr<ResultStream>* result) override {
474     auto buf = Buffer::FromString(context.peer_identity());
475     auto peer = Buffer::FromString(context.peer());
476     *result = std::unique_ptr<ResultStream>(
477         new SimpleResultStream({Result{buf}, Result{peer}}));
478     return Status::OK();
479   }
480 };
481 
482 class TlsTestServer : public FlightServerBase {
DoAction(const ServerCallContext & context,const Action & action,std::unique_ptr<ResultStream> * result)483   Status DoAction(const ServerCallContext& context, const Action& action,
484                   std::unique_ptr<ResultStream>* result) override {
485     auto buf = Buffer::FromString("Hello, world!");
486     *result = std::unique_ptr<ResultStream>(new SimpleResultStream({Result{buf}}));
487     return Status::OK();
488   }
489 };
490 
491 class DoPutTestServer : public FlightServerBase {
492  public:
DoPut(const ServerCallContext & context,std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMetadataWriter> writer)493   Status DoPut(const ServerCallContext& context,
494                std::unique_ptr<FlightMessageReader> reader,
495                std::unique_ptr<FlightMetadataWriter> writer) override {
496     descriptor_ = reader->descriptor();
497     return reader->ReadAll(&batches_);
498   }
499 
500  protected:
501   FlightDescriptor descriptor_;
502   BatchVector batches_;
503 
504   friend class TestDoPut;
505 };
506 
507 class MetadataTestServer : public FlightServerBase {
DoGet(const ServerCallContext & context,const Ticket & request,std::unique_ptr<FlightDataStream> * data_stream)508   Status DoGet(const ServerCallContext& context, const Ticket& request,
509                std::unique_ptr<FlightDataStream>* data_stream) override {
510     BatchVector batches;
511     if (request.ticket == "dicts") {
512       RETURN_NOT_OK(ExampleDictBatches(&batches));
513     } else if (request.ticket == "floats") {
514       RETURN_NOT_OK(ExampleFloatBatches(&batches));
515     } else {
516       RETURN_NOT_OK(ExampleIntBatches(&batches));
517     }
518     std::shared_ptr<RecordBatchReader> batch_reader =
519         std::make_shared<BatchIterator>(batches[0]->schema(), batches);
520 
521     *data_stream = std::unique_ptr<FlightDataStream>(new NumberingStream(
522         std::unique_ptr<FlightDataStream>(new RecordBatchStream(batch_reader))));
523     return Status::OK();
524   }
525 
DoPut(const ServerCallContext & context,std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMetadataWriter> writer)526   Status DoPut(const ServerCallContext& context,
527                std::unique_ptr<FlightMessageReader> reader,
528                std::unique_ptr<FlightMetadataWriter> writer) override {
529     FlightStreamChunk chunk;
530     int counter = 0;
531     while (true) {
532       RETURN_NOT_OK(reader->Next(&chunk));
533       if (chunk.data == nullptr) break;
534       if (chunk.app_metadata == nullptr) {
535         return Status::Invalid("Expected application metadata to be provided");
536       }
537       if (std::to_string(counter) != chunk.app_metadata->ToString()) {
538         return Status::Invalid("Expected metadata value: " + std::to_string(counter) +
539                                " but got: " + chunk.app_metadata->ToString());
540       }
541       auto metadata = Buffer::FromString(std::to_string(counter));
542       RETURN_NOT_OK(writer->WriteMetadata(*metadata));
543       counter++;
544     }
545     return Status::OK();
546   }
547 };
548 
549 // Server for testing custom IPC options support
550 class OptionsTestServer : public FlightServerBase {
DoGet(const ServerCallContext & context,const Ticket & request,std::unique_ptr<FlightDataStream> * data_stream)551   Status DoGet(const ServerCallContext& context, const Ticket& request,
552                std::unique_ptr<FlightDataStream>* data_stream) override {
553     BatchVector batches;
554     RETURN_NOT_OK(ExampleNestedBatches(&batches));
555     auto reader = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
556     *data_stream = std::unique_ptr<FlightDataStream>(new RecordBatchStream(reader));
557     return Status::OK();
558   }
559 
560   // Just echo the number of batches written. The client will try to
561   // call this method with different write options set.
DoPut(const ServerCallContext & context,std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMetadataWriter> writer)562   Status DoPut(const ServerCallContext& context,
563                std::unique_ptr<FlightMessageReader> reader,
564                std::unique_ptr<FlightMetadataWriter> writer) override {
565     FlightStreamChunk chunk;
566     int counter = 0;
567     while (true) {
568       RETURN_NOT_OK(reader->Next(&chunk));
569       if (chunk.data == nullptr) break;
570       counter++;
571     }
572     auto metadata = Buffer::FromString(std::to_string(counter));
573     return writer->WriteMetadata(*metadata);
574   }
575 
576   // Echo client data, but with write options set to limit the nesting
577   // level.
DoExchange(const ServerCallContext & context,std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMessageWriter> writer)578   Status DoExchange(const ServerCallContext& context,
579                     std::unique_ptr<FlightMessageReader> reader,
580                     std::unique_ptr<FlightMessageWriter> writer) override {
581     FlightStreamChunk chunk;
582     auto options = ipc::IpcWriteOptions::Defaults();
583     options.max_recursion_depth = 1;
584     bool begun = false;
585     while (true) {
586       RETURN_NOT_OK(reader->Next(&chunk));
587       if (!chunk.data && !chunk.app_metadata) {
588         break;
589       }
590       if (!begun && chunk.data) {
591         begun = true;
592         RETURN_NOT_OK(writer->Begin(chunk.data->schema(), options));
593       }
594       if (chunk.data && chunk.app_metadata) {
595         RETURN_NOT_OK(writer->WriteWithMetadata(*chunk.data, chunk.app_metadata));
596       } else if (chunk.data) {
597         RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data));
598       } else if (chunk.app_metadata) {
599         RETURN_NOT_OK(writer->WriteMetadata(chunk.app_metadata));
600       }
601     }
602     return Status::OK();
603   }
604 };
605 
606 class HeaderAuthTestServer : public FlightServerBase {
607  public:
ListFlights(const ServerCallContext & context,const Criteria * criteria,std::unique_ptr<FlightListing> * listings)608   Status ListFlights(const ServerCallContext& context, const Criteria* criteria,
609                      std::unique_ptr<FlightListing>* listings) override {
610     return Status::OK();
611   }
612 };
613 
614 class TestMetadata : public ::testing::Test {
615  public:
SetUp()616   void SetUp() {
617     ASSERT_OK(MakeServer<MetadataTestServer>(
618         &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
619         [](FlightClientOptions* options) { return Status::OK(); }));
620   }
621 
TearDown()622   void TearDown() { ASSERT_OK(server_->Shutdown()); }
623 
624  protected:
625   std::unique_ptr<FlightClient> client_;
626   std::unique_ptr<FlightServerBase> server_;
627 };
628 
629 class TestOptions : public ::testing::Test {
630  public:
SetUp()631   void SetUp() {
632     ASSERT_OK(MakeServer<OptionsTestServer>(
633         &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
634         [](FlightClientOptions* options) { return Status::OK(); }));
635   }
636 
TearDown()637   void TearDown() { ASSERT_OK(server_->Shutdown()); }
638 
639  protected:
640   std::unique_ptr<FlightClient> client_;
641   std::unique_ptr<FlightServerBase> server_;
642 };
643 
644 class TestAuthHandler : public ::testing::Test {
645  public:
SetUp()646   void SetUp() {
647     ASSERT_OK(MakeServer<AuthTestServer>(
648         &server_, &client_,
649         [](FlightServerOptions* options) {
650           options->auth_handler = std::unique_ptr<ServerAuthHandler>(
651               new TestServerAuthHandler("user", "p4ssw0rd"));
652           return Status::OK();
653         },
654         [](FlightClientOptions* options) { return Status::OK(); }));
655   }
656 
TearDown()657   void TearDown() { ASSERT_OK(server_->Shutdown()); }
658 
659  protected:
660   std::unique_ptr<FlightClient> client_;
661   std::unique_ptr<FlightServerBase> server_;
662 };
663 
664 class TestBasicAuthHandler : public ::testing::Test {
665  public:
SetUp()666   void SetUp() {
667     ASSERT_OK(MakeServer<AuthTestServer>(
668         &server_, &client_,
669         [](FlightServerOptions* options) {
670           options->auth_handler = std::unique_ptr<ServerAuthHandler>(
671               new TestServerBasicAuthHandler("user", "p4ssw0rd"));
672           return Status::OK();
673         },
674         [](FlightClientOptions* options) { return Status::OK(); }));
675   }
676 
TearDown()677   void TearDown() { ASSERT_OK(server_->Shutdown()); }
678 
679  protected:
680   std::unique_ptr<FlightClient> client_;
681   std::unique_ptr<FlightServerBase> server_;
682 };
683 
684 class TestDoPut : public ::testing::Test {
685  public:
SetUp()686   void SetUp() {
687     ASSERT_OK(MakeServer<DoPutTestServer>(
688         &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
689         [](FlightClientOptions* options) { return Status::OK(); }));
690     do_put_server_ = (DoPutTestServer*)server_.get();
691   }
692 
TearDown()693   void TearDown() { ASSERT_OK(server_->Shutdown()); }
694 
CheckBatches(FlightDescriptor expected_descriptor,const BatchVector & expected_batches)695   void CheckBatches(FlightDescriptor expected_descriptor,
696                     const BatchVector& expected_batches) {
697     ASSERT_TRUE(do_put_server_->descriptor_.Equals(expected_descriptor));
698     ASSERT_EQ(do_put_server_->batches_.size(), expected_batches.size());
699     for (size_t i = 0; i < expected_batches.size(); ++i) {
700       ASSERT_BATCHES_EQUAL(*do_put_server_->batches_[i], *expected_batches[i]);
701     }
702   }
703 
CheckDoPut(FlightDescriptor descr,const std::shared_ptr<Schema> & schema,const BatchVector & batches)704   void CheckDoPut(FlightDescriptor descr, const std::shared_ptr<Schema>& schema,
705                   const BatchVector& batches) {
706     std::unique_ptr<FlightStreamWriter> stream;
707     std::unique_ptr<FlightMetadataReader> reader;
708     ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader));
709     for (const auto& batch : batches) {
710       ASSERT_OK(stream->WriteRecordBatch(*batch));
711     }
712     ASSERT_OK(stream->DoneWriting());
713     ASSERT_OK(stream->Close());
714 
715     CheckBatches(descr, batches);
716   }
717 
718  protected:
719   std::unique_ptr<FlightClient> client_;
720   std::unique_ptr<FlightServerBase> server_;
721   DoPutTestServer* do_put_server_;
722 };
723 
724 class TestTls : public ::testing::Test {
725  public:
SetUp()726   void SetUp() {
727     // Manually initialize gRPC to try to ensure some thread-locals
728     // get initialized.
729     // https://github.com/grpc/grpc/issues/13856
730     // https://github.com/grpc/grpc/issues/20311
731     // In general, gRPC on MacOS struggles with TLS (both in the sense
732     // of thread-locals and encryption)
733     grpc_init();
734 
735     server_.reset(new TlsTestServer);
736 
737     Location location;
738     ASSERT_OK(Location::ForGrpcTls("localhost", 0, &location));
739     FlightServerOptions options(location);
740     ASSERT_RAISES(UnknownError, server_->Init(options));
741     ASSERT_OK(ExampleTlsCertificates(&options.tls_certificates));
742     ASSERT_OK(server_->Init(options));
743 
744     ASSERT_OK(Location::ForGrpcTls("localhost", server_->port(), &location_));
745     ASSERT_OK(ConnectClient());
746   }
747 
TearDown()748   void TearDown() {
749     ASSERT_OK(server_->Shutdown());
750     grpc_shutdown();
751   }
752 
ConnectClient()753   Status ConnectClient() {
754     auto options = FlightClientOptions::Defaults();
755     CertKeyPair root_cert;
756     RETURN_NOT_OK(ExampleTlsCertificateRoot(&root_cert));
757     options.tls_root_certs = root_cert.pem_cert;
758     return FlightClient::Connect(location_, options, &client_);
759   }
760 
761  protected:
762   Location location_;
763   std::unique_ptr<FlightClient> client_;
764   std::unique_ptr<FlightServerBase> server_;
765 };
766 
767 // A server middleware that rejects all calls.
768 class RejectServerMiddlewareFactory : public ServerMiddlewareFactory {
StartCall(const CallInfo & info,const CallHeaders & incoming_headers,std::shared_ptr<ServerMiddleware> * middleware)769   Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
770                    std::shared_ptr<ServerMiddleware>* middleware) override {
771     return MakeFlightError(FlightStatusCode::Unauthenticated, "All calls are rejected");
772   }
773 };
774 
775 // A server middleware that counts the number of successful and failed
776 // calls.
777 class CountingServerMiddleware : public ServerMiddleware {
778  public:
CountingServerMiddleware(std::atomic<int> * successful,std::atomic<int> * failed)779   CountingServerMiddleware(std::atomic<int>* successful, std::atomic<int>* failed)
780       : successful_(successful), failed_(failed) {}
SendingHeaders(AddCallHeaders * outgoing_headers)781   void SendingHeaders(AddCallHeaders* outgoing_headers) override {}
CallCompleted(const Status & status)782   void CallCompleted(const Status& status) override {
783     if (status.ok()) {
784       ARROW_IGNORE_EXPR((*successful_)++);
785     } else {
786       ARROW_IGNORE_EXPR((*failed_)++);
787     }
788   }
789 
name() const790   std::string name() const override { return "CountingServerMiddleware"; }
791 
792  private:
793   std::atomic<int>* successful_;
794   std::atomic<int>* failed_;
795 };
796 
797 class CountingServerMiddlewareFactory : public ServerMiddlewareFactory {
798  public:
CountingServerMiddlewareFactory()799   CountingServerMiddlewareFactory() : successful_(0), failed_(0) {}
800 
StartCall(const CallInfo & info,const CallHeaders & incoming_headers,std::shared_ptr<ServerMiddleware> * middleware)801   Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
802                    std::shared_ptr<ServerMiddleware>* middleware) override {
803     *middleware = std::make_shared<CountingServerMiddleware>(&successful_, &failed_);
804     return Status::OK();
805   }
806 
807   std::atomic<int> successful_;
808   std::atomic<int> failed_;
809 };
810 
811 // The current span ID, used to emulate OpenTracing style distributed
812 // tracing. Only used for communication between application code and
813 // client middleware.
814 static thread_local std::string current_span_id = "";
815 
816 // A server middleware that stores the current span ID, in an
817 // emulation of OpenTracing style distributed tracing.
818 class TracingServerMiddleware : public ServerMiddleware {
819  public:
TracingServerMiddleware(const std::string & current_span_id)820   explicit TracingServerMiddleware(const std::string& current_span_id)
821       : span_id(current_span_id) {}
SendingHeaders(AddCallHeaders * outgoing_headers)822   void SendingHeaders(AddCallHeaders* outgoing_headers) override {}
CallCompleted(const Status & status)823   void CallCompleted(const Status& status) override {}
824 
name() const825   std::string name() const override { return "TracingServerMiddleware"; }
826 
827   std::string span_id;
828 };
829 
830 class TracingServerMiddlewareFactory : public ServerMiddlewareFactory {
831  public:
TracingServerMiddlewareFactory()832   TracingServerMiddlewareFactory() {}
833 
StartCall(const CallInfo & info,const CallHeaders & incoming_headers,std::shared_ptr<ServerMiddleware> * middleware)834   Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
835                    std::shared_ptr<ServerMiddleware>* middleware) override {
836     const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>& iter_pair =
837         incoming_headers.equal_range("x-tracing-span-id");
838     if (iter_pair.first != iter_pair.second) {
839       const util::string_view& value = (*iter_pair.first).second;
840       *middleware = std::make_shared<TracingServerMiddleware>(std::string(value));
841     }
842     return Status::OK();
843   }
844 };
845 
846 // Function to look in CallHeaders for a key that has a value starting with prefix and
847 // return the rest of the value after the prefix.
FindKeyValPrefixInCallHeaders(const CallHeaders & incoming_headers,const std::string & key,const std::string & prefix)848 std::string FindKeyValPrefixInCallHeaders(const CallHeaders& incoming_headers,
849                                           const std::string& key,
850                                           const std::string& prefix) {
851   // Lambda function to compare characters without case sensitivity.
852   auto char_compare = [](const char& char1, const char& char2) {
853     return (::toupper(char1) == ::toupper(char2));
854   };
855 
856   auto iter = incoming_headers.find(key);
857   if (iter == incoming_headers.end()) {
858     return "";
859   }
860   const std::string val = iter->second.to_string();
861   if (val.size() > prefix.length()) {
862     if (std::equal(val.begin(), val.begin() + prefix.length(), prefix.begin(),
863                    char_compare)) {
864       return val.substr(prefix.length());
865     }
866   }
867   return "";
868 }
869 
870 class HeaderAuthServerMiddleware : public ServerMiddleware {
871  public:
SendingHeaders(AddCallHeaders * outgoing_headers)872   void SendingHeaders(AddCallHeaders* outgoing_headers) override {
873     outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + kBearerToken);
874   }
875 
CallCompleted(const Status & status)876   void CallCompleted(const Status& status) override {}
877 
name() const878   std::string name() const override { return "HeaderAuthServerMiddleware"; }
879 };
880 
ParseBasicHeader(const CallHeaders & incoming_headers,std::string & username,std::string & password)881 void ParseBasicHeader(const CallHeaders& incoming_headers, std::string& username,
882                       std::string& password) {
883   std::string encoded_credentials =
884       FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix);
885   std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials));
886   std::getline(decoded_stream, username, ':');
887   std::getline(decoded_stream, password, ':');
888 }
889 
890 // Factory for base64 header authentication testing.
891 class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory {
892  public:
HeaderAuthServerMiddlewareFactory()893   HeaderAuthServerMiddlewareFactory() {}
894 
StartCall(const CallInfo & info,const CallHeaders & incoming_headers,std::shared_ptr<ServerMiddleware> * middleware)895   Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
896                    std::shared_ptr<ServerMiddleware>* middleware) override {
897     std::string username, password;
898     ParseBasicHeader(incoming_headers, username, password);
899     if ((username == kValidUsername) && (password == kValidPassword)) {
900       *middleware = std::make_shared<HeaderAuthServerMiddleware>();
901     } else if ((username == kInvalidUsername) && (password == kInvalidPassword)) {
902       return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid credentials");
903     }
904     return Status::OK();
905   }
906 };
907 
908 // A server middleware for validating incoming bearer header authentication.
909 class BearerAuthServerMiddleware : public ServerMiddleware {
910  public:
BearerAuthServerMiddleware(const CallHeaders & incoming_headers,bool * isValid)911   explicit BearerAuthServerMiddleware(const CallHeaders& incoming_headers, bool* isValid)
912       : isValid_(isValid) {
913     incoming_headers_ = incoming_headers;
914   }
915 
SendingHeaders(AddCallHeaders * outgoing_headers)916   void SendingHeaders(AddCallHeaders* outgoing_headers) override {
917     std::string bearer_token =
918         FindKeyValPrefixInCallHeaders(incoming_headers_, kAuthHeader, kBearerPrefix);
919     *isValid_ = (bearer_token == std::string(kBearerToken));
920   }
921 
CallCompleted(const Status & status)922   void CallCompleted(const Status& status) override {}
923 
name() const924   std::string name() const override { return "BearerAuthServerMiddleware"; }
925 
926  private:
927   CallHeaders incoming_headers_;
928   bool* isValid_;
929 };
930 
931 // Factory for base64 header authentication testing.
932 class BearerAuthServerMiddlewareFactory : public ServerMiddlewareFactory {
933  public:
BearerAuthServerMiddlewareFactory()934   BearerAuthServerMiddlewareFactory() : isValid_(false) {}
935 
StartCall(const CallInfo & info,const CallHeaders & incoming_headers,std::shared_ptr<ServerMiddleware> * middleware)936   Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
937                    std::shared_ptr<ServerMiddleware>* middleware) override {
938     const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>& iter_pair =
939         incoming_headers.equal_range(kAuthHeader);
940     if (iter_pair.first != iter_pair.second) {
941       *middleware =
942           std::make_shared<BearerAuthServerMiddleware>(incoming_headers, &isValid_);
943     }
944     return Status::OK();
945   }
946 
GetIsValid()947   bool GetIsValid() { return isValid_; }
948 
949  private:
950   bool isValid_;
951 };
952 
953 // A client middleware that adds a thread-local "request ID" to
954 // outgoing calls as a header, and keeps track of the status of
955 // completed calls. NOT thread-safe.
956 class PropagatingClientMiddleware : public ClientMiddleware {
957  public:
PropagatingClientMiddleware(std::atomic<int> * received_headers,std::vector<Status> * recorded_status)958   explicit PropagatingClientMiddleware(std::atomic<int>* received_headers,
959                                        std::vector<Status>* recorded_status)
960       : received_headers_(received_headers), recorded_status_(recorded_status) {}
961 
SendingHeaders(AddCallHeaders * outgoing_headers)962   void SendingHeaders(AddCallHeaders* outgoing_headers) {
963     // Pick up the span ID from thread locals. We have to use a
964     // thread-local for communication, since we aren't even
965     // instantiated until after the application code has already
966     // started the call (and so there's no chance for application code
967     // to pass us parameters directly).
968     outgoing_headers->AddHeader("x-tracing-span-id", current_span_id);
969   }
970 
ReceivedHeaders(const CallHeaders & incoming_headers)971   void ReceivedHeaders(const CallHeaders& incoming_headers) { (*received_headers_)++; }
972 
CallCompleted(const Status & status)973   void CallCompleted(const Status& status) { recorded_status_->push_back(status); }
974 
975  private:
976   std::atomic<int>* received_headers_;
977   std::vector<Status>* recorded_status_;
978 };
979 
980 class PropagatingClientMiddlewareFactory : public ClientMiddlewareFactory {
981  public:
StartCall(const CallInfo & info,std::unique_ptr<ClientMiddleware> * middleware)982   void StartCall(const CallInfo& info, std::unique_ptr<ClientMiddleware>* middleware) {
983     recorded_calls_.push_back(info.method);
984     *middleware = arrow::internal::make_unique<PropagatingClientMiddleware>(
985         &received_headers_, &recorded_status_);
986   }
987 
Reset()988   void Reset() {
989     recorded_calls_.clear();
990     recorded_status_.clear();
991     received_headers_.fetch_and(0);
992   }
993 
994   std::vector<FlightMethod> recorded_calls_;
995   std::vector<Status> recorded_status_;
996   std::atomic<int> received_headers_;
997 };
998 
999 class ReportContextTestServer : public FlightServerBase {
DoAction(const ServerCallContext & context,const Action & action,std::unique_ptr<ResultStream> * result)1000   Status DoAction(const ServerCallContext& context, const Action& action,
1001                   std::unique_ptr<ResultStream>* result) override {
1002     std::shared_ptr<Buffer> buf;
1003     const ServerMiddleware* middleware = context.GetMiddleware("tracing");
1004     if (middleware == nullptr || middleware->name() != "TracingServerMiddleware") {
1005       buf = Buffer::FromString("");
1006     } else {
1007       buf = Buffer::FromString(((const TracingServerMiddleware*)middleware)->span_id);
1008     }
1009     *result = std::unique_ptr<ResultStream>(new SimpleResultStream({Result{buf}}));
1010     return Status::OK();
1011   }
1012 };
1013 
1014 class ErrorMiddlewareServer : public FlightServerBase {
DoAction(const ServerCallContext & context,const Action & action,std::unique_ptr<ResultStream> * result)1015   Status DoAction(const ServerCallContext& context, const Action& action,
1016                   std::unique_ptr<ResultStream>* result) override {
1017     std::string msg = "error_message";
1018     auto buf = Buffer::FromString("");
1019 
1020     std::shared_ptr<FlightStatusDetail> flightStatusDetail(
1021         new FlightStatusDetail(FlightStatusCode::Failed, msg));
1022     *result = std::unique_ptr<ResultStream>(new SimpleResultStream({Result{buf}}));
1023     return Status(StatusCode::ExecutionError, "test failed", flightStatusDetail);
1024   }
1025 };
1026 
1027 class PropagatingTestServer : public FlightServerBase {
1028  public:
PropagatingTestServer(std::unique_ptr<FlightClient> client)1029   explicit PropagatingTestServer(std::unique_ptr<FlightClient> client)
1030       : client_(std::move(client)) {}
1031 
DoAction(const ServerCallContext & context,const Action & action,std::unique_ptr<ResultStream> * result)1032   Status DoAction(const ServerCallContext& context, const Action& action,
1033                   std::unique_ptr<ResultStream>* result) override {
1034     const ServerMiddleware* middleware = context.GetMiddleware("tracing");
1035     if (middleware == nullptr || middleware->name() != "TracingServerMiddleware") {
1036       current_span_id = "";
1037     } else {
1038       current_span_id = ((const TracingServerMiddleware*)middleware)->span_id;
1039     }
1040 
1041     return client_->DoAction(action, result);
1042   }
1043 
1044  private:
1045   std::unique_ptr<FlightClient> client_;
1046 };
1047 
1048 class TestRejectServerMiddleware : public ::testing::Test {
1049  public:
SetUp()1050   void SetUp() {
1051     ASSERT_OK(MakeServer<MetadataTestServer>(
1052         &server_, &client_,
1053         [](FlightServerOptions* options) {
1054           options->middleware.push_back(
1055               {"reject", std::make_shared<RejectServerMiddlewareFactory>()});
1056           return Status::OK();
1057         },
1058         [](FlightClientOptions* options) { return Status::OK(); }));
1059   }
1060 
TearDown()1061   void TearDown() { ASSERT_OK(server_->Shutdown()); }
1062 
1063  protected:
1064   std::unique_ptr<FlightClient> client_;
1065   std::unique_ptr<FlightServerBase> server_;
1066 };
1067 
1068 class TestCountingServerMiddleware : public ::testing::Test {
1069  public:
SetUp()1070   void SetUp() {
1071     request_counter_ = std::make_shared<CountingServerMiddlewareFactory>();
1072     ASSERT_OK(MakeServer<MetadataTestServer>(
1073         &server_, &client_,
1074         [&](FlightServerOptions* options) {
1075           options->middleware.push_back({"request_counter", request_counter_});
1076           return Status::OK();
1077         },
1078         [](FlightClientOptions* options) { return Status::OK(); }));
1079   }
1080 
TearDown()1081   void TearDown() { ASSERT_OK(server_->Shutdown()); }
1082 
1083  protected:
1084   std::shared_ptr<CountingServerMiddlewareFactory> request_counter_;
1085   std::unique_ptr<FlightClient> client_;
1086   std::unique_ptr<FlightServerBase> server_;
1087 };
1088 
1089 // Setup for this test is 2 servers
1090 // 1. Client makes request to server A with a request ID set
1091 // 2. server A extracts the request ID and makes a request to server B
1092 //    with the same request ID set
1093 // 3. server B extracts the request ID and sends it back
1094 // 4. server A returns the response of server B
1095 // 5. Client validates the response
1096 class TestPropagatingMiddleware : public ::testing::Test {
1097  public:
SetUp()1098   void SetUp() {
1099     server_middleware_ = std::make_shared<TracingServerMiddlewareFactory>();
1100     second_client_middleware_ = std::make_shared<PropagatingClientMiddlewareFactory>();
1101     client_middleware_ = std::make_shared<PropagatingClientMiddlewareFactory>();
1102 
1103     std::unique_ptr<FlightClient> server_client;
1104     ASSERT_OK(MakeServer<ReportContextTestServer>(
1105         &second_server_, &server_client,
1106         [&](FlightServerOptions* options) {
1107           options->middleware.push_back({"tracing", server_middleware_});
1108           return Status::OK();
1109         },
1110         [&](FlightClientOptions* options) {
1111           options->middleware.push_back(second_client_middleware_);
1112           return Status::OK();
1113         }));
1114 
1115     ASSERT_OK(MakeServer<PropagatingTestServer>(
1116         &first_server_, &client_,
1117         [&](FlightServerOptions* options) {
1118           options->middleware.push_back({"tracing", server_middleware_});
1119           return Status::OK();
1120         },
1121         [&](FlightClientOptions* options) {
1122           options->middleware.push_back(client_middleware_);
1123           return Status::OK();
1124         },
1125         std::move(server_client)));
1126   }
1127 
ValidateStatus(const Status & status,const FlightMethod & method)1128   void ValidateStatus(const Status& status, const FlightMethod& method) {
1129     ASSERT_EQ(1, client_middleware_->received_headers_);
1130     ASSERT_EQ(method, client_middleware_->recorded_calls_.at(0));
1131     ASSERT_EQ(status.code(), client_middleware_->recorded_status_.at(0).code());
1132   }
1133 
TearDown()1134   void TearDown() {
1135     ASSERT_OK(first_server_->Shutdown());
1136     ASSERT_OK(second_server_->Shutdown());
1137   }
1138 
CheckHeader(const std::string & header,const std::string & value,const CallHeaders::const_iterator & it)1139   void CheckHeader(const std::string& header, const std::string& value,
1140                    const CallHeaders::const_iterator& it) {
1141     // Construct a string_view before comparison to satisfy MSVC
1142     util::string_view header_view(header.data(), header.length());
1143     util::string_view value_view(value.data(), value.length());
1144     ASSERT_EQ(header_view, (*it).first);
1145     ASSERT_EQ(value_view, (*it).second);
1146   }
1147 
1148  protected:
1149   std::unique_ptr<FlightClient> client_;
1150   std::unique_ptr<FlightServerBase> first_server_;
1151   std::unique_ptr<FlightServerBase> second_server_;
1152   std::shared_ptr<TracingServerMiddlewareFactory> server_middleware_;
1153   std::shared_ptr<PropagatingClientMiddlewareFactory> second_client_middleware_;
1154   std::shared_ptr<PropagatingClientMiddlewareFactory> client_middleware_;
1155 };
1156 
1157 class TestErrorMiddleware : public ::testing::Test {
1158  public:
SetUp()1159   void SetUp() {
1160     ASSERT_OK(MakeServer<ErrorMiddlewareServer>(
1161         &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
1162         [](FlightClientOptions* options) { return Status::OK(); }));
1163   }
1164 
TearDown()1165   void TearDown() { ASSERT_OK(server_->Shutdown()); }
1166 
1167  protected:
1168   std::unique_ptr<FlightClient> client_;
1169   std::unique_ptr<FlightServerBase> server_;
1170 };
1171 
1172 class TestBasicHeaderAuthMiddleware : public ::testing::Test {
1173  public:
SetUp()1174   void SetUp() {
1175     header_middleware_ = std::make_shared<HeaderAuthServerMiddlewareFactory>();
1176     bearer_middleware_ = std::make_shared<BearerAuthServerMiddlewareFactory>();
1177     std::pair<std::string, std::string> bearer = make_pair(
1178         kAuthHeader, std::string(kBearerPrefix) + " " + std::string(kBearerToken));
1179     ASSERT_OK(MakeServer<HeaderAuthTestServer>(
1180         &server_, &client_,
1181         [&](FlightServerOptions* options) {
1182           options->auth_handler =
1183               std::unique_ptr<ServerAuthHandler>(new NoOpAuthHandler());
1184           options->middleware.push_back({"header-auth-server", header_middleware_});
1185           options->middleware.push_back({"bearer-auth-server", bearer_middleware_});
1186           return Status::OK();
1187         },
1188         [&](FlightClientOptions* options) { return Status::OK(); }));
1189   }
1190 
RunValidClientAuth()1191   void RunValidClientAuth() {
1192     arrow::Result<std::pair<std::string, std::string>> bearer_result =
1193         client_->AuthenticateBasicToken({}, kValidUsername, kValidPassword);
1194     ASSERT_OK(bearer_result.status());
1195     ASSERT_EQ(bearer_result.ValueOrDie().first, kAuthHeader);
1196     ASSERT_EQ(bearer_result.ValueOrDie().second,
1197               (std::string(kBearerPrefix) + kBearerToken));
1198     std::unique_ptr<FlightListing> listing;
1199     FlightCallOptions call_options;
1200     call_options.headers.push_back(bearer_result.ValueOrDie());
1201     ASSERT_OK(client_->ListFlights(call_options, {}, &listing));
1202     ASSERT_TRUE(bearer_middleware_->GetIsValid());
1203   }
1204 
RunInvalidClientAuth()1205   void RunInvalidClientAuth() {
1206     arrow::Result<std::pair<std::string, std::string>> bearer_result =
1207         client_->AuthenticateBasicToken({}, kInvalidUsername, kInvalidPassword);
1208     ASSERT_RAISES(IOError, bearer_result.status());
1209     ASSERT_THAT(bearer_result.status().message(),
1210                 ::testing::HasSubstr("Invalid credentials"));
1211   }
1212 
TearDown()1213   void TearDown() { ASSERT_OK(server_->Shutdown()); }
1214 
1215  protected:
1216   std::unique_ptr<FlightClient> client_;
1217   std::unique_ptr<FlightServerBase> server_;
1218   std::shared_ptr<HeaderAuthServerMiddlewareFactory> header_middleware_;
1219   std::shared_ptr<BearerAuthServerMiddlewareFactory> bearer_middleware_;
1220 };
1221 
1222 // This test keeps an internal cookie cache and compares that with the middleware.
1223 class TestCookieMiddleware : public ::testing::Test {
1224  public:
1225   // Setup function creates middleware factory and starts it up.
SetUp()1226   void SetUp() {
1227     factory_ = GetCookieFactory();
1228     CallInfo callInfo;
1229     factory_->StartCall(callInfo, &middleware_);
1230   }
1231 
1232   // Function to add incoming cookies to middleware and validate them.
AddAndValidate(const std::string & incoming_cookie)1233   void AddAndValidate(const std::string& incoming_cookie) {
1234     // Add cookie
1235     CallHeaders call_headers;
1236     call_headers.insert(std::make_pair(arrow::util::string_view("set-cookie"),
1237                                        arrow::util::string_view(incoming_cookie)));
1238     middleware_->ReceivedHeaders(call_headers);
1239     expected_cookie_cache_.UpdateCachedCookies(call_headers);
1240 
1241     // Get cookie from middleware.
1242     TestCallHeaders add_call_headers;
1243     middleware_->SendingHeaders(&add_call_headers);
1244     const std::string actual_cookies = add_call_headers.GetCookies();
1245 
1246     // Validate cookie
1247     const std::string expected_cookies = expected_cookie_cache_.GetValidCookiesAsString();
1248     const std::vector<std::string> split_expected_cookies =
1249         SplitCookies(expected_cookies);
1250     const std::vector<std::string> split_actual_cookies = SplitCookies(actual_cookies);
1251     EXPECT_EQ(split_expected_cookies, split_actual_cookies);
1252   }
1253 
1254   // Function to take a list of cookies and split them into a vector of individual
1255   // cookies. This is done because the cookie cache is a map so ordering is not
1256   // necessarily consistent.
SplitCookies(const std::string & cookies)1257   static std::vector<std::string> SplitCookies(const std::string& cookies) {
1258     std::vector<std::string> split_cookies;
1259     std::string::size_type pos1 = 0;
1260     std::string::size_type pos2 = 0;
1261     while ((pos2 = cookies.find(';', pos1)) != std::string::npos) {
1262       split_cookies.push_back(
1263           arrow::internal::TrimString(cookies.substr(pos1, pos2 - pos1)));
1264       pos1 = pos2 + 1;
1265     }
1266     if (pos1 < cookies.size()) {
1267       split_cookies.push_back(arrow::internal::TrimString(cookies.substr(pos1)));
1268     }
1269     std::sort(split_cookies.begin(), split_cookies.end());
1270     return split_cookies;
1271   }
1272 
1273  protected:
1274   // Class to allow testing of the call headers.
1275   class TestCallHeaders : public AddCallHeaders {
1276    public:
TestCallHeaders()1277     TestCallHeaders() {}
~TestCallHeaders()1278     ~TestCallHeaders() {}
1279 
1280     // Function to add cookie header.
AddHeader(const std::string & key,const std::string & value)1281     void AddHeader(const std::string& key, const std::string& value) {
1282       ASSERT_EQ(key, "cookie");
1283       outbound_cookie_ = value;
1284     }
1285 
1286     // Function to get outgoing cookie.
GetCookies()1287     std::string GetCookies() { return outbound_cookie_; }
1288 
1289    private:
1290     std::string outbound_cookie_;
1291   };
1292 
1293   internal::CookieCache expected_cookie_cache_;
1294   std::unique_ptr<ClientMiddleware> middleware_;
1295   std::shared_ptr<ClientMiddlewareFactory> factory_;
1296 };
1297 
1298 // This test is used to test the parsing capabilities of the cookie framework.
1299 class TestCookieParsing : public ::testing::Test {
1300  public:
VerifyParseCookie(const std::string & cookie_str,bool expired)1301   void VerifyParseCookie(const std::string& cookie_str, bool expired) {
1302     internal::Cookie cookie = internal::Cookie::parse(cookie_str);
1303     EXPECT_EQ(expired, cookie.IsExpired());
1304   }
1305 
VerifyCookieName(const std::string & cookie_str,const std::string & name)1306   void VerifyCookieName(const std::string& cookie_str, const std::string& name) {
1307     internal::Cookie cookie = internal::Cookie::parse(cookie_str);
1308     EXPECT_EQ(name, cookie.GetName());
1309   }
1310 
VerifyCookieString(const std::string & cookie_str,const std::string & cookie_as_string)1311   void VerifyCookieString(const std::string& cookie_str,
1312                           const std::string& cookie_as_string) {
1313     internal::Cookie cookie = internal::Cookie::parse(cookie_str);
1314     EXPECT_EQ(cookie_as_string, cookie.AsCookieString());
1315   }
1316 
VerifyCookieDateConverson(std::string date,const std::string & converted_date)1317   void VerifyCookieDateConverson(std::string date, const std::string& converted_date) {
1318     internal::Cookie::ConvertCookieDate(&date);
1319     EXPECT_EQ(converted_date, date);
1320   }
1321 
VerifyCookieAttributeParsing(const std::string cookie_str,std::string::size_type start_pos,const util::optional<std::pair<std::string,std::string>> cookie_attribute,const std::string::size_type start_pos_after)1322   void VerifyCookieAttributeParsing(
1323       const std::string cookie_str, std::string::size_type start_pos,
1324       const util::optional<std::pair<std::string, std::string>> cookie_attribute,
1325       const std::string::size_type start_pos_after) {
1326     util::optional<std::pair<std::string, std::string>> attr =
1327         internal::Cookie::ParseCookieAttribute(cookie_str, &start_pos);
1328 
1329     if (cookie_attribute == util::nullopt) {
1330       EXPECT_EQ(cookie_attribute, attr);
1331     } else {
1332       EXPECT_EQ(cookie_attribute.value(), attr.value());
1333     }
1334     EXPECT_EQ(start_pos_after, start_pos);
1335   }
1336 
AddCookieVerifyCache(const std::vector<std::string> & cookies,const std::string & expected_cookies)1337   void AddCookieVerifyCache(const std::vector<std::string>& cookies,
1338                             const std::string& expected_cookies) {
1339     internal::CookieCache cookie_cache;
1340     for (auto& cookie : cookies) {
1341       // Add cookie
1342       CallHeaders call_headers;
1343       call_headers.insert(std::make_pair(arrow::util::string_view("set-cookie"),
1344                                          arrow::util::string_view(cookie)));
1345       cookie_cache.UpdateCachedCookies(call_headers);
1346     }
1347     const std::string actual_cookies = cookie_cache.GetValidCookiesAsString();
1348     const std::vector<std::string> actual_split_cookies =
1349         TestCookieMiddleware::SplitCookies(actual_cookies);
1350     const std::vector<std::string> expected_split_cookies =
1351         TestCookieMiddleware::SplitCookies(expected_cookies);
1352   }
1353 };
1354 
TEST_F(TestErrorMiddleware,TestMetadata)1355 TEST_F(TestErrorMiddleware, TestMetadata) {
1356   Action action;
1357   std::unique_ptr<ResultStream> stream;
1358 
1359   // Run action1
1360   action.type = "action1";
1361 
1362   action.body = Buffer::FromString("action1-content");
1363   Status s = client_->DoAction(action, &stream);
1364   ASSERT_FALSE(s.ok());
1365   std::shared_ptr<FlightStatusDetail> flightStatusDetail =
1366       FlightStatusDetail::UnwrapStatus(s);
1367   ASSERT_TRUE(flightStatusDetail);
1368   ASSERT_EQ(flightStatusDetail->extra_info(), "error_message");
1369 }
1370 
TEST_F(TestFlightClient,ListFlights)1371 TEST_F(TestFlightClient, ListFlights) {
1372   std::unique_ptr<FlightListing> listing;
1373   ASSERT_OK(client_->ListFlights(&listing));
1374   ASSERT_TRUE(listing != nullptr);
1375 
1376   std::vector<FlightInfo> flights = ExampleFlightInfo();
1377 
1378   std::unique_ptr<FlightInfo> info;
1379   for (const FlightInfo& flight : flights) {
1380     ASSERT_OK(listing->Next(&info));
1381     AssertEqual(flight, *info);
1382   }
1383   ASSERT_OK(listing->Next(&info));
1384   ASSERT_TRUE(info == nullptr);
1385 
1386   ASSERT_OK(listing->Next(&info));
1387   ASSERT_TRUE(info == nullptr);
1388 }
1389 
TEST_F(TestFlightClient,ListFlightsWithCriteria)1390 TEST_F(TestFlightClient, ListFlightsWithCriteria) {
1391   std::unique_ptr<FlightListing> listing;
1392   ASSERT_OK(client_->ListFlights(FlightCallOptions(), {"foo"}, &listing));
1393   std::unique_ptr<FlightInfo> info;
1394   ASSERT_OK(listing->Next(&info));
1395   ASSERT_TRUE(info == nullptr);
1396 }
1397 
TEST_F(TestFlightClient,GetFlightInfo)1398 TEST_F(TestFlightClient, GetFlightInfo) {
1399   auto descr = FlightDescriptor::Path({"examples", "ints"});
1400   std::unique_ptr<FlightInfo> info;
1401 
1402   ASSERT_OK(client_->GetFlightInfo(descr, &info));
1403   ASSERT_NE(info, nullptr);
1404 
1405   std::vector<FlightInfo> flights = ExampleFlightInfo();
1406   AssertEqual(flights[0], *info);
1407 }
1408 
TEST_F(TestFlightClient,GetSchema)1409 TEST_F(TestFlightClient, GetSchema) {
1410   auto descr = FlightDescriptor::Path({"examples", "ints"});
1411   std::unique_ptr<SchemaResult> schema_result;
1412   std::shared_ptr<Schema> schema;
1413   ipc::DictionaryMemo dict_memo;
1414 
1415   ASSERT_OK(client_->GetSchema(descr, &schema_result));
1416   ASSERT_NE(schema_result, nullptr);
1417   ASSERT_OK(schema_result->GetSchema(&dict_memo, &schema));
1418 }
1419 
TEST_F(TestFlightClient,GetFlightInfoNotFound)1420 TEST_F(TestFlightClient, GetFlightInfoNotFound) {
1421   auto descr = FlightDescriptor::Path({"examples", "things"});
1422   std::unique_ptr<FlightInfo> info;
1423   // XXX Ideally should be Invalid (or KeyError), but gRPC doesn't support
1424   // multiple error codes.
1425   auto st = client_->GetFlightInfo(descr, &info);
1426   ASSERT_RAISES(Invalid, st);
1427   ASSERT_NE(st.message().find("Flight not found"), std::string::npos);
1428 }
1429 
TEST_F(TestFlightClient,DoGetInts)1430 TEST_F(TestFlightClient, DoGetInts) {
1431   auto descr = FlightDescriptor::Path({"examples", "ints"});
1432   BatchVector expected_batches;
1433   ASSERT_OK(ExampleIntBatches(&expected_batches));
1434 
1435   auto check_endpoints = [](const std::vector<FlightEndpoint>& endpoints) {
1436     // Two endpoints in the example FlightInfo
1437     ASSERT_EQ(2, endpoints.size());
1438     AssertEqual(Ticket{"ticket-ints-1"}, endpoints[0].ticket);
1439   };
1440 
1441   CheckDoGet(descr, expected_batches, check_endpoints);
1442 }
1443 
TEST_F(TestFlightClient,DoGetFloats)1444 TEST_F(TestFlightClient, DoGetFloats) {
1445   auto descr = FlightDescriptor::Path({"examples", "floats"});
1446   BatchVector expected_batches;
1447   ASSERT_OK(ExampleFloatBatches(&expected_batches));
1448 
1449   auto check_endpoints = [](const std::vector<FlightEndpoint>& endpoints) {
1450     // One endpoint in the example FlightInfo
1451     ASSERT_EQ(1, endpoints.size());
1452     AssertEqual(Ticket{"ticket-floats-1"}, endpoints[0].ticket);
1453   };
1454 
1455   CheckDoGet(descr, expected_batches, check_endpoints);
1456 }
1457 
TEST_F(TestFlightClient,DoGetDicts)1458 TEST_F(TestFlightClient, DoGetDicts) {
1459   auto descr = FlightDescriptor::Path({"examples", "dicts"});
1460   BatchVector expected_batches;
1461   ASSERT_OK(ExampleDictBatches(&expected_batches));
1462 
1463   auto check_endpoints = [](const std::vector<FlightEndpoint>& endpoints) {
1464     // One endpoint in the example FlightInfo
1465     ASSERT_EQ(1, endpoints.size());
1466     AssertEqual(Ticket{"ticket-dicts-1"}, endpoints[0].ticket);
1467   };
1468 
1469   CheckDoGet(descr, expected_batches, check_endpoints);
1470 }
1471 
1472 // Ensure the gRPC client is configured to allow large messages
1473 // Tests a 32 MiB batch
TEST_F(TestFlightClient,DoGetLargeBatch)1474 TEST_F(TestFlightClient, DoGetLargeBatch) {
1475   BatchVector expected_batches;
1476   ASSERT_OK(ExampleLargeBatches(&expected_batches));
1477   Ticket ticket{"ticket-large-batch-1"};
1478   CheckDoGet(ticket, expected_batches);
1479 }
1480 
TEST_F(TestFlightClient,FlightDataOverflowServerBatch)1481 TEST_F(TestFlightClient, FlightDataOverflowServerBatch) {
1482   // Regression test for ARROW-13253
1483   // N.B. this is rather a slow and memory-hungry test
1484   {
1485     // DoGet: check for overflow on large batch
1486     Ticket ticket{"ARROW-13253-DoGet-Batch"};
1487     std::unique_ptr<FlightStreamReader> stream;
1488     ASSERT_OK(client_->DoGet(ticket, &stream));
1489     FlightStreamChunk chunk;
1490     EXPECT_RAISES_WITH_MESSAGE_THAT(
1491         Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
1492         stream->Next(&chunk));
1493   }
1494   {
1495     // DoExchange: check for overflow on large batch from server
1496     auto descr = FlightDescriptor::Command("large_batch");
1497     std::unique_ptr<FlightStreamReader> reader;
1498     std::unique_ptr<FlightStreamWriter> writer;
1499     ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1500     BatchVector batches;
1501     EXPECT_RAISES_WITH_MESSAGE_THAT(
1502         Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
1503         reader->ReadAll(&batches));
1504   }
1505 }
1506 
TEST_F(TestFlightClient,FlightDataOverflowClientBatch)1507 TEST_F(TestFlightClient, FlightDataOverflowClientBatch) {
1508   ASSERT_OK_AND_ASSIGN(auto batch, VeryLargeBatch());
1509   {
1510     // DoPut: check for overflow on large batch
1511     std::unique_ptr<FlightStreamWriter> stream;
1512     std::unique_ptr<FlightMetadataReader> reader;
1513     auto descr = FlightDescriptor::Path({""});
1514     ASSERT_OK(client_->DoPut(descr, batch->schema(), &stream, &reader));
1515     EXPECT_RAISES_WITH_MESSAGE_THAT(
1516         Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
1517         stream->WriteRecordBatch(*batch));
1518     ASSERT_OK(stream->Close());
1519   }
1520   {
1521     // DoExchange: check for overflow on large batch from client
1522     auto descr = FlightDescriptor::Command("counter");
1523     std::unique_ptr<FlightStreamReader> reader;
1524     std::unique_ptr<FlightStreamWriter> writer;
1525     ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1526     ASSERT_OK(writer->Begin(batch->schema()));
1527     EXPECT_RAISES_WITH_MESSAGE_THAT(
1528         Invalid, ::testing::HasSubstr("Cannot send record batches exceeding 2GiB yet"),
1529         writer->WriteRecordBatch(*batch));
1530     ASSERT_OK(writer->Close());
1531   }
1532 }
1533 
TEST_F(TestFlightClient,DoExchange)1534 TEST_F(TestFlightClient, DoExchange) {
1535   auto descr = FlightDescriptor::Command("counter");
1536   BatchVector batches;
1537   auto a1 = ArrayFromJSON(int32(), "[4, 5, 6, null]");
1538   auto schema = arrow::schema({field("f1", a1->type())});
1539   batches.push_back(RecordBatch::Make(schema, a1->length(), {a1}));
1540   std::unique_ptr<FlightStreamReader> reader;
1541   std::unique_ptr<FlightStreamWriter> writer;
1542   ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1543   ASSERT_OK(writer->Begin(schema));
1544   for (const auto& batch : batches) {
1545     ASSERT_OK(writer->WriteRecordBatch(*batch));
1546   }
1547   ASSERT_OK(writer->DoneWriting());
1548   FlightStreamChunk chunk;
1549   ASSERT_OK(reader->Next(&chunk));
1550   ASSERT_NE(nullptr, chunk.app_metadata);
1551   ASSERT_EQ(nullptr, chunk.data);
1552   ASSERT_EQ("1", chunk.app_metadata->ToString());
1553   ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema());
1554   AssertSchemaEqual(schema, server_schema);
1555   for (const auto& batch : batches) {
1556     ASSERT_OK(reader->Next(&chunk));
1557     ASSERT_BATCHES_EQUAL(*batch, *chunk.data);
1558   }
1559   ASSERT_OK(writer->Close());
1560 }
1561 
1562 // Test pure-metadata DoExchange to ensure nothing blocks waiting for
1563 // schema messages
TEST_F(TestFlightClient,DoExchangeNoData)1564 TEST_F(TestFlightClient, DoExchangeNoData) {
1565   auto descr = FlightDescriptor::Command("counter");
1566   std::unique_ptr<FlightStreamReader> reader;
1567   std::unique_ptr<FlightStreamWriter> writer;
1568   ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1569   ASSERT_OK(writer->DoneWriting());
1570   FlightStreamChunk chunk;
1571   ASSERT_OK(reader->Next(&chunk));
1572   ASSERT_EQ(nullptr, chunk.data);
1573   ASSERT_NE(nullptr, chunk.app_metadata);
1574   ASSERT_EQ("0", chunk.app_metadata->ToString());
1575   ASSERT_OK(writer->Close());
1576 }
1577 
1578 // Test sending a schema without any data, as this hits an edge case
1579 // in the client-side writer.
TEST_F(TestFlightClient,DoExchangeWriteOnlySchema)1580 TEST_F(TestFlightClient, DoExchangeWriteOnlySchema) {
1581   auto descr = FlightDescriptor::Command("counter");
1582   std::unique_ptr<FlightStreamReader> reader;
1583   std::unique_ptr<FlightStreamWriter> writer;
1584   ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1585   auto schema = arrow::schema({field("f1", arrow::int32())});
1586   ASSERT_OK(writer->Begin(schema));
1587   ASSERT_OK(writer->WriteMetadata(Buffer::FromString("foo")));
1588   ASSERT_OK(writer->DoneWriting());
1589   FlightStreamChunk chunk;
1590   ASSERT_OK(reader->Next(&chunk));
1591   ASSERT_EQ(nullptr, chunk.data);
1592   ASSERT_NE(nullptr, chunk.app_metadata);
1593   ASSERT_EQ("0", chunk.app_metadata->ToString());
1594   ASSERT_OK(writer->Close());
1595 }
1596 
1597 // Emulate DoGet
TEST_F(TestFlightClient,DoExchangeGet)1598 TEST_F(TestFlightClient, DoExchangeGet) {
1599   auto descr = FlightDescriptor::Command("get");
1600   std::unique_ptr<FlightStreamReader> reader;
1601   std::unique_ptr<FlightStreamWriter> writer;
1602   ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1603   ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema());
1604   AssertSchemaEqual(*ExampleIntSchema(), *server_schema);
1605   BatchVector batches;
1606   ASSERT_OK(ExampleIntBatches(&batches));
1607   FlightStreamChunk chunk;
1608   for (const auto& batch : batches) {
1609     ASSERT_OK(reader->Next(&chunk));
1610     ASSERT_NE(nullptr, chunk.data);
1611     AssertBatchesEqual(*batch, *chunk.data);
1612   }
1613   ASSERT_OK(reader->Next(&chunk));
1614   ASSERT_EQ(nullptr, chunk.data);
1615   ASSERT_EQ(nullptr, chunk.app_metadata);
1616   ASSERT_OK(writer->Close());
1617 }
1618 
1619 // Emulate DoPut
TEST_F(TestFlightClient,DoExchangePut)1620 TEST_F(TestFlightClient, DoExchangePut) {
1621   auto descr = FlightDescriptor::Command("put");
1622   std::unique_ptr<FlightStreamReader> reader;
1623   std::unique_ptr<FlightStreamWriter> writer;
1624   ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1625   ASSERT_OK(writer->Begin(ExampleIntSchema()));
1626   BatchVector batches;
1627   ASSERT_OK(ExampleIntBatches(&batches));
1628   for (const auto& batch : batches) {
1629     ASSERT_OK(writer->WriteRecordBatch(*batch));
1630   }
1631   ASSERT_OK(writer->DoneWriting());
1632   FlightStreamChunk chunk;
1633   ASSERT_OK(reader->Next(&chunk));
1634   ASSERT_NE(nullptr, chunk.app_metadata);
1635   AssertBufferEqual(*chunk.app_metadata, "done");
1636   ASSERT_OK(reader->Next(&chunk));
1637   ASSERT_EQ(nullptr, chunk.data);
1638   ASSERT_EQ(nullptr, chunk.app_metadata);
1639   ASSERT_OK(writer->Close());
1640 }
1641 
1642 // Test the echo server
TEST_F(TestFlightClient,DoExchangeEcho)1643 TEST_F(TestFlightClient, DoExchangeEcho) {
1644   auto descr = FlightDescriptor::Command("echo");
1645   std::unique_ptr<FlightStreamReader> reader;
1646   std::unique_ptr<FlightStreamWriter> writer;
1647   ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1648   ASSERT_OK(writer->Begin(ExampleIntSchema()));
1649   BatchVector batches;
1650   FlightStreamChunk chunk;
1651   ASSERT_OK(ExampleIntBatches(&batches));
1652   for (const auto& batch : batches) {
1653     ASSERT_OK(writer->WriteRecordBatch(*batch));
1654     ASSERT_OK(reader->Next(&chunk));
1655     ASSERT_NE(nullptr, chunk.data);
1656     ASSERT_EQ(nullptr, chunk.app_metadata);
1657     AssertBatchesEqual(*batch, *chunk.data);
1658   }
1659   for (int i = 0; i < 10; i++) {
1660     const auto buf = Buffer::FromString(std::to_string(i));
1661     ASSERT_OK(writer->WriteMetadata(buf));
1662     ASSERT_OK(reader->Next(&chunk));
1663     ASSERT_EQ(nullptr, chunk.data);
1664     ASSERT_NE(nullptr, chunk.app_metadata);
1665     AssertBufferEqual(*buf, *chunk.app_metadata);
1666   }
1667   int index = 0;
1668   for (const auto& batch : batches) {
1669     const auto buf = Buffer::FromString(std::to_string(index));
1670     ASSERT_OK(writer->WriteWithMetadata(*batch, buf));
1671     ASSERT_OK(reader->Next(&chunk));
1672     ASSERT_NE(nullptr, chunk.data);
1673     ASSERT_NE(nullptr, chunk.app_metadata);
1674     AssertBatchesEqual(*batch, *chunk.data);
1675     AssertBufferEqual(*buf, *chunk.app_metadata);
1676     index++;
1677   }
1678   ASSERT_OK(writer->DoneWriting());
1679   ASSERT_OK(reader->Next(&chunk));
1680   ASSERT_EQ(nullptr, chunk.data);
1681   ASSERT_EQ(nullptr, chunk.app_metadata);
1682   ASSERT_OK(writer->Close());
1683 }
1684 
1685 // Test interleaved reading/writing
TEST_F(TestFlightClient,DoExchangeTotal)1686 TEST_F(TestFlightClient, DoExchangeTotal) {
1687   auto descr = FlightDescriptor::Command("total");
1688   std::unique_ptr<FlightStreamReader> reader;
1689   std::unique_ptr<FlightStreamWriter> writer;
1690   {
1691     auto a1 = ArrayFromJSON(arrow::int32(), "[4, 5, 6, null]");
1692     auto schema = arrow::schema({field("f1", a1->type())});
1693     // XXX: as noted in flight/client.cc, Begin() is lazy and the
1694     // schema message won't be written until some data is also
1695     // written. There's also timing issues; hence we check each status
1696     // here.
1697     EXPECT_RAISES_WITH_MESSAGE_THAT(
1698         Invalid, ::testing::HasSubstr("Field is not INT64: f1"), ([&]() {
1699           RETURN_NOT_OK(client_->DoExchange(descr, &writer, &reader));
1700           RETURN_NOT_OK(writer->Begin(schema));
1701           auto batch = RecordBatch::Make(schema, /* num_rows */ 4, {a1});
1702           RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
1703           return writer->Close();
1704         })());
1705   }
1706   {
1707     auto a1 = ArrayFromJSON(arrow::int64(), "[1, 2, null, 3]");
1708     auto a2 = ArrayFromJSON(arrow::int64(), "[null, 4, 5, 6]");
1709     auto schema = arrow::schema({field("f1", a1->type()), field("f2", a2->type())});
1710     ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1711     ASSERT_OK(writer->Begin(schema));
1712     auto batch = RecordBatch::Make(schema, /* num_rows */ 4, {a1, a2});
1713     FlightStreamChunk chunk;
1714     ASSERT_OK(writer->WriteRecordBatch(*batch));
1715     ASSERT_OK_AND_ASSIGN(auto server_schema, reader->GetSchema());
1716     AssertSchemaEqual(*schema, *server_schema);
1717 
1718     ASSERT_OK(reader->Next(&chunk));
1719     ASSERT_NE(nullptr, chunk.data);
1720     auto expected1 = RecordBatch::Make(
1721         schema, /* num_rows */ 1,
1722         {ArrayFromJSON(arrow::int64(), "[6]"), ArrayFromJSON(arrow::int64(), "[15]")});
1723     AssertBatchesEqual(*expected1, *chunk.data);
1724 
1725     ASSERT_OK(writer->WriteRecordBatch(*batch));
1726     ASSERT_OK(reader->Next(&chunk));
1727     ASSERT_NE(nullptr, chunk.data);
1728     auto expected2 = RecordBatch::Make(
1729         schema, /* num_rows */ 1,
1730         {ArrayFromJSON(arrow::int64(), "[12]"), ArrayFromJSON(arrow::int64(), "[30]")});
1731     AssertBatchesEqual(*expected2, *chunk.data);
1732 
1733     ASSERT_OK(writer->Close());
1734   }
1735 }
1736 
1737 // Ensure server errors get propagated no matter what we try
TEST_F(TestFlightClient,DoExchangeError)1738 TEST_F(TestFlightClient, DoExchangeError) {
1739   auto descr = FlightDescriptor::Command("error");
1740   std::unique_ptr<FlightStreamReader> reader;
1741   std::unique_ptr<FlightStreamWriter> writer;
1742   {
1743     ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1744     auto status = writer->Close();
1745     EXPECT_RAISES_WITH_MESSAGE_THAT(
1746         NotImplemented, ::testing::HasSubstr("Expected error"), writer->Close());
1747   }
1748   {
1749     ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1750     FlightStreamChunk chunk;
1751     EXPECT_RAISES_WITH_MESSAGE_THAT(
1752         NotImplemented, ::testing::HasSubstr("Expected error"), reader->Next(&chunk));
1753   }
1754   {
1755     ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
1756     EXPECT_RAISES_WITH_MESSAGE_THAT(
1757         NotImplemented, ::testing::HasSubstr("Expected error"), reader->GetSchema());
1758   }
1759   // writer->Begin isn't tested here because, as noted in client.cc,
1760   // OpenRecordBatchWriter lazily writes the initial message - hence
1761   // Begin() won't fail. Additionally, it appears gRPC may buffer
1762   // writes - a write won't immediately fail even when the server
1763   // immediately fails.
1764 }
1765 
TEST_F(TestFlightClient,ListActions)1766 TEST_F(TestFlightClient, ListActions) {
1767   std::vector<ActionType> actions;
1768   ASSERT_OK(client_->ListActions(&actions));
1769 
1770   std::vector<ActionType> expected = ExampleActionTypes();
1771   AssertEqual(expected, actions);
1772 }
1773 
TEST_F(TestFlightClient,DoAction)1774 TEST_F(TestFlightClient, DoAction) {
1775   Action action;
1776   std::unique_ptr<ResultStream> stream;
1777   std::unique_ptr<Result> result;
1778 
1779   // Run action1
1780   action.type = "action1";
1781 
1782   const std::string action1_value = "action1-content";
1783   action.body = Buffer::FromString(action1_value);
1784   ASSERT_OK(client_->DoAction(action, &stream));
1785 
1786   for (int i = 0; i < 3; ++i) {
1787     ASSERT_OK(stream->Next(&result));
1788     std::string expected = action1_value + "-part" + std::to_string(i);
1789     ASSERT_EQ(expected, result->body->ToString());
1790   }
1791 
1792   // stream consumed
1793   ASSERT_OK(stream->Next(&result));
1794   ASSERT_EQ(nullptr, result);
1795 
1796   // Run action2, no results
1797   action.type = "action2";
1798   ASSERT_OK(client_->DoAction(action, &stream));
1799 
1800   ASSERT_OK(stream->Next(&result));
1801   ASSERT_EQ(nullptr, result);
1802 }
1803 
TEST_F(TestFlightClient,RoundTripStatus)1804 TEST_F(TestFlightClient, RoundTripStatus) {
1805   const auto descr = FlightDescriptor::Command("status-outofmemory");
1806   std::unique_ptr<FlightInfo> info;
1807   const auto status = client_->GetFlightInfo(descr, &info);
1808   ASSERT_RAISES(OutOfMemory, status);
1809 }
1810 
TEST_F(TestFlightClient,Issue5095)1811 TEST_F(TestFlightClient, Issue5095) {
1812   // Make sure the server-side error message is reflected to the
1813   // client
1814   Ticket ticket1{"ARROW-5095-fail"};
1815   std::unique_ptr<FlightStreamReader> stream;
1816   Status status = client_->DoGet(ticket1, &stream);
1817   ASSERT_RAISES(UnknownError, status);
1818   ASSERT_THAT(status.message(), ::testing::HasSubstr("Server-side error"));
1819 
1820   Ticket ticket2{"ARROW-5095-success"};
1821   status = client_->DoGet(ticket2, &stream);
1822   ASSERT_RAISES(KeyError, status);
1823   ASSERT_THAT(status.message(), ::testing::HasSubstr("No data"));
1824 }
1825 
1826 // Test setting generic transport options by configuring gRPC to fail
1827 // all calls.
TEST_F(TestFlightClient,GenericOptions)1828 TEST_F(TestFlightClient, GenericOptions) {
1829   std::unique_ptr<FlightClient> client;
1830   auto options = FlightClientOptions::Defaults();
1831   // Set a very low limit at the gRPC layer to fail all calls
1832   options.generic_options.emplace_back(GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH, 4);
1833   Location location;
1834   ASSERT_OK(Location::ForGrpcTcp("localhost", server_->port(), &location));
1835   ASSERT_OK(FlightClient::Connect(location, options, &client));
1836   auto descr = FlightDescriptor::Path({"examples", "ints"});
1837   std::unique_ptr<SchemaResult> schema_result;
1838   std::shared_ptr<Schema> schema;
1839   ipc::DictionaryMemo dict_memo;
1840   auto status = client->GetSchema(descr, &schema_result);
1841   ASSERT_RAISES(Invalid, status);
1842   ASSERT_THAT(status.message(), ::testing::HasSubstr("resource exhausted"));
1843 }
1844 
TEST_F(TestFlightClient,TimeoutFires)1845 TEST_F(TestFlightClient, TimeoutFires) {
1846   // Server does not exist on this port, so call should fail
1847   std::unique_ptr<FlightClient> client;
1848   Location location;
1849   ASSERT_OK(Location::ForGrpcTcp("localhost", 30001, &location));
1850   ASSERT_OK(FlightClient::Connect(location, &client));
1851   FlightCallOptions options;
1852   options.timeout = TimeoutDuration{0.2};
1853   std::unique_ptr<FlightInfo> info;
1854   auto start = std::chrono::system_clock::now();
1855   Status status = client->GetFlightInfo(options, FlightDescriptor{}, &info);
1856   auto end = std::chrono::system_clock::now();
1857 #ifdef ARROW_WITH_TIMING_TESTS
1858   EXPECT_LE(end - start, std::chrono::milliseconds{400});
1859 #else
1860   ARROW_UNUSED(end - start);
1861 #endif
1862   ASSERT_RAISES(IOError, status);
1863 }
1864 
TEST_F(TestFlightClient,NoTimeout)1865 TEST_F(TestFlightClient, NoTimeout) {
1866   // Call should complete quickly, so timeout should not fire
1867   FlightCallOptions options;
1868   options.timeout = TimeoutDuration{5.0};  // account for slow server process startup
1869   std::unique_ptr<FlightInfo> info;
1870   auto start = std::chrono::system_clock::now();
1871   auto descriptor = FlightDescriptor::Path({"examples", "ints"});
1872   Status status = client_->GetFlightInfo(options, descriptor, &info);
1873   auto end = std::chrono::system_clock::now();
1874 #ifdef ARROW_WITH_TIMING_TESTS
1875   EXPECT_LE(end - start, std::chrono::milliseconds{600});
1876 #else
1877   ARROW_UNUSED(end - start);
1878 #endif
1879   ASSERT_OK(status);
1880   ASSERT_NE(nullptr, info);
1881 }
1882 
TEST_F(TestDoPut,DoPutInts)1883 TEST_F(TestDoPut, DoPutInts) {
1884   auto descr = FlightDescriptor::Path({"ints"});
1885   BatchVector batches;
1886   auto a0 = ArrayFromJSON(int8(), "[0, 1, 127, -128, null]");
1887   auto a1 = ArrayFromJSON(uint8(), "[0, 1, 127, 255, null]");
1888   auto a2 = ArrayFromJSON(int16(), "[0, 258, 32767, -32768, null]");
1889   auto a3 = ArrayFromJSON(uint16(), "[0, 258, 32767, 65535, null]");
1890   auto a4 = ArrayFromJSON(int32(), "[0, 65538, 2147483647, -2147483648, null]");
1891   auto a5 = ArrayFromJSON(uint32(), "[0, 65538, 2147483647, 4294967295, null]");
1892   auto a6 = ArrayFromJSON(
1893       int64(), "[0, 4294967298, 9223372036854775807, -9223372036854775808, null]");
1894   auto a7 = ArrayFromJSON(
1895       uint64(), "[0, 4294967298, 9223372036854775807, 18446744073709551615, null]");
1896   auto schema = arrow::schema({field("f0", a0->type()), field("f1", a1->type()),
1897                                field("f2", a2->type()), field("f3", a3->type()),
1898                                field("f4", a4->type()), field("f5", a5->type()),
1899                                field("f6", a6->type()), field("f7", a7->type())});
1900   batches.push_back(
1901       RecordBatch::Make(schema, a0->length(), {a0, a1, a2, a3, a4, a5, a6, a7}));
1902 
1903   CheckDoPut(descr, schema, batches);
1904 }
1905 
TEST_F(TestDoPut,DoPutFloats)1906 TEST_F(TestDoPut, DoPutFloats) {
1907   auto descr = FlightDescriptor::Path({"floats"});
1908   BatchVector batches;
1909   auto a0 = ArrayFromJSON(float32(), "[0, 1.2, -3.4, 5.6, null]");
1910   auto a1 = ArrayFromJSON(float64(), "[0, 1.2, -3.4, 5.6, null]");
1911   auto schema = arrow::schema({field("f0", a0->type()), field("f1", a1->type())});
1912   batches.push_back(RecordBatch::Make(schema, a0->length(), {a0, a1}));
1913 
1914   CheckDoPut(descr, schema, batches);
1915 }
1916 
TEST_F(TestDoPut,DoPutEmptyBatch)1917 TEST_F(TestDoPut, DoPutEmptyBatch) {
1918   // Sending and receiving a 0-sized batch shouldn't fail
1919   auto descr = FlightDescriptor::Path({"ints"});
1920   BatchVector batches;
1921   auto a1 = ArrayFromJSON(int32(), "[]");
1922   auto schema = arrow::schema({field("f1", a1->type())});
1923   batches.push_back(RecordBatch::Make(schema, a1->length(), {a1}));
1924 
1925   CheckDoPut(descr, schema, batches);
1926 }
1927 
TEST_F(TestDoPut,DoPutDicts)1928 TEST_F(TestDoPut, DoPutDicts) {
1929   auto descr = FlightDescriptor::Path({"dicts"});
1930   BatchVector batches;
1931   auto dict_values = ArrayFromJSON(utf8(), "[\"foo\", \"bar\", \"quux\"]");
1932   auto ty = dictionary(int8(), dict_values->type());
1933   auto schema = arrow::schema({field("f1", ty)});
1934   // Make several batches
1935   for (const char* json : {"[1, 0, 1]", "[null]", "[null, 1]"}) {
1936     auto indices = ArrayFromJSON(int8(), json);
1937     auto dict_array = std::make_shared<DictionaryArray>(ty, indices, dict_values);
1938     batches.push_back(RecordBatch::Make(schema, dict_array->length(), {dict_array}));
1939   }
1940 
1941   CheckDoPut(descr, schema, batches);
1942 }
1943 
1944 // Ensure the gRPC server is configured to allow large messages
1945 // Tests a 32 MiB batch
TEST_F(TestDoPut,DoPutLargeBatch)1946 TEST_F(TestDoPut, DoPutLargeBatch) {
1947   auto descr = FlightDescriptor::Path({"large-batches"});
1948   auto schema = ExampleLargeSchema();
1949   BatchVector batches;
1950   ASSERT_OK(ExampleLargeBatches(&batches));
1951   CheckDoPut(descr, schema, batches);
1952 }
1953 
TEST_F(TestDoPut,DoPutSizeLimit)1954 TEST_F(TestDoPut, DoPutSizeLimit) {
1955   const int64_t size_limit = 4096;
1956   Location location;
1957   ASSERT_OK(Location::ForGrpcTcp("localhost", server_->port(), &location));
1958   auto client_options = FlightClientOptions::Defaults();
1959   client_options.write_size_limit_bytes = size_limit;
1960   std::unique_ptr<FlightClient> client;
1961   ASSERT_OK(FlightClient::Connect(location, client_options, &client));
1962 
1963   auto descr = FlightDescriptor::Path({"ints"});
1964   // Batch is too large to fit in one message
1965   auto schema = arrow::schema({field("f1", arrow::int64())});
1966   auto batch = arrow::ConstantArrayGenerator::Zeroes(768, schema);
1967   BatchVector batches;
1968   batches.push_back(batch->Slice(0, 384));
1969   batches.push_back(batch->Slice(384));
1970 
1971   std::unique_ptr<FlightStreamWriter> stream;
1972   std::unique_ptr<FlightMetadataReader> reader;
1973   ASSERT_OK(client->DoPut(descr, schema, &stream, &reader));
1974 
1975   // Large batch will exceed the limit
1976   const auto status = stream->WriteRecordBatch(*batch);
1977   EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("exceeded soft limit"),
1978                                   status);
1979   auto detail = FlightWriteSizeStatusDetail::UnwrapStatus(status);
1980   ASSERT_NE(nullptr, detail);
1981   ASSERT_EQ(size_limit, detail->limit());
1982   ASSERT_GT(detail->actual(), size_limit);
1983 
1984   // But we can retry with a smaller batch
1985   for (const auto& batch : batches) {
1986     ASSERT_OK(stream->WriteRecordBatch(*batch));
1987   }
1988 
1989   ASSERT_OK(stream->DoneWriting());
1990   ASSERT_OK(stream->Close());
1991   CheckBatches(descr, batches);
1992 }
1993 
TEST_F(TestAuthHandler,PassAuthenticatedCalls)1994 TEST_F(TestAuthHandler, PassAuthenticatedCalls) {
1995   ASSERT_OK(client_->Authenticate(
1996       {},
1997       std::unique_ptr<ClientAuthHandler>(new TestClientAuthHandler("user", "p4ssw0rd"))));
1998 
1999   Status status;
2000   std::unique_ptr<FlightListing> listing;
2001   status = client_->ListFlights(&listing);
2002   ASSERT_RAISES(NotImplemented, status);
2003 
2004   std::unique_ptr<ResultStream> results;
2005   Action action;
2006   action.type = "";
2007   action.body = Buffer::FromString("");
2008   status = client_->DoAction(action, &results);
2009   ASSERT_OK(status);
2010 
2011   std::vector<ActionType> actions;
2012   status = client_->ListActions(&actions);
2013   ASSERT_RAISES(NotImplemented, status);
2014 
2015   std::unique_ptr<FlightInfo> info;
2016   status = client_->GetFlightInfo(FlightDescriptor{}, &info);
2017   ASSERT_RAISES(NotImplemented, status);
2018 
2019   std::unique_ptr<FlightStreamReader> stream;
2020   status = client_->DoGet(Ticket{}, &stream);
2021   ASSERT_RAISES(NotImplemented, status);
2022 
2023   std::unique_ptr<FlightStreamWriter> writer;
2024   std::unique_ptr<FlightMetadataReader> reader;
2025   std::shared_ptr<Schema> schema = arrow::schema({});
2026   status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
2027   ASSERT_OK(status);
2028   status = writer->Close();
2029   ASSERT_RAISES(NotImplemented, status);
2030 }
2031 
TEST_F(TestAuthHandler,FailUnauthenticatedCalls)2032 TEST_F(TestAuthHandler, FailUnauthenticatedCalls) {
2033   Status status;
2034   std::unique_ptr<FlightListing> listing;
2035   status = client_->ListFlights(&listing);
2036   ASSERT_RAISES(IOError, status);
2037   ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2038 
2039   std::unique_ptr<ResultStream> results;
2040   Action action;
2041   action.type = "";
2042   action.body = Buffer::FromString("");
2043   status = client_->DoAction(action, &results);
2044   ASSERT_RAISES(IOError, status);
2045   ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2046 
2047   std::vector<ActionType> actions;
2048   status = client_->ListActions(&actions);
2049   ASSERT_RAISES(IOError, status);
2050   ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2051 
2052   std::unique_ptr<FlightInfo> info;
2053   status = client_->GetFlightInfo(FlightDescriptor{}, &info);
2054   ASSERT_RAISES(IOError, status);
2055   ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2056 
2057   std::unique_ptr<FlightStreamReader> stream;
2058   status = client_->DoGet(Ticket{}, &stream);
2059   ASSERT_RAISES(IOError, status);
2060   ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2061 
2062   std::unique_ptr<FlightStreamWriter> writer;
2063   std::unique_ptr<FlightMetadataReader> reader;
2064   std::shared_ptr<Schema> schema(
2065       (new arrow::Schema(std::vector<std::shared_ptr<Field>>())));
2066   status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
2067   ASSERT_OK(status);
2068   status = writer->Close();
2069   ASSERT_RAISES(IOError, status);
2070   // ARROW-7583: don't check the error message here.
2071   // Because gRPC reports errors in some paths with booleans, instead
2072   // of statuses, we can fail the call without knowing why it fails,
2073   // instead reporting a generic error message. This is
2074   // nondeterministic, so don't assert any particular message here.
2075 }
2076 
TEST_F(TestAuthHandler,CheckPeerIdentity)2077 TEST_F(TestAuthHandler, CheckPeerIdentity) {
2078   ASSERT_OK(client_->Authenticate(
2079       {},
2080       std::unique_ptr<ClientAuthHandler>(new TestClientAuthHandler("user", "p4ssw0rd"))));
2081 
2082   Action action;
2083   action.type = "who-am-i";
2084   action.body = Buffer::FromString("");
2085   std::unique_ptr<ResultStream> results;
2086   ASSERT_OK(client_->DoAction(action, &results));
2087   ASSERT_NE(results, nullptr);
2088 
2089   std::unique_ptr<Result> result;
2090   ASSERT_OK(results->Next(&result));
2091   ASSERT_NE(result, nullptr);
2092   // Action returns the peer identity as the result.
2093   ASSERT_EQ(result->body->ToString(), "user");
2094 
2095   ASSERT_OK(results->Next(&result));
2096   ASSERT_NE(result, nullptr);
2097   // Action returns the peer address as the result.
2098 #ifndef _WIN32
2099   // On Windows gRPC sometimes returns a blank peer address, so don't
2100   // bother checking for it.
2101   ASSERT_NE(result->body->ToString(), "");
2102 #endif
2103 }
2104 
TEST_F(TestBasicAuthHandler,PassAuthenticatedCalls)2105 TEST_F(TestBasicAuthHandler, PassAuthenticatedCalls) {
2106   ASSERT_OK(
2107       client_->Authenticate({}, std::unique_ptr<ClientAuthHandler>(
2108                                     new TestClientBasicAuthHandler("user", "p4ssw0rd"))));
2109 
2110   Status status;
2111   std::unique_ptr<FlightListing> listing;
2112   status = client_->ListFlights(&listing);
2113   ASSERT_RAISES(NotImplemented, status);
2114 
2115   std::unique_ptr<ResultStream> results;
2116   Action action;
2117   action.type = "";
2118   action.body = Buffer::FromString("");
2119   status = client_->DoAction(action, &results);
2120   ASSERT_OK(status);
2121 
2122   std::vector<ActionType> actions;
2123   status = client_->ListActions(&actions);
2124   ASSERT_RAISES(NotImplemented, status);
2125 
2126   std::unique_ptr<FlightInfo> info;
2127   status = client_->GetFlightInfo(FlightDescriptor{}, &info);
2128   ASSERT_RAISES(NotImplemented, status);
2129 
2130   std::unique_ptr<FlightStreamReader> stream;
2131   status = client_->DoGet(Ticket{}, &stream);
2132   ASSERT_RAISES(NotImplemented, status);
2133 
2134   std::unique_ptr<FlightStreamWriter> writer;
2135   std::unique_ptr<FlightMetadataReader> reader;
2136   std::shared_ptr<Schema> schema = arrow::schema({});
2137   status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
2138   ASSERT_OK(status);
2139   status = writer->Close();
2140   ASSERT_RAISES(NotImplemented, status);
2141 }
2142 
TEST_F(TestBasicAuthHandler,FailUnauthenticatedCalls)2143 TEST_F(TestBasicAuthHandler, FailUnauthenticatedCalls) {
2144   Status status;
2145   std::unique_ptr<FlightListing> listing;
2146   status = client_->ListFlights(&listing);
2147   ASSERT_RAISES(IOError, status);
2148   ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2149 
2150   std::unique_ptr<ResultStream> results;
2151   Action action;
2152   action.type = "";
2153   action.body = Buffer::FromString("");
2154   status = client_->DoAction(action, &results);
2155   ASSERT_RAISES(IOError, status);
2156   ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2157 
2158   std::vector<ActionType> actions;
2159   status = client_->ListActions(&actions);
2160   ASSERT_RAISES(IOError, status);
2161   ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2162 
2163   std::unique_ptr<FlightInfo> info;
2164   status = client_->GetFlightInfo(FlightDescriptor{}, &info);
2165   ASSERT_RAISES(IOError, status);
2166   ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2167 
2168   std::unique_ptr<FlightStreamReader> stream;
2169   status = client_->DoGet(Ticket{}, &stream);
2170   ASSERT_RAISES(IOError, status);
2171   ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2172 
2173   std::unique_ptr<FlightStreamWriter> writer;
2174   std::unique_ptr<FlightMetadataReader> reader;
2175   std::shared_ptr<Schema> schema(
2176       (new arrow::Schema(std::vector<std::shared_ptr<Field>>())));
2177   status = client_->DoPut(FlightDescriptor{}, schema, &writer, &reader);
2178   ASSERT_OK(status);
2179   status = writer->Close();
2180   ASSERT_RAISES(IOError, status);
2181   ASSERT_THAT(status.message(), ::testing::HasSubstr("Invalid token"));
2182 }
2183 
TEST_F(TestBasicAuthHandler,CheckPeerIdentity)2184 TEST_F(TestBasicAuthHandler, CheckPeerIdentity) {
2185   ASSERT_OK(
2186       client_->Authenticate({}, std::unique_ptr<ClientAuthHandler>(
2187                                     new TestClientBasicAuthHandler("user", "p4ssw0rd"))));
2188 
2189   Action action;
2190   action.type = "who-am-i";
2191   action.body = Buffer::FromString("");
2192   std::unique_ptr<ResultStream> results;
2193   ASSERT_OK(client_->DoAction(action, &results));
2194   ASSERT_NE(results, nullptr);
2195 
2196   std::unique_ptr<Result> result;
2197   ASSERT_OK(results->Next(&result));
2198   ASSERT_NE(result, nullptr);
2199   // Action returns the peer identity as the result.
2200   ASSERT_EQ(result->body->ToString(), "user");
2201 }
2202 
TEST_F(TestTls,DoAction)2203 TEST_F(TestTls, DoAction) {
2204   FlightCallOptions options;
2205   options.timeout = TimeoutDuration{5.0};
2206   Action action;
2207   action.type = "test";
2208   action.body = Buffer::FromString("");
2209   std::unique_ptr<ResultStream> results;
2210   ASSERT_OK(client_->DoAction(options, action, &results));
2211   ASSERT_NE(results, nullptr);
2212 
2213   std::unique_ptr<Result> result;
2214   ASSERT_OK(results->Next(&result));
2215   ASSERT_NE(result, nullptr);
2216   ASSERT_EQ(result->body->ToString(), "Hello, world!");
2217 }
2218 
2219 #if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS)
TEST_F(TestTls,DisableServerVerification)2220 TEST_F(TestTls, DisableServerVerification) {
2221   std::unique_ptr<FlightClient> client;
2222   auto client_options = FlightClientOptions::Defaults();
2223   // For security reasons, if encryption is being used,
2224   // the client should be configured to verify the server by default.
2225   ASSERT_EQ(client_options.disable_server_verification, false);
2226   client_options.disable_server_verification = true;
2227   ASSERT_OK(FlightClient::Connect(location_, client_options, &client));
2228 
2229   FlightCallOptions options;
2230   options.timeout = TimeoutDuration{5.0};
2231   Action action;
2232   action.type = "test";
2233   action.body = Buffer::FromString("");
2234   std::unique_ptr<ResultStream> results;
2235   ASSERT_OK(client->DoAction(options, action, &results));
2236   ASSERT_NE(results, nullptr);
2237 
2238   std::unique_ptr<Result> result;
2239   ASSERT_OK(results->Next(&result));
2240   ASSERT_NE(result, nullptr);
2241   ASSERT_EQ(result->body->ToString(), "Hello, world!");
2242 }
2243 #endif
2244 
TEST_F(TestTls,OverrideHostname)2245 TEST_F(TestTls, OverrideHostname) {
2246   std::unique_ptr<FlightClient> client;
2247   auto client_options = FlightClientOptions::Defaults();
2248   client_options.override_hostname = "fakehostname";
2249   CertKeyPair root_cert;
2250   ASSERT_OK(ExampleTlsCertificateRoot(&root_cert));
2251   client_options.tls_root_certs = root_cert.pem_cert;
2252   ASSERT_OK(FlightClient::Connect(location_, client_options, &client));
2253 
2254   FlightCallOptions options;
2255   options.timeout = TimeoutDuration{5.0};
2256   Action action;
2257   action.type = "test";
2258   action.body = Buffer::FromString("");
2259   std::unique_ptr<ResultStream> results;
2260   ASSERT_RAISES(IOError, client->DoAction(options, action, &results));
2261 }
2262 
2263 // Test the facility for setting generic transport options.
TEST_F(TestTls,OverrideHostnameGeneric)2264 TEST_F(TestTls, OverrideHostnameGeneric) {
2265   std::unique_ptr<FlightClient> client;
2266   auto client_options = FlightClientOptions::Defaults();
2267   client_options.generic_options.emplace_back(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG,
2268                                               "fakehostname");
2269   CertKeyPair root_cert;
2270   ASSERT_OK(ExampleTlsCertificateRoot(&root_cert));
2271   client_options.tls_root_certs = root_cert.pem_cert;
2272   ASSERT_OK(FlightClient::Connect(location_, client_options, &client));
2273 
2274   FlightCallOptions options;
2275   options.timeout = TimeoutDuration{5.0};
2276   Action action;
2277   action.type = "test";
2278   action.body = Buffer::FromString("");
2279   std::unique_ptr<ResultStream> results;
2280   ASSERT_RAISES(IOError, client->DoAction(options, action, &results));
2281   // Could check error message for the gRPC error message but it isn't
2282   // necessarily stable
2283 }
2284 
TEST_F(TestMetadata,DoGet)2285 TEST_F(TestMetadata, DoGet) {
2286   Ticket ticket{""};
2287   std::unique_ptr<FlightStreamReader> stream;
2288   ASSERT_OK(client_->DoGet(ticket, &stream));
2289 
2290   BatchVector expected_batches;
2291   ASSERT_OK(ExampleIntBatches(&expected_batches));
2292 
2293   FlightStreamChunk chunk;
2294   auto num_batches = static_cast<int>(expected_batches.size());
2295   for (int i = 0; i < num_batches; ++i) {
2296     ASSERT_OK(stream->Next(&chunk));
2297     ASSERT_NE(nullptr, chunk.data);
2298     ASSERT_NE(nullptr, chunk.app_metadata);
2299     ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data);
2300     ASSERT_EQ(std::to_string(i), chunk.app_metadata->ToString());
2301   }
2302   ASSERT_OK(stream->Next(&chunk));
2303   ASSERT_EQ(nullptr, chunk.data);
2304 }
2305 
2306 // Test dictionaries. This tests a corner case in the reader:
2307 // dictionary batches come in between the schema and the first record
2308 // batch, so the server must take care to read application metadata
2309 // from the record batch, and not one of the dictionary batches.
TEST_F(TestMetadata,DoGetDictionaries)2310 TEST_F(TestMetadata, DoGetDictionaries) {
2311   Ticket ticket{"dicts"};
2312   std::unique_ptr<FlightStreamReader> stream;
2313   ASSERT_OK(client_->DoGet(ticket, &stream));
2314 
2315   BatchVector expected_batches;
2316   ASSERT_OK(ExampleDictBatches(&expected_batches));
2317 
2318   FlightStreamChunk chunk;
2319   auto num_batches = static_cast<int>(expected_batches.size());
2320   for (int i = 0; i < num_batches; ++i) {
2321     ASSERT_OK(stream->Next(&chunk));
2322     ASSERT_NE(nullptr, chunk.data);
2323     ASSERT_NE(nullptr, chunk.app_metadata);
2324     ASSERT_BATCHES_EQUAL(*expected_batches[i], *chunk.data);
2325     ASSERT_EQ(std::to_string(i), chunk.app_metadata->ToString());
2326   }
2327   ASSERT_OK(stream->Next(&chunk));
2328   ASSERT_EQ(nullptr, chunk.data);
2329 }
2330 
TEST_F(TestMetadata,DoPut)2331 TEST_F(TestMetadata, DoPut) {
2332   std::unique_ptr<FlightStreamWriter> writer;
2333   std::unique_ptr<FlightMetadataReader> reader;
2334   std::shared_ptr<Schema> schema = ExampleIntSchema();
2335   ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader));
2336 
2337   BatchVector expected_batches;
2338   ASSERT_OK(ExampleIntBatches(&expected_batches));
2339 
2340   std::shared_ptr<RecordBatch> chunk;
2341   std::shared_ptr<Buffer> metadata;
2342   auto num_batches = static_cast<int>(expected_batches.size());
2343   for (int i = 0; i < num_batches; ++i) {
2344     ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i],
2345                                         Buffer::FromString(std::to_string(i))));
2346   }
2347   // This eventually calls grpc::ClientReaderWriter::Finish which can
2348   // hang if there are unread messages. So make sure our wrapper
2349   // around this doesn't hang (because it drains any unread messages)
2350   ASSERT_OK(writer->Close());
2351 }
2352 
2353 // Test DoPut() with dictionaries. This tests a corner case in the
2354 // server-side reader; see DoGetDictionaries above.
TEST_F(TestMetadata,DoPutDictionaries)2355 TEST_F(TestMetadata, DoPutDictionaries) {
2356   std::unique_ptr<FlightStreamWriter> writer;
2357   std::unique_ptr<FlightMetadataReader> reader;
2358   BatchVector expected_batches;
2359   ASSERT_OK(ExampleDictBatches(&expected_batches));
2360   // ARROW-8749: don't get the schema via ExampleDictSchema because
2361   // DictionaryMemo uses field addresses to determine whether it's
2362   // seen a field before. Hence, if we use a schema that is different
2363   // (identity-wise) than the schema of the first batch we write,
2364   // we'll end up generating a duplicate set of dictionaries that
2365   // confuses the reader.
2366   ASSERT_OK(client_->DoPut(FlightDescriptor{}, expected_batches[0]->schema(), &writer,
2367                            &reader));
2368   std::shared_ptr<RecordBatch> chunk;
2369   std::shared_ptr<Buffer> metadata;
2370   auto num_batches = static_cast<int>(expected_batches.size());
2371   for (int i = 0; i < num_batches; ++i) {
2372     ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i],
2373                                         Buffer::FromString(std::to_string(i))));
2374   }
2375   ASSERT_OK(writer->Close());
2376 }
2377 
TEST_F(TestMetadata,DoPutReadMetadata)2378 TEST_F(TestMetadata, DoPutReadMetadata) {
2379   std::unique_ptr<FlightStreamWriter> writer;
2380   std::unique_ptr<FlightMetadataReader> reader;
2381   std::shared_ptr<Schema> schema = ExampleIntSchema();
2382   ASSERT_OK(client_->DoPut(FlightDescriptor{}, schema, &writer, &reader));
2383 
2384   BatchVector expected_batches;
2385   ASSERT_OK(ExampleIntBatches(&expected_batches));
2386 
2387   std::shared_ptr<RecordBatch> chunk;
2388   std::shared_ptr<Buffer> metadata;
2389   auto num_batches = static_cast<int>(expected_batches.size());
2390   for (int i = 0; i < num_batches; ++i) {
2391     ASSERT_OK(writer->WriteWithMetadata(*expected_batches[i],
2392                                         Buffer::FromString(std::to_string(i))));
2393     ASSERT_OK(reader->ReadMetadata(&metadata));
2394     ASSERT_NE(nullptr, metadata);
2395     ASSERT_EQ(std::to_string(i), metadata->ToString());
2396   }
2397   // As opposed to DoPutDrainMetadata, now we've read the messages, so
2398   // make sure this still closes as expected.
2399   ASSERT_OK(writer->Close());
2400 }
2401 
TEST_F(TestOptions,DoGetReadOptions)2402 TEST_F(TestOptions, DoGetReadOptions) {
2403   // Call DoGet, but with a very low read nesting depth set to fail the call.
2404   Ticket ticket{""};
2405   auto options = FlightCallOptions();
2406   options.read_options.max_recursion_depth = 1;
2407   std::unique_ptr<FlightStreamReader> stream;
2408   ASSERT_OK(client_->DoGet(options, ticket, &stream));
2409   FlightStreamChunk chunk;
2410   ASSERT_RAISES(Invalid, stream->Next(&chunk));
2411 }
2412 
TEST_F(TestOptions,DoPutWriteOptions)2413 TEST_F(TestOptions, DoPutWriteOptions) {
2414   // Call DoPut, but with a very low write nesting depth set to fail the call.
2415   std::unique_ptr<FlightStreamWriter> writer;
2416   std::unique_ptr<FlightMetadataReader> reader;
2417   BatchVector expected_batches;
2418   ASSERT_OK(ExampleNestedBatches(&expected_batches));
2419 
2420   auto options = FlightCallOptions();
2421   options.write_options.max_recursion_depth = 1;
2422   ASSERT_OK(client_->DoPut(options, FlightDescriptor{}, expected_batches[0]->schema(),
2423                            &writer, &reader));
2424   for (const auto& batch : expected_batches) {
2425     ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch));
2426   }
2427 }
2428 
TEST_F(TestOptions,DoExchangeClientWriteOptions)2429 TEST_F(TestOptions, DoExchangeClientWriteOptions) {
2430   // Call DoExchange and write nested data, but with a very low nesting depth set to
2431   // fail the call.
2432   auto options = FlightCallOptions();
2433   options.write_options.max_recursion_depth = 1;
2434   auto descr = FlightDescriptor::Command("");
2435   std::unique_ptr<FlightStreamReader> reader;
2436   std::unique_ptr<FlightStreamWriter> writer;
2437   ASSERT_OK(client_->DoExchange(options, descr, &writer, &reader));
2438   BatchVector batches;
2439   ASSERT_OK(ExampleNestedBatches(&batches));
2440   ASSERT_OK(writer->Begin(batches[0]->schema()));
2441   for (const auto& batch : batches) {
2442     ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch));
2443   }
2444   ASSERT_OK(writer->DoneWriting());
2445   ASSERT_OK(writer->Close());
2446 }
2447 
TEST_F(TestOptions,DoExchangeClientWriteOptionsBegin)2448 TEST_F(TestOptions, DoExchangeClientWriteOptionsBegin) {
2449   // Call DoExchange and write nested data, but with a very low nesting depth set to
2450   // fail the call. Here the options are set explicitly when we write data and not in the
2451   // call options.
2452   auto descr = FlightDescriptor::Command("");
2453   std::unique_ptr<FlightStreamReader> reader;
2454   std::unique_ptr<FlightStreamWriter> writer;
2455   ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
2456   BatchVector batches;
2457   ASSERT_OK(ExampleNestedBatches(&batches));
2458   auto options = ipc::IpcWriteOptions::Defaults();
2459   options.max_recursion_depth = 1;
2460   ASSERT_OK(writer->Begin(batches[0]->schema(), options));
2461   for (const auto& batch : batches) {
2462     ASSERT_RAISES(Invalid, writer->WriteRecordBatch(*batch));
2463   }
2464   ASSERT_OK(writer->DoneWriting());
2465   ASSERT_OK(writer->Close());
2466 }
2467 
TEST_F(TestOptions,DoExchangeServerWriteOptions)2468 TEST_F(TestOptions, DoExchangeServerWriteOptions) {
2469   // Call DoExchange and write nested data, but with a very low nesting depth set to fail
2470   // the call. (The low nesting depth is set on the server side.)
2471   auto descr = FlightDescriptor::Command("");
2472   std::unique_ptr<FlightStreamReader> reader;
2473   std::unique_ptr<FlightStreamWriter> writer;
2474   ASSERT_OK(client_->DoExchange(descr, &writer, &reader));
2475   BatchVector batches;
2476   ASSERT_OK(ExampleNestedBatches(&batches));
2477   ASSERT_OK(writer->Begin(batches[0]->schema()));
2478   FlightStreamChunk chunk;
2479   ASSERT_OK(writer->WriteRecordBatch(*batches[0]));
2480   ASSERT_OK(writer->DoneWriting());
2481   ASSERT_RAISES(Invalid, writer->Close());
2482 }
2483 
TEST_F(TestRejectServerMiddleware,Rejected)2484 TEST_F(TestRejectServerMiddleware, Rejected) {
2485   std::unique_ptr<FlightInfo> info;
2486   const auto& status = client_->GetFlightInfo(FlightDescriptor{}, &info);
2487   ASSERT_RAISES(IOError, status);
2488   ASSERT_THAT(status.message(), ::testing::HasSubstr("All calls are rejected"));
2489 }
2490 
TEST_F(TestCountingServerMiddleware,Count)2491 TEST_F(TestCountingServerMiddleware, Count) {
2492   std::unique_ptr<FlightInfo> info;
2493   const auto& status = client_->GetFlightInfo(FlightDescriptor{}, &info);
2494   ASSERT_RAISES(NotImplemented, status);
2495 
2496   Ticket ticket{""};
2497   std::unique_ptr<FlightStreamReader> stream;
2498   ASSERT_OK(client_->DoGet(ticket, &stream));
2499 
2500   ASSERT_EQ(1, request_counter_->failed_);
2501 
2502   while (true) {
2503     FlightStreamChunk chunk;
2504     ASSERT_OK(stream->Next(&chunk));
2505     if (chunk.data == nullptr) {
2506       break;
2507     }
2508   }
2509 
2510   ASSERT_EQ(1, request_counter_->successful_);
2511   ASSERT_EQ(1, request_counter_->failed_);
2512 }
2513 
TEST_F(TestPropagatingMiddleware,Propagate)2514 TEST_F(TestPropagatingMiddleware, Propagate) {
2515   Action action;
2516   std::unique_ptr<ResultStream> stream;
2517   std::unique_ptr<Result> result;
2518 
2519   current_span_id = "trace-id";
2520   client_middleware_->Reset();
2521 
2522   action.type = "action1";
2523   action.body = Buffer::FromString("action1-content");
2524   ASSERT_OK(client_->DoAction(action, &stream));
2525 
2526   ASSERT_OK(stream->Next(&result));
2527   ASSERT_EQ("trace-id", result->body->ToString());
2528   ValidateStatus(Status::OK(), FlightMethod::DoAction);
2529 }
2530 
2531 // For each method, make sure that the client middleware received
2532 // headers from the server and that the proper method enum value was
2533 // passed to the interceptor
TEST_F(TestPropagatingMiddleware,ListFlights)2534 TEST_F(TestPropagatingMiddleware, ListFlights) {
2535   client_middleware_->Reset();
2536   std::unique_ptr<FlightListing> listing;
2537   const Status status = client_->ListFlights(&listing);
2538   ASSERT_RAISES(NotImplemented, status);
2539   ValidateStatus(status, FlightMethod::ListFlights);
2540 }
2541 
TEST_F(TestPropagatingMiddleware,GetFlightInfo)2542 TEST_F(TestPropagatingMiddleware, GetFlightInfo) {
2543   client_middleware_->Reset();
2544   auto descr = FlightDescriptor::Path({"examples", "ints"});
2545   std::unique_ptr<FlightInfo> info;
2546   const Status status = client_->GetFlightInfo(descr, &info);
2547   ASSERT_RAISES(NotImplemented, status);
2548   ValidateStatus(status, FlightMethod::GetFlightInfo);
2549 }
2550 
TEST_F(TestPropagatingMiddleware,GetSchema)2551 TEST_F(TestPropagatingMiddleware, GetSchema) {
2552   client_middleware_->Reset();
2553   auto descr = FlightDescriptor::Path({"examples", "ints"});
2554   std::unique_ptr<SchemaResult> result;
2555   const Status status = client_->GetSchema(descr, &result);
2556   ASSERT_RAISES(NotImplemented, status);
2557   ValidateStatus(status, FlightMethod::GetSchema);
2558 }
2559 
TEST_F(TestPropagatingMiddleware,ListActions)2560 TEST_F(TestPropagatingMiddleware, ListActions) {
2561   client_middleware_->Reset();
2562   std::vector<ActionType> actions;
2563   const Status status = client_->ListActions(&actions);
2564   ASSERT_RAISES(NotImplemented, status);
2565   ValidateStatus(status, FlightMethod::ListActions);
2566 }
2567 
TEST_F(TestPropagatingMiddleware,DoGet)2568 TEST_F(TestPropagatingMiddleware, DoGet) {
2569   client_middleware_->Reset();
2570   Ticket ticket1{"ARROW-5095-fail"};
2571   std::unique_ptr<FlightStreamReader> stream;
2572   Status status = client_->DoGet(ticket1, &stream);
2573   ASSERT_RAISES(NotImplemented, status);
2574   ValidateStatus(status, FlightMethod::DoGet);
2575 }
2576 
TEST_F(TestPropagatingMiddleware,DoPut)2577 TEST_F(TestPropagatingMiddleware, DoPut) {
2578   client_middleware_->Reset();
2579   auto descr = FlightDescriptor::Path({"ints"});
2580   auto a1 = ArrayFromJSON(int32(), "[4, 5, 6, null]");
2581   auto schema = arrow::schema({field("f1", a1->type())});
2582 
2583   std::unique_ptr<FlightStreamWriter> stream;
2584   std::unique_ptr<FlightMetadataReader> reader;
2585   ASSERT_OK(client_->DoPut(descr, schema, &stream, &reader));
2586   const Status status = stream->Close();
2587   ASSERT_RAISES(NotImplemented, status);
2588   ValidateStatus(status, FlightMethod::DoPut);
2589 }
2590 
TEST_F(TestBasicHeaderAuthMiddleware,ValidCredentials)2591 TEST_F(TestBasicHeaderAuthMiddleware, ValidCredentials) { RunValidClientAuth(); }
2592 
TEST_F(TestBasicHeaderAuthMiddleware,InvalidCredentials)2593 TEST_F(TestBasicHeaderAuthMiddleware, InvalidCredentials) { RunInvalidClientAuth(); }
2594 
TEST_F(TestCookieMiddleware,BasicParsing)2595 TEST_F(TestCookieMiddleware, BasicParsing) {
2596   AddAndValidate("id1=1; foo=bar;");
2597   AddAndValidate("id1=1; foo=bar");
2598   AddAndValidate("id2=2;");
2599   AddAndValidate("id4=\"4\"");
2600   AddAndValidate("id5=5; foo=bar; baz=buz;");
2601 }
2602 
TEST_F(TestCookieMiddleware,Overwrite)2603 TEST_F(TestCookieMiddleware, Overwrite) {
2604   AddAndValidate("id0=0");
2605   AddAndValidate("id0=1");
2606   AddAndValidate("id1=0");
2607   AddAndValidate("id1=1");
2608   AddAndValidate("id1=1");
2609   AddAndValidate("id1=10");
2610   AddAndValidate("id=3");
2611   AddAndValidate("id=0");
2612   AddAndValidate("id=0");
2613 }
2614 
TEST_F(TestCookieMiddleware,MaxAge)2615 TEST_F(TestCookieMiddleware, MaxAge) {
2616   AddAndValidate("id0=0; max-age=0;");
2617   AddAndValidate("id1=0; max-age=-1;");
2618   AddAndValidate("id2=0; max-age=0");
2619   AddAndValidate("id3=0; max-age=-1");
2620   AddAndValidate("id4=0; max-age=1");
2621   AddAndValidate("id5=0; max-age=1");
2622   AddAndValidate("id4=0; max-age=0");
2623   AddAndValidate("id5=0; max-age=0");
2624 }
2625 
TEST_F(TestCookieMiddleware,Expires)2626 TEST_F(TestCookieMiddleware, Expires) {
2627   AddAndValidate("id0=0; expires=0, 0 0 0 0:0:0 GMT;");
2628   AddAndValidate("id0=0; expires=0, 0 0 0 0:0:0 GMT");
2629   AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;");
2630   AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT");
2631   AddAndValidate("id0=0; expires=Fri, 01 Jan 2038 22:15:36 GMT;");
2632   AddAndValidate("id1=0; expires=Fri, 01 Jan 2038 22:15:36 GMT");
2633   AddAndValidate("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;");
2634   AddAndValidate("id1=0; expires=Fri, 22 Dec 2017 22:15:36 GMT");
2635 }
2636 
TEST_F(TestCookieParsing,Expired)2637 TEST_F(TestCookieParsing, Expired) {
2638   VerifyParseCookie("id0=0; expires=Fri, 22 Dec 2017 22:15:36 GMT;", true);
2639   VerifyParseCookie("id1=0; max-age=-1;", true);
2640   VerifyParseCookie("id0=0; max-age=0;", true);
2641 }
2642 
TEST_F(TestCookieParsing,Invalid)2643 TEST_F(TestCookieParsing, Invalid) {
2644   VerifyParseCookie("id1=0; expires=0, 0 0 0 0:0:0 GMT;", true);
2645   VerifyParseCookie("id1=0; expires=Fri, 01 FOO 2038 22:15:36 GMT", true);
2646   VerifyParseCookie("id1=0; expires=foo", true);
2647   VerifyParseCookie("id1=0; expires=", true);
2648   VerifyParseCookie("id1=0; max-age=FOO", true);
2649   VerifyParseCookie("id1=0; max-age=", true);
2650 }
2651 
TEST_F(TestCookieParsing,NoExpiry)2652 TEST_F(TestCookieParsing, NoExpiry) {
2653   VerifyParseCookie("id1=0;", false);
2654   VerifyParseCookie("id1=0; noexpiry=Fri, 01 Jan 2038 22:15:36 GMT", false);
2655   VerifyParseCookie("id1=0; noexpiry=\"Fri, 01 Jan 2038 22:15:36 GMT\"", false);
2656   VerifyParseCookie("id1=0; nomax-age=-1", false);
2657   VerifyParseCookie("id1=0; nomax-age=\"-1\"", false);
2658   VerifyParseCookie("id1=0; randomattr=foo", false);
2659 }
2660 
TEST_F(TestCookieParsing,NotExpired)2661 TEST_F(TestCookieParsing, NotExpired) {
2662   VerifyParseCookie("id5=0; max-age=1", false);
2663   VerifyParseCookie("id0=0; expires=Fri, 01 Jan 2038 22:15:36 GMT;", false);
2664 }
2665 
TEST_F(TestCookieParsing,GetName)2666 TEST_F(TestCookieParsing, GetName) {
2667   VerifyCookieName("id1=1; foo=bar;", "id1");
2668   VerifyCookieName("id1=1; foo=bar", "id1");
2669   VerifyCookieName("id2=2;", "id2");
2670   VerifyCookieName("id4=\"4\"", "id4");
2671   VerifyCookieName("id5=5; foo=bar; baz=buz;", "id5");
2672 }
2673 
TEST_F(TestCookieParsing,ToString)2674 TEST_F(TestCookieParsing, ToString) {
2675   VerifyCookieString("id1=1; foo=bar;", "id1=\"1\"");
2676   VerifyCookieString("id1=1; foo=bar", "id1=\"1\"");
2677   VerifyCookieString("id2=2;", "id2=\"2\"");
2678   VerifyCookieString("id4=\"4\"", "id4=\"4\"");
2679   VerifyCookieString("id5=5; foo=bar; baz=buz;", "id5=\"5\"");
2680 }
2681 
TEST_F(TestCookieParsing,DateConversion)2682 TEST_F(TestCookieParsing, DateConversion) {
2683   VerifyCookieDateConverson("Mon, 01 jan 2038 22:15:36 GMT;", "01 01 2038 22:15:36");
2684   VerifyCookieDateConverson("TUE, 10 Feb 2038 22:15:36 GMT", "10 02 2038 22:15:36");
2685   VerifyCookieDateConverson("WED, 20 MAr 2038 22:15:36 GMT;", "20 03 2038 22:15:36");
2686   VerifyCookieDateConverson("thu, 15 APR 2038 22:15:36 GMT", "15 04 2038 22:15:36");
2687   VerifyCookieDateConverson("Fri, 30 mAY 2038 22:15:36 GMT;", "30 05 2038 22:15:36");
2688   VerifyCookieDateConverson("Sat, 03 juN 2038 22:15:36 GMT", "03 06 2038 22:15:36");
2689   VerifyCookieDateConverson("Sun, 01 JuL 2038 22:15:36 GMT;", "01 07 2038 22:15:36");
2690   VerifyCookieDateConverson("Fri, 06 aUg 2038 22:15:36 GMT", "06 08 2038 22:15:36");
2691   VerifyCookieDateConverson("Fri, 01 SEP 2038 22:15:36 GMT;", "01 09 2038 22:15:36");
2692   VerifyCookieDateConverson("Fri, 01 OCT 2038 22:15:36 GMT", "01 10 2038 22:15:36");
2693   VerifyCookieDateConverson("Fri, 01 Nov 2038 22:15:36 GMT;", "01 11 2038 22:15:36");
2694   VerifyCookieDateConverson("Fri, 01 deC 2038 22:15:36 GMT", "01 12 2038 22:15:36");
2695   VerifyCookieDateConverson("", "");
2696   VerifyCookieDateConverson("Fri, 01 INVALID 2038 22:15:36 GMT;",
2697                             "01 INVALID 2038 22:15:36");
2698 }
2699 
TEST_F(TestCookieParsing,ParseCookieAttribute)2700 TEST_F(TestCookieParsing, ParseCookieAttribute) {
2701   VerifyCookieAttributeParsing("", 0, util::nullopt, std::string::npos);
2702 
2703   std::string cookie_string = "attr0=0; attr1=1; attr2=2; attr3=3";
2704   auto attr_length = std::string("attr0=0;").length();
2705   std::string::size_type start_pos = 0;
2706   VerifyCookieAttributeParsing(cookie_string, start_pos, std::make_pair("attr0", "0"),
2707                                cookie_string.find("attr0=0;") + attr_length);
2708   VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)),
2709                                std::make_pair("attr1", "1"),
2710                                cookie_string.find("attr1=1;") + attr_length);
2711   VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)),
2712                                std::make_pair("attr2", "2"),
2713                                cookie_string.find("attr2=2;") + attr_length);
2714   VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)),
2715                                std::make_pair("attr3", "3"), std::string::npos);
2716   VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length - 1)),
2717                                util::nullopt, std::string::npos);
2718   VerifyCookieAttributeParsing(cookie_string, std::string::npos, util::nullopt,
2719                                std::string::npos);
2720 }
2721 
TEST_F(TestCookieParsing,CookieCache)2722 TEST_F(TestCookieParsing, CookieCache) {
2723   AddCookieVerifyCache({"id0=0;"}, "");
2724   AddCookieVerifyCache({"id0=0;", "id0=1;"}, "id0=\"1\"");
2725   AddCookieVerifyCache({"id0=0;", "id1=1;"}, "id0=\"0\"; id1=\"1\"");
2726   AddCookieVerifyCache({"id0=0;", "id1=1;", "id2=2"}, "id0=\"0\"; id1=\"1\"; id2=\"2\"");
2727 }
2728 
2729 class ForeverFlightListing : public FlightListing {
Next(std::unique_ptr<FlightInfo> * info)2730   Status Next(std::unique_ptr<FlightInfo>* info) override {
2731     std::this_thread::sleep_for(std::chrono::milliseconds(100));
2732     *info = arrow::internal::make_unique<FlightInfo>(ExampleFlightInfo()[0]);
2733     return Status::OK();
2734   }
2735 };
2736 
2737 class ForeverResultStream : public ResultStream {
Next(std::unique_ptr<Result> * result)2738   Status Next(std::unique_ptr<Result>* result) override {
2739     std::this_thread::sleep_for(std::chrono::milliseconds(100));
2740     *result = arrow::internal::make_unique<Result>();
2741     (*result)->body = Buffer::FromString("foo");
2742     return Status::OK();
2743   }
2744 };
2745 
2746 class ForeverDataStream : public FlightDataStream {
2747  public:
ForeverDataStream()2748   ForeverDataStream() : schema_(arrow::schema({})), mapper_(*schema_) {}
schema()2749   std::shared_ptr<Schema> schema() override { return schema_; }
2750 
GetSchemaPayload(FlightPayload * payload)2751   Status GetSchemaPayload(FlightPayload* payload) override {
2752     return ipc::GetSchemaPayload(*schema_, ipc::IpcWriteOptions::Defaults(), mapper_,
2753                                  &payload->ipc_message);
2754   }
2755 
Next(FlightPayload * payload)2756   Status Next(FlightPayload* payload) override {
2757     auto batch = RecordBatch::Make(schema_, 0, ArrayVector{});
2758     return ipc::GetRecordBatchPayload(*batch, ipc::IpcWriteOptions::Defaults(),
2759                                       &payload->ipc_message);
2760   }
2761 
2762  private:
2763   std::shared_ptr<Schema> schema_;
2764   ipc::DictionaryFieldMapper mapper_;
2765 };
2766 
2767 class CancelTestServer : public FlightServerBase {
2768  public:
ListFlights(const ServerCallContext &,const Criteria *,std::unique_ptr<FlightListing> * listings)2769   Status ListFlights(const ServerCallContext&, const Criteria*,
2770                      std::unique_ptr<FlightListing>* listings) override {
2771     *listings = arrow::internal::make_unique<ForeverFlightListing>();
2772     return Status::OK();
2773   }
DoAction(const ServerCallContext &,const Action &,std::unique_ptr<ResultStream> * result)2774   Status DoAction(const ServerCallContext&, const Action&,
2775                   std::unique_ptr<ResultStream>* result) override {
2776     *result = arrow::internal::make_unique<ForeverResultStream>();
2777     return Status::OK();
2778   }
ListActions(const ServerCallContext &,std::vector<ActionType> * actions)2779   Status ListActions(const ServerCallContext&,
2780                      std::vector<ActionType>* actions) override {
2781     *actions = {};
2782     return Status::OK();
2783   }
DoGet(const ServerCallContext &,const Ticket &,std::unique_ptr<FlightDataStream> * data_stream)2784   Status DoGet(const ServerCallContext&, const Ticket&,
2785                std::unique_ptr<FlightDataStream>* data_stream) override {
2786     *data_stream = arrow::internal::make_unique<ForeverDataStream>();
2787     return Status::OK();
2788   }
2789 };
2790 
2791 class TestCancel : public ::testing::Test {
2792  public:
SetUp()2793   void SetUp() {
2794     ASSERT_OK(MakeServer<CancelTestServer>(
2795         &server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
2796         [](FlightClientOptions* options) { return Status::OK(); }));
2797   }
TearDown()2798   void TearDown() { ASSERT_OK(server_->Shutdown()); }
2799 
2800  protected:
2801   std::unique_ptr<FlightClient> client_;
2802   std::unique_ptr<FlightServerBase> server_;
2803 };
2804 
TEST_F(TestCancel,ListFlights)2805 TEST_F(TestCancel, ListFlights) {
2806   StopSource stop_source;
2807   FlightCallOptions options;
2808   options.stop_token = stop_source.token();
2809   std::unique_ptr<FlightListing> listing;
2810   stop_source.RequestStop(Status::Cancelled("StopSource"));
2811   EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
2812                                   client_->ListFlights(options, {}, &listing));
2813 }
2814 
TEST_F(TestCancel,DoAction)2815 TEST_F(TestCancel, DoAction) {
2816   StopSource stop_source;
2817   FlightCallOptions options;
2818   options.stop_token = stop_source.token();
2819   std::unique_ptr<ResultStream> results;
2820   stop_source.RequestStop(Status::Cancelled("StopSource"));
2821   EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
2822                                   client_->DoAction(options, {}, &results));
2823 }
2824 
TEST_F(TestCancel,ListActions)2825 TEST_F(TestCancel, ListActions) {
2826   StopSource stop_source;
2827   FlightCallOptions options;
2828   options.stop_token = stop_source.token();
2829   std::vector<ActionType> results;
2830   stop_source.RequestStop(Status::Cancelled("StopSource"));
2831   EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
2832                                   client_->ListActions(options, &results));
2833 }
2834 
TEST_F(TestCancel,DoGet)2835 TEST_F(TestCancel, DoGet) {
2836   StopSource stop_source;
2837   FlightCallOptions options;
2838   options.stop_token = stop_source.token();
2839   std::unique_ptr<ResultStream> results;
2840   stop_source.RequestStop(Status::Cancelled("StopSource"));
2841   std::unique_ptr<FlightStreamReader> stream;
2842   ASSERT_OK(client_->DoGet(options, {}, &stream));
2843   std::shared_ptr<Table> table;
2844   EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
2845                                   stream->ReadAll(&table));
2846 
2847   ASSERT_OK(client_->DoGet({}, &stream));
2848   EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
2849                                   stream->ReadAll(&table, options.stop_token));
2850 }
2851 
TEST_F(TestCancel,DoExchange)2852 TEST_F(TestCancel, DoExchange) {
2853   StopSource stop_source;
2854   FlightCallOptions options;
2855   options.stop_token = stop_source.token();
2856   std::unique_ptr<ResultStream> results;
2857   stop_source.RequestStop(Status::Cancelled("StopSource"));
2858   std::unique_ptr<FlightStreamWriter> writer;
2859   std::unique_ptr<FlightStreamReader> stream;
2860   ASSERT_OK(
2861       client_->DoExchange(options, FlightDescriptor::Command(""), &writer, &stream));
2862   std::shared_ptr<Table> table;
2863   EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
2864                                   stream->ReadAll(&table));
2865 
2866   ASSERT_OK(client_->DoExchange(FlightDescriptor::Command(""), &writer, &stream));
2867   EXPECT_RAISES_WITH_MESSAGE_THAT(Cancelled, ::testing::HasSubstr("StopSource"),
2868                                   stream->ReadAll(&table, options.stop_token));
2869 }
2870 
2871 }  // namespace flight
2872 }  // namespace arrow
2873