1 // Licensed to the Apache Software Foundation (ASF) under one or more 2 // contributor license agreements. See the NOTICE file distributed with 3 // this work for additional information regarding copyright ownership. 4 // The ASF licenses this file to You under the Apache License, Version 2.0 5 // (the "License"); you may not use this file except in compliance with 6 // the License. You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 using System; 17 using System.Buffers.Binary; 18 using System.Collections.Generic; 19 using System.IO; 20 using System.Linq; 21 using System.Net; 22 using System.Net.Sockets; 23 using System.Threading.Tasks; 24 using Apache.Arrow.Ipc; 25 using Apache.Arrow.Types; 26 using Xunit; 27 28 namespace Apache.Arrow.Tests 29 { 30 public class ArrowStreamWriterTests 31 { 32 [Fact] Ctor_LeaveOpenDefault_StreamClosedOnDispose()33 public void Ctor_LeaveOpenDefault_StreamClosedOnDispose() 34 { 35 RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); 36 var stream = new MemoryStream(); 37 new ArrowStreamWriter(stream, originalBatch.Schema).Dispose(); 38 Assert.Throws<ObjectDisposedException>(() => stream.Position); 39 } 40 41 [Fact] Ctor_LeaveOpenFalse_StreamClosedOnDispose()42 public void Ctor_LeaveOpenFalse_StreamClosedOnDispose() 43 { 44 RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); 45 var stream = new MemoryStream(); 46 new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: false).Dispose(); 47 Assert.Throws<ObjectDisposedException>(() => stream.Position); 48 } 49 50 [Fact] Ctor_LeaveOpenTrue_StreamValidOnDispose()51 public void Ctor_LeaveOpenTrue_StreamValidOnDispose() 52 { 53 RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); 54 var stream = new MemoryStream(); 55 new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true).Dispose(); 56 Assert.Equal(0, stream.Position); 57 } 58 59 [Fact] CanWriteToNetworkStream()60 public void CanWriteToNetworkStream() 61 { 62 RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); 63 64 const int port = 32153; 65 TcpListener listener = new TcpListener(IPAddress.Loopback, port); 66 listener.Start(); 67 68 using (TcpClient sender = new TcpClient()) 69 { 70 sender.Connect(IPAddress.Loopback, port); 71 NetworkStream stream = sender.GetStream(); 72 73 using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema)) 74 { 75 writer.WriteRecordBatch(originalBatch); 76 writer.WriteEnd(); 77 78 stream.Flush(); 79 } 80 } 81 82 using (TcpClient receiver = listener.AcceptTcpClient()) 83 { 84 NetworkStream stream = receiver.GetStream(); 85 using (var reader = new ArrowStreamReader(stream)) 86 { 87 RecordBatch newBatch = reader.ReadNextRecordBatch(); 88 ArrowReaderVerifier.CompareBatches(originalBatch, newBatch); 89 } 90 } 91 } 92 93 [Fact] CanWriteToNetworkStreamAsync()94 public async Task CanWriteToNetworkStreamAsync() 95 { 96 RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); 97 98 const int port = 32154; 99 TcpListener listener = new TcpListener(IPAddress.Loopback, port); 100 listener.Start(); 101 102 using (TcpClient sender = new TcpClient()) 103 { 104 sender.Connect(IPAddress.Loopback, port); 105 NetworkStream stream = sender.GetStream(); 106 107 using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema)) 108 { 109 await writer.WriteRecordBatchAsync(originalBatch); 110 await writer.WriteEndAsync(); 111 112 stream.Flush(); 113 } 114 } 115 116 using (TcpClient receiver = listener.AcceptTcpClient()) 117 { 118 NetworkStream stream = receiver.GetStream(); 119 using (var reader = new ArrowStreamReader(stream)) 120 { 121 RecordBatch newBatch = reader.ReadNextRecordBatch(); 122 ArrowReaderVerifier.CompareBatches(originalBatch, newBatch); 123 } 124 } 125 } 126 127 [Fact] WriteEmptyBatch()128 public void WriteEmptyBatch() 129 { 130 RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 0); 131 132 TestRoundTripRecordBatch(originalBatch); 133 } 134 135 [Fact] WriteEmptyBatchAsync()136 public async Task WriteEmptyBatchAsync() 137 { 138 RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 0); 139 140 await TestRoundTripRecordBatchAsync(originalBatch); 141 } 142 143 [Fact] WriteBatchWithNulls()144 public void WriteBatchWithNulls() 145 { 146 RecordBatch originalBatch = new RecordBatch.Builder() 147 .Append("Column1", false, col => col.Int32(array => array.AppendRange(Enumerable.Range(0, 10)))) 148 .Append("Column2", true, new Int32Array( 149 valueBuffer: new ArrowBuffer.Builder<int>().AppendRange(Enumerable.Range(0, 10)).Build(), 150 nullBitmapBuffer: new ArrowBuffer.Builder<byte>().Append(0xfd).Append(0xff).Build(), 151 length: 10, 152 nullCount: 2, 153 offset: 0)) 154 .Append("Column3", true, new Int32Array( 155 valueBuffer: new ArrowBuffer.Builder<int>().AppendRange(Enumerable.Range(0, 10)).Build(), 156 nullBitmapBuffer: new ArrowBuffer.Builder<byte>().Append(0x00).Append(0x00).Build(), 157 length: 10, 158 nullCount: 10, 159 offset: 0)) 160 .Append("NullableBooleanColumn", true, new BooleanArray( 161 valueBuffer: new ArrowBuffer.Builder<byte>().Append(0xfd).Append(0xff).Build(), 162 nullBitmapBuffer: new ArrowBuffer.Builder<byte>().Append(0xed).Append(0xff).Build(), 163 length: 10, 164 nullCount: 3, 165 offset: 0)) 166 .Build(); 167 168 TestRoundTripRecordBatch(originalBatch); 169 } 170 171 [Fact] WriteBatchWithNullsAsync()172 public async Task WriteBatchWithNullsAsync() 173 { 174 RecordBatch originalBatch = new RecordBatch.Builder() 175 .Append("Column1", false, col => col.Int32(array => array.AppendRange(Enumerable.Range(0, 10)))) 176 .Append("Column2", true, new Int32Array( 177 valueBuffer: new ArrowBuffer.Builder<int>().AppendRange(Enumerable.Range(0, 10)).Build(), 178 nullBitmapBuffer: new ArrowBuffer.Builder<byte>().Append(0xfd).Append(0xff).Build(), 179 length: 10, 180 nullCount: 2, 181 offset: 0)) 182 .Append("Column3", true, new Int32Array( 183 valueBuffer: new ArrowBuffer.Builder<int>().AppendRange(Enumerable.Range(0, 10)).Build(), 184 nullBitmapBuffer: new ArrowBuffer.Builder<byte>().Append(0x00).Append(0x00).Build(), 185 length: 10, 186 nullCount: 10, 187 offset: 0)) 188 .Append("NullableBooleanColumn", true, new BooleanArray( 189 valueBuffer: new ArrowBuffer.Builder<byte>().Append(0xfd).Append(0xff).Build(), 190 nullBitmapBuffer: new ArrowBuffer.Builder<byte>().Append(0xed).Append(0xff).Build(), 191 length: 10, 192 nullCount: 3, 193 offset: 0)) 194 .Build(); 195 196 await TestRoundTripRecordBatchAsync(originalBatch); 197 } 198 TestRoundTripRecordBatch(RecordBatch originalBatch, IpcOptions options = null)199 private static void TestRoundTripRecordBatch(RecordBatch originalBatch, IpcOptions options = null) 200 { 201 using (MemoryStream stream = new MemoryStream()) 202 { 203 using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options)) 204 { 205 writer.WriteRecordBatch(originalBatch); 206 writer.WriteEnd(); 207 } 208 209 stream.Position = 0; 210 211 using (var reader = new ArrowStreamReader(stream)) 212 { 213 RecordBatch newBatch = reader.ReadNextRecordBatch(); 214 ArrowReaderVerifier.CompareBatches(originalBatch, newBatch); 215 } 216 } 217 } 218 219 TestRoundTripRecordBatchAsync(RecordBatch originalBatch, IpcOptions options = null)220 private static async Task TestRoundTripRecordBatchAsync(RecordBatch originalBatch, IpcOptions options = null) 221 { 222 using (MemoryStream stream = new MemoryStream()) 223 { 224 using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options)) 225 { 226 await writer.WriteRecordBatchAsync(originalBatch); 227 await writer.WriteEndAsync(); 228 } 229 230 stream.Position = 0; 231 232 using (var reader = new ArrowStreamReader(stream)) 233 { 234 RecordBatch newBatch = reader.ReadNextRecordBatch(); 235 ArrowReaderVerifier.CompareBatches(originalBatch, newBatch); 236 } 237 } 238 } 239 240 [Fact] WriteBatchWithCorrectPadding()241 public void WriteBatchWithCorrectPadding() 242 { 243 byte value1 = 0x04; 244 byte value2 = 0x14; 245 var batch = new RecordBatch( 246 new Schema.Builder() 247 .Field(f => f.Name("age").DataType(Int32Type.Default)) 248 .Field(f => f.Name("characterCount").DataType(Int32Type.Default)) 249 .Build(), 250 new IArrowArray[] 251 { 252 new Int32Array( 253 new ArrowBuffer(new byte[] { value1, value1, 0x00, 0x00 }), 254 ArrowBuffer.Empty, 255 length: 1, 256 nullCount: 0, 257 offset: 0), 258 new Int32Array( 259 new ArrowBuffer(new byte[] { value2, value2, 0x00, 0x00 }), 260 ArrowBuffer.Empty, 261 length: 1, 262 nullCount: 0, 263 offset: 0) 264 }, 265 length: 1); 266 267 TestRoundTripRecordBatch(batch); 268 269 using (MemoryStream stream = new MemoryStream()) 270 { 271 using (var writer = new ArrowStreamWriter(stream, batch.Schema, leaveOpen: true)) 272 { 273 writer.WriteRecordBatch(batch); 274 writer.WriteEnd(); 275 } 276 277 byte[] writtenBytes = stream.ToArray(); 278 279 // ensure that the data buffers at the end are 8-byte aligned 280 Assert.Equal(value1, writtenBytes[writtenBytes.Length - 24]); 281 Assert.Equal(value1, writtenBytes[writtenBytes.Length - 23]); 282 for (int i = 22; i > 16; i--) 283 { 284 Assert.Equal(0, writtenBytes[writtenBytes.Length - i]); 285 } 286 287 Assert.Equal(value2, writtenBytes[writtenBytes.Length - 16]); 288 Assert.Equal(value2, writtenBytes[writtenBytes.Length - 15]); 289 for (int i = 14; i > 8; i--) 290 { 291 Assert.Equal(0, writtenBytes[writtenBytes.Length - i]); 292 } 293 294 // verify the EOS is written correctly 295 for (int i = 8; i > 4; i--) 296 { 297 Assert.Equal(0xFF, writtenBytes[writtenBytes.Length - i]); 298 } 299 for (int i = 4; i > 0; i--) 300 { 301 Assert.Equal(0x00, writtenBytes[writtenBytes.Length - i]); 302 } 303 } 304 } 305 306 [Fact] WriteBatchWithCorrectPaddingAsync()307 public async Task WriteBatchWithCorrectPaddingAsync() 308 { 309 byte value1 = 0x04; 310 byte value2 = 0x14; 311 var batch = new RecordBatch( 312 new Schema.Builder() 313 .Field(f => f.Name("age").DataType(Int32Type.Default)) 314 .Field(f => f.Name("characterCount").DataType(Int32Type.Default)) 315 .Build(), 316 new IArrowArray[] 317 { 318 new Int32Array( 319 new ArrowBuffer(new byte[] { value1, value1, 0x00, 0x00 }), 320 ArrowBuffer.Empty, 321 length: 1, 322 nullCount: 0, 323 offset: 0), 324 new Int32Array( 325 new ArrowBuffer(new byte[] { value2, value2, 0x00, 0x00 }), 326 ArrowBuffer.Empty, 327 length: 1, 328 nullCount: 0, 329 offset: 0) 330 }, 331 length: 1); 332 333 await TestRoundTripRecordBatchAsync(batch); 334 335 using (MemoryStream stream = new MemoryStream()) 336 { 337 using (var writer = new ArrowStreamWriter(stream, batch.Schema, leaveOpen: true)) 338 { 339 await writer.WriteRecordBatchAsync(batch); 340 await writer.WriteEndAsync(); 341 } 342 343 byte[] writtenBytes = stream.ToArray(); 344 345 // ensure that the data buffers at the end are 8-byte aligned 346 Assert.Equal(value1, writtenBytes[writtenBytes.Length - 24]); 347 Assert.Equal(value1, writtenBytes[writtenBytes.Length - 23]); 348 for (int i = 22; i > 16; i--) 349 { 350 Assert.Equal(0, writtenBytes[writtenBytes.Length - i]); 351 } 352 353 Assert.Equal(value2, writtenBytes[writtenBytes.Length - 16]); 354 Assert.Equal(value2, writtenBytes[writtenBytes.Length - 15]); 355 for (int i = 14; i > 8; i--) 356 { 357 Assert.Equal(0, writtenBytes[writtenBytes.Length - i]); 358 } 359 360 // verify the EOS is written correctly 361 for (int i = 8; i > 4; i--) 362 { 363 Assert.Equal(0xFF, writtenBytes[writtenBytes.Length - i]); 364 } 365 for (int i = 4; i > 0; i--) 366 { 367 Assert.Equal(0x00, writtenBytes[writtenBytes.Length - i]); 368 } 369 } 370 } 371 372 [Fact] LegacyIpcFormatRoundTrips()373 public void LegacyIpcFormatRoundTrips() 374 { 375 RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); 376 TestRoundTripRecordBatch(originalBatch, new IpcOptions() { WriteLegacyIpcFormat = true }); 377 } 378 379 380 [Fact] LegacyIpcFormatRoundTripsAsync()381 public async Task LegacyIpcFormatRoundTripsAsync() 382 { 383 RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); 384 await TestRoundTripRecordBatchAsync(originalBatch, new IpcOptions() { WriteLegacyIpcFormat = true }); 385 } 386 387 [Theory] 388 [InlineData(true)] 389 [InlineData(false)] WriteLegacyIpcFormat(bool writeLegacyIpcFormat)390 public void WriteLegacyIpcFormat(bool writeLegacyIpcFormat) 391 { 392 RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); 393 var options = new IpcOptions() { WriteLegacyIpcFormat = writeLegacyIpcFormat }; 394 395 using (MemoryStream stream = new MemoryStream()) 396 { 397 using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options)) 398 { 399 writer.WriteRecordBatch(originalBatch); 400 writer.WriteEnd(); 401 } 402 403 stream.Position = 0; 404 405 // ensure the continuation is written correctly 406 byte[] buffer = stream.ToArray(); 407 int messageLength = BinaryPrimitives.ReadInt32LittleEndian(buffer); 408 int endOfBuffer1 = BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan(buffer.Length - 8)); 409 int endOfBuffer2 = BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan(buffer.Length - 4)); 410 if (writeLegacyIpcFormat) 411 { 412 // the legacy IPC format doesn't have a continuation token at the start 413 Assert.NotEqual(-1, messageLength); 414 Assert.NotEqual(-1, endOfBuffer1); 415 } 416 else 417 { 418 // the latest IPC format has a continuation token at the start 419 Assert.Equal(-1, messageLength); 420 Assert.Equal(-1, endOfBuffer1); 421 } 422 423 Assert.Equal(0, endOfBuffer2); 424 } 425 } 426 427 [Theory] 428 [InlineData(true)] 429 [InlineData(false)] WriteLegacyIpcFormatAsync(bool writeLegacyIpcFormat)430 public async Task WriteLegacyIpcFormatAsync(bool writeLegacyIpcFormat) 431 { 432 RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100); 433 var options = new IpcOptions() { WriteLegacyIpcFormat = writeLegacyIpcFormat }; 434 435 using (MemoryStream stream = new MemoryStream()) 436 { 437 using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true, options)) 438 { 439 await writer.WriteRecordBatchAsync(originalBatch); 440 await writer.WriteEndAsync(); 441 } 442 443 stream.Position = 0; 444 445 // ensure the continuation is written correctly 446 byte[] buffer = stream.ToArray(); 447 int messageLength = BinaryPrimitives.ReadInt32LittleEndian(buffer); 448 int endOfBuffer1 = BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan(buffer.Length - 8)); 449 int endOfBuffer2 = BinaryPrimitives.ReadInt32LittleEndian(buffer.AsSpan(buffer.Length - 4)); 450 if (writeLegacyIpcFormat) 451 { 452 // the legacy IPC format doesn't have a continuation token at the start 453 Assert.NotEqual(-1, messageLength); 454 Assert.NotEqual(-1, endOfBuffer1); 455 } 456 else 457 { 458 // the latest IPC format has a continuation token at the start 459 Assert.Equal(-1, messageLength); 460 Assert.Equal(-1, endOfBuffer1); 461 } 462 463 Assert.Equal(0, endOfBuffer2); 464 } 465 } 466 467 [Fact] WritesMetadataCorrectly()468 public void WritesMetadataCorrectly() 469 { 470 Schema.Builder schemaBuilder = new Schema.Builder() 471 .Metadata("index", "1, 2, 3, 4, 5") 472 .Metadata("reverseIndex", "5, 4, 3, 2, 1") 473 .Field(f => f 474 .Name("IntCol") 475 .DataType(UInt32Type.Default) 476 .Metadata("custom1", "false") 477 .Metadata("custom2", "true")) 478 .Field(f => f 479 .Name("StringCol") 480 .DataType(StringType.Default) 481 .Metadata("custom2", "false") 482 .Metadata("custom3", "4")) 483 .Field(f => f 484 .Name("StructCol") 485 .DataType(new StructType(new[] { 486 new Field("Inner1", FloatType.Default, nullable: false), 487 new Field("Inner2", DoubleType.Default, nullable: true, new Dictionary<string, string>() { { "customInner", "1" }, { "customInner2", "3" } }) 488 })) 489 .Metadata("custom4", "6.4") 490 .Metadata("custom1", "true")); 491 492 var schema = schemaBuilder.Build(); 493 RecordBatch originalBatch = TestData.CreateSampleRecordBatch(schema, length: 10); 494 495 TestRoundTripRecordBatch(originalBatch); 496 } 497 } 498 } 499