1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import ast 19import base64 20import itertools 21import os 22import signal 23import struct 24import tempfile 25import threading 26import time 27import traceback 28import json 29 30import numpy as np 31import pytest 32import pyarrow as pa 33 34from pyarrow.lib import tobytes 35from pyarrow.util import pathlib, find_free_port 36from pyarrow.tests import util 37 38try: 39 from pyarrow import flight 40 from pyarrow.flight import ( 41 FlightClient, FlightServerBase, 42 ServerAuthHandler, ClientAuthHandler, 43 ServerMiddleware, ServerMiddlewareFactory, 44 ClientMiddleware, ClientMiddlewareFactory, 45 ) 46except ImportError: 47 flight = None 48 FlightClient, FlightServerBase = object, object 49 ServerAuthHandler, ClientAuthHandler = object, object 50 ServerMiddleware, ServerMiddlewareFactory = object, object 51 ClientMiddleware, ClientMiddlewareFactory = object, object 52 53# Marks all of the tests in this module 54# Ignore these with pytest ... -m 'not flight' 55pytestmark = pytest.mark.flight 56 57 58def test_import(): 59 # So we see the ImportError somewhere 60 import pyarrow.flight # noqa 61 62 63def resource_root(): 64 """Get the path to the test resources directory.""" 65 if not os.environ.get("ARROW_TEST_DATA"): 66 raise RuntimeError("Test resources not found; set " 67 "ARROW_TEST_DATA to <repo root>/testing/data") 68 return pathlib.Path(os.environ["ARROW_TEST_DATA"]) / "flight" 69 70 71def read_flight_resource(path): 72 """Get the contents of a test resource file.""" 73 root = resource_root() 74 if not root: 75 return None 76 try: 77 with (root / path).open("rb") as f: 78 return f.read() 79 except FileNotFoundError: 80 raise RuntimeError( 81 "Test resource {} not found; did you initialize the " 82 "test resource submodule?\n{}".format(root / path, 83 traceback.format_exc())) 84 85 86def example_tls_certs(): 87 """Get the paths to test TLS certificates.""" 88 return { 89 "root_cert": read_flight_resource("root-ca.pem"), 90 "certificates": [ 91 flight.CertKeyPair( 92 cert=read_flight_resource("cert0.pem"), 93 key=read_flight_resource("cert0.key"), 94 ), 95 flight.CertKeyPair( 96 cert=read_flight_resource("cert1.pem"), 97 key=read_flight_resource("cert1.key"), 98 ), 99 ] 100 } 101 102 103def simple_ints_table(): 104 data = [ 105 pa.array([-10, -5, 0, 5, 10]) 106 ] 107 return pa.Table.from_arrays(data, names=['some_ints']) 108 109 110def simple_dicts_table(): 111 dict_values = pa.array(["foo", "baz", "quux"], type=pa.utf8()) 112 data = [ 113 pa.chunked_array([ 114 pa.DictionaryArray.from_arrays([1, 0, None], dict_values), 115 pa.DictionaryArray.from_arrays([2, 1], dict_values) 116 ]) 117 ] 118 return pa.Table.from_arrays(data, names=['some_dicts']) 119 120 121class ConstantFlightServer(FlightServerBase): 122 """A Flight server that always returns the same data. 123 124 See ARROW-4796: this server implementation will segfault if Flight 125 does not properly hold a reference to the Table object. 126 """ 127 128 CRITERIA = b"the expected criteria" 129 130 def __init__(self, location=None, options=None, **kwargs): 131 super().__init__(location, **kwargs) 132 # Ticket -> Table 133 self.table_factories = { 134 b'ints': simple_ints_table, 135 b'dicts': simple_dicts_table, 136 } 137 self.options = options 138 139 def list_flights(self, context, criteria): 140 if criteria == self.CRITERIA: 141 yield flight.FlightInfo( 142 pa.schema([]), 143 flight.FlightDescriptor.for_path('/foo'), 144 [], 145 -1, -1 146 ) 147 148 def do_get(self, context, ticket): 149 # Return a fresh table, so that Flight is the only one keeping a 150 # reference. 151 table = self.table_factories[ticket.ticket]() 152 return flight.RecordBatchStream(table, options=self.options) 153 154 155class MetadataFlightServer(FlightServerBase): 156 """A Flight server that numbers incoming/outgoing data.""" 157 158 def __init__(self, options=None, **kwargs): 159 super().__init__(**kwargs) 160 self.options = options 161 162 def do_get(self, context, ticket): 163 data = [ 164 pa.array([-10, -5, 0, 5, 10]) 165 ] 166 table = pa.Table.from_arrays(data, names=['a']) 167 return flight.GeneratorStream( 168 table.schema, 169 self.number_batches(table), 170 options=self.options) 171 172 def do_put(self, context, descriptor, reader, writer): 173 counter = 0 174 expected_data = [-10, -5, 0, 5, 10] 175 while True: 176 try: 177 batch, buf = reader.read_chunk() 178 assert batch.equals(pa.RecordBatch.from_arrays( 179 [pa.array([expected_data[counter]])], 180 ['a'] 181 )) 182 assert buf is not None 183 client_counter, = struct.unpack('<i', buf.to_pybytes()) 184 assert counter == client_counter 185 writer.write(struct.pack('<i', counter)) 186 counter += 1 187 except StopIteration: 188 return 189 190 @staticmethod 191 def number_batches(table): 192 for idx, batch in enumerate(table.to_batches()): 193 buf = struct.pack('<i', idx) 194 yield batch, buf 195 196 197class EchoFlightServer(FlightServerBase): 198 """A Flight server that returns the last data uploaded.""" 199 200 def __init__(self, location=None, expected_schema=None, **kwargs): 201 super().__init__(location, **kwargs) 202 self.last_message = None 203 self.expected_schema = expected_schema 204 205 def do_get(self, context, ticket): 206 return flight.RecordBatchStream(self.last_message) 207 208 def do_put(self, context, descriptor, reader, writer): 209 if self.expected_schema: 210 assert self.expected_schema == reader.schema 211 self.last_message = reader.read_all() 212 213 def do_exchange(self, context, descriptor, reader, writer): 214 for chunk in reader: 215 pass 216 217 218class EchoStreamFlightServer(EchoFlightServer): 219 """An echo server that streams individual record batches.""" 220 221 def do_get(self, context, ticket): 222 return flight.GeneratorStream( 223 self.last_message.schema, 224 self.last_message.to_batches(max_chunksize=1024)) 225 226 def list_actions(self, context): 227 return [] 228 229 def do_action(self, context, action): 230 if action.type == "who-am-i": 231 return [context.peer_identity(), context.peer().encode("utf-8")] 232 raise NotImplementedError 233 234 235class GetInfoFlightServer(FlightServerBase): 236 """A Flight server that tests GetFlightInfo.""" 237 238 def get_flight_info(self, context, descriptor): 239 return flight.FlightInfo( 240 pa.schema([('a', pa.int32())]), 241 descriptor, 242 [ 243 flight.FlightEndpoint(b'', ['grpc://test']), 244 flight.FlightEndpoint( 245 b'', 246 [flight.Location.for_grpc_tcp('localhost', 5005)], 247 ), 248 ], 249 -1, 250 -1, 251 ) 252 253 def get_schema(self, context, descriptor): 254 info = self.get_flight_info(context, descriptor) 255 return flight.SchemaResult(info.schema) 256 257 258class ListActionsFlightServer(FlightServerBase): 259 """A Flight server that tests ListActions.""" 260 261 @classmethod 262 def expected_actions(cls): 263 return [ 264 ("action-1", "description"), 265 ("action-2", ""), 266 flight.ActionType("action-3", "more detail"), 267 ] 268 269 def list_actions(self, context): 270 yield from self.expected_actions() 271 272 273class ListActionsErrorFlightServer(FlightServerBase): 274 """A Flight server that tests ListActions.""" 275 276 def list_actions(self, context): 277 yield ("action-1", "") 278 yield "foo" 279 280 281class CheckTicketFlightServer(FlightServerBase): 282 """A Flight server that compares the given ticket to an expected value.""" 283 284 def __init__(self, expected_ticket, location=None, **kwargs): 285 super().__init__(location, **kwargs) 286 self.expected_ticket = expected_ticket 287 288 def do_get(self, context, ticket): 289 assert self.expected_ticket == ticket.ticket 290 data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] 291 table = pa.Table.from_arrays(data1, names=['a']) 292 return flight.RecordBatchStream(table) 293 294 def do_put(self, context, descriptor, reader): 295 self.last_message = reader.read_all() 296 297 298class InvalidStreamFlightServer(FlightServerBase): 299 """A Flight server that tries to return messages with differing schemas.""" 300 301 schema = pa.schema([('a', pa.int32())]) 302 303 def do_get(self, context, ticket): 304 data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] 305 data2 = [pa.array([-10.0, -5.0, 0.0, 5.0, 10.0], type=pa.float64())] 306 assert data1.type != data2.type 307 table1 = pa.Table.from_arrays(data1, names=['a']) 308 table2 = pa.Table.from_arrays(data2, names=['a']) 309 assert table1.schema == self.schema 310 311 return flight.GeneratorStream(self.schema, [table1, table2]) 312 313 314class NeverSendsDataFlightServer(FlightServerBase): 315 """A Flight server that never actually yields data.""" 316 317 schema = pa.schema([('a', pa.int32())]) 318 319 def do_get(self, context, ticket): 320 if ticket.ticket == b'yield_data': 321 # Check that the server handler will ignore empty tables 322 # up to a certain extent 323 data = [ 324 self.schema.empty_table(), 325 self.schema.empty_table(), 326 pa.RecordBatch.from_arrays([range(5)], schema=self.schema), 327 ] 328 return flight.GeneratorStream(self.schema, data) 329 return flight.GeneratorStream( 330 self.schema, itertools.repeat(self.schema.empty_table())) 331 332 333class SlowFlightServer(FlightServerBase): 334 """A Flight server that delays its responses to test timeouts.""" 335 336 def do_get(self, context, ticket): 337 return flight.GeneratorStream(pa.schema([('a', pa.int32())]), 338 self.slow_stream()) 339 340 def do_action(self, context, action): 341 time.sleep(0.5) 342 return [] 343 344 @staticmethod 345 def slow_stream(): 346 data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] 347 yield pa.Table.from_arrays(data1, names=['a']) 348 # The second message should never get sent; the client should 349 # cancel before we send this 350 time.sleep(10) 351 yield pa.Table.from_arrays(data1, names=['a']) 352 353 354class ErrorFlightServer(FlightServerBase): 355 """A Flight server that uses all the Flight-specific errors.""" 356 357 def do_action(self, context, action): 358 if action.type == "internal": 359 raise flight.FlightInternalError("foo") 360 elif action.type == "timedout": 361 raise flight.FlightTimedOutError("foo") 362 elif action.type == "cancel": 363 raise flight.FlightCancelledError("foo") 364 elif action.type == "unauthenticated": 365 raise flight.FlightUnauthenticatedError("foo") 366 elif action.type == "unauthorized": 367 raise flight.FlightUnauthorizedError("foo") 368 elif action.type == "protobuf": 369 err_msg = b'this is an error message' 370 raise flight.FlightUnauthorizedError("foo", err_msg) 371 raise NotImplementedError 372 373 def list_flights(self, context, criteria): 374 yield flight.FlightInfo( 375 pa.schema([]), 376 flight.FlightDescriptor.for_path('/foo'), 377 [], 378 -1, -1 379 ) 380 raise flight.FlightInternalError("foo") 381 382 def do_put(self, context, descriptor, reader, writer): 383 if descriptor.command == b"internal": 384 raise flight.FlightInternalError("foo") 385 elif descriptor.command == b"timedout": 386 raise flight.FlightTimedOutError("foo") 387 elif descriptor.command == b"cancel": 388 raise flight.FlightCancelledError("foo") 389 elif descriptor.command == b"unauthenticated": 390 raise flight.FlightUnauthenticatedError("foo") 391 elif descriptor.command == b"unauthorized": 392 raise flight.FlightUnauthorizedError("foo") 393 elif descriptor.command == b"protobuf": 394 err_msg = b'this is an error message' 395 raise flight.FlightUnauthorizedError("foo", err_msg) 396 397 398class ExchangeFlightServer(FlightServerBase): 399 """A server for testing DoExchange.""" 400 401 def __init__(self, options=None, **kwargs): 402 super().__init__(**kwargs) 403 self.options = options 404 405 def do_exchange(self, context, descriptor, reader, writer): 406 if descriptor.descriptor_type != flight.DescriptorType.CMD: 407 raise pa.ArrowInvalid("Must provide a command descriptor") 408 elif descriptor.command == b"echo": 409 return self.exchange_echo(context, reader, writer) 410 elif descriptor.command == b"get": 411 return self.exchange_do_get(context, reader, writer) 412 elif descriptor.command == b"put": 413 return self.exchange_do_put(context, reader, writer) 414 elif descriptor.command == b"transform": 415 return self.exchange_transform(context, reader, writer) 416 else: 417 raise pa.ArrowInvalid( 418 "Unknown command: {}".format(descriptor.command)) 419 420 def exchange_do_get(self, context, reader, writer): 421 """Emulate DoGet with DoExchange.""" 422 data = pa.Table.from_arrays([ 423 pa.array(range(0, 10 * 1024)) 424 ], names=["a"]) 425 writer.begin(data.schema) 426 writer.write_table(data) 427 428 def exchange_do_put(self, context, reader, writer): 429 """Emulate DoPut with DoExchange.""" 430 num_batches = 0 431 for chunk in reader: 432 if not chunk.data: 433 raise pa.ArrowInvalid("All chunks must have data.") 434 num_batches += 1 435 writer.write_metadata(str(num_batches).encode("utf-8")) 436 437 def exchange_echo(self, context, reader, writer): 438 """Run a simple echo server.""" 439 started = False 440 for chunk in reader: 441 if not started and chunk.data: 442 writer.begin(chunk.data.schema, options=self.options) 443 started = True 444 if chunk.app_metadata and chunk.data: 445 writer.write_with_metadata(chunk.data, chunk.app_metadata) 446 elif chunk.app_metadata: 447 writer.write_metadata(chunk.app_metadata) 448 elif chunk.data: 449 writer.write_batch(chunk.data) 450 else: 451 assert False, "Should not happen" 452 453 def exchange_transform(self, context, reader, writer): 454 """Sum rows in an uploaded table.""" 455 for field in reader.schema: 456 if not pa.types.is_integer(field.type): 457 raise pa.ArrowInvalid("Invalid field: " + repr(field)) 458 table = reader.read_all() 459 sums = [0] * table.num_rows 460 for column in table: 461 for row, value in enumerate(column): 462 sums[row] += value.as_py() 463 result = pa.Table.from_arrays([pa.array(sums)], names=["sum"]) 464 writer.begin(result.schema) 465 writer.write_table(result) 466 467 468class HttpBasicServerAuthHandler(ServerAuthHandler): 469 """An example implementation of HTTP basic authentication.""" 470 471 def __init__(self, creds): 472 super().__init__() 473 self.creds = creds 474 475 def authenticate(self, outgoing, incoming): 476 buf = incoming.read() 477 auth = flight.BasicAuth.deserialize(buf) 478 if auth.username not in self.creds: 479 raise flight.FlightUnauthenticatedError("unknown user") 480 if self.creds[auth.username] != auth.password: 481 raise flight.FlightUnauthenticatedError("wrong password") 482 outgoing.write(tobytes(auth.username)) 483 484 def is_valid(self, token): 485 if not token: 486 raise flight.FlightUnauthenticatedError("token not provided") 487 if token not in self.creds: 488 raise flight.FlightUnauthenticatedError("unknown user") 489 return token 490 491 492class HttpBasicClientAuthHandler(ClientAuthHandler): 493 """An example implementation of HTTP basic authentication.""" 494 495 def __init__(self, username, password): 496 super().__init__() 497 self.basic_auth = flight.BasicAuth(username, password) 498 self.token = None 499 500 def authenticate(self, outgoing, incoming): 501 auth = self.basic_auth.serialize() 502 outgoing.write(auth) 503 self.token = incoming.read() 504 505 def get_token(self): 506 return self.token 507 508 509class TokenServerAuthHandler(ServerAuthHandler): 510 """An example implementation of authentication via handshake.""" 511 512 def __init__(self, creds): 513 super().__init__() 514 self.creds = creds 515 516 def authenticate(self, outgoing, incoming): 517 username = incoming.read() 518 password = incoming.read() 519 if username in self.creds and self.creds[username] == password: 520 outgoing.write(base64.b64encode(b'secret:' + username)) 521 else: 522 raise flight.FlightUnauthenticatedError( 523 "invalid username/password") 524 525 def is_valid(self, token): 526 token = base64.b64decode(token) 527 if not token.startswith(b'secret:'): 528 raise flight.FlightUnauthenticatedError("invalid token") 529 return token[7:] 530 531 532class TokenClientAuthHandler(ClientAuthHandler): 533 """An example implementation of authentication via handshake.""" 534 535 def __init__(self, username, password): 536 super().__init__() 537 self.username = username 538 self.password = password 539 self.token = b'' 540 541 def authenticate(self, outgoing, incoming): 542 outgoing.write(self.username) 543 outgoing.write(self.password) 544 self.token = incoming.read() 545 546 def get_token(self): 547 return self.token 548 549 550class NoopAuthHandler(ServerAuthHandler): 551 """A no-op auth handler.""" 552 553 def authenticate(self, outgoing, incoming): 554 """Do nothing.""" 555 556 def is_valid(self, token): 557 """ 558 Returning an empty string. 559 Returning None causes Type error. 560 """ 561 return "" 562 563 564def case_insensitive_header_lookup(headers, lookup_key): 565 """Lookup the value of given key in the given headers. 566 The key lookup is case insensitive. 567 """ 568 for key in headers: 569 if key.lower() == lookup_key.lower(): 570 return headers.get(key) 571 572 573class ClientHeaderAuthMiddlewareFactory(ClientMiddlewareFactory): 574 """ClientMiddlewareFactory that creates ClientAuthHeaderMiddleware.""" 575 576 def __init__(self): 577 self.call_credential = [] 578 579 def start_call(self, info): 580 return ClientHeaderAuthMiddleware(self) 581 582 def set_call_credential(self, call_credential): 583 self.call_credential = call_credential 584 585 586class ClientHeaderAuthMiddleware(ClientMiddleware): 587 """ 588 ClientMiddleware that extracts the authorization header 589 from the server. 590 591 This is an example of a ClientMiddleware that can extract 592 the bearer token authorization header from a HTTP header 593 authentication enabled server. 594 595 Parameters 596 ---------- 597 factory : ClientHeaderAuthMiddlewareFactory 598 This factory is used to set call credentials if an 599 authorization header is found in the headers from the server. 600 """ 601 602 def __init__(self, factory): 603 self.factory = factory 604 605 def received_headers(self, headers): 606 auth_header = case_insensitive_header_lookup(headers, 'Authorization') 607 self.factory.set_call_credential([ 608 b'authorization', 609 auth_header[0].encode("utf-8")]) 610 611 612class HeaderAuthServerMiddlewareFactory(ServerMiddlewareFactory): 613 """Validates incoming username and password.""" 614 615 def start_call(self, info, headers): 616 auth_header = case_insensitive_header_lookup( 617 headers, 618 'Authorization' 619 ) 620 values = auth_header[0].split(' ') 621 token = '' 622 error_message = 'Invalid credentials' 623 624 if values[0] == 'Basic': 625 decoded = base64.b64decode(values[1]) 626 pair = decoded.decode("utf-8").split(':') 627 if not (pair[0] == 'test' and pair[1] == 'password'): 628 raise flight.FlightUnauthenticatedError(error_message) 629 token = 'token1234' 630 elif values[0] == 'Bearer': 631 token = values[1] 632 if not token == 'token1234': 633 raise flight.FlightUnauthenticatedError(error_message) 634 else: 635 raise flight.FlightUnauthenticatedError(error_message) 636 637 return HeaderAuthServerMiddleware(token) 638 639 640class HeaderAuthServerMiddleware(ServerMiddleware): 641 """A ServerMiddleware that transports incoming username and passowrd.""" 642 643 def __init__(self, token): 644 self.token = token 645 646 def sending_headers(self): 647 return {'authorization': 'Bearer ' + self.token} 648 649 650class HeaderAuthFlightServer(FlightServerBase): 651 """A Flight server that tests with basic token authentication. """ 652 653 def do_action(self, context, action): 654 middleware = context.get_middleware("auth") 655 if middleware: 656 auth_header = case_insensitive_header_lookup( 657 middleware.sending_headers(), 'Authorization') 658 values = auth_header.split(' ') 659 return [values[1].encode("utf-8")] 660 raise flight.FlightUnauthenticatedError( 661 'No token auth middleware found.') 662 663 664class ArbitraryHeadersServerMiddlewareFactory(ServerMiddlewareFactory): 665 """A ServerMiddlewareFactory that transports arbitrary headers.""" 666 667 def start_call(self, info, headers): 668 return ArbitraryHeadersServerMiddleware(headers) 669 670 671class ArbitraryHeadersServerMiddleware(ServerMiddleware): 672 """A ServerMiddleware that transports arbitrary headers.""" 673 674 def __init__(self, incoming): 675 self.incoming = incoming 676 677 def sending_headers(self): 678 return self.incoming 679 680 681class ArbitraryHeadersFlightServer(FlightServerBase): 682 """A Flight server that tests multiple arbitrary headers.""" 683 684 def do_action(self, context, action): 685 middleware = context.get_middleware("arbitrary-headers") 686 if middleware: 687 headers = middleware.sending_headers() 688 header_1 = case_insensitive_header_lookup( 689 headers, 690 'test-header-1' 691 ) 692 header_2 = case_insensitive_header_lookup( 693 headers, 694 'test-header-2' 695 ) 696 value1 = header_1[0].encode("utf-8") 697 value2 = header_2[0].encode("utf-8") 698 return [value1, value2] 699 raise flight.FlightServerError("No headers middleware found") 700 701 702class HeaderServerMiddleware(ServerMiddleware): 703 """Expose a per-call value to the RPC method body.""" 704 705 def __init__(self, special_value): 706 self.special_value = special_value 707 708 709class HeaderServerMiddlewareFactory(ServerMiddlewareFactory): 710 """Expose a per-call hard-coded value to the RPC method body.""" 711 712 def start_call(self, info, headers): 713 return HeaderServerMiddleware("right value") 714 715 716class HeaderFlightServer(FlightServerBase): 717 """Echo back the per-call hard-coded value.""" 718 719 def do_action(self, context, action): 720 middleware = context.get_middleware("test") 721 if middleware: 722 return [middleware.special_value.encode()] 723 return [b""] 724 725 726class MultiHeaderFlightServer(FlightServerBase): 727 """Test sending/receiving multiple (binary-valued) headers.""" 728 729 def do_action(self, context, action): 730 middleware = context.get_middleware("test") 731 headers = repr(middleware.client_headers).encode("utf-8") 732 return [headers] 733 734 735class SelectiveAuthServerMiddlewareFactory(ServerMiddlewareFactory): 736 """Deny access to certain methods based on a header.""" 737 738 def start_call(self, info, headers): 739 if info.method == flight.FlightMethod.LIST_ACTIONS: 740 # No auth needed 741 return 742 743 token = headers.get("x-auth-token") 744 if not token: 745 raise flight.FlightUnauthenticatedError("No token") 746 747 token = token[0] 748 if token != "password": 749 raise flight.FlightUnauthenticatedError("Invalid token") 750 751 return HeaderServerMiddleware(token) 752 753 754class SelectiveAuthClientMiddlewareFactory(ClientMiddlewareFactory): 755 def start_call(self, info): 756 return SelectiveAuthClientMiddleware() 757 758 759class SelectiveAuthClientMiddleware(ClientMiddleware): 760 def sending_headers(self): 761 return { 762 "x-auth-token": "password", 763 } 764 765 766class RecordingServerMiddlewareFactory(ServerMiddlewareFactory): 767 """Record what methods were called.""" 768 769 def __init__(self): 770 super().__init__() 771 self.methods = [] 772 773 def start_call(self, info, headers): 774 self.methods.append(info.method) 775 return None 776 777 778class RecordingClientMiddlewareFactory(ClientMiddlewareFactory): 779 """Record what methods were called.""" 780 781 def __init__(self): 782 super().__init__() 783 self.methods = [] 784 785 def start_call(self, info): 786 self.methods.append(info.method) 787 return None 788 789 790class MultiHeaderClientMiddlewareFactory(ClientMiddlewareFactory): 791 """Test sending/receiving multiple (binary-valued) headers.""" 792 793 def __init__(self): 794 # Read in test_middleware_multi_header below. 795 # The middleware instance will update this value. 796 self.last_headers = {} 797 798 def start_call(self, info): 799 return MultiHeaderClientMiddleware(self) 800 801 802class MultiHeaderClientMiddleware(ClientMiddleware): 803 """Test sending/receiving multiple (binary-valued) headers.""" 804 805 EXPECTED = { 806 "x-text": ["foo", "bar"], 807 "x-binary-bin": [b"\x00", b"\x01"], 808 } 809 810 def __init__(self, factory): 811 self.factory = factory 812 813 def sending_headers(self): 814 return self.EXPECTED 815 816 def received_headers(self, headers): 817 # Let the test code know what the last set of headers we 818 # received were. 819 self.factory.last_headers = headers 820 821 822class MultiHeaderServerMiddlewareFactory(ServerMiddlewareFactory): 823 """Test sending/receiving multiple (binary-valued) headers.""" 824 825 def start_call(self, info, headers): 826 return MultiHeaderServerMiddleware(headers) 827 828 829class MultiHeaderServerMiddleware(ServerMiddleware): 830 """Test sending/receiving multiple (binary-valued) headers.""" 831 832 def __init__(self, client_headers): 833 self.client_headers = client_headers 834 835 def sending_headers(self): 836 return MultiHeaderClientMiddleware.EXPECTED 837 838 839class LargeMetadataFlightServer(FlightServerBase): 840 """Regression test for ARROW-13253.""" 841 842 def __init__(self, *args, **kwargs): 843 super().__init__(*args, **kwargs) 844 self._metadata = b' ' * (2 ** 31 + 1) 845 846 def do_get(self, context, ticket): 847 schema = pa.schema([('a', pa.int64())]) 848 return flight.GeneratorStream(schema, [ 849 (pa.record_batch([[1]], schema=schema), self._metadata), 850 ]) 851 852 def do_exchange(self, context, descriptor, reader, writer): 853 writer.write_metadata(self._metadata) 854 855 856def test_flight_server_location_argument(): 857 locations = [ 858 None, 859 'grpc://localhost:0', 860 ('localhost', find_free_port()), 861 ] 862 for location in locations: 863 with FlightServerBase(location) as server: 864 assert isinstance(server, FlightServerBase) 865 866 867def test_server_exit_reraises_exception(): 868 with pytest.raises(ValueError): 869 with FlightServerBase(): 870 raise ValueError() 871 872 873@pytest.mark.slow 874def test_client_wait_for_available(): 875 location = ('localhost', find_free_port()) 876 server = None 877 878 def serve(): 879 global server 880 time.sleep(0.5) 881 server = FlightServerBase(location) 882 server.serve() 883 884 client = FlightClient(location) 885 thread = threading.Thread(target=serve, daemon=True) 886 thread.start() 887 888 started = time.time() 889 client.wait_for_available(timeout=5) 890 elapsed = time.time() - started 891 assert elapsed >= 0.5 892 893 894def test_flight_list_flights(): 895 """Try a simple list_flights call.""" 896 with ConstantFlightServer() as server: 897 client = flight.connect(('localhost', server.port)) 898 assert list(client.list_flights()) == [] 899 flights = client.list_flights(ConstantFlightServer.CRITERIA) 900 assert len(list(flights)) == 1 901 902 903def test_flight_do_get_ints(): 904 """Try a simple do_get call.""" 905 table = simple_ints_table() 906 907 with ConstantFlightServer() as server: 908 client = flight.connect(('localhost', server.port)) 909 data = client.do_get(flight.Ticket(b'ints')).read_all() 910 assert data.equals(table) 911 912 options = pa.ipc.IpcWriteOptions( 913 metadata_version=pa.ipc.MetadataVersion.V4) 914 with ConstantFlightServer(options=options) as server: 915 client = flight.connect(('localhost', server.port)) 916 data = client.do_get(flight.Ticket(b'ints')).read_all() 917 assert data.equals(table) 918 919 # Also test via RecordBatchReader interface 920 data = client.do_get(flight.Ticket(b'ints')).to_reader().read_all() 921 assert data.equals(table) 922 923 with pytest.raises(flight.FlightServerError, 924 match="expected IpcWriteOptions, got <class 'int'>"): 925 with ConstantFlightServer(options=42) as server: 926 client = flight.connect(('localhost', server.port)) 927 data = client.do_get(flight.Ticket(b'ints')).read_all() 928 929 930@pytest.mark.pandas 931def test_do_get_ints_pandas(): 932 """Try a simple do_get call.""" 933 table = simple_ints_table() 934 935 with ConstantFlightServer() as server: 936 client = flight.connect(('localhost', server.port)) 937 data = client.do_get(flight.Ticket(b'ints')).read_pandas() 938 assert list(data['some_ints']) == table.column(0).to_pylist() 939 940 941def test_flight_do_get_dicts(): 942 table = simple_dicts_table() 943 944 with ConstantFlightServer() as server: 945 client = flight.connect(('localhost', server.port)) 946 data = client.do_get(flight.Ticket(b'dicts')).read_all() 947 assert data.equals(table) 948 949 950def test_flight_do_get_ticket(): 951 """Make sure Tickets get passed to the server.""" 952 data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] 953 table = pa.Table.from_arrays(data1, names=['a']) 954 with CheckTicketFlightServer(expected_ticket=b'the-ticket') as server: 955 client = flight.connect(('localhost', server.port)) 956 data = client.do_get(flight.Ticket(b'the-ticket')).read_all() 957 assert data.equals(table) 958 959 960def test_flight_get_info(): 961 """Make sure FlightEndpoint accepts string and object URIs.""" 962 with GetInfoFlightServer() as server: 963 client = FlightClient(('localhost', server.port)) 964 info = client.get_flight_info(flight.FlightDescriptor.for_command(b'')) 965 assert info.total_records == -1 966 assert info.total_bytes == -1 967 assert info.schema == pa.schema([('a', pa.int32())]) 968 assert len(info.endpoints) == 2 969 assert len(info.endpoints[0].locations) == 1 970 assert info.endpoints[0].locations[0] == flight.Location('grpc://test') 971 assert info.endpoints[1].locations[0] == \ 972 flight.Location.for_grpc_tcp('localhost', 5005) 973 974 975def test_flight_get_schema(): 976 """Make sure GetSchema returns correct schema.""" 977 with GetInfoFlightServer() as server: 978 client = FlightClient(('localhost', server.port)) 979 info = client.get_schema(flight.FlightDescriptor.for_command(b'')) 980 assert info.schema == pa.schema([('a', pa.int32())]) 981 982 983def test_list_actions(): 984 """Make sure the return type of ListActions is validated.""" 985 # ARROW-6392 986 with ListActionsErrorFlightServer() as server: 987 client = FlightClient(('localhost', server.port)) 988 with pytest.raises( 989 flight.FlightServerError, 990 match=("Results of list_actions must be " 991 "ActionType or tuple") 992 ): 993 list(client.list_actions()) 994 995 with ListActionsFlightServer() as server: 996 client = FlightClient(('localhost', server.port)) 997 assert list(client.list_actions()) == \ 998 ListActionsFlightServer.expected_actions() 999 1000 1001class ConvenienceServer(FlightServerBase): 1002 """ 1003 Server for testing various implementation conveniences (auto-boxing, etc.) 1004 """ 1005 1006 @property 1007 def simple_action_results(self): 1008 return [b'foo', b'bar', b'baz'] 1009 1010 def do_action(self, context, action): 1011 if action.type == 'simple-action': 1012 return self.simple_action_results 1013 elif action.type == 'echo': 1014 return [action.body] 1015 elif action.type == 'bad-action': 1016 return ['foo'] 1017 elif action.type == 'arrow-exception': 1018 raise pa.ArrowMemoryError() 1019 1020 1021def test_do_action_result_convenience(): 1022 with ConvenienceServer() as server: 1023 client = FlightClient(('localhost', server.port)) 1024 1025 # do_action as action type without body 1026 results = [x.body for x in client.do_action('simple-action')] 1027 assert results == server.simple_action_results 1028 1029 # do_action with tuple of type and body 1030 body = b'the-body' 1031 results = [x.body for x in client.do_action(('echo', body))] 1032 assert results == [body] 1033 1034 1035def test_nicer_server_exceptions(): 1036 with ConvenienceServer() as server: 1037 client = FlightClient(('localhost', server.port)) 1038 with pytest.raises(flight.FlightServerError, 1039 match="a bytes-like object is required"): 1040 list(client.do_action('bad-action')) 1041 # While Flight/C++ sends across the original status code, it 1042 # doesn't get mapped to the equivalent code here, since we 1043 # want to be able to distinguish between client- and server- 1044 # side errors. 1045 with pytest.raises(flight.FlightServerError, 1046 match="ArrowMemoryError"): 1047 list(client.do_action('arrow-exception')) 1048 1049 1050def test_get_port(): 1051 """Make sure port() works.""" 1052 server = GetInfoFlightServer("grpc://localhost:0") 1053 try: 1054 assert server.port > 0 1055 finally: 1056 server.shutdown() 1057 1058 1059@pytest.mark.skipif(os.name == 'nt', 1060 reason="Unix sockets can't be tested on Windows") 1061def test_flight_domain_socket(): 1062 """Try a simple do_get call over a Unix domain socket.""" 1063 with tempfile.NamedTemporaryFile() as sock: 1064 sock.close() 1065 location = flight.Location.for_grpc_unix(sock.name) 1066 with ConstantFlightServer(location=location): 1067 client = FlightClient(location) 1068 1069 reader = client.do_get(flight.Ticket(b'ints')) 1070 table = simple_ints_table() 1071 assert reader.schema.equals(table.schema) 1072 data = reader.read_all() 1073 assert data.equals(table) 1074 1075 reader = client.do_get(flight.Ticket(b'dicts')) 1076 table = simple_dicts_table() 1077 assert reader.schema.equals(table.schema) 1078 data = reader.read_all() 1079 assert data.equals(table) 1080 1081 1082@pytest.mark.slow 1083def test_flight_large_message(): 1084 """Try sending/receiving a large message via Flight. 1085 1086 See ARROW-4421: by default, gRPC won't allow us to send messages > 1087 4MiB in size. 1088 """ 1089 data = pa.Table.from_arrays([ 1090 pa.array(range(0, 10 * 1024 * 1024)) 1091 ], names=['a']) 1092 1093 with EchoFlightServer(expected_schema=data.schema) as server: 1094 client = FlightClient(('localhost', server.port)) 1095 writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'), 1096 data.schema) 1097 # Write a single giant chunk 1098 writer.write_table(data, 10 * 1024 * 1024) 1099 writer.close() 1100 result = client.do_get(flight.Ticket(b'')).read_all() 1101 assert result.equals(data) 1102 1103 1104def test_flight_generator_stream(): 1105 """Try downloading a flight of RecordBatches in a GeneratorStream.""" 1106 data = pa.Table.from_arrays([ 1107 pa.array(range(0, 10 * 1024)) 1108 ], names=['a']) 1109 1110 with EchoStreamFlightServer() as server: 1111 client = FlightClient(('localhost', server.port)) 1112 writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'), 1113 data.schema) 1114 writer.write_table(data) 1115 writer.close() 1116 result = client.do_get(flight.Ticket(b'')).read_all() 1117 assert result.equals(data) 1118 1119 1120def test_flight_invalid_generator_stream(): 1121 """Try streaming data with mismatched schemas.""" 1122 with InvalidStreamFlightServer() as server: 1123 client = FlightClient(('localhost', server.port)) 1124 with pytest.raises(pa.ArrowException): 1125 client.do_get(flight.Ticket(b'')).read_all() 1126 1127 1128def test_timeout_fires(): 1129 """Make sure timeouts fire on slow requests.""" 1130 # Do this in a separate thread so that if it fails, we don't hang 1131 # the entire test process 1132 with SlowFlightServer() as server: 1133 client = FlightClient(('localhost', server.port)) 1134 action = flight.Action("", b"") 1135 options = flight.FlightCallOptions(timeout=0.2) 1136 # gRPC error messages change based on version, so don't look 1137 # for a particular error 1138 with pytest.raises(flight.FlightTimedOutError): 1139 list(client.do_action(action, options=options)) 1140 1141 1142def test_timeout_passes(): 1143 """Make sure timeouts do not fire on fast requests.""" 1144 with ConstantFlightServer() as server: 1145 client = FlightClient(('localhost', server.port)) 1146 options = flight.FlightCallOptions(timeout=5.0) 1147 client.do_get(flight.Ticket(b'ints'), options=options).read_all() 1148 1149 1150basic_auth_handler = HttpBasicServerAuthHandler(creds={ 1151 b"test": b"p4ssw0rd", 1152}) 1153 1154token_auth_handler = TokenServerAuthHandler(creds={ 1155 b"test": b"p4ssw0rd", 1156}) 1157 1158 1159@pytest.mark.slow 1160def test_http_basic_unauth(): 1161 """Test that auth fails when not authenticated.""" 1162 with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: 1163 client = FlightClient(('localhost', server.port)) 1164 action = flight.Action("who-am-i", b"") 1165 with pytest.raises(flight.FlightUnauthenticatedError, 1166 match=".*unauthenticated.*"): 1167 list(client.do_action(action)) 1168 1169 1170@pytest.mark.skipif(os.name == 'nt', 1171 reason="ARROW-10013: gRPC on Windows corrupts peer()") 1172def test_http_basic_auth(): 1173 """Test a Python implementation of HTTP basic authentication.""" 1174 with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: 1175 client = FlightClient(('localhost', server.port)) 1176 action = flight.Action("who-am-i", b"") 1177 client.authenticate(HttpBasicClientAuthHandler('test', 'p4ssw0rd')) 1178 results = client.do_action(action) 1179 identity = next(results) 1180 assert identity.body.to_pybytes() == b'test' 1181 peer_address = next(results) 1182 assert peer_address.body.to_pybytes() != b'' 1183 1184 1185def test_http_basic_auth_invalid_password(): 1186 """Test that auth fails with the wrong password.""" 1187 with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: 1188 client = FlightClient(('localhost', server.port)) 1189 action = flight.Action("who-am-i", b"") 1190 with pytest.raises(flight.FlightUnauthenticatedError, 1191 match=".*wrong password.*"): 1192 client.authenticate(HttpBasicClientAuthHandler('test', 'wrong')) 1193 next(client.do_action(action)) 1194 1195 1196def test_token_auth(): 1197 """Test an auth mechanism that uses a handshake.""" 1198 with EchoStreamFlightServer(auth_handler=token_auth_handler) as server: 1199 client = FlightClient(('localhost', server.port)) 1200 action = flight.Action("who-am-i", b"") 1201 client.authenticate(TokenClientAuthHandler('test', 'p4ssw0rd')) 1202 identity = next(client.do_action(action)) 1203 assert identity.body.to_pybytes() == b'test' 1204 1205 1206def test_token_auth_invalid(): 1207 """Test an auth mechanism that uses a handshake.""" 1208 with EchoStreamFlightServer(auth_handler=token_auth_handler) as server: 1209 client = FlightClient(('localhost', server.port)) 1210 with pytest.raises(flight.FlightUnauthenticatedError): 1211 client.authenticate(TokenClientAuthHandler('test', 'wrong')) 1212 1213 1214header_auth_server_middleware_factory = HeaderAuthServerMiddlewareFactory() 1215no_op_auth_handler = NoopAuthHandler() 1216 1217 1218def test_authenticate_basic_token(): 1219 """Test authenticate_basic_token with bearer token and auth headers.""" 1220 with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ 1221 "auth": HeaderAuthServerMiddlewareFactory() 1222 }) as server: 1223 client = FlightClient(('localhost', server.port)) 1224 token_pair = client.authenticate_basic_token(b'test', b'password') 1225 assert token_pair[0] == b'authorization' 1226 assert token_pair[1] == b'Bearer token1234' 1227 1228 1229def test_authenticate_basic_token_invalid_password(): 1230 """Test authenticate_basic_token with an invalid password.""" 1231 with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ 1232 "auth": HeaderAuthServerMiddlewareFactory() 1233 }) as server: 1234 client = FlightClient(('localhost', server.port)) 1235 with pytest.raises(flight.FlightUnauthenticatedError): 1236 client.authenticate_basic_token(b'test', b'badpassword') 1237 1238 1239def test_authenticate_basic_token_and_action(): 1240 """Test authenticate_basic_token and doAction after authentication.""" 1241 with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ 1242 "auth": HeaderAuthServerMiddlewareFactory() 1243 }) as server: 1244 client = FlightClient(('localhost', server.port)) 1245 token_pair = client.authenticate_basic_token(b'test', b'password') 1246 assert token_pair[0] == b'authorization' 1247 assert token_pair[1] == b'Bearer token1234' 1248 options = flight.FlightCallOptions(headers=[token_pair]) 1249 result = list(client.do_action( 1250 action=flight.Action('test-action', b''), options=options)) 1251 assert result[0].body.to_pybytes() == b'token1234' 1252 1253 1254def test_authenticate_basic_token_with_client_middleware(): 1255 """Test authenticate_basic_token with client middleware 1256 to intercept authorization header returned by the 1257 HTTP header auth enabled server. 1258 """ 1259 with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ 1260 "auth": HeaderAuthServerMiddlewareFactory() 1261 }) as server: 1262 client_auth_middleware = ClientHeaderAuthMiddlewareFactory() 1263 client = FlightClient( 1264 ('localhost', server.port), 1265 middleware=[client_auth_middleware] 1266 ) 1267 encoded_credentials = base64.b64encode(b'test:password') 1268 options = flight.FlightCallOptions(headers=[ 1269 (b'authorization', b'Basic ' + encoded_credentials) 1270 ]) 1271 result = list(client.do_action( 1272 action=flight.Action('test-action', b''), options=options)) 1273 assert result[0].body.to_pybytes() == b'token1234' 1274 assert client_auth_middleware.call_credential[0] == b'authorization' 1275 assert client_auth_middleware.call_credential[1] == \ 1276 b'Bearer ' + b'token1234' 1277 result2 = list(client.do_action( 1278 action=flight.Action('test-action', b''), options=options)) 1279 assert result2[0].body.to_pybytes() == b'token1234' 1280 assert client_auth_middleware.call_credential[0] == b'authorization' 1281 assert client_auth_middleware.call_credential[1] == \ 1282 b'Bearer ' + b'token1234' 1283 1284 1285def test_arbitrary_headers_in_flight_call_options(): 1286 """Test passing multiple arbitrary headers to the middleware.""" 1287 with ArbitraryHeadersFlightServer( 1288 auth_handler=no_op_auth_handler, 1289 middleware={ 1290 "auth": HeaderAuthServerMiddlewareFactory(), 1291 "arbitrary-headers": ArbitraryHeadersServerMiddlewareFactory() 1292 }) as server: 1293 client = FlightClient(('localhost', server.port)) 1294 token_pair = client.authenticate_basic_token(b'test', b'password') 1295 assert token_pair[0] == b'authorization' 1296 assert token_pair[1] == b'Bearer token1234' 1297 options = flight.FlightCallOptions(headers=[ 1298 token_pair, 1299 (b'test-header-1', b'value1'), 1300 (b'test-header-2', b'value2') 1301 ]) 1302 result = list(client.do_action(flight.Action( 1303 "test-action", b""), options=options)) 1304 assert result[0].body.to_pybytes() == b'value1' 1305 assert result[1].body.to_pybytes() == b'value2' 1306 1307 1308def test_location_invalid(): 1309 """Test constructing invalid URIs.""" 1310 with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"): 1311 flight.connect("%") 1312 1313 with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"): 1314 ConstantFlightServer("%") 1315 1316 1317def test_location_unknown_scheme(): 1318 """Test creating locations for unknown schemes.""" 1319 assert flight.Location("s3://foo").uri == b"s3://foo" 1320 assert flight.Location("https://example.com/bar.parquet").uri == \ 1321 b"https://example.com/bar.parquet" 1322 1323 1324@pytest.mark.slow 1325@pytest.mark.requires_testing_data 1326def test_tls_fails(): 1327 """Make sure clients cannot connect when cert verification fails.""" 1328 certs = example_tls_certs() 1329 1330 with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: 1331 # Ensure client doesn't connect when certificate verification 1332 # fails (this is a slow test since gRPC does retry a few times) 1333 client = FlightClient("grpc+tls://localhost:" + str(s.port)) 1334 1335 # gRPC error messages change based on version, so don't look 1336 # for a particular error 1337 with pytest.raises(flight.FlightUnavailableError): 1338 client.do_get(flight.Ticket(b'ints')).read_all() 1339 1340 1341@pytest.mark.requires_testing_data 1342def test_tls_do_get(): 1343 """Try a simple do_get call over TLS.""" 1344 table = simple_ints_table() 1345 certs = example_tls_certs() 1346 1347 with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: 1348 client = FlightClient(('localhost', s.port), 1349 tls_root_certs=certs["root_cert"]) 1350 data = client.do_get(flight.Ticket(b'ints')).read_all() 1351 assert data.equals(table) 1352 1353 1354@pytest.mark.requires_testing_data 1355def test_tls_disable_server_verification(): 1356 """Try a simple do_get call over TLS with server verification disabled.""" 1357 table = simple_ints_table() 1358 certs = example_tls_certs() 1359 1360 with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: 1361 try: 1362 client = FlightClient(('localhost', s.port), 1363 disable_server_verification=True) 1364 except NotImplementedError: 1365 pytest.skip('disable_server_verification feature is not available') 1366 data = client.do_get(flight.Ticket(b'ints')).read_all() 1367 assert data.equals(table) 1368 1369 1370@pytest.mark.requires_testing_data 1371def test_tls_override_hostname(): 1372 """Check that incorrectly overriding the hostname fails.""" 1373 certs = example_tls_certs() 1374 1375 with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: 1376 client = flight.connect(('localhost', s.port), 1377 tls_root_certs=certs["root_cert"], 1378 override_hostname="fakehostname") 1379 with pytest.raises(flight.FlightUnavailableError): 1380 client.do_get(flight.Ticket(b'ints')) 1381 1382 1383def test_flight_do_get_metadata(): 1384 """Try a simple do_get call with metadata.""" 1385 data = [ 1386 pa.array([-10, -5, 0, 5, 10]) 1387 ] 1388 table = pa.Table.from_arrays(data, names=['a']) 1389 1390 batches = [] 1391 with MetadataFlightServer() as server: 1392 client = FlightClient(('localhost', server.port)) 1393 reader = client.do_get(flight.Ticket(b'')) 1394 idx = 0 1395 while True: 1396 try: 1397 batch, metadata = reader.read_chunk() 1398 batches.append(batch) 1399 server_idx, = struct.unpack('<i', metadata.to_pybytes()) 1400 assert idx == server_idx 1401 idx += 1 1402 except StopIteration: 1403 break 1404 data = pa.Table.from_batches(batches) 1405 assert data.equals(table) 1406 1407 1408def test_flight_do_get_metadata_v4(): 1409 """Try a simple do_get call with V4 metadata version.""" 1410 table = pa.Table.from_arrays( 1411 [pa.array([-10, -5, 0, 5, 10])], names=['a']) 1412 options = pa.ipc.IpcWriteOptions( 1413 metadata_version=pa.ipc.MetadataVersion.V4) 1414 with MetadataFlightServer(options=options) as server: 1415 client = FlightClient(('localhost', server.port)) 1416 reader = client.do_get(flight.Ticket(b'')) 1417 data = reader.read_all() 1418 assert data.equals(table) 1419 1420 1421def test_flight_do_put_metadata(): 1422 """Try a simple do_put call with metadata.""" 1423 data = [ 1424 pa.array([-10, -5, 0, 5, 10]) 1425 ] 1426 table = pa.Table.from_arrays(data, names=['a']) 1427 1428 with MetadataFlightServer() as server: 1429 client = FlightClient(('localhost', server.port)) 1430 writer, metadata_reader = client.do_put( 1431 flight.FlightDescriptor.for_path(''), 1432 table.schema) 1433 with writer: 1434 for idx, batch in enumerate(table.to_batches(max_chunksize=1)): 1435 metadata = struct.pack('<i', idx) 1436 writer.write_with_metadata(batch, metadata) 1437 buf = metadata_reader.read() 1438 assert buf is not None 1439 server_idx, = struct.unpack('<i', buf.to_pybytes()) 1440 assert idx == server_idx 1441 1442 1443def test_flight_do_put_limit(): 1444 """Try a simple do_put call with a size limit.""" 1445 large_batch = pa.RecordBatch.from_arrays([ 1446 pa.array(np.ones(768, dtype=np.int64())), 1447 ], names=['a']) 1448 1449 with EchoFlightServer() as server: 1450 client = FlightClient(('localhost', server.port), 1451 write_size_limit_bytes=4096) 1452 writer, metadata_reader = client.do_put( 1453 flight.FlightDescriptor.for_path(''), 1454 large_batch.schema) 1455 with writer: 1456 with pytest.raises(flight.FlightWriteSizeExceededError, 1457 match="exceeded soft limit") as excinfo: 1458 writer.write_batch(large_batch) 1459 assert excinfo.value.limit == 4096 1460 smaller_batches = [ 1461 large_batch.slice(0, 384), 1462 large_batch.slice(384), 1463 ] 1464 for batch in smaller_batches: 1465 writer.write_batch(batch) 1466 expected = pa.Table.from_batches([large_batch]) 1467 actual = client.do_get(flight.Ticket(b'')).read_all() 1468 assert expected == actual 1469 1470 1471@pytest.mark.slow 1472def test_cancel_do_get(): 1473 """Test canceling a DoGet operation on the client side.""" 1474 with ConstantFlightServer() as server: 1475 client = FlightClient(('localhost', server.port)) 1476 reader = client.do_get(flight.Ticket(b'ints')) 1477 reader.cancel() 1478 with pytest.raises(flight.FlightCancelledError, match=".*Cancel.*"): 1479 reader.read_chunk() 1480 1481 1482@pytest.mark.slow 1483def test_cancel_do_get_threaded(): 1484 """Test canceling a DoGet operation from another thread.""" 1485 with SlowFlightServer() as server: 1486 client = FlightClient(('localhost', server.port)) 1487 reader = client.do_get(flight.Ticket(b'ints')) 1488 1489 read_first_message = threading.Event() 1490 stream_canceled = threading.Event() 1491 result_lock = threading.Lock() 1492 raised_proper_exception = threading.Event() 1493 1494 def block_read(): 1495 reader.read_chunk() 1496 read_first_message.set() 1497 stream_canceled.wait(timeout=5) 1498 try: 1499 reader.read_chunk() 1500 except flight.FlightCancelledError: 1501 with result_lock: 1502 raised_proper_exception.set() 1503 1504 thread = threading.Thread(target=block_read, daemon=True) 1505 thread.start() 1506 read_first_message.wait(timeout=5) 1507 reader.cancel() 1508 stream_canceled.set() 1509 thread.join(timeout=1) 1510 1511 with result_lock: 1512 assert raised_proper_exception.is_set() 1513 1514 1515def test_roundtrip_types(): 1516 """Make sure serializable types round-trip.""" 1517 ticket = flight.Ticket("foo") 1518 assert ticket == flight.Ticket.deserialize(ticket.serialize()) 1519 1520 desc = flight.FlightDescriptor.for_command("test") 1521 assert desc == flight.FlightDescriptor.deserialize(desc.serialize()) 1522 1523 desc = flight.FlightDescriptor.for_path("a", "b", "test.arrow") 1524 assert desc == flight.FlightDescriptor.deserialize(desc.serialize()) 1525 1526 info = flight.FlightInfo( 1527 pa.schema([('a', pa.int32())]), 1528 desc, 1529 [ 1530 flight.FlightEndpoint(b'', ['grpc://test']), 1531 flight.FlightEndpoint( 1532 b'', 1533 [flight.Location.for_grpc_tcp('localhost', 5005)], 1534 ), 1535 ], 1536 -1, 1537 -1, 1538 ) 1539 info2 = flight.FlightInfo.deserialize(info.serialize()) 1540 assert info.schema == info2.schema 1541 assert info.descriptor == info2.descriptor 1542 assert info.total_bytes == info2.total_bytes 1543 assert info.total_records == info2.total_records 1544 assert info.endpoints == info2.endpoints 1545 1546 1547def test_roundtrip_errors(): 1548 """Ensure that Flight errors propagate from server to client.""" 1549 with ErrorFlightServer() as server: 1550 client = FlightClient(('localhost', server.port)) 1551 1552 with pytest.raises(flight.FlightInternalError, match=".*foo.*"): 1553 list(client.do_action(flight.Action("internal", b""))) 1554 with pytest.raises(flight.FlightTimedOutError, match=".*foo.*"): 1555 list(client.do_action(flight.Action("timedout", b""))) 1556 with pytest.raises(flight.FlightCancelledError, match=".*foo.*"): 1557 list(client.do_action(flight.Action("cancel", b""))) 1558 with pytest.raises(flight.FlightUnauthenticatedError, match=".*foo.*"): 1559 list(client.do_action(flight.Action("unauthenticated", b""))) 1560 with pytest.raises(flight.FlightUnauthorizedError, match=".*foo.*"): 1561 list(client.do_action(flight.Action("unauthorized", b""))) 1562 with pytest.raises(flight.FlightInternalError, match=".*foo.*"): 1563 list(client.list_flights()) 1564 1565 data = [pa.array([-10, -5, 0, 5, 10])] 1566 table = pa.Table.from_arrays(data, names=['a']) 1567 1568 exceptions = { 1569 'internal': flight.FlightInternalError, 1570 'timedout': flight.FlightTimedOutError, 1571 'cancel': flight.FlightCancelledError, 1572 'unauthenticated': flight.FlightUnauthenticatedError, 1573 'unauthorized': flight.FlightUnauthorizedError, 1574 } 1575 1576 for command, exception in exceptions.items(): 1577 1578 with pytest.raises(exception, match=".*foo.*"): 1579 writer, reader = client.do_put( 1580 flight.FlightDescriptor.for_command(command), 1581 table.schema) 1582 writer.write_table(table) 1583 writer.close() 1584 1585 with pytest.raises(exception, match=".*foo.*"): 1586 writer, reader = client.do_put( 1587 flight.FlightDescriptor.for_command(command), 1588 table.schema) 1589 writer.close() 1590 1591 1592def test_do_put_independent_read_write(): 1593 """Ensure that separate threads can read/write on a DoPut.""" 1594 # ARROW-6063: previously this would cause gRPC to abort when the 1595 # writer was closed (due to simultaneous reads), or would hang 1596 # forever. 1597 data = [ 1598 pa.array([-10, -5, 0, 5, 10]) 1599 ] 1600 table = pa.Table.from_arrays(data, names=['a']) 1601 1602 with MetadataFlightServer() as server: 1603 client = FlightClient(('localhost', server.port)) 1604 writer, metadata_reader = client.do_put( 1605 flight.FlightDescriptor.for_path(''), 1606 table.schema) 1607 1608 count = [0] 1609 1610 def _reader_thread(): 1611 while metadata_reader.read() is not None: 1612 count[0] += 1 1613 1614 thread = threading.Thread(target=_reader_thread) 1615 thread.start() 1616 1617 batches = table.to_batches(max_chunksize=1) 1618 with writer: 1619 for idx, batch in enumerate(batches): 1620 metadata = struct.pack('<i', idx) 1621 writer.write_with_metadata(batch, metadata) 1622 # Causes the server to stop writing and end the call 1623 writer.done_writing() 1624 # Thus reader thread will break out of loop 1625 thread.join() 1626 # writer.close() won't segfault since reader thread has 1627 # stopped 1628 assert count[0] == len(batches) 1629 1630 1631def test_server_middleware_same_thread(): 1632 """Ensure that server middleware run on the same thread as the RPC.""" 1633 with HeaderFlightServer(middleware={ 1634 "test": HeaderServerMiddlewareFactory(), 1635 }) as server: 1636 client = FlightClient(('localhost', server.port)) 1637 results = list(client.do_action(flight.Action(b"test", b""))) 1638 assert len(results) == 1 1639 value = results[0].body.to_pybytes() 1640 assert b"right value" == value 1641 1642 1643def test_middleware_reject(): 1644 """Test rejecting an RPC with server middleware.""" 1645 with HeaderFlightServer(middleware={ 1646 "test": SelectiveAuthServerMiddlewareFactory(), 1647 }) as server: 1648 client = FlightClient(('localhost', server.port)) 1649 # The middleware allows this through without auth. 1650 with pytest.raises(pa.ArrowNotImplementedError): 1651 list(client.list_actions()) 1652 1653 # But not anything else. 1654 with pytest.raises(flight.FlightUnauthenticatedError): 1655 list(client.do_action(flight.Action(b"", b""))) 1656 1657 client = FlightClient( 1658 ('localhost', server.port), 1659 middleware=[SelectiveAuthClientMiddlewareFactory()] 1660 ) 1661 response = next(client.do_action(flight.Action(b"", b""))) 1662 assert b"password" == response.body.to_pybytes() 1663 1664 1665def test_middleware_mapping(): 1666 """Test that middleware records methods correctly.""" 1667 server_middleware = RecordingServerMiddlewareFactory() 1668 client_middleware = RecordingClientMiddlewareFactory() 1669 with FlightServerBase(middleware={"test": server_middleware}) as server: 1670 client = FlightClient( 1671 ('localhost', server.port), 1672 middleware=[client_middleware] 1673 ) 1674 1675 descriptor = flight.FlightDescriptor.for_command(b"") 1676 with pytest.raises(NotImplementedError): 1677 list(client.list_flights()) 1678 with pytest.raises(NotImplementedError): 1679 client.get_flight_info(descriptor) 1680 with pytest.raises(NotImplementedError): 1681 client.get_schema(descriptor) 1682 with pytest.raises(NotImplementedError): 1683 client.do_get(flight.Ticket(b"")) 1684 with pytest.raises(NotImplementedError): 1685 writer, _ = client.do_put(descriptor, pa.schema([])) 1686 writer.close() 1687 with pytest.raises(NotImplementedError): 1688 list(client.do_action(flight.Action(b"", b""))) 1689 with pytest.raises(NotImplementedError): 1690 list(client.list_actions()) 1691 with pytest.raises(NotImplementedError): 1692 writer, _ = client.do_exchange(descriptor) 1693 writer.close() 1694 1695 expected = [ 1696 flight.FlightMethod.LIST_FLIGHTS, 1697 flight.FlightMethod.GET_FLIGHT_INFO, 1698 flight.FlightMethod.GET_SCHEMA, 1699 flight.FlightMethod.DO_GET, 1700 flight.FlightMethod.DO_PUT, 1701 flight.FlightMethod.DO_ACTION, 1702 flight.FlightMethod.LIST_ACTIONS, 1703 flight.FlightMethod.DO_EXCHANGE, 1704 ] 1705 assert server_middleware.methods == expected 1706 assert client_middleware.methods == expected 1707 1708 1709def test_extra_info(): 1710 with ErrorFlightServer() as server: 1711 client = FlightClient(('localhost', server.port)) 1712 try: 1713 list(client.do_action(flight.Action("protobuf", b""))) 1714 assert False 1715 except flight.FlightUnauthorizedError as e: 1716 assert e.extra_info is not None 1717 ei = e.extra_info 1718 assert ei == b'this is an error message' 1719 1720 1721@pytest.mark.requires_testing_data 1722def test_mtls(): 1723 """Test mutual TLS (mTLS) with gRPC.""" 1724 certs = example_tls_certs() 1725 table = simple_ints_table() 1726 1727 with ConstantFlightServer( 1728 tls_certificates=[certs["certificates"][0]], 1729 verify_client=True, 1730 root_certificates=certs["root_cert"]) as s: 1731 client = FlightClient( 1732 ('localhost', s.port), 1733 tls_root_certs=certs["root_cert"], 1734 cert_chain=certs["certificates"][0].cert, 1735 private_key=certs["certificates"][0].key) 1736 data = client.do_get(flight.Ticket(b'ints')).read_all() 1737 assert data.equals(table) 1738 1739 1740def test_doexchange_get(): 1741 """Emulate DoGet with DoExchange.""" 1742 expected = pa.Table.from_arrays([ 1743 pa.array(range(0, 10 * 1024)) 1744 ], names=["a"]) 1745 1746 with ExchangeFlightServer() as server: 1747 client = FlightClient(("localhost", server.port)) 1748 descriptor = flight.FlightDescriptor.for_command(b"get") 1749 writer, reader = client.do_exchange(descriptor) 1750 with writer: 1751 table = reader.read_all() 1752 assert expected == table 1753 1754 1755def test_doexchange_put(): 1756 """Emulate DoPut with DoExchange.""" 1757 data = pa.Table.from_arrays([ 1758 pa.array(range(0, 10 * 1024)) 1759 ], names=["a"]) 1760 batches = data.to_batches(max_chunksize=512) 1761 1762 with ExchangeFlightServer() as server: 1763 client = FlightClient(("localhost", server.port)) 1764 descriptor = flight.FlightDescriptor.for_command(b"put") 1765 writer, reader = client.do_exchange(descriptor) 1766 with writer: 1767 writer.begin(data.schema) 1768 for batch in batches: 1769 writer.write_batch(batch) 1770 writer.done_writing() 1771 chunk = reader.read_chunk() 1772 assert chunk.data is None 1773 expected_buf = str(len(batches)).encode("utf-8") 1774 assert chunk.app_metadata == expected_buf 1775 1776 1777def test_doexchange_echo(): 1778 """Try a DoExchange echo server.""" 1779 data = pa.Table.from_arrays([ 1780 pa.array(range(0, 10 * 1024)) 1781 ], names=["a"]) 1782 batches = data.to_batches(max_chunksize=512) 1783 1784 with ExchangeFlightServer() as server: 1785 client = FlightClient(("localhost", server.port)) 1786 descriptor = flight.FlightDescriptor.for_command(b"echo") 1787 writer, reader = client.do_exchange(descriptor) 1788 with writer: 1789 # Read/write metadata before starting data. 1790 for i in range(10): 1791 buf = str(i).encode("utf-8") 1792 writer.write_metadata(buf) 1793 chunk = reader.read_chunk() 1794 assert chunk.data is None 1795 assert chunk.app_metadata == buf 1796 1797 # Now write data without metadata. 1798 writer.begin(data.schema) 1799 for batch in batches: 1800 writer.write_batch(batch) 1801 assert reader.schema == data.schema 1802 chunk = reader.read_chunk() 1803 assert chunk.data == batch 1804 assert chunk.app_metadata is None 1805 1806 # And write data with metadata. 1807 for i, batch in enumerate(batches): 1808 buf = str(i).encode("utf-8") 1809 writer.write_with_metadata(batch, buf) 1810 chunk = reader.read_chunk() 1811 assert chunk.data == batch 1812 assert chunk.app_metadata == buf 1813 1814 1815def test_doexchange_echo_v4(): 1816 """Try a DoExchange echo server using the V4 metadata version.""" 1817 data = pa.Table.from_arrays([ 1818 pa.array(range(0, 10 * 1024)) 1819 ], names=["a"]) 1820 batches = data.to_batches(max_chunksize=512) 1821 1822 options = pa.ipc.IpcWriteOptions( 1823 metadata_version=pa.ipc.MetadataVersion.V4) 1824 with ExchangeFlightServer(options=options) as server: 1825 client = FlightClient(("localhost", server.port)) 1826 descriptor = flight.FlightDescriptor.for_command(b"echo") 1827 writer, reader = client.do_exchange(descriptor) 1828 with writer: 1829 # Now write data without metadata. 1830 writer.begin(data.schema, options=options) 1831 for batch in batches: 1832 writer.write_batch(batch) 1833 assert reader.schema == data.schema 1834 chunk = reader.read_chunk() 1835 assert chunk.data == batch 1836 assert chunk.app_metadata is None 1837 1838 1839def test_doexchange_transform(): 1840 """Transform a table with a service.""" 1841 data = pa.Table.from_arrays([ 1842 pa.array(range(0, 1024)), 1843 pa.array(range(1, 1025)), 1844 pa.array(range(2, 1026)), 1845 ], names=["a", "b", "c"]) 1846 expected = pa.Table.from_arrays([ 1847 pa.array(range(3, 1024 * 3 + 3, 3)), 1848 ], names=["sum"]) 1849 1850 with ExchangeFlightServer() as server: 1851 client = FlightClient(("localhost", server.port)) 1852 descriptor = flight.FlightDescriptor.for_command(b"transform") 1853 writer, reader = client.do_exchange(descriptor) 1854 with writer: 1855 writer.begin(data.schema) 1856 writer.write_table(data) 1857 writer.done_writing() 1858 table = reader.read_all() 1859 assert expected == table 1860 1861 1862def test_middleware_multi_header(): 1863 """Test sending/receiving multiple (binary-valued) headers.""" 1864 with MultiHeaderFlightServer(middleware={ 1865 "test": MultiHeaderServerMiddlewareFactory(), 1866 }) as server: 1867 headers = MultiHeaderClientMiddlewareFactory() 1868 client = FlightClient(('localhost', server.port), middleware=[headers]) 1869 response = next(client.do_action(flight.Action(b"", b""))) 1870 # The server echoes the headers it got back to us. 1871 raw_headers = response.body.to_pybytes().decode("utf-8") 1872 client_headers = ast.literal_eval(raw_headers) 1873 # Don't directly compare; gRPC may add headers like User-Agent. 1874 for header, values in MultiHeaderClientMiddleware.EXPECTED.items(): 1875 assert client_headers.get(header) == values 1876 assert headers.last_headers.get(header) == values 1877 1878 1879@pytest.mark.requires_testing_data 1880def test_generic_options(): 1881 """Test setting generic client options.""" 1882 certs = example_tls_certs() 1883 1884 with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: 1885 # Try setting a string argument that will make requests fail 1886 options = [("grpc.ssl_target_name_override", "fakehostname")] 1887 client = flight.connect(('localhost', s.port), 1888 tls_root_certs=certs["root_cert"], 1889 generic_options=options) 1890 with pytest.raises(flight.FlightUnavailableError): 1891 client.do_get(flight.Ticket(b'ints')) 1892 # Try setting an int argument that will make requests fail 1893 options = [("grpc.max_receive_message_length", 32)] 1894 client = flight.connect(('localhost', s.port), 1895 tls_root_certs=certs["root_cert"], 1896 generic_options=options) 1897 with pytest.raises(pa.ArrowInvalid): 1898 client.do_get(flight.Ticket(b'ints')) 1899 1900 1901class CancelFlightServer(FlightServerBase): 1902 """A server for testing StopToken.""" 1903 1904 def do_get(self, context, ticket): 1905 schema = pa.schema([]) 1906 rb = pa.RecordBatch.from_arrays([], schema=schema) 1907 return flight.GeneratorStream(schema, itertools.repeat(rb)) 1908 1909 def do_exchange(self, context, descriptor, reader, writer): 1910 schema = pa.schema([]) 1911 rb = pa.RecordBatch.from_arrays([], schema=schema) 1912 writer.begin(schema) 1913 while not context.is_cancelled(): 1914 writer.write_batch(rb) 1915 time.sleep(0.5) 1916 1917 1918def test_interrupt(): 1919 if threading.current_thread().ident != threading.main_thread().ident: 1920 pytest.skip("test only works from main Python thread") 1921 # Skips test if not available 1922 raise_signal = util.get_raise_signal() 1923 1924 def signal_from_thread(): 1925 time.sleep(0.5) 1926 raise_signal(signal.SIGINT) 1927 1928 exc_types = (KeyboardInterrupt, pa.ArrowCancelled) 1929 1930 def test(read_all): 1931 try: 1932 try: 1933 t = threading.Thread(target=signal_from_thread) 1934 with pytest.raises(exc_types) as exc_info: 1935 t.start() 1936 read_all() 1937 finally: 1938 t.join() 1939 except KeyboardInterrupt: 1940 # In case KeyboardInterrupt didn't interrupt read_all 1941 # above, at least prevent it from stopping the test suite 1942 pytest.fail("KeyboardInterrupt didn't interrupt Flight read_all") 1943 e = exc_info.value.__context__ 1944 assert isinstance(e, pa.ArrowCancelled) or \ 1945 isinstance(e, KeyboardInterrupt) 1946 1947 with CancelFlightServer() as server: 1948 client = FlightClient(("localhost", server.port)) 1949 1950 reader = client.do_get(flight.Ticket(b"")) 1951 test(reader.read_all) 1952 1953 descriptor = flight.FlightDescriptor.for_command(b"echo") 1954 writer, reader = client.do_exchange(descriptor) 1955 test(reader.read_all) 1956 1957 1958def test_never_sends_data(): 1959 # Regression test for ARROW-12779 1960 match = "application server implementation error" 1961 with NeverSendsDataFlightServer() as server: 1962 client = flight.connect(('localhost', server.port)) 1963 with pytest.raises(flight.FlightServerError, match=match): 1964 client.do_get(flight.Ticket(b'')).read_all() 1965 1966 # Check that the server handler will ignore empty tables 1967 # up to a certain extent 1968 table = client.do_get(flight.Ticket(b'yield_data')).read_all() 1969 assert table.num_rows == 5 1970 1971 1972@pytest.mark.large_memory 1973@pytest.mark.slow 1974def test_large_descriptor(): 1975 # Regression test for ARROW-13253. Placed here with appropriate marks 1976 # since some CI pipelines can't run the C++ equivalent 1977 large_descriptor = flight.FlightDescriptor.for_command( 1978 b' ' * (2 ** 31 + 1)) 1979 with FlightServerBase() as server: 1980 client = flight.connect(('localhost', server.port)) 1981 with pytest.raises(OSError, 1982 match="Failed to serialize Flight descriptor"): 1983 writer, _ = client.do_put(large_descriptor, pa.schema([])) 1984 writer.close() 1985 with pytest.raises(pa.ArrowException, 1986 match="Failed to serialize Flight descriptor"): 1987 client.do_exchange(large_descriptor) 1988 1989 1990@pytest.mark.large_memory 1991@pytest.mark.slow 1992def test_large_metadata_client(): 1993 # Regression test for ARROW-13253 1994 descriptor = flight.FlightDescriptor.for_command(b'') 1995 metadata = b' ' * (2 ** 31 + 1) 1996 with EchoFlightServer() as server: 1997 client = flight.connect(('localhost', server.port)) 1998 with pytest.raises(pa.ArrowCapacityError, 1999 match="app_metadata size overflow"): 2000 writer, _ = client.do_put(descriptor, pa.schema([])) 2001 with writer: 2002 writer.write_metadata(metadata) 2003 writer.close() 2004 with pytest.raises(pa.ArrowCapacityError, 2005 match="app_metadata size overflow"): 2006 writer, reader = client.do_exchange(descriptor) 2007 with writer: 2008 writer.write_metadata(metadata) 2009 2010 del metadata 2011 with LargeMetadataFlightServer() as server: 2012 client = flight.connect(('localhost', server.port)) 2013 with pytest.raises(flight.FlightServerError, 2014 match="app_metadata size overflow"): 2015 reader = client.do_get(flight.Ticket(b'')) 2016 reader.read_all() 2017 with pytest.raises(pa.ArrowException, 2018 match="app_metadata size overflow"): 2019 writer, reader = client.do_exchange(descriptor) 2020 with writer: 2021 reader.read_all() 2022 2023 2024class ActionNoneFlightServer(EchoFlightServer): 2025 """A server that implements a side effect to a non iterable action.""" 2026 VALUES = [] 2027 2028 def do_action(self, context, action): 2029 if action.type == "get_value": 2030 return [json.dumps(self.VALUES).encode('utf-8')] 2031 elif action.type == "append": 2032 self.VALUES.append(True) 2033 return None 2034 raise NotImplementedError 2035 2036 2037def test_none_action_side_effect(): 2038 """Ensure that actions are executed even when we don't consume iterator. 2039 2040 See https://issues.apache.org/jira/browse/ARROW-14255 2041 """ 2042 2043 with ActionNoneFlightServer() as server: 2044 client = FlightClient(('localhost', server.port)) 2045 client.do_action(flight.Action("append", b"")) 2046 r = client.do_action(flight.Action("get_value", b"")) 2047 assert json.loads(next(r).body.to_pybytes()) == [True] 2048