1 // Licensed to the Apache Software Foundation (ASF) under one or more
2 // contributor license agreements. See the NOTICE file distributed with
3 // this work for additional information regarding copyright ownership.
4 // The ASF licenses this file to You under the Apache License, Version 2.0
5 // (the "License"); you may not use this file except in compliance with
6 // the License.  You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 using System;
17 using System.Buffers;
18 using System.Buffers.Binary;
19 using System.Collections.Generic;
20 using System.Diagnostics;
21 using System.IO;
22 using System.Threading;
23 using System.Threading.Tasks;
24 using Apache.Arrow.Arrays;
25 using Apache.Arrow.Types;
26 using FlatBuffers;
27 
28 namespace Apache.Arrow.Ipc
29 {
30     public class ArrowStreamWriter : IDisposable
31     {
32         internal class ArrowRecordBatchFlatBufferBuilder :
33             IArrowArrayVisitor<Int8Array>,
34             IArrowArrayVisitor<Int16Array>,
35             IArrowArrayVisitor<Int32Array>,
36             IArrowArrayVisitor<Int64Array>,
37             IArrowArrayVisitor<UInt8Array>,
38             IArrowArrayVisitor<UInt16Array>,
39             IArrowArrayVisitor<UInt32Array>,
40             IArrowArrayVisitor<UInt64Array>,
41             IArrowArrayVisitor<FloatArray>,
42             IArrowArrayVisitor<DoubleArray>,
43             IArrowArrayVisitor<BooleanArray>,
44             IArrowArrayVisitor<TimestampArray>,
45             IArrowArrayVisitor<Date32Array>,
46             IArrowArrayVisitor<Date64Array>,
47             IArrowArrayVisitor<ListArray>,
48             IArrowArrayVisitor<StringArray>,
49             IArrowArrayVisitor<BinaryArray>,
50             IArrowArrayVisitor<FixedSizeBinaryArray>,
51             IArrowArrayVisitor<StructArray>,
52             IArrowArrayVisitor<Decimal128Array>,
53             IArrowArrayVisitor<Decimal256Array>,
54             IArrowArrayVisitor<DictionaryArray>
55         {
56             public readonly struct Buffer
57             {
58                 public readonly ArrowBuffer DataBuffer;
59                 public readonly int Offset;
60 
BufferApache.Arrow.Ipc.ArrowStreamWriter.ArrowRecordBatchFlatBufferBuilder.Buffer61                 public Buffer(ArrowBuffer buffer, int offset)
62                 {
63                     DataBuffer = buffer;
64                     Offset = offset;
65                 }
66             }
67 
68             private readonly List<Buffer> _buffers;
69 
70             public IReadOnlyList<Buffer> Buffers => _buffers;
71 
72             public int TotalLength { get; private set; }
73 
ArrowRecordBatchFlatBufferBuilder()74             public ArrowRecordBatchFlatBufferBuilder()
75             {
76                 _buffers = new List<Buffer>();
77                 TotalLength = 0;
78             }
79 
80             public void Visit(Int8Array array) => CreateBuffers(array);
81             public void Visit(Int16Array array) => CreateBuffers(array);
82             public void Visit(Int32Array array) => CreateBuffers(array);
83             public void Visit(Int64Array array) => CreateBuffers(array);
84             public void Visit(UInt8Array array) => CreateBuffers(array);
85             public void Visit(UInt16Array array) => CreateBuffers(array);
86             public void Visit(UInt32Array array) => CreateBuffers(array);
87             public void Visit(UInt64Array array) => CreateBuffers(array);
88             public void Visit(FloatArray array) => CreateBuffers(array);
89             public void Visit(DoubleArray array) => CreateBuffers(array);
90             public void Visit(TimestampArray array) => CreateBuffers(array);
91             public void Visit(BooleanArray array) => CreateBuffers(array);
92             public void Visit(Date32Array array) => CreateBuffers(array);
93             public void Visit(Date64Array array) => CreateBuffers(array);
94 
Visit(ListArray array)95             public void Visit(ListArray array)
96             {
97                 _buffers.Add(CreateBuffer(array.NullBitmapBuffer));
98                 _buffers.Add(CreateBuffer(array.ValueOffsetsBuffer));
99 
100                 array.Values.Accept(this);
101             }
102 
103             public void Visit(StringArray array) => Visit(array as BinaryArray);
104 
Visit(BinaryArray array)105             public void Visit(BinaryArray array)
106             {
107                 _buffers.Add(CreateBuffer(array.NullBitmapBuffer));
108                 _buffers.Add(CreateBuffer(array.ValueOffsetsBuffer));
109                 _buffers.Add(CreateBuffer(array.ValueBuffer));
110             }
111 
Visit(FixedSizeBinaryArray array)112             public void Visit(FixedSizeBinaryArray array)
113             {
114                 _buffers.Add(CreateBuffer(array.NullBitmapBuffer));
115                 _buffers.Add(CreateBuffer(array.ValueBuffer));
116             }
117 
Visit(Decimal128Array array)118             public void Visit(Decimal128Array array)
119             {
120                 _buffers.Add(CreateBuffer(array.NullBitmapBuffer));
121                 _buffers.Add(CreateBuffer(array.ValueBuffer));
122             }
123 
Visit(Decimal256Array array)124             public void Visit(Decimal256Array array)
125             {
126                 _buffers.Add(CreateBuffer(array.NullBitmapBuffer));
127                 _buffers.Add(CreateBuffer(array.ValueBuffer));
128             }
129 
Visit(StructArray array)130             public void Visit(StructArray array)
131             {
132                 _buffers.Add(CreateBuffer(array.NullBitmapBuffer));
133 
134                 for (int i = 0; i < array.Fields.Count; i++)
135                 {
136                     array.Fields[i].Accept(this);
137                 }
138             }
139 
Visit(DictionaryArray array)140             public void Visit(DictionaryArray array)
141             {
142                 // Dictionary is serialized separately in Dictionary serialization.
143                 // We are only interested in indices at this context.
144 
145                 _buffers.Add(CreateBuffer(array.NullBitmapBuffer));
146                 _buffers.Add(CreateBuffer(array.IndicesBuffer));
147             }
148 
CreateBuffers(BooleanArray array)149             private void CreateBuffers(BooleanArray array)
150             {
151                 _buffers.Add(CreateBuffer(array.NullBitmapBuffer));
152                 _buffers.Add(CreateBuffer(array.ValueBuffer));
153             }
154 
155             private void CreateBuffers<T>(PrimitiveArray<T> array)
156                 where T : struct
157             {
CreateBufferApache.Arrow.Ipc.ArrowStreamWriter.ArrowRecordBatchFlatBufferBuilder.__anon1158                 _buffers.Add(CreateBuffer(array.NullBitmapBuffer));
CreateBufferApache.Arrow.Ipc.ArrowStreamWriter.ArrowRecordBatchFlatBufferBuilder.__anon1159                 _buffers.Add(CreateBuffer(array.ValueBuffer));
160             }
161 
CreateBuffer(ArrowBuffer buffer)162             private Buffer CreateBuffer(ArrowBuffer buffer)
163             {
164                 int offset = TotalLength;
165 
166                 int paddedLength = checked((int)BitUtility.RoundUpToMultipleOf8(buffer.Length));
167                 TotalLength += paddedLength;
168 
169                 return new Buffer(buffer, offset);
170             }
171 
Visit(IArrowArray array)172             public void Visit(IArrowArray array)
173             {
174                 throw new NotImplementedException();
175             }
176         }
177 
178         protected Stream BaseStream { get; }
179 
180         protected ArrayPool<byte> Buffers { get; }
181 
182         private protected FlatBufferBuilder Builder { get; }
183 
184         protected bool HasWrittenSchema { get; set; }
185 
186         private bool HasWrittenDictionaryBatch { get; set; }
187 
188         private bool HasWrittenStart { get; set; }
189 
190         private bool HasWrittenEnd { get; set; }
191 
192         protected Schema Schema { get; }
193 
194         private readonly bool _leaveOpen;
195         private readonly IpcOptions _options;
196 
197         private protected const Flatbuf.MetadataVersion CurrentMetadataVersion = Flatbuf.MetadataVersion.V4;
198 
199         private static readonly byte[] s_padding = new byte[64];
200 
201         private readonly ArrowTypeFlatbufferBuilder _fieldTypeBuilder;
202 
203         private DictionaryMemo _dictionaryMemo;
204         private DictionaryMemo DictionaryMemo => _dictionaryMemo ??= new DictionaryMemo();
205 
ArrowStreamWriter(Stream baseStream, Schema schema)206         public ArrowStreamWriter(Stream baseStream, Schema schema)
207             : this(baseStream, schema, leaveOpen: false)
208         {
209         }
210 
ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen)211         public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen)
212             : this(baseStream, schema, leaveOpen, options: null)
213         {
214         }
215 
ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen, IpcOptions options)216         public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen, IpcOptions options)
217         {
218             BaseStream = baseStream ?? throw new ArgumentNullException(nameof(baseStream));
219             Schema = schema ?? throw new ArgumentNullException(nameof(schema));
220             _leaveOpen = leaveOpen;
221 
222             Buffers = ArrayPool<byte>.Create();
223             Builder = new FlatBufferBuilder(1024);
224             HasWrittenSchema = false;
225 
226             _fieldTypeBuilder = new ArrowTypeFlatbufferBuilder(Builder);
227             _options = options ?? IpcOptions.Default;
228         }
229 
230 
CreateSelfAndChildrenFieldNodes(ArrayData data)231         private void CreateSelfAndChildrenFieldNodes(ArrayData data)
232         {
233             if (data.DataType is NestedType)
234             {
235                 // flatbuffer struct vectors have to be created in reverse order
236                 for (int i = data.Children.Length - 1; i >= 0; i--)
237                 {
238                     CreateSelfAndChildrenFieldNodes(data.Children[i]);
239                 }
240             }
241             Flatbuf.FieldNode.CreateFieldNode(Builder, data.Length, data.NullCount);
242         }
243 
CountAllNodes(IReadOnlyDictionary<string, Field> fields)244         private static int CountAllNodes(IReadOnlyDictionary<string, Field> fields)
245         {
246             int count = 0;
247             foreach (Field arrowArray in fields.Values)
248             {
249                 CountSelfAndChildrenNodes(arrowArray.DataType, ref count);
250             }
251             return count;
252         }
253 
CountSelfAndChildrenNodes(IArrowType type, ref int count)254         private static void CountSelfAndChildrenNodes(IArrowType type, ref int count)
255         {
256             if (type is NestedType nestedType)
257             {
258                 foreach (Field childField in nestedType.Fields)
259                 {
260                     CountSelfAndChildrenNodes(childField.DataType, ref count);
261                 }
262             }
263             count++;
264         }
265 
WriteRecordBatchInternal(RecordBatch recordBatch)266         private protected void WriteRecordBatchInternal(RecordBatch recordBatch)
267         {
268             // TODO: Truncate buffers with extraneous padding / unused capacity
269 
270             if (!HasWrittenSchema)
271             {
272                 WriteSchema(Schema);
273                 HasWrittenSchema = true;
274             }
275 
276             if (!HasWrittenDictionaryBatch)
277             {
278                 DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo);
279                 WriteDictionaries(recordBatch);
280                 HasWrittenDictionaryBatch = true;
281             }
282 
283             (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, VectorOffset fieldNodesVectorOffset) =
284                 PreparingWritingRecordBatch(recordBatch);
285 
286             VectorOffset buffersVectorOffset = Builder.EndVector();
287 
288             // Serialize record batch
289 
290             StartingWritingRecordBatch();
291 
292             Offset<Flatbuf.RecordBatch> recordBatchOffset = Flatbuf.RecordBatch.CreateRecordBatch(Builder, recordBatch.Length,
293                 fieldNodesVectorOffset,
294                 buffersVectorOffset);
295 
296             long metadataLength = WriteMessage(Flatbuf.MessageHeader.RecordBatch,
297                 recordBatchOffset, recordBatchBuilder.TotalLength);
298 
299             long bufferLength = WriteBufferData(recordBatchBuilder.Buffers);
300 
301             FinishedWritingRecordBatch(bufferLength, metadataLength);
302         }
303 
WriteRecordBatchInternalAsync(RecordBatch recordBatch, CancellationToken cancellationToken = default)304         private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBatch,
305             CancellationToken cancellationToken = default)
306         {
307             // TODO: Truncate buffers with extraneous padding / unused capacity
308 
309             if (!HasWrittenSchema)
310             {
311                 await WriteSchemaAsync(Schema, cancellationToken).ConfigureAwait(false);
312                 HasWrittenSchema = true;
313             }
314 
315             if (!HasWrittenDictionaryBatch)
316             {
317                 DictionaryCollector.Collect(recordBatch, ref _dictionaryMemo);
318                 await WriteDictionariesAsync(recordBatch, cancellationToken).ConfigureAwait(false);
319                 HasWrittenDictionaryBatch = true;
320             }
321 
322             (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, VectorOffset fieldNodesVectorOffset) =
323                 PreparingWritingRecordBatch(recordBatch);
324 
325             VectorOffset buffersVectorOffset = Builder.EndVector();
326 
327             // Serialize record batch
328 
329             StartingWritingRecordBatch();
330 
331             Offset<Flatbuf.RecordBatch> recordBatchOffset = Flatbuf.RecordBatch.CreateRecordBatch(Builder, recordBatch.Length,
332                 fieldNodesVectorOffset,
333                 buffersVectorOffset);
334 
335             long metadataLength = await WriteMessageAsync(Flatbuf.MessageHeader.RecordBatch,
336                 recordBatchOffset, recordBatchBuilder.TotalLength,
337                 cancellationToken).ConfigureAwait(false);
338 
339             long bufferLength = await WriteBufferDataAsync(recordBatchBuilder.Buffers, cancellationToken).ConfigureAwait(false);
340 
341             FinishedWritingRecordBatch(bufferLength, metadataLength);
342         }
343 
WriteBufferData(IReadOnlyList<ArrowRecordBatchFlatBufferBuilder.Buffer> buffers)344         private long WriteBufferData(IReadOnlyList<ArrowRecordBatchFlatBufferBuilder.Buffer> buffers)
345         {
346             long bodyLength = 0;
347 
348             for (int i = 0; i < buffers.Count; i++)
349             {
350                 ArrowBuffer buffer = buffers[i].DataBuffer;
351                 if (buffer.IsEmpty)
352                     continue;
353 
354                 WriteBuffer(buffer);
355 
356                 int paddedLength = checked((int)BitUtility.RoundUpToMultipleOf8(buffer.Length));
357                 int padding = paddedLength - buffer.Length;
358                 if (padding > 0)
359                 {
360                     WritePadding(padding);
361                 }
362 
363                 bodyLength += paddedLength;
364             }
365 
366             // Write padding so the record batch message body length is a multiple of 8 bytes
367 
368             int bodyPaddingLength = CalculatePadding(bodyLength);
369 
370             WritePadding(bodyPaddingLength);
371 
372             return bodyLength + bodyPaddingLength;
373         }
374 
WriteBufferDataAsync(IReadOnlyList<ArrowRecordBatchFlatBufferBuilder.Buffer> buffers, CancellationToken cancellationToken = default)375         private async ValueTask<long> WriteBufferDataAsync(IReadOnlyList<ArrowRecordBatchFlatBufferBuilder.Buffer> buffers, CancellationToken cancellationToken = default)
376         {
377             long bodyLength = 0;
378 
379             for (int i = 0; i < buffers.Count; i++)
380             {
381                 ArrowBuffer buffer = buffers[i].DataBuffer;
382                 if (buffer.IsEmpty)
383                     continue;
384 
385                 await WriteBufferAsync(buffer, cancellationToken).ConfigureAwait(false);
386 
387                 int paddedLength = checked((int)BitUtility.RoundUpToMultipleOf8(buffer.Length));
388                 int padding = paddedLength - buffer.Length;
389                 if (padding > 0)
390                 {
391                     await WritePaddingAsync(padding).ConfigureAwait(false);
392                 }
393 
394                 bodyLength += paddedLength;
395             }
396 
397             // Write padding so the record batch message body length is a multiple of 8 bytes
398 
399             int bodyPaddingLength = CalculatePadding(bodyLength);
400 
401             await WritePaddingAsync(bodyPaddingLength).ConfigureAwait(false);
402 
403             return bodyLength + bodyPaddingLength;
404         }
405 
PreparingWritingRecordBatch(RecordBatch recordBatch)406         private Tuple<ArrowRecordBatchFlatBufferBuilder, VectorOffset> PreparingWritingRecordBatch(RecordBatch recordBatch)
407         {
408             return PreparingWritingRecordBatch(recordBatch.Schema.Fields, recordBatch.ArrayList);
409         }
410 
PreparingWritingRecordBatch(IReadOnlyDictionary<string, Field> fields, IReadOnlyList<IArrowArray> arrays)411         private Tuple<ArrowRecordBatchFlatBufferBuilder, VectorOffset> PreparingWritingRecordBatch(IReadOnlyDictionary<string, Field> fields, IReadOnlyList<IArrowArray> arrays)
412         {
413             Builder.Clear();
414 
415             // Serialize field nodes
416 
417             int fieldCount = fields.Count;
418 
419             Flatbuf.RecordBatch.StartNodesVector(Builder, CountAllNodes(fields));
420 
421             // flatbuffer struct vectors have to be created in reverse order
422             for (int i = fieldCount - 1; i >= 0; i--)
423             {
424                 CreateSelfAndChildrenFieldNodes(arrays[i].Data);
425             }
426 
427             VectorOffset fieldNodesVectorOffset = Builder.EndVector();
428 
429             // Serialize buffers
430 
431             var recordBatchBuilder = new ArrowRecordBatchFlatBufferBuilder();
432             for (int i = 0; i < fieldCount; i++)
433             {
434                 IArrowArray fieldArray = arrays[i];
435                 fieldArray.Accept(recordBatchBuilder);
436             }
437 
438             IReadOnlyList<ArrowRecordBatchFlatBufferBuilder.Buffer> buffers = recordBatchBuilder.Buffers;
439 
440             Flatbuf.RecordBatch.StartBuffersVector(Builder, buffers.Count);
441 
442             // flatbuffer struct vectors have to be created in reverse order
443             for (int i = buffers.Count - 1; i >= 0; i--)
444             {
445                 Flatbuf.Buffer.CreateBuffer(Builder,
446                     buffers[i].Offset, buffers[i].DataBuffer.Length);
447             }
448 
449             return Tuple.Create(recordBatchBuilder, fieldNodesVectorOffset);
450         }
451 
452 
WriteDictionaries(RecordBatch recordBatch)453         private protected void WriteDictionaries(RecordBatch recordBatch)
454         {
455             foreach (Field field in recordBatch.Schema.Fields.Values)
456             {
457                 WriteDictionary(field);
458             }
459         }
460 
WriteDictionary(Field field)461         private protected void WriteDictionary(Field field)
462         {
463             if (field.DataType.TypeId != ArrowTypeId.Dictionary)
464             {
465                 if (field.DataType is NestedType nestedType)
466                 {
467                     foreach (Field child in nestedType.Fields)
468                     {
469                         WriteDictionary(child);
470                     }
471                 }
472                 return;
473             }
474 
475             (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, Offset<Flatbuf.DictionaryBatch> dictionaryBatchOffset) =
476                 CreateDictionaryBatchOffset(field);
477 
478             WriteMessage(Flatbuf.MessageHeader.DictionaryBatch,
479                 dictionaryBatchOffset, recordBatchBuilder.TotalLength);
480 
481             WriteBufferData(recordBatchBuilder.Buffers);
482         }
483 
WriteDictionariesAsync(RecordBatch recordBatch, CancellationToken cancellationToken)484         private protected async Task WriteDictionariesAsync(RecordBatch recordBatch, CancellationToken cancellationToken)
485         {
486             foreach (Field field in recordBatch.Schema.Fields.Values)
487             {
488                 await WriteDictionaryAsync(field, cancellationToken).ConfigureAwait(false);
489             }
490         }
491 
WriteDictionaryAsync(Field field, CancellationToken cancellationToken)492         private protected async Task WriteDictionaryAsync(Field field, CancellationToken cancellationToken)
493         {
494             if (field.DataType.TypeId != ArrowTypeId.Dictionary)
495             {
496                 if (field.DataType is NestedType nestedType)
497                 {
498                     foreach (Field child in nestedType.Fields)
499                     {
500                         await WriteDictionaryAsync(child, cancellationToken).ConfigureAwait(false);
501                     }
502                 }
503                 return;
504             }
505 
506             (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, Offset<Flatbuf.DictionaryBatch> dictionaryBatchOffset) =
507                 CreateDictionaryBatchOffset(field);
508 
509             await WriteMessageAsync(Flatbuf.MessageHeader.DictionaryBatch,
510                 dictionaryBatchOffset, recordBatchBuilder.TotalLength, cancellationToken).ConfigureAwait(false);
511 
512             await WriteBufferDataAsync(recordBatchBuilder.Buffers, cancellationToken).ConfigureAwait(false);
513         }
514 
CreateDictionaryBatchOffset(Field field)515         private Tuple<ArrowRecordBatchFlatBufferBuilder, Offset<Flatbuf.DictionaryBatch>> CreateDictionaryBatchOffset(Field field)
516         {
517             Field dictionaryField = new Field("dummy", ((DictionaryType)field.DataType).ValueType, false);
518             long id = DictionaryMemo.GetId(field);
519             IArrowArray dictionary = DictionaryMemo.GetDictionary(id);
520 
521             var fieldsDictionary = new Dictionary<string, Field> {
522                 { dictionaryField.Name, dictionaryField } };
523 
524             var arrays = new List<IArrowArray> { dictionary };
525 
526             (ArrowRecordBatchFlatBufferBuilder recordBatchBuilder, VectorOffset fieldNodesVectorOffset) =
527                                                             PreparingWritingRecordBatch(fieldsDictionary, arrays);
528 
529             VectorOffset buffersVectorOffset = Builder.EndVector();
530 
531             // Serialize record batch
532             Offset<Flatbuf.RecordBatch> recordBatchOffset = Flatbuf.RecordBatch.CreateRecordBatch(Builder, dictionary.Length,
533                 fieldNodesVectorOffset,
534                 buffersVectorOffset);
535 
536             // TODO: Support delta.
537             Offset<Flatbuf.DictionaryBatch> dictionaryBatchOffset = Flatbuf.DictionaryBatch.CreateDictionaryBatch(Builder, id, recordBatchOffset, false);
538             return Tuple.Create(recordBatchBuilder, dictionaryBatchOffset);
539         }
540 
WriteStartInternal()541         private protected virtual void WriteStartInternal()
542         {
543             if (!HasWrittenSchema)
544             {
545                 WriteSchema(Schema);
546                 HasWrittenSchema = true;
547             }
548         }
549 
WriteStartInternalAsync(CancellationToken cancellationToken)550         private protected async virtual ValueTask WriteStartInternalAsync(CancellationToken cancellationToken)
551         {
552             if (!HasWrittenSchema)
553             {
554                 await WriteSchemaAsync(Schema, cancellationToken).ConfigureAwait(false);
555                 HasWrittenSchema = true;
556             }
557         }
558 
WriteEndInternal()559         private protected virtual void WriteEndInternal()
560         {
561             WriteIpcMessageLength(length: 0);
562         }
563 
WriteEndInternalAsync(CancellationToken cancellationToken)564         private protected virtual ValueTask WriteEndInternalAsync(CancellationToken cancellationToken)
565         {
566             return WriteIpcMessageLengthAsync(length: 0, cancellationToken);
567         }
568 
StartingWritingRecordBatch()569         private protected virtual void StartingWritingRecordBatch()
570         {
571         }
572 
FinishedWritingRecordBatch(long bodyLength, long metadataLength)573         private protected virtual void FinishedWritingRecordBatch(long bodyLength, long metadataLength)
574         {
575         }
576 
WriteRecordBatch(RecordBatch recordBatch)577         public virtual void WriteRecordBatch(RecordBatch recordBatch)
578         {
579             WriteRecordBatchInternal(recordBatch);
580         }
581 
WriteRecordBatchAsync(RecordBatch recordBatch, CancellationToken cancellationToken = default)582         public virtual Task WriteRecordBatchAsync(RecordBatch recordBatch, CancellationToken cancellationToken = default)
583         {
584             return WriteRecordBatchInternalAsync(recordBatch, cancellationToken);
585         }
586 
WriteStart()587         public void WriteStart()
588         {
589             if (!HasWrittenStart)
590             {
591                 WriteStartInternal();
592                 HasWrittenStart = true;
593             }
594         }
595 
WriteStartAsync(CancellationToken cancellationToken = default)596         public async Task WriteStartAsync(CancellationToken cancellationToken = default)
597         {
598             if (!HasWrittenStart)
599             {
600                 await WriteStartInternalAsync(cancellationToken);
601                 HasWrittenStart = true;
602             }
603         }
604 
WriteEnd()605         public void WriteEnd()
606         {
607             if (!HasWrittenEnd)
608             {
609                 WriteEndInternal();
610                 HasWrittenEnd = true;
611             }
612         }
613 
WriteEndAsync(CancellationToken cancellationToken = default)614         public async Task WriteEndAsync(CancellationToken cancellationToken = default)
615         {
616             if (!HasWrittenEnd)
617             {
618                 await WriteEndInternalAsync(cancellationToken);
619                 HasWrittenEnd = true;
620             }
621         }
622 
WriteBuffer(ArrowBuffer arrowBuffer)623         private void WriteBuffer(ArrowBuffer arrowBuffer)
624         {
625             BaseStream.Write(arrowBuffer.Memory);
626         }
627 
WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken cancellationToken = default)628         private ValueTask WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken cancellationToken = default)
629         {
630             return BaseStream.WriteAsync(arrowBuffer.Memory, cancellationToken);
631         }
632 
SerializeSchema(Schema schema)633         private protected Offset<Flatbuf.Schema> SerializeSchema(Schema schema)
634         {
635             // Build metadata
636             VectorOffset metadataVectorOffset = default;
637             if (schema.HasMetadata)
638             {
639                 Offset<Flatbuf.KeyValue>[] metadataOffsets = GetMetadataOffsets(schema.Metadata);
640                 metadataVectorOffset = Flatbuf.Schema.CreateCustomMetadataVector(Builder, metadataOffsets);
641             }
642 
643             // Build fields
644             var fieldOffsets = new Offset<Flatbuf.Field>[schema.Fields.Count];
645             for (int i = 0; i < fieldOffsets.Length; i++)
646             {
647                 Field field = schema.GetFieldByIndex(i);
648                 StringOffset fieldNameOffset = Builder.CreateString(field.Name);
649                 ArrowTypeFlatbufferBuilder.FieldType fieldType = _fieldTypeBuilder.BuildFieldType(field);
650 
651                 VectorOffset fieldChildrenVectorOffset = GetChildrenFieldOffset(field);
652                 VectorOffset fieldMetadataVectorOffset = GetFieldMetadataOffset(field);
653                 Offset<Flatbuf.DictionaryEncoding> dictionaryOffset = GetDictionaryOffset(field);
654 
655                 fieldOffsets[i] = Flatbuf.Field.CreateField(Builder,
656                     fieldNameOffset, field.IsNullable, fieldType.Type, fieldType.Offset,
657                     dictionaryOffset, fieldChildrenVectorOffset, fieldMetadataVectorOffset);
658             }
659 
660             VectorOffset fieldsVectorOffset = Flatbuf.Schema.CreateFieldsVector(Builder, fieldOffsets);
661 
662             // Build schema
663 
664             Flatbuf.Endianness endianness = BitConverter.IsLittleEndian ? Flatbuf.Endianness.Little : Flatbuf.Endianness.Big;
665 
666             return Flatbuf.Schema.CreateSchema(
667                 Builder, endianness, fieldsVectorOffset, metadataVectorOffset);
668         }
669 
GetChildrenFieldOffset(Field field)670         private VectorOffset GetChildrenFieldOffset(Field field)
671         {
672             IArrowType targetDataType = field.DataType is DictionaryType dictionaryType ?
673                 dictionaryType.ValueType :
674                 field.DataType;
675 
676             if (!(targetDataType is NestedType type))
677             {
678                 return default;
679             }
680 
681             int childrenCount = type.Fields.Count;
682             var children = new Offset<Flatbuf.Field>[childrenCount];
683 
684             for (int i = 0; i < childrenCount; i++)
685             {
686                 Field childField = type.Fields[i];
687                 StringOffset childFieldNameOffset = Builder.CreateString(childField.Name);
688                 ArrowTypeFlatbufferBuilder.FieldType childFieldType = _fieldTypeBuilder.BuildFieldType(childField);
689 
690                 VectorOffset childFieldChildrenVectorOffset = GetChildrenFieldOffset(childField);
691                 VectorOffset childFieldMetadataVectorOffset = GetFieldMetadataOffset(childField);
692                 Offset<Flatbuf.DictionaryEncoding> dictionaryOffset = GetDictionaryOffset(childField);
693 
694                 children[i] = Flatbuf.Field.CreateField(Builder,
695                     childFieldNameOffset, childField.IsNullable, childFieldType.Type, childFieldType.Offset,
696                     dictionaryOffset, childFieldChildrenVectorOffset, childFieldMetadataVectorOffset);
697             }
698 
699             return Builder.CreateVectorOfTables(children);
700         }
701 
GetFieldMetadataOffset(Field field)702         private VectorOffset GetFieldMetadataOffset(Field field)
703         {
704             if (!field.HasMetadata)
705             {
706                 return default;
707             }
708 
709             Offset<Flatbuf.KeyValue>[] metadataOffsets = GetMetadataOffsets(field.Metadata);
710             return Flatbuf.Field.CreateCustomMetadataVector(Builder, metadataOffsets);
711         }
712 
GetDictionaryOffset(Field field)713         private Offset<Flatbuf.DictionaryEncoding> GetDictionaryOffset(Field field)
714         {
715             if (field.DataType.TypeId != ArrowTypeId.Dictionary)
716             {
717                 return default;
718             }
719 
720             long id = DictionaryMemo.GetOrAssignId(field);
721             var dicType = field.DataType as DictionaryType;
722             var indexType = dicType.IndexType as NumberType;
723 
724             Offset<Flatbuf.Int> indexOffset = Flatbuf.Int.CreateInt(Builder, indexType.BitWidth, indexType.IsSigned);
725             return Flatbuf.DictionaryEncoding.CreateDictionaryEncoding(Builder, id, indexOffset, dicType.Ordered);
726         }
727 
GetMetadataOffsets(IReadOnlyDictionary<string, string> metadata)728         private Offset<Flatbuf.KeyValue>[] GetMetadataOffsets(IReadOnlyDictionary<string, string> metadata)
729         {
730             Debug.Assert(metadata != null);
731             Debug.Assert(metadata.Count > 0);
732 
733             Offset<Flatbuf.KeyValue>[] metadataOffsets = new Offset<Flatbuf.KeyValue>[metadata.Count];
734             int index = 0;
735             foreach (KeyValuePair<string, string> metadatum in metadata)
736             {
737                 StringOffset keyOffset = Builder.CreateString(metadatum.Key);
738                 StringOffset valueOffset = Builder.CreateString(metadatum.Value);
739 
740                 metadataOffsets[index++] = Flatbuf.KeyValue.CreateKeyValue(Builder, keyOffset, valueOffset);
741             }
742 
743             return metadataOffsets;
744         }
745 
WriteSchema(Schema schema)746         private Offset<Flatbuf.Schema> WriteSchema(Schema schema)
747         {
748             Builder.Clear();
749 
750             // Build schema
751 
752             Offset<Flatbuf.Schema> schemaOffset = SerializeSchema(schema);
753 
754             // Build message
755 
756             WriteMessage(Flatbuf.MessageHeader.Schema, schemaOffset, 0);
757 
758             return schemaOffset;
759         }
760 
WriteSchemaAsync(Schema schema, CancellationToken cancellationToken)761         private async ValueTask<Offset<Flatbuf.Schema>> WriteSchemaAsync(Schema schema, CancellationToken cancellationToken)
762         {
763             Builder.Clear();
764 
765             // Build schema
766 
767             Offset<Flatbuf.Schema> schemaOffset = SerializeSchema(schema);
768 
769             // Build message
770 
771             await WriteMessageAsync(Flatbuf.MessageHeader.Schema, schemaOffset, 0, cancellationToken)
772                 .ConfigureAwait(false);
773 
774             return schemaOffset;
775         }
776 
777         /// <summary>
778         /// Writes the message to the <see cref="BaseStream"/>.
779         /// </summary>
780         /// <returns>
781         /// The number of bytes written to the stream.
782         /// </returns>
783         private protected long WriteMessage<T>(
784             Flatbuf.MessageHeader headerType, Offset<T> headerOffset, int bodyLength)
785             where T : struct
786         {
787             Offset<Flatbuf.Message> messageOffset = Flatbuf.Message.CreateMessage(
788                 Builder, CurrentMetadataVersion, headerType, headerOffset.Value,
789                 bodyLength);
790 
Builder.FinishApache.Arrow.Ipc.ArrowStreamWriter.__anon2791             Builder.Finish(messageOffset.Value);
792 
793             ReadOnlyMemory<byte> messageData = Builder.DataBuffer.ToReadOnlyMemory(Builder.DataBuffer.Position, Builder.Offset);
794             int messagePaddingLength = CalculatePadding(_options.SizeOfIpcLength + messageData.Length);
795 
796             WriteIpcMessageLength(messageData.Length + messagePaddingLength);
797 
BaseStream.WriteApache.Arrow.Ipc.ArrowStreamWriter.__anon2798             BaseStream.Write(messageData);
WritePaddingApache.Arrow.Ipc.ArrowStreamWriter.__anon2799             WritePadding(messagePaddingLength);
800 
801             checked
802             {
803                 return _options.SizeOfIpcLength + messageData.Length + messagePaddingLength;
804             }
805         }
806 
807         /// <summary>
808         /// Writes the message to the <see cref="BaseStream"/>.
809         /// </summary>
810         /// <returns>
811         /// The number of bytes written to the stream.
812         /// </returns>
813         private protected virtual async ValueTask<long> WriteMessageAsync<T>(
814             Flatbuf.MessageHeader headerType, Offset<T> headerOffset, int bodyLength,
815             CancellationToken cancellationToken)
816             where T : struct
817         {
818             Offset<Flatbuf.Message> messageOffset = Flatbuf.Message.CreateMessage(
819                 Builder, CurrentMetadataVersion, headerType, headerOffset.Value,
820                 bodyLength);
821 
Builder.FinishApache.Arrow.Ipc.ArrowStreamWriter.__anon3822             Builder.Finish(messageOffset.Value);
823 
824             ReadOnlyMemory<byte> messageData = Builder.DataBuffer.ToReadOnlyMemory(Builder.DataBuffer.Position, Builder.Offset);
825             int messagePaddingLength = CalculatePadding(_options.SizeOfIpcLength + messageData.Length);
826 
827             await WriteIpcMessageLengthAsync(messageData.Length + messagePaddingLength, cancellationToken)
ConfigureAwaitApache.Arrow.Ipc.ArrowStreamWriter.__anon3828                 .ConfigureAwait(false);
829 
ConfigureAwaitApache.Arrow.Ipc.ArrowStreamWriter.__anon3830             await BaseStream.WriteAsync(messageData, cancellationToken).ConfigureAwait(false);
ConfigureAwaitApache.Arrow.Ipc.ArrowStreamWriter.__anon3831             await WritePaddingAsync(messagePaddingLength).ConfigureAwait(false);
832 
833             checked
834             {
835                 return _options.SizeOfIpcLength + messageData.Length + messagePaddingLength;
836             }
837         }
838 
WriteFlatBuffer()839         private protected void WriteFlatBuffer()
840         {
841             ReadOnlyMemory<byte> segment = Builder.DataBuffer.ToReadOnlyMemory(Builder.DataBuffer.Position, Builder.Offset);
842 
843             BaseStream.Write(segment);
844         }
845 
WriteFlatBufferAsync(CancellationToken cancellationToken = default)846         private protected async ValueTask WriteFlatBufferAsync(CancellationToken cancellationToken = default)
847         {
848             ReadOnlyMemory<byte> segment = Builder.DataBuffer.ToReadOnlyMemory(Builder.DataBuffer.Position, Builder.Offset);
849 
850             await BaseStream.WriteAsync(segment, cancellationToken).ConfigureAwait(false);
851         }
852 
WriteIpcMessageLength(int length)853         private void WriteIpcMessageLength(int length)
854         {
855             Buffers.RentReturn(_options.SizeOfIpcLength, (buffer) =>
856             {
857                 Memory<byte> currentBufferPosition = buffer;
858                 if (!_options.WriteLegacyIpcFormat)
859                 {
860                     BinaryPrimitives.WriteInt32LittleEndian(
861                         currentBufferPosition.Span, MessageSerializer.IpcContinuationToken);
862                     currentBufferPosition = currentBufferPosition.Slice(sizeof(int));
863                 }
864 
865                 BinaryPrimitives.WriteInt32LittleEndian(currentBufferPosition.Span, length);
866                 BaseStream.Write(buffer);
867             });
868         }
869 
WriteIpcMessageLengthAsync(int length, CancellationToken cancellationToken)870         private async ValueTask WriteIpcMessageLengthAsync(int length, CancellationToken cancellationToken)
871         {
872             await Buffers.RentReturnAsync(_options.SizeOfIpcLength, async (buffer) =>
873             {
874                 Memory<byte> currentBufferPosition = buffer;
875                 if (!_options.WriteLegacyIpcFormat)
876                 {
877                     BinaryPrimitives.WriteInt32LittleEndian(
878                         currentBufferPosition.Span, MessageSerializer.IpcContinuationToken);
879                     currentBufferPosition = currentBufferPosition.Slice(sizeof(int));
880                 }
881 
882                 BinaryPrimitives.WriteInt32LittleEndian(currentBufferPosition.Span, length);
883                 await BaseStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
884             }).ConfigureAwait(false);
885         }
886 
CalculatePadding(long offset, int alignment = 8)887         protected int CalculatePadding(long offset, int alignment = 8)
888         {
889             long result = BitUtility.RoundUpToMultiplePowerOfTwo(offset, alignment) - offset;
890             checked
891             {
892                 return (int)result;
893             }
894         }
895 
WritePadding(int length)896         private protected void WritePadding(int length)
897         {
898             if (length > 0)
899             {
900                 BaseStream.Write(s_padding.AsMemory(0, Math.Min(s_padding.Length, length)));
901             }
902         }
903 
WritePaddingAsync(int length)904         private protected ValueTask WritePaddingAsync(int length)
905         {
906             if (length > 0)
907             {
908                 return BaseStream.WriteAsync(s_padding.AsMemory(0, Math.Min(s_padding.Length, length)));
909             }
910 
911             return default;
912         }
913 
Dispose()914         public virtual void Dispose()
915         {
916             if (!_leaveOpen)
917             {
918                 BaseStream.Dispose();
919             }
920         }
921     }
922 
923     internal static class DictionaryCollector
924     {
Collect(RecordBatch recordBatch, ref DictionaryMemo dictionaryMemo)925         internal static void Collect(RecordBatch recordBatch, ref DictionaryMemo dictionaryMemo)
926         {
927             Schema schema = recordBatch.Schema;
928             for (int i = 0; i < schema.Fields.Count; i++)
929             {
930                 Field field = schema.GetFieldByIndex(i);
931                 IArrowArray array = recordBatch.Column(i);
932 
933                 CollectDictionary(field, array.Data, ref dictionaryMemo);
934             }
935         }
936 
CollectDictionary(Field field, ArrayData arrayData, ref DictionaryMemo dictionaryMemo)937         private static void CollectDictionary(Field field, ArrayData arrayData, ref DictionaryMemo dictionaryMemo)
938         {
939             if (field.DataType is DictionaryType dictionaryType)
940             {
941                 if (arrayData.Dictionary == null)
942                 {
943                     throw new ArgumentException($"{nameof(arrayData.Dictionary)} must not be null");
944                 }
945                 arrayData.Dictionary.EnsureDataType(dictionaryType.ValueType.TypeId);
946 
947                 IArrowArray dictionary = ArrowArrayFactory.BuildArray(arrayData.Dictionary);
948 
949                 dictionaryMemo ??= new DictionaryMemo();
950                 long id = dictionaryMemo.GetOrAssignId(field);
951 
952                 dictionaryMemo.AddOrReplaceDictionary(id, dictionary);
953                 WalkChildren(dictionary.Data, ref dictionaryMemo);
954             }
955             else
956             {
957                 WalkChildren(arrayData, ref dictionaryMemo);
958             }
959         }
960 
WalkChildren(ArrayData arrayData, ref DictionaryMemo dictionaryMemo)961         private static void WalkChildren(ArrayData arrayData, ref DictionaryMemo dictionaryMemo)
962         {
963             ArrayData[] children = arrayData.Children;
964 
965             if (children == null)
966             {
967                 return;
968             }
969 
970             if (arrayData.DataType is NestedType nestedType)
971             {
972                 for (int i = 0; i < nestedType.Fields.Count; i++)
973                 {
974                     Field childField = nestedType.Fields[i];
975                     ArrayData child = children[i];
976 
977                     CollectDictionary(childField, child, ref dictionaryMemo);
978                 }
979             }
980         }
981     }
982 }
983