1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include <cstdint>
19 #include <mutex>
20 #include <sstream>
21 #include <string>
22 #include <vector>
23
24 #include <gflags/gflags.h>
25
26 #include "arrow/io/file.h"
27 #include "arrow/io/memory.h"
28 #include "arrow/ipc/api.h"
29 #include "arrow/record_batch.h"
30 #include "arrow/testing/gtest_util.h"
31 #include "arrow/util/compression.h"
32 #include "arrow/util/stopwatch.h"
33 #include "arrow/util/tdigest.h"
34 #include "arrow/util/thread_pool.h"
35
36 #include "arrow/flight/api.h"
37 #include "arrow/flight/perf.pb.h"
38 #include "arrow/flight/test_util.h"
39
40 DEFINE_string(server_host, "",
41 "An existing performance server to benchmark against (leave blank to spawn "
42 "one automatically)");
43 DEFINE_int32(server_port, 31337, "The port to connect to");
44 DEFINE_string(server_unix, "",
45 "An existing performance server listening on Unix socket (leave blank to "
46 "spawn one automatically)");
47 DEFINE_bool(test_unix, false, "Test Unix socket instead of TCP");
48 DEFINE_int32(num_perf_runs, 1,
49 "Number of times to run the perf test to "
50 "increase precision");
51 DEFINE_int32(num_servers, 1, "Number of performance servers to run");
52 DEFINE_int32(num_streams, 4, "Number of streams for each server");
53 DEFINE_int32(num_threads, 4, "Number of concurrent gets");
54 DEFINE_int64(records_per_stream, 10000000, "Total records per stream");
55 DEFINE_int32(records_per_batch, 4096, "Total records per batch within stream");
56 DEFINE_bool(test_put, false, "Test DoPut instead of DoGet");
57 DEFINE_string(compression, "",
58 "Select compression method (\"zstd\", \"lz4\"). "
59 "Leave blank to disable compression.\n"
60 "E.g., \"zstd\": zstd with default compression level.\n"
61 " \"zstd:7\": zstd with compression leve = 7.\n");
62 DEFINE_string(
63 data_file, "",
64 "Instead of random data, use data from the given IPC file. Only affects -test_put.");
65 DEFINE_string(cert_file, "", "Path to TLS certificate");
66 DEFINE_string(key_file, "", "Path to TLS private key (used when spawning a server)");
67
68 namespace perf = arrow::flight::perf;
69
70 namespace arrow {
71
72 using internal::StopWatch;
73 using internal::ThreadPool;
74
75 namespace flight {
76
77 struct PerformanceResult {
78 int64_t num_batches;
79 int64_t num_records;
80 int64_t num_bytes;
81 };
82
83 struct PerformanceStats {
84 std::mutex mutex;
85 int64_t total_batches = 0;
86 int64_t total_records = 0;
87 int64_t total_bytes = 0;
88 const std::array<double, 3> quantiles = {0.5, 0.95, 0.99};
89 mutable arrow::internal::TDigest latencies;
90
Updatearrow::flight::PerformanceStats91 void Update(int64_t total_batches, int64_t total_records, int64_t total_bytes) {
92 std::lock_guard<std::mutex> lock(this->mutex);
93 this->total_batches += total_batches;
94 this->total_records += total_records;
95 this->total_bytes += total_bytes;
96 }
97
98 // Invoked per batch in the test loop. Holding a lock looks not scalable.
99 // Tested with 1 ~ 8 threads, no noticeable overhead is observed.
100 // A better approach may be calculate per-thread quantiles and merge.
AddLatencyarrow::flight::PerformanceStats101 void AddLatency(uint64_t elapsed_nanos) {
102 std::lock_guard<std::mutex> lock(this->mutex);
103 latencies.Add(static_cast<double>(elapsed_nanos));
104 }
105
106 // ns -> us
max_latencyarrow::flight::PerformanceStats107 uint64_t max_latency() const { return latencies.Max() / 1000; }
108
mean_latencyarrow::flight::PerformanceStats109 uint64_t mean_latency() const { return latencies.Mean() / 1000; }
110
quantile_latencyarrow::flight::PerformanceStats111 uint64_t quantile_latency(double q) const { return latencies.Quantile(q) / 1000; }
112 };
113
WaitForReady(FlightClient * client,const FlightCallOptions & call_options)114 Status WaitForReady(FlightClient* client, const FlightCallOptions& call_options) {
115 Action action{"ping", nullptr};
116 for (int attempt = 0; attempt < 10; attempt++) {
117 std::unique_ptr<ResultStream> stream;
118 if (client->DoAction(call_options, action, &stream).ok()) {
119 return Status::OK();
120 }
121 std::this_thread::sleep_for(std::chrono::milliseconds(1000));
122 }
123 return Status::IOError("Server was not available after 10 attempts");
124 }
125
RunDoGetTest(FlightClient * client,const FlightCallOptions & call_options,const perf::Token & token,const FlightEndpoint & endpoint,PerformanceStats * stats)126 arrow::Result<PerformanceResult> RunDoGetTest(FlightClient* client,
127 const FlightCallOptions& call_options,
128 const perf::Token& token,
129 const FlightEndpoint& endpoint,
130 PerformanceStats* stats) {
131 std::unique_ptr<FlightStreamReader> reader;
132 RETURN_NOT_OK(client->DoGet(call_options, endpoint.ticket, &reader));
133
134 FlightStreamChunk batch;
135
136 // This is hard-coded for right now, 4 columns each with int64
137 const int bytes_per_record = 32;
138
139 // This must also be set in perf_server.cc
140 const bool verify = false;
141
142 int64_t num_bytes = 0;
143 int64_t num_records = 0;
144 int64_t num_batches = 0;
145 StopWatch timer;
146 while (true) {
147 timer.Start();
148 RETURN_NOT_OK(reader->Next(&batch));
149 stats->AddLatency(timer.Stop());
150 if (!batch.data) {
151 break;
152 }
153
154 if (verify) {
155 auto values = batch.data->column_data(0)->GetValues<int64_t>(1);
156 const int64_t start = token.start() + num_records;
157 for (int64_t i = 0; i < batch.data->num_rows(); ++i) {
158 if (values[i] != start + i) {
159 return Status::Invalid("verification failure");
160 }
161 }
162 }
163
164 ++num_batches;
165 num_records += batch.data->num_rows();
166
167 // Hard-coded
168 num_bytes += batch.data->num_rows() * bytes_per_record;
169 }
170 return PerformanceResult{num_batches, num_records, num_bytes};
171 }
172
173 struct SizedBatch {
174 std::shared_ptr<arrow::RecordBatch> batch;
175 int64_t bytes;
176 };
177
GetPutData(const perf::Token & token)178 arrow::Result<std::vector<SizedBatch>> GetPutData(const perf::Token& token) {
179 if (!FLAGS_data_file.empty()) {
180 ARROW_ASSIGN_OR_RAISE(auto file, arrow::io::ReadableFile::Open(FLAGS_data_file));
181 ARROW_ASSIGN_OR_RAISE(auto reader,
182 arrow::ipc::RecordBatchFileReader::Open(std::move(file)));
183 std::vector<SizedBatch> batches(reader->num_record_batches());
184 for (int i = 0; i < reader->num_record_batches(); i++) {
185 ARROW_ASSIGN_OR_RAISE(batches[i].batch, reader->ReadRecordBatch(i));
186 RETURN_NOT_OK(arrow::ipc::GetRecordBatchSize(*batches[i].batch, &batches[i].bytes));
187 }
188 return batches;
189 }
190
191 std::shared_ptr<Schema> schema =
192 arrow::schema({field("a", int64()), field("b", int64()), field("c", int64()),
193 field("d", int64())});
194
195 // This is hard-coded for right now, 4 columns each with int64
196 const int bytes_per_record = 32;
197
198 std::shared_ptr<ResizableBuffer> buffer;
199 std::vector<std::shared_ptr<Array>> arrays;
200
201 const int64_t total_records = token.definition().records_per_stream();
202 const int32_t length = token.definition().records_per_batch();
203 const int32_t ncolumns = 4;
204 for (int i = 0; i < ncolumns; ++i) {
205 RETURN_NOT_OK(MakeRandomByteBuffer(length * sizeof(int64_t), default_memory_pool(),
206 &buffer, static_cast<int32_t>(i) /* seed */));
207 arrays.push_back(std::make_shared<Int64Array>(length, buffer));
208 RETURN_NOT_OK(arrays.back()->Validate());
209 }
210
211 std::shared_ptr<RecordBatch> batch = RecordBatch::Make(schema, length, arrays);
212 std::vector<SizedBatch> batches;
213
214 int64_t records_sent = 0;
215 while (records_sent < total_records) {
216 if (records_sent + length > total_records) {
217 const int last_length = total_records - records_sent;
218 // Hard-coded
219 batches.push_back(SizedBatch{batch->Slice(0, last_length),
220 /*bytes=*/last_length * bytes_per_record});
221 records_sent += last_length;
222 } else {
223 // Hard-coded
224 batches.push_back(SizedBatch{batch, /*bytes=*/length * bytes_per_record});
225 records_sent += length;
226 }
227 }
228 return batches;
229 }
230
RunDoPutTest(FlightClient * client,const FlightCallOptions & call_options,const perf::Token & token,const FlightEndpoint & endpoint,PerformanceStats * stats)231 arrow::Result<PerformanceResult> RunDoPutTest(FlightClient* client,
232 const FlightCallOptions& call_options,
233 const perf::Token& token,
234 const FlightEndpoint& endpoint,
235 PerformanceStats* stats) {
236 ARROW_ASSIGN_OR_RAISE(const auto batches, GetPutData(token));
237 StopWatch timer;
238 int64_t num_records = 0;
239 int64_t num_bytes = 0;
240 std::unique_ptr<FlightStreamWriter> writer;
241 std::unique_ptr<FlightMetadataReader> reader;
242 RETURN_NOT_OK(client->DoPut(call_options, FlightDescriptor{},
243 batches[0].batch->schema(), &writer, &reader));
244 for (size_t i = 0; i < batches.size(); i++) {
245 auto batch = batches[i];
246 auto is_last = i == (batches.size() - 1);
247 if (is_last) {
248 RETURN_NOT_OK(writer->WriteRecordBatch(*batch.batch));
249 num_records += batch.batch->num_rows();
250 num_bytes += batch.bytes;
251 } else {
252 timer.Start();
253 RETURN_NOT_OK(writer->WriteRecordBatch(*batch.batch));
254 stats->AddLatency(timer.Stop());
255 num_records += batch.batch->num_rows();
256 num_bytes += batch.bytes;
257 }
258 }
259 RETURN_NOT_OK(writer->Close());
260 return PerformanceResult{static_cast<int64_t>(batches.size()), num_records, num_bytes};
261 }
262
DoSinglePerfRun(FlightClient * client,const FlightClientOptions client_options,const FlightCallOptions & call_options,bool test_put,PerformanceStats * stats)263 Status DoSinglePerfRun(FlightClient* client, const FlightClientOptions client_options,
264 const FlightCallOptions& call_options, bool test_put,
265 PerformanceStats* stats) {
266 // schema not needed
267 perf::Perf perf;
268 perf.set_stream_count(FLAGS_num_streams);
269 perf.set_records_per_stream(FLAGS_records_per_stream);
270 perf.set_records_per_batch(FLAGS_records_per_batch);
271
272 // Plan the query
273 FlightDescriptor descriptor;
274 descriptor.type = FlightDescriptor::CMD;
275 perf.SerializeToString(&descriptor.cmd);
276
277 std::unique_ptr<FlightInfo> plan;
278 RETURN_NOT_OK(client->GetFlightInfo(call_options, descriptor, &plan));
279
280 // Read the streams in parallel
281 std::shared_ptr<Schema> schema;
282 ipc::DictionaryMemo dict_memo;
283 RETURN_NOT_OK(plan->GetSchema(&dict_memo, &schema));
284
285 int64_t start_total_records = stats->total_records;
286
287 auto test_loop = test_put ? &RunDoPutTest : &RunDoGetTest;
288 auto ConsumeStream = [&stats, &test_loop, &client_options,
289 &call_options](const FlightEndpoint& endpoint) {
290 std::unique_ptr<FlightClient> client;
291 RETURN_NOT_OK(
292 FlightClient::Connect(endpoint.locations.front(), client_options, &client));
293
294 perf::Token token;
295 token.ParseFromString(endpoint.ticket.ticket);
296
297 const auto& result = test_loop(client.get(), call_options, token, endpoint, stats);
298 if (result.ok()) {
299 const PerformanceResult& perf = result.ValueOrDie();
300 stats->Update(perf.num_batches, perf.num_records, perf.num_bytes);
301 }
302 return result.status();
303 };
304
305 // XXX(wesm): Serial version for debugging
306 // for (const auto& endpoint : plan->endpoints()) {
307 // RETURN_NOT_OK(ConsumeStream(endpoint));
308 // }
309
310 ARROW_ASSIGN_OR_RAISE(auto pool, ThreadPool::Make(FLAGS_num_threads));
311 std::vector<Future<>> tasks;
312 for (const auto& endpoint : plan->endpoints()) {
313 ARROW_ASSIGN_OR_RAISE(auto task, pool->Submit(ConsumeStream, endpoint));
314 tasks.push_back(std::move(task));
315 }
316
317 // Wait for tasks to finish
318 for (auto&& task : tasks) {
319 RETURN_NOT_OK(task.status());
320 }
321
322 if (FLAGS_data_file.empty()) {
323 // Check that number of rows read / written is as expected
324 int64_t records_for_run = stats->total_records - start_total_records;
325 if (records_for_run != static_cast<int64_t>(plan->total_records())) {
326 return Status::Invalid("Did not consume expected number of records");
327 }
328 }
329 return Status::OK();
330 }
331
RunPerformanceTest(FlightClient * client,const FlightClientOptions & client_options,const FlightCallOptions & call_options,bool test_put)332 Status RunPerformanceTest(FlightClient* client, const FlightClientOptions& client_options,
333 const FlightCallOptions& call_options, bool test_put) {
334 StopWatch timer;
335 timer.Start();
336
337 PerformanceStats stats;
338 for (int i = 0; i < FLAGS_num_perf_runs; ++i) {
339 RETURN_NOT_OK(
340 DoSinglePerfRun(client, client_options, call_options, test_put, &stats));
341 }
342
343 // Elapsed time in seconds
344 uint64_t elapsed_nanos = timer.Stop();
345 double time_elapsed =
346 static_cast<double>(elapsed_nanos) / static_cast<double>(1000000000);
347
348 constexpr double kMegabyte = static_cast<double>(1 << 20);
349
350 std::cout << "Number of perf runs: " << FLAGS_num_perf_runs << std::endl;
351 std::cout << "Number of concurrent gets/puts: " << FLAGS_num_threads << std::endl;
352 std::cout << "Batch size: " << stats.total_bytes / stats.total_batches << std::endl;
353 if (FLAGS_test_put) {
354 std::cout << "Batches written: " << stats.total_batches << std::endl;
355 std::cout << "Bytes written: " << stats.total_bytes << std::endl;
356 } else {
357 std::cout << "Batches read: " << stats.total_batches << std::endl;
358 std::cout << "Bytes read: " << stats.total_bytes << std::endl;
359 }
360
361 std::cout << "Nanos: " << elapsed_nanos << std::endl;
362 std::cout << "Speed: "
363 << (static_cast<double>(stats.total_bytes) / kMegabyte / time_elapsed)
364 << " MB/s" << std::endl;
365
366 // Calculate throughput(IOPS) and latency vs batch size
367 std::cout << "Throughput: " << (static_cast<double>(stats.total_batches) / time_elapsed)
368 << " batches/s" << std::endl;
369 std::cout << "Latency mean: " << stats.mean_latency() << " us" << std::endl;
370 for (auto q : stats.quantiles) {
371 std::cout << "Latency quantile=" << q << ": " << stats.quantile_latency(q) << " us"
372 << std::endl;
373 }
374 std::cout << "Latency max: " << stats.max_latency() << " us" << std::endl;
375
376 return Status::OK();
377 }
378
379 } // namespace flight
380 } // namespace arrow
381
main(int argc,char ** argv)382 int main(int argc, char** argv) {
383 gflags::ParseCommandLineFlags(&argc, &argv, true);
384
385 std::cout << "Testing method: ";
386 if (FLAGS_test_put) {
387 std::cout << "DoPut";
388 } else {
389 std::cout << "DoGet";
390 }
391 std::cout << std::endl;
392
393 arrow::flight::FlightCallOptions call_options;
394 if (!FLAGS_compression.empty()) {
395 if (!FLAGS_test_put) {
396 std::cerr << "Compression is only useful for Put test now, "
397 "please append \"-test_put\" to command line"
398 << std::endl;
399 std::abort();
400 }
401
402 // "zstd" -> name = "zstd", level = default
403 // "zstd:7" -> name = "zstd", level = 7
404 const size_t delim = FLAGS_compression.find(":");
405 const std::string name = FLAGS_compression.substr(0, delim);
406 const std::string level_str =
407 delim == std::string::npos
408 ? ""
409 : FLAGS_compression.substr(delim + 1, FLAGS_compression.length() - delim - 1);
410 const int level = level_str.empty() ? arrow::util::kUseDefaultCompressionLevel
411 : std::stoi(level_str);
412 const auto type = arrow::util::Codec::GetCompressionType(name).ValueOrDie();
413 auto codec = arrow::util::Codec::Create(type, level).ValueOrDie();
414 std::cout << "Compression method: " << name;
415 if (!level_str.empty()) {
416 std::cout << ", level " << level;
417 }
418 std::cout << std::endl;
419
420 call_options.write_options.codec = std::move(codec);
421 }
422 if (!FLAGS_data_file.empty() && !FLAGS_test_put) {
423 std::cerr << "A data file can only be specified with \"-test_put\"" << std::endl;
424 return 1;
425 }
426
427 std::unique_ptr<arrow::flight::TestServer> server;
428 arrow::flight::Location location;
429 auto options = arrow::flight::FlightClientOptions::Defaults();
430 if (FLAGS_test_unix || !FLAGS_server_unix.empty()) {
431 if (FLAGS_server_unix == "") {
432 FLAGS_server_unix = "/tmp/flight-bench-spawn.sock";
433 std::cout << "Using spawned Unix server" << std::endl;
434 server.reset(
435 new arrow::flight::TestServer("arrow-flight-perf-server", FLAGS_server_unix));
436 server->Start();
437 } else {
438 std::cout << "Using standalone Unix server" << std::endl;
439 }
440 std::cout << "Server unix socket: " << FLAGS_server_unix << std::endl;
441 ABORT_NOT_OK(arrow::flight::Location::ForGrpcUnix(FLAGS_server_unix, &location));
442 } else {
443 if (FLAGS_server_host == "") {
444 FLAGS_server_host = "localhost";
445 std::cout << "Using spawned TCP server" << std::endl;
446 server.reset(
447 new arrow::flight::TestServer("arrow-flight-perf-server", FLAGS_server_port));
448 std::vector<std::string> args;
449 if (!FLAGS_cert_file.empty() || !FLAGS_key_file.empty()) {
450 if (!FLAGS_cert_file.empty() && !FLAGS_key_file.empty()) {
451 std::cout << "Enabling TLS for spawned server" << std::endl;
452 args.push_back("-cert_file");
453 args.push_back(FLAGS_cert_file);
454 args.push_back("-key_file");
455 args.push_back(FLAGS_key_file);
456 } else {
457 std::cerr << "If providing TLS cert/key, must provide both" << std::endl;
458 return 1;
459 }
460 }
461 server->Start(args);
462 } else {
463 std::cout << "Using standalone TCP server" << std::endl;
464 }
465 std::cout << "Server host: " << FLAGS_server_host << std::endl
466 << "Server port: " << FLAGS_server_port << std::endl;
467 if (FLAGS_cert_file.empty()) {
468 ABORT_NOT_OK(arrow::flight::Location::ForGrpcTcp(FLAGS_server_host,
469 FLAGS_server_port, &location));
470 } else {
471 ABORT_NOT_OK(arrow::flight::Location::ForGrpcTls(FLAGS_server_host,
472 FLAGS_server_port, &location));
473 options.disable_server_verification = true;
474 }
475 }
476
477 std::unique_ptr<arrow::flight::FlightClient> client;
478 ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location, options, &client));
479 ABORT_NOT_OK(arrow::flight::WaitForReady(client.get(), call_options));
480
481 arrow::Status s = arrow::flight::RunPerformanceTest(client.get(), options, call_options,
482 FLAGS_test_put);
483
484 if (server) {
485 server->Stop();
486 }
487
488 if (!s.ok()) {
489 std::cerr << "Failed with error: << " << s.ToString() << std::endl;
490 }
491
492 return 0;
493 }
494