1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #pragma once
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "arrow/flight/api.h"
25 #include "arrow/ipc/dictionary.h"
26 #include "arrow/python/common.h"
27 
28 #if defined(_WIN32) || defined(__CYGWIN__)  // Windows
29 #if defined(_MSC_VER)
30 #pragma warning(disable : 4251)
31 #else
32 #pragma GCC diagnostic ignored "-Wattributes"
33 #endif
34 
35 #ifdef ARROW_STATIC
36 #define ARROW_PYFLIGHT_EXPORT
37 #elif defined(ARROW_PYFLIGHT_EXPORTING)
38 #define ARROW_PYFLIGHT_EXPORT __declspec(dllexport)
39 #else
40 #define ARROW_PYFLIGHT_EXPORT __declspec(dllimport)
41 #endif
42 
43 #else  // Not Windows
44 #ifndef ARROW_PYFLIGHT_EXPORT
45 #define ARROW_PYFLIGHT_EXPORT __attribute__((visibility("default")))
46 #endif
47 #endif  // Non-Windows
48 
49 namespace arrow {
50 
51 namespace py {
52 
53 namespace flight {
54 
55 ARROW_PYFLIGHT_EXPORT
56 extern const char* kPyServerMiddlewareName;
57 
58 /// \brief A table of function pointers for calling from C++ into
59 /// Python.
60 class ARROW_PYFLIGHT_EXPORT PyFlightServerVtable {
61  public:
62   std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
63                        const arrow::flight::Criteria*,
64                        std::unique_ptr<arrow::flight::FlightListing>*)>
65       list_flights;
66   std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
67                        const arrow::flight::FlightDescriptor&,
68                        std::unique_ptr<arrow::flight::FlightInfo>*)>
69       get_flight_info;
70   std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
71                        const arrow::flight::FlightDescriptor&,
72                        std::unique_ptr<arrow::flight::SchemaResult>*)>
73       get_schema;
74   std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
75                        const arrow::flight::Ticket&,
76                        std::unique_ptr<arrow::flight::FlightDataStream>*)>
77       do_get;
78   std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
79                        std::unique_ptr<arrow::flight::FlightMessageReader>,
80                        std::unique_ptr<arrow::flight::FlightMetadataWriter>)>
81       do_put;
82   std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
83                        std::unique_ptr<arrow::flight::FlightMessageReader>,
84                        std::unique_ptr<arrow::flight::FlightMessageWriter>)>
85       do_exchange;
86   std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
87                        const arrow::flight::Action&,
88                        std::unique_ptr<arrow::flight::ResultStream>*)>
89       do_action;
90   std::function<Status(PyObject*, const arrow::flight::ServerCallContext&,
91                        std::vector<arrow::flight::ActionType>*)>
92       list_actions;
93 };
94 
95 class ARROW_PYFLIGHT_EXPORT PyServerAuthHandlerVtable {
96  public:
97   std::function<Status(PyObject*, arrow::flight::ServerAuthSender*,
98                        arrow::flight::ServerAuthReader*)>
99       authenticate;
100   std::function<Status(PyObject*, const std::string&, std::string*)> is_valid;
101 };
102 
103 class ARROW_PYFLIGHT_EXPORT PyClientAuthHandlerVtable {
104  public:
105   std::function<Status(PyObject*, arrow::flight::ClientAuthSender*,
106                        arrow::flight::ClientAuthReader*)>
107       authenticate;
108   std::function<Status(PyObject*, std::string*)> get_token;
109 };
110 
111 /// \brief A helper to implement an auth mechanism in Python.
112 class ARROW_PYFLIGHT_EXPORT PyServerAuthHandler
113     : public arrow::flight::ServerAuthHandler {
114  public:
115   explicit PyServerAuthHandler(PyObject* handler,
116                                const PyServerAuthHandlerVtable& vtable);
117   Status Authenticate(arrow::flight::ServerAuthSender* outgoing,
118                       arrow::flight::ServerAuthReader* incoming) override;
119   Status IsValid(const std::string& token, std::string* peer_identity) override;
120 
121  private:
122   OwnedRefNoGIL handler_;
123   PyServerAuthHandlerVtable vtable_;
124 };
125 
126 /// \brief A helper to implement an auth mechanism in Python.
127 class ARROW_PYFLIGHT_EXPORT PyClientAuthHandler
128     : public arrow::flight::ClientAuthHandler {
129  public:
130   explicit PyClientAuthHandler(PyObject* handler,
131                                const PyClientAuthHandlerVtable& vtable);
132   Status Authenticate(arrow::flight::ClientAuthSender* outgoing,
133                       arrow::flight::ClientAuthReader* incoming) override;
134   Status GetToken(std::string* token) override;
135 
136  private:
137   OwnedRefNoGIL handler_;
138   PyClientAuthHandlerVtable vtable_;
139 };
140 
141 class ARROW_PYFLIGHT_EXPORT PyFlightServer : public arrow::flight::FlightServerBase {
142  public:
143   explicit PyFlightServer(PyObject* server, const PyFlightServerVtable& vtable);
144 
145   // Like Serve(), but set up signals and invoke Python signal handlers
146   // if necessary.  This function may return with a Python exception set.
147   Status ServeWithSignals();
148 
149   Status ListFlights(const arrow::flight::ServerCallContext& context,
150                      const arrow::flight::Criteria* criteria,
151                      std::unique_ptr<arrow::flight::FlightListing>* listings) override;
152   Status GetFlightInfo(const arrow::flight::ServerCallContext& context,
153                        const arrow::flight::FlightDescriptor& request,
154                        std::unique_ptr<arrow::flight::FlightInfo>* info) override;
155   Status GetSchema(const arrow::flight::ServerCallContext& context,
156                    const arrow::flight::FlightDescriptor& request,
157                    std::unique_ptr<arrow::flight::SchemaResult>* result) override;
158   Status DoGet(const arrow::flight::ServerCallContext& context,
159                const arrow::flight::Ticket& request,
160                std::unique_ptr<arrow::flight::FlightDataStream>* stream) override;
161   Status DoPut(const arrow::flight::ServerCallContext& context,
162                std::unique_ptr<arrow::flight::FlightMessageReader> reader,
163                std::unique_ptr<arrow::flight::FlightMetadataWriter> writer) override;
164   Status DoExchange(const arrow::flight::ServerCallContext& context,
165                     std::unique_ptr<arrow::flight::FlightMessageReader> reader,
166                     std::unique_ptr<arrow::flight::FlightMessageWriter> writer) override;
167   Status DoAction(const arrow::flight::ServerCallContext& context,
168                   const arrow::flight::Action& action,
169                   std::unique_ptr<arrow::flight::ResultStream>* result) override;
170   Status ListActions(const arrow::flight::ServerCallContext& context,
171                      std::vector<arrow::flight::ActionType>* actions) override;
172 
173  private:
174   OwnedRefNoGIL server_;
175   PyFlightServerVtable vtable_;
176 };
177 
178 /// \brief A callback that obtains the next result from a Flight action.
179 typedef std::function<Status(PyObject*, std::unique_ptr<arrow::flight::Result>*)>
180     PyFlightResultStreamCallback;
181 
182 /// \brief A ResultStream built around a Python callback.
183 class ARROW_PYFLIGHT_EXPORT PyFlightResultStream : public arrow::flight::ResultStream {
184  public:
185   /// \brief Construct a FlightResultStream from a Python object and callback.
186   /// Must only be called while holding the GIL.
187   explicit PyFlightResultStream(PyObject* generator,
188                                 PyFlightResultStreamCallback callback);
189   Status Next(std::unique_ptr<arrow::flight::Result>* result) override;
190 
191  private:
192   OwnedRefNoGIL generator_;
193   PyFlightResultStreamCallback callback_;
194 };
195 
196 /// \brief A wrapper around a FlightDataStream that keeps alive a
197 /// Python object backing it.
198 class ARROW_PYFLIGHT_EXPORT PyFlightDataStream : public arrow::flight::FlightDataStream {
199  public:
200   /// \brief Construct a FlightDataStream from a Python object and underlying stream.
201   /// Must only be called while holding the GIL.
202   explicit PyFlightDataStream(PyObject* data_source,
203                               std::unique_ptr<arrow::flight::FlightDataStream> stream);
204 
205   std::shared_ptr<Schema> schema() override;
206   Status GetSchemaPayload(arrow::flight::FlightPayload* payload) override;
207   Status Next(arrow::flight::FlightPayload* payload) override;
208 
209  private:
210   OwnedRefNoGIL data_source_;
211   std::unique_ptr<arrow::flight::FlightDataStream> stream_;
212 };
213 
214 class ARROW_PYFLIGHT_EXPORT PyServerMiddlewareFactory
215     : public arrow::flight::ServerMiddlewareFactory {
216  public:
217   /// \brief A callback to create the middleware instance in Python
218   typedef std::function<Status(
219       PyObject*, const arrow::flight::CallInfo& info,
220       const arrow::flight::CallHeaders& incoming_headers,
221       std::shared_ptr<arrow::flight::ServerMiddleware>* middleware)>
222       StartCallCallback;
223 
224   /// \brief Must only be called while holding the GIL.
225   explicit PyServerMiddlewareFactory(PyObject* factory, StartCallCallback start_call);
226 
227   Status StartCall(const arrow::flight::CallInfo& info,
228                    const arrow::flight::CallHeaders& incoming_headers,
229                    std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) override;
230 
231  private:
232   OwnedRefNoGIL factory_;
233   StartCallCallback start_call_;
234 };
235 
236 class ARROW_PYFLIGHT_EXPORT PyServerMiddleware : public arrow::flight::ServerMiddleware {
237  public:
238   typedef std::function<Status(PyObject*,
239                                arrow::flight::AddCallHeaders* outgoing_headers)>
240       SendingHeadersCallback;
241   typedef std::function<Status(PyObject*, const Status& status)> CallCompletedCallback;
242 
243   struct Vtable {
244     SendingHeadersCallback sending_headers;
245     CallCompletedCallback call_completed;
246   };
247 
248   /// \brief Must only be called while holding the GIL.
249   explicit PyServerMiddleware(PyObject* middleware, Vtable vtable);
250 
251   void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override;
252   void CallCompleted(const Status& status) override;
253   std::string name() const override;
254   /// \brief Get the underlying Python object.
255   PyObject* py_object() const;
256 
257  private:
258   OwnedRefNoGIL middleware_;
259   Vtable vtable_;
260 };
261 
262 class ARROW_PYFLIGHT_EXPORT PyClientMiddlewareFactory
263     : public arrow::flight::ClientMiddlewareFactory {
264  public:
265   /// \brief A callback to create the middleware instance in Python
266   typedef std::function<Status(
267       PyObject*, const arrow::flight::CallInfo& info,
268       std::unique_ptr<arrow::flight::ClientMiddleware>* middleware)>
269       StartCallCallback;
270 
271   /// \brief Must only be called while holding the GIL.
272   explicit PyClientMiddlewareFactory(PyObject* factory, StartCallCallback start_call);
273 
274   void StartCall(const arrow::flight::CallInfo& info,
275                  std::unique_ptr<arrow::flight::ClientMiddleware>* middleware) override;
276 
277  private:
278   OwnedRefNoGIL factory_;
279   StartCallCallback start_call_;
280 };
281 
282 class ARROW_PYFLIGHT_EXPORT PyClientMiddleware : public arrow::flight::ClientMiddleware {
283  public:
284   typedef std::function<Status(PyObject*,
285                                arrow::flight::AddCallHeaders* outgoing_headers)>
286       SendingHeadersCallback;
287   typedef std::function<Status(PyObject*,
288                                const arrow::flight::CallHeaders& incoming_headers)>
289       ReceivedHeadersCallback;
290   typedef std::function<Status(PyObject*, const Status& status)> CallCompletedCallback;
291 
292   struct Vtable {
293     SendingHeadersCallback sending_headers;
294     ReceivedHeadersCallback received_headers;
295     CallCompletedCallback call_completed;
296   };
297 
298   /// \brief Must only be called while holding the GIL.
299   explicit PyClientMiddleware(PyObject* factory, Vtable vtable);
300 
301   void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override;
302   void ReceivedHeaders(const arrow::flight::CallHeaders& incoming_headers) override;
303   void CallCompleted(const Status& status) override;
304 
305  private:
306   OwnedRefNoGIL middleware_;
307   Vtable vtable_;
308 };
309 
310 /// \brief A callback that obtains the next payload from a Flight result stream.
311 typedef std::function<Status(PyObject*, arrow::flight::FlightPayload*)>
312     PyGeneratorFlightDataStreamCallback;
313 
314 /// \brief A FlightDataStream built around a Python callback.
315 class ARROW_PYFLIGHT_EXPORT PyGeneratorFlightDataStream
316     : public arrow::flight::FlightDataStream {
317  public:
318   /// \brief Construct a FlightDataStream from a Python object and underlying stream.
319   /// Must only be called while holding the GIL.
320   explicit PyGeneratorFlightDataStream(PyObject* generator,
321                                        std::shared_ptr<arrow::Schema> schema,
322                                        PyGeneratorFlightDataStreamCallback callback,
323                                        const ipc::IpcWriteOptions& options);
324   std::shared_ptr<Schema> schema() override;
325   Status GetSchemaPayload(arrow::flight::FlightPayload* payload) override;
326   Status Next(arrow::flight::FlightPayload* payload) override;
327 
328  private:
329   OwnedRefNoGIL generator_;
330   std::shared_ptr<arrow::Schema> schema_;
331   ipc::DictionaryFieldMapper mapper_;
332   ipc::IpcWriteOptions options_;
333   PyGeneratorFlightDataStreamCallback callback_;
334 };
335 
336 ARROW_PYFLIGHT_EXPORT
337 Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema,
338                         const arrow::flight::FlightDescriptor& descriptor,
339                         const std::vector<arrow::flight::FlightEndpoint>& endpoints,
340                         int64_t total_records, int64_t total_bytes,
341                         std::unique_ptr<arrow::flight::FlightInfo>* out);
342 
343 ARROW_PYFLIGHT_EXPORT
344 Status DeserializeBasicAuth(const std::string& buf,
345                             std::unique_ptr<arrow::flight::BasicAuth>* out);
346 
347 ARROW_PYFLIGHT_EXPORT
348 Status SerializeBasicAuth(const arrow::flight::BasicAuth& basic_auth, std::string* out);
349 
350 /// \brief Create a SchemaResult from schema.
351 ARROW_PYFLIGHT_EXPORT
352 Status CreateSchemaResult(const std::shared_ptr<arrow::Schema>& schema,
353                           std::unique_ptr<arrow::flight::SchemaResult>* out);
354 
355 }  // namespace flight
356 }  // namespace py
357 }  // namespace arrow
358