1 // Licensed to the Apache Software Foundation (ASF) under one or more 2 // contributor license agreements. See the NOTICE file distributed with 3 // this work for additional information regarding copyright ownership. 4 // The ASF licenses this file to You under the Apache License, Version 2.0 5 // (the "License"); you may not use this file except in compliance with 6 // the License. You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 using Apache.Arrow.Memory; 17 using System; 18 using System.Buffers; 19 using System.IO; 20 using System.Linq; 21 using System.Threading; 22 using System.Threading.Tasks; 23 24 namespace Apache.Arrow.Ipc 25 { 26 internal sealed class ArrowFileReaderImplementation : ArrowStreamReaderImplementation 27 { 28 public bool IsFileValid { get; private set; } 29 30 /// <summary> 31 /// When using GetNextRecordBatch this value 32 /// is to remember what index is next 33 /// </summary> 34 private int _recordBatchIndex; 35 36 private ArrowFooter _footer; 37 ArrowFileReaderImplementation(Stream stream, MemoryAllocator allocator, bool leaveOpen)38 public ArrowFileReaderImplementation(Stream stream, MemoryAllocator allocator, bool leaveOpen) 39 : base(stream, allocator, leaveOpen) 40 { 41 } 42 RecordBatchCountAsync()43 public async ValueTask<int> RecordBatchCountAsync() 44 { 45 if (!HasReadSchema) 46 { 47 await ReadSchemaAsync().ConfigureAwait(false); 48 } 49 50 return _footer.RecordBatchCount; 51 } 52 ReadSchemaAsync()53 protected override async ValueTask ReadSchemaAsync() 54 { 55 if (HasReadSchema) 56 { 57 return; 58 } 59 60 await ValidateFileAsync().ConfigureAwait(false); 61 62 int footerLength = 0; 63 await ArrayPool<byte>.Shared.RentReturnAsync(4, async (buffer) => 64 { 65 BaseStream.Position = GetFooterLengthPosition(); 66 67 int bytesRead = await BaseStream.ReadFullBufferAsync(buffer).ConfigureAwait(false); 68 EnsureFullRead(buffer, bytesRead); 69 70 footerLength = ReadFooterLength(buffer); 71 }).ConfigureAwait(false); 72 73 await ArrayPool<byte>.Shared.RentReturnAsync(footerLength, async (buffer) => 74 { 75 long footerStartPosition = GetFooterLengthPosition() - footerLength; 76 77 BaseStream.Position = footerStartPosition; 78 79 int bytesRead = await BaseStream.ReadFullBufferAsync(buffer).ConfigureAwait(false); 80 EnsureFullRead(buffer, bytesRead); 81 82 ReadSchema(buffer); 83 }).ConfigureAwait(false); 84 } 85 ReadSchema()86 protected override void ReadSchema() 87 { 88 if (HasReadSchema) 89 { 90 return; 91 } 92 93 ValidateFile(); 94 95 int footerLength = 0; 96 ArrayPool<byte>.Shared.RentReturn(4, (buffer) => 97 { 98 BaseStream.Position = GetFooterLengthPosition(); 99 100 int bytesRead = BaseStream.ReadFullBuffer(buffer); 101 EnsureFullRead(buffer, bytesRead); 102 103 footerLength = ReadFooterLength(buffer); 104 }); 105 106 ArrayPool<byte>.Shared.RentReturn(footerLength, (buffer) => 107 { 108 long footerStartPosition = GetFooterLengthPosition() - footerLength; 109 110 BaseStream.Position = footerStartPosition; 111 112 int bytesRead = BaseStream.ReadFullBuffer(buffer); 113 EnsureFullRead(buffer, bytesRead); 114 115 ReadSchema(buffer); 116 }); 117 } 118 GetFooterLengthPosition()119 private long GetFooterLengthPosition() 120 { 121 return BaseStream.Length - ArrowFileConstants.Magic.Length - 4; 122 } 123 ReadFooterLength(Memory<byte> buffer)124 private static int ReadFooterLength(Memory<byte> buffer) 125 { 126 int footerLength = BitUtility.ReadInt32(buffer); 127 128 if (footerLength <= 0) 129 throw new InvalidDataException( 130 $"Footer length has invalid size <{footerLength}>"); 131 132 return footerLength; 133 } 134 ReadSchema(Memory<byte> buffer)135 private void ReadSchema(Memory<byte> buffer) 136 { 137 // Deserialize the footer from the footer flatbuffer 138 _footer = new ArrowFooter(Flatbuf.Footer.GetRootAsFooter(CreateByteBuffer(buffer))); 139 140 Schema = _footer.Schema; 141 } 142 ReadRecordBatchAsync(int index, CancellationToken cancellationToken)143 public async ValueTask<RecordBatch> ReadRecordBatchAsync(int index, CancellationToken cancellationToken) 144 { 145 await ReadSchemaAsync().ConfigureAwait(false); 146 147 if (index >= _footer.RecordBatchCount) 148 { 149 throw new ArgumentOutOfRangeException(nameof(index)); 150 } 151 152 Block block = _footer.GetRecordBatchBlock(index); 153 154 BaseStream.Position = block.Offset; 155 156 return await ReadRecordBatchAsync(cancellationToken).ConfigureAwait(false); 157 } 158 ReadRecordBatch(int index)159 public RecordBatch ReadRecordBatch(int index) 160 { 161 ReadSchema(); 162 163 if (index >= _footer.RecordBatchCount) 164 { 165 throw new ArgumentOutOfRangeException(nameof(index)); 166 } 167 168 Block block = _footer.GetRecordBatchBlock(index); 169 170 BaseStream.Position = block.Offset; 171 172 return ReadRecordBatch(); 173 } 174 ReadNextRecordBatchAsync(CancellationToken cancellationToken)175 public override async ValueTask<RecordBatch> ReadNextRecordBatchAsync(CancellationToken cancellationToken) 176 { 177 await ReadSchemaAsync().ConfigureAwait(false); 178 179 if (_recordBatchIndex >= _footer.RecordBatchCount) 180 { 181 return null; 182 } 183 184 RecordBatch result = await ReadRecordBatchAsync(_recordBatchIndex, cancellationToken).ConfigureAwait(false); 185 _recordBatchIndex++; 186 187 return result; 188 } 189 ReadNextRecordBatch()190 public override RecordBatch ReadNextRecordBatch() 191 { 192 ReadSchema(); 193 194 if (_recordBatchIndex >= _footer.RecordBatchCount) 195 { 196 return null; 197 } 198 199 RecordBatch result = ReadRecordBatch(_recordBatchIndex); 200 _recordBatchIndex++; 201 202 return result; 203 } 204 205 /// <summary> 206 /// Check if file format is valid. If it's valid don't run the validation again. 207 /// </summary> ValidateFileAsync()208 private async ValueTask ValidateFileAsync() 209 { 210 if (IsFileValid) 211 { 212 return; 213 } 214 215 await ValidateMagicAsync().ConfigureAwait(false); 216 217 IsFileValid = true; 218 } 219 220 /// <summary> 221 /// Check if file format is valid. If it's valid don't run the validation again. 222 /// </summary> ValidateFile()223 private void ValidateFile() 224 { 225 if (IsFileValid) 226 { 227 return; 228 } 229 230 ValidateMagic(); 231 232 IsFileValid = true; 233 } 234 ValidateMagicAsync()235 private async ValueTask ValidateMagicAsync() 236 { 237 long startingPosition = BaseStream.Position; 238 int magicLength = ArrowFileConstants.Magic.Length; 239 240 try 241 { 242 await ArrayPool<byte>.Shared.RentReturnAsync(magicLength, async (buffer) => 243 { 244 // Seek to the beginning of the stream 245 BaseStream.Position = 0; 246 247 // Read beginning of stream 248 await BaseStream.ReadAsync(buffer).ConfigureAwait(false); 249 250 VerifyMagic(buffer); 251 252 // Move stream position to magic-length bytes away from the end of the stream 253 BaseStream.Position = BaseStream.Length - magicLength; 254 255 // Read the end of the stream 256 await BaseStream.ReadAsync(buffer).ConfigureAwait(false); 257 258 VerifyMagic(buffer); 259 }).ConfigureAwait(false); 260 } 261 finally 262 { 263 BaseStream.Position = startingPosition; 264 } 265 } 266 ValidateMagic()267 private void ValidateMagic() 268 { 269 long startingPosition = BaseStream.Position; 270 int magicLength = ArrowFileConstants.Magic.Length; 271 272 try 273 { 274 ArrayPool<byte>.Shared.RentReturn(magicLength, buffer => 275 { 276 // Seek to the beginning of the stream 277 BaseStream.Position = 0; 278 279 // Read beginning of stream 280 BaseStream.Read(buffer); 281 282 VerifyMagic(buffer); 283 284 // Move stream position to magic-length bytes away from the end of the stream 285 BaseStream.Position = BaseStream.Length - magicLength; 286 287 // Read the end of the stream 288 BaseStream.Read(buffer); 289 290 VerifyMagic(buffer); 291 }); 292 } 293 finally 294 { 295 BaseStream.Position = startingPosition; 296 } 297 } 298 VerifyMagic(Memory<byte> buffer)299 private void VerifyMagic(Memory<byte> buffer) 300 { 301 if (!ArrowFileConstants.Magic.AsSpan().SequenceEqual(buffer.Span)) 302 { 303 throw new InvalidDataException( 304 $"Invalid magic at offset <{BaseStream.Position}>"); 305 } 306 } 307 } 308 } 309