1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "arrow/flight/platform.h"
19
20 #ifdef __APPLE__
21 #include <limits.h>
22 #include <mach-o/dyld.h>
23 #endif
24
25 #include <cstdlib>
26 #include <sstream>
27
28 #include <boost/filesystem.hpp>
29 // boost/process/detail/windows/handle_workaround.hpp doesn't work
30 // without BOOST_USE_WINDOWS_H with MinGW because MinGW doesn't
31 // provide __kernel_entry without winternl.h.
32 //
33 // See also:
34 // https://github.com/boostorg/process/blob/develop/include/boost/process/detail/windows/handle_workaround.hpp
35 #define BOOST_USE_WINDOWS_H 1
36 #include <boost/process.hpp>
37
38 #include <gtest/gtest.h>
39
40 #include "arrow/ipc/test_common.h"
41 #include "arrow/testing/gtest_util.h"
42 #include "arrow/util/logging.h"
43
44 #include "arrow/flight/api.h"
45 #include "arrow/flight/internal.h"
46 #include "arrow/flight/test_util.h"
47
48 namespace arrow {
49 namespace flight {
50
51 namespace bp = boost::process;
52 namespace fs = boost::filesystem;
53
54 namespace {
55
ResolveCurrentExecutable(fs::path * out)56 Status ResolveCurrentExecutable(fs::path* out) {
57 // See https://stackoverflow.com/a/1024937/10194 for various
58 // platform-specific recipes.
59
60 boost::system::error_code ec;
61
62 #if defined(__linux__)
63 *out = fs::canonical("/proc/self/exe", ec);
64 #elif defined(__APPLE__)
65 char buf[PATH_MAX + 1];
66 uint32_t bufsize = sizeof(buf);
67 if (_NSGetExecutablePath(buf, &bufsize) < 0) {
68 return Status::Invalid("Can't resolve current exe: path too large");
69 }
70 *out = fs::canonical(buf, ec);
71 #elif defined(_WIN32)
72 char buf[MAX_PATH + 1];
73 if (!GetModuleFileNameA(NULL, buf, sizeof(buf))) {
74 return Status::Invalid("Can't get executable file path");
75 }
76 *out = fs::canonical(buf, ec);
77 #else
78 ARROW_UNUSED(ec);
79 return Status::NotImplemented("Not available on this system");
80 #endif
81 if (ec) {
82 // XXX fold this into the Status class?
83 return Status::IOError("Can't resolve current exe: ", ec.message());
84 } else {
85 return Status::OK();
86 }
87 }
88
89 } // namespace
90
Start()91 void TestServer::Start() {
92 namespace fs = boost::filesystem;
93
94 std::string str_port = std::to_string(port_);
95 std::vector<fs::path> search_path = ::boost::this_process::path();
96 // If possible, prepend current executable directory to search path,
97 // since it's likely that the test server executable is located in
98 // the same directory as the running unit test.
99 fs::path current_exe;
100 Status st = ResolveCurrentExecutable(¤t_exe);
101 if (st.ok()) {
102 search_path.insert(search_path.begin(), current_exe.parent_path());
103 } else if (st.IsNotImplemented()) {
104 ARROW_CHECK(st.IsNotImplemented()) << st.ToString();
105 }
106
107 try {
108 server_process_ = std::make_shared<bp::child>(
109 bp::search_path(executable_name_, search_path), "-port", str_port);
110 } catch (...) {
111 std::stringstream ss;
112 ss << "Failed to launch test server '" << executable_name_ << "', looked in ";
113 for (const auto& path : search_path) {
114 ss << path << " : ";
115 }
116 ARROW_LOG(FATAL) << ss.str();
117 throw;
118 }
119 std::cout << "Server running with pid " << server_process_->id() << std::endl;
120 }
121
Stop()122 int TestServer::Stop() {
123 if (server_process_ && server_process_->valid()) {
124 #ifndef _WIN32
125 kill(server_process_->id(), SIGTERM);
126 #else
127 // This would use SIGKILL on POSIX, which is more brutal than SIGTERM
128 server_process_->terminate();
129 #endif
130 server_process_->wait();
131 return server_process_->exit_code();
132 } else {
133 // Presumably the server wasn't able to start
134 return -1;
135 }
136 }
137
IsRunning()138 bool TestServer::IsRunning() { return server_process_->running(); }
139
port() const140 int TestServer::port() const { return port_; }
141
GetBatchForFlight(const Ticket & ticket,std::shared_ptr<RecordBatchReader> * out)142 Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr<RecordBatchReader>* out) {
143 if (ticket.ticket == "ticket-ints-1") {
144 BatchVector batches;
145 RETURN_NOT_OK(ExampleIntBatches(&batches));
146 *out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
147 return Status::OK();
148 } else if (ticket.ticket == "ticket-dicts-1") {
149 BatchVector batches;
150 RETURN_NOT_OK(ExampleDictBatches(&batches));
151 *out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
152 return Status::OK();
153 } else {
154 return Status::NotImplemented("no stream implemented for this ticket");
155 }
156 }
157
158 class FlightTestServer : public FlightServerBase {
ListFlights(const ServerCallContext & context,const Criteria * criteria,std::unique_ptr<FlightListing> * listings)159 Status ListFlights(const ServerCallContext& context, const Criteria* criteria,
160 std::unique_ptr<FlightListing>* listings) override {
161 std::vector<FlightInfo> flights = ExampleFlightInfo();
162 if (criteria && criteria->expression != "") {
163 // For test purposes, if we get criteria, return no results
164 flights.clear();
165 }
166 *listings = std::unique_ptr<FlightListing>(new SimpleFlightListing(flights));
167 return Status::OK();
168 }
169
GetFlightInfo(const ServerCallContext & context,const FlightDescriptor & request,std::unique_ptr<FlightInfo> * out)170 Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
171 std::unique_ptr<FlightInfo>* out) override {
172 // Test that Arrow-C++ status codes can make it through gRPC
173 if (request.type == FlightDescriptor::DescriptorType::CMD &&
174 request.cmd == "status-outofmemory") {
175 return Status::OutOfMemory("Sentinel");
176 }
177
178 std::vector<FlightInfo> flights = ExampleFlightInfo();
179
180 for (const auto& info : flights) {
181 if (info.descriptor().Equals(request)) {
182 *out = std::unique_ptr<FlightInfo>(new FlightInfo(info));
183 return Status::OK();
184 }
185 }
186 return Status::Invalid("Flight not found: ", request.ToString());
187 }
188
DoGet(const ServerCallContext & context,const Ticket & request,std::unique_ptr<FlightDataStream> * data_stream)189 Status DoGet(const ServerCallContext& context, const Ticket& request,
190 std::unique_ptr<FlightDataStream>* data_stream) override {
191 // Test for ARROW-5095
192 if (request.ticket == "ARROW-5095-fail") {
193 return Status::UnknownError("Server-side error");
194 }
195 if (request.ticket == "ARROW-5095-success") {
196 return Status::OK();
197 }
198
199 std::shared_ptr<RecordBatchReader> batch_reader;
200 RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader));
201
202 *data_stream = std::unique_ptr<FlightDataStream>(new RecordBatchStream(batch_reader));
203 return Status::OK();
204 }
205
DoExchange(const ServerCallContext & context,std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMessageWriter> writer)206 Status DoExchange(const ServerCallContext& context,
207 std::unique_ptr<FlightMessageReader> reader,
208 std::unique_ptr<FlightMessageWriter> writer) override {
209 // Test various scenarios for a DoExchange
210 if (reader->descriptor().type != FlightDescriptor::DescriptorType::CMD) {
211 return Status::Invalid("Must provide a command descriptor");
212 }
213
214 const std::string& cmd = reader->descriptor().cmd;
215 if (cmd == "error") {
216 // Immediately return an error to the client.
217 return Status::NotImplemented("Expected error");
218 } else if (cmd == "get") {
219 return RunExchangeGet(std::move(reader), std::move(writer));
220 } else if (cmd == "put") {
221 return RunExchangePut(std::move(reader), std::move(writer));
222 } else if (cmd == "counter") {
223 return RunExchangeCounter(std::move(reader), std::move(writer));
224 } else if (cmd == "total") {
225 return RunExchangeTotal(std::move(reader), std::move(writer));
226 } else if (cmd == "echo") {
227 return RunExchangeEcho(std::move(reader), std::move(writer));
228 } else {
229 return Status::NotImplemented("Scenario not implemented: ", cmd);
230 }
231 }
232
233 // A simple example - act like DoGet.
RunExchangeGet(std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMessageWriter> writer)234 Status RunExchangeGet(std::unique_ptr<FlightMessageReader> reader,
235 std::unique_ptr<FlightMessageWriter> writer) {
236 RETURN_NOT_OK(writer->Begin(ExampleIntSchema()));
237 BatchVector batches;
238 RETURN_NOT_OK(ExampleIntBatches(&batches));
239 for (const auto& batch : batches) {
240 RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
241 }
242 return Status::OK();
243 }
244
245 // A simple example - act like DoPut
RunExchangePut(std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMessageWriter> writer)246 Status RunExchangePut(std::unique_ptr<FlightMessageReader> reader,
247 std::unique_ptr<FlightMessageWriter> writer) {
248 ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
249 if (!schema->Equals(ExampleIntSchema(), false)) {
250 return Status::Invalid("Schema is not as expected");
251 }
252 BatchVector batches;
253 RETURN_NOT_OK(ExampleIntBatches(&batches));
254 FlightStreamChunk chunk;
255 for (const auto& batch : batches) {
256 RETURN_NOT_OK(reader->Next(&chunk));
257 if (!chunk.data) {
258 return Status::Invalid("Expected another batch");
259 }
260 if (!batch->Equals(*chunk.data)) {
261 return Status::Invalid("Batch does not match");
262 }
263 }
264 RETURN_NOT_OK(reader->Next(&chunk));
265 if (chunk.data || chunk.app_metadata) {
266 return Status::Invalid("Too many batches");
267 }
268
269 RETURN_NOT_OK(writer->WriteMetadata(Buffer::FromString("done")));
270 return Status::OK();
271 }
272
273 // Read some number of record batches from the client, send a
274 // metadata message back with the count, then echo the batches back.
RunExchangeCounter(std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMessageWriter> writer)275 Status RunExchangeCounter(std::unique_ptr<FlightMessageReader> reader,
276 std::unique_ptr<FlightMessageWriter> writer) {
277 std::vector<std::shared_ptr<RecordBatch>> batches;
278 FlightStreamChunk chunk;
279 int chunks = 0;
280 while (true) {
281 RETURN_NOT_OK(reader->Next(&chunk));
282 if (!chunk.data && !chunk.app_metadata) {
283 break;
284 }
285 if (chunk.data) {
286 batches.push_back(chunk.data);
287 chunks++;
288 }
289 }
290
291 // Echo back the number of record batches read.
292 std::shared_ptr<Buffer> buf = Buffer::FromString(std::to_string(chunks));
293 RETURN_NOT_OK(writer->WriteMetadata(buf));
294 // Echo the record batches themselves.
295 if (chunks > 0) {
296 ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
297 RETURN_NOT_OK(writer->Begin(schema));
298
299 for (const auto& batch : batches) {
300 RETURN_NOT_OK(writer->WriteRecordBatch(*batch));
301 }
302 }
303
304 return Status::OK();
305 }
306
307 // Read int64 batches from the client, each time sending back a
308 // batch with a running sum of columns.
RunExchangeTotal(std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMessageWriter> writer)309 Status RunExchangeTotal(std::unique_ptr<FlightMessageReader> reader,
310 std::unique_ptr<FlightMessageWriter> writer) {
311 FlightStreamChunk chunk{};
312 ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema());
313 // Ensure the schema contains only int64 columns
314 for (const auto& field : schema->fields()) {
315 if (field->type()->id() != Type::type::INT64) {
316 return Status::Invalid("Field is not INT64: ", field->name());
317 }
318 }
319 std::vector<int64_t> sums(schema->num_fields());
320 std::vector<std::shared_ptr<Array>> columns(schema->num_fields());
321 RETURN_NOT_OK(writer->Begin(schema));
322 while (true) {
323 RETURN_NOT_OK(reader->Next(&chunk));
324 if (!chunk.data && !chunk.app_metadata) {
325 break;
326 }
327 if (chunk.data) {
328 if (!chunk.data->schema()->Equals(schema, false)) {
329 // A compliant client implementation would make this impossible
330 return Status::Invalid("Schemas are incompatible");
331 }
332
333 // Update the running totals
334 auto builder = std::make_shared<Int64Builder>();
335 int col_index = 0;
336 for (const auto& column : chunk.data->columns()) {
337 auto arr = std::dynamic_pointer_cast<Int64Array>(column);
338 if (!arr) {
339 return MakeFlightError(FlightStatusCode::Internal, "Could not cast array");
340 }
341 for (int row = 0; row < column->length(); row++) {
342 if (!arr->IsNull(row)) {
343 sums[col_index] += arr->Value(row);
344 }
345 }
346
347 builder->Reset();
348 RETURN_NOT_OK(builder->Append(sums[col_index]));
349 RETURN_NOT_OK(builder->Finish(&columns[col_index]));
350
351 col_index++;
352 }
353
354 // Echo the totals to the client
355 auto response = RecordBatch::Make(schema, /* num_rows */ 1, columns);
356 RETURN_NOT_OK(writer->WriteRecordBatch(*response));
357 }
358 }
359 return Status::OK();
360 }
361
362 // Echo the client's messages back.
RunExchangeEcho(std::unique_ptr<FlightMessageReader> reader,std::unique_ptr<FlightMessageWriter> writer)363 Status RunExchangeEcho(std::unique_ptr<FlightMessageReader> reader,
364 std::unique_ptr<FlightMessageWriter> writer) {
365 FlightStreamChunk chunk;
366 bool begun = false;
367 while (true) {
368 RETURN_NOT_OK(reader->Next(&chunk));
369 if (!chunk.data && !chunk.app_metadata) {
370 break;
371 }
372 if (!begun && chunk.data) {
373 begun = true;
374 RETURN_NOT_OK(writer->Begin(chunk.data->schema()));
375 }
376 if (chunk.data && chunk.app_metadata) {
377 RETURN_NOT_OK(writer->WriteWithMetadata(*chunk.data, chunk.app_metadata));
378 } else if (chunk.data) {
379 RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data));
380 } else if (chunk.app_metadata) {
381 RETURN_NOT_OK(writer->WriteMetadata(chunk.app_metadata));
382 }
383 }
384 return Status::OK();
385 }
386
RunAction1(const Action & action,std::unique_ptr<ResultStream> * out)387 Status RunAction1(const Action& action, std::unique_ptr<ResultStream>* out) {
388 std::vector<Result> results;
389 for (int i = 0; i < 3; ++i) {
390 Result result;
391 std::string value = action.body->ToString() + "-part" + std::to_string(i);
392 result.body = Buffer::FromString(std::move(value));
393 results.push_back(result);
394 }
395 *out = std::unique_ptr<ResultStream>(new SimpleResultStream(std::move(results)));
396 return Status::OK();
397 }
398
RunAction2(std::unique_ptr<ResultStream> * out)399 Status RunAction2(std::unique_ptr<ResultStream>* out) {
400 // Empty
401 *out = std::unique_ptr<ResultStream>(new SimpleResultStream({}));
402 return Status::OK();
403 }
404
DoAction(const ServerCallContext & context,const Action & action,std::unique_ptr<ResultStream> * out)405 Status DoAction(const ServerCallContext& context, const Action& action,
406 std::unique_ptr<ResultStream>* out) override {
407 if (action.type == "action1") {
408 return RunAction1(action, out);
409 } else if (action.type == "action2") {
410 return RunAction2(out);
411 } else {
412 return Status::NotImplemented(action.type);
413 }
414 }
415
ListActions(const ServerCallContext & context,std::vector<ActionType> * out)416 Status ListActions(const ServerCallContext& context,
417 std::vector<ActionType>* out) override {
418 std::vector<ActionType> actions = ExampleActionTypes();
419 *out = std::move(actions);
420 return Status::OK();
421 }
422
GetSchema(const ServerCallContext & context,const FlightDescriptor & request,std::unique_ptr<SchemaResult> * schema)423 Status GetSchema(const ServerCallContext& context, const FlightDescriptor& request,
424 std::unique_ptr<SchemaResult>* schema) override {
425 std::vector<FlightInfo> flights = ExampleFlightInfo();
426
427 for (const auto& info : flights) {
428 if (info.descriptor().Equals(request)) {
429 *schema =
430 std::unique_ptr<SchemaResult>(new SchemaResult(info.serialized_schema()));
431 return Status::OK();
432 }
433 }
434 return Status::Invalid("Flight not found: ", request.ToString());
435 }
436 };
437
ExampleTestServer()438 std::unique_ptr<FlightServerBase> ExampleTestServer() {
439 return std::unique_ptr<FlightServerBase>(new FlightTestServer);
440 }
441
MakeFlightInfo(const Schema & schema,const FlightDescriptor & descriptor,const std::vector<FlightEndpoint> & endpoints,int64_t total_records,int64_t total_bytes,FlightInfo::Data * out)442 Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor,
443 const std::vector<FlightEndpoint>& endpoints, int64_t total_records,
444 int64_t total_bytes, FlightInfo::Data* out) {
445 out->descriptor = descriptor;
446 out->endpoints = endpoints;
447 out->total_records = total_records;
448 out->total_bytes = total_bytes;
449 return internal::SchemaToString(schema, &out->schema);
450 }
451
NumberingStream(std::unique_ptr<FlightDataStream> stream)452 NumberingStream::NumberingStream(std::unique_ptr<FlightDataStream> stream)
453 : counter_(0), stream_(std::move(stream)) {}
454
schema()455 std::shared_ptr<Schema> NumberingStream::schema() { return stream_->schema(); }
456
GetSchemaPayload(FlightPayload * payload)457 Status NumberingStream::GetSchemaPayload(FlightPayload* payload) {
458 return stream_->GetSchemaPayload(payload);
459 }
460
Next(FlightPayload * payload)461 Status NumberingStream::Next(FlightPayload* payload) {
462 RETURN_NOT_OK(stream_->Next(payload));
463 if (payload && payload->ipc_message.type == ipc::Message::RECORD_BATCH) {
464 payload->app_metadata = Buffer::FromString(std::to_string(counter_));
465 counter_++;
466 }
467 return Status::OK();
468 }
469
ExampleIntSchema()470 std::shared_ptr<Schema> ExampleIntSchema() {
471 auto f0 = field("f0", int32());
472 auto f1 = field("f1", int32());
473 return ::arrow::schema({f0, f1});
474 }
475
ExampleStringSchema()476 std::shared_ptr<Schema> ExampleStringSchema() {
477 auto f0 = field("f0", utf8());
478 auto f1 = field("f1", binary());
479 return ::arrow::schema({f0, f1});
480 }
481
ExampleDictSchema()482 std::shared_ptr<Schema> ExampleDictSchema() {
483 std::shared_ptr<RecordBatch> batch;
484 ABORT_NOT_OK(ipc::test::MakeDictionary(&batch));
485 return batch->schema();
486 }
487
ExampleFlightInfo()488 std::vector<FlightInfo> ExampleFlightInfo() {
489 Location location1;
490 Location location2;
491 Location location3;
492 Location location4;
493 ARROW_EXPECT_OK(Location::ForGrpcTcp("foo1.bar.com", 12345, &location1));
494 ARROW_EXPECT_OK(Location::ForGrpcTcp("foo2.bar.com", 12345, &location2));
495 ARROW_EXPECT_OK(Location::ForGrpcTcp("foo3.bar.com", 12345, &location3));
496 ARROW_EXPECT_OK(Location::ForGrpcTcp("foo4.bar.com", 12345, &location4));
497
498 FlightInfo::Data flight1, flight2, flight3;
499
500 FlightEndpoint endpoint1({{"ticket-ints-1"}, {location1}});
501 FlightEndpoint endpoint2({{"ticket-ints-2"}, {location2}});
502 FlightEndpoint endpoint3({{"ticket-cmd"}, {location3}});
503 FlightEndpoint endpoint4({{"ticket-dicts-1"}, {location4}});
504
505 FlightDescriptor descr1{FlightDescriptor::PATH, "", {"examples", "ints"}};
506 FlightDescriptor descr2{FlightDescriptor::CMD, "my_command", {}};
507 FlightDescriptor descr3{FlightDescriptor::PATH, "", {"examples", "dicts"}};
508
509 auto schema1 = ExampleIntSchema();
510 auto schema2 = ExampleStringSchema();
511 auto schema3 = ExampleDictSchema();
512
513 ARROW_EXPECT_OK(
514 MakeFlightInfo(*schema1, descr1, {endpoint1, endpoint2}, 1000, 100000, &flight1));
515 ARROW_EXPECT_OK(MakeFlightInfo(*schema2, descr2, {endpoint3}, 1000, 100000, &flight2));
516 ARROW_EXPECT_OK(MakeFlightInfo(*schema3, descr3, {endpoint4}, -1, -1, &flight3));
517 return {FlightInfo(flight1), FlightInfo(flight2), FlightInfo(flight3)};
518 }
519
ExampleIntBatches(BatchVector * out)520 Status ExampleIntBatches(BatchVector* out) {
521 std::shared_ptr<RecordBatch> batch;
522 for (int i = 0; i < 5; ++i) {
523 // Make all different sizes, use different random seed
524 RETURN_NOT_OK(ipc::test::MakeIntBatchSized(10 + i, &batch, i));
525 out->push_back(batch);
526 }
527 return Status::OK();
528 }
529
ExampleDictBatches(BatchVector * out)530 Status ExampleDictBatches(BatchVector* out) {
531 // Just the same batch, repeated a few times
532 std::shared_ptr<RecordBatch> batch;
533 for (int i = 0; i < 3; ++i) {
534 RETURN_NOT_OK(ipc::test::MakeDictionary(&batch));
535 out->push_back(batch);
536 }
537 return Status::OK();
538 }
539
ExampleActionTypes()540 std::vector<ActionType> ExampleActionTypes() {
541 return {{"drop", "drop a dataset"}, {"cache", "cache a dataset"}};
542 }
543
TestServerAuthHandler(const std::string & username,const std::string & password)544 TestServerAuthHandler::TestServerAuthHandler(const std::string& username,
545 const std::string& password)
546 : username_(username), password_(password) {}
547
~TestServerAuthHandler()548 TestServerAuthHandler::~TestServerAuthHandler() {}
549
Authenticate(ServerAuthSender * outgoing,ServerAuthReader * incoming)550 Status TestServerAuthHandler::Authenticate(ServerAuthSender* outgoing,
551 ServerAuthReader* incoming) {
552 std::string token;
553 RETURN_NOT_OK(incoming->Read(&token));
554 if (token != password_) {
555 return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
556 }
557 RETURN_NOT_OK(outgoing->Write(username_));
558 return Status::OK();
559 }
560
IsValid(const std::string & token,std::string * peer_identity)561 Status TestServerAuthHandler::IsValid(const std::string& token,
562 std::string* peer_identity) {
563 if (token != password_) {
564 return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
565 }
566 *peer_identity = username_;
567 return Status::OK();
568 }
569
TestServerBasicAuthHandler(const std::string & username,const std::string & password)570 TestServerBasicAuthHandler::TestServerBasicAuthHandler(const std::string& username,
571 const std::string& password) {
572 basic_auth_.username = username;
573 basic_auth_.password = password;
574 }
575
~TestServerBasicAuthHandler()576 TestServerBasicAuthHandler::~TestServerBasicAuthHandler() {}
577
Authenticate(ServerAuthSender * outgoing,ServerAuthReader * incoming)578 Status TestServerBasicAuthHandler::Authenticate(ServerAuthSender* outgoing,
579 ServerAuthReader* incoming) {
580 std::string token;
581 RETURN_NOT_OK(incoming->Read(&token));
582 BasicAuth incoming_auth;
583 RETURN_NOT_OK(BasicAuth::Deserialize(token, &incoming_auth));
584 if (incoming_auth.username != basic_auth_.username ||
585 incoming_auth.password != basic_auth_.password) {
586 return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
587 }
588 RETURN_NOT_OK(outgoing->Write(basic_auth_.username));
589 return Status::OK();
590 }
591
IsValid(const std::string & token,std::string * peer_identity)592 Status TestServerBasicAuthHandler::IsValid(const std::string& token,
593 std::string* peer_identity) {
594 if (token != basic_auth_.username) {
595 return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
596 }
597 *peer_identity = basic_auth_.username;
598 return Status::OK();
599 }
600
TestClientAuthHandler(const std::string & username,const std::string & password)601 TestClientAuthHandler::TestClientAuthHandler(const std::string& username,
602 const std::string& password)
603 : username_(username), password_(password) {}
604
~TestClientAuthHandler()605 TestClientAuthHandler::~TestClientAuthHandler() {}
606
Authenticate(ClientAuthSender * outgoing,ClientAuthReader * incoming)607 Status TestClientAuthHandler::Authenticate(ClientAuthSender* outgoing,
608 ClientAuthReader* incoming) {
609 RETURN_NOT_OK(outgoing->Write(password_));
610 std::string username;
611 RETURN_NOT_OK(incoming->Read(&username));
612 if (username != username_) {
613 return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token");
614 }
615 return Status::OK();
616 }
617
GetToken(std::string * token)618 Status TestClientAuthHandler::GetToken(std::string* token) {
619 *token = password_;
620 return Status::OK();
621 }
622
TestClientBasicAuthHandler(const std::string & username,const std::string & password)623 TestClientBasicAuthHandler::TestClientBasicAuthHandler(const std::string& username,
624 const std::string& password) {
625 basic_auth_.username = username;
626 basic_auth_.password = password;
627 }
628
~TestClientBasicAuthHandler()629 TestClientBasicAuthHandler::~TestClientBasicAuthHandler() {}
630
Authenticate(ClientAuthSender * outgoing,ClientAuthReader * incoming)631 Status TestClientBasicAuthHandler::Authenticate(ClientAuthSender* outgoing,
632 ClientAuthReader* incoming) {
633 std::string pb_result;
634 RETURN_NOT_OK(BasicAuth::Serialize(basic_auth_, &pb_result));
635 RETURN_NOT_OK(outgoing->Write(pb_result));
636 RETURN_NOT_OK(incoming->Read(&token_));
637 return Status::OK();
638 }
639
GetToken(std::string * token)640 Status TestClientBasicAuthHandler::GetToken(std::string* token) {
641 *token = token_;
642 return Status::OK();
643 }
644
GetTestResourceRoot(std::string * out)645 Status GetTestResourceRoot(std::string* out) {
646 const char* c_root = std::getenv("ARROW_TEST_DATA");
647 if (!c_root) {
648 return Status::IOError(
649 "Test resources not found, set ARROW_TEST_DATA to <repo root>/testing/data");
650 }
651 *out = std::string(c_root);
652 return Status::OK();
653 }
654
ExampleTlsCertificates(std::vector<CertKeyPair> * out)655 Status ExampleTlsCertificates(std::vector<CertKeyPair>* out) {
656 std::string root;
657 RETURN_NOT_OK(GetTestResourceRoot(&root));
658
659 *out = std::vector<CertKeyPair>();
660 for (int i = 0; i < 2; i++) {
661 try {
662 std::stringstream cert_path;
663 cert_path << root << "/flight/cert" << i << ".pem";
664 std::stringstream key_path;
665 key_path << root << "/flight/cert" << i << ".key";
666
667 std::ifstream cert_file(cert_path.str());
668 if (!cert_file) {
669 return Status::IOError("Could not open certificate: " + cert_path.str());
670 }
671 std::stringstream cert;
672 cert << cert_file.rdbuf();
673
674 std::ifstream key_file(key_path.str());
675 if (!key_file) {
676 return Status::IOError("Could not open key: " + key_path.str());
677 }
678 std::stringstream key;
679 key << key_file.rdbuf();
680
681 out->push_back(CertKeyPair{cert.str(), key.str()});
682 } catch (const std::ifstream::failure& e) {
683 return Status::IOError(e.what());
684 }
685 }
686 return Status::OK();
687 }
688
ExampleTlsCertificateRoot(CertKeyPair * out)689 Status ExampleTlsCertificateRoot(CertKeyPair* out) {
690 std::string root;
691 RETURN_NOT_OK(GetTestResourceRoot(&root));
692
693 std::stringstream path;
694 path << root << "/flight/root-ca.pem";
695
696 try {
697 std::ifstream cert_file(path.str());
698 if (!cert_file) {
699 return Status::IOError("Could not open certificate: " + path.str());
700 }
701 std::stringstream cert;
702 cert << cert_file.rdbuf();
703 out->pem_cert = cert.str();
704 out->pem_key = "";
705 return Status::OK();
706 } catch (const std::ifstream::failure& e) {
707 return Status::IOError(e.what());
708 }
709 }
710
711 } // namespace flight
712 } // namespace arrow
713