1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import ast
19import base64
20import itertools
21import os
22import signal
23import struct
24import tempfile
25import threading
26import time
27import traceback
28import json
29
30import numpy as np
31import pytest
32import pyarrow as pa
33
34from pyarrow.lib import tobytes
35from pyarrow.util import pathlib, find_free_port
36from pyarrow.tests import util
37
38try:
39    from pyarrow import flight
40    from pyarrow.flight import (
41        FlightClient, FlightServerBase,
42        ServerAuthHandler, ClientAuthHandler,
43        ServerMiddleware, ServerMiddlewareFactory,
44        ClientMiddleware, ClientMiddlewareFactory,
45    )
46except ImportError:
47    flight = None
48    FlightClient, FlightServerBase = object, object
49    ServerAuthHandler, ClientAuthHandler = object, object
50    ServerMiddleware, ServerMiddlewareFactory = object, object
51    ClientMiddleware, ClientMiddlewareFactory = object, object
52
53# Marks all of the tests in this module
54# Ignore these with pytest ... -m 'not flight'
55pytestmark = pytest.mark.flight
56
57
58def test_import():
59    # So we see the ImportError somewhere
60    import pyarrow.flight  # noqa
61
62
63def resource_root():
64    """Get the path to the test resources directory."""
65    if not os.environ.get("ARROW_TEST_DATA"):
66        raise RuntimeError("Test resources not found; set "
67                           "ARROW_TEST_DATA to <repo root>/testing/data")
68    return pathlib.Path(os.environ["ARROW_TEST_DATA"]) / "flight"
69
70
71def read_flight_resource(path):
72    """Get the contents of a test resource file."""
73    root = resource_root()
74    if not root:
75        return None
76    try:
77        with (root / path).open("rb") as f:
78            return f.read()
79    except FileNotFoundError:
80        raise RuntimeError(
81            "Test resource {} not found; did you initialize the "
82            "test resource submodule?\n{}".format(root / path,
83                                                  traceback.format_exc()))
84
85
86def example_tls_certs():
87    """Get the paths to test TLS certificates."""
88    return {
89        "root_cert": read_flight_resource("root-ca.pem"),
90        "certificates": [
91            flight.CertKeyPair(
92                cert=read_flight_resource("cert0.pem"),
93                key=read_flight_resource("cert0.key"),
94            ),
95            flight.CertKeyPair(
96                cert=read_flight_resource("cert1.pem"),
97                key=read_flight_resource("cert1.key"),
98            ),
99        ]
100    }
101
102
103def simple_ints_table():
104    data = [
105        pa.array([-10, -5, 0, 5, 10])
106    ]
107    return pa.Table.from_arrays(data, names=['some_ints'])
108
109
110def simple_dicts_table():
111    dict_values = pa.array(["foo", "baz", "quux"], type=pa.utf8())
112    data = [
113        pa.chunked_array([
114            pa.DictionaryArray.from_arrays([1, 0, None], dict_values),
115            pa.DictionaryArray.from_arrays([2, 1], dict_values)
116        ])
117    ]
118    return pa.Table.from_arrays(data, names=['some_dicts'])
119
120
121class ConstantFlightServer(FlightServerBase):
122    """A Flight server that always returns the same data.
123
124    See ARROW-4796: this server implementation will segfault if Flight
125    does not properly hold a reference to the Table object.
126    """
127
128    CRITERIA = b"the expected criteria"
129
130    def __init__(self, location=None, options=None, **kwargs):
131        super().__init__(location, **kwargs)
132        # Ticket -> Table
133        self.table_factories = {
134            b'ints': simple_ints_table,
135            b'dicts': simple_dicts_table,
136        }
137        self.options = options
138
139    def list_flights(self, context, criteria):
140        if criteria == self.CRITERIA:
141            yield flight.FlightInfo(
142                pa.schema([]),
143                flight.FlightDescriptor.for_path('/foo'),
144                [],
145                -1, -1
146            )
147
148    def do_get(self, context, ticket):
149        # Return a fresh table, so that Flight is the only one keeping a
150        # reference.
151        table = self.table_factories[ticket.ticket]()
152        return flight.RecordBatchStream(table, options=self.options)
153
154
155class MetadataFlightServer(FlightServerBase):
156    """A Flight server that numbers incoming/outgoing data."""
157
158    def __init__(self, options=None, **kwargs):
159        super().__init__(**kwargs)
160        self.options = options
161
162    def do_get(self, context, ticket):
163        data = [
164            pa.array([-10, -5, 0, 5, 10])
165        ]
166        table = pa.Table.from_arrays(data, names=['a'])
167        return flight.GeneratorStream(
168            table.schema,
169            self.number_batches(table),
170            options=self.options)
171
172    def do_put(self, context, descriptor, reader, writer):
173        counter = 0
174        expected_data = [-10, -5, 0, 5, 10]
175        while True:
176            try:
177                batch, buf = reader.read_chunk()
178                assert batch.equals(pa.RecordBatch.from_arrays(
179                    [pa.array([expected_data[counter]])],
180                    ['a']
181                ))
182                assert buf is not None
183                client_counter, = struct.unpack('<i', buf.to_pybytes())
184                assert counter == client_counter
185                writer.write(struct.pack('<i', counter))
186                counter += 1
187            except StopIteration:
188                return
189
190    @staticmethod
191    def number_batches(table):
192        for idx, batch in enumerate(table.to_batches()):
193            buf = struct.pack('<i', idx)
194            yield batch, buf
195
196
197class EchoFlightServer(FlightServerBase):
198    """A Flight server that returns the last data uploaded."""
199
200    def __init__(self, location=None, expected_schema=None, **kwargs):
201        super().__init__(location, **kwargs)
202        self.last_message = None
203        self.expected_schema = expected_schema
204
205    def do_get(self, context, ticket):
206        return flight.RecordBatchStream(self.last_message)
207
208    def do_put(self, context, descriptor, reader, writer):
209        if self.expected_schema:
210            assert self.expected_schema == reader.schema
211        self.last_message = reader.read_all()
212
213    def do_exchange(self, context, descriptor, reader, writer):
214        for chunk in reader:
215            pass
216
217
218class EchoStreamFlightServer(EchoFlightServer):
219    """An echo server that streams individual record batches."""
220
221    def do_get(self, context, ticket):
222        return flight.GeneratorStream(
223            self.last_message.schema,
224            self.last_message.to_batches(max_chunksize=1024))
225
226    def list_actions(self, context):
227        return []
228
229    def do_action(self, context, action):
230        if action.type == "who-am-i":
231            return [context.peer_identity(), context.peer().encode("utf-8")]
232        raise NotImplementedError
233
234
235class GetInfoFlightServer(FlightServerBase):
236    """A Flight server that tests GetFlightInfo."""
237
238    def get_flight_info(self, context, descriptor):
239        return flight.FlightInfo(
240            pa.schema([('a', pa.int32())]),
241            descriptor,
242            [
243                flight.FlightEndpoint(b'', ['grpc://test']),
244                flight.FlightEndpoint(
245                    b'',
246                    [flight.Location.for_grpc_tcp('localhost', 5005)],
247                ),
248            ],
249            -1,
250            -1,
251        )
252
253    def get_schema(self, context, descriptor):
254        info = self.get_flight_info(context, descriptor)
255        return flight.SchemaResult(info.schema)
256
257
258class ListActionsFlightServer(FlightServerBase):
259    """A Flight server that tests ListActions."""
260
261    @classmethod
262    def expected_actions(cls):
263        return [
264            ("action-1", "description"),
265            ("action-2", ""),
266            flight.ActionType("action-3", "more detail"),
267        ]
268
269    def list_actions(self, context):
270        yield from self.expected_actions()
271
272
273class ListActionsErrorFlightServer(FlightServerBase):
274    """A Flight server that tests ListActions."""
275
276    def list_actions(self, context):
277        yield ("action-1", "")
278        yield "foo"
279
280
281class CheckTicketFlightServer(FlightServerBase):
282    """A Flight server that compares the given ticket to an expected value."""
283
284    def __init__(self, expected_ticket, location=None, **kwargs):
285        super().__init__(location, **kwargs)
286        self.expected_ticket = expected_ticket
287
288    def do_get(self, context, ticket):
289        assert self.expected_ticket == ticket.ticket
290        data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
291        table = pa.Table.from_arrays(data1, names=['a'])
292        return flight.RecordBatchStream(table)
293
294    def do_put(self, context, descriptor, reader):
295        self.last_message = reader.read_all()
296
297
298class InvalidStreamFlightServer(FlightServerBase):
299    """A Flight server that tries to return messages with differing schemas."""
300
301    schema = pa.schema([('a', pa.int32())])
302
303    def do_get(self, context, ticket):
304        data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
305        data2 = [pa.array([-10.0, -5.0, 0.0, 5.0, 10.0], type=pa.float64())]
306        assert data1.type != data2.type
307        table1 = pa.Table.from_arrays(data1, names=['a'])
308        table2 = pa.Table.from_arrays(data2, names=['a'])
309        assert table1.schema == self.schema
310
311        return flight.GeneratorStream(self.schema, [table1, table2])
312
313
314class NeverSendsDataFlightServer(FlightServerBase):
315    """A Flight server that never actually yields data."""
316
317    schema = pa.schema([('a', pa.int32())])
318
319    def do_get(self, context, ticket):
320        if ticket.ticket == b'yield_data':
321            # Check that the server handler will ignore empty tables
322            # up to a certain extent
323            data = [
324                self.schema.empty_table(),
325                self.schema.empty_table(),
326                pa.RecordBatch.from_arrays([range(5)], schema=self.schema),
327            ]
328            return flight.GeneratorStream(self.schema, data)
329        return flight.GeneratorStream(
330            self.schema, itertools.repeat(self.schema.empty_table()))
331
332
333class SlowFlightServer(FlightServerBase):
334    """A Flight server that delays its responses to test timeouts."""
335
336    def do_get(self, context, ticket):
337        return flight.GeneratorStream(pa.schema([('a', pa.int32())]),
338                                      self.slow_stream())
339
340    def do_action(self, context, action):
341        time.sleep(0.5)
342        return []
343
344    @staticmethod
345    def slow_stream():
346        data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
347        yield pa.Table.from_arrays(data1, names=['a'])
348        # The second message should never get sent; the client should
349        # cancel before we send this
350        time.sleep(10)
351        yield pa.Table.from_arrays(data1, names=['a'])
352
353
354class ErrorFlightServer(FlightServerBase):
355    """A Flight server that uses all the Flight-specific errors."""
356
357    def do_action(self, context, action):
358        if action.type == "internal":
359            raise flight.FlightInternalError("foo")
360        elif action.type == "timedout":
361            raise flight.FlightTimedOutError("foo")
362        elif action.type == "cancel":
363            raise flight.FlightCancelledError("foo")
364        elif action.type == "unauthenticated":
365            raise flight.FlightUnauthenticatedError("foo")
366        elif action.type == "unauthorized":
367            raise flight.FlightUnauthorizedError("foo")
368        elif action.type == "protobuf":
369            err_msg = b'this is an error message'
370            raise flight.FlightUnauthorizedError("foo", err_msg)
371        raise NotImplementedError
372
373    def list_flights(self, context, criteria):
374        yield flight.FlightInfo(
375            pa.schema([]),
376            flight.FlightDescriptor.for_path('/foo'),
377            [],
378            -1, -1
379        )
380        raise flight.FlightInternalError("foo")
381
382    def do_put(self, context, descriptor, reader, writer):
383        if descriptor.command == b"internal":
384            raise flight.FlightInternalError("foo")
385        elif descriptor.command == b"timedout":
386            raise flight.FlightTimedOutError("foo")
387        elif descriptor.command == b"cancel":
388            raise flight.FlightCancelledError("foo")
389        elif descriptor.command == b"unauthenticated":
390            raise flight.FlightUnauthenticatedError("foo")
391        elif descriptor.command == b"unauthorized":
392            raise flight.FlightUnauthorizedError("foo")
393        elif descriptor.command == b"protobuf":
394            err_msg = b'this is an error message'
395            raise flight.FlightUnauthorizedError("foo", err_msg)
396
397
398class ExchangeFlightServer(FlightServerBase):
399    """A server for testing DoExchange."""
400
401    def __init__(self, options=None, **kwargs):
402        super().__init__(**kwargs)
403        self.options = options
404
405    def do_exchange(self, context, descriptor, reader, writer):
406        if descriptor.descriptor_type != flight.DescriptorType.CMD:
407            raise pa.ArrowInvalid("Must provide a command descriptor")
408        elif descriptor.command == b"echo":
409            return self.exchange_echo(context, reader, writer)
410        elif descriptor.command == b"get":
411            return self.exchange_do_get(context, reader, writer)
412        elif descriptor.command == b"put":
413            return self.exchange_do_put(context, reader, writer)
414        elif descriptor.command == b"transform":
415            return self.exchange_transform(context, reader, writer)
416        else:
417            raise pa.ArrowInvalid(
418                "Unknown command: {}".format(descriptor.command))
419
420    def exchange_do_get(self, context, reader, writer):
421        """Emulate DoGet with DoExchange."""
422        data = pa.Table.from_arrays([
423            pa.array(range(0, 10 * 1024))
424        ], names=["a"])
425        writer.begin(data.schema)
426        writer.write_table(data)
427
428    def exchange_do_put(self, context, reader, writer):
429        """Emulate DoPut with DoExchange."""
430        num_batches = 0
431        for chunk in reader:
432            if not chunk.data:
433                raise pa.ArrowInvalid("All chunks must have data.")
434            num_batches += 1
435        writer.write_metadata(str(num_batches).encode("utf-8"))
436
437    def exchange_echo(self, context, reader, writer):
438        """Run a simple echo server."""
439        started = False
440        for chunk in reader:
441            if not started and chunk.data:
442                writer.begin(chunk.data.schema, options=self.options)
443                started = True
444            if chunk.app_metadata and chunk.data:
445                writer.write_with_metadata(chunk.data, chunk.app_metadata)
446            elif chunk.app_metadata:
447                writer.write_metadata(chunk.app_metadata)
448            elif chunk.data:
449                writer.write_batch(chunk.data)
450            else:
451                assert False, "Should not happen"
452
453    def exchange_transform(self, context, reader, writer):
454        """Sum rows in an uploaded table."""
455        for field in reader.schema:
456            if not pa.types.is_integer(field.type):
457                raise pa.ArrowInvalid("Invalid field: " + repr(field))
458        table = reader.read_all()
459        sums = [0] * table.num_rows
460        for column in table:
461            for row, value in enumerate(column):
462                sums[row] += value.as_py()
463        result = pa.Table.from_arrays([pa.array(sums)], names=["sum"])
464        writer.begin(result.schema)
465        writer.write_table(result)
466
467
468class HttpBasicServerAuthHandler(ServerAuthHandler):
469    """An example implementation of HTTP basic authentication."""
470
471    def __init__(self, creds):
472        super().__init__()
473        self.creds = creds
474
475    def authenticate(self, outgoing, incoming):
476        buf = incoming.read()
477        auth = flight.BasicAuth.deserialize(buf)
478        if auth.username not in self.creds:
479            raise flight.FlightUnauthenticatedError("unknown user")
480        if self.creds[auth.username] != auth.password:
481            raise flight.FlightUnauthenticatedError("wrong password")
482        outgoing.write(tobytes(auth.username))
483
484    def is_valid(self, token):
485        if not token:
486            raise flight.FlightUnauthenticatedError("token not provided")
487        if token not in self.creds:
488            raise flight.FlightUnauthenticatedError("unknown user")
489        return token
490
491
492class HttpBasicClientAuthHandler(ClientAuthHandler):
493    """An example implementation of HTTP basic authentication."""
494
495    def __init__(self, username, password):
496        super().__init__()
497        self.basic_auth = flight.BasicAuth(username, password)
498        self.token = None
499
500    def authenticate(self, outgoing, incoming):
501        auth = self.basic_auth.serialize()
502        outgoing.write(auth)
503        self.token = incoming.read()
504
505    def get_token(self):
506        return self.token
507
508
509class TokenServerAuthHandler(ServerAuthHandler):
510    """An example implementation of authentication via handshake."""
511
512    def __init__(self, creds):
513        super().__init__()
514        self.creds = creds
515
516    def authenticate(self, outgoing, incoming):
517        username = incoming.read()
518        password = incoming.read()
519        if username in self.creds and self.creds[username] == password:
520            outgoing.write(base64.b64encode(b'secret:' + username))
521        else:
522            raise flight.FlightUnauthenticatedError(
523                "invalid username/password")
524
525    def is_valid(self, token):
526        token = base64.b64decode(token)
527        if not token.startswith(b'secret:'):
528            raise flight.FlightUnauthenticatedError("invalid token")
529        return token[7:]
530
531
532class TokenClientAuthHandler(ClientAuthHandler):
533    """An example implementation of authentication via handshake."""
534
535    def __init__(self, username, password):
536        super().__init__()
537        self.username = username
538        self.password = password
539        self.token = b''
540
541    def authenticate(self, outgoing, incoming):
542        outgoing.write(self.username)
543        outgoing.write(self.password)
544        self.token = incoming.read()
545
546    def get_token(self):
547        return self.token
548
549
550class NoopAuthHandler(ServerAuthHandler):
551    """A no-op auth handler."""
552
553    def authenticate(self, outgoing, incoming):
554        """Do nothing."""
555
556    def is_valid(self, token):
557        """
558        Returning an empty string.
559        Returning None causes Type error.
560        """
561        return ""
562
563
564def case_insensitive_header_lookup(headers, lookup_key):
565    """Lookup the value of given key in the given headers.
566       The key lookup is case insensitive.
567    """
568    for key in headers:
569        if key.lower() == lookup_key.lower():
570            return headers.get(key)
571
572
573class ClientHeaderAuthMiddlewareFactory(ClientMiddlewareFactory):
574    """ClientMiddlewareFactory that creates ClientAuthHeaderMiddleware."""
575
576    def __init__(self):
577        self.call_credential = []
578
579    def start_call(self, info):
580        return ClientHeaderAuthMiddleware(self)
581
582    def set_call_credential(self, call_credential):
583        self.call_credential = call_credential
584
585
586class ClientHeaderAuthMiddleware(ClientMiddleware):
587    """
588    ClientMiddleware that extracts the authorization header
589    from the server.
590
591    This is an example of a ClientMiddleware that can extract
592    the bearer token authorization header from a HTTP header
593    authentication enabled server.
594
595    Parameters
596    ----------
597    factory : ClientHeaderAuthMiddlewareFactory
598        This factory is used to set call credentials if an
599        authorization header is found in the headers from the server.
600    """
601
602    def __init__(self, factory):
603        self.factory = factory
604
605    def received_headers(self, headers):
606        auth_header = case_insensitive_header_lookup(headers, 'Authorization')
607        self.factory.set_call_credential([
608            b'authorization',
609            auth_header[0].encode("utf-8")])
610
611
612class HeaderAuthServerMiddlewareFactory(ServerMiddlewareFactory):
613    """Validates incoming username and password."""
614
615    def start_call(self, info, headers):
616        auth_header = case_insensitive_header_lookup(
617            headers,
618            'Authorization'
619        )
620        values = auth_header[0].split(' ')
621        token = ''
622        error_message = 'Invalid credentials'
623
624        if values[0] == 'Basic':
625            decoded = base64.b64decode(values[1])
626            pair = decoded.decode("utf-8").split(':')
627            if not (pair[0] == 'test' and pair[1] == 'password'):
628                raise flight.FlightUnauthenticatedError(error_message)
629            token = 'token1234'
630        elif values[0] == 'Bearer':
631            token = values[1]
632            if not token == 'token1234':
633                raise flight.FlightUnauthenticatedError(error_message)
634        else:
635            raise flight.FlightUnauthenticatedError(error_message)
636
637        return HeaderAuthServerMiddleware(token)
638
639
640class HeaderAuthServerMiddleware(ServerMiddleware):
641    """A ServerMiddleware that transports incoming username and passowrd."""
642
643    def __init__(self, token):
644        self.token = token
645
646    def sending_headers(self):
647        return {'authorization': 'Bearer ' + self.token}
648
649
650class HeaderAuthFlightServer(FlightServerBase):
651    """A Flight server that tests with basic token authentication. """
652
653    def do_action(self, context, action):
654        middleware = context.get_middleware("auth")
655        if middleware:
656            auth_header = case_insensitive_header_lookup(
657                middleware.sending_headers(), 'Authorization')
658            values = auth_header.split(' ')
659            return [values[1].encode("utf-8")]
660        raise flight.FlightUnauthenticatedError(
661            'No token auth middleware found.')
662
663
664class ArbitraryHeadersServerMiddlewareFactory(ServerMiddlewareFactory):
665    """A ServerMiddlewareFactory that transports arbitrary headers."""
666
667    def start_call(self, info, headers):
668        return ArbitraryHeadersServerMiddleware(headers)
669
670
671class ArbitraryHeadersServerMiddleware(ServerMiddleware):
672    """A ServerMiddleware that transports arbitrary headers."""
673
674    def __init__(self, incoming):
675        self.incoming = incoming
676
677    def sending_headers(self):
678        return self.incoming
679
680
681class ArbitraryHeadersFlightServer(FlightServerBase):
682    """A Flight server that tests multiple arbitrary headers."""
683
684    def do_action(self, context, action):
685        middleware = context.get_middleware("arbitrary-headers")
686        if middleware:
687            headers = middleware.sending_headers()
688            header_1 = case_insensitive_header_lookup(
689                headers,
690                'test-header-1'
691            )
692            header_2 = case_insensitive_header_lookup(
693                headers,
694                'test-header-2'
695            )
696            value1 = header_1[0].encode("utf-8")
697            value2 = header_2[0].encode("utf-8")
698            return [value1, value2]
699        raise flight.FlightServerError("No headers middleware found")
700
701
702class HeaderServerMiddleware(ServerMiddleware):
703    """Expose a per-call value to the RPC method body."""
704
705    def __init__(self, special_value):
706        self.special_value = special_value
707
708
709class HeaderServerMiddlewareFactory(ServerMiddlewareFactory):
710    """Expose a per-call hard-coded value to the RPC method body."""
711
712    def start_call(self, info, headers):
713        return HeaderServerMiddleware("right value")
714
715
716class HeaderFlightServer(FlightServerBase):
717    """Echo back the per-call hard-coded value."""
718
719    def do_action(self, context, action):
720        middleware = context.get_middleware("test")
721        if middleware:
722            return [middleware.special_value.encode()]
723        return [b""]
724
725
726class MultiHeaderFlightServer(FlightServerBase):
727    """Test sending/receiving multiple (binary-valued) headers."""
728
729    def do_action(self, context, action):
730        middleware = context.get_middleware("test")
731        headers = repr(middleware.client_headers).encode("utf-8")
732        return [headers]
733
734
735class SelectiveAuthServerMiddlewareFactory(ServerMiddlewareFactory):
736    """Deny access to certain methods based on a header."""
737
738    def start_call(self, info, headers):
739        if info.method == flight.FlightMethod.LIST_ACTIONS:
740            # No auth needed
741            return
742
743        token = headers.get("x-auth-token")
744        if not token:
745            raise flight.FlightUnauthenticatedError("No token")
746
747        token = token[0]
748        if token != "password":
749            raise flight.FlightUnauthenticatedError("Invalid token")
750
751        return HeaderServerMiddleware(token)
752
753
754class SelectiveAuthClientMiddlewareFactory(ClientMiddlewareFactory):
755    def start_call(self, info):
756        return SelectiveAuthClientMiddleware()
757
758
759class SelectiveAuthClientMiddleware(ClientMiddleware):
760    def sending_headers(self):
761        return {
762            "x-auth-token": "password",
763        }
764
765
766class RecordingServerMiddlewareFactory(ServerMiddlewareFactory):
767    """Record what methods were called."""
768
769    def __init__(self):
770        super().__init__()
771        self.methods = []
772
773    def start_call(self, info, headers):
774        self.methods.append(info.method)
775        return None
776
777
778class RecordingClientMiddlewareFactory(ClientMiddlewareFactory):
779    """Record what methods were called."""
780
781    def __init__(self):
782        super().__init__()
783        self.methods = []
784
785    def start_call(self, info):
786        self.methods.append(info.method)
787        return None
788
789
790class MultiHeaderClientMiddlewareFactory(ClientMiddlewareFactory):
791    """Test sending/receiving multiple (binary-valued) headers."""
792
793    def __init__(self):
794        # Read in test_middleware_multi_header below.
795        # The middleware instance will update this value.
796        self.last_headers = {}
797
798    def start_call(self, info):
799        return MultiHeaderClientMiddleware(self)
800
801
802class MultiHeaderClientMiddleware(ClientMiddleware):
803    """Test sending/receiving multiple (binary-valued) headers."""
804
805    EXPECTED = {
806        "x-text": ["foo", "bar"],
807        "x-binary-bin": [b"\x00", b"\x01"],
808    }
809
810    def __init__(self, factory):
811        self.factory = factory
812
813    def sending_headers(self):
814        return self.EXPECTED
815
816    def received_headers(self, headers):
817        # Let the test code know what the last set of headers we
818        # received were.
819        self.factory.last_headers = headers
820
821
822class MultiHeaderServerMiddlewareFactory(ServerMiddlewareFactory):
823    """Test sending/receiving multiple (binary-valued) headers."""
824
825    def start_call(self, info, headers):
826        return MultiHeaderServerMiddleware(headers)
827
828
829class MultiHeaderServerMiddleware(ServerMiddleware):
830    """Test sending/receiving multiple (binary-valued) headers."""
831
832    def __init__(self, client_headers):
833        self.client_headers = client_headers
834
835    def sending_headers(self):
836        return MultiHeaderClientMiddleware.EXPECTED
837
838
839class LargeMetadataFlightServer(FlightServerBase):
840    """Regression test for ARROW-13253."""
841
842    def __init__(self, *args, **kwargs):
843        super().__init__(*args, **kwargs)
844        self._metadata = b' ' * (2 ** 31 + 1)
845
846    def do_get(self, context, ticket):
847        schema = pa.schema([('a', pa.int64())])
848        return flight.GeneratorStream(schema, [
849            (pa.record_batch([[1]], schema=schema), self._metadata),
850        ])
851
852    def do_exchange(self, context, descriptor, reader, writer):
853        writer.write_metadata(self._metadata)
854
855
856def test_flight_server_location_argument():
857    locations = [
858        None,
859        'grpc://localhost:0',
860        ('localhost', find_free_port()),
861    ]
862    for location in locations:
863        with FlightServerBase(location) as server:
864            assert isinstance(server, FlightServerBase)
865
866
867def test_server_exit_reraises_exception():
868    with pytest.raises(ValueError):
869        with FlightServerBase():
870            raise ValueError()
871
872
873@pytest.mark.slow
874def test_client_wait_for_available():
875    location = ('localhost', find_free_port())
876    server = None
877
878    def serve():
879        global server
880        time.sleep(0.5)
881        server = FlightServerBase(location)
882        server.serve()
883
884    client = FlightClient(location)
885    thread = threading.Thread(target=serve, daemon=True)
886    thread.start()
887
888    started = time.time()
889    client.wait_for_available(timeout=5)
890    elapsed = time.time() - started
891    assert elapsed >= 0.5
892
893
894def test_flight_list_flights():
895    """Try a simple list_flights call."""
896    with ConstantFlightServer() as server:
897        client = flight.connect(('localhost', server.port))
898        assert list(client.list_flights()) == []
899        flights = client.list_flights(ConstantFlightServer.CRITERIA)
900        assert len(list(flights)) == 1
901
902
903def test_flight_do_get_ints():
904    """Try a simple do_get call."""
905    table = simple_ints_table()
906
907    with ConstantFlightServer() as server:
908        client = flight.connect(('localhost', server.port))
909        data = client.do_get(flight.Ticket(b'ints')).read_all()
910        assert data.equals(table)
911
912    options = pa.ipc.IpcWriteOptions(
913        metadata_version=pa.ipc.MetadataVersion.V4)
914    with ConstantFlightServer(options=options) as server:
915        client = flight.connect(('localhost', server.port))
916        data = client.do_get(flight.Ticket(b'ints')).read_all()
917        assert data.equals(table)
918
919        # Also test via RecordBatchReader interface
920        data = client.do_get(flight.Ticket(b'ints')).to_reader().read_all()
921        assert data.equals(table)
922
923    with pytest.raises(flight.FlightServerError,
924                       match="expected IpcWriteOptions, got <class 'int'>"):
925        with ConstantFlightServer(options=42) as server:
926            client = flight.connect(('localhost', server.port))
927            data = client.do_get(flight.Ticket(b'ints')).read_all()
928
929
930@pytest.mark.pandas
931def test_do_get_ints_pandas():
932    """Try a simple do_get call."""
933    table = simple_ints_table()
934
935    with ConstantFlightServer() as server:
936        client = flight.connect(('localhost', server.port))
937        data = client.do_get(flight.Ticket(b'ints')).read_pandas()
938        assert list(data['some_ints']) == table.column(0).to_pylist()
939
940
941def test_flight_do_get_dicts():
942    table = simple_dicts_table()
943
944    with ConstantFlightServer() as server:
945        client = flight.connect(('localhost', server.port))
946        data = client.do_get(flight.Ticket(b'dicts')).read_all()
947        assert data.equals(table)
948
949
950def test_flight_do_get_ticket():
951    """Make sure Tickets get passed to the server."""
952    data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())]
953    table = pa.Table.from_arrays(data1, names=['a'])
954    with CheckTicketFlightServer(expected_ticket=b'the-ticket') as server:
955        client = flight.connect(('localhost', server.port))
956        data = client.do_get(flight.Ticket(b'the-ticket')).read_all()
957        assert data.equals(table)
958
959
960def test_flight_get_info():
961    """Make sure FlightEndpoint accepts string and object URIs."""
962    with GetInfoFlightServer() as server:
963        client = FlightClient(('localhost', server.port))
964        info = client.get_flight_info(flight.FlightDescriptor.for_command(b''))
965        assert info.total_records == -1
966        assert info.total_bytes == -1
967        assert info.schema == pa.schema([('a', pa.int32())])
968        assert len(info.endpoints) == 2
969        assert len(info.endpoints[0].locations) == 1
970        assert info.endpoints[0].locations[0] == flight.Location('grpc://test')
971        assert info.endpoints[1].locations[0] == \
972            flight.Location.for_grpc_tcp('localhost', 5005)
973
974
975def test_flight_get_schema():
976    """Make sure GetSchema returns correct schema."""
977    with GetInfoFlightServer() as server:
978        client = FlightClient(('localhost', server.port))
979        info = client.get_schema(flight.FlightDescriptor.for_command(b''))
980        assert info.schema == pa.schema([('a', pa.int32())])
981
982
983def test_list_actions():
984    """Make sure the return type of ListActions is validated."""
985    # ARROW-6392
986    with ListActionsErrorFlightServer() as server:
987        client = FlightClient(('localhost', server.port))
988        with pytest.raises(
989                flight.FlightServerError,
990                match=("Results of list_actions must be "
991                       "ActionType or tuple")
992        ):
993            list(client.list_actions())
994
995    with ListActionsFlightServer() as server:
996        client = FlightClient(('localhost', server.port))
997        assert list(client.list_actions()) == \
998            ListActionsFlightServer.expected_actions()
999
1000
1001class ConvenienceServer(FlightServerBase):
1002    """
1003    Server for testing various implementation conveniences (auto-boxing, etc.)
1004    """
1005
1006    @property
1007    def simple_action_results(self):
1008        return [b'foo', b'bar', b'baz']
1009
1010    def do_action(self, context, action):
1011        if action.type == 'simple-action':
1012            return self.simple_action_results
1013        elif action.type == 'echo':
1014            return [action.body]
1015        elif action.type == 'bad-action':
1016            return ['foo']
1017        elif action.type == 'arrow-exception':
1018            raise pa.ArrowMemoryError()
1019
1020
1021def test_do_action_result_convenience():
1022    with ConvenienceServer() as server:
1023        client = FlightClient(('localhost', server.port))
1024
1025        # do_action as action type without body
1026        results = [x.body for x in client.do_action('simple-action')]
1027        assert results == server.simple_action_results
1028
1029        # do_action with tuple of type and body
1030        body = b'the-body'
1031        results = [x.body for x in client.do_action(('echo', body))]
1032        assert results == [body]
1033
1034
1035def test_nicer_server_exceptions():
1036    with ConvenienceServer() as server:
1037        client = FlightClient(('localhost', server.port))
1038        with pytest.raises(flight.FlightServerError,
1039                           match="a bytes-like object is required"):
1040            list(client.do_action('bad-action'))
1041        # While Flight/C++ sends across the original status code, it
1042        # doesn't get mapped to the equivalent code here, since we
1043        # want to be able to distinguish between client- and server-
1044        # side errors.
1045        with pytest.raises(flight.FlightServerError,
1046                           match="ArrowMemoryError"):
1047            list(client.do_action('arrow-exception'))
1048
1049
1050def test_get_port():
1051    """Make sure port() works."""
1052    server = GetInfoFlightServer("grpc://localhost:0")
1053    try:
1054        assert server.port > 0
1055    finally:
1056        server.shutdown()
1057
1058
1059@pytest.mark.skipif(os.name == 'nt',
1060                    reason="Unix sockets can't be tested on Windows")
1061def test_flight_domain_socket():
1062    """Try a simple do_get call over a Unix domain socket."""
1063    with tempfile.NamedTemporaryFile() as sock:
1064        sock.close()
1065        location = flight.Location.for_grpc_unix(sock.name)
1066        with ConstantFlightServer(location=location):
1067            client = FlightClient(location)
1068
1069            reader = client.do_get(flight.Ticket(b'ints'))
1070            table = simple_ints_table()
1071            assert reader.schema.equals(table.schema)
1072            data = reader.read_all()
1073            assert data.equals(table)
1074
1075            reader = client.do_get(flight.Ticket(b'dicts'))
1076            table = simple_dicts_table()
1077            assert reader.schema.equals(table.schema)
1078            data = reader.read_all()
1079            assert data.equals(table)
1080
1081
1082@pytest.mark.slow
1083def test_flight_large_message():
1084    """Try sending/receiving a large message via Flight.
1085
1086    See ARROW-4421: by default, gRPC won't allow us to send messages >
1087    4MiB in size.
1088    """
1089    data = pa.Table.from_arrays([
1090        pa.array(range(0, 10 * 1024 * 1024))
1091    ], names=['a'])
1092
1093    with EchoFlightServer(expected_schema=data.schema) as server:
1094        client = FlightClient(('localhost', server.port))
1095        writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
1096                                  data.schema)
1097        # Write a single giant chunk
1098        writer.write_table(data, 10 * 1024 * 1024)
1099        writer.close()
1100        result = client.do_get(flight.Ticket(b'')).read_all()
1101        assert result.equals(data)
1102
1103
1104def test_flight_generator_stream():
1105    """Try downloading a flight of RecordBatches in a GeneratorStream."""
1106    data = pa.Table.from_arrays([
1107        pa.array(range(0, 10 * 1024))
1108    ], names=['a'])
1109
1110    with EchoStreamFlightServer() as server:
1111        client = FlightClient(('localhost', server.port))
1112        writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'),
1113                                  data.schema)
1114        writer.write_table(data)
1115        writer.close()
1116        result = client.do_get(flight.Ticket(b'')).read_all()
1117        assert result.equals(data)
1118
1119
1120def test_flight_invalid_generator_stream():
1121    """Try streaming data with mismatched schemas."""
1122    with InvalidStreamFlightServer() as server:
1123        client = FlightClient(('localhost', server.port))
1124        with pytest.raises(pa.ArrowException):
1125            client.do_get(flight.Ticket(b'')).read_all()
1126
1127
1128def test_timeout_fires():
1129    """Make sure timeouts fire on slow requests."""
1130    # Do this in a separate thread so that if it fails, we don't hang
1131    # the entire test process
1132    with SlowFlightServer() as server:
1133        client = FlightClient(('localhost', server.port))
1134        action = flight.Action("", b"")
1135        options = flight.FlightCallOptions(timeout=0.2)
1136        # gRPC error messages change based on version, so don't look
1137        # for a particular error
1138        with pytest.raises(flight.FlightTimedOutError):
1139            list(client.do_action(action, options=options))
1140
1141
1142def test_timeout_passes():
1143    """Make sure timeouts do not fire on fast requests."""
1144    with ConstantFlightServer() as server:
1145        client = FlightClient(('localhost', server.port))
1146        options = flight.FlightCallOptions(timeout=5.0)
1147        client.do_get(flight.Ticket(b'ints'), options=options).read_all()
1148
1149
1150basic_auth_handler = HttpBasicServerAuthHandler(creds={
1151    b"test": b"p4ssw0rd",
1152})
1153
1154token_auth_handler = TokenServerAuthHandler(creds={
1155    b"test": b"p4ssw0rd",
1156})
1157
1158
1159@pytest.mark.slow
1160def test_http_basic_unauth():
1161    """Test that auth fails when not authenticated."""
1162    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
1163        client = FlightClient(('localhost', server.port))
1164        action = flight.Action("who-am-i", b"")
1165        with pytest.raises(flight.FlightUnauthenticatedError,
1166                           match=".*unauthenticated.*"):
1167            list(client.do_action(action))
1168
1169
1170@pytest.mark.skipif(os.name == 'nt',
1171                    reason="ARROW-10013: gRPC on Windows corrupts peer()")
1172def test_http_basic_auth():
1173    """Test a Python implementation of HTTP basic authentication."""
1174    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
1175        client = FlightClient(('localhost', server.port))
1176        action = flight.Action("who-am-i", b"")
1177        client.authenticate(HttpBasicClientAuthHandler('test', 'p4ssw0rd'))
1178        results = client.do_action(action)
1179        identity = next(results)
1180        assert identity.body.to_pybytes() == b'test'
1181        peer_address = next(results)
1182        assert peer_address.body.to_pybytes() != b''
1183
1184
1185def test_http_basic_auth_invalid_password():
1186    """Test that auth fails with the wrong password."""
1187    with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server:
1188        client = FlightClient(('localhost', server.port))
1189        action = flight.Action("who-am-i", b"")
1190        with pytest.raises(flight.FlightUnauthenticatedError,
1191                           match=".*wrong password.*"):
1192            client.authenticate(HttpBasicClientAuthHandler('test', 'wrong'))
1193            next(client.do_action(action))
1194
1195
1196def test_token_auth():
1197    """Test an auth mechanism that uses a handshake."""
1198    with EchoStreamFlightServer(auth_handler=token_auth_handler) as server:
1199        client = FlightClient(('localhost', server.port))
1200        action = flight.Action("who-am-i", b"")
1201        client.authenticate(TokenClientAuthHandler('test', 'p4ssw0rd'))
1202        identity = next(client.do_action(action))
1203        assert identity.body.to_pybytes() == b'test'
1204
1205
1206def test_token_auth_invalid():
1207    """Test an auth mechanism that uses a handshake."""
1208    with EchoStreamFlightServer(auth_handler=token_auth_handler) as server:
1209        client = FlightClient(('localhost', server.port))
1210        with pytest.raises(flight.FlightUnauthenticatedError):
1211            client.authenticate(TokenClientAuthHandler('test', 'wrong'))
1212
1213
1214header_auth_server_middleware_factory = HeaderAuthServerMiddlewareFactory()
1215no_op_auth_handler = NoopAuthHandler()
1216
1217
1218def test_authenticate_basic_token():
1219    """Test authenticate_basic_token with bearer token and auth headers."""
1220    with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
1221        "auth": HeaderAuthServerMiddlewareFactory()
1222    }) as server:
1223        client = FlightClient(('localhost', server.port))
1224        token_pair = client.authenticate_basic_token(b'test', b'password')
1225        assert token_pair[0] == b'authorization'
1226        assert token_pair[1] == b'Bearer token1234'
1227
1228
1229def test_authenticate_basic_token_invalid_password():
1230    """Test authenticate_basic_token with an invalid password."""
1231    with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
1232        "auth": HeaderAuthServerMiddlewareFactory()
1233    }) as server:
1234        client = FlightClient(('localhost', server.port))
1235        with pytest.raises(flight.FlightUnauthenticatedError):
1236            client.authenticate_basic_token(b'test', b'badpassword')
1237
1238
1239def test_authenticate_basic_token_and_action():
1240    """Test authenticate_basic_token and doAction after authentication."""
1241    with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
1242        "auth": HeaderAuthServerMiddlewareFactory()
1243    }) as server:
1244        client = FlightClient(('localhost', server.port))
1245        token_pair = client.authenticate_basic_token(b'test', b'password')
1246        assert token_pair[0] == b'authorization'
1247        assert token_pair[1] == b'Bearer token1234'
1248        options = flight.FlightCallOptions(headers=[token_pair])
1249        result = list(client.do_action(
1250            action=flight.Action('test-action', b''), options=options))
1251        assert result[0].body.to_pybytes() == b'token1234'
1252
1253
1254def test_authenticate_basic_token_with_client_middleware():
1255    """Test authenticate_basic_token with client middleware
1256       to intercept authorization header returned by the
1257       HTTP header auth enabled server.
1258    """
1259    with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={
1260        "auth": HeaderAuthServerMiddlewareFactory()
1261    }) as server:
1262        client_auth_middleware = ClientHeaderAuthMiddlewareFactory()
1263        client = FlightClient(
1264            ('localhost', server.port),
1265            middleware=[client_auth_middleware]
1266        )
1267        encoded_credentials = base64.b64encode(b'test:password')
1268        options = flight.FlightCallOptions(headers=[
1269            (b'authorization', b'Basic ' + encoded_credentials)
1270        ])
1271        result = list(client.do_action(
1272            action=flight.Action('test-action', b''), options=options))
1273        assert result[0].body.to_pybytes() == b'token1234'
1274        assert client_auth_middleware.call_credential[0] == b'authorization'
1275        assert client_auth_middleware.call_credential[1] == \
1276            b'Bearer ' + b'token1234'
1277        result2 = list(client.do_action(
1278            action=flight.Action('test-action', b''), options=options))
1279        assert result2[0].body.to_pybytes() == b'token1234'
1280        assert client_auth_middleware.call_credential[0] == b'authorization'
1281        assert client_auth_middleware.call_credential[1] == \
1282            b'Bearer ' + b'token1234'
1283
1284
1285def test_arbitrary_headers_in_flight_call_options():
1286    """Test passing multiple arbitrary headers to the middleware."""
1287    with ArbitraryHeadersFlightServer(
1288            auth_handler=no_op_auth_handler,
1289            middleware={
1290                "auth": HeaderAuthServerMiddlewareFactory(),
1291                "arbitrary-headers": ArbitraryHeadersServerMiddlewareFactory()
1292            }) as server:
1293        client = FlightClient(('localhost', server.port))
1294        token_pair = client.authenticate_basic_token(b'test', b'password')
1295        assert token_pair[0] == b'authorization'
1296        assert token_pair[1] == b'Bearer token1234'
1297        options = flight.FlightCallOptions(headers=[
1298            token_pair,
1299            (b'test-header-1', b'value1'),
1300            (b'test-header-2', b'value2')
1301        ])
1302        result = list(client.do_action(flight.Action(
1303            "test-action", b""), options=options))
1304        assert result[0].body.to_pybytes() == b'value1'
1305        assert result[1].body.to_pybytes() == b'value2'
1306
1307
1308def test_location_invalid():
1309    """Test constructing invalid URIs."""
1310    with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"):
1311        flight.connect("%")
1312
1313    with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"):
1314        ConstantFlightServer("%")
1315
1316
1317def test_location_unknown_scheme():
1318    """Test creating locations for unknown schemes."""
1319    assert flight.Location("s3://foo").uri == b"s3://foo"
1320    assert flight.Location("https://example.com/bar.parquet").uri == \
1321        b"https://example.com/bar.parquet"
1322
1323
1324@pytest.mark.slow
1325@pytest.mark.requires_testing_data
1326def test_tls_fails():
1327    """Make sure clients cannot connect when cert verification fails."""
1328    certs = example_tls_certs()
1329
1330    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
1331        # Ensure client doesn't connect when certificate verification
1332        # fails (this is a slow test since gRPC does retry a few times)
1333        client = FlightClient("grpc+tls://localhost:" + str(s.port))
1334
1335        # gRPC error messages change based on version, so don't look
1336        # for a particular error
1337        with pytest.raises(flight.FlightUnavailableError):
1338            client.do_get(flight.Ticket(b'ints')).read_all()
1339
1340
1341@pytest.mark.requires_testing_data
1342def test_tls_do_get():
1343    """Try a simple do_get call over TLS."""
1344    table = simple_ints_table()
1345    certs = example_tls_certs()
1346
1347    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
1348        client = FlightClient(('localhost', s.port),
1349                              tls_root_certs=certs["root_cert"])
1350        data = client.do_get(flight.Ticket(b'ints')).read_all()
1351        assert data.equals(table)
1352
1353
1354@pytest.mark.requires_testing_data
1355def test_tls_disable_server_verification():
1356    """Try a simple do_get call over TLS with server verification disabled."""
1357    table = simple_ints_table()
1358    certs = example_tls_certs()
1359
1360    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
1361        try:
1362            client = FlightClient(('localhost', s.port),
1363                                  disable_server_verification=True)
1364        except NotImplementedError:
1365            pytest.skip('disable_server_verification feature is not available')
1366        data = client.do_get(flight.Ticket(b'ints')).read_all()
1367        assert data.equals(table)
1368
1369
1370@pytest.mark.requires_testing_data
1371def test_tls_override_hostname():
1372    """Check that incorrectly overriding the hostname fails."""
1373    certs = example_tls_certs()
1374
1375    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
1376        client = flight.connect(('localhost', s.port),
1377                                tls_root_certs=certs["root_cert"],
1378                                override_hostname="fakehostname")
1379        with pytest.raises(flight.FlightUnavailableError):
1380            client.do_get(flight.Ticket(b'ints'))
1381
1382
1383def test_flight_do_get_metadata():
1384    """Try a simple do_get call with metadata."""
1385    data = [
1386        pa.array([-10, -5, 0, 5, 10])
1387    ]
1388    table = pa.Table.from_arrays(data, names=['a'])
1389
1390    batches = []
1391    with MetadataFlightServer() as server:
1392        client = FlightClient(('localhost', server.port))
1393        reader = client.do_get(flight.Ticket(b''))
1394        idx = 0
1395        while True:
1396            try:
1397                batch, metadata = reader.read_chunk()
1398                batches.append(batch)
1399                server_idx, = struct.unpack('<i', metadata.to_pybytes())
1400                assert idx == server_idx
1401                idx += 1
1402            except StopIteration:
1403                break
1404        data = pa.Table.from_batches(batches)
1405        assert data.equals(table)
1406
1407
1408def test_flight_do_get_metadata_v4():
1409    """Try a simple do_get call with V4 metadata version."""
1410    table = pa.Table.from_arrays(
1411        [pa.array([-10, -5, 0, 5, 10])], names=['a'])
1412    options = pa.ipc.IpcWriteOptions(
1413        metadata_version=pa.ipc.MetadataVersion.V4)
1414    with MetadataFlightServer(options=options) as server:
1415        client = FlightClient(('localhost', server.port))
1416        reader = client.do_get(flight.Ticket(b''))
1417        data = reader.read_all()
1418        assert data.equals(table)
1419
1420
1421def test_flight_do_put_metadata():
1422    """Try a simple do_put call with metadata."""
1423    data = [
1424        pa.array([-10, -5, 0, 5, 10])
1425    ]
1426    table = pa.Table.from_arrays(data, names=['a'])
1427
1428    with MetadataFlightServer() as server:
1429        client = FlightClient(('localhost', server.port))
1430        writer, metadata_reader = client.do_put(
1431            flight.FlightDescriptor.for_path(''),
1432            table.schema)
1433        with writer:
1434            for idx, batch in enumerate(table.to_batches(max_chunksize=1)):
1435                metadata = struct.pack('<i', idx)
1436                writer.write_with_metadata(batch, metadata)
1437                buf = metadata_reader.read()
1438                assert buf is not None
1439                server_idx, = struct.unpack('<i', buf.to_pybytes())
1440                assert idx == server_idx
1441
1442
1443def test_flight_do_put_limit():
1444    """Try a simple do_put call with a size limit."""
1445    large_batch = pa.RecordBatch.from_arrays([
1446        pa.array(np.ones(768, dtype=np.int64())),
1447    ], names=['a'])
1448
1449    with EchoFlightServer() as server:
1450        client = FlightClient(('localhost', server.port),
1451                              write_size_limit_bytes=4096)
1452        writer, metadata_reader = client.do_put(
1453            flight.FlightDescriptor.for_path(''),
1454            large_batch.schema)
1455        with writer:
1456            with pytest.raises(flight.FlightWriteSizeExceededError,
1457                               match="exceeded soft limit") as excinfo:
1458                writer.write_batch(large_batch)
1459            assert excinfo.value.limit == 4096
1460            smaller_batches = [
1461                large_batch.slice(0, 384),
1462                large_batch.slice(384),
1463            ]
1464            for batch in smaller_batches:
1465                writer.write_batch(batch)
1466        expected = pa.Table.from_batches([large_batch])
1467        actual = client.do_get(flight.Ticket(b'')).read_all()
1468        assert expected == actual
1469
1470
1471@pytest.mark.slow
1472def test_cancel_do_get():
1473    """Test canceling a DoGet operation on the client side."""
1474    with ConstantFlightServer() as server:
1475        client = FlightClient(('localhost', server.port))
1476        reader = client.do_get(flight.Ticket(b'ints'))
1477        reader.cancel()
1478        with pytest.raises(flight.FlightCancelledError, match=".*Cancel.*"):
1479            reader.read_chunk()
1480
1481
1482@pytest.mark.slow
1483def test_cancel_do_get_threaded():
1484    """Test canceling a DoGet operation from another thread."""
1485    with SlowFlightServer() as server:
1486        client = FlightClient(('localhost', server.port))
1487        reader = client.do_get(flight.Ticket(b'ints'))
1488
1489        read_first_message = threading.Event()
1490        stream_canceled = threading.Event()
1491        result_lock = threading.Lock()
1492        raised_proper_exception = threading.Event()
1493
1494        def block_read():
1495            reader.read_chunk()
1496            read_first_message.set()
1497            stream_canceled.wait(timeout=5)
1498            try:
1499                reader.read_chunk()
1500            except flight.FlightCancelledError:
1501                with result_lock:
1502                    raised_proper_exception.set()
1503
1504        thread = threading.Thread(target=block_read, daemon=True)
1505        thread.start()
1506        read_first_message.wait(timeout=5)
1507        reader.cancel()
1508        stream_canceled.set()
1509        thread.join(timeout=1)
1510
1511        with result_lock:
1512            assert raised_proper_exception.is_set()
1513
1514
1515def test_roundtrip_types():
1516    """Make sure serializable types round-trip."""
1517    ticket = flight.Ticket("foo")
1518    assert ticket == flight.Ticket.deserialize(ticket.serialize())
1519
1520    desc = flight.FlightDescriptor.for_command("test")
1521    assert desc == flight.FlightDescriptor.deserialize(desc.serialize())
1522
1523    desc = flight.FlightDescriptor.for_path("a", "b", "test.arrow")
1524    assert desc == flight.FlightDescriptor.deserialize(desc.serialize())
1525
1526    info = flight.FlightInfo(
1527        pa.schema([('a', pa.int32())]),
1528        desc,
1529        [
1530            flight.FlightEndpoint(b'', ['grpc://test']),
1531            flight.FlightEndpoint(
1532                b'',
1533                [flight.Location.for_grpc_tcp('localhost', 5005)],
1534            ),
1535        ],
1536        -1,
1537        -1,
1538    )
1539    info2 = flight.FlightInfo.deserialize(info.serialize())
1540    assert info.schema == info2.schema
1541    assert info.descriptor == info2.descriptor
1542    assert info.total_bytes == info2.total_bytes
1543    assert info.total_records == info2.total_records
1544    assert info.endpoints == info2.endpoints
1545
1546
1547def test_roundtrip_errors():
1548    """Ensure that Flight errors propagate from server to client."""
1549    with ErrorFlightServer() as server:
1550        client = FlightClient(('localhost', server.port))
1551
1552        with pytest.raises(flight.FlightInternalError, match=".*foo.*"):
1553            list(client.do_action(flight.Action("internal", b"")))
1554        with pytest.raises(flight.FlightTimedOutError, match=".*foo.*"):
1555            list(client.do_action(flight.Action("timedout", b"")))
1556        with pytest.raises(flight.FlightCancelledError, match=".*foo.*"):
1557            list(client.do_action(flight.Action("cancel", b"")))
1558        with pytest.raises(flight.FlightUnauthenticatedError, match=".*foo.*"):
1559            list(client.do_action(flight.Action("unauthenticated", b"")))
1560        with pytest.raises(flight.FlightUnauthorizedError, match=".*foo.*"):
1561            list(client.do_action(flight.Action("unauthorized", b"")))
1562        with pytest.raises(flight.FlightInternalError, match=".*foo.*"):
1563            list(client.list_flights())
1564
1565        data = [pa.array([-10, -5, 0, 5, 10])]
1566        table = pa.Table.from_arrays(data, names=['a'])
1567
1568        exceptions = {
1569            'internal': flight.FlightInternalError,
1570            'timedout': flight.FlightTimedOutError,
1571            'cancel': flight.FlightCancelledError,
1572            'unauthenticated': flight.FlightUnauthenticatedError,
1573            'unauthorized': flight.FlightUnauthorizedError,
1574        }
1575
1576        for command, exception in exceptions.items():
1577
1578            with pytest.raises(exception, match=".*foo.*"):
1579                writer, reader = client.do_put(
1580                    flight.FlightDescriptor.for_command(command),
1581                    table.schema)
1582                writer.write_table(table)
1583                writer.close()
1584
1585            with pytest.raises(exception, match=".*foo.*"):
1586                writer, reader = client.do_put(
1587                    flight.FlightDescriptor.for_command(command),
1588                    table.schema)
1589                writer.close()
1590
1591
1592def test_do_put_independent_read_write():
1593    """Ensure that separate threads can read/write on a DoPut."""
1594    # ARROW-6063: previously this would cause gRPC to abort when the
1595    # writer was closed (due to simultaneous reads), or would hang
1596    # forever.
1597    data = [
1598        pa.array([-10, -5, 0, 5, 10])
1599    ]
1600    table = pa.Table.from_arrays(data, names=['a'])
1601
1602    with MetadataFlightServer() as server:
1603        client = FlightClient(('localhost', server.port))
1604        writer, metadata_reader = client.do_put(
1605            flight.FlightDescriptor.for_path(''),
1606            table.schema)
1607
1608        count = [0]
1609
1610        def _reader_thread():
1611            while metadata_reader.read() is not None:
1612                count[0] += 1
1613
1614        thread = threading.Thread(target=_reader_thread)
1615        thread.start()
1616
1617        batches = table.to_batches(max_chunksize=1)
1618        with writer:
1619            for idx, batch in enumerate(batches):
1620                metadata = struct.pack('<i', idx)
1621                writer.write_with_metadata(batch, metadata)
1622            # Causes the server to stop writing and end the call
1623            writer.done_writing()
1624            # Thus reader thread will break out of loop
1625            thread.join()
1626        # writer.close() won't segfault since reader thread has
1627        # stopped
1628        assert count[0] == len(batches)
1629
1630
1631def test_server_middleware_same_thread():
1632    """Ensure that server middleware run on the same thread as the RPC."""
1633    with HeaderFlightServer(middleware={
1634        "test": HeaderServerMiddlewareFactory(),
1635    }) as server:
1636        client = FlightClient(('localhost', server.port))
1637        results = list(client.do_action(flight.Action(b"test", b"")))
1638        assert len(results) == 1
1639        value = results[0].body.to_pybytes()
1640        assert b"right value" == value
1641
1642
1643def test_middleware_reject():
1644    """Test rejecting an RPC with server middleware."""
1645    with HeaderFlightServer(middleware={
1646        "test": SelectiveAuthServerMiddlewareFactory(),
1647    }) as server:
1648        client = FlightClient(('localhost', server.port))
1649        # The middleware allows this through without auth.
1650        with pytest.raises(pa.ArrowNotImplementedError):
1651            list(client.list_actions())
1652
1653        # But not anything else.
1654        with pytest.raises(flight.FlightUnauthenticatedError):
1655            list(client.do_action(flight.Action(b"", b"")))
1656
1657        client = FlightClient(
1658            ('localhost', server.port),
1659            middleware=[SelectiveAuthClientMiddlewareFactory()]
1660        )
1661        response = next(client.do_action(flight.Action(b"", b"")))
1662        assert b"password" == response.body.to_pybytes()
1663
1664
1665def test_middleware_mapping():
1666    """Test that middleware records methods correctly."""
1667    server_middleware = RecordingServerMiddlewareFactory()
1668    client_middleware = RecordingClientMiddlewareFactory()
1669    with FlightServerBase(middleware={"test": server_middleware}) as server:
1670        client = FlightClient(
1671            ('localhost', server.port),
1672            middleware=[client_middleware]
1673        )
1674
1675        descriptor = flight.FlightDescriptor.for_command(b"")
1676        with pytest.raises(NotImplementedError):
1677            list(client.list_flights())
1678        with pytest.raises(NotImplementedError):
1679            client.get_flight_info(descriptor)
1680        with pytest.raises(NotImplementedError):
1681            client.get_schema(descriptor)
1682        with pytest.raises(NotImplementedError):
1683            client.do_get(flight.Ticket(b""))
1684        with pytest.raises(NotImplementedError):
1685            writer, _ = client.do_put(descriptor, pa.schema([]))
1686            writer.close()
1687        with pytest.raises(NotImplementedError):
1688            list(client.do_action(flight.Action(b"", b"")))
1689        with pytest.raises(NotImplementedError):
1690            list(client.list_actions())
1691        with pytest.raises(NotImplementedError):
1692            writer, _ = client.do_exchange(descriptor)
1693            writer.close()
1694
1695        expected = [
1696            flight.FlightMethod.LIST_FLIGHTS,
1697            flight.FlightMethod.GET_FLIGHT_INFO,
1698            flight.FlightMethod.GET_SCHEMA,
1699            flight.FlightMethod.DO_GET,
1700            flight.FlightMethod.DO_PUT,
1701            flight.FlightMethod.DO_ACTION,
1702            flight.FlightMethod.LIST_ACTIONS,
1703            flight.FlightMethod.DO_EXCHANGE,
1704        ]
1705        assert server_middleware.methods == expected
1706        assert client_middleware.methods == expected
1707
1708
1709def test_extra_info():
1710    with ErrorFlightServer() as server:
1711        client = FlightClient(('localhost', server.port))
1712        try:
1713            list(client.do_action(flight.Action("protobuf", b"")))
1714            assert False
1715        except flight.FlightUnauthorizedError as e:
1716            assert e.extra_info is not None
1717            ei = e.extra_info
1718            assert ei == b'this is an error message'
1719
1720
1721@pytest.mark.requires_testing_data
1722def test_mtls():
1723    """Test mutual TLS (mTLS) with gRPC."""
1724    certs = example_tls_certs()
1725    table = simple_ints_table()
1726
1727    with ConstantFlightServer(
1728            tls_certificates=[certs["certificates"][0]],
1729            verify_client=True,
1730            root_certificates=certs["root_cert"]) as s:
1731        client = FlightClient(
1732            ('localhost', s.port),
1733            tls_root_certs=certs["root_cert"],
1734            cert_chain=certs["certificates"][0].cert,
1735            private_key=certs["certificates"][0].key)
1736        data = client.do_get(flight.Ticket(b'ints')).read_all()
1737        assert data.equals(table)
1738
1739
1740def test_doexchange_get():
1741    """Emulate DoGet with DoExchange."""
1742    expected = pa.Table.from_arrays([
1743        pa.array(range(0, 10 * 1024))
1744    ], names=["a"])
1745
1746    with ExchangeFlightServer() as server:
1747        client = FlightClient(("localhost", server.port))
1748        descriptor = flight.FlightDescriptor.for_command(b"get")
1749        writer, reader = client.do_exchange(descriptor)
1750        with writer:
1751            table = reader.read_all()
1752        assert expected == table
1753
1754
1755def test_doexchange_put():
1756    """Emulate DoPut with DoExchange."""
1757    data = pa.Table.from_arrays([
1758        pa.array(range(0, 10 * 1024))
1759    ], names=["a"])
1760    batches = data.to_batches(max_chunksize=512)
1761
1762    with ExchangeFlightServer() as server:
1763        client = FlightClient(("localhost", server.port))
1764        descriptor = flight.FlightDescriptor.for_command(b"put")
1765        writer, reader = client.do_exchange(descriptor)
1766        with writer:
1767            writer.begin(data.schema)
1768            for batch in batches:
1769                writer.write_batch(batch)
1770            writer.done_writing()
1771            chunk = reader.read_chunk()
1772            assert chunk.data is None
1773            expected_buf = str(len(batches)).encode("utf-8")
1774            assert chunk.app_metadata == expected_buf
1775
1776
1777def test_doexchange_echo():
1778    """Try a DoExchange echo server."""
1779    data = pa.Table.from_arrays([
1780        pa.array(range(0, 10 * 1024))
1781    ], names=["a"])
1782    batches = data.to_batches(max_chunksize=512)
1783
1784    with ExchangeFlightServer() as server:
1785        client = FlightClient(("localhost", server.port))
1786        descriptor = flight.FlightDescriptor.for_command(b"echo")
1787        writer, reader = client.do_exchange(descriptor)
1788        with writer:
1789            # Read/write metadata before starting data.
1790            for i in range(10):
1791                buf = str(i).encode("utf-8")
1792                writer.write_metadata(buf)
1793                chunk = reader.read_chunk()
1794                assert chunk.data is None
1795                assert chunk.app_metadata == buf
1796
1797            # Now write data without metadata.
1798            writer.begin(data.schema)
1799            for batch in batches:
1800                writer.write_batch(batch)
1801                assert reader.schema == data.schema
1802                chunk = reader.read_chunk()
1803                assert chunk.data == batch
1804                assert chunk.app_metadata is None
1805
1806            # And write data with metadata.
1807            for i, batch in enumerate(batches):
1808                buf = str(i).encode("utf-8")
1809                writer.write_with_metadata(batch, buf)
1810                chunk = reader.read_chunk()
1811                assert chunk.data == batch
1812                assert chunk.app_metadata == buf
1813
1814
1815def test_doexchange_echo_v4():
1816    """Try a DoExchange echo server using the V4 metadata version."""
1817    data = pa.Table.from_arrays([
1818        pa.array(range(0, 10 * 1024))
1819    ], names=["a"])
1820    batches = data.to_batches(max_chunksize=512)
1821
1822    options = pa.ipc.IpcWriteOptions(
1823        metadata_version=pa.ipc.MetadataVersion.V4)
1824    with ExchangeFlightServer(options=options) as server:
1825        client = FlightClient(("localhost", server.port))
1826        descriptor = flight.FlightDescriptor.for_command(b"echo")
1827        writer, reader = client.do_exchange(descriptor)
1828        with writer:
1829            # Now write data without metadata.
1830            writer.begin(data.schema, options=options)
1831            for batch in batches:
1832                writer.write_batch(batch)
1833                assert reader.schema == data.schema
1834                chunk = reader.read_chunk()
1835                assert chunk.data == batch
1836                assert chunk.app_metadata is None
1837
1838
1839def test_doexchange_transform():
1840    """Transform a table with a service."""
1841    data = pa.Table.from_arrays([
1842        pa.array(range(0, 1024)),
1843        pa.array(range(1, 1025)),
1844        pa.array(range(2, 1026)),
1845    ], names=["a", "b", "c"])
1846    expected = pa.Table.from_arrays([
1847        pa.array(range(3, 1024 * 3 + 3, 3)),
1848    ], names=["sum"])
1849
1850    with ExchangeFlightServer() as server:
1851        client = FlightClient(("localhost", server.port))
1852        descriptor = flight.FlightDescriptor.for_command(b"transform")
1853        writer, reader = client.do_exchange(descriptor)
1854        with writer:
1855            writer.begin(data.schema)
1856            writer.write_table(data)
1857            writer.done_writing()
1858            table = reader.read_all()
1859        assert expected == table
1860
1861
1862def test_middleware_multi_header():
1863    """Test sending/receiving multiple (binary-valued) headers."""
1864    with MultiHeaderFlightServer(middleware={
1865        "test": MultiHeaderServerMiddlewareFactory(),
1866    }) as server:
1867        headers = MultiHeaderClientMiddlewareFactory()
1868        client = FlightClient(('localhost', server.port), middleware=[headers])
1869        response = next(client.do_action(flight.Action(b"", b"")))
1870        # The server echoes the headers it got back to us.
1871        raw_headers = response.body.to_pybytes().decode("utf-8")
1872        client_headers = ast.literal_eval(raw_headers)
1873        # Don't directly compare; gRPC may add headers like User-Agent.
1874        for header, values in MultiHeaderClientMiddleware.EXPECTED.items():
1875            assert client_headers.get(header) == values
1876            assert headers.last_headers.get(header) == values
1877
1878
1879@pytest.mark.requires_testing_data
1880def test_generic_options():
1881    """Test setting generic client options."""
1882    certs = example_tls_certs()
1883
1884    with ConstantFlightServer(tls_certificates=certs["certificates"]) as s:
1885        # Try setting a string argument that will make requests fail
1886        options = [("grpc.ssl_target_name_override", "fakehostname")]
1887        client = flight.connect(('localhost', s.port),
1888                                tls_root_certs=certs["root_cert"],
1889                                generic_options=options)
1890        with pytest.raises(flight.FlightUnavailableError):
1891            client.do_get(flight.Ticket(b'ints'))
1892        # Try setting an int argument that will make requests fail
1893        options = [("grpc.max_receive_message_length", 32)]
1894        client = flight.connect(('localhost', s.port),
1895                                tls_root_certs=certs["root_cert"],
1896                                generic_options=options)
1897        with pytest.raises(pa.ArrowInvalid):
1898            client.do_get(flight.Ticket(b'ints'))
1899
1900
1901class CancelFlightServer(FlightServerBase):
1902    """A server for testing StopToken."""
1903
1904    def do_get(self, context, ticket):
1905        schema = pa.schema([])
1906        rb = pa.RecordBatch.from_arrays([], schema=schema)
1907        return flight.GeneratorStream(schema, itertools.repeat(rb))
1908
1909    def do_exchange(self, context, descriptor, reader, writer):
1910        schema = pa.schema([])
1911        rb = pa.RecordBatch.from_arrays([], schema=schema)
1912        writer.begin(schema)
1913        while not context.is_cancelled():
1914            writer.write_batch(rb)
1915            time.sleep(0.5)
1916
1917
1918def test_interrupt():
1919    if threading.current_thread().ident != threading.main_thread().ident:
1920        pytest.skip("test only works from main Python thread")
1921    # Skips test if not available
1922    raise_signal = util.get_raise_signal()
1923
1924    def signal_from_thread():
1925        time.sleep(0.5)
1926        raise_signal(signal.SIGINT)
1927
1928    exc_types = (KeyboardInterrupt, pa.ArrowCancelled)
1929
1930    def test(read_all):
1931        try:
1932            try:
1933                t = threading.Thread(target=signal_from_thread)
1934                with pytest.raises(exc_types) as exc_info:
1935                    t.start()
1936                    read_all()
1937            finally:
1938                t.join()
1939        except KeyboardInterrupt:
1940            # In case KeyboardInterrupt didn't interrupt read_all
1941            # above, at least prevent it from stopping the test suite
1942            pytest.fail("KeyboardInterrupt didn't interrupt Flight read_all")
1943        e = exc_info.value.__context__
1944        assert isinstance(e, pa.ArrowCancelled) or \
1945            isinstance(e, KeyboardInterrupt)
1946
1947    with CancelFlightServer() as server:
1948        client = FlightClient(("localhost", server.port))
1949
1950        reader = client.do_get(flight.Ticket(b""))
1951        test(reader.read_all)
1952
1953        descriptor = flight.FlightDescriptor.for_command(b"echo")
1954        writer, reader = client.do_exchange(descriptor)
1955        test(reader.read_all)
1956
1957
1958def test_never_sends_data():
1959    # Regression test for ARROW-12779
1960    match = "application server implementation error"
1961    with NeverSendsDataFlightServer() as server:
1962        client = flight.connect(('localhost', server.port))
1963        with pytest.raises(flight.FlightServerError, match=match):
1964            client.do_get(flight.Ticket(b'')).read_all()
1965
1966        # Check that the server handler will ignore empty tables
1967        # up to a certain extent
1968        table = client.do_get(flight.Ticket(b'yield_data')).read_all()
1969        assert table.num_rows == 5
1970
1971
1972@pytest.mark.large_memory
1973@pytest.mark.slow
1974def test_large_descriptor():
1975    # Regression test for ARROW-13253. Placed here with appropriate marks
1976    # since some CI pipelines can't run the C++ equivalent
1977    large_descriptor = flight.FlightDescriptor.for_command(
1978        b' ' * (2 ** 31 + 1))
1979    with FlightServerBase() as server:
1980        client = flight.connect(('localhost', server.port))
1981        with pytest.raises(OSError,
1982                           match="Failed to serialize Flight descriptor"):
1983            writer, _ = client.do_put(large_descriptor, pa.schema([]))
1984            writer.close()
1985        with pytest.raises(pa.ArrowException,
1986                           match="Failed to serialize Flight descriptor"):
1987            client.do_exchange(large_descriptor)
1988
1989
1990@pytest.mark.large_memory
1991@pytest.mark.slow
1992def test_large_metadata_client():
1993    # Regression test for ARROW-13253
1994    descriptor = flight.FlightDescriptor.for_command(b'')
1995    metadata = b' ' * (2 ** 31 + 1)
1996    with EchoFlightServer() as server:
1997        client = flight.connect(('localhost', server.port))
1998        with pytest.raises(pa.ArrowCapacityError,
1999                           match="app_metadata size overflow"):
2000            writer, _ = client.do_put(descriptor, pa.schema([]))
2001            with writer:
2002                writer.write_metadata(metadata)
2003                writer.close()
2004        with pytest.raises(pa.ArrowCapacityError,
2005                           match="app_metadata size overflow"):
2006            writer, reader = client.do_exchange(descriptor)
2007            with writer:
2008                writer.write_metadata(metadata)
2009
2010    del metadata
2011    with LargeMetadataFlightServer() as server:
2012        client = flight.connect(('localhost', server.port))
2013        with pytest.raises(flight.FlightServerError,
2014                           match="app_metadata size overflow"):
2015            reader = client.do_get(flight.Ticket(b''))
2016            reader.read_all()
2017        with pytest.raises(pa.ArrowException,
2018                           match="app_metadata size overflow"):
2019            writer, reader = client.do_exchange(descriptor)
2020            with writer:
2021                reader.read_all()
2022
2023
2024class ActionNoneFlightServer(EchoFlightServer):
2025    """A server that implements a side effect to a non iterable action."""
2026    VALUES = []
2027
2028    def do_action(self, context, action):
2029        if action.type == "get_value":
2030            return [json.dumps(self.VALUES).encode('utf-8')]
2031        elif action.type == "append":
2032            self.VALUES.append(True)
2033            return None
2034        raise NotImplementedError
2035
2036
2037def test_none_action_side_effect():
2038    """Ensure that actions are executed even when we don't consume iterator.
2039
2040    See https://issues.apache.org/jira/browse/ARROW-14255
2041    """
2042
2043    with ActionNoneFlightServer() as server:
2044        client = FlightClient(('localhost', server.port))
2045        client.do_action(flight.Action("append", b""))
2046        r = client.do_action(flight.Action("get_value", b""))
2047        assert json.loads(next(r).body.to_pybytes()) == [True]
2048