1// Licensed to the Apache Software Foundation (ASF) under one 2// or more contributor license agreements. See the NOTICE file 3// distributed with this work for additional information 4// regarding copyright ownership. The ASF licenses this file 5// to you under the Apache License, Version 2.0 (the 6// "License"); you may not use this file except in compliance 7// with the License. You may obtain a copy of the License at 8// 9// http://www.apache.org/licenses/LICENSE-2.0 10// 11// Unless required by applicable law or agreed to in writing, software 12// distributed under the License is distributed on an "AS IS" BASIS, 13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14// See the License for the specific language governing permissions and 15// limitations under the License. 16 17package flight_test 18 19import ( 20 "context" 21 "errors" 22 "fmt" 23 "io" 24 "testing" 25 26 "github.com/apache/arrow/go/v6/arrow/array" 27 "github.com/apache/arrow/go/v6/arrow/flight" 28 "github.com/apache/arrow/go/v6/arrow/internal/arrdata" 29 "github.com/apache/arrow/go/v6/arrow/ipc" 30 "github.com/apache/arrow/go/v6/arrow/memory" 31 "google.golang.org/grpc" 32 "google.golang.org/grpc/codes" 33 "google.golang.org/grpc/status" 34) 35 36type flightServer struct { 37 mem memory.Allocator 38} 39 40func (f *flightServer) getmem() memory.Allocator { 41 if f.mem == nil { 42 f.mem = memory.NewGoAllocator() 43 } 44 45 return f.mem 46} 47 48func (f *flightServer) ListFlights(c *flight.Criteria, fs flight.FlightService_ListFlightsServer) error { 49 expr := string(c.GetExpression()) 50 51 auth := "" 52 authVal := flight.AuthFromContext(fs.Context()) 53 if authVal != nil { 54 auth = authVal.(string) 55 } 56 57 for _, name := range arrdata.RecordNames { 58 if expr != "" && expr != name { 59 continue 60 } 61 62 recs := arrdata.Records[name] 63 totalRows := int64(0) 64 for _, r := range recs { 65 totalRows += r.NumRows() 66 } 67 68 fs.Send(&flight.FlightInfo{ 69 Schema: flight.SerializeSchema(recs[0].Schema(), f.getmem()), 70 FlightDescriptor: &flight.FlightDescriptor{ 71 Type: flight.FlightDescriptor_PATH, 72 Path: []string{name, auth}, 73 }, 74 TotalRecords: totalRows, 75 TotalBytes: -1, 76 }) 77 } 78 79 return nil 80} 81 82func (f *flightServer) GetSchema(_ context.Context, in *flight.FlightDescriptor) (*flight.SchemaResult, error) { 83 if in == nil { 84 return nil, status.Error(codes.InvalidArgument, "invalid flight descriptor") 85 } 86 87 recs, ok := arrdata.Records[in.Path[0]] 88 if !ok { 89 return nil, status.Error(codes.NotFound, "flight not found") 90 } 91 92 return &flight.SchemaResult{Schema: flight.SerializeSchema(recs[0].Schema(), f.getmem())}, nil 93} 94 95func (f *flightServer) DoGet(tkt *flight.Ticket, fs flight.FlightService_DoGetServer) error { 96 recs := arrdata.Records[string(tkt.GetTicket())] 97 98 w := flight.NewRecordWriter(fs, ipc.WithSchema(recs[0].Schema())) 99 for _, r := range recs { 100 w.Write(r) 101 } 102 103 return nil 104} 105 106type servAuth struct{} 107 108func (a *servAuth) Authenticate(c flight.AuthConn) error { 109 tok, err := c.Read() 110 if err == io.EOF { 111 return nil 112 } 113 114 if string(tok) != "foobar" { 115 return errors.New("novalid") 116 } 117 118 if err != nil { 119 return err 120 } 121 122 return c.Send([]byte("baz")) 123} 124 125func (a *servAuth) IsValid(token string) (interface{}, error) { 126 if token == "baz" { 127 return "bar", nil 128 } 129 return "", errors.New("novalid") 130} 131 132type ctxauth struct{} 133 134type clientAuth struct{} 135 136func (a *clientAuth) Authenticate(ctx context.Context, c flight.AuthConn) error { 137 if err := c.Send(ctx.Value(ctxauth{}).([]byte)); err != nil { 138 return err 139 } 140 141 _, err := c.Read() 142 return err 143} 144 145func (a *clientAuth) GetToken(ctx context.Context) (string, error) { 146 return ctx.Value(ctxauth{}).(string), nil 147} 148 149func TestListFlights(t *testing.T) { 150 s := flight.NewFlightServer(nil) 151 s.Init("localhost:0") 152 f := &flightServer{} 153 s.RegisterFlightService(&flight.FlightServiceService{ 154 ListFlights: f.ListFlights, 155 }) 156 157 go s.Serve() 158 defer s.Shutdown() 159 160 client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure()) 161 if err != nil { 162 t.Error(err) 163 } 164 defer client.Close() 165 166 flightStream, err := client.ListFlights(context.Background(), &flight.Criteria{}) 167 if err != nil { 168 t.Error(err) 169 } 170 171 for { 172 info, err := flightStream.Recv() 173 if err == io.EOF { 174 break 175 } else if err != nil { 176 t.Error(err) 177 } 178 179 fname := info.GetFlightDescriptor().GetPath()[0] 180 recs, ok := arrdata.Records[fname] 181 if !ok { 182 t.Fatalf("got unknown flight info: %s", fname) 183 } 184 185 sc, err := flight.DeserializeSchema(info.GetSchema(), f.mem) 186 if err != nil { 187 t.Fatal(err) 188 } 189 190 if !recs[0].Schema().Equal(sc) { 191 t.Fatalf("flight info schema transfer failed: \ngot = %#v\nwant = %#v\n", sc, recs[0].Schema()) 192 } 193 194 var total int64 = 0 195 for _, r := range recs { 196 total += r.NumRows() 197 } 198 199 if info.TotalRecords != total { 200 t.Fatalf("got wrong number of total records: got = %d, wanted = %d", info.TotalRecords, total) 201 } 202 } 203} 204 205func TestGetSchema(t *testing.T) { 206 s := flight.NewFlightServer(nil) 207 s.Init("localhost:0") 208 f := &flightServer{} 209 s.RegisterFlightService(&flight.FlightServiceService{ 210 GetSchema: f.GetSchema, 211 }) 212 213 go s.Serve() 214 defer s.Shutdown() 215 216 client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure()) 217 if err != nil { 218 t.Error(err) 219 } 220 defer client.Close() 221 222 for name, testrecs := range arrdata.Records { 223 t.Run("flight get schema: "+name, func(t *testing.T) { 224 res, err := client.GetSchema(context.Background(), &flight.FlightDescriptor{Path: []string{name}}) 225 if err != nil { 226 t.Fatal(err) 227 } 228 229 schema, err := flight.DeserializeSchema(res.GetSchema(), f.getmem()) 230 if err != nil { 231 t.Fatal(err) 232 } 233 234 if !testrecs[0].Schema().Equal(schema) { 235 t.Fatalf("schema not match: \ngot = %#v\nwant = %#v\n", schema, testrecs[0].Schema()) 236 } 237 }) 238 } 239} 240 241func TestServer(t *testing.T) { 242 f := &flightServer{} 243 service := &flight.FlightServiceService{ 244 ListFlights: f.ListFlights, 245 DoGet: f.DoGet, 246 } 247 248 s := flight.NewFlightServer(&servAuth{}) 249 s.Init("localhost:0") 250 s.RegisterFlightService(service) 251 252 go s.Serve() 253 defer s.Shutdown() 254 255 client, err := flight.NewFlightClient(s.Addr().String(), &clientAuth{}, grpc.WithInsecure()) 256 if err != nil { 257 t.Error(err) 258 } 259 defer client.Close() 260 261 err = client.Authenticate(context.WithValue(context.Background(), ctxauth{}, []byte("foobar"))) 262 if err != nil { 263 t.Error(err) 264 } 265 266 ctx := context.WithValue(context.Background(), ctxauth{}, "baz") 267 268 fistream, err := client.ListFlights(ctx, &flight.Criteria{Expression: []byte("decimal128")}) 269 if err != nil { 270 t.Error(err) 271 } 272 273 fi, err := fistream.Recv() 274 if err != nil { 275 t.Fatal(err) 276 } 277 278 if len(fi.FlightDescriptor.GetPath()) != 2 || fi.FlightDescriptor.GetPath()[1] != "bar" { 279 t.Fatalf("path should have auth info: want %s got %s", "bar", fi.FlightDescriptor.GetPath()[1]) 280 } 281 282 fdata, err := client.DoGet(ctx, &flight.Ticket{Ticket: []byte("decimal128")}) 283 if err != nil { 284 t.Error(err) 285 } 286 287 r, err := flight.NewRecordReader(fdata) 288 if err != nil { 289 t.Error(err) 290 } 291 292 expected := arrdata.Records["decimal128"] 293 idx := 0 294 var numRows int64 = 0 295 for { 296 rec, err := r.Read() 297 if err != nil { 298 if err == io.EOF { 299 break 300 } 301 t.Error(err) 302 } 303 304 numRows += rec.NumRows() 305 if !array.RecordEqual(expected[idx], rec) { 306 t.Errorf("flight data stream records don't match: \ngot = %#v\nwant = %#v", rec, expected[idx]) 307 } 308 idx++ 309 } 310 311 if numRows != fi.TotalRecords { 312 t.Fatalf("got %d, want %d", numRows, fi.TotalRecords) 313 } 314} 315 316type flightMetadataWriterServer struct{} 317 318func (f *flightMetadataWriterServer) DoGet(tkt *flight.Ticket, fs flight.FlightService_DoGetServer) error { 319 recs := arrdata.Records[string(tkt.GetTicket())] 320 321 w := flight.NewRecordWriter(fs, ipc.WithSchema(recs[0].Schema())) 322 defer w.Close() 323 for idx, r := range recs { 324 w.WriteWithAppMetadata(r, []byte(fmt.Sprintf("%d_%s", idx, string(tkt.GetTicket()))) /*metadata*/) 325 } 326 return nil 327} 328 329func TestFlightWithAppMetadata(t *testing.T) { 330 f := &flightMetadataWriterServer{} 331 s := flight.NewFlightServer(nil) 332 s.RegisterFlightService(&flight.FlightServiceService{DoGet: f.DoGet}) 333 s.Init("localhost:0") 334 335 go s.Serve() 336 defer s.Shutdown() 337 338 client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure()) 339 if err != nil { 340 t.Fatal(err) 341 } 342 defer client.Close() 343 344 fdata, err := client.DoGet(context.Background(), &flight.Ticket{Ticket: []byte("primitives")}) 345 if err != nil { 346 t.Fatal(err) 347 } 348 349 r, err := flight.NewRecordReader(fdata) 350 if err != nil { 351 t.Fatal(err) 352 } 353 354 expected := arrdata.Records["primitives"] 355 idx := 0 356 for { 357 rec, err := r.Read() 358 if err != nil { 359 if err == io.EOF { 360 break 361 } 362 t.Fatal(err) 363 } 364 365 appMeta := r.LatestAppMetadata() 366 if !array.RecordEqual(expected[idx], rec) { 367 t.Errorf("flight data stream records for idx: %d don't match: \ngot = %#v\nwant = %#v", idx, rec, expected[idx]) 368 } 369 370 exMeta := fmt.Sprintf("%d_primitives", idx) 371 if string(appMeta) != exMeta { 372 t.Errorf("flight data stream application metadata mismatch: got: %v, want: %v\n", string(appMeta), exMeta) 373 } 374 idx++ 375 } 376} 377 378type flightErrorReturn struct{} 379 380func (f *flightErrorReturn) DoGet(_ *flight.Ticket, _ flight.FlightService_DoGetServer) error { 381 return status.Error(codes.NotFound, "nofound") 382} 383 384func TestReaderError(t *testing.T) { 385 f := &flightErrorReturn{} 386 s := flight.NewFlightServer(nil) 387 s.RegisterFlightService(&flight.FlightServiceService{DoGet: f.DoGet}) 388 s.Init("localhost:0") 389 390 go s.Serve() 391 defer s.Shutdown() 392 393 client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure()) 394 if err != nil { 395 t.Fatal(err) 396 } 397 defer client.Close() 398 399 fdata, err := client.DoGet(context.Background(), &flight.Ticket{}) 400 if err != nil { 401 t.Fatal(err) 402 } 403 404 _, err = flight.NewRecordReader(fdata) 405 if err == nil { 406 t.Fatal("should have errored") 407 } 408} 409