1 // Licensed to the Apache Software Foundation (ASF) under one or more
2 // contributor license agreements. See the NOTICE file distributed with
3 // this work for additional information regarding copyright ownership.
4 // The ASF licenses this file to You under the Apache License, Version 2.0
5 // (the "License"); you may not use this file except in compliance with
6 // the License.  You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 
16 using System;
17 using System.Buffers.Binary;
18 using System.Collections.Generic;
19 using System.IO;
20 using System.Linq;
21 using System.Net;
22 using System.Net.Sockets;
23 using System.Threading.Tasks;
24 using Apache.Arrow.Ipc;
25 using Apache.Arrow.Types;
26 using Xunit;
27 
28 namespace Apache.Arrow.Tests
29 {
30     public class ArrowStreamWriterTests
31     {
32         [Fact]
Ctor_LeaveOpenDefault_StreamClosedOnDispose()33         public void Ctor_LeaveOpenDefault_StreamClosedOnDispose()
34         {
35             RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100);
36             var stream = new MemoryStream();
37             new ArrowStreamWriter(stream, originalBatch.Schema).Dispose();
38             Assert.Throws<ObjectDisposedException>(() => stream.Position);
39         }
40 
41         [Fact]
Ctor_LeaveOpenFalse_StreamClosedOnDispose()42         public void Ctor_LeaveOpenFalse_StreamClosedOnDispose()
43         {
44             RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100);
45             var stream = new MemoryStream();
46             new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: false).Dispose();
47             Assert.Throws<ObjectDisposedException>(() => stream.Position);
48         }
49 
50         [Fact]
Ctor_LeaveOpenTrue_StreamValidOnDispose()51         public void Ctor_LeaveOpenTrue_StreamValidOnDispose()
52         {
53             RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100);
54             var stream = new MemoryStream();
55             new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true).Dispose();
56             Assert.Equal(0, stream.Position);
57         }
58 
59         [Fact]
CanWriteToNetworkStream()60         public void CanWriteToNetworkStream()
61         {
62             RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100);
63 
64             const int port = 32153;
65             TcpListener listener = new TcpListener(IPAddress.Loopback, port);
66             listener.Start();
67 
68             using (TcpClient sender = new TcpClient())
69             {
70                 sender.Connect(IPAddress.Loopback, port);
71                 NetworkStream stream = sender.GetStream();
72 
73                 using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema))
74                 {
75                     writer.WriteRecordBatch(originalBatch);
76                     writer.WriteEnd();
77 
78                     stream.Flush();
79                 }
80             }
81 
82             using (TcpClient receiver = listener.AcceptTcpClient())
83             {
84                 NetworkStream stream = receiver.GetStream();
85                 using (var reader = new ArrowStreamReader(stream))
86                 {
87                     RecordBatch newBatch = reader.ReadNextRecordBatch();
88                     ArrowReaderVerifier.CompareBatches(originalBatch, newBatch);
89                 }
90             }
91         }
92 
93         [Fact]
CanWriteToNetworkStreamAsync()94         public async Task CanWriteToNetworkStreamAsync()
95         {
96             RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100);
97 
98             const int port = 32154;
99             TcpListener listener = new TcpListener(IPAddress.Loopback, port);
100             listener.Start();
101 
102             using (TcpClient sender = new TcpClient())
103             {
104                 sender.Connect(IPAddress.Loopback, port);
105                 NetworkStream stream = sender.GetStream();
106 
107                 using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema))
108                 {
109                     await writer.WriteRecordBatchAsync(originalBatch);
110                     await writer.WriteEndAsync();
111 
112                     stream.Flush();
113                 }
114             }
115 
116             using (TcpClient receiver = listener.AcceptTcpClient())
117             {
118                 NetworkStream stream = receiver.GetStream();
119                 using (var reader = new ArrowStreamReader(stream))
120                 {
121                     RecordBatch newBatch = reader.ReadNextRecordBatch();
122                     ArrowReaderVerifier.CompareBatches(originalBatch, newBatch);
123                 }
124             }
125         }
126 
127         [Fact]
WriteEmptyBatch()128         public void WriteEmptyBatch()
129         {
130             RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 0);
131 
132             TestRoundTripRecordBatch(originalBatch);
133         }
134 
135         [Fact]
WriteEmptyBatchAsync()136         public async Task WriteEmptyBatchAsync()
137         {
138             RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 0);
139 
140             await TestRoundTripRecordBatchAsync(originalBatch);
141         }
142 
143         [Fact]
WriteBatchWithNulls()144         public void WriteBatchWithNulls()
145         {
146             RecordBatch originalBatch = new RecordBatch.Builder()
147                 .Append("Column1", false, col => col.Int32(array => array.AppendRange(Enumerable.Range(0, 10))))
148                 .Append("Column2", true, new Int32Array(
149                     valueBuffer: new ArrowBuffer.Builder<int>().AppendRange(Enumerable.Range(0, 10)).Build(),
150                     nullBitmapBuffer: new ArrowBuffer.Builder<byte>().Append(0xfd).Append(0xff).Build(),
151                     length: 10,
152                     nullCount: 2,
153                     offset: 0))
154                 .Append("Column3", true, new Int32Array(
155                     valueBuffer: new ArrowBuffer.Builder<int>().AppendRange(Enumerable.Range(0, 10)).Build(),
156                     nullBitmapBuffer: new ArrowBuffer.Builder<byte>().Append(0x00).Append(0x00).Build(),
157                     length: 10,
158                     nullCount: 10,
159                     offset: 0))
160                 .Append("NullableBooleanColumn", true, new BooleanArray(
161                     valueBuffer: new ArrowBuffer.Builder<byte>().Append(0xfd).Append(0xff).Build(),
162                     nullBitmapBuffer: new ArrowBuffer.Builder<byte>().Append(0xed).Append(0xff).Build(),
163                     length: 10,
164                     nullCount: 3,
165                     offset: 0))
166                 .Build();
167 
168             TestRoundTripRecordBatch(originalBatch);
169         }
170 
171         [Fact]
WriteBatchWithNullsAsync()172         public async Task WriteBatchWithNullsAsync()
173         {
174             RecordBatch originalBatch = new RecordBatch.Builder()
175                 .Append("Column1", false, col => col.Int32(array => array.AppendRange(Enumerable.Range(0, 10))))
176                 .Append("Column2", true, new Int32Array(
177                     valueBuffer: new ArrowBuffer.Builder<int>().AppendRange(Enumerable.Range(0, 10)).Build(),
178                     nullBitmapBuffer: new ArrowBuffer.Builder<byte>().Append(0xfd).Append(0xff).Build(),
179                     length: 10,
180                     nullCount: 2,
181                     offset: 0))
182                 .Append("Column3", true, new Int32Array(
183                     valueBuffer: new ArrowBuffer.Builder<int>().AppendRange(Enumerable.Range(0, 10)).Build(),
184                     nullBitmapBuffer: new ArrowBuffer.Builder<byte>().Append(0x00).Append(0x00).Build(),
185                     length: 10,
186                     nullCount: 10,
187                     offset: 0))
188                 .Append("NullableBooleanColumn", true, new BooleanArray(
189                     valueBuffer: new ArrowBuffer.Builder<byte>().Append(0xfd).Append(0xff).Build(),
190                     nullBitmapBuffer: new ArrowBuffer.Builder<byte>().Append(0xed).Append(0xff).Build(),
191                     length: 10,
192                     nullCount: 3,
193                     offset: 0))
194                 .Build();
195 
196             await TestRoundTripRecordBatchAsync(originalBatch);
197         }
198 
TestRoundTripRecordBatch(RecordBatch originalBatch, IpcOptions options = null)199         private static void TestRoundTripRecordBatch(RecordBatch originalBatch, IpcOptions options = null)
200         {
201             using (MemoryStream stream = new MemoryStream())
202             {
203                 using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options))
204                 {
205                     writer.WriteRecordBatch(originalBatch);
206                     writer.WriteEnd();
207                 }
208 
209                 stream.Position = 0;
210 
211                 using (var reader = new ArrowStreamReader(stream))
212                 {
213                     RecordBatch newBatch = reader.ReadNextRecordBatch();
214                     ArrowReaderVerifier.CompareBatches(originalBatch, newBatch);
215                 }
216             }
217         }
218 
219 
TestRoundTripRecordBatchAsync(RecordBatch originalBatch, IpcOptions options = null)220         private static async Task TestRoundTripRecordBatchAsync(RecordBatch originalBatch, IpcOptions options = null)
221         {
222             using (MemoryStream stream = new MemoryStream())
223             {
224                 using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options))
225                 {
226                     await writer.WriteRecordBatchAsync(originalBatch);
227                     await writer.WriteEndAsync();
228                 }
229 
230                 stream.Position = 0;
231 
232                 using (var reader = new ArrowStreamReader(stream))
233                 {
234                     RecordBatch newBatch = reader.ReadNextRecordBatch();
235                     ArrowReaderVerifier.CompareBatches(originalBatch, newBatch);
236                 }
237             }
238         }
239 
240         [Fact]
WriteBatchWithCorrectPadding()241         public void WriteBatchWithCorrectPadding()
242         {
243             byte value1 = 0x04;
244             byte value2 = 0x14;
245             var batch = new RecordBatch(
246                 new Schema.Builder()
247                     .Field(f => f.Name("age").DataType(Int32Type.Default))
248                     .Field(f => f.Name("characterCount").DataType(Int32Type.Default))
249                     .Build(),
250                 new IArrowArray[]
251                 {
252                     new Int32Array(
253                         new ArrowBuffer(new byte[] { value1, value1, 0x00, 0x00 }),
254                         ArrowBuffer.Empty,
255                         length: 1,
256                         nullCount: 0,
257                         offset: 0),
258                     new Int32Array(
259                         new ArrowBuffer(new byte[] { value2, value2, 0x00, 0x00 }),
260                         ArrowBuffer.Empty,
261                         length: 1,
262                         nullCount: 0,
263                         offset: 0)
264                 },
265                 length: 1);
266 
267             TestRoundTripRecordBatch(batch);
268 
269             using (MemoryStream stream = new MemoryStream())
270             {
271                 using (var writer = new ArrowStreamWriter(stream, batch.Schema, leaveOpen: true))
272                 {
273                     writer.WriteRecordBatch(batch);
274                     writer.WriteEnd();
275                 }
276 
277                 byte[] writtenBytes = stream.ToArray();
278 
279                 // ensure that the data buffers at the end are 8-byte aligned
280                 Assert.Equal(value1, writtenBytes[writtenBytes.Length - 24]);
281                 Assert.Equal(value1, writtenBytes[writtenBytes.Length - 23]);
282                 for (int i = 22; i > 16; i--)
283                 {
284                     Assert.Equal(0, writtenBytes[writtenBytes.Length - i]);
285                 }
286 
287                 Assert.Equal(value2, writtenBytes[writtenBytes.Length - 16]);
288                 Assert.Equal(value2, writtenBytes[writtenBytes.Length - 15]);
289                 for (int i = 14; i > 8; i--)
290                 {
291                     Assert.Equal(0, writtenBytes[writtenBytes.Length - i]);
292                 }
293 
294                 // verify the EOS is written correctly
295                 for (int i = 8; i > 4; i--)
296                 {
297                     Assert.Equal(0xFF, writtenBytes[writtenBytes.Length - i]);
298                 }
299                 for (int i = 4; i > 0; i--)
300                 {
301                     Assert.Equal(0x00, writtenBytes[writtenBytes.Length - i]);
302                 }
303             }
304         }
305 
306         [Fact]
WriteBatchWithCorrectPaddingAsync()307         public async Task WriteBatchWithCorrectPaddingAsync()
308         {
309             byte value1 = 0x04;
310             byte value2 = 0x14;
311             var batch = new RecordBatch(
312                 new Schema.Builder()
313                     .Field(f => f.Name("age").DataType(Int32Type.Default))
314                     .Field(f => f.Name("characterCount").DataType(Int32Type.Default))
315                     .Build(),
316                 new IArrowArray[]
317                 {
318                     new Int32Array(
319                         new ArrowBuffer(new byte[] { value1, value1, 0x00, 0x00 }),
320                         ArrowBuffer.Empty,
321                         length: 1,
322                         nullCount: 0,
323                         offset: 0),
324                     new Int32Array(
325                         new ArrowBuffer(new byte[] { value2, value2, 0x00, 0x00 }),
326                         ArrowBuffer.Empty,
327                         length: 1,
328                         nullCount: 0,
329                         offset: 0)
330                 },
331                 length: 1);
332 
333             await TestRoundTripRecordBatchAsync(batch);
334 
335             using (MemoryStream stream = new MemoryStream())
336             {
337                 using (var writer = new ArrowStreamWriter(stream, batch.Schema, leaveOpen: true))
338                 {
339                     await writer.WriteRecordBatchAsync(batch);
340                     await writer.WriteEndAsync();
341                 }
342 
343                 byte[] writtenBytes = stream.ToArray();
344 
345                 // ensure that the data buffers at the end are 8-byte aligned
346                 Assert.Equal(value1, writtenBytes[writtenBytes.Length - 24]);
347                 Assert.Equal(value1, writtenBytes[writtenBytes.Length - 23]);
348                 for (int i = 22; i > 16; i--)
349                 {
350                     Assert.Equal(0, writtenBytes[writtenBytes.Length - i]);
351                 }
352 
353                 Assert.Equal(value2, writtenBytes[writtenBytes.Length - 16]);
354                 Assert.Equal(value2, writtenBytes[writtenBytes.Length - 15]);
355                 for (int i = 14; i > 8; i--)
356                 {
357                     Assert.Equal(0, writtenBytes[writtenBytes.Length - i]);
358                 }
359 
360                 // verify the EOS is written correctly
361                 for (int i = 8; i > 4; i--)
362                 {
363                     Assert.Equal(0xFF, writtenBytes[writtenBytes.Length - i]);
364                 }
365                 for (int i = 4; i > 0; i--)
366                 {
367                     Assert.Equal(0x00, writtenBytes[writtenBytes.Length - i]);
368                 }
369             }
370         }
371 
372         [Fact]
LegacyIpcFormatRoundTrips()373         public void LegacyIpcFormatRoundTrips()
374         {
375             RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100);
376             TestRoundTripRecordBatch(originalBatch, new IpcOptions() { WriteLegacyIpcFormat = true });
377         }
378 
379 
380         [Fact]
LegacyIpcFormatRoundTripsAsync()381         public async Task LegacyIpcFormatRoundTripsAsync()
382         {
383             RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100);
384             await TestRoundTripRecordBatchAsync(originalBatch, new IpcOptions() { WriteLegacyIpcFormat = true });
385         }
386 
387         [Theory]
388         [InlineData(true)]
389         [InlineData(false)]
WriteLegacyIpcFormat(bool writeLegacyIpcFormat)390         public void WriteLegacyIpcFormat(bool writeLegacyIpcFormat)
391         {
392             RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100);
393             var options = new IpcOptions() { WriteLegacyIpcFormat = writeLegacyIpcFormat };
394 
395             using (MemoryStream stream = new MemoryStream())
396             {
397                 using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options))
398                 {
399                     writer.WriteRecordBatch(originalBatch);
400                     writer.WriteEnd();
401                 }
402 
403                 stream.Position = 0;
404 
405                 // ensure the continuation is written correctly
406                 byte[] buffer = stream.ToArray();
407                 int messageLength = BinaryPrimitives.ReadInt32LittleEndian(buffer);
408                 int endOfBuffer1 = BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan(buffer.Length - 8));
409                 int endOfBuffer2 = BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan(buffer.Length - 4));
410                 if (writeLegacyIpcFormat)
411                 {
412                     // the legacy IPC format doesn't have a continuation token at the start
413                     Assert.NotEqual(-1, messageLength);
414                     Assert.NotEqual(-1, endOfBuffer1);
415                 }
416                 else
417                 {
418                     // the latest IPC format has a continuation token at the start
419                     Assert.Equal(-1, messageLength);
420                     Assert.Equal(-1, endOfBuffer1);
421                 }
422 
423                 Assert.Equal(0, endOfBuffer2);
424             }
425         }
426 
427         [Theory]
428         [InlineData(true)]
429         [InlineData(false)]
WriteLegacyIpcFormatAsync(bool writeLegacyIpcFormat)430         public async Task WriteLegacyIpcFormatAsync(bool writeLegacyIpcFormat)
431         {
432             RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100);
433             var options = new IpcOptions() { WriteLegacyIpcFormat = writeLegacyIpcFormat };
434 
435             using (MemoryStream stream = new MemoryStream())
436             {
437                 using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options))
438                 {
439                     await writer.WriteRecordBatchAsync(originalBatch);
440                     await writer.WriteEndAsync();
441                 }
442 
443                 stream.Position = 0;
444 
445                 // ensure the continuation is written correctly
446                 byte[] buffer = stream.ToArray();
447                 int messageLength = BinaryPrimitives.ReadInt32LittleEndian(buffer);
448                 int endOfBuffer1 = BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan(buffer.Length - 8));
449                 int endOfBuffer2 = BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan(buffer.Length - 4));
450                 if (writeLegacyIpcFormat)
451                 {
452                     // the legacy IPC format doesn't have a continuation token at the start
453                     Assert.NotEqual(-1, messageLength);
454                     Assert.NotEqual(-1, endOfBuffer1);
455                 }
456                 else
457                 {
458                     // the latest IPC format has a continuation token at the start
459                     Assert.Equal(-1, messageLength);
460                     Assert.Equal(-1, endOfBuffer1);
461                 }
462 
463                 Assert.Equal(0, endOfBuffer2);
464             }
465         }
466 
467         [Fact]
WritesMetadataCorrectly()468         public void WritesMetadataCorrectly()
469         {
470             Schema.Builder schemaBuilder = new Schema.Builder()
471                 .Metadata("index", "1, 2, 3, 4, 5")
472                 .Metadata("reverseIndex", "5, 4, 3, 2, 1")
473                 .Field(f => f
474                     .Name("IntCol")
475                     .DataType(UInt32Type.Default)
476                     .Metadata("custom1", "false")
477                     .Metadata("custom2", "true"))
478                 .Field(f => f
479                     .Name("StringCol")
480                     .DataType(StringType.Default)
481                     .Metadata("custom2", "false")
482                     .Metadata("custom3", "4"))
483                 .Field(f => f
484                     .Name("StructCol")
485                     .DataType(new StructType(new[] {
486                         new Field("Inner1", FloatType.Default, nullable: false),
487                         new Field("Inner2", DoubleType.Default, nullable: true, new Dictionary<string, string>() { { "customInner", "1" }, { "customInner2", "3" } })
488                     }))
489                     .Metadata("custom4", "6.4")
490                     .Metadata("custom1", "true"));
491 
492             var schema = schemaBuilder.Build();
493             RecordBatch originalBatch = TestData.CreateSampleRecordBatch(schema, length: 10);
494 
495             TestRoundTripRecordBatch(originalBatch);
496         }
497     }
498 }
499