1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17package ipc // import "github.com/apache/arrow/go/arrow/ipc"
18
19import (
20	"bytes"
21	"encoding/binary"
22	"io"
23
24	"github.com/apache/arrow/go/arrow"
25	"github.com/apache/arrow/go/arrow/array"
26	"github.com/apache/arrow/go/arrow/bitutil"
27	"github.com/apache/arrow/go/arrow/internal/flatbuf"
28	"github.com/apache/arrow/go/arrow/memory"
29	"golang.org/x/xerrors"
30)
31
32// FileReader is an Arrow file reader.
33type FileReader struct {
34	r ReadAtSeeker
35
36	footer struct {
37		offset int64
38		buffer *memory.Buffer
39		data   *flatbuf.Footer
40	}
41
42	fields dictTypeMap
43	memo   dictMemo
44
45	schema *arrow.Schema
46	record array.Record
47
48	irec int   // current record index. used for the arrio.Reader interface
49	err  error // last error
50}
51
52// NewFileReader opens an Arrow file using the provided reader r.
53func NewFileReader(r ReadAtSeeker, opts ...Option) (*FileReader, error) {
54	var (
55		cfg = newConfig(opts...)
56		err error
57
58		f = FileReader{
59			r:      r,
60			fields: make(dictTypeMap),
61			memo:   newMemo(),
62		}
63	)
64
65	if cfg.footer.offset <= 0 {
66		cfg.footer.offset, err = f.r.Seek(0, io.SeekEnd)
67		if err != nil {
68			return nil, xerrors.Errorf("arrow/ipc: could retrieve footer offset: %w", err)
69		}
70	}
71	f.footer.offset = cfg.footer.offset
72
73	err = f.readFooter()
74	if err != nil {
75		return nil, xerrors.Errorf("arrow/ipc: could not decode footer: %w", err)
76	}
77
78	err = f.readSchema()
79	if err != nil {
80		return nil, xerrors.Errorf("arrow/ipc: could not decode schema: %w", err)
81	}
82
83	if cfg.schema != nil && !cfg.schema.Equal(f.schema) {
84		return nil, xerrors.Errorf("arrow/ipc: inconsistent schema for reading (got: %v, want: %v)", f.schema, cfg.schema)
85	}
86
87	return &f, err
88}
89
90func (f *FileReader) readFooter() error {
91	var err error
92
93	if f.footer.offset <= int64(len(Magic)*2+4) {
94		return xerrors.Errorf("arrow/ipc: file too small (size=%d)", f.footer.offset)
95	}
96
97	eof := int64(len(Magic) + 4)
98	buf := make([]byte, eof)
99	n, err := f.r.ReadAt(buf, f.footer.offset-eof)
100	if err != nil {
101		return xerrors.Errorf("arrow/ipc: could not read footer: %w", err)
102	}
103	if n != len(buf) {
104		return xerrors.Errorf("arrow/ipc: could not read %d bytes from end of file", len(buf))
105	}
106
107	if !bytes.Equal(buf[4:], Magic) {
108		return errNotArrowFile
109	}
110
111	size := int64(binary.LittleEndian.Uint32(buf[:4]))
112	if size <= 0 || size+int64(len(Magic)*2+4) > f.footer.offset {
113		return errInconsistentFileMetadata
114	}
115
116	buf = make([]byte, size)
117	n, err = f.r.ReadAt(buf, f.footer.offset-size-eof)
118	if err != nil {
119		return xerrors.Errorf("arrow/ipc: could not read footer data: %w", err)
120	}
121	if n != len(buf) {
122		return xerrors.Errorf("arrow/ipc: could not read %d bytes from footer data", len(buf))
123	}
124
125	f.footer.buffer = memory.NewBufferBytes(buf)
126	f.footer.data = flatbuf.GetRootAsFooter(buf, 0)
127	return err
128}
129
130func (f *FileReader) readSchema() error {
131	var err error
132	f.fields, err = dictTypesFromFB(f.footer.data.Schema(nil))
133	if err != nil {
134		return xerrors.Errorf("arrow/ipc: could not load dictionary types from file: %w", err)
135	}
136
137	for i := 0; i < f.NumDictionaries(); i++ {
138		blk, err := f.dict(i)
139		if err != nil {
140			return xerrors.Errorf("arrow/ipc: could read dictionary[%d]: %w", i, err)
141		}
142		switch {
143		case !bitutil.IsMultipleOf8(blk.Offset):
144			return xerrors.Errorf("arrow/ipc: invalid file offset=%d for dictionary %d", blk.Offset, i)
145		case !bitutil.IsMultipleOf8(int64(blk.Meta)):
146			return xerrors.Errorf("arrow/ipc: invalid file metadata=%d position for dictionary %d", blk.Meta, i)
147		case !bitutil.IsMultipleOf8(blk.Body):
148			return xerrors.Errorf("arrow/ipc: invalid file body=%d position for dictionary %d", blk.Body, i)
149		}
150
151		msg, err := blk.NewMessage()
152		if err != nil {
153			return err
154		}
155
156		id, dict, err := readDictionary(msg.meta, f.fields, f.r)
157		msg.Release()
158		if err != nil {
159			return xerrors.Errorf("arrow/ipc: could not read dictionary %d from file: %w", i, err)
160		}
161		f.memo.Add(id, dict)
162		dict.Release() // memo.Add increases ref-count of dict.
163	}
164
165	schema := f.footer.data.Schema(nil)
166	if schema == nil {
167		return xerrors.Errorf("arrow/ipc: could not load schema from flatbuffer data")
168	}
169	f.schema, err = schemaFromFB(schema, &f.memo)
170	if err != nil {
171		return xerrors.Errorf("arrow/ipc: could not read schema: %w", err)
172	}
173
174	return err
175}
176
177func (f *FileReader) block(i int) (fileBlock, error) {
178	var blk flatbuf.Block
179	if !f.footer.data.RecordBatches(&blk, i) {
180		return fileBlock{}, xerrors.Errorf("arrow/ipc: could not extract file block %d", i)
181	}
182
183	return fileBlock{
184		Offset: blk.Offset(),
185		Meta:   blk.MetaDataLength(),
186		Body:   blk.BodyLength(),
187		r:      f.r,
188	}, nil
189}
190
191func (f *FileReader) dict(i int) (fileBlock, error) {
192	var blk flatbuf.Block
193	if !f.footer.data.Dictionaries(&blk, i) {
194		return fileBlock{}, xerrors.Errorf("arrow/ipc: could not extract dictionary block %d", i)
195	}
196
197	return fileBlock{
198		Offset: blk.Offset(),
199		Meta:   blk.MetaDataLength(),
200		Body:   blk.BodyLength(),
201		r:      f.r,
202	}, nil
203}
204
205func (f *FileReader) Schema() *arrow.Schema {
206	return f.schema
207}
208
209func (f *FileReader) NumDictionaries() int {
210	if f.footer.data == nil {
211		return 0
212	}
213	return f.footer.data.DictionariesLength()
214}
215
216func (f *FileReader) NumRecords() int {
217	return f.footer.data.RecordBatchesLength()
218}
219
220func (f *FileReader) Version() MetadataVersion {
221	return MetadataVersion(f.footer.data.Version())
222}
223
224// Close cleans up resources used by the File.
225// Close does not close the underlying reader.
226func (f *FileReader) Close() error {
227	if f.footer.data != nil {
228		f.footer.data = nil
229	}
230
231	if f.footer.buffer != nil {
232		f.footer.buffer.Release()
233		f.footer.buffer = nil
234	}
235
236	if f.record != nil {
237		f.record.Release()
238		f.record = nil
239	}
240	return nil
241}
242
243// Record returns the i-th record from the file.
244// The returned value is valid until the next call to Record.
245// Users need to call Retain on that Record to keep it valid for longer.
246func (f *FileReader) Record(i int) (array.Record, error) {
247	if i < 0 || i > f.NumRecords() {
248		panic("arrow/ipc: record index out of bounds")
249	}
250
251	blk, err := f.block(i)
252	if err != nil {
253		return nil, err
254	}
255	switch {
256	case !bitutil.IsMultipleOf8(blk.Offset):
257		return nil, xerrors.Errorf("arrow/ipc: invalid file offset=%d for record %d", blk.Offset, i)
258	case !bitutil.IsMultipleOf8(int64(blk.Meta)):
259		return nil, xerrors.Errorf("arrow/ipc: invalid file metadata=%d position for record %d", blk.Meta, i)
260	case !bitutil.IsMultipleOf8(blk.Body):
261		return nil, xerrors.Errorf("arrow/ipc: invalid file body=%d position for record %d", blk.Body, i)
262	}
263
264	msg, err := blk.NewMessage()
265	if err != nil {
266		return nil, err
267	}
268	defer msg.Release()
269
270	if msg.Type() != MessageRecordBatch {
271		return nil, xerrors.Errorf("arrow/ipc: message %d is not a Record", i)
272	}
273
274	if f.record != nil {
275		f.record.Release()
276	}
277
278	f.record = newRecord(f.schema, msg.meta, bytes.NewReader(msg.body.Bytes()))
279	return f.record, nil
280}
281
282// Read reads the current record from the underlying stream and an error, if any.
283// When the Reader reaches the end of the underlying stream, it returns (nil, io.EOF).
284//
285// The returned record value is valid until the next call to Read.
286// Users need to call Retain on that Record to keep it valid for longer.
287func (f *FileReader) Read() (rec array.Record, err error) {
288	if f.irec == f.NumRecords() {
289		return nil, io.EOF
290	}
291	rec, f.err = f.Record(f.irec)
292	f.irec++
293	return rec, f.err
294}
295
296// ReadAt reads the i-th record from the underlying stream and an error, if any.
297func (f *FileReader) ReadAt(i int64) (array.Record, error) {
298	return f.Record(int(i))
299}
300
301func newRecord(schema *arrow.Schema, meta *memory.Buffer, body ReadAtSeeker) array.Record {
302	var (
303		msg = flatbuf.GetRootAsMessage(meta.Bytes(), 0)
304		md  flatbuf.RecordBatch
305	)
306	initFB(&md, msg.Header)
307	rows := md.Length()
308
309	ctx := &arrayLoaderContext{
310		src: ipcSource{
311			meta: &md,
312			r:    body,
313		},
314		max: kMaxNestingDepth,
315	}
316
317	cols := make([]array.Interface, len(schema.Fields()))
318	for i, field := range schema.Fields() {
319		cols[i] = ctx.loadArray(field.Type)
320	}
321
322	return array.NewRecord(schema, cols, rows)
323}
324
325type ipcSource struct {
326	meta *flatbuf.RecordBatch
327	r    ReadAtSeeker
328}
329
330func (src *ipcSource) buffer(i int) *memory.Buffer {
331	var buf flatbuf.Buffer
332	if !src.meta.Buffers(&buf, i) {
333		panic("buffer index out of bound")
334	}
335	if buf.Length() == 0 {
336		return memory.NewBufferBytes(nil)
337	}
338
339	raw := make([]byte, buf.Length())
340	_, err := src.r.ReadAt(raw, buf.Offset())
341	if err != nil {
342		panic(err)
343	}
344
345	return memory.NewBufferBytes(raw)
346}
347
348func (src *ipcSource) fieldMetadata(i int) *flatbuf.FieldNode {
349	var node flatbuf.FieldNode
350	if !src.meta.Nodes(&node, i) {
351		panic("field metadata out of bound")
352	}
353	return &node
354}
355
356type arrayLoaderContext struct {
357	src     ipcSource
358	ifield  int
359	ibuffer int
360	max     int
361}
362
363func (ctx *arrayLoaderContext) field() *flatbuf.FieldNode {
364	field := ctx.src.fieldMetadata(ctx.ifield)
365	ctx.ifield++
366	return field
367}
368
369func (ctx *arrayLoaderContext) buffer() *memory.Buffer {
370	buf := ctx.src.buffer(ctx.ibuffer)
371	ctx.ibuffer++
372	return buf
373}
374
375func (ctx *arrayLoaderContext) loadArray(dt arrow.DataType) array.Interface {
376	switch dt := dt.(type) {
377	case *arrow.NullType:
378		return ctx.loadNull()
379
380	case *arrow.BooleanType,
381		*arrow.Int8Type, *arrow.Int16Type, *arrow.Int32Type, *arrow.Int64Type,
382		*arrow.Uint8Type, *arrow.Uint16Type, *arrow.Uint32Type, *arrow.Uint64Type,
383		*arrow.Float16Type, *arrow.Float32Type, *arrow.Float64Type,
384		*arrow.Decimal128Type,
385		*arrow.Time32Type, *arrow.Time64Type,
386		*arrow.TimestampType,
387		*arrow.Date32Type, *arrow.Date64Type,
388		*arrow.MonthIntervalType, *arrow.DayTimeIntervalType,
389		*arrow.DurationType:
390		return ctx.loadPrimitive(dt)
391
392	case *arrow.BinaryType, *arrow.StringType:
393		return ctx.loadBinary(dt)
394
395	case *arrow.FixedSizeBinaryType:
396		return ctx.loadFixedSizeBinary(dt)
397
398	case *arrow.ListType:
399		return ctx.loadList(dt)
400
401	case *arrow.FixedSizeListType:
402		return ctx.loadFixedSizeList(dt)
403
404	case *arrow.StructType:
405		return ctx.loadStruct(dt)
406
407	default:
408		panic(xerrors.Errorf("array type %T not handled yet", dt))
409	}
410}
411
412func (ctx *arrayLoaderContext) loadCommon(nbufs int) (*flatbuf.FieldNode, []*memory.Buffer) {
413	buffers := make([]*memory.Buffer, 0, nbufs)
414	field := ctx.field()
415
416	var buf *memory.Buffer
417	switch field.NullCount() {
418	case 0:
419		ctx.ibuffer++
420	default:
421		buf = ctx.buffer()
422	}
423	buffers = append(buffers, buf)
424
425	return field, buffers
426}
427
428func (ctx *arrayLoaderContext) loadChild(dt arrow.DataType) array.Interface {
429	if ctx.max == 0 {
430		panic("arrow/ipc: nested type limit reached")
431	}
432	ctx.max--
433	sub := ctx.loadArray(dt)
434	ctx.max++
435	return sub
436}
437
438func (ctx *arrayLoaderContext) loadNull() array.Interface {
439	field := ctx.field()
440	data := array.NewData(arrow.Null, int(field.Length()), nil, nil, int(field.NullCount()), 0)
441	defer data.Release()
442
443	return array.MakeFromData(data)
444}
445
446func (ctx *arrayLoaderContext) loadPrimitive(dt arrow.DataType) array.Interface {
447	field, buffers := ctx.loadCommon(2)
448
449	switch field.Length() {
450	case 0:
451		buffers = append(buffers, nil)
452		ctx.ibuffer++
453	default:
454		buffers = append(buffers, ctx.buffer())
455	}
456
457	data := array.NewData(dt, int(field.Length()), buffers, nil, int(field.NullCount()), 0)
458	defer data.Release()
459
460	return array.MakeFromData(data)
461}
462
463func (ctx *arrayLoaderContext) loadBinary(dt arrow.DataType) array.Interface {
464	field, buffers := ctx.loadCommon(3)
465	buffers = append(buffers, ctx.buffer(), ctx.buffer())
466
467	data := array.NewData(dt, int(field.Length()), buffers, nil, int(field.NullCount()), 0)
468	defer data.Release()
469
470	return array.MakeFromData(data)
471}
472
473func (ctx *arrayLoaderContext) loadFixedSizeBinary(dt *arrow.FixedSizeBinaryType) array.Interface {
474	field, buffers := ctx.loadCommon(2)
475	buffers = append(buffers, ctx.buffer())
476
477	data := array.NewData(dt, int(field.Length()), buffers, nil, int(field.NullCount()), 0)
478	defer data.Release()
479
480	return array.MakeFromData(data)
481}
482
483func (ctx *arrayLoaderContext) loadList(dt *arrow.ListType) array.Interface {
484	field, buffers := ctx.loadCommon(2)
485	buffers = append(buffers, ctx.buffer())
486
487	sub := ctx.loadChild(dt.Elem())
488	defer sub.Release()
489
490	data := array.NewData(dt, int(field.Length()), buffers, []*array.Data{sub.Data()}, int(field.NullCount()), 0)
491	defer data.Release()
492
493	return array.NewListData(data)
494}
495
496func (ctx *arrayLoaderContext) loadFixedSizeList(dt *arrow.FixedSizeListType) array.Interface {
497	field, buffers := ctx.loadCommon(1)
498
499	sub := ctx.loadChild(dt.Elem())
500	defer sub.Release()
501
502	data := array.NewData(dt, int(field.Length()), buffers, []*array.Data{sub.Data()}, int(field.NullCount()), 0)
503	defer data.Release()
504
505	return array.NewFixedSizeListData(data)
506}
507
508func (ctx *arrayLoaderContext) loadStruct(dt *arrow.StructType) array.Interface {
509	field, buffers := ctx.loadCommon(1)
510
511	arrs := make([]array.Interface, len(dt.Fields()))
512	subs := make([]*array.Data, len(dt.Fields()))
513	for i, f := range dt.Fields() {
514		arrs[i] = ctx.loadChild(f.Type)
515		subs[i] = arrs[i].Data()
516	}
517	defer func() {
518		for i := range arrs {
519			arrs[i].Release()
520		}
521	}()
522
523	data := array.NewData(dt, int(field.Length()), buffers, subs, int(field.NullCount()), 0)
524	defer data.Release()
525
526	return array.NewStructData(data)
527}
528
529func readDictionary(meta *memory.Buffer, types dictTypeMap, r ReadAtSeeker) (int64, array.Interface, error) {
530	//	msg := flatbuf.GetRootAsMessage(meta.Bytes(), 0)
531	//	var dictBatch flatbuf.DictionaryBatch
532	//	initFB(&dictBatch, msg.Header)
533	//
534	//	id := dictBatch.Id()
535	//	v, ok := types[id]
536	//	if !ok {
537	//		return id, nil, errors.Errorf("arrow/ipc: no type metadata for dictionary with ID=%d", id)
538	//	}
539	//
540	//	fields := []arrow.Field{v}
541	//
542	//	// we need a schema for the record batch.
543	//	schema := arrow.NewSchema(fields, nil)
544	//
545	//	// the dictionary is embedded in a record batch with a single column.
546	//	recBatch := dictBatch.Data(nil)
547	//
548	//	var (
549	//		batchMeta *memory.Buffer
550	//		body      *memory.Buffer
551	//	)
552	//
553	//
554	//	ctx := &arrayLoaderContext{
555	//		src: ipcSource{
556	//			meta: &md,
557	//			r:    bytes.NewReader(body.Bytes()),
558	//		},
559	//		max: kMaxNestingDepth,
560	//	}
561	//
562	//	cols := make([]array.Interface, len(schema.Fields()))
563	//	for i, field := range schema.Fields() {
564	//		cols[i] = ctx.loadArray(field.Type)
565	//	}
566	//
567	//	batch := array.NewRecord(schema, cols, rows)
568
569	panic("not implemented")
570}
571