1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18#' Load a Python Flight server
19#'
20#' @param name string Python module name
21#' @param path file system path where the Python module is found. Default is
22#' to look in the `inst/` directory for included modules.
23#' @export
24#' @examplesIf FALSE
25#' load_flight_server("demo_flight_server")
26load_flight_server <- function(name, path = system.file(package = "arrow")) {
27  reticulate::import_from_path(name, path)
28}
29
30#' Connect to a Flight server
31#'
32#' @param host string hostname to connect to
33#' @param port integer port to connect on
34#' @param scheme URL scheme, default is "grpc+tcp"
35#' @return A `pyarrow.flight.FlightClient`.
36#' @export
37flight_connect <- function(host = "localhost", port, scheme = "grpc+tcp") {
38  pa <- reticulate::import("pyarrow")
39  location <- paste0(scheme, "://", host, ":", port)
40  pa$flight$FlightClient(location)
41}
42
43#' Send data to a Flight server
44#'
45#' @param client `pyarrow.flight.FlightClient`, as returned by [flight_connect()]
46#' @param data `data.frame`, [RecordBatch], or [Table] to upload
47#' @param path string identifier to store the data under
48#' @param overwrite logical: if `path` exists on `client` already, should we
49#' replace it with the contents of `data`? Default is `TRUE`; if `FALSE` and
50#' `path` exists, the function will error.
51#' @return `client`, invisibly.
52#' @export
53flight_put <- function(client, data, path, overwrite = TRUE) {
54  if (!overwrite && flight_path_exists(client, path)) {
55    stop(path, " exists.", call. = FALSE)
56  }
57  if (is.data.frame(data)) {
58    data <- Table$create(data)
59  }
60  py_data <- reticulate::r_to_py(data)
61  writer <- client$do_put(descriptor_for_path(path), py_data$schema)[[1]]
62  if (inherits(data, "RecordBatch")) {
63    writer$write_batch(py_data)
64  } else {
65    writer$write_table(py_data)
66  }
67  writer$close()
68  invisible(client)
69}
70
71#' Get data from a Flight server
72#'
73#' @param client `pyarrow.flight.FlightClient`, as returned by [flight_connect()]
74#' @param path string identifier under which data is stored
75#' @return A [Table]
76#' @export
77flight_get <- function(client, path) {
78  reader <- flight_reader(client, path)
79  reader$read_all()
80}
81
82# TODO: could use this as a RecordBatch iterator, call $read_chunk() on this
83flight_reader <- function(client, path) {
84  info <- client$get_flight_info(descriptor_for_path(path))
85  # Hack: assume a single ticket, on the same server as client is already connected
86  ticket <- info$endpoints[[1]]$ticket
87  client$do_get(ticket)
88}
89
90descriptor_for_path <- function(path) {
91  pa <- reticulate::import("pyarrow")
92  pa$flight$FlightDescriptor$for_path(path)
93}
94
95#' See available resources on a Flight server
96#'
97#' @inheritParams flight_get
98#' @return `list_flights()` returns a character vector of paths.
99#' `flight_path_exists()` returns a logical value, the equivalent of `path %in% list_flights()`
100#' @export
101list_flights <- function(client) {
102  generator <- client$list_flights()
103  out <- reticulate::iterate(generator, function(x) as.character(x$descriptor$path[[1]]))
104  out
105}
106
107#' @rdname list_flights
108#' @export
109flight_path_exists <- function(client, path) {
110  it_exists <- tryCatch(
111    expr = {
112      client$get_flight_info(descriptor_for_path(path))
113      TRUE
114    },
115    error = function(e) {
116      msg <- conditionMessage(e)
117      if (!any(grepl("ArrowKeyError", msg))) {
118        # Raise an error if this fails for any reason other than not found
119        stop(e)
120      }
121      FALSE
122    }
123  )
124}
125