1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include <cstdint>
19 #include <mutex>
20 #include <sstream>
21 #include <string>
22 #include <vector>
23 
24 #include <gflags/gflags.h>
25 
26 #include "arrow/io/file.h"
27 #include "arrow/io/memory.h"
28 #include "arrow/ipc/api.h"
29 #include "arrow/record_batch.h"
30 #include "arrow/testing/gtest_util.h"
31 #include "arrow/util/compression.h"
32 #include "arrow/util/stopwatch.h"
33 #include "arrow/util/tdigest.h"
34 #include "arrow/util/thread_pool.h"
35 
36 #include "arrow/flight/api.h"
37 #include "arrow/flight/perf.pb.h"
38 #include "arrow/flight/test_util.h"
39 
40 DEFINE_string(server_host, "",
41               "An existing performance server to benchmark against (leave blank to spawn "
42               "one automatically)");
43 DEFINE_int32(server_port, 31337, "The port to connect to");
44 DEFINE_string(server_unix, "",
45               "An existing performance server listening on Unix socket (leave blank to "
46               "spawn one automatically)");
47 DEFINE_bool(test_unix, false, "Test Unix socket instead of TCP");
48 DEFINE_int32(num_perf_runs, 1,
49              "Number of times to run the perf test to "
50              "increase precision");
51 DEFINE_int32(num_servers, 1, "Number of performance servers to run");
52 DEFINE_int32(num_streams, 4, "Number of streams for each server");
53 DEFINE_int32(num_threads, 4, "Number of concurrent gets");
54 DEFINE_int64(records_per_stream, 10000000, "Total records per stream");
55 DEFINE_int32(records_per_batch, 4096, "Total records per batch within stream");
56 DEFINE_bool(test_put, false, "Test DoPut instead of DoGet");
57 DEFINE_string(compression, "",
58               "Select compression method (\"zstd\", \"lz4\"). "
59               "Leave blank to disable compression.\n"
60               "E.g., \"zstd\":   zstd with default compression level.\n"
61               "      \"zstd:7\": zstd with compression leve = 7.\n");
62 DEFINE_string(
63     data_file, "",
64     "Instead of random data, use data from the given IPC file. Only affects -test_put.");
65 DEFINE_string(cert_file, "", "Path to TLS certificate");
66 DEFINE_string(key_file, "", "Path to TLS private key (used when spawning a server)");
67 
68 namespace perf = arrow::flight::perf;
69 
70 namespace arrow {
71 
72 using internal::StopWatch;
73 using internal::ThreadPool;
74 
75 namespace flight {
76 
77 struct PerformanceResult {
78   int64_t num_batches;
79   int64_t num_records;
80   int64_t num_bytes;
81 };
82 
83 struct PerformanceStats {
84   std::mutex mutex;
85   int64_t total_batches = 0;
86   int64_t total_records = 0;
87   int64_t total_bytes = 0;
88   const std::array<double, 3> quantiles = {0.5, 0.95, 0.99};
89   mutable arrow::internal::TDigest latencies;
90 
Updatearrow::flight::PerformanceStats91   void Update(int64_t total_batches, int64_t total_records, int64_t total_bytes) {
92     std::lock_guard<std::mutex> lock(this->mutex);
93     this->total_batches += total_batches;
94     this->total_records += total_records;
95     this->total_bytes += total_bytes;
96   }
97 
98   // Invoked per batch in the test loop. Holding a lock looks not scalable.
99   // Tested with 1 ~ 8 threads, no noticeable overhead is observed.
100   // A better approach may be calculate per-thread quantiles and merge.
AddLatencyarrow::flight::PerformanceStats101   void AddLatency(uint64_t elapsed_nanos) {
102     std::lock_guard<std::mutex> lock(this->mutex);
103     latencies.Add(static_cast<double>(elapsed_nanos));
104   }
105 
106   // ns -> us
max_latencyarrow::flight::PerformanceStats107   uint64_t max_latency() const { return latencies.Max() / 1000; }
108 
mean_latencyarrow::flight::PerformanceStats109   uint64_t mean_latency() const { return latencies.Mean() / 1000; }
110 
quantile_latencyarrow::flight::PerformanceStats111   uint64_t quantile_latency(double q) const { return latencies.Quantile(q) / 1000; }
112 };
113 
WaitForReady(FlightClient * client,const FlightCallOptions & call_options)114 Status WaitForReady(FlightClient* client, const FlightCallOptions& call_options) {
115   Action action{"ping", nullptr};
116   for (int attempt = 0; attempt < 10; attempt++) {
117     std::unique_ptr<ResultStream> stream;
118     if (client->DoAction(call_options, action, &stream).ok()) {
119       return Status::OK();
120     }
121     std::this_thread::sleep_for(std::chrono::milliseconds(1000));
122   }
123   return Status::IOError("Server was not available after 10 attempts");
124 }
125 
RunDoGetTest(FlightClient * client,const FlightCallOptions & call_options,const perf::Token & token,const FlightEndpoint & endpoint,PerformanceStats * stats)126 arrow::Result<PerformanceResult> RunDoGetTest(FlightClient* client,
127                                               const FlightCallOptions& call_options,
128                                               const perf::Token& token,
129                                               const FlightEndpoint& endpoint,
130                                               PerformanceStats* stats) {
131   std::unique_ptr<FlightStreamReader> reader;
132   RETURN_NOT_OK(client->DoGet(call_options, endpoint.ticket, &reader));
133 
134   FlightStreamChunk batch;
135 
136   // This is hard-coded for right now, 4 columns each with int64
137   const int bytes_per_record = 32;
138 
139   // This must also be set in perf_server.cc
140   const bool verify = false;
141 
142   int64_t num_bytes = 0;
143   int64_t num_records = 0;
144   int64_t num_batches = 0;
145   StopWatch timer;
146   while (true) {
147     timer.Start();
148     RETURN_NOT_OK(reader->Next(&batch));
149     stats->AddLatency(timer.Stop());
150     if (!batch.data) {
151       break;
152     }
153 
154     if (verify) {
155       auto values = batch.data->column_data(0)->GetValues<int64_t>(1);
156       const int64_t start = token.start() + num_records;
157       for (int64_t i = 0; i < batch.data->num_rows(); ++i) {
158         if (values[i] != start + i) {
159           return Status::Invalid("verification failure");
160         }
161       }
162     }
163 
164     ++num_batches;
165     num_records += batch.data->num_rows();
166 
167     // Hard-coded
168     num_bytes += batch.data->num_rows() * bytes_per_record;
169   }
170   return PerformanceResult{num_batches, num_records, num_bytes};
171 }
172 
173 struct SizedBatch {
174   std::shared_ptr<arrow::RecordBatch> batch;
175   int64_t bytes;
176 };
177 
GetPutData(const perf::Token & token)178 arrow::Result<std::vector<SizedBatch>> GetPutData(const perf::Token& token) {
179   if (!FLAGS_data_file.empty()) {
180     ARROW_ASSIGN_OR_RAISE(auto file, arrow::io::ReadableFile::Open(FLAGS_data_file));
181     ARROW_ASSIGN_OR_RAISE(auto reader,
182                           arrow::ipc::RecordBatchFileReader::Open(std::move(file)));
183     std::vector<SizedBatch> batches(reader->num_record_batches());
184     for (int i = 0; i < reader->num_record_batches(); i++) {
185       ARROW_ASSIGN_OR_RAISE(batches[i].batch, reader->ReadRecordBatch(i));
186       RETURN_NOT_OK(arrow::ipc::GetRecordBatchSize(*batches[i].batch, &batches[i].bytes));
187     }
188     return batches;
189   }
190 
191   std::shared_ptr<Schema> schema =
192       arrow::schema({field("a", int64()), field("b", int64()), field("c", int64()),
193                      field("d", int64())});
194 
195   // This is hard-coded for right now, 4 columns each with int64
196   const int bytes_per_record = 32;
197 
198   std::shared_ptr<ResizableBuffer> buffer;
199   std::vector<std::shared_ptr<Array>> arrays;
200 
201   const int64_t total_records = token.definition().records_per_stream();
202   const int32_t length = token.definition().records_per_batch();
203   const int32_t ncolumns = 4;
204   for (int i = 0; i < ncolumns; ++i) {
205     RETURN_NOT_OK(MakeRandomByteBuffer(length * sizeof(int64_t), default_memory_pool(),
206                                        &buffer, static_cast<int32_t>(i) /* seed */));
207     arrays.push_back(std::make_shared<Int64Array>(length, buffer));
208     RETURN_NOT_OK(arrays.back()->Validate());
209   }
210 
211   std::shared_ptr<RecordBatch> batch = RecordBatch::Make(schema, length, arrays);
212   std::vector<SizedBatch> batches;
213 
214   int64_t records_sent = 0;
215   while (records_sent < total_records) {
216     if (records_sent + length > total_records) {
217       const int last_length = total_records - records_sent;
218       // Hard-coded
219       batches.push_back(SizedBatch{batch->Slice(0, last_length),
220                                    /*bytes=*/last_length * bytes_per_record});
221       records_sent += last_length;
222     } else {
223       // Hard-coded
224       batches.push_back(SizedBatch{batch, /*bytes=*/length * bytes_per_record});
225       records_sent += length;
226     }
227   }
228   return batches;
229 }
230 
RunDoPutTest(FlightClient * client,const FlightCallOptions & call_options,const perf::Token & token,const FlightEndpoint & endpoint,PerformanceStats * stats)231 arrow::Result<PerformanceResult> RunDoPutTest(FlightClient* client,
232                                               const FlightCallOptions& call_options,
233                                               const perf::Token& token,
234                                               const FlightEndpoint& endpoint,
235                                               PerformanceStats* stats) {
236   ARROW_ASSIGN_OR_RAISE(const auto batches, GetPutData(token));
237   StopWatch timer;
238   int64_t num_records = 0;
239   int64_t num_bytes = 0;
240   std::unique_ptr<FlightStreamWriter> writer;
241   std::unique_ptr<FlightMetadataReader> reader;
242   RETURN_NOT_OK(client->DoPut(call_options, FlightDescriptor{},
243                               batches[0].batch->schema(), &writer, &reader));
244   for (size_t i = 0; i < batches.size(); i++) {
245     auto batch = batches[i];
246     auto is_last = i == (batches.size() - 1);
247     if (is_last) {
248       RETURN_NOT_OK(writer->WriteRecordBatch(*batch.batch));
249       num_records += batch.batch->num_rows();
250       num_bytes += batch.bytes;
251     } else {
252       timer.Start();
253       RETURN_NOT_OK(writer->WriteRecordBatch(*batch.batch));
254       stats->AddLatency(timer.Stop());
255       num_records += batch.batch->num_rows();
256       num_bytes += batch.bytes;
257     }
258   }
259   RETURN_NOT_OK(writer->Close());
260   return PerformanceResult{static_cast<int64_t>(batches.size()), num_records, num_bytes};
261 }
262 
DoSinglePerfRun(FlightClient * client,const FlightClientOptions client_options,const FlightCallOptions & call_options,bool test_put,PerformanceStats * stats)263 Status DoSinglePerfRun(FlightClient* client, const FlightClientOptions client_options,
264                        const FlightCallOptions& call_options, bool test_put,
265                        PerformanceStats* stats) {
266   // schema not needed
267   perf::Perf perf;
268   perf.set_stream_count(FLAGS_num_streams);
269   perf.set_records_per_stream(FLAGS_records_per_stream);
270   perf.set_records_per_batch(FLAGS_records_per_batch);
271 
272   // Plan the query
273   FlightDescriptor descriptor;
274   descriptor.type = FlightDescriptor::CMD;
275   perf.SerializeToString(&descriptor.cmd);
276 
277   std::unique_ptr<FlightInfo> plan;
278   RETURN_NOT_OK(client->GetFlightInfo(call_options, descriptor, &plan));
279 
280   // Read the streams in parallel
281   std::shared_ptr<Schema> schema;
282   ipc::DictionaryMemo dict_memo;
283   RETURN_NOT_OK(plan->GetSchema(&dict_memo, &schema));
284 
285   int64_t start_total_records = stats->total_records;
286 
287   auto test_loop = test_put ? &RunDoPutTest : &RunDoGetTest;
288   auto ConsumeStream = [&stats, &test_loop, &client_options,
289                         &call_options](const FlightEndpoint& endpoint) {
290     std::unique_ptr<FlightClient> client;
291     RETURN_NOT_OK(
292         FlightClient::Connect(endpoint.locations.front(), client_options, &client));
293 
294     perf::Token token;
295     token.ParseFromString(endpoint.ticket.ticket);
296 
297     const auto& result = test_loop(client.get(), call_options, token, endpoint, stats);
298     if (result.ok()) {
299       const PerformanceResult& perf = result.ValueOrDie();
300       stats->Update(perf.num_batches, perf.num_records, perf.num_bytes);
301     }
302     return result.status();
303   };
304 
305   // XXX(wesm): Serial version for debugging
306   // for (const auto& endpoint : plan->endpoints()) {
307   //   RETURN_NOT_OK(ConsumeStream(endpoint));
308   // }
309 
310   ARROW_ASSIGN_OR_RAISE(auto pool, ThreadPool::Make(FLAGS_num_threads));
311   std::vector<Future<>> tasks;
312   for (const auto& endpoint : plan->endpoints()) {
313     ARROW_ASSIGN_OR_RAISE(auto task, pool->Submit(ConsumeStream, endpoint));
314     tasks.push_back(std::move(task));
315   }
316 
317   // Wait for tasks to finish
318   for (auto&& task : tasks) {
319     RETURN_NOT_OK(task.status());
320   }
321 
322   if (FLAGS_data_file.empty()) {
323     // Check that number of rows read / written is as expected
324     int64_t records_for_run = stats->total_records - start_total_records;
325     if (records_for_run != static_cast<int64_t>(plan->total_records())) {
326       return Status::Invalid("Did not consume expected number of records");
327     }
328   }
329   return Status::OK();
330 }
331 
RunPerformanceTest(FlightClient * client,const FlightClientOptions & client_options,const FlightCallOptions & call_options,bool test_put)332 Status RunPerformanceTest(FlightClient* client, const FlightClientOptions& client_options,
333                           const FlightCallOptions& call_options, bool test_put) {
334   StopWatch timer;
335   timer.Start();
336 
337   PerformanceStats stats;
338   for (int i = 0; i < FLAGS_num_perf_runs; ++i) {
339     RETURN_NOT_OK(
340         DoSinglePerfRun(client, client_options, call_options, test_put, &stats));
341   }
342 
343   // Elapsed time in seconds
344   uint64_t elapsed_nanos = timer.Stop();
345   double time_elapsed =
346       static_cast<double>(elapsed_nanos) / static_cast<double>(1000000000);
347 
348   constexpr double kMegabyte = static_cast<double>(1 << 20);
349 
350   std::cout << "Number of perf runs: " << FLAGS_num_perf_runs << std::endl;
351   std::cout << "Number of concurrent gets/puts: " << FLAGS_num_threads << std::endl;
352   std::cout << "Batch size: " << stats.total_bytes / stats.total_batches << std::endl;
353   if (FLAGS_test_put) {
354     std::cout << "Batches written: " << stats.total_batches << std::endl;
355     std::cout << "Bytes written: " << stats.total_bytes << std::endl;
356   } else {
357     std::cout << "Batches read: " << stats.total_batches << std::endl;
358     std::cout << "Bytes read: " << stats.total_bytes << std::endl;
359   }
360 
361   std::cout << "Nanos: " << elapsed_nanos << std::endl;
362   std::cout << "Speed: "
363             << (static_cast<double>(stats.total_bytes) / kMegabyte / time_elapsed)
364             << " MB/s" << std::endl;
365 
366   // Calculate throughput(IOPS) and latency vs batch size
367   std::cout << "Throughput: " << (static_cast<double>(stats.total_batches) / time_elapsed)
368             << " batches/s" << std::endl;
369   std::cout << "Latency mean: " << stats.mean_latency() << " us" << std::endl;
370   for (auto q : stats.quantiles) {
371     std::cout << "Latency quantile=" << q << ": " << stats.quantile_latency(q) << " us"
372               << std::endl;
373   }
374   std::cout << "Latency max: " << stats.max_latency() << " us" << std::endl;
375 
376   return Status::OK();
377 }
378 
379 }  // namespace flight
380 }  // namespace arrow
381 
main(int argc,char ** argv)382 int main(int argc, char** argv) {
383   gflags::ParseCommandLineFlags(&argc, &argv, true);
384 
385   std::cout << "Testing method: ";
386   if (FLAGS_test_put) {
387     std::cout << "DoPut";
388   } else {
389     std::cout << "DoGet";
390   }
391   std::cout << std::endl;
392 
393   arrow::flight::FlightCallOptions call_options;
394   if (!FLAGS_compression.empty()) {
395     if (!FLAGS_test_put) {
396       std::cerr << "Compression is only useful for Put test now, "
397                    "please append \"-test_put\" to command line"
398                 << std::endl;
399       std::abort();
400     }
401 
402     // "zstd"   -> name = "zstd", level = default
403     // "zstd:7" -> name = "zstd", level = 7
404     const size_t delim = FLAGS_compression.find(":");
405     const std::string name = FLAGS_compression.substr(0, delim);
406     const std::string level_str =
407         delim == std::string::npos
408             ? ""
409             : FLAGS_compression.substr(delim + 1, FLAGS_compression.length() - delim - 1);
410     const int level = level_str.empty() ? arrow::util::kUseDefaultCompressionLevel
411                                         : std::stoi(level_str);
412     const auto type = arrow::util::Codec::GetCompressionType(name).ValueOrDie();
413     auto codec = arrow::util::Codec::Create(type, level).ValueOrDie();
414     std::cout << "Compression method: " << name;
415     if (!level_str.empty()) {
416       std::cout << ", level " << level;
417     }
418     std::cout << std::endl;
419 
420     call_options.write_options.codec = std::move(codec);
421   }
422   if (!FLAGS_data_file.empty() && !FLAGS_test_put) {
423     std::cerr << "A data file can only be specified with \"-test_put\"" << std::endl;
424     return 1;
425   }
426 
427   std::unique_ptr<arrow::flight::TestServer> server;
428   arrow::flight::Location location;
429   auto options = arrow::flight::FlightClientOptions::Defaults();
430   if (FLAGS_test_unix || !FLAGS_server_unix.empty()) {
431     if (FLAGS_server_unix == "") {
432       FLAGS_server_unix = "/tmp/flight-bench-spawn.sock";
433       std::cout << "Using spawned Unix server" << std::endl;
434       server.reset(
435           new arrow::flight::TestServer("arrow-flight-perf-server", FLAGS_server_unix));
436       server->Start();
437     } else {
438       std::cout << "Using standalone Unix server" << std::endl;
439     }
440     std::cout << "Server unix socket: " << FLAGS_server_unix << std::endl;
441     ABORT_NOT_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &location));
442   } else {
443     if (FLAGS_server_host == "") {
444       FLAGS_server_host = "localhost";
445       std::cout << "Using spawned TCP server" << std::endl;
446       server.reset(
447           new arrow::flight::TestServer("arrow-flight-perf-server", FLAGS_server_port));
448       std::vector<std::string> args;
449       if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) {
450         if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) {
451           std::cout << "Enabling TLS for spawned server" << std::endl;
452           args.push_back("-cert_file");
453           args.push_back(FLAGS_cert_file);
454           args.push_back("-key_file");
455           args.push_back(FLAGS_key_file);
456         } else {
457           std::cerr << "If providing TLS cert/key, must provide both" << std::endl;
458           return 1;
459         }
460       }
461       server->Start(args);
462     } else {
463       std::cout << "Using standalone TCP server" << std::endl;
464     }
465     std::cout << "Server host: " << FLAGS_server_host << std::endl
466               << "Server port: " << FLAGS_server_port << std::endl;
467     if (FLAGS_cert_file.empty()) {
468       ABORT_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host,
469                                                        FLAGS_server_port, &location));
470     } else {
471       ABORT_NOT_OK(arrow::flight::Location::ForGrpcTls(FLAGS_server_host,
472                                                        FLAGS_server_port, &location));
473       options.disable_server_verification = true;
474     }
475   }
476 
477   std::unique_ptr<arrow::flight::FlightClient> client;
478   ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location, options, &client));
479   ABORT_NOT_OK(arrow::flight::WaitForReady(client.get(), call_options));
480 
481   arrow::Status s = arrow::flight::RunPerformanceTest(client.get(), options, call_options,
482                                                       FLAGS_test_put);
483 
484   if (server) {
485     server->Stop();
486   }
487 
488   if (!s.ok()) {
489     std::cerr << "Failed with error: << " << s.ToString() << std::endl;
490   }
491 
492   return 0;
493 }
494