1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include "arrow/flight/platform.h"
19 
20 #ifdef __APPLE__
21 #include <limits.h>
22 #include <mach-o/dyld.h>
23 #endif
24 
25 #include <cstdlib>
26 #include <sstream>
27 
28 #include <boost/filesystem.hpp>
29 // boost/process/detail/windows/handle_workaround.hpp doesn't work
30 // without BOOST_USE_WINDOWS_H with MinGW because MinGW doesn't
31 // provide __kernel_entry without winternl.h.
32 //
33 // See also:
34 // https://github.com/boostorg/process/blob/develop/include/boost/process/detail/windows/handle_workaround.hpp
35 #define BOOST_USE_WINDOWS_H 1
36 #include <boost/process.hpp>
37 
38 #include <gtest/gtest.h>
39 
40 #include "arrow/ipc/test_common.h"
41 #include "arrow/testing/gtest_util.h"
42 #include "arrow/util/logging.h"
43 
44 #include "arrow/flight/api.h"
45 #include "arrow/flight/internal.h"
46 #include "arrow/flight/test_util.h"
47 
48 namespace arrow {
49 namespace flight {
50 
51 namespace bp = boost::process;
52 namespace fs = boost::filesystem;
53 
54 namespace {
55 
ResolveCurrentExecutable(fs::path * out)56 Status ResolveCurrentExecutable(fs::path* out) {
57   // See https://stackoverflow.com/a/1024937/10194 for various
58   // platform-specific recipes.
59 
60   boost::system::error_code ec;
61 
62 #if defined(__linux__)
63   *out = fs::canonical("/proc/self/exe", ec);
64 #elif defined(__APPLE__)
65   char buf[PATH_MAX + 1];
66   uint32_t bufsize = sizeof(buf);
67   if (_NSGetExecutablePath(buf, &bufsize) < 0) {
68     return Status::Invalid("Can't resolve current exe: path too large");
69   }
70   *out = fs::canonical(buf, ec);
71 #elif defined(_WIN32)
72   char buf[MAX_PATH + 1];
73   if (!GetModuleFileNameA(NULL, buf, sizeof(buf))) {
74     return Status::Invalid("Can't get executable file path");
75   }
76   *out = fs::canonical(buf, ec);
77 #else
78   ARROW_UNUSED(ec);
79   return Status::NotImplemented("Not available on this system");
80 #endif
81   if (ec) {
82     // XXX fold this into the Status class?
83     return Status::IOError("Can't resolve current exe: ", ec.message());
84   } else {
85     return Status::OK();
86   }
87 }
88 
89 }  // namespace
90 
Start()91 void TestServer::Start() {
92   namespace fs = boost::filesystem;
93 
94   std::string str_port = std::to_string(port_);
95   std::vector<fs::path> search_path = ::boost::this_process::path();
96   // If possible, prepend current executable directory to search path,
97   // since it's likely that the test server executable is located in
98   // the same directory as the running unit test.
99   fs::path current_exe;
100   Status st = ResolveCurrentExecutable(&current_exe);
101   if (st.ok()) {
102     search_path.insert(search_path.begin(), current_exe.parent_path());
103   } else if (st.IsNotImplemented()) {
104     ARROW_CHECK(st.IsNotImplemented()) << st.ToString();
105   }
106 
107   try {
108     server_process_ = std::make_shared<bp::child>(
109         bp::search_path(executable_name_, search_path), "-port", str_port);
110   } catch (...) {
111     std::stringstream ss;
112     ss << "Failed to launch test server '" << executable_name_ << "', looked in ";
113     for (const auto& path : search_path) {
114       ss << path << " : ";
115     }
116     ARROW_LOG(FATAL) << ss.str();
117     throw;
118   }
119   std::cout << "Server running with pid " << server_process_->id() << std::endl;
120 }
121 
Stop()122 int TestServer::Stop() {
123   if (server_process_ && server_process_->valid()) {
124 #ifndef _WIN32
125     kill(server_process_->id(), SIGTERM);
126 #else
127     // This would use SIGKILL on POSIX, which is more brutal than SIGTERM
128     server_process_->terminate();
129 #endif
130     server_process_->wait();
131     return server_process_->exit_code();
132   } else {
133     // Presumably the server wasn't able to start
134     return -1;
135   }
136 }
137 
IsRunning()138 bool TestServer::IsRunning() { return server_process_->running(); }
139 
port() const140 int TestServer::port() const { return port_; }
141 
GetBatchForFlight(const Ticket & ticket,std::shared_ptr<RecordBatchReader> * out)142 Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr<RecordBatchReader>* out) {
143   if (ticket.ticket == "ticket-ints-1") {
144     BatchVector batches;
145     RETURN_NOT_OK(ExampleIntBatches(&batches));
146     *out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
147     return Status::OK();
148   } else if (ticket.ticket == "ticket-dicts-1") {
149     BatchVector batches;
150     RETURN_NOT_OK(ExampleDictBatches(&batches));
151     *out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
152     return Status::OK();
153   } else {
154     return Status::NotImplemented("no stream implemented for this ticket");
155   }
156 }
157 
158 class FlightTestServer : public FlightServerBase {
ListFlights(const ServerCallContext & context,const Criteria * criteria,std::unique_ptr<FlightListing> * listings)159   Status ListFlights(const ServerCallContext& context, const Criteria* criteria,
160                      std::unique_ptr<FlightListing>* listings) override {
161     std::vector<FlightInfo> flights = ExampleFlightInfo();
162     if (criteria && criteria->expression != "") {
163       // For test purposes, if we get criteria, return no results
164       flights.clear();
165     }
166     *listings = std::unique_ptr<FlightListing>(new SimpleFlightListing(flights));
167     return Status::OK();
168   }
169 
GetFlightInfo(const ServerCallContext & context,const FlightDescriptor & request,std::unique_ptr<FlightInfo> * out)170   Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
171                        std::unique_ptr<FlightInfo>* out) override {
172     // Test that Arrow-C++ status codes can make it through gRPC
173     if (request.type == FlightDescriptor::DescriptorType::CMD &&
174         request.cmd == "status-outofmemory") {
175       return Status::OutOfMemory("Sentinel");
176     }
177 
178     std::vector<FlightInfo> flights = ExampleFlightInfo();
179 
180     for (const auto& info : flights) {
181       if (info.descriptor().Equals(request)) {
182         *out = std::unique_ptr<FlightInfo>(new FlightInfo(info));
183         return Status::OK();
184       }
185     }
186     return Status::Invalid("Flight not found: ", request.ToString());
187   }
188 
DoGet(const ServerCallContext & context,const Ticket & request,std::unique_ptr<FlightDataStream> * data_stream)189   Status DoGet(const ServerCallContext& context, const Ticket& request,
190                std::unique_ptr<FlightDataStream>* data_stream) override {
191     // Test for ARROW-5095
192     if (request.ticket == "ARROW-5095-fail") {
193       return Status::UnknownError("Server-side error");
194     }
195     if (request.ticket == "ARROW-5095-success") {
196       return Status::OK();
197     }
198 
199     std::shared_ptr<RecordBatchReader> batch_reader;
200     RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader));
201 
202     *data_stream = std::unique_ptr<FlightDataStream>(new RecordBatchStream(batch_reader));
203     return Status::OK();
204   }
205 
DoExchange(const ServerCallContext & context,std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMessageWriter> writer)206   Status DoExchange(const ServerCallContext& context,
207                     std::unique_ptr<FlightMessageReader> reader,
208                     std::unique_ptr<FlightMessageWriter> writer) override {
209     // Test various scenarios for a DoExchange
210     if (reader->descriptor().type != FlightDescriptor::DescriptorType::CMD) {
211       return Status::Invalid("Must provide a command descriptor");
212     }
213 
214     const std::string& cmd = reader->descriptor().cmd;
215     if (cmd == "error") {
216       // Immediately return an error to the client.
217       return Status::NotImplemented("Expected error");
218     } else if (cmd == "get") {
219       return RunExchangeGet(std::move(reader), std::move(writer));
220     } else if (cmd == "put") {
221       return RunExchangePut(std::move(reader), std::move(writer));
222     } else if (cmd == "counter") {
223       return RunExchangeCounter(std::move(reader), std::move(writer));
224     } else if (cmd == "total") {
225       return RunExchangeTotal(std::move(reader), std::move(writer));
226     } else if (cmd == "echo") {
227       return RunExchangeEcho(std::move(reader), std::move(writer));
228     } else {
229       return Status::NotImplemented("Scenario not implemented: ", cmd);
230     }
231   }
232 
233   // A simple example - act like DoGet.
RunExchangeGet(std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMessageWriter> writer)234   Status RunExchangeGet(std::unique_ptr<FlightMessageReader> reader,
235                         std::unique_ptr<FlightMessageWriter> writer) {
236     RETURN_NOT_OK(writer->Begin(ExampleIntSchema()));
237     BatchVector batches;
238     RETURN_NOT_OK(ExampleIntBatches(&batches));
239     for (const auto& batch : batches) {
240       RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
241     }
242     return Status::OK();
243   }
244 
245   // A simple example - act like DoPut
RunExchangePut(std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMessageWriter> writer)246   Status RunExchangePut(std::unique_ptr<FlightMessageReader> reader,
247                         std::unique_ptr<FlightMessageWriter> writer) {
248     ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
249     if (!schema->Equals(ExampleIntSchema(), false)) {
250       return Status::Invalid("Schema is not as expected");
251     }
252     BatchVector batches;
253     RETURN_NOT_OK(ExampleIntBatches(&batches));
254     FlightStreamChunk chunk;
255     for (const auto& batch : batches) {
256       RETURN_NOT_OK(reader->Next(&chunk));
257       if (!chunk.data) {
258         return Status::Invalid("Expected another batch");
259       }
260       if (!batch->Equals(*chunk.data)) {
261         return Status::Invalid("Batch does not match");
262       }
263     }
264     RETURN_NOT_OK(reader->Next(&chunk));
265     if (chunk.data || chunk.app_metadata) {
266       return Status::Invalid("Too many batches");
267     }
268 
269     RETURN_NOT_OK(writer->WriteMetadata(Buffer::FromString("done")));
270     return Status::OK();
271   }
272 
273   // Read some number of record batches from the client, send a
274   // metadata message back with the count, then echo the batches back.
RunExchangeCounter(std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMessageWriter> writer)275   Status RunExchangeCounter(std::unique_ptr<FlightMessageReader> reader,
276                             std::unique_ptr<FlightMessageWriter> writer) {
277     std::vector<std::shared_ptr<RecordBatch>> batches;
278     FlightStreamChunk chunk;
279     int chunks = 0;
280     while (true) {
281       RETURN_NOT_OK(reader->Next(&chunk));
282       if (!chunk.data && !chunk.app_metadata) {
283         break;
284       }
285       if (chunk.data) {
286         batches.push_back(chunk.data);
287         chunks++;
288       }
289     }
290 
291     // Echo back the number of record batches read.
292     std::shared_ptr<Buffer> buf = Buffer::FromString(std::to_string(chunks));
293     RETURN_NOT_OK(writer->WriteMetadata(buf));
294     // Echo the record batches themselves.
295     if (chunks > 0) {
296       ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
297       RETURN_NOT_OK(writer->Begin(schema));
298 
299       for (const auto& batch : batches) {
300         RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
301       }
302     }
303 
304     return Status::OK();
305   }
306 
307   // Read int64 batches from the client, each time sending back a
308   // batch with a running sum of columns.
RunExchangeTotal(std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMessageWriter> writer)309   Status RunExchangeTotal(std::unique_ptr<FlightMessageReader> reader,
310                           std::unique_ptr<FlightMessageWriter> writer) {
311     FlightStreamChunk chunk{};
312     ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
313     // Ensure the schema contains only int64 columns
314     for (const auto& field : schema->fields()) {
315       if (field->type()->id() != Type::type::INT64) {
316         return Status::Invalid("Field is not INT64: ", field->name());
317       }
318     }
319     std::vector<int64_t> sums(schema->num_fields());
320     std::vector<std::shared_ptr<Array>> columns(schema->num_fields());
321     RETURN_NOT_OK(writer->Begin(schema));
322     while (true) {
323       RETURN_NOT_OK(reader->Next(&chunk));
324       if (!chunk.data && !chunk.app_metadata) {
325         break;
326       }
327       if (chunk.data) {
328         if (!chunk.data->schema()->Equals(schema, false)) {
329           // A compliant client implementation would make this impossible
330           return Status::Invalid("Schemas are incompatible");
331         }
332 
333         // Update the running totals
334         auto builder = std::make_shared<Int64Builder>();
335         int col_index = 0;
336         for (const auto& column : chunk.data->columns()) {
337           auto arr = std::dynamic_pointer_cast<Int64Array>(column);
338           if (!arr) {
339             return MakeFlightError(FlightStatusCode::Internal, "Could not cast array");
340           }
341           for (int row = 0; row < column->length(); row++) {
342             if (!arr->IsNull(row)) {
343               sums[col_index] += arr->Value(row);
344             }
345           }
346 
347           builder->Reset();
348           RETURN_NOT_OK(builder->Append(sums[col_index]));
349           RETURN_NOT_OK(builder->Finish(&columns[col_index]));
350 
351           col_index++;
352         }
353 
354         // Echo the totals to the client
355         auto response = RecordBatch::Make(schema, /* num_rows */ 1, columns);
356         RETURN_NOT_OK(writer->WriteRecordBatch(*response));
357       }
358     }
359     return Status::OK();
360   }
361 
362   // Echo the client's messages back.
RunExchangeEcho(std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMessageWriter> writer)363   Status RunExchangeEcho(std::unique_ptr<FlightMessageReader> reader,
364                          std::unique_ptr<FlightMessageWriter> writer) {
365     FlightStreamChunk chunk;
366     bool begun = false;
367     while (true) {
368       RETURN_NOT_OK(reader->Next(&chunk));
369       if (!chunk.data && !chunk.app_metadata) {
370         break;
371       }
372       if (!begun && chunk.data) {
373         begun = true;
374         RETURN_NOT_OK(writer->Begin(chunk.data->schema()));
375       }
376       if (chunk.data && chunk.app_metadata) {
377         RETURN_NOT_OK(writer->WriteWithMetadata(*chunk.data, chunk.app_metadata));
378       } else if (chunk.data) {
379         RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data));
380       } else if (chunk.app_metadata) {
381         RETURN_NOT_OK(writer->WriteMetadata(chunk.app_metadata));
382       }
383     }
384     return Status::OK();
385   }
386 
RunAction1(const Action & action,std::unique_ptr<ResultStream> * out)387   Status RunAction1(const Action& action, std::unique_ptr<ResultStream>* out) {
388     std::vector<Result> results;
389     for (int i = 0; i < 3; ++i) {
390       Result result;
391       std::string value = action.body->ToString() + "-part" + std::to_string(i);
392       result.body = Buffer::FromString(std::move(value));
393       results.push_back(result);
394     }
395     *out = std::unique_ptr<ResultStream>(new SimpleResultStream(std::move(results)));
396     return Status::OK();
397   }
398 
RunAction2(std::unique_ptr<ResultStream> * out)399   Status RunAction2(std::unique_ptr<ResultStream>* out) {
400     // Empty
401     *out = std::unique_ptr<ResultStream>(new SimpleResultStream({}));
402     return Status::OK();
403   }
404 
DoAction(const ServerCallContext & context,const Action & action,std::unique_ptr<ResultStream> * out)405   Status DoAction(const ServerCallContext& context, const Action& action,
406                   std::unique_ptr<ResultStream>* out) override {
407     if (action.type == "action1") {
408       return RunAction1(action, out);
409     } else if (action.type == "action2") {
410       return RunAction2(out);
411     } else {
412       return Status::NotImplemented(action.type);
413     }
414   }
415 
ListActions(const ServerCallContext & context,std::vector<ActionType> * out)416   Status ListActions(const ServerCallContext& context,
417                      std::vector<ActionType>* out) override {
418     std::vector<ActionType> actions = ExampleActionTypes();
419     *out = std::move(actions);
420     return Status::OK();
421   }
422 
GetSchema(const ServerCallContext & context,const FlightDescriptor & request,std::unique_ptr<SchemaResult> * schema)423   Status GetSchema(const ServerCallContext& context, const FlightDescriptor& request,
424                    std::unique_ptr<SchemaResult>* schema) override {
425     std::vector<FlightInfo> flights = ExampleFlightInfo();
426 
427     for (const auto& info : flights) {
428       if (info.descriptor().Equals(request)) {
429         *schema =
430             std::unique_ptr<SchemaResult>(new SchemaResult(info.serialized_schema()));
431         return Status::OK();
432       }
433     }
434     return Status::Invalid("Flight not found: ", request.ToString());
435   }
436 };
437 
ExampleTestServer()438 std::unique_ptr<FlightServerBase> ExampleTestServer() {
439   return std::unique_ptr<FlightServerBase>(new FlightTestServer);
440 }
441 
MakeFlightInfo(const Schema & schema,const FlightDescriptor & descriptor,const std::vector<FlightEndpoint> & endpoints,int64_t total_records,int64_t total_bytes,FlightInfo::Data * out)442 Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor,
443                       const std::vector<FlightEndpoint>& endpoints, int64_t total_records,
444                       int64_t total_bytes, FlightInfo::Data* out) {
445   out->descriptor = descriptor;
446   out->endpoints = endpoints;
447   out->total_records = total_records;
448   out->total_bytes = total_bytes;
449   return internal::SchemaToString(schema, &out->schema);
450 }
451 
NumberingStream(std::unique_ptr<FlightDataStream> stream)452 NumberingStream::NumberingStream(std::unique_ptr<FlightDataStream> stream)
453     : counter_(0), stream_(std::move(stream)) {}
454 
schema()455 std::shared_ptr<Schema> NumberingStream::schema() { return stream_->schema(); }
456 
GetSchemaPayload(FlightPayload * payload)457 Status NumberingStream::GetSchemaPayload(FlightPayload* payload) {
458   return stream_->GetSchemaPayload(payload);
459 }
460 
Next(FlightPayload * payload)461 Status NumberingStream::Next(FlightPayload* payload) {
462   RETURN_NOT_OK(stream_->Next(payload));
463   if (payload && payload->ipc_message.type == ipc::Message::RECORD_BATCH) {
464     payload->app_metadata = Buffer::FromString(std::to_string(counter_));
465     counter_++;
466   }
467   return Status::OK();
468 }
469 
ExampleIntSchema()470 std::shared_ptr<Schema> ExampleIntSchema() {
471   auto f0 = field("f0", int32());
472   auto f1 = field("f1", int32());
473   return ::arrow::schema({f0, f1});
474 }
475 
ExampleStringSchema()476 std::shared_ptr<Schema> ExampleStringSchema() {
477   auto f0 = field("f0", utf8());
478   auto f1 = field("f1", binary());
479   return ::arrow::schema({f0, f1});
480 }
481 
ExampleDictSchema()482 std::shared_ptr<Schema> ExampleDictSchema() {
483   std::shared_ptr<RecordBatch> batch;
484   ABORT_NOT_OK(ipc::test::MakeDictionary(&batch));
485   return batch->schema();
486 }
487 
ExampleFlightInfo()488 std::vector<FlightInfo> ExampleFlightInfo() {
489   Location location1;
490   Location location2;
491   Location location3;
492   Location location4;
493   ARROW_EXPECT_OK(Location::ForGrpcTcp("foo1.bar.com", 12345, &location1));
494   ARROW_EXPECT_OK(Location::ForGrpcTcp("foo2.bar.com", 12345, &location2));
495   ARROW_EXPECT_OK(Location::ForGrpcTcp("foo3.bar.com", 12345, &location3));
496   ARROW_EXPECT_OK(Location::ForGrpcTcp("foo4.bar.com", 12345, &location4));
497 
498   FlightInfo::Data flight1, flight2, flight3;
499 
500   FlightEndpoint endpoint1({{"ticket-ints-1"}, {location1}});
501   FlightEndpoint endpoint2({{"ticket-ints-2"}, {location2}});
502   FlightEndpoint endpoint3({{"ticket-cmd"}, {location3}});
503   FlightEndpoint endpoint4({{"ticket-dicts-1"}, {location4}});
504 
505   FlightDescriptor descr1{FlightDescriptor::PATH, "", {"examples", "ints"}};
506   FlightDescriptor descr2{FlightDescriptor::CMD, "my_command", {}};
507   FlightDescriptor descr3{FlightDescriptor::PATH, "", {"examples", "dicts"}};
508 
509   auto schema1 = ExampleIntSchema();
510   auto schema2 = ExampleStringSchema();
511   auto schema3 = ExampleDictSchema();
512 
513   ARROW_EXPECT_OK(
514       MakeFlightInfo(*schema1, descr1, {endpoint1, endpoint2}, 1000, 100000, &flight1));
515   ARROW_EXPECT_OK(MakeFlightInfo(*schema2, descr2, {endpoint3}, 1000, 100000, &flight2));
516   ARROW_EXPECT_OK(MakeFlightInfo(*schema3, descr3, {endpoint4}, -1, -1, &flight3));
517   return {FlightInfo(flight1), FlightInfo(flight2), FlightInfo(flight3)};
518 }
519 
ExampleIntBatches(BatchVector * out)520 Status ExampleIntBatches(BatchVector* out) {
521   std::shared_ptr<RecordBatch> batch;
522   for (int i = 0; i < 5; ++i) {
523     // Make all different sizes, use different random seed
524     RETURN_NOT_OK(ipc::test::MakeIntBatchSized(10 + i, &batch, i));
525     out->push_back(batch);
526   }
527   return Status::OK();
528 }
529 
ExampleDictBatches(BatchVector * out)530 Status ExampleDictBatches(BatchVector* out) {
531   // Just the same batch, repeated a few times
532   std::shared_ptr<RecordBatch> batch;
533   for (int i = 0; i < 3; ++i) {
534     RETURN_NOT_OK(ipc::test::MakeDictionary(&batch));
535     out->push_back(batch);
536   }
537   return Status::OK();
538 }
539 
ExampleActionTypes()540 std::vector<ActionType> ExampleActionTypes() {
541   return {{"drop", "drop a dataset"}, {"cache", "cache a dataset"}};
542 }
543 
TestServerAuthHandler(const std::string & username,const std::string & password)544 TestServerAuthHandler::TestServerAuthHandler(const std::string& username,
545                                              const std::string& password)
546     : username_(username), password_(password) {}
547 
~TestServerAuthHandler()548 TestServerAuthHandler::~TestServerAuthHandler() {}
549 
Authenticate(ServerAuthSender * outgoing,ServerAuthReader * incoming)550 Status TestServerAuthHandler::Authenticate(ServerAuthSender* outgoing,
551                                            ServerAuthReader* incoming) {
552   std::string token;
553   RETURN_NOT_OK(incoming->Read(&token));
554   if (token != password_) {
555     return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
556   }
557   RETURN_NOT_OK(outgoing->Write(username_));
558   return Status::OK();
559 }
560 
IsValid(const std::string & token,std::string * peer_identity)561 Status TestServerAuthHandler::IsValid(const std::string& token,
562                                       std::string* peer_identity) {
563   if (token != password_) {
564     return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
565   }
566   *peer_identity = username_;
567   return Status::OK();
568 }
569 
TestServerBasicAuthHandler(const std::string & username,const std::string & password)570 TestServerBasicAuthHandler::TestServerBasicAuthHandler(const std::string& username,
571                                                        const std::string& password) {
572   basic_auth_.username = username;
573   basic_auth_.password = password;
574 }
575 
~TestServerBasicAuthHandler()576 TestServerBasicAuthHandler::~TestServerBasicAuthHandler() {}
577 
Authenticate(ServerAuthSender * outgoing,ServerAuthReader * incoming)578 Status TestServerBasicAuthHandler::Authenticate(ServerAuthSender* outgoing,
579                                                 ServerAuthReader* incoming) {
580   std::string token;
581   RETURN_NOT_OK(incoming->Read(&token));
582   BasicAuth incoming_auth;
583   RETURN_NOT_OK(BasicAuth::Deserialize(token, &incoming_auth));
584   if (incoming_auth.username != basic_auth_.username ||
585       incoming_auth.password != basic_auth_.password) {
586     return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
587   }
588   RETURN_NOT_OK(outgoing->Write(basic_auth_.username));
589   return Status::OK();
590 }
591 
IsValid(const std::string & token,std::string * peer_identity)592 Status TestServerBasicAuthHandler::IsValid(const std::string& token,
593                                            std::string* peer_identity) {
594   if (token != basic_auth_.username) {
595     return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
596   }
597   *peer_identity = basic_auth_.username;
598   return Status::OK();
599 }
600 
TestClientAuthHandler(const std::string & username,const std::string & password)601 TestClientAuthHandler::TestClientAuthHandler(const std::string& username,
602                                              const std::string& password)
603     : username_(username), password_(password) {}
604 
~TestClientAuthHandler()605 TestClientAuthHandler::~TestClientAuthHandler() {}
606 
Authenticate(ClientAuthSender * outgoing,ClientAuthReader * incoming)607 Status TestClientAuthHandler::Authenticate(ClientAuthSender* outgoing,
608                                            ClientAuthReader* incoming) {
609   RETURN_NOT_OK(outgoing->Write(password_));
610   std::string username;
611   RETURN_NOT_OK(incoming->Read(&username));
612   if (username != username_) {
613     return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
614   }
615   return Status::OK();
616 }
617 
GetToken(std::string * token)618 Status TestClientAuthHandler::GetToken(std::string* token) {
619   *token = password_;
620   return Status::OK();
621 }
622 
TestClientBasicAuthHandler(const std::string & username,const std::string & password)623 TestClientBasicAuthHandler::TestClientBasicAuthHandler(const std::string& username,
624                                                        const std::string& password) {
625   basic_auth_.username = username;
626   basic_auth_.password = password;
627 }
628 
~TestClientBasicAuthHandler()629 TestClientBasicAuthHandler::~TestClientBasicAuthHandler() {}
630 
Authenticate(ClientAuthSender * outgoing,ClientAuthReader * incoming)631 Status TestClientBasicAuthHandler::Authenticate(ClientAuthSender* outgoing,
632                                                 ClientAuthReader* incoming) {
633   std::string pb_result;
634   RETURN_NOT_OK(BasicAuth::Serialize(basic_auth_, &pb_result));
635   RETURN_NOT_OK(outgoing->Write(pb_result));
636   RETURN_NOT_OK(incoming->Read(&token_));
637   return Status::OK();
638 }
639 
GetToken(std::string * token)640 Status TestClientBasicAuthHandler::GetToken(std::string* token) {
641   *token = token_;
642   return Status::OK();
643 }
644 
GetTestResourceRoot(std::string * out)645 Status GetTestResourceRoot(std::string* out) {
646   const char* c_root = std::getenv("ARROW_TEST_DATA");
647   if (!c_root) {
648     return Status::IOError(
649         "Test resources not found, set ARROW_TEST_DATA to <repo root>/testing/data");
650   }
651   *out = std::string(c_root);
652   return Status::OK();
653 }
654 
ExampleTlsCertificates(std::vector<CertKeyPair> * out)655 Status ExampleTlsCertificates(std::vector<CertKeyPair>* out) {
656   std::string root;
657   RETURN_NOT_OK(GetTestResourceRoot(&root));
658 
659   *out = std::vector<CertKeyPair>();
660   for (int i = 0; i < 2; i++) {
661     try {
662       std::stringstream cert_path;
663       cert_path << root << "/flight/cert" << i << ".pem";
664       std::stringstream key_path;
665       key_path << root << "/flight/cert" << i << ".key";
666 
667       std::ifstream cert_file(cert_path.str());
668       if (!cert_file) {
669         return Status::IOError("Could not open certificate: " + cert_path.str());
670       }
671       std::stringstream cert;
672       cert << cert_file.rdbuf();
673 
674       std::ifstream key_file(key_path.str());
675       if (!key_file) {
676         return Status::IOError("Could not open key: " + key_path.str());
677       }
678       std::stringstream key;
679       key << key_file.rdbuf();
680 
681       out->push_back(CertKeyPair{cert.str(), key.str()});
682     } catch (const std::ifstream::failure& e) {
683       return Status::IOError(e.what());
684     }
685   }
686   return Status::OK();
687 }
688 
ExampleTlsCertificateRoot(CertKeyPair * out)689 Status ExampleTlsCertificateRoot(CertKeyPair* out) {
690   std::string root;
691   RETURN_NOT_OK(GetTestResourceRoot(&root));
692 
693   std::stringstream path;
694   path << root << "/flight/root-ca.pem";
695 
696   try {
697     std::ifstream cert_file(path.str());
698     if (!cert_file) {
699       return Status::IOError("Could not open certificate: " + path.str());
700     }
701     std::stringstream cert;
702     cert << cert_file.rdbuf();
703     out->pem_cert = cert.str();
704     out->pem_key = "";
705     return Status::OK();
706   } catch (const std::ifstream::failure& e) {
707     return Status::IOError(e.what());
708   }
709 }
710 
711 }  // namespace flight
712 }  // namespace arrow
713