1 // Licensed to the Apache Software Foundation (ASF) under one 2 // or more contributor license agreements. See the NOTICE file 3 // distributed with this work for additional information 4 // regarding copyright ownership. The ASF licenses this file 5 // to you under the Apache License, Version 2.0 (the 6 // "License"); you may not use this file except in compliance 7 // with the License. You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, 12 // software distributed under the License is distributed on an 13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14 // KIND, either express or implied. See the License for the 15 // specific language governing permissions and limitations 16 // under the License. 17 18 #pragma once 19 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "arrow/flight/api.h" 25 #include "arrow/ipc/dictionary.h" 26 #include "arrow/python/common.h" 27 28 #if defined(_WIN32) || defined(__CYGWIN__) // Windows 29 #if defined(_MSC_VER) 30 #pragma warning(disable : 4251) 31 #else 32 #pragma GCC diagnostic ignored "-Wattributes" 33 #endif 34 35 #ifdef ARROW_STATIC 36 #define ARROW_PYFLIGHT_EXPORT 37 #elif defined(ARROW_PYFLIGHT_EXPORTING) 38 #define ARROW_PYFLIGHT_EXPORT __declspec(dllexport) 39 #else 40 #define ARROW_PYFLIGHT_EXPORT __declspec(dllimport) 41 #endif 42 43 #else // Not Windows 44 #ifndef ARROW_PYFLIGHT_EXPORT 45 #define ARROW_PYFLIGHT_EXPORT __attribute__((visibility("default"))) 46 #endif 47 #endif // Non-Windows 48 49 namespace arrow { 50 51 namespace py { 52 53 namespace flight { 54 55 ARROW_PYFLIGHT_EXPORT 56 extern const char* kPyServerMiddlewareName; 57 58 /// \brief A table of function pointers for calling from C++ into 59 /// Python. 60 class ARROW_PYFLIGHT_EXPORT PyFlightServerVtable { 61 public: 62 std::function<Status(PyObject*, const arrow::flight::ServerCallContext&, 63 const arrow::flight::Criteria*, 64 std::unique_ptr<arrow::flight::FlightListing>*)> 65 list_flights; 66 std::function<Status(PyObject*, const arrow::flight::ServerCallContext&, 67 const arrow::flight::FlightDescriptor&, 68 std::unique_ptr<arrow::flight::FlightInfo>*)> 69 get_flight_info; 70 std::function<Status(PyObject*, const arrow::flight::ServerCallContext&, 71 const arrow::flight::FlightDescriptor&, 72 std::unique_ptr<arrow::flight::SchemaResult>*)> 73 get_schema; 74 std::function<Status(PyObject*, const arrow::flight::ServerCallContext&, 75 const arrow::flight::Ticket&, 76 std::unique_ptr<arrow::flight::FlightDataStream>*)> 77 do_get; 78 std::function<Status(PyObject*, const arrow::flight::ServerCallContext&, 79 std::unique_ptr<arrow::flight::FlightMessageReader>, 80 std::unique_ptr<arrow::flight::FlightMetadataWriter>)> 81 do_put; 82 std::function<Status(PyObject*, const arrow::flight::ServerCallContext&, 83 std::unique_ptr<arrow::flight::FlightMessageReader>, 84 std::unique_ptr<arrow::flight::FlightMessageWriter>)> 85 do_exchange; 86 std::function<Status(PyObject*, const arrow::flight::ServerCallContext&, 87 const arrow::flight::Action&, 88 std::unique_ptr<arrow::flight::ResultStream>*)> 89 do_action; 90 std::function<Status(PyObject*, const arrow::flight::ServerCallContext&, 91 std::vector<arrow::flight::ActionType>*)> 92 list_actions; 93 }; 94 95 class ARROW_PYFLIGHT_EXPORT PyServerAuthHandlerVtable { 96 public: 97 std::function<Status(PyObject*, arrow::flight::ServerAuthSender*, 98 arrow::flight::ServerAuthReader*)> 99 authenticate; 100 std::function<Status(PyObject*, const std::string&, std::string*)> is_valid; 101 }; 102 103 class ARROW_PYFLIGHT_EXPORT PyClientAuthHandlerVtable { 104 public: 105 std::function<Status(PyObject*, arrow::flight::ClientAuthSender*, 106 arrow::flight::ClientAuthReader*)> 107 authenticate; 108 std::function<Status(PyObject*, std::string*)> get_token; 109 }; 110 111 /// \brief A helper to implement an auth mechanism in Python. 112 class ARROW_PYFLIGHT_EXPORT PyServerAuthHandler 113 : public arrow::flight::ServerAuthHandler { 114 public: 115 explicit PyServerAuthHandler(PyObject* handler, 116 const PyServerAuthHandlerVtable& vtable); 117 Status Authenticate(arrow::flight::ServerAuthSender* outgoing, 118 arrow::flight::ServerAuthReader* incoming) override; 119 Status IsValid(const std::string& token, std::string* peer_identity) override; 120 121 private: 122 OwnedRefNoGIL handler_; 123 PyServerAuthHandlerVtable vtable_; 124 }; 125 126 /// \brief A helper to implement an auth mechanism in Python. 127 class ARROW_PYFLIGHT_EXPORT PyClientAuthHandler 128 : public arrow::flight::ClientAuthHandler { 129 public: 130 explicit PyClientAuthHandler(PyObject* handler, 131 const PyClientAuthHandlerVtable& vtable); 132 Status Authenticate(arrow::flight::ClientAuthSender* outgoing, 133 arrow::flight::ClientAuthReader* incoming) override; 134 Status GetToken(std::string* token) override; 135 136 private: 137 OwnedRefNoGIL handler_; 138 PyClientAuthHandlerVtable vtable_; 139 }; 140 141 class ARROW_PYFLIGHT_EXPORT PyFlightServer : public arrow::flight::FlightServerBase { 142 public: 143 explicit PyFlightServer(PyObject* server, const PyFlightServerVtable& vtable); 144 145 // Like Serve(), but set up signals and invoke Python signal handlers 146 // if necessary. This function may return with a Python exception set. 147 Status ServeWithSignals(); 148 149 Status ListFlights(const arrow::flight::ServerCallContext& context, 150 const arrow::flight::Criteria* criteria, 151 std::unique_ptr<arrow::flight::FlightListing>* listings) override; 152 Status GetFlightInfo(const arrow::flight::ServerCallContext& context, 153 const arrow::flight::FlightDescriptor& request, 154 std::unique_ptr<arrow::flight::FlightInfo>* info) override; 155 Status GetSchema(const arrow::flight::ServerCallContext& context, 156 const arrow::flight::FlightDescriptor& request, 157 std::unique_ptr<arrow::flight::SchemaResult>* result) override; 158 Status DoGet(const arrow::flight::ServerCallContext& context, 159 const arrow::flight::Ticket& request, 160 std::unique_ptr<arrow::flight::FlightDataStream>* stream) override; 161 Status DoPut(const arrow::flight::ServerCallContext& context, 162 std::unique_ptr<arrow::flight::FlightMessageReader> reader, 163 std::unique_ptr<arrow::flight::FlightMetadataWriter> writer) override; 164 Status DoExchange(const arrow::flight::ServerCallContext& context, 165 std::unique_ptr<arrow::flight::FlightMessageReader> reader, 166 std::unique_ptr<arrow::flight::FlightMessageWriter> writer) override; 167 Status DoAction(const arrow::flight::ServerCallContext& context, 168 const arrow::flight::Action& action, 169 std::unique_ptr<arrow::flight::ResultStream>* result) override; 170 Status ListActions(const arrow::flight::ServerCallContext& context, 171 std::vector<arrow::flight::ActionType>* actions) override; 172 173 private: 174 OwnedRefNoGIL server_; 175 PyFlightServerVtable vtable_; 176 }; 177 178 /// \brief A callback that obtains the next result from a Flight action. 179 typedef std::function<Status(PyObject*, std::unique_ptr<arrow::flight::Result>*)> 180 PyFlightResultStreamCallback; 181 182 /// \brief A ResultStream built around a Python callback. 183 class ARROW_PYFLIGHT_EXPORT PyFlightResultStream : public arrow::flight::ResultStream { 184 public: 185 /// \brief Construct a FlightResultStream from a Python object and callback. 186 /// Must only be called while holding the GIL. 187 explicit PyFlightResultStream(PyObject* generator, 188 PyFlightResultStreamCallback callback); 189 Status Next(std::unique_ptr<arrow::flight::Result>* result) override; 190 191 private: 192 OwnedRefNoGIL generator_; 193 PyFlightResultStreamCallback callback_; 194 }; 195 196 /// \brief A wrapper around a FlightDataStream that keeps alive a 197 /// Python object backing it. 198 class ARROW_PYFLIGHT_EXPORT PyFlightDataStream : public arrow::flight::FlightDataStream { 199 public: 200 /// \brief Construct a FlightDataStream from a Python object and underlying stream. 201 /// Must only be called while holding the GIL. 202 explicit PyFlightDataStream(PyObject* data_source, 203 std::unique_ptr<arrow::flight::FlightDataStream> stream); 204 205 std::shared_ptr<Schema> schema() override; 206 Status GetSchemaPayload(arrow::flight::FlightPayload* payload) override; 207 Status Next(arrow::flight::FlightPayload* payload) override; 208 209 private: 210 OwnedRefNoGIL data_source_; 211 std::unique_ptr<arrow::flight::FlightDataStream> stream_; 212 }; 213 214 class ARROW_PYFLIGHT_EXPORT PyServerMiddlewareFactory 215 : public arrow::flight::ServerMiddlewareFactory { 216 public: 217 /// \brief A callback to create the middleware instance in Python 218 typedef std::function<Status( 219 PyObject*, const arrow::flight::CallInfo& info, 220 const arrow::flight::CallHeaders& incoming_headers, 221 std::shared_ptr<arrow::flight::ServerMiddleware>* middleware)> 222 StartCallCallback; 223 224 /// \brief Must only be called while holding the GIL. 225 explicit PyServerMiddlewareFactory(PyObject* factory, StartCallCallback start_call); 226 227 Status StartCall(const arrow::flight::CallInfo& info, 228 const arrow::flight::CallHeaders& incoming_headers, 229 std::shared_ptr<arrow::flight::ServerMiddleware>* middleware) override; 230 231 private: 232 OwnedRefNoGIL factory_; 233 StartCallCallback start_call_; 234 }; 235 236 class ARROW_PYFLIGHT_EXPORT PyServerMiddleware : public arrow::flight::ServerMiddleware { 237 public: 238 typedef std::function<Status(PyObject*, 239 arrow::flight::AddCallHeaders* outgoing_headers)> 240 SendingHeadersCallback; 241 typedef std::function<Status(PyObject*, const Status& status)> CallCompletedCallback; 242 243 struct Vtable { 244 SendingHeadersCallback sending_headers; 245 CallCompletedCallback call_completed; 246 }; 247 248 /// \brief Must only be called while holding the GIL. 249 explicit PyServerMiddleware(PyObject* middleware, Vtable vtable); 250 251 void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override; 252 void CallCompleted(const Status& status) override; 253 std::string name() const override; 254 /// \brief Get the underlying Python object. 255 PyObject* py_object() const; 256 257 private: 258 OwnedRefNoGIL middleware_; 259 Vtable vtable_; 260 }; 261 262 class ARROW_PYFLIGHT_EXPORT PyClientMiddlewareFactory 263 : public arrow::flight::ClientMiddlewareFactory { 264 public: 265 /// \brief A callback to create the middleware instance in Python 266 typedef std::function<Status( 267 PyObject*, const arrow::flight::CallInfo& info, 268 std::unique_ptr<arrow::flight::ClientMiddleware>* middleware)> 269 StartCallCallback; 270 271 /// \brief Must only be called while holding the GIL. 272 explicit PyClientMiddlewareFactory(PyObject* factory, StartCallCallback start_call); 273 274 void StartCall(const arrow::flight::CallInfo& info, 275 std::unique_ptr<arrow::flight::ClientMiddleware>* middleware) override; 276 277 private: 278 OwnedRefNoGIL factory_; 279 StartCallCallback start_call_; 280 }; 281 282 class ARROW_PYFLIGHT_EXPORT PyClientMiddleware : public arrow::flight::ClientMiddleware { 283 public: 284 typedef std::function<Status(PyObject*, 285 arrow::flight::AddCallHeaders* outgoing_headers)> 286 SendingHeadersCallback; 287 typedef std::function<Status(PyObject*, 288 const arrow::flight::CallHeaders& incoming_headers)> 289 ReceivedHeadersCallback; 290 typedef std::function<Status(PyObject*, const Status& status)> CallCompletedCallback; 291 292 struct Vtable { 293 SendingHeadersCallback sending_headers; 294 ReceivedHeadersCallback received_headers; 295 CallCompletedCallback call_completed; 296 }; 297 298 /// \brief Must only be called while holding the GIL. 299 explicit PyClientMiddleware(PyObject* factory, Vtable vtable); 300 301 void SendingHeaders(arrow::flight::AddCallHeaders* outgoing_headers) override; 302 void ReceivedHeaders(const arrow::flight::CallHeaders& incoming_headers) override; 303 void CallCompleted(const Status& status) override; 304 305 private: 306 OwnedRefNoGIL middleware_; 307 Vtable vtable_; 308 }; 309 310 /// \brief A callback that obtains the next payload from a Flight result stream. 311 typedef std::function<Status(PyObject*, arrow::flight::FlightPayload*)> 312 PyGeneratorFlightDataStreamCallback; 313 314 /// \brief A FlightDataStream built around a Python callback. 315 class ARROW_PYFLIGHT_EXPORT PyGeneratorFlightDataStream 316 : public arrow::flight::FlightDataStream { 317 public: 318 /// \brief Construct a FlightDataStream from a Python object and underlying stream. 319 /// Must only be called while holding the GIL. 320 explicit PyGeneratorFlightDataStream(PyObject* generator, 321 std::shared_ptr<arrow::Schema> schema, 322 PyGeneratorFlightDataStreamCallback callback, 323 const ipc::IpcWriteOptions& options); 324 std::shared_ptr<Schema> schema() override; 325 Status GetSchemaPayload(arrow::flight::FlightPayload* payload) override; 326 Status Next(arrow::flight::FlightPayload* payload) override; 327 328 private: 329 OwnedRefNoGIL generator_; 330 std::shared_ptr<arrow::Schema> schema_; 331 ipc::DictionaryFieldMapper mapper_; 332 ipc::IpcWriteOptions options_; 333 PyGeneratorFlightDataStreamCallback callback_; 334 }; 335 336 ARROW_PYFLIGHT_EXPORT 337 Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema, 338 const arrow::flight::FlightDescriptor& descriptor, 339 const std::vector<arrow::flight::FlightEndpoint>& endpoints, 340 int64_t total_records, int64_t total_bytes, 341 std::unique_ptr<arrow::flight::FlightInfo>* out); 342 343 ARROW_PYFLIGHT_EXPORT 344 Status DeserializeBasicAuth(const std::string& buf, 345 std::unique_ptr<arrow::flight::BasicAuth>* out); 346 347 ARROW_PYFLIGHT_EXPORT 348 Status SerializeBasicAuth(const arrow::flight::BasicAuth& basic_auth, std::string* out); 349 350 /// \brief Create a SchemaResult from schema. 351 ARROW_PYFLIGHT_EXPORT 352 Status CreateSchemaResult(const std::shared_ptr<arrow::Schema>& schema, 353 std::unique_ptr<arrow::flight::SchemaResult>* out); 354 355 } // namespace flight 356 } // namespace py 357 } // namespace arrow 358