1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17package flight_test
18
19import (
20	"context"
21	"errors"
22	"fmt"
23	"io"
24	"testing"
25
26	"github.com/apache/arrow/go/v6/arrow/array"
27	"github.com/apache/arrow/go/v6/arrow/flight"
28	"github.com/apache/arrow/go/v6/arrow/internal/arrdata"
29	"github.com/apache/arrow/go/v6/arrow/ipc"
30	"github.com/apache/arrow/go/v6/arrow/memory"
31	"google.golang.org/grpc"
32	"google.golang.org/grpc/codes"
33	"google.golang.org/grpc/status"
34)
35
36type flightServer struct {
37	mem memory.Allocator
38}
39
40func (f *flightServer) getmem() memory.Allocator {
41	if f.mem == nil {
42		f.mem = memory.NewGoAllocator()
43	}
44
45	return f.mem
46}
47
48func (f *flightServer) ListFlights(c *flight.Criteria, fs flight.FlightService_ListFlightsServer) error {
49	expr := string(c.GetExpression())
50
51	auth := ""
52	authVal := flight.AuthFromContext(fs.Context())
53	if authVal != nil {
54		auth = authVal.(string)
55	}
56
57	for _, name := range arrdata.RecordNames {
58		if expr != "" && expr != name {
59			continue
60		}
61
62		recs := arrdata.Records[name]
63		totalRows := int64(0)
64		for _, r := range recs {
65			totalRows += r.NumRows()
66		}
67
68		fs.Send(&flight.FlightInfo{
69			Schema: flight.SerializeSchema(recs[0].Schema(), f.getmem()),
70			FlightDescriptor: &flight.FlightDescriptor{
71				Type: flight.FlightDescriptor_PATH,
72				Path: []string{name, auth},
73			},
74			TotalRecords: totalRows,
75			TotalBytes:   -1,
76		})
77	}
78
79	return nil
80}
81
82func (f *flightServer) GetSchema(_ context.Context, in *flight.FlightDescriptor) (*flight.SchemaResult, error) {
83	if in == nil {
84		return nil, status.Error(codes.InvalidArgument, "invalid flight descriptor")
85	}
86
87	recs, ok := arrdata.Records[in.Path[0]]
88	if !ok {
89		return nil, status.Error(codes.NotFound, "flight not found")
90	}
91
92	return &flight.SchemaResult{Schema: flight.SerializeSchema(recs[0].Schema(), f.getmem())}, nil
93}
94
95func (f *flightServer) DoGet(tkt *flight.Ticket, fs flight.FlightService_DoGetServer) error {
96	recs := arrdata.Records[string(tkt.GetTicket())]
97
98	w := flight.NewRecordWriter(fs, ipc.WithSchema(recs[0].Schema()))
99	for _, r := range recs {
100		w.Write(r)
101	}
102
103	return nil
104}
105
106type servAuth struct{}
107
108func (a *servAuth) Authenticate(c flight.AuthConn) error {
109	tok, err := c.Read()
110	if err == io.EOF {
111		return nil
112	}
113
114	if string(tok) != "foobar" {
115		return errors.New("novalid")
116	}
117
118	if err != nil {
119		return err
120	}
121
122	return c.Send([]byte("baz"))
123}
124
125func (a *servAuth) IsValid(token string) (interface{}, error) {
126	if token == "baz" {
127		return "bar", nil
128	}
129	return "", errors.New("novalid")
130}
131
132type ctxauth struct{}
133
134type clientAuth struct{}
135
136func (a *clientAuth) Authenticate(ctx context.Context, c flight.AuthConn) error {
137	if err := c.Send(ctx.Value(ctxauth{}).([]byte)); err != nil {
138		return err
139	}
140
141	_, err := c.Read()
142	return err
143}
144
145func (a *clientAuth) GetToken(ctx context.Context) (string, error) {
146	return ctx.Value(ctxauth{}).(string), nil
147}
148
149func TestListFlights(t *testing.T) {
150	s := flight.NewFlightServer(nil)
151	s.Init("localhost:0")
152	f := &flightServer{}
153	s.RegisterFlightService(&flight.FlightServiceService{
154		ListFlights: f.ListFlights,
155	})
156
157	go s.Serve()
158	defer s.Shutdown()
159
160	client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure())
161	if err != nil {
162		t.Error(err)
163	}
164	defer client.Close()
165
166	flightStream, err := client.ListFlights(context.Background(), &flight.Criteria{})
167	if err != nil {
168		t.Error(err)
169	}
170
171	for {
172		info, err := flightStream.Recv()
173		if err == io.EOF {
174			break
175		} else if err != nil {
176			t.Error(err)
177		}
178
179		fname := info.GetFlightDescriptor().GetPath()[0]
180		recs, ok := arrdata.Records[fname]
181		if !ok {
182			t.Fatalf("got unknown flight info: %s", fname)
183		}
184
185		sc, err := flight.DeserializeSchema(info.GetSchema(), f.mem)
186		if err != nil {
187			t.Fatal(err)
188		}
189
190		if !recs[0].Schema().Equal(sc) {
191			t.Fatalf("flight info schema transfer failed: \ngot = %#v\nwant = %#v\n", sc, recs[0].Schema())
192		}
193
194		var total int64 = 0
195		for _, r := range recs {
196			total += r.NumRows()
197		}
198
199		if info.TotalRecords != total {
200			t.Fatalf("got wrong number of total records: got = %d, wanted = %d", info.TotalRecords, total)
201		}
202	}
203}
204
205func TestGetSchema(t *testing.T) {
206	s := flight.NewFlightServer(nil)
207	s.Init("localhost:0")
208	f := &flightServer{}
209	s.RegisterFlightService(&flight.FlightServiceService{
210		GetSchema: f.GetSchema,
211	})
212
213	go s.Serve()
214	defer s.Shutdown()
215
216	client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure())
217	if err != nil {
218		t.Error(err)
219	}
220	defer client.Close()
221
222	for name, testrecs := range arrdata.Records {
223		t.Run("flight get schema: "+name, func(t *testing.T) {
224			res, err := client.GetSchema(context.Background(), &flight.FlightDescriptor{Path: []string{name}})
225			if err != nil {
226				t.Fatal(err)
227			}
228
229			schema, err := flight.DeserializeSchema(res.GetSchema(), f.getmem())
230			if err != nil {
231				t.Fatal(err)
232			}
233
234			if !testrecs[0].Schema().Equal(schema) {
235				t.Fatalf("schema not match: \ngot = %#v\nwant = %#v\n", schema, testrecs[0].Schema())
236			}
237		})
238	}
239}
240
241func TestServer(t *testing.T) {
242	f := &flightServer{}
243	service := &flight.FlightServiceService{
244		ListFlights: f.ListFlights,
245		DoGet:       f.DoGet,
246	}
247
248	s := flight.NewFlightServer(&servAuth{})
249	s.Init("localhost:0")
250	s.RegisterFlightService(service)
251
252	go s.Serve()
253	defer s.Shutdown()
254
255	client, err := flight.NewFlightClient(s.Addr().String(), &clientAuth{}, grpc.WithInsecure())
256	if err != nil {
257		t.Error(err)
258	}
259	defer client.Close()
260
261	err = client.Authenticate(context.WithValue(context.Background(), ctxauth{}, []byte("foobar")))
262	if err != nil {
263		t.Error(err)
264	}
265
266	ctx := context.WithValue(context.Background(), ctxauth{}, "baz")
267
268	fistream, err := client.ListFlights(ctx, &flight.Criteria{Expression: []byte("decimal128")})
269	if err != nil {
270		t.Error(err)
271	}
272
273	fi, err := fistream.Recv()
274	if err != nil {
275		t.Fatal(err)
276	}
277
278	if len(fi.FlightDescriptor.GetPath()) != 2 || fi.FlightDescriptor.GetPath()[1] != "bar" {
279		t.Fatalf("path should have auth info: want %s got %s", "bar", fi.FlightDescriptor.GetPath()[1])
280	}
281
282	fdata, err := client.DoGet(ctx, &flight.Ticket{Ticket: []byte("decimal128")})
283	if err != nil {
284		t.Error(err)
285	}
286
287	r, err := flight.NewRecordReader(fdata)
288	if err != nil {
289		t.Error(err)
290	}
291
292	expected := arrdata.Records["decimal128"]
293	idx := 0
294	var numRows int64 = 0
295	for {
296		rec, err := r.Read()
297		if err != nil {
298			if err == io.EOF {
299				break
300			}
301			t.Error(err)
302		}
303
304		numRows += rec.NumRows()
305		if !array.RecordEqual(expected[idx], rec) {
306			t.Errorf("flight data stream records don't match: \ngot = %#v\nwant = %#v", rec, expected[idx])
307		}
308		idx++
309	}
310
311	if numRows != fi.TotalRecords {
312		t.Fatalf("got %d, want %d", numRows, fi.TotalRecords)
313	}
314}
315
316type flightMetadataWriterServer struct{}
317
318func (f *flightMetadataWriterServer) DoGet(tkt *flight.Ticket, fs flight.FlightService_DoGetServer) error {
319	recs := arrdata.Records[string(tkt.GetTicket())]
320
321	w := flight.NewRecordWriter(fs, ipc.WithSchema(recs[0].Schema()))
322	defer w.Close()
323	for idx, r := range recs {
324		w.WriteWithAppMetadata(r, []byte(fmt.Sprintf("%d_%s", idx, string(tkt.GetTicket()))) /*metadata*/)
325	}
326	return nil
327}
328
329func TestFlightWithAppMetadata(t *testing.T) {
330	f := &flightMetadataWriterServer{}
331	s := flight.NewFlightServer(nil)
332	s.RegisterFlightService(&flight.FlightServiceService{DoGet: f.DoGet})
333	s.Init("localhost:0")
334
335	go s.Serve()
336	defer s.Shutdown()
337
338	client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure())
339	if err != nil {
340		t.Fatal(err)
341	}
342	defer client.Close()
343
344	fdata, err := client.DoGet(context.Background(), &flight.Ticket{Ticket: []byte("primitives")})
345	if err != nil {
346		t.Fatal(err)
347	}
348
349	r, err := flight.NewRecordReader(fdata)
350	if err != nil {
351		t.Fatal(err)
352	}
353
354	expected := arrdata.Records["primitives"]
355	idx := 0
356	for {
357		rec, err := r.Read()
358		if err != nil {
359			if err == io.EOF {
360				break
361			}
362			t.Fatal(err)
363		}
364
365		appMeta := r.LatestAppMetadata()
366		if !array.RecordEqual(expected[idx], rec) {
367			t.Errorf("flight data stream records for idx: %d don't match: \ngot = %#v\nwant = %#v", idx, rec, expected[idx])
368		}
369
370		exMeta := fmt.Sprintf("%d_primitives", idx)
371		if string(appMeta) != exMeta {
372			t.Errorf("flight data stream application metadata mismatch: got: %v, want: %v\n", string(appMeta), exMeta)
373		}
374		idx++
375	}
376}
377
378type flightErrorReturn struct{}
379
380func (f *flightErrorReturn) DoGet(_ *flight.Ticket, _ flight.FlightService_DoGetServer) error {
381	return status.Error(codes.NotFound, "nofound")
382}
383
384func TestReaderError(t *testing.T) {
385	f := &flightErrorReturn{}
386	s := flight.NewFlightServer(nil)
387	s.RegisterFlightService(&flight.FlightServiceService{DoGet: f.DoGet})
388	s.Init("localhost:0")
389
390	go s.Serve()
391	defer s.Shutdown()
392
393	client, err := flight.NewFlightClient(s.Addr().String(), nil, grpc.WithInsecure())
394	if err != nil {
395		t.Fatal(err)
396	}
397	defer client.Close()
398
399	fdata, err := client.DoGet(context.Background(), &flight.Ticket{})
400	if err != nil {
401		t.Fatal(err)
402	}
403
404	_, err = flight.NewRecordReader(fdata)
405	if err == nil {
406		t.Fatal("should have errored")
407	}
408}
409